From 62bb0ddf4c6fe09fce4d880354b198166e208729 Mon Sep 17 00:00:00 2001 From: zhangxiao <44601329+zhangxiao-stack@users.noreply.github.com> Date: Thu, 5 Jan 2023 18:03:05 +0800 Subject: [PATCH 1/4] fix hip support --- tensorflow/compiler/xla/service/gpu/BUILD | 6 ++-- .../xla/service/gpu/amdgpu_compiler.cc | 5 +-- .../gpu/llvm_gpu_backend/gpu_backend_lib.cc | 20 ++++++++--- .../compiler/xla/stream_executor/rocm/BUILD | 3 -- .../xla/stream_executor/rocm/rocm_fft.h | 2 +- .../stream_executor/rocm/rocm_gpu_executor.cc | 33 ++++++++++--------- tensorflow/core/kernels/gpu_prim.h | 2 +- tensorflow/stream_executor/rocm/BUILD | 7 +++- .../devel.usertools/nvidia.bazelrc | 2 +- .../tsl/platform/default/build_config/BUILD | 9 +++-- tensorflow/tsl/platform/default/dso_loader.cc | 6 ++-- third_party/gpus/rocm_configure.bzl | 6 ++-- 12 files changed, 62 insertions(+), 39 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 429bbdd2952..b11e12de385 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -776,8 +776,10 @@ cc_library( "//tensorflow/compiler/xla/stream_executor", "//tensorflow/compiler/xla/stream_executor:device_memory", "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", - "//tensorflow/compiler/xla/stream_executor/gpu:gpu_asm_opts", - ]) + ["//tensorflow/tsl/platform:status"], + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_asm_opts" + ]) + ["//tensorflow/tsl/platform:status"] + if_dcu([ + "//tensorflow/stream_executor/rocm:hipsolver_wrapper", + ]), ) # TODO(ezhulenev): Extract `RunTriangularSolve` into a separate library. diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc index dbbfe18819b..bdc2418f64f 100644 --- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc @@ -69,8 +69,9 @@ std::string GetROCDLDir(const HloModuleConfig& config) { } checked = true; } - -#if (TF_ROCM_VERSION >= 30900 || TENSORFLOW_USE_DCU) +#if TENSORFLOW_USE_DCU + std::string libdevice_dir = tsl::io::JoinPath(rocm_path, "amdgcn/bitcode"); +#elif (TF_ROCM_VERSION >= 30900) && (!TENSORFLOW_USE_DCU) std::string libdevice_dir = tensorflow::io::JoinPath(rocm_path, "amdgcn/bitcode"); #else std::string libdevice_dir = tensorflow::io::JoinPath(rocm_path, "lib"); diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 048640fb046..e4bf7b70c57 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -752,7 +752,11 @@ StatusOr> EmitModuleToHsaco( // Locate lld. // TODO(whchung@gmail.com): change to tensorflow::ROCmRoot() after // ROCm-Device-Libs PR. +#if TENSORFLOW_USE_DCU + std::string lld_path = tsl::io::JoinPath("/opt/dtk", "llvm/bin"); +#else std::string lld_path = tsl::io::JoinPath("/opt/rocm", "llvm/bin"); +#endif auto lld_program = llvm::sys::findProgramByName("ld.lld", {lld_path}); if (!lld_program) { return xla::InternalError("unable to find ld.lld in PATH: %s", @@ -811,8 +815,11 @@ Status AMDGPUTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version, if (!compute_capability) { return xla::InternalError("Incompatible compute capability was specified."); } - +#if TENSORFLOW_USE_DCU + std::string gcn_arch_name = "gfx906"; +#else std::string gcn_arch_name = compute_capability->gcn_arch_name(); +#endif TF_RETURN_IF_ERROR( LinkROCDLIfNecessary(module, gcn_arch_name, device_bitcode_dir_path)); @@ -887,8 +894,11 @@ std::unique_ptr AMDGPUGetTargetMachine( const HloModuleConfig& hlo_module_config) { auto compute_capability = std::get_if(&gpu_version); - +#if TENSORFLOW_USE_DCU + std::string gcn_arch_name = "gfx906"; +#else std::string gcn_arch_name = compute_capability->gcn_arch_name(); +#endif auto arch = GetFeatureStrFromGCNArchName(gcn_arch_name); return GetTargetMachine(std::move(target_triple), arch.first, hlo_module_config, arch.second); @@ -958,9 +968,11 @@ StatusOr> CompileToHsaco( return xla::InternalError( "Incompatible compute capability was specified."); } - +#if TENSORFLOW_USE_DCU + std::string gcn_arch_name = "gfx906"; +#else std::string gcn_arch_name = compute_capability->gcn_arch_name(); - +#endif uint64_t hash; if (HsacoCache::Find(str, hash, gcn_arch_name, hsaco)) { VLOG(1) << "HSACO cache hit"; diff --git a/tensorflow/compiler/xla/stream_executor/rocm/BUILD b/tensorflow/compiler/xla/stream_executor/rocm/BUILD index a187b45f1ea..46da0252e8f 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/BUILD +++ b/tensorflow/compiler/xla/stream_executor/rocm/BUILD @@ -459,9 +459,6 @@ cc_library( ":rocm_driver", ":rocm_platform", ":rocm_helpers", - ]) + if_dcu([ - ":rocfft_plugin", - ], [ ":hipfft_plugin", ]), alwayslink = 1, diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_fft.h b/tensorflow/compiler/xla/stream_executor/rocm/rocm_fft.h index beb54a328bd..892393dcc37 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_fft.h +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_fft.h @@ -23,7 +23,7 @@ limitations under the License. #if TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" -#if TF_ROCM_VERSION < 40100 || TENSORFLOW_USE_DCU +#if TF_ROCM_VERSION < 40100 #include "rocm/include/rocfft/hipfft.h" #else #include "rocm/include/hipfft/hipfft.h" diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_gpu_executor.cc b/tensorflow/compiler/xla/stream_executor/rocm/rocm_gpu_executor.cc index 2d740d7a6fc..f8d0cd839ca 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_gpu_executor.cc +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_gpu_executor.cc @@ -326,25 +326,26 @@ port::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, // prepare kernargs // KernelArgsArrayBase keeps the pointer of arguments // deference them here - std::vector kernargs; - KernelArgIterator iter = args.arg_iterator(); - while (iter.has_next()) { - KernelArg arg = iter.next(); - VLOG(2) << "*(arg.address): " - << reinterpret_cast( - *static_cast(arg.address)); - kernargs.push_back( - reinterpret_cast(*static_cast(arg.address))); - } - - size_t size = sizeof(void*) * kernargs.size(); - void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, kernargs.data(), - HIP_LAUNCH_PARAM_BUFFER_SIZE, &size, HIP_LAUNCH_PARAM_END}; - +// std::vector kernargs; +// KernelArgIterator iter = args.arg_iterator(); +// while (iter.has_next()) { +// KernelArg arg = iter.next(); +// VLOG(2) << "*(arg.address): " +// << reinterpret_cast( +// *static_cast(arg.address)); +// kernargs.push_back( +// reinterpret_cast(*static_cast(arg.address))); +// } + +// size_t size = sizeof(void*) * kernargs.size(); +// void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, kernargs.data(), +// HIP_LAUNCH_PARAM_BUFFER_SIZE, &size, HIP_LAUNCH_PARAM_END}; + void** kernel_params = const_cast(args.argument_addresses().data()); return GpuDriver::LaunchKernel( GetGpuContext(stream), kernel.name(), hipfunc, block_dims.x, block_dims.y, block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, - args.number_of_shared_bytes(), hipstream, nullptr, (void**)&config); + args.number_of_shared_bytes(), hipstream, kernel_params, nullptr); +// args.number_of_shared_bytes(), hipstream, nullptr, (void**)&config); } int GpuExecutor::CalculateOccupancy(const DeviceDescription& device_description, diff --git a/tensorflow/core/kernels/gpu_prim.h b/tensorflow/core/kernels/gpu_prim.h index 2e6df7d5b6b..9c334899d78 100644 --- a/tensorflow/core/kernels/gpu_prim.h +++ b/tensorflow/core/kernels/gpu_prim.h @@ -87,7 +87,7 @@ namespace gpuprim = ::hipcub; namespace rocprim { namespace detail { -#if (TF_ROCM_VERSION >= 50200) +#if (TF_ROCM_VERSION >= 50200) && (!TENSORFLOW_USE_DCU) template <> struct float_bit_mask { static constexpr uint16_t sign_bit = 0x8000; diff --git a/tensorflow/stream_executor/rocm/BUILD b/tensorflow/stream_executor/rocm/BUILD index 82409bdea10..6d4757f9211 100644 --- a/tensorflow/stream_executor/rocm/BUILD +++ b/tensorflow/stream_executor/rocm/BUILD @@ -6,7 +6,12 @@ load( "//tensorflow/compiler/xla/stream_executor:build_defs.bzl", "stream_executor_friends", ) -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") +load( + "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm_is_configured", + "if_dcu", +) +load("//tensorflow/tsl/platform:build_config_root.bzl", "if_static") package( default_visibility = [":friends"], diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/nvidia.bazelrc b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/nvidia.bazelrc index 63d35330f1e..18ae0c3a90e 120000 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/nvidia.bazelrc +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/nvidia.bazelrc @@ -1 +1 @@ -gpu.bazelrc \ No newline at end of file +gpu.bazelrc diff --git a/tensorflow/tsl/platform/default/build_config/BUILD b/tensorflow/tsl/platform/default/build_config/BUILD index de18198b017..b98a758f277 100644 --- a/tensorflow/tsl/platform/default/build_config/BUILD +++ b/tensorflow/tsl/platform/default/build_config/BUILD @@ -3,6 +3,10 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/tsl:tsl.bzl", "tsl_copts") +load( + "@local_config_cuda//cuda:build_defs.bzl", + "if_cuda", +) package(default_visibility = ["//tensorflow/tsl:internal"]) @@ -112,7 +116,6 @@ cc_library( linkstatic = 1, deps = [], ) - cc_library( name = "cuda", data = [ @@ -128,7 +131,7 @@ cc_library( "-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib64", ], }), - deps = [ + deps = if_cuda([ "@local_config_cuda//cuda:cudart", - ], + ]), ) diff --git a/tensorflow/tsl/platform/default/dso_loader.cc b/tensorflow/tsl/platform/default/dso_loader.cc index 645bd3614fc..961af20e6b3 100644 --- a/tensorflow/tsl/platform/default/dso_loader.cc +++ b/tensorflow/tsl/platform/default/dso_loader.cc @@ -144,7 +144,7 @@ StatusOr GetRocblasDsoHandle() { return GetDsoHandle("rocblas", ""); } StatusOr GetMiopenDsoHandle() { return GetDsoHandle("MIOpen", ""); } StatusOr GetHipfftDsoHandle() { -#if TF_ROCM_VERSION < 40100 || TENSORFLOW_USE_DCU +#if TF_ROCM_VERSION < 40100 return GetDsoHandle("rocfft", ""); #else return GetDsoHandle("hipfft", ""); @@ -157,7 +157,7 @@ StatusOr GetRocsolverDsoHandle() { return GetDsoHandle("rocsolver", ""); } -#if TF_ROCM_VERSION >= 40500 +#if TF_ROCM_VERSION >= 40500 || TENSORFLOW_USE_DCU StatusOr GetHipsolverDsoHandle() { return GetDsoHandle("hipsolver", ""); } @@ -256,7 +256,7 @@ StatusOr GetRocsolverDsoHandle() { return *result; } -#if TF_ROCM_VERSION >= 40500 +#if TF_ROCM_VERSION >= 40500 || TENSORFLOW_USE_DCU StatusOr GetHipsolverDsoHandle() { static auto result = new auto(DsoLoader::GetHipsolverDsoHandle()); return *result; diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 1a34b0006ed..f0b9b61caa0 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -330,7 +330,7 @@ def _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin): return libs -def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_path, bash_bin): +def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_path, bash_bin, enable_dcu): """Returns the ROCm libraries on the system. Args: @@ -561,6 +561,7 @@ def _create_local_rocm_repository(repository_ctx): bash_bin = get_bash_bin(repository_ctx) rocm_config = _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script) + enable_dcu = get_host_environ(repository_ctx, "TF_NEED_DCU") == '1' # For ROCm 4.1 and above use hipfft, older ROCm versions use rocfft rocm_version_number = int(rocm_config.rocm_version_number) @@ -620,7 +621,7 @@ def _create_local_rocm_repository(repository_ctx): ), ) - rocm_libs = _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_path, bash_bin) + rocm_libs = _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_path, bash_bin, enable_dcu) rocm_lib_srcs = [] rocm_lib_outs = [] for lib in rocm_libs.values(): @@ -753,6 +754,7 @@ def _create_local_rocm_repository(repository_ctx): "%{hip_runtime_library}": "amdhip64", "%{crosstool_verbose}": _crosstool_verbose(repository_ctx), "%{gcc_host_compiler_path}": str(cc), + "%{using_dcu}": str(enable_dcu), }, ) From 5612a8cd8597b89d08b996697719e6ea0851b9e4 Mon Sep 17 00:00:00 2001 From: zhangxiao <44601329+zhangxiao-stack@users.noreply.github.com> Date: Tue, 7 Feb 2023 11:29:20 +0800 Subject: [PATCH 2/4] decompose disc-compiler (#27) (#28) * decompose disc-compiler * update Co-authored-by: Yan Xu --- tensorflow/cc/BUILD | 2 +- tensorflow/compiler/jit/BUILD | 5 +---- .../compiler/mlir/tensorflow/ir/tf_generated_ops.td | 2 +- tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td | 2 +- tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td | 4 ++-- tensorflow/compiler/mlir/xla/BUILD | 1 - tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc | 1 - .../compiler/mlir/xla/transforms/xla_legalize_tf.cc | 2 -- tensorflow/compiler/xla/BUILD | 8 ++++---- tensorflow/compiler/xla/service/BUILD | 2 +- tensorflow/compiler/xla/service/gpu/BUILD | 2 +- .../compiler/xla/service/gpu/llvm_gpu_backend/BUILD | 2 +- tensorflow/compiler/xla/stream_executor/BUILD | 4 +--- tensorflow/compiler/xla/stream_executor/cuda/BUILD | 2 +- tensorflow/compiler/xla/stream_executor/gpu/BUILD | 9 +-------- tensorflow/compiler/xla/stream_executor/lib/BUILD | 2 +- tensorflow/core/BUILD | 8 ++------ tensorflow/core/common_runtime/BUILD | 2 +- tensorflow/stream_executor/gpu/BUILD | 2 +- tensorflow/tsl/platform/BUILD | 1 + 20 files changed, 22 insertions(+), 41 deletions(-) diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 04796a71711..df7472e08c7 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -759,7 +759,7 @@ tf_gen_op_wrappers_cc( "function_ops", ], pkg = "//tensorflow/core", - visibility = ["//tensorflow:internal"], + visibility = ["//visibility:public"], ) tf_gen_op_wrappers_cc( diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 0114be58534..d6ed1108fd4 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -50,10 +50,7 @@ filegroup( # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( name = "jit", - visibility = [ - ":friends", - "//learning/tfx:__subpackages__", - ], + visibility = ["//visibility:public"], deps = [ ":xla_cpu_device", ":xla_cpu_jit", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 734dc1e2a18..4597ec07cc4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -31,7 +31,7 @@ limitations under the License. // // Ops in this file are sorted alphabetically. -include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" +include "external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/OpAsmInterface.td" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 672e7d867d3..9e05ae6c42b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -25,7 +25,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" -include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td" +include "external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td" //===----------------------------------------------------------------------===// // TensorFlow dialect definitions diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 14fa5387d43..8f14ca318d1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -27,8 +27,8 @@ limitations under the License. #ifndef TF_OPS #define TF_OPS -include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td" -include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" +include "external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td" +include "external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 8fe52c64d1a..dbaff4a8552 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -177,7 +177,6 @@ cc_library( ":tf_xla_passes_inc_gen", ":xla_legalize_tf_passes_inc_gen", ":xla_passes_inc_gen", - "//tensorflow/compiler/mlir/disc:mhlo_disc", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "//tensorflow/compiler/xla:xla_data_proto_cc", diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index f4ddee6bbbf..f7aaa1bae3a 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -49,7 +49,6 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "mlir/Transforms/DialectConversion.h" -#include "tensorflow/compiler/mlir/disc/IR/hlo_disc_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf.cc index 68fdc943641..e3d18b37f17 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf.cc @@ -37,7 +37,6 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo -#include "tensorflow/compiler/mlir/disc/IR/hlo_disc_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" @@ -454,7 +453,6 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); - target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 08a9ecdcd06..877ae56dc88 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -9,7 +9,7 @@ load( load("//third_party/compute_library:build_defs.bzl", "if_enable_acl") package( - default_visibility = ["//tensorflow:internal"], + default_visibility = ["//visibility:public"], licenses = ["notice"], ) @@ -579,14 +579,14 @@ cc_library( cc_library( name = "error_spec", hdrs = ["error_spec.h"], - visibility = [":friends"], + visibility = ["//visibility:public"], ) cc_library( name = "literal_comparison", srcs = ["literal_comparison.cc"], hdrs = ["literal_comparison.h"], - visibility = [":friends"], + visibility = ["//visibility:public"], deps = [ ":error_spec", ":literal", @@ -995,7 +995,7 @@ cc_library( ], hdrs = ["debug_options_flags.h"], copts = if_enable_acl(["-DXLA_CPU_USE_ACL=1"]), - visibility = [":friends"], + visibility = ["//visibility:public"], deps = [ ":parse_flags_from_env", diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index faa1deb3741..4b3e7451012 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -23,7 +23,7 @@ load("//tensorflow/compiler/xla/stream_executor:build_defs.bzl", "if_gpu_is_conf load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") package( - default_visibility = [":friends"], + default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index b11e12de385..44753512b71 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -32,7 +32,7 @@ load( load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_cloud", "if_nccl") package( - default_visibility = [":friends"], + default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index 892c9b6e85c..d097deab752 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -9,7 +9,7 @@ load( ) package( - default_visibility = [":friends"], + default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/xla/stream_executor/BUILD b/tensorflow/compiler/xla/stream_executor/BUILD index e931bd9afe9..2b5149b44e6 100644 --- a/tensorflow/compiler/xla/stream_executor/BUILD +++ b/tensorflow/compiler/xla/stream_executor/BUILD @@ -13,9 +13,7 @@ load("//tensorflow/compiler/xla/stream_executor:build_defs.bzl", "stream_executo load("//tensorflow/tsl:tsl.default.bzl", "filegroup") package( - default_visibility = [ - ":friends", - ], + default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/xla/stream_executor/cuda/BUILD b/tensorflow/compiler/xla/stream_executor/cuda/BUILD index 4316e0d6a57..dc64ad43080 100644 --- a/tensorflow/compiler/xla/stream_executor/cuda/BUILD +++ b/tensorflow/compiler/xla/stream_executor/cuda/BUILD @@ -30,7 +30,7 @@ load( ) package( - default_visibility = [":friends"], + default_visibility = ["//visibility:public"], features = ["-layering_check"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/xla/stream_executor/gpu/BUILD b/tensorflow/compiler/xla/stream_executor/gpu/BUILD index a658ba9e7be..b986c0e7f05 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/BUILD +++ b/tensorflow/compiler/xla/stream_executor/gpu/BUILD @@ -269,14 +269,7 @@ cc_library( srcs = if_gpu_is_configured(["asm_compiler.cc"]), hdrs = if_gpu_is_configured(["asm_compiler.h"]), copts = tsl_copts(), - visibility = [ - "//tensorflow/compiler/mlir/disc:__subpackages__", - "//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__", - "//tensorflow/compiler/xla/service/gpu:__subpackages__", - "//tensorflow/compiler/xla/stream_executor:__subpackages__", - "//tensorflow/core/kernels:__subpackages__", - "//tensorflow/stream_executor:__subpackages__", - ], + visibility = ["//visibility:public"], deps = if_gpu_is_configured([ ":gpu_asm_opts", ":gpu_driver_header", diff --git a/tensorflow/compiler/xla/stream_executor/lib/BUILD b/tensorflow/compiler/xla/stream_executor/lib/BUILD index 3108b8b84f5..161e432adf8 100644 --- a/tensorflow/compiler/xla/stream_executor/lib/BUILD +++ b/tensorflow/compiler/xla/stream_executor/lib/BUILD @@ -4,7 +4,7 @@ load("//tensorflow/compiler/xla/stream_executor:build_defs.bzl", "stream_executo load("//tensorflow/tsl/platform:build_config_root.bzl", "if_static") package( - default_visibility = [":friends"], + default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 9256c31c17d..2d6359c968b 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -113,11 +113,7 @@ load( ) package( - default_visibility = [ - ":dependency_whitelist", - "//tensorflow:internal", - "//tensorflow_models:__subpackages__", - ], + default_visibility = ["//visibility:public"], features = if_google([ "-layering_check", "-parse_headers", @@ -1838,7 +1834,7 @@ cc_library( alias( name = "test_main", actual = "//tensorflow/tsl/platform:test_main", - visibility = ["//tensorflow:internal"], + visibility = ["//visibility:public"], ) tf_cc_tests( diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index 77e86d90f9c..c4f3d743c0a 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -42,7 +42,7 @@ default_package_visibility = [ ] package( - default_visibility = default_package_visibility, + default_visibility = ["//visibility:public"], features = if_google( [ "-layering_check", diff --git a/tensorflow/stream_executor/gpu/BUILD b/tensorflow/stream_executor/gpu/BUILD index ccf6559e9c4..93503f74f47 100644 --- a/tensorflow/stream_executor/gpu/BUILD +++ b/tensorflow/stream_executor/gpu/BUILD @@ -148,7 +148,7 @@ cc_library( hdrs = if_gpu_is_configured(["asm_compiler.h"]), copts = tf_copts(), visibility = [ - "//tensorflow/compiler/mlir/disc:__subpackages__", + "//mlir/disc:__subpackages__", "//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__", "//tensorflow/compiler/xla/service/gpu:__subpackages__", "//tensorflow/compiler/xla/stream_executor:__subpackages__", diff --git a/tensorflow/tsl/platform/BUILD b/tensorflow/tsl/platform/BUILD index 1d0f57e80b1..0555910512f 100644 --- a/tensorflow/tsl/platform/BUILD +++ b/tensorflow/tsl/platform/BUILD @@ -1279,6 +1279,7 @@ cc_library( "@com_google_absl//absl/strings", ], alwayslink = 1, + visibility = ["//visibility:public"], ) tsl_cc_test( From bb83ab49055bea7b3ae00c6080fb9bb153ffcf36 Mon Sep 17 00:00:00 2001 From: zhangxiao <44601329+zhangxiao-stack@users.noreply.github.com> Date: Tue, 18 Apr 2023 12:01:50 +0800 Subject: [PATCH 3/4] merge master (#31) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [xla][mlir][sparse] override sparse shape behavior for xla runtime path PiperOrigin-RevId: 506126261 * Create a codec class for built-in `TypeSpec`s to register, to make `TypeSpec` classes follow the codec structure used by the rest of `nested_structure_coder.py`. Also remove `nested_structure_coder.py`'s dependency on `dataset_ops.DatasetSpec`, `values.PerReplicaSpec`, `iterator_ops.IteratorSpec`, and `optional_ops.OptionalSpec`. PiperOrigin-RevId: 506126332 * Fixes shape inference of LookupTableImportV2 to handle scalar values. PiperOrigin-RevId: 506126405 * Update Android full build script Use `configure` script instead of obsolete `configure_android_workspace` PiperOrigin-RevId: 506130660 * Refactor keras/metrics to be modular. PiperOrigin-RevId: 506144312 * Internal change to the ARM build. PiperOrigin-RevId: 506145147 * gpu_delegate: Allow undefined symbol PiperOrigin-RevId: 506148959 * opencl_wrapper: Update build rule to use opencl icd loader if necessary PiperOrigin-RevId: 506152314 * TensorSpec experimental_get_compiler_ir improve the captured_input support. Major changes include: * Enable the compiler_ir.from_concrete_function support speicialize_flat_input. * Improve experimental_get_compiler_ir functionality: support captured_input PiperOrigin-RevId: 506158256 * [StableHLO to MHLO] Improve Python bindings for MHLO StableHLO PR: https://github.com/openxla/stablehlo/pull/283. PiperOrigin-RevId: 506161080 * Avoid unnecessary polling in EventMgr. The TF EventMgr lets you enqueue a std::function to be run when an se::Stream finishes all the work that's currently enqueued on it. It does this by creating se::Event's on the streams and periodically polling all of them to see if they're completed. This poll loop is very expensive for some clients. If you have two se::Event's enqueued on the same se::Stream and the first event has not been hit yet, then you can be sure that the second one also hasn't been hit: A Stream's work runs in strict FIFO order. Previously EventMgr would check all of the events on every stream, doing unnecessary work. This CL changes it so it stops after the first event on a stream that hasn't been hit yet. If there are often multiple events pending on a particular stream, this should save significant CPU. While we're here, we also cleaned up EventMgr. Previously it had additional functionality about freeing tensors, but this was ripped out a while ago. Cleaning this up allows us to simplify the code somewhat. PiperOrigin-RevId: 506161538 * [StableHLO to MHLO] Relax dimension checks in TriangularSolveOp StableHLO PR: https://github.com/openxla/stablehlo/pull/893. PiperOrigin-RevId: 506162066 * [XLA] Use the async copy elapsed instead of prefetch interval picker to decide whether to disable end-of-program prefetch optimization. The shape override introduced in cl/504951495 caused the heuristic that disables end-of-program prefetch optimization to break since it was using the prefetch interval picker to gauge how long the cross-program prefetch is going to be live. This CL changes the logic to use the cost analysis directly. PiperOrigin-RevId: 506172259 * Minor touch up in release notes for 2.12. PiperOrigin-RevId: 506185475 * [StableHLO to MHLO] Handle bounds in the WhileOp shape function PiperOrigin-RevId: 506186744 * [XLA] Fix HLO parser for attribute allow_spmd_sharding_propagation_to_output. PiperOrigin-RevId: 506195622 * [StableHLO to MHLO] Remove AllShapesMatch from DynamicUpdateSliceOp StableHLO PR: https://github.com/openxla/stablehlo/pull/892. PiperOrigin-RevId: 506199864 * Implement functions for retrieving initializer functions in `tf_saved_model` dialect. Retrieving initializer functions is a common operation done in TensorFlow graph transformation passes. This change provides functions for this in the `tf_saved_model` dialect. This also replaces initializer function retrieval codes with the new functions. PiperOrigin-RevId: 506201497 * Removed `ParallelTensor` from `TensorWithLayout` and used `TensorHandlePtr`. PiperOrigin-RevId: 506209442 * [xla:cpu] Add debug info to XLA CPU pipeline This adds a pass that provides some debug info with which basic line number info can be generated. Adapted from Flang's AddDebugFoundationPass. PiperOrigin-RevId: 506213461 * update fuzztest dependency PiperOrigin-RevId: 506217195 * Remove references to stream_executor/lib PiperOrigin-RevId: 506225078 * [XLA:GPU] Handle device buffers more safely in run_hlo_module This fixes double-free errors or memory leaks for example when the running of the HLO is unsuccessful. The old code-path is also left there, as a lot of our code depends on the ability to run the same HLO multiple times without reallocating the input buffers. PiperOrigin-RevId: 506238363 * compat: Update forward compatibility horizon to 2023-02-01 PiperOrigin-RevId: 506239134 * Update GraphDef version to 1394. PiperOrigin-RevId: 506239156 * Fix a typo in the documentation in preemption_watcher.py PiperOrigin-RevId: 506240202 * Rollback of PR #58763 PiperOrigin-RevId: 506243978 * [GmlSt] Group tiling passes for cpu, gpu and triton. PiperOrigin-RevId: 506244287 * Propagate quantize_params in prepare_pass PiperOrigin-RevId: 506252805 * [GmlSt] Remove bufferization test pass. Use hlo-one-shot-bufferize instead. PiperOrigin-RevId: 506260068 * Fix build breakage for DPB. PiperOrigin-RevId: 506261904 * Integrate LLVM at llvm/llvm-project@00ce96b02e87 Updates LLVM usage to match [00ce96b02e87](https://github.com/llvm/llvm-project/commit/00ce96b02e87) PiperOrigin-RevId: 506269173 * Implement clamping of dynamic{_update,}slice start indices. PiperOrigin-RevId: 506270039 * [xla:gpu] Add verbose logging to cuda graph library Optionally print captured graphs dot files to help with debugging PiperOrigin-RevId: 506270560 * Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/89dc2707c7195dc2b839c7a1a987309d91fc89c7. PiperOrigin-RevId: 506270854 * [GmlSt] Split vectorization.cc into vectorize_copy/vectorize_for_cpu,gpu files. PiperOrigin-RevId: 506273813 * Fix bounds checks. - transfer_{read,write} was checking memory bounds incorrectly. - check all buffer accesses. - make invalid accesses interpreter failures instead of asserting. PiperOrigin-RevId: 506286548 * Manage snapshot streams assignments in tf.data service dispatcher. Related changes: - Added `DispatcherService::GetSnapshotStreams`, a new readonly API for seeing the state of snapshot stream assignments from the dispatcher's perspective. - Made `DispatcherConfig.worker_timeout_ms` configurable. PiperOrigin-RevId: 506287683 * Remove multiple defines of XlaLegalizeTFNoFallback This occurred because xla_legalize_tf_passes.h.inc technically depends on all passes listed in the .td file being defined. However, the no-fallback pass is intentionally supposed to be in a separate target. For now, depend on no-fallback, so xla_legalize_tf is correct, but xla_legalize_tf_no_fallback should be fully moved to a separate .td/.h file, so it doesn't surface unsupported methods. PiperOrigin-RevId: 506290313 * Skip invalid candidates, add flag for no canonicalization, bisect for errors. Don't ask me how long it took me to realize that canonicalization goof while debugging canonicalization. PiperOrigin-RevId: 506291648 * Fix hybrid indy lstm by forwarding `recurrent_to-*` parameters to `ComputeRowSums`. PiperOrigin-RevId: 506312178 * Fix a bug in which an invalidated reference to a hash table element is used after a potential rehash. `emplace` can cause a rehash that invalidates references to elements in the hashtable. PiperOrigin-RevId: 506313210 * Add path to snapshot-level done file in tf.data service snapshot on-disk state. PiperOrigin-RevId: 506317430 * Identify the "file_prefix" tensor by matching the unique `tf_saved_model.index_path` attribute. Currently the `file_prefix` tensor, which is used to identify the directory to the checkpoint file from which the variables are restored, is identified by relying on the fact that it is used as an input to the restore op. Doing so makes some assumptions (the name of the restore op) and is prone to accidental conflict. We can find the file_prefix tensor with more certainty by seeing whether the `tf_saved_model.index_path` attribute matches `__tf_file_prefix`. PiperOrigin-RevId: 506318827 * Add abstract base types for common `dataset_ops` types. The presently added types do not define any abstract methods, attributes, properties etc. for their equivalent `dataset_ops` concrete types. I.E., they do not currently define the "shape" of the type and are primarily intended for use in `isinstance` checks to avoid a direct dependency on the concrete type. The types are currently only exported under the internal namespace. PiperOrigin-RevId: 506320622 * TF Lite Runtime support for Python 3.10 under glibc 2.31 * Improve DPB documentation. PiperOrigin-RevId: 506333596 * Update ANDROID_NDK_API_LEVEL default in configure.py PiperOrigin-RevId: 506335783 * #tf-data Ramp down `stage_based_autotune` to do analysis based on the data collected. PiperOrigin-RevId: 506340159 * Integrate LLVM at llvm/llvm-project@0ece2050da3e Updates LLVM usage to match [0ece2050da3e](https://github.com/llvm/llvm-project/commit/0ece2050da3e) PiperOrigin-RevId: 506340480 * [xla][mlir][sparse] allow sparse shapes at construction time only PiperOrigin-RevId: 506342697 * Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/afca233650bc0ce402e8a9a07787732b04bef7aa. PiperOrigin-RevId: 506343516 * [GmlSt] Add dim(gml_st.parallel) and gml_st.parallel(tensor.cast) patterns. Additional canonicalization patterns for gml_st.parallel loop. PiperOrigin-RevId: 506344025 * Temporarily disable the flaky test for Kokoro build. PiperOrigin-RevId: 506345945 * small cleanup of fuzz helper PiperOrigin-RevId: 506350422 * #tf-data-service Clean up checkpoints after completing the snapshot. PiperOrigin-RevId: 506355658 * Check that RefineDynamicShapes doesn't leave dynamic shapes around It is expected that RefineDynamicShapes in the XlaCallModuleOp kernel fully specializes the StableHLO program to static shapes. However, we aren't checking that, so specialization failures may go unnoticed and manifest downstream in weird ways where they are harder to debug. This CL introduces an early check for this. This is a second attempt at landing this CL. The first attempt broke some tests and got rolled back. Now the broken test is disabled because it was relying on wrong behavior that we started detecting thanks to the increased scrutiny implemented here. PiperOrigin-RevId: 506356516 * Expand applicability of real_dynamic_slice canonicalizers At the moment, RealDynamicSliceOp => SliceOp canonicalization only works when start_indices, limit_indices and strides are all of type arith::ConstantOp. This CL extends canonicalization to handle any kind of m_Constant ops. Furthermore, this CL supersedes the RealDynamicSliceIsStatic C++ pattern with the RealDSliceToSlice TableGen pattern. I'm not sure why both of these patterns were enabled when they are doing roughly the same thing. PiperOrigin-RevId: 506356645 * Add warning about assumed input_signatures PiperOrigin-RevId: 506357398 * [GmlSt] Use upstream patterns to collapse extract/insert_slice. PiperOrigin-RevId: 506358242 * Modify LiteralTestUtil to ensure dynamic dimensions are equivalent when checking equality. Previously the LiteralTestUtil would consider two dynamic literals equal as long as they had identical elements (even if they had different dynamic dimensions). PiperOrigin-RevId: 506359222 * feat: update boringssl to fix aarch64 build failures PiperOrigin-RevId: 506366004 * [TF:PJRT] Use PjRtDeviceContext in XlaDevice. - Use AsyncValueAllocator as the allocator when PjRtDeviceContext is used. - Update places that use XlaDeviceContext as signature to DeviceContext. - Change GetXlaOpsCommonFlags to return XlaOpsCommonFlags* so that the flag tf_xla_use_device_api can be set in the test. - Implement Name() in AsyncValueAllocator which is a virtual function. PiperOrigin-RevId: 506369982 * Remove time (AKA time_fraction) field, since it's no longer used. We now compute this in the frontend to avoid storing this redundant field in the protobuf. PiperOrigin-RevId: 506372540 * Fix crash in simplifyDynamicGatherToGather DynamicGatherOp's slice_sizes is 1DTensorOf<[HLO_DimensionValue]> where HLO_DimensionValue is AnyTypeOf<[Index, HLO_Int]>. However, GatherOp's slice_sizes is I64ElementsAttr. If there's a mismatch in element types, canonicalization from DynamicGatherOp to GatherOp will crash, so in that case we need to explicitly convert the elements. PiperOrigin-RevId: 506374817 * Remove legacy references from `ops.py`. This is done to eventually remove the lazy loads in `indexed_slices.py`. PiperOrigin-RevId: 506375428 * Upgrade clang toolchain to use clang-16. PiperOrigin-RevId: 506381712 * [xla:gpu] Remove check on XLA_FLAGS when doing deserialization We no longer need this because XLA Runtime is enabled by default. PiperOrigin-RevId: 506382068 * Adding profiler assertions for TPU. PiperOrigin-RevId: 506382173 * Fix a operand does not dominate bug caused by tpu_extract_outside_compilation. tpu_extract_outside_compilation can create a _XlaRecvAtHostOp or a _XlaRecvAtHostV2Op op to receive at the host side. Its operand identification function (GetStaticExternalOperands) avoids including operands that are already on the host by checking if they are set by a recv, since a recv would have been created already. The bug was that only _XlaRecvAtHostV2Op was counted as a recv, not _XlaRecvAtHostOp. PiperOrigin-RevId: 506383107 * Add copybara config tests. PiperOrigin-RevId: 506396032 * Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/250ad8a0ccdab9d6882931d0dcdfa8fa73eceadf. PiperOrigin-RevId: 506399106 * [IFRT] Add additional ArrayImpl tests with various host buffer semantics Additional tests verify that the `Array` implementation meet the API contract as defined by `HostBufferSemantics`. This change also adds a revised version of the `PjRtBuffer::HostBufferSemantics` comment. It does not yet define a new IFRT `HostBufferSemantics` type yet for a JAX compatibility. PiperOrigin-RevId: 506401836 * [GmlSt] Use the gml-st-cpu-tiling-pipeline to test transformations. We used to test separate transformation patterns that don't include vectorization. These tests check the transformation + vectorization. Later there will be additional CHECKs in the same files for bufferization to verify that we don't allocate inside the loops. Reduce and matmul will be updated in a follow-up. PiperOrigin-RevId: 506408069 * Silence some pytype errors. PiperOrigin-RevId: 506409026 * Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/c6df8512943f31fbf1d2cf3fcdcbc6bc1aa747db. PiperOrigin-RevId: 506409909 * Add a DistributedValue that is backed by DTensor instance. This can be used as the input and output value for a strategy.run function. PiperOrigin-RevId: 506414540 * Redirect usages of `convert_variables_to_constants` from `graph_util_impl.py` to `convert_to_constants.py` to remove a cycle. PiperOrigin-RevId: 506414914 * Add out of bounds array check to dynamic_stitch_op. PiperOrigin-RevId: 506418249 * Partial rollforward of PR #59315. Bring back two following fixes for TF MacOS + Metal PluggableDevice: - TensorList op exclusion for MacOS - Temporary hack to avoid jit_compile on MacOS. Eigen buffer alignment fix is not included in this rollforward and will be in a separate commit. END_PUBLIC *** Reason for rollback *** Partial rollforward of PR #59315. *** Original change description *** Automated g4 rollback of changelist 504212615. Rollback PR #59315. Breaks MacOS tests. For eg: tensorflow/core/framework:tensor_test PiperOrigin-RevId: 506419803 * [XLA] Create skeleton for a partition assignment pass, which annotates the given module with (good) shardings, by adding: - an HLO pass: PartitionAssignment - a base class: PartitioningAlgorithm, - a no-op derived class extending PartitioningAlgorithm: NoopPartitioning, and - a flag to determine the algorithm (kind/type): xla_partitioning_algorithm. PiperOrigin-RevId: 506423268 * [XLA:CPU] Add concat benchmarks PiperOrigin-RevId: 506427653 * Fix memory corruption vulnerability in reverse_sequence_op. PiperOrigin-RevId: 506433062 * [XLA] Support merging partially replicated dimension in complex resharding code. PiperOrigin-RevId: 506433374 * Implement Tensorflow verification pass that ensures no TF dialect ops remain. Required so that we can remove the allow_partial_conversion check in LegalizeTF, which is required to only call LegalizeTF once. PiperOrigin-RevId: 506434351 * Delete SetPayload(absl::string_view, absl::string_view); PiperOrigin-RevId: 506438247 * Integrate LLVM at llvm/llvm-project@dbd02002dd0c Updates LLVM usage to match [dbd02002dd0c](https://github.com/llvm/llvm-project/commit/dbd02002dd0c) PiperOrigin-RevId: 506440373 * [Tensorflow] Fix security vulnerability with TensorListSplitOp PiperOrigin-RevId: 506441188 * Remove unused code in cost analysis PiperOrigin-RevId: 506441280 * Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/851d62673267a061aab673a33fa9ad37a5aa39fb. PiperOrigin-RevId: 506442730 * Replace `error_message()` with `message()` since we have upgraded to a newer protobuf PiperOrigin-RevId: 506443040 * #tf-data-service Add a test util for `choose_from_datasets`. PiperOrigin-RevId: 506444429 * #tf-data-service Add a check for infinite datasets. The next step is to support `repeat`, for example: datasets = [tf.data.Dataset.from_tensors("a").repeat(10), tf.data.Dataset.from_tensors("b").repeat(10), tf.data.Dataset.from_tensors("c").repeat(10)] choice_dataset = tf.data.Dataset.range(3).repeat() dataset = tf.data.Dataset.choose_from_datasets(datasets, choice_dataset) PiperOrigin-RevId: 506448078 * Recognize empty input_signatures with default value parameters PiperOrigin-RevId: 506449019 * Limit the thread pool size of the TFE context used for constant folding PiperOrigin-RevId: 506454804 * [jax] Skip compilation cache test for older jaxlibs PiperOrigin-RevId: 506460144 * Add back CLANG_CUDA_COMPILER_PATH to gpu_clang.bazelrc. PiperOrigin-RevId: 506468121 * Rollback the change to add GPU PJRT client. PiperOrigin-RevId: 506477686 * - Add _cast() to TraceType - Implement _cast() to default types and TensorSpec PiperOrigin-RevId: 506479924 * Cast `status.message()` explicitly to `std::string` PiperOrigin-RevId: 506502241 * Canonicalize RealDynamicSliceOp to DynamicSliceOp We know how to canonicalize RealDynamicSliceOp to SliceOp (when all attributes are static), but there is one more case when RealDynamicSliceOp can be canonicalized to a simpler op // Before rewrite %slice_sizes = mhlo.constant ... %limit_indices = mhlo.add %start_indices, %slice_sizes %strides = mhlo.constant dense<1> %result = mhlo.real_dynamic_slice %operand, %start_indices, %limit_indices, %strides // After rewrite %result = "mhlo.dynamic_slice"(%operand, %start_indices0, ...) { slice_sizes = ... } PiperOrigin-RevId: 506504799 * Disable `tensorflow/dtensor/python/tests:spmd_test` on Python 3.8 PiperOrigin-RevId: 506505212 * Disable `tensorflow/dtensor/python/tests:multi_client_test_nccl_local` on OSS PiperOrigin-RevId: 506507742 * Add Python specific disable tags to the bazel configs PiperOrigin-RevId: 506508975 * [Tensorflow] Fix security vulnerability with DenseBincountOp PiperOrigin-RevId: 506514542 * Update Eigen to commit:3460f3558e7b469efb8a225894e21929c8c77629 CHANGELOG ========= 3460f3558 - Use VERIFY_IS_EQUAL to compare to zeros. 13a1f25da - Revert StlIterators edit from "Fix undefined behavior..." fd2fd4870 - Update file ForwardDeclarations.h 37b2e9717 - Tweak special case handling in atan2. a1cdcdb03 - Fix undefined behavior in Block access 4a58f30aa - Fix pre-POWER8_VECTOR bugs in pcmp_lt and pnegate and reactivate psqrt. 12ad99ce6 - Remove unused variables from GenericPacketMathFunctions.h 6987a200b - Fix stupid sparse bugs with outerSize == 0 0471e61b4 - Optimize various mathematical packet ops 1aa6dc200 - Fix sparse warnings 17ae83a96 - Fix bugs exposed by enabling GPU asserts. ab8725d94 - Turn off vectorize version of rsqrt - doesn't match generic version 6d9f662a7 - Tweak atan2 6fc9de7d9 - Fix slowdown in bfloat16 MMA when rows is not a multiple of 8 or columns is not a multiple of 4. 6d4221af7 - Revert qr tests 7f58bc98b - Refactor sparse 576448572 - More fixes for __GNUC_PATCHLEVEL__. 164ddf75a - Use __GNUC_PATCHLEVEL__ rather than __GNUC_PATCH__, according to the documentation https://gcc.gnu.org/onlinedocs/cpp/Common-Predefined-Macros.html 5a7ca681d - Fix sparse insert 08c961e83 - Add custom ODR-safe assert. 3fe8c5110 - Replace the Deprecated `$` with `$` d70b4864d - issue #2581: review and cleanup of compiler version checks b52312068 - [SYCL-2020 Support] Enabling Intel DPCPP Compiler support to Eigen bae119bb7 - Support per-thread is_malloc_allowed() state fa0bd2c34 - improve sparse permutations 2e61c0c6b - Add missing EIGEN_DEVICE_FUNC in a few places when called by asserts. 4aca06f63 - avoid move assignment in ColPivHouseholderQR 68082b822 - Fix QR, again 4d0576534 - Altivec fixes for Darwin: do not use unsupported VSX insns PiperOrigin-RevId: 506525228 * Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/ec4a1c4d591c9c5be3ae207551452f2f667177c7. PiperOrigin-RevId: 506526946 * Update GraphDef version to 1395. PiperOrigin-RevId: 506547190 * compat: Update forward compatibility horizon to 2023-02-02 PiperOrigin-RevId: 506547933 * [XLA] Use wide accumulator for integer types in HloEvaluator. Generally, this should not affect the operations, as the results are downcasted to ReturnT. Some integer operations (SHR, CLZ, popcnt) were updated, as they didn't previously support cases where ReturnT != ElementwiseT For convolutions, clamp the result to the ReturnT range, as discarding the high bits doesn't make sense. This allows to enable convolution tests that would otherwise fail (cl/506267884). PiperOrigin-RevId: 506548096 * Add `diagonal_recurrent_tensors` attribute to UNIDIRECTIONAL_SEQUENCE_LSTM op. PiperOrigin-RevId: 506553811 * Add test for tf.TensorScatterAdd PiperOrigin-RevId: 506561719 * Reference the `benchmark_model` instructions from the delegate performance benchmark README. Running `benchmark_model` can be useful for quick feedback during the early stages of development. PiperOrigin-RevId: 506568968 * Add convolution tests for int8x32 cuDNN vectorized layout PiperOrigin-RevId: 506573468 * [GmlSt] Split and clean-up codegen tests for matmul. PiperOrigin-RevId: 506574062 * Also log the execution time in run_hlo_module. replay_computation has this functionality, and the goal is to replace it with run_hlo_module. PiperOrigin-RevId: 506584404 * Integrate LLVM at llvm/llvm-project@7d3a181c8c18 Updates LLVM usage to match [7d3a181c8c18](https://github.com/llvm/llvm-project/commit/7d3a181c8c18) PiperOrigin-RevId: 506591519 * Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/49fef8924ba03c721f5f1125df217d902c42d1c3. PiperOrigin-RevId: 506593472 * [GmlSt] Split and clean-up codegen tests for reduce. PiperOrigin-RevId: 506601849 * [XLA:TPU] Avoids serializing large literals that can cause high compilation latencies. PiperOrigin-RevId: 506622024 * hide extra devices instead of raising an error. This change relaxed DTensor's safety check in the NCCL path too, to conform that tf's set visible device doesn't affect the physical device list. PiperOrigin-RevId: 506643035 * Cleanup: rename initializers_v2.py to initializers.py. PiperOrigin-RevId: 506645119 * Suppress a noisy log line. PiperOrigin-RevId: 506651054 * Provide a better error message in case of compilation failure PiperOrigin-RevId: 506656750 * Register a custom codec for `extension_type.ExtensionTypeSpec` to remove `nested_structure_coder.py`'s dependency on `extension_type.py`. PiperOrigin-RevId: 506657587 * Fix use-after-move in iterator_ops.cc PiperOrigin-RevId: 506659126 * [XNNPACK] Fix some error logging in delegate logging_context can be nullptr, so use a different macro for error logging PiperOrigin-RevId: 506664135 * Register codecs for `row_partition.RowPartitionSpec` and `ragged_tensor.RaggedTensorSpec` to remove `nested_structure_coder.py`'s dependency on them. PiperOrigin-RevId: 506666103 * Change how -Xcuda-fatbinary is passed depending on the compiler used. PiperOrigin-RevId: 506667062 * Register a codec for `resource_variable_ops.VariableSpec` to remove `nested_structure_coder.py`'s dependency on `resource_variable_ops.py`. PiperOrigin-RevId: 506676601 * [xla:cpu:next] Add remove-copies-to-out-params pass To remove redundant allocations and subsequent copies to out parameters, which come from buffer allocation. The reason why these exist is that during bufferization we must allocate a buffer for each returned result. It is only post-bufferization that we run BufferResultsToOutParams, which inserts copies to those "out" buffers from the allocated ones. The pass added here detects this pattern and removes the allocation and copy, using each output buffer directly. Example input: ``` func.func @main(%arg0: tensor<1024xf64>) -> tensor<1024xf64> { %0 = mhlo.add %arg0, %arg0 : tensor<1024xf64> return %0 : tensor<1024xf64> } ``` $ xla-opt -split-input-file -hlo-xla-runtime-pipeline %s - Before: ``` module { func.func @main(%arg0: memref<1024xf64>, %arg1: memref<1024xf64>) { %c1024 = arith.constant 1024 : index %c0 = arith.constant 0 : index %c8 = arith.constant 8 : index %cst = arith.constant 0.000000e+00 : f64 %alloc = memref.alloc() {alignment = 64 : i64} : memref<1024xf64> scf.parallel (%arg2) = (%c0) to (%c1024) step (%c8) { %subview = memref.subview %alloc[%arg2] [8] [1] : memref<1024xf64> to memref<8xf64, strided<[1], offset: ?>> %0 = vector.transfer_read %arg0[%arg2], %cst {in_bounds = [true]} : memref<1024xf64>, vector<8xf64> %1 = arith.addf %0, %0 : vector<8xf64> vector.transfer_write %1, %subview[%c0] {in_bounds = [true]} : vector<8xf64>, memref<8xf64, strided<[1], offset: ?>> scf.yield } memref.copy %alloc, %arg1 : memref<1024xf64> to memref<1024xf64> return } } ``` - After: ``` module { func.func @main(%arg0: memref<1024xf64>, %arg1: memref<1024xf64>) { %c1024 = arith.constant 1024 : index %c0 = arith.constant 0 : index %c8 = arith.constant 8 : index %cst = arith.constant 0.000000e+00 : f64 scf.parallel (%arg2) = (%c0) to (%c1024) step (%c8) { %subview = memref.subview %arg1[%arg2] [8] [1] : memref<1024xf64> to memref<8xf64, strided<[1], offset: ?>> %0 = vector.transfer_read %arg0[%arg2], %cst {in_bounds = [true]} : memref<1024xf64>, vector<8xf64> %1 = arith.addf %0, %0 : vector<8xf64> vector.transfer_write %1, %subview[%c0] {in_bounds = [true]} : vector<8xf64>, memref<8xf64, strided<[1], offset: ?>> scf.yield } return } } ``` PiperOrigin-RevId: 506678216 * Register a codec for `tensor_array_ops.TensorArraySpec` to remove `nested_structure_coder.py`'s dependency on `tensor_array_ops.py`. PiperOrigin-RevId: 506681769 * [mhlo] Remove the undefined AllReduceOp build(). PiperOrigin-RevId: 506683695 * Use graph export pipeline V2 in TPU Bridge This new graph export pipeline can avoid to generate some unnecessary control dependencies, bring better performance and make the control dependencies more readable. PiperOrigin-RevId: 506687026 * Move custom codecs for TensorSpec and BoundedTensorSpec to `tensor_spec.py`. Register a codec for `sparse_tensor.SparseTensorSpec`. PiperOrigin-RevId: 506690720 * Factor out get_default_ops and make get_ops_from_nodedef a public method in TF selective_registration_header_lib. PiperOrigin-RevId: 506697634 * Integrate LLVM at llvm/llvm-project@6dd84983d0c1 Updates LLVM usage to match [6dd84983d0c1](https://github.com/llvm/llvm-project/commit/6dd84983d0c1) PiperOrigin-RevId: 506708273 * Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/4487e42c7e3dc1f6d641bb1c98b01990fbbbc167. PiperOrigin-RevId: 506711487 * Register a codec for `indexed_slices.IndexedSlicesSpec` to remove `nested_structure_coder.py`'s dependency on `indexed_slices.py`. PiperOrigin-RevId: 506715268 * Disable tsan for distributed snapshot fault tolerance tests. PiperOrigin-RevId: 506724161 * #tf-data-service Update the default protocol in DistributedSaveOp. PiperOrigin-RevId: 506725241 * use common string for profiler lock contention detection. PiperOrigin-RevId: 506726080 * gpu_delegate: Link nativewindow PiperOrigin-RevId: 506727041 * Call XNNPACK Transpose from TFLite kernel. PiperOrigin-RevId: 506737055 * Integrate LLVM at llvm/llvm-project@16c8709cf61b Updates LLVM usage to match [16c8709cf61b](https://github.com/llvm/llvm-project/commit/16c8709cf61b) PiperOrigin-RevId: 506742350 * Fix dimension mismatch bug in MultinomialOp GPU implementation. PiperOrigin-RevId: 506744108 * Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/6e37e534eaa88a022470d77d457722249235d331. PiperOrigin-RevId: 506745467 * #tf-data-service Write tmp files in the same file system as the snapshot. `rename` requires the source and destination files be in the same file system. The temp files are named similar to https://github.com/tensorflow/tensorflow/blob/33722bc185e676c99f738790ef35db8479f2f7d4/tensorflow/core/data/snapshot_utils.cc#L950. PiperOrigin-RevId: 506746696 * Add a metric in the eager function runtime to measure when a tf.function should be compiled. This metric will cover all TF2 jit_compilation paths including TPU to give an accurate number for the number of tf.functions that will be compiled per device type. PiperOrigin-RevId: 506750567 * Integrate LLVM at llvm/llvm-project@10939d1d580b Updates LLVM usage to match [10939d1d580b](https://github.com/llvm/llvm-project/commit/10939d1d580b) PiperOrigin-RevId: 506761372 * #tf-data-service Use absl::Duration in dispatcher_impl. PiperOrigin-RevId: 506762070 * Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/91d765cad5599f9710973d3e34d4dc22583e2e79. PiperOrigin-RevId: 506763472 * Fix for the "bfc_allocator_delay" metric being registered multiple times. PiperOrigin-RevId: 506778911 * support big-endian for numpy type descriptor * Patch llvm to fix Windows build. PiperOrigin-RevId: 506800859 * Allow batch function to avoid padding the inputs. PiperOrigin-RevId: 506820270 * [XLA] Speed up constant folding by optimizing and inlining a few simple index/shape/layout utility functions. PiperOrigin-RevId: 506827782 * Integrate LLVM at llvm/llvm-project@9b7e57470155 Updates LLVM usage to match [9b7e57470155](https://github.com/llvm/llvm-project/commit/9b7e57470155) PiperOrigin-RevId: 506834755 * Update GraphDef version to 1396. PiperOrigin-RevId: 506834807 * compat: Update forward compatibility horizon to 2023-02-03 PiperOrigin-RevId: 506834823 * Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/27d9da88424935b171d2f28b63658d3ee85bfe5c. PiperOrigin-RevId: 506836991 * Move ReplicateTensorListInitOps pass before Canonicalizer pass to cleanup uninitialized tensor lists. PiperOrigin-RevId: 506838420 * Implement `tf.AssignVariableOp(tf.VarHandleOp, tf.ConstOp)` removal pattern into a separate pass. This was part of the `InsertRestoreOpPass` to replace the variable initialization patterns that uses consts with the initialization that uses RestoreV2 op. In order to support future plans to insert SaveV2 op and to make the passes more modular, this change splits the pattern removal part out into a separate pass named `RemoveVariableInitializationByCosntPass`. The new pass is put right after the `InsertRestoreOpPass` and the resulting exported model should be the same as without this change. PiperOrigin-RevId: 506863386 * Make to use new MLIR for 16x8 unidirectional LSTM operation PiperOrigin-RevId: 506864247 * Add tests to rigorously check channel dimension attributes Checks that all relevant attributes are properly changed for per_channel and per_tensor case. PiperOrigin-RevId: 506873551 * Fix a bug in the validation_model build macro. This build macro uses a temporary file named $(@D)/tmp, which is a problem, because this build macro is instantiated twice in the same package, and both of the generated rules are executed in parallel during the build, and both try to write to the same tmp file, and then both of them try to remove it. This leads to one of the rules failing due to the file not being found, on account of it having been removed by the other rule. The fix is to instead use $(@D)/.tflite.tmp as the name of the temporary file, where is the name of the rule. This is sufficient to ensure that the temporary file names used by the different instantiations of this build macro are distinct. PiperOrigin-RevId: 506882948 * Fixes crashes due to buzz input for ApproxTopK PiperOrigin-RevId: 506898015 * [JITRT] Add scatter benchmark. PiperOrigin-RevId: 506909548 * [XLA] Reduce some unnecessary string creations in xla::Printer callsites. PiperOrigin-RevId: 506909676 * always instrument module level no matter if trace is enabled. PiperOrigin-RevId: 506915862 * In ExportXlaFunctions, iterate all ops under the xla function instead of only the ones that inherit from SymbolUserOpInterface. Ideally all ops referencing functions should inherit from SymbolUserOpInterface. But that would take some time to happen. PiperOrigin-RevId: 506918556 * Move `_TensorShapeCodec` to `tensor_shape.py` to remove `nested_structure_coder.py`'s dependency on `tensor_shape.py`. PiperOrigin-RevId: 506923585 * Adds a method to build datasets on workers without creating an iterator when doing ParameterServerStrategy training. PiperOrigin-RevId: 506923786 * adding TensorShape fuzzers PiperOrigin-RevId: 506927213 * Turn lazy load of `nested_structure_coder` in `type_spec.py` into a regular import. PiperOrigin-RevId: 506927612 * Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/2796b3e7ea8a3e7614029f9e307c0113f8d6bb90. PiperOrigin-RevId: 506929191 * Set shared_name to node name if it is empty in BatchFunctionFallback op PiperOrigin-RevId: 506931118 * Add some missing BUILD dependencies of `structured_tensor.py`. PiperOrigin-RevId: 506932682 * Remove dependency on indexed_slices.internal_convert_to_tensor_or_indexed_slices. PiperOrigin-RevId: 506934163 * [XLA] Add way to allow propagation to output only to a subset of root instruction tuple shardings. PiperOrigin-RevId: 506935285 * Support float64 under CPU/GPU CopyToMesh PiperOrigin-RevId: 506939058 * Turn lazy loads related to `saved_model/nested_structure_coder.py` in `data/ops/` into regular imports. PiperOrigin-RevId: 506940249 * Delete legacy objects from `ops.py`. I have moved all references to these objects to reference their respective files. PiperOrigin-RevId: 506945566 * Add a new replica context for DTensor related strategy. Since DTensor operate in a global context, we mostly raise an explicit error about methods are not available in the context. This will lead to a behavior discrepancy between the existing strategy and new one. We should consider this carefully for the future plan when rebasing the strategy on top of DTensor (eg using replicated_run/spmd_run). PiperOrigin-RevId: 506945673 * Replace `nest` usage in `FuncGraph.capture_call_time_value()` with `TraceType._cast()` PiperOrigin-RevId: 506949659 * Add implementation for strategy.gather() for new MirroredStrategy. PiperOrigin-RevId: 506952602 * [Tensorflow] Fix security vulnerability with UniqueV2. The bug is that the axis index should be canonicalized when it's negative. PiperOrigin-RevId: 506966510 * [XLA:CPU] Scalarize scf.if op PiperOrigin-RevId: 506969655 * [tflite-gpu] Fix OpenGL slice calculation bug. PiperOrigin-RevId: 506971865 * Don't use xnnpack in kernel with the flag --define=tflite_with_xnnpack=false PiperOrigin-RevId: 506975198 * [XLA:TPU] Speed up constant propagation by inlining a few methods in comparison_util.h. PiperOrigin-RevId: 506977511 * Made `TensorWithLayout` an abstract class and defined `TensorWithLayoutTf` to hold the current implementation using TF runtime. Also used `ConstValueNode` to capture the information useful for TF runtime. PiperOrigin-RevId: 506978870 * Skip license check for pybind11_abseil PiperOrigin-RevId: 506979746 * Moving control_captures out of FuncGraph PiperOrigin-RevId: 506985311 * Use gfile over the native open for an implementation based on TensorFlow's C++ FileSystem API. PiperOrigin-RevId: 506985831 * Make the HloModule constructor that takes CompilationEnvironments public. PiperOrigin-RevId: 506993058 * Add float16 and float64 input&output type support for TFLite operator 'cast' Type float16 and float64 input/output for TensorFlow 'cast' operator is used in some Federated Learning models, thus adding these type supports to TFLite 'cast' op can make these operators converted to TFLite build-in ops instead of flex ops. PiperOrigin-RevId: 506997479 * color adjust for timeline PiperOrigin-RevId: 507002833 * [TF:PLUGIN] Fix a dependency issue. PiperOrigin-RevId: 507003433 * Add isinstance check for eager execution. PiperOrigin-RevId: 507003564 * -Add 4 bit support to depthwise_conv.cc and fully_connected.cc in TfLite using the reference kernels 4bit functions for those op . And added/changed supporting functions to get test to run on fully_connected_test.cc -added a 4bit Test(Simple4bit3x3FilterTest) to depthwise_conv_test.cc in Tflite which is ported from the existing Simple3x3FilterTest with adjusted PerChannelQuanization scales for 4bit input. -added a 4bit Test(SimpleTestQuantizedInt4) to fully_connected_test.cc in Tflite which is ported from the existing SimpleTestQuantizedInt8 with adjusted outputs for 4bit. PiperOrigin-RevId: 507003918 * Add Keras metrics FBetaScore and F1Score. PiperOrigin-RevId: 507013486 * Update the TensorFlow RELEASE.md on master. (We cut the branch for 2.12.0. Insert new blurb for new release notes TF 2.13.0) PiperOrigin-RevId: 507017354 * Implement functional<->regional transformation for `CaseOp` and `CaseRegionOp` Even if we already have `CaseRegionOp` as a region version of `CaseOp`, the associated transformations were missing in functional<->regional control flow transformation passes. This CL implements them. PiperOrigin-RevId: 507017912 * In the code, we have some modules that are "based" off other modules but not exact clones. Change these code locations, so these modules retain the CompilationEnvironments from the original module. PiperOrigin-RevId: 507020873 * Update the TensorFlow RELEASE.md on master. (We cut the branch for 2.12.0. Insert new blurb for new release notes TF 2.13.0) PiperOrigin-RevId: 507021191 * Break dependency between tensorflow/core/function/transform:transform and python/saved_model. PiperOrigin-RevId: 507021697 * Handle snapshot and stream completion in tf.data service dispatcher. PiperOrigin-RevId: 507023442 * Fix release notes. Becuase the automation that updates the release notes in master after branch cut for release happens has been destroyed and the step has not been done manually in time, we have commits such as https://github.com/tensorflow/tensorflow/commit/9fbf1137044ac63e296ebf73c61b1e8513149b1c# and https://github.com/tensorflow/tensorflow/commit/ba1372a41ed90aba0efa5763b06350dd0ee7074b that write the wrong contents to the release notes. PiperOrigin-RevId: 507025073 * Implement functional<->regional transformation for `CaseOp` and `CaseRegionOp` Even if we already have `CaseRegionOp` as a region version of `CaseOp`, the associated transformations were missing in functional<->regional control flow transformation passes. This CL implements them. PiperOrigin-RevId: 507030461 * #tf-data-service Wait for the snapshot DONE file in unit tests. PiperOrigin-RevId: 507030943 * Expose TF2XLA MLIR pipeline for reuse PiperOrigin-RevId: 507033096 * Update master version numbers to 2.13.0. Branch for TF 2.12 releases has been cut. Switch to new version. PiperOrigin-RevId: 507037877 * [PJRT C API] Add C API for PjRtCApiClient::LookupAddressableDevice. wrapped_device_map_ and GetCApiDevice can be removed after LookupAddressableDevice is added. PiperOrigin-RevId: 507060995 * Work around compiler bug returning an optional unique_ptr. It looks like some compilers (e.g. gcc-7) don't like returning a moveable value directly when the return type is `std::optional` (i.e. it fails to treat the returned value as an r-value and automatically construct an optional instance around it). Explicitly creating the `std::optional` and returning _that_ seems to work around the issue. PiperOrigin-RevId: 507062621 * Update GraphDef version to 1397. PiperOrigin-RevId: 507090576 * compat: Update forward compatibility horizon to 2023-02-04 PiperOrigin-RevId: 507090630 * Update the version of Estimator nightly and Keras nightly used in TensorFlow after the corresponding nightly packages with the next version are released PiperOrigin-RevId: 507168035 * Add tfstreamz for input spec mismatch cases. PiperOrigin-RevId: 507180105 * Implement shape inference for `CaseOp` and `CaseRegionOp` PiperOrigin-RevId: 507201660 * compat: Update forward compatibility horizon to 2023-02-05 PiperOrigin-RevId: 507239797 * Update GraphDef version to 1398. PiperOrigin-RevId: 507239865 * #tf-data-service Use a mock dispatcher in the split provider test. The original test works by manipulating the files. It makes the test depend on the snapshot_manager's state. When the snapshot_manager implementation changes, it could affect this test because of the structure of the test not because of a bug. Decoupling it from the dispatcher makes the test cleaner, more stable, and less likely to be flaky. PiperOrigin-RevId: 507312461 * Preserve the linear index when computing the operand of concatenate. If the Concatenate op concatenates the fastest varying dimension, we can relatively cheaply preserve the linear index. This is a HLO snippet where we see a 30% improvement with this change: HloModule concatenate ENTRY main { param = f32[100,11,12,13,160]{4,3,2,1,0} parameter(0) param2 = f32[27456000]{0} parameter(1) reshape = f32[100,11,12,13,160]{4,3,2,1,0} reshape(param2) ROOT concat = f32[100,11,12,13,320]{4,3,2,1,0} concatenate(param, reshape), dimensions={4} } PiperOrigin-RevId: 507391165 * compat: Update forward compatibility horizon to 2023-02-06 PiperOrigin-RevId: 507406616 * Update GraphDef version to 1399. PiperOrigin-RevId: 507406617 * Add math dialect. PiperOrigin-RevId: 507412764 * Integrate LLVM at llvm/llvm-project@8c712296fb75 Updates LLVM usage to match [8c712296fb75](https://github.com/llvm/llvm-project/commit/8c712296fb75) PiperOrigin-RevId: 507429387 * Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/e94b53450349f6837d11cc39f614af86c825ef94. PiperOrigin-RevId: 507431532 * Track memref allocations, deallocations and peak heap size. PiperOrigin-RevId: 507449826 * [TF/MLIR] Supports lowering mhlo.reduce_window when there is reshape/broadcast in the divisor. PiperOrigin-RevId: 507455712 * updated TF patch PiperOrigin-RevId: 507456411 * [XLA:GPU] Do not expand Scatter ops that are deterministic when running with --xla_gpu_deterministic_ops. Currently all Scatter ops are expanded when deterministic ops are enforced. However, scatter ops on unique indices cannot have data races irrespective of the implementation. Similarly, scatter ops with associative combiner functions will compute deterministic results irrespective of the order in which the combiner function is applied. In both cases, scatter will be deterministic and expanding it is thus not required. This reduces slowdowns due to the xla_gpu_deterministic_ops flag. PiperOrigin-RevId: 507460239 * BladeDISC patch 20221101 1, Build related changes: * No force -std=c++17 for cuda. https://github.com/pai-disc/tensorflow/commit/af4d5a07589c1d30c14c76aba6592554210451a5 * workaround compilation error on GCC 7.3.1: (#19) like: undefined reference to `std::function (xla::Shape const&, bool, mlir::XlaLayoutPreference)>::function()' * [to #35355928] fix build issue when enabling MLIR_GPU_TO_CUBIN_PASS_ENABLE * disable `-Werror=unused-result` * disable `noincompatible_remove_legacy_whole_archive` * add miss dependency `//tensorflow/compiler/xla/stream_executor:dnn_proto_cc_impl` 2, hlo related changes: * [to #35377611] feat: bufferize DotOp and DotGeneralOp. remove community DotGeneralOp bufferizer as well Link: https://code.aone.alibaba-inc.com/algo/tensorflow_google/codereview/5885662 * [to #35377611] feat: bufferize ConvOp and DynamicConvOp. Link: https://code.aone.alibaba-inc.com/algo/tensorflow_google/codereview/5910553 * [to #37276187] feat: bufferize mhlo.log1pOp Link: https://code.aone.alibaba-inc.com/algo/tensorflow_google/codereview/6675686 * [to #36574644] [MLIR] [DISC] Add reverse op in lmhlo Link: https://code.aone.alibaba-inc.com/algo/tensorflow_google/codereview/6364791 * support RoundNearestEvenOp (#20) * support RoundOp (#21) * add const folding support for mhlo.TransposeOp * disable some static checks of mhlo.dot_general `tf.BatchMatmul(tensor, tensor<4x?x?xf32>)` is valid, while the `tf.BatchMatmul` tf2mhlo converter does not handle shape propagation between the lhs & rhs, leading to some of the static check of `dot_general` failed. Just disable the check as a workaround. * [to #35355928] fix missing Elementwise traits of some special ops * [to #35355928] fix a bug in lmhlo structured interface * enhance `maybeCastTo` to support casting between i64 and i32. * cast to i64 only if index type in DynamicIotaBroadcast pattern. * add a patch not to fold UnrealizedConversionCastOp with ui/si type 3, TF related changes: * lower tf.GatherV2 op to mhlo in dynamic shape * lower tf.DepthwiseConv2DNative op to mhlo in dynamic shape * lower tf.StridedSlice op to mhlo in dynamic shape * lower tf.DynamicStitchOp to mhlo in dynamic shape * lower tf.BatchMatMulOp op to mhlo in dynamic shape * lower tf.Conv2DBackpropInputOp/tf.Conv2DBackpropFilterOp to mhlo in dynamic shape * support tf.Range wit negative stride * support tf.StridedSlice with new_axis_mask * add mhlo_disc dependency in xla legalize_tf * legalize quantized tf const before lowering to mhlo * add tf2mhlo support for tf.BatchMatMul * bugfix: only handling non-const begin/end in ConvertStridedSliceOpDynamic * bugfix: using tf.TileOp static tf2mhlo conversion only when all ins/outs have static shape * bugfix: size of begin/end/strides < the rank of input * bugfix: disable TF_RandomUniformOp in tf->mhlo * fix a bug in tf.SigmoidGradOp legalize pattern * fix a bug in ConvertSplitOp pattern * fix a bug in ConvertUnpackOp pattern * [to #36775150] feat: to support multi StreamExecutor by stream as cache key In the original design of StreamExecutor, one StreamExecutor maps to each device ordinal and owns one cudnnHandle. This means in multiple stream applications, there will be only one StreamExecutor for each GPU device. We observed dramatic performance degrade due to the lock in the SE, refer to https://yuque.antfin-inc.com/pai/blade/irdw7g. This commit revises the executor cache, so that there will be multiple StreamExecutors for each gpu stream. Link: https://code.aone.alibaba-inc.com/algo/tensorflow_google/codereview/6455216 * [to #36574492]feat: dynamic strided slice op supports negative strides Link: https://code.aone.alibaba-inc.com/algo/tensorflow_google/codereview/6425899 3, ROCM/DCU related changes: * [to #37531008] feat: Support building for DCU Link: https://code.aone.alibaba-inc.com/algo/tensorflow_google/codereview/6897356 * [to #37531008] allow asigned stream in se for rocm backend Link: https://code.aone.alibaba-inc.com/algo/tensorflow_google/codereview/6918944 * [to #37383705] patch: Add shfl.sync.bfly lowering to ROCm Link: https://code.aone.alibaba-inc.com/algo/tensorflow_google/codereview/6937184 * [to #39814120] Fix config for Macro and support gfx90a * [rocm] remove -fno-canonical-system-headers as hipcc does not support * [to #37531008] expose some more methods in gpu_backend_lib.h 4, others: * update gitignore * Add llvm patch to fix error GVN on shared memory load. (#4) add folder for mhlo::ClampOp import iree in workspace file increase `kFoldOpEltLimit` to 64GB decompose disc-compiler (#27) * decompose disc-compiler * update * fix some compilation erros * Fix dynamic shape reifyReturnTypeShapes * fix dynamic shapes & mhlo op operands & proto_alls * fix llvm patch * remove iree * reverse unsigned type lowering workaround (#30) --------- Co-authored-by: Aart Bik Co-authored-by: Fiona Lang Co-authored-by: Peng Wang Co-authored-by: Terry Heo Co-authored-by: Francois Chollet Co-authored-by: A. Unique TensorFlower Co-authored-by: John QiangZhang Co-authored-by: Eugene Burmako Co-authored-by: Justin Lebar Co-authored-by: Berkin Ilbeyi Co-authored-by: Xin Zhou Co-authored-by: Ce Zheng Co-authored-by: Dan Suh Co-authored-by: Dateng Lin Co-authored-by: Son Tuan Vu Co-authored-by: David Dunleavy Co-authored-by: Alexander Belyaev Co-authored-by: Ian Hua Co-authored-by: Johannes Reifferscheid Co-authored-by: Eugene Zhulenev Co-authored-by: Matt Callanan Co-authored-by: Tres Popp Co-authored-by: samypr100 <3933065+samypr100@users.noreply.github.com> Co-authored-by: Wilsin Gosti Co-authored-by: Yang Chen Co-authored-by: Faizan Muhammad Co-authored-by: Vadym Matsishevskyi Co-authored-by: Jieying Luo Co-authored-by: Juan Martinez Castellanos Co-authored-by: Brian Wieder Co-authored-by: Anlun Xu Co-authored-by: Roshani Narasimhan Co-authored-by: Michael Delorimier Co-authored-by: Hyeontaek Lim Co-authored-by: Rebecca Chen Co-authored-by: Scott Zhu Co-authored-by: Mason Chang Co-authored-by: Penporn Koanantakool Co-authored-by: Frederik Gossen Co-authored-by: Matthias Kramm Co-authored-by: Marcello Maggioni Co-authored-by: Jean-Baptiste Lespiau Co-authored-by: Jorge Gorbe Moya Co-authored-by: Jian Cai Co-authored-by: Kuangyuan Chen Co-authored-by: Nitin Srinivasan Co-authored-by: Zhufeng Pan Co-authored-by: Sergey Kozub Co-authored-by: Aliia Khasanova Co-authored-by: Matt Kreileder Co-authored-by: Adrian Kuegel Co-authored-by: Andrew Audibert Co-authored-by: Zhi An Ng Co-authored-by: Emilio Cota Co-authored-by: Jared Junyoung Lim Co-authored-by: Jie Sun Co-authored-by: Grant Jensen Co-authored-by: Antonio Sanchez Co-authored-by: Ken Franko Co-authored-by: Fabien Hertschuh Co-authored-by: Kazuaki Ishizaki Co-authored-by: Tomás Longeri Co-authored-by: Bangda Zhou Co-authored-by: Fergus Henderson Co-authored-by: Felix Chern Co-authored-by: Yifan Jiang Co-authored-by: James Mullenbach Co-authored-by: Justin Szaday Co-authored-by: Youchuan Hu Co-authored-by: Chuanhao Zhuge Co-authored-by: Vinila Settem Co-authored-by: Mihai Maruseac Co-authored-by: Ashish Shenoy Co-authored-by: Yiming Zhang Co-authored-by: Thomas Joerg Co-authored-by: Wenyi Zhao Co-authored-by: TanyoKwok --- .bazelrc | 33 +- .bazelversion | 3 +- .../tensorflow_issue_template.yaml | 12 + .github/ISSUE_TEMPLATE/tflite-other.md | 62 + .github/bot_config.yml | 5 +- .github/workflows/arm-cd.yml | 11 +- .github/workflows/arm-ci-extended.yml | 2 +- .github/workflows/arm-ci.yml | 9 +- .github/workflows/cffconvert.yml | 4 +- .github/workflows/issue-on-pr-rollback.yml | 4 +- .github/workflows/pylint-presubmit.yml | 6 +- .../workflows/release-branch-cherrypick.yml | 8 +- .github/workflows/scorecards-analysis.yml | 13 +- .github/workflows/sigbuild-docker-branch.yml | 16 +- .../workflows/sigbuild-docker-presubmit.yml | 16 +- .github/workflows/sigbuild-docker.yml | 16 +- .github/workflows/trusted-partners.yml | 7 +- .github/workflows/trusted_partners.js | 37 +- .github/workflows/update-nightly.yml | 2 +- .github/workflows/update-rbe.yml | 34 +- CONTRIBUTING.md | 2 +- README.md | 13 + RELEASE.md | 369 +- SECURITY.md | 12 +- configure.py | 14 +- fuzztest.bazelrc | 47 + tensorflow/BUILD | 306 +- tensorflow/api_template.__init__.py | 7 +- tensorflow/api_template_v1.__init__.py | 7 +- tensorflow/c/BUILD | 88 +- tensorflow/c/c_api.cc | 11 +- tensorflow/c/c_api_experimental.cc | 6 + tensorflow/c/c_api_experimental.h | 6 + tensorflow/c/c_api_experimental_test.cc | 106 + tensorflow/c/c_api_function.cc | 2 +- tensorflow/c/c_api_internal.h | 8 + tensorflow/c/eager/BUILD | 20 +- tensorflow/c/eager/c_api.cc | 8 +- tensorflow/c/eager/c_api_experimental.cc | 39 +- tensorflow/c/eager/c_api_experimental.h | 19 +- tensorflow/c/eager/c_api_experimental_test.cc | 12 +- tensorflow/c/eager/c_api_internal.h | 5 - .../immediate_execution_distributed_manager.h | 25 +- tensorflow/c/eager/parallel_device/BUILD | 1 + .../parallel_device/parallel_device_lib.cc | 33 +- .../parallel_device/parallel_device_lib.h | 20 +- tensorflow/c/eager/tfe_executor_internal.h | 7 +- tensorflow/c/experimental/filesystem/BUILD | 2 + .../filesystem/modular_filesystem.cc | 6 +- .../experimental/filesystem/plugins/gcs/BUILD | 1 + .../filesystem/plugins/posix/BUILD | 1 + .../filesystem/plugins/windows/BUILD | 1 + tensorflow/c/experimental/gradients/BUILD | 18 +- .../c/experimental/gradients/tape/BUILD | 1 + tensorflow/c/experimental/grappler/BUILD | 1 + .../experimental/next_pluggable_device/BUILD | 34 + .../next_pluggable_device/c_api.cc | 333 + .../next_pluggable_device/c_api.h | 153 + tensorflow/c/experimental/ops/BUILD | 1 + tensorflow/c/experimental/ops/gen/BUILD | 1 + .../c/experimental/ops/gen/common/BUILD | 1 + tensorflow/c/experimental/ops/gen/cpp/BUILD | 1 + .../c/experimental/ops/gen/cpp/golden/BUILD | 1 + .../experimental/ops/gen/cpp/renderers/BUILD | 1 + .../c/experimental/ops/gen/cpp/views/BUILD | 1 + tensorflow/c/experimental/ops/gen/model/BUILD | 1 + .../c/experimental/pluggable_profiler/BUILD | 3 +- .../pluggable_profiler/pluggable_profiler.cc | 2 +- .../pluggable_profiler_internal.h | 2 +- .../c/experimental/saved_model/core/BUILD | 2 + .../c/experimental/saved_model/core/ops/BUILD | 1 + .../saved_model/core/revived_types/BUILD | 2 + .../saved_model/core/saved_model_utils.cc | 97 +- .../saved_model/core/saved_model_utils.h | 9 + .../c/experimental/saved_model/internal/BUILD | 5 +- .../saved_model/internal/testdata/BUILD | 1 + .../c/experimental/saved_model/public/BUILD | 1 + .../c/experimental/stream_executor/BUILD | 4 +- .../stream_executor/stream_executor.cc | 126 +- .../stream_executor_internal.h | 32 +- .../stream_executor/stream_executor_test.cc | 24 +- .../c/experimental/stream_executor/test/BUILD | 1 + tensorflow/c/kernels.cc | 33 + tensorflow/c/kernels.h | 16 + tensorflow/c/kernels/BUILD | 1 + tensorflow/c/tf_datatype.h | 5 +- tensorflow/c/tf_status.cc | 36 +- tensorflow/c/tf_status.h | 41 +- tensorflow/c/tf_status_helper.cc | 65 +- tensorflow/c/tf_status_helper.h | 4 +- tensorflow/c/tf_status_internal.h | 9 +- tensorflow/c/tf_tensor.cc | 2 +- tensorflow/c/tf_tstring.h | 2 +- tensorflow/cc/BUILD | 1 + tensorflow/cc/client/client_session_test.cc | 15 +- tensorflow/cc/experimental/base/public/BUILD | 1 + tensorflow/cc/experimental/base/tests/BUILD | 1 + tensorflow/cc/experimental/libexport/BUILD | 1 + tensorflow/cc/experimental/libtf/BUILD | 1 + tensorflow/cc/experimental/libtf/impl/BUILD | 1 + tensorflow/cc/experimental/libtf/mlir/BUILD | 1 + tensorflow/cc/experimental/libtf/object.h | 2 +- .../cc/experimental/libtf/runtime/BUILD | 1 + .../cc/experimental/libtf/runtime/core/BUILD | 1 + .../cc/experimental/libtf/runtime/tfrt/BUILD | 1 + .../libtf/tests/runtime_test_core.cc | 2 +- tensorflow/cc/framework/cc_op_gen.cc | 1 + tensorflow/cc/framework/cc_op_gen.h | 2 + tensorflow/cc/framework/cc_op_gen_util.h | 1 + tensorflow/cc/framework/fuzzing/BUILD | 13 +- .../cc/framework/fuzzing/cc_op_fuzz_gen.cc | 103 +- .../cc/framework/fuzzing/cc_op_fuzz_gen.h | 8 +- .../framework/fuzzing/cc_op_fuzz_gen_main.cc | 47 +- .../cc/framework/fuzzing/op_fuzzing.bzl | 222 +- tensorflow/cc/framework/grad_op_registry.h | 2 + tensorflow/cc/framework/gradient_checker.cc | 3 + tensorflow/cc/framework/gradient_checker.h | 2 + tensorflow/cc/framework/gradients.cc | 20 +- tensorflow/cc/framework/gradients.h | 2 + tensorflow/cc/framework/ops.h | 3 + tensorflow/cc/framework/scope_internal.h | 6 + tensorflow/cc/framework/testutil.h | 2 + tensorflow/cc/framework/while_gradients.cc | 2 + tensorflow/cc/framework/while_gradients.h | 2 + tensorflow/cc/saved_model/BUILD | 5 + tensorflow/cc/saved_model/bundle_v2.cc | 13 +- tensorflow/cc/saved_model/bundle_v2_test.cc | 6 + .../cc/saved_model/experimental/public/BUILD | 1 + .../cc/saved_model/experimental/tests/BUILD | 1 + tensorflow/cc/saved_model/fingerprinting.cc | 49 +- tensorflow/cc/saved_model/fingerprinting.h | 19 +- .../cc/saved_model/fingerprinting_test.cc | 49 +- tensorflow/cc/saved_model/loader.cc | 14 +- tensorflow/cc/saved_model/metrics.cc | 40 + tensorflow/cc/saved_model/metrics.h | 17 + tensorflow/cc/saved_model/metrics_test.cc | 32 + tensorflow/cc/saved_model/python/BUILD | 1 + tensorflow/cc/saved_model/reader.cc | 2 +- .../cc/saved_model/saved_model_bundle_test.cc | 20 + .../fingerprint.pb | 1 + tensorflow/cc/tools/BUILD | 1 + tensorflow/compiler/aot/BUILD | 6 +- .../compiler/aot/embedded_protocol_buffers.cc | 7 +- tensorflow/compiler/aot/tests/BUILD | 1 + tensorflow/compiler/jit/BUILD | 310 +- .../clone_constants_for_better_clustering.cc | 49 +- .../clone_constants_for_better_clustering.h | 1 - ...ne_constants_for_better_clustering_test.cc | 27 +- tensorflow/compiler/jit/deadness_analysis.cc | 2 +- tensorflow/compiler/jit/deadness_analysis.h | 5 +- .../compiler/jit/deadness_analysis_test.cc | 2 +- .../compiler/jit/device_compilation_cache.h | 212 + .../jit/device_compilation_cache_test.cc | 220 + .../device_compilation_cluster_signature.cc | 139 + .../device_compilation_cluster_signature.h | 56 + ...ice_compilation_cluster_signature_test.cc} | 41 +- .../jit/device_compilation_profiler.cc | 229 + .../jit/device_compilation_profiler.h | 100 + .../jit/device_compilation_profiler_test.cc | 243 + tensorflow/compiler/jit/device_compiler.h | 492 ++ .../compiler/jit/device_compiler_client.cc | 46 + .../compiler/jit/device_compiler_client.h | 75 + .../jit/device_compiler_client_test.cc | 64 + ...est.cc => device_compiler_disable_test.cc} | 50 +- .../jit/device_executable_persistor.h | 336 + .../jit/device_executable_persistor_test.cc | 483 ++ tensorflow/compiler/jit/encapsulate_util.cc | 3 +- tensorflow/compiler/jit/encapsulate_util.h | 4 +- .../jit/extract_outside_compilation_pass.cc | 1 - tensorflow/compiler/jit/flags.cc | 61 +- tensorflow/compiler/jit/flags.h | 33 +- tensorflow/compiler/jit/get_compiler_ir.cc | 275 +- tensorflow/compiler/jit/get_compiler_ir.h | 18 +- .../introduce_floating_point_jitter_pass.cc | 153 - .../introduce_floating_point_jitter_pass.h | 35 - ...duce_floating_point_jitter_pass_internal.h | 27 - ...troduce_floating_point_jitter_pass_test.cc | 197 - .../jit/jit_compilation_pass_registration.cc | 4 - tensorflow/compiler/jit/kernels/BUILD | 12 +- tensorflow/compiler/jit/kernels/xla_ops.cc | 285 +- tensorflow/compiler/jit/kernels/xla_ops.h | 6 +- .../compiler/jit/mark_for_compilation_pass.cc | 2 + tensorflow/compiler/jit/ops/BUILD | 1 + .../jit/pjrt_device_compiler_client.cc | 81 + .../jit/pjrt_device_compiler_client.h | 73 + .../compiler/jit/pjrt_device_context.cc | 124 + tensorflow/compiler/jit/pjrt_device_context.h | 44 + tensorflow/compiler/jit/tests/BUILD | 26 +- ...device_compiler_serialize_options_test.cc} | 12 +- ...t.cc => device_compiler_serialize_test.cc} | 12 +- ...lper.cc => device_compiler_test_helper.cc} | 11 +- ...helper.h => device_compiler_test_helper.h} | 33 +- .../tests/keras_imagenet_main.golden_summary | 10 +- ...as_imagenet_main_graph_mode.golden_summary | 12 +- ...pens2s_gnmt_mixed_precision.golden_summary | 102 +- .../compiler/jit/tf_graph_to_hlo_compiler.cc | 36 + .../compiler/jit/tf_graph_to_hlo_compiler.h | 59 + tensorflow/compiler/jit/tf_to_hlo_compiler.h | 52 + tensorflow/compiler/jit/xla_cluster_util.h | 1 - .../compiler/jit/xla_compilation_cache.cc | 951 --- .../compiler/jit/xla_compilation_cache.h | 357 - .../compiler/jit/xla_compile_on_demand_op.cc | 59 +- .../compiler/jit/xla_compile_on_demand_op.h | 9 +- tensorflow/compiler/jit/xla_compile_util.cc | 76 +- tensorflow/compiler/jit/xla_compile_util.h | 39 +- .../compiler/jit/xla_compile_util_test.cc | 77 +- tensorflow/compiler/jit/xla_device.cc | 10 +- tensorflow/compiler/jit/xla_device.h | 12 +- .../jit/xla_device_compiler_client.cc | 114 + .../compiler/jit/xla_device_compiler_client.h | 68 + tensorflow/compiler/jit/xla_device_context.cc | 2 +- tensorflow/compiler/jit/xla_gpu_device.cc | 11 +- tensorflow/compiler/jit/xla_launch_util.cc | 64 +- tensorflow/compiler/jit/xla_launch_util.h | 37 +- tensorflow/compiler/jit/xla_platform_info.cc | 55 +- tensorflow/compiler/jit/xla_platform_info.h | 19 +- tensorflow/compiler/jit/xla_tpu_device.cc | 4 +- tensorflow/compiler/mlir/BUILD | 29 +- .../mlir/g3doc/_includes/tf_passes.md | 23 + .../compiler/mlir/g3doc/xla_gpu_codegen.md | 269 - tensorflow/compiler/mlir/lite/BUILD | 46 +- .../compiler/mlir/lite/converter_gen.cc | 6 +- .../mlir/lite/experimental/common/BUILD | 18 + .../experimental/common/outline_operations.cc | 210 + .../experimental/common/outline_operations.h | 132 + .../compiler/mlir/lite/experimental/tac/BUILD | 7 +- .../lite/experimental/tac/common/subgraph.h | 3 +- .../lite/experimental/tac/common/targets.h | 14 +- .../mlir/lite/experimental/tac/examples/BUILD | 1 + .../tac/execution_metadata_exporter.cc | 23 +- .../tac/execution_metadata_exporter_test.cc | 3 +- .../lite/experimental/tac/hardwares/BUILD | 1 + .../lite/experimental/tac/py_wrapper/BUILD | 4 + .../mlir/lite/experimental/tac/tac_module.cc | 2 +- .../mlir/lite/experimental/tac/tests/BUILD | 6 +- .../lite/experimental/tac/tests/e2e/BUILD | 1 + .../tac/tests/fold-constants-to-subgraph.mlir | 25 + .../tac/tests/raise-target-subgraphs.mlir | 342 +- .../experimental/tac/tflite_import_export.cc | 9 +- .../tac/transforms/compute_cost.cc | 2 +- .../experimental/tac/transforms/cost_model.cc | 2 +- .../transforms/fold_constants_to_subgraph.cc | 23 +- .../transforms/get_alternative_subgraph.cc | 14 +- .../tac/transforms/pick_subgraphs.cc | 15 +- .../tac/transforms/raise_target_subgraphs.cc | 428 +- .../mlir/lite/experimental/tac/utils/BUILD | 1 + .../mlir/lite/experimental/tac/utils/utils.cc | 13 +- .../mlir/lite/experimental/tac/utils/utils.h | 4 +- .../compiler/mlir/lite/flatbuffer_export.cc | 203 +- .../compiler/mlir/lite/flatbuffer_import.cc | 21 +- .../compiler/mlir/lite/flatbuffer_operator.cc | 8 +- .../compiler/mlir/lite/flatbuffer_operator.h | 4 +- .../mlir/lite/ir/tfl_op_interfaces.td | 3 +- tensorflow/compiler/mlir/lite/ir/tfl_ops.cc | 420 +- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 88 +- tensorflow/compiler/mlir/lite/metrics/BUILD | 4 +- .../lite/metrics/error_collector_inst_test.cc | 14 +- tensorflow/compiler/mlir/lite/python/BUILD | 12 +- .../lite/python/graphdef_to_tfl_flatbuffer.cc | 7 +- .../mlir/lite/python/jax_to_tfl_flatbuffer.cc | 7 +- .../python/saved_model_to_tfl_flatbuffer.cc | 3 +- .../lite/python/tf_tfl_flatbuffer_helpers.cc | 18 +- .../lite/python/tf_tfl_flatbuffer_helpers.h | 2 +- .../compiler/mlir/lite/quantization/BUILD | 4 + .../mlir/lite/quantization/device_target.cc | 2 +- .../mlir/lite/quantization/device_target.h | 3 +- .../compiler/mlir/lite/quantization/ir/BUILD | 1 + .../mlir/lite/quantization/ir/QuantOps.cc | 2 +- .../mlir/lite/quantization/ir/QuantOpsBase.td | 3 +- .../mlir/lite/quantization/lite/BUILD | 2 + .../lite/quantization/lite/quantize_model.cc | 10 +- .../lite/quantization/lite/quantize_model.h | 7 +- .../quantization/lite/quantize_weights.cc | 12 +- .../lite/quantization/lite/quantize_weights.h | 20 +- .../lite/quantization/quantization_config.cc | 2 +- .../lite/quantization/quantization_config.h | 13 +- .../lite/quantization/quantization_utils.cc | 2 +- .../lite/quantization/quantization_utils.h | 156 +- .../mlir/lite/quantization/tensorflow/BUILD | 1 + .../lite/quantization/tensorflow/tests/BUILD | 5 +- .../quantization/tensorflow/tf_to_quant.cc | 18 +- .../mlir/lite/quantization/tests/BUILD | 5 +- .../tflite_op_coverage_spec_getters_gen.cc | 64 +- tensorflow/compiler/mlir/lite/sparsity/BUILD | 24 + .../mlir/lite/sparsity/sparsify_model.cc | 15 +- .../mlir/lite/sparsity/sparsify_model.h | 6 +- .../mlir/lite/sparsity/sparsify_model_test.cc | 87 + tensorflow/compiler/mlir/lite/stablehlo/BUILD | 271 +- .../mlir/lite/stablehlo/odml_to_stablehlo.cc | 47 +- .../compiler/mlir/lite/stablehlo/tests/BUILD | 6 +- .../lite/stablehlo/tests/fold_broadcast.mlir | 44 + .../tests/fuse_mhlo_convolution.mlir | 18 + .../lite/stablehlo/tests/legalize-acos.mlir | 38 - ...legalize-inplaceupdate-tf_mhlo_tflite.mlir | 19 - .../tests/legalize-inplaceupdate.mlir | 12 +- .../tests/legalize-mhlo-tf-fb-tf.mlir | 17 - .../legalize-mhlo-tfl-broadcast_in_dim.mlir | 15 - .../tests/legalize-mhlo-tfl-compare.mlir | 19 - .../tests/legalize-mhlo-tfl-constant.mlir | 17 - .../tests/legalize-mhlo-tfl-reshape.mlir | 15 - .../tests/legalize-mhlo-tfl-rsqrt.mlir | 15 - .../stablehlo/tests/legalize-mhlo-tfl.mlir | 17 - .../lite/stablehlo/tests/legalize-poly.mlir | 31 - .../tests/legalize-skip-quantization-ops.mlir | 4 +- .../tests/legalize-stablehlo-tf-fb-tf.mlir | 17 + ...b.mlir => legalize-stablehlo-tfl-add.mlir} | 6 +- ...galize-stablehlo-tfl-broadcast_in_dim.mlir | 15 + ...mlir => legalize-stablehlo-tfl-clamp.mlir} | 6 +- .../tests/legalize-stablehlo-tfl-compare.mlir | 19 + ...lir => legalize-stablehlo-tfl-concat.mlir} | 6 +- .../legalize-stablehlo-tfl-constant.mlir | 17 + ....mlir => legalize-stablehlo-tfl-conv.mlir} | 10 +- ...t.mlir => legalize-stablehlo-tfl-dot.mlir} | 8 +- ...lir => legalize-stablehlo-tfl-gather.mlir} | 8 +- ...x.mlir => legalize-stablehlo-tfl-max.mlir} | 6 +- ...d.mlir => legalize-stablehlo-tfl-mul.mlir} | 6 +- ...d.mlir => legalize-stablehlo-tfl-pad.mlir} | 6 +- .../tests/legalize-stablehlo-tfl-reshape.mlir | 15 + .../tests/legalize-stablehlo-tfl-rsqrt.mlir | 15 + ...ir => legalize-stablehlo-tfl-scatter.mlir} | 10 +- ...l.mlir => legalize-stablehlo-tfl-sub.mlir} | 6 +- .../tests/legalize-stablehlo-tfl.mlir | 17 + .../tests/legalize-tfl-mhlo-broadcast.mlir | 15 - .../tests/legalize-tfl-mhlo-concat.mlir | 15 - .../tests/legalize-tfl-mhlo-constant.mlir | 15 - .../tests/legalize-tfl-mhlo-conv.mlir | 14 - .../tests/legalize-tfl-mhlo-pad.mlir | 16 - .../tests/legalize-tfl-mhlo-reshape.mlir | 15 - .../tests/legalize-tfl-mhlo-rsqrt.mlir | 15 - .../stablehlo/tests/legalize-tfl-mhlo.mlir | 17 - ...b.mlir => legalize-tfl-stablehlo-add.mlir} | 6 +- .../legalize-tfl-stablehlo-broadcast.mlir | 15 + ...mlir => legalize-tfl-stablehlo-clamp.mlir} | 6 +- .../tests/legalize-tfl-stablehlo-concat.mlir | 15 + .../legalize-tfl-stablehlo-constant.mlir | 15 + .../tests/legalize-tfl-stablehlo-conv.mlir | 14 + ...d.mlir => legalize-tfl-stablehlo-max.mlir} | 6 +- ...x.mlir => legalize-tfl-stablehlo-mul.mlir} | 6 +- .../tests/legalize-tfl-stablehlo-pad.mlir | 16 + .../tests/legalize-tfl-stablehlo-reshape.mlir | 15 + .../tests/legalize-tfl-stablehlo-rsqrt.mlir | 15 + ...l.mlir => legalize-tfl-stablehlo-sub.mlir} | 6 +- .../tests/legalize-tfl-stablehlo.mlir | 17 + .../stablehlo/tests/odml-stablehlo-tfl.mlir | 8 +- .../tests/odml-to-stablehlo-allow-tf.mlir | 4 +- .../mlir/lite/stablehlo/tests/optimize.mlir | 246 + .../tf-tfl-translate-serialize-stablehlo.mlir | 22 + .../tests/tf-tfl-translate-tf-quantize.mlir | 10 +- .../tests/unfuse_mhlo_batch_norm.mlir | 126 + .../transforms/check_accepted_ops_pass.cc | 14 +- .../transforms/drop_savedmodel_semantics.cc | 17 +- .../transforms/drop_savedmodel_semantics.h | 6 +- .../transforms/fold_broadcast_pass.cc | 259 + .../transforms/fuse_convolution_pass.cc | 148 + .../transforms/mhlo_tfl_legalize_patterns.td | 46 - .../lite/stablehlo/transforms/mhlo_util.cc | 77 - .../lite/stablehlo/transforms/mhlo_util.h | 107 - .../lite/stablehlo/transforms/op_stat_pass.cc | 16 +- .../lite/stablehlo/transforms/optimize.cc | 547 ++ .../mlir/lite/stablehlo/transforms/passes.h | 43 + .../transforms/rename_entrypoint_to_main.cc | 6 +- .../transforms/rename_entrypoint_to_main.h | 6 +- .../transforms/smuggle_disallowed_ops.cc | 14 +- .../transforms/smuggle_disallowed_ops.h | 6 +- ...mhlo_tfl_pass.cc => stablehlo_tfl_pass.cc} | 87 +- .../{mhlo_tfl_pass.h => stablehlo_tfl_pass.h} | 16 +- .../stablehlo/transforms/stablehlo_util.cc | 45 + .../stablehlo/transforms/stablehlo_util.h | 41 + .../lite/stablehlo/transforms/tf_mhlo_pass.cc | 116 - .../lite/stablehlo/transforms/tf_mhlo_pass.h | 38 - .../stablehlo/transforms/tf_mhlo_tfl_pass.cc | 199 - .../lite/stablehlo/transforms/tf_poly_pass.cc | 204 - .../lite/stablehlo/transforms/tf_poly_pass.h | 42 - .../stablehlo/transforms/tf_stablehlo_pass.cc | 165 + ...tf_mhlo_tfl_pass.h => tf_stablehlo_pass.h} | 26 +- ...tfl_mhlo_pass.cc => tfl_stablehlo_pass.cc} | 64 +- .../{tfl_mhlo_pass.h => tfl_stablehlo_pass.h} | 16 +- .../lite/stablehlo/transforms/transforms.cc | 32 +- .../transforms/unfuse_batch_norm_pass.cc | 198 + tensorflow/compiler/mlir/lite/tests/BUILD | 13 +- .../mlir/lite/tests/canonicalize.mlir | 2 +- .../compiler/mlir/lite/tests/debuginfo/BUILD | 2 + .../tests/decompose-hybrid-quantization.mlir | 4 +- .../compiler/mlir/lite/tests/end2end/BUILD | 5 + .../mlir/lite/tests/flatbuffer2mlir/BUILD | 2 + .../flatbuffer2mlir/importer_test_min_max.cc | 8 +- .../mlir/lite/tests/flatbuffer2mlir/lstm.mlir | 2 +- .../flatbuffer2mlir/multi_output_op.json | 5 +- .../mlir/lite/tests/get-arithmetic-count.mlir | 2 +- .../compiler/mlir/lite/tests/legalize-tf.mlir | 18 +- .../mlir/lite/tests/lift_tflite_flex_ops.mlir | 24 + .../compiler/mlir/lite/tests/mlir2exec/BUILD | 2 + .../mlir/lite/tests/mlir2flatbuffer/BUILD | 2 + .../transpose_conv_optional.mlir | 2 +- tensorflow/compiler/mlir/lite/tests/ops.mlir | 15 +- .../compiler/mlir/lite/tests/optimize.mlir | 10 +- .../mlir/lite/tests/post-quantize.mlir | 4 +- .../tests/prepare-composite-functions-tf.mlir | 236 +- .../tests/prepare-quantize-dynamic-range.mlir | 2 +- ...prepare-quantize-post-training-16bits.mlir | 213 + .../lite/tests/prepare-quantize-signed.mlir | 4 +- .../lite/tests/quantize-dynamic-range.mlir | 2 +- .../mlir/lite/tests/quantize-variables.mlir | 190 + .../compiler/mlir/lite/tests/quantize.mlir | 59 + .../mlir/lite/tests/shape-inference.mlir | 11 + .../mlir/lite/tests/tfl_while_outline.mlir | 30 - .../compiler/mlir/lite/tf_tfl_passes.cc | 32 +- .../compiler/mlir/lite/tf_tfl_translate.cc | 8 +- .../mlir/lite/tf_to_tfl_flatbuffer.cc | 32 +- .../compiler/mlir/lite/tf_to_tfl_flatbuffer.h | 8 +- .../mlir/lite/transforms/dilated_conv.h | 51 +- .../lite/transforms/insert_call_once_op.cc | 17 +- .../lite/transforms/legalize_hashtables.cc | 26 +- .../lite/transforms/legalize_jax_random.cc | 2 +- .../mlir/lite/transforms/legalize_patterns.td | 3 +- .../mlir/lite/transforms/legalize_tf.cc | 220 +- .../mlir/lite/transforms/legalize_tf_while.cc | 2 +- .../lite/transforms/lift_tflite_flex_ops.cc | 24 +- .../transforms/lower_static_tensor_list.cc | 32 +- .../compiler/mlir/lite/transforms/optimize.cc | 35 +- .../transforms/optimize_functional_ops.cc | 12 +- .../mlir/lite/transforms/optimize_patterns.td | 32 +- .../compiler/mlir/lite/transforms/passes.h | 12 +- .../compiler/mlir/lite/transforms/passes.td | 8 + .../prepare_composite_functions_tf.cc | 15 + .../mlir/lite/transforms/prepare_quantize.cc | 31 +- .../prepare_quantize_dynamic_range.cc | 2 +- .../transforms/prepare_quantize_helper.cc | 43 + .../lite/transforms/prepare_quantize_helper.h | 38 +- .../mlir/lite/transforms/prepare_tf.cc | 126 +- .../compiler/mlir/lite/transforms/quantize.cc | 50 +- .../lite/transforms/quantize_variables.cc | 208 + .../lite/transforms/reduce_while_operands.cc | 5 +- .../lite/transforms/while_loop_outline.cc | 8 +- .../mlir/lite/utils/constant_utils.cc | 2 +- .../compiler/mlir/lite/utils/constant_utils.h | 4 +- .../mlir/lite/utils/fake_quant_utils.h | 14 +- .../compiler/mlir/lite/utils/lstm_utils.cc | 53 +- .../compiler/mlir/lite/utils/lstm_utils.h | 3 + .../mlir/lite/utils/lstm_utils_test.cc | 4 +- .../mlir/lite/utils/perception_ops_utils.cc | 6 + .../lite/utils/size_utils.cc} | 18 +- .../compiler/mlir/lite/utils/size_utils.h | 32 + .../mlir/lite/utils/size_utils_test.cc | 33 + .../mlir/mlir_graph_optimization_pass.cc | 64 +- .../mlir/mlir_graph_optimization_pass_test.cc | 29 + tensorflow/compiler/mlir/python/BUILD | 5 +- tensorflow/compiler/mlir/python/mlir.cc | 41 +- tensorflow/compiler/mlir/python/mlir.h | 4 + .../compiler/mlir/python/mlir_wrapper/BUILD | 5 +- .../mlir/quantization/stablehlo/BUILD | 28 + .../internal_visibility_allowlist.bzl | 10 + .../stablehlo/quantization_options.proto | 107 + .../mlir/quantization/tensorflow/BUILD | 91 +- .../quantization/tensorflow/calibrator/BUILD | 38 +- .../custom_aggregator_op_test.py | 12 +- .../mlir/quantization/tensorflow/cc/BUILD | 93 + .../tensorflow/cc/const_op_size.cc | 79 + .../tensorflow/cc/const_op_size.h | 32 + .../tensorflow/cc/const_op_size_test.cc | 151 + .../tensorflow/cc/save_variables.cc | 135 + .../tensorflow/cc/save_variables.h | 40 + .../tensorflow/cc/save_variables_test.cc | 385 ++ .../quantization/tensorflow/debugging/BUILD | 46 + .../tensorflow/debugging/mlir_dump.cc | 145 + .../tensorflow/debugging/mlir_dump.h | 43 + .../tensorflow/debugging/mlir_dump_test.cc | 80 + .../tensorflow/exported_model.proto | 40 + .../gen_quantized_function_library.py | 85 +- .../tensorflow/ops/tf_op_quant_spec.cc | 23 +- .../tensorflow/ops/tf_op_quant_spec.h | 1 - .../tensorflow/ops/uniform_op_quant_spec.cc | 39 + .../tensorflow/ops/uniform_op_quant_spec.h | 35 + .../passes/convert_tf_quant_ops_to_mhlo.cc | 2 +- .../duplicate_shape_determining_constants.cc | 369 + .../passes/insert_custom_aggregation_ops.cc | 12 +- .../tensorflow/passes/insert_main_function.cc | 326 +- .../passes/insert_quantized_functions.cc | 55 +- .../tensorflow/passes/insert_restore_op.cc | 213 + .../lift_quantizable_spots_as_functions.cc | 94 +- .../lift_quantizable_spots_as_functions.td | 54 + ...lift_quantizable_spots_as_functions_drq.cc | 11 +- ...lift_quantizable_spots_as_functions_drq.td | 36 + .../passes/mark_functions_noinline.cc | 124 + .../merge_initializer_function_ops_to_main.cc | 292 +- .../tensorflow/passes/optimize.td | 3 +- .../quantization/tensorflow/passes/passes.h | 61 +- .../tensorflow/passes/post_quantize.cc | 26 +- .../tensorflow/passes/post_quantize.td | 4 +- .../tensorflow/passes/prepare_lifting.cc | 214 +- .../tensorflow/passes/prepare_lifting.td | 201 +- .../tensorflow/passes/prepare_quantize.cc | 40 +- .../tensorflow/passes/prepare_quantize_drq.cc | 114 +- .../tensorflow/passes/preprocess_op.cc | 209 + .../tensorflow/passes/preprocess_op.td | 28 + .../tensorflow/passes/quantize.cc | 249 +- .../passes/quantize_composite_functions.cc | 559 +- .../passes/quantized_function_library.mlir | 186 +- .../quantized_function_library_tf_drq.mlir | 140 +- ...ed_function_library_uniform_quantized.mlir | 242 +- ...unction_library_uniform_quantized_drq.mlir | 78 +- .../passes/remove_var_init_by_const.cc | 121 + .../replace_cast_hacks_with_tf_xla_ops.cc | 378 +- .../replace_cast_hacks_with_tf_xla_ops.td | 278 +- .../tensorflow/passes/tf_quant_ops.td | 75 - .../tensorflow/passes/unfreeze_constants.cc | 141 +- .../quantization/tensorflow/passes/utils.cc | 17 +- .../quantization/tensorflow/passes/utils.h | 24 +- .../quantization/tensorflow/passes/utils.td | 33 +- .../mlir/quantization/tensorflow/python/BUILD | 159 +- .../integration_test/concurrency_test.py | 42 +- .../integration_test/quantize_model_test.py | 4030 ++++++++--- .../quantize_model_test_base.py | 801 ++- .../python/pywrap_quantize_model.cc | 223 +- .../python/pywrap_quantize_model_test.py | 51 + .../tensorflow/python/quantize_model.cc | 685 +- .../tensorflow/python/quantize_model.h | 55 +- .../tensorflow/python/quantize_model.py | 684 +- .../python/quantize_model_wrapper.cc | 158 - .../python/quantize_model_wrapper.h | 59 - .../python/representative_dataset.py | 50 +- .../python/representative_dataset_test.py | 114 +- .../tensorflow/python/save_model.py | 413 ++ .../tensorflow/quantization_options.proto | 38 +- .../tensorflow/quantize_passes.cc | 95 +- .../tensorflow/quantize_preprocess.cc | 84 +- .../tensorflow/quantize_preprocess.h | 39 +- .../mlir/quantization/tensorflow/tests/BUILD | 3 + ...duplicate_shape_determining_constants.mlir | 223 + .../tensorflow/tests/fake_quant_e2e_xla.mlir | 21 +- .../tests/insert_main_function.mlir | 132 +- .../tests/insert_quantized_functions.mlir | 35 +- .../tests/insert_quantized_functions_drq.mlir | 12 +- .../tensorflow/tests/insert_restore_op.mlir | 170 + ...ft_quantizable_spots_as_functions_drq.mlir | 115 + ...e_spots_as_functions_drq_min_elements.mlir | 22 +- ...ft_quantizable_spots_as_functions_xla.mlir | 30 + .../tests/mark_functions_noinline.mlir | 24 + ...erge_initializer_function_ops_to_main.mlir | 226 +- .../tensorflow/tests/prepare_lifting.mlir | 95 +- .../tensorflow/tests/prepare_quantize.mlir | 6 +- .../tests/prepare_quantize_drq.mlir | 86 +- .../prepare_quantize_drq_per_channel.mlir | 90 + .../tests/prepare_quantize_ptq.mlir | 4 +- .../prepare_quantize_ptq_per_channel.mlir | 47 + .../tensorflow/tests/preprocess_op.mlir | 39 + .../tensorflow/tests/quantize.mlir | 6 +- .../tests/quantize_composite_functions.mlir | 62 +- .../quantize_composite_functions_drq.mlir | 132 +- ...ntize_composite_functions_weight_only.mlir | 101 + .../quantize_composite_functions_xla.mlir | 35 +- .../tensorflow/tests/quantize_drq.mlir | 14 +- .../tensorflow/tests/quantize_xla.mlir | 136 + .../tests/remove_var_init_by_const.mlir | 150 + .../replace_cast_hacks_with_tf_xla_ops.mlir | 496 +- ...hacks_with_tf_xla_ops_large_constants.mlir | 2 +- .../tensorflow/tests/unfreeze_constants.mlir | 122 +- .../mlir/quantization/tensorflow/utils/BUILD | 18 + .../tensorflow/utils/fake_quant_utils.h | 23 +- .../utils/lift_as_function_call_utils.cc | 5 +- .../utils/lift_as_function_call_utils.h | 1 - .../utils/tf_to_uniform_attribute_utils.cc | 241 + .../utils/tf_to_uniform_attribute_utils.h | 44 + .../utils/tf_to_xla_attribute_utils.h | 4 +- tensorflow/compiler/mlir/runlit.cfg.py | 49 +- tensorflow/compiler/mlir/runlit.site.cfg.py | 28 +- tensorflow/compiler/mlir/tensorflow/BUILD | 107 +- .../analysis/resource_alias_analysis.cc | 28 +- .../analysis/resource_value_typed_analyzer.cc | 18 +- .../analysis/side_effect_analysis.cc | 44 +- .../analysis/side_effect_analysis.h | 15 +- tensorflow/compiler/mlir/tensorflow/c/BUILD | 1 + .../c/c_api_unified_experimental_mlir.cc | 7 +- .../mlir/tensorflow/ir/tf_arith_ops_folder.h | 8 +- .../compiler/mlir/tensorflow/ir/tf_device.cc | 10 +- .../mlir/tensorflow/ir/tf_device_ops.td | 6 +- .../mlir/tensorflow/ir/tf_executor.cc | 6 +- .../mlir/tensorflow/ir/tf_executor_ops.td | 2 +- .../mlir/tensorflow/ir/tf_generated_ops.td | 1113 ++- .../compiler/mlir/tensorflow/ir/tf_op_base.td | 22 +- .../mlir/tensorflow/ir/tf_op_interfaces.h | 2 +- .../mlir/tensorflow/ir/tf_op_interfaces.td | 16 +- .../compiler/mlir/tensorflow/ir/tf_ops.cc | 8 +- .../compiler/mlir/tensorflow/ir/tf_ops.td | 139 +- .../compiler/mlir/tensorflow/ir/tf_ops_a_m.cc | 899 +-- .../mlir/tensorflow/ir/tf_ops_layout_helper.h | 9 +- .../compiler/mlir/tensorflow/ir/tf_ops_n_z.cc | 752 +- .../tensorflow/ir/tf_ops_tensor_helper.cc | 16 +- .../mlir/tensorflow/ir/tf_remaining_ops.cc | 45 +- .../mlir/tensorflow/ir/tf_saved_model.cc | 91 +- .../mlir/tensorflow/ir/tf_saved_model.h | 25 +- .../mlir/tensorflow/ir/tf_saved_model_ops.td | 3 +- .../mlir/tensorflow/ir/tf_saved_model_test.cc | 189 + .../mlir/tensorflow/ir/tf_side_effects.h | 5 + .../compiler/mlir/tensorflow/ir/tf_structs.cc | 4 +- .../compiler/mlir/tensorflow/ir/tf_traits.h | 6 +- .../compiler/mlir/tensorflow/ir/tf_types.def | 2 + .../compiler/mlir/tensorflow/ir/tfrt_ops.cc | 8 +- .../ir/tpu_embedding_ops_registry.cc | 34 + .../ir/tpu_embedding_ops_registry.h | 59 + .../compiler/mlir/tensorflow/tests/BUILD | 10 +- .../mlir/tensorflow/tests/canonicalize.mlir | 2 +- .../tensorflow/tests/cluster_formation.mlir | 80 +- .../tensorflow/tests/compile_mlir_util/BUILD | 2 + .../replicate-tensor-list-init-ops.mlir | 26 + .../compile_mlir_util/stablehlo_add.mlir | 20 + .../mlir/tensorflow/tests/constant-fold.mlir | 52 +- ...nvert_session_initializer_to_function.mlir | 51 + .../tensorflow/tests/device_canonicalize.mlir | 2 +- .../mlir/tensorflow/tests/einsum.mlir | 10 + .../tests/executor_canonicalize.mlir | 2 +- .../executor_island_materialize_const.mlir | 2 +- .../executor_tpuv1_island_coarsening/BUILD | 2 + .../executor_tpuv1_island_inlining/BUILD | 2 + .../tests/executor_tpuv1_outline_island/BUILD | 2 + .../tensorflow/tests/freeze_variables.mlir | 44 + .../mlir/tensorflow/tests/graphdef2mlir/BUILD | 2 + .../batch_use_same_function/BUILD | 2 + .../tests/guarantee-all-funcs-one-use.mlir | 18 + .../tests/launch_to_device_attribute.mlir | 55 +- .../launch_to_device_attribute_legacy.mlir | 121 + .../mlir/tensorflow/tests/legalize_hlo.mlir | 140 + .../mlir/tensorflow/tests/lower_tf.mlir | 21 +- .../mark_ops_for_outside_compilation.mlir | 9 + .../mlir/tensorflow/tests/mlir2graphdef/BUILD | 2 + .../tensorflow/tests/order_by_dialect.mlir | 68 +- .../tests/parallel_execute_to_islands.mlir | 84 +- .../parallel_execute_to_islands_legacy.mlir | 182 + .../tests/remove_unused_arguments.mlir | 15 + .../tests/replicate_tensor_list_init_ops.mlir | 139 + .../tensorflow/tests/replicate_to_island.mlir | 150 +- .../tests/replicate_to_island_legacy.mlir | 277 + .../tensorflow/tests/shape_inference.mlir | 73 + .../tests/side-effect-analysis-test.mlir | 247 +- .../tests/split_into_island_per_op.mlir | 20 + .../mlir/tensorflow/tests/tf-ops.mlir | 353 +- .../tensorflow/tests/tf_saved_model/BUILD | 1 + .../tensorflow/tests/tf_saved_model/common.py | 6 +- .../control_flow_upgrade_legacy_v1.py | 2 + .../tests/tf_saved_model/debug_info.py | 6 +- .../tensorflow/tests/tf_to_hlo_pipeline/BUILD | 2 + .../mlir/tensorflow/tests/tpu_bridge_v1/BUILD | 2 + .../tests/tpu_cluster_formation.mlir | 10 +- .../tpu_extract_outside_compilation.mlir | 71 + .../tests/tpu_partitioned_op_conversion.mlir | 106 + ...rder_replicate_and_partitioned_inputs.mlir | 109 +- .../tests/tpu_resource_partitioning.mlir | 113 +- .../mlir/tensorflow/tests/tpu_rewrite.mlir | 56 +- .../tests/tpu_sharding_identification.mlir | 28 +- .../tensorflow/tests/unroll-batch-matmul.mlir | 11 + .../tests/update_control_dependencies.mlir | 163 +- .../tests/xla_cluster_formation.mlir | 54 +- .../mlir/tensorflow/tests/xla_rewrite.mlir | 60 +- .../transforms/batchmatmul_to_einsum.cc | 8 +- .../mlir/tensorflow/transforms/bridge.cc | 83 +- .../mlir/tensorflow/transforms/bridge.h | 3 + .../tensorflow/transforms/call_graph_util.h | 64 + .../transforms/cluster_formation.cc | 134 +- .../transforms/cluster_ops_by_policy.cc | 10 +- .../transforms/cluster_tf_ops_pass.cc | 10 +- .../transforms/collection_ops_util.cc | 21 +- .../tensorflow/transforms/constant_fold.cc | 6 +- .../convert_control_to_data_outputs.cc | 9 +- .../convert_launch_func_to_tf_call.cc | 2 +- ...convert_session_initializer_to_function.cc | 100 + .../convert_tf_control_flow_to_scf.cc | 16 +- .../transforms/decode_attributes_hook.cc | 2 +- .../transforms/decompose_reduce_dataset.cc | 16 +- .../transforms/decompose_resource_ops.cc | 12 +- .../transforms/decompose_resource_ops_pass.cc | 9 +- .../transforms/device_index_selector.cc | 2 +- .../mlir/tensorflow/transforms/einsum.cc | 49 +- .../executor_tpuv1_inline_tpu_island.cc | 2 +- .../executor_tpuv1_island_coarsening.cc | 28 +- .../tensorflow/transforms/fold_broadcast.cc | 14 +- .../transforms/freeze_global_tensors.cc | 2 + .../transforms/freeze_saved_model_assets.cc | 7 +- .../functional_control_flow_to_cfg.cc | 8 +- .../functional_control_flow_to_regions.cc | 18 +- .../transforms/fused_kernel_matcher.cc | 20 +- .../mlir/tensorflow/transforms/gpu_fusion.cc | 16 +- .../transforms/hoist_loop_invariant.cc | 4 +- ...ist_replicate_invariant_resource_writes.cc | 6 +- .../transforms/init_text_file_to_import.cc | 10 +- .../initialize_variables_in_session_init.cc | 47 +- .../transforms/launch_to_device_attribute.cc | 37 +- .../transforms/layout_optimization.cc | 21 +- .../tensorflow/transforms/legalize_hlo.cc | 294 +- .../transforms/legalize_hlo_patterns.td | 384 +- .../transforms/localize_var_handles.cc | 18 +- .../transforms/lower_globals_to_ml_program.cc | 16 +- .../mlir/tensorflow/transforms/lower_tf.cc | 106 +- .../mlir/tensorflow/transforms/lower_tf.td | 17 +- .../lower_variable_ops_to_ml_program.cc | 20 +- .../transforms/mark_input_output_aliases.cc | 4 +- .../mark_ops_for_outside_compilation.cc | 39 +- .../materialize_mlir_passthrough_op.cc | 2 +- .../transforms/merge_control_flow.cc | 50 +- .../mlir/tensorflow/transforms/mlprogram.cc | 13 +- .../transforms/name_anonymous_iterators.cc | 2 +- .../mlir/tensorflow/transforms/optimize.cc | 17 +- .../tensorflow/transforms/order_by_dialect.cc | 127 +- .../transforms/parallel_execute_to_islands.cc | 113 +- .../mlir/tensorflow/transforms/passes.h | 43 +- .../prepare_tpu_computation_for_tf_export.cc | 12 +- .../transforms/promote_resources_to_args.cc | 18 +- .../readonly_references_to_resources.cc | 3 +- .../region_control_flow_to_functional.cc | 71 +- .../transforms/remove_unused_arguments.cc | 13 +- .../transforms/remove_unused_while_results.cc | 6 +- .../replicate_invariant_op_hoisting.cc | 10 +- .../replicate_tensor_list_init_ops_pass.cc | 79 + .../transforms/replicate_to_island.cc | 125 +- .../transforms/resource_device_inference.cc | 12 +- .../transforms/resource_op_lifting.cc | 40 +- .../transforms/resource_op_lifting_cleanup.cc | 31 +- .../transforms/rewrite_tpu_embedding_ops.cc | 6 +- .../tensorflow/transforms/rewrite_util.cc | 7 +- .../transforms/set_tpu_infeed_layout.cc | 2 +- .../tensorflow/transforms/shape_inference.cc | 415 +- .../transforms/shape_inference_pass.cc | 2 +- .../tensorflow/transforms/sink_constant.cc | 2 +- .../transforms/stack_ops_decomposition.cc | 37 +- .../transforms/strip_saved_module_metadata.cc | 4 +- .../transforms/strip_tf_attributes.cc | 4 +- .../tensor_array_ops_decomposition.cc | 131 +- .../tensor_list_ops_decomposition.cc | 135 +- .../transforms/tf_data_optimization.cc | 18 +- .../tensorflow/transforms/tf_device_passes.td | 76 +- .../transforms/tf_graph_optimization_pass.cc | 2 +- .../mlir/tensorflow/transforms/tf_passes.td | 53 +- .../tf_saved_model_freeze_variables.cc | 141 +- .../transforms/tf_saved_model_passes.h | 4 + .../transforms/tf_savedmodel_passes.td | 16 + .../mlir/tensorflow/transforms/tfg-to-tfe.cc | 2 +- .../transforms/tpu_cluster_formation.cc | 19 +- .../tpu_colocate_composite_resource_ops.cc | 6 +- .../transforms/tpu_dynamic_layout_pass.cc | 16 +- .../tpu_extract_outside_compilation.cc | 117 +- .../tpu_merge_variables_with_execute.cc | 37 +- ...pu_parallel_execute_sink_resource_write.cc | 9 +- .../tpu_partitioned_op_conversion.cc | 147 + ...eorder_replicate_and_partitioned_inputs.cc | 104 +- .../transforms/tpu_resource_partitioning.cc | 118 +- .../transforms/tpu_resource_read_for_write.cc | 12 +- .../tensorflow/transforms/tpu_rewrite_pass.cc | 57 +- .../tpu_sharding_identification_pass.cc | 59 +- .../transforms/tpu_space_to_depth_pass.cc | 62 +- .../tpu_variable_runtime_reformatting.cc | 22 +- .../transforms/unroll_batch_matmul.cc | 41 +- .../transforms/update_control_dependencies.cc | 348 +- .../transforms/xla_cluster_formation.cc | 62 +- .../mlir/tensorflow/transforms/xla_rewrite.cc | 70 +- .../tensorflow/translate/breakup-islands.cc | 2 +- .../tensorflow/translate/export_graphdef.cc | 9 +- .../tensorflow/translate/export_graphdef.h | 5 +- .../translate/export_tf_dialect_op.cc | 4 +- .../translate/export_tf_dialect_op.h | 1 - .../mlir/tensorflow/translate/import_model.cc | 74 +- .../mlir/tensorflow/translate/import_model.h | 55 +- .../translate/mlir_roundtrip_flags.cc | 5 +- .../translate/mlir_roundtrip_pass.cc | 7 +- .../split_into_island_per_op_pass.cc | 35 +- .../split_into_island_per_op_pass.h} | 18 +- .../tensorflow/translate/tf_mlir_translate.cc | 3 + .../tensorflow/translate/tf_mlir_translate.h | 5 +- .../tf_mlir_translate_registration.cc | 5 +- .../mlir/tensorflow/translate/upgrade_graph.h | 1 - .../mlir/tensorflow/utils/attribute_utils.cc | 36 + .../mlir/tensorflow/utils/attribute_utils.h | 66 + .../tensorflow/utils/compile_mlir_util.cc | 162 +- .../mlir/tensorflow/utils/compile_mlir_util.h | 36 +- .../mlir/tensorflow/utils/convert_attr.cc | 3 +- .../mlir/tensorflow/utils/convert_attr.h | 4 +- .../mlir/tensorflow/utils/convert_tensor.cc | 63 +- .../mlir/tensorflow/utils/convert_tensor.h | 3 +- .../tensorflow/utils/convert_tensor_test.cc | 8 +- .../mlir/tensorflow/utils/convert_type.cc | 17 +- .../mlir/tensorflow/utils/convert_type.h | 3 +- .../tensorflow/utils/convert_type_test.cc | 1 - .../mlir/tensorflow/utils/dump_graph.h | 5 +- .../mlir/tensorflow/utils/dump_mlir_util.cc | 5 +- .../tensorflow/utils/dump_mlir_util_test.cc | 59 +- .../tensorflow/utils/dynamic_shape_utils.cc | 2 +- .../mlir/tensorflow/utils/eval_util.cc | 1 - .../mlir/tensorflow/utils/export_utils.h | 3 +- .../tensorflow/utils/parallel_execute_util.cc | 2 +- .../mlir/tensorflow/utils/session_utils.cc | 6 +- .../tensorflow/utils/shape_inference_utils.cc | 4 +- .../tensorflow/utils/shape_inference_utils.h | 2 +- .../tensorflow/utils/tf_xla_mlir_translate.cc | 14 +- .../mlir/tensorflow/utils/topological_sort.cc | 157 + .../mlir/tensorflow/utils/topological_sort.h | 70 + .../utils/tpu_rewrite_device_util.cc | 1 - .../utils/tpu_rewrite_device_util.h | 4 +- .../mlir/tensorflow/utils/translate_utils.cc | 3 +- .../mlir/tensorflow/utils/translate_utils.h | 5 +- .../tensorflow/utils/verification_utils.cc | 2 +- .../tensorflow/utils/xla_sharding_util.cc | 40 +- tensorflow/compiler/mlir/tf2xla/BUILD | 2 + .../mlir/tf2xla/mlir_bridge_rollout_policy.h | 3 - tensorflow/compiler/mlir/tf_mlir_opt_main.cc | 32 +- .../compiler/mlir/tf_mlir_reduce_main.cc | 11 +- .../compiler/mlir/tf_mlir_translate_main.cc | 8 +- tensorflow/compiler/mlir/tfr/BUILD | 7 +- .../mlir/tfr/examples/customization/BUILD | 1 + .../examples/customization/test_ops_test.py | 2 +- .../compiler/mlir/tfr/examples/mnist/BUILD | 1 + .../compiler/mlir/tfr/examples/pad/BUILD | 1 + .../tfr/integration/graph_decompose_pass.cc | 2 +- .../tfr/integration/graph_decompose_pass.h | 2 +- .../tfr/integration/node_expansion_pass.cc | 2 +- .../tfr/integration/node_expansion_pass.h | 2 +- .../mlir/tfr/integration/tfr_decompose_ctx.cc | 2 +- .../mlir/tfr/integration/tfr_decompose_ctx.h | 4 +- .../tfr/integration/tfr_decompose_ctx_test.cc | 2 +- tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc | 36 +- tensorflow/compiler/mlir/tfr/ir/tfr_ops.td | 11 +- .../compiler/mlir/tfr/passes/canonicalize.cc | 4 +- tensorflow/compiler/mlir/tfr/resources/BUILD | 1 + .../compiler/mlir/tfr/tests/canonicalize.mlir | 2 +- tensorflow/compiler/mlir/tfr/tests/ops.mlir | 2 +- tensorflow/compiler/mlir/tfrt/BUILD | 99 +- .../mlir/tfrt/analysis/cost_analysis.cc | 28 +- .../mlir/tfrt/analysis/cost_analysis.h | 5 +- .../tensor_array_side_effect_analysis.cc | 4 +- .../tfrt/analysis/test_cost_analysis_pass.cc | 3 +- .../compiler/mlir/tfrt/benchmarks/BUILD | 83 +- .../mlir/tfrt/benchmarks/benchmark.cc | 11 + .../compiler/mlir/tfrt/benchmarks/benchmark.h | 5 +- .../benchmarks/benchmark_mlir_function.cc | 4 + .../mlir/tfrt/benchmarks/concat_benchmark.cc | 435 ++ .../benchmarks/fused_map_bcast_benchmark.cc | 128 + .../mlir/tfrt/benchmarks/map_op_benchmark.cc | 115 + .../tfrt/benchmarks/matmul_op_benchmark.cc | 111 +- .../tfrt/benchmarks/matmul_op_benchmark.h | 68 +- .../tfrt/benchmarks/reverse_op_benchmark.cc | 207 + .../tfrt/benchmarks/scatter_op_benchmark.cc | 130 + .../tfrt/benchmarks/softmax_op_benchmark.cc | 9 +- .../compiler/mlir/tfrt/function/function.cc | 6 +- .../compiler/mlir/tfrt/function/function.h | 3 - tensorflow/compiler/mlir/tfrt/ir/BUILD | 61 + tensorflow/compiler/mlir/tfrt/ir/gpu_ops.cc | 41 + tensorflow/compiler/mlir/tfrt/ir/gpu_ops.h | 40 + tensorflow/compiler/mlir/tfrt/ir/gpu_ops.td | 91 + .../compiler/mlir/tfrt/ir/tfrt_fallback.td | 3 +- .../mlir/tfrt/ir/tfrt_fallback_async.cc | 6 +- .../mlir/tfrt/ir/tfrt_fallback_async.td | 3 +- .../mlir/tfrt/ir/tfrt_fallback_sync.cc | 30 - .../mlir/tfrt/ir/tfrt_fallback_sync.td | 40 +- .../compiler/mlir/tfrt/jit/default/BUILD | 1 + .../mlir/tfrt/jit/opdefs/tf_jitrt_ops.cc | 4 +- .../mlir/tfrt/jit/opdefs/tf_jitrt_ops.td | 23 +- .../mlir/tfrt/jit/python_binding/BUILD | 1 + .../jit/python_binding/tf_jitrt_executor.cc | 21 +- .../jit/python_binding/tf_jitrt_executor.h | 4 +- .../mlir/tfrt/jit/tf_jitrt_kernels.cc | 5 +- .../mlir/tfrt/jit/tf_jitrt_pipeline.cc | 109 +- .../mlir/tfrt/jit/tf_jitrt_pipeline.h | 11 + .../compiler/mlir/tfrt/jit/transforms/BUILD | 10 +- .../jit/transforms/tf_jitrt_clustering.cc | 70 +- .../transforms/tf_jitrt_clustering_pass.cc | 4 +- .../transforms/tf_jitrt_detensorize_linalg.cc | 130 - .../tfrt/jit/transforms/tf_jitrt_fission.cc | 8 +- ...tf_jitrt_fuse_fill_into_tiled_reduction.cc | 340 - .../tfrt/jit/transforms/tf_jitrt_fusion.cc | 3 +- .../tf_jitrt_lower_vector_transpose.cc | 67 - .../transforms/tf_jitrt_math_approximation.cc | 2 +- .../tfrt/jit/transforms/tf_jitrt_passes.cc | 26 - .../tfrt/jit/transforms/tf_jitrt_passes.h | 65 +- .../tfrt/jit/transforms/tf_jitrt_passes.td | 148 +- .../transforms/tf_jitrt_peel_tiled_loops.cc | 102 - .../tf_jitrt_symbolic_shape_optimization.cc | 327 - .../jit/transforms/tf_jitrt_tile_cwise.cc | 174 - .../jit/transforms/tf_jitrt_tile_reduction.cc | 420 -- .../jit/transforms/tf_jitrt_tile_transpose.cc | 179 - .../compiler/mlir/tfrt/lhlo-tfrt-opt.cc | 4 +- .../compiler/mlir/tfrt/python_tests/BUILD | 26 + .../tfrt/python_tests/python_test_attrs.td | 3 +- .../regression_tests/broadcasting_25.mlir | 15 + .../mlir/tfrt/python_tests/tf_matmul_test.py | 12 +- .../tfrt/python_tests/tf_reduction_test.py | 4 +- .../mlir/tfrt/python_tests/tf_reverse_test.py | 114 + .../mlir/tfrt/python_tests/tf_scatter_test.py | 54 + .../runtime_fallback_executor.cc | 2 +- .../runtime_fallback_executor.h | 2 +- .../runtime_fallback/runtime_fallback_ops.td | 3 +- .../mlir/tfrt/saved_model/saved_model.cc | 7 +- tensorflow/compiler/mlir/tfrt/tests/BUILD | 5 +- .../compiler/mlir/tfrt/tests/analysis/BUILD | 28 +- .../tfrt/tests/analysis/testdata/test.mlir | 7 + .../update_op_cost_in_tfrt_mlir_test.cc | 84 + ..._resource_variable_as_captured_tensor.mlir | 2 +- .../tfrt/tests/batch_function_lowering.mlir | 2 +- .../fuse_tpu_compile_and_execute_ops.mlir | 18 + .../mlir/tfrt/tests/hoist_invariant_ops.mlir | 50 +- tensorflow/compiler/mlir/tfrt/tests/ir/BUILD | 5 +- tensorflow/compiler/mlir/tfrt/tests/jit/BUILD | 6 +- .../tfrt/tests/jit/detensorize_linalg.mlir | 26 - .../jit/symbolic_shape_optimization.mlir | 203 - .../tfrt/tests/jit/tf_jitrt_benchmark_test.cc | 4 + .../tests/jit/tf_jitrt_codegen_transpose.mlir | 185 +- .../tf_jitrt_codegen_transpose_detection.mlir | 72 - ..._jitrt_fuse_fill_into_tiled_reduction.mlir | 230 - .../tests/jit/tf_jitrt_peel_tiled_loops.mlir | 373 - .../tfrt/tests/jit/tf_jitrt_pipeline.mlir | 12 +- .../tests/jit/tf_jitrt_pipeline_one_shot.mlir | 440 -- .../jit/tf_jitrt_pipeline_vectorized.mlir | 48 +- .../tfrt/tests/jit/tf_jitrt_tile_cwise.mlir | 61 - .../tfrt/tests/jit/tf_jitrt_tile_fill.mlir | 19 - .../tfrt/tests/jit/tf_jitrt_tile_matmul.mlir | 94 - .../tests/jit/tf_jitrt_tile_reduction.mlir | 238 - .../mlir/tfrt/tests/lhlo_to_jitrt/BUILD | 5 +- .../tfrt/tests/remote_run_encapsulate.mlir | 47 - .../mlir/tfrt/tests/saved_model/BUILD | 7 +- .../tests/saved_model/saved_model_test.cc | 62 + .../saved_model/testdata/xla_launch.mlir | 25 + .../xla_launch_xla_reduce_window.mlir | 22 + .../tfrt/tests/sink_in_invariant_ops.mlir | 248 + .../mlir/tfrt/tests/tf_to_corert/BUILD | 8 +- .../tfrt/tests/tf_to_corert/attributes.mlir | 6 +- .../tfrt/tests/tf_to_corert/auto-fusion.mlir | 2 +- .../mlir/tfrt/tests/tf_to_corert/basic.mlir | 2 +- .../tests/tf_to_corert/device_conversion.mlir | 16 +- .../tfrt/tests/tf_to_corert/fallback.mlir | 4 +- .../fallback_canonicalization.mlir | 2 +- .../tests/tf_to_corert/func_attributes.mlir | 2 +- .../func_attributes_multiple_callers.mlir | 40 + .../tf_to_corert/tf_to_corert_pipeline.mlir | 24 +- .../tf_to_corert_pipeline_cpurt.mlir | 1 - .../tf_to_corert_pipeline_refvar.mlir | 2 +- .../mlir/tfrt/tests/tf_to_corert/whileop.mlir | 33 + .../tfrt/tests/tf_to_tfrt_data/batch.mlir | 24 - .../tfrt/tests/tf_to_tfrt_data/range.mlir | 18 - .../mlir/tfrt/tests/xla_launch_fallback.mlir | 82 + tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc | 6 +- .../mlir/tfrt/transforms/corert_converter.cc | 95 +- .../mlir/tfrt/transforms/corert_converter.h | 33 +- .../transforms/deduplicate_batch_function.cc | 4 +- .../fuse_tpu_compile_and_execute_ops.cc | 59 +- .../mlir/tfrt/transforms/gpu_passes.cc | 32 + .../mlir/tfrt/transforms/gpu_passes.h | 37 + .../mlir/tfrt/transforms/lower_saved_model.cc | 48 +- .../mlir/tfrt/transforms/merge_tf_if_ops.cc | 21 +- .../compiler/mlir/tfrt/transforms/optimize.cc | 2 +- .../optimize_tf_control_flow_side_effect.cc | 10 +- .../compiler/mlir/tfrt/transforms/passes.h | 39 +- .../tfrt/transforms/remote_run_encapsulate.cc | 245 - .../transforms/remove_device_attribute.cc | 2 +- .../transforms/remove_tf_if_const_args.cc | 11 +- .../mlir/tfrt/transforms/reorder_assert.cc | 3 +- .../set_shape_invariant_in_while_ops.cc | 2 +- .../tfrt/transforms/sink_in_invariant_ops.cc | 186 + .../mlir/tfrt/transforms/tf_to_tfrt.cc | 428 +- .../mlir/tfrt/transforms/tf_to_tfrt_data.cc | 349 - .../mlir/tfrt/transforms/tf_to_tfrt_data.h | 32 - .../transforms/update_op_cost_in_tfrt_mlir.cc | 44 + .../transforms/update_op_cost_in_tfrt_mlir.h | 32 + .../compiler/mlir/tfrt/transforms/utils.cc | 59 + .../compiler/mlir/tfrt/transforms/utils.h | 31 + .../mlir/tfrt/translate/import_model.cc | 98 +- .../mlir/tfrt/translate/import_model.h | 9 +- .../tfrt/translate/tfrt_compile_options.cc | 20 +- .../tfrt/translate/tfrt_compile_options.h | 34 +- .../compiler/mlir/tools/kernel_gen/BUILD | 2 +- .../compiler/mlir/tools/kernel_gen/ir/BUILD | 1 + .../tools/kernel_gen/ir/tf_framework_ops.td | 2 +- .../mlir/tools/kernel_gen/kernel_creator.cc | 39 +- .../mlir/tools/kernel_gen/kernel_creator.h | 4 +- .../mlir/tools/kernel_gen/tests/BUILD | 5 +- .../tools/kernel_gen/tests/copy_cleanup.mlir | 12 +- .../jit_i64_indexed_for_large_tensors.mlir | 44 - .../tools/kernel_gen/tests/print_memrefs.mlir | 55 - .../tests/tf_to_jit_invocations.mlir | 120 +- .../tools/kernel_gen/tests/tf_to_kernel/BUILD | 5 +- .../kernel_gen/tf_framework_c_interface.cc | 4 +- .../mlir/tools/kernel_gen/tf_to_kernel.cc | 21 +- .../tools/kernel-gen-opt/kernel-gen-opt.cc | 8 +- .../mlir/tools/kernel_gen/transforms/BUILD | 2 +- .../transforms/buffer_reuse_pass.cc | 2 +- .../tools/kernel_gen/transforms/bufferize.cc | 2 +- .../kernel_gen/transforms/bufferize_pass.cc | 2 +- .../transforms/embed_memref_prints.cc | 196 - .../transforms/gpu_kernel_to_blob_pass.cc | 20 +- .../mlir/tools/kernel_gen/transforms/passes.h | 4 - .../tools/kernel_gen/transforms/passes.td | 6 - .../transforms/same_shape_propagation.cc | 7 +- .../tensorflow_abi_knowledge_propagation.cc | 2 +- .../tf_framework_legalize_to_llvm.cc | 31 +- .../transforms/tf_kernel_to_llvm_pass.cc | 2 +- .../transforms/tf_to_jit_invocations.cc | 147 +- tensorflow/compiler/mlir/tosa/BUILD | 2 + tensorflow/compiler/mlir/tosa/tests/BUILD | 5 +- .../mlir/tosa/tests/tf-to-tosa-pipeline.mlir | 217 +- .../tests/tfl-to-tosa-pipeline-filtered.mlir | 2 +- .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 516 +- .../mlir/tosa/transforms/convert_tfl_uint8.cc | 12 +- .../mlir/tosa/transforms/fuse_bias_tf.cc | 23 +- .../mlir/tosa/transforms/legalize_common.cc | 403 +- .../mlir/tosa/transforms/legalize_common.h | 6 +- .../mlir/tosa/transforms/legalize_tf.cc | 700 +- .../mlir/tosa/transforms/legalize_tfl.cc | 633 +- .../mlir/tosa/transforms/legalize_utils.cc | 64 +- .../mlir/tosa/transforms/legalize_utils.h | 13 +- tensorflow/compiler/mlir/xla/BUILD | 268 +- tensorflow/compiler/mlir/xla/tests/BUILD | 11 +- .../mlir/xla/tests/adjust-layout.mlir | 2 +- .../xla/tests/convert-mhlo-quant-to-int.mlir | 51 + .../xla/tests/hlo_xla_runtime_pipeline.mlir | 102 +- .../xla/tests/hlo_xla_sparsification.mlir | 33 + .../xla/tests/legalize-tf-control-flow.mlir | 223 - .../tests/legalize-tf-no-tf2xla-fallback.mlir | 98 +- .../xla/tests/legalize-tf-prefer-tf2xla.mlir | 8 +- .../xla/tests/legalize-tf-with-tf2xla.mlir | 73 +- .../compiler/mlir/xla/tests/legalize-tf.mlir | 650 +- .../verify-tfxla-legalization-no-chlo.mlir | 11 + .../xla/tests/verify-tfxla-legalization.mlir | 39 + .../mlir/xla/transforms/adjust_layout.cc | 14 +- .../transforms/convert_mhlo_quant_to_int.cc | 240 + .../mlir/xla/transforms/legalize_tf.cc | 1671 ++--- .../xla/transforms/legalize_tf_collective.cc | 37 +- .../transforms/legalize_tf_communication.cc | 14 +- .../transforms/legalize_tf_control_flow.cc | 452 -- .../xla/transforms/legalize_tf_patterns.td | 251 +- .../mlir/xla/transforms/legalize_tf_types.cc | 3 +- .../xla/transforms/legalize_tf_with_tf2xla.cc | 620 +- .../compiler/mlir/xla/transforms/passes.h | 27 +- .../mlir/xla/transforms/tf_xla_passes.td | 19 - .../compiler/mlir/xla/transforms/utils.cc | 2 +- .../compiler/mlir/xla/transforms/utils.h | 2 +- .../transforms/verify_tfxla_legalization.cc | 76 + .../xla/transforms/xla_legalize_targets.cc | 56 + .../xla/transforms/xla_legalize_targets.h | 34 + .../transforms/xla_legalize_targets_test.cc | 96 + .../mlir/xla/transforms/xla_legalize_tf.cc | 494 +- .../transforms/xla_legalize_tf_no_fallback.cc | 2 +- .../xla/transforms/xla_legalize_tf_passes.td | 35 +- tensorflow/compiler/mlir/xla/xla_opt_main.cc | 16 +- tensorflow/compiler/plugin/BUILD | 1 + tensorflow/compiler/tests/BUILD | 118 +- tensorflow/compiler/tests/approx_topk_test.py | 43 +- tensorflow/compiler/tests/async_comp_test.py | 2 + tensorflow/compiler/tests/binary_ops_test.py | 10 +- tensorflow/compiler/tests/bincount_op_test.py | 40 + tensorflow/compiler/tests/build_defs.bzl | 9 +- tensorflow/compiler/tests/const_test.py | 60 + tensorflow/compiler/tests/fft_test.py | 30 +- tensorflow/compiler/tests/fifo_queue_test.py | 6 +- ...onst_op_test.py => giant_const_op_test.py} | 6 +- tensorflow/compiler/tests/image_ops_test.py | 5 +- tensorflow/compiler/tests/pooling_ops_test.py | 30 + tensorflow/compiler/tests/random_ops_test.py | 6 + .../tests/reverse_sequence_op_args_test.py | 52 + .../tests/segment_reduction_ops_test.py | 40 + .../tests/stateless_random_ops_test.py | 225 +- .../compiler/tests/tensor_list_ops_test.py | 11 + tensorflow/compiler/tests/unary_ops_test.py | 86 +- tensorflow/compiler/tests/unique_ops_test.py | 46 + .../compiler/tests/xla_call_module_test.py | 376 +- tensorflow/compiler/tests/xla_ops_test.py | 4 +- tensorflow/compiler/tf2tensorrt/BUILD | 8 +- .../tf2tensorrt/convert/convert_graph.cc | 10 +- .../tf2tensorrt/convert/convert_nodes.cc | 30 +- .../tf2tensorrt/convert/convert_nodes.h | 3 +- .../tf2tensorrt/convert/convert_nodes_test.cc | 46 +- .../convert/op_converter_registry.cc | 38 +- .../tf2tensorrt/convert/ops/fill_ops.cc | 2 +- .../tf2tensorrt/convert/ops/slice_ops.cc | 86 +- .../compiler/tf2tensorrt/tensorrt_test.cc | 4 +- .../compiler/tf2tensorrt/utils/py_utils.cc | 5 +- .../tf2tensorrt/utils/trt_engine_utils.h | 3 +- tensorflow/compiler/tf2xla/BUILD | 20 +- tensorflow/compiler/tf2xla/cc/BUILD | 2 + tensorflow/compiler/tf2xla/kernels/BUILD | 1 + .../compiler/tf2xla/kernels/binary_ops.cc | 19 +- .../compiler/tf2xla/kernels/bincount_op.cc | 17 +- .../tf2xla/kernels/conv_op_helpers.cc | 6 +- .../compiler/tf2xla/kernels/conv_op_helpers.h | 7 +- .../compiler/tf2xla/kernels/conv_ops.cc | 52 +- .../tf2xla/kernels/dynamic_stitch_op.cc | 4 + .../kernels/extract_image_patches_op.cc | 8 +- .../compiler/tf2xla/kernels/if_while_utils.cc | 3 +- .../compiler/tf2xla/kernels/image_ops.cc | 28 +- .../kernels/light_outside_compilation.cc | 28 +- .../compiler/tf2xla/kernels/mirror_pad_op.cc | 1 - .../compiler/tf2xla/kernels/pooling_ops.cc | 108 +- .../compiler/tf2xla/kernels/resampler_ops.h | 2 +- .../tf2xla/kernels/reverse_sequence_op.cc | 5 + .../compiler/tf2xla/kernels/scatter_nd_op.cc | 10 +- .../tf2xla/kernels/segment_reduction_ops.cc | 43 +- .../tf2xla/kernels/sparse_to_dense_op.cc | 1 + .../compiler/tf2xla/kernels/split_op.cc | 30 +- .../compiler/tf2xla/kernels/stack_ops.cc | 2 +- .../tf2xla/kernels/stateful_random_ops.cc | 27 +- .../tf2xla/kernels/stateless_random_ops.cc | 40 +- .../tf2xla/kernels/stateless_random_ops_v2.cc | 31 +- .../tf2xla/kernels/tensor_array_ops.cc | 2 +- .../tf2xla/kernels/tensor_list_ops.cc | 2 + .../compiler/tf2xla/kernels/transpose_op.cc | 5 +- .../compiler/tf2xla/kernels/unary_ops.cc | 2 +- .../compiler/tf2xla/kernels/unique_op.cc | 4 + .../compiler/tf2xla/kernels/variable_ops.cc | 2 +- .../tf2xla/kernels/xla_call_module_op.cc | 140 +- tensorflow/compiler/tf2xla/lib/BUILD | 1 + tensorflow/compiler/tf2xla/lib/scatter.cc | 3 +- tensorflow/compiler/tf2xla/lib/scatter.h | 1 + ...ht_outside_compilation_kernels_for_test.cc | 2 +- .../compiler/tf2xla/mlir_bridge_pass.cc | 58 +- tensorflow/compiler/tf2xla/mlir_tf2xla.cc | 2 +- .../compiler/tf2xla/mlir_xla_op_kernel.cc | 5 +- tensorflow/compiler/tf2xla/ops/BUILD | 1 + tensorflow/compiler/tf2xla/ops/xla_ops.cc | 3 +- tensorflow/compiler/tf2xla/python/BUILD | 1 + tensorflow/compiler/tf2xla/python/xla.py | 3 +- tensorflow/compiler/tf2xla/resource_util.cc | 3 +- tensorflow/compiler/tf2xla/resource_util.h | 1 - tensorflow/compiler/tf2xla/shape_util.cc | 2 +- tensorflow/compiler/tf2xla/type_util.cc | 10 + tensorflow/compiler/tf2xla/xla_argument.h | 2 +- tensorflow/compiler/tf2xla/xla_compiler.cc | 84 +- tensorflow/compiler/tf2xla/xla_compiler.h | 11 + .../compiler/tf2xla/xla_compiler_test.cc | 70 +- tensorflow/compiler/tf2xla/xla_expression.cc | 5 +- tensorflow/compiler/tf2xla/xla_expression.h | 2 +- tensorflow/compiler/tf2xla/xla_helpers.h | 2 +- .../xla_jit_compiled_cpu_function_test.cc | 13 +- tensorflow/compiler/tf2xla/xla_op_kernel.cc | 13 +- tensorflow/compiler/tf2xla/xla_op_kernel.h | 4 + tensorflow/compiler/tf2xla/xla_op_registry.cc | 2 +- tensorflow/compiler/tf2xla/xla_op_registry.h | 26 +- .../compiler/tf2xla/xla_op_registry_test.cc | 19 + tensorflow/compiler/tf2xla/xla_resource.cc | 6 +- tensorflow/compiler/xla/BUILD | 175 +- tensorflow/compiler/xla/README.md | 88 + tensorflow/compiler/xla/array.h | 27 +- tensorflow/compiler/xla/array2d.h | 6 +- .../compiler/xla/autotune_results.proto | 51 + tensorflow/compiler/xla/autotune_serialize.cc | 63 + tensorflow/compiler/xla/autotune_serialize.h | 70 + .../compiler/xla/backends/interpreter/BUILD | 13 +- .../xla/backends/interpreter/compiler.h | 2 +- .../xla/backends/interpreter/executable.cc | 4 +- .../xla/backends/interpreter/executable.h | 2 +- .../backends/interpreter/executable_base.cc | 4 +- .../backends/interpreter/executable_base.h | 2 +- .../xla/backends/interpreter/executor.cc | 41 +- .../xla/backends/interpreter/executor.h | 92 +- .../interpreter_transfer_manager.h | 12 + .../xla/backends/interpreter/platform.cc | 20 +- .../xla/backends/interpreter/platform.h | 10 +- .../compiler/xla/backends/profiler/BUILD | 22 +- .../compiler/xla/backends/profiler/cpu/BUILD | 116 + .../xla/backends/profiler/cpu/host_tracer.cc | 122 + .../xla/backends/profiler/cpu/host_tracer.h | 43 + .../profiler}/cpu/host_tracer_factory.cc | 16 +- .../profiler}/cpu/metadata_collector.cc | 36 +- .../backends/profiler/cpu/metadata_utils.h | 55 + .../backends/profiler}/cpu/python_tracer.cc | 56 +- .../xla/backends/profiler/cpu/python_tracer.h | 43 + .../profiler}/cpu/python_tracer_factory.cc | 16 +- .../compiler/xla/backends/profiler/gpu/BUILD | 290 + .../backends/profiler}/gpu/cuda_test.cu.cc | 8 +- .../xla/backends/profiler/gpu/cuda_test.h | 55 + .../backends/profiler}/gpu/cupti_collector.cc | 99 +- .../backends/profiler/gpu/cupti_collector.h | 277 + .../profiler}/gpu/cupti_error_manager.cc | 10 +- .../profiler/gpu/cupti_error_manager.h | 277 + .../profiler}/gpu/cupti_error_manager_test.cc | 36 +- .../backends/profiler/gpu/cupti_interface.h | 204 + .../backends/profiler}/gpu/cupti_tracer.cc | 182 +- .../xla/backends/profiler/gpu/cupti_tracer.h | 156 + .../xla/backends/profiler}/gpu/cupti_utils.cc | 10 +- .../backends/profiler}/gpu/cupti_wrapper.cc | 9 +- .../xla/backends/profiler/gpu/cupti_wrapper.h | 185 + .../profiler}/gpu/device_tracer_cuda.cc | 52 +- .../profiler}/gpu/device_tracer_rocm.cc | 111 +- .../xla/backends/profiler/gpu/mock_cupti.h | 168 + .../xla/backends/profiler}/gpu/nvtx_utils.cc | 8 +- .../xla/backends/profiler/gpu/nvtx_utils.h | 58 + .../xla/backends/profiler}/gpu/rocm_tracer.cc | 136 +- .../xla/backends/profiler/gpu/rocm_tracer.h | 395 ++ .../compiler/xla/backends/profiler/tpu/BUILD | 9 +- .../xla/backends/profiler/tpu/tpu_tracer.cc | 24 +- tensorflow/compiler/xla/c/BUILD | 1 + tensorflow/compiler/xla/client/BUILD | 24 +- .../xla/client/executable_build_options.cc | 39 +- .../xla/client/executable_build_options.h | 30 +- tensorflow/compiler/xla/client/lib/BUILD | 3 +- tensorflow/compiler/xla/client/lib/math.cc | 4 - tensorflow/compiler/xla/client/lib/math.h | 3 - .../compiler/xla/client/lib/math_test.cc | 4 + tensorflow/compiler/xla/client/lib/prng.cc | 156 +- tensorflow/compiler/xla/client/padding.cc | 10 + .../compiler/xla/client/value_inference.cc | 10 +- .../compiler/xla/client/value_inference.h | 6 +- tensorflow/compiler/xla/client/xla_builder.cc | 354 +- tensorflow/compiler/xla/client/xla_builder.h | 147 +- .../compiler/xla/client/xla_builder_test.cc | 73 +- tensorflow/compiler/xla/comparison_util.cc | 11 + tensorflow/compiler/xla/comparison_util.h | 12 +- .../compiler/xla/debug_options_flags.cc | 561 +- tensorflow/compiler/xla/debug_options_flags.h | 13 +- tensorflow/compiler/xla/examples/axpy/BUILD | 29 + .../compiler/xla/examples/axpy/README.md | 221 + .../xla/examples/axpy/stablehlo_axpy.mlir | 9 + .../examples/axpy/stablehlo_compile_test.cc | 148 + .../compiler/xla/executable_run_options.cc | 10 + .../compiler/xla/executable_run_options.h | 70 +- .../xla/experimental/conv_emitter/BUILD | 10 +- .../experimental/conv_emitter/conv_emitter.cc | 4 +- .../experimental/conv_emitter/conv_emitter.h | 8 +- .../conv_emitter/conv_emitter_test.cc | 4 +- .../conv_emitter/conv_emitter_transforms.cc | 2 +- .../conv_emitter/conv_emitter_transforms.h | 6 +- .../conv_emitter/g3doc/conv_emitter.md | 0 .../hlo_opcode.h => frontend_attributes.cc} | 23 +- tensorflow/compiler/xla/frontend_attributes.h | 38 + .../xla/g3doc/images/batch_group_counts.svg | 1 + .../compiler/xla/g3doc/operation_semantics.md | 21 +- .../g3doc/tutorials/autoclustering_xla.ipynb | 6 +- .../xla/g3doc/tutorials/jit_compile.ipynb | 3 +- tensorflow/compiler/xla/glob_lit_test.bzl | 117 + tensorflow/compiler/xla/hlo/evaluator/BUILD | 14 +- .../xla/hlo/evaluator/hlo_evaluator.cc | 1005 +-- .../xla/hlo/evaluator/hlo_evaluator.h | 25 +- .../xla/hlo/evaluator/hlo_evaluator_test.cc | 141 +- .../evaluator/hlo_evaluator_typed_visitor.h | 594 +- .../hlo_evaluator_typed_visitor_float8.cc | 22 + .../hlo_evaluator_typed_visitor_int16.cc | 2 +- .../hlo_evaluator_typed_visitor_int32.cc | 2 +- .../hlo_evaluator_typed_visitor_int8.cc | 2 +- .../hlo_evaluator_typed_visitor_uint16.cc | 2 +- .../hlo_evaluator_typed_visitor_uint32.cc | 2 +- .../hlo_evaluator_typed_visitor_uint8.cc | 2 +- .../xla/hlo/experimental/auto_sharding/BUILD | 118 +- .../auto_sharding/auto_sharding.cc | 1050 ++- .../auto_sharding/auto_sharding.h | 100 +- .../auto_sharding/auto_sharding_cost_graph.h | 12 +- .../auto_sharding_dot_handler.cc | 224 +- .../auto_sharding/auto_sharding_runner.cc | 2 +- .../auto_sharding_solver_option.h | 107 + .../auto_sharding/auto_sharding_strategy.h | 607 +- .../auto_sharding/auto_sharding_util.cc | 811 +-- .../auto_sharding/auto_sharding_util.h | 249 +- .../auto_sharding/cluster_environment.cc | 248 + .../auto_sharding/cluster_environment.h | 170 + .../hlo/experimental/auto_sharding/matrix.h | 116 + .../hlo/experimental/auto_sharding/metrics.cc | 48 + .../hlo/experimental/auto_sharding/metrics.h | 31 + .../auto_sharding/profiling_result.h | 159 + tensorflow/compiler/xla/hlo/ir/BUILD | 17 +- .../compiler/xla/hlo/ir/dfs_hlo_visitor.h | 77 +- .../compiler/xla/hlo/ir/hlo_computation.cc | 159 +- .../compiler/xla/hlo/ir/hlo_computation.h | 35 +- .../compiler/xla/hlo/ir/hlo_instruction.cc | 181 +- .../compiler/xla/hlo/ir/hlo_instruction.h | 45 +- .../compiler/xla/hlo/ir/hlo_instructions.cc | 156 +- .../compiler/xla/hlo/ir/hlo_instructions.h | 66 +- tensorflow/compiler/xla/hlo/ir/hlo_module.cc | 141 +- tensorflow/compiler/xla/hlo/ir/hlo_module.h | 82 +- .../{service => hlo/ir}/hlo_module_group.cc | 2 +- .../{service => hlo/ir}/hlo_module_group.h | 8 +- tensorflow/compiler/xla/hlo/ir/hlo_opcode.h | 37 +- .../compiler/xla/hlo/ir/hlo_schedule.cc | 24 +- .../compiler/xla/hlo/ir/hlo_sharding.cc | 93 +- tensorflow/compiler/xla/hlo/ir/hlo_sharding.h | 65 +- .../xla/hlo/ir/hlo_sharding_metadata.cc | 6 +- tensorflow/compiler/xla/hlo/transforms/BUILD | 39 + .../transforms}/hlo_constant_splitter.cc | 2 +- .../transforms}/hlo_constant_splitter.h | 6 +- .../transforms}/hlo_constant_splitter_test.cc | 2 +- tensorflow/compiler/xla/index_util.cc | 64 - tensorflow/compiler/xla/index_util.h | 59 +- tensorflow/compiler/xla/layout.cc | 127 +- tensorflow/compiler/xla/layout.h | 21 +- tensorflow/compiler/xla/layout_util.cc | 126 +- tensorflow/compiler/xla/layout_util.h | 41 +- tensorflow/compiler/xla/layout_util_test.cc | 7 + tensorflow/compiler/xla/lazy.h | 45 + tensorflow/compiler/xla/literal.cc | 409 +- tensorflow/compiler/xla/literal.h | 56 +- tensorflow/compiler/xla/literal_comparison.cc | 125 +- tensorflow/compiler/xla/literal_comparison.h | 5 + tensorflow/compiler/xla/literal_test.cc | 124 +- tensorflow/compiler/xla/literal_util.cc | 5 + .../xla/mlir/{tools => backends/cpu}/BUILD | 33 +- .../cpu => backends/cpu/transforms}/BUILD | 19 +- .../cpu/transforms/legalize_collective_ops.cc | 303 + .../legalize_i1_vector_transfers.cc | 139 + .../cpu/transforms/lmhlo_to_cpu_runtime.cc | 515 ++ .../cpu => backends/cpu/transforms}/passes.h | 20 +- .../cpu => backends/cpu/transforms}/passes.td | 57 + .../transforms/remove_copies_to_out_params.cc | 129 + .../mlir/backends/cpu/transforms/tests/BUILD | 24 + .../cpu/transforms/tests/collective_ops.mlir | 256 + .../tests/collective_ops_to_cpu_runtime.mlir | 102 + .../backends/cpu/transforms/tests/fft.mlir | 16 + .../tests/legalize_i1_vector_transfers.mlir | 35 + .../transforms}/tests/lmhlo_custom_call.mlir | 27 +- .../cpu/transforms/tests/lmhlo_infeed.mlir | 13 + .../tests/remove_copies_to_out_params.mlir | 127 + .../transforms/tests/rng_bit_generator.mlir | 16 + .../tests/xla_abi_legalization.mlir | 20 +- .../xla_cpu_memref_element_cast_to_llvm.mlir | 48 + .../cpu/transforms/tests/xla_cpu_outfeed.mlir | 37 + .../cpu/transforms}/xla_abi_legalization.cc | 35 +- .../xla_cpu_memref_element_cast_to_llvm.cc | 117 + .../cpu/xla-cpu-opt.cc} | 34 +- .../compiler/xla/mlir/backends/gpu/BUILD | 22 + .../gpu => backends/gpu/transforms}/BUILD | 10 +- .../transforms}/add_hlo_trace_annotations.cc | 14 +- .../gpu/transforms}/gpu_to_gpu_runtime.cc | 28 +- .../transforms}/lmhlo_gpu_to_gpu_runtime.cc | 122 +- .../gpu/transforms}/lmhlo_to_gpu_launch.cc | 160 +- .../gpu/transforms}/lmhlo_to_gpu_runtime.cc | 347 +- .../transforms}/memref_get_global_to_arg.cc | 4 +- .../gpu/transforms/outline_cuda_graphs.cc | 354 + .../gpu => backends/gpu/transforms}/passes.cc | 27 +- .../gpu => backends/gpu/transforms}/passes.h | 26 +- .../gpu => backends/gpu/transforms}/passes.td | 26 +- .../mlir/backends/gpu/transforms/tests/BUILD | 24 + .../gpu/transforms}/tests/add_hlo_trace.mlir | 4 +- .../gpu/transforms}/tests/gpu_launch.mlir | 22 +- .../gpu/transforms}/tests/gpu_memcpy.mlir | 0 .../gpu/transforms}/tests/gpu_memset.mlir | 0 .../gpu/transforms}/tests/lmhlo_case.mlir | 0 .../transforms}/tests/lmhlo_custom_call.mlir | 4 +- .../gpu/transforms}/tests/lmhlo_fft.mlir | 1 + .../transforms}/tests/lmhlo_gpu_cholesky.mlir | 0 .../gpu/transforms}/tests/lmhlo_gpu_conv.mlir | 0 .../tests/lmhlo_gpu_cublas_lt_matmul.mlir | 10 +- .../gpu/transforms}/tests/lmhlo_gpu_gemm.mlir | 0 .../gpu/transforms}/tests/lmhlo_infeed.mlir | 0 .../gpu/transforms}/tests/lmhlo_outfeed.mlir | 0 .../gpu/transforms/tests/lmhlo_send_recv.mlir | 102 + .../gpu/transforms}/tests/lmhlo_while.mlir | 0 .../tests/memref_get_global_to_arg.mlir | 2 +- .../transforms/tests/outline_cuda_graphs.mlir | 290 + .../gpu/transforms}/uid_generator.h | 6 +- .../gpu/xla-gpu-opt.cc} | 13 +- .../compiler/xla/mlir/framework/ir/BUILD | 74 + .../mlir/framework}/ir/xla_framework.cc | 12 +- .../mlir/framework}/ir/xla_framework.h | 12 +- .../mlir/framework}/ir/xla_framework_ops.td | 2 +- .../compiler/xla/mlir/framework/tests/BUILD | 28 + .../tests/legalize-xla-framework.mlir | 2 +- .../tests/outline-with-xla-framework.mlir | 2 +- .../mlir/framework}/tests/xla-framework.mlir | 2 +- .../xla/mlir/framework/transforms/BUILD | 57 + .../transforms/outline_with_xla_framework.cc | 17 +- .../mlir/framework/transforms/passes.h} | 8 +- .../mlir/framework/transforms/passes.td} | 0 .../transforms/xla_framework_to_llvm_pass.cc | 12 +- tensorflow/compiler/xla/mlir/math/BUILD | 16 + .../math => math/transforms}/BUILD | 20 +- .../math/transforms/math_approximation.cc | 293 + .../mlir/math/transforms/math_legalization.cc | 75 + .../transforms}/math_optimization.cc | 6 +- .../math => math/transforms}/passes.h | 22 +- .../math => math/transforms}/passes.td | 35 +- .../math => math/transforms}/tests/BUILD | 9 +- .../transforms/tests/math_legalization.mlir | 19 + .../transforms}/tests/math_optimization.mlir | 0 tensorflow/compiler/xla/mlir/memref/BUILD | 16 + .../memref => memref/transforms}/BUILD | 8 +- .../transforms}/aligned_allocations.cc | 6 +- .../memref => memref/transforms}/passes.h | 12 +- .../memref => memref/transforms}/passes.td | 2 +- .../memref => memref/transforms}/tests/BUILD | 9 +- .../tests/aligned_allocations.mlir | 0 tensorflow/compiler/xla/mlir/runtime/BUILD | 12 +- tensorflow/compiler/xla/mlir/runtime/ir/BUILD | 30 +- .../xla/mlir/runtime/ir/rt_dialect.cc | 2 +- .../xla/mlir/runtime/ir/rt_dialect.td | 9 +- .../compiler/xla/mlir/runtime/ir/rt_ops.cc | 20 +- .../compiler/xla/mlir/runtime/ir/rt_ops.td | 17 +- .../compiler/xla/mlir/runtime/ir/tests/BUILD | 6 +- .../xla/mlir/runtime/ir/tests/ops.mlir | 8 +- .../xla/mlir/runtime/ir/tests/ops_verify.mlir | 2 +- .../xla/mlir/runtime/ir/tests/testlib.td | 3 +- .../xla/mlir/runtime/transforms/BUILD | 42 +- .../transforms/compilation_pipeline_cpu.cc | 42 +- .../transforms/compilation_pipeline_gpu.cc | 29 +- .../transforms/custom_call_encoding.cc | 485 +- .../runtime/transforms/custom_call_encoding.h | 193 +- .../mlir/runtime/transforms/jit_compiler.cc | 135 +- .../mlir/runtime/transforms/jit_compiler.h | 23 +- .../xla/mlir/runtime/transforms/rt_to_llvm.cc | 288 +- .../mlir/runtime/transforms/specialization.cc | 19 +- .../mlir/runtime/transforms/specialization.h | 4 +- .../xla/mlir/runtime/transforms/tests/BUILD | 11 +- .../tests/convert_custom_calls.mlir | 4 +- .../runtime/transforms/tests/rt_to_llvm.mlir | 155 +- .../transforms/tests/testlib_pipeline.cc | 7 +- .../mlir/runtime/transforms/type_converter.cc | 2 + .../mlir/runtime/transforms/type_converter.h | 9 +- .../compiler/xla/mlir/runtime/utils/BUILD | 3 +- .../mlir/runtime/utils/async_runtime_api.cc | 8 +- .../xla/mlir/runtime/utils/constraints.cc | 6 +- .../xla/mlir/runtime/utils/constraints.h | 6 +- .../xla/mlir/runtime/utils/custom_calls.cc | 2 +- .../xla/mlir/runtime/utils/custom_calls.h | 3 +- .../xla/mlir/runtime/xla-runtime-opt.cc | 12 +- .../compiler/xla/mlir/tools/mlir_bisect/BUILD | 57 + .../xla/mlir/tools/mlir_bisect/README.md | 85 + .../xla/mlir/tools/mlir_bisect/bisect_lib.cc | 81 + .../xla/mlir/tools/mlir_bisect/bisect_lib.h | 96 + .../xla/mlir/tools/mlir_bisect/mlir_bisect.cc | 346 + .../xla/mlir/tools/mlir_bisect/rewrites/BUILD | 24 + .../mlir/tools/mlir_bisect/rewrites/func.cc | 75 + .../tools/mlir_bisect/rewrites/general.cc | 173 + .../mlir/tools/mlir_bisect/rewrites/gml_st.cc | 54 + .../mlir/tools/mlir_bisect/rewrites/scf.cc | 102 + .../mlir_bisect/rewrites}/tests/BUILD | 11 +- .../tests/erase-op-without-results.mlir | 12 + .../rewrites/tests/inline-scf-while.mlir | 40 + .../tests/reduce-gml-st-parallel-bounds.mlir | 19 + .../tests/replace-op-with-constant.mlir | 26 + .../rewrites/tests/replace-op-with-value.mlir | 16 + .../tests/replace-operand-with-constant.mlir | 28 + ...eturn-operands-of-terminator-operands.mlir | 15 + .../rewrites/tests/truncate-function.mlir | 31 + .../xla/mlir/tools/mlir_bisect/test_passes.cc | 48 + .../xla/mlir/tools/mlir_bisect/test_passes.h | 29 + .../xla/mlir/tools/mlir_bisect/tests/BUILD | 24 + .../mlir/tools/mlir_bisect/tests/bisect.mlir | 46 + .../mlir/tools/mlir_bisect/tests/no-bug.mlir | 10 + .../tools/mlir_bisect/tests/snapshot.mlir | 12 + .../tools/mlir_bisect/tests/snapshot.mlir.pb | Bin 0 -> 68 bytes .../compiler/xla/mlir/tools/mlir_replay/BUILD | 54 + .../xla/mlir/tools/mlir_replay/README.md | 48 + .../xla/mlir/tools/mlir_replay/mlir_replay.cc | 230 + .../mlir/tools/mlir_replay/mlir_replay_lib.cc | 187 + .../mlir/tools/mlir_replay/mlir_replay_lib.h | 39 + .../xla/mlir/tools/mlir_replay/public/BUILD | 69 + .../mlir/tools/mlir_replay/public/README.md | 10 + .../mlir_replay/public/compiler_trace.proto | 31 + .../public/compiler_trace_instrumentation.cc | 61 + .../public/compiler_trace_instrumentation.h | 48 + .../mlir_replay/public/execution_trace.proto | 72 + .../public/execution_trace_utils.cc | 429 ++ .../public/execution_trace_utils.h | 73 + .../public/execution_trace_utils_test.cc | 138 + .../transforms/cpu/lmhlo_to_cpu_runtime.cc | 159 - .../gpu/launch_func_to_cuda_graph.cc | 185 - .../gpu/tests/gpu_launch_to_cuda_graph.mlir | 149 - tensorflow/compiler/xla/mlir/utils/BUILD | 5 +- tensorflow/compiler/xla/mlir/xla_cpu/ir/BUILD | 106 + .../compiler/xla/mlir/xla_cpu/ir/xla_cpu.cc | 164 + .../compiler/xla/mlir/xla_cpu/ir/xla_cpu.h | 38 + .../xla/mlir/xla_cpu/ir/xla_cpu_dialect.td | 33 + .../xla/mlir/xla_cpu/ir/xla_cpu_enums.td | 39 + .../xla/mlir/xla_cpu/ir/xla_cpu_ops.td | 339 + .../mlir/xla_cpu/tests}/BUILD | 14 +- .../xla/mlir/xla_cpu/tests/bufferize.mlir | 129 + .../xla/mlir/xla_cpu/tests/invalid.mlir | 7 + .../compiler/xla/mlir/xla_cpu/tests/ops.mlir | 16 + tensorflow/compiler/xla/mlir_hlo/BUILD | 930 +-- .../compiler/xla/mlir_hlo/CMakeLists.txt | 19 +- .../{lib/Analysis => analysis}/CMakeLists.txt | 4 +- .../test_userange_analysis.cc | 6 +- .../userange_analysis.cc | 2 +- .../Analysis => analysis}/userange_analysis.h | 7 +- .../xla/mlir_hlo/bindings/CMakeLists.txt | 5 + .../{lib/CAPI => bindings/c}/Attributes.cc | 152 +- .../mlir-hlo-c => bindings/c}/Attributes.h | 83 +- .../{lib/CAPI => bindings/c}/CMakeLists.txt | 0 .../{lib/CAPI => bindings/c}/Dialects.cc | 4 +- .../mlir-hlo-c => bindings/c}/Dialects.h | 6 +- .../{lib/CAPI => bindings/c}/Passes.cc | 4 +- .../mlir-hlo-c => bindings/c}/Passes.h | 6 +- .../{lib/CAPI => bindings/c}/Types.cc | 4 +- .../mlir-hlo-c => bindings/c}/Types.h | 6 +- .../{ => bindings}/python/CMakeLists.txt | 0 .../{ => bindings}/python/MlirHloModule.cc | 130 +- .../python/mlir/dialects/MhloOps.td | 2 +- .../python/mlir/dialects/mhlo.py | 0 .../xla/mlir_hlo/cmake/modules/CMakeLists.txt | 1 + .../{lib/Dialect => }/gml_st/CMakeLists.txt | 2 + .../Dialect => }/gml_st/IR/CMakeLists.txt | 24 + .../xla/mlir_hlo/gml_st/IR/gml_st_ops.cc | 1460 ++++ .../Dialect => }/gml_st/IR/gml_st_ops.h | 14 +- .../Dialect => }/gml_st/IR/gml_st_ops.td | 496 +- .../Dialect => }/gml_st/IR/gml_st_ops_base.td | 6 +- .../mlir-hlo/Dialect => }/gml_st/README.md | 0 .../IR => gml_st/interfaces}/CMakeLists.txt | 19 +- .../bufferizable_op_interface_impl.cc | 441 +- .../bufferizable_op_interface_impl.h | 6 +- .../mlir_hlo/gml_st/transforms/CMakeLists.txt | 114 + .../add_debug_info/add_debug_info.cc | 74 + .../collapse_shape/collapse_shape.cc | 351 + .../compose_extract_insert_slice.cc | 55 + .../cpu_tiling/cpu_tiling_pipeline.cc | 51 + .../cpu_tiling}/transform_map_for_cpu.cc | 106 +- .../cpu_tiling/transform_matmul_for_cpu.cc | 788 +++ .../cpu_tiling/transform_reduce_for_cpu.cc | 529 ++ .../cpu_tiling/transform_reverse_for_cpu.cc | 162 + .../cpu_tiling}/transform_scatter_for_cpu.cc | 84 +- .../cpu_tiling/transform_sort_for_cpu.cc | 118 + .../transform_transpose_for_cpu.cc | 51 +- .../gml_st/transforms/fusion/fusion.cc | 537 ++ .../gml_st/transforms/fusion/fusion.h | 86 + .../gml_st_simtfy/gml_st_simtfy.cc} | 346 +- .../transforms/gml_st_to_gpu/gml_st_to_gpu.cc | 364 + .../gml_st_to_scf}/gml_st_to_scf.cc | 88 +- .../transforms/gpu_tiling/greedy_fusion.cc | 159 + .../transforms/gpu_tiling}/tiling_cwise.cc | 35 +- .../transforms/gpu_tiling}/tiling_gpu_warp.cc | 269 +- .../xla/mlir_hlo/gml_st/transforms/passes.h | 202 + .../xla/mlir_hlo/gml_st/transforms/passes.td | 322 + .../gml_st/transforms/peeling/peeling.cc | 180 + .../gml_st/transforms/peeling/peeling.h | 57 + .../rewrite_vector_contract.cc | 132 + .../rewrite_vector_multi_reduction.cc} | 51 +- .../rewrite_vector_transpose.cc | 66 + .../scalarization}/scalarization.cc | 440 +- .../mlir_hlo/gml_st/transforms/test_passes.cc | 95 + .../gml_st/transforms/test_passes.h | 16 +- .../transforms/test_passes.td} | 9 +- .../transforms/tiling}/tiling.cc | 173 +- .../transforms/tiling}/tiling.h | 21 +- .../tiling_softmax}/tiling_softmax.cc | 110 +- .../mlir_hlo/gml_st/transforms/transforms.cc | 312 + .../mlir_hlo/gml_st/transforms/transforms.h | 116 + .../transform_matmul_for_triton.cc | 192 + .../transforms/vectorization/vectorization.cc | 93 + .../transforms/vectorization/vectorization.h | 61 + .../vectorization/vectorize_copy.cc | 108 + .../vectorization/vectorize_for_cpu.cc | 264 + .../vectorization/vectorize_for_gpu.cc} | 754 +- .../mlir-hlo => gml_st/utils}/CMakeLists.txt | 9 +- .../xla/mlir_hlo/gml_st/utils/linalg_utils.cc | 70 + .../utils}/linalg_utils.h | 12 +- .../utils}/vector_utils.h | 4 +- .../include/mlir-hlo/Dialect/CMakeLists.txt | 20 - .../mlir-hlo/Dialect/gml_st/CMakeLists.txt | 16 - .../Dialect/gml_st/IR/gml_st_legacy_ops.td | 339 - .../Dialect/gml_st/transforms/CMakeLists.txt | 28 - .../Dialect/gml_st/transforms/fusion.h | 38 - .../Dialect/gml_st/transforms/passes.h | 89 - .../Dialect/gml_st/transforms/passes.td | 159 - .../Dialect/gml_st/transforms/rewriters.h | 45 - .../Dialect/gml_st/transforms/test_passes.td | 52 - .../gml_st/transforms/tiling_interface.td | 113 - .../Dialect/gml_st/transforms/transforms.h | 109 - .../Dialect/lhlo/transforms/CMakeLists.txt | 19 - .../mlir-hlo/Dialect/mhlo/CMakeLists.txt | 17 - .../mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt | 27 - .../mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h | 127 - .../mlir-hlo/Dialect/mhlo/IR/hlo_ops_enums.td | 237 - .../mlir-hlo/Transforms/CMakeLists.txt | 23 - .../mlir-hlo/Transforms/gml_st_pipeline.h | 41 - .../mlir-hlo/Dialect => }/lhlo/CMakeLists.txt | 0 .../Dialect => }/lhlo/IR/CMakeLists.txt | 28 + .../Dialect => }/lhlo/IR/lhlo_dialect.td | 3 +- .../{lib/Dialect => }/lhlo/IR/lhlo_ops.cc | 44 +- .../mlir-hlo/Dialect => }/lhlo/IR/lhlo_ops.h | 14 +- .../mlir-hlo/Dialect => }/lhlo/IR/lhlo_ops.td | 146 +- .../Dialect => }/lhlo/IR/lhlo_ops_base.td | 12 +- .../Dialect => }/lhlo/IR/lhlo_ops_structs.h | 8 +- .../Dialect => }/lhlo/IR/lhlo_ops_structs.td | 2 +- .../lhlo/IR/lhlo_structured_interface.cc | 4 +- .../lhlo/IR/lhlo_structured_interface.h | 8 +- .../lhlo/IR/lhlo_structured_interface.td | 0 .../lhlo/transforms/CMakeLists.txt | 15 +- .../legalize_to_tensor_op.cc | 4 +- .../lhlo/transforms/lhlo_elemental_utils.cc | 12 +- .../lhlo/transforms/lhlo_elemental_utils.h | 6 +- .../lhlo_legalize_to_affine.cc | 11 +- .../lhlo_legalize_to_gpu.cc | 16 +- .../lhlo_legalize_to_parallel_loops.cc | 16 +- .../lhlo/transforms/lmhlo_passes.td | 12 - .../lhlo/transforms/map_hlo_to_lhlo_op.h | 11 +- .../lhlo/transforms/map_lhlo_to_hlo_op.h | 11 +- .../lhlo/transforms/map_lmhlo_to_scalar_op.h | 10 +- .../Dialect => }/lhlo/transforms/passes.h | 22 +- .../mlir-hlo => lhlo}/utils/lhlo_utils.h | 16 +- .../Dialect => }/lhlo_gpu/CMakeLists.txt | 0 .../Dialect => }/lhlo_gpu/IR/CMakeLists.txt | 16 + .../Dialect => }/lhlo_gpu/IR/lhlo_gpu_ops.cc | 30 +- .../Dialect => }/lhlo_gpu/IR/lhlo_gpu_ops.h | 16 +- .../Dialect => }/lhlo_gpu/IR/lhlo_gpu_ops.td | 72 +- .../lhlo_gpu/IR/lhlo_gpu_ops_base.td | 2 +- .../lhlo_gpu/IR/lhlo_gpu_ops_enums.td | 13 +- .../compiler/xla/mlir_hlo/lib/CMakeLists.txt | 20 - .../xla/mlir_hlo/lib/Dialect/CMakeLists.txt | 19 - .../lib/Dialect/gml_st/IR/CMakeLists.txt | 36 - .../lib/Dialect/gml_st/IR/gml_st_ops.cc | 1959 ------ .../Dialect/gml_st/transforms/CMakeLists.txt | 134 - .../transforms/collapse_materialize_ops.cc | 183 - .../lib/Dialect/gml_st/transforms/fusion.cc | 302 - .../Dialect/gml_st/transforms/linalg_utils.cc | 187 - .../Dialect/gml_st/transforms/test_passes.cc | 208 - .../transforms/tiling_interface_impl.cc | 182 - .../transforms/transform_matmul_for_cpu.cc | 116 - .../Dialect/gml_st/transforms/transforms.cc | 396 -- .../mlir_hlo/lib/Dialect/lhlo/CMakeLists.txt | 17 - .../lhlo/transforms/lhlo_fuse_linalg.cc | 224 - .../lib/Dialect/lhlo_gpu/CMakeLists.txt | 16 - .../transforms/hlo_legalize_to_stablehlo.cc | 226 - .../mhlo/transforms/lower_complex_patterns.td | 136 - .../lib/Transforms/alloc_to_arg_pass.cc | 71 - .../lib/Transforms/gml_st_pipeline.cc | 66 - .../lib/Transforms/hlo_to_gpu_pipeline.cc | 140 - .../lib/Transforms/inline_fusion_pass.cc | 71 - .../{lib/Dialect => }/mhlo/CMakeLists.txt | 2 + .../{lib/Dialect => }/mhlo/IR/CMakeLists.txt | 17 +- .../Dialect => }/mhlo/IR/chlo_canonicalize.td | 8 +- .../compiler/xla/mlir_hlo/mhlo/IR/hlo_base.td | 122 + .../{lib/Dialect => }/mhlo/IR/hlo_ops.cc | 3950 ++++------- .../mlir-hlo/Dialect => }/mhlo/IR/hlo_ops.h | 16 +- .../mlir-hlo/Dialect => }/mhlo/IR/hlo_ops.td | 1076 +-- .../Dialect => }/mhlo/IR/hlo_ops_attrs.td | 163 +- .../Dialect => }/mhlo/IR/hlo_ops_common.cc | 8 +- .../xla/mlir_hlo/mhlo/IR/hlo_ops_common.h | 57 + .../Dialect => }/mhlo/IR/hlo_ops_common.td | 12 +- .../xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td | 264 + .../Dialect => }/mhlo/IR/hlo_ops_typedefs.td | 10 +- .../{lib/Dialect => }/mhlo/IR/hlo_patterns.td | 40 +- .../Dialect => }/mhlo/IR/hlo_utils.td | 8 +- .../{lib/Dialect => }/mhlo/IR/init.cc | 4 +- .../Dialect => }/mhlo/IR/mhlo_bytecode.cc | 4 +- .../Dialect => }/mhlo/IR/mhlo_bytecode.h | 6 +- .../Dialect => }/mhlo/IR/mhlo_canonicalize.td | 42 +- .../mlir-hlo/Dialect => }/mhlo/IR/register.h | 0 .../xla/mlir_hlo/mhlo/analysis/CMakeLists.txt | 24 + .../analysis}/shape_component_analysis.cc | 13 +- .../analysis}/shape_component_analysis.h | 8 +- .../test_shape_component_analysis.cc | 6 +- .../bufferizable_op_interface_impl.h | 6 +- .../mhlo/transforms/CMakeLists.txt | 178 +- .../broadcast_propagation.cc | 8 +- .../chlo_legalize_to_hlo.cc | 59 +- .../chlo_legalize_to_hlo_pass.cc | 8 +- .../chlo_legalize_to_hlo_patterns.td | 285 +- .../collapse_elementwise_map.cc | 12 +- .../constraint_fusion_pass.cc | 4 +- .../convert_to_signless_pass.cc | 6 +- .../expand_hlo_tuples}/expand_hlo_tuples.cc | 25 +- .../expand_ops_simplifier.cc | 229 + .../group_reduction_dimensions.cc | 8 +- .../hlo_legalize_shape_ops_to_standard.cc | 12 +- .../hlo_legalize_to_arithmetic.cc | 28 +- .../hlo_legalize_to_lhlo.cc | 43 +- .../hlo_legalize_to_memref.cc | 167 +- .../hlo_legalize_to_stablehlo.cc | 384 ++ .../hlo_legalize_to_stablehlo_pass.cc | 14 +- .../legalize_control_flow.cc | 6 +- .../legalize_einsum_to_dot_general.cc | 8 +- .../legalize_gather_to_torch_index_select.cc | 8 +- .../legalize_mhlo_to_thlo.cc | 111 +- .../legalize_shape_computations.cc | 8 +- .../legalize_sort}/legalize_sort.cc | 16 +- .../legalize_to_linalg}/legalize_to_linalg.cc | 852 ++- .../legalize_to_standard.cc | 15 +- .../legalize_to_standard_patterns.td | 32 +- ...legalize_trigonometric_to_approximation.cc | 6 +- .../lower_complex}/lower_complex.cc | 12 +- .../lower_complex/lower_complex_patterns.td | 136 + .../lower_general_dot}/lower_general_dot.cc | 71 +- .../mhlo/transforms/map_chlo_to_hlo_op.h | 14 +- .../mhlo/transforms/map_mhlo_to_scalar_op.h | 74 +- .../mhlo/transforms/map_stablehlo_to_hlo_op.h | 17 +- .../materialize_broadcasts.cc | 2 +- .../materialize_broadcasts_pass.cc | 6 +- .../merge_assuming_ops}/merge_assuming_ops.cc | 22 +- .../mhlo_canonicalize_gather.cc | 22 +- .../mhlo_canonicalize_reduction.cc | 11 +- .../mhlo_canonicalize_scatter.cc | 20 +- .../mhlo_flatten_tuple}/mhlo_flatten_tuple.cc | 6 +- .../mhlo/transforms/mhlo_passes.td | 23 + .../optimize_mhlo}/optimize_mhlo.cc | 6 +- .../optimize_mhlo}/optimize_mhlo_pass.cc | 8 +- .../Dialect => }/mhlo/transforms/passes.h | 28 +- .../prepare_for_export}/prepare_for_export.cc | 12 +- .../rank_specialization.cc | 72 +- .../restrict_max_rank}/restrict_max_rank.cc | 6 +- .../Dialect => }/mhlo/transforms/rewriters.h | 12 +- .../shape_reification_pass.cc | 6 +- .../shape_simplification.cc | 22 +- .../sink_constants_to_control_flow.cc | 4 +- .../sparse_chlo_legalize_to_linalg.cc | 8 +- .../sparse_rewriting}/sparse_rewriting.cc | 8 +- .../stablehlo_legalize_to_hlo.cc | 46 +- .../stablehlo_legalize_to_hlo_pass.cc | 10 +- .../symbolic_shape_optimization.cc | 115 +- .../test_infer_shaped_type_pass.cc | 2 +- .../unfuse_batch_norm}/unfuse_batch_norm.cc | 2 +- .../unfuse_batch_norm_pass.cc | 6 +- .../lhlo/IR => mhlo/utils}/CMakeLists.txt | 50 +- .../utils}/legalize_to_linalg_utils.cc | 34 +- .../utils}/legalize_to_linalg_utils.h | 44 +- .../utils}/mhlo_scatter_gather_utils.cc | 6 +- .../utils}/mhlo_scatter_gather_utils.h | 2 +- .../utils}/type_conversion.cc | 15 +- .../utils}/type_conversion.h | 6 +- tensorflow/compiler/xla/mlir_hlo/tests/BUILD | 5 +- .../chlo/chlo_legalize_to_hlo_broadcasts.mlir | 18 +- .../tests/Dialect/gml_st/add_debug_info.mlir | 22 + .../tests/Dialect/gml_st/bufferization.mlir | 231 +- .../tests/Dialect/gml_st/canonicalize.mlir | 595 +- .../tests/Dialect/gml_st/collapse-shape.mlir | 288 + .../gml_st/collapse_materialize_ops.mlir | 26 - .../gml_st/compose_extract_insert_slice.mlir | 21 + .../gml_st/cpu_tiling/map_bcast_map.mlir | 35 + .../Dialect/gml_st/cpu_tiling/map_matmul.mlir | 70 + .../gml_st/cpu_tiling/map_reduce_map.mlir | 24 + .../Dialect/gml_st/cpu_tiling/matmul.mlir | 206 + .../Dialect/gml_st/cpu_tiling/reduce_1d.mlir | 52 + .../Dialect/gml_st/cpu_tiling/reduce_2d.mlir | 55 + .../Dialect/gml_st/cpu_tiling/reverse.mlir | 40 + .../scatter.mlir} | 7 +- .../tests/Dialect/gml_st/cpu_tiling/sort.mlir | 22 + .../Dialect/gml_st/cpu_tiling/transpose.mlir | 41 + .../mlir_hlo/tests/Dialect/gml_st/fusion.mlir | 195 +- .../tests/Dialect/gml_st/gml_st_simtfy.mlir | 221 + .../tests/Dialect/gml_st/gml_st_to_gpu.mlir | 273 +- .../tests/Dialect/gml_st/gml_st_to_scf.mlir | 214 - .../gml_st/{ => gpu_tiling}/tiling_cwise.mlir | 29 +- .../gml_st/gpu_tiling/tiling_gpu_warp.mlir | 316 + .../tests/Dialect/gml_st/greedy_fusion.mlir | 131 + .../gml_st/greedy_tiling_and_fusion.mlir | 143 + .../tests/Dialect/gml_st/invalid.mlir | 220 +- .../Dialect/gml_st/legacy_bufferization.mlir | 133 - .../Dialect/gml_st/legacy_loop_tiling.mlir | 93 - .../tests/Dialect/gml_st/legacy_peeling.mlir | 171 - .../tests/Dialect/gml_st/nested_tiling.mlir | 18 +- .../Dialect/gml_st/nested_tiling_cwise.mlir | 37 +- .../Dialect/gml_st/nested_tiling_softmax.mlir | 155 +- .../mlir_hlo/tests/Dialect/gml_st/ops.mlir | 351 +- .../gml_st/rewrite_vector_contract.mlir | 166 + .../rewrite_vector_multi_reduction.mlir} | 17 +- .../Dialect/gml_st/simplify_dead_copy.mlir | 198 + .../mlir_hlo/tests/Dialect/gml_st/tiling.mlir | 400 +- .../Dialect/gml_st/tiling_and_fusion.mlir | 18 +- .../tests/Dialect/gml_st/tiling_gpu_warp.mlir | 389 -- .../tests/Dialect/gml_st/tiling_softmax.mlir | 317 +- .../Dialect/gml_st/transform_map_for_cpu.mlir | 45 - .../gml_st/transform_matmul_for_cpu.mlir | 126 - .../gml_st/transform_transpose_for_cpu.mlir | 26 - .../transform_matmul_for_triton.mlir | 109 + .../tests/Dialect/gml_st/vectorization.mlir | 295 - .../tests/Dialect/gml_st/vectorize.mlir | 213 - .../tests/Dialect/gml_st/vectorize_copy.mlir | 28 + .../Dialect/gml_st/vectorize_for_cpu.mlir | 224 + .../Dialect/gml_st/vectorize_for_gpu.mlir | 300 + .../gml_st/vectorize_for_gpu_distributed.mlir | 235 + .../Dialect/gml_st/vectorize_gml_st.mlir | 249 - .../tests/Dialect/gml_st/warp_reduce.mlir | 226 + .../tests/Dialect/lhlo/lhlo-fuse-linalg.mlir | 429 -- .../lhlo-legalize-select-and-scatter.mlir | 2 +- .../xla/mlir_hlo/tests/Dialect/lhlo/ops.mlir | 78 +- .../mlir_hlo/tests/Dialect/mhlo/attrs.mlir | 35 + .../Dialect/mhlo/canonicalize/bitcast.mlir | 71 + .../mhlo/canonicalize/canonicalize.mlir | 157 +- .../mhlo/canonicalize/concatenate.mlir | 2 +- .../Dialect/mhlo/canonicalize/convert.mlir | 2 +- .../mhlo/canonicalize/convolution.mlir | 28 +- .../mhlo/canonicalize/custom_call.mlir | 2 +- .../mhlo/canonicalize/folder_limit.mlir | 2 +- .../Dialect/mhlo/canonicalize/reduce.mlir | 2 +- .../Dialect/mhlo/canonicalize/reshape.mlir | 2 +- .../Dialect/mhlo/canonicalize/reverse.mlir | 2 +- .../Dialect/mhlo/canonicalize/scatter.mlir | 2 +- .../Dialect/mhlo/canonicalize/transpose.mlir | 2 +- .../Dialect/mhlo/canonicalize/tuple.mlir | 2 +- .../tests/Dialect/mhlo/expand_hlo_tuples.mlir | 27 +- .../Dialect/mhlo/expand_ops_simplifier.mlir | 57 + .../Dialect/mhlo/hlo-legalize-to-lhlo.mlir | 6 +- .../Dialect/mhlo/hlo-legalize-to-linalg.mlir | 566 +- .../Dialect/mhlo/hlo-legalize-to-memref.mlir | 51 +- ...lo-legalize-to-stablehlo-experimental.mlir | 50 + .../mhlo/hlo-legalize-to-stablehlo.mlir | 174 +- .../mlir_hlo/tests/Dialect/mhlo/invalid.mlir | 65 + .../Dialect/mhlo/legalize-mhlo-to-thlo.mlir | 75 +- .../tests/Dialect/mhlo/lower-general-dot.mlir | 16 +- .../mhlo/mhlo_canonicalize_scatter.mlir | 31 + .../Dialect/mhlo/mhlo_flatten_tuple.mlir | 4 +- .../mhlo/mhlo_infer_shape_type_methods.mlir | 964 ++- .../Dialect/mhlo/mhlo_ops_prettyprint.mlir | 30 +- .../xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir | 1101 ++- .../tests/Dialect/mhlo/optimize-hlo.mlir | 2 +- .../mhlo/stablehlo-legalize-to-hlo.mlir | 132 +- .../mhlo/symbolic-shape-optimization.mlir | 148 +- .../tests/Dialect/mhlo/verifier_bounds.mlir | 13 + .../tests/Dialect/mhlo/verifier_conv_op.mlir | 199 +- .../Dialect/mhlo/verifier_reduce_op.mlir | 17 + .../Dialect/mhlo/verifier_scatter_op.mlir | 2 +- .../mhlo/verifier_select_and_scatter_op.mlir | 4 +- .../tests/Dialect/mhlo/verifier_while_op.mlir | 2 +- .../tests/Dialect/thlo/bufferize.mlir | 14 +- .../tests/Dialect/thlo/canonicalize.mlir | 15 + .../mlir_hlo/tests/Dialect/thlo/invalid.mlir | 30 +- .../tests/Dialect/thlo/legalize_sort.mlir | 3 +- .../xla/mlir_hlo/tests/Dialect/thlo/ops.mlir | 37 +- .../xla/mlir_hlo/tests/alloc_to_arg.mlir | 16 +- .../mlir_hlo/tests/bufferize_one_shot.mlir | 210 - .../compiler/xla/mlir_hlo/tests/capi_test.c | 8 +- .../mlir_hlo/tests/gpu_fusion_rewrite.mlir | 98 +- .../mlir_hlo/tests/hlo_to_gpu_pipeline.mlir | 25 +- .../tests/hlo_to_gpu_pipeline_softmax.mlir | 44 +- .../tests/hlo_to_triton_pipeline_softmax.mlir | 56 + .../tests/index_type_llvm_lowering.mlir | 12 +- .../xla/mlir_hlo/tests/inline_fusion.mlir | 15 - .../xla/mlir_hlo/tests/lower_index_cast.mlir | 28 + .../xla/mlir_hlo/tests/python/attributes.py | 279 +- .../xla/mlir_hlo/tests/python/types.py | 15 +- .../xla/mlir_hlo/tests/scalarization.mlir | 380 +- .../xla/mlir_hlo/tests/warp_reduce.mlir | 101 - .../{lib/Dialect => }/thlo/CMakeLists.txt | 1 + .../Dialect => }/thlo/IR/CMakeLists.txt | 21 + .../{lib/Dialect => }/thlo/IR/thlo_ops.cc | 760 ++- .../mlir-hlo/Dialect => }/thlo/IR/thlo_ops.h | 13 +- .../mlir-hlo/Dialect => }/thlo/IR/thlo_ops.td | 185 +- .../IR => thlo/interfaces}/CMakeLists.txt | 17 +- .../bufferizable_op_interface_impl.cc | 14 +- .../bufferizable_op_interface_impl.h | 6 +- .../thlo/transforms/CMakeLists.txt | 18 +- .../legalize_sort}/legalize_sort.cc | 294 +- .../Dialect => }/thlo/transforms/passes.h | 10 +- .../thlo/transforms/thlo_passes.td | 0 .../tools/mlir-hlo-opt/CMakeLists.txt | 2 + .../tools/mlir-hlo-opt/mlir-hlo-opt.cc | 42 +- .../tools/mlir_interpreter/dialects/affine.cc | 49 + .../tools/mlir_interpreter/dialects/arith.cc | 237 + .../dialects/bufferization.cc | 69 + .../mlir_interpreter/dialects/builtin.cc | 50 + .../mlir_interpreter/dialects/comparators.h | 104 + .../mlir_interpreter/dialects/cwise_math.h | 193 + .../tools/mlir_interpreter/dialects/func.cc | 116 + .../tools/mlir_interpreter/dialects/gml_st.cc | 178 + .../tools/mlir_interpreter/dialects/linalg.cc | 253 + .../tools/mlir_interpreter/dialects/math.cc | 47 + .../tools/mlir_interpreter/dialects/memref.cc | 230 + .../tools/mlir_interpreter/dialects/mhlo.cc | 776 +++ .../dialects/mhlo_binary_cwise.cc | 45 + .../dialects/mhlo_unary_cwise.cc | 112 + .../tools/mlir_interpreter/dialects/scf.cc | 159 + .../tools/mlir_interpreter/dialects/tensor.cc | 209 + .../mlir_interpreter/dialects}/tests/BUILD | 11 +- .../dialects/tests/affine/apply.mlir | 52 + .../dialects/tests/affine/minmax.mlir | 36 + .../dialects/tests/arith/cmpf.mlir | 118 + .../dialects/tests/arith/cmpi.mlir | 147 + .../dialects/tests/arith/constant.mlir | 37 + .../dialects/tests/arith/extf.mlir | 11 + .../dialects/tests/arith/index_cast.mlir | 28 + .../dialects/tests/arith/int_math.mlir | 67 + .../dialects/tests/arith/minmax.mlir | 25 + .../dialects/tests/arith/negf.mlir | 21 + .../dialects/tests/arith/remf.mlir | 12 + .../dialects/tests/arith/select.mlir | 15 + .../dialects/tests/arith/sitofp.mlir | 21 + .../dialects/tests/arith/uitofp.mlir | 21 + .../dialects/tests/arith/vector_math.mlir | 12 + .../tests/bufferization/alloc_tensor.mlir | 30 + .../dialects/tests/bufferization/clone.mlir | 14 + .../tests/bufferization/to_memref.mlir | 10 + .../tests/bufferization/to_tensor.mlir | 11 + .../builtin/unrealized_conversion_cast.mlir | 21 + .../dialects/tests/func/call.mlir | 48 + .../dialects/tests/gml_st/for.mlir | 20 + .../dialects/tests/gml_st/parallel.mlir | 20 + .../dialects/tests/linalg/broadcast.mlir | 30 + .../dialects/tests/linalg/fill.mlir | 24 + .../dialects/tests/linalg/generic.mlir | 75 + .../dialects/tests/linalg/map.mlir | 74 + .../dialects/tests/linalg/matmul.mlir | 41 + .../dialects/tests/linalg/reduce.mlir | 36 + .../dialects/tests/linalg/transpose.mlir | 27 + .../dialects/tests/math/math.mlir | 102 + .../dialects/tests/memref/alloc.mlir | 47 + .../dialects/tests/memref/collapse_shape.mlir | 33 + .../dialects/tests/memref/copy.mlir | 39 + .../dialects/tests/memref/dim.mlir | 12 + .../dialects/tests/memref/expand_shape.mlir | 52 + .../dialects/tests/memref/get_global.mlir | 12 + .../dialects/tests/memref/invalid.mlir | 50 + .../dialects/tests/memref/load.mlir | 12 + .../dialects/tests/memref/subview.mlir | 131 + .../dialects/tests/mhlo/broadcast_in_dim.mlir | 20 + .../dialects/tests/mhlo/case.mlir | 17 + .../dialects/tests/mhlo/compare.mlir | 143 + .../dialects/tests/mhlo/complex_math.mlir | 100 + .../tests/mhlo/compute_reshape_shape.mlir | 26 + .../dialects/tests/mhlo/constant.mlir | 25 + .../dialects/tests/mhlo/convert.mlir | 21 + .../dialects/tests/mhlo/dot.mlir | 37 + .../dialects/tests/mhlo/dot_general.mlir | 73 + .../dialects/tests/mhlo/dynamic_slice.mlir | 32 + .../tests/mhlo/dynamic_update_slice.mlir | 34 + .../dialects/tests/mhlo/float_math.mlir | 200 + .../dialects/tests/mhlo/gather.mlir | 78 + .../dialects/tests/mhlo/int_math.mlir | 358 + .../dialects/tests/mhlo/iota.mlir | 30 + .../dialects/tests/mhlo/pad.mlir | 56 + .../dialects/tests/mhlo/reduce.mlir | 17 + .../dialects/tests/mhlo/reshape.mlir | 34 + .../dialects/tests/mhlo/scatter.mlir | 55 + .../dialects/tests/mhlo/select.mlir | 14 + .../dialects/tests/mhlo/slice.mlir | 16 + .../dialects/tests/mhlo/subtract.mlir | 10 + .../dialects/tests/mhlo/transpose.mlir | 28 + .../dialects/tests/mhlo/tuple.mlir | 30 + .../dialects/tests/mhlo/while.mlir | 25 + .../dialects/tests/scf/for.mlir | 50 + .../dialects/tests/scf/if.mlir | 69 + .../dialects/tests/scf/parallel.mlir | 44 + .../dialects/tests/scf/while.mlir | 45 + .../dialects/tests/tensor/collapse_shape.mlir | 31 + .../dialects/tests/tensor/dim.mlir | 12 + .../dialects/tests/tensor/empty.mlir | 21 + .../dialects/tests/tensor/expand_shape.mlir | 30 + .../dialects/tests/tensor/extract.mlir | 13 + .../dialects/tests/tensor/extract_slice.mlir | 21 + .../dialects/tests/tensor/from_elements.mlir | 25 + .../dialects/tests/tensor/generate.mlir | 29 + .../dialects/tests/tensor/insert.mlir | 14 + .../dialects/tests/tensor/insert_slice.mlir | 25 + .../dialects/tests/tensor/pad.mlir | 38 + .../dialects/tests/vector/bitcast.mlir | 32 + .../dialects/tests/vector/broadcast.mlir | 51 + .../dialects/tests/vector/compressstore.mlir | 16 + .../dialects/tests/vector/constant_mask.mlir | 14 + .../dialects/tests/vector/contract.mlir | 141 + .../dialects/tests/vector/create_mask.mlir | 16 + .../dialects/tests/vector/expandload.mlir | 19 + .../dialects/tests/vector/extract.mlir | 52 + .../tests/vector/extract_strided_slice.mlir | 18 + .../dialects/tests/vector/extractelement.mlir | 22 + .../dialects/tests/vector/flat_transpose.mlir | 23 + .../dialects/tests/vector/fma.mlir | 13 + .../dialects/tests/vector/gather.mlir | 50 + .../dialects/tests/vector/insert.mlir | 57 + .../tests/vector/insert_strided_slice.mlir | 17 + .../dialects/tests/vector/insertelement.mlir | 24 + .../dialects/tests/vector/invalid.mlir | 27 + .../dialects/tests/vector/load.mlir | 27 + .../dialects/tests/vector/maskedload.mlir | 19 + .../dialects/tests/vector/maskedstore.mlir | 18 + .../tests/vector/multi_reduction.mlir | 46 + .../dialects/tests/vector/outerproduct.mlir | 155 + .../dialects/tests/vector/reduction.mlir | 235 + .../dialects/tests/vector/shape_cast.mlir | 23 + .../dialects/tests/vector/shuffle.mlir | 34 + .../dialects/tests/vector/splat.mlir | 11 + .../dialects/tests/vector/store.mlir | 39 + .../dialects/tests/vector/transfer_read.mlir | 103 + .../dialects/tests/vector/transfer_write.mlir | 91 + .../dialects/tests/vector/transpose.mlir | 28 + .../dialects/tests/vector/type_cast.mlir | 11 + .../dialects/tests/vector/vscale.mlir | 12 + .../tools/mlir_interpreter/dialects/util.cc | 163 + .../tools/mlir_interpreter/dialects/util.h | 77 + .../tools/mlir_interpreter/dialects/vector.cc | 856 +++ .../mlir_interpreter/framework/interpreter.cc | 126 + .../mlir_interpreter/framework/interpreter.h | 191 + .../framework/interpreter_value.cc | 361 + .../framework/interpreter_value.h | 224 + .../framework/interpreter_value_util.h | 176 + .../framework/registration.cc | 120 + .../mlir_interpreter/framework/registration.h | 225 + .../framework/tensor_or_memref.cc | 156 + .../framework/tensor_or_memref.h | 291 + .../mlir_interpreter/framework/tests/BUILD | 22 + .../framework/tests/interpreter_value_test.cc | 241 + .../framework/tests/tensor_or_memref_test.cc | 104 + .../mlir-interpreter-runner.cc | 139 + tensorflow/compiler/xla/mlir_hlo/tosa/BUILD | 66 +- .../compiler/xla/mlir_hlo/tosa/CMakeLists.txt | 12 +- .../xla/mlir_hlo/tosa/mhlo_tosa_opt.cc | 5 +- .../compiler/xla/mlir_hlo/tosa/tests/BUILD | 7 +- .../xla/mlir_hlo/tosa/tests/binary.mlir | 62 +- .../xla/mlir_hlo/tosa/tests/nullary.mlir | 4 +- .../xla/mlir_hlo/tosa/tests/prepare-mhlo.mlir | 52 + .../xla/mlir_hlo/tosa/tests/unary.mlir | 31 +- .../xla/mlir_hlo/tosa/transforms/BUILD | 80 + .../Transforms => transforms}/CMakeLists.txt | 10 +- .../legalize_mhlo}/legalize_mhlo.cc | 174 +- .../legalize_mhlo}/legalize_mhlo.pdll | 2 +- .../Transforms => transforms}/passes.h | 11 +- .../xla/mlir_hlo/tosa/transforms/passes.td | 32 + .../transforms/prepare_mhlo/prepare_mhlo.cc | 57 + .../Transforms => transforms}/CMakeLists.txt | 15 +- .../mlir_hlo/transforms/alloc_to_arg_pass.cc | 111 + .../buffer_packing.cc | 8 +- .../Transforms => transforms}/buffer_reuse.cc | 6 +- .../Transforms => transforms}/bufferize.cc | 31 +- .../bufferize_pass.cc | 48 +- .../collapse_parallel_loops_to_1d_pass.cc | 6 +- .../Transforms => transforms}/copy_removal.cc | 8 +- .../detensorize_scf_ops.cc | 4 +- .../generic_host_to_llvm.cc | 9 +- .../gpu_fusion_rewrite.cc | 367 +- .../gpu_kernel_lowering_passes.cc | 27 +- .../Transforms => transforms}/gpu_passes.h | 10 +- .../Transforms => transforms}/gpu_passes.td | 0 .../transforms/hlo_to_gpu_pipeline.cc | 291 + .../transforms/hlo_to_triton_pipeline.cc | 112 + .../lower_index_cast_pass.cc | 43 +- .../Transforms => transforms}/passes.h | 16 +- .../Transforms => transforms}/passes.td | 35 +- .../propagate_static_shapes_to_kernel.cc | 6 +- .../Transforms => transforms}/rewriters.h | 0 .../tile_loops_pass.cc | 6 +- .../unbufferize_pass.cc | 8 +- .../Transforms => transforms}/unroll_loops.cc | 4 +- .../mlir_hlo/{lib => }/utils/CMakeLists.txt | 0 .../mlir_hlo/{lib => }/utils/codegen_utils.cc | 2 +- .../mlir-hlo => }/utils/codegen_utils.h | 0 .../{lib => }/utils/convert_op_folder.cc | 2 +- .../mlir-hlo => }/utils/convert_op_folder.h | 0 .../{lib => }/utils/cycle_detector.cc | 2 +- .../mlir-hlo => }/utils/cycle_detector.h | 0 .../{lib => }/utils/cycle_detector_test.cc | 2 +- .../xla/mlir_hlo/{lib => }/utils/hlo_utils.cc | 2 +- .../{include/mlir-hlo => }/utils/hlo_utils.h | 0 .../mlir-hlo => }/utils/placement_utils.h | 0 .../compiler/xla/parse_flags_from_env.h | 2 +- tensorflow/compiler/xla/pjrt/BUILD | 133 +- tensorflow/compiler/xla/pjrt/c/BUILD | 32 + tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h | 313 +- .../compiler/xla/pjrt/c/pjrt_c_api_cpu.cc | 51 + .../pjrt_api.h => pjrt/c/pjrt_c_api_cpu.h} | 22 +- .../xla/pjrt/c/pjrt_c_api_cpu_test.cc | 93 + .../compiler/xla/pjrt/c/pjrt_c_api_helpers.cc | 99 + .../compiler/xla/pjrt/c/pjrt_c_api_helpers.h | 44 +- .../compiler/xla/pjrt/c/pjrt_c_api_tpu.h | 5 +- .../xla/pjrt/c/pjrt_c_api_wrapper_impl.cc | 422 +- .../xla/pjrt/c/pjrt_c_api_wrapper_impl.h | 168 +- .../compiler/xla/pjrt/compile_options.proto | 14 +- tensorflow/compiler/xla/pjrt/cpu_device.cc | 68 - tensorflow/compiler/xla/pjrt/cpu_device.h | 35 - .../compiler/xla/pjrt/distributed/BUILD | 37 +- .../compiler/xla/pjrt/distributed/client.cc | 74 +- .../compiler/xla/pjrt/distributed/client.h | 26 +- .../pjrt/distributed/client_server_test.cc | 87 + .../xla/pjrt/distributed/protocol.proto | 8 +- .../compiler/xla/pjrt/distributed/service.cc | 52 +- tensorflow/compiler/xla/pjrt/gpu/BUILD | 35 +- .../compiler/xla/pjrt/gpu/gpu_helpers.cc | 2 +- .../xla/pjrt/gpu/se_gpu_pjrt_client.cc | 56 +- .../xla/pjrt/gpu/se_gpu_pjrt_client.h | 5 +- .../xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 216 + tensorflow/compiler/xla/pjrt/mlir_to_hlo.cc | 20 +- tensorflow/compiler/xla/pjrt/pjrt_api.cc | 102 + tensorflow/compiler/xla/pjrt/pjrt_api.h | 45 + tensorflow/compiler/xla/pjrt/pjrt_api_test.cc | 46 + .../compiler/xla/pjrt/pjrt_c_api_client.cc | 547 +- .../compiler/xla/pjrt/pjrt_c_api_client.h | 135 +- tensorflow/compiler/xla/pjrt/pjrt_client.cc | 43 + tensorflow/compiler/xla/pjrt/pjrt_client.h | 159 +- .../compiler/xla/pjrt/pjrt_client_test.cc | 29 + tensorflow/compiler/xla/pjrt/pjrt_compiler.cc | 10 + tensorflow/compiler/xla/pjrt/pjrt_compiler.h | 8 + .../compiler/xla/pjrt/pjrt_executable.h | 10 +- tensorflow/compiler/xla/pjrt/pjrt_future.h | 4 +- .../xla/pjrt/pjrt_stream_executor_client.cc | 333 +- .../xla/pjrt/pjrt_stream_executor_client.h | 43 +- tensorflow/compiler/xla/pjrt/plugin/BUILD | 42 + .../compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc | 155 +- .../compiler/xla/pjrt/tfrt_cpu_pjrt_client.h | 39 +- .../xla/pjrt/tfrt_cpu_pjrt_client_test.cc | 71 + tensorflow/compiler/xla/pjrt/tpu_client.cc | 42 +- tensorflow/compiler/xla/pjrt/tpu_client.h | 7 +- .../tracked_tfrt_cpu_device_buffer_test.cc | 9 +- tensorflow/compiler/xla/pjrt/utils.cc | 49 +- tensorflow/compiler/xla/pjrt/utils.h | 14 +- tensorflow/compiler/xla/primitive_util.cc | 11 +- tensorflow/compiler/xla/primitive_util.h | 31 +- .../compiler/xla/primitive_util_test.cc | 66 + .../lib/demangle.h => printer.cc} | 25 +- tensorflow/compiler/xla/printer.h | 64 + tensorflow/compiler/xla/python/BUILD | 127 +- tensorflow/compiler/xla/python/callback.h | 13 +- .../xla/python/custom_call_sharding.cc | 2 +- tensorflow/compiler/xla/python/dlpack.cc | 62 +- tensorflow/compiler/xla/python/exceptions.h | 13 +- tensorflow/compiler/xla/python/ifrt/BUILD | 262 + tensorflow/compiler/xla/python/ifrt/README.md | 36 + tensorflow/compiler/xla/python/ifrt/array.cc | 37 + tensorflow/compiler/xla/python/ifrt/array.h | 141 + .../xla/python/ifrt/array_impl_test_lib.cc | 496 ++ .../compiler/xla/python/ifrt/array_test.cc | 86 + .../xla/python/ifrt/client.cc} | 11 +- tensorflow/compiler/xla/python/ifrt/client.h | 147 + .../xla/python/ifrt/client_impl_test_lib.cc | 64 + .../xla/python/ifrt/compiler.cc} | 11 +- .../compiler/xla/python/ifrt/compiler.h | 56 + tensorflow/compiler/xla/python/ifrt/device.h | 65 + tensorflow/compiler/xla/python/ifrt/dtype.cc | 129 + tensorflow/compiler/xla/python/ifrt/dtype.h | 114 + .../compiler/xla/python/ifrt/executable.cc | 25 + .../compiler/xla/python/ifrt/executable.h | 174 + .../python/ifrt/executable_impl_test_lib.cc | 152 + tensorflow/compiler/xla/python/ifrt/future.cc | 65 + tensorflow/compiler/xla/python/ifrt/future.h | 51 + .../compiler/xla/python/ifrt/future_test.cc | 137 + tensorflow/compiler/xla/python/ifrt/index.cc | 35 + tensorflow/compiler/xla/python/ifrt/index.h | 99 + .../compiler/xla/python/ifrt/index_domain.cc | 36 + .../compiler/xla/python/ifrt/index_domain.h | 83 + .../xla/python/ifrt/index_domain_test.cc | 57 + .../compiler/xla/python/ifrt/index_test.cc | 64 + .../xla/python/ifrt/no_impl_test_main.cc | 33 + tensorflow/compiler/xla/python/ifrt/shape.cc | 43 + tensorflow/compiler/xla/python/ifrt/shape.h | 64 + .../compiler/xla/python/ifrt/shape_test.cc | 80 + .../compiler/xla/python/ifrt/sharding.cc | 135 + .../compiler/xla/python/ifrt/sharding.h | 158 + .../compiler/xla/python/ifrt/sharding_test.cc | 74 + .../compiler/xla/python/ifrt/test_util.cc | 71 + .../compiler/xla/python/ifrt/test_util.h | 43 + .../ifrt/tuple.cc} | 15 +- tensorflow/compiler/xla/python/ifrt/tuple.h | 64 + .../xla/python/ifrt/tuple_impl_test_lib.cc | 136 + tensorflow/compiler/xla/python/ifrt/value.cc | 24 + tensorflow/compiler/xla/python/ifrt/value.h | 68 + tensorflow/compiler/xla/python/jax_jit.cc | 425 +- tensorflow/compiler/xla/python/jax_jit.h | 19 +- tensorflow/compiler/xla/python/mlir.cc | 132 +- tensorflow/compiler/xla/python/ops.cc | 24 +- .../compiler/xla/python/outfeed_receiver.cc | 37 +- .../compiler/xla/python/outfeed_receiver.h | 6 +- .../xla/python/outfeed_receiver_py.cc | 8 +- .../xla/python/outfeed_receiver_test.cc | 153 +- tensorflow/compiler/xla/python/pjit.cc | 636 +- .../compiler/xla/python/pjrt_ifrt/BUILD | 126 + .../xla/python/pjrt_ifrt/pjrt_array.cc | 329 + .../xla/python/pjrt_ifrt/pjrt_array.h | 151 + .../pjrt_array_impl_test_tfrt_cpu.cc | 32 + .../xla/python/pjrt_ifrt/pjrt_client.cc | 139 + .../xla/python/pjrt_ifrt/pjrt_client.h | 167 + .../xla/python/pjrt_ifrt/pjrt_compiler.cc | 51 + .../xla/python/pjrt_ifrt/pjrt_compiler.h | 54 + .../xla/python/pjrt_ifrt/pjrt_executable.cc | 502 ++ .../xla/python/pjrt_ifrt/pjrt_executable.h | 278 + .../pjrt_executable_impl_test_tfrt_cpu.cc | 29 + .../xla/python/pjrt_ifrt/pjrt_tuple.cc | 101 + .../xla/python/pjrt_ifrt/pjrt_tuple.h | 82 + .../pjrt_ifrt/tfrt_cpu_client_test_lib.cc | 40 + tensorflow/compiler/xla/python/pmap_lib.cc | 370 +- .../xla/python/pprof_profile_builder.cc | 3 +- .../xla/python/pprof_profile_builder.h | 2 +- tensorflow/compiler/xla/python/profiler.cc | 40 +- .../xla/python/profiler/internal/BUILD | 22 +- .../python/profiler/internal/python_hooks.h | 1 - .../profiler/internal/traceme_wrapper.h | 20 +- tensorflow/compiler/xla/python/py_array.cc | 200 +- tensorflow/compiler/xla/python/py_array.h | 61 +- tensorflow/compiler/xla/python/py_buffer.cc | 321 +- tensorflow/compiler/xla/python/py_buffer.h | 149 +- tensorflow/compiler/xla/python/py_client.cc | 272 +- tensorflow/compiler/xla/python/py_client.h | 60 +- .../compiler/xla/python/py_executable.cc | 224 +- .../compiler/xla/python/py_executable.h | 62 +- tensorflow/compiler/xla/python/py_values.cc | 179 +- tensorflow/compiler/xla/python/py_values.h | 26 +- tensorflow/compiler/xla/python/pytree.cc | 91 +- tensorflow/compiler/xla/python/pytree.h | 11 +- .../xla/python/sharded_device_array.cc | 75 +- .../xla/python/sharded_device_array.h | 9 +- tensorflow/compiler/xla/python/sharding.cc | 52 +- tensorflow/compiler/xla/python/sharding.h | 50 +- .../compiler/xla/python/tpu_driver/BUILD | 12 +- .../xla/python/tpu_driver/client/BUILD | 9 +- .../tpu_driver/client/tpu_client_extension.cc | 15 +- tensorflow/compiler/xla/python/traceback.cc | 11 +- tensorflow/compiler/xla/python/traceback.h | 5 + tensorflow/compiler/xla/python/types.cc | 54 +- tensorflow/compiler/xla/python/types.h | 2 + .../cpu_device_test.cc => python/util.cc} | 38 +- tensorflow/compiler/xla/python/util.h | 66 + tensorflow/compiler/xla/python/xla.cc | 140 +- tensorflow/compiler/xla/python/xla_client.py | 80 +- tensorflow/compiler/xla/python/xla_client.pyi | 6 +- .../xla_client_backend_independent_test.py | 53 + .../compiler/xla/python/xla_client_test.py | 214 +- .../compiler/xla/python/xla_compiler.cc | 127 +- .../xla/python/xla_extension/__init__.pyi | 56 +- .../xla/python/xla_extension/mlir.pyi | 2 + .../compiler/xla/python/xla_extension/ops.pyi | 4 +- .../python/xla_extension/outfeed_receiver.pyi | 3 +- .../xla/python/xla_extension/pmap_lib.pyi | 1 - tensorflow/compiler/xla/python_api/BUILD | 6 +- tensorflow/compiler/xla/pytype.default.bzl | 14 + tensorflow/compiler/xla/reference_util.cc | 2 +- tensorflow/compiler/xla/rpc/BUILD | 26 +- .../compiler/xla/rpc/grpc_client_test.cc | 7 +- tensorflow/compiler/xla/runlit.cfg.py | 85 + tensorflow/compiler/xla/runlit.site.cfg.py | 73 + tensorflow/compiler/xla/runtime/BUILD | 142 +- tensorflow/compiler/xla/runtime/arguments.cc | 2 +- tensorflow/compiler/xla/runtime/arguments.h | 7 +- .../compiler/xla/runtime/async_runtime.cc | 106 +- .../compiler/xla/runtime/async_runtime.h | 25 +- .../xla/runtime/async_runtime_test.cc | 182 + .../compiler/xla/runtime/custom_call.cc | 51 +- tensorflow/compiler/xla/runtime/custom_call.h | 679 +- .../compiler/xla/runtime/custom_call_test.cc | 951 ++- tensorflow/compiler/xla/runtime/default/BUILD | 3 +- tensorflow/compiler/xla/runtime/executable.cc | 81 +- tensorflow/compiler/xla/runtime/executable.h | 49 +- .../compiler/xla/runtime/executable_test.cc | 448 +- .../compiler/xla/runtime/execution_engine.cc | 101 +- tensorflow/compiler/xla/runtime/ffi.cc | 481 ++ tensorflow/compiler/xla/runtime/ffi.h | 93 + tensorflow/compiler/xla/runtime/ffi/BUILD | 29 + tensorflow/compiler/xla/runtime/ffi/ffi_abi.h | 62 + tensorflow/compiler/xla/runtime/ffi/ffi_api.h | 1137 +++ .../compiler/xla/runtime/ffi/ffi_c_api.h | 269 + tensorflow/compiler/xla/runtime/ffi_test.cc | 400 ++ .../compiler/xla/runtime/jit_executable.cc | 4 +- tensorflow/compiler/xla/runtime/map_by_type.h | 33 +- tensorflow/compiler/xla/runtime/module.h | 185 + .../compiler/xla/runtime/module_registry.cc | 76 + .../compiler/xla/runtime/module_registry.h | 75 + .../compiler/xla/runtime/module_test.cc | 96 + tensorflow/compiler/xla/runtime/results.h | 46 + tensorflow/compiler/xla/runtime/runner/BUILD | 85 + .../compiler/xla/runtime/runner/runner.cc | 342 + .../compiler/xla/runtime/runner/runner.h | 56 + .../compiler/xla/runtime/runner/runner.proto | 59 + .../compiler/xla/runtime/runner/runner.py | 149 + .../xla/runtime/runner/testlib_runner.cc | 37 + .../xla/runtime/runner/testlib_runner_test.py | 85 + tensorflow/compiler/xla/runtime/state.h | 182 + tensorflow/compiler/xla/runtime/state_test.cc | 108 + .../compiler/xla/runtime/symbolic_shape.cc | 10 +- .../compiler/xla/runtime/symbolic_shape.h | 2 +- .../xla/runtime/symbolic_shape_test.cc | 79 +- tensorflow/compiler/xla/runtime/tracing.h | 6 +- tensorflow/compiler/xla/runtime/types.h | 20 +- tensorflow/compiler/xla/service/BUILD | 1328 ++-- .../xla/service/algebraic_simplifier.cc | 90 +- .../xla/service/algebraic_simplifier.h | 10 +- .../xla/service/algebraic_simplifier_test.cc | 171 +- .../service/all_gather_broadcast_reorder.cc | 8 +- .../service/all_gather_broadcast_reorder.h | 2 +- .../xla/service/all_gather_combiner.cc | 8 +- .../xla/service/all_gather_combiner.h | 2 +- .../xla/service/all_gather_combiner_test.cc | 8 +- .../xla/service/all_gather_decomposer.cc | 10 +- .../xla/service/all_gather_decomposer.h | 4 +- .../xla/service/all_gather_decomposer_test.cc | 10 +- .../xla/service/all_reduce_combiner.cc | 8 +- .../xla/service/all_reduce_combiner.h | 2 +- .../xla/service/all_reduce_combiner_test.cc | 8 +- .../xla/service/all_reduce_contiguous.cc | 8 +- .../xla/service/all_reduce_contiguous.h | 2 +- .../xla/service/all_reduce_contiguous_test.cc | 6 +- .../compiler/xla/service/all_reduce_folder.cc | 10 +- .../compiler/xla/service/all_reduce_folder.h | 2 +- .../xla/service/all_reduce_folder_test.cc | 6 +- .../compiler/xla/service/all_reduce_key.cc | 4 +- .../compiler/xla/service/all_reduce_key.h | 2 +- .../xla/service/all_reduce_promotion.cc | 69 + .../xla/service/all_reduce_promotion.h | 40 + .../xla/service/all_reduce_promotion_test.cc | 92 + .../xla/service/all_reduce_reassociate.cc | 6 +- .../xla/service/all_reduce_reassociate.h | 2 +- .../service/all_reduce_reassociate_test.cc | 4 +- .../xla/service/all_reduce_simplifier.cc | 6 +- .../xla/service/all_reduce_simplifier.h | 2 +- .../xla/service/all_reduce_simplifier_test.cc | 6 +- .../xla/service/all_to_all_decomposer.cc | 10 +- .../xla/service/all_to_all_decomposer.h | 6 +- .../compiler/xla/service/ar_crs_combiner.cc | 8 +- .../compiler/xla/service/ar_crs_combiner.h | 2 +- .../xla/service/async_collective_creator.cc | 14 +- .../service/async_collective_creator_test.cc | 6 +- .../xla/service/async_op_canonicalizer.h | 2 +- .../xla/service/batch_dot_simplification.cc | 2 +- .../xla/service/batch_dot_simplification.h | 2 +- .../xla/service/batchnorm_expander.cc | 8 +- .../compiler/xla/service/batchnorm_expander.h | 2 +- .../xla/service/batchnorm_expander_test.cc | 6 +- .../service/bfloat16_conversion_folding.cc | 8 +- .../xla/service/bfloat16_conversion_folding.h | 2 +- .../bfloat16_conversion_folding_test.cc | 9 +- .../xla/service/bfloat16_normalization.cc | 8 +- .../xla/service/bfloat16_normalization.h | 2 +- .../service/bfloat16_normalization_test.cc | 9 +- .../xla/service/bfloat16_propagation.cc | 8 +- .../xla/service/bfloat16_propagation.h | 4 +- .../xla/service/bfloat16_propagation_test.cc | 8 +- .../compiler/xla/service/bfloat16_support.cc | 7 +- .../compiler/xla/service/bfloat16_support.h | 4 +- .../xla/service/bitcast_dtypes_expander.cc | 12 +- .../xla/service/bitcast_dtypes_expander.h | 4 +- .../compiler/xla/service/buffer_assignment.cc | 131 +- .../compiler/xla/service/buffer_assignment.h | 34 +- .../xla/service/buffer_assignment_test.cc | 247 +- .../compiler/xla/service/buffer_value.cc | 9 +- .../compiler/xla/service/buffer_value.h | 2 +- tensorflow/compiler/xla/service/call_graph.h | 6 +- .../compiler/xla/service/call_graph_test.cc | 2 +- .../compiler/xla/service/call_inliner.cc | 6 +- .../compiler/xla/service/call_inliner_test.cc | 6 +- .../xla/service/change_op_data_type.cc | 45 +- .../xla/service/change_op_data_type.h | 33 +- .../xla/service/change_op_data_type_test.cc | 59 +- .../compiler/xla/service/channel_tracker.cc | 4 +- .../compiler/xla/service/channel_tracker.h | 2 +- .../xla/service/collective_combiner_utils.h | 6 +- .../service/collective_decomposer_utils.cc | 8 +- .../xla/service/collective_decomposer_utils.h | 2 +- .../xla/service/collective_ops_utils.cc | 4 +- .../xla/service/collective_ops_utils.h | 4 +- .../collectives_schedule_linearizer.cc | 66 +- .../service/collectives_schedule_linearizer.h | 2 +- .../collectives_schedule_linearizer_test.cc | 52 +- .../xla/service/comparison_expander.cc | 10 +- .../xla/service/comparison_expander.h | 2 +- .../xla/service/compilation_environments.cc | 145 +- .../xla/service/compilation_environments.h | 81 +- .../service/compilation_environments_test.cc | 78 +- .../compiler/xla/service/compilation_stats.cc | 6 + .../compiler/xla/service/compilation_stats.h | 4 + .../compiler/xla/service/compile_time_cap.h | 2 +- tensorflow/compiler/xla/service/compiler.cc | 21 +- tensorflow/compiler/xla/service/compiler.h | 33 +- .../xla/service/computation_layout.cc | 23 +- .../compiler/xla/service/computation_layout.h | 4 + .../xla/service/conditional_canonicalizer.cc | 4 +- .../xla/service/conditional_canonicalizer.h | 2 +- .../service/conditional_canonicalizer_test.cc | 8 +- .../xla/service/conditional_code_motion.cc | 27 +- .../xla/service/conditional_code_motion.h | 2 +- .../service/conditional_code_motion_test.cc | 6 +- .../xla/service/conditional_simplifier.cc | 10 +- .../xla/service/conditional_simplifier.h | 2 +- .../service/conditional_simplifier_test.cc | 7 +- .../xla/service/conditional_to_select.cc | 6 +- .../xla/service/conditional_to_select.h | 2 +- .../xla/service/conditional_to_select_test.cc | 6 +- .../compiler/xla/service/convert_mover.cc | 2 +- .../xla/service/convert_operand_folding.h | 2 +- .../xla/service/convolution_4d_expander.cc | 2 +- .../xla/service/convolution_4d_expander.h | 2 +- .../service/convolution_4d_expander_test.cc | 6 +- .../service/convolution_group_converter.cc | 8 +- .../xla/service/convolution_group_converter.h | 2 +- .../convolution_group_converter_test.cc | 6 +- .../xla/service/convolution_pred_expander.cc | 4 +- .../xla/service/convolution_pred_expander.h | 2 +- .../compiler/xla/service/copy_insertion.cc | 36 +- .../compiler/xla/service/copy_insertion.h | 6 +- .../xla/service/copy_insertion_test.cc | 86 +- tensorflow/compiler/xla/service/cpu/BUILD | 118 +- .../compiler/xla/service/cpu/build_defs.bzl | 4 +- .../xla/service/cpu/compiler_functor.cc | 2 +- .../xla/service/cpu/conv_canonicalization.cc | 6 +- .../xla/service/cpu/conv_canonicalization.h | 2 +- .../service/cpu/conv_canonicalization_test.cc | 9 +- .../compiler/xla/service/cpu/cpu_compiler.cc | 332 +- .../compiler/xla/service/cpu/cpu_compiler.h | 21 +- .../xla/service/cpu/cpu_executable.cc | 42 +- .../compiler/xla/service/cpu/cpu_executable.h | 37 +- .../xla/service/cpu/cpu_instruction_fusion.cc | 2 +- .../xla/service/cpu/cpu_instruction_fusion.h | 2 +- .../service/cpu/cpu_layout_assignment_test.cc | 8 +- .../compiler/xla/service/cpu/cpu_runtime.cc | 647 +- .../compiler/xla/service/cpu/cpu_runtime.h | 2 - .../xla/service/cpu/cpu_transfer_manager.cc | 2 +- .../xla/service/cpu/cpu_transfer_manager.h | 3 +- .../compiler/xla/service/cpu/cpu_xfeed.cc | 13 +- .../compiler/xla/service/cpu/cpu_xfeed.h | 3 +- .../xla/service/cpu/dot_op_emitter.cc | 12 +- .../compiler/xla/service/cpu/dot_op_emitter.h | 2 +- .../xla/service/cpu/dot_op_emitter_internal.h | 4 +- .../xla/service/cpu/elemental_ir_emitter.cc | 8 +- .../xla/service/cpu/elemental_ir_emitter.h | 2 +- .../service/cpu/hlo_xla_runtime_pipeline.cc | 124 +- .../service/cpu/hlo_xla_runtime_pipeline.h | 8 +- .../xla/service/cpu/ir_emission_utils.cc | 2 +- .../xla/service/cpu/ir_emission_utils.h | 2 +- .../compiler/xla/service/cpu/ir_emitter.cc | 28 +- .../compiler/xla/service/cpu/ir_emitter.h | 14 +- .../compiler/xla/service/cpu/mlir_emitter.cc | 2 +- .../service/cpu/parallel_task_assignment.cc | 6 +- .../service/cpu/parallel_task_assignment.h | 2 +- .../compiler/xla/service/cpu/runtime/BUILD | 65 + .../xla/service/cpu/runtime/collectives.cc | 356 + .../xla/service/cpu/runtime/collectives.h | 29 + .../xla/service/cpu/runtime/custom_call.cc | 7 +- .../xla/service/cpu/runtime/fft_call.cc | 102 + .../xla/service/cpu/runtime/fft_call.h | 28 + .../compiler/xla/service/cpu/runtime/rng.cc | 226 + .../compiler/xla/service/cpu/runtime/rng.h | 29 + .../compiler/xla/service/cpu/runtime/xfeed.cc | 184 + .../compiler/xla/service/cpu/runtime/xfeed.h | 29 + .../xla/service/cpu/runtime_conv_impl.h | 4 +- .../xla/service/cpu/runtime_matmul.cc | 2 +- .../cpu/runtime_single_threaded_matmul.cc | 2 +- .../xla/service/cpu/simple_orc_jit.cc | 7 +- .../compiler/xla/service/cpu/simple_orc_jit.h | 2 + .../compiler/xla/service/cpu/tests/BUILD | 71 +- .../cpu/tests/cpu_eigen_dot_operation_test.cc | 2 +- .../cpu/tests/cpu_external_constants_test.cc | 6 +- .../xla/service/cpu/tests/cpu_fusion_test.cc | 8 +- .../service/cpu/tests/cpu_intrinsic_test.cc | 2 +- .../xla/service/cpu/tests/cpu_noalias_test.cc | 6 +- .../service/cpu/tests/cpu_profiling_test.cc | 2 +- .../cpu/tests/cpu_vectorization_test.cc | 2 +- .../cpu/tests/tree_reduction_rewriter_test.cc | 4 +- .../xla/service/cpu/tiled_dot_emitter.cc | 2 +- ...ed_reduce_with_no_vector_registers_test.cc | 2 +- .../xla/service/custom_call_sharding_helper.h | 4 +- tensorflow/compiler/xla/service/defuser.cc | 6 +- tensorflow/compiler/xla/service/defuser.h | 2 +- .../compiler/xla/service/despecializer.h | 6 +- .../xla/service/despecializer_test.cc | 2 +- .../compiler/xla/service/dfs_hlo_visitor.h | 33 - .../service/dfs_hlo_visitor_with_default.h | 33 - .../dfs_hlo_visitor_with_default_test.cc | 10 +- .../xla/service/dot_as_convolution_util.cc | 4 +- .../xla/service/dot_as_convolution_util.h | 2 +- .../compiler/xla/service/dot_decomposer.cc | 6 +- .../compiler/xla/service/dot_decomposer.h | 2 +- tensorflow/compiler/xla/service/dot_merger.cc | 10 +- tensorflow/compiler/xla/service/dump.cc | 27 +- tensorflow/compiler/xla/service/dump.h | 7 +- .../service/dynamic_dimension_inference.cc | 33 +- .../xla/service/dynamic_dimension_inference.h | 4 +- .../dynamic_dimension_inference_test.cc | 12 +- .../service/dynamic_dimension_simplifier.cc | 4 +- .../service/dynamic_dimension_simplifier.h | 2 +- .../dynamic_dimension_simplifier_test.cc | 10 +- .../xla/service/dynamic_index_splitter.cc | 10 +- .../xla/service/dynamic_index_splitter.h | 2 +- .../service/dynamic_index_splitter_test.cc | 6 +- .../compiler/xla/service/dynamic_padder.cc | 14 +- .../xla/service/dynamic_padder_test.cc | 34 +- .../xla/service/dynamic_parameter_binding.h | 30 - .../service/dynamic_parameter_binding_test.cc | 8 +- .../xla/service/dynamic_window_utils.cc | 6 +- .../xla/service/dynamic_window_utils.h | 4 +- .../xla/service/elemental_ir_emitter.cc | 492 +- .../xla/service/elemental_ir_emitter.h | 13 +- .../xla/service/elemental_ir_emitter_test.cc | 11 + tensorflow/compiler/xla/service/executable.h | 2 +- .../xla/service/flatten_call_graph.cc | 6 +- .../xla/service/flatten_call_graph_test.cc | 2 +- .../fusion_node_indexing_evaluation.cc | 4 +- .../service/fusion_node_indexing_evaluation.h | 2 +- .../fusion_node_indexing_evaluation_test.cc | 4 +- .../compiler/xla/service/fusion_queue.h | 2 +- .../compiler/xla/service/g3doc/gpu_backend.md | 6 +- .../compiler/xla/service/gather_expander.cc | 2 +- .../xla/service/gather_scatter_utils.h | 2 +- .../compiler/xla/service/gather_simplifier.cc | 4 +- .../xla/service/generic_transfer_manager.cc | 66 +- .../xla/service/generic_transfer_manager.h | 5 +- tensorflow/compiler/xla/service/gpu/BUILD | 711 +- tensorflow/compiler/xla/service/gpu/README.md | 3 + .../service/gpu/alias_passthrough_params.cc | 4 +- .../service/gpu/alias_passthrough_params.h | 2 +- .../xla/service/gpu/all_reduce_blueconnect.cc | 8 +- .../xla/service/gpu/all_reduce_blueconnect.h | 2 +- .../gpu/all_reduce_blueconnect_test.cc | 6 +- .../xla/service/gpu/amdgpu_compiler.cc | 19 +- .../xla/service/gpu/amdgpu_compiler.h | 12 +- .../xla/service/gpu/backend_configs.proto | 6 +- .../xla/service/gpu/cholesky_thunk.cc | 98 +- .../compiler/xla/service/gpu/cholesky_thunk.h | 2 +- .../xla/service/gpu/conditional_thunk.cc | 2 +- .../xla/service/gpu/conditional_thunk.h | 2 +- .../service/gpu/conv_layout_normalization.cc | 4 +- .../service/gpu/conv_layout_normalization.h | 4 +- .../gpu/conv_layout_normalization_test.cc | 6 +- .../xla/service/gpu/convolution_thunk.cc | 2 +- .../xla/service/gpu/convolution_thunk.h | 4 +- .../compiler/xla/service/gpu/copy_thunk.h | 2 +- .../compiler/xla/service/gpu/cublas_cudnn.cc | 22 +- .../compiler/xla/service/gpu/cublas_cudnn.h | 23 +- .../xla/service/gpu/cublas_lt_matmul_thunk.cc | 38 +- .../xla/service/gpu/cublas_lt_matmul_thunk.h | 14 +- .../xla/service/gpu/cublas_pad_for_gemms.cc | 6 +- .../service/gpu/cudnn_fused_conv_rewriter.cc | 206 +- .../service/gpu/cudnn_fused_conv_rewriter.h | 8 +- .../gpu/cudnn_fused_conv_rewriter_test.cc | 505 +- .../service/gpu/cudnn_pad_for_convolutions.cc | 2 +- .../service/gpu/cudnn_pad_for_convolutions.h | 2 +- .../gpu/cudnn_pad_for_convolutions_test.cc | 48 +- .../xla/service/gpu/cudnn_simplify_padding.cc | 61 +- .../gpu/cudnn_simplify_padding_test.cc | 141 +- .../xla/service/gpu/cudnn_support_utils.h | 2 +- .../service/gpu/cudnn_support_utils_test.cc | 6 +- .../gpu/cudnn_vectorize_convolutions.cc | 4 +- .../gpu/cudnn_vectorize_convolutions_test.cc | 20 +- .../xla/service/gpu/cusolver_context.cc | 119 +- .../xla/service/gpu/cusolver_context.h | 49 +- .../xla/service/gpu/cusolver_rewriter.cc | 11 +- .../xla/service/gpu/cusolver_rewriter.h | 4 +- .../xla/service/gpu/custom_call_test.cc | 183 +- .../xla/service/gpu/dot_dimension_sorter.cc | 135 + .../xla/service/gpu/dot_dimension_sorter.h | 49 + .../service/gpu/dot_dimension_sorter_test.cc | 160 + .../xla/service/gpu/elemental_ir_emitter.cc | 36 +- .../xla/service/gpu/elemental_ir_emitter.h | 19 +- .../compiler/xla/service/gpu/executable.proto | 9 + .../compiler/xla/service/gpu/fft_thunk.h | 2 +- .../compiler/xla/service/gpu/for_thunk.h | 2 +- .../compiler/xla/service/gpu/fusion_merger.cc | 297 +- .../compiler/xla/service/gpu/fusion_merger.h | 14 +- .../xla/service/gpu/fusion_merger_test.cc | 592 +- .../xla/service/gpu/gemm_algorithm_picker.cc | 367 +- .../xla/service/gpu/gemm_algorithm_picker.h | 80 +- .../service/gpu/gemm_algorithm_picker_test.cc | 168 + .../gpu/gemm_broadcast_folding_rewriter.cc | 13 +- .../gpu/gemm_broadcast_folding_rewriter.h | 4 +- .../compiler/xla/service/gpu/gemm_rewriter.cc | 910 ++- .../compiler/xla/service/gpu/gemm_rewriter.h | 4 +- .../service/gpu/gpu_aot_compilation_test.cc | 53 + .../xla/service/gpu/gpu_autotuning.proto | 2 +- .../compiler/xla/service/gpu/gpu_compiler.cc | 613 +- .../compiler/xla/service/gpu/gpu_compiler.h | 86 +- .../service/gpu/gpu_conv_algorithm_picker.cc | 193 +- .../service/gpu/gpu_conv_algorithm_picker.h | 48 +- .../gpu/gpu_conv_algorithm_picker_test.cc | 100 + .../gpu/gpu_conv_padding_legalization.cc | 2 +- .../gpu/gpu_conv_padding_legalization_test.cc | 20 +- .../xla/service/gpu/gpu_conv_rewriter.cc | 6 +- .../xla/service/gpu/gpu_conv_rewriter.h | 2 +- .../xla/service/gpu/gpu_conv_rewriter_test.cc | 106 +- .../xla/service/gpu/gpu_conv_runner.cc | 57 +- .../xla/service/gpu/gpu_conv_runner.h | 7 +- .../xla/service/gpu/gpu_device_info.cc | 46 +- .../xla/service/gpu/gpu_device_info.h | 43 +- ..._helper.h => gpu_device_info_for_tests.cc} | 35 +- .../service/gpu/gpu_device_info_for_tests.h | 32 + .../gpu/gpu_device_info_test.cc} | 44 +- .../xla/service/gpu/gpu_executable.cc | 217 +- .../compiler/xla/service/gpu/gpu_executable.h | 8 +- .../compiler/xla/service/gpu/gpu_fusible.cc | 118 +- .../compiler/xla/service/gpu/gpu_fusible.h | 16 +- .../xla/service/gpu/gpu_fusible_test.cc | 2 +- .../xla/service/gpu/gpu_hlo_cost_analysis.cc | 212 +- .../xla/service/gpu/gpu_hlo_cost_analysis.h | 33 +- .../service/gpu/gpu_hlo_cost_analysis_test.cc | 339 +- .../xla/service/gpu/gpu_hlo_schedule.cc | 100 +- .../xla/service/gpu/gpu_hlo_schedule.h | 11 +- .../xla/service/gpu/gpu_hlo_schedule_test.cc | 94 +- .../xla/service/gpu/gpu_layout_assignment.cc | 15 +- .../xla/service/gpu/gpu_layout_assignment.h | 2 +- .../service/gpu/gpu_layout_assignment_test.cc | 45 +- .../xla/service/gpu/gpu_performance_model.cc | 214 + .../xla/service/gpu/gpu_performance_model.h | 45 + .../service/gpu/gpu_performance_model_test.cc | 850 +++ .../service/gpu/gpu_reduce_scatter_creator.cc | 10 +- .../gpu/gpu_sanitize_constant_names.cc | 4 +- .../service/gpu/gpu_sanitize_constant_names.h | 2 +- .../xla/service/gpu/gpu_scatter_expander.cc | 6 +- .../xla/service/gpu/gpu_transfer_manager.cc | 8 +- .../xla/service/gpu/gpu_transfer_manager.h | 3 +- .../xla/service/gpu/hlo_algorithm_denylist.h | 2 +- .../gpu/hlo_algorithm_denylist_test.cc | 8 +- .../xla/service/gpu/hlo_fusion_stats.cc | 9 +- .../xla/service/gpu/hlo_fusion_stats.h | 4 +- .../xla/service/gpu/hlo_to_ir_bindings.cc | 2 +- .../xla/service/gpu/hlo_to_ir_bindings.h | 2 +- .../service/gpu/horizontal_input_fusion.cc | 14 +- .../xla/service/gpu/horizontal_input_fusion.h | 10 +- .../gpu/horizontal_input_fusion_test.cc | 17 +- .../xla/service/gpu/horizontal_loop_fusion.cc | 378 +- .../xla/service/gpu/horizontal_loop_fusion.h | 41 +- .../gpu/horizontal_loop_fusion_test.cc | 230 +- .../compiler/xla/service/gpu/infeed_thunk.h | 2 +- .../xla/service/gpu/instruction_fusion.cc | 27 +- .../xla/service/gpu/instruction_fusion.h | 20 +- .../service/gpu/instruction_fusion_test.cc | 168 +- .../xla/service/gpu/ir_emission_utils.cc | 43 +- .../xla/service/gpu/ir_emission_utils.h | 28 +- .../xla/service/gpu/ir_emission_utils_test.cc | 2 +- .../compiler/xla/service/gpu/ir_emitter.cc | 9 +- .../compiler/xla/service/gpu/ir_emitter.h | 6 +- .../xla/service/gpu/ir_emitter_nested.cc | 8 +- .../xla/service/gpu/ir_emitter_unnested.cc | 639 +- .../xla/service/gpu/ir_emitter_unnested.h | 43 +- .../xla/service/gpu/jitrt_custom_calls.cc | 1557 ----- .../xla/service/gpu/jitrt_custom_calls.h | 117 - .../xla/service/gpu/kernel_mapping_scheme.h | 8 +- .../compiler/xla/service/gpu/kernel_thunk.cc | 3 +- .../compiler/xla/service/gpu/kernel_thunk.h | 2 +- .../xla/service/gpu/launch_dimensions.h | 9 + .../xla/service/gpu/llvm_gpu_backend/BUILD | 8 +- .../gpu/llvm_gpu_backend/dump_ir_pass.cc | 104 - .../gpu/llvm_gpu_backend/dump_ir_pass.h | 51 - .../gpu/llvm_gpu_backend/gpu_backend_lib.cc | 292 +- .../gpu/llvm_gpu_backend/gpu_backend_lib.h | 5 +- .../gpu/llvm_gpu_backend/utils_test.cc | 8 +- .../compiler/xla/service/gpu/matmul_utils.cc | 391 +- .../compiler/xla/service/gpu/matmul_utils.h | 83 +- .../compiler/xla/service/gpu/memset_thunk.h | 2 +- .../xla/service/gpu/move_copy_to_users.cc | 4 +- .../xla/service/gpu/move_copy_to_users.h | 2 +- .../xla/service/gpu/multi_output_fusion.cc | 40 +- .../xla/service/gpu/multi_output_fusion.h | 14 +- .../service/gpu/multi_output_fusion_test.cc | 288 +- .../xla/service/gpu/nccl_all_gather_thunk.h | 4 +- .../xla/service/gpu/nccl_all_reduce_thunk.cc | 84 +- .../xla/service/gpu/nccl_all_reduce_thunk.h | 32 +- .../xla/service/gpu/nccl_all_to_all_thunk.cc | 3 +- .../xla/service/gpu/nccl_all_to_all_thunk.h | 4 +- .../gpu/nccl_collective_permute_thunk.cc | 136 +- .../gpu/nccl_collective_permute_thunk.h | 72 +- .../xla/service/gpu/nccl_collective_thunk.cc | 42 + .../xla/service/gpu/nccl_collective_thunk.h | 30 + .../compiler/xla/service/gpu/nccl_utils.cc | 56 +- .../compiler/xla/service/gpu/nccl_utils.h | 6 +- .../xla/service/gpu/nvptx_compiler.cc | 198 +- .../compiler/xla/service/gpu/nvptx_compiler.h | 18 +- .../xla/service/gpu/nvptx_compiler_test.cc | 22 +- .../compiler/xla/service/gpu/nvptx_helper.cc | 68 - .../compiler/xla/service/gpu/outfeed_thunk.cc | 2 +- .../compiler/xla/service/gpu/outfeed_thunk.h | 2 +- .../gpu/reduction_degenerate_dim_remover.cc | 9 +- .../gpu/reduction_degenerate_dim_remover.h | 4 +- .../gpu/reduction_dimension_grouper.cc | 8 +- .../service/gpu/reduction_dimension_grouper.h | 4 +- .../gpu/reduction_layout_normalizer.cc | 9 +- .../service/gpu/reduction_layout_normalizer.h | 4 +- .../xla/service/gpu/reduction_splitter.cc | 4 +- .../xla/service/gpu/reduction_splitter.h | 2 +- .../compiler/xla/service/gpu/runtime/BUILD | 201 +- .../xla/service/gpu/runtime/cholesky.cc | 87 + .../xla/service/gpu/runtime/cholesky.h | 30 + .../xla/service/gpu/runtime/collectives.cc | 635 ++ .../xla/service/gpu/runtime/collectives.h | 75 + .../compiler/xla/service/gpu/runtime/conv.cc | 479 +- .../compiler/xla/service/gpu/runtime/conv.h | 103 +- .../service/gpu/runtime/cublas_lt_matmul.cc | 260 +- .../service/gpu/runtime/cublas_lt_matmul.h | 43 +- .../xla/service/gpu/runtime/custom_call.cc | 152 + .../xla/service/gpu/runtime/custom_call.h | 29 + .../xla/service/gpu/runtime/executable.cc | 250 +- .../xla/service/gpu/runtime/executable.h | 60 +- .../compiler/xla/service/gpu/runtime/fft.cc | 108 +- .../compiler/xla/service/gpu/runtime/fft.h | 22 +- .../compiler/xla/service/gpu/runtime/gemm.cc | 160 + .../compiler/xla/service/gpu/runtime/gemm.h | 36 + .../xla/service/gpu/runtime/graph_launch.cc | 248 +- .../xla/service/gpu/runtime/graph_launch.h | 53 + .../xla/service/gpu/runtime/io_feed.cc | 77 +- .../xla/service/gpu/runtime/kernel_launch.cc | 134 +- .../xla/service/gpu/runtime/kernel_launch.h | 22 +- .../xla/service/gpu/runtime/memcpy.cc | 90 + .../compiler/xla/service/gpu/runtime/memcpy.h | 30 + .../xla/service/gpu/runtime/memset.cc | 146 + .../compiler/xla/service/gpu/runtime/memset.h | 30 + .../xla/service/gpu/runtime/send_recv.cc | 310 + .../xla/service/gpu/runtime/send_recv.h | 56 + .../xla/service/gpu/runtime/support.h | 57 +- .../xla/service/gpu/runtime/tracing.cc | 73 +- .../service/gpu/runtime/triangular_solve.cc | 141 + .../service/gpu/runtime/triangular_solve.h | 52 + .../service/gpu/scatter_slice_simplifier.cc | 8 +- .../xla/service/gpu/sequential_thunk.h | 2 +- .../xla/service/gpu/softmax_fusion.cc | 343 +- .../xla/service/gpu/softmax_fusion_test.cc | 918 ++- .../xla/service/gpu/stream_executor_util.cc | 41 +- .../xla/service/gpu/stream_executor_util.h | 10 +- .../compiler/xla/service/gpu/target_util.cc | 23 +- .../compiler/xla/service/gpu/target_util.h | 5 +- .../compiler/xla/service/gpu/tests/BUILD | 228 +- .../gpu/tests/dynamic_shared_memory_test.cc | 108 + .../gpu/tests/elemental_ir_emitter_test.cc | 65 + .../service/gpu/tests/gemm_rewrite_test.cc | 1620 ++++- .../xla/service/gpu/tests/gpu_atomic_test.cc | 2 +- .../xla/service/gpu/tests/gpu_codegen_test.cc | 1 - .../tests/gpu_compilation_parallelism_test.cc | 2 +- .../service/gpu/tests/gpu_copy_alone_test.cc | 2 +- .../xla/service/gpu/tests/gpu_copy_test.cc | 8 +- .../service/gpu/tests/gpu_dyn_shape_test.cc | 2 +- .../gpu/tests/gpu_fusion_pipeline_test.cc | 15 +- .../xla/service/gpu/tests/gpu_fusion_test.cc | 13 +- .../xla/service/gpu/tests/gpu_index_test.cc | 8 +- .../gpu/tests/gpu_kernel_tiling_test.cc | 79 +- .../xla/service/gpu/tests/gpu_ldg_test.cc | 6 +- .../xla/service/gpu/tests/gpu_noalias_test.cc | 6 +- .../tests/gpu_reduce_scatter_creator_test.cc | 8 +- .../service/gpu/tests/gpu_unrolling_test.cc | 29 +- .../xla/service/gpu/tests/hlo_to_llvm_ir.cc | 12 +- .../service/gpu/tests/launch_dimensions.hlo | 2 +- .../xla/service/gpu/tests/mlir_gemm_test.cc | 50 +- .../service/gpu/tests/mlir_gpu_test_base.cc | 5 +- .../compiler/xla/service/gpu/tests/mnist.py | 86 - .../reduction_degenerate_dim_remover_test.cc | 3 +- .../tests/reduction_dimension_grouper_test.cc | 2 +- .../gpu/tests/reduction_emitter_test.cc | 59 + .../tests/reduction_layout_normalizer_test.cc | 2 +- .../gpu/tests/reduction_vectorization_test.cc | 15 +- .../xla/service/gpu/tests/sorting_test.cc | 3 +- .../gpu/tests/swap_conv_operands_test.cc | 20 +- .../gpu/tests/tree_reduction_rewriter_test.cc | 3 +- tensorflow/compiler/xla/service/gpu/thunk.cc | 13 +- tensorflow/compiler/xla/service/gpu/thunk.h | 4 +- .../service/gpu/tree_reduction_rewriter.cc | 12 +- .../xla/service/gpu/tree_reduction_rewriter.h | 2 +- .../service/gpu/triangular_solve_rewriter.h | 2 +- .../xla/service/gpu/triangular_solve_thunk.cc | 4 +- .../xla/service/gpu/triangular_solve_thunk.h | 4 +- .../xla/service/gpu/variadic_op_splitter.cc | 6 +- .../service/gpu/variadic_op_splitter_test.cc | 6 +- .../compiler/xla/service/gpu/while_thunk.h | 2 +- .../compiler/xla/service/graphcycles/BUILD | 10 +- .../compiler/xla/service/heap_simulator.cc | 2 +- .../compiler/xla/service/heap_simulator.h | 15 +- .../xla/service/heap_simulator_test.cc | 6 +- tensorflow/compiler/xla/service/hlo.proto | 50 +- .../xla/service/hlo_activation_analysis.cc | 127 + .../xla/service/hlo_activation_analysis.h} | 19 +- .../service/hlo_activation_analysis_test.cc | 298 + .../xla/service/hlo_alias_analysis.cc | 4 +- .../compiler/xla/service/hlo_alias_analysis.h | 4 +- .../xla/service/hlo_alias_analysis_test.cc | 2 +- tensorflow/compiler/xla/service/hlo_buffer.cc | 2 +- .../xla/service/hlo_casting_utils_test.cc | 4 +- .../compiler/xla/service/hlo_computation.h | 46 - .../service/hlo_computation_deduplicator.cc | 4 +- .../hlo_computation_deduplicator_test.cc | 6 +- .../xla/service/hlo_computation_test.cc | 8 +- .../xla/service/hlo_constant_folding.cc | 8 +- .../xla/service/hlo_constant_folding.h | 2 +- .../xla/service/hlo_constant_folding_test.cc | 6 +- .../compiler/xla/service/hlo_cost_analysis.cc | 382 +- .../compiler/xla/service/hlo_cost_analysis.h | 391 +- .../xla/service/hlo_cost_analysis_test.cc | 125 +- .../xla/service/hlo_creation_utils.cc | 6 +- .../compiler/xla/service/hlo_creation_utils.h | 4 +- .../xla/service/hlo_creation_utils_test.cc | 4 +- tensorflow/compiler/xla/service/hlo_cse.cc | 25 +- tensorflow/compiler/xla/service/hlo_cse.h | 2 +- .../compiler/xla/service/hlo_cse_test.cc | 83 +- .../xla/service/hlo_dataflow_analysis.cc | 43 +- .../xla/service/hlo_dataflow_analysis.h | 4 +- .../xla/service/hlo_dataflow_analysis_test.cc | 58 +- tensorflow/compiler/xla/service/hlo_dce.cc | 12 +- tensorflow/compiler/xla/service/hlo_dce.h | 6 +- .../compiler/xla/service/hlo_dce_test.cc | 12 +- .../xla/service/hlo_domain_isolator.cc | 8 +- .../xla/service/hlo_domain_isolator.h | 4 +- .../compiler/xla/service/hlo_domain_map.cc | 2 +- .../compiler/xla/service/hlo_domain_map.h | 8 +- .../xla/service/hlo_domain_remover.cc | 6 +- .../compiler/xla/service/hlo_domain_remover.h | 6 +- .../compiler/xla/service/hlo_domain_test.cc | 39 +- .../xla/service/hlo_domain_verifier.cc | 6 +- .../xla/service/hlo_domain_verifier.h | 4 +- .../xla/service/hlo_element_type_converter.cc | 8 +- .../xla/service/hlo_element_type_converter.h | 2 +- .../xla/service/hlo_execution_profile.cc | 4 +- .../compiler/xla/service/hlo_graph_dumper.cc | 13 +- .../compiler/xla/service/hlo_graph_dumper.h | 2 +- .../xla/service/hlo_graph_dumper_test.cc | 8 +- .../service/hlo_input_output_alias_config.h | 28 - .../hlo_input_output_alias_config_test.cc | 8 +- .../compiler/xla/service/hlo_instruction.h | 64 - .../xla/service/hlo_instruction_test.cc | 20 +- .../compiler/xla/service/hlo_instructions.h | 38 - .../compiler/xla/service/hlo_live_range.cc | 26 +- .../compiler/xla/service/hlo_live_range.h | 6 +- .../xla/service/hlo_live_range_test.cc | 77 +- .../xla/service/hlo_liveness_analysis.cc | 8 +- .../xla/service/hlo_liveness_analysis.h | 4 +- .../xla/service/hlo_liveness_analysis_test.cc | 4 +- .../compiler/xla/service/hlo_matchers.cc | 6 +- .../compiler/xla/service/hlo_matchers.h | 7 +- .../xla/service/hlo_memory_scheduler.cc | 8 +- .../xla/service/hlo_memory_scheduler.h | 6 +- .../xla/service/hlo_memory_scheduler_test.cc | 8 +- tensorflow/compiler/xla/service/hlo_module.h | 49 - .../compiler/xla/service/hlo_module_config.cc | 301 +- .../compiler/xla/service/hlo_module_config.h | 87 +- .../xla/service/hlo_module_config_test.cc | 72 + .../compiler/xla/service/hlo_module_dce.cc | 8 +- .../compiler/xla/service/hlo_module_dce.h | 2 +- .../xla/service/hlo_module_dce_test.cc | 8 +- .../xla/service/hlo_module_group_metadata.cc | 6 +- .../xla/service/hlo_module_group_metadata.h | 6 +- .../xla/service/hlo_module_group_test.cc | 2 +- .../xla/service/hlo_module_group_util.cc | 6 +- .../xla/service/hlo_module_group_util.h | 4 +- .../xla/service/hlo_module_metadata.h | 28 - .../xla/service/hlo_module_metadata_test.cc | 3 +- .../compiler/xla/service/hlo_module_test.cc | 210 +- .../compiler/xla/service/hlo_module_util.cc | 25 +- .../compiler/xla/service/hlo_op_metadata.h | 24 - .../compiler/xla/service/hlo_opcode_test.cc | 2 +- .../compiler/xla/service/hlo_ordering.cc | 15 +- .../compiler/xla/service/hlo_ordering.h | 6 +- .../compiler/xla/service/hlo_ordering_test.cc | 55 +- tensorflow/compiler/xla/service/hlo_parser.cc | 326 +- tensorflow/compiler/xla/service/hlo_parser.h | 6 +- .../compiler/xla/service/hlo_parser_test.cc | 252 +- .../compiler/xla/service/hlo_pass_fix.h | 4 +- .../compiler/xla/service/hlo_pass_interface.h | 4 +- .../compiler/xla/service/hlo_pass_pipeline.cc | 27 +- .../compiler/xla/service/hlo_pass_pipeline.h | 2 +- .../xla/service/hlo_pass_pipeline_test.cc | 6 +- .../compiler/xla/service/hlo_phi_graph.h | 4 +- .../compiler/xla/service/hlo_proto_util.h | 2 +- .../xla/service/hlo_proto_util_test.cc | 2 +- tensorflow/compiler/xla/service/hlo_query.cc | 6 +- tensorflow/compiler/xla/service/hlo_query.h | 8 +- .../compiler/xla/service/hlo_reachability.cc | 122 +- .../compiler/xla/service/hlo_reachability.h | 118 +- .../xla/service/hlo_reachability_test.cc | 2 +- .../xla/service/hlo_rematerialization.cc | 211 +- .../xla/service/hlo_rematerialization.h | 11 +- .../xla/service/hlo_rematerialization_test.cc | 22 +- .../hlo_rematerialization_test_utils.h | 6 +- .../hlo_rematerialization_test_utils_test.cc | 6 +- .../xla/service/hlo_replication_analysis.cc | 21 +- .../xla/service/hlo_replication_analysis.h | 6 +- .../service/hlo_replication_analysis_test.cc | 6 +- tensorflow/compiler/xla/service/hlo_runner.cc | 140 +- tensorflow/compiler/xla/service/hlo_runner.h | 28 +- .../xla/service/hlo_runner_interface.cc | 2 +- .../xla/service/hlo_runner_interface.h | 4 +- .../compiler/xla/service/hlo_runner_pjrt.cc | 148 +- .../compiler/xla/service/hlo_runner_pjrt.h | 22 +- .../compiler/xla/service/hlo_schedule.h | 31 - .../compiler/xla/service/hlo_schedule_test.cc | 8 +- .../compiler/xla/service/hlo_sharding.h | 35 - .../xla/service/hlo_sharding_metadata.h | 30 - .../compiler/xla/service/hlo_sharding_test.cc | 14 +- .../compiler/xla/service/hlo_sharding_util.cc | 241 +- .../compiler/xla/service/hlo_sharding_util.h | 30 +- .../xla/service/hlo_sharding_util_test.cc | 111 + tensorflow/compiler/xla/service/hlo_value.cc | 17 +- tensorflow/compiler/xla/service/hlo_value.h | 26 +- .../compiler/xla/service/hlo_verifier.cc | 164 +- .../compiler/xla/service/hlo_verifier.h | 8 + .../compiler/xla/service/hlo_verifier_test.cc | 127 +- .../xla/service/indexed_array_analysis.h | 4 +- .../xla/service/instruction_fusion.cc | 5 +- .../compiler/xla/service/instruction_fusion.h | 6 +- .../xla/service/instruction_hoister.h | 2 +- .../xla/service/latency_hiding_scheduler.cc | 330 +- .../xla/service/latency_hiding_scheduler.h | 42 +- .../service/latency_hiding_scheduler_test.cc | 117 +- .../compiler/xla/service/layout_assignment.cc | 14 +- .../compiler/xla/service/layout_assignment.h | 6 +- .../xla/service/layout_assignment_test.cc | 8 +- .../xla/service/layout_normalization.cc | 243 +- .../xla/service/layout_normalization.h | 4 +- .../xla/service/layout_normalization_test.cc | 107 + tensorflow/compiler/xla/service/llvm_ir/BUILD | 20 +- .../xla/service/llvm_ir/alias_analysis.h | 2 +- .../llvm_ir/dynamic_update_slice_util.h | 2 +- .../xla/service/llvm_ir/fused_ir_emitter.cc | 6 +- .../xla/service/llvm_ir/fused_ir_emitter.h | 2 +- .../xla/service/llvm_ir/ir_builder_mixin.h | 6 +- .../compiler/xla/service/llvm_ir/llvm_util.cc | 5 + .../compiler/xla/service/llvm_ir/llvm_util.h | 2 +- .../compiler/xla/service/local_service.cc | 4 +- .../compiler/xla/service/logical_buffer.cc | 4 +- .../compiler/xla/service/logical_buffer.h | 2 +- .../xla/service/logical_buffer_analysis.cc | 26 +- .../xla/service/logical_buffer_analysis.h | 7 +- .../compiler/xla/service/logistic_expander.cc | 8 +- .../compiler/xla/service/logistic_expander.h | 2 +- .../xla/service/logistic_expander_test.cc | 10 +- .../xla/service/loop_schedule_linearizer.cc | 26 +- .../xla/service/loop_schedule_linearizer.h | 6 +- .../service/loop_schedule_linearizer_test.cc | 8 +- .../compiler/xla/service/map_inliner.cc | 8 +- tensorflow/compiler/xla/service/map_inliner.h | 2 +- .../compiler/xla/service/map_inliner_test.cc | 6 +- .../xla/service/memory_space_assignment.cc | 213 +- .../xla/service/memory_space_assignment.h | 61 +- ...mory_space_assignment_best_fit_repacker.cc | 19 + .../service/memory_space_assignment_test.cc | 203 +- .../memory_space_assignment_tuning_utils.h | 2 +- .../service/memory_space_assignment_utils.cc | 17 +- .../xla/service/memory_space_propagation.h | 2 +- tensorflow/compiler/xla/service/metrics.proto | 26 + .../xla/service/metrics_hook_interface.h | 56 + .../xla/service/multi_output_fusion.cc | 4 +- .../xla/service/multi_output_fusion.h | 2 +- .../compiler/xla/service/op_expander_pass.cc | 2 +- .../compiler/xla/service/operand_upcaster.cc | 2 +- .../compiler/xla/service/operand_upcaster.h | 2 +- .../optimize_input_output_buffer_alias.cc | 4 +- .../optimize_input_output_buffer_alias.h | 4 +- .../compiler/xla/service/pattern_matcher.h | 212 +- .../xla/service/pattern_matcher_gmock.h | 3 +- .../xla/service/pattern_matcher_test.cc | 272 +- .../xla/service/real_imag_expander_test.cc | 4 +- .../compiler/xla/service/reduce_decomposer.cc | 6 +- .../xla/service/reduce_scatter_combiner.cc | 8 +- .../xla/service/reduce_scatter_combiner.h | 2 +- .../service/reduce_scatter_combiner_test.cc | 4 +- .../xla/service/reduce_scatter_decomposer.cc | 12 +- .../xla/service/reduce_scatter_decomposer.h | 2 +- .../service/reduce_scatter_decomposer_test.cc | 4 +- .../xla/service/reduce_scatter_reassociate.cc | 8 +- .../xla/service/reduce_scatter_reassociate.h | 2 +- .../reduce_scatter_reassociate_test.cc | 4 +- .../xla/service/reduce_scatter_utils.cc | 4 +- .../xla/service/reduce_scatter_utils.h | 2 +- .../xla/service/reshape_decomposer.cc | 2 +- .../xla/service/reshape_mover_test.cc | 6 +- .../compiler/xla/service/result_caster.h | 2 +- .../xla/service/rng_bit_generator_expander.cc | 11 +- .../xla/service/rng_bit_generator_expander.h | 4 +- .../xla/service/root_instruction_sinker.h | 2 +- .../compiler/xla/service/scatter_expander.cc | 53 +- .../compiler/xla/service/scatter_expander.h | 7 + .../xla/service/scatter_expander_test.cc | 96 +- .../xla/service/scatter_simplifier.cc | 6 +- .../service/select_and_scatter_expander.cc | 6 +- .../select_and_scatter_expander_test.cc | 2 +- tensorflow/compiler/xla/service/service.cc | 12 +- tensorflow/compiler/xla/service/service.h | 2 +- .../compiler/xla/service/shape_inference.cc | 1 + .../compiler/xla/service/shape_inference.h | 4 +- .../xla/service/shape_inference_test.cc | 2 +- .../xla/service/sharding_propagation.cc | 190 +- .../xla/service/sharding_propagation.h | 17 +- .../xla/service/sharding_propagation_test.cc | 379 +- .../compiler/xla/service/sharding_remover.cc | 6 +- .../compiler/xla/service/sharding_remover.h | 2 +- .../xla/service/simplify_fp_conversions.cc | 6 +- .../xla/service/simplify_fp_conversions.h | 2 +- .../service/simplify_fp_conversions_test.cc | 2 +- .../compiler/xla/service/slice_sinker_test.cc | 8 +- .../xla/service/slow_operation_alarm.cc | 2 +- .../compiler/xla/service/sort_simplifier.cc | 4 +- .../compiler/xla/service/sort_simplifier.h | 2 +- .../xla/service/space_to_batch_converter.cc | 27 +- .../xla/service/space_to_batch_converter.h | 2 +- .../service/space_to_batch_converter_test.cc | 6 +- tensorflow/compiler/xla/service/spmd/BUILD | 85 +- .../spmd/canonicalize_all_gather_for_cse.cc | 8 +- .../spmd/canonicalize_all_gather_for_cse.h | 2 +- .../canonicalize_all_gather_for_cse_test.cc | 2 +- .../service/spmd/collective_permute_motion.cc | 316 + .../service/spmd/collective_permute_motion.h | 41 + .../spmd/collective_permute_motion_test.cc | 251 + .../xla/service/spmd/convolution_handler.cc | 73 +- .../xla/service/spmd/convolution_handler.h | 6 +- .../xla/service/spmd/custom_call_handler.cc | 10 +- .../xla/service/spmd/custom_call_handler.h | 2 +- .../compiler/xla/service/spmd/dot_handler.cc | 182 +- .../compiler/xla/service/spmd/fft_handler.cc | 8 +- .../service/spmd/gather_scatter_handler.cc | 21 +- .../xla/service/spmd/partition_assignment.cc | 97 + .../xla/service/spmd/partition_assignment.h | 115 + .../service/spmd/partition_assignment_test.cc | 52 + .../spmd/schedule_aware_collective_ops_cse.cc | 10 +- .../spmd/schedule_aware_collective_ops_cse.h | 2 +- .../schedule_aware_collective_ops_cse_test.cc | 2 +- .../xla/service/spmd/spmd_partitioner.cc | 247 +- .../xla/service/spmd/spmd_partitioner.h | 21 +- .../xla/service/spmd/spmd_partitioner_test.cc | 609 +- .../xla/service/spmd/spmd_partitioner_util.cc | 18 +- .../xla/service/spmd/spmd_partitioner_util.h | 10 +- .../spmd/stateful_rng_spmd_partitioner.cc | 4 +- .../spmd/stateful_rng_spmd_partitioner.h | 6 +- .../stateful_rng_spmd_partitioner_test.cc | 8 +- .../xla/service/stable_sort_expander.cc | 8 +- .../xla/service/stable_sort_expander.h | 4 +- .../service/stochastic_convert_decomposer.cc | 157 + .../service/stochastic_convert_decomposer.h | 39 + .../stochastic_convert_decomposer_test.cc | 122 + .../compiler/xla/service/topk_rewriter.cc | 4 +- .../compiler/xla/service/topk_rewriter.h | 6 +- .../xla/service/topk_rewriter_test.cc | 2 +- .../xla/service/tpu_computation_placer.cc | 13 +- .../compiler/xla/service/transfer_manager.cc | 101 +- .../compiler/xla/service/transfer_manager.h | 41 +- .../compiler/xla/service/transpose_folding.cc | 10 +- .../compiler/xla/service/transpose_folding.h | 2 +- .../xla/service/transpose_folding_test.cc | 8 +- .../xla/service/tree_reduction_rewriter.cc | 7 +- .../xla/service/tree_reduction_rewriter.h | 2 +- .../xla/service/tuple_points_to_analysis.cc | 40 +- .../xla/service/tuple_points_to_analysis.h | 7 +- .../service/tuple_points_to_analysis_test.cc | 6 +- .../compiler/xla/service/tuple_simplifier.cc | 25 +- .../compiler/xla/service/tuple_simplifier.h | 4 +- .../xla/service/tuple_simplifier_test.cc | 33 +- tensorflow/compiler/xla/service/tuple_util.cc | 3 +- tensorflow/compiler/xla/service/tuple_util.h | 2 +- .../while_loop_all_reduce_code_motion.cc | 12 +- .../while_loop_all_reduce_code_motion.h | 2 +- .../while_loop_all_reduce_code_motion_test.cc | 8 +- .../xla/service/while_loop_analysis.cc | 6 +- .../xla/service/while_loop_analysis.h | 2 +- .../service/while_loop_concat_code_motion.cc | 13 +- .../service/while_loop_concat_code_motion.h | 2 +- .../while_loop_concat_code_motion_test.cc | 8 +- .../xla/service/while_loop_constant_sinking.h | 2 +- ...le_loop_expensive_invariant_code_motion.cc | 4 +- ...ile_loop_expensive_invariant_code_motion.h | 2 +- ...op_expensive_invariant_code_motion_test.cc | 42 + .../while_loop_invariant_code_motion.cc | 4 +- .../while_loop_invariant_code_motion.h | 2 +- .../while_loop_invariant_code_motion_test.cc | 51 + .../xla/service/while_loop_simplifier.cc | 8 +- .../xla/service/while_loop_simplifier.h | 2 +- .../xla/service/while_loop_simplifier_test.cc | 2 +- .../service/while_loop_trip_count_annotator.h | 2 +- tensorflow/compiler/xla/service/while_util.cc | 4 +- tensorflow/compiler/xla/service/while_util.h | 2 +- .../xla/service/xla_aot_compile_cpu_test.cc | 82 + .../xla/service/xla_aot_compile_gpu_test.cc | 292 + .../xla_aot_compile_stablehlo_cpu_test.cc | 90 + .../xla_aot_compile_stablehlo_test.mlir | 10 + .../xla/service/xla_aot_compile_test.mlir | 6 + ...aot_compile_test_autotune_results.prototxt | 37 + .../xla_aot_compile_test_constant.mlir | 7 + .../xla_aot_compile_test_convolution.mlir | 21 + .../service/xla_aot_compile_test_gemm.mlir | 12 + ...ot_compile_test_gpu_target_config.prototxt | 39 + .../compiler/xla/service/xla_compile.bzl | 112 + .../compiler/xla/service/xla_compile_main.cc | 222 + .../xla/service/xla_debug_info_manager.h | 2 +- .../xla/service/zero_sized_hlo_elimination.cc | 4 +- .../xla/service/zero_sized_hlo_elimination.h | 2 +- .../zero_sized_hlo_elimination_test.cc | 8 +- tensorflow/compiler/xla/shape.cc | 24 +- tensorflow/compiler/xla/shape.h | 13 +- tensorflow/compiler/xla/shape_layout.cc | 12 +- tensorflow/compiler/xla/shape_layout.h | 4 + tensorflow/compiler/xla/shape_util.cc | 309 +- tensorflow/compiler/xla/shape_util.h | 120 +- tensorflow/compiler/xla/shape_util_test.cc | 163 +- tensorflow/compiler/xla/side_effect_util.cc | 2 + tensorflow/compiler/xla/side_effect_util.h | 1 + tensorflow/compiler/xla/status.h | 6 +- tensorflow/compiler/xla/statusor.h | 2 +- tensorflow/compiler/xla/stream_executor/BUILD | 164 +- .../compiler/xla/stream_executor/blas.h | 174 +- .../xla/stream_executor/build_defs.bzl | 5 + .../compiler/xla/stream_executor/cuda/BUILD | 88 +- .../stream_executor/cuda/cuda_asm_compiler.cc | 96 +- .../xla/stream_executor/cuda/cuda_blas.cc | 210 +- .../xla/stream_executor/cuda/cuda_blas.h | 11 +- .../xla/stream_executor/cuda/cuda_blas_lt.cc | 143 +- .../xla/stream_executor/cuda/cuda_blas_lt.h | 105 +- .../stream_executor/cuda/cuda_blas_utils.cc | 12 +- .../stream_executor/cuda/cuda_blas_utils.h | 4 +- .../stream_executor/cuda/cuda_diagnostics.cc | 92 +- .../stream_executor/cuda/cuda_diagnostics.h | 4 +- .../xla/stream_executor/cuda/cuda_dnn.cc | 618 +- .../xla/stream_executor/cuda/cuda_dnn.h | 129 +- .../xla/stream_executor/cuda/cuda_driver.cc | 313 +- .../xla/stream_executor/cuda/cuda_event.cc | 4 +- .../xla/stream_executor/cuda/cuda_fft.cc | 90 +- .../xla/stream_executor/cuda/cuda_fft.h | 22 +- .../stream_executor/cuda/cuda_gpu_executor.cc | 221 +- .../xla/stream_executor/cuda/cuda_graph.cc | 179 + .../xla/stream_executor/cuda/cuda_graph.h | 91 + .../xla/stream_executor/cuda/cuda_platform.cc | 30 +- .../xla/stream_executor/cuda/cuda_platform.h | 16 +- .../xla/stream_executor/cuda/cuda_rng.cc | 7 +- .../cuda/redzone_allocator_test.cc | 5 +- .../cuda/stream_search_test.cc | 8 +- .../compiler/xla/stream_executor/data_type.h | 10 + .../xla/stream_executor/device_description.cc | 23 +- .../xla/stream_executor/device_description.h | 47 +- .../stream_executor/device_description.proto | 69 + .../stream_executor/device_host_allocator.h | 10 +- .../xla/stream_executor}/device_id_utils.h | 79 +- .../stream_executor/device_mem_allocator.h | 10 +- .../stream_executor/device_memory_allocator.h | 40 +- .../compiler/xla/stream_executor/dnn.cc | 40 +- tensorflow/compiler/xla/stream_executor/dnn.h | 176 +- .../compiler/xla/stream_executor/dnn.proto | 167 +- .../xla/stream_executor/executor_cache.cc | 22 +- .../xla/stream_executor/executor_cache.h | 10 +- tensorflow/compiler/xla/stream_executor/fft.h | 2 +- .../compiler/xla/stream_executor/gpu/BUILD | 135 +- .../xla/stream_executor/gpu/asm_compiler.cc | 84 +- .../xla/stream_executor/gpu/asm_compiler.h | 37 +- .../gpu/gpu_cudamallocasync_allocator.cc | 78 +- .../gpu/gpu_cudamallocasync_allocator.h | 25 +- .../xla/stream_executor/gpu/gpu_diagnostics.h | 16 +- .../xla/stream_executor/gpu/gpu_driver.h | 171 +- .../xla/stream_executor/gpu/gpu_event.cc | 8 +- .../xla/stream_executor/gpu/gpu_event.h | 8 +- .../xla/stream_executor/gpu/gpu_executor.h | 89 +- .../xla/stream_executor}/gpu/gpu_init.cc | 29 +- .../xla/stream_executor}/gpu/gpu_init.h | 18 +- .../xla/stream_executor/gpu/gpu_stream.cc | 4 +- .../xla/stream_executor/gpu/gpu_timer.cc | 12 +- .../stream_executor/gpu/redzone_allocator.cc | 21 +- .../stream_executor/gpu/redzone_allocator.h | 4 +- .../compiler/xla/stream_executor/host/BUILD | 19 +- .../stream_executor/host/host_gpu_executor.cc | 68 +- .../stream_executor/host/host_gpu_executor.h | 73 +- .../xla/stream_executor/host/host_platform.cc | 21 +- .../xla/stream_executor/host/host_platform.h | 11 +- .../xla/stream_executor/host/host_stream.cc | 30 +- .../xla/stream_executor/host/host_stream.h | 17 +- .../stream_executor/host/host_stream_test.cc | 8 +- .../compiler/xla/stream_executor/kernel.cc | 4 +- .../compiler/xla/stream_executor/kernel.h | 12 +- .../xla/stream_executor/lazy_op_runner.h | 63 +- .../compiler/xla/stream_executor/lib/BUILD | 26 - .../xla/stream_executor/lib/array_slice.h | 43 - .../xla/stream_executor/lib/demangle.cc | 51 - .../compiler/xla/stream_executor/lib/env.h | 47 - .../xla/stream_executor/lib/human_readable.h | 74 - .../xla/stream_executor/lib/initialize.h | 21 - .../xla/stream_executor/lib/mathutil.h | 101 - .../xla/stream_executor/lib/numbers.cc | 43 - .../xla/stream_executor/lib/numbers.h | 35 - .../compiler/xla/stream_executor/lib/path.cc | 62 - .../compiler/xla/stream_executor/lib/path.h | 57 - .../xla/stream_executor/lib/process_state.cc | 57 - .../xla/stream_executor/lib/process_state.h | 30 - .../xla/stream_executor/lib/stacktrace.h | 30 - .../compiler/xla/stream_executor/lib/status.h | 61 - .../xla/stream_executor/lib/thread_options.h | 29 - .../xla/stream_executor/lib/threadpool.h | 32 - .../xla/stream_executor/module_spec.h | 9 +- .../stream_executor/multi_platform_manager.cc | 93 +- .../stream_executor/multi_platform_manager.h | 34 +- .../compiler/xla/stream_executor/platform.cc | 16 +- .../compiler/xla/stream_executor/platform.h | 21 +- .../xla/stream_executor/platform/BUILD | 1 + .../stream_executor/platform/default/BUILD | 11 +- .../platform/default/dso_loader.h | 4 +- .../xla/stream_executor/plugin_registry.cc | 38 +- .../xla/stream_executor/plugin_registry.h | 54 +- .../compiler/xla/stream_executor/rocm/BUILD | 32 +- .../stream_executor/rocm/hipsolver_wrapper.h | 6 +- .../stream_executor/rocm/hipsparse_wrapper.h | 6 +- .../stream_executor/rocm/rocblas_wrapper.h | 7 +- .../xla/stream_executor/rocm/rocm_blas.cc | 219 +- .../xla/stream_executor/rocm/rocm_blas.h | 23 +- .../stream_executor/rocm/rocm_diagnostics.cc | 69 +- .../stream_executor/rocm/rocm_diagnostics.h | 4 +- .../xla/stream_executor/rocm/rocm_dnn.cc | 383 +- .../xla/stream_executor/rocm/rocm_dnn.h | 75 +- .../xla/stream_executor/rocm/rocm_driver.cc | 293 +- .../rocm/rocm_driver_wrapper.h | 34 +- .../xla/stream_executor/rocm/rocm_event.cc | 3 +- .../xla/stream_executor/rocm/rocm_fft.cc | 164 +- .../xla/stream_executor/rocm/rocm_fft.h | 22 +- .../stream_executor/rocm/rocm_gpu_executor.cc | 176 +- .../xla/stream_executor/rocm/rocm_platform.cc | 29 +- .../xla/stream_executor/rocm/rocm_platform.h | 15 +- .../xla/stream_executor/rocm/rocm_rng.cc | 13 +- .../stream_executor/rocm/rocsolver_wrapper.h | 12 +- .../stream_executor/rocm/roctracer_wrapper.h | 30 +- .../xla/stream_executor/scratch_allocator.cc | 4 +- .../xla/stream_executor/scratch_allocator.h | 8 +- .../compiler/xla/stream_executor/stream.cc | 336 +- .../compiler/xla/stream_executor/stream.h | 304 +- .../stream_executor_internal.cc | 33 - .../stream_executor_internal.h | 102 +- .../stream_executor/stream_executor_pimpl.cc | 179 +- .../stream_executor/stream_executor_pimpl.h | 118 +- .../temporary_memory_manager.cc | 8 +- .../temporary_memory_manager.h | 12 +- .../stream_executor/tf_allocator_adapter.cc | 9 +- .../stream_executor/tf_allocator_adapter.h | 22 +- .../compiler/xla/stream_executor/tpu/BUILD | 102 +- .../stream_executor/tpu/c_api_conversions.cc | 15 +- .../stream_executor/tpu/c_api_conversions.h | 2 +- .../xla/stream_executor/tpu/c_api_decl.h | 10 +- .../xla/stream_executor/tpu/status_helper.h | 13 +- .../xla/stream_executor/tpu/tpu_api.cc | 4 +- .../xla/stream_executor/tpu/tpu_api.h | 4 +- .../xla/stream_executor/tpu/tpu_event.h | 6 +- .../xla/stream_executor/tpu/tpu_executable.cc | 2 +- .../xla/stream_executor/tpu/tpu_executable.h | 5 +- .../tpu/tpu_executable_interface.cc | 26 +- .../tpu/tpu_executable_interface.h | 9 +- .../xla/stream_executor/tpu/tpu_executor.cc | 138 +- .../xla/stream_executor/tpu/tpu_executor.h | 88 +- .../stream_executor/tpu/tpu_executor_api.cc | 4 +- .../stream_executor/tpu/tpu_executor_api.h | 4 +- .../tpu/tpu_executor_init_fns.inc | 4 +- .../tpu/tpu_executor_interface.h | 9 +- .../tpu/tpu_initializer_helper.cc | 45 +- .../tpu/tpu_initializer_helper.h | 8 +- .../tpu/tpu_library_init_fns.inc | 13 +- .../stream_executor/tpu/tpu_node_context.cc | 30 +- .../stream_executor/tpu/tpu_node_context.h | 13 +- .../tpu/tpu_on_demand_compiler.cc | 14 +- .../stream_executor/tpu/tpu_op_executable.cc | 66 +- .../stream_executor/tpu/tpu_op_executable.h | 8 +- .../xla/stream_executor/tpu/tpu_ops_c_api.h | 32 +- .../xla/stream_executor/tpu/tpu_platform.cc | 94 +- .../xla/stream_executor/tpu/tpu_platform.h | 11 +- .../tpu/tpu_platform_interface.h | 8 +- .../xla/stream_executor/tpu/tpu_stream.h | 28 +- .../tpu/tpu_stream_interface.h | 4 +- .../xla/stream_executor/tpu/tpu_timer.h | 8 +- .../xla/stream_executor/tpu/tpu_topology.cc | 78 +- .../tpu/tpu_transfer_manager.cc | 152 +- .../tpu/tpu_transfer_manager.h | 31 +- .../xla/stream_executor/trace_listener.h | 8 +- tensorflow/compiler/xla/test.h | 2 +- tensorflow/compiler/xla/tests/BUILD | 231 +- .../xla/tests/array_elementwise_ops_test.cc | 110 +- .../xla/tests/bitcast_convert_test.cc | 11 + tensorflow/compiler/xla/tests/build_defs.bzl | 4 +- .../xla/tests/client_library_test_base.h | 13 + .../compiler/xla/tests/collective_ops_test.cc | 376 +- .../compiler/xla/tests/constants_test.cc | 18 + .../conv_depthwise_backprop_filter_test.cc | 13 + tensorflow/compiler/xla/tests/convert_test.cc | 149 +- .../xla/tests/convolution_cudnn_test.cc | 65 + .../xla/tests/convolution_variants_test.cc | 12 + .../compiler/xla/tests/cpu_gpu_fusion_test.cc | 5 +- .../compiler/xla/tests/custom_call_test.cc | 70 + .../compiler/xla/tests/dot_operation_test.cc | 362 +- .../compiler/xla/tests/exhaustive/BUILD | 141 + .../exhaustive_binary_16_bit_test.cc | 2 +- .../exhaustive_binary_test_f32_f64.cc | 2 +- .../exhaustive_op_test_utils.cc | 2 +- .../exhaustive_op_test_utils.h | 6 +- .../exhaustive_unary_test_complex.cc | 2 +- .../exhaustive_unary_test_f32_or_smaller.cc | 108 +- .../exhaustive_unary_test_f64.cc | 2 +- .../xla/tests/gpu_dump_mlir_passes_test.cc | 76 + .../compiler/xla/tests/hlo_test_base.cc | 138 +- tensorflow/compiler/xla/tests/hlo_test_base.h | 21 + .../xla/tests/literal_test_util_test.cc | 274 + .../compiler/xla/tests/llvm_compiler_test.cc | 16 +- .../xla/tests/local_client_execute_test.cc | 5 + .../xla/tests/matrix_ops_simple_test.cc | 13 + tensorflow/compiler/xla/tests/params_test.cc | 13 +- .../xla/tests/pjrt_client_registry.cc | 21 +- .../compiler/xla/tests/pjrt_client_registry.h | 37 +- .../compiler/xla/tests/sample_file_test.cc | 8 +- tensorflow/compiler/xla/tests/test_macros.cc | 2 +- tensorflow/compiler/xla/tests/test_utils.cc | 56 +- tensorflow/compiler/xla/tests/test_utils.h | 4 + .../compiler/xla/tests/token_hlo_test.cc | 27 +- tensorflow/compiler/xla/tests/tuple_test.cc | 18 + .../xla/tests/vector_ops_simple_test.cc | 2 +- tensorflow/compiler/xla/tests/xla_ffi_test.cc | 147 + .../xla/tests/xla_hlo_profile_test.cc | 9 +- tensorflow/compiler/xla/tools/BUILD | 124 +- tensorflow/compiler/xla/tools/data/add.hlo | 6 + .../compiler/xla/tools/data/must_alias.hlo | 6 + .../tools/data/must_alias_with_sharding.hlo | 7 + tensorflow/compiler/xla/tools/driver.cc | 10 +- .../dumped_computation_to_operation_list.cc | 2 +- .../compiler/xla/tools/hlo_bisect/BUILD | 93 + .../xla/tools/hlo_bisect/hlo_bisect.cc | 137 + .../xla/tools/hlo_bisect/hlo_bisect_state.cc | 347 + .../xla/tools/hlo_bisect/hlo_bisect_state.h | 96 + .../tools/hlo_bisect/hlo_bisect_state_test.cc | 202 + .../xla/tools/hlo_bisect/hlo_bisect_utils.cc | 384 ++ .../xla/tools/hlo_bisect/hlo_bisect_utils.h | 99 + .../xla/tools/hlo_control_flow_flattening.cc | 14 +- .../xla/tools/hlo_control_flow_flattening.h | 2 +- .../tools/hlo_control_flow_flattening_test.cc | 10 +- .../compiler/xla/tools/hlo_extractor.cc | 11 +- tensorflow/compiler/xla/tools/hlo_extractor.h | 4 +- .../compiler/xla/tools/hlo_module_loader.cc | 46 +- .../compiler/xla/tools/hlo_module_loader.h | 20 +- .../compiler/xla/tools/hlo_proto_to_json.cc | 2 +- .../tools/interactive_graphviz_bin_test.cc | 66 + .../xla/tools/prepare_reference_module.cc | 4 +- .../xla/tools/prepare_reference_module.h | 4 +- .../compiler/xla/tools/replay_computation.cc | 2 + .../xla/tools/replay_computation_bin_test.cc | 65 + .../compiler/xla/tools/run_hlo_module.cc | 20 +- .../compiler/xla/tools/run_hlo_module.h | 6 +- .../xla/tools/run_hlo_module_bin_test.cc | 89 + .../compiler/xla/tools/run_hlo_module_main.cc | 10 +- tensorflow/compiler/xla/translate/BUILD | 12 +- .../compiler/xla/translate/hlo_to_mhlo/BUILD | 31 +- .../hlo_to_mhlo/attribute_importer.cc | 43 +- .../hlo_to_mhlo/attribute_importer.h | 14 +- .../hlo_to_mhlo/hlo_function_importer.cc | 332 +- .../hlo_to_mhlo/hlo_function_importer.h | 32 +- .../hlo_to_mhlo/hlo_module_importer.cc | 64 +- .../hlo_to_mhlo/hlo_module_importer.h | 5 +- .../xla/translate/hlo_to_mhlo/hlo_utils.cc | 23 +- .../xla/translate/hlo_to_mhlo/hlo_utils.h | 10 +- .../translate/hlo_to_mhlo/hlo_utils_test.cc | 7 +- .../translate/hlo_to_mhlo/location_importer.h | 2 +- .../translate/hlo_to_mhlo/mlir_hlo_builder.cc | 141 +- .../translate/hlo_to_mhlo/mlir_hlo_builder.h | 31 +- .../hlo_to_mhlo/mlir_hlo_builder_test.cc | 170 +- .../xla/translate/hlo_to_mhlo/tests/BUILD | 5 +- .../hlo_to_mhlo/tests/dynamic_param.hlo | 104 + .../tests/frontend_attributes.hlotxt | 15 + .../translate/hlo_to_mhlo/tests/fusion.hlotxt | 8 +- .../translate/hlo_to_mhlo/tests/import.hlotxt | 210 +- .../hlo_to_mhlo/tests/import_async.hlotxt | 31 +- .../hlo_to_mhlo/tests/location.hlotxt | 14 +- .../hlo_to_mhlo/tests/module_attributes.hlo | 214 + .../tests/spmd_module_sharding.hlo | 135 + .../translate/hlo_to_mhlo/tests/while.hlotxt | 2 +- .../xla/translate/hlo_to_mhlo/translate.cc | 2 +- .../compiler/xla/translate/mhlo_to_hlo/BUILD | 22 +- .../mhlo_to_hlo/attribute_exporter.cc | 30 +- .../mhlo_to_hlo/attribute_exporter.h | 11 +- .../xla/translate/mhlo_to_hlo/layout_util.h | 2 +- .../mhlo_to_hlo/location_exporter.cc | 8 +- .../translate/mhlo_to_hlo/location_exporter.h | 3 +- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 639 +- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.h | 5 +- .../mhlo_to_hlo/operator_writer_gen.cc | 10 +- .../xla/translate/mhlo_to_hlo/tests/BUILD | 5 +- .../translate/mhlo_to_hlo/tests/export.mlir | 1619 +---- .../tests/export_and_check_layouts.mlir | 17 +- .../translate/mhlo_to_hlo/tests/fusion.mlir | 20 +- .../tests/location_to_op_metadata.mlir | 9 +- .../mhlo_to_hlo/tests/module_attributes.mlir | 180 + .../translate/mhlo_to_hlo/tests/sharding.mlir | 12 + .../mhlo_to_hlo/tests/unsupported_type.mlir | 7 + .../xla/translate/mhlo_to_hlo/translate.cc | 25 +- .../xla/translate/mhlo_to_hlo/translate.h | 2 +- .../mhlo_to_hlo/translate_registration.cc | 20 +- .../translate/mhlo_to_hlo/type_to_shape.cc | 18 +- .../mhlo_to_hlo/type_to_shape_test.cc | 4 +- .../xla/translate/mhlo_to_lhlo_with_xla/BUILD | 104 + .../mhlo_to_lhlo_with_xla.cc | 547 +- .../mhlo_to_lhlo_with_xla.h | 119 +- .../mhlo_to_lhlo_with_xla/tests}/BUILD | 15 +- .../mhlo_to_lhlo_with_xla/tests}/gpu_ops.mlir | 2 +- .../tests}/hlo_text_to_lhlo_no_opt.hlotxt | 177 +- .../tests}/no_opt_ops.hlotxt | 2 +- .../tests}/non_identity_layouts.hlotxt | 2 +- .../mhlo_to_lhlo_with_xla/tests}/ops.mlir | 2 +- .../tests}/passthrough.mlir | 2 +- .../translate_registration.cc} | 2 +- .../xla_translate_opt_main.cc | 44 + .../xla/translate/xla_translate_main.cc | 29 +- tensorflow/compiler/xla/util.cc | 17 +- tensorflow/compiler/xla/util.h | 41 +- tensorflow/compiler/xla/util_test.cc | 13 + tensorflow/compiler/xla/xla.bzl | 92 + tensorflow/compiler/xla/xla.proto | 146 +- tensorflow/compiler/xla/xla_data.proto | 33 +- tensorflow/compiler/xrt/BUILD | 3 +- tensorflow/compiler/xrt/cc/BUILD | 1 + tensorflow/compiler/xrt/kernels/BUILD | 5 +- .../compiler/xrt/kernels/tpu_compile_ops.cc | 2 +- .../compiler/xrt/kernels/tpu_execute_op.cc | 2 +- .../compiler/xrt/kernels/xrt_execute_op.cc | 5 +- tensorflow/compiler/xrt/tests/BUILD | 1 + tensorflow/compiler/xrt/tests/raw_api_test.cc | 11 +- tensorflow/compiler/xrt/xrt_util.cc | 3 + tensorflow/compiler/xrt/xrt_util.h | 2 +- tensorflow/core/BUILD | 87 +- tensorflow/core/activity_watcher/BUILD | 1 + tensorflow/core/activity_watcher/activity.cc | 2 +- tensorflow/core/activity_watcher/activity.h | 7 +- tensorflow/core/api_def/BUILD | 1 + tensorflow/core/api_def/base_api/BUILD | 1 + .../core/api_def/base_api/api_def_Angle.pbtxt | 2 +- .../api_def_CollectiveReduceScatterV2.pbtxt | 5 + .../api_def_ComputeDedupDataTupleMask.pbtxt | 27 + .../api_def_Conv2DBackpropFilterV2.pbtxt | 75 + .../api_def_Conv2DBackpropInputV2.pbtxt | 75 + .../base_api/api_def_DistributedSave.pbtxt | 4 + .../base_api/api_def_DynamicPartition.pbtxt | 7 + .../api_def_FakeQuantWithMinMaxArgs.pbtxt | 28 +- .../base_api/api_def_MergeDedupData.pbtxt | 56 + .../base_api/api_def_RandomDatasetV2.pbtxt | 44 + .../base_api/api_def_SegmentProdV2.pbtxt | 56 + .../base_api/api_def_SegmentSumV2.pbtxt | 38 + .../base_api/api_def_SplitDedupData.pbtxt | 55 + .../api_def/base_api/api_def_SyncDevice.pbtxt | 9 + .../api_def_TPUPartitionedInputV2.pbtxt | 30 + .../api_def_TPUPartitionedOutputV2.pbtxt | 26 + .../base_api/api_def_TensorScatterMax.pbtxt | 2 +- .../api_def/base_api/api_def_Timestamp.pbtxt | 11 +- .../base_api/api_def_TruncateDiv.pbtxt | 2 +- .../api_def_UniformQuantizedAdd.pbtxt | 161 + .../api_def_UniformQuantizedConvolution.pbtxt | 238 + .../base_api/api_def_UniqueWithCountsV2.pbtxt | 6 +- ...api_def_XlaSendTPUEmbeddingGradients.pbtxt | 12 + tensorflow/core/api_def/java_api/BUILD | 1 + tensorflow/core/api_def/python_api/BUILD | 5 +- .../python_api/api_def_SegmentProdV2.pbtxt | 4 + .../python_api/api_def_SegmentSumV2.pbtxt | 4 + .../api_def/python_api/api_def_Xdivy.pbtxt | 4 +- tensorflow/core/common_runtime/BUILD | 206 +- .../core/common_runtime/arg_ret_placement.cc | 223 + .../core/common_runtime/arg_ret_placement.h | 136 + .../common_runtime/arg_ret_placement_test.cc | 328 + .../base_collective_executor.cc | 18 +- .../collective_param_resolver_local.cc | 5 +- .../common_runtime/collective_test_util.cc | 12 +- .../common_runtime/collective_test_util.h | 4 +- tensorflow/core/common_runtime/device/BUILD | 13 +- .../common_runtime/device/device_event_mgr.cc | 172 +- .../common_runtime/device/device_event_mgr.h | 94 +- .../device/device_event_mgr_test.cc | 34 +- tensorflow/core/common_runtime/eager/BUILD | 89 +- .../core/common_runtime/eager/context.cc | 45 +- .../core/common_runtime/eager/context.h | 72 +- .../eager/context_distributed_manager.h | 8 +- .../core/common_runtime/eager/context_test.cc | 31 +- .../eager/custom_device_test.cc | 9 +- .../common_runtime/eager/eager_executor.cc | 32 +- .../common_runtime/eager/eager_executor.h | 10 +- .../eager/eager_executor_test.cc | 285 + .../eager/eager_op_rewrite_registry_test.cc | 3 +- .../eager/eager_operation_test.cc | 6 +- .../core/common_runtime/eager/execute.cc | 96 +- .../common_runtime/eager/execute_node_test.cc | 6 +- .../common_runtime/eager/kernel_and_device.cc | 46 +- .../common_runtime/eager/kernel_and_device.h | 42 +- .../eager/mkl_eager_op_rewrite.cc | 3 + .../eager/mkl_eager_op_rewrite_test.cc | 3 +- .../common_runtime/eager/placement_test.cc | 11 +- .../common_runtime/eager/placement_utils.h | 1 - .../eager/placement_utils_test.cc | 255 + .../eager/tensor_handle_data_test.cc | 161 + .../eager/tensor_handle_test.cc | 292 +- .../eager/zen_eager_op_rewrite.cc | 180 + tensorflow/core/common_runtime/executor.cc | 39 +- tensorflow/core/common_runtime/executor.h | 7 +- tensorflow/core/common_runtime/function.cc | 21 +- tensorflow/core/common_runtime/gpu/BUILD | 94 +- .../gpu/gpu_bfc_allocator_test.cc | 8 +- .../gpu/gpu_cudamalloc_allocator.cc | 8 +- .../common_runtime/gpu/gpu_debug_allocator.cc | 12 +- .../gpu/gpu_debug_allocator_test.cc | 8 +- .../core/common_runtime/gpu/gpu_device.cc | 87 +- .../core/common_runtime/gpu/gpu_device.h | 29 +- .../common_runtime/gpu/gpu_device_factory.cc | 14 +- .../common_runtime/gpu/gpu_device_test.cc | 57 +- .../common_runtime/gpu/gpu_process_state.cc | 66 +- .../common_runtime/gpu/gpu_process_state.h | 10 +- .../core/common_runtime/gpu/gpu_util.cc | 7 +- .../gpu/gpu_virtual_mem_allocator.cc | 17 +- .../gpu/gpu_virtual_mem_allocator.h | 14 +- .../gpu/gpu_virtual_mem_allocator_test.cc | 16 +- .../common_runtime/gpu/pool_allocator_test.cc | 12 +- .../core/common_runtime/graph_constructor.cc | 9 +- .../core/common_runtime/graph_constructor.h | 7 +- .../common_runtime/graph_execution_state.cc | 3 +- .../core/common_runtime/int32_fulltype.cc | 133 + .../core/common_runtime/int32_fulltype.h | 53 + .../common_runtime/int32_fulltype_test.cc | 189 + .../core/common_runtime/layout_pass_util.cc | 145 + .../core/common_runtime/layout_pass_util.h | 82 + .../common_runtime/local_session_selection.h | 2 +- .../common_runtime/lower_functional_ops.cc | 4 +- .../core/common_runtime/mkl_cpu_allocator.h | 2 +- .../core/common_runtime/mkl_layout_pass.cc | 2 + .../next_pluggable_device/BUILD | 349 + .../next_pluggable_device/c/BUILD | 56 + .../next_pluggable_device/c/example_plugin.cc | 86 + .../next_pluggable_device/c/example_plugin.h | 49 + .../next_pluggable_device/c/plugin_c_api.h | 165 + .../c/plugin_c_api_test.cc | 141 + .../c/test_next_pluggable_device_plugin.cc | 40 + .../c_plugin_coordination_service_agent.cc | 58 + .../c_plugin_coordination_service_agent.h | 52 + .../c_plugin_op_kernel.cc | 335 + .../c_plugin_op_kernel.h | 162 + .../c_plugin_variable.cc | 60 + .../next_pluggable_device/c_plugin_variable.h | 50 + ...direct_plugin_coordination_service_agent.h | 58 + .../direct_plugin_op_kernel.cc | 143 + .../direct_plugin_op_kernel.h | 182 + .../direct_plugin_variable.cc | 29 + .../direct_plugin_variable.h | 52 + .../next_pluggable_device.cc | 131 + .../next_pluggable_device.h | 85 + .../next_pluggable_device_allocator.cc | 46 + .../next_pluggable_device_allocator.h | 54 + .../next_pluggable_device_api.cc | 45 + .../next_pluggable_device_api.h | 37 + .../next_pluggable_device_context.cc | 137 + .../next_pluggable_device_context.h | 52 + .../next_pluggable_device_factory.cc | 78 + .../next_pluggable_device_factory.h | 52 + .../pjrt_compile_on_demand_op.cc | 289 + .../pjrt_compile_on_demand_op.h | 57 + .../plugin_coordination_service_agent.h | 46 + ...plugin_coordination_service_agent_helper.h | 42 + .../next_pluggable_device/plugin_op_kernel.h | 168 + .../plugin_op_kernel_helper.h | 125 + .../next_pluggable_device/plugin_resource.cc} | 10 +- .../next_pluggable_device/plugin_resource.h | 56 + .../next_pluggable_device/plugin_variable.h | 46 + .../next_pluggable_device/utils.cc | 54 + .../next_pluggable_device/utils.h | 28 + .../common_runtime/optimization_registry.cc | 6 + .../optimize_cross_host_control_deps.cc | 20 +- .../optimize_function_graph_utils.cc | 497 ++ .../optimize_function_graph_utils.h | 65 + .../optimize_function_graph_utils_test.cc | 107 + .../optimized_function_graph_info.cc | 74 + .../optimized_function_graph_info.h | 58 + .../optimized_function_graph_info_test.cc | 179 + .../common_runtime/pluggable_device/BUILD | 43 +- .../pluggable_device/pluggable_device.cc | 4 +- .../pluggable_device/pluggable_device.h | 1 - .../pluggable_device_factory.cc | 14 +- .../pluggable_device_plugin_init.cc | 46 + .../pluggable_device_process_state.cc | 25 +- .../core/common_runtime/pool_allocator.cc | 6 + .../process_function_library_runtime.cc | 608 +- .../process_function_library_runtime.h | 73 +- .../process_function_library_runtime_test.cc | 39 + .../core/common_runtime/propagator_state.cc | 28 + .../core/common_runtime/propagator_state.h | 42 +- .../replicate_per_replica_nodes.cc | 6 + .../core/common_runtime/shape_refiner.cc | 5 +- .../core/common_runtime/shape_refiner.h | 4 +- .../stats_publisher_interface.cc | 27 + .../stats_publisher_interface.h | 26 +- .../core/common_runtime/zen_layout_pass.cc | 1166 ++++ tensorflow/core/config/BUILD | 5 +- tensorflow/core/data/BUILD | 108 +- tensorflow/core/data/captured_function.cc | 7 +- tensorflow/core/data/compression_utils.cc | 181 +- tensorflow/core/data/compression_utils.h | 7 - .../core/data/compression_utils_test.cc | 42 +- tensorflow/core/data/dataset_utils_test.cc | 3 + tensorflow/core/data/finalization_utils.h | 2 +- tensorflow/core/data/metric_utils.cc | 4 +- tensorflow/core/data/metric_utils.h | 1 + tensorflow/core/data/root_dataset.cc | 21 +- tensorflow/core/data/serialization_utils.cc | 57 +- tensorflow/core/data/serialization_utils.h | 13 + .../core/data/serialization_utils_test.cc | 17 + tensorflow/core/data/service/BUILD | 87 +- tensorflow/core/data/service/client/BUILD | 10 +- tensorflow/core/data/service/client/common.h | 20 - .../data/service/client/data_service_client.h | 1 + .../client/data_service_client_test.cc | 1 + tensorflow/core/data/service/common.h | 24 +- tensorflow/core/data/service/common.proto | 46 + tensorflow/core/data/service/dispatcher.proto | 96 +- .../core/data/service/dispatcher_client.cc | 132 +- .../core/data/service/dispatcher_client.h | 19 +- .../data/service/dispatcher_client_test.cc | 140 +- .../core/data/service/dispatcher_impl.cc | 147 +- .../core/data/service/dispatcher_impl.h | 42 +- .../core/data/service/dispatcher_state.cc | 7 + .../core/data/service/dispatcher_state.h | 10 + .../data/service/dispatcher_state_test.cc | 26 +- .../core/data/service/grpc_dispatcher_impl.cc | 3 + .../core/data/service/grpc_dispatcher_impl.h | 3 + tensorflow/core/data/service/grpc_util.cc | 9 + .../core/data/service/grpc_worker_impl.cc | 1 + .../core/data/service/grpc_worker_impl.h | 1 + tensorflow/core/data/service/journal.proto | 8 +- tensorflow/core/data/service/server_lib.cc | 35 +- tensorflow/core/data/service/server_lib.h | 30 +- tensorflow/core/data/service/snapshot/BUILD | 353 + .../snapshot/distributed_snapshot_test.cc | 198 + .../core/data/service/snapshot/file_utils.cc | 98 + .../core/data/service/snapshot/file_utils.h | 52 + .../data/service/snapshot/file_utils_test.cc | 106 + .../core/data/service/snapshot/path_utils.cc | 129 + .../core/data/service/snapshot/path_utils.h | 83 + .../data/service/snapshot/path_utils_test.cc | 126 + .../data/service/snapshot/snapshot_manager.cc | 446 ++ .../data/service/snapshot/snapshot_manager.h | 185 + .../data/service/snapshot/snapshot_reader.cc | 92 + .../data/service/snapshot/snapshot_reader.h | 101 + .../service/snapshot/snapshot_reader_test.cc | 154 + .../snapshot/snapshot_split_provider.cc | 231 + .../snapshot/snapshot_split_provider.h | 100 + .../snapshot/snapshot_split_provider_test.cc | 190 + .../snapshot/snapshot_stream_writer.cc | 382 ++ .../service/snapshot/snapshot_stream_writer.h | 230 + .../snapshot_stream_writer_checkpoint_test.cc | 270 + .../snapshot/snapshot_stream_writer_test.cc | 318 + .../core/data/service/snapshot/test_utils.cc | 199 + .../core/data/service/snapshot/test_utils.h | 108 + .../core/data/service/snapshot/utils.cc | 37 + tensorflow/core/data/service/snapshot/utils.h | 31 + .../core/data/service/snapshot/utils_test.cc | 69 + .../core/data/service/split_provider.cc | 22 +- tensorflow/core/data/service/split_provider.h | 12 +- tensorflow/core/data/service/task_runner.cc | 7 + tensorflow/core/data/service/task_runner.h | 18 + tensorflow/core/data/service/test_cluster.cc | 10 +- tensorflow/core/data/service/test_cluster.h | 6 +- tensorflow/core/data/service/test_util.cc | 30 + tensorflow/core/data/service/test_util.h | 16 + .../core/data/service/test_util_test.cc | 83 +- tensorflow/core/data/service/testdata/BUILD | 1 + .../testdata/choose_from_datasets.pbtxt | 661 ++ tensorflow/core/data/service/worker.proto | 11 + tensorflow/core/data/service/worker_client.cc | 16 +- tensorflow/core/data/service/worker_impl.cc | 160 +- tensorflow/core/data/service/worker_impl.h | 50 +- tensorflow/core/data/snapshot_utils.cc | 25 + tensorflow/core/data/snapshot_utils.h | 30 +- tensorflow/core/data/snapshot_utils_test.cc | 25 + tensorflow/core/data/standalone.cc | 65 +- tensorflow/core/data/standalone.h | 20 +- .../core/data/standalone_save_restore_test.cc | 112 + tensorflow/core/data/tfdataz_metrics.cc | 154 + tensorflow/core/data/tfdataz_metrics.h | 144 + tensorflow/core/data/tfdataz_metrics_test.cc | 224 + tensorflow/core/debug/BUILD | 8 +- tensorflow/core/distributed_runtime/BUILD | 97 +- .../base_rendezvous_mgr.cc | 25 +- .../cluster_function_library_runtime.cc | 10 +- .../cluster_function_library_runtime.h | 4 + .../cluster_function_library_runtime_test.cc | 4 +- .../distributed_runtime/coordination/BUILD | 90 +- .../coordination_service_barrier_proxy.cc | 4 +- .../coordination_service_barrier_proxy.h | 9 +- ...coordination_service_barrier_proxy_test.cc | 9 +- .../coordination_service_rpc_handler.h | 73 +- .../core/distributed_runtime/eager/BUILD | 14 +- .../eager/eager_service_impl.cc | 56 +- .../eager/eager_service_impl_test.cc | 10 +- .../eager/remote_copy_node.cc | 4 +- .../distributed_runtime/eager/remote_mgr.cc | 2 +- .../eager/remote_mgr_test.cc | 3 +- .../core/distributed_runtime/graph_mgr.cc | 6 +- .../core/distributed_runtime/graph_mgr.h | 9 +- .../integration_test/BUILD | 41 +- .../c_api_coordination_test.cc | 134 - .../c_api_recoverable_jobs_test.cc | 252 + ...coordination_test_opkernel_registration.cc | 5 +- tensorflow/core/distributed_runtime/master.cc | 3 + .../core/distributed_runtime/preemption/BUILD | 46 +- .../preemption/check_preemption_op_kernel.cc | 13 +- .../preemption/preemption_sync_manager.h | 49 +- tensorflow/core/distributed_runtime/rpc/BUILD | 52 +- .../rpc/coordination/BUILD | 34 +- .../coordination/grpc_coordination_client.h | 13 +- .../grpc_coordination_service_impl.h | 90 +- .../core/distributed_runtime/rpc/eager/BUILD | 5 + .../rpc/grpc_client_cq_tag.h | 24 +- .../rpc/grpc_remote_worker.cc | 10 +- .../rpc/grpc_server_lib.cc | 16 +- .../distributed_runtime/rpc/grpc_server_lib.h | 7 +- .../distributed_runtime/rpc/grpc_state.cc | 3 +- .../core/distributed_runtime/rpc/grpc_state.h | 205 +- .../core/distributed_runtime/rpc/grpc_util.cc | 55 +- .../core/distributed_runtime/rpc/grpc_util.h | 23 +- .../distributed_runtime/rpc/grpc_util_test.cc | 193 - .../rpc/grpc_worker_service.cc | 4 + .../rpc/rpc_rendezvous_mgr_test.cc | 4 +- .../core/distributed_runtime/server_lib.h | 7 +- .../core/distributed_runtime/session_mgr.cc | 55 +- .../core/distributed_runtime/session_mgr.h | 19 +- .../distributed_runtime/session_mgr_test.cc | 3 +- .../distributed_runtime/worker_session.cc | 24 +- .../core/distributed_runtime/worker_session.h | 23 +- tensorflow/core/example/BUILD | 15 + tensorflow/core/example/feature_util.cc | 31 - tensorflow/core/example/feature_util.h | 61 + tensorflow/core/example/testdata/BUILD | 1 + tensorflow/core/framework/BUILD | 63 +- tensorflow/core/framework/collective.h | 1 + tensorflow/core/framework/common_shape_fns.cc | 2 +- tensorflow/core/framework/common_shape_fns.h | 2 +- tensorflow/core/framework/dataset.cc | 58 +- tensorflow/core/framework/dataset.h | 470 +- tensorflow/core/framework/dataset.proto | 20 +- .../core/framework/dataset_options.proto | 16 +- tensorflow/core/framework/float8.cc | 458 -- tensorflow/core/framework/float8.h | 351 - tensorflow/core/framework/full_type.proto | 11 + tensorflow/core/framework/function.h | 12 +- tensorflow/core/framework/function_testlib.cc | 18 + tensorflow/core/framework/function_testlib.h | 3 + .../framework/graph_to_functiondef_test.cc | 256 + tensorflow/core/framework/metrics.cc | 14 +- tensorflow/core/framework/model.cc | 71 +- tensorflow/core/framework/model.h | 17 +- tensorflow/core/framework/model_test.cc | 359 +- tensorflow/core/framework/node_def_builder.cc | 5 + .../core/framework/node_def_util_test.cc | 3 +- .../core/framework/op_def_builder_test.cc | 6 +- tensorflow/core/framework/op_kernel.h | 18 +- tensorflow/core/framework/op_kernel_test.cc | 140 +- tensorflow/core/framework/op_requires.h | 2 +- .../framework/optimized_function_graph.proto | 26 + .../framework/partial_tensor_shape_test.cc | 166 +- tensorflow/core/framework/register_types.h | 11 +- .../core/framework/register_types_traits.h | 6 +- tensorflow/core/framework/registration/BUILD | 1 + tensorflow/core/framework/tensor.cc | 64 +- tensorflow/core/framework/tensor.proto | 4 + tensorflow/core/framework/tensor_shape.cc | 6 +- tensorflow/core/framework/tensor_shape.h | 7 +- .../core/framework/tensor_shape_fuzz.cc | 89 + .../core/framework/tensor_shape_test.cc | 321 + tensorflow/core/framework/tensor_test.cc | 418 +- tensorflow/core/framework/tensor_testutil.cc | 4 + tensorflow/core/framework/tensor_util.cc | 4 +- tensorflow/core/framework/types.cc | 14 + tensorflow/core/framework/types.h | 14 +- tensorflow/core/framework/types.proto | 7 +- tensorflow/core/framework/types_test.cc | 7 + tensorflow/core/function/capture/BUILD | 62 + .../function/capture/by_ref_capture_test.py | 106 + .../function/capture/capture_container.py | 148 + .../capture/capture_container_test.py | 120 + .../core/function/capture/restore_captures.py | 133 + .../core/function/integration_test/BUILD | 3 +- tensorflow/core/function/polymorphism/BUILD | 32 + .../polymorphism/function_cache_test.py | 9 + .../function/polymorphism/function_type.proto | 36 + .../function/polymorphism/function_type.py | 243 +- .../polymorphism/function_type_test.py | 135 +- .../polymorphism/type_dispatch_test.py | 3 + .../core/function/{ => runtime_client}/BUILD | 2 + .../{ => runtime_client}/runtime_client.cc | 2 +- .../{ => runtime_client}/runtime_client.h | 6 +- .../{ => runtime_client}/runtime_client.py | 2 +- .../runtime_client_pybind.cc | 2 +- .../runtime_client_test.cc | 2 +- .../runtime_client_test.py | 2 +- tensorflow/core/function/testing/BUILD | 6 +- tensorflow/core/function/trace_type/BUILD | 3 + .../core/function/trace_type/__init__.py | 2 + .../core/function/trace_type/default_types.py | 161 +- .../function/trace_type/default_types_test.py | 3 + .../function/trace_type/trace_type_builder.py | 92 +- .../function/trace_type/trace_type_test.py | 109 +- tensorflow/core/function/transform/BUILD | 8 +- .../core/function/transform/transform.py | 117 +- .../core/function/transform/transform_test.py | 229 + tensorflow/core/graph/BUILD | 68 +- tensorflow/core/graph/costmodel.cc | 40 +- tensorflow/core/graph/costmodel.h | 2 +- tensorflow/core/graph/costmodel_test.cc | 571 +- tensorflow/core/graph/mkl_graph_util.h | 7 +- tensorflow/core/graph/mkl_testlib.cc | 9 + tensorflow/core/graph/mkl_testlib.h | 2 + tensorflow/core/graph/regularization/BUILD | 1 + tensorflow/core/graph/validate_test.cc | 77 + tensorflow/core/graph/zen_graph_util.h | 83 + tensorflow/core/grappler/BUILD | 3 +- tensorflow/core/grappler/clusters/BUILD | 1 + tensorflow/core/grappler/costs/BUILD | 2 + .../core/grappler/costs/cost_estimator.cc | 2 + .../core/grappler/costs/cost_estimator.h | 5 + .../core/grappler/costs/graph_properties.cc | 37 +- .../costs/graph_properties_testdata/BUILD | 1 + tensorflow/core/grappler/devices.cc | 8 +- tensorflow/core/grappler/graph_analyzer/BUILD | 1 + .../core/grappler/grappler_item_builder.cc | 2 +- tensorflow/core/grappler/inputs/BUILD | 1 + .../core/grappler/inputs/testdata/BUILD | 1 + .../core/grappler/mutable_graph_view.cc | 15 +- tensorflow/core/grappler/optimizers/BUILD | 12 +- .../optimizers/auto_mixed_precision_lists.h | 13 +- .../optimizers/auto_mixed_precision_test.cc | 2 +- .../optimizers/constant_folding_test.cc | 107 +- .../core/grappler/optimizers/data/BUILD | 1 + .../grappler/optimizers/data/auto_shard.cc | 8 +- .../grappler/optimizers/function_optimizer.cc | 9 +- .../optimizers/generic_layout_optimizer.cc | 97 +- .../generic_layout_optimizer_transposer.cc | 5 + .../core/grappler/optimizers/inference/BUILD | 1 + .../grappler/optimizers/mkl_remapper_test.cc | 225 +- .../core/grappler/optimizers/remapper.cc | 400 +- .../optimizers/tfg_optimizer_hook_test.cc | 5 +- tensorflow/core/grappler/utils/BUILD | 1 + .../core/grappler/utils/pattern_utils.h | 12 + tensorflow/core/grappler/verifiers/BUILD | 1 + tensorflow/core/ir/BUILD | 2 +- tensorflow/core/ir/dialect.h | 76 +- tensorflow/core/ir/dialect.td | 4 +- tensorflow/core/ir/importexport/BUILD | 3 + .../ir/importexport/convert_attributes.cc | 3 +- .../core/ir/importexport/convert_tensor.cc | 96 +- .../core/ir/importexport/convert_tensor.h | 11 + .../core/ir/importexport/convert_types.cc | 16 +- .../ir/importexport/functiondef_import.cc | 40 +- .../core/ir/importexport/graphdef_export.cc | 24 +- .../core/ir/importexport/graphdef_import.cc | 2 +- tensorflow/core/ir/importexport/tests/BUILD | 1 + .../importexport/tests/graphdef_to_mlir/BUILD | 1 + ...id_generic_function_named_edge_index.pbtxt | 52 + .../importexport/tests/mlir_to_graphdef/BUILD | 1 + .../ir/importexport/tests/roundtrip/BUILD | 1 + .../importexport/tests/roundtrip/roundtrip.h | 1 - .../ir/importexport/tests/saved_model/BUILD | 1 + .../core/ir/importexport/tfg-translate.cc | 2 + tensorflow/core/ir/ops.cc | 206 +- tensorflow/core/ir/ops.td | 12 +- tensorflow/core/ir/tests/BUILD | 1 + tensorflow/core/ir/tests/types.mlir | 4 + tensorflow/core/ir/tf_op_names.cc | 5 + tensorflow/core/ir/tf_op_names.inc | 3 + tensorflow/core/ir/types/BUILD | 3 +- tensorflow/core/ir/types/attributes.td | 21 +- tensorflow/core/ir/types/dialect.cc | 109 +- tensorflow/core/ir/types/dialect.h | 4 +- tensorflow/core/ir/types/dialect.td | 1 + tensorflow/core/ir/types/types.def | 2 + .../core/ir/utils/shape_inference_utils.cc | 46 +- .../core/ir/utils/shape_inference_utils.h | 8 +- tensorflow/core/kernels/BUILD | 279 +- .../core/kernels/aggregate_ops_gpu.cu.cc | 3 +- tensorflow/core/kernels/argmax_op.cc | 5 +- tensorflow/core/kernels/avgpooling_op.cc | 30 +- .../core/kernels/avgpooling_op_gpu.cu.cc | 32 +- tensorflow/core/kernels/batch_kernels.cc | 11 +- tensorflow/core/kernels/batching_util/BUILD | 1 + .../adaptive_shared_batch_scheduler.h | 3 + .../batching_util/batch_resource_base.cc | 22 +- .../batching_util/batch_resource_base.h | 13 +- .../batching_util/shared_batch_scheduler.h | 44 +- .../shared_batch_scheduler_test.cc | 6 +- tensorflow/core/kernels/batchtospace_op.cc | 28 +- tensorflow/core/kernels/bias_op_gpu.cu.cc | 8 +- .../core/kernels/broadcast_to_op_gpu.cu.cc | 2 +- .../core/kernels/candidate_sampler_ops.cc | 17 +- tensorflow/core/kernels/cast_op.cc | 28 + tensorflow/core/kernels/cast_op.h | 108 +- tensorflow/core/kernels/cast_op_gpu.cu.cc | 13 + tensorflow/core/kernels/cast_op_impl.h | 8 + .../core/kernels/cast_op_impl_bfloat.cc | 13 +- tensorflow/core/kernels/cast_op_impl_float.cc | 4 + .../core/kernels/cast_op_impl_float8.cc | 61 + tensorflow/core/kernels/cast_op_impl_half.cc | 8 +- tensorflow/core/kernels/collective_nccl.cc | 6 + .../kernels/collective_nccl_all_to_all.cc | 34 + .../core/kernels/collective_nccl_all_to_all.h | 35 + .../core/kernels/collective_nccl_reducer.cc | 1 + .../core/kernels/collective_nccl_reducer.h | 9 + .../core/kernels/collective_nccl_test.cc | 514 +- tensorflow/core/kernels/collective_ops.cc | 107 +- tensorflow/core/kernels/concat_lib.h | 1 - tensorflow/core/kernels/concat_lib_gpu.cc | 1 - tensorflow/core/kernels/concat_lib_gpu.h | 1 - .../core/kernels/concat_lib_gpu_impl.cu.cc | 4 - tensorflow/core/kernels/concat_op.cc | 3 +- tensorflow/core/kernels/constant_op.cc | 2 +- tensorflow/core/kernels/conv_2d.h | 26 +- .../core/kernels/conv_2d_gpu_bfloat16.cu.cc | 51 + .../core/kernels/conv_grad_filter_ops.cc | 129 +- .../core/kernels/conv_grad_input_ops.cc | 146 +- tensorflow/core/kernels/conv_grad_input_ops.h | 2 +- tensorflow/core/kernels/conv_grad_ops_3d.cc | 1191 ++-- .../core/kernels/conv_grad_shape_utils.cc | 5 +- tensorflow/core/kernels/conv_ops.cc | 167 +- tensorflow/core/kernels/conv_ops_3d.cc | 644 +- .../kernels/conv_ops_fused_image_transform.cc | 9 +- tensorflow/core/kernels/conv_ops_fused_impl.h | 42 +- .../core/kernels/conv_ops_fused_int8.cc | 18 +- tensorflow/core/kernels/conv_ops_gpu.cc | 86 +- tensorflow/core/kernels/conv_ops_gpu.h | 23 +- tensorflow/core/kernels/conv_ops_test.cc | 17 +- .../core/kernels/conv_ops_using_gemm.cc | 6 +- tensorflow/core/kernels/cudnn_pooling_gpu.cc | 180 +- tensorflow/core/kernels/cudnn_rnn_ops.cc | 14 +- tensorflow/core/kernels/cwise_op_abs.cc | 2 + tensorflow/core/kernels/cwise_op_add_1.cc | 3 + tensorflow/core/kernels/cwise_op_clip.cc | 110 +- tensorflow/core/kernels/cwise_op_div.cc | 4 + .../core/kernels/cwise_op_gpu_abs.cu.cc | 2 + .../core/kernels/cwise_op_gpu_add.cu.cc | 2 + .../core/kernels/cwise_op_gpu_div.cu.cc | 4 + .../core/kernels/cwise_op_gpu_inverse.cu.cc | 4 +- .../core/kernels/cwise_op_gpu_maximum.cu.cc | 2 + .../core/kernels/cwise_op_gpu_minimum.cu.cc | 2 + .../core/kernels/cwise_op_gpu_mul.cu.cc | 5 +- .../core/kernels/cwise_op_gpu_sign.cu.cc | 2 + .../core/kernels/cwise_op_gpu_square.cu.cc | 2 + .../cwise_op_gpu_squared_difference.cu.cc | 2 + .../core/kernels/cwise_op_gpu_sub.cu.cc | 2 + .../core/kernels/cwise_op_gpu_xdivy.cu.cc | 2 + tensorflow/core/kernels/cwise_op_maximum.cc | 2 + tensorflow/core/kernels/cwise_op_minimum.cc | 2 + tensorflow/core/kernels/cwise_op_mul_1.cc | 3 + .../core/kernels/cwise_op_reciprocal.cc | 10 +- tensorflow/core/kernels/cwise_op_sign.cc | 2 + tensorflow/core/kernels/cwise_op_square.cc | 2 + .../kernels/cwise_op_squared_difference.cc | 2 + tensorflow/core/kernels/cwise_op_sub.cc | 2 + tensorflow/core/kernels/cwise_op_xdivy.cc | 5 +- tensorflow/core/kernels/cwise_op_xlog1py.cc | 4 +- tensorflow/core/kernels/cwise_op_xlogy.cc | 4 +- tensorflow/core/kernels/cwise_ops.h | 99 +- tensorflow/core/kernels/cwise_ops_common.h | 27 +- tensorflow/core/kernels/data/BUILD | 8 + .../core/kernels/data/batch_dataset_op.cc | 13 +- .../kernels/data/concatenate_dataset_op.cc | 43 +- .../core/kernels/data/experimental/BUILD | 22 + .../assert_cardinality_dataset_op.cc | 2 + .../directed_interleave_dataset_op.cc | 38 +- .../data/experimental/distributed_save_op.cc | 107 + .../data/experimental/distributed_save_op.h | 46 + .../data/experimental/list_dataset_op.cc | 2 + .../experimental/map_and_batch_dataset_op.cc | 15 +- .../data/experimental/random_access_ops.h | 19 +- .../data/experimental/random_dataset_op.cc | 199 +- .../data/experimental/random_dataset_op.h | 4 + .../experimental/random_dataset_op_test.cc | 228 +- .../data/experimental/scan_dataset_op.cc | 31 +- .../core/kernels/data/experimental/sql/BUILD | 1 + .../experimental/take_while_dataset_op.cc | 14 +- .../data/experimental/unbatch_dataset_op.cc | 14 +- .../core/kernels/data/filter_dataset_op.cc | 27 +- .../core/kernels/data/flat_map_dataset_op.cc | 45 +- .../kernels/data/interleave_dataset_op.cc | 26 +- tensorflow/core/kernels/data/iterator_ops.cc | 118 +- tensorflow/core/kernels/data/iterator_ops.h | 31 +- .../core/kernels/data/map_dataset_op.cc | 10 +- .../kernels/data/padded_batch_dataset_op.cc | 22 +- .../kernels/data/parallel_map_dataset_op.cc | 82 +- .../core/kernels/data/prefetch_autotuner.cc | 26 +- .../core/kernels/data/prefetch_autotuner.h | 34 +- .../kernels/data/prefetch_autotuner_test.cc | 35 + .../core/kernels/data/prefetch_dataset_op.cc | 24 +- .../core/kernels/data/range_dataset_op.cc | 24 +- .../kernels/data/range_dataset_op_test.cc | 20 +- .../core/kernels/data/repeat_dataset_op.cc | 36 +- .../core/kernels/data/shard_dataset_op.cc | 13 +- .../core/kernels/data/skip_dataset_op.cc | 14 +- .../core/kernels/data/take_dataset_op.cc | 14 +- .../core/kernels/data/tensor_dataset_op.cc | 10 +- .../kernels/data/tensor_slice_dataset_op.cc | 2 + .../core/kernels/data/zip_dataset_op.cc | 21 +- .../core/kernels/decode_padded_raw_op.cc | 21 +- tensorflow/core/kernels/decode_proto_op.cc | 4 +- tensorflow/core/kernels/decode_raw_op.cc | 4 +- .../kernels/dense_update_functor_gpu.cu.cc | 4 +- tensorflow/core/kernels/dense_update_ops.cc | 1 - tensorflow/core/kernels/depthtospace_op.cc | 15 +- .../core/kernels/depthtospace_op_gpu.cu.cc | 6 + .../core/kernels/depthwise_conv_grad_op.cc | 77 +- tensorflow/core/kernels/depthwise_conv_op.cc | 38 +- tensorflow/core/kernels/depthwise_conv_op.h | 6 + .../core/kernels/depthwise_conv_op_gpu.h | 4 + .../depthwise_conv_op_gpu_bfloat16.cu.cc | 30 + tensorflow/core/kernels/diag_op.cc | 6 +- .../core/kernels/dynamic_partition_op.cc | 4 +- .../kernels/dynamic_partition_op_gpu.cu.cc | 36 +- tensorflow/core/kernels/dynamic_stitch_op.cc | 21 +- tensorflow/core/kernels/edit_distance_op.cc | 5 +- .../eigen_backward_spatial_convolutions.h | 2 +- tensorflow/core/kernels/eigen_benchmark.h | 2 +- .../core/kernels/eigen_cuboid_convolution.h | 4 +- .../kernels/eigen_cuboid_convolutions_test.cc | 1309 ++++ .../eigen_mkldnn_contraction_kernel_test.cc | 2 +- tensorflow/core/kernels/eigen_pooling.h | 23 +- tensorflow/core/kernels/encode_proto_op.cc | 3 +- .../core/kernels/example_parsing_ops.cc | 10 +- tensorflow/core/kernels/fft_ops.cc | 203 +- tensorflow/core/kernels/fill_functor.cc | 4 + .../core/kernels/fractional_avg_pool_op.cc | 5 +- .../core/kernels/fractional_max_pool_op.cc | 2 +- tensorflow/core/kernels/function_ops.cc | 6 +- tensorflow/core/kernels/functional_ops.cc | 10 +- .../core/kernels/fused_batch_norm_op.cc | 407 +- .../core/kernels/fused_batch_norm_op.cu.cc | 3 + tensorflow/core/kernels/fuzzing/BUILD | 1 + tensorflow/core/kernels/gather_functor.cc | 1 + .../core/kernels/gather_functor_batched.cc | 1 + .../kernels/gather_functor_batched_gpu.cu.cc | 1 + .../core/kernels/gather_functor_gpu.cu.cc | 1 + tensorflow/core/kernels/gather_nd_op.cc | 1 + tensorflow/core/kernels/gather_op.cc | 8 +- tensorflow/core/kernels/gemm_functors.h | 3 +- tensorflow/core/kernels/gpu_prim.h | 2 - tensorflow/core/kernels/gpu_utils.cc | 4 +- tensorflow/core/kernels/histogram_op.cc | 21 +- tensorflow/core/kernels/image/BUILD | 1 + .../core/kernels/image/crop_and_resize_op.cc | 4 +- .../image/crop_and_resize_op_gpu.cu.cc | 4 +- .../core/kernels/image/decode_image_op.cc | 5 +- .../core/kernels/image/mirror_pad_op.cc | 2 +- .../kernels/image/non_max_suppression_op.cc | 11 +- .../image/non_max_suppression_op.cu.cc | 29 +- .../core/kernels/image/resize_area_op_test.cc | 22 +- .../image/resize_bilinear_op_gpu.cu.cc | 4 +- .../kernels/image/scale_and_translate_op.cc | 11 +- tensorflow/core/kernels/inplace_ops.cc | 8 +- .../kernels/inplace_ops_functor_gpu.cu.cc | 3 + tensorflow/core/kernels/l2loss_op_gpu.cu.cc | 1 + tensorflow/core/kernels/linalg/BUILD | 1 + .../linalg/banded_triangular_solve_op.cc | 4 +- .../core/kernels/linalg/cholesky_op_gpu.cu.cc | 4 +- .../core/kernels/linalg/einsum_op_impl.h | 21 +- .../kernels/linalg/einsum_op_impl_bfloat16.cc | 6 + .../core/kernels/linalg/linalg_ops_common.cc | 3 +- tensorflow/core/kernels/linalg/lu_op.cc | 5 +- .../core/kernels/linalg/matrix_diag_op.cc | 11 +- .../core/kernels/linalg/matrix_inverse_op.cc | 46 + .../core/kernels/linalg/matrix_set_diag_op.cc | 6 +- .../linalg/matrix_triangular_solve_op_impl.h | 4 +- tensorflow/core/kernels/list_kernels.cc | 10 +- tensorflow/core/kernels/list_kernels.cu.cc | 4 +- tensorflow/core/kernels/lrn_op.cc | 76 +- tensorflow/core/kernels/matmul_op.h | 2 +- tensorflow/core/kernels/matmul_op_fused.cc | 33 +- tensorflow/core/kernels/matmul_op_impl.h | 40 +- tensorflow/core/kernels/matmul_op_real.cc | 2 + tensorflow/core/kernels/matmul_util.h | 8 +- tensorflow/core/kernels/maxpooling_op.cc | 68 +- .../core/kernels/maxpooling_op_gpu.cu.cc | 3 +- tensorflow/core/kernels/mkl/BUILD | 24 +- .../core/kernels/mkl/mkl_avgpooling_op.cc | 38 +- .../kernels/mkl/mkl_conv_grad_filter_ops.cc | 22 +- tensorflow/core/kernels/mkl/mkl_conv_ops.cc | 28 +- tensorflow/core/kernels/mkl/mkl_conv_ops.h | 22 +- tensorflow/core/kernels/mkl/mkl_einsum_op.cc | 47 +- .../core/kernels/mkl/mkl_kernel_util.cc | 93 + tensorflow/core/kernels/mkl/mkl_kernel_util.h | 79 +- .../core/kernels/mkl/mkl_layer_norm_op.cc | 10 +- .../core/kernels/mkl/mkl_matmul_ops_common.h | 2 +- .../core/kernels/mkl/mkl_maxpooling_op.cc | 13 + tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc | 60 +- .../core/kernels/mkl/mkl_qmatmul_op_test.cc | 80 +- .../mkl/mkl_quantized_conv_ops_test.cc | 2 - tensorflow/core/kernels/mkl/mkl_softmax_op.cc | 109 +- .../core/kernels/mkl/mkl_swish_op_test.cc | 1 - .../kernels/mkl/onednn_nn_ops_benchmark.cc | 115 + tensorflow/core/kernels/mlir_generated/BUILD | 1592 +++-- .../kernels/mlir_generated/base_ops_test.cc | 5 + .../kernels/mlir_generated/base_ops_test.h | 11 +- .../mlir_generated/base_unary_ops_test.h | 3 +- .../kernels/mlir_generated/build_defs.bzl | 2 +- .../mlir_generated/gpu_binary_ops_test.cc | 19 + .../mlir_generated/gpu_op_truncate_div.cc | 3 + .../gpu_unary_ops_large_tensor_test.cc | 67 + .../core/kernels/multinomial_op_gpu.cu.cc | 12 +- tensorflow/core/kernels/nccl_ops.cc | 4 + tensorflow/core/kernels/nth_element_op.cc | 2 +- tensorflow/core/kernels/one_hot_op.cc | 1 + tensorflow/core/kernels/ops_testutil.cc | 2 + tensorflow/core/kernels/pack_op.cc | 1 - tensorflow/core/kernels/pad_op.cc | 1 - tensorflow/core/kernels/pad_op_gpu.cu.cc | 3 +- tensorflow/core/kernels/pooling_ops_3d.cc | 32 +- tensorflow/core/kernels/pooling_ops_3d.h | 2 +- .../core/kernels/pooling_ops_3d_gpu.cu.cc | 3 +- tensorflow/core/kernels/pooling_ops_common.cc | 297 +- tensorflow/core/kernels/pooling_ops_common.h | 12 +- .../core/kernels/quantized_pooling_ops.cc | 9 +- tensorflow/core/kernels/ragged_range_op.cc | 12 +- tensorflow/core/kernels/range_sampler.cc | 41 +- tensorflow/core/kernels/range_sampler.h | 14 +- tensorflow/core/kernels/range_sampler_test.cc | 58 +- .../core/kernels/reduction_gpu_kernels.cu.h | 130 +- tensorflow/core/kernels/reduction_ops.h | 1 + .../core/kernels/reduction_ops_euclidean.cc | 1 - .../kernels/reduction_ops_gpu_bfloat16.cu.cc | 75 + tensorflow/core/kernels/reduction_ops_max.cc | 1 + tensorflow/core/kernels/reduction_ops_mean.cc | 1 - tensorflow/core/kernels/reduction_ops_min.cc | 1 + tensorflow/core/kernels/reduction_ops_prod.cc | 1 - tensorflow/core/kernels/relu_op.cc | 88 +- tensorflow/core/kernels/reshape_util.cc | 4 +- .../core/kernels/resource_variable_ops.cc | 41 +- tensorflow/core/kernels/reverse_op.cc | 3 + tensorflow/core/kernels/reverse_op_gpu.cu.cc | 3 +- .../core/kernels/reverse_sequence_op.cc | 1 + tensorflow/core/kernels/risc/BUILD | 1 + .../core/kernels/risc/experimental/BUILD | 1 + tensorflow/core/kernels/rnn/BUILD | 3 +- tensorflow/core/kernels/rnn/blas_gemm.h | 2 +- tensorflow/core/kernels/rnn/gru_ops.cc | 88 +- tensorflow/core/kernels/rnn/lstm_ops.cc | 8 + tensorflow/core/kernels/scan_ops.cc | 8 +- .../core/kernels/scan_ops_gpu_bfloat16.cu.cc | 33 + tensorflow/core/kernels/scatter_functor.cc | 1 + .../core/kernels/scatter_functor_gpu.cu.cc | 1 + tensorflow/core/kernels/scatter_nd_op.cc | 2 +- tensorflow/core/kernels/scatter_nd_util.h | 17 + tensorflow/core/kernels/scatter_op.cc | 2 - tensorflow/core/kernels/scatter_op_gpu.cu.cc | 1 + tensorflow/core/kernels/sdca_internal.cc | 2 +- .../kernels/segment_reduction_ops_gpu.cu.h | 5 + .../core/kernels/segment_reduction_ops_impl.h | 4 +- tensorflow/core/kernels/sequence_ops.cc | 4 +- tensorflow/core/kernels/set_kernels.cc | 13 +- tensorflow/core/kernels/slice_op.cc | 4 +- tensorflow/core/kernels/slice_op_gpu.cu.cc | 4 +- tensorflow/core/kernels/sobol_op.cc | 4 + tensorflow/core/kernels/softmax_op_gpu.cu.cc | 50 +- tensorflow/core/kernels/softplus_op.cc | 1 + tensorflow/core/kernels/softplus_op_gpu.cu.cc | 3 +- tensorflow/core/kernels/softsign_op.cc | 2 +- tensorflow/core/kernels/softsign_op_gpu.cu.cc | 3 +- .../kernels/spacetobatch_functor_gpu.cu.cc | 2 +- tensorflow/core/kernels/spacetodepth_op.cc | 15 +- .../core/kernels/spacetodepth_op_gpu.cu.cc | 6 + tensorflow/core/kernels/sparse/BUILD | 2 + .../sparse/dense_to_csr_sparse_matrix_op.cc | 9 +- tensorflow/core/kernels/sparse/kernels.cc | 53 +- tensorflow/core/kernels/sparse/kernels.h | 2 +- .../core/kernels/sparse/kernels_test.cc | 93 +- tensorflow/core/kernels/sparse/mat_mul_op.cc | 56 +- tensorflow/core/kernels/sparse/nnz_op.cc | 3 +- .../core/kernels/sparse/sparse_mat_mul_op.cc | 177 +- .../sparse_tensor_to_csr_sparse_matrix_op.cc | 7 +- .../kernels/sparse_dense_binary_op_shared.cc | 2 +- tensorflow/core/kernels/sparse_matmul_op.cc | 2 +- .../kernels/sparse_sparse_binary_op_shared.cc | 43 +- tensorflow/core/kernels/sparse_to_dense_op.cc | 4 +- .../core/kernels/sparse_to_dense_op_gpu.cu.cc | 6 +- tensorflow/core/kernels/sparse_xent_op.cc | 2 + .../core/kernels/sparse_xent_op_gpu.cu.cc | 5 +- tensorflow/core/kernels/special_math/BUILD | 1 + .../core/kernels/spectrogram_op_test.cc | 41 + .../core/kernels/spectrogram_test_data/BUILD | 1 + tensorflow/core/kernels/split_lib_gpu.cu.cc | 7 +- tensorflow/core/kernels/split_lib_gpu.h | 1 - tensorflow/core/kernels/split_op.cc | 3 - tensorflow/core/kernels/split_v_op.cc | 60 +- tensorflow/core/kernels/string_util.h | 2 +- tensorflow/core/kernels/sync_ops.cc | 59 + tensorflow/core/kernels/tensor_array.h | 1 + tensorflow/core/kernels/tensor_array_ops.cc | 10 +- tensorflow/core/kernels/tensor_cord.h | 2 +- .../kernels/tile_functor_gpu_bfloat16.cu.cc} | 20 +- tensorflow/core/kernels/tile_ops.cc | 5 + tensorflow/core/kernels/tile_ops_gpu_impl.h | 25 +- tensorflow/core/kernels/topk_op.cc | 1 + tensorflow/core/kernels/topk_op_gpu.h | 5 +- .../kernels/topk_op_gpu_bfloat16.cu.cc} | 15 +- tensorflow/core/kernels/transpose_op.cc | 2 +- .../core/kernels/uniform_quant_ops/BUILD | 26 +- .../kernels/uniform_quant_ops/math_utils.h | 127 + .../uniform_quantized_add_op.cc | 268 + .../uniform_quantized_add_op_test.cc | 491 ++ .../uniform_quantized_convolution_ops.cc | 307 + .../uniform_quantized_convolution_ops_test.cc | 800 +++ .../uniform_requantize_op.cc | 125 - tensorflow/core/kernels/unpack_op.cc | 1 - tensorflow/core/kernels/while_op_test.cc | 2 +- tensorflow/core/kernels/xent_op.cc | 21 +- tensorflow/core/kernels/xent_op_gpu.cu.cc | 3 +- tensorflow/core/lib/bfloat16/BUILD | 1 + tensorflow/core/lib/bmp/BUILD | 1 + tensorflow/core/lib/bmp/testdata/BUILD | 1 + tensorflow/core/lib/core/BUILD | 1 + tensorflow/core/lib/core/status_test.cc | 8 +- tensorflow/core/lib/db/BUILD | 1 + tensorflow/core/lib/gif/BUILD | 1 + tensorflow/core/lib/gif/gif_io.cc | 14 +- tensorflow/core/lib/gif/gif_io_test.cc | 6 +- .../core/lib/gif/testdata/3g_multiframe.gif | Bin 0 -> 22703 bytes tensorflow/core/lib/gif/testdata/BUILD | 2 + tensorflow/core/lib/gtl/BUILD | 1 + tensorflow/core/lib/gtl/subtle/BUILD | 1 + tensorflow/core/lib/hash/BUILD | 3 + tensorflow/core/lib/histogram/BUILD | 1 + tensorflow/core/lib/io/BUILD | 1 + tensorflow/core/lib/jpeg/BUILD | 1 + tensorflow/core/lib/jpeg/testdata/BUILD | 1 + tensorflow/core/lib/llvm_rtti/BUILD | 1 + tensorflow/core/lib/lmdb/BUILD | 1 + tensorflow/core/lib/lmdb/testdata/BUILD | 1 + tensorflow/core/lib/math/BUILD | 1 + tensorflow/core/lib/monitoring/BUILD | 1 + .../core/lib/monitoring/cell_reader_test.cc | 2 +- tensorflow/core/lib/png/BUILD | 1 + tensorflow/core/lib/png/testdata/BUILD | 1 + tensorflow/core/lib/psnr/BUILD | 1 + tensorflow/core/lib/psnr/testdata/BUILD | 1 + tensorflow/core/lib/random/BUILD | 1 + tensorflow/core/lib/ssim/BUILD | 1 + tensorflow/core/lib/ssim/testdata/BUILD | 1 + tensorflow/core/lib/strings/BUILD | 3 + tensorflow/core/lib/wav/BUILD | 1 + tensorflow/core/nccl/BUILD | 5 +- .../core/nccl/collective_communicator.cc | 16 + tensorflow/core/nccl/nccl_manager.cc | 71 + tensorflow/core/nccl/nccl_manager.h | 14 + tensorflow/core/nccl/nccl_manager_test.cc | 4 +- tensorflow/core/ops/BUILD | 19 +- tensorflow/core/ops/array_ops.cc | 39 +- tensorflow/core/ops/array_ops_test.cc | 56 + tensorflow/core/ops/audio_ops.cc | 10 + tensorflow/core/ops/collective_ops.cc | 29 + tensorflow/core/ops/compat/BUILD | 1 + .../core/ops/compat/ops_history.v2.pbtxt | 100 + .../core/ops/compat/ops_history_v1/Acos.pbtxt | 25 + .../core/ops/compat/ops_history_v1/Asin.pbtxt | 25 + .../core/ops/compat/ops_history_v1/Atan.pbtxt | 26 + .../core/ops/compat/ops_history_v1/BUILD | 1 + .../core/ops/compat/ops_history_v1/Tan.pbtxt | 25 + .../compat/ops_history_v2/AccumulateNV2.pbtxt | 51 + .../AccumulatorApplyGradient.pbtxt | 43 + .../AccumulatorTakeGradient.pbtxt | 43 + .../core/ops/compat/ops_history_v2/Acos.pbtxt | 25 + .../core/ops/compat/ops_history_v2/AddN.pbtxt | 48 + .../ops/compat/ops_history_v2/AllToAll.pbtxt | 56 + .../compat/ops_history_v2/ApplyAdaMax.pbtxt | 81 + .../compat/ops_history_v2/ApplyAdadelta.pbtxt | 73 + .../compat/ops_history_v2/ApplyAdagrad.pbtxt | 67 + .../ops_history_v2/ApplyAdagradDA.pbtxt | 77 + .../ops_history_v2/ApplyAdagradV2.pbtxt | 71 + .../ops/compat/ops_history_v2/ApplyAdam.pbtxt | 92 + .../compat/ops_history_v2/ApplyAddSign.pbtxt | 72 + .../ops_history_v2/ApplyCenteredRMSProp.pbtxt | 82 + .../ops/compat/ops_history_v2/ApplyFtrl.pbtxt | 84 + .../compat/ops_history_v2/ApplyFtrlV2.pbtxt | 88 + .../ops_history_v2/ApplyGradientDescent.pbtxt | 55 + .../compat/ops_history_v2/ApplyMomentum.pbtxt | 71 + .../ops_history_v2/ApplyPowerSign.pbtxt | 72 + .../ops_history_v2/ApplyProximalAdagrad.pbtxt | 68 + .../ApplyProximalGradientDescent.pbtxt | 63 + .../compat/ops_history_v2/ApplyRMSProp.pbtxt | 77 + .../ops_history_v2/ApproximateEqual.pbtxt | 50 + .../ops/compat/ops_history_v2/ArgMax.pbtxt | 156 +- .../ops/compat/ops_history_v2/ArgMin.pbtxt | 146 +- .../core/ops/compat/ops_history_v2/Asin.pbtxt | 25 + .../ops/compat/ops_history_v2/AssignAdd.pbtxt | 51 + .../ops/compat/ops_history_v2/AssignSub.pbtxt | 51 + .../core/ops/compat/ops_history_v2/Atan.pbtxt | 25 + .../core/ops/compat/ops_history_v2/BUILD | 1 + .../compat/ops_history_v2/BatchMatMulV2.pbtxt | 50 + .../BatchNormWithGlobalNormalization.pbtxt | 65 + ...BatchNormWithGlobalNormalizationGrad.pbtxt | 81 + .../ops/compat/ops_history_v2/BiasAdd.pbtxt | 55 + .../compat/ops_history_v2/BiasAddGrad.pbtxt | 51 + .../ops/compat/ops_history_v2/BiasAddV1.pbtxt | 42 + .../compat/ops_history_v2/ClipByValue.pbtxt | 46 + .../ops_history_v2/CollectivePermute.pbtxt | 43 + .../CollectiveReduceScatterV2.pbtxt | 95 + .../ComputeDedupDataTupleMask.pbtxt | 12 + .../ConditionalAccumulator.pbtxt | 67 + .../Conv2DBackpropFilterV2.pbtxt | 86 + .../Conv2DBackpropInputV2.pbtxt | 87 + .../ops/compat/ops_history_v2/Cumprod.pbtxt | 69 + .../ops/compat/ops_history_v2/Cumsum.pbtxt | 69 + .../ops_history_v2/CumulativeLogsumexp.pbtxt | 54 + .../ops_history_v2/DistributedSave.pbtxt | 23 + .../compat/ops_history_v2/EuclideanNorm.pbtxt | 62 + .../ops/compat/ops_history_v2/MatMul.pbtxt | 49 + .../core/ops/compat/ops_history_v2/Mean.pbtxt | 62 + .../ops_history_v2/MergeDedupData.pbtxt | 42 + .../core/ops/compat/ops_history_v2/Prod.pbtxt | 62 + .../ops_history_v2/RandomDatasetV2.pbtxt | 66 + .../ResourceAccumulatorApplyGradient.pbtxt | 43 + .../ResourceAccumulatorTakeGradient.pbtxt | 43 + .../ops_history_v2/ResourceApplyAdaMax.pbtxt | 74 + .../ResourceApplyAdadelta.pbtxt | 66 + .../ops_history_v2/ResourceApplyAdagrad.pbtxt | 61 + .../ResourceApplyAdagradDA.pbtxt | 70 + .../ResourceApplyAdagradV2.pbtxt | 65 + .../ops_history_v2/ResourceApplyAdam.pbtxt | 85 + .../ResourceApplyAdamWithAmsgrad.pbtxt | 82 + .../ops_history_v2/ResourceApplyAddSign.pbtxt | 66 + .../ResourceApplyCenteredRMSProp.pbtxt | 74 + .../ops_history_v2/ResourceApplyFtrl.pbtxt | 77 + .../ops_history_v2/ResourceApplyFtrlV2.pbtxt | 81 + .../ResourceApplyGradientDescent.pbtxt | 50 + .../ResourceApplyKerasMomentum.pbtxt | 65 + .../ResourceApplyMomentum.pbtxt | 65 + .../ResourceApplyPowerSign.pbtxt | 66 + .../ResourceApplyProximalAdagrad.pbtxt | 62 + ...ResourceApplyProximalGradientDescent.pbtxt | 58 + .../ops_history_v2/ResourceApplyRMSProp.pbtxt | 70 + .../ResourceConditionalAccumulator.pbtxt | 66 + .../ops_history_v2/ResourceScatterAdd.pbtxt | 53 + .../ops_history_v2/ResourceScatterDiv.pbtxt | 53 + .../ops_history_v2/ResourceScatterMax.pbtxt | 53 + .../ops_history_v2/ResourceScatterMin.pbtxt | 53 + .../ops_history_v2/ResourceScatterMul.pbtxt | 53 + .../ops_history_v2/ResourceScatterSub.pbtxt | 53 + .../ResourceSparseApplyAdadelta.pbtxt | 80 + .../ResourceSparseApplyAdagrad.pbtxt | 75 + .../ResourceSparseApplyAdagradDA.pbtxt | 84 + .../ResourceSparseApplyAdagradV2.pbtxt | 79 + .../ResourceSparseApplyCenteredRMSProp.pbtxt | 88 + .../ResourceSparseApplyFtrl.pbtxt | 91 + .../ResourceSparseApplyFtrlV2.pbtxt | 95 + .../ResourceSparseApplyKerasMomentum.pbtxt | 79 + .../ResourceSparseApplyMomentum.pbtxt | 79 + .../ResourceSparseApplyProximalAdagrad.pbtxt | 76 + ...ceSparseApplyProximalGradientDescent.pbtxt | 72 + .../ResourceSparseApplyRMSProp.pbtxt | 84 + .../compat/ops_history_v2/ScatterAdd.pbtxt | 65 + .../compat/ops_history_v2/ScatterDiv.pbtxt | 65 + .../compat/ops_history_v2/ScatterMul.pbtxt | 65 + .../compat/ops_history_v2/ScatterNdAdd.pbtxt | 65 + .../compat/ops_history_v2/ScatterNdMax.pbtxt | 65 + .../compat/ops_history_v2/ScatterNdMin.pbtxt | 65 + .../ScatterNdNonAliasingAdd.pbtxt | 57 + .../compat/ops_history_v2/ScatterNdSub.pbtxt | 65 + .../compat/ops_history_v2/ScatterSub.pbtxt | 65 + .../compat/ops_history_v2/SegmentMean.pbtxt | 52 + .../compat/ops_history_v2/SegmentProd.pbtxt | 52 + .../compat/ops_history_v2/SegmentProdV2.pbtxt | 69 + .../compat/ops_history_v2/SegmentSum.pbtxt | 52 + .../compat/ops_history_v2/SegmentSumV2.pbtxt | 69 + .../ops/compat/ops_history_v2/Shape.pbtxt | 35 + .../SparseAccumulatorApplyGradient.pbtxt | 55 + .../SparseAccumulatorTakeGradient.pbtxt | 51 + .../ops/compat/ops_history_v2/SparseAdd.pbtxt | 90 + .../compat/ops_history_v2/SparseAddGrad.pbtxt | 54 + .../ops_history_v2/SparseApplyAdadelta.pbtxt | 87 + .../ops_history_v2/SparseApplyAdagrad.pbtxt | 81 + .../ops_history_v2/SparseApplyAdagradDA.pbtxt | 91 + .../ops_history_v2/SparseApplyAdagradV2.pbtxt | 85 + .../SparseApplyCenteredRMSProp.pbtxt | 96 + .../ops_history_v2/SparseApplyFtrl.pbtxt | 98 + .../ops_history_v2/SparseApplyFtrlV2.pbtxt | 102 + .../ops_history_v2/SparseApplyMomentum.pbtxt | 85 + .../SparseApplyProximalAdagrad.pbtxt | 82 + .../SparseApplyProximalGradientDescent.pbtxt | 77 + .../ops_history_v2/SparseApplyRMSProp.pbtxt | 91 + .../SparseConditionalAccumulator.pbtxt | 67 + .../ops_history_v2/SparseDenseCwiseAdd.pbtxt | 50 + .../ops_history_v2/SparseDenseCwiseDiv.pbtxt | 50 + .../ops_history_v2/SparseDenseCwiseMul.pbtxt | 50 + .../ops_history_v2/SparseReduceSum.pbtxt | 57 + .../SparseReduceSumSparse.pbtxt | 65 + .../ops_history_v2/SparseSliceGrad.pbtxt | 50 + .../ops_history_v2/SparseSparseMinimum.pbtxt | 62 + .../ops_history_v2/SparseTensorDenseAdd.pbtxt | 60 + .../ops_history_v2/SplitDedupData.pbtxt | 42 + .../core/ops/compat/ops_history_v2/Sum.pbtxt | 62 + .../compat/ops_history_v2/SyncDevice.pbtxt | 4 + .../TPUPartitionedInputV2.pbtxt | 33 + .../TPUPartitionedOutputV2.pbtxt | 26 + .../core/ops/compat/ops_history_v2/Tan.pbtxt | 25 + .../ops_history_v2/UniformQuantizedAdd.pbtxt | 93 + .../UniformQuantizedConvolution.pbtxt | 159 + .../ops_history_v2/UnsortedSegmentProd.pbtxt | 69 + .../ops_history_v2/UnsortedSegmentSum.pbtxt | 69 + .../ops/compat/ops_history_v2/Where.pbtxt | 42 + .../ops/compat/ops_history_v2/Xdivy.pbtxt | 29 + .../ops/compat/ops_history_v2/Xlog1py.pbtxt | 29 + .../ops/compat/ops_history_v2/Xlogy.pbtxt | 29 + tensorflow/core/ops/dataset_ops.cc | 49 - .../core/ops/experimental_dataset_ops.cc | 31 +- tensorflow/core/ops/image_ops.cc | 11 +- tensorflow/core/ops/lookup_ops.cc | 5 +- tensorflow/core/ops/math_ops.cc | 59 +- tensorflow/core/ops/mkl_nn_ops.cc | 20 + tensorflow/core/ops/nn_ops.cc | 68 +- tensorflow/core/ops/ops.pbtxt | 1193 +++- tensorflow/core/ops/optional_ops.cc | 74 + .../core/ops/random_index_shuffle_ops.cc | 88 +- tensorflow/core/ops/string_ops.cc | 2 +- tensorflow/core/ops/sync_ops.cc | 27 + tensorflow/core/ops/uniform_quant_ops.cc | 170 +- tensorflow/core/ops/uniform_quant_ops_test.cc | 112 +- tensorflow/core/platform/BUILD | 42 +- tensorflow/core/platform/build_config.bzl | 2 - tensorflow/core/platform/cloud/BUILD | 5 +- tensorflow/core/platform/cloud/testdata/BUILD | 1 + tensorflow/core/platform/cpu_feature_guard.cc | 23 +- .../core/platform/default/build_config/BUILD | 26 + tensorflow/core/platform/error_payloads.cc | 2 +- tensorflow/core/platform/float8.h | 26 + tensorflow/core/platform/profile_utils/BUILD | 1 + tensorflow/core/platform/stream_executor.h | 1 - .../core/platform/stream_executor_no_cuda.h | 1 - tensorflow/core/platform/testdata/BUILD | 1 + tensorflow/core/platform/types.h | 1 + tensorflow/core/profiler/BUILD | 30 +- tensorflow/core/profiler/backends/cpu/BUILD | 42 +- .../core/profiler/backends/cpu/host_tracer.cc | 118 - .../core/profiler/backends/cpu/host_tracer.h | 18 +- .../profiler/backends/cpu/metadata_utils.h | 26 +- .../profiler/backends/cpu/python_tracer.h | 18 +- tensorflow/core/profiler/backends/gpu/BUILD | 80 +- .../core/profiler/backends/gpu/cuda_test.h | 38 +- .../profiler/backends/gpu/cupti_collector.h | 249 +- .../backends/gpu/cupti_error_manager.h | 240 +- .../profiler/backends/gpu/cupti_interface.h | 173 +- .../core/profiler/backends/gpu/cupti_tracer.h | 125 +- .../backends/gpu/device_tracer_test.cc | 12 +- .../core/profiler/backends/gpu/mock_cupti.h | 136 +- .../core/profiler/backends/gpu/nvtx_utils.h | 29 +- .../core/profiler/backends/gpu/rocm_tracer.h | 380 +- tensorflow/core/profiler/builds/BUILD | 7 +- .../core/profiler/builds/build_config.bzl | 17 +- tensorflow/core/profiler/builds/oss/BUILD | 3 - tensorflow/core/profiler/convert/BUILD | 136 +- .../convert/hlo_proto_to_graph_view.cc | 49 +- .../convert/hlo_proto_to_graph_view.h | 3 +- .../convert/hlo_proto_to_graph_view_test.cc | 34 +- ...hlo_proto_to_memory_visualization_utils.cc | 1140 ++-- .../hlo_proto_to_memory_visualization_utils.h | 17 +- ...roto_to_memory_visualization_utils_test.cc | 15 +- .../profiler/convert/hlo_to_tools_data.cc | 10 +- .../core/profiler/convert/hlo_to_tools_data.h | 2 +- .../convert/multi_xplanes_to_op_stats.cc | 66 + .../convert/multi_xplanes_to_op_stats.h | 38 + .../profiler/convert/op_metrics_to_record.h | 38 +- .../profiler/convert/op_profile_builder.cc | 92 +- .../profiler/convert/op_profile_builder.h | 3 +- .../profiler/convert/op_stats_combiner.cc | 19 +- .../convert/op_stats_to_op_profile.cc | 10 +- .../core/profiler/convert/tool_options.h | 25 - .../profiler/convert/xplane_to_op_stats.cc | 84 +- .../profiler/convert/xplane_to_op_stats.h | 13 +- .../convert/xplane_to_op_stats_test.cc | 179 +- .../convert/xplane_to_step_events_test.cc | 18 +- .../convert/xplane_to_tf_data_stats.cc | 15 +- .../profiler/convert/xplane_to_tool_names.cc | 26 +- .../convert/xplane_to_tool_names_test.cc | 130 + .../profiler/convert/xplane_to_tools_data.cc | 8 +- tensorflow/core/profiler/internal/BUILD | 1 + .../core/profiler/internal/advisor/BUILD | 1 + .../core/profiler/internal/testdata/BUILD | 1 + tensorflow/core/profiler/lib/BUILD | 65 +- .../core/profiler/lib/connected_traceme.h | 91 +- .../profiler/lib/device_profiler_session.h | 11 +- .../core/profiler/lib/profiler_factory.h | 2 +- .../core/profiler/lib/profiler_session.h | 73 +- .../lib/scoped_memory_debug_annotation.h | 86 +- tensorflow/core/profiler/profile.proto | 69 +- .../core/profiler/profiler_analysis.proto | 80 +- .../core/profiler/profiler_options.proto | 87 +- .../core/profiler/profiler_service.proto | 120 +- .../profiler_service_monitor_result.proto | 38 +- tensorflow/core/profiler/protobuf/BUILD | 24 + .../protobuf/memory_viewer_preprocess.proto | 2 + .../core/profiler/protobuf/op_metrics.proto | 31 + .../core/profiler/protobuf/op_profile.proto | 16 +- .../core/profiler/protobuf/op_stats.proto | 12 +- .../core/profiler/protobuf/pod_viewer.proto | 12 +- .../core/profiler/protobuf/topology.proto | 18 + .../core/profiler/protobuf/trace_events.proto | 71 +- tensorflow/core/profiler/rpc/BUILD | 14 +- tensorflow/core/profiler/rpc/client/BUILD | 100 +- .../profiler/rpc/client/capture_profile.h | 23 +- .../profiler/rpc/client/profiler_client.h | 70 +- .../client/remote_profiler_session_manager.h | 49 +- .../core/profiler/rpc/client/save_profile.h | 34 +- tensorflow/core/profiler/rpc/oss/BUILD | 5 +- .../core/profiler/rpc/profiler_server.h | 14 +- .../core/profiler/rpc/profiler_service_impl.h | 5 +- tensorflow/core/profiler/tfprof_log.proto | 3 +- tensorflow/core/profiler/utils/BUILD | 62 +- .../core/profiler/utils/derived_timeline.cc | 40 +- .../core/profiler/utils/file_system_utils.h | 37 +- tensorflow/core/profiler/utils/format_utils.h | 45 +- .../core/profiler/utils/hlo_module_map.cc | 125 + .../core/profiler/utils/hlo_module_map.h | 145 + .../core/profiler/utils/hlo_proto_map.cc | 10 + .../core/profiler/utils/hlo_proto_map.h | 9 +- .../profiler/utils/hlo_proto_to_module.cc | 2 +- .../core/profiler/utils/hlo_proto_to_module.h | 2 +- .../core/profiler/utils/kernel_stats_utils.cc | 2 +- tensorflow/core/protobuf/BUILD | 11 +- tensorflow/core/protobuf/autotuning.proto | 103 +- tensorflow/core/protobuf/config.proto | 48 +- tensorflow/core/protobuf/fingerprint.proto | 4 +- tensorflow/core/protobuf/service_config.proto | 5 +- tensorflow/core/protobuf/snapshot.proto | 11 + tensorflow/core/protobuf/tpu/BUILD | 1 + .../tpu/optimization_parameters.proto | 5 +- tensorflow/core/public/BUILD | 1 + tensorflow/core/public/version.h | 4 +- tensorflow/core/runtime_fallback/BUILD | 2 +- .../core/runtime_fallback/conversion/BUILD | 1 + .../runtime_fallback/conversion/conversion.cc | 2 +- tensorflow/core/runtime_fallback/kernel/BUILD | 30 + .../runtime_fallback/kernel/gpurt_kernels.cc | 151 + .../kernel_fallback_compat_request_state.h | 16 +- .../kernel/kernel_fallback_execute.cc | 5 +- .../kernel/kernel_fallback_execute_compat.cc | 118 +- .../kernel/kernel_fallback_execute_compat.h | 4 +- .../runtime_fallback/kernel/tensor_util.cc | 88 +- .../runtime_fallback/kernel/tensor_util.h | 96 +- .../core/runtime_fallback/runtime/BUILD | 1 + .../runtime/conversion_function.cc | 4 +- .../runtime_fallback_batch_tf_opkernels.cc | 34 +- .../runtime/runtime_fallback_kernels.cc | 6 +- .../runtime/runtime_fallback_tensor.h | 3 +- tensorflow/core/runtime_fallback/util/BUILD | 1 + .../core/runtime_fallback/util/attr_util.cc | 4 +- .../core/runtime_fallback/util/tensor_util.h | 4 +- .../runtime_fallback/util/tensor_util_test.cc | 2 +- tensorflow/core/summary/BUILD | 1 + tensorflow/core/summary/summary_db_writer.cc | 2 +- tensorflow/core/tfrt/BUILD | 2 + tensorflow/core/tfrt/common/BUILD | 29 +- .../core/tfrt/common/async_value_tensor.h | 1 + tensorflow/core/tfrt/common/pjrt_state.cc | 2 - tensorflow/core/tfrt/common/pjrt_state.h | 2 - tensorflow/core/tfrt/common/pjrt_util.cc | 37 +- tensorflow/core/tfrt/common/pjrt_util.h | 17 +- tensorflow/core/tfrt/common/pjrt_util_test.cc | 24 +- tensorflow/core/tfrt/eager/BUILD | 52 +- tensorflow/core/tfrt/eager/backends/BUILD | 1 + tensorflow/core/tfrt/eager/backends/cpu/BUILD | 1 + tensorflow/core/tfrt/eager/backends/gpu/BUILD | 1 + tensorflow/core/tfrt/eager/c_api_tfrt.cc | 74 +- tensorflow/core/tfrt/eager/c_api_tfrt.h | 20 +- .../tfrt/eager/c_api_tfrt_distributed_impl.cc | 218 - .../eager/c_api_tfrt_distributed_interface.h | 37 - tensorflow/core/tfrt/eager/core_runtime/BUILD | 1 + .../tfrt/eager/cpp_tests/core_runtime/BUILD | 2 + .../core_runtime/op_handler_selector_test.cc | 3 +- .../core/tfrt/eager/function_cache_test.cc | 26 +- .../tfrt/eager/transform_graph_function.cc | 3 +- tensorflow/core/tfrt/fallback/BUILD | 8 +- .../core/tfrt/fallback/cost_recorder.cc | 29 +- tensorflow/core/tfrt/fallback/cost_recorder.h | 38 +- .../core/tfrt/fallback/cost_recorder_test.cc | 37 +- .../core/tfrt/fallback/fallback_state.cc | 4 + .../core/tfrt/fallback/fallback_state.h | 3 + .../core/tfrt/fallback/op_cost_map.proto | 2 +- .../core/tfrt/fallback/op_kernel_runner.cc | 17 +- .../core/tfrt/fallback/op_kernel_runner.h | 28 +- .../tfrt/fallback/op_kernel_runner_cache.cc | 14 +- .../tfrt/fallback/op_kernel_runner_cache.h | 1 - .../tfrt/fallback/op_kernel_runner_test.cc | 86 +- tensorflow/core/tfrt/graph_executor/BUILD | 29 +- .../graph_executor/graph_execution_options.cc | 19 +- .../graph_executor/graph_execution_options.h | 10 +- .../tfrt/graph_executor/graph_executor.cc | 359 +- .../core/tfrt/graph_executor/graph_executor.h | 122 +- .../graph_executor/graph_executor_test.cc | 178 +- .../synchronous_graph_executor.cc | 21 +- .../synchronous_graph_executor.h | 9 +- .../synchronous_graph_executor_test.cc | 32 +- tensorflow/core/tfrt/mla/BUILD | 1 + .../core/tfrt/run_handler_thread_pool/BUILD | 2 +- .../run_handler_thread_pool/run_handler.h | 5 +- .../run_handler_concurrent_work_queue.cc | 12 +- .../run_handler_concurrent_work_queue.h | 4 +- .../run_handler_concurrent_work_queue_test.cc | 7 +- tensorflow/core/tfrt/runtime/BUILD | 2 +- tensorflow/core/tfrt/runtime/runtime.h | 18 + .../tf_threadpool_concurrent_work_queue.cc | 12 +- .../tf_threadpool_concurrent_work_queue.h | 6 +- ...f_threadpool_concurrent_work_queue_test.cc | 6 +- .../core/tfrt/runtime/work_queue_interface.cc | 71 +- .../core/tfrt/runtime/work_queue_interface.h | 18 +- .../tfrt/runtime/work_queue_interface_test.cc | 11 +- tensorflow/core/tfrt/saved_model/BUILD | 6 +- .../core/tfrt/saved_model/saved_model.cc | 672 +- .../core/tfrt/saved_model/saved_model.h | 40 +- .../saved_model/saved_model_import_input.cc | 4 +- .../saved_model/saved_model_import_input.h | 2 +- .../tfrt/saved_model/saved_model_mira_impl.h | 4 +- .../tfrt/saved_model/saved_model_testutil.cc | 19 +- .../tfrt/saved_model/saved_model_testutil.h | 2 +- tensorflow/core/tfrt/saved_model/tests/BUILD | 2 +- .../saved_model/tests/saved_model_gpu_test.cc | 13 +- .../saved_model/tests/saved_model_test.cc | 320 +- tensorflow/core/tfrt/tpu/BUILD | 1 + tensorflow/core/tfrt/utils/BUILD | 21 + .../core/tfrt/utils/device_variables_table.h | 98 + .../core/tfrt/utils/gpu_variables_table.h | 42 + .../tfrt/utils/tfrt_graph_execution_state.cc | 2 +- .../tfrt/utils/tfrt_graph_execution_state.h | 3 + .../utils/tfrt_graph_execution_state_test.cc | 9 +- tensorflow/core/tpu/BUILD | 6 +- tensorflow/core/tpu/graph_rewrite/BUILD | 1 + .../distributed_tpu_rewrite_pass.cc | 47 +- .../encapsulate_tpu_computations_pass.cc | 4 +- ...ing_software_deduplication_rewrite_pass.cc | 1 - tensorflow/core/tpu/kernels/BUILD | 22 +- tensorflow/core/tpu/kernels/infeed_ops.cc | 4 +- .../core/tpu/kernels/sharding_util_ops.cc | 3 +- .../tpu_compilation_cache_interface.cc | 5 +- tensorflow/core/tpu/kernels/tpu_compile_op.cc | 2 +- .../core/tpu/kernels/tpu_compile_op_common.cc | 21 +- .../core/tpu/kernels/tpu_compile_op_common.h | 6 +- .../core/tpu/kernels/tpu_compile_op_impl.cc | 2 +- .../tpu/kernels/tpu_compile_op_support.cc | 4 +- .../core/tpu/kernels/tpu_compile_op_support.h | 28 +- .../core/tpu/kernels/tpu_configuration_ops.cc | 28 +- .../core/tpu/kernels/tpu_configuration_ops.h | 1 - .../tpu_embedding_configuration_ops.cc | 28 +- .../tpu_embedding_engine_state_interface.h | 12 +- .../tpu/kernels/tpu_embedding_enqueue_ops.cc | 11 +- .../tpu_embedding_load_retrieve_ops.cc | 8 +- .../core/tpu/kernels/tpu_embedding_ops.cc | 458 +- .../core/tpu/kernels/tpu_functional_ops.cc | 26 +- .../core/tpu/kernels/tpu_functional_ops.h | 2 +- .../core/tpu/kernels/tpu_handle_to_key_op.cc | 2 +- .../tpu/kernels/tpu_mesh_state_interface.h | 13 +- tensorflow/core/tpu/kernels/tpu_op_util.cc | 13 +- .../core/tpu/kernels/tpu_ordinal_selector.h | 16 +- tensorflow/core/tpu/kernels/tpu_pod_state.cc | 14 +- .../core/tpu/kernels/tpu_program_group.cc | 77 +- .../core/tpu/kernels/tpu_program_group.h | 2 +- tensorflow/core/tpu/kernels/xla/BUILD | 1 + tensorflow/core/tpu/kernels/xla/infeed_op.cc | 4 +- tensorflow/core/tpu/ops/BUILD | 4 +- tensorflow/core/tpu/ops/tpu_embedding_ops.cc | 96 + tensorflow/core/tpu/ops/tpu_embedding_ops.h | 6 + .../core/tpu/ops/tpu_partitioned_input_op.cc | 131 +- .../core/tpu/ops/tpu_partitioned_output_op.cc | 82 +- .../core/tpu/tpu_api_dlsym_initializer.cc | 1 + tensorflow/core/tpu/tpu_embedding_errors.cc | 3 +- tensorflow/core/tpu/tpu_embedding_errors.h | 3 +- ...embedding_optimization_parameters_utils.cc | 2 +- tensorflow/core/tpu/tpu_execute.cc | 34 +- .../core/tpu/tpu_model_server_initializer.cc | 2 +- tensorflow/core/transforms/BUILD | 1 + tensorflow/core/transforms/cf_sink/BUILD | 1 + .../core/transforms/consolidate_attrs/BUILD | 2 + .../core/transforms/consolidate_attrs/pass.cc | 1 + .../core/transforms/const_dedupe_hoist/BUILD | 1 + .../core/transforms/constant_folding/BUILD | 1 + .../core/transforms/constant_folding/pass.cc | 15 +- tensorflow/core/transforms/cse/BUILD | 1 + tensorflow/core/transforms/cse/tests/cse.mlir | 4 +- .../drop_unregistered_attribute/BUILD | 1 + .../eliminate_passthrough_iter_args/BUILD | 1 + .../core/transforms/func_to_graph/BUILD | 1 + .../func_to_graph/tests/round_trip.mlir | 8 +- .../transforms/functional_to_region/BUILD | 1 + .../transforms/functional_to_region/impl.cc | 14 +- .../core/transforms/graph_compactor/BUILD | 1 + .../core/transforms/graph_compactor/pass.cc | 5 +- .../graph_compactor/tests/rename.mlir | 2 +- .../graph_compactor/tests/rename_lots.mlir | 2 +- .../core/transforms/graph_to_func/BUILD | 1 + .../tests/same_feeds_and_fetches.mlir | 2 +- .../graph_to_func/tests/simple.mlir | 4 +- .../transforms/graph_transform_wrapper.cc | 2 +- tensorflow/core/transforms/legacy_call/BUILD | 1 + .../transforms/region_to_functional/BUILD | 1 + .../transforms/region_to_functional/impl.cc | 11 +- .../transforms/region_to_functional/pass.cc | 2 +- .../tests/idempotence_arg_count.mlir | 2 +- .../tests/idempotence_arg_reorder.mlir | 2 +- .../tests/sink_respecialize.mlir | 4 +- tensorflow/core/transforms/remapper/BUILD | 1 + tensorflow/core/transforms/remapper/pass.cc | 79 + .../remapper/tests/onednn_mish.mlir | 16 + .../core/transforms/shape_inference/BUILD | 2 + .../core/transforms/shape_inference/pass.cc | 10 +- tensorflow/core/transforms/toposort/BUILD | 1 + .../transforms/toposort/tests/toposort.mlir | 2 +- .../toposort/tests/toposort_regions.mlir | 2 +- tensorflow/core/user_ops/BUILD | 1 + tensorflow/core/util/BUILD | 37 +- tensorflow/core/util/autotune_maps/BUILD | 35 +- .../util/autotune_maps/autotune_maps_utils.cc | 102 - .../util/autotune_maps/autotune_maps_utils.h | 53 - .../util/autotune_maps/autotune_serialize.cc | 91 +- .../autotune_maps/autotune_serialize_test.cc | 17 +- .../util/autotune_maps/conv_parameters.cc | 47 +- .../core/util/autotune_maps/conv_parameters.h | 29 +- tensorflow/core/util/ctc/BUILD | 1 + tensorflow/core/util/cuda_solvers.cc | 12 + tensorflow/core/util/cuda_sparse.cc | 111 +- tensorflow/core/util/cuda_sparse.h | 190 +- .../core/util/example_proto_helper_test.cc | 41 + .../util/fake_clock_env.cc} | 28 +- tensorflow/core/util/fake_clock_env.h | 57 + tensorflow/core/util/fake_clock_env_test.cc | 65 + tensorflow/core/util/gpu_device_functions.h | 37 + tensorflow/core/util/gpu_kernel_helper.h | 1 + .../core/util/gpu_kernel_helper_test.cu.cc | 12 +- tensorflow/core/util/gpu_launch_config.h | 24 +- tensorflow/core/util/gpu_solvers.h | 14 +- tensorflow/core/util/image_resizer_state.h | 2 +- tensorflow/core/util/mkl_threadpool.h | 34 +- tensorflow/core/util/mkl_util.h | 1 - tensorflow/core/util/onednn_env_vars.cc | 11 + tensorflow/core/util/onednn_env_vars.h | 5 + tensorflow/core/util/port.cc | 66 +- tensorflow/core/util/port.h | 3 + tensorflow/core/util/proto/BUILD | 1 + tensorflow/core/util/quantization/BUILD | 1 + .../quantization/uniform_quant_ops_params.cc | 2 +- tensorflow/core/util/rocm_solvers.cc | 9 +- tensorflow/core/util/rocm_sparse.cc | 13 +- tensorflow/core/util/sparse/BUILD | 44 + tensorflow/core/util/sparse/dim_comparator.h | 1 + tensorflow/core/util/strided_slice_op_test.cc | 432 ++ tensorflow/core/util/tensor_bundle/BUILD | 1 + .../util/tensor_bundle/byte_swap_tensor.cc | 85 +- .../util/tensor_bundle/byte_swap_tensor.h | 12 +- .../core/util/tensor_bundle/testdata/BUILD | 1 + .../testdata/old_string_tensors/BUILD | 1 + tensorflow/core/util/tensor_format.h | 51 +- tensorflow/core/util/tensor_slice_reader.cc | 11 + tensorflow/core/util/zen_util.h | 51 + .../distribute/experimental/rpc/kernels/BUILD | 2 + .../experimental/rpc/kernels/oss/BUILD | 3 +- .../distribute/experimental/rpc/proto/BUILD | 1 + tensorflow/dtensor/BUILD | 3 + tensorflow/dtensor/build_defs.bzl | 173 + tensorflow/dtensor/cc/BUILD | 83 +- tensorflow/dtensor/cc/constants.h | 12 +- .../dtensor/cc/default_parallel_executor.cc | 5 +- tensorflow/dtensor/cc/dstatus.h | 4 +- tensorflow/dtensor/cc/dtensor_device.cc | 588 +- tensorflow/dtensor/cc/dtensor_device.h | 26 +- tensorflow/dtensor/cc/dtensor_device_util.cc | 354 +- tensorflow/dtensor/cc/dtensor_device_util.h | 544 +- .../dtensor/cc/dtensor_graph_to_mlir_pass.cc | 43 +- .../dtensor/cc/dtensor_graph_to_mlir_pass.h | 17 +- tensorflow/dtensor/cc/dtensor_meta_ops.cc | 2 +- tensorflow/dtensor/cc/dtensor_tpu_kernels.cc | 10 +- tensorflow/dtensor/cc/dtensor_utils.cc | 30 + tensorflow/dtensor/cc/dtensor_utils.h | 9 + tensorflow/dtensor/cc/mesh_type.h | 33 + tensorflow/dtensor/cc/parallel_executor.h | 29 +- tensorflow/dtensor/cc/save_restore_util.cc | 2 +- .../dtensor/cc/small_constant_optimization.cc | 19 +- .../dtensor/cc/small_constant_optimization.h | 11 +- tensorflow/dtensor/cc/tensor_layout.cc | 128 +- tensorflow/dtensor/cc/tensor_layout.h | 47 +- tensorflow/dtensor/cc/tensor_with_layout.h | 174 + .../cc/xla_spmd/layout_to_xla_sharding.cc | 289 + .../cc/xla_spmd/layout_to_xla_sharding.h | 36 + tensorflow/dtensor/mlir/BUILD | 8 +- tensorflow/dtensor/mlir/Passes.td | 122 + tensorflow/dtensor/mlir/collectives.cc | 6 +- .../dtensor/mlir/create_dtensor_mlir_passes.h | 10 + .../mlir/designate_resource_handle_mesh.cc | 13 - .../dtensor_allreduce_combine_optimization.cc | 18 +- .../dtensor_allreduce_scatter_optimization.cc | 16 +- .../dtensor_allreduce_sum_optimization.cc | 42 +- tensorflow/dtensor/mlir/dtensor_dialect/BUILD | 1 + .../dtensor_dialect/ir/dtensor_dialect.td | 3 +- .../mlir/dtensor_mixed_precision_reduce.cc | 20 +- .../dtensor/mlir/dtensor_mlir_passes.cc | 13 + .../mlir/dtensor_remove_dtensorlayout.cc | 48 + .../dtensor_replace_auxiliary_layout_op.cc | 48 + tensorflow/dtensor/mlir/dtensor_send_recv.cc | 104 +- tensorflow/dtensor/mlir/dtensor_send_recv.h | 16 +- .../dtensor/mlir/dtensor_set_hlo_sharding.cc | 141 + .../mlir/expansions/argmax_spmd_expander.cc | 18 +- .../mlir/expansions/bias_add_spmd_expander.cc | 6 +- .../expansions/broadcast_to_spmd_expander.cc | 24 +- .../expansions/control_flow_spmd_expander.cc | 8 +- .../mlir/expansions/conv_spmd_expander.cc | 418 +- .../mlir/expansions/cumsum_spmd_expander.cc | 4 +- .../expansions/dataparallel_spmd_expander.cc | 2 +- .../expansions/dtensor_op_spmd_expander.cc | 20 +- .../mlir/expansions/einsum_spmd_expander.cc | 10 +- .../expansions/expanddims_spmd_expander.cc | 13 +- .../mlir/expansions/fill_spmd_expander.cc | 23 +- .../mlir/expansions/gather_spmd_expander.cc | 396 +- .../mlir/expansions/gather_spmd_expander.h | 181 +- .../mlir/expansions/in_top_k_spmd_expander.cc | 8 +- .../mlir/expansions/iterator_spmd_expander.cc | 148 + .../mlir/expansions/iterator_spmd_expander.h | 54 + .../mlir/expansions/matmul_spmd_expander.cc | 16 +- .../mlir/expansions/meta_spmd_expander.cc | 96 +- .../mlir/expansions/nullary_spmd_expander.cc | 4 +- .../mlir/expansions/optional_spmd_expander.cc | 104 + .../mlir/expansions/optional_spmd_expander.h | 54 + .../expansions/random_op_spmd_expander.cc | 43 +- .../mlir/expansions/reduce_spmd_expander.cc | 10 +- .../mlir/expansions/resource_spmd_expander.cc | 80 +- .../expansions/save_restore_spmd_expander.cc | 28 +- .../mlir/expansions/scatter_spmd_expander.cc | 210 +- .../mlir/expansions/scatter_spmd_expander.h | 12 + .../expansions/segmentation_spmd_expander.cc | 16 +- .../mlir/expansions/slice_spmd_expander.cc | 160 +- .../mlir/expansions/softmax_spmd_expander.cc | 26 +- .../mlir/expansions/split_spmd_expander.cc | 12 +- .../tensorlist_reserve_spmd_expander.cc | 2 +- .../mlir/expansions/top_k_spmd_expander.cc | 6 +- .../unsupported_op_spmd_expander.cc | 44 + .../expansions/unsupported_op_spmd_expander.h | 46 + .../dtensor/mlir/group_assignment_test.cc | 1 - .../mlir/handle_cross_cluster_dependencies.cc | 16 +- .../dtensor/mlir/handle_sparsetensors.cc | 13 +- tensorflow/dtensor/mlir/ir/tf_dtensor.cc | 22 +- tensorflow/dtensor/mlir/ir/tf_dtensor.td | 4 +- tensorflow/dtensor/mlir/layout_parsing.cc | 49 +- tensorflow/dtensor/mlir/layout_parsing.h | 7 +- .../dtensor/mlir/layout_propagation_v2.cc | 99 +- tensorflow/dtensor/mlir/merge_clusters.cc | 22 +- tensorflow/dtensor/mlir/mesh_propagation.cc | 8 +- .../dtensor/mlir/move_compilation_to_host.cc | 10 +- .../dtensor/mlir/op_to_device_cluster.cc | 4 +- tensorflow/dtensor/mlir/op_utils.cc | 54 +- tensorflow/dtensor/mlir/op_utils.h | 26 +- .../dtensor/mlir/propagate_default_layout.cc | 8 +- .../propagate_device_id_to_function_args.cc | 16 +- .../dtensor/mlir/restore_shape_inference.cc | 20 +- tensorflow/dtensor/mlir/shape_utils.cc | 12 +- .../dynamic_enqueue_sparse_expander.cc | 12 +- .../matmul_sparse_expander.cc | 2 +- tensorflow/dtensor/mlir/spmd_expander.cc | 29 +- .../dtensor/mlir/spmd_expander_common.cc | 37 +- tensorflow/dtensor/mlir/spmd_expanders.cc | 40 + tensorflow/dtensor/mlir/spmd_expansion.cc | 22 +- tensorflow/dtensor/mlir/tests/BUILD | 60 + .../mlir/tests/annotate_global_shape.mlir | 70 + .../tests/cluster_function_conversion.mlir | 113 + .../dtensor/mlir/tests/constant_folding.mlir | 19 + .../dtensor/mlir/tests/cpu_layout.pbtxt | 18 + .../tests/designate_resource_handle_mesh.mlir | 48 + .../tests/device_mesh_cluster_coarsening.mlir | 167 + .../mlir/tests/dtensor_all_gather.mlir | 31 + .../mlir/tests/dtensor_all_scatter.mlir | 21 + ...tensor_allreduce_combine_optimization.mlir | 148 + .../tests/dtensor_allreduce_lowering.mlir | 58 + ...tensor_allreduce_scatter_optimization.mlir | 112 + .../dtensor_allreduce_sum_optimization.mlir | 293 + .../tests/dtensor_embedding_checkpoint.mlir | 18 + .../mlir/tests/dtensor_embedding_v2.mlir | 117 + .../tests/dtensor_mixed_precision_reduce.mlir | 111 + .../mlir/tests/dtensor_mlir_opt_main.cc | 45 + .../dtensor_reduce_scatter_lowering.mlir | 140 + .../tests/dtensor_remove_dtensorlayout.mlir | 24 + .../dtensor_replace_auxiliary_layout_op.mlir | 12 + .../mlir/tests/dtensor_set_hlo_sharding.mlir | 27 + .../dtensor_set_hlo_sharding_default.mlir | 17 + .../tests/dtensor_xla_spmd_integration.mlir | 26 + .../elide_identity_before_copy_to_mesh.mlir | 17 + .../mlir/tests/embedding_optimizer.mlir | 93 + .../dtensor/mlir/tests/function_renaming.mlir | 17 + .../handle_cross_cluster_dependencies.mlir | 246 + .../mlir/tests/handle_sparsetensors.mlir | 81 + .../mlir/tests/layout_propagation.mlir | 535 ++ .../mlir/tests/layout_propagation_v2.mlir | 1040 +++ .../dtensor/mlir/tests/lower_send_recv.mlir | 78 + .../dtensor/mlir/tests/merge_clusters.mlir | 556 ++ .../dtensor/mlir/tests/mesh_propagation.mlir | 547 ++ .../mlir/tests/move_compilation_to_host.mlir | 391 ++ .../mlir/tests/op_to_device_cluster.mlir | 72 + .../mlir/tests/propagate_default_layout.mlir | 77 + .../propagate_device_id_to_function.mlir | 35 + .../mlir/tests/restore_and_assign.mlir | 109 + .../mlir/tests/restore_shape_inference.mlir | 62 + .../mlir/tests/set_default_sharding.mlir | 30 + .../dtensor/mlir/tests/sparse_expansion.mlir | 57 + .../mlir/tests/spmd_batchparallel.mlir | 38 + .../dtensor/mlir/tests/spmd_concat.mlir | 190 + tensorflow/dtensor/mlir/tests/spmd_conv.mlir | 696 ++ .../dtensor/mlir/tests/spmd_dtensor_ops.mlir | 323 + .../dtensor/mlir/tests/spmd_einsum.mlir | 214 + .../dtensor/mlir/tests/spmd_embedding.mlir | 110 + .../dtensor/mlir/tests/spmd_expansion.mlir | 712 ++ tensorflow/dtensor/mlir/tests/spmd_fill.mlir | 53 + .../dtensor/mlir/tests/spmd_io_ops.mlir | 21 + .../dtensor/mlir/tests/spmd_iterator.mlir | 107 + .../dtensor/mlir/tests/spmd_matmul.mlir | 180 + .../dtensor/mlir/tests/spmd_metadata.mlir | 111 + .../dtensor/mlir/tests/spmd_random.mlir | 62 + .../dtensor/mlir/tests/spmd_reduction.mlir | 91 + .../dtensor/mlir/tests/spmd_save_restore.mlir | 189 + .../dtensor/mlir/tests/spmd_segment_sum.mlir | 29 + tensorflow/dtensor/mlir/tests/spmd_slice.mlir | 155 + .../dtensor/mlir/tests/spmd_softmax_loss.mlir | 134 + .../dtensor/mlir/tests/spmd_squeeze.mlir | 55 + tensorflow/dtensor/mlir/tests/spmd_tile.mlir | 62 + .../dtensor/mlir/tests/spmd_var_handle.mlir | 33 + .../dtensor/mlir/tests/tf_dtensor_ops.mlir | 109 + .../tpu_add_resource_device_attribute.mlir | 74 + .../dtensor/mlir/tests/tpu_integration.mlir | 36 + .../tests/undo_merge_const_across_mesh.mlir | 17 + .../mlir/tests/update_tpu_metadata.mlir | 83 + .../mlir/tpu_add_resource_device_attribute.cc | 4 +- tensorflow/dtensor/mlir/tpu_integration.cc | 2 +- .../mlir/undo_merge_const_across_mesh.cc | 2 +- tensorflow/dtensor/mlir/utils/BUILD | 14 +- .../dtensor/mlir/utils/collective_lowering.cc | 676 +- .../mlir/utils/dtensor_embedding_stub.cc | 2 +- .../mlir/utils/dtensor_embedding_v2_stub.cc | 14 +- .../utils/dtensor_mlir_passes_internal.cc | 3 +- .../dtensor/mlir/utils/update_tpu_metadata.cc | 50 +- tensorflow/dtensor/mlir/value_utils.cc | 16 +- tensorflow/dtensor/mlir/value_utils.h | 5 + tensorflow/dtensor/proto/BUILD | 1 + tensorflow/dtensor/proto/layout.proto | 3 + tensorflow/dtensor/python/BUILD | 21 +- tensorflow/dtensor/python/accelerator_util.py | 26 +- tensorflow/dtensor/python/api.py | 42 +- tensorflow/dtensor/python/config.py | 5 + tensorflow/dtensor/python/d_checkpoint.py | 18 +- tensorflow/dtensor/python/dtensor_device.py | 55 +- tensorflow/dtensor/python/layout.py | 59 +- tensorflow/dtensor/python/mesh_util.py | 52 +- tensorflow/dtensor/python/numpy_util.py | 115 + tensorflow/dtensor/python/tests/BUILD | 339 + .../dtensor/python/tests/collective_test.py | 448 ++ .../dtensor/python/tests/config_test.py | 92 + .../dtensor/python/tests/multi_client_test.py | 339 + .../dtensor/python/tests/numpy_util_test.py | 127 + tensorflow/dtensor/python/tests/spmd_test.py | 3508 ++++++++++ .../dtensor/python/tests/test_backend_name.py | 39 + .../dtensor/python/tests/test_backend_util.py | 67 + tensorflow/dtensor/python/tests/test_util.py | 384 ++ .../dtensor/python/tests/test_util_ops.py | 658 ++ tensorflow/dtensor/python/tpu_util.py | 32 +- tensorflow/dtensor/tests/BUILD | 15 + .../tests/layout_to_xla_sharding_test.cc | 377 + .../dtensor/tests/tensor_layout_test.cc | 68 +- tensorflow/examples/adding_an_op/BUILD | 9 +- .../custom_ops_doc/multiplex_1/README.md | 402 ++ .../multiplex_1/multiplex_1_test.py | 4 +- .../multiplex_2/multiplex_2_test.py | 4 +- .../custom_ops_doc/simple_hash_table/BUILD | 3 +- tensorflow/examples/label_image/BUILD | 1 + tensorflow/examples/multibox_detector/BUILD | 5 +- tensorflow/examples/speech_commands/BUILD | 1 + .../examples/speech_commands/freeze_test.py | 4 +- tensorflow/examples/wav_to_spectrogram/BUILD | 5 +- tensorflow/go/core/framework/BUILD | 1 + tensorflow/go/core/protobuf/BUILD | 1 + .../go/example_inception_inference_test.go | 96 +- tensorflow/go/genop/BUILD | 1 + tensorflow/go/genop/internal/BUILD | 1 + tensorflow/go/graph_test.go | 7 + tensorflow/go/op/BUILD | 1 + tensorflow/go/op/wrappers.go | 1027 ++- tensorflow/go/saved_model_test.go | 81 +- tensorflow/go/stream_executor/BUILD | 1 + tensorflow/go/tensor.go | 48 +- tensorflow/go/tensor_handle_test.go | 9 +- .../go/testdata/label_image/grace_hopper.jpg | Bin 0 -> 73746 bytes .../half_plus_two/00000123/assets/foo.txt | 1 + .../half_plus_two/00000123/saved_model.pb | Bin 0 -> 9935 bytes .../variables/variables.data-00000-of-00001 | Bin 0 -> 12 bytes .../00000123/variables/variables.index | Bin 0 -> 151 bytes tensorflow/go/tsl/profiler/protobuf/BUILD | 1 + tensorflow/go/tsl/protobuf/BUILD | 1 + tensorflow/java/BUILD | 1 + tensorflow/java/LEGACY.md | 88 + tensorflow/java/README.md | 95 +- tensorflow/java/build_defs.bzl | 245 +- .../main/java/org/tensorflow/examples/BUILD | 1 + .../java/src/main/native/exception_jni.h | 4 +- tensorflow/js/BUILD | 1 + tensorflow/lite/BUILD | 235 +- tensorflow/lite/CMakeLists.txt | 34 +- tensorflow/lite/allocation.h | 1 + tensorflow/lite/arena_planner.cc | 145 +- tensorflow/lite/arena_planner.h | 15 +- tensorflow/lite/arena_planner_test.cc | 132 +- tensorflow/lite/build_def.bzl | 30 +- tensorflow/lite/builtin_op_data.h | 2 +- tensorflow/lite/c/BUILD | 117 +- tensorflow/lite/c/CMakeLists.txt | 18 +- tensorflow/lite/c/README.md | 3 +- tensorflow/lite/c/builtin_op_data.h | 509 +- tensorflow/lite/c/c_api_experimental.h | 389 +- tensorflow/lite/c/c_api_internal.h | 10 +- tensorflow/lite/c/c_api_opaque.cc | 238 +- tensorflow/lite/c/c_api_opaque.h | 307 +- tensorflow/lite/c/c_api_opaque_internal.cc | 2 +- tensorflow/lite/c/c_api_opaque_internal.h | 2 +- .../lite/c/c_api_signature_runner_test.cc | 2 +- tensorflow/lite/c/c_api_types.h | 127 +- tensorflow/lite/c/c_test.c | 14 +- tensorflow/lite/c/common.h | 1089 +-- tensorflow/lite/c/common_internal.cc | 16 +- tensorflow/lite/c/common_internal.h | 4 +- .../lite/cmake/DownloadPThreadPool.cmake | 30 + tensorflow/lite/context.h | 2 +- tensorflow/lite/context_util.h | 3 +- tensorflow/lite/core/BUILD | 121 +- tensorflow/lite/core/api/BUILD | 9 +- tensorflow/lite/core/api/error_reporter.h | 15 +- .../lite/core/api/flatbuffer_conversions.cc | 9 +- .../lite/core/api/flatbuffer_conversions.h | 2 +- .../core/api/flatbuffer_conversions_test.cc | 2 +- tensorflow/lite/core/api/op_resolver.cc | 2 +- tensorflow/lite/core/api/op_resolver.h | 17 +- .../lite/core/api/op_resolver_internal.h | 1 + .../core/api/op_resolver_internal_test.cc | 4 +- tensorflow/lite/core/api/profiler.h | 29 +- tensorflow/lite/core/api/tensor_utils.cc | 2 +- tensorflow/lite/core/api/tensor_utils.h | 2 +- tensorflow/lite/core/api/verifier.h | 1 + tensorflow/lite/core/async/BUILD | 152 + tensorflow/lite/core/async/README.md | 16 + .../lite/core/async/async_kernel_internal.h | 139 + .../lite/core/async/async_signature_runner.cc | 159 + .../lite/core/async/async_signature_runner.h | 178 + .../core/async/async_signature_runner_test.cc | 115 + tensorflow/lite/core/async/async_subgraph.cc | 186 + tensorflow/lite/core/async/async_subgraph.h | 174 + .../lite/core/async/async_subgraph_test.cc | 163 + .../async/backend_async_kernel_interface.cc | 163 + .../async/backend_async_kernel_interface.h | 202 + .../backend_async_kernel_interface_test.cc | 61 + tensorflow/lite/core/async/c/BUILD | 47 + tensorflow/lite/core/async/c/task.cc | 100 + tensorflow/lite/core/async/c/task.h | 136 + tensorflow/lite/core/async/c/task_test.cc | 116 + tensorflow/lite/core/async/c/types.h | 33 + .../core/async/common.h} | 21 +- tensorflow/lite/core/async/interop/BUILD | 70 + .../lite/core/async/interop/attribute_keys.h | 51 + .../async/interop/attribute_map_internal.cc | 48 + .../async/interop/attribute_map_internal.h | 113 + .../interop/attribute_map_internal_test.cc | 125 + tensorflow/lite/core/async/interop/c/BUILD | 61 + .../core/async/interop/c/attribute_map.cc | 78 + .../lite/core/async/interop/c/attribute_map.h | 94 + .../async/interop/c/attribute_map_test.cc | 96 + .../lite/core/async/interop/c/constants.cc | 21 + .../lite/core/async/interop/c/constants.h | 45 + tensorflow/lite/core/async/interop/c/types.cc | 55 + tensorflow/lite/core/async/interop/c/types.h | 95 + .../lite/core/async/interop/c/types_test.cc | 54 + .../lite/core/async/interop/reconcile_fns.cc | 187 + .../lite/core/async/interop/reconcile_fns.h | 49 + .../core/async/interop/reconcile_fns_test.cc | 297 + tensorflow/lite/core/async/interop/variant.cc | 47 + tensorflow/lite/core/async/interop/variant.h | 117 + .../lite/core/async/interop/variant_test.cc | 123 + tensorflow/lite/core/async/task_internal.cc | 107 + tensorflow/lite/core/async/task_internal.h | 151 + .../lite/core/async/task_internal_test.cc | 102 + tensorflow/lite/core/async/testing/BUILD | 34 + .../core/async/testing/mock_async_kernel.h | 67 + .../lite/core/async/testing/test_backend.cc | 91 + .../lite/core/async/testing/test_backend.h | 63 + tensorflow/lite/core/c/BUILD | 276 +- tensorflow/lite/core/c/builtin_op_data.h | 537 ++ .../lite/{ => core}/c/builtin_op_data_test.cc | 3 +- tensorflow/lite/core/c/c_api.cc | 65 +- tensorflow/lite/core/c/c_api.h | 558 +- .../lite/{ => core}/c/c_api_experimental.cc | 11 +- tensorflow/lite/core/c/c_api_experimental.h | 419 ++ .../{ => core}/c/c_api_experimental_test.cc | 14 +- tensorflow/lite/core/c/c_api_test.cc | 435 +- tensorflow/lite/core/c/c_api_types.h | 168 + tensorflow/lite/{ => core}/c/common.cc | 63 +- tensorflow/lite/core/c/common.h | 1156 ++++ tensorflow/lite/{ => core}/c/common_test.cc | 202 +- tensorflow/lite/core/c/special_rules.bzl | 10 +- .../acceleration/configuration/BUILD | 33 +- .../acceleration/configuration/c/BUILD | 128 +- .../configuration/c/delegate_plugin.h | 105 + .../configuration/c/gpu_plugin.cc | 4 +- .../acceleration/configuration/c/gpu_plugin.h | 50 + .../configuration/c/gpu_plugin_test.cc | 4 +- .../configuration/c/nnapi_plugin.cc | 4 +- .../configuration/c/nnapi_plugin.h | 50 + .../configuration/c/nnapi_plugin_test.cc | 4 +- .../{vendor_delegate.h => stable_delegate.h} | 24 +- .../configuration/c/xnnpack_plugin.cc | 4 +- .../configuration/c/xnnpack_plugin.h | 50 + .../configuration/c/xnnpack_plugin_test.cc | 4 +- .../configuration/delegate_registry.h | 2 +- .../configuration/stable_delegate_registry.cc | 58 + .../configuration/stable_delegate_registry.h | 57 + .../stable_delegate_registry_test.cc | 51 + tensorflow/lite/core/interpreter.cc | 28 +- tensorflow/lite/core/interpreter.h | 203 +- tensorflow/lite/core/interpreter_builder.cc | 73 +- tensorflow/lite/core/interpreter_builder.h | 39 +- .../lite/core/interpreter_experimental.cc | 36 +- tensorflow/lite/core/kernels/BUILD | 20 +- .../lite/core/kernels/builtin_op_kernels.h | 2 +- tensorflow/lite/core/kernels/register.cc | 35 +- tensorflow/lite/core/kernels/register.h | 6 + tensorflow/lite/core/model.h | 1 + tensorflow/lite/core/model_builder.cc | 20 +- tensorflow/lite/core/model_builder.h | 14 +- tensorflow/lite/core/model_test.cc | 10 +- tensorflow/lite/core/shims/BUILD | 113 +- .../lite/core/shims/c/builtin_op_data.h | 2 +- .../lite/core/shims/c/c_api_experimental.h | 2 +- .../core/shims/c/c_api_opaque.h} | 9 +- tensorflow/lite/core/shims/c/c_api_types.h | 2 +- tensorflow/lite/core/shims/c/common.h | 2 +- .../configuration/delegate_plugin.h | 16 +- .../acceleration/configuration/gpu_plugin.h | 2 +- .../acceleration/configuration/nnapi_plugin.h | 2 +- .../configuration/xnnpack_plugin.h | 2 +- .../lite/core/shims/cc/create_op_resolver.h | 2 +- .../configuration/delegate_registry.h | 2 +- tensorflow/lite/core/shims/cc/interpreter.h | 2 +- .../lite/core/shims/cc/interpreter_builder.h | 4 +- .../shims/cc/kernels/builtin_op_kernels.h | 2 +- .../lite/core/shims/cc/kernels/register.h | 2 +- tensorflow/lite/core/shims/cc/model_builder.h | 4 +- .../lite/core/shims/cc/tools/verifier.h | 2 +- .../core/shims/cc/tools/verifier_internal.h | 2 +- tensorflow/lite/core/special_rules.bzl | 27 + tensorflow/lite/core/subgraph.cc | 157 +- tensorflow/lite/core/subgraph.h | 44 +- tensorflow/lite/core/tools/BUILD | 12 +- tensorflow/lite/core/tools/verifier.h | 2 +- .../lite/core/tools/verifier_internal_test.cc | 2 +- tensorflow/lite/core/tools/verifier_test.cc | 2 +- tensorflow/lite/create_op_resolver.h | 12 +- .../create_op_resolver_with_builtin_ops.cc | 4 +- .../create_op_resolver_with_selected_ops.cc | 2 +- tensorflow/lite/delegates/BUILD | 65 +- tensorflow/lite/delegates/coreml/BUILD | 16 +- tensorflow/lite/delegates/coreml/README.md | 24 +- .../lite/delegates/coreml/builders/BUILD | 8 + .../builders/activation_layer_builder.cc | 2 +- .../builders/activation_layer_builder.h | 2 +- .../coreml/builders/add_op_builder.cc | 7 +- .../builders/concatenation_op_builder.cc | 2 +- .../coreml/builders/convolution_op_builder.h | 2 +- .../coreml/builders/dummy_op_builder.h | 2 +- .../builders/fully_connected_op_builder.cc | 2 +- .../coreml/builders/hardswish_op_builder.cc | 2 +- .../coreml/builders/mul_op_builder.cc | 6 +- .../delegates/coreml/builders/op_builder.cc | 2 +- .../delegates/coreml/builders/op_factory.h | 2 +- .../delegates/coreml/builders/op_validator.h | 2 +- .../coreml/builders/pad_op_builder.cc | 2 +- .../coreml/builders/pooling_layer_builder.cc | 2 +- .../coreml/builders/reshape_op_builder.cc | 2 +- .../builders/resize_bilinear_op_builder.cc | 2 +- .../lite/delegates/coreml/builders/util.cc | 23 +- .../lite/delegates/coreml/builders/util.h | 4 + .../delegates/coreml/builders/util_test.cc | 32 + .../lite/delegates/coreml/coreml_delegate.h | 10 + .../lite/delegates/coreml/coreml_delegate.mm | 15 +- .../coreml/coreml_delegate_kernel.mm | 12 +- tensorflow/lite/delegates/delegate_test.cc | 153 +- .../lite/delegates/delegate_test_util.cc | 6 +- .../lite/delegates/delegate_test_util.h | 15 +- tensorflow/lite/delegates/external/BUILD | 11 + .../delegates/external/external_delegate.cc | 26 +- .../external/external_delegate_interface.h | 79 + tensorflow/lite/delegates/flex/BUILD | 28 +- .../delegates/flex/allowlisted_flex_ops.cc | 1 + .../lite/delegates/flex/buffer_map_test.cc | 2 +- .../lite/delegates/flex/buffer_map_util.h | 5 + tensorflow/lite/delegates/flex/build_def.bzl | 10 +- tensorflow/lite/delegates/flex/delegate.cc | 4 + .../lite/delegates/flex/delegate_symbol.cc | 2 +- .../main/java/org/tensorflow/lite/flex/BUILD | 2 + .../delegates/flex/java/src/main/native/BUILD | 5 +- tensorflow/lite/delegates/flex/kernel.cc | 4 +- tensorflow/lite/delegates/flex/test/BUILD | 1 + .../delegates/flex/tflite_subgraph_execute.cc | 203 +- tensorflow/lite/delegates/flex/util_test.cc | 6 +- tensorflow/lite/delegates/gpu/BUILD | 18 +- tensorflow/lite/delegates/gpu/api.h | 5 + tensorflow/lite/delegates/gpu/build_defs.bzl | 14 +- tensorflow/lite/delegates/gpu/cl/BUILD | 17 +- tensorflow/lite/delegates/gpu/cl/api.cc | 2 + .../lite/delegates/gpu/cl/cl_arguments.cc | 6 + .../lite/delegates/gpu/cl/cl_arguments.h | 19 + .../lite/delegates/gpu/cl/cl_operation.h | 4 + .../lite/delegates/gpu/cl/default/BUILD | 1 + .../delegates/gpu/cl/inference_context.cc | 27 + .../lite/delegates/gpu/cl/kernels/BUILD | 1 + .../delegates/gpu/cl/kernels/converter.cc | 13 +- .../gpu/cl/kernels/elementwise_test.cc | 195 +- .../gpu/cl/kernels/lstm_full_test.cc | 2 +- .../delegates/gpu/cl/kernels/winograd_test.cc | 5 + .../lite/delegates/gpu/cl/opencl_wrapper.cc | 10 + .../lite/delegates/gpu/cl/testing/BUILD | 5 +- .../gpu/cl/testing/memory_sharing_sample.cc | 4 +- .../gpu/cl/testing/performance_profiling.cc | 6 +- tensorflow/lite/delegates/gpu/common/BUILD | 19 +- .../lite/delegates/gpu/common/default/BUILD | 1 + .../lite/delegates/gpu/common/google/BUILD | 1 + .../lite/delegates/gpu/common/gpu_info.cc | 8 + .../lite/delegates/gpu/common/gpu_info.h | 3 + .../gpu/common/gpu_model_test_util.cc | 102 + .../gpu/common/gpu_model_test_util.h | 9 + .../lite/delegates/gpu/common/lstm_parser.cc | 2 +- .../lite/delegates/gpu/common/lstm_parser.h | 2 +- .../delegates/gpu/common/model_builder.cc | 77 +- .../lite/delegates/gpu/common/model_builder.h | 2 +- .../gpu/common/model_builder_helper.cc | 2 +- .../gpu/common/model_builder_helper.h | 2 +- .../lite/delegates/gpu/common/operations.cc | 9 +- .../lite/delegates/gpu/common/operations.h | 29 +- .../lite/delegates/gpu/common/selectors/BUILD | 1 + .../gpu/common/selectors/default/BUILD | 1 + .../common/selectors/operation_selector.cc | 46 +- .../gpu/common/selectors/special_selector.cc | 3 +- .../lite/delegates/gpu/common/task/BUILD | 7 +- .../delegates/gpu/common/task/buffer_desc.cc | 3 +- .../delegates/gpu/common/task/buffer_desc.h | 3 +- .../gpu/common/task/gpu_object_desc.h | 25 +- .../gpu/common/task/gpu_operation.cc | 5 +- .../gpu/common/task/qcom_thin_filter_desc.cc | 4 +- .../gpu/common/task/qcom_thin_filter_desc.h | 3 +- .../delegates/gpu/common/task/tensor_desc.cc | 79 +- .../delegates/gpu/common/task/tensor_desc.h | 80 +- .../lite/delegates/gpu/common/tasks/BUILD | 1 + .../gpu/common/tasks/conv_generic.cc | 11 +- .../delegates/gpu/common/tasks/elementwise.cc | 92 +- .../delegates/gpu/common/tasks/elementwise.h | 21 +- .../gpu/common/tasks/elementwise_test_util.cc | 90 + .../gpu/common/tasks/elementwise_test_util.h | 2 + .../lite/delegates/gpu/common/tasks/reduce.cc | 48 +- .../gpu/common/tasks/reduce_test_util.cc | 55 +- .../delegates/gpu/common/tasks/special/BUILD | 1 + .../delegates/gpu/common/tasks/winograd.cc | 116 +- .../gpu/common/tasks/winograd_test_util.cc | 69 + .../gpu/common/tasks/winograd_test_util.h | 1 + .../lite/delegates/gpu/common/testing/BUILD | 5 +- .../gpu/common/testing/feature_parity/BUILD | 5 +- .../testing/feature_parity/generators/BUILD | 1 + .../common/testing/feature_parity/utils.cc | 4 +- .../gpu/common/testing/feature_parity/utils.h | 2 +- .../gpu/common/testing/interpreter_utils.cc | 4 +- .../gpu/common/transformations/BUILD | 1 + tensorflow/lite/delegates/gpu/delegate.cc | 35 + .../lite/delegates/gpu/delegate_options.h | 5 + tensorflow/lite/delegates/gpu/gl/BUILD | 1 + .../lite/delegates/gpu/gl/compiler/BUILD | 1 + .../lite/delegates/gpu/gl/converters/BUILD | 1 + .../lite/delegates/gpu/gl/kernels/BUILD | 1 + .../lite/delegates/gpu/gl/kernels/resize.cc | 3 +- .../lite/delegates/gpu/gl/kernels/slice.cc | 52 +- .../lite/delegates/gpu/gl/runtime/BUILD | 1 + .../lite/delegates/gpu/gl/workgroups/BUILD | 1 + .../main/java/org/tensorflow/lite/gpu/BUILD | 1 + .../org/tensorflow/lite/gpu/GpuDelegate.java | 6 +- .../lite/gpu/GpuDelegateFactory.java | 33 + .../delegates/gpu/java/src/main/native/BUILD | 2 + .../java/src/main/native/gpu_delegate_jni.cc | 15 +- tensorflow/lite/delegates/gpu/metal/BUILD | 1 + .../delegates/gpu/metal/benchmarking/BUILD | 5 +- .../delegates/gpu/metal/benchmarking/main.mm | 10 +- .../lite/delegates/gpu/metal/kernels/BUILD | 2 + .../gpu/metal/kernels/elementwise_test.mm | 10 + tensorflow/lite/delegates/hexagon/BUILD | 5 + .../lite/delegates/hexagon/builders/BUILD | 7 +- .../hexagon/builders/activation_builder.cc | 2 +- .../hexagon/builders/arithmetic_builder.cc | 2 +- .../hexagon/builders/cast_builder.cc | 2 +- .../hexagon/builders/concat_builder.cc | 2 +- .../hexagon/builders/conv_2d_builder.cc | 2 +- .../hexagon/builders/conv_2d_helpers.cc | 2 +- .../hexagon/builders/hardswish_builder.cc | 2 +- .../builders/l2_normalization_builder.cc | 2 +- .../hexagon/builders/matmul_builder.cc | 2 +- .../hexagon/builders/mirror_pad_builder.cc | 2 +- .../delegates/hexagon/builders/op_builder.h | 2 +- .../hexagon/builders/pack_builder.cc | 2 +- .../delegates/hexagon/builders/pad_builder.cc | 2 +- .../hexagon/builders/pool_2d_builder.cc | 2 +- .../hexagon/builders/quantize_builder.cc | 2 +- .../hexagon/builders/reduce_builder.cc | 2 +- .../hexagon/builders/reshape_builder.cc | 2 +- .../resize_nearest_neighbor_builder.cc | 2 +- .../hexagon/builders/softmax_builder.cc | 2 +- .../builders/space_to_depth_builder.cc | 2 +- .../hexagon/builders/split_builder.cc | 2 +- .../delegates/hexagon/builders/tests/BUILD | 7 +- .../tests/hexagon_delegate_op_model.h | 4 +- .../builders/transpose_conv_2d_builder.cc | 2 +- .../hexagon/hexagon_delegate_kernel.cc | 2 +- .../hexagon/hexagon_delegate_kernel.h | 2 +- tensorflow/lite/delegates/hexagon/java/BUILD | 1 + .../src/main/java/org/tensorflow/lite/BUILD | 2 + .../hexagon/java/src/main/native/BUILD | 5 +- tensorflow/lite/delegates/hexagon/utils.cc | 2 +- tensorflow/lite/delegates/nnapi/BUILD | 20 +- .../main/java/org/tensorflow/lite/nnapi/BUILD | 1 + .../nnapi/java/src/main/native/BUILD | 1 + .../lite/delegates/nnapi/nnapi_delegate.cc | 71 +- .../lite/delegates/nnapi/nnapi_delegate.h | 12 + .../nnapi_delegate_device_selection_test.cc | 2 +- .../nnapi/nnapi_delegate_errno_test.cc | 2 +- ...pi_delegate_nnapi_failure_handling_test.cc | 2 +- ...nnapi_delegate_signed_quantization_test.cc | 2 +- .../delegates/nnapi/nnapi_delegate_test.cc | 2 +- .../lite/delegates/opaque_delegate_test.cc | 306 + tensorflow/lite/delegates/utils.cc | 19 +- tensorflow/lite/delegates/utils.h | 29 +- tensorflow/lite/delegates/utils/BUILD | 36 +- .../lite/delegates/utils/dummy_delegate/BUILD | 16 + .../utils/dummy_delegate/dummy_delegate.cc | 1 + .../utils/dummy_delegate/dummy_delegate.h | 1 + .../external_delegate_adaptor.cc | 30 +- .../dummy_delegate/external_delegate_test.sh | 35 + .../experimental/sample_stable_delegate/BUILD | 125 + .../sample_stable_delegate/README.md | 160 + .../sample_stable_delegate.cc | 210 + .../sample_stable_delegate.h | 65 + .../sample_stable_delegate_external.cc | 60 + .../sample_stable_delegate_external_test.cc | 117 + .../sample_stable_delegate_test.cc | 91 + .../utils/experimental/stable_delegate/BUILD | 126 + .../stable_delegate/delegate_loader.cc | 77 + .../stable_delegate/delegate_loader.h | 49 + .../stable_delegate/delegate_loader_test.cc | 76 + .../stable_delegate_interface.h | 59 + .../stable_xnnpack_delegate.cc | 33 + .../test_xnnpack_settings.json | 11 + .../tflite_settings_json_parser.cc | 76 + .../tflite_settings_json_parser.h | 70 + .../tflite_settings_json_parser_test.cc | 71 + .../stable_delegate/version_script.lds | 28 + .../delegates/utils/simple_delegate_test.cc | 2 +- .../delegates/utils/simple_opaque_delegate.cc | 30 +- .../delegates/utils/simple_opaque_delegate.h | 29 +- .../utils/simple_opaque_delegate_test.cc | 241 +- tensorflow/lite/delegates/xnnpack/BUILD | 217 +- tensorflow/lite/delegates/xnnpack/README.md | 204 +- .../xnnpack/binary_elementwise_tester.cc | 4 +- .../xnnpack/binary_elementwise_tester.h | 2 +- .../delegates/xnnpack/concatenation_tester.cc | 4 +- .../delegates/xnnpack/concatenation_tester.h | 2 +- .../lite/delegates/xnnpack/conv_2d_tester.cc | 4 +- .../lite/delegates/xnnpack/conv_2d_tester.h | 2 +- .../xnnpack/depth_to_space_tester.cc | 4 +- .../delegates/xnnpack/depth_to_space_tester.h | 2 +- .../xnnpack/depthwise_conv_2d_tester.cc | 4 +- .../xnnpack/depthwise_conv_2d_tester.h | 2 +- .../delegates/xnnpack/dequantize_tester.cc | 4 +- .../delegates/xnnpack/dequantize_tester.h | 2 +- .../xnnpack/fully_connected_tester.cc | 4 +- .../xnnpack/fully_connected_tester.h | 2 +- .../delegates/xnnpack/leaky_relu_tester.cc | 4 +- .../delegates/xnnpack/leaky_relu_tester.h | 2 +- .../lite/delegates/xnnpack/pad_tester.cc | 4 +- .../lite/delegates/xnnpack/pad_tester.h | 2 +- .../lite/delegates/xnnpack/pool_2d_tester.cc | 4 +- .../lite/delegates/xnnpack/pool_2d_tester.h | 2 +- .../lite/delegates/xnnpack/prelu_tester.cc | 4 +- .../lite/delegates/xnnpack/prelu_tester.h | 2 +- .../lite/delegates/xnnpack/quantize_tester.cc | 4 +- .../lite/delegates/xnnpack/quantize_tester.h | 2 +- .../quantized_binary_elementwise_tester.cc | 4 +- .../quantized_binary_elementwise_tester.h | 2 +- .../xnnpack/quantized_conv_2d_tester.cc | 4 +- .../xnnpack/quantized_conv_2d_tester.h | 2 +- .../quantized_depthwise_conv_2d_tester.cc | 4 +- .../quantized_depthwise_conv_2d_tester.h | 2 +- .../quantized_fully_connected_tester.cc | 4 +- .../quantized_fully_connected_tester.h | 2 +- .../xnnpack/quantized_leaky_relu_tester.cc | 4 +- .../xnnpack/quantized_leaky_relu_tester.h | 2 +- .../delegates/xnnpack/quantized_pad_tester.cc | 4 +- .../delegates/xnnpack/quantized_pad_tester.h | 2 +- .../xnnpack/quantized_pool_2d_tester.cc | 4 +- .../xnnpack/quantized_pool_2d_tester.h | 2 +- .../xnnpack/quantized_reduce_tester.cc | 4 +- .../xnnpack/quantized_reduce_tester.h | 2 +- .../quantized_resize_bilinear_tester.cc | 4 +- .../quantized_resize_bilinear_tester.h | 2 +- .../quantized_transpose_conv_tester.cc | 6 +- .../xnnpack/quantized_transpose_conv_tester.h | 5 +- .../quantized_unary_elementwise_tester.cc | 4 +- .../quantized_unary_elementwise_tester.h | 2 +- .../lite/delegates/xnnpack/reduce_tester.cc | 4 +- .../lite/delegates/xnnpack/reduce_tester.h | 2 +- .../lite/delegates/xnnpack/reshape_test.cc | 20 +- .../lite/delegates/xnnpack/reshape_tester.cc | 117 +- .../lite/delegates/xnnpack/reshape_tester.h | 12 +- .../xnnpack/resize_bilinear_tester.cc | 4 +- .../xnnpack/resize_bilinear_tester.h | 2 +- .../xnnpack/signed_quantized_reshape_test.cc | 226 + .../lite/delegates/xnnpack/slice_tester.cc | 2 +- .../lite/delegates/xnnpack/slice_tester.h | 2 +- .../lite/delegates/xnnpack/softmax_tester.cc | 4 +- .../lite/delegates/xnnpack/softmax_tester.h | 2 +- .../xnnpack/space_to_depth_tester.cc | 2 +- .../delegates/xnnpack/space_to_depth_tester.h | 2 +- .../lite/delegates/xnnpack/split_tester.cc | 4 +- .../lite/delegates/xnnpack/split_tester.h | 2 +- .../delegates/xnnpack/strided_slice_tester.cc | 2 +- .../delegates/xnnpack/strided_slice_tester.h | 2 +- .../xnnpack/transpose_conv_tester.cc | 6 +- .../delegates/xnnpack/transpose_conv_tester.h | 2 +- .../delegates/xnnpack/transpose_tester.cc | 4 +- .../lite/delegates/xnnpack/transpose_tester.h | 2 +- .../xnnpack/unary_elementwise_tester.cc | 4 +- .../xnnpack/unary_elementwise_tester.h | 2 +- .../unsigned_quantized_reshape_test.cc | 226 + .../delegates/xnnpack/variable_ops_tester.cc | 4 +- .../delegates/xnnpack/variable_ops_tester.h | 2 +- .../delegates/xnnpack/weights_cache_test.cc | 4 +- .../delegates/xnnpack/xnnpack_delegate.cc | 128 +- .../lite/delegates/xnnpack/xnnpack_delegate.h | 8 +- .../examples/experimental_new_converter/BUILD | 1 + tensorflow/lite/examples/label_image/BUILD | 2 + .../examples/label_image/bitmap_helpers.cc | 11 +- tensorflow/lite/examples/minimal/BUILD | 1 + tensorflow/lite/examples/python/BUILD | 1 + .../acceleration/compatibility/BUILD | 1 + .../acceleration/configuration/BUILD | 127 +- .../acceleration/configuration/build_defs.bzl | 43 + .../acceleration/configuration/c/BUILD | 71 +- .../configuration/c/delegate_plugin.h | 44 +- .../acceleration/configuration/c/gpu_plugin.h | 18 +- .../configuration/c/nnapi_plugin.h | 18 +- .../{vendor_delegate.h => stable_delegate.h} | 8 +- .../configuration/c/xnnpack_plugin.h | 18 +- .../configuration/configuration.proto | 79 +- .../configuration/configuration_generated.h | 591 +- .../configuration/coreml_plugin.cc | 2 +- .../configuration/delegate_registry.h | 18 +- .../configuration/flatbuffer_to_proto.cc | 34 +- .../configuration/flatbuffer_to_proto_test.cc | 15 + .../acceleration/configuration/gpu_plugin.h | 2 +- .../configuration/gpu_plugin_test.cc | 2 +- .../configuration/hexagon_plugin.cc | 2 +- .../acceleration/configuration/nnapi_plugin.h | 6 +- .../configuration/nnapi_plugin_test.cc | 8 +- .../prev_is_different_than_current_test.sh} | 18 +- .../configuration/proto_to_flatbuffer.cc | 97 +- .../configuration/stable_delegate_plugin.cc | 27 + .../configuration/stable_delegate_plugin.h | 92 + .../stable_delegate_plugin_test.cc | 109 + .../testdata/configuration.old.fbs | 355 + .../testdata/configuration.proto_prev | 785 +++ .../configuration/xnnpack_plugin.cc | 2 +- .../configuration/xnnpack_plugin_test.cc | 2 +- .../acceleration/mini_benchmark/BUILD | 145 +- .../benchmark_result_evaluator.cc | 16 +- .../benchmark_result_evaluator.h | 36 +- .../blocking_validator_runner.cc | 62 +- .../blocking_validator_runner.h | 7 +- .../blocking_validator_runner_test.cc | 67 +- .../mini_benchmark/build_defs.bzl | 22 +- .../acceleration/mini_benchmark/c/BUILD | 15 +- .../acceleration/mini_benchmark/c/c_api.cc | 161 +- .../acceleration/mini_benchmark/c/c_api.h | 82 +- .../mini_benchmark/c/c_api_test.cc | 198 +- .../mini_benchmark/c/c_api_types.h | 90 + .../acceleration/mini_benchmark/call.cc | 2 +- .../mini_benchmark/call_register.h | 2 +- .../acceleration/mini_benchmark/call_test.cc | 6 +- .../mini_benchmark/decode_jpeg.cc | 4 +- .../mini_benchmark/decode_jpeg_register.h | 2 +- .../mini_benchmark/decode_jpeg_status.h | 2 +- .../acceleration/mini_benchmark/fb_storage.cc | 2 +- .../acceleration/mini_benchmark/fb_storage.h | 4 +- .../mini_benchmark/fb_storage_test.cc | 2 +- .../mini_benchmark/gpu_module_plugin.cc | 110 + .../mini_benchmark/gpu_module_plugin.h | 63 + .../mini_benchmark/jpeg_header_parser.cc | 6 +- .../mini_benchmark/jpeg_header_parser.h | 2 +- .../mini_benchmark/jpeg_header_parser_test.cc | 14 +- .../mini_benchmark/libjpeg_decoder.cc | 2 +- .../mini_benchmark/libjpeg_decoder.h | 2 +- .../mini_benchmark/libjpeg_decoder_test.cc | 49 +- .../mini_benchmark/libjpeg_handle.cc | 2 +- .../mini_benchmark/libjpeg_handle_test.cc | 2 +- .../acceleration/mini_benchmark/metrics/BUILD | 1 + .../mini_benchmark/mini_benchmark_test.cc | 2 +- .../mini_benchmark/model_modifier/BUILD | 20 +- .../custom_validation_embedder.cc | 19 +- .../custom_validation_embedder.h | 11 +- .../custom_validation_embedder_test.cc | 25 +- .../mini_benchmark/model_modifier/embedder.cc | 9 +- .../mini_benchmark/model_modifier/embedder.h | 2 +- .../model_modifier/embedder_main.cc | 2 +- .../validation_graph_builder.cc | 2 +- .../mini_benchmark/model_validation_test.cc | 13 +- .../acceleration/mini_benchmark/models/BUILD | 1 + .../runner_test_entry_points.cc | 15 +- .../mini_benchmark/status_codes.h | 30 +- .../acceleration/mini_benchmark/validator.cc | 130 +- .../acceleration/mini_benchmark/validator.h | 29 +- .../mini_benchmark/validator_runner.cc | 6 +- .../mini_benchmark/validator_runner.h | 2 +- .../validator_runner_entrypoint.cc | 130 +- .../validator_runner_entrypoint_test.cc | 9 +- .../mini_benchmark/validator_runner_impl.cc | 219 +- .../mini_benchmark/validator_runner_impl.h | 40 +- .../validator_runner_impl_test.cc | 158 +- .../mini_benchmark/validator_runner_options.h | 15 +- .../mini_benchmark/validator_runner_test.cc | 8 + .../mini_benchmark/validator_test.cc | 77 +- .../lite/experimental/microfrontend/BUILD | 3 +- .../microfrontend/audio_microfrontend_test.cc | 2 +- .../lite/experimental/microfrontend/lib/BUILD | 1 + .../audio_microfrontend_op_test.py | 22 +- tensorflow/lite/experimental/remat/BUILD | 1 + .../lite/experimental/remat/metadata_util.h | 1 + tensorflow/lite/experimental/resource/BUILD | 9 +- .../resource/initialization_status.h | 2 +- .../experimental/resource/lookup_interfaces.h | 2 +- .../resource/resource_variable.cc | 2 +- .../experimental/resource/resource_variable.h | 2 +- .../resource/resource_variable_test.cc | 4 +- .../experimental/resource/static_hashtable.h | 2 +- .../lite/external_cpu_backend_context.cc | 2 +- .../lite/external_cpu_backend_context.h | 2 +- tensorflow/lite/g3doc/android/lite_build.md | 6 +- .../lite/g3doc/android/play_services.md | 4 +- tensorflow/lite/g3doc/api_docs/index.md | 1 + .../examples/audio_classification/overview.md | 4 +- tensorflow/lite/g3doc/guide/build_cmake.md | 1 + .../lite/g3doc/guide/build_cmake_arm.md | 17 +- tensorflow/lite/g3doc/guide/inference.md | 38 + tensorflow/lite/g3doc/guide/ios.md | 6 +- .../lite/g3doc/guide/op_select_allowlist.md | 1 + .../lite/g3doc/guide/ops_compatibility.md | 2 +- tensorflow/lite/g3doc/guide/ops_custom.md | 75 +- tensorflow/lite/g3doc/guide/signatures.ipynb | 4 +- .../task_library/image_embedder.md | 6 +- .../microcontrollers/get_started_low_level.md | 2 +- .../lite/g3doc/microcontrollers/index.md | 4 +- .../lite/g3doc/microcontrollers/library.md | 4 +- tensorflow/lite/g3doc/performance/gpu.md | 1 + .../performance/post_training_quantization.md | 9 +- tensorflow/lite/g3doc/tools/BUILD | 1 + tensorflow/lite/generate-pc.sh | 73 + tensorflow/lite/graph_info.cc | 2 +- tensorflow/lite/graph_info.h | 6 +- tensorflow/lite/graph_info_test.cc | 7 +- tensorflow/lite/internal/BUILD | 1 + tensorflow/lite/interpreter.h | 13 +- tensorflow/lite/interpreter_builder.h | 12 +- tensorflow/lite/interpreter_options.h | 1 + tensorflow/lite/interpreter_test.cc | 23 +- tensorflow/lite/ios/BUILD.apple | 26 +- tensorflow/lite/ios/TensorFlowLiteC.h | 7 +- tensorflow/lite/ios/TensorFlowLiteC.podspec | 4 +- .../ios/TensorFlowLiteSelectTfOps.podspec | 4 +- tensorflow/lite/java/AndroidManifest.xml | 5 +- tensorflow/lite/java/BUILD | 11 +- tensorflow/lite/java/demo/app/src/main/BUILD | 1 + .../lite/java/demo/app/src/main/assets/BUILD | 1 + tensorflow/lite/java/ovic/BUILD | 1 + tensorflow/lite/java/ovic/demo/app/BUILD | 1 + tensorflow/lite/java/ovic/src/testdata/BUILD | 1 + .../java/org/tensorflow/lite/Interpreter.java | 22 +- .../org/tensorflow/lite/InterpreterApi.java | 41 + .../org/tensorflow/lite/InterpreterImpl.java | 6 - .../lite/NativeInterpreterWrapper.java | 20 +- .../ValidatedAccelerationConfig.java | 37 + tensorflow/lite/java/src/main/native/BUILD | 1 + .../native/interpreter_factory_impl_jni.cc | 5 +- .../native/op_resolver_lazy_delegate_proxy.cc | 9 +- .../native/op_resolver_lazy_delegate_proxy.h | 2 +- .../tensorflow/lite/InterpreterApiTest.java | 43 +- .../tensorflow/lite/gpu/GpuDelegateTest.java | 18 + tensorflow/lite/java/src/test/native/BUILD | 5 +- .../src/test/native/interpreter_test_jni.cc | 2 +- .../testhelper/java/org/tensorflow/lite/BUILD | 1 + tensorflow/lite/kernels/BUILD | 142 +- tensorflow/lite/kernels/CMakeLists.txt | 20 +- tensorflow/lite/kernels/activations.cc | 8 +- tensorflow/lite/kernels/add.cc | 4 +- tensorflow/lite/kernels/add_n.cc | 2 +- tensorflow/lite/kernels/arg_min_max.cc | 6 +- tensorflow/lite/kernels/assign_variable.cc | 2 +- tensorflow/lite/kernels/atan2.cc | 2 +- tensorflow/lite/kernels/atan2_custom.cc | 2 +- tensorflow/lite/kernels/audio_spectrogram.cc | 2 +- tensorflow/lite/kernels/basic_rnn.cc | 4 +- tensorflow/lite/kernels/batch_matmul.cc | 4 +- tensorflow/lite/kernels/batch_to_space_nd.cc | 2 +- .../lite/kernels/batch_to_space_nd_test.cc | 2 +- .../kernels/bidirectional_sequence_lstm.cc | 23 +- .../kernels/bidirectional_sequence_rnn.cc | 8 +- tensorflow/lite/kernels/broadcast_args.cc | 2 +- .../lite/kernels/broadcast_args_test.cc | 6 +- tensorflow/lite/kernels/broadcast_to.cc | 4 +- tensorflow/lite/kernels/broadcast_to_test.cc | 6 +- tensorflow/lite/kernels/bucketize.cc | 12 +- tensorflow/lite/kernels/bucketize_test.cc | 2 +- tensorflow/lite/kernels/builtin_op_kernels.h | 21 +- tensorflow/lite/kernels/builtin_ops_list.inc | 173 + tensorflow/lite/kernels/call_once.cc | 4 +- tensorflow/lite/kernels/cast.cc | 51 +- tensorflow/lite/kernels/ceil.cc | 2 +- tensorflow/lite/kernels/comparisons.cc | 2 +- tensorflow/lite/kernels/complex_support.cc | 2 +- tensorflow/lite/kernels/concatenation.cc | 5 +- tensorflow/lite/kernels/concatenation_test.cc | 2 +- tensorflow/lite/kernels/conv.cc | 51 +- tensorflow/lite/kernels/conv3d.cc | 4 +- tensorflow/lite/kernels/conv3d_transpose.cc | 4 +- tensorflow/lite/kernels/conv_mem_test.cc | 6 +- tensorflow/lite/kernels/conv_test.cc | 49 +- .../lite/kernels/cpu_backend_context.cc | 12 +- tensorflow/lite/kernels/cpu_backend_context.h | 10 +- tensorflow/lite/kernels/ctc/BUILD | 7 +- .../kernels/ctc/ctc_beam_search_decoder.cc | 2 +- .../ctc/ctc_beam_search_decoder_test.cc | 4 +- tensorflow/lite/kernels/cumsum.cc | 5 +- tensorflow/lite/kernels/custom_ops_register.h | 2 +- tensorflow/lite/kernels/densify.cc | 2 +- tensorflow/lite/kernels/densify_test.cc | 2 +- tensorflow/lite/kernels/depth_to_space.cc | 4 +- .../lite/kernels/depth_to_space_test.cc | 2 +- tensorflow/lite/kernels/depthwise_conv.cc | 80 +- .../lite/kernels/depthwise_conv_test.cc | 103 + tensorflow/lite/kernels/dequantize.cc | 2 +- tensorflow/lite/kernels/dequantize.h | 2 +- .../lite/kernels/detection_postprocess.cc | 2 +- tensorflow/lite/kernels/div.cc | 4 +- .../lite/kernels/dynamic_update_slice.cc | 9 +- .../lite/kernels/dynamic_update_slice_test.cc | 14 + tensorflow/lite/kernels/eigen_support.cc | 2 +- tensorflow/lite/kernels/eigen_support.h | 2 +- tensorflow/lite/kernels/eigen_support_test.cc | 2 +- tensorflow/lite/kernels/elementwise.cc | 7 +- tensorflow/lite/kernels/elementwise_test.cc | 24 + tensorflow/lite/kernels/embedding_lookup.cc | 2 +- .../lite/kernels/embedding_lookup_sparse.cc | 4 +- .../lite/kernels/embedding_lookup_test.cc | 36 +- tensorflow/lite/kernels/exp.cc | 2 +- tensorflow/lite/kernels/expand_dims.cc | 9 +- tensorflow/lite/kernels/fake_quant.cc | 4 +- tensorflow/lite/kernels/fill.cc | 2 +- tensorflow/lite/kernels/floor.cc | 2 +- tensorflow/lite/kernels/floor_div.cc | 2 +- tensorflow/lite/kernels/floor_mod.cc | 2 +- tensorflow/lite/kernels/fully_connected.cc | 52 +- tensorflow/lite/kernels/fully_connected.h | 2 +- .../lite/kernels/fully_connected_test.cc | 53 +- tensorflow/lite/kernels/gather.cc | 4 +- tensorflow/lite/kernels/gather_nd.cc | 4 +- tensorflow/lite/kernels/gradient/BUILD | 3 +- .../lite/kernels/gradient/bcast_grad_args.cc | 4 +- .../lite/kernels/gradient/bcast_grad_args.h | 2 +- tensorflow/lite/kernels/hashtable.cc | 4 +- tensorflow/lite/kernels/hashtable_find.cc | 2 +- tensorflow/lite/kernels/hashtable_import.cc | 2 +- tensorflow/lite/kernels/hashtable_lookup.cc | 2 +- tensorflow/lite/kernels/hashtable_ops_test.cc | 2 +- tensorflow/lite/kernels/hashtable_size.cc | 2 +- tensorflow/lite/kernels/if.cc | 6 +- tensorflow/lite/kernels/internal/BUILD | 28 +- tensorflow/lite/kernels/internal/common.h | 184 +- .../lite/kernels/internal/kernel_utils.h | 2 +- .../kernels/internal/optimized/batch_matmul.h | 2 +- .../optimized/eigen_spatial_convolutions.h | 2 +- .../optimized/integer_ops/transpose_conv.h | 9 +- .../internal/optimized/multithreaded_conv.h | 3 +- .../internal/optimized/neon_tensor_utils.cc | 4 +- .../internal/optimized/optimized_ops.h | 26 +- .../internal/optimized/resize_bilinear.h | 2 +- .../optimized/sparse_ops/fully_connected.h | 2 +- .../lite/kernels/internal/portable_tensor.h | 2 +- .../kernels/internal/portable_tensor_utils.cc | 2 +- .../kernels/internal/portable_tensor_utils.h | 4 +- .../internal/quantization_util_test.cc | 4 +- .../kernels/internal/reference/comparisons.h | 2 +- .../lite/kernels/internal/reference/densify.h | 2 +- .../lite/kernels/internal/reference/gather.h | 3 +- .../reference/integer_ops/depthwise_conv.h | 16 + .../reference/integer_ops/fully_connected.h | 138 +- .../internal/reference/integer_ops/mul.h | 10 +- .../reference/integer_ops/transpose_conv.h | 8 +- .../internal/reference/l2normalization.h | 2 +- .../internal/reference/legacy_reference_ops.h | 4 + .../reference/portable_tensor_utils.cc | 2 +- .../lite/kernels/internal/reference/reduce.h | 16 + .../internal/reference/reference_ops.h | 4 +- .../internal/reference/strided_slice.h | 8 +- .../internal/reference/string_comparisons.h | 2 +- .../lite/kernels/internal/reference/svdf.h | 4 +- .../internal/reference/transpose_conv.h | 22 +- .../lite/kernels/internal/tensor_ctypes.h | 2 +- .../lite/kernels/internal/tensor_utils.h | 2 +- .../kernels/internal/tensor_utils_test.cc | 8 +- tensorflow/lite/kernels/internal/utils/BUILD | 4 +- .../utils/sparsity_format_converter.h | 2 +- .../utils/sparsity_format_converter_test.cc | 2 +- tensorflow/lite/kernels/irfft2d.cc | 2 +- tensorflow/lite/kernels/kernel_util.cc | 4 +- tensorflow/lite/kernels/kernel_util.h | 4 +- tensorflow/lite/kernels/kernel_util_test.cc | 4 +- tensorflow/lite/kernels/l2norm.cc | 4 +- .../lite/kernels/local_response_norm.cc | 4 +- tensorflow/lite/kernels/logical.cc | 2 +- tensorflow/lite/kernels/lsh_projection.cc | 4 +- tensorflow/lite/kernels/lstm.cc | 16 +- tensorflow/lite/kernels/lstm_eval.cc | 208 +- tensorflow/lite/kernels/lstm_eval.h | 13 +- tensorflow/lite/kernels/lstm_eval_test.cc | 10 +- tensorflow/lite/kernels/lstm_test.cc | 2 +- tensorflow/lite/kernels/matrix_diag.cc | 2 +- tensorflow/lite/kernels/matrix_set_diag.cc | 2 +- tensorflow/lite/kernels/maximum_minimum.cc | 2 +- tensorflow/lite/kernels/mfcc.cc | 2 +- tensorflow/lite/kernels/mirror_pad.cc | 4 +- tensorflow/lite/kernels/mul.cc | 4 +- tensorflow/lite/kernels/multinomial.cc | 2 +- tensorflow/lite/kernels/neg.cc | 2 +- .../lite/kernels/non_max_suppression.cc | 2 +- tensorflow/lite/kernels/numeric_verify.cc | 2 +- tensorflow/lite/kernels/one_hot.cc | 4 +- tensorflow/lite/kernels/pack.cc | 4 +- tensorflow/lite/kernels/pad.cc | 2 +- tensorflow/lite/kernels/pad_test.cc | 8 +- tensorflow/lite/kernels/padding.h | 2 +- tensorflow/lite/kernels/parse_example/BUILD | 18 +- .../kernels/parse_example/parse_example.cc | 2 +- .../parse_example/parse_example_test.cc | 8 +- tensorflow/lite/kernels/perception/BUILD | 3 +- .../kernels/perception/dense_image_warp.cc | 4 +- .../perception/max_pool_with_argmax.cc | 4 +- .../kernels/perception/max_unpooling_2d.cc | 4 +- tensorflow/lite/kernels/pooling.cc | 4 +- tensorflow/lite/kernels/pooling3d.cc | 2 +- tensorflow/lite/kernels/pooling3d_test.cc | 2 +- tensorflow/lite/kernels/pooling_test.cc | 2 +- tensorflow/lite/kernels/pow.cc | 2 +- tensorflow/lite/kernels/quantize.cc | 2 +- tensorflow/lite/kernels/random_ops.cc | 2 +- .../kernels/random_standard_normal_custom.cc | 2 +- .../lite/kernels/random_uniform_custom.cc | 2 +- tensorflow/lite/kernels/range.cc | 7 +- tensorflow/lite/kernels/rank.cc | 2 +- tensorflow/lite/kernels/read_variable.cc | 2 +- tensorflow/lite/kernels/reduce.cc | 14 +- tensorflow/lite/kernels/register.h | 16 +- tensorflow/lite/kernels/register_ref.cc | 2 +- tensorflow/lite/kernels/register_ref.h | 2 +- tensorflow/lite/kernels/reshape.cc | 11 +- tensorflow/lite/kernels/reshape_test.cc | 8 +- tensorflow/lite/kernels/resize_bilinear.cc | 4 +- .../lite/kernels/resize_nearest_neighbor.cc | 4 +- tensorflow/lite/kernels/reverse.cc | 2 +- tensorflow/lite/kernels/reverse_sequence.cc | 4 +- tensorflow/lite/kernels/rfft2d.cc | 2 +- tensorflow/lite/kernels/roll.cc | 2 +- tensorflow/lite/kernels/roll_test.cc | 2 +- tensorflow/lite/kernels/round.cc | 2 +- tensorflow/lite/kernels/scatter_nd.cc | 4 +- tensorflow/lite/kernels/segment_sum.cc | 2 +- tensorflow/lite/kernels/select.cc | 2 +- tensorflow/lite/kernels/shape.cc | 4 +- tensorflow/lite/kernels/shim/BUILD | 38 +- tensorflow/lite/kernels/shim/README.md | 3 + tensorflow/lite/kernels/shim/op_kernel.h | 33 +- tensorflow/lite/kernels/shim/tensor_view.h | 16 +- tensorflow/lite/kernels/shim/test_op/BUILD | 81 +- .../lite/kernels/shim/test_op/README.md | 11 +- .../lite/kernels/shim/test_op/simple_op.h | 3 + .../kernels/shim/test_op/simple_tflite_op.cc | 2 +- .../kernels/shim/test_op/simple_tflite_op.h | 2 +- .../lite/kernels/shim/test_op/tmpl_op.h | 104 + .../lite/kernels/shim/test_op/tmpl_tf_op.cc | 41 + .../lite/kernels/shim/test_op/tmpl_tf_op.h | 33 + .../kernels/shim/test_op/tmpl_tf_op_test.cc | 78 + .../kernels/shim/test_op/tmpl_tflite_op.cc | 49 + .../kernels/shim/test_op/tmpl_tflite_op.h} | 30 +- .../shim/test_op/tmpl_tflite_op_test.cc | 141 + tensorflow/lite/kernels/shim/test_util.h | 2 +- tensorflow/lite/kernels/shim/tf_op_shim.h | 12 +- .../lite/kernels/shim/tflite_op_shim.cc | 8 +- tensorflow/lite/kernels/shim/tflite_op_shim.h | 13 +- .../lite/kernels/shim/tflite_op_wrapper.h | 369 + .../kernels/shim/tflite_op_wrapper_test.cc | 571 ++ .../lite/kernels/shim/tflite_tensor_view.cc | 2 +- .../lite/kernels/shim/tflite_tensor_view.h | 2 +- tensorflow/lite/kernels/sign.cc | 6 +- tensorflow/lite/kernels/sign_custom.cc | 2 +- tensorflow/lite/kernels/sign_test.cc | 46 +- tensorflow/lite/kernels/skip_gram.cc | 6 +- tensorflow/lite/kernels/slice.cc | 2 +- tensorflow/lite/kernels/slice_test.cc | 2 +- tensorflow/lite/kernels/space_to_batch_nd.cc | 2 +- .../lite/kernels/space_to_batch_nd_test.cc | 4 +- tensorflow/lite/kernels/space_to_depth.cc | 4 +- .../lite/kernels/space_to_depth_test.cc | 2 +- tensorflow/lite/kernels/sparse_to_dense.cc | 2 +- tensorflow/lite/kernels/split.cc | 4 +- tensorflow/lite/kernels/split_v.cc | 4 +- tensorflow/lite/kernels/squared_difference.cc | 2 +- tensorflow/lite/kernels/squeeze.cc | 11 +- tensorflow/lite/kernels/strided_slice.cc | 86 +- tensorflow/lite/kernels/strided_slice_test.cc | 2 +- tensorflow/lite/kernels/sub.cc | 4 +- tensorflow/lite/kernels/subgraph_test_util.cc | 6 +- tensorflow/lite/kernels/svdf.cc | 4 +- tensorflow/lite/kernels/table.cc | 2 +- tensorflow/lite/kernels/test_util.cc | 6 +- tensorflow/lite/kernels/test_util.h | 100 +- tensorflow/lite/kernels/tile.cc | 2 +- tensorflow/lite/kernels/topk_v2.cc | 4 +- tensorflow/lite/kernels/transpose.cc | 78 +- tensorflow/lite/kernels/transpose_conv.cc | 13 +- .../lite/kernels/transpose_conv_test.cc | 293 +- tensorflow/lite/kernels/transpose_test.cc | 4 +- .../kernels/unidirectional_sequence_gru.cc | 2 +- .../kernels/unidirectional_sequence_lstm.cc | 91 +- .../unidirectional_sequence_lstm_test.cc | 441 +- .../unidirectional_sequence_lstm_test_util.h | 125 +- .../kernels/unidirectional_sequence_rnn.cc | 4 +- tensorflow/lite/kernels/unique.cc | 4 +- tensorflow/lite/kernels/unpack.cc | 4 +- tensorflow/lite/kernels/unsorted_segment.cc | 2 +- .../lite/kernels/unsorted_segment_test.h | 2 +- tensorflow/lite/kernels/var_handle.cc | 5 +- tensorflow/lite/kernels/variable_ops_test.cc | 4 +- tensorflow/lite/kernels/where.cc | 4 +- tensorflow/lite/kernels/where_test.cc | 2 +- tensorflow/lite/kernels/while.cc | 22 +- tensorflow/lite/kernels/zeros_like.cc | 2 +- tensorflow/lite/memory_planner.h | 2 +- tensorflow/lite/model.h | 12 +- tensorflow/lite/model_builder.h | 9 +- tensorflow/lite/model_flex_test.cc | 6 +- tensorflow/lite/model_xnnpack_test.cc | 6 +- tensorflow/lite/mutable_op_resolver.cc | 2 +- tensorflow/lite/mutable_op_resolver.h | 2 +- tensorflow/lite/mutable_op_resolver_test.cc | 2 +- tensorflow/lite/nnapi/BUILD | 3 +- tensorflow/lite/nnapi/sl/BUILD | 1 + tensorflow/lite/objc/BUILD.apple | 3 +- .../lite/objc/TensorFlowLiteObjC.podspec | 5 +- .../objc/TensorFlowLiteObjC.podspec.template | 6 - tensorflow/lite/objc/sources/TFLCommonUtil.mm | 4 + .../lite/objc/sources/TFLInterpreter.mm | 6 +- .../lite/objc/sources/TFLSignatureRunner.mm | 6 +- tensorflow/lite/optional_debug_tools.cc | 24 +- tensorflow/lite/optional_debug_tools.h | 3 +- tensorflow/lite/optional_debug_tools_test.cc | 6 +- tensorflow/lite/portable_type_to_tflitetype.h | 2 +- tensorflow/lite/profiling/BUILD | 8 + tensorflow/lite/profiling/buffered_profiler.h | 10 +- tensorflow/lite/profiling/memory_info.cc | 9 +- .../lite/profiling/profile_summarizer.cc | 19 +- .../lite/profiling/profile_summarizer_test.cc | 2 +- .../lite/profiling/signpost_profiler.mm | 2 + tensorflow/lite/profiling/telemetry/BUILD | 40 +- tensorflow/lite/profiling/telemetry/c/BUILD | 50 + .../lite/profiling/telemetry/c/profiler.h | 85 + .../profiling/telemetry/c/telemetry_setting.h | 103 + .../telemetry/c/telemetry_setting_internal.cc | 99 + .../telemetry/c/telemetry_setting_internal.h | 60 + .../lite/profiling/telemetry/profiler.cc | 83 +- .../lite/profiling/telemetry/profiler.h | 15 +- .../lite/profiling/telemetry/profiler_test.cc | 156 + .../lite/profiling/telemetry/telemetry.cc | 20 +- .../lite/profiling/telemetry/telemetry.h | 15 +- .../profiling/telemetry/telemetry_settings.h | 49 - .../profiling/telemetry/telemetry_status.h | 24 +- .../profiling/telemetry/telemetry_test.cc | 20 +- tensorflow/lite/python/BUILD | 2 + tensorflow/lite/python/analyzer_wrapper/BUILD | 3 +- .../python/analyzer_wrapper/model_analyzer.cc | 46 +- tensorflow/lite/python/authoring/BUILD | 1 + tensorflow/lite/python/convert.py | 39 +- tensorflow/lite/python/convert_phase.py | 2 +- tensorflow/lite/python/interpreter.py | 54 +- tensorflow/lite/python/interpreter_test.py | 2 +- .../lite/python/interpreter_wrapper/BUILD | 9 +- .../interpreter_wrapper.cc | 109 +- .../interpreter_wrapper/interpreter_wrapper.h | 27 +- .../interpreter_wrapper_pybind11.cc | 46 +- .../lite/python/interpreter_wrapper/numpy.h | 4 +- tensorflow/lite/python/lite.py | 83 +- tensorflow/lite/python/lite_v2_test.py | 86 +- tensorflow/lite/python/lite_v2_test_util.py | 42 + tensorflow/lite/python/metrics/BUILD | 1 + tensorflow/lite/python/optimize/BUILD | 7 +- .../python/optimize/calibration_wrapper.cc | 6 +- .../python/optimize/calibration_wrapper.h | 6 +- tensorflow/lite/python/testdata/BUILD | 3 +- .../control_flow_v1_saved_model/BUILD | 1 + .../lite/python/testdata/test_delegate.cc | 2 +- tensorflow/lite/python/tflite_convert.py | 5 +- tensorflow/lite/python/util.py | 8 +- tensorflow/lite/python/wrap_toco.py | 14 +- tensorflow/lite/schema/BUILD | 9 +- .../lite/schema/builtin_ops_header/BUILD | 1 + tensorflow/lite/schema/builtin_ops_list/BUILD | 1 + .../lite/schema/conversion_metadata.fbs | 2 + .../schema/conversion_metadata_generated.h | 665 ++ tensorflow/lite/schema/schema.fbs | 7 + tensorflow/lite/schema/schema_generated.h | 40 +- tensorflow/lite/signature_runner.cc | 2 +- tensorflow/lite/signature_runner.h | 6 +- tensorflow/lite/signature_runner_test.cc | 4 +- tensorflow/lite/simple_memory_arena.cc | 6 +- tensorflow/lite/simple_memory_arena.h | 2 +- tensorflow/lite/simple_memory_arena_test.cc | 2 +- tensorflow/lite/simple_planner.cc | 34 +- tensorflow/lite/simple_planner.h | 2 +- tensorflow/lite/simple_planner_test.cc | 12 +- tensorflow/lite/stderr_reporter.h | 2 +- tensorflow/lite/string_util.cc | 43 +- tensorflow/lite/string_util.h | 29 +- tensorflow/lite/string_util_test.cc | 45 +- tensorflow/lite/swift/BUILD.apple | 1 + .../lite/swift/Sources/Interpreter.swift | 33 +- tensorflow/lite/swift/Sources/Model.swift | 18 + .../lite/swift/TensorFlowLiteSwift.podspec | 5 +- .../lite/swift/Tests/InterpreterTests.swift | 16 + tensorflow/lite/swift/Tests/ModelTests.swift | 22 +- tensorflow/lite/tensorflow_profiler_logger.cc | 224 +- tensorflow/lite/tensorflow_profiler_logger.h | 6 + .../lite/tensorflow_profiler_logger_shim.cc | 8 + .../test_util.h} | 21 +- tensorflow/lite/testing/BUILD | 12 +- tensorflow/lite/testing/build_def.bzl | 11 - tensorflow/lite/testing/generate_examples.py | 4 +- .../lite/testing/generate_examples_lib.py | 2 +- tensorflow/lite/testing/kernel_test/BUILD | 15 +- .../lite/testing/kernel_test/diff_analyzer.cc | 2 +- .../lite/testing/kernel_test/diff_analyzer.h | 2 +- .../testing/kernel_test/input_generator.cc | 6 +- .../testing/kernel_test/input_generator.h | 6 +- tensorflow/lite/testing/kernel_test/util.h | 2 +- tensorflow/lite/testing/op_tests/abs.py | 2 +- tensorflow/lite/testing/op_tests/add_n.py | 2 +- .../lite/testing/op_tests/arg_min_max.py | 6 +- tensorflow/lite/testing/op_tests/atan2.py | 2 +- .../testing/op_tests/batch_to_space_nd.py | 4 +- tensorflow/lite/testing/op_tests/binary_op.py | 2 +- tensorflow/lite/testing/op_tests/cast.py | 144 +- tensorflow/lite/testing/op_tests/ceil.py | 2 +- .../lite/testing/op_tests/complex_abs.py | 2 +- tensorflow/lite/testing/op_tests/concat.py | 2 +- tensorflow/lite/testing/op_tests/cond.py | 11 +- tensorflow/lite/testing/op_tests/constant.py | 2 +- .../lite/testing/op_tests/control_dep.py | 7 +- tensorflow/lite/testing/op_tests/conv.py | 6 +- .../lite/testing/op_tests/conv2d_transpose.py | 2 +- tensorflow/lite/testing/op_tests/conv3d.py | 2 +- .../lite/testing/op_tests/conv3d_transpose.py | 2 +- .../lite/testing/op_tests/conv_activation.py | 6 +- .../testing/op_tests/conv_bias_activation.py | 14 +- ...nv_to_depthwiseconv_with_shared_weights.py | 10 +- .../op_tests/conv_with_shared_weights.py | 10 +- tensorflow/lite/testing/op_tests/cos.py | 2 +- tensorflow/lite/testing/op_tests/cumsum.py | 2 +- .../lite/testing/op_tests/depth_to_space.py | 2 +- .../lite/testing/op_tests/depthwiseconv.py | 8 +- .../lite/testing/op_tests/dynamic_rnn.py | 10 +- .../testing/op_tests/dynamic_update_slice.py | 6 +- .../lite/testing/op_tests/elementwise.py | 2 +- tensorflow/lite/testing/op_tests/elu.py | 2 +- .../lite/testing/op_tests/embedding_lookup.py | 4 +- tensorflow/lite/testing/op_tests/equal.py | 2 +- tensorflow/lite/testing/op_tests/exp.py | 2 +- .../lite/testing/op_tests/expand_dims.py | 2 +- tensorflow/lite/testing/op_tests/expm1.py | 4 +- tensorflow/lite/testing/op_tests/eye.py | 2 +- tensorflow/lite/testing/op_tests/fill.py | 2 +- tensorflow/lite/testing/op_tests/floor.py | 2 +- .../lite/testing/op_tests/fully_connected.py | 2 +- .../lite/testing/op_tests/fused_batch_norm.py | 2 +- tensorflow/lite/testing/op_tests/gather.py | 2 +- tensorflow/lite/testing/op_tests/gather_nd.py | 2 +- .../testing/op_tests/gather_with_constant.py | 2 +- .../testing/op_tests/global_batch_norm.py | 2 +- tensorflow/lite/testing/op_tests/greater.py | 2 +- .../lite/testing/op_tests/greater_equal.py | 2 +- tensorflow/lite/testing/op_tests/hardswish.py | 2 +- .../testing/op_tests/identify_dilated_conv.py | 6 +- .../op_tests/identify_dilated_conv1d.py | 6 +- tensorflow/lite/testing/op_tests/identity.py | 2 +- tensorflow/lite/testing/op_tests/imag.py | 2 +- tensorflow/lite/testing/op_tests/irfft2d.py | 2 +- tensorflow/lite/testing/op_tests/is_finite.py | 2 +- tensorflow/lite/testing/op_tests/l2norm.py | 2 +- .../testing/op_tests/l2norm_shared_epsilon.py | 2 +- .../lite/testing/op_tests/leaky_relu.py | 2 +- tensorflow/lite/testing/op_tests/less.py | 2 +- .../lite/testing/op_tests/less_equal.py | 2 +- .../testing/op_tests/local_response_norm.py | 2 +- .../lite/testing/op_tests/log_softmax.py | 2 +- tensorflow/lite/testing/op_tests/logic.py | 2 +- tensorflow/lite/testing/op_tests/lstm.py | 6 +- .../lite/testing/op_tests/matrix_band_part.py | 2 +- .../lite/testing/op_tests/matrix_diag.py | 2 +- .../lite/testing/op_tests/matrix_set_diag.py | 2 +- tensorflow/lite/testing/op_tests/maximum.py | 2 +- tensorflow/lite/testing/op_tests/minimum.py | 2 +- .../lite/testing/op_tests/mirror_pad.py | 4 +- .../lite/testing/op_tests/multinomial.py | 4 +- .../lite/testing/op_tests/nearest_upsample.py | 2 +- tensorflow/lite/testing/op_tests/neg.py | 2 +- tensorflow/lite/testing/op_tests/not_equal.py | 2 +- tensorflow/lite/testing/op_tests/one_hot.py | 2 +- tensorflow/lite/testing/op_tests/pack.py | 2 +- tensorflow/lite/testing/op_tests/pad.py | 4 +- tensorflow/lite/testing/op_tests/padv2.py | 4 +- .../lite/testing/op_tests/parse_example.py | 18 +- .../op_tests/placeholder_with_default.py | 2 +- tensorflow/lite/testing/op_tests/pool.py | 10 +- tensorflow/lite/testing/op_tests/pool3d.py | 2 +- tensorflow/lite/testing/op_tests/prelu.py | 2 +- .../op_tests/random_standard_normal.py | 4 +- .../lite/testing/op_tests/random_uniform.py | 4 +- tensorflow/lite/testing/op_tests/range.py | 2 +- tensorflow/lite/testing/op_tests/rank.py | 2 +- tensorflow/lite/testing/op_tests/real.py | 2 +- .../lite/testing/op_tests/reciprocal.py | 2 +- tensorflow/lite/testing/op_tests/reduce.py | 2 +- tensorflow/lite/testing/op_tests/relu.py | 2 +- tensorflow/lite/testing/op_tests/relu1.py | 2 +- tensorflow/lite/testing/op_tests/relu6.py | 2 +- tensorflow/lite/testing/op_tests/reshape.py | 2 +- .../lite/testing/op_tests/resize_bilinear.py | 2 +- .../op_tests/resize_nearest_neighbor.py | 4 +- .../resolve_constant_strided_slice.py | 2 +- .../lite/testing/op_tests/reverse_sequence.py | 4 +- .../lite/testing/op_tests/reverse_v2.py | 2 +- tensorflow/lite/testing/op_tests/rfft.py | 2 +- tensorflow/lite/testing/op_tests/rfft2d.py | 2 +- tensorflow/lite/testing/op_tests/roll.py | 2 +- tensorflow/lite/testing/op_tests/round.py | 2 +- .../lite/testing/op_tests/scatter_nd.py | 2 +- .../lite/testing/op_tests/segment_sum.py | 4 +- tensorflow/lite/testing/op_tests/shape.py | 4 +- .../op_tests/shape_to_strided_slice.py | 4 +- tensorflow/lite/testing/op_tests/sigmoid.py | 2 +- .../lite/testing/op_tests/sigmoid_grad.py | 2 +- tensorflow/lite/testing/op_tests/sign.py | 2 +- tensorflow/lite/testing/op_tests/slice.py | 2 +- tensorflow/lite/testing/op_tests/softmax.py | 4 +- tensorflow/lite/testing/op_tests/softplus.py | 2 +- .../testing/op_tests/space_to_batch_nd.py | 4 +- .../lite/testing/op_tests/space_to_depth.py | 5 +- .../lite/testing/op_tests/sparse_to_dense.py | 4 +- tensorflow/lite/testing/op_tests/split.py | 2 +- tensorflow/lite/testing/op_tests/splitv.py | 2 +- tensorflow/lite/testing/op_tests/squeeze.py | 2 +- .../testing/op_tests/squeeze_transpose.py | 4 +- .../lite/testing/op_tests/static_hashtable.py | 5 +- .../static_rnn_with_control_flow_v2.py | 12 +- tensorflow/lite/testing/op_tests/stft.py | 2 +- .../lite/testing/op_tests/strided_slice.py | 2 +- .../op_tests/strided_slice_np_style.py | 2 +- tensorflow/lite/testing/op_tests/tanh.py | 2 +- .../testing/op_tests/tensor_list_concat.py | 4 +- .../op_tests/tensor_list_dynamic_shape.py | 7 +- .../testing/op_tests/tensor_list_get_item.py | 4 +- .../testing/op_tests/tensor_list_length.py | 4 +- .../testing/op_tests/tensor_list_resize.py | 4 +- .../testing/op_tests/tensor_list_set_item.py | 6 +- .../testing/op_tests/tensor_scatter_add.py | 2 +- .../testing/op_tests/tensor_scatter_update.py | 2 +- tensorflow/lite/testing/op_tests/tile.py | 2 +- tensorflow/lite/testing/op_tests/topk.py | 2 +- tensorflow/lite/testing/op_tests/transpose.py | 9 +- .../lite/testing/op_tests/transpose_conv.py | 6 +- .../lite/testing/op_tests/unfused_gru.py | 2 +- tensorflow/lite/testing/op_tests/unique.py | 2 +- tensorflow/lite/testing/op_tests/unpack.py | 2 +- .../testing/op_tests/unroll_batch_matmul.py | 2 +- .../lite/testing/op_tests/unsorted_segment.py | 2 +- tensorflow/lite/testing/op_tests/where.py | 5 +- tensorflow/lite/testing/op_tests/where_v2.py | 4 +- .../lite/testing/op_tests/while_loop.py | 9 +- .../lite/testing/op_tests/zeros_like.py | 2 +- .../lite/testing/selective_build_test.cc | 2 +- tensorflow/lite/testing/split.h | 12 + tensorflow/lite/testing/tflite_driver.cc | 20 +- tensorflow/lite/testing/tflite_driver.h | 6 +- tensorflow/lite/testing/zip_test_utils.py | 28 +- tensorflow/lite/tflite_with_xnnpack.cc | 6 +- .../lite/tflite_with_xnnpack_optional.cc | 23 +- .../lite/tflite_with_xnnpack_optional.h | 5 +- tensorflow/lite/toco/BUILD | 3 + .../toco/graph_transformations/tests/BUILD | 1 + tensorflow/lite/toco/logging/BUILD | 1 + tensorflow/lite/toco/logging/testdata/BUILD | 1 + tensorflow/lite/toco/python/BUILD | 4 +- .../lite/toco/python/toco_python_api.cc | 17 +- tensorflow/lite/toco/python/toco_python_api.h | 3 +- .../lite/toco/tensorflow_graph_matching/BUILD | 1 + tensorflow/lite/toco/tflite/BUILD | 4 +- tensorflow/lite/toco/tflite/import.cc | 4 +- tensorflow/lite/toco/toco_flags.proto | 9 +- tensorflow/lite/tools/BUILD | 74 +- tensorflow/lite/tools/benchmark/BUILD | 18 +- .../lite/tools/benchmark/CMakeLists.txt | 5 +- tensorflow/lite/tools/benchmark/README.md | 56 +- tensorflow/lite/tools/benchmark/android/BUILD | 1 + .../lite/tools/benchmark/benchmark_model.h | 2 +- .../benchmark_performance_options.cc | 4 +- .../benchmark/benchmark_performance_options.h | 2 +- .../lite/tools/benchmark/benchmark_test.cc | 108 +- .../tools/benchmark/benchmark_tflite_model.cc | 36 +- .../tools/benchmark/benchmark_tflite_model.h | 4 +- .../lite/tools/benchmark/experimental/c/BUILD | 3 +- .../experimental/c/benchmark_c_api.h | 2 +- .../android/AndroidManifest.xml | 6 +- .../delegate_performance/android/BUILD | 106 +- .../delegate_performance/android/README.md | 287 +- .../android/jni/accuracy_benchmark.cc | 172 - .../android/jni/accuracy_benchmark.h | 63 - .../android/jni/accuracy_benchmark_test.cc | 156 - .../jni/delegate_performance_benchmark_jni.cc | 75 - .../android/jni/latency_benchmark.cc | 216 - .../delegate_performance/android/models/BUILD | 85 + .../models/mobilenet_v1_1.0_224.textproto} | 13 +- .../mobilenet_v1_1.0_224_quant.textproto} | 12 +- .../delegate_performance/android/proto.bzl | 85 + .../delegate_performance/android/proto/BUILD | 33 + .../proto/default_latency_criteria.textproto} | 12 +- .../android/proto/delegate_performance.proto | 115 + .../BenchmarkAccuracyActivity.java | 54 + .../BenchmarkAccuracyImpl.java | 155 + .../BenchmarkLatencyActivity.java | 56 + .../BenchmarkLatencyImpl.java | 365 + .../BenchmarkResultType.java | 39 + .../delegateperformance/CsvWriter.java | 107 + .../DelegatePerformanceBenchmark.java | 216 + .../TfLiteSettingsListEntry.java | 123 + .../android/src/main/native/BUILD | 81 + .../src/main/native/accuracy_benchmark.cc | 120 + .../src/main/native/accuracy_benchmark.h | 49 + .../delegate_performance_benchmark_jni.cc | 149 + .../src/main/native/latency_benchmark.cc | 230 + .../src/main/native/latency_benchmark.h | 49 + .../main/native/status_codes.h} | 24 +- .../BenchmarkAccuracyActivity.java | 98 - .../BenchmarkLatencyActivity.java | 77 - .../DelegatePerformanceBenchmark.java | 54 - .../android/src/test/native/BUILD | 60 + .../test/native/accuracy_benchmark_test.cc | 150 + .../src/test/native/latency_benchmark_test.cc | 115 + .../experimental/firebase/android/BUILD | 1 + .../benchmark/experimental/ios/BUILD.apple | 3 +- .../lite/tools/cmake/download_toolchains.sh | 21 +- .../tools/cmake/modules/Findfarmhash.cmake | 30 +- .../cmake/modules/Findgoogle_benchmark.cmake | 46 + .../lite/tools/cmake/modules/cpuinfo.cmake | 2 +- .../tools/cmake/modules/egl_headers.cmake | 2 +- .../lite/tools/cmake/modules/eigen.cmake | 2 +- .../tools/cmake/modules/flatbuffers.cmake | 2 +- .../tools/cmake/modules/fp16_headers.cmake | 2 +- .../tools/cmake/modules/opencl_headers.cmake | 2 +- .../tools/cmake/modules/opengl_headers.cmake | 2 +- .../tools/cmake/modules/vulkan_headers.cmake | 2 +- .../lite/tools/cmake/modules/xnnpack.cmake | 4 +- tensorflow/lite/tools/delegates/BUILD | 2 + tensorflow/lite/tools/delegates/README.md | 96 +- .../delegates/compatibility/common/BUILD | 7 +- .../delegate_compatibility_checker_base.h | 2 +- .../tools/delegates/compatibility/gpu/BUILD | 1 + .../tools/delegates/compatibility/nnapi/BUILD | 6 +- .../nnapi/nnapi_compatibility_lib_test.cc | 2 +- .../nnapi_delegate_compatibility_checker.cc | 2 +- .../delegates/compatibility/protos/BUILD | 1 + .../delegates/coreml_delegate_provider.cc | 6 + .../delegates/default_execution_provider.cc | 18 + .../experimental/stable_delegate/BUILD | 63 + .../stable_delegate_provider.cc | 127 + .../stable_delegate_provider_test.cc | 96 + .../test_invalid_settings.json | 4 + .../test_missing_delegate_path_settings.json | 6 + ...test_missing_stable_delegate_settings.json | 4 + .../test_sample_stable_delegate_settings.json | 11 + .../test_stable_xnnpack_settings.json | 14 + tensorflow/lite/tools/evaluation/BUILD | 9 +- .../lite/tools/evaluation/evaluation_stage.h | 2 +- tensorflow/lite/tools/evaluation/proto/BUILD | 1 + tensorflow/lite/tools/evaluation/stages/BUILD | 17 +- .../stages/inference_profiler_stage.cc | 25 +- .../stages/inference_profiler_stage.h | 5 +- .../stages/inference_profiler_stage_test.cc | 2 +- .../stages/object_detection_stage.cc | 2 +- .../stages/tflite_inference_stage.cc | 2 +- .../stages/tflite_inference_stage.h | 6 +- .../stages/tflite_inference_stage_test.cc | 4 +- .../lite/tools/evaluation/stages/utils/BUILD | 1 + tensorflow/lite/tools/evaluation/tasks/BUILD | 24 + .../tasks/coco_object_detection/BUILD | 3 +- .../tasks/coco_object_detection/run_eval.cc | 2 +- .../tasks/imagenet_image_classification/BUILD | 3 +- .../imagenet_image_classification/run_eval.cc | 2 +- .../evaluation/tasks/inference_diff/BUILD | 3 +- .../evaluation/tasks/inference_diff/README.md | 8 +- .../tasks/inference_diff/run_eval.cc | 2 +- .../tools/evaluation/tasks/ios/BUILD.apple | 37 + .../lite/tools/evaluation/tasks/ios/README.md | 43 + .../project.pbxproj | 413 ++ .../TFLiteEvaluation/AppDelegate.h | 22 + .../TFLiteEvaluation/AppDelegate.m | 27 + .../AccentColor.colorset/Contents.json | 11 + .../AppIcon.appiconset/Contents.json | 98 + .../Assets.xcassets/Contents.json | 6 + .../Base.lproj/Main.storyboard | 57 + .../EvaluationViewController.h | 22 + .../EvaluationViewController.mm | 141 + .../TFLiteEvaluation/Info.plist | 43 + .../evaluation_data/evaluation_params.json | 8 + .../TFLiteEvaluation/TFLiteEvaluation/main.m | 26 + .../tasks/ios/build_evaluation_framework.sh | 62 + .../evaluation/tasks/task_executor_c_api.cc | 99 + .../evaluation/tasks/task_executor_c_api.h | 93 + tensorflow/lite/tools/evaluation/utils.h | 4 +- tensorflow/lite/tools/flatbuffer_utils.py | 75 +- .../lite/tools/flatbuffer_utils_test.py | 14 + tensorflow/lite/tools/gen_op_registration.cc | 2 +- tensorflow/lite/tools/gen_op_registration.h | 2 +- tensorflow/lite/tools/list_flex_ops.h | 2 +- .../mini_benchmark => tools}/model_loader.cc | 54 +- .../mini_benchmark => tools}/model_loader.h | 39 +- .../model_loader_test.cc | 42 +- tensorflow/lite/tools/optimize/BUILD | 19 +- .../lite/tools/optimize/calibration/BUILD | 19 +- .../calibration/builtin_logging_ops/lstm.cc | 2 +- .../calibration/builtin_logging_ops/lstm.h | 2 +- .../optimize/calibration/calibration_logger.h | 2 +- .../optimize/calibration/calibration_reader.h | 2 +- .../tools/optimize/calibration/calibrator.cc | 6 +- .../tools/optimize/calibration/calibrator.h | 2 +- .../optimize/calibration/calibrator_test.cc | 4 +- .../calibration/custom_logging_ops/lstm.cc | 2 +- .../calibration/custom_logging_ops/lstm.h | 2 +- .../tools/optimize/calibration/logging_op.h | 2 +- .../tools/optimize/debugging/python/BUILD | 1 + .../optimize/debugging/python/debugger.py | 7 +- tensorflow/lite/tools/optimize/model_utils.cc | 2 +- tensorflow/lite/tools/optimize/model_utils.h | 2 +- .../lite/tools/optimize/model_utils_test.cc | 2 +- .../tools/optimize/modify_model_interface.cc | 4 +- .../tools/optimize/modify_model_interface.h | 2 +- .../optimize/modify_model_interface_test.cc | 2 +- .../lite/tools/optimize/operator_property.cc | 400 +- .../lite/tools/optimize/operator_property.h | 7 +- tensorflow/lite/tools/optimize/python/BUILD | 1 + .../lite/tools/optimize/quantization_utils.cc | 39 +- .../lite/tools/optimize/quantization_utils.h | 3 +- .../tools/optimize/quantization_utils_test.cc | 17 +- .../optimize/quantization_wrapper_utils.cc | 5 +- .../optimize/quantization_wrapper_utils.h | 2 +- .../quantization_wrapper_utils_custom_test.cc | 2 +- .../quantization_wrapper_utils_test.cc | 2 +- .../lite/tools/optimize/quantize_model.cc | 2 +- .../lite/tools/optimize/quantize_model.h | 2 +- .../tools/optimize/quantize_model_test.cc | 239 +- .../lite/tools/optimize/quantize_weights.cc | 2 +- .../lite/tools/optimize/quantize_weights.h | 2 +- .../optimize/quantize_weights_portable.cc | 4 +- .../tools/optimize/quantize_weights_test.cc | 2 +- .../reduced_precision_support_test.cc | 2 +- tensorflow/lite/tools/optimize/sparsity/BUILD | 3 +- .../format_converter_wrapper_pybind11.cc | 2 +- .../lite/tools/pip_package/Dockerfile.py3 | 10 +- tensorflow/lite/tools/pip_package/Makefile | 14 +- .../build_pip_package_with_bazel.sh | 21 +- .../tools/pip_package/setup_with_binary.py | 8 +- tensorflow/lite/tools/randomize_weights.py | 13 + tensorflow/lite/tools/serialization/BUILD | 28 +- .../serialization/option_writer_generator.cc | 2 +- tensorflow/lite/tools/serialization/writer.cc | 4 +- .../lite/tools/serialization/writer_lib.cc | 6 +- .../lite/tools/serialization/writer_lib.h | 2 +- .../tools/serialization/writer_lib_test.cc | 10 +- .../lite/tools/serialization/writer_test.cc | 4 +- tensorflow/lite/tools/signature/BUILD | 15 +- .../tools/signature/signature_def_util.cc | 3 +- .../lite/tools/signature/signature_def_util.h | 2 +- .../signature/signature_def_util_test.cc | 4 +- .../signature_def_util_wrapper_pybind11.cc | 2 +- tensorflow/lite/tools/strip_buffers/BUILD | 10 +- .../reconstitute_buffers_into_fb.cc | 4 +- .../strip_buffers/strip_buffers_from_fb.cc | 4 +- .../lite/tools/strip_buffers/stripping_lib.cc | 4 +- .../lite/tools/strip_buffers/stripping_lib.h | 2 +- tensorflow/lite/tools/test_utils.py | 85 +- tensorflow/lite/tools/utils.h | 2 +- tensorflow/lite/tools/verifier.h | 12 +- tensorflow/lite/tools/verifier_internal.h | 12 +- tensorflow/lite/tools/versioning/BUILD | 15 +- .../tools/versioning/gpu_compatibility.cc | 1 + .../lite/tools/versioning/gpu_compatibility.h | 2 +- .../versioning/gpu_compatibility_test.cc | 2 +- .../lite/tools/versioning/op_signature.cc | 13 +- .../lite/tools/versioning/op_signature.h | 4 +- .../tools/versioning/op_signature_test.cc | 2 +- .../lite/tools/versioning/op_version.cc | 78 +- .../lite/tools/versioning/op_version_test.cc | 51 + .../lite/tools/versioning/runtime_version.cc | 11 +- .../tools/versioning/runtime_version_test.cc | 2 +- tensorflow/lite/tools/visualize_test.py | 2 +- tensorflow/lite/tutorials/BUILD | 1 + tensorflow/lite/type_to_tflitetype.h | 2 +- tensorflow/lite/type_to_tflitetype_test.cc | 2 +- tensorflow/lite/util.cc | 31 +- tensorflow/lite/util.h | 8 +- tensorflow/lite/util_test.cc | 30 +- tensorflow/opensource_only.files | 49 +- tensorflow/python/BUILD | 199 +- tensorflow/python/autograph/BUILD | 1 + tensorflow/python/autograph/CONTRIBUTING.md | 7 +- tensorflow/python/autograph/converters/BUILD | 1 + .../converters/return_statements_test.py | 2 +- tensorflow/python/autograph/core/BUILD | 1 + .../autograph/g3doc/reference/control_flow.md | 2 +- tensorflow/python/autograph/impl/BUILD | 1 + .../python/autograph/impl/testing/BUILD | 1 + tensorflow/python/autograph/lang/BUILD | 1 + tensorflow/python/autograph/operators/BUILD | 3 +- .../autograph/operators/control_flow.py | 185 +- .../autograph/operators/control_flow_test.py | 4 +- .../python/autograph/operators/py_builtins.py | 260 +- tensorflow/python/autograph/pyct/BUILD | 1 + .../autograph/pyct/common_transformers/BUILD | 1 + .../python/autograph/pyct/inspect_utils.py | 5 +- .../autograph/pyct/static_analysis/BUILD | 1 + .../pyct/static_analysis/liveness.py | 8 - .../python/autograph/pyct/testing/BUILD | 1 + .../python/autograph/pyct/transpiler.py | 8 +- tensorflow/python/autograph/tests/BUILD | 2 + tensorflow/python/autograph/utils/BUILD | 11 +- .../python/autograph/utils/type_registry.py | 62 + tensorflow/python/checkpoint/BUILD | 47 +- tensorflow/python/checkpoint/__init__.py | 3 - .../checkpoint/async_checkpoint_helper.py | 561 ++ .../python/checkpoint/benchmarks_test.py | 21 +- tensorflow/python/checkpoint/checkpoint.py | 209 +- .../checkpoint/checkpoint_management.py | 2 + .../checkpoint/checkpoint_management_test.py | 2 +- .../python/checkpoint/checkpoint_options.py | 28 +- .../python/checkpoint/checkpoint_test.py | 94 +- tensorflow/python/checkpoint/restore.py | 22 +- tensorflow/python/checkpoint/restore_test.py | 11 +- tensorflow/python/checkpoint/save_util.py | 4 +- tensorflow/python/checkpoint/save_util_v1.py | 7 +- .../python/checkpoint/saveable_compat_test.py | 6 +- .../python/checkpoint/tensor_callable.py | 41 + .../python/checkpoint/tensor_callable_test.py | 73 + tensorflow/python/client/BUILD | 16 +- tensorflow/python/client/session_test.py | 60 +- tensorflow/python/client/tf_session_helper.h | 4 +- .../python/client/tf_session_wrapper.cc | 29 +- tensorflow/python/compat/BUILD | 1 + tensorflow/python/compat/compat.py | 4 +- tensorflow/python/compiler/BUILD | 1 + tensorflow/python/compiler/mlir/BUILD | 1 + tensorflow/python/compiler/mlir/mlir_test.py | 3 +- tensorflow/python/compiler/tensorrt/BUILD | 2 + .../compiler/tensorrt/model_tests/BUILD | 1 + .../tensorrt/model_tests/model_handler.py | 15 +- .../tensorrt/model_tests/result_analyzer.py | 87 +- .../tensorrt/model_tests/run_models.py | 67 +- .../python/compiler/tensorrt/test/BUILD | 1 + .../test/tf_trt_integration_test_base.py | 6 +- tensorflow/python/compiler/xla/BUILD | 1 + .../python/compiler/xla/experimental/BUILD | 1 + tensorflow/python/compiler/xla/xla.py | 39 +- tensorflow/python/data/BUILD | 5 +- tensorflow/python/data/benchmarks/BUILD | 2 + .../python/data/benchmarks/map_benchmark.py | 5 +- tensorflow/python/data/experimental/BUILD | 1 + .../python/data/experimental/__init__.py | 4 +- .../python/data/experimental/benchmarks/BUILD | 1 + .../data/experimental/kernel_tests/BUILD | 50 +- .../kernel_tests/assert_cardinality_test.py | 25 +- .../kernel_tests/auto_shard_dataset_test.py | 10 +- .../kernel_tests/distributed_save_test.py | 96 + .../kernel_tests/from_list_test.py | 24 +- .../kernel_tests/index_shuffle_test.py | 230 + .../kernel_tests/map_and_batch_test.py | 48 +- .../kernel_tests/model_dataset_test.py | 3 +- .../kernel_tests/optimization/BUILD | 7 +- .../kernel_tests/parallel_interleave_test.py | 368 - .../kernel_tests/replicate_test.py | 3 - .../experimental/kernel_tests/service/BUILD | 20 + .../service/coordinated_read_test.py | 8 +- .../service/cross_trainer_cache_test.py | 4 +- .../service/dynamic_sharding_test.py | 2 +- .../service/fault_tolerance_test.py | 7 +- .../service/local_workers_test.py | 8 +- .../kernel_tests/service/snapshot_ft_test.py | 229 + .../kernel_tests/service/test_base.py | 40 +- .../kernel_tests/service/worker_tags_test.py | 8 +- tensorflow/python/data/experimental/ops/BUILD | 29 + .../python/data/experimental/ops/batching.py | 1 + .../python/data/experimental/ops/counter.py | 3 +- .../data/experimental/ops/data_service_ops.py | 23 +- .../experimental/ops/distributed_save_op.py | 72 + .../python/data/experimental/ops/error_ops.py | 3 +- .../data/experimental/ops/parsing_ops.py | 5 +- .../data/experimental/ops/random_ops.py | 4 +- .../python/data/experimental/ops/readers.py | 5 +- .../data/experimental/ops/shuffle_ops.py | 171 +- .../python/data/experimental/ops/snapshot.py | 7 +- .../python/data/experimental/service/BUILD | 1 + .../data/experimental/service/server_lib.py | 71 +- .../service/server_lib_wrapper.cc | 79 +- tensorflow/python/data/kernel_tests/BUILD | 79 +- .../python/data/kernel_tests/batch_test.py | 31 +- .../data/kernel_tests/checkpoint_test_base.py | 4 +- .../kernel_tests/choose_from_datasets_test.py | 172 + .../data/kernel_tests/concatenate_test.py | 24 +- .../python/data/kernel_tests/counter_test.py | 28 + .../python/data/kernel_tests/dataset_test.py | 5 +- .../data/kernel_tests/enumerate_test.py | 26 + .../python/data/kernel_tests/filter_test.py | 29 +- .../python/data/kernel_tests/flat_map_test.py | 15 +- .../data/kernel_tests/from_generator_test.py | 3 +- .../kernel_tests/from_tensor_slices_test.py | 21 +- .../data/kernel_tests/from_tensors_test.py | 21 +- .../data/kernel_tests/interleave_test.py | 18 +- .../python/data/kernel_tests/iterator_test.py | 6 +- .../python/data/kernel_tests/map_test.py | 55 +- .../multi_device_iterator_test.py | 3 +- .../python/data/kernel_tests/options_test.py | 6 - .../data/kernel_tests/padded_batch_test.py | 37 +- .../python/data/kernel_tests/prefetch_test.py | 24 +- .../python/data/kernel_tests/random_test.py | 128 +- .../python/data/kernel_tests/range_test.py | 20 +- .../kernel_tests/rejection_resample_test.py | 6 +- .../python/data/kernel_tests/repeat_test.py | 95 +- .../sample_from_datasets_test.py} | 238 +- .../python/data/kernel_tests/scan_test.py | 25 +- .../python/data/kernel_tests/shard_test.py | 15 +- .../python/data/kernel_tests/skip_test.py | 23 +- .../python/data/kernel_tests/take_test.py | 21 +- .../data/kernel_tests/take_while_test.py | 18 +- .../python/data/kernel_tests/unbatch_test.py | 29 +- .../python/data/kernel_tests/window_test.py | 6 +- .../python/data/kernel_tests/zip_test.py | 19 +- tensorflow/python/data/ops/BUILD | 192 +- tensorflow/python/data/ops/batch_op.py | 31 +- tensorflow/python/data/ops/cache_op.py | 49 + .../data/ops/choose_from_datasets_op.py | 53 + tensorflow/python/data/ops/concatenate_op.py | 63 + tensorflow/python/data/ops/counter_op.py | 2 +- .../python/data/ops/dataset_autograph.py | 221 + tensorflow/python/data/ops/dataset_ops.py | 1928 +----- tensorflow/python/data/ops/debug_mode.py | 77 + .../python/data/ops/directed_interleave_op.py | 72 + tensorflow/python/data/ops/filter_op.py | 8 +- tensorflow/python/data/ops/flat_map_op.py | 57 + .../python/data/ops/from_generator_op.py | 400 ++ .../data/ops/from_sparse_tensor_slices_op.py | 53 + .../python/data/ops/from_tensor_slices_op.py | 8 +- tensorflow/python/data/ops/from_tensors_op.py | 43 + .../python/data/ops/group_by_window_op.py | 132 + .../python/data/ops/ignore_errors_op.py | 6 +- tensorflow/python/data/ops/interleave_op.py | 170 + .../python/data/ops/iterator_autograph.py | 119 + tensorflow/python/data/ops/iterator_ops.py | 67 +- tensorflow/python/data/ops/load_op.py | 13 +- tensorflow/python/data/ops/map_op.py | 182 + .../data/ops/multi_device_iterator_ops.py | 4 +- tensorflow/python/data/ops/optional_ops.py | 27 +- tensorflow/python/data/ops/options.py | 15 + tensorflow/python/data/ops/padded_batch_op.py | 26 +- tensorflow/python/data/ops/prefetch_op.py | 51 + tensorflow/python/data/ops/ragged_batch_op.py | 18 +- tensorflow/python/data/ops/random_op.py | 64 + tensorflow/python/data/ops/range_op.py | 70 + tensorflow/python/data/ops/readers.py | 6 +- tensorflow/python/data/ops/rebatch_op.py | 10 +- tensorflow/python/data/ops/repeat_op.py | 44 + .../data/ops/sample_from_datasets_op.py | 121 + tensorflow/python/data/ops/save_op.py | 22 +- tensorflow/python/data/ops/scan_op.py | 160 + tensorflow/python/data/ops/shard_op.py | 43 + tensorflow/python/data/ops/shuffle_op.py | 72 + tensorflow/python/data/ops/skip_op.py | 38 + tensorflow/python/data/ops/snapshot_op.py | 119 + tensorflow/python/data/ops/sparse_batch_op.py | 4 +- .../python/data/ops/structured_function.py | 6 +- tensorflow/python/data/ops/take_op.py | 39 + tensorflow/python/data/ops/take_while_op.py | 58 + tensorflow/python/data/ops/unbatch_op.py | 58 + tensorflow/python/data/ops/unique_op.py | 42 + tensorflow/python/data/ops/window_op.py | 76 + tensorflow/python/data/ops/zip_op.py | 2 +- tensorflow/python/data/util/BUILD | 2 + tensorflow/python/data/util/structure.py | 28 +- tensorflow/python/debug/BUILD | 1 + tensorflow/python/debug/cli/BUILD | 1 + tensorflow/python/debug/cli/curses_widgets.py | 3 +- tensorflow/python/debug/examples/v1/BUILD | 6 +- tensorflow/python/debug/examples/v2/BUILD | 5 +- tensorflow/python/debug/lib/BUILD | 1 + .../debug/lib/debug_events_monitors_test.py | 51 +- .../python/debug/lib/dumping_callback_test.py | 159 +- tensorflow/python/debug/wrappers/BUILD | 1 + tensorflow/python/distribute/BUILD | 15 +- .../python/distribute/cluster_resolver/BUILD | 1 + .../slurm_cluster_resolver.py | 6 +- .../slurm_cluster_resolver_test.py | 3 + .../distribute/cluster_resolver/tpu/BUILD | 1 + .../tpu/tpu_cluster_resolver.py | 10 + .../python/distribute/coordinator/BUILD | 22 + .../coordinator/cluster_coordinator.py | 3 +- .../coordinator/cluster_coordinator_test.py | 20 + .../coordinator/get_task_states_test.py | 164 + .../distribute/coordinator/metric_utils.py | 17 +- .../python/distribute/coordinator/values.py | 21 + .../distribute/cross_device_ops_test.py | 17 + .../python/distribute/cross_device_utils.py | 6 +- .../python/distribute/experimental/BUILD | 80 + .../distribute/experimental/__init__.py | 1 + .../distribute/experimental/dtensor_util.py | 99 + .../experimental/dtensor_util_test.py | 109 + .../experimental/mirrored_strategy.py | 381 ++ .../experimental/mirrored_strategy_test.py | 476 ++ .../python/distribute/experimental/rpc/BUILD | 1 + .../python/distribute/failure_handling/BUILD | 18 +- .../distribute/failure_handling/__init__.py | 1 + .../failure_handling/failure_handler_test.py | 85 +- .../failure_handling/failure_handling.py | 225 +- .../failure_handling/failure_handling_util.py | 28 +- .../gce_failure_handler_test.py | 74 +- .../failure_handling/preemption_watcher.py | 107 + tensorflow/python/distribute/input_lib.py | 18 + .../python/distribute/input_lib_test.py | 26 + .../python/distribute/integration_test/BUILD | 5 +- tensorflow/python/distribute/mirrored_run.py | 9 +- .../python/distribute/mirrored_strategy.py | 2 +- .../distribute/multi_process_runner_test.py | 35 + .../multi_worker_continuous_run_test.py | 25 +- .../distribute/packed_distributed_variable.py | 3 +- .../python/distribute/parallel_device/BUILD | 5 +- tensorflow/python/distribute/ps_values.py | 29 +- .../python/distribute/sharded_variable.py | 20 +- .../distribute/tpu_replicated_variable.py | 42 +- .../tpu_strategy_model_parallelism_test.py | 23 +- tensorflow/python/distribute/v1/BUILD | 1 + tensorflow/python/distribute/values.py | 83 +- tensorflow/python/distribute/values_v2.py | 4 +- tensorflow/python/dlpack/BUILD | 1 + tensorflow/python/eager/BUILD | 24 +- tensorflow/python/eager/backprop.py | 53 +- tensorflow/python/eager/backprop_test.py | 27 +- tensorflow/python/eager/backprop_util.py | 42 + tensorflow/python/eager/benchmarks/BUILD | 1 + .../python/eager/benchmarks/resnet50/BUILD | 4 + tensorflow/python/eager/context.py | 65 +- tensorflow/python/eager/executor.py | 7 +- tensorflow/python/eager/forwardprop_test.py | 8 + tensorflow/python/eager/function.py | 4 +- tensorflow/python/eager/memory_tests/BUILD | 1 + tensorflow/python/eager/ops_test.py | 4 + .../python/eager/polymorphic_function/BUILD | 78 +- .../eager/polymorphic_function/attributes.py | 109 + .../eager/polymorphic_function/compiler_ir.py | 110 + .../polymorphic_function/compiler_ir_test.py | 240 + .../composite_tensor_utils.py | 28 - .../polymorphic_function/function_context.py | 28 +- .../polymorphic_function/function_spec.py | 669 +- .../function_spec_test.py | 164 +- .../monomorphic_function.py | 104 +- .../polymorphic_function.py | 156 +- .../polymorphic_function_test.py | 320 +- .../polymorphic_function_xla_jit_test.py | 59 +- .../eager/polymorphic_function/quarantine.py | 340 +- .../polymorphic_function/quarantine_test.py | 323 +- .../saved_model_exported_concrete.py | 116 + .../polymorphic_function/saved_model_utils.py | 195 - .../polymorphic_function/tracing_compiler.py | 139 +- .../eager/pywrap_gradient_exclusions.cc | 3 +- tensorflow/python/eager/pywrap_tensor.cc | 11 +- tensorflow/python/eager/pywrap_tensor.h | 6 +- tensorflow/python/eager/remote.py | 18 +- tensorflow/python/estimator/BUILD | 1 + tensorflow/python/feature_column/BUILD | 1 + .../python/feature_column/feature_column.py | 30 +- .../feature_column/feature_column_v2.py | 85 +- .../feature_column/sequence_feature_column.py | 39 +- .../python/feature_column/serialization.py | 18 +- tensorflow/python/framework/BUILD | 178 +- .../python/framework/auto_control_deps.py | 11 +- tensorflow/python/framework/config.py | 54 +- tensorflow/python/framework/config_test.py | 16 - tensorflow/python/framework/constant_op.py | 9 +- .../python/framework/convert_to_constants.py | 81 +- tensorflow/python/framework/dtypes.py | 87 +- tensorflow/python/framework/dtypes_test.py | 25 +- tensorflow/python/framework/errors_impl.py | 120 +- .../python/framework/errors_test_helper.cc | 4 +- .../python/framework/experimental/BUILD | 1 + tensorflow/python/framework/extension_type.py | 92 +- .../python/framework/extension_type_test.py | 16 +- .../python/framework/fast_tensor_util.pyx | 5 + tensorflow/python/framework/framework_lib.py | 6 +- tensorflow/python/framework/func_graph.py | 408 +- tensorflow/python/framework/function.py | 4 +- .../python/framework/function_def_to_graph.py | 8 +- .../python/framework/graph_util_impl.py | 28 +- .../python/framework/graph_util_test.py | 39 + tensorflow/python/framework/importer.py | 20 +- tensorflow/python/framework/indexed_slices.py | 9 + .../python/framework/kythe_metadata.proto | 142 + tensorflow/python/framework/offset_counter.cc | 70 + .../python/framework/offset_counter_helper.cc | 59 + .../framework/offset_counter_helper.h} | 22 +- .../framework/offset_counter_helper_test.cc | 57 + .../python/framework/op_def_library_test.py | 4 +- .../python/framework/op_reg_offset.proto | 21 + tensorflow/python/framework/ops.py | 139 +- tensorflow/python/framework/ops_test.py | 3 +- tensorflow/python/framework/python_op_gen.cc | 92 +- tensorflow/python/framework/python_op_gen.h | 22 +- .../framework/python_op_gen_annotator.cc | 83 + .../framework/python_op_gen_annotator.h | 66 + .../framework/python_op_gen_annotator_test.cc | 152 + .../framework/python_op_gen_internal.cc | 14 +- .../python/framework/python_op_gen_internal.h | 6 +- .../python/framework/python_op_gen_main.cc | 238 +- .../python/framework/python_op_gen_test.cc | 86 +- .../python/framework/smart_cond_test.py | 8 +- tensorflow/python/framework/sparse_tensor.py | 12 +- .../python/framework/summary_test_util.py | 58 + tensorflow/python/framework/tensor_shape.py | 88 +- tensorflow/python/framework/tensor_spec.py | 212 +- .../python/framework/tensor_spec_test.py | 89 + tensorflow/python/framework/tensor_util.py | 117 +- .../python/framework/tensor_util_test.py | 36 + tensorflow/python/framework/test_ops.cc | 21 + tensorflow/python/framework/test_ops.cu.cc | 47 + .../core/numpy.cc => framework/test_ops.h} | 18 +- tensorflow/python/framework/test_util.py | 110 +- tensorflow/python/framework/test_util_test.py | 54 + tensorflow/python/framework/type_spec.py | 139 +- .../python/framework/type_spec_registry.py | 87 + tensorflow/python/framework/type_spec_test.py | 62 +- tensorflow/python/grappler/BUILD | 10 +- .../grappler/auto_mixed_precision_test.py | 5 +- tensorflow/python/grappler/remapper_test.py | 10 +- tensorflow/python/integration_testing/BUILD | 5 +- tensorflow/python/keras/BUILD | 1 + tensorflow/python/keras/distribute/BUILD | 1 + tensorflow/python/keras/engine/BUILD | 1 + tensorflow/python/keras/initializers/BUILD | 1 + tensorflow/python/keras/layers/BUILD | 1 + .../python/keras/layers/legacy_rnn/BUILD | 1 + .../python/keras/legacy_tf_layers/BUILD | 1 + tensorflow/python/keras/mixed_precision/BUILD | 2 + .../mixed_precision/autocast_variable.py | 15 +- tensorflow/python/keras/optimizer_v2/BUILD | 1 + tensorflow/python/keras/protobuf/BUILD | 1 + tensorflow/python/keras/saving/BUILD | 1 + .../python/keras/saving/saved_model/BUILD | 1 + tensorflow/python/keras/saving/utils_v1/BUILD | 1 + tensorflow/python/keras/utils/BUILD | 1 + tensorflow/python/kernel_tests/BUILD | 5 +- .../python/kernel_tests/array_ops/BUILD | 8 +- .../kernel_tests/array_ops/array_ops_test.py | 77 +- .../array_ops/batchtospace_op_test.py | 9 +- .../array_ops/broadcast_to_ops_test.py | 4 +- .../kernel_tests/array_ops/cast_op_test.py | 62 + .../kernel_tests/array_ops/concat_op_test.py | 11 + .../array_ops/constant_op_test.py | 6 +- .../array_ops/depthtospace_op_test.py | 15 +- .../kernel_tests/array_ops/diag_op_test.py | 26 +- .../array_ops/edit_distance_op_test.py | 63 +- .../array_ops/gather_nd_op_test.py | 2 + .../kernel_tests/array_ops/gather_op_test.py | 14 +- .../array_ops/inplace_ops_test.py | 8 +- .../kernel_tests/array_ops/manip_ops_test.py | 2 +- .../kernel_tests/array_ops/one_hot_op_test.py | 101 +- .../kernel_tests/array_ops/pad_op_test.py | 34 +- .../kernel_tests/array_ops/reshape_op_test.py | 8 + .../array_ops/reverse_sequence_op_test.py | 6 + .../array_ops/scatter_nd_ops_test.py | 32 +- .../array_ops/scatter_ops_test.py | 25 +- .../kernel_tests/array_ops/shape_ops_test.py | 2 + .../kernel_tests/array_ops/slice_op_test.py | 2 + .../array_ops/spacetobatch_op_test.py | 28 +- .../array_ops/spacetodepth_op_test.py | 4 +- .../kernel_tests/array_ops/split_op_test.py | 2 +- .../kernel_tests/array_ops/stack_op_test.py | 5 +- .../python/kernel_tests/control_flow/BUILD | 8 +- .../kernel_tests/control_flow/cond_v2_test.py | 97 +- .../control_flow/scan_ops_test.py | 27 +- .../python/kernel_tests/custom_ops/BUILD | 5 +- .../python/kernel_tests/data_structures/BUILD | 2 +- .../dynamic_partition_op_test.py | 9 +- .../data_structures/dynamic_stitch_op_test.py | 20 +- .../data_structures/lookup_ops_test.py | 15 + .../data_structures/tensor_array_ops_test.py | 29 + .../python/kernel_tests/distributions/BUILD | 5 +- .../python/kernel_tests/histogram_ops_test.py | 38 +- .../python/kernel_tests/image_ops/BUILD | 5 +- .../extract_image_patches_op_test.py | 25 +- .../extract_volume_patches_op_test.py | 8 +- tensorflow/python/kernel_tests/io_ops/BUILD | 5 +- .../kernel_tests/io_ops/parsing_ops_test.py | 43 +- tensorflow/python/kernel_tests/linalg/BUILD | 6 +- .../linear_operator_composition_test.py | 26 + .../linear_operator_lower_triangular_test.py | 35 + .../linalg/linear_operator_test.py | 15 +- .../linalg/linear_operator_util_test.py | 59 + .../linalg/matrix_exponential_op_test.py | 1 + .../linalg/matrix_inverse_op_test.py | 22 +- .../python/kernel_tests/linalg/sparse/BUILD | 1 + .../python/kernel_tests/linalg/testdata/BUILD | 1 + tensorflow/python/kernel_tests/math_ops/BUILD | 9 +- .../math_ops/aggregate_ops_test.py | 16 +- .../kernel_tests/math_ops/approx_topk_test.py | 33 +- .../kernel_tests/math_ops/argmax_op_test.py | 39 +- .../math_ops/batch_matmul_op_test.py | 29 +- .../kernel_tests/math_ops/bincount_op_test.py | 2 +- .../kernel_tests/math_ops/clip_ops_test.py | 30 +- .../math_ops/cumulative_logsumexp_test.py | 10 +- .../math_ops/cwise_ops_binary_test.py | 50 + .../math_ops/cwise_ops_unary_test.py | 9 +- .../kernel_tests/math_ops/matmul_op_test.py | 46 +- .../math_ops/reduction_ops_test.py | 133 +- .../math_ops/segment_reduction_ops_test.py | 48 +- .../python/kernel_tests/math_ops/sets_test.py | 10 + .../kernel_tests/math_ops/topk_op_test.py | 6 +- .../math_ops/transpose_op_test.py | 15 +- tensorflow/python/kernel_tests/nn_ops/BUILD | 11 +- .../kernel_tests/nn_ops/bias_op_base.py | 34 +- .../nn_ops/conv1d_transpose_test.py | 1 + .../nn_ops/conv2d_transpose_test.py | 12 + .../nn_ops/conv3d_transpose_test.py | 17 +- .../kernel_tests/nn_ops/conv_ops_test.py | 2 +- .../nn_ops/cudnn_deterministic_ops_test.py | 3 + .../nn_ops/depthwise_conv_op_base.py | 248 +- .../nn_ops/fractional_avg_pool_op_test.py | 12 +- .../nn_ops/fractional_max_pool_op_test.py | 12 +- .../nn_ops/morphological_ops_test.py | 149 +- .../nn_ops/pooling_ops_3d_test.py | 214 +- .../kernel_tests/nn_ops/pooling_ops_test.py | 104 +- .../kernel_tests/nn_ops/relu_op_test.py | 35 +- .../kernel_tests/nn_ops/rnn_cell_test.py | 38 + .../kernel_tests/nn_ops/softmax_op_test.py | 12 + .../kernel_tests/nn_ops/softplus_op_test.py | 19 +- .../kernel_tests/nn_ops/softsign_op_test.py | 8 +- .../kernel_tests/nn_ops/xent_op_test_base.py | 9 +- tensorflow/python/kernel_tests/proto/BUILD | 1 + .../kernel_tests/quantization_ops/BUILD | 1 + tensorflow/python/kernel_tests/random/BUILD | 1 + .../kernel_tests/random/random_grad_test.py | 25 + .../random/random_index_shuffle_test.py | 14 + .../random/stateless_random_ops_test.py | 3 +- tensorflow/python/kernel_tests/signal/BUILD | 1 + .../python/kernel_tests/sparse_ops/BUILD | 1 + .../sparse_ops/sparse_ops_test.py | 13 + .../sparse_ops/sparse_reorder_op_test.py | 11 +- .../sparse_ops/sparse_to_dense_op_py_test.py | 14 +- .../sparse_ops/sparse_xent_op_test_base.py | 8 + .../python/kernel_tests/strings_ops/BUILD | 5 +- .../python/kernel_tests/summary_ops/BUILD | 5 +- .../kernel_tests/tensor_priority_test.py | 3 +- .../python/kernel_tests/v1_compat_tests/BUILD | 2 +- .../python/kernel_tests/variables/BUILD | 6 +- .../variables/dense_update_ops_test.py | 6 +- .../variables/resource_variable_ops_test.py | 24 +- .../variables/variable_ops_test.py | 8 +- tensorflow/python/layers/BUILD | 1 + tensorflow/python/lib/BUILD | 1 + tensorflow/python/lib/core/BUILD | 75 +- tensorflow/python/lib/core/bfloat16.cc | 1859 ----- tensorflow/python/lib/core/bfloat16.h | 21 +- .../python/lib/core/bfloat16_wrapper.cc | 8 +- .../python/lib/core/custom_casts_wrapper.cc | 21 + ...{bfloat16_test.py => custom_float_test.py} | 297 +- tensorflow/python/lib/core/float8_e4m3b11.cc | 87 - tensorflow/python/lib/core/float8_e4m3b11.h | 64 - .../lib/core/float8_wrapper.cc} | 15 +- tensorflow/python/lib/core/ndarray_tensor.cc | 38 +- .../python/lib/core/ndarray_tensor_bridge.cc | 13 +- .../python/lib/core/ndarray_tensor_bridge.h | 6 +- tensorflow/python/lib/core/numpy.h | 50 - tensorflow/python/lib/core/py_func.cc | 10 +- tensorflow/python/lib/core/py_seq_tensor.cc | 6 +- tensorflow/python/lib/io/BUILD | 25 +- .../python/lite/toco_python_api_wrapper.cc | 6 +- tensorflow/python/mlir_wrapper.cc | 8 + tensorflow/python/module/BUILD | 1 + tensorflow/python/modules_with_exports.py | 1 + tensorflow/python/ops/array_ops.py | 50 +- tensorflow/python/ops/array_ops_test.py | 15 + tensorflow/python/ops/boosted_trees_ops.py | 64 +- .../python/ops/candidate_sampling_ops.py | 157 +- tensorflow/python/ops/check_ops.py | 266 +- tensorflow/python/ops/cond_v2.py | 55 +- tensorflow/python/ops/control_flow_ops.py | 226 +- tensorflow/python/ops/control_flow_util_v2.py | 19 +- tensorflow/python/ops/ctc_ops.py | 109 +- tensorflow/python/ops/custom_gradient.py | 7 +- tensorflow/python/ops/distributions/BUILD | 1 + tensorflow/python/ops/embedding_ops.py | 4 +- tensorflow/python/ops/gradients_util.py | 24 +- tensorflow/python/ops/handle_data_util.py | 5 +- tensorflow/python/ops/image_ops_impl.py | 80 +- tensorflow/python/ops/image_ops_test.py | 54 +- tensorflow/python/ops/linalg/BUILD | 2 + .../ops/linalg/cholesky_registrations.py | 77 +- .../python/ops/linalg/linear_operator.py | 3 +- .../ops/linalg/linear_operator_composition.py | 46 + .../ops/linalg/linear_operator_test_util.py | 5 + .../python/ops/linalg/linear_operator_util.py | 26 + tensorflow/python/ops/linalg/sparse/BUILD | 8 + tensorflow/python/ops/losses/BUILD | 1 + tensorflow/python/ops/math_ops.py | 177 +- tensorflow/python/ops/math_ops_test.py | 6 + tensorflow/python/ops/memory_tests/BUILD | 5 +- .../python/ops/nn_fused_batchnorm_test.py | 49 +- tensorflow/python/ops/nn_grad.py | 5 +- tensorflow/python/ops/nn_impl.py | 11 +- tensorflow/python/ops/nn_ops.py | 136 +- tensorflow/python/ops/nn_test.py | 56 +- tensorflow/python/ops/numerics.py | 10 + tensorflow/python/ops/numpy_ops/BUILD | 5 +- .../ops/numpy_ops/integration_test/BUILD | 2 + .../integration_test/benchmarks/BUILD | 5 +- .../python/ops/numpy_ops/np_math_ops.py | 12 +- .../python/ops/numpy_ops/np_math_ops_test.py | 8 + tensorflow/python/ops/optional_grad.py | 9 +- tensorflow/python/ops/parallel_for/BUILD | 1 + .../ops/parallel_for/control_flow_ops_test.py | 17 +- tensorflow/python/ops/parallel_for/pfor.py | 16 +- tensorflow/python/ops/ragged/BUILD | 8 + ...vert_to_tensor_or_ragged_tensor_op_test.py | 6 +- .../python/ops/ragged/ragged_const_op_test.py | 7 +- .../ragged/ragged_constant_value_op_test.py | 7 +- .../ops/ragged/ragged_one_hot_op_test.py | 4 +- .../python/ops/ragged/ragged_operators.py | 267 +- .../python/ops/ragged/ragged_range_op_test.py | 10 + tensorflow/python/ops/ragged/ragged_tensor.py | 12 +- .../python/ops/ragged/ragged_tensor_test.py | 7 +- tensorflow/python/ops/ragged/row_partition.py | 14 +- tensorflow/python/ops/random_grad.py | 61 +- .../python/ops/resource_variable_ops.py | 106 +- tensorflow/python/ops/risc/BUILD | 1 + tensorflow/python/ops/script_ops.py | 6 +- tensorflow/python/ops/signal/BUILD | 1 + tensorflow/python/ops/sobol_ops_test.py | 10 + tensorflow/python/ops/sparse_ops.py | 128 +- tensorflow/python/ops/stateless_random_ops.py | 111 +- tensorflow/python/ops/structured/BUILD | 8 +- tensorflow/python/ops/template.py | 5 +- tensorflow/python/ops/tensor_array_ops.py | 18 +- tensorflow/python/ops/v1_compat_tests/BUILD | 5 +- tensorflow/python/ops/variable_scope.py | 3 +- tensorflow/python/ops/variables.py | 42 +- tensorflow/python/ops/while_v2.py | 6 +- tensorflow/python/platform/BUILD | 24 +- tensorflow/python/platform/benchmark.py | 18 +- tensorflow/python/platform/sysconfig.py | 26 +- tensorflow/python/platform/sysconfig_test.py | 8 +- tensorflow/python/platform/tf_logging.py | 37 +- tensorflow/python/profiler/BUILD | 1 + .../python/profiler/integration_test/BUILD | 1 + tensorflow/python/profiler/internal/BUILD | 37 +- .../profiler/internal/flops_registry.py | 40 +- .../profiler/internal/flops_registry_test.py | 51 + .../profiler/internal/profiler_pywrap_impl.cc | 8 +- .../profiler/internal/traceme_wrapper.cc | 4 +- .../python/profiler/profiler_v2_test.py | 8 +- tensorflow/python/pywrap_dtensor_device.cc | 37 +- tensorflow/python/pywrap_mlir.py | 4 + tensorflow/python/saved_model/BUILD | 73 +- tensorflow/python/saved_model/builder_impl.py | 60 +- .../python/saved_model/fingerprinting.md | 26 + .../python/saved_model/fingerprinting.py | 97 + .../python/saved_model/fingerprinting_test.py | 39 +- .../saved_model/function_deserialization.py | 4 +- tensorflow/python/saved_model/load.py | 30 +- tensorflow/python/saved_model/load_test.py | 1845 +++-- .../python/saved_model/load_v1_in_v2.py | 7 +- tensorflow/python/saved_model/loader_impl.py | 5 +- tensorflow/python/saved_model/metrics_test.py | 19 +- .../python/saved_model/model_utils/BUILD | 1 + .../saved_model/nested_structure_coder.py | 219 +- .../nested_structure_coder_test.py | 35 +- tensorflow/python/saved_model/path_helpers.py | 82 + .../pywrap_saved_model_fingerprinting.cc | 36 +- .../pywrap_saved_model_fingerprinting_test.py | 63 +- .../saved_model/pywrap_saved_model_metrics.cc | 30 + .../pywrap_saved_model_metrics_test.py | 9 + .../python/saved_model/registration/BUILD | 1 + tensorflow/python/saved_model/save.py | 29 +- tensorflow/python/saved_model/save_test.py | 21 + tensorflow/python/saved_model/saved_model.py | 2 + .../python/saved_model/tracing_utils.py | 8 + .../python/saved_model/tracing_utils_test.py | 9 +- tensorflow/python/saved_model/utils_impl.py | 67 - tensorflow/python/summary/BUILD | 5 +- tensorflow/python/summary/writer/BUILD | 5 +- tensorflow/python/tfe_wrapper.cc | 105 +- tensorflow/python/tools/BUILD | 2 + tensorflow/python/tools/api/generator/BUILD | 1 + .../tools/api/generator/api_init_files.bzl | 5 + .../tools/api/generator/api_init_files_v1.bzl | 7 + tensorflow/python/tools/freeze_graph.py | 6 +- tensorflow/python/tools/saved_model_cli.py | 635 +- .../python/tools/saved_model_cli_test.py | 603 +- .../selective_registration_header_lib.py | 50 +- tensorflow/python/tpu/BUILD | 37 + tensorflow/python/tpu/client/BUILD | 1 + .../python/tpu/client/pip_package/BUILD | 5 +- tensorflow/python/tpu/experimental/BUILD | 5 +- tensorflow/python/tpu/feature_column_v2.py | 4 +- tensorflow/python/tpu/ops/BUILD | 1 + tensorflow/python/tpu/profiler/BUILD | 1 + tensorflow/python/tpu/tests/BUILD | 1 + tensorflow/python/tpu/tpu.py | 2 +- .../python/tpu/tpu_embedding_for_serving.py | 23 +- tensorflow/python/tpu/tpu_embedding_v2.py | 27 +- .../python/tpu/tpu_embedding_v2_utils.py | 188 +- .../python/tpu/tpu_embedding_v2_utils_test.py | 5 +- .../tpu/tpu_outside_compilation_test.py | 21 + tensorflow/python/tpu/tpu_replication.py | 692 ++ tensorflow/python/tpu/tpu_strategy_util.py | 50 +- tensorflow/python/trackable/BUILD | 2 + tensorflow/python/trackable/asset.py | 14 +- tensorflow/python/trackable/base.py | 50 +- .../python/trackable/data_structures.py | 14 +- tensorflow/python/trackable/resource.py | 9 +- tensorflow/python/training/BUILD | 6 +- .../python/training/checkpoint_utils.py | 30 +- tensorflow/python/training/experimental/BUILD | 1 + tensorflow/python/training/input.py | 13 +- tensorflow/python/training/optimizer.py | 4 +- tensorflow/python/training/saver.py | 5 + tensorflow/python/training/saving/BUILD | 2 + .../training/saving/saveable_object_util.py | 123 +- .../saving/saveable_object_util_test.py | 27 +- tensorflow/python/training/tracking/BUILD | 3 +- tensorflow/python/training/training.py | 1 + tensorflow/python/types/BUILD | 4 +- tensorflow/python/types/core.py | 14 +- .../types/data.py} | 26 +- tensorflow/python/types/internal.py | 14 + tensorflow/python/types/trace.py | 67 +- tensorflow/python/user_ops/BUILD | 1 + tensorflow/python/util/BUILD | 1 + tensorflow/python/util/dispatch.py | 23 +- tensorflow/python/util/memory.py | 41 - tensorflow/python/util/protobuf/BUILD | 1 + tensorflow/python/util/stack_trace.h | 10 +- tensorflow/python/util/tf_decorator.py | 96 +- tensorflow/python/util/tf_decorator_test.py | 66 +- tensorflow/python/util/tf_export.py | 19 +- tensorflow/python/util/tf_export_test.py | 42 - tensorflow/python/util/tf_inspect.py | 35 +- tensorflow/python/util/tf_inspect_test.py | 91 +- tensorflow/security/README.md | 27 + tensorflow/security/advisory/tfsa-2022-144.md | 30 + tensorflow/security/advisory/tfsa-2022-145.md | 44 + tensorflow/security/advisory/tfsa-2022-146.md | 34 + tensorflow/security/advisory/tfsa-2022-147.md | 33 + tensorflow/security/advisory/tfsa-2022-148.md | 30 + tensorflow/security/advisory/tfsa-2022-149.md | 29 + tensorflow/security/advisory/tfsa-2022-150.md | 29 + tensorflow/security/advisory/tfsa-2022-151.md | 28 + tensorflow/security/advisory/tfsa-2022-152.md | 30 + tensorflow/security/advisory/tfsa-2022-153.md | 27 + tensorflow/security/advisory/tfsa-2022-154.md | 28 + tensorflow/security/advisory/tfsa-2022-155.md | 29 + tensorflow/security/advisory/tfsa-2022-156.md | 52 + tensorflow/security/advisory/tfsa-2022-157.md | 27 + tensorflow/security/advisory/tfsa-2022-158.md | 32 + tensorflow/security/advisory/tfsa-2022-159.md | 32 + tensorflow/security/advisory/tfsa-2022-160.md | 27 + tensorflow/security/advisory/tfsa-2022-161.md | 42 + tensorflow/security/advisory/tfsa-2022-162.md | 30 + tensorflow/security/advisory/tfsa-2022-163.md | 30 + tensorflow/security/advisory/tfsa-2022-164.md | 25 + tensorflow/security/advisory/tfsa-2022-165.md | 20 + tensorflow/security/advisory/tfsa-2022-166.md | 18 + tensorflow/security/advisory/tfsa-2022-167.md | 16 + tensorflow/security/advisory/tfsa-2022-168.md | 35 + tensorflow/security/advisory/tfsa-2022-169.md | 29 + tensorflow/security/advisory/tfsa-2022-170.md | 30 + tensorflow/security/fuzzing/BUILD | 1 + .../fuzzing/cc/AreAttrValuesEqual_fuzz.cc | 2 +- tensorflow/security/fuzzing/cc/BUILD | 75 +- .../fuzzing/cc/ParseAttrValue_fuzz.cc | 2 +- .../security/fuzzing/cc/arg_def_case_fuzz.cc | 2 +- tensorflow/security/fuzzing/cc/base64_fuzz.cc | 2 +- .../security/fuzzing/cc/bfloat16_fuzz.cc | 2 +- .../fuzzing/cc/checkpoint_reader_fuzz.cc | 2 +- .../security/fuzzing/cc/cleanpath_fuzz.cc | 2 +- .../fuzzing/cc/consume_leading_digits_fuzz.cc | 2 +- .../security/fuzzing/cc/core/framework/BUILD | 14 + .../cc/core/framework/tensor_shape_domains.cc | 54 + .../cc/core/framework/tensor_shape_domains.h} | 21 +- .../security/fuzzing/cc/core/function/BUILD | 14 + .../cc/core/function/runtime_client_fuzz.cc | 97 + tensorflow/security/fuzzing/cc/fuzz_domains.h | 39 + tensorflow/security/fuzzing/cc/fuzz_session.h | 2 +- .../security/fuzzing/cc/joinpath_fuzz.cc | 2 +- tensorflow/security/fuzzing/cc/ops/BUILD | 46 +- .../security/fuzzing/cc/ops/add_fuzz.cc | 2 +- .../security/fuzzing/cc/ops/bincount_fuzz.cc | 58 + .../security/fuzzing/cc/ops/concat_fuzz.cc | 2 +- .../security/fuzzing/cc/ops/identity_fuzz.cc | 2 +- .../security/fuzzing/cc/ops/matmul_fuzz.cc | 2 +- .../fuzzing/cc/ops/string_to_number_fuzz.cc | 47 + .../security/fuzzing/cc/parseURI_fuzz.cc | 2 +- tensorflow/security/fuzzing/cc/status_fuzz.cc | 12 +- .../security/fuzzing/cc/status_group_fuzz.cc | 10 +- .../fuzzing/cc/string_replace_fuzz.cc | 2 +- .../security/fuzzing/cc/stringprintf_fuzz.cc | 2 +- .../security/fuzzing/cc/tstring_fuzz.cc | 2 +- tensorflow/security/fuzzing/tf_fuzzing.bzl | 28 + tensorflow/stream_executor/BUILD | 418 -- tensorflow/stream_executor/allocator_stats.h | 21 - tensorflow/stream_executor/blas.h | 21 - tensorflow/stream_executor/cuda/BUILD | 381 -- .../stream_executor/cuda/cuda_activation.h | 21 - tensorflow/stream_executor/cuda/cuda_blas.h | 21 - .../stream_executor/cuda/cuda_diagnostics.h | 21 - tensorflow/stream_executor/cuda/cuda_dnn.h | 21 - tensorflow/stream_executor/cuda/cuda_driver.h | 21 - tensorflow/stream_executor/cuda/cuda_event.h | 21 - tensorflow/stream_executor/cuda/cuda_fft.h | 21 - .../stream_executor/cuda/cuda_gpu_executor.h | 21 - .../stream_executor/cuda/cuda_helpers.h | 21 - tensorflow/stream_executor/cuda/cuda_kernel.h | 21 - .../stream_executor/cuda/cuda_platform.h | 21 - .../stream_executor/cuda/cuda_platform_id.h | 21 - tensorflow/stream_executor/cuda/cuda_rng.h | 21 - tensorflow/stream_executor/cuda/cuda_stream.h | 21 - tensorflow/stream_executor/cuda/cuda_timer.h | 21 - tensorflow/stream_executor/data_type.h | 21 - .../stream_executor/device_description.h | 25 - tensorflow/stream_executor/device_memory.h | 21 - .../stream_executor/device_memory_allocator.h | 21 - tensorflow/stream_executor/device_options.h | 21 - tensorflow/stream_executor/dnn.h | 21 - tensorflow/stream_executor/event.h | 21 - tensorflow/stream_executor/executor_cache.h | 21 - tensorflow/stream_executor/gpu/BUILD | 176 - tensorflow/stream_executor/gpu/gpu_asm_opts.h | 21 - tensorflow/stream_executor/gpu/gpu_driver.h | 21 - tensorflow/stream_executor/gpu/gpu_event.h | 21 - tensorflow/stream_executor/gpu/gpu_executor.h | 21 - tensorflow/stream_executor/gpu/gpu_helpers.h | 21 - tensorflow/stream_executor/gpu/gpu_kernel.h | 21 - tensorflow/stream_executor/gpu/gpu_stream.h | 21 - tensorflow/stream_executor/gpu/gpu_timer.h | 21 - tensorflow/stream_executor/gpu/gpu_types.h | 21 - .../stream_executor/gpu/redzone_allocator.h | 21 - tensorflow/stream_executor/gpu_launch_dim.h | 21 - tensorflow/stream_executor/host/BUILD | 100 - .../stream_executor/host/host_gpu_executor.h | 21 - .../stream_executor/host/host_platform.h | 21 - .../stream_executor/host/host_platform_id.h | 21 - tensorflow/stream_executor/host/host_stream.h | 21 - tensorflow/stream_executor/host/host_timer.h | 21 - .../stream_executor/host_or_device_scalar.h | 21 - tensorflow/stream_executor/kernel.h | 21 - .../stream_executor/kernel_cache_config.h | 21 - tensorflow/stream_executor/kernel_spec.h | 21 - tensorflow/stream_executor/launch_dim.h | 21 - tensorflow/stream_executor/lib/BUILD | 36 - tensorflow/stream_executor/lib/array_slice.h | 21 - tensorflow/stream_executor/lib/demangle.h | 21 - tensorflow/stream_executor/lib/env.h | 21 - tensorflow/stream_executor/lib/error.h | 21 - .../stream_executor/lib/human_readable.h | 21 - tensorflow/stream_executor/lib/initialize.h | 21 - tensorflow/stream_executor/lib/mathutil.h | 21 - tensorflow/stream_executor/lib/numbers.h | 21 - tensorflow/stream_executor/lib/path.h | 21 - .../stream_executor/lib/process_state.h | 21 - tensorflow/stream_executor/lib/stacktrace.h | 21 - .../stream_executor/lib/static_threadlocal.h | 21 - tensorflow/stream_executor/lib/status.h | 21 - tensorflow/stream_executor/lib/statusor.h | 21 - .../stream_executor/lib/thread_options.h | 21 - tensorflow/stream_executor/lib/threadpool.h | 21 - tensorflow/stream_executor/module_spec.h | 21 - .../stream_executor/multi_platform_manager.h | 21 - tensorflow/stream_executor/platform.h | 21 - tensorflow/stream_executor/platform/BUILD | 38 - .../stream_executor/platform/default/BUILD | 32 - .../platform/default/dso_loader.h | 21 - .../stream_executor/platform/dso_loader.h | 21 - .../stream_executor/platform/initialize.h | 21 - tensorflow/stream_executor/platform/logging.h | 21 - .../stream_executor/platform/platform.h | 21 - tensorflow/stream_executor/platform/port.h | 21 - tensorflow/stream_executor/plugin.h | 21 - tensorflow/stream_executor/plugin_registry.h | 21 - tensorflow/stream_executor/rng.h | 21 - tensorflow/stream_executor/rocm/BUILD | 211 - .../stream_executor/rocm/hipsolver_wrapper.h | 21 - .../stream_executor/rocm/hipsparse_wrapper.h | 21 - .../stream_executor/rocm/rocblas_wrapper.h | 21 - .../stream_executor/rocm/rocm_activation.h | 21 - tensorflow/stream_executor/rocm/rocm_blas.h | 21 - .../stream_executor/rocm/rocm_diagnostics.h | 21 - tensorflow/stream_executor/rocm/rocm_dnn.h | 21 - .../rocm/rocm_driver_wrapper.h | 21 - tensorflow/stream_executor/rocm/rocm_fft.h | 21 - .../stream_executor/rocm/rocm_platform.h | 21 - .../stream_executor/rocm/rocm_platform_id.h | 21 - .../stream_executor/rocm/rocsolver_wrapper.h | 21 - .../stream_executor/rocm/roctracer_wrapper.h | 21 - .../stream_executor/scratch_allocator.h | 21 - tensorflow/stream_executor/stream.h | 21 - tensorflow/stream_executor/stream_executor.h | 21 - .../stream_executor_internal.h | 21 - .../stream_executor/stream_executor_pimpl.h | 21 - .../stream_executor/temporary_device_memory.h | 21 - .../temporary_memory_manager.h | 21 - .../stream_executor/tf_allocator_adapter.h | 21 - tensorflow/stream_executor/timer.h | 21 - tensorflow/stream_executor/tpu/BUILD | 229 - tensorflow/stream_executor/tpu/c_api_decl.h | 21 - tensorflow/stream_executor/tpu/c_api_defn.h | 21 - .../stream_executor/tpu/noncopyable_buffer.h | 21 - tensorflow/stream_executor/tpu/proto_helper.h | 21 - .../stream_executor/tpu/status_helper.h | 21 - tensorflow/stream_executor/tpu/tpu_event.h | 21 - .../stream_executor/tpu/tpu_executable.h | 21 - .../tpu/tpu_executable_interface.h | 21 - tensorflow/stream_executor/tpu/tpu_executor.h | 21 - .../stream_executor/tpu/tpu_executor_c_api.h | 21 - .../stream_executor/tpu/tpu_node_context.h | 21 - .../stream_executor/tpu/tpu_op_executable.h | 21 - tensorflow/stream_executor/tpu/tpu_platform.h | 21 - .../stream_executor/tpu/tpu_platform_id.h | 21 - .../tpu/tpu_platform_interface.h | 21 - tensorflow/stream_executor/tpu/tpu_stream.h | 21 - .../tpu/tpu_stream_interface.h | 21 - tensorflow/stream_executor/tpu/tpu_timer.h | 21 - tensorflow/stream_executor/tpu/tpu_topology.h | 21 - .../tpu/tpu_transfer_manager.h | 21 - .../tpu/tpu_transfer_manager_interface.h | 21 - tensorflow/stream_executor/trace_listener.h | 21 - tensorflow/tensorflow.bzl | 363 +- tensorflow/tensorflow.default.bzl | 6 - tensorflow/tf_exported_symbols.lds | 3 + tensorflow/tf_version_script.lds | 3 + .../tools/android/inference_interface/BUILD | 1 + tensorflow/tools/android/test/BUILD | 1 + tensorflow/tools/android/test/assets/BUILD | 1 + tensorflow/tools/api/golden/BUILD | 1 + ...nsorflow.-config-proto.-experimental.pbtxt | 20 +- .../golden/v1/tensorflow.-config-proto.pbtxt | 20 +- .../api/golden/v1/tensorflow.-d-type.pbtxt | 4 + .../golden/v1/tensorflow.-g-p-u-options.pbtxt | 12 + .../v1/tensorflow.-indexed-slices-spec.pbtxt | 5 + .../golden/v1/tensorflow.-optional-spec.pbtxt | 5 + .../v1/tensorflow.-ragged-tensor-spec.pbtxt | 5 + .../v1/tensorflow.-sparse-tensor-spec.pbtxt | 5 + .../v1/tensorflow.-tensor-array-spec.pbtxt | 5 + .../golden/v1/tensorflow.-tensor-shape.pbtxt | 4 + .../golden/v1/tensorflow.-tensor-spec.pbtxt | 6 + .../api/golden/v1/tensorflow.-type-spec.pbtxt | 5 + .../golden/v1/tensorflow.__internal__.pbtxt | 7 + ...low.__internal__.types.data.-dataset.pbtxt | 9 + .../tensorflow.__internal__.types.data.pbtxt | 7 + .../v1/tensorflow.__internal__.types.pbtxt | 7 + .../v1/tensorflow.config.experimental.pbtxt | 8 - .../v1/tensorflow.data.-dataset-spec.pbtxt | 5 + .../golden/v1/tensorflow.data.-dataset.pbtxt | 4 +- ...ow.data.-fixed-length-record-dataset.pbtxt | 4 +- .../golden/v1/tensorflow.data.-options.pbtxt | 4 + .../tensorflow.data.-t-f-record-dataset.pbtxt | 4 +- .../tensorflow.data.-text-line-dataset.pbtxt | 4 +- ...rflow.data.experimental.-csv-dataset.pbtxt | 4 +- ...data.experimental.-dataset-structure.pbtxt | 5 + ...ata.experimental.-optional-structure.pbtxt | 5 + ...ow.data.experimental.-random-dataset.pbtxt | 6 +- ...rflow.data.experimental.-sql-dataset.pbtxt | 4 +- ...sorflow.data.experimental.-structure.pbtxt | 5 + ...erimental.service.-dispatcher-config.pbtxt | 4 + ...ter_resolver.-t-p-u-cluster-resolver.pbtxt | 4 + .../golden/v1/tensorflow.dtypes.-d-type.pbtxt | 4 + .../v1/tensorflow.dtypes.experimental.pbtxt | 11 + .../api/golden/v1/tensorflow.dtypes.pbtxt | 4 + ...rimental.-dynamic-ragged-shape.-spec.pbtxt | 5 + ...xperimental.-structured-tensor.-spec.pbtxt | 5 + ...nsorflow.experimental.extension_type.pbtxt | 7 + .../golden/v1/tensorflow.experimental.pbtxt | 12 + ...ensorflow.layers.-average-pooling1-d.pbtxt | 8 + ...ensorflow.layers.-average-pooling2-d.pbtxt | 8 + ...ensorflow.layers.-average-pooling3-d.pbtxt | 8 + ...nsorflow.layers.-batch-normalization.pbtxt | 10 +- .../v1/tensorflow.layers.-conv1-d.pbtxt | 8 + ...tensorflow.layers.-conv2-d-transpose.pbtxt | 8 + .../v1/tensorflow.layers.-conv2-d.pbtxt | 8 + ...tensorflow.layers.-conv3-d-transpose.pbtxt | 8 + .../v1/tensorflow.layers.-conv3-d.pbtxt | 8 + .../golden/v1/tensorflow.layers.-dense.pbtxt | 8 + .../v1/tensorflow.layers.-dropout.pbtxt | 8 + .../v1/tensorflow.layers.-flatten.pbtxt | 8 + .../golden/v1/tensorflow.layers.-layer.pbtxt | 8 + .../tensorflow.layers.-max-pooling1-d.pbtxt | 8 + .../tensorflow.layers.-max-pooling2-d.pbtxt | 8 + .../tensorflow.layers.-max-pooling3-d.pbtxt | 8 + ...tensorflow.layers.-separable-conv1-d.pbtxt | 8 + ...tensorflow.layers.-separable-conv2-d.pbtxt | 8 + .../golden/v1/tensorflow.lite.-ops-set.pbtxt | 4 + .../v1/tensorflow.nn.experimental.pbtxt | 4 + .../tools/api/golden/v1/tensorflow.nn.pbtxt | 6 +- ...flow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt | 8 + ...orflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt | 8 + ...nsorflow.nn.rnn_cell.-device-wrapper.pbtxt | 8 + ...sorflow.nn.rnn_cell.-dropout-wrapper.pbtxt | 8 + .../tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt | 8 + ...tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt | 8 + ...orflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt | 8 + .../tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt | 8 + ...orflow.nn.rnn_cell.-residual-wrapper.pbtxt | 8 + .../tools/api/golden/v1/tensorflow.pbtxt | 14 +- .../api/golden/v1/tensorflow.random.pbtxt | 8 + .../api/golden/v1/tensorflow.raw_ops.pbtxt | 48 + .../v1/tensorflow.test.experimental.pbtxt | 7 + .../tools/api/golden/v1/tensorflow.test.pbtxt | 4 + ...rimental.embedding.-adagrad-momentum.pbtxt | 2 +- ....tpu.experimental.embedding.-adagrad.pbtxt | 2 +- ...low.tpu.experimental.embedding.-adam.pbtxt | 2 +- ....tpu.experimental.embedding.-f-t-r-l.pbtxt | 2 +- ...ow.tpu.experimental.embedding.-s-g-d.pbtxt | 2 +- ...tensorflow.train.-checkpoint-options.pbtxt | 6 +- .../api/golden/v2/tensorflow.-d-type.pbtxt | 4 + .../v2/tensorflow.-indexed-slices-spec.pbtxt | 5 + .../golden/v2/tensorflow.-optional-spec.pbtxt | 5 + .../v2/tensorflow.-ragged-tensor-spec.pbtxt | 5 + .../v2/tensorflow.-sparse-tensor-spec.pbtxt | 5 + .../v2/tensorflow.-tensor-array-spec.pbtxt | 5 + .../golden/v2/tensorflow.-tensor-shape.pbtxt | 4 + .../golden/v2/tensorflow.-tensor-spec.pbtxt | 6 + .../api/golden/v2/tensorflow.-type-spec.pbtxt | 5 + .../tensorflow.__internal__.-func-graph.pbtxt | 4 - ...flow.__internal__.function.-function.pbtxt | 2 +- .../v2/tensorflow.__internal__.tracking.pbtxt | 2 +- ...low.__internal__.types.data.-dataset.pbtxt | 8 + .../tensorflow.__internal__.types.data.pbtxt | 7 + .../v2/tensorflow.__internal__.types.pbtxt | 4 + .../v2/tensorflow.config.experimental.pbtxt | 8 - .../v2/tensorflow.data.-dataset-spec.pbtxt | 5 + .../golden/v2/tensorflow.data.-dataset.pbtxt | 4 +- ...ow.data.-fixed-length-record-dataset.pbtxt | 4 +- .../v2/tensorflow.data.-iterator-spec.pbtxt | 5 + .../golden/v2/tensorflow.data.-options.pbtxt | 4 + .../tensorflow.data.-t-f-record-dataset.pbtxt | 4 +- .../tensorflow.data.-text-line-dataset.pbtxt | 4 +- ...rflow.data.experimental.-csv-dataset.pbtxt | 4 +- ...ow.data.experimental.-random-dataset.pbtxt | 8 +- ...rflow.data.experimental.-sql-dataset.pbtxt | 4 +- .../v2/tensorflow.data.experimental.pbtxt | 4 + ...erimental.service.-dispatcher-config.pbtxt | 4 + ...ter_resolver.-t-p-u-cluster-resolver.pbtxt | 4 + ...ute.experimental.-preemption-watcher.pbtxt | 13 + ...ute.experimental.-termination-config.pbtxt | 2 +- .../tensorflow.distribute.experimental.pbtxt | 4 + .../golden/v2/tensorflow.dtypes.-d-type.pbtxt | 4 + .../v2/tensorflow.dtypes.experimental.pbtxt | 11 + .../api/golden/v2/tensorflow.dtypes.pbtxt | 4 + ...rimental.-dynamic-ragged-shape.-spec.pbtxt | 5 + ...xperimental.-structured-tensor.-spec.pbtxt | 5 + ...perimental.dtensor.-d-tensor-dataset.pbtxt | 4 +- ...sorflow.experimental.dtensor.-layout.pbtxt | 4 - ...ensorflow.experimental.dtensor.-mesh.pbtxt | 6 +- .../v2/tensorflow.experimental.dtensor.pbtxt | 18 +- ...nsorflow.experimental.extension_type.pbtxt | 7 + .../golden/v2/tensorflow.experimental.pbtxt | 12 + .../tensorflow.initializers.-constant.pbtxt | 4 +- ...nsorflow.initializers.-glorot-normal.pbtxt | 6 +- ...sorflow.initializers.-glorot-uniform.pbtxt | 6 +- .../tensorflow.initializers.-he-normal.pbtxt | 6 +- .../tensorflow.initializers.-he-uniform.pbtxt | 6 +- .../tensorflow.initializers.-identity.pbtxt | 4 +- ...tensorflow.initializers.-initializer.pbtxt | 2 +- ...ensorflow.initializers.-lecun-normal.pbtxt | 6 +- ...nsorflow.initializers.-lecun-uniform.pbtxt | 6 +- .../v2/tensorflow.initializers.-ones.pbtxt | 4 +- .../tensorflow.initializers.-orthogonal.pbtxt | 4 +- ...nsorflow.initializers.-random-normal.pbtxt | 4 +- ...sorflow.initializers.-random-uniform.pbtxt | 4 +- ...rflow.initializers.-truncated-normal.pbtxt | 4 +- ...rflow.initializers.-variance-scaling.pbtxt | 4 +- .../v2/tensorflow.initializers.-zeros.pbtxt | 4 +- .../v2/tensorflow.initializers.constant.pbtxt | 4 +- ...ensorflow.initializers.glorot_normal.pbtxt | 6 +- ...nsorflow.initializers.glorot_uniform.pbtxt | 6 +- .../tensorflow.initializers.he_normal.pbtxt | 6 +- .../tensorflow.initializers.he_uniform.pbtxt | 6 +- .../v2/tensorflow.initializers.identity.pbtxt | 4 +- ...tensorflow.initializers.lecun_normal.pbtxt | 6 +- ...ensorflow.initializers.lecun_uniform.pbtxt | 6 +- .../v2/tensorflow.initializers.ones.pbtxt | 4 +- .../tensorflow.initializers.orthogonal.pbtxt | 4 +- .../golden/v2/tensorflow.initializers.pbtxt | 4 +- ...ensorflow.initializers.random_normal.pbtxt | 4 +- ...nsorflow.initializers.random_uniform.pbtxt | 4 +- ...orflow.initializers.truncated_normal.pbtxt | 4 +- ...orflow.initializers.variance_scaling.pbtxt | 4 +- .../v2/tensorflow.initializers.zeros.pbtxt | 4 +- .../golden/v2/tensorflow.lite.-ops-set.pbtxt | 4 + .../api/golden/v2/tensorflow.losses.pbtxt | 4 +- .../golden/v2/tensorflow.metrics.-a-u-c.pbtxt | 10 +- .../v2/tensorflow.metrics.-accuracy.pbtxt | 10 +- .../tensorflow.metrics.-binary-accuracy.pbtxt | 10 +- ...sorflow.metrics.-binary-crossentropy.pbtxt | 10 +- .../v2/tensorflow.metrics.-binary-io-u.pbtxt | 14 +- ...orflow.metrics.-categorical-accuracy.pbtxt | 10 +- ...ow.metrics.-categorical-crossentropy.pbtxt | 10 +- ...ensorflow.metrics.-categorical-hinge.pbtxt | 10 +- ...ensorflow.metrics.-cosine-similarity.pbtxt | 10 +- .../v2/tensorflow.metrics.-f-beta-score.pbtxt | 255 + .../v2/tensorflow.metrics.-f1-score.pbtxt | 256 + .../tensorflow.metrics.-false-negatives.pbtxt | 12 +- .../tensorflow.metrics.-false-positives.pbtxt | 12 +- .../golden/v2/tensorflow.metrics.-hinge.pbtxt | 10 +- .../golden/v2/tensorflow.metrics.-io-u.pbtxt | 12 +- .../tensorflow.metrics.-k-l-divergence.pbtxt | 10 +- .../tensorflow.metrics.-log-cosh-error.pbtxt | 10 +- ...sorflow.metrics.-mean-absolute-error.pbtxt | 10 +- ...rics.-mean-absolute-percentage-error.pbtxt | 10 +- .../v2/tensorflow.metrics.-mean-io-u.pbtxt | 14 +- ...sorflow.metrics.-mean-metric-wrapper.pbtxt | 8 + ...sorflow.metrics.-mean-relative-error.pbtxt | 10 +- ...nsorflow.metrics.-mean-squared-error.pbtxt | 10 +- ...rics.-mean-squared-logarithmic-error.pbtxt | 10 +- .../v2/tensorflow.metrics.-mean-tensor.pbtxt | 8 + .../golden/v2/tensorflow.metrics.-mean.pbtxt | 8 + .../v2/tensorflow.metrics.-metric.pbtxt | 8 + .../v2/tensorflow.metrics.-one-hot-io-u.pbtxt | 14 +- ...ensorflow.metrics.-one-hot-mean-io-u.pbtxt | 16 +- .../v2/tensorflow.metrics.-poisson.pbtxt | 10 +- ...sorflow.metrics.-precision-at-recall.pbtxt | 12 +- .../v2/tensorflow.metrics.-precision.pbtxt | 10 +- ...sorflow.metrics.-recall-at-precision.pbtxt | 12 +- .../v2/tensorflow.metrics.-recall.pbtxt | 10 +- ...low.metrics.-root-mean-squared-error.pbtxt | 10 +- ....metrics.-sensitivity-at-specificity.pbtxt | 12 +- ...metrics.-sparse-categorical-accuracy.pbtxt | 10 +- ...ics.-sparse-categorical-crossentropy.pbtxt | 10 +- ...s.-sparse-top-k-categorical-accuracy.pbtxt | 10 +- ....metrics.-specificity-at-sensitivity.pbtxt | 12 +- .../tensorflow.metrics.-squared-hinge.pbtxt | 10 +- .../golden/v2/tensorflow.metrics.-sum.pbtxt | 8 + ....metrics.-top-k-categorical-accuracy.pbtxt | 10 +- .../tensorflow.metrics.-true-negatives.pbtxt | 12 +- .../tensorflow.metrics.-true-positives.pbtxt | 12 +- .../api/golden/v2/tensorflow.metrics.pbtxt | 12 +- .../v2/tensorflow.nn.experimental.pbtxt | 4 + .../v2/tensorflow.optimizers.-adadelta.pbtxt | 14 +- .../v2/tensorflow.optimizers.-adafactor.pbtxt | 81 + .../v2/tensorflow.optimizers.-adagrad.pbtxt | 14 +- .../v2/tensorflow.optimizers.-adam-w.pbtxt | 81 + .../v2/tensorflow.optimizers.-adam.pbtxt | 14 +- .../v2/tensorflow.optimizers.-adamax.pbtxt | 14 +- .../v2/tensorflow.optimizers.-ftrl.pbtxt | 14 +- .../v2/tensorflow.optimizers.-nadam.pbtxt | 14 +- .../v2/tensorflow.optimizers.-optimizer.pbtxt | 12 +- .../v2/tensorflow.optimizers.-r-m-sprop.pbtxt | 14 +- .../v2/tensorflow.optimizers.-s-g-d.pbtxt | 16 +- ...ow.optimizers.experimental.-adadelta.pbtxt | 14 +- ...w.optimizers.experimental.-adafactor.pbtxt | 14 +- ...low.optimizers.experimental.-adagrad.pbtxt | 14 +- ...flow.optimizers.experimental.-adam-w.pbtxt | 14 +- ...orflow.optimizers.experimental.-adam.pbtxt | 14 +- ...flow.optimizers.experimental.-adamax.pbtxt | 14 +- ...orflow.optimizers.experimental.-ftrl.pbtxt | 14 +- ...rflow.optimizers.experimental.-nadam.pbtxt | 14 +- ...w.optimizers.experimental.-optimizer.pbtxt | 12 +- ...w.optimizers.experimental.-r-m-sprop.pbtxt | 14 +- ...rflow.optimizers.experimental.-s-g-d.pbtxt | 16 +- ...nsorflow.optimizers.legacy.-adadelta.pbtxt | 4 +- ...ensorflow.optimizers.legacy.-adagrad.pbtxt | 4 +- .../tensorflow.optimizers.legacy.-adam.pbtxt | 4 +- ...tensorflow.optimizers.legacy.-adamax.pbtxt | 4 +- .../tensorflow.optimizers.legacy.-ftrl.pbtxt | 4 +- .../tensorflow.optimizers.legacy.-nadam.pbtxt | 4 +- ...sorflow.optimizers.legacy.-optimizer.pbtxt | 2 +- ...sorflow.optimizers.legacy.-r-m-sprop.pbtxt | 4 +- .../tensorflow.optimizers.legacy.-s-g-d.pbtxt | 4 +- .../api/golden/v2/tensorflow.optimizers.pbtxt | 12 +- .../v2/tensorflow.optimizers.schedules.pbtxt | 4 +- .../tools/api/golden/v2/tensorflow.pbtxt | 14 +- .../api/golden/v2/tensorflow.random.pbtxt | 8 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 48 + ...aved_model.experimental.-fingerprint.pbtxt | 9 + .../tensorflow.saved_model.experimental.pbtxt | 8 + .../v2/tensorflow.test.experimental.pbtxt | 7 + .../tools/api/golden/v2/tensorflow.test.pbtxt | 4 + ...rimental.embedding.-adagrad-momentum.pbtxt | 2 +- ....tpu.experimental.embedding.-adagrad.pbtxt | 2 +- ...low.tpu.experimental.embedding.-adam.pbtxt | 2 +- ....tpu.experimental.embedding.-f-t-r-l.pbtxt | 2 +- ...ow.tpu.experimental.embedding.-s-g-d.pbtxt | 2 +- ...tensorflow.train.-checkpoint-options.pbtxt | 6 +- ...rflow.types.experimental.-trace-type.pbtxt | 4 + tensorflow/tools/api/lib/BUILD | 1 + tensorflow/tools/api/tests/BUILD | 1 + .../tools/api/tests/api_compatibility_test.py | 73 +- tensorflow/tools/benchmark/BUILD | 1 + tensorflow/tools/benchmark/README.md | 4 +- tensorflow/tools/benchmark/download_models.sh | 6 +- .../onednn_benchmark_config.sh} | 32 +- tensorflow/tools/benchmark/run_models.sh | 38 +- .../tools/benchmark/run_onednn_benchmarks.sh | 7 +- .../tools/ci_build/Dockerfile.cpu.arm64 | 4 +- ...dnn8-ubuntu20.04-manylinux2014-multipython | 14 +- ...n8.2-ubuntu20.04-manylinux2014-multipython | 10 +- ...n8.6-ubuntu20.04-manylinux2014-multipython | 50 + ...rocm-ubuntu20.04-manylinux2014-multipython | 16 +- tensorflow/tools/ci_build/Dockerfile.rocm | 7 +- .../ci_build/build_scripts/ARM_SKIP_TESTS.sh | 18 +- tensorflow/tools/ci_build/builds/android.sh | 4 +- .../tools/ci_build/builds/android_full.sh | 4 +- tensorflow/tools/ci_build/builds/pip_new.sh | 1 + tensorflow/tools/ci_build/ctpu/BUILD | 1 + .../ci_build/devtoolset/build_devtoolset.sh | 2 +- tensorflow/tools/ci_build/gpu_build/BUILD | 1 + tensorflow/tools/ci_build/install/BUILD | 5 +- .../tools/ci_build/install/install_bazel.sh | 3 +- .../install/install_bazel_from_source.sh | 3 +- .../install_pip_packages_by_version.sh | 50 +- .../ci_build/linux/rocm/rocm_py310_pip.sh | 83 + .../ci_build/linux/rocm/rocm_py37_pip.sh | 16 +- .../ci_build/linux/rocm/rocm_py38_pip.sh | 16 +- .../ci_build/linux/rocm/rocm_py39_pip.sh | 14 +- .../ci_build/linux/rocm/run_gpu_multi.sh | 2 +- .../ci_build/linux/rocm/run_gpu_single.sh | 3 +- .../nightly_release/ubuntu/cpu_py37.sh | 41 - .../nightly_release/ubuntu/gpu_py37.sh | 42 - .../ci_build/rel/macos/cpu_py37_nonpip.sh | 49 - .../ci_build/rel/ubuntu/cpu_arm64_pip.sh | 5 +- .../ci_build/rel/ubuntu/cpu_py37_nonpip.sh | 46 - .../tools/ci_build/rel/ubuntu/cpu_py37_pip.sh | 43 - .../ci_build/rel/ubuntu/gpu_py37_nonpip.sh | 48 - .../tools/ci_build/rel/ubuntu/gpu_py37_pip.sh | 51 - .../tools/ci_build/rel/windows/cpu_py37.bat | 24 - .../tools/ci_build/rel/windows/gpu_py37.bat | 25 - .../ci_build/rel/windows_cuda114/cpu_py37.bat | 24 - .../ci_build/rel/windows_cuda114/gpu_py37.bat | 25 - tensorflow/tools/ci_build/release/common.sh | 9 +- .../tools/ci_build/release/common_win.bat | 16 +- .../ci_build/release/requirements_common.txt | 29 +- .../ci_build/release/requirements_mac.txt | 6 +- .../ci_build/windows/bazel/bazel_test_lib.sh | 7 - .../ci_build/windows/bazel/common_env.sh | 12 - tensorflow/tools/common/BUILD | 1 + tensorflow/tools/compatibility/BUILD | 1 + .../tools/compatibility/all_renames_v2.py | 78 +- tensorflow/tools/compatibility/ast_edits.py | 9 +- tensorflow/tools/compatibility/renames_v2.py | 136 +- tensorflow/tools/compatibility/reorders_v2.py | 182 +- .../tools/compatibility/tf_upgrade_v2.py | 34 +- .../tools/compatibility/tf_upgrade_v2_test.py | 201 +- tensorflow/tools/compatibility/update/BUILD | 8 +- .../update/generate_v2_renames_map.py | 60 +- .../update/generate_v2_reorders_map.py | 125 +- .../def_file_filter/def_file_filter.py.tpl | 33 +- .../tools/def_file_filter/symbols_pybind.txt | 68 +- .../devel-cpu-arm64v8-jupyter.Dockerfile | 1 + .../arm64v8/devel-cpu-arm64v8.Dockerfile | 1 + .../dockerfiles/devel-cpu-jupyter.Dockerfile | 4 +- .../dockerfiles/devel-cpu.Dockerfile | 4 +- .../dockerfiles/devel-gpu-jupyter.Dockerfile | 24 +- .../dockerfiles/devel-gpu.Dockerfile | 24 +- .../dockerfiles/gpu-jupyter.Dockerfile | 30 +- .../dockerfiles/dockerfiles/gpu.Dockerfile | 30 +- .../devel-cpu-ppc64le-jupyter.Dockerfile | 1 + .../ppc64le/devel-cpu-ppc64le.Dockerfile | 1 + .../devel-gpu-ppc64le-jupyter.Dockerfile | 21 +- .../ppc64le/devel-gpu-ppc64le.Dockerfile | 21 +- .../ppc64le/gpu-ppc64le-jupyter.Dockerfile | 30 +- .../ppc64le/gpu-ppc64le.Dockerfile | 30 +- .../ubuntu/devel-nvidia.partial.Dockerfile | 20 +- .../partials/ubuntu/nvidia.partial.Dockerfile | 30 +- tensorflow/tools/dockerfiles/spec.yml | 1 - .../tools/dockerfiles/tests/build-cpu.sh | 2 +- .../tools/dockerfiles/tests/build-gpu.sh | 2 +- .../dockerfiles/tflite-android.Dockerfile | 2 +- tensorflow/tools/docs/BUILD | 1 + tensorflow/tools/docs/generate2.py | 8 +- tensorflow/tools/git/BUILD | 1 + tensorflow/tools/graph_transforms/BUILD | 1 + tensorflow/tools/lib_package/BUILD | 16 +- tensorflow/tools/mlpbtxt/BUILD | 1 + tensorflow/tools/optimization/BUILD | 1 + tensorflow/tools/pip_package/BUILD | 16 +- .../tools/pip_package/build_pip_package.sh | 13 +- .../redundant_tensorflow_gpu/README.md | 30 + .../redundant_tensorflow_gpu/setup.cfg | 49 + .../redundant_tensorflow_gpu/setup.py | 40 + .../redundant_tf_nightly_gpu/README.md | 29 + .../redundant_tf_nightly_gpu/setup.cfg | 49 + .../redundant_tf_nightly_gpu/setup.py | 40 + tensorflow/tools/pip_package/setup.py | 40 +- tensorflow/tools/proto_text/BUILD | 1 + .../tensorflow_builder/compat_checker/BUILD | 1 + .../tensorflow_builder/config_detector/BUILD | 5 +- .../config_detector/data/golden/BUILD | 1 + tensorflow/tools/test/BUILD | 2 + .../tools/test/run_and_gather_logs_lib.py | 4 +- .../tools/tf_sig_build_dockerfiles/Dockerfile | 9 +- .../builder.devtoolset/build_devtoolset.sh | 11 +- .../builder.devtoolset/glibc2.17-inline.patch | 11 + .../devel.packages.txt | 38 +- .../devel.requirements.txt | 39 +- .../devel.usertools/code_check_full.bats | 2 + .../devel.usertools/cpu.bazelrc | 26 +- .../devel.usertools/cpu_clang.bazelrc | 123 + .../devel.usertools/gpu.bazelrc | 48 +- .../devel.usertools/gpu_clang.bazelrc | 153 + .../devel.usertools/test.requirements.txt | 9 +- .../devel.usertools/wheel_verification.bats | 2 +- .../tf_sig_build_dockerfiles/setup.sources.sh | 7 + tensorflow/tools/tfg_graph_transforms/BUILD | 5 +- .../tools/toolchains/cpus/aarch64/README.md | 21 + .../tools/toolchains/cpus/aarch64/aarch64.bzl | 242 + .../aarch64/aarch64_compiler_configure.bzl | 70 + .../toolchains/cpus/aarch64/crosstool/BUILD | 0 .../cpus/aarch64/crosstool/BUILD.tpl | 78 + .../crosstool/cc_toolchain_config.bzl.tpl | 640 ++ .../arm-linux/aarch64-linux-toolchain.BUILD | 20 +- .../arm-linux/armhf-linux-toolchain.BUILD | 20 +- .../embedded/arm-linux/cc_config.bzl.tpl | 74 +- .../toolchains/remote_config/configs.bzl | 147 +- .../toolchains/remote_config/containers.bzl | 14 +- tensorflow/tools/toolchains/win/BUILD | 2 +- .../toolchains/win/tf_win_01112023/BUILD | 630 ++ .../armeabi_cc_toolchain_config.bzl | 82 + .../builtin_include_directory_paths_msvc | 6 + .../win/tf_win_01112023/toolchain_image_info | 2 + .../windows_cc_toolchain_config.bzl | 1392 ++++ .../toolchains/win/tf_win_01232023/BUILD | 630 ++ .../armeabi_cc_toolchain_config.bzl | 82 + .../builtin_include_directory_paths_msvc | 6 + .../win/tf_win_01232023/toolchain_image_info | 2 + .../windows_cc_toolchain_config.bzl | 1392 ++++ tensorflow/tsl/BUILD | 23 + tensorflow/tsl/c/BUILD | 131 + tensorflow/tsl/c/tsl_status.cc | 54 + tensorflow/tsl/c/tsl_status.h | 84 + tensorflow/tsl/c/tsl_status_helper.cc | 91 + tensorflow/tsl/c/tsl_status_helper.h | 44 + .../c/tsl_status_helper_test.cc} | 26 +- .../gpu_rng.h => tsl/c/tsl_status_internal.h} | 15 +- .../c/tsl_status_test.cc} | 28 +- tensorflow/tsl/concurrency/BUILD | 81 + tensorflow/tsl/concurrency/async_value.cc | 278 + tensorflow/tsl/concurrency/async_value.h | 993 +++ .../tsl/concurrency/async_value_ptr_test.cc | 78 + tensorflow/tsl/concurrency/async_value_ref.cc | 38 + tensorflow/tsl/concurrency/async_value_ref.h | 468 ++ .../tsl/concurrency/async_value_ref_test.cc | 190 + .../tsl/concurrency/async_value_test.cc | 182 + .../fft.h => tsl/concurrency/chain.h} | 14 +- .../tsl/concurrency/concurrent_vector.h | 171 + .../tsl/concurrency/concurrent_vector_test.cc | 103 + tensorflow/tsl/concurrency/ref_count.h | 260 + tensorflow/tsl/cuda/BUILD | 1 + tensorflow/tsl/cuda/cuda_12_0.inc | 3323 +++++++++ tensorflow/tsl/cuda/cuda_runtime_11_8.inc | 2771 ++++++++ tensorflow/tsl/cuda/cuda_runtime_12_0.inc | 2676 ++++++++ tensorflow/tsl/cuda/cuda_stub.cc | 39 +- tensorflow/tsl/cuda/cudart_stub.cc | 10 +- tensorflow/tsl/cuda/cusparse_12_0.inc | 6080 +++++++++++++++++ tensorflow/tsl/cuda/cusparse_stub.cc | 4 +- tensorflow/tsl/distributed_runtime/BUILD | 1 + .../distributed_runtime/coordination/BUILD | 97 + .../coordination/coordination_service.cc | 77 +- .../coordination/coordination_service.h | 10 +- .../coordination_service_agent.cc | 95 +- .../coordination/coordination_service_agent.h | 92 +- .../coordination_service_agent_test.cc | 50 +- .../coordination_service_error_util.h | 8 +- ...ordination_service_recoverable_job_test.cc | 54 +- .../coordination_service_rpc_handler.cc | 140 +- .../coordination_service_rpc_handler.h | 102 + .../coordination/coordination_service_test.cc | 32 +- .../tsl/distributed_runtime/preemption/BUILD | 53 +- .../preemption/preemption_sync_manager.cc | 70 +- .../preemption/preemption_sync_manager.h | 65 + .../preemption_sync_manager_test.cc | 39 +- tensorflow/tsl/distributed_runtime/rpc/BUILD | 45 +- .../rpc/coordination/BUILD | 47 + .../coordination/grpc_coordination_client.cc | 63 +- .../coordination/grpc_coordination_client.h} | 24 +- .../grpc_coordination_service_impl.cc | 40 +- .../grpc_coordination_service_impl.h | 116 + .../rpc/grpc_client_cq_tag.h | 40 + .../tsl/distributed_runtime/rpc/grpc_state.h | 251 + .../tsl/distributed_runtime/rpc/grpc_util.cc | 55 + .../tsl/distributed_runtime/rpc/grpc_util.h | 24 +- .../distributed_runtime/rpc/grpc_util_test.cc | 175 +- .../rpc/test_request.proto | 23 + tensorflow/tsl/framework/BUILD | 12 +- tensorflow/tsl/framework/allocator.h | 10 +- tensorflow/tsl/framework/bfc_allocator.cc | 46 +- tensorflow/tsl/framework/bfc_allocator.h | 5 +- tensorflow/tsl/framework/contraction/BUILD | 144 + .../contraction}/eigen_contraction_kernel.cc | 2 +- .../contraction}/eigen_contraction_kernel.h | 11 +- tensorflow/tsl/framework/convolution/BUILD | 112 + .../convolution}/eigen_convolution_helpers.h | 11 +- .../eigen_spatial_convolutions-inl.h | 8 +- .../convolution}/eigen_spatial_convolutions.h | 10 +- .../eigen_spatial_convolutions_test.cc | 620 +- .../tsl/framework/cpu_allocator_impl.cc | 11 +- tensorflow/tsl/framework/fixedpoint/BUILD | 1 + tensorflow/tsl/framework/type_traits.h | 4 +- tensorflow/tsl/lib/core/BUILD | 1 + tensorflow/tsl/lib/gtl/BUILD | 1 + tensorflow/tsl/lib/gtl/subtle/BUILD | 1 + tensorflow/tsl/lib/hash/BUILD | 1 + tensorflow/tsl/lib/histogram/BUILD | 1 + tensorflow/tsl/lib/io/BUILD | 2 + tensorflow/tsl/lib/io/snappy/BUILD | 1 + tensorflow/tsl/lib/math/BUILD | 1 + tensorflow/tsl/lib/monitoring/BUILD | 5 +- tensorflow/tsl/lib/random/BUILD | 1 + tensorflow/tsl/lib/strings/BUILD | 2 + tensorflow/tsl/mkl/BUILD | 146 + {third_party => tensorflow/tsl}/mkl/LICENSE | 0 .../tsl}/mkl/MKL_LICENSE | 0 tensorflow/tsl/mkl/build_defs.bzl | 147 + tensorflow/tsl/platform/BUILD | 64 +- tensorflow/tsl/platform/build_config.bzl | 2 - tensorflow/tsl/platform/cloud/BUILD | 9 +- .../cloud/google_auth_provider_test.cc | 11 +- .../tsl/platform/cloud/oauth_client_test.cc | 11 +- tensorflow/tsl/platform/cloud/testdata/BUILD | 1 + tensorflow/tsl/platform/default/BUILD | 23 +- .../tsl/platform/default/build_config.bzl | 77 +- .../tsl/platform/default/build_config/BUILD | 17 +- tensorflow/tsl/platform/default/dso_loader.cc | 5 +- tensorflow/tsl/platform/default/logging.cc | 1 + tensorflow/tsl/platform/default/logging.h | 23 + tensorflow/tsl/platform/default/port.cc | 8 +- .../tsl/platform/default/resource_loader.cc | 8 +- tensorflow/tsl/platform/default/status.h | 70 + tensorflow/tsl/platform/default/test.cc | 48 +- tensorflow/tsl/platform/errors.h | 78 +- tensorflow/tsl/platform/errors_test.cc | 16 +- tensorflow/tsl/platform/float8.h | 884 +++ .../platform/float8_test.cu.cc} | 306 +- tensorflow/tsl/platform/profile_utils/BUILD | 1 + tensorflow/tsl/platform/protobuf.h | 38 +- .../lib => tsl/platform}/static_threadlocal.h | 41 +- tensorflow/tsl/platform/status.cc | 54 +- tensorflow/tsl/platform/status.h | 57 +- tensorflow/tsl/platform/status_test.cc | 39 +- tensorflow/tsl/platform/strcat.cc | 30 +- tensorflow/tsl/platform/strcat_test.cc | 39 + tensorflow/tsl/platform/subprocess_test.cc | 15 +- tensorflow/tsl/platform/test.h | 17 +- tensorflow/tsl/platform/testdata/BUILD | 1 + tensorflow/tsl/platform/types.h | 1 + tensorflow/tsl/platform/windows/BUILD | 1 + .../platform/windows/windows_file_system.cc | 121 +- .../platform/windows/windows_file_system.h | 4 + tensorflow/tsl/profiler/BUILD | 15 + tensorflow/tsl/profiler/backends/cpu/BUILD | 23 +- .../backends/cpu/host_tracer_utils.cc | 2 +- .../profiler/backends/cpu/host_tracer_utils.h | 2 +- tensorflow/tsl/profiler/builds/BUILD | 7 + .../tsl/profiler/builds/build_config.bzl | 16 + tensorflow/tsl/profiler/builds/oss/BUILD | 6 + .../profiler/builds/oss/build_config.bzl | 7 +- tensorflow/tsl/profiler/convert/BUILD | 85 +- .../convert/post_process_single_host_xplane.h | 2 +- .../profiler/convert/trace_events_to_json.cc | 21 +- .../profiler/convert/trace_events_to_json.h | 16 +- .../convert/trace_events_to_json_test.cc | 14 +- .../convert/xplane_to_trace_events.cc | 48 +- .../profiler/convert/xplane_to_trace_events.h | 26 +- .../convert/xplane_to_trace_events_test.cc | 40 +- tensorflow/tsl/profiler/lib/BUILD | 161 +- .../tsl/profiler/lib/connected_traceme.h | 118 + tensorflow/tsl/profiler/lib/context_types.h | 1 + tensorflow/tsl/profiler/lib/nvtx_utils.h | 83 + .../profiler/lib/profiler_collection.cc | 12 +- .../profiler/lib/profiler_collection.h | 16 +- .../tsl/profiler/lib/profiler_controller.cc | 2 +- .../tsl/profiler/lib/profiler_controller.h | 2 +- .../tsl/profiler/lib/profiler_factory.cc | 7 +- .../tsl/profiler/lib/profiler_factory.h | 7 +- .../tsl/profiler/lib/profiler_factory_test.cc | 2 +- tensorflow/tsl/profiler/lib/profiler_lock.cc | 2 +- tensorflow/tsl/profiler/lib/profiler_lock.h | 5 + .../profiler/lib/profiler_session.cc | 49 +- .../tsl/profiler/lib/profiler_session.h | 93 + .../tsl/profiler/lib/scoped_annotation.h | 83 +- .../profiler/lib/scoped_annotation_stack.h | 35 +- .../lib/scoped_memory_debug_annotation.cc | 6 +- .../lib/scoped_memory_debug_annotation.h | 112 + tensorflow/tsl/profiler/lib/traceme_encode.h | 18 +- .../tsl/profiler/lib/traceme_encode_test.cc | 22 + tensorflow/tsl/profiler/protobuf/BUILD | 67 +- .../tsl/profiler/protobuf/profile.proto | 71 + .../profiler/protobuf/profiler_analysis.proto | 81 + .../profiler/protobuf/profiler_options.proto | 88 + .../profiler/protobuf/profiler_service.proto | 121 + .../profiler_service_monitor_result.proto | 39 + .../tsl/profiler/protobuf/trace_events.proto | 72 + tensorflow/tsl/profiler/rpc/BUILD | 79 + tensorflow/tsl/profiler/rpc/client/BUILD | 214 + .../profiler/rpc/client/capture_profile.cc | 70 +- .../tsl/profiler/rpc/client/capture_profile.h | 51 + .../profiler/rpc/client/profiler_client.cc | 38 +- .../tsl/profiler/rpc/client/profiler_client.h | 102 + .../rpc/client/profiler_client_test.cc | 27 +- .../rpc/client/profiler_client_test_util.h | 28 +- .../client/remote_profiler_session_manager.cc | 25 +- .../client/remote_profiler_session_manager.h | 85 + .../remote_profiler_session_manager_test.cc | 43 +- .../profiler/rpc/client/save_profile.cc | 60 +- .../tsl/profiler/rpc/client/save_profile.h | 58 + .../profiler/rpc/profiler_server.cc | 14 +- tensorflow/tsl/profiler/rpc/profiler_server.h | 41 + .../profiler/rpc/profiler_service_impl.cc | 52 +- .../profiler/rpc/profiler_service_impl.h} | 20 +- tensorflow/tsl/profiler/utils/BUILD | 71 +- .../profiler/utils/buffer_pool.cc | 12 +- .../profiler/utils/buffer_pool.h | 14 +- .../profiler/utils/buffer_pool_test.cc | 8 +- .../tsl/profiler/utils/file_system_utils.h | 69 + tensorflow/tsl/profiler/utils/format_utils.h | 63 + tensorflow/tsl/profiler/utils/group_events.cc | 109 +- tensorflow/tsl/profiler/utils/group_events.h | 11 +- .../tsl/profiler/utils/group_events_test.cc | 110 +- .../tsl/profiler/utils/xplane_schema.cc | 4 + tensorflow/tsl/profiler/utils/xplane_schema.h | 2 + .../tsl/profiler/utils/xplane_utils_test.cc | 4 +- tensorflow/tsl/protobuf/BUILD | 18 + tensorflow/tsl/protobuf/autotuning.proto | 106 + tensorflow/tsl/protobuf/dnn.proto | 173 + tensorflow/tsl/python/lib/core/BUILD | 136 + tensorflow/tsl/python/lib/core/bfloat16.cc | 248 + .../python => tsl/python/lib/core}/bfloat16.h | 13 +- .../tsl/python/lib/core/custom_casts.cc | 99 + tensorflow/tsl/python/lib/core/custom_casts.h | 28 + .../python/lib/core/custom_float.h} | 261 +- tensorflow/tsl/python/lib/core/float8.cc | 239 + tensorflow/tsl/python/lib/core/float8.h | 43 + .../python/lib/core}/float8_e4m3b11.cc | 6 +- .../python/lib/core}/float8_e4m3b11.h | 10 +- .../python => tsl/python/lib/core}/numpy.cc | 6 +- .../python => tsl/python/lib/core}/numpy.h | 10 +- tensorflow/tsl/tsl.bzl | 358 +- tensorflow/tsl/tsl.default.bzl | 11 +- tensorflow/tsl/util/BUILD | 1 + tensorflow/tsl/util/command_line_flags.h | 3 - tensorflow/tsl/util/proto/BUILD | 1 + tensorflow/tsl/util/stats_calculator.cc | 10 +- tensorflow/workspace2.bzl | 140 +- tensorflow/workspace3.bzl | 4 +- third_party/FP16/BUILD | 1 + third_party/absl/BUILD | 1 + ...m_google_absl_fix_mac_and_nvcc_build.patch | 13 + third_party/benchmark/BUILD | 1 + third_party/common.bzl | 42 - .../acl_depthwise_updateable_weights.patch | 101 - .../acl_fixup_SVE_merges.patch | 3122 --------- .../compute_library/acl_openmp_fix.patch | 46 + .../compute_library/compute_library.patch | 2 +- third_party/cudnn_frontend_header_fix.patch | 16 +- third_party/curl.BUILD | 41 +- third_party/dlpack/BUILD | 1 + third_party/eigen3/workspace.bzl | 4 +- third_party/farmhash/BUILD | 1 + third_party/flatbuffers/BUILD | 2 + third_party/flatbuffers/build_defs.bzl | 2 +- third_party/gemmlowp/BUILD | 1 + .../crosstool/cc_toolchain_config.bzl.tpl | 1 - third_party/gpus/cuda/build_defs.bzl.tpl | 27 +- third_party/gpus/cuda/cuda_config.h.tpl | 1 + third_party/gpus/cuda_configure.bzl | 35 +- third_party/gpus/find_rocm_config.py | 160 +- .../gpus/find_rocm_config.py.gz.base64 | 2 +- third_party/gpus/rocm_configure.bzl | 2 + third_party/hexagon/BUILD | 1 + third_party/highwayhash/BUILD | 2 + third_party/hwloc/BUILD | 5 +- third_party/hwloc/hwloc.BUILD | 24 +- third_party/icu/BUILD | 2 + third_party/jpeg/BUILD | 2 + third_party/jpeg/jpeg.BUILD | 17 + third_party/kissfft/BUILD | 2 + third_party/libprotobuf_mutator/BUILD | 1 + ...-Fix-error-GVN-on-shared-memory-load.patch | 23 +- third_party/llvm/BUILD | 1 + third_party/llvm/generated.patch | 33 + third_party/llvm/toolchains.patch | 10 +- third_party/llvm/workspace.bzl | 6 +- third_party/llvm_openmp/BUILD | 6 +- third_party/llvm_openmp/openmp.bzl | 8 +- third_party/mkl/BUILD | 114 +- third_party/mkl/build_defs.bzl | 149 +- third_party/mkl_dnn/BUILD | 1 + third_party/mkl_dnn/build_defs.bzl | 16 - third_party/mkl_dnn/mkldnn.BUILD | 15 +- third_party/mkl_dnn/mkldnn_acl.BUILD | 18 +- third_party/mkl_dnn/mkldnn_v1.BUILD | 45 +- .../onednn_acl_threadpool_scheduler.patch | 28 + third_party/nasm/BUILD | 2 + third_party/nccl/archive.BUILD | 39 +- third_party/nccl/archive.patch | 13 + third_party/nccl/build_defs.bzl.tpl | 5 +- third_party/nccl/nccl_configure.bzl | 11 +- third_party/nvtx.BUILD | 20 + third_party/opencl_headers/BUILD | 1 + third_party/ortools/ortools.patch | 23 +- third_party/pasta/BUILD | 2 + third_party/png.BUILD | 2 +- third_party/pprof.BUILD | 2 +- third_party/protobuf/protobuf.patch | 304 +- third_party/psimd/BUILD | 2 + third_party/py/numpy/tf_numpy_api/BUILD | 1 + third_party/pybind11_abseil/BUILD | 2 + .../pybind11_abseil/remove_license.patch | 13 + third_party/pybind11_abseil/workspace.bzl | 1 + third_party/ruy/BUILD | 1 + third_party/snappy.BUILD | 4 +- third_party/sobol_data/BUILD | 2 + third_party/sqlite.BUILD | 6 +- third_party/stablehlo/BUILD | 262 + third_party/stablehlo/temporary.patch | 24 + third_party/stablehlo/workspace.bzl | 6 +- .../systemlibs/grpc.bazel.generate_cc.bzl | 4 +- third_party/tensorrt/tensorrt_configure.bzl | 37 +- third_party/tf_runtime/workspace.bzl | 4 +- third_party/triton/BUILD | 1 + third_party/triton/workspace.bzl | 16 + third_party/vulkan_headers/BUILD | 1 + .../vulkan_headers/vulkan_headers.BUILD | 8 +- third_party/wrapt.BUILD | 1 + third_party/zlib.BUILD | 2 +- 7788 files changed, 338679 insertions(+), 120761 deletions(-) create mode 100644 .github/ISSUE_TEMPLATE/tflite-other.md create mode 100644 fuzztest.bazelrc create mode 100644 tensorflow/c/experimental/next_pluggable_device/BUILD create mode 100644 tensorflow/c/experimental/next_pluggable_device/c_api.cc create mode 100644 tensorflow/c/experimental/next_pluggable_device/c_api.h create mode 100644 tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/fingerprint.pb create mode 100644 tensorflow/compiler/jit/device_compilation_cache.h create mode 100644 tensorflow/compiler/jit/device_compilation_cache_test.cc create mode 100644 tensorflow/compiler/jit/device_compilation_cluster_signature.cc create mode 100644 tensorflow/compiler/jit/device_compilation_cluster_signature.h rename tensorflow/compiler/jit/{xla_compilation_cache_test.cc => device_compilation_cluster_signature_test.cc} (73%) create mode 100644 tensorflow/compiler/jit/device_compilation_profiler.cc create mode 100644 tensorflow/compiler/jit/device_compilation_profiler.h create mode 100644 tensorflow/compiler/jit/device_compilation_profiler_test.cc create mode 100644 tensorflow/compiler/jit/device_compiler.h create mode 100644 tensorflow/compiler/jit/device_compiler_client.cc create mode 100644 tensorflow/compiler/jit/device_compiler_client.h create mode 100644 tensorflow/compiler/jit/device_compiler_client_test.cc rename tensorflow/compiler/jit/{xla_compilation_cache_disable_test.cc => device_compiler_disable_test.cc} (58%) create mode 100644 tensorflow/compiler/jit/device_executable_persistor.h create mode 100644 tensorflow/compiler/jit/device_executable_persistor_test.cc delete mode 100644 tensorflow/compiler/jit/introduce_floating_point_jitter_pass.cc delete mode 100644 tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h delete mode 100644 tensorflow/compiler/jit/introduce_floating_point_jitter_pass_internal.h delete mode 100644 tensorflow/compiler/jit/introduce_floating_point_jitter_pass_test.cc create mode 100644 tensorflow/compiler/jit/pjrt_device_compiler_client.cc create mode 100644 tensorflow/compiler/jit/pjrt_device_compiler_client.h create mode 100644 tensorflow/compiler/jit/pjrt_device_context.cc create mode 100644 tensorflow/compiler/jit/pjrt_device_context.h rename tensorflow/compiler/jit/tests/{xla_compilation_cache_serialize_options_test.cc => device_compiler_serialize_options_test.cc} (87%) rename tensorflow/compiler/jit/tests/{xla_compilation_cache_serialize_test.cc => device_compiler_serialize_test.cc} (87%) rename tensorflow/compiler/jit/tests/{xla_compilation_cache_test_helper.cc => device_compiler_test_helper.cc} (94%) rename tensorflow/compiler/jit/tests/{xla_compilation_cache_test_helper.h => device_compiler_test_helper.h} (74%) create mode 100644 tensorflow/compiler/jit/tf_graph_to_hlo_compiler.cc create mode 100644 tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h create mode 100644 tensorflow/compiler/jit/tf_to_hlo_compiler.h delete mode 100644 tensorflow/compiler/jit/xla_compilation_cache.cc delete mode 100644 tensorflow/compiler/jit/xla_compilation_cache.h create mode 100644 tensorflow/compiler/jit/xla_device_compiler_client.cc create mode 100644 tensorflow/compiler/jit/xla_device_compiler_client.h delete mode 100644 tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md create mode 100644 tensorflow/compiler/mlir/lite/experimental/common/BUILD create mode 100644 tensorflow/compiler/mlir/lite/experimental/common/outline_operations.cc create mode 100644 tensorflow/compiler/mlir/lite/experimental/common/outline_operations.h create mode 100644 tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcast.mlir create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/fuse_mhlo_convolution.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-acos.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-inplaceupdate-tf_mhlo_tflite.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tf-fb-tf.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-broadcast_in_dim.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-compare.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-constant.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-reshape.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-rsqrt.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-poly.mlir create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tf-fb-tf.mlir rename tensorflow/compiler/mlir/lite/stablehlo/tests/{legalize-mhlo-tfl-sub.mlir => legalize-stablehlo-tfl-add.mlir} (69%) create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-broadcast_in_dim.mlir rename tensorflow/compiler/mlir/lite/stablehlo/tests/{legalize-mhlo-tfl-clamp.mlir => legalize-stablehlo-tfl-clamp.mlir} (51%) create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-compare.mlir rename tensorflow/compiler/mlir/lite/stablehlo/tests/{legalize-mhlo-tfl-concat.mlir => legalize-stablehlo-tfl-concat.mlir} (53%) create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-constant.mlir rename tensorflow/compiler/mlir/lite/stablehlo/tests/{legalize-mhlo-tfl-conv.mlir => legalize-stablehlo-tfl-conv.mlir} (51%) rename tensorflow/compiler/mlir/lite/stablehlo/tests/{legalize-mhlo-tfl-dot.mlir => legalize-stablehlo-tfl-dot.mlir} (64%) rename tensorflow/compiler/mlir/lite/stablehlo/tests/{legalize-mhlo-tfl-gather.mlir => legalize-stablehlo-tfl-gather.mlir} (60%) rename tensorflow/compiler/mlir/lite/stablehlo/tests/{legalize-mhlo-tfl-max.mlir => legalize-stablehlo-tfl-max.mlir} (58%) rename tensorflow/compiler/mlir/lite/stablehlo/tests/{legalize-mhlo-tfl-add.mlir => legalize-stablehlo-tfl-mul.mlir} (58%) rename tensorflow/compiler/mlir/lite/stablehlo/tests/{legalize-mhlo-tfl-pad.mlir => legalize-stablehlo-tfl-pad.mlir} (62%) create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-reshape.mlir create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-rsqrt.mlir rename tensorflow/compiler/mlir/lite/stablehlo/tests/{legalize-mhlo-tfl-scatter.mlir => legalize-stablehlo-tfl-scatter.mlir} (58%) rename tensorflow/compiler/mlir/lite/stablehlo/tests/{legalize-mhlo-tfl-mul.mlir => legalize-stablehlo-tfl-sub.mlir} (58%) create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-broadcast.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-concat.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-constant.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-conv.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-pad.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-reshape.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-rsqrt.mlir delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo.mlir rename tensorflow/compiler/mlir/lite/stablehlo/tests/{legalize-tfl-mhlo-sub.mlir => legalize-tfl-stablehlo-add.mlir} (72%) create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-broadcast.mlir rename tensorflow/compiler/mlir/lite/stablehlo/tests/{legalize-tfl-mhlo-clamp.mlir => legalize-tfl-stablehlo-clamp.mlir} (51%) create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-concat.mlir create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-constant.mlir create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-conv.mlir rename tensorflow/compiler/mlir/lite/stablehlo/tests/{legalize-tfl-mhlo-add.mlir => legalize-tfl-stablehlo-max.mlir} (57%) rename tensorflow/compiler/mlir/lite/stablehlo/tests/{legalize-tfl-mhlo-max.mlir => legalize-tfl-stablehlo-mul.mlir} (57%) create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-pad.mlir create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-reshape.mlir create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-rsqrt.mlir rename tensorflow/compiler/mlir/lite/stablehlo/tests/{legalize-tfl-mhlo-mul.mlir => legalize-tfl-stablehlo-sub.mlir} (57%) create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo.mlir create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/tests/unfuse_mhlo_batch_norm.mlir create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/fuse_convolution_pass.cc delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_tfl_legalize_patterns.td delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_util.cc delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_util.h create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h rename tensorflow/compiler/mlir/lite/stablehlo/transforms/{mhlo_tfl_pass.cc => stablehlo_tfl_pass.cc} (75%) rename tensorflow/compiler/mlir/lite/stablehlo/transforms/{mhlo_tfl_pass.h => stablehlo_tfl_pass.h} (68%) create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.cc create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_pass.cc delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_pass.h delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_tfl_pass.cc delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_poly_pass.cc delete mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_poly_pass.h create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc rename tensorflow/compiler/mlir/lite/stablehlo/transforms/{tf_mhlo_tfl_pass.h => tf_stablehlo_pass.h} (55%) rename tensorflow/compiler/mlir/lite/stablehlo/transforms/{tfl_mhlo_pass.cc => tfl_stablehlo_pass.cc} (80%) rename tensorflow/compiler/mlir/lite/stablehlo/transforms/{tfl_mhlo_pass.h => tfl_stablehlo_pass.h} (72%) create mode 100644 tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc create mode 100644 tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir create mode 100644 tensorflow/compiler/mlir/lite/tests/quantize-variables.mlir create mode 100644 tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.cc create mode 100644 tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc rename tensorflow/compiler/{xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h => mlir/lite/utils/size_utils.cc} (67%) create mode 100644 tensorflow/compiler/mlir/lite/utils/size_utils.h create mode 100644 tensorflow/compiler/mlir/lite/utils/size_utils_test.cc create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/BUILD create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/internal_visibility_allowlist.bzl create mode 100644 tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.cc create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.h create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size_test.cc create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.cc create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.h create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables_test.cc create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/exported_model.proto create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.cc create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.h create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/passes/duplicate_shape_determining_constants.cc create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_restore_op.cc create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/passes/mark_functions_noinline.cc create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.td create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_var_init_by_const.cc create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model_test.py delete mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model_wrapper.cc delete mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model_wrapper.h create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/python/save_model.py create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/tests/duplicate_shape_determining_constants.mlir create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_restore_op.mlir create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/tests/mark_functions_noinline.mlir create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_drq_per_channel.mlir create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_ptq_per_channel.mlir create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/tests/preprocess_op.mlir create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_weight_only.mlir create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_xla.mlir create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/tests/remove_var_init_by_const.mlir create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.cc create mode 100644 tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h create mode 100644 tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_test.cc create mode 100644 tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.cc create mode 100644 tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/replicate-tensor-list-init-ops.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/stablehlo_add.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/convert_session_initializer_to_function.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/launch_to_device_attribute_legacy.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands_legacy.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/replicate_tensor_list_init_ops.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island_legacy.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/tpu_partitioned_op_conversion.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/call_graph_util.h create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/convert_session_initializer_to_function.cc create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/replicate_tensor_list_init_ops_pass.cc create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/tpu_partitioned_op_conversion.cc rename tensorflow/{stream_executor/gpu/asm_compiler.h => compiler/mlir/tensorflow/translate/split_into_island_per_op_pass.h} (53%) create mode 100644 tensorflow/compiler/mlir/tensorflow/utils/topological_sort.cc create mode 100644 tensorflow/compiler/mlir/tensorflow/utils/topological_sort.h create mode 100644 tensorflow/compiler/mlir/tfrt/benchmarks/concat_benchmark.cc create mode 100644 tensorflow/compiler/mlir/tfrt/benchmarks/fused_map_bcast_benchmark.cc create mode 100644 tensorflow/compiler/mlir/tfrt/benchmarks/map_op_benchmark.cc create mode 100644 tensorflow/compiler/mlir/tfrt/benchmarks/reverse_op_benchmark.cc create mode 100644 tensorflow/compiler/mlir/tfrt/benchmarks/scatter_op_benchmark.cc create mode 100644 tensorflow/compiler/mlir/tfrt/ir/gpu_ops.cc create mode 100644 tensorflow/compiler/mlir/tfrt/ir/gpu_ops.h create mode 100644 tensorflow/compiler/mlir/tfrt/ir/gpu_ops.td delete mode 100644 tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_detensorize_linalg.cc delete mode 100644 tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fuse_fill_into_tiled_reduction.cc delete mode 100644 tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_lower_vector_transpose.cc delete mode 100644 tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_peel_tiled_loops.cc delete mode 100644 tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_symbolic_shape_optimization.cc delete mode 100644 tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_tile_cwise.cc delete mode 100644 tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_tile_reduction.cc delete mode 100644 tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_tile_transpose.cc create mode 100644 tensorflow/compiler/mlir/tfrt/python_tests/regression_tests/broadcasting_25.mlir create mode 100644 tensorflow/compiler/mlir/tfrt/python_tests/tf_reverse_test.py create mode 100644 tensorflow/compiler/mlir/tfrt/python_tests/tf_scatter_test.py create mode 100644 tensorflow/compiler/mlir/tfrt/tests/analysis/testdata/test.mlir create mode 100644 tensorflow/compiler/mlir/tfrt/tests/analysis/update_op_cost_in_tfrt_mlir_test.cc delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/jit/detensorize_linalg.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/jit/symbolic_shape_optimization.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_codegen_transpose_detection.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_fuse_fill_into_tiled_reduction.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_peel_tiled_loops.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline_one_shot.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_tile_cwise.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_tile_fill.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_tile_matmul.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_tile_reduction.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/remote_run_encapsulate.mlir create mode 100644 tensorflow/compiler/mlir/tfrt/tests/saved_model/testdata/xla_launch.mlir create mode 100644 tensorflow/compiler/mlir/tfrt/tests/saved_model/testdata/xla_launch_xla_reduce_window.mlir create mode 100644 tensorflow/compiler/mlir/tfrt/tests/sink_in_invariant_ops.mlir create mode 100644 tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/func_attributes_multiple_callers.mlir create mode 100644 tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/whileop.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data/batch.mlir delete mode 100644 tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data/range.mlir create mode 100644 tensorflow/compiler/mlir/tfrt/tests/xla_launch_fallback.mlir create mode 100644 tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.cc create mode 100644 tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h delete mode 100644 tensorflow/compiler/mlir/tfrt/transforms/remote_run_encapsulate.cc create mode 100644 tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc delete mode 100644 tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt_data.cc delete mode 100644 tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt_data.h create mode 100644 tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.cc create mode 100644 tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.h create mode 100644 tensorflow/compiler/mlir/tfrt/transforms/utils.cc create mode 100644 tensorflow/compiler/mlir/tfrt/transforms/utils.h delete mode 100644 tensorflow/compiler/mlir/tools/kernel_gen/tests/jit_i64_indexed_for_large_tensors.mlir delete mode 100644 tensorflow/compiler/mlir/tools/kernel_gen/tests/print_memrefs.mlir delete mode 100644 tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc create mode 100644 tensorflow/compiler/mlir/xla/tests/convert-mhlo-quant-to-int.mlir create mode 100644 tensorflow/compiler/mlir/xla/tests/hlo_xla_sparsification.mlir delete mode 100644 tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir create mode 100644 tensorflow/compiler/mlir/xla/tests/verify-tfxla-legalization-no-chlo.mlir create mode 100644 tensorflow/compiler/mlir/xla/tests/verify-tfxla-legalization.mlir create mode 100644 tensorflow/compiler/mlir/xla/transforms/convert_mhlo_quant_to_int.cc delete mode 100644 tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc create mode 100644 tensorflow/compiler/mlir/xla/transforms/verify_tfxla_legalization.cc create mode 100644 tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets.cc create mode 100644 tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets.h create mode 100644 tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets_test.cc create mode 100644 tensorflow/compiler/tests/bincount_op_test.py create mode 100644 tensorflow/compiler/tests/const_test.py rename tensorflow/compiler/tests/{const_op_test.py => giant_const_op_test.py} (97%) create mode 100644 tensorflow/compiler/tests/reverse_sequence_op_args_test.py create mode 100644 tensorflow/compiler/tests/unique_ops_test.py create mode 100644 tensorflow/compiler/xla/autotune_results.proto create mode 100644 tensorflow/compiler/xla/autotune_serialize.cc create mode 100644 tensorflow/compiler/xla/autotune_serialize.h create mode 100644 tensorflow/compiler/xla/backends/profiler/cpu/BUILD create mode 100644 tensorflow/compiler/xla/backends/profiler/cpu/host_tracer.cc create mode 100644 tensorflow/compiler/xla/backends/profiler/cpu/host_tracer.h rename tensorflow/{core/profiler/backends => compiler/xla/backends/profiler}/cpu/host_tracer_factory.cc (72%) rename tensorflow/{core/profiler/backends => compiler/xla/backends/profiler}/cpu/metadata_collector.cc (69%) create mode 100644 tensorflow/compiler/xla/backends/profiler/cpu/metadata_utils.h rename tensorflow/{core/profiler/backends => compiler/xla/backends/profiler}/cpu/python_tracer.cc (56%) create mode 100644 tensorflow/compiler/xla/backends/profiler/cpu/python_tracer.h rename tensorflow/{core/profiler/backends => compiler/xla/backends/profiler}/cpu/python_tracer_factory.cc (74%) create mode 100644 tensorflow/compiler/xla/backends/profiler/gpu/BUILD rename tensorflow/{core/profiler/backends => compiler/xla/backends/profiler}/gpu/cuda_test.cu.cc (96%) create mode 100644 tensorflow/compiler/xla/backends/profiler/gpu/cuda_test.h rename tensorflow/{core/profiler/backends => compiler/xla/backends/profiler}/gpu/cupti_collector.cc (89%) create mode 100644 tensorflow/compiler/xla/backends/profiler/gpu/cupti_collector.h rename tensorflow/{core/profiler/backends => compiler/xla/backends/profiler}/gpu/cupti_error_manager.cc (99%) create mode 100644 tensorflow/compiler/xla/backends/profiler/gpu/cupti_error_manager.h rename tensorflow/{core/profiler/backends => compiler/xla/backends/profiler}/gpu/cupti_error_manager_test.cc (88%) create mode 100644 tensorflow/compiler/xla/backends/profiler/gpu/cupti_interface.h rename tensorflow/{core/profiler/backends => compiler/xla/backends/profiler}/gpu/cupti_tracer.cc (93%) create mode 100644 tensorflow/compiler/xla/backends/profiler/gpu/cupti_tracer.h rename tensorflow/{core/profiler/backends => compiler/xla/backends/profiler}/gpu/cupti_utils.cc (77%) rename tensorflow/{core/profiler/backends => compiler/xla/backends/profiler}/gpu/cupti_wrapper.cc (97%) create mode 100644 tensorflow/compiler/xla/backends/profiler/gpu/cupti_wrapper.h rename tensorflow/{core/profiler/backends => compiler/xla/backends/profiler}/gpu/device_tracer_cuda.cc (83%) rename tensorflow/{core/profiler/backends => compiler/xla/backends/profiler}/gpu/device_tracer_rocm.cc (93%) create mode 100644 tensorflow/compiler/xla/backends/profiler/gpu/mock_cupti.h rename tensorflow/{core/profiler/backends => compiler/xla/backends/profiler}/gpu/nvtx_utils.cc (85%) create mode 100644 tensorflow/compiler/xla/backends/profiler/gpu/nvtx_utils.h rename tensorflow/{core/profiler/backends => compiler/xla/backends/profiler}/gpu/rocm_tracer.cc (94%) create mode 100644 tensorflow/compiler/xla/backends/profiler/gpu/rocm_tracer.h create mode 100644 tensorflow/compiler/xla/examples/axpy/BUILD create mode 100644 tensorflow/compiler/xla/examples/axpy/README.md create mode 100644 tensorflow/compiler/xla/examples/axpy/stablehlo_axpy.mlir create mode 100644 tensorflow/compiler/xla/examples/axpy/stablehlo_compile_test.cc rename tensorflow/compiler/{mlir => }/xla/experimental/conv_emitter/BUILD (91%) rename tensorflow/compiler/{mlir => }/xla/experimental/conv_emitter/conv_emitter.cc (99%) rename tensorflow/compiler/{mlir => }/xla/experimental/conv_emitter/conv_emitter.h (85%) rename tensorflow/compiler/{mlir => }/xla/experimental/conv_emitter/conv_emitter_test.cc (97%) rename tensorflow/compiler/{mlir => }/xla/experimental/conv_emitter/conv_emitter_transforms.cc (98%) rename tensorflow/compiler/{mlir => }/xla/experimental/conv_emitter/conv_emitter_transforms.h (93%) rename tensorflow/compiler/{mlir => }/xla/experimental/conv_emitter/g3doc/conv_emitter.md (100%) rename tensorflow/compiler/xla/{service/hlo_opcode.h => frontend_attributes.cc} (58%) create mode 100644 tensorflow/compiler/xla/frontend_attributes.h create mode 100644 tensorflow/compiler/xla/g3doc/images/batch_group_counts.svg create mode 100644 tensorflow/compiler/xla/glob_lit_test.bzl create mode 100644 tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc create mode 100644 tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver_option.h create mode 100644 tensorflow/compiler/xla/hlo/experimental/auto_sharding/cluster_environment.cc create mode 100644 tensorflow/compiler/xla/hlo/experimental/auto_sharding/cluster_environment.h create mode 100644 tensorflow/compiler/xla/hlo/experimental/auto_sharding/matrix.h create mode 100644 tensorflow/compiler/xla/hlo/experimental/auto_sharding/metrics.cc create mode 100644 tensorflow/compiler/xla/hlo/experimental/auto_sharding/metrics.h create mode 100644 tensorflow/compiler/xla/hlo/experimental/auto_sharding/profiling_result.h rename tensorflow/compiler/xla/{service => hlo/ir}/hlo_module_group.cc (98%) rename tensorflow/compiler/xla/{service => hlo/ir}/hlo_module_group.h (94%) create mode 100644 tensorflow/compiler/xla/hlo/transforms/BUILD rename tensorflow/compiler/xla/{service => hlo/transforms}/hlo_constant_splitter.cc (98%) rename tensorflow/compiler/xla/{service => hlo/transforms}/hlo_constant_splitter.h (88%) rename tensorflow/compiler/xla/{service => hlo/transforms}/hlo_constant_splitter_test.cc (98%) create mode 100644 tensorflow/compiler/xla/lazy.h rename tensorflow/compiler/xla/mlir/{tools => backends/cpu}/BUILD (54%) rename tensorflow/compiler/xla/mlir/{transforms/cpu => backends/cpu/transforms}/BUILD (57%) create mode 100644 tensorflow/compiler/xla/mlir/backends/cpu/transforms/legalize_collective_ops.cc create mode 100644 tensorflow/compiler/xla/mlir/backends/cpu/transforms/legalize_i1_vector_transfers.cc create mode 100644 tensorflow/compiler/xla/mlir/backends/cpu/transforms/lmhlo_to_cpu_runtime.cc rename tensorflow/compiler/xla/mlir/{transforms/cpu => backends/cpu/transforms}/passes.h (67%) rename tensorflow/compiler/xla/mlir/{transforms/cpu => backends/cpu/transforms}/passes.td (53%) create mode 100644 tensorflow/compiler/xla/mlir/backends/cpu/transforms/remove_copies_to_out_params.cc create mode 100644 tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/BUILD create mode 100644 tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/collective_ops.mlir create mode 100644 tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/collective_ops_to_cpu_runtime.mlir create mode 100644 tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/fft.mlir create mode 100644 tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/legalize_i1_vector_transfers.mlir rename tensorflow/compiler/xla/mlir/{transforms/cpu => backends/cpu/transforms}/tests/lmhlo_custom_call.mlir (72%) create mode 100644 tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/lmhlo_infeed.mlir create mode 100644 tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/remove_copies_to_out_params.mlir create mode 100644 tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/rng_bit_generator.mlir rename tensorflow/compiler/xla/mlir/{transforms/cpu => backends/cpu/transforms}/tests/xla_abi_legalization.mlir (90%) create mode 100644 tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/xla_cpu_memref_element_cast_to_llvm.mlir create mode 100644 tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/xla_cpu_outfeed.mlir rename tensorflow/compiler/xla/mlir/{transforms/cpu => backends/cpu/transforms}/xla_abi_legalization.cc (90%) create mode 100644 tensorflow/compiler/xla/mlir/backends/cpu/transforms/xla_cpu_memref_element_cast_to_llvm.cc rename tensorflow/compiler/xla/mlir/{tools/xla_cpu_opt.cc => backends/cpu/xla-cpu-opt.cc} (50%) create mode 100644 tensorflow/compiler/xla/mlir/backends/gpu/BUILD rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/BUILD (87%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/add_hlo_trace_annotations.cc (80%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/gpu_to_gpu_runtime.cc (86%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/lmhlo_gpu_to_gpu_runtime.cc (76%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/lmhlo_to_gpu_launch.cc (75%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/lmhlo_to_gpu_runtime.cc (73%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/memref_get_global_to_arg.cc (97%) create mode 100644 tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/passes.cc (74%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/passes.h (80%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/passes.td (89%) create mode 100644 tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/BUILD rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/tests/add_hlo_trace.mlir (77%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/tests/gpu_launch.mlir (68%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/tests/gpu_memcpy.mlir (100%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/tests/gpu_memset.mlir (100%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/tests/lmhlo_case.mlir (100%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/tests/lmhlo_custom_call.mlir (95%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/tests/lmhlo_fft.mlir (96%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/tests/lmhlo_gpu_cholesky.mlir (100%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/tests/lmhlo_gpu_conv.mlir (100%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/tests/lmhlo_gpu_cublas_lt_matmul.mlir (94%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/tests/lmhlo_gpu_gemm.mlir (100%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/tests/lmhlo_infeed.mlir (100%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/tests/lmhlo_outfeed.mlir (100%) create mode 100644 tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_send_recv.mlir rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/tests/lmhlo_while.mlir (100%) rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/tests/memref_get_global_to_arg.mlir (99%) create mode 100644 tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir rename tensorflow/compiler/xla/mlir/{transforms/gpu => backends/gpu/transforms}/uid_generator.h (83%) rename tensorflow/compiler/xla/mlir/{tools/xla_gpu_opt.cc => backends/gpu/xla-gpu-opt.cc} (72%) create mode 100644 tensorflow/compiler/xla/mlir/framework/ir/BUILD rename tensorflow/compiler/{mlir/xla => xla/mlir/framework}/ir/xla_framework.cc (74%) rename tensorflow/compiler/{mlir/xla => xla/mlir/framework}/ir/xla_framework.h (72%) rename tensorflow/compiler/{mlir/xla => xla/mlir/framework}/ir/xla_framework_ops.td (98%) create mode 100644 tensorflow/compiler/xla/mlir/framework/tests/BUILD rename tensorflow/compiler/{mlir/xla => xla/mlir/framework}/tests/legalize-xla-framework.mlir (98%) rename tensorflow/compiler/{mlir/xla => xla/mlir/framework}/tests/outline-with-xla-framework.mlir (86%) rename tensorflow/compiler/{mlir/xla => xla/mlir/framework}/tests/xla-framework.mlir (93%) create mode 100644 tensorflow/compiler/xla/mlir/framework/transforms/BUILD rename tensorflow/compiler/{mlir/xla => xla/mlir/framework}/transforms/outline_with_xla_framework.cc (91%) rename tensorflow/compiler/{mlir/xla/transforms/xla_passes.h => xla/mlir/framework/transforms/passes.h} (87%) rename tensorflow/compiler/{mlir/xla/transforms/xla_passes.td => xla/mlir/framework/transforms/passes.td} (100%) rename tensorflow/compiler/{mlir/xla => xla/mlir/framework}/transforms/xla_framework_to_llvm_pass.cc (96%) create mode 100644 tensorflow/compiler/xla/mlir/math/BUILD rename tensorflow/compiler/xla/mlir/{transforms/math => math/transforms}/BUILD (66%) create mode 100644 tensorflow/compiler/xla/mlir/math/transforms/math_approximation.cc create mode 100644 tensorflow/compiler/xla/mlir/math/transforms/math_legalization.cc rename tensorflow/compiler/xla/mlir/{transforms/math => math/transforms}/math_optimization.cc (91%) rename tensorflow/compiler/xla/mlir/{transforms/math => math/transforms}/passes.h (58%) rename tensorflow/compiler/xla/mlir/{transforms/math => math/transforms}/passes.td (55%) rename tensorflow/compiler/xla/mlir/{transforms/math => math/transforms}/tests/BUILD (66%) create mode 100644 tensorflow/compiler/xla/mlir/math/transforms/tests/math_legalization.mlir rename tensorflow/compiler/xla/mlir/{transforms/math => math/transforms}/tests/math_optimization.mlir (100%) create mode 100644 tensorflow/compiler/xla/mlir/memref/BUILD rename tensorflow/compiler/xla/mlir/{transforms/memref => memref/transforms}/BUILD (81%) rename tensorflow/compiler/xla/mlir/{transforms/memref => memref/transforms}/aligned_allocations.cc (92%) rename tensorflow/compiler/xla/mlir/{transforms/memref => memref/transforms}/passes.h (74%) rename tensorflow/compiler/xla/mlir/{transforms/memref => memref/transforms}/passes.td (94%) rename tensorflow/compiler/xla/mlir/{transforms/memref => memref/transforms}/tests/BUILD (66%) rename tensorflow/compiler/xla/mlir/{transforms/memref => memref/transforms}/tests/aligned_allocations.mlir (100%) create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/BUILD create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/README.md create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/bisect_lib.cc create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/bisect_lib.h create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/mlir_bisect.cc create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/BUILD create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/func.cc create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/general.cc create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/gml_st.cc create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/scf.cc rename tensorflow/compiler/xla/mlir/{transforms/cpu => tools/mlir_bisect/rewrites}/tests/BUILD (58%) create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/erase-op-without-results.mlir create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/inline-scf-while.mlir create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/reduce-gml-st-parallel-bounds.mlir create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/replace-op-with-constant.mlir create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/replace-op-with-value.mlir create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/replace-operand-with-constant.mlir create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/return-operands-of-terminator-operands.mlir create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/truncate-function.mlir create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/test_passes.cc create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/test_passes.h create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/BUILD create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/bisect.mlir create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/no-bug.mlir create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/snapshot.mlir create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/snapshot.mlir.pb create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_replay/BUILD create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_replay/README.md create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_replay/mlir_replay.cc create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_replay/mlir_replay_lib.cc create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_replay/mlir_replay_lib.h create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_replay/public/BUILD create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_replay/public/README.md create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_replay/public/compiler_trace.proto create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.cc create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.h create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace.proto create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils.h create mode 100644 tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils_test.cc delete mode 100644 tensorflow/compiler/xla/mlir/transforms/cpu/lmhlo_to_cpu_runtime.cc delete mode 100644 tensorflow/compiler/xla/mlir/transforms/gpu/launch_func_to_cuda_graph.cc delete mode 100644 tensorflow/compiler/xla/mlir/transforms/gpu/tests/gpu_launch_to_cuda_graph.mlir create mode 100644 tensorflow/compiler/xla/mlir/xla_cpu/ir/BUILD create mode 100644 tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.cc create mode 100644 tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.h create mode 100644 tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_dialect.td create mode 100644 tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_enums.td create mode 100644 tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_ops.td rename tensorflow/compiler/{mlir/tfrt/tests/tf_to_tfrt_data => xla/mlir/xla_cpu/tests}/BUILD (55%) create mode 100644 tensorflow/compiler/xla/mlir/xla_cpu/tests/bufferize.mlir create mode 100644 tensorflow/compiler/xla/mlir/xla_cpu/tests/invalid.mlir create mode 100644 tensorflow/compiler/xla/mlir/xla_cpu/tests/ops.mlir rename tensorflow/compiler/xla/mlir_hlo/{lib/Analysis => analysis}/CMakeLists.txt (84%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Analysis => analysis}/test_userange_analysis.cc (91%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Analysis => analysis}/userange_analysis.cc (99%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Analysis => analysis}/userange_analysis.h (97%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/bindings/CMakeLists.txt rename tensorflow/compiler/xla/mlir_hlo/{lib/CAPI => bindings/c}/Attributes.cc (79%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo-c => bindings/c}/Attributes.h (86%) rename tensorflow/compiler/xla/mlir_hlo/{lib/CAPI => bindings/c}/CMakeLists.txt (100%) rename tensorflow/compiler/xla/mlir_hlo/{lib/CAPI => bindings/c}/Dialects.cc (90%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo-c => bindings/c}/Dialects.h (88%) rename tensorflow/compiler/xla/mlir_hlo/{lib/CAPI => bindings/c}/Passes.cc (89%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo-c => bindings/c}/Passes.h (88%) rename tensorflow/compiler/xla/mlir_hlo/{lib/CAPI => bindings/c}/Types.cc (92%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo-c => bindings/c}/Types.h (90%) rename tensorflow/compiler/xla/mlir_hlo/{ => bindings}/python/CMakeLists.txt (100%) rename tensorflow/compiler/xla/mlir_hlo/{ => bindings}/python/MlirHloModule.cc (80%) rename tensorflow/compiler/xla/mlir_hlo/{ => bindings}/python/mlir/dialects/MhloOps.td (94%) rename tensorflow/compiler/xla/mlir_hlo/{ => bindings}/python/mlir/dialects/mhlo.py (100%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/gml_st/CMakeLists.txt (92%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/gml_st/IR/CMakeLists.txt (72%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.cc rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/gml_st/IR/gml_st_ops.h (78%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/gml_st/IR/gml_st_ops.td (51%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/gml_st/IR/gml_st_ops_base.td (87%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/gml_st/README.md (100%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/lhlo_gpu/IR => gml_st/interfaces}/CMakeLists.txt (73%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/gml_st/transforms => gml_st/interfaces}/bufferizable_op_interface_impl.cc (57%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect/gml_st/transforms => gml_st/interfaces}/bufferizable_op_interface_impl.h (78%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/CMakeLists.txt create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/add_debug_info/add_debug_info.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/collapse_shape/collapse_shape.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/compose_extract_insert_slice/compose_extract_insert_slice.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/cpu_tiling_pipeline.cc rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/gml_st/transforms => gml_st/transforms/cpu_tiling}/transform_map_for_cpu.cc (51%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_matmul_for_cpu.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_reduce_for_cpu.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_reverse_for_cpu.cc rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/gml_st/transforms => gml_st/transforms/cpu_tiling}/transform_scatter_for_cpu.cc (52%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_sort_for_cpu.cc rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/gml_st/transforms => gml_st/transforms/cpu_tiling}/transform_transpose_for_cpu.cc (75%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/fusion/fusion.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/fusion/fusion.h rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/gml_st/transforms/gml_st_to_gpu.cc => gml_st/transforms/gml_st_simtfy/gml_st_simtfy.cc} (51%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gml_st_to_gpu/gml_st_to_gpu.cc rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/gml_st/transforms => gml_st/transforms/gml_st_to_scf}/gml_st_to_scf.cc (51%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gpu_tiling/greedy_fusion.cc rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/gml_st/transforms => gml_st/transforms/gpu_tiling}/tiling_cwise.cc (77%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/gml_st/transforms => gml_st/transforms/gpu_tiling}/tiling_gpu_warp.cc (51%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/passes.h create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/passes.td create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/peeling/peeling.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/peeling/peeling.h create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/rewrite_vector_ops/rewrite_vector_contract.cc rename tensorflow/compiler/{mlir/tfrt/jit/transforms/tf_jitrt_rewrite_vector_multi_reduction.cc => xla/mlir_hlo/gml_st/transforms/rewrite_vector_ops/rewrite_vector_multi_reduction.cc} (55%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/rewrite_vector_ops/rewrite_vector_transpose.cc rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => gml_st/transforms/scalarization}/scalarization.cc (53%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/test_passes.cc rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/gml_st/transforms/test_passes.h (62%) rename tensorflow/compiler/xla/mlir_hlo/{tosa/include/mhlo_tosa/Transforms/passes.td => gml_st/transforms/test_passes.td} (75%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/gml_st/transforms => gml_st/transforms/tiling}/tiling.cc (70%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect/gml_st/transforms => gml_st/transforms/tiling}/tiling.h (74%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/gml_st/transforms => gml_st/transforms/tiling_softmax}/tiling_softmax.cc (71%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/transforms.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/transforms.h create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/triton_tiling/transform_matmul_for_triton.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorization.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorization.h create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_copy.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/gml_st/transforms/vectorization.cc => gml_st/transforms/vectorization/vectorize_for_gpu.cc} (62%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo => gml_st/utils}/CMakeLists.txt (85%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/gml_st/utils/linalg_utils.cc rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect/gml_st/transforms => gml_st/utils}/linalg_utils.h (76%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect/gml_st/transforms => gml_st/utils}/vector_utils.h (94%) delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/CMakeLists.txt delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/CMakeLists.txt delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_legacy_ops.td delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/fusion.h delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.td delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/rewriters.h delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.td delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.td delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/transforms.h delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/CMakeLists.txt delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/CMakeLists.txt delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_enums.td delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/CMakeLists.txt delete mode 100644 tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/gml_st_pipeline.h rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo/CMakeLists.txt (100%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo/IR/CMakeLists.txt (75%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo/IR/lhlo_dialect.td (94%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/lhlo/IR/lhlo_ops.cc (91%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo/IR/lhlo_ops.h (84%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo/IR/lhlo_ops.td (93%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo/IR/lhlo_ops_base.td (82%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo/IR/lhlo_ops_structs.h (81%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo/IR/lhlo_ops_structs.td (97%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/lhlo/IR/lhlo_structured_interface.cc (84%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo/IR/lhlo_structured_interface.h (74%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo/IR/lhlo_structured_interface.td (100%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/lhlo/transforms/CMakeLists.txt (73%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/lhlo/transforms => lhlo/transforms/legalize_to_tensor_op}/legalize_to_tensor_op.cc (96%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/lhlo/transforms/lhlo_elemental_utils.cc (96%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo/transforms/lhlo_elemental_utils.h (92%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/lhlo/transforms => lhlo/transforms/lhlo_legalize_to_affine}/lhlo_legalize_to_affine.cc (99%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/lhlo/transforms => lhlo/transforms/lhlo_legalize_to_gpu}/lhlo_legalize_to_gpu.cc (94%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/lhlo/transforms => lhlo/transforms/lhlo_legalize_to_parallel_loops}/lhlo_legalize_to_parallel_loops.cc (98%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo/transforms/lmhlo_passes.td (74%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo/transforms/map_hlo_to_lhlo_op.h (91%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo/transforms/map_lhlo_to_hlo_op.h (91%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo/transforms/map_lmhlo_to_scalar_op.h (88%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo/transforms/passes.h (68%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo => lhlo}/utils/lhlo_utils.h (81%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo_gpu/CMakeLists.txt (100%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo_gpu/IR/CMakeLists.txt (83%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/lhlo_gpu/IR/lhlo_gpu_ops.cc (79%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo_gpu/IR/lhlo_gpu_ops.h (72%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo_gpu/IR/lhlo_gpu_ops.td (75%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo_gpu/IR/lhlo_gpu_ops_base.td (94%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/lhlo_gpu/IR/lhlo_gpu_ops_enums.td (86%) delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/CMakeLists.txt delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Dialect/CMakeLists.txt delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/IR/CMakeLists.txt delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/IR/gml_st_ops.cc delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/collapse_materialize_ops.cc delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/fusion.cc delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/linalg_utils.cc delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/test_passes.cc delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_interface_impl.cc delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transform_matmul_for_cpu.cc delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transforms.cc delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/CMakeLists.txt delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/lhlo_fuse_linalg.cc delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo_gpu/CMakeLists.txt delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_stablehlo.cc delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/lower_complex_patterns.td delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Transforms/alloc_to_arg_pass.cc delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Transforms/gml_st_pipeline.cc delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Transforms/hlo_to_gpu_pipeline.cc delete mode 100644 tensorflow/compiler/xla/mlir_hlo/lib/Transforms/inline_fusion_pass.cc rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/mhlo/CMakeLists.txt (92%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/mhlo/IR/CMakeLists.txt (68%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/mhlo/IR/chlo_canonicalize.td (83%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_base.td rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/mhlo/IR/hlo_ops.cc (68%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/mhlo/IR/hlo_ops.h (88%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/mhlo/IR/hlo_ops.td (75%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/mhlo/IR/hlo_ops_attrs.td (58%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/mhlo/IR/hlo_ops_common.cc (98%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_common.h rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/mhlo/IR/hlo_ops_common.td (77%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/mhlo/IR/hlo_ops_typedefs.td (86%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/mhlo/IR/hlo_patterns.td (64%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/mhlo/IR/hlo_utils.td (92%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/mhlo/IR/init.cc (92%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/mhlo/IR/mhlo_bytecode.cc (99%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/mhlo/IR/mhlo_bytecode.h (85%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/mhlo/IR/mhlo_canonicalize.td (67%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/mhlo/IR/register.h (100%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/mhlo/analysis/CMakeLists.txt rename tensorflow/compiler/xla/mlir_hlo/{lib/Analysis => mhlo/analysis}/shape_component_analysis.cc (99%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Analysis => mhlo/analysis}/shape_component_analysis.h (96%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Analysis => mhlo/analysis}/test_shape_component_analysis.cc (94%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect/mhlo/transforms => mhlo/interfaces}/bufferizable_op_interface_impl.h (81%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/mhlo/transforms/CMakeLists.txt (56%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/broadcast_propagation}/broadcast_propagation.cc (98%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/chlo_legalize_to_hlo}/chlo_legalize_to_hlo.cc (97%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/chlo_legalize_to_hlo}/chlo_legalize_to_hlo_pass.cc (93%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/chlo_legalize_to_hlo}/chlo_legalize_to_hlo_patterns.td (57%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/collapse_elementwise_map}/collapse_elementwise_map.cc (92%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/constraint_fusion}/constraint_fusion_pass.cc (99%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/convert_to_signless}/convert_to_signless_pass.cc (96%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/expand_hlo_tuples}/expand_hlo_tuples.cc (89%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/expand_ops_simplifier/expand_ops_simplifier.cc rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/group_reduction_dimensions}/group_reduction_dimensions.cc (98%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/hlo_legalize_shape_ops_to_standard}/hlo_legalize_shape_ops_to_standard.cc (96%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/hlo_legalize_to_arithmetic}/hlo_legalize_to_arithmetic.cc (92%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/hlo_legalize_to_lhlo}/hlo_legalize_to_lhlo.cc (94%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/hlo_legalize_to_memref}/hlo_legalize_to_memref.cc (70%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/hlo_legalize_to_stablehlo}/hlo_legalize_to_stablehlo_pass.cc (83%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/legalize_control_flow}/legalize_control_flow.cc (98%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/legalize_einsum_to_dot_general}/legalize_einsum_to_dot_general.cc (97%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/legalize_gather_to_torch_index_select}/legalize_gather_to_torch_index_select.cc (96%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/legalize_mhlo_to_thlo}/legalize_mhlo_to_thlo.cc (85%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/legalize_shape_computations}/legalize_shape_computations.cc (97%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/legalize_sort}/legalize_sort.cc (98%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/legalize_to_linalg}/legalize_to_linalg.cc (86%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/legalize_to_standard}/legalize_to_standard.cc (95%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/legalize_to_standard}/legalize_to_standard_patterns.td (75%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/legalize_trigonometric_to_approximation}/legalize_trigonometric_to_approximation.cc (97%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/lower_complex}/lower_complex.cc (89%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/lower_complex/lower_complex_patterns.td rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/lower_general_dot}/lower_general_dot.cc (86%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/mhlo/transforms/map_chlo_to_hlo_op.h (94%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/mhlo/transforms/map_mhlo_to_scalar_op.h (96%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/mhlo/transforms/map_stablehlo_to_hlo_op.h (93%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/materialize_broadcasts}/materialize_broadcasts.cc (98%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/materialize_broadcasts}/materialize_broadcasts_pass.cc (92%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/merge_assuming_ops}/merge_assuming_ops.cc (96%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/mhlo_canonicalize_gather}/mhlo_canonicalize_gather.cc (92%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/mhlo_canonicalize_reduction}/mhlo_canonicalize_reduction.cc (97%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/mhlo_canonicalize_scatter}/mhlo_canonicalize_scatter.cc (95%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/mhlo_flatten_tuple}/mhlo_flatten_tuple.cc (96%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/mhlo/transforms/mhlo_passes.td (94%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/optimize_mhlo}/optimize_mhlo.cc (97%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/optimize_mhlo}/optimize_mhlo_pass.cc (89%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/mhlo/transforms/passes.h (89%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/prepare_for_export}/prepare_for_export.cc (96%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/rank_specialization}/rank_specialization.cc (95%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/restrict_max_rank}/restrict_max_rank.cc (98%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/mhlo/transforms/rewriters.h (96%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/shape_reification}/shape_reification_pass.cc (97%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => mhlo/transforms/shape_simplification}/shape_simplification.cc (94%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/sink_constants_to_control_flow}/sink_constants_to_control_flow.cc (96%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/sparse_chlo_legalize_to_linalg}/sparse_chlo_legalize_to_linalg.cc (96%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/sparse_rewriting}/sparse_rewriting.cc (95%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/stablehlo_legalize_to_hlo}/stablehlo_legalize_to_hlo.cc (80%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/stablehlo_legalize_to_hlo}/stablehlo_legalize_to_hlo_pass.cc (89%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => mhlo/transforms/symbolic_shape_optimization}/symbolic_shape_optimization.cc (87%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/test_infer_shaped_type}/test_infer_shaped_type_pass.cc (98%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/unfuse_batch_norm}/unfuse_batch_norm.cc (99%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/transforms/unfuse_batch_norm}/unfuse_batch_norm_pass.cc (91%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/lhlo/IR => mhlo/utils}/CMakeLists.txt (57%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/utils}/legalize_to_linalg_utils.cc (79%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect/mhlo/transforms => mhlo/utils}/legalize_to_linalg_utils.h (81%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/utils}/mhlo_scatter_gather_utils.cc (96%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect/mhlo/transforms => mhlo/utils}/mhlo_scatter_gather_utils.h (98%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/mhlo/transforms => mhlo/utils}/type_conversion.cc (95%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect/mhlo/transforms => mhlo/utils}/type_conversion.h (95%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/add_debug_info.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/collapse-shape.mlir delete mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/collapse_materialize_ops.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/compose_extract_insert_slice.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_bcast_map.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_matmul.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_reduce_map.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/matmul.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_1d.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_2d.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reverse.mlir rename tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/{transform_scatter_for_cpu.mlir => cpu_tiling/scatter.mlir} (77%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/sort.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/transpose.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gml_st_simtfy.mlir rename tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/{ => gpu_tiling}/tiling_cwise.mlir (70%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gpu_tiling/tiling_gpu_warp.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/greedy_fusion.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/greedy_tiling_and_fusion.mlir delete mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/legacy_bufferization.mlir delete mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/legacy_loop_tiling.mlir delete mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/legacy_peeling.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/rewrite_vector_contract.mlir rename tensorflow/compiler/{mlir/tfrt/tests/jit/tf_jitrt_rewrite_vector_multi_reduction.mlir => xla/mlir_hlo/tests/Dialect/gml_st/rewrite_vector_multi_reduction.mlir} (53%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/simplify_dead_copy.mlir delete mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_gpu_warp.mlir delete mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/transform_map_for_cpu.mlir delete mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/transform_matmul_for_cpu.mlir delete mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/transform_transpose_for_cpu.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/triton_tiling/transform_matmul_for_triton.mlir delete mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorization.mlir delete mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_copy.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_for_cpu.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_for_gpu.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_for_gpu_distributed.mlir delete mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_gml_st.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/warp_reduce.mlir delete mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-fuse-linalg.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/attrs.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/bitcast.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/expand_ops_simplifier.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_bounds.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/canonicalize.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/hlo_to_triton_pipeline_softmax.mlir delete mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/inline_fusion.mlir delete mode 100644 tensorflow/compiler/xla/mlir_hlo/tests/warp_reduce.mlir rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/thlo/CMakeLists.txt (95%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/thlo/IR/CMakeLists.txt (73%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/thlo/IR/thlo_ops.cc (59%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/thlo/IR/thlo_ops.h (75%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/thlo/IR/thlo_ops.td (59%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/thlo/IR => thlo/interfaces}/CMakeLists.txt (77%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/thlo/transforms => thlo/interfaces}/bufferizable_op_interface_impl.cc (92%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect/thlo/transforms => thlo/interfaces}/bufferizable_op_interface_impl.h (79%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect => }/thlo/transforms/CMakeLists.txt (82%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Dialect/thlo/transforms => thlo/transforms/legalize_sort}/legalize_sort.cc (69%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/thlo/transforms/passes.h (79%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Dialect => }/thlo/transforms/thlo_passes.td (100%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/affine.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/arith.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/bufferization.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/builtin.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/comparators.h create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/cwise_math.h create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/func.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/gml_st.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/linalg.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/math.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/memref.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/mhlo.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/mhlo_binary_cwise.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/mhlo_unary_cwise.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/scf.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tensor.cc rename tensorflow/compiler/xla/{mlir/transforms/gpu => mlir_hlo/tools/mlir_interpreter/dialects}/tests/BUILD (58%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/affine/apply.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/affine/minmax.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/arith/cmpf.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/arith/cmpi.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/arith/constant.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/arith/extf.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/arith/index_cast.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/arith/int_math.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/arith/minmax.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/arith/negf.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/arith/remf.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/arith/select.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/arith/sitofp.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/arith/uitofp.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/arith/vector_math.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/bufferization/alloc_tensor.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/bufferization/clone.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/bufferization/to_memref.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/bufferization/to_tensor.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/builtin/unrealized_conversion_cast.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/func/call.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/gml_st/for.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/gml_st/parallel.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/linalg/broadcast.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/linalg/fill.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/linalg/generic.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/linalg/map.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/linalg/matmul.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/linalg/reduce.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/linalg/transpose.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/math/math.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/memref/alloc.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/memref/collapse_shape.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/memref/copy.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/memref/dim.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/memref/expand_shape.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/memref/get_global.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/memref/invalid.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/memref/load.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/memref/subview.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/broadcast_in_dim.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/case.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/compare.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/complex_math.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/compute_reshape_shape.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/constant.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/convert.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/dot.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/dot_general.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/dynamic_slice.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/dynamic_update_slice.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/float_math.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/gather.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/int_math.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/iota.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/pad.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/reduce.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/reshape.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/scatter.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/select.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/slice.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/subtract.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/transpose.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/tuple.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/mhlo/while.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/scf/for.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/scf/if.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/scf/parallel.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/scf/while.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/tensor/collapse_shape.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/tensor/dim.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/tensor/empty.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/tensor/expand_shape.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/tensor/extract.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/tensor/extract_slice.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/tensor/from_elements.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/tensor/generate.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/tensor/insert.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/tensor/insert_slice.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/tensor/pad.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/bitcast.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/broadcast.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/compressstore.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/constant_mask.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/contract.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/create_mask.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/expandload.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/extract.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/extract_strided_slice.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/extractelement.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/flat_transpose.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/fma.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/gather.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/insert.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/insert_strided_slice.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/insertelement.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/invalid.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/load.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/maskedload.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/maskedstore.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/multi_reduction.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/outerproduct.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/reduction.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/shape_cast.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/shuffle.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/splat.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/store.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/transfer_read.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/transfer_write.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/transpose.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/type_cast.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/tests/vector/vscale.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/util.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/util.h create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/vector.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/interpreter.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/interpreter.h create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/interpreter_value.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/interpreter_value.h create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/interpreter_value_util.h create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/registration.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/registration.h create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/tensor_or_memref.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/tensor_or_memref.h create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/tests/BUILD create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/tests/interpreter_value_test.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/tests/tensor_or_memref_test.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/mlir-interpreter-runner.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/tosa/tests/prepare-mhlo.mlir create mode 100644 tensorflow/compiler/xla/mlir_hlo/tosa/transforms/BUILD rename tensorflow/compiler/xla/mlir_hlo/tosa/{lib/Transforms => transforms}/CMakeLists.txt (76%) rename tensorflow/compiler/xla/mlir_hlo/tosa/{lib/Transforms => transforms/legalize_mhlo}/legalize_mhlo.cc (68%) rename tensorflow/compiler/xla/mlir_hlo/tosa/{lib/Transforms => transforms/legalize_mhlo}/legalize_mhlo.pdll (99%) rename tensorflow/compiler/xla/mlir_hlo/tosa/{include/mhlo_tosa/Transforms => transforms}/passes.h (76%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/tosa/transforms/passes.td create mode 100644 tensorflow/compiler/xla/mlir_hlo/tosa/transforms/prepare_mhlo/prepare_mhlo.cc rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => transforms}/CMakeLists.txt (79%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/transforms/alloc_to_arg_pass.cc rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => transforms}/buffer_packing.cc (99%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => transforms}/buffer_reuse.cc (98%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => transforms}/bufferize.cc (96%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => transforms}/bufferize_pass.cc (91%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => transforms}/collapse_parallel_loops_to_1d_pass.cc (93%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => transforms}/copy_removal.cc (98%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => transforms}/detensorize_scf_ops.cc (98%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => transforms}/generic_host_to_llvm.cc (92%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => transforms}/gpu_fusion_rewrite.cc (75%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => transforms}/gpu_kernel_lowering_passes.cc (78%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Transforms => transforms}/gpu_passes.h (86%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Transforms => transforms}/gpu_passes.td (100%) create mode 100644 tensorflow/compiler/xla/mlir_hlo/transforms/hlo_to_gpu_pipeline.cc create mode 100644 tensorflow/compiler/xla/mlir_hlo/transforms/hlo_to_triton_pipeline.cc rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => transforms}/lower_index_cast_pass.cc (60%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Transforms => transforms}/passes.h (88%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Transforms => transforms}/passes.td (88%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => transforms}/propagate_static_shapes_to_kernel.cc (98%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo/Transforms => transforms}/rewriters.h (100%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => transforms}/tile_loops_pass.cc (97%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => transforms}/unbufferize_pass.cc (95%) rename tensorflow/compiler/xla/mlir_hlo/{lib/Transforms => transforms}/unroll_loops.cc (95%) rename tensorflow/compiler/xla/mlir_hlo/{lib => }/utils/CMakeLists.txt (100%) rename tensorflow/compiler/xla/mlir_hlo/{lib => }/utils/codegen_utils.cc (98%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo => }/utils/codegen_utils.h (100%) rename tensorflow/compiler/xla/mlir_hlo/{lib => }/utils/convert_op_folder.cc (98%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo => }/utils/convert_op_folder.h (100%) rename tensorflow/compiler/xla/mlir_hlo/{lib => }/utils/cycle_detector.cc (99%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo => }/utils/cycle_detector.h (100%) rename tensorflow/compiler/xla/mlir_hlo/{lib => }/utils/cycle_detector_test.cc (98%) rename tensorflow/compiler/xla/mlir_hlo/{lib => }/utils/hlo_utils.cc (99%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo => }/utils/hlo_utils.h (100%) rename tensorflow/compiler/xla/mlir_hlo/{include/mlir-hlo => }/utils/placement_utils.h (100%) create mode 100644 tensorflow/compiler/xla/pjrt/c/pjrt_c_api_cpu.cc rename tensorflow/compiler/xla/{stream_executor/tpu/pjrt_api.h => pjrt/c/pjrt_c_api_cpu.h} (65%) create mode 100644 tensorflow/compiler/xla/pjrt/c/pjrt_c_api_cpu_test.cc delete mode 100644 tensorflow/compiler/xla/pjrt/cpu_device.cc delete mode 100644 tensorflow/compiler/xla/pjrt/cpu_device.h create mode 100644 tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc create mode 100644 tensorflow/compiler/xla/pjrt/pjrt_api.cc create mode 100644 tensorflow/compiler/xla/pjrt/pjrt_api.h create mode 100644 tensorflow/compiler/xla/pjrt/pjrt_api_test.cc create mode 100644 tensorflow/compiler/xla/pjrt/plugin/BUILD rename tensorflow/compiler/xla/{stream_executor/lib/demangle.h => printer.cc} (54%) create mode 100644 tensorflow/compiler/xla/printer.h create mode 100644 tensorflow/compiler/xla/python/ifrt/BUILD create mode 100644 tensorflow/compiler/xla/python/ifrt/README.md create mode 100644 tensorflow/compiler/xla/python/ifrt/array.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/array.h create mode 100644 tensorflow/compiler/xla/python/ifrt/array_impl_test_lib.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/array_test.cc rename tensorflow/{stream_executor/cuda/cuda_blas_lt.h => compiler/xla/python/ifrt/client.cc} (73%) create mode 100644 tensorflow/compiler/xla/python/ifrt/client.h create mode 100644 tensorflow/compiler/xla/python/ifrt/client_impl_test_lib.cc rename tensorflow/{stream_executor/cuda/cuda_blas_utils.h => compiler/xla/python/ifrt/compiler.cc} (72%) create mode 100644 tensorflow/compiler/xla/python/ifrt/compiler.h create mode 100644 tensorflow/compiler/xla/python/ifrt/device.h create mode 100644 tensorflow/compiler/xla/python/ifrt/dtype.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/dtype.h create mode 100644 tensorflow/compiler/xla/python/ifrt/executable.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/executable.h create mode 100644 tensorflow/compiler/xla/python/ifrt/executable_impl_test_lib.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/future.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/future.h create mode 100644 tensorflow/compiler/xla/python/ifrt/future_test.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/index.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/index.h create mode 100644 tensorflow/compiler/xla/python/ifrt/index_domain.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/index_domain.h create mode 100644 tensorflow/compiler/xla/python/ifrt/index_domain_test.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/index_test.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/no_impl_test_main.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/shape.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/shape.h create mode 100644 tensorflow/compiler/xla/python/ifrt/shape_test.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/sharding.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/sharding.h create mode 100644 tensorflow/compiler/xla/python/ifrt/sharding_test.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/test_util.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/test_util.h rename tensorflow/compiler/xla/{mlir_hlo/lib/Dialect/gml_st/transforms/tiling_interface.cc => python/ifrt/tuple.cc} (75%) create mode 100644 tensorflow/compiler/xla/python/ifrt/tuple.h create mode 100644 tensorflow/compiler/xla/python/ifrt/tuple_impl_test_lib.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/value.cc create mode 100644 tensorflow/compiler/xla/python/ifrt/value.h create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/BUILD create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/pjrt_array.cc create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/pjrt_array.h create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/pjrt_array_impl_test_tfrt_cpu.cc create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/pjrt_client.cc create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/pjrt_client.h create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/pjrt_compiler.cc create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/pjrt_compiler.h create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/pjrt_executable.cc create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/pjrt_executable.h create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/pjrt_executable_impl_test_tfrt_cpu.cc create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/pjrt_tuple.cc create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/pjrt_tuple.h create mode 100644 tensorflow/compiler/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc rename tensorflow/{ => compiler/xla}/python/profiler/internal/traceme_wrapper.h (83%) rename tensorflow/compiler/xla/{pjrt/cpu_device_test.cc => python/util.cc} (51%) create mode 100644 tensorflow/compiler/xla/pytype.default.bzl create mode 100644 tensorflow/compiler/xla/runlit.cfg.py create mode 100644 tensorflow/compiler/xla/runlit.site.cfg.py create mode 100644 tensorflow/compiler/xla/runtime/async_runtime_test.cc create mode 100644 tensorflow/compiler/xla/runtime/ffi.cc create mode 100644 tensorflow/compiler/xla/runtime/ffi.h create mode 100644 tensorflow/compiler/xla/runtime/ffi/BUILD create mode 100644 tensorflow/compiler/xla/runtime/ffi/ffi_abi.h create mode 100644 tensorflow/compiler/xla/runtime/ffi/ffi_api.h create mode 100644 tensorflow/compiler/xla/runtime/ffi/ffi_c_api.h create mode 100644 tensorflow/compiler/xla/runtime/ffi_test.cc create mode 100644 tensorflow/compiler/xla/runtime/module.h create mode 100644 tensorflow/compiler/xla/runtime/module_registry.cc create mode 100644 tensorflow/compiler/xla/runtime/module_registry.h create mode 100644 tensorflow/compiler/xla/runtime/module_test.cc create mode 100644 tensorflow/compiler/xla/runtime/runner/BUILD create mode 100644 tensorflow/compiler/xla/runtime/runner/runner.cc create mode 100644 tensorflow/compiler/xla/runtime/runner/runner.h create mode 100644 tensorflow/compiler/xla/runtime/runner/runner.proto create mode 100644 tensorflow/compiler/xla/runtime/runner/runner.py create mode 100644 tensorflow/compiler/xla/runtime/runner/testlib_runner.cc create mode 100644 tensorflow/compiler/xla/runtime/runner/testlib_runner_test.py create mode 100644 tensorflow/compiler/xla/runtime/state.h create mode 100644 tensorflow/compiler/xla/runtime/state_test.cc create mode 100644 tensorflow/compiler/xla/service/all_reduce_promotion.cc create mode 100644 tensorflow/compiler/xla/service/all_reduce_promotion.h create mode 100644 tensorflow/compiler/xla/service/all_reduce_promotion_test.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime/collectives.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime/collectives.h create mode 100644 tensorflow/compiler/xla/service/cpu/runtime/fft_call.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime/fft_call.h create mode 100644 tensorflow/compiler/xla/service/cpu/runtime/rng.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime/rng.h create mode 100644 tensorflow/compiler/xla/service/cpu/runtime/xfeed.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime/xfeed.h delete mode 100644 tensorflow/compiler/xla/service/dfs_hlo_visitor.h delete mode 100644 tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h delete mode 100644 tensorflow/compiler/xla/service/dynamic_parameter_binding.h create mode 100644 tensorflow/compiler/xla/service/gpu/README.md create mode 100644 tensorflow/compiler/xla/service/gpu/dot_dimension_sorter.cc create mode 100644 tensorflow/compiler/xla/service/gpu/dot_dimension_sorter.h create mode 100644 tensorflow/compiler/xla/service/gpu/dot_dimension_sorter_test.cc create mode 100644 tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker_test.cc create mode 100644 tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker_test.cc rename tensorflow/compiler/xla/service/gpu/{nvptx_helper.h => gpu_device_info_for_tests.cc} (50%) create mode 100644 tensorflow/compiler/xla/service/gpu/gpu_device_info_for_tests.h rename tensorflow/compiler/xla/{stream_executor/cuda/cuda_gpu_executor_test.cc => service/gpu/gpu_device_info_test.cc} (61%) create mode 100644 tensorflow/compiler/xla/service/gpu/gpu_performance_model.cc create mode 100644 tensorflow/compiler/xla/service/gpu/gpu_performance_model.h create mode 100644 tensorflow/compiler/xla/service/gpu/gpu_performance_model_test.cc delete mode 100644 tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc delete mode 100644 tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.h delete mode 100644 tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc delete mode 100644 tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h delete mode 100644 tensorflow/compiler/xla/service/gpu/nvptx_helper.cc create mode 100644 tensorflow/compiler/xla/service/gpu/runtime/cholesky.cc create mode 100644 tensorflow/compiler/xla/service/gpu/runtime/cholesky.h create mode 100644 tensorflow/compiler/xla/service/gpu/runtime/collectives.cc create mode 100644 tensorflow/compiler/xla/service/gpu/runtime/collectives.h create mode 100644 tensorflow/compiler/xla/service/gpu/runtime/custom_call.cc create mode 100644 tensorflow/compiler/xla/service/gpu/runtime/custom_call.h create mode 100644 tensorflow/compiler/xla/service/gpu/runtime/gemm.cc create mode 100644 tensorflow/compiler/xla/service/gpu/runtime/gemm.h create mode 100644 tensorflow/compiler/xla/service/gpu/runtime/memcpy.cc create mode 100644 tensorflow/compiler/xla/service/gpu/runtime/memcpy.h create mode 100644 tensorflow/compiler/xla/service/gpu/runtime/memset.cc create mode 100644 tensorflow/compiler/xla/service/gpu/runtime/memset.h create mode 100644 tensorflow/compiler/xla/service/gpu/runtime/send_recv.cc create mode 100644 tensorflow/compiler/xla/service/gpu/runtime/send_recv.h create mode 100644 tensorflow/compiler/xla/service/gpu/runtime/triangular_solve.cc create mode 100644 tensorflow/compiler/xla/service/gpu/runtime/triangular_solve.h create mode 100644 tensorflow/compiler/xla/service/gpu/tests/dynamic_shared_memory_test.cc create mode 100644 tensorflow/compiler/xla/service/gpu/tests/elemental_ir_emitter_test.cc delete mode 100644 tensorflow/compiler/xla/service/gpu/tests/mnist.py create mode 100644 tensorflow/compiler/xla/service/gpu/tests/reduction_emitter_test.cc create mode 100644 tensorflow/compiler/xla/service/hlo_activation_analysis.cc rename tensorflow/{stream_executor/gpu/gpu_activation.h => compiler/xla/service/hlo_activation_analysis.h} (52%) create mode 100644 tensorflow/compiler/xla/service/hlo_activation_analysis_test.cc delete mode 100644 tensorflow/compiler/xla/service/hlo_computation.h delete mode 100644 tensorflow/compiler/xla/service/hlo_input_output_alias_config.h delete mode 100644 tensorflow/compiler/xla/service/hlo_instruction.h delete mode 100644 tensorflow/compiler/xla/service/hlo_instructions.h delete mode 100644 tensorflow/compiler/xla/service/hlo_module.h create mode 100644 tensorflow/compiler/xla/service/hlo_module_config_test.cc delete mode 100644 tensorflow/compiler/xla/service/hlo_module_metadata.h delete mode 100644 tensorflow/compiler/xla/service/hlo_op_metadata.h delete mode 100644 tensorflow/compiler/xla/service/hlo_schedule.h delete mode 100644 tensorflow/compiler/xla/service/hlo_sharding.h delete mode 100644 tensorflow/compiler/xla/service/hlo_sharding_metadata.h create mode 100644 tensorflow/compiler/xla/service/metrics.proto create mode 100644 tensorflow/compiler/xla/service/metrics_hook_interface.h create mode 100644 tensorflow/compiler/xla/service/spmd/collective_permute_motion.cc create mode 100644 tensorflow/compiler/xla/service/spmd/collective_permute_motion.h create mode 100644 tensorflow/compiler/xla/service/spmd/collective_permute_motion_test.cc create mode 100644 tensorflow/compiler/xla/service/spmd/partition_assignment.cc create mode 100644 tensorflow/compiler/xla/service/spmd/partition_assignment.h create mode 100644 tensorflow/compiler/xla/service/spmd/partition_assignment_test.cc create mode 100644 tensorflow/compiler/xla/service/stochastic_convert_decomposer.cc create mode 100644 tensorflow/compiler/xla/service/stochastic_convert_decomposer.h create mode 100644 tensorflow/compiler/xla/service/stochastic_convert_decomposer_test.cc create mode 100644 tensorflow/compiler/xla/service/xla_aot_compile_cpu_test.cc create mode 100644 tensorflow/compiler/xla/service/xla_aot_compile_gpu_test.cc create mode 100644 tensorflow/compiler/xla/service/xla_aot_compile_stablehlo_cpu_test.cc create mode 100644 tensorflow/compiler/xla/service/xla_aot_compile_stablehlo_test.mlir create mode 100644 tensorflow/compiler/xla/service/xla_aot_compile_test.mlir create mode 100644 tensorflow/compiler/xla/service/xla_aot_compile_test_autotune_results.prototxt create mode 100644 tensorflow/compiler/xla/service/xla_aot_compile_test_constant.mlir create mode 100644 tensorflow/compiler/xla/service/xla_aot_compile_test_convolution.mlir create mode 100644 tensorflow/compiler/xla/service/xla_aot_compile_test_gemm.mlir create mode 100644 tensorflow/compiler/xla/service/xla_aot_compile_test_gpu_target_config.prototxt create mode 100644 tensorflow/compiler/xla/service/xla_compile.bzl create mode 100644 tensorflow/compiler/xla/service/xla_compile_main.cc create mode 100644 tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.cc create mode 100644 tensorflow/compiler/xla/stream_executor/cuda/cuda_graph.h create mode 100644 tensorflow/compiler/xla/stream_executor/device_description.proto rename tensorflow/{core/common_runtime/device => compiler/xla/stream_executor}/device_id_utils.h (57%) rename tensorflow/{core/common_runtime => compiler/xla/stream_executor}/gpu/gpu_cudamallocasync_allocator.cc (89%) rename tensorflow/{core/common_runtime => compiler/xla/stream_executor}/gpu/gpu_cudamallocasync_allocator.h (85%) rename tensorflow/{core/common_runtime => compiler/xla/stream_executor}/gpu/gpu_init.cc (58%) rename tensorflow/{core/common_runtime => compiler/xla/stream_executor}/gpu/gpu_init.h (79%) delete mode 100644 tensorflow/compiler/xla/stream_executor/lib/array_slice.h delete mode 100644 tensorflow/compiler/xla/stream_executor/lib/demangle.cc delete mode 100644 tensorflow/compiler/xla/stream_executor/lib/env.h delete mode 100644 tensorflow/compiler/xla/stream_executor/lib/human_readable.h delete mode 100644 tensorflow/compiler/xla/stream_executor/lib/initialize.h delete mode 100644 tensorflow/compiler/xla/stream_executor/lib/mathutil.h delete mode 100644 tensorflow/compiler/xla/stream_executor/lib/numbers.cc delete mode 100644 tensorflow/compiler/xla/stream_executor/lib/numbers.h delete mode 100644 tensorflow/compiler/xla/stream_executor/lib/path.cc delete mode 100644 tensorflow/compiler/xla/stream_executor/lib/path.h delete mode 100644 tensorflow/compiler/xla/stream_executor/lib/process_state.cc delete mode 100644 tensorflow/compiler/xla/stream_executor/lib/process_state.h delete mode 100644 tensorflow/compiler/xla/stream_executor/lib/stacktrace.h delete mode 100644 tensorflow/compiler/xla/stream_executor/lib/status.h delete mode 100644 tensorflow/compiler/xla/stream_executor/lib/thread_options.h delete mode 100644 tensorflow/compiler/xla/stream_executor/lib/threadpool.h delete mode 100644 tensorflow/compiler/xla/stream_executor/stream_executor_internal.cc create mode 100644 tensorflow/compiler/xla/tests/convolution_cudnn_test.cc create mode 100644 tensorflow/compiler/xla/tests/exhaustive/BUILD rename tensorflow/compiler/xla/tests/{ => exhaustive}/exhaustive_binary_16_bit_test.cc (98%) rename tensorflow/compiler/xla/tests/{ => exhaustive}/exhaustive_binary_test_f32_f64.cc (99%) rename tensorflow/compiler/xla/tests/{ => exhaustive}/exhaustive_op_test_utils.cc (99%) rename tensorflow/compiler/xla/tests/{ => exhaustive}/exhaustive_op_test_utils.h (99%) rename tensorflow/compiler/xla/tests/{ => exhaustive}/exhaustive_unary_test_complex.cc (99%) rename tensorflow/compiler/xla/tests/{ => exhaustive}/exhaustive_unary_test_f32_or_smaller.cc (87%) rename tensorflow/compiler/xla/tests/{ => exhaustive}/exhaustive_unary_test_f64.cc (98%) create mode 100644 tensorflow/compiler/xla/tests/gpu_dump_mlir_passes_test.cc create mode 100644 tensorflow/compiler/xla/tests/xla_ffi_test.cc create mode 100644 tensorflow/compiler/xla/tools/data/add.hlo create mode 100644 tensorflow/compiler/xla/tools/data/must_alias.hlo create mode 100644 tensorflow/compiler/xla/tools/data/must_alias_with_sharding.hlo create mode 100644 tensorflow/compiler/xla/tools/hlo_bisect/BUILD create mode 100644 tensorflow/compiler/xla/tools/hlo_bisect/hlo_bisect.cc create mode 100644 tensorflow/compiler/xla/tools/hlo_bisect/hlo_bisect_state.cc create mode 100644 tensorflow/compiler/xla/tools/hlo_bisect/hlo_bisect_state.h create mode 100644 tensorflow/compiler/xla/tools/hlo_bisect/hlo_bisect_state_test.cc create mode 100644 tensorflow/compiler/xla/tools/hlo_bisect/hlo_bisect_utils.cc create mode 100644 tensorflow/compiler/xla/tools/hlo_bisect/hlo_bisect_utils.h create mode 100644 tensorflow/compiler/xla/tools/interactive_graphviz_bin_test.cc create mode 100644 tensorflow/compiler/xla/tools/replay_computation_bin_test.cc create mode 100644 tensorflow/compiler/xla/tools/run_hlo_module_bin_test.cc create mode 100644 tensorflow/compiler/xla/translate/hlo_to_mhlo/tests/dynamic_param.hlo create mode 100644 tensorflow/compiler/xla/translate/hlo_to_mhlo/tests/frontend_attributes.hlotxt create mode 100644 tensorflow/compiler/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo create mode 100644 tensorflow/compiler/xla/translate/hlo_to_mhlo/tests/spmd_module_sharding.hlo create mode 100644 tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir create mode 100644 tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/unsupported_type.mlir create mode 100644 tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/BUILD rename tensorflow/compiler/{mlir/xla/transforms => xla/translate/mhlo_to_lhlo_with_xla}/mhlo_to_lhlo_with_xla.cc (78%) rename tensorflow/compiler/{mlir/xla/transforms => xla/translate/mhlo_to_lhlo_with_xla}/mhlo_to_lhlo_with_xla.h (73%) rename tensorflow/compiler/{mlir/xla/tests/hlo_to_lhlo_with_xla => xla/translate/mhlo_to_lhlo_with_xla/tests}/BUILD (60%) rename tensorflow/compiler/{mlir/xla/tests/hlo_to_lhlo_with_xla => xla/translate/mhlo_to_lhlo_with_xla/tests}/gpu_ops.mlir (94%) rename tensorflow/compiler/{mlir/xla/tests/hlo_to_lhlo_with_xla => xla/translate/mhlo_to_lhlo_with_xla/tests}/hlo_text_to_lhlo_no_opt.hlotxt (79%) rename tensorflow/compiler/{mlir/xla/tests/hlo_to_lhlo_with_xla => xla/translate/mhlo_to_lhlo_with_xla/tests}/no_opt_ops.hlotxt (94%) rename tensorflow/compiler/{mlir/xla/tests/hlo_to_lhlo_with_xla => xla/translate/mhlo_to_lhlo_with_xla/tests}/non_identity_layouts.hlotxt (92%) rename tensorflow/compiler/{mlir/xla/tests/hlo_to_lhlo_with_xla => xla/translate/mhlo_to_lhlo_with_xla/tests}/ops.mlir (99%) rename tensorflow/compiler/{mlir/xla/tests/hlo_to_lhlo_with_xla => xla/translate/mhlo_to_lhlo_with_xla/tests}/passthrough.mlir (90%) rename tensorflow/compiler/{mlir/xla/xla_mlir_translate_registration.cc => xla/translate/mhlo_to_lhlo_with_xla/translate_registration.cc} (94%) create mode 100644 tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/xla_translate_opt_main.cc create mode 100644 tensorflow/core/api_def/base_api/api_def_CollectiveReduceScatterV2.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_ComputeDedupDataTupleMask.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_Conv2DBackpropFilterV2.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_Conv2DBackpropInputV2.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_DistributedSave.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_MergeDedupData.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_RandomDatasetV2.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_SegmentProdV2.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_SegmentSumV2.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_SplitDedupData.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_SyncDevice.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_TPUPartitionedInputV2.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_TPUPartitionedOutputV2.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_UniformQuantizedAdd.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_UniformQuantizedConvolution.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_SegmentProdV2.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_SegmentSumV2.pbtxt create mode 100644 tensorflow/core/common_runtime/arg_ret_placement.cc create mode 100644 tensorflow/core/common_runtime/arg_ret_placement.h create mode 100644 tensorflow/core/common_runtime/arg_ret_placement_test.cc create mode 100644 tensorflow/core/common_runtime/eager/eager_executor_test.cc create mode 100644 tensorflow/core/common_runtime/eager/placement_utils_test.cc create mode 100644 tensorflow/core/common_runtime/eager/tensor_handle_data_test.cc create mode 100644 tensorflow/core/common_runtime/eager/zen_eager_op_rewrite.cc create mode 100644 tensorflow/core/common_runtime/int32_fulltype.cc create mode 100644 tensorflow/core/common_runtime/int32_fulltype.h create mode 100644 tensorflow/core/common_runtime/int32_fulltype_test.cc create mode 100644 tensorflow/core/common_runtime/layout_pass_util.cc create mode 100644 tensorflow/core/common_runtime/layout_pass_util.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/BUILD create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c/BUILD create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c/example_plugin.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c/example_plugin.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c/plugin_c_api.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c/plugin_c_api_test.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c/test_next_pluggable_device_plugin.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c_plugin_variable.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/c_plugin_variable.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_coordination_service_agent.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_variable.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_variable.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_allocator.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_allocator.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_api.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_api.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_context.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_context.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/pjrt_compile_on_demand_op.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/pjrt_compile_on_demand_op.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent_helper.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel_helper.h rename tensorflow/{compiler/xla/stream_executor/tpu/pjrt_api.cc => core/common_runtime/next_pluggable_device/plugin_resource.cc} (75%) create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/plugin_resource.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/plugin_variable.h create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/utils.cc create mode 100644 tensorflow/core/common_runtime/next_pluggable_device/utils.h create mode 100644 tensorflow/core/common_runtime/optimize_function_graph_utils.cc create mode 100644 tensorflow/core/common_runtime/optimize_function_graph_utils.h create mode 100644 tensorflow/core/common_runtime/optimize_function_graph_utils_test.cc create mode 100644 tensorflow/core/common_runtime/optimized_function_graph_info.cc create mode 100644 tensorflow/core/common_runtime/optimized_function_graph_info.h create mode 100644 tensorflow/core/common_runtime/optimized_function_graph_info_test.cc create mode 100644 tensorflow/core/common_runtime/zen_layout_pass.cc create mode 100644 tensorflow/core/data/service/snapshot/BUILD create mode 100644 tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc create mode 100644 tensorflow/core/data/service/snapshot/file_utils.cc create mode 100644 tensorflow/core/data/service/snapshot/file_utils.h create mode 100644 tensorflow/core/data/service/snapshot/file_utils_test.cc create mode 100644 tensorflow/core/data/service/snapshot/path_utils.cc create mode 100644 tensorflow/core/data/service/snapshot/path_utils.h create mode 100644 tensorflow/core/data/service/snapshot/path_utils_test.cc create mode 100644 tensorflow/core/data/service/snapshot/snapshot_manager.cc create mode 100644 tensorflow/core/data/service/snapshot/snapshot_manager.h create mode 100644 tensorflow/core/data/service/snapshot/snapshot_reader.cc create mode 100644 tensorflow/core/data/service/snapshot/snapshot_reader.h create mode 100644 tensorflow/core/data/service/snapshot/snapshot_reader_test.cc create mode 100644 tensorflow/core/data/service/snapshot/snapshot_split_provider.cc create mode 100644 tensorflow/core/data/service/snapshot/snapshot_split_provider.h create mode 100644 tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc create mode 100644 tensorflow/core/data/service/snapshot/snapshot_stream_writer.cc create mode 100644 tensorflow/core/data/service/snapshot/snapshot_stream_writer.h create mode 100644 tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc create mode 100644 tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc create mode 100644 tensorflow/core/data/service/snapshot/test_utils.cc create mode 100644 tensorflow/core/data/service/snapshot/test_utils.h create mode 100644 tensorflow/core/data/service/snapshot/utils.cc create mode 100644 tensorflow/core/data/service/snapshot/utils.h create mode 100644 tensorflow/core/data/service/snapshot/utils_test.cc create mode 100644 tensorflow/core/data/service/testdata/choose_from_datasets.pbtxt create mode 100644 tensorflow/core/data/standalone_save_restore_test.cc create mode 100644 tensorflow/core/data/tfdataz_metrics.cc create mode 100644 tensorflow/core/data/tfdataz_metrics.h create mode 100644 tensorflow/core/data/tfdataz_metrics_test.cc create mode 100644 tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc delete mode 100644 tensorflow/core/distributed_runtime/rpc/grpc_util_test.cc delete mode 100644 tensorflow/core/framework/float8.cc delete mode 100644 tensorflow/core/framework/float8.h create mode 100644 tensorflow/core/framework/optimized_function_graph.proto create mode 100644 tensorflow/core/framework/tensor_shape_fuzz.cc create mode 100644 tensorflow/core/function/capture/by_ref_capture_test.py create mode 100644 tensorflow/core/function/capture/capture_container.py create mode 100644 tensorflow/core/function/capture/capture_container_test.py create mode 100644 tensorflow/core/function/capture/restore_captures.py create mode 100644 tensorflow/core/function/polymorphism/function_type.proto rename tensorflow/core/function/{ => runtime_client}/BUILD (97%) rename tensorflow/core/function/{ => runtime_client}/runtime_client.cc (99%) rename tensorflow/core/function/{ => runtime_client}/runtime_client.h (93%) rename tensorflow/core/function/{ => runtime_client}/runtime_client.py (95%) rename tensorflow/core/function/{ => runtime_client}/runtime_client_pybind.cc (97%) rename tensorflow/core/function/{ => runtime_client}/runtime_client_test.cc (99%) rename tensorflow/core/function/{ => runtime_client}/runtime_client_test.py (98%) create mode 100644 tensorflow/core/graph/zen_graph_util.h create mode 100644 tensorflow/core/ir/importexport/tests/graphdef_to_mlir/invalid_generic_function_named_edge_index.pbtxt create mode 100644 tensorflow/core/kernels/cast_op_impl_float8.cc create mode 100644 tensorflow/core/kernels/collective_nccl_all_to_all.cc create mode 100644 tensorflow/core/kernels/collective_nccl_all_to_all.h create mode 100644 tensorflow/core/kernels/conv_2d_gpu_bfloat16.cu.cc create mode 100644 tensorflow/core/kernels/data/experimental/distributed_save_op.cc create mode 100644 tensorflow/core/kernels/data/experimental/distributed_save_op.h create mode 100644 tensorflow/core/kernels/depthwise_conv_op_gpu_bfloat16.cu.cc create mode 100644 tensorflow/core/kernels/eigen_cuboid_convolutions_test.cc create mode 100644 tensorflow/core/kernels/mkl/mkl_kernel_util.cc create mode 100644 tensorflow/core/kernels/mkl/onednn_nn_ops_benchmark.cc create mode 100644 tensorflow/core/kernels/mlir_generated/gpu_unary_ops_large_tensor_test.cc create mode 100644 tensorflow/core/kernels/reduction_ops_gpu_bfloat16.cu.cc create mode 100644 tensorflow/core/kernels/scan_ops_gpu_bfloat16.cu.cc create mode 100644 tensorflow/core/kernels/sync_ops.cc rename tensorflow/{compiler/xla/service/hlo_casting_utils.h => core/kernels/tile_functor_gpu_bfloat16.cu.cc} (60%) rename tensorflow/{stream_executor/platform/default/initialize.h => core/kernels/topk_op_gpu_bfloat16.cu.cc} (66%) create mode 100644 tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_add_op.cc create mode 100644 tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_add_op_test.cc create mode 100644 tensorflow/core/lib/gif/testdata/3g_multiframe.gif create mode 100644 tensorflow/core/ops/compat/ops_history_v2/CollectiveReduceScatterV2.pbtxt create mode 100644 tensorflow/core/ops/compat/ops_history_v2/ComputeDedupDataTupleMask.pbtxt create mode 100644 tensorflow/core/ops/compat/ops_history_v2/Conv2DBackpropFilterV2.pbtxt create mode 100644 tensorflow/core/ops/compat/ops_history_v2/Conv2DBackpropInputV2.pbtxt create mode 100644 tensorflow/core/ops/compat/ops_history_v2/DistributedSave.pbtxt create mode 100644 tensorflow/core/ops/compat/ops_history_v2/MergeDedupData.pbtxt create mode 100644 tensorflow/core/ops/compat/ops_history_v2/RandomDatasetV2.pbtxt create mode 100644 tensorflow/core/ops/compat/ops_history_v2/SegmentProdV2.pbtxt create mode 100644 tensorflow/core/ops/compat/ops_history_v2/SegmentSumV2.pbtxt create mode 100644 tensorflow/core/ops/compat/ops_history_v2/SplitDedupData.pbtxt create mode 100644 tensorflow/core/ops/compat/ops_history_v2/SyncDevice.pbtxt create mode 100644 tensorflow/core/ops/compat/ops_history_v2/TPUPartitionedInputV2.pbtxt create mode 100644 tensorflow/core/ops/compat/ops_history_v2/TPUPartitionedOutputV2.pbtxt create mode 100644 tensorflow/core/ops/compat/ops_history_v2/UniformQuantizedAdd.pbtxt create mode 100644 tensorflow/core/ops/compat/ops_history_v2/UniformQuantizedConvolution.pbtxt create mode 100644 tensorflow/core/ops/optional_ops.cc create mode 100644 tensorflow/core/ops/sync_ops.cc create mode 100644 tensorflow/core/platform/default/build_config/BUILD create mode 100644 tensorflow/core/platform/float8.h delete mode 100644 tensorflow/core/profiler/backends/cpu/host_tracer.cc delete mode 100644 tensorflow/core/profiler/builds/oss/BUILD create mode 100644 tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc create mode 100644 tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h create mode 100644 tensorflow/core/profiler/convert/xplane_to_tool_names_test.cc create mode 100644 tensorflow/core/profiler/protobuf/topology.proto create mode 100644 tensorflow/core/profiler/utils/hlo_module_map.cc create mode 100644 tensorflow/core/profiler/utils/hlo_module_map.h create mode 100644 tensorflow/core/runtime_fallback/kernel/gpurt_kernels.cc delete mode 100644 tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.cc delete mode 100644 tensorflow/core/tfrt/eager/c_api_tfrt_distributed_interface.h create mode 100644 tensorflow/core/tfrt/utils/device_variables_table.h create mode 100644 tensorflow/core/tfrt/utils/gpu_variables_table.h create mode 100644 tensorflow/core/transforms/remapper/tests/onednn_mish.mlir delete mode 100644 tensorflow/core/util/autotune_maps/autotune_maps_utils.cc delete mode 100644 tensorflow/core/util/autotune_maps/autotune_maps_utils.h rename tensorflow/{dtensor/cc/default_parallel_executor.h => core/util/fake_clock_env.cc} (63%) create mode 100644 tensorflow/core/util/fake_clock_env.h create mode 100644 tensorflow/core/util/fake_clock_env_test.cc create mode 100644 tensorflow/core/util/zen_util.h create mode 100644 tensorflow/dtensor/build_defs.bzl create mode 100644 tensorflow/dtensor/cc/mesh_type.h create mode 100644 tensorflow/dtensor/cc/tensor_with_layout.h create mode 100644 tensorflow/dtensor/cc/xla_spmd/layout_to_xla_sharding.cc create mode 100644 tensorflow/dtensor/cc/xla_spmd/layout_to_xla_sharding.h create mode 100644 tensorflow/dtensor/mlir/dtensor_remove_dtensorlayout.cc create mode 100644 tensorflow/dtensor/mlir/dtensor_replace_auxiliary_layout_op.cc create mode 100644 tensorflow/dtensor/mlir/dtensor_set_hlo_sharding.cc create mode 100644 tensorflow/dtensor/mlir/expansions/iterator_spmd_expander.cc create mode 100644 tensorflow/dtensor/mlir/expansions/iterator_spmd_expander.h create mode 100644 tensorflow/dtensor/mlir/expansions/optional_spmd_expander.cc create mode 100644 tensorflow/dtensor/mlir/expansions/optional_spmd_expander.h create mode 100644 tensorflow/dtensor/mlir/expansions/unsupported_op_spmd_expander.cc create mode 100644 tensorflow/dtensor/mlir/expansions/unsupported_op_spmd_expander.h create mode 100644 tensorflow/dtensor/mlir/tests/BUILD create mode 100644 tensorflow/dtensor/mlir/tests/annotate_global_shape.mlir create mode 100644 tensorflow/dtensor/mlir/tests/cluster_function_conversion.mlir create mode 100644 tensorflow/dtensor/mlir/tests/constant_folding.mlir create mode 100644 tensorflow/dtensor/mlir/tests/cpu_layout.pbtxt create mode 100644 tensorflow/dtensor/mlir/tests/designate_resource_handle_mesh.mlir create mode 100644 tensorflow/dtensor/mlir/tests/device_mesh_cluster_coarsening.mlir create mode 100644 tensorflow/dtensor/mlir/tests/dtensor_all_gather.mlir create mode 100644 tensorflow/dtensor/mlir/tests/dtensor_all_scatter.mlir create mode 100644 tensorflow/dtensor/mlir/tests/dtensor_allreduce_combine_optimization.mlir create mode 100644 tensorflow/dtensor/mlir/tests/dtensor_allreduce_lowering.mlir create mode 100644 tensorflow/dtensor/mlir/tests/dtensor_allreduce_scatter_optimization.mlir create mode 100644 tensorflow/dtensor/mlir/tests/dtensor_allreduce_sum_optimization.mlir create mode 100644 tensorflow/dtensor/mlir/tests/dtensor_embedding_checkpoint.mlir create mode 100644 tensorflow/dtensor/mlir/tests/dtensor_embedding_v2.mlir create mode 100644 tensorflow/dtensor/mlir/tests/dtensor_mixed_precision_reduce.mlir create mode 100644 tensorflow/dtensor/mlir/tests/dtensor_mlir_opt_main.cc create mode 100644 tensorflow/dtensor/mlir/tests/dtensor_reduce_scatter_lowering.mlir create mode 100644 tensorflow/dtensor/mlir/tests/dtensor_remove_dtensorlayout.mlir create mode 100644 tensorflow/dtensor/mlir/tests/dtensor_replace_auxiliary_layout_op.mlir create mode 100644 tensorflow/dtensor/mlir/tests/dtensor_set_hlo_sharding.mlir create mode 100644 tensorflow/dtensor/mlir/tests/dtensor_set_hlo_sharding_default.mlir create mode 100644 tensorflow/dtensor/mlir/tests/dtensor_xla_spmd_integration.mlir create mode 100644 tensorflow/dtensor/mlir/tests/elide_identity_before_copy_to_mesh.mlir create mode 100644 tensorflow/dtensor/mlir/tests/embedding_optimizer.mlir create mode 100644 tensorflow/dtensor/mlir/tests/function_renaming.mlir create mode 100644 tensorflow/dtensor/mlir/tests/handle_cross_cluster_dependencies.mlir create mode 100644 tensorflow/dtensor/mlir/tests/handle_sparsetensors.mlir create mode 100644 tensorflow/dtensor/mlir/tests/layout_propagation.mlir create mode 100644 tensorflow/dtensor/mlir/tests/layout_propagation_v2.mlir create mode 100644 tensorflow/dtensor/mlir/tests/lower_send_recv.mlir create mode 100644 tensorflow/dtensor/mlir/tests/merge_clusters.mlir create mode 100644 tensorflow/dtensor/mlir/tests/mesh_propagation.mlir create mode 100644 tensorflow/dtensor/mlir/tests/move_compilation_to_host.mlir create mode 100644 tensorflow/dtensor/mlir/tests/op_to_device_cluster.mlir create mode 100644 tensorflow/dtensor/mlir/tests/propagate_default_layout.mlir create mode 100644 tensorflow/dtensor/mlir/tests/propagate_device_id_to_function.mlir create mode 100644 tensorflow/dtensor/mlir/tests/restore_and_assign.mlir create mode 100644 tensorflow/dtensor/mlir/tests/restore_shape_inference.mlir create mode 100644 tensorflow/dtensor/mlir/tests/set_default_sharding.mlir create mode 100644 tensorflow/dtensor/mlir/tests/sparse_expansion.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_batchparallel.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_concat.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_conv.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_dtensor_ops.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_einsum.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_embedding.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_expansion.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_fill.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_io_ops.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_iterator.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_matmul.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_metadata.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_random.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_reduction.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_save_restore.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_segment_sum.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_slice.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_softmax_loss.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_squeeze.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_tile.mlir create mode 100644 tensorflow/dtensor/mlir/tests/spmd_var_handle.mlir create mode 100644 tensorflow/dtensor/mlir/tests/tf_dtensor_ops.mlir create mode 100644 tensorflow/dtensor/mlir/tests/tpu_add_resource_device_attribute.mlir create mode 100644 tensorflow/dtensor/mlir/tests/tpu_integration.mlir create mode 100644 tensorflow/dtensor/mlir/tests/undo_merge_const_across_mesh.mlir create mode 100644 tensorflow/dtensor/mlir/tests/update_tpu_metadata.mlir create mode 100644 tensorflow/dtensor/python/numpy_util.py create mode 100644 tensorflow/dtensor/python/tests/BUILD create mode 100644 tensorflow/dtensor/python/tests/collective_test.py create mode 100644 tensorflow/dtensor/python/tests/config_test.py create mode 100644 tensorflow/dtensor/python/tests/multi_client_test.py create mode 100644 tensorflow/dtensor/python/tests/numpy_util_test.py create mode 100644 tensorflow/dtensor/python/tests/spmd_test.py create mode 100644 tensorflow/dtensor/python/tests/test_backend_name.py create mode 100644 tensorflow/dtensor/python/tests/test_backend_util.py create mode 100644 tensorflow/dtensor/python/tests/test_util.py create mode 100644 tensorflow/dtensor/python/tests/test_util_ops.py create mode 100644 tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc create mode 100644 tensorflow/examples/custom_ops_doc/multiplex_1/README.md create mode 100644 tensorflow/go/core/framework/BUILD create mode 100644 tensorflow/go/core/protobuf/BUILD create mode 100644 tensorflow/go/genop/BUILD create mode 100644 tensorflow/go/genop/internal/BUILD create mode 100644 tensorflow/go/op/BUILD create mode 100644 tensorflow/go/stream_executor/BUILD create mode 100644 tensorflow/go/testdata/label_image/grace_hopper.jpg create mode 100644 tensorflow/go/testdata/saved_model/half_plus_two/00000123/assets/foo.txt create mode 100644 tensorflow/go/testdata/saved_model/half_plus_two/00000123/saved_model.pb create mode 100644 tensorflow/go/testdata/saved_model/half_plus_two/00000123/variables/variables.data-00000-of-00001 create mode 100644 tensorflow/go/testdata/saved_model/half_plus_two/00000123/variables/variables.index create mode 100644 tensorflow/go/tsl/profiler/protobuf/BUILD create mode 100644 tensorflow/go/tsl/protobuf/BUILD create mode 100644 tensorflow/java/LEGACY.md create mode 100644 tensorflow/lite/cmake/DownloadPThreadPool.cmake create mode 100644 tensorflow/lite/core/async/BUILD create mode 100644 tensorflow/lite/core/async/README.md create mode 100644 tensorflow/lite/core/async/async_kernel_internal.h create mode 100644 tensorflow/lite/core/async/async_signature_runner.cc create mode 100644 tensorflow/lite/core/async/async_signature_runner.h create mode 100644 tensorflow/lite/core/async/async_signature_runner_test.cc create mode 100644 tensorflow/lite/core/async/async_subgraph.cc create mode 100644 tensorflow/lite/core/async/async_subgraph.h create mode 100644 tensorflow/lite/core/async/async_subgraph_test.cc create mode 100644 tensorflow/lite/core/async/backend_async_kernel_interface.cc create mode 100644 tensorflow/lite/core/async/backend_async_kernel_interface.h create mode 100644 tensorflow/lite/core/async/backend_async_kernel_interface_test.cc create mode 100644 tensorflow/lite/core/async/c/BUILD create mode 100644 tensorflow/lite/core/async/c/task.cc create mode 100644 tensorflow/lite/core/async/c/task.h create mode 100644 tensorflow/lite/core/async/c/task_test.cc create mode 100644 tensorflow/lite/core/async/c/types.h rename tensorflow/{compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.h => lite/core/async/common.h} (61%) create mode 100644 tensorflow/lite/core/async/interop/BUILD create mode 100644 tensorflow/lite/core/async/interop/attribute_keys.h create mode 100644 tensorflow/lite/core/async/interop/attribute_map_internal.cc create mode 100644 tensorflow/lite/core/async/interop/attribute_map_internal.h create mode 100644 tensorflow/lite/core/async/interop/attribute_map_internal_test.cc create mode 100644 tensorflow/lite/core/async/interop/c/BUILD create mode 100644 tensorflow/lite/core/async/interop/c/attribute_map.cc create mode 100644 tensorflow/lite/core/async/interop/c/attribute_map.h create mode 100644 tensorflow/lite/core/async/interop/c/attribute_map_test.cc create mode 100644 tensorflow/lite/core/async/interop/c/constants.cc create mode 100644 tensorflow/lite/core/async/interop/c/constants.h create mode 100644 tensorflow/lite/core/async/interop/c/types.cc create mode 100644 tensorflow/lite/core/async/interop/c/types.h create mode 100644 tensorflow/lite/core/async/interop/c/types_test.cc create mode 100644 tensorflow/lite/core/async/interop/reconcile_fns.cc create mode 100644 tensorflow/lite/core/async/interop/reconcile_fns.h create mode 100644 tensorflow/lite/core/async/interop/reconcile_fns_test.cc create mode 100644 tensorflow/lite/core/async/interop/variant.cc create mode 100644 tensorflow/lite/core/async/interop/variant.h create mode 100644 tensorflow/lite/core/async/interop/variant_test.cc create mode 100644 tensorflow/lite/core/async/task_internal.cc create mode 100644 tensorflow/lite/core/async/task_internal.h create mode 100644 tensorflow/lite/core/async/task_internal_test.cc create mode 100644 tensorflow/lite/core/async/testing/BUILD create mode 100644 tensorflow/lite/core/async/testing/mock_async_kernel.h create mode 100644 tensorflow/lite/core/async/testing/test_backend.cc create mode 100644 tensorflow/lite/core/async/testing/test_backend.h create mode 100644 tensorflow/lite/core/c/builtin_op_data.h rename tensorflow/lite/{ => core}/c/builtin_op_data_test.cc (98%) rename tensorflow/lite/{ => core}/c/c_api_experimental.cc (95%) create mode 100644 tensorflow/lite/core/c/c_api_experimental.h rename tensorflow/lite/{ => core}/c/c_api_experimental_test.cc (95%) create mode 100644 tensorflow/lite/core/c/c_api_types.h rename tensorflow/lite/{ => core}/c/common.cc (81%) create mode 100644 tensorflow/lite/core/c/common.h rename tensorflow/lite/{ => core}/c/common_test.cc (50%) create mode 100644 tensorflow/lite/core/experimental/acceleration/configuration/c/delegate_plugin.h rename tensorflow/lite/{ => core}/experimental/acceleration/configuration/c/gpu_plugin.cc (94%) create mode 100644 tensorflow/lite/core/experimental/acceleration/configuration/c/gpu_plugin.h rename tensorflow/lite/{ => core}/experimental/acceleration/configuration/c/gpu_plugin_test.cc (94%) rename tensorflow/lite/{ => core}/experimental/acceleration/configuration/c/nnapi_plugin.cc (94%) create mode 100644 tensorflow/lite/core/experimental/acceleration/configuration/c/nnapi_plugin.h rename tensorflow/lite/{ => core}/experimental/acceleration/configuration/c/nnapi_plugin_test.cc (95%) rename tensorflow/lite/core/experimental/acceleration/configuration/c/{vendor_delegate.h => stable_delegate.h} (71%) rename tensorflow/lite/{ => core}/experimental/acceleration/configuration/c/xnnpack_plugin.cc (92%) create mode 100644 tensorflow/lite/core/experimental/acceleration/configuration/c/xnnpack_plugin.h rename tensorflow/lite/{ => core}/experimental/acceleration/configuration/c/xnnpack_plugin_test.cc (95%) create mode 100644 tensorflow/lite/core/experimental/acceleration/configuration/stable_delegate_registry.cc create mode 100644 tensorflow/lite/core/experimental/acceleration/configuration/stable_delegate_registry.h create mode 100644 tensorflow/lite/core/experimental/acceleration/configuration/stable_delegate_registry_test.cc rename tensorflow/{stream_executor/tpu/c_api_conversions.h => lite/core/shims/c/c_api_opaque.h} (76%) create mode 100644 tensorflow/lite/core/special_rules.bzl create mode 100644 tensorflow/lite/delegates/external/external_delegate_interface.h create mode 100644 tensorflow/lite/delegates/opaque_delegate_test.cc create mode 100755 tensorflow/lite/delegates/utils/dummy_delegate/external_delegate_test.sh create mode 100644 tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/BUILD create mode 100644 tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/README.md create mode 100644 tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate.cc create mode 100644 tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate.h create mode 100644 tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external.cc create mode 100644 tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc create mode 100644 tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_test.cc create mode 100644 tensorflow/lite/delegates/utils/experimental/stable_delegate/BUILD create mode 100644 tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.cc create mode 100644 tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader.h create mode 100644 tensorflow/lite/delegates/utils/experimental/stable_delegate/delegate_loader_test.cc create mode 100644 tensorflow/lite/delegates/utils/experimental/stable_delegate/stable_delegate_interface.h create mode 100644 tensorflow/lite/delegates/utils/experimental/stable_delegate/stable_xnnpack_delegate.cc create mode 100644 tensorflow/lite/delegates/utils/experimental/stable_delegate/test_xnnpack_settings.json create mode 100644 tensorflow/lite/delegates/utils/experimental/stable_delegate/tflite_settings_json_parser.cc create mode 100644 tensorflow/lite/delegates/utils/experimental/stable_delegate/tflite_settings_json_parser.h create mode 100644 tensorflow/lite/delegates/utils/experimental/stable_delegate/tflite_settings_json_parser_test.cc create mode 100644 tensorflow/lite/delegates/utils/experimental/stable_delegate/version_script.lds create mode 100644 tensorflow/lite/delegates/xnnpack/signed_quantized_reshape_test.cc create mode 100644 tensorflow/lite/delegates/xnnpack/unsigned_quantized_reshape_test.cc create mode 100644 tensorflow/lite/experimental/acceleration/configuration/build_defs.bzl rename tensorflow/lite/experimental/acceleration/configuration/c/{vendor_delegate.h => stable_delegate.h} (90%) mode change 100755 => 100644 tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h rename tensorflow/{compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/CMakeLists.txt => lite/experimental/acceleration/configuration/prev_is_different_than_current_test.sh} (57%) mode change 100644 => 100755 create mode 100644 tensorflow/lite/experimental/acceleration/configuration/stable_delegate_plugin.cc create mode 100644 tensorflow/lite/experimental/acceleration/configuration/stable_delegate_plugin.h create mode 100644 tensorflow/lite/experimental/acceleration/configuration/stable_delegate_plugin_test.cc create mode 100644 tensorflow/lite/experimental/acceleration/configuration/testdata/configuration.old.fbs create mode 100644 tensorflow/lite/experimental/acceleration/configuration/testdata/configuration.proto_prev create mode 100644 tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api_types.h create mode 100644 tensorflow/lite/experimental/acceleration/mini_benchmark/gpu_module_plugin.cc create mode 100644 tensorflow/lite/experimental/acceleration/mini_benchmark/gpu_module_plugin.h create mode 100644 tensorflow/lite/generate-pc.sh create mode 100644 tensorflow/lite/java/src/main/java/org/tensorflow/lite/acceleration/ValidatedAccelerationConfig.java create mode 100644 tensorflow/lite/kernels/builtin_ops_list.inc create mode 100644 tensorflow/lite/kernels/shim/test_op/tmpl_op.h create mode 100644 tensorflow/lite/kernels/shim/test_op/tmpl_tf_op.cc create mode 100644 tensorflow/lite/kernels/shim/test_op/tmpl_tf_op.h create mode 100644 tensorflow/lite/kernels/shim/test_op/tmpl_tf_op_test.cc create mode 100644 tensorflow/lite/kernels/shim/test_op/tmpl_tflite_op.cc rename tensorflow/{core/tfrt/eager/c_api_tfrt_distributed_impl.h => lite/kernels/shim/test_op/tmpl_tflite_op.h} (52%) create mode 100644 tensorflow/lite/kernels/shim/test_op/tmpl_tflite_op_test.cc create mode 100644 tensorflow/lite/kernels/shim/tflite_op_wrapper.h create mode 100644 tensorflow/lite/kernels/shim/tflite_op_wrapper_test.cc create mode 100644 tensorflow/lite/profiling/telemetry/c/BUILD create mode 100644 tensorflow/lite/profiling/telemetry/c/profiler.h create mode 100644 tensorflow/lite/profiling/telemetry/c/telemetry_setting.h create mode 100644 tensorflow/lite/profiling/telemetry/c/telemetry_setting_internal.cc create mode 100644 tensorflow/lite/profiling/telemetry/c/telemetry_setting_internal.h create mode 100644 tensorflow/lite/profiling/telemetry/profiler_test.cc delete mode 100644 tensorflow/lite/profiling/telemetry/telemetry_settings.h create mode 100644 tensorflow/lite/schema/conversion_metadata_generated.h rename tensorflow/{stream_executor/tpu/tpu_executor_interface.h => lite/test_util.h} (62%) delete mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/jni/accuracy_benchmark.cc delete mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/jni/accuracy_benchmark.h delete mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/jni/accuracy_benchmark_test.cc delete mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/jni/delegate_performance_benchmark_jni.cc delete mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/jni/latency_benchmark.cc create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/models/BUILD rename tensorflow/{compiler/xla/mlir_hlo/tosa/include/mhlo_tosa/Transforms/CMakeLists.txt => lite/tools/benchmark/experimental/delegate_performance/android/models/mobilenet_v1_1.0_224.textproto} (55%) rename tensorflow/{compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/transforms/CMakeLists.txt => lite/tools/benchmark/experimental/delegate_performance/android/models/mobilenet_v1_1.0_224_quant.textproto} (55%) create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/proto.bzl create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/proto/BUILD rename tensorflow/{compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/CMakeLists.txt => lite/tools/benchmark/experimental/delegate_performance/android/proto/default_latency_criteria.textproto} (55%) create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/proto/delegate_performance.proto create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyActivity.java create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkAccuracyImpl.java create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkLatencyActivity.java create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkLatencyImpl.java create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/BenchmarkResultType.java create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/CsvWriter.java create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/DelegatePerformanceBenchmark.java create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/java/org/tensorflow/lite/benchmark/delegateperformance/TfLiteSettingsListEntry.java create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.cc create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.h create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/delegate_performance_benchmark_jni.cc create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/latency_benchmark.cc create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/latency_benchmark.h rename tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/{jni/latency_benchmark.h => src/main/native/status_codes.h} (65%) delete mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/org/tensorflow/lite/benchmark/delegate_performance/BenchmarkAccuracyActivity.java delete mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/org/tensorflow/lite/benchmark/delegate_performance/BenchmarkLatencyActivity.java delete mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/org/tensorflow/lite/benchmark/delegate_performance/DelegatePerformanceBenchmark.java create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/test/native/BUILD create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/test/native/accuracy_benchmark_test.cc create mode 100644 tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/test/native/latency_benchmark_test.cc create mode 100644 tensorflow/lite/tools/cmake/modules/Findgoogle_benchmark.cmake create mode 100644 tensorflow/lite/tools/delegates/experimental/stable_delegate/BUILD create mode 100644 tensorflow/lite/tools/delegates/experimental/stable_delegate/stable_delegate_provider.cc create mode 100644 tensorflow/lite/tools/delegates/experimental/stable_delegate/stable_delegate_provider_test.cc create mode 100755 tensorflow/lite/tools/delegates/experimental/stable_delegate/test_invalid_settings.json create mode 100755 tensorflow/lite/tools/delegates/experimental/stable_delegate/test_missing_delegate_path_settings.json create mode 100755 tensorflow/lite/tools/delegates/experimental/stable_delegate/test_missing_stable_delegate_settings.json create mode 100755 tensorflow/lite/tools/delegates/experimental/stable_delegate/test_sample_stable_delegate_settings.json create mode 100755 tensorflow/lite/tools/delegates/experimental/stable_delegate/test_stable_xnnpack_settings.json create mode 100644 tensorflow/lite/tools/evaluation/tasks/ios/BUILD.apple create mode 100644 tensorflow/lite/tools/evaluation/tasks/ios/README.md create mode 100644 tensorflow/lite/tools/evaluation/tasks/ios/TFLiteEvaluation/TFLiteEvaluation.xcodeproj/project.pbxproj create mode 100644 tensorflow/lite/tools/evaluation/tasks/ios/TFLiteEvaluation/TFLiteEvaluation/AppDelegate.h create mode 100644 tensorflow/lite/tools/evaluation/tasks/ios/TFLiteEvaluation/TFLiteEvaluation/AppDelegate.m create mode 100644 tensorflow/lite/tools/evaluation/tasks/ios/TFLiteEvaluation/TFLiteEvaluation/Assets.xcassets/AccentColor.colorset/Contents.json create mode 100644 tensorflow/lite/tools/evaluation/tasks/ios/TFLiteEvaluation/TFLiteEvaluation/Assets.xcassets/AppIcon.appiconset/Contents.json create mode 100644 tensorflow/lite/tools/evaluation/tasks/ios/TFLiteEvaluation/TFLiteEvaluation/Assets.xcassets/Contents.json create mode 100644 tensorflow/lite/tools/evaluation/tasks/ios/TFLiteEvaluation/TFLiteEvaluation/Base.lproj/Main.storyboard create mode 100644 tensorflow/lite/tools/evaluation/tasks/ios/TFLiteEvaluation/TFLiteEvaluation/EvaluationViewController.h create mode 100644 tensorflow/lite/tools/evaluation/tasks/ios/TFLiteEvaluation/TFLiteEvaluation/EvaluationViewController.mm create mode 100644 tensorflow/lite/tools/evaluation/tasks/ios/TFLiteEvaluation/TFLiteEvaluation/Info.plist create mode 100644 tensorflow/lite/tools/evaluation/tasks/ios/TFLiteEvaluation/TFLiteEvaluation/evaluation_data/evaluation_params.json create mode 100644 tensorflow/lite/tools/evaluation/tasks/ios/TFLiteEvaluation/TFLiteEvaluation/main.m create mode 100644 tensorflow/lite/tools/evaluation/tasks/ios/build_evaluation_framework.sh create mode 100644 tensorflow/lite/tools/evaluation/tasks/task_executor_c_api.cc create mode 100644 tensorflow/lite/tools/evaluation/tasks/task_executor_c_api.h rename tensorflow/lite/{experimental/acceleration/mini_benchmark => tools}/model_loader.cc (75%) rename tensorflow/lite/{experimental/acceleration/mini_benchmark => tools}/model_loader.h (79%) rename tensorflow/lite/{experimental/acceleration/mini_benchmark => tools}/model_loader_test.cc (74%) create mode 100644 tensorflow/python/autograph/utils/type_registry.py delete mode 100644 tensorflow/python/checkpoint/__init__.py create mode 100644 tensorflow/python/checkpoint/async_checkpoint_helper.py create mode 100644 tensorflow/python/checkpoint/tensor_callable.py create mode 100644 tensorflow/python/checkpoint/tensor_callable_test.py create mode 100644 tensorflow/python/data/experimental/kernel_tests/distributed_save_test.py create mode 100644 tensorflow/python/data/experimental/kernel_tests/index_shuffle_test.py create mode 100644 tensorflow/python/data/experimental/kernel_tests/service/snapshot_ft_test.py create mode 100644 tensorflow/python/data/experimental/ops/distributed_save_op.py create mode 100644 tensorflow/python/data/kernel_tests/choose_from_datasets_test.py rename tensorflow/python/data/{experimental/kernel_tests/directed_interleave_dataset_test.py => kernel_tests/sample_from_datasets_test.py} (59%) create mode 100644 tensorflow/python/data/ops/cache_op.py create mode 100644 tensorflow/python/data/ops/choose_from_datasets_op.py create mode 100644 tensorflow/python/data/ops/concatenate_op.py create mode 100644 tensorflow/python/data/ops/dataset_autograph.py create mode 100644 tensorflow/python/data/ops/debug_mode.py create mode 100644 tensorflow/python/data/ops/directed_interleave_op.py create mode 100644 tensorflow/python/data/ops/flat_map_op.py create mode 100644 tensorflow/python/data/ops/from_generator_op.py create mode 100644 tensorflow/python/data/ops/from_sparse_tensor_slices_op.py create mode 100644 tensorflow/python/data/ops/from_tensors_op.py create mode 100644 tensorflow/python/data/ops/group_by_window_op.py create mode 100644 tensorflow/python/data/ops/interleave_op.py create mode 100644 tensorflow/python/data/ops/iterator_autograph.py create mode 100644 tensorflow/python/data/ops/map_op.py create mode 100644 tensorflow/python/data/ops/prefetch_op.py create mode 100644 tensorflow/python/data/ops/random_op.py create mode 100644 tensorflow/python/data/ops/range_op.py create mode 100644 tensorflow/python/data/ops/repeat_op.py create mode 100644 tensorflow/python/data/ops/sample_from_datasets_op.py create mode 100644 tensorflow/python/data/ops/scan_op.py create mode 100644 tensorflow/python/data/ops/shard_op.py create mode 100644 tensorflow/python/data/ops/shuffle_op.py create mode 100644 tensorflow/python/data/ops/skip_op.py create mode 100644 tensorflow/python/data/ops/snapshot_op.py create mode 100644 tensorflow/python/data/ops/take_op.py create mode 100644 tensorflow/python/data/ops/take_while_op.py create mode 100644 tensorflow/python/data/ops/unbatch_op.py create mode 100644 tensorflow/python/data/ops/unique_op.py create mode 100644 tensorflow/python/data/ops/window_op.py create mode 100644 tensorflow/python/distribute/coordinator/get_task_states_test.py create mode 100644 tensorflow/python/distribute/experimental/dtensor_util.py create mode 100644 tensorflow/python/distribute/experimental/dtensor_util_test.py create mode 100644 tensorflow/python/distribute/experimental/mirrored_strategy.py create mode 100644 tensorflow/python/distribute/experimental/mirrored_strategy_test.py create mode 100644 tensorflow/python/distribute/failure_handling/preemption_watcher.py create mode 100644 tensorflow/python/eager/polymorphic_function/attributes.py create mode 100644 tensorflow/python/eager/polymorphic_function/compiler_ir.py create mode 100644 tensorflow/python/eager/polymorphic_function/compiler_ir_test.py create mode 100644 tensorflow/python/eager/polymorphic_function/saved_model_exported_concrete.py create mode 100644 tensorflow/python/framework/kythe_metadata.proto create mode 100644 tensorflow/python/framework/offset_counter.cc create mode 100644 tensorflow/python/framework/offset_counter_helper.cc rename tensorflow/{compiler/xla/service/hlo_domain_metadata.h => python/framework/offset_counter_helper.h} (58%) create mode 100644 tensorflow/python/framework/offset_counter_helper_test.cc create mode 100644 tensorflow/python/framework/op_reg_offset.proto create mode 100644 tensorflow/python/framework/python_op_gen_annotator.cc create mode 100644 tensorflow/python/framework/python_op_gen_annotator.h create mode 100644 tensorflow/python/framework/python_op_gen_annotator_test.cc create mode 100644 tensorflow/python/framework/summary_test_util.py create mode 100644 tensorflow/python/framework/test_ops.cu.cc rename tensorflow/python/{lib/core/numpy.cc => framework/test_ops.h} (64%) create mode 100644 tensorflow/python/framework/type_spec_registry.py delete mode 100644 tensorflow/python/lib/core/bfloat16.cc create mode 100644 tensorflow/python/lib/core/custom_casts_wrapper.cc rename tensorflow/python/lib/core/{bfloat16_test.py => custom_float_test.py} (73%) delete mode 100644 tensorflow/python/lib/core/float8_e4m3b11.cc delete mode 100644 tensorflow/python/lib/core/float8_e4m3b11.h rename tensorflow/{stream_executor/gpu/gpu_diagnostics.h => python/lib/core/float8_wrapper.cc} (63%) delete mode 100644 tensorflow/python/lib/core/numpy.h create mode 100644 tensorflow/python/profiler/internal/flops_registry_test.py create mode 100644 tensorflow/python/saved_model/fingerprinting.md create mode 100644 tensorflow/python/saved_model/fingerprinting.py create mode 100644 tensorflow/python/saved_model/path_helpers.py create mode 100644 tensorflow/python/tpu/tpu_replication.py rename tensorflow/{compiler/xla/tools/interactive_graphviz_test.sh => python/types/data.py} (56%) mode change 100755 => 100644 delete mode 100644 tensorflow/python/util/memory.py create mode 100644 tensorflow/security/advisory/tfsa-2022-144.md create mode 100644 tensorflow/security/advisory/tfsa-2022-145.md create mode 100644 tensorflow/security/advisory/tfsa-2022-146.md create mode 100644 tensorflow/security/advisory/tfsa-2022-147.md create mode 100644 tensorflow/security/advisory/tfsa-2022-148.md create mode 100644 tensorflow/security/advisory/tfsa-2022-149.md create mode 100644 tensorflow/security/advisory/tfsa-2022-150.md create mode 100644 tensorflow/security/advisory/tfsa-2022-151.md create mode 100644 tensorflow/security/advisory/tfsa-2022-152.md create mode 100644 tensorflow/security/advisory/tfsa-2022-153.md create mode 100644 tensorflow/security/advisory/tfsa-2022-154.md create mode 100644 tensorflow/security/advisory/tfsa-2022-155.md create mode 100644 tensorflow/security/advisory/tfsa-2022-156.md create mode 100644 tensorflow/security/advisory/tfsa-2022-157.md create mode 100644 tensorflow/security/advisory/tfsa-2022-158.md create mode 100644 tensorflow/security/advisory/tfsa-2022-159.md create mode 100644 tensorflow/security/advisory/tfsa-2022-160.md create mode 100644 tensorflow/security/advisory/tfsa-2022-161.md create mode 100644 tensorflow/security/advisory/tfsa-2022-162.md create mode 100644 tensorflow/security/advisory/tfsa-2022-163.md create mode 100644 tensorflow/security/advisory/tfsa-2022-164.md create mode 100644 tensorflow/security/advisory/tfsa-2022-165.md create mode 100644 tensorflow/security/advisory/tfsa-2022-166.md create mode 100644 tensorflow/security/advisory/tfsa-2022-167.md create mode 100644 tensorflow/security/advisory/tfsa-2022-168.md create mode 100644 tensorflow/security/advisory/tfsa-2022-169.md create mode 100644 tensorflow/security/advisory/tfsa-2022-170.md create mode 100644 tensorflow/security/fuzzing/cc/core/framework/BUILD create mode 100644 tensorflow/security/fuzzing/cc/core/framework/tensor_shape_domains.cc rename tensorflow/{compiler/xla/stream_executor/lib/error.h => security/fuzzing/cc/core/framework/tensor_shape_domains.h} (53%) create mode 100644 tensorflow/security/fuzzing/cc/core/function/BUILD create mode 100644 tensorflow/security/fuzzing/cc/core/function/runtime_client_fuzz.cc create mode 100644 tensorflow/security/fuzzing/cc/fuzz_domains.h create mode 100644 tensorflow/security/fuzzing/cc/ops/bincount_fuzz.cc create mode 100644 tensorflow/security/fuzzing/cc/ops/string_to_number_fuzz.cc delete mode 100644 tensorflow/stream_executor/BUILD delete mode 100644 tensorflow/stream_executor/allocator_stats.h delete mode 100644 tensorflow/stream_executor/blas.h delete mode 100644 tensorflow/stream_executor/cuda/BUILD delete mode 100644 tensorflow/stream_executor/cuda/cuda_activation.h delete mode 100644 tensorflow/stream_executor/cuda/cuda_blas.h delete mode 100644 tensorflow/stream_executor/cuda/cuda_diagnostics.h delete mode 100644 tensorflow/stream_executor/cuda/cuda_dnn.h delete mode 100644 tensorflow/stream_executor/cuda/cuda_driver.h delete mode 100644 tensorflow/stream_executor/cuda/cuda_event.h delete mode 100644 tensorflow/stream_executor/cuda/cuda_fft.h delete mode 100644 tensorflow/stream_executor/cuda/cuda_gpu_executor.h delete mode 100644 tensorflow/stream_executor/cuda/cuda_helpers.h delete mode 100644 tensorflow/stream_executor/cuda/cuda_kernel.h delete mode 100644 tensorflow/stream_executor/cuda/cuda_platform.h delete mode 100644 tensorflow/stream_executor/cuda/cuda_platform_id.h delete mode 100644 tensorflow/stream_executor/cuda/cuda_rng.h delete mode 100644 tensorflow/stream_executor/cuda/cuda_stream.h delete mode 100644 tensorflow/stream_executor/cuda/cuda_timer.h delete mode 100644 tensorflow/stream_executor/data_type.h delete mode 100644 tensorflow/stream_executor/device_description.h delete mode 100644 tensorflow/stream_executor/device_memory.h delete mode 100644 tensorflow/stream_executor/device_memory_allocator.h delete mode 100644 tensorflow/stream_executor/device_options.h delete mode 100644 tensorflow/stream_executor/dnn.h delete mode 100644 tensorflow/stream_executor/event.h delete mode 100644 tensorflow/stream_executor/executor_cache.h delete mode 100644 tensorflow/stream_executor/gpu/BUILD delete mode 100644 tensorflow/stream_executor/gpu/gpu_asm_opts.h delete mode 100644 tensorflow/stream_executor/gpu/gpu_driver.h delete mode 100644 tensorflow/stream_executor/gpu/gpu_event.h delete mode 100644 tensorflow/stream_executor/gpu/gpu_executor.h delete mode 100644 tensorflow/stream_executor/gpu/gpu_helpers.h delete mode 100644 tensorflow/stream_executor/gpu/gpu_kernel.h delete mode 100644 tensorflow/stream_executor/gpu/gpu_stream.h delete mode 100644 tensorflow/stream_executor/gpu/gpu_timer.h delete mode 100644 tensorflow/stream_executor/gpu/gpu_types.h delete mode 100644 tensorflow/stream_executor/gpu/redzone_allocator.h delete mode 100644 tensorflow/stream_executor/gpu_launch_dim.h delete mode 100644 tensorflow/stream_executor/host/BUILD delete mode 100644 tensorflow/stream_executor/host/host_gpu_executor.h delete mode 100644 tensorflow/stream_executor/host/host_platform.h delete mode 100644 tensorflow/stream_executor/host/host_platform_id.h delete mode 100644 tensorflow/stream_executor/host/host_stream.h delete mode 100644 tensorflow/stream_executor/host/host_timer.h delete mode 100644 tensorflow/stream_executor/host_or_device_scalar.h delete mode 100644 tensorflow/stream_executor/kernel.h delete mode 100644 tensorflow/stream_executor/kernel_cache_config.h delete mode 100644 tensorflow/stream_executor/kernel_spec.h delete mode 100644 tensorflow/stream_executor/launch_dim.h delete mode 100644 tensorflow/stream_executor/lib/BUILD delete mode 100644 tensorflow/stream_executor/lib/array_slice.h delete mode 100644 tensorflow/stream_executor/lib/demangle.h delete mode 100644 tensorflow/stream_executor/lib/env.h delete mode 100644 tensorflow/stream_executor/lib/error.h delete mode 100644 tensorflow/stream_executor/lib/human_readable.h delete mode 100644 tensorflow/stream_executor/lib/initialize.h delete mode 100644 tensorflow/stream_executor/lib/mathutil.h delete mode 100644 tensorflow/stream_executor/lib/numbers.h delete mode 100644 tensorflow/stream_executor/lib/path.h delete mode 100644 tensorflow/stream_executor/lib/process_state.h delete mode 100644 tensorflow/stream_executor/lib/stacktrace.h delete mode 100644 tensorflow/stream_executor/lib/static_threadlocal.h delete mode 100644 tensorflow/stream_executor/lib/status.h delete mode 100644 tensorflow/stream_executor/lib/statusor.h delete mode 100644 tensorflow/stream_executor/lib/thread_options.h delete mode 100644 tensorflow/stream_executor/lib/threadpool.h delete mode 100644 tensorflow/stream_executor/module_spec.h delete mode 100644 tensorflow/stream_executor/multi_platform_manager.h delete mode 100644 tensorflow/stream_executor/platform.h delete mode 100644 tensorflow/stream_executor/platform/BUILD delete mode 100644 tensorflow/stream_executor/platform/default/BUILD delete mode 100644 tensorflow/stream_executor/platform/default/dso_loader.h delete mode 100644 tensorflow/stream_executor/platform/dso_loader.h delete mode 100644 tensorflow/stream_executor/platform/initialize.h delete mode 100644 tensorflow/stream_executor/platform/logging.h delete mode 100644 tensorflow/stream_executor/platform/platform.h delete mode 100644 tensorflow/stream_executor/platform/port.h delete mode 100644 tensorflow/stream_executor/plugin.h delete mode 100644 tensorflow/stream_executor/plugin_registry.h delete mode 100644 tensorflow/stream_executor/rng.h delete mode 100644 tensorflow/stream_executor/rocm/BUILD delete mode 100644 tensorflow/stream_executor/rocm/hipsolver_wrapper.h delete mode 100644 tensorflow/stream_executor/rocm/hipsparse_wrapper.h delete mode 100644 tensorflow/stream_executor/rocm/rocblas_wrapper.h delete mode 100644 tensorflow/stream_executor/rocm/rocm_activation.h delete mode 100644 tensorflow/stream_executor/rocm/rocm_blas.h delete mode 100644 tensorflow/stream_executor/rocm/rocm_diagnostics.h delete mode 100644 tensorflow/stream_executor/rocm/rocm_dnn.h delete mode 100644 tensorflow/stream_executor/rocm/rocm_driver_wrapper.h delete mode 100644 tensorflow/stream_executor/rocm/rocm_fft.h delete mode 100644 tensorflow/stream_executor/rocm/rocm_platform.h delete mode 100644 tensorflow/stream_executor/rocm/rocm_platform_id.h delete mode 100644 tensorflow/stream_executor/rocm/rocsolver_wrapper.h delete mode 100644 tensorflow/stream_executor/rocm/roctracer_wrapper.h delete mode 100644 tensorflow/stream_executor/scratch_allocator.h delete mode 100644 tensorflow/stream_executor/stream.h delete mode 100644 tensorflow/stream_executor/stream_executor.h delete mode 100644 tensorflow/stream_executor/stream_executor_internal.h delete mode 100644 tensorflow/stream_executor/stream_executor_pimpl.h delete mode 100644 tensorflow/stream_executor/temporary_device_memory.h delete mode 100644 tensorflow/stream_executor/temporary_memory_manager.h delete mode 100644 tensorflow/stream_executor/tf_allocator_adapter.h delete mode 100644 tensorflow/stream_executor/timer.h delete mode 100644 tensorflow/stream_executor/tpu/BUILD delete mode 100644 tensorflow/stream_executor/tpu/c_api_decl.h delete mode 100644 tensorflow/stream_executor/tpu/c_api_defn.h delete mode 100644 tensorflow/stream_executor/tpu/noncopyable_buffer.h delete mode 100644 tensorflow/stream_executor/tpu/proto_helper.h delete mode 100644 tensorflow/stream_executor/tpu/status_helper.h delete mode 100644 tensorflow/stream_executor/tpu/tpu_event.h delete mode 100644 tensorflow/stream_executor/tpu/tpu_executable.h delete mode 100644 tensorflow/stream_executor/tpu/tpu_executable_interface.h delete mode 100644 tensorflow/stream_executor/tpu/tpu_executor.h delete mode 100644 tensorflow/stream_executor/tpu/tpu_executor_c_api.h delete mode 100644 tensorflow/stream_executor/tpu/tpu_node_context.h delete mode 100644 tensorflow/stream_executor/tpu/tpu_op_executable.h delete mode 100644 tensorflow/stream_executor/tpu/tpu_platform.h delete mode 100644 tensorflow/stream_executor/tpu/tpu_platform_id.h delete mode 100644 tensorflow/stream_executor/tpu/tpu_platform_interface.h delete mode 100644 tensorflow/stream_executor/tpu/tpu_stream.h delete mode 100644 tensorflow/stream_executor/tpu/tpu_stream_interface.h delete mode 100644 tensorflow/stream_executor/tpu/tpu_timer.h delete mode 100644 tensorflow/stream_executor/tpu/tpu_topology.h delete mode 100644 tensorflow/stream_executor/tpu/tpu_transfer_manager.h delete mode 100644 tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h delete mode 100644 tensorflow/stream_executor/trace_listener.h create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.__internal__.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.__internal__.types.data.-dataset.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.__internal__.types.data.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.__internal__.types.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.dtypes.experimental.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.experimental.extension_type.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.test.experimental.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.__internal__.types.data.-dataset.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.__internal__.types.data.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-preemption-watcher.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.dtypes.experimental.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.experimental.extension_type.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.metrics.-f-beta-score.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.metrics.-f1-score.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adafactor.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.optimizers.-adam-w.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.saved_model.experimental.-fingerprint.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.test.experimental.pbtxt rename tensorflow/tools/{ci_build/rel/macos/cpu_py37_pip.sh => benchmark/onednn_benchmark_config.sh} (50%) create mode 100644 tensorflow/tools/ci_build/Dockerfile.rbe.cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython create mode 100755 tensorflow/tools/ci_build/linux/rocm/rocm_py310_pip.sh delete mode 100644 tensorflow/tools/ci_build/nightly_release/ubuntu/cpu_py37.sh delete mode 100644 tensorflow/tools/ci_build/nightly_release/ubuntu/gpu_py37.sh delete mode 100644 tensorflow/tools/ci_build/rel/macos/cpu_py37_nonpip.sh delete mode 100644 tensorflow/tools/ci_build/rel/ubuntu/cpu_py37_nonpip.sh delete mode 100644 tensorflow/tools/ci_build/rel/ubuntu/cpu_py37_pip.sh delete mode 100644 tensorflow/tools/ci_build/rel/ubuntu/gpu_py37_nonpip.sh delete mode 100644 tensorflow/tools/ci_build/rel/ubuntu/gpu_py37_pip.sh delete mode 100644 tensorflow/tools/ci_build/rel/windows/cpu_py37.bat delete mode 100644 tensorflow/tools/ci_build/rel/windows/gpu_py37.bat delete mode 100644 tensorflow/tools/ci_build/rel/windows_cuda114/cpu_py37.bat delete mode 100644 tensorflow/tools/ci_build/rel/windows_cuda114/gpu_py37.bat create mode 100644 tensorflow/tools/pip_package/redundant_tensorflow_gpu/README.md create mode 100644 tensorflow/tools/pip_package/redundant_tensorflow_gpu/setup.cfg create mode 100644 tensorflow/tools/pip_package/redundant_tensorflow_gpu/setup.py create mode 100644 tensorflow/tools/pip_package/redundant_tf_nightly_gpu/README.md create mode 100644 tensorflow/tools/pip_package/redundant_tf_nightly_gpu/setup.cfg create mode 100644 tensorflow/tools/pip_package/redundant_tf_nightly_gpu/setup.py create mode 100644 tensorflow/tools/tf_sig_build_dockerfiles/builder.devtoolset/glibc2.17-inline.patch create mode 100644 tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/cpu_clang.bazelrc create mode 100644 tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/gpu_clang.bazelrc create mode 100644 tensorflow/tools/toolchains/cpus/aarch64/README.md create mode 100644 tensorflow/tools/toolchains/cpus/aarch64/aarch64.bzl create mode 100644 tensorflow/tools/toolchains/cpus/aarch64/aarch64_compiler_configure.bzl rename third_party/llvm/temporary.patch => tensorflow/tools/toolchains/cpus/aarch64/crosstool/BUILD (100%) create mode 100644 tensorflow/tools/toolchains/cpus/aarch64/crosstool/BUILD.tpl create mode 100644 tensorflow/tools/toolchains/cpus/aarch64/crosstool/cc_toolchain_config.bzl.tpl create mode 100644 tensorflow/tools/toolchains/win/tf_win_01112023/BUILD create mode 100644 tensorflow/tools/toolchains/win/tf_win_01112023/armeabi_cc_toolchain_config.bzl create mode 100644 tensorflow/tools/toolchains/win/tf_win_01112023/builtin_include_directory_paths_msvc create mode 100644 tensorflow/tools/toolchains/win/tf_win_01112023/toolchain_image_info create mode 100644 tensorflow/tools/toolchains/win/tf_win_01112023/windows_cc_toolchain_config.bzl create mode 100644 tensorflow/tools/toolchains/win/tf_win_01232023/BUILD create mode 100644 tensorflow/tools/toolchains/win/tf_win_01232023/armeabi_cc_toolchain_config.bzl create mode 100644 tensorflow/tools/toolchains/win/tf_win_01232023/builtin_include_directory_paths_msvc create mode 100644 tensorflow/tools/toolchains/win/tf_win_01232023/toolchain_image_info create mode 100644 tensorflow/tools/toolchains/win/tf_win_01232023/windows_cc_toolchain_config.bzl create mode 100644 tensorflow/tsl/c/BUILD create mode 100644 tensorflow/tsl/c/tsl_status.cc create mode 100644 tensorflow/tsl/c/tsl_status.h create mode 100644 tensorflow/tsl/c/tsl_status_helper.cc create mode 100644 tensorflow/tsl/c/tsl_status_helper.h rename tensorflow/{c/tf_status_helper_test.cc => tsl/c/tsl_status_helper_test.cc} (68%) rename tensorflow/{stream_executor/gpu/gpu_rng.h => tsl/c/tsl_status_internal.h} (65%) rename tensorflow/{c/tf_status_test.cc => tsl/c/tsl_status_test.cc} (62%) create mode 100644 tensorflow/tsl/concurrency/BUILD create mode 100644 tensorflow/tsl/concurrency/async_value.cc create mode 100644 tensorflow/tsl/concurrency/async_value.h create mode 100644 tensorflow/tsl/concurrency/async_value_ptr_test.cc create mode 100644 tensorflow/tsl/concurrency/async_value_ref.cc create mode 100644 tensorflow/tsl/concurrency/async_value_ref.h create mode 100644 tensorflow/tsl/concurrency/async_value_ref_test.cc create mode 100644 tensorflow/tsl/concurrency/async_value_test.cc rename tensorflow/{stream_executor/fft.h => tsl/concurrency/chain.h} (70%) create mode 100644 tensorflow/tsl/concurrency/concurrent_vector.h create mode 100644 tensorflow/tsl/concurrency/concurrent_vector_test.cc create mode 100644 tensorflow/tsl/concurrency/ref_count.h create mode 100644 tensorflow/tsl/cuda/cuda_12_0.inc create mode 100644 tensorflow/tsl/cuda/cuda_runtime_11_8.inc create mode 100644 tensorflow/tsl/cuda/cuda_runtime_12_0.inc create mode 100644 tensorflow/tsl/cuda/cusparse_12_0.inc rename tensorflow/{core => tsl}/distributed_runtime/coordination/coordination_service_agent.cc (88%) rename tensorflow/{core => tsl}/distributed_runtime/coordination/coordination_service_agent.h (80%) rename tensorflow/{core => tsl}/distributed_runtime/coordination/coordination_service_agent_test.cc (92%) rename tensorflow/{core => tsl}/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc (82%) rename tensorflow/{core => tsl}/distributed_runtime/coordination/coordination_service_rpc_handler.cc (69%) create mode 100644 tensorflow/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h rename tensorflow/{core => tsl}/distributed_runtime/preemption/preemption_sync_manager.cc (83%) create mode 100644 tensorflow/tsl/distributed_runtime/preemption/preemption_sync_manager.h rename tensorflow/{core => tsl}/distributed_runtime/preemption/preemption_sync_manager_test.cc (89%) create mode 100644 tensorflow/tsl/distributed_runtime/rpc/coordination/BUILD rename tensorflow/{core => tsl}/distributed_runtime/rpc/coordination/grpc_coordination_client.cc (84%) rename tensorflow/{compiler/xla/stream_executor/lib/statusor.h => tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h} (50%) rename tensorflow/{core => tsl}/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.cc (71%) create mode 100644 tensorflow/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h create mode 100644 tensorflow/tsl/distributed_runtime/rpc/grpc_client_cq_tag.h create mode 100644 tensorflow/tsl/distributed_runtime/rpc/grpc_state.h create mode 100644 tensorflow/tsl/distributed_runtime/rpc/test_request.proto create mode 100644 tensorflow/tsl/framework/contraction/BUILD rename tensorflow/{core/kernels => tsl/framework/contraction}/eigen_contraction_kernel.cc (96%) rename tensorflow/{core/kernels => tsl/framework/contraction}/eigen_contraction_kernel.h (99%) create mode 100644 tensorflow/tsl/framework/convolution/BUILD rename tensorflow/{core/kernels => tsl/framework/convolution}/eigen_convolution_helpers.h (90%) rename tensorflow/{core/kernels => tsl/framework/convolution}/eigen_spatial_convolutions-inl.h (99%) rename tensorflow/{core/kernels => tsl/framework/convolution}/eigen_spatial_convolutions.h (97%) rename tensorflow/{core/kernels => tsl/framework/convolution}/eigen_spatial_convolutions_test.cc (71%) create mode 100644 tensorflow/tsl/mkl/BUILD rename {third_party => tensorflow/tsl}/mkl/LICENSE (100%) rename {third_party => tensorflow/tsl}/mkl/MKL_LICENSE (100%) create mode 100644 tensorflow/tsl/mkl/build_defs.bzl create mode 100644 tensorflow/tsl/platform/default/status.h create mode 100644 tensorflow/tsl/platform/float8.h rename tensorflow/{core/framework/float8_test.cc => tsl/platform/float8_test.cu.cc} (50%) rename tensorflow/{compiler/xla/stream_executor/lib => tsl/platform}/static_threadlocal.h (56%) create mode 100644 tensorflow/tsl/profiler/builds/oss/BUILD rename tensorflow/{core => tsl}/profiler/builds/oss/build_config.bzl (75%) rename tensorflow/{core => tsl}/profiler/convert/trace_events_to_json.cc (90%) rename tensorflow/{core => tsl}/profiler/convert/trace_events_to_json.h (67%) rename tensorflow/{core => tsl}/profiler/convert/trace_events_to_json_test.cc (92%) rename tensorflow/{core => tsl}/profiler/convert/xplane_to_trace_events.cc (82%) rename tensorflow/{core => tsl}/profiler/convert/xplane_to_trace_events.h (50%) rename tensorflow/{core => tsl}/profiler/convert/xplane_to_trace_events_test.cc (73%) create mode 100644 tensorflow/tsl/profiler/lib/connected_traceme.h create mode 100644 tensorflow/tsl/profiler/lib/nvtx_utils.h rename tensorflow/{core => tsl}/profiler/lib/profiler_collection.cc (84%) rename tensorflow/{core => tsl}/profiler/lib/profiler_collection.h (75%) rename tensorflow/{core => tsl}/profiler/lib/profiler_session.cc (71%) create mode 100644 tensorflow/tsl/profiler/lib/profiler_session.h rename tensorflow/{core => tsl}/profiler/lib/scoped_memory_debug_annotation.cc (87%) create mode 100644 tensorflow/tsl/profiler/lib/scoped_memory_debug_annotation.h create mode 100644 tensorflow/tsl/profiler/protobuf/profile.proto create mode 100644 tensorflow/tsl/profiler/protobuf/profiler_analysis.proto create mode 100644 tensorflow/tsl/profiler/protobuf/profiler_options.proto create mode 100644 tensorflow/tsl/profiler/protobuf/profiler_service.proto create mode 100644 tensorflow/tsl/profiler/protobuf/profiler_service_monitor_result.proto create mode 100644 tensorflow/tsl/profiler/protobuf/trace_events.proto create mode 100644 tensorflow/tsl/profiler/rpc/BUILD create mode 100644 tensorflow/tsl/profiler/rpc/client/BUILD rename tensorflow/{core => tsl}/profiler/rpc/client/capture_profile.cc (80%) create mode 100644 tensorflow/tsl/profiler/rpc/client/capture_profile.h rename tensorflow/{core => tsl}/profiler/rpc/client/profiler_client.cc (83%) create mode 100644 tensorflow/tsl/profiler/rpc/client/profiler_client.h rename tensorflow/{core => tsl}/profiler/rpc/client/profiler_client_test.cc (90%) rename tensorflow/{core => tsl}/profiler/rpc/client/profiler_client_test_util.h (78%) rename tensorflow/{core => tsl}/profiler/rpc/client/remote_profiler_session_manager.cc (87%) create mode 100644 tensorflow/tsl/profiler/rpc/client/remote_profiler_session_manager.h rename tensorflow/{core => tsl}/profiler/rpc/client/remote_profiler_session_manager_test.cc (85%) rename tensorflow/{core => tsl}/profiler/rpc/client/save_profile.cc (75%) create mode 100644 tensorflow/tsl/profiler/rpc/client/save_profile.h rename tensorflow/{core => tsl}/profiler/rpc/profiler_server.cc (83%) create mode 100644 tensorflow/tsl/profiler/rpc/profiler_server.h rename tensorflow/{core => tsl}/profiler/rpc/profiler_service_impl.cc (73%) rename tensorflow/{compiler/xla/service/hlo_clone_context.h => tsl/profiler/rpc/profiler_service_impl.h} (60%) rename tensorflow/{core => tsl}/profiler/utils/buffer_pool.cc (91%) rename tensorflow/{core => tsl}/profiler/utils/buffer_pool.h (85%) rename tensorflow/{core => tsl}/profiler/utils/buffer_pool_test.cc (95%) create mode 100644 tensorflow/tsl/profiler/utils/file_system_utils.h create mode 100644 tensorflow/tsl/profiler/utils/format_utils.h create mode 100644 tensorflow/tsl/protobuf/autotuning.proto create mode 100644 tensorflow/tsl/protobuf/dnn.proto create mode 100644 tensorflow/tsl/python/lib/core/BUILD create mode 100644 tensorflow/tsl/python/lib/core/bfloat16.cc rename tensorflow/{compiler/xla/python => tsl/python/lib/core}/bfloat16.h (78%) create mode 100644 tensorflow/tsl/python/lib/core/custom_casts.cc create mode 100644 tensorflow/tsl/python/lib/core/custom_casts.h rename tensorflow/{compiler/xla/python/bfloat16.cc => tsl/python/lib/core/custom_float.h} (89%) create mode 100644 tensorflow/tsl/python/lib/core/float8.cc create mode 100644 tensorflow/tsl/python/lib/core/float8.h rename tensorflow/{compiler/xla/python => tsl/python/lib/core}/float8_e4m3b11.cc (96%) rename tensorflow/{compiler/xla/python => tsl/python/lib/core}/float8_e4m3b11.h (88%) rename tensorflow/{compiler/xla/python => tsl/python/lib/core}/numpy.cc (90%) rename tensorflow/{compiler/xla/python => tsl/python/lib/core}/numpy.h (88%) delete mode 100644 third_party/common.bzl delete mode 100644 third_party/compute_library/acl_depthwise_updateable_weights.patch delete mode 100644 third_party/compute_library/acl_fixup_SVE_merges.patch create mode 100644 third_party/compute_library/acl_openmp_fix.patch create mode 100644 third_party/llvm/generated.patch create mode 100644 third_party/mkl_dnn/onednn_acl_threadpool_scheduler.patch create mode 100644 third_party/nvtx.BUILD create mode 100644 third_party/pybind11_abseil/remove_license.patch create mode 100644 third_party/triton/BUILD create mode 100644 third_party/triton/workspace.bzl diff --git a/.bazelrc b/.bazelrc index 0322618b53f..ff910cd186e 100644 --- a/.bazelrc +++ b/.bazelrc @@ -149,6 +149,12 @@ build --experimental_cc_shared_library # cc_shared_library ensures no library is linked statically more than once. build --experimental_link_static_libraries_once=false +# Prevent regressions on those two incompatible changes +# TODO: remove those flags when they are flipped in the default Bazel version TF uses. +build --incompatible_enforce_config_setting_visibility +# TODO: also enable this flag after fixing the visbility violations +# build --incompatible_config_setting_private_default_visibility + # Default options should come above this line. # Allow builds using libc++ as a linker library @@ -324,7 +330,9 @@ build:linux --copt="-Wunused-result" # build:linux --copt="-Werror=unused-result" # Add switch as an error on Linux. build:linux --copt="-Wswitch" -# build:linux --copt="-Werror=switch" +build:linux --copt="-Werror=switch" +# Required for building with clang +build:linux --copt="-Wno-error=unused-but-set-variable" # On Windows, `__cplusplus` is wrongly defined without this switch # See https://devblogs.microsoft.com/cppblog/msvc-now-correctly-reports-__cplusplus/ @@ -382,8 +390,8 @@ build:windows --host_copt=-DNOGDI # MSVC (Windows): Standards-conformant preprocessor mode # See https://docs.microsoft.com/en-us/cpp/preprocessor/preprocessor-experimental-overview -build:windows --copt=/experimental:preprocessor -build:windows --host_copt=/experimental:preprocessor +build:windows --copt=/Zc:preprocessor +build:windows --host_copt=/Zc:preprocessor # Misc build options we need for windows. build:windows --linkopt=/DEBUG @@ -559,8 +567,8 @@ build:rbe_linux_py3_base --python_path="/usr/local/bin/python3.9" build:rbe_linux_py3_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.9" build:rbe_win --config=rbe -build:rbe_win --crosstool_top="//tensorflow/tools/toolchains/win/tf_win_06152022:toolchain" -build:rbe_win --extra_toolchains="//tensorflow/tools/toolchains/win/tf_win_06152022:cc-toolchain-x64_windows" +build:rbe_win --crosstool_top="//tensorflow/tools/toolchains/win/tf_win_01232023:toolchain" +build:rbe_win --extra_toolchains="//tensorflow/tools/toolchains/win/tf_win_01232023:cc-toolchain-x64_windows" build:rbe_win --extra_execution_platforms="//tensorflow/tools/toolchains/win:rbe_windows_ltsc2019" build:rbe_win --host_platform="//tensorflow/tools/toolchains/win:rbe_windows_ltsc2019" build:rbe_win --platforms="//tensorflow/tools/toolchains/win:rbe_windows_ltsc2019" @@ -672,6 +680,7 @@ build:asan --copt -g build:asan --copt -O3 build:asan --copt -fno-omit-frame-pointer build:asan --linkopt -fsanitize=address +build:asan --@libjpeg_turbo//:noasm=yes # Memory sanitizer # CC=clang bazel build --config msan @@ -695,7 +704,17 @@ build:ubsan --linkopt -fsanitize=undefined build:ubsan --linkopt -lubsan # Disable TFRT integration for now unless --config=tfrt is specified. -build --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/common,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils +build --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils # TODO(b/240450920): We are in the process of migrating JitRt backend to XLA # and while we are doing this we can't keep it buildable/testable in OSS. -build:tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/common,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils +build:tfrt --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils + +# TF Fuzztest config +try-import fuzztest.bazelrc +run:tf_fuzztest --config=fuzztest +# Should aim to remove these +build:tf_fuzztest --action_env=CC=clang +build:tf_fuzztest --action_env=CXX=clang++ +build:tf_fuzztest --spawn_strategy=sandboxed +build:tf_fuzztest --config=monolithic +build:tf_fuzztest --@libjpeg_turbo//:noasm=yes diff --git a/.bazelversion b/.bazelversion index e230c8396d1..f53152b50eb 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1,2 @@ -5.3.0 \ No newline at end of file +5.3.0 +# NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/tensorflow_issue_template.yaml b/.github/ISSUE_TEMPLATE/tensorflow_issue_template.yaml index 6e4753d8674..70bdc6160cb 100644 --- a/.github/ISSUE_TEMPLATE/tensorflow_issue_template.yaml +++ b/.github/ISSUE_TEMPLATE/tensorflow_issue_template.yaml @@ -23,6 +23,17 @@ body: value: | Please make sure that this is a bug. As per our [GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. + - type: dropdown + id: tf-nightly + attributes: + label: Have you reproduced the bug with TF nightly? + description: It is strongly suggested that you have reproduced the bug with [TF nightly](https://www.tensorflow.org/install/pip#nightly) + options: + - "Yes" + - "No" + validations: + required: true + - type: markdown attributes: value: | @@ -38,6 +49,7 @@ body: - binary validations: required: true + - type: input id: tfversion attributes: diff --git a/.github/ISSUE_TEMPLATE/tflite-other.md b/.github/ISSUE_TEMPLATE/tflite-other.md new file mode 100644 index 00000000000..8b8246f2b72 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/tflite-other.md @@ -0,0 +1,62 @@ +name: TensorFlow Lite Other Issue description: Use this template to report any +issue in TensorFlow Lite that is not about Converters, Play Services or Ops +body: - type: dropdown id: issue-type attributes: label: Issue Type description: +What type of issue would you like to report? multiple: false options: - Bug - +Build/Install - Performance - Support - Feature Request - Documentation Feature +Request - Documentation Bug - Others validations: required: true - type: +markdown attributes: value: | Please make sure that this is a bug. As per our +[GitHub Policy](https://github.com/tensorflow/tensorflow/blob/master/ISSUES.md),we +only address code/doc bugs, performance issues, feature requests and +build/installation issues on GitHub. + +- type: markdown + attributes: + value: | + You can collect some of this information using our environment capture [script](https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh) You can also obtain the TensorFlow version with:
1. TF 1.0: `python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"`
2. TF 2.0: `python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"` + +- type: dropdown id: source attributes: label: Source description: Tensorflow + installed from options: - source - binary validations: required: true + +- type: input id: tfversion attributes: label: Tensorflow Version description: + placeholder: ex,. tf 2.8 validations: required: true + +- type: dropdown id: Code attributes: label: Custom Code description: + options: - "Yes" - "No" validations: required: true + +- type: input id: OS attributes: label: OS Platform and Distribution + description: placeholder: e.g., Linux Ubuntu 16.04 validations: required: + false + +- type: input id: Mobile attributes: label: Mobile device description: + placeholder: e.g., Linux Ubuntu 16.04 validations: required: false + +- type: input id: Python attributes: label: Python version description: + placeholder: e.g., 3.9 validations: required: false + +- type: input id: Bazel attributes: label: Bazel version description: if + compiling from source placeholder: validations: required: false + +- type: input id: Compiler attributes: label: GCC/Compiler version + description: if compiling from source placeholder: validations: required: + false + +- type: input id: Cuda attributes: label: CUDA/cuDNN version description: + placeholder: validations: required: false + +- type: input id: Gpu attributes: label: GPU model and memory description: if + compiling from source placeholder: validations: required: false + +- type: textarea id: what-happened attributes: label: Current Behaviour? + description: Also tell us, what did you expect to happen? placeholder: Tell + us what you see! value: "A bug happened!" render: shell validations: + required: true + +- type: textarea id: code-to-reproduce attributes: label: Standalone code to + reproduce the issue description: Provide a reproducible test case that is + the bare minimum necessary to generate the problem. If possible, please + share a link to Colab/Jupyter/any notebook. placeholder: Tell us what you + see! value: render: shell validations: required: true + +- type: textarea id: logs attributes: label: Relevant log output description: + Please copy and paste any relevant log output. This will be automatically + formatted into code, so no need for backticks. render: shell diff --git a/.github/bot_config.yml b/.github/bot_config.yml index 3f039b9e176..bab88af1a8e 100644 --- a/.github/bot_config.yml +++ b/.github/bot_config.yml @@ -15,9 +15,10 @@ # A list of assignees assignees: - - tilakrayal + - synandi - tiruk007 - - Mohantym + - gaikwadrahul8 + - pjpratik # A list of assignees for compiler folder compiler_assignees: - joker-eph diff --git a/.github/workflows/arm-cd.yml b/.github/workflows/arm-cd.yml index 1698cf0f0b3..b601b0054c7 100644 --- a/.github/workflows/arm-cd.yml +++ b/.github/workflows/arm-cd.yml @@ -26,9 +26,14 @@ jobs: build: if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks runs-on: [self-hosted, linux, ARM64] + continue-on-error: ${{ matrix.experimental }} strategy: matrix: - pyver: ['3.7', '3.8', '3.9', '3.10'] + pyver: ['3.8', '3.9', '3.10'] + experimental: [false] + include: + - pyver: '3.11' + experimental: true steps: - name: Stop old running containers (if any) shell: bash @@ -46,12 +51,12 @@ jobs: run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository for nightly (skipped for releases) if: ${{ github.event_name == 'schedule' }} - uses: actions/checkout@v3 + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 with: ref: 'nightly' - name: Checkout repository for releases (skipped for nightly) if: ${{ github.event_name == 'push' }} - uses: actions/checkout@v3 + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 - name: Build and test pip wheel shell: bash run: | diff --git a/.github/workflows/arm-ci-extended.yml b/.github/workflows/arm-ci-extended.yml index 0fcf49e340a..1592f4ed18a 100644 --- a/.github/workflows/arm-ci-extended.yml +++ b/.github/workflows/arm-ci-extended.yml @@ -50,7 +50,7 @@ jobs: shell: bash run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 - name: Build binary and run non-pip tests shell: bash run: | diff --git a/.github/workflows/arm-ci.yml b/.github/workflows/arm-ci.yml index 067e29131e7..e6ddbb9eec9 100644 --- a/.github/workflows/arm-ci.yml +++ b/.github/workflows/arm-ci.yml @@ -21,14 +21,15 @@ on: - master - r2.** pull_request: - types: [opened, synchronize, reopened] + types: [labeled, opened, synchronize, reopened] branches: - master - r2.** jobs: build: - if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks + # Don't do this in forks, and if labeled, only for 'kokoro:force-run' + if: github.repository == 'tensorflow/tensorflow' && (github.event.action != 'labeled' || (github.event.action == 'labeled' && github.event.label.name == 'kokoro:force-run')) runs-on: [self-hosted, linux, ARM64] strategy: matrix: @@ -49,14 +50,14 @@ jobs: shell: bash run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 - name: Build and test pip wheel shell: bash run: | CI_DOCKER_BUILD_EXTRA_PARAMS='--build-arg py_major_minor_version=${{ matrix.pyver }}' \ ./tensorflow/tools/ci_build/ci_build.sh cpu.arm64 bash tensorflow/tools/ci_build/rel/ubuntu/cpu_arm64_pip.sh - name: Upload pip wheel to GitHub - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@83fd05a356d7e2593de66fc9913b3002723633cb # v3.1.1 with: name: tensorflow_py${{ matrix.pyver }}_wheel path: /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/whl/*.whl diff --git a/.github/workflows/cffconvert.yml b/.github/workflows/cffconvert.yml index fdae2ac19e6..21ac759f3ef 100644 --- a/.github/workflows/cffconvert.yml +++ b/.github/workflows/cffconvert.yml @@ -27,9 +27,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out a copy of the repository - uses: actions/checkout@v2 + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 - name: Check whether the citation metadata from CITATION.cff is valid - uses: citation-file-format/cffconvert-github-action@2.0.0 + uses: citation-file-format/cffconvert-github-action@4cf11baa70a673bfdf9dad0acc7ee33b3f4b6084 # v2.0.0 with: args: "--validate" diff --git a/.github/workflows/issue-on-pr-rollback.yml b/.github/workflows/issue-on-pr-rollback.yml index ce0182bedc2..fa76923a2ba 100644 --- a/.github/workflows/issue-on-pr-rollback.yml +++ b/.github/workflows/issue-on-pr-rollback.yml @@ -27,9 +27,9 @@ jobs: startsWith(github.event.head_commit.message, 'Rollback of PR #') steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 - name: Create a new Github Issue - uses: actions/github-script@v5 + uses: actions/github-script@d556feaca394842dc55e4734bf3bb9f685482fa0 # v6.3.3 with: github-token: ${{secrets.GITHUB_TOKEN}} script: | diff --git a/.github/workflows/pylint-presubmit.yml b/.github/workflows/pylint-presubmit.yml index f1b539f551b..e97f34472d8 100644 --- a/.github/workflows/pylint-presubmit.yml +++ b/.github/workflows/pylint-presubmit.yml @@ -25,17 +25,17 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 - name: Get file changes id: get_file_changes - uses: trilom/file-changes-action@v1.2.4 + uses: trilom/file-changes-action@a6ca26c14274c33b15e6499323aac178af06ad4b # v1.2.4 with: output: ' ' - name: Report list of changed files run: | echo Changed files: ${{ steps.get_file_changes.outputs.files }} - name: Set up Python 3.9 - uses: actions/setup-python@v2 + uses: actions/setup-python@2c3dd9e7e29afd70cc0950079bde6c979d1f69f9 # v4.3.1 with: python-version: "3.9" - name: Install Python dependencies diff --git a/.github/workflows/release-branch-cherrypick.yml b/.github/workflows/release-branch-cherrypick.yml index a57852a9644..5ff69e46805 100644 --- a/.github/workflows/release-branch-cherrypick.yml +++ b/.github/workflows/release-branch-cherrypick.yml @@ -42,7 +42,7 @@ jobs: if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 with: ref: ${{ github.event.inputs.release_branch }} - name: Get some helpful info for formatting @@ -52,10 +52,10 @@ jobs: git config --global user.email "jenkins@tensorflow.org" git fetch origin master git cherry-pick ${{ github.event.inputs.git_commit }} - echo ::set-output name=SHORTSHA::$(git log -1 ${{ github.event.inputs.git_commit }} --format="%h") - echo ::set-output name=TITLE::$(git log -1 ${{ github.event.inputs.git_commit }} --format="%s") + echo "SHORTSHA=$(git log -1 ${{ github.event.inputs.git_commit }} --format="%h")" >> "$GITHUB_OUTPUT" + echo "TITLE=$(git log -1 ${{ github.event.inputs.git_commit }} --format="%s")" >> "$GITHUB_OUTPUT" - name: Create Pull Request with changes - uses: peter-evans/create-pull-request@v3 + uses: peter-evans/create-pull-request@2b011faafdcbc9ceb11414d64d0573f37c774b04 # v4.2.3 with: title: '${{ github.event.inputs.release_branch }} cherry-pick: ${{ steps.cherrypick.outputs.SHORTSHA }} "${{ steps.cherrypick.outputs.TITLE }}"' committer: TensorFlow Release Automation diff --git a/.github/workflows/scorecards-analysis.yml b/.github/workflows/scorecards-analysis.yml index 8f9dab872b6..1c520aa86fd 100644 --- a/.github/workflows/scorecards-analysis.yml +++ b/.github/workflows/scorecards-analysis.yml @@ -34,23 +34,18 @@ jobs: # Needed to upload the results to code-scanning dashboard. security-events: write id-token: write - actions: read - contents: read steps: - name: "Checkout code" - uses: actions/checkout@ec3a7ce113134d7a93b817d10a8272cb61118579 # v2.4.0 + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 with: persist-credentials: false - name: "Run analysis" - uses: ossf/scorecard-action@08dd0cebb088ac0fd6364339b1b3b68b75041ea8 # v2.0.0-alpha.2 + uses: ossf/scorecard-action@15c10fcf1cf912bd22260bfec67569a359ab87da # v2.1.1 with: results_file: results.sarif results_format: sarif - # Read-only PAT token. To create it, - # follow the steps in https://github.com/ossf/scorecard-action#pat-token-creation. - repo_token: ${{ secrets.SCORECARD_READ_TOKEN }} # Publish the results to enable scorecard badges. For more details, see # https://github.com/ossf/scorecard-action#publishing-results. # For private repositories, `publish_results` will automatically be set to `false`, @@ -59,7 +54,7 @@ jobs: # Upload the results as artifacts (optional). - name: "Upload artifact" - uses: actions/upload-artifact@82c141cc518b40d92cc801eee768e7aafc9c2fa2 # v2.3.1 + uses: actions/upload-artifact@83fd05a356d7e2593de66fc9913b3002723633cb # v3.1.1 with: name: SARIF file path: results.sarif @@ -67,6 +62,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@5f532563584d71fdef14ee64d17bafb34f751ce5 # v1.0.26 + uses: github/codeql-action/upload-sarif@896079047b4bb059ba6f150a5d87d47dde99e6e5 # v2.11.6 with: sarif_file: results.sarif diff --git a/.github/workflows/sigbuild-docker-branch.yml b/.github/workflows/sigbuild-docker-branch.yml index 41b0fe5a13a..c898381efd5 100644 --- a/.github/workflows/sigbuild-docker-branch.yml +++ b/.github/workflows/sigbuild-docker-branch.yml @@ -31,23 +31,23 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [python3.7, python3.8, python3.9, python3.10] + python-version: [python3.8, python3.9, python3.10, python3.11] steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@8c0edbc76e98fa90f69d9a2c020dcb50019dc325 # v2.2.1 - name: Login to DockerHub - uses: docker/login-action@v1 + uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a # v2.1.0 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Login to GCR - uses: docker/login-action@v1 + uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a # v2.1.0 with: registry: gcr.io username: _json_key @@ -55,14 +55,14 @@ jobs: - name: Generate variables for cache busting and tag naming run: | - echo "::set-output name=DATE::$(date +'%Y-%m-%d')" + echo "DATE=$(date +'%Y-%m-%d')" >> "$GITHUB_OUTPUT" # Converts r2.9 to just 2.9 - echo "::set-output name=REF::$(echo $GITHUB_REF_NAME | sed 's/r//g')" + echo "REF=$(echo $GITHUB_REF_NAME | sed 's/r//g')" >> "$GITHUB_OUTPUT" id: vars - name: Build and push id: docker_build - uses: docker/build-push-action@v2 + uses: docker/build-push-action@c56af957549030174b10d6867f20e78cfd7debc5 # v3.2.0 with: push: true context: ./tensorflow/tools/tf_sig_build_dockerfiles diff --git a/.github/workflows/sigbuild-docker-presubmit.yml b/.github/workflows/sigbuild-docker-presubmit.yml index c77c0d66311..065fd91319e 100644 --- a/.github/workflows/sigbuild-docker-presubmit.yml +++ b/.github/workflows/sigbuild-docker-presubmit.yml @@ -29,18 +29,18 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [python3.7, python3.8, python3.9, python3.10] + python-version: [python3.8, python3.9, python3.10, python3.11] steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@8c0edbc76e98fa90f69d9a2c020dcb50019dc325 # v2.2.1 - name: Login to GCR if: contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging') - uses: docker/login-action@v1 + uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a # v2.1.0 with: registry: gcr.io username: _json_key @@ -48,12 +48,12 @@ jobs: - name: Grab the date to do cache busting (assumes same day OK to keep) run: | - echo "::set-output name=DATE::$(date +'%Y-%m-%d')" + echo "DATE=$(date +'%Y-%m-%d')" >> "$GITHUB_OUTPUT" id: date - name: Build containers, and push to GCR only if the 'build and push to gcr.io for staging' label is applied id: docker_build - uses: docker/build-push-action@v2 + uses: docker/build-push-action@c56af957549030174b10d6867f20e78cfd7debc5 # v3.2.0 with: push: ${{ contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging') }} context: ./tensorflow/tools/tf_sig_build_dockerfiles @@ -69,17 +69,17 @@ jobs: cache-to: type=inline - name: Add a comment with the pushed containers - uses: mshick/add-pr-comment@v1 + uses: mshick/add-pr-comment@a65df5f64fc741e91c59b8359a4bc56e57aaf5b1 # v2 if: contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging') with: repo-token: ${{ secrets.GITHUB_TOKEN }} message: | I pushed these containers: + - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.11` - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.10` - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.9` - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.8` - - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.7` Re-apply the `build and push to gcr.io for staging` label to rebuild and push again. This comment will only be posted once. - diff --git a/.github/workflows/sigbuild-docker.yml b/.github/workflows/sigbuild-docker.yml index 276a0abc242..c9b12a39076 100644 --- a/.github/workflows/sigbuild-docker.yml +++ b/.github/workflows/sigbuild-docker.yml @@ -34,23 +34,23 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [python3.7, python3.8, python3.9, python3.10] + python-version: [python3.8, python3.9, python3.10, python3.11] steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@8c0edbc76e98fa90f69d9a2c020dcb50019dc325 # v2.2.1 - name: Login to DockerHub - uses: docker/login-action@v1 + uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a # v2.1.0 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Login to GCR - uses: docker/login-action@v1 + uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a # v2.1.0 with: registry: gcr.io username: _json_key @@ -61,15 +61,15 @@ jobs: # [[:digit:]] searches for numbers and \+ joins them together major_version=$(grep "^#define TF_MAJOR_VERSION" ./tensorflow/core/public/version.h | grep -o "[[:digit:]]\+") minor_version=$(grep "^#define TF_MINOR_VERSION" ./tensorflow/core/public/version.h | grep -o "[[:digit:]]\+") - echo ::set-output name=TF_VERSION::${major_version}.${minor_version} + echo "TF_VERSION=${major_version}.${minor_version}" >> "$GITHUB_OUTPUT" # Also get the current date to do cache busting. Assumes one day # is an ok range for rebuilds - echo "::set-output name=DATE::$(date +'%Y-%m-%d')" + echo "DATE=$(date +'%Y-%m-%d')" >> "$GITHUB_OUTPUT" id: tf-version - name: Build and push id: docker_build - uses: docker/build-push-action@v2 + uses: docker/build-push-action@c56af957549030174b10d6867f20e78cfd7debc5 # v3.2.0 with: push: true context: ./tensorflow/tools/tf_sig_build_dockerfiles diff --git a/.github/workflows/trusted-partners.yml b/.github/workflows/trusted-partners.yml index abf62dd2b8a..7c2fb863d15 100644 --- a/.github/workflows/trusted-partners.yml +++ b/.github/workflows/trusted-partners.yml @@ -30,9 +30,9 @@ jobs: github.event.sender.type == 'User' steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 - name: Trusted-Partners-PR - uses: actions/github-script@v6 + uses: actions/github-script@d556feaca394842dc55e4734bf3bb9f685482fa0 # v6.3.3 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | @@ -49,6 +49,9 @@ jobs: case "nvidia.com": console.log(await script.filter({github, context, domain})); break; + case "linaro.org": + console.log(await script.filter({github, context, domain})); + break; case "google.com": console.log("Googler. No action necessary"); break; diff --git a/.github/workflows/trusted_partners.js b/.github/workflows/trusted_partners.js index 6b6de25946e..60de918108d 100644 --- a/.github/workflows/trusted_partners.js +++ b/.github/workflows/trusted_partners.js @@ -39,9 +39,9 @@ const get_email_domain = async ({github, username}) => { return domain; }; -/** For trusted parters like Intel, we want to auto-run tests and mark the PR as ready to pull +/** For trusted parters like Intel, we want to auto-run tests This allows us to reduce the delay to external partners - Add Labels - kokoro:force-run, ready to pull + Add Labels - kokoro:force-run The PR is also assigned to specific teams to fast track review Additional reviewers can be added manually based on PR contents @param {!object} @@ -50,34 +50,41 @@ const get_email_domain = async ({github, username}) => { @return {string} Returns the message with labels attached and assignees added */ const filter_action = async ({github, context, domain}) => { - const labels = ['kokoro:force-run', 'ready to pull']; + const labels = ['kokoro:force-run']; let assignees = []; const title = context.payload.pull_request && context.payload.pull_request.title; + const lowercased_title = (title || '').toLowerCase(); const onednn_assignees = ['penpornk']; - if (title && title.toLowerCase().includes("onednn")) - assignees = onednn_assignees; + if (lowercased_title.includes('onednn')) assignees = onednn_assignees; const intel_windows_assignees = ['nitins17', 'learning-to-play']; - if (title && title.toLowerCase().includes('intel') && - title.toLowerCase().includes('windows') && domain.includes('intel.com')) + if (lowercased_title.includes('intel') && + lowercased_title.includes('windows') && domain.includes('intel.com')) assignees = intel_windows_assignees; const apple_silicon_assignees = ['penpornk', 'nitins17']; - if (title && title.toLowerCase().includes('apple') && - title.toLowerCase().includes('silicon') && domain.includes('apple.com')) + if (lowercased_title.includes('apple') && + lowercased_title.includes('silicon') && domain.includes('apple.com')) assignees = apple_silicon_assignees; - if (title && title.toLowerCase().includes('nvidia') && - domain.includes('nvidia.com')) { - if (title.toLowerCase().includes('jax')) { + if (lowercased_title.includes('tf-trt') && domain.includes('nvidia.com')) { + assignees.push( + 'DEKHTIARJonathan', 'meena-at-work', 'nluehr', 'pjannaty', 'poulsbo'); + } else if ( + lowercased_title.includes('nvidia') && domain.includes('nvidia.com')) { + if (lowercased_title.includes('jax')) { assignees.push('hawkinsp', 'yashk2810', 'skye'); } - if (title.toLowerCase().includes('xla') || - title.toLowerCase().includes('gpu')) { + if (lowercased_title.includes('xla') || lowercased_title.includes('gpu')) { assignees.push('cheshire', 'gcforster', 'reedwm', 'chsigg', 'xla-rotation'); } - if (title.toLowerCase().includes('tf')) { + if (lowercased_title.includes('tf')) { assignees.push('rohan100jain', 'bfontain'); } } + if (lowercased_title.includes('linaro') && domain.includes('linaro.org')) { + if (lowercased_title.includes('arm_ci')) { + assignees.push('nitins17', 'penpornk'); + } + } const resp_label = await github.rest.issues.addLabels({ issue_number: context.issue.number, diff --git a/.github/workflows/update-nightly.yml b/.github/workflows/update-nightly.yml index 0265ffbebe2..60372fddd27 100644 --- a/.github/workflows/update-nightly.yml +++ b/.github/workflows/update-nightly.yml @@ -23,7 +23,7 @@ jobs: if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks runs-on: ubuntu-latest steps: - - uses: zofrex/mirror-branch@v1 + - uses: zofrex/mirror-branch@a8809f0b42f9dfe9b2c5c2162a46327c23d15266 # v1.0.3 name: Set nightly branch to master HEAD with: target-branch: 'nightly' diff --git a/.github/workflows/update-rbe.yml b/.github/workflows/update-rbe.yml index 2f86ff2b2e5..ce31d59868a 100644 --- a/.github/workflows/update-rbe.yml +++ b/.github/workflows/update-rbe.yml @@ -27,7 +27,7 @@ jobs: if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@755da8c3cf115ac066823e79a1e1788f8940201b # v3.2.0 - name: Update the RBE Configs run: | function map() { @@ -48,28 +48,40 @@ jobs: # See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/toolchains/remote_config/configs.bzl # This is a mapping of name_container_map keys under sigbuild_tf_configs # to tag names on gcr.io/tensorflow-sigs/build. + # TF 2.9 map sigbuild-r2.9 2.9-python3.9 - map sigbuild-r2.9-python3.7 2.9-python3.7 map sigbuild-r2.9-python3.8 2.9-python3.8 map sigbuild-r2.9-python3.9 2.9-python3.9 map sigbuild-r2.9-python3.10 2.9-python3.10 + # TF 2.10 map sigbuild-r2.10 2.10-python3.9 - map sigbuild-r2.10-python3.7 2.10-python3.7 map sigbuild-r2.10-python3.8 2.10-python3.8 map sigbuild-r2.10-python3.9 2.10-python3.9 map sigbuild-r2.10-python3.10 2.10-python3.10 - map sigbuild-128 128-python3.9 - map sigbuild-128-python3.7 128-python3.7 - map sigbuild-128-python3.8 128-python3.8 - map sigbuild-128-python3.9 128-python3.9 - map sigbuild-128-python3.10 128-python3.10 + # TF 2.11 map sigbuild-r2.11 2.11-python3.9 - map sigbuild-r2.11-python3.7 2.11-python3.7 map sigbuild-r2.11-python3.8 2.11-python3.8 map sigbuild-r2.11-python3.9 2.11-python3.9 - map sigbuild-r2.11-python3.11 2.11-python3.10 + map sigbuild-r2.11-python3.10 2.11-python3.10 + # WIP Clang Containers, used by TVCs + map sigbuild-57469 57469-python3.9 + map sigbuild-57469-python3.8 57469-python3.8 + map sigbuild-57469-python3.9 57469-python3.9 + map sigbuild-57469-python3.10 57469-python3.10 + # TF 2.12 + map sigbuild-r2.12 2.12-python3.9 + map sigbuild-r2.12-python3.8 2.12-python3.8 + map sigbuild-r2.12-python3.9 2.12-python3.9 + map sigbuild-r2.12-python3.10 2.12-python3.10 + map sigbuild-r2.12-python3.11 2.12-python3.11 + # TF 2.12 + Clang (containers are the same, but env vars in configs.bzl are different) + map sigbuild-r2.12-clang 2.12-python3.9 + map sigbuild-r2.12-clang-python3.8 2.12-python3.8 + map sigbuild-r2.12-clang-python3.9 2.12-python3.9 + map sigbuild-r2.12-clang-python3.10 2.12-python3.10 + map sigbuild-r2.12-clang-python3.11 2.12-python3.11 - name: Create Pull Request with changes - uses: peter-evans/create-pull-request@v3 + uses: peter-evans/create-pull-request@2b011faafdcbc9ceb11414d64d0573f37c774b04 # v4.2.3 with: title: Update the RBE images to the latest container versions committer: TensorFlow Release Automation diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 01e20da7c87..ccc170b5c6e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -243,7 +243,7 @@ There are two ways to run TensorFlow unit tests. For a single component e.g. softmax op: ```bash - bazel test ${flags} tensorflow/python/kernel_tests:softmax_op_test + bazel test ${flags} tensorflow/python/kernel_tests/nn_ops:softmax_op_test ``` For a single/parameterized test e.g. `test_capture_variables` in diff --git a/README.md b/README.md index 73e75c1df81..c94227d26d7 100644 --- a/README.md +++ b/README.md @@ -104,6 +104,19 @@ for general questions and discussion, and please direct specific questions to The TensorFlow project strives to abide by generally accepted best practices in open-source software development. +## Patching guidelines + +Follow these steps to patch a specific version of TensorFlow, for example, to +apply fixes to bugs or security vulnerabilities: + +* Clone the TensorFlow repo and switch to the corresponding branch for your + desired TensorFlow version, for example, branch `r2.8` for version 2.8. +* Apply (that is, cherry pick) the desired changes and resolve any code + conflicts. +* Run TensorFlow tests and ensure they pass. +* [Build](https://www.tensorflow.org/install/source) the TensorFlow pip + package from source. + ## Continuous build status You can find more community-supported platforms and configurations in the diff --git a/RELEASE.md b/RELEASE.md index 40320f2a172..ea4ab08237e 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,18 +1,114 @@ -# Release 2.12.0 +# Release 2.13.0 -* `tf.keras`: +# Breaking Changes - * Added `jit_compile` as a settable property to `tf.keras.Model`. - * Added `synchronized` optional parameter to `layers.BatchNormalization`. - * Added deprecation warning to - `layers.experimental.SyncBatchNormalization` and suggested to use - `layers.BatchNormalization` with `synchronized=True` instead. +* +* + +# Known Caveats + +* +* +* + +# Major Features and Improvements + +* `tf.lite`: + + * Add 16-bit and 64-bit float type support for built-in op `cast`. + +* `tf.keras` + + * Added Keras metrics `tf.keras.metrics.FBetaScore` and + `tf.keras.metrics.F1Score`. + +# Bug Fixes and Other Changes + +* +* +* + +# Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +, , , , , + + +# Release 2.12.0 # Breaking Changes * * +* Build, Compilation and Packaging + + * Removal of redundant packages: the `tensorflow-gpu` and `tf-nightly-gpu` + packages have been effectively removed and replaced with packages that + direct users to switch to `tensorflow` or `tf-nightly` respectively. + The naming difference was the only difference between the two sets of + packages ever since TensorFlow 2.1, so there is no loss of functionality + or GPU support. See + https://pypi.org/project/tensorflow-gpu for more details. + +* `tf.function`: + + * tf.function now uses the Python inspect library directly for parsing + the signature of the Python function it is decorated on. + * This can break certain cases that were previously ignored where the + signature is malformed, e.g. + * Using functools.wraps on a function with different signature + * Using functools.partial with an invalid tf.function input + * tf.function now enforces input parameter names to be valid Python + identifiers. Incompatible names are automatically sanitized similarly to + existing SavedModel signature behavior. + * Parameterless tf.functions are assumed to have an empty input_signature + instead of an undefined one even if the input_signature is unspecified. + * tf.types.experimental.TraceType now requires an additional + `placeholder_value` method to be defined. + * tf.function now traces with placeholder values generated by TraceType + instead of the value itself. + +* `tf.config.experimental.enable_mlir_graph_optimization`: + + * Experimental API removed. + +* `tf.config.experimental.disable_mlir_graph_optimization`: + + * Experimental API removed. + +* `tf.keras` + + * Moved all saving-related utilities to a new namespace, `keras.saving`, + i.e. `keras.saving.load_model`, `keras.saving.save_model`, + `keras.saving.custom_object_scope`, `keras.saving.get_custom_objects`, + `keras.saving.register_keras_serializable`, + `keras.saving.get_registered_name` and + `keras.saving.get_registered_object`. + The previous API locations (in `keras.utils` and `keras.models`) will + stay available indefinitely, but we recommend that you update your code + to point to the new API locations. + * Improvements and fixes in Keras loss masking: + * Whether you represent a ragged tensor as a `tf.RaggedTensor` or using + [keras masking](https://www.tensorflow.org/guide/keras/masking_and_padding), + the returned loss values should be the identical to each other. + In previous versions Keras may have silently ignored the mask. + * If you use masked losses with Keras the loss values may be different + in TensorFlow `2.12` compared to previous versions. + * In cases where the mask was previously ignored, you will now get + an error if you pass a mask with an incompatible shape. + +* `tf.SavedModel` + + * Introduce new class `tf.saved_model.experimental.Fingerprint` that + contains the fingerprint of the SavedModel. See the + [SavedModel Fingerprinting RFC](https://github.com/tensorflow/community/pull/415) + for details. + * Introduce API `tf.saved_model.experimental.read_fingerprint(export_dir)` + for reading the fingerprint of a SavedModel. + + # Known Caveats * @@ -25,13 +121,90 @@ * Add 16-bit float type support for built-in op `fill`. * Transpose now supports 6D tensors. + * Float LSTM now supports diagonal recurrent tensors: + https://arxiv.org/abs/1903.08023 * `tf.keras`: + * The new Keras model saving format (`.keras`) is available. You can start + using it via `model.save(f"{fname}.keras", save_format="keras_v3")`. In + the future it will become the default for all files with the `.keras` + extension. This file format targets the Python runtime only and makes + it possible to reload Python objects identical to the saved originals. + The format supports non-numerical state such as vocabulary files and + lookup tables, and it is easy to customize in the case of custom layers + with exotic elements of state (e.g. a FIFOQueue). The format + does not rely on bytecode or pickling, and is safe by default. Note + that as a result, Python `lambdas` are disallowed at loading time. If + you want to use `lambdas`, you can pass `safe_mode=False` to the loading + method (only do this if you trust the source of the model). + * Added a `model.export(filepath)` API to create a lightweight SavedModel + artifact that can be used for inference (e.g. with TF-Serving). + * Added `keras.export.ExportArchive` class for low-level customization of + the process of exporting SavedModel artifacts for inference. + Both ways of exporting models are based on `tf.function` tracing + and produce a TF program composed of TF ops. They are meant primarily + for environments where the TF runtime is available, + but not the Python interpreter, as is typical + for production with TF Serving. + * Added utility `tf.keras.utils.FeatureSpace`, a one-stop shop for + structured data preprocessing and encoding. * Added `tf.SparseTensor` input support to `tf.keras.layers.Embedding` layer. The layer now accepts a new boolean argument `sparse`. If `sparse` is set to True, the layer returns a SparseTensor instead of a dense Tensor. Defaults to False. + * Added `jit_compile` as a settable property to `tf.keras.Model`. + * Added `synchronized` optional parameter to `layers.BatchNormalization`. + * Added deprecation warning to + `layers.experimental.SyncBatchNormalization` and suggested to use + `layers.BatchNormalization` with `synchronized=True` instead. + * Updated `tf.keras.layers.BatchNormalization` to support masking of the + inputs (`mask` argument) when computing the mean and variance. + * Add `tf.keras.layers.Identity`, a placeholder pass-through layer. + * Add `show_trainable` option to `tf.keras.utils.model_to_dot` to display + layer trainable status in model plots. + * Add ability to save a `tf.keras.utils.FeatureSpace` object, via + `feature_space.save("myfeaturespace.keras")`, and reload it via + `feature_space = tf.keras.models.load_model("myfeaturespace.keras")`. + * Added utility `tf.keras.utils.to_ordinal` to convert class vector to + ordinal regression / classification matrix. + +* `tf.experimental.dtensor`: + + * Coordination service now works with + `dtensor.initialize_accelerator_system`, and enabled by default. + * Add `tf.experimental.dtensor.is_dtensor` to check if a tensor is a + DTensor instance. + +* `tf.data`: + + * Added support for alternative checkpointing protocol which makes it + possible to checkpoint the state of the input pipeline without having to + store the contents of internal buffers. The new functionality can be + enabled through the `experimental_symbolic_checkpoint` option of + `tf.data.Options()`. + * Added a new `rerandomize_each_iteration` argument for the + `tf.data.Dataset.random()` operation, which controls whether the + sequence of generated random numbers should be re-randomized every epoch + or not (the default behavior). If `seed` is set and + `rerandomize_each_iteration=True`, the `random()` operation will produce + a different (deterministic) sequence of numbers every epoch. + * Added a new `rerandomize_each_iteration` argument for the + `tf.data.Dataset.sample_from_datasets()` operation, which controls + whether the sequence of generated random numbers used for sampling + should be re-randomized every epoch or not. If `seed` is set and + `rerandomize_each_iteration=True`, the `sample_from_datasets()` + operation will use a different (deterministic) sequence of numbers every + epoch. + +* `tf.test`: + + * Added `tf.test.experimental.sync_devices`, which is useful for + accurately measuring performance in benchmarks. + +* `tf.experimental.dtensor`: + + * Added experimental support to ReduceScatter fuse on GPU (NCCL). # Bug Fixes and Other Changes @@ -39,6 +212,29 @@ * * +* `tf.random` + * Added non-experimental aliases for `tf.random.split` and + `tf.random.fold_in`, the experimental endpoints are still available + so no code changes are necessary. +* `tf.experimental.ExtensionType` + * Added function `experimental.extension_type.as_dict()`, which converts an + instance of `tf.experimental.ExtensionType` to a `dict` representation. +* `stream_executor` + * Top level `stream_executor` directory has been deleted, users should use + equivalent headers and targets under `compiler/xla/stream_executor`. +* `tf.nn` + * Added `tf.nn.experimental.general_dropout`, which is similar to + `tf.random.experimental.stateless_dropout` but accepts a custom sampler + function. +* `tf.types.experimental.GenericFunction` + * The `experimental_get_compiler_ir` method supports tf.TensorSpec + compilation arguments. +* `tf.config.experimental.mlir_bridge_rollout` + * Removed enums `MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED` and + `MLIR_BRIDGE_ROLLOUT_SAFE_MODE_FALLBACK_ENABLED` which are no longer used by + the tf2xla bridge + + # Thanks to our Contributors This release contains contributions from many people at Google, as well as: @@ -47,12 +243,6 @@ This release contains contributions from many people at Google, as well as: # Release 2.11.0 - - -* `StatusOr::ConsumeValueOrDie` and `StatusOr::ValueOrDie`, both deprecated in - TF 2.10 has been removed. - - ## Breaking Changes * `tf.keras.optimizers.Optimizer` now points to the new Keras optimizer, and old optimizers have moved to the `tf.keras.optimizers.legacy` namespace. @@ -106,12 +296,6 @@ This release contains contributions from many people at Google, as well as: only be implemented based on `tf.keras.optimizers.Optimizer`, the new base class. -## Known Caveats - -* -* -* - ## Major Features and Improvements * `tf.lite`: @@ -160,7 +344,7 @@ This release contains contributions from many people at Google, as well as: file is a protobuf containing the "fingerprint" of the SavedModel. See the [RFC](https://github.com/tensorflow/community/pull/415) for more details regarding its design and properties. - + * `tf.data`: * Graduated experimental APIs: * [`tf.data.Dataset.ragged_batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset/#ragged_batch), which batches elements of `tf.data.Dataset`s into `tf.RaggedTensor`s. @@ -185,11 +369,152 @@ This release contains contributions from many people at Google, as well as: * `tf.SparseTensor`: * Introduced `set_shape`, which sets the static dense shape of the sparse tensor and has the same semantics as `tf.Tensor.set_shape`. +## Security + +* TF is currently using giflib 5.2.1 which has [CVE-2022-28506](https://nvd.nist.gov/vuln/detail/CVE-2022-28506). TF is not affected by the CVE as it does not use `DumpScreen2RGB` at all. +* Fixes an OOB seg fault in `DynamicStitch` due to missing validation ([CVE-2022-41883](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41883)) +* Fixes an overflow in `tf.keras.losses.poisson` ([CVE-2022-41887](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41887)) +* Fixes a heap OOB failure in `ThreadUnsafeUnigramCandidateSampler` caused by missing validation ([CVE-2022-41880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41880)) +* Fixes a segfault in `ndarray_tensor_bridge` ([CVE-2022-41884](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41884)) +* Fixes an overflow in `FusedResizeAndPadConv2D` ([CVE-2022-41885](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41885)) +* Fixes a overflow in `ImageProjectiveTransformV2` ([CVE-2022-41886](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41886)) +* Fixes an FPE in `tf.image.generate_bounding_box_proposals` on GPU ([CVE-2022-41888](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41888)) +* Fixes a segfault in `pywrap_tfe_src` caused by invalid attributes ([CVE-2022-41889](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41889)) +* Fixes a `CHECK` fail in `BCast` ([CVE-2022-41890](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41890)) +* Fixes a segfault in `TensorListConcat` ([CVE-2022-41891](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41891)) +* Fixes a `CHECK_EQ` fail in `TensorListResize` ([CVE-2022-41893](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41893)) +* Fixes an overflow in `CONV_3D_TRANSPOSE` on TFLite ([CVE-2022-41894](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41894)) +* Fixes a heap OOB in `MirrorPadGrad` ([CVE-2022-41895](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41895)) +* Fixes a crash in `Mfcc` ([CVE-2022-41896](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41896)) +* Fixes a heap OOB in `FractionalMaxPoolGrad` ([CVE-2022-41897](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41897)) +* Fixes a `CHECK` fail in `SparseFillEmptyRowsGrad` ([CVE-2022-41898](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41898)) +* Fixes a `CHECK` fail in `SdcaOptimizer` ([CVE-2022-41899](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41899)) +* Fixes a heap OOB in `FractionalAvgPool` and `FractionalMaxPool`([CVE-2022-41900](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41900)) +* Fixes a `CHECK_EQ` in `SparseMatrixNNZ` ([CVE-2022-41901](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41901)) +* Fixes an OOB write in grappler ([CVE-2022-41902](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41902)) +* Fixes a overflow in `ResizeNearestNeighborGrad` ([CVE-2022-41907](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41907)) +* Fixes a `CHECK` fail in `PyFunc` ([CVE-2022-41908](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41908)) +* Fixes a segfault in `CompositeTensorVariantToComponents` ([CVE-2022-41909](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41909)) +* Fixes a invalid char to bool conversion in printing a tensor ([CVE-2022-41911](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41911)) +* Fixes a heap overflow in `QuantizeAndDequantizeV2` ([CVE-2022-41910](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41910)) +* Fixes a `CHECK` failure in `SobolSample` via missing validation ([CVE-2022-35935](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-35935)) +* Fixes a `CHECK` fail in `TensorListScatter` and `TensorListScatterV2` in eager mode ([CVE-2022-35935](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-35935)) + ## Thanks to our Contributors This release contains contributions from many people at Google, as well as: -, , , , , +103yiran, 8bitmp3, Aakar Dwivedi, Alexander Grund, alif_elham, Aman Agarwal, +amoitra, Andrei Ivanov, andreii, Andrew Goodbody, angerson, Ashay Rane, +Azeem Shaikh, Ben Barsdell, bhack, Bhavani Subramanian, Cedric Nugteren, +Chandra Kumar Ramasamy, Christopher Bate, CohenAriel, Cotarou, cramasam, +Enrico Minack, Francisco Unda, Frederic Bastien, gadagashwini, Gauri1 Deshpande, +george, Jake, Jeff, Jerry Ge, Jingxuan He, Jojimon Varghese, Jonathan Dekhtiar, +Kaixi Hou, Kanvi Khanna, kcoul, Keith Smiley, Kevin Hu, Kun Lu, kushanam, +Lianmin Zheng, liuyuanqiang, Louis Sugy, Mahmoud Abuzaina, Marius Brehler, +mdfaijul, Meenakshi Venkataraman, Milos Puzovic, mohantym, Namrata-Ibm, +Nathan John Sircombe, Nathan Luehr, Olaf Lipinski, Om Thakkar, Osman F Bayram, +Patrice Vignola, Pavani Majety, Philipp Hack, Prianka Liz Kariat, Rahul Batra, +RajeshT, Renato Golin, riestere, Roger Iyengar, Rohit Santhanam, Rsanthanam-Amd, +Sadeed Pv, Samuel Marks, Shimokawa, Naoaki, Siddhesh Kothadi, Simengliu-Nv, +Sindre Seppola, snadampal, Srinivasan Narayanamoorthy, sushreebarsa, +syedshahbaaz, Tamas Bela Feher, Tatwai Chong, Thibaut Goetghebuer-Planchon, +tilakrayal, Tom Anderson, Tomohiro Endo, Trevor Morris, vibhutisawant, +Victor Zhang, Vremold, Xavier Bonaventura, Yanming Wang, Yasir Modak, +Yimei Sun, Yong Tang, Yulv-Git, zhuoran.liu, zotanika + +# Release 2.10.1 + +This release introduces several vulnerability fixes: + +* Fixes an OOB seg fault in `DynamicStitch` due to missing validation ([CVE-2022-41883](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41883)) +* Fixes an overflow in `tf.keras.losses.poisson` ([CVE-2022-41887](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41887)) +* Fixes a heap OOB failure in `ThreadUnsafeUnigramCandidateSampler` caused by missing validation ([CVE-2022-41880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41880)) +* Fixes a segfault in `ndarray_tensor_bridge` ([CVE-2022-41884](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41884)) +* Fixes an overflow in `FusedResizeAndPadConv2D` ([CVE-2022-41885](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41885)) +* Fixes a overflow in `ImageProjectiveTransformV2` ([CVE-2022-41886](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41886)) +* Fixes an FPE in `tf.image.generate_bounding_box_proposals` on GPU ([CVE-2022-41888](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41888)) +* Fixes a segfault in `pywrap_tfe_src` caused by invalid attributes ([CVE-2022-41889](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41889)) +* Fixes a `CHECK` fail in `BCast` ([CVE-2022-41890](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41890)) +* Fixes a segfault in `TensorListConcat` ([CVE-2022-41891](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41891)) +* Fixes a `CHECK_EQ` fail in `TensorListResize` ([CVE-2022-41893](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41893)) +* Fixes an overflow in `CONV_3D_TRANSPOSE` on TFLite ([CVE-2022-41894](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41894)) +* Fixes a heap OOB in `MirrorPadGrad` ([CVE-2022-41895](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41895)) +* Fixes a crash in `Mfcc` ([CVE-2022-41896](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41896)) +* Fixes a heap OOB in `FractionalMaxPoolGrad` ([CVE-2022-41897](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41897)) +* Fixes a `CHECK` fail in `SparseFillEmptyRowsGrad` ([CVE-2022-41898](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41898)) +* Fixes a `CHECK` fail in `SdcaOptimizer` ([CVE-2022-41899](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41899)) +* Fixes a heap OOB in `FractionalAvgPool` and `FractionalMaxPool`([CVE-2022-41900](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41900)) +* Fixes a `CHECK_EQ` in `SparseMatrixNNZ` ([CVE-2022-41901](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41901)) +* Fixes an OOB write in grappler ([CVE-2022-41902](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41902)) +* Fixes a overflow in `ResizeNearestNeighborGrad` ([CVE-2022-41907](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41907)) +* Fixes a `CHECK` fail in `PyFunc` ([CVE-2022-41908](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41908)) +* Fixes a segfault in `CompositeTensorVariantToComponents` ([CVE-2022-41909](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41909)) +* Fixes a invalid char to bool conversion in printing a tensor ([CVE-2022-41911](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41911)) +* Fixes a heap overflow in `QuantizeAndDequantizeV2` ([CVE-2022-41910](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41910)) +* Fixes a `CHECK` failure in `SobolSample` via missing validation ([CVE-2022-35935](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-35935)) +* Fixes a `CHECK` fail in `TensorListScatter` and `TensorListScatterV2` in eager mode ([CVE-2022-35935](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-35935)) + +# Release 2.9.3 + +This release introduces several vulnerability fixes: + +* Fixes an overflow in `tf.keras.losses.poisson` ([CVE-2022-41887](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41887)) +* Fixes a heap OOB failure in `ThreadUnsafeUnigramCandidateSampler` caused by missing validation ([CVE-2022-41880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41880)) +* Fixes a segfault in `ndarray_tensor_bridge` ([CVE-2022-41884](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41884)) +* Fixes an overflow in `FusedResizeAndPadConv2D` ([CVE-2022-41885](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41885)) +* Fixes a overflow in `ImageProjectiveTransformV2` ([CVE-2022-41886](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41886)) +* Fixes an FPE in `tf.image.generate_bounding_box_proposals` on GPU ([CVE-2022-41888](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41888)) +* Fixes a segfault in `pywrap_tfe_src` caused by invalid attributes ([CVE-2022-41889](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41889)) +* Fixes a `CHECK` fail in `BCast` ([CVE-2022-41890](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41890)) +* Fixes a segfault in `TensorListConcat` ([CVE-2022-41891](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41891)) +* Fixes a `CHECK_EQ` fail in `TensorListResize` ([CVE-2022-41893](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41893)) +* Fixes an overflow in `CONV_3D_TRANSPOSE` on TFLite ([CVE-2022-41894](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41894)) +* Fixes a heap OOB in `MirrorPadGrad` ([CVE-2022-41895](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41895)) +* Fixes a crash in `Mfcc` ([CVE-2022-41896](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41896)) +* Fixes a heap OOB in `FractionalMaxPoolGrad` ([CVE-2022-41897](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41897)) +* Fixes a `CHECK` fail in `SparseFillEmptyRowsGrad` ([CVE-2022-41898](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41898)) +* Fixes a `CHECK` fail in `SdcaOptimizer` ([CVE-2022-41899](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41899)) +* Fixes a heap OOB in `FractionalAvgPool` and `FractionalMaxPool`([CVE-2022-41900](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41900)) +* Fixes a `CHECK_EQ` in `SparseMatrixNNZ` ([CVE-2022-41901](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41901)) +* Fixes an OOB write in grappler ([CVE-2022-41902](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41902)) +* Fixes a overflow in `ResizeNearestNeighborGrad` ([CVE-2022-41907](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41907)) +* Fixes a `CHECK` fail in `PyFunc` ([CVE-2022-41908](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41908)) +* Fixes a segfault in `CompositeTensorVariantToComponents` ([CVE-2022-41909](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41909)) +* Fixes a invalid char to bool conversion in printing a tensor ([CVE-2022-41911](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41911)) +* Fixes a heap overflow in `QuantizeAndDequantizeV2` ([CVE-2022-41910](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41910)) +* Fixes a `CHECK` failure in `SobolSample` via missing validation ([CVE-2022-35935](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-35935)) +* Fixes a `CHECK` fail in `TensorListScatter` and `TensorListScatterV2` in eager mode ([CVE-2022-35935](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-35935)) + +# Release 2.8.4 + +This release introduces several vulnerability fixes: + +* Fixes a heap OOB failure in `ThreadUnsafeUnigramCandidateSampler` caused by missing validation ([CVE-2022-41880](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41880)) +* Fixes a segfault in `ndarray_tensor_bridge` ([CVE-2022-41884](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41884)) +* Fixes an overflow in `FusedResizeAndPadConv2D` ([CVE-2022-41885](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41885)) +* Fixes a overflow in `ImageProjectiveTransformV2` ([CVE-2022-41886](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41886)) +* Fixes an FPE in `tf.image.generate_bounding_box_proposals` on GPU ([CVE-2022-41888](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41888)) +* Fixes a segfault in `pywrap_tfe_src` caused by invalid attributes ([CVE-2022-41889](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41889)) +* Fixes a `CHECK` fail in `BCast` ([CVE-2022-41890](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41890)) +* Fixes a segfault in `TensorListConcat` ([CVE-2022-41891](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41891)) +* Fixes a `CHECK_EQ` fail in `TensorListResize` ([CVE-2022-41893](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41893)) +* Fixes an overflow in `CONV_3D_TRANSPOSE` on TFLite ([CVE-2022-41894](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41894)) +* Fixes a heap OOB in `MirrorPadGrad` ([CVE-2022-41895](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41895)) +* Fixes a crash in `Mfcc` ([CVE-2022-41896](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41896)) +* Fixes a heap OOB in `FractionalMaxPoolGrad` ([CVE-2022-41897](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41897)) +* Fixes a `CHECK` fail in `SparseFillEmptyRowsGrad` ([CVE-2022-41898](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41898)) +* Fixes a `CHECK` fail in `SdcaOptimizer` ([CVE-2022-41899](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41899)) +* Fixes a heap OOB in `FractionalAvgPool` and `FractionalMaxPool`([CVE-2022-41900](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41900)) +* Fixes a `CHECK_EQ` in `SparseMatrixNNZ` ([CVE-2022-41901](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41901)) +* Fixes an OOB write in grappler ([CVE-2022-41902](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41902)) +* Fixes a overflow in `ResizeNearestNeighborGrad` ([CVE-2022-41907](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41907)) +* Fixes a `CHECK` fail in `PyFunc` ([CVE-2022-41908](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41908)) +* Fixes a segfault in `CompositeTensorVariantToComponents` ([CVE-2022-41909](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41909)) +* Fixes a invalid char to bool conversion in printing a tensor ([CVE-2022-41911](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41911)) +* Fixes a heap overflow in `QuantizeAndDequantizeV2` ([CVE-2022-41910](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-41910)) +* Fixes a `CHECK` failure in `SobolSample` via missing validation ([CVE-2022-35935](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-35935)) +* Fixes a `CHECK` fail in `TensorListScatter` and `TensorListScatterV2` in eager mode ([CVE-2022-35935](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2022-35935)) # Release 2.10.0 @@ -10654,3 +10979,5 @@ answered questions, and were part of inspiring discussions. # Release 0.5.0 Initial release of TensorFlow. + + diff --git a/SECURITY.md b/SECURITY.md index f6d414794c0..d6d47c4e635 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -273,21 +273,11 @@ TensorFlow uses the following disclosure process: * An advisory is prepared (but not published) which details the problem and steps for mitigation. * The vulnerability is fixed and potential workarounds are identified. -* We will attempt to cherry-pick the fix to the release branches used for all - releases of TensorFlow that are at most one year old (though sometimes we - might not patch all of them). The cherry-picks will occur as soon as possible - and the patch releases will come at the same time as the next quarterly - release. -* Whenever patch releases are finalized, we will notify discuss@tensorflow.org. * We will publish a security advisory for all fixed vulnerabilities. For each vulnerability, we try to ingress it as soon as possible, given the size of the team and the number of reports. Vulnerabilities will, in general, be -batched to be fixed at the same time as a quarterly release. An exception to -this rule is for high impact vulnerabilities where exploitation of models used -for inference in products (i.e., not models created just to showcase a -vulnerability) is possible. In these cases, we will attempt to do patch releases -within an accelerated timeline, not waiting for the next quarterly release. +batched to be fixed at the same time as a quarterly release. Past security advisories are listed [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/README.md). diff --git a/configure.py b/configure.py index 135001ed103..6abde63a28a 100644 --- a/configure.py +++ b/configure.py @@ -36,7 +36,7 @@ _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' _SUPPORTED_ANDROID_NDK_VERSIONS = [ - 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21 + 19, 20, 21 ] _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 @@ -619,7 +619,7 @@ def prompt_loop_or_load_from_env(environ_cp, 'Assuming to be a scripting mistake.' % (var_name, n_ask_attempts)) - if resolve_symlinks and os.path.islink(val): + if resolve_symlinks: val = os.path.realpath(val) environ_cp[var_name] = val return val @@ -718,7 +718,8 @@ def valid_build_tools(version): def get_ndk_api_level(environ_cp, android_ndk_home_path): - """Gets the appropriate NDK API level to use for the provided Android NDK path.""" + """Gets the appropriate NDK API level to use for the provided Android NDK path. + """ # First check to see if we're using a blessed version of the NDK. properties_path = '%s/source.properties' % android_ndk_home_path @@ -756,7 +757,7 @@ def valid_api_level(api_level): android_ndk_api_level = prompt_loop_or_load_from_env( environ_cp, var_name='ANDROID_NDK_API_LEVEL', - var_default='21', # 21 is required for ARM64 support. + var_default='26', # 26 is required to support AHardwareBuffer. ask_for_var=('Please specify the (min) Android NDK API level to use. ' '[Available levels: %s]') % api_levels, check_success=valid_api_level, @@ -1188,6 +1189,9 @@ def main(): gcc_env = get_gcc_compiler(environ_cp) if gcc_env is not None: + # Use gold linker if 'gcc' and if 'ppc64le' + write_to_bazelrc('build --linkopt="-fuse-ld=gold"') + # Get the linker version ld_version = run_shell([gcc_env, '-Wl,-version']).split() @@ -1215,8 +1219,6 @@ def main(): if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('ROCM_PATH')): write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH')) - write_action_env_to_bazelrc('ROCBLAS_TENSILE_LIBPATH', - environ_cp.get('ROCM_PATH') + '/lib/library') if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('HIP_PLATFORM')): write_action_env_to_bazelrc('HIP_PLATFORM', environ_cp.get('HIP_PLATFORM')) diff --git a/fuzztest.bazelrc b/fuzztest.bazelrc new file mode 100644 index 00000000000..360b3484ee9 --- /dev/null +++ b/fuzztest.bazelrc @@ -0,0 +1,47 @@ +### DO NOT EDIT. Generated file. +# +# To regenerate, run the following from your project's workspace: +# +# bazel run @com_google_fuzztest//bazel:setup_configs > fuzztest.bazelrc +# +# And don't forget to add the following to your project's .bazelrc: +# +# try-import %workspace%/fuzztest.bazelrc + + +### Common options. +# +# Do not use directly. + +# Link with Address Sanitizer (ASAN). +build:fuzztest-common --linkopt=-fsanitize=address + +# Standard define for "ifdef-ing" any fuzz test specific code. +build:fuzztest-common --copt=-DFUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION + +# In fuzz tests, we want to catch assertion violations even in optimized builds. +build:fuzztest-common --copt=-UNDEBUG + +# Enable libc++ assertions. +# See https://libcxx.llvm.org/UsingLibcxx.html#enabling-the-safe-libc-mode +build:fuzztest-common --copt=-D_LIBCPP_ENABLE_ASSERTIONS=1 + + +### FuzzTest build configuration. +# +# Use with: --config=fuzztest + +build:fuzztest --config=fuzztest-common + +# Link statically. +build:fuzztest --dynamic_mode=off + +# We rely on the following flag instead of the compiler provided +# __has_feature(address_sanitizer) to know that we have an ASAN build even in +# the uninstrumented runtime. +build:fuzztest --copt=-DADDRESS_SANITIZER + +# We apply coverage tracking and ASAN instrumentation to everything but the +# FuzzTest framework itself (including GoogleTest and GoogleMock). +build:fuzztest --per_file_copt=+//,-//fuzztest:,-googletest/.*,-googlemock/.*@-fsanitize=address,-fsanitize-coverage=inline-8bit-counters,-fsanitize-coverage=trace-cmp + diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 46879082f93..0d27a8294f5 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -8,16 +8,26 @@ load( "//tensorflow:tensorflow.bzl", "VERSION", "VERSION_MAJOR", + "check_deps", "if_google", "if_oss", + "if_xla_available", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", + "tf_monitoring_python_deps", "tf_native_cc_binary", + "tsl_async_value_deps", ) load( "//tensorflow/core/platform:build_config.bzl", "tf_additional_binary_deps", ) +load( + "//tensorflow/core/platform:build_config_root.bzl", + "if_static", + "tf_additional_plugin_deps", + "tf_additional_profiler_deps", +) load( "//third_party/mkl:build_defs.bzl", "if_mkl_ml", @@ -28,6 +38,7 @@ load( "ADDITIONAL_API_INDEXABLE_SETTINGS", "tf_cc_shared_library", ) +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") # copybara:uncomment_begin # load("//tools/build_defs/license:license.bzl", "license") @@ -95,12 +106,25 @@ PACKAGE_STATIC_DEPS = [ "@mkl_dnn_acl_compatible//:__subpackages__", "@mkl_dnn_v1//:__subpackages__", "@nccl_archive//:__subpackages__", + "@nvtx_archive//:__subpackages__", "@org_sqlite//:__subpackages__", "@platforms//:__subpackages__", "@snappy//:__subpackages__", "@upb//:__subpackages__", "@zlib//:__subpackages__", -] + "@dlpack//:__subpackages__", + "@arm_neon_2_x86_sse//:__subpackages__", + "@cpuinfo//:__subpackages__", + "@ruy//:__subpackages__", + "@XNNPACK//:__subpackages__", + "@pthreadpool//:__subpackages__", + "@FXdiv//:__subpackages__", + "@FP16//:__subpackages__", + "@clog//:__subpackages__", + "@flatbuffers//:__subpackages__", + "@nccl_archive//:__subpackages__", + "@triton//:__subpackages__", +] + tsl_async_value_deps() package( # copybara:uncomment default_applicable_licenses = [":license"], @@ -918,23 +942,21 @@ config_setting( visibility = ["//visibility:public"], ) -# copybara:uncomment_begin(configurable API loading) -# bool_flag( -# name = "enable_api_indexable", -# build_setting_default = False, -# ) -# -# config_setting( -# name = "api_indexable_flag", -# flag_values = {":enable_api_indexable": "True"}, -# ) -# -# selects.config_setting_group( -# name = "api_indexable", -# match_any = [":api_indexable_flag"] + ADDITIONAL_API_INDEXABLE_SETTINGS, -# visibility = ["//visibility:public"], -# ) -# copybara:uncomment_end +bool_flag( + name = "enable_api_indexable", + build_setting_default = False, +) + +config_setting( + name = "api_indexable_flag", + flag_values = {":enable_api_indexable": "True"}, +) + +selects.config_setting_group( + name = "api_indexable", + match_any = [":api_indexable_flag"] + ADDITIONAL_API_INDEXABLE_SETTINGS, + visibility = ["//visibility:public"], +) # DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST! # Instead, please use public APIs or public build rules TF provides. @@ -949,6 +971,8 @@ package_group( "//learning/brain/tfrt/...", "//learning/lib/ami/simple_ml/...", "//learning/pathways/...", + "//learning/serving/contrib/tfrt/mlir/canonical_ops/...", + "//perftools/accelerators/xprof/integration_tests/...", "//smartass/brain/configure/...", "//tensorflow/...", "//tensorflow_decision_forests/...", @@ -967,12 +991,12 @@ package_group(name = "ndarray_tensor_allow_list") # Packages that use private types symbols, until they are exported. # TODO(b/154650521) Remove. # If this is modified, then copy.bara.sky must also be modified. -package_group(name = "types_whitelist") +package_group(name = "types_allowlist") # Packages that use StructuredTensors. # TODO(b/159007891) Remove this package once StructuredTensor is exported. # LINT.IfChange -package_group(name = "structured_tensor_whitelist") +package_group(name = "structured_tensor_allowlist") # LINT.ThenChange(copy.bara.sky) filegroup( @@ -1081,28 +1105,38 @@ tf_cc_shared_library( linkstatic = 1, per_os_targets = True, roots = [ - "//tensorflow/c/experimental/filesystem:filesystem_interface", - "//tensorflow/c/experimental/stream_executor:stream_executor", - "//tensorflow/c:env", - "//tensorflow/c:kernels", - "//tensorflow/c:kernels_experimental", - "//tensorflow/c:logging", - "//tensorflow/c:ops", - "//tensorflow/cc/saved_model:fingerprinting_impl", - "//tensorflow/cc/saved_model:loader_lite_impl", - "//tensorflow/cc/saved_model:metrics_impl", - "//tensorflow/compiler/tf2tensorrt:op_converter_registry_impl", - "//tensorflow/core/common_runtime:core_cpu_impl", - "//tensorflow/core:framework_internal_impl", - "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl", - "//tensorflow/core/common_runtime/pluggable_device:pluggable_device_runtime_impl", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", - "//tensorflow/core:lib_internal_impl", - "//tensorflow/core/profiler:profiler_impl", - "//tensorflow/core/util:determinism", # Must be linked and exported to libtensorflow_framework.so. - "//tensorflow/lite/kernels/shim:tf_kernel_shim", - "//tensorflow/compiler/xla/stream_executor:stream_executor_impl", - ] + tf_additional_binary_deps(), + "//tensorflow/c/experimental/filesystem:filesystem_interface", + "//tensorflow/c/experimental/stream_executor:stream_executor", + "//tensorflow/c:env", + "//tensorflow/c:kernels", + "//tensorflow/c:kernels_experimental", + "//tensorflow/c:logging", + "//tensorflow/c:ops", + "//tensorflow/cc/saved_model:fingerprinting_impl", + "//tensorflow/cc/saved_model:loader_lite_impl", + "//tensorflow/cc/saved_model:metrics_impl", + "//tensorflow/compiler/tf2tensorrt:op_converter_registry_impl", + "//tensorflow/core/common_runtime:core_cpu_impl", + "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl", + "//tensorflow/core/common_runtime/pluggable_device:pluggable_device_runtime_impl", + "//tensorflow/core:framework_internal_impl", + "//tensorflow/core/framework:tensor", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", + "//tensorflow/core:lib_internal_impl", + "//tensorflow/core/profiler:profiler_impl", + "//tensorflow/core/util:determinism", # Must be linked and exported to libtensorflow_framework.so. + "//tensorflow/lite/kernels/shim:tf_kernel_shim", + "//tensorflow/compiler/xla/stream_executor:stream_executor_impl", + "//tensorflow/tsl/framework:bfc_allocator", + "//tensorflow/tsl/framework:metrics", + ] + tf_additional_binary_deps() + + # TODO(b/259305727): Remove this select and include captured_function in macos builds. + select({ + "//tensorflow:macos": [], + "//conditions:default": [ + "//tensorflow/core/data:captured_function", + ], + }), soversion = VERSION, static_deps = PACKAGE_STATIC_DEPS, visibility = ["//visibility:public"], @@ -1193,6 +1227,9 @@ tf_cc_shared_library( "//tensorflow:macos": ["//tensorflow:libtensorflow_framework.%s.dylib" % VERSION], "//conditions:default": ["//tensorflow:libtensorflow_framework.so.%s" % VERSION], }), + exports_filter = [ + "//:__subpackages__", + ], framework_so = [], linkopts = select({ "//tensorflow:macos": [ @@ -1206,41 +1243,168 @@ tf_cc_shared_library( }), per_os_targets = True, roots = [ + "//tensorflow/c:c_api", + "//tensorflow/c/eager:c_api", "//tensorflow/cc:cc_ops", "//tensorflow/cc:client_session", - "//tensorflow/cc:const_op", "//tensorflow/cc:scope", - ], + "//tensorflow/core:tensorflow", + "//tensorflow/core/data:standalone", + # Exports for pywrap_tensorflow_internal. Many of these are transitive + # depedencies of the above, but must be explicitly listed for + # cc_shared_library to work. + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/c/eager:c_api_internal", + "//tensorflow/c/eager:dlpack", + "//tensorflow/c/eager:tape", + "//tensorflow/c/eager:tfe_context_internal", + "//tensorflow/c/eager:tfe_op_internal", + "//tensorflow/c/eager:tfe_tensorhandle_internal", + "//tensorflow/c/experimental/gradients", + "//tensorflow/c/experimental/gradients/tape", + "//tensorflow/c/experimental/ops", + "//tensorflow/c:c_api_experimental", + "//tensorflow/c:c_api_internal", + "//tensorflow/c:c_api_no_xla", + "//tensorflow/c:checkpoint_reader", + "//tensorflow/c:tensor_interface", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c:tf_tensor_internal", + "//tensorflow/cc/saved_model:loader", + "//tensorflow/compiler/mlir/lite/metrics:error_collector", + "//tensorflow/compiler/mlir/lite/python:flatbuffer_to_mlir", + "//tensorflow/compiler/mlir/lite/python:graphdef_to_tfl_flatbuffer", + "//tensorflow/compiler/mlir/lite/python:jax_to_tfl_flatbuffer", + "//tensorflow/compiler/mlir/lite/python:saved_model_to_tfl_flatbuffer", + "//tensorflow/compiler/mlir/lite/quantization/lite:quantize_model", + "//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "//tensorflow/compiler/mlir/lite/sparsity:sparsify_model", + "//tensorflow/compiler/mlir/python:mlir", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:custom_aggregator_op", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc_impl", + "//tensorflow/compiler/mlir/quantization/tensorflow:passes", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:export_graphdef", + "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", + "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", + "//tensorflow/compiler/mlir/tensorflow:translate_lib", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/core", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/config:flag_defs", + "//tensorflow/core/config:flags", + "//tensorflow/core/data/service:dispatcher_client", + "//tensorflow/core/data/service:grpc_util", + "//tensorflow/core/data/service:py_utils", + "//tensorflow/core/data/service:server_lib", + "//tensorflow/core/debug", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/function/runtime_client:runtime_client_cc", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/clusters:single_machine", + "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/costs:graph_memory", + "//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool", + "//tensorflow/core/grappler/optimizers:meta_optimizer", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:grappler_item_builder", + "//tensorflow/core/kernels:data_service_ops", + "//tensorflow/core/kernels:dataset_ops", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:path", + "//tensorflow/core/platform:stacktrace_handler", + "//tensorflow/core/platform:statusor", + "//tensorflow/core/platform:stringpiece", + "//tensorflow/core/platform:types", + "//tensorflow/core/profiler/internal:print_model_analysis", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/profiler/rpc/client:profiler_client_impl", + "//tensorflow/core/profiler/rpc:profiler_server_impl", + "//tensorflow/core/util:managed_stack_trace", + "//tensorflow/core:all_kernels", + "//tensorflow/core:core_cpu", + "//tensorflow/core:direct_session", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:reader_base", + "//tensorflow/core:script_ops_op_lib", + "//tensorflow/distribute/experimental/rpc/kernels:rpc_ops", + "//tensorflow/dtensor/cc:dtensor_device_cc", + "//tensorflow/dtensor/cc:tensor_layout", + "//tensorflow/lite/c:common", + "//tensorflow/lite/core/api", + "//tensorflow/lite/delegates/flex:delegate", + "//tensorflow/lite/kernels/internal:compatibility", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/kernels:reference_ops", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/toco/logging:conversion_log_util", + "//tensorflow/lite/toco/logging:toco_conversion_log_proto_cc", + "//tensorflow/lite/toco:model_flags_proto_cc", + "//tensorflow/lite/toco:toco_convert", + "//tensorflow/lite/toco:toco_flags_proto_cc", + "//tensorflow/lite/toco:toco_graphviz_dump_options", + "//tensorflow/lite/toco:toco_port", + "//tensorflow/lite/toco:toco_tooling", + "//tensorflow/lite/toco:tooling_util", + "//tensorflow/lite/toco:types_proto_cc", + "//tensorflow/lite:framework", + "//tensorflow/lite:shared_library", + "//tensorflow/lite:stateful_error_reporter", + "//tensorflow/lite:string_util", + "//tensorflow/lite:util", + "//tensorflow/python/grappler:cost_analyzer_lib", + "//tensorflow/tools/graph_transforms:transform_graph_lib", + ] + (tf_monitoring_python_deps() + + tf_additional_plugin_deps() + + tf_additional_profiler_deps()) + if_xla_available([ + "//tensorflow/compiler/aot:tfcompile_lib", + ]) + if_static(extra_deps = [ + "//tensorflow/core/platform:tensor_float_32_utils", + "//tensorflow/core/platform:enable_tf2_utils", + ]) + if_oss([ + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_session", + ]), soversion = VERSION, static_deps = PACKAGE_STATIC_DEPS, visibility = ["//visibility:public"], win_def_file = ":tensorflow_filtered_def_file", - deps = [ - "//tensorflow/c:c_api", - "//tensorflow/c:env", - "//tensorflow/c:kernels", - "//tensorflow/c:kernels_experimental", - "//tensorflow/c:logging", - "//tensorflow/c:ops", - "//tensorflow/c/eager:c_api", - "//tensorflow/c/experimental/filesystem:filesystem_interface", - "//tensorflow/c/experimental/stream_executor:stream_executor", - "//tensorflow/cc/saved_model:fingerprinting_impl", - "//tensorflow/cc/saved_model:loader_lite_impl", - "//tensorflow/cc/saved_model:metrics_impl", - "//tensorflow/core:framework_internal_impl", - "//tensorflow/core:lib_internal_impl", - "//tensorflow/core:tensorflow", - "//tensorflow/core/data:standalone", - "//tensorflow/core/common_runtime:core_cpu_impl", - "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl", - "//tensorflow/core/common_runtime/pluggable_device:pluggable_device_runtime_impl", - "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", - "//tensorflow/core/profiler:profiler_impl", - "//tensorflow/core/util:determinism", - "//tensorflow/lite/kernels/shim:tf_kernel_shim", - "//tensorflow/compiler/xla/stream_executor:stream_executor_impl", - ] + tf_additional_binary_deps(), +) + +# To avoid duplication, check that the C++ or python library does not depend on +# the stream executor cuda plugins. Targets that want to use cuda APIs should +# instead depend on the dummy plugins in //tensorflow/tsl/platform/default/build_config +# and use header only targets. +# TODO(ddunleavy): This seems completely broken. :tensorflow_cc depends on +# cuda_platform from tf_additional_binary_deps and this doesn't break. +check_deps( + name = "cuda_plugins_check_deps", + disallowed_deps = if_static( + [], + otherwise = [ + "//tensorflow/compiler/xla/stream_executor/cuda:all_runtime", + "//tensorflow/compiler/xla/stream_executor/cuda:cuda_driver", + "//tensorflow/compiler/xla/stream_executor/cuda:cuda_platform", + "//tensorflow/compiler/xla/stream_executor/cuda:cudnn_plugin", + "//tensorflow/compiler/xla/stream_executor/cuda:cufft_plugin", + "//tensorflow/compiler/xla/stream_executor/cuda:curand_plugin", + "//tensorflow/compiler/xla/stream_executor:cuda_platform", + ], + ), + deps = if_cuda([ + "//tensorflow:tensorflow_cc", + "//tensorflow/python:pywrap_tensorflow_internal", + ]), ) # ** Targets for Windows build (start) ** @@ -1344,7 +1508,7 @@ genrule( "//tensorflow/c/eager:headers", "//tensorflow/cc:headers", "//tensorflow/core:headers", - "//tensorflow/stream_executor:stream_executor_install_hdrs", + "//tensorflow/compiler/xla/stream_executor:stream_executor_install_hdrs", ], outs = ["include"], cmd = """ diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 3bb0bb91ba6..cd3cbac7a96 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -103,8 +103,6 @@ # Load all plugin libraries from site-packages/tensorflow-plugins if we are # running under pip. -# TODO(gunan): Enable setting an environment variable to define arbitrary plugin -# directories. # TODO(gunan): Find a better location for this code snippet. from tensorflow.python.framework import load_library as _ll from tensorflow.python.lib.io import file_io as _fi @@ -146,6 +144,11 @@ def _running_from_pip_package(): # Load Pluggable Device Library _ll.load_pluggable_device_library(_plugin_dir) +if _os.getenv("TF_PLUGGABLE_DEVICE_LIBRARY_PATH", ""): + _ll.load_pluggable_device_library( + _os.getenv("TF_PLUGGABLE_DEVICE_LIBRARY_PATH") + ) + # Add module aliases if hasattr(_current_module, 'keras'): # It is possible that keras is a lazily loaded module, which might break when diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index f11fedce109..6c42fea562f 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -145,8 +145,6 @@ # Load all plugin libraries from site-packages/tensorflow-plugins if we are # running under pip. -# TODO(gunan): Enable setting an environment variable to define arbitrary plugin -# directories. # TODO(gunan): Find a better location for this code snippet. from tensorflow.python.framework import load_library as _ll from tensorflow.python.lib.io import file_io as _fi @@ -187,6 +185,11 @@ def _running_from_pip_package(): # Load Pluggable Device Library _ll.load_pluggable_device_library(_plugin_dir) +if _os.getenv("TF_PLUGGABLE_DEVICE_LIBRARY_PATH", ""): + _ll.load_pluggable_device_library( + _os.getenv("TF_PLUGGABLE_DEVICE_LIBRARY_PATH") + ) + # Explicitly import lazy-loaded modules to support autocompletion. # pylint: disable=g-import-not-at-top if _typing.TYPE_CHECKING: diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index dbd90e1d01f..3c1568b7091 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -2,7 +2,7 @@ # C API for TensorFlow, for use by client language bindings. load("@bazel_skylib//lib:selects.bzl", "selects") -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "check_deps", @@ -18,6 +18,7 @@ load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") load("//tensorflow:tensorflow.default.bzl", "filegroup", "tf_cuda_cc_test") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -39,6 +40,7 @@ filegroup( "tf_tensor.h", "tf_tstring.h", "//tensorflow/core/platform:ctstring", + "//tensorflow/tsl/c:headers", ] + if_tensorrt([ "//tensorflow/compiler/tf2tensorrt:headers", ]), @@ -60,7 +62,8 @@ filegroup( "*test*", ], ) + [ - "//tensorflow/core/platform:ctstring", + "//tensorflow/tsl/c:srcs", + "//tensorflow/tsl/platform:ctstring", "//tensorflow/cc:srcs_no_runtime", "//tensorflow/core/distributed_runtime:server_lib.h", ], @@ -79,6 +82,7 @@ cc_library( "tf_buffer_internal.h", "tf_status_internal.h", "tf_tensor_internal.h", + "//tensorflow/tsl/c:tsl_status_internal_headers", ], visibility = [ "//tensorflow/core:__pkg__", @@ -86,6 +90,22 @@ cc_library( ], ) +cc_library( + name = "c_api_headers", + hdrs = [ + "c_api.h", + "c_api_macros.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":tf_attrtype", + ":tf_buffer", + ":tf_datatype", + ":tf_status_headers", + ":tf_tstring", + ], +) + tf_cuda_library( name = "c_api_internal", hdrs = [ @@ -184,6 +204,7 @@ tf_cuda_library( ":tf_tensor_internal", ":tf_tstring", "//tensorflow/core/platform:tstring", + "//tensorflow/tsl/c:tsl_status", ] + select({ "//tensorflow:with_xla_support": [ "//tensorflow/compiler/tf2xla:xla_compiler", @@ -213,7 +234,7 @@ tf_cuda_library( ], copts = tf_copts(), visibility = [ - "//tensorflow/c:__subpackages__", + "//tensorflow:__subpackages__", "//tensorflow/python:__subpackages__", ], deps = [ @@ -273,6 +294,7 @@ tf_cuda_library( hdrs = [ "tf_status.h", "tf_status_internal.h", + "//tensorflow/tsl/c:tsl_status_internal_headers", ], visibility = [ "//tensorflow/c:__subpackages__", @@ -285,7 +307,11 @@ tf_cuda_library( "//tensorflow/compiler/mlir/tensorflow/c:__subpackages__", "//tensorflow/core/transforms:__subpackages__", ], - deps = select({ + deps = [ + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/c:tsl_status", + "//tensorflow/tsl/c:tsl_status_internal", + ] + select({ "//tensorflow:android": [ "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs ], @@ -297,7 +323,10 @@ tf_cuda_library( filegroup( name = "tf_status_internal_headers", - srcs = ["tf_status_internal.h"], + srcs = [ + "tf_status_internal.h", + "//tensorflow/tsl/c:tsl_status_internal_headers", + ], visibility = [ "//tensorflow/python:__subpackages__", ], @@ -331,9 +360,11 @@ cc_library( name = "tf_status", srcs = ["tf_status.cc"], hdrs = ["tf_status.h"], + copts = tf_copts(), visibility = ["//visibility:public"], deps = [ ":tf_status_internal", + "//tensorflow/tsl/c:tsl_status", ] + select({ "//tensorflow:android": [ "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs @@ -344,22 +375,13 @@ cc_library( }), ) -tf_cc_test( - name = "tf_status_test", - srcs = ["tf_status_test.cc"], - deps = [ - ":tf_status", - ":tf_status_internal", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - cc_library( name = "tf_status_headers", hdrs = ["tf_status.h"], visibility = ["//visibility:public"], + deps = [ + "//tensorflow/tsl/c:tsl_status", + ], ) cc_library( @@ -374,10 +396,12 @@ cc_library( "tf_tensor.h", "tf_tstring.h", ], + copts = tf_copts(), visibility = ["//visibility:public"], deps = [ "//tensorflow/core/platform:status", "//tensorflow/core/platform:tstring", + "//tensorflow/tsl/c:tsl_status", ], ) @@ -406,6 +430,7 @@ cc_library( name = "tf_datatype", srcs = ["tf_datatype.cc"], hdrs = ["tf_datatype.h"], + copts = tf_copts(), visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ @@ -422,6 +447,7 @@ cc_library( name = "tf_tensor", srcs = ["tf_tensor.cc"], hdrs = ["tf_tensor.h"], + copts = tf_copts(), visibility = ["//visibility:public"], deps = [ ":c_api_macros", @@ -475,6 +501,7 @@ cc_library( hdrs = [ "tf_buffer.h", ], + copts = tf_copts(), visibility = ["//visibility:public"], deps = [ ":tf_buffer_internal", @@ -570,24 +597,9 @@ tf_cuda_library( deps = [ ":tf_status", ":tf_status_internal", - ] + select({ - "//tensorflow:android": [ - "//tensorflow/core:portable_tensorflow_lib_lite", # TODO(annarev): exclude runtime srcs - ], - "//conditions:default": [ - "//tensorflow/core:lib", - ], - }), -) - -tf_cc_test( - name = "tf_status_helper_test", - srcs = ["tf_status_helper_test.cc"], - deps = [ - ":tf_status_helper", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "//tensorflow/tsl/c:tsl_status_helper", ], ) @@ -804,7 +816,6 @@ tf_cuda_cc_test( ], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. - # linkstatic = tf_kernel_tests_linkstatic(), deps = [ ":c_api", ":c_api_internal", @@ -849,6 +860,7 @@ tf_cc_test( data = [ "testdata/tf_record", "//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so", + "//tensorflow/core/common_runtime/next_pluggable_device/c:test_next_pluggable_device_plugin.so", ], extra_copts = if_google(["-DTENSORFLOW_NO_SHARED_OBJECTS=1"]), linkopts = select({ @@ -861,7 +873,6 @@ tf_cc_test( ], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. - # linkstatic = tf_kernel_tests_linkstatic(), deps = [ ":c_api", ":c_api_experimental", @@ -934,7 +945,6 @@ tf_cuda_cc_test( tags = ["noasan"], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. - # linkstatic = tf_kernel_tests_linkstatic(), deps = [ ":c_api", ":env", @@ -955,7 +965,6 @@ tf_cuda_cc_test( tags = ["no_cuda_on_cpu_tap"], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. - # linkstatic = tf_kernel_tests_linkstatic(), deps = [ ":c_api", ":kernels", @@ -982,7 +991,6 @@ tf_cc_test( tags = ["noasan"], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. - # linkstatic = tf_kernel_tests_linkstatic(), deps = [ ":c_api", ":ops", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 96e1268f62d..da62fc35bc0 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -117,7 +117,13 @@ const char* TF_Version() { return TF_VERSION_STRING; } // -------------------------------------------------------------------------- // -------------------------------------------------------------------------- -TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; } +TF_SessionOptions* TF_NewSessionOptions() { + TF_SessionOptions* out = new TF_SessionOptions; + // Disable optimizations for static graph to allow calls to Session::Extend. + out->options.config.mutable_experimental() + ->set_disable_optimize_for_static_graph(true); + return out; +} void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; } void TF_SetTarget(TF_SessionOptions* options, const char* target) { @@ -129,6 +135,9 @@ void TF_SetConfig(TF_SessionOptions* options, const void* proto, if (!options->options.config.ParseFromArray(proto, proto_len)) { status->status = InvalidArgument("Unparseable ConfigProto"); } + // Disable optimizations for static graph to allow calls to Session::Extend. + options->options.config.mutable_experimental() + ->set_disable_optimize_for_static_graph(true); } void TF_TensorFromProto(const TF_Buffer* from, TF_Tensor* to, diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 523f5c6e609..3a05e1e64db 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -758,3 +758,9 @@ TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename, void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) { delete lib_handle; } + +void TF_GraphRemoveFunction(TF_Graph* g, const char* func_name, + TF_Status* status) { + tensorflow::mutex_lock l(g->mu); + status->status = g->graph.mutable_flib_def()->RemoveFunction(func_name); +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index ac41bb5a9ca..aec1e875eaf 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -329,6 +329,12 @@ TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary( TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle( TF_Library* lib_handle); +// Removes `func_name` from `g`. If `func_name` is not in `g`, an error will be +// returned. +TF_CAPI_EXPORT extern void TF_GraphRemoveFunction(TF_Graph* g, + const char* func_name, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index e47b7d0b0f7..63013c3fe46 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" #include "absl/types/optional.h" +#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/c/eager/c_api.h" @@ -255,5 +256,110 @@ TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) { #endif // !defined(PLATFORM_WINDOWS) } +TEST(CAPI_EXPERIMENTAL, LibraryNextPluggableDeviceLoadFunctions) { + // TODO(penpornk): Enable this test on Windows. +#if !defined(PLATFORM_WINDOWS) +#if !defined(TENSORFLOW_NO_SHARED_OBJECTS) + // Load the library. + TF_Status* status = TF_NewStatus(); + string lib_path = + tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath( + "tensorflow", "core", "common_runtime", "next_pluggable_device", "c", + "test_next_pluggable_device_plugin.so")); + TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status); + TF_Code code = TF_GetCode(status); + string status_msg(TF_Message(status)); + TF_DeleteStatus(status); + ASSERT_EQ(TF_OK, code) << status_msg; + TF_DeletePluggableDeviceLibraryHandle(lib); +#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS) +#endif // !defined(PLATFORM_WINDOWS) +} + +void DefineFunction(const char* name, TF_Function** func, + const char* description = nullptr, + bool append_hash = false) { + std::unique_ptr func_graph( + TF_NewGraph(), TF_DeleteGraph); + std::unique_ptr s(TF_NewStatus(), + TF_DeleteStatus); + + TF_Operation* feed = Placeholder(func_graph.get(), s.get()); + TF_Operation* neg = Neg(feed, func_graph.get(), s.get()); + + TF_Output inputs[] = {{feed, 0}}; + TF_Output outputs[] = {{neg, 0}}; + *func = TF_GraphToFunction(func_graph.get(), name, append_hash, -1, + /*opers=*/nullptr, 1, inputs, 1, outputs, + /*output_names=*/nullptr, + /*opts=*/nullptr, description, s.get()); + ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); + ASSERT_NE(*func, nullptr); +} + +class CApiExperimentalFunctionTest : public ::testing::Test { + protected: + CApiExperimentalFunctionTest() + : s_(TF_NewStatus()), func_graph_(TF_NewGraph()), func_(nullptr) {} + + void SetUp() override {} + + ~CApiExperimentalFunctionTest() override { + TF_DeleteFunction(func_); + TF_DeleteGraph(func_graph_); + TF_DeleteStatus(s_); + } + + const char* func_name_ = "MyFunc"; + TF_Status* s_; + TF_Graph* func_graph_; + TF_Function* func_; +}; + +TEST_F(CApiExperimentalFunctionTest, GraphRemoveFunction) { + TF_Function* funcs[1]; + DefineFunction(func_name_, &func_); + + TF_GraphCopyFunction(func_graph_, func_, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + EXPECT_EQ(TF_GraphNumFunctions(func_graph_), 1); + EXPECT_EQ(TF_GraphGetFunctions(func_graph_, funcs, 1, s_), 1); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_GraphRemoveFunction(func_graph_, func_name_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + EXPECT_EQ(TF_GraphNumFunctions(func_graph_), 0); + EXPECT_EQ(TF_GraphGetFunctions(func_graph_, funcs, 1, s_), 0); + + TF_DeleteFunction(funcs[0]); +} + +TEST_F(CApiExperimentalFunctionTest, EmptyGraphRemoveNonExistentFunction) { + TF_GraphRemoveFunction(func_graph_, "wrong_name", s_); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("Tried to remove non-existent function 'wrong_name'."), + string(TF_Message(s_))); +} + +TEST_F(CApiExperimentalFunctionTest, GraphRemoveNonExistentFunction) { + TF_Function* funcs[1]; + DefineFunction(func_name_, &func_); + + TF_GraphCopyFunction(func_graph_, func_, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + EXPECT_EQ(TF_GraphNumFunctions(func_graph_), 1); + EXPECT_EQ(TF_GraphGetFunctions(func_graph_, funcs, 1, s_), 1); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_GraphRemoveFunction(func_graph_, "wrong_name", s_); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("Tried to remove non-existent function 'wrong_name'."), + string(TF_Message(s_))); + TF_DeleteFunction(funcs[0]); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 537b61f3558..a13a1458553 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -185,7 +185,7 @@ TF_Function* TF_GraphToFunctionWithControlOutputs( if (control_output_names) { control_output_names_vec.reserve(ncontrol_outputs); for (int i = 0; i < ncontrol_outputs; ++i) { - control_output_names_vec.push_back(string(output_names[i])); + control_output_names_vec.push_back(string(control_output_names[i])); } } diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 9722841691f..79d2841b724 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -211,6 +211,14 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) std::string getTF_OutputDebugString(TF_Output node); +// Set whether to propagate assigned device information when constructing a new +// Graph from a GraphDef. By default assigned device information is not copied +// and is re-computed by the runtime. +inline void TF_ImportGraphDefOptionsSetPropagateDeviceSpec( + TF_ImportGraphDefOptions* opts, unsigned char propagate_device_spec) { + opts->opts.propagate_device_spec = propagate_device_spec; +} + } // end namespace tensorflow #endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index c1aeb831bce..43dfe5155de 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -9,16 +9,13 @@ load( "tf_cuda_library", ) load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup", "internal_tfrt_deps") -load( - "//tensorflow/core/platform:build_config.bzl", - "tf_kernel_tests_linkstatic", -) load( "//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags", ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -131,7 +128,7 @@ filegroup( "tfe_tensorhandle_internal.h", ], visibility = [ - "//tensorflow/core/function:__pkg__", + "//tensorflow/core/function/runtime_client:__pkg__", "//tensorflow/python:__subpackages__", ], ) @@ -256,7 +253,6 @@ tf_cuda_cc_test( "gradients_test.cc", ], args = ["--heap_check="], - linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], deps = [ ":abstract_context", @@ -293,7 +289,6 @@ tf_cuda_cc_test( "unified_api_test.cc", ], args = ["--heap_check="], - linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156 deps = [ ":c_api_experimental", @@ -337,7 +332,6 @@ tf_cuda_cc_test( "gradient_checker_test.cc", ], args = ["--heap_check="], - linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + [ "no_cuda_asan", # b/175330074 ], @@ -755,7 +749,10 @@ tf_cuda_cc_test( tags = [ "no_oss", # TODO(b/200848572) "no_windows", + # TODO(b/136478427): sanitizers report issues due to unclean exit. "noasan", # leaks gRPC server instances + "nomsan", # b/229991646: use of destructed memory due to unclean exit. + "notsan", # b/259602430: race on destructed mutex due to unclean exit. ], deps = [ ":c_api", @@ -885,9 +882,9 @@ tf_cuda_library( }) + [ "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", "@com_google_absl//absl/container:flat_hash_map", "//tensorflow/c:tf_status_helper", - "//tensorflow/core/distributed_runtime/coordination:coordination_service_agent", "//tensorflow/core/distributed_runtime/coordination:coordination_service_error_util", "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", @@ -900,6 +897,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core:gpu_runtime", + "//tensorflow/tsl/distributed_runtime/coordination:coordination_service_agent", ], alwayslink = 1, ) @@ -911,7 +909,6 @@ tf_cuda_cc_test( "c_api_experimental_test.cc", ], args = ["--heap_check="], - linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], deps = [ ":c_api", @@ -934,7 +931,6 @@ tf_cuda_cc_test( "c_api_unified_experimental_test.cc", ], args = ["--heap_check="], - linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], deps = [ ":c_api", @@ -1015,7 +1011,7 @@ cc_library( name = "dlpack", srcs = ["dlpack.cc"], hdrs = ["dlpack.h"], - copts = [ + copts = tf_copts() + [ "-fexceptions", "-fno-strict-aliasing", ], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 13a9c797235..e3199b204f6 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -68,7 +68,6 @@ limitations under the License. #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) && \ !defined(PLATFORM_FUCHSIA) #include "tensorflow/core/tfrt/eager/c_api_tfrt.h" -#include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h" #endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE && !PLATFORM_FUCHSIA #if !defined(IS_MOBILE_PLATFORM) @@ -123,12 +122,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { opts->session_options.options, static_cast( opts->device_placement_policy), - opts->async, opts->use_tfrt_distributed_runtime); -#if !defined(IS_MOBILE_PLATFORM) - tfrt_context->SetDistributedManager( - tfrt::tf::CreateDistributedManagerContext( - tfrt_context->GetCoreRuntime()->GetHostContext())); -#endif // !IS_MOBILE_PLATFORM + opts->async); return tensorflow::wrap(tfrt_context); #else status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 149c6062d23..7eb22ed2c7c 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/strings/match.h" +#include "absl/time/time.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/tfe_context_internal.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/composite_device.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" -#include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h" #include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h" #include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/lib/monitoring/gauge.h" @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/strcat.h" +#include "tensorflow/tsl/distributed_runtime/coordination/coordination_service_agent.h" using tensorflow::string; @@ -539,11 +540,6 @@ void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) { options->use_tfrt = use_tfrt; } -void TFE_ContextOptionsSetTfrtDistributedRuntime( - TFE_ContextOptions* options, bool use_tfrt_distributed_runtime) { - options->use_tfrt_distributed_runtime = use_tfrt_distributed_runtime; -} - TFE_CancellationManager* TFE_NewCancellationManager() { return tensorflow::wrap(new tensorflow::CancellationManager); } @@ -571,8 +567,10 @@ void TFE_OpSetCancellationManager(TFE_Op* op, status->status = ::tensorflow::OkStatus(); } -TFE_Executor* TFE_NewExecutor(bool is_async, bool enable_streaming_enqueue) { - return new TFE_Executor(is_async, enable_streaming_enqueue); +TFE_Executor* TFE_NewExecutor(bool is_async, bool enable_streaming_enqueue, + int in_flight_nodes_limit) { + return new TFE_Executor(is_async, enable_streaming_enqueue, + in_flight_nodes_limit); } void TFE_DeleteExecutor(TFE_Executor* executor) { delete executor; } @@ -785,7 +783,7 @@ void TFE_InsertConfigKeyValue(TFE_Context* ctx, const char* key, const char* value, TF_Status* status) { tensorflow::ImmediateExecutionDistributedManager* dist_mgr = tensorflow::unwrap(ctx)->GetDistributedManager(); - tensorflow::CoordinationServiceAgent* coord_agent = + tsl::CoordinationServiceAgent* coord_agent = dist_mgr->GetCoordinationServiceAgent(); if (coord_agent == nullptr) { status->status = tensorflow::errors::FailedPrecondition( @@ -799,7 +797,7 @@ void TFE_GetConfigKeyValue(TFE_Context* ctx, const char* key, TF_Buffer* value_buf, TF_Status* status) { tensorflow::ImmediateExecutionDistributedManager* dist_mgr = tensorflow::unwrap(ctx)->GetDistributedManager(); - tensorflow::CoordinationServiceAgent* coord_agent = + tsl::CoordinationServiceAgent* coord_agent = dist_mgr->GetCoordinationServiceAgent(); if (coord_agent == nullptr) { status->status = tensorflow::errors::FailedPrecondition( @@ -824,7 +822,7 @@ void TFE_DeleteConfigKeyValue(TFE_Context* ctx, const char* key, TF_Status* status) { tensorflow::ImmediateExecutionDistributedManager* dist_mgr = tensorflow::unwrap(ctx)->GetDistributedManager(); - tensorflow::CoordinationServiceAgent* coord_agent = + tsl::CoordinationServiceAgent* coord_agent = dist_mgr->GetCoordinationServiceAgent(); if (coord_agent == nullptr) { status->status = tensorflow::errors::FailedPrecondition( @@ -838,7 +836,7 @@ void TFE_ReportErrorToCluster(TFE_Context* ctx, int error_code, const char* error_message, TF_Status* status) { tensorflow::ImmediateExecutionDistributedManager* dist_mgr = tensorflow::unwrap(ctx)->GetDistributedManager(); - tensorflow::CoordinationServiceAgent* coord_agent = + tsl::CoordinationServiceAgent* coord_agent = dist_mgr->GetCoordinationServiceAgent(); if (coord_agent == nullptr) { status->status = tensorflow::errors::FailedPrecondition( @@ -854,7 +852,7 @@ void TFE_GetTaskStates(TFE_Context* ctx, const TF_Buffer& tasks, void* states, TF_Status* status) { tensorflow::ImmediateExecutionDistributedManager* dist_mgr = tensorflow::unwrap(ctx)->GetDistributedManager(); - tensorflow::CoordinationServiceAgent* coord_agent = + tsl::CoordinationServiceAgent* coord_agent = dist_mgr->GetCoordinationServiceAgent(); if (coord_agent == nullptr) { status->status = tensorflow::errors::FailedPrecondition( @@ -890,3 +888,18 @@ void TFE_GetTaskStates(TFE_Context* ctx, const TF_Buffer& tasks, void* states, } status->status = tensorflow::OkStatus(); } + +void TFE_WaitAtBarrier(TFE_Context* ctx, const char* barrier_id, + int64_t barrier_timeout_in_ms, TF_Status* status) { + tensorflow::ImmediateExecutionDistributedManager* dist_mgr = + tensorflow::unwrap(ctx)->GetDistributedManager(); + tsl::CoordinationServiceAgent* coord_agent = + dist_mgr->GetCoordinationServiceAgent(); + if (coord_agent == nullptr) { + status->status = tensorflow::errors::FailedPrecondition( + "Coordination service is not enabled."); + return; + } + status->status = coord_agent->WaitAtBarrier( + barrier_id, absl::Milliseconds(barrier_timeout_in_ms), {}); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 704a093fbab..95d833f6f47 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -294,10 +294,6 @@ TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2( TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*, bool use_tfrt); -// Sets whether to use TFRT distributed runtime -TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrtDistributedRuntime( - TFE_ContextOptions* options, bool use_tfrt_distributed_runtime); - // Returns the context_id from the EagerContext which is used by the // EagerService to maintain consistency between client and worker. The // context_id is initialized with a dummy value and is later set when the worker @@ -333,8 +329,16 @@ typedef struct TFE_Executor TFE_Executor; // Creates a new eager Executor. Nodes in one executor are guaranteed to be // executed in sequence. Assigning nodes to different executors allows executing // nodes in parallel. +// in_flight_nodes_limit: when is_async is true, this value controls the +// maximum number of in flight async nodes. Enqueuing of additional async ops +// after the limit is reached blocks until some inflight nodes finishes. +// The effect is bounding the memory held by inflight TensorHandles that are +// referenced by the inflight nodes. +// A recommended value has not been established. +// A value of 0 removes the limit, which is the behavior of TensorFlow 2.11. +// When is_async is false, the value is ignored. TF_CAPI_EXPORT extern TFE_Executor* TFE_NewExecutor( - bool is_async, bool enable_streaming_enqueue); + bool is_async, bool enable_streaming_enqueue, int in_flight_nodes_limit); // Deletes the eager Executor without waiting for enqueued nodes. Please call // TFE_ExecutorWaitForAllPendingNodes before calling this API if you want to @@ -724,6 +728,11 @@ TF_CAPI_EXPORT extern void TFE_GetTaskStates(TFE_Context* ctx, const TF_Buffer& tasks, void* states, TF_Status* status); +TF_CAPI_EXPORT extern void TFE_WaitAtBarrier(TFE_Context* ctx, + const char* barrier_id, + int64_t barrier_timeout_in_ms, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index f900de59060..68dbafc4d2a 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -220,7 +220,8 @@ TEST(CAPI, ExecutorContextDestructionOrder) { ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TFE_DeleteContextOptions(opts); TFE_Executor* executor = TFE_NewExecutor( - /*is_async=*/false, /*enable_streaming_enqueue=*/true); + /*is_async=*/false, /*enable_streaming_enqueue=*/true, + /*in_flight_nodes_limit=*/0); TFE_ContextSetExecutorForThread(ctx, executor); TFE_DeleteContext(ctx); @@ -233,7 +234,8 @@ TEST(CAPI, ExecutorContextDestructionOrder) { ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TFE_DeleteContextOptions(opts); TFE_Executor* executor = TFE_NewExecutor( - /*is_async=*/false, /*enable_streaming_enqueue=*/true); + /*is_async=*/false, /*enable_streaming_enqueue=*/true, + /*in_flight_nodes_limit=*/0); TFE_ContextSetExecutorForThread(ctx, executor); TFE_DeleteExecutor(executor); @@ -275,7 +277,8 @@ TEST(CAPI, Function_ident_CPU) { for (bool async : {false, true, false}) { TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx); TFE_Executor* executor = TFE_NewExecutor( - /*is_async=*/async, /*enable_streaming_enqueue=*/true); + /*is_async=*/async, /*enable_streaming_enqueue=*/true, + /*in_flight_nodes_limit=*/0); TFE_ContextSetExecutorForThread(ctx, executor); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -327,7 +330,8 @@ void Executor_MatMul_CPU(bool async) { TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx); TFE_Executor* executor = TFE_NewExecutor( - /*is_async=*/async, /*enable_streaming_enqueue=*/true); + /*is_async=*/async, /*enable_streaming_enqueue=*/true, + /*in_flight_nodes_limit=*/0); TFE_ContextSetExecutorForThread(ctx, executor); TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 8bec998681e..eff96826822 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -34,11 +34,6 @@ struct TFE_ContextOptions { TFE_DEVICE_PLACEMENT_SILENT}; // If true, use TFRT backend bool use_tfrt = false; - // This option is effective only when use_tfrt is true. If true, TFRT will use - // native TFRT distributed runtime. Otherwise, TFRT will use current runtime's - // distributed runtime. Note that TFRT distributed runtime is in development - // and not functionally complete. - bool use_tfrt_distributed_runtime = false; // Whether to run elementary eager ops wrapped in a call op. bool run_eager_op_as_function = false; // Whether to rewrite jit_compile functions. diff --git a/tensorflow/c/eager/immediate_execution_distributed_manager.h b/tensorflow/c/eager/immediate_execution_distributed_manager.h index 9efb2fa85d6..4f96992e739 100644 --- a/tensorflow/c/eager/immediate_execution_distributed_manager.h +++ b/tensorflow/c/eager/immediate_execution_distributed_manager.h @@ -20,8 +20,11 @@ limitations under the License. #include "tensorflow/core/platform/status.h" -namespace tensorflow { +namespace tsl { class CoordinationServiceAgent; +} + +namespace tensorflow { class ImmediateExecutionContext; class ServerDef; class WorkerEnv; @@ -32,19 +35,19 @@ class ImmediateExecutionDistributedManager { virtual ~ImmediateExecutionDistributedManager() {} // Set up distributed execution environment on local and remote tasks. - // When `reset_context` is true, initialize new cluster context state based on - // cluster configurations provided in `server_def`; otherwise, update existing - // context state with the provided `server_def`. - // Contexts created on remote tasks will be considered stale and garbage - // collected after `keep_alive_secs` of inactivity. + // When `reset_context` is true, initialize new cluster context state based + // on cluster configurations provided in `server_def`; otherwise, update + // existing context state with the provided `server_def`. Contexts created + // on remote tasks will be considered stale and garbage collected after + // `keep_alive_secs` of inactivity. virtual Status SetOrUpdateServerDef(const ServerDef& server_def, bool reset_context, int keep_alive_secs) = 0; - // Set up a multi-client distributed execution environment. Must be called on - // all tasks in the cluster. - // This call internally coordinates with other tasks to initialize the eager - // context and TF server for multi-client execution. + // Set up a multi-client distributed execution environment. Must be called + // on all tasks in the cluster. This call internally coordinates with other + // tasks to initialize the eager context and TF server for multi-client + // execution. virtual Status EnableCollectiveOps(const ServerDef& server_def) = 0; // Check if the remote task is alive. @@ -52,7 +55,7 @@ class ImmediateExecutionDistributedManager { bool* is_alive) = 0; // Get pointer to the coordination service agent instance. - virtual CoordinationServiceAgent* GetCoordinationServiceAgent() = 0; + virtual tsl::CoordinationServiceAgent* GetCoordinationServiceAgent() = 0; }; } // namespace tensorflow diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD index e528a7070ab..0de029ff449 100644 --- a/tensorflow/c/eager/parallel_device/BUILD +++ b/tensorflow/c/eager/parallel_device/BUILD @@ -6,6 +6,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc index 727c1f83396..fd054c9af9a 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc @@ -66,7 +66,8 @@ using ExecutorPtr = std::unique_ptr; class DeviceThread { public: // Starts a background thread waiting for `StartExecute`. - explicit DeviceThread(const std::string& device, const bool is_async) + explicit DeviceThread(const std::string& device, const bool is_async, + const int in_flight_nodes_limit) : status_(TF_NewStatus()), // If the context's default exector is set to async, re-using that in // each thread would cause collectives to deadlock. For consistency we @@ -75,7 +76,9 @@ class DeviceThread { // TODO(allenl): We should have an async API that works with the // parallel device. device_(device), - executor_(TFE_NewExecutor(is_async, /*enable_streaming_enqueue=*/true)), + executor_( + TFE_NewExecutor(is_async, /*enable_streaming_enqueue=*/true, + /*in_flight_nodes_limit=*/in_flight_nodes_limit)), op_(nullptr), thread_(tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "parallel_device_execute", @@ -282,13 +285,13 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name, } ParallelDevice::ParallelDevice(const std::vector& devices, - const bool is_async) + bool is_async, int in_flight_nodes_limit) : underlying_devices_(devices), default_cancellation_manager_(absl::make_unique()) { device_threads_.reserve(devices.size()); for (int device_index = 0; device_index < devices.size(); ++device_index) { - device_threads_.emplace_back( - new DeviceThread(devices[device_index].c_str(), is_async)); + device_threads_.emplace_back(new DeviceThread( + devices[device_index].c_str(), is_async, in_flight_nodes_limit)); } } @@ -365,6 +368,26 @@ void ParallelDevice::StartExecute(TFE_Context* context, } } +void ParallelDevice::StartExecute( + TFE_Context* context, const std::vector& inputs, + const char* operation_name, const TFE_OpAttrs* attributes, + int expected_max_outputs, CancellationManager& cancellation_manager, + absl::optional step_id) const { + for (int device_index = 0; device_index < underlying_devices_.size(); + ++device_index) { + DeviceThread* device_thread = device_threads_[device_index].get(); + std::vector device_inputs; + device_inputs.reserve(inputs.size()); + for (int input_index = 0; input_index < inputs.size(); ++input_index) { + // Parallel tensors are divided between operations by device. + device_inputs.push_back(inputs[input_index][device_index].get()); + } + device_thread->StartExecute( + context, operation_name, std::move(device_inputs), attributes, + expected_max_outputs, cancellation_manager, step_id); + } +} + void ParallelDevice::AsyncWait(TFE_Context* context, TF_Status* status) const { StatusPtr first_bad_status(nullptr); diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.h b/tensorflow/c/eager/parallel_device/parallel_device_lib.h index 80f81dd47a4..01581f40e05 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.h +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_ #include +#include #include #include @@ -44,6 +45,8 @@ class TensorHandleDeleter { } }; +// TODO(b/256016071): Replace this with `Safe_TFE_TensorHandlePtr` when +// `Safe_TFE_TensorHandlePtr` is marked to be compatible on non-prod env. using TensorHandlePtr = std::unique_ptr; class ParallelTensor; @@ -56,7 +59,7 @@ class ParallelDevice { // Eager async execution is only supported when remote eager is not in use // (b/157523095). explicit ParallelDevice(const std::vector& devices, - const bool is_async = false); + bool is_async = false, int in_flight_nodes_limit = 0); ~ParallelDevice(); @@ -118,12 +121,24 @@ class ParallelDevice { // // Set step_id to configure the step id used for rendezvous creation. step id // of value -1 is reserved for global rendezvous and should not be set here. + // + // This function is overloaded so that if the inputs are constructed from + // `TensorWithLayout` we can use the one with `TensorHandlePtr` but + // if the inputs are directly `ParallelTensor` (for example, in the case of + // custom device execution) we can use the one with `ParallelTensor`. void StartExecute(TFE_Context* context, const std::vector& inputs, const char* operation_name, const TFE_OpAttrs* attributes, int expected_max_outputs, CancellationManager& cancellation_manager, - absl::optional step_id = absl::nullopt) const; + std::optional step_id = std::nullopt) const; + + void StartExecute(TFE_Context* context, + const std::vector& inputs, + const char* operation_name, const TFE_OpAttrs* attributes, + int expected_max_outputs, + CancellationManager& cancellation_manager, + std::optional step_id = std::nullopt) const; // Blocks until the previous `StartExecute` has run `TFE_Execute` on each // device. If is_async=false (constructor argument) this means the ops have @@ -189,6 +204,7 @@ class ParallelTensor { size_t num_tensors() const { return tensors_.size(); } TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); } + const TensorHandlePtr* tensor_data() const { return tensors_.data(); } // If the `shape` argument to `FromTensorHandles` is specified, returns that. // diff --git a/tensorflow/c/eager/tfe_executor_internal.h b/tensorflow/c/eager/tfe_executor_internal.h index 081b139bd34..7f55532af56 100644 --- a/tensorflow/c/eager/tfe_executor_internal.h +++ b/tensorflow/c/eager/tfe_executor_internal.h @@ -20,9 +20,10 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/eager_executor.h" struct TFE_Executor { - explicit TFE_Executor(bool async, bool enable_streaming_enqueue) - : owned_executor( - new tensorflow::EagerExecutor(async, enable_streaming_enqueue)) {} + explicit TFE_Executor(bool async, bool enable_streaming_enqueue, + int in_flight_nodes_limit) + : owned_executor(new tensorflow::EagerExecutor( + async, enable_streaming_enqueue, in_flight_nodes_limit)) {} explicit TFE_Executor(tensorflow::EagerExecutor* executor) : owned_executor(nullptr), unowned_executor(executor) {} diff --git a/tensorflow/c/experimental/filesystem/BUILD b/tensorflow/c/experimental/filesystem/BUILD index 4d8ff231ce7..6c5c43fbb46 100644 --- a/tensorflow/c/experimental/filesystem/BUILD +++ b/tensorflow/c/experimental/filesystem/BUILD @@ -6,6 +6,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -42,6 +43,7 @@ cc_library( "//tensorflow/core/platform:env", "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", + "//tensorflow/tsl/platform:errors", ], ) diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem.cc b/tensorflow/c/experimental/filesystem/modular_filesystem.cc index 32b06697d77..b47748374fe 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/modular_filesystem.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/file_system_helper.h" #include "tensorflow/core/util/ptr_util.h" +#include "tensorflow/tsl/platform/errors.h" // TODO(b/139060984): After all filesystems are converted, all calls to // methods from `FileSystem` will have to be replaced to calls to private @@ -561,8 +562,9 @@ Status RegisterFilesystemPlugin(const std::string& dso_path) { // Step 2: Load symbol for `TF_InitPlugin` void* dso_symbol; - TF_RETURN_IF_ERROR( - env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + env->GetSymbolFromLibrary(dso_handle, "TF_InitPlugin", &dso_symbol), + "Failed to load TF_InitPlugin symbol for DSO: ", dso_path); // Step 3: Call `TF_InitPlugin` TF_FilesystemPluginInfo info; diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD index 1d9bfc1a15f..bd2041b1d43 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD @@ -4,6 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) diff --git a/tensorflow/c/experimental/filesystem/plugins/posix/BUILD b/tensorflow/c/experimental/filesystem/plugins/posix/BUILD index 9d655fd43b5..90acb2bf389 100644 --- a/tensorflow/c/experimental/filesystem/plugins/posix/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/posix/BUILD @@ -4,6 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:private"], licenses = ["notice"], ) diff --git a/tensorflow/c/experimental/filesystem/plugins/windows/BUILD b/tensorflow/c/experimental/filesystem/plugins/windows/BUILD index fb2f99f44ff..2ac57f6a731 100644 --- a/tensorflow/c/experimental/filesystem/plugins/windows/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/windows/BUILD @@ -4,6 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index 90a99b05e38..1788cbd6551 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -5,10 +5,6 @@ load( "if_libtpu", "tf_cuda_cc_test", ) -load( - "//tensorflow/core/platform:build_config.bzl", - "tf_kernel_tests_linkstatic", -) load( "//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags", @@ -16,6 +12,7 @@ load( # Library of gradient functions. package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -59,7 +56,7 @@ cc_library( "nn_grad.h", ], visibility = [ - "//tensorflow:internal", + "//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private. ], deps = [ "//tensorflow/c/eager:abstract_tensor_handle", @@ -118,7 +115,6 @@ tf_cuda_cc_test( "custom_gradient_test.cc", ], args = ["--heap_check="], # TODO(b/174752220): Remove - linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags(), deps = [ "//tensorflow/c:tf_status_helper", @@ -144,10 +140,7 @@ filegroup( "nn_grad.h", "not_differentiable.h", ], - visibility = [ - "//tensorflow/core:__pkg__", - "//tensorflow/python:__pkg__", - ], + visibility = ["//tensorflow/python:__pkg__"], ) cc_library( @@ -156,7 +149,7 @@ cc_library( srcs = ["grad_test_helper.cc"], hdrs = ["grad_test_helper.h"], visibility = [ - "//tensorflow:internal", + "//visibility:private", # Only private by automation, not intent. Owner may accept CLs adding visibility. See go/scheuklappen#explicit-private. ], deps = [ "//tensorflow/c/eager:gradient_checker", @@ -175,7 +168,6 @@ tf_cuda_cc_test( "nn_grad_test.cc", ], args = ["--heap_check="], # TODO(b/174752220): Remove - linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156, deps = [ ":grad_test_helper", @@ -202,7 +194,6 @@ tf_cuda_cc_test( "math_grad_test.cc", ], args = ["--heap_check="], # TODO(b/174752220): Remove - linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156, deps = [ ":grad_test_helper", @@ -229,7 +220,6 @@ tf_cuda_cc_test( "array_grad_test.cc", ], args = ["--heap_check="], # TODO(b/174752220): Remove - linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156, deps = [ ":grad_test_helper", diff --git a/tensorflow/c/experimental/gradients/tape/BUILD b/tensorflow/c/experimental/gradients/tape/BUILD index 123f1908020..c29b7929d43 100644 --- a/tensorflow/c/experimental/gradients/tape/BUILD +++ b/tensorflow/c/experimental/gradients/tape/BUILD @@ -2,6 +2,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) diff --git a/tensorflow/c/experimental/grappler/BUILD b/tensorflow/c/experimental/grappler/BUILD index 68bdcdcda70..482ec08efed 100644 --- a/tensorflow/c/experimental/grappler/BUILD +++ b/tensorflow/c/experimental/grappler/BUILD @@ -8,6 +8,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) diff --git a/tensorflow/c/experimental/next_pluggable_device/BUILD b/tensorflow/c/experimental/next_pluggable_device/BUILD new file mode 100644 index 00000000000..890477266ea --- /dev/null +++ b/tensorflow/c/experimental/next_pluggable_device/BUILD @@ -0,0 +1,34 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +cc_library( + name = "c_api", + srcs = ["c_api.cc"], + hdrs = ["c_api.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/c:c_api", + "//tensorflow/c:kernels", + "//tensorflow/c:kernels_experimental_hdrs", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c:tf_status_internal", + "//tensorflow/c:tf_tensor_internal", + "//tensorflow/compiler/jit:xla_launch_util", + "//tensorflow/compiler/xla/pjrt:pjrt_c_api_client", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt/c:pjrt_c_api_hdrs", + "//tensorflow/core:framework", + "//tensorflow/core/common_runtime/next_pluggable_device", + "//tensorflow/core/common_runtime/next_pluggable_device:plugin_resource", + "//tensorflow/core/platform:status", + "//tensorflow/core/tfrt/common:async_value_tensor", + "//tensorflow/core/tfrt/common:pjrt_util", + "//tensorflow/tsl/distributed_runtime/coordination:coordination_service_agent", + "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:statusor", + ], +) diff --git a/tensorflow/c/experimental/next_pluggable_device/c_api.cc b/tensorflow/c/experimental/next_pluggable_device/c_api.cc new file mode 100644 index 00000000000..1ff6e091507 --- /dev/null +++ b/tensorflow/c/experimental/next_pluggable_device/c_api.cc @@ -0,0 +1,333 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/next_pluggable_device/c_api.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/c/kernels_experimental.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/c/tf_tensor_internal.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/plugin_resource.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tfrt/common/async_value_tensor.h" +#include "tensorflow/core/tfrt/common/pjrt_util.h" +#include "tensorflow/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/statusor.h" + +TF_Device* TF_GetDevice(TF_OpKernelContext* ctx) { + auto* cc_ctx = reinterpret_cast(ctx); + return reinterpret_cast(cc_ctx->device()); +} + +size_t TF_GetDeviceOrdinal(TF_Device* device) { + // TODO(chuanhao): make GetDeviceOrdinal a virtual member function in the base + // device class, instead of casting to `NextPluggableDevice`. + auto cc_device = reinterpret_cast(device); + return cc_device->GetDeviceOrdinal(); +} + +// -------------------------- Resource --------------------------------------- +void TF_CreatePluginResource(TF_OpKernelContext* ctx, + const char* container_name, + const char* plugin_resource_name, + void* plugin_resource, void (*delete_func)(void*), + TF_Status* status) { + auto* cc_ctx = reinterpret_cast(ctx); + tensorflow::PluginResource* cc_resource_ptr = new tensorflow::PluginResource( + plugin_resource, plugin_resource_name, delete_func); + auto cc_status = + cc_ctx->resource_manager()->Create( + container_name, plugin_resource_name, cc_resource_ptr); + Set_TF_Status_from_Status(status, cc_status); +} + +void TF_LookupOrCreatePluginResource( + TF_OpKernelContext* ctx, const char* container_name, + const char* plugin_resource_name, void** result_plugin_resource, + void* (*create_func)(void*), void* create_func_args, + void (*delete_func)(void*), TF_Status* status) { + auto* cc_ctx = reinterpret_cast(ctx); + auto* resource_mgr = cc_ctx->resource_manager(); + tensorflow::core::RefCountPtr + tf_plugin_resource_ptr; + tensorflow::PluginResource* tf_plugin_resource = nullptr; + + auto cc_status = resource_mgr->LookupOrCreate( + container_name, plugin_resource_name, &tf_plugin_resource, + [plugin_resource_name, create_func, create_func_args, + delete_func](tensorflow::PluginResource** new_resource) { + void* opaque_plugin_resource = create_func(create_func_args); + *new_resource = new tensorflow::PluginResource( + opaque_plugin_resource, plugin_resource_name, delete_func); + return tensorflow::OkStatus(); + }); + + if (cc_status.ok()) { + tf_plugin_resource_ptr.reset(tf_plugin_resource); + *result_plugin_resource = tf_plugin_resource_ptr->GetOpaquePluginResource(); + } else { + *result_plugin_resource = nullptr; + } + Set_TF_Status_from_Status(status, cc_status); +} + +// ------------------------- VariableInfo ------------------------------------ +struct TF_VariableInfo { + TF_VariableInfo() = delete; + // TF_VariableInfo is constructed here by TensorFlow, and will be passed to + // plugin as a opaque pointer. Plugin will need to call C APIs below to + // operate on TF_VaribleInfo (such as allocate temp tensor for the `var` held + // by the underlying tensorflow::VariableInfo. + TF_VariableInfo(int index, const std::string& name, tensorflow::Var* var) { + var_info = tensorflow::VariableInfo{index, name, var}; + } + + tensorflow::VariableInfo var_info{0, "", nullptr}; +}; + +TF_VariableInfo* TF_CreateVariableInfoFromContext(TF_OpKernelContext* ctx, + int index, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast(ctx); + const tensorflow::Tensor& arg_tensor = cc_ctx->input(index); + tsl::Status cc_status; + if (arg_tensor.dtype() != tensorflow::DT_RESOURCE) { + cc_status = tsl::errors::InvalidArgument( + "Trying to obtain resource handle from Input[", index, + "], which is not type DT_RESOURCE."); + Set_TF_Status_from_Status(status, cc_status); + return nullptr; + } + const tensorflow::ResourceHandle& handle = + arg_tensor.flat()(0); + tensorflow::Var* variable; + cc_status = tensorflow::LookupResource(cc_ctx, handle, &variable); + return new TF_VariableInfo(index, handle.name(), variable); +} + +void TF_LockVariableInfos(TF_VariableInfo** vars, int num_vars, + TF_Status* status) { + std::vector variable_ptrs; + variable_ptrs.reserve(num_vars); + for (int i = 0; i < num_vars; ++i) { + variable_ptrs.push_back(&(vars[i]->var_info)); + } + tsl::Status cc_status = LockVariables(absl::MakeSpan(variable_ptrs)); + tsl::Set_TF_Status_from_Status(status, cc_status); +} + +void TF_AllocateTempForVariableInfo(TF_OpKernelContext* ctx, + TF_VariableInfo* var_info, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast(ctx); + tsl::Status cc_status; + if (var_info == nullptr) { + cc_status = tsl::errors::InvalidArgument("TF_VariableInfo is NULL."); + Set_TF_Status_from_Status(status, cc_status); + return; + } + if (var_info->var_info.var() == nullptr) { + cc_status = tsl::errors::InvalidArgument( + "VariableInfo does not track a resource variable."); + Set_TF_Status_from_Status(status, cc_status); + return; + } + + cc_status = cc_ctx->allocate_temp(var_info->var_info.var()->tensor()->dtype(), + var_info->var_info.var()->tensor()->shape(), + var_info->var_info.var()->tensor()); + Set_TF_Status_from_Status(status, cc_status); +} + +TF_Tensor* TF_GetTensorFromVariableInfo(TF_VariableInfo* var_info, + TF_Status* status) { + tsl::Status cc_status; + if (var_info == nullptr) { + cc_status = tsl::errors::InvalidArgument("TF_VariableInfo is NULL."); + Set_TF_Status_from_Status(status, cc_status); + return nullptr; + } + if (var_info->var_info.var() == nullptr) { + cc_status = tsl::errors::InvalidArgument( + "VariableInfo does not track a resource variable."); + Set_TF_Status_from_Status(status, cc_status); + return nullptr; + } + + tensorflow::Tensor* tensor = var_info->var_info.var()->tensor(); + TF_Tensor* result_tensor = + tensorflow::TF_TensorFromTensor(*tensor, &cc_status); + Set_TF_Status_from_Status(status, cc_status); + return result_tensor; +} + +void TF_DeleteVariableInfo(TF_VariableInfo* var_info) { + if (var_info != nullptr) { + delete var_info; + } +} + +// --------------------- Coordination service -------------------------------- +TF_CoordinationServiceAgent* TF_GetCoordinationServiceAgent( + TF_OpKernelContext* ctx) { + auto* cc_ctx = reinterpret_cast(ctx); + return reinterpret_cast( + cc_ctx->coordination_service_agent()); +} + +bool TF_CoordinationServiceIsInitialized(TF_CoordinationServiceAgent* agent) { + if (agent == nullptr) return false; + auto* cc_agent = reinterpret_cast(agent); + return cc_agent->IsInitialized(); +} + +void TF_CoordinationServiceInsertKeyValue(const char* key, const char* value, + TF_CoordinationServiceAgent* agent, + TF_Status* status) { + auto* cc_agent = reinterpret_cast(agent); + tsl::Status cc_status = cc_agent->InsertKeyValue(key, value); + tsl::Set_TF_Status_from_Status(status, cc_status); +} + +TF_Buffer* TF_CoordinationServiceGetKeyValue(const char* key, + TF_CoordinationServiceAgent* agent, + TF_Status* status) { + auto* cc_agent = reinterpret_cast(agent); + auto value = cc_agent->GetKeyValue(key); + tsl::Set_TF_Status_from_Status(status, value.status()); + if (!value.ok()) { + return nullptr; + } + // Caller is responsible to call `TF_DeleteBuffer` to release the buffer. + TF_Buffer* result = TF_NewBuffer(); + const std::string& value_str = *value; + void* data = malloc(value_str.length()); + value_str.copy(static_cast(data), value_str.length(), 0); + result->data = data; + result->length = value_str.length(); + result->data_deallocator = [](void* data, size_t length) { free(data); }; + return result; +} + +void TF_CoordinationServiceDeleteKeyValue(const char* key, + TF_CoordinationServiceAgent* agent, + TF_Status* status) { + auto* cc_agent = reinterpret_cast(agent); + tsl::Status cc_status = cc_agent->DeleteKeyValue(key); + tsl::Set_TF_Status_from_Status(status, cc_status); +} + +// ---------------------------- PJRT ----------------------------------------- +void TF_CreateAndSetPjRtCApiClient(const char* device_type, TF_Status* status) { + tsl::StatusOr> pjrt_client = + xla::GetCApiClient(device_type); + if (!pjrt_client.ok()) { + tensorflow::Set_TF_Status_from_Status(status, pjrt_client.status()); + return; + } + + tsl::Status s = tensorflow::SetPjRtClientInTFGlobalResourceManager( + tensorflow::DeviceType(device_type), std::move(*pjrt_client)); + tsl::Set_TF_Status_from_Status(status, s); +} + +PJRT_Client* TF_GetPjRtCClient(const char* device_type, TF_Status* status) { + tsl::StatusOr pjrt_client = + tensorflow::GetOrCreatePjRtClient(tensorflow::DeviceType(device_type)); + if (!pjrt_client.ok()) { + tensorflow::Set_TF_Status_from_Status(status, pjrt_client.status()); + return nullptr; + } + auto* pjrt_c_api_client = + tensorflow::down_cast(*pjrt_client); + if (pjrt_c_api_client == nullptr) { + tensorflow::Set_TF_Status_from_Status( + status, tsl::errors::Internal("PjRtClient for ", device_type, + " is not type PjRtCApiClient")); + return nullptr; + } + TF_SetStatus(status, TF_OK, ""); + return pjrt_c_api_client->pjrt_c_client(); +} + +PJRT_Buffer* TF_GetPjRtCBuffer(TF_Tensor* c_tensor, TF_Status* status) { + tensorflow::Tensor tensor; + auto s = tensorflow::TF_TensorToTensor(c_tensor, &tensor); + if (!s.ok()) { + tensorflow::Set_TF_Status_from_Status(status, s); + return nullptr; + } + tensorflow::AsyncValueTensor* av_tensor = + tensorflow::AsyncValueTensor::FromTensor(&tensor); + if (av_tensor == nullptr || av_tensor->GetBuffer() == nullptr) { + tensorflow::Set_TF_Status_from_Status( + status, + tsl::errors::Internal("Input tensor does not have PjRtBuffer.")); + return nullptr; + } + auto* c_api_buffer = + tensorflow::down_cast(av_tensor->GetBuffer().get()); + if (c_api_buffer == nullptr) { + tensorflow::Set_TF_Status_from_Status( + status, + tsl::errors::Internal( + "The PjRtBuffer in the tensor is not type PjRtCApiBuffer.")); + return nullptr; + } + TF_SetStatus(status, TF_OK, ""); + return c_api_buffer->c_buffer(); +} + +void TF_CreatePjRtBuffer(TF_Tensor* c_tensor, PJRT_Buffer* c_buffer, + const char* device_type, TF_Status* status) { + tensorflow::Tensor tensor; + auto s = tensorflow::TF_TensorToTensor(c_tensor, &tensor); + if (!s.ok()) { + tensorflow::Set_TF_Status_from_Status(status, s); + return; + } + auto pjrt_client = + tensorflow::GetOrCreatePjRtClient(tensorflow::DeviceType(device_type)); + if (!pjrt_client.ok()) { + tensorflow::Set_TF_Status_from_Status(status, pjrt_client.status()); + return; + } + auto* pjrt_c_api_client = + tensorflow::down_cast(*pjrt_client); + if (pjrt_c_api_client == nullptr) { + tensorflow::Set_TF_Status_from_Status( + status, tsl::errors::Internal("PjRtClient for ", device_type, + " is not type PjRtCApiClient")); + return; + } + tensorflow::AsyncValueTensor* av_tensor = + tensorflow::AsyncValueTensor::FromTensor(&tensor); + av_tensor->SetBuffer( + std::make_unique(pjrt_c_api_client, c_buffer)); + TF_SetStatus(status, TF_OK, ""); +} diff --git a/tensorflow/c/experimental/next_pluggable_device/c_api.h b/tensorflow/c/experimental/next_pluggable_device/c_api.h new file mode 100644 index 00000000000..e577f02a595 --- /dev/null +++ b/tensorflow/c/experimental/next_pluggable_device/c_api.h @@ -0,0 +1,153 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_NEXT_PLUGGABLE_DEVICE_C_API_H_ +#define TENSORFLOW_C_EXPERIMENTAL_NEXT_PLUGGABLE_DEVICE_C_API_H_ + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/kernels.h" +#include "tensorflow/c/kernels_experimental.h" +#include "tensorflow/c/tf_buffer.h" +#include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h" + +// -------------------------------------------------------------------------- +// C API for device. The API is under active development and eventually +// should allow registering a plugin device with TensorFlow. + +// Macro to control visibility of exported symbols in the shared library (.so, +// .dylib, .dll). +// This duplicates the TF_EXPORT macro definition in +// tensorflow/core/platform/macros.h in order to keep this .h file independent +// of any other includes. +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 +#endif // SWIG + +#ifdef __cplusplus +extern "C" { +#endif + +// TF_Device is a C wrapper to the C++ TF Device class. This is to be passed +// through TF_OpKernelContext, and is opaque to plugin. +typedef struct TF_Device TF_Device; + +typedef struct TF_VariableInfo TF_VariableInfo; + +// Returns a `TF_Device` pointer, which actually points to a C++ `Device`. +// Currently we only allow `NextPluggableDevice` to be casted as `TF_Device`, +// but in theory every this is a C API for every kind of device. +TF_CAPI_EXPORT extern TF_Device* TF_GetDevice(TF_OpKernelContext* ctx); + +TF_CAPI_EXPORT extern size_t TF_GetDeviceOrdinal(TF_Device* device); + +// -------------------------- Resource --------------------------------------- +// Create a `tensorflow::PluginResource` to the ResourceMgr provided by the +// `ctx`. The `tensorflow::PluginResource` wraps a resource by plugin (as a +// opaque pointer, since TensorFlow cannot parse it). `delete_func` is needed +// for ResourceMgr to clean up the resource. `status` will be set. +TF_CAPI_EXPORT extern void TF_CreatePluginResource( + TF_OpKernelContext* ctx, const char* container_name, + const char* plugin_resource_name, void* plugin_resource, + void (*delete_func)(void*), TF_Status* status); + +// If the ResourceMgr provided by the `ctx` has a resource +// `plugin_resource_name`, returns it in `*result_plugin_resource`. Otherwise, +// invokes create_func to create the resource. `delete_func` is needed for +// ResourceMgr to clean up the resource. `status` will be set. If `status` is +// not OK, `*result_plugin_resource` will be set as nullptr. +// +// Caller does not take ownership of the `plugin_resource`. +TF_CAPI_EXPORT extern void TF_LookupOrCreatePluginResource( + TF_OpKernelContext* ctx, const char* container_name, + const char* plugin_resource_name, void** result_plugin_resource, + void* (*create_func)(void*), void* create_func_args, + void (*delete_func)(void*), TF_Status* status); + +// ------------------------- VariableInfo ------------------------------------ +TF_CAPI_EXPORT extern TF_VariableInfo* TF_CreateVariableInfoFromContext( + TF_OpKernelContext* ctx, int index, TF_Status* status); + +TF_CAPI_EXPORT extern void TF_LockVariableInfos(TF_VariableInfo** vars, + int num_vars, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_AllocateTempForVariableInfo( + TF_OpKernelContext* ctx, TF_VariableInfo* var_info, TF_Status* status); + +TF_CAPI_EXPORT extern TF_Tensor* TF_GetTensorFromVariableInfo( + TF_VariableInfo* var_info, TF_Status* status); + +TF_CAPI_EXPORT extern void TF_DeleteVariableInfo(TF_VariableInfo* var_info); + +// --------------------- Coordination service -------------------------------- +// Returns a not owning pointer to the coordination service agent, which is +// opaque to plugin. Plugin OpKernels need to use the accompanying C APIs to +// access coordination service functionalities. +TF_CAPI_EXPORT extern TF_CoordinationServiceAgent* +TF_GetCoordinationServiceAgent(TF_OpKernelContext* ctx); + +// Returns true if the coordination service agent has been initialized. +TF_CAPI_EXPORT extern bool TF_CoordinationServiceIsInitialized( + TF_CoordinationServiceAgent* agent); + +TF_CAPI_EXPORT extern void TF_CoordinationServiceInsertKeyValue( + const char* key, const char* value, TF_CoordinationServiceAgent* agent, + TF_Status* status); + +// Obtains key-value from coorination service agent. The returned `TF_Buffer` +// is a newly allocated buffer to hold the string key-value, and caller is +// responsible for managing the lifetime. If error, `status` will be set and a +// nullptr will be returned. +TF_CAPI_EXPORT extern TF_Buffer* TF_CoordinationServiceGetKeyValue( + const char* key, TF_CoordinationServiceAgent* agent, TF_Status* status); + +TF_CAPI_EXPORT extern void TF_CoordinationServiceDeleteKeyValue( + const char* key, TF_CoordinationServiceAgent* agent, TF_Status* status); + +// ---------------------------- PJRT ----------------------------------------- +TF_CAPI_EXPORT extern void TF_CreateAndSetPjRtCApiClient( + const char* device_type, TF_Status* status); + +// Gets the `PJRT_Client*` stored in TF global ResourceManager. +TF_CAPI_EXPORT extern PJRT_Client* TF_GetPjRtCClient(const char* device_type, + TF_Status* status); + +// Gets the `PJRT_Buffer*` stored in the tensor. The status will contain error +// if the tensor does not have a `PjRtCApiBuffer`. +TF_CAPI_EXPORT extern PJRT_Buffer* TF_GetPjRtCBuffer(TF_Tensor* c_tensor, + TF_Status* status); + +// Creates a `PjRtCApiBuffer` with the `PJRT_Buffer*` passed in and set to the +// tensor. +TF_CAPI_EXPORT extern void TF_CreatePjRtBuffer(TF_Tensor* c_tensor, + PJRT_Buffer* c_buffer, + const char* device_type, + TF_Status* status); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_C_EXPERIMENTAL_NEXT_PLUGGABLE_DEVICE_C_API_H_ diff --git a/tensorflow/c/experimental/ops/BUILD b/tensorflow/c/experimental/ops/BUILD index e5cf1c39f65..13f1c808d45 100644 --- a/tensorflow/c/experimental/ops/BUILD +++ b/tensorflow/c/experimental/ops/BUILD @@ -3,6 +3,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") # Experimental ops. These will eventually be replaced by machine-generated versions. package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) diff --git a/tensorflow/c/experimental/ops/gen/BUILD b/tensorflow/c/experimental/ops/gen/BUILD index 21e855dceb9..7ab0a9f49c5 100644 --- a/tensorflow/c/experimental/ops/gen/BUILD +++ b/tensorflow/c/experimental/ops/gen/BUILD @@ -4,6 +4,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/c/experimental/ops/gen/common/BUILD b/tensorflow/c/experimental/ops/gen/common/BUILD index 2dcbc644cf0..a5618623bbd 100644 --- a/tensorflow/c/experimental/ops/gen/common/BUILD +++ b/tensorflow/c/experimental/ops/gen/common/BUILD @@ -4,6 +4,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/c/experimental/ops/gen/cpp/BUILD b/tensorflow/c/experimental/ops/gen/cpp/BUILD index 7b9aa347198..d2fd0294adb 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/BUILD +++ b/tensorflow/c/experimental/ops/gen/cpp/BUILD @@ -4,6 +4,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:private"], licenses = ["notice"], ) diff --git a/tensorflow/c/experimental/ops/gen/cpp/golden/BUILD b/tensorflow/c/experimental/ops/gen/cpp/golden/BUILD index 5180b86cece..86880db388b 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/golden/BUILD +++ b/tensorflow/c/experimental/ops/gen/cpp/golden/BUILD @@ -1,4 +1,5 @@ package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD b/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD index 2d41ae84512..7589ea2d2f2 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/BUILD @@ -4,6 +4,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:private"], licenses = ["notice"], ) diff --git a/tensorflow/c/experimental/ops/gen/cpp/views/BUILD b/tensorflow/c/experimental/ops/gen/cpp/views/BUILD index 455c6cac143..46f61c89d8e 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/views/BUILD +++ b/tensorflow/c/experimental/ops/gen/cpp/views/BUILD @@ -1,4 +1,5 @@ package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:private"], licenses = ["notice"], ) diff --git a/tensorflow/c/experimental/ops/gen/model/BUILD b/tensorflow/c/experimental/ops/gen/model/BUILD index 04df5d61748..918acaabb6b 100644 --- a/tensorflow/c/experimental/ops/gen/model/BUILD +++ b/tensorflow/c/experimental/ops/gen/model/BUILD @@ -1,4 +1,5 @@ package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//tensorflow/c/experimental/ops/gen:__subpackages__"], licenses = ["notice"], ) diff --git a/tensorflow/c/experimental/pluggable_profiler/BUILD b/tensorflow/c/experimental/pluggable_profiler/BUILD index 9fd79348de6..4e3de6a46c1 100644 --- a/tensorflow/c/experimental/pluggable_profiler/BUILD +++ b/tensorflow/c/experimental/pluggable_profiler/BUILD @@ -5,6 +5,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.default.bzl", "filegroup") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -61,8 +62,8 @@ cc_library( "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_helper", "//tensorflow/core/platform:status", - "//tensorflow/core/profiler:profiler_options_proto_cc", "//tensorflow/core/profiler/lib:profiler_interface", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/tsl/profiler/protobuf:profiler_options_proto_cc", ], ) diff --git a/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.cc b/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.cc index 6e8cc32e556..0efa257723b 100644 --- a/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.cc +++ b/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/profiler/lib/profiler_factory.h" #include "tensorflow/core/profiler/lib/profiler_interface.h" -#include "tensorflow/core/profiler/profiler_options.pb.h" +#include "tensorflow/tsl/profiler/protobuf/profiler_options.pb.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h b/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h index 103c0905f08..6dbbe4549ff 100644 --- a/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h +++ b/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/profiler/lib/profiler_interface.h" -#include "tensorflow/core/profiler/profiler_options.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tensorflow/tsl/profiler/protobuf/profiler_options.pb.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index 394c7de8b59..d72cf86a7bc 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -11,7 +11,9 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ + # copybara:uncomment() "//learning/brain/tfrt/aot:__pkg__", "//tensorflow/c:__subpackages__", "//tensorflow/c/experimental/saved_model/internal:__pkg__", "//tensorflow/cc/experimental/libtf:__pkg__", diff --git a/tensorflow/c/experimental/saved_model/core/ops/BUILD b/tensorflow/c/experimental/saved_model/core/ops/BUILD index 14fa051a4ab..cce725db3fc 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/BUILD +++ b/tensorflow/c/experimental/saved_model/core/ops/BUILD @@ -8,6 +8,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ # Restricting visibility for now "//tensorflow/c/experimental/saved_model/core:__subpackages__", diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index 3c2050e79ec..ab7de9bae06 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -3,9 +3,11 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") # This package contains classes corresponding to Revived SavedObjectGraph types # used by SavedModel. See https://cs.opensource.google/tensorflow/tensorflow/+/c575e2ba93c442121d98d3f125d83fed1339924d:tensorflow/core/protobuf/saved_object_graph.proto;l=56-62 package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ # Restricting visibility for now "//tensorflow/c/experimental/saved_model/core:__pkg__", + # copybara:uncomment "//learning/brain/tfrt/aot:__pkg__", ], licenses = ["notice"], ) diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc index 2a4297e2b67..660a417be8f 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc @@ -80,51 +80,6 @@ Status ConstantFromSavedConstant( return internal::TensorProtoToConstant(ctx, tensor_proto, output); } -// Finds the "signatures" object in the object graph, and fills a mapping of -// each signature's name to the corresponding function's node in the object -// graph. -Status GetSignaturesMap(const SavedObjectGraph& saved_objects, - gtl::FlatMap* signatures_map) { - if (saved_objects.nodes().empty()) { - return errors::FailedPrecondition("Saved Object Graph was empty."); - } - const SavedObject& root = saved_objects.nodes(0); - const SavedObject* signatures = nullptr; - for (const auto& child : root.children()) { - if (child.local_name() == "signatures") { - if (child.node_id() >= saved_objects.nodes().size()) { - return errors::FailedPrecondition( - "Signature object had child node id ", child.node_id(), - " which exceeds the size of the set of nodes"); - } - signatures = &saved_objects.nodes(child.node_id()); - } - } - - // Some basic sanity checks that this object is actually our "signatures" map - if (signatures == nullptr) { - // This is where the "signatures" attribute is always set: - // https://github.com/tensorflow/tensorflow/blob/a2c542a0d83227568f9214a2af9a38ae3625976f/tensorflow/python/saved_model/save.py#L1106-L1109 - return errors::FailedPrecondition( - "SavedObjectGraph's root object must have a child 'signatures' object"); - } - if (signatures->kind_case() != SavedObject::kUserObject) { - return errors::FailedPrecondition( - "Signatures must be a SavedObject of type UserObject."); - } - if (signatures->user_object().identifier() != "signature_map") { - // This is where the string comes from: - // https://github.com/tensorflow/tensorflow/blob/c59af2913aaec235d883f50428efef1086f4c0e6/tensorflow/python/saved_model/signature_serialization.py#L220 - return errors::FailedPrecondition( - "Signatures SavedObject must have identifier 'signature_map'."); - } - - for (const auto& child : signatures->children()) { - (*signatures_map)[child.local_name()] = child.node_id(); - } - return Status(); -} - // Perform some basic sanity checks on SavedConcreteFunction's input and // output signatures with respect to the corresponding FunctionDef's input // and output args. @@ -183,6 +138,50 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef( return Status(); } +} // namespace + +Status GetSignaturesMap(const SavedObjectGraph& saved_objects, + gtl::FlatMap* signatures_map) { + if (saved_objects.nodes().empty()) { + return errors::FailedPrecondition("Saved Object Graph was empty."); + } + const SavedObject& root = saved_objects.nodes(0); + const SavedObject* signatures = nullptr; + for (const auto& child : root.children()) { + if (child.local_name() == "signatures") { + if (child.node_id() >= saved_objects.nodes().size()) { + return errors::FailedPrecondition( + "Signature object had child node id ", child.node_id(), + " which exceeds the size of the set of nodes"); + } + signatures = &saved_objects.nodes(child.node_id()); + } + } + + // Some basic sanity checks that this object is actually our "signatures" map + if (signatures == nullptr) { + // This is where the "signatures" attribute is always set: + // https://github.com/tensorflow/tensorflow/blob/a2c542a0d83227568f9214a2af9a38ae3625976f/tensorflow/python/saved_model/save.py#L1106-L1109 + return errors::FailedPrecondition( + "SavedObjectGraph's root object must have a child 'signatures' object"); + } + if (signatures->kind_case() != SavedObject::kUserObject) { + return errors::FailedPrecondition( + "Signatures must be a SavedObject of type UserObject."); + } + if (signatures->user_object().identifier() != "signature_map") { + // This is where the string comes from: + // https://github.com/tensorflow/tensorflow/blob/c59af2913aaec235d883f50428efef1086f4c0e6/tensorflow/python/saved_model/signature_serialization.py#L220 + return errors::FailedPrecondition( + "Signatures SavedObject must have identifier 'signature_map'."); + } + + for (const auto& child : signatures->children()) { + (*signatures_map)[child.local_name()] = child.node_id(); + } + return Status(); +} + Status ValidateSingleConcreteFunction(const SavedFunction& saved_function) { // We only allow loading functions that have an annotated input signature, // which means there is 1:1 correspondence between tf.function @@ -198,8 +197,6 @@ Status ValidateSingleConcreteFunction(const SavedFunction& saved_function) { return Status(); } -} // namespace - Status LoadSavedAsset(ImmediateExecutionContext* ctx, const SavedAsset& asset, const std::string& saved_model_dir, absl::Span assets, @@ -438,9 +435,11 @@ Status PartiallyReviveSavedModelObjects(const MetaGraphDef& metagraph, resource_revival_state.device = node.resource().device(); objects->restored_resources[i] = std::move(resource_revival_state); } else if (node.kind_case() == SavedObject::kFunction) { - // Get the SavedFunction node and validate it has a single concrete func. + // Get the SavedFunction node and skip if it has no concrete functions. const SavedFunction& saved_function = node.function(); - TF_RETURN_IF_ERROR(ValidateSingleConcreteFunction(saved_function)); + if (saved_function.concrete_functions_size() < 1) { + continue; + } // Retrieve related function information. const std::string& function_name = saved_function.concrete_functions(0); diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h index db45e28087f..34b4499621c 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h @@ -94,6 +94,15 @@ gtl::FlatMap NodeToAttrMap( gtl::FlatMap FunctionNameToFunctionDefMap(const FunctionDefLibrary& library); +// Finds the "signatures" object in the object graph, and fills a mapping of +// each signature's name to the corresponding function's node in the object +// graph. +Status GetSignaturesMap(const SavedObjectGraph& saved_objects, + gtl::FlatMap* signatures_map); + +// Validates the `saved_function`. +Status ValidateSingleConcreteFunction(const SavedFunction& saved_function); + // Walks through the SavedObjectGraph in metagraph, and restores all nodes // (except "UserDefinedObjects") with their corresponding type in // "PartiallyRevivedObjects". diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index 2647a822f93..d6dc1f202b0 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -20,7 +20,10 @@ load( "tf_copts", ) -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) cc_library( name = "concrete_function", diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/BUILD b/tensorflow/c/experimental/saved_model/internal/testdata/BUILD index 49acc9274fc..ab1a6e3689e 100644 --- a/tensorflow/c/experimental/saved_model/internal/testdata/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/testdata/BUILD @@ -2,6 +2,7 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow:strict.default.bzl", "py_strict_binary") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) diff --git a/tensorflow/c/experimental/saved_model/public/BUILD b/tensorflow/c/experimental/saved_model/public/BUILD index 71fd46ab889..6a711ae1738 100644 --- a/tensorflow/c/experimental/saved_model/public/BUILD +++ b/tensorflow/c/experimental/saved_model/public/BUILD @@ -11,6 +11,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], # This is intentionally public default_visibility = [ "//visibility:public", diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD index 849c0f2c22b..d06c536f671 100644 --- a/tensorflow/c/experimental/stream_executor/BUILD +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -9,6 +9,7 @@ load( load("//tensorflow:tensorflow.default.bzl", "filegroup") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -43,13 +44,12 @@ cc_library( "//tensorflow/compiler/xla/stream_executor:executor_cache", "//tensorflow/compiler/xla/stream_executor:multi_platform_manager", "//tensorflow/compiler/xla/stream_executor:platform", - "//tensorflow/compiler/xla/stream_executor:stream_executor_internal", "//tensorflow/compiler/xla/stream_executor:stream_executor_pimpl", "//tensorflow/compiler/xla/stream_executor:timer", "//tensorflow/core:lib", "//tensorflow/core/common_runtime/device:device_utils", - "//tensorflow/core/platform:regexp", "//tensorflow/core/platform:strcat", + "@com_google_absl//absl/functional:any_invocable", ], ) diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index c8a9670156b..2ba7d3cc953 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -22,7 +22,9 @@ limitations under the License. #include "tensorflow/c/experimental/stream_executor/stream_executor.h" #include +#include +#include "absl/functional/any_invocable.h" #include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/c_api_macros_internal.h" #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" @@ -51,7 +53,7 @@ using tensorflow::StringPiece; using OwnedTFStatus = tensorflow::TF_StatusPtr; namespace { -port::Status ValidateSPPlatform(const SP_Platform& platform) { +tsl::Status ValidateSPPlatform(const SP_Platform& platform) { TF_VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE); TF_VALIDATE_NOT_NULL(SP_Platform, platform, name); TF_VALIDATE_NOT_NULL(SP_Platform, platform, type); @@ -63,7 +65,7 @@ port::Status ValidateSPPlatform(const SP_Platform& platform) { return ::tensorflow::OkStatus(); } -port::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) { +tsl::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) { TF_VALIDATE_STRUCT_SIZE(SP_PlatformFns, platform_fns, SP_PLATFORM_FNS_STRUCT_SIZE); TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, create_device); @@ -77,40 +79,40 @@ port::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) { return ::tensorflow::OkStatus(); } -port::Status ValidateSPTimerFns(const SP_TimerFns& timer_fns) { +tsl::Status ValidateSPTimerFns(const SP_TimerFns& timer_fns) { TF_VALIDATE_STRUCT_SIZE(SP_TimerFns, timer_fns, SP_TIMER_FNS_STRUCT_SIZE); TF_VALIDATE_NOT_NULL(SP_TimerFns, timer_fns, nanoseconds); return ::tensorflow::OkStatus(); } -port::Status ValidateSPAllocatorStats(const SP_AllocatorStats& stats) { +tsl::Status ValidateSPAllocatorStats(const SP_AllocatorStats& stats) { TF_VALIDATE_STRUCT_SIZE(SP_AllocatorStats, stats, SP_ALLOCATORSTATS_STRUCT_SIZE); // All other fields could theoretically be zero/null. return ::tensorflow::OkStatus(); } -port::Status ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase& mem) { +tsl::Status ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase& mem) { TF_VALIDATE_STRUCT_SIZE(SP_DeviceMemoryBase, mem, SP_DEVICE_MEMORY_BASE_STRUCT_SIZE); // All other fields could theoretically be zero/null. return ::tensorflow::OkStatus(); } -port::Status ValidateSPDevice(const SP_Device& device) { +tsl::Status ValidateSPDevice(const SP_Device& device) { TF_VALIDATE_STRUCT_SIZE(SP_Device, device, SP_DEVICE_STRUCT_SIZE); // All other fields could theoretically be zero/null. return ::tensorflow::OkStatus(); } -port::Status ValidateSPDeviceFns(const SP_DeviceFns& device_fns) { +tsl::Status ValidateSPDeviceFns(const SP_DeviceFns& device_fns) { TF_VALIDATE_STRUCT_SIZE(SP_DeviceFns, device_fns, SP_DEVICE_FNS_STRUCT_SIZE); // All other fields could theoretically be zero/null. return ::tensorflow::OkStatus(); } -port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se, - const SP_Platform& platform) { +tsl::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se, + const SP_Platform& platform) { TF_VALIDATE_STRUCT_SIZE(SP_StreamExecutor, se, SP_STREAM_EXECUTOR_STRUCT_SIZE); TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, allocate); @@ -149,7 +151,7 @@ port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se, return ::tensorflow::OkStatus(); } -port::Status ValidateSEPlatformRegistrationParams( +tsl::Status ValidateSEPlatformRegistrationParams( const SE_PlatformRegistrationParams& params) { TF_VALIDATE_STRUCT_SIZE(SE_PlatformRegistrationParams, params, SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE); @@ -193,7 +195,7 @@ DeviceMemoryBase DeviceMemoryBaseFromC(const SP_DeviceMemoryBase& mem) { // Wrapper that allows passing std::function across C API. struct HostCallbackContext { - std::function callback; + absl::AnyInvocable callback; }; // This wrapper allows calling `HostCallbackContext::callback` across C API. @@ -201,7 +203,7 @@ struct HostCallbackContext { // `callback_fn` to `host_callback` in `SP_StreamExecutor`. void HostCallbackTrampoline(void* ctx, TF_Status* status) { HostCallbackContext* host_ctx = static_cast(ctx); - port::Status s = host_ctx->callback(); + tsl::Status s = std::move(host_ctx->callback)(); Set_TF_Status_from_Status(status, s); delete host_ctx; } @@ -226,14 +228,14 @@ class CStreamExecutor : public internal::StreamExecutorInterface { platform_fns_->destroy_device(platform_, &device_); } - port::Status Init(int device_ordinal, DeviceOptions device_options) override { + tsl::Status Init(int device_ordinal, DeviceOptions device_options) override { return ::tensorflow::OkStatus(); } DeviceMemoryBase Allocate(uint64 size, int64_t memory_space) override { SP_DeviceMemoryBase mem = {SP_DEVICE_MEMORY_BASE_STRUCT_SIZE}; stream_executor_->allocate(&device_, size, memory_space, &mem); - port::Status status = ValidateSPDeviceMemoryBase(mem); + tsl::Status status = ValidateSPDeviceMemoryBase(mem); if (!status.ok()) { LOG(ERROR) << status.error_message(); } @@ -280,7 +282,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface { if (!has_stats) { return absl::nullopt; } - port::Status status = ValidateSPAllocatorStats(c_stats); + tsl::Status status = ValidateSPAllocatorStats(c_stats); if (!status.ok()) { LOG(ERROR) << status.error_message(); return absl::nullopt; @@ -310,38 +312,37 @@ class CStreamExecutor : public internal::StreamExecutorInterface { } return true; } - port::Status SynchronousMemZero(DeviceMemoryBase* location, - uint64 size) override { + tsl::Status SynchronousMemZero(DeviceMemoryBase* location, + uint64 size) override { // TODO(annarev): figure out if we should support memzero/memset // functionality by allocating on host and then copying to device. - return port::UnimplementedError( + return tsl::errors::Unimplemented( "SynchronousMemZero is not supported by pluggable device."); } - port::Status SynchronousMemSet(DeviceMemoryBase* location, int value, - uint64 size) override { - return port::UnimplementedError( + tsl::Status SynchronousMemSet(DeviceMemoryBase* location, int value, + uint64 size) override { + return tsl::errors::Unimplemented( "SynchronousMemSet is not supported by pluggable device."); } - port::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst, - const void* host_src, uint64 size) override { + tsl::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst, const void* host_src, + uint64 size) override { OwnedTFStatus c_status(TF_NewStatus()); SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(gpu_dst); stream_executor_->sync_memcpy_htod(&device_, &device_memory_base, host_src, size, c_status.get()); return StatusFromTF_Status(c_status.get()); } - port::Status SynchronousMemcpy(void* host_dst, - const DeviceMemoryBase& gpu_src, - uint64 size) override { + tsl::Status SynchronousMemcpy(void* host_dst, const DeviceMemoryBase& gpu_src, + uint64 size) override { OwnedTFStatus c_status(TF_NewStatus()); SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(&gpu_src); stream_executor_->sync_memcpy_dtoh(&device_, host_dst, &device_memory_base, size, c_status.get()); return StatusFromTF_Status(c_status.get()); } - port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, - uint64 size) override { + tsl::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, + uint64 size) override { OwnedTFStatus c_status(TF_NewStatus()); SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst); SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src); @@ -349,8 +350,8 @@ class CStreamExecutor : public internal::StreamExecutorInterface { &device_mem_src, size, c_status.get()); return StatusFromTF_Status(c_status.get()); } - port::Status MemZero(Stream* stream, DeviceMemoryBase* location, - uint64 size) override { + tsl::Status MemZero(Stream* stream, DeviceMemoryBase* location, + uint64 size) override { OwnedTFStatus c_status(TF_NewStatus()); SP_Stream stream_handle = static_cast(stream->implementation())->Handle(); @@ -359,8 +360,8 @@ class CStreamExecutor : public internal::StreamExecutorInterface { c_status.get()); return StatusFromTF_Status(c_status.get()); } - port::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern, - uint64 size) override { + tsl::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern, + uint64 size) override { OwnedTFStatus c_status(TF_NewStatus()); SP_Stream stream_handle = static_cast(stream->implementation())->Handle(); @@ -369,8 +370,8 @@ class CStreamExecutor : public internal::StreamExecutorInterface { size, c_status.get()); return StatusFromTF_Status(c_status.get()); } - port::Status Memset32(Stream* stream, DeviceMemoryBase* location, - uint32 pattern, uint64 size) override { + tsl::Status Memset32(Stream* stream, DeviceMemoryBase* location, + uint32 pattern, uint64 size) override { OwnedTFStatus c_status(TF_NewStatus()); SP_Stream stream_handle = static_cast(stream->implementation())->Handle(); @@ -424,27 +425,27 @@ class CStreamExecutor : public internal::StreamExecutorInterface { return true; } bool HostCallback(Stream* stream, - std::function callback) override { + absl::AnyInvocable callback) override { SP_Stream stream_handle = static_cast(stream->implementation())->Handle(); - HostCallbackContext* ctx = new HostCallbackContext{callback}; + HostCallbackContext* ctx = new HostCallbackContext{std::move(callback)}; return stream_executor_->host_callback(&device_, stream_handle, &HostCallbackTrampoline, ctx); } - port::Status AllocateEvent(Event* event) override { + tsl::Status AllocateEvent(Event* event) override { DCHECK(event != nullptr); return static_cast(event->implementation())->Create(); } - port::Status DeallocateEvent(Event* event) override { + tsl::Status DeallocateEvent(Event* event) override { static_cast(event->implementation())->Destroy(); return ::tensorflow::OkStatus(); } - port::Status RecordEvent(Stream* stream, Event* event) override { + tsl::Status RecordEvent(Stream* stream, Event* event) override { SP_Stream stream_handle = static_cast(stream->implementation())->Handle(); return static_cast(event->implementation())->Record(stream_handle); } - port::Status WaitForEvent(Stream* stream, Event* event) override { + tsl::Status WaitForEvent(Stream* stream, Event* event) override { SP_Stream stream_handle = static_cast(stream->implementation())->Handle(); SP_Event event_handle = @@ -452,7 +453,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface { OwnedTFStatus c_status(TF_NewStatus()); stream_executor_->wait_for_event(&device_, stream_handle, event_handle, c_status.get()); - port::Status s = StatusFromTF_Status(c_status.get()); + tsl::Status s = StatusFromTF_Status(c_status.get()); return s; } Event::Status PollForEventStatus(Event* event) override { @@ -464,7 +465,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface { } bool AllocateStream(Stream* stream) override { DCHECK(stream != nullptr); - port::Status status = + tsl::Status status = static_cast(stream->implementation())->Create(); // TODO(annarev): update AllocateStream to return status instead // (similar to AllocateEvent). @@ -488,7 +489,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface { return true; } bool AllocateTimer(Timer* timer) override { - port::Status status = + tsl::Status status = static_cast(timer->implementation())->Create(); // TODO(annarev): change return value of AllocateTimer // to status (similar to AllocateEvent). @@ -525,7 +526,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface { } return true; } - port::Status BlockHostForEvent(Stream* stream, Event* event) { + tsl::Status BlockHostForEvent(Stream* stream, Event* event) { OwnedTFStatus c_status(TF_NewStatus()); SP_Event event_handle = static_cast(event->implementation())->Handle(); @@ -534,7 +535,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface { return StatusFromTF_Status(c_status.get()); } - port::Status BlockHostUntilDone(Stream* stream) override { + tsl::Status BlockHostUntilDone(Stream* stream) override { OwnedTFStatus c_status(TF_NewStatus()); SP_Stream stream_handle = static_cast(stream->implementation())->Handle(); @@ -551,7 +552,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface { TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get())); stream_executor_->record_event(&device_, stream_handle, event_handle, c_status.get()); - port::Status s = StatusFromTF_Status(c_status.get()); + tsl::Status s = StatusFromTF_Status(c_status.get()); if (!s.ok()) { stream_executor_->destroy_event(&device_, event_handle); return s; @@ -562,7 +563,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface { return StatusFromTF_Status(c_status.get()); } - port::Status GetStatus(Stream* stream) override { + tsl::Status GetStatus(Stream* stream) override { OwnedTFStatus c_status(TF_NewStatus()); SP_Stream stream_handle = static_cast(stream->implementation())->Handle(); @@ -571,8 +572,8 @@ class CStreamExecutor : public internal::StreamExecutorInterface { return StatusFromTF_Status(c_status.get()); } int PlatformDeviceCount() override { return visible_device_count_; } - port::Status EnablePeerAccessTo(StreamExecutorInterface* other) override { - return port::UnimplementedError( + tsl::Status EnablePeerAccessTo(StreamExecutorInterface* other) override { + return tsl::errors::Unimplemented( "EnablePeerAccessTo is not supported by pluggable device."); } bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override { @@ -587,7 +588,7 @@ class CStreamExecutor : public internal::StreamExecutorInterface { // Creates a new DeviceDescription object. // Ownership is transferred to the caller. - port::StatusOr> CreateDeviceDescription() + tsl::StatusOr> CreateDeviceDescription() const override { OwnedTFStatus c_status(TF_NewStatus()); @@ -679,7 +680,7 @@ CPlatform::~CPlatform() { destroy_platform_fns_(&platform_fns_); } -port::StatusOr> +tsl::StatusOr> CPlatform::DescriptionForDevice(int ordinal) const { // TODO(annarev): see if we can get StreamExecutor instance // and call GetDeviceDescription. executor_cache_.Get would need @@ -688,24 +689,24 @@ CPlatform::DescriptionForDevice(int ordinal) const { builder.set_name(name_); return builder.Build(); } -port::StatusOr CPlatform::ExecutorForDevice(int ordinal) { +tsl::StatusOr CPlatform::ExecutorForDevice(int ordinal) { stream_executor::StreamExecutorConfig config; config.ordinal = ordinal; return GetExecutor(config); } -port::StatusOr CPlatform::ExecutorForDeviceWithPluginConfig( +tsl::StatusOr CPlatform::ExecutorForDeviceWithPluginConfig( int ordinal, const PluginConfig& plugin_config) { StreamExecutorConfig config; config.ordinal = ordinal; config.plugin_config = plugin_config; return GetExecutor(config); } -port::StatusOr CPlatform::GetExecutor( +tsl::StatusOr CPlatform::GetExecutor( const StreamExecutorConfig& config) { return executor_cache_.GetOrCreate( config, [&]() { return GetUncachedExecutor(config); }); } -port::StatusOr> CPlatform::GetUncachedExecutor( +tsl::StatusOr> CPlatform::GetUncachedExecutor( const StreamExecutorConfig& config) { // Fill device creation params SE_CreateDeviceParams device_params{SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE}; @@ -734,9 +735,8 @@ port::StatusOr> CPlatform::GetUncachedExecutor( return result; } -port::Status InitStreamExecutorPlugin(void* dso_handle, - std::string* device_type, - std::string* platform_name) { +tsl::Status InitStreamExecutorPlugin(void* dso_handle, std::string* device_type, + std::string* platform_name) { tensorflow::Env* env = tensorflow::Env::Default(); // Step 1: Load symbol for `TF_InitPlugin` @@ -749,9 +749,9 @@ port::Status InitStreamExecutorPlugin(void* dso_handle, return InitStreamExecutorPlugin(init_fn, device_type, platform_name); } -port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn, - std::string* device_type, - std::string* platform_name) { +tsl::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn, + std::string* device_type, + std::string* platform_name) { SE_PlatformRegistrationParams params{ SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE}; SP_Platform platform{SP_PLATFORM_STRUCT_SIZE}; @@ -804,7 +804,7 @@ port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn, std::move(platform), params.destroy_platform, std::move(platform_fns), params.destroy_platform_fns, std::move(device_fns), std::move(se), std::move(timer_fns))); - SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform( + TF_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform( std::move(cplatform))); // TODO(annarev): Return `use_bfc_allocator` value in some way so that it is // available in `PluggableDeviceProcessState` once the latter is checked in. diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h index 7246dde2660..ad8a77d61fa 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/c/experimental/stream_executor/stream_executor.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/compiler/xla/stream_executor/executor_cache.h" -#include "tensorflow/compiler/xla/stream_executor/lib/status.h" #include "tensorflow/compiler/xla/stream_executor/platform.h" namespace stream_executor { @@ -33,15 +32,14 @@ typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const, // Registers StreamExecutor platform. `device_type` and `platform_name` are // output parameters. -port::Status InitStreamExecutorPlugin(void* dso_handle, - std::string* device_type, - std::string* platform_name); +tsl::Status InitStreamExecutorPlugin(void* dso_handle, std::string* device_type, + std::string* platform_name); // Allow registering a StreamExecutor plugin using a function (used for // testing). -port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn, - std::string* device_type, - std::string* platform_name); +tsl::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn, + std::string* device_type, + std::string* platform_name); // This file implements core stream executor base classes in terms of // the C API defined in stream_executor.h. A class "CSomething" represents a @@ -71,14 +69,14 @@ class CPlatform : public Platform { } bool UseBfcAllocator() const { return platform_.use_bfc_allocator; } bool ForceMemoryGrowth() const { return platform_.force_memory_growth; } - port::StatusOr> DescriptionForDevice( + tsl::StatusOr> DescriptionForDevice( int ordinal) const override; - port::StatusOr ExecutorForDevice(int ordinal) override; - port::StatusOr ExecutorForDeviceWithPluginConfig( + tsl::StatusOr ExecutorForDevice(int ordinal) override; + tsl::StatusOr ExecutorForDeviceWithPluginConfig( int ordinal, const PluginConfig& plugin_config) override; - port::StatusOr GetExecutor( + tsl::StatusOr GetExecutor( const StreamExecutorConfig& config) override; - port::StatusOr> GetUncachedExecutor( + tsl::StatusOr> GetUncachedExecutor( const StreamExecutorConfig& config) override; // Trace listener is not supported @@ -110,10 +108,10 @@ class CStream : public internal::StreamInterface { stream_handle_(nullptr) {} ~CStream() override { Destroy(); } - port::Status Create() { + tsl::Status Create() { tensorflow::TF_StatusPtr c_status(TF_NewStatus()); stream_executor_->create_stream(device_, &stream_handle_, c_status.get()); - port::Status s = tensorflow::StatusFromTF_Status(c_status.get()); + tsl::Status s = tensorflow::StatusFromTF_Status(c_status.get()); return s; } @@ -140,13 +138,13 @@ class CEvent : public internal::EventInterface { event_handle_(nullptr) {} ~CEvent() override { Destroy(); } - port::Status Create() { + tsl::Status Create() { tensorflow::TF_StatusPtr c_status(TF_NewStatus()); stream_executor_->create_event(device_, &event_handle_, c_status.get()); return tensorflow::StatusFromTF_Status(c_status.get()); } - port::Status Record(SP_Stream stream_handle) { + tsl::Status Record(SP_Stream stream_handle) { tensorflow::TF_StatusPtr c_status(TF_NewStatus()); stream_executor_->record_event(device_, stream_handle, event_handle_, c_status.get()); @@ -178,7 +176,7 @@ class CTimer : public internal::TimerInterface { timer_fns_(timer_fns) {} ~CTimer() override { Destroy(); } - port::Status Create() { + tsl::Status Create() { tensorflow::TF_StatusPtr c_status(TF_NewStatus()); stream_executor_->create_timer(device_, &timer_handle_, c_status.get()); return tensorflow::StatusFromTF_Status(c_status.get()); diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc index 8b82121c51d..cf21374c48f 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc @@ -38,17 +38,17 @@ TEST(StreamExecutor, SuccessfulRegistration) { test_util::PopulateDefaultPlatformRegistrationParams(params); }; std::string device_type, platform_name; - port::Status status = + tsl::Status status = InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); TF_ASSERT_OK(status); - port::StatusOr maybe_platform = + tsl::StatusOr maybe_platform = MultiPlatformManager::PlatformWithName("MY_DEVICE"); TF_ASSERT_OK(maybe_platform.status()); Platform* platform = std::move(maybe_platform).value(); ASSERT_EQ(platform->Name(), test_util::kDeviceName); ASSERT_EQ(platform->VisibleDeviceCount(), test_util::kDeviceCount); - port::StatusOr maybe_executor = + tsl::StatusOr maybe_executor = platform->ExecutorForDevice(0); TF_ASSERT_OK(maybe_executor.status()); } @@ -62,7 +62,7 @@ TEST(StreamExecutor, NameNotSet) { }; std::string device_type, platform_name; - port::Status status = + tsl::Status status = InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set."); @@ -77,7 +77,7 @@ TEST(StreamExecutor, InvalidNameWithSemicolon) { }; std::string device_type, platform_name; - port::Status status = + tsl::Status status = InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); EXPECT_THAT( @@ -94,7 +94,7 @@ TEST(StreamExecutor, InvalidNameWithSlash) { }; std::string device_type, platform_name; - port::Status status = + tsl::Status status = InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); EXPECT_THAT(status.error_message(), @@ -110,7 +110,7 @@ TEST(StreamExecutor, CreateDeviceNotSet) { }; std::string device_type, platform_name; - port::Status status = + tsl::Status status = InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); ASSERT_EQ(status.error_message(), @@ -126,7 +126,7 @@ TEST(StreamExecutor, UnifiedMemoryAllocateNotSet) { }; std::string device_type, platform_name; - port::Status status = + tsl::Status status = InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); ASSERT_EQ( @@ -152,7 +152,7 @@ class StreamExecutorTest : public ::testing::Test { platform_, test_util::DestroyPlatform, platform_fns_, test_util::DestroyPlatformFns, device_fns_, se_, timer_fns_); } - port::StatusOr maybe_executor = + tsl::StatusOr maybe_executor = cplatform_->ExecutorForDevice(ordinal); TF_CHECK_OK(maybe_executor.status()); return std::move(maybe_executor).value(); @@ -724,7 +724,7 @@ TEST_F(StreamExecutorTest, HostCallbackOk) { StreamExecutor* executor = GetExecutor(0); Stream stream(executor); stream.Init(); - std::function callback = []() -> port::Status { + std::function callback = []() -> tsl::Status { return ::tensorflow::OkStatus(); }; stream.ThenDoHostCallbackWithStatus(callback); @@ -744,8 +744,8 @@ TEST_F(StreamExecutorTest, HostCallbackError) { StreamExecutor* executor = GetExecutor(0); Stream stream(executor); stream.Init(); - std::function callback = []() -> port::Status { - return port::UnimplementedError("Unimplemented"); + std::function callback = []() -> tsl::Status { + return tsl::errors::Unimplemented("Unimplemented"); }; stream.ThenDoHostCallbackWithStatus(callback); ASSERT_FALSE(stream.ok()); diff --git a/tensorflow/c/experimental/stream_executor/test/BUILD b/tensorflow/c/experimental/stream_executor/test/BUILD index e3795a2715b..2a4d40b3e79 100644 --- a/tensorflow/c/experimental/stream_executor/test/BUILD +++ b/tensorflow/c/experimental/stream_executor/test/BUILD @@ -6,6 +6,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index c3a54a46b3c..85b2433ac43 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/kernels.h" #include +#include #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_macros.h" @@ -26,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/resource_handle.pb.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/types.h" // Required for IS_MOBILE_PLATFORM definition @@ -295,6 +297,13 @@ void TF_InputRange(TF_OpKernelContext* ctx, const char* name, tensorflow::Set_TF_Status_from_Status(args->status, status); } +TF_DataType TF_InputDatatype(TF_OpKernelContext* ctx, int index) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + CHECK_GE(index, 0); // Crash OK + CHECK_LT(index, cc_ctx->num_inputs()); // Crash OK + return static_cast(cc_ctx->input_dtype(index)); +} + void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor, TF_Status* status) { auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); @@ -354,6 +363,18 @@ void TF_GetSerializedConfigProto(TF_OpKernelContext* ctx, tensorflow::Set_TF_Status_from_Status(status, cc_status); } +void TF_GetSerializedResourceHandleProto( + TF_OpKernelContext* ctx, int i, TF_Buffer* serialized_resource_handle_proto, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + const tensorflow::ResourceHandle& handle = HandleFromInput(cc_ctx, i); + tensorflow::ResourceHandleProto handle_proto; + handle.AsProto(&handle_proto); + auto cc_status = tensorflow::MessageToBuffer( + handle_proto, serialized_resource_handle_proto); + tensorflow::Set_TF_Status_from_Status(status, cc_status); +} + void TF_OpKernelConstruction_Failure(TF_OpKernelConstruction* ctx, TF_Status* status) { auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); @@ -652,6 +673,18 @@ int64_t TF_GetIterId(TF_OpKernelContext* ctx) { .iter_id; } +int64_t TF_GetStepId(TF_OpKernelContext* ctx) { + return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id(); +} + +int TF_GetDeviceId(TF_OpKernelContext* ctx) { + // TensorFlow always sets device in OpKernelContext. + auto* device = + reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->device(); + if (!device->parsed_name().has_id) return -1; + return device->parsed_name().id; +} + TF_StringView TF_GetOpKernelName(TF_OpKernelContext* ctx) { auto cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); TF_StringView opkernel_name_sv; diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index e85dc9f252a..2e765b7dfaa 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -190,6 +190,11 @@ TF_CAPI_EXPORT extern void TF_InputRange(TF_OpKernelContext* ctx, const char* name, TF_InputRange_Args* args); +// Returns the data type of the index-th input. If index < 0 or index >= +// TF_NumInputs(ctx), the program aborts. +TF_CAPI_EXPORT extern TF_DataType TF_InputDatatype(TF_OpKernelContext* ctx, + int index); + // Sets the ith output of ctx to tensor. If TF_GetCode(status) is anything but // TF_OK, ctx is left unmodified. // @@ -216,6 +221,11 @@ TF_CAPI_EXPORT extern void TF_GetSerializedConfigProto( TF_OpKernelContext* ctx, TF_Buffer* serialized_config_proto, TF_Status* status); +// Retrieves a serialized ResourceHandleProto. Status will be set. +TF_CAPI_EXPORT extern void TF_GetSerializedResourceHandleProto( + TF_OpKernelContext* ctx, int i, TF_Buffer* serialized_resource_handle_proto, + TF_Status* status); + // Notifies the given OpKernelConstruction that kernel construction has failed. TF_CAPI_EXPORT extern void TF_OpKernelConstruction_Failure( TF_OpKernelConstruction* ctx, TF_Status* status); @@ -253,6 +263,12 @@ TF_CAPI_EXPORT extern uint64_t TF_GetFrameId(TF_OpKernelContext* ctx); // Returns the Iter ID of the given context. TF_CAPI_EXPORT extern int64_t TF_GetIterId(TF_OpKernelContext* ctx); +// Returns the Step ID of the given context. +TF_CAPI_EXPORT extern int64_t TF_GetStepId(TF_OpKernelContext* ctx); + +// Returns the Device ID of the device that the context possesses. +TF_CAPI_EXPORT extern int TF_GetDeviceId(TF_OpKernelContext* ctx); + // Returns the graph def version of the given context. TF_CAPI_EXPORT extern int TF_GetGraphDefVersion(TF_OpKernelContext* ctx); diff --git a/tensorflow/c/kernels/BUILD b/tensorflow/c/kernels/BUILD index 99fbcfabab4..93ed9a7880b 100644 --- a/tensorflow/c/kernels/BUILD +++ b/tensorflow/c/kernels/BUILD @@ -3,6 +3,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/c/tf_datatype.h b/tensorflow/c/tf_datatype.h index 3e6121bf989..df0c1fb45b0 100644 --- a/tensorflow/c/tf_datatype.h +++ b/tensorflow/c/tf_datatype.h @@ -59,7 +59,7 @@ typedef enum TF_DataType { TF_QINT8 = 11, // Quantized int8 TF_QUINT8 = 12, // Quantized uint8 TF_QINT32 = 13, // Quantized int32 - TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. Only for cast ops. + TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. TF_QINT16 = 15, // Quantized int16 TF_QUINT16 = 16, // Quantized uint16 TF_UINT16 = 17, @@ -69,6 +69,9 @@ typedef enum TF_DataType { TF_VARIANT = 21, TF_UINT32 = 22, TF_UINT64 = 23, + TF_FLOAT8_E5M2 = 24, // 5 exponent bits, 2 mantissa bits. + TF_FLOAT8_E4M3FN = 25, // 4 exponent bits, 3 mantissa bits, finite-only, with + // 2 NaNs (0bS1111111). } TF_DataType; // TF_DataTypeSize returns the sizeof() for the underlying type corresponding diff --git a/tensorflow/c/tf_status.cc b/tensorflow/c/tf_status.cc index 2f774fa7977..686e09508ac 100644 --- a/tensorflow/c/tf_status.cc +++ b/tensorflow/c/tf_status.cc @@ -16,39 +16,21 @@ limitations under the License. #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_internal.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" -using ::tensorflow::Status; -using ::tensorflow::error::Code; -using ::tensorflow::errors::IOError; - -TF_Status* TF_NewStatus() { return new TF_Status; } - -void TF_DeleteStatus(TF_Status* s) { delete s; } +// Trampoline implementation to redirect to TSL. Kept here for backward +// compatibility only. +TF_Status* TF_NewStatus() { return TSL_NewStatus(); } +void TF_DeleteStatus(TF_Status* s) { TSL_DeleteStatus(s); } void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) { - if (code == TF_OK) { - s->status = ::tensorflow::OkStatus(); - return; - } - s->status = Status(static_cast(code), tensorflow::StringPiece(msg)); + TSL_SetStatus(s, TSL_Code(code), msg); } - void TF_SetPayload(TF_Status* s, const char* key, const char* value) { - s->status.SetPayload(key, value); + TSL_SetPayload(s, key, value); } - void TF_SetStatusFromIOError(TF_Status* s, int error_code, const char* context) { - // TODO(b/139060984): Handle windows when changing its filesystem - s->status = IOError(context, error_code); -} - -TF_Code TF_GetCode(const TF_Status* s) { - return static_cast(s->status.code()); -} - -const char* TF_Message(const TF_Status* s) { - return s->status.error_message().c_str(); + TSL_SetStatusFromIOError(s, error_code, context); } +TF_Code TF_GetCode(const TF_Status* s) { return TF_Code(TSL_GetCode(s)); } +const char* TF_Message(const TF_Status* s) { return TSL_Message(s); } diff --git a/tensorflow/c/tf_status.h b/tensorflow/c/tf_status.h index 4616ee434d9..db1d32bf8e7 100644 --- a/tensorflow/c/tf_status.h +++ b/tensorflow/c/tf_status.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_C_TF_STATUS_H_ #define TENSORFLOW_C_TF_STATUS_H_ +#include "tensorflow/tsl/c/tsl_status.h" + #ifdef SWIG #define TF_CAPI_EXPORT #else @@ -34,30 +36,29 @@ limitations under the License. extern "C" { #endif -typedef struct TF_Status TF_Status; +typedef struct TSL_Status TF_Status; // -------------------------------------------------------------------------- // TF_Code holds an error code. The enum values here are identical to // corresponding values in error_codes.proto. -typedef enum TF_Code { - TF_OK = 0, - TF_CANCELLED = 1, - TF_UNKNOWN = 2, - TF_INVALID_ARGUMENT = 3, - TF_DEADLINE_EXCEEDED = 4, - TF_NOT_FOUND = 5, - TF_ALREADY_EXISTS = 6, - TF_PERMISSION_DENIED = 7, - TF_UNAUTHENTICATED = 16, - TF_RESOURCE_EXHAUSTED = 8, - TF_FAILED_PRECONDITION = 9, - TF_ABORTED = 10, - TF_OUT_OF_RANGE = 11, - TF_UNIMPLEMENTED = 12, - TF_INTERNAL = 13, - TF_UNAVAILABLE = 14, - TF_DATA_LOSS = 15, -} TF_Code; +typedef TSL_Code TF_Code; +#define TF_OK TSL_OK +#define TF_CANCELLED TSL_CANCELLED +#define TF_UNKNOWN TSL_UNKNOWN +#define TF_INVALID_ARGUMENT TSL_INVALID_ARGUMENT +#define TF_DEADLINE_EXCEEDED TSL_DEADLINE_EXCEEDED +#define TF_NOT_FOUND TSL_NOT_FOUND +#define TF_ALREADY_EXISTS TSL_ALREADY_EXISTS +#define TF_PERMISSION_DENIED TSL_PERMISSION_DENIED +#define TF_UNAUTHENTICATED TSL_UNAUTHENTICATED +#define TF_RESOURCE_EXHAUSTED TSL_RESOURCE_EXHAUSTED +#define TF_FAILED_PRECONDITION TSL_FAILED_PRECONDITION +#define TF_ABORTED TSL_ABORTED +#define TF_OUT_OF_RANGE TSL_OUT_OF_RANGE +#define TF_UNIMPLEMENTED TSL_UNIMPLEMENTED +#define TF_INTERNAL TSL_INTERNAL +#define TF_UNAVAILABLE TSL_UNAVAILABLE +#define TF_DATA_LOSS TSL_DATA_LOSS // -------------------------------------------------------------------------- diff --git a/tensorflow/c/tf_status_helper.cc b/tensorflow/c/tf_status_helper.cc index 1e4360d5531..9155d9dde8b 100644 --- a/tensorflow/c/tf_status_helper.cc +++ b/tensorflow/c/tf_status_helper.cc @@ -17,75 +17,16 @@ limitations under the License. #include "tensorflow/c/tf_status_internal.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/tsl/c/tsl_status_helper.h" namespace tsl { void Set_TF_Status_from_Status(TF_Status* tf_status, const Status& status) { - tensorflow::error::Code code = status.code(); - const char* message(status.error_message().c_str()); - - switch (code) { - case tensorflow::error::OK: - assert(TF_GetCode(tf_status) == TF_OK); - break; - case tensorflow::error::CANCELLED: - TF_SetStatus(tf_status, TF_CANCELLED, message); - break; - case tensorflow::error::UNKNOWN: - TF_SetStatus(tf_status, TF_UNKNOWN, message); - break; - case tensorflow::error::INVALID_ARGUMENT: - TF_SetStatus(tf_status, TF_INVALID_ARGUMENT, message); - break; - case tensorflow::error::DEADLINE_EXCEEDED: - TF_SetStatus(tf_status, TF_DEADLINE_EXCEEDED, message); - break; - case tensorflow::error::NOT_FOUND: - TF_SetStatus(tf_status, TF_NOT_FOUND, message); - break; - case tensorflow::error::ALREADY_EXISTS: - TF_SetStatus(tf_status, TF_ALREADY_EXISTS, message); - break; - case tensorflow::error::PERMISSION_DENIED: - TF_SetStatus(tf_status, TF_PERMISSION_DENIED, message); - break; - case tensorflow::error::UNAUTHENTICATED: - TF_SetStatus(tf_status, TF_UNAUTHENTICATED, message); - break; - case tensorflow::error::RESOURCE_EXHAUSTED: - TF_SetStatus(tf_status, TF_RESOURCE_EXHAUSTED, message); - break; - case tensorflow::error::FAILED_PRECONDITION: - TF_SetStatus(tf_status, TF_FAILED_PRECONDITION, message); - break; - case tensorflow::error::ABORTED: - TF_SetStatus(tf_status, TF_ABORTED, message); - break; - case tensorflow::error::OUT_OF_RANGE: - TF_SetStatus(tf_status, TF_OUT_OF_RANGE, message); - break; - case tensorflow::error::UNIMPLEMENTED: - TF_SetStatus(tf_status, TF_UNIMPLEMENTED, message); - break; - case tensorflow::error::INTERNAL: - TF_SetStatus(tf_status, TF_INTERNAL, message); - break; - case tensorflow::error::UNAVAILABLE: - TF_SetStatus(tf_status, TF_UNAVAILABLE, message); - break; - case tensorflow::error::DATA_LOSS: - TF_SetStatus(tf_status, TF_DATA_LOSS, message); - break; - default: - assert(0); - break; - } - - errors::CopyPayloads(status, tf_status->status); + Set_TSL_Status_from_Status(tf_status, status); } Status StatusFromTF_Status(const TF_Status* tf_status) { - return tf_status->status; + return StatusFromTSL_Status(tf_status); } } // namespace tsl diff --git a/tensorflow/c/tf_status_helper.h b/tensorflow/c/tf_status_helper.h index 4c3c8af6864..df4600b85dc 100644 --- a/tensorflow/c/tf_status_helper.h +++ b/tensorflow/c/tf_status_helper.h @@ -21,10 +21,10 @@ limitations under the License. namespace tsl { // Set the attribute of "tf_status" from the attributes of "status". -void Set_TF_Status_from_Status(TF_Status* tf_status, const tsl::Status& status); +void Set_TF_Status_from_Status(TF_Status* tf_status, const Status& status); // Returns a "status" from "tf_status". -tensorflow::Status StatusFromTF_Status(const TF_Status* tf_status); +Status StatusFromTF_Status(const TF_Status* tf_status); } // namespace tsl namespace tensorflow { diff --git a/tensorflow/c/tf_status_internal.h b/tensorflow/c/tf_status_internal.h index 1e0f99819ff..7a40d6f518e 100644 --- a/tensorflow/c/tf_status_internal.h +++ b/tensorflow/c/tf_status_internal.h @@ -16,13 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_C_TF_STATUS_INTERNAL_H_ #define TENSORFLOW_C_TF_STATUS_INTERNAL_H_ -#include "tensorflow/core/platform/status.h" +#include "tensorflow/tsl/c/tsl_status_internal.h" -// Internal structures used by the status C API. These are likely to change -// and should not be depended on. - -struct TF_Status { - tensorflow::Status status; -}; +typedef struct TSL_Status TF_Status; #endif // TENSORFLOW_C_TF_STATUS_INTERNAL_H_ diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc index 7bf662d81e0..e007af200c4 100644 --- a/tensorflow/c/tf_tensor.cc +++ b/tensorflow/c/tf_tensor.cc @@ -247,7 +247,7 @@ Status TensorInterface::BitcastFrom(const TensorInterface& from, DataType type, const int64_t* new_dims, int num_new_dims) { tensorflow::TensorShape s; for (int i = 0; i < num_new_dims; ++i) { - s.AddDim(new_dims[i]); + TF_RETURN_IF_ERROR(s.AddDimWithStatus(new_dims[i])); } return tensor_.BitcastFrom(from.tensor_, type, s); } diff --git a/tensorflow/c/tf_tstring.h b/tensorflow/c/tf_tstring.h index 5dc29f23d59..f9fb2fe083f 100644 --- a/tensorflow/c/tf_tstring.h +++ b/tensorflow/c/tf_tstring.h @@ -59,4 +59,4 @@ TF_CAPI_EXPORT extern void TF_StringDealloc(TF_TString *tstr); } /* end extern "C" */ #endif -#endif // THIRD_PARTY_TENSORFLOW_C_TF_TSTRING_H_ +#endif // TENSORFLOW_C_TF_TSTRING_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index df7472e08c7..4fc555871af 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -13,6 +13,7 @@ load( load("//tensorflow:tensorflow.default.bzl", "filegroup", "tf_gen_op_wrappers_cc") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc index 27ec4c0871d..3c5357f739e 100644 --- a/tensorflow/cc/client/client_session_test.cc +++ b/tensorflow/cc/client/client_session_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/client/client_session.h" +#include #include #include "absl/synchronization/barrier.h" @@ -39,6 +40,14 @@ using ops::Mul; using ops::Placeholder; using ops::Sub; +tensorflow::SessionOptions GetSessionOptions() { + tensorflow::SessionOptions options; + // Disable optimizations for static graph to allow calls to Session::Extend. + options.config.mutable_experimental()->set_disable_optimize_for_static_graph( + true); + return options; +} + class CustomThreadPoolImpl : public thread::ThreadPoolInterface { public: explicit CustomThreadPoolImpl(int numThreads) { @@ -100,7 +109,7 @@ TEST(ClientSessionTest, Extend) { Scope root = Scope::NewRootScope(); auto a = Placeholder(root, DT_INT32, Placeholder::Shape({2})); auto c = Add(root, a, {2, 2}); - ClientSession session(root); + ClientSession session(root, GetSessionOptions()); std::vector outputs; TF_EXPECT_OK(session.Run({{a, {1, 1}}}, {c}, &outputs)); @@ -116,7 +125,7 @@ TEST(ClientSessionTest, MultiThreadedWithDefaultThreadpool) { Scope root = Scope::NewRootScope(); auto a = Add(root, {1, 2}, {3, 4}); auto b = Mul(root, {1, 2}, {3, 4}); - ClientSession session(root); + ClientSession session(root, GetSessionOptions()); { thread::ThreadPool thread_pool(Env::Default(), "pool", 2); thread_pool.Schedule([&session, a]() { @@ -143,7 +152,7 @@ TEST(ClientSessionTest, MultiThreadedWithCustomThreadpool) { int num_threads = 3; auto a = Add(root, {1, 2}, {3, 4}); auto b = Mul(root, {1, 2}, {3, 4}); - ClientSession session(root); + ClientSession session(root, GetSessionOptions()); auto inter_op_threadpool = absl::make_unique(num_threads); diff --git a/tensorflow/cc/experimental/base/public/BUILD b/tensorflow/cc/experimental/base/public/BUILD index 5313b502bf5..7c1a040960f 100644 --- a/tensorflow/cc/experimental/base/public/BUILD +++ b/tensorflow/cc/experimental/base/public/BUILD @@ -11,6 +11,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], # This is intentionally public default_visibility = [ "//visibility:public", diff --git a/tensorflow/cc/experimental/base/tests/BUILD b/tensorflow/cc/experimental/base/tests/BUILD index 5f442faa77c..e749d2433bd 100644 --- a/tensorflow/cc/experimental/base/tests/BUILD +++ b/tensorflow/cc/experimental/base/tests/BUILD @@ -4,6 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) diff --git a/tensorflow/cc/experimental/libexport/BUILD b/tensorflow/cc/experimental/libexport/BUILD index 5533cf76431..910ab930440 100644 --- a/tensorflow/cc/experimental/libexport/BUILD +++ b/tensorflow/cc/experimental/libexport/BUILD @@ -5,6 +5,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/cc/experimental/libtf:__subpackages__", diff --git a/tensorflow/cc/experimental/libtf/BUILD b/tensorflow/cc/experimental/libtf/BUILD index e9529725d94..e281672de9e 100644 --- a/tensorflow/cc/experimental/libtf/BUILD +++ b/tensorflow/cc/experimental/libtf/BUILD @@ -12,6 +12,7 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow:strict.default.bzl", "py_strict_binary") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/cc/experimental/libtf:__subpackages__", ], diff --git a/tensorflow/cc/experimental/libtf/impl/BUILD b/tensorflow/cc/experimental/libtf/impl/BUILD index 8231a25102e..0eae5a1f05c 100644 --- a/tensorflow/cc/experimental/libtf/impl/BUILD +++ b/tensorflow/cc/experimental/libtf/impl/BUILD @@ -10,6 +10,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/cc/experimental/libtf:__subpackages__", ], diff --git a/tensorflow/cc/experimental/libtf/mlir/BUILD b/tensorflow/cc/experimental/libtf/mlir/BUILD index 2d42d855dae..51336186510 100644 --- a/tensorflow/cc/experimental/libtf/mlir/BUILD +++ b/tensorflow/cc/experimental/libtf/mlir/BUILD @@ -6,6 +6,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/cc/experimental/libtf:__subpackages__", ], diff --git a/tensorflow/cc/experimental/libtf/object.h b/tensorflow/cc/experimental/libtf/object.h index 8f510def431..72d05aaf430 100644 --- a/tensorflow/cc/experimental/libtf/object.h +++ b/tensorflow/cc/experimental/libtf/object.h @@ -166,7 +166,7 @@ class Object : public Handle { if (class_dict_maybe.type() == TaggedValue::DICT) { auto& dict = class_dict_maybe.dict(); auto it = dict.find(key.value_); - if (it != value_.dict().end()) { + if (it != dict.end()) { return Cast(Handle(it->second)); } } diff --git a/tensorflow/cc/experimental/libtf/runtime/BUILD b/tensorflow/cc/experimental/libtf/runtime/BUILD index 75f81a5a8a2..b20c0e6e3f9 100644 --- a/tensorflow/cc/experimental/libtf/runtime/BUILD +++ b/tensorflow/cc/experimental/libtf/runtime/BUILD @@ -4,6 +4,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/cc/experimental/libtf:__subpackages__", ], diff --git a/tensorflow/cc/experimental/libtf/runtime/core/BUILD b/tensorflow/cc/experimental/libtf/runtime/core/BUILD index cb750c4c7a4..83f61ee11ba 100644 --- a/tensorflow/cc/experimental/libtf/runtime/core/BUILD +++ b/tensorflow/cc/experimental/libtf/runtime/core/BUILD @@ -1,4 +1,5 @@ package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/cc/experimental/libtf:__subpackages__", ], diff --git a/tensorflow/cc/experimental/libtf/runtime/tfrt/BUILD b/tensorflow/cc/experimental/libtf/runtime/tfrt/BUILD index 6350e007875..586ef6b9523 100644 --- a/tensorflow/cc/experimental/libtf/runtime/tfrt/BUILD +++ b/tensorflow/cc/experimental/libtf/runtime/tfrt/BUILD @@ -1,4 +1,5 @@ package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/cc/experimental/libtf:__subpackages__", ], diff --git a/tensorflow/cc/experimental/libtf/tests/runtime_test_core.cc b/tensorflow/cc/experimental/libtf/tests/runtime_test_core.cc index 0be93c31a28..59952002522 100644 --- a/tensorflow/cc/experimental/libtf/tests/runtime_test_core.cc +++ b/tensorflow/cc/experimental/libtf/tests/runtime_test_core.cc @@ -21,7 +21,7 @@ namespace runtime { INSTANTIATE_TEST_SUITE_P(TF2CAPI, RuntimeTest, ::testing::Values(core::Runtime)); - +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(RuntimeTest); } // namespace runtime } // namespace libtf } // namespace tf diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index d0cd220f112..031451d3d2d 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/cc/framework/cc_op_gen.h" +#include #include #include #include diff --git a/tensorflow/cc/framework/cc_op_gen.h b/tensorflow/cc/framework/cc_op_gen.h index 9af3b9ce1e3..7b348365b33 100644 --- a/tensorflow/cc/framework/cc_op_gen.h +++ b/tensorflow/cc/framework/cc_op_gen.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ #define TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ +#include + #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/cc/framework/cc_op_gen_util.h b/tensorflow/cc/framework/cc_op_gen_util.h index d6c729f2dc9..8fb90356841 100644 --- a/tensorflow/cc/framework/cc_op_gen_util.h +++ b/tensorflow/cc/framework/cc_op_gen_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_UTIL_H_ #define TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_UTIL_H_ +#include #include #include #include diff --git a/tensorflow/cc/framework/fuzzing/BUILD b/tensorflow/cc/framework/fuzzing/BUILD index 4c6b0d80baf..c14b324fdf2 100644 --- a/tensorflow/cc/framework/fuzzing/BUILD +++ b/tensorflow/cc/framework/fuzzing/BUILD @@ -7,6 +7,8 @@ load( ) load("//tensorflow:tensorflow.bzl", "tf_copts") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + cc_library( name = "cc_op_fuzz_gen_main", srcs = [ @@ -28,6 +30,7 @@ cc_library( "//tensorflow/core/platform:hash", "//tensorflow/tsl/platform:status", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], ) @@ -36,15 +39,8 @@ cc_library( # tf_gen_op_wrappers_fuzz( # name = "array_ops_fuzz", # api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], -# extra_gen_deps = ["//tensorflow/c/kernels:bitcast_op_lib"], -# op_lib_names = [ -# "array_ops", -# ], -# pkg = "//tensorflow/core", -# deps = [ -# "//third_party/mediapipe/framework/port:parse_text_proto", +# kernel_deps = [ # "//tensorflow/c/kernels:bitcast_op", -# "//tensorflow/cc:cc_ops", # "//tensorflow/core/kernels:array", # "//tensorflow/core/kernels:check_numerics_op", # "//tensorflow/core/kernels:fake_quant_ops", @@ -57,6 +53,7 @@ cc_library( # "//tensorflow/core/kernels/linalg:matrix_diag_op", # "//tensorflow/core/kernels/linalg:matrix_set_diag_op", # ], +# op_def_src = "//tensorflow/core/ops:array_ops_op_lib", # ) # copybara:uncomment_end diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc index 02af4b4aa86..416bb56e820 100644 --- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc +++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.cc @@ -202,11 +202,58 @@ string WriteFuzzTest(const OpInfo& op_info) { })); } +string FuzzerFileStart() { + const string fuzz_namespace_begin = R"namespace( +namespace tensorflow { +namespace fuzzing { + +)namespace"; + + const string fuzz_header = strings::StrCat( + R"include(// This file is MACHINE GENERATED! Do not edit. + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/security/fuzzing/cc/fuzz_session.h" +#include "third_party/mediapipe/framework/port/parse_text_proto.h" +)include", + fuzz_namespace_begin); + + return fuzz_header; +} + +string FuzzerFileEnd() { + const string fuzz_footer = R"footer( +} // namespace fuzzing +} // namespace tensorflow +)footer"; + + return fuzz_footer; +} + +} // namespace + bool OpFuzzingIsOk(const OpInfo& op_info) { + // Skip deprecated ops. + if (op_info.graph_op_def.has_deprecation() && + op_info.graph_op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) { + std::cout << "NOT fuzzing: " << op_info.graph_op_def.name() + << " is deprecated.\n"; + return false; + } + // TODO(unda, b/249347507): should we hide fuzzers for hidden ops? - if (op_info.api_def.visibility() == ApiDef::HIDDEN) return false; + if (op_info.api_def.visibility() == ApiDef::HIDDEN) { + std::cout << "NOT fuzzing: " << op_info.graph_op_def.name() + << " is hidden.\n"; + return false; + } - if (op_info.api_def.visibility() == ApiDef::SKIP) return false; + if (op_info.api_def.visibility() == ApiDef::SKIP) { + std::cout << "NOT fuzzing: " << op_info.graph_op_def.name() + << " is skipped.\n"; + return false; + } // TODO(unda) : zero input ops std::set zero_input_ops = {"Placeholder", "ImmutableConst"}; @@ -272,56 +319,10 @@ bool OpFuzzingIsOk(const OpInfo& op_info) { return true; } -string FuzzerFileStart() { - const string fuzz_namespace_begin = R"namespace( -namespace tensorflow { -namespace fuzzing { - -)namespace"; - - const string fuzz_header = strings::StrCat( - R"include(// This file is MACHINE GENERATED! Do not edit. - -#include "tensorflow/cc/ops/const_op.h" -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/security/fuzzing/cc/fuzz_session.h" -#include "third_party/mediapipe/framework/port/parse_text_proto.h" -)include", - fuzz_namespace_begin); - - return fuzz_header; -} - -string FuzzerFileEnd() { - const string fuzz_footer = R"footer( -} // namespace fuzzing -} // namespace tensorflow -)footer"; - - return fuzz_footer; -} - -} // namespace - -string WriteFuzzers(const OpList& ops, const ApiDefMap& api_def_map) { +string WriteSingleFuzzer(const OpInfo& op_info, bool is_fuzzable) { return absl::StrCat( - FuzzerFileStart(), - absl::StrJoin( - ops.op(), "", - [&api_def_map](string* out, const OpDef& op_def) { - // Skip deprecated ops. - bool skip = op_def.has_deprecation() && - op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION; - const auto* api_Def = api_def_map.GetApiDef(op_def.name()); - OpInfo op_info(op_def, *api_Def, std::vector()); - skip |= !OpFuzzingIsOk(op_info); - if (!skip) { - out->append(WriteClassFuzzDef(op_info)); - out->append(WriteFuzzTest(op_info)); - out->append("\n"); - } - }), - FuzzerFileEnd()); + FuzzerFileStart(), is_fuzzable ? WriteClassFuzzDef(op_info) : "", + is_fuzzable ? WriteFuzzTest(op_info) : "", FuzzerFileEnd()); } } // namespace cc_op diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h index 6770430ad69..c11c9635d6d 100644 --- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h +++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CC_FRAMEWORK_FUZZING_CC_OP_FUZZ_GEN_H_ #define TENSORFLOW_CC_FRAMEWORK_FUZZING_CC_OP_FUZZ_GEN_H_ +#include "tensorflow/cc/framework/cc_op_gen_util.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/platform/types.h" @@ -23,8 +24,11 @@ limitations under the License. namespace tensorflow { namespace cc_op { -/// String with fuzzer file contents. -string WriteFuzzers(const OpList& ops, const ApiDefMap& api_def_map); +// String with single fuzzer file content. +string WriteSingleFuzzer(const OpInfo& op_info, bool is_fuzzable); + +// Do we have all we need to create a fuzzer +bool OpFuzzingIsOk(const OpInfo& op_info); } // namespace cc_op } // namespace tensorflow diff --git a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc index 0a1de103d37..99388eb8847 100644 --- a/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc +++ b/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen_main.cc @@ -14,10 +14,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include +#include "absl/status/status.h" #include "tensorflow/cc/framework/cc_op_gen_util.h" #include "tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h" #include "tensorflow/core/framework/op_def.pb.h" @@ -33,20 +35,33 @@ namespace tensorflow { namespace cc_op { namespace { -void WriteAllFuzzers(const std::string& file_name, bool include_internal, - const std::vector& api_def_dirs) { +void WriteAllFuzzers(string root_location, std::vector api_def_dirs, + std::vector op_names) { OpList ops; - StatusOr api_def_map = - LoadOpsAndApiDefs(ops, include_internal, api_def_dirs); + StatusOr api_def_map = LoadOpsAndApiDefs(ops, false, api_def_dirs); TF_CHECK_OK(api_def_map.status()); - WriteFuzzers(ops, api_def_map.value()); Env* env = Env::Default(); + tsl::Status status; std::unique_ptr fuzz_file = nullptr; - auto status = env->NewWritableFile(file_name, &fuzz_file); - status.Update(fuzz_file->Append(WriteFuzzers(ops, api_def_map.value()))); - status.Update(fuzz_file->Close()); + for (const OpDef& op_def : ops.op()) { + if (std::find(op_names.begin(), op_names.end(), op_def.name()) == + op_names.end()) + continue; + + const ApiDef* api_def = api_def_map->GetApiDef(op_def.name()); + if (api_def == nullptr) { + continue; + } + + OpInfo op_info(op_def, *api_def, std::vector()); + status.Update(env->NewWritableFile( + root_location + "/" + op_def.name() + "_fuzz.cc", &fuzz_file)); + status.Update( + fuzz_file->Append(WriteSingleFuzzer(op_info, OpFuzzingIsOk(op_info)))); + status.Update(fuzz_file->Close()); + } TF_CHECK_OK(status); } @@ -60,17 +75,17 @@ int main(int argc, char* argv[]) { for (int i = 1; i < argc; ++i) { fprintf(stderr, "Arg %d = %s\n", i, argv[i]); } - fprintf(stderr, - "Usage: %s out include_internal " - "api_def_dirs1,api_def_dir2 ...\n" - " include_internal: 1 means include internal ops\n", + fprintf(stderr, "Usage: %s location api_def1,api_def2 op1,op2,op3\n", argv[0]); exit(1); } - - bool include_internal = tensorflow::StringPiece("1") == argv[2]; - std::vector api_def_dirs = tensorflow::str_util::Split( + for (int i = 1; i < argc; ++i) { + fprintf(stdout, "Arg %d = %s\n", i, argv[i]); + } + std::vector api_def_srcs = tensorflow::str_util::Split( + argv[2], ",", tensorflow::str_util::SkipEmpty()); + std::vector op_names = tensorflow::str_util::Split( argv[3], ",", tensorflow::str_util::SkipEmpty()); - tensorflow::cc_op::WriteAllFuzzers(argv[1], include_internal, api_def_dirs); + tensorflow::cc_op::WriteAllFuzzers(argv[1], api_def_srcs, op_names); return 0; } diff --git a/tensorflow/cc/framework/fuzzing/op_fuzzing.bzl b/tensorflow/cc/framework/fuzzing/op_fuzzing.bzl index aac616f8928..2dfb4d08589 100644 --- a/tensorflow/cc/framework/fuzzing/op_fuzzing.bzl +++ b/tensorflow/cc/framework/fuzzing/op_fuzzing.bzl @@ -12,108 +12,160 @@ load( "cc_test", ) -def tf_gen_op_wrapper_fuzz( +def tf_gen_op_wrappers_fuzz( name, - out_ops_file, - pkg = "", - deps = None, - include_internal_ops = 0, - api_def_srcs = []): + op_def_src, + api_def_srcs = [], + kernel_deps = []): """ - Generates a file with fuzzers for a subset of ops. + Generates fuzzers for several groups of ops. + + For each one we need the corresponding OpDef, ApiDef and KernelDef, + since they all can contain constraints for the inputs. Args: - name: name of the op class - out_ops_file: prefix for file generation - pkg: where to find op registrations - deps: depedencies - include_internal_ops: true if we should generate internal ops - api_def_srcs: which op definitions to use + name: the name of the fuzz artifact + op_def_src: op definitions + api_def_srcs: api definitions + kernel_deps: op kernel dependencies """ - tool = out_ops_file + "_gen_fuzz" - if deps == None: - deps = [pkg + ":" + name + "_op_lib"] + # Create tool to generate .cc fuzzer files. tf_cc_binary( - name = tool, + name = "op_fuzz_gen_tool", copts = tf_copts(), linkopts = if_not_windows(["-lm", "-Wl,-ldl"]) + lrt_if_needed(), linkstatic = 1, # Faster to link this one-time-use binary dynamically deps = [ "//tensorflow/cc/framework/fuzzing:cc_op_fuzz_gen_main", - ] + deps, + op_def_src, + ] + kernel_deps, ) - srcs = api_def_srcs[:] + # Add relevant locations to look for api_defs. + api_def_src_locations = ",".join(["$$(dirname $$(echo $(locations " + api_def_src + ") | cut -d\" \" -f1))" for api_def_src in api_def_srcs]) - if not api_def_srcs: - api_def_args_str = "," - else: - api_def_args = [] - for api_def_src in api_def_srcs: - # Add directory of the first ApiDef source to args. - # We are assuming all ApiDefs in a single api_def_src are in the - # same directory. - api_def_args.append( - " $$(dirname $$(echo $(locations " + api_def_src + - ") | cut -d\" \" -f1))", - ) - api_def_args_str = ",".join(api_def_args) - - out_fuzz_file = out_ops_file + "_fuzz.cc" + out_fuzz_files = [op_name + "_fuzz.cc" for op_name in op_names] native.genrule( name = name + "_genrule", - outs = [ - out_fuzz_file, - ], - srcs = srcs, - tools = [":" + tool], # + tf_binary_additional_srcs(), - cmd = ("$(location :" + tool + ") $(location :" + out_fuzz_file + ") " + - str(include_internal_ops) + " " + api_def_args_str), + outs = out_fuzz_files, + srcs = api_def_srcs, + tools = [":op_fuzz_gen_tool"], + cmd = ("$(location :op_fuzz_gen_tool) " + + " $$(dirname $(location " + out_fuzz_files[0] + "))" + + " " + api_def_src_locations + " " + (",".join(op_names))), ) -def tf_gen_op_wrappers_fuzz( - name, - op_lib_names = [], - pkg = "", - deps = [ - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/cc:const_op", - ], - include_internal_ops = 0, - api_def_srcs = [], - extra_gen_deps = []): - """ - Generates fuzzers for several groups of ops. - - Args: - name: the name of the fuzz artifact - op_lib_names: which op libraries to fuzz - pkg: where to find op registrations - deps: dependencies - include_internal_ops: true if we should generate internal ops - api_def_srcs: where to find the op definitions - extra_gen_deps: extra dependencies for generation - """ - fuzzsrcs = [] - for n in op_lib_names: - tf_gen_op_wrapper_fuzz( - n, - "fuzzers/" + n, - api_def_srcs = api_def_srcs, - include_internal_ops = include_internal_ops, - pkg = pkg, - deps = [pkg + ":" + n + "_op_lib"] + extra_gen_deps, + for op_name in op_names: + cc_test( + name = op_name.lower() + "_fuzz", + srcs = [op_name + "_fuzz.cc"], + deps = kernel_deps + + [ + "//tensorflow/security/fuzzing/cc:fuzz_session", + "@com_google_googletest//:gtest_main", + "@com_google_fuzztest//fuzztest", + "//tensorflow/cc:cc_ops", + "//third_party/mediapipe/framework/port:parse_text_proto", + ], ) - fuzzsrcs.append("fuzzers/" + n + "_fuzz.cc") - cc_test( - name = name, - srcs = fuzzsrcs, - deps = deps + - [ - "//tensorflow/security/fuzzing/cc:fuzz_session", - "@com_google_googletest//:gtest_main", - "@com_google_fuzztest//fuzztest", - ], - ) + +op_names = [ + "BatchMatrixBandPart", + "BatchMatrixDiag", + "BatchMatrixDiagPart", + "BatchMatrixSetDiag", + "BatchToSpace", + "BatchToSpaceND", + "Bitcast", + "BroadcastArgs", + "BroadcastTo", + "CheckNumerics", + "ConcatV2", + "ConjugateTranspose", + "DebugGradientIdentity", + "DeepCopy", + "DepthToSpace", + "Dequantize", + "EditDistance", + "Empty", + "EnsureShape", + "ExpandDims", + "ExtractImagePatches", + "ExtractVolumePatches", + "FakeQuantWithMinMaxArgs", + "FakeQuantWithMinMaxArgsGradient", + "FakeQuantWithMinMaxVars", + "FakeQuantWithMinMaxVarsGradient", + "FakeQuantWithMinMaxVarsPerChannel", + "FakeQuantWithMinMaxVarsPerChannelGradient", + "Fill", + "Fingerprint", + "Gather", + "GuaranteeConst", + "Identity", + "IdentityN", + "InplaceAdd", + "InplaceSub", + "InplaceUpdate", + "InvertPermutation", + "ListDiff", + "MatrixBandPart", + "MatrixDiag", + "MatrixDiagPart", + "MatrixDiagPartV2", + "MatrixDiagPartV3", + "MatrixDiagV2", + "MatrixDiagV3", + "MatrixSetDiag", + "MatrixSetDiagV2", + "MatrixSetDiagV3", + "MirrorPad", + "OneHot", + "OnesLike", + "Pack", + "Pad", + "PadV2", + "ParallelConcat", + "PlaceholderV2", + "PlaceholderWithDefault", + "PreventGradient", + "QuantizeAndDequantize", + "QuantizeV2", + "Rank", + "Reshape", + "ResourceStridedSliceAssign", + "ReverseSequence", + "ReverseV2", + "ScatterNdNonAliasingAdd", + "Shape", + "ShapeN", + "Size", + "Slice", + "Snapshot", + "SpaceToBatch", + "SpaceToBatchND", + "SpaceToDepth", + "Split", + "SplitV", + "Squeeze", + "StopGradient", + "StridedSlice", + "StridedSliceGrad", + "TensorScatterAdd", + "TensorScatterMax", + "TensorScatterMin", + "TensorScatterSub", + "TensorStridedSliceUpdate", + "Tile", + "TileGrad", + "Transpose", + "Unique", + "UniqueV2", + "UniqueWithCounts", + "UniqueWithCountsV2", + "Unpack", + "UnravelIndex", + "Where", + "ZerosLike", +] diff --git a/tensorflow/cc/framework/grad_op_registry.h b/tensorflow/cc/framework/grad_op_registry.h index 0fc5abb20c8..951144cf8ce 100644 --- a/tensorflow/cc/framework/grad_op_registry.h +++ b/tensorflow/cc/framework/grad_op_registry.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ #define TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ +#include #include +#include #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" diff --git a/tensorflow/cc/framework/gradient_checker.cc b/tensorflow/cc/framework/gradient_checker.cc index 0013ea732df..0c026cf9a0c 100644 --- a/tensorflow/cc/framework/gradient_checker.cc +++ b/tensorflow/cc/framework/gradient_checker.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/cc/framework/gradient_checker.h" +#include +#include + #include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/ops/standard_ops.h" diff --git a/tensorflow/cc/framework/gradient_checker.h b/tensorflow/cc/framework/gradient_checker.h index 1aa215a9088..b8db767f77c 100644 --- a/tensorflow/cc/framework/gradient_checker.h +++ b/tensorflow/cc/framework/gradient_checker.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ #define TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ +#include + #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index cdb3d0c7d68..3dd2ab3ab82 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -16,6 +16,11 @@ limitations under the License. #include "tensorflow/cc/framework/gradients.h" #include +#include +#include +#include +#include +#include #include #include "tensorflow/cc/framework/grad_op_registry.h" @@ -35,9 +40,7 @@ namespace tensorflow { namespace { struct OutputHash { - uint64 operator()(const Output& x) const { - return x.hash(); - } + uint64 operator()(const Output& x) const { return x.hash(); } }; struct OutputEq { @@ -343,8 +346,8 @@ Status SymbolicGradientBuilder::Initialize() { Status SymbolicGradientBuilder::SumGradients(const Output& src, Output* grad) { auto iter = backprops_.find(src); if (iter == backprops_.end()) { - return errors::Internal( - "Unable to find backprop list for node.id ", src.node()->name()); + return errors::Internal("Unable to find backprop list for node.id ", + src.node()->name()); } const auto& grads = iter->second; // Filter any backpropped 'NoGradient' Outputs from 'grads' (if needed). @@ -378,8 +381,7 @@ bool SymbolicGradientBuilder::IsPrimitiveOpWithNoGrad(const string& opname) { } Status SymbolicGradientBuilder::CallGradFunction( - const Operation& op, - const std::vector& grad_inputs, + const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { ops::GradFunc grad_fn; TF_RETURN_IF_ERROR(registry_->Lookup(op.node()->type_string(), &grad_fn)); @@ -526,8 +528,8 @@ Status SymbolicGradientBuilder::AddGradients() { if (e->IsControlEdge()) continue; size_t dx_index = e->dst_input(); if (dx_index >= dx.size()) { - return errors::Internal( - "Invalid gradient output index: ", dx_index, " size: ", dx.size()); + return errors::Internal("Invalid gradient output index: ", dx_index, + " size: ", dx.size()); } TF_RETURN_IF_ERROR( BackpropAlongEdge(dx[dx_index], {e->src(), e->src_output()})); diff --git a/tensorflow/cc/framework/gradients.h b/tensorflow/cc/framework/gradients.h index 0a377ad56d1..d404bd34c4a 100644 --- a/tensorflow/cc/framework/gradients.h +++ b/tensorflow/cc/framework/gradients.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ #define TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ +#include + #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index 08527b213e3..d19b895654b 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -16,7 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_CC_FRAMEWORK_OPS_H_ #define TENSORFLOW_CC_FRAMEWORK_OPS_H_ +#include #include +#include +#include #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 5db7eab2b81..586165ee4eb 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -16,6 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ #define TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ +#include +#include +#include +#include +#include + #include "tensorflow/cc/framework/scope.h" namespace tensorflow { diff --git a/tensorflow/cc/framework/testutil.h b/tensorflow/cc/framework/testutil.h index 7ad6fb4a676..2464b491d79 100644 --- a/tensorflow/cc/framework/testutil.h +++ b/tensorflow/cc/framework/testutil.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ #define TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ +#include + #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" diff --git a/tensorflow/cc/framework/while_gradients.cc b/tensorflow/cc/framework/while_gradients.cc index a907fa9677a..e28306e5a33 100644 --- a/tensorflow/cc/framework/while_gradients.cc +++ b/tensorflow/cc/framework/while_gradients.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/cc/framework/while_gradients.h" +#include + #include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" diff --git a/tensorflow/cc/framework/while_gradients.h b/tensorflow/cc/framework/while_gradients.h index cb4e579c854..6d33d49dbb3 100644 --- a/tensorflow/cc/framework/while_gradients.h +++ b/tensorflow/cc/framework/while_gradients.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ #define TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ +#include + #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/core/graph/while_context.h" diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index b07af98fbef..18466df3691 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -17,6 +17,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) @@ -121,6 +122,7 @@ cc_library( hdrs = ["loader.h"], deps = [ ":constants", + ":fingerprinting", ":loader_util", ":reader", ] + if_not_mobile([ @@ -142,6 +144,7 @@ cc_library( hdrs = ["bundle_v2.h"], deps = [ ":constants", + ":fingerprinting", ":metrics", ":reader", ":util", @@ -322,6 +325,7 @@ cc_library( visibility = [ "//tensorflow:__pkg__", "//tensorflow/python:__pkg__", + "//tensorflow/security/fuzzing/cc/ops:__pkg__", # TODO(b/261455394): Remove. ], deps = if_not_mobile(["//tensorflow/core:lib"]) + if_android(["//tensorflow/core:portable_tensorflow_lib_lite"]), alwayslink = True, @@ -380,6 +384,7 @@ cc_library( visibility = [ "//tensorflow:__pkg__", "//tensorflow/python:__pkg__", + "//tensorflow/security/fuzzing/cc/ops:__pkg__", # TODO(b/261455394): Remove. ], deps = [ ":constants", diff --git a/tensorflow/cc/saved_model/bundle_v2.cc b/tensorflow/cc/saved_model/bundle_v2.cc index 4785f6f32be..0edc3469e89 100644 --- a/tensorflow/cc/saved_model/bundle_v2.cc +++ b/tensorflow/cc/saved_model/bundle_v2.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/fingerprinting.h" #include "tensorflow/cc/saved_model/metrics.h" #include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/cc/saved_model/util.h" @@ -121,6 +122,7 @@ Status SavedModelV2Bundle::Load(const std::string& export_dir, metrics::SavedModelReadApi(kCCLoadBundleV2Label).IncrementBy(1); SavedModel saved_model_proto; TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto)); + metrics::SavedModelReadPath().Set(export_dir); // Load MetaGraphDef. // In version 2 SavedModels, there is only one MetaGraphDef. @@ -136,7 +138,8 @@ Status SavedModelV2Bundle::Load(const std::string& export_dir, // Correct the endiness of Tensor content on big-endian system if (!port::kLittleEndian) { - TF_RETURN_IF_ERROR(ByteSwapTensorContent(&(bundle->meta_graph_def_))); + TF_RETURN_IF_ERROR( + ByteSwapTensorContentInMetaGraphDef(&(bundle->meta_graph_def_))); } // Load GraphDebugInfo. @@ -163,6 +166,14 @@ Status SavedModelV2Bundle::Load(const std::string& export_dir, TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph( bundle->variable_reader_.get(), &bundle->trackable_object_graph_)); } + // Read the fingerprint. + auto fingerprint_proto = + saved_model::fingerprinting::ReadSavedModelFingerprint(export_dir); + if (fingerprint_proto.ok()) { + // Set gauge cell with saved_model_checksum. + metrics::SavedModelReadFingerprint().Set( + std::to_string(fingerprint_proto->saved_model_checksum())); + } return OkStatus(); } diff --git a/tensorflow/cc/saved_model/bundle_v2_test.cc b/tensorflow/cc/saved_model/bundle_v2_test.cc index 506a4fbd9b4..a70244b04fd 100644 --- a/tensorflow/cc/saved_model/bundle_v2_test.cc +++ b/tensorflow/cc/saved_model/bundle_v2_test.cc @@ -28,6 +28,8 @@ namespace tensorflow { namespace { constexpr char kTestData[] = "cc/saved_model/testdata"; +// This is the value in testdata/VarsAndArithmeticObjectGraph/fingerprint.pb +constexpr char kV2ModuleSavedModelChecksum[] = "15788619162413586750"; class BundleV2Test : public ::testing::Test { protected: @@ -114,6 +116,10 @@ TEST_F(BundleV2Test, UpdatesMetrics) { EXPECT_EQ(metrics::SavedModelRead("2").value(), read_count + 1); EXPECT_EQ(metrics::SavedModelReadApi(kCCLoadBundleV2Label).value(), api_count + 1); + // Check that the gauge contains the fingerprint. + EXPECT_EQ(metrics::SavedModelReadFingerprint().value(), + kV2ModuleSavedModelChecksum); + EXPECT_EQ(metrics::SavedModelReadPath().value(), export_dir); } } // namespace diff --git a/tensorflow/cc/saved_model/experimental/public/BUILD b/tensorflow/cc/saved_model/experimental/public/BUILD index 2b91a34d538..a26eabfe5ec 100644 --- a/tensorflow/cc/saved_model/experimental/public/BUILD +++ b/tensorflow/cc/saved_model/experimental/public/BUILD @@ -4,6 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], # This is intentionally public default_visibility = [ "//visibility:public", diff --git a/tensorflow/cc/saved_model/experimental/tests/BUILD b/tensorflow/cc/saved_model/experimental/tests/BUILD index ebdcdf02887..3818412b19f 100644 --- a/tensorflow/cc/saved_model/experimental/tests/BUILD +++ b/tensorflow/cc/saved_model/experimental/tests/BUILD @@ -2,6 +2,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) diff --git a/tensorflow/cc/saved_model/fingerprinting.cc b/tensorflow/cc/saved_model/fingerprinting.cc index 90eda79971c..7d2893f6199 100644 --- a/tensorflow/cc/saved_model/fingerprinting.cc +++ b/tensorflow/cc/saved_model/fingerprinting.cc @@ -16,7 +16,9 @@ limitations under the License. #include "tensorflow/cc/saved_model/fingerprinting.h" #include +#include #include +#include #include "absl/container/btree_map.h" #include "absl/strings/strip.h" @@ -36,13 +38,20 @@ limitations under the License. #include "tensorflow/core/protobuf/saved_model.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/util/tensor_bundle/naming.h" +#include "tensorflow/tsl/lib/strings/proto_serialization.h" namespace tensorflow::saved_model::fingerprinting { // Version of the code that produced the fingerprint. -const int kFingerprintProducer = 0; +const int kFingerprintProducer = 1; namespace { +uint64 HashSavedModel(const SavedModel& saved_model) { + std::string saved_model_string; + SerializeToStringDeterministic(saved_model, &saved_model_string); + return tensorflow::Fingerprint64(saved_model_string); +} + uint64 RegularizeAndHashSignatureDefs( const google::protobuf::Map& signature_def_map) { // Sort `signature_def_map`, which is an unordered map from string keys to @@ -113,15 +122,14 @@ uint64 HashCheckpointIndexFile(absl::string_view model_dir) { } // namespace -FingerprintDef CreateFingerprintDef(const MetaGraphDef& metagraph, +FingerprintDef CreateFingerprintDef(const SavedModel& saved_model, absl::string_view export_dir) { // Create a copy of `metagraph` which will be used and mutated for fingerprint // computation. - MetaGraphDef metagraph_copy = metagraph; + MetaGraphDef metagraph_copy = saved_model.meta_graphs(0); FingerprintDef fingerprint_def; // Set fingerprint field #1. - fingerprint_def.set_graph_def_checksum( - graph_regularization::ComputeHash(metagraph_copy.graph_def())); + fingerprint_def.set_saved_model_checksum(HashSavedModel(saved_model)); // Set fingerprint field #2. graph_regularization::SimpleDelete(*metagraph_copy.mutable_graph_def()); fingerprint_def.set_graph_def_program_hash( @@ -143,4 +151,35 @@ FingerprintDef CreateFingerprintDef(const MetaGraphDef& metagraph, return fingerprint_def; } +StatusOr ReadSavedModelFingerprint( + absl::string_view export_dir) { + const string fingerprint_pb_path = + io::JoinPath(export_dir, kFingerprintFilenamePb); + Status found_pb = Env::Default()->FileExists(fingerprint_pb_path); + if (found_pb.ok()) { + FingerprintDef fingerprint_proto; + Status result = ReadBinaryProto(Env::Default(), fingerprint_pb_path, + &fingerprint_proto); + if (result.ok()) { + return fingerprint_proto; + } + return result; + } + return found_pb; +} + +std::unordered_map MakeFingerprintMap( + const FingerprintDef& fingerprint) { + std::unordered_map fingerprint_map; + fingerprint_map["saved_model_checksum"] = fingerprint.saved_model_checksum(); + fingerprint_map["graph_def_program_hash"] = + fingerprint.graph_def_program_hash(); + fingerprint_map["signature_def_hash"] = fingerprint.signature_def_hash(); + fingerprint_map["saved_object_graph_hash"] = + fingerprint.saved_object_graph_hash(); + fingerprint_map["checkpoint_hash"] = fingerprint.checkpoint_hash(); + fingerprint_map["version"] = fingerprint.version().producer(); + return fingerprint_map; +} + } // namespace tensorflow::saved_model::fingerprinting diff --git a/tensorflow/cc/saved_model/fingerprinting.h b/tensorflow/cc/saved_model/fingerprinting.h index a827e1cb32d..15790ed61e9 100644 --- a/tensorflow/cc/saved_model/fingerprinting.h +++ b/tensorflow/cc/saved_model/fingerprinting.h @@ -16,17 +16,30 @@ limitations under the License. #ifndef TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_H_ #define TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_H_ +#include +#include + #include "absl/strings/string_view.h" +#include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/protobuf/fingerprint.pb.h" -#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saved_model.pb.h" namespace tensorflow::saved_model::fingerprinting { -// Creates a FingerprintDef proto from a MetaGraph and the checkpoint meta file +// Creates a FingerprintDef proto from a SavedModel and the checkpoint meta file // (.index) in `export_dir`. -FingerprintDef CreateFingerprintDef(const MetaGraphDef& metagraph, +FingerprintDef CreateFingerprintDef(const SavedModel& saved_model, absl::string_view export_dir); +// Loads the `fingerprint.pb` from `export_dir`, returns an error if there is +// none. +StatusOr ReadSavedModelFingerprint( + absl::string_view export_dir); + +// Converts the fingerprint into a dictionary mapping field names to values. +std::unordered_map MakeFingerprintMap( + const FingerprintDef& fingerprint); + } // namespace tensorflow::saved_model::fingerprinting #endif // TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_H_ diff --git a/tensorflow/cc/saved_model/fingerprinting_test.cc b/tensorflow/cc/saved_model/fingerprinting_test.cc index ee337f4337d..0db31bbf17a 100644 --- a/tensorflow/cc/saved_model/fingerprinting_test.cc +++ b/tensorflow/cc/saved_model/fingerprinting_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/fingerprint.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/saved_model.pb.h" @@ -52,9 +53,9 @@ TEST(FingerprintingTest, TestCreateFingerprint) { TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb, ReadSavedModel(export_dir)); FingerprintDef fingerprint_def = - CreateFingerprintDef(saved_model_pb.meta_graphs(0), export_dir); + CreateFingerprintDef(saved_model_pb, export_dir); - EXPECT_GT(fingerprint_def.graph_def_checksum(), 0); + EXPECT_GT(fingerprint_def.saved_model_checksum(), 0); EXPECT_EQ(fingerprint_def.graph_def_program_hash(), 10127142238652115842U); EXPECT_EQ(fingerprint_def.signature_def_hash(), 5693392539583495303); EXPECT_EQ(fingerprint_def.saved_object_graph_hash(), 3678101440349108924); @@ -72,14 +73,14 @@ TEST(FingerprintingTest, TestCompareFingerprintForTwoModelSavedTwice) { TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb, ReadSavedModel(export_dir)); FingerprintDef fingerprint_def = - CreateFingerprintDef(saved_model_pb.meta_graphs(0), export_dir); + CreateFingerprintDef(saved_model_pb, export_dir); const std::string export_dir2 = io::JoinPath( testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "bert2"); TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb2, ReadSavedModel(export_dir2)); FingerprintDef fingerprint_def2 = - CreateFingerprintDef(saved_model_pb2.meta_graphs(0), export_dir2); + CreateFingerprintDef(saved_model_pb2, export_dir2); EXPECT_EQ(fingerprint_def.graph_def_program_hash(), fingerprint_def2.graph_def_program_hash()); @@ -95,12 +96,12 @@ TEST(FingerprintingTest, TestFingerprintComputationDoesNotMutateModel) { TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb, ReadSavedModel(export_dir)); FingerprintDef fingerprint_def = - CreateFingerprintDef(saved_model_pb.meta_graphs(0), export_dir); + CreateFingerprintDef(saved_model_pb, export_dir); FingerprintDef fingerprint_def2 = - CreateFingerprintDef(saved_model_pb.meta_graphs(0), export_dir); + CreateFingerprintDef(saved_model_pb, export_dir); - EXPECT_EQ(fingerprint_def.graph_def_checksum(), - fingerprint_def2.graph_def_checksum()); + EXPECT_EQ(fingerprint_def.saved_model_checksum(), + fingerprint_def2.saved_model_checksum()); } TEST(FingerprintingTest, TestFingerprintHasVersion) { @@ -109,8 +110,8 @@ TEST(FingerprintingTest, TestFingerprintHasVersion) { TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb, ReadSavedModel(export_dir)); FingerprintDef fingerprint_def = - CreateFingerprintDef(saved_model_pb.meta_graphs(0), export_dir); - EXPECT_EQ(fingerprint_def.version().producer(), 0); + CreateFingerprintDef(saved_model_pb, export_dir); + EXPECT_EQ(fingerprint_def.version().producer(), 1); } TEST(FingerprintingTest, TestHashCheckpointForModelWithNoVariables) { @@ -119,9 +120,35 @@ TEST(FingerprintingTest, TestHashCheckpointForModelWithNoVariables) { TF_ASSERT_OK_AND_ASSIGN(SavedModel saved_model_pb, ReadSavedModel(export_dir)); FingerprintDef fingerprint_def = - CreateFingerprintDef(saved_model_pb.meta_graphs(0), export_dir); + CreateFingerprintDef(saved_model_pb, export_dir); EXPECT_EQ(fingerprint_def.checkpoint_hash(), 0); } +TEST(FingerprintingTest, TestReadValidFingerprint) { + const std::string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", + "VarsAndArithmeticObjectGraph"); + TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_pb, + ReadSavedModelFingerprint(export_dir)); + EXPECT_EQ(fingerprint_pb.saved_model_checksum(), 15788619162413586750u); +} + +TEST(FingerprintingTest, TestReadNonexistentFingerprint) { + const std::string export_dir = io::JoinPath( + testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", "AssetModule"); + EXPECT_FALSE(ReadSavedModelFingerprint(export_dir).ok()); +} + +TEST(FingerprintingTest, TestMakeFingerprintMap) { + const std::string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata", + "VarsAndArithmeticObjectGraph"); + TF_ASSERT_OK_AND_ASSIGN(FingerprintDef fingerprint_pb, + ReadSavedModelFingerprint(export_dir)); + auto fingerprint_map = MakeFingerprintMap(fingerprint_pb); + EXPECT_EQ(fingerprint_pb.saved_model_checksum(), + fingerprint_map["saved_model_checksum"]); +} + } // namespace } // namespace tensorflow::saved_model::fingerprinting diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 2f87d6da6fe..75869afe687 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -15,9 +15,11 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader.h" +#include #include #include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/fingerprinting.h" #include "tensorflow/cc/saved_model/loader_util.h" #include "tensorflow/cc/saved_model/metrics.h" #include "tensorflow/cc/saved_model/reader.h" @@ -63,8 +65,8 @@ auto* load_latency_by_stage = monitoring::Sampler<2>::New( "model_path", "stage", }, - // Scale of 10, power of 1.8 with bucket count 33 (~20 minutes). - monitoring::Buckets::Exponential(10, 1.8, 33)); + // Scale of 10, power of 1.8 with bucket count 37 (~258 minutes). + monitoring::Buckets::Exponential(10, 1.8, 37)); constexpr char kLoadAttemptFail[] = "fail"; constexpr char kLoadAttemptSuccess[] = "success"; @@ -296,6 +298,13 @@ Status LoadSavedModel(const SessionOptions& session_options, const std::unordered_set& tags, SavedModelBundle* const bundle) { metrics::SavedModelReadApi(kCCLoadLabel).IncrementBy(1); + auto fingerprint_proto = + saved_model::fingerprinting::ReadSavedModelFingerprint(export_dir); + if (fingerprint_proto.ok()) { + // Set gauge cell with saved_model_checksum. + metrics::SavedModelReadFingerprint().Set( + std::to_string(fingerprint_proto->saved_model_checksum())); + } // TODO(robson): Add tests for the counters. const uint64 start_microseconds = Env::Default()->NowMicros(); @@ -309,6 +318,7 @@ Status LoadSavedModel(const SessionOptions& session_options, }; if (status.ok()) { log_and_count(kLoadAttemptSuccess); + metrics::SavedModelReadPath().Set(export_dir); } else { log_and_count(kLoadAttemptFail); } diff --git a/tensorflow/cc/saved_model/metrics.cc b/tensorflow/cc/saved_model/metrics.cc index fc04e2c7725..14e84c93510 100644 --- a/tensorflow/cc/saved_model/metrics.cc +++ b/tensorflow/cc/saved_model/metrics.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/lib/monitoring/sampler.h" namespace tensorflow { @@ -51,6 +52,29 @@ auto* saved_model_read_api = monitoring::Counter<1>::New( "/tensorflow/core/saved_model/read/api", "The API used to load the SavedModel.", "api_label"); +// Gauge that contains the fingerprint (saved_model_checksum) of the newly +// written SavedModel. +auto* saved_model_write_fingerprint = monitoring::Gauge::New( + "/tensorflow/core/saved_model/write/fingerprint", + "The fingerprint (saved_model_checksum) of the exported SavedModel."); + +// Gauge that contains the path (saved_model_path) of the newly written +// SavedModel. +auto* saved_model_write_path = monitoring::Gauge::New( + "/tensorflow/core/saved_model/write/path", + "The path (saved_model_path) of the exported SavedModel."); + +// Gauge that contains the fingerprint (saved_model_checksum) of the loaded +// SavedModel. +auto* saved_model_read_fingerprint = monitoring::Gauge::New( + "/tensorflow/core/saved_model/read/fingerprint", + "The fingerprint (saved_model_checksum) of the loaded SavedModel."); + +// Gauge that contains the path (saved_model_path) of the loaded SavedModel. +auto* saved_model_read_path = monitoring::Gauge::New( + "/tensorflow/core/saved_model/read/path", + "The path (saved_model_path) of the loaded SavedModel."); + // Distribution of checkpoint write durations. auto* checkpoint_write_durations = monitoring::Sampler<1>::New( { @@ -121,6 +145,22 @@ monitoring::CounterCell& SavedModelReadApi(absl::string_view api_label) { return *saved_model_read_api->GetCell(std::string(api_label)); } +monitoring::GaugeCell& SavedModelReadFingerprint() { + return *saved_model_read_fingerprint->GetCell(); +} + +monitoring::GaugeCell& SavedModelReadPath() { + return *saved_model_read_path->GetCell(); +} + +monitoring::GaugeCell& SavedModelWriteFingerprint() { + return *saved_model_write_fingerprint->GetCell(); +} + +monitoring::GaugeCell& SavedModelWritePath() { + return *saved_model_write_path->GetCell(); +} + monitoring::SamplerCell& CheckpointReadDuration(absl::string_view api_label) { return *checkpoint_read_durations->GetCell(std::string(api_label)); } diff --git a/tensorflow/cc/saved_model/metrics.h b/tensorflow/cc/saved_model/metrics.h index 7ae41285494..4971bb25077 100644 --- a/tensorflow/cc/saved_model/metrics.h +++ b/tensorflow/cc/saved_model/metrics.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/gauge.h" #include "tensorflow/core/lib/monitoring/sampler.h" namespace tensorflow { @@ -40,6 +41,22 @@ monitoring::CounterCell& SavedModelWrite(absl::string_view write_version); // incremented when a SavedModel has been successfully read. monitoring::CounterCell& SavedModelRead(absl::string_view write_version); +// Returns "/tensorflow/core/saved_model/write/fingerprint" cell, which contains +// the saved_model_checksum of the SM's fingerprint when it is exported. +monitoring::GaugeCell& SavedModelWriteFingerprint(); + +// Returns "/tensorflow/core/saved_model/write/path" cell, which contains +// the saved_model_path of the SM when it is exported. +monitoring::GaugeCell& SavedModelWritePath(); + +// Returns "/tensorflow/core/saved_model/read/fingerprint" cell, wich contains +// the saved_model_checksum of the SM's fingerprint when it is imported. +monitoring::GaugeCell& SavedModelReadFingerprint(); + +// Returns "/tensorflow/core/saved_model/read/path" cell, wich contains +// the saved_model_path of the SM when it is imported. +monitoring::GaugeCell& SavedModelReadPath(); + // Returns "/tensorflow/core/saved_model/write/api" cell. This metric has 1 // field "api_label" which corresponds to a SavedModel write API. The cell for // `foo` should be incremented when the write API `foo` is called. diff --git a/tensorflow/cc/saved_model/metrics_test.cc b/tensorflow/cc/saved_model/metrics_test.cc index 1f6d9a8658f..b88af2f73f1 100644 --- a/tensorflow/cc/saved_model/metrics_test.cc +++ b/tensorflow/cc/saved_model/metrics_test.cc @@ -73,5 +73,37 @@ TEST(MetricsTest, TestCheckpointSize) { EXPECT_EQ(CheckpointSize("foo", 10).value(), 1); } +TEST(MetricsTest, TestWriteFingerprint) { + EXPECT_EQ(SavedModelWriteFingerprint().value(), ""); + SavedModelWriteFingerprint().Set("foo"); + EXPECT_EQ(SavedModelWriteFingerprint().value(), "foo"); + SavedModelWriteFingerprint().Set("bar"); + EXPECT_EQ(SavedModelWriteFingerprint().value(), "bar"); +} + +TEST(MetricsTest, TestWritePath) { + EXPECT_EQ(SavedModelWritePath().value(), ""); + SavedModelWritePath().Set("foo"); + EXPECT_EQ(SavedModelWritePath().value(), "foo"); + SavedModelWritePath().Set("bar"); + EXPECT_EQ(SavedModelWritePath().value(), "bar"); +} + +TEST(MetricsTest, TestReadFingerprint) { + EXPECT_EQ(SavedModelReadFingerprint().value(), ""); + SavedModelReadFingerprint().Set("foo"); + EXPECT_EQ(SavedModelReadFingerprint().value(), "foo"); + SavedModelReadFingerprint().Set("bar"); + EXPECT_EQ(SavedModelReadFingerprint().value(), "bar"); +} + +TEST(MetricsTest, TestReadPath) { + EXPECT_EQ(SavedModelReadPath().value(), ""); + SavedModelReadPath().Set("foo"); + EXPECT_EQ(SavedModelReadPath().value(), "foo"); + SavedModelReadPath().Set("bar"); + EXPECT_EQ(SavedModelReadPath().value(), "bar"); +} + } // namespace metrics } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/python/BUILD b/tensorflow/cc/saved_model/python/BUILD index f59f381bb58..7d4737e2c3f 100644 --- a/tensorflow/cc/saved_model/python/BUILD +++ b/tensorflow/cc/saved_model/python/BUILD @@ -4,6 +4,7 @@ load("//tensorflow/core/platform:build_config.bzl", "tf_py_clif_cc") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/cc/saved_model/reader.cc b/tensorflow/cc/saved_model/reader.cc index ba5672e982c..7dbe9c97504 100644 --- a/tensorflow/cc/saved_model/reader.cc +++ b/tensorflow/cc/saved_model/reader.cc @@ -99,7 +99,7 @@ Status FindMetaGraphDef(const std::unordered_set& tags, *meta_graph_def = std::move(graph_def); // Correct the endiness of Tensor content on big-endian system if (!port::kLittleEndian) { - TF_RETURN_IF_ERROR(ByteSwapTensorContent(meta_graph_def)); + TF_RETURN_IF_ERROR(ByteSwapTensorContentInMetaGraphDef(meta_graph_def)); } return OkStatus(); } diff --git a/tensorflow/cc/saved_model/saved_model_bundle_test.cc b/tensorflow/cc/saved_model/saved_model_bundle_test.cc index 6d17f0663de..c89b6c5736b 100644 --- a/tensorflow/cc/saved_model/saved_model_bundle_test.cc +++ b/tensorflow/cc/saved_model/saved_model_bundle_test.cc @@ -50,6 +50,10 @@ constexpr char kTestFuzzGeneratedBadNodeAttr[] = "cc/saved_model/testdata/fuzz_generated/bad_node_attr"; constexpr char kTestCyclicModule[] = "cc/saved_model/testdata/CyclicModule"; constexpr char kTestSimpleV1Model[] = "cc/saved_model/testdata/SimpleV1Model"; +constexpr char kVarsAndArithmeticObjectGraph[] = + "cc/saved_model/testdata/VarsAndArithmeticObjectGraph"; +// This is the value in testdata/VarsAndArithmeticObjectGraph/fingerprint.pb +constexpr char kV2ModuleSavedModelChecksum[] = "15788619162413586750"; class LoaderTest : public ::testing::Test { protected: @@ -385,5 +389,21 @@ TEST_F(LoaderTest, UpdateMetricsV1) { EXPECT_EQ(metrics::SavedModelReadApi(kCCLoadLabel).value(), api_count + 1); } +TEST_F(LoaderTest, UpdateFingerprintMetrics) { + SavedModelBundle bundle; + SessionOptions session_options; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kVarsAndArithmeticObjectGraph); + TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle)); + + EXPECT_EQ(metrics::SavedModelReadPath().value(), export_dir); + + EXPECT_EQ(metrics::SavedModelReadFingerprint().value(), + kV2ModuleSavedModelChecksum); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/fingerprint.pb b/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/fingerprint.pb new file mode 100644 index 00000000000..a5b79c3c288 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/VarsAndArithmeticObjectGraph/fingerprint.pb @@ -0,0 +1 @@ +쟯 ˿O ɧ(ە2 \ No newline at end of file diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD index 8d1e1602dba..510e7f589fd 100644 --- a/tensorflow/cc/tools/BUILD +++ b/tensorflow/cc/tools/BUILD @@ -8,6 +8,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index c214d7f5706..c0ea0f038c4 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -4,6 +4,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available", "if_llvm_system_z_available") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:private"], licenses = ["notice"], ) @@ -40,7 +41,10 @@ cc_library( defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]) + if_llvm_system_z_available([ "TF_LLVM_S390X_AVAILABLE=1", ]), - visibility = ["//tensorflow/python:__pkg__"], + visibility = [ + "//tensorflow:__pkg__", + "//tensorflow/python:__pkg__", + ], deps = [ ":aot_only_var_handle_op", ":embedded_protocol_buffers", diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index f96bd0a9189..849a5227349 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" #include +#include #include #include "absl/memory/memory.h" @@ -106,7 +107,7 @@ GetTargetMachineFromTriple(absl::string_view target_triple) { return absl::WrapUnique(target->createTargetMachine( normalized_triple, /*CPU=*/"", - /*Features=*/"", llvm::TargetOptions(), llvm::None)); + /*Features=*/"", llvm::TargetOptions(), std::nullopt)); } StatusOr CreateEmbeddedProtocolBuffers( @@ -116,8 +117,8 @@ StatusOr CreateEmbeddedProtocolBuffers( GetTargetMachineFromTriple(target_triple)); llvm::LLVMContext llvm_context; - std::unique_ptr module_with_serialized_proto = - absl::make_unique("embedded_data_module", llvm_context); + auto module_with_serialized_proto = + std::make_unique("embedded_data_module", llvm_context); EmbeddedProtocolBuffers result; diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 23357df0883..e145ef37107 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -4,6 +4,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:private"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index d6ed1108fd4..793e2890454 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -10,6 +10,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ ":internal", "//third_party/cloud_tpu/inference_converter:__pkg__", @@ -151,9 +152,9 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_init", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", - "//tensorflow/core/common_runtime/gpu:gpu_init", ] + if_libtpu( if_false = [ "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep @@ -300,8 +301,13 @@ cc_library( # Public visibility is needed for external TF/XLA backends. visibility = ["//visibility:public"], deps = XLA_DEVICE_DEPS + [ + ":device_compilation_cache", + ":device_compilation_profiler", + ":device_compiler_client", + ":device_executable_persistor", ":flags_headers", - ":xla_compilation_cache", + ":device_compiler", + ":xla_device_compiler_client", "//tensorflow/core/tpu:tpu_defs", ], alwayslink = 1, @@ -318,6 +324,7 @@ cc_library( # Public visibility is needed for external TF/XLA backends. visibility = ["//visibility:public"], deps = XLA_DEVICE_DEPS + [ + ":device_compilation_profiler", ":jit_compilation_passes", ":xla_device_no_jit_rewrite_registration", ], @@ -401,16 +408,17 @@ cc_library( ":internal", # We reuse VariableInfo in TFRT's implementation of TpuExecuteOp. "//learning/brain/tfrt/tf_tpu:__pkg__", + "//learning/brain/tfrt/tpu_plugin:__pkg__", "//learning/brain/tfrt/tpu_common:__pkg__", "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", ], deps = [ ":common", - ":xla_compilation_cache", ":xla_tensor", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", @@ -438,9 +446,10 @@ cc_library( "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", ], deps = [ - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla/client:executable_build_options", - "//tensorflow/core:framework", + ":flags_headers", + "//tensorflow/compiler/tf2xla:xla_argument", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core/util:determinism", ], ) @@ -451,6 +460,7 @@ tf_cc_test( ], deps = [ ":xla_compile_util", + "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:test_main", "//tensorflow/core/framework:fake_input", "//tensorflow/core/kernels:identity_op", @@ -468,76 +478,46 @@ tf_proto_library( ) cc_library( - name = "xla_compilation_cache", - srcs = ["xla_compilation_cache.cc"], - hdrs = ["xla_compilation_cache.h"], + name = "device_compiler", + hdrs = ["device_compiler.h"], copts = tf_copts(), + visibility = [ + ":internal", + "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", + ], deps = [ - ":flags", - ":xla_activity_listener", - ":xla_activity_proto_cc", - ":xla_cluster_util", - ":xla_compilation_cache_proto_cc", + ":device_compilation_cache", + ":device_compilation_cluster_signature", + ":device_compilation_profiler", + ":device_compiler_client", + ":device_executable_persistor", + ":flags_headers", + ":tf_graph_to_hlo_compiler", ":xla_compile_util", - "//tensorflow/compiler/mlir:array_container_utils", - "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", - "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy", - "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla:xla_context", - "//tensorflow/compiler/xla:protobuf_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/service:compiler", - "//tensorflow/compiler/xla/service:hlo_proto_cc", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:framework_lite", "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:logging", - "//tensorflow/core/tpu:tpu_defs", + "//tensorflow/core/platform:thread_annotations", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", ], ) tf_cc_test( - name = "xla_compilation_cache_test", + name = "device_compiler_disable_test", srcs = [ - "xla_compilation_cache_test.cc", + "device_compiler_disable_test.cc", ], deps = [ + ":device_compilation_profiler", + ":device_compiler", ":flags", - ":xla_compilation_cache", ":xla_cpu_jit", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -tf_cc_test( - name = "xla_compilation_cache_disable_test", - srcs = [ - "xla_compilation_cache_disable_test.cc", - ], - deps = [ - ":flags", - ":xla_compilation_cache", - ":xla_cpu_jit", - "//tensorflow/compiler/tf2xla:common", + ":xla_device_compiler_client", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/core:test", @@ -564,20 +544,21 @@ cc_library( hdrs = ["get_compiler_ir.h"], visibility = [":internal"], deps = [ - ":common", ":compilability_check_util", - ":flags", + ":device_compiler", ":xla_device_no_jit_rewrite_registration", ":xla_launch_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime:core_cpu_internal", "//tensorflow/core/common_runtime/eager:tensor_handle", - "@com_google_absl//absl/memory", + "//tensorflow/tsl/platform:status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -592,10 +573,9 @@ cc_library( textual_hdrs = ["get_compiler_ir.h"], visibility = [":internal"], deps = [ - "//tensorflow/compiler/xla:statusor", - "@com_google_absl//absl/memory", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/core/platform:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", ], ) @@ -812,7 +792,6 @@ cc_library( "extract_outside_compilation_pass.cc", "force_xla_constants_on_host_pass.cc", "increase_dynamism_for_auto_jit_pass.cc", - "introduce_floating_point_jitter_pass.cc", "mark_for_compilation_pass.cc", "mark_for_compilation_pass_test_helper.cc", "partially_decluster_pass.cc", @@ -828,7 +807,6 @@ cc_library( "extract_outside_compilation_pass.h", "force_xla_constants_on_host_pass.h", "increase_dynamism_for_auto_jit_pass.h", - "introduce_floating_point_jitter_pass.h", "mark_for_compilation_pass.h", "mark_for_compilation_pass_test_helper.h", "partially_decluster_pass.h", @@ -998,8 +976,6 @@ tf_cc_test( "extract_outside_compilation_pass_test.cc", "force_xla_constants_on_host_pass_test.cc", "increase_dynamism_for_auto_jit_pass_test.cc", - "introduce_floating_point_jitter_pass_internal.h", - "introduce_floating_point_jitter_pass_test.cc", "mark_for_compilation_pass_test.cc", "partially_decluster_pass_test.cc", "rearrange_function_argument_pass_test.cc", @@ -1243,3 +1219,205 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "tf_to_hlo_compiler", + hdrs = ["tf_to_hlo_compiler.h"], + deps = [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/core:framework", + ], +) + +cc_library( + name = "tf_graph_to_hlo_compiler", + srcs = ["tf_graph_to_hlo_compiler.cc"], + hdrs = ["tf_graph_to_hlo_compiler.h"], + deps = [ + ":tf_to_hlo_compiler", + "//tensorflow/compiler/tf2xla:xla_compiler", + ], +) + +cc_library( + name = "device_compilation_profiler", + srcs = ["device_compilation_profiler.cc"], + hdrs = ["device_compilation_profiler.h"], + deps = [ + ":xla_activity_listener", + ":xla_activity_proto_cc", + ":xla_compile_util", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "//tensorflow/tsl/platform:mutex", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "device_compiler_client", + srcs = ["device_compiler_client.cc"], + hdrs = ["device_compiler_client.h"], + visibility = [ + ":internal", + "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", + ], + deps = [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/core/util:determinism", + ], +) + +tf_cc_test( + name = "device_compiler_client_test", + srcs = ["device_compiler_client_test.cc"], + deps = [ + ":device_compiler_client", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "xla_device_compiler_client", + srcs = ["xla_device_compiler_client.cc"], + hdrs = ["xla_device_compiler_client.h"], + deps = [ + ":device_compiler_client", + "//tensorflow/compiler/xla/client:local_client", + ], +) + +cc_library( + name = "device_executable_persistor", + hdrs = ["device_executable_persistor.h"], + deps = [ + ":xla_compilation_cache_proto_cc", + ":xla_device_compiler_client", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:lib_headers_for_pybind", + "//tensorflow/core/platform:path", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:statusor", + "//tensorflow/tsl/platform:statusor", + ], +) + +cc_library( + name = "device_compilation_cache", + hdrs = ["device_compilation_cache.h"], + deps = [ + ":device_compilation_cluster_signature", + ":device_compiler_client", + ":xla_compile_util", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/core:framework_lite", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "device_compilation_cluster_signature", + srcs = ["device_compilation_cluster_signature.cc"], + hdrs = ["device_compilation_cluster_signature.h"], + deps = ["//tensorflow/compiler/tf2xla:xla_compiler"], +) + +cc_library( + name = "pjrt_device_compiler_client", + srcs = ["pjrt_device_compiler_client.cc"], + hdrs = ["pjrt_device_compiler_client.h"], + deps = [ + ":device_compiler_client", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + ], +) + +cc_library( + name = "pjrt_device_context", + srcs = [ + "pjrt_device_context.cc", + ], + hdrs = [ + "pjrt_device_context.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/core:framework", + "//tensorflow/core/platform:status", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/tfrt/common:async_value_tensor", + "//tensorflow/core/tfrt/common:pjrt_util", + ], +) + +tf_cc_test( + name = "device_compilation_cluster_signature_test", + srcs = [ + "device_compilation_cluster_signature_test.cc", + ], + deps = [ + ":device_compilation_cluster_signature", + ":flags", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "device_compilation_profiler_test", + srcs = ["device_compilation_profiler_test.cc"], + deps = [ + ":device_compilation_profiler", + ":xla_activity_proto_cc", + "//tensorflow/compiler/jit/tests:device_compiler_test_helper", + "//tensorflow/core:protos_all_cc", + "@com_google_googletest//:gtest_main", + ], +) + +tf_cc_test( + name = "device_executable_persistor_test", + srcs = ["device_executable_persistor_test.cc"], + deps = [ + ":device_compiler_client", + ":device_executable_persistor", + ":xla_compilation_cache_proto_cc", + ":xla_cpu_device", + ":xla_cpu_jit", + ":xla_device_compiler_client", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:math_ops", + "//tensorflow/cc:scope", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:test", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status_matchers", + "//tensorflow/core/platform:statusor", + "@com_google_googletest//:gtest_main", + ], +) + +tf_cc_test( + name = "device_compilation_cache_test", + srcs = ["device_compilation_cache_test.cc"], + deps = [ + ":device_compilation_cache", + "//tensorflow/core:test", + "//tensorflow/core/platform:errors", + "//tensorflow/tsl/protobuf:error_codes_proto_impl_cc", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc index e2dda95e4ef..5a1619d79ff 100644 --- a/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc @@ -26,7 +26,7 @@ limitations under the License. namespace tensorflow { -using se::port::StatusOr; +using tsl::StatusOr; class CloneConstantsForBetterClusteringPassImpl { public: @@ -35,12 +35,12 @@ class CloneConstantsForBetterClusteringPassImpl { Status Run(); private: - Status CloneSmallHostConstantInputs( - const absl::flat_hash_set& name_set, Node* n); + Status CloneSmallConstantInputs(const absl::flat_hash_set& name_set, + Node* n); string GenerateUniqueName(const absl::flat_hash_set& name_set, absl::string_view prefix); - se::port::StatusOr CloneNode( - const absl::flat_hash_set& name_set, Node* n); + tsl::StatusOr CloneNode(const absl::flat_hash_set& name_set, + Node* n); Graph* graph_; int unique_name_counter_; @@ -75,25 +75,10 @@ StatusOr CloneConstantsForBetterClusteringPassImpl::CloneNode( } namespace { -StatusOr IsConstantOnHost(Node* n) { - if (n->output_type(0) == DT_INT32) { - // TensorFlow always puts int32 tensors on the host. - return true; - } - - DeviceNameUtils::ParsedName parsed; - TF_RET_CHECK( - DeviceNameUtils::ParseFullName(n->assigned_device_name(), &parsed)); - return parsed.type == DEVICE_CPU; -} - StatusOr IsConstantSmall(Node* n) { const TensorProto* proto = nullptr; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "value", &proto)); - // TODO(sanjoy): It may make sense to combine this threshold with XLA's "large - // constant" threshold, if there is one. - const int kSmallTensorThreshold = 16; int64_t total_elements = 1; for (const auto& dim : proto->tensor_shape().dim()) { if (dim.size() < 0) { @@ -102,21 +87,20 @@ StatusOr IsConstantSmall(Node* n) { } total_elements *= dim.size(); } + + // TODO(sanjoy): It may make sense to combine this threshold with XLA's "large + // constant" threshold, if there is one. + const int kSmallTensorThreshold = 16; return total_elements < kSmallTensorThreshold; } -// We only clone host constants for now since we want to avoid increasing memory +// We only clone small constants since we want to avoid increasing memory // pressure on GPUs. -StatusOr IsSmallHostConstant(Node* n) { +StatusOr IsSmallConstant(Node* n) { if (!n->IsConstant()) { return false; } - TF_ASSIGN_OR_RETURN(bool is_constant_on_host, IsConstantOnHost(n)); - if (!is_constant_on_host) { - return false; - } - return IsConstantSmall(n); } @@ -126,7 +110,7 @@ bool IsInPlaceOp(absl::string_view op_name) { } } // namespace -Status CloneConstantsForBetterClusteringPassImpl::CloneSmallHostConstantInputs( +Status CloneConstantsForBetterClusteringPassImpl::CloneSmallConstantInputs( const absl::flat_hash_set& name_set, Node* n) { std::vector in_edges; // Get the edges and sort them so we clone in a deterministic order. @@ -136,10 +120,9 @@ Status CloneConstantsForBetterClusteringPassImpl::CloneSmallHostConstantInputs( }); for (const Edge* e : in_edges) { Node* input = e->src(); - TF_ASSIGN_OR_RETURN(bool is_small_host_constant, - IsSmallHostConstant(input)); - if (is_small_host_constant && input->out_edges().size() != 1) { - VLOG(2) << "Cloning small host constant " << input->name(); + TF_ASSIGN_OR_RETURN(bool is_small_constant, IsSmallConstant(input)); + if (is_small_constant && input->out_edges().size() != 1) { + VLOG(2) << "Cloning small constant " << input->name(); TF_ASSIGN_OR_RETURN(Node* const input_cloned, CloneNode(name_set, input)); if (e->IsControlEdge()) { graph_->AddControlEdge(input_cloned, e->dst()); @@ -210,7 +193,7 @@ Status CloneConstantsForBetterClusteringPassImpl::Run() { // Iterate over a copy of the nodes to avoid iterating over g->nodes() while // creating more nodes. for (Node* n : nodes) { - TF_RETURN_IF_ERROR(CloneSmallHostConstantInputs(name_set, n)); + TF_RETURN_IF_ERROR(CloneSmallConstantInputs(name_set, n)); } return OkStatus(); } diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering.h b/tensorflow/compiler/jit/clone_constants_for_better_clustering.h index b0900d7f1cd..19e6c49ec44 100644 --- a/tensorflow/compiler/jit/clone_constants_for_better_clustering.h +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_CLONE_CONSTANTS_FOR_BETTER_CLUSTERING_H_ #include "absl/container/flat_hash_set.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc b/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc index 29ce9700e38..468b1eab82b 100644 --- a/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc @@ -61,6 +61,29 @@ Status CloneConstantsForBetterClustering(const Scope& s, const char* kCPU = "/job:localhost/replica:0/task:0/device:CPU:0"; const char* kGPU = "/job:localhost/replica:0/task:0/device:GPU:0"; +TEST(CloneConstantsForBetterClusteringTest, ScalarConstantPlacedOnGpu) { + Scope root = Scope::NewRootScope().ExitOnError(); + Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU); + + Output in = ops::Placeholder(on_gpu.WithOpName("in"), DT_FLOAT); + Output c = ops::Const(on_gpu.WithOpName("const"), 1.0f, {}); + Output add1 = ops::AddV2(on_gpu.WithOpName("add1"), in, c); + Output add2 = ops::AddV2(on_gpu.WithOpName("add2"), add1, c); + + std::unique_ptr result; + TF_ASSERT_OK(CloneConstantsForBetterClustering(root, &result)); + + OutputTensor add1_operand; + TF_ASSERT_OK( + FindNodeByName(result.get(), "add1")->input_tensor(1, &add1_operand)); + + OutputTensor add2_operand; + TF_ASSERT_OK( + FindNodeByName(result.get(), "add2")->input_tensor(1, &add2_operand)); + + EXPECT_NE(add1_operand.node, add2_operand.node); +} + TEST(CloneConstantsForBetterClusteringTest, HostConstantPlacedOnCpu) { Scope root = Scope::NewRootScope().ExitOnError(); Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU); @@ -114,7 +137,7 @@ TEST(CloneConstantsForBetterClusteringTest, HostConstantPlacedOnGpu) { EXPECT_NE(tr0_perm.node, tr1_perm.node); } -TEST(CloneConstantsForBetterClusteringTest, DontCloneNonHostConstants) { +TEST(CloneConstantsForBetterClusteringTest, CloneSmallDeviceConstants) { Scope root = Scope::NewRootScope().ExitOnError(); Scope on_gpu = root.WithAssignedDevice(kGPU).WithDevice(kGPU); @@ -143,7 +166,7 @@ TEST(CloneConstantsForBetterClusteringTest, DontCloneNonHostConstants) { TF_ASSERT_OK( FindNodeByName(result.get(), "perm_cast_1")->input_tensor(0, &tr1_perm)); - EXPECT_EQ(tr0_perm.node, tr1_perm.node); + EXPECT_NE(tr0_perm.node, tr1_perm.node); } TEST(CloneConstantsForBetterClusteringTest, DontCloneLargeConstants) { diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index fe86b35943a..9e62619354a 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -115,7 +115,7 @@ namespace tensorflow { namespace { -using se::port::StatusOr; +using tsl::StatusOr; // Represents a logical predicate, used as described in the algorithm overview // above. diff --git a/tensorflow/compiler/jit/deadness_analysis.h b/tensorflow/compiler/jit/deadness_analysis.h index fbd2a36c4f5..72b446f165a 100644 --- a/tensorflow/compiler/jit/deadness_analysis.h +++ b/tensorflow/compiler/jit/deadness_analysis.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_ #define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_ -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/graph/graph.h" namespace tensorflow { @@ -74,8 +73,8 @@ class DeadnessAnalysis { friend class DeadnessAnalysis; }; - virtual se::port::StatusOr GetPredicateFor( - Node* n, int oidx) const = 0; + virtual tsl::StatusOr GetPredicateFor(Node* n, + int oidx) const = 0; // Prints out the internal state of this instance. For debugging purposes // only. diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 02263f7b292..33cb716623f 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -36,7 +36,7 @@ limitations under the License. namespace tensorflow { namespace { -se::port::StatusOr HasInputsWithMismatchingDeadness( +tsl::StatusOr HasInputsWithMismatchingDeadness( const DeadnessAnalysis& deadness_analysis, const Node& n) { std::optional pred; for (const Edge* edge : n.in_edges()) { diff --git a/tensorflow/compiler/jit/device_compilation_cache.h b/tensorflow/compiler/jit/device_compilation_cache.h new file mode 100644 index 00000000000..872c9c6c956 --- /dev/null +++ b/tensorflow/compiler/jit/device_compilation_cache.h @@ -0,0 +1,212 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_CACHE_H_ +#define TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_CACHE_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/device_compilation_cluster_signature.h" +#include "tensorflow/compiler/jit/xla_compile_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// Cache to store compiled HLO, executables and related metadata keyed by +// `DeviceCompilationClusterSignature`. The cache owns the stored +// CompilationResults and Executables. +// Currently no cache eviction policy is implemented and the cache grows without +// bound. +template +class DeviceCompilationCache { + public: + DeviceCompilationCache() = default; + ~DeviceCompilationCache() = default; + + using Key = DeviceCompilationClusterSignature; + struct Value { + DeviceCompileState compile_state = DeviceCompileState::kUncompiled; + Status compilation_status; + int64_t request_count = 0; + const XlaCompiler::CompilationResult* compilation_result = nullptr; + ExecutableType* executable = nullptr; + }; + + // Returns std::nullopt if value for the supplied key is not found. If a value + // is found, `request_count` is incremented before returning the value. + std::optional Lookup(const Key& key) const; + + // Inserts an empty value if value is not found and returns it. If a value is + // found, `request_count` is incremented before returning the value. + Value LookupOrCreate(const Key& key); + + // Caches `compile_state`, `compilation_status`, `compilation_result` and + // `executable` and associates them with the provided `key`. Takes ownership + // of `compilation_result` and `executable`. Does not increment the + // corresponding `request_count`. Only arguments that are not std::nullopt are + // updated in the cache. + void Store(const Key& key, std::optional compile_state, + std::optional compilation_status, + std::optional> + compilation_result, + std::optional> executable); + + std::string DebugString() const; + + private: + // The value associated with a cache entry. + struct Entry { + mutable mutex mu; + + // The current compilation state for this entry. + DeviceCompileState compile_state TF_GUARDED_BY(mu) = + DeviceCompileState::kUncompiled; + + // The number of times a compilation with this signature has been requested. + int64_t request_count TF_GUARDED_BY(mu) = 0; + + // Did compilation succeed? + Status compilation_status TF_GUARDED_BY(mu); + + // Output of the XlaCompiler. + std::unique_ptr compilation_result + TF_GUARDED_BY(mu); + + // The XLA executable compiled from . May be null if no + // executable has been built. + std::unique_ptr executable TF_GUARDED_BY(mu); + + std::string DebugString() const { + mutex_lock lock(mu); + return absl::StrCat( + "{compile_state: ", compile_state, ", request_count: ", request_count, + ", compilation_status: ", compilation_status.ToString(), + ", compilation_result?: ", compilation_result != nullptr, + ", executable?: ", executable != nullptr, "}"); + } + }; + + mutable mutex compile_cache_mu_; + absl::flat_hash_map, Key::Hash> cache_ + TF_GUARDED_BY(compile_cache_mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(DeviceCompilationCache); +}; + +template +std::optional::Value> +DeviceCompilationCache::Lookup(const Key& key) const { + // The outer lock protects the existence of the cache entry. It does not + // protect the contents of the cache entry. + Entry* entry; + { + mutex_lock lock(compile_cache_mu_); + // Find cache entry. + auto it = cache_.find(key); + if (it == cache_.cend()) { + return std::nullopt; + } + + entry = it->second.get(); + } + + mutex_lock lock(entry->mu); + Value value = {/*compile_state=*/entry->compile_state, + /*compilation_status=*/entry->compilation_status, + /*request_count=*/++entry->request_count, + /*compilation_result=*/entry->compilation_result.get(), + /*executable=*/entry->executable.get()}; + return value; +} + +template +typename DeviceCompilationCache::Value +DeviceCompilationCache::LookupOrCreate(const Key& key) { + // The outer lock protects the existence of the cache entry. It does not + // protect the contents of the cache entry. + Entry* entry; + { + mutex_lock lock(compile_cache_mu_); + // Emplace empty cache entry if not found. + auto it = cache_.emplace(key, std::make_unique()).first; + entry = it->second.get(); + } + + mutex_lock lock(entry->mu); + Value value = {/*compile_state=*/entry->compile_state, + /*compilation_status=*/entry->compilation_status, + /*request_count=*/++entry->request_count, + /*compilation_result=*/entry->compilation_result.get(), + /*executable=*/entry->executable.get()}; + return value; +} + +template +void DeviceCompilationCache::Store( + const Key& key, std::optional compile_state, + std::optional compilation_status, + std::optional> + compilation_result, + std::optional> executable) { + Entry* entry; + { + mutex_lock lock(compile_cache_mu_); + // Emplace empty cache entry if not found. + auto it = cache_.emplace(key, std::make_unique()).first; + entry = it->second.get(); + } + + { + mutex_lock lock(entry->mu); + if (compile_state.has_value()) { + entry->compile_state = *compile_state; + } + if (compilation_status.has_value()) { + entry->compilation_status = *compilation_status; + } + if (compilation_result.has_value()) { + entry->compilation_result = std::move(*compilation_result); + } + if (executable.has_value()) { + entry->executable = std::move(*executable); + } + } +} + +template +std::string DeviceCompilationCache::DebugString() const { + std::string s = "DeviceCompilationCache {\n"; + { + mutex_lock lock(compile_cache_mu_); + for (const auto& [key, entry] : cache_) { + absl::StrAppend(&s, key.HumanString(), " : ", entry->DebugString(), + ",\n"); + } + } + absl::StrAppend(&s, "}"); + + return s; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_CACHE_H_ diff --git a/tensorflow/compiler/jit/device_compilation_cache_test.cc b/tensorflow/compiler/jit/device_compilation_cache_test.cc new file mode 100644 index 00000000000..b755ea0f362 --- /dev/null +++ b/tensorflow/compiler/jit/device_compilation_cache_test.cc @@ -0,0 +1,220 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/device_compilation_cache.h" + +#include +#include +#include +#include +#include + +#include +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/tsl/protobuf/error_codes.pb.h" + +namespace tensorflow { +namespace { +struct FakeExecutable { + std::string data; + explicit FakeExecutable(const std::string& s) : data(s) {} +}; + +using Cache = DeviceCompilationCache; +using Signature = DeviceCompilationClusterSignature; + +StatusOr BuildSampleSignature(const std::string& fn_name) { + NameAttrList fn; + fn.set_name(fn_name); + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kConstant; + args[0].type = DT_INT32; + args[0].shape = TensorShape({4, 0}); + args[0].constant_value = Tensor(DT_INT32, {4, 0}); + return Signature::Build(fn, args); +} + +TEST(DeviceCompilationCacheTest, LookupEntryDoesntExist) { + auto cache = std::make_unique(); + + TF_ASSERT_OK_AND_ASSIGN(auto key, BuildSampleSignature("foo")); + auto cache_value = cache->Lookup(key); + + EXPECT_FALSE(cache_value.has_value()); +} + +TEST(DeviceCompilationCacheTest, LookupOrCreateEntryDoesntExist) { + auto cache = std::make_unique(); + + TF_ASSERT_OK_AND_ASSIGN(auto key, BuildSampleSignature("foo")); + Cache::Value cache_value = cache->LookupOrCreate(key); + + EXPECT_EQ(cache_value.compile_state, DeviceCompileState::kUncompiled); + EXPECT_EQ(cache_value.request_count, 1); + EXPECT_EQ(cache_value.compilation_result, nullptr); + EXPECT_EQ(cache_value.executable, nullptr); +} + +TEST(DeviceCompilationCacheTest, IncrementRequestCountOnLookup) { + auto cache = std::make_unique(); + + TF_ASSERT_OK_AND_ASSIGN(auto key, BuildSampleSignature("foo")); + Cache::Value cache_value = cache->LookupOrCreate(key); + EXPECT_EQ(cache_value.request_count, 1); + + cache_value = cache->LookupOrCreate(key); + EXPECT_EQ(cache_value.request_count, 2); + + cache_value = cache->LookupOrCreate(key); + EXPECT_EQ(cache_value.request_count, 3); +} + +TEST(DeviceCompilationCacheTest, RequestCountUnchangedOnStore) { + auto cache = std::make_unique(); + + TF_ASSERT_OK_AND_ASSIGN(auto key, BuildSampleSignature("foo")); + Cache::Value cache_value = cache->LookupOrCreate(key); + EXPECT_EQ(cache_value.request_count, 1); + + cache_value = cache->LookupOrCreate(key); + EXPECT_EQ(cache_value.request_count, 2); + + cache_value = cache->LookupOrCreate(key); + EXPECT_EQ(cache_value.request_count, 3); + + auto compilation_result = std::make_unique(); + cache->Store(key, DeviceCompileState::kCompiled, OkStatus(), + std::move(compilation_result), std::nullopt); + cache_value = cache->LookupOrCreate(key); + + EXPECT_EQ(cache_value.request_count, 4); +} + +TEST(DeviceCompilationCacheTest, StoreLookup) { + auto cache = std::make_unique(); + + TF_ASSERT_OK_AND_ASSIGN(auto key, BuildSampleSignature("foo")); + auto compilation_result = std::make_unique(); + auto executable = std::make_unique("foo_exe"); + cache->Store(key, DeviceCompileState::kCompiled, OkStatus(), + std::move(compilation_result), std::move(executable)); + auto cache_value = cache->Lookup(key); + + EXPECT_EQ(cache_value->compile_state, DeviceCompileState::kCompiled); + EXPECT_EQ(cache_value->request_count, 1); + EXPECT_TRUE(cache_value->compilation_status.ok()); + EXPECT_TRUE(cache_value->compilation_result != nullptr); + EXPECT_TRUE(cache_value->executable != nullptr); + EXPECT_EQ(cache_value->executable->data, "foo_exe"); +} + +TEST(DeviceCompilationCacheTest, StoreLookupOrCreate) { + auto cache = std::make_unique(); + + TF_ASSERT_OK_AND_ASSIGN(auto key, BuildSampleSignature("foo")); + auto compilation_result = std::make_unique(); + auto executable = std::make_unique("foo_exe"); + cache->Store(key, DeviceCompileState::kCompiled, OkStatus(), + std::move(compilation_result), std::move(executable)); + auto cache_value = cache->LookupOrCreate(key); + + EXPECT_EQ(cache_value.compile_state, DeviceCompileState::kCompiled); + EXPECT_EQ(cache_value.request_count, 1); + EXPECT_TRUE(cache_value.compilation_status.ok()); + EXPECT_TRUE(cache_value.compilation_result != nullptr); + EXPECT_TRUE(cache_value.executable != nullptr); + EXPECT_EQ(cache_value.executable->data, "foo_exe"); +} + +TEST(DeviceCompilationCacheTest, StoreOptionalArgs) { + auto cache = std::make_unique(); + + TF_ASSERT_OK_AND_ASSIGN(auto key, BuildSampleSignature("foo")); + + auto compilation_result = std::make_unique(); + auto executable = std::make_unique("foo_exe"); + + cache->Store(key, DeviceCompileState::kCompiled, std::nullopt, std::nullopt, + std::nullopt); + auto cache_value = cache->Lookup(key); + + EXPECT_EQ(cache_value->compile_state, DeviceCompileState::kCompiled); + EXPECT_TRUE(cache_value->compilation_status.ok()); + EXPECT_TRUE(cache_value->compilation_result == nullptr); + EXPECT_TRUE(cache_value->executable == nullptr); + + cache->Store(key, std::nullopt, errors::InvalidArgument("Couldn't compile."), + std::nullopt, std::nullopt); + cache_value = cache->Lookup(key); + + EXPECT_EQ(cache_value->compile_state, DeviceCompileState::kCompiled); + EXPECT_EQ(cache_value->compilation_status.code(), error::INVALID_ARGUMENT); + EXPECT_TRUE(cache_value->compilation_result == nullptr); + EXPECT_TRUE(cache_value->executable == nullptr); + + cache->Store(key, std::nullopt, std::nullopt, std::move(compilation_result), + std::nullopt); + cache_value = cache->Lookup(key); + + EXPECT_EQ(cache_value->compile_state, DeviceCompileState::kCompiled); + EXPECT_EQ(cache_value->compilation_status.code(), error::INVALID_ARGUMENT); + EXPECT_TRUE(cache_value->compilation_result != nullptr); + EXPECT_TRUE(cache_value->executable == nullptr); + + cache->Store(key, std::nullopt, std::nullopt, std::nullopt, + std::move(executable)); + cache_value = cache->Lookup(key); + + EXPECT_EQ(cache_value->compile_state, DeviceCompileState::kCompiled); + EXPECT_EQ(cache_value->compilation_status.code(), error::INVALID_ARGUMENT); + EXPECT_TRUE(cache_value->compilation_result != nullptr); + EXPECT_TRUE(cache_value->executable != nullptr); + EXPECT_EQ(cache_value->executable->data, "foo_exe"); +} + +TEST(DeviceCompilationCacheTest, StoreMultipleEntries) { + auto cache = std::make_unique(); + + TF_ASSERT_OK_AND_ASSIGN(auto key1, BuildSampleSignature("foo")); + TF_ASSERT_OK_AND_ASSIGN(auto key2, BuildSampleSignature("bar")); + + auto compilation_result1 = std::make_unique(); + auto compilation_result2 = std::make_unique(); + auto executable1 = std::make_unique("foo_exe"); + auto executable2 = std::make_unique("bar_exe"); + cache->Store(key1, DeviceCompileState::kCompiled, + errors::InvalidArgument("Invalid argument."), + std::move(compilation_result1), std::move(executable1)); + cache->Store(key2, DeviceCompileState::kCompiling, OkStatus(), + std::move(compilation_result2), std::move(executable2)); + auto cache_value_1 = cache->Lookup(key1); + auto cache_value_2 = cache->Lookup(key2); + + EXPECT_EQ(cache_value_1->compile_state, DeviceCompileState::kCompiled); + EXPECT_EQ(cache_value_1->compilation_status.code(), error::INVALID_ARGUMENT); + EXPECT_TRUE(cache_value_1->compilation_result != nullptr); + EXPECT_TRUE(cache_value_1->executable != nullptr); + EXPECT_EQ(cache_value_1->executable->data, "foo_exe"); + + EXPECT_EQ(cache_value_2->compile_state, DeviceCompileState::kCompiling); + EXPECT_TRUE(cache_value_2->compilation_status.ok()); + EXPECT_TRUE(cache_value_2->compilation_result != nullptr); + EXPECT_TRUE(cache_value_2->executable != nullptr); + EXPECT_EQ(cache_value_2->executable->data, "bar_exe"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/device_compilation_cluster_signature.cc b/tensorflow/compiler/jit/device_compilation_cluster_signature.cc new file mode 100644 index 00000000000..8ca571b104b --- /dev/null +++ b/tensorflow/compiler/jit/device_compilation_cluster_signature.cc @@ -0,0 +1,139 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/device_compilation_cluster_signature.h" + +#include +#include +#include + +namespace tensorflow { +namespace { +using Signature = DeviceCompilationClusterSignature; +using TensorTypeAndShape = Signature::TensorTypeAndShape; + +// Functor that converts a Signature's arg to a human readable string. +struct SignatureHumanStringAppender { + explicit SignatureHumanStringAppender(std::string* dest) : dest(dest) {} + std::string* dest; + void operator()(const Tensor& arg) { + absl::StrAppend(dest, "; ", arg.DebugString()); + } + void operator()(const TensorTypeAndShape& arg) { + absl::StrAppend(dest, ",", DataTypeString(arg.first)); + absl::StrAppend(dest, " [", absl::StrJoin(arg.second, ","), "]"); + } +}; + +// Functor that compares the arg values of two different signatures. Returns +// true when the args are not equal. +struct SignatureNotEqual { + bool operator()(const Tensor& arg, const Tensor& other) { + return arg.dtype() != other.dtype() || arg.shape() != other.shape() || + arg.tensor_data() != other.tensor_data(); + } + bool operator()(const TensorTypeAndShape& arg, + const TensorTypeAndShape& other) { + return arg.first != other.first || arg.second != other.second; + } + bool operator()(const Tensor& arg, const TensorTypeAndShape& other) { + return true; + } + bool operator()(const TensorTypeAndShape& arg, const Tensor& other) { + return true; + } +}; + +// Functor that incrementally computes a Signature's hash given its current hash +// and one of its args. +struct SignatureHashCombiner { + explicit SignatureHashCombiner(const uint64 h) : h(h) {} + uint64 h; + uint64 operator()(const Tensor& arg) { + h = Hash64Combine(h, std::hash()(static_cast(arg.dtype()))); + h = Hash64Combine( + h, Hash64(arg.tensor_data().data(), arg.tensor_data().size())); + for (int dim = 0; dim < arg.dims(); ++dim) { + h = Hash64Combine(h, std::hash()(arg.dim_size(dim))); + } + return h; + } + uint64 operator()(const TensorTypeAndShape& arg) { + h = Hash64Combine(h, std::hash()(static_cast(arg.first))); + h = Hash64Combine(h, std::hash()(arg.second.size())); + for (int dim : arg.second) { + h = Hash64Combine(h, std::hash()(dim)); + } + return h; + } +}; +} // namespace + +// Compute a string signature which encodes the shapes of the +// arguments in the supplied list. +std::string Signature::HumanString() const { + std::string result = name; + for (const auto& arg : args) { + std::visit(SignatureHumanStringAppender(&result), arg); + } + return result; +} + +bool Signature::operator==(const Signature& other) const { + if (name != other.name) return false; + if (args.size() != other.args.size()) return false; + for (int i = 0, end = args.size(); i < end; ++i) { + if (std::visit(SignatureNotEqual(), args[i], other.args[i])) { + return false; + } + } + return true; +} + +uint64 Signature::Hash::operator()(const Signature& signature) const { + uint64 h = std::hash()(signature.name); + for (const auto& arg : signature.args) { + h = std::visit(SignatureHashCombiner(h), arg); + } + return h; +} + +StatusOr Signature::Build( + const NameAttrList& function, + absl::Span args) { + Signature signature; + signature.name = Canonicalize(function.name(), AttrSlice(&function.attr())); + + for (const XlaCompiler::Argument& arg : args) { + switch (arg.kind) { + case XlaCompiler::Argument::kConstant: + case XlaCompiler::Argument::kConstantResource: + signature.args.push_back(arg.constant_value); + break; + case XlaCompiler::Argument::kParameter: + case XlaCompiler::Argument::kResource: + signature.args.push_back( + TensorTypeAndShape(arg.type, arg.DimensionSizesAsInlinedVector())); + break; + default: + return errors::InvalidArgument( + "Unhandled argument kind in XlaCompilationCache: ", + arg.HumanString()); + } + } + return std::move(signature); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/device_compilation_cluster_signature.h b/tensorflow/compiler/jit/device_compilation_cluster_signature.h new file mode 100644 index 00000000000..76a8daa0d95 --- /dev/null +++ b/tensorflow/compiler/jit/device_compilation_cluster_signature.h @@ -0,0 +1,56 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_CLUSTER_SIGNATURE_H_ +#define TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_CLUSTER_SIGNATURE_H_ + +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_compiler.h" + +namespace tensorflow { + +// Describes the types, shapes and any compile-time constant arguments +// to a kernel. Key that uniquely identifies a compilation output. +struct DeviceCompilationClusterSignature { + // Name of the cluster, built from the function name and it's attributes. + string name; + + // List of args (either as a TensorTypeAndShape or as a Tensor value) + // for compile-time constant arguments to the compilation, ordered by + // argument number. Tensors must be in host memory. + using TensorTypeAndShape = + std::pair>; + absl::InlinedVector, 8> args; + + bool operator==(const DeviceCompilationClusterSignature& other) const; + + struct Hash { + uint64 operator()(const DeviceCompilationClusterSignature& signature) const; + }; + + // Returns a human-readable description of the signature. + string HumanString() const; + + // Builds the signature for a compilation. + static StatusOr Build( + const NameAttrList& function, + absl::Span args); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_CLUSTER_SIGNATURE_H_ diff --git a/tensorflow/compiler/jit/xla_compilation_cache_test.cc b/tensorflow/compiler/jit/device_compilation_cluster_signature_test.cc similarity index 73% rename from tensorflow/compiler/jit/xla_compilation_cache_test.cc rename to tensorflow/compiler/jit/device_compilation_cluster_signature_test.cc index d806e1b2d7d..39758c71580 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache_test.cc +++ b/tensorflow/compiler/jit/device_compilation_cluster_signature_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/device_compilation_cluster_signature.h" + +#include +#include #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -23,10 +26,9 @@ limitations under the License. namespace tensorflow { namespace { +using SignatureHash = DeviceCompilationClusterSignature::Hash; -using SignatureHash = XlaCompilationCache::Signature::Hash; - -TEST(XlaCompilationCacheTest, SignatureEquality) { +TEST(DeviceCompilationClusterSignatureTest, SignatureEquality) { NameAttrList fn; fn.set_name("afunction"); std::vector args(1); @@ -34,20 +36,20 @@ TEST(XlaCompilationCacheTest, SignatureEquality) { args[0].type = DT_INT32; args[0].shape = TensorShape({4, 0}); args[0].constant_value = Tensor(DT_INT32, {4, 0}); - TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s1, - XlaCompilationCache::BuildSignature(fn, args)); + TF_ASSERT_OK_AND_ASSIGN(DeviceCompilationClusterSignature s1, + DeviceCompilationClusterSignature::Build(fn, args)); args[0].type = DT_FLOAT; args[0].constant_value = Tensor(DT_FLOAT, {4, 0}); - TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s2, - XlaCompilationCache::BuildSignature(fn, args)); + TF_ASSERT_OK_AND_ASSIGN(DeviceCompilationClusterSignature s2, + DeviceCompilationClusterSignature::Build(fn, args)); args[0].shape = TensorShape({0, 4}); args[0].constant_value = Tensor(DT_FLOAT, {0, 4}); - TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s3, - XlaCompilationCache::BuildSignature(fn, args)); + TF_ASSERT_OK_AND_ASSIGN(DeviceCompilationClusterSignature s3, + DeviceCompilationClusterSignature::Build(fn, args)); - std::vector signatures = {s1, s2, s3}; + std::vector signatures = {s1, s2, s3}; for (int i = 0; i < signatures.size(); ++i) { for (int j = 0; j < signatures.size(); ++j) { EXPECT_EQ(i == j, signatures[i] == signatures[j]) @@ -67,7 +69,7 @@ TEST(XlaCompilationCacheTest, SignatureEquality) { } } -TEST(XlaCompilationCacheTest, SignatureUniqueness) { +TEST(DeviceCompilationClusterSignatureTest, SignatureUniqueness) { NameAttrList fn; fn.set_name("afunction"); std::vector args(2); @@ -79,13 +81,13 @@ TEST(XlaCompilationCacheTest, SignatureUniqueness) { args[1].type = DT_INT32; args[1].shape = TensorShape({4, 0}); - TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s1, - XlaCompilationCache::BuildSignature(fn, args)); + TF_ASSERT_OK_AND_ASSIGN(DeviceCompilationClusterSignature s1, + DeviceCompilationClusterSignature::Build(fn, args)); using std::swap; // go/using-std-swap swap(args[0], args[1]); - TF_ASSERT_OK_AND_ASSIGN(XlaCompilationCache::Signature s2, - XlaCompilationCache::BuildSignature(fn, args)); + TF_ASSERT_OK_AND_ASSIGN(DeviceCompilationClusterSignature s2, + DeviceCompilationClusterSignature::Build(fn, args)); EXPECT_NE(s1.HumanString(), s2.HumanString()); EXPECT_NE(SignatureHash()(s1), SignatureHash()(s2)); @@ -110,10 +112,9 @@ void BM_BuildSignature(::testing::benchmark::State& state) { } for (auto i : state) { - StatusOr s = - XlaCompilationCache::BuildSignature(fn, args); + auto s = DeviceCompilationClusterSignature::Build(fn, args); CHECK(s.ok()); - XlaCompilationCache::Signature sig = std::move(s.value()); + DeviceCompilationClusterSignature sig = std::move(s.value()); } } BENCHMARK(BM_BuildSignature)->Arg(0)->Arg(1)->Arg(2)->Arg(5)->Arg(10); diff --git a/tensorflow/compiler/jit/device_compilation_profiler.cc b/tensorflow/compiler/jit/device_compilation_profiler.cc new file mode 100644 index 00000000000..3ff54b6ff35 --- /dev/null +++ b/tensorflow/compiler/jit/device_compilation_profiler.cc @@ -0,0 +1,229 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/device_compilation_profiler.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/xla_activity.pb.h" +#include "tensorflow/compiler/jit/xla_activity_listener.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/metrics.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/tsl/platform/mutex.h" + +namespace tensorflow { +namespace { +bool ShouldBeMegamorphic(int64_t compile_count, int64_t execution_count) { + const int64_t kCompileThreshold = 10; + const int64_t kMinExecutionsPerCompile = 50; + + // This heuristic is trying to capture the following property: have we sunk a + // certain minimum amount of compile time into the cluster that didn't quite + // "pay off"? + return compile_count > kCompileThreshold && + execution_count < kMinExecutionsPerCompile * compile_count; +} + +void RegisterExecutionForCluster( + const NameAttrList& function, + DeviceCompilationProfiler::ClusterCompileStats* stats) { + ++stats->execution_count; + + // The is_megamorphic bit is "sticky". We assume clusters that have been + // observed to be megamorphic once stay megamorphic forever. + if (!stats->is_megamorphic && + ShouldBeMegamorphic(stats->compile_count, stats->execution_count)) { + VLOG(1) << "Marking " << function.name() + << " as megamorphic, compile_count=" << stats->compile_count + << " execution_count=" << stats->execution_count; + stats->is_megamorphic = true; + } +} + +// The number of times a lazy compilation must be requested for a specific +// signature before we attempt to compile it. +constexpr int64_t kDefaultCompilationThreshold = 2; + +// Maximum number of ongoing compilations. +constexpr int64_t kMaxNumOngoingCompilations = kNumAsyncDeviceCompilerThreads; + +} // namespace + +DeviceCompilationProfiler::~DeviceCompilationProfiler() { + mutex_lock lock(mu_); + cluster_compile_stats_.clear(); +} + +StatusOr +DeviceCompilationProfiler::GetCompileStats(const NameAttrList& function) const { + mutex_lock lock(mu_); + + if (auto it = cluster_compile_stats_.find(function.name()); + it != cluster_compile_stats_.end()) { + return it->second; + } + + return errors::NotFound("Couldn't find compilation stats for cluster: ", + function.name()); +} + +void DeviceCompilationProfiler::RegisterExecution( + const NameAttrList& function) { + mutex_lock lock(mu_); + auto it = + cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{}) + .first; + RegisterExecutionForCluster(function, &it->second); +} + +Status DeviceCompilationProfiler::RegisterCompilation( + const NameAttrList& function, int64_t compile_time_us, + bool used_persistent_cache) { + metrics::UpdateXlaCompilationTime(compile_time_us); + + const std::string& function_name = function.name(); + + mutex_lock lock(mu_); + // Create a stats entry if it doesn't already exist. + auto it = + cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{}) + .first; + + const uint64 compile_time_s = compile_time_us / 1.0e6; + it->second.compile_count++; + it->second.cumulative_compile_time_us += compile_time_us; + VLOG(1) << "Compiled " << function_name << " " << it->second.compile_count + << " times, compile time: " << compile_time_us + << " us, cumulative: " << it->second.cumulative_compile_time_us + << " us (" + << tensorflow::strings::HumanReadableElapsedTime(compile_time_s) + << " / " + << tensorflow::strings::HumanReadableElapsedTime( + it->second.cumulative_compile_time_us / 1.0e6) + << ")"; + + XlaJitCompilationActivity jit_compilation_activity; + jit_compilation_activity.set_cluster_name(function_name); + jit_compilation_activity.set_compile_count(it->second.compile_count); + jit_compilation_activity.set_compile_time_us(compile_time_us); + jit_compilation_activity.set_cumulative_compile_time_us( + it->second.cumulative_compile_time_us); + jit_compilation_activity.set_used_persistent_cache(used_persistent_cache); + return BroadcastXlaActivity(std::move(jit_compilation_activity)); +} + +bool DeviceCompilationProfiler::ShouldCompileCluster( + const NameAttrList& function, DeviceCompileMode compile_mode, + int64_t current_request_count) { + std::optional compile_threshold; + if (compile_mode == DeviceCompileMode::kLazy) { + compile_threshold = kDefaultCompilationThreshold; + } else if (compile_mode == DeviceCompileMode::kAsync) { + compile_threshold = 0; // for now, always compile right away. + } + + if (compile_mode == DeviceCompileMode::kStrict) { + // Lazy compilation is disabled. + return true; + } + + mutex_lock lock(mu_); + // Create a stats entry if one isn't found and register an execution. + // Determine eligibility assuming this is the first execution of the cluster + // and this cluster has never been compiled before. + auto [it, cluster_not_found] = + cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{}); + if (cluster_not_found) { + RegisterExecutionForCluster(function, &it->second); + } + + // We avoid compiling clusters that have "gone megamorphic" i.e. have an + // excessive amount of shape dynamism. + if (it->second.is_megamorphic) { + BroadcastOptimizationRemark(XlaOptimizationRemark::MEGAMORPHIC_FUNCTION, + function.name()) + .IgnoreError(); + VLOG(2) << "Not compiling cluster " << function.name() + << " because it is megamorphic."; + return false; + } + + // We always compile a cluster the very first time it is executed. This is an + // optimistic guess that pays off for statically shaped TensorFlow graphs + // (since they get the benefit of XLA right away without waiting for warmup) + // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at + // most one cluster-compilation's worth of compile time). + if (it->second.execution_count == 1) { + return true; + } + + if (compile_mode == DeviceCompileMode::kAsync) { + // Asynchronous compilation is enabled. + if (num_ongoing_compilations_ >= kMaxNumOngoingCompilations) { + VLOG(2) << "Not asynchronously compiling cluster " << function.name() + << " because of too many ongoing compilations."; + return false; + } + } + + bool reached_compile_threshold = current_request_count >= *compile_threshold; + if (!reached_compile_threshold) { + VLOG(2) << "Not compiling cluster " << function.name() + << " because it has not reached compile threshold; threshold is " + << *compile_threshold << " execution count " + << current_request_count << "."; + } + return reached_compile_threshold; +} + +void DeviceCompilationProfiler::IncrementOngoingAsyncCompilations() { + mutex_lock lock(mu_); + num_ongoing_compilations_++; +} + +void DeviceCompilationProfiler::DecrementOngoingAsyncCompilations() { + mutex_lock lock(mu_); + num_ongoing_compilations_--; +} + +int64_t DeviceCompilationProfiler::GetNumOngoingAsyncCompilations() const { + mutex_lock lock(mu_); + return num_ongoing_compilations_; +} + +std::string DeviceCompilationProfiler::DebugString() const { + std::string debug_string = + "DeviceCompilationProfiler {\ncluster_compile_stats: {\n"; + { + mutex_lock lock(mu_); + + for (const auto& [key, stats] : cluster_compile_stats_) { + absl::StrAppend(&debug_string, key, ": ", stats.DebugString(), "\n"); + } + } + + absl::StrAppend(&debug_string, "}\nnum_ongoing_compilations=", + GetNumOngoingAsyncCompilations(), "\n}\n"); + + return debug_string; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/device_compilation_profiler.h b/tensorflow/compiler/jit/device_compilation_profiler.h new file mode 100644 index 00000000000..0f7fdf568bf --- /dev/null +++ b/tensorflow/compiler/jit/device_compilation_profiler.h @@ -0,0 +1,100 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_PROFILER_H_ +#define TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_PROFILER_H_ + +#include +#include + +#include "tensorflow/compiler/jit/xla_compile_util.h" +#include "tensorflow/core/framework/attr_value.pb.h" + +namespace tensorflow { + +// Tracks statistics for device compilation and uses these to determine whether +// the given cluster should be compiled or not. +class DeviceCompilationProfiler : public ResourceBase { + public: + DeviceCompilationProfiler() = default; + ~DeviceCompilationProfiler() final; + + struct ClusterCompileStats { + // Number of times the cluster has been (re-)compiled. + int64_t compile_count = 0; + + // The number of times this cluster has been executed. + int64_t execution_count = 0; + + // Cumulative time spent compiling the cluster. + int64_t cumulative_compile_time_us = 0; + + // True if we have decided that this cluster is too dynamic (i.e. its shapes + // change too frequently) to profitably JIT compile. Once a cluster is + // tagged megamorphic, it stays megamorphic forever. + bool is_megamorphic = false; + + std::string DebugString() const { + return absl::StrCat( + "DeviceCompilationProfiler::ClusterCompileStats {compile_count=", + compile_count, ", execution_count=", execution_count, + ", cumulative_compile_time_us=", cumulative_compile_time_us, + ", is_megamorphic=", is_megamorphic, "}"); + } + }; + + // Returns the compilation statistics for the given cluster. + StatusOr GetCompileStats( + const NameAttrList& function) const; + + // Determines whether the cluster should be compiled. Creates and inserts an + // entry into stats (also calls `RegisterExecution`) for `function` if it + // doesn't already exist. + bool ShouldCompileCluster(const NameAttrList& function, + DeviceCompileMode compile_mode, + int64_t current_request_count); + + // Registers a cluster execution. Increments the execution count for the given + // cluster and also determines whether the cluster has gone megamorphic (and + // sets the megamorphic bit accordingly). + void RegisterExecution(const NameAttrList& function); + + // Registers a cluster compilation. Increments the compilation count and + // accumulates the compile time for the given cluster. Also broadcasts an + // XlaJitCompilationActivity. + Status RegisterCompilation(const NameAttrList& function, + int64_t compile_time_us, + bool used_persistent_cache); + + void IncrementOngoingAsyncCompilations(); + void DecrementOngoingAsyncCompilations(); + int64_t GetNumOngoingAsyncCompilations() const; + std::string DebugString() const override; + + private: + mutable mutex mu_; + + // Maps cluster names to compilation statistics for said cluster. + absl::flat_hash_map cluster_compile_stats_ + TF_GUARDED_BY(mu_); + + int64_t num_ongoing_compilations_ TF_GUARDED_BY(mu_) = 0; + + TF_DISALLOW_COPY_AND_ASSIGN(DeviceCompilationProfiler); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_PROFILER_H_ diff --git a/tensorflow/compiler/jit/device_compilation_profiler_test.cc b/tensorflow/compiler/jit/device_compilation_profiler_test.cc new file mode 100644 index 00000000000..317858e84b7 --- /dev/null +++ b/tensorflow/compiler/jit/device_compilation_profiler_test.cc @@ -0,0 +1,243 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/device_compilation_profiler.h" + +#include +#include +#include + +#include +#include +#include "tensorflow/compiler/jit/tests/device_compiler_test_helper.h" +#include "tensorflow/compiler/jit/xla_activity.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" + +namespace tensorflow { +namespace { + +TEST(DeviceCompilationProfilerTest, RegisterExecution) { + DeviceCompilationProfiler* profiler = new DeviceCompilationProfiler(); + core::ScopedUnref profiler_ref(profiler); + + NameAttrList function; + function.set_name("TestFunc"); + + for (int i = 0; i < 5; ++i) { + profiler->RegisterExecution(function); + } + TF_ASSERT_OK_AND_ASSIGN(auto stats, profiler->GetCompileStats(function)); + EXPECT_EQ(stats.execution_count, 5); +} + +TEST(DeviceCompilationProfilerTest, RegisterCompilation) { + DeviceCompilationProfiler* profiler = new DeviceCompilationProfiler(); + core::ScopedUnref profiler_ref(profiler); + + auto listener = std::make_unique(); + auto listener_ptr = listener.get(); + RegisterXlaActivityListener(std::move(listener)); + + NameAttrList function; + function.set_name("TestFunc"); + + std::vector expected_activities; + for (int i = 0; i < 5; ++i) { + EXPECT_TRUE(profiler->RegisterCompilation(function, 4, false).ok()); + + TF_ASSERT_OK_AND_ASSIGN(auto stats, profiler->GetCompileStats(function)); + XlaJitCompilationActivity expected_activity; + expected_activity.set_cluster_name(function.name()); + expected_activity.set_compile_count(stats.compile_count); + expected_activity.set_compile_time_us(4); + expected_activity.set_cumulative_compile_time_us( + stats.cumulative_compile_time_us); + expected_activity.set_used_persistent_cache(false); + expected_activities.push_back(expected_activity); + } + + TF_ASSERT_OK_AND_ASSIGN(auto stats, profiler->GetCompileStats(function)); + EXPECT_EQ(stats.compile_count, 5); + EXPECT_EQ(stats.cumulative_compile_time_us, 5 * 4); + + // TODO(b/255826209): Use ::testing::EqualsProto once b/135192747 is fixed. + const auto& actual_activities = listener_ptr->GetListenerHistory(); + EXPECT_EQ(actual_activities.size(), expected_activities.size()); + for (size_t i = 0; i < actual_activities.size(); ++i) { + EXPECT_EQ(actual_activities[i].SerializeAsString(), + expected_activities[i].SerializeAsString()); + } +} + +TEST(DeviceCompilationProfilerTest, OngoingAsyncCompilations) { + DeviceCompilationProfiler* profiler = new DeviceCompilationProfiler(); + core::ScopedUnref profiler_ref(profiler); + + for (int i = 0; i < 5; ++i) { + profiler->IncrementOngoingAsyncCompilations(); + } + + EXPECT_EQ(profiler->GetNumOngoingAsyncCompilations(), 5); + + for (int i = 0; i < 5; ++i) { + profiler->DecrementOngoingAsyncCompilations(); + } + + EXPECT_EQ(profiler->GetNumOngoingAsyncCompilations(), 0); + + for (int i = 0; i < 5; ++i) { + profiler->IncrementOngoingAsyncCompilations(); + profiler->DecrementOngoingAsyncCompilations(); + } + + EXPECT_EQ(profiler->GetNumOngoingAsyncCompilations(), 0); +} + +TEST(DeviceCompilationProfilerTest, ShouldCompileClusterNotFound) { + DeviceCompilationProfiler* profiler = new DeviceCompilationProfiler(); + core::ScopedUnref profiler_ref(profiler); + + NameAttrList function; + function.set_name("TestFunc"); + + EXPECT_TRUE( + profiler->ShouldCompileCluster(function, DeviceCompileMode::kAsync, 0)); + EXPECT_TRUE( + profiler->ShouldCompileCluster(function, DeviceCompileMode::kLazy, 0)); + EXPECT_TRUE( + profiler->ShouldCompileCluster(function, DeviceCompileMode::kStrict, 0)); +} + +TEST(DeviceCompilationProfilerTest, ShouldCompileClusterFirstExecution) { + DeviceCompilationProfiler* profiler = new DeviceCompilationProfiler(); + core::ScopedUnref profiler_ref(profiler); + + NameAttrList function; + function.set_name("TestFunc"); + + profiler->RegisterExecution(function); + + EXPECT_TRUE( + profiler->ShouldCompileCluster(function, DeviceCompileMode::kAsync, 0)); + EXPECT_TRUE( + profiler->ShouldCompileCluster(function, DeviceCompileMode::kLazy, 0)); +} + +TEST(DeviceCompilationProfilerTest, ShouldCompileClusterMegamorphic) { + DeviceCompilationProfiler* profiler = new DeviceCompilationProfiler(); + core::ScopedUnref profiler_ref(profiler); + + NameAttrList function; + function.set_name("TestFunc"); + + const int64_t kCompileThreshold = 10; + const int64_t kMinExecutionsPerCompile = 50; + + // Register compilation enough times (without registering executions enough + // times) so that the function is marked megamorphic. + for (int i = 0; i < kCompileThreshold + 1; ++i) { + EXPECT_TRUE(profiler->RegisterCompilation(function, 1, false).ok()); + } + profiler->RegisterExecution(function); + + // Shouldn't compile cluster since it has gone megamorphic. + EXPECT_FALSE( + profiler->ShouldCompileCluster(function, DeviceCompileMode::kAsync, 0)); + EXPECT_FALSE( + profiler->ShouldCompileCluster(function, DeviceCompileMode::kLazy, 0)); + TF_ASSERT_OK_AND_ASSIGN(auto stats, profiler->GetCompileStats(function)); + EXPECT_TRUE(stats.is_megamorphic); + + // Always compile for strict compile mode. + EXPECT_TRUE( + profiler->ShouldCompileCluster(function, DeviceCompileMode::kStrict, 0)); + + // Once a cluster has gone megamorphic, it remains megamorphic (even though + // it's being executed more frequently now) and shouldn't be compiled again. + for (int i = 0; i < kCompileThreshold * kMinExecutionsPerCompile + 1; ++i) { + profiler->RegisterExecution(function); + } + + EXPECT_FALSE( + profiler->ShouldCompileCluster(function, DeviceCompileMode::kAsync, 0)); + EXPECT_FALSE( + profiler->ShouldCompileCluster(function, DeviceCompileMode::kLazy, 0)); + TF_ASSERT_OK_AND_ASSIGN(stats, profiler->GetCompileStats(function)); + EXPECT_TRUE(stats.is_megamorphic); + + // Always compile for strict compile mode. + EXPECT_TRUE( + profiler->ShouldCompileCluster(function, DeviceCompileMode::kStrict, 0)); +} + +TEST(DeviceCompilationProfilerTest, ShouldCompileClusterAsync) { + DeviceCompilationProfiler* profiler = new DeviceCompilationProfiler(); + core::ScopedUnref profiler_ref(profiler); + + NameAttrList function; + function.set_name("TestFunc"); + + const int64_t kMaxNumOngoingCompilations = 10; + for (int i = 0; i < kMaxNumOngoingCompilations; ++i) { + profiler->IncrementOngoingAsyncCompilations(); + } + + // Should allow compilation since this is the first execution. + profiler->RegisterExecution(function); + EXPECT_TRUE( + profiler->ShouldCompileCluster(function, DeviceCompileMode::kAsync, 0)); + + // Should not allow compilation since this is not the first execution and + // we've already reached the maximum number of ongoing compilations allowed. + profiler->RegisterExecution(function); + EXPECT_FALSE( + profiler->ShouldCompileCluster(function, DeviceCompileMode::kAsync, 0)); + + profiler->DecrementOngoingAsyncCompilations(); + // Should allow compilation since we've decremented the number of ongoing + // compilations. + EXPECT_TRUE( + profiler->ShouldCompileCluster(function, DeviceCompileMode::kAsync, 0)); +} + +TEST(DeviceCompilationProfilerTest, ShouldCompileClusterLazy) { + DeviceCompilationProfiler* profiler = new DeviceCompilationProfiler(); + core::ScopedUnref profiler_ref(profiler); + + NameAttrList function; + function.set_name("TestFunc"); + + constexpr int64_t kDefaultCompilationThreshold = 2; + + // Should allow compilation since this is the first execution. + profiler->RegisterExecution(function); + EXPECT_TRUE( + profiler->ShouldCompileCluster(function, DeviceCompileMode::kLazy, 0)); + + // Shouldn't allow compilation until compilation has been requested at least + // kDefaultCompilationThreshold times. + profiler->RegisterExecution(function); + for (int current_request_count = 0; + current_request_count < kDefaultCompilationThreshold; + ++current_request_count) { + EXPECT_FALSE(profiler->ShouldCompileCluster( + function, DeviceCompileMode::kLazy, current_request_count)); + } + EXPECT_TRUE(profiler->ShouldCompileCluster(function, DeviceCompileMode::kLazy, + kDefaultCompilationThreshold)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/device_compiler.h b/tensorflow/compiler/jit/device_compiler.h new file mode 100644 index 00000000000..e942169c519 --- /dev/null +++ b/tensorflow/compiler/jit/device_compiler.h @@ -0,0 +1,492 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_COMPILER_H_ +#define TENSORFLOW_COMPILER_JIT_DEVICE_COMPILER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/call_once.h" +#include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/jit/device_compilation_cache.h" +#include "tensorflow/compiler/jit/device_compilation_cluster_signature.h" +#include "tensorflow/compiler/jit/device_compilation_profiler.h" +#include "tensorflow/compiler/jit/device_compiler_client.h" +#include "tensorflow/compiler/jit/device_executable_persistor.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h" +#include "tensorflow/compiler/jit/xla_compile_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// Compiles/lowers a given Tensorflow graph/function/cluster into a compiled XLA +// compilation (HLO) using the XlaCompiler and compiles the resulting +// XlaCompilationResult into an `ExecutableType` (eg. xla::LocalExecutable) by +// calling `ClientType` (eg. xla::LocalClient). +// +// Caches the compiled XlaCompilationResult and Executable using a +// DeviceCompilationCache. Compilation is done only when there's a cache miss. +// +// Uses the DeviceExecutablePersistor class for persistence and tries to load a +// serialized executable from disk upon a request for compilation. If the +// appropriate executable isn't found on disk, compiles the given Tensorflow +// function/graph/cluster into an XlaCompilationResult (HLO) and +// `ExecutableType` and tries saving/persisting the compiled HLO and executable +// to disk. +// +// Since XLA computations must have static shapes, DeviceCompiler generates a +// new XLA computation for each new set of input shapes. +// TODO(b/255826209): De-templatize once we've moved to Device API completely. +template +class DeviceCompiler : public ResourceBase { + public: + DeviceCompiler( + std::unique_ptr> + persistor, + std::unique_ptr> + compiler_client); + ~DeviceCompiler() override; + + enum class CompileScope { + kOp, + kFunction, + }; + + // Compiles a function into a XlaCompiler::CompilationResult that can be used + // to execute an XLA Computation. Compilation results are cached. Compilation + // is skipped if there is a cache hit. `function` is the name of a Tensorflow + // function to compile. `args` is a description of the arguments to the + // computation. + // + // `compile_mode` controls the behavior of the compilation cache on a cache + // miss. If `compile_mode` is `kLazy` then, based on some profitability + // heuristics, the compilation cache may decide not to compile the cluster at + // this time. In this case it returns null into both `out_compilation_result` + // and `out_executable`. If `compile_mode` is `kStrict` then the compilation + // cache always attempts the compilation on a cache miss. If compilation mode + // is 'kAsync' compilation of the cluster happens in the background while the + // fallback path executes. + // + // The result of compilation is written to `*out_compilation_result`, which + // must be non-null. If `out_executable` is non-null, also builds an + // `ExecutableType` and sets `out_executable` to point to it. The + // resulting executable pointer may be null if the computation has no + // non-constant outputs. + Status CompileIfNeeded( + const XlaCompiler::Options& options, const NameAttrList& function, + const std::vector& args, + const XlaCompiler::CompileOptions& compile_options, + DeviceCompileMode compile_mode, DeviceCompilationProfiler* profiler, + const XlaCompiler::CompilationResult** out_compilation_result, + ExecutableType** out_executable); + + // As above, but for a single op. + Status CompileSingleOpIfNeeded( + const XlaCompiler::Options& options, + const std::vector& args, + const XlaCompiler::CompileOptions& compile_options, OpKernelContext* ctx, + DeviceCompilationProfiler* profiler, + const XlaCompiler::CompilationResult** out_compilation_result, + ExecutableType** out_executable); + + ClientType* client() const { return compiler_client_->client(); } + const DeviceType& device_type() const { return persistor_->device_type(); } + DeviceCompilationCache* cache() { return cache_.get(); } + + string DebugString() const override; + + private: + // Common implementation of Compile and CompileSingleOp. The `OpKernelContext` + // parameter is always null for the former. + Status CompileImpl( + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::Options& options, const NameAttrList& function, + const std::vector& args, CompileScope scope, + DeviceCompileMode compile_mode, OpKernelContext* ctx, + DeviceCompilationProfiler* profiler, + const XlaCompiler::CompilationResult** out_compilation_result, + ExecutableType** out_executable); + + StatusOr::Value> + CompileStrict( + const DeviceCompilationClusterSignature& sig, + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::Options& options, + const std::vector& args, + const NameAttrList& function, + typename DeviceCompilationCache::Value cache_value, + CompileScope scope, OpKernelContext* ctx, + DeviceCompilationProfiler* profiler, mutex* mu) + TF_EXCLUSIVE_LOCKS_REQUIRED(*mu); + + Status CompileAsynchronous(const DeviceCompilationClusterSignature& sig, + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::Options& options, + const std::vector& args, + const NameAttrList& function, CompileScope scope, + OpKernelContext* ctx, + DeviceCompilationProfiler* profiler); + + std::unique_ptr> + persistor_; + std::unique_ptr> + compiler_client_; + std::unique_ptr> cache_; + + // Pool of threads for asynchronous compilations. + std::unique_ptr async_compiler_threads_; + + mutex cluster_mutexes_mu_; + absl::flat_hash_map, + DeviceCompilationClusterSignature::Hash> + cluster_mutexes_ TF_GUARDED_BY(cluster_mutexes_mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(DeviceCompiler); +}; + +namespace device_compiler_internal { +// Print something that users can search for to definitively ascertain that XLA +// was used for their TF model. +// Prints only once to avoid spamming LOG(INFO). +inline void LogOnceXlaCompiledFirstCluster() { + static absl::once_flag log_once; + absl::call_once(log_once, [] { + LOG(INFO) << "Compiled cluster using XLA! This line is logged at most " + "once for the lifetime of the process."; + }); +} + +inline Status EligibleToPersist(DeviceCompileState compile_state, + const xla::LocalExecutable* executable) { + if (compile_state != DeviceCompileState::kCompiled) { + return errors::FailedPrecondition( + "Cache entry to serialize is not compiled."); + } + if (executable == nullptr) { + return errors::FailedPrecondition( + "LocalExecutable not found for cache entry to serialize."); + } + return OkStatus(); +} +} // namespace device_compiler_internal + +template +DeviceCompiler::DeviceCompiler( + std::unique_ptr> + persistor, + std::unique_ptr> + compiler_client) + : persistor_(std::move(persistor)), + compiler_client_(std::move(compiler_client)) { + cache_ = std::make_unique>(); + async_compiler_threads_ = std::make_unique( + tensorflow::Env::Default(), "async_compiler_threads", + kNumAsyncDeviceCompilerThreads); +} + +template +DeviceCompiler::~DeviceCompiler() { + // Since programs are owned by the cache, ensure any use of our programs have + // completed by waiting for all stream executors to complete. + compiler_client_->WaitForProgramsToFinish(); + // Wait for all outstanding compilations to finish. + // Resetting the pointer explicitly in the top level destructor. + // Without this, the pointer would be reset when the AsyncCompilationState + // is destructed, which is dependent on the order of the members in the + // DeviceCompiler class, which is error prone if the order changes. + async_compiler_threads_.reset(); + // TODO(b/110813685): Think about the program ownership model. Programs are + // currently owned by the compilation cache which means we must wait for + // program completion in the destructor. There are multiple compilation caches + // around, which complicates things a little. Perhaps having programs be + // shared_ptrs (an invasive change) would make the model easier to reason + // about? +} + +template +string DeviceCompiler::DebugString() const { + return "DeviceCompiler"; +} + +template +Status DeviceCompiler::CompileIfNeeded( + const XlaCompiler::Options& options, const NameAttrList& function, + const std::vector& args, + const XlaCompiler::CompileOptions& compile_options, + DeviceCompileMode compile_mode, DeviceCompilationProfiler* profiler, + const XlaCompiler::CompilationResult** out_compilation_result, + ExecutableType** out_executable) { + return CompileImpl(compile_options, options, function, args, + CompileScope::kFunction, compile_mode, /*ctx=*/nullptr, + profiler, out_compilation_result, out_executable); +} + +template +Status DeviceCompiler::CompileSingleOpIfNeeded( + const XlaCompiler::Options& options, + const std::vector& args, + const XlaCompiler::CompileOptions& compile_options, OpKernelContext* ctx, + DeviceCompilationProfiler* profiler, + const XlaCompiler::CompilationResult** out_compilation_result, + ExecutableType** out_executable) { + const NodeDef& def = ctx->op_kernel().def(); + NameAttrList name; + name.set_name(def.op()); + *name.mutable_attr() = def.attr(); + // Remove the "_class" attribute from the attribute set used to create the + // compilation cache key. This attribute is information for the colocator + // and causes false uniqueness between nodes. + name.mutable_attr()->erase("_class"); + return CompileImpl(compile_options, options, name, args, CompileScope::kOp, + DeviceCompileMode::kStrict, ctx, profiler, + out_compilation_result, out_executable); +} + +template +StatusOr::Value> +DeviceCompiler::CompileStrict( + const DeviceCompilationClusterSignature& sig, + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::Options& options, + const std::vector& args, + const NameAttrList& function, + typename DeviceCompilationCache::Value cache_value, + CompileScope scope, OpKernelContext* ctx, + DeviceCompilationProfiler* profiler, mutex* mu) { + tensorflow::Env* env = tensorflow::Env::Default(); + const uint64 compile_start_us = env->NowMicros(); + + TfGraphToHloCompiler compiler(options); + cache_value.compile_state = DeviceCompileState::kCompiled; + + std::unique_ptr out_executable; + auto out_compilation_result = + std::make_unique(); + + if (scope == CompileScope::kOp) { + cache_value.compilation_status = compiler.CompileSingleOp( + compile_options, ctx, args, out_compilation_result.get()); + } else { + CHECK(scope == CompileScope::kFunction); // Crash OK + cache_value.compilation_status = compiler.Compile( + compile_options, function, args, out_compilation_result.get()); + } + TF_RETURN_IF_ERROR(cache_value.compilation_status); + TF_RET_CHECK(cache_value.executable == nullptr); + TF_RET_CHECK(out_compilation_result->computation != nullptr); + + auto loaded_executable = persistor_->TryToLoadExecutable( + DeviceCompilationClusterSignature::Hash()(sig), sig.HumanString(), + options, *out_compilation_result, compiler_client_.get()); + + if (loaded_executable.has_value()) { + cache_value.compilation_status = loaded_executable->status(); + if (loaded_executable->ok()) { + out_executable = *std::move(*loaded_executable); + } + } else { + auto built_executable = + compiler_client_->BuildExecutable(options, *out_compilation_result); + TF_RETURN_IF_ERROR(built_executable.status()); + out_executable = *std::move(built_executable); + + TF_RETURN_IF_ERROR(device_compiler_internal::EligibleToPersist( + cache_value.compile_state, out_executable.get())); + TF_RETURN_IF_ERROR(persistor_->TryToPersistExecutable( + DeviceCompilationClusterSignature::Hash()(sig), sig.HumanString(), + options, *out_compilation_result, *out_executable, + compiler_client_.get())); + } + + cache_value.compilation_result = out_compilation_result.get(); + cache_value.executable = out_executable.get(); + cache_->Store(sig, cache_value.compile_state, cache_value.compilation_status, + std::move(out_compilation_result), std::move(out_executable)); + + const uint64 compile_end_us = env->NowMicros(); + const uint64 compile_time_us = compile_end_us - compile_start_us; + + device_compiler_internal::LogOnceXlaCompiledFirstCluster(); + TF_RETURN_IF_ERROR(profiler->RegisterCompilation( + function, compile_time_us, loaded_executable.has_value())); + return cache_value; +} + +template +Status DeviceCompiler::CompileAsynchronous( + const DeviceCompilationClusterSignature& signature, + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::Options& options, + const std::vector& args, + const NameAttrList& function, CompileScope scope, OpKernelContext* ctx, + DeviceCompilationProfiler* profiler) { + // Explicitly capture all required data by value for async compilation. + // Update compilation state in cache. + cache_->Store(signature, DeviceCompileState::kCompiling, std::nullopt, + std::nullopt, std::nullopt); + profiler->IncrementOngoingAsyncCompilations(); + // Don't move the above code into the thread function as it synchronously + // updates the async compilation state! + + // When the ThreadPool for the compilation cache is destroyed, it waits for + // compilations to have finished. This means that both 'entry' and 'this' will + // be alive for the duration of the compilation. + // !!Pay attention when additional variables must be captured by this lambda!! + // All values are captured by value. Make sure that all pointer values (like + // entry) do not get freed until the lambda has finished. + const std::string& function_name = function.name(); + async_compiler_threads_->Schedule([=] { + VLOG(2) << "Starting asynchronous compilation of cluster " << function_name + << '.'; + // We don't need to lock mu, but do it anyway to satisfy thread safety + // analysis. + mutex mu; + mutex_lock lock(mu); + auto cache_value = typename DeviceCompilationCache::Value(); + auto s = CompileStrict(signature, compile_options, options, args, function, + cache_value, scope, ctx, profiler, &mu); + VLOG(2) << "Finished asynchronous compililation of cluster " + << function_name << '.'; + profiler->DecrementOngoingAsyncCompilations(); + // Update compilation status in cache. + if (!s.ok()) { + cache_->Store(signature, std::nullopt, s.status(), std::nullopt, + std::nullopt); + } + }); + return OkStatus(); +} + +template +Status DeviceCompiler::CompileImpl( + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::Options& options, const NameAttrList& function, + const std::vector& args, CompileScope scope, + DeviceCompileMode compile_mode, OpKernelContext* ctx, + DeviceCompilationProfiler* profiler, + const XlaCompiler::CompilationResult** out_compilation_result, + ExecutableType** out_executable) { + DCHECK_NE(out_executable, nullptr); + VLOG(2) << "DeviceCompiler::Compile " << DebugString(); + + if (VLOG_IS_ON(2)) { + VLOG(2) << "num_inputs=" << args.size(); + for (int i = 0, end = args.size(); i < end; i++) { + VLOG(3) << i << ": " << args[i].HumanString(); + } + } + TF_ASSIGN_OR_RETURN(auto signature, + DeviceCompilationClusterSignature::Build(function, args)); + + // The outer lock protects the existence of the mutex in the map. + mutex* cluster_mutex; + { + mutex_lock lock(cluster_mutexes_mu_); + auto it = + cluster_mutexes_.emplace(signature, std::make_unique()).first; + cluster_mutex = it->second.get(); + } + + profiler->RegisterExecution(function); + + string human_signature; + if (VLOG_IS_ON(2)) { + human_signature = VLOG_IS_ON(3) ? signature.HumanString() : function.name(); + VLOG(2) << "DeviceCompilationClusterSignature: " << human_signature; + } + + // Acquire the cache entry lock and compile, if necessary. + // TODO(phawkins): this locking will need to be restructured when we implement + // cache eviction. + mutex_lock cluster_compile_lock(*cluster_mutex); + auto cache_value = cache_->LookupOrCreate(signature); + + int64_t current_request_count = cache_value.request_count; + VLOG(2) << "Compilation cache entry hit: " + << static_cast(cache_value.compile_state) + << " signature: " << human_signature << " with request count " + << current_request_count; + + DeviceCompileState state = cache_value.compile_state; + *out_compilation_result = nullptr; + *out_executable = nullptr; + + // Check if the requested entry is uncompiled and return an error if + // compilation is disabled. This will raise an error for kLazy even if we have + // not yet hit the compilation threshold and no compilation happens this + // round. This is to avoid non-determanism of when compilation is disallowed, + // for example by changing the threshold. + if (state == DeviceCompileState::kUncompiled && FailOnXlaCompilation()) { + VLOG(1) << "XLA compilation disabled: " << function.name() << "\n" + << absl::StrJoin( + args, "\n", + [](std::string* out, const XlaCompiler::Argument& arg) { + absl::StrAppend(out, " arg: ", arg.HumanString()); + }); + + return errors::Internal("XLA compilation disabled"); + } + + if (state == DeviceCompileState::kUncompiled) { + XLA_SCOPED_LOGGING_TIMER("Compilation of XLA executable"); + if (!profiler->ShouldCompileCluster(function, compile_mode, + current_request_count)) { + VLOG(2) << "Not compiling for signature: " << human_signature; + return OkStatus(); + } else if (compile_mode == DeviceCompileMode::kAsync) { + VLOG(2) << "Queueing asynchronous compilation for signature: " + << human_signature; + TF_RETURN_IF_ERROR(CompileAsynchronous(signature, compile_options, + options, args, function, scope, + ctx, profiler)); + return OkStatus(); + } else { + VLOG(2) << "Instantly compiling for signature: " << human_signature; + TF_ASSIGN_OR_RETURN( + cache_value, + CompileStrict(signature, compile_options, options, args, function, + cache_value, scope, ctx, profiler, cluster_mutex)); + } + } else if (state == DeviceCompileState::kCompiling) { + VLOG(2) << "Ongoing asynchronous compilation for signature: " + << human_signature; + return OkStatus(); + } else if (state == DeviceCompileState::kCompiled) { + VLOG(2) << "Already Compiled for signature: " << human_signature; + } + + TF_RETURN_IF_ERROR(cache_value.compilation_status); + *out_compilation_result = cache_value.compilation_result; + *out_executable = cache_value.executable; + return OkStatus(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEVICE_COMPILER_H_ diff --git a/tensorflow/compiler/jit/device_compiler_client.cc b/tensorflow/compiler/jit/device_compiler_client.cc new file mode 100644 index 00000000000..e84906126b5 --- /dev/null +++ b/tensorflow/compiler/jit/device_compiler_client.cc @@ -0,0 +1,46 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/device_compiler_client.h" + +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/core/util/determinism.h" + +namespace tensorflow { + +xla::ExecutableBuildOptions GetExecutableBuildOptions( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result, int default_device_ordinal) { + xla::ExecutableBuildOptions build_options; + if (result.collective_info) { + build_options.set_num_replicas(result.collective_info->group_size); + } + if (options.device_ordinal != -1) { + build_options.set_device_ordinal(options.device_ordinal); + } else if (default_device_ordinal != -1) { + build_options.set_device_ordinal(default_device_ordinal); + } + build_options.set_result_layout(result.xla_output_shape); + build_options.set_device_allocator(options.device_allocator.get()); + build_options.set_alias_passthrough_params(options.alias_passthrough_params); + build_options.mutable_debug_options()->set_xla_detailed_logging_and_dumping( + options.detailed_logging); + if (tensorflow::OpDeterminismRequired()) { + build_options.mutable_debug_options()->set_xla_gpu_deterministic_ops(true); + } + return build_options; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/device_compiler_client.h b/tensorflow/compiler/jit/device_compiler_client.h new file mode 100644 index 00000000000..44da761ee28 --- /dev/null +++ b/tensorflow/compiler/jit/device_compiler_client.h @@ -0,0 +1,75 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_COMPILER_CLIENT_H_ +#define TENSORFLOW_COMPILER_JIT_DEVICE_COMPILER_CLIENT_H_ + +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/executable_build_options.h" + +namespace tensorflow { + +template +class DeviceCompilerClient { + public: + DeviceCompilerClient() = default; + virtual ~DeviceCompilerClient() = default; + + // Compiles `result` (HLO) to an `ExecutableType` using `ClientType` and + // returns it. + virtual StatusOr> BuildExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) = 0; + + // Serializes an available `executable` to string using `ClientType` and + // returns it. + virtual StatusOr SerializeExecutable( + const ExecutableType& executable) = 0; + + // Compiles `result` (HLO) to a serializable executable (eg. + // xla::AotCompilationResult) using `ClientType`, serializes it to string and + // returns it. + virtual StatusOr BuildSerializedExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) = 0; + + // Loads `serialized_executable` into an `ExecutableType` using `ClientType`. + virtual StatusOr> LoadExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result, + const std::string& serialized_executable) = 0; + + // Waits for the underlying `ClientType` backend's programs to finish + // executing before returning. + virtual void WaitForProgramsToFinish() = 0; + + virtual ClientType* client() const = 0; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(DeviceCompilerClient); +}; + +// Generates the ExecutableBuildOptions for compilation from HLO to +// executable. +xla::ExecutableBuildOptions GetExecutableBuildOptions( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result, int default_device_ordinal); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEVICE_COMPILER_CLIENT_H_ diff --git a/tensorflow/compiler/jit/device_compiler_client_test.cc b/tensorflow/compiler/jit/device_compiler_client_test.cc new file mode 100644 index 00000000000..104b0b0f651 --- /dev/null +++ b/tensorflow/compiler/jit/device_compiler_client_test.cc @@ -0,0 +1,64 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/device_compiler_client.h" + +#include + +namespace tensorflow { +namespace { + +TEST(GetExecutableOptionTest, Basic) { + XlaCompiler::Options options; + options.device_ordinal = 0; + options.alias_passthrough_params = true; + options.detailed_logging = true; + XlaCompiler::CompilationResult result; + xla::Shape xla_output_shape; + result.xla_output_shape = xla_output_shape; + + auto build_option = + GetExecutableBuildOptions(options, result, /*default_device_ordinal=*/-1); + + EXPECT_EQ(build_option.device_ordinal(), 0); + EXPECT_EQ(build_option.result_layout()->ToString(), + xla_output_shape.ToString()); + EXPECT_EQ(build_option.alias_passthrough_params(), true); + EXPECT_EQ(build_option.debug_options().xla_detailed_logging_and_dumping(), + true); +} + +TEST(GetExecutableOptionTest, DefaultDeviceOrdinal) { + XlaCompiler::Options options; + XlaCompiler::CompilationResult result; + + auto build_option = + GetExecutableBuildOptions(options, result, /*default_device_ordinal=*/0); + + EXPECT_EQ(build_option.device_ordinal(), 0); +} + +TEST(GetExecutableOptionTest, DeviceOrdinalNotSet) { + XlaCompiler::Options options; + XlaCompiler::CompilationResult result; + + auto build_option = + GetExecutableBuildOptions(options, result, /*default_device_ordinal=*/-1); + + EXPECT_EQ(build_option.device_ordinal(), -1); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache_disable_test.cc b/tensorflow/compiler/jit/device_compiler_disable_test.cc similarity index 58% rename from tensorflow/compiler/jit/xla_compilation_cache_disable_test.cc rename to tensorflow/compiler/jit/device_compiler_disable_test.cc index 3114974bd00..cf4b5461861 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache_disable_test.cc +++ b/tensorflow/compiler/jit/device_compiler_disable_test.cc @@ -13,19 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include + +#include "tensorflow/compiler/jit/device_compilation_profiler.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/xla_compilation_cache.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/jit/device_compiler.h" +#include "tensorflow/compiler/jit/xla_device_compiler_client.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" namespace tensorflow { namespace { -// This test is kept separate because it disables XLA compilation globaly. -TEST(XlaCompilationCacheTest, TestDisabledXlaCompilation) { +// This test is kept separate because it disables XLA compilation globally. +TEST(DeviceCompilerTest, TestDisabledXlaCompilation) { NameAttrList fn; fn.set_name("afunction"); @@ -46,33 +50,39 @@ TEST(XlaCompilationCacheTest, TestDisabledXlaCompilation) { const XlaCompiler::CompilationResult* compilation_result; xla::LocalExecutable* executable; - auto cache = new XlaCompilationCache(XlaCompilationCache::Config(), client, - device_type); - core::ScopedUnref cache_ref(cache); + using XlaDeviceExecutablePersistor = + DeviceExecutablePersistor; + auto persistor = std::make_unique( + XlaDeviceExecutablePersistor::Config(), device_type); + auto compiler_client = std::make_unique(client); + auto xla_device_compiler = + new DeviceCompiler( + std::move(persistor), std::move(compiler_client)); + core::ScopedUnref xla_device_compiler_ref(xla_device_compiler); + + auto profiler = new DeviceCompilationProfiler(); + core::ScopedUnref profiler_ref(profiler); // Check that strict compilation is disallowed. - Status status = cache->Compile(XlaCompiler::Options{}, fn, args, - XlaCompiler::CompileOptions{}, - XlaCompilationCache::CompileMode::kStrict, - &compilation_result, &executable); + Status status = xla_device_compiler->CompileIfNeeded( + XlaCompiler::Options{}, fn, args, XlaCompiler::CompileOptions{}, + DeviceCompileMode::kStrict, profiler, &compilation_result, &executable); EXPECT_FALSE(status.ok()); EXPECT_TRUE( absl::StrContains(status.error_message(), "XLA compilation disabled")); // Check that async compilation is disallowed. - status = cache->Compile(XlaCompiler::Options{}, fn, args, - XlaCompiler::CompileOptions{}, - XlaCompilationCache::CompileMode::kAsync, - &compilation_result, &executable); + status = xla_device_compiler->CompileIfNeeded( + XlaCompiler::Options{}, fn, args, XlaCompiler::CompileOptions{}, + DeviceCompileMode::kAsync, profiler, &compilation_result, &executable); EXPECT_FALSE(status.ok()); EXPECT_TRUE( absl::StrContains(status.error_message(), "XLA compilation disabled")); // Check that lazy compilation is disallowed. - status = cache->Compile(XlaCompiler::Options{}, fn, args, - XlaCompiler::CompileOptions{}, - XlaCompilationCache::CompileMode::kLazy, - &compilation_result, &executable); + status = xla_device_compiler->CompileIfNeeded( + XlaCompiler::Options{}, fn, args, XlaCompiler::CompileOptions{}, + DeviceCompileMode::kLazy, profiler, &compilation_result, &executable); EXPECT_FALSE(status.ok()); EXPECT_TRUE( absl::StrContains(status.error_message(), "XLA compilation disabled")); diff --git a/tensorflow/compiler/jit/device_executable_persistor.h b/tensorflow/compiler/jit/device_executable_persistor.h new file mode 100644 index 00000000000..ced6e2c43e1 --- /dev/null +++ b/tensorflow/compiler/jit/device_executable_persistor.h @@ -0,0 +1,336 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_EXECUTABLE_PERSISTOR_H_ +#define TENSORFLOW_COMPILER_JIT_DEVICE_EXECUTABLE_PERSISTOR_H_ + +#include +#include + +#include "tensorflow/compiler/jit/xla_compilation_cache.pb.h" +#include "tensorflow/compiler/jit/xla_device_compiler_client.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace tensorflow { + +// Offers a way to persist and/or load compiled `ExecutableType`s along with the +// corresponding HLO (`CompilationResult`) to/from `persistent_cache_directory` +// (if one was provided during construction) on disk using `ClientType`. +template +class DeviceExecutablePersistor { + public: + // Configuration for setting up persistence (directory, filename prefix, etc). + struct Config { + Config() = default; + explicit Config(absl::string_view persistent_cache_directory, + bool disable_strict_signature_checks, + absl::string_view persistence_prefix) + : persistent_cache_directory(persistent_cache_directory), + disable_strict_signature_checks(disable_strict_signature_checks), + persistence_prefix(persistence_prefix) {} + + // If non-empty, JIT-compiled executables are saved to and loaded from the + // specified file system directory path. + std::string persistent_cache_directory; + + // Disable strict signature checks for entries loaded into the cache from + // external sources. + bool disable_strict_signature_checks = false; + + // The cache persistence prefix to use if serializing/deserialzing entries. + std::string persistence_prefix; + }; + + DeviceExecutablePersistor(const Config& config, + const DeviceType& device_type); + + // Returns std::nullopt if persistence is not enabled (i.e. + // `persistent_cache_directory_` is empty) or if the serialized entry is not + // found on disk. Otherwise, loads and returns the serialized executable + // (or returns a status). + // TODO(b/255826209): Take in Signature instead of hash and string once cache + // is refactored. + std::optional>> TryToLoadExecutable( + uint64 signature_hash, const std::string& signature_str, + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& compilation_result, + DeviceCompilerClient* client) const; + + // Tries to serialize an already built `executable` and persist it on disk. If + // unable to do so, tries to build a serialized executable using the AOT + // pipeline and persists that to disk. + // TODO(b/255826209): Take in Signature instead hash and string once cache + // is refactored. + Status TryToPersistExecutable( + uint64 signature_hash, const std::string& signature_str, + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& compilation_result, + const ExecutableType& executable, + DeviceCompilerClient* client) const; + + const DeviceType& device_type() const { return device_type_; } + const std::string& persistence_prefix() const { return persistence_prefix_; } + const std::string& persistent_cache_directory() const { + return persistent_cache_directory_; + } + + private: + // Returns a cache key proto that identifies an entry in the compilation + // cache. + XlaSerializedCacheKey BuildSerializedCacheKey( + uint64 signature_hash, const xla::HloModuleProto& hlo_module) const; + + // Serializes the signature and its corresponding entry to a proto message. + StatusOr SerializeEntry( + uint64 signature_hash, const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& compilation_result, + const ExecutableType& executable, + DeviceCompilerClient* compiler_client) const; + + // Saves the cache entry in the file directory supplied during the + // construction of this class. Overwrites existing entries. + Status SaveSerializedEntry(const XlaSerializedCacheEntry& entry) const; + + // Tries to read a cache entry given a `key` by searching the file directory + // supplied during the construction of this class. Returns std::nullopt if no + // cache entry is found. + StatusOr> TryToReadSerializedEntry( + const XlaSerializedCacheKey& key) const; + + // Checks if the loaded `entry` matches the expected `key` and `hlo_module`. + Status VerifyLoadedCacheEntry(const XlaSerializedCacheKey& key, + const xla::HloModuleProto& hlo_module, + const XlaSerializedCacheEntry& entry) const; + + std::string XlaSerializedCacheKeyToString( + const XlaSerializedCacheKey& key) const; + std::string GetFilePath(const XlaSerializedCacheKey& key) const; + + const DeviceType device_type_; + const bool disable_strict_signature_checks_; + const std::string persistence_prefix_; + + // If non-empty, JIT-compiled executables are saved to and loaded from the + // specified file system directory path. + const std::string persistent_cache_directory_; + + TF_DISALLOW_COPY_AND_ASSIGN(DeviceExecutablePersistor); +}; + +template +DeviceExecutablePersistor:: + DeviceExecutablePersistor(const Config& config, + const DeviceType& device_type) + : device_type_(device_type), + disable_strict_signature_checks_(config.disable_strict_signature_checks), + persistence_prefix_(config.persistence_prefix), + persistent_cache_directory_(config.persistent_cache_directory) {} + +template +std::string DeviceExecutablePersistor:: + XlaSerializedCacheKeyToString(const XlaSerializedCacheKey& key) const { + static constexpr char kXlaSerializedCacheKeySeparator[] = "__"; + return absl::StrCat( + key.prefix(), key.prefix().empty() ? "" : kXlaSerializedCacheKeySeparator, + key.signature_fingerprint(), kXlaSerializedCacheKeySeparator, + key.cluster_fingerprint(), kXlaSerializedCacheKeySeparator, + key.device_type()); +} + +template +std::string DeviceExecutablePersistor::GetFilePath( + const XlaSerializedCacheKey& key) const { + const std::string file_name = + absl::StrCat(XlaSerializedCacheKeyToString(key), ".pb"); + return io::JoinPath(persistent_cache_directory_, file_name); +} + +template +XlaSerializedCacheKey +DeviceExecutablePersistor::BuildSerializedCacheKey( + uint64 signature_hash, const xla::HloModuleProto& hlo_module) const { + XlaSerializedCacheKey serialized_cache_key; + serialized_cache_key.set_signature_fingerprint(signature_hash); + serialized_cache_key.set_cluster_fingerprint( + DeterministicProtoHash64(hlo_module)); + serialized_cache_key.set_device_type(device_type().type_string()); + serialized_cache_key.set_prefix(persistence_prefix()); + return serialized_cache_key; +} + +template +StatusOr> +DeviceExecutablePersistor::TryToReadSerializedEntry( + const XlaSerializedCacheKey& key) const { + Env* env = Env::Default(); + const std::string file_path = GetFilePath(key); + if (!env->FileExists(file_path).ok()) { + return StatusOr>(std::nullopt); + } + + XlaSerializedCacheEntry entry; + TF_RETURN_IF_ERROR(ReadTextOrBinaryProto(env, file_path, &entry)); + return std::optional(entry); +} + +template +Status +DeviceExecutablePersistor::VerifyLoadedCacheEntry( + const XlaSerializedCacheKey& key, const xla::HloModuleProto& hlo_module, + const XlaSerializedCacheEntry& entry) const { + XLA_SCOPED_LOGGING_TIMER(absl::StrCat("Verifying loaded cache entry: ", + hlo_module.entry_computation_name())); + + if (!AreSerializedProtosEqual(key, entry.key())) { + VLOG(2) << "Serialized cache key does not match:\n" + << "got:\n" + << entry.key().DebugString() << "\nexpected:\n" + << key.DebugString() << "\n"; + return errors::InvalidArgument("Serialized cache key does not match."); + } + + // Perform a stricter (slower) check of the snapshot to verify that they + // match exactly. + if (!disable_strict_signature_checks_) { + if (!AreSerializedProtosEqual(hlo_module, entry.hlo_module())) { + VLOG(2) << "HLOs do not match:\n" + << "got:\n" + << hlo_module.DebugString() << "\nexpected:\n" + << entry.hlo_module().DebugString() << "\n"; + return errors::InvalidArgument("Serialized HLO does not match."); + } + } + + if (entry.executable().empty()) { + return errors::InvalidArgument("No binary found in serialized entry."); + } + return OkStatus(); +} + +template +Status +DeviceExecutablePersistor::SaveSerializedEntry( + const XlaSerializedCacheEntry& entry) const { + Env* env = Env::Default(); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(persistent_cache_directory_)); + const std::string file_path = GetFilePath(entry.key()); + return WriteBinaryProto(env, file_path, entry); +} + +template +StatusOr +DeviceExecutablePersistor::SerializeEntry( + uint64 signature_hash, const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& compilation_result, + const ExecutableType& executable, + DeviceCompilerClient* compiler_client) const { + XlaSerializedCacheEntry serialized_entry; + const xla::HloModuleProto& hlo_module = + compilation_result.computation->proto(); + *serialized_entry.mutable_key() = + BuildSerializedCacheKey(signature_hash, hlo_module); + *serialized_entry.mutable_hlo_module() = hlo_module; + + // XLA compiler supports exporting executables as an AOT compilation result + // to avoid running potentially expensive compilation pipeline twice. + // Check if XLA compiler can export available executable. + if (auto serialized_executable = + compiler_client->SerializeExecutable(executable); + serialized_executable.ok()) { + serialized_entry.set_executable(std::move(*serialized_executable)); + return serialized_entry; + } else if (serialized_executable.status().code() == error::UNIMPLEMENTED) { + VLOG(1) << "Executable export is not implemented"; + } else { + return serialized_executable.status(); + } + + TF_ASSIGN_OR_RETURN( + auto serialized_executable, + compiler_client->BuildSerializedExecutable(options, compilation_result)); + serialized_entry.set_executable(std::move(serialized_executable)); + return serialized_entry; +} + +template +std::optional>> +DeviceExecutablePersistor::TryToLoadExecutable( + uint64 signature_hash, const std::string& signature_str, + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& compilation_result, + DeviceCompilerClient* compiler_client) const { + if (persistent_cache_directory_.empty()) { + return std::nullopt; + } + + const xla::HloModuleProto& hlo_module = + compilation_result.computation->proto(); + + XlaSerializedCacheKey cache_key = + BuildSerializedCacheKey(signature_hash, hlo_module); + + std::optional serialized_entry; + { + XLA_SCOPED_LOGGING_TIMER( + absl::StrCat("Try loading serialized cache entry:", signature_str)); + TF_ASSIGN_OR_RETURN(serialized_entry, TryToReadSerializedEntry(cache_key)); + } + + if (!serialized_entry.has_value()) { + return std::nullopt; + } + + TF_RETURN_IF_ERROR( + VerifyLoadedCacheEntry(cache_key, hlo_module, *serialized_entry)); + + VLOG(1) << "Loading cached entry for: " << signature_str; + return compiler_client->LoadExecutable(options, compilation_result, + serialized_entry->executable()); +} + +template +Status +DeviceExecutablePersistor::TryToPersistExecutable( + uint64 signature_hash, const std::string& signature_str, + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& compilation_result, + const ExecutableType& executable, + DeviceCompilerClient* client) const { + if (persistent_cache_directory_.empty()) { + VLOG(1) << "Not persisting executable. No `persistent_cache_directory` " + "provided."; + return OkStatus(); + } + + XLA_SCOPED_LOGGING_TIMER( + absl::StrCat("Serializing and saving cache entry: ", signature_str)); + TF_ASSIGN_OR_RETURN(XlaSerializedCacheEntry serialized_entry, + SerializeEntry(signature_hash, options, + compilation_result, executable, client)); + TF_RETURN_IF_ERROR(SaveSerializedEntry(std::move(serialized_entry))); + return OkStatus(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEVICE_EXECUTABLE_PERSISTOR_H_ diff --git a/tensorflow/compiler/jit/device_executable_persistor_test.cc b/tensorflow/compiler/jit/device_executable_persistor_test.cc new file mode 100644 index 00000000000..d79b93f04f5 --- /dev/null +++ b/tensorflow/compiler/jit/device_executable_persistor_test.cc @@ -0,0 +1,483 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/device_executable_persistor.h" + +#include + +#include +#include +#include +#include + +#include +#include +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/compiler/jit/device_compiler_client.h" +#include "tensorflow/compiler/jit/xla_compilation_cache.pb.h" +#include "tensorflow/compiler/jit/xla_device_compiler_client.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status_matchers.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { +namespace { + +using ::testing::_; +using ::testing::ByMove; +using ::testing::Return; +using XlaDeviceExecutablePersistor = + DeviceExecutablePersistor; + +class DeviceExecutionPersistorTest : public ::testing::Test { + protected: + void SetUp() override { + compiler_client_ = std::make_unique( + xla::ClientLibrary::LocalClientOrDie()); + + XlaOpRegistry::RegisterCompilationKernels(); + + flib_def_ = std::make_unique( + OpRegistry::Global(), FunctionDefLibrary()); + + cache_dir_ = testing::TmpDir(); + TF_ASSERT_OK_AND_ASSIGN(compilation_result_add_, + BuildSampleCompilationResult()); + } + + StatusOr> BuildSampleExecutable() { + return compiler_client_->BuildExecutable(DefaultOptions(), + compilation_result_add_); + } + + StatusOr BuildSampleCompilationResult( + bool mul = false) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); + if (mul) { + auto c = ops::Mul(scope.WithOpName("C"), a, b); + auto d = ops::_Retval(scope.WithOpName("D"), c, 0); + TF_RETURN_IF_ERROR(scope.ToGraph(graph.get())); + } else { + auto c = ops::Add(scope.WithOpName("C"), a, b); + auto d = ops::_Retval(scope.WithOpName("D"), c, 0); + TF_RETURN_IF_ERROR(scope.ToGraph(graph.get())); + } + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kParameter; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult compilation_result; + TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(), + "graph", std::move(graph), args, + &compilation_result)); + return compilation_result; + } + + XlaCompiler::Options DefaultOptions() { + XlaCompiler::Options options; + options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); + options.client = compiler_client_->client(); + options.flib_def = flib_def_.get(); + return options; + } + + std::unique_ptr flib_def_; + std::unique_ptr compiler_client_; + XlaCompiler::CompilationResult compilation_result_add_; + std::string serialized_executable_ = "serialized_executable"; + std::string cache_dir_; +}; + +// Using a mock to make testing different branches and triggering errors easier. +// Currently the `XlaDeviceCompilerClient`'s load/serialize functions don't work +// with the current test setup. +// TODO(b/255826209): Look into using a real object for most tests. +class MockCompilerClient : public XlaDeviceCompilerClient { + public: + MockCompilerClient() : XlaDeviceCompilerClient(nullptr) {} + MOCK_METHOD(StatusOr, SerializeExecutable, + (const xla::LocalExecutable& executable), (override)); + MOCK_METHOD(StatusOr, BuildSerializedExecutable, + (const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result), + (override)); + MOCK_METHOD(StatusOr>, LoadExecutable, + (const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result, + const std::string& serialized_executable), + (override)); +}; + +std::string GetFilePath(XlaSerializedCacheKey key, + const std::string& persistent_cache_dir) { + static constexpr char kXlaSerializedCacheKeySeparator[] = "__"; + + std::string file_name = absl::StrCat( + key.prefix(), key.prefix().empty() ? "" : kXlaSerializedCacheKeySeparator, + key.signature_fingerprint(), kXlaSerializedCacheKeySeparator, + key.cluster_fingerprint(), kXlaSerializedCacheKeySeparator, + key.device_type(), ".pb"); + + return io::JoinPath(persistent_cache_dir, file_name); +} + +StatusOr ReadCacheEntryFromFile( + XlaSerializedCacheKey key, const std::string& persistent_cache_dir) { + std::string file_path = GetFilePath(key, persistent_cache_dir); + XlaSerializedCacheEntry entry; + TF_RETURN_IF_ERROR(ReadTextOrBinaryProto(Env::Default(), file_path, &entry)); + return entry; +} + +XlaSerializedCacheKey CreateCacheKey( + uint64 signature_hash, + const XlaCompiler::CompilationResult& compilation_result, + const DeviceType& device_type, const std::string& persistence_prefix) { + XlaSerializedCacheKey key; + key.set_signature_fingerprint(signature_hash); + key.set_cluster_fingerprint( + DeterministicProtoHash64(compilation_result.computation->proto())); + key.set_device_type(device_type.type_string()); + key.set_prefix(persistence_prefix); + return key; +} + +TEST_F(DeviceExecutionPersistorTest, PersistCacheDirNotSet) { + XlaDeviceExecutablePersistor::Config config( + /*persistent_cache_directory=*/"", + /*disable_strict_signature_checks=*/false, + /*persistence_prefix=*/"xla"); + XlaDeviceExecutablePersistor persistor(config, DefaultOptions().device_type); + + MockCompilerClient mock_client; + TF_ASSERT_OK_AND_ASSIGN(auto executable, BuildSampleExecutable()); + TF_EXPECT_OK(persistor.TryToPersistExecutable( + /*signature_hash=*/123, "signature_string", DefaultOptions(), + compilation_result_add_, *executable, &mock_client)); + + auto key = + CreateCacheKey(/*signature_hash=*/123, compilation_result_add_, + persistor.device_type(), persistor.persistence_prefix()); + auto entry = ReadCacheEntryFromFile(key, ""); + EXPECT_FALSE(entry.ok()); +} + +TEST_F(DeviceExecutionPersistorTest, PersistSerializeAlreadyBuiltExecutable) { + XlaDeviceExecutablePersistor::Config config( + /*persistent_cache_directory=*/cache_dir_, + /*disable_strict_signature_checks=*/false, + /*persistence_prefix=*/"xla"); + XlaDeviceExecutablePersistor persistor(config, DefaultOptions().device_type); + + MockCompilerClient mock_client; + EXPECT_CALL(mock_client, SerializeExecutable(_)) + .WillOnce(Return(StatusOr(serialized_executable_))); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, BuildSampleExecutable()); + TF_EXPECT_OK(persistor.TryToPersistExecutable( + /*signature_hash=*/123, "signature_string", DefaultOptions(), + compilation_result_add_, *executable, &mock_client)); + + auto key = + CreateCacheKey(/*signature_hash=*/123, compilation_result_add_, + persistor.device_type(), persistor.persistence_prefix()); + TF_ASSERT_OK_AND_ASSIGN(auto entry, ReadCacheEntryFromFile(key, cache_dir_)); + + EXPECT_EQ(entry.executable(), serialized_executable_); +} + +TEST_F(DeviceExecutionPersistorTest, PersistBuildSerializedExecutable) { + XlaDeviceExecutablePersistor::Config config( + /*persistent_cache_directory=*/cache_dir_, + /*disable_strict_signature_checks=*/false, + /*persistence_prefix=*/"xla"); + XlaDeviceExecutablePersistor persistor(config, DefaultOptions().device_type); + + MockCompilerClient mock_client; + EXPECT_CALL(mock_client, SerializeExecutable(_)) + .WillOnce(Return(errors::Unimplemented("Unimplemented."))); + EXPECT_CALL(mock_client, BuildSerializedExecutable(_, _)) + .WillOnce(Return(serialized_executable_)); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, BuildSampleExecutable()); + TF_EXPECT_OK(persistor.TryToPersistExecutable( + /*signature_hash=*/123, "signature_string", DefaultOptions(), + compilation_result_add_, *executable, &mock_client)); + + auto key = + CreateCacheKey(/*signature_hash=*/123, compilation_result_add_, + persistor.device_type(), persistor.persistence_prefix()); + TF_ASSERT_OK_AND_ASSIGN(auto entry, ReadCacheEntryFromFile(key, cache_dir_)); + + EXPECT_EQ(entry.executable(), serialized_executable_); +} + +TEST_F(DeviceExecutionPersistorTest, PersistSerializeExecutableError) { + XlaDeviceExecutablePersistor::Config config( + /*persistent_cache_directory=*/cache_dir_, + /*disable_strict_signature_checks=*/false, + /*persistence_prefix=*/"xla"); + XlaDeviceExecutablePersistor persistor(config, DefaultOptions().device_type); + + MockCompilerClient mock_client; + EXPECT_CALL(mock_client, SerializeExecutable(_)) + .WillOnce(Return(errors::InvalidArgument("InvalidArgument."))); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, BuildSampleExecutable()); + EXPECT_THAT(persistor.TryToPersistExecutable( + /*signature_hash=*/123, "signature_string", DefaultOptions(), + compilation_result_add_, *executable, &mock_client), + testing::StatusIs(error::INVALID_ARGUMENT)); +} + +TEST_F(DeviceExecutionPersistorTest, PersistExecutableEmpty) { + XlaDeviceExecutablePersistor::Config config( + /*persistent_cache_directory=*/cache_dir_, + /*disable_strict_signature_checks=*/false, + /*persistence_prefix=*/"xla"); + XlaDeviceExecutablePersistor persistor(config, DefaultOptions().device_type); + + MockCompilerClient mock_client; + xla::LocalExecutable empty_executable( + nullptr, nullptr, + GetExecutableBuildOptions(DefaultOptions(), compilation_result_add_, 0)); + EXPECT_CALL(mock_client, SerializeExecutable(_)) + .WillOnce(Return(errors::FailedPrecondition("Failed precondition."))); + + EXPECT_THAT(persistor.TryToPersistExecutable( + /*signature_hash=*/123, "signature_string", DefaultOptions(), + compilation_result_add_, empty_executable, &mock_client), + testing::StatusIs(error::FAILED_PRECONDITION)); +} + +TEST_F(DeviceExecutionPersistorTest, LoadCacheDirNotSet) { + XlaDeviceExecutablePersistor::Config config( + /*persistent_cache_directory=*/"", + /*disable_strict_signature_checks=*/false, + /*persistence_prefix=*/"xla"); + XlaDeviceExecutablePersistor persistor(config, DefaultOptions().device_type); + + MockCompilerClient mock_client; + auto executable = + persistor.TryToLoadExecutable(123, "signature_string", DefaultOptions(), + compilation_result_add_, &mock_client); + EXPECT_FALSE(executable.has_value()); +} + +TEST_F(DeviceExecutionPersistorTest, LoadSuccess) { + XlaDeviceExecutablePersistor::Config config( + /*persistent_cache_directory=*/cache_dir_, + /*disable_strict_signature_checks=*/false, + /*persistence_prefix=*/"xla"); + XlaDeviceExecutablePersistor persistor(config, DefaultOptions().device_type); + + MockCompilerClient mock_client; + TF_ASSERT_OK_AND_ASSIGN(auto executable, BuildSampleExecutable()); + EXPECT_CALL(mock_client, LoadExecutable(_, _, serialized_executable_)) + .WillOnce(Return(ByMove(std::move(executable)))); + + auto loaded_executable = persistor.TryToLoadExecutable( + /*signature_hash=*/123, "signature_string", DefaultOptions(), + compilation_result_add_, &mock_client); + + EXPECT_TRUE(loaded_executable.has_value()); + EXPECT_TRUE(loaded_executable.value().ok()); + EXPECT_TRUE((*loaded_executable.value())->executable() != nullptr); +} + +TEST_F(DeviceExecutionPersistorTest, LoadFileDoesntExist) { + XlaDeviceExecutablePersistor::Config config( + /*persistent_cache_directory=*/cache_dir_, + /*disable_strict_signature_checks=*/false, + /*persistence_prefix=*/"xla"); + XlaDeviceExecutablePersistor persistor(config, DefaultOptions().device_type); + + MockCompilerClient mock_client; + // Try to load an executable for a different signature hash (which hasn't been + // persisted). + auto loaded_executable = persistor.TryToLoadExecutable( + /*signature_hash=*/12345, "different_signature", DefaultOptions(), + compilation_result_add_, &mock_client); + + EXPECT_FALSE(loaded_executable.has_value()); +} + +TEST_F(DeviceExecutionPersistorTest, LoadSerializedKeyMismatch) { + XlaDeviceExecutablePersistor::Config config( + /*persistent_cache_directory=*/cache_dir_, + /*disable_strict_signature_checks=*/false, + /*persistence_prefix=*/"xla"); + XlaDeviceExecutablePersistor persistor(config, DefaultOptions().device_type); + + auto key1 = + CreateCacheKey(/*signature_hash=*/123, compilation_result_add_, + persistor.device_type(), persistor.persistence_prefix()); + auto key2 = + CreateCacheKey(/*signature_hash=*/456, compilation_result_add_, + persistor.device_type(), persistor.persistence_prefix()); + // File for key2 contains the same content as key1. + TF_ASSERT_OK(Env::Default()->CopyFile( + GetFilePath(key1, persistor.persistent_cache_directory()), + GetFilePath(key2, persistor.persistent_cache_directory()))); + + MockCompilerClient mock_client; + // Try to load an executable from file corresponding to key2 (whose file + // content corresponds to key1). + auto loaded_executable = persistor.TryToLoadExecutable( + /*signature_hash=*/456, "different_signature", DefaultOptions(), + compilation_result_add_, &mock_client); + + EXPECT_TRUE(loaded_executable.has_value()); + EXPECT_FALSE(loaded_executable->ok()); + EXPECT_THAT(loaded_executable.value(), + testing::StatusIs(error::INVALID_ARGUMENT)); +} + +TEST_F(DeviceExecutionPersistorTest, LoadSerializedHloMismatch) { + XlaDeviceExecutablePersistor::Config config( + /*persistent_cache_directory=*/cache_dir_, + /*disable_strict_signature_checks=*/false, + /*persistence_prefix=*/"xla"); + XlaDeviceExecutablePersistor persistor(config, DefaultOptions().device_type); + + TF_ASSERT_OK_AND_ASSIGN(auto compilation_result_mul, + BuildSampleCompilationResult(true)); + + auto key1 = + CreateCacheKey(/*signature_hash=*/123, compilation_result_add_, + persistor.device_type(), persistor.persistence_prefix()); + auto key2 = + CreateCacheKey(/*signature_hash=*/123, compilation_result_mul, + persistor.device_type(), persistor.persistence_prefix()); + + // Read serialized entry corresponding to key1. + XlaSerializedCacheEntry entry; + TF_ASSERT_OK(ReadTextOrBinaryProto( + Env::Default(), GetFilePath(key1, persistor.persistent_cache_directory()), + &entry)); + // Change the entry's key to key2. + *entry.mutable_key() = key2; + // Write the modified entry to file corresponding to key2. + TF_ASSERT_OK(WriteBinaryProto( + Env::Default(), GetFilePath(key2, persistor.persistent_cache_directory()), + entry)); + + MockCompilerClient mock_client; + // Try to load executable corresponding to key2 (whose file contains HLO + // corresponding to key1). + auto loaded_executable = persistor.TryToLoadExecutable( + /*signature_hash=*/123, "signature", DefaultOptions(), + compilation_result_mul, &mock_client); + + EXPECT_TRUE(loaded_executable.has_value()); + EXPECT_FALSE(loaded_executable->ok()); + EXPECT_THAT(loaded_executable.value(), + testing::StatusIs(error::INVALID_ARGUMENT)); +} + +TEST_F(DeviceExecutionPersistorTest, LoadStrictChecksDisabled) { + XlaDeviceExecutablePersistor::Config config( + /*persistent_cache_directory=*/cache_dir_, + /*disable_strict_signature_checks=*/true, + /*persistence_prefix=*/"xla"); + XlaDeviceExecutablePersistor persistor(config, DefaultOptions().device_type); + + TF_ASSERT_OK_AND_ASSIGN(auto compilation_result_mul, + BuildSampleCompilationResult(true)); + + auto key1 = + CreateCacheKey(/*signature_hash=*/123, compilation_result_add_, + persistor.device_type(), persistor.persistence_prefix()); + auto key2 = + CreateCacheKey(/*signature_hash=*/123, compilation_result_mul, + persistor.device_type(), persistor.persistence_prefix()); + + // Read serialized entry corresponding to key1. + XlaSerializedCacheEntry entry; + TF_ASSERT_OK(ReadTextOrBinaryProto( + Env::Default(), GetFilePath(key1, persistor.persistent_cache_directory()), + &entry)); + // Change the entry's key to key2. + *entry.mutable_key() = key2; + // Write the modified entry to file corresponding to key2. + TF_ASSERT_OK(WriteBinaryProto( + Env::Default(), GetFilePath(key2, persistor.persistent_cache_directory()), + entry)); + + MockCompilerClient mock_client; + TF_ASSERT_OK_AND_ASSIGN(auto executable, BuildSampleExecutable()); + EXPECT_CALL(mock_client, LoadExecutable(_, _, serialized_executable_)) + .WillOnce(Return(ByMove(std::move(executable)))); + + auto loaded_executable = persistor.TryToLoadExecutable( + 123, "signature", DefaultOptions(), compilation_result_mul, &mock_client); + + EXPECT_TRUE(loaded_executable.has_value()); + EXPECT_TRUE(loaded_executable->ok()); +} + +TEST_F(DeviceExecutionPersistorTest, LoadSerializedExecutableEmpty) { + XlaDeviceExecutablePersistor::Config config( + /*persistent_cache_directory=*/cache_dir_, + /*disable_strict_signature_checks=*/false, + /*persistence_prefix=*/"xla"); + XlaDeviceExecutablePersistor persistor(config, DefaultOptions().device_type); + + auto key = + CreateCacheKey(/*signature_hash=*/123, compilation_result_add_, + persistor.device_type(), persistor.persistence_prefix()); + + // Read serialized entry. + XlaSerializedCacheEntry entry; + TF_ASSERT_OK(ReadTextOrBinaryProto( + Env::Default(), GetFilePath(key, persistor.persistent_cache_directory()), + &entry)); + entry.clear_executable(); + // Write entry to another file. + TF_ASSERT_OK(WriteBinaryProto( + Env::Default(), GetFilePath(key, persistor.persistent_cache_directory()), + entry)); + + MockCompilerClient mock_client; + auto loaded_executable = persistor.TryToLoadExecutable( + /*signature_hash=*/123, "signature", DefaultOptions(), + compilation_result_add_, &mock_client); + + EXPECT_TRUE(loaded_executable.has_value()); + EXPECT_FALSE(loaded_executable->ok()); + EXPECT_THAT(loaded_executable.value(), + testing::StatusIs(error::INVALID_ARGUMENT)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index 3a5db290025..d761483a1f4 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -24,12 +24,11 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/protobuf/error_codes.pb.h" -using stream_executor::port::StatusOr; +using tsl::StatusOr; namespace tensorflow { diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index e7d61561efe..304e317ee25 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -20,7 +20,6 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ #include "absl/container/flat_hash_map.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/graph/graph.h" namespace tensorflow { @@ -117,8 +116,7 @@ struct XlaClusterInfo { // dependencies and control dependencies. cluster_deps maps the name name of an // outside compilation cluster to a set of names of outside compilation clusters // that it depends on. -stream_executor::port::StatusOr< - std::unique_ptr>>> +tsl::StatusOr>>> OutsideCompilationClusterDependencies( const Graph* g, const string& outside_compilation_attr_name); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index f0edee10499..047ebe24e3d 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index be2ec0efc2d..2c216500c5c 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -34,7 +34,6 @@ BuildXlaOpsPassFlags* build_ops_flags; MarkForCompilationPassFlags* mark_for_compilation_flags; XlaDeviceFlags* device_flags; XlaOpsCommonFlags* ops_flags; -IntroduceFloatingPointJitterPassFlags* jitter_flags; MlirCommonFlags* mlir_flags; JitRtFlags* jitrt_flags; std::vector* jitrt_flag_list; @@ -163,17 +162,22 @@ void AllocateAndParseJitRtFlags() { jitrt_flags = new JitRtFlags; jitrt_flags->always_specialize = false; jitrt_flags->cost_driven_async_parallel_for = false; + jitrt_flags->enable_crash_reproducer = false; + jitrt_flags->enable_xla_cpu_transformations = false; jitrt_flags->log_query_of_death = false; + jitrt_flags->pack_matmul = false; jitrt_flags->vectorize = false; - jitrt_flags->enable_crash_reproducer = false; jitrt_flag_list = new std::vector({ Flag("always_specialize", &jitrt_flags->always_specialize, ""), Flag("cost_driven_async_parallel_for", &jitrt_flags->cost_driven_async_parallel_for, ""), - Flag("log_query_of_death", &jitrt_flags->log_query_of_death, ""), - Flag("vectorize", &jitrt_flags->vectorize, ""), Flag("enable_crash_reproducer", &jitrt_flags->enable_crash_reproducer, ""), + Flag("enable_xla_cpu_transformations", + &jitrt_flags->enable_xla_cpu_transformations, ""), + Flag("log_query_of_death", &jitrt_flags->log_query_of_death, ""), + Flag("pack_matmul", &jitrt_flags->pack_matmul, ""), + Flag("vectorize", &jitrt_flags->vectorize, ""), }); xla::ParseFlagsFromEnvAndDieIfUnknown("TF_JITRT_FLAGS", *jitrt_flag_list); } @@ -214,9 +218,7 @@ void AllocateAndParseFlags() { ops_flags = new XlaOpsCommonFlags; ops_flags->tf_xla_always_defer_compilation = false; ops_flags->tf_xla_async_compilation = false; - - jitter_flags = new IntroduceFloatingPointJitterPassFlags; - jitter_flags->jitter_amount = 1e-5; + ops_flags->tf_xla_use_device_api = false; // The `enable_mlir_bridge` flag allows the user to explicitly request that // their program is (or isn't) compiled using the MLIR-based TF-to-XLA bridge. @@ -228,13 +230,8 @@ void AllocateAndParseFlags() { // bridge, on a per-graph basis). bool enable_mlir_bridge = false; bool enable_mlir_bridge_is_explicit = false; - bool mlir_bridge_safe_mode = false; bool enable_mlir_merge_control_flow_pass = true; bool enable_mlir_convert_control_to_data_outputs_pass = false; - auto setter_for_jitter_tensor_names = [](string sequence) { - jitter_flags->tensor_names = absl::StrSplit(sequence, ','); - return true; - }; // Dump graphs in TFG dialect. bool use_tfg_graph_dumper = false; @@ -274,15 +271,9 @@ void AllocateAndParseFlags() { "When lazy compilation is enabled, asynchronous compilation starts " "the cluster compilation in the background, and the fallback path " "is executed until the compilation has finished."), - - Flag("tf_introduce_floating_point_jitter_to_tensors", - setter_for_jitter_tensor_names, "", - "The Tensors to add the jitter to. The tensors are named in the " - "TensorId format of :."), - Flag("tf_introduce_floating_point_jitter_amount", - &jitter_flags->jitter_amount, - "The amount of jitter to introduce. This amount is added to each " - "element in the tensors named in `tensor_names."), + Flag("tf_xla_use_device_api", &ops_flags->tf_xla_use_device_api, + "If true, uses the Device API (PjRt) for single device compilation." + " Defaults to false."), Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge, "Enables experimental MLIR-Based TensorFlow Compiler Bridge.", @@ -295,12 +286,6 @@ void AllocateAndParseFlags() { &enable_mlir_convert_control_to_data_outputs_pass, "Enables `tf-executor-convert-control-to-data-outputs` pass for " "MLIR-Based TensorFlow Compiler Bridge."), - Flag( - "tf_mlir_bridge_safe_mode", &mlir_bridge_safe_mode, - "When tf_mlir_enable_mlir_bridge is true, this field can enable " - "the MLIR bridge's safe mode. When the MLIR bridge is in safe mode, " - "it only runs for graphs that use features MLIR bridge currently " - "supports."), Flag("tf_dump_graphs_in_tfg", &use_tfg_graph_dumper, "When tf_dump_graphs_in_tfg is true, graphs after transformations " "are dumped in MLIR TFG dialect and not in GraphDef")}); @@ -311,15 +296,10 @@ void AllocateAndParseFlags() { mlir_flags = new MlirCommonFlags; if (!enable_mlir_bridge_is_explicit) { mlir_flags->tf_mlir_enable_mlir_bridge = - (mlir_bridge_safe_mode) - ? ConfigProto::Experimental:: - MLIR_BRIDGE_ROLLOUT_SAFE_MODE_FALLBACK_ENABLED - : ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED; + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED; } else if (enable_mlir_bridge) { mlir_flags->tf_mlir_enable_mlir_bridge = - (mlir_bridge_safe_mode) - ? ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED - : ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED; + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED; } else { mlir_flags->tf_mlir_enable_mlir_bridge = ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED; @@ -342,7 +322,6 @@ void ResetFlags() { delete mark_for_compilation_flags; delete device_flags; delete ops_flags; - delete jitter_flags; delete mlir_flags; delete flag_list; delete jitrt_flags; @@ -372,15 +351,9 @@ XlaDeviceFlags* GetXlaDeviceFlags() { return device_flags; } -const XlaOpsCommonFlags& GetXlaOpsCommonFlags() { - absl::call_once(flags_init, &AllocateAndParseFlags); - return *ops_flags; -} - -const IntroduceFloatingPointJitterPassFlags& -GetIntroduceFloatingPointJitterPassFlags() { +XlaOpsCommonFlags* GetXlaOpsCommonFlags() { absl::call_once(flags_init, &AllocateAndParseFlags); - return *jitter_flags; + return ops_flags; } MlirCommonFlags* GetMlirCommonFlags() { @@ -435,6 +408,8 @@ static std::atomic xla_compilation_disabled(false); void DisableXlaCompilation() { xla_compilation_disabled = true; } +void EnableXlaCompilation() { xla_compilation_disabled = false; } + bool FailOnXlaCompilation() { return xla_compilation_disabled; } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 1cbfdb9caf5..fc319ac3e0e 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_FLAGS_H_ #define TENSORFLOW_COMPILER_JIT_FLAGS_H_ +#include #include #include @@ -122,6 +123,9 @@ struct XlaOpsCommonFlags { // If true, _XlaCompile compiles the cluster asynchronously with respect to // the main execution. The fallback path is taken while compilation happens. bool tf_xla_async_compilation; + // If true, uses Device API (PjRt) for single device compilation. Defaults to + // false. + bool tf_xla_use_device_api; }; // Flags for the build_xla_ops pass. @@ -147,17 +151,6 @@ struct BuildXlaOpsPassFlags { bool tf_xla_disable_constant_folding; }; -// Flags for the IntroduceFloatingPointJitter pass. -struct IntroduceFloatingPointJitterPassFlags { - // The amount of jitter to introduce. This amount is added to each element in - // the tensors named in `tensor_names. - float jitter_amount; - - // The Tensors to add the jitter to. The tensors are named in the TensorId - // format of :. - std::vector tensor_names; -}; - // Flags for common MLIR configurations. struct MlirCommonFlags { ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge; @@ -175,8 +168,17 @@ struct JitRtFlags { // "query of death". See TfJitRtQueryOfDeathLogger. bool log_query_of_death; + // Enable vectorization, which requires tiling and peeling on different ops. bool vectorize; + // Enable tiling/fusion transformations shared with XLA:CPU Next. + bool enable_xla_cpu_transformations; + + // Enable packing for matmul, which lowers the matmul op into linalg.mmt4d, to + // hopefully get the most optimized layout for matmul inputs, hence accelerate + // accesses to these during matmul computation. + bool pack_matmul; + // Enables crash reproducer for JitRt MLIR pass manager. bool enable_crash_reproducer; }; @@ -191,10 +193,7 @@ struct JitRtFlags { MarkForCompilationPassFlags* GetMarkForCompilationPassFlags(); BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags(); XlaDeviceFlags* GetXlaDeviceFlags(); -const XlaOpsCommonFlags& GetXlaOpsCommonFlags(); - -const IntroduceFloatingPointJitterPassFlags& -GetIntroduceFloatingPointJitterPassFlags(); +XlaOpsCommonFlags* GetXlaOpsCommonFlags(); MlirCommonFlags* GetMlirCommonFlags(); @@ -218,6 +217,10 @@ void AppendMarkForCompilationPassFlags( // be used by a server to ensure that JIT compilation is opt-in. void DisableXlaCompilation(); +// Enables XLA compilation. Can be used with `DisableXlaCompilation` to +// enable/disable JIT compilation at different stages. +void EnableXlaCompilation(); + // Returns `false` unless `DisableXlaCompilation` was called. bool FailOnXlaCompilation(); diff --git a/tensorflow/compiler/jit/get_compiler_ir.cc b/tensorflow/compiler/jit/get_compiler_ir.cc index 4b06371078a..db3642e6111 100644 --- a/tensorflow/compiler/jit/get_compiler_ir.cc +++ b/tensorflow/compiler/jit/get_compiler_ir.cc @@ -15,29 +15,37 @@ limitations under the License. #include "tensorflow/compiler/jit/get_compiler_ir.h" +#include +#include +#include #include #include #include +#include -#include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/compilability_check_util.h" -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/device_compiler.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/jit/xla_platform_info.h" -#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" -#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/resource_handle.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/util/ptr_util.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/status.h" namespace tensorflow { @@ -75,10 +83,153 @@ static StatusOr> BuildExecutable( return std::move(executables[0]); } +static StatusOr BuildHLOString( + IrExportStage stage, const XlaCompiler::CompilationResult& result, + xla::LocalClient* local_client, const XlaCompiler::Options& options) { + switch (stage) { + case IrExportStage::HLO: + case IrExportStage::HLO_NO_METADATA: + case IrExportStage::HLO_SERIALIZED: { + TF_ASSIGN_OR_RETURN(xla::ProgramShape program_shape, + result.computation->GetProgramShape()); + xla::HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN( + std::unique_ptr new_module, + xla::HloModule::CreateFromProto(result.computation->proto(), config)); + + xla::HloPrintOptions opts; + if (stage == IrExportStage::HLO_NO_METADATA) { + opts.set_print_metadata(false); + } + + if (stage == IrExportStage::HLO_SERIALIZED) { + return new_module->ToProto().SerializeAsString(); + } else { + return new_module->ToString(opts); + } + } + case IrExportStage::OPTIMIZED_HLO: + case IrExportStage::OPTIMIZED_HLO_SERIALIZED: { + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + BuildExecutable(local_client, result, options)); + xla::Executable* new_executable = executable->executable(); + if (stage == IrExportStage::OPTIMIZED_HLO_SERIALIZED) { + return new_executable->module().ToProto().SerializeAsString(); + } else { + return new_executable->module().ToString(); + } + } + case IrExportStage::OPTIMIZED_HLO_PROTO_SERIALIZED: { + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + BuildExecutable(local_client, result, options, + /*xla_embed_ir_in_executable=*/true)); + return executable->executable()->hlo_proto()->SerializeAsString(); + } + case IrExportStage::OPTIMIZED_HLO_DOT: { + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + BuildExecutable(local_client, result, options)); + StatusOr graph = xla::RenderGraph( + *executable->executable()->module().entry_computation(), + "Visualization", + /*debug_options=*/{}, xla::RenderedGraphFormat::kDot, + /*hlo_render_options=*/{}); + TF_RETURN_IF_ERROR(graph.status()); + return *graph; + } + } +} + +static StatusOr> +BuildXlaCompilerArgumentFromTensorSpec( + const FunctionBody* fbody, absl::Span must_be_constant_idxs, + absl::Span inputs, + absl::Span variable_args, Device* device, + absl::Span flat_arg_shape_and_dtype) { + TF_RET_CHECK(fbody != nullptr); + auto& input_args = fbody->fdef.signature().input_arg(); + int input_arg_size = input_args.size(); + std::vector args; + args.reserve(input_arg_size); + + for (auto& arg_info : flat_arg_shape_and_dtype) { + XlaCompiler::Argument arg; + arg.kind = XlaCompiler::Argument::kParameter; + arg.type = arg_info.dtype; + arg.shape = arg_info.shape; + args.push_back(arg); + } + + // Build Xla Compiler Arguments from concrete_fn.captured_inputs + absl::flat_hash_map variable_info_lookup; + TF_RETURN_IF_ERROR( + CreateVariableInfoLookup(variable_args, variable_info_lookup)); + + for (const VariableInfo& info : variable_args) { + TF_RET_CHECK(!info.var() || info.lock_held() || info.shared_lock_held()) + << "Need to hold the lock on resource variables " + "before calling BuildXlaCompilerArguments"; + variable_info_lookup.emplace(info.index(), &info); + } + + int offset = flat_arg_shape_and_dtype.size(); + // Here it takes in the concrete_fn.captured_inputs and builds the appropriate + // XLA compiler arguments. + for (int64_t input_num = offset; input_num < input_arg_size; ++input_num) { + const Tensor* input = inputs[input_num]; + + XlaCompiler::Argument arg; + if (variable_info_lookup.count(input_num)) { + // Handles tf.resource variables. + TF_RET_CHECK(input->dtype() == DT_RESOURCE); + const VariableInfo& variable = *variable_info_lookup[input_num]; + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = XlaResource::kVariable; + arg.definition_stack_trace = variable.definition_stack_trace(); + TF_RET_CHECK(variable.var() && variable.var()->is_initialized); + const Tensor* value = variable.var()->tensor(); + arg.type = value->dtype(); + arg.shape = value->shape(); + arg.initialized = true; + } else { + // Instead of embedding constant into HLO, + // we handle tf.constant as parameter to reduce size. + arg.kind = XlaCompiler::Argument::kParameter; + arg.type = input->dtype(); + arg.shape = input->shape(); + } + args.push_back(arg); + } + + for (int64_t i = 0; i < input_arg_size; ++i) { + args[i].name = input_args[i].name(); + } + + return args; +} + +/** + * Clarifies the different meanings of 'input_arg_shape_and_dtype' and + * 'input_handles' in different cases. + * + * For TENSOR_SPEC case: + * - `input_arg_shape_and_dtype`: Contains the shape and dtype of + * concrete_fn input args. + * - `input_handles`: Contains the concrete_fn.captured_input tensors. + * + * For CONCRETE_INPUT case: + * - `input_arg_shape_and_dtype`: it is empty. + * - `input_handles`: Contains all concrete_fn inputs tensors, including + * captured inputs. + */ StatusOr GetCompilerIr( IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, absl::string_view func_name, Device* dev, EagerContext* context, - absl::Span inputs_handles) { + absl::Span input_arg_shape_and_dtype, + absl::Span input_handles, + CompilerArgSource compiler_arg_source) { + using XlaDeviceCompiler = + DeviceCompiler; + auto is_tfrt_tpu_supported_stage = [](IrExportStage stage) { return stage == IrExportStage::HLO || stage == IrExportStage::HLO_NO_METADATA || @@ -104,15 +255,18 @@ StatusOr GetCompilerIr( TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( flr, function, &fbody, &constant_arg_indices, &resource_arg_indices)); - MemoryTypeVector input_memory_types = - GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices); - MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody); + // `input_args` includes both concrete_fn input args and captured_input here. + auto& input_args = fbody->fdef.signature().input_arg(); + // Here input_arg_size = len(flat_args) + len(captured_input) + int input_arg_size = input_args.size(); + std::vector inputs(input_arg_size); std::deque inputs_storage; - std::vector inputs; - inputs.reserve(inputs_handles.size()); - for (int i = 0; i < inputs_handles.size(); i++) { - const TensorHandle* th = inputs_handles[i]; + std::vector variable_infos; + int offset = input_arg_shape_and_dtype.size(); + + for (int i = 0; i < input_handles.size(); i++) { + const TensorHandle* th = input_handles[i]; const Tensor* t; // Handle owns the tensor. TF_RETURN_IF_ERROR(th->Tensor(&t)); @@ -121,27 +275,26 @@ StatusOr GetCompilerIr( inputs_storage.emplace_back(t->dtype(), t->shape()); TF_RETURN_IF_ERROR( th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back())); - inputs.push_back(&inputs_storage.back()); + inputs[i + offset] = &inputs_storage.back(); } else { - inputs.push_back(t); + inputs[i + offset] = t; } } - std::vector variable_infos; TF_RETURN_IF_ERROR(GetVariableInfosFromInputs( rmgr, dev, inputs, resource_arg_indices, &variable_infos)); TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos))); XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(dev); - XlaCompilationCache* cache; - TF_RETURN_IF_ERROR(rmgr->LookupOrCreate( - rmgr->default_container(), "xla_cache", &cache, - [&](XlaCompilationCache** cache_write_into) { - return BuildXlaCompilationCache(dev, flr, platform_info, - cache_write_into); + XlaDeviceCompiler* xla_device_compiler; + TF_RETURN_IF_ERROR(rmgr->LookupOrCreate( + rmgr->default_container(), "xla_device_compiler", &xla_device_compiler, + [&](XlaDeviceCompiler** xla_device_compiler) { + return BuildXlaDeviceCompiler(dev, flr, platform_info, + xla_device_compiler); })); - core::ScopedUnref cache_ref(cache); + core::ScopedUnref xla_device_compiler_ref(xla_device_compiler); se::Stream* stream = nullptr; if (const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info = @@ -151,9 +304,10 @@ StatusOr GetCompilerIr( XlaCompiler::Options options; if (platform_info.device_type() == DEVICE_TPU && stream == nullptr) { - options = GenerateTfrtTpuCompilerOptions(*cache, *flr); + options = GenerateTfrtTpuCompilerOptions(*xla_device_compiler, *flr); } else { - options = GenerateCompilerOptions(*cache, *flr, dev, stream, platform_info, + options = GenerateCompilerOptions(*xla_device_compiler, *flr, dev, stream, + platform_info, /*has_ref_vars=*/false); } @@ -163,67 +317,24 @@ StatusOr GetCompilerIr( XlaCompiler compiler(options); - StatusOr> args = - XlaComputationLaunchContext::BuildXlaCompilerArguments( - constant_arg_indices, inputs, variable_infos, dev); + StatusOr> args; + + if (compiler_arg_source == CompilerArgSource::TENSOR_SPEC) { + args = BuildXlaCompilerArgumentFromTensorSpec(fbody, constant_arg_indices, + inputs, variable_infos, dev, + input_arg_shape_and_dtype); + } else if (compiler_arg_source == CompilerArgSource::CONCRETE_INPUT) { + args = XlaComputationLaunchContext::BuildXlaCompilerArguments( + constant_arg_indices, inputs, variable_infos, dev); + } TF_RETURN_IF_ERROR(args.status()); - xla::LocalClient* local_client = cache->client(); + xla::LocalClient* local_client = xla_device_compiler->client(); XlaCompiler::CompilationResult result; TF_RETURN_IF_ERROR( compiler.CompileFunction(compile_options, function, *args, &result)); - switch (stage) { - case IrExportStage::HLO: - case IrExportStage::HLO_NO_METADATA: - case IrExportStage::HLO_SERIALIZED: { - TF_ASSIGN_OR_RETURN(xla::ProgramShape program_shape, - result.computation->GetProgramShape()); - xla::HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN( - std::unique_ptr new_module, - xla::HloModule::CreateFromProto(result.computation->proto(), config)); - - xla::HloPrintOptions opts; - if (stage == IrExportStage::HLO_NO_METADATA) { - opts.set_print_metadata(false); - } - - if (stage == IrExportStage::HLO_SERIALIZED) { - return new_module->ToProto().SerializeAsString(); - } else { - return new_module->ToString(opts); - } - } - case IrExportStage::OPTIMIZED_HLO: - case IrExportStage::OPTIMIZED_HLO_SERIALIZED: { - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - BuildExecutable(local_client, result, options)); - xla::Executable* new_executable = executable->executable(); - if (stage == IrExportStage::OPTIMIZED_HLO_SERIALIZED) { - return new_executable->module().ToProto().SerializeAsString(); - } else { - return new_executable->module().ToString(); - } - } - case IrExportStage::OPTIMIZED_HLO_PROTO_SERIALIZED: { - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - BuildExecutable(local_client, result, options, - /*xla_embed_ir_in_executable=*/true)); - return executable->executable()->hlo_proto()->SerializeAsString(); - } - case IrExportStage::OPTIMIZED_HLO_DOT: { - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - BuildExecutable(local_client, result, options)); - StatusOr graph = xla::RenderGraph( - *executable->executable()->module().entry_computation(), - "Visualization", - /*debug_options=*/{}, xla::RenderedGraphFormat::kDot, - /*hlo_render_options=*/{}); - TF_RETURN_IF_ERROR(graph.status()); - return *graph; - } - } + return BuildHLOString(stage, result, local_client, options); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/get_compiler_ir.h b/tensorflow/compiler/jit/get_compiler_ir.h index 10b37b54a7c..107c7a002b1 100644 --- a/tensorflow/compiler/jit/get_compiler_ir.h +++ b/tensorflow/compiler/jit/get_compiler_ir.h @@ -15,8 +15,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_ #define TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_ +#include +#include + #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/core/platform/statusor.h" namespace tensorflow { @@ -37,12 +41,24 @@ enum class IrExportStage { OPTIMIZED_HLO_DOT }; +struct ArgShapeAndDType { + TensorShape shape; + DataType dtype; +}; + +enum class CompilerArgSource { + TENSOR_SPEC, + CONCRETE_INPUT, +}; + // Returns the IR format of the selected stage for a given function `func_name` // using library runtime `runtime` on a device `dev` with given `inputs`. StatusOr GetCompilerIr( IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, absl::string_view func_name, Device* dev, EagerContext* context, - absl::Span inputs); + absl::Span flat_arg_shape_and_dtype_or_empty, + absl::Span input_handles, + CompilerArgSource compiler_arg_source); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/introduce_floating_point_jitter_pass.cc b/tensorflow/compiler/jit/introduce_floating_point_jitter_pass.cc deleted file mode 100644 index 64370a609d8..00000000000 --- a/tensorflow/compiler/jit/introduce_floating_point_jitter_pass.cc +++ /dev/null @@ -1,153 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h" -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "tensorflow/cc/framework/scope_internal.h" -#include "tensorflow/cc/ops/const_op.h" -#include "tensorflow/cc/ops/math_ops.h" -#include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/core/graph/tensor_id.h" - -namespace tensorflow { -namespace { -std::vector>> GetNodesToModify( - const Graph& g, absl::Span tensor_names) { - absl::flat_hash_map name_to_node; - for (Node* n : g.op_nodes()) { - name_to_node[n->name()] = n; - } - - absl::flat_hash_map> nodes_to_modify_map; - - for (const string& tensor_name : tensor_names) { - TensorId tensor_id = ParseTensorName(tensor_name); - auto it = name_to_node.find(tensor_id.node()); - DCHECK(it != name_to_node.end()); - nodes_to_modify_map[it->second].push_back(tensor_id.index()); - } - - std::vector>> nodes_to_modify; - absl::c_copy(nodes_to_modify_map, std::back_inserter(nodes_to_modify)); - - absl::c_sort(nodes_to_modify, - [](const std::pair>& a, - const std::pair>& b) { - return a.first->id() < b.first->id(); - }); - - for (auto& p : nodes_to_modify) { - absl::c_sort(p.second); - p.second.erase(std::unique(p.second.begin(), p.second.end()), - p.second.end()); - } - - return nodes_to_modify; -} - -Status IntroduceJitterToTensor( - Graph* g, Node* n, int oidx, float jitter_amount, - absl::flat_hash_map, Output>* - node_to_jitter_constant) { - std::vector edges_to_update; - absl::c_copy_if(n->out_edges(), std::back_inserter(edges_to_update), - [&](const Edge* e) { return e->src_output() == oidx; }); - - if (edges_to_update.empty()) { - VLOG(1) << "No users for " << TensorId(n->name(), oidx).ToString(); - return OkStatus(); - } - - VLOG(1) << "Updating " << edges_to_update.size() << " users for " - << TensorId(n->name(), oidx).ToString(); - - Status status; - Scope s = NewInternalScope(g, &status, /*refiner=*/nullptr) - .NewSubScope(absl::StrCat(n->name(), "/jitter")); - - Output node_out(n, oidx); - Output jitter_constant; - DataType dtype = n->output_type(oidx); - auto it = node_to_jitter_constant->find({dtype, n}); - if (it == node_to_jitter_constant->end()) { - Tensor constant_tensor; - if (dtype == DT_FLOAT) { - constant_tensor = Tensor(static_cast(jitter_amount)); - } else if (dtype == DT_HALF) { - constant_tensor = Tensor(Eigen::half(jitter_amount)); - } else { - return errors::Unimplemented("Only float and half are supported"); - } - - jitter_constant = - ops::Const(s.WithOpName("jitter_amount"), constant_tensor); - (*node_to_jitter_constant)[{dtype, n}] = jitter_constant; - } else { - jitter_constant = it->second; - } - - Output jittered_output = - ops::Add(s.NewSubScope(absl::StrCat(oidx)).WithOpName("jittered_output"), - jitter_constant, node_out); - - TF_RETURN_IF_ERROR(status); - - for (const Edge* e : edges_to_update) { - VLOG(3) << "Updating " << e->dst()->name(); - TF_RETURN_IF_ERROR( - g->UpdateEdge(jittered_output.node(), 0, e->dst(), e->dst_input())); - } - - // Add a control edge to make sure that the two inputs to jittered_output are - // from the same frame. - g->AddControlEdge(n, jitter_constant.node()); - - return OkStatus(); -} -} // namespace - -Status IntroduceFloatingPointJitter(Graph* graph, - absl::Span tensor_names, - float jitter_amount) { - if (tensor_names.empty()) { - VLOG(3) << "Nothing to do"; - return OkStatus(); - } - - std::vector>> nodes_to_modify = - GetNodesToModify(*graph, tensor_names); - - absl::flat_hash_map, Output> - node_to_jitter_constant; - for (const auto& p : nodes_to_modify) { - for (int oidx : p.second) { - TF_RETURN_IF_ERROR(IntroduceJitterToTensor( - graph, p.first, oidx, jitter_amount, &node_to_jitter_constant)); - } - } - - return OkStatus(); -} - -Status IntroduceFloatingPointJitterPass::Run( - const GraphOptimizationPassOptions& options) { - const IntroduceFloatingPointJitterPassFlags& flags = - GetIntroduceFloatingPointJitterPassFlags(); - - return IntroduceFloatingPointJitter(options.graph->get(), flags.tensor_names, - flags.jitter_amount); -} -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h b/tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h deleted file mode 100644 index 115f72a6eea..00000000000 --- a/tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_H_ -#define TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_H_ - -#include "tensorflow/core/common_runtime/optimization_registry.h" - -namespace tensorflow { -// A debug-only pass that introduces error into outputs of specific TF nodes. -// This can be used to check the sensitivity of a TF graph to floating point -// rounding differences. -// -// This pass is controlled by TF_XLA_FLAGS. Please see -// IntroduceFloatingPointJitterPassFlags for information on how to use this. -class IntroduceFloatingPointJitterPass : public GraphOptimizationPass { - public: - IntroduceFloatingPointJitterPass() = default; - - Status Run(const GraphOptimizationPassOptions& options) override; -}; -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_H_ diff --git a/tensorflow/compiler/jit/introduce_floating_point_jitter_pass_internal.h b/tensorflow/compiler/jit/introduce_floating_point_jitter_pass_internal.h deleted file mode 100644 index ea7261bc872..00000000000 --- a/tensorflow/compiler/jit/introduce_floating_point_jitter_pass_internal.h +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_INTERNAL_H_ -#define TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_INTERNAL_H_ - -#include "absl/types/span.h" -#include "tensorflow/core/common_runtime/optimization_registry.h" - -namespace tensorflow { -Status IntroduceFloatingPointJitter(Graph* graph, - absl::Span tensor_names, - float jitter_amount); -} - -#endif // TENSORFLOW_COMPILER_JIT_INTRODUCE_FLOATING_POINT_JITTER_PASS_INTERNAL_H_ diff --git a/tensorflow/compiler/jit/introduce_floating_point_jitter_pass_test.cc b/tensorflow/compiler/jit/introduce_floating_point_jitter_pass_test.cc deleted file mode 100644 index 25155a133d7..00000000000 --- a/tensorflow/compiler/jit/introduce_floating_point_jitter_pass_test.cc +++ /dev/null @@ -1,197 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass_internal.h" - -#include "tensorflow/cc/framework/ops.h" -#include "tensorflow/cc/ops/array_ops.h" -#include "tensorflow/cc/ops/const_op.h" -#include "tensorflow/cc/ops/linalg_ops.h" -#include "tensorflow/cc/ops/math_ops.h" -#include "tensorflow/compiler/jit/node_matchers.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace { - -using testing::matchers::Const; -using testing::matchers::Inputs; -using testing::matchers::Name; -using testing::matchers::NodeWith; -using testing::matchers::Op; -using testing::matchers::Out; - -TEST(IntroduceFloatingPointJitterTest, SingleOutputFP32) { - Scope root = Scope::NewRootScope().ExitOnError(); - - Output input_a = ops::Placeholder(root.WithOpName("input_a"), DT_FLOAT); - Output input_b = ops::Placeholder(root.WithOpName("input_b"), DT_FLOAT); - - Output sigmoid_a = ops::Sigmoid(root.WithOpName("sigmoid_a"), input_a); - Output sigmoid_b = ops::Sigmoid(root.WithOpName("sigmoid_b"), input_b); - - Output tanh_a = ops::Tanh(root.WithOpName("tanh_a"), sigmoid_a); - Output tanh_b = ops::Tanh(root.WithOpName("tanh_b"), sigmoid_b); - - auto graph = std::make_unique(OpRegistry::Global()); - TF_ASSERT_OK(root.ToGraph(graph.get())); - - std::vector tensor_names; - tensor_names.push_back("sigmoid_a"); - tensor_names.push_back("sigmoid_b"); - - TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f)); - VLOG(1) << graph->ToGraphDefDebug().DebugString(); - - auto m_sigmoid_a = Out(NodeWith(Name("sigmoid_a"))); - auto m_sigmoid_a_with_jitter = - NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_a)); - auto m_tanh_a = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_a_with_jitter))); - - auto m_sigmoid_b = Out(NodeWith(Name("sigmoid_b"))); - auto m_sigmoid_b_with_jitter = - NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_b)); - auto m_tanh_b = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_b_with_jitter))); - - Node* tanh_a_transformed = testing::FindNodeByName(graph.get(), "tanh_a"); - Node* tanh_b_transformed = testing::FindNodeByName(graph.get(), "tanh_b"); - - ASSERT_NE(tanh_a_transformed, nullptr); - ASSERT_NE(tanh_b_transformed, nullptr); - - EXPECT_THAT(tanh_a_transformed, m_tanh_a); - EXPECT_THAT(tanh_b_transformed, m_tanh_b); -} - -TEST(IntroduceFloatingPointJitterTest, TwoNodesOneUser) { - Scope root = Scope::NewRootScope().ExitOnError(); - - Output input_a = ops::Placeholder(root.WithOpName("input_a"), DT_FLOAT); - Output input_b = ops::Placeholder(root.WithOpName("input_b"), DT_FLOAT); - - Output sigmoid_a = ops::Sigmoid(root.WithOpName("sigmoid_a"), input_a); - Output sigmoid_b = ops::Sigmoid(root.WithOpName("sigmoid_b"), input_b); - - Output add = ops::Add(root.WithOpName("add"), sigmoid_a, sigmoid_b); - - auto graph = std::make_unique(OpRegistry::Global()); - TF_ASSERT_OK(root.ToGraph(graph.get())); - - std::vector tensor_names; - tensor_names.push_back("sigmoid_a"); - tensor_names.push_back("sigmoid_b"); - - TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f)); - VLOG(1) << graph->ToGraphDefDebug().DebugString(); - - auto m_sigmoid_a = Out(NodeWith(Name("sigmoid_a"))); - auto m_sigmoid_a_with_jitter = - NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_a)); - - auto m_sigmoid_b = Out(NodeWith(Name("sigmoid_b"))); - auto m_sigmoid_b_with_jitter = - NodeWith(Op("Add"), Inputs(Const(0.01f), m_sigmoid_b)); - - auto m_add = NodeWith(Op("Add"), Inputs(Out(m_sigmoid_a_with_jitter), - Out(m_sigmoid_b_with_jitter))); - - Node* add_transformed = testing::FindNodeByName(graph.get(), "add"); - - ASSERT_NE(add_transformed, nullptr); - - EXPECT_THAT(add_transformed, m_add); -} - -TEST(IntroduceFloatingPointJitterTest, NotFP32) { - Scope root = Scope::NewRootScope().ExitOnError(); - - Output input = ops::Placeholder(root.WithOpName("input"), DT_HALF); - - Output sigmoid = ops::Sigmoid(root.WithOpName("sigmoid"), input); - - Output tanh = ops::Tanh(root.WithOpName("tanh"), sigmoid); - - auto graph = std::make_unique(OpRegistry::Global()); - TF_ASSERT_OK(root.ToGraph(graph.get())); - - std::vector tensor_names; - tensor_names.push_back("sigmoid"); - - TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f)); - VLOG(1) << graph->ToGraphDefDebug().DebugString(); - - auto m_sigmoid = Out(NodeWith(Name("sigmoid"))); - auto m_sigmoid_with_jitter = - NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_sigmoid)); - auto m_tanh = NodeWith(Op("Tanh"), Inputs(Out(m_sigmoid_with_jitter))); - - Node* tanh_transformed = testing::FindNodeByName(graph.get(), "tanh"); - - ASSERT_NE(tanh_transformed, nullptr); - - EXPECT_THAT(tanh_transformed, m_tanh); -} - -TEST(IntroduceFloatingPointJitterTest, MultiOutput) { - Scope root = Scope::NewRootScope().ExitOnError(); - - Output input = ops::Placeholder(root.WithOpName("input"), DT_HALF); - - ops::Svd svd(root.WithOpName("svd"), input); - - Output tanh_s = ops::Tanh(root.WithOpName("tanh_s"), svd.s); - Output tanh_u = ops::Tanh(root.WithOpName("tanh_u"), svd.u); - Output tanh_v = ops::Tanh(root.WithOpName("tanh_v"), svd.v); - - auto graph = std::make_unique(OpRegistry::Global()); - TF_ASSERT_OK(root.ToGraph(graph.get())); - - std::vector tensor_names; - tensor_names.push_back("svd:0"); - tensor_names.push_back("svd:2"); - - TF_ASSERT_OK(IntroduceFloatingPointJitter(graph.get(), tensor_names, 0.01f)); - VLOG(1) << graph->ToGraphDefDebug().DebugString(); - - auto m_svd_s = Out(0, NodeWith(Name("svd"))); - auto m_svd_s_with_jitter = Out( - NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_svd_s))); - - auto m_svd_u = Out(1, NodeWith(Name("svd"))); - - auto m_svd_v = Out(2, NodeWith(Name("svd"))); - auto m_svd_v_with_jitter = Out( - NodeWith(Op("Add"), Inputs(Const(Tensor(Eigen::half(0.01f))), m_svd_v))); - - auto m_tanh_s = NodeWith(Op("Tanh"), Inputs(m_svd_s_with_jitter)); - auto m_tanh_u = NodeWith(Op("Tanh"), Inputs(m_svd_u)); - auto m_tanh_v = NodeWith(Op("Tanh"), Inputs(m_svd_v_with_jitter)); - - Node* tanh_s_transformed = testing::FindNodeByName(graph.get(), "tanh_s"); - ASSERT_NE(tanh_s_transformed, nullptr); - - Node* tanh_u_transformed = testing::FindNodeByName(graph.get(), "tanh_u"); - ASSERT_NE(tanh_u_transformed, nullptr); - - Node* tanh_v_transformed = testing::FindNodeByName(graph.get(), "tanh_v"); - ASSERT_NE(tanh_v_transformed, nullptr); - - EXPECT_THAT(tanh_s_transformed, m_tanh_s); - EXPECT_THAT(tanh_u_transformed, m_tanh_u); - EXPECT_THAT(tanh_v_transformed, m_tanh_v); -} -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 8adfff34c66..a4a9c7d7e11 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" #include "tensorflow/compiler/jit/force_xla_constants_on_host_pass.h" #include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" -#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/compiler/jit/report_clustering_info_pass.h" @@ -35,9 +34,6 @@ namespace tensorflow { REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 36, EncapsulateXlaComputationsPass); -REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 35, - IntroduceFloatingPointJitterPass); - // from // tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc // FunctionalizeControlFlowPass: 27 diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 8f1d9844c24..8fb1aecf8c5 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -1,6 +1,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/compiler/tf2xla:internal", "//tensorflow/core/tpu:__subpackages__", @@ -17,7 +18,7 @@ XLA_OPS_DEPS = [ "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:xla_activity_listener", "//tensorflow/compiler/jit:xla_activity_proto_cc", - "//tensorflow/compiler/jit:xla_compilation_cache", + "//tensorflow/compiler/jit:device_compiler", "//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration", "//tensorflow/compiler/jit:xla_cluster_util", "//tensorflow/compiler/jit:xla_launch_util", @@ -48,7 +49,14 @@ cc_library( name = "xla_ops_no_jit_rewrite_registration", srcs = ["xla_ops.cc"], hdrs = ["xla_ops.h"], - deps = XLA_OPS_DEPS + ["//tensorflow/core/platform:refcount"], + deps = XLA_OPS_DEPS + [ + "//tensorflow/compiler/jit:device_compilation_cache", + "//tensorflow/compiler/jit:device_compilation_profiler", + "//tensorflow/compiler/jit:tf_graph_to_hlo_compiler", + "//tensorflow/compiler/jit:tf_to_hlo_compiler", + "//tensorflow/compiler/jit:xla_compile_util", + "//tensorflow/core/platform:refcount", + ], alwayslink = 1, ) diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 9bebf2be3b9..8868a4c8229 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -18,16 +18,19 @@ limitations under the License. #include #include #include +#include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" +#include "tensorflow/compiler/jit/device_compilation_profiler.h" +#include "tensorflow/compiler/jit/device_compiler.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" -#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/xla_compile_util.h" #include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -66,6 +69,8 @@ limitations under the License. namespace tensorflow { namespace { +using XlaDeviceCompiler = + DeviceCompiler; auto* xla_launch_counter = monitoring::Counter<1>::New( "/tensorflow/core/xla_launch_counter", @@ -225,102 +230,6 @@ GetXlaCompilerArgsAndSnapshotVariables( return result; } -struct CompilationResults { - const XlaCompiler::CompilationResult* compilation_result = nullptr; - xla::LocalExecutable* executable = nullptr; - xla::LocalClient* client = nullptr; - // This needs to be a shared pointer because it needs to be captured in an - // std::function, and std::function requires all captured arguments to be - // copyable. - std::shared_ptr> variable_infos; -}; - -struct CompilerArgsAndVariableSnapshots { - std::vector xla_compiler_args; - ResourceVarsSnapshot variable_snapshots; -}; - -// CompilationResults is filled when the XLA cluster is already compiled. -// CompilerArgsAndVariableSnapshots is filled when the XLA cluster is not -// compiled yet. -using CompilationResultOrXlaCompilerArgsAndVariableSnapshots = - std::variant; - -// If the XLA Cluster is already compiled, we don't snapshot the variables and -// return the compilation result and executable. -// If the XLA Cluster is not compiled yet, we snapshot the variables and return -// the variable snapshots and the XLA compiler arguments. -StatusOr -GetCompilationResultOrGetXlaCompilerArgsAndSnapshotVariables( - const XlaPlatformInfo& platform_info, const NameAttrList& function, - absl::Span variable_indices, - absl::Span must_be_constant_idxs, - absl::Span inputs, OpKernelContext* ctx) { - CompilationResults compilation_results; - CompilerArgsAndVariableSnapshots compiler_args_and_variable_snapshots; - - compilation_results.variable_infos = - std::make_shared>(); - std::vector& variable_infos = - *compilation_results.variable_infos; - TF_RETURN_IF_ERROR( - GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(), inputs, - variable_indices, &variable_infos)); - TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos))); - - TF_ASSIGN_OR_RETURN(compiler_args_and_variable_snapshots.xla_compiler_args, - XlaComputationLaunchContext::BuildXlaCompilerArguments( - must_be_constant_idxs, inputs, variable_infos, - static_cast(ctx->device()))); - - // We store information about the JIT-compiled XLA computation - // in the ResourceMgr. - ResourceMgr* rm = ctx->resource_manager(); - if (!rm) { - return errors::Internal("No resource manager."); - } - - XlaCompilationCache* cache; - TF_RETURN_IF_ERROR(rm->LookupOrCreate( - rm->default_container(), "xla_cache", &cache, - [&](XlaCompilationCache** cache) { - return BuildXlaCompilationCache(ctx->device(), ctx->function_library(), - platform_info, cache); - })); - // Hold the reference to the JIT during evaluation. (We could probably - // free it sooner because the ResourceMgr will retain a reference, but - // this is more obviously correct.) - core::ScopedUnref cache_ref(cache); - - TF_ASSIGN_OR_RETURN( - auto compilation_result_and_executable, - cache->GetCompilationResultIfAlreadyCompiled( - function, compiler_args_and_variable_snapshots.xla_compiler_args)); - - if (compilation_result_and_executable.executable != nullptr) { - // The XLA signature is already compiled. We return the compilation result - // and the locks on the variables(contained in variable_infos). - compilation_results.compilation_result = - compilation_result_and_executable.compilation_result; - compilation_results.executable = - compilation_result_and_executable.executable; - compilation_results.client = - static_cast(cache->client()); - return CompilationResultOrXlaCompilerArgsAndVariableSnapshots( - std::move(compilation_results)); - } - - // The XLA signature is not compiled yet. We snapshot the variables and - // release the locks on the variables(by destructing variable_infos). We - // return the compiler arguments and variable snapshots. - TF_RETURN_IF_ERROR(SnapshotResourceVariables( - ctx, variable_indices, variable_infos, - &compiler_args_and_variable_snapshots.variable_snapshots)); - - return CompilationResultOrXlaCompilerArgsAndVariableSnapshots( - std::move(compiler_args_and_variable_snapshots)); -} - } // namespace XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, @@ -339,8 +248,8 @@ static Status CompileToLocalExecutable( OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars, const XlaPlatformInfo& platform_info, const std::vector& args, - XlaCompilationCache::CompileMode compile_mode, - bool may_alias_resource_update, xla::LocalClient** client, + DeviceCompileMode compile_mode, bool may_alias_resource_update, + xla::LocalClient** client, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable) { // We store information about the JIT-compiled XLA computation @@ -350,23 +259,31 @@ static Status CompileToLocalExecutable( return errors::Internal("No resource manager."); } - XlaCompilationCache* cache; - TF_RETURN_IF_ERROR(rm->LookupOrCreate( - rm->default_container(), "xla_cache", &cache, - [&](XlaCompilationCache** cache) { - return BuildXlaCompilationCache(ctx->device(), ctx->function_library(), - platform_info, cache); + XlaDeviceCompiler* xla_device_compiler; + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), "xla_device_compiler", &xla_device_compiler, + [&](XlaDeviceCompiler** xla_device_compiler) { + return BuildXlaDeviceCompiler(ctx->device(), ctx->function_library(), + platform_info, xla_device_compiler); + })); + DeviceCompilationProfiler* profiler; + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), "device_compilation_profiler", &profiler, + [](DeviceCompilationProfiler** profiler) { + *profiler = new DeviceCompilationProfiler(); + return OkStatus(); })); - // Hold the reference to the JIT during evaluation. (We could probably - // free it sooner because the ResourceMgr will retain a reference, but - // this is more obviously correct.) - core::ScopedUnref cache_ref(cache); + // Hold the reference to the XLA device compiler and profiler during + // evaluation. (We could probably free them sooner because the ResourceMgr + // will retain references, but this is more obviously correct.) + core::ScopedUnref xla_device_compiler_ref(xla_device_compiler); + core::ScopedUnref profiler_ref(profiler); - *client = static_cast(cache->client()); + *client = static_cast(xla_device_compiler->client()); - XlaCompiler::Options options = - GenerateCompilerOptions(*cache, *ctx->function_library(), ctx->device(), - GetStream(ctx), platform_info, has_ref_vars); + XlaCompiler::Options options = GenerateCompilerOptions( + *xla_device_compiler, *ctx->function_library(), ctx->device(), + GetStream(ctx), platform_info, has_ref_vars); XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; @@ -376,8 +293,9 @@ static Status CompileToLocalExecutable( compile_options.alias_resource_update = !has_ref_vars && may_alias_resource_update; - return cache->Compile(options, function, args, compile_options, compile_mode, - compilation_result, executable); + return xla_device_compiler->CompileIfNeeded( + options, function, args, compile_options, compile_mode, profiler, + compilation_result, executable); } // Get-or-create thread pool for a given collective. @@ -410,67 +328,61 @@ void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { xla::LocalClient* client; const XlaCompiler::CompilationResult* compilation_result; xla::LocalExecutable* executable; + std::vector xla_compiler_args; - auto compilation_result_or_args_and_variables_snapshot = - GetCompilationResultOrGetXlaCompilerArgsAndSnapshotVariables( - platform_info_, function_, resources_, constants_, inputs, ctx); - OP_REQUIRES_OK_ASYNC( - ctx, compilation_result_or_args_and_variables_snapshot.status(), done); - - std::shared_ptr> variable_infos; - ResourceVarsSnapshot variables_snapshot; - if (std::holds_alternative( - *compilation_result_or_args_and_variables_snapshot)) { - auto& compilation_results = std::get( - *compilation_result_or_args_and_variables_snapshot); - // This is when the signature is already compiled. - client = compilation_results.client; - compilation_result = compilation_results.compilation_result; - executable = compilation_results.executable; - variable_infos = compilation_results.variable_infos; - } else { - auto& compiler_args_and_variable_snapshots = - std::get( - *compilation_result_or_args_and_variables_snapshot); - // This is when the signature is not compiled yet. - const std::vector& args = - compiler_args_and_variable_snapshots.xla_compiler_args; - variables_snapshot = - std::move(compiler_args_and_variable_snapshots.variable_snapshots); - - const Status s = CompileToLocalExecutable( - ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, args, - XlaCompilationCache::CompileMode::kStrict, - /*may_alias_resource_update=*/true, &client, &compilation_result, - &executable); - OP_REQUIRES_OK_ASYNC(ctx, s, done); + // Note that here we assume the shape of the variables don't change between + // compilation and execution. The locks on the variables are released before + // compilation so that we can achieve parallel compilation of different batch + // sizes during warm-up. + { + // Creating a scope so that the locks on the variables are released when + // variable_infos goes out of scope. + std::vector variable_infos; + std::set variables_updated; + // Here we only need to reader-lock the variables, so we pass an empty + // variables_updated set here. + Status status = GetVariableInfosFromInputs( + ctx->resource_manager(), ctx->device(), inputs, resources_, + &variables_updated, &variable_infos); + OP_REQUIRES_OK_ASYNC(ctx, status, done); + status = LockVariables(absl::MakeSpan(variable_infos)); + OP_REQUIRES_OK_ASYNC(ctx, status, done); + auto status_or_xla_compiler_args = + XlaComputationLaunchContext::BuildXlaCompilerArguments( + constants_, inputs, variable_infos, + static_cast(ctx->device())); + OP_REQUIRES_OK_ASYNC(ctx, status_or_xla_compiler_args.status(), done); + xla_compiler_args = std::move(status_or_xla_compiler_args.value()); } + Status status = CompileToLocalExecutable( + ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, + xla_compiler_args, DeviceCompileMode::kStrict, + /*may_alias_resource_update=*/true, &client, &compilation_result, + &executable); + OP_REQUIRES_OK_ASYNC(ctx, status, done); // Continuation of the execution, may be run in a different thread. - // - // Note that only one of variables_snapshot and variable_infos contains - // value. variables_snapshot contains value when the signature is not - // compiled yet before XlaLocalLaunchBase::ComputeAsync is called. Otherwise - // variable_infos contains value. variable_infos also contains locks on the - // variables. - auto run_xla_cluster = [ctx, - variables_snapshot = std::move(variables_snapshot), - variable_infos = std::move(variable_infos), client, - executable, compilation_result, done, inputs, - resources = resources_]() { + auto run_xla_cluster = [ctx, client, executable, compilation_result, done, + inputs, resources = resources_]() { auto platform_info = XlaPlatformInfoFromDevice(ctx->device()); - std::map resource_var_ptrs; - if (variable_infos.get() == nullptr) { - for (const auto& [variable_index, variable_tensor] : variables_snapshot) { - resource_var_ptrs.emplace(variable_index, variable_tensor.has_value() - ? &variable_tensor.value() - : nullptr); - } - } else { - for (int i = 0; i < resources.size(); i++) { - resource_var_ptrs[resources[i]] = (*variable_infos)[i].var()->tensor(); + std::vector variable_infos; + std::set variables_updated; + for (const auto& resource_update : compilation_result->resource_updates) { + if (resource_update.modified) { + variables_updated.insert(resource_update.input_index); } } + OP_REQUIRES_OK_ASYNC(ctx, + GetVariableInfosFromInputs( + ctx->resource_manager(), ctx->device(), inputs, + resources, &variables_updated, &variable_infos), + done); + OP_REQUIRES_OK_ASYNC(ctx, LockVariables(absl::MakeSpan(variable_infos)), + done); + std::map resource_var_ptrs; + for (int i = 0; i < resources.size(); i++) { + resource_var_ptrs[resources[i]] = variable_infos[i].var()->tensor(); + } std::shared_ptr allocator = GetAllocator(ctx->device(), GetStream(ctx), platform_info); @@ -509,30 +421,11 @@ void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { OP_REQUIRES_ASYNC(ctx, execution_output.ok(), execution_output.status(), done); - std::vector local_variable_infos; - std::vector& variable_infos_ref = - variable_infos.get() == nullptr ? local_variable_infos - : *variable_infos; - - // We need to hold the locks on the variables since we are going - // to write to them. If variable_infos is null, then we need to acquire the - // locks using local_variable_infos. If variable_infos is not null, then we - // are already holding the locks on the variables through variable_infos. - if (variable_infos.get() == nullptr) { - OP_REQUIRES_OK_ASYNC( - ctx, - GetVariableInfosFromInputs(ctx->resource_manager(), ctx->device(), - inputs, resources, &local_variable_infos), - done); - OP_REQUIRES_OK_ASYNC( - ctx, LockVariables(absl::MakeSpan(local_variable_infos)), done); - } - OP_REQUIRES_OK_ASYNC( ctx, launch_context.PopulateOutputs( ctx, compilation_result, execution_output->ConsumeResult(), - /*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos_ref), + /*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos), input_output_alias, resource_var_ptrs), done); VLOG(1) << "Done"; @@ -633,16 +526,16 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { mutex_lock guard(cannot_compile_cluster_mu_); cannot_compile_cluster = cannot_compile_cluster_; } - XlaCompilationCache::CompileMode compile_mode = [&] { + DeviceCompileMode compile_mode = [&] { if (must_compile_) { - return XlaCompilationCache::CompileMode::kStrict; + return DeviceCompileMode::kStrict; } - return GetXlaOpsCommonFlags().tf_xla_async_compilation - ? XlaCompilationCache::CompileMode::kAsync - : XlaCompilationCache::CompileMode::kLazy; + return GetXlaOpsCommonFlags()->tf_xla_async_compilation + ? DeviceCompileMode::kAsync + : DeviceCompileMode::kLazy; }(); - if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation || + if (GetXlaOpsCommonFlags()->tf_xla_always_defer_compilation || cannot_compile_cluster) { executable = nullptr; } else { @@ -658,7 +551,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { const Status status = CompileToLocalExecutable( ctx, function_, has_ref_vars_, platform_info_, args, compile_mode, /*may_alias_resource_update=*/false, &client, &kernel, &executable); - if (compile_mode != XlaCompilationCache::CompileMode::kLazy || + if (compile_mode != DeviceCompileMode::kLazy || status.code() != error::UNIMPLEMENTED) { OP_REQUIRES_OK(ctx, status); } diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h index a5e77259d0f..03a9fa1e85d 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.h +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/device_compiler.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/jit/xla_platform_info.h" @@ -72,7 +72,7 @@ class XlaLocalLaunchBase : public AsyncOpKernel { // which will be compiled and executed using XLA. The XlaLocalLaunchOp is // responsible for handling interactions with the TensorFlow executor. // Once all inputs are present, and their shapes are known, the op can -// use a 'XlaCompilationCache' to compile and execute code which is specific +// use a 'DeviceCompiler' to compile and execute code which is specific // to the shapes of input Tensors. // XlaLocalLaunchOp uses xla::LocalClient::Compile() and // xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device @@ -136,4 +136,4 @@ class XlaMergeOp : public OpKernel { } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ +#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index ab93e9f8ec3..2cb35d3e144 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -2144,6 +2144,8 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "RngSkip", "Roll", "ScatterNd", + "SegmentSumV2", + "SegmentProdV2", "SelfAdjointEigV2", "SoftmaxCrossEntropyWithLogits", "SpaceToBatch", diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index 5144584adab..e70b5c2525d 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -2,6 +2,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//tensorflow/compiler/tf2xla:internal"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/jit/pjrt_device_compiler_client.cc b/tensorflow/compiler/jit/pjrt_device_compiler_client.cc new file mode 100644 index 00000000000..605af0e6865 --- /dev/null +++ b/tensorflow/compiler/jit/pjrt_device_compiler_client.cc @@ -0,0 +1,81 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/pjrt_device_compiler_client.h" + +#include +#include +#include + +namespace tensorflow { + +namespace { +xla::CompileOptions GetPjRtCompileOptions( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) { + xla::CompileOptions pjrt_compile_options; + pjrt_compile_options.argument_layouts = result.xla_input_shapes; + pjrt_compile_options.executable_build_options = + GetExecutableBuildOptions(options, result, /*default_device_ordinal=*/-1); + // Compile portable executable for single device compilation. + pjrt_compile_options.compile_portable_executable = true; + return pjrt_compile_options; +} +} // namespace + +StatusOr> +PjRtDeviceCompilerClient::BuildExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) { + VLOG(2) << "Compiling to xla::PjRtLoadedExecutable."; + + TF_ASSIGN_OR_RETURN(auto executable, + client_->Compile(*result.computation, + GetPjRtCompileOptions(options, result))); + + VLOG(2) << "Compiled PJRT executable " << executable->name() + << " num_replicas " << executable->num_replicas() + << " num_partitions " << executable->num_partitions(); + + return std::move(executable); +} + +StatusOr PjRtDeviceCompilerClient::BuildSerializedExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) { + VLOG(1) << "PJRT currently doesn't support AOT compilation. Compiling to " + "xla::PjRtLoadedExecutable and serializing it"; + TF_ASSIGN_OR_RETURN(auto executable, BuildExecutable(options, result)); + return executable->SerializeExecutable(); +} + +StatusOr> +PjRtDeviceCompilerClient::LoadExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result, + const std::string& serialized_executable) { + VLOG(1) << "Deserializing from string to xla::PjRtLoadedExecutable."; + return client_->DeserializeExecutable(serialized_executable, + GetPjRtCompileOptions(options, result)); +} + +void PjRtDeviceCompilerClient::WaitForProgramsToFinish() { + // TODO(b/255826209): Modify this if PjRtClient exposes a function to wait for + // programs to finish. + LOG(INFO) << "Unimplemented: PJRT uses futures and waiting for programs to " + "finish isn't necessary."; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/pjrt_device_compiler_client.h b/tensorflow/compiler/jit/pjrt_device_compiler_client.h new file mode 100644 index 00000000000..bd03d577f83 --- /dev/null +++ b/tensorflow/compiler/jit/pjrt_device_compiler_client.h @@ -0,0 +1,73 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_PJRT_DEVICE_COMPILER_CLIENT_H_ +#define TENSORFLOW_COMPILER_JIT_PJRT_DEVICE_COMPILER_CLIENT_H_ + +#include +#include +#include + +#include "tensorflow/compiler/jit/device_compiler_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" + +namespace tensorflow { + +// Calls into PjRtClient to provide functionality for building, serializing and +// loading PjRtLoadedExecutables. +class PjRtDeviceCompilerClient + : public DeviceCompilerClient { + public: + explicit PjRtDeviceCompilerClient(xla::PjRtClient* client) + : client_(client) {} + + StatusOr> BuildExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) override; + + // PjRt doesn't support AOT compilation yet. Builds a PjRtLoadedExecutable and + // serializes it to string. + StatusOr BuildSerializedExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) override; + + // Deserializes a serialized executable as produced by + // PjRtExecutable::SerializeExecutable(). `serialized_executable` must have + // been produced by a compiler of the same platform and version as this one. + // + // PjRt doesn't support AOT compilation yet. Loading a serialized executable + // is currently only implemented for TfrtTpuPjrtClient and hence, this + // function doesn't use PjRtClient::LoadSerializedExecutable() and uses + // PjRtClient::DeserializeExecutable() instead. + StatusOr> LoadExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result, + const std::string& serialized_executable) override; + + // No-op. PJRT uses futures and waiting for programs to finish isn't + // necessary. + void WaitForProgramsToFinish() override; + + xla::PjRtClient* client() const override { return client_; } + + private: + xla::PjRtClient* const client_; + + TF_DISALLOW_COPY_AND_ASSIGN(PjRtDeviceCompilerClient); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PJRT_DEVICE_COMPILER_CLIENT_H_ diff --git a/tensorflow/compiler/jit/pjrt_device_context.cc b/tensorflow/compiler/jit/pjrt_device_context.cc new file mode 100644 index 00000000000..e6d8157115a --- /dev/null +++ b/tensorflow/compiler/jit/pjrt_device_context.cc @@ -0,0 +1,124 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/pjrt_device_context.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/tfrt/common/async_value_tensor.h" +#include "tensorflow/core/tfrt/common/pjrt_util.h" + +namespace tensorflow { +namespace { + +StatusOr> HostTensorToPjRtBuffer( + const tensorflow::Tensor* cpu_tensor, tensorflow::Device* device, + xla::PjRtClient* pjrt_client) { + // TODO(b/262472386): Consider layout_preference_fn and + // shape_representation_fn. + xla::Shape shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(cpu_tensor->dtype(), cpu_tensor->shape(), &shape)); + TF_ASSIGN_OR_RETURN(xla::PjRtDevice * pjrt_device, + pjrt_client->LookupDevice(device->parsed_name().id)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr buffer, + pjrt_client->BufferFromHostBuffer( + cpu_tensor->data(), shape.element_type(), shape.dimensions(), + /*byte_strides=*/std::nullopt, + xla::PjRtClient::HostBufferSemantics::kZeroCopy, + /*on_done_with_host_buffer=*/ + [cpu_tensor = *cpu_tensor]() { /* frees tensor */ }, pjrt_device)); + return buffer; +} + +} // namespace + +void PjRtDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, + absl::string_view tensor_name, + Device* device, + Tensor* cpu_tensor, + StatusCallback done) { + profiler::TraceMe traceme("PjRtDeviceContext::CopyDeviceTensorToCPU"); + if (device_tensor->NumElements() == 0) { + VLOG(2) << "CopyDeviceTensorToCPU empty tensor"; + done(OkStatus()); + return; + } + auto literal = std::make_unique(); + auto status = tensorflow::HostTensorToMutableBorrowingLiteral(cpu_tensor, + literal.get()); + if (!status.ok()) { + done(status); + } + std::shared_ptr device_buffer = + tensorflow::AsyncValueTensor::FromTensor(device_tensor)->GetBuffer(); + xla::PjRtFuture future = device_buffer->ToLiteral(literal.get()); + future.OnReady([literal = std::move(literal), done = std::move(done), + device_buffer = std::move(device_buffer)]( + const tensorflow::Status& status) { done(status); }); +} + +void PjRtDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, + Device* device, + Tensor* device_tensor, + StatusCallback done, + bool sync_dst_compute) const { + profiler::TraceMe traceme("PjRtDeviceContext::CopyCPUTensorToDevice"); + if (cpu_tensor->NumElements() == 0) { + VLOG(2) << "CopyCPUTensorToDevice empty tensor"; + done(OkStatus()); + return; + } + AsyncValueTensor* result_tensor = + tensorflow::AsyncValueTensor::FromTensor(device_tensor); + // The result tensor should be newly allocated, which does not point to a + // valid buffer yet. + CHECK(!result_tensor->GetBuffer()); // Crash OK + // TODO(b/252887149): figure out how to cache PJRT client. + StatusOr pjrt_client = + GetOrCreatePjRtClient(DeviceType(device->device_type())); + if (!pjrt_client.ok()) { + done(pjrt_client.status()); + return; + } + StatusOr> buffer_or = + HostTensorToPjRtBuffer(cpu_tensor, device, *pjrt_client); + if (!buffer_or.ok()) { + done(buffer_or.status()); + return; + } + std::unique_ptr device_buffer = std::move(buffer_or.value()); + // TODO(b/244666476): evaluate the performance impact of marking ready when + // the data in device buffer is computed. In `tpu_device_context`, it is + // marked done when the allocation finished. + device_buffer->GetReadyFuture().OnReady(std::move(done)); + result_tensor->SetBuffer(std::move(device_buffer)); +} + +void PjRtDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor, + Device* device, + Tensor* output_tensor, + StatusCallback done) const { + done(errors::Unimplemented("Same-device copies not implemented.")); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/pjrt_device_context.h b/tensorflow/compiler/jit/pjrt_device_context.h new file mode 100644 index 00000000000..42e72dbd9d7 --- /dev/null +++ b/tensorflow/compiler/jit/pjrt_device_context.h @@ -0,0 +1,44 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_PJRT_DEVICE_CONTEXT_H_ +#define TENSORFLOW_COMPILER_JIT_PJRT_DEVICE_CONTEXT_H_ + +#include + +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Helper class for managing data transfers between host and accelerator +// devices using PjRt. +class PjRtDeviceContext : public DeviceContext { + public: + void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, StatusCallback done, + bool sync_dst_compute) const override; + void CopyDeviceTensorToCPU(const Tensor* device_tensor, + absl::string_view tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) override; + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PJRT_DEVICE_CONTEXT_H_ diff --git a/tensorflow/compiler/jit/tests/BUILD b/tensorflow/compiler/jit/tests/BUILD index a4311584a6f..17e47dd9a81 100644 --- a/tensorflow/compiler/jit/tests/BUILD +++ b/tensorflow/compiler/jit/tests/BUILD @@ -3,7 +3,10 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") licenses(["notice"]) -package(default_visibility = ["//visibility:private"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:private"], +) cc_library( name = "auto_clustering_test_helper", @@ -51,10 +54,13 @@ tf_cc_test( ) cc_library( - name = "xla_compilation_cache_test_helper", + name = "device_compiler_test_helper", testonly = True, - srcs = ["xla_compilation_cache_test_helper.cc"], - hdrs = ["xla_compilation_cache_test_helper.h"], + srcs = ["device_compiler_test_helper.cc"], + hdrs = ["device_compiler_test_helper.h"], + visibility = [ + "//tensorflow/compiler/jit:__pkg__", + ], deps = [ "//tensorflow/compiler/jit:xla_activity_listener", "//tensorflow/compiler/jit:xla_compilation_cache_proto_cc", @@ -78,9 +84,9 @@ cc_library( ) tf_cc_test( - name = "xla_compilation_cache_serialize_test", + name = "device_compiler_serialize_test", srcs = [ - "xla_compilation_cache_serialize_test.cc", + "device_compiler_serialize_test.cc", ], env = { "XLA_FLAGS": "--xla_gpu_enable_xla_runtime_executable", @@ -92,7 +98,7 @@ tf_cc_test( "xla", ], deps = [ - ":xla_compilation_cache_test_helper", + ":device_compiler_test_helper", "//tensorflow/compiler/jit:compilation_passes", "//tensorflow/compiler/jit:flags", "//tensorflow/core:test", @@ -100,9 +106,9 @@ tf_cc_test( ) tf_cc_test( - name = "xla_compilation_cache_serialize_options_test", + name = "device_compiler_serialize_options_test", srcs = [ - "xla_compilation_cache_serialize_options_test.cc", + "device_compiler_serialize_options_test.cc", ], env = { "XLA_FLAGS": "--xla_gpu_enable_xla_runtime_executable", @@ -114,7 +120,7 @@ tf_cc_test( "xla", ], deps = [ - ":xla_compilation_cache_test_helper", + ":device_compiler_test_helper", "//tensorflow/compiler/jit:compilation_passes", "//tensorflow/compiler/jit:flags", "//tensorflow/core:test", diff --git a/tensorflow/compiler/jit/tests/xla_compilation_cache_serialize_options_test.cc b/tensorflow/compiler/jit/tests/device_compiler_serialize_options_test.cc similarity index 87% rename from tensorflow/compiler/jit/tests/xla_compilation_cache_serialize_options_test.cc rename to tensorflow/compiler/jit/tests/device_compiler_serialize_options_test.cc index c2b10cdc133..3eaa8202261 100644 --- a/tensorflow/compiler/jit/tests/xla_compilation_cache_serialize_options_test.cc +++ b/tensorflow/compiler/jit/tests/device_compiler_serialize_options_test.cc @@ -15,13 +15,13 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" -#include "tensorflow/compiler/jit/tests/xla_compilation_cache_test_helper.h" +#include "tensorflow/compiler/jit/tests/device_compiler_test_helper.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace tensorflow { namespace { -TEST_F(XlaCompilationCacheSerializeTest, PersistentCacheOptionsTest) { +TEST_F(DeviceCompilerSerializeTest, PersistentCacheOptionsTest) { GraphDef graph = GetTestGraph({-1, 4}); // Warmup the persistent cache(s) with multiple runs. 4 is a magic number to @@ -30,8 +30,8 @@ TEST_F(XlaCompilationCacheSerializeTest, PersistentCacheOptionsTest) { for (int b = 1; b < 4; ++b) { TF_ASSERT_OK(ExecuteWithBatch(graph, b)); } - TF_ASSERT_OK( - listener()->VerifyListenerHistory(/*expect_persistent_cache_use=*/false)); + TF_ASSERT_OK(listener()->VerifyPersistentCacheUseListenerHistory( + /*expect_persistent_cache_use=*/false)); // Reset the cluster numbering between sessions so we can get the same // cluster numbering. @@ -54,8 +54,8 @@ TEST_F(XlaCompilationCacheSerializeTest, PersistentCacheOptionsTest) { for (int b = 1; b < 4; ++b) { TF_ASSERT_OK(ExecuteWithBatch(graph, b)); } - TF_ASSERT_OK( - listener()->VerifyListenerHistory(/*expect_persistent_cache_use=*/true)); + TF_ASSERT_OK(listener()->VerifyPersistentCacheUseListenerHistory( + /*expect_persistent_cache_use=*/true)); } } // namespace diff --git a/tensorflow/compiler/jit/tests/xla_compilation_cache_serialize_test.cc b/tensorflow/compiler/jit/tests/device_compiler_serialize_test.cc similarity index 87% rename from tensorflow/compiler/jit/tests/xla_compilation_cache_serialize_test.cc rename to tensorflow/compiler/jit/tests/device_compiler_serialize_test.cc index 8828cd4d4bb..984b9852535 100644 --- a/tensorflow/compiler/jit/tests/xla_compilation_cache_serialize_test.cc +++ b/tensorflow/compiler/jit/tests/device_compiler_serialize_test.cc @@ -15,13 +15,13 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" -#include "tensorflow/compiler/jit/tests/xla_compilation_cache_test_helper.h" +#include "tensorflow/compiler/jit/tests/device_compiler_test_helper.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace tensorflow { namespace { -TEST_F(XlaCompilationCacheSerializeTest, PersistentCacheTest) { +TEST_F(DeviceCompilerSerializeTest, PersistentCacheTest) { GraphDef graph = GetTestGraph({-1, 4}); // Warmup the persistent cache(s) with multiple runs. 4 is a magic number to @@ -30,8 +30,8 @@ TEST_F(XlaCompilationCacheSerializeTest, PersistentCacheTest) { for (int b = 1; b < 4; ++b) { TF_ASSERT_OK(ExecuteWithBatch(graph, b)); } - TF_ASSERT_OK( - listener()->VerifyListenerHistory(/*expect_persistent_cache_use=*/false)); + TF_ASSERT_OK(listener()->VerifyPersistentCacheUseListenerHistory( + /*expect_persistent_cache_use=*/false)); // Reset the cluster numbering between sessions so we can get the same // cluster numbering. @@ -42,8 +42,8 @@ TEST_F(XlaCompilationCacheSerializeTest, PersistentCacheTest) { for (int b = 1; b < 4; ++b) { TF_ASSERT_OK(ExecuteWithBatch(graph, b)); } - TF_ASSERT_OK( - listener()->VerifyListenerHistory(/*expect_persistent_cache_use=*/true)); + TF_ASSERT_OK(listener()->VerifyPersistentCacheUseListenerHistory( + /*expect_persistent_cache_use=*/true)); // Reset the cluster numbering between sessions so we can get the same // cluster numbering. diff --git a/tensorflow/compiler/jit/tests/xla_compilation_cache_test_helper.cc b/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc similarity index 94% rename from tensorflow/compiler/jit/tests/xla_compilation_cache_test_helper.cc rename to tensorflow/compiler/jit/tests/device_compiler_test_helper.cc index cfa6efd4d71..95d75b67cac 100644 --- a/tensorflow/compiler/jit/tests/xla_compilation_cache_test_helper.cc +++ b/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/tests/xla_compilation_cache_test_helper.h" +#include "tensorflow/compiler/jit/tests/device_compiler_test_helper.h" #include @@ -52,7 +52,7 @@ NodeDef MakeNode( } // namespace -GraphDef XlaCompilationCacheSerializeTest::GetTestGraph( +GraphDef DeviceCompilerSerializeTest::GetTestGraph( const PartialTensorShape& input_shape) { FunctionDef make_test_fn = FunctionDefHelper::Define( "TestFn", {"a:float", "b:float", "c:float"}, {"m:float"}, {}, @@ -80,8 +80,8 @@ GraphDef XlaCompilationCacheSerializeTest::GetTestGraph( return graph; } -Status XlaCompilationCacheSerializeTest::ExecuteWithBatch(const GraphDef& graph, - int batch) { +Status DeviceCompilerSerializeTest::ExecuteWithBatch(const GraphDef& graph, + int batch) { const TensorShape shape({batch, 4}); // Compute the golden output tensor @@ -134,8 +134,7 @@ Status XlaCompilationCacheSerializeTest::ExecuteWithBatch(const GraphDef& graph, return OkStatus(); } -Status -XlaCompilationCacheSerializeTest::AlterPersistentCacheEntryHloModuleNames( +Status DeviceCompilerSerializeTest::AlterPersistentCacheEntryHloModuleNames( absl::string_view persistent_cache_dir_path, absl::string_view file_prefix) { Env* env = Env::Default(); diff --git a/tensorflow/compiler/jit/tests/xla_compilation_cache_test_helper.h b/tensorflow/compiler/jit/tests/device_compiler_test_helper.h similarity index 74% rename from tensorflow/compiler/jit/tests/xla_compilation_cache_test_helper.h rename to tensorflow/compiler/jit/tests/device_compiler_test_helper.h index df68e0da82a..edb6be6a0ff 100644 --- a/tensorflow/compiler/jit/tests/xla_compilation_cache_test_helper.h +++ b/tensorflow/compiler/jit/tests/device_compiler_test_helper.h @@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_TESTS_XLA_COMPILATION_CACHE_TEST_HELPER_H_ -#define TENSORFLOW_COMPILER_JIT_TESTS_XLA_COMPILATION_CACHE_TEST_HELPER_H_ +#ifndef TENSORFLOW_COMPILER_JIT_TESTS_DEVICE_COMPILER_TEST_HELPER_H_ +#define TENSORFLOW_COMPILER_JIT_TESTS_DEVICE_COMPILER_TEST_HELPER_H_ #include #include +#include #include "absl/strings/string_view.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" @@ -36,8 +37,7 @@ class JitCompilationListener : public XlaActivityListener { Status Listen( const XlaJitCompilationActivity& jit_compilation_activity) override { - used_persistent_cache_.push_back( - jit_compilation_activity.used_persistent_cache()); + activity_history_.push_back(jit_compilation_activity); return OkStatus(); } @@ -45,28 +45,33 @@ class JitCompilationListener : public XlaActivityListener { return OkStatus(); } - ~JitCompilationListener() override {} + ~JitCompilationListener() override = default; - Status VerifyListenerHistory(bool expect_persistent_cache_use) { - for (bool used_persistent_cache : used_persistent_cache_) { - if (used_persistent_cache != expect_persistent_cache_use) { + Status VerifyPersistentCacheUseListenerHistory( + bool expect_persistent_cache_use) { + for (const auto& activity : activity_history_) { + if (activity.used_persistent_cache() != expect_persistent_cache_use) { return errors::FailedPrecondition("Unexpected listener history."); } } return OkStatus(); } - void ClearListenerHistory() { used_persistent_cache_.clear(); } + std::vector GetListenerHistory() { + return activity_history_; + } + + void ClearListenerHistory() { activity_history_.clear(); } private: - std::vector used_persistent_cache_; + std::vector activity_history_; }; // Fixture for testing XLA compilation cache serialization. -class XlaCompilationCacheSerializeTest : public ::testing::Test { +class DeviceCompilerSerializeTest : public ::testing::Test { protected: - XlaCompilationCacheSerializeTest() { - auto listener = absl::make_unique(); + DeviceCompilerSerializeTest() { + auto listener = std::make_unique(); listener_ = listener.get(); RegisterXlaActivityListener(std::move(listener)); } @@ -94,4 +99,4 @@ class XlaCompilationCacheSerializeTest : public ::testing::Test { } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_JIT_TESTS_XLA_COMPILATION_CACHE_TEST_HELPER_H_ +#endif // TENSORFLOW_COMPILER_JIT_TESTS_DEVICE_COMPILER_TEST_HELPER_H_ diff --git a/tensorflow/compiler/jit/tests/keras_imagenet_main.golden_summary b/tensorflow/compiler/jit/tests/keras_imagenet_main.golden_summary index c264867cd6a..630836b86c2 100644 --- a/tensorflow/compiler/jit/tests/keras_imagenet_main.golden_summary +++ b/tensorflow/compiler/jit/tests/keras_imagenet_main.golden_summary @@ -1,4 +1,4 @@ -Clustered nodes: 2301 +Clustered nodes: 2725 Unclustered nodes: 606 Number of clusters: 2 @@ -13,7 +13,7 @@ unclustered size 606 Switch 1 _Arg 435 _Retval 2 -cluster 0 size 1645 +cluster 0 size 1910 Add 16 AddN 71 ArgMax 1 @@ -21,7 +21,7 @@ cluster 0 size 1645 BiasAdd 1 BiasAddGrad 1 Cast 115 - Const 142 + Const 407 Conv2D 53 Conv2DBackpropFilter 53 Conv2DBackpropInput 52 @@ -46,11 +46,11 @@ cluster 0 size 1645 Sum 1 Tile 1 Transpose 1 -cluster 1 size 656 +cluster 1 size 815 AddN 1 AssignAddVariableOp 1 AssignSubVariableOp 106 - Const 61 + Const 220 DivNoNan 1 Identity 1 Mul 161 diff --git a/tensorflow/compiler/jit/tests/keras_imagenet_main_graph_mode.golden_summary b/tensorflow/compiler/jit/tests/keras_imagenet_main_graph_mode.golden_summary index 9683f9c03ee..a1692ccef28 100644 --- a/tensorflow/compiler/jit/tests/keras_imagenet_main_graph_mode.golden_summary +++ b/tensorflow/compiler/jit/tests/keras_imagenet_main_graph_mode.golden_summary @@ -1,17 +1,17 @@ -Clustered nodes: 1968 -Unclustered nodes: 445 +Clustered nodes: 2178 +Unclustered nodes: 446 Number of clusters: 1 -unclustered size 445 +unclustered size 446 AssignAddVariableOp 2 - Const 1 + Const 2 DivNoNan 1 Identity 1 NoOp 1 ReadVariableOp 2 VarHandleOp 435 _Retval 2 -cluster 0 size 1968 +cluster 0 size 2178 Add 17 AddN 72 ArgMax 1 @@ -20,7 +20,7 @@ cluster 0 size 1968 BiasAdd 1 BiasAddGrad 1 Cast 3 - Const 147 + Const 357 Conv2D 53 Conv2DBackpropFilter 53 Conv2DBackpropInput 52 diff --git a/tensorflow/compiler/jit/tests/opens2s_gnmt_mixed_precision.golden_summary b/tensorflow/compiler/jit/tests/opens2s_gnmt_mixed_precision.golden_summary index a3ffbf7dbbc..2ad145e0147 100644 --- a/tensorflow/compiler/jit/tests/opens2s_gnmt_mixed_precision.golden_summary +++ b/tensorflow/compiler/jit/tests/opens2s_gnmt_mixed_precision.golden_summary @@ -1,8 +1,8 @@ -Clustered nodes: 2452 -Unclustered nodes: 3894 -Number of clusters: 31 +Clustered nodes: 2385 +Unclustered nodes: 4221 +Number of clusters: 30 -unclustered size 3894 +unclustered size 4221 Add 17 AddN 1 All 1 @@ -12,8 +12,9 @@ unclustered size 3894 AssignAdd 2 AssignSub 2 BroadcastGradientArgs 44 + Cast 38 ConcatV2 3 - Const 662 + Const 875 ControlTrigger 5 Enter 874 Equal 4 @@ -30,9 +31,9 @@ unclustered size 3894 LogicalAnd 3 LoopCond 8 Max 4 - Maximum 6 + Maximum 44 Merge 145 - Minimum 5 + Minimum 43 Mul 8 NextIteration 136 RandomUniform 14 @@ -138,10 +139,28 @@ cluster 4 size 11 ReverseSequence 1 Slice 2 Transpose 3 -cluster 5 size 32 +cluster 5 size 21 + All 1 + ConcatV2 1 + Const 11 + Equal 1 + ExpandDims 1 + ReverseSequence 1 + Shape 1 + StridedSlice 1 + Transpose 3 +cluster 6 size 11 + Cast 1 + Const 5 + GatherV2 1 + Shape 1 + StridedSlice 1 + Transpose 1 + ZerosLike 1 +cluster 7 size 33 All 2 Cast 1 - Const 16 + Const 17 Equal 2 ExpandDims 2 GatherV2 1 @@ -150,15 +169,7 @@ cluster 5 size 32 Shape 2 StridedSlice 2 Transpose 2 -cluster 6 size 11 - Cast 1 - Const 5 - GatherV2 1 - Shape 1 - StridedSlice 1 - Transpose 1 - ZerosLike 1 -cluster 7 size 11 +cluster 8 size 11 Cast 1 Const 4 Floor 1 @@ -166,26 +177,26 @@ cluster 7 size 11 Mul 2 Pow 1 Sub 1 -cluster 8 size 5 +cluster 9 size 5 All 1 Const 1 Less 1 LogicalAnd 1 LogicalNot 1 -cluster 9 size 9 +cluster 10 size 9 All 1 Const 4 Equal 1 LessEqual 1 LogicalOr 1 Max 1 -cluster 10 size 272 +cluster 11 size 302 Add 24 BatchMatMulV2 1 BiasAdd 8 Cast 8 ConcatV2 16 - Const 51 + Const 81 ExpandDims 3 Fill 1 GreaterEqual 8 @@ -207,13 +218,13 @@ cluster 10 size 272 StridedSlice 1 Sum 2 Tanh 17 -cluster 11 size 6 +cluster 12 size 6 Add 1 All 1 Const 2 GreaterEqual 1 LogicalOr 1 -cluster 14 size 614 +cluster 15 size 614 Add 22 AddN 41 BatchMatMulV2 2 @@ -239,7 +250,7 @@ cluster 14 size 614 TanhGrad 17 Tile 2 ZerosLike 1 -cluster 15 size 22 +cluster 16 size 22 Add 2 BiasAdd 1 ConcatV2 1 @@ -252,7 +263,7 @@ cluster 15 size 22 Sigmoid 3 Split 1 Tanh 2 -cluster 16 size 60 +cluster 17 size 60 Add 2 AddN 4 BiasAddGrad 1 @@ -270,16 +281,6 @@ cluster 16 size 60 Slice 2 Sum 9 TanhGrad 2 -cluster 17 size 20 - All 1 - ConcatV2 1 - Const 10 - Equal 1 - ExpandDims 1 - ReverseSequence 1 - Shape 1 - StridedSlice 1 - Transpose 3 cluster 18 size 22 Add 2 BiasAdd 1 @@ -311,12 +312,12 @@ cluster 19 size 60 Slice 2 Sum 9 TanhGrad 2 -cluster 21 size 27 +cluster 21 size 29 Add 2 BiasAdd 1 Cast 1 ConcatV2 1 - Const 5 + Const 7 GreaterEqual 2 MatMul 1 Mul 5 @@ -325,12 +326,12 @@ cluster 21 size 27 Snapshot 1 Split 1 Tanh 2 -cluster 22 size 25 +cluster 22 size 28 Add 3 BiasAdd 1 Cast 1 ConcatV2 1 - Const 2 + Const 5 GreaterEqual 1 MatMul 1 Mul 5 @@ -361,12 +362,12 @@ cluster 24 size 4 Const 1 Shape 2 Transpose 1 -cluster 25 size 24 +cluster 25 size 27 Add 3 BiasAdd 1 Cast 1 ConcatV2 1 - Const 2 + Const 5 GreaterEqual 1 MatMul 1 Mul 5 @@ -375,12 +376,12 @@ cluster 25 size 24 Snapshot 1 Split 1 Tanh 2 -cluster 26 size 24 +cluster 26 size 27 Add 3 BiasAdd 1 Cast 1 ConcatV2 1 - Const 2 + Const 5 GreaterEqual 1 MatMul 1 Mul 5 @@ -389,12 +390,12 @@ cluster 26 size 24 Snapshot 1 Split 1 Tanh 2 -cluster 27 size 24 +cluster 27 size 27 Add 3 BiasAdd 1 Cast 1 ConcatV2 1 - Const 2 + Const 5 GreaterEqual 1 MatMul 1 Mul 5 @@ -403,12 +404,12 @@ cluster 27 size 24 Snapshot 1 Split 1 Tanh 2 -cluster 28 size 24 +cluster 28 size 27 Add 3 BiasAdd 1 Cast 1 ConcatV2 1 - Const 2 + Const 5 GreaterEqual 1 MatMul 1 Mul 5 @@ -435,8 +436,3 @@ cluster 31 size 4 cluster 32 size 4 Mul 3 UnsortedSegmentSum 1 -cluster 33 size 116 - Cast 38 - Const 2 - Maximum 38 - Minimum 38 diff --git a/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.cc b/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.cc new file mode 100644 index 00000000000..3f96a0f2aa9 --- /dev/null +++ b/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.cc @@ -0,0 +1,36 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h" + +#include + +namespace tensorflow { + +Status TfGraphToHloCompiler::Compile(const XlaCompiler::CompileOptions& options, + const NameAttrList& function, + absl::Span args, + XlaCompilationResult* result) { + return xla_compiler_.CompileFunction(options, function, args, result); +} + +Status TfGraphToHloCompiler::CompileSingleOp( + const XlaCompiler::CompileOptions& options, const OpKernelContext* ctx, + absl::Span args, XlaCompilationResult* result) { + return xla_compiler_.CompileSingleOp( + options, XlaCompiler::SingleOpCompileArgument(*ctx), args, result); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h b/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h new file mode 100644 index 00000000000..c927a90486b --- /dev/null +++ b/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h @@ -0,0 +1,59 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_TF_GRAPH_TO_HLO_COMPILER_H_ +#define TENSORFLOW_COMPILER_JIT_TF_GRAPH_TO_HLO_COMPILER_H_ + +#include +#include + +#include "tensorflow/compiler/jit/tf_to_hlo_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" + +namespace tensorflow { + +class TfGraphToHloCompiler : public TfToHloCompiler { + public: + TfGraphToHloCompiler() = delete; + + explicit TfGraphToHloCompiler(const XlaCompiler::Options& options) + : xla_compiler_(options) {} + + // Compiles a Tensorflow `function` into an HloModuleProto stored in the + // XlaCompilationResult pointed to by `result` by calling + // XlaCompiler::CompileFunction. + Status Compile(const XlaCompiler::CompileOptions& options, + const NameAttrList& function, + absl::Span args, + XlaCompilationResult* result) override; + + // Compiles a Tensorflow single op into an HloModuleProto stored in the + // XlaCompilationResult pointed to by `result` by calling + // XlaCompiler::CompileSingleOp. + Status CompileSingleOp(const XlaCompiler::CompileOptions& options, + const OpKernelContext* ctx, + absl::Span args, + XlaCompilationResult* result) override; + + private: + XlaCompiler xla_compiler_; + + TF_DISALLOW_COPY_AND_ASSIGN(TfGraphToHloCompiler); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_TF_GRAPH_TO_HLO_COMPILER_H_ diff --git a/tensorflow/compiler/jit/tf_to_hlo_compiler.h b/tensorflow/compiler/jit/tf_to_hlo_compiler.h new file mode 100644 index 00000000000..cf6245639f2 --- /dev/null +++ b/tensorflow/compiler/jit/tf_to_hlo_compiler.h @@ -0,0 +1,52 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_TF_TO_HLO_COMPILER_H_ +#define TENSORFLOW_COMPILER_JIT_TF_TO_HLO_COMPILER_H_ + +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class TfToHloCompiler { + public: + TfToHloCompiler() = default; + virtual ~TfToHloCompiler() = default; + + // Compiles a Tensorflow `function` to an HloModuleProto stored in the + // XlaCompilationResult pointed to by `result`. + virtual Status Compile(const XlaCompiler::CompileOptions& options, + const NameAttrList& function, + absl::Span args, + XlaCompilationResult* result) = 0; + + // Compiles a Tensorflow single op to an HloModuleProto stored in the + // XlaCompilationResult pointed to by `result`. + virtual Status CompileSingleOp(const XlaCompiler::CompileOptions& options, + const OpKernelContext* ctx, + absl::Span args, + XlaCompilationResult* result) = 0; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TfToHloCompiler); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_TF_TO_HLO_COMPILER_H_ diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index 3c4450a8fe8..e65059bb27b 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/graph/algorithm.h" diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc deleted file mode 100644 index f24b146d0dc..00000000000 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ /dev/null @@ -1,951 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/xla_compilation_cache.h" - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" -#include "absl/base/call_once.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/types/variant.h" -#include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/xla_activity.pb.h" -#include "tensorflow/compiler/jit/xla_activity_listener.h" -#include "tensorflow/compiler/jit/xla_cluster_util.h" -#include "tensorflow/compiler/jit/xla_compilation_cache.pb.h" -#include "tensorflow/compiler/jit/xla_compile_util.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" -#include "tensorflow/compiler/mlir/utils/array_container_utils.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/service/compiler.h" -#include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/graph_constructor.h" -#include "tensorflow/core/common_runtime/graph_optimizer.h" -#include "tensorflow/core/framework/attr_value_util.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/metrics.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/node_builder.h" -#include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/proto_serialization.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/fingerprint.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/path.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/protobuf/debug_event.pb.h" -#include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" -#include "tensorflow/core/public/version.h" -#include "tensorflow/core/tpu/tpu_defs.h" -#include "tensorflow/core/util/determinism.h" -#include "tensorflow/core/util/dump_graph.h" - -namespace tensorflow { -namespace { - -using TensorTypeAndShape = XlaCompilationCache::Signature::TensorTypeAndShape; - -constexpr char kXlaSerializedCacheKeySeparator[] = "__"; - -// Functor that converts a Signature's arg to a human readable string. -struct SignatureHumanStringAppender { - explicit SignatureHumanStringAppender(string* dest) : dest(dest) {} - string* dest; - void operator()(const Tensor& arg) { - absl::StrAppend(dest, "; ", arg.DebugString()); - } - void operator()(const TensorTypeAndShape& arg) { - absl::StrAppend(dest, ",", DataTypeString(arg.first)); - absl::StrAppend(dest, " [", absl::StrJoin(arg.second, ","), "]"); - } -}; - -// Functor that compares the arg values of two different signatures. Returns -// true when the args are not equal. -struct SignatureNotEqual { - bool operator()(const Tensor& arg, const Tensor& other) { - return arg.dtype() != other.dtype() || arg.shape() != other.shape() || - arg.tensor_data() != other.tensor_data(); - } - bool operator()(const TensorTypeAndShape& arg, - const TensorTypeAndShape& other) { - return arg.first != other.first || arg.second != other.second; - } - bool operator()(const Tensor& arg, const TensorTypeAndShape& other) { - return true; - } - bool operator()(const TensorTypeAndShape& arg, const Tensor& other) { - return true; - } -}; - -// Functor that incrementally computes a Signature's hash given its current hash -// and one of its args. -struct SignatureHashCombiner { - explicit SignatureHashCombiner(const uint64 h) : h(h) {} - uint64 h; - uint64 operator()(const Tensor& arg) { - h = Hash64Combine(h, std::hash()(static_cast(arg.dtype()))); - h = Hash64Combine( - h, Hash64(arg.tensor_data().data(), arg.tensor_data().size())); - for (int dim = 0; dim < arg.dims(); ++dim) { - h = Hash64Combine(h, std::hash()(arg.dim_size(dim))); - } - return h; - } - uint64 operator()(const TensorTypeAndShape& arg) { - h = Hash64Combine(h, std::hash()(static_cast(arg.first))); - h = Hash64Combine(h, std::hash()(arg.second.size())); - for (int dim : arg.second) { - h = Hash64Combine(h, std::hash()(dim)); - } - return h; - } -}; - -std::string XlaSerializedCacheKeyToString(const XlaSerializedCacheKey& key) { - return absl::StrCat( - key.prefix(), key.prefix().empty() ? "" : kXlaSerializedCacheKeySeparator, - key.signature_fingerprint(), kXlaSerializedCacheKeySeparator, - key.cluster_fingerprint(), kXlaSerializedCacheKeySeparator, - key.device_type()); -} - -} // namespace - -constexpr int64_t XlaCompilationCache::kDefaultCompilationThreshold; -constexpr int64_t - XlaCompilationCache::AsyncCompilationState::kNumCompilerThreads; -constexpr int64_t - XlaCompilationCache::AsyncCompilationState::kMaxNumOngoingCompilations; - -XlaCompilationCache::XlaCompilationCache(Config config, - xla::LocalClient* client, - DeviceType device_type) - : client_(client), - device_type_(std::move(device_type)), - disable_strict_signature_checks_(config.disable_strict_signature_checks), - persistance_prefix_(config.persistance_prefix), - persistent_cache_directory_(config.persistent_cache_directory) {} - -XlaCompilationCache::~XlaCompilationCache() { - // Ensure any use of our programs have completed by waiting for all stream - // executors to complete. - for (auto* executor : client_->backend().stream_executors()) { - bool ok = executor->SynchronizeAllActivity(); - if (!ok) { - LOG(ERROR) << "Error synchronizing activity while waiting for all " - "programs to complete"; - } - } - // Wait for all outstanding compilations to finish. - // Resetting the pointer explicitly in the top level destructor. - // Without this, the pointer would be reset when the AsyncCompilationState - // is destructed, which is dependent on the order of the members in the - // XlaCompilationCache class, which is error prone if the order changes. - async_compilation_state_.compiler_threads.reset(); - // TODO(b/110813685): Think about the program ownership model. Programs are - // currently owned by the compilation cache which means we must wait for - // program completion in the destructor. There are multiple compilation caches - // around, which complicates things a little. Perhaps having programs be - // shared_ptrs (an invasive change) would make the model easier to reason - // about? -} - -string XlaCompilationCache::DebugString() const { - return "XLA JIT compilation cache"; -} - -// Compute a string signature which encodes the shapes of the -// arguments in the supplied list. -string XlaCompilationCache::Signature::HumanString() const { - string result = name; - for (const auto& a : args) { - absl::visit(SignatureHumanStringAppender(&result), a); - } - return result; -} - -bool XlaCompilationCache::Signature::operator==(const Signature& other) const { - if (name != other.name) return false; - if (args.size() != other.args.size()) return false; - for (int i = 0, end = args.size(); i < end; ++i) { - if (absl::visit(SignatureNotEqual(), args[i], other.args[i])) { - return false; - } - } - return true; -} - -uint64 XlaCompilationCache::Signature::Hash::operator()( - const XlaCompilationCache::Signature& signature) const { - uint64 h = std::hash()(signature.name); - for (const auto& arg : signature.args) { - h = absl::visit(SignatureHashCombiner(h), arg); - } - return h; -} - -StatusOr XlaCompilationCache::BuildSignature( - const NameAttrList& function, - absl::Span args) { - Signature signature; - signature.name = Canonicalize(function.name(), AttrSlice(&function.attr())); - - for (const XlaCompiler::Argument& arg : args) { - switch (arg.kind) { - case XlaCompiler::Argument::kConstant: - case XlaCompiler::Argument::kConstantResource: - signature.args.push_back(arg.constant_value); - break; - case XlaCompiler::Argument::kParameter: - case XlaCompiler::Argument::kResource: - signature.args.push_back( - TensorTypeAndShape(arg.type, arg.DimensionSizesAsInlinedVector())); - break; - default: - return errors::InvalidArgument( - "Unhandled argument kind in XlaCompilationCache: ", - arg.HumanString()); - } - } - return std::move(signature); -} - -static std::vector GetShapePointers( - absl::Span shapes) { - std::vector shape_ptrs; - shape_ptrs.reserve(shapes.size()); - for (const auto& shape : shapes) { - shape_ptrs.push_back(&shape); - } - return shape_ptrs; -} - -Status XlaCompilationCache::BuildExecutable( - const XlaCompiler::Options& options, - const XlaCompiler::CompilationResult& result, - std::unique_ptr* executable) { - VLOG(2) << "Compiling to local executable"; - - std::vector argument_layouts = - GetShapePointers(result.xla_input_shapes); - xla::ExecutableBuildOptions build_options = GetExecutableBuildOptions( - options, result, client_->default_device_ordinal()); - TF_ASSIGN_OR_RETURN( - auto executables, - client_->Compile(*result.computation, argument_layouts, build_options)); - TF_RET_CHECK(executables.size() == 1); - *executable = std::move(executables[0]); - return OkStatus(); -} - -StatusOr> -XlaCompilationCache::BuildSerializedExecutable( - const XlaCompiler::Options& options, - const XlaCompiler::CompilationResult& result) { - VLOG(2) << "Compiling to local executable"; - - std::vector argument_layouts = - GetShapePointers(result.xla_input_shapes); - xla::ExecutableBuildOptions build_options = GetExecutableBuildOptions( - options, result, client_->default_device_ordinal()); - TF_ASSIGN_OR_RETURN( - std::vector> aot_results, - client_->CompileAheadOfTime(*result.computation, argument_layouts, - build_options)); - TF_RET_CHECK(aot_results.size() == 1); - return std::move(aot_results[0]); -} - -StatusOr> -XlaCompilationCache::LoadExecutable( - const XlaCompiler::Options& options, - const XlaCompiler::CompilationResult& result, - const std::string& serialized_aot_result) { - VLOG(2) << "Loading local executable using BEF."; - - xla::ExecutableBuildOptions build_options = GetExecutableBuildOptions( - options, result, client_->default_device_ordinal()); - return client_->Load(serialized_aot_result, build_options); -} - -Status XlaCompilationCache::Compile( - const XlaCompiler::Options& options, const NameAttrList& function, - const std::vector& args, - const XlaCompiler::CompileOptions& compile_options, - CompileMode compile_mode, - const XlaCompiler::CompilationResult** out_compilation_result, - xla::LocalExecutable** out_executable) { - return CompileImpl(compile_options, options, function, args, - /*ctx=*/nullptr, CompileScope::kFunction, compile_mode, - out_compilation_result, out_executable); -} - -static bool ShouldBeMegamorphic(int64_t compile_count, - int64_t execution_count) { - const int64_t kCompileThreshold = 10; - const int64_t kMinExecutionsPerCompile = 50; - - // This heuristic is trying to capture the following property: have we sunk a - // certain minimum amount of compile time into the cluster that didn't quite - // "pay off"? - return compile_count > kCompileThreshold && - execution_count < kMinExecutionsPerCompile * compile_count; -} - -StatusOr> CreateGraph( - const NodeDef& node_def, absl::Span args, - absl::Span result_types) { - // TODO(b/74182462): We implement this by creating a new dummy Graph including - // _Arg nodes, and let CompileGraph walk it. This could be optimized. - std::unique_ptr graph(new Graph(OpRegistry::Global())); - - // First create the actual node we care about computing. - TF_ASSIGN_OR_RETURN(Node * main_node, graph->AddNode(node_def)); - - // Create dummy _Arg nodes. Link these to `node` and also via a control - // dependency edge to the _SOURCE node. - for (int64_t i = 0, end = args.size(); i < end; ++i) { - Node* node; - string arg_name = absl::StrCat("_arg", i); - Status status = - NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp) - .ControlInput(graph->source_node()) - .Attr("T", args[i].kind == XlaCompiler::Argument::kResource - ? DT_RESOURCE - : args[i].type) - .Attr("index", i) - .Finalize(graph.get(), &node); - TF_RETURN_IF_ERROR(status); - graph->AddEdge(node, 0, main_node, i); - } - - // Similarly with return values, create dummy _Retval nodes fed by `node`. - for (int64_t i = 0, end = result_types.size(); i < end; ++i) { - Node* node; - string retval_name = absl::StrCat("_retval", i); - Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp) - .Input(main_node, i) - .Attr("T", result_types[i]) - .Attr("index", i) - .Finalize(graph.get(), &node); - TF_RETURN_IF_ERROR(status); - } - FixupSourceAndSinkEdges(graph.get()); - return graph; -} - -Status XlaSingleOpToHlo( - XlaCompiler* compiler, const XlaCompiler::Options& options, - const std::vector& args, - const XlaCompiler::SingleOpCompileArgument& single_op_compile_argument, - const XlaCompiler::CompileOptions& compile_options, - XlaCompiler::CompilationResult* compilation_result) { - const std::vector& result_dtypes = - single_op_compile_argument.output_dtypes; - const NodeDef& node_def = single_op_compile_argument.node_def; - TF_ASSIGN_OR_RETURN( - auto graph, - CreateGraph(node_def, args, single_op_compile_argument.output_dtypes)); - - auto compile_with_old_bridge = [&]() { - *compilation_result = {}; - return compiler->CompileGraph(compile_options, node_def.name(), - std::move(graph), args, compilation_result); - }; - - const ConfigProto* config = &(single_op_compile_argument.config_proto); - auto bridge_rollout = GetMlirBridgeRolloutState( - config ? std::optional(*config) : std::nullopt); - if (bridge_rollout == - ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED || - node_def.op() == "VarIsInitializedOp" || - (bridge_rollout != - ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED && - options.device_type.type_string() != DEVICE_TPU_XLA_JIT)) { - return compile_with_old_bridge(); - } - - GraphDebugInfo debug_info; - std::vector control_rets; - if (result_dtypes.empty()) { - control_rets.push_back(node_def.name()); - } - - bool mlir_enabled = (bridge_rollout == - ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED); - VLOG(1) << "Attempting MLIR bridge." - << (mlir_enabled ? " MLIR is explicitly enabled." : ""); - auto mlir_result = CompileGraphToXlaHlo( - *graph, mlir::SpanToArrayRef(args), control_rets, - options.device_type.type_string(), compile_options.use_tuple_arg, - /*analyse_graph=*/!mlir_enabled, *options.flib_def, debug_info, - options.shape_determination_fns, compilation_result); - - if (mlir_result.ok() || mlir_enabled) { - return mlir_result; - } - - VLOG(2) << "Failed second phase of the MLIR bridge. Will " - "retry with the old bridge. MLIR bridge compilation status: " - << mlir_result; - return compile_with_old_bridge(); -} - -Status XlaCompilationCache::CompileSingleOp( - const XlaCompiler::Options& options, - const std::vector& args, OpKernelContext* ctx, - const XlaCompiler::CompileOptions& compile_options, - const XlaCompiler::CompilationResult** out_compilation_result, - xla::LocalExecutable** out_executable) { - const NodeDef& def = ctx->op_kernel().def(); - NameAttrList name; - name.set_name(def.op()); - *name.mutable_attr() = def.attr(); - // Remove the "_class" attribute from the attribute set used to create the - // compilation cache key. This attribute is information for the colocator - // and causes false uniqueness between nodes. - name.mutable_attr()->erase("_class"); - return CompileImpl(compile_options, options, name, args, ctx, - CompileScope::kOp, CompileMode::kStrict, - out_compilation_result, out_executable); -} - -namespace { -// Print something that users can search for to definitively ascertain that XLA -// was used for their TF model. -// -// Prints only once to avoid spamming LOG(INFO). -void LogOnceXlaCompiledFirstCluster() { - static absl::once_flag log_once; - absl::call_once(log_once, [] { - LOG(INFO) << "Compiled cluster using XLA! This line is logged at most " - "once for the lifetime of the process."; - }); -} -} // namespace - -Status XlaCompilationCache::CompileStrict( - const Signature& sig, Entry* entry, - const XlaCompiler::CompileOptions& compile_options, - const XlaCompiler::Options& options, - const std::vector& args, - const NameAttrList& function, OpKernelContext* ctx, CompileScope scope) { - tensorflow::Env* env = tensorflow::Env::Default(); - const uint64 compile_start_us = env->NowMicros(); - - XlaCompiler compiler(options); - entry->compile_state = CompileState::kCompiled; - entry->compilation_status = [&] { - if (scope == CompileScope::kOp) { - return XlaSingleOpToHlo(&compiler, options, args, - BuildSingleOpCompileArgument(ctx), - compile_options, &entry->compilation_result); - - } else { - CHECK(scope == CompileScope::kFunction); // Crash OK - return compiler.CompileFunction(compile_options, function, args, - &entry->compilation_result); - } - }(); - TF_RETURN_IF_ERROR(entry->compilation_status); - TF_RET_CHECK(entry->executable.get() == nullptr); - TF_RET_CHECK(entry->compilation_result.computation != nullptr); - - std::optional serialized_entry; - if (!persistent_cache_directory_.empty()) { - const xla::HloModuleProto& hlo_module = - entry->compilation_result.computation->proto(); - - XlaSerializedCacheKey cache_key = BuildSerializedCacheKey(sig, hlo_module); - - { - XLA_SCOPED_LOGGING_TIMER(absl::StrCat( - "Try loading serialized cache entry:", sig.HumanString())); - TF_ASSIGN_OR_RETURN(serialized_entry, TryLoadSerializedEntry(cache_key)); - } - - if (serialized_entry.has_value()) { - TF_RETURN_IF_ERROR( - VerifyLoadedCacheEntry(cache_key, hlo_module, *serialized_entry)); - } - } - - if (serialized_entry.has_value()) { - VLOG(1) << "Loading cached entry for: " << sig.HumanString(); - StatusOr> executable = LoadExecutable( - options, entry->compilation_result, serialized_entry->executable()); - entry->compilation_status = executable.status(); - if (executable.ok()) { - entry->executable = *std::move(executable); - } - } else { - entry->compilation_status = - BuildExecutable(options, entry->compilation_result, &entry->executable); - - // Caching is done regardless of the entry->compilation_status. To take - // advantage of newer compilation code, a cache flush is required. - if (!persistent_cache_directory_.empty()) { - XLA_SCOPED_LOGGING_TIMER(absl::StrCat( - "Serializing and saving cache entry: ", sig.HumanString())); - TF_ASSIGN_OR_RETURN(XlaSerializedCacheEntry serialized_entry, - SerializeEntry(options, sig, *entry)); - TF_RETURN_IF_ERROR(SaveSerializedEntry(std::move(serialized_entry))); - } - } - - const uint64 compile_end_us = env->NowMicros(); - const uint64 compile_time_us = compile_end_us - compile_start_us; - metrics::UpdateXlaCompilationTime(compile_time_us); - - mutex_lock lock(cluster_compile_stats_mu_); - const std::string& function_name = function.name(); - auto it = cluster_compile_stats_.find(function_name); - const uint64 compile_time_s = compile_time_us / 1.0e6; - it->second.compile_count++; - it->second.cumulative_compile_time_us += compile_time_us; - LogOnceXlaCompiledFirstCluster(); - VLOG(1) << "compiled " << function_name << " " << it->second.compile_count - << " times, compile time: " << compile_time_us - << " us, cumulative: " << it->second.cumulative_compile_time_us - << " us (" - << tensorflow::strings::HumanReadableElapsedTime(compile_time_s) - << " / " - << tensorflow::strings::HumanReadableElapsedTime( - it->second.cumulative_compile_time_us / 1.0e6) - << ")"; - - XlaJitCompilationActivity jit_compilation_activity; - jit_compilation_activity.set_cluster_name(function_name); - jit_compilation_activity.set_compile_count(it->second.compile_count); - jit_compilation_activity.set_compile_time_us(compile_time_us); - jit_compilation_activity.set_cumulative_compile_time_us( - it->second.cumulative_compile_time_us); - jit_compilation_activity.set_used_persistent_cache( - serialized_entry.has_value()); - TF_RETURN_IF_ERROR(BroadcastXlaActivity(std::move(jit_compilation_activity))); - - return OkStatus(); -} - -Status XlaCompilationCache::CompileAsynchronous( - const Signature& signature, Entry* entry, - const XlaCompiler::CompileOptions& compile_options, - const XlaCompiler::Options& options, - const std::vector& args, - const NameAttrList& function, OpKernelContext* ctx, CompileScope scope) { - // Explicitly capture all required data by value for async compilation. - entry->compile_state = CompileState::kCompiling; - { - mutex_lock lock(async_compilation_state_.async_compilation_state_mu); - async_compilation_state_.num_ongoing_compilations++; - } - // Don't move the above code into the thread function as it synchronously - // updates the async compilation state! - - // When the ThreadPool for the compilation cache is destroyed, it waits for - // compilations to have finished. This means that both 'entry' and 'this' will - // be alive for the duration of the compilation. - // !!Pay attention when additional variables must be captured by this lambda!! - // All values are captured by value. Make sure that all pointer values (like - // entry) do not get freed until the lambda has finished,\. - const std::string& function_name = function.name(); - async_compilation_state_.compiler_threads->Schedule([=] { - Entry local_entry; - VLOG(2) << "Starting asynchronous compilation of cluster " << function_name - << '.'; - // We don't need to lock local_entry.mu, but do it anyway to satisfy - // thread safety analysis. - mutex_lock entry_lock(local_entry.mu); - Status s = CompileStrict(signature, &local_entry, compile_options, options, - args, function, ctx, scope); - VLOG(2) << "Finished asynchronous compililation of cluster " - << function_name << '.'; - { - mutex_lock lock(async_compilation_state_.async_compilation_state_mu); - async_compilation_state_.num_ongoing_compilations--; - } - { // Populate original entry with compilation result. - mutex_lock entry_lock(entry->mu); - if (!s.ok()) { - entry->compilation_status = s; - } else { - entry->compilation_status = local_entry.compilation_status; - } - entry->compilation_result = local_entry.compilation_result; - entry->compile_state = local_entry.compile_state; - entry->executable = std::move(local_entry.executable); - } - }); - return OkStatus(); -} - -bool XlaCompilationCache::ShouldCompileCluster(CompileMode compile_mode, - bool is_megamorphic, - bool is_first_execution, - int64_t current_request_count, - const NameAttrList& function) { - std::optional compile_threshold; - if (compile_mode == CompileMode::kLazy) { - compile_threshold = kDefaultCompilationThreshold; - } else if (compile_mode == CompileMode::kAsync) { - compile_threshold = 0; // for now, always compile right away. - } - - if (compile_mode == CompileMode::kStrict) { - // Lazy compilation is disabled. - return true; - } - - if (is_megamorphic) { - BroadcastOptimizationRemark(XlaOptimizationRemark::MEGAMORPHIC_FUNCTION, - function.name()) - .IgnoreError(); - VLOG(2) << "Not compiling cluster " << function.name() - << " because it is megamorphic."; - return false; - } - - if (is_first_execution) { - return true; - } - - if (compile_mode == CompileMode::kAsync) { - // Asynchronous compilation is enabled. - mutex_lock lock(async_compilation_state_.async_compilation_state_mu); - if (async_compilation_state_.num_ongoing_compilations >= - async_compilation_state_.kMaxNumOngoingCompilations) { - VLOG(2) << "Not asynchronously compiling cluster " << function.name() - << " because of too many ongoing compilations."; - return false; - } - } - - bool reached_compile_threshold = current_request_count >= *compile_threshold; - if (!reached_compile_threshold) { - VLOG(2) << "Not compiling cluster " << function.name() - << " because it has not reached compile threshold; threshold is " - << *compile_threshold << " execution count " - << current_request_count << "."; - } - return reached_compile_threshold; -} - -StatusOr -XlaCompilationCache::GetCompilationResultIfAlreadyCompiled( - const NameAttrList& function, - absl::Span args) { - CompilationResultAndExecutable result{nullptr, nullptr}; - - TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args)); - - // The outer lock protects the existence of the cache entry. It does not - // protect the contents of the cache entry. - Entry* entry; - { - mutex_lock lock(compile_cache_mu_); - // Try to find a cache entry. - auto cache_entry = cache_.find(signature); - if (cache_entry == cache_.end()) { - return result; - } - entry = cache_entry->second.get(); - } - - // Acquire the cache entry lock. - // TODO(phawkins): this locking will need to be restructured when we implement - // cache eviction. - mutex_lock entry_lock(entry->mu); - - const CompileState state = entry->compile_state; - if (state != CompileState::kCompiled) { - return result; - } - - int64_t current_request_count = ++entry->request_count; - - VLOG(2) << "Compilation cache entry hit and is already compiled : " - << static_cast(entry->compile_state) - << " signature: " << signature.HumanString() << " with request count " - << current_request_count; - - result.compilation_result = &entry->compilation_result; - result.executable = entry->executable.get(); - - return result; -} - -Status XlaCompilationCache::CompileImpl( - const XlaCompiler::CompileOptions& compile_options, - const XlaCompiler::Options& options, const NameAttrList& function, - const std::vector& args, OpKernelContext* ctx, - CompileScope scope, CompileMode compile_mode, - const XlaCompiler::CompilationResult** out_compilation_result, - xla::LocalExecutable** out_executable) { - DCHECK_NE(out_executable, nullptr); - VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); - - if (VLOG_IS_ON(2)) { - VLOG(2) << "num_inputs=" << args.size(); - for (int i = 0, end = args.size(); i < end; i++) { - VLOG(3) << i << ": " << args[i].HumanString(); - } - } - TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args)); - - // The outer lock protects the existence of the cache entry. It does not - // protect the contents of the cache entry. - Entry* entry; - { - mutex_lock lock(compile_cache_mu_); - // Find or create a cache entry. - auto cache_entry = cache_.find(signature); - if (cache_entry == cache_.end()) { - auto inserted_entry = - cache_.emplace(signature, std::make_unique()); - cache_entry = inserted_entry.first; - } - entry = cache_entry->second.get(); - } - - // We always compile a cluster the very first time it is executed. This is an - // optimistic guess that pays off for statically shaped TensorFlow graphs - // (since they get the benefit of XLA right away without waiting for warmup) - // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at - // most one cluster-compilation's worth of compile time). - bool is_first_execution; - - // We avoid compiling clusters that have "gone megamorphic" i.e. have an - // excessive amount of shape dynamism. - bool is_megamorphic; - - { - mutex_lock lock(cluster_compile_stats_mu_); - auto it = - cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{}) - .first; - is_first_execution = it->second.execution_count++ == 0; - - // The is_megamorphic bit is "sticky". We assume clusters that have been - // observed to be megamorphic once stay megamorphic forever. - if (!it->second.is_megamorphic && - ShouldBeMegamorphic(/*compile_count=*/it->second.compile_count, - /*execution_count=*/it->second.execution_count)) { - VLOG(1) << "Marking " << function.name() - << " as megamorphic, compile_count=" << it->second.compile_count - << " execution_count=" << it->second.execution_count; - it->second.is_megamorphic = true; - } - - is_megamorphic = it->second.is_megamorphic; - } - - string human_signature; - if (VLOG_IS_ON(2)) { - human_signature = VLOG_IS_ON(3) ? signature.HumanString() : function.name(); - VLOG(2) << "Signature: " << human_signature; - } - - // Acquire the cache entry lock and compile, if necessary. - // TODO(phawkins): this locking will need to be restructured when we implement - // cache eviction. - mutex_lock entry_lock(entry->mu); - int64_t current_request_count = ++entry->request_count; - VLOG(2) << "Compilation cache entry hit: " - << static_cast(entry->compile_state) - << " signature: " << human_signature << " with request count " - << current_request_count; - - CompileState state = entry->compile_state; - *out_compilation_result = nullptr; - *out_executable = nullptr; - - // Check if the requested entry is uncompiled and return an error if - // compilation is disabled. This will raise an error for kLazy even if we have - // not yet hit the compilation threshold and no compilation happens this - // round. This is to avoid non-determanism of when compilation is disallowed, - // for example by changing the threshold. - if (state == CompileState::kUncompiled && FailOnXlaCompilation()) { - VLOG(1) << "XLA compilation disabled: " << function.name() << "\n" - << absl::StrJoin( - args, "\n", - [](std::string* out, const XlaCompiler::Argument& arg) { - absl::StrAppend(out, " arg: ", arg.HumanString()); - }); - - return errors::Internal("XLA compilation disabled"); - } - - if (state == CompileState::kUncompiled) { - XLA_SCOPED_LOGGING_TIMER("Compilation of XLA executable"); - if (!ShouldCompileCluster(compile_mode, is_megamorphic, is_first_execution, - current_request_count, function)) { - VLOG(2) << "Not compiling for signature: " << human_signature; - return OkStatus(); - } else if (compile_mode == CompileMode::kAsync) { - VLOG(2) << "Queueing asynchronous compilation for signature: " - << human_signature; - TF_RETURN_IF_ERROR(CompileAsynchronous(signature, entry, compile_options, - options, args, function, ctx, - scope)); - return OkStatus(); - } else { - VLOG(2) << "Instantly compiling for signature: " << human_signature; - TF_RETURN_IF_ERROR(CompileStrict(signature, entry, compile_options, - options, args, function, ctx, scope)); - } - } else if (state == CompileState::kCompiling) { - VLOG(2) << "Ongoing asynchronous compilation for signature: " - << human_signature; - return OkStatus(); - } else if (state == CompileState::kCompiled) { - VLOG(2) << "Already Compiled for signature: " << human_signature; - } - - TF_RETURN_IF_ERROR(entry->compilation_status); - *out_compilation_result = &entry->compilation_result; - *out_executable = entry->executable.get(); - return OkStatus(); -} - -XlaSerializedCacheKey XlaCompilationCache::BuildSerializedCacheKey( - const Signature& sig, const xla::HloModuleProto& hlo_module) const { - XlaSerializedCacheKey serialized_cache_key; - serialized_cache_key.set_signature_fingerprint(Signature::Hash()(sig)); - serialized_cache_key.set_cluster_fingerprint( - DeterministicProtoHash64(hlo_module)); - serialized_cache_key.set_device_type(device_type_.type_string()); - serialized_cache_key.set_prefix(persistance_prefix_); - return serialized_cache_key; -} - -Status XlaCompilationCache::VerifyLoadedCacheEntry( - const XlaSerializedCacheKey& key, const xla::HloModuleProto& hlo_module, - const XlaSerializedCacheEntry& entry) { - XLA_SCOPED_LOGGING_TIMER(absl::StrCat("Verifying loaded cache entry: ", - hlo_module.entry_computation_name())); - - if (!AreSerializedProtosEqual(key, entry.key())) { - VLOG(2) << "Serialized cache key does not match:\n" - << "got:\n" - << entry.key().DebugString() << "\nexpected:\n" - << key.DebugString() << "\n"; - return errors::InvalidArgument("Serialized cache key does not match."); - } - - // Perform a stricter (slower) check of the snapshot to verify that they - // match exactly. - if (!disable_strict_signature_checks_) { - if (!AreSerializedProtosEqual(hlo_module, entry.hlo_module())) { - VLOG(2) << "HLOs do not match:\n" - << "got:\n" - << hlo_module.DebugString() << "\nexpected:\n" - << entry.hlo_module().DebugString() << "\n"; - return errors::InvalidArgument("Serialized HLO does not match."); - } - } - - if (entry.executable().empty()) { - return errors::InvalidArgument("No binary found in serialized entry."); - } - return OkStatus(); -} - -StatusOr XlaCompilationCache::SerializeEntry( - const XlaCompiler::Options& options, const Signature& sig, - const Entry& entry) { - if (entry.compile_state != CompileState::kCompiled) { - return errors::FailedPrecondition( - "Cache entry to serialize is not compiled."); - } - if (entry.executable == nullptr) { - return errors::FailedPrecondition( - "LocalExecutable not found for cache entry to serialize."); - } - if (entry.executable->executable() == nullptr) { - return errors::FailedPrecondition( - "Executable not found for cache entry to serialize."); - } - - XlaSerializedCacheEntry serialized_entry; - const xla::HloModuleProto& hlo_module = - entry.compilation_result.computation->proto(); - *serialized_entry.mutable_key() = BuildSerializedCacheKey(sig, hlo_module); - *serialized_entry.mutable_hlo_module() = hlo_module; - - TF_ASSIGN_OR_RETURN( - std::unique_ptr aot_result, - BuildSerializedExecutable(options, entry.compilation_result)); - TF_ASSIGN_OR_RETURN(std::string serialized, aot_result->SerializeAsString()); - serialized_entry.set_executable(std::move(serialized)); - return serialized_entry; -} - -namespace { - -std::string GetFilePath(const XlaSerializedCacheKey& key, - absl::string_view persistent_cache_directory) { - const std::string file_name = - absl::StrCat(XlaSerializedCacheKeyToString(key), ".pb"); - return io::JoinPath(persistent_cache_directory, file_name); -} - -} // namespace - -Status XlaCompilationCache::SaveSerializedEntry( - const XlaSerializedCacheEntry& entry) { - Env* env = Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(persistent_cache_directory_)); - const std::string file_path = - GetFilePath(entry.key(), persistent_cache_directory_); - return WriteBinaryProto(env, file_path, entry); -} - -StatusOr> -XlaCompilationCache::TryLoadSerializedEntry(const XlaSerializedCacheKey& key) { - Env* env = Env::Default(); - const std::string file_path = GetFilePath(key, persistent_cache_directory_); - if (!env->FileExists(file_path).ok()) { - return StatusOr>(std::nullopt); - } - - XlaSerializedCacheEntry entry; - TF_RETURN_IF_ERROR(ReadTextOrBinaryProto(env, file_path, &entry)); - return StatusOr>(entry); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h deleted file mode 100644 index 22ca15d2a0e..00000000000 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ /dev/null @@ -1,357 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ -#define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ - -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/inlined_vector.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" -#include "absl/types/variant.h" -#include "tensorflow/compiler/jit/xla_compilation_cache.pb.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/device_mgr.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/protobuf/meta_graph.pb.h" - -namespace tensorflow { - -// The XlaCompilationCache class caches the results of the XlaCompiler class, -// which converts a Tensorflow graph into a compiled XLA compilation. -// -// Since XLA computations must have static shapes, the cache generates a new -// XLA computation for each new set of input shapes. -// -// Currently no cache eviction policy is implemented and the cache grows without -// bound. -class XlaCompilationCache : public ResourceBase { - public: - struct Config { - Config() {} - explicit Config(absl::string_view persistent_cache_directory, - bool disable_strict_signature_checks, - absl::string_view persistance_prefix) - : persistent_cache_directory(persistent_cache_directory), - disable_strict_signature_checks(disable_strict_signature_checks), - persistance_prefix(persistance_prefix) {} - - // If non-empty, JIT-compiled executables are saved to and loaded from the - // specified file system directory path. - std::string persistent_cache_directory; - - // Disable strict signature checks for entries loaded into the cache from - // external sources. - bool disable_strict_signature_checks = false; - - // The cache persistence prefix to use if serializing/deserialzing entries. - std::string persistance_prefix; - }; - XlaCompilationCache(Config config, xla::LocalClient* client, - DeviceType device_type); - ~XlaCompilationCache() override; - - enum class CompileMode { - kLazy, - kStrict, - kAsync, - }; - - enum class CompileState { kUncompiled, kCompiling, kCompiled }; - - enum class CompileScope { - kOp, - kFunction, - }; - - // Compiles a function into a XlaCompiler::CompilationResult that can be used - // to execute an XLA Computation. Compilation results are cached. - // `function` is the name of a Tensorflow function to compile. - // `args` is a description of the arguments to the computation. - // - // `compile_mode` controls the behavior of the compilation cache on a cache - // miss. If `compile_mode` is `kLazy` then, based on some profitability - // heuristics, the compilation cache may decide not to compile the cluster at - // this time. In this case it returns null into both `out_compilation_result` - // and `out_executable`. If `compile_mode` is `kStrict` then the compilation - // cache always attempts the compilation on a cache miss. If compilation mode - // is 'kAsync' compilation of the cluster happens in the background while the - // fallback path executes. - // - // The result of compilation is written to `*out_compilation_result`, which - // must be non-null. If `out_executable` is non-null, also builds an - // xla::LocalExecutable and sets `out_executable` to point to it. The - // resulting executable pointer may be null if the computation has no - // non-constant outputs. - Status Compile(const XlaCompiler::Options& options, - const NameAttrList& function, - const std::vector& args, - const XlaCompiler::CompileOptions& compile_options, - CompileMode compile_mode, - const XlaCompiler::CompilationResult** out_compilation_result, - xla::LocalExecutable** out_executable); - - // As above, but calls XlaCompiler::CompileSingleOp instead of - // XlaCompiler::CompileFunction. If MLIR bridge is enabled through ConfigProto - // in OpKernelContext, then uses MLIR bridge for compilation instead of - // XlaCompiler, if possible. - Status CompileSingleOp( - const XlaCompiler::Options& options, - const std::vector& args, OpKernelContext* ctx, - const XlaCompiler::CompileOptions& compile_options, - const XlaCompiler::CompilationResult** out_compilation_result, - xla::LocalExecutable** out_executable); - - struct CompilationResultAndExecutable { - const XlaCompiler::CompilationResult* compilation_result; - xla::LocalExecutable* executable; - }; - - // Returns CompilationResultAndExecutable with non-null compilation_result and - // executable if the signature is already compiled. - // If the signature has not been compiled yet, this function returns a - // CompilationResultAndExecutable instance with only nullptrs in it. - // Non-ok status means something other than the 2 circumstances above - // happened. - StatusOr - GetCompilationResultIfAlreadyCompiled( - const NameAttrList& function, - absl::Span args); - - xla::LocalClient* client() const { return client_; } - const DeviceType& device_type() const { return device_type_; } - - string DebugString() const override; - - // Describes the types, shapes and any compile-time constant arguments - // to a kernel. Key that uniquely identifies a compilation output. - struct Signature { - string name; - - // List of args (either as a TensorTypeAndShape or as a Tensor value) - // for compile-time constant arguments to the compilation, ordered by - // argument number. Tensors must be in host memory. - using TensorTypeAndShape = - std::pair>; - absl::InlinedVector, 8> args; - - bool operator==(const Signature& other) const; - - struct Hash { - uint64 operator()(const Signature& signature) const; - }; - - // Returns a human-readable description of the signature. - string HumanString() const; - }; - - // Builds the signature for a compilation. - static StatusOr BuildSignature( - const NameAttrList& function, - absl::Span args); - - private: - // Common implementation of Compile and CompileSingleOp. The `OpKernelContext` - // parameter is always null for the former. - Status CompileImpl( - const XlaCompiler::CompileOptions& compile_options, - const XlaCompiler::Options& options, const NameAttrList& function, - const std::vector& args, OpKernelContext* ctx, - CompileScope scope, CompileMode compile_mode, - const XlaCompiler::CompilationResult** out_compilation_result, - xla::LocalExecutable** out_executable); - - // Takes `result` which has been compiled from a Tensorflow subgraph to a - // XLA computation already, and generates an XLA LocalExecutable `executable`. - Status BuildExecutable(const XlaCompiler::Options& options, - const XlaCompiler::CompilationResult& result, - std::unique_ptr* executable); - - // Like BuildExecutable above, except that it generates an XLA - // AotCompilationResult (instead of LocalExecutable), which can be persisted - // to later load a LocalExecutable using the LoadExecutable() method below. - StatusOr> - BuildSerializedExecutable(const XlaCompiler::Options& options, - const XlaCompiler::CompilationResult& result); - - // Returns an XLA LocalExecutable loaded from a serialized XLA - // AotCompilationResult. - StatusOr> LoadExecutable( - const XlaCompiler::Options& options, - const XlaCompiler::CompilationResult& result, - const std::string& serialized_aot_result); - - // Determines whether the cluster should be compiled. - bool ShouldCompileCluster(CompileMode compile_mode, bool is_megamorphic, - bool is_first_execution, - int64_t current_request_count, - const NameAttrList& function); - - xla::LocalClient* const client_; - const DeviceType device_type_; - bool disable_strict_signature_checks_; - std::string persistance_prefix_; - - // The value associated with a cache entry. - struct Entry { - mutex mu; - - // The current compilation state for this entry. - CompileState compile_state = CompileState::kUncompiled; - - // The number of times a compilation with this signature has been requested. - int64_t request_count = 0; - - // Did compilation succeed? - Status compilation_status TF_GUARDED_BY(mu); - - // Output of the XlaCompiler. - XlaCompiler::CompilationResult compilation_result TF_GUARDED_BY(mu); - - // The XLA executable compiled from . May be null if no - // executable has been built. - std::unique_ptr executable TF_GUARDED_BY(mu); - }; - - // Returns a cache key proto that identifies an entry in the compilation - // cache. - XlaSerializedCacheKey BuildSerializedCacheKey( - const Signature& sig, const xla::HloModuleProto& hlo_module) const; - - // Serializes the signature and its corresponding entry to a proto message. - StatusOr SerializeEntry( - const XlaCompiler::Options& options, const Signature& sig, - const Entry& entry) TF_EXCLUSIVE_LOCKS_REQUIRED(entry.mu); - - // Checks if the loaded `entry` matches the expected `key` and `hlo_module`. - Status VerifyLoadedCacheEntry(const XlaSerializedCacheKey& key, - const xla::HloModuleProto& hlo_module, - const XlaSerializedCacheEntry& entry); - - Status CompileStrict(const Signature& sig, Entry* entry, - const XlaCompiler::CompileOptions& compile_options, - const XlaCompiler::Options& options, - const std::vector& args, - const NameAttrList& function, OpKernelContext* ctx, - CompileScope scope) - TF_EXCLUSIVE_LOCKS_REQUIRED(entry->mu); - Status CompileAsynchronous(const Signature& sig, Entry* entry, - const XlaCompiler::CompileOptions& compile_options, - const XlaCompiler::Options& options, - const std::vector& args, - const NameAttrList& function, OpKernelContext* ctx, - CompileScope scope); - - // Saves the cache entry in the file directory supplied during the - // construction of this class. Overwrites existing entries. - Status SaveSerializedEntry(const XlaSerializedCacheEntry& entry); - - // Tries to load a cache entry given a `key` by searching the file directory - // supplied during the construction of this class. Returns std::nullopt if no - // cache entry is found. - StatusOr> TryLoadSerializedEntry( - const XlaSerializedCacheKey& key); - - mutex compile_cache_mu_; - absl::flat_hash_map, Signature::Hash> cache_ - TF_GUARDED_BY(compile_cache_mu_); - - struct ClusterCompileStats { - // Number of times the cluster has been (re-)compiled. - int64_t compile_count = 0; - - // The number of times this cluster has been executed. - int64_t execution_count = 0; - - // Cumulative time spent compiling the cluster. - int64_t cumulative_compile_time_us = 0; - - // True if we have decided that this cluster is too dynamic (i.e. its shapes - // change too frequently) to profitably JIT compile. Once a cluster is - // tagged megamorphic, it stays megamorphic forever. - bool is_megamorphic = false; - }; - - mutex cluster_compile_stats_mu_; - - // Maps cluster names to compilation statistics for said cluster. - absl::flat_hash_map cluster_compile_stats_ - TF_GUARDED_BY(cluster_compile_stats_mu_); - - struct AsyncCompilationState { - mutex async_compilation_state_mu; - - // Number of threads for asynchronous compilations. - static constexpr int64_t kNumCompilerThreads = 10; - - // Maximum number of ongoing compilations. - static constexpr int64_t kMaxNumOngoingCompilations = kNumCompilerThreads; - - // Number of ongoing compilations. - int64_t num_ongoing_compilations TF_GUARDED_BY(async_compilation_state_mu) = - 0; - - // Pool of threads for asynchronous compilations. - std::unique_ptr compiler_threads; - - AsyncCompilationState() { - compiler_threads = std::make_unique( - tensorflow::Env::Default(), "async_compiler_threads", - kNumCompilerThreads); - } - - } async_compilation_state_; - - // The number of times a lazy compilation must be requested for a specific - // signature before we attempt to compile it. - static constexpr int64_t kDefaultCompilationThreshold = 2; - - // If non-empty, JIT-compiled executables are saved to and loaded from the - // specified file system directory path. - std::string persistent_cache_directory_; - - TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); -}; - -// Creates a single-node graph using the specified node_def as the only op apart -// from the arg and retval nodes. -StatusOr> CreateGraph( - const NodeDef& node_def, absl::Span args, - absl::Span result_types); - -// Use XlaCompiler to compile a single op into HLO. -Status XlaSingleOpToHlo( - XlaCompiler* compiler, const XlaCompiler::Options& options, - const std::vector& args, - const XlaCompiler::SingleOpCompileArgument& single_op_compile_argument, - const XlaCompiler::CompileOptions& compile_options, - XlaCompiler::CompilationResult* compilation_result); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 1415a572622..18347058806 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -17,9 +17,11 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" +#include #include #include "absl/memory/memory.h" +#include "tensorflow/compiler/jit/device_compilation_profiler.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/jit/xla_platform_info.h" @@ -31,13 +33,18 @@ limitations under the License. #include "tensorflow/core/lib/core/refcount.h" namespace tensorflow { +namespace { +using XlaDeviceCompiler = + DeviceCompiler; +} // namespace Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, - XlaCompilationCache* cache, + XlaDeviceCompiler* xla_device_compiler, const XlaCompiler::CompilationResult* result, xla::LocalExecutable* executable, const ResourceVarsSnapshot& variable_args) { - xla::LocalClient* client = static_cast(cache->client()); + xla::LocalClient* client = + static_cast(xla_device_compiler->client()); se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -97,7 +104,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, Status XlaCompileOnDemandOp::Compile( OpKernelContext* ctx, const XlaCompiler::CompilationResult** result, - XlaCompilationCache** cache, ResourceVarsSnapshot* variable_args, + XlaDeviceCompiler** xla_device_compiler, + DeviceCompilationProfiler** profiler, ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable) { TF_ASSIGN_OR_RETURN(std::vector constant_input_indices, GetConstantInputIndicesFromContext(ctx)); @@ -108,15 +116,22 @@ Status XlaCompileOnDemandOp::Compile( ResourceMgr* rm = ctx->resource_manager(); CHECK(rm); - TF_RETURN_IF_ERROR(rm->LookupOrCreate( - rm->default_container(), "xla_cache", cache, - [&](XlaCompilationCache** write_into_cache) { - return BuildXlaCompilationCache(ctx->device(), ctx->function_library(), - platform_info_, write_into_cache); + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), "xla_device_compiler", xla_device_compiler, + [&](XlaDeviceCompiler** xla_device_compiler) { + return BuildXlaDeviceCompiler(ctx->device(), ctx->function_library(), + platform_info_, xla_device_compiler); + })); + + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), "device_compilation_profiler", profiler, + [](DeviceCompilationProfiler** profiler) { + *profiler = new DeviceCompilationProfiler(); + return OkStatus(); })); XlaCompiler::Options options = GenerateCompilerOptions( - **cache, *ctx->function_library(), ctx->device(), + **xla_device_compiler, *ctx->function_library(), ctx->device(), ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, platform_info_, /*has_ref_vars=*/true); // No detailed logging from on demand op. @@ -146,25 +161,29 @@ Status XlaCompileOnDemandOp::Compile( TF_RETURN_IF_ERROR(args.status()); } - return (*cache)->CompileSingleOp(options, *args, ctx, compile_options, result, - executable); + return (*xla_device_compiler) + ->CompileSingleOpIfNeeded(options, *args, compile_options, ctx, *profiler, + result, executable); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { const XlaCompiler::CompilationResult* result; xla::LocalExecutable* executable; ResourceVarsSnapshot variable_args; - XlaCompilationCache* cache; + XlaDeviceCompiler* xla_device_compiler; + DeviceCompilationProfiler* profiler; OP_REQUIRES(ctx, ctx->function_library(), errors::Internal("Function library missing")); - OP_REQUIRES_OK(ctx, - Compile(ctx, &result, &cache, &variable_args, &executable)); - - // Hold the reference to the JIT during evaluation. (We could probably - // free it sooner because the ResourceMgr will retain a reference, but - // this is more obviously correct.) - core::ScopedUnref cache_ref(cache); - OP_REQUIRES_OK(ctx, Run(ctx, cache, result, executable, variable_args)); + OP_REQUIRES_OK(ctx, Compile(ctx, &result, &xla_device_compiler, &profiler, + &variable_args, &executable)); + + // Hold the reference to the XLA device compiler and profiler during + // evaluation. (We could probably free them sooner because the ResourceMgr + // will retain references, but this is more obviously correct.) + core::ScopedUnref xla_device_compiler_ref(xla_device_compiler); + core::ScopedUnref profiler_ref(profiler); + OP_REQUIRES_OK( + ctx, Run(ctx, xla_device_compiler, result, executable, variable_args)); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h index 598ea7d3093..800b78d3286 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -19,6 +19,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ #define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ +#include "tensorflow/compiler/jit/device_compilation_profiler.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/jit/xla_platform_info.h" @@ -45,11 +46,15 @@ class XlaCompileOnDemandOp : public OpKernel { XlaCompiler::Argument CreateCompilerArgument(OpKernelContext* ctx, int64_t i); Status Compile(OpKernelContext* ctx, const XlaCompiler::CompilationResult** result, - XlaCompilationCache** cache, + DeviceCompiler** + xla_device_compiler, + DeviceCompilationProfiler** profiler, ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable); - Status Run(OpKernelContext* ctx, XlaCompilationCache* cache, + Status Run(OpKernelContext* ctx, + DeviceCompiler* + xla_device_compiler, const XlaCompiler::CompilationResult* result, xla::LocalExecutable* executable, const ResourceVarsSnapshot& variable_args); diff --git a/tensorflow/compiler/jit/xla_compile_util.cc b/tensorflow/compiler/jit/xla_compile_util.cc index 4fdd7834b22..8d72d20ba55 100644 --- a/tensorflow/compiler/jit/xla_compile_util.cc +++ b/tensorflow/compiler/jit/xla_compile_util.cc @@ -15,49 +15,59 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compile_util.h" +#include #include +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/util/determinism.h" namespace tensorflow { -xla::ExecutableBuildOptions GetExecutableBuildOptions( - const XlaCompiler::Options& options, - const XlaCompiler::CompilationResult& result, int default_device_ordinal) { - xla::ExecutableBuildOptions build_options; - if (result.collective_info) { - build_options.set_num_replicas(result.collective_info->group_size); - } - if (options.device_ordinal != -1) { - build_options.set_device_ordinal(options.device_ordinal); - } else if (default_device_ordinal != -1) { - build_options.set_device_ordinal(default_device_ordinal); +StatusOr> CreateSingleOpGraph( + const NodeDef& node_def, absl::Span args, + absl::Span result_types) { + // TODO(b/74182462): We implement this by creating a new dummy Graph including + // _Arg nodes, and let CompileGraph walk it. This could be optimized. + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + // First create the actual node we care about computing. + TF_ASSIGN_OR_RETURN(Node * main_node, graph->AddNode(node_def)); + + // Create dummy _Arg nodes. Link these to `node` and also via a control + // dependency edge to the _SOURCE node. + for (int64_t i = 0, end = args.size(); i < end; ++i) { + Node* node; + string arg_name = absl::StrCat("_arg", i); + Status status = + NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp) + .ControlInput(graph->source_node()) + .Attr("T", args[i].kind == XlaArgument::kResource ? DT_RESOURCE + : args[i].type) + .Attr("index", i) + .Finalize(graph.get(), &node); + TF_RETURN_IF_ERROR(status); + graph->AddEdge(node, 0, main_node, i); } - build_options.set_result_layout(result.xla_output_shape); - build_options.set_device_allocator(options.device_allocator.get()); - build_options.set_alias_passthrough_params(options.alias_passthrough_params); - build_options.mutable_debug_options()->set_xla_detailed_logging_and_dumping( - options.detailed_logging); - if (tensorflow::OpDeterminismRequired()) { - build_options.mutable_debug_options()->set_xla_gpu_deterministic_ops(true); + + // Similarly with return values, create dummy _Retval nodes fed by `node`. + for (int64_t i = 0, end = result_types.size(); i < end; ++i) { + Node* node; + string retval_name = absl::StrCat("_retval", i); + Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp) + .Input(main_node, i) + .Attr("T", result_types[i]) + .Attr("index", i) + .Finalize(graph.get(), &node); + TF_RETURN_IF_ERROR(status); } - return build_options; + FixupSourceAndSinkEdges(graph.get()); + return graph; } -XlaCompiler::SingleOpCompileArgument BuildSingleOpCompileArgument( - OpKernelContext* ctx) { - XlaCompiler::SingleOpCompileArgument single_op_arg; - std::vector output_dtypes(ctx->num_outputs()); - for (int i = 0; i < output_dtypes.size(); ++i) { - output_dtypes[i] = ctx->expected_output_dtype(i); - } - single_op_arg.output_dtypes = output_dtypes; - single_op_arg.node_def = ctx->op_kernel().def(); - auto* config_proto = ctx->function_library()->config_proto(); - if (config_proto != nullptr) { - single_op_arg.config_proto = *config_proto; - } - return single_op_arg; +bool UsePjRtForSingleDeviceCompilation() { + return GetXlaOpsCommonFlags()->tf_xla_use_device_api; } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_util.h b/tensorflow/compiler/jit/xla_compile_util.h index d7b7fc6e6eb..bdc0ebafad5 100644 --- a/tensorflow/compiler/jit/xla_compile_util.h +++ b/tensorflow/compiler/jit/xla_compile_util.h @@ -16,20 +16,35 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILE_UTIL_H_ #define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_UTIL_H_ -#include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/executable_build_options.h" -#include "tensorflow/core/framework/op_kernel.h" +#include -namespace tensorflow { - -// Generates the ExecutableBuildOptions for compliation from HLO to executable. -xla::ExecutableBuildOptions GetExecutableBuildOptions( - const XlaCompiler::Options& options, - const XlaCompiler::CompilationResult& result, int default_device_ordinal); - -XlaCompiler::SingleOpCompileArgument BuildSingleOpCompileArgument( - OpKernelContext* ctx); +#include "tensorflow/compiler/tf2xla/xla_argument.h" +#include "tensorflow/core/graph/graph.h" +namespace tensorflow { +// The number of compiler threads to use for asynchronous device compilation. +inline constexpr int64_t kNumAsyncDeviceCompilerThreads = 10; + +enum class DeviceCompileMode { + kLazy, + kStrict, + kAsync, +}; + +enum class DeviceCompileState { + kUncompiled, + kCompiling, + kCompiled, +}; + +// Creates a single-node graph using the specified `node_def` as the only op +// apart from the arg and retval nodes corresponding to `args` and +// `result_types` respectively. +StatusOr> CreateSingleOpGraph( + const NodeDef& node_def, absl::Span args, + absl::Span result_types); + +bool UsePjRtForSingleDeviceCompilation(); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILE_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_compile_util_test.cc b/tensorflow/compiler/jit/xla_compile_util_test.cc index d1fb89fcd5e..0e971a6b4db 100644 --- a/tensorflow/compiler/jit/xla_compile_util_test.cc +++ b/tensorflow/compiler/jit/xla_compile_util_test.cc @@ -15,73 +15,62 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compile_util.h" #include +#include -#include #include +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/kernels/ops_testutil.h" namespace tensorflow { namespace { -using ::testing::ElementsAreArray; - -TEST_F(OpsTestBase, Basic) { +TEST_F(OpsTestBase, CreateSingleOpGraph) { TF_EXPECT_OK(NodeDefBuilder("identity_op", "Identity") .Input(FakeInput(DT_FLOAT)) .Attr("T", DT_FLOAT) .Finalize(node_def())); TF_EXPECT_OK(InitOp()); - AddInputFromArray(TensorShape({1, 2}), {0, 1}); + AddInputFromArray(TensorShape({1, 2}), {6.9, 4.2}); TF_EXPECT_OK(RunOpKernel()); - auto arg = BuildSingleOpCompileArgument(context_.get()); + XlaCompiler::SingleOpCompileArgument single_op_arg(*context_); - EXPECT_THAT(arg.output_dtypes, ElementsAreArray({DT_FLOAT})); - EXPECT_EQ(arg.node_def.SerializeAsString(), - context_->op_kernel().def().SerializeAsString()); - EXPECT_EQ(arg.config_proto.ByteSizeLong(), 0); -} + std::vector args(1); + args[0].kind = XlaArgument::kConstant; + args[0].type = DT_FLOAT; + args[0].shape = TensorShape({1, 2}); + args[0].constant_value = GetInput(0); + args[0].initialized = true; -TEST(GetExecutableOptionTest, Basic) { - XlaCompiler::Options options; - options.device_ordinal = 0; - options.alias_passthrough_params = true; - options.detailed_logging = true; - XlaCompiler::CompilationResult result; - xla::Shape xla_output_shape; - result.xla_output_shape = xla_output_shape; - - auto build_option = - GetExecutableBuildOptions(options, result, /*default_device_ordinal=*/-1); - - EXPECT_EQ(build_option.device_ordinal(), 0); - EXPECT_EQ(build_option.result_layout()->ToString(), - xla_output_shape.ToString()); - EXPECT_EQ(build_option.alias_passthrough_params(), true); - EXPECT_EQ(build_option.debug_options().xla_detailed_logging_and_dumping(), - true); - LOG(ERROR) << build_option.ToString(); -} + TF_ASSERT_OK_AND_ASSIGN( + auto graph, + CreateSingleOpGraph(*node_def(), args, single_op_arg.output_dtypes)); -TEST(GetExecutableOptionTest, DefaultDeviceOrdinal) { - XlaCompiler::Options options; - XlaCompiler::CompilationResult result; + const auto& node_name_index = graph->BuildNodeNameIndex(); - auto build_option = - GetExecutableBuildOptions(options, result, /*default_device_ordinal=*/0); + const Node* identity_node = node_name_index.at("identity_op"); + EXPECT_EQ(identity_node->op_def().name(), "Identity"); + EXPECT_EQ(identity_node->attrs().FindByString("T")->type(), DT_FLOAT); - EXPECT_EQ(build_option.device_ordinal(), 0); -} + EXPECT_EQ(identity_node->num_inputs(), 1); + const Node* identity_input_node = nullptr; + TF_EXPECT_OK(identity_node->input_node(0, &identity_input_node)); + EXPECT_EQ(identity_input_node->name(), "_arg0"); -TEST(GetExecutableOptionTest, DeviceOrdinalNotSet) { - XlaCompiler::Options options; - XlaCompiler::CompilationResult result; + const Node* arg_node = node_name_index.at("_arg0"); + EXPECT_EQ(arg_node->op_def().name(), "_Arg"); + EXPECT_EQ(arg_node->attrs().FindByString("T")->type(), DT_FLOAT); - auto build_option = - GetExecutableBuildOptions(options, result, /*default_device_ordinal=*/-1); + const Node* retval_node = node_name_index.at("_retval0"); + EXPECT_EQ(retval_node->op_def().name(), "_Retval"); + EXPECT_EQ(retval_node->attrs().FindByString("T")->type(), DT_FLOAT); - EXPECT_EQ(build_option.device_ordinal(), -1); + EXPECT_EQ(identity_node->num_outputs(), 1); + EXPECT_EQ(retval_node->num_inputs(), 1); + const Node* retval_input_node = nullptr; + TF_EXPECT_OK(retval_node->input_node(0, &retval_input_node)); + EXPECT_EQ(retval_input_node->name(), "identity_op"); } } // namespace diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 9f1832af61c..4e03e45769d 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -299,7 +299,7 @@ Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend, return OkStatus(); } -StatusOr> XlaDevice::GetDeviceContextLocked() { +StatusOr> XlaDevice::GetDeviceContextLocked() { TF_ASSIGN_OR_RETURN(xla::LocalClient * client, GetOrCreateClient()); xla::Backend* backend = client->mutable_backend(); @@ -397,13 +397,13 @@ StatusOr> XlaDevice::GetDeviceContextLocked() { return device_contexts_; } -StatusOr XlaDevice::GetDeviceContextWithIndex(int index) { +StatusOr XlaDevice::GetDeviceContextWithIndex(int index) { mutex_lock lock(mu_); TF_ASSIGN_OR_RETURN(auto device_contexts, GetDeviceContextLocked()); return device_contexts.at(index); } -StatusOr XlaDevice::GetDeviceContextDefault() { +StatusOr XlaDevice::GetDeviceContextDefault() { return GetDeviceContextWithIndex(0); } @@ -502,7 +502,7 @@ void XlaDevice::Sync(const DoneCallback& done) { }); } -Status XlaDevice::MakeTensorFromProto(XlaDeviceContext* device_context, +Status XlaDevice::MakeTensorFromProto(DeviceContext* device_context, const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { @@ -534,7 +534,7 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { VLOG(1) << "XlaDevice::MakeTensorFromProto"; - XlaDeviceContext* device_context; + DeviceContext* device_context; TF_ASSIGN_OR_RETURN(device_context, GetDeviceContextDefault()); return MakeTensorFromProto(device_context, tensor_proto, alloc_attrs, tensor); } diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 6207932a0ae..fd200e57e06 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -166,7 +166,7 @@ class XlaDevice : public LocalDevice { const AllocatorAttributes alloc_attrs, Tensor* tensor) override TF_LOCKS_EXCLUDED(mu_); - Status MakeTensorFromProto(XlaDeviceContext* device_context, + Status MakeTensorFromProto(DeviceContext* device_context, const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor); @@ -184,9 +184,9 @@ class XlaDevice : public LocalDevice { // Two convenient methods to get the underlying device context. // Get the default device context, created by the first // shape_representation_fn. - StatusOr GetDeviceContextDefault(); + StatusOr GetDeviceContextDefault(); // Get the device context given the index. - StatusOr GetDeviceContextWithIndex(int index); + StatusOr GetDeviceContextWithIndex(int index); // Instructs this XlaDevice to set a AcceleratorDeviceInfo, which holds extra // information for GPU and TPU devices. @@ -214,7 +214,7 @@ class XlaDevice : public LocalDevice { // Return a vector of device context, ordered by the sequence in the given // shape_representation_fns. - StatusOr> GetDeviceContextLocked() + StatusOr> GetDeviceContextLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Handles error when RefreshStatus sees !status.ok(). @@ -260,8 +260,8 @@ class XlaDevice : public LocalDevice { // calls to EnsureDeviceContextOk. The number of device conetexts is based on // the number of shape representation functions in XlaDevice::Options. If // accelerator_device_info_ is non-null, this pointer is also filled in to - // that struct. XlaDeviceContext is a ref-counted object. - std::vector device_contexts_ TF_GUARDED_BY(mu_); + // that struct. DeviceContext is a ref-counted object. + std::vector device_contexts_ TF_GUARDED_BY(mu_); // Holds extra information for GPU and TPU devices, e.g. the device context. bool use_accelerator_device_info_ TF_GUARDED_BY(mu_) = false; diff --git a/tensorflow/compiler/jit/xla_device_compiler_client.cc b/tensorflow/compiler/jit/xla_device_compiler_client.cc new file mode 100644 index 00000000000..37ebffd95dc --- /dev/null +++ b/tensorflow/compiler/jit/xla_device_compiler_client.cc @@ -0,0 +1,114 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_device_compiler_client.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/client/local_client.h" + +namespace tensorflow { +namespace { +std::vector GetShapePointers( + absl::Span shapes) { + std::vector shape_ptrs; + shape_ptrs.reserve(shapes.size()); + for (const auto& shape : shapes) { + shape_ptrs.push_back(&shape); + } + return shape_ptrs; +} +} // namespace + +StatusOr> +XlaDeviceCompilerClient::BuildExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) { + VLOG(2) << "Compiling to xla::LocalExecutable."; + + std::vector argument_layouts = + GetShapePointers(result.xla_input_shapes); + xla::ExecutableBuildOptions build_options = GetExecutableBuildOptions( + options, result, client_->default_device_ordinal()); + TF_ASSIGN_OR_RETURN( + auto executables, + client_->Compile(*result.computation, argument_layouts, build_options)); + TF_RET_CHECK(executables.size() == 1); + return std::move(executables[0]); +} + +StatusOr XlaDeviceCompilerClient::SerializeExecutable( + const xla::LocalExecutable& executable) { + if (executable.executable() == nullptr) { + return errors::FailedPrecondition( + "Executable not found for serialization."); + } + + VLOG(1) + << "Exporting xla::LocalExecutable as an xla::AotCompilationResult and " + "serializing it to string."; + xla::Compiler* compiler = client_->backend().compiler(); + auto exported = compiler->Export(executable.executable()); + if (exported.ok()) { + return (*exported)->SerializeAsString(); + } + + return exported.status(); +} + +StatusOr XlaDeviceCompilerClient::BuildSerializedExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) { + VLOG(2) << "Compiling to xla::AotCompilationResult and serializing it"; + + std::vector argument_layouts = + GetShapePointers(result.xla_input_shapes); + xla::ExecutableBuildOptions build_options = GetExecutableBuildOptions( + options, result, client_->default_device_ordinal()); + TF_ASSIGN_OR_RETURN( + std::vector> aot_results, + client_->CompileAheadOfTime(*result.computation, argument_layouts, + build_options)); + TF_RET_CHECK(aot_results.size() == 1); + return aot_results[0]->SerializeAsString(); +} + +StatusOr> +XlaDeviceCompilerClient::LoadExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result, + const std::string& serialized_executable) { + VLOG(2) << "Loading xla::LocalExecutable from a serialized " + "xla::AotCompilationResult"; + + xla::ExecutableBuildOptions build_options = GetExecutableBuildOptions( + options, result, client_->default_device_ordinal()); + return client_->Load(serialized_executable, build_options); +} + +void XlaDeviceCompilerClient::WaitForProgramsToFinish() { + for (auto* executor : client_->backend().stream_executors()) { + bool ok = executor->SynchronizeAllActivity(); + if (!ok) { + LOG(ERROR) << "Error synchronizing activity while waiting for all " + "programs to complete"; + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_compiler_client.h b/tensorflow/compiler/jit/xla_device_compiler_client.h new file mode 100644 index 00000000000..01325cfcd62 --- /dev/null +++ b/tensorflow/compiler/jit/xla_device_compiler_client.h @@ -0,0 +1,68 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_COMPILER_CLIENT_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_COMPILER_CLIENT_H_ + +#include +#include +#include + +#include "tensorflow/compiler/jit/device_compiler_client.h" +#include "tensorflow/compiler/xla/client/local_client.h" + +namespace tensorflow { + +class XlaDeviceCompilerClient + : public DeviceCompilerClient { + public: + explicit XlaDeviceCompilerClient(xla::LocalClient* client) + : client_(client) {} + + StatusOr> BuildExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) override; + + // Returns a serialized AOT result obtained by exporting the available + // `executable` using the XlaCompiler. + StatusOr SerializeExecutable( + const xla::LocalExecutable& executable) override; + + // Returns a serialized AOT result obtained by compiling `result` into an AOT + // result. + StatusOr BuildSerializedExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) override; + + // Loads a serialized AOT result (`serialized_executable`) into an + // xla::LocalExecutable and returns it. + StatusOr> LoadExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result, + const std::string& serialized_executable) override; + + void WaitForProgramsToFinish() override; + + xla::LocalClient* client() const override { return client_; } + + private: + xla::LocalClient* const client_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceCompilerClient); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_COMPILER_CLIENT_H_ diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 0a1428d8e65..ee00464178f 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -227,7 +227,7 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, if (device_to_host_stream_) { device_to_host_stream = device_to_host_stream_; } else { - stream_executor::port::StatusOr ptr_or_status = + tsl::StatusOr ptr_or_status = client_->mutable_backend()->BorrowStream( stream_->parent()->device_ordinal()); if (!ptr_or_status.status().ok()) { diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 38f086cd1d1..573399b34e2 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -16,6 +16,7 @@ limitations under the License. // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend. +#include #include #include "absl/memory/memory.h" @@ -29,8 +30,8 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_init.h" #include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/common_runtime/gpu/gpu_init.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { @@ -51,7 +52,7 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { } auto platform = - se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName()); + se::MultiPlatformManager::PlatformWithName(se::GpuPlatformName()); if (!platform.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << platform.status(); @@ -100,7 +101,7 @@ Status XlaGpuDeviceFactory::CreateDevices( (void)registrations; auto platform = - se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName()); + se::MultiPlatformManager::PlatformWithName(se::GpuPlatformName()); if (!platform.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << platform.status(); @@ -155,10 +156,10 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaGpuTypes = { +constexpr std::array kAllXlaGpuTypes = { {DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, - DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; + DT_COMPLEX128, DT_BOOL, DT_BFLOAT16, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 3c08eb29fcb..450478d6ef9 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include +#include #include +#include #include "absl/algorithm/container.h" #include "absl/cleanup/cleanup.h" @@ -29,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/function.h" @@ -42,7 +45,9 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/util/stream_executor_util.h" +#include "tensorflow/tsl/platform/status.h" namespace tensorflow { namespace { @@ -99,6 +104,9 @@ VariableInfo::~VariableInfo() { if (lock_held()) { var()->mu()->unlock(); } + if (shared_lock_held()) { + var()->mu()->unlock_shared(); + } // Unref the variable so it can be released by ResourceManager. var()->Unref(); @@ -109,6 +117,15 @@ Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, absl::Span inputs, absl::Span variable_indices, std::vector* result) { + return GetVariableInfosFromInputs(rm, dev, inputs, variable_indices, nullptr, + result); +} + +Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, + absl::Span inputs, + absl::Span variable_indices, + const std::set* variables_updated, + std::vector* result) { result->clear(); result->reserve(variable_indices.size()); for (int var_idx : variable_indices) { @@ -131,8 +148,12 @@ Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, *ptr = new Var(DT_INVALID); return OkStatus(); })); - result->emplace_back(var_idx, handle.name(), variable, - handle.definition_stack_trace()); + VariableInfo& variable_info = result->emplace_back( + var_idx, handle.name(), variable, handle.definition_stack_trace()); + if (variables_updated != nullptr && + variables_updated->find(var_idx) == variables_updated->end()) { + variable_info.set_read_only(); + } } return OkStatus(); } @@ -181,10 +202,17 @@ Status LockVariables(absl::Span variables) { // TODO(b/128495870) Add support for passing aliased resource variables. return errors::Unimplemented("Duplicate variable passed to XLA cluster"); } - VLOG(4) << "Acquiring lock for variable " - << reinterpret_cast(variable); - mu->lock(); - variables[i]->set_lock_held(); + if (variables[i]->read_only()) { + VLOG(4) << "Acquiring reader lock for variable " + << reinterpret_cast(variable); + mu->lock_shared(); + variables[i]->set_shared_lock_held(); + } else { + VLOG(4) << "Acquiring lock for variable " + << reinterpret_cast(variable); + mu->lock(); + variables[i]->set_lock_held(); + } prev = mu; } VLOG(4) << "Finished acquiring variable locks."; @@ -512,7 +540,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( } std::shared_ptr definition_event; - if (use_multiple_streams_) { + if (use_multiple_streams_ && stream) { definition_event = std::make_shared(stream->parent()); if (!definition_event->Init()) { return errors::Internal("Failed to initialize tensor definition event."); @@ -634,6 +662,20 @@ Status XlaComputationLaunchContext::PopulateOutputs( return OkStatus(); } +Status CreateVariableInfoLookup( + absl::Span variable_args, + absl::flat_hash_map& variable_info_lookup) { + for (const VariableInfo& info : variable_args) { + if (!(!info.var() || info.lock_held() || info.shared_lock_held())) { + return errors::Internal( + "Need to hold the lock on resource variables " + "before calling BuildXlaCompilerArguments"); + } + variable_info_lookup.emplace(info.index(), &info); + } + return OkStatus(); +} + StatusOr> XlaComputationLaunchContext::BuildXlaCompilerArguments( absl::Span must_be_constant_idxs, @@ -662,13 +704,7 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments( } absl::flat_hash_map variable_info_lookup; - for (const VariableInfo& info : variable_args) { - CHECK(!info.var() || info.lock_held()) - << "Need to hold the lock on resource variables " - "before calling BuildXlaCompilerArguments"; - variable_info_lookup.emplace(info.index(), &info); - } - + TF_CHECK_OK(CreateVariableInfoLookup(variable_args, variable_info_lookup)); for (int64_t input_num = 0; input_num < inputs.size(); ++input_num) { const Tensor* input = inputs[input_num]; diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 0f35d9d020f..e09c910c0ea 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -18,7 +18,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ #define TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ -#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include +#include + #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -70,6 +72,14 @@ class VariableInfo { bool lock_held() const { return lock_held_; } void set_lock_held() { lock_held_ = true; } + // Returns true if the resource variable reader lock was successfully acquired + // by this thread. + bool shared_lock_held() const { return shared_lock_held_; } + void set_shared_lock_held() { shared_lock_held_ = true; } + + bool read_only() const { return read_only_; } + void set_read_only() { read_only_ = true; } + const std::optional& definition_stack_trace() const { return definition_stack_trace_; } @@ -86,6 +96,11 @@ class VariableInfo { // thread safety analysis. Instead we use a boolean flag and release the lock // in the VariableInfo destructor. bool lock_held_ = false; + bool shared_lock_held_ = false; + + // Whether this variable is going to be mutated. Left false if the caller + // doesn't provide this information. + bool read_only_ = false; }; // Creates a list of updated resource variables. @@ -113,6 +128,8 @@ Status SnapshotResourceVariables(OpKernelContext* ctx, // // `variables` is allowed to contain instances that don't track a resource // variable (i.e. variables[i].var() can be null for some i). +// +// If the variable is read_only(), only acquires reader locks. Status LockVariables(absl::Span variables) TF_EXCLUSIVE_LOCK_FUNCTION(); Status LockVariables(absl::Span variables) @@ -121,11 +138,25 @@ Status LockVariables(absl::Span variables) // Returns a vector of VariableInfo instances for the resource variable inputs, // given that *all* inputs are in `inputs`. The input indices for the resource // variable inputs are in `variable_indices`. +// +// When using the VariableInfos generated by this version, all variables would +// be writer-locked. Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, absl::Span inputs, absl::Span variable_indices, std::vector* result); +// variables_updated is a set containing the indices of the variables that are +// going to be mutated. If variables_updated is empty, then in LockVariables all +// variables would only be reader-locked. If variables_updated is null, then we +// consider this information unknown and will acquire writer-lock for all +// variables. +Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, + absl::Span inputs, + absl::Span variable_indices, + const std::set* variables_updated, + std::vector* result); + // Returns pointers to inputs stored in `ctx`. std::vector InputsFromContext(OpKernelContext* ctx); @@ -247,6 +278,10 @@ class XlaTensorBuffer : public TensorBuffer { Allocator* allocator_; }; +Status CreateVariableInfoLookup( + absl::Span variable_args, + absl::flat_hash_map& variable_info_lookup); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc index 0cad9e63669..0ed41b41833 100644 --- a/tensorflow/compiler/jit/xla_platform_info.cc +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -15,13 +15,24 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_platform_info.h" +#include +#include #include +#include +#include "tensorflow/compiler/jit/device_compiler_client.h" +#include "tensorflow/compiler/jit/device_executable_persistor.h" #include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/xla_device_compiler_client.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/tpu/tpu_defs.h" namespace tensorflow { +namespace { +using XlaDeviceCompiler = + DeviceCompiler; +} // namespace xla::StatusOr>> ParseVisibleDeviceList( absl::string_view visible_device_list) { @@ -44,26 +55,35 @@ xla::StatusOr>> ParseVisibleDeviceList( return {{gpu_ids}}; } -Status BuildXlaCompilationCache(DeviceBase* device, FunctionLibraryRuntime* flr, - const XlaPlatformInfo& platform_info, - XlaCompilationCache** cache) { - XlaCompilationCache::Config cache_config( +Status BuildXlaDeviceCompiler(DeviceBase* device, FunctionLibraryRuntime* flr, + const XlaPlatformInfo& platform_info, + XlaDeviceCompiler** xla_device_compiler) { + using XlaDeviceExecutablePersistor = + DeviceExecutablePersistor; + XlaDeviceExecutablePersistor::Config persistor_config( GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_directory, GetMarkForCompilationPassFlags()->tf_xla_disable_strict_signature_checks, GetMarkForCompilationPassFlags()->tf_xla_persistent_cache_prefix); if (platform_info.xla_device_metadata()) { - *cache = new XlaCompilationCache( - std::move(cache_config), platform_info.xla_device_metadata()->client(), + auto persistor = std::make_unique( + std::move(persistor_config), platform_info.xla_device_metadata()->jit_device_type()); + auto compiler_client = std::make_unique( + platform_info.xla_device_metadata()->client()); + *xla_device_compiler = + new XlaDeviceCompiler(std::move(persistor), std::move(compiler_client)); return OkStatus(); } // TFRT-TPU is used if device type is `DEVICE_TPU` and platform_info does not // have `xla_device_metadata`. if (platform_info.device_type() == DEVICE_TPU) { - *cache = new XlaCompilationCache(std::move(cache_config), nullptr, - DeviceType(DEVICE_TPU_XLA_JIT)); + auto persistor = std::make_unique( + std::move(persistor_config), DeviceType(DEVICE_TPU_XLA_JIT)); + auto compiler_client = std::make_unique(nullptr); + *xla_device_compiler = + new XlaDeviceCompiler(std::move(persistor), std::move(compiler_client)); return OkStatus(); } @@ -117,9 +137,14 @@ Status BuildXlaCompilationCache(DeviceBase* device, FunctionLibraryRuntime* flr, return errors::InvalidArgument("No JIT device registered for ", platform_info.device_type().type()); } - *cache = new XlaCompilationCache( - std::move(cache_config), client.value(), + + auto persistor = std::make_unique( + std::move(persistor_config), DeviceType(registration->compilation_device_name)); + auto compiler_client = + std::make_unique(client.value()); + *xla_device_compiler = + new XlaDeviceCompiler(std::move(persistor), std::move(compiler_client)); return OkStatus(); } @@ -174,16 +199,16 @@ std::shared_ptr GetAllocator( } XlaCompiler::Options GenerateCompilerOptions( - const XlaCompilationCache& cache, + const XlaDeviceCompiler& xla_device_compiler, const FunctionLibraryRuntime& function_library, DeviceBase* device, se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars) { XlaCompiler::Options options; - options.client = static_cast(cache.client()); + options.client = static_cast(xla_device_compiler.client()); if (stream != nullptr) { options.device_ordinal = stream->parent()->device_ordinal(); } - options.device_type = cache.device_type(); + options.device_type = xla_device_compiler.device_type(); options.flib_def = function_library.GetFunctionLibraryDefinition(); options.graph_def_version = function_library.graph_def_version(); options.allow_cpu_custom_calls = @@ -201,11 +226,11 @@ XlaCompiler::Options GenerateCompilerOptions( } XlaCompiler::Options GenerateTfrtTpuCompilerOptions( - const XlaCompilationCache& cache, + const XlaDeviceCompiler& xla_device_compiler, const FunctionLibraryRuntime& function_library) { XlaCompiler::Options options; // TODO(b/238830423): consider device_ordinal and shape_determination_fns. - options.device_type = cache.device_type(); + options.device_type = xla_device_compiler.device_type(); options.flib_def = function_library.GetFunctionLibraryDefinition(); options.graph_def_version = function_library.graph_def_version(); options.allow_cpu_custom_calls = false; diff --git a/tensorflow/compiler/jit/xla_platform_info.h b/tensorflow/compiler/jit/xla_platform_info.h index 4341bc8e394..86ccee99bce 100644 --- a/tensorflow/compiler/jit/xla_platform_info.h +++ b/tensorflow/compiler/jit/xla_platform_info.h @@ -16,7 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ #define TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ -#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include +#include + +#include "tensorflow/compiler/jit/device_compiler.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/xla/stream_executor/tf_allocator_adapter.h" @@ -88,9 +91,11 @@ StatusOr>> ParseVisibleDeviceList( absl::string_view visible_device_list); // Returns created XLA compilation cache. -Status BuildXlaCompilationCache(DeviceBase* dev, FunctionLibraryRuntime* flr, - const XlaPlatformInfo& platform_info, - XlaCompilationCache** cache); +Status BuildXlaDeviceCompiler( + DeviceBase* dev, FunctionLibraryRuntime* flr, + const XlaPlatformInfo& platform_info, + DeviceCompiler** + xla_device_compiler); // Returns information about the platform from kernel context. XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device); @@ -109,14 +114,16 @@ std::shared_ptr GetAllocator( // Returns created options for the XLA compiler, and writes the used allocator // into `tf_allocator_adapter`. XlaCompiler::Options GenerateCompilerOptions( - const XlaCompilationCache& cache, + const DeviceCompiler& + xla_device_compiler, const FunctionLibraryRuntime& function_library, DeviceBase* device, se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars); // Returns created options for XLA compiler when TFRT-TPU is used. XlaCompiler::Options GenerateTfrtTpuCompilerOptions( - const XlaCompilationCache& cache, + const DeviceCompiler& + xla_device_compiler, const FunctionLibraryRuntime& function_library); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_tpu_device.cc b/tensorflow/compiler/jit/xla_tpu_device.cc index a7830e51312..e6047e68bde 100644 --- a/tensorflow/compiler/jit/xla_tpu_device.cc +++ b/tensorflow/compiler/jit/xla_tpu_device.cc @@ -62,7 +62,7 @@ StatusOr TpuShapeRepresentation( ApiConverter::StackHelper se_shape(xla_shape); ApiConverter::StackHelper tpu_shape; StatusHelper status; - tpu::ExecutorApiFn()->XlaShapeToTpuShapeRepresentationFn( + stream_executor::tpu::ExecutorApiFn()->XlaShapeToTpuShapeRepresentationFn( &se_shape.value, type, use_fast_memory, &tpu_shape.value, status.c_status); if (!status.status().ok()) { @@ -93,7 +93,7 @@ Status TpuPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { StatusHelper status; ApiConverter::StackHelper se_shape(on_device_shape); ApiConverter::StackHelper tpu_shape; - tpu::ExecutorApiFn()->XlaShapeToTpuPaddedShapeFn( + stream_executor::tpu::ExecutorApiFn()->XlaShapeToTpuPaddedShapeFn( &se_shape.value, &tpu_shape.value, status.c_status); if (!status.ok()) { return status.status(); diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 42c78ff6143..6995615d10f 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -10,6 +10,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) @@ -74,7 +75,7 @@ cc_library( deps = [ ":init_mlir", "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/lite:tf_tfl_passes", + "//tensorflow/compiler/mlir/lite:tf_tfl_passes", # buildcleaner:keep "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util", "//tensorflow/compiler/mlir/tensorflow:mlprogram_util", @@ -82,20 +83,22 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", - "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", + "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", # buildcleaner:keep "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "//tensorflow/compiler/mlir/tosa:tf_passes", "//tensorflow/compiler/mlir/tosa:tf_tfl_passes", "//tensorflow/compiler/mlir/tosa:tfl_passes", "//tensorflow/compiler/mlir/xla:tf_xla_passes", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", - "//tensorflow/compiler/mlir/xla:xla_passes", + "//tensorflow/compiler/xla/mlir/framework/transforms:passes", "//tensorflow/compiler/xla/mlir_hlo:all_passes", "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", - "//tensorflow/core:lib", - "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TosaDialect", + "@llvm-project//mlir:Transforms", "@stablehlo//:register", ], ) @@ -188,7 +191,7 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/xla:tf_xla_passes", - "//tensorflow/compiler/mlir/xla:xla_passes", + "//tensorflow/compiler/xla/mlir/framework/transforms:passes", ], ) @@ -201,16 +204,14 @@ tf_cc_binary( ":passes", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_reduce_patterns_inc_gen", "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "//tensorflow/compiler/mlir/xla:tf_xla_passes", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", - "//tensorflow/compiler/mlir/xla:xla_passes", + "//tensorflow/compiler/xla/mlir/framework/transforms:passes", "//tensorflow/compiler/xla/mlir_hlo:all_passes", "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", - "//tensorflow/core:lib", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:MlirReduceLib", @@ -225,25 +226,18 @@ tf_cc_binary( srcs = ["tf_mlir_translate_main.cc"], deps = [ ":init_mlir", - "//tensorflow/compiler/mlir/tensorflow:export_graphdef", - "//tensorflow/compiler/mlir/tensorflow:import_model", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:tf_xla_mlir_translate", "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow:translate_registration", "//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op", - "//tensorflow/compiler/mlir/xla:translate_cl_registration", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/compiler/xla/translate/hlo_to_mhlo:translate_registration", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:translate_registration", - "//tensorflow/core:framework", + "//tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla:translate_registration", "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@llvm-project//mlir:TranslateLib", @@ -258,6 +252,7 @@ tf_cc_test( "//tensorflow/core:ops", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/framework:tensor_testutil", "@llvm-project//mlir:IR", ], ) diff --git a/tensorflow/compiler/mlir/g3doc/_includes/tf_passes.md b/tensorflow/compiler/mlir/g3doc/_includes/tf_passes.md index c0f9508ef76..a7a8bbdf953 100644 --- a/tensorflow/compiler/mlir/g3doc/_includes/tf_passes.md +++ b/tensorflow/compiler/mlir/g3doc/_includes/tf_passes.md @@ -931,6 +931,11 @@ Would become the following ops (unimportant attribute, type are omitted): }) {num_cores_per_replica = 1, topology = "", device_assignment = []} ``` ### `-tf-parallel-execute-to-islands`: Lowers device parallel_execute to executor islands + +#### Options +``` +-legacy-graph-export : Determines whether or not this pass should execute logic that is reserved for the legacy graph export pipeline to maintain expected invariants. In the case of this pass, that means manually propagating controls to lifted parallel execute regions to the graph fetch to ensure the ops execute. +``` ### `-tf-promote-resources-to-args`: Promote resources reads/writes to function inputs/outputs. This pass promotes resource accesses in function(s) (by default, the main) to input arguments and outputs of the function(s). @@ -1111,7 +1116,24 @@ tf_device.replicate([%0, %1] as %ri: tensor<*x!tf_type.resource>) {n = 2 : i32} tf_device.return } ``` +### `-tf-replicate-tensor-list-init-ops`: Replicate TensorList init ops for correct shape assignments in shape inference +If we pass same TensorList to a while op as multiple arguments or just use +the same TensorList at multiple places and assign different +TensorListSetItem to elements of TensorList, the shape inference is then +unable to identify the Shape of these args and thus the input TensorList +shape is unidentifiable. +All of these args are supposed to be independent and not related to original +creation of TensorList. + +This pass will create multiple instances of TensorList for each arg of the +while op and each use and thus there will be not a conflict in resolving the +shape of these different inputs. ### `-tf-replicate-to-island`: Lowers device replicate to executor islands + +#### Options +``` +-legacy-graph-export : Determines whether or not this pass should execute logic that is reserved for the legacy graph export pipeline to maintain expected invariants. In the case of this pass, that means manually propagating controls to lifted parallel execute regions to the graph fetch to ensure the ops execute, as well as determining whether or not the islands created by this pass should be split after the replicated ops have been lifted. +``` ### `-tf-resource-device-inference`: Propagates the device attribute on resources from callers to callees. A pass that propagates device assignment of resources on a module. It performs in-function propagation, as well as cross-function propagation from @@ -1549,6 +1571,7 @@ The transformation happens only for on-device variables. The above transformation requires `%arg0`, `%arg1` to have the same device assignment as the `TPUExecute` op. ### `-tf-tpu-parallel-execute-sink-resource-write`: Moves tf.AssignVariableOp consumers of tf_device.parallel_execute into tf_device.parallel_execute regions +### `-tf-tpu-partitioned-op-conversion`: Rewrite all TPU Partitioned ops into their V2 counterparts. ### `-tf-tpu-reorder-replicate-partitioned-inputs`: Reorder replicated and partitioned input ops. This pass rewrites how data parallelism and model parallelism is expressed for inputs. It reorders `tf.TPUPartitionedInput` (model parallelism) and diff --git a/tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md b/tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md deleted file mode 100644 index 1130199fbae..00000000000 --- a/tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md +++ /dev/null @@ -1,269 +0,0 @@ -# MLIR CodeGen for XLA - - - -XLA operates on `HloInstruction` and performs many optimizations on this -representation, sharing a lot of these between targeted devices. As some point a -linear schedule is computed and the memory buffer is assigned to each value -statically. The device specific codegen operates by traversing this sequence and -calling "emitters" to generate a representation suitable for the device (for -example a single LLVM function per XLA computation on CPU, or a sequence of -"thunks" encapsulating GPU operations and possibly generated PTX when targeting -GPU). - -As a staging step, we're currently in the process of intercepting the process -right after XLA completes the buffer-assignment phase and emit instead an MLIR -module in the `lhlo` dialect. From there we perform the codegen using MLIR -components (Linalg, affine, and GPU dialect mainly) depending on the device. - -Below is the plan of record to incrementally migrate XLA/GPU by using `lhlo` as -the codegen input. - -## Tasks - -| | Host | Device -| ------------- | ------------------------ | ------------------------ -| Input format | HloInstruction* (Task 1) | HloInstruction* (Task 1) -| Output format | xla::Thunk (Task 2) | LLVM IR (Task 3) - -* **Task 1** changes both host and device input format from HloInstruction* to - LHLO. -* **Task 2** changes output format of host from thunks to "some landing pad - for host" (see below). -* **Task 3** migrates device output from LLVM IR to some form of MLIR. It's - optional to this project, and see the section "Migrating Device LLVM IR" for - details. - -This project prioritizes having end-to-end runnable models with LHLO-emitters -enabled as much as possible. This implies that the following order list of -objectives by priority: - -* Make XLA/GPU runnable with LHLO emitters, with existing Thunks and emitters - unmodified. -* Eliminate the references to HloInstruction\* in LHLO, case by case: - * Switch a legacy emitter to an MLIR-based emitter (e.g. Linalg), or - * Mechanically translate the existing emitter to take MLIR representation - (migrate to Standard with GPU Dialect). - -## Migrating Thunks (Task 2) - -xla::gpu::Thunk is a data structure that: - -* Can be called into from the host (xla::gpu::Thunk::ExecuteOnStream()). -* Carries various data in its subclasses. -* Interacts with BufferAllocation::Slice and StreamExecutor. -* Launches kernels -* Calls into all runtime libraries. - -The cost of that includes: - -* Representing op-specific configuration data (e.g. convolution configs). -* Migrating op shape and operand shapes. -* Representing a tree of thunks (while, condition, etc). - -The migration work is independent from LHLO / emitter migration. Under limited -resources, it's prioritized behind LHLO / emitter migration. - -We have several choices on how to lower the host-side part from LHLO: - -* TFRT - * (Pro) great CUDA and HIP wrappers for use. - * (Pro) easy to implement library calls (cuDNN, cuBLAS, cuFFT, etc), as - TFRT ops are interpreted by C++ code. - * (Con) host side is under development and not tested. -* Jitted CPU code - * (Pro) great lower-ability. Create a few loops and conditions and it's - done. - * (Con) GPUDialect doesn't yet model chains/streams/asynchronicity/device - allocation. - * (Con) CUDA / HIP runtime support is minimal (toolkit path, version, - dynamic loading, etc). -* Existing (interpreting) XLA runtime - -Decision: adopt TFRT, but also support jitting CPU code in TFRT. - -## Migrating Device LLVM IR (Task 3) - -An elemental emitter generates target op by filling it element by element. Each -output element depends on a set of elements from the operands. All elements are -described by combining the buffer with dynamic indices. It's sufficient to -describe almost all "math" ops, but for performance reasons only a large subset -of "math" ops are implemented directly in (Cpu|Gpu)ElementalIrEmitter. - -ElementalIrEmitter is unique in that: - -* A large portion of the code is shared between XLA/GPU and CPU. -* It represents a large portion of ops seen in models, including all - element-wise ops. -* Most fusions solely depend on ElementalIrEmitter. -* It's structurally simple, as it describes a data dependency DAG between op - elements and operand elements. -* It's mostly portable and high-level (e.g. unlike GPU kReduce and GPU kCopy). -* Dynamic shape support is easy for at least element-wise ops. - -Now, for all ops, elementally-emitted or not, there are several flavors of the -end state of each XLA op: - -1. Device code stays as LLVM IR. -1. Refactor the old emitter to be like LHLO -> MLIR LLVM Dialect: - * (Cost) Will be throw-away work if we want to ultimately migrate to - Standard. - * (Benefit) It is easy and mechanical. Can be done in a short period. - * (Benefit) It doesn't benefit more compared to (1). -1. Refactor old emitters to be like LHLO -> MLIR GPU + Standard + Loops: - * (Cost) Lifting existing emitters to Standard introduces some challenges. - Pointers and GEPs need to be converted to MemRefs and SubViews. Ensuring - amdgpu completeness is another one. - * (Cost) XLA/GPU heavily relies on LLVM metadata: - * `range` for block/thread indices. - * `align`, `dereferenceable`, `invariant.load`, `alias.scope`, - `noalias` for load/stores. - * `llvm.loop.unroll.disable`, `llvm.loop.unroll.full`, - `llvm.loop.vectorize.enable` for sequential loops. - * (Benefit) Can be long-term. More portable. -1. Refactor old emitters to be LHLO -> Linalg, and write new Linalg emitters - * (Cost) This is case by case. Compared to previous options, a new - implementation that matches XLA's performance needs to go through the - benchmark <-> optimize workflow, which can be a significant cost for - some ops. - * (Benefit) unified stack; community support; portability; more - optimization potentials. - -Conclusions: - -* Don't go for (2). (1) or (3) are just better than (2). (2) costs more than - (1), since it requires a lot of mechanical refactoring. With (1) we can - still achieve the goal of enabling XLA to pick up MLIR emitters. This is by - doing LHLO -> LLVM IR -> run legacy device emitters. -* ElementalIrEmitter ops go for (4), but not incrementally. There is no way to - do it op by op, because all elementally-emitted ops are connected into the - same graph. This work can also serve as a unification point of several - on-going forces (the kernel generator, Linalg). -* All other ops go for (1). As a stretch goal, they might be migrated to (3) - or (4). - -## Prioritization - -While all three tasks mentioned above are parallelizable, under limited -resources they have to be serialized. The prioritization focuses on visible -results for completion of each task. - -The prioritization is: Task1 (LHLO for legacy emitters) > Task 2 (Thunks) > Task -3 (MLIR emitters). - -By the end of Task 1, users of XLA can generate an LHLO (e.g. kernel generator) -and execute them. The compilation format will not be serializable MLIR. - -By the end of Task 2, LHLO lowers to proper, serializable MLIR. This enables -offline compilation. - -By the end of Task 3, all XLA emitters are MLIR-based in its implementation. - -## Detailed Design - -### Step 1: (Task 1) Complete LHLO and Make Legacy Emitters Take LHLO - -This step makes all existing XLA/GPU emitters interact with MLIR ops. This step -is pure refactoring and NFC. - -This step is mostly mechanical, but it's worth noticing the following -discrepancies between an unnested HloComputation and LHLO: - -* Each HloInstruction has direct access to its operands (a data-flow DAG). On - contrary, each LHLO op only has access to its operand buffers (a bipartite - between ops and buffers). LHLO ops have to go through use-def chains to - access their operand ops. -* Unnested legacy emitters empirically almost never access their operands. The - only exception is kReduce. -* Unnested legacy emitters access BufferAssignment only for getting slices, - not for accessing aux data structures like dataflow\_analysis() or - alias\_analysis(). llvm\_ir builds its own alias\_analysis() based on slice - information. - -The conclusion is that LHLO should fit right-in without major hassle. - -### Step 2: (Optional) Profiling Support - -**This step is only needed if we start to discard some of the XLA Thunk logic -(see the next step).** - -Before actually turning on any MLIR-based emitters, we need profiling for -MLIR-based emitters. - -Currently XLA performs its own profiling by calling into StreamExecutor's timer. -The timer under the hood inserts two events before and after a kernel launch, -and measures the sync time between these two events. - -There are roughly three approaches to support profiling in MLIR: - -* Run a profiler end-to-end -* Add a profile op for each op in LHLO, using an injected profiler. - -The "end-to-end" approach is transparent to MLIR, but suffers the same problem -that makes XLA not use it in the first place: library calls collected by a -profiler (nvprof/...) can't easily relate to HLO ops. For example, cuDNN -launches multiple kernels for each HLO, and it's hard to tell which kernels -correspond to which HLO. - -The "injected profiler" approach requires: - -* LHLO to take a profiler as a parameter. -* inserting profile.start / profile.end before and after each op. -* a pass from that lowers profile.{start,end} to a C++ implementation. - -The exact profiling can't be easily done for MLIR-generated ops, since: - -* MLIR doesn't have a timer, nor it depends on TFRT / StreamExecutor. -* MLIR doesn't easily call into C functions with complicated parameters. - -### Step 3: (Task 2) Migrating Thunks - -As a note, there are roughly three kinds of thunks: -* KernelThunk, which launches a kernel. -* Control flow thunks, which has host control flow logic (conditional, while, - for, sequence) and launch body kernels. -* Library thunks: cuDNN, cuBLAS, cuFFT, NCCL, etc. - -The plan is: -* Make Thunks (de)serializable. -* Help improve TFRT to a state where it can support these semantics. -* As the state improves, migrate individual thunks incrementally. - -These action items are only partially ordered. The actual execution order / -engineering parallelism is to be evaluated as it goes. - -### Step 4: (Task 3) Migrated ElementalIrEmitter - -Once profiling is ready, we can complete and tune all ElementalIrEmitter-based -emitters in MLIR. Then we turn them on by default, assuming that all of these -MLIR-based emitters use a single stream. - -Notice that it's beneficial to migrate XLA/CPU's ElementalIrEmitter as well, -since they share a large portion of the code. - -With all benchmarking and performance hunting done (TODO: define performance -parity), we turn on the new MLIR-based elemental emitter, and delete the legacy -ElementalIrEmitter. - -This step also provides easy fusion transitions (nested ops) for the later -migration. - -### Step 5: Multi-Stream Support or Drop - -We can't delete -[some of the emitters](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/gpu/stream_assignment.cc#L140) -until we support it in MLIR, or we drop the feature. It's a relatively large -amount of work in MLIR and a small amount of gain for XLA. We should investigate -current users of multi-stream XLA/GPU users, and try to delete this feature if -reasonable. - -### Step 6: (Task 3) Migrated Device Ops - -This step migrates all unnested ops, then we can delete all unnested emitters. - -This calls on a rewrite/refactor for kCopy and kReduce. kReduce is already -worked on for plenty, so the actual amount of work that needs to be done remains -to be seen. diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 1840df2b1fc..9ce3e6913ad 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -5,6 +5,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ # TODO(jpienaar): Make the visibility more restrictive. ":friends", @@ -19,7 +20,6 @@ package_group( packages = [ "//learning/brain/experimental/mlir/tflite/tfmrt/...", "//learning/brain/mlir/...", - "//third_party/auroraml/...", "//third_party/iree/...", "//tensorflow/compiler/mlir/...", "//tensorflow/lite/python/...", @@ -312,6 +312,7 @@ cc_library( ], deps = [ ":cost_estimators", + ":size_utils", ":tensorflow_lite_op_enums_inc_gen", ":tensorflow_lite_op_interfaces_inc_gen", ":tensorflow_lite_ops_inc_gen", @@ -363,6 +364,31 @@ cc_library( ], ) +cc_library( + name = "size_utils", + srcs = [ + "utils/size_utils.cc", + ], + hdrs = [ + "utils/size_utils.h", + ], + deps = [ + "@llvm-project//mlir:IR", + ], +) + +tf_cc_test( + name = "size_utils_test", + size = "small", + srcs = ["utils/size_utils_test.cc"], + deps = [ + ":size_utils", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "cost_estimators", hdrs = [ @@ -385,13 +411,12 @@ cc_library( "utils/constant_utils.h", ], deps = [ - "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mangling_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:status", + "//tensorflow/tsl/platform:statusor", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -615,6 +640,7 @@ cc_library( ":lstm_utils", ":nms_utils", ":perception_ops_utils", + ":size_utils", ":stateful_ops_utils", ":tensorflow_lite", ":tensorflow_lite_passes_inc_gen", @@ -706,7 +732,9 @@ cc_library( "transforms/post_quantize.cc", "transforms/prepare_quantize.cc", "transforms/prepare_quantize_dynamic_range.cc", + "transforms/prepare_quantize_helper.cc", "transforms/quantize.cc", + "transforms/quantize_variables.cc", "utils/generated_op_quant_spec_getters.inc", ], hdrs = [ @@ -722,6 +750,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/xla/mlir_hlo", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -983,6 +1012,7 @@ cc_library( ":convert_type", ":flatbuffer_tflite_operator_lib", ":low_bit_utils", + ":size_utils", ":tensorflow_lite", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", @@ -1139,12 +1169,13 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/compiler/xla/translate/hlo_to_mhlo:translate", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", "//tensorflow/lite:framework", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", @@ -1227,15 +1258,12 @@ cc_library( ":common", ":flatbuffer_translate_lib", ":tensorflow_lite", - ":tensorflow_lite_legalize_tf", - ":tensorflow_lite_optimize", - ":tensorflow_lite_quantize", ":tf_tfl_passes", "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir/lite/metrics:error_collector_inst", "//tensorflow/compiler/mlir/lite/quantization:quantization_config", - "//tensorflow/compiler/mlir/lite/stablehlo:mhlo_tfl", "//tensorflow/compiler/mlir/lite/stablehlo:op_stat_pass", + "//tensorflow/compiler/mlir/lite/stablehlo:stablehlo_tfl", "//tensorflow/compiler/mlir/lite/stablehlo:transforms", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_passes", @@ -1249,12 +1277,12 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_freeze_variables", "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/tools/optimize:quantize_weights", "//tensorflow/lite/tools/optimize:reduced_precision_support", + "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index 994002b0a9b..6a8b8c1f656 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -249,7 +249,7 @@ static void EmitGetBuiltinOpCode(const std::vector &defs, << " return tflite::BuiltinOperator_" << operator_name << ";\n"; } - os << " return llvm::None;\n" + os << " return std::nullopt;\n" "}\n"; } @@ -335,7 +335,7 @@ static void EmitBuildOperator(const std::vector &defs, << "fbb);\n"; } - os << " return llvm::None;\n" + os << " return std::nullopt;\n" "}\n"; } @@ -499,7 +499,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { << "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool " "emit_error_on_verify_fail) {\n"; os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n"; - verify_ctx.withOp("top"); + verify_ctx.addSubst("_op", "top"); for (int i = 0, e = op.getNumOperands(); i < e; ++i) { auto &value = op.getOperand(i); diff --git a/tensorflow/compiler/mlir/lite/experimental/common/BUILD b/tensorflow/compiler/mlir/lite/experimental/common/BUILD new file mode 100644 index 00000000000..f94c2478774 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/experimental/common/BUILD @@ -0,0 +1,18 @@ +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") + +cc_library( + name = "outline_operations", + srcs = ["outline_operations.cc"], + hdrs = ["outline_operations.h"], + compatible_with = get_compatible_with_cloud(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.cc b/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.cc new file mode 100644 index 00000000000..d7142f7e6b5 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.cc @@ -0,0 +1,210 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/experimental/common/outline_operations.h" + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFL { +namespace common { + +bool IsConstantOrNone(Operation* op) { + return (op->getNumResults() == 1 && + op->getResult(0).getType().isa()) || + matchPattern(op, m_Constant()); +} + +// Pre-order traverse, adding results and BlockArgs to `been_defined` and +// collecting operands not contained within `been_defined`. If we encounter an +// operand that references a Value that has been defined (and added to +// `been_defined`) it is garuanteed that the Value definition is not contained +// in descedant node of reference, and given that the input DAG is valid, the +// definition is self-contained within `op` so it is not depended upon. +// Otherwise, the operand must have been defined somewhere above the Subgraph, +// so union with other operand dependencies. +llvm::SmallVector AccumulateOperandsDefinedAbove( + const llvm::SetVector& partition_ops) { + // Assuming that all are topologically sorted. + llvm::SetVector been_defined; + llvm::SetVector results; + auto update_from_op = [&](Operation* op) { + been_defined.insert(op->getResults().begin(), op->getResults().end()); + for (Value input : op->getOperands()) { + if (been_defined.contains(input)) { + continue; + } + results.insert(input); + } + }; + for (Operation* op : partition_ops) { + update_from_op(op); + op->walk([&](Block* nested_block) { + been_defined.insert(nested_block->getArguments().begin(), + nested_block->getArguments().end()); + for (Operation& op : nested_block->getOperations()) update_from_op(&op); + }); + } + return SmallVector(results.getArrayRef()); +} + +llvm::SmallVector AccumulateResultsDefinedWithin( + const llvm::SetVector& partition_ops) { + llvm::SmallVector values_for_results; + for (Operation* op : partition_ops) { + if (IsConstantOrNone(op)) { + continue; + } + for (Value output : op->getResults()) { + bool output_consumed_outside_subgraph = false; + for (Operation* consumer : output.getUsers()) { + if (llvm::all_of(partition_ops, [&](Operation* op) { + return !op->isAncestor(consumer); + })) { + output_consumed_outside_subgraph = true; + } + } + if (output_consumed_outside_subgraph) { + values_for_results.push_back(output); + } + } + } + return values_for_results; +} + +// Compute signature for raised func from arugments and outputs of +// Operation partition. +llvm::SmallVector TypesFromValues( + const llvm::SmallVector& values) { + llvm::SmallVector types; + for (auto value : values) { + types.push_back(value.getType()); + } + return types; +} + +func::FuncOp BuildFuncOp(const Subgraph& subgraph, OpBuilder& builder, + ModuleOp& module, OpsAdded& ops_added) { + // The parameters of the new MLIR function are taken to be the union + // of all operands referenced by Operations within the subraph. + // Likewise the results of the function are any Value(s) that are defined + // within the subgraph and are referenced outside the subgraph. + llvm::SmallVector input_types = + TypesFromValues(subgraph.FuncArguments()); + llvm::SmallVector return_types = + TypesFromValues(subgraph.FuncOutputs()); + + FunctionType function_type = + builder.getFunctionType(input_types, return_types); + + std::string function_name = absl::StrCat("func_", subgraph.subgraph_id_); + + func::FuncOp new_func = func::FuncOp::create(builder.getUnknownLoc(), + function_name, function_type); + new_func.setVisibility(func::FuncOp::Visibility::Private); + new_func.addEntryBlock(); + + // To form the body of the new function we need to clone each + // Operation along with its respective operands and result Values(s). + // The semantic of `Operation::clone` is copying given entity *into* this + // entity. The new FuncOp body is populated by cloning partitioned ops into + // it. Cloning Operation(s) will create cloned Value(s) for the results of a + // cloned op, but it needs a reference to the new operand Value(s) which are + // the result of the cloned ops. The approach is to traverse the subgraph in + // order, accumulating clones of defined Values into a `IRMapping` + // and pass that map to calls to clone ops. + OpBuilder function_builder(new_func.getBody()); + // Prefered data structure for mapping MLIR values. + IRMapping values_in_scope; + // Function arguments can appear as operands, so they clone should + // be aware of them. + assert(subgraph.FuncArguments().size() == new_func.getNumArguments()); + for (int i = 0; i < subgraph.FuncArguments().size(); ++i) { + Value original_value = subgraph.FuncArguments()[i]; + Value new_func_arg = new_func.getArgument(i); + values_in_scope.map(original_value, new_func_arg); + } + + for (Operation* op : subgraph.partition_ops_) { + function_builder.clone(*op, values_in_scope); + } + SmallVector return_operands; + for (Value result : subgraph.FuncOutputs()) { + Value cloned_output = values_in_scope.lookup(result); + return_operands.push_back(cloned_output); + } + function_builder.create(new_func.getLoc(), + return_operands); + ops_added.func_op = new_func; + module.push_back(new_func); + return new_func; +} + +void ExtractSubgraphToFunc(const Subgraph& subgraph, OpBuilder& builder, + ModuleOp& module, OpsAdded& ops_added) { + func::FuncOp func = BuildFuncOp(subgraph, builder, module, ops_added); + + // We just use the location of the last ops in the subgraph as the location + // for the call_op. + Operation* last_output = subgraph.partition_ops_.back(); + + builder.setInsertionPoint(last_output); + auto call_op = builder.create(last_output->getLoc(), func, + subgraph.FuncArguments()); + ops_added.call_op = call_op; + // FuncOutputs refer to the original `Values` in input module which are now + // invalid after pulling out the defining ops. The values in + // `call_ops.getResult` refer to the clones of original `Values` which are now + // returned by the new `FuncOp`. We can replace each in `FuncOutputs` with + // clone in `call_op` to fix up. + for (int i = 0; i < subgraph.FuncOutputs().size(); ++i) { + Value output = subgraph.FuncOutputs()[i]; + output.replaceAllUsesWith(call_op.getResult(i)); + } + + // Clear the subgraph. + // Those ops should be removed. + for (auto* op : subgraph.partition_ops_) { + if (IsConstantOrNone(op)) { + continue; + } + op->dropAllDefinedValueUses(); + op->dropAllReferences(); + op->erase(); + } +} + +} // namespace common +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.h b/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.h new file mode 100644 index 00000000000..143a37d12c6 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.h @@ -0,0 +1,132 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_COMMON_OUTLINE_OPERATIONS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_COMMON_OUTLINE_OPERATIONS_H_ + +#include +#include +#include + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_os_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/UseDefLists.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" + +namespace mlir { +namespace TFL { +namespace common { + +// Returns true if the `op` is a constant-like op or produces none type. +bool IsConstantOrNone(Operation* op); + +// Computes the list of Value(s) referenced by Subgraph Operations that are +// not defined within the Subgraph. Any such Value(s) +// are validly in-scope for the initial Operation. They must be either +// defined above the subgraph or appear as an argument to the containing func. +// These Value(s) are taken to be the arguments of the new raised func. +// An operand dependency is a Value referenced anywhere in an Op +// that is defined above the Op. All SSA Values are assigned/defined in a +// BlockArg or as a result of an Operation. +llvm::SmallVector AccumulateOperandsDefinedAbove( + const llvm::SetVector& partition_ops); + +// Similar to `AccumulateOperandsDefinedAbove()`, computes the Value(s) that are +// defined within a Subgraph and referenced in a descendant Operation. These +// Values(s) are to be returned by the new raised function. +llvm::SmallVector AccumulateResultsDefinedWithin( + const llvm::SetVector& partition_ops); + +// Represents a view of a set of mlir Operations that form a subgraph of the +// entire Module's DAG. `Subgraph` can be thought of as segment of sequential +// Operations within a func definition. Additional facts: +// 1. Subgraphs are restricted to a single Block. They do not span +// branching instructions. Thus the subgraph is a simple 1-degree path. +// 2. All Operations in a subgraph belong to the same block in a +// funtion body. +// 3. Function bodies are assumed to have only one block in some places. +class Subgraph { + // Set vector preserves insertion order, must insert Ops in topological order. + public: + const llvm::SetVector partition_ops_; + + // Subgraphs are given a unique incremented integer id based on when + // they were encountered in this pass. + const int subgraph_id_; + + const llvm::StringRef dialect_namespace_; + + Subgraph(const llvm::SetVector partition_ops, int num_subgraphs) + : partition_ops_(partition_ops), + subgraph_id_(num_subgraphs), + func_arguments_(AccumulateOperandsDefinedAbove(partition_ops)), + func_outputs_(AccumulateResultsDefinedWithin(partition_ops)) {} + + const llvm::SmallVector& FuncArguments() const { + // `Value`s in MLIR library are implemented as having "value semantics" + // see "llvm/llvm-project/mlir/include/mlir/IR/Value.h" so copying is fine. + return func_arguments_; + } + const llvm::SmallVector& FuncOutputs() const { return func_outputs_; } + + private: + // Compute once at construction and save as field. + const llvm::SmallVector func_arguments_; + const llvm::SmallVector func_outputs_; +}; + +// Helper data structure for output parameters to `ExtractSubgraphToFunc`. +// `ExtractSubgraphToFunc` adds exactly two "new" `Operations`, a FuncOp and +// a CallOp. Pass these back to the caller for setting more specific attributes +// after graph mutation has taken place. +struct OpsAdded { + mlir::func::FuncOp func_op; + mlir::func::CallOp call_op; +}; + +// Given a `Subgraph` containing a sequence of adjacent `Operations` from +// the `module`, raise these `Operations` (and any ops contained nested within) +// to the body of a new seperate root level function. Replace in their current +// location with a `CallOp` which invokes said `FuncOp`. The inputs to +// this new functions are taken to be the `Values` that appear as operands +// to ops in the subgraph, which are not self-contained within the subgraph. +// The outputs of this function are taken to be the results of ops in the +// subgraph which are referenced as operands outside of the subgraph. +// Also refer to documention of `AccumulateOperandsDefinedAbove` & +// `AccumulateResultsDefinedWithin`. +void ExtractSubgraphToFunc(const Subgraph& subgraph, OpBuilder& builder, + ModuleOp& module, OpsAdded& ops_added); + +} // namespace common +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_COMMON_OUTLINE_OPERATIONS_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD index 4808005008a..989aa30e1ed 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD @@ -8,6 +8,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//visibility:public", ], @@ -245,12 +246,13 @@ cc_library( ":tac_importer_exporter", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tf_tfl_passes", + "//tensorflow/compiler/mlir/lite/experimental/common:outline_operations", "//tensorflow/compiler/mlir/lite/experimental/tac/hardwares:target_hardware", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:error_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -267,7 +269,6 @@ cc_library( ":target_aware_conversion", "//tensorflow/compiler/mlir:tf_mlir_opt_main", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration", - "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", ], alwayslink = 1, ) @@ -287,6 +288,7 @@ tf_cc_binary( testonly = True, deps = [ ":tac-opt_lib", + "//tensorflow/compiler/mlir/lite/experimental/common:outline_operations", "//tensorflow/compiler/mlir/lite/experimental/tac/hardwares:all-target-hardwares", ], ) @@ -303,6 +305,7 @@ cc_library( ":tflite_importer_exporter", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf", + "//tensorflow/compiler/mlir/lite:tensorflow_lite_optimize", "//tensorflow/compiler/mlir/lite/experimental/tac/hardwares:target_hardware", "//tensorflow/compiler/mlir/lite/experimental/tac/utils", "//tensorflow/compiler/mlir/tensorflow", diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h b/tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h index 60f6c38e396..ecbfcd8bb2c 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_SUBGRAPH_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_SUBGRAPH_H_ +#include #include #include "llvm/ADT/StringRef.h" @@ -40,7 +41,7 @@ constexpr char kInterfaceNameAttr[] = "tac.interface_name"; inline llvm::Optional GetInterFaceName(Operation* op) { auto name_attr = op->getAttrOfType(kInterfaceNameAttr); - if (!name_attr) return llvm::None; + if (!name_attr) return std::nullopt; return name_attr.getValue().str(); } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h b/tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h index 36246df45a7..5674708f697 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h @@ -18,11 +18,11 @@ limitations under the License. #include #include +#include #include #include #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/Operation.h" // from @llvm-project @@ -88,7 +88,7 @@ inline std::string GetCanonicalHardwareName(const std::string& hardware_name) { // Get the target annotation form the op. inline llvm::Optional GetTargetAnnotation(Operation* op) { auto device = op->getAttrOfType(kDevice); - if (device == nullptr || device.getValue().empty()) return llvm::None; + if (device == nullptr || device.getValue().empty()) return std::nullopt; return GetCanonicalHardwareName(device.getValue().str()); } @@ -96,7 +96,7 @@ inline llvm::Optional GetTargetAnnotation(Operation* op) { // Get inference type attribute from the operation if available. inline llvm::Optional GetInferenceTypeAnnotation(Operation* op) { auto inference_type = op->getAttrOfType(kInferenceType); - if (inference_type == nullptr) return llvm::None; + if (inference_type == nullptr) return std::nullopt; llvm::StringRef device_name_str = inference_type.getValue(); return GetInferenceTypeEnum(device_name_str); @@ -129,14 +129,14 @@ struct InferenceDeviceType { inline llvm::Optional GetInferenceDeviceTypeForOp( Operation* op) { auto hardware = GetTargetAnnotation(op); - if (!hardware.has_value()) return llvm::None; + if (!hardware.has_value()) return std::nullopt; auto inference_type = GetInferenceTypeAnnotation(op); - if (!inference_type.has_value()) return llvm::None; + if (!inference_type.has_value()) return std::nullopt; InferenceDeviceType inference_device_type; - inference_device_type.hardware = hardware.getValue(); - inference_device_type.inference_type = inference_type.getValue(); + inference_device_type.hardware = hardware.value(); + inference_device_type.inference_type = inference_type.value(); return inference_device_type; } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/examples/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/examples/BUILD index 998b7770251..57fb5ea9eef 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/examples/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/examples/BUILD @@ -2,6 +2,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//visibility:public", ], diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.cc b/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.cc index a76de268f5f..7fcd4d8b37d 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.cc @@ -16,12 +16,12 @@ #include #include +#include #include #include #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project @@ -56,17 +56,17 @@ bool HasValidHardwareTarget(mlir::Operation* op) { } llvm::Optional GetDeviceName(mlir::Operation* op) { - if (IsConst(op)) return llvm::None; + if (IsConst(op)) return std::nullopt; // The model may contain quant stats op which is unrelevant to the // execution. if (llvm::isa(op)) - return llvm::None; + return std::nullopt; - if (!HasValidHardwareTarget(op)) return llvm::None; + if (!HasValidHardwareTarget(op)) return std::nullopt; auto device = op->getAttrOfType(mlir::TFL::tac::kDevice); - if (device == nullptr) return llvm::None; + if (device == nullptr) return std::nullopt; llvm::StringRef device_name_str = device.getValue(); return device_name_str.str(); @@ -76,13 +76,13 @@ llvm::Optional> GetPerDeviceCosts( const std::map& hardware_map, mlir::Operation* op) { auto device_costs_attr = op->getAttrOfType("per_device_costs"); - if (device_costs_attr == nullptr) return llvm::None; + if (device_costs_attr == nullptr) return std::nullopt; std::vector device_costs(hardware_map.size(), -1.f); for (const auto& kv : hardware_map) { auto cost_attr = device_costs_attr.getNamed(kv.first); - if (!cost_attr.has_value()) return llvm::None; + if (!cost_attr.has_value()) return std::nullopt; float cost = cost_attr->getValue() .dyn_cast_or_null() .getValueAsDouble(); @@ -116,13 +116,12 @@ flatbuffers::Offset CreateSubgraphMetadata( flatbuffers::Offset> per_device_cost_offset; if (per_device_cost.has_value()) { - per_device_cost_offset = - builder->CreateVector(per_device_cost.getValue()); + per_device_cost_offset = builder->CreateVector(*per_device_cost); } OpMetadataBuilder op_builder(*builder); op_builder.add_index(index); - uint8_t hardware = hardware_map.at(device_name.getValue()); + uint8_t hardware = hardware_map.at(*device_name); op_builder.add_hardware(hardware); if (per_device_cost.has_value()) { @@ -147,9 +146,9 @@ CreateHardwareMetadataAndPopulateLookupTable( auto device_name = GetDeviceName(op); if (!device_name.has_value()) return; - auto iter = hardware_names->find(device_name.getValue()); + auto iter = hardware_names->find(*device_name); if (iter == hardware_names->end()) { - hardware_names->insert({device_name.getValue(), index++}); + hardware_names->insert({*device_name, index++}); } }); } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter_test.cc b/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter_test.cc index 19b041e3b35..6d17c7f6ff6 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter_test.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter_test.cc @@ -114,8 +114,7 @@ func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32> mlir::parseSourceString(kMLIR, &context)); auto module_op = module.get(); auto serialized_result_fb = ExportRuntimeMetadata(module_op); - const auto* result = - GetRuntimeMetadata(serialized_result_fb.getValue().c_str()); + const auto* result = GetRuntimeMetadata(serialized_result_fb.value().c_str()); const auto* expected = GetRuntimeMetadata(kExpectedFB.c_str()); ASSERT_TRUE(result != nullptr); ASSERT_TRUE(result->subgraph_metadata() != nullptr); diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/BUILD index c04afe663bc..b62f6c5cf4b 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/BUILD @@ -1,6 +1,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//visibility:public", ], diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD index b5a65f58aaa..154afe2d923 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD @@ -2,6 +2,7 @@ load("//tensorflow:tensorflow.default.bzl", "pybind_extension") load("//tensorflow:tensorflow.bzl", "VERSION") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/compiler/mlir/lite/experimental/tac:__subpackages__", ], @@ -59,6 +60,8 @@ pybind_extension( "@farmhash_gpu_archive//:__subpackages__", "@fft2d//:__subpackages__", "@flatbuffers//:__subpackages__", + "@FP16//:__subpackages__", + "@FXdiv//:__subpackages__", "@gemmlowp//:__subpackages__", "@gif//:__subpackages__", "@highwayhash//:__subpackages__", @@ -84,6 +87,7 @@ pybind_extension( "@org_sqlite//:__subpackages__", "@platforms//:__subpackages__", "@png//:__subpackages__", + "@pthreadpool//:__subpackages__", "@pybind11//:__subpackages__", "@ruy//:__subpackages__", "@snappy//:__subpackages__", diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc b/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc index b0b278f2ad4..805b7802517 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc @@ -80,7 +80,7 @@ const tac::TargetHardware* TacModule::GetTargetHardware( } absl::Status TacModule::RunTacPasses(mlir::ModuleOp* module, bool debug_mode) { - mlir::PassManager pm(module->getContext(), + mlir::PassManager pm((*module)->getName(), mlir::OpPassManager::Nesting::Implicit); AddTACPass(&pm, options_.hardware_backends); if (!debug_mode) { diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/tests/BUILD index 060d24b5000..1ae5f737d37 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/BUILD @@ -1,7 +1,10 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], @@ -14,6 +17,7 @@ filegroup( name = "test_utilities", testonly = True, data = [ + "//tensorflow/compiler/mlir/lite/experimental/common:outline_operations", "//tensorflow/compiler/mlir/lite/experimental/tac:tac-opt-all-backends", "@llvm-project//llvm:FileCheck", "@llvm-project//llvm:not", diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/e2e/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/tests/e2e/BUILD index 18dd066d9ec..8fef794a866 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/e2e/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/e2e/BUILD @@ -2,6 +2,7 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//visibility:public", ], diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/fold-constants-to-subgraph.mlir b/tensorflow/compiler/mlir/lite/experimental/tac/tests/fold-constants-to-subgraph.mlir index 94534bd7c97..e8a30755a8c 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/fold-constants-to-subgraph.mlir +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/fold-constants-to-subgraph.mlir @@ -114,3 +114,28 @@ func.func @fold_all_test(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3x // ALL: return %[[VAL_5]] : tensor<256x30x30x16xf32> // ALL: } } + +// ----- + +module { + +func.func @main(%arg0: tensor<4x384x32xf32>) -> tensor<1x384x32xf32> { + %0 = arith.constant dense<0> : tensor<3xi32> + %1 = arith.constant dense<[1, 384, 32]> : tensor<3xi32> + %2 = func.call @simple_test(%arg0, %0, %1) {tac.interface_name = "func1"} : (tensor<4x384x32xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x384x32xf32> + func.return %2 : tensor<1x384x32xf32> +} + +// PARTIAL-LABEL: @simple_test +func.func @simple_test(%arg0: tensor<4x384x32xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<1x384x32xf32> attributes {tac.interface_name = "func1"} { + %0 = "tfl.slice"(%arg0, %arg1, %arg2) : (tensor<4x384x32xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x384x32xf32> + func.return %0 : tensor<1x384x32xf32> +} + +// PARTIAL: func @simple_test(%[[VAL_0:.*]]: tensor<4x384x32xf32>, %[[VAL_1:.*]]: tensor<3xi32>, %[[VAL_2:.*]]: tensor<3xi32>) -> tensor<1x384x32xf32> attributes {tac.interface_name = "func1"} { +// PARTIAL: %[[VAL_3:.*]] = arith.constant dense<[1, 384, 32]> : tensor<3xi32> +// PARTIAL: %[[VAL_4:.*]] = arith.constant dense<0> : tensor<3xi32> +// PARTIAL: %[[VAL_5:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_4]], %[[VAL_3]]) : (tensor<4x384x32xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x384x32xf32> +// PARTIAL: return %[[VAL_5]] : tensor<1x384x32xf32> +// PARTIAL: } +} diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/raise-target-subgraphs.mlir b/tensorflow/compiler/mlir/lite/experimental/tac/tests/raise-target-subgraphs.mlir index 71e9c0bd69c..07639de6bca 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/raise-target-subgraphs.mlir +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/raise-target-subgraphs.mlir @@ -1,4 +1,94 @@ -// RUN: tac-opt-all-backends -tfl-raise-target-subgraphs %s -split-input-file -verify-diagnostics | FileCheck %s +// RUN: tac-opt-all-backends -tfl-raise-target-subgraphs %s -split-input-file | FileCheck %s + +module { +func.func @simpleWhile(%arg0: tensor) -> tensor { + %0 = "tfl.while"(%arg0) ({ + ^bb0(%block: tensor): + "tfl.yield"(%block) : (tensor) -> () + },{ + ^bb0(%block: tensor): + "tfl.yield"(%block) : (tensor) -> () + }) {tac.device = "CPU", fused_activation_function = "RELU6", tac.inference_type = "FLOAT"} : (tensor) -> tensor + func.return %0 : tensor +} +} + +// CHECK: func.func @simpleWhile(%arg0: tensor) -> tensor { +// CHECK: %0 = call @func_0_CPU_FLOAT(%arg0) {tac.device = "CPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} : (tensor) -> tensor +// CHECK: return %0 : tensor +// CHECK: } +// CHECK: func.func private @func_0_CPU_FLOAT(%arg0: tensor) -> tensor attributes {tac.device = "CPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} { +// CHECK: %0 = "tfl.while"(%arg0) ({ +// CHECK: ^bb0(%arg1: tensor): +// CHECK: "tfl.yield"(%arg1) : (tensor) -> () +// CHECK: }, { +// CHECK: ^bb0(%arg1: tensor): +// CHECK: "tfl.yield"(%arg1) : (tensor) -> () +// CHECK: }) {fused_activation_function = "RELU6", tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor) -> tensor +// CHECK: return %0 : tensor +// CHECK: } + +// ----- + +module { +func.func @whileWithNested(%arg0: tensor) -> tensor { + %0 = "tfl.while"(%arg0) ({ + ^bb0(%block: tensor): + %1 = "tfl.add"(%arg0, %arg0) { fused_activation_function = "NONE", tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor) -> (tensor) + %2 = "tfl.add"(%1, %1) { fused_activation_function = "NONE", tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor, tensor) -> (tensor) + "tfl.yield"(%2) : (tensor) -> () + },{ + ^bb0(%block: tensor): + "tfl.yield"(%block) : (tensor) -> () + }) {tac.device = "CPU", fused_activation_function = "RELU6", tac.inference_type = "FLOAT"} : (tensor) -> tensor + func.return %0 : tensor +} +} + +// CHECK: func.func @whileWithNested(%arg0: tensor) -> tensor { +// CHECK: %0 = call @func_0_CPU_FLOAT(%arg0) {tac.device = "CPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} : (tensor) -> tensor +// CHECK: return %0 : tensor +// CHECK: } +// CHECK: func.func private @func_0_CPU_FLOAT(%arg0: tensor) -> tensor attributes {tac.device = "CPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} { +// CHECK: %0 = "tfl.while"(%arg0) ({ +// CHECK: ^bb0(%arg1: tensor): +// CHECK: %1 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE", tac.device = "CPU", tac.inference_type = "FLOAT"} : tensor +// CHECK: %2 = func.call @func_1_GPU_FLOAT(%1) {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_1"} : (tensor) -> tensor +// CHECK: "tfl.yield"(%2) : (tensor) -> () +// CHECK: }, { +// CHECK: ^bb0(%arg1: tensor): +// CHECK: "tfl.yield"(%arg1) : (tensor) -> () +// CHECK: }) {fused_activation_function = "RELU6", tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor) -> tensor +// CHECK: return %0 : tensor +// CHECK: } +// CHECK: func.func private @func_1_GPU_FLOAT(%arg0: tensor) -> tensor attributes {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_1"} { +// CHECK: %0 = tfl.add %arg0, %arg0 {fused_activation_function = "NONE", tac.device = "GPU", tac.inference_type = "FLOAT"} : tensor +// CHECK: return %0 : tensor +// CHECK: } + + + + + +// ----- + +module { +func.func @degenerateCase(%arg0: tensor<1xf32>) -> tensor<1xf32> { + %0 = "tfl.add"(%arg0, %arg0) {tac.device = "GPU", fused_activation_function = "RELU6", tac.inference_type = "FLOAT"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + func.return %0 : tensor<1xf32> +} +} + +// CHECK: func.func @degenerateCase(%arg0: tensor<1xf32>) -> tensor<1xf32> { +// CHECK: %0 = call @func_0_GPU_FLOAT(%arg0) {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} : (tensor<1xf32>) -> tensor<1xf32> +// CHECK: return %0 : tensor<1xf32> +// CHECK: } +// CHECK: func.func private @func_0_GPU_FLOAT(%arg0: tensor<1xf32>) -> tensor<1xf32> attributes {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} { +// CHECK: %0 = tfl.add %arg0, %arg0 {fused_activation_function = "RELU6", tac.device = "GPU", tac.inference_type = "FLOAT"} : tensor<1xf32> +// CHECK: return %0 : tensor<1xf32> +// CHECK: } + +// ----- module { func.func @simpleTest(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>, %arg3: tensor<1xf32>) -> tensor<2x1xf32> { @@ -11,12 +101,12 @@ func.func @simpleTest(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>, %arg2: tensor< } // CHECK: func @simpleTest(%[[VAL_0:.*]]: tensor<1xf32>, %[[VAL_1:.*]]: tensor<1xf32>, %[[VAL_2:.*]]: tensor<1xf32>, %[[VAL_3:.*]]: tensor<1xf32>) -> tensor<2x1xf32> { -// CHECK: %[[VAL_4:.*]]:2 = call @func_0_GPU_FLOAT(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_0]], %[[VAL_3]]) {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) +// CHECK: %[[VAL_4:.*]]:2 = call @func_0_GPU_FLOAT(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) // CHECK: %[[VAL_5:.*]] = call @func_1_CPU_FLOAT(%[[VAL_4]]#0, %[[VAL_4]]#1) {tac.device = "CPU", tac.inference_type = "FLOAT", tac.interface_name = "func_1"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<2x1xf32> // CHECK: return %[[VAL_5]] : tensor<2x1xf32> // CHECK: } -// CHECK: func private @func_0_GPU_FLOAT(%[[VAL_0:.*]]: tensor<1xf32>, %[[VAL_1:.*]]: tensor<1xf32>, %[[VAL_2:.*]]: tensor<1xf32>, %[[VAL_3:.*]]: tensor<1xf32>, %[[VAL_4:.*]]: tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) attributes {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} { +// CHECK: func private @func_0_GPU_FLOAT(%[[VAL_0:.*]]: tensor<1xf32>, %[[VAL_1:.*]]: tensor<1xf32>, %[[VAL_2:.*]]: tensor<1xf32>, %[[VAL_4:.*]]: tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) attributes {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} { // CHECK: %[[VAL_5:.*]] = tfl.add %[[VAL_0]], %[[VAL_1]] {fused_activation_function = "RELU6", tac.device = "GPU", tac.inference_type = "FLOAT"} : tensor<1xf32> // CHECK: %[[VAL_6:.*]] = tfl.mul %[[VAL_5]], %[[VAL_2]] {fused_activation_function = "RELU6", tac.device = "GPU", tac.inference_type = "FLOAT"} : tensor<1xf32> // CHECK: %[[VAL_7:.*]] = tfl.add %[[VAL_0]], %[[VAL_4]] {fused_activation_function = "RELU6", tac.device = "GPU", tac.inference_type = "FLOAT"} : tensor<1xf32> @@ -109,6 +199,7 @@ func.func @norm2(%arg0: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> { %10 = "tfl.add"(%1, %9) {tac.device = "GPU", tac.inference_type = "FLOAT", fused_activation_function = "NONE"} : (tensor<1x128x128xf32>, tensor<1x128x128xf32>) -> tensor<1x128x128xf32> func.return %10 : tensor<1x128x128xf32> } +} // CHECK: func @norm2(%[[VAL_0:.*]]: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> { // CHECK-DAG: %[[VAL_1:.*]] = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<128xf32>} : () -> tensor<128xf32> @@ -122,25 +213,23 @@ func.func @norm2(%arg0: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> { // CHECK: return %[[VAL_8]] : tensor<1x128x128xf32> // CHECK: } -// CHECK: func private @func_0_GPU_FLOAT(%[[VAL_0:.*]]: tensor<1x128x128xf32>, %[[VAL_1:.*]]: tensor<128xf32>, %[[VAL_2:.*]]: tensor<2xi32>) -> (tensor<1x128x128xf32>, tensor<128x128xf32>) attributes {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} { +// CHECK: func.func private @func_0_GPU_FLOAT(%[[VAL_0:.*]]: tensor<1x128x128xf32>, %[[VAL_1:.*]]: tensor<128xf32>, %[[VAL_2:.*]]: tensor<2xi32>) -> (tensor<1x128x128xf32>, tensor<128x128xf32>) attributes {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} { // CHECK: %[[VAL_3:.*]] = tfl.add(%[[VAL_0]], %[[VAL_1]]) {fused_activation_function = "NONE", tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x128x128xf32>, tensor<128xf32>) -> tensor<1x128x128xf32> // CHECK: %[[VAL_4:.*]] = "tfl.reshape"(%[[VAL_3]], %[[VAL_2]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<1x128x128xf32>, tensor<2xi32>) -> tensor<128x128xf32> // CHECK: %[[VAL_5:.*]] = "tfl.relu"(%[[VAL_4]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<128x128xf32>) -> tensor<128x128xf32> // CHECK: return %[[VAL_3]], %[[VAL_5]] : tensor<1x128x128xf32>, tensor<128x128xf32> // CHECK: } -// CHECK: func private @func_2_GPU_FLOAT(%[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<3xi32>, %[[VAL_2:.*]]: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> attributes {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_2"} { -// CHECK: %[[VAL_3:.*]] = "tfl.reshape"(%[[VAL_0]], %[[VAL_1]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<128x128xf32>, tensor<3xi32>) -> tensor<1x128x128xf32> -// CHECK: %[[VAL_4:.*]] = tfl.add %[[VAL_2]], %[[VAL_3]] {fused_activation_function = "NONE", tac.device = "GPU", tac.inference_type = "FLOAT"} : tensor<1x128x128xf32> -// CHECK: return %[[VAL_4]] : tensor<1x128x128xf32> -// CHECK: } - -// CHECK: func private @func_1_CPU_FLOAT(%[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<128x128xf32>, %[[VAL_2:.*]]: tensor<128xf32>) -> tensor<128x128xf32> attributes {tac.device = "CPU", tac.inference_type = "FLOAT", tac.interface_name = "func_1"} { +// CHECK: func.func private @func_1_CPU_FLOAT(%[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<128x128xf32>, %[[VAL_2:.*]]: tensor<128xf32>) -> tensor<128x128xf32> attributes {tac.device = "CPU", tac.inference_type = "FLOAT", tac.interface_name = "func_1"} { // CHECK: %[[VAL_3:.*]] = "tfl.fully_connected"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) {fused_activation_function = "NONE", keep_num_dims = false, tac.device = "CPU", tac.inference_type = "FLOAT", weights_format = "DEFAULT"} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<128x128xf32> // CHECK: return %[[VAL_3]] : tensor<128x128xf32> // CHECK: } -} +// CHECK: func.func private @func_2_GPU_FLOAT(%[[VAL_0:.*]]: tensor<128x128xf32>, %[[VAL_1:.*]]: tensor<3xi32>, %[[VAL_2:.*]]: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> attributes {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_2"} { +// CHECK: %[[VAL_3:.*]] = "tfl.reshape"(%[[VAL_0]], %[[VAL_1]]) {tac.device = "GPU", tac.inference_type = "FLOAT"} : (tensor<128x128xf32>, tensor<3xi32>) -> tensor<1x128x128xf32> +// CHECK: %[[VAL_4:.*]] = tfl.add %[[VAL_2]], %[[VAL_3]] {fused_activation_function = "NONE", tac.device = "GPU", tac.inference_type = "FLOAT"} : tensor<1x128x128xf32> +// CHECK: return %[[VAL_4]] : tensor<1x128x128xf32> +// CHECK: } // ----- @@ -159,11 +248,11 @@ func.func @quantizedOpOnly(%arg0: tensor<1x!quant.uniform>, // CHECK: func @quantizedOpOnly(%[[VAL_0:.*]]: tensor<1x!quant.uniform>, %[[VAL_1:.*]]: tensor<1x!quant.uniform>) -> tensor<2x1x!quant.uniform> { // CHECK: %[[VAL_2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x!quant.uniform>, value = dense<127> : tensor<1xi8>} : () -> tensor<1x!quant.uniform> // CHECK: %[[VAL_3:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x!quant.uniform>, value = dense<127> : tensor<1xi8>} : () -> tensor<1x!quant.uniform> -// CHECK: %[[VAL_4:.*]] = call @func_0_CPU_QUANTIZED_INT8(%[[VAL_0]], %[[VAL_2]], %[[VAL_3]], %[[VAL_1]], %[[VAL_2]]) {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8", tac.interface_name = "func_0"} : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>, tensor<1x!quant.uniform>, tensor<1x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<2x1x!quant.uniform> +// CHECK: %[[VAL_4:.*]] = call @func_0_CPU_QUANTIZED_INT8(%[[VAL_0]], %[[VAL_2]], %[[VAL_3]], %[[VAL_1]]) {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8", tac.interface_name = "func_0"} : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>, tensor<1x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<2x1x!quant.uniform> // CHECK: return %[[VAL_4]] : tensor<2x1x!quant.uniform> // CHECK: } -// CHECK: func private @func_0_CPU_QUANTIZED_INT8(%[[VAL_0:.*]]: tensor<1x!quant.uniform>, %[[VAL_1:.*]]: tensor<1x!quant.uniform>, %[[VAL_2:.*]]: tensor<1x!quant.uniform>, %[[VAL_3:.*]]: tensor<1x!quant.uniform>, %[[VAL_4:.*]]: tensor<1x!quant.uniform>) -> tensor<2x1x!quant.uniform> attributes {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8", tac.interface_name = "func_0"} { +// CHECK: func private @func_0_CPU_QUANTIZED_INT8(%[[VAL_0:.*]]: tensor<1x!quant.uniform>, %[[VAL_1:.*]]: tensor<1x!quant.uniform>, %[[VAL_2:.*]]: tensor<1x!quant.uniform>, %[[VAL_3:.*]]: tensor<1x!quant.uniform>) -> tensor<2x1x!quant.uniform> attributes {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8", tac.interface_name = "func_0"} { // CHECK: %[[VAL_5:.*]] = tfl.mul %[[VAL_0]], %[[VAL_1]] {fused_activation_function = "NONE", tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : tensor<1x!quant.uniform> // CHECK: %[[VAL_6:.*]] = tfl.add %[[VAL_5]], %[[VAL_2]] {fused_activation_function = "NONE", tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : tensor<1x!quant.uniform> // CHECK: %[[VAL_7:.*]] = tfl.add %[[VAL_3]], %[[VAL_1]] {fused_activation_function = "NONE", tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : tensor<1x!quant.uniform> @@ -186,6 +275,7 @@ func.func @quantizationWithFloat(%arg0: tensor<1x1x384x!quant.uniform>, tensor<1x384x384x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> func.return %6: tensor<1x384x384x!quant.uniform> } +} // CHECK: func @quantizationWithFloat(%[[VAL_0:.*]]: tensor<1x1x384x!quant.uniform>, %[[VAL_1:.*]]: tensor<1x1x384x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> { // CHECK: %[[VAL_2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<1x384x1x!quant.uniform>, value = dense<127> : tensor<1x384x1xi8>} : () -> tensor<1x384x1x!quant.uniform> @@ -203,14 +293,228 @@ func.func @quantizationWithFloat(%arg0: tensor<1x1x384x!quant.uniform> // CHECK: } +// CHECK: func private @func_1_GPU_FLOAT(%[[VAL_0:.*]]: tensor<1x384x384xf32>, %[[VAL_1:.*]]: tensor<1x384x384xf32>) -> tensor<1x384x384xf32> attributes {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_1"} { +// CHECK: %[[VAL_2:.*]] = tfl.add %[[VAL_0]], %[[VAL_1]] {fused_activation_function = "NONE", tac.device = "GPU", tac.inference_type = "FLOAT"} : tensor<1x384x384xf32> +// CHECK: return %[[VAL_2]] : tensor<1x384x384xf32> +// CHECK: } + // CHECK: func private @func_2_CPU_QUANTIZED_INT8(%[[VAL_0:.*]]: tensor<1x1x384x!quant.uniform>, %[[VAL_1:.*]]: tensor<1x384x384x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> attributes {tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8", tac.interface_name = "func_2"} { // CHECK: %[[VAL_2:.*]] = tfl.mul(%[[VAL_0]], %[[VAL_1]]) {fused_activation_function = "NONE", tac.device = "CPU", tac.inference_type = "QUANTIZED_INT8"} : (tensor<1x1x384x!quant.uniform>, tensor<1x384x384x!quant.uniform>) -> tensor<1x384x384x!quant.uniform> // CHECK: return %[[VAL_2]] : tensor<1x384x384x!quant.uniform> // CHECK: } -// CHECK: func private @func_1_GPU_FLOAT(%[[VAL_0:.*]]: tensor<1x384x384xf32>, %[[VAL_1:.*]]: tensor<1x384x384xf32>) -> tensor<1x384x384xf32> attributes {tac.device = "GPU", tac.inference_type = "FLOAT", tac.interface_name = "func_1"} { -// CHECK: %[[VAL_2:.*]] = tfl.add %[[VAL_0]], %[[VAL_1]] {fused_activation_function = "NONE", tac.device = "GPU", tac.inference_type = "FLOAT"} : tensor<1x384x384xf32> -// CHECK: return %[[VAL_2]] : tensor<1x384x384xf32> -// CHECK: } +// ----- -} +func.func @cond_false_72730(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor) -> (tensor, tensor) { + %cst = arith.constant dense<[1, 0]> : tensor<2xi32> + %cst_0 = arith.constant dense<0> : tensor + %cst_1 = arith.constant dense<-1> : tensor<1xi32> + %cst_2 = arith.constant dense<1> : tensor<1xi32> + %cst_3 = arith.constant dense<0> : tensor<1xi32> + %0 = "tfl.shape"(%arg2) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<2xi32> + %1 = "tfl.strided_slice"(%0, %cst_3, %cst_2, %cst_2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %2 = "tfl.custom"(%cst_1, %1) {custom_code = "FlexTensorListReserve", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor>> + %3 = "tfl.custom"(%cst_1, %1) {custom_code = "FlexTensorListReserve", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor>> + %4:8 = "tfl.while"(%cst_0, %cst_0, %arg5, %arg6, %2, %2, %3, %3) ({ + ^bb0(%arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor>>, %arg12: tensor>>, %arg13: tensor>>, %arg14: tensor>>): + %9 = tfl.less(%arg8, %1) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor + %10 = tfl.less(%arg7, %1) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor + %11 = tfl.logical_and %10, %9 {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor + "tfl.yield"(%11) : (tensor) -> () + }, { + ^bb0(%arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor>>, %arg12: tensor>>, %arg13: tensor>>, %arg14: tensor>>): + %cst_4 = arith.constant dense<[0, 0, 1, 1, 1]> : tensor<5xi32> + %cst_5 = arith.constant dense<[0, 1, 0, 1, 1]> : tensor<5xi32> + %cst_6 = arith.constant dense<2> : tensor + %cst_7 = arith.constant dense<"*"> : tensor + %cst_8 = arith.constant dense<-1> : tensor<1xi32> + %cst_9 = arith.constant dense<-1> : tensor + %cst_10 = arith.constant dense<2> : tensor<1xi32> + %cst_11 = arith.constant dense<1> : tensor + %cst_12 = arith.constant dense<0> : tensor + %cst_13 = arith.constant dense<1> : tensor<1xi32> + %cst_14 = arith.constant dense<0> : tensor<1xi32> + %9 = "tfl.shape"(%arg1) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<2xi32> + %10 = "tfl.strided_slice"(%9, %cst_14, %cst_13, %cst_13) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %11 = "tfl.range"(%cst_12, %10, %cst_11) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor + %12 = "tfl.pack"(%10, %cst_11) {axis = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT", values_count = 2 : i32} : (tensor, tensor) -> tensor<2xi32> + %13 = "tfl.strided_slice"(%9, %cst_13, %cst_10, %cst_13) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %14 = tfl.mul(%11, %13) {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor + %15 = "tfl.reshape"(%14, %12) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor + %16 = "tfl.strided_slice"(%9, %cst_14, %cst_10, %cst_13) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %17 = "tfl.reduce_prod"(%16, %cst_14) {keep_dims = true, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>) -> tensor<1xi32> + %18 = "tfl.reshape"(%arg1, %17) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor + %19 = "tfl.shape"(%arg0) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<2xi32> + %20 = "tfl.strided_slice"(%19, %cst_14, %cst_13, %cst_13) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %21 = "tfl.range"(%cst_12, %20, %cst_11) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor + %22 = "tfl.pack"(%20, %cst_11) {axis = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT", values_count = 2 : i32} : (tensor, tensor) -> tensor<2xi32> + %23 = "tfl.strided_slice"(%19, %cst_13, %cst_10, %cst_13) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %24 = tfl.mul(%21, %23) {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor + %25 = "tfl.reshape"(%24, %22) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor + %26 = "tfl.strided_slice"(%19, %cst_14, %cst_10, %cst_13) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %27 = "tfl.reduce_prod"(%26, %cst_14) {keep_dims = true, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>) -> tensor<1xi32> + %28 = "tfl.reshape"(%arg0, %27) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor + %29 = tfl.add %arg8, %cst_11 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor + %30 = "tfl.expand_dims"(%arg9, %cst_9) {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor + %31 = tfl.add %30, %25 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor + %32 = "tfl.reshape"(%31, %cst_8) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor + %33 = "tfl.gather"(%28, %32) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor + %34 = "tfl.reshape"(%33, %cst_8) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor + %35 = "tfl.shape"(%34) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<1xi32> + %36 = "tfl.fill"(%35, %cst_7) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor + %37 = "tfl.expand_dims"(%arg10, %cst_9) {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor + %38 = tfl.add %37, %15 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor + %39 = "tfl.reshape"(%38, %cst_8) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor + %40 = "tfl.gather"(%18, %39) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor + %41 = "tfl.reshape"(%40, %cst_8) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor + %42 = "tfl.shape"(%41) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<1xi32> + %43 = "tfl.fill"(%42, %cst_7) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor + %44 = "tfl.gather"(%arg2, %arg8) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor + %45 = "tfl.equal"(%44, %cst_6) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor + %46 = "tfl.custom"(%45, %36, %34) {custom_code = "FlexSelect", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor + %47 = "tfl.custom"(%arg11, %arg8, %46) {custom_code = "FlexTensorListSetItem", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> + %48 = "tfl.equal"(%44, %cst_11) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor + %49 = "tfl.custom"(%48, %43, %41) {custom_code = "FlexSelect", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor + %50 = "tfl.custom"(%arg12, %arg8, %49) {custom_code = "FlexTensorListSetItem", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> + %51 = "tfl.gather"(%cst_5, %44) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<5xi32>, tensor) -> tensor + %52 = tfl.add %arg9, %51 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor + %53 = "tfl.custom"(%arg13, %arg8, %52) {custom_code = "FlexTensorListSetItem", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> + %54 = "tfl.gather"(%cst_4, %44) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<5xi32>, tensor) -> tensor + %55 = tfl.add %arg10, %54 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor + %56 = "tfl.custom"(%arg14, %arg8, %55) {custom_code = "FlexTensorListSetItem", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> + %57 = tfl.add %arg7, %cst_11 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor + "tfl.yield"(%57, %29, %52, %55, %47, %50, %53, %56) : (tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) -> () + }) {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) -> (tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) + %5 = "tfl.custom"(%4#4, %cst_1) {custom_code = "FlexTensorListStack", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor<1xi32>) -> tensor + %6 = "tfl.custom"(%5, %cst) {custom_code = "FlexTranspose", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor + %7 = "tfl.custom"(%4#5, %cst_1) {custom_code = "FlexTensorListStack", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor<1xi32>) -> tensor + %8 = "tfl.custom"(%7, %cst) {custom_code = "FlexTranspose", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor + return %6, %8 : tensor, tensor + } + +// CHECK: func.func @cond_false_72730(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor) -> (tensor, tensor) { +// CHECK: %cst = arith.constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %cst_0 = arith.constant dense<0> : tensor +// CHECK: %cst_1 = arith.constant dense<-1> : tensor<1xi32> +// CHECK: %cst_2 = arith.constant dense<1> : tensor<1xi32> +// CHECK: %cst_3 = arith.constant dense<0> : tensor<1xi32> +// CHECK: %0 = call @func_0_DARWINN_FLOAT(%arg2, %cst_3, %cst_2) {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} : (tensor, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK: %1:2 = call @func_1_CPU_FLOAT(%cst_1, %0, %cst_0, %arg5, %arg6, %arg1, %arg0, %arg2, %cst) {tac.device = "CPU", tac.inference_type = "FLOAT", tac.interface_name = "func_1"} : (tensor<1xi32>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<2xi32>) -> (tensor, tensor) +// CHECK: return %1#0, %1#1 : tensor, tensor +// CHECK: } +// CHECK: func.func private @func_0_DARWINN_FLOAT(%arg0: tensor, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>) -> tensor attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_0"} { +// CHECK: %0 = "tfl.shape"(%arg0) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<2xi32> +// CHECK: %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK: return %1 : tensor +// CHECK: } +// CHECK: func.func private @func_1_CPU_FLOAT(%arg0: tensor<1xi32>, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor<2xi32>) -> (tensor, tensor) attributes {tac.device = "CPU", tac.inference_type = "FLOAT", tac.interface_name = "func_1"} { +// CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "FlexTensorListReserve", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor>> +// CHECK: %1 = "tfl.custom"(%arg0, %arg1) {custom_code = "FlexTensorListReserve", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor>> +// CHECK: %2:8 = "tfl.while"(%arg2, %arg2, %arg3, %arg4, %0, %0, %1, %1) ({ +// CHECK: ^bb0(%arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor>>, %arg14: tensor>>, %arg15: tensor>>, %arg16: tensor>>): +// CHECK: %7 = func.call @func_2_DARWINN_FLOAT(%arg10, %arg1, %arg9) {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_2"} : (tensor, tensor, tensor) -> tensor +// CHECK: "tfl.yield"(%7) : (tensor) -> () +// CHECK: }, { +// CHECK: ^bb0(%arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor>>, %arg14: tensor>>, %arg15: tensor>>, %arg16: tensor>>): +// CHECK: %cst = arith.constant dense<[0, 0, 1, 1, 1]> : tensor<5xi32> +// CHECK: %cst_0 = arith.constant dense<[0, 1, 0, 1, 1]> : tensor<5xi32> +// CHECK: %cst_1 = arith.constant dense<2> : tensor +// CHECK: %cst_2 = arith.constant dense<"*"> : tensor +// CHECK: %cst_3 = arith.constant dense<-1> : tensor<1xi32> +// CHECK: %cst_4 = arith.constant dense<-1> : tensor +// CHECK: %cst_5 = arith.constant dense<2> : tensor<1xi32> +// CHECK: %cst_6 = arith.constant dense<1> : tensor +// CHECK: %cst_7 = arith.constant dense<0> : tensor +// CHECK: %cst_8 = arith.constant dense<1> : tensor<1xi32> +// CHECK: %cst_9 = arith.constant dense<0> : tensor<1xi32> +// CHECK: %7:2 = func.call @func_3_DARWINN_FLOAT(%arg5, %cst_9, %cst_8, %cst_7, %cst_6, %cst_5) {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_3"} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor, tensor, tensor<1xi32>) -> (tensor, tensor<2xi32>) +// CHECK: %8 = "tfl.reduce_prod"(%7#1, %cst_9) {keep_dims = true, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>) -> tensor<1xi32> +// CHECK: %9:3 = func.call @func_4_DARWINN_FLOAT(%arg5, %8, %arg6, %cst_9, %cst_8, %cst_7, %cst_6, %cst_5) {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_4"} : (tensor, tensor<1xi32>, tensor, tensor<1xi32>, tensor<1xi32>, tensor, tensor, tensor<1xi32>) -> (tensor, tensor, tensor<2xi32>) +// CHECK: %10 = "tfl.reduce_prod"(%9#2, %cst_9) {keep_dims = true, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>) -> tensor<1xi32> +// CHECK: %11 = "tfl.expand_dims"(%arg11, %cst_4) {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor +// CHECK: %12 = "tfl.expand_dims"(%arg12, %cst_4) {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor +// CHECK: %13:7 = func.call @func_5_DARWINN_FLOAT(%arg6, %10, %arg10, %cst_6, %11, %9#1, %cst_3, %cst_2, %12, %7#0, %9#0, %arg7, %cst_1) {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_5"} : (tensor, tensor<1xi32>, tensor, tensor, tensor, tensor, tensor<1xi32>, tensor, tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor) +// CHECK: %14 = "tfl.custom"(%13#6, %13#2, %13#1) {custom_code = "FlexSelect", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor +// CHECK: %15 = "tfl.custom"(%arg13, %arg10, %14) {custom_code = "FlexTensorListSetItem", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> +// CHECK: %16 = func.call @func_6_DARWINN_FLOAT(%13#5, %cst_6) {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_6"} : (tensor, tensor) -> tensor +// CHECK: %17 = "tfl.custom"(%16, %13#4, %13#3) {custom_code = "FlexSelect", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor +// CHECK: %18 = "tfl.custom"(%arg14, %arg10, %17) {custom_code = "FlexTensorListSetItem", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> +// CHECK: %19 = func.call @func_7_DARWINN_FLOAT(%cst_0, %13#5, %arg11) {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_7"} : (tensor<5xi32>, tensor, tensor) -> tensor +// CHECK: %20 = "tfl.custom"(%arg15, %arg10, %19) {custom_code = "FlexTensorListSetItem", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> +// CHECK: %21 = func.call @func_8_DARWINN_FLOAT(%cst, %13#5, %arg12) {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_8"} : (tensor<5xi32>, tensor, tensor) -> tensor +// CHECK: %22 = "tfl.custom"(%arg16, %arg10, %21) {custom_code = "FlexTensorListSetItem", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor, tensor) -> tensor>> +// CHECK: %23 = func.call @func_9_DARWINN_FLOAT(%arg9, %cst_6) {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_9"} : (tensor, tensor) -> tensor +// CHECK: "tfl.yield"(%23, %13#0, %19, %21, %15, %18, %20, %22) : (tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) -> () +// CHECK: }) {tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) -> (tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) +// CHECK: %3 = "tfl.custom"(%2#4, %arg0) {custom_code = "FlexTensorListStack", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor<1xi32>) -> tensor +// CHECK: %4 = "tfl.custom"(%3, %arg8) {custom_code = "FlexTranspose", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor +// CHECK: %5 = "tfl.custom"(%2#5, %arg0) {custom_code = "FlexTensorListStack", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor>>, tensor<1xi32>) -> tensor +// CHECK: %6 = "tfl.custom"(%5, %arg8) {custom_code = "FlexTranspose", custom_option = #tfl, tac.device = "CPU", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor +// CHECK: return %4, %6 : tensor, tensor +// CHECK: } +// CHECK: func.func private @func_2_DARWINN_FLOAT(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_2"} { +// CHECK: %0 = tfl.less(%arg0, %arg1) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor +// CHECK: %1 = tfl.less(%arg2, %arg1) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor +// CHECK: %2 = tfl.logical_and %1, %0 {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor +// CHECK: return %2 : tensor +// CHECK: } +// CHECK: func.func private @func_3_DARWINN_FLOAT(%arg0: tensor, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor, %arg4: tensor, %arg5: tensor<1xi32>) -> (tensor, tensor<2xi32>) attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_3"} { +// CHECK: %0 = "tfl.shape"(%arg0) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<2xi32> +// CHECK: %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK: %2 = "tfl.range"(%arg3, %1, %arg4) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor +// CHECK: %3 = "tfl.pack"(%1, %arg4) {axis = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT", values_count = 2 : i32} : (tensor, tensor) -> tensor<2xi32> +// CHECK: %4 = "tfl.strided_slice"(%0, %arg2, %arg5, %arg2) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK: %5 = tfl.mul(%2, %4) {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor +// CHECK: %6 = "tfl.reshape"(%5, %3) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor +// CHECK: %7 = "tfl.strided_slice"(%0, %arg1, %arg5, %arg2) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +// CHECK: return %6, %7 : tensor, tensor<2xi32> +// CHECK: } +// CHECK: func.func private @func_4_DARWINN_FLOAT(%arg0: tensor, %arg1: tensor<1xi32>, %arg2: tensor, %arg3: tensor<1xi32>, %arg4: tensor<1xi32>, %arg5: tensor, %arg6: tensor, %arg7: tensor<1xi32>) -> (tensor, tensor, tensor<2xi32>) attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_4"} { +// CHECK: %0 = "tfl.reshape"(%arg0, %arg1) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor +// CHECK: %1 = "tfl.shape"(%arg2) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<2xi32> +// CHECK: %2 = "tfl.strided_slice"(%1, %arg3, %arg4, %arg4) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK: %3 = "tfl.range"(%arg5, %2, %arg6) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor, tensor) -> tensor +// CHECK: %4 = "tfl.pack"(%2, %arg6) {axis = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT", values_count = 2 : i32} : (tensor, tensor) -> tensor<2xi32> +// CHECK: %5 = "tfl.strided_slice"(%1, %arg4, %arg7, %arg4) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK: %6 = tfl.mul(%3, %5) {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor +// CHECK: %7 = "tfl.reshape"(%6, %4) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<2xi32>) -> tensor +// CHECK: %8 = "tfl.strided_slice"(%1, %arg3, %arg7, %arg4) {begin_mask = 1 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +// CHECK: return %0, %7, %8 : tensor, tensor, tensor<2xi32> +// CHECK: } +// CHECK: func.func private @func_5_DARWINN_FLOAT(%arg0: tensor, %arg1: tensor<1xi32>, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor<1xi32>, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor) attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_5"} { +// CHECK: %0 = "tfl.reshape"(%arg0, %arg1) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor +// CHECK: %1 = tfl.add %arg2, %arg3 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor +// CHECK: %2 = tfl.add %arg4, %arg5 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor +// CHECK: %3 = "tfl.reshape"(%2, %arg6) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor +// CHECK: %4 = "tfl.gather"(%0, %3) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor +// CHECK: %5 = "tfl.reshape"(%4, %arg6) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor +// CHECK: %6 = "tfl.shape"(%5) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<1xi32> +// CHECK: %7 = "tfl.fill"(%6, %arg7) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor +// CHECK: %8 = tfl.add %arg8, %arg9 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor +// CHECK: %9 = "tfl.reshape"(%8, %arg6) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor +// CHECK: %10 = "tfl.gather"(%arg10, %9) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor +// CHECK: %11 = "tfl.reshape"(%10, %arg6) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor<1xi32>) -> tensor +// CHECK: %12 = "tfl.shape"(%11) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor) -> tensor<1xi32> +// CHECK: %13 = "tfl.fill"(%12, %arg7) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<1xi32>, tensor) -> tensor +// CHECK: %14 = "tfl.gather"(%arg11, %arg2) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor +// CHECK: %15 = "tfl.equal"(%14, %arg12) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor +// CHECK: return %1, %5, %7, %11, %13, %14, %15 : tensor, tensor, tensor, tensor, tensor, tensor, tensor +// CHECK: } +// CHECK: func.func private @func_6_DARWINN_FLOAT(%arg0: tensor, %arg1: tensor) -> tensor attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_6"} { +// CHECK: %0 = "tfl.equal"(%arg0, %arg1) {tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor, tensor) -> tensor +// CHECK: return %0 : tensor +// CHECK: } +// CHECK: func.func private @func_7_DARWINN_FLOAT(%arg0: tensor<5xi32>, %arg1: tensor, %arg2: tensor) -> tensor attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_7"} { +// CHECK: %0 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<5xi32>, tensor) -> tensor +// CHECK: %1 = tfl.add %arg2, %0 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor +// CHECK: return %1 : tensor +// CHECK: } +// CHECK: func.func private @func_8_DARWINN_FLOAT(%arg0: tensor<5xi32>, %arg1: tensor, %arg2: tensor) -> tensor attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_8"} { +// CHECK: %0 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32, batch_dims = 0 : i32, tac.device = "DARWINN", tac.inference_type = "FLOAT"} : (tensor<5xi32>, tensor) -> tensor +// CHECK: %1 = tfl.add %arg2, %0 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor +// CHECK: return %1 : tensor +// CHECK: } +// CHECK: func.func private @func_9_DARWINN_FLOAT(%arg0: tensor, %arg1: tensor) -> tensor attributes {tac.device = "DARWINN", tac.inference_type = "FLOAT", tac.interface_name = "func_9"} { +// CHECK: %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE", tac.device = "DARWINN", tac.inference_type = "FLOAT"} : tensor +// CHECK: return %0 : tensor +// CHECK: } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tflite_import_export.cc b/tensorflow/compiler/mlir/lite/experimental/tac/tflite_import_export.cc index b227a8e35d4..e579b60869b 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tflite_import_export.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tflite_import_export.cc @@ -75,8 +75,10 @@ void AttachCostPerDevice(mlir::ModuleOp module, absl::StatusOr> TfLiteImporter::Import() { source_mgr_handler_ = std::make_unique( source_mgr_, &context_); - return ImportFlatbufferOrMlir(options_.file_name, options_.input_mlir, - &source_mgr_, &context_); + return ImportFlatbufferOrMlir( + options_.file_name, options_.input_mlir, + /*experimental_prune_unreachable_nodes_unconditionally=*/true, + &source_mgr_, &context_); } //////////// Exporter //////////// @@ -107,7 +109,8 @@ absl::Status TfLiteExporter::Export(mlir::ModuleOp module) { } return mlir::TFL::tac::ExportFlatbufferOrMlir(options_.output_file_name, - options_.output_mlir, module); + options_.output_mlir, module, + /*enable_select_tf_ops=*/false); } } // namespace tac diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/compute_cost.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/compute_cost.cc index 92ccb17f895..f399af7da68 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/compute_cost.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/compute_cost.cc @@ -96,7 +96,7 @@ void ComputeCostPass::runOnOperation() { signalPassFailure(); } - float total_cost = GetCostForFunc(&func, target.getValue()); + float total_cost = GetCostForFunc(&func, *target); OpBuilder builder(func); UpdateCost(func, total_cost, &builder); } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc index a66fc48c559..15fb7e66477 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc @@ -106,7 +106,7 @@ void GetOpCostPass::runOnOperation() { !llvm::isa(op)) { auto hardware = GetTargetAnnotation(op); if (!hardware) return; - float cost = GetCostForOp(op, hardware.getValue()); + float cost = GetCostForOp(op, *hardware); UpdateCost(op, cost, &builder); } }); diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/fold_constants_to_subgraph.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/fold_constants_to_subgraph.cc index c15bf2d2be4..b6c544a8f69 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/fold_constants_to_subgraph.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/fold_constants_to_subgraph.cc @@ -23,9 +23,11 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -87,8 +89,9 @@ class FoldConstantsToSubgraphPass void CopyConstantIntoFunc(int argument_index, Operation* const_op, func::FuncOp func) { - assert((llvm::isa(const_op)) && - "Expect QConst or Const op."); + assert( + (llvm::isa(const_op)) && + "Expect QConst or Const op."); OpBuilder builder(func.getBody()); auto cloned_const_op = const_op->clone(); cloned_const_op->setLoc(func.getBody().getLoc()); @@ -99,13 +102,16 @@ void CopyConstantIntoFunc(int argument_index, Operation* const_op, } bool IsConstOrQConstInt(Operation* op) { - if (!llvm::isa(op)) return false; + if (!llvm::isa(op)) + return false; - if (auto const_op = dyn_cast_or_null(op)) { + if (auto arith_const_op = dyn_cast_or_null(op)) { + // arith ConstOp path. + auto type = arith_const_op.getType().cast().getElementType(); + if (!type.isInteger(32) && !type.isInteger(64)) return false; + } else if (auto const_op = dyn_cast_or_null(op)) { // ConstOp path. - auto type = const_op.getType() - .dyn_cast_or_null() - .getElementType(); + auto type = const_op.getType().cast().getElementType(); if (!type.isInteger(32) && !type.isInteger(64)) return false; } else { // QConstOp path. @@ -124,7 +130,8 @@ void FoldConstantsToSubgraphPass::runOnOperation() { for (auto fn : module.getOps()) { fn.walk([&](Operation* op) { - if (!llvm::isa(op)) return; + if (!llvm::isa(op)) + return; // We only fold int32/int64 for Const and i32 for QConst if not specify // all constants flag. (Since they're more like "configs" or i32 biases.) diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc index ba79879aa45..42852425741 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc @@ -167,11 +167,10 @@ void AlternativeSubgraphPass::GetAlternativeGraphForFunc( } const InferenceDeviceType current_device_type( - {current_device.getValue(), current_inference_type.getValue()}); + {*current_device, *current_inference_type}); const std::vector& all_inference_device_type = - GetAllAlternativeInferenceDeviceType(current_inference_type.getValue(), - devices); + GetAllAlternativeInferenceDeviceType(*current_inference_type, devices); for (const auto& device_inference_type : all_inference_device_type) { if (device_inference_type != current_device_type) { @@ -183,7 +182,8 @@ void AlternativeSubgraphPass::GetAlternativeGraphForFunc( // see if we need to erase the func op. // Ideally it would be nice if we can utilize dynamic illegal op to do // the job. - if (!IsAllSupportedbySpec(cloned_func, device_inference_type)) { + if (device_inference_type.hardware != "CPU" && + !IsAllSupportedbySpec(cloned_func, device_inference_type)) { cloned_func.erase(); } } @@ -193,7 +193,7 @@ void AlternativeSubgraphPass::GetAlternativeGraphForFunc( // We need to run the optimization for the current device last because we // need to avoid any changes made the current graph polluting other // alternative graph views. - Optimize(func, current_device.getValue()); + Optimize(func, *current_device); } bool AlternativeSubgraphPass::IsAllSupportedbySpec( @@ -237,8 +237,8 @@ func::FuncOp AlternativeSubgraphPass::GetAlternativeViewForSpec( cloned_func->setAttr(kInferenceType, builder->getStringAttr(GetInferenceString( target_device_inference_type.inference_type))); - std::string new_function_name = GetFunctionImplName( - interface_name.getValue(), target_device_inference_type); + std::string new_function_name = + GetFunctionImplName(*interface_name, target_device_inference_type); cloned_func.setName(new_function_name); // If it's quantized -> float, we need to wrap all the ops around with dequant diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/pick_subgraphs.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/pick_subgraphs.cc index bd7f43d36ca..58940205edf 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/pick_subgraphs.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/pick_subgraphs.cc @@ -318,7 +318,7 @@ void PickSubgraphsPass::BuildSubgraphs( // Build the subgraph. Subgraph subgraph; subgraph.call = call_op; - auto impl_iter = func_impls.find(interface_name.getValue()); + auto impl_iter = func_impls.find(*interface_name); if (impl_iter == func_impls.end()) { call_op.emitError( "we cannot find corresponding implementation for this call op"); @@ -331,8 +331,7 @@ void PickSubgraphsPass::BuildSubgraphs( impl.emitError("we cannot find inference device type for this func"); signalPassFailure(); } - subgraph.available_choices.emplace(inference_device_type.getValue(), - impl); + subgraph.available_choices.emplace(*inference_device_type, impl); } // Insert in the subgraphs. @@ -352,11 +351,10 @@ PickSubgraphsPass::CollectSubgraphFuncs(ModuleOp module) { for (auto func : module.getOps()) { auto interface_name = GetInterFaceName(func); if (interface_name.has_value()) { - auto impls_iter = func_impls.find(interface_name.getValue()); + auto impls_iter = func_impls.find(*interface_name); if (impls_iter == func_impls.end()) impls_iter = - func_impls - .emplace(interface_name.getValue(), std::vector()) + func_impls.emplace(*interface_name, std::vector()) .first; impls_iter->second.push_back(func); } @@ -424,12 +422,11 @@ void PickSubgraphsPass::RewireSubgraphs( const InferenceDeviceType& preferred_inference_device_type = kv.second; // We need to rewire the call. - std::string interface_name = GetInterFaceName(call).getValue(); + std::string interface_name = *GetInterFaceName(call); for (auto impl : collected_impl_funcs.find(interface_name)->second) { const auto& impl_inference_device_type = GetInferenceDeviceTypeForOp(impl); - if (impl_inference_device_type.getValue() == - preferred_inference_device_type) { + if (*impl_inference_device_type == preferred_inference_device_type) { if (call.getCallee() != impl.getName()) { // We need to rebuild the call op. :( builder->setInsertionPoint(call); diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/raise_target_subgraphs.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/raise_target_subgraphs.cc index 6fb2fcaeeab..457487c3f2f 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/raise_target_subgraphs.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/raise_target_subgraphs.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include @@ -33,10 +34,13 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/experimental/common/outline_operations.h" #include "tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h" #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h" #include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h" @@ -48,51 +52,11 @@ namespace TFL { namespace tac { namespace { -// Subgraph here is actually an intermediate data structure holder for the ops: -// The ops within share the same "target", they're topologically sorted. -// The subgraph here will be later populated to generate func ops. -// All the subgraphs should not create cyclic dependencies: -// So we should not have: -// subgraph1 -// \ -// subgraph2 -// / -// subgraph1 -struct Subgraph { - // All ops must be inserted in it's topological order. - llvm::SetVector all_ops; - int subgraph_id; - InferenceDeviceType inference_device_type; -}; - -// This will exclude arguments & consts & quantize/dequantize ops. -inline bool IsNonConstQuantizeOp(Operation* op) { - return IsNonConstOp(op) && NotTFLQuantDequantizeOp(op) && !IsTerminatorOp(op); -} +using ::mlir::TFL::common::OpsAdded; +using ::mlir::TFL::common::Subgraph; -// This pass will group those ops (non-const TFL dialect ops) have the same -// target together and raise them as FuncOps. -// See the following Example: -// -// op1 (GPU) -// \ op2 (GPU) -// \ | -// \ op3 (GPU) -// \ / -// op4 (CPU) -// -// will be raised as 3 subgraphs: -// Subgraph 1: {op1}, GPU -> Func_1_GPU -// Subgraph 2: {op2, op3}, GPU -> Func_2_GPU -// Subgraph 3: {op4} CPU -> Func_3_CPU -// -// MainFunc: -// %0 = call @Func_1_GPU -// %1 = call @Func_2_GPU -// %2 = call @Func_3_CPU(%0, %1) class RaiseTargetSubgraphsPass - : public mlir::PassWrapper> { + : public PassWrapper> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RaiseTargetSubgraphsPass) @@ -106,272 +70,148 @@ class RaiseTargetSubgraphsPass } void runOnOperation() override; - void RaiseTargetSubgraphsForBlock(Block* block, OpBuilder* builder, - ModuleOp module); - - void ExtractSubgraphToFunc(Subgraph* subgraph, OpBuilder* builder, - ModuleOp module); - - func::FuncOp BuildFuncOp(Subgraph* subgraph, OpBuilder* builder, - ModuleOp module_op, SmallVector* inputs, - SmallVector* outputs, - InferenceDeviceType* inference_device_type); - - int subgraph_count_ = 0; + void RaiseTargetSubgraphsForBlock(Block& block, OpBuilder& builder, + ModuleOp module, bool skip_cpu, + int& func_count); }; -// This is to collect input arguments for the given set of ops. -// See the example: -// -// value1 value2 -// \ / -// op1 -// \ value3 -// \ / -// op2 -// | -// op3 -// -// Then the arguments will be {value1, value2, value3} -void CollectInputs(const llvm::SetVector& all_ops, - SmallVector* inputs) { - for (Operation* op : all_ops) { - for (Value input : op->getOperands()) { - Operation* input_op = input.getDefiningOp(); - const bool input_within_subgraph = - (input_op && all_ops.count(input_op) == 1); - if (!input_within_subgraph) { - inputs->push_back(input); - } - } - } +// After raising ops and adding the Func & Call op, call this function +// to set attributes specific to this pass. +void AddAttrs(OpsAdded& ops_added, OpBuilder& builder, int func_count) { + func::FuncOp& added_func_op = ops_added.func_op; + func::CallOp& added_call_op = ops_added.call_op; + StringAttr interface_name = + builder.getStringAttr(absl::StrCat("func_", func_count)); + + added_func_op->setAttr(kInterfaceNameAttr, interface_name); + added_call_op->setAttr(kInterfaceNameAttr, interface_name); + + StringAttr device = added_func_op->getRegion(0) + .getBlocks() + .front() + .front() + .getAttr(kDevice) + .cast(); + StringAttr inference_type = added_func_op->getRegion(0) + .getBlocks() + .front() + .front() + .getAttr(kInferenceType) + .cast(); + added_call_op->setAttr(kDevice, device); + added_call_op->setAttr(kInferenceType, inference_type); + added_func_op->setAttr(kDevice, device); + added_func_op->setAttr(kInferenceType, inference_type); + + std::string function_name = absl::StrCat(interface_name.getValue().str(), "_", + device.getValue().str(), "_", + inference_type.getValue().str()); + added_func_op.setName(builder.getStringAttr(function_name)); + added_call_op.setCallee(builder.getStringAttr(function_name)); } -// This is to collect outputs arguments for the given set of ops. -// See the example: -// -// op1 -// / \ -// value1 \ -// op2 -// | \ -// op3 value2 -// | -// value3 -// -// Then the arguments will be {value1, value2, value3} -void CollectOutputs(const llvm::SetVector& all_ops, - SmallVector* outputs) { - for (Operation* op : all_ops) { - for (Value output : op->getResults()) { - bool output_consumed_outside_subgraph = false; - for (Operation* consumer : output.getUsers()) { - if (all_ops.count(consumer) == 0) { - output_consumed_outside_subgraph = true; +// Raises partitioned sequential `Operations` from a block to a new function +// definition. `Operations` are partitioned into classes from the cartesian +// product of possible devices and inference datatypes. For example, we might +// raise a chunk of sequential operations from a block all having attributes +// `{ tac.device = "GPU", tac.inference_type = "FLOAT"}` to a function +// with the matching attributes. Assumed is that device type "CPU" +// is the only device that is allowed to call other devices. I.e. ancestors of a +// "CPU" `Operation` may only `Operations` without a device or other "CPU" +// `Operations`. Implied is that "CPU" ops may contain subgraphs of different +// device types which also need to be raised. +void RaiseTargetSubgraphsPass::RaiseTargetSubgraphsForBlock(Block& block, + OpBuilder& builder, + ModuleOp module, + bool skip_cpu, + int& func_count) { + llvm::SetVector partition_ops; + + auto device_is = [&](InferenceDeviceType t, llvm::StringRef device) -> bool { + return t.hardware == device; + }; + + auto op_has_device = [&](Operation& op, InferenceDeviceType& device) -> bool { + Optional op_device = GetInferenceDeviceTypeForOp(&op); + if (!op_device.has_value()) return false; + device = op_device.value(); + return true; + }; + + auto op_device_is = [&](Operation& op, llvm::StringRef device) -> bool { + InferenceDeviceType device_type; + if (!op_has_device(op, device_type)) return false; + return device_is(device_type, device); + }; + + // Given a list of `Operation`s to partitition, raise them to a new + // function. If the partitons is of type "CPU" then it may contain + // other deivice subgraphs that need to be raised. We recur on + // any nested blocks of "CPU" ops and skip raising "CPU" ops for the + // remainder of that recursive call. + auto extract = [&](llvm::SetVector& partition_ops) -> void { + if (partition_ops.empty()) return; + InferenceDeviceType device = + GetInferenceDeviceTypeForOp(partition_ops.front()).value(); + Subgraph old_subgraph(partition_ops, ++func_count); + OpsAdded ops_added; + ExtractSubgraphToFunc(old_subgraph, builder, module, ops_added); + AddAttrs(ops_added, builder, func_count); + // Ops in "CPU" subgraphs may nested regions with other device subgraphs. + // We recur into these nested blocks to raise those as well. We don't raise + // "CPU" ops who are themselves nested within a "CPU" op, so set + // `skip_cpu` to true. + if (device_is(device, "CPU")) { + for (auto& block : ops_added.func_op->getRegion(0).getBlocks()) + for (auto& op : block) { + auto op_device = GetInferenceDeviceTypeForOp(&op); + if (op_device_is(op, "CPU")) + // The recently raised func is device type cpu & `op` is a "CPU". + // Recursivley call again to raise any non-"CPU" subgraphs contained + // within nested region of `op`. + for (auto& region : op.getRegions()) + for (auto& block : region.getBlocks()) + RaiseTargetSubgraphsForBlock(block, builder, module, + /*skip_cpu=*/true, func_count); } - } - if (output_consumed_outside_subgraph) { - outputs->push_back(output); - } } - } -} - -void BuildTypes(const SmallVector& values, - SmallVector* types) { - for (auto value : values) { - types->push_back(value.getType()); - } -} - -void GetFunctionName(const Subgraph& subgrpah, std::string* function_name, - std::string* interface_name) { - *interface_name = absl::StrCat("func_", std::to_string(subgrpah.subgraph_id)); - *function_name = absl::StrCat( - (*interface_name), "_", subgrpah.inference_device_type.hardware, "_", - GetInferenceString(subgrpah.inference_device_type.inference_type)); -} - -func::FuncOp RaiseTargetSubgraphsPass::BuildFuncOp( - Subgraph* subgraph, OpBuilder* builder, ModuleOp module_op, - SmallVector* inputs, SmallVector* outputs, - InferenceDeviceType* inference_device_type) { - CollectInputs(subgraph->all_ops, inputs); - CollectOutputs(subgraph->all_ops, outputs); - - SmallVector input_types; - SmallVector return_types; - - BuildTypes(*inputs, &input_types); - BuildTypes(*outputs, &return_types); - - FunctionType function_type = - builder->getFunctionType(input_types, return_types); - - SmallVector attrs; - // Function name. - std::string function_name; - std::string interface_name; - GetFunctionName(*subgraph, &function_name, &interface_name); - attrs.push_back(builder->getNamedAttr( - kInterfaceNameAttr, builder->getStringAttr(interface_name))); - - // Inference Device type. - attrs.push_back(builder->getNamedAttr( - kDevice, - builder->getStringAttr(subgraph->inference_device_type.hardware))); - attrs.push_back(builder->getNamedAttr( - kInferenceType, builder->getStringAttr(GetInferenceString( - subgraph->inference_device_type.inference_type)))); - *inference_device_type = subgraph->inference_device_type; - - func::FuncOp new_func = - func::FuncOp::create(builder->getUnknownLoc(), function_name, - function_type, llvm::makeArrayRef(attrs)); - new_func.setPrivate(); - - new_func.addEntryBlock(); - - // Function argument mapping. - llvm::DenseMap function_argument_mapping; - for (int i = 0; i < inputs->size(); ++i) { - function_argument_mapping.insert({(*inputs)[i], i}); - } - - OpBuilder function_builder(new_func.getBody()); - - llvm::DenseMap op_cloned_op_mapping; - llvm::DenseMap output_cloned_op_output_mapping; - for (Operation* op : subgraph->all_ops) { - Operation* cloned_op = function_builder.clone(*op); - op_cloned_op_mapping.insert({op, cloned_op}); - for (int i = 0; i < op->getNumResults(); ++i) { - Value op_output = op->getResult(i); - Value cloned_op_output = cloned_op->getResult(i); - output_cloned_op_output_mapping.insert({op_output, cloned_op_output}); - } - } - - for (Operation* op : subgraph->all_ops) { - Operation* cloned_op = op_cloned_op_mapping.find(op)->second; - for (int i = 0; i < op->getNumOperands(); ++i) { - Value input = op->getOperand(i); - Value cloned_op_input; - // If the input is actually a function argument. - if (function_argument_mapping.count(input) > 0) { - int function_argument = function_argument_mapping.find(input)->second; - cloned_op_input = new_func.getArgument(function_argument); - } else { - // The input is actually with in the subgraph. - cloned_op_input = output_cloned_op_output_mapping.find(input)->second; + partition_ops.clear(); + }; + + // Given a block, partition into lists of similar `Operations` as described. + Optional current_device_type = std::nullopt; + for (Operation& current_op : block) { + auto next_device_type = GetInferenceDeviceTypeForOp(¤t_op); + if (!next_device_type.has_value() || + (skip_cpu && device_is(next_device_type.value(), "CPU"))) { + // If we aren't raising this op, we only need to raise the current + // partition if this op depends on one the the partitioned ops results. + for (Value operand : current_op.getOperands()) { + if (partition_ops.contains(operand.getDefiningOp())) + extract(partition_ops); } - cloned_op->setOperand(i, cloned_op_input); + continue; } - } - - SmallVector final_outputs; - for (auto output : *outputs) { - auto cloned_output = output_cloned_op_output_mapping.find(output)->second; - final_outputs.push_back(cloned_output); - } - function_builder.create(new_func.getLoc(), - final_outputs); - - module_op.push_back(new_func); - return new_func; -} - -void RaiseTargetSubgraphsPass::ExtractSubgraphToFunc(Subgraph* subgraph, - OpBuilder* builder, - ModuleOp module) { - SmallVector func_inputs; - SmallVector func_outputs; - - InferenceDeviceType inference_device_type; - func::FuncOp func = BuildFuncOp(subgraph, builder, module, &func_inputs, - &func_outputs, &inference_device_type); - - // We just use the location of the last ops in the subgraph as the location - // for the call_op. - Operation* last_output = subgraph->all_ops.back(); - - // TODO(renjieliu): we should add func attributes to the call op. - builder->setInsertionPoint(last_output); - auto call_op = - builder->create(last_output->getLoc(), func, func_inputs); - - auto interface_name = GetInterFaceName(func); - - // Set call op attribute: interface_name, hardware. - call_op->setAttr(kInterfaceNameAttr, - builder->getStringAttr(interface_name.getValue())); - call_op->setAttr(kDevice, - builder->getStringAttr(inference_device_type.hardware)); - call_op->setAttr(kInferenceType, builder->getStringAttr(GetInferenceString( - inference_device_type.inference_type))); - - // Rewire the outputs. - if (call_op.getNumResults() != func_outputs.size()) { - module.emitError("the constructed func op has mismatched returns"); - signalPassFailure(); - } - - for (int i = 0; i < func_outputs.size(); ++i) { - Value output = func_outputs[i]; - output.replaceAllUsesWith(call_op.getResult(i)); - } - - // Clear the subgraph. - // Those ops should be removed. - for (auto* op : subgraph->all_ops) { - op->dropAllDefinedValueUses(); - op->dropAllReferences(); - op->erase(); - } -} - -// TODO(renjieliu): We may need to consider about side effect ops: we may leave -// those ops alone when building the subgraph. -void RaiseTargetSubgraphsPass::RaiseTargetSubgraphsForBlock(Block* block, - OpBuilder* builder, - ModuleOp module) { - // This is a very naive implementation: - // It will greedily group adjacent ops that have the same inference type to a - // subgraph. - llvm::DenseMap all_subgraphs; - llvm::Optional previous_device_type = llvm::None; - int current_subgraph_id = -1; - for (auto& op : *block) { - if (IsNonConstQuantizeOp(&op) && !IsTerminatorOp(&op) && - !llvm::isa(op)) { - auto current_device_type = GetInferenceDeviceTypeForOp(&op); - if (!(current_device_type.has_value() && - current_device_type == previous_device_type)) { - // We should start a new subgraph. - Subgraph new_subgraph; - new_subgraph.inference_device_type = current_device_type.getValue(); - new_subgraph.subgraph_id = subgraph_count_++; - all_subgraphs.insert({new_subgraph.subgraph_id, new_subgraph}); - current_subgraph_id = new_subgraph.subgraph_id; - } - previous_device_type = current_device_type; - all_subgraphs.find(current_subgraph_id)->second.all_ops.insert(&op); + if (next_device_type == current_device_type) { + partition_ops.insert(¤t_op); + continue; } + extract(partition_ops); + partition_ops.insert(¤t_op); + current_device_type = next_device_type; } - - // Create FuncOp & replace with current uses based on those subgraphs. - for (auto& subgraph : all_subgraphs) { - ExtractSubgraphToFunc(&subgraph.second, builder, module); - } + extract(partition_ops); } void RaiseTargetSubgraphsPass::runOnOperation() { - auto module = getOperation(); - SmallVector funcs(module.getOps()); + ModuleOp module = getOperation(); + SmallVector funcs(module.getOps()); + int func_count = -1; for (auto func : funcs) { for (auto& block : func) { - auto builder = OpBuilder::atBlockBegin(&block); - RaiseTargetSubgraphsForBlock(&block, &builder, module); + OpBuilder builder = OpBuilder::atBlockBegin(&block); + RaiseTargetSubgraphsForBlock(block, builder, module, /*skip_cpu=*/false, + func_count); } } } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/utils/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/utils/BUILD index d319b6094d1..b4546ba2120 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/utils/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/utils/BUILD @@ -1,6 +1,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//visibility:public", ], diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.cc b/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.cc index 97779c3d3ff..20c81962e5a 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.cc @@ -43,6 +43,7 @@ namespace tac { absl::StatusOr> ImportFlatbufferOrMlir( const std::string& input_filename, bool input_mlir, + bool experimental_prune_unreachable_nodes_unconditionally, llvm::SourceMgr* source_mgr, mlir::MLIRContext* context) { std::string error; std::unique_ptr buffer = @@ -70,11 +71,12 @@ absl::StatusOr> ImportFlatbufferOrMlir( return tflite::FlatBufferToMlir( absl::string_view(buffer->getBufferStart(), buffer->getBufferSize()), context, loc, /*use_external_constant=*/false, inputs, outputs, - /*experimental_prune_unreachable_nodes_unconditionally=*/true); + experimental_prune_unreachable_nodes_unconditionally); } absl::Status ExportFlatbufferOrMlir(const std::string& output_filename, - bool output_mlir, mlir::ModuleOp module) { + bool output_mlir, mlir::ModuleOp module, + bool enable_select_tf_ops) { std::string error_msg; auto output = mlir::openOutputFile(output_filename, &error_msg); if (output == nullptr) { @@ -90,8 +92,13 @@ absl::Status ExportFlatbufferOrMlir(const std::string& output_filename, } else { tflite::FlatbufferExportOptions options; options.toco_flags.set_force_select_tf_ops(false); - options.toco_flags.set_enable_select_tf_ops(false); options.toco_flags.set_allow_custom_ops(true); + if (enable_select_tf_ops) { + options.toco_flags.set_enable_select_tf_ops(true); + options.toco_flags.set_allow_all_select_tf_ops(true); + } else { + options.toco_flags.set_enable_select_tf_ops(false); + } if (!tflite::MlirToFlatBufferTranslateFunction(module, options, &result)) { return absl::UnknownError("Failed to export tflite file."); } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.h b/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.h index 4af9f005606..af5732683e9 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.h @@ -32,11 +32,13 @@ namespace tac { // Import the file as mlir module, the input maybe flatbuffer or mlir file. absl::StatusOr> ImportFlatbufferOrMlir( const std::string& input_filename, bool input_mlir, + bool experimental_prune_unreachable_nodes_unconditionally, llvm::SourceMgr* source_mgr, mlir::MLIRContext* context); // Export the module to file, can be either mlir or flatbuffer. absl::Status ExportFlatbufferOrMlir(const std::string& output_filename, - bool output_mlir, mlir::ModuleOp module); + bool output_mlir, mlir::ModuleOp module, + bool enable_select_tf_ops); } // namespace tac } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 39942fa0f5f..0223cd93a79 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -37,7 +38,6 @@ limitations under the License. #include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" @@ -70,6 +70,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" @@ -436,9 +437,9 @@ static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef( } // Converts a mlir padding StringRef to TfLitePadding. -// Returns llvm::None if conversion fails. -static Optional GetTflitePadding(Operation* inst, - llvm::StringRef padding) { +// Returns std::nullopt if conversion fails. +static std::optional GetTflitePadding(Operation* inst, + llvm::StringRef padding) { const tflite::Padding padding_attr = std::move(llvm::StringSwitch(padding) .Case("SAME", tflite::Padding_SAME) @@ -451,16 +452,16 @@ static Optional GetTflitePadding(Operation* inst, } return inst->emitOpError() << "Invalid padding attribute: " << padding, - llvm::None; + std::nullopt; } // Extracts TfLitePoolParams from a TFL custom op. // Template parameter, TFLOp, should be a TFL custom op containing attributes // generated from TfLitePoolParams. -// Returns llvm::None if conversion fails. +// Returns std::nullopt if conversion fails. template -static Optional GetTflitePoolParams(Operation* inst, - TFLOp op) { +static std::optional GetTflitePoolParams(Operation* inst, + TFLOp op) { TfLitePoolParams pool_params; pool_params.stride_height = op.stride_h().getSExtValue(); pool_params.stride_width = op.stride_w().getSExtValue(); @@ -474,11 +475,14 @@ static Optional GetTflitePoolParams(Operation* inst, return pool_params; } - return llvm::None; + return std::nullopt; } namespace { +using ::mlir::tf_saved_model::kTfSavedModelExportedNamesAttr; +using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; + // Helper struct that wraps inputs/outputs of a single SignatureDef. struct SignatureDefData { // Note, we are using maps here to make order deterministic @@ -498,9 +502,9 @@ struct SignatureDefData { class Translator { public: // Translates the given MLIR module into TFLite FlatBuffer format and returns - // the serialized output. Returns llvm::None on unsupported, invalid inputs or - // internal error. - static Optional Translate( + // the serialized output. Returns std::nullopt on unsupported, invalid inputs + // or internal error. + static std::optional Translate( ModuleOp module, const toco::TocoFlags& toco_flags, const std::unordered_set& tags, OpOrArgNameMapper* op_or_arg_name_mapper, @@ -543,29 +547,29 @@ class Translator { ->getOrLoadDialect(); } - Optional TranslateInternal(); + std::optional TranslateInternal(); // Returns TFLite buffer populated with constant value if the operation is // TFLite constant operation. Otherwise, returns an empty buffer. Emits error - // and returns llvm::None on failure. - Optional> BuildBuffer(Value value); + // and returns std::nullopt on failure. + std::optional> BuildBuffer(Value value); // Build TFLite tensor from the given type. This function is for tfl.lstm // intermediates, which should have UniformQuantizedType. - Optional> BuildTensorFromType( + std::optional> BuildTensorFromType( mlir::Type type, const std::string& name); - // Builds TF::VariantType from the given element type. Returns llvm::None if + // Builds TF::VariantType from the given element type. Returns std::nullopt if // failure. Returns empty vector if the element type is not TF::VariantType or // there is empty TensorType in the TF::VariantType. - Optional>> + std::optional>> BuildTFVariantType(mlir::Type element_type); // Builds TFLite tensor from the given value. `buffer_idx` is index of the - // corresponding buffer. Emits error and returns llvm::None on failure. - Optional> BuildTensor( + // corresponding buffer. Emits error and returns std::nullopt on failure. + std::optional> BuildTensor( Value value, const std::string& name, unsigned buffer_idx, - const Optional>& + const std::optional>& quant_parameters); // TODO(b/137395003): Legalize tf.IfOp to TFLite dialect, and change the @@ -593,10 +597,10 @@ class Translator { const std::vector& operands, const std::vector& results); - Optional CreateFlexOpCustomOptions( + std::optional CreateFlexOpCustomOptions( const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); - Optional CreateCustomOpCustomOptions( + std::optional CreateCustomOpCustomOptions( const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); std::unique_ptr CreateFlexBuilderWithNodeAttrs( @@ -609,8 +613,8 @@ class Translator { tflite::BuiltinOperator builtin); // Builds operator for the given operation with specified operand and result - // tensor indices. Emits an error and returns llvm::None on failure. - Optional> BuildOperator( + // tensor indices. Emits an error and returns std::nullopt on failure. + llvm::Optional> BuildOperator( Operation* inst, std::vector operands, const std::vector& results, const std::vector& intermediates); @@ -622,7 +626,7 @@ class Translator { // Build a subgraph with a given name out of the region either corresponding // to a function's body or while op. Modifies *region by calling // ExtractControlEdges. - Optional> BuildSubGraph( + std::optional> BuildSubGraph( const std::string& name, Region* region, const int index); // Modifies *block by unwrapping all ControlNodeOps. The DAG of the control @@ -636,11 +640,11 @@ class Translator { // Encodes the `tfl.metadata` dictionary attribute of the module to the // metadata section in the final model. - Optional>> + std::optional>> CreateMetadataVector(); // Builds and returns list of tfl.SignatureDef sections in the model. - Optional>> + std::optional>> CreateSignatureDefs(const std::vector& signature_defs); // Returns list of offsets for the passed 'items' in TensorMap structure @@ -747,7 +751,7 @@ std::string Translator::UniqueName(mlir::Value val) { return std::string(name_mapper_.GetUniqueName(val)); } -Optional> Translator::BuildBuffer( +std::optional> Translator::BuildBuffer( mlir::Value value) { auto inst = value.getDefiningOp(); ElementsAttr attr; @@ -756,7 +760,7 @@ Optional> Translator::BuildBuffer( // TFLite module. attr = cst.getValue().cast(); } else if (auto cst = dyn_cast(inst)) { - attr = cst.value(); + attr = cst.getValue(); } else if (auto cst = dyn_cast(inst)) { attr = cst.getValue(); } else if (auto cst = dyn_cast(inst)) { @@ -793,7 +797,7 @@ Optional> Translator::BuildBuffer( inst->emitError( Twine("failed to convert value attribute to tensor with error: " + status.ToString())); - return llvm::None; + return std::nullopt; } // TensorFlow and TensorFlow Lite use different string encoding formats. @@ -819,7 +823,7 @@ Optional> Translator::BuildBuffer( return tflite::CreateBuffer(builder_, buffer_data); } -Optional>> +std::optional>> Translator::BuildTFVariantType(mlir::Type element_type) { std::vector> variant_params; auto variant_type = element_type.dyn_cast(); @@ -829,7 +833,7 @@ Translator::BuildTFVariantType(mlir::Type element_type) { // We only support up to one nested type in tf_type.variant_type. if (variant_type.getSubtypes().size() > 1) { - return llvm::None; + return std::nullopt; } if (variant_type.getSubtypes().empty()) { return variant_params; @@ -849,7 +853,7 @@ Translator::BuildTFVariantType(mlir::Type element_type) { return variant_params; } -Optional> Translator::BuildTensorFromType( +std::optional> Translator::BuildTensorFromType( mlir::Type type, const std::string& name) { auto tensor_type = type.cast(); @@ -861,17 +865,17 @@ Optional> Translator::BuildTensorFromType( shape_ref = tensor_type.getShape(); shape = std::vector(shape_ref.begin(), shape_ref.end()); } else { - return llvm::None; + return std::nullopt; } } auto element_type = tensor_type.getElementType(); tflite::TensorType tflite_element_type = GetTFLiteType(tensor_type.getElementType()).value(); - Optional>> variant_params = - BuildTFVariantType(element_type); - if (!variant_params.hasValue()) { - return llvm::None; + std::optional>> + variant_params = BuildTFVariantType(element_type); + if (!variant_params.has_value()) { + return std::nullopt; } BufferOffset q_params = 0; if (auto qtype = element_type.dyn_cast()) { @@ -897,9 +901,9 @@ Optional> Translator::BuildTensorFromType( variant_params->empty() ? 0 : builder_.CreateVector(*variant_params)); } -Optional> Translator::BuildTensor( +std::optional> Translator::BuildTensor( Value value, const std::string& name, unsigned buffer_idx, - const Optional>& + const std::optional>& quant_parameters) { auto type = value.getType().cast(); @@ -924,7 +928,7 @@ Optional> Translator::BuildTensor( auto* inst = value.getDefiningOp(); if (type.hasStaticShape()) { llvm::ArrayRef shape_ref = type.getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; + if (mlir::failed(check_shape(shape_ref))) return std::nullopt; shape = std::vector(shape_ref.begin(), shape_ref.end()); } else if (inst && IsConst(inst)) { @@ -934,21 +938,21 @@ Optional> Translator::BuildTensor( auto tensor_attr = inst->getAttr("value").cast(); llvm::ArrayRef shape_ref = tensor_attr.getType().cast().getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; + if (mlir::failed(check_shape(shape_ref))) return std::nullopt; shape = std::vector(shape_ref.begin(), shape_ref.end()); } else if (type.hasRank()) { llvm::ArrayRef shape_ref = type.getShape(); - if (mlir::failed(check_shape(shape_ref))) return llvm::None; + if (mlir::failed(check_shape(shape_ref))) return std::nullopt; shape.reserve(shape_ref.size()); for (auto& dim : shape_ref) { // translate dynamic shapes from mlir to tfl values shape.push_back( - dim == mlir::ShapedType::kDynamicSize ? 1 : static_cast(dim)); + dim == mlir::ShapedType::kDynamic ? 1 : static_cast(dim)); shape_signature.push_back(static_cast( - dim == mlir::ShapedType::kDynamicSize ? tensorflow::kTFDynamicSize - : dim)); + dim == mlir::ShapedType::kDynamic ? tensorflow::kTFDynamicSize + : dim)); } } @@ -965,10 +969,10 @@ Optional> Translator::BuildTensor( tflite::TensorType tflite_element_type = GetTFLiteType(type.getElementType()).value(); - Optional>> variant_params = - BuildTFVariantType(element_type); - if (!variant_params.hasValue()) { - return llvm::None; + std::optional>> + variant_params = BuildTFVariantType(element_type); + if (!variant_params.has_value()) { + return std::nullopt; } BufferOffset q_params; @@ -993,7 +997,7 @@ Optional> Translator::BuildTensor( tflite::QuantizationDetails_NONE, /*details=*/0, qtype.getQuantizedDimension()); } else if (quant_parameters.has_value()) { - q_params = quant_parameters.getValue(); + q_params = quant_parameters.value(); } else { q_params = tflite::CreateQuantizationParameters(builder_); } @@ -1032,8 +1036,8 @@ BufferOffset Translator::BuildIfOperator( mlir::TF::IfOp op, const std::vector& operands, const std::vector& results) { auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF); - int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str()); - int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str()); + int then_subgraph_index = subgraph_index_map_.at(op.getThenBranch().str()); + int else_subgraph_index = subgraph_index_map_.at(op.getElseBranch().str()); auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index, else_subgraph_index) .Union(); @@ -1060,21 +1064,21 @@ BufferOffset Translator::BuildCallOnceOperator( builtin_options); } -Optional> Translator::BuildWhileOperator( +llvm::Optional> Translator::BuildWhileOperator( mlir::TFL::WhileOp op, const std::vector& operands, const std::vector& results) { auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); - auto get_call_index = [&](mlir::Block& b) -> Optional { - if (b.getOperations().size() != 2) return llvm::None; + auto get_call_index = [&](mlir::Block& b) -> std::optional { + if (b.getOperations().size() != 2) return std::nullopt; if (auto call_op = dyn_cast(b.front())) return subgraph_index_map_.at(call_op.getCallee().str()); - return llvm::None; + return std::nullopt; }; auto body_subgraph_index = get_call_index(op.getBody().front()); auto cond_subgraph_index = get_call_index(op.getCond().front()); if (!body_subgraph_index || !cond_subgraph_index) return op.emitOpError("only single call cond/body while export supported"), - llvm::None; + std::nullopt; auto builtin_options = tflite::CreateWhileOptions(builder_, *cond_subgraph_index, *body_subgraph_index) @@ -1125,12 +1129,12 @@ BufferOffset Translator::BuildCustomOperator( tflite::CustomOptionsFormat_FLEXBUFFERS); } -Optional Translator::CreateFlexOpCustomOptions( +std::optional Translator::CreateFlexOpCustomOptions( const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { std::string node_def_str; if (!node_def.SerializeToString(&node_def_str)) { return emitError(loc, "failed to serialize tensorflow node_def"), - llvm::None; + std::nullopt; } auto flex_builder = std::make_unique(); @@ -1142,7 +1146,7 @@ Optional Translator::CreateFlexOpCustomOptions( return builder_.CreateVector(flex_builder->GetBuffer()); } -Optional Translator::CreateCustomOpCustomOptions( +std::optional Translator::CreateCustomOpCustomOptions( const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc); return builder_.CreateVector(flex_builder->GetBuffer()); @@ -1239,14 +1243,14 @@ uint32_t Translator::GetOpcodeIndex(const std::string& op_name, return it.first->second; } -Optional> Translator::BuildOperator( +llvm::Optional> Translator::BuildOperator( Operation* inst, std::vector operands, const std::vector& results, const std::vector& intermediates) { const auto* dialect = inst->getDialect(); if (!dialect) { inst->emitOpError("dialect is not registered"); - return llvm::None; + return std::nullopt; } // If TFLite built in op, create operator as a builtin op. @@ -1256,7 +1260,7 @@ Optional> Translator::BuildOperator( if (!enabled_op_types_.contains(OpType::kTfliteBuiltin)) { return inst->emitOpError( "is a TFLite builtin op but builtin emission is not enabled"), - llvm::None; + std::nullopt; } auto builtin_code = GetBuiltinOpCode(inst); @@ -1272,13 +1276,13 @@ Optional> Translator::BuildOperator( inst->emitOpError( "number of operands and results don't match, only canonical " "TFL While supported"); - return llvm::None; + return std::nullopt; } return BuildWhileOperator(whileOp, operands, results); } inst->emitOpError("is not a supported TFLite op"); - return llvm::None; + return std::nullopt; } if (*builtin_code == tflite::BuiltinOperator_CALL_ONCE) { @@ -1327,7 +1331,7 @@ Optional> Translator::BuildOperator( // we emit the op as custom. auto node_def = GetTensorFlowNodeDef(inst); if (!node_def) { - return llvm::None; + return std::nullopt; } std::string op_name = node_def->op(); @@ -1358,7 +1362,7 @@ Optional> Translator::BuildOperator( if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) { custom_options = *options; } else { - return llvm::None; + return std::nullopt; } // Gather flex ops. @@ -1371,7 +1375,7 @@ Optional> Translator::BuildOperator( CreateCustomOpCustomOptions(*node_def, inst->getLoc())) { custom_options = *options; } else { - return llvm::None; + return std::nullopt; } // Gather custom ops. @@ -1389,7 +1393,7 @@ Optional> Translator::BuildOperator( inst->emitOpError("is neither a custom op nor a flex op"), tflite::metrics::ConverterErrorData::ERROR_NEEDS_CUSTOM_OPS); } - return llvm::None; + return std::nullopt; } uint32_t opcode_index = @@ -1408,7 +1412,7 @@ Optional> Translator::BuildOperator( return inst->emitOpError( "is not any of a builtin TFLite op, a flex TensorFlow op or a " "custom TensorFlow op"), - llvm::None; + std::nullopt; } void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) { @@ -1457,12 +1461,12 @@ BufferOffset Translator::GetQuantizationForQuantStatsOpOutput( mlir::quantfork::StatisticsOp stats_op) { auto layer_stats = stats_op.getLayerStats().cast(); - Optional axis_stats = stats_op.getAxisStats(); - Optional axis = stats_op.getAxis(); + std::optional axis_stats = stats_op.getAxisStats(); + std::optional axis = stats_op.getAxis(); std::vector mins, maxs; mlir::DenseFPElementsAttr min_max_attr = axis_stats.has_value() - ? axis_stats.getValue().cast() + ? axis_stats.value().cast() : layer_stats; for (const auto& index_and_value : @@ -1479,10 +1483,10 @@ Translator::GetQuantizationForQuantStatsOpOutput( builder_, builder_.CreateVector(mins), builder_.CreateVector(maxs), /*scale=*/0, /*zero_point=*/0, tflite::QuantizationDetails_NONE, /*details=*/0, - /*quantized_dimension=*/axis.has_value() ? axis.getValue() : 0); + /*quantized_dimension=*/axis.has_value() ? axis.value() : 0); } -Optional> Translator::BuildSubGraph( +std::optional> Translator::BuildSubGraph( const std::string& name, Region* region, const int index) { const auto control_edges = ExtractControlEdges(®ion->front()); bool has_input_attr = false; @@ -1503,7 +1507,8 @@ Optional> Translator::BuildSubGraph( tensor_index_map.insert({value, tensors.size()}); tensor_index_map_[subgraph_index][tensor_name] = tensors.size(); - Optional> quant_parameters; + std::optional> + quant_parameters; if (value.hasOneUse()) { auto stats_op = llvm::dyn_cast(*value.user_begin()); @@ -1546,7 +1551,7 @@ Optional> Translator::BuildSubGraph( if (has_input_attr) tensor_name = std::string(name_mapper_.GetUniqueName(arg)); if (tensor_name.empty()) tensor_name = absl::StrCat("arg", i); - if (!build_tensor_and_buffer(arg, index, tensor_name)) return llvm::None; + if (!build_tensor_and_buffer(arg, index, tensor_name)) return std::nullopt; } bool failed_once = false; @@ -1579,7 +1584,7 @@ Optional> Translator::BuildSubGraph( continue; } else { intermediates.push_back(tensors.size()); - tensors.push_back(tensor_or.getValue()); + tensors.push_back(tensor_or.value()); } } } @@ -1598,7 +1603,8 @@ Optional> Translator::BuildSubGraph( tensor_name = "NumericVerify/" + UniqueName(quantized_op_val) + ":" + std::to_string(tensor_index_map[quantized_op_val]); } - if (!build_tensor_and_buffer(val, index, tensor_name)) return llvm::None; + if (!build_tensor_and_buffer(val, index, tensor_name)) + return std::nullopt; } // Skip constant ops as they don't represent a TFLite operator. @@ -1650,7 +1656,7 @@ Optional> Translator::BuildSubGraph( subgraph_op_inst_map_.resize(index + 1); } subgraph_op_inst_map_[index] = operators_in_mlir; - if (failed_once) return llvm::None; + if (failed_once) return std::nullopt; // Get input and output tensor indices for the subgraph. std::vector inputs, outputs; @@ -1687,7 +1693,7 @@ BufferOffset Translator::BuildMetadata(StringRef name, return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index); } -Optional>> +std::optional>> Translator::CreateMetadataVector() { auto dict_attr = module_->getAttrOfType("tfl.metadata"); std::vector> metadata; @@ -1701,7 +1707,7 @@ Translator::CreateMetadataVector() { module_.emitError( "all values in tfl.metadata's dictionary key-value pairs should be " "string attributes"); - return llvm::None; + return std::nullopt; } } } @@ -1750,14 +1756,14 @@ llvm::SmallVector GetStringsFromAttrWithSeparator( // Attribute identified by 'attr_name'. std::vector GetStringsFromDictionaryAttr( const llvm::SmallVector& dict_attrs, - const std::string& attr_name) { + const StringRef attr_name) { std::vector result; for (const auto& arg_attr : dict_attrs) { if (!arg_attr) continue; auto attrs = arg_attr.getValue(); for (const auto attr : attrs) { - if (attr.getName().str() == attr_name) { + if (attr.getName() == attr_name) { auto array_attr = attr.getValue().dyn_cast_or_null(); if (!array_attr || array_attr.empty()) continue; auto string_attr = array_attr[0].dyn_cast_or_null(); @@ -1772,7 +1778,6 @@ std::vector GetStringsFromDictionaryAttr( std::vector BuildSignaturedef( FuncOp main_op, const std::string& saved_model_tag, const uint32_t subgraph_index, tensorflow::OpOrArgNameMapper& name_mapper) { - static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path"; static const char kEntryFunctionAttributes[] = "tf.entry_function"; // Fetch inputs and outputs from the signature. @@ -1780,9 +1785,9 @@ std::vector BuildSignaturedef( main_op.getAllArgAttrs(arg_attrs); main_op.getAllResultAttrs(res_attrs); std::vector sig_def_inputs = - GetStringsFromDictionaryAttr(arg_attrs, kSignatureDefIndexPath); + GetStringsFromDictionaryAttr(arg_attrs, kTfSavedModelIndexPathAttr); std::vector sig_def_outputs = - GetStringsFromDictionaryAttr(res_attrs, kSignatureDefIndexPath); + GetStringsFromDictionaryAttr(res_attrs, kTfSavedModelIndexPathAttr); // If no defined saved model signature, then return empty list. // This can happen when we are converting model not from SavedModel. @@ -1823,7 +1828,7 @@ std::vector BuildSignaturedef( } // Exported method name. auto exported_name = - main_op->getAttrOfType("tf_saved_model.exported_names"); + main_op->getAttrOfType(kTfSavedModelExportedNamesAttr); if (exported_name.empty()) { main_op.emitError("Empty exported names for main Function"); return {}; @@ -1861,7 +1866,7 @@ std::vector> Translator::GetList( return result; } -Optional>> +std::optional>> Translator::CreateSignatureDefs( const std::vector& signature_defs) { std::vector> signature_defs_buffer; @@ -1914,7 +1919,7 @@ bool UpdateEntryFunction(ModuleOp module) { return true; } -Optional Translator::Translate( +std::optional Translator::Translate( ModuleOp module, const toco::TocoFlags& toco_flags, const std::unordered_set& tags, OpOrArgNameMapper* op_or_arg_name_mapper, @@ -1922,8 +1927,8 @@ Optional Translator::Translate( OpOrArgLocNameMapper default_op_or_arg_name_mapper; if (!op_or_arg_name_mapper) op_or_arg_name_mapper = &default_op_or_arg_name_mapper; - if (!UpdateEntryFunction(module)) return llvm::None; - if (!IsValidTFLiteMlirModule(module)) return llvm::None; + if (!UpdateEntryFunction(module)) return std::nullopt; + if (!IsValidTFLiteMlirModule(module)) return std::nullopt; Translator translator(module, toco_flags, tags, op_or_arg_name_mapper, metadata); return translator.TranslateInternal(); @@ -1955,7 +1960,7 @@ bool Translator::CheckGpuDelegateCompatibility(uint8_t* model_buffer_pointer) { return gpu_compatibile; } -Optional Translator::TranslateInternal() { +std::optional Translator::TranslateInternal() { // A list of named regions in the module with main function being the first in // the list. The main function is required as the first subgraph in the model // is entry point for the model. @@ -2084,7 +2089,7 @@ Optional Translator::TranslateInternal() { return failed_region.second->getParentOp()->emitError() << "failed while converting: '" << failed_region.first << "': " << err, - llvm::None; + std::nullopt; } // Log MAC count. @@ -2123,7 +2128,7 @@ Optional Translator::TranslateInternal() { auto description = builder_.CreateString(model_description.data()); VectorBufferOffset metadata_buffer = 0; // Deprecated auto metadata = CreateMetadataVector(); - if (!metadata) return llvm::None; + if (!metadata) return std::nullopt; std::vector signature_defs_vec; subgraph_index = 0; @@ -2152,13 +2157,13 @@ Optional Translator::TranslateInternal() { // There is a limit of 2GB for a flatbuffer. if (builder_.GetSize() > 2147483648) { LOG(ERROR) << "Model size is bigger than 2gb"; - return llvm::None; + return std::nullopt; } tflite::UpdateOpVersion(builder_.GetBufferPointer()); tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer()); if (supported_backends_.find("GPU") != supported_backends_.end()) { if (!CheckGpuDelegateCompatibility(builder_.GetBufferPointer())) { - return llvm::None; + return std::nullopt; } } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 7dc0ce17c4c..532d342f360 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -70,8 +70,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/low_bit_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/size_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" @@ -111,6 +113,9 @@ namespace tfl = mlir::TFL; namespace { +using ::mlir::tf_saved_model::kTfSavedModelExportedNamesAttr; +using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; + bool IsQuantized(const TensorT& tensor) { return (tensor.quantization != nullptr) && !tensor.quantization->zero_point.empty(); @@ -849,7 +854,8 @@ StatusOr ConvertOp( mlir::SmallVector shape; for (auto s : new_shape) { - shape.push_back(builder.getI32IntegerAttr(static_cast(s))); + shape.push_back( + builder.getI32IntegerAttr(mlir::TFL::ConvertToTfliteSize(s))); } auto output_shape = DenseElementsAttr::get(shape_type, shape); auto shape_op = builder.create(loc, output_shape); @@ -906,9 +912,8 @@ StatusOr ConvertOp( int32_t dim_size = 0; for (const auto& dim : llvm::enumerate(shape_attr.getValues())) { - const int64_t size = dim.value().getSExtValue(); - shape.push_back( - builder.getI32IntegerAttr(static_cast(size))); + shape.push_back(builder.getI32IntegerAttr( + mlir::TFL::ConvertToTfliteSize(dim.value().getSExtValue()))); ++dim_size; } auto shape_type = tensorflow::GetTypeFromTFTensorShape( @@ -1118,8 +1123,6 @@ void SetSignature( FuncOp func, const tflite::SignatureDefT* signature, const std::vector>& tensors) { auto* context = func->getContext(); - static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path"; - static const char kExportedNameAttr[] = "tf_saved_model.exported_names"; static const char kEntryFunctionAttributes[] = "tf.entry_function"; auto dict_attr = @@ -1140,7 +1143,7 @@ void SetSignature( return; } func.setArgAttr( - arg_index, kSignatureDefIndexPath, + arg_index, kTfSavedModelIndexPathAttr, mlir::ArrayAttr::get(context, {mlir::StringAttr::get( context, input_pair.value()->name)})); } @@ -1155,14 +1158,14 @@ void SetSignature( func->emitWarning("Invalid signature tensors specified."); return; } - func.setResultAttr(arg_index, kSignatureDefIndexPath, + func.setResultAttr(arg_index, kTfSavedModelIndexPathAttr, mlir::ArrayAttr::get( context, {mlir::StringAttr::get( context, output_pair.value()->name)})); seen_indices.insert(arg_index); } func->setAttr( - kExportedNameAttr, + kTfSavedModelExportedNamesAttr, mlir::ArrayAttr::get( context, {mlir::StringAttr::get(context, signature->signature_key)})); } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index 4153e3aa88a..35475091aa8 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -178,11 +178,11 @@ static bool ConvertBoolAttrForOptionWriter( return b; } -// Overloading of ConvertBoolAttrForOptionWriter which takes Optional as -// an input. If value is not specified, false is set for the attribute. +// Overloading of ConvertBoolAttrForOptionWriter which takes std::optional +// as an input. If value is not specified, false is set for the attribute. static bool ConvertBoolAttrForOptionWriter( - mlir::Optional b, flatbuffers::FlatBufferBuilder* builder) { - return b.has_value() ? b.getValue() : false; + std::optional b, flatbuffers::FlatBufferBuilder* builder) { + return b.has_value() ? b.value() : false; } static flatbuffers::Offset ConvertStrAttrForOptionWriter( diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h index 600c2740da1..8d1200d33db 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h @@ -41,12 +41,12 @@ namespace mlir { std::string GetMlirOpNameFromOpCode(const ::tflite::OperatorCodeT &op_code); // Returns the builtin op code for the given MLIR operation on success; emits -// error and returns llvm::None on failure. +// error and returns std::nullopt on failure. llvm::Optional GetBuiltinOpCode(Operation *mlir_op); // Packs the given MLIR operation into a TFLite FlatBuffer operator object. // Returns the FlatBuffer offset for the operator on success; emits error and -// returns llvm::None on failure. +// returns std::nullopt on failure. llvm::Optional> CreateFlatBufferOperator( Operation *mlir_op, uint32_t opcode_index, const std::vector &operands, const std::vector &results, diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td index cd7505b9427..5a83b906a68 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td @@ -36,8 +36,6 @@ def TFL_Dialect : Dialect { let cppNamespace = "::mlir::TFL"; - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; - let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 1; @@ -52,6 +50,7 @@ def TFL_Dialect : Dialect { Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) override; }]; + let useFoldAPI = kEmitFoldAdaptorFolder; } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 233999f0129..177af6feb9b 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -56,6 +56,7 @@ limitations under the License. #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h" +#include "tensorflow/compiler/mlir/lite/utils/size_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" @@ -80,8 +81,8 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SquareOp); namespace { -ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, - OperationState &result) { +ParseResult parseOneResultSameOperandTypeOp(OpAsmParser& parser, + OperationState& result) { SmallVector ops; Type type; // If the operand list is in-between parentheses, then we have a generic form. @@ -109,7 +110,7 @@ ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, parser.addTypeToList(type, result.types)); } -void printOneResultOp(Operation *op, OpAsmPrinter &p) { +void printOneResultOp(Operation* op, OpAsmPrinter& p) { assert(op->getNumResults() == 1 && "op should have one result"); // If not all the operand and result types are the same, just use the @@ -128,8 +129,8 @@ void printOneResultOp(Operation *op, OpAsmPrinter &p) { p << " : " << resultType; } -Operation *getDefiningBroadcastArgsOp(Value operand) { - auto *defining_op = operand.getDefiningOp(); +Operation* getDefiningBroadcastArgsOp(Value operand) { + auto* defining_op = operand.getDefiningOp(); if (!llvm::dyn_cast_or_null(defining_op) && !llvm::dyn_cast_or_null(defining_op)) { return nullptr; @@ -137,7 +138,7 @@ Operation *getDefiningBroadcastArgsOp(Value operand) { Value broadcast_shape = defining_op->getOperand( 1); // Broadcasted shape operand of BroadcastTo op. - Operation *parent_of_defining_op = broadcast_shape.getDefiningOp(); + Operation* parent_of_defining_op = broadcast_shape.getDefiningOp(); if (!llvm::dyn_cast_or_null(parent_of_defining_op) && !llvm::dyn_cast_or_null(parent_of_defining_op)) { return nullptr; @@ -164,7 +165,7 @@ bool VerifyCompatibleShapesSameElementType(TypeRange lhs, TypeRange rhs) { // non-static and maximum rank is within the given rank, this method returns // true. bool VerifyOperandsHaveSameShapesOrBroadcastableShape( - Operation *op, ArrayRef indices, int max_bcast_rank) { + Operation* op, ArrayRef indices, int max_bcast_rank) { if (indices.empty()) return true; // First, it checks there are any inputs that has unknown rank. @@ -228,9 +229,9 @@ bool VerifyOperandsHaveSameShapesOrBroadcastableShape( // Checks if all operands are broadcasted by BroadcastTo ops with the shape // is calculated from the same BroadcastArgs op. In such case, all operands // will have the same shape. - Operation *broadcast_args_pivot = nullptr; + Operation* broadcast_args_pivot = nullptr; for (unsigned index : indices) { - Operation *parent_broadcast_args = + Operation* parent_broadcast_args = getDefiningBroadcastArgsOp(op->getOperand(index)); if (parent_broadcast_args == nullptr) { return false; @@ -306,7 +307,7 @@ struct RemoveOptionalZeroBias : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ConcreteOpType op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { if (EqualsZero(op.getBias())) { auto none_value = rewriter.create( rewriter.getUnknownLoc(), rewriter.getNoneType(), @@ -405,17 +406,17 @@ struct TensorFlowLiteInlinerInterface : public DialectInlinerInterface { //===--------------------------------------------------------------------===// // Allow all call operations to be inlined. - bool isLegalToInline(Operation *call, Operation *callable, + bool isLegalToInline(Operation* call, Operation* callable, bool wouldBeCloned) const final { return true; } - bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, - BlockAndValueMapping &) const final { + bool isLegalToInline(Operation* op, Region* dest, bool wouldBeCloned, + IRMapping&) const final { // No TFLite op restricts inlining today, revise as needed in the future. return true; } - bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, - BlockAndValueMapping &valueMapping) const final { + bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned, + IRMapping& valueMapping) const final { return isa(dest->getParentOp()); } }; @@ -429,12 +430,12 @@ struct TensorFlowLiteDialectFoldInterface : public DialectFoldInterface { // materializing constants. // In the TFLite dialect we materialize inside a while regions as slightly // more efficient computationally. - bool shouldMaterializeInto(Region *region) const final { + bool shouldMaterializeInto(Region* region) const final { return isa(region->getParentOp()); } }; -void TFLDialect::printType(Type type, DialectAsmPrinter &os) const { +void TFLDialect::printType(Type type, DialectAsmPrinter& os) const { if (type.isa()) { os << "control"; return; @@ -442,7 +443,7 @@ void TFLDialect::printType(Type type, DialectAsmPrinter &os) const { os << ""; } -Type TFLDialect::parseType(DialectAsmParser &parser) const { +Type TFLDialect::parseType(DialectAsmParser& parser) const { StringRef data_type; if (parser.parseKeyword(&data_type)) return Type(); if (data_type == "control") return ControlType::get(getContext()); @@ -507,8 +508,8 @@ inline std::vector GetPaddedShape(ArrayRef old_shape, // Helper method that given and 'current_index' representing // index in broadcasted tensor, get the index in the flat original tensor. // 'shape' is the original shape with padding to match result shape. -int64_t GetElementIndex(const std::vector &shape, - const std::vector ¤t_index) { +int64_t GetElementIndex(const std::vector& shape, + const std::vector& current_index) { int64_t ind = 0; int64_t mul = 1; for (int i = shape.size() - 1; i >= 0; --i) { @@ -521,8 +522,8 @@ int64_t GetElementIndex(const std::vector &shape, // Helper method that increment index represented in 'current_index_ptr' // in the shape of 'result_shape'. void IncrementIndex(ArrayRef result_shape, - std::vector *current_index_ptr) { - std::vector ¤t_index = *current_index_ptr; + std::vector* current_index_ptr) { + std::vector& current_index = *current_index_ptr; for (int i = result_shape.size() - 1; i >= 0; --i) { current_index[i]++; if (current_index[i] == result_shape[i]) { @@ -543,7 +544,7 @@ template > Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs, DenseElementsAttr rhs, - const CalculationT &calculate) { + const CalculationT& calculate) { auto type = OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType()) .dyn_cast_or_null(); if (!type) { @@ -606,7 +607,7 @@ template > Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1, - Attribute operand2, const CalculationT &calculate) { + Attribute operand2, const CalculationT& calculate) { if (operand1.dyn_cast_or_null() && operand2.dyn_cast_or_null()) { return ConstFoldBinaryOpDenseDense( @@ -663,7 +664,7 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand, const int num_elements = result_shape_type.getNumElements(); new_values.reserve(num_elements); - for (const APFloat &old_value : dense_elements.getValues()) { + for (const APFloat& old_value : dense_elements.getValues()) { new_values.push_back(calculate(old_value)); } @@ -673,7 +674,7 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand, return {}; } -void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs, +void buildComparisonBinOp(Builder* builder, OperationState& result, Value lhs, Value rhs) { auto result_type = OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType()); @@ -692,7 +693,7 @@ void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs, } } -void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result, +void buildFusedBroadcastableBinOp(Builder* builder, OperationState& result, Value lhs, Value rhs, StringAttr fused_activation_function) { auto result_type = @@ -714,7 +715,8 @@ void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result, // AddOp //===----------------------------------------------------------------------===// -OpFoldResult AddOp::fold(ArrayRef operands) { +OpFoldResult AddOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); // TODO(b/142478136): Handle fused ops. if (getFusedActivationFunction() != "NONE") return {}; return ConstFoldBinaryOp( @@ -722,7 +724,7 @@ OpFoldResult AddOp::fold(ArrayRef operands) { [](APInt a, APInt b) { return a + b; }); } -int64_t AddOp::GetArithmeticCount(Operation *op) { +int64_t AddOp::GetArithmeticCount(Operation* op) { int64_t count; if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count; @@ -755,28 +757,28 @@ int64_t GetConcatenationOpAxis(ConcatenationOp op) { // // Note: If an operand has unranked tensor type or has dynamic dimension size, // those dimensions will be skipped. -LogicalResult VerifyConcatenationOpTypes(Operation *op, +LogicalResult VerifyConcatenationOpTypes(Operation* op, RankedTensorType output_type, ArrayRef operand_types, int64_t axis) { const int64_t output_rank = output_type.getRank(); SmallVector result_dim_sizes_loc(output_rank, - ShapedType::kDynamicSize); + ShapedType::kDynamic); SmallVector result_dim_sizes(output_type.getShape().begin(), output_type.getShape().end()); result_dim_sizes[axis] = 0; auto FormatLoc = [&result_dim_sizes_loc](int64_t dim) { const int64_t loc = result_dim_sizes_loc[dim]; - if (loc == ShapedType::kDynamicSize) return std::string("output"); + if (loc == ShapedType::kDynamic) return std::string("output"); return llvm::formatv("operand #{0}", loc).str(); }; - for (const auto &operand : llvm::enumerate(operand_types)) { + for (const auto& operand : llvm::enumerate(operand_types)) { auto operand_type = operand.value().dyn_cast(); if (!operand_type) { - result_dim_sizes[axis] = ShapedType::kDynamicSize; + result_dim_sizes[axis] = ShapedType::kDynamic; continue; } @@ -793,7 +795,7 @@ LogicalResult VerifyConcatenationOpTypes(Operation *op, if (dim == axis) { if (ShapedType::isDynamic(operand_dim_size) || ShapedType::isDynamic(result_dim_size)) { - result_dim_sizes[axis] = ShapedType::kDynamicSize; + result_dim_sizes[axis] = ShapedType::kDynamic; } else { result_dim_sizes[axis] += operand_dim_size; } @@ -900,7 +902,8 @@ LogicalResult ConcatenationOp::verify() { operand_types, axis); } -OpFoldResult ConcatenationOp::fold(ArrayRef operands) { +OpFoldResult ConcatenationOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); if (getFusedActivationFunction() == "NONE") { if (auto output_type = getOutput().getType().dyn_cast()) { const int64_t axis = GetConcatenationOpAxis(*this); @@ -952,13 +955,13 @@ mlir::LogicalResult CustomOp::verify() { //===----------------------------------------------------------------------===// LogicalResult CustomTfOp::inferReturnTypes( - MLIRContext *, Optional location, ValueRange operands, + MLIRContext*, std::optional location, ValueRange operands, DictionaryAttr attr, RegionRange ranges, - SmallVectorImpl &inferredReturnTypes) { + SmallVectorImpl& inferredReturnTypes) { CustomTfOpAdaptor op(operands, attr, ranges); if (op.getRegions().empty()) return success(); - auto *real_op = &op.getBody().front().front(); + auto* real_op = &op.getBody().front().front(); if (llvm::isa(real_op)) { Value input = *operands.begin(); @@ -1002,7 +1005,7 @@ struct ConvertBroadcastToReshape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(BroadcastToOp op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { auto input_type = op.getInput().getType().cast(); auto output_type = op.getType().cast(); if (!input_type.hasStaticShape() || !output_type.hasStaticShape() || @@ -1023,8 +1026,8 @@ struct ConvertBroadcastToReshape : public OpRewritePattern { } }; -void BroadcastToOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void BroadcastToOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -1079,9 +1082,9 @@ LogicalResult FullyConnectedOp::verify() { return mlir::success(); } -LogicalResult FullyConnectedOp::fold(ArrayRef operands, - SmallVectorImpl &results) { - assert(operands.size() == 3); +LogicalResult FullyConnectedOp::fold(FoldAdaptor adaptor, + SmallVectorImpl& results) { + assert(adaptor.getOperands().size() == 3); // Folding not implemented with any activation function or any weight type // besides the default. @@ -1177,12 +1180,12 @@ LogicalResult FullyConnectedOp::fold(ArrayRef operands, return success(); } -void FullyConnectedOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void FullyConnectedOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add>(context); } -int64_t FullyConnectedOp::GetArithmeticCount(Operation *op) { +int64_t FullyConnectedOp::GetArithmeticCount(Operation* op) { int64_t count; if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp( op, &count)) @@ -1195,8 +1198,8 @@ int64_t FullyConnectedOp::GetArithmeticCount(Operation *op) { // Conv2DOp //===----------------------------------------------------------------------===// -void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { // TODO(b/180121750): Enable the pattern after the integration tests are // fixed. // results.add>(context); @@ -1204,7 +1207,7 @@ void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results, static LogicalResult ComputeConvWindowedOutputSize( int64_t input_size, int64_t filter_size, int64_t dilation_rate, - int64_t stride, tensorflow::Padding padding, int64_t *output_size) { + int64_t stride, tensorflow::Padding padding, int64_t* output_size) { int64_t pad_low; int64_t pad_high; @@ -1217,9 +1220,9 @@ static LogicalResult ComputeConvWindowedOutputSize( } LogicalResult Conv2DOp::inferReturnTypes( - MLIRContext *, Optional location, ValueRange operands, + MLIRContext*, std::optional location, ValueRange operands, DictionaryAttr attr, RegionRange, - SmallVectorImpl &inferredReturnTypes) { + SmallVectorImpl& inferredReturnTypes) { Conv2DOpAdaptor op(operands, attr); const Value input = op.getInput(); @@ -1262,7 +1265,7 @@ LogicalResult Conv2DOp::inferReturnTypes( // Output always have rank 4. All dimensions are initialized to // dynamic size and can be partially inferred. // TFL's conv2d is always NHWC format & the filter is OHWI. - SmallVector return_shape(4, ShapedType::kDynamicSize); + SmallVector return_shape(4, ShapedType::kDynamic); return_shape[0] = input_ty.getDimSize(0); return_shape[3] = filter_ty.getDimSize(0); @@ -1304,7 +1307,7 @@ bool Conv2DOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) { return true; } -int64_t Conv2DOp::GetArithmeticCount(Operation *op) { +int64_t Conv2DOp::GetArithmeticCount(Operation* op) { int64_t count; if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp( op, &count)) @@ -1317,14 +1320,14 @@ int64_t Conv2DOp::GetArithmeticCount(Operation *op) { // DepthwiseConv2DO //===----------------------------------------------------------------------===// -void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { // TODO(b/180121750): Enable the pattern after the integration tests are // fixed. // results.add>(context); } -int64_t DepthwiseConv2DOp::GetArithmeticCount(Operation *op) { +int64_t DepthwiseConv2DOp::GetArithmeticCount(Operation* op) { int64_t count; if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp( op, &count)) @@ -1337,7 +1340,7 @@ int64_t DepthwiseConv2DOp::GetArithmeticCount(Operation *op) { // GatherOp //===----------------------------------------------------------------------===// -static void BuildGatherOp(OpBuilder *builder, OperationState &result, +static void BuildGatherOp(OpBuilder* builder, OperationState& result, Value params, Value indices, IntegerAttr axis, IntegerAttr batch_dims) { auto params_type = params.getType().cast(); @@ -1481,7 +1484,7 @@ mlir::LogicalResult ScatterNdOp::verify() { // Checks whether the last `(shape_type.getDimSize(0) - outermost_dim)` // dimensions of `updates` and `shape` are equal. - for (const auto &shape_it : llvm::enumerate(shape_value)) { + for (const auto& shape_it : llvm::enumerate(shape_value)) { int64_t i = shape_it.index(); auto value = shape_it.value().getSExtValue(); if (i >= outermost_dim) { @@ -1497,7 +1500,7 @@ mlir::LogicalResult ScatterNdOp::verify() { // Checks if the output has the shape specified by `shape`. if (output_type.hasStaticShape()) { - for (const auto &shape_it : llvm::enumerate(shape_value)) { + for (const auto& shape_it : llvm::enumerate(shape_value)) { int i = shape_it.index(); auto value = shape_it.value().getSExtValue(); if (output_type.getDimSize(i) != value) { @@ -1515,7 +1518,8 @@ mlir::LogicalResult ScatterNdOp::verify() { // MulOp //===----------------------------------------------------------------------===// -OpFoldResult MulOp::fold(ArrayRef operands) { +OpFoldResult MulOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); // TODO(b/142478136): Handle fused ops. if (getFusedActivationFunction() != "NONE") return {}; @@ -1544,7 +1548,7 @@ OpFoldResult MulOp::fold(ArrayRef operands) { [](APInt a, APInt b) { return a * b; }); } -int64_t MulOp::GetArithmeticCount(Operation *op) { +int64_t MulOp::GetArithmeticCount(Operation* op) { int64_t count; if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count; @@ -1555,7 +1559,8 @@ int64_t MulOp::GetArithmeticCount(Operation *op) { // DivOp //===----------------------------------------------------------------------===// -OpFoldResult DivOp::fold(ArrayRef operands) { +OpFoldResult DivOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); // TODO(b/142478136): Handle fused ops. if (getFusedActivationFunction() != "NONE") return {}; return ConstFoldBinaryOp( @@ -1563,7 +1568,7 @@ OpFoldResult DivOp::fold(ArrayRef operands) { [](APInt a, APInt b) { return a.sdiv(b); }); } -int64_t DivOp::GetArithmeticCount(Operation *op) { +int64_t DivOp::GetArithmeticCount(Operation* op) { int64_t count; if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count; @@ -1661,16 +1666,16 @@ namespace { // TODO(antiagainst): This pattern probably should be moved to the peephole // category, after we have the infra for peephole passes. struct RemoveAdjacentReshape : public RewritePattern { - explicit RemoveAdjacentReshape(MLIRContext *context) + explicit RemoveAdjacentReshape(MLIRContext* context) : RewritePattern(ReshapeOp::getOperationName(), 1, context) {} - LogicalResult match(Operation *op) const override { + LogicalResult match(Operation* op) const override { auto thisOp = cast(op); auto prevOp = thisOp.getOperand(0).getDefiningOp(); return isa_and_nonnull(prevOp) ? success() : failure(); } - void rewrite(Operation *op, PatternRewriter &rewriter) const override { + void rewrite(Operation* op, PatternRewriter& rewriter) const override { auto thisOp = cast(op); auto prevOp = cast(thisOp.getOperand(0).getDefiningOp()); @@ -1693,7 +1698,7 @@ struct ConvertShapeTo1D : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ReshapeOp reshape, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { if (!reshape.getShape().hasOneUse()) return failure(); DenseIntElementsAttr shape; @@ -1733,7 +1738,8 @@ bool InputOutputHasSameShape(mlir::Type input_type, mlir::Type output_type) { } // end anonymous namespace -OpFoldResult ReshapeOp::fold(ArrayRef operands) { +OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); // Remove identity reshape with both static result and input shape. auto result_type = getType().cast(); auto input_type = getOperand(0).getType().cast(); @@ -1748,7 +1754,7 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) { if (!shape_elements) return nullptr; SmallVector shape_data; - for (const auto &it : shape_elements.getValues()) { + for (const auto& it : shape_elements.getValues()) { shape_data.push_back(it.getSExtValue()); } result_type = tensorflow::GetTypeFromTFTensorShape( @@ -1760,17 +1766,17 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) { return nullptr; } -void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } using ReshapeErrorHandler = - llvm::function_ref; + llvm::function_ref; LogicalResult GetReshapeOutputType(Value input, Value shape, ReshapeErrorHandler error_handler, - TensorType &output_ty) { + TensorType& output_ty) { auto input_ty = input.getType().cast(); auto element_ty = input_ty.getElementType(); output_ty = UnrankedTensorType::get(element_ty); @@ -1787,7 +1793,7 @@ LogicalResult GetReshapeOutputType(Value input, Value shape, // shape. if (shape_ty.hasStaticShape()) { llvm::SmallVector dynamic_shape(shape_ty.getDimSize(0), - ShapedType::kDynamicSize); + ShapedType::kDynamic); output_ty = tensorflow::GetTypeFromTFTensorShape(dynamic_shape, element_ty); } @@ -1801,10 +1807,10 @@ LogicalResult GetReshapeOutputType(Value input, Value shape, int64_t shape_ty_size = 1; llvm::SmallVector output_ty_shape; output_ty_shape.reserve(shape_attr.getNumElements()); - for (const auto &dim : llvm::enumerate(shape_attr.getValues())) { + for (const auto& dim : llvm::enumerate(shape_attr.getValues())) { const int64_t size = dim.value().getSExtValue(); if (size == tensorflow::kTFDynamicSize || // NOLINT - size == ShapedType::kDynamicSize) { // NOLINT + size == ShapedType::kDynamic) { // NOLINT if (unknown_index != -1) return error_handler(llvm::formatv( "requires 'shape' to have at most one dynamic dimension, but got " @@ -1839,7 +1845,7 @@ LogicalResult GetReshapeOutputType(Value input, Value shape, // Compute number of elements in tensor shape. int64_t input_ty_size = 1; bool input_ty_zero_dim = false; - for (const auto &dim : input_ty.getShape()) { + for (const auto& dim : input_ty.getShape()) { if (dim > 0 || !shape_ty_zero_dim) { input_ty_size *= dim; } else { @@ -1866,7 +1872,7 @@ LogicalResult GetReshapeOutputType(Value input, Value shape, mlir::LogicalResult ReshapeOp::verify() { ReshapeOp op = *this; - auto error_handler = [&op](const llvm::Twine &message) -> LogicalResult { + auto error_handler = [&op](const llvm::Twine& message) -> LogicalResult { return op.emitOpError() << message; }; TensorType expected_ty; @@ -1895,14 +1901,14 @@ mlir::LogicalResult ReshapeOp::verify() { } LogicalResult ReshapeOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, + MLIRContext* context, std::optional location, ValueRange operands, DictionaryAttr attr, RegionRange, - SmallVectorImpl &inferredReturnTypes) { + SmallVectorImpl& inferredReturnTypes) { ReshapeOpAdaptor op(operands, attr); const Value input = op.getInput(); const Value shape = op.getShape(); - auto error_handler = [&](const llvm::Twine &message) -> LogicalResult { + auto error_handler = [&](const llvm::Twine& message) -> LogicalResult { // A dummy error handler. // Errors when computing the output shape will be raised in // ReshapeOp::verify call. @@ -1948,13 +1954,13 @@ bool ReshapeOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) { // => Value [5, 8, 9] // TODO(b/133341698): Move to tablegen when variadic is supported. struct RemoveRedundantUnpackPack : public RewritePattern { - explicit RemoveRedundantUnpackPack(MLIRContext *context) + explicit RemoveRedundantUnpackPack(MLIRContext* context) : RewritePattern(PackOp::getOperationName(), 2, context) {} - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { TFL::PackOp pack_op = cast(op); - Operation *first_input = pack_op.getOperand(0).getDefiningOp(); + Operation* first_input = pack_op.getOperand(0).getDefiningOp(); if (!first_input) return failure(); auto input_unpack_op = dyn_cast_or_null(first_input); if (!input_unpack_op) return failure(); @@ -1986,10 +1992,10 @@ struct RemoveRedundantUnpackPack : public RewritePattern { // Replace PackOp with a reshape when there is only one operand. struct ReplacePackWithReshape : public RewritePattern { - explicit ReplacePackWithReshape(MLIRContext *context) + explicit ReplacePackWithReshape(MLIRContext* context) : RewritePattern(PackOp::getOperationName(), 2, context) {} - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { TFL::PackOp pack_op = cast(op); if (pack_op.getNumOperands() != 1) return failure(); @@ -2000,7 +2006,7 @@ struct ReplacePackWithReshape : public RewritePattern { // This is to workaround the unnecessary cast i64 -> i32. SmallVector new_shape_array; for (auto size : output_type.getShape()) { - new_shape_array.push_back(static_cast(size)); + new_shape_array.push_back(ConvertToTfliteSize(size)); } auto new_shape = rewriter.create( @@ -2016,8 +2022,8 @@ struct ReplacePackWithReshape : public RewritePattern { } }; -void PackOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void PackOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -2046,7 +2052,7 @@ mlir::LogicalResult SliceOp::verify() { DenseIntElementsAttr begin; if (matchPattern(op.getBegin(), m_Constant(&begin))) { int axis = 0; - for (const auto &begin_i : llvm::enumerate(begin)) { + for (const auto& begin_i : llvm::enumerate(begin)) { if (begin_i.value().getSExtValue() < 0) { return op.emitError( llvm::formatv("begin[{0}] cannot be negative", axis)); @@ -2058,7 +2064,7 @@ mlir::LogicalResult SliceOp::verify() { DenseIntElementsAttr size; if (matchPattern(op.getSize(), m_Constant(&size))) { int axis = 0; - for (const auto &size_i : llvm::enumerate(size)) { + for (const auto& size_i : llvm::enumerate(size)) { if (size_i.value().getSExtValue() < -1) { return op.emitError( llvm::formatv("size[{0}] cannot be negative other than -1", axis)); @@ -2087,9 +2093,9 @@ mlir::LogicalResult SliceOp::verify() { return success(); } -TFL::ConstOp NarrowDownInt64InputValuesForOp(Operation *input_op, +TFL::ConstOp NarrowDownInt64InputValuesForOp(Operation* input_op, RankedTensorType value_type, - Location loc, OpBuilder *builder) { + Location loc, OpBuilder* builder) { if (input_op == nullptr) return nullptr; mlir::DenseIntElementsAttr attr; @@ -2102,8 +2108,8 @@ TFL::ConstOp NarrowDownInt64InputValuesForOp(Operation *input_op, SmallVector value_i32; value_i32.reserve(value_type.getRank()); - for (const auto &size : attr) { - value_i32.push_back(static_cast(size.getSExtValue())); + for (const auto& size : attr) { + value_i32.push_back(ConvertToTfliteSize(size.getSExtValue())); } auto new_value_i32_attr = mlir::DenseIntElementsAttr::get(value_shape_type, value_i32); @@ -2117,7 +2123,7 @@ struct CastDonwInt64BeginEndToInt32 : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TFL::SliceOp slice_op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { auto begin = slice_op.getBegin(); auto size = slice_op.getSize(); auto begin_type = begin.getType().dyn_cast_or_null(); @@ -2151,8 +2157,8 @@ struct CastDonwInt64BeginEndToInt32 : public OpRewritePattern { } }; -void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void SliceOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -2160,7 +2166,7 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, // SqueezeOp //===----------------------------------------------------------------------===// -OpFoldResult SqueezeOp::fold(ArrayRef operands) { +OpFoldResult SqueezeOp::fold(FoldAdaptor) { auto input_ty = getInput().getType().dyn_cast(); auto result_ty = getType().dyn_cast(); @@ -2173,7 +2179,8 @@ OpFoldResult SqueezeOp::fold(ArrayRef operands) { // SubOp //===----------------------------------------------------------------------===// -OpFoldResult SubOp::fold(ArrayRef operands) { +OpFoldResult SubOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); // TODO(b/142478136): Handle fused ops. if (getFusedActivationFunction() != "NONE") return {}; return ConstFoldBinaryOp( @@ -2181,7 +2188,7 @@ OpFoldResult SubOp::fold(ArrayRef operands) { [](APInt a, APInt b) { return a - b; }); } -int64_t SubOp::GetArithmeticCount(Operation *op) { +int64_t SubOp::GetArithmeticCount(Operation* op) { int64_t count; if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) return count; @@ -2192,7 +2199,7 @@ int64_t SubOp::GetArithmeticCount(Operation *op) { // TopKOp //===----------------------------------------------------------------------===// -static void BuildTopKOp(OpBuilder *builder, OperationState &result, Value input, +static void BuildTopKOp(OpBuilder* builder, OperationState& result, Value input, Value k) { // Output size is only known if k is constant value. A negative dimension is // considered dynamic so use -1 here if k is not a constant value. @@ -2225,7 +2232,7 @@ static void BuildTopKOp(OpBuilder *builder, OperationState &result, Value input, //===----------------------------------------------------------------------===// // Return true if the op has non-empty "minmax" attribute. -static inline bool HasValidMinMaxAttribute(Operation *op) { +static inline bool HasValidMinMaxAttribute(Operation* op) { auto minmax = op->getAttrOfType("minmax"); return minmax && minmax.getValue().size() == 2; } @@ -2235,31 +2242,31 @@ namespace { /// This pattern matches and remove a tfl.fake_quant if all the users of this op /// and itself have "minmax" attribute set. struct DropFakeQuant : public RewritePattern { - explicit DropFakeQuant(MLIRContext *context) + explicit DropFakeQuant(MLIRContext* context) : RewritePattern(FakeQuantOp::getOperationName(), 1, context) {} - LogicalResult match(Operation *op) const override { + LogicalResult match(Operation* op) const override { // We only match the op with valid "minmax" attribute. if (!HasValidMinMaxAttribute(op)) return failure(); // If all the users of this op have valid "minmax" attributes, it is matched // and can be removed. auto fakeQuantOp = cast(op); - for (auto *operand : fakeQuantOp.getResult().getUsers()) + for (auto* operand : fakeQuantOp.getResult().getUsers()) if (!HasValidMinMaxAttribute(operand)) return failure(); return success(); } - void rewrite(Operation *op, PatternRewriter &rewriter) const override { + void rewrite(Operation* op, PatternRewriter& rewriter) const override { // Replace the matched FakeQuantOp by its primary operand. rewriter.replaceOp(op, op->getOperand(0)); } }; } // end anonymous namespace -void FakeQuantOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void FakeQuantOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -2270,9 +2277,9 @@ void FakeQuantOp::getCanonicalizationPatterns(RewritePatternSet &results, // TODO(b/133486129): Implement shape inference for unpack LogicalResult UnpackOp::inferReturnTypes( - MLIRContext *context, Optional loc, ValueRange operands, + MLIRContext* context, std::optional loc, ValueRange operands, DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { + SmallVectorImpl& inferredReturnTypes) { UnpackOpAdaptor op(operands, attributes); // TODO(jpienaar): Refactor verify if (failed(op.verify(loc.has_value() ? *loc : UnknownLoc::get(context)))) @@ -2342,7 +2349,7 @@ bool UnpackOp::isCompatibleReturnTypes(TypeRange lhs, TypeRange rhs) { // Extracts and returns the signed integer constant in a 0-rank integer tensor // or 1-element 1-rank integer tensor if 'value' is a constant. -static llvm::Optional ExtractConstantIntFromTensor(Value value) { +static std::optional ExtractConstantIntFromTensor(Value value) { ElementsAttr attr; if (!matchPattern(value, m_Constant(&attr))) return {}; if (attr.getNumElements() != 1) return {}; @@ -2365,7 +2372,7 @@ static RankedTensorType SubstituteRankedTensorTypeDimSize( // Verifies the output tensor types of SplitOp or SplitVOp. template static LogicalResult VerifySplitOpOutputTypes( - Operation *op, int64_t num_splits, + Operation* op, int64_t num_splits, ExpectedOutputTypeGetter get_expected_output_type) { for (int64_t i = 0; i < num_splits; ++i) { auto expected_output_type = get_expected_output_type(i); @@ -2385,7 +2392,7 @@ mlir::LogicalResult SplitOp::verify() { return op.emitOpError("output count should match 'num_splits' attribute"); // If 'split_dim' is not a constant, there are no other checks. - llvm::Optional split_dim_opt = + std::optional split_dim_opt = ExtractConstantIntFromTensor(op.getSplitDim()); if (!split_dim_opt) return success(); @@ -2393,7 +2400,7 @@ mlir::LogicalResult SplitOp::verify() { auto input_type = op.getValue().getType().dyn_cast(); if (!input_type) return success(); - int64_t split_dim = split_dim_opt.getValue(); + int64_t split_dim = split_dim_opt.value(); const int64_t rank = input_type.getRank(); if (split_dim < 0) split_dim += rank; if (split_dim < 0 || split_dim >= rank) @@ -2422,7 +2429,7 @@ mlir::LogicalResult SplitVOp::verify() { return op.emitOpError("output count should match 'num_splits' attribute"); // If 'split_dim' is not a constant, there are no other checks. - llvm::Optional split_dim_opt = + std::optional split_dim_opt = ExtractConstantIntFromTensor(op.getSplitDim()); if (!split_dim_opt) return success(); @@ -2430,7 +2437,7 @@ mlir::LogicalResult SplitVOp::verify() { auto input_type = op.getValue().getType().dyn_cast(); if (!input_type) return success(); - int64_t split_dim = split_dim_opt.getValue(); + int64_t split_dim = split_dim_opt.value(); const int64_t rank = input_type.getRank(); if (split_dim < 0) split_dim += rank; if (split_dim < 0 || split_dim >= rank) @@ -2588,7 +2595,7 @@ struct RemoveLSTMOpZeroBias : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(LSTMOp op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { if (EqualsZero(op.getInputGateBias())) { auto none_value = rewriter.create( rewriter.getUnknownLoc(), rewriter.getNoneType(), @@ -2609,8 +2616,8 @@ struct RemoveLSTMOpZeroBias : public OpRewritePattern { } // namespace -void LSTMOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void LSTMOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -2629,8 +2636,9 @@ mlir::LogicalResult UnidirectionalSequenceLSTMOp::verify() { } LogicalResult UnidirectionalSequenceLSTMOp::inferReturnTypes( - MLIRContext *, Optional, ValueRange operands, DictionaryAttr attr, - RegionRange, SmallVectorImpl &inferredReturnTypes) { + MLIRContext*, std::optional, ValueRange operands, + DictionaryAttr attr, RegionRange, + SmallVectorImpl& inferredReturnTypes) { Value input = operands[0]; auto input_type = input.getType().dyn_cast_or_null(); @@ -2664,10 +2672,11 @@ LogicalResult UnidirectionalSequenceLSTMOp::inferReturnTypes( time_major_attr ? time_major_attr->getValue().cast().getValue() : false; - int batch = + int64_t batch = time_majored ? input_type.getDimSize(1) : input_type.getDimSize(0); - int time = time_majored ? input_type.getDimSize(0) : input_type.getDimSize(1); - int n_output = output_state_type.getDimSize(1); + int64_t time = + time_majored ? input_type.getDimSize(0) : input_type.getDimSize(1); + int64_t n_output = output_state_type.getDimSize(1); // Build the output shape. SmallVector output_shape; @@ -2736,7 +2745,8 @@ mlir::LogicalResult SVDFOp::verify() { // AbsOp //===----------------------------------------------------------------------===// -OpFoldResult AbsOp::fold(ArrayRef operands) { +OpFoldResult AbsOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); Type result_type = getType(); // Only constant fold for tensor of f32 is implemented. if (!IsF32ShapedType(result_type)) return nullptr; @@ -2749,7 +2759,8 @@ OpFoldResult AbsOp::fold(ArrayRef operands) { // NegOp //===----------------------------------------------------------------------===// -OpFoldResult NegOp::fold(ArrayRef operands) { +OpFoldResult NegOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); Type result_type = getType(); // Only constant fold for tensor of f32 is implemented. if (!IsF32ShapedType(result_type)) return nullptr; @@ -2762,7 +2773,8 @@ OpFoldResult NegOp::fold(ArrayRef operands) { // SinOp //===----------------------------------------------------------------------===// -OpFoldResult SinOp::fold(ArrayRef operands) { +OpFoldResult SinOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); Type result_type = getType(); // Only constant fold for tensor of f32 is implemented. if (!IsF32ShapedType(result_type)) return nullptr; @@ -2779,7 +2791,8 @@ OpFoldResult SinOp::fold(ArrayRef operands) { // CosOp //===----------------------------------------------------------------------===// -OpFoldResult CosOp::fold(ArrayRef operands) { +OpFoldResult CosOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); Type result_type = getType(); // Only constant fold for tensor of f32 is implemented. if (!IsF32ShapedType(result_type)) return nullptr; @@ -2796,7 +2809,8 @@ OpFoldResult CosOp::fold(ArrayRef operands) { // LogOp //===----------------------------------------------------------------------===// -OpFoldResult LogOp::fold(ArrayRef operands) { +OpFoldResult LogOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); Type result_type = getType(); // Only constant fold for tensor of f32 is implemented. if (!IsF32ShapedType(result_type)) return nullptr; @@ -2813,7 +2827,7 @@ OpFoldResult LogOp::fold(ArrayRef operands) { // ShapeOp //===----------------------------------------------------------------------===// -OpFoldResult ShapeOp::fold(ArrayRef operands) { +OpFoldResult ShapeOp::fold(FoldAdaptor) { auto input_type = getInput().getType().cast(); if (!input_type.hasStaticShape()) return nullptr; @@ -2836,7 +2850,8 @@ OpFoldResult ShapeOp::fold(ArrayRef operands) { // SqrtOp //===----------------------------------------------------------------------===// -OpFoldResult SqrtOp::fold(ArrayRef operands) { +OpFoldResult SqrtOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); Type result_type = getType(); // Only constant fold for tensor of f32 is implemented. if (!IsF32ShapedType(result_type)) return nullptr; @@ -2853,7 +2868,8 @@ OpFoldResult SqrtOp::fold(ArrayRef operands) { // RsqrtOp //===----------------------------------------------------------------------===// -OpFoldResult RsqrtOp::fold(ArrayRef operands) { +OpFoldResult RsqrtOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); Type result_type = getType(); // Only constant fold for tensor of f32/bf16 is implemented. if (!IsF32ShapedType(result_type) && !IsBF16ShapedType(result_type)) @@ -2861,7 +2877,7 @@ OpFoldResult RsqrtOp::fold(ArrayRef operands) { auto compute = [](APFloat value) -> APFloat { bool loseInfo; - const llvm::fltSemantics &original_float_semantics = value.getSemantics(); + const llvm::fltSemantics& original_float_semantics = value.getSemantics(); value.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &loseInfo); float f = value.convertToFloat(); @@ -2877,7 +2893,8 @@ OpFoldResult RsqrtOp::fold(ArrayRef operands) { // SquareOp //===----------------------------------------------------------------------===// -OpFoldResult SquareOp::fold(ArrayRef operands) { +OpFoldResult SquareOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); Type result_type = getType(); // Only constant fold for tensor of f32 is implemented. if (!IsF32ShapedType(result_type)) return nullptr; @@ -2890,7 +2907,8 @@ OpFoldResult SquareOp::fold(ArrayRef operands) { // RankOp //===----------------------------------------------------------------------===// -OpFoldResult RankOp::fold(ArrayRef operands) { +OpFoldResult RankOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); assert(operands.size() == 1); auto result_type = getType().cast(); if (auto elements_attr = operands[0].dyn_cast_or_null()) { @@ -2916,7 +2934,9 @@ OpFoldResult RankOp::fold(ArrayRef operands) { // ConstOp //===----------------------------------------------------------------------===// -OpFoldResult ConstOp::fold(ArrayRef operands) { +OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); + (void)operands; assert(operands.empty() && "constant has no operands"); // Return the held attribute value. return getValue(); @@ -2937,7 +2957,7 @@ struct FoldPseudoConstOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ConstOp const_op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { if (arith::ConstantOp::isBuildableWith(const_op.getValue(), const_op.getType())) { rewriter.replaceOpWithNewOp(const_op, @@ -2956,8 +2976,8 @@ struct FoldPseudoConstOp : public OpRewritePattern { } // namespace -void ConstOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void ConstOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -2965,7 +2985,8 @@ void ConstOp::getCanonicalizationPatterns(RewritePatternSet &results, // CastOp //===----------------------------------------------------------------------===// -OpFoldResult CastOp::fold(ArrayRef operands) { +OpFoldResult CastOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); assert(operands.size() == 1); if (getElementTypeOrSelf(getInput()) == getElementTypeOrSelf(getType())) { return getInput(); @@ -3016,7 +3037,7 @@ OpFoldResult CastOp::fold(ArrayRef operands) { // SelectV2Op //===----------------------------------------------------------------------===// -static void BuildSelectV2Op(Builder *builder, OperationState &result, +static void BuildSelectV2Op(Builder* builder, OperationState& result, Value cond, Value x, Value y) { auto operand_type = OpTrait::util::getBroadcastedType(x.getType(), y.getType()); @@ -3106,7 +3127,8 @@ DenseElementsAttr BuildConstRangeTensor(Type result_elem_type, int num_elements, } } // namespace -OpFoldResult RangeOp::fold(ArrayRef operands) { +OpFoldResult RangeOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); assert(operands.size() == 3); auto start_tensor = operands[0].dyn_cast_or_null(); auto limit_tensor = operands[1].dyn_cast_or_null(); @@ -3178,7 +3200,7 @@ mlir::LogicalResult TransposeConvOp::verify() { return success(); } -int64_t TransposeConvOp::GetArithmeticCount(Operation *op) { +int64_t TransposeConvOp::GetArithmeticCount(Operation* op) { int64_t count = -1; auto transpose_conv = llvm::dyn_cast(op); auto input_type = transpose_conv.getInput() @@ -3240,7 +3262,7 @@ LogicalResult StridedSliceOp::verify() { return success(); } -OpFoldResult StridedSliceOp::fold(ArrayRef operands) { +OpFoldResult StridedSliceOp::fold(FoldAdaptor) { // Currently only support all masks being 0. if (getBeginMask() != 0 || getEndMask() != 0 || getEllipsisMask() != 0 || getNewAxisMask() != 0 || getShrinkAxisMask() != 0) @@ -3298,8 +3320,8 @@ namespace { // `new_values`. void ComputePermutation(ElementsAttr input_tensor, ArrayRef perm, ArrayRef output_shape, int num_dimensions, - int output_axis, std::vector *input_indices, - std::vector *new_values) { + int output_axis, std::vector* input_indices, + std::vector* new_values) { // Refer to the implementation of `Transpose` function in // tensorflow/lite/kernels/internal/reference/reference_ops.h assert(output_axis < num_dimensions); @@ -3322,7 +3344,8 @@ void ComputePermutation(ElementsAttr input_tensor, ArrayRef perm, } // namespace -OpFoldResult TransposeOp::fold(ArrayRef operands) { +OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); assert(operands.size() == 2); auto input_tensor = operands[0].dyn_cast_or_null(); auto perm_tensor = operands[1].dyn_cast_or_null(); @@ -3380,7 +3403,7 @@ mlir::LogicalResult TransposeOp::verify() { int index = 0; llvm::SmallVector axes; - for (const auto &axis_int : perm.getValues()) { + for (const auto& axis_int : perm.getValues()) { const int64_t axis = axis_int.getSExtValue(); if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) { return op.emitOpError( @@ -3430,7 +3453,7 @@ mlir::LogicalResult TransposeOp::verify() { return success(); } -static void BuildTransposeOp(OpBuilder *builder, OperationState &result, +static void BuildTransposeOp(OpBuilder* builder, OperationState& result, Value input, Value perm) { // Output size is only known if input is ranked and perm is a constant. auto input_type = input.getType().cast(); @@ -3490,9 +3513,9 @@ static void BuildTransposeOp(OpBuilder *builder, OperationState &result, /// during the flow of control. `operands` is a set of optional attributes that /// correspond to a constant value for each operand, or null if that operand is /// not a constant. -void IfOp::getSuccessorRegions(Optional index, +void IfOp::getSuccessorRegions(std::optional index, ArrayRef operands, - SmallVectorImpl ®ions) { + SmallVectorImpl& regions) { // The `then` and the `else` region branch back to the parent operation. if (index.has_value()) { regions.push_back(RegionSuccessor(getResults())); @@ -3500,7 +3523,7 @@ void IfOp::getSuccessorRegions(Optional index, } // Don't consider the else region if it is empty. - Region *else_reg = &getElseRegion(); + Region* else_reg = &getElseRegion(); if (else_reg->empty()) else_reg = nullptr; // Otherwise, the successor is dependent on the condition. @@ -3531,7 +3554,7 @@ struct PolyCallResultOperandsMatchAndImplicitCapture using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PolyCallOp while_op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { // Finish this. return success(); } @@ -3539,14 +3562,14 @@ struct PolyCallResultOperandsMatchAndImplicitCapture } // namespace -void PolyCallOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void PolyCallOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } void PolyCallOp::getSuccessorRegions( - Optional index, ArrayRef operands, - SmallVectorImpl ®ions) { + std::optional index, ArrayRef operands, + SmallVectorImpl& regions) { // Defaults to first region for TFLite execution. } @@ -3588,15 +3611,15 @@ struct WhileResultOperandsMatchAndImplicitCapture using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(WhileOp while_op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { // Replace values simply passed through the body with extern values // (in both body and condition regions as well as while result). The // block arguments of body and while match and so the corresponding cond // argument can be easily found. bool unchanged = true; - auto &body_block = while_op.getBody().front(); - auto &cond_block = while_op.getCond().front(); - auto &yield = *body_block.getTerminator(); + auto& body_block = while_op.getBody().front(); + auto& cond_block = while_op.getCond().front(); + auto& yield = *body_block.getTerminator(); for (auto ba : body_block.getArguments()) { int arg_no = ba.getArgNumber(); // Skip removing resources that are not read-only variables. @@ -3666,8 +3689,8 @@ struct WhileResultOperandsMatchAndImplicitCapture if (unchanged) return failure(); // Replace with new While with matching operands and results. - Operation *op = while_op.getOperation(); - Operation *new_op = rewriter.insert( + Operation* op = while_op.getOperation(); + Operation* new_op = rewriter.insert( Operation::create(op->getLoc(), op->getName(), types, new_operands, op->getAttrs(), {}, /*numRegions=*/2)); @@ -3680,7 +3703,7 @@ struct WhileResultOperandsMatchAndImplicitCapture } rewriter.eraseOp(op); - Block &new_body_block = cast(new_op).getBody().front(); + Block& new_body_block = cast(new_op).getBody().front(); rewriter.setInsertionPointToEnd(&new_body_block); rewriter.replaceOpWithNewOp(new_body_block.getTerminator(), new_body_yield); @@ -3691,12 +3714,12 @@ struct WhileResultOperandsMatchAndImplicitCapture } // namespace -void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void WhileOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } -Region &WhileOp::getLoopBody() { return getBody(); } +Region& WhileOp::getLoopBody() { return getBody(); } bool WhileOp::isDefinedOutsideOfLoop(Value value) { // TODO(jpienaar): This is to overly conservative and disables anything other @@ -3708,7 +3731,7 @@ bool WhileOp::isDefinedOutsideOfLoop(Value value) { // LogisticOp //===----------------------------------------------------------------------===// -int64_t LogisticOp::GetArithmeticCount(Operation *op) { +int64_t LogisticOp::GetArithmeticCount(Operation* op) { int64_t count; // As a very rough ballpark, the cost of evaluating a math function // such as tanh or logistic is about 32 multiplications, and about as @@ -3724,7 +3747,7 @@ int64_t LogisticOp::GetArithmeticCount(Operation *op) { // LogSoftmaxOp //===----------------------------------------------------------------------===// -int64_t LogSoftmaxOp::GetArithmeticCount(Operation *op) { +int64_t LogSoftmaxOp::GetArithmeticCount(Operation* op) { int64_t count; // As a very rough ballpark, the cost of evaluating a math function // such as tanh or logistic is about 32 multiplications, and about as @@ -3740,7 +3763,7 @@ int64_t LogSoftmaxOp::GetArithmeticCount(Operation *op) { // SoftmaxOp //===----------------------------------------------------------------------===// -int64_t SoftmaxOp::GetArithmeticCount(Operation *op) { +int64_t SoftmaxOp::GetArithmeticCount(Operation* op) { int64_t count; // As a very rough ballpark, the cost of evaluating a math function // such as tanh or logistic is about 32 multiplications, and about as @@ -3756,7 +3779,7 @@ int64_t SoftmaxOp::GetArithmeticCount(Operation *op) { // TanhOp //===----------------------------------------------------------------------===// -int64_t TanhOp::GetArithmeticCount(Operation *op) { +int64_t TanhOp::GetArithmeticCount(Operation* op) { int64_t count; // As a very rough ballpark, the cost of evaluating a math function // such as tanh or logistic is about 32 multiplications, and about as @@ -3772,7 +3795,7 @@ int64_t TanhOp::GetArithmeticCount(Operation *op) { // AddNOp //===----------------------------------------------------------------------===// -int64_t AddNOp::GetArithmeticCount(Operation *op) { +int64_t AddNOp::GetArithmeticCount(Operation* op) { int64_t count; if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) { // AddN cost is roughly the same cost as N-1 Adds. @@ -3787,7 +3810,7 @@ int64_t AddNOp::GetArithmeticCount(Operation *op) { // AveragePool2DOp //===----------------------------------------------------------------------===// -int64_t AveragePool2DOp::GetArithmeticCount(Operation *op) { +int64_t AveragePool2DOp::GetArithmeticCount(Operation* op) { int64_t count; if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) { auto avg_pool = llvm::dyn_cast(op); @@ -3801,7 +3824,7 @@ int64_t AveragePool2DOp::GetArithmeticCount(Operation *op) { // MaxPool2DOp //===----------------------------------------------------------------------===// -int64_t MaxPool2DOp::GetArithmeticCount(Operation *op) { +int64_t MaxPool2DOp::GetArithmeticCount(Operation* op) { int64_t count; if (ArithmeticCountUtilHelper::GetFirstOutputCount(op, &count)) { auto max_pool = llvm::dyn_cast(op); @@ -3815,7 +3838,7 @@ int64_t MaxPool2DOp::GetArithmeticCount(Operation *op) { // L2NormalizationOp //===----------------------------------------------------------------------===// -int64_t L2NormalizationOp::GetArithmeticCount(Operation *op) { +int64_t L2NormalizationOp::GetArithmeticCount(Operation* op) { int64_t count; // Computing the squared L2 norm is N multiply-adds so 2N ops, // then the single inverse-sqrt is negligible, then we multiply each @@ -3831,7 +3854,7 @@ int64_t L2NormalizationOp::GetArithmeticCount(Operation *op) { // PadOp //===----------------------------------------------------------------------===// -OpFoldResult PadOp::fold(ArrayRef operands) { +OpFoldResult PadOp::fold(FoldAdaptor) { if (InputOutputHasSameShape(getInput().getType(), getOutput().getType())) return getInput(); @@ -3842,7 +3865,7 @@ OpFoldResult PadOp::fold(ArrayRef operands) { // PadV2Op //===----------------------------------------------------------------------===// -OpFoldResult PadV2Op::fold(ArrayRef operands) { +OpFoldResult PadV2Op::fold(FoldAdaptor) { if (InputOutputHasSameShape(getInput().getType(), getOutput().getType())) return getInput(); @@ -3853,9 +3876,7 @@ OpFoldResult PadV2Op::fold(ArrayRef operands) { // NoValueOp //===----------------------------------------------------------------------===// -OpFoldResult NoValueOp::fold(ArrayRef operands) { - return getValueAttr(); -} +OpFoldResult NoValueOp::fold(FoldAdaptor) { return getValueAttr(); } bool NoValueOp::isBuildableWith(Attribute value, Type type) { return value.isa() && type.isa(); @@ -3871,7 +3892,7 @@ bool ControlNodeOp::WrapsSinglePerfectlyForwardedOp() { auto body = GetBody().without_terminator(); if (!hasSingleElement(body)) return false; - Operation &controlled_op = *body.begin(); + Operation& controlled_op = *body.begin(); YieldOp yield = GetYield(); return controlled_op.getNumResults() == yield.getNumOperands() && std::equal(controlled_op.getResults().begin(), @@ -3884,7 +3905,7 @@ mlir::LogicalResult ControlNodeOp::verify() { if (!control_node.GetBody().args_empty()) return control_node.emitOpError() << "expects body without any arguments"; - Operation &yield = control_node.GetBody().back(); + Operation& yield = control_node.GetBody().back(); if (!isa(yield)) return yield.emitOpError() << "invalid TFL.control_node terminator, yield expected"; @@ -3907,7 +3928,7 @@ mlir::LogicalResult ControlNodeOp::verify() { return success(); } -void ControlNodeOp::print(OpAsmPrinter &p) { +void ControlNodeOp::print(OpAsmPrinter& p) { if (getNumOperands()) { // These are always control operand, no explicit type needed. p << '('; @@ -3918,7 +3939,7 @@ void ControlNodeOp::print(OpAsmPrinter &p) { // control_node contains a single operation and the results of this operation // are perfectly forwarded to the yield. if (getOperation()->getAttrs().empty() && WrapsSinglePerfectlyForwardedOp()) { - Operation &controlled_op = GetBody().front(); + Operation& controlled_op = GetBody().front(); // The "controls" syntax only encodes a single location. YieldOp yield_op = GetYield(); // In order to correctly round-trip, we can only use this syntax when all @@ -3934,7 +3955,7 @@ void ControlNodeOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict(getOperation()->getAttrs()); } -ParseResult ControlNodeOp::parse(OpAsmParser &parser, OperationState &result) { +ParseResult ControlNodeOp::parse(OpAsmParser& parser, OperationState& result) { // Parse the body region. llvm::SMLoc loc = parser.getCurrentLocation(); Type control_type = ControlType::get(parser.getBuilder().getContext()); @@ -3949,15 +3970,15 @@ ParseResult ControlNodeOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); } - Region &body = *result.addRegion(); + Region& body = *result.addRegion(); if (succeeded(parser.parseOptionalKeyword("controls"))) { // If we parse the short version of the control node, we have an operation // in the generic form that follows the "controls" keyword. Parse it inside // the region and forward all of its results as-is to the yield operation. body.push_back(new Block); - Block &block = body.back(); - Operation *controlled_op = + Block& block = body.back(); + Operation* controlled_op = parser.parseGenericOperation(&block, block.begin()); if (!controlled_op) return failure(); OpBuilder builder(parser.getBuilder().getContext()); @@ -3972,7 +3993,7 @@ ParseResult ControlNodeOp::parse(OpAsmParser &parser, OperationState &result) { ControlNodeOp::ensureTerminator(body, parser.getBuilder(), result.location); // Get the results type for the control node from the terminator operands. - Operation &yield = body.back().back(); + Operation& yield = body.back().back(); result.types.reserve(yield.getNumOperands() + 1); result.types.append(yield.operand_type_begin(), yield.operand_type_end()); result.types.push_back(control_type); @@ -3986,7 +4007,8 @@ ParseResult ControlNodeOp::parse(OpAsmParser &parser, OperationState &result) { // EmbeddingLookupOp //===----------------------------------------------------------------------===// -OpFoldResult EmbeddingLookupOp::fold(ArrayRef operands) { +OpFoldResult EmbeddingLookupOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); auto lookup_attr = operands[0].dyn_cast_or_null(); auto value_attr = operands[1].dyn_cast_or_null(); if (!lookup_attr || !value_attr) { @@ -4031,7 +4053,7 @@ OpFoldResult EmbeddingLookupOp::fold(ArrayRef operands) { // ConstBytesAttr //===----------------------------------------------------------------------===// -Attribute ConstBytesAttr::parse(AsmParser &parser, Type type) { +Attribute ConstBytesAttr::parse(AsmParser& parser, Type type) { if (parser.parseColon()) { return nullptr; } @@ -4049,7 +4071,7 @@ Attribute ConstBytesAttr::parse(AsmParser &parser, Type type) { return ConstBytesAttr::get(parser.getBuilder().getContext(), bytes_data); } -void ConstBytesAttr::print(mlir::AsmPrinter &printer) const { +void ConstBytesAttr::print(mlir::AsmPrinter& printer) const { StringRef bytes_str = getValue(); printer << " : \"0x" << llvm::toHex(bytes_str) << "\""; } @@ -4060,7 +4082,7 @@ void ConstBytesAttr::print(mlir::AsmPrinter &printer) const { #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.cc.inc" -static FailureOr> parseI32Array(AsmParser &parser) { +static FailureOr> parseI32Array(AsmParser& parser) { SmallVector elements; auto elementParser = [&]() { int32_t element; @@ -4089,7 +4111,7 @@ namespace TFL { #include "tensorflow/compiler/mlir/lite/runtime_verifiers.inc" -Operation *TFLDialect::materializeConstant(OpBuilder &builder, Attribute value, +Operation* TFLDialect::materializeConstant(OpBuilder& builder, Attribute value, Type type, Location loc) { // If this is a constant bytes attribute or the result type doesn't match the // attribute type, then generate a tfl.pseudo_const. diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index ec8938aa358..b624618ea35 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -508,9 +508,9 @@ value of each element in `x`. For example, if x is an input element and y is an output element, this operation computes \\(y = |x|\\). }]; - let arguments = (ins TFL_TensorOf<[I16, F32, QI8, QI16]>:$x); + let arguments = (ins TFL_TensorOf<[I16, I32, F32, QI8, QI16]>:$x); - let results = (outs TFL_TensorOf<[I16, F32, QI8, QI16]>:$y); + let results = (outs TFL_TensorOf<[I16, I32, F32, QI8, QI16]>:$y); let hasFolder = 1; } @@ -654,7 +654,8 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [ TFL_TensorOfOrNone<[F32, QI32, I64]>:$bias, TFL_PaddingAttr:$padding, ConfinedAttr:$stride_h, - ConfinedAttr:$stride_w + ConfinedAttr:$stride_w, + TFL_AFAttr:$fused_activation_function ); let results = (outs TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$output); @@ -996,7 +997,7 @@ def TFL_DepthwiseConv2DOp : DynamicRangeQuantizedOpInterface]> { let arguments = ( ins TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$input, - TFL_TensorOf<[F32, QI8, QUI8]>:$filter, + TFL_TensorOf<[F32, QI4, QI8, QUI8]>:$filter, TFL_1DTensorOfOrNone<[F32, I32, I64]>:$bias, I32Attr:$dilation_h_factor, I32Attr:$dilation_w_factor, @@ -1036,7 +1037,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ let arguments = (ins TFL_TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input, - TFL_TensorOf<[F32, QI8, QUI8, QI16]>:$filter, + TFL_TensorOf<[F32, QI4, QI8, QUI8, QI16]>:$filter, TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias, TFL_AFAttr:$fused_activation_function, @@ -2883,6 +2884,24 @@ def TFL_Relu6Op: TFL_Op<"relu6", [ ]; } +def TFL_Relu0To1Op: TFL_Op<"relu_0_to_1", [ + PredOpTrait<"x and y must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + Pure, + QuantizableResult, + SameOperandsAndResultShape]> { + let summary = "Relu0To1 operator"; + + let description = [{ + Element-wise Relu0To1 operator + x -> max(0, min(1, x)) + }]; + + let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x); + + let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y); +} + def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [ PredOpTrait<"x and y must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, @@ -3866,10 +3885,10 @@ def TFL_CastOp : TFL_Op<"cast", [ }]; let arguments = (ins - TFL_TensorOf<[F32, I1, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex>]>:$input + TFL_TensorOf<[F16, F32, F64, I1, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex>]>:$input ); - let results = (outs TFL_TensorOf<[F32, I1, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex>]>:$output); + let results = (outs TFL_TensorOf<[F16, F32, F64, I1, I16, UI16, I32, UI32, I64, TFL_Quint8, UI8, I8, Complex>]>:$output); // TFLite's cast op does not utilize CastOptions, instead derives types // from the TfLiteTensors. @@ -3998,7 +4017,7 @@ def TFL_DynamicUpdateSliceOp: TFL_Op<"dynamic_update_slice", [ //===----------------------------------------------------------------------===// // Quantization ops. //===----------------------------------------------------------------------===// -def TFL_DequantizeOp: TFL_Op<"dequantize", []> { +def TFL_DequantizeOp: TFL_Op<"dequantize", [NoMemoryEffect]> { let summary = "Dequantize operator"; let description = [{ @@ -4104,7 +4123,7 @@ def TFL_SparseQConstOp : Op { + SameOperandsAndResultShape, NoMemoryEffect]> { let summary = "Quantize operator"; let description = [{ @@ -4277,7 +4296,7 @@ Ba et al. 'Layer Normalization' }]; let arguments = ( - ins TFL_TensorOf<[F32, QI8]>:$input, + ins TFL_TensorOf<[F32, QI8, QI16]>:$input, // Weights TFL_TensorOfOrNone<[F32, QI8]>:$input_to_input_weights, @@ -4377,24 +4396,24 @@ def TFL_UnidirectionalSequenceLSTMOp : LstmProjectionWeightBiasConstraint, LstmCifgInputConstraint, LstmResultConstraint, - TFL_OperandHasRankAtLeast<0, 2>, // input - TFL_OperandIsNoneOrHasRank<1, 2>, // input_to_input_weights - TFL_OperandHasRank<2, 2>, // input_to_forget_weights - TFL_OperandHasRank<3, 2>, // input_to_cell_weights - TFL_OperandHasRank<4, 2>, // input_to_output_weights - TFL_OperandIsNoneOrHasRank<5, 2>, // recurrent_to_input_weights - TFL_OperandHasRank<6, 2>, // recurrent_to_forget_weights - TFL_OperandHasRank<7, 2>, // recurrent_to_cell_weights - TFL_OperandHasRank<8, 2>, // recurrent_to_output_weights - TFL_OperandIsNoneOrHasRank<9, 1>, // cell_to_input_weights - TFL_OperandIsNoneOrHasRank<10, 1>, // cell_to_forget_weights - TFL_OperandIsNoneOrHasRank<11, 1>, // cell_to_output_weights - TFL_OperandIsNoneOrHasRank<12, 1>, // input_gate_bias - TFL_OperandHasRank<13, 1>, // forget_gate_bias - TFL_OperandHasRank<14, 1>, // cell_gate_bias - TFL_OperandHasRank<15, 1>, // output_gate_bias - TFL_OperandIsNoneOrHasRank<16, 2>, // projection_weights - TFL_OperandIsNoneOrHasRank<17, 1>, // projection_bias + TFL_OperandHasRankAtLeast<0, 2>, // input + TFL_OperandIsNoneOrHasRank<1, 2>, // input_to_input_weights + TFL_OperandHasRank<2, 2>, // input_to_forget_weights + TFL_OperandHasRank<3, 2>, // input_to_cell_weights + TFL_OperandHasRank<4, 2>, // input_to_output_weights + TFL_OperandIsNoneOrHasRankAtMost<5, 2>,// recurrent_to_input_weights + TFL_OperandHasRankAtMost<6, 2>, // recurrent_to_forget_weights + TFL_OperandHasRankAtMost<7, 2>, // recurrent_to_cell_weights + TFL_OperandHasRankAtMost<8, 2>, // recurrent_to_output_weights + TFL_OperandIsNoneOrHasRank<9, 1>, // cell_to_input_weights + TFL_OperandIsNoneOrHasRank<10, 1>, // cell_to_forget_weights + TFL_OperandIsNoneOrHasRank<11, 1>, // cell_to_output_weights + TFL_OperandIsNoneOrHasRank<12, 1>, // input_gate_bias + TFL_OperandHasRank<13, 1>, // forget_gate_bias + TFL_OperandHasRank<14, 1>, // cell_gate_bias + TFL_OperandHasRank<15, 1>, // output_gate_bias + TFL_OperandIsNoneOrHasRank<16, 2>, // projection_weights + TFL_OperandIsNoneOrHasRank<17, 1>, // projection_bias TFL_StatefulOp, DeclareOpInterfaceMethods, QuantizableResult, @@ -4464,6 +4483,9 @@ def TFL_UnidirectionalSequenceLSTMOp : // Used in post-training dynamic range quantization. If the value is true, // input activations are asymmetrically quantized. OptionalAttr:$asymmetric_quantize_inputs, + // IndyLSTM optimizations (i.e. optimizations enabled because of diagonal + // recurrent weight matrices that are provided as vectors) + OptionalAttr:$diagonal_recurrent_tensors, // Types of the optional intermediate tensors, which exist for fully // quantized op and hold the ranges of the intermediate tensors. @@ -4995,7 +5017,7 @@ def TFL_Atan2Op: TFL_Op<"atan2", [ let summary = "Atan2 operation"; let description = [{ - The "atan2" operation computes the arctangent of y/x element-wise, + The "atan2" operation computes the arctangent of y/x element-wise, respecting signs of the arguments. }]; @@ -5014,18 +5036,18 @@ def TFL_SignOp: TFL_Op<"sign", [ SameOperandsAndResultShape, SameOperandsAndResultElementType ]> { - + let summary = "Sign operation"; let description = [{ Returns NaN if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. }]; let arguments = (ins - TFL_TensorOf<[F32, F64]>:$x + TFL_TensorOf<[F32, F64, I32]>:$x ); let results = (outs - TFL_TensorOf<[F32, F64]>:$output + TFL_TensorOf<[F32, F64, I32]>:$output ); } @@ -5041,7 +5063,7 @@ def TFL_YieldOp : Op:$operands); + let arguments = (ins Variadic); // Default builder needed for ensureTerminator let builders = [ diff --git a/tensorflow/compiler/mlir/lite/metrics/BUILD b/tensorflow/compiler/mlir/lite/metrics/BUILD index b6df8dcabac..4173130bafe 100644 --- a/tensorflow/compiler/mlir/lite/metrics/BUILD +++ b/tensorflow/compiler/mlir/lite/metrics/BUILD @@ -2,7 +2,9 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ + "//tensorflow:__pkg__", "//tensorflow/compiler/mlir/lite:__subpackages__", "//tensorflow/lite/python:__subpackages__", "//tensorflow/lite/toco/python:__subpackages__", @@ -45,9 +47,9 @@ tf_cc_test( ":error_collector_inst", ":types_util", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:test", "//tensorflow/core/platform:resource_loader", + "//tensorflow/tsl/platform:statusor", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", diff --git a/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc b/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc index 824c464fd33..7874f6c5f3c 100644 --- a/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc +++ b/tensorflow/compiler/mlir/lite/metrics/error_collector_inst_test.cc @@ -33,14 +33,14 @@ limitations under the License. #include "mlir/Support/FileUtilities.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/metrics/types_util.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/tsl/platform/statusor.h" namespace mlir { namespace TFL { namespace { -using stream_executor::port::StatusOr; +using tsl::StatusOr; // MockSuccessPass reports errors but doesn't fail. class MockSuccessPass @@ -109,13 +109,14 @@ TEST(ErrorCollectorTest, TessSuccessPass) { "tensorflow/compiler/mlir/lite/metrics/testdata/strided_slice.mlir"); MLIRContext context; context.getOrLoadDialect(); - context.allowUnregisteredDialects(); + context.getOrLoadDialect(); context.enableMultithreading(); auto module = LoadModule(&context, input_file); EXPECT_EQ(module.ok(), true); - PassManager pm(&context, OpPassManager::Nesting::Implicit); + PassManager pm(module.value().get()->getName(), + OpPassManager::Nesting::Implicit); pm.addPass(std::make_unique()); pm.addInstrumentation( @@ -131,18 +132,19 @@ TEST(ErrorCollectorTest, TessFailurePass) { using tflite::metrics::ConverterErrorData; MLIRContext context; context.getOrLoadDialect(); + context.getOrLoadDialect(); const std::string input_file = "tensorflow/compiler/mlir/lite/metrics/testdata/strided_slice.mlir"; auto input_file_id = StringAttr::get(&context, input_file); - context.allowUnregisteredDialects(); context.enableMultithreading(); auto module = LoadModule(&context, tensorflow::GetDataDependencyFilepath(input_file)); EXPECT_EQ(module.ok(), true); - PassManager pm(&context, OpPassManager::Nesting::Implicit); + PassManager pm(module.value().get()->getName(), + OpPassManager::Nesting::Implicit); pm.addPass(std::make_unique()); pm.addPass(std::make_unique()); diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD index c68359214b4..51618d4826e 100644 --- a/tensorflow/compiler/mlir/lite/python/BUILD +++ b/tensorflow/compiler/mlir/lite/python/BUILD @@ -2,7 +2,13 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") licenses(["notice"]) -package(default_visibility = [":friends"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + ":friends", + "//tensorflow:__pkg__", + ], +) package_group( name = "friends", @@ -26,7 +32,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:core_cpu_base", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -58,7 +63,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/lite/toco:model_flags_proto_cc", @@ -90,7 +94,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/lite/toco:model_flags_proto_cc", @@ -122,7 +125,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto_cc", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index e8833159316..fae07ee2644 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -15,11 +15,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h" +#include #include #include #include -#include "llvm/ADT/None.h" #include "llvm/Support/ToolOutputFile.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" @@ -43,6 +42,7 @@ limitations under the License. #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, @@ -113,11 +113,12 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, pass_config.preserve_assert_op = toco_flags.preserve_assert_op(); pass_config.guarantee_all_funcs_one_use = toco_flags.guarantee_all_funcs_one_use(); + pass_config.enable_stablehlo_conversion = toco_flags.convert_to_stablehlo(); return internal::ConvertMLIRToTFLiteFlatBuffer( model_flags, toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{}, result, - /*session=*/llvm::None); + /*session=*/std::nullopt); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc index d310867fed7..ef8536ccfb6 100644 --- a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h" #include +#include #include #include #include "absl/strings/str_join.h" #include "absl/types/span.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/ToolOutputFile.h" @@ -42,7 +42,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/compiler/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/types.pb.h" @@ -53,6 +52,7 @@ limitations under the License. #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { namespace { @@ -160,6 +160,7 @@ Status ConvertJaxToTFLiteFlatBuffer(const std::string& input, pass_config.unfold_large_splat_constant = toco_flags.unfold_large_splat_constant(); pass_config.enable_hlo_to_tf_conversion = true; + pass_config.enable_stablehlo_conversion = toco_flags.convert_to_stablehlo(); mlir::OwningOpRef module; if (model_flags.hlo_file_type() == toco::ModelFlags::HLO_TEXT) { @@ -192,7 +193,7 @@ Status ConvertJaxToTFLiteFlatBuffer(const std::string& input, auto status = internal::ConvertMLIRToTFLiteFlatBuffer( model_flags, toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{}, result, - /*session=*/llvm::None); + /*session=*/std::nullopt); return status; } diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index eae882586ba..de9e5eadb10 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -38,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" @@ -47,6 +46,7 @@ limitations under the License. #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { @@ -198,6 +198,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, pass_config.preserve_assert_op = toco_flags.preserve_assert_op(); pass_config.guarantee_all_funcs_one_use = toco_flags.guarantee_all_funcs_one_use(); + pass_config.enable_stablehlo_conversion = toco_flags.convert_to_stablehlo(); // TODO(b/153507667): Pass the session object when importing logic is removed. auto status = internal::ConvertMLIRToTFLiteFlatBuffer( diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index 2e10602e7f6..1ddba633947 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" +#include #include #include #include @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" @@ -43,8 +43,9 @@ limitations under the License. #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" #include "tensorflow/lite/tools/optimize/reduced_precision_support.h" +#include "tensorflow/tsl/platform/statusor.h" -using stream_executor::port::StatusOr; +using tsl::StatusOr; namespace tensorflow { namespace internal { @@ -234,7 +235,7 @@ Status PopulateQuantizationSpecs( DataType_Name(ConvertIODataTypeToDataType(toco_data_type))); } if (flag.shape().unknown_rank()) { - node_shapes->push_back(llvm::None); + node_shapes->push_back(std::nullopt); } else { node_shapes->push_back(std::vector(flag.shape().dims().begin(), flag.shape().dims().end())); @@ -248,8 +249,8 @@ Status PopulateQuantizationSpecs( node_mins->push_back(min_max.first); node_maxs->push_back(min_max.second); } else { - node_mins->push_back(llvm::None); - node_maxs->push_back(llvm::None); + node_mins->push_back(std::nullopt); + node_maxs->push_back(std::nullopt); } } } @@ -306,9 +307,10 @@ Status PopulateQuantizationSpecs( if (toco_flags.has_default_ranges_max()) { quant_specs->default_ranges.second = toco_flags.default_ranges_max(); } - if (toco_flags.enable_mlir_dynamic_range_quantizer()) { - quant_specs->enable_mlir_dynamic_range_quantizer = true; - } + quant_specs->enable_mlir_dynamic_range_quantizer = + toco_flags.enable_mlir_dynamic_range_quantizer(); + quant_specs->enable_mlir_variable_quantization = + toco_flags.enable_mlir_variable_quantization(); return OkStatus(); } diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index c9b5b49ca88..37d79b679c0 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -25,11 +25,11 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/public/session.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/types.pb.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { namespace internal { diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD index 1c5328d3080..33d54e4e449 100644 --- a/tensorflow/compiler/mlir/lite/quantization/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/BUILD @@ -5,8 +5,10 @@ load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ ":friends", + "//tensorflow:__pkg__", ], licenses = ["notice"], ) @@ -158,6 +160,8 @@ tf_native_cc_binary( "tools/tflite_op_coverage_spec_getters_gen.cc", ], deps = [ + "//tensorflow/tsl/platform:logging", + "//tensorflow/tsl/platform:regexp", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//llvm:TableGen", diff --git a/tensorflow/compiler/mlir/lite/quantization/device_target.cc b/tensorflow/compiler/mlir/lite/quantization/device_target.cc index 914c6f5f419..5297e55f599 100644 --- a/tensorflow/compiler/mlir/lite/quantization/device_target.cc +++ b/tensorflow/compiler/mlir/lite/quantization/device_target.cc @@ -55,7 +55,7 @@ DeviceTarget::DeviceTarget(MLIRContext* ctx) : ctx_(ctx) { Optional DeviceTarget::GetKernelSpec( llvm::StringRef kernel, const KernelSpecs::Signature& signature) const { auto kernel_specs_it = specs_.find(kernel); - if (kernel_specs_it == specs_.end()) return llvm::None; + if (kernel_specs_it == specs_.end()) return std::nullopt; return kernel_specs_it->getValue().Find(signature); } diff --git a/tensorflow/compiler/mlir/lite/quantization/device_target.h b/tensorflow/compiler/mlir/lite/quantization/device_target.h index a74a696da80..a49035a24db 100644 --- a/tensorflow/compiler/mlir/lite/quantization/device_target.h +++ b/tensorflow/compiler/mlir/lite/quantization/device_target.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_ #include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Hashing.h" @@ -78,7 +79,7 @@ class KernelSpecs { if (spec_it != all_signatures_.end()) { return spec_it->second; } else { - return llvm::None; + return std::nullopt; } } diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/BUILD b/tensorflow/compiler/mlir/lite/quantization/ir/BUILD index 04dd6034cfd..bd718c14d23 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/ir/BUILD @@ -3,6 +3,7 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc b/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc index f9cb74826ec..3bd80ad4a7b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc @@ -40,7 +40,7 @@ void QuantizationForkDialect::initialize() { >(); } -OpFoldResult StorageCastOp::fold(ArrayRef operands) { +OpFoldResult StorageCastOp::fold(FoldAdaptor) { // Matches x -> [scast -> scast] -> y, replacing the second scast with the // value of x if the casts invert each other. auto srcScastOp = getArg().getDefiningOp(); diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/QuantOpsBase.td b/tensorflow/compiler/mlir/lite/quantization/ir/QuantOpsBase.td index e6851daf260..54e730d2b80 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/QuantOpsBase.td +++ b/tensorflow/compiler/mlir/lite/quantization/ir/QuantOpsBase.td @@ -27,8 +27,7 @@ include "mlir/IR/OpBase.td" def QuantizationFork_Dialect : Dialect { let name = "quantfork"; let cppNamespace = "::mlir::quantfork"; - - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // QUANT_FORK_BASE diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 789b56a00ed..20346adde81 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -2,8 +2,10 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ ":friends", + "//tensorflow:__pkg__", ], licenses = ["notice"], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 4195dfdc44a..e21105fc5c4 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" @@ -56,9 +57,11 @@ TfLiteStatus QuantizeModel( flatbuffers::FlatBufferBuilder* builder, tflite::ErrorReporter* error_reporter, bool verify_numeric, bool whole_model_verify, bool legacy_float_scale, - const StringSet& denylisted_ops, const StringSet& denylisted_nodes) { + const absl::flat_hash_set& denylisted_ops, + const absl::flat_hash_set& denylisted_nodes, + const bool enable_variable_quantization) { // Translate TFLite names to mlir op names. - StringSet denylisted_mlir_op_names; + absl::flat_hash_set denylisted_mlir_op_names; for (const auto& entry : denylisted_ops) { denylisted_mlir_op_names.insert(TfLiteToMlir(entry)); } @@ -87,7 +90,7 @@ TfLiteStatus QuantizeModel( } // Apply quantization passes. - PassManager pm(module->getContext(), OpPassManager::Nesting::Implicit); + PassManager pm((*module)->getName(), OpPassManager::Nesting::Implicit); quant::QuantizationSpecs quant_specs; quant_specs.inference_type = tflite::TflTypeToTfType(inference_type); quant_specs.post_training_quantization = true; @@ -97,6 +100,7 @@ TfLiteStatus QuantizeModel( quant_specs.legacy_float_scale = legacy_float_scale; quant_specs.ops_blocklist = denylisted_mlir_op_names; quant_specs.nodes_blocklist = denylisted_nodes; + quant_specs.enable_mlir_variable_quantization = enable_variable_quantization; llvm::dbgs() << "fully_quantize: " << fully_quantize << ", inference_type: " << quant_specs.inference_type diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h index eba106589ed..243af219da6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h @@ -27,8 +27,6 @@ limitations under the License. namespace mlir { namespace lite { -using StringSet = absl::flat_hash_set; - // Quantize the `input_model` and write the result to a flatbuffer `builder`. // The `input_type`, `output_type` and `inference_type` can be // float32/qint8/int8/int16. @@ -47,8 +45,9 @@ TfLiteStatus QuantizeModel( flatbuffers::FlatBufferBuilder* builder, tflite::ErrorReporter* error_reporter, bool verify_numeric = false, bool whole_model_verify = false, bool legacy_float_scale = true, - const StringSet& denylisted_ops = {}, - const StringSet& denylisted_nodes = {}); + const absl::flat_hash_set& denylisted_ops = {}, + const absl::flat_hash_set& denylisted_nodes = {}, + bool enable_variable_quantization = false); } // namespace lite } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc index b60ba282a49..ce87e5d8f92 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" @@ -51,7 +52,7 @@ using llvm::StringRef; // Convert op represented in TFLite builtin_code to its corresponding MLIR // OperationName. void TfLiteBuiltinOpToMlir(const BuiltinOperatorSet& tflite_builtin_codes, - StringSet& mlir_op_names) { + absl::flat_hash_set& mlir_op_names) { for (const auto& entry : tflite_builtin_codes) { StringRef tflite_op_name = EnumNameBuiltinOperator(entry); std::string mlir_name = llvm::Twine("tfl.", tflite_op_name.lower()).str(); @@ -77,12 +78,13 @@ std::unique_ptr CreateMutableModelFromFile( TfLiteStatus QuantizeWeights( flatbuffers::FlatBufferBuilder* builder, const tflite::Model* input_model, tflite::ErrorReporter* error_reporter, - const tflite::TensorType& inference_type, const StringSet& denylisted_ops, + const tflite::TensorType& inference_type, + const absl::flat_hash_set& denylisted_ops, const CustomOpMap& custom_op_map, int64_t minimum_elements_for_weights, bool disable_per_channel, bool weight_only_quantization, bool legacy_float_scale) { // Translate TFLite names to mlir op names. - StringSet denylisted_mlir_op_names; + absl::flat_hash_set denylisted_mlir_op_names; for (auto& entry : denylisted_ops) { denylisted_mlir_op_names.insert(TfLiteToMlir(entry)); } @@ -106,7 +108,7 @@ TfLiteStatus QuantizeWeights( serialized_model, &context, UnknownLoc::get(&context)); // Apply quantization passes. - PassManager pm(module->getContext(), OpPassManager::Nesting::Implicit); + PassManager pm((*module)->getName(), OpPassManager::Nesting::Implicit); quant::QuantizationSpecs quant_specs; quant_specs.inference_type = tflite::TflTypeToTfType(inference_type); quant_specs.weight_quantization = true; @@ -215,7 +217,7 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, tflite::StderrReporter error_reporter; const tflite::TensorType inference_type = tflite::TensorType_INT8; - StringSet mlir_op_denylist; + absl::flat_hash_set mlir_op_denylist; TfLiteBuiltinOpToMlir(op_denylist, mlir_op_denylist); return QuantizeWeights( diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h index 7fac2b5ab64..d7cb5ab1fe6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h @@ -44,7 +44,6 @@ struct CustomOpInfo { bool no_side_effect = true; }; -using StringSet = absl::flat_hash_set; using BuiltinOperatorSet = absl::flat_hash_set; // Map from custom op code to custom op quantization information. using CustomOpMap = std::unordered_map; @@ -57,16 +56,15 @@ using CustomOpMap = std::unordered_map; // third_party/tensorflow/lite/tools/optimize/quantize_weights.h. // TODO(b/202468183): Selective quantization + quant debugger support for // dynamic range quantization for verify_numeric and whole_model_verify flags. -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, - const tflite::Model* input_model, - tflite::ErrorReporter* error_reporter, - const tflite::TensorType& inference_type, - const StringSet& denylisted_ops, - const CustomOpMap& custom_op_map, - int64_t minimum_elements_for_weights = 1024, - bool disable_per_channel = false, - bool weight_only_quantization = false, - bool legacy_float_scale = false); +TfLiteStatus QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const tflite::Model* input_model, + tflite::ErrorReporter* error_reporter, + const tflite::TensorType& inference_type, + const absl::flat_hash_set& denylisted_ops, + const CustomOpMap& custom_op_map, + int64_t minimum_elements_for_weights = 1024, + bool disable_per_channel = false, bool weight_only_quantization = false, + bool legacy_float_scale = false); // Overloading methods to support old quantizer versions API TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc index c89bcac032c..adbed4d6eb0 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc @@ -66,7 +66,7 @@ void ParseCustomOpSpecs(absl::string_view node_names, auto node_specification = node_infos[1]; CustomOpInfo new_node_info; switch (update_option) { - case CustomOpUpdateOptions::kINputIndices: { + case CustomOpUpdateOptions::kInputIndices: { std::vector indices = absl::StrSplit(node_specification, '-'); for (auto& cur_index : indices) { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h index 1cdf17f9c24..a523c9009f3 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.h @@ -42,9 +42,8 @@ struct CustomOpInfo { }; using ::tflite::optimize::ReducedPrecisionSupport; -using StringSet = absl::flat_hash_set; using CustomOpMap = std::unordered_map; -enum CustomOpUpdateOptions { kINputIndices, kWeightOnly, kNoSideEffect }; +enum CustomOpUpdateOptions { kInputIndices, kWeightOnly, kNoSideEffect }; struct QuantizationSpecs { // Which function this node quant specifications belong to. @@ -92,6 +91,12 @@ struct QuantizationSpecs { // quantization. bool disable_infer_tensor_range = false; + // Whether use the unfrozen variable quantization in MLIR. Typically, + // variables are frozen for passing passes, but some variables aren't frozen. + // If it is true, QuantizeVariables pass will be added after the + // PrepareQuantizePass. + bool enable_mlir_variable_quantization = false; + // The node type when the model is exported. Currently this is limited to // DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the // `weight_quantization` flag needs to set to true. When DT_QUINT8 is used, @@ -194,10 +199,10 @@ struct QuantizationSpecs { // Names of ops to block from quantization. Used in QuantizePass. // For dynamic range quantization, ops in blocklist are quantized in weight- // only manner. - StringSet ops_blocklist; + absl::flat_hash_set ops_blocklist; // Names of locations to block from quantization. Used in QuantizePass. - StringSet nodes_blocklist; + absl::flat_hash_set nodes_blocklist; // Map from custom op code to custom op quantization information. // For dynamic range quantization, among the custom ops in the graph those diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index 7241b78fb52..18a812fea8e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -50,8 +50,8 @@ namespace mlir { #include "tensorflow/compiler/mlir/lite/quantization/quantization_interface.cc.inc" namespace quant { - namespace { + constexpr double kSmallestHalfRange = kNearZeroTolerance / 2; using QType = quant::QuantizedType; diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index bf90c38a07f..100e911b8bb 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -21,7 +21,9 @@ limitations under the License. #include #include +#include #include +#include #include #include "absl/container/flat_hash_set.h" @@ -35,7 +37,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -74,7 +76,7 @@ constexpr absl::string_view QuantTraitValues[] = {"fully_quantizable", constexpr double kNearZeroTolerance = 1.0e-6; -using QuantParams = mlir::quant::QuantizedType; +using QuantParams = QuantizedType; using QuantSpec = QuantizationSpecs; using SignedInteger = std::pair; // bitwidth and sign using QuantParamsForResults = llvm::SmallVector; @@ -89,7 +91,6 @@ using RequiredSameOperandsAndResultsScaleFunc = std::function; // bool RequiredSameQuantizedAxes() using RequiredSameQuantizedAxesFunc = std::function; -using StringSet = absl::flat_hash_set; using CustomMap = quant::CustomOpMap; // Quantization spec of an op, driving the quantization algorithm. @@ -107,10 +108,10 @@ struct OpQuantSpec { llvm::DenseMap restricted_output_params; // Coefficient operand index and whether supporting per-channel quantization. - // For QAT, this information is carried by the FakeQuant*/QDQ ops, but - // post-training quantization, the quantization parameters need to be inferred - // from the tensor content and op property. A "-1" value indicates the - // operand doesn't support per-channel quantization. + // For QAT, this information is carried by the FakeQuant*/Quantize/Dequantize + // ops, but post-training quantization, the quantization parameters need to be + // inferred from the tensor content and op property. A "-1" value indicates + // the operand doesn't support per-channel quantization. llvm::DenseMap coeff_op_quant_dim; // Indices of quantizable operands. Biases are not included in this field, @@ -118,6 +119,11 @@ struct OpQuantSpec { absl::flat_hash_set quantizable_operands; }; +// A function signature for getting the particular OpQuantSpec for the provided +// op. +using OpQuantSpecGetter = + std::function(Operation*)>; + // Quantization scale spec of an op. The information defined in the MLIR // interfaces FixedOutputRangeInterface and SameOperandsAndResultsScale should // be checked first if present. @@ -137,6 +143,11 @@ struct OpQuantScaleSpec { }; }; +// A function signature for getting the particular OpQuantScaleSpec for the +// provided op. +using OpQuantScaleSpecGetter = + std::function(Operation*)>; + // Used in TFL Numeric Verify struct NumericVerifySpec { // Whether to enable numeric verification @@ -162,14 +173,6 @@ struct QuantPassSpec { QuantSpec quant_spec; }; -// A function signature for getting the particular OpQuantSpec for the provided -// op. -typedef std::unique_ptr (*OpQuantSpecGetter)(Operation* op); -// A function signature for getting the particular OpQuantScaleSpec for the -// provided op. -typedef std::unique_ptr (*OpQuantScaleSpecGetter)( - Operation* op); - // Re-calculates scales again in float instead of simply downcasting existing // scales. quant::QuantizedType DownCastScale(quant::QuantizedType type, @@ -190,7 +193,7 @@ inline std::string GetTensorNameFromLoc(Location loc) { return ""; } -template +template struct ConvertStatsToQDQs : public OpRewritePattern { ConvertStatsToQDQs(int num_bits, bool narrow_range, bool is_signed, bool legacy_float_scale, MLIRContext* context) @@ -207,6 +210,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern { SmallVector mins, maxs; if (op.getAxisStats().has_value()) { + // Per axis quantization (or per channel quantization) int stats_num = op.getAxisStats()->getNumElements(); if (stats_num == 0 || stats_num % 2 != 0) return failure(); auto stats = op.getAxisStats()->dyn_cast(); @@ -220,6 +224,13 @@ struct ConvertStatsToQDQs : public OpRewritePattern { // So here we adjust the range to include 0.0. rmin = std::min(rmin, 0.0); rmax = std::max(rmax, 0.0); + if (num_bits == 16) { + // TODO(b/266536261): Since the kernel implementation assumes that + // 16x8 integer quantization is symmetric, this MLIR quantizer + // supports only symmetric quantization. + rmax = std::max(std::abs(rmin), std::abs(rmax)); + rmin = -rmax; + } TensorRangeSanityCheck(op, rmin, rmax); mins.push_back(rmin); maxs.push_back(rmax); @@ -232,6 +243,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern { } } else if (auto stats = op.getLayerStats().dyn_cast()) { + // Per tensor quantization auto statValues = stats.getValues(); double rmin = FloatAttr::getValueAsDouble(statValues[0]); double rmax = FloatAttr::getValueAsDouble(statValues[1]); @@ -240,6 +252,13 @@ struct ConvertStatsToQDQs : public OpRewritePattern { // So here we adjust the range to include 0.0. rmin = std::min(rmin, 0.0); rmax = std::max(rmax, 0.0); + if (num_bits == 16) { + // TODO(b/266536261): Since the kernel implementation assumes that + // 16x8 integer quantization is symmetric, this MLIR quantizer supports + // only symmetric quantization. + rmax = std::max(std::abs(rmin), std::abs(rmax)); + rmin = -rmax; + } TensorRangeSanityCheck(op, rmin, rmax); quant_type = quantfork::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax, @@ -253,10 +272,11 @@ struct ConvertStatsToQDQs : public OpRewritePattern { rewriter.setInsertionPointAfter(op.getOperation()); Type result_type = quant_type.castFromExpressedType(op.getType()); - auto q = rewriter.create(op.getLoc(), result_type, op.getArg()); + auto q = + rewriter.create(op.getLoc(), result_type, op.getArg()); q->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr()); - auto dq = rewriter.create(op.getLoc(), op.getType(), q); + auto dq = rewriter.create(op.getLoc(), op.getType(), q); op.getResult().replaceAllUsesWith(dq); q.getOperation()->replaceUsesOfWith(dq, op.getArg()); op.erase(); @@ -330,8 +350,8 @@ inline void CreateVerifier(Operation* quantizing_op, // A base rewrite pattern which matches any N-in-M-out operations with // quantization parameters propagated to at least one of its operands. The -// quantization parameters are annotated by the Q/DQ op pairs. Each -// matched pattern are rewritten by its quantized alternatives. +// quantization parameters are annotated by the QuantizeOp/DequantizeOp pairs. +// Each matched pattern are rewritten by its quantized alternatives. // // The concrete pattern, extends from this base pattern, can specify whether it // allows dynamic range quantized operands and results for the operations in the @@ -346,16 +366,17 @@ inline void CreateVerifier(Operation* quantizing_op, // Full integer quantization disallows "DynamicRangeQuantized" operands or // results. Dynamic range quantization allows "DynamicRangeQuantized" operands // and results. -template +template class QuantizationPattern : public RewritePattern { public: - using BaseType = QuantizationPattern; + using BaseType = QuantizationPattern; explicit QuantizationPattern(MLIRContext* context, const QuantPassSpec& quant_params) // Set the score to a large number so it is always preferred. - : RewritePattern(RootOp::getOperationName(), 300, context), + : RewritePattern(RootOpT::getOperationName(), 300, context), quant_params_(quant_params) {} LogicalResult matchAndRewrite(Operation* op, @@ -363,24 +384,25 @@ class QuantizationPattern : public RewritePattern { llvm::SmallVector quantizing_ops; // Collect all the ops to quantize, as the user / producer of the root op. - if (std::is_same::value) { + if constexpr (std::is_same_v) { if (op->getNumResults() != 1) { return failure(); } auto users = op->getResult(0).getUsers(); quantizing_ops.append(users.begin(), users.end()); - } else if (std::is_same::value) { + } else if constexpr (std::is_same_v) { if (op->getNumOperands() != 1) { return failure(); } Value quantize_operand = op->getOperand(0); if (QuantizedType::getQuantizedElementType(quantize_operand.getType())) { - // The input of this Q op has already been quantized, i.e. rescale. + // The input of this QuantizeOp has already been quantized, i.e. + // rescale. return failure(); } DenseFPElementsAttr attr; if (matchPattern(quantize_operand, m_Constant(&attr))) { - // Const->Q pattern will be handled separately. + // Const-> QuantizeOp pattern will be handled separately. return failure(); } if (Operation* quantizing_op = quantize_operand.getDefiningOp()) { @@ -395,15 +417,17 @@ class QuantizationPattern : public RewritePattern { bool enable_verify = quant_params_.numeric_verify_spec.verify_numeric; bool enable_whole_model_verify = quant_params_.numeric_verify_spec.whole_model_verify; - StringSet ops_blocklist = quant_params_.quant_spec.ops_blocklist; - StringSet nodes_blocklist = quant_params_.quant_spec.nodes_blocklist; + absl::flat_hash_set ops_blocklist = + quant_params_.quant_spec.ops_blocklist; + absl::flat_hash_set nodes_blocklist = + quant_params_.quant_spec.nodes_blocklist; CustomMap custom_map = quant_params_.quant_spec.custom_map; // Rewrite the floating-point ops to the quantized version, by fusing // preceding dequantize ops and succeding quantize ops. for (Operation* quantizing_op : quantizing_ops) { // If it is requantize op, we shouldn't rewrite this op. - if (llvm::isa(quantizing_op)) { + if (llvm::isa(quantizing_op)) { return failure(); } @@ -416,7 +440,7 @@ class QuantizationPattern : public RewritePattern { } if (IsOpNotQuantizable(quantizing_op) && - !static_cast(this)->IsQuantizableCustomOp( + !static_cast(this)->IsQuantizableCustomOp( quantizing_op, custom_map)) { if (!(enable_verify && enable_whole_model_verify)) { return failure(); @@ -455,7 +479,7 @@ class QuantizationPattern : public RewritePattern { // An op with float inputs and outputs are expected when it's used by a // NumericVerify op. Skip this op. - if (enable_verify && UsedBy(quantizing_op)) { + if (enable_verify && UsedBy(quantizing_op)) { continue; } @@ -471,17 +495,17 @@ class QuantizationPattern : public RewritePattern { } auto ele_type = operand.getType().cast().getElementType(); - if (static_cast(this) + if (static_cast(this) ->AllowDynamicRangeQuantizedOperand(quantizing_op, custom_map)) { - auto dq_op = dyn_cast_or_null(operand.getDefiningOp()); + auto dq_op = dyn_cast_or_null(operand.getDefiningOp()); if (dq_op && inference_type == tensorflow::DT_QINT8 && - !static_cast(this)->IsWeightOnlyOp( + !static_cast(this)->IsWeightOnlyOp( quantizing_op, ops_blocklist, weight_only_quantization, custom_map)) { - // Dynamic range quantization is applied by having Q as an input. - // Only int8 weight is supported for now. + // Dynamic range quantization is applied by having QuantizeOp as an + // input. Only int8 weight is supported for now. inputs.push_back(dq_op.getOperand()); } else { // Otherwise, it's the case where the operand is activations or the @@ -489,11 +513,12 @@ class QuantizationPattern : public RewritePattern { inputs.push_back(operand); } } else { - if (auto dq_op = dyn_cast_or_null(operand.getDefiningOp())) { + if (auto dq_op = + dyn_cast_or_null(operand.getDefiningOp())) { inputs.push_back(dq_op.getOperand()); } else if (!ele_type.isF32()) { // If the operand is an integer tensor, then it doesn't require the - // DQ op in the pattern. + // DequantizeOp in the pattern. inputs.push_back(operand); } else { return failure(); @@ -519,9 +544,10 @@ class QuantizationPattern : public RewritePattern { } Type result_ele_type = result.getType().cast().getElementType(); - // If the user is the Quantize op, it must be the only user. - if (result.hasOneUse() && llvm::isa(*result.user_begin())) { - auto user = llvm::cast(*result.user_begin()); + // If the user is the QuantizeOp, it must be the only user. + if (result.hasOneUse() && + llvm::isa(*result.user_begin())) { + auto user = llvm::cast(*result.user_begin()); outputs_replaced.insert( {user.getResult(), enumerated_result.index()}); output_types.push_back(user.getType()); @@ -530,7 +556,7 @@ class QuantizationPattern : public RewritePattern { // D op in the pattern. outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result.getType()); - } else if (static_cast(this) + } else if (static_cast(this) ->AllowDynamicRangeQuantizedResult(quantizing_op, custom_map)) { outputs_replaced.insert({result, enumerated_result.index()}); @@ -553,7 +579,7 @@ class QuantizationPattern : public RewritePattern { llvm::enumerate(quantizing_op->getRegions())) { Region& target_region = quantized_op->getRegion(indexed_regions.index()); - BlockAndValueMapping mapping; + IRMapping mapping; indexed_regions.value().cloneInto(&target_region, mapping); } } @@ -565,12 +591,12 @@ class QuantizationPattern : public RewritePattern { // To verify the numericals, the original floating-point ops are // preserved in the graph. The result of these floating-point ops are sent // to a numeric verifier op as the reference. - if (enable_verify && !std::is_same()) { + if (enable_verify && !std::is_same_v) { // For constant operands, the floating-point constant is duplicated in // case it is quantized. for (int i = 0, e = quantized_op->getNumOperands(); i < e; ++i) { auto def = quantized_op->getOperand(i).getDefiningOp(); - if (auto q = llvm::dyn_cast_or_null(def)) { + if (auto q = llvm::dyn_cast_or_null(def)) { DenseFPElementsAttr attr; if (!matchPattern(q.getOperand(), m_Constant(&attr))) { continue; @@ -589,8 +615,8 @@ class QuantizationPattern : public RewritePattern { .isa()) { continue; } - CreateVerifier(quantizing_op, quantized_op, rewriter, i, - quant_params_); + CreateVerifier(quantizing_op, quantized_op, rewriter, i, + quant_params_); if (enable_whole_model_verify) { RewireFloatModelBackbone(quantized_op, quantizing_op); @@ -623,11 +649,11 @@ class QuantizationPattern : public RewritePattern { if (IsOpNotQuantizable(float_op)) { // For not quantizable ops, search for dequantize attached to the // quantized op of the output. - if (Operation* quantize_op = dyn_cast_or_null( + if (Operation* quantize_op = dyn_cast_or_null( *quantized_op->getResult(i).getUsers().begin())) { result = quantize_op->getResult(0); } else { - quantize_op->emitError() + quantized_op->emitError() << "Output[" << i << "] is expected to have only one user [QUANTIZE]"; return; @@ -638,12 +664,12 @@ class QuantizationPattern : public RewritePattern { for (auto user : result.getUsers()) { // Skip the Requantize op and set the user to the following dequantize // op. This happens when the quantizer tries to match the scale conflict - // with Q - Q(requant) - DQ op triples. The correct float op should be - // the user of the last DQ op. - if (llvm::isa(user)) { + // with QuantizeOp - QuantizeOp(requant) - DequantizeOp triples. The + // correct float op should be the user of the last DequantizeOp. + if (llvm::isa(user)) { user = *user->getResult(0).getUsers().begin(); } - if (auto dequantize = llvm::dyn_cast(user)) { + if (auto dequantize = llvm::dyn_cast(user)) { // Replace all uses, except not quantizable ops that are being used in // the float backbone. dequantize.getResult().replaceUsesWithIf( @@ -674,15 +700,15 @@ Type ConvertSignedQuantizedToUnsigned(Type signed_tensor_type, Location loc); // Converts quantize ops with unsigned quantized types to these with signed // quantized types and preserves the scales. -template -struct ConvertUnsignedToSigned : public OpRewritePattern { - using BaseType = ConvertUnsignedToSigned; +template +struct ConvertUnsignedToSigned : public OpRewritePattern { + using BaseType = ConvertUnsignedToSigned; using QType = quant::QuantizedType; explicit ConvertUnsignedToSigned(MLIRContext* context) - : OpRewritePattern(context, 1) {} + : OpRewritePattern(context, 1) {} - LogicalResult matchAndRewrite(Q op, + LogicalResult matchAndRewrite(QuantizeOpT op, PatternRewriter& rewriter) const override { Type output_type = op.getResult().getType(); auto qtype = QType::getQuantizedElementType(output_type); @@ -728,18 +754,18 @@ struct ConvertUnsignedToSigned : public OpRewritePattern { if (!new_qtype) return failure(); Type new_output_type = new_qtype.castFromExpressedType( QType::castToExpressedType(output_type)); - rewriter.replaceOpWithNewOp(op, new_output_type, op.getArg()); + rewriter.replaceOpWithNewOp(op, new_output_type, op.getArg()); return success(); } }; // Fold Extra Requantize ops if the preceding ops has free scale requirement. -template -struct FoldTrivalRequantizeOp : public OpRewritePattern { +template +struct FoldTrivalRequantizeOp : public OpRewritePattern { explicit FoldTrivalRequantizeOp(MLIRContext* context) - : OpRewritePattern(context, 1) {} + : OpRewritePattern(context, 1) {} - LogicalResult matchAndRewrite(RQ op, + LogicalResult matchAndRewrite(RequantizeOpT op, PatternRewriter& rewriter) const override { Value pre_quantized = op->getOperand(0); auto pre_quantized_type = diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD index 4cf3e5a0346..9652196367f 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD @@ -3,6 +3,7 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ ":friends", ], diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD index dc36508a866..03332e19f6a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tests/BUILD @@ -1,7 +1,10 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc index b6d4d8f5633..6afc81e8ce9 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc @@ -93,23 +93,23 @@ struct InsertQuantOpsAfterTFFakeQuantOp LogicalResult matchAndRewrite(TFFakeQuantOp tf_op, PatternRewriter &rewriter) const override { // We don't want to insert quantize/dequantize if the quantize op exists. - auto res = tf_op.outputs(); + auto res = tf_op.getOutputs(); if (!res.hasOneUse() || isa(*res.user_begin())) return failure(); // Extract the min/max constant values from the operands. We also consider // a special case that there are tf.Identity ops between the min/max // constants and the tf.FakeQuantWithMinMaxVarsOp. - Value min = tf_op.min(), max = tf_op.max(); + Value min = tf_op.getMin(), max = tf_op.getMax(); DenseFPElementsAttr min_value, max_value; if (auto id1 = dyn_cast_or_null(min.getDefiningOp())) { - id1.replaceAllUsesWith(id1.input()); - min = tf_op.min(); + id1.replaceAllUsesWith(id1.getInput()); + min = tf_op.getMin(); rewriter.eraseOp(id1); } if (auto id2 = dyn_cast_or_null(max.getDefiningOp())) { - id2.replaceAllUsesWith(id2.input()); - max = tf_op.max(); + id2.replaceAllUsesWith(id2.getInput()); + max = tf_op.getMax(); rewriter.eraseOp(id2); } if (!matchPattern(min, m_Constant(&min_value))) return failure(); @@ -124,8 +124,8 @@ struct InsertQuantOpsAfterTFFakeQuantOp // Use the min/max from the operands and the num_bits and narrow_range // attribute to create the quantization parameter for the new quantize op. rewriter.setInsertionPointAfter(tf_op.getOperation()); - IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.num_bits()); - BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range()); + IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.getNumBits()); + BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.getNarrowRange()); Type res_type = tf_op.getType(); TypeAttr qtype = quant::GetQuantizedTypeAttr( rewriter, res_type, min_value, max_value, quant_dim, num_bits, @@ -135,7 +135,7 @@ struct InsertQuantOpsAfterTFFakeQuantOp // Finally, use the quantization parameter to create the quantize and // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp // and its users. - Value value = tf_op.outputs(); + Value value = tf_op.getOutputs(); auto quantize = rewriter.create( tf_op.getLoc(), qtype.getValue(), value); auto dequantize = rewriter.create( diff --git a/tensorflow/compiler/mlir/lite/quantization/tests/BUILD b/tensorflow/compiler/mlir/lite/quantization/tests/BUILD index dc36508a866..03332e19f6a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/tests/BUILD @@ -1,7 +1,10 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], diff --git a/tensorflow/compiler/mlir/lite/quantization/tools/tflite_op_coverage_spec_getters_gen.cc b/tensorflow/compiler/mlir/lite/quantization/tools/tflite_op_coverage_spec_getters_gen.cc index ce0cded8bb5..1cbf9cb71af 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tools/tflite_op_coverage_spec_getters_gen.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tools/tflite_op_coverage_spec_getters_gen.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include // NOLINT #include #include "absl/strings/match.h" @@ -24,6 +23,8 @@ limitations under the License. #include "llvm/TableGen/Main.h" #include "llvm/TableGen/Record.h" #include "mlir/TableGen/Operator.h" // from @llvm-project +#include "tensorflow/tsl/platform/logging.h" +#include "tensorflow/tsl/platform/regexp.h" using llvm::LessRecord; using llvm::raw_ostream; @@ -143,7 +144,8 @@ bool CheckTypeConstraints(llvm::Init *input_value, void GenerateStaticQuantOp(std::vector &defs, std::vector &result, - InputDataType act_type, bool per_axis) { + InputDataType act_type, const bool per_axis, + const bool is_toco) { std::list required_types = { GetTypeToStringRepresentation().at("F32")}; @@ -169,8 +171,8 @@ void GenerateStaticQuantOp(std::vector &defs, // Dimension equals to -1 means per-channel quantization is not supported for // the op. Therefore check whether the return value is positive integer as // well. - std::regex per_channel_support_regex( - "(.*)(int GetQuantizationDimIndex\\(\\) \\{ return (\\d*); \\})(.*)"); + static const LazyRE2 per_channel_support_regex = { + "int GetQuantizationDimIndex\\(\\) \\{ return (\\d*); \\}"}; for (const auto *def : defs) { Operator op(def); @@ -188,11 +190,22 @@ void GenerateStaticQuantOp(std::vector &defs, per_axis)) { std::string op_name = op.getCppClassName().str(); + // TODO(b/197195711): Please add the additional operations for 16x8 MLIR + // quantizer. This code is temporary until 16x8 is fully supported in MLIR + // quantizer. + if (act_type == InputDataType::INT16) { + if (absl::StrContains(op_name, "LSTMOp") && is_toco) { + continue; + } else if (!absl::StrContains(op_name, "LSTMOp") && !is_toco) { + continue; + } + } + if (per_axis) { std::string op_extra_declaration = op.getExtraClassDeclaration().str(); - bool per_axis_support = std::regex_match( + bool per_axis_support = RE2::PartialMatch( absl::StrReplaceAll(op_extra_declaration, {{"\n", " "}}), - per_channel_support_regex); + *per_channel_support_regex); if (per_axis_support) result.emplace_back(op_name); } else { result.emplace_back(op_name); @@ -209,7 +222,8 @@ void EmitStaticInt8PerAxisQuantOp(std::vector &defs, os.indent(4) << "new std::set({\n"; std::vector result; - GenerateStaticQuantOp(defs, result, InputDataType::INT8, true); + GenerateStaticQuantOp(defs, result, InputDataType::INT8, /*per_axis=*/true, + /*is_toco=*/false); for (const auto &op_name : result) { os.indent(6) << "\"" << op_name << "\",\n"; @@ -228,7 +242,8 @@ void EmitStaticInt8PerTensorQuantOp(std::vector &defs, os.indent(4) << "new std::set({\n"; std::vector result; - GenerateStaticQuantOp(defs, result, InputDataType::INT8, false); + GenerateStaticQuantOp(defs, result, InputDataType::INT8, /*per_axis=*/false, + /*is_toco=*/false); for (const auto &op_name : result) { os.indent(6) << "\"" << op_name << "\",\n"; @@ -247,7 +262,8 @@ void EmitStaticUInt8PerAxisQuantOp(std::vector &defs, os.indent(4) << "new std::set({\n"; std::vector result; - GenerateStaticQuantOp(defs, result, InputDataType::UINT8, true); + GenerateStaticQuantOp(defs, result, InputDataType::UINT8, /*per_axis=*/true, + /*is_toco=*/false); for (const auto &op_name : result) { os.indent(6) << "\"" << op_name << "\",\n"; @@ -266,7 +282,8 @@ void EmitStaticUInt8PerTensorQuantOp(std::vector &defs, os.indent(4) << "new std::set({\n"; std::vector result; - GenerateStaticQuantOp(defs, result, InputDataType::UINT8, false); + GenerateStaticQuantOp(defs, result, InputDataType::UINT8, /*per_axis=*/false, + /*is_toco=*/false); for (const auto &op_name : result) { os.indent(6) << "\"" << op_name << "\",\n"; @@ -298,7 +315,31 @@ void EmitStaticQuantWithInt16ActOp(std::vector &defs, os.indent(4) << "new std::set({\n"; std::vector result; - GenerateStaticQuantOp(defs, result, InputDataType::INT16, false); + GenerateStaticQuantOp(defs, result, InputDataType::INT16, /*per_axis=*/false, + /*is_toco=*/false); + + for (const auto &op_name : result) { + os.indent(6) << "\"" << op_name << "\",\n"; + } + + os.indent(4) << "});"; + os.indent(2) << "return *result;\n"; + os.indent(0) << "}\n"; +} + +void EmitStaticQuantWithInt16ActTocoOp(std::vector &defs, + raw_ostream *ostream) { + raw_ostream &os = *ostream; + llvm::sort(defs, LessRecord()); + + os.indent(0) << "const std::set " + "&ExportStaticInt8WithInt16ActTocoSpec() {\n"; + os.indent(2) << "static const std::set * result =\n"; + os.indent(4) << "new std::set({\n"; + + std::vector result; + GenerateStaticQuantOp(defs, result, InputDataType::INT16, /*per_axis=*/false, + /*is_toco=*/true); for (const auto &op_name : result) { os.indent(6) << "\"" << op_name << "\",\n"; @@ -315,6 +356,7 @@ static bool TFLiteOpCoverageSpecWritersMain(raw_ostream &os, EmitStaticQuantOp(op_defs, &os); EmitDynamicRangeOp(op_defs, &os); EmitStaticQuantWithInt16ActOp(op_defs, &os); + EmitStaticQuantWithInt16ActTocoOp(op_defs, &os); EmitSparseOp(op_defs, &os); return false; } diff --git a/tensorflow/compiler/mlir/lite/sparsity/BUILD b/tensorflow/compiler/mlir/lite/sparsity/BUILD index f32f3669b77..4f2e681a986 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/BUILD +++ b/tensorflow/compiler/mlir/lite/sparsity/BUILD @@ -1,8 +1,11 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ ":friends", + "//tensorflow:__pkg__", ], licenses = ["notice"], ) @@ -31,11 +34,32 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/lite:framework", "//tensorflow/lite/core/api", + "//tensorflow/lite/core/c:private_c_api_types", "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/tools/optimize:reduced_precision_support", "@com_google_absl//absl/strings", + "@flatbuffers", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", ], ) + +tf_cc_test( + name = "sparsify_model_test", + srcs = ["sparsify_model_test.cc"], + data = [ + "//tensorflow/lite:testdata/sparse_tensor.bin", + ], + deps = [ + ":sparsify_model", + "//tensorflow/lite/core:model_builder", + "//tensorflow/lite/core/api:error_reporter", + "//tensorflow/lite/core/c:private_c_api_types", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/tools/optimize:reduced_precision_support", + "@com_google_googletest//:gtest_main", + "@flatbuffers", + ], +) diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc index 67ba8907d4c..a9614c0e62c 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc +++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/lite/tools/optimize/reduced_precision_support.h" namespace mlir { namespace lite { @@ -60,7 +61,7 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model, return kTfLiteError; } - PassManager pm(module->getContext(), OpPassManager::Nesting::Implicit); + PassManager pm((*module)->getName(), OpPassManager::Nesting::Implicit); pm.addPass(TFL::CreateDenseToSparsePass()); if (failed(pm.run(module.get()))) { @@ -75,6 +76,18 @@ TfLiteStatus SparsifyModel(const tflite::ModelT& input_model, options.toco_flags.set_force_select_tf_ops(false); options.toco_flags.set_enable_select_tf_ops(true); options.toco_flags.set_allow_custom_ops(true); + + // Copy metadata for Reduced Precision Support from input model if it exists + for (const auto& metadata : input_model.metadata) { + if (metadata->name != tflite::optimize::kTfLiteReducedPrecisionKey) { + continue; + } + + const auto& data = input_model.buffers[metadata->buffer]->data; + options.metadata[metadata->name] = std::string(data.begin(), data.end()); + break; + } + if (!tflite::MlirToFlatBufferTranslateFunction(module.get(), options, &result)) { error_reporter->Report("Failed to export MLIR to flatbuffer."); diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h index 0689a7031f9..53deff6d990 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h +++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h @@ -15,11 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_SPARSITY_SPARSIFY_MODEL_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_SPARSITY_SPARSIFY_MODEL_H_ -#include -#include - +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/model.h" +#include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/schema/schema_generated.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc new file mode 100644 index 00000000000..861a02be9ca --- /dev/null +++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc @@ -0,0 +1,87 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/lite/core/c/c_api_types.h" +#include "tensorflow/lite/core/model_builder.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/tools/optimize/reduced_precision_support.h" + +namespace mlir { +namespace lite { +namespace { + +class NoopErrorReporter : public ::tflite::ErrorReporter { + public: + int Report(const char* format, std::va_list args) override { return 0; } +}; + +TEST(SparsifyModelTest, MetadataIsAddedToOutputModel) { + std::string expected_key = tflite::optimize::kTfLiteReducedPrecisionKey; + std::string expected_value = "test_data"; + + // Load input model + auto input_fbm = tflite::FlatBufferModel::BuildFromFile( + "tensorflow/lite/testdata/sparse_tensor.bin"); + tflite::ModelT input_model; + input_fbm->GetModel()->UnPackTo(&input_model); + + // Populate input metadata + auto model_metadata_buffer = std::make_unique(); + model_metadata_buffer->data = + std::vector(expected_value.begin(), expected_value.end()); + input_model.buffers.push_back(std::move(model_metadata_buffer)); + auto metadata_t = std::make_unique(); + metadata_t->name = tflite::optimize::kTfLiteReducedPrecisionKey; + metadata_t->buffer = input_model.buffers.size() - 1; + input_model.metadata.push_back(std::move(metadata_t)); + + // Sparsify and create output model + flatbuffers::FlatBufferBuilder output_builder; + NoopErrorReporter reporter; + ASSERT_EQ(SparsifyModel(input_model, &output_builder, &reporter), kTfLiteOk); + auto output_fbm = tflite::FlatBufferModel::BuildFromBuffer( + reinterpret_cast(output_builder.GetCurrentBufferPointer()), + output_builder.GetSize()); + tflite::ModelT output_model; + output_fbm->GetModel()->UnPackTo(&output_model); + + // Extract output metadata + std::map output_metadata; + for (const auto& metadata : output_model.metadata) { + const auto& data = output_model.buffers[metadata->buffer]->data; + output_metadata[metadata->name] = std::string(data.begin(), data.end()); + } + + EXPECT_THAT(output_metadata, + testing::Contains(testing::Pair(expected_key, expected_value))); +} + +} // namespace +} // namespace lite +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 3cba4ad3123..bb4eaaef136 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -1,9 +1,8 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//visibility:public", ], @@ -22,7 +21,7 @@ cc_library( "-Ithird_party", ], deps = [ - ":mhlo_util", + ":stablehlo_util", "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", @@ -32,85 +31,43 @@ cc_library( alwayslink = 1, ) -gentbl_cc_library( - name = "mhlo_tfl_legalize_patterns_inc_gen", - compatible_with = get_compatible_with_cloud(), - includes = [ - "//tensorflow/compiler/xla/mlir_hlo/include/", - ], - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_mhlo_tfl_legalize_patterns.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "transforms/mhlo_tfl_legalize_patterns.td", - deps = [ - "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", - "//tensorflow/compiler/xla/mlir_hlo:hlo_ops_td_files", - "@llvm-project//mlir:ArithOpsTdFiles", - "@llvm-project//mlir:FuncTdFiles", - ], -) - cc_library( - name = "mhlo_tfl", + name = "stablehlo_tfl", srcs = [ - "transforms/mhlo_tfl_pass.cc", + "transforms/stablehlo_tfl_pass.cc", ], hdrs = [ - "transforms/mhlo_tfl_pass.h", + "transforms/stablehlo_tfl_pass.h", ], copts = [ "-Ithird_party", ], deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/xla/mlir_hlo", - "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", "@flatbuffers", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", - "@stablehlo//:chlo_ops", - "@stablehlo//:register", + "@stablehlo//:stablehlo_ops", ], alwayslink = 1, ) cc_library( - name = "mhlo_util", + name = "stablehlo_util", srcs = [ - "transforms/mhlo_util.cc", + "transforms/stablehlo_util.cc", ], hdrs = [ - "transforms/mhlo_util.h", + "transforms/stablehlo_util.h", ], copts = [ "-Ithird_party", ], deps = [ - "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", - "//tensorflow/compiler/mlir/xla:xla_legalize_tf", - "//tensorflow/compiler/xla/mlir_hlo", - "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", - "//tensorflow/compiler/xla/mlir_hlo:mhlo_passes", "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - "@stablehlo//:chlo_ops", - "@stablehlo//:register", ], alwayslink = 1, ) @@ -127,7 +84,7 @@ cc_library( "-Ithird_party", ], deps = [ - ":mhlo_util", + ":stablehlo_util", "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -148,54 +105,31 @@ cc_library( "-Ithird_party", ], deps = [ - ":mhlo_util", + ":stablehlo_util", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/xla/mlir_hlo", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Transforms", + "@stablehlo//:stablehlo_ops", ], alwayslink = 1, ) cc_library( - name = "tf_mhlo", - srcs = [ - "transforms/tf_mhlo_pass.cc", - ], - hdrs = [ - "transforms/tf_mhlo_pass.h", - ], - copts = [ - "-Ithird_party", - ], - deps = [ - ":mhlo_util", - "//tensorflow/compiler/mlir/tensorflow", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - ], - alwayslink = 1, -) - -cc_library( - name = "tf_mhlo_tfl", + name = "tf_stablehlo", srcs = [ - "transforms/tf_mhlo_tfl_pass.cc", + "transforms/tf_stablehlo_pass.cc", ], hdrs = [ - "transforms/tf_mhlo_tfl_pass.h", + "transforms/tf_stablehlo_pass.h", ], copts = [ "-Ithird_party", ], deps = [ - ":mhlo_tfl_legalize_patterns_inc_gen", - ":mhlo_util", - "//tensorflow/compiler/mlir/lite:tensorflow_lite", + ":stablehlo_util", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", @@ -208,45 +142,8 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - "@stablehlo//:chlo_ops", - "@stablehlo//:register", - ], - alwayslink = 1, -) - -cc_library( - name = "tf_poly", - srcs = [ - "transforms/tf_poly_pass.cc", - ], - hdrs = [ - "transforms/tf_poly_pass.h", - ], - copts = [ - "-Ithird_party", - ], - deps = [ - ":mhlo_tfl_legalize_patterns_inc_gen", - ":mhlo_util", - ":tf_mhlo", - "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", - "//tensorflow/compiler/mlir/xla:xla_legalize_tf", - "//tensorflow/compiler/xla/mlir_hlo", - "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", "@stablehlo//:chlo_ops", "@stablehlo//:register", ], @@ -254,12 +151,12 @@ cc_library( ) cc_library( - name = "tfl_mhlo", + name = "tfl_stablehlo", srcs = [ - "transforms/tfl_mhlo_pass.cc", + "transforms/tfl_stablehlo_pass.cc", ], hdrs = [ - "transforms/tfl_mhlo_pass.h", + "transforms/tfl_stablehlo_pass.h", ], copts = [ "-Ithird_party", @@ -267,16 +164,14 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", - "//tensorflow/compiler/xla/mlir_hlo", - "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", "@flatbuffers", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", - "@stablehlo//:chlo_ops", "@stablehlo//:register", + "@stablehlo//:stablehlo_ops", ], alwayslink = 1, ) @@ -294,13 +189,18 @@ cc_library( ], deps = [ ":drop_savedmodel_semantics", + ":fold_broadcast_pass", + ":fuse_convolution_pass", + ":optimize", ":rename_entrypoint_to_main", ":smuggle_disallowed_ops", - ":tf_mhlo", + ":tf_stablehlo", + ":unfuse_batch_norm_pass", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", "//tensorflow/compiler/mlir/xla:tf_xla_passes", + "//tensorflow/compiler/xla/mlir_hlo:mhlo_passes", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Transforms", ], @@ -318,7 +218,8 @@ cc_library( "-Ithird_party", ], deps = [ - ":mhlo_util", + ":stablehlo_util", + "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Transforms", ], @@ -336,13 +237,106 @@ cc_library( "-Ithird_party", ], deps = [ - ":mhlo_util", + ":stablehlo_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "unfuse_batch_norm_pass", + srcs = [ + "transforms/unfuse_batch_norm_pass.cc", + ], + hdrs = [ + "transforms/passes.h", + ], + copts = [ + "-Ithird_party", + ], + deps = [ + "//tensorflow/compiler/xla/mlir_hlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + +cc_library( + name = "fold_broadcast_pass", + srcs = [ + "transforms/fold_broadcast_pass.cc", + ], + hdrs = [ + "transforms/passes.h", + ], + copts = [ + "-Ithird_party", + ], + deps = [ + "//tensorflow/compiler/xla/mlir_hlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + +cc_library( + name = "fuse_convolution_pass", + srcs = [ + "transforms/fuse_convolution_pass.cc", + ], + hdrs = [ + "transforms/passes.h", + ], + copts = [ + "-Ithird_party", + ], + deps = [ + "//tensorflow/compiler/mlir/lite:validators", + "//tensorflow/compiler/xla/mlir_hlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Transforms", ], + alwayslink = 1, +) + +cc_library( + name = "optimize", + srcs = [ + "transforms/optimize.cc", + ], + hdrs = [ + "transforms/passes.h", + ], + copts = [ + "-Ithird_party", + ], + deps = [ + "//tensorflow/compiler/xla/mlir_hlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, ) tf_cc_binary( @@ -353,15 +347,14 @@ tf_cc_binary( copts = ["-O3"], deps = [ ":check_accepted_ops_pass", - ":mhlo_tfl", ":op_stat_pass", + ":stablehlo_tfl", ":transforms", "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir:passes", "//tensorflow/compiler/mlir/lite:flatbuffer_export", "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", - "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", "//tensorflow/compiler/mlir/tensorflow", @@ -369,7 +362,7 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/xla:legalize_tf", - "//tensorflow/compiler/mlir/xla:xla_passes", + "//tensorflow/compiler/xla/mlir/framework/transforms:passes", "//tensorflow/compiler/xla/mlir_hlo:all_passes", "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", "//tensorflow/core:lib", @@ -387,15 +380,17 @@ tf_cc_binary( ) tf_cc_binary( - name = "tf-mhlo-tfl-opt", + name = "odml-to-stablehlo-opt", testonly = True, tags = ["hostonly"], deps = [ - ":mhlo_tfl", - ":tf_mhlo", - ":tf_mhlo_tfl", - ":tf_poly", - ":tfl_mhlo", + ":fold_broadcast_pass", + ":fuse_convolution_pass", + ":optimize", + ":stablehlo_tfl", + ":tf_stablehlo", + ":tfl_stablehlo", + ":unfuse_batch_norm_pass", "//tensorflow/compiler/mlir:passes", "//tensorflow/compiler/mlir:tf_mlir_opt_main", ], diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc index 58a6de36c26..d1d8e605566 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -50,8 +51,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/check_accepted_ops_pass.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_tfl_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" @@ -61,10 +62,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h" #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" -#include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir/framework/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/register.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" @@ -133,7 +134,7 @@ opt skip_resize( // NOLINTNEXTLINE opt smuggle_disallowed_ops( "smuggle-disallowed-ops", - llvm::cl::desc("Smuggle disallowed ops via mhlo.custom_calls."), + llvm::cl::desc("Smuggle disallowed ops via stablehlo.custom_calls."), llvm::cl::Optional, llvm::cl::init(false)); // NOLINTNEXTLINE @@ -180,9 +181,9 @@ tensorflow::StatusOr> ImportSavedModelOrMLIR( tensorflow::Status ConvertStableHLOToFlatbuffer(mlir::ModuleOp module, std::string* flatbuffer_str) { // Convert StableHLO MLIR to TFLite Custom Op MLIR - mlir::PassManager mhlo_tfl_pm(module->getContext()); - mhlo_tfl_pm.addNestedPass(TFL::mhlo::CreateMhloToTflPass()); - if (failed(mhlo_tfl_pm.run(module))) { + mlir::PassManager pm(module->getContext()); + pm.addNestedPass(CreateStablehloToTflPass()); + if (failed(pm.run(module))) { return tensorflow::errors::Aborted("HLO to TFL passes failed."); } @@ -271,7 +272,7 @@ tensorflow::Status ConvertTFToStableHLO( mlir::odml::AddStablehloOptimizationPasses(pm); if (failed(pm.run(tf_module))) { - return tensorflow::errors::Aborted("Lowering to Compute IR failed."); + return tensorflow::errors::Aborted("Lowering to StableHLO failed."); } return ::tensorflow::OkStatus(); @@ -314,7 +315,7 @@ tensorflow::Status RunConverter(const PassPipelineCLParser& pass_pipeline) { elide_large_elements_attrs)); } - llvm::Optional session = llvm::None; + llvm::Optional session = std::nullopt; if (bundle) session = bundle->GetSession(); // NOMUTANTS--it should pass. if (freeze_tf_graph) { @@ -332,10 +333,23 @@ tensorflow::Status RunConverter(const PassPipelineCLParser& pass_pipeline) { } auto conversion_status = ConvertTFToStableHLO(*module, pass_pipeline); - auto export_path = conversion_status.ok() - ? output_path - : absl::StrCat(verbose_dir, "/debug_mhlo.mlir"); - return ExportModule(*module, export_path, elide_large_elements_attrs); + auto output_export_status = + ExportModule(*module, output_path, elide_large_elements_attrs); + if (!conversion_status.ok()) { + LOG(ERROR) << "TF to StableHLO conversion failed: " + << conversion_status.error_message(); + + auto debug_export_status = ExportModule( + *module, absl::StrCat(verbose_dir, "/debug_stablehlo.mlir"), + elide_large_elements_attrs); + if (!debug_export_status.ok()) { + LOG(ERROR) << "Failed to export debug_stablehlo.mlir: " + << debug_export_status.error_message(); + } + + return conversion_status; + } + return output_export_status; } // All MLIR and TF passes are registered here, similar to mlirOptMain. @@ -355,9 +369,8 @@ void initAllPasses() { mlir::lmhlo::registerAllLmhloPasses(); // These are in compiler/mlir/xla and not part of the above MHLO passes. mlir::mhlo::registerTfXlaPasses(); - mlir::mhlo::registerXlaPasses(); + mlir::mhlo::registerXlaFrameworkPasses(); mlir::mhlo::registerLegalizeTFPass(); - mlir::mhlo::registerLegalizeTFControlFlowPass(); mlir::mhlo::registerLegalizeTfTypesPassPass(); tensorflow::RegisterConvertMlirToXlaHloPipelineWithDefaults(); tensorflow::RegisterGraphOptimizationPasses(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/tests/BUILD index 140bb674f4c..7002dd57dda 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/BUILD @@ -2,6 +2,7 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//visibility:public", ], @@ -11,6 +12,9 @@ package( glob_lit_tests( data = [":test_utilities"], driver = "//tensorflow/compiler/mlir/lite/stablehlo:run_lit.sh", + size_override = { + "legalize-skip-quantization-ops.mlir": "medium", + }, test_file_exts = [ "mlir", "cc", @@ -24,8 +28,8 @@ filegroup( data = [ "//tensorflow/compiler/mlir/lite:flatbuffer_translate", "//tensorflow/compiler/mlir/lite:tf_tfl_translate", + "//tensorflow/compiler/mlir/lite/stablehlo:odml-to-stablehlo-opt", "//tensorflow/compiler/mlir/lite/stablehlo:odml_to_stablehlo", - "//tensorflow/compiler/mlir/lite/stablehlo:tf-mhlo-tfl-opt", "@llvm-project//llvm:FileCheck", "@llvm-project//mlir:run_lit.sh", ], diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcast.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcast.mlir new file mode 100644 index 00000000000..18526fe94c7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/fold_broadcast.mlir @@ -0,0 +1,44 @@ +// RUN: odml-to-stablehlo-opt %s -constant-fold-broadcast-pass -cse -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @foldBroadcastInDimBeforeMulOp_bcast_dim_1D_float +func.func @foldBroadcastInDimBeforeMulOp_bcast_dim_1D_float() -> (tensor<1x1x2x4xf32>) { + // CHECK-DAG: %[[RES:.*]] = mhlo.constant dense<{{\[\[\[\[}}1.000000e+00, 4.000000e+00, 9.000000e+00, 1.600000e+01], [5.000000e+00, 1.200000e+01, 2.100000e+01, 3.200000e+01]]]]> : tensor<1x1x2x4xf32> + %cst0 = mhlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> + %cst1 = mhlo.constant dense<[[[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]]]> : tensor<1x1x2x4xf32> + %0 = "mhlo.broadcast_in_dim"(%cst0) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<1x1x2x4xf32> + %1 = mhlo.multiply %0, %cst1 : tensor<1x1x2x4xf32> + // CHECK: return %[[RES]] : tensor<1x1x2x4xf32> + func.return %1 : tensor<1x1x2x4xf32> +} + +// CHECK-LABEL: @foldBroadcastInDimBeforeMulOp_bcast_dim_2D_float +func.func @foldBroadcastInDimBeforeMulOp_bcast_dim_2D_float() -> (tensor<1x2x2x3xf32>) { + // CHECK-DAG: %[[RES:.*]] = mhlo.constant dense<{{\[\[\[\[}}1.000000e+00, 4.000000e+00, 9.000000e+00], [4.000000e+00, 1.000000e+01, 1.800000e+01]], {{\[\[}}2.800000e+01, 4.000000e+01, 5.400000e+01], [4.000000e+01, 5.500000e+01, 7.200000e+01]]]]> : tensor<1x2x2x3xf32> + %cst0 = mhlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> + %cst1 = mhlo.constant dense<[[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]]> : tensor<1x2x2x3xf32> + %0 = "mhlo.broadcast_in_dim"(%cst0) {broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<1x2x2x3xf32> + %1 = mhlo.multiply %0, %cst1 : tensor<1x2x2x3xf32> + // CHECK: return %[[RES]] : tensor<1x2x2x3xf32> + func.return %1 : tensor<1x2x2x3xf32> +} + +// CHECK-LABEL: @foldBroadcastInDimBeforeMulOp_bcast_dim_1D_int +func.func @foldBroadcastInDimBeforeMulOp_bcast_dim_1D_int() -> (tensor<1x1x2x4xi32>) { + // CHECK-DAG: %[[RES:.*]] = mhlo.constant dense<{{\[\[\[\[}}1, 4, 9, 16], [5, 12, 21, 32]]]]> : tensor<1x1x2x4xi32> + %cst0 = mhlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32> + %cst1 = mhlo.constant dense<[[[[1, 2, 3, 4], [5, 6, 7, 8]]]]> : tensor<1x1x2x4xi32> + %0 = "mhlo.broadcast_in_dim"(%cst0) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1x1x2x4xi32> + %1 = mhlo.multiply %0, %cst1 : tensor<1x1x2x4xi32> + // CHECK: return %[[RES]] : tensor<1x1x2x4xi32> + func.return %1 : tensor<1x1x2x4xi32> +} + +// CHECK-LABEL: @foldBroadcastInDimBeforeMulOp_bcast_dim_4D_int +func.func @foldBroadcastInDimBeforeMulOp_bcast_dim_4D_int(%arg0: tensor<1x2x1x4xi32>) -> tensor<1x2x1x4xi32> { + // CHECK-DAG: %[[RES:.*]] = mhlo.constant dense<{{\[\[\[\[}}0, 1, 2, 3]], {{\[\[}}0, 1, 2, 3]]]]> : tensor<1x2x1x4xi32> + %0 = mhlo.constant dense<[[[[0, 1, 2, 3]]]]> : tensor<1x1x1x4xi32> + %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} : (tensor<1x1x1x4xi32>) -> tensor<1x2x1x4xi32> + // CHECK: mhlo.multiply %[[ARG0:.*]], %[[RES]] : tensor<1x2x1x4xi32> + %2 = mhlo.multiply %arg0, %1 : tensor<1x2x1x4xi32> + return %2 : tensor<1x2x1x4xi32> +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/fuse_mhlo_convolution.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/fuse_mhlo_convolution.mlir new file mode 100644 index 00000000000..a05f4ee36fb --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/fuse_mhlo_convolution.mlir @@ -0,0 +1,18 @@ +// RUN: odml-to-stablehlo-opt %s -fuse-mhlo-convolution-pass -cse | FileCheck %s + +// CHECK-LABEL: @fuseMulAndConv2D +// CHECK-SAME: %[[INPUT:[^:[:space:]]+]] +func.func @fuseMulAndConv2D(%input: tensor<1x256x256x3xf32>) -> (tensor<1x256x256x2xf32>) { + // CHECK-DAG: %[[FILTER:.+]] = mhlo.constant dense<{{\[\[\[\[}}1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00], [5.000000e+00, 6.000000e+00]]]]> : tensor<1x1x3x2xf32> + // CHECK-DAG: %[[CST:.+]] = mhlo.constant dense<[1.000000e-01, 2.000000e-01]> : tensor<2xf32> + // CHECK-DAG: %[[CST_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[CST]]) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<2xf32>) -> tensor<1x1x3x2xf32> + // CHECK-DAG: %[[NEW_FILTER:.+]] = mhlo.multiply %[[CST_BCAST]], %[[FILTER]] : tensor<1x1x3x2xf32> + // CHECK-DAG: %[[RESULT:.+]] = mhlo.convolution(%[[INPUT]], %[[NEW_FILTER]]) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = {{\[\[}}0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x256x256x3xf32>, tensor<1x1x3x2xf32>) -> tensor<1x256x256x2xf32> + %filter = mhlo.constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]]]> : tensor<1x1x3x2xf32> + %cst = mhlo.constant dense<[0.1, 0.2]> : tensor<2xf32> + %0 = mhlo.convolution(%input, %filter) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x256x256x3xf32>, tensor<1x1x3x2xf32>) -> tensor<1x256x256x2xf32> + %1 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<2xf32>) -> tensor<1x256x256x2xf32> + %2 = mhlo.multiply %0, %1 : tensor<1x256x256x2xf32> + // CHECK-DAG: return %[[RESULT]] + func.return %2 : tensor<1x256x256x2xf32> +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-acos.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-acos.mlir deleted file mode 100644 index 2e558fec4cc..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-acos.mlir +++ /dev/null @@ -1,38 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -tf-mhlo-tfl | FileCheck %s - -// Convert Acos to TFL via MHLO, but leave the unsupported tf.Cos untouched. -func.func @convertAcos(%arg0: tensor<3xf32>) -> tensor<3xf32> { - %1 = "tf.Acos"(%arg0) {device = ""} : (tensor<3xf32>) -> tensor<3xf32> - %2 = "tf.Cos"(%1) : (tensor<3xf32>) -> tensor<3xf32> - func.return %2: tensor<3xf32> -} - -// CHECK-LABEL: @convertAcos -// CHECK-SAME: %arg0: tensor<3xf32> -// CHECK: %[[CST:.*]] = arith.constant dense<-1.000000e+00> : tensor<3xf32> -// CHECK: %[[TMP1:.*]] = tfl.not_equal(%arg0, %[[CST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> -// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : tensor<3xf32> -// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<3xf32> -// CHECK: %[[TMP2:.*]] = tfl.mul %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<3xf32> -// CHECK: %[[TMP3:.*]] = tfl.sub %[[CST1]], %[[TMP2]] {fused_activation_function = "NONE"} : tensor<3xf32> -// CHECK: %[[TMP4:.*]] = "tfl.sqrt"(%[[TMP3]]) : (tensor<3xf32>) -> tensor<3xf32> -// CHECK: %[[CST2:.*]] = arith.constant dense<1.000000e+00> : tensor<3xf32> -// CHECK: %[[TMP5:.*]] = tfl.add %[[CST2]], %arg0 {fused_activation_function = "NONE"} : tensor<3xf32> -// CHECK: %[[TMP6:.*]] = "tfl.custom"(%[[TMP4]], %[[TMP5]]) {custom_code = "atan2", custom_option = #tfl} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> -// CHECK: %[[TMP7:.*]] = tfl.mul %[[CST0]], %[[TMP6]] {fused_activation_function = "NONE"} : tensor<3xf32> -// CHECK: %[[CST3:.*]] = arith.constant dense<3.14159274> : tensor<3xf32> -// CHECK: %[[RES1:.*]] = "tfl.select"(%[[TMP1]], %[[TMP7]], %[[CST3]]) : (tensor<3xi1>, tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> -// CHECK: %[[RES2:.*]] = "tf.Cos"(%[[RES1]]) : (tensor<3xf32>) -> tensor<3xf32> -// CHECK: return %[[RES2]] : tensor<3xf32> - - -// Leave unsupported tf.Cos untouched in TF dialect. -func.func @cosUnconverted(%arg0: tensor<3xf32>) -> tensor<3xf32> { - %0 = "tf.Cos"(%arg0) : (tensor<3xf32>) -> tensor<3xf32> - func.return %0 : tensor<3xf32> -} - -// CHECK-LABEL: @cosUnconverted -// CHECK-SAME: %arg0: tensor<3xf32> -// CHECK: %[[RES:.*]] = "tf.Cos"(%arg0) : (tensor<3xf32>) -> tensor<3xf32> -// CHECK: return %[[RES]] : tensor<3xf32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-inplaceupdate-tf_mhlo_tflite.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-inplaceupdate-tf_mhlo_tflite.mlir deleted file mode 100644 index 27653a005ab..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-inplaceupdate-tf_mhlo_tflite.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -tf-mhlo-tfl | FileCheck %s - -module attributes {tf.versions = {producer = 888 : i32}} { - -func.func @tfInplaceUpdate(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> { - %1 = arith.constant dense<1> : tensor<1xi32> - %2 = arith.constant dense<2.0> : tensor<1x1x2xf32> - %3 = "tf.InplaceUpdate"(%arg0, %1, %2) {device = ""} - : (tensor<2x1x2xf32>, tensor<1xi32>, tensor<1x1x2xf32>) -> tensor<2x1x2xf32> - func.return %3 : tensor<2x1x2xf32> -} - -} - -// CHECK-LABEL: @tfInplaceUpdate -// CHECK-DAG: %cst = arith.constant dense<1> : tensor<1xi32> -// CHECK-DAG: %cst_0 = arith.constant dense<2.000000e+00> : tensor<1x1x2xf32> -// CHECK: %0 = "tf.InplaceUpdate"(%arg0, %cst, %cst_0) {device = ""} : (tensor<2x1x2xf32>, tensor<1xi32>, tensor<1x1x2xf32>) -> tensor<2x1x2xf32> -// CHECK: return %0 : tensor<2x1x2xf32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-inplaceupdate.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-inplaceupdate.mlir index efb8c5e20be..4355ab7de08 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-inplaceupdate.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-inplaceupdate.mlir @@ -1,4 +1,4 @@ -// RUN: tf-mhlo-tfl-opt %s -tf-mhlo | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -tf-stablehlo | FileCheck %s module attributes {tf.versions = {producer = 888 : i32}} { @@ -13,10 +13,8 @@ func.func @tfInplaceUpdate(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> { } // CHECK-LABEL: @tfInplaceUpdate -// CHECK-DAG: %{{.*}} = arith.constant dense<1> : tensor<1xi32> -// CHECK-DAG: %{{.*}} = arith.constant dense<2.000000e+00> : tensor<1x1x2xf32> -// CHECK-DAG: %[[CST0:.*]] = mhlo.constant dense<1> : tensor -// CHECK-DAG: %[[CST1:.*]] = mhlo.constant dense<0> : tensor -// CHECK-DAG: %[[CST2:.*]] = mhlo.constant dense<2.000000e+00> : tensor<1x1x2xf32> -// CHECK: %[[RES:.*]] = mhlo.dynamic_update_slice %arg0, %[[CST2]], %[[CST0]], %[[CST1]], %[[CST1]] : (tensor<2x1x2xf32>, tensor<1x1x2xf32>, tensor, tensor, tensor) -> tensor<2x1x2xf32> +// CHECK-DAG: %[[CST0:.*]] = stablehlo.constant dense<1> : tensor +// CHECK-DAG: %[[CST1:.*]] = stablehlo.constant dense<0> : tensor +// CHECK-DAG: %[[CST2:.*]] = stablehlo.constant dense<2.000000e+00> : tensor<1x1x2xf32> +// CHECK: %[[RES:.*]] = stablehlo.dynamic_update_slice %arg0, %[[CST2]], %[[CST0]], %[[CST1]], %[[CST1]] : (tensor<2x1x2xf32>, tensor<1x1x2xf32>, tensor, tensor, tensor) -> tensor<2x1x2xf32> // CHECK: return %[[RES]] : tensor<2x1x2xf32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tf-fb-tf.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tf-fb-tf.mlir deleted file mode 100644 index 840d4164658..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tf-fb-tf.mlir +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | flatbuffer_translate -mlir-to-tflite-flatbuffer - -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s - -module { -func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = mhlo.add %arg0, %arg0 : tensor<2xi32> - %1 = mhlo.subtract %0, %arg0 : tensor<2xi32> - func.return %1 : tensor<2xi32> -} -} - -// CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { -// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "mhlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: %1 = "tfl.custom"(%0, %arg0) {custom_code = "mhlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: return %1 : tensor<2xi32> -// CHECK-NEXT: } -// CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-broadcast_in_dim.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-broadcast_in_dim.mlir deleted file mode 100644 index 15f2c6e1f2e..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-broadcast_in_dim.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { - %0= "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> - func.return %0 : tensor<1x2x2xi32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0) {custom_code = "mhlo.broadcast_in_dim", custom_option = #tfl} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> -// CHECK-NEXT: return %0 : tensor<1x2x2xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-compare.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-compare.mlir deleted file mode 100644 index b815be086fe..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-compare.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xi1> { - %0 = mhlo.compare LT, %arg0, %arg1 : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - %1 = mhlo.compare LT, %arg0, %arg1, TOTALORDER : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> - %2 = mhlo.compare GT, %arg2, %arg3 : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> - func.return %2 : tensor<2xi1> -} -} - -// CHECK: module { -// CHECK-NEXT: func.func @main(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xi1> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "mhlo.compare", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK-NEXT: %1 = "tfl.custom"(%arg0, %arg1) {custom_code = "mhlo.compare", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK-NEXT: %2 = "tfl.custom"(%arg2, %arg3) {custom_code = "mhlo.compare", custom_option = #tfl} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> -// CHECK-NEXT: return %2 : tensor<2xi1> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-constant.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-constant.mlir deleted file mode 100644 index 93e1f8ed1be..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-constant.mlir +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck %s - -module { -func.func @main() -> tensor<2xf32> { - %0 = mhlo.constant dense<2> : tensor - %1 = mhlo.constant dense<[10.0, 11.0]> : tensor<2xf32> - func.return %1 : tensor<2xf32> -} -} - -// CHECK: module { -// CHECK-NEXT: func.func @main() -> tensor<2xf32> { -// CHECK-NEXT: %0 = "tfl.custom"() {custom_code = "mhlo.constant", custom_option = #tfl} : () -> tensor -// CHECK-NEXT: %1 = "tfl.custom"() {custom_code = "mhlo.constant", custom_option = #tfl} : () -> tensor<2xf32> -// CHECK-NEXT: return %1 : tensor<2xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-reshape.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-reshape.mlir deleted file mode 100644 index 4d500bae675..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-reshape.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = "mhlo.reshape"(%arg0) : (tensor<2xi32>) -> tensor<2xi32> - func.return %0 : tensor<2xi32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0) {custom_code = "mhlo.reshape", custom_option = #tfl} : (tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: return %0 : tensor<2xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-rsqrt.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-rsqrt.mlir deleted file mode 100644 index 08c8c788b44..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-rsqrt.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0) {custom_code = "mhlo.rsqrt", custom_option = #tfl} : (tensor<2xf32>) -> tensor<2xf32> -// CHECK-NEXT: return %0 : tensor<2xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl.mlir deleted file mode 100644 index 12163c26c87..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl.mlir +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck %s - -module { -func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = mhlo.add %arg0, %arg0 : tensor<2xi32> - %1 = mhlo.subtract %0, %arg0 : tensor<2xi32> - func.return %1 : tensor<2xi32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "mhlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: %1 = "tfl.custom"(%0, %arg0) {custom_code = "mhlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: return %1 : tensor<2xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-poly.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-poly.mlir deleted file mode 100644 index 36b04e7951c..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-poly.mlir +++ /dev/null @@ -1,31 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -tf-poly | FileCheck %s - -module attributes {tf.versions = {producer = 888 : i32}} { - -func.func @tfInplaceUpdate(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> { - %1 = arith.constant dense<1> : tensor<1xi32> - %2 = arith.constant dense<2.0> : tensor<1x1x2xf32> - %3 = "tf.InplaceUpdate"(%arg0, %1, %2) {device = ""} - : (tensor<2x1x2xf32>, tensor<1xi32>, tensor<1x1x2xf32>) -> tensor<2x1x2xf32> - func.return %3 : tensor<2x1x2xf32> -} - -} - -// CHECK-LABEL: @tfInplaceUpdate -// CHECK-NEXT: %cst = arith.constant dense<1> : tensor<1xi32> -// CHECK-NEXT: %cst_0 = arith.constant dense<2.000000e+00> : tensor<1x1x2xf32> -// CHECK-NEXT: %0 = "tfl.poly_call"(%arg0, %cst, %cst_0) ({ -// CHECK-NEXT: ^bb0(%arg1: tensor<2x1x2xf32>, %arg2: tensor<1xi32>, %arg3: tensor<1x1x2xf32>): -// CHECK-NEXT: %1 = "tf.InplaceUpdate"(%arg1, %arg2, %arg3) {device = ""} : (tensor<2x1x2xf32>, tensor<1xi32>, tensor<1x1x2xf32>) -> tensor<2x1x2xf32> -// CHECK-NEXT: "tfl.yield"(%1) : (tensor<2x1x2xf32>) -> () -// CHECK-NEXT: }, { -// CHECK-NEXT: ^bb0(%arg1: tensor<2x1x2xf32>, %arg2: tensor<1xi32>, %arg3: tensor<1x1x2xf32>): -// CHECK-NEXT: %1 = "mhlo.slice"(%arg2) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> -// CHECK-NEXT: %2 = mhlo.reshape %1 : (tensor<1xi32>) -> tensor -// CHECK-NEXT: %3 = mhlo.constant dense<0> : tensor -// CHECK-NEXT: %4 = "mhlo.slice"(%arg3) {limit_indices = dense<[1, 1, 2]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<1x1x2xf32> -// CHECK-NEXT: %5 = mhlo.dynamic_update_slice %arg1, %4, %2, %3, %3 : (tensor<2x1x2xf32>, tensor<1x1x2xf32>, tensor, tensor, tensor) -> tensor<2x1x2xf32> -// CHECK-NEXT: "tfl.yield"(%5) : (tensor<2x1x2xf32>) -> () -// CHECK-NEXT: }) {device = ""} : (tensor<2x1x2xf32>, tensor<1xi32>, tensor<1x1x2xf32>) -> tensor<2x1x2xf32> -// CHECK-NEXT: return %0 : tensor<2x1x2xf32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-skip-quantization-ops.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-skip-quantization-ops.mlir index dec004d600c..30ad5f7fce9 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-skip-quantization-ops.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-skip-quantization-ops.mlir @@ -1,5 +1,5 @@ -// RUN: tf-mhlo-tfl-opt %s --tf-mhlo=skip-quantization-ops=true | FileCheck %s --check-prefix=CHECK-SKIP -// RUN: tf-mhlo-tfl-opt %s --tf-mhlo=skip-quantization-ops=false | FileCheck %s --check-prefix=CHECK-NOSKIP +// RUN: odml-to-stablehlo-opt %s --tf-stablehlo=skip-quantization-ops=true | FileCheck %s --check-prefix=CHECK-SKIP +// RUN: odml-to-stablehlo-opt %s --tf-stablehlo=skip-quantization-ops=false | FileCheck %s --check-prefix=CHECK-NOSKIP func.func @fake_quant_with_min_max_vars(%arg0: tensor<1x1x28x48xf32>, %arg1: tensor, %arg2: tensor) -> tensor<1x1x28x48xf32> { %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {device = "", narrow_range = true, num_bits = 8 : i64} : (tensor<1x1x28x48xf32>, tensor, tensor) -> tensor<1x1x28x48xf32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tf-fb-tf.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tf-fb-tf.mlir new file mode 100644 index 00000000000..f50c399deb5 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tf-fb-tf.mlir @@ -0,0 +1,17 @@ +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | flatbuffer_translate -mlir-to-tflite-flatbuffer - -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s + +module { +func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %0 = stablehlo.add %arg0, %arg0 : tensor<2xi32> + %1 = stablehlo.subtract %0, %arg0 : tensor<2xi32> + func.return %1 : tensor<2xi32> +} +} + +// CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { +// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK-NEXT: %1 = "tfl.custom"(%0, %arg0) {custom_code = "stablehlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK-NEXT: return %1 : tensor<2xi32> +// CHECK-NEXT: } +// CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-sub.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-add.mlir similarity index 69% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-sub.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-add.mlir index f47d382ab63..b0eb02192f4 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-sub.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-add.mlir @@ -1,15 +1,15 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s module { func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = mhlo.subtract %arg0, %arg0 : tensor<2xi32> + %0 = stablehlo.add %arg0, %arg0 : tensor<2xi32> func.return %0 : tensor<2xi32> } } // CHECK: module { // CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "mhlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-broadcast_in_dim.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-broadcast_in_dim.mlir new file mode 100644 index 00000000000..3b5dae4706e --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-broadcast_in_dim.mlir @@ -0,0 +1,15 @@ +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s + +module { +func.func @main(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { + %0= "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> + func.return %0 : tensor<1x2x2xi32> +} +} + +// CHECK: module { +// CHECK-NEXT: func @main(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { +// CHECK-NEXT: %0 = "tfl.custom"(%arg0) {custom_code = "stablehlo.broadcast_in_dim", custom_option = #tfl} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> +// CHECK-NEXT: return %0 : tensor<1x2x2xi32> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-clamp.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-clamp.mlir similarity index 51% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-clamp.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-clamp.mlir index a98e0f52f77..2d0051afde9 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-clamp.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-clamp.mlir @@ -1,15 +1,15 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s module { func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = "mhlo.clamp"(%arg0, %arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %0 = "stablehlo.clamp"(%arg0, %arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> func.return %0 : tensor<2xi32> } } // CHECK: module { // CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0, %arg0) {custom_code = "mhlo.clamp", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0, %arg0) {custom_code = "stablehlo.clamp", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-compare.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-compare.mlir new file mode 100644 index 00000000000..44b69ab9330 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-compare.mlir @@ -0,0 +1,19 @@ +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s + +module { +func.func @main(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xi1> { + %0 = stablehlo.compare LT, %arg0, %arg1 : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %1 = stablehlo.compare LT, %arg0, %arg1, TOTALORDER : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + %2 = stablehlo.compare GT, %arg2, %arg3 : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + func.return %2 : tensor<2xi1> +} +} + +// CHECK: module { +// CHECK-NEXT: func.func @main(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xi1> { +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.compare", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK-NEXT: %1 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.compare", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK-NEXT: %2 = "tfl.custom"(%arg2, %arg3) {custom_code = "stablehlo.compare", custom_option = #tfl} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: return %2 : tensor<2xi1> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-concat.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-concat.mlir similarity index 53% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-concat.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-concat.mlir index 7f0ae066aab..4be83175a41 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-concat.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-concat.mlir @@ -1,15 +1,15 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s module { func.func @main(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { - %1 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> + %1 = "stablehlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> func.return %1 : tensor<6x3xf32> } } // CHECK: module { // CHECK-NEXT: func @main(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "mhlo.concatenate", custom_option = #tfl} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.concatenate", custom_option = #tfl} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> // CHECK-NEXT: return %0 : tensor<6x3xf32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-constant.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-constant.mlir new file mode 100644 index 00000000000..62c2253869c --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-constant.mlir @@ -0,0 +1,17 @@ +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s + +module { +func.func @main() -> tensor<2xf32> { + %0 = stablehlo.constant dense<2> : tensor + %1 = stablehlo.constant dense<[10.0, 11.0]> : tensor<2xf32> + func.return %1 : tensor<2xf32> +} +} + +// CHECK: module { +// CHECK-NEXT: func.func @main() -> tensor<2xf32> { +// CHECK-NEXT: %0 = "tfl.custom"() {custom_code = "stablehlo.constant", custom_option = #tfl} : () -> tensor +// CHECK-NEXT: %1 = "tfl.custom"() {custom_code = "stablehlo.constant", custom_option = #tfl} : () -> tensor<2xf32> +// CHECK-NEXT: return %1 : tensor<2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-conv.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-conv.mlir similarity index 51% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-conv.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-conv.mlir index c21b33a9db5..40305064722 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-conv.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-conv.mlir @@ -1,9 +1,9 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck -dump-input always %s +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck -dump-input always %s module { func.func @main(%arg0: tensor<8x8x1x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> { - %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv, %arg1: tensor<3x3x16x207xf32>) -> output_batch_dimension = 3, output_feature_dimension = 0, output_spatial_dimensions = [1, 2] - >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo, #mhlo], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : + >, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#stablehlo, #stablehlo], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> func.return %0 : tensor<16x8x8x1xf32> } @@ -21,7 +21,7 @@ func.func @main(%arg0: tensor<8x8x1x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> // CHECK: module { // CHECK-NEXT: func @main(%arg0: tensor<8x8x1x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "mhlo.convolution", custom_option = #tfl} : (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.convolution", custom_option = #tfl} : (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> // CHECK-NEXT: return %0 : tensor<16x8x8x1xf32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-dot.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-dot.mlir similarity index 64% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-dot.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-dot.mlir index 17ad4b9dd15..ef715f778e8 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-dot.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-dot.mlir @@ -1,9 +1,9 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s module { func.func @main(%arg0: tensor<72x2048xf32>, %arg1: tensor<2048x512xf32>) -> tensor<72x512xf32> { - %0 = "mhlo.dot"(%arg0, %arg1) { - dimension_numbers = #mhlo.dot< + %0 = "stablehlo.dot"(%arg0, %arg1) { + dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [0, 1], rhs_batching_dimensions = [1, 2], lhs_contracting_dimensions = [0, 1], @@ -16,7 +16,7 @@ func.func @main(%arg0: tensor<72x2048xf32>, %arg1: tensor<2048x512xf32>) -> tens // CHECK: module { // CHECK-NEXT: func.func @main(%arg0: tensor<72x2048xf32>, %arg1: tensor<2048x512xf32>) -> tensor<72x512xf32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "mhlo.dot", custom_option = #tfl} : (tensor<72x2048xf32>, tensor<2048x512xf32>) -> tensor<72x512xf32> +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.dot", custom_option = #tfl} : (tensor<72x2048xf32>, tensor<2048x512xf32>) -> tensor<72x512xf32> // CHECK-NEXT: return %0 : tensor<72x512xf32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-gather.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-gather.mlir similarity index 60% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-gather.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-gather.mlir index 5235221b4e2..5fb78f0540c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-gather.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-gather.mlir @@ -1,9 +1,9 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s module { func.func @main(%arg0: tensor<1x128x256xf32>, %arg1: tensor<30x1x2xi32>) -> tensor<30x1x256xf32> { - %0 = "mhlo.gather"(%arg0, %arg1) { - dimension_numbers = #mhlo.gather< + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< offset_dims = [2], collapsed_slice_dims = [0, 1], start_index_map = [0, 1], @@ -17,7 +17,7 @@ func.func @main(%arg0: tensor<1x128x256xf32>, %arg1: tensor<30x1x2xi32>) -> tens // CHECK: module { // CHECK-NEXT: func.func @main(%arg0: tensor<1x128x256xf32>, %arg1: tensor<30x1x2xi32>) -> tensor<30x1x256xf32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "mhlo.gather", custom_option = #tfl} : (tensor<1x128x256xf32>, tensor<30x1x2xi32>) -> tensor<30x1x256xf32> +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.gather", custom_option = #tfl} : (tensor<1x128x256xf32>, tensor<30x1x2xi32>) -> tensor<30x1x256xf32> // CHECK-NEXT: return %0 : tensor<30x1x256xf32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-max.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-max.mlir similarity index 58% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-max.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-max.mlir index 99f0edf7562..e8ccfcaee07 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-max.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-max.mlir @@ -1,15 +1,15 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s module { func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = mhlo.maximum %arg0, %arg0 : tensor<2xi32> + %0 = stablehlo.maximum %arg0, %arg0 : tensor<2xi32> func.return %0 : tensor<2xi32> } } // CHECK: module { // CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "mhlo.maximum", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.maximum", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-add.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-mul.mlir similarity index 58% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-add.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-mul.mlir index 4c925611687..b4bcbc455f2 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-add.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-mul.mlir @@ -1,15 +1,15 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s module { func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = mhlo.add %arg0, %arg0 : tensor<2xi32> + %0 = stablehlo.multiply %arg0, %arg0 : tensor<2xi32> func.return %0 : tensor<2xi32> } } // CHECK: module { // CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "mhlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.multiply", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-pad.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-pad.mlir similarity index 62% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-pad.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-pad.mlir index a0ef6d95dc0..f5f69b1cf18 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-pad.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-pad.mlir @@ -1,8 +1,8 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s module { func.func @main(%arg0: tensor<8x128xf32>, %arg1: tensor) -> tensor<11x131xf32> { - %0 = "mhlo.pad"(%arg0, %arg1) { + %0 = "stablehlo.pad"(%arg0, %arg1) { edge_padding_low = dense<[1, 0]> : tensor<2xi64>, edge_padding_high = dense<[2, 3]> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64> @@ -13,7 +13,7 @@ func.func @main(%arg0: tensor<8x128xf32>, %arg1: tensor) -> tensor<11x131xf // CHECK: module { // CHECK-NEXT: func @main(%arg0: tensor<8x128xf32>, %arg1: tensor) -> tensor<11x131xf32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "mhlo.pad", custom_option = #tfl} : (tensor<8x128xf32>, tensor) -> tensor<11x131xf32> +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.pad", custom_option = #tfl} : (tensor<8x128xf32>, tensor) -> tensor<11x131xf32> // CHECK-NEXT: return %0 : tensor<11x131xf32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-reshape.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-reshape.mlir new file mode 100644 index 00000000000..281f14bf8b8 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-reshape.mlir @@ -0,0 +1,15 @@ +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s + +module { +func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %0 = "stablehlo.reshape"(%arg0) : (tensor<2xi32>) -> tensor<2xi32> + func.return %0 : tensor<2xi32> +} +} + +// CHECK: module { +// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { +// CHECK-NEXT: %0 = "tfl.custom"(%arg0) {custom_code = "stablehlo.reshape", custom_option = #tfl} : (tensor<2xi32>) -> tensor<2xi32> +// CHECK-NEXT: return %0 : tensor<2xi32> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-rsqrt.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-rsqrt.mlir new file mode 100644 index 00000000000..f352e19959c --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-rsqrt.mlir @@ -0,0 +1,15 @@ +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s + +module { +func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "stablehlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} +} + +// CHECK: module { +// CHECK-NEXT: func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK-NEXT: %0 = "tfl.custom"(%arg0) {custom_code = "stablehlo.rsqrt", custom_option = #tfl} : (tensor<2xf32>) -> tensor<2xf32> +// CHECK-NEXT: return %0 : tensor<2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-scatter.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-scatter.mlir similarity index 58% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-scatter.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-scatter.mlir index 78bd4de693b..5bd79227f57 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-scatter.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-scatter.mlir @@ -1,12 +1,12 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s module { func.func @main(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>, %arg2: tensor<1xi32>) -> tensor<3xi32> { - %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor, %arg4: tensor): - "mhlo.return"(%arg4) : (tensor) -> () + "stablehlo.return"(%arg4) : (tensor) -> () }) { - scatter_dimension_numbers = #mhlo.scatter< + scatter_dimension_numbers = #stablehlo.scatter< update_window_dims = [], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], @@ -20,7 +20,7 @@ func.func @main(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>, %arg2: tensor<1xi3 // CHECK: module { // CHECK-NEXT: func.func @main(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>, %arg2: tensor<1xi32>) -> tensor<3xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "mhlo.scatter", custom_option = #tfl} : (tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "stablehlo.scatter", custom_option = #tfl} : (tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<3xi32> // CHECK-NEXT: return %0 : tensor<3xi32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-mul.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-sub.mlir similarity index 58% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-mul.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-sub.mlir index e55f83a742b..bc4f72fd2bc 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-mhlo-tfl-mul.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-sub.mlir @@ -1,15 +1,15 @@ -// RUN: tf-mhlo-tfl-opt %s -mhlo-tfl | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s module { func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32> + %0 = stablehlo.subtract %arg0, %arg0 : tensor<2xi32> func.return %0 : tensor<2xi32> } } // CHECK: module { // CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "mhlo.multiply", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl.mlir new file mode 100644 index 00000000000..8898fac4288 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl.mlir @@ -0,0 +1,17 @@ +// RUN: odml-to-stablehlo-opt %s -stablehlo-tfl | FileCheck %s + +module { +func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %0 = stablehlo.add %arg0, %arg0 : tensor<2xi32> + %1 = stablehlo.subtract %0, %arg0 : tensor<2xi32> + func.return %1 : tensor<2xi32> +} +} + +// CHECK: module { +// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK-NEXT: %1 = "tfl.custom"(%0, %arg0) {custom_code = "stablehlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK-NEXT: return %1 : tensor<2xi32> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-broadcast.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-broadcast.mlir deleted file mode 100644 index aa14caeb1c9..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-broadcast.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -tfl-parse-mhlo-ops | FileCheck %s - -module { - func.func @main(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { - %0 = "tfl.custom"(%arg0) {custom_code = "mhlo.broadcast_in_dim", custom_option = #tfl} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> - func.return %0 : tensor<1x2x2xi32> - } -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { -// CHECK-NEXT: %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> -// CHECK-NEXT: return %0 : tensor<1x2x2xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-concat.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-concat.mlir deleted file mode 100644 index d2327e2a898..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-concat.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -tfl-parse-mhlo-ops | FileCheck %s - -module { -func.func @main(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { - %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "mhlo.concatenate", custom_option = #tfl} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> - func.return %0 : tensor<6x3xf32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { -// CHECK-NEXT: %0 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> -// CHECK-NEXT: return %0 : tensor<6x3xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-constant.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-constant.mlir deleted file mode 100644 index ea8deda4d71..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-constant.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -tfl-parse-mhlo-ops | FileCheck -dump-input always %s - -module { - func.func @main() -> tensor<1xi64> { - %0 = "tfl.custom"() {custom_code = "mhlo.constant", custom_option = #tfl} : () -> tensor<1xi64> - func.return %0 : tensor<1xi64> - } -} - -// CHECK: module { -// CHECK-NEXT: func @main() -> tensor<1xi64> { -// CHECK-NEXT: %0 = mhlo.constant dense<2> : tensor<1xi64> -// CHECK-NEXT: return %0 : tensor<1xi64> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-conv.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-conv.mlir deleted file mode 100644 index e95f1e22764..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-conv.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -tfl-parse-mhlo-ops | FileCheck -dump-input always %s - -module { - func.func @main(%arg0: tensor<8x8x1x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> { - %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "mhlo.convolution", custom_option = #tfl} : (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> - func.return %0 : tensor<16x8x8x1xf32> - } -} - - -// CHECK: module { -// CHECK: func @main(%arg0: tensor<8x8x1x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> { -// CHECK: %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [0, 1, b, f]x[0, 1, o, i]->[f, 0, 1, b], window = {stride = [1, 1], pad = [1, 1, 1, 1], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> -// CHECK: return %0 : tensor<16x8x8x1xf32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-pad.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-pad.mlir deleted file mode 100644 index 87ab7f83276..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-pad.mlir +++ /dev/null @@ -1,16 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -tfl-parse-mhlo-ops | FileCheck -dump-input always %s - -module { - func.func @main(%arg0: tensor<8x128xf32>, %arg1: tensor) -> tensor<11x131xf32> { - %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "mhlo.pad", custom_option = #tfl} : (tensor<8x128xf32>, tensor) -> tensor<11x131xf32> - func.return %0 : tensor<11x131xf32> - } - } - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<8x128xf32>, %arg1: tensor) -> tensor<11x131xf32> { -// CHECK-NEXT: %0 = "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 3]> : tensor<2xi64>, edge_padding_low = dense<[1, 0]> : tensor<2xi64>, interior_padding = dense<0> : tensor<2xi64>} : (tensor<8x128xf32>, tensor) -> tensor<11x131xf32> -// CHECK-NEXT: return %0 : tensor<11x131xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } - diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-reshape.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-reshape.mlir deleted file mode 100644 index e6777e9c9c0..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-reshape.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -tfl-parse-mhlo-ops | FileCheck -dump-input always %s - -module { - func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = "tfl.custom"(%arg0) {custom_code = "mhlo.reshape", custom_option = #tfl} : (tensor<2xi32>) -> tensor<2xi32> - func.return %0 : tensor<2xi32> - } -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { -// CHECK-NEXT: %0 = mhlo.reshape %arg0 : (tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: return %0 : tensor<2xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-rsqrt.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-rsqrt.mlir deleted file mode 100644 index 16ea2618697..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-rsqrt.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -tfl-parse-mhlo-ops | FileCheck -dump-input always %s - -module { -func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "tfl.custom"(%arg0) {custom_code = "mhlo.rsqrt", custom_option = #tfl} : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} -} - -// CHECK: module -// CHECK-NEXT: func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { -// CHECK-NEXT: %0 = mhlo.rsqrt %arg0 : tensor<2xf32> -// CHECK-NEXT: return %0 : tensor<2xf32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo.mlir deleted file mode 100644 index 379b24890e4..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo.mlir +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: tf-mhlo-tfl-opt %s -tfl-parse-mhlo-ops | FileCheck %s - -module { -func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { - %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "mhlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - %1 = "tfl.custom"(%0, %arg0) {custom_code = "mhlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - func.return %1 : tensor<2xi32> -} -} - -// CHECK: module { -// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { -// CHECK-NEXT: %0 = mhlo.add %arg0, %arg0 : tensor<2xi32> -// CHECK-NEXT: %1 = mhlo.subtract %0, %arg0 : tensor<2xi32> -// CHECK-NEXT: return %1 : tensor<2xi32> -// CHECK-NEXT: } -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-sub.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-add.mlir similarity index 72% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-sub.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-add.mlir index 98ddf5a38f0..298e87b2a40 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-sub.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-add.mlir @@ -1,15 +1,15 @@ -// RUN: tf-mhlo-tfl-opt %s -tfl-parse-mhlo-ops | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -tfl-parse-stablehlo-ops | FileCheck %s module { func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { - %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "mhlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> func.return %0 : tensor<2xi32> } } // CHECK: module { // CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { -// CHECK-NEXT: %0 = mhlo.subtract %arg0, %arg0 : tensor<2xi32> +// CHECK-NEXT: %0 = stablehlo.add %arg0, %arg0 : tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-broadcast.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-broadcast.mlir new file mode 100644 index 00000000000..b2477312ef1 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-broadcast.mlir @@ -0,0 +1,15 @@ +// RUN: odml-to-stablehlo-opt %s -tfl-parse-stablehlo-ops | FileCheck %s + +module { + func.func @main(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { + %0 = "tfl.custom"(%arg0) {custom_code = "stablehlo.broadcast_in_dim", custom_option = #tfl} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> + func.return %0 : tensor<1x2x2xi32> + } +} + +// CHECK: module { +// CHECK-NEXT: func @main(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { +// CHECK-NEXT: %0 = stablehlo.broadcast_in_dim %arg0, dims = [1, 2] : (tensor<1x2xi32>) -> tensor<1x2x2xi32> +// CHECK-NEXT: return %0 : tensor<1x2x2xi32> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-clamp.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-clamp.mlir similarity index 51% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-clamp.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-clamp.mlir index c5cf3a9cd21..6456a802809 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-clamp.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-clamp.mlir @@ -1,15 +1,15 @@ -// RUN: tf-mhlo-tfl-opt %s -tfl-parse-mhlo-ops | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -tfl-parse-stablehlo-ops | FileCheck %s module { func.func @main(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - %0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "mhlo.clamp", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "stablehlo.clamp", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> func.return %0 : tensor<2xi32> } } // CHECK: module { // CHECK-NEXT: func @main(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { -// CHECK-NEXT: %0 = mhlo.clamp %arg0, %arg1, %arg2 : tensor<2xi32> +// CHECK-NEXT: %0 = stablehlo.clamp %arg0, %arg1, %arg2 : tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-concat.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-concat.mlir new file mode 100644 index 00000000000..6e783dca9dd --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-concat.mlir @@ -0,0 +1,15 @@ +// RUN: odml-to-stablehlo-opt %s -tfl-parse-stablehlo-ops | FileCheck %s + +module { +func.func @main(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { + %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.concatenate", custom_option = #tfl} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> + func.return %0 : tensor<6x3xf32> +} +} + +// CHECK: module { +// CHECK-NEXT: func @main(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { +// CHECK-NEXT: %0 = stablehlo.concatenate %arg0, %arg1, dim = 0 : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> +// CHECK-NEXT: return %0 : tensor<6x3xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-constant.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-constant.mlir new file mode 100644 index 00000000000..d4d3b0abf01 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-constant.mlir @@ -0,0 +1,15 @@ +// RUN: odml-to-stablehlo-opt %s -tfl-parse-stablehlo-ops | FileCheck -dump-input always %s + +module { + func.func @main() -> tensor<1xi64> { + %0 = "tfl.custom"() {custom_code = "stablehlo.constant", custom_option = #tfl} : () -> tensor<1xi64> + func.return %0 : tensor<1xi64> + } +} + +// CHECK: module { +// CHECK-NEXT: func @main() -> tensor<1xi64> { +// CHECK-NEXT: %0 = stablehlo.constant dense<2> : tensor<1xi64> +// CHECK-NEXT: return %0 : tensor<1xi64> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-conv.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-conv.mlir new file mode 100644 index 00000000000..65f988cebe7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-conv.mlir @@ -0,0 +1,14 @@ +// RUN: odml-to-stablehlo-opt %s -tfl-parse-stablehlo-ops | FileCheck -dump-input always %s + +module { + func.func @main(%arg0: tensor<8x8x1x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> { + %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.convolution", custom_option = #tfl} : (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> + func.return %0 : tensor<16x8x8x1xf32> + } +} + + +// CHECK: module { +// CHECK: func @main(%arg0: tensor<8x8x1x207xf32>, %arg1: tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> { +// CHECK: %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [0, 1, b, f]x[0, 1, o, i]->[f, 0, 1, b], window = {stride = [1, 1], pad = {{\[}}[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor<8x8x1x207xf32>, tensor<3x3x16x207xf32>) -> tensor<16x8x8x1xf32> +// CHECK: return %0 : tensor<16x8x8x1xf32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-add.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-max.mlir similarity index 57% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-add.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-max.mlir index 173559af740..98be6176789 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-add.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-max.mlir @@ -1,15 +1,15 @@ -// RUN: tf-mhlo-tfl-opt %s -tfl-parse-mhlo-ops | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -tfl-parse-stablehlo-ops | FileCheck %s module { func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { - %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "mhlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.maximum", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> func.return %0 : tensor<2xi32> } } // CHECK: module { // CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { -// CHECK-NEXT: %0 = mhlo.add %arg0, %arg0 : tensor<2xi32> +// CHECK-NEXT: %0 = stablehlo.maximum %arg0, %arg0 : tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-max.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-mul.mlir similarity index 57% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-max.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-mul.mlir index 0e129dad19c..338657415d1 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-max.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-mul.mlir @@ -1,15 +1,15 @@ -// RUN: tf-mhlo-tfl-opt %s -tfl-parse-mhlo-ops | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -tfl-parse-stablehlo-ops | FileCheck %s module { func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { - %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "mhlo.maximum", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.multiply", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> func.return %0 : tensor<2xi32> } } // CHECK: module { // CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { -// CHECK-NEXT: %0 = mhlo.maximum %arg0, %arg0 : tensor<2xi32> +// CHECK-NEXT: %0 = stablehlo.multiply %arg0, %arg0 : tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-pad.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-pad.mlir new file mode 100644 index 00000000000..482a7f9e176 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-pad.mlir @@ -0,0 +1,16 @@ +// RUN: odml-to-stablehlo-opt %s -tfl-parse-stablehlo-ops | FileCheck -dump-input always %s + +module { + func.func @main(%arg0: tensor<8x128xf32>, %arg1: tensor) -> tensor<11x131xf32> { + %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "stablehlo.pad", custom_option = #tfl} : (tensor<8x128xf32>, tensor) -> tensor<11x131xf32> + func.return %0 : tensor<11x131xf32> + } + } + +// CHECK: module { +// CHECK-NEXT: func @main(%arg0: tensor<8x128xf32>, %arg1: tensor) -> tensor<11x131xf32> { +// CHECK-NEXT: %0 = stablehlo.pad %arg0, %arg1, low = [1, 0], high = [2, 3], interior = [0, 0] : (tensor<8x128xf32>, tensor) -> tensor<11x131xf32> +// CHECK-NEXT: return %0 : tensor<11x131xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } + diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-reshape.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-reshape.mlir new file mode 100644 index 00000000000..76e77c00c59 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-reshape.mlir @@ -0,0 +1,15 @@ +// RUN: odml-to-stablehlo-opt %s -tfl-parse-stablehlo-ops | FileCheck -dump-input always %s + +module { + func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %0 = "tfl.custom"(%arg0) {custom_code = "stablehlo.reshape", custom_option = #tfl} : (tensor<2xi32>) -> tensor<2xi32> + func.return %0 : tensor<2xi32> + } +} + +// CHECK: module { +// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { +// CHECK-NEXT: %0 = stablehlo.reshape %arg0 : (tensor<2xi32>) -> tensor<2xi32> +// CHECK-NEXT: return %0 : tensor<2xi32> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-rsqrt.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-rsqrt.mlir new file mode 100644 index 00000000000..5c24fb2e354 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-rsqrt.mlir @@ -0,0 +1,15 @@ +// RUN: odml-to-stablehlo-opt %s -tfl-parse-stablehlo-ops | FileCheck -dump-input always %s + +module { +func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "tfl.custom"(%arg0) {custom_code = "stablehlo.rsqrt", custom_option = #tfl} : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} +} + +// CHECK: module +// CHECK-NEXT: func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { +// CHECK-NEXT: %0 = stablehlo.rsqrt %arg0 : tensor<2xf32> +// CHECK-NEXT: return %0 : tensor<2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-mul.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-sub.mlir similarity index 57% rename from tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-mul.mlir rename to tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-sub.mlir index 6caeed24216..73c3e919163 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-mhlo-mul.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo-sub.mlir @@ -1,15 +1,15 @@ -// RUN: tf-mhlo-tfl-opt %s -tfl-parse-mhlo-ops | FileCheck %s +// RUN: odml-to-stablehlo-opt %s -tfl-parse-stablehlo-ops | FileCheck %s module { func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { - %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "mhlo.multiply", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> func.return %0 : tensor<2xi32> } } // CHECK: module { // CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { -// CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32> +// CHECK-NEXT: %0 = stablehlo.subtract %arg0, %arg0 : tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> // CHECK-NEXT: } // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo.mlir new file mode 100644 index 00000000000..73dc4307a86 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tfl-stablehlo.mlir @@ -0,0 +1,17 @@ +// RUN: odml-to-stablehlo-opt %s -tfl-parse-stablehlo-ops | FileCheck %s + +module { +func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { + %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %1 = "tfl.custom"(%0, %arg0) {custom_code = "stablehlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + func.return %1 : tensor<2xi32> +} +} + +// CHECK: module { +// CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { +// CHECK-NEXT: %0 = stablehlo.add %arg0, %arg0 : tensor<2xi32> +// CHECK-NEXT: %1 = stablehlo.subtract %0, %arg0 : tensor<2xi32> +// CHECK-NEXT: return %1 : tensor<2xi32> +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-stablehlo-tfl.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-stablehlo-tfl.mlir index 0bc46ac0ea2..46e0748b348 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-stablehlo-tfl.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-stablehlo-tfl.mlir @@ -2,16 +2,16 @@ module { func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { - %0 = mhlo.add %arg0, %arg0 : tensor<2xi32> - %1 = mhlo.subtract %0, %arg0 : tensor<2xi32> + %0 = stablehlo.add %arg0, %arg0 : tensor<2xi32> + %1 = stablehlo.subtract %0, %arg0 : tensor<2xi32> func.return %1 : tensor<2xi32> } } // CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { // CHECK-NEXT: func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom1"}} { -// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "mhlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> -// CHECK-NEXT: %1 = "tfl.custom"(%0, %arg0) {custom_code = "mhlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK-NEXT: %0 = "tfl.custom"(%arg0, %arg0) {custom_code = "stablehlo.add", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> +// CHECK-NEXT: %1 = "tfl.custom"(%0, %arg0) {custom_code = "stablehlo.subtract", custom_option = #tfl} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> // CHECK-NEXT: return %1 : tensor<2xi32> // CHECK-NEXT: } // CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-allow-tf.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-allow-tf.mlir index e2c8d1bfb70..608b90d54a7 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-allow-tf.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/odml-to-stablehlo-allow-tf.mlir @@ -1,9 +1,9 @@ -// RUN: odml_to_stablehlo %s --allow-tf=false -o /tmp/temp.mlir; [ ! -f /tmp/temp.mlir ] +// RUN: odml_to_stablehlo %s --allow-tf=false -o /tmp/temp.mlir; [ -f /tmp/temp.mlir ]; [ -f /tmp/debug_stablehlo.mlir ] // RUN: odml_to_stablehlo %s --allow-tf=true -o /tmp/temp2.mlir; [ -f /tmp/temp2.mlir ] module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 975 : i32}, tf_saved_model.semantics} { func.func @serving_default(%arg0: tensor<1x20x20x28xf32> {tf_saved_model.index_path = ["a"]}) -> (tensor<1x40x40x28xf32> {tf_saved_model.index_path = ["b"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "c:0", outputs = "d:0"}, tf_saved_model.exported_names = ["serving_default"]} { - %0 = mhlo.constant dense<40> : tensor<2xi32> + %0 = stablehlo.constant dense<40> : tensor<2xi32> %1 = "tf.UnconvertedOp"(%arg0, %0) {align_corners = false, half_pixel_centers = false} : (tensor<1x20x20x28xf32>, tensor<2xi32>) -> tensor<1x40x40x28xf32> func.return %1 : tensor<1x40x40x28xf32> } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir new file mode 100644 index 00000000000..d59c5488240 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir @@ -0,0 +1,246 @@ +// RUN: odml-to-stablehlo-opt %s -split-input-file -mhlo-optimize | FileCheck %s + +// CHECK-LABEL: testDotToDotGeneralVectorVector +func.func @testDotToDotGeneralVectorVector(%arg0: tensor<3072xf32>, %arg1: tensor<3072xf32>) -> tensor { + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<3072xf32>, tensor<3072xf32>) -> tensor + func.return %0 : tensor + +// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) { +// CHECK-SAME: dot_dimension_numbers = #mhlo.dot< +// CHECK-SAME: lhs_contracting_dimensions = [0], +// CHECK-SAME: rhs_contracting_dimensions = [0] +// CHECK-SAME: >} : (tensor<3072xf32>, tensor<3072xf32>) -> tensor +// CHECK: return %[[RES]] : tensor +} + +// ----- + +// CHECK-LABEL: testDotToDotGeneralVectorMatrix +func.func @testDotToDotGeneralVectorMatrix(%arg0: tensor<3072xf32>, %arg1: tensor<3072x512xf32>) -> tensor<512xf32> { + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<3072xf32>, tensor<3072x512xf32>) -> tensor<512xf32> + func.return %0 : tensor<512xf32> + +// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) { +// CHECK-SAME: dot_dimension_numbers = #mhlo.dot< +// CHECK-SAME: lhs_contracting_dimensions = [0], +// CHECK-SAME: rhs_contracting_dimensions = [0] +// CHECK-SAME: >} : (tensor<3072xf32>, tensor<3072x512xf32>) -> tensor<512xf32> +// CHECK: return %[[RES]] : tensor<512xf32> +} + +// ----- + +// CHECK-LABEL: testDotToDotGeneralMatrixVector +func.func @testDotToDotGeneralMatrixVector(%arg0: tensor<2x3072xf32>, %arg1: tensor<3072xf32>) -> tensor<2xf32> { + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3072xf32>, tensor<3072xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> + +// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) { +// CHECK-SAME: dot_dimension_numbers = #mhlo.dot< +// CHECK-SAME: lhs_contracting_dimensions = [1], +// CHECK-SAME: rhs_contracting_dimensions = [0] +// CHECK-SAME: >} : (tensor<2x3072xf32>, tensor<3072xf32>) -> tensor<2xf32> +// CHECK: return %[[RES]] : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: testDotToDotGeneralMatrixMatrix +func.func @testDotToDotGeneralMatrixMatrix(%arg0: tensor<2x3072xf32>, %arg1: tensor<3072x512xf32>) -> tensor<2x512xf32> { + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3072xf32>, tensor<3072x512xf32>) -> tensor<2x512xf32> + func.return %0 : tensor<2x512xf32> + +// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) { +// CHECK-SAME: dot_dimension_numbers = #mhlo.dot< +// CHECK-SAME: lhs_contracting_dimensions = [1], +// CHECK-SAME: rhs_contracting_dimensions = [0] +// CHECK-SAME: >} : (tensor<2x3072xf32>, tensor<3072x512xf32>) -> tensor<2x512xf32> +// CHECK: return %[[RES]] : tensor<2x512xf32> +} + +// ----- + +// CHECK-LABEL: testRemoveReshapeAroundDotGeneral +func.func @testRemoveReshapeAroundDotGeneral(%arg0: tensor<3x72x1x2048xf32>, %arg1: tensor<3x2048x512xf32>) -> tensor<3x72x1x512xf32> { + %0 = "mhlo.reshape"(%arg0) : (tensor<3x72x1x2048xf32>) -> tensor<3x72x2048xf32> + %1 = "mhlo.dot_general"(%0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [1] + >} : (tensor<3x72x2048xf32>, tensor<3x2048x512xf32>) -> tensor<3x72x512xf32> + %2 = "mhlo.reshape"(%1) : (tensor<3x72x512xf32>) -> tensor<3x72x1x512xf32> + func.return %2 : tensor<3x72x1x512xf32> + +// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) { +// CHECK-SAME: dot_dimension_numbers = #mhlo.dot< +// CHECK-SAME: lhs_batching_dimensions = [0], +// CHECK-SAME: rhs_batching_dimensions = [0], +// CHECK-SAME: lhs_contracting_dimensions = [3], +// CHECK-SAME: rhs_contracting_dimensions = [1] +// CHECK-SAME: >} : (tensor<3x72x1x2048xf32>, tensor<3x2048x512xf32>) -> tensor<3x72x1x512xf32> +// CHECK: return %[[RES]] : tensor<3x72x1x512xf32> +} + +// ----- + +// CHECK-LABEL: testRemoveReshapeAroundDot +func.func @testRemoveReshapeAroundDot(%arg0: tensor<1x1x512xf32>, %arg1: tensor<512x13x!quant.uniform>) -> tensor<1x1x13xf32> { + %0 = "mhlo.reshape"(%arg0) : (tensor<1x1x512xf32>) -> tensor<1x512xf32> + %1 = "mhlo.dot"(%0, %arg1) : (tensor<1x512xf32>, tensor<512x13x!quant.uniform>) -> tensor<1x13xf32> + %2 = "mhlo.reshape"(%1) : (tensor<1x13xf32>) -> tensor<1x1x13xf32> + func.return %2 : tensor<1x1x13xf32> + +// CHECK: %[[RES:.*]] = "mhlo.dot_general"(%arg0, %arg1) { +// CHECK-SAME: dot_dimension_numbers = #mhlo.dot< +// CHECK-SAME: lhs_contracting_dimensions = [2], +// CHECK-SAME: rhs_contracting_dimensions = [0] +// CHECK-SAME: >} : (tensor<1x1x512xf32>, tensor<512x13x!quant.uniform>) -> tensor<1x1x13xf32> +// CHECK: return %[[RES]] : tensor<1x1x13xf32> +} + +// ----- + +// CHECK-LABEL: testLiftDotConcatLHSSimple +func.func @testLiftDotConcatLHSSimple(%arg0: tensor<1x1x512xf32>, %arg1: tensor<2x1x512xf32>, %arg2: tensor<3x1x512xf32>, %arg3: tensor<512x13xf32>) -> tensor<6x1x13xf32> { + %0 = "mhlo.dot_general"(%arg0, %arg3) { + dot_dimension_numbers = #mhlo.dot< + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [0] + >} : (tensor<1x1x512xf32>, tensor<512x13xf32>) -> tensor<1x1x13xf32> + %1 = "mhlo.dot_general"(%arg1, %arg3) { + dot_dimension_numbers = #mhlo.dot< + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [0] + >} : (tensor<2x1x512xf32>, tensor<512x13xf32>) -> tensor<2x1x13xf32> + %2 = "mhlo.dot_general"(%arg2, %arg3) { + dot_dimension_numbers = #mhlo.dot< + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [0] + >} : (tensor<3x1x512xf32>, tensor<512x13xf32>) -> tensor<3x1x13xf32> + %r = "mhlo.concatenate"(%0, %1, %2) {dimension = 0 : i64} : (tensor<1x1x13xf32>, tensor<2x1x13xf32>, tensor<3x1x13xf32>) -> tensor<6x1x13xf32> + func.return %r : tensor<6x1x13xf32> + +// CHECK: %[[R0:.*]] = "mhlo.concatenate"(%arg0, %arg1, %arg2) {dimension = 0 : i64} : (tensor<1x1x512xf32>, tensor<2x1x512xf32>, tensor<3x1x512xf32>) -> tensor<6x1x512xf32> +// CHECK: %[[R1:.*]] = "mhlo.dot_general"(%[[R0]], %arg3) { +// CHECK-SAME: dot_dimension_numbers = #mhlo.dot< +// CHECK-SAME: lhs_contracting_dimensions = [2], +// CHECK-SAME: rhs_contracting_dimensions = [0] +// CHECK-SAME: >} : (tensor<6x1x512xf32>, tensor<512x13xf32>) -> tensor<6x1x13xf32> +// CHECK: return %[[R1]] : tensor<6x1x13xf32> +} + +// ----- + +// CHECK-LABEL: testLiftDotConcatLHSComplex +func.func @testLiftDotConcatLHSComplex(%arg0: tensor<1x9x2x3x8x4x10xf32>, %arg1: tensor<1x9x2x3x8x100x10xf32>, %arg2: tensor<9x2x1x5x10x5x8x7xf32>) -> tensor<1x2x3x104x5x5x7xf32> { + %0 = "mhlo.dot_general"(%arg0, %arg2) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0, 2], + rhs_batching_dimensions = [2, 1], + lhs_contracting_dimensions = [4, 1, 6], + rhs_contracting_dimensions = [6, 0, 4] + >} : (tensor<1x9x2x3x8x4x10xf32>, tensor<9x2x1x5x10x5x8x7xf32>) -> tensor<1x2x3x4x5x5x7xf32> + %1 = "mhlo.dot_general"(%arg1, %arg2) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0, 2], + rhs_batching_dimensions = [2, 1], + lhs_contracting_dimensions = [4, 1, 6], + rhs_contracting_dimensions = [6, 0, 4] + >} : (tensor<1x9x2x3x8x100x10xf32>, tensor<9x2x1x5x10x5x8x7xf32>) -> tensor<1x2x3x100x5x5x7xf32> + %r = "mhlo.concatenate"(%0, %1) {dimension = 3 : i64} : (tensor<1x2x3x4x5x5x7xf32>, tensor<1x2x3x100x5x5x7xf32>) -> tensor<1x2x3x104x5x5x7xf32> + func.return %r : tensor<1x2x3x104x5x5x7xf32> + +// CHECK: %[[R0:.*]] = "mhlo.concatenate"(%arg0, %arg1) {dimension = 5 : i64} : (tensor<1x9x2x3x8x4x10xf32>, tensor<1x9x2x3x8x100x10xf32>) -> tensor<1x9x2x3x8x104x10xf32> +// CHECK: %[[R1:.*]] = "mhlo.dot_general"(%[[R0]], %arg2) { +// CHECK-SAME: dot_dimension_numbers = #mhlo.dot< +// CHECK-SAME: lhs_batching_dimensions = [0, 2], +// CHECK-SAME: rhs_batching_dimensions = [2, 1], +// CHECK-SAME: lhs_contracting_dimensions = [4, 1, 6], +// CHECK-SAME: rhs_contracting_dimensions = [6, 0, 4] +// CHECK-SAME: >} : (tensor<1x9x2x3x8x104x10xf32>, tensor<9x2x1x5x10x5x8x7xf32>) -> tensor<1x2x3x104x5x5x7xf32> +// CHECK: return %[[R1]] : tensor<1x2x3x104x5x5x7xf32> +} + +// ----- + +// CHECK-LABEL: testLiftDotConcatLHSAndRHS +func.func @testLiftDotConcatLHSAndRHS(%arg0: tensor<1x72x128xf32>, %arg1: tensor<1x128x72xf32>, %arg2: tensor<1x72x128xf32>, %arg3: tensor<1x128x72xf32>, %arg4: tensor<1x72x128xf32>, %arg5: tensor<1x128x72xf32>, %arg6: tensor<1x72x128xf32>, %arg7: tensor<1x128x72xf32>) -> tensor<4x72x72xf32> { + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [1] + >} : (tensor<1x72x128xf32>, tensor<1x128x72xf32>) -> tensor<1x72x72xf32> + %1 = "mhlo.dot_general"(%arg2, %arg3) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [1] + >} : (tensor<1x72x128xf32>, tensor<1x128x72xf32>) -> tensor<1x72x72xf32> + %2 = "mhlo.dot_general"(%arg4, %arg5) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [1] + >} : (tensor<1x72x128xf32>, tensor<1x128x72xf32>) -> tensor<1x72x72xf32> + %3 = "mhlo.dot_general"(%arg6, %arg7) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [1] + >} : (tensor<1x72x128xf32>, tensor<1x128x72xf32>) -> tensor<1x72x72xf32> + %4 = "mhlo.concatenate"(%0, %1, %2, %3) {dimension = 0 : i64} : (tensor<1x72x72xf32>, tensor<1x72x72xf32>, tensor<1x72x72xf32>, tensor<1x72x72xf32>) -> tensor<4x72x72xf32> + func.return %4 : tensor<4x72x72xf32> + +// CHECK: %[[R0:.*]] = "mhlo.concatenate"(%arg0, %arg2, %arg4, %arg6) {dimension = 0 : i64} : (tensor<1x72x128xf32>, tensor<1x72x128xf32>, tensor<1x72x128xf32>, tensor<1x72x128xf32>) -> tensor<4x72x128xf32> +// CHECK: %[[R1:.*]] = "mhlo.concatenate"(%arg1, %arg3, %arg5, %arg7) {dimension = 0 : i64} : (tensor<1x128x72xf32>, tensor<1x128x72xf32>, tensor<1x128x72xf32>, tensor<1x128x72xf32>) -> tensor<4x128x72xf32> +// CHECK: %[[R2:.*]] = "mhlo.dot_general"(%[[R0]], %[[R1]]) { +// CHECK-SAME: dot_dimension_numbers = #mhlo.dot< +// CHECK-SAME: lhs_batching_dimensions = [0], +// CHECK-SAME: rhs_batching_dimensions = [0], +// CHECK-SAME: lhs_contracting_dimensions = [2], +// CHECK-SAME: rhs_contracting_dimensions = [1] +// CHECK-SAME: >} : (tensor<4x72x128xf32>, tensor<4x128x72xf32>) -> tensor<4x72x72xf32> +// CHECK: return %[[R2]] : tensor<4x72x72xf32> +} + +// ----- + +// CHECK-LABEL: testSliceConcat +func.func @testSliceConcat(%arg0: tensor<3x1x512xf32>) -> tensor<3x1x512xf32> { + %0 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 1, 512]> : tensor<3xi64>, start_indices = dense<[0, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<3x1x512xf32>) -> tensor<1x1x512xf32> + %1 = "mhlo.slice"(%arg0) {limit_indices = dense<[2, 1, 512]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<3x1x512xf32>) -> tensor<1x1x512xf32> + %2 = "mhlo.slice"(%arg0) {limit_indices = dense<[3, 1, 512]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<3x1x512xf32>) -> tensor<1x1x512xf32> + %r = "mhlo.concatenate"(%0, %1, %2) {dimension = 0 : i64} : (tensor<1x1x512xf32>, tensor<1x1x512xf32>, tensor<1x1x512xf32>) -> tensor<3x1x512xf32> + func.return %r : tensor<3x1x512xf32> + +// CHECK: return %arg0 : tensor<3x1x512xf32> +} + +// ----- + +// CHECK-LABEL: testConvertReshapeDotRhsToBatchedDot +func.func @testConvertReshapeDotRhsToBatchedDot(%arg0: tensor<1x72x72xf32>, %arg1: tensor<1x72x128xf32>) -> tensor<1x72x128xf32> { + %0 = mhlo.reshape %arg1 : (tensor<1x72x128xf32>) -> tensor<72x128xf32> + %1 = "mhlo.dot_general"(%arg0, %0) { + dot_dimension_numbers = #mhlo.dot< + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [0] + >} : (tensor<1x72x72xf32>, tensor<72x128xf32>) -> tensor<1x72x128xf32> + func.return %1 : tensor<1x72x128xf32> + +// CHECK: %[[R:.*]] = "mhlo.dot_general"(%arg0, %arg1) { +// CHECK-SAME: dot_dimension_numbers = #mhlo.dot< +// CHECK-SAME: lhs_batching_dimensions = [0], +// CHECK-SAME: rhs_batching_dimensions = [0], +// CHECK-SAME: lhs_contracting_dimensions = [2], +// CHECK-SAME: rhs_contracting_dimensions = [1] +// CHECK-SAME: >} : (tensor<1x72x72xf32>, tensor<1x72x128xf32>) -> tensor<1x72x128xf32> +// CHECK: return %[[R]] : tensor<1x72x128xf32> +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir new file mode 100644 index 00000000000..d131149d7d1 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-serialize-stablehlo.mlir @@ -0,0 +1,22 @@ +//RUN: tf_tfl_translate --enable-stablehlo-conversion --input-mlir %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s + + +module { +func.func @tfInplaceUpdate(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> { + %1 = arith.constant dense<1> : tensor<1xi32> + %2 = arith.constant dense<2.0> : tensor<1x1x2xf32> + %3 = "tf.InplaceUpdate"(%arg0, %1, %2) {device = ""} + : (tensor<2x1x2xf32>, tensor<1xi32>, tensor<1x1x2xf32>) -> tensor<2x1x2xf32> + func.return %3 : tensor<2x1x2xf32> +} +} + +//CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { +//CHECK-NEXT: func.func @main(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> attributes {tf.entry_function = {inputs = "arg0", outputs = "tfl.custom3"}} { +//CHECK-NEXT: %[[cst0:.*]] = "tfl.custom"() {custom_code = "stablehlo.constant", custom_option = #tfl} : () -> tensor +//CHECK-NEXT: %[[cst1:.*]] = "tfl.custom"() {custom_code = "stablehlo.constant", custom_option = #tfl} : () -> tensor +//CHECK-NEXT: %[[cst2:.*]] = "tfl.custom"() {custom_code = "stablehlo.constant", custom_option = #tfl} : () -> tensor<1x1x2xf32> +//CHECK-NEXT: %[[dus:.*]] = "tfl.custom"(%arg0, %[[cst2]], %[[cst0]], %[[cst1]], %[[cst1]]) {custom_code = "stablehlo.dynamic_update_slice", custom_option = #tfl} : (tensor<2x1x2xf32>, tensor<1x1x2xf32>, tensor, tensor, tensor) -> tensor<2x1x2xf32> +//CHECK-NEXT: return %[[dus]] : tensor<2x1x2xf32> +//CHECK-NEXT: } +//CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-tf-quantize.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-tf-quantize.mlir index 748037580c8..d23de7ce50c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-tf-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tf-tfl-translate-tf-quantize.mlir @@ -13,10 +13,10 @@ func.func @tfInplaceUpdate(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> { //CHECK: module { //CHECK-NEXT: func.func @main(%arg0: tensor<2x1x2xf32>) -> tensor<2x1x2xf32> { -//CHECK-DAG: %0 = mhlo.constant dense<2.000000e+00> : tensor<1x1x2xf32> -//CHECK-DAG: %1 = mhlo.constant dense<1> : tensor -//CHECK-DAG: %2 = mhlo.constant dense<0> : tensor -//CHECK-NEXT: %3 = mhlo.dynamic_update_slice %arg0, %0, %1, %2, %2 : (tensor<2x1x2xf32>, tensor<1x1x2xf32>, tensor, tensor, tensor) -> tensor<2x1x2xf32> +//CHECK-DAG: %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x2xf32> +//CHECK-DAG: %1 = stablehlo.constant dense<1> : tensor +//CHECK-DAG: %2 = stablehlo.constant dense<0> : tensor +//CHECK-NEXT: %3 = stablehlo.dynamic_update_slice %arg0, %0, %1, %2, %2 : (tensor<2x1x2xf32>, tensor<1x1x2xf32>, tensor, tensor, tensor) -> tensor<2x1x2xf32> //CHECK-NEXT: return %3 : tensor<2x1x2xf32> //CHECK-NEXT: } -//CHECK-NEXT:} \ No newline at end of file +//CHECK-NEXT:} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/unfuse_mhlo_batch_norm.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/unfuse_mhlo_batch_norm.mlir new file mode 100644 index 00000000000..6a25c98a768 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/unfuse_mhlo_batch_norm.mlir @@ -0,0 +1,126 @@ +// RUN: odml-to-stablehlo-opt %s -unfuse-mhlo-batch-norm-pass -cse -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @batchNormInference_2D_inner_features +// CHECK-SAME: %[[X:[^:[:space:]]+]] +// CHECK-SAME: %[[SCALE:[^:[:space:]]+]] +// CHECK-SAME: %[[OFFSET:[^:[:space:]]+]] +// CHECK-SAME: %[[MEAN:[^:[:space:]]+]] +// CHECK-SAME: %[[VARIANCE:[^:[:space:]]+]] +func.func @batchNormInference_2D_inner_features( + %x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, + %mean: tensor<256xf32>, %variance: tensor<256xf32>) + -> (tensor<4x256xf32>) { + // CHECK-DAG: %[[EPS_BCAST:.+]] = mhlo.constant dense<1.001000e-05> : tensor<256xf32> + // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> + // CHECK-DAG: %[[VARIANCE_EPS_RSQRT:.+]] = mhlo.rsqrt %[[VARIANCE_EPS]] : tensor<256xf32> + // CHECK-DAG: %[[MULTIPLIER:.+]] = mhlo.multiply %[[VARIANCE_EPS_RSQRT]], %[[SCALE]] : tensor<256xf32> + // CHECK-DAG: %[[MUL_MEAN:.+]] = mhlo.multiply %[[MULTIPLIER]], %[[MEAN]] : tensor<256xf32> + // CHECK-DAG: %[[RHS:.+]] = mhlo.subtract %[[OFFSET]], %[[MUL_MEAN]] : tensor<256xf32> + // CHECK-DAG: %[[MULTIPLIER_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MULTIPLIER]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[RHS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.multiply %[[X]], %[[MULTIPLIER_BCAST]] : tensor<4x256xf32> + // CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[RHS_BCAST]] : tensor<4x256xf32> + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + {epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} : + (tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, + tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: return %[[RESULT]] + func.return %0 : tensor<4x256xf32> +} + +// CHECK-LABEL: @batchNormInference_4D_middle_features +// CHECK-SAME: %[[X:[^:[:space:]]+]] +// CHECK-SAME: %[[SCALE:[^:[:space:]]+]] +// CHECK-SAME: %[[OFFSET:[^:[:space:]]+]] +// CHECK-SAME: %[[MEAN:[^:[:space:]]+]] +// CHECK-SAME: %[[VARIANCE:[^:[:space:]]+]] +func.func @batchNormInference_4D_middle_features( + %x: tensor<3x4x256x6xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, + %mean: tensor<256xf32>, %variance: tensor<256xf32>) + -> (tensor<3x4x256x6xf32>) { + // CHECK-DAG: %[[EPS_BCAST:.+]] = mhlo.constant dense<1.001000e-05> : tensor<256xf32> + // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> + // CHECK-DAG: %[[VARIANCE_EPS_RSQRT:.+]] = mhlo.rsqrt %[[VARIANCE_EPS]] : tensor<256xf32> + // CHECK-DAG: %[[MULTIPLIER:.+]] = mhlo.multiply %[[VARIANCE_EPS_RSQRT]], %[[SCALE]] : tensor<256xf32> + // CHECK-DAG: %[[MUL_MEAN:.+]] = mhlo.multiply %[[MULTIPLIER]], %[[MEAN]] : tensor<256xf32> + // CHECK-DAG: %[[RHS:.+]] = mhlo.subtract %[[OFFSET]], %[[MUL_MEAN]] : tensor<256xf32> + // CHECK-DAG: %[[MULTIPLIER_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MULTIPLIER]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> + // CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[RHS]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + {epsilon = 1.001000e-05 : f32, feature_index = 2 : i64} : + (tensor<3x4x256x6xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, + tensor<256xf32>) -> tensor<3x4x256x6xf32> + func.return %0 : tensor<3x4x256x6xf32> +} + +// CHECK-LABEL: @batchNormInference_dynamic_shape +// Validate that dynamic shapes are handled properly. +// CHECK-SAME: %[[X:[^:[:space:]]+]] +// CHECK-SAME: %[[SCALE:[^:[:space:]]+]] +// CHECK-SAME: %[[OFFSET:[^:[:space:]]+]] +// CHECK-SAME: %[[MEAN:[^:[:space:]]+]] +// CHECK-SAME: %[[VARIANCE:[^:[:space:]]+]] +func.func @batchNormInference_dynamic_shape( + %x: tensor, %scale: tensor, %offset: tensor, + %mean: tensor, %variance: tensor) + -> tensor { + // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor + // CHECK-DAG: %[[VAR_SHAPE:.+]] = shape.shape_of %[[VARIANCE]] : tensor -> tensor<1xindex> + // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[VAR_SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor + // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor + // CHECK-DAG: %[[R_STDDEV:.+]] = mhlo.rsqrt %[[VARIANCE_EPS]] : tensor + // CHECK-DAG: %[[MULTIPLIER:.+]] = mhlo.multiply %[[R_STDDEV]], %[[SCALE]] : tensor + // CHECK-DAG: %[[MUL_MEAN:.+]] = mhlo.multiply %[[MULTIPLIER]], %[[MEAN]] : tensor + // CHECK-DAG: %[[RHS:.+]] = mhlo.subtract %[[OFFSET]], %[[MUL_MEAN]] : tensor + // CHECK-DAG: %[[X_SHAPE:.+]] = shape.shape_of %[[X]] : tensor -> tensor<4xindex> + // CHECK-DAG: %[[MULTIPLIER_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MULTIPLIER]], %[[X_SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[RHS]], %[[X_SHAPE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.multiply %[[X]], %[[MULTIPLIER_BCAST]] : tensor + // CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[RHS_BCAST]] : tensor + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + {epsilon = 0.001 : f32, feature_index = 1 : i64} : + (tensor, tensor, tensor, tensor, + tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: @batchNormInference_f64 +// Validate that epsilon is properly promoted to f64 +// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<256xf64> +func.func @batchNormInference_f64( + %x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>, + %mean: tensor<256xf64>, %variance: tensor<256xf64>) + -> (tensor<4x256xf64>) { + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + {epsilon = 1.0 : f32, feature_index = 1 : i64} : + (tensor<4x256xf64>, tensor<256xf64>, tensor<256xf64>, tensor<256xf64>, + tensor<256xf64>) -> tensor<4x256xf64> + func.return %0 : tensor<4x256xf64> +} + +// CHECK-LABEL: @batchNormInference_f16 +// Validate that epsilon is properly down to f16 +// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<256xf16> +func.func @batchNormInference_f16( + %x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>, + %mean: tensor<256xf16>, %variance: tensor<256xf16>) + -> (tensor<4x256xf16>) { + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + {epsilon = 1.0 : f32, feature_index = 1 : i64} : + (tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>, + tensor<256xf16>) -> tensor<4x256xf16> + func.return %0 : tensor<4x256xf16> +} + +// Validate that epsilon is overflow +func.func @batchNormInference_f16_overflow( + %x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>, + %mean: tensor<256xf16>, %variance: tensor<256xf16>) + -> (tensor<4x256xf16>) { + // expected-warning @+1 {{Could not convert batch_norm epsilon to target fp type: opStatus = 24}} + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + {epsilon = 0.00000001 : f32, feature_index = 1 : i64} : + (tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>, + tensor<256xf16>) -> tensor<4x256xf16> + func.return %0 : tensor<4x256xf16> +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/check_accepted_ops_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/check_accepted_ops_pass.cc index 1cf7083332c..159d6529163 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/check_accepted_ops_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/check_accepted_ops_pass.cc @@ -19,7 +19,11 @@ limitations under the License. #include #include -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_util.h" +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" namespace mlir { namespace odml { @@ -33,7 +37,7 @@ class CheckAcceptedOpsPass explicit CheckAcceptedOpsPass( const std::vector &optional_accepted_dialects) - : accepted_dialects_(TFL::mhlo::GetAcceptedDialects()), + : accepted_dialects_(GetAcceptedDialects()), optional_accepted_dialects_(optional_accepted_dialects) {} // Check if TF dialect ops exist over the module. @@ -49,10 +53,10 @@ void CheckAcceptedOpsPass::runOnOperation() { getOperation()->walk([&](Operation *op) { auto dialect_name = op->getDialect()->getNamespace(); auto op_name = op->getName().stripDialect(); - if (TFL::mhlo::IsAcceptedOp(dialect_name, op_name, accepted_dialects_)) { + if (IsAcceptedOp(dialect_name, op_name, accepted_dialects_)) { // If given op is in the `accepted_dialects_`, it's ok. - } else if (TFL::mhlo::IsAcceptedOp(dialect_name, op_name, - optional_accepted_dialects_)) { + } else if (IsAcceptedOp(dialect_name, op_name, + optional_accepted_dialects_)) { // If the given op is in the `optional_accepted_dialects_`, let's warn it. op->emitWarning() << op->getName().getStringRef() << " op is temporarily " << "accepted, but it should be removed in the end."; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/drop_savedmodel_semantics.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/drop_savedmodel_semantics.cc index 0d45c812a23..d9664a681b6 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/drop_savedmodel_semantics.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/drop_savedmodel_semantics.cc @@ -26,8 +26,13 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" namespace mlir { -namespace TFL { -namespace mhlo { +namespace odml { +namespace { + +using ::mlir::tf_saved_model::kTfSavedModelExportedNamesAttr; +using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; + +} // namespace class DropSavedModelSemanticsPass : public PassWrapper> { @@ -47,9 +52,8 @@ class DropSavedModelSemanticsPass // Clean up functions from tf_saved_model attributes. OpBuilder builder(module); auto bound_input = builder.getStringAttr("tf_saved_model.bound_input"); - auto exported_names = - builder.getStringAttr("tf_saved_model.exported_names"); - auto index_path = builder.getStringAttr("tf_saved_model.index_path"); + auto exported_names = builder.getStringAttr(kTfSavedModelExportedNamesAttr); + auto index_path = builder.getStringAttr(kTfSavedModelIndexPathAttr); module.walk([&](func::FuncOp func) { func->removeAttr(exported_names); for (unsigned i = 0, e = func.getNumArguments(); i != e; ++i) { @@ -75,6 +79,5 @@ std::unique_ptr CreateDropSavedModelSemanticsPass() { static PassRegistration pass; -} // namespace mhlo -} // namespace TFL +} // namespace odml } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/drop_savedmodel_semantics.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/drop_savedmodel_semantics.h index fa8ce65d88f..444a3c4632f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/drop_savedmodel_semantics.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/drop_savedmodel_semantics.h @@ -21,13 +21,11 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project namespace mlir { -namespace TFL { -namespace mhlo { +namespace odml { std::unique_ptr CreateDropSavedModelSemanticsPass(); -} // namespace mhlo -} // namespace TFL +} // namespace odml } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_DROP_SAVEDMODEL_SEMANTICS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc new file mode 100644 index 00000000000..88ddc74e75c --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc @@ -0,0 +1,259 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/APSInt.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { + +static const APFloat &addSign(const APFloat &v, Type) { return v; } +static APSInt addSign(const APInt &v, Type t) { + // Add signedness information to the value, treating signless as signed, + // unless it's i1. + return APSInt(v, t.isUnsignedInteger() || t.isSignlessInteger(1)); +} + +// Helper method that given 'shape' and 'current_index' representing +// index in broadcasted tensor, get the index in the flat original tensor. +// 'shape' is computed from the original shape and the broadcast dimensions to +// match result shape. +int64_t GetElementIndex(llvm::SmallVectorImpl &shape, + llvm::SmallVectorImpl ¤t_index) { + int64_t ind = 0; + int64_t mul = 1; + for (int i = shape.size() - 1; i >= 0; --i) { + ind += (current_index[i] % shape[i]) * mul; + mul *= shape[i]; + } + return ind; +} + +// Helper method that increment index represented in 'current_index_ptr' +// in the shape of 'result_shape'. +void IncrementIndex(ArrayRef result_shape, + llvm::SmallVectorImpl ¤t_index) { + for (int i = result_shape.size() - 1; i >= 0; --i) { + current_index[i]++; + if (current_index[i] == result_shape[i]) { + current_index[i] = 0; + } else { + break; + } + } +} + +template +Attribute ConstFoldBroadcastInDim(ShapedType result_type, + DenseElementsAttr operand, + DenseIntElementsAttr bcast_dims) { + auto dimensions = llvm::to_vector(bcast_dims.getValues()); + const auto result_shape = result_type.getShape(); + // Index for the broadcasted matrix. + llvm::SmallVector current_index(result_type.getRank(), 0); + // Computes the new operand shape using the original shape and the broadcast + // dimensions to match result shape. + llvm::SmallVector operand_new_shape(result_type.getRank(), 1); + for (int i = 0; i < dimensions.size(); ++i) { + operand_new_shape[dimensions[i]] = operand.getType().getDimSize(i); + } + + llvm::SmallVector new_values; + auto num_elements = result_type.getNumElements(); + new_values.reserve(num_elements); + auto operand_values = operand.getValues(); + for (int64_t i = 0; i < num_elements; ++i) { + const int64_t operand_index = + GetElementIndex(operand_new_shape, current_index); + new_values.push_back(*(operand_values.begin() + operand_index)); + IncrementIndex(result_shape, current_index); + } + return DenseElementsAttr::get(result_type, + ArrayRef(new_values)); +} + +template +static Attribute BinaryFolder(Op *op) { + auto lhs_op = op->getLhs().template getDefiningOp(); + auto rhs_op = op->getRhs().template getDefiningOp(); + if (!lhs_op || !lhs_op) return {}; + + auto lhs = dyn_cast_or_null(lhs_op.getValue()); + auto rhs = dyn_cast_or_null(rhs_op.getValue()); + if (!lhs || !rhs) return {}; + + ShapedType type = op->getType().template cast(); + if (!type.hasStaticShape()) { + return {}; + } + + Type etype = type.getElementType(); + + // Evaluate for element types. + if (!etype.isa()) { + return {}; + } + + // Special case for folding splats no matter how large. + // Only covers the case of both attrs being splats; operation-specific cases + // like adding a zero or multiplying by one are handled elsewhere. + SplatElementsAttr splatLhs = lhs.template dyn_cast(); + SplatElementsAttr splatRhs = rhs.template dyn_cast(); + if (splatLhs && splatRhs) { + auto signedLhs = addSign(splatLhs.getSplatValue(), etype); + auto signedRhs = addSign(splatRhs.getSplatValue(), etype); + FailureOr result(Convert()(signedLhs, signedRhs)); + return succeeded(result) ? SplatElementsAttr::get(type, *result) + : Attribute(); + } + + SmallVector values; + values.reserve(lhs.getNumElements()); + for (const auto zip : llvm::zip(lhs.template getValues(), + rhs.template getValues())) { + auto signedLhs = addSign(std::get<0>(zip), etype); + auto signedRhs = addSign(std::get<1>(zip), etype); + FailureOr result(Convert()(signedLhs, signedRhs)); + if (failed(result)) { + return {}; + } + values.push_back(std::move(*result)); + } + + return DenseElementsAttr::get(type, values); +} + +template +class FoldBroadcastInDimBeforeBinaryElementwiseOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BinaryOpType binary_op, + PatternRewriter &rewriter) const override { + auto lhs = binary_op.getLhs(); + auto rhs = binary_op.getRhs(); + auto lhs_bcast_op = lhs.template getDefiningOp(); + auto rhs_bcast_op = rhs.template getDefiningOp(); + if ((lhs_bcast_op && rhs_bcast_op) || (!lhs_bcast_op && !rhs_bcast_op)) { + return rewriter.notifyMatchFailure( + binary_op, "Operands should have exactly one BroadcastInDim op."); + } + auto bcast_op = lhs_bcast_op ? lhs_bcast_op : rhs_bcast_op; + auto const_op = + bcast_op.getOperand().template getDefiningOp(); + if (!const_op) return failure(); + auto const_val = dyn_cast_or_null(const_op.getValue()); + if (!const_val) return failure(); + + auto result_type = + dyn_cast_or_null(bcast_op.getResult().getType()); + if (!result_type || !result_type.hasStaticShape()) + return rewriter.notifyMatchFailure(binary_op, + "Result type must have static shape."); + + auto bcast_dims = bcast_op.getBroadcastDimensions(); + auto elem_type = const_val.getElementType(); + Attribute result; + if (elem_type.template isa()) { + result = ConstFoldBroadcastInDim(result_type, const_val, + bcast_dims); + } else if (elem_type.template isa()) { + result = ConstFoldBroadcastInDim(result_type, const_val, + bcast_dims); + } else { + return rewriter.notifyMatchFailure(bcast_op, "Unsupported element type."); + } + Value new_const_op = + rewriter.create(bcast_op.getLoc(), result); + rewriter.replaceOp(bcast_op, {new_const_op}); + return success(); + } +}; + +using FoldBroadcastInDimBeforeMulOp = + FoldBroadcastInDimBeforeBinaryElementwiseOp; + +// Constant folds mhlo.mul, this folder doesn't have an upper limit on how many +// elements can be folded. +LogicalResult ConstantFoldMul(mhlo::MulOp op, PatternRewriter &rewriter) { + ShapedType type = op.getType().dyn_cast(); + Type etype = type.getElementType(); + Attribute result = {}; + if (etype.isa()) { + result = + BinaryFolder>( + &op); + } else if (etype.isa()) { + result = + BinaryFolder>( + &op); + } + if (result == Attribute()) return failure(); + Value new_const_op = rewriter.create(op.getLoc(), result); + rewriter.replaceOp(op, {new_const_op}); + return success(); +} + +class FoldBroadcastPass + : public PassWrapper> { + public: + StringRef getArgument() const final { return "constant-fold-broadcast-pass"; } + StringRef getDescription() const final { + return "Constant folds BroadcastInDimOp before binary elementwise ops"; + } + void getDependentDialects(::mlir::DialectRegistry ®istry) const override {} + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + patterns.add(ConstantFoldMul); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +std::unique_ptr createFoldBroadcastPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/fuse_convolution_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fuse_convolution_pass.cc new file mode 100644 index 00000000000..45c8edc1ec5 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fuse_convolution_pass.cc @@ -0,0 +1,148 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/utils/validators.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { + +class FuseMhloMulAndConvolutionPattern : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::MulOp mul_op, + PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops. + mhlo::ConvolutionOp conv_op; + mhlo::BroadcastInDimOp broadcast_op; + mhlo::ConstantOp filter; + mhlo::ConstantOp multiplier; + mlir::ElementsAttr filter_value, mul_value; + mlir::DenseIntElementsAttr broadcast_dims; + + // Match and capture values/attributes. + Value lhs = mul_op.getLhs(); + Value rhs = mul_op.getRhs(); + conv_op = lhs.getDefiningOp(); + if (conv_op == nullptr) { + return failure(); + } + filter = conv_op.getRhs().getDefiningOp(); + if (filter == nullptr) { + return failure(); + } + broadcast_op = rhs.getDefiningOp(); + multiplier = + (broadcast_op == nullptr) + ? rhs.getDefiningOp() + : broadcast_op.getOperand().getDefiningOp(); + if (multiplier == nullptr) { + return failure(); + } + auto result_type = OpTrait::util::getBroadcastedType(filter.getType(), + multiplier.getType()); + if (!result_type) { + return rewriter.notifyMatchFailure(mul_op, [&](::mlir::Diagnostic &diag) { + diag << "entities 'filter, multiplier' failed to satisfy constraint: " + "non-broadcastable operands"; + }); + } + filter_value = filter.getValue(); + mul_value = multiplier.getValue(); + // In MHLO, Conv filter is in HWIO format, Depthwise conv filter is in HW1O + // format and backprop input conv filter is in HWOI format. + // Only fuses multiplier if all dimensions other than the out channel + // dimension are equal to 1. + if (!TFL::IsDimensionsDegenerateExceptLastOne( + mul_value.getType().getShape())) { + return rewriter.notifyMatchFailure(mul_op, [&](::mlir::Diagnostic &diag) { + diag << "entities 'mul_value' failed to satisfy constraint: " + "unsupported dimensions"; + }); + } + if (!((*conv_op.getODSResults(0).begin()).hasOneUse())) { + return rewriter.notifyMatchFailure(mul_op, [&](::mlir::Diagnostic &diag) { + diag << "entities 'conv' failed to satisfy constraint: has one use"; + }); + } + + // Rewrite + broadcast_dims = broadcast_op.getBroadcastDimensions(); + if (broadcast_dims == nullptr) { + const auto filter_rank = filter_value.getType().getRank(); + auto dimsType = RankedTensorType::get({1}, rewriter.getIntegerType(64)); + broadcast_dims = DenseIntElementsAttr::get(dimsType, {filter_rank - 1}); + } + Value broadcast_multiplier = rewriter.create( + mul_op.getLoc(), filter.getType(), multiplier, broadcast_dims); + Value new_filter = rewriter.create( + mul_op.getLoc(), filter.getType(), filter, broadcast_multiplier); + Value new_conv = rewriter.create( + mul_op.getLoc(), conv_op.getType(), conv_op.getLhs(), new_filter, + conv_op.getWindowStridesAttr(), conv_op.getPaddingAttr(), + conv_op.getLhsDilationAttr(), conv_op.getRhsDilationAttr(), + conv_op.getWindowReversalAttr(), conv_op.getDimensionNumbers(), + conv_op.getFeatureGroupCount(), conv_op.getBatchGroupCount(), + conv_op.getPrecisionConfigAttr()); + rewriter.replaceOp(mul_op, {new_conv}); + + return success(); + } +}; + +class FuseMhloConvolutionPass + : public PassWrapper> { + public: + StringRef getArgument() const final { return "fuse-mhlo-convolution-pass"; } + StringRef getDescription() const final { + return "Fuses MHLO binary element-wise ops and convolution op"; + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +std::unique_ptr createFuseConvolutionPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_tfl_legalize_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_tfl_legalize_patterns.td deleted file mode 100644 index 44157c7cd5b..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_tfl_legalize_patterns.td +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -include "mlir/IR/OpBase.td" -include "mlir/Dialect/Arith/IR/ArithOps.td" -include "mlir/Dialect/Func/IR/FuncOps.td" -include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" -include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" -include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" - - -// Patterns to legalize mhlo to tfl. - -// TODO(pulkitb): Both TFL_ConstOp and ConstantOp seem to work. Decide which -// one is more apropos. -def LegalizeConst : Pat<(HLO_ConstantOp $value), - (Arith_ConstantOp $value)>; - -def LegalizeMul : Pat<(HLO_MulOp $lhs, $rhs), - (TFL_MulOp $lhs, $rhs, TFL_AF_None)>; - -def LegalizeDiv : Pat<(HLO_DivOp $lhs, $rhs), - (TFL_DivOp $lhs, $rhs, TFL_AF_None)>; - -def LegalizeAdd : Pat<(HLO_AddOp $lhs, $rhs), - (TFL_AddOp $lhs, $rhs, TFL_AF_None)>; - -def LegalizeSub : Pat<(HLO_SubtractOp $lhs, $rhs), - (TFL_SubOp $lhs, $rhs, TFL_AF_None)>; - -def LegalizeSqrt : Pat<(HLO_SqrtOp $arg), (TFL_SqrtOp $arg)>; - -def LegalizeSelect : Pat<(HLO_SelectOp $cond, $x, $y), - (TFL_SelectOp $cond, $x, $y)>; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_util.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_util.cc deleted file mode 100644 index e8ae0140f65..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_util.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_util.h" - -#include -#include - -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/StringRef.h" - -namespace mlir { -namespace TFL { -namespace mhlo { - -std::vector GetAcceptedDialects() { - // It returns the default list of accepted dialects. - std::vector accepted_dialects({"mhlo", "builtin", "func"}); - return accepted_dialects; -} - -bool IsAcceptedDialect(llvm::StringRef dialect_name, - const std::vector& accepted_dialects) { - return std::find(accepted_dialects.begin(), accepted_dialects.end(), - dialect_name) != accepted_dialects.end(); -} - -bool IsMhloOpAllowed(StringRef op_name) { - // As per go/compute-ir-ops-v01. - static DenseSet* denylist = new DenseSet{ - // (R2) Part 1: Internal ops. - "bitcast", "fusion", - // (R2) Part 2: Modularity ops. - // NOTE: These ops were proposed to be excluded from Compute IR - // because we didn't want to necessarily tie the specification to MLIR. - // In an MLIR-based implementation such as MHLO, these ops are fine. - // "get_tuple_element", "return", "tuple", - // (R3) Part 1: Distribution ops. - "after_all", "all_gather", "all_reduce", "all_to_all", - "collective_permute", "create_token", "cross-replica-sum", "infeed", - "outfeed", "print", "recv", "reduce_scatter", "replica_id", "send", - "trace", - // (R3) Part 2: Dynamism ops. - "compute_reshape_shape", "cstr_reshapable", "dynamic_broadcast_in_dim", - "dynamic_conv", "dynamic_gather", "dynamic_iota", "dynamic_pad", - "dynamic_reshape", "get_dimension_size", "real_dynamic_slice", - "set_dimension_size" - // NOTE: These ops were proposed to be excluded from Compute IR for now - // because we wanted to unify them with slice and real_dynamic_slice. - // In the meanwhile, they are very practically important to MHLO, - // so we should keep them on the allowlist. - // "dynamic-slice", "dynamic-update-slice" - }; - return !denylist->contains(op_name); -} - -bool IsAcceptedOp(llvm::StringRef dialect_name, llvm::StringRef op_name, - const std::vector& accepted_dialects) { - if (!IsAcceptedDialect(dialect_name, accepted_dialects)) return false; - return dialect_name != "mhlo" || IsMhloOpAllowed(op_name); -} - -} // namespace mhlo -} // namespace TFL -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_util.h deleted file mode 100644 index b18081a7fca..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_util.h +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_MHLO_UTIL_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_MHLO_UTIL_H_ - -#include - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Casting.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project -#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "stablehlo/dialect/ChloOps.h" // from @stablehlo -#include "stablehlo/dialect/Register.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" -#include "tensorflow/compiler/mlir/xla/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" - -namespace mlir { -namespace TFL { -namespace mhlo { - -std::vector GetAcceptedDialects(); - -// Can we find the given `dialect_name` in the `accepted_dialects`? -bool IsAcceptedDialect(llvm::StringRef dialect_name, - const std::vector &accepted_dialects); - -// Is MHLO op allowed in the TF to MHLO conversion result? -bool IsMhloOpAllowed(StringRef op_name); - -// The consolidated logic to verify if each final op is acceptable or not. -// Also see `PrintOpStatsPass` and `CheckAcceptedOpsPass`. -bool IsAcceptedOp(llvm::StringRef dialect_name, llvm::StringRef op_name, - const std::vector &accepted_dialects); - -// Adds patterns which map TF Ops to MHLO Ops. -inline void PopulateTFToMhloPatterns( - MLIRContext *context, bool legalize_chlo, - llvm::Optional tf2xla_fallback_device_type, bool prefer_tf2xla, - RewritePatternSet *patterns) { - // Add TF->HLO legalization patterns. - ::mlir::mhlo::PopulateLegalizeTfPatterns(context, patterns); - - // Add TF->TF lowering patterns. - TF::PopulateTFLoweringBeforeHLOPatterns(context, patterns); - - if (tf2xla_fallback_device_type) { - // Adding fallback Tf2XlaPatterns is needed to make the patterns work. - // Add TF->HLO legalization patterns via TF2XLA fallback. - ::mlir::mhlo::PopulateLegalizeTfWithTf2XlaPatterns( - tf2xla_fallback_device_type.getValue(), *patterns, context, - prefer_tf2xla); - } - - // Populate with CHLO->HLO lowerings to account for TF ops legalized to - // client HLO (CHLO) first. - // https://github.com/tensorflow/mlir-hlo - if (legalize_chlo) { - chlo::populateDecomposeChloPatterns(context, patterns); - chlo::populateChloBroadcastingPatterns(context, patterns); - } - // ConstantLike op is convenient to create splat constants, but is - // canonicalized to plain HLO constant if statically shaped. Add the - // canonicalization pattern to pattern list to enable multi-hop lowering. - ::mlir::chlo::ConstantLikeOp::getCanonicalizationPatterns(*patterns, context); -} - -} // namespace mhlo -} // namespace TFL -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_MHLO_UTIL_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc index 7c4371caf1e..53a3ba563bc 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc @@ -24,8 +24,11 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_util.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" namespace mlir { namespace odml { @@ -38,9 +41,7 @@ class PrintOpStatsPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpStatsPass) explicit PrintOpStatsPass(raw_ostream *os = &llvm::errs()) - : accepted_dialects_(TFL::mhlo::GetAcceptedDialects()), - os_(os), - total_ops_(0) {} + : accepted_dialects_(GetAcceptedDialects()), os_(os), total_ops_(0) {} // Prints the resultant operation statistics pos_t iterating over the module. void runOnOperation() override; @@ -118,7 +119,7 @@ void PrintOpStatsPass::printSummary() { num_dialect = 0; // Print the number of unconverted ops in the non-accepted dialects. for (const auto &dialect_name : sorted_dialect) { - if (!TFL::mhlo::IsAcceptedDialect(dialect_name, accepted_dialects_)) { + if (!IsAcceptedDialect(dialect_name, accepted_dialects_)) { *os_ << absl::StrFormat("%d %s ops", dialect_count_[dialect_name], absl::AsciiStrToUpper(dialect_name)); if (++num_dialect < num_unaccepted) { @@ -130,9 +131,8 @@ void PrintOpStatsPass::printSummary() { *os_ << "\n\n"; for (const auto &op_with_dialect_name : sorted_op) { - if (!TFL::mhlo::IsAcceptedOp(dialect_name_of_[op_with_dialect_name], - op_name_of_[op_with_dialect_name], - accepted_dialects_)) { + if (!IsAcceptedOp(dialect_name_of_[op_with_dialect_name], + op_name_of_[op_with_dialect_name], accepted_dialects_)) { *os_ << absl::StrFormat("- %s: %4d occurrences \n", op_with_dialect_name, op_with_dialect_count_[op_with_dialect_name]); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc new file mode 100644 index 00000000000..8392c307fb9 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc @@ -0,0 +1,547 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { + +// Convert mhlo.dot to mhlo.dot_general. +LogicalResult ConvertDotToDotGeneral(mhlo::DotOp op, + PatternRewriter &rewriter) { + auto lhs_type = op.getLhs().getType().cast(); + auto rhs_type = op.getRhs().getType().cast(); + if (!lhs_type.hasRank() || !rhs_type.hasRank()) { + return rewriter.notifyMatchFailure(op, "unsupported unranked input type"); + } + if (lhs_type.getRank() < 1 || 2 < lhs_type.getRank() || + rhs_type.getRank() < 1 || 2 < rhs_type.getRank()) { + return rewriter.notifyMatchFailure( + op, + "unsupported dot operation type; operands must be vectors or " + "matrices"); + } + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getLhs(), op.getRhs(), + mhlo::DotDimensionNumbersAttr::get( + op.getContext(), + /*lhsBatchingDimensions=*/{}, + /*rhsBatchingDimensions=*/{}, + /*lhsContractingDimensions=*/{lhs_type.getRank() - 1}, + /*rhsContractingDimensions=*/{0}), + op.getPrecisionConfigAttr()); + return success(); +} + +// Convert reshape(dot_general(reshape(%y), %z)) to +// dot_general(%y, %z) if possible. +LogicalResult RemoveReshapeAroundDotGeneral(mhlo::ReshapeOp reshape_after, + PatternRewriter &rewriter) { + auto dot = dyn_cast_or_null( + reshape_after.getOperand().getDefiningOp()); + if (!dot) return failure(); + + auto reshape_before = + dyn_cast_or_null(dot.getLhs().getDefiningOp()); + if (!reshape_before) return failure(); + + if (!dot.getLhs().getType().hasStaticShape() || + !dot.getRhs().getType().hasStaticShape() || + !reshape_before.getOperand().getType().hasStaticShape() || + !dot.getType().hasStaticShape() || + !reshape_after.getType().hasStaticShape()) { + return rewriter.notifyMatchFailure(reshape_after, + "dynamic shapes not supported"); + } + + const auto range = [](int64_t begin, int n) { + SmallVector result; + result.reserve(n); + for (int i = 0; i < n; ++i) { + result.push_back(i + begin); + } + return result; + }; + + // We only support the mhlo.dot style input layouts, i.e.: + // lhs: [batch, other dims, contract dims] + // rhs: [batch, contract dims, other dims] + auto dim_nums = dot.getDotDimensionNumbers(); + int batch_dims_count = dim_nums.getLhsBatchingDimensions().size(); + int contracting_dims_count = dim_nums.getLhsContractingDimensions().size(); + if (dim_nums.getLhsBatchingDimensions() != + ArrayRef(range(0, batch_dims_count)) || + dim_nums.getRhsBatchingDimensions() != + ArrayRef(range(0, batch_dims_count)) || + dim_nums.getLhsContractingDimensions() != + ArrayRef( + range(dot.getLhs().getType().getRank() - contracting_dims_count, + contracting_dims_count)) || + dim_nums.getRhsContractingDimensions() != + ArrayRef(range(batch_dims_count, contracting_dims_count))) { + return rewriter.notifyMatchFailure(reshape_after, + "unsupported dot_general layout"); + } + + // (B = batch dims, C = contracting dims, Y/Z = other dims) + // + // This pattern converts: + // %before = "mhlo.reshape"(%lhs : BxY1xC) : BxY2xC + // %dot = "mhlo.dot_general"(%before, %rhs : BxCxZ) : BxY2xZ + // %after = "mhlo.reshape"(%dot) : BxY1xZ + // to: + // %dot : "mhlo.dot_general"(%lhs : BxY1xC, %rhs : BxCxZ) : BxY1xZ + + // Extract B, Y1, C from %lhs. + ArrayRef shape_lhs = + reshape_before.getOperand().getType().getShape(); + ArrayRef shape_b = shape_lhs.take_front(batch_dims_count); + ArrayRef shape_c = shape_lhs.take_back(contracting_dims_count); + ArrayRef shape_y1 = + shape_lhs.drop_front(shape_b.size()).drop_back(shape_c.size()); + + // Check %before shape, and extract Y2 from it. + ArrayRef shape_before = reshape_before.getType().getShape(); + if (shape_before.take_front(shape_b.size()) != shape_b || + shape_before.take_back(shape_c.size()) != shape_c) { + return failure(); + } + ArrayRef shape_y2 = + shape_before.drop_front(shape_b.size()).drop_back(shape_c.size()); + + // No need to check %dot; dot_general verifier ensures correct shapes. + // Extract Z from %dot. + ArrayRef shape_z = + dot.getType().getShape().drop_front(shape_b.size() + shape_y2.size()); + + // Check %after shape. + if (reshape_after.getType().getShape() != + ArrayRef(llvm::to_vector( + llvm::concat(shape_b, shape_y1, shape_z)))) { + return failure(); + } + + rewriter.replaceOpWithNewOp( + reshape_after, reshape_after.getType(), reshape_before.getOperand(), + dot.getRhs(), + mhlo::DotDimensionNumbersAttr::get( + reshape_after.getContext(), + /*lhsBatchingDimensions=*/range(0, batch_dims_count), + /*rhsBatchingDimensions=*/range(0, batch_dims_count), + /*lhsContractingDimensions=*/ + range(batch_dims_count + shape_y1.size(), contracting_dims_count), + /*rhsContractingDimensions=*/ + range(batch_dims_count, contracting_dims_count)), + dot.getPrecisionConfigAttr()); + return success(); +} + +// Convert: +// %y0 = dot_general(%x0, %w) +// %y1 = dot_general(%x1, %w) +// ... +// concatenate(%y0, %y1, ...) +// To: +// %x = concatenate(%x0, %x1, ...) +// dot_general(%x, %w) +LogicalResult LiftDotConcatLHS(mhlo::ConcatenateOp concat, + PatternRewriter &rewriter) { + if (concat.getVal().size() < 2) + return rewriter.notifyMatchFailure( + concat, "Concatenate op should have at least two operands"); + + auto first_dot = concat.getVal()[0].getDefiningOp(); + if (!first_dot) + return rewriter.notifyMatchFailure(concat, "Operand is not dot_general"); + if (!first_dot.getLhs().getType().hasStaticShape()) + return rewriter.notifyMatchFailure( + first_dot, "All dot_general LHS must be statically shaped"); + if (!first_dot->hasOneUse()) + return rewriter.notifyMatchFailure(first_dot, "Op has multiple uses"); + + SmallVector all_dot_lhs; + all_dot_lhs.reserve(concat.getVal().size()); + all_dot_lhs.push_back(first_dot.getLhs()); + + const uint64_t batch_dims_count = + first_dot.getDotDimensionNumbers().getLhsBatchingDimensions().size(); + const uint64_t contracting_dims_count = + first_dot.getDotDimensionNumbers().getLhsContractingDimensions().size(); + const uint64_t lhs_other_dims_count = first_dot.getLhs().getType().getRank() - + batch_dims_count - + contracting_dims_count; + + // This pattern only supports concating on LHS other dims (neither batch nor + // contracting). + if (concat.getDimension() < batch_dims_count || + concat.getDimension() >= batch_dims_count + lhs_other_dims_count) { + return rewriter.notifyMatchFailure(concat, + "Not concating on LHS other dims"); + } + + for (auto value : concat.getVal().drop_front()) { + auto dot = value.getDefiningOp(); + if (!dot) + return rewriter.notifyMatchFailure(concat, "Operand is not dot_general"); + + if (dot.getRhs() != first_dot.getRhs()) + return rewriter.notifyMatchFailure( + dot, "dot_general ops have different rhs parameters"); + if (dot.getDotDimensionNumbers() != first_dot.getDotDimensionNumbers()) + return rewriter.notifyMatchFailure( + dot, "dot_general ops have different dimension numbers"); + if (dot.getPrecisionConfig() != first_dot.getPrecisionConfig()) + return rewriter.notifyMatchFailure( + dot, "dot_general ops have different precision configs"); + if (!dot.getLhs().getType().hasStaticShape()) + return rewriter.notifyMatchFailure( + dot, "all dot_general LHS must be statically shaped"); + if (dot.getLhs().getType().getElementType() != + first_dot.getLhs().getType().getElementType() || + dot.getType().getElementType() != first_dot.getType().getElementType()) + return rewriter.notifyMatchFailure( + dot, "all dot_general ops must have the same element type"); + if (!dot->hasOneUse()) + return rewriter.notifyMatchFailure(dot, "Op has multiple uses"); + + all_dot_lhs.push_back(dot.getLhs()); + } + + const auto is_lhs_batch_or_contracting_dim = [&](uint64_t dim) { + auto dim_nums = first_dot.getDotDimensionNumbers(); + return llvm::is_contained(dim_nums.getLhsBatchingDimensions(), dim) || + llvm::is_contained(dim_nums.getLhsContractingDimensions(), dim); + }; + + // dot_general outputs are always in the + // [batch dims, LHS other dims, RHS other dims] + // layout, so the new concat dim is where the n-th (base-0 counting) LHS other + // dim appears in the original LHS layout, where: + // n = old_concat_dim - batch_dims_count + uint64_t n = concat.getDimension() - batch_dims_count; + + // Now try to answer where the n-th LHS other dim was originally placed. + // This is the dimension we should now concat on. + int new_concat_dim = -1; + for (int i = 0; i < first_dot.getLhs().getType().getRank(); ++i) { + if (!is_lhs_batch_or_contracting_dim(i) && n-- == 0) { + new_concat_dim = i; + break; + } + } + + // Now get the output shape of the lifted concat op. + SmallVector new_concat_shape( + first_dot.getLhs().getType().getShape()); + new_concat_shape[new_concat_dim] = 0; + for (auto v : all_dot_lhs) { + new_concat_shape[new_concat_dim] += + v.getType().dyn_cast().getShape()[new_concat_dim]; + } + + auto new_concat = rewriter.create( + concat->getLoc(), concat.getType().clone(new_concat_shape), all_dot_lhs, + rewriter.getI64IntegerAttr(new_concat_dim)); + rewriter.replaceOpWithNewOp( + concat, concat.getType(), new_concat, first_dot.getRhs(), + first_dot.getDotDimensionNumbers(), first_dot.getPrecisionConfigAttr()); + return success(); +} + +// Convert: +// %y0 = dot_general(%x0, %w0) +// %y1 = dot_general(%x1, %w1) +// ... +// concatenate(%y0, %y1, ...) +// To: +// %x = concatenate(%x0, %x1, ...) +// %w = concatenate(%w0, %w1, ...) +// dot_general(%x, %w) +// +// To simplify the implementation, we only handle the case where the final +// concat is on the only batching dim. +LogicalResult LiftDotConcatLHSAndRHS(mhlo::ConcatenateOp concat, + PatternRewriter &rewriter) { + if (concat.getVal().size() < 2) + return rewriter.notifyMatchFailure( + concat, "Concatenate op should have at least two operands"); + + auto first_dot = concat.getVal()[0].getDefiningOp(); + if (!first_dot) + return rewriter.notifyMatchFailure(concat, "Operand is not dot_general"); + if (!first_dot.getLhs().getType().hasStaticShape()) + return rewriter.notifyMatchFailure( + first_dot, "All dot_general LHS must be statically shaped"); + if (!first_dot->hasOneUse()) + return rewriter.notifyMatchFailure(first_dot, "Op has multiple uses"); + + SmallVector all_dot_lhs; + all_dot_lhs.reserve(concat.getVal().size()); + all_dot_lhs.push_back(first_dot.getLhs()); + SmallVector all_dot_rhs; + all_dot_rhs.reserve(concat.getVal().size()); + all_dot_rhs.push_back(first_dot.getRhs()); + + if (first_dot.getDotDimensionNumbers().getLhsBatchingDimensions().size() != 1) + return rewriter.notifyMatchFailure(first_dot, "One batching dim required"); + if (concat.getDimension() != 0) + return rewriter.notifyMatchFailure( + concat, "Not concating on the first batching dim"); + + for (auto value : concat.getVal().drop_front()) { + auto dot = value.getDefiningOp(); + if (!dot) + return rewriter.notifyMatchFailure(concat, "Operand is not dot_general"); + + if (dot.getDotDimensionNumbers() != first_dot.getDotDimensionNumbers()) + return rewriter.notifyMatchFailure( + dot, "dot_general ops have different dimension numbers"); + if (dot.getPrecisionConfig() != first_dot.getPrecisionConfig()) + return rewriter.notifyMatchFailure( + dot, "dot_general ops have different precision configs"); + if (!dot.getLhs().getType().hasStaticShape() || + !dot.getRhs().getType().hasStaticShape()) + return rewriter.notifyMatchFailure( + dot, "all dot_general operands must be statically shaped"); + if (dot.getLhs().getType().getElementType() != + first_dot.getLhs().getType().getElementType() || + dot.getRhs().getType().getElementType() != + first_dot.getRhs().getType().getElementType() || + dot.getType().getElementType() != first_dot.getType().getElementType()) + return rewriter.notifyMatchFailure( + dot, "all dot_general ops must have the same element type"); + if (!dot->hasOneUse()) + return rewriter.notifyMatchFailure(dot, "Op has multiple uses"); + + all_dot_lhs.push_back(dot.getLhs()); + all_dot_rhs.push_back(dot.getRhs()); + } + + // Now get the output shapes of the lifted concat ops. + const int64_t lhs_batch_dim = + first_dot.getDotDimensionNumbers().getLhsBatchingDimensions()[0]; + SmallVector lhs_new_concat_shape( + first_dot.getLhs().getType().getShape()); + lhs_new_concat_shape[lhs_batch_dim] = 0; + for (auto v : all_dot_lhs) { + lhs_new_concat_shape[lhs_batch_dim] += + v.getType().dyn_cast().getShape()[lhs_batch_dim]; + } + const int64_t rhs_batch_dim = + first_dot.getDotDimensionNumbers().getRhsBatchingDimensions()[0]; + SmallVector rhs_new_concat_shape( + first_dot.getRhs().getType().getShape()); + rhs_new_concat_shape[rhs_batch_dim] = 0; + for (auto v : all_dot_rhs) { + rhs_new_concat_shape[rhs_batch_dim] += + v.getType().dyn_cast().getShape()[rhs_batch_dim]; + } + + auto lhs_new_concat = rewriter.create( + concat->getLoc(), concat.getType().clone(lhs_new_concat_shape), + all_dot_lhs, rewriter.getI64IntegerAttr(lhs_batch_dim)); + auto rhs_new_concat = rewriter.create( + concat->getLoc(), concat.getType().clone(rhs_new_concat_shape), + all_dot_rhs, rewriter.getI64IntegerAttr(rhs_batch_dim)); + rewriter.replaceOpWithNewOp( + concat, concat.getType(), lhs_new_concat, rhs_new_concat, + first_dot.getDotDimensionNumbers(), first_dot.getPrecisionConfigAttr()); + return success(); +} + +// Convert: +// %y0 = slice(%x, start=0, limit=2) +// %y1 = slice(%x, start=2, limit=3) +// concat(%y0, %y1, ...) +// To: +// %y = slice(%x, start=0, limit=3) +// concat(%y, ...) +LogicalResult FuseSliceConcat(mhlo::ConcatenateOp concat, + PatternRewriter &rewriter) { + if (concat.getVal().size() < 2) + return rewriter.notifyMatchFailure( + concat, "Concatenate op should have at least two operands"); + + auto first = concat.getVal()[0].getDefiningOp(); + auto second = concat.getVal()[1].getDefiningOp(); + if (!first || !second) + return rewriter.notifyMatchFailure(concat, "operands are not slice ops"); + if (first.getOperand() != second.getOperand()) + return rewriter.notifyMatchFailure(concat, "slice not on the same input"); + if (!first.getStrides().isSplat() || + first.getStrides().getSplatValue().getInt() != 1 || + first.getStrides() != second.getStrides()) + return rewriter.notifyMatchFailure(concat, "slice ops must have stride=1"); + if (!first->hasOneUse() || !second->hasOneUse()) + return rewriter.notifyMatchFailure(concat, "slice ops are used elsewhere"); + + SmallVector new_start; + SmallVector new_limit; + SmallVector new_slice_shape; + new_start.reserve(first.getStrides().size()); + new_limit.reserve(first.getStrides().size()); + new_slice_shape.reserve(first.getStrides().size()); + + for (int i = 0; i < first.getStrides().size(); ++i) { + const int64_t first_start = + first.getStartIndicesAttr().getValues()[i].getInt(); + const int64_t first_limit = + first.getLimitIndicesAttr().getValues()[i].getInt(); + const int64_t second_start = + second.getStartIndicesAttr().getValues()[i].getInt(); + const int64_t second_limit = + second.getLimitIndicesAttr().getValues()[i].getInt(); + + if (i == concat.getDimension()) { + if (first_limit != second_start) + return rewriter.notifyMatchFailure( + second, "slice is not continuous with previous slice"); + } else { + if (first_start != second_start || first_limit != second_limit) + return rewriter.notifyMatchFailure( + second, "non-concat dims have mismatching slice bounds"); + } + + new_start.push_back(first_start); + new_limit.push_back(second_limit); + new_slice_shape.push_back(second_limit - first_start); + } + + auto new_slice = rewriter.create( + FusedLoc::get(first->getContext(), {first.getLoc(), second.getLoc()}), + first.getType().clone(new_slice_shape), first.getOperand(), + /*start_indices=*/rewriter.getI64TensorAttr(new_start), + /*limit_indices=*/rewriter.getI64TensorAttr(new_limit), + /*strides=*/first.getStrides()); + + SmallVector new_concat_values; + new_concat_values.reserve(concat.getVal().size() - 1); + new_concat_values.push_back(new_slice); + llvm::append_range(new_concat_values, concat.getVal().drop_front(2)); + + rewriter.replaceOpWithNewOp( + concat, concat.getType(), new_concat_values, concat.getDimension()); + return success(); +} + +// Convert: +// %input : 1xYxC +// %1 = mhlo.reshape %param : (1xCxZ) -> CxZ +// mhlo.dot_general %input, %1 {batch_dims = []} +// To: +// mhlo.dot_general %input, %param {batch_dims = [0]} +// +// This usage will mostly come from tf-unroll-batch-matmul, so it's fine to only +// handle the case where batching dim is the leftmost dim. +LogicalResult ConvertReshapeDotRhsToBatchedDot(mhlo::DotGeneralOp dot, + PatternRewriter &rewriter) { + mhlo::DotDimensionNumbersAttr dim_nums = dot.getDotDimensionNumbers(); + if (!dim_nums.getLhsBatchingDimensions().empty()) return failure(); + + auto reshape = dot.getRhs().getDefiningOp(); + if (!reshape) return failure(); + if (!reshape->hasOneUse()) + return rewriter.notifyMatchFailure(reshape, "reshape has multiple usages"); + if (!reshape.getType().hasStaticShape() || + !reshape.getOperand().getType().hasStaticShape() || + !dot.getLhs().getType().hasStaticShape()) { + return rewriter.notifyMatchFailure(dot, "dynamic shaping not supported"); + } + + ArrayRef orig_param_shape = + reshape.getOperand().getType().getShape(); + ArrayRef dot_param_shape = reshape.getType().getShape(); + if (orig_param_shape.size() != dot_param_shape.size() + 1 || + orig_param_shape.front() != 1) { + return rewriter.notifyMatchFailure(reshape, "unsupported reshape pattern"); + } + + int lhs_first_other_dim = -1; + for (int i = 0; i < dot.getLhs().getType().getRank(); ++i) { + if (!llvm::is_contained(dim_nums.getLhsContractingDimensions(), i)) { + lhs_first_other_dim = i; + break; + } + } + if (lhs_first_other_dim == -1 || + dot.getLhs().getType().getShape()[lhs_first_other_dim] != 1) { + return rewriter.notifyMatchFailure(dot, "unsupported LHS shape"); + } + + SmallVector new_rhs_contracting_dims; + new_rhs_contracting_dims.reserve( + dim_nums.getRhsContractingDimensions().size()); + for (int64_t d : dim_nums.getRhsContractingDimensions()) { + new_rhs_contracting_dims.push_back(d + 1); + } + + rewriter.replaceOpWithNewOp( + dot, dot.getType(), dot.getLhs(), reshape.getOperand(), + mhlo::DotDimensionNumbersAttr::get( + dot.getContext(), + /*lhsBatchingDimensions=*/{lhs_first_other_dim}, + /*rhsBatchingDimensions=*/{0}, + /*lhsContractingDimensions=*/dim_nums.getLhsContractingDimensions(), + /*rhsContractingDimensions=*/new_rhs_contracting_dims), + dot.getPrecisionConfigAttr()); + return success(); +} + +class OptimizePass + : public PassWrapper> { + public: + StringRef getArgument() const final { return "mhlo-optimize"; } + StringRef getDescription() const final { + return "Applies various optimizations on MHLO IR"; + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(ConvertDotToDotGeneral); + patterns.add(RemoveReshapeAroundDotGeneral); + patterns.add(LiftDotConcatLHS); + patterns.add(LiftDotConcatLHSAndRHS); + patterns.add(FuseSliceConcat); + patterns.add(ConvertReshapeDotRhsToBatchedDot); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +std::unique_ptr createOptimizePass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h new file mode 100644 index 00000000000..9da4a1b5008 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h @@ -0,0 +1,43 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_PASSES_H_ + +#include +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace odml { + +// Creates a pass which unfuses MHLO batch norm inference op into arithmetic +// ops. +std::unique_ptr createUnfuseBatchNormPass(); + +// Creates a pass which constant folds broadcast_in_dim op conditionally. +std::unique_ptr createFoldBroadcastPass(); + +// Creates a pass which fuses MHLO binary element-wise ops and convolution op. +std::unique_ptr createFuseConvolutionPass(); + +// Creates a pass which applies various optimizations on MHLO IR. +std::unique_ptr createOptimizePass(); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc index f802b161fb9..72ae7cc1c00 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc @@ -25,8 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" namespace mlir { -namespace TFL { -namespace mhlo { +namespace odml { class RenameEntrypointToMainPass : public PassWrapper> { @@ -86,6 +85,5 @@ std::unique_ptr CreateRenameEntrypointToMainPass() { static PassRegistration pass; -} // namespace mhlo -} // namespace TFL +} // namespace odml } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h index 2b1fac5d305..e56b7130132 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h @@ -21,13 +21,11 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project namespace mlir { -namespace TFL { -namespace mhlo { +namespace odml { std::unique_ptr CreateRenameEntrypointToMainPass(); -} // namespace mhlo -} // namespace TFL +} // namespace odml } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_RENAME_ENTRYPOINT_TO_MAIN_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.cc index 7ba5821514d..7b050d281fd 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.cc @@ -23,19 +23,18 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" namespace mlir { -namespace TFL { -namespace mhlo { +namespace odml { namespace { LogicalResult SmuggleOp(Operation* op, PatternRewriter& rewriter) { auto call_target = rewriter.getNamedAttr("call_target_name", op->getName().getIdentifier()); - auto custom_call = rewriter.create( + auto custom_call = rewriter.create( op->getLoc(), op->getResultTypes(), op->getOperands(), ArrayRef{call_target}); rewriter.replaceOp(op, custom_call.getResults()); @@ -61,7 +60,7 @@ class SmuggleDisallowedOpsPass public: StringRef getArgument() const final { return "smuggle-disallowed-ops-pass"; } StringRef getDescription() const final { - return "Smuggle disallowed ops via mhlo.custom_calls"; + return "Smuggle disallowed ops via stablehlo.custom_calls"; } void runOnOperation() override { @@ -70,7 +69,7 @@ class SmuggleDisallowedOpsPass ConversionTarget target(getContext()); target.addIllegalOp(); - target.addLegalDialect(); + target.addLegalDialect(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { signalPassFailure(); @@ -84,6 +83,5 @@ std::unique_ptr CreateSmuggleDisallowedOpsPass() { static PassRegistration pass; -} // namespace mhlo -} // namespace TFL +} // namespace odml } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h index f46a369628f..61e076e8099 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h @@ -21,13 +21,11 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project namespace mlir { -namespace TFL { -namespace mhlo { +namespace odml { std::unique_ptr CreateSmuggleDisallowedOpsPass(); -} // namespace mhlo -} // namespace TFL +} // namespace odml } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_SMUGGLE_DISALLOWED_OPS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.cc similarity index 75% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_tfl_pass.cc rename to tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.cc index c8db922d33f..858fe15a7f4 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_tfl_pass.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.h" #include #include @@ -37,41 +37,32 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "stablehlo/dialect/ChloOps.h" // from @stablehlo -#include "stablehlo/dialect/Register.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" namespace mlir { -namespace TFL { -namespace mhlo { +namespace odml { -class MhloToTflPass - : public mlir::PassWrapper> { public: - explicit MhloToTflPass() : PassWrapper() {} - StringRef getArgument() const final { return "mhlo-tfl"; } + explicit StablehloToTflPass() : PassWrapper() {} + StringRef getArgument() const final { return "stablehlo-tfl"; } StringRef getDescription() const final { - return "This pass will legalize MHLO Ops to TFLite custom Ops."; + return "This pass will legalize StableHLO Ops to TFLite custom Ops."; } private: void runOnOperation() override; void getDependentDialects(DialectRegistry& registry) const override { - mlir::mhlo::registerAllMhloDialects(registry); - mlir::stablehlo::registerAllDialects(registry); - registry.insert(); - registry.insert<::mlir::mhlo::MhloDialect>(); - registry.insert(); - registry.insert(); + registry.insert(); } - inline ConstBytesAttr CustomOption(OpBuilder* builder, - const std::string& content) { - return ConstBytesAttr::get(builder->getContext(), - StringRef(content.data(), content.size())); + inline TFL::ConstBytesAttr CustomOption(OpBuilder* builder, + const std::string& content) { + return TFL::ConstBytesAttr::get(builder->getContext(), + StringRef(content.data(), content.size())); } void AddIntegerArray(flexbuffers::Builder* fbb, @@ -84,12 +75,12 @@ class MhloToTflPass } }; -void MhloToTflPass::runOnOperation() { +void StablehloToTflPass::runOnOperation() { func::FuncOp fn = getOperation(); OpBuilder builder(fn.getContext()); fn.walk([&](Operation* op) { - // Process only MHLO ops. - if (op->getDialect()->getNamespace() != "mhlo") return; + // Process only StableHLO ops. + if (op->getDialect()->getNamespace() != "stablehlo") return; // Build options. std::string custom_option_buffer; @@ -143,16 +134,16 @@ void MhloToTflPass::runOnOperation() { auto start = fbb->StartVector(key); auto array_attr = attr.dyn_cast(); if (array_attr.size() > 1 && !array_attr[0].isa() && - !array_attr[0].isa()) { + !array_attr[0].isa()) { emitWarning(op->getLoc(), "serialization of ArrayAttr for ") << key << " only supports Strings."; continue; } for (auto value : array_attr) { - if (value.isa()) { + if (value.isa()) { auto string_value = - mlir::mhlo::stringifyPrecision( - value.cast().getValue()) + mlir::stablehlo::stringifyPrecision( + value.cast().getValue()) .data(); fbb->Add(string_value); } else { @@ -165,9 +156,9 @@ void MhloToTflPass::runOnOperation() { continue; } - if (attr.isa<::mlir::mhlo::ConvDimensionNumbersAttr>()) { + if (attr.isa<::mlir::stablehlo::ConvDimensionNumbersAttr>()) { auto dimension_attr = - attr.dyn_cast<::mlir::mhlo::ConvDimensionNumbersAttr>(); + attr.dyn_cast<::mlir::stablehlo::ConvDimensionNumbersAttr>(); auto start = fbb->StartVector(key); fbb->Add(dimension_attr.getInputBatchDimension()); fbb->Add(dimension_attr.getInputFeatureDimension()); @@ -182,9 +173,9 @@ void MhloToTflPass::runOnOperation() { continue; } - if (attr.isa<::mlir::mhlo::GatherDimensionNumbersAttr>()) { + if (attr.isa<::mlir::stablehlo::GatherDimensionNumbersAttr>()) { auto dimension_attr = - attr.dyn_cast<::mlir::mhlo::GatherDimensionNumbersAttr>(); + attr.dyn_cast<::mlir::stablehlo::GatherDimensionNumbersAttr>(); auto start = fbb->StartVector(key); AddIntegerArray(fbb.get(), dimension_attr.getOffsetDims()); AddIntegerArray(fbb.get(), dimension_attr.getCollapsedSliceDims()); @@ -194,9 +185,9 @@ void MhloToTflPass::runOnOperation() { continue; } - if (attr.isa<::mlir::mhlo::ScatterDimensionNumbersAttr>()) { + if (attr.isa<::mlir::stablehlo::ScatterDimensionNumbersAttr>()) { auto dimension_attr = - attr.dyn_cast<::mlir::mhlo::ScatterDimensionNumbersAttr>(); + attr.dyn_cast<::mlir::stablehlo::ScatterDimensionNumbersAttr>(); auto start = fbb->StartVector(key); AddIntegerArray(fbb.get(), dimension_attr.getUpdateWindowDims()); AddIntegerArray(fbb.get(), dimension_attr.getInsertedWindowDims()); @@ -207,9 +198,9 @@ void MhloToTflPass::runOnOperation() { continue; } - if (attr.isa<::mlir::mhlo::DotDimensionNumbersAttr>()) { + if (attr.isa<::mlir::stablehlo::DotDimensionNumbersAttr>()) { auto dimension_attr = - attr.dyn_cast<::mlir::mhlo::DotDimensionNumbersAttr>(); + attr.dyn_cast<::mlir::stablehlo::DotDimensionNumbersAttr>(); auto start = fbb->StartVector(key); AddIntegerArray(fbb.get(), dimension_attr.getLhsBatchingDimensions()); AddIntegerArray(fbb.get(), dimension_attr.getRhsBatchingDimensions()); @@ -221,19 +212,20 @@ void MhloToTflPass::runOnOperation() { continue; } - if (attr.isa<::mlir::mhlo::ComparisonDirectionAttr>()) { + if (attr.isa<::mlir::stablehlo::ComparisonDirectionAttr>()) { auto string_value = - mlir::mhlo::stringifyComparisonDirection( - attr.cast().getValue()) + mlir::stablehlo::stringifyComparisonDirection( + attr.cast() + .getValue()) .str(); fbb->String(key, string_value); continue; } - if (attr.isa<::mlir::mhlo::ComparisonTypeAttr>()) { + if (attr.isa<::mlir::stablehlo::ComparisonTypeAttr>()) { auto string_value = - mlir::mhlo::stringifyComparisonType( - attr.cast().getValue()) + mlir::stablehlo::stringifyComparisonType( + attr.cast().getValue()) .str(); fbb->String(key, string_value); continue; @@ -257,12 +249,11 @@ void MhloToTflPass::runOnOperation() { op->erase(); }); } -std::unique_ptr> CreateMhloToTflPass() { - return std::make_unique(); +std::unique_ptr> CreateStablehloToTflPass() { + return std::make_unique(); } -static PassRegistration pass; +static PassRegistration pass; -} // namespace mhlo -} // namespace TFL +} // namespace odml } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_tfl_pass.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.h similarity index 68% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_tfl_pass.h rename to tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.h index 4517986edfd..9445b770f10 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_tfl_pass.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_MHLO_TFL_PASS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_MHLO_TFL_PASS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_TFL_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_TFL_PASS_H_ #include #include @@ -23,14 +23,12 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project namespace mlir { -namespace TFL { -namespace mhlo { +namespace odml { -// Creates a pass which transforms TF Ops to MHLO Ops. -std::unique_ptr> CreateMhloToTflPass(); +// Creates a pass which transforms StableHLO Ops to TFL Ops. +std::unique_ptr> CreateStablehloToTflPass(); -} // namespace mhlo -} // namespace TFL +} // namespace odml } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_MHLO_TFL_PASS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_TFL_PASS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.cc new file mode 100644 index 00000000000..4c9929631e1 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.cc @@ -0,0 +1,45 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" + +#include +#include +#include + +#include "llvm/ADT/DenseSet.h" + +namespace mlir { +namespace odml { + +std::vector GetAcceptedDialects() { + // It returns the default list of accepted dialects. + std::vector accepted_dialects({"stablehlo", "builtin", "func"}); + return accepted_dialects; +} + +bool IsAcceptedDialect(llvm::StringRef dialect_name, + const std::vector& accepted_dialects) { + return std::find(accepted_dialects.begin(), accepted_dialects.end(), + dialect_name) != accepted_dialects.end(); +} + +bool IsAcceptedOp(llvm::StringRef dialect_name, llvm::StringRef op_name, + const std::vector& accepted_dialects) { + return IsAcceptedDialect(dialect_name, accepted_dialects); +} + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h new file mode 100644 index 00000000000..bee9031f4e7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h @@ -0,0 +1,41 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_UTIL_H_ + +#include +#include + +#include "llvm/ADT/StringRef.h" + +namespace mlir { +namespace odml { + +std::vector GetAcceptedDialects(); + +// Can we find the given `dialect_name` in the `accepted_dialects`? +bool IsAcceptedDialect(llvm::StringRef dialect_name, + const std::vector &accepted_dialects); + +// The consolidated logic to verify if each final op is acceptable or not. +// Also see `PrintOpStatsPass` and `CheckAcceptedOpsPass`. +bool IsAcceptedOp(llvm::StringRef dialect_name, llvm::StringRef op_name, + const std::vector &accepted_dialects); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_UTIL_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_pass.cc deleted file mode 100644 index cd997872cc1..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_pass.cc +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_pass.h" - -#include - -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_util.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" - -namespace mlir { -namespace TFL { -namespace mhlo { - -class TFToMhloPass - : public mlir::PassWrapper> { - public: - explicit TFToMhloPass(bool skip_quantization_ops = false, - bool skip_resize = false) - : PassWrapper() { - skip_quantization_ops_ = skip_quantization_ops; - skip_resize_ = skip_resize; - } - - TFToMhloPass(const TFToMhloPass &pass) { - skip_quantization_ops_ = pass.skip_quantization_ops_; - skip_resize_ = pass.skip_resize_; - } - - private: - void runOnOperation() override; - - void getDependentDialects(DialectRegistry ®istry) const override { - mlir::mhlo::registerAllMhloDialects(registry); - mlir::stablehlo::registerAllDialects(registry); - registry.insert(); - registry.insert(); - } - - public: - StringRef getArgument() const final { return "tf-mhlo"; } - StringRef getDescription() const final { - return "This pass will legalize TF Ops to MHLO Ops.."; - } - - protected: - Option skip_quantization_ops_{ - *this, "skip-quantization-ops", - ::llvm::cl::desc("Skip quantization ops")}; - - Option skip_resize_{ - *this, "skip-resize", - ::llvm::cl::desc("Skip tf.ResizeBilinear and tf.ResizeNearestNeighbor")}; -}; - -void TFToMhloPass::runOnOperation() { - auto func = getOperation(); - MLIRContext *context = func->getContext(); - - RewritePatternSet patterns(context); - // Add TF to MHLO patterns. - PopulateTFToMhloPatterns( - context, /*legalize_chlo=*/true, - /*tf2xla_fallback_device_type=*/llvm::StringRef("XLA_CPU_JIT"), - /*prefer_tf2xla=*/false, &patterns); - - ConversionTarget target(*context); - target.addIllegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalOp(); - - if (skip_quantization_ops_) { - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - } - if (skip_resize_) { - target.addLegalOp(); - target.addLegalOp(); - } - - FrozenRewritePatternSet frozen_patterns(std::move(patterns)); - if (failed(applyPartialConversion(func, target, frozen_patterns))) { - return signalPassFailure(); - } -} - -std::unique_ptr> CreateTFToMhloPass( - bool skip_quantization_ops, bool skip_resize) { - return std::make_unique(skip_quantization_ops, skip_resize); -} - -static PassRegistration pass; - -} // namespace mhlo -} // namespace TFL -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_pass.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_pass.h deleted file mode 100644 index a5797b3e02e..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_pass.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_MHLO_PASS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_MHLO_PASS_H_ - -#include -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project - -namespace mlir { -namespace TFL { -namespace mhlo { - -// Creates a pass which transforms TF Ops to MHLO Ops. -std::unique_ptr> CreateTFToMhloPass( - bool skip_quantization_ops, bool skip_resize_nearest); - -} // namespace mhlo -} // namespace TFL -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_MHLO_PASS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_tfl_pass.cc deleted file mode 100644 index fa787e63444..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_tfl_pass.cc +++ /dev/null @@ -1,199 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_tfl_pass.h" - -#include -#include -#include - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Casting.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project -#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "stablehlo/dialect/ChloOps.h" // from @stablehlo -#include "stablehlo/dialect/Register.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_util.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" -#include "tensorflow/compiler/mlir/xla/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" - -namespace mlir { -namespace TFL { -namespace mhlo { - -namespace { - -// Converts an `mhlo.compare` general Op to specific TFL comparison Ops such -// as `tfl.greater`. Needs to be instantiated with specific target types, and -// comparison directions. -// -// Example: -// patterns.add>( -// patterns.getContext(), -// ::mlir::mhlo::ComparisonDirection::GE); -// -template -struct ConvertMhloCompareOp - : public OpConversionPattern<::mlir::mhlo::CompareOp> { - public: - explicit ConvertMhloCompareOp(MLIRContext *context) - : OpConversionPattern<::mlir::mhlo::CompareOp>(context) {} - - ::mlir::LogicalResult matchAndRewrite( - ::mlir::mhlo::CompareOp op, mlir::mhlo::CompareOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - auto direction = op.getComparisonDirection(); - - if (direction != Direction) { - return failure(); - } - - rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), - adaptor.getRhs()); - return success(); - } -}; - -// Convert the MHLO Atan2 Op which cannot be mapped to TFL directly to a -// `tfl.custom` Op which can be executed using custom TFL kernels. -struct ConvertMhloAtan2ToTflCustomOp - : public OpConversionPattern<::mlir::mhlo::Atan2Op> { - public: - explicit ConvertMhloAtan2ToTflCustomOp(::mlir::MLIRContext *context) - : OpConversionPattern<::mlir::mhlo::Atan2Op>(context) {} - - ConstBytesAttr BuildEmptyConstBytesAttr(Operation *op) const { - OpBuilder builder(op); - - return ConstBytesAttr::get(builder.getContext(), StringRef()); - } - - ::mlir::LogicalResult matchAndRewrite( - ::mlir::mhlo::Atan2Op op, OpAdaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - auto op_code = op->getName().stripDialect().str(); - auto custom_option = BuildEmptyConstBytesAttr(op); - - rewriter.replaceOpWithNewOp(op, op->getResultTypes(), - adaptor.getOperands(), - op_code, custom_option); - return success(); - } -}; - -} // namespace - -class TFMhloTFLPass - : public mlir::PassWrapper> { - private: - void runOnOperation() override; - - void getDependentDialects(DialectRegistry ®istry) const override { - mlir::mhlo::registerAllMhloDialects(registry); - mlir::stablehlo::registerAllDialects(registry); - registry.insert(); - } - - public: - StringRef getArgument() const final { return "tf-mhlo-tfl"; } - StringRef getDescription() const final { - return "This pass will legalize TF ops to TFL via mHLO."; - } -}; - -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/generated_mhlo_tfl_legalize_patterns.inc" - -// Convert TF to MHLO and then to TFLite. -// This is "all or nothing" based conversion. If a TF op converts to a few MHLO -// ops that cannot fully convert to TFLite, the original TF op is kept. -void TFMhloTFLPass::runOnOperation() { - auto func = getOperation(); - MLIRContext *context = func->getContext(); - - RewritePatternSet patterns(context); - - // Add TF to MHLO patterns. - PopulateTFToMhloPatterns( - context, /*legalize_chlo=*/true, - /*tf2xla_fallback_device_type=*/llvm::StringRef("XLA_CPU_JIT"), - /*prefer_tf2xla=*/false, &patterns); - - // Add MHLO to TFL patterns. - populateWithGenerated(patterns); - patterns.add, - ConvertMhloCompareOp, - ConvertMhloCompareOp, - ConvertMhloCompareOp, - ConvertMhloCompareOp>( - patterns.getContext()); - - ConversionTarget target(*context); - // Intermediate dialects. - target.addIllegalDialect(); - target.addIllegalDialect<::mlir::mhlo::MhloDialect>(); - // Final expected dialects. - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - - FrozenRewritePatternSet frozen_patterns(std::move(patterns)); - if (failed(applyPartialConversion(func, target, frozen_patterns))) { - return signalPassFailure(); - } -} - -std::unique_ptr> CreateTFMhloTFLPass() { - return std::make_unique(); -} - -static PassRegistration pass([] { - return CreateTFMhloTFLPass(); -}); - -} // namespace mhlo -} // namespace TFL -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_poly_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_poly_pass.cc deleted file mode 100644 index 9195fcf0577..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_poly_pass.cc +++ /dev/null @@ -1,204 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_poly_pass.h" - -#include -#include -#include - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Casting.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project -#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "stablehlo/dialect/ChloOps.h" // from @stablehlo -#include "stablehlo/dialect/Register.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_util.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_pass.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" -#include "tensorflow/compiler/mlir/xla/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" - -namespace mlir { -namespace TFL { -namespace mhlo { - -static bool isTFOp(Operation *op) { - return op->getDialect()->getNamespace() == "tf"; -} - -class TFPolyPass - : public mlir::PassWrapper> { - private: - void runOnOperation() override; - - void getDependentDialects(DialectRegistry ®istry) const override { - mlir::mhlo::registerAllMhloDialects(registry); - mlir::stablehlo::registerAllDialects(registry); - registry.insert(); - } - - // Options that control what transfomraitons are applied. - - PolyCallOptions options_; - - // A transformation to be applied to TF ops. The transformation is specified - // by patterns and target. - struct Tranformation { - std::string name; - FrozenRewritePatternSet patterns; // The patterns to be applied. - ConversionTarget target; // The target of the transformation. - }; - - // Creates a list of transformaiton to be applied to TF ops. - // The transformations apply on replicated list of TF ops so there is no - // interaction between the transformations. - std::vector LoadTransformations(PolyCallOptions options, - MLIRContext *context) { - std::vector transformations; - // Optionally add TF to MHLO pass. - if (options.enable_tf_mhlo_conversion) { - RewritePatternSet patterns(context); - PopulateTFToMhloPatterns( - context, /*legalize_chlo=*/true, - /*tf2xla_fallback_device_type=*/llvm::StringRef("DEFAULT"), - /*prefer_tf2xla=*/false, &patterns); - - ConversionTarget target(*context); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect<::mlir::mhlo::MhloDialect>(); - - FrozenRewritePatternSet frozen_patterns(std::move(patterns)); - Tranformation tf_mhlo_transform = {"tf_mhlo", frozen_patterns, target}; - transformations.push_back(tf_mhlo_transform); - } - return transformations; - } - - // Copies an `op` and put the copy into `region`, and return the copied op. - Operation *CopyTfAndCreateRegion(OpBuilder *builder, Operation *op, - Region *region) { - Block *block = new Block; - region->push_back(block); - builder->setInsertionPointToEnd(®ion->front()); - Operation *tf_op = builder->clone(*op); - Location loc = op->getLoc(); - block->addArguments(op->getOperandTypes(), - SmallVector(op->getNumOperands(), loc)); - for (auto &idx_args : llvm::enumerate(block->getArguments())) { - tf_op->setOperand(idx_args.index(), idx_args.value()); - } - builder->create(loc, tf_op->getResults()); - return tf_op; - } - - public: - StringRef getArgument() const final { return "tf-poly"; } - StringRef getDescription() const final { - return "This pass will legalize TF ops to poly call."; - } - explicit TFPolyPass(PolyCallOptions options) { options_ = options; } -}; - -void TFPolyPass::runOnOperation() { - func::FuncOp fn = getOperation(); - MLIRContext *context = fn->getContext(); - const std::vector transformations = - LoadTransformations(options_, context); - const int num_transformation = transformations.size(); - std::vector> to_transform(num_transformation); - fn.walk([&](Operation *op) { - // Process only TF ops. - if (!isTFOp(op)) return; - - // Create polycall op. Need to call setInsertionPoint to avoid recurrsion. - OpBuilder builder(op->getContext()); - builder.setInsertionPoint(op); - auto poly_op = builder.create( - op->getLoc(), op->getResultTypes(), op->getOperands(), - num_transformation + 1); - poly_op->setAttrs(op->getAttrs()); - - // Create TF region. - Region tf_region; - (void)CopyTfAndCreateRegion(&builder, op, &tf_region); - poly_op.getCalls() - .take_back(num_transformation + 1) - .data() - ->takeBody(tf_region); - - // Create regions according to the transformations. - for (int i = 0; i < num_transformation; i++) { - Region region; - to_transform[i].push_back(CopyTfAndCreateRegion(&builder, op, ®ion)); - poly_op.getCalls().take_back(i + 1).data()->takeBody(region); - } - - // Replace original func with polycall. - op->replaceAllUsesWith(poly_op); - op->erase(); - }); - - // Apply transformations. - for (int i = 0; i < num_transformation; i++) { - auto transformation = transformations[i]; - auto op_to_transform = to_transform[i]; - if (failed(applyPartialConversion(op_to_transform, transformation.target, - transformation.patterns))) { - return signalPassFailure(); - } - } -} - -std::unique_ptr> CreateTFPolyPass( - PolyCallOptions options) { - return std::make_unique(options); -} - -static PassRegistration pass([] { - PolyCallOptions options; - options.enable_tf_mhlo_conversion = true; - return CreateTFPolyPass(options); -}); - -} // namespace mhlo -} // namespace TFL -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_poly_pass.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_poly_pass.h deleted file mode 100644 index 2f2a6f7b655..00000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_poly_pass.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_POLY_PASS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_POLY_PASS_H_ - -#include -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project - -namespace mlir { -namespace TFL { -namespace mhlo { - -// Creates a pass which transforms TF Ops to multiple representations. -// Only use this to TF ops that cannot convert to tflite fully. -struct PolyCallOptions { - bool enable_tf_mhlo_conversion = false; -}; -std::unique_ptr> CreateTFPolyPass( - PolyCallOptions options); - -} // namespace mhlo -} // namespace TFL -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_POLY_PASS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc new file mode 100644 index 00000000000..d3886e9127b --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc @@ -0,0 +1,165 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" + +#include +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "stablehlo/dialect/Register.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/register.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rewriters.h" + +namespace mlir { +namespace odml { + +class TFToMhloPass + : public mlir::PassWrapper> { + public: + explicit TFToMhloPass(bool skip_quantization_ops = false, + bool skip_resize = false) + : PassWrapper() { + skip_quantization_ops_ = skip_quantization_ops; + skip_resize_ = skip_resize; + } + + TFToMhloPass(const TFToMhloPass &pass) { + skip_quantization_ops_ = pass.skip_quantization_ops_; + skip_resize_ = pass.skip_resize_; + } + + private: + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + mlir::mhlo::registerAllMhloDialects(registry); + mlir::stablehlo::registerAllDialects(registry); + registry.insert(); + registry.insert(); + } + + public: + StringRef getArgument() const final { return "tf-mhlo"; } + StringRef getDescription() const final { + return "This pass will legalize TF Ops to MHLO Ops."; + } + + protected: + Option skip_quantization_ops_{ + *this, "skip-quantization-ops", + ::llvm::cl::desc("Skip quantization ops")}; + + Option skip_resize_{ + *this, "skip-resize", + ::llvm::cl::desc("Skip tf.ResizeBilinear and tf.ResizeNearestNeighbor")}; +}; + +void TFToMhloPass::runOnOperation() { + auto func = getOperation(); + MLIRContext *context = func->getContext(); + + RewritePatternSet patterns(context); + mhlo::PopulateLegalizeTfPatterns(context, &patterns); + TF::PopulateTFLoweringBeforeHLOPatterns(context, &patterns); + mhlo::Tf2XlaTypeConverter converter; + mhlo::PopulateLegalizeTfWithTf2XlaPatterns( + "XLA_CPU_JIT", patterns, context, converter, /*prefer_tf2xla=*/false); + chlo::populateDecomposeChloPatterns(context, &patterns); + chlo::populateChloBroadcastingPatterns(context, &patterns); + chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context); + + ConversionTarget target(*context); + target.addIllegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalOp(); + + if (skip_quantization_ops_) { + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + } + if (skip_resize_) { + target.addLegalOp(); + target.addLegalOp(); + } + + FrozenRewritePatternSet frozen_patterns(std::move(patterns)); + if (failed(applyPartialConversion(func, target, frozen_patterns))) { + return signalPassFailure(); + } +} + +struct TFToStablehloOptions : public PassPipelineOptions { + Option skip_quantization_ops{*this, "skip-quantization-ops", + ::llvm::cl::desc("Skip quantization ops")}; + Option skip_resize{ + *this, "skip-resize", + ::llvm::cl::desc("Skip tf.ResizeBilinear and tf.ResizeNearestNeighbor")}; +}; + +void PopulateLegalizeTFToStablehloPipeline( + OpPassManager &pm, const TFToStablehloOptions &options) { + // TODO(burmako): Migrate this pass from producing MHLO to producing StableHLO + // by aligning with the TF/XLA bridge on the corresponding functionality and + // reusing their work, perhaps through `LowerToMlProgramAndHlo`. + pm.addNestedPass(std::make_unique( + options.skip_quantization_ops, options.skip_resize)); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mhlo::createHloLegalizeToStablehloPass()); +} + +static PassPipelineRegistration + legalize_tf_to_stablehlo_pipeline("tf-stablehlo", + "Legalize TF ops to StableHLO ops", + PopulateLegalizeTFToStablehloPipeline); + +void AddLegalizeTFToStablehloPasses(OpPassManager &pm, + bool skip_quantization_ops, + bool skip_resize) { + TFToStablehloOptions options; + options.skip_quantization_ops = skip_quantization_ops; + options.skip_resize = skip_resize; + PopulateLegalizeTFToStablehloPipeline(pm, options); +} + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_tfl_pass.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h similarity index 55% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_tfl_pass.h rename to tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h index 06e44700bb8..0eb199b7e8f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_tfl_pass.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h @@ -13,26 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_MHLO_TFL_PASS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_MHLO_TFL_PASS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ -#include -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project namespace mlir { -namespace TFL { -namespace mhlo { +namespace odml { -// Creates a pass which transforms TF Ops to TFLite via an intermediate -// conversion to MHLO. -std::unique_ptr> CreateTFMhloTFLPass(); +// Adds passes which transform TF Ops to StableHLO Ops. +void AddLegalizeTFToStablehloPasses(OpPassManager& pm, + bool skip_quantization_ops, + bool skip_resize); -} // namespace mhlo -} // namespace TFL +} // namespace odml } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_MHLO_TFL_PASS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_mhlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc similarity index 80% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_mhlo_pass.cc rename to tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc index c2370ecd63f..278ce333258 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_mhlo_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_mhlo_pass.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.h" #include #include @@ -35,40 +35,34 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "stablehlo/dialect/Register.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" namespace mlir { -namespace TFL { -namespace mhlo { +namespace odml { -class TflToMhloPass - : public mlir::PassWrapper> { public: - explicit TflToMhloPass() : PassWrapper() {} - StringRef getArgument() const final { return "tfl-parse-mhlo-ops"; } + explicit TflToStablehloPass() : PassWrapper() {} + StringRef getArgument() const final { return "tfl-parse-stablehlo-ops"; } StringRef getDescription() const final { - return "This pass will legalize TFLite custom Ops to MHLO ops."; + return "This pass will legalize TFLite custom Ops to StableHLO ops."; } private: void runOnOperation() override; void getDependentDialects(DialectRegistry& registry) const override { - mlir::mhlo::registerAllMhloDialects(registry); mlir::stablehlo::registerAllDialects(registry); - registry.insert<::mlir::mhlo::MhloDialect>(); - registry.insert(); } - inline ConstBytesAttr CustomOption(OpBuilder* builder, - const std::string& content) { - return ConstBytesAttr::get(builder->getContext(), - StringRef(content.data(), content.size())); + inline TFL::ConstBytesAttr CustomOption(OpBuilder* builder, + const std::string& content) { + return TFL::ConstBytesAttr::get(builder->getContext(), + StringRef(content.data(), content.size())); } std::vector FlatbufferVecToMlirVec(const flexbuffers::Vector& vec) { @@ -100,8 +94,15 @@ class TflToMhloPass for (size_t i = 0; i < vector.size(); i++) { vec.push_back(vector[i].AsInt64()); } + std::vector shape; + if (std::string{key} == "padding") { + shape.push_back(vec.size() / 2); + shape.push_back(2); + } else { + shape.push_back(vec.size()); + } RankedTensorType ty = tensorflow::GetTypeFromTFTensorShape( - {static_cast(vec.size())}, builder->getIntegerType(64)); + shape, builder->getIntegerType(64)); auto named_attr = builder->getNamedAttr(key, DenseIntElementsAttr::get(ty, vec)); attrs.push_back(named_attr); @@ -113,10 +114,10 @@ class TflToMhloPass if (std::string{key} == "precision_config") { llvm::SmallVector precision_attrs; for (size_t i = 0; i < vector.size(); i++) { - auto conf_attr = mlir::mhlo::PrecisionAttr::get( - builder->getContext(), - mlir::mhlo::symbolizePrecision(vector[i].AsString().str()) - .getValue()); + auto conf_attr = mlir::stablehlo::PrecisionAttr::get( + builder->getContext(), mlir::stablehlo::symbolizePrecision( + vector[i].AsString().str()) + .value()); precision_attrs.push_back(conf_attr); } auto named_attr = builder->getNamedAttr( @@ -144,7 +145,7 @@ class TflToMhloPass auto vec2 = FlatbufferVecToMlirVec(value_vec[5].AsVector()); auto vec3 = FlatbufferVecToMlirVec(value_vec[8].AsVector()); auto conv_dimension_numbers_attr = - mlir::mhlo::ConvDimensionNumbersAttr::get( + mlir::stablehlo::ConvDimensionNumbersAttr::get( builder->getContext(), value_vec[0].AsInt64(), value_vec[1].AsInt64(), llvm::ArrayRef(vec1), value_vec[3].AsInt64(), value_vec[4].AsInt64(), @@ -168,7 +169,7 @@ class TflToMhloPass } }; -void TflToMhloPass::runOnOperation() { +void TflToStablehloPass::runOnOperation() { func::FuncOp fn = getOperation(); OpBuilder builder(fn.getContext()); fn.walk([&](TFL::CustomOp custom_op) { @@ -189,18 +190,17 @@ void TflToMhloPass::runOnOperation() { } op_state.addTypes(output_tys); op_state.addAttributes(attr); - auto mhlo_op = builder.create(op_state); - custom_op.replaceAllUsesWith(mhlo_op); + auto stablehlo_op = builder.create(op_state); + custom_op.replaceAllUsesWith(stablehlo_op); custom_op.erase(); }); } -std::unique_ptr> CreateTflToMhloPass() { - return std::make_unique(); +std::unique_ptr> CreateTflToStablehloPass() { + return std::make_unique(); } -static PassRegistration pass; +static PassRegistration pass; -} // namespace mhlo -} // namespace TFL +} // namespace odml } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_mhlo_pass.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.h similarity index 72% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_mhlo_pass.h rename to tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.h index 1ad815c1cca..e6e40762891 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_mhlo_pass.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TFL_MHLO_PASS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TFL_MHLO_PASS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TFL_STABLEHLO_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TFL_STABLEHLO_PASS_H_ #include #include @@ -23,14 +23,12 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project namespace mlir { -namespace TFL { -namespace mhlo { +namespace odml { -// Creates a pass which transforms TFLite to MHLO Ops. -std::unique_ptr> CreateTflToMhloPass(); +// Creates a pass which transforms TFLite to StableHLO Ops. +std::unique_ptr> CreateTflToStablehloPass(); -} // namespace mhlo -} // namespace TFL +} // namespace odml } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TFL_MHLO_PASS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TFL_STABLEHLO_PASS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc index 7b0d00ac45b..b34c9bc0ffa 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc @@ -18,20 +18,22 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/drop_savedmodel_semantics.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_mhlo_pass.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" namespace mlir { namespace odml { void AddTFToStablehloPasses(OpPassManager& pm, bool skip_resize, bool smuggle_disallowed_ops) { - pm.addPass(mlir::TFL::mhlo::CreateRenameEntrypointToMainPass()); + pm.addPass(CreateRenameEntrypointToMainPass()); // TODO(b/230572023): Consider improving shape inference for While op instead // of dropping the attribute. This need not be correct for models not trained // on TPU. @@ -56,20 +58,30 @@ void AddTFToStablehloPasses(OpPassManager& pm, bool skip_resize, pm.addPass(mlir::TF::CreateTFShapeInferencePass()); pm.addNestedPass( mlir::quant::CreateConvertTFQuantOpsToMHLOPass()); - pm.addPass(mhlo::createLegalizeTFControlFlowPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addNestedPass(mlir::TFL::mhlo::CreateTFToMhloPass( - /*skip_quantization_ops=*/false, skip_resize)); pm.addPass(mlir::createCanonicalizerPass()); + AddLegalizeTFToStablehloPasses(pm, /*skip_quantization_ops=*/false, + skip_resize); if (smuggle_disallowed_ops) { - pm.addNestedPass( - mlir::TFL::mhlo::CreateSmuggleDisallowedOpsPass()); + pm.addNestedPass(CreateSmuggleDisallowedOpsPass()); pm.addPass(mlir::createCanonicalizerPass()); } - pm.addPass(mlir::TFL::mhlo::CreateDropSavedModelSemanticsPass()); + pm.addPass(CreateDropSavedModelSemanticsPass()); } -void AddStablehloOptimizationPasses(OpPassManager& pm) {} +void AddStablehloOptimizationPasses(OpPassManager& pm) { + // The current plan of record is to avoid doing optimization passes + // on StableHLO, treating StableHLO purely as an input format, and do all + // optimizations via MHLO passes that can be shared with the OpenXLA compiler. + // Therefore, this function inserts a StableHLO <=> MHLO roundtrip to make + // this happen. + pm.addPass(mhlo::createStablehloLegalizeToHloPass()); + pm.addNestedPass(createUnfuseBatchNormPass()); + pm.addNestedPass(createFuseConvolutionPass()); + pm.addNestedPass(createFoldBroadcastPass()); + pm.addNestedPass(createOptimizePass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mhlo::createHloLegalizeToStablehloPass()); +} } // namespace odml } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc new file mode 100644 index 00000000000..4a46a083aff --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc @@ -0,0 +1,198 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { + +// Broadcasts the 1D value tensor 'value_1d' to the shape of 'result_type'. If +// 'shape_value' is initialized, creates a dynamic broadcast, otherwise creates +// a static broadcast. +Value broadcastToFeatureDim(Location loc, RankedTensorType result_type, + Value value1d, Value shape_value, + int64_t feature_dim, PatternRewriter &rewriter) { + auto dims_type = + RankedTensorType::get(/*shape=*/{1}, rewriter.getIntegerType(64)); + auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim}); + if (shape_value) { + return rewriter.createOrFold( + loc, result_type, value1d, shape_value, dims); + } + assert(result_type.hasStaticShape()); + return rewriter.create(loc, result_type, value1d, + dims); +} + +// Gets the shape of operand, assuming it is a dynamic shape with static rank. +Value getShapeValue(Location loc, Value operand, PatternRewriter &rewriter) { + RankedTensorType resultType = operand.getType().dyn_cast(); + return rewriter.create( + loc, + RankedTensorType::get(/*shape=*/{resultType.getRank()}, + rewriter.getIndexType()), + operand); +} + +Value materializeEpsilon(Operation *op, FloatAttr epsilon_attr, + FloatType fp_type, Value broadcast_to, + RankedTensorType broadcast_to_type, + PatternRewriter &rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + if (epsilon_attr.getType() != fp_type) { + // Need to convert. + bool loses_info; + APFloat epsilon_float = epsilon_attr.getValue(); + auto status = epsilon_float.convert( + fp_type.getFloatSemantics(), APFloat::rmNearestTiesToEven, &loses_info); + if ((status & (~APFloat::opInexact)) != APFloat::opOK) { + op->emitWarning() << "Could not convert batch_norm epsilon to target fp " + "type: opStatus = " + << static_cast(status); + return nullptr; + } + if (loses_info) { + op->emitWarning("Conversion of epsilon loses precision"); + } + epsilon_attr = b.getFloatAttr(fp_type, epsilon_float); + } + + auto scalar_type = RankedTensorType::get(/*shape=*/{}, fp_type); + auto epsilon_tensor_attr = + DenseElementsAttr::get(scalar_type, {epsilon_attr.cast()}); + Value epsilon = b.create(epsilon_tensor_attr); + auto dims_type = RankedTensorType::get(/*shape=*/{0}, b.getIntegerType(64)); + auto dims = DenseIntElementsAttr::get(dims_type, SmallVector{}); + if (broadcast_to_type.hasStaticShape()) { + return b.create(broadcast_to_type, epsilon, dims); + } + Value shape_value = getShapeValue(op->getLoc(), broadcast_to, rewriter); + return b.createOrFold( + broadcast_to_type, epsilon, shape_value, dims); +} + +class UnfuseBatchNormInferencePattern + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bn_op, + PatternRewriter &rewriter) const override { + // Enforce type invariants. + // Note that we deduce the actual element type from the variance, + // which should not be subject to quantization at a higher level. + auto input_type = bn_op.getOperand().getType().dyn_cast(); + auto variance_type = + bn_op.getVariance().getType().dyn_cast(); + if (!input_type || !variance_type) { + return failure(); + } + auto fp_type = variance_type.getElementType().dyn_cast(); + if (!fp_type) { + return failure(); + } + + // result = (x - mean) * scale / sqrt(variance + epsilon) + offset + // Let multiplier = scale / sqrt(variance + epsilon), to compute + // (x - mean) * scale / sqrt(variance + epsilon) + offset, + // is then to compute (x * multiplier) + (offset - mean * multiplier). + + auto epsilon = materializeEpsilon( + bn_op.getOperation(), bn_op.getEpsilonAttr(), fp_type, + bn_op.getVariance(), variance_type, rewriter); + if (!epsilon) { + return failure(); + } + + // Compute multiplier = scale / sqrt(variance + epsilon) + Value multiplier = rewriter.create( + bn_op.getLoc(), bn_op.getVariance(), epsilon); + multiplier = rewriter.create(bn_op.getLoc(), multiplier); + multiplier = rewriter.create(bn_op.getLoc(), multiplier, + bn_op.getScale()); + + // Compute rhs = offset - mean * multiplier + Value rhs = rewriter.create(bn_op.getLoc(), multiplier, + bn_op.getMean()); + rhs = rewriter.create(bn_op.getLoc(), bn_op.getOffset(), + rhs); + + // Broadcast `multiplier` and `rhs` + Value shape_value; + if (!input_type.hasStaticShape()) { + shape_value = getShapeValue(bn_op.getLoc(), bn_op.getOperand(), rewriter); + } + int64_t feature_dim = bn_op.getFeatureIndex(); + auto broadcast_multiplier = + broadcastToFeatureDim(bn_op.getLoc(), input_type, multiplier, + shape_value, feature_dim, rewriter); + auto broadcast_rhs = broadcastToFeatureDim( + bn_op.getLoc(), input_type, rhs, shape_value, feature_dim, rewriter); + + // Computes x * multiplier + rhs + Value lhs = rewriter.create(bn_op.getLoc(), bn_op.getOperand(), + broadcast_multiplier); + rewriter.replaceOpWithNewOp(bn_op, lhs, broadcast_rhs); + + return success(); + } +}; + +class UnfuseMhloBatchNormPass + : public PassWrapper> { + public: + StringRef getArgument() const final { return "unfuse-mhlo-batch-norm-pass"; } + StringRef getDescription() const final { + return "Unfuses MHLO batch norm inference op into arithmetic ops"; + } + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +std::unique_ptr createUnfuseBatchNormPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/tests/BUILD b/tensorflow/compiler/mlir/lite/tests/BUILD index ae71e967452..f2cfa39ab54 100644 --- a/tensorflow/compiler/mlir/lite/tests/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/BUILD @@ -1,16 +1,27 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", exclude = ["load-quantization-recipe.mlir"], + size_override = { + "optimize.mlir": "medium", + "prepare-tf.mlir": "medium", + "prepare-tf-fake-quant.mlir": "medium", + "prepare-tf-fake-quant-4bit.mlir": "medium", + "raise-custom-ops.mlir": "medium", + }, tags_override = { "legalize-tf.mlir": ["no_rocm"], "optimize.mlir": ["no_rocm"], "prepare-tf.mlir": ["no_rocm"], + "const-fold.mlir": ["no_mac_arm64"], }, test_file_exts = ["mlir"], ) diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir index 6de8a6b5f9c..8429d9198c4 100644 --- a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -pass-pipeline='func.func(canonicalize)' -tfl-runtime-verify -split-input-file -verify-diagnostics %s | FileCheck %s +// RUN: tf-opt -canonicalize -tfl-runtime-verify -split-input-file -verify-diagnostics %s | FileCheck %s // CHECK-LABEL: @squeeze_folder func.func @squeeze_folder(%arg0 : tensor) -> tensor { diff --git a/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD b/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD index 4ef1e88b73c..9a0b427f294 100644 --- a/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/debuginfo/BUILD @@ -1,6 +1,8 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + licenses(["notice"]) glob_lit_tests( diff --git a/tensorflow/compiler/mlir/lite/tests/decompose-hybrid-quantization.mlir b/tensorflow/compiler/mlir/lite/tests/decompose-hybrid-quantization.mlir index 10237c446c6..ce2c896ccde 100644 --- a/tensorflow/compiler/mlir/lite/tests/decompose-hybrid-quantization.mlir +++ b/tensorflow/compiler/mlir/lite/tests/decompose-hybrid-quantization.mlir @@ -97,12 +97,12 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32 // CHECK-DAG: %[[VAL1:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<16x{{.+}}>, value = dense<2> : tensor<16xi32>} // CHECK-DAG: %[[VAL2:.+]] = "tfl.dequantize"(%[[VAL0]]) // CHECK-DAG: %[[VAL3:.+]] = "tfl.dequantize"(%[[VAL1]]) - // CHECK-DAG: %[[VAL4:.+]] = "tfl.transpose_conv"(%[[SHAPE]], %[[VAL2]], %arg0, %[[VAL3]]) {padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} + // CHECK-DAG: %[[VAL4:.+]] = "tfl.transpose_conv"(%[[SHAPE]], %[[VAL2]], %arg0, %[[VAL3]]) {fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} // CHECK: return %[[VAL4]] %0 = "tfl.pseudo_const"() { value = dense<[1, 32, 32, 16]> : tensor<4xi32> } : () -> tensor<4xi32> %1 = "tfl.pseudo_qconst"() {qtype = tensor<16x!quant.uniform>, value = dense<1> : tensor<16xi32>} : () -> tensor<16x1x1x8x!quant.uniform> %2 = "tfl.pseudo_qconst"() {qtype = tensor<16x!quant.uniform>, value = dense<2> : tensor<16xi32>} : () -> tensor<16x!quant.uniform> - %3 = "tfl.transpose_conv"(%0, %1, %arg0, %2) {padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<4xi32>, tensor<16x1x1x8x!quant.uniform>, tensor<1x32x32x8xf32>, tensor<16x!quant.uniform>) -> tensor<1x32x32x16xf32> + %3 = "tfl.transpose_conv"(%0, %1, %arg0, %2) {fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<4xi32>, tensor<16x1x1x8x!quant.uniform>, tensor<1x32x32x8xf32>, tensor<16x!quant.uniform>) -> tensor<1x32x32x16xf32> func.return %3 : tensor<1x32x32x16xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/BUILD b/tensorflow/compiler/mlir/lite/tests/end2end/BUILD index 835c126d63b..b162606d135 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/end2end/BUILD @@ -1,6 +1,8 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + licenses(["notice"]) glob_lit_tests( @@ -9,6 +11,9 @@ glob_lit_tests( ":test_utilities", ], driver = "@llvm-project//mlir:run_lit.sh", + size_override = { + "quant_stats.pbtxt": "medium", + }, tags_override = { "add.pbtxt": ["no_rocm"], "conv_2d.pbtxt": ["no_rocm"], diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD index 3d23fb1aeca..f7dbeaf48af 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD @@ -2,6 +2,8 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + licenses(["notice"]) glob_lit_tests( diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc index 302c0103603..ba0fd474a3a 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include "absl/strings/string_view.h" #include "llvm/Support/CommandLine.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" -using llvm::Optional; using llvm::cl::opt; // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s.mlir -o - \ @@ -49,12 +49,12 @@ static opt inputFileName(llvm::cl::Positional, namespace mlir { namespace { -Optional> InjectStatsToFullyConnected( +std::optional> InjectStatsToFullyConnected( llvm::StringRef buffer) { auto model_ptr = tflite::FlatBufferModel::VerifyAndBuildFromBuffer( buffer.data(), buffer.size()); if (nullptr == model_ptr) { - return llvm::None; + return std::nullopt; } std::unique_ptr model(model_ptr->GetModel()->UnPack()); @@ -161,7 +161,7 @@ int main(int argc, char** argv) { } flatbuffers::FlatBufferBuilder builder; flatbuffers::Offset output_model_location = - tflite::Model::Pack(builder, maybe_module.getValue().get()); + tflite::Model::Pack(builder, maybe_module.value().get()); tflite::FinishModelBuffer(builder, output_model_location); std::string output_model_content( reinterpret_cast(builder.GetBufferPointer()), diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir index eb303dbd01d..97e3a647b04 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/lstm.mlir @@ -32,7 +32,7 @@ func.func @testFullyQuantizedLSTM(%arg0: tensor<1x528x!quant.uniform, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { - // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {asymmetric_quantize_inputs = false, cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {asymmetric_quantize_inputs = false, cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, effective_hidden_scale_intermediate = tensor<0x!quant.uniform>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform>, fused_activation_function = "TANH", input_to_cell_intermediate = tensor<0xf32>, input_to_forget_intermediate = tensor<0xf32>, input_to_input_intermediate = tensor<0xf32>, input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor func.return %0 : tensor } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/multi_output_op.json b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/multi_output_op.json index fda0a759f85..7d39ccb9d48 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/multi_output_op.json +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/multi_output_op.json @@ -1,8 +1,11 @@ // RUN: json_to_flatbuffer %p/test_schema.fbs %s | flatbuffer_translate --tflite-flatbuffer-to-mlir --mlir-print-debuginfo -o - | FileCheck %s +// CHECK: #[[LOC0:.*]] = loc("":0:0) // CHECK: "tfl.split" // CHECK-SAME: loc(#[[SPLIT_LOC:.*]]) -// CHECK: #[[SPLIT_LOC]] = loc(fused["output0"("":0:0), "output1"("":0:0)]) +// CHECK: #[[LOC1:.*]] = loc("output0"(#[[LOC0]])) +// CHECK: #[[LOC2:.*]] = loc("output1"(#[[LOC0]])) +// CHECK: #[[SPLIT_LOC]] = loc(fused[#[[LOC1]], #[[LOC2]]]) { "version": 3, diff --git a/tensorflow/compiler/mlir/lite/tests/get-arithmetic-count.mlir b/tensorflow/compiler/mlir/lite/tests/get-arithmetic-count.mlir index 0d1503892b2..fd2075cb4f2 100644 --- a/tensorflow/compiler/mlir/lite/tests/get-arithmetic-count.mlir +++ b/tensorflow/compiler/mlir/lite/tests/get-arithmetic-count.mlir @@ -122,7 +122,7 @@ func.func @testAveragePool2D(tensor<1x10x10x3xf32>) -> tensor<1x10x10x3xf32> { func.func @testTransposeConv(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { %cst = "tfl.no_value"() {value = unit} : () -> none // CHECK: _arithmetic_count = 176160768 : i64 - %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> + %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> func.return %0 : tensor<1x64x84x32xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 660f8a6b7ab..c6562460a5f 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1626,8 +1626,8 @@ func.func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf3 // CHECK: %[[CST:.*]] = arith.constant dense<[2, 0, 1, 3]> : tensor<4xi32> // CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> // CHECK: %[[CST_0:.*]] = "tfl.no_value"() {value} : () -> none - // CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> - // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> + // CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> + // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32> // CHECK: %[[RESULT:.*]] = tfl.add %[[ARG1]], %[[ARG3]] {fused_activation_function = "NONE"} : tensor<15x28x28x1xf32> // CHECK: return %[[RESULT]] : tensor<15x28x28x1xf32> } @@ -2504,3 +2504,17 @@ func.func @sigmoidGrad(%arg0: tensor, %arg1: tensor) -> tens // CHECK-NEXT: [[MUL1:%.+]] = tfl.mul %arg1, [[MUL0]] {fused_activation_function = "NONE"} : tensor // CHECK: return [[MUL1]] } + +func.func @batchmatmul2fullyconnected(%arg0: tensor<4x128x2xf32>) -> (tensor<4x128x1xf32>) { + %0 = "tf.Const"() {value = dense<[[1.0], [2.0]]> : tensor<2x1xf32>} : () -> tensor<2x1xf32> + %1 = "tf.BatchMatMulV2"(%arg0, %0) : (tensor<4x128x2xf32>, tensor<2x1xf32>) -> tensor<4x128x1xf32> + func.return %1 : tensor<4x128x1xf32> + + // CHECK-LABEL: batchmatmul2fullyconnected + // CHECK-DAG: %cst_0 = arith.constant dense<[1, 0]> : tensor<2xi32> + // CHECK: %0 = "tfl.transpose"(%cst, %cst_0) : (tensor<2x1xf32>, tensor<2xi32>) -> tensor<1x2xf32> + // CHECK-DAG: %1 = "tfl.no_value"() {value} : () -> none + // CHECK: %2 = "tfl.fully_connected"(%arg0, %0, %1) {fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<4x128x2xf32>, tensor<1x2xf32>, none) -> tensor<4x128x1xf32> + // CHECK: return %2 : tensor<4x128x1xf32> +} + diff --git a/tensorflow/compiler/mlir/lite/tests/lift_tflite_flex_ops.mlir b/tensorflow/compiler/mlir/lite/tests/lift_tflite_flex_ops.mlir index 99498a96f84..9b49a3edc96 100644 --- a/tensorflow/compiler/mlir/lite/tests/lift_tflite_flex_ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/lift_tflite_flex_ops.mlir @@ -57,6 +57,30 @@ func.func @TfParseExample(%arg0: tensor<1x!tf_type.string>) -> (tensor<1x1x!tf_t // CHECK-SAME: operand_segment_sizes = array, result_segment_sizes = array } +// CHECK-LABEL: TfMapDataset +func.func @TfMapDataset(%arg0: tensor) -> (tensor) { + %0 = "tfl.custom"(%arg0) { + custom_code = "FlexMapDataset", + custom_option = #tfl + } : (tensor) -> tensor + + func.return %0 : tensor +// CHECK: "tf.MapDataset"( +// CHECK-SAME: {Targuments = [], f = @{{.*}}, metadata = "", output_shapes = [#tf_type.shape<>], output_types = [!tf_type.string], preserve_cardinality = true, use_inter_op_parallelism = true} +} + +// CHECK-LABEL: TfTakeWhileDataset +func.func @TfTakeWhileDataset(%arg0: tensor, %arg1: tensor) -> (tensor) { + %0 = "tfl.custom"(%arg0, %arg1) { + custom_code = "FlexTakeWhileDataset", + custom_option = #tfl + } : (tensor, tensor) -> tensor + + func.return %0 : tensor +// CHECK: "tf.TakeWhileDataset"( +// CHECK-SAME: {Targuments = [!tf_type.resource, !tf_type.resource, i64, !tf_type.resource, !tf_type.resource, !tf_type.resource, !tf_type.resource, i64], metadata = "", output_shapes = [#tf_type.shape<>], output_types = [!tf_type.string], predicate = @{{.*}}} +} + // CHECK-LABEL: FailureOnInvalidOp func.func @FailureOnInvalidOp(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>) -> tensor<4xf64> { // expected-error@+1 can't find registered TF op for Nop diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2exec/BUILD b/tensorflow/compiler/mlir/lite/tests/mlir2exec/BUILD index cc69cd46b83..930e0f20b05 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2exec/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/mlir2exec/BUILD @@ -9,6 +9,8 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + licenses(["notice"]) glob_lit_tests( diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD index 62aa45c0dc0..7e748ffe18d 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/BUILD @@ -1,6 +1,8 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + licenses(["notice"]) glob_lit_tests( diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir index 2bb0aa766ff..975e959b052 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/transpose_conv_optional.mlir @@ -78,6 +78,6 @@ func.func @main(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tens // CHECK-NEXT:} %cst = "tfl.no_value"() {value = unit} : () -> none - %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> + %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> func.return %0 : tensor<1x64x84x32xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 0fb6a0595e8..3c3d50243a9 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -1905,6 +1905,13 @@ func.func @testRelu6WithQuantizedTypes(%arg0 : tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> { + %0 = "tfl.relu_0_to_1"(%arg0) : (tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> + func.return %0 : tensor<10x!quant.uniform> +} + +// ----- + func.func @testReluWithDifferentScales(%arg0 : tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> { %0 = "tfl.relu"(%arg0) : (tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> %1 = "tfl.relu_n1_to_1"(%0) : (tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> @@ -2504,7 +2511,7 @@ func.func @testFullyConnectedWithBadOutputShape(%arg0: tensor<1x37xf32>, %arg1: func.func @testTransposeConv(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { %cst = "tfl.no_value"() {value = unit} : () -> none - %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> + %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> func.return %0 : tensor<1x64x84x32xf32> } @@ -2515,7 +2522,7 @@ func.func @testTransposeConvWithOutputThatHasDynamicSizes(%arg0: tensor<4xi32>, // CHECK: %[[NONE:.*]] = "tfl.no_value"() {value} : () -> none // CHECK: "tfl.transpose_conv"(%arg0, %arg1, %arg2, %[[NONE]]) %cst = "tfl.no_value"() {value = unit} : () -> none - %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor + %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor func.return %0 : tensor } @@ -2541,7 +2548,7 @@ func.func @testConvolution2DTransposeNoBias(%arg0: tensor<32x4x4x128xf32>, %arg1 func.func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> { %cst = "tfl.no_value"() {value = unit} : () -> none // expected-error @+1 {{expect output type has rank = 4, got output type tensor<64x84x32xf32>}} - %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<64x84x32xf32> + %0 = "tfl.transpose_conv"(%arg0, %arg1, %arg2, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<64x84x32xf32> func.return %0 : tensor<64x84x32xf32> } @@ -2551,7 +2558,7 @@ func.func @testTransposeConvBadOutputShape(%arg1: tensor<32x4x4x128xf32>, %arg2: %cst = arith.constant dense<[1, 64, 84, 32]> : tensor<4xi32> %cst_1 = "tfl.no_value"() {value = unit} : () -> none // expected-error @+1 {{expect output type tensor<1x64x84x32xf32>, got tensor<1x64x84x31xf32>}} - %0 = "tfl.transpose_conv"(%cst, %arg1, %arg2, %cst_1) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x31xf32> + %0 = "tfl.transpose_conv"(%cst, %arg1, %arg2, %cst_1) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x31xf32> func.return %0 : tensor<1x64x84x31xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 4d4f7e88f47..b225c7355df 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -77,7 +77,7 @@ func.func @fuseAddIntoTransposeConv(%arg0: tensor<1x32x42x128xf32>) -> tensor<1x %cst_1 = arith.constant dense<[1, 64, 84, 32]> : tensor<4xi32> %cst_2 = arith.constant dense<1.0> : tensor<32x4x4x128xf32> %cst_3 = arith.constant dense<[1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0]> : tensor<32xf32> - %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<1x64x84x32xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> func.return %1 : tensor<1x64x84x32xf32> @@ -95,7 +95,7 @@ func.func @fuseSubIntoTransposeConv(%arg0: tensor<1x32x42x128xf32>) -> tensor<1x %cst_1 = arith.constant dense<[1, 64, 84, 32]> : tensor<4xi32> %cst_2 = arith.constant dense<1.0> : tensor<32x4x4x128xf32> %cst_3 = arith.constant dense<[1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0]> : tensor<32xf32> - %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> %1 = "tfl.sub"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<1x64x84x32xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> func.return %1 : tensor<1x64x84x32xf32> @@ -113,7 +113,7 @@ func.func @fuseAddIntoTransposeConvNoBias(%arg0: tensor<1x32x42x128xf32>) -> ten %cst_1 = arith.constant dense<[1, 64, 84, 32]> : tensor<4xi32> %cst_2 = arith.constant dense<1.0> : tensor<32x4x4x128xf32> %cst_3 = "tfl.no_value"() {value} : () -> none - %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> + %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<1x64x84x32xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> func.return %1 : tensor<1x64x84x32xf32> @@ -131,7 +131,7 @@ func.func @fuseMulIntoTransposeConv(%arg0: tensor<1x32x42x128xf32>) -> tensor<1x %cst_1 = arith.constant dense<[1, 64, 84, 32]> : tensor<4xi32> %cst_2 = arith.constant dense<1.0> : tensor<32x4x4x128xf32> %cst_3 = arith.constant dense<[1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0]> : tensor<32xf32> - %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> + %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> %1 = "tfl.mul"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<1x64x84x32xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> func.return %1 : tensor<1x64x84x32xf32> @@ -149,7 +149,7 @@ func.func @fuseMulIntoTransposeConvNoBias(%arg0: tensor<1x32x42x128xf32>) -> ten %cst_1 = arith.constant dense<[1, 64, 84, 32]> : tensor<4xi32> %cst_2 = arith.constant dense<1.0> : tensor<32x4x4x128xf32> %cst_3 = "tfl.no_value"() {value} : () -> none - %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> + %0 = "tfl.transpose_conv"(%cst_1, %cst_2, %arg0, %cst_3) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> %1 = "tfl.mul"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<1x64x84x32xf32>, tensor<32xf32>) -> tensor<1x64x84x32xf32> func.return %1 : tensor<1x64x84x32xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir index e1e4036a43e..de372f38daf 100644 --- a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir @@ -161,7 +161,7 @@ func.func @FoldTranspose(%arg0: tensor<1x10x20x3xf32>) -> tensor<1x20x40x16xf32> %1 = "tfl.pseudo_qconst"() {qtype = tensor<3x3x16x3x!quant.uniform:f32, 0.047244094488188976>>, value = dense<"0x0303040002010303FFFFFD0304020401FF0000FEFF0003FF01FD0203FF0202FEFE0003010201FD04FE0402030303000202FD0100FDFE0402FEFEFE01020101FD0204FEFDFC03FFFE0101FDFE02040002FDFFFE03FFFE0201FEFDFF00FFFDFEFD030201FD01FC01FF010003FF0401FCFD0101FC0000FE03FEFE010102000002FE02030100FE00FEFDFD0003FD000303000103FE01FF02000002FF0101FDFDFF02FFFF00000203FF0003030302FDFF03FFFF030001020102FD04FE0104FE030401030102FEFCFEFD03FD03FD000102FE02020001020000FE030202030103FFFC01FC000302000304FCFF03FD04FC00010400010100030303FC02FCFEFE01000303000100010003FE000303010301010102FEFC01FD020301FFFDFFFCFDFEFCFE030001FDFCFE000202FE020300FD00FD02FF0001FF0002FF01FD010102FDFE04FCFE0000FD01000101FF0402FF020103FC020301FF03010204FDFFFE0202FF0302FF02FFFF01FF01FF04FD0002FF00FC00FC0101010404FE03040300000301FD0001FE04FF040103FF01FD0301FF0002040403FF03FE04FDFD0103FCFE01FDFCFF03FC010200FDFE020200FF00FFFC03FE"> : tensor<3x3x16x3xi8>} : () -> tensor<3x3x16x3x!quant.uniform:f32, 0.047244094488188976>> %2 = "tfl.quantize"(%arg0) {qtype = tensor<1x10x20x3x!quant.uniform>} : (tensor<1x10x20x3xf32>) -> tensor<1x10x20x3x!quant.uniform> %3 = "tfl.transpose"(%1, %cst_0) : (tensor<3x3x16x3x!quant.uniform:f32, 0.047244094488188976>>, tensor<4xi32>) -> tensor<16x3x3x3x!quant.uniform:f32, 0.047244094488188976>> - %4 = "tfl.transpose_conv"(%cst, %3, %2, %0) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<16x3x3x3x!quant.uniform:f32, 0.047244094488188976>>, tensor<1x10x20x3x!quant.uniform>, tensor<16x!quant.uniform>) -> tensor<1x20x40x16x!quant.uniform> + %4 = "tfl.transpose_conv"(%cst, %3, %2, %0) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<16x3x3x3x!quant.uniform:f32, 0.047244094488188976>>, tensor<1x10x20x3x!quant.uniform>, tensor<16x!quant.uniform>) -> tensor<1x20x40x16x!quant.uniform> %5 = "tfl.dequantize"(%4) : (tensor<1x20x40x16x!quant.uniform>) -> tensor<1x20x40x16xf32> return %5 : tensor<1x20x40x16xf32> @@ -175,7 +175,7 @@ func.func @FoldReshape(%arg0: tensor<4xi32>, %arg1: tensor<1x48x80x16x!quant.uni %cst = arith.constant dense<[1, 2, 2, 16]> : tensor<4xi32> %0 = "tfl.pseudo_qconst"() {qtype = tensor<2x2x1x16x!quant.uniform:f32, 0.022395913056501255>>, value = dense<[[[[12, -60, -51, -59, -62, 33, 53, 17, -31, 50, 27, 7, -19, -34, -14, -26]], [[47, -84, -32, -36, -102, -8, -8, 35, -33, 59, 95, 40, -25, -30, -55, 25]]], [[[4, -41, -61, 12, -23, 48, 40, 15, -39, 52, 81, -62, -24, 17, -7, -52]], [[40, -70, -45, 32, -43, 2, -30, 34, -35, 58, 77, -28, -30, 37, -47, -5]]]]> : tensor<2x2x1x16xi8>} : () -> tensor<2x2x1x16x!quant.uniform:f32, 0.022395913056501255>> %1 = "tfl.reshape"(%0, %cst) : (tensor<2x2x1x16x!quant.uniform:f32, 0.022395913056501255>>, tensor<4xi32>) -> tensor<1x2x2x16x!quant.uniform:f32, 0.022395913056501255>> - %2 = "tfl.transpose_conv"(%arg0, %1, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x2x2x16x!quant.uniform:f32, 0.022395913056501255>>, tensor<1x48x80x16x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<1x96x160x1x!quant.uniform> + %2 = "tfl.transpose_conv"(%arg0, %1, %arg1, %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<1x2x2x16x!quant.uniform:f32, 0.022395913056501255>>, tensor<1x48x80x16x!quant.uniform>, tensor<1x!quant.uniform>) -> tensor<1x96x160x1x!quant.uniform> return %2 : tensor<1x96x160x1x!quant.uniform> // CHECK-NOT: "tfl.reshape" // CHECK{LITERAL}: "tfl.pseudo_qconst"() {qtype = tensor<1x2x2x16x!quant.uniform:f32, 0.022395913056501255>>, value = dense<[[[[12, -60, -51, -59, -62, 33, 53, 17, -31, 50, 27, 7, -19, -34, -14, -26], [47, -84, -32, -36, -102, -8, -8, 35, -33, 59, 95, 40, -25, -30, -55, 25]], [[4, -41, -61, 12, -23, 48, 40, 15, -39, 52, 81, -62, -24, 17, -7, -52], [40, -70, -45, 32, -43, 2, -30, 34, -35, 58, 77, -28, -30, 37, -47, -5]]]]> : tensor<1x2x2x16xi8>} : () -> tensor<1x2x2x16x!quant.uniform:f32, 0.022395913056501255>> diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir index 99c6f24481e..7aac8662a83 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir @@ -204,7 +204,7 @@ func.func @inference_standard_lstm_time_major(%arg0: tensor, %arg1: t // CHECK-DAG: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_19:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor +// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor // CHECK-DAG: [[VAL_21:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32> // CHECK-DAG: [[VAL_22:%.*]] = arith.constant dense<0> : tensor<3xi32> // CHECK-DAG: [[VAL_23:%.*]] = arith.constant dense<1> : tensor<3xi32> @@ -218,6 +218,61 @@ func.func @inference_standard_lstm_time_major(%arg0: tensor, %arg1: t // ----- +module { +func.func @inference_standard_indy_lstm_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x4xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "indy_lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { + %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32> + %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32> + %2 = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } : () -> tensor<2xi32> + %3 = "tf.Transpose"(%arg4, %2) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32> + %4 = "tf.MatrixDiag"(%3) : (tensor<4x10xf32>) -> tensor<4x10x10xf32> + %5 = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %6 = "tf.ConcatV2"(%4, %5) : (tensor<4x10x10xf32>, tensor) -> tensor<40x10xf32> + %7 = "tf.BatchMatMulV2"(%1, %6) {adj_x = false, adj_y = false} : (tensor<8x8x40xf32>, tensor<40x10xf32>) -> tensor<8x8x10xf32> + %8 = "tf.Add"(%7, %arg1) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32> + %9 = "tf.Add"(%7, %arg2) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32> + %10 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32> + %11 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor} : () -> tensor + func.return %10, %9, %10, %10, %11 : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor +} + +// CHECK: func @inference_standard_indy_lstm_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x4xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "indy_lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { +// CHECK: [[VAL_6:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32> +// CHECK: [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32> +// CHECK-DAG: [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK-DAG: [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>) +// CHECK-DAG: [[VAL_13:%.*]] = "tf.Const"() {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK-DAG: [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<4x10xf32>, tensor<4xi32>, tensor) -> (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) +// CHECK-DAG: [[VAL_20:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: [[VAL_21:%.*]] = "tf.Reshape"([[VAL_15]]#0, [[VAL_20]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32> +// CHECK-DAG: [[VAL_22:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: [[VAL_23:%.*]] = "tf.Reshape"([[VAL_15]]#1, [[VAL_22]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32> +// CHECK-DAG: [[VAL_24:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: [[VAL_25:%.*]] = "tf.Reshape"([[VAL_15]]#2, [[VAL_24]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32> +// CHECK-DAG: [[VAL_26:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: [[VAL_27:%.*]] = "tf.Reshape"([[VAL_15]]#3, [[VAL_26]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32> +// CHECK-DAG: [[VAL_28:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK-DAG: [[VAL_29:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_30:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_28]], [[VAL_29]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) +// CHECK: [[VAL_31:%.*]] = "tfl.no_value"() {value} : () -> none +// CHECK: [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK-DAG: [[VAL_33:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32> +// CHECK-DAG: [[VAL_34:%.*]] = arith.constant dense<0> : tensor<3xi32> +// CHECK-DAG: [[VAL_35:%.*]] = arith.constant dense<1> : tensor<3xi32> +// CHECK: [[VAL_36:%.*]] = "tf.StridedSlice"([[VAL_32]], [[VAL_33]], [[VAL_34]], [[VAL_35]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32> +// CHECK-DAG: [[VAL_37:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> +// CHECK-DAG: [[VAL_38:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> +// CHECK-DAG: [[VAL_39:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: return [[VAL_36]], [[VAL_32]], [[VAL_37]], [[VAL_38]], [[VAL_39]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor +// CHECK: } + +} + +// ----- + module { func.func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} { %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32> @@ -245,7 +300,7 @@ func.func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg // CHECK-DAG: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_19:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> // CHECK-DAG: [[VAL_21:%.*]] = arith.constant dense<[0, -1, 0]> : tensor<3xi32> // CHECK-DAG: [[VAL_22:%.*]] = arith.constant dense<0> : tensor<3xi32> // CHECK-DAG: [[VAL_23:%.*]] = arith.constant dense<1> : tensor<3xi32> @@ -260,6 +315,61 @@ func.func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg // ----- +module { +func.func @inference_standard_indy_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x4xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "indy_lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} { + %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32> + %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32> + %2 = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } : () -> tensor<2xi32> + %3 = "tf.Transpose"(%arg4, %2) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32> + %4 = "tf.MatrixDiag"(%3) : (tensor<4x10xf32>) -> tensor<4x10x10xf32> + %5 = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %6 = "tf.ConcatV2"(%4, %5) : (tensor<4x10x10xf32>, tensor) -> tensor<40x10xf32> + %7 = "tf.BatchMatMulV2"(%1, %6) {adj_x = false, adj_y = false} : (tensor<8x8x40xf32>, tensor<40x10xf32>) -> tensor<8x8x10xf32> + %8 = "tf.Add"(%7, %arg1) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32> + %9 = "tf.Add"(%7, %arg2) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32> + %10 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32> + %11 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor} : () -> tensor + func.return %10, %9, %10, %10, %11 : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor +} + +// CHECK: func @inference_standard_indy_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x4xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "indy_lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} { +// CHECK: [[VAL_6:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32> +// CHECK: [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32> +// CHECK-DAG: [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK-DAG: [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>) +// CHECK-DAG: [[VAL_13:%.*]] = "tf.Const"() {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK-DAG: [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<4x10xf32>, tensor<4xi32>, tensor) -> (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) +// CHECK-DAG: [[VAL_20:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: [[VAL_21:%.*]] = "tf.Reshape"([[VAL_15]]#0, [[VAL_20]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32> +// CHECK-DAG: [[VAL_22:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: [[VAL_23:%.*]] = "tf.Reshape"([[VAL_15]]#1, [[VAL_22]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32> +// CHECK-DAG: [[VAL_24:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: [[VAL_25:%.*]] = "tf.Reshape"([[VAL_15]]#2, [[VAL_24]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32> +// CHECK-DAG: [[VAL_26:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: [[VAL_27:%.*]] = "tf.Reshape"([[VAL_15]]#3, [[VAL_26]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32> +// CHECK-DAG: [[VAL_28:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK-DAG: [[VAL_29:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_30:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_28]], [[VAL_29]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) +// CHECK: [[VAL_31:%.*]] = "tfl.no_value"() {value} : () -> none +// CHECK: [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK-DAG: [[VAL_33:%.*]] = arith.constant dense<[0, -1, 0]> : tensor<3xi32> +// CHECK-DAG: [[VAL_34:%.*]] = arith.constant dense<0> : tensor<3xi32> +// CHECK-DAG: [[VAL_35:%.*]] = arith.constant dense<1> : tensor<3xi32> +// CHECK: [[VAL_36:%.*]] = "tf.StridedSlice"([[VAL_32]], [[VAL_33]], [[VAL_34]], [[VAL_35]]) {begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32> +// CHECK-DAG: [[VAL_37:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> +// CHECK-DAG: [[VAL_38:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> +// CHECK-DAG: [[VAL_39:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: return [[VAL_36]], [[VAL_32]], [[VAL_37]], [[VAL_38]], [[VAL_39]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor +// CHECK: } + +} + +// ----- + module { func.func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor, tensor<8x40xf32>) -> tensor @@ -289,7 +399,7 @@ func.func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor : tensor} : () -> tensor // CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_21:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor +// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor // CHECK-DAG: [[VAL_23:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32> // CHECK-DAG: [[VAL_24:%.*]] = arith.constant dense<0> : tensor<3xi32> // CHECK-DAG: [[VAL_25:%.*]] = arith.constant dense<1> : tensor<3xi32> @@ -304,6 +414,63 @@ func.func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x4xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "indy_lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { + %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32> + %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32> + %2 = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } : () -> tensor<2xi32> + %3 = "tf.Transpose"(%arg4, %2) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32> + %4 = "tf.MatrixDiag"(%3) : (tensor<4x10xf32>) -> tensor<4x10x10xf32> + %5 = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %6 = "tf.ConcatV2"(%4, %5) : (tensor<4x10x10xf32>, tensor) -> tensor<40x10xf32> + %7 = "tf.BatchMatMulV2"(%1, %6) {adj_x = false, adj_y = false} : (tensor<8x8x40xf32>, tensor<40x10xf32>) -> tensor<8x8x10xf32> + %8 = "tf.Add"(%7, %arg1) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32> + %9 = "tf.Add"(%7, %arg2) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32> + %10 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32> + %11 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor} : () -> tensor + func.return %10, %9, %10, %10, %11 : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor +} + +// CHECK: func @inference_standard_indy_lstm_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x4xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "indy_lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { +// CHECK: [[VAL_40:%.*]] = arith.constant dense<0> : tensor<1xi32> +// CHECK: [[VAL_41:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_40]]) : (tensor<8x8x8xf32>, tensor<1xi32>) -> tensor<8x8x8xf32> +// CHECK: [[VAL_6:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32> +// CHECK: [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32> +// CHECK-DAG: [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK-DAG: [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>) +// CHECK-DAG: [[VAL_13:%.*]] = "tf.Const"() {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK-DAG: [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<4x10xf32>, tensor<4xi32>, tensor) -> (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) +// CHECK-DAG: [[VAL_20:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: [[VAL_21:%.*]] = "tf.Reshape"([[VAL_15]]#0, [[VAL_20]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32> +// CHECK-DAG: [[VAL_22:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: [[VAL_23:%.*]] = "tf.Reshape"([[VAL_15]]#1, [[VAL_22]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32> +// CHECK-DAG: [[VAL_24:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: [[VAL_25:%.*]] = "tf.Reshape"([[VAL_15]]#2, [[VAL_24]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32> +// CHECK-DAG: [[VAL_26:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: [[VAL_27:%.*]] = "tf.Reshape"([[VAL_15]]#3, [[VAL_26]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32> +// CHECK-DAG: [[VAL_28:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK-DAG: [[VAL_29:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_30:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_28]], [[VAL_29]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) +// CHECK: [[VAL_31:%.*]] = "tfl.no_value"() {value} : () -> none +// CHECK: [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_41]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK-DAG: [[VAL_33:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32> +// CHECK-DAG: [[VAL_34:%.*]] = arith.constant dense<0> : tensor<3xi32> +// CHECK-DAG: [[VAL_35:%.*]] = arith.constant dense<1> : tensor<3xi32> +// CHECK: [[VAL_36:%.*]] = "tf.StridedSlice"([[VAL_32]], [[VAL_33]], [[VAL_34]], [[VAL_35]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32> +// CHECK-DAG: [[VAL_37:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> +// CHECK-DAG: [[VAL_38:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> +// CHECK-DAG: [[VAL_39:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: return [[VAL_36]], [[VAL_32]], [[VAL_37]], [[VAL_38]], [[VAL_39]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor +// CHECK: } + +} + +// ----- + module { func.func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32> @@ -333,7 +500,7 @@ func.func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8 // CHECK-DAG: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_21:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> // CHECK-DAG: [[VAL_23:%.*]] = arith.constant dense<[0, -1, 0]> : tensor<3xi32> // CHECK-DAG: [[VAL_24:%.*]] = arith.constant dense<0> : tensor<3xi32> // CHECK-DAG: [[VAL_25:%.*]] = arith.constant dense<1> : tensor<3xi32> @@ -348,6 +515,63 @@ func.func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8 // ----- +module { +func.func @inference_standard_indy_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x4xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "indy_lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { + %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32> + %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32> + %2 = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } : () -> tensor<2xi32> + %3 = "tf.Transpose"(%arg4, %2) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32> + %4 = "tf.MatrixDiag"(%3) : (tensor<4x10xf32>) -> tensor<4x10x10xf32> + %5 = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %6 = "tf.ConcatV2"(%4, %5) : (tensor<4x10x10xf32>, tensor) -> tensor<40x10xf32> + %7 = "tf.BatchMatMulV2"(%1, %6) {adj_x = false, adj_y = false} : (tensor<8x8x40xf32>, tensor<40x10xf32>) -> tensor<8x8x10xf32> + %8 = "tf.Add"(%7, %arg1) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32> + %9 = "tf.Add"(%7, %arg2) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32> + %10 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32> + %11 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor} : () -> tensor + func.return %10, %9, %10, %10, %11 : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor +} + +// CHECK: func @inference_standard_indy_lstm_non_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x4xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "indy_lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { +// CHECK: [[VAL_40:%.*]] = arith.constant dense<1> : tensor<1xi32> +// CHECK: [[VAL_41:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_40]]) : (tensor<8x8x8xf32>, tensor<1xi32>) -> tensor<8x8x8xf32> +// CHECK: [[VAL_6:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32> +// CHECK: [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32> +// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32> +// CHECK-DAG: [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK-DAG: [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>) +// CHECK-DAG: [[VAL_13:%.*]] = "tf.Const"() {value = dense<1> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK-DAG: [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<4x10xf32>, tensor<4xi32>, tensor) -> (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) +// CHECK-DAG: [[VAL_20:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: [[VAL_21:%.*]] = "tf.Reshape"([[VAL_15]]#0, [[VAL_20]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32> +// CHECK-DAG: [[VAL_22:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: [[VAL_23:%.*]] = "tf.Reshape"([[VAL_15]]#1, [[VAL_22]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32> +// CHECK-DAG: [[VAL_24:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: [[VAL_25:%.*]] = "tf.Reshape"([[VAL_15]]#2, [[VAL_24]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32> +// CHECK-DAG: [[VAL_26:%.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK: [[VAL_27:%.*]] = "tf.Reshape"([[VAL_15]]#3, [[VAL_26]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32> +// CHECK-DAG: [[VAL_28:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32> +// CHECK-DAG: [[VAL_29:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: [[VAL_30:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_28]], [[VAL_29]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) +// CHECK: [[VAL_31:%.*]] = "tfl.no_value"() {value} : () -> none +// CHECK: [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_41]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK-DAG: [[VAL_33:%.*]] = arith.constant dense<[0, -1, 0]> : tensor<3xi32> +// CHECK-DAG: [[VAL_34:%.*]] = arith.constant dense<0> : tensor<3xi32> +// CHECK-DAG: [[VAL_35:%.*]] = arith.constant dense<1> : tensor<3xi32> +// CHECK: [[VAL_36:%.*]] = "tf.StridedSlice"([[VAL_32]], [[VAL_33]], [[VAL_34]], [[VAL_35]]) {begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32> +// CHECK-DAG: [[VAL_37:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> +// CHECK-DAG: [[VAL_38:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> +// CHECK-DAG: [[VAL_39:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: return [[VAL_36]], [[VAL_32]], [[VAL_37]], [[VAL_38]], [[VAL_39]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor +// CHECK: } + +} + +// ----- + module { func.func @inference_can_fuse(%arg0: tensor, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) { %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = f32, value = dense<0.000000e+00> : tensor} : () -> tensor @@ -382,7 +606,7 @@ func.func @inference_standard_lstm_time_major_can_fuse(%arg0: tensor, // CHECK-DAG: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_19:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor +// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor // CHECK-DAG: [[VAL_21:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32> // CHECK-DAG: [[VAL_22:%.*]] = arith.constant dense<0> : tensor<3xi32> // CHECK-DAG: [[VAL_23:%.*]] = arith.constant dense<1> : tensor<3xi32> @@ -432,7 +656,7 @@ func.func @inference_standard_lstm_time_major_can_fuse_last_output(%arg0: tensor // CHECK-DAG: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_19:%.*]] = "tfl.no_value"() {value} : () -> none -// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor +// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor // CHECK-DAG: [[VAL_21:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32> // CHECK-DAG: [[VAL_22:%.*]] = arith.constant dense<0> : tensor<3xi32> // CHECK-DAG: [[VAL_23:%.*]] = arith.constant dense<1> : tensor<3xi32> diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir index 6cb9dc0ac4d..b549b564515 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-dynamic-range.mlir @@ -244,7 +244,7 @@ func.func @QuantizeTransposeConvWeightOnly(%arg0: tensor<32x4x4x128xf32>, %arg1: %0 = "quantfork.stats"(%arg0) {layerStats = dense<[0.000000e+00, 1.000000e+01]> : tensor<2xf32>} : (tensor<32x4x4x128xf32>) -> tensor<32x4x4x128xf32> %w = arith.constant dense<127.0> : tensor<1x32x42x128xf32> %b = arith.constant dense<0.0> : tensor<1x32x42x128xf32> - %tconv = "tfl.transpose_conv"(%arg1, %w, %0, %b) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x32x42x128xf32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x32x42x128xf32> + %tconv = "tfl.transpose_conv"(%arg1, %w, %0, %b) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<1x32x42x128xf32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x32x42x128xf32> %tconv_s = "quantfork.stats"(%tconv) {layerStats = dense<[0.000000e+00, 1.000000e+01]> : tensor<2xf32>} : (tensor<1x32x42x128xf32>) -> tensor<1x32x42x128xf32> func.return %tconv_s : tensor<1x32x42x128xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir new file mode 100644 index 00000000000..2a4005baa07 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-post-training-16bits.mlir @@ -0,0 +1,213 @@ +// RUN: tf-opt %s -tfl-prepare-quantize="quantize-signed=true post-training-quantize=true activation-number-of-bits=16" -cse | FileCheck %s + +// CHECK-LABEL: QuantizeUnidirectionalLstmFullPerTensor +func.func @QuantizeUnidirectionalLstmFullPerTensor(%arg0: tensor<1x2x3xf32>) -> (tensor<1x2x3xf32>) { + %input = "quantfork.stats"(%arg0) {layerStats = dense<[0.0, 1.0]> : tensor<2xf32>} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + %1 = "tfl.pseudo_const"() {value = dense<[[0.1]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %2 = "tfl.pseudo_const"() {value = dense<[[0.2]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %3 = "tfl.pseudo_const"() {value = dense<[[0.3]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %4 = "tfl.pseudo_const"() {value = dense<[[0.4]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %5 = "tfl.pseudo_const"() {value = dense<[[0.5]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %6 = "tfl.pseudo_const"() {value = dense<[[0.6]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %7 = "tfl.pseudo_const"() {value = dense<[[0.7]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %8 = "tfl.pseudo_const"() {value = dense<[[0.8]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %9 = "tfl.no_value"() {value} : () -> none + %10 = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %11 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %recurrent_input = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x3xf32>} : () -> tensor<1x3xf32> + %recurrent_stats = "quantfork.stats"(%recurrent_input) {layerStats = dense<[0.0, 1.0]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + %cell_input = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<1x3xf32>} : () -> tensor<1x3xf32> + %cell_stats = "quantfork.stats"(%cell_input) {layerStats = dense<[0.0, 1.0]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + %16 = "tfl.unidirectional_sequence_lstm"( + %input, + %1, %2, %3, %4, + %5, %6, %7, %8, + %9, %9, %9, + %10, %11, + %10, %10, + %9, %9, + %recurrent_stats, %cell_stats, + %9, %9, %9, %9) { + asymmetric_quantize_inputs = false, + cell_clip = 1.000000e+01 : f32, + effective_hidden_scale_intermediate = tensor<0x!quant.calibrated>>, + fused_activation_function = "TANH", + input_to_cell_intermediate = tensor<0xf32>, + input_to_forget_intermediate = tensor<0xf32>, + input_to_input_intermediate = tensor<0xf32>, + input_to_output_intermediate = tensor<0xf32>, + proj_clip = 0.000000e+00 : f32, + time_major = false} : ( + tensor<1x2x3xf32>, + tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, + tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, + none, none, none, + tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, + none, none, + tensor<1x3xf32>, tensor<1x3xf32>, + none, none, none, none) -> tensor<1x2x3xf32> + %17 = "quantfork.stats"(%16) {layerStats = dense<[-0.1, 0.1]> : tensor<2xf32>} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + func.return %17 : tensor<1x2x3xf32> + +// CHECK-DAG: %[[input_0:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x2x3x!quant.uniform:f32, 3.0518509475997192E-5>>) -> tensor<1x2x3xf32> +// CHECK-DAG: %[[input_1:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 7.8740158653634745E-4>>) -> tensor<1x1xf32> +// CHECK-DAG: %[[input_2:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0015748031730726949>>) -> tensor<1x1xf32> +// CHECK-DAG: %[[input_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x1xf32> +// CHECK-DAG: %[[input_4:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0031496063461453898>>) -> tensor<1x1xf32> +// CHECK-DAG: %[[input_5:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.003937007874015748>>) -> tensor<1x1xf32> +// CHECK-DAG: %[[input_6:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0047244096365500624>>) -> tensor<1x1xf32> +// CHECK-DAG: %[[input_7:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0055118109297564652>>) -> tensor<1x1xf32> +// CHECK-DAG: %[[input_8:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0062992126922907796>>) -> tensor<1x1xf32> +// CHECK-DAG: %[[input_9:.*]] = "tfl.no_value"() {value} : () -> none +// CHECK-DAG: %[[input_10:.*]] = "tfl.dequantize"({{.*}}) : (tensor<3x!quant.uniform>) -> tensor<3xf32> +// CHECK-DAG: %[[input_11:.*]] = "tfl.dequantize"({{.*}}) : (tensor<3x!quant.uniform>) -> tensor<3xf32> +// CHECK-DAG: %[[input_12:.*]] = "tfl.dequantize"({{.*}}) : (tensor<3x!quant.uniform>) -> tensor<3xf32> +// CHECK-DAG: %[[input_13:.*]] = "tfl.dequantize"({{.*}}) : (tensor<3x!quant.uniform>) -> tensor<3xf32> +// CHECK-DAG: %[[input_14:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK-DAG: %[[input_15:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"( +// CHECK-SAME: %[[input_0]], +// CHECK-SAME: %[[input_1]], %[[input_2]], %[[input_3]], %[[input_4]], +// CHECK-SAME: %[[input_5]], %[[input_6]], %[[input_7]], %[[input_8]], +// CHECK-SAME: %[[input_9]], %[[input_9]], %[[input_9]], +// CHECK-SAME: %[[input_10]], %[[input_11]], %[[input_12]], %[[input_13]], +// CHECK-SAME: %[[input_9]], %[[input_9]], +// CHECK-SAME: %[[input_14]], %[[input_15]], +// CHECK-SAME: %[[input_9]], %[[input_9]], %[[input_9]], %[[input_9]]) { +// CHECK-SAME: asymmetric_quantize_inputs = false, +// CHECK-SAME: cell_clip = 1.000000e+01 : f32, +// CHECK-SAME: effective_hidden_scale_intermediate = tensor<0x!quant.uniform>, +// CHECK-SAME: fused_activation_function = "TANH", +// CHECK-SAME: input_to_cell_intermediate = tensor<0xf32>, +// CHECK-SAME: input_to_forget_intermediate = tensor<0xf32>, +// CHECK-SAME: input_to_input_intermediate = tensor<0xf32>, +// CHECK-SAME: input_to_output_intermediate = tensor<0xf32>, +// CHECK-SAME: proj_clip = 0.000000e+00 : f32, +// CHECK-SAME: time_major = false} : ( +// CHECK-SAME: tensor<1x2x3xf32>, +// CHECK-SAME: tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, +// CHECK-SAME: tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, +// CHECK-SAME: none, none, none, +// CHECK-SAME: tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, +// CHECK-SAME: none, none, +// CHECK-SAME: tensor<1x3xf32>, tensor<1x3xf32>, +// CHECK-SAME: none, none, none, none) +// CHECK-SAME: -> tensor<1x2x3xf32> +// CHECK: "tfl.quantize"(%[[lstm]]) {qtype = tensor<1x2x3x!quant.uniform>, volatile} : (tensor<1x2x3xf32>) -> tensor<1x2x3x!quant.uniform> + +} + +// CHECK-LABEL: QuantizeUnidirectionalLstmFullPerAxis +func.func @QuantizeUnidirectionalLstmFullPerAxis(%arg0: tensor<1x2x3xf32>) -> (tensor<1x2x3xf32>) { + %input = "quantfork.stats"(%arg0) { + layerStats = dense<[0.0, 1.0]> : tensor<2xf32>, + axisStats = dense<[ + [-1.0, 1.0], + [-8.0, 8.0], + [-0.5, 0.5] + ]> : tensor<3x2xf32>, axis = 2 : i64 + } : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + %1 = "tfl.pseudo_const"() {value = dense<[[0.1]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %2 = "tfl.pseudo_const"() {value = dense<[[0.2]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %3 = "tfl.pseudo_const"() {value = dense<[[0.3]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %4 = "tfl.pseudo_const"() {value = dense<[[0.4]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %5 = "tfl.pseudo_const"() {value = dense<[[0.5]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %6 = "tfl.pseudo_const"() {value = dense<[[0.6]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %7 = "tfl.pseudo_const"() {value = dense<[[0.7]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %8 = "tfl.pseudo_const"() {value = dense<[[0.8]]> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %9 = "tfl.no_value"() {value} : () -> none + %10 = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %11 = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %recurrent_input = "tfl.pseudo_const"() {value = dense<0.000000e+00> : tensor<1x3xf32>} : () -> tensor<1x3xf32> + %recurrent_stats = "quantfork.stats"(%recurrent_input) {layerStats = dense<[0.0, 1.0]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + %cell_input = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<1x3xf32>} : () -> tensor<1x3xf32> + %cell_stats = "quantfork.stats"(%cell_input) { + layerStats = dense<[0.0, 1.0]> : tensor<2xf32>, + axisStats = dense<[ + [-1.0, 1.0], + [-8.0, 8.0], + [-0.5, 0.5] + ]> : tensor<3x2xf32>, axis = 1 : i64 + } : (tensor<1x3xf32>) -> tensor<1x3xf32> + %16 = "tfl.unidirectional_sequence_lstm"( + %input, + %1, %2, %3, %4, + %5, %6, %7, %8, + %9, %9, %9, + %10, %11, + %10, %10, + %9, %9, + %recurrent_stats, %cell_stats, + %9, %9, %9, %9) { + asymmetric_quantize_inputs = false, + cell_clip = 1.000000e+01 : f32, + effective_hidden_scale_intermediate = tensor<0x!quant.calibrated>>, + fused_activation_function = "TANH", + input_to_cell_intermediate = tensor<0xf32>, + input_to_forget_intermediate = tensor<0xf32>, + input_to_input_intermediate = tensor<0xf32>, + input_to_output_intermediate = tensor<0xf32>, + proj_clip = 0.000000e+00 : f32, + time_major = false} : ( + tensor<1x2x3xf32>, + tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, + tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, + none, none, none, + tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, + none, none, + tensor<1x3xf32>, tensor<1x3xf32>, + none, none, none, none) -> tensor<1x2x3xf32> + %17 = "quantfork.stats"(%16) { + layerStats = dense<[0.0, 1.0]> : tensor<2xf32>, + axisStats = dense<[ + [-1.0, 1.0], + [-8.0, 8.0], + [-0.5, 0.5] + ]> : tensor<3x2xf32>, axis = 2 : i64 + } : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + func.return %17 : tensor<1x2x3xf32> + +// CHECK-DAG: %[[input_0:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x2x3x!quant.uniform:f32, {{3.0518509475997192E-5}}>>) -> tensor<1x2x3xf32> +// CHECK-DAG: %[[input_1:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 7.8740158653634745E-4>>) -> tensor<1x1xf32> +// CHECK-DAG: %[[input_2:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0015748031730726949>>) -> tensor<1x1xf32> +// CHECK-DAG: %[[input_3:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x1xf32> +// CHECK-DAG: %[[input_4:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0031496063461453898>>) -> tensor<1x1xf32> +// CHECK-DAG: %[[input_5:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.003937007874015748>>) -> tensor<1x1xf32> +// CHECK-DAG: %[[input_6:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0047244096365500624>>) -> tensor<1x1xf32> +// CHECK-DAG: %[[input_7:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0055118109297564652>>) -> tensor<1x1xf32> +// CHECK-DAG: %[[input_8:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x1x!quant.uniform:f32, 0.0062992126922907796>>) -> tensor<1x1xf32> +// CHECK-DAG: %[[input_9:.*]] = "tfl.no_value"() {value} : () -> none +// CHECK-DAG: %[[input_10:.*]] = "tfl.dequantize"({{.*}}) : (tensor<3x!quant.uniform>) -> tensor<3xf32> +// CHECK-DAG: %[[input_11:.*]] = "tfl.dequantize"({{.*}}) : (tensor<3x!quant.uniform>) -> tensor<3xf32> +// CHECK-DAG: %[[input_12:.*]] = "tfl.dequantize"({{.*}}) : (tensor<3x!quant.uniform>) -> tensor<3xf32> +// CHECK-DAG: %[[input_13:.*]] = "tfl.dequantize"({{.*}}) : (tensor<3x!quant.uniform>) -> tensor<3xf32> +// CHECK-DAG: %[[input_14:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK-DAG: %[[input_15:.*]] = "tfl.dequantize"({{.*}}) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: %31 = "tfl.unidirectional_sequence_lstm"( +// CHECK-SAME: %[[input_0]], +// CHECK-SAME: %[[input_1]], %[[input_2]], %[[input_3]], %[[input_4]], +// CHECK-SAME: %[[input_5]], %[[input_6]], %[[input_7]], %[[input_8]], +// CHECK-SAME: %[[input_9]], %[[input_9]], %[[input_9]], +// CHECK-SAME: %[[input_10]], %[[input_11]], %[[input_12]], %[[input_13]], +// CHECK-SAME: %[[input_9]], %[[input_9]], +// CHECK-SAME: %[[input_14]], %[[input_15]], +// CHECK-SAME: %[[input_9]], %[[input_9]], %[[input_9]], %[[input_9]]) { +// CHECK-SAME: asymmetric_quantize_inputs = false, +// CHECK-SAME: cell_clip = 1.000000e+01 : f32, effective_hidden_scale_intermediate = tensor<0x!quant.uniform>, +// CHECK-SAME: fused_activation_function = "TANH", +// CHECK-SAME: input_to_cell_intermediate = tensor<0xf32>, +// CHECK-SAME: input_to_forget_intermediate = tensor<0xf32>, +// CHECK-SAME: input_to_input_intermediate = tensor<0xf32>, +// CHECK-SAME: input_to_output_intermediate = tensor<0xf32>, proj_clip = 0.000000e+00 : f32, time_major = false} : ( +// CHECK-SAME: tensor<1x2x3xf32>, +// CHECK-SAME: tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, +// CHECK-SAME: tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, tensor<1x1xf32>, +// CHECK-SAME: none, none, none, +// CHECK-SAME: tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, +// CHECK-SAME: none, none, +// CHECK-SAME: tensor<1x3xf32>, tensor<1x3xf32>, +// CHECK-SAME: none, none, none, none) +// CHECK-SAME: -> tensor<1x2x3xf32> +// CHECK: %32 = "tfl.quantize"(%31) {qtype = tensor<1x2x3x!quant.uniform>, volatile} : (tensor<1x2x3xf32>) -> tensor<1x2x3x!quant.uniform> + +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir index 838851740cc..6e9ca99e11f 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir @@ -188,7 +188,7 @@ func.func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x11 func.func @QuantizeTransposeConv(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<4xi32>) -> tensor<1x32x42x128xf32> { %w = arith.constant dense<127.0> : tensor<1x32x42x128xf32> %b = arith.constant dense<0.0> : tensor<1x32x42x128xf32> - %tc = "tfl.transpose_conv"(%arg1, %w, %arg0, %b) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x32x42x128xf32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x32x42x128xf32> + %tc = "tfl.transpose_conv"(%arg1, %w, %arg0, %b) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<1x32x42x128xf32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x32x42x128xf32> func.return %tc : tensor<1x32x42x128xf32> // CHECK: %[[CST:.*]] = arith.constant dense<1.270000e+02> : tensor<1x32x42x128xf32> @@ -231,7 +231,7 @@ func.func @bias_adjust_perchannel(%arg0: tensor<1x5x5x2xf32>, %arg1: tensor<4xi3 %w = arith.constant dense<[[[[-1.0, 1.0]]], [[[1.0, 2.0]]], [[[-2.0, 1.0]]]]> : tensor<3x1x1x2xf32> %b = arith.constant dense<[1.0e-2, 2.1473647e1, -2.1473647e2]> : tensor<3xf32> %transpose_conv = "tfl.transpose_conv"(%arg1, %w, %0, %b) { - padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32 + padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE" } : (tensor<4xi32>, tensor<3x1x1x2xf32>, tensor<1x5x5x2xf32>, tensor<3xf32>) -> tensor<1x5x5x3xf32> func.return %transpose_conv : tensor<1x5x5x3xf32> // CHECK: %[[bias:.*]] = arith.constant dense<[0.00999999977, 21.4736462, -214.736465]> diff --git a/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range.mlir b/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range.mlir index 9f08baea0bd..881e29d205f 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize-dynamic-range.mlir @@ -144,7 +144,7 @@ func.func @QuantizeMatmulWithActConst(%arg0: tensor<1x3x3x512xf32>) -> tensor<1x func.func @QuantizeTransposeConvWeightOnly(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<4xi32>) -> tensor<1x32x42x128xf32> { %w = arith.constant dense<127.0> : tensor<1x32x42x128xf32> %b = arith.constant dense<0.0> : tensor<1x32x42x128xf32> - %tconv = "tfl.transpose_conv"(%arg1, %w, %arg0, %b) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x32x42x128xf32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x32x42x128xf32> + %tconv = "tfl.transpose_conv"(%arg1, %w, %arg0, %b) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<1x32x42x128xf32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x32x42x128xf32> func.return %tconv : tensor<1x32x42x128xf32> // CHECK: %[[b:.*]] = arith.constant dense<0.000000e+00> : tensor<1x32x42x128xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/quantize-variables.mlir b/tensorflow/compiler/mlir/lite/tests/quantize-variables.mlir new file mode 100644 index 00000000000..4431444c1ba --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/quantize-variables.mlir @@ -0,0 +1,190 @@ +// RUN: tf-opt %s -tfl-quantize-variables | FileCheck %s +// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-post-quantize -tfl-quantize-variables -tfl-quantize -tfl-post-quantize | FileCheck --check-prefix=WHOLE-PASSES %s + +// CHECK-LABEL: QuantizeReadVariable +func.func @QuantizeReadVariable() -> (tensor<1x2x1x3x!quant.uniform>) { + %1 = "tfl.var_handle"() : () -> tensor + %2 = "tfl.read_variable"(%1) : (tensor) -> tensor<1x2x1x3xf32> + %3 = "tfl.quantize"(%2) {qtype = tensor<1x2x1x3x!quant.uniform>, volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> + func.return %3 : tensor<1x2x1x3x!quant.uniform> + +// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() {container = "", shared_name = ""} : () -> tensor<*x!tf_type.resource>>> +// CHECK-NEXT: %[[rv:.*]] = "tfl.read_variable"(%[[vh]]) : (tensor<*x!tf_type.resource>>>) -> tensor<1x2x1x3x!quant.uniform> +// CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[rv]]) : (tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3xf32> +// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%[[dq]]) {qtype = tensor<1x2x1x3x!quant.uniform>, volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> +// CHECK-NEXT: return %[[q]] : tensor<1x2x1x3x!quant.uniform> +} + +// CHECK-LABEL: QuantizeAssignVariableWithDequantAndEqualType +func.func @QuantizeAssignVariableWithDequantAndEqualType(%arg0 : tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3x!quant.uniform> { + %0 = "tfl.var_handle"() : () -> tensor + %1 = "tfl.dequantize"(%arg0) : (tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3xf32> + "tfl.assign_variable"(%0, %1) : (tensor, tensor<1x2x1x3xf32>) -> () + func.return %arg0 : tensor<1x2x1x3x!quant.uniform> + +// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() {container = "", shared_name = ""} : () -> tensor<*x!tf_type.resource>>> +// CHECK-NEXT: "tfl.assign_variable"(%[[vh]], %arg0) : (tensor<*x!tf_type.resource>>>, tensor<1x2x1x3x!quant.uniform>) -> () +// CHECK-NEXT: return %arg0 : tensor<1x2x1x3x!quant.uniform> +} + +// CHECK-LABEL: QuantizeAssignVariableWithDequantAndNotEqualType +func.func @QuantizeAssignVariableWithDequantAndNotEqualType(%arg0 : tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3x!quant.uniform> { + %1 = "tfl.var_handle"() : () -> tensor + %2 = "tfl.read_variable"(%1) : (tensor) -> tensor<1x2x1x3xf32> + %3 = "tfl.quantize"(%2) {qtype = tensor<1x2x1x3x!quant.uniform>, volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> + %5 = "tfl.dequantize"(%arg0) : (tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3xf32> + "tfl.assign_variable"(%1, %5) : (tensor, tensor<1x2x1x3xf32>) -> () + func.return %arg0 : tensor<1x2x1x3x!quant.uniform> + +// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() {container = "", shared_name = ""} : () -> tensor<*x!tf_type.resource>>> +// CHECK-NEXT: %[[rv:.*]] = "tfl.read_variable"(%[[vh]]) : (tensor<*x!tf_type.resource>>>) -> tensor<1x2x1x3x!quant.uniform> +// CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[rv]]) : (tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3xf32> +// CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%[[dq]]) {qtype = tensor<1x2x1x3x!quant.uniform>, volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x1x3x!quant.uniform>} : (tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3x!quant.uniform> +// CHECK-NEXT: "tfl.assign_variable"(%[[vh]], %[[q2]]) : (tensor<*x!tf_type.resource>>>, tensor<1x2x1x3x!quant.uniform>) -> () +// CHECK-NEXT: return %arg0 : tensor<1x2x1x3x!quant.uniform> +} + +// CHECK-LABEL: QuantizeAssignVariableWithoutDequant +func.func @QuantizeAssignVariableWithoutDequant(%arg0 : tensor<1x2x1x3xf32>) -> tensor<1x2x1x3xf32> { + %0 = "tfl.var_handle"() : () -> tensor + %1 = "tfl.read_variable"(%0) : (tensor) -> tensor<1x2x1x3xf32> + %2 = "tfl.quantize"(%1) {qtype = tensor<1x2x1x3x!quant.uniform>, volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> + %3 = "tfl.dequantize"(%2) : (tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3xf32> + "tfl.assign_variable"(%0, %3) : (tensor, tensor<1x2x1x3xf32>) -> () + func.return %arg0 : tensor<1x2x1x3xf32> + +// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() {container = "", shared_name = ""} : () -> tensor<*x!tf_type.resource>>> +// CHECK-NEXT: %[[rv:.*]] = "tfl.read_variable"(%[[vh]]) : (tensor<*x!tf_type.resource>>>) -> tensor<1x2x1x3x!quant.uniform> +// CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[rv]]) : (tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3xf32> +// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%[[dq]]) {qtype = tensor<1x2x1x3x!quant.uniform>, volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> +// CHECK-NEXT: "tfl.assign_variable"(%[[vh]], %[[q]]) : (tensor<*x!tf_type.resource>>>, tensor<1x2x1x3x!quant.uniform>) -> () +// CHECK-NEXT: return %arg0 : tensor<1x2x1x3xf32> +} + +// CHECK-LABEL: VarHandleCase +func.func @VarHandleCase(%arg0 : tensor<1x2x1x3xf32>) -> tensor<1x2x1x3xf32> { + %0 = "tfl.var_handle"() : () -> tensor + func.return %arg0 : tensor<1x2x1x3xf32> + +// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() {container = "", shared_name = ""} : () -> tensor<*x!tf_type.resource>>> +// CHECK-NEXT: return %arg0 : tensor<1x2x1x3xf32> +} + +// CHECK-LABEL: QuantizeReadAssign +func.func @QuantizeReadAssign(%arg0: tensor<1x32x1x3xf32>) -> (tensor<1x34x1x3xf32>) { + %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x32x1x3x!quant.uniform>, volatile} : (tensor<1x32x1x3xf32>) -> tensor<1x32x1x3x!quant.uniform> + %1 = "tfl.dequantize"(%0) : (tensor<1x32x1x3x!quant.uniform>) -> tensor<1x32x1x3xf32> + %cst = arith.constant dense<1> : tensor<4xi32> + %cst_0 = arith.constant dense<[0, 0, 0, 3]> : tensor<4xi32> + %cst_1 = arith.constant dense<[0, -2, 0, 0]> : tensor<4xi32> + %2 = "tfl.var_handle"() {container = "", shared_name = "read_assign2/states"} : () -> tensor + %3 = "tfl.read_variable"(%2) : (tensor) -> tensor<1x2x1x3xf32> + %4 = "tfl.concatenation"(%3, %1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x1x3xf32>, tensor<1x32x1x3xf32>) -> tensor<1x34x1x3xf32> + %5 = "tfl.quantize"(%4) {qtype = tensor<1x34x1x3x!quant.uniform>, volatile} : (tensor<1x34x1x3xf32>) -> tensor<1x34x1x3x!quant.uniform> + %6 = "tfl.dequantize"(%5) : (tensor<1x34x1x3x!quant.uniform>) -> tensor<1x34x1x3xf32> + %7 = "tfl.strided_slice"(%6, %cst_1, %cst_0, %cst) {begin_mask = 13 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x34x1x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x1x3xf32> + %8 = "tfl.quantize"(%7) {qtype = tensor<1x2x1x3x!quant.uniform>, volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> + %9 = "tfl.dequantize"(%8) : (tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3xf32> + "tfl.assign_variable"(%2, %9) : (tensor, tensor<1x2x1x3xf32>) -> () + func.return %6 : tensor<1x34x1x3xf32> + +// CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x32x1x3x!quant.uniform>, volatile} : (tensor<1x32x1x3xf32>) -> tensor<1x32x1x3x!quant.uniform> +// CHECK-NEXT: %[[dq1:.*]] = "tfl.dequantize"(%[[q1]]) : (tensor<1x32x1x3x!quant.uniform>) -> tensor<1x32x1x3xf32> +// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<1> : tensor<4xi32> +// CHECK-NEXT: %[[cst_0:.*]] = arith.constant dense<[0, 0, 0, 3]> : tensor<4xi32> +// CHECK-NEXT: %[[cst_1:.*]] = arith.constant dense<[0, -2, 0, 0]> : tensor<4xi32> +// CHECK-NEXT: %[[vh:.*]] = "tfl.var_handle"() {container = "", shared_name = "read_assign2/states"} : () -> tensor<*x!tf_type.resource>>> +// CHECK-NEXT: %[[rv:.*]] = "tfl.read_variable"(%[[vh]]) : (tensor<*x!tf_type.resource>>>) -> tensor<1x2x1x3x!quant.uniform> +// CHECK-NEXT: %[[dq2:.*]] = "tfl.dequantize"(%[[rv]]) : (tensor<1x2x1x3x!quant.uniform>) -> tensor<1x2x1x3xf32> +// CHECK-NEXT: %[[cc:.*]] = "tfl.concatenation"(%[[dq2]], %[[dq1]]) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x1x3xf32>, tensor<1x32x1x3xf32>) -> tensor<1x34x1x3xf32> +// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cc]]) {qtype = tensor<1x34x1x3x!quant.uniform>, volatile} : (tensor<1x34x1x3xf32>) -> tensor<1x34x1x3x!quant.uniform> +// CHECK-NEXT: %[[dq3:.*]] = "tfl.dequantize"(%[[q2]]) : (tensor<1x34x1x3x!quant.uniform>) -> tensor<1x34x1x3xf32> +// CHECK-NEXT: %[[ss:.*]] = "tfl.strided_slice"(%[[dq3]], %[[cst_1]], %[[cst_0]], %[[cst]]) {begin_mask = 13 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x34x1x3xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x1x3xf32> +// CHECK-NEXT: %[[q3:.*]] = "tfl.quantize"(%[[ss]]) {qtype = tensor<1x2x1x3x!quant.uniform>, volatile} : (tensor<1x2x1x3xf32>) -> tensor<1x2x1x3x!quant.uniform> +// CHECK-NEXT: "tfl.assign_variable"(%[[vh]], %[[q3]]) : (tensor<*x!tf_type.resource>>>, tensor<1x2x1x3x!quant.uniform>) -> () +// CHECK-NEXT: return %[[dq3]] : tensor<1x34x1x3xf32> +} + +// WHOLE-PASSES-LABEL: QuantizeConvVariable +func.func @QuantizeConvVariable(%arg0: tensor<1x3x1x1xf32>) -> (tensor<1x3x1x1xf32>) { + %cst = arith.constant dense<1> : tensor<4xi32> + %cst_0 = arith.constant dense<[0, 3, 0, 1]> : tensor<4xi32> + %cst_1 = arith.constant dense<0> : tensor<4xi32> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<1xf32> + %0 = "tfl.quantize"(%cst_2) {qtype = tensor<1x!quant.uniform>, volatile} : (tensor<1xf32>) -> tensor<1x!quant.uniform> + %1 = "tfl.dequantize"(%0) : (tensor<1x!quant.uniform>) -> tensor<1xf32> + %cst_3 = arith.constant dense<[[[[1.0]], [[1.0]], [[1.0]]]]> : tensor<1x3x1x1xf32> + %2 = "tfl.quantize"(%cst_3) {qtype = tensor<1x3x1x1x!quant.uniform:f32:0, {1.0}>>, volatile} : (tensor<1x3x1x1xf32>) -> tensor<1x3x1x1x!quant.uniform:f32:0, {1.0}>> + %3 = "tfl.dequantize"(%2) : (tensor<1x3x1x1x!quant.uniform:f32:0, {1.0}>>) -> tensor<1x3x1x1xf32> + %4 = "tfl.quantize"(%arg0) {qtype = tensor<1x3x1x1x!quant.uniform>, volatile} : (tensor<1x3x1x1xf32>) -> tensor<1x3x1x1x!quant.uniform> + %5 = "tfl.dequantize"(%4) : (tensor<1x3x1x1x!quant.uniform>) -> tensor<1x3x1x1xf32> + %6 = "tfl.var_handle"() {container = "", shared_name = "conv_variable/state"} : () -> tensor + %7 = "tfl.read_variable"(%6) : (tensor) -> tensor<1x3x1x1xf32> + %8 = "tfl.conv_2d"(%5, %3, %1) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x3x1x1xf32>, tensor<1x3x1x1xf32>, tensor<1xf32>) -> tensor<1x3x1x1xf32> + %9 = "tfl.quantize"(%8) {qtype = tensor<1x3x1x1x!quant.uniform>, volatile} : (tensor<1x3x1x1xf32>) -> tensor<1x3x1x1x!quant.uniform> + %10 = "tfl.dequantize"(%9) : (tensor<1x3x1x1x!quant.uniform>) -> tensor<1x3x1x1xf32> + %11 = "tfl.concatenation"(%7, %10) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x3x1x1xf32>, tensor<1x3x1x1xf32>) -> tensor<1x6x1x1xf32> + %12 = "tfl.quantize"(%11) {qtype = tensor<1x6x1x1x!quant.uniform>, volatile} : (tensor<1x6x1x1xf32>) -> tensor<1x6x1x1x!quant.uniform> + %13 = "tfl.dequantize"(%12) : (tensor<1x6x1x1x!quant.uniform>) -> tensor<1x6x1x1xf32> + %14 = "tfl.strided_slice"(%13, %cst_1, %cst_0, %cst) {begin_mask = 15 : i32, ellipsis_mask = 0 : i32, end_mask = 13 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x6x1x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x3x1x1xf32> + %15 = "tfl.quantize"(%14) {qtype = tensor<1x3x1x1x!quant.uniform>, volatile} : (tensor<1x3x1x1xf32>) -> tensor<1x3x1x1x!quant.uniform> + %16 = "tfl.dequantize"(%15) : (tensor<1x3x1x1x!quant.uniform>) -> tensor<1x3x1x1xf32> + "tfl.assign_variable"(%6, %16) : (tensor, tensor<1x3x1x1xf32>) -> () + func.return %10 : tensor<1x3x1x1xf32> + +// WHOLE-PASSES: %[[vh:.*]] = "tfl.var_handle"() {container = "", shared_name = "conv_variable/state"} : () -> tensor<*x!tf_type.resource>>> +// WHOLE-PASSES-NEXT: %[[rv:.*]] = "tfl.read_variable"(%[[vh]]) : (tensor<*x!tf_type.resource>>>) -> tensor<1x3x1x1x!quant.uniform> +// WHOLE-PASSES-DAG: %[[cv:.*]] = "tfl.conv_2d"(%arg0, {{.*}}) {{{.*}}} : (tensor<1x3x1x1x!quant.uniform>, tensor<1x3x1x1x!quant.uniform:f32:0, {{.*}}>>, tensor<1x!quant.uniform>) -> tensor<1x3x1x1x!quant.uniform> +// WHOLE-PASSES-NEXT: %[[cc:.*]] = "tfl.concatenation"(%[[rv]], %[[cv]]) {{{.*}}} : (tensor<1x3x1x1x!quant.uniform>, tensor<1x3x1x1x!quant.uniform>) -> tensor<1x6x1x1x!quant.uniform> +// WHOLE-PASSES-NEXT: %[[ss:.*]] = "tfl.strided_slice"(%[[cc]], {{.*}}) {{{.*}}} : (tensor<1x6x1x1x!quant.uniform>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x3x1x1x!quant.uniform> +// WHOLE-PASSES-NEXT: "tfl.assign_variable"(%[[vh]], %[[ss]]) : (tensor<*x!tf_type.resource>>>, tensor<1x3x1x1x!quant.uniform>) -> () +// WHOLE-PASSES-NEXT: return %[[cv]] : tensor<1x3x1x1x!quant.uniform> +} + +// WHOLE-PASSES-LABEL: QuantizeTwoVariable +func.func @QuantizeTwoVariable(%arg0: tensor<1x2x3xf32>) -> (tensor<1x2x3xf32>) { + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + + %1 = "tfl.pseudo_const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32> + %2 = "tfl.pseudo_const"() {value = dense<[0, 2, 0]> : tensor<3xi32>} : () -> tensor<3xi32> + %3 = "tfl.pseudo_const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> + + %4 = "tfl.var_handle"() {container = "", shared_name = "read_assign/states0"} : () -> tensor + %5 = "tfl.var_handle"() {container = "", shared_name = "read_assign/states1"} : () -> tensor + + %40 = "tfl.read_variable"(%4) : (tensor) -> tensor<1x2x3xf32> + %41 = "quantfork.stats"(%40) {layerStats = dense<[0.0, 1.0]> : tensor<2xf32>} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + %42 = "tfl.concatenation"(%41, %0) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xf32>, tensor<1x2x3xf32>) -> tensor<1x4x3xf32> + %43 = "quantfork.stats"(%42) {layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>} : (tensor<1x4x3xf32>) -> tensor<1x4x3xf32> + %44 = "tfl.strided_slice"(%43, %1, %2, %3) {begin_mask = 7 : i32, ellipsis_mask = 0 : i32, end_mask = 5 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x4x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x2x3xf32> + %45 = "quantfork.stats"(%44) {layerStats = dense<[0.0, 1.0]> : tensor<2xf32>} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + "tfl.assign_variable"(%4, %45) : (tensor, tensor<1x2x3xf32>) -> () + + %50 = "tfl.read_variable"(%5) : (tensor) -> tensor<1x2x3xf32> + %51 = "quantfork.stats"(%50) {layerStats = dense<[0.0, 1.0]> : tensor<2xf32>} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + %52 = "tfl.concatenation"(%51, %0) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xf32>, tensor<1x2x3xf32>) -> tensor<1x4x3xf32> + %53 = "quantfork.stats"(%52) {layerStats = dense<[0.0, 1.0]> : tensor<2xf32>} : (tensor<1x4x3xf32>) -> tensor<1x4x3xf32> + %54 = "tfl.strided_slice"(%53, %1, %2, %3) {begin_mask = 7 : i32, ellipsis_mask = 0 : i32, end_mask = 5 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x4x3xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x2x3xf32> + %55 = "quantfork.stats"(%54) {layerStats = dense<[0.0, 1.0]> : tensor<2xf32>} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + "tfl.assign_variable"(%5, %55) : (tensor, tensor<1x2x3xf32>) -> () + + func.return %0 : tensor<1x2x3xf32> + +// WHOLE-PASSES: %[[q1:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x2x3x!quant.uniform>} : (tensor<1x2x3x!quant.uniform>) -> tensor<1x2x3x!quant.uniform> +// WHOLE-PASSES-DAG: %[[vh1:.*]] = "tfl.var_handle"() {container = "", shared_name = "read_assign/states0"} : () -> tensor<*x!tf_type.resource>>> +// WHOLE-PASSES-DAG: %[[vh2:.*]] = "tfl.var_handle"() {container = "", shared_name = "read_assign/states1"} : () -> tensor<*x!tf_type.resource>>> + +// WHOLE-PASSES-DAG: %[[rv1:.*]] = "tfl.read_variable"({{.*}}) : (tensor<*x!tf_type.resource>>>) -> tensor<1x2x3x!quant.uniform> +// WHOLE-PASSES-NEXT: %[[cc1:.*]] = "tfl.concatenation"(%[[rv1]], {{.*}}) {{.*}} : (tensor<1x2x3x!quant.uniform>, tensor<1x2x3x!quant.uniform>) -> tensor<1x4x3x!quant.uniform> +// WHOLE-PASSES-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cc1]]) {qtype = tensor<1x4x3x!quant.uniform>} : (tensor<1x4x3x!quant.uniform>) -> tensor<1x4x3x!quant.uniform> +// WHOLE-PASSES-NEXT: %[[ss1:.*]] = "tfl.strided_slice"(%[[q2]], {{.*}}) {{{.*}}} : (tensor<1x4x3x!quant.uniform>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x2x3x!quant.uniform> +// WHOLE-PASSES-NEXT: "tfl.assign_variable"(%[[vh1]], %[[ss1]]) : (tensor<*x!tf_type.resource>>>, tensor<1x2x3x!quant.uniform>) -> () + +// WHOLE-PASSES-DAG: %[[rv2:.*]] = "tfl.read_variable"({{.*}}) : (tensor<*x!tf_type.resource>>>) -> tensor<1x2x3x!quant.uniform> +// WHOLE-PASSES-NEXT: %[[cc2:.*]] = "tfl.concatenation"(%[[rv2]], {{.*}}) {{.*}} : (tensor<1x2x3x!quant.uniform>, tensor<1x2x3x!quant.uniform>) -> tensor<1x4x3x!quant.uniform> +// WHOLE-PASSES-NEXT: %[[ss2:.*]] = "tfl.strided_slice"(%[[cc2]], {{.*}}) {{{.*}}} : (tensor<1x4x3x!quant.uniform>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x2x3x!quant.uniform> +// WHOLE-PASSES-NEXT: "tfl.assign_variable"(%[[vh2]], %[[ss2]]) : (tensor<*x!tf_type.resource>>>, tensor<1x2x3x!quant.uniform>) -> () + +// WHOLE-PASSES-NEXT: return %arg0 : tensor<1x2x3x!quant.uniform> +} diff --git a/tensorflow/compiler/mlir/lite/tests/quantize.mlir b/tensorflow/compiler/mlir/lite/tests/quantize.mlir index 05dd85c0c42..30a5f67bd14 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize.mlir @@ -80,7 +80,24 @@ func.func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform:f32, 1.000000e-01>>, value = dense<1> : tensor<32x3x3x3xi8>} // CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[cst1]], %[[cst0]]) // CHECK: return %[[conv]] : tensor<1x112x112x32x!quant.uniform> +} +// CHECK-LABEL: QuantizeConv2D4Bit +func.func @QuantizeConv2D4Bit(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { +^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): + %cst = arith.constant dense<-1.23697901> : tensor<32xf32> + %2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3xf32> + %w = arith.constant dense<-1.0> : tensor<32x3x3x3xf32> + %3 = "tfl.quantize"(%w) {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.1>>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3x!quant.uniform:f32, 0.1>> + %4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform:f32, 0.1>>) -> tensor<32x3x3x3xf32> + %5 = "tfl.conv_2d"(%2, %4, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> + %6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> + func.return %6 : tensor<1x112x112x32x!quant.uniform> + +// CHECK: %[[cst0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<-1583> : tensor<32xi32>} +// CHECK: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 1.000000e-01>>, value = dense<1> : tensor<32x3x3x3xi4>} +// CHECK: %[[conv:.*]] = "tfl.conv_2d"(%arg0, %[[cst1]], %[[cst0]]) +// CHECK: return %[[conv]] : tensor<1x112x112x32x!quant.uniform> } // CHECK-LABEL: QuantizeDepthwiseConv2D @@ -100,6 +117,23 @@ func.func @QuantizeDepthwiseConv2D(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { +^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): + %cst = arith.constant dense<-1.23697901> : tensor<32xf32> + %2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3xf32> + %3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.395998308:8>>, value = dense<-7> : tensor<32x3x3x3xi4>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.395998308:8>> + %4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform:f32, 0.395998308:8>>) -> tensor<32x3x3x3xf32> + %5 = "tfl.depthwise_conv_2d"(%2, %4, %cst) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> + %6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> + func.return %6 : tensor<1x112x112x32x!quant.uniform> + +// CHECK: %[[cst0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<-400> : tensor<32xi32>} +// CHECK: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.39599830800000002:8>>, value = dense<-7> : tensor<32x3x3x3xi4>} +// CHECK: %[[conv:.*]] = "tfl.depthwise_conv_2d"(%arg0, %[[cst1]], %[[cst0]]) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} +// CHECK: return %[[conv]] +} + // CHECK-LABEL: QuantizeFullyConnected func.func @QuantizeFullyConnected(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { ^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): @@ -125,6 +159,31 @@ func.func @QuantizeFullyConnected(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { +^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): + %cst = arith.constant dense<-1.23697901> : tensor<32xf32> + %2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3xf32> + %3 = "tfl.pseudo_qconst"() {qtype = tensor<32x12x!quant.uniform:f32, 0.395998308:8>>, value = dense<-7> : tensor<32x12xi4>} : () -> tensor<32x12x!quant.uniform:f32, 0.395998308:8>> + %4 = "tfl.dequantize"(%3) : (tensor<32x12x!quant.uniform:f32, 0.395998308:8>>) -> tensor<32x12xf32> + %5 = "tfl.fully_connected"(%2, %4, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x224x224x3xf32>, tensor<32x12xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> + %6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> + func.return %6 : tensor<1x112x112x32x!quant.uniform> + +// CHECK: %[[cst_0:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<-400> : tensor<32xi32>} +// CHECK: %[[cst_1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x12x!quant.uniform:f32, 0.39599830800000002:8>>, value = dense<-7> : tensor<32x12xi4>} +// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %[[cst_1]], %[[cst_0]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} +// CHECK: return %[[fc]] + +// BLOCK: %[[cst:.*]] = arith.constant dense<-1.23697901> +// BLOCK: %[[dq1:.*]] = "tfl.dequantize"(%arg0) +// BLOCK: %[[cst2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x12x!quant.uniform:f32, 0.39599830800000002:8>>, value = dense<-7> : tensor<32x12xi4>} +// BLOCK: %[[dq2:.*]] = "tfl.dequantize"(%[[cst2]]) +// BLOCK: %[[fc:.*]] = "tfl.fully_connected"(%[[dq1]], %[[dq2]], %[[cst]]) +// BLOCK: %[[q:.*]] = "tfl.quantize"(%[[fc]]) +// BLOCK: return %[[q]] +} + // CHECK-LABEL: QuantizeNoBiasFullyConnected func.func @QuantizeNoBiasFullyConnected(%arg0: tensor<3x!quant.uniform>, %arg1: tensor<3x3x!quant.uniform:f32, 1.0>>, %arg2: none) -> tensor<3x!quant.uniform> { %0 = "tfl.dequantize"(%arg0) : (tensor<3x!quant.uniform>) -> tensor<3xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir b/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir index d4d30b4c922..34b9a54bc91 100644 --- a/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir +++ b/tensorflow/compiler/mlir/lite/tests/shape-inference.mlir @@ -87,6 +87,17 @@ func.func @testUnidirectionalSequenceLstmShapeInference(%arg0: tensor<600 x 10 x // ----- +module attributes {tf.versions = {producer = 888 : i32}} { +// CHECK-LABEL: testUnidirectionalSequenceLstmShapeInference +func.func @testUnidirectionalSequenceLstmShapeInference(%arg0: tensor<600 x ? x 20 x f32>, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor<40 x f32>, %arg16: tensor, %arg17: tensor, %arg18: tensor<600 x 40 x f32>, %arg19: tensor<600 x 40 x f32>, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { + // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<600x?x20xf32>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<40xf32>, tensor, tensor, tensor<600x40xf32>, tensor<600x40xf32>, tensor, tensor, tensor, tensor) -> tensor<600x?x40xf32 + %0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<600 x ? x 20 x f32>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<40xf32>, tensor, tensor, tensor<600x40xf32>, tensor<600x40xf32>, tensor, tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} +} + +// ----- + // CHECK-LABEL: testReshapeShapeInference module attributes {tf.versions = {producer = 888 : i32}} { func.func @testReshapeShapeInference(%arg0: tensor<3x4xi32>) -> tensor<*xi32> { diff --git a/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir b/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir index 35c25a92f2a..0fd1482d77d 100644 --- a/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir +++ b/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir @@ -190,33 +190,3 @@ func.func @whileDifferentResultShapes(%arg0: tensor) -> tensor // CHECK: (tensor, tensor<1xf32>, tensor) -> (tensor, tensor, tensor) func.return %0#1 : tensor } - -// ----- - -func.func @unsupportedCast(%arg0: tensor<4x4x3xf32>) -> tensor<*xf32> { - %cst = arith.constant dense<0.000000e+00> : tensor<4x2xf32> - %cst_0 = arith.constant dense<0.000000e+00> : tensor<4x4x3xf64> - %cst_1 = arith.constant dense<[1, 0, 2]> : tensor<3xi32> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<4x4x2xf32> - %cst_3 = arith.constant dense<4> : tensor - %cst_4 = arith.constant dense<0> : tensor - %cst_5 = arith.constant dense<0.000000e+00> : tensor<4x2xf64> - %0 = "tfl.transpose"(%arg0, %cst_1) : (tensor<4x4x3xf32>, tensor<3xi32>) -> tensor<4x4x3xf32> - %1:6 = "tfl.while"(%cst_4, %cst_4, %cst_2, %cst, %cst_5, %cst_0) ({ - ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor<*xf32>, %arg4: tensor<4x2xf32>, %arg5: tensor<4x2xf64>, %arg6: tensor<*xf64>): - %5 = "tfl.less"(%arg2, %cst_3) : (tensor, tensor) -> tensor - %6 = "tfl.less"(%arg1, %cst_3) : (tensor, tensor) -> tensor - %7 = tfl.logical_and %6, %5 : tensor - "tfl.yield"(%7) : (tensor) -> () - }, { - ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor<*xf32>, %arg4: tensor<4x2xf32>, %arg5: tensor<4x2xf64>, %arg6: tensor<*xf64>): - "tfl.yield"(%arg1, %arg2, %arg3, %arg4, %arg5, %cst_0) : (tensor, tensor, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf64>, tensor<4x4x3xf64>) -> () - }) {is_stateless = true} : (tensor, tensor, tensor<4x4x2xf32>, tensor<4x2xf32>, tensor<4x2xf64>, tensor<4x4x3xf64>) -> (tensor, tensor, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf64>, tensor<*xf32>) - func.return %1#2 : tensor<*xf32> -} - -// CHECK-LABEL: func @unsupportedCast( - -// CHECK-LABEL: func private @tfl.while_body( -// CHECK-SAME: %arg0: tensor, %arg1: tensor, %arg2: tensor<*xf32>, %arg3: tensor<4x2xf32>, %arg4: tensor<4x2xf64>, %arg5: tensor<*xf64>) -> (tensor, tensor, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf64>, tensor<*xf64>) -// CHECK: [[VAL:%.*]] = "tf.Cast" diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 6d0b035683b..17b0501cb78 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -35,7 +35,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" namespace mlir { /// Create a pass to convert from the TFExecutor to the TF control dialect. @@ -57,8 +57,8 @@ void AddQuantizationPasses(const mlir::quant::QuantizationSpecs& quant_specs, quant_specs.default_ranges.second.has_value()) { pass_manager.addNestedPass( mlir::TFL::CreateDefaultQuantParamsPass( - quant_specs.default_ranges.first.getValueOr(0.0), - quant_specs.default_ranges.second.getValueOr(0.0), + quant_specs.default_ranges.first.value_or(0.0), + quant_specs.default_ranges.second.value_or(0.0), quant_specs.IsSignedInferenceType())); } pass_manager.addNestedPass( @@ -67,6 +67,19 @@ void AddQuantizationPasses(const mlir::quant::QuantizationSpecs& quant_specs, quant_specs.inference_type != quant_specs.inference_input_type; pass_manager.addNestedPass( mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); + // TODO(b/265081639): When added PrepareQuantizeVariablesPass before adding + // PrepareQuantizePass, an error occurs in certain model. It could fix it by + // roll-back to run PrepareQuantizeVariablesPass, QuantizePass, + // PostQuantizePass as suggested in cl/479634700. Need to figure out the + // fundamental reason of the error, and (if needed) fix it without this + // rollback. + if (quant_specs.enable_mlir_variable_quantization) { + pass_manager.addPass(mlir::TFL::CreatePrepareQuantizeVariablesPass()); + pass_manager.addNestedPass( + mlir::TFL::CreateQuantizePass(quant_specs)); + pass_manager.addNestedPass( + mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); + } pass_manager.addNestedPass( mlir::TFL::CreateOptimizeOpOrderPass()); // Add optimization pass after quantization for additional fusing @@ -87,6 +100,19 @@ void AddDynamicRangeQuantizationPasses( pass_manager.addNestedPass( mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops, quant_specs.custom_map)); + // TODO(b/265081639): When added PrepareQuantizeVariablesPass before adding + // PrepareQuantizePass, an error occurs in certain model. It could fix it by + // roll-back to run PrepareQuantizeVariablesPass, QuantizePass, + // PostQuantizePass as suggested in cl/479634700. Need to figure out the + // fundamental reason of the error, and (if needed) fix it without this + // rollback. + if (quant_specs.enable_mlir_variable_quantization) { + pass_manager.addPass(mlir::TFL::CreatePrepareQuantizeVariablesPass()); + pass_manager.addNestedPass( + mlir::TFL::CreateQuantizePass(quant_specs)); + pass_manager.addNestedPass( + mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); + } pass_manager.addNestedPass( mlir::TFL::CreateOptimizeOpOrderPass()); // Add optimization pass after quantization for additional fusing diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 74f9d994eaa..c808336e571 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -15,9 +15,9 @@ limitations under the License. #include #include +#include #include "absl/strings/str_split.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -53,17 +53,17 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/compiler/xla/translate/hlo_to_mhlo/translate.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/tsl/platform/statusor.h" using mlir::MLIRContext; using mlir::ModuleOp; using mlir::func::FuncOp; -using stream_executor::port::StatusOr; +using tsl::StatusOr; // Debugging flag to print function mapping in the flatbuffer. // NOLINTNEXTLINE @@ -308,7 +308,7 @@ int main(int argc, char **argv) { }); std::string result; - llvm::Optional session = llvm::None; + llvm::Optional session = std::nullopt; if (bundle) session = bundle->GetSession(); auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer( module.value().get(), output_mlir, toco_flags, pass_config, tags, diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 0755c06aa0e..f908f608238 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -15,9 +15,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" +#include +#include #include #include #include +#include #include "absl/types/span.h" #include "llvm/Support/raw_ostream.h" @@ -34,8 +37,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_tfl_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_tfl_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -48,12 +51,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/lite/tools/optimize/quantize_weights.h" #include "tensorflow/lite/tools/optimize/reduced_precision_support.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { namespace { @@ -61,7 +64,7 @@ using mlir::MLIRContext; using mlir::ModuleOp; using mlir::Operation; using mlir::OwningOpRef; -using stream_executor::port::StatusOr; +using tsl::StatusOr; bool IsControlFlowV1Op(Operation* op) { return mlir::isa( + mlir::odml::CreateStablehloToTflPass()); + if (failed(pass_manager.run(module))) { + return statusHandler.ConsumeStatus(); + } + + // Write TFLite Custom Op MLIR to Flatbuffer + // TODO(b/260112687): will serialize StableHLO to Flatbuffer directly + tflite::FlatbufferExportOptions options; + options.toco_flags.set_allow_custom_ops(true); + if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) { + return statusHandler.ConsumeStatus(); + } + + return OkStatus(); } Status ConvertTFExecutorToTFLOrFlatbuffer( @@ -297,8 +313,8 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( // Freeze variables if a session is provided. if (session.has_value()) { mlir::TFL::ErrorCollectorInstrumentation collector(module.getContext()); - if (failed(mlir::tf_saved_model::FreezeVariables(module, - session.getValue()))) { + if (failed( + mlir::tf_saved_model::FreezeVariables(module, session.value()))) { auto status = statusHandler.ConsumeStatus(); mlir::TFL::ErrorCollector* collector = mlir::TFL::ErrorCollector::GetErrorCollector(); diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index 0b327054ebf..0561be72547 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -30,8 +30,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/lite/toco/toco_flags.pb.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { @@ -40,8 +40,7 @@ namespace tensorflow { // file; otherwise, load from a GraphDef. // Setting prune_unused_nodes to true, would prune unreachable nodes if // output_arrays is specified. -stream_executor::port::StatusOr> -LoadFromGraphdefOrMlirSource( +tsl::StatusOr> LoadFromGraphdefOrMlirSource( const std::string& input_filename, bool input_mlir, bool use_splatted_constant, const std::vector& extra_tf_opdefs, const GraphImportConfig& specs, absl::string_view debug_info_file, @@ -52,8 +51,7 @@ LoadFromGraphdefOrMlirSource( // Load Saved model (either v1 or v2) into MLIR. // 'saved_model_bundle' will be initialized if V1 model was loaded. -stream_executor::port::StatusOr> -ImportSavedModel( +tsl::StatusOr> ImportSavedModel( const std::string& input_filename, const int saved_model_version, const std::unordered_set& tags, absl::Span extra_tf_opdefs, diff --git a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h index c4ba044caf7..d8ccdc6bef2 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h +++ b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h @@ -99,8 +99,8 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "dilations should be all 1"); } - if (!TFL::TFTypeIsFloat32Tensor(op.input()) && - !TFL::TFTypeIsBFloat16OrHalfTensor(op.input())) { + if (!TFL::TFTypeIsFloat32Tensor(op.getInput()) && + !TFL::TFTypeIsBFloat16OrHalfTensor(op.getInput())) { return rewriter.notifyMatchFailure( op, "op's input is not float or half or bfloat16"); } @@ -185,8 +185,8 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( } // Make sure that the axis in `expand_op` is constant. if (auto const_op = - llvm::dyn_cast(expand_op.dim().getDefiningOp())) { - expand_axis = (*const_op.value() + llvm::dyn_cast(expand_op.getDim().getDefiningOp())) { + expand_axis = (*const_op.getValue() .cast() .getValues() .begin()) @@ -202,7 +202,7 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( expand_op, "ExpandDimsOp doesn't have a constant axis"); } // Make sure that the `squeeze_dims` is equal to `expand_axis`. - auto squeeze_dims = squeeze_op.squeeze_dims(); + auto squeeze_dims = squeeze_op.getSqueezeDims(); if (squeeze_dims.size() != 1) { return rewriter.notifyMatchFailure( squeeze_op, "squeeze dims should have exactly 1 dimension specified"); @@ -218,7 +218,7 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( } // Update previous/next op pointer. - Operation* tmp = expand_op.input().getDefiningOp(); + Operation* tmp = expand_op.getInput().getDefiningOp(); if (!tmp || tmp->getNumResults() != 1) { return rewriter.notifyMatchFailure( producer_op, @@ -264,7 +264,7 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( return maybeConsumer.first; } consumer_op = maybeConsumer.second; - if (!matchPattern(pad_op.paddings(), m_Constant(&pad_attr))) { + if (!matchPattern(pad_op.getPaddings(), m_Constant(&pad_attr))) { // If the padding value isn't constant, we can't determine the padding // scheme for Conv2D below, in this case just reject the pattern. return rewriter.notifyMatchFailure( @@ -311,13 +311,13 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( } llvm::Optional dilations_attr = ExtractDilationsAttrFromBlockShape( - stb_op.block_shape(), bts_op.block_shape(), expand_axis, rewriter); + stb_op.getBlockShape(), bts_op.getBlockShape(), expand_axis, rewriter); if (!dilations_attr.has_value()) { return rewriter.notifyMatchFailure(op, "failed to extract dilation rate"); } if (expand_op) { - if (stb_op.input().getType().dyn_cast() == nullptr) { + if (stb_op.getInput().getType().dyn_cast() == nullptr) { return rewriter.notifyMatchFailure( stb_op, "SpaceToBatchND op's input should have RankedTensorType"); } @@ -351,8 +351,8 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( // dilated conv, hence we shouldn't pattern match here. Instead, we need to // check values of `paddings` and `crops` to make sure it really stands for // a dilated conv. - auto stb_paddings = stb_op.paddings(); - auto bts_crops = bts_op.crops(); + auto stb_paddings = stb_op.getPaddings(); + auto bts_crops = bts_op.getCrops(); ElementsAttr stb_paddings_attr, bts_crops_attr; if (!matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) || !matchPattern(bts_crops, m_Constant(&bts_crops_attr))) { @@ -388,7 +388,7 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( } // Set dilations - op->setAttr("dilations", dilations_attr.getValue()); + op->setAttr("dilations", dilations_attr.value()); if (expand_op) { // If there is `expand_op`, we need to rewire the inputs to bypass the @@ -397,45 +397,46 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( // BiasAdd' to 'Expand -> Conv2D ->Squeeze -> BiasAdd'. // Connect `expand_op` with the input of `stb_op`. - expand_op.setOperand(0, stb_op.input()); + expand_op.setOperand(0, stb_op.getInput()); // Calculate the shape for expand. - auto input_shape = stb_op.input().getType().cast().getShape(); + auto input_shape = + stb_op.getInput().getType().cast().getShape(); SmallVector expand_shape(input_shape.begin(), input_shape.end()); expand_shape.insert(expand_shape.begin() + expand_axis, 1); auto expand_result_type = RankedTensorType::get( - expand_shape, getElementTypeOrSelf(stb_op.input())); + expand_shape, getElementTypeOrSelf(stb_op.getInput())); expand_op.getResult().setType(expand_result_type); // Update the conv op's output shape. auto bts_output_shape = - bts_op.output().getType().cast().getShape(); + bts_op.getOutput().getType().cast().getShape(); SmallVector conv_result_shape(bts_output_shape.begin(), bts_output_shape.end()); conv_result_shape.insert(conv_result_shape.begin() + expand_axis, 1); auto conv_result_type = RankedTensorType::get( - conv_result_shape, getElementTypeOrSelf(stb_op.input())); + conv_result_shape, getElementTypeOrSelf(stb_op.getInput())); op.getResult().setType(conv_result_type); - squeeze_op.getResult().setType(bts_op.output().getType()); + squeeze_op.getResult().setType(bts_op.getOutput().getType()); // Connect `biasadd_op` with the output of `squeeze_op`. if (biasadd_op) { - biasadd_op.setOperand(0, squeeze_op.output()); - biasadd_op.output().setType(squeeze_op.output().getType()); + biasadd_op.setOperand(0, squeeze_op.getOutput()); + biasadd_op.getOutput().setType(squeeze_op.getOutput().getType()); } } else { - if (biasadd_op) biasadd_op.setOperand(0, op.output()); - op.setOperand(0, stb_op.input()); + if (biasadd_op) biasadd_op.setOperand(0, op.getOutput()); + op.setOperand(0, stb_op.getInput()); op.getResult().setType(bts_op.getResult().getType()); } if (final_op_is_bts) { - if (bts_op.input().getDefiningOp()) { - bts_op.getResult().replaceAllUsesWith(pad_op.input()); + if (bts_op.getInput().getDefiningOp()) { + bts_op.getResult().replaceAllUsesWith(pad_op.getInput()); } else { - bts_op.getResult().replaceAllUsesWith(bts_op.input()); + bts_op.getResult().replaceAllUsesWith(bts_op.getInput()); } } diff --git a/tensorflow/compiler/mlir/lite/transforms/insert_call_once_op.cc b/tensorflow/compiler/mlir/lite/transforms/insert_call_once_op.cc index fd279b75f96..4e1fe8e0122 100644 --- a/tensorflow/compiler/mlir/lite/transforms/insert_call_once_op.cc +++ b/tensorflow/compiler/mlir/lite/transforms/insert_call_once_op.cc @@ -41,22 +41,9 @@ class InsertCallOnceOpFromSessionInitializerPass void InsertCallOnceOpFromSessionInitializerPass::runOnOperation() { ModuleOp module = getOperation(); - tf_saved_model::SessionInitializerOp session_init_op = - tf_saved_model::GetSessionInitializerOp(module); - - if (!session_init_op) return; - - SymbolTable symbol_table(module); - - for (auto sym_ref : session_init_op.getInitializers()) { - func::FuncOp init_func_op = symbol_table.lookup( - sym_ref.cast().getValue()); - - if (!init_func_op) { - module.emitError("no session initializer function found"); - return signalPassFailure(); - } + for (func::FuncOp init_func_op : + tf_saved_model::GetInitializerFunctions(module)) { for (auto func : module.getOps()) { auto dict_attr = func->getAttrOfType("tf.entry_function"); diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_hashtables.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_hashtables.cc index e3b201bd82d..59dbbbb9dc0 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_hashtables.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_hashtables.cc @@ -59,9 +59,9 @@ class LegalizeHashTableOpPattern : public OpRewritePattern { // native resource design is based on integer keys to identify the // corresponding resource objects. auto table_id = - static_cast(::llvm::hash_value(hashtable_op.shared_name())); - auto key_dtype = hashtable_op.key_dtype(); - auto value_dtype = hashtable_op.value_dtype(); + static_cast(::llvm::hash_value(hashtable_op.getSharedName())); + auto key_dtype = hashtable_op.getKeyDtype(); + auto value_dtype = hashtable_op.getValueDtype(); rewriter.replaceOpWithNewOp( hashtable_op, output_type, table_id, key_dtype, value_dtype); @@ -76,13 +76,13 @@ class LegalizeHashTableFindOpPattern LogicalResult matchAndRewrite(TF::LookupTableFindV2Op find_op, PatternRewriter& rewriter) const override { - auto handle_op = find_op.table_handle().getDefiningOp(); + auto handle_op = find_op.getTableHandle().getDefiningOp(); if (handle_op == nullptr) return failure(); auto hashtable_op = llvm::dyn_cast(handle_op); if (hashtable_op == nullptr) return failure(); rewriter.replaceOpWithNewOp( - find_op, find_op->getResultTypes(), find_op.table_handle(), - find_op.keys(), find_op.default_value()); + find_op, find_op->getResultTypes(), find_op.getTableHandle(), + find_op.getKeys(), find_op.getDefaultValue()); return success(); } }; @@ -94,13 +94,13 @@ class LegalizeHashTableImportOpPattern LogicalResult matchAndRewrite(TF::LookupTableImportV2Op import_op, PatternRewriter& rewriter) const override { - auto handle_op = import_op.table_handle().getDefiningOp(); + auto handle_op = import_op.getTableHandle().getDefiningOp(); if (handle_op == nullptr) return failure(); auto hashtable_op = llvm::dyn_cast(handle_op); if (hashtable_op == nullptr) return failure(); rewriter.replaceOpWithNewOp( - import_op, import_op->getResultTypes(), import_op.table_handle(), - import_op.keys(), import_op.values()); + import_op, import_op->getResultTypes(), import_op.getTableHandle(), + import_op.getKeys(), import_op.getValues()); return success(); } }; @@ -112,12 +112,12 @@ class LegalizeHashTableSizeOpPattern LogicalResult matchAndRewrite(TF::LookupTableSizeV2Op size_op, PatternRewriter& rewriter) const override { - auto handle_op = size_op.table_handle().getDefiningOp(); + auto handle_op = size_op.getTableHandle().getDefiningOp(); if (handle_op == nullptr) return failure(); auto hashtable_op = llvm::dyn_cast(handle_op); if (hashtable_op == nullptr) return failure(); rewriter.replaceOpWithNewOp( - size_op, size_op->getResultTypes(), size_op.table_handle()); + size_op, size_op->getResultTypes(), size_op.getTableHandle()); return success(); } }; @@ -137,8 +137,8 @@ bool checkWhetherGraphHasValidStaticLookupTables(ModuleOp module) { } for (auto hashtable : hashtables) { - auto key_dtype = hashtable.key_dtype(); - auto value_dtype = hashtable.value_dtype(); + auto key_dtype = hashtable.getKeyDtype(); + auto value_dtype = hashtable.getValueDtype(); // Only allow string -> int64 and int64 -> string mappings due to kernel // capability. diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc index 1bc15fefe49..db99805c92e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc @@ -50,7 +50,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 74e1a04b958..a8f0ea36135 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -491,7 +491,8 @@ def LegalizeConv2DBackpropInput : Pat< /*bias=*/ (CreateNoneValue $input_sizes), /*padding=*/ $padding, /*stride_h=*/ ExtractI32At<1>:$strides, - /*stride_w=*/ ExtractI32At<2>:$strides)>; + /*stride_w=*/ ExtractI32At<2>:$strides, + /*fused_activation_function=*/TFL_AF_None)>; def IsRankZeroAttr : CPred<"$_self.cast().getType().getRank() == 0">; diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 38d651cba4b..555a879c956 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -21,6 +21,7 @@ limitations under the License. // constant folding opportunities from the extra ops can be exploited by the // constant folding support for the TensorFlow ops. +#include #include #include #include @@ -136,7 +137,7 @@ Value GetShape(Value input, Location loc, PatternRewriter& rewriter) { RankedTensorType::get(static_shape.size(), rewriter.getIntegerType(64)); auto static_shape_attr = mlir::DenseIntElementsAttr::get(static_shape_type, static_shape); - return rewriter.create(loc, static_shape_attr).output(); + return rewriter.create(loc, static_shape_attr).getOutput(); } // If the shape is not static, create a new ShapeOp. @@ -144,7 +145,7 @@ Value GetShape(Value input, Location loc, PatternRewriter& rewriter) { return rewriter .create(loc, input, /*use_32bit=*/false_attr) - .output(); + .getOutput(); } mlir::TFL::MirrorPaddingType GetTFLMirrorPaddingFromString( @@ -169,6 +170,9 @@ mlir::TFL::MirrorPaddingType GetTFLMirrorPaddingFromString( DECL_CONVERT_OP(Assert); DECL_CONVERT_OP(ConcatV2); +DECL_CONVERT_OP(BatchMatMul); +DECL_CONVERT_OP(BatchMatMulV2); +DECL_CONVERT_OP(BatchMatMulV3); DECL_CONVERT_OP(MatMul); DECL_CONVERT_OP(MatrixDiagV2); DECL_CONVERT_OP(MatrixDiagV3); @@ -205,11 +209,12 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_concat_op = cast(op); - auto values = tf_concat_op.values(); - auto output_type = tf_concat_op.output().getType(); + auto values = tf_concat_op.getValues(); + auto output_type = tf_concat_op.getOutput().getType(); // Extract axis attribute from constant axis tensor ElementsAttr axis; - if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure(); + if (!matchPattern(tf_concat_op.getAxis(), m_Constant(&axis))) + return failure(); IntegerAttr axis_int = ExtractSingleElementAsInteger(axis); // "axis" operand could be a i64 tensor. Resolve it here. @@ -223,6 +228,128 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite( return success(); } +template +bool ConvertTFBatchMatMulOp2TFLFullyConnectedOp(Operation* bmm_op, + PatternRewriter& rewriter) { + // If `value` is produced by tf.Dequantize op, then return the Dequantize op's + // input. Otherwise return `value`. + auto get_real_input_value = [](Value value) -> Value { + Operation* defining_op = value.getDefiningOp(); + if (auto dequantize = dyn_cast_or_null(defining_op)) { + return dequantize.getInput(); + } else if (auto dequantize = + dyn_cast_or_null(defining_op)) { + return dequantize.getInput(); + } else { + return value; + } + }; + + // Returns true if the TF::BatchMatMul operation can be converted to + // tfl.fully_connected. + auto can_convert_to_fully_connected = + [&](BatchMatMulOpType& batch_matmul_op) { + Value input_rhs = get_real_input_value(batch_matmul_op.getY()); + + DenseElementsAttr constant; + if (!matchPattern(input_rhs, m_Constant(&constant))) { + return false; + } + + // The rhs matrix must be 2D for fully connected op. + return (constant.getType().getRank() == 2); + }; + + auto op = cast(bmm_op); + + // Create a tfl.transpose op that performs ZX transpose on `input`. + auto create_z_x_transpose_op = [&](Value input) -> Value { + RankedTensorType input_type = input.getType().cast(); + const int input_rank = input_type.getRank(); + + // Create a 1D I32 tensor for representing the dimension permutation. + auto permuation_tensor_type = + RankedTensorType::get({input_rank}, rewriter.getIntegerType(32)); + llvm::SmallVector permute; + permute.reserve(input_rank); + // First create an identity permutation tensor. + for (int i = 0; i < input_rank; i++) { + permute.push_back(rewriter.getI32IntegerAttr(i)); + } + // Swaps the last two dimension since the last two dimension will be mapped + // to X and Z dimension. + std::iter_swap(permute.begin() + input_rank - 1, + permute.begin() + input_rank - 2); + auto permutation_tensor_op = rewriter.create( + op->getLoc(), permuation_tensor_type, + DenseElementsAttr::get(permuation_tensor_type, permute)); + + auto input_shape = input_type.getShape(); + llvm::SmallVector permuted_shape(input_shape.begin(), + input_shape.end()); + // Swaps z dimension and x dimension to get permuted shape. + std::iter_swap(permuted_shape.begin() + input_rank - 1, + permuted_shape.begin() + input_rank - 2); + return rewriter.create( + op->getLoc(), + RankedTensorType::get(permuted_shape, input_type.getElementType()), + input, permutation_tensor_op.getResult()); + }; + + if (!can_convert_to_fully_connected(op)) { + return false; + } + + Value input_lhs = get_real_input_value(op.getX()); + Value input_rhs = get_real_input_value(op.getY()); + + Value legalized_lhs = + op.getAdjX() ? create_z_x_transpose_op(input_lhs) : input_lhs; + + // The rhs need to be transposed if adj_y == false AND this matmul will be + // legalized to tfl.fully_connected + Value legalized_rhs = + !op.getAdjY() ? create_z_x_transpose_op(input_rhs) : input_rhs; + + Type output_type = op.getResult().getType(); + auto no_input = rewriter.create( + op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); + auto fc_op = rewriter.create( + op->getLoc(), ArrayRef{output_type}, + /*input=*/legalized_lhs, /*filter=*/legalized_rhs, /*bias=*/no_input, + /*fused_activation_function=*/rewriter.getStringAttr("NONE"), + /*weights_format=*/rewriter.getStringAttr("DEFAULT"), + /*keep_num_dims=*/rewriter.getBoolAttr(true), + /*asymmetric_quantize_inputs=*/mlir::BoolAttr()); + rewriter.replaceOp(op, {fc_op.getResult(0)}); + + return true; +} + +LogicalResult ConvertTFBatchMatMulOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + if (ConvertTFBatchMatMulOp2TFLFullyConnectedOp(op, + rewriter)) + return success(); + return failure(); +} + +LogicalResult ConvertTFBatchMatMulV2Op::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + if (ConvertTFBatchMatMulOp2TFLFullyConnectedOp(op, + rewriter)) + return success(); + return failure(); +} + +LogicalResult ConvertTFBatchMatMulV3Op::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + if (ConvertTFBatchMatMulOp2TFLFullyConnectedOp(op, + rewriter)) + return success(); + return failure(); +} + LogicalResult ConvertTFMatMulOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_matmul_op = cast(op); @@ -246,12 +373,12 @@ LogicalResult ConvertTFMatMulOp::matchAndRewrite( }; // TODO(jpienaar): Remove once handled via dailect conversion. - if (tf_matmul_op.transpose_a()) { + if (tf_matmul_op.getTransposeA()) { LogicalResult result = success(); std::tie(result, lhs) = transpose(lhs); if (failed(result)) return failure(); } - if (!tf_matmul_op.transpose_b()) { + if (!tf_matmul_op.getTransposeB()) { LogicalResult result = success(); std::tie(result, rhs) = transpose(rhs); if (failed(result)) return failure(); @@ -275,11 +402,11 @@ LogicalResult ConvertTFPackOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_pack_op = cast(op); - SmallVector values(tf_pack_op.values()); - auto output_type = tf_pack_op.output().getType(); - auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N()); + SmallVector values(tf_pack_op.getValues()); + auto output_type = tf_pack_op.getOutput().getType(); + auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.getN()); // Axis can be negative. - auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis()); + auto axis = rewriter.getI32IntegerAttr(tf_pack_op.getAxis()); rewriter.replaceOpWithNewOp(op, output_type, values, values_count, axis); @@ -291,11 +418,11 @@ LogicalResult ConvertTFSplitOp::matchAndRewrite( auto tf_split_op = cast(op); // Number of splits cannot be negative. - auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split()); + auto num_split = rewriter.getI32IntegerAttr(tf_split_op.getNumSplit()); - rewriter.replaceOpWithNewOp(op, tf_split_op.output().getTypes(), - tf_split_op.split_dim(), - tf_split_op.value(), num_split); + rewriter.replaceOpWithNewOp( + op, tf_split_op.getOutput().getTypes(), tf_split_op.getSplitDim(), + tf_split_op.getValue(), num_split); return success(); } @@ -304,11 +431,11 @@ LogicalResult ConvertTFSplitVOp::matchAndRewrite( auto tf_splitv_op = cast(op); // Number of splits cannot be negative. - auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split()); + auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.getNumSplit()); rewriter.replaceOpWithNewOp( - op, tf_splitv_op.output().getTypes(), tf_splitv_op.value(), - tf_splitv_op.size_splits(), tf_splitv_op.split_dim(), num_split); + op, tf_splitv_op.getOutput().getTypes(), tf_splitv_op.getValue(), + tf_splitv_op.getSizeSplits(), tf_splitv_op.getSplitDim(), num_split); return success(); } @@ -316,12 +443,12 @@ LogicalResult ConvertTFUnpackOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_unpack_op = cast(op); - auto input = tf_unpack_op.value(); - auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num()); + auto input = tf_unpack_op.getValue(); + auto num = rewriter.getI32IntegerAttr(tf_unpack_op.getNum()); // Axis can be negative. - auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.axis()); + auto axis = rewriter.getI32IntegerAttr(tf_unpack_op.getAxis()); - rewriter.replaceOpWithNewOp(op, tf_unpack_op.output().getTypes(), + rewriter.replaceOpWithNewOp(op, tf_unpack_op.getOutput().getTypes(), input, num, axis); return success(); } @@ -357,7 +484,7 @@ LogicalResult ConvertTFConv3DOp::matchAndRewrite( op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); rewriter.replaceOpWithNewOp( - op, tf_op.getType(), tf_op.input(), tf_op.filter(), + op, tf_op.getType(), tf_op.getInput(), tf_op.getFilter(), /*bias=*/none, dilation_depth_factor, dilation_height_factor, dilation_width_factor, /*fused_activation_function=*/rewriter.getStringAttr("NONE"), padding, @@ -397,10 +524,11 @@ LogicalResult ConvertTFConv3DBackpropInputV2Op::matchAndRewrite( op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); Value output_shape = - CreateCastToInt32(tf_op.input_sizes(), op->getLoc(), rewriter); + CreateCastToInt32(tf_op.getInputSizes(), op->getLoc(), rewriter); rewriter.replaceOpWithNewOp( - op, tf_op.getType(), output_shape, tf_op.filter(), tf_op.out_backprop(), + op, tf_op.getType(), output_shape, tf_op.getFilter(), + tf_op.getOutBackprop(), /*bias=*/none, dilation_depth_factor, dilation_height_factor, dilation_width_factor, /*fused_activation_function=*/rewriter.getStringAttr("NONE"), padding, @@ -422,31 +550,31 @@ bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) { if (tf_matrix_diag_v2_or_v3_op.getNumOperands() != 5) return false; - auto input = tf_matrix_diag_v2_or_v3_op.diagonal(); - auto output_type = tf_matrix_diag_v2_or_v3_op.output().getType(); + auto input = tf_matrix_diag_v2_or_v3_op.getDiagonal(); + auto output_type = tf_matrix_diag_v2_or_v3_op.getOutput().getType(); // Extract k constant tensor and check value = 0. ElementsAttr k; - if (!matchPattern(tf_matrix_diag_v2_or_v3_op.k(), m_Constant(&k))) + if (!matchPattern(tf_matrix_diag_v2_or_v3_op.getK(), m_Constant(&k))) return false; if (ExtractSingleElementAsInteger(k).getInt() != 0) return false; // Extract num_rows constant tensor and check value = -1. ElementsAttr num_rows; - if (!matchPattern(tf_matrix_diag_v2_or_v3_op.num_rows(), + if (!matchPattern(tf_matrix_diag_v2_or_v3_op.getNumRows(), m_Constant(&num_rows))) return false; if (ExtractSingleElementAsInteger(num_rows).getInt() != -1) return false; // Extract num_cols constant tensor and check value = -1. ElementsAttr num_cols; - if (!matchPattern(tf_matrix_diag_v2_or_v3_op.num_cols(), + if (!matchPattern(tf_matrix_diag_v2_or_v3_op.getNumCols(), m_Constant(&num_cols))) return false; if (ExtractSingleElementAsInteger(num_cols).getInt() != -1) return false; // Verify padding_value is a tensor with all 0s. - mlir::Value padding_value = tf_matrix_diag_v2_or_v3_op.padding_value(); + mlir::Value padding_value = tf_matrix_diag_v2_or_v3_op.getPaddingValue(); mlir::Type element_type = padding_value.getType().cast().getElementType(); if (element_type.isa()) { @@ -659,7 +787,7 @@ class ApplyExplicitBroadcasting : public OpRewritePattern { RankedTensorType::get(symbolic_broadcast_shape.size(), rewriter.getIntegerType(64)), lhs_shape, rhs_shape) - .r0(); + .getR0(); // Broadcasts inputs using BroadcastTo op. auto broadcast_type = RankedTensorType::get( @@ -668,12 +796,12 @@ class ApplyExplicitBroadcasting : public OpRewritePattern { rewriter .create(op->getLoc(), broadcast_type, lhs, broadcast_shape) - .output(); + .getOutput(); auto broadcasted_rhs = rewriter .create(op->getLoc(), broadcast_type, rhs, broadcast_shape) - .output(); + .getOutput(); // Recreate an op with the above BroadcastTo op results. RankedTensorType result_type = RankedTensorType::get( @@ -725,13 +853,13 @@ class ApplyExplicitBroadcasting : public OpRewritePattern { lhs = rewriter .create(op->getLoc(), broadcast_type, lhs, new_shape) - .output(); + .getOutput(); } if (result_type.getShape() != rhs_shape) { rhs = rewriter .create(op->getLoc(), broadcast_type, rhs, new_shape) - .output(); + .getOutput(); } // Recreate an op with the above Broadcast op results. @@ -786,12 +914,12 @@ class ApplyExplicitBroadcasting rewriter .create(op->getLoc(), lhs_shape.getType(), lhs_shape, rhs_shape) - .r0(); + .getR0(); broadcast_shape_value = rewriter .create(op->getLoc(), lhs_shape.getType(), broadcast_shape_value, cond_shape) - .r0(); + .getR0(); // Broadcasting inputs using BroadcastTo op. auto broadcast_type = RankedTensorType::get( @@ -803,17 +931,17 @@ class ApplyExplicitBroadcasting RankedTensorType::get(symbolic_broadcast_shape, rewriter.getIntegerType(1)), cond, broadcast_shape_value) - .output(); + .getOutput(); auto broadcasted_lhs = rewriter .create(op->getLoc(), broadcast_type, lhs, broadcast_shape_value) - .output(); + .getOutput(); auto broadcasted_rhs = rewriter .create(op->getLoc(), broadcast_type, rhs, broadcast_shape_value) - .output(); + .getOutput(); // Recreate an op with the above BroadcastTo op results. rewriter.replaceOpWithNewOp( @@ -873,19 +1001,19 @@ class ApplyExplicitBroadcasting cond = rewriter .create(op->getLoc(), cond_result_type, cond, new_shape) - .output(); + .getOutput(); } if (result_shape != lhs_shape) { lhs = rewriter .create(op->getLoc(), result_type, lhs, new_shape) - .output(); + .getOutput(); } if (result_shape != rhs_shape) { rhs = rewriter .create(op->getLoc(), result_type, rhs, new_shape) - .output(); + .getOutput(); } // Recreate an op with the above Broadcast op results. @@ -902,7 +1030,9 @@ void addPatterns(MLIRContext* context, RewritePatternSet& patterns, // Add the generated patterns to the list. populateWithGenerated(patterns); - patterns.add(context); diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc index 849db47e67d..1dcbd4c0384 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc @@ -63,7 +63,7 @@ void RunOnWhile(TF::WhileOp while_op) { // Create new TFL While op that will be used to replace TF While op. auto new_op = OpBuilder(op).create( op->getLoc(), op->getResultTypes(), op->getOperands(), - while_op.is_stateless()); + while_op.getIsStateless()); Location loc = while_op->getLoc(); CreateRegionWithCall(while_op.cond_function(), new_op.getCond(), loc); CreateRegionWithCall(while_op.body_function(), new_op.getBody(), loc); diff --git a/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc b/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc index f43e13b82cd..d939d74c5dd 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc @@ -110,11 +110,25 @@ class LiftFlexCustomOp : public OpRewritePattern { Operation* tf_op = rewriter.create(op_state); rewriter.replaceOp(op, tf_op->getResults()); + if (isa(tf_op)) { + constexpr StringRef kFuncAttrName = "f"; + tf_op->setAttr( + kFuncAttrName, + tf_op->getAttr(kFuncAttrName).cast().getName()); + } + + if (isa(tf_op)) { + constexpr StringRef kFuncAttrName = "predicate"; + tf_op->setAttr( + kFuncAttrName, + tf_op->getAttr(kFuncAttrName).cast().getName()); + } + // Special type fixes for TF Resource Tensors that are casted to // Int32 tensor during MLIR->TFLite flatbuffer conversion. // TODO(b/146131919): correct handling of resource type if (auto tensor_array_v3_op = dyn_cast(tf_op)) { - Value handle = tensor_array_v3_op.handle(); + Value handle = tensor_array_v3_op.getHandle(); auto handle_type = handle.getType().cast(); if (handle_type.getElementType().isInteger(/*width=*/32)) { Type resource_tensor_type = @@ -132,7 +146,7 @@ class LiftFlexCustomOp : public OpRewritePattern { if (auto tensor_array_v3_op = dyn_cast(tf_op)) { // The "flow" in TensorArrayV3 is always a scalar float tensor. // https://www.tensorflow.org/api_docs/python/tf/raw_ops/TensorArrayWriteV3 - Value flow = tensor_array_v3_op.flow(); + Value flow = tensor_array_v3_op.getFlow(); Type scalar_f32_tensor_type = RankedTensorType::get(/*shape=*/{}, rewriter.getF32Type()); flow.setType(scalar_f32_tensor_type); @@ -150,10 +164,10 @@ class LiftFlexCustomOp : public OpRewritePattern { values.reserve(args.size()); for (const auto& arg : args) { auto range = arg_ranges.at(arg.name()); - values.push_back( - range.second - range.first); + values.push_back(range.second - range.first); } - auto attr_value = mlir::DenseI32ArrayAttr::get(tf_op->getContext(), values); + auto attr_value = + mlir::DenseI32ArrayAttr::get(tf_op->getContext(), values); tf_op->setAttr(attr_name, attr_value); }; if (tf_op->hasTrait() || diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 13e8dc7dc4f..865da881efe 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -22,12 +22,12 @@ limitations under the License. #include #include +#include #include #include "absl/container/inlined_vector.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" @@ -255,7 +255,7 @@ struct ConvertConst : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { // Verify that the tensor proto contains tensor of type variant and scalar // shape. The variant type should hold a TensorList. - auto proto_attr = op.value().dyn_cast(); + auto proto_attr = op.getValue().dyn_cast(); if (!proto_attr) return failure(); tensorflow::Tensor tensor; if (!tensorflow::ConvertToTensor(proto_attr, &tensor).ok()) @@ -470,7 +470,7 @@ struct ConvertTensorListInitOp : public TensorListOpConverterBase { LogicalResult matchAndRewrite( OpT op, typename OpT::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type dtype = op.element_dtype(); + Type dtype = op.getElementDtype(); if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() || dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) || dtype.isInteger(32) || dtype.isInteger(64))) { @@ -499,7 +499,7 @@ struct ConvertTensorListInitOp : public TensorListOpConverterBase { element_shape = rewriter.create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape({-1}, shape_dtype), - set_op.item()); + set_op.getItem()); element_shape_acquired = true; } else if (TF::WhileOp while_op = llvm::dyn_cast(use.getOwner())) { @@ -513,7 +513,7 @@ struct ConvertTensorListInitOp : public TensorListOpConverterBase { llvm::dyn_cast( inside_use.getOwner())) { if (auto shaped_type = - set_op.item().getType().dyn_cast()) { + set_op.getItem().getType().dyn_cast()) { if (shaped_type.hasStaticShape()) { RankedTensorType type = tensorflow::GetTypeFromTFTensorShape( @@ -594,7 +594,7 @@ struct ConvertTensorListInitOp : public TensorListOpConverterBase { } int64_t result_rank = -1; // -1 means unknown result rank. - Type element_dtype = op.element_dtype(); + Type element_dtype = op.getElementDtype(); Type result_type = UnrankedTensorType::get(element_dtype); Value leading_dim = GetNumElements(op, adaptor.getOperands(), &rewriter); if (auto element_type = @@ -650,7 +650,7 @@ struct ConvertTensorListReserve Value GetNumElements(TF::TensorListReserveOp op, ValueRange operands, PatternRewriter *rewriter) const override { Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0); - Type shape_dtype = getElementTypeOrSelf(op.element_shape().getType()); + Type shape_dtype = getElementTypeOrSelf(op.getElementShape().getType()); Value num_elements = operands[1]; IntegerAttr attr; if (matchPattern(num_elements, m_Constant(&attr))) { @@ -658,7 +658,7 @@ struct ConvertTensorListReserve } if (auto const_op = num_elements.getDefiningOp()) { return CreateI32SplatConst(op->getLoc(), rewriter, {1}, - (*const_op.value() + (*const_op.getValue() .cast() .getValues() .begin()) @@ -707,7 +707,7 @@ struct ConvertTensorListPushBack loc, expanded_item_type, item, scalar_zero); Type elem_type = getElementTypeOrSelf(item); - auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType()) + auto handle_dtype = getElementTypeOrSelf(op.getOutputHandle().getType()) .cast(); Type result_type = GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter); @@ -749,7 +749,7 @@ struct ConvertTensorListResize // Infer result type of this op based on TF's shape inference result. Type elem_type = getElementTypeOrSelf(input_handle); - auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType()) + auto handle_dtype = getElementTypeOrSelf(op.getOutputHandle().getType()) .cast(); Type result_type = GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter); @@ -842,7 +842,7 @@ struct ConvertTensorListResize loc, tensorflow::GetTypeFromTFTensorShape({-1}, shape_dtype), input_shape, slice_start, slice_size); auto extended_part = rewriter->create( - loc, resize_op.output_handle().getType(), elem_shape, size_diff); + loc, resize_op.getOutputHandle().getType(), elem_shape, size_diff); // `ConcatOp` expects non-variant-typed input. Insert a // `TensorListStackOp` here to convert type from variant to non-variant. // Note that we are using the same `result_type` for both the @@ -954,7 +954,7 @@ struct ConvertTensorListStack RankedTensorType shape_type = tensorflow::GetTypeFromTFTensorShape({-1}, rewriter.getIntegerType(32)); auto new_shape = rewriter.create(loc, shape_type, input); - SmallVector output_shape(/*Size=*/1, op.num_elements()); + SmallVector output_shape(/*Size=*/1, op.getNumElements()); for (const auto &dim : dense_elem_attr.getValues()) output_shape.push_back(dim.getSExtValue()); RankedTensorType result_type = tensorflow::GetTypeFromTFTensorShape( @@ -1116,7 +1116,7 @@ bool IsTensorListType(Type type, llvm::Optional value) { if (!value.has_value()) { return false; } - for (const mlir::OpOperand &use : value.getValue().getUses()) { + for (const mlir::OpOperand &use : value.value().getUses()) { mlir::Operation *op = use.getOwner(); if (llvm::isa(op) || llvm::isa(op) || @@ -1150,7 +1150,7 @@ llvm::SmallSet GetTensorListResultsIndex(func::FuncOp func) { for (const auto &result_and_idx : llvm::enumerate(func.getFunctionType().getResults())) { - if (IsTensorListType(result_and_idx.value(), llvm::None)) { + if (IsTensorListType(result_and_idx.value(), std::nullopt)) { set.insert(result_and_idx.index()); } } @@ -1315,10 +1315,10 @@ llvm::DenseMap MapTensorListResultToArgument(func::FuncOp func) { Value parent = value; while (true) { if (auto identity = parent.getDefiningOp()) { - parent = identity.input(); + parent = identity.getInput(); } else if (auto set_item = parent.getDefiningOp()) { - parent = set_item.input_handle(); + parent = set_item.getInputHandle(); } else { break; } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 0a0536d9f53..8c6a2f14e32 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -23,12 +23,12 @@ limitations under the License. #include #include #include +#include #include #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" @@ -309,7 +309,7 @@ DenseElementsAttr GetShape(Value output_val) { RankedTensorType::get( {static_cast(shape.size())}, mlir::IntegerType::get(output_val.getContext(), 32)), - llvm::makeArrayRef(shape)); + llvm::ArrayRef(shape)); } static Type GetShapeStrippedType(TypeAttr type_attr) { @@ -867,11 +867,12 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { // TF::MulOp is used to fold the constant. // TODO(b/139192933): switch to the TFL constant folding auto new_filter = - rewriter.create(mul_op.getLoc(), filter, new_const_val).z(); + rewriter.create(mul_op.getLoc(), filter, new_const_val) + .getZ(); // If bias isn't None, it needs to be multiplied as well. if (!bias.getType().isa()) { - bias = - rewriter.create(mul_op.getLoc(), bias, constant_val).z(); + bias = rewriter.create(mul_op.getLoc(), bias, constant_val) + .getZ(); } auto fc = rewriter.create( @@ -978,7 +979,7 @@ struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern { // Rewrite filter constant. Since the folder of TFL::MulOp couldn't // broadcast the operands, TF::MulOp is used to fold the constant. auto new_filter = - rewriter.create(loc, filter, broadcasted_gamma).z(); + rewriter.create(loc, filter, broadcasted_gamma).getZ(); // Update the scale in the quantize op. auto new_qtype = RescaleQtype(q_op.getQtype(), gamma_cst); if (!new_qtype) return failure(); @@ -1542,16 +1543,16 @@ struct OptimizeTopK : public OpRewritePattern { auto slice_op = llvm::dyn_cast_or_null(value.getUses().begin().getUser()); // We only match for the case where value is used by SliceOp. - if (!slice_op) return llvm::None; + if (!slice_op) return std::nullopt; DenseElementsAttr begin; DenseElementsAttr size; if (!matchPattern(slice_op->getOperand(1), m_Constant(&begin)) || !matchPattern(slice_op->getOperand(2), m_Constant(&size))) - return llvm::None; + return std::nullopt; // Check if "begin" is a zero tensor. for (auto begin_idx : begin.getValues()) - if (begin_idx != 0) return llvm::None; + if (begin_idx != 0) return std::nullopt; // Check if "size" is equal to slice_op.input.shape except // for last dimension. @@ -1559,19 +1560,19 @@ struct OptimizeTopK : public OpRewritePattern { // i.e., num_input/input_last_dim = num_result/k auto input_ty = value.getType().dyn_cast_or_null(); auto result_ty = slice_op.getType().dyn_cast(); - if (!input_ty || !result_ty) return llvm::None; + if (!input_ty || !result_ty) return std::nullopt; if (!input_ty.hasStaticShape() || !result_ty.hasStaticShape()) - return llvm::None; - if (!input_ty.getRank() || !result_ty.getRank()) return llvm::None; + return std::nullopt; + if (!input_ty.getRank() || !result_ty.getRank()) return std::nullopt; int num_input = input_ty.getNumElements(); int input_last_dim = input_ty.getShape().back(); - if (input_last_dim < 1) return llvm::None; + if (input_last_dim < 1) return std::nullopt; int num_result = result_ty.getNumElements(); auto size_last = *(--size.value_end()); int32_t k = size_last.getSExtValue(); - if (num_input / input_last_dim * k != num_result) return llvm::None; + if (num_input / input_last_dim * k != num_result) return std::nullopt; // We don't match sliceOp with last dim size = 0. - if (!k) return llvm::None; + if (!k) return std::nullopt; return k; } @@ -1586,8 +1587,8 @@ struct OptimizeTopK : public OpRewritePattern { auto k_values_or = ComputeSliceK(values); auto k_indices_or = ComputeSliceK(indices); if (!k_values_or.has_value() || !k_indices_or.has_value()) return failure(); - int32_t k_values = k_values_or.getValue(); - int32_t k_indices = k_indices_or.getValue(); + int32_t k_values = k_values_or.value(); + int32_t k_indices = k_indices_or.value(); // We don't match two SliceOp with different sizes. if (k_values != k_indices && !values.use_empty() && !indices.use_empty()) return failure(); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc index ef3c6d66e6a..7d7ab4b5acd 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc @@ -21,7 +21,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -56,7 +56,7 @@ void UpdateFuncType(func::FuncOp func) { auto return_types = llvm::to_vector<4>(terminator->getOperandTypes()); FunctionType func_type = func.getFunctionType(); - if (llvm::makeArrayRef(return_types) == func_type.getResults()) return; + if (llvm::ArrayRef(return_types) == func_type.getResults()) return; auto updated_type = FunctionType::get(func.getContext(), func_type.getInputs(), return_types); @@ -67,7 +67,7 @@ void UpdateFuncType(func::FuncOp func) { bool IsSideEffectFree(func::FuncOp func) { return !func.getBody() .walk([&](Operation* op) { - if (!MemoryEffectOpInterface::hasNoEffect(op) && + if (!isMemoryEffectFree(op) && !op->hasTrait()) return WalkResult::interrupt(); return WalkResult::advance(); @@ -99,7 +99,7 @@ class FoldIfOp : public OpRewritePattern { // remove. // TODO(jpienaar): Remove once recusive side-effects are supported. if (op.use_empty() && - (op.is_stateless() || + (op.getIsStateless() || (IsSideEffectFree(then_func) && IsSideEffectFree(else_func)))) { rewriter.eraseOp(op.getOperation()); return success(); @@ -107,7 +107,7 @@ class FoldIfOp : public OpRewritePattern { // Extract the constant cond value. DenseElementsAttr cond; - if (!matchPattern(op.cond(), m_Constant(&cond))) return failure(); + if (!matchPattern(op.getCond(), m_Constant(&cond))) return failure(); // TODO(hinsu): Handle constants that are not scalar booleans. auto cond_type = cond.getType().dyn_cast(); @@ -124,7 +124,7 @@ class FoldIfOp : public OpRewritePattern { // one blocks are not encountered in practice. if (!llvm::hasSingleElement(func)) return failure(); - BlockAndValueMapping mapper; + IRMapping mapper; for (int i = 0, e = func.getNumArguments(); i != e; ++i) mapper.map(func.getArgument(i), op.getOperand(i + 1)); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index f6b28b35e20..e772f9cb88d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -129,25 +129,25 @@ multiclass FuseBinaryOpToPrecedingAffine { (HasRank<1> $value), (HasOneUse $output)]>; def FuseBinaryOpWithTransposeConv#binaryOp : Pat< - (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs, + (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $input, (Arith_ConstantOp FloatElementsAttr:$bias), $padding, - $stride_h, $stride_w), - (Arith_ConstantOp FloatElementsAttr:$value), TFL_AF_None), - (TFL_TransposeConvOp $output_shape, $weights, $inputs, + $stride_h, $stride_w, TFL_AF_None), + (Arith_ConstantOp FloatElementsAttr:$value), $act_fn), + (TFL_TransposeConvOp $output_shape, $weights, $input, (binaryOp (Arith_ConstantOp $bias), (Arith_ConstantOp $value), TFL_AF_None), - $padding, $stride_h, $stride_w), + $padding, $stride_h, $stride_w, $act_fn), [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), (HasOneUse $output)]>; // Fuse for TransposeConv with no bias def FuseBinaryOpWithTransposeConvNoneBias#binaryOp : Pat< - (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs, + (binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $input, $bias, $padding, - $stride_h, $stride_w), - (Arith_ConstantOp FloatElementsAttr:$value), TFL_AF_None), - (TFL_TransposeConvOp $output_shape, $weights, $inputs, + $stride_h, $stride_w, TFL_AF_None), + (Arith_ConstantOp FloatElementsAttr:$value), $act_fn), + (TFL_TransposeConvOp $output_shape, $weights, $input, (Arith_ConstantOp $value), - $padding, $stride_h, $stride_w), + $padding, $stride_h, $stride_w, $act_fn), [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), (IsNoneType $bias), (HasOneUse $output)]>; @@ -209,8 +209,8 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d { (BinaryOp (TFL_TransposeConvOp:$output $output_shape, (Arith_ConstantOp FloatElementsAttr:$weights), $input, (Arith_ConstantOp FloatElementsAttr:$bias), - $padding, $stride_h, $stride_w), - (Arith_ConstantOp $value), TFL_AF_None), + $padding, $stride_h, $stride_w, TFL_AF_None), + (Arith_ConstantOp $value), $act_fn), (TFL_TransposeConvOp $output_shape, (BinaryOp (Arith_ConstantOp $weights), (Arith_ConstantOp (ExpandTo4DForConv $value)), @@ -219,22 +219,22 @@ multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d { (BinaryOp (Arith_ConstantOp $bias), (Arith_ConstantOp $value), TFL_AF_None), - $padding, $stride_h, $stride_w), + $padding, $stride_h, $stride_w, $act_fn), [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), (HasOneUse $output)]>; def FuseMulOrDivWithTransposeConvWithNoneBias#BinaryOp : Pat< (BinaryOp (TFL_TransposeConvOp:$output $output_shape, (Arith_ConstantOp FloatElementsAttr:$weights), $input, $bias, - $padding, $stride_h, $stride_w), - (Arith_ConstantOp $value), TFL_AF_None), + $padding, $stride_h, $stride_w, TFL_AF_None), + (Arith_ConstantOp $value), $act_fn), (TFL_TransposeConvOp $output_shape, (BinaryOp (Arith_ConstantOp $weights), (Arith_ConstantOp (ExpandTo4DForConv $value)), TFL_AF_None), $input, $bias, - $padding, $stride_h, $stride_w), + $padding, $stride_h, $stride_w, $act_fn), [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), (IsNoneType $bias), (HasOneUse $output)]>; diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 30fd466963a..7777a1e2ded 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -49,7 +49,6 @@ class OperationPass; class Type; namespace TFL { -using StringSet = absl::flat_hash_set; // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. // When the given run_tfl_runtime_verification value is true, it will check each @@ -84,7 +83,8 @@ std::unique_ptr> CreateLowerStaticTensorListPass(); // as they are now structure variables of QuantizationSpecs. std::unique_ptr> CreateQuantizePass( const quant::QuantizationSpecs& quant_specs, - const StringSet& ops_blocklist = {}, const StringSet& nodes_blocklist = {}); + const absl::flat_hash_set& ops_blocklist = {}, + const absl::flat_hash_set& nodes_blocklist = {}); std::unique_ptr> CreateDefaultQuantizePass(); @@ -92,8 +92,9 @@ std::unique_ptr> CreateDefaultQuantizePass(); // the binary size. std::unique_ptr> CreateQuantizePass( bool verify_numeric = false, bool whole_model_verify = false, - bool legacy_float_scale = false, const StringSet& ops_blocklist = {}, - const StringSet& nodes_blocklist = {}); + bool legacy_float_scale = false, + const absl::flat_hash_set& ops_blocklist = {}, + const absl::flat_hash_set& nodes_blocklist = {}); // Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass. std::unique_ptr> CreatePrepareQuantizePass( @@ -115,6 +116,9 @@ std::unique_ptr> CreatePostQuantizePass(); std::unique_ptr> CreatePostQuantizePass( bool emit_quant_adaptor_ops, const quant::CustomOpMap& custom_op_map = {}); +// Creates an instance of the TensorFlow Lite dialect QuantizeVariables pass. +std::unique_ptr> CreatePrepareQuantizeVariablesPass(); + // Creates an instance of the TensorFlow Lite pass that decomposes hybrid // quantization patterns to the same dense operation with tfl dequantization // and quantization patterns. diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.td b/tensorflow/compiler/mlir/lite/transforms/passes.td index 3fa85cd3504..d54d34d2033 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/transforms/passes.td @@ -288,6 +288,8 @@ def PrepareQuantizePass : Pass<"tfl-prepare-quantize", "mlir::func::FuncOp"> { "comma separated list of allowlisted functions to be quantized. Only used in tests">, Option<"quantize_signed_", "quantize-signed", "bool", "false", "signed inference type. Only used in tests">, + Option<"activation_number_of_bits_", "activation-number-of-bits", "int", "8", + "number of bits for inference type. Only used in tests">, Option<"post_training_quantize_", "post-training-quantize", "bool", "false", "enable post training quantization. Only used in tests">, Option<"legacy_float_scale_", "legacy-float-scale", "bool", "false", @@ -384,6 +386,12 @@ def QuantizePass : Pass<"tfl-quantize", "mlir::func::FuncOp"> { ]; } +def QuantizeVariablesPass : Pass<"tfl-quantize-variables", "mlir::ModuleOp"> { + let summary = "Quantize variables"; + let constructor = "CreatePrepareQuantizeVariablesPass()"; + let dependentDialects = ["TFL::TensorFlowLiteDialect"]; +} + def RaiseCustomOpsPass : Pass<"tfl-raise-custom-ops", "mlir::func::FuncOp"> { let summary = "Raise custom ops into tflite dialect."; let constructor = "CreateRaiseCustomOpsPass()"; diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index 59fee3b765e..3df42acf5e3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -399,6 +399,21 @@ void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(func::FuncOp func, if (failed(ConvertKerasLSTMLayer(func, &builder))) return signalPassFailure(); } + + // LSTM `func::FuncOps` with indy behavior always have the `tf.api_implements` + // function attribute prefixed with `"indy_lstm_"`. + // IndyLSTMs have diagonal recurrent weight matrices and can benefit from + // more efficent operations in TFLite with the correct conversion (i.e. when + // the diagonal recurrent weight matrices are provided as vectors). + if (attr.getValue().startswith("indy_lstm_")) { + // Check if the keras lstm can be fused, if not, we just don't do anything. + if (failed(CheckFusableKerasLstm(func, module))) return; + func.eraseBody(); + func.addEntryBlock(); + OpBuilder builder(func.getBody()); + if (failed(ConvertKerasLSTMLayer(func, &builder, true))) + return signalPassFailure(); + } } void PrepareCompositeFunctionsPass::runOnOperation() { diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index c9a8444b4be..1fa65305a97 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -190,9 +190,8 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) { if (!min_max.first.has_value() || !min_max.second.has_value()) return; TypeAttr params = quant::GetQuantizedTypeAttr( - builder, input_type, - builder.getF64FloatAttr(min_max.first.getValue()), - builder.getF64FloatAttr(min_max.second.getValue()), + builder, input_type, builder.getF64FloatAttr(min_max.first.value()), + builder.getF64FloatAttr(min_max.second.value()), /*quant_dim=*/-1, num_bits, narrow_range, is_signed); builder.setInsertionPoint(block, insertion_point); auto q_op = builder.create( @@ -327,8 +326,14 @@ void PrepareQuantizePass::runOnOperation() { MLIRContext* ctx = func.getContext(); ScopedTFLQuantOpsToMlirQuantOpsConverter converter(func); if (use_quantization_flags_) { - quant_specs_.inference_type = - this->quantize_signed_ ? tensorflow::DT_QINT8 : tensorflow::DT_QUINT8; + quant_specs_.inference_type = GetQuantizedInferenceType( + this->quantize_signed_, this->activation_number_of_bits_); + if (quant_specs_.inference_type == tensorflow::DT_INVALID) { + func.emitError() << "prepare-quantize pass failed: unsupported " + "inference type specification"; + signalPassFailure(); + return; + } quant_specs_.post_training_quantization = post_training_quantize_; quant_specs_.legacy_float_scale = legacy_float_scale_; quant_specs_.disable_set_input_nodes_quantization_params = @@ -378,16 +383,14 @@ void PrepareQuantizePass::runOnOperation() { if (is_signed) { patterns_2.add>( ctx); - // Convert quant stats to int8 quantization parameters. - // Currently, only activation stats are imported, so narrow_range = false. - patterns_2.add(bit_width, false, true, - quant_specs_.legacy_float_scale, ctx); - } else { - // Convert quant stats to uint8 quantization parameters. - // Currently, only activation stats are imported, so narrow_range = false. - patterns_2.add(bit_width, false, false, - quant_specs_.legacy_float_scale, ctx); } + // Convert quant stats to int8, unit8, int16 quantization parameters. + // Currently, only activation stats are imported, so narrow_range = false. + // TODO(b/266524882): Support narrow_range in TFLite converter(ODML + // converter). + patterns_2.add(bit_width, /*narrow_range=*/false, + is_signed, quant_specs_.legacy_float_scale, + ctx); if (quant_specs_.post_training_quantization) { patterns_2.add>(ctx, quant_specs_); diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc index a5d2099fa51..04f7fb84011 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc @@ -419,7 +419,7 @@ void PrepareDynamicRangeQuantizePass::runOnOperation() { if (!enable_custom_op_quantization_.empty()) { ParseCustomOpSpecs(enable_custom_op_quantization_, - quant::CustomOpUpdateOptions::kINputIndices, + quant::CustomOpUpdateOptions::kInputIndices, quant_specs_.custom_map); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.cc new file mode 100644 index 00000000000..ad48c5f1407 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.cc @@ -0,0 +1,43 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mlir { +namespace TFL { + +double PowerOfTwoBound(double value) { + return std::pow(2, std::ceil(std::log2(value))); +} + +tensorflow::DataType GetQuantizedInferenceType(bool is_signed, + int number_of_bits) { + if (is_signed && number_of_bits == 8) { + return tensorflow::DT_QINT8; + } else if (!is_signed && number_of_bits == 8) { + return tensorflow::DT_QUINT8; + } else if (is_signed && number_of_bits == 16) { + return tensorflow::DT_QINT16; + } else { + return tensorflow::DT_INVALID; + } +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h index 644751d0321..90a48d577ef 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h @@ -62,9 +62,10 @@ constexpr const char* intermediate_attributes[] = { "effective_hidden_scale_intermediate"}; // Calculates the minimum power of two that is not less than the value. -inline double PowerOfTwoBound(double value) { - return std::pow(2, std::ceil(std::log2(value))); -} +double PowerOfTwoBound(double value); + +tensorflow::DataType GetQuantizedInferenceType(bool is_signed, + int activation_number_of_bits); // Returns the element type of LSTM's intermediate tensor designated by the // index. @@ -84,9 +85,10 @@ using Q = quantfork::QuantizeCastOp; using DQ = quantfork::DequantizeCastOp; template -LogicalResult GetLstmProperty( - LstmOp op, operator_property::OpVariant* lstm_variant, - operator_property::OperatorProperty* op_property) { +LogicalResult GetLstmProperty(LstmOp op, + operator_property::OpVariant* lstm_variant, + operator_property::OperatorProperty* op_property, + int activation_number_of_bits = 8) { if (llvm::isa(op.getOperation())) { lstm_variant->op_code = tflite::BuiltinOperator_LSTM; } else if (llvm::isa(op.getOperation())) { @@ -103,7 +105,8 @@ LogicalResult GetLstmProperty( lstm_variant->use_layer_norm = !op.getForgetLayerNormCoefficients().getType().template isa(); - *op_property = operator_property::GetOperatorProperty(*lstm_variant); + *op_property = operator_property::GetOperatorProperty( + *lstm_variant, activation_number_of_bits); // TODO(b/176258587) move this to operator_property.cc if this is needed in // other components, too. @@ -308,12 +311,12 @@ class ConvertOpStatsToQDQs : public OpRewritePattern { rewriter.getIntegerType(16), attr.getType().getElementType(), scale, /*zeroPoint=*/0, llvm::minIntN(10), -llvm::minIntN(10)); } else { - quant_type = - quant::GetUniformQuantizedTypeForWeight( - attr, /*symmetric=*/true, - /*num_bits=*/tensor_property.number_of_bits, /*is_signed=*/true, - /*narrow_range=*/true, quant_specs_.legacy_float_scale) - .template dyn_cast(); + quant_type = quant::GetUniformQuantizedTypeForWeight( + attr, /*symmetric=*/true, + /*num_bits=*/tensor_property.number_of_bits, + /*is_signed=*/true, + /*narrow_range=*/true, quant_specs_.legacy_float_scale) + .template dyn_cast(); } if (!quant_type) { const_op->emitError("Failed to get quantized type"); @@ -405,13 +408,14 @@ class ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs { public: ConvertLstmStatsToQDQs(MLIRContext* context, const quant::QuantizationSpecs& quant_specs) - - : ConvertOpStatsToQDQs(context, quant_specs) {} + : ConvertOpStatsToQDQs(context, quant_specs), + activation_number_of_bits_(quant_specs.GetQuantizationTypeWidth()) {} LogicalResult matchAndRewrite(SourceOp op, PatternRewriter& rewriter) const override { operator_property::OpVariant lstm_variant; operator_property::OperatorProperty lstm_property; - if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) { + if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property, + activation_number_of_bits_))) { return failure(); } @@ -491,6 +495,8 @@ class ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs { } return success(); } + + int activation_number_of_bits_; }; // Returns a function that returns the quantized type of a bias input. diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 42a6bf643b0..465827429db 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -64,6 +64,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/size_utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" @@ -72,7 +73,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" #define DEBUG_TYPE "tf-tfl-legalization" @@ -167,9 +168,9 @@ class ConvertTFConvOp : public RewritePattern { // [1, X, Y, 1] if exists. TFConvOpType tf_op = cast(op); - if (!TFTypeIsFloat32Tensor(tf_op.input()) && + if (!TFTypeIsFloat32Tensor(tf_op.getInput()) && !(allow_bf16_and_f16_type_legalization_ && - TFTypeIsBFloat16OrHalfTensor(tf_op.input()))) + TFTypeIsBFloat16OrHalfTensor(tf_op.getInput()))) return failure(); if (!TFDataFormatIsNHWC(op)) return failure(); @@ -196,13 +197,13 @@ class ConvertTFConvOp : public RewritePattern { // Additionally, we require the filter operand to be of 4-D tensor type so // that we can extract info from the shape (e.g., for constructing bias // tensor, for setting depth_multiplier attribute, etc.). - auto filter = tf_op.filter(); + auto filter = tf_op.getFilter(); auto filter_type = filter.getType().template dyn_cast(); if (!filter_type || filter_type.getRank() != 4 || !filter_type.hasStaticShape()) return failure(); - Value input = tf_op.input(); + Value input = tf_op.getInput(); RankedTensorType input_type = input.getType().template dyn_cast(); // Only rank size four input will be only available by the tf.Conv2D @@ -317,7 +318,7 @@ class ConvertTFConv2D : public ConvertTFConvOp { auto perm_type = tensorflow::GetTypeFromTFTensorShape( {static_cast(perm.size())}, rewriter.getIntegerType(32)); auto perm_attr = - DenseElementsAttr::get(perm_type, llvm::makeArrayRef(perm)); + DenseElementsAttr::get(perm_type, llvm::ArrayRef(perm)); auto perm_op = rewriter.create(loc, perm_type, perm_attr); // Create tensor type for the transpose result. @@ -394,8 +395,9 @@ class ConvertTFDepthwiseConv2dNative tensorflow::GetTypeFromTFTensorShape({4}, rewriter.getIntegerType(32)); SmallVector result_shape_data(4); for (int i = 0; i < 4; ++i) { + auto size = result_shape[i]; result_shape_data[i] = - rewriter.getI32IntegerAttr(static_cast(result_shape[i])); + rewriter.getI32IntegerAttr(ConvertToTfliteSize(size)); } auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data); auto shape = rewriter.create(loc, shape_type, shape_attr); @@ -427,9 +429,9 @@ struct ConvertTFStridedSlice : public RewritePattern { LogicalResult RewriteNewAxisMask(Operation *op, PatternRewriter &rewriter) const { TF::StridedSliceOp strided_slice_op = llvm::cast(op); - uint64_t new_axis_mask = strided_slice_op.new_axis_mask(); + uint64_t new_axis_mask = strided_slice_op.getNewAxisMask(); - if (strided_slice_op.ellipsis_mask() != 0) { + if (strided_slice_op.getEllipsisMask() != 0) { // Ellipsis mask should have been lowered-away prior to invoking this // function. op->emitError() << "encountered a logical error"; @@ -437,7 +439,7 @@ struct ConvertTFStridedSlice : public RewritePattern { } // Insert a new reshape op. - Value original_input = strided_slice_op.input(); + Value original_input = strided_slice_op.getInput(); RankedTensorType original_input_type = original_input.getType().dyn_cast(); if (!original_input_type) { @@ -466,8 +468,9 @@ struct ConvertTFStridedSlice : public RewritePattern { {dim_size}, rewriter.getIntegerType(32)); SmallVector result_shape_data(dim_size); for (int i = 0; i < dim_size; ++i) { + auto size = revised_shape[i]; result_shape_data[i] = - rewriter.getI32IntegerAttr(static_cast(revised_shape[i])); + rewriter.getI32IntegerAttr(ConvertToTfliteSize(size)); } auto shape_attr = DenseElementsAttr::get(shape_type, result_shape_data); @@ -479,25 +482,25 @@ struct ConvertTFStridedSlice : public RewritePattern { loc, revised_output_type, original_input, shape); // Replace the original strided_slice. - uint64_t revised_begin_mask = strided_slice_op.begin_mask(); - uint64_t revised_end_mask = strided_slice_op.end_mask(); + uint64_t revised_begin_mask = strided_slice_op.getBeginMask(); + uint64_t revised_end_mask = strided_slice_op.getEndMask(); // Since we expand the dims, we need to apply them to the begin_mask & // end_mask. - revised_begin_mask |= strided_slice_op.new_axis_mask(); - revised_end_mask |= strided_slice_op.new_axis_mask(); + revised_begin_mask |= strided_slice_op.getNewAxisMask(); + revised_end_mask |= strided_slice_op.getNewAxisMask(); // Enforce operator precedence. - uint64_t revised_shrink_axis_mask = - strided_slice_op.shrink_axis_mask() & ~strided_slice_op.new_axis_mask(); + uint64_t revised_shrink_axis_mask = strided_slice_op.getShrinkAxisMask() & + ~strided_slice_op.getNewAxisMask(); auto attribute_type = rewriter.getIntegerType(64); rewriter.replaceOpWithNewOp( - op, strided_slice_op.getType(), reshape, strided_slice_op.begin(), - strided_slice_op.end(), strided_slice_op.strides(), + op, strided_slice_op.getType(), reshape, strided_slice_op.getBegin(), + strided_slice_op.getEnd(), strided_slice_op.getStrides(), rewriter.getIntegerAttr(attribute_type, revised_begin_mask), rewriter.getIntegerAttr(attribute_type, revised_end_mask), rewriter.getIntegerAttr(attribute_type, - strided_slice_op.ellipsis_mask()), + strided_slice_op.getEllipsisMask()), rewriter.getI64IntegerAttr(0), rewriter.getIntegerAttr(attribute_type, revised_shrink_axis_mask)); return success(); @@ -507,16 +510,16 @@ struct ConvertTFStridedSlice : public RewritePattern { PatternRewriter &rewriter) const { TF::StridedSliceOp strided_slice_op = llvm::cast(op); - uint64_t ellipsis_mask = strided_slice_op.ellipsis_mask(); - uint64_t shrink_axis_mask = strided_slice_op.shrink_axis_mask(); - uint64_t new_axis_mask = strided_slice_op.new_axis_mask(); + uint64_t ellipsis_mask = strided_slice_op.getEllipsisMask(); + uint64_t shrink_axis_mask = strided_slice_op.getShrinkAxisMask(); + uint64_t new_axis_mask = strided_slice_op.getNewAxisMask(); // Enforce operator precedence. shrink_axis_mask &= ~ellipsis_mask; new_axis_mask &= ~ellipsis_mask; DenseIntElementsAttr begin_dense_elem_attr; - Value begin = strided_slice_op.begin(); + Value begin = strided_slice_op.getBegin(); auto begin_ranked_attr_type = begin.getType().dyn_cast(); if (!begin_ranked_attr_type || !matchPattern(begin, m_Constant(&begin_dense_elem_attr))) { @@ -524,7 +527,7 @@ struct ConvertTFStridedSlice : public RewritePattern { } DenseIntElementsAttr end_dense_elem_attr; - Value end = strided_slice_op.end(); + Value end = strided_slice_op.getEnd(); auto end_ranked_attr_type = end.getType().dyn_cast(); if (!end_ranked_attr_type || !matchPattern(end, m_Constant(&end_dense_elem_attr))) { @@ -532,7 +535,7 @@ struct ConvertTFStridedSlice : public RewritePattern { } DenseIntElementsAttr stride_dense_elem_attr; - Value stride = strided_slice_op.strides(); + Value stride = strided_slice_op.getStrides(); auto stride_ranked_attr_type = stride.getType().dyn_cast(); if (!stride_ranked_attr_type || @@ -540,7 +543,7 @@ struct ConvertTFStridedSlice : public RewritePattern { return failure(); } - Value input = strided_slice_op.input(); + Value input = strided_slice_op.getInput(); RankedTensorType input_type = input.getType().dyn_cast(); if (!input_type) { return failure(); @@ -560,8 +563,8 @@ struct ConvertTFStridedSlice : public RewritePattern { const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1 + absl::popcount(new_axis_mask); - int64_t begin_mask = strided_slice_op.begin_mask(); - int64_t end_mask = strided_slice_op.end_mask(); + int64_t begin_mask = strided_slice_op.getBeginMask(); + int64_t end_mask = strided_slice_op.getEndMask(); int64_t revised_begin_mask = 0; int64_t revised_end_mask = 0; int64_t revised_shrink_axis_mask = 0; @@ -673,24 +676,24 @@ struct ConvertTFStridedSlice : public RewritePattern { TF::StridedSliceOp strided_slice_op = llvm::cast(op); // Handle ellipsis mask. - if (strided_slice_op.ellipsis_mask() != 0) { + if (strided_slice_op.getEllipsisMask() != 0) { return RewriteEllipsisMask(strided_slice_op, rewriter); } // Handle new axis mask. - if (strided_slice_op.new_axis_mask() != 0) { + if (strided_slice_op.getNewAxisMask() != 0) { return RewriteNewAxisMask(strided_slice_op, rewriter); } auto ranked_input_type = - strided_slice_op.input().getType().dyn_cast(); + strided_slice_op.getInput().getType().dyn_cast(); if (!ranked_input_type) { return failure(); } - auto begin_attr = strided_slice_op.begin(); - auto end_attr = strided_slice_op.end(); - auto strides_attr = strided_slice_op.strides(); + auto begin_attr = strided_slice_op.getBegin(); + auto end_attr = strided_slice_op.getEnd(); + auto strides_attr = strided_slice_op.getStrides(); auto begin_attr_type = begin_attr.getType().dyn_cast(); auto end_attr_type = end_attr.getType().dyn_cast(); @@ -722,8 +725,8 @@ struct ConvertTFStridedSlice : public RewritePattern { SmallVector padding_end(input_shape.begin(), input_shape.end()); SmallVector padding_strides(num_input_dims, 1); - int begin_mask = strided_slice_op.begin_mask(); - int end_mask = strided_slice_op.end_mask(); + int begin_mask = strided_slice_op.getBeginMask(); + int end_mask = strided_slice_op.getEndMask(); PadStridedSliceAttributeArray(begin_elem_attr, begin, padded_begin, padding_begin, &begin_mask); @@ -734,8 +737,8 @@ struct ConvertTFStridedSlice : public RewritePattern { if (begin == padded_begin && end == padded_end && strides == padded_strides && - begin_mask == strided_slice_op.begin_mask() && - end_mask == strided_slice_op.end_mask()) { + begin_mask == strided_slice_op.getBeginMask() && + end_mask == strided_slice_op.getEndMask()) { return failure(); } @@ -756,16 +759,16 @@ struct ConvertTFStridedSlice : public RewritePattern { auto attribute_type = rewriter.getIntegerType(64); rewriter.replaceOpWithNewOp( - op, strided_slice_op.output().getType(), strided_slice_op.input(), + op, strided_slice_op.getOutput().getType(), strided_slice_op.getInput(), new_begin_attr, new_end_attr, new_strides_attr, rewriter.getIntegerAttr(attribute_type, begin_mask), rewriter.getIntegerAttr(attribute_type, end_mask), rewriter.getIntegerAttr(attribute_type, - strided_slice_op.ellipsis_mask()), + strided_slice_op.getEllipsisMask()), rewriter.getIntegerAttr(attribute_type, - strided_slice_op.new_axis_mask()), + strided_slice_op.getNewAxisMask()), rewriter.getIntegerAttr(attribute_type, - strided_slice_op.shrink_axis_mask())); + strided_slice_op.getShrinkAxisMask())); return success(); } @@ -778,9 +781,12 @@ struct ConvertTFBroadcastTo : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { auto tf_broadcast_to_op = cast(op); - auto input_type = tf_broadcast_to_op.input().getType().cast(); - auto output_type = tf_broadcast_to_op.output().getType().cast(); - auto shape_type = tf_broadcast_to_op.shape().getType().cast(); + auto input_type = + tf_broadcast_to_op.getInput().getType().cast(); + auto output_type = + tf_broadcast_to_op.getOutput().getType().cast(); + auto shape_type = + tf_broadcast_to_op.getShape().getType().cast(); Type element_type = input_type.getElementType(); // Allow lowering when low dimension inputs are given and its type is F32 or @@ -801,11 +807,11 @@ struct ConvertTFBroadcastTo : public RewritePattern { } auto tf_fill_op = rewriter.create(op->getLoc(), output_type, - tf_broadcast_to_op.shape(), + tf_broadcast_to_op.getShape(), status_or_const_op.value()); auto mul_op = rewriter.create( - op->getLoc(), output_type, tf_broadcast_to_op.input(), tf_fill_op); + op->getLoc(), output_type, tf_broadcast_to_op.getInput(), tf_fill_op); rewriter.replaceOp(op, mul_op.getResult()); return success(); } @@ -925,7 +931,7 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern { ::mlir::Value mean_value = (*mean.begin()); ::mlir::Value variance_value = (*variance.begin()); - if (!TFTypeIsFloat32Tensor(fused_batch_norm_op.x())) return failure(); + if (!TFTypeIsFloat32Tensor(fused_batch_norm_op.getX())) return failure(); { epsilon = @@ -994,14 +1000,14 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern { auto odsLoc = rewriter.getFusedLoc({fused_batch_norm->getLoc()}); // We need to make sure input and output shapes are compatible. - int64_t last_dim = ShapedType::kDynamicSize; + int64_t last_dim = ShapedType::kDynamic; { auto is_last_dim_compatible = [](const Value &v, int64_t &last_dim) { auto v_type = v.getType().dyn_cast_or_null(); if (!v_type) return true; int64_t v_last_dim = v_type.getDimSize(v_type.getRank() - 1); - if (v_last_dim == ShapedType::kDynamicSize) return true; - if (last_dim != ShapedType::kDynamicSize && v_last_dim != last_dim) + if (v_last_dim == ShapedType::kDynamic) return true; + if (last_dim != ShapedType::kDynamic && v_last_dim != last_dim) return false; last_dim = v_last_dim; return true; @@ -1041,7 +1047,7 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern { // For training, mean and variance is calculated from input values. if (is_training.getValue()) { - auto input_type = fused_batch_norm_op.x() + auto input_type = fused_batch_norm_op.getX() .getType() .dyn_cast_or_null(); if (!input_type || input_type.getRank() != 4) { @@ -1215,7 +1221,9 @@ LogicalResult ConvertTf2XlaOps(func::FuncOp func, MLIRContext *context) { target.addIllegalOp(); RewritePatternSet patterns(context); - mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns, context); + mhlo::Tf2XlaTypeConverter converter; + mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns, context, + converter); mhlo::PopulateLegalizeTfPatterns(context, &patterns); TF::PopulateLegalizeHloToTfPatterns(&patterns, context); mhlo::GatherOp::getCanonicalizationPatterns(patterns, context); @@ -1248,10 +1256,10 @@ struct ConvertRfftToRfft2d : public RewritePattern { PatternRewriter &rewriter) const override { auto rfft_op = dyn_cast(op); - auto input = rfft_op.input(); + auto input = rfft_op.getInput(); auto input_type = input.getType().dyn_cast_or_null(); if (!input_type) return failure(); - auto fft_len = rfft_op.fft_length(); + auto fft_len = rfft_op.getFftLength(); auto fft_len_type = fft_len.getType().dyn_cast_or_null(); if (!fft_len_type) return failure(); @@ -1328,8 +1336,8 @@ struct RemoveIdentity : public OpRewritePattern { LogicalResult matchAndRewrite(TF::IdentityOp identity, PatternRewriter &rewriter) const override { // Replace the op with the input if input and result have the same type. - if (identity.input().getType() == identity.getType()) { - rewriter.replaceOp(identity, identity.input()); + if (identity.getInput().getType() == identity.getType()) { + rewriter.replaceOp(identity, identity.getInput()); return success(); } // Replace the op with the input if output is only used by TF ops. @@ -1343,7 +1351,7 @@ struct RemoveIdentity : public OpRewritePattern { } } - rewriter.replaceOp(identity, identity.input()); + rewriter.replaceOp(identity, identity.getInput()); return success(); } }; diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index 013325de59f..0bc21e41f68 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" @@ -50,7 +51,7 @@ namespace TFL { //===----------------------------------------------------------------------===// // The actual Quantize Pass. -// +//===----------------------------------------------------------------------===// namespace { #define GEN_PASS_DEF_QUANTIZEPASS #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc" @@ -58,16 +59,16 @@ namespace { enum QuantizationTrait { kFullQuantization, kDynamicRangeQuantization }; // Base struct for quantization. -template +template struct TFLQuantizationBase - : public quant::QuantizationPattern { + : public quant::QuantizationPattern { explicit TFLQuantizationBase(MLIRContext* ctx, const quant::QuantPassSpec& quant_params) - : quant::QuantizationPattern(ctx, quant_params) { - } + : quant::QuantizationPattern(ctx, + quant_params) {} static bool IsQuantizableCustomOp(Operation* op, const quant::CustomOpMap& custom_op_map) { @@ -77,7 +78,7 @@ struct TFLQuantizationBase // behaviors. In that case, these ops can be marked in the custom map and // treated separately in this pass. - auto custom_op = llvm::dyn_cast_or_null(op); + auto custom_op = llvm::dyn_cast_or_null(op); if (!custom_op) return false; // Custom op which is marked in the custom op map is quantizable. @@ -89,7 +90,6 @@ struct TFLQuantizationBase Operation* quantized_op, const quant::CustomOpMap& custom_op_map) { // Collect the input if dynamic range quantization is on and the op supports // it. - return quantization_trait == kDynamicRangeQuantization && (dyn_cast_or_null(quantized_op) || IsQuantizableCustomOp(quantized_op, custom_op_map)); @@ -99,15 +99,16 @@ struct TFLQuantizationBase Operation* quantized_op, const quant::CustomOpMap& custom_op_map) { // Collect the output if dynamic range quantization is on and the op // supports it. - return quantization_trait == kDynamicRangeQuantization && (dyn_cast_or_null(quantized_op) || IsQuantizableCustomOp(quantized_op, custom_op_map)); } - static bool IsWeightOnlyOp(Operation* quantized_op, StringSet& ops_blocklist, - bool weight_only_quantization, - const quant::CustomOpMap& custom_op_map) { + static bool IsWeightOnlyOp( + Operation* quantized_op, + const absl::flat_hash_set& ops_blocklist, + const bool weight_only_quantization, + const quant::CustomOpMap& custom_op_map) { // Check whether the quantized_op needs to be quantized in weight-only // manner. bool is_blocklisted = false; @@ -234,13 +235,13 @@ void QuantizePass::runOnOperation() { quant_specs.weight_quantization = enable_dynamic_range_quantization_; quant_specs.weight_only_quantization = enable_weight_only_quantization_; if (!ops_blocklist_flag_.empty()) { - quant_specs.ops_blocklist = - StringSet(ops_blocklist_flag_.begin(), ops_blocklist_flag_.end()); + quant_specs.ops_blocklist = absl::flat_hash_set( + ops_blocklist_flag_.begin(), ops_blocklist_flag_.end()); } if (!nodes_blocklist_flag_.empty()) { - quant_specs.nodes_blocklist = - StringSet(nodes_blocklist_flag_.begin(), nodes_blocklist_flag_.end()); + quant_specs.nodes_blocklist = absl::flat_hash_set( + nodes_blocklist_flag_.begin(), nodes_blocklist_flag_.end()); } if (!enable_custom_op_weight_only_.empty()) { @@ -254,7 +255,7 @@ void QuantizePass::runOnOperation() { quant_specs.whole_model_verify, enable_log_if_failed_}, quant_specs}; - TFL::populateWithGenerated(patterns); + populateWithGenerated(patterns); if (quant_specs.weight_quantization || quant_specs.use_fake_quant_num_bits) { patterns.add(ctx, quant_params); @@ -277,8 +278,9 @@ void QuantizePass::runOnOperation() { // Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass. std::unique_ptr> CreateQuantizePass( - const quant::QuantizationSpecs& quant_specs, const StringSet& ops_blocklist, - const StringSet& nodes_blocklist) { + const quant::QuantizationSpecs& quant_specs, + const absl::flat_hash_set& ops_blocklist, + const absl::flat_hash_set& nodes_blocklist) { quant::QuantizationSpecs updated_quant_specs; updated_quant_specs = quant_specs; // If there's new blocklists given, update quant_specs to use the new one. @@ -296,8 +298,10 @@ std::unique_ptr> CreateDefaultQuantizePass() { } std::unique_ptr> CreateQuantizePass( - bool verify_numeric, bool whole_model_verify, bool legacy_float_scale, - const StringSet& ops_blocklist, const StringSet& nodes_blocklist) { + const bool verify_numeric, const bool whole_model_verify, + const bool legacy_float_scale, + const absl::flat_hash_set& ops_blocklist, + const absl::flat_hash_set& nodes_blocklist) { quant::QuantizationSpecs quant_specs; quant_specs.verify_numeric = verify_numeric; quant_specs.whole_model_verify = whole_model_verify; diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc b/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc new file mode 100644 index 00000000000..33580d1ea95 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc @@ -0,0 +1,208 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TFL { +namespace { +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc" + +using ResourceIdMap = + absl::flat_hash_map, int>; + +using ResourceMap = absl::flat_hash_map>; + +Type GetQuantizedTypeFromReadVariableOp(VarHandleOp var_handle_op) { + Type ref_qtype = nullptr; + for (auto *var_handle_user : var_handle_op.getResult().getUsers()) { + auto read_variable_op = dyn_cast_or_null(var_handle_user); + if (!read_variable_op) continue; + for (auto *read_variable_user : read_variable_op.getResult().getUsers()) { + auto q_op = dyn_cast_or_null(read_variable_user); + if (!q_op || ref_qtype) continue; + ref_qtype = q_op.getResult().getType(); + } + } + return ref_qtype; +} + +Type GetDequantizedTypeFromAssigneVariableOp(VarHandleOp var_handle_op) { + Type ref_qtype = nullptr; + for (auto *var_handle_user : var_handle_op.getResult().getUsers()) { + auto assign_variable_op = + dyn_cast_or_null(var_handle_user); + if (!assign_variable_op) continue; + auto value_op = assign_variable_op.getValue().getDefiningOp(); + auto dq_op = dyn_cast_or_null(value_op); + if (!dq_op || ref_qtype) continue; + ref_qtype = dq_op.getInput().getType(); + } + return ref_qtype; +} + +class QuantizeVariablesPass + : public QuantizeVariablesPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizeVariablesPass) + explicit QuantizeVariablesPass() = default; + + void runOnOperation() override; + + private: + // Outlines the regions of the WhileOp's cond and body and insert function + // calls instead. + void QuantizeVariable(OpBuilder &builder, + const std::vector &var_handle_op); +}; + +void QuantizeVariablesPass::QuantizeVariable( + OpBuilder &builder, const std::vector &var_handle_ops) { + // TODO(b/261940892): Refactoring quantize_variables.cc + Type ref_qtype = nullptr; + for (VarHandleOp var_handle_op : var_handle_ops) { + if (ref_qtype) break; + ref_qtype = GetQuantizedTypeFromReadVariableOp(var_handle_op); + if (ref_qtype) break; + ref_qtype = GetDequantizedTypeFromAssigneVariableOp(var_handle_op); + } + if (!ref_qtype) return; + + for (VarHandleOp var_handle_op : var_handle_ops) { + for (auto *var_handle_user : + llvm::make_early_inc_range(var_handle_op.getResult().getUsers())) { + auto read_variable_op = dyn_cast_or_null(var_handle_user); + if (!read_variable_op) continue; + // Add dequantize. + builder.setInsertionPointAfter(read_variable_op); + auto new_read_variable_op = + builder.create(read_variable_op.getLoc(), ref_qtype, + read_variable_op.getResourceId()); + auto new_dq_op = builder.create( + read_variable_op.getLoc(), read_variable_op.getResult().getType(), + new_read_variable_op.getResult()); + read_variable_op->replaceAllUsesWith(new_dq_op); + read_variable_op.erase(); + } + for (auto *var_handle_user : + llvm::make_early_inc_range(var_handle_op.getResult().getUsers())) { + auto assign_variable_op = + dyn_cast_or_null(var_handle_user); + if (!assign_variable_op) continue; + auto *value_op = assign_variable_op.getValue().getDefiningOp(); + auto dq_op = dyn_cast_or_null(value_op); + if (dq_op) { + Type output_type = dq_op.getInput().getType(); + auto qtype = quant::QuantizedType::getQuantizedElementType(output_type); + if (qtype == quant::QuantizedType::getQuantizedElementType(ref_qtype)) { + // Same quantization parameters, remove it. + builder.setInsertionPoint(assign_variable_op); + auto new_assign_variable_op = builder.create( + assign_variable_op.getLoc(), assign_variable_op.getResourceId(), + dq_op.getInput()); + assign_variable_op->replaceAllUsesWith(new_assign_variable_op); + } else { + // Otherwise, apply re-quantization. + builder.setInsertionPoint(assign_variable_op); + auto new_q_op = builder.create( + assign_variable_op.getLoc(), ref_qtype, dq_op.getInput(), + TypeAttr::get(ref_qtype)); + auto new_assign_variable_op = builder.create( + assign_variable_op.getLoc(), assign_variable_op.getResourceId(), + new_q_op.getResult()); + assign_variable_op->replaceAllUsesWith(new_assign_variable_op); + } + assign_variable_op.erase(); + dq_op.erase(); + } else { + // Add quantize op. + builder.setInsertionPoint(assign_variable_op); + auto new_q_op = builder.create( + assign_variable_op.getLoc(), ref_qtype, + assign_variable_op.getValue(), TypeAttr::get(ref_qtype)); + auto new_assign_variable_op = builder.create( + assign_variable_op.getLoc(), assign_variable_op.getResourceId(), + new_q_op.getResult()); + assign_variable_op->replaceAllUsesWith(new_assign_variable_op); + assign_variable_op.erase(); + } + } + } + // Update resource tensors. + for (VarHandleOp var_handle_op : var_handle_ops) { + builder.setInsertionPoint(var_handle_op); + auto output_type = UnrankedTensorType::get(TF::ResourceType::get( + {ref_qtype.cast()}, builder.getContext())); + auto new_var_handle_op = builder.create( + var_handle_op.getLoc(), output_type, var_handle_op.getContainer(), + var_handle_op.getSharedName()); + var_handle_op->replaceAllUsesWith(new_var_handle_op); + var_handle_op.erase(); + } +} + +void QuantizeVariablesPass::runOnOperation() { + ResourceIdMap resource_id_map; + ResourceMap resource_map; + + // Collect all resource identities. + getOperation().walk([&](TFL::VarHandleOp var_handle_op) { + auto identity = std::make_pair(var_handle_op.getContainer().str(), + var_handle_op.getSharedName().str()); + resource_id_map.insert( + std::make_pair(identity, static_cast(resource_id_map.size()))); + int resource_id = resource_id_map[identity]; + resource_map[resource_id].push_back(var_handle_op); + }); + + OpBuilder builder(getOperation().getContext()); + + for (const auto &[identity, var_handle_op] : resource_map) { + QuantizeVariable(builder, var_handle_op); + } +} +} // namespace + +// Creates an instance of the TensorFlow Lite dialect Quantize Variables pass. +std::unique_ptr> CreatePrepareQuantizeVariablesPass() { + return std::make_unique(); +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc b/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc index 0c0b0b11bc0..19109ce3b29 100644 --- a/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc +++ b/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc @@ -47,7 +47,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { namespace TFL { @@ -142,8 +142,7 @@ bool AllOperationSafe(Block &block) { auto walk_result = block.walk([&](Operation *op) { // op has SideEffect. if (!isa_and_nonnull(op) && - !op->hasTrait() && - !MemoryEffectOpInterface::hasNoEffect(op)) { + !op->hasTrait() && !isMemoryEffectFree(op)) { return WalkResult::interrupt(); } // op has implict arguments not listed in operands. diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc index c6c019e604d..560bcfd3543 100644 --- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc +++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc @@ -79,8 +79,10 @@ bool IsAlreadyOutlined(WhileOp while_op) { bool IsCompatibleTypeWithTFLCastOp(Type type) { auto elemType = getElementTypeOrSelf(type); - // F32 and BF16 types are allowed. - if (elemType.isBF16() || elemType.isF32()) return true; + // F16, F32, F64, BF16 types are allowed. + if (elemType.isBF16() || elemType.isF16() || elemType.isF32() || + elemType.isF64()) + return true; // I1, I8 I16, I32, I64 types are allowed. if (elemType.isInteger(1) || elemType.isInteger(8) || @@ -180,7 +182,7 @@ void ReplaceRegionWithCall(StringRef name, Region& region, auto block = b.createBlock(®ion); SmallVector new_operands; new_operands.reserve(types.size()); - for (Type t : llvm::makeArrayRef(types).drop_back(extern_values.size())) + for (Type t : llvm::ArrayRef(types).drop_back(extern_values.size())) new_operands.push_back(block->addArgument(t, loc)); for (Value v : extern_values) new_operands.push_back(v); auto call = b.create(loc, func, new_operands); diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc index b55d0713cb3..8bd8f0ee3d5 100644 --- a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc @@ -28,7 +28,7 @@ limitations under the License. namespace mlir { namespace TFL { -stream_executor::port::StatusOr CreateConstOpWithSingleValue( +tsl::StatusOr CreateConstOpWithSingleValue( PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value) { Type element_type = shaped_type.getElementType(); diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.h b/tensorflow/compiler/mlir/lite/utils/constant_utils.h index a8a40723835..9c107211b71 100644 --- a/tensorflow/compiler/mlir/lite/utils/constant_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.h @@ -22,13 +22,13 @@ limitations under the License. #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" +#include "tensorflow/tsl/platform/statusor.h" namespace mlir { namespace TFL { // Returns a Constant op with a single value. -stream_executor::port::StatusOr CreateConstOpWithSingleValue( +tsl::StatusOr CreateConstOpWithSingleValue( PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value); } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h b/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h index 5db4c759f9a..093e53c0869 100644 --- a/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h @@ -36,8 +36,8 @@ struct FetchMinMaxAttrs { using AttrType = FloatAttr; bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, AttrType &max_value) const { - min_value = tf_op.minAttr(); - max_value = tf_op.maxAttr(); + min_value = tf_op.getMinAttr(); + max_value = tf_op.getMaxAttr(); return true; // Successfully matched and fetched. } }; @@ -47,7 +47,7 @@ struct FetchConstantMinMaxInputs { using AttrType = DenseFPElementsAttr; bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, AttrType &max_value) const { - Value min = tf_op.min(), max = tf_op.max(); + Value min = tf_op.getMin(), max = tf_op.getMax(); if (!matchPattern(min, m_Constant(&min_value))) { return false; } @@ -105,7 +105,7 @@ class InsertTFLQuantOpsAfterTFFakeQuantOp { LogicalResult matchAndRewrite(TFFakeQuantOp tf_op, OpBuilder &rewriter) const { // We don't want to insert quantize/dequantize if the quantize op exists. - auto res = tf_op.outputs(); + auto res = tf_op.getOutputs(); if (!res.hasOneUse() || isa(*res.user_begin())) { return failure(); } @@ -127,8 +127,8 @@ class InsertTFLQuantOpsAfterTFFakeQuantOp { // Use the min/max from the operands and the num_bits and narrow_range // attribute to create the quantization parameter for the new quantize op. rewriter.setInsertionPointAfter(tf_op.getOperation()); - IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.num_bits()); - BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range()); + IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.getNumBits()); + BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.getNarrowRange()); Type res_type = tf_op.getType(); TypeAttr qtype = quant::GetQuantizedTypeAttr( rewriter, res_type, min_value, max_value, quant_dim, num_bits, @@ -141,7 +141,7 @@ class InsertTFLQuantOpsAfterTFFakeQuantOp { // Finally, use the quantization parameter to create the quantize and // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp // and its users. - Value value = tf_op.outputs(); + Value value = tf_op.getOutputs(); auto quantize = rewriter.create( tf_op.getLoc(), qtype.getValue(), value, qtype); auto dequantize = rewriter.create( diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index f9ca9318c1b..50f2a9bd4ca 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -594,6 +594,15 @@ TF::ConstOp CreateScalarConstantOp(int value, Location loc, return builder->create(loc, builder->getI32IntegerAttr(value)); } +TF::ReshapeOp CreateFlattenOP(const Value& input, Location loc, + OpBuilder* builder) { + auto output_shape = Create1DConstantOp({-1}, loc, builder); + return builder->create( + loc, + /*tensor=*/input, + /*shape=*/output_shape.getResult()); +} + LogicalResult CreateEqualSizeSplitVOp(Value input, int axis, int splits, Location loc, OpBuilder* builder, Operation** result) { @@ -630,9 +639,14 @@ LogicalResult CreateEqualSizeSplitVOp(Value input, int axis, int splits, return success(); } -// TODO(b/147436982): Consider refactor this to be more general. +// TODO(b/147436982): Consider refactoring these to be more general. LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, OpBuilder* builder) { + return ConvertKerasLSTMLayer(func_op, builder, false); +} + +LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, + OpBuilder* builder, bool indy) { // For argument order, please check out standard_lstm under // tensorflow/python/keras/layers/recurrent_v2.py Value input = func_op.getArgument(0); @@ -707,6 +721,34 @@ LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, &recurrent_weights_array))) return failure(); + // Reshape recurrent weights to vectors if indy behaviour is enabled. + // IndyLSTMs are a LSTM variant with diagonal recurrent weight + // matrices. For optimization purposes these are provided as vectors. + Value recurrent_to_input_weights = + indy ? CreateFlattenOP(recurrent_weights_array->getResult(0), + func_op.getLoc(), builder) + .getResult() + .cast() + : recurrent_weights_array->getResult(0); + Value recurrent_to_forget_weights = + indy ? CreateFlattenOP(recurrent_weights_array->getResult(1), + func_op.getLoc(), builder) + .getResult() + .cast() + : recurrent_weights_array->getResult(1); + Value recurrent_to_cell_weights = + indy ? CreateFlattenOP(recurrent_weights_array->getResult(2), + func_op.getLoc(), builder) + .getResult() + .cast() + : recurrent_weights_array->getResult(2); + Value recurrent_to_output_weights = + indy ? CreateFlattenOP(recurrent_weights_array->getResult(3), + func_op.getLoc(), builder) + .getResult() + .cast() + : recurrent_weights_array->getResult(3); + // Splits the bias into 4: Operation* bias_array; if (failed(CreateEqualSizeSplitVOp(bias, 0, splits, func_op.getLoc(), builder, @@ -731,10 +773,10 @@ LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, /*input_to_forget_weights=*/weights_array->getResult(1), /*input_to_cell_weights=*/weights_array->getResult(2), /*input_to_output_weights=*/weights_array->getResult(3), - /*recurrent_to_input_weights=*/recurrent_weights_array->getResult(0), - /*recurrent_to_forget_weights=*/recurrent_weights_array->getResult(1), - /*recurrent_to_cell_weights=*/recurrent_weights_array->getResult(2), - /*recurrent_to_output_weights=*/recurrent_weights_array->getResult(3), + /*recurrent_to_input_weights=*/recurrent_to_input_weights, + /*recurrent_to_forget_weights=*/recurrent_to_forget_weights, + /*recurrent_to_cell_weights=*/recurrent_to_cell_weights, + /*recurrent_to_output_weights=*/recurrent_to_output_weights, /*cell_to_input_weights=*/none, /*cell_to_forget_weights=*/none, /*cell_to_output_weights=*/none, @@ -755,6 +797,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, /*proj_clip*/ builder->getF32FloatAttr(0.0), /*time_major*/ builder->getBoolAttr(time_majored), /*asymmetric_quantize_inputs=*/mlir::BoolAttr(), + /*diagonal_recurrent_tensors=*/builder->getBoolAttr(indy), /*input_to_input_intermediate=*/mlir::TypeAttr(), /*input_to_forget_intermediate=*/mlir::TypeAttr(), /*input_to_cell_intermediate=*/mlir::TypeAttr(), diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h index 6749f824943..7421fe2faa8 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h @@ -211,6 +211,9 @@ class ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, OpBuilder* builder); +LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, + OpBuilder* builder, bool indy); + } // end namespace TFL } // end namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc index d7aca785e2b..342bbb5c7fe 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc @@ -183,7 +183,7 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) { EXPECT_EQ(fused_lstm_func_.getFunctionType().getNumResults(), 1); auto output_types = fused_lstm_func_.getFunctionType().getResults(); - SmallVector output_shape{1, mlir::ShapedType::kDynamicSize}; + SmallVector output_shape{1, mlir::ShapedType::kDynamic}; EXPECT_EQ(output_types[0].cast().getShape().size(), output_shape.size()); for (int i = 0; i < output_shape.size(); i++) { @@ -254,7 +254,7 @@ TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) { EXPECT_EQ(fused_ln_lstm_func_.getFunctionType().getNumResults(), 1); auto output_types = fused_ln_lstm_func_.getFunctionType().getResults(); - SmallVector output_shape{1, mlir::ShapedType::kDynamicSize}; + SmallVector output_shape{1, mlir::ShapedType::kDynamic}; EXPECT_EQ(output_types[0].cast().getShape().size(), output_shape.size()); for (int i = 0; i < output_shape.size(); i++) { diff --git a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc index d84835df803..c7944b67406 100644 --- a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc @@ -183,6 +183,12 @@ LogicalResult ConvertMaxUnpoolingFunc::CreateCustomOptions( pool_params.activation = kTfLiteActNone; pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0}; +#if FLATBUFFERS_LITTLEENDIAN == 0 + int32_t* p = reinterpret_cast(&pool_params); + for (size_t i = 0; i < sizeof(TfLitePoolParams) / 4; i++, p++) + *p = flatbuffers::EndianSwap(*p); +#endif + custom_option_buffer.assign(reinterpret_cast(&pool_params), sizeof(TfLitePoolParams)); return success(); diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h b/tensorflow/compiler/mlir/lite/utils/size_utils.cc similarity index 67% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h rename to tensorflow/compiler/mlir/lite/utils/size_utils.cc index f99cbf2327f..a5ffb64eaf4 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h +++ b/tensorflow/compiler/mlir/lite/utils/size_utils.cc @@ -13,18 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TILING_INTERFACE_IMPL_H -#define MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TILING_INTERFACE_IMPL_H +#include "tensorflow/compiler/mlir/lite/utils/size_utils.h" -namespace mlir { +#include -class DialectRegistry; +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -namespace gml_st { +namespace mlir { +namespace TFL { -void registerGmlStTilingInterfaceExternalModels(DialectRegistry ®istry); +int32_t ConvertToTfliteSize(int64_t size) { + return mlir::ShapedType::isDynamic(size) ? -1 : static_cast(size); +} -} // namespace gml_st +} // namespace TFL } // namespace mlir - -#endif // MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TILING_INTERFACE_IMPL_H diff --git a/tensorflow/compiler/mlir/lite/utils/size_utils.h b/tensorflow/compiler/mlir/lite/utils/size_utils.h new file mode 100644 index 00000000000..52aa50c1440 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/size_utils.h @@ -0,0 +1,32 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_SIZE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_SIZE_UTILS_H_ + +#include + +namespace mlir { +namespace TFL { + +// Converts a TF size (64-bit) to TFLite (32-bit) and properly converts TF's +// value for dynamic size (`std::numeric_limits::min()`) to the +// TFLite-specific value. +int32_t ConvertToTfliteSize(int64_t size); + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_SIZE_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/utils/size_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/size_utils_test.cc new file mode 100644 index 00000000000..49c3fc70cd0 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/size_utils_test.cc @@ -0,0 +1,33 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/utils/size_utils.h" + +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "tensorflow/core/platform/test.h" + +namespace mlir { +namespace TFL { +namespace { + +TEST(SizeUtilTest, TestConvertsSize) { + ASSERT_EQ(ConvertToTfliteSize(1), 1); + ASSERT_EQ(ConvertToTfliteSize(-1), -1); + ASSERT_EQ(ConvertToTfliteSize(mlir::ShapedType::kDynamic), -1); +} + +} // namespace +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 9902b3db18b..5db4ef7cce7 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -168,26 +168,12 @@ Status MlirFunctionOptimizationPass::Run( } } - static const char* kTfMlirCategory = "TfMlir"; - tensorflow::metrics::ScopedCounter<2> timings( - tensorflow::metrics::GetGraphOptimizationCounter(), - {kTfMlirCategory, "graph_analysis"}); - - timings.ReportAndStop(); - if (overall_state == MlirOptimizationPassState::Disabled) { if (VLOG_IS_ON(1)) { LOG_FIRST_N(INFO, 1) << "None of the MLIR Optimization Passes are enabled " << "(registered " << registry_->passes().size() << ")"; } - // Capture stats on graph properties analyzed before running the MLIR - // bridge. We set `uses_uninitialized_resource_args` to false here because - // function optimization is not affected by uninitialized resource args. - // TODO(b/241853328): Remove LogGraphFeatures when fixed - LogGraphFeatures(**graph, flib_def, config_proto, - /*uses_uninitialized_resource_args=*/false, - /*is_v1_compat=*/false); return OkStatus(); } @@ -214,7 +200,11 @@ Status MlirFunctionOptimizationPass::Run( // during import is not necessary. import_config.enable_shape_inference = false; - timings.Reset({kTfMlirCategory, "convert_graph_to_mlir"}); + static const char* kTfMlirCategory = "TfMlir"; + tensorflow::metrics::ScopedCounter<2> timings( + tensorflow::metrics::GetGraphOptimizationCounter(), + {kTfMlirCategory, "convert_graph_to_mlir"}); + auto module_ref_status = ConvertGraphToMlir(**graph, debug_info, *flib_def, import_config, &context); timings.ReportAndStop(); @@ -233,15 +223,8 @@ Status MlirFunctionOptimizationPass::Run( std::move(module_ref_status.value()); AddDevicesToOp(*module_ref, &device_set); - // Capture stats on graph properties analyzed before running the MLIR - // bridge. We set `uses_uninitialized_resource_args` to false here because - // function optimization is not affected by uninitialized resource args. - // TODO (b/241853328) Remove LogGraphFeatures when fixed - LogGraphFeatures(**graph, flib_def, config_proto, - /*uses_uninitialized_resource_args=*/false, - /*is_v1_compat=*/false); - int per_pass_state_index = 0; + bool is_module_updated = false; for (auto& pass_registration : registry_->passes()) { llvm::StringRef name = pass_registration.pass->name(); @@ -253,13 +236,24 @@ Status MlirFunctionOptimizationPass::Run( auto pass_state = per_pass_state[per_pass_state_index++]; if (pass_state == MlirOptimizationPassState::Enabled) { VLOG(2) << "Run MLIR graph optimization pass: " << StringRefToView(name); + VLOG(2) << "Graph #nodes " << (*graph)->num_nodes() << " #edges " + << (*graph)->num_edges(); timings.Reset({kTfMlirCategory, name.str()}); pass_status = pass_registration.pass->Run(config_proto, *module_ref, **graph, *flib_def); timings.ReportAndStop(); + if (pass_status.ok()) { + VLOG(2) << "Finished MLIR graph optimization pass: " + << StringRefToView(name); + VLOG(2) << "Graph #nodes " << (*graph)->num_nodes() << " #edges " + << (*graph)->num_edges(); + is_module_updated = true; + } } else if (pass_state == MlirOptimizationPassState::FallbackEnabled) { VLOG(2) << "Run MLIR graph optimization pass with fallback: " << StringRefToView(name); + VLOG(2) << "Graph #nodes " << (*graph)->num_nodes() << " #edges " + << (*graph)->num_edges(); // Make sure when the pass is FallbackEnabled, it only modifies the MLIR // module in case of no failures. auto module_ref_clone = module_ref->clone(); @@ -268,10 +262,16 @@ Status MlirFunctionOptimizationPass::Run( **graph, *flib_def); timings.ReportAndStop(); - if (pass_status.ok()) + if (pass_status.ok()) { + VLOG(2) << "Finished MLIR graph optimization pass with fallback: " + << StringRefToView(name); + VLOG(2) << "Graph #nodes " << (*graph)->num_nodes() << " #edges " + << (*graph)->num_edges(); module_ref = module_ref_clone; - else + is_module_updated = true; + } else { module_ref_clone->destroy(); + } } else { VLOG(2) << "MLIR graph optimization pass: " << StringRefToView(name) << " is disabled and will not be run."; @@ -301,6 +301,11 @@ Status MlirFunctionOptimizationPass::Run( } } + if (!is_module_updated) { + VLOG(2) << "MLIR module is not updated. Using the original graph. " + << "Do not convert mlir module back to graph"; + return OkStatus(); + } GraphExportConfig export_config; absl::flat_hash_set control_ret_nodes; @@ -344,15 +349,6 @@ Status MlirV1CompatGraphOptimizationPass::Run( pass->GetPassState(options.device_set, options.session_options->config, **options.graph, *options.flib_def); - // If we ever have more than one MlirV1CompatOptimization pass we need to - // ensure the logging only happens once per graph to avoid redundant logging - // (see how it is used in the MLIRFunctionOptimizationPass as an example) - // TODO(b/241853328): Remove LogGraphFeatures when fixed - LogGraphFeatures(**options.graph, options.flib_def, - options.session_options->config, - /*uses_uninitialized_resource_args=*/false, - /*is_v1_compat=*/true); - if (pass_state == MlirOptimizationPassState::Disabled) { LOG_FIRST_N(INFO, 1) << "MLIR V1 optimization pass is not enabled"; return OkStatus(); diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc index ec38a06676c..c4fa5158f54 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass_test.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" #include +#include #include "mlir/IR/Builders.h" // from @llvm-project +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -89,6 +91,25 @@ class ModifyMlirModulePass : public MlirOptimizationPass { Status run_status_; }; +FunctionDef XTimesTwo() { + const Tensor kTwo = test::AsScalar(2); + return FunctionDefHelper::Define( + // Name + "XTimesTwo", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}}, + {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, + {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}}, + }); +} + class MlirGraphOptimizationPassTest : public Test { public: void Init(Status pass_run_result, @@ -171,6 +192,14 @@ TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsDisabledFallback) { {MlirOptimizationPassState::Disabled, MlirOptimizationPassState::FallbackEnabled}); + // We expect the result graph to be exactly the same as the original graph + // so we define the `graph_` by the following `flib` in this test point + // instead of the way we do in the Init method. + FunctionDefLibrary flib; + *flib.add_function() = XTimesTwo(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); + graph_ = std::make_unique(flib_def); + GraphDef original_graph_def; graph_->ToGraphDef(&original_graph_def); AddModuleModificationPass(MlirOptimizationPassState::FallbackEnabled, diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD index 016eecee453..136d5522d0d 100644 --- a/tensorflow/compiler/mlir/python/BUILD +++ b/tensorflow/compiler/mlir/python/BUILD @@ -2,6 +2,7 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) @@ -17,14 +18,17 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:BytecodeWriter", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_helper", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:tfe_context_internal", + "//tensorflow/compiler/xla/mlir/framework/transforms:passes", "//tensorflow/compiler/xla/mlir_hlo:all_passes", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow", @@ -42,7 +46,6 @@ cc_library( "//tensorflow/compiler/mlir/tosa:tf_tfl_passes", "//tensorflow/compiler/mlir/tosa:tfl_passes", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", - "//tensorflow/compiler/mlir/xla:xla_passes", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:framework", "//tensorflow/core:lib_proto_parsing", diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index c7ba1896e65..a51bef8376f 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -25,13 +25,17 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/AsmState.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/InitAllPasses.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/tfe_context_internal.h" #include "tensorflow/c/tf_status.h" @@ -52,9 +56,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tosa/tfl_passes.h" #include "tensorflow/compiler/mlir/tosa/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" -#include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir/framework/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/function_body.h" @@ -82,10 +86,9 @@ static void RegisterPasses() { mlir::lmhlo::registerAllLmhloPasses(); // These are in compiler/mlir/xla and not part of the above MHLO // passes. - mlir::mhlo::registerXlaPasses(); + mlir::mhlo::registerXlaFrameworkPasses(); mlir::mhlo::registerTfXlaPasses(); mlir::mhlo::registerLegalizeTFPass(); - mlir::mhlo::registerLegalizeTFControlFlowPass(); mlir::mhlo::registerLegalizeTfTypesPassPass(); mlir::tosa::registerLegalizeTosaPasses(); mlir::tosa::registerTFtoTOSALegalizationPipeline(); @@ -355,4 +358,32 @@ std::string ExperimentalRunPassPipeline(const std::string& mlir_txt, return MlirModuleToString(*module, show_debug_info); } +void ExperimentalWriteBytecode(const std::string& filename, + const std::string& mlir_txt, TF_Status* status) { + mlir::DialectRegistry registry; + mlir::RegisterAllTensorFlowDialects(registry); + mlir::MLIRContext context(registry); + mlir::OwningOpRef module; + { + mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); + module = mlir::parseSourceString(mlir_txt, &context); + if (!module) { + Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus()); + return; + } + } + mlir::FallbackAsmResourceMap fallback_resource_map; + mlir::BytecodeWriterConfig writer_config(fallback_resource_map); + std::string error; + std::unique_ptr outputFile = + mlir::openOutputFile(filename, &error); + if (!error.empty()) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + ("Unable to create output file" + error).c_str()); + return; + } + outputFile->keep(); + mlir::writeBytecodeToFile(*module, outputFile->os(), writer_config); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/python/mlir.h b/tensorflow/compiler/mlir/python/mlir.h index 6d1dff63ebd..740971d4fb8 100644 --- a/tensorflow/compiler/mlir/python/mlir.h +++ b/tensorflow/compiler/mlir/python/mlir.h @@ -103,6 +103,10 @@ std::string ExperimentalRunPassPipeline(const std::string &mlir_txt, bool show_debug_info, TF_Status *status); +// Writes the input textual MLIR as bytecode to output file. +void ExperimentalWriteBytecode(const std::string &filename, + const std::string &mlir_txt, TF_Status *status); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_ diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD index 55fe818a1a9..807d0f497df 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/BUILD @@ -1,6 +1,9 @@ load("//tensorflow:tensorflow.default.bzl", "tf_python_pybind_extension") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) tf_python_pybind_extension( name = "mlir_wrapper", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD new file mode 100644 index 00000000000..0ae09a43dd3 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -0,0 +1,28 @@ +load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") +load("//tensorflow/compiler/mlir/quantization/stablehlo:internal_visibility_allowlist.bzl", "internal_visibility_allowlist") + +package_group( + name = "internal_visibility_allowlist_package", + packages = [ + "//tensorflow/compiler/mlir/lite/...", + "//tensorflow/compiler/mlir/quantization/...", + "//third_party/cloud_tpu/inference_converter/...", # TPU Inference Converter V1 + ] + internal_visibility_allowlist(), +) + +tf_proto_library( + name = "quantization_options_proto", + srcs = ["quantization_options.proto"], + cc_api_version = 2, + make_default_target_header_only = True, + visibility = [":internal_visibility_allowlist_package"], +) + +# copybara:uncomment_begin(google-only) +# py_proto_library( +# name = "quantization_options_py_pb2", +# api_version = 2, +# visibility = [":internal_visibility_allowlist_package"], +# deps = [":quantization_options_proto"], +# ) +# copybara:uncomment_end diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/internal_visibility_allowlist.bzl b/tensorflow/compiler/mlir/quantization/stablehlo/internal_visibility_allowlist.bzl new file mode 100644 index 00000000000..0e302a08fd5 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/internal_visibility_allowlist.bzl @@ -0,0 +1,10 @@ +"""Internal visibility rules.""" + +def internal_visibility_allowlist(): + """Returns a list of g3 packages that can depend on internal targets.""" + return [ + "//learning/brain/experimental/mlir/quantization/...", + "//learning/brain/mlir/quantization/tensorflow/...", + "//learning/brain/mobile/programmability/...", + "//learning/brain/experimental/tfq/...", + ] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto new file mode 100644 index 00000000000..22163b54a6d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto @@ -0,0 +1,107 @@ +syntax = "proto3"; + +package stablehlo.quantization; + +option cc_enable_arenas = true; + +// Defines arious options to specify and control the behavior of the +// StableHLO quantizer. +// NEXT ID: 2 +message QuantizationOptions { + QuantizationMethod quantization_method = 1; +} + +// NEXT ID: 3 +message QuantizationMethod { + // Quantization Method can be either preset or custom. + oneof quantization_method { + PresetQuantizationMethod preset_quantization_method = 1; + CustomQuantizationMethod custom_quantization_method = 2; + } +} + +// Preset model quantization method for optimization. +// +// Common quantization methods are defined as preset methods in this message. +// Note that the quantization specs (ex: bit width) for preset quantization +// methods are fixed. To use a different quantization spec for a given method, +// use CustomQuantizationMethod. +// NEXT ID: 2 +message PresetQuantizationMethod { + // Preset quantization methods that are supported as a stable API. + // NEXT ID: 3 + enum PresetMethod { + // TODO(b/266173150): Update preset methods after redefining quantization + // pattern matching in DarwiNN. + // This should never be used. Using this will generally result in an error. + METHOD_UNSPECIFIED = 0; // go/do-include-enum-unspecified + + // Apply default weight-only quantization. Weights are quantized during + // conversion, then dequantized during inference. Data type is as follows: + // Weight: i8, Bias: f32, Activation: f32, Input/output: f32 + WEIGHT_ONLY = 1; + + // Apply default dynamic range quantization. Quantized tensor value's + // ranges are determined during graph runtime. Data type is as follows: + // Weight: i8, Bias: f32, Activation: f32, Input/output: f32 + DYNAMIC_RANGE = 2; + } + PresetMethod preset_method = 1; +} + +// Custom option for specifying quantization spec details. +// If the selected quantization option is not available, StableHLO quantizer +// will raise an error. +// NEXT ID: 2 +message CustomQuantizationMethod { + // Specify component name, bit width, and other specs for all compoenents + // intended to be quantized. + repeated QuantizationComponentSpec quantization_component_spec = 1; +} + +// Quantization spec per each component designated to be quantized. +// Components whose QuantizationComponentSpec is not specified will not be +// quantized, and remain f32. +// NEXT ID: 7 +message QuantizationComponentSpec { + // NEXT ID: 4 + enum QuantizationComponent { + COMPONENT_UNSPECIFIED = 0; + COMPONENT_ACTIVATION = 1; + COMPONENT_WEIGHT = 2; + COMPONENT_BIAS = 3; + } + + // NEXT ID: 4 + enum BitWidth { + BIT_WIDTH_UNSPECIFIED = 0; + BIT_WIDTH_4 = 1; + BIT_WIDTH_8 = 2; + BIT_WIDTH_16 = 3; + } + + // NEXT ID: 4 + enum BitType { + BIT_TYPE_UNSPECIFIED = 0; + BIT_TYPE_INT = 1; + BIT_TYPE_FLOAT = 2; + BIT_TYPE_BFLOAT = 3; + } + + QuantizationComponent quantization_component = 1; + + // Defines the target bit of the data. + BitWidth bit_width = 2; + + // Defines the type of data of the quantized component. + BitType bit_type = 3; + + // Defines whether quantization is done in narrow range. + bool enable_narrow_range = 4; + + // Defines whether quantiation is done per-channel. + bool enable_per_channel_quantization = 5; + + // Defines whether quantization is done symmetrically. + bool enable_symmetric = 6; +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index 47269d4b222..04679693985 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -15,8 +15,10 @@ package_group( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ ":internal_visibility_allowlist_package", + "//tensorflow:__pkg__", ], licenses = ["notice"], ) @@ -85,6 +87,7 @@ td_library( "passes/post_quantize.td", "passes/prepare_lifting.td", "passes/prepare_quantize.td", + "passes/preprocess_op.td", "passes/quantize_composite_functions.td", "passes/replace_cast_hacks_with_tf_xla_ops.td", "passes/tf_quant_ops.td", @@ -218,6 +221,20 @@ gentbl_cc_library( deps = [":quant_td_files"], ) +gentbl_cc_library( + name = "preprocess_op_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ( + ["-gen-rewriters"], + "passes/preprocess_op.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/preprocess_op.td", + deps = [":quant_td_files"], +) + cc_library( name = "tf_quant_ops", srcs = [ @@ -264,6 +281,23 @@ cc_library( ], ) +cc_library( + name = "uniform_op_quant_spec", + srcs = [ + "ops/uniform_op_quant_spec.cc", + ], + hdrs = ["ops/uniform_op_quant_spec.h"], + compatible_with = get_compatible_with_cloud(), + deps = [ + ":tf_quant_ops", + "//tensorflow/compiler/mlir/lite/quantization:quantization_config", + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/container:flat_hash_set", + "@llvm-project//mlir:IR", + ], +) + gentbl_cc_library( name = "replace_cast_hacks_with_tf_xla_ops_inc_gen", compatible_with = get_compatible_with_cloud(), @@ -284,14 +318,17 @@ cc_library( "passes/convert_custom_aggregation_op_to_quant_stats.cc", "passes/convert_fake_quant_to_qdq.cc", "passes/convert_tf_quant_ops_to_mhlo.cc", + "passes/duplicate_shape_determining_constants.cc", "passes/insert_custom_aggregation_ops.cc", "passes/insert_main_function.cc", "passes/insert_quantized_functions.cc", + "passes/insert_restore_op.cc", "passes/issue_ids_of_custom_aggregation_ops.cc", "passes/lift_quantizable_spots_as_functions.cc", "passes/lift_quantizable_spots_as_functions.inc", "passes/lift_quantizable_spots_as_functions_drq.cc", "passes/lift_quantizable_spots_as_functions_drq.inc", + "passes/mark_functions_noinline.cc", "passes/merge_initializer_function_ops_to_main.cc", "passes/optimize.cc", "passes/optimize.inc", @@ -302,10 +339,13 @@ cc_library( "passes/prepare_quantize.cc", "passes/prepare_quantize.inc", "passes/prepare_quantize_drq.cc", + "passes/preprocess_op.cc", + "passes/preprocess_op.inc", "passes/quantize.cc", "passes/quantize_composite_functions.cc", "passes/quantize_composite_functions.inc", "passes/quantized_function_library.h", + "passes/remove_var_init_by_const.cc", "passes/replace_cast_hacks_with_tf_xla_ops.cc", "passes/replace_cast_hacks_with_tf_xla_ops.inc", "passes/unfreeze_constants.cc", @@ -325,8 +365,10 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:const_op_size", "//tensorflow/compiler/mlir/quantization/tensorflow/utils:fake_quant_utils", "//tensorflow/compiler/mlir/quantization/tensorflow/utils:lift_as_function_call_utils", + "//tensorflow/compiler/mlir/quantization/tensorflow/utils:tf_to_uniform_attribute_utils", "//tensorflow/compiler/mlir/quantization/tensorflow/utils:tf_to_xla_attribute_utils", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:error_util", @@ -353,6 +395,7 @@ cc_library( "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/random", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:Support", @@ -381,28 +424,26 @@ cc_library( ], compatible_with = get_compatible_with_cloud(), deps = [ - ":quantization_options_proto_cc", - "//tensorflow/cc/saved_model:loader", + ":passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/debugging:mlir_dump", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow:export_graphdef", - "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_freeze_variables", "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", "//tensorflow/compiler/mlir/tensorflow:translate_lib", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:path", - "//tensorflow/core/platform:statusor", - "@llvm-project//mlir:ArithDialect", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Transforms", ], ) @@ -417,9 +458,10 @@ cc_library( ], compatible_with = get_compatible_with_cloud(), deps = [ + ":pass_utils", + ":passes", ":quantization_options_proto_cc", "//tensorflow/cc/saved_model:loader", - "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:export_graphdef", @@ -468,6 +510,33 @@ tf_proto_library( # ) # copybara:uncomment_end +# OSS only: This target is header-only. Link `exported_model_proto_impl` only to +# `libtensorflow_framework.so` via `lib_internal_impl`. Do NOT link +# `exported_model_proto_impl` directly unless the target does not link +# `libtensorflow_framework.so`. +tf_proto_library( + name = "exported_model_proto", + srcs = ["exported_model.proto"], + cc_api_version = 2, + make_default_target_header_only = True, + protodeps = [ + "//tensorflow/core:protos_all", + ], + visibility = [ + ":internal_visibility_allowlist_package", + # To be visible from `lib_internal_impl`. + "//tensorflow/core:__pkg__", + ], +) + +# copybara:uncomment_begin(google-only) +# py_proto_library( +# name = "exported_model_py_pb2", +# api_version = 2, +# deps = [":exported_model_proto"], +# ) +# copybara:uncomment_end + tf_cc_binary( name = "tf-quant-opt", srcs = ["passes/tf_quant_opt.cc"], diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD index 73cdea990ec..831bf9980ca 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD @@ -1,4 +1,13 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud", "tf_kernel_library", "tf_py_test") +load( + "//tensorflow/core/platform:build_config_root.bzl", + "if_static", +) +load( + "//tensorflow:tensorflow.default.bzl", + "get_compatible_with_cloud", + "tf_kernel_library", + "tf_py_test", +) load( "//tensorflow:tensorflow.bzl", "tf_cc_test", @@ -6,17 +15,21 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/compiler/mlir/quantization/tensorflow:internal_visibility_allowlist_package", ], licenses = ["notice"], ) +# Directly linked to `custom_aggregator_op`. In general, one should avoid directly depending on +# this target to avoid the ODR violation. Depend on `calibrator_singleton` instead. cc_library( - name = "calibrator_singleton", + name = "calibrator_singleton_impl", srcs = ["calibrator_singleton.cc"], hdrs = ["calibrator_singleton.h"], compatible_with = get_compatible_with_cloud(), + visibility = ["//visibility:private"], deps = [ "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", @@ -25,12 +38,24 @@ cc_library( ], ) +cc_library( + name = "calibrator_singleton", + hdrs = ["calibrator_singleton.h"], + compatible_with = get_compatible_with_cloud(), + deps = if_static([":calibrator_singleton_impl"]) + [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", + ], +) + tf_cc_test( name = "calibrator_singleton_test", size = "small", srcs = ["calibrator_singleton_test.cc"], deps = [ - ":calibrator_singleton", + ":calibrator_singleton_impl", "//tensorflow/core:test", "//tensorflow/core:test_main", ], @@ -40,9 +65,12 @@ tf_kernel_library( name = "custom_aggregator_op", srcs = ["custom_aggregator_op.cc"], compatible_with = get_compatible_with_cloud(), - visibility = ["//tensorflow/compiler/mlir/quantization/tensorflow/python:__pkg__"], + visibility = [ + "//tensorflow:__pkg__", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:__pkg__", + ], deps = [ - ":calibrator_singleton", + ":calibrator_singleton_impl", "//tensorflow/core:framework", ], ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/integration_test/custom_aggregator_op_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/integration_test/custom_aggregator_op_test.py index 135ff141d9d..846095b30ac 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/integration_test/custom_aggregator_op_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/integration_test/custom_aggregator_op_test.py @@ -40,7 +40,7 @@ def testBypassAndMinMax(self): dtypes.float32) aggregator = custom_aggregator_op_wrapper.custom_aggregator( input_tensor, '1') - self.assertAllEqual(aggregator.eval(), [1.0, 2.0, 3.0, 4.0, 5.0]) + self.assertAllEqual(self.evaluate(aggregator), [1.0, 2.0, 3.0, 4.0, 5.0]) min_val = quantize_model_wrapper.get_min_from_calibrator('1') max_val = quantize_model_wrapper.get_max_from_calibrator('1') self.assertAllEqual((min_val, max_val), (1.0, 5.0)) @@ -52,12 +52,13 @@ def testTwoIdentities(self): dtypes.float32) aggregator1 = custom_aggregator_op_wrapper.custom_aggregator( input_tensor1, '2') - self.assertAllEqual(aggregator1.eval(), [1.0, 2.0, 3.0, 4.0, 5.0]) + self.assertAllEqual(self.evaluate(aggregator1), [1.0, 2.0, 3.0, 4.0, 5.0]) input_tensor2 = array_ops.constant([-1.0, -2.0, -3.0, -4.0, -5.0], dtypes.float32) aggregator2 = custom_aggregator_op_wrapper.custom_aggregator( input_tensor2, '3') - self.assertAllEqual(aggregator2.eval(), [-1.0, -2.0, -3.0, -4.0, -5.0]) + self.assertAllEqual( + self.evaluate(aggregator2), [-1.0, -2.0, -3.0, -4.0, -5.0]) min_val = quantize_model_wrapper.get_min_from_calibrator('2') max_val = quantize_model_wrapper.get_max_from_calibrator('2') @@ -73,12 +74,13 @@ def testClearData(self): dtypes.float32) aggregator1 = custom_aggregator_op_wrapper.custom_aggregator( input_tensor1, '4') - self.assertAllEqual(aggregator1.eval(), [1.0, 2.0, 3.0, 4.0, 5.0]) + self.assertAllEqual(self.evaluate(aggregator1), [1.0, 2.0, 3.0, 4.0, 5.0]) input_tensor2 = array_ops.constant([-1.0, -2.0, -3.0, -4.0, -5.0], dtypes.float32) aggregator2 = custom_aggregator_op_wrapper.custom_aggregator( input_tensor2, '5') - self.assertAllEqual(aggregator2.eval(), [-1.0, -2.0, -3.0, -4.0, -5.0]) + self.assertAllEqual( + self.evaluate(aggregator2), [-1.0, -2.0, -3.0, -4.0, -5.0]) min_val = quantize_model_wrapper.get_min_from_calibrator('4') max_val = quantize_model_wrapper.get_max_from_calibrator('4') diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD new file mode 100644 index 00000000000..5976d1f91ce --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD @@ -0,0 +1,93 @@ +load( + "//tensorflow:tensorflow.default.bzl", + "get_compatible_with_cloud", +) +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + # By default, these targets should only be used within the quantization library. + default_visibility = [ + "//tensorflow/compiler/mlir/quantization/tensorflow:__subpackages__", + ], + licenses = ["notice"], +) + +cc_library( + name = "save_variables", + srcs = ["save_variables.cc"], + hdrs = ["save_variables.h"], + compatible_with = get_compatible_with_cloud(), + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core:framework", + "//tensorflow/core/ir/importexport:convert_tensor", + "//tensorflow/core/util/tensor_bundle", + "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:logging", + "//tensorflow/tsl/platform:status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + ], +) + +tf_cc_test( + name = "save_variables_test", + srcs = ["save_variables_test.cc"], + deps = [ + ":save_variables", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/framework:tensor_testutil", + "//tensorflow/core/util/tensor_bundle", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/platform:status_matchers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + ], +) + +cc_library( + name = "const_op_size", + srcs = ["const_op_size.cc"], + hdrs = ["const_op_size.h"], + compatible_with = get_compatible_with_cloud(), + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_remaining_ops", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "@llvm-project//mlir:IR", + ], +) + +tf_cc_test( + name = "const_op_size_test", + srcs = ["const_op_size_test.cc"], + deps = [ + ":const_op_size", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.cc new file mode 100644 index 00000000000..2c1b85ba194 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.cc @@ -0,0 +1,79 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.h" + +#include + +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace quant { +namespace { + +// For types that have varying sizes or difficult to determine the size of, each +// element is arbitrarily considered to be 4 bytes. +constexpr int64_t kAssumedNumBytesPerElem = 4; + +int64_t GetSizeOfIntOrFloatConst(TF::ConstOp const_op) { + const Type dtype = const_op.getDtype(); + const ElementsAttr const_value = const_op.getValue(); + + const auto bytes_per_elem = + static_cast(dtype.getIntOrFloatBitWidth() / CHAR_BIT); + + return bytes_per_elem * const_value.getNumElements(); +} + +int64_t GetSizeOfStringConst(TF::ConstOp const_op) { + const ElementsAttr const_value = const_op.getValue(); + + // This cast is guaranteed to succeed. See `ConvertToTensorProto` from + // tensorflow/core/ir/importexport/convert_tensor.cc. + const auto str_attr = cast(const_value); + + // Sum the sizes of each string. + return absl::c_accumulate( + str_attr.getRawStringData(), 0, + [](int64_t acc, const StringRef str_value) -> int64_t { + return acc + str_value.size(); + }); +} + +// Arbitrarily calculate the size of const of type whose size is unkown or +// varying. Each element of such a type is considered to have +// `kAssumedNumBytesPerElem` bytes. +int64_t GetSizeOfUnsupportedTypeConst(TF::ConstOp const_op) { + return kAssumedNumBytesPerElem * const_op.getValue().getNumElements(); +} + +} // namespace + +int64_t GetSizeInBytes(TF::ConstOp const_op) { + const Type dtype = const_op.getDtype(); + + if (dtype.isIntOrFloat()) { + return GetSizeOfIntOrFloatConst(const_op); + } else if (isa(dtype)) { + return GetSizeOfStringConst(const_op); + } else { + return GetSizeOfUnsupportedTypeConst(const_op); + } +} + +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.h b/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.h new file mode 100644 index 00000000000..884ac938f3c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.h @@ -0,0 +1,32 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_CONST_OP_SIZE_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_CONST_OP_SIZE_H_ + +#include + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace quant { + +// Returns the size in bytes of the underlying data of `const_op`. If the +// underlying type's size cannot be determined, it assumes 4 bytes per element. +int64_t GetSizeInBytes(TF::ConstOp const_op); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_CONST_OP_SIZE_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size_test.cc new file mode 100644 index 00000000000..5206aceec7b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size_test.cc @@ -0,0 +1,151 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.h" + +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/AsmState.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/platform/test.h" + +namespace mlir { +namespace quant { +namespace { + +using ::testing::Eq; + +class GetSizeInBytesTest : public ::testing::Test { + protected: + GetSizeInBytesTest() : ctx_() { ctx_.loadDialect(); } + + MLIRContext ctx_; +}; + +TF::ConstOp ParseConstOp(const absl::string_view const_op_str, Block& block, + MLIRContext& ctx) { + const LogicalResult parse_result = + parseSourceString(const_op_str, &block, ParserConfig(&ctx)); + EXPECT_TRUE(succeeded(parse_result)); + + auto const_op = dyn_cast_or_null(block.front()); + EXPECT_TRUE(const_op); + + return const_op; +} + +TEST_F(GetSizeInBytesTest, Int32ScalarConstOpSizeInBytes) { + constexpr absl::string_view kConstOpExpr = + R"mlir(%cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor)mlir"; + + Block block{}; + TF::ConstOp int_tensor_const_op = ParseConstOp(kConstOpExpr, block, ctx_); + + const int64_t num_bytes = GetSizeInBytes(int_tensor_const_op); + EXPECT_THAT(num_bytes, Eq(4)); +} + +TEST_F(GetSizeInBytesTest, Int32ConstOpSizeInBytes) { + constexpr absl::string_view kConstOpExpr = + R"mlir(%cst = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32>)mlir"; + + Block block{}; + TF::ConstOp int_tensor_const_op = ParseConstOp(kConstOpExpr, block, ctx_); + + const int64_t num_bytes = GetSizeInBytes(int_tensor_const_op); + EXPECT_THAT(num_bytes, Eq(8)); +} + +TEST_F(GetSizeInBytesTest, Int8ConstOpSizeInBytes) { + constexpr absl::string_view kConstOpExpr = + R"mlir(%cst = "tf.Const"() {value = dense<2> : tensor<3xi8>} : () -> tensor<3xi8>)mlir"; + + Block block{}; + TF::ConstOp int_tensor_const_op = ParseConstOp(kConstOpExpr, block, ctx_); + + const int64_t num_bytes = GetSizeInBytes(int_tensor_const_op); + EXPECT_THAT(num_bytes, Eq(3)); +} + +TEST_F(GetSizeInBytesTest, Float32ConstOpSizeInBytes) { + constexpr absl::string_view kConstOpExpr = + R"mlir(%cst = "tf.Const"() {value = dense<3.0> : tensor<4xf32>} : () -> tensor<4xf32>)mlir"; + + Block block{}; + TF::ConstOp int_tensor_const_op = ParseConstOp(kConstOpExpr, block, ctx_); + + const int64_t num_bytes = GetSizeInBytes(int_tensor_const_op); + EXPECT_THAT(num_bytes, Eq(16)); +} + +TEST_F(GetSizeInBytesTest, Float64ConstOpSizeInBytes) { + constexpr absl::string_view kConstOpExpr = + R"mlir(%cst = "tf.Const"() {value = dense<3.0> : tensor<2xf64>} : () -> tensor<2xf64>)mlir"; + + Block block{}; + TF::ConstOp int_tensor_const_op = ParseConstOp(kConstOpExpr, block, ctx_); + + const int64_t num_bytes = GetSizeInBytes(int_tensor_const_op); + EXPECT_THAT(num_bytes, Eq(16)); +} + +TEST_F(GetSizeInBytesTest, Bfloat16ConstOpSizeInBytes) { + constexpr absl::string_view kConstOpExpr = R"mlir( + %cst = "tf.Const"() {value = dense<1.0> : tensor<7xbf16>} : () -> tensor<7xbf16> + )mlir"; + + Block block{}; + TF::ConstOp int_tensor_const_op = ParseConstOp(kConstOpExpr, block, ctx_); + + const int64_t num_bytes = GetSizeInBytes(int_tensor_const_op); + EXPECT_THAT(num_bytes, Eq(14)); +} + +TEST_F(GetSizeInBytesTest, TfStringConstOpSizeInBytes) { + constexpr absl::string_view kConstOpExpr = R"mlir( + %cst = "tf.Const"() {value = dense<["Hello World", "Quantization"]> : tensor<2x!tf_type.string>} : () -> tensor<2x!tf_type.string> + )mlir"; + + Block block{}; + TF::ConstOp int_tensor_const_op = ParseConstOp(kConstOpExpr, block, ctx_); + + // Sum of the number of characters in "Hello World" and "Quantization". + const int64_t num_bytes = GetSizeInBytes(int_tensor_const_op); + EXPECT_THAT(num_bytes, Eq(23)); +} + +TEST_F(GetSizeInBytesTest, ConstOpWithUnknownSizeAssumes4BytesPerElement) { + constexpr absl::string_view kConstOpExpr = R"mlir( + %cst = "tf.Const"() {value = #tf_type : tensor} : () -> tensor + )mlir"; + + Block block{}; + TF::ConstOp int_tensor_const_op = ParseConstOp(kConstOpExpr, block, ctx_); + + // For non-fixed size like tf_type.variant, the size of each element is + // assumed to be 4 bytes. + const int64_t num_bytes = GetSizeInBytes(int_tensor_const_op); + EXPECT_THAT(num_bytes, Eq(4)); +} + +} // namespace +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.cc new file mode 100644 index 00000000000..529879df617 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.cc @@ -0,0 +1,135 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/ir/importexport/convert_tensor.h" +#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" +#include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/logging.h" +#include "tensorflow/tsl/platform/status.h" + +namespace tensorflow { +namespace quantization { +namespace { + +using ::mlir::func::FuncOp; +using ::mlir::tf_saved_model::GetInitializerFunction; +using ::mlir::tf_saved_model::GetSessionInitializerOp; +using ::mlir::tf_saved_model::kTfSavedModelInitializerRestoreType; +using ::mlir::tf_saved_model::kTfSavedModelInitializerTypeAttr; +using ::mlir::tf_saved_model::SessionInitializerOp; + +// Adds the tensor that initializes the variable through the provided +// `assign_var_op` to the `bundle_writer` for saving to checkpoint. Returns the +// shared name of the variable if a variable is saved successfully. If the +// variable is not saved, returns an empty string. +absl::StatusOr AddTensorToBundleWriter( + mlir::TF::AssignVariableOp assign_var_op, BundleWriter& bundle_writer) { + auto resource_operand = assign_var_op.getOperand(0); + auto var_handle_op = + llvm::dyn_cast(resource_operand.getDefiningOp()); + if (!var_handle_op) { + assign_var_op->emitRemark( + "Operand idx 0 is not a tf.VarHandleOp. The initializing tensor is not " + "saved to checkpoint."); + return ""; + } + + auto assigned_value_operand = assign_var_op.getOperand(1); + auto const_op = + llvm::dyn_cast(assigned_value_operand.getDefiningOp()); + if (!const_op) { + assign_var_op->emitRemark( + "Operand idx 1 is not a tf.ConstOp. The initializing tensor is not " + "saved to checkpoint."); + return ""; + } + + Tensor const_tensor{}; + if (const tsl::Status status = mlir::tfg::ConvertToTensor( + /*attr=*/const_op.getValue(), /*output_tensor=*/&const_tensor); + !status.ok()) { + return tsl::ToAbslStatus(status); + } + + if (!bundle_writer.Add(/*key=*/var_handle_op.getSharedName(), const_tensor) + .ok()) { + return tsl::ToAbslStatus(bundle_writer.status()); + } + + return var_handle_op.getSharedName().str(); +} + +} // namespace + +absl::StatusOr> SaveVariablesToCheckpoint( + const absl::string_view prefix, mlir::ModuleOp module_op) { + // Only the "tf.AssignVariableOp" patterns inside this initializer function + // will be searched. + FuncOp session_init_func_type_restore_op = GetInitializerFunction( + module_op, /*initializer_type=*/kTfSavedModelInitializerRestoreType); + if (!session_init_func_type_restore_op) { + LOG(INFO) << "No session initializer function with type 'restore_op'. No " + "variables are saved to checkpoint."; + return std::vector{}; + } + + BundleWriter bundle_writer(Env::Default(), prefix); + if (!bundle_writer.status().ok()) { + return tsl::ToAbslStatus(bundle_writer.status()); + } + + std::vector saved_variable_shared_names; + for (auto assign_variable_op : + session_init_func_type_restore_op.getOps()) { + if (const absl::StatusOr variable_shared_name = + AddTensorToBundleWriter(assign_variable_op, bundle_writer); + !variable_shared_name.ok()) { + return variable_shared_name.status(); + } else if (!variable_shared_name->empty()) { + // Empty string means the variable isn't applicable for saving. + saved_variable_shared_names.emplace_back( + std::move(*variable_shared_name)); + VLOG(1) << "Saved a variable with shared_name: " << *variable_shared_name; + } + } + + // Exit early if no variables are added. + if (saved_variable_shared_names.empty()) { + LOG(INFO) << "No variables are saved to checkpoint"; + return saved_variable_shared_names; + } + + if (!bundle_writer.Finish().ok()) { + return tsl::ToAbslStatus(bundle_writer.status()); + } + + return saved_variable_shared_names; +} + +} // namespace quantization +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.h b/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.h new file mode 100644 index 00000000000..124f2a5bc9d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.h @@ -0,0 +1,40 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_SAVE_VARIABLES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_SAVE_VARIABLES_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace tensorflow { +namespace quantization { + +// Saves variables in `module_op` to the checkpoint file inside `prefix`. +// It finds variables that are initialized with "tf.AssignVariableOp" inside the +// initializer function with type "restore_op". The "tf.Const"s used to +// initialize the variables are saved. This function does not modify the +// `module_op`. Returns a list of saved names of the saved variables. +absl::StatusOr> SaveVariablesToCheckpoint( + absl::string_view prefix, mlir::ModuleOp module_op); + +} // namespace quantization +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_SAVE_VARIABLES_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables_test.cc new file mode 100644 index 00000000000..8967b64b877 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables_test.cc @@ -0,0 +1,385 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.h" + +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" +#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/platform/status_matchers.h" + +namespace tensorflow { +namespace quantization { +namespace { + +using ::tensorflow::test::AsTensor; +using ::tensorflow::test::ExpectEqual; +using ::testing::IsEmpty; +using ::testing::Not; +using ::testing::UnorderedElementsAre; +using ::tsl::testing::IsOk; + +// This fixture simply wraps the Env and MLIRContext. +class SaveVariablesToCheckpointTest : public ::testing::Test { + protected: + SaveVariablesToCheckpointTest() : env_(Env::Default()) { + ctx_.loadDialect(); + } + + absl::StatusOr MakeTempDir() { + std::string tmp_dir{}; + if (!env_->LocalTempFilename(&tmp_dir)) { + return absl::InternalError("Failed to create temp file."); + } + + TF_CHECK_OK(env_->CreateDir(tmp_dir)); + return tmp_dir; + } + + // Parses `module_op_str` to create a `ModuleOp`. Checks whether the created + // module op is valid. + mlir::OwningOpRef ParseModuleOpString( + const absl::string_view module_op_str) { + auto module_op_ref = + mlir::parseSourceString(module_op_str, &ctx_); + EXPECT_TRUE(module_op_ref); + return module_op_ref; + } + + Env* env_{}; + mlir::MLIRContext ctx_{}; +}; + +TEST_F(SaveVariablesToCheckpointTest, VariableSavedToCheckpoint) { + constexpr absl::string_view kModuleCode = R"mlir( + module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + + func.func @init_func_restore_op() -> () attributes {tf_saved_model.exported_names = ["restore"], tf_saved_model.initializer_type = "restore_op"} { + %cst = "tf.Const"() {device = "", value = dense<[1.0, 2.0]> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.VarHandleOp"() {container = "", device = "/device:CPU:0", shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%0, %cst) : (tensor>>, tensor<2xf32>) -> () + return + } + } + )mlir"; + + mlir::OwningOpRef module_op_ref = + ParseModuleOpString(kModuleCode); + + const absl::StatusOr checkpoint_prefix = MakeTempDir(); + EXPECT_TRUE(checkpoint_prefix.ok()); + + const absl::Cleanup checkpoint_prefix_cleanup = [this, &checkpoint_prefix]() { + int64_t undeleted_files, undeleted_dirs; + TF_CHECK_OK(env_->DeleteRecursively(*checkpoint_prefix, &undeleted_files, + &undeleted_dirs)); + }; + + const absl::StatusOr> variable_shared_names = + SaveVariablesToCheckpoint(*checkpoint_prefix, *module_op_ref); + EXPECT_TRUE(variable_shared_names.ok()); + EXPECT_THAT(*variable_shared_names, UnorderedElementsAre("var_0")); + + // Verify the saved variable. + BundleReader bundle_reader(env_, *checkpoint_prefix); + + Tensor loaded_tensor{}; + EXPECT_TRUE( + tsl::ToAbslStatus(bundle_reader.Lookup("var_0", &loaded_tensor)).ok()); + + ExpectEqual(loaded_tensor, AsTensor({1.0, 2.0})); +} + +TEST_F(SaveVariablesToCheckpointTest, MultipleVariablesSavedToCheckpoint) { + // Module's session intializer contains two variables. + constexpr absl::string_view kModuleCode = R"mlir( + module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + + func.func @init_func_restore_op() -> () attributes {tf_saved_model.exported_names = ["restore"], tf_saved_model.initializer_type = "restore_op"} { + %cst = "tf.Const"() {device = "", value = dense<[1.0, 2.0]> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.VarHandleOp"() {container = "", device = "/device:CPU:0", shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%0, %cst) : (tensor>>, tensor<2xf32>) -> () + + %cst_0 = "tf.Const"() {device = "", value = dense<[3, 4, 5, 6]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tf.VarHandleOp"() {container = "", device = "/device:CPU:0", shared_name = "var_1"} : () -> tensor>> + "tf.AssignVariableOp"(%1, %cst_0) : (tensor>>, tensor<4xi32>) -> () + + return + } + } + )mlir"; + + mlir::OwningOpRef module_op_ref = + ParseModuleOpString(kModuleCode); + + const absl::StatusOr checkpoint_prefix = MakeTempDir(); + EXPECT_TRUE(checkpoint_prefix.ok()); + + const absl::Cleanup checkpoint_prefix_cleanup = [this, &checkpoint_prefix]() { + int64_t undeleted_files, undeleted_dirs; + TF_CHECK_OK(env_->DeleteRecursively(*checkpoint_prefix, &undeleted_files, + &undeleted_dirs)); + }; + + const absl::StatusOr> variable_shared_names = + SaveVariablesToCheckpoint(*checkpoint_prefix, *module_op_ref); + EXPECT_TRUE(variable_shared_names.ok()); + EXPECT_THAT(*variable_shared_names, UnorderedElementsAre("var_0", "var_1")); + + // Verify that both variables are saved correctly. + BundleReader bundle_reader(env_, *checkpoint_prefix); + + Tensor loaded_var_0{}; + EXPECT_TRUE( + tsl::ToAbslStatus(bundle_reader.Lookup("var_0", &loaded_var_0)).ok()); + ExpectEqual(loaded_var_0, AsTensor({1.0, 2.0})); + + Tensor loaded_var_1{}; + EXPECT_TRUE( + tsl::ToAbslStatus(bundle_reader.Lookup("var_1", &loaded_var_1)).ok()); + ExpectEqual(loaded_var_1, AsTensor({3, 4, 5, 6})); +} + +TEST_F(SaveVariablesToCheckpointTest, + NoVariablesSavedWhenNoInitializerFunction) { + constexpr absl::string_view kModuleCode = R"mlir( + module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = []} : () -> () + } + )mlir"; + + mlir::OwningOpRef module_op_ref = + ParseModuleOpString(kModuleCode); + + const absl::StatusOr checkpoint_prefix = MakeTempDir(); + EXPECT_TRUE(checkpoint_prefix.ok()); + + const absl::Cleanup checkpoint_prefix_cleanup = [this, &checkpoint_prefix]() { + int64_t undeleted_files, undeleted_dirs; + TF_CHECK_OK(env_->DeleteRecursively(*checkpoint_prefix, &undeleted_files, + &undeleted_dirs)); + }; + + const absl::StatusOr> variable_shared_names = + SaveVariablesToCheckpoint(*checkpoint_prefix, *module_op_ref); + EXPECT_TRUE(variable_shared_names.ok()); + EXPECT_THAT(*variable_shared_names, IsEmpty()); + + // Verify that the checkpoint doesn't exist. + BundleReader bundle_reader(env_, *checkpoint_prefix); + EXPECT_THAT(bundle_reader.status(), Not(IsOk())); +} + +TEST_F(SaveVariablesToCheckpointTest, + NoVariablesSavedWhenNoSessionInitializerOp) { + constexpr absl::string_view kModuleCode = R"mlir( + module { + func.func @my_func() -> () { + return + } + } + )mlir"; + + mlir::OwningOpRef module_op_ref = + ParseModuleOpString(kModuleCode); + + const absl::StatusOr checkpoint_prefix = MakeTempDir(); + EXPECT_TRUE(checkpoint_prefix.ok()); + + const absl::Cleanup checkpoint_prefix_cleanup = [this, &checkpoint_prefix]() { + int64_t undeleted_files, undeleted_dirs; + TF_CHECK_OK(env_->DeleteRecursively(*checkpoint_prefix, &undeleted_files, + &undeleted_dirs)); + }; + + EXPECT_TRUE( + SaveVariablesToCheckpoint(*checkpoint_prefix, *module_op_ref).ok()); + + // Verify that the checkpoint doesn't exist. + BundleReader bundle_reader(env_, *checkpoint_prefix); + EXPECT_THAT(bundle_reader.status(), Not(IsOk())); +} + +TEST_F(SaveVariablesToCheckpointTest, + NoVariablesSavedWhenNoSessionInitializerOpTypeRestoreOp) { + constexpr absl::string_view kModuleCode = R"mlir( + module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_init_op]} : () -> () + + func.func @init_func_init_op() -> () attributes {tf_saved_model.exported_names = ["init"], tf_saved_model.initializer_type = "init_op"} { + %cst = "tf.Const"() {device = "", value = dense<[1.0, 2.0]> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.VarHandleOp"() {container = "", device = "/device:CPU:0", shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%0, %cst) : (tensor>>, tensor<2xf32>) -> () + return + } + } + )mlir"; + + mlir::OwningOpRef module_op_ref = + ParseModuleOpString(kModuleCode); + + const absl::StatusOr checkpoint_prefix = MakeTempDir(); + EXPECT_TRUE(checkpoint_prefix.ok()); + + const absl::Cleanup checkpoint_prefix_cleanup = [this, &checkpoint_prefix]() { + int64_t undeleted_files, undeleted_dirs; + TF_CHECK_OK(env_->DeleteRecursively(*checkpoint_prefix, &undeleted_files, + &undeleted_dirs)); + }; + + const absl::StatusOr> variable_shared_names = + SaveVariablesToCheckpoint(*checkpoint_prefix, *module_op_ref); + EXPECT_TRUE(variable_shared_names.ok()); + EXPECT_THAT(*variable_shared_names, IsEmpty()); + + // Verify that the checkpoint doesn't exist. + BundleReader bundle_reader(env_, *checkpoint_prefix); + EXPECT_THAT(bundle_reader.status(), Not(IsOk())); +} + +TEST_F(SaveVariablesToCheckpointTest, MutableVariablesNotSaved) { + // This function includes an AssignVariableOp that does not initialize the + // variable from a ConstOp. In this case, the variable is not saved to the + // checkpoint. + constexpr absl::string_view kModuleCode = R"mlir( + module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + + func.func @init_func_restore_op() -> () attributes {tf_saved_model.exported_names = ["init"], tf_saved_model.initializer_type = "restore_op"} { + %cst = "tf.Const"() {device = "", value = dense<[1.0, 2.0]> : tensor<2xf32>} : () -> tensor<2xf32> + %add = "tf.AddV2"(%cst, %cst) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + %var_handle = "tf.VarHandleOp"() {container = "", device = "/device:CPU:0", shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_handle, %add) : (tensor>>, tensor<2xf32>) -> () + return + } + } + )mlir"; + + mlir::OwningOpRef module_op_ref = + ParseModuleOpString(kModuleCode); + + const absl::StatusOr checkpoint_prefix = MakeTempDir(); + EXPECT_TRUE(checkpoint_prefix.ok()); + + const absl::Cleanup checkpoint_prefix_cleanup = [this, &checkpoint_prefix]() { + int64_t undeleted_files, undeleted_dirs; + TF_CHECK_OK(env_->DeleteRecursively(*checkpoint_prefix, &undeleted_files, + &undeleted_dirs)); + }; + + const absl::StatusOr> variable_shared_names = + SaveVariablesToCheckpoint(*checkpoint_prefix, *module_op_ref); + EXPECT_TRUE(variable_shared_names.ok()); + EXPECT_THAT(*variable_shared_names, IsEmpty()); + + BundleReader bundle_reader(env_, *checkpoint_prefix); + EXPECT_THAT(bundle_reader.status(), Not(IsOk())); +} + +TEST_F(SaveVariablesToCheckpointTest, + VariableNotSavedWhenNonVarHandleOpOperandForAssignVariableOp) { + constexpr absl::string_view kModuleCode = R"mlir( + module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + + func.func @init_func_restore_op() -> () attributes {tf_saved_model.exported_names = ["init"], tf_saved_model.initializer_type = "restore_op"} { + %cst = "tf.Const"() {device = "", value = dense<[1.0, 2.0]> : tensor<2xf32>} : () -> tensor<2xf32> + %var_handle = "tf.VarHandleOp"() {container = "", device = "/device:CPU:0", shared_name = "var_0"} : () -> tensor>> + %var_handle_cast = "tf.Cast"(%var_handle) : (tensor>>) -> tensor + "tf.AssignVariableOp"(%var_handle_cast, %cst) : (tensor, tensor<2xf32>) -> () + return + } + } + )mlir"; + + mlir::OwningOpRef module_op_ref = + ParseModuleOpString(kModuleCode); + + const absl::StatusOr checkpoint_prefix = MakeTempDir(); + EXPECT_TRUE(checkpoint_prefix.ok()); + + const absl::Cleanup checkpoint_prefix_cleanup = [this, &checkpoint_prefix]() { + int64_t undeleted_files, undeleted_dirs; + TF_CHECK_OK(env_->DeleteRecursively(*checkpoint_prefix, &undeleted_files, + &undeleted_dirs)); + }; + + const absl::StatusOr> variable_shared_names = + SaveVariablesToCheckpoint(*checkpoint_prefix, *module_op_ref); + EXPECT_TRUE(variable_shared_names.ok()); + EXPECT_THAT(*variable_shared_names, IsEmpty()); + + BundleReader bundle_reader(env_, *checkpoint_prefix); + EXPECT_THAT(bundle_reader.status(), Not(IsOk())); +} + +TEST_F(SaveVariablesToCheckpointTest, FailsWhenDuplicateSharedName) { + // Saving variables fails when there are duplicate shared_names ("var_0"). + constexpr absl::string_view kModuleCode = R"mlir( + module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + + func.func @init_func_restore_op() -> () attributes {tf_saved_model.exported_names = ["restore"], tf_saved_model.initializer_type = "restore_op"} { + %cst = "tf.Const"() {device = "", value = dense<[1.0, 2.0]> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.VarHandleOp"() {container = "", device = "/device:CPU:0", shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%0, %cst) : (tensor>>, tensor<2xf32>) -> () + + %cst_0 = "tf.Const"() {device = "", value = dense<[3, 4, 5, 6]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tf.VarHandleOp"() {container = "", device = "/device:CPU:0", shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%1, %cst_0) : (tensor>>, tensor<4xi32>) -> () + + return + } + } + )mlir"; + + mlir::OwningOpRef module_op_ref = + ParseModuleOpString(kModuleCode); + + const absl::StatusOr checkpoint_prefix = MakeTempDir(); + EXPECT_TRUE(checkpoint_prefix.ok()); + + const absl::Cleanup checkpoint_prefix_cleanup = [this, &checkpoint_prefix]() { + int64_t undeleted_files, undeleted_dirs; + TF_CHECK_OK(env_->DeleteRecursively(*checkpoint_prefix, &undeleted_files, + &undeleted_dirs)); + }; + + EXPECT_FALSE( + SaveVariablesToCheckpoint(*checkpoint_prefix, *module_op_ref).ok()); +} + +} // namespace +} // namespace quantization +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD new file mode 100644 index 00000000000..448ac05842f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD @@ -0,0 +1,46 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") + +package( + default_visibility = [ + "//tensorflow/compiler/mlir/quantization:__subpackages__", + ], + licenses = ["notice"], +) + +cc_library( + name = "mlir_dump", + srcs = ["mlir_dump.cc"], + hdrs = ["mlir_dump.h"], + compatible_with = get_compatible_with_cloud(), + deps = [ + "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:path", + "//tensorflow/tsl/platform:status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + +tf_cc_test( + name = "mlir_dump_test", + srcs = ["mlir_dump_test.cc"], + compatible_with = get_compatible_with_cloud(), + deps = [ + ":mlir_dump", + "//tensorflow/tsl/platform:path", + "//tensorflow/tsl/platform:test", + "//tensorflow/tsl/platform:test_main", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc new file mode 100644 index 00000000000..fe1b205ce1d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc @@ -0,0 +1,145 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/path.h" +#include "tensorflow/tsl/platform/status.h" + +namespace tensorflow { +namespace quantization { +namespace { + +// Retrieve the MLIR dump directory. The directory is read from the environment +// variable `TF_QUANT_MLIR_DUMP_PREFIX`. However, if a special value "sponge" is +// set to `TF_QUANT_MLIR_DUMP_PREFIX`, it uses the directory set in +// `TEST_UNDECLARED_OUTPUT_DIRS`. Returns `absl::FailedPreconditionError` if +// either: +// 1. `TF_QUANT_MLIR_DUMP_PREFIX` is not set (empty), or +// 2. `TEST_UNDECLARED_OUTPUT_DIRS` is not set (empty) when +// `TF_QUANT_MLIR_DUMP_PREFIX = "sponge"`. +absl::StatusOr GetMlirDumpDir() { + auto dump_dir = std::string( + absl::NullSafeStringView(std::getenv("TF_QUANT_MLIR_DUMP_PREFIX"))); + if (dump_dir.empty()) { + return absl::FailedPreconditionError( + "Environment variable not set: TF_QUANT_MLIR_DUMP_PREFIX, " + "IR dump file for TF quantization is not created."); + } + + if (absl::EqualsIgnoreCase(dump_dir, "sponge")) { + if (!tsl::io::GetTestUndeclaredOutputsDir(&dump_dir)) { + return absl::FailedPreconditionError( + "Environment variable TF_QUANT_MLIR_DUMP_PREFIX=sponge but " + "TEST_UNDECLARED_OUTPUT_DIRS not set."); + } + } + + return dump_dir; +} + +// Creates a new file to dump the intermediate MLIRs by prefixing the +// `dump_file_name` with the value of the TF_QUANT_MLIR_DUMP_PREFIX env +// variable. Returns absl::FailedPreconditionError if the env variable is not +// set or set to an empty string. +absl::StatusOr> CreateMlirDumpFile( + const absl::string_view dump_file_name) { + const absl::StatusOr dump_dir = GetMlirDumpDir(); + if (!dump_dir.ok()) { + return dump_dir.status(); + } + + auto *env = tsl::Env::Default(); + const tsl::Status status = env->RecursivelyCreateDir(*dump_dir); + if (!status.ok()) { + return tsl::ToAbslStatus(status); + } + + std::error_code ec{}; // NOLINT: Required to create llvm::raw_fd_ostream + const std::string dump_file_path = + tsl::io::JoinPath(*dump_dir, dump_file_name); + auto dump_file = std::make_unique(dump_file_path, ec); + if (ec) { + return absl::InternalError(absl::StrFormat( + "Unable to open file: %s, error: %s", dump_file_path, ec.message())); + } + + LOG(INFO) << "IR dump file created: " << dump_file_path; + return dump_file; +} + +} // namespace + +void EnableIrPrinting(llvm::raw_ostream &out_stream, mlir::PassManager &pm) { + mlir::OpPrintingFlags flag{}; + flag.useLocalScope().elideLargeElementsAttrs().enableDebugInfo(); + + // IR printing requires multithreading disabled. + pm.getContext()->disableMultithreading(); + + // The configuration uses the default parameter values for + // `PassManager::enableIRPrinting`, except for the `printModuleScope` + // parameter, which is true by default. It is set to false to avoid the dump + // file size becoming too large when the passes are running on a large model. + pm.enableIRPrinting( + /*shouldPrintBeforePass=*/[](mlir::Pass *, + mlir::Operation *) { return true; }, + /*shouldPrintAfterPass=*/ + [](mlir::Pass *, mlir::Operation *) { return true; }, + /*printModuleScope=*/false, /*printAfterOnlyOnChange=*/true, + /*printAfterOnlyOnFailure=*/false, out_stream, flag); + + LOG(INFO) << "IR dump for TensorFlow quantization pipeline enabled."; +} + +// TODO(b/259374854): Create tests for MaybeEnableIrPrinting. +absl::StatusOr> MaybeEnableIrPrinting( + mlir::PassManager &pm, const absl::string_view name) { + if (!VLOG_IS_ON(1)) { + LOG(INFO) << "Verbosity level too low to enable IR printing."; + return nullptr; + } + + absl::StatusOr> dump_file = + CreateMlirDumpFile(/*dump_file_name=*/absl::StrCat(name, ".mlir")); + if (absl::IsFailedPrecondition(dump_file.status())) { + // Requirements for enabling IR dump are not met. IR printing will not be + // enabled. + LOG(WARNING) << dump_file.status(); + return nullptr; + } else if (!dump_file.ok()) { + return dump_file.status(); + } + + EnableIrPrinting(**dump_file, pm); + + return dump_file; +} + +} // namespace quantization +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h new file mode 100644 index 00000000000..db13cd19f08 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h @@ -0,0 +1,43 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_DEBUGGING_MLIR_DUMP_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_DEBUGGING_MLIR_DUMP_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Pass/PassManager.h" // from @llvm-project + +namespace tensorflow { +namespace quantization { + +// Enables IR printing for `pm`. When the passes are run, the IRs will be dumped +// to `out_stream`. +void EnableIrPrinting(llvm::raw_ostream &out_stream, mlir::PassManager &pm); + +// If verbosity level >= 1, this will dump intermediate IRs of passes to a file. +// The file path is given by prefixing `name`.mlir with the value of the +// TF_QUANT_MLIR_DUMP_PREFIX env variable. Returns `nullptr` iff the verbosity +// level < 1 or TF_QUANT_MLIR_DUMP_PREFIX is not set or set to an empty string. +// The returned ostream instance should live until the pass run is complete. +absl::StatusOr> MaybeEnableIrPrinting( + mlir::PassManager &pm, const absl::string_view name); + +} // namespace quantization +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_DEBUGGING_MLIR_DUMP_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc new file mode 100644 index 00000000000..da8b64fc136 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc @@ -0,0 +1,80 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h" + +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/tsl/platform/path.h" +#include "tensorflow/tsl/platform/test.h" + +namespace tensorflow { +namespace quantization { +namespace { + +class NoOpPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NoOpPass) + + NoOpPass() = default; + + llvm::StringRef getArgument() const final { return "no-op-pass"; } + + void runOnOperation() override { + // Noop pass does nothing on the operation. + } +}; + +std::unique_ptr> CreateNoOpPass() { + return std::make_unique(); +} + +TEST(EnableIrPrintingTest, PassSuccessfullyRuns) { + mlir::MLIRContext ctx{}; + + mlir::PassManager pm = {&ctx}; + pm.addPass(CreateNoOpPass()); + + std::error_code ec{}; // NOLINT: Required to create llvm::raw_fd_ostream + const std::string tmp_dump_filename = + tsl::io::GetTempFilename(/*extension=*/".mlir"); + llvm::raw_fd_ostream dump_file{tmp_dump_filename, ec}; + + EnableIrPrinting(dump_file, pm); + + mlir::OpBuilder builder(&ctx); + auto module_op = builder.create(builder.getUnknownLoc()); + // Destroy by calling destroy() to avoid memory leak since it is allocated + // with malloc(). + const absl::Cleanup module_op_cleanup = [module_op] { module_op->destroy(); }; + + const mlir::LogicalResult result = pm.run(module_op); + EXPECT_FALSE(failed(result)); +} + +} // namespace +} // namespace quantization +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/exported_model.proto b/tensorflow/compiler/mlir/quantization/tensorflow/exported_model.proto new file mode 100644 index 00000000000..6684f6696ea --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/exported_model.proto @@ -0,0 +1,40 @@ +syntax = "proto3"; + +package tensorflow.quantization; + +import "tensorflow/core/framework/graph.proto"; + +option cc_enable_arenas = true; + +// Represents an exported TensorFlow model. It consists of a GraphDef and extra +// metadata required for building a SavedModel. This message is primarily used +// to "export" the model produced from various quantization passes in c++ to +// Python layer. +// Next ID: 7 +message ExportedModel { + GraphDef graph_def = 1; + + // Name of the initialization node (TF Operation) used for initializing + // resources like hash tables upon loading. + string init_node_name = 2; + + // Name of the restore node. When fetched it runs the `RestoreV2` op that + // restores variables from the checkpoint file specified by `checkpoint_dir`. + string restore_node_name = 3; + + // A set of variable `shared_name`s to restore for the quantized model. + repeated string variable_shared_names = 4; + + // Path to the directory where checkpoint files are saved. This directoy is + // not expected to be persistent (usually a temporary directory). When + // fetching the restore op (see `restore_node_name`), this value is provided + // to the "file_prefix" tensor to identify the checkpoint directory. + string checkpoint_dir = 5; + + // Function name -> function alias mapping. This associates the quantized + // functions to the original functions' aliases. This information will be used + // to populate `MetaInfoDef`s `function_aliases` when the quantized model is + // exported to the saved model. This field is usually only populated for the + // TF2 models. + map function_aliases = 6; +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/gen_quantized_function_library.py b/tensorflow/compiler/mlir/quantization/tensorflow/gen_quantized_function_library.py index f44e2df1e78..8352b974996 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/gen_quantized_function_library.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/gen_quantized_function_library.py @@ -22,6 +22,8 @@ from absl import app from absl import flags +# TODO(b/263048929): Create a test for gen_quantized_function_library. + _OUTPUT_FILE = flags.DEFINE_string('output_file', None, 'output file location') _SRCS = flags.DEFINE_string('src', None, 'source file locations') _NAMESPACE = flags.DEFINE_string('namespace', 'mlir::quant', @@ -51,14 +53,14 @@ def _substitute_for_loop_template(module: str) -> str: for arg_value in arg_values: arg_dict = {arg_name: arg_value} replacement_text += '\\n' - replacement_text += _substitute_function_template( + replacement_text += _substitute_parameterization_template( loop_template.safe_substitute(arg_dict)) module = re.sub(compiled_regex, replacement_text, module, count=1) return module -def _substitute_function_template(module: str) -> str: +def _substitute_parameterization_template(module: str) -> str: """Substitutes all the function templates in the given module.""" compiled_regex = re.compile( r'^\s*parameters(\[.*?\])\n?(^\s*(?:func\.)+func.*?\{.*?(?:func\.)+return.*?\}\n)', @@ -70,18 +72,91 @@ def _substitute_function_template(module: str) -> str: try: value_list = ast.literal_eval(func_match.group(1)) - func_template = string.Template(func_match.group(2)) + # Escapes template $-based substitutions for attributes containing $. + # $$ is replaced with a single $. + func_template = string.Template( + func_match.group(2).replace('tfdtype$DT_', 'tfdtype$$DT_')) except Exception as e: # pylint: disable=broad-except raise ValueError('The function template is in wrong format') from e replacement_text = '' for value_dict in value_list: + for key, value in value_dict.items(): + # Replace single quote to double quote since single quote around a + # string are not valid in the MLIR representation. + value_dict[key] = str(value).replace("'", '"') replacement_text += '\\n' replacement_text += func_template.substitute(value_dict) module = re.sub(compiled_regex, replacement_text, module, count=1) return module +def _format_snake_case_op_name(s): + """Formats the op name to snake case.""" + s = s.replace('2D', '2d').replace('3D', '3d') + snake_case = ''.join(['_' + i.lower() if i.isupper() else i for i in s + ]).lstrip('_') + return snake_case.replace('mat_mul', 'matmul').replace('bias_add', 'bias') + + +def _substitute_impl_function_name_template(module: str) -> str: + """Generates the op-specific implementation function name.""" + compiled_regex = re.compile(r'GenerateImplFunctionName\(([\w\s]+)\)') + while True: + func_match = re.search(compiled_regex, module) + if func_match is None: + break + + text = func_match.group(1) + function_name = 'internal_{}_fn'.format(_format_snake_case_op_name(text)) + module = re.sub(compiled_regex, function_name, module, count=1) + return module + + +def _substitute_quantized_function_name_template(module: str) -> str: + """Generates the quantized function name.""" + compiled_regex = re.compile( + r'GenerateQuantizedFunctionName(\([\w\s\'\"\[\],]+\))') + while True: + func_match = re.search(compiled_regex, module) + if func_match is None: + break + + # Make sure the string ends with ",)" so the parsed value is a tuple. + argument_string = func_match.group(1) + if not argument_string.endswith(',)'): + argument_string = argument_string[:-1] + ',)' + arguments = ast.literal_eval(argument_string) + + if len(arguments) < 1 or len(arguments) > 2: + raise ValueError( + 'Wrong number of arguments to GenerateQuantizedFunctionName') + + quantized_ops = arguments[0] + if not quantized_ops: + raise ValueError('The quantized_ops list must not be empty') + + # Add op names to the function name. + function_name = 'quantized_{}'.format( + _format_snake_case_op_name(quantized_ops[0])) + if len(quantized_ops) > 1: + function_name += '_with_{}'.format( + _format_snake_case_op_name(quantized_ops[1])) + if len(quantized_ops) > 1: + for quantized_op in quantized_ops[2:]: + function_name += '_and_{}'.format( + _format_snake_case_op_name(quantized_op)) + + # Add suffix based on output type. + suffix = '_fn' + if len(arguments) > 1 and arguments[1] == 'f32': + suffix = '_float_output_fn' + function_name += suffix + + module = re.sub(compiled_regex, function_name, module, count=1) + return module + + def main(_: Sequence[str]) -> None: namespaces = _NAMESPACE.value.split('::') src_files = _SRCS.value.split(' ') @@ -107,7 +182,9 @@ def main(_: Sequence[str]) -> None: raise ValueError('The file name must start with {}'.format(file_prefix)) tag = out[1][:-5] # the last five values = ".mlir" module = _substitute_for_loop_template(module) - module = _substitute_function_template(module) + module = _substitute_parameterization_template(module) + module = _substitute_quantized_function_name_template(module) + module = _substitute_impl_function_name_template(module) modules.append((tag, module)) with open(_OUTPUT_FILE.value, 'w') as f: diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc index 3daed86a936..919557f557e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc @@ -30,7 +30,7 @@ std::unique_ptr GetTFOpQuantSpec(Operation* op) { auto spec = std::make_unique(); if (auto call_op = dyn_cast(op)) { StringRef function_name = - call_op.fAttr().cast().getValue(); + call_op.getFAttr().cast().getValue(); if (!function_name.startswith("composite_")) { return spec; } @@ -58,6 +58,15 @@ std::unique_ptr GetTFOpQuantSpec(Operation* op) { spec->biases_params[2] = {{0, 1}, quant::GetUniformQuantizedTypeForBias}; } + } else if (function_name.contains("batch_matmul")) { + spec->coeff_op_quant_dim[1] = -1; + if (function_name.contains("with_bias")) { + spec->biases_params[2] = {{0, 1}, + quant::GetUniformQuantizedTypeForBias}; + } + } else if (function_name.contains("gather")) { + // Note that gather has axis attribute that specifies channel axis. + spec->coeff_op_quant_dim[0] = -1; } for (auto quantizable_operand : spec->coeff_op_quant_dim) { spec->quantizable_operands.insert(quantizable_operand.first); @@ -72,12 +81,22 @@ std::unique_ptr GetTfQuantScaleSpec(Operation* op) { // clang-format off // go/keep-sorted start TF::AvgPoolOp, + TF::ConcatOp, TF::ConcatV2Op, + TF::ExpandDimsOp, + TF::IdentityNOp, TF::IdentityOp, TF::MaxPoolOp, TF::PadV2Op, + TF::RankOp, TF::ReshapeOp, - TF::SqueezeOp + TF::SelectOp, + TF::SelectV2Op, + TF::ShapeNOp, + TF::ShapeOp, + TF::SizeOp, + TF::SqueezeOp, + TF::TransposeOp // go/keep-sorted end // clang-format on >(op)) { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h index 411bc463456..2b3089d6ddf 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h @@ -32,7 +32,6 @@ namespace quant { std::unique_ptr GetTFOpQuantSpec(Operation* op); // Returns quantization scale specs (fixed output, same scale) for a TF op. -// TODO(b/224691264): Implement same scale verification like `VerifySameScales` std::unique_ptr GetTfQuantScaleSpec(Operation* op); } // namespace quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.cc new file mode 100644 index 00000000000..c86968b319c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.cc @@ -0,0 +1,39 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.h" + +#include + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::quant { + +std::unique_ptr GetUniformOpQuantSpec(Operation* op) { + auto spec = std::make_unique(); + if (isa(op) || + isa(op)) { + spec->coeff_op_quant_dim[1] = 3; + } else if (isa(op)) { + spec->coeff_op_quant_dim[1] = -1; + } + + for (auto quantizable_operand : spec->coeff_op_quant_dim) { + spec->quantizable_operands.insert(quantizable_operand.first); + } + return spec; +} + +} // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.h b/tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.h new file mode 100644 index 00000000000..5ff8929c71c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.h @@ -0,0 +1,35 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Functions for quantization specifications of Uniform Quantized ops. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_UNIFORM_OP_QUANT_SPEC_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_UNIFORM_OP_QUANT_SPEC_H_ + +#include + +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" + +namespace mlir { +namespace quant { + +// Returns the spec for the given operation that can be used for both of +// dynamic and static range quantization. +std::unique_ptr GetUniformOpQuantSpec(Operation* op); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_UNIFORM_OP_QUANT_SPEC_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_quant_ops_to_mhlo.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_quant_ops_to_mhlo.cc index cefc126a515..1ad50bb633c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_quant_ops_to_mhlo.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_quant_ops_to_mhlo.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/ir/importexport/mangling.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/duplicate_shape_determining_constants.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/duplicate_shape_determining_constants.cc new file mode 100644 index 00000000000..707856389d2 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/duplicate_shape_determining_constants.cc @@ -0,0 +1,369 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "absl/algorithm/container.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" + +// Required to use LLVM_DEBUG macro. +#define DEBUG_TYPE "quant-duplicate-shape-determining-constants" + +namespace mlir { +namespace quant { +namespace { + +// This pass duplicates constants that affect or determine the shape of a tensor +// after being used in a computation for some op. Some specific operands of TF +// ops (like the `dim` argument for `TF::ExpandDimsOp`) determine the shape of +// the resulting tensor. If these operands are constants, they are duplicated +// and replace the shape-determining operands. Each duplicated constant will +// only be used as the shape-determining operand; it will not replace other +// usages of the original constant. If the operands are not constants (i.e. +// results of some other computation), then the pass recursively traverses the +// call tree upwards and duplicates all constants found in the subtree in a +// similar manner. +// +// This pass may be used to avoid placing shape-determining constants in the CPU +// graph and pass them as arguments to the TPU graph (via `TPUPartitionedCall`). +// If this happens, the XLA compiler cannot recognize such arguments as +// constants and may result in an error. +// +// A set of predefined ops and operand indices is used to determine whether an +// operand is a target for constant duplication. +class DuplicateShapeDeterminingConstantsPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + DuplicateShapeDeterminingConstantsPass) + + StringRef getArgument() const final { + return "quant-duplicate-shape-determining-constants"; + } + + StringRef getDescription() const final { + return "Duplicates shape-determining constants. A shape-determining " + "constant is a constant that are transitively used to change or " + "determine the shape of a tensor. For example, the second argument " + "'dim' to TF::ExpandDimsOp specifies the dimension index to expand."; + } + + void runOnOperation() override; +}; + +// Returns True iff the otuput value of `op` is either a compile time constant +// or bounded from the XLA compiler's perspective, even if it is not a +// `ConstOp`. +bool IsOutputCompileTimeConstantOrBounded(Operation* op) { + return llvm::isa_and_nonnull(op); +} + +// Recursively duplicate constants for `op_operands` upward. +void RecursivelyDuplicateConstantsForOperands( + llvm::ArrayRef op_operands) { + // Target operands to duplicate if it is a ConstOp. + llvm::SmallVector duplication_targets{op_operands.begin(), + op_operands.end()}; + + int target_idx = 0; + while (target_idx < duplication_targets.size()) { + OpOperand* curr_operand = duplication_targets[target_idx]; + target_idx++; + + Operation* owning_op = curr_operand->getOwner(); + Operation* defining_op = curr_operand->get().getDefiningOp(); + + if (llvm::isa_and_nonnull(defining_op)) { + // No need to clone if this is the only use. + if (defining_op->hasOneUse()) { + LLVM_DEBUG(llvm::dbgs() + << "Not duplicating constant operand since it has only one " + "usage. Op: " + << curr_operand->getOperandNumber() + << ", operand idx: " << curr_operand->getOperandNumber() + << ", loc: " << owning_op->getLoc() << "\n"); + continue; + } + + mlir::OpBuilder builder{owning_op->getContext()}; + builder.setInsertionPointAfter(defining_op); + auto const_op_cloned = builder.clone(*defining_op); + + // Replace the operand with the duplicated op. + owning_op->setOperand(curr_operand->getOperandNumber(), + const_op_cloned->getResult(0)); + + LLVM_DEBUG(llvm::dbgs() + << "Duplicated constant operand from: " + << owning_op->getName().getStringRef() + << ", operand idx: " << curr_operand->getOperandNumber() + << ", loc: " << const_op_cloned->getLoc() << "\n"); + } else if (IsOutputCompileTimeConstantOrBounded(defining_op)) { + // Stop the recursion early when the output of the defining op is + // considered compile-time constant from the XLA compiler's perspective. + continue; + } else if (!defining_op) { + // One example for this case is when `curr_operand` is a function + // argument. + owning_op->emitWarning() + << "Operand idx (zero-based): " << curr_operand->getOperandNumber() + << " does not have a defining op and cannot be duplicated."; + } else { + // If the operand's defining is not a ConstOp, recursively traverse + // "upwards" to find ConstOps that transitively produces the current + // operand and duplicate them. + auto op_operands = defining_op->getOpOperands(); + absl::c_transform( + op_operands, std::back_inserter(duplication_targets), + [](OpOperand& op_operand) -> OpOperand* { return &op_operand; }); + } + } +} + +// Evaluate `operand_idx` w.r.t. `op`'s operands. If `operand_idx` is a positive +// number or a zero, it is returned as it is. If it is a negative number, it +// means it is counting backwards and will return the zero-based operand index +// for `op`. +// +// `operand_idx` should be within the range: [-num_operands, num_operands - 1]. +int EvaluateOperandIdx(const int operand_idx, Operation& op) { + if (operand_idx < 0) { + // Calculate the actual index if a negative value is provided for + // `operand_idx`. + return op.getNumOperands() + operand_idx; + } + return operand_idx; +} + +// Returns the pointers to operands at `operand_indices` of `op`. +llvm::SmallVector GetOperands(Operation& op, + llvm::ArrayRef operand_indices) { + llvm::SmallVector operands{}; + for (const int operand_idx : operand_indices) { + const int evaluated_operand_idx = EvaluateOperandIdx(operand_idx, op); + operands.emplace_back(&op.getOpOperand(evaluated_operand_idx)); + } + + return operands; +} + +// Represents an op type and its operand indices that should be "compile time +// constant" from the XLA compiler's point of view. +template +struct CompileTimeConstantOperand { + static_assert( + sizeof...(OperandIdx) > 0, + "CompileTimeConstantOperand should have at least one operand index."); + + using OpType = OpT; + + // Returns the indices of operands that should be compile time constants. + static constexpr std::array OperandIndices() { + return {OperandIdx...}; + } +}; + +// Finds all op of type `T::OpType` `func_op` and recursively duplicates +// constants used at the op's operands at `T::OperandIndices()`. It sequentially +// does the same thing for `Ts`. +template +void DuplicateShapeDeterminingConstants(func::FuncOp func_op) { + for (auto op : func_op.getOps()) { + RecursivelyDuplicateConstantsForOperands( + GetOperands(*op, T::OperandIndices())); + } + + // Do the same thing for the rest of `Ts`. + if constexpr (sizeof...(Ts) != 0) { + DuplicateShapeDeterminingConstants(func_op); + } +} + +void DuplicateShapeDeterminingConstantsPass::runOnOperation() { + func::FuncOp func_op = getOperation(); + + DuplicateShapeDeterminingConstants< + // go/keep-sorted start + CompileTimeConstantOperand, // $group_assignment + CompileTimeConstantOperand, // $dimension + CompileTimeConstantOperand, // $dimension + // $orig_input_shape + CompileTimeConstantOperand, + // $orig_input_shape + CompileTimeConstantOperand, + // $block_shape, $crops + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $crops + CompileTimeConstantOperand, // $size + CompileTimeConstantOperand, // $s0, $s1 + // $s0, $s1 + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $shape + /// $group_assignment + CompileTimeConstantOperand, + // $source_target_pairs + CompileTimeConstantOperand, + // $group_size, $group_key + CompileTimeConstantOperand, + CompileTimeConstantOperand, // (variadic) $axis + // $filter_sizes + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $input_sizes + // $filter_sizes + CompileTimeConstantOperand, + // $input_sizes + CompileTimeConstantOperand, + // $group_assignment + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $axis + CompileTimeConstantOperand, // $axis + CompileTimeConstantOperand, // $axis + // $filter_sizes + CompileTimeConstantOperand, + // $input_sizes + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $shape + // $element_shape, $max_num_elements + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $dim + CompileTimeConstantOperand, // $dims + CompileTimeConstantOperand, // $axis + CompileTimeConstantOperand, // $fft_length + CompileTimeConstantOperand, // $fft_length + CompileTimeConstantOperand, // $fft_length + CompileTimeConstantOperand, // $k + CompileTimeConstantOperand, // $num + CompileTimeConstantOperand, // $x, $y + // $k, $padding_value + CompileTimeConstantOperand, + // $k, $num_rows, $num_cols, $padding_value + CompileTimeConstantOperand, + // $k, $num_rows, $num_cols, $padding_value + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $k + CompileTimeConstantOperand, // $k + CompileTimeConstantOperand, // $reduction_indices + // $ksize, $strides + CompileTimeConstantOperand, + // $ksize, $strides + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $ksize, $strides + CompileTimeConstantOperand, // $reduction_indices + CompileTimeConstantOperand, // $paddings + CompileTimeConstantOperand, // $paddings + CompileTimeConstantOperand, // $num_samples + // $max_output_size + CompileTimeConstantOperand, + // $max_output_size + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $depth + CompileTimeConstantOperand, // $paddings + CompileTimeConstantOperand, // $paddings + // $shape + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $fft_length + CompileTimeConstantOperand, // $fft_length + CompileTimeConstantOperand, // $fft_length + CompileTimeConstantOperand, // $shape + CompileTimeConstantOperand, // $shape + CompileTimeConstantOperand, // $shape + // $start, $limit, $delta + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $shape + CompileTimeConstantOperand, // $size + CompileTimeConstantOperand, // $size + // $begin, $end, $strides + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $dims + CompileTimeConstantOperand, // $axis + CompileTimeConstantOperand, // $shape + CompileTimeConstantOperand, // $num_segments + CompileTimeConstantOperand, // $begin, $size + CompileTimeConstantOperand, // $output_shape + CompileTimeConstantOperand, // $max_size + // $num_samples + CompileTimeConstantOperand, + // $shape, $begin, $end, $strides + CompileTimeConstantOperand, + // $begin, $end, $strides + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $reduction_indices + CompileTimeConstantOperand, // $lengths + CompileTimeConstantOperand, // $size + // $element_shape + CompileTimeConstantOperand, + // $element_shape, $num_elements + CompileTimeConstantOperand, + // $begin, $end, $strides + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $multiples + CompileTimeConstantOperand, // $k + CompileTimeConstantOperand, // $perm + CompileTimeConstantOperand, // $shape + CompileTimeConstantOperand, // $num_segments + CompileTimeConstantOperand, // $num_segments + CompileTimeConstantOperand, // $num_segments + // $broadcast_dims + CompileTimeConstantOperand, + // $window_strides, $padding, $lhs_dilation, $rhs_dilation, + // $feature_group_count + CompileTimeConstantOperand, + // $window_strides, $padding, $lhs_dilation, $rhs_dilation, + // $feature_group_count + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $slice_indices + CompileTimeConstantOperand, // $slice_sizes + // $padding_low, $padding_high, $padding_interior + CompileTimeConstantOperand, + // $window_dimensions, $window_strides, $base_dilations, + // $window_dilations, $padding + CompileTimeConstantOperand, + // $dim_index + CompileTimeConstantOperand, + // $window_dimensions, $window_strides, $padding + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $bound + // $dim_index + CompileTimeConstantOperand + // go/keep-sorted end + >(func_op); +} + +static PassRegistration pass{}; + +} // namespace + +std::unique_ptr> +CreateDuplicateShapeDeterminingConstantsPass() { + return std::make_unique(); +} + +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc index 7441b68a19e..2970cafffc0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc @@ -33,7 +33,8 @@ namespace mlir { namespace quant { namespace { -constexpr char kCustomAggregatorOpName[] = "tf.CustomAggregator"; +constexpr StringRef kCustomAggregatorOpName = "tf.CustomAggregator"; +constexpr StringRef kQuantTraitAttrName = "_tfl_quant_trait"; class InsertCustomAggregationOpsPass : public PassWrappergetName().getStringRef() == kCustomAggregatorOpName) return failure(); + // Return early if the given op is a non-quantizable op. + auto call_op = dyn_cast_or_null(op); + if (call_op && !op->hasAttr(kQuantTraitAttrName)) { + return failure(); + } + bool mutated = false; for (Value input : op->getOperands()) { Type element_type = getElementTypeOrSelf(input.getType()); @@ -89,7 +96,8 @@ class AddCustomAggregationOp : public RewritePattern { } // Skip calibration when the given operand comes from a constant. - if (defining_op != nullptr && detail::isConstantLike(defining_op)) { + if (defining_op != nullptr && + defining_op->hasTrait()) { continue; } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc index e6fa3b3b1e1..0d60c9a2020 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc @@ -12,18 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include +#include +#include +#include +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/core/platform/macros.h" @@ -31,45 +42,41 @@ namespace mlir { namespace quant { namespace { +using ::mlir::tf_saved_model::kTfSavedModelExportedNamesAttr; +using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; using ::tensorflow::kImportModelDefaultGraphFuncName; -constexpr char kEntryFunctionAttr[] = "tf.entry_function"; -constexpr char kExportedNameAttr[] = "tf_saved_model.exported_names"; -constexpr char kIndexPathAttr[] = "tf_saved_model.index_path"; +constexpr StringRef kEntryFunctionAttr = "tf.entry_function"; // The ConvertMlirToGraphdef requires the provided input module to have a main // function, which might not exist in case of multi-signature graphs. In that // case, this pass will create a new main function, which calls signature // functions. +// +// An already existing @main function will be renamed by attaching a numeric +// suffix like `@main_0` to avoid conflict with the newly created main function. class InsertMainFunctionPass : public PassWrapper> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertMainFunctionPass) - explicit InsertMainFunctionPass() {} + explicit InsertMainFunctionPass() = default; - StringRef getArgument() const override { return "quant-add-main-function"; } + StringRef getArgument() const override { + return "quant-insert-main-function"; + } StringRef getDescription() const override { - return "Insert the main function to the module if it is missing."; + return "Inserts the main function to the module."; } void runOnOperation() override; }; -// Checks if the module has a main function. -bool HasMainFunction(ModuleOp& module) { - StringAttr main_func_id = - StringAttr::get(module.getContext(), kImportModelDefaultGraphFuncName); - for (auto function : module.getOps()) { - if (function.getName() == main_func_id) return true; - } - return false; -} - // Checks if a FuncOp is exported. bool IsExported(func::FuncOp op) { - auto exported_names = op->getAttrOfType(kExportedNameAttr); + auto exported_names = + op->getAttrOfType(kTfSavedModelExportedNamesAttr); return exported_names && !exported_names.empty(); } @@ -86,7 +93,7 @@ bool ShouldIncludeInMainFunction(func::FuncOp func_op) { } // Sets a function to be private so it can be referred internally. -void SetFunctionPrivate(func::FuncOp& func) { +void SetFunctionPrivate(func::FuncOp func) { func.setVisibility(SymbolTable::Visibility::Private); // The `tf_saved_model` attributes can only be appied to public functions. @@ -97,8 +104,9 @@ void SetFunctionPrivate(func::FuncOp& func) { } } + auto iface = cast(func.getOperation()); for (int i = 0; i < func.getNumArguments(); ++i) { - for (auto& attr : func.getArgAttrs(i)) { + for (auto& attr : iface.getArgAttrs(i)) { const StringAttr& attr_name = attr.getName(); if (attr_name.getValue().startswith("tf_saved_model.")) { func.removeArgAttr(i, attr_name); @@ -106,7 +114,7 @@ void SetFunctionPrivate(func::FuncOp& func) { } } for (int i = 0; i < func.getNumResults(); ++i) { - for (auto& attr : func.getResultAttrs(i)) { + for (auto& attr : iface.getResultAttrs(i)) { const StringAttr& attr_name = attr.getName(); if (attr_name.getValue().startswith("tf_saved_model.")) { func.removeResultAttr(i, attr_name); @@ -115,55 +123,121 @@ void SetFunctionPrivate(func::FuncOp& func) { } } -// Creates a main function which calls other exported functions. -bool CreateMainFunction(ModuleOp& module) { - MLIRContext* context = module.getContext(); - OpBuilder builder(context); - - // Collects argument and result types. - llvm::SmallVector arg_locs; - llvm::SmallVector arg_types, result_types; - std::vector input_names, output_names; - for (auto function : module.getOps()) { - if (!ShouldIncludeInMainFunction(function)) continue; - - arg_types.append(function.getArgumentTypes().begin(), - function.getArgumentTypes().end()); - auto& return_op = function.getBody().getBlocks().front().back(); - result_types.append(return_op.getOperandTypes().begin(), - return_op.getOperandTypes().end()); - for (const auto& arg : function.getArguments()) { - arg_locs.push_back(arg.getLoc()); - } +// Information to identify an output in its node and in the model output list. +// Ex: If the model output list is ["add:0", "topk:0": "topk:1"], then the +// output corresponding to "topk:1" will have output_index=2 and tensor_index=1. +struct OutputInfo { + // The index of this output in the model output list. + int32_t output_index; + // The index of this output in its node. + int32_t tensor_index; + // The output value. + Value value; +}; - // Collects input and output node names. These names are prefixed with the - // signature key in SavedModel. They also contain the index suffix. Ex: - // "_:0", where 0 is the index. +// Makes input/output names across entry functions unique if necessary. If a +// dupliated name is found, this function will add signature prefix for all the +// input/output names. +void GetUniqueInputOutputNodeNames(ModuleOp module_op, + std::vector& input_name_vec, + std::vector& output_name_vec) { + bool need_prefix_for_input_name = false; + bool need_prefix_for_output_name = false; + std::vector fn_input_name_vec, fn_output_name_vec; + StringSet<> input_name_set, output_name_set; + for (auto func_op : module_op.getOps()) { + if (!ShouldIncludeInMainFunction(func_op)) continue; if (auto tf_attrs = - function->getAttrOfType(kEntryFunctionAttr)) { + func_op->getAttrOfType(kEntryFunctionAttr)) { + StringRef function_name = func_op.getSymName(); + if (auto inputs_attr = tf_attrs.get("inputs")) { - std::string inputs_attr_str = + const std::string inputs_attr_str = inputs_attr.cast().getValue().str(); - std::vector inputs_attr_vec = + std::vector fn_input_names = absl::StrSplit(inputs_attr_str, ',', absl::SkipEmpty()); - input_names.insert(input_names.end(), inputs_attr_vec.begin(), - inputs_attr_vec.end()); + + for (StringRef input_name : fn_input_names) { + if (input_name_set.contains(input_name)) { + // Found a duplicated name, all input names will be prefixed by + // their corresponding function names. + need_prefix_for_input_name = true; + } + input_name_set.insert(input_name); + fn_input_name_vec.push_back(function_name); + } + input_name_vec.insert(input_name_vec.end(), + std::make_move_iterator(fn_input_names.begin()), + std::make_move_iterator(fn_input_names.end())); } + if (auto outputs_attr = tf_attrs.get("outputs")) { - std::string outputs_attr_str = + const std::string outputs_attr_str = outputs_attr.cast().getValue().str(); - std::vector outputs_attr_vec = + std::vector fn_output_names = absl::StrSplit(outputs_attr_str, ',', absl::SkipEmpty()); - output_names.insert(output_names.end(), outputs_attr_vec.begin(), - outputs_attr_vec.end()); + + for (StringRef output_name : fn_output_names) { + if (output_name_set.contains(output_name)) { + // Found a duplicated name, all output names will be prefixed by + // their corresponding function names. + need_prefix_for_output_name = true; + } + output_name_set.insert(output_name); + fn_output_name_vec.push_back(function_name); + } + output_name_vec.insert(output_name_vec.end(), + std::make_move_iterator(fn_output_names.begin()), + std::make_move_iterator(fn_output_names.end())); } } } + if (need_prefix_for_input_name) { + absl::c_transform( + input_name_vec, fn_input_name_vec, input_name_vec.begin(), + [](const std::string& input_name, const StringRef fn_name) { + return absl::StrCat(fn_name.str(), "_", input_name); + }); + } + if (need_prefix_for_output_name) { + absl::c_transform( + output_name_vec, fn_output_name_vec, output_name_vec.begin(), + [](const std::string& output_name, const StringRef fn_name) { + return absl::StrCat(fn_name.str(), "_", output_name); + }); + } +} + +// Creates a main function which calls other exported functions. +bool CreateMainFunction(ModuleOp module_op) { + MLIRContext* context = module_op.getContext(); + OpBuilder builder(context); + + std::vector input_names, output_names; + GetUniqueInputOutputNodeNames(module_op, input_names, output_names); + + // Collects argument and result types. + llvm::SmallVector arg_locs; + llvm::SmallVector arg_types, result_types; + + for (auto func_op : module_op.getOps()) { + if (!ShouldIncludeInMainFunction(func_op)) continue; + + arg_types.append(func_op.getArgumentTypes().begin(), + func_op.getArgumentTypes().end()); + auto& return_op = func_op.getBody().getBlocks().front().back(); + result_types.append(return_op.getOperandTypes().begin(), + return_op.getOperandTypes().end()); + for (const auto& arg : func_op.getArguments()) { + arg_locs.push_back(arg.getLoc()); + } + } + // Creates a new main function. auto func_type = FunctionType::get(context, arg_types, result_types); auto main_func = builder.create( - module.getLoc(), kImportModelDefaultGraphFuncName, func_type); + module_op.getLoc(), kImportModelDefaultGraphFuncName, func_type); builder.createBlock(&main_func.getBody(), main_func.begin(), arg_types, arg_locs); SmallVector func_attrs; @@ -176,12 +250,12 @@ bool CreateMainFunction(ModuleOp& module) { auto dictAttr = DictionaryAttr::get(context, func_attrs); main_func->setAttr(StringAttr::get(context, kEntryFunctionAttr), dictAttr); main_func->setAttr( - kExportedNameAttr, + kTfSavedModelExportedNamesAttr, builder.getStrArrayAttr({kImportModelDefaultGraphFuncName})); if (input_names.size() != main_func.getNumArguments() || output_names.size() != main_func.getNumResults()) { - module.emitError() + module_op.emitError() << "Number of inputs and outputs in the tf.entry_function attribute " "mismatched. [Input] Expected: " << input_names.size() << ", got: " << main_func.getNumArguments() @@ -193,14 +267,14 @@ bool CreateMainFunction(ModuleOp& module) { const int num_args = main_func.getNumArguments(); for (int i = 0; i < num_args; ++i) { main_func.setArgAttr( - i, kIndexPathAttr, + i, kTfSavedModelIndexPathAttr, ArrayAttr::get(context, {StringAttr::get(context, input_names[i])})); } const int num_results = main_func.getNumResults(); for (int i = 0; i < num_results; ++i) { main_func.setResultAttr( - i, kIndexPathAttr, + i, kTfSavedModelIndexPathAttr, ArrayAttr::get(context, {StringAttr::get(context, output_names[i])})); } @@ -208,42 +282,140 @@ bool CreateMainFunction(ModuleOp& module) { auto guard = OpBuilder::InsertionGuard(builder); int arg_idx = 0; int result_idx = 0; - llvm::SmallVector returning_values; - for (auto function : module.getOps()) { - if (!ShouldIncludeInMainFunction(function)) continue; + llvm::SmallVector call_op_returns; + for (auto func_op : module_op.getOps()) { + if (!ShouldIncludeInMainFunction(func_op)) continue; - llvm::ArrayRef new_args = llvm::makeArrayRef( - main_func.getArguments().begin() + arg_idx, function.getNumArguments()); - arg_idx += function.getNumArguments(); - llvm::ArrayRef new_types = llvm::makeArrayRef( - result_types.begin() + result_idx, function.getNumResults()); - result_idx += function.getNumResults(); + llvm::ArrayRef new_args = llvm::ArrayRef( + main_func.getArguments().begin() + arg_idx, func_op.getNumArguments()); + arg_idx += func_op.getNumArguments(); + llvm::ArrayRef new_types = llvm::ArrayRef( + result_types.begin() + result_idx, func_op.getNumResults()); + result_idx += func_op.getNumResults(); auto call_op = builder.create( - module.getLoc(), new_types, new_args, - SymbolRefAttr::get(context, function.getSymName()), + module_op.getLoc(), new_types, new_args, + SymbolRefAttr::get(context, func_op.getSymName()), /*config=*/builder.getStringAttr(""), /*config_proto=*/builder.getStringAttr(""), /*executor_type=*/builder.getStringAttr("")); - returning_values.append(call_op.getResults().begin(), - call_op.getResults().end()); - SetFunctionPrivate(function); + call_op_returns.append(call_op.getResults().begin(), + call_op.getResults().end()); + SetFunctionPrivate(func_op); + } + + // Creates Identity/IdentityN ops for returing values. This allows us to + // restore the same output tensor names in python. + int32_t output_count = 0; + // Map from node name to the list of the OutputInfos of its outputs that are + // used as the model outputs. + llvm::StringMap> node_to_output_map; + for (auto [output_name, call_op_return] : + llvm::zip(output_names, call_op_returns)) { + std::vector name_and_index = + absl::StrSplit(output_name, ':', absl::SkipEmpty()); + llvm::StringRef node_name = name_and_index.front(); + int32_t tensor_index = 0; + if (name_and_index.size() > 1) { + tensor_index = std::stoi(name_and_index.back()); + } + node_to_output_map[node_name].push_back( + {output_count++, tensor_index, call_op_return}); + } + + Value scalar_one = + CreateScalarConstValue(builder, builder.getUnknownLoc(), 1.0); + llvm::SmallVector returning_values(output_count, Value()); + for (const auto& node_name : node_to_output_map.keys()) { + auto node_output_tensors = node_to_output_map[node_name]; + + NameLoc new_loc = NameLoc::get(builder.getStringAttr(node_name)); + int32_t max_tensor_index = 0; + absl::c_for_each(node_output_tensors, + [&max_tensor_index](const OutputInfo& output_info) { + max_tensor_index = + std::max(max_tensor_index, output_info.tensor_index); + }); + + // Create IdentityOp or IdentityNOp based on the number of outputs. + Operation* identity_op; + if (max_tensor_index == 0) { + Value output_value = node_output_tensors.front().value; + identity_op = builder.create( + new_loc, output_value.getType(), output_value); + } else { + llvm::SmallVector input_values(node_output_tensors.size(), + scalar_one); + for (const auto& [output_index, tensor_index, tensor_value] : + node_output_tensors) { + input_values[tensor_index] = tensor_value; + } + identity_op = builder.create( + new_loc, TypeRange(ValueRange(input_values)), input_values); + } + + for (const auto& [output_index, tensor_index, tensor_value] : + node_output_tensors) { + returning_values[output_index] = identity_op->getResult(tensor_index); + } } builder.create(main_func.getBody().getLoc(), returning_values); // Adds the new function to symbol table. - SymbolTable symbol_table(module); + SymbolTable symbol_table(module_op); symbol_table.insert(main_func); return true; } +// Creates a new function name by attaching a number suffix +// (`main_func_name_{i}`) and incrementing it until there are no conflicts. +std::string CreateNewFuncName(const StringRef main_func_name, + SymbolTable& symbol_table) { + int suffix_id = 0; + std::string new_func_name = + absl::StrCat(main_func_name.str(), "_", suffix_id); + while (symbol_table.lookup(new_func_name)) { + suffix_id++; + new_func_name = absl::StrCat(main_func_name.str(), "_", suffix_id); + } + + return new_func_name; +} + +// Renames the existing @main function to avoid conflict with the newly +// created main function. When it is renamed, its usages will also be replaced. +// It will be renamed by attaching a number suffix like `@main_{i}`, until there +// are no conflicts. This function is a no-op when no function called @main +// exists. +LogicalResult RenameExistingMainFunction(ModuleOp module_op) { + SymbolTable symbol_table(module_op); + + auto main_func_op = + symbol_table.lookup(kImportModelDefaultGraphFuncName); + if (!main_func_op) { + return success(); + } + + const std::string new_func_name = + CreateNewFuncName(main_func_op.getSymName(), symbol_table); + + main_func_op.setSymName(new_func_name); + return symbol_table.replaceAllSymbolUses( + main_func_op, StringAttr::get(module_op.getContext(), new_func_name), + module_op); +} + void InsertMainFunctionPass::runOnOperation() { - ModuleOp module = getOperation(); - if (!HasMainFunction(module)) { - if (!CreateMainFunction(module)) { - signalPassFailure(); - } + ModuleOp module_op = getOperation(); + + if (failed(RenameExistingMainFunction(module_op))) { + module_op->emitError("Failed to rename existing function `@main`."); + signalPassFailure(); + } + + if (!CreateMainFunction(module_op)) { + signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_quantized_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_quantized_functions.cc index 38f6de3be6c..2535fc5ae62 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_quantized_functions.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_quantized_functions.cc @@ -36,15 +36,18 @@ namespace mlir { namespace quant { namespace { +using QuantMethod = + tensorflow::quantization::QuantizationMethod::ExperimentalMethod; + class InsertQuantizedFunctionsPass : public PassWrapper> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertQuantizedFunctionsPass) - explicit InsertQuantizedFunctionsPass() {} - explicit InsertQuantizedFunctionsPass(QuantizationMethod quantization_method, - const OpSet& op_set) { + explicit InsertQuantizedFunctionsPass() = default; + explicit InsertQuantizedFunctionsPass(QuantMethod quantization_method, + OpSet op_set) { quantization_method_ = quantization_method; op_set_ = op_set; } @@ -73,18 +76,22 @@ class InsertQuantizedFunctionsPass // Returns the function library for the given quantization method and opset // pair. - llvm::StringRef GetFunctionLibrary(QuantizationMethod quantization_method, + llvm::StringRef GetFunctionLibrary(QuantMethod quantization_method, OpSet op_set); - Option quantization_method_{ + Option quantization_method_{ *this, "quantization-method", - llvm::cl::init(QuantizationMethod::kPostTrainingQuantization), + llvm::cl::init( + tensorflow::quantization::QuantizationMethod::STATIC_RANGE), llvm::cl::desc("Choose quantization method."), llvm::cl::values( - clEnumValN(QuantizationMethod::kPostTrainingQuantization, "ptq", - "Post-training static-range quantization"), - clEnumValN(QuantizationMethod::kDynamicRangeQuantization, "drq", - "Post-training dynamic-range quantizaiton"))}; + clEnumValN(tensorflow::quantization::QuantizationMethod::STATIC_RANGE, + "ptq", "Post-training static-range quantization"), + clEnumValN( + tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE, + "drq", "Post-training dynamic-range quantizaiton"), + clEnumValN(tensorflow::quantization::QuantizationMethod::WEIGHT_ONLY, + "weight_only", "Post-training weight_only quantizaiton"))}; Option op_set_{ *this, "target-opset", llvm::cl::init(OpSet::TF), @@ -98,13 +105,22 @@ class InsertQuantizedFunctionsPass }; llvm::StringRef InsertQuantizedFunctionsPass::GetFunctionLibrary( - QuantizationMethod quantization_method, OpSet op_set) { + QuantMethod quantization_method, OpSet op_set) { absl::flat_hash_map function_library_map; - if (quantization_method == QuantizationMethod::kDynamicRangeQuantization) { + if (quantization_method == + tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE) { function_library_map = { + {OpSet::TF, kQuantizedFunctionLibraryInMLIR_TF_DRQ}, {OpSet::UNIFORM_QUANTIZED, kQuantizedFunctionLibraryInMLIR_UNIFORM_QUANTIZED_DRQ}, - {OpSet::TF, kQuantizedFunctionLibraryInMLIR_TF_DRQ}}; + {OpSet::XLA, kQuantizedFunctionLibraryInMLIR_TF_DRQ}}; + } else if (quantization_method == + tensorflow::quantization::QuantizationMethod::WEIGHT_ONLY) { + // Uniform quantized opset is not supported for weight-only as inputs for + // weight quantization are floats. And only dequantize_i8 is used from the + // quantized function library. + function_library_map = {{OpSet::TF, kQuantizedFunctionLibraryInMLIR}, + {OpSet::XLA, kQuantizedFunctionLibraryInMLIR}}; } else { function_library_map = {{OpSet::TF, kQuantizedFunctionLibraryInMLIR}, {OpSet::UNIFORM_QUANTIZED, @@ -170,6 +186,15 @@ void InsertQuantizedFunctionsPass::runOnOperation() { func::FuncOp new_func = func.clone(); new_func.setPrivate(); symbol_table.insert(new_func); + + // For consistency, we require all quantized composite function to have + // the "tf_quant.quantized_ops" attribute. + if (!new_func.getSymName().starts_with("quantized_")) continue; + if (!new_func->hasAttrOfType("tf_quant.quantized_ops")) { + new_func->emitError() << "Missing \"tf_quant.quantized_ops\" " + "attribute in the quantized composite function."; + signalPassFailure(); + } } } @@ -177,9 +202,9 @@ void InsertQuantizedFunctionsPass::runOnOperation() { // Creates an instance of the pass for inserting quantized functions. std::unique_ptr> CreateInsertQuantizedFunctionsPass( - QuantizationMethod quantization_method, const OpSet& op_set) { + QuantMethod quantization_method, OpSet target_opset) { return std::make_unique(quantization_method, - op_set); + target_opset); } } // namespace quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_restore_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_restore_op.cc new file mode 100644 index 00000000000..31371f7b502 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_restore_op.cc @@ -0,0 +1,213 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" + +namespace mlir { +namespace quant { +namespace { + +using ::mlir::tf_saved_model::GetInitializerFunction; +using ::mlir::tf_saved_model::GetSessionInitializerOp; +using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; +using ::mlir::tf_saved_model::kTfSavedModelInitializerRestoreType; +using ::mlir::tf_saved_model::kTfSavedModelInitializerTypeAttr; +using ::mlir::tf_saved_model::SessionInitializerOp; + +// This pass creates a RestoreV2 op in the initializer function with +// type "restore_op" that initializes variables from checkpoint. It finds +// tf.AssignVariableOp(tf.VarHandleOp, tf.Const) patterns in the initializer +// function and replaces tf.Consts with the results of RestoreV2. +class InsertRestoreOpPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertRestoreOpPass) + + explicit InsertRestoreOpPass() = default; + + // The argument used to refer to the pass in the textual format (e.g. on the + // commandline). + StringRef getArgument() const final { return "quant-insert-restore-op"; } + + StringRef getDescription() const final { + return "Creates RestoreV2 op to initialize the variables in the " + "initializer function (`tf_saved_model.initializer_type == " + "'restore_op'`). Replaces each occurrence of " + "`tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` patterns with " + "`tf.AssignVariableOp(tf.VarHandleOp, restore_op_output#N)`, where " + "`restore_op_output#N` is the Nth output of the newly created " + "RestoreV2Op."; + } + + void runOnOperation() override; +}; + +// Finds `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` patterns and returns +// the `tf.VarHandleOp`s that are initialized by these `tf.AssignVariableOp`s. +std::vector CollectVariableOps( + func::FuncOp session_init_func) { + std::vector var_handle_ops{}; + + for (auto assign_variable_op : llvm::make_early_inc_range( + session_init_func.getOps())) { + Value resource_operand = assign_variable_op.getOperand(0); + Value assigned_value_operand = assign_variable_op.getOperand(1); + + if (auto var_handle_op = + dyn_cast(resource_operand.getDefiningOp()); + var_handle_op && + isa(assigned_value_operand.getDefiningOp())) { + var_handle_ops.emplace_back(var_handle_op); + } + } + + return var_handle_ops; +} + +// Creates a `ConstOp` of 1-dimensional TF::StringType out of `str_values`. +TF::ConstOp Create1DStringConst(const ArrayRef str_values, + const Location loc, OpBuilder& builder) { + const auto tensor_type = + RankedTensorType::get(/*shape=*/{static_cast(str_values.size())}, + /*elementType=*/builder.getType()); + + return builder.create( + loc, DenseStringElementsAttr::get( + tensor_type, + SmallVector(str_values.begin(), str_values.end()))); +} + +// Creates a new argument for `func_op` that accepts a string tensor containing +// the checkpoint file's prefix. +BlockArgument InsertFilePrefixArgument(func::FuncOp func_op, + OpBuilder& builder) { + const auto filename_op_type = RankedTensorType::get( + /*shape=*/{}, /*elementType=*/builder.getType()); + const auto file_prefix_attr = builder.getStringAttr("__tf_file_prefix"); + const auto arg_attrs = builder.getDictionaryAttr({builder.getNamedAttr( + kTfSavedModelIndexPathAttr, builder.getArrayAttr({file_prefix_attr}))}); + + const int insert_idx = func_op.getNumArguments(); + + func_op.insertArgument(insert_idx, /*argType=*/filename_op_type, arg_attrs, + NameLoc::get(file_prefix_attr)); + + return func_op.getArgument(insert_idx); +} + +// Creates a 1D string array constant for "tensor_names" input of `RestoreV2` +// op. The `ConstOp` will be created at `builder`'s current insertion point. +TF::ConstOp CreateTensorNamesConst(const ArrayRef tensor_names, + OpBuilder& builder) { + const auto loc = NameLoc::get(builder.getStringAttr("tensor_names")); + return Create1DStringConst(tensor_names, loc, builder); +} + +// Creates a 1D string array constant for "shape_and_slices" input of +// `RestoreV2` op. The `ConstOp` will be created at `builder`'s current +// insertion point. It will be filled with `size` empty strings. +TF::ConstOp CreateShapeAndSlicesConst(const int size, OpBuilder& builder) { + const SmallVector shape_and_slices_values(size, /*Value=*/""); + + const auto loc = NameLoc::get(builder.getStringAttr("shape_and_slices")); + return Create1DStringConst(shape_and_slices_values, loc, builder); +} + +// Creates a `tf.RestoreV2Op` that loads the variable values from the checkpoint +// file. The loaded tensors will be used to initialize `tf.VarHandleOp`s via +// `tf.AssignVariableOp`s. +void CreateRestoreV2Op(std::vector& target_var_handle_ops, + func::FuncOp session_init_func) { + SmallVector tensor_types{}; + SmallVector tensor_names{}; + for (auto var_handle_op : target_var_handle_ops) { + tensor_names.emplace_back(var_handle_op.getSharedName().str()); + + // Ex) If VarHandleOp's type is tensor>>, + // then tensor<1xf32> is the subtype. + tensor_types.emplace_back(var_handle_op.resource_subtype()); + } + + auto builder = + OpBuilder::atBlockTerminator(&session_init_func.getBody().front()); + + const BlockArgument filename_arg = + InsertFilePrefixArgument(session_init_func, builder); + + TF::ConstOp tensor_names_const = + CreateTensorNamesConst(tensor_names, builder); + TF::ConstOp shape_and_slices_const = + CreateShapeAndSlicesConst(tensor_names.size(), builder); + + auto restore_op = builder.create( + session_init_func.getLoc(), + /*tensors=*/tensor_types, + /*prefix=*/filename_arg, tensor_names_const, shape_and_slices_const); + + for (auto [idx, restore_result] : llvm::enumerate(restore_op.getResults())) { + builder.create( + restore_op.getLoc(), target_var_handle_ops[idx], restore_result); + } +} + +// TODO(b/261813194): Do not create a new RestoreV2 op when a RestoreV2 op +// already exists. +void InsertRestoreOpPass::runOnOperation() { + ModuleOp module_op = getOperation(); + + func::FuncOp session_init_func = GetInitializerFunction( + module_op, /*initializer_type=*/kTfSavedModelInitializerRestoreType); + if (!session_init_func) { + LOG(INFO) << "No session initializer function with type 'restore_op'. " + "RestoreV2 op will not be created."; + return; + } + + std::vector target_var_handle_ops = + CollectVariableOps(session_init_func); + if (target_var_handle_ops.empty()) { + LOG(INFO) << "There are no VarHandleOps to restore. RestoreV2 op will not " + "be created."; + return; + } + + CreateRestoreV2Op(target_var_handle_ops, session_init_func); +} + +static PassRegistration pass{}; + +} // namespace + +std::unique_ptr> CreateInsertRestoreOpPass() { + return std::make_unique(); +} + +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.cc index d7d2845ea2a..b4f9fd73239 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -26,7 +27,6 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/TypeRange.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -37,9 +37,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/tsl/protobuf/error_codes.pb.h" namespace mlir { namespace quant { @@ -52,15 +49,18 @@ class LiftQuantizableSpotsAsFunctionsPass MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( LiftQuantizableSpotsAsFunctionsPass) - LiftQuantizableSpotsAsFunctionsPass() {} + LiftQuantizableSpotsAsFunctionsPass() = default; - explicit LiftQuantizableSpotsAsFunctionsPass(const OpSet& op_set) { + explicit LiftQuantizableSpotsAsFunctionsPass(OpSet op_set, + bool enable_two_input_tensors) { op_set_ = op_set; + enable_two_input_tensors_ = enable_two_input_tensors; } LiftQuantizableSpotsAsFunctionsPass( const LiftQuantizableSpotsAsFunctionsPass& other) { op_set_ = other.op_set_; + enable_two_input_tensors_ = other.enable_two_input_tensors_; } StringRef getArgument() const final { @@ -91,76 +91,105 @@ class LiftQuantizableSpotsAsFunctionsPass clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"), clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED", "Uses TF Uniform Quantized ops"))}; + + bool enable_two_input_tensors_{false}; }; class CheckQuantizableOps : public mlir::OpRewritePattern { public: - explicit CheckQuantizableOps(MLIRContext* context, const OpSet& op_set) - : OpRewritePattern(context), op_set_(op_set) {} + explicit CheckQuantizableOps(MLIRContext* context, OpSet op_set, + bool enable_two_input_tensors) + : OpRewritePattern(context), + op_set_(op_set), + enable_two_input_tensors_(enable_two_input_tensors) {} private: LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, PatternRewriter& rewriter) const override { StringRef function_name = - call_op.fAttr().cast().getValue(); + call_op.getFAttr().cast().getValue(); if (!function_name.startswith("composite_") || !call_op->hasAttr(kQuantTraitAttrName)) { return failure(); } - tensorflow::Status check_status; - switch (op_set_) { - case OpSet::XLA: - check_status = checkQuantizableOpsForXla(call_op, function_name); - break; - default: - check_status = tensorflow::OkStatus(); - break; + absl::Status check_status; + // Skip quantization for read-only ops as only weight-only is supported. + if (function_name.contains("gather")) { + check_status.Update(absl::InternalError("Weight-only op is skipped.")); + } + + if (op_set_ == OpSet::XLA) { + check_status.Update(checkQuantizableOpsForXla(call_op, function_name, + enable_two_input_tensors_)); } // The OK status means this op is quantizable. Return failure since the // pattern doesn't rewrite anything yet. if (check_status.ok()) return failure(); call_op->removeAttr(kQuantTraitAttrName); - removeAttrMapAttribute(call_op, function_name, - check_status.error_message()); + removeAttrMapAttribute(call_op, function_name, check_status.message()); return success(); } - tensorflow::Status checkQuantizableOpsForXla(TF::PartitionedCallOp call_op, - StringRef function_name) const { + absl::Status checkQuantizableOpsForXla(TF::PartitionedCallOp call_op, + StringRef function_name, + bool enable_two_input_tensors) const { // Disable quantization for the DepthwiseConv since it has no benefits in // the XLA opset. if (function_name.contains("depthwise_conv2d")) { - return tensorflow::errors::Unknown( + return absl::InternalError( "DepthwiseConv2D doesn't get any benefit of quantization in XLA."); } else if (function_name.contains("conv2d")) { // For Conv2D, the channel dimension must be static to calculate the // feature group count. if (!HasStaticShapeAtDims(call_op->getOperand(0), /*dims=*/3)) { - return tensorflow::errors::Unknown( + return absl::InternalError( "The channel dimension of Conv2D is required to be static."); } } else if (function_name.contains("conv3d")) { // For Conv3D, the channel dimension must be static to calculate the // feature group count. if (!HasStaticShapeAtDims(call_op->getOperand(0), /*dims=*/4)) { - return tensorflow::errors::Unknown( + return absl::InternalError( "The channel dimension of Conv3D is required to be static."); } + } else if (function_name.contains("batch_matmul")) { + // For BatchMatMul, the input must be ranked. + auto shaped_type = + call_op->getOperand(0).getType().dyn_cast(); + if (!shaped_type || !shaped_type.hasRank()) { + return absl::InternalError("The input of BatchMatMul must have rank."); + } } std::unique_ptr spec = GetTFOpQuantSpec(call_op); for (auto iter : spec->coeff_op_quant_dim) { Operation* preceding_op = call_op.getOperand(iter.first).getDefiningOp(); // The XLA opset only supports constant filter/weight at the moment. - if (!preceding_op || !preceding_op->hasTrait()) { - return tensorflow::errors::Unknown( - "Non-constant weights are not supported at the moment."); + bool is_weight_constant = + preceding_op && preceding_op->hasTrait(); + + // There might be q/dq ops after the filter/weight. + if (auto dq_op = llvm::dyn_cast_or_null( + preceding_op)) { + if (auto q_op = llvm::dyn_cast_or_null( + dq_op.getArg().getDefiningOp())) { + Operation* q_op_input = q_op.getArg().getDefiningOp(); + is_weight_constant = + q_op_input && q_op_input->hasTrait(); + } + } + + if (!is_weight_constant) { + if (!enable_two_input_tensors || !function_name.contains("matmul")) { + return absl::InternalError( + "Non-constant weights are not supported at the moment."); + } } } - return tensorflow::OkStatus(); + return absl::OkStatus(); } void removeAttrMapAttribute(TF::PartitionedCallOp call_op, @@ -189,6 +218,7 @@ class CheckQuantizableOps } OpSet op_set_; + bool enable_two_input_tensors_; }; static PassRegistration pass; @@ -201,7 +231,7 @@ void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() { ModuleOp module = getOperation(); populateWithGenerated(patterns); - patterns.add(ctx, op_set_); + patterns.add(ctx, op_set_, enable_two_input_tensors_); FrozenRewritePatternSet frozen_patterns(std::move(patterns)); for (auto func : module.getOps()) { if (failed(applyPatternsAndFoldGreedily(func, frozen_patterns))) { @@ -214,8 +244,10 @@ void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() { } // namespace std::unique_ptr> -CreateLiftQuantizableSpotsAsFunctionsPass(const OpSet& op_set) { - return std::make_unique(op_set); +CreateLiftQuantizableSpotsAsFunctionsPass(OpSet target_opset, + bool enable_two_input_tensors) { + return std::make_unique( + target_opset, enable_two_input_tensors); } } // namespace quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td index 7ec9739eaac..00fd6f5605d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td @@ -83,6 +83,16 @@ def LiftConv3D : Pat< (NamedAttr<"dilations"> $dilations))), [(IsNotInLiftedFunc $res)], (addBenefit 1)>; +def LiftBatchMatMul : Pat< + (TF_BatchMatMulV2Op:$res $x, $y, $adj_x, $adj_y), + (LiftAsFunctionCall<"composite_batch_matmul_fn"> + (ArgumentList $x, $y), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"adj_x"> $adj_x), + (NamedAttr<"adj_y"> $adj_y))), + [(IsNotInLiftedFunc $res)], (addBenefit 1)>; + //===----------------------------------------------------------------------===// // Pattern rules for lifting ops with bias as functions //===----------------------------------------------------------------------===// @@ -144,6 +154,18 @@ def LiftConv3dWithBias : Pat< (NamedAttr<"dilations"> $dilations))), [(IsNotInLiftedFunc $res)], (addBenefit 5)>; +def LiftBatchMatMulWithBias : Pat< + (TF_BiasAddOp:$res + (TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y), + $bias, IsDataFormatNHWC:$bias_data_format), + (LiftAsFunctionCall<"composite_batch_matmul_with_bias_fn"> + (ArgumentList $x, $y, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"adj_x"> $adj_x), + (NamedAttr<"adj_y"> $adj_y))), + [(IsNotInLiftedFunc $res)], (addBenefit 5)>; + //===----------------------------------------------------------------------===// // Pattern rules for lifting ops with bias and activation as functions //===----------------------------------------------------------------------===// @@ -263,6 +285,38 @@ multiclass LiftCompositeOpsWithActivation $dilations))), [(IsNotInLiftedFunc $res)], (addBenefit 10)>; + def LiftBatchMatMulWith#ActivationOp : Pat< + (ActivationOp:$res + (TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y)), + (LiftAsFunctionCall<"composite_batch_matmul_with_"# ActivationName #"_fn"> + (ArgumentList $x, $y), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"adj_x"> $adj_x), + (NamedAttr<"adj_y"> $adj_y))), + [(IsNotInLiftedFunc $res)], (addBenefit 10)>; + + def LiftBatchMatMulWithBiasAnd#LastFusedOp : Pat< + (ActivationOp:$res + (TF_BiasAddOp + (TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y), + $bias, IsDataFormatNHWC:$bias_data_format)), + (LiftAsFunctionCall<"composite_batch_matmul_with_bias_and_"# ActivationName #"_fn"> + (ArgumentList $x, $y, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"adj_x"> $adj_x), + (NamedAttr<"adj_y"> $adj_y))), + [(IsNotInLiftedFunc $res)], (addBenefit 10)>; } defm : LiftCompositeOpsWithActivation; defm : LiftCompositeOpsWithActivation; + +def LiftGather : Pat< + (TF_GatherV2Op:$res $params, $indices, $axis, $batch_dims), + (LiftAsFunctionCall<"composite_gather_fn"> + (ArgumentList $params, $indices, $axis), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"batch_dims"> $batch_dims))), + [(IsNotInLiftedFunc $res), (IsConstTensor $params)], (addBenefit 1)>; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.cc index 2b135096426..94109bf98f2 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.cc @@ -27,6 +27,8 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -43,7 +45,7 @@ class LiftQuantizableSpotsAsFunctionsDRQPass LiftQuantizableSpotsAsFunctionsDRQPass) // Constructor used by the PassRegistration. This is only used by test. - explicit LiftQuantizableSpotsAsFunctionsDRQPass() {} + explicit LiftQuantizableSpotsAsFunctionsDRQPass() = default; // Constructor used by manually creating the pass. explicit LiftQuantizableSpotsAsFunctionsDRQPass( @@ -108,6 +110,13 @@ class CheckQuantizableOps call_op->removeAttr(kQuantTraitAttrName); } } + + StringRef function_name = + call_op.getFAttr().cast().getValue(); + if (function_name.contains("gather")) { + call_op->removeAttr(kQuantTraitAttrName); + } + return failure(); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.td index b74cf08acba..c4ea778b522 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.td @@ -24,6 +24,33 @@ include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" // Pattern rules for lifting ops as functions //===----------------------------------------------------------------------===// +def LiftConv : Pat< + (TF_Conv2DOp:$res $input, $filter, $strides, $use_cudnn_on_gpu, $padding, + $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), + (LiftAsFunctionCall<"composite_conv2d_fn"> + (ArgumentList $input, $filter), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"use_cudnn_on_gpu"> $use_cudnn_on_gpu), + (NamedAttr<"padding"> $padding), + (NamedAttr<"explicit_paddings"> $explicit_paddings), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res), (IsConstTensor $filter)], (addBenefit 1)>; + +def LiftDepthwiseConv : Pat< + (TF_DepthwiseConv2dNativeOp:$res $input, $filter, $strides, $padding, + $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), + (LiftAsFunctionCall<"composite_depthwise_conv2d_fn"> + (ArgumentList $input, $filter), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"padding"> $padding), + (NamedAttr<"explicit_paddings"> $explicit_paddings), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res), (IsConstTensor $filter)], (addBenefit 1)>; + def LiftMatMul : Pat< (TF_MatMulOp:$res $a, $b, $transpose_a, $transpose_b), (LiftAsFunctionCall<"composite_matmul_fn"> @@ -33,3 +60,12 @@ def LiftMatMul : Pat< (NamedAttr<"transpose_a"> $transpose_a), (NamedAttr<"transpose_b"> $transpose_b))), [(IsNotInLiftedFunc $res), (IsConstTensor $b)], (addBenefit 1)>; + +def LiftGather : Pat< + (TF_GatherV2Op:$res $params, $indices, $axis, $batch_dims), + (LiftAsFunctionCall<"composite_gather_fn"> + (ArgumentList $params, $indices, $axis), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"batch_dims"> $batch_dims))), + [(IsNotInLiftedFunc $res), (IsConstTensor $params)], (addBenefit 1)>; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/mark_functions_noinline.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/mark_functions_noinline.cc new file mode 100644 index 00000000000..11fb2bebfef --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/mark_functions_noinline.cc @@ -0,0 +1,124 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "absl/strings/str_cat.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" + +// Required when using LLVM_DEBUG macro. +#define DEBUG_TYPE "mark-functions-noinline" + +namespace mlir { +namespace quant { +namespace { + +// Name of the boolean attribute indicating whether the function can be +// inlined or not. +constexpr StringRef kTfNoinlineAttr = "tf._noinline"; + +// This pass marks functions with the attribute `tf._noinline = true` so that +// they aren't inlined by the `InlinerPass`. The names of the functions to be +// marked noinline should be specified by the `noinline-functions` option. +class MarkFunctionsNoinlinePass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MarkFunctionsNoinlinePass) + + explicit MarkFunctionsNoinlinePass() + : MarkFunctionsNoinlinePass( + /*noinline_functions=*/ArrayRef{}) {} + + // `noinline_functions` is a list of function names to be marked noinline. + explicit MarkFunctionsNoinlinePass( + const ArrayRef noinline_functions) + : noinline_functions_(CreateNoinlineFunctionsOption(noinline_functions)) { + } + + MarkFunctionsNoinlinePass(const MarkFunctionsNoinlinePass& other) + : MarkFunctionsNoinlinePass() { + noinline_functions_ = other.noinline_functions_; + } + + StringRef getArgument() const final { return "mark-functions-noinline"; } + + StringRef getDescription() const final { + return "Marks a function whose name is in `noinline-functions` option with " + "the attribute `tf._noinline = true`. This attributes the function " + "from being inlined by the `InlinerPass`."; + } + + void runOnOperation() override; + + private: + ListOption CreateNoinlineFunctionsOption( + const ArrayRef noinline_functions) { + return {*this, "noinline-functions", + llvm::cl::desc( + "Name of the functions that should be marked " + "tf._noinline = true to prevent inlining. The name of the " + "function should exactly match to be marked noinline."), + llvm::cl::list_init(noinline_functions), + llvm::cl::ZeroOrMore}; + } + + // Gets a set of function names from `noinline_functions_`. + StringSet<> GetNoinlineFunctionsSet() { + StringSet<> noinline_functions; + noinline_functions.insert(noinline_functions_.begin(), + noinline_functions_.end()); + return noinline_functions; + } + + // Names of the functions to be marked noinline. + ListOption noinline_functions_; +}; + +void MarkFunctionsNoinlinePass::runOnOperation() { + const StringSet<> noinline_functions = GetNoinlineFunctionsSet(); + + func::FuncOp func_op = getOperation(); + Builder builder(&getContext()); + + // Adds the `tf._noinline = true` attribute to the function if the name + // matches. + if (noinline_functions.contains(func_op.getSymName())) { + func_op->setAttr(kTfNoinlineAttr, builder.getBoolAttr(true)); + LLVM_DEBUG(llvm::dbgs() + << "Marked tf._noinline = true: " << func_op.getSymName()); + } +} + +static PassRegistration pass{}; + +} // namespace + +std::unique_ptr> CreateMarkFunctionsNoinlinePass( + const ArrayRef noinline_functions) { + return std::make_unique(noinline_functions); +} + +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc index cf1a2f0fb64..d5491f6010c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -22,9 +23,11 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project @@ -43,6 +46,9 @@ namespace mlir { namespace quant { namespace { +using ::mlir::tf_executor::FetchOp; +using ::mlir::tf_executor::GraphOp; +using ::mlir::tf_executor::IslandOp; using ::mlir::tf_saved_model::GetSessionInitializerOp; using ::mlir::tf_saved_model::kTfSavedModelInitializerInitType; using ::mlir::tf_saved_model::kTfSavedModelInitializerTypeAttr; @@ -61,7 +67,7 @@ class MergeInitializerFunctionOpsToMainPass MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( MergeInitializerFunctionOpsToMainPass) - explicit MergeInitializerFunctionOpsToMainPass() {} + explicit MergeInitializerFunctionOpsToMainPass() = default; StringRef getArgument() const override { return "quant-merge-initializer-function-ops-to-main"; @@ -88,7 +94,7 @@ class MergeInitializerFunctionOpsToMainPass // Gets the "main" function from the module. Returns an empty op iff it doesn't // exist. func::FuncOp GetMainFunction(ModuleOp module_op) { - const StringAttr main_func_id = + const auto main_func_id = StringAttr::get(module_op.getContext(), kImportModelDefaultGraphFuncName); auto func_ops = module_op.getOps(); auto main_func_itr = absl::c_find_if(func_ops, [&main_func_id](auto func_op) { @@ -106,14 +112,14 @@ bool IsFuncOpEmpty(func::FuncOp func_op) { // Gets the GraphOp from the function op. Returns an empty op iff it doesn't // exist. -tf_executor::GraphOp GetGraphOpFromFuncOp(func::FuncOp func_op) { +GraphOp GetGraphOpFromFuncOp(func::FuncOp func_op) { if (IsFuncOpEmpty(func_op)) return {}; auto graph_op_range = func_op.front().without_terminator(); if (llvm::hasSingleElement(graph_op_range)) { // The pass runs on a valid tf_executor dialect, so the op should be the // GraphOp. - return cast(graph_op_range.begin()); + return cast(graph_op_range.begin()); } return {}; @@ -121,34 +127,55 @@ tf_executor::GraphOp GetGraphOpFromFuncOp(func::FuncOp func_op) { // Gets the string representation of the type name. std::string GetTypeName(const Type type) { - // Gets the string representation of the type name. std::string type_name{}; auto os = llvm::raw_string_ostream{type_name}; os << type; return type_name; } +// Retrieves the value of `tf_saved_model.initializer_type` attribute from the +// initializer function. Returns "unknown_initializer_type" iff the attribute is +// not set. +std::string GetInitializerType(func::FuncOp init_func_op) { + const auto initializer_type_attr = + init_func_op->getAttrOfType(kTfSavedModelInitializerTypeAttr); + + if (!initializer_type_attr) { + init_func_op->emitWarning() + << "Initializer func op does not have tf_saved_model.initializer_type " + "attribute. Func op: " + << init_func_op.getSymName(); + return "unknown_initializer_type"; + } + + return initializer_type_attr.str(); +} + // An initializer function should satisfy the follwing conditions: -// 1. The arguments should not be used. +// 1. The arguments should not be used if the type is "init_op" (it assumes +// non-variable resources like tables aren't being initialized by the asset +// files passed as arguments). // 2. Its GraphOp should only have control outputs. LogicalResult ValidateInitFunc(func::FuncOp init_func_op) { - for (BlockArgument arg : init_func_op.getArguments()) { - if (!arg.use_empty()) { - const int arg_idx = arg.getArgNumber(); - const int num_uses = absl::c_distance(arg.getUses()); - init_func_op.emitError(absl::StrFormat( - "Validation failed for the initializer function: %s. " - "The initializer function's arguments should have no " - "usages. Instead, argument index: %d has number of usages: %d.", - init_func_op.getName().str(), arg_idx, num_uses)); - return failure(); + if (GetInitializerType(init_func_op) == kTfSavedModelInitializerInitType) { + for (BlockArgument arg : init_func_op.getArguments()) { + if (!arg.use_empty()) { + const int arg_idx = arg.getArgNumber(); + const int num_uses = absl::c_distance(arg.getUses()); + init_func_op.emitError(absl::StrFormat( + "Validation failed for the initializer function: %s. " + "The initializer function's arguments should have no " + "usages. Instead, argument index: %d has number of usages: %d.", + init_func_op.getName().str(), arg_idx, num_uses)); + return failure(); + } } } - tf_executor::GraphOp graph_op = GetGraphOpFromFuncOp(init_func_op); + GraphOp graph_op = GetGraphOpFromFuncOp(init_func_op); if (!graph_op) return success(); // Consider empty FuncOp valid. - tf_executor::FetchOp fetch_op = graph_op.GetFetch(); + FetchOp fetch_op = graph_op.GetFetch(); for (const Value fetch : fetch_op.getFetches()) { if (!fetch.getType().isa()) { fetch_op.emitError(absl::StrFormat( @@ -163,24 +190,6 @@ LogicalResult ValidateInitFunc(func::FuncOp init_func_op) { return success(); } -// Retrieves the value of `tf_saved_model.initializer_type` attribute from the -// initializer function. Returns "unknown_initializer_type" iff the attribute is -// not set. -std::string GetInitializerType(func::FuncOp init_func_op) { - const auto initializer_type_attr = - init_func_op->getAttrOfType(kTfSavedModelInitializerTypeAttr); - - if (!initializer_type_attr) { - init_func_op->emitWarning() - << "Initializer func op does not have tf_saved_model.initializer_type " - "attribute. Func op: " - << init_func_op.getSymName(); - return "unknown_initializer_type"; - } - - return initializer_type_attr.str(); -} - // Returns initializer_type -> init_func_op mapping from the session_init_op's // initializers. The initializer functions are validated for whether it can be // moved to the main function. Returns failure() iff validation fails. @@ -207,13 +216,85 @@ FailureOr> GetInitFuncOps( return init_func_ops; } +// If `main_func_op` has the `tf.entry_function` attribute, adds a new input +// name to the `inputs` field of the attribute. Otherwise, no attribute is +// modified. +void MaybeAddEntryFunctionInput(const StringRef input_name, + func::FuncOp main_func_op) { + auto entry_func_attr = + main_func_op->getAttrOfType("tf.entry_function"); + if (!entry_func_attr) return; + + auto entry_func_attrs = SmallVector(entry_func_attr.begin(), + entry_func_attr.end()); + + MLIRContext* ctx = main_func_op.getContext(); + for (auto& named_attr : entry_func_attrs) { + if (named_attr.getName() != "inputs") continue; + + // Splits the "inputs" field to retrieve individual input names. Ignores + // empty strings. + SmallVector inputs_attrs{}; + cast(named_attr.getValue()) + .strref() + .split(inputs_attrs, /*Separator=*/',', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + + inputs_attrs.emplace_back(input_name); + + const std::string new_inputs_attr_str = + llvm::join(std::move(inputs_attrs), /*Separator=*/","); + + named_attr.setValue(StringAttr::get(ctx, new_inputs_attr_str)); + } + + main_func_op->setAttr("tf.entry_function", + DictionaryAttr::get(ctx, entry_func_attrs)); +} + +// Creates new arguments to the main function that corresponds to the source +// function's arguments. Returns the `IRMapping` that contains the +// relationship. +IRMapping CloneSrcFuncArgumentsToMainFunc( + func::FuncOp src_func_op, func::FuncOp main_func_op) { + IRMapping mapper{}; + + for (auto [src_arg_idx, src_arg] : + llvm::enumerate(src_func_op.getArguments())) { + // No need to create a mapping when there is no usage - it will not affect + // the cloning. + if (src_arg.use_empty()) continue; + + const unsigned main_arg_idx = main_func_op.getNumArguments(); + + const DictionaryAttr main_arg_attr = + src_func_op.getArgAttrDict(src_arg_idx); + + main_func_op.insertArgument(main_arg_idx, src_arg.getType(), main_arg_attr, + src_arg.getLoc()); + + const std::string new_input_name = + absl::StrCat(GetInitializerType(src_func_op), "_", src_arg_idx, ":0"); + + MaybeAddEntryFunctionInput(new_input_name, main_func_op); + + // During cloning, let it know that the source function's argument + // corresponds to the main function's newly created argument when cloning + // ops from src -> main. + BlockArgument main_arg = main_func_op.getArgument(main_arg_idx); + mapper.map(src_arg, main_arg); + } + + return mapper; +} + // Copies ops from `src_func_op` to `main_body` except for the FetchOps. Returns // the fetch values in the main GraphOp corresponding to the original fetch // values from `src_func_op`. Returns an empty vector when `src_func_op` is -// empty. -llvm::SmallVector CopyOpsToMainFunction( - func::FuncOp src_func_op, tf_executor::GraphOp main_graph_op) { - tf_executor::GraphOp src_graph_op = GetGraphOpFromFuncOp(src_func_op); +// empty. `main_func_op` must have a GraphOp. +SmallVector CopyOpsToMainFunction(func::FuncOp src_func_op, + func::FuncOp main_func_op) { + GraphOp src_graph_op = GetGraphOpFromFuncOp(src_func_op); if (!src_graph_op) { VLOG(1) << "Function " << src_func_op.getName().str() << " does not have a tf_executor::GraphOp. No ops are copied to " @@ -221,17 +302,20 @@ llvm::SmallVector CopyOpsToMainFunction( return {}; } - tf_executor::FetchOp main_fetch_op = main_graph_op.GetFetch(); + GraphOp main_graph_op = GetGraphOpFromFuncOp(main_func_op); + + FetchOp main_fetch_op = main_graph_op.GetFetch(); const absl::Cleanup erase_main_fetch_op = [main_fetch_op]() mutable { main_fetch_op.erase(); }; - Block& main_body = main_graph_op.GetBody(); + // TODO(b/245473863): Handle when assets are actually used in the body. + IRMapping mapper = + CloneSrcFuncArgumentsToMainFunc(src_func_op, main_func_op); // Clones each op from src to main_body. + Block& main_body = main_graph_op.GetBody(); Block& src_body = src_graph_op.GetBody(); - // TODO(b/245473863): Handle when assets are actually used in the body. - BlockAndValueMapping mapper{}; for (Operation& op : src_body.without_terminator()) { main_body.push_back(op.clone(mapper)); } @@ -241,59 +325,22 @@ llvm::SmallVector CopyOpsToMainFunction( // Clone the source's FetchOp, but do not push to the main function's body. // The clone is only needed to identify the fetch operands. - auto cloned_fetch_op = - cast(src_graph_op.GetFetch()->clone(mapper)); + auto cloned_fetch_op = cast(src_graph_op.GetFetch()->clone(mapper)); const absl::Cleanup erase_cloned_fetch_op = [cloned_fetch_op]() mutable { cloned_fetch_op.erase(); }; - const auto fetch_operands = llvm::to_vector(cloned_fetch_op.getFetches()); - - return fetch_operands; -} - -// An overload where it accepts multiple source FuncOps. Returns all the fetches -// from the source FuncOps. -llvm::SmallVector CopyOpsToMainFunction( - const ArrayRef src_func_ops, - tf_executor::GraphOp main_graph_op) { - llvm::SmallVector fetches{}; - absl::c_for_each(src_func_ops, [main_graph_op, &fetches](auto src_func_op) { - const auto fetch_operands = - CopyOpsToMainFunction(src_func_op, main_graph_op); - fetches.append(fetch_operands); - }); - - return fetches; -} - -// Removes the SymbolRefAttr from session_initializer op's `initializers` -// attribute when its initializer_type corresponds to `init_type_to_erase`. -void EraseInitializerFromInitializersAttr( - absl::flat_hash_map& init_func_ops, - StringRef init_type_to_erase, SessionInitializerOp session_init_op, - MLIRContext* ctx) { - // Resets the `initializers` attribute excluding the symbol ref of the init - // function whose type matches `init_type_to_erase`. - llvm::SmallVector init_func_symbols{}; - for (auto& [init_type, init_func_op] : init_func_ops) { - if (init_type == init_type_to_erase) continue; - - init_func_symbols.emplace_back( - SymbolRefAttr::get(ctx, init_func_op.getSymName())); - } - - session_init_op.setInitializersAttr(ArrayAttr::get(ctx, init_func_symbols)); + return llvm::to_vector(cloned_fetch_op.getFetches()); } // Creates a new `IslandOp` that wraps a `TF::NoOp`. The `IslandOp` has control // dependencies to the values provided. -tf_executor::IslandOp CreateNoOpWithControlDependencies( - const Location loc, tf_executor::GraphOp main_graph_op, +IslandOp CreateNoOpWithControlDependencies( + const Location loc, GraphOp main_graph_op, const ArrayRef control_dependencies) { auto builder = OpBuilder::atBlockTerminator(&main_graph_op.GetBody()); - auto wrapper_island_op = builder.create( + auto wrapper_island_op = builder.create( loc, /*outputs=*/TypeRange{}, /*control=*/tf_executor::ControlType::get(builder.getContext()), /*controlInputs=*/control_dependencies); @@ -310,9 +357,8 @@ tf_executor::IslandOp CreateNoOpWithControlDependencies( } // Adds a new fetch operand for the main function's GraphOp. -void AddFetchOperandToMain(tf_executor::GraphOp main_graph_op, - const Value fetch_operand) { - tf_executor::FetchOp old_fetch = main_graph_op.GetFetch(); +void AddFetchOperandToMain(GraphOp main_graph_op, const Value fetch_operand) { + FetchOp old_fetch = main_graph_op.GetFetch(); const absl::Cleanup erase_old_fetch = [old_fetch]() mutable { old_fetch.erase(); }; @@ -321,16 +367,15 @@ void AddFetchOperandToMain(tf_executor::GraphOp main_graph_op, fetches.emplace_back(fetch_operand); auto builder = OpBuilder::atBlockTerminator(&main_graph_op.GetBody()); - builder.create(main_graph_op.getLoc(), - std::move(fetches)); + builder.create(main_graph_op.getLoc(), std::move(fetches)); } -// Creates a new Location for the init op. This creates a loc by attaching a -// prefix `kInitOpNamePrefix` to the initializer function's name so that it is -// identifiable. +// Creates a new Location for the initializer function. This creates a loc by +// attaching a to the initializer function's type so that it is identifiable. Location CreateInitOpLoc(MLIRContext* ctx, func::FuncOp init_func_ops) { + const std::string init_type = GetInitializerType(init_func_ops); const std::string name = - absl::StrCat(kInitOpNamePrefix, "_", init_func_ops.getName().str()); + absl::StrCat(init_type, "_", init_func_ops.getName().str()); return NameLoc::get(StringAttr::get(ctx, name)); } @@ -344,7 +389,7 @@ void MergeInitializerFunctionOpsToMainPass::runOnOperation() { return signalPassFailure(); } - tf_executor::GraphOp main_graph_op = GetGraphOpFromFuncOp(main_func_op); + GraphOp main_graph_op = GetGraphOpFromFuncOp(main_func_op); if (!main_graph_op) return; SessionInitializerOp session_init_op = GetSessionInitializerOp(module_op); @@ -362,41 +407,32 @@ void MergeInitializerFunctionOpsToMainPass::runOnOperation() { return; } - // Find the init function with type "init_op" and clone the ops to @main. - // TODO(b/253614209): Also add the init function corresponding to the - // "restore_op" to @main. - const auto init_op_it = init_func_ops->find(kTfSavedModelInitializerInitType); - if (init_op_it == init_func_ops->end()) { - VLOG(1) << "Initializer function with tf_saved_model.initializer_type == " - "'init_op' not found."; - return; - } - - func::FuncOp init_op_func = init_op_it->second; - const llvm::SmallVector init_op_fetches = - CopyOpsToMainFunction(init_op_func, main_graph_op); - if (init_op_fetches.empty()) { - VLOG(1) << "No fetch values exist from initializer functions."; - return; - } + // Find the initializer functions and clone their ops to @main. + for (auto& [init_type, init_op_func] : *init_func_ops) { + const SmallVector init_op_fetches = + CopyOpsToMainFunction(init_op_func, main_func_op); + if (init_op_fetches.empty()) { + VLOG(1) << "No fetch values exist from initializer functions."; + return; + } - // Creates a NoOp that has control dependency to the initializer function - // for non-variables. - const Location init_op_loc = CreateInitOpLoc(ctx, init_op_func); - tf_executor::IslandOp noop_wrapper_island_op = - CreateNoOpWithControlDependencies( - init_op_loc, main_graph_op, - /*control_dependencies=*/init_op_fetches); + // Creates a NoOp that has control dependency to the initializer function + // for non-variables. + const Location init_op_loc = CreateInitOpLoc(ctx, init_op_func); + IslandOp noop_wrapper_island_op = CreateNoOpWithControlDependencies( + init_op_loc, main_graph_op, + /*control_dependencies=*/init_op_fetches); - AddFetchOperandToMain(main_graph_op, - /*fetch_operand=*/noop_wrapper_island_op.getControl()); + AddFetchOperandToMain( + main_graph_op, + /*fetch_operand=*/noop_wrapper_island_op.getControl()); - symbol_table.erase(init_op_func); + symbol_table.erase(init_op_func); + } - EraseInitializerFromInitializersAttr( - *init_func_ops, - /*init_type_to_erase=*/kTfSavedModelInitializerInitType, session_init_op, - ctx); + // Empties the "initializers" attribute from the `SessionInitializerOp` since + // all ops of the initializer ops are cloned into @main. + session_init_op.setInitializersAttr(ArrayAttr::get(ctx, {})); } } // namespace diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.td index be4e2160ee2..2348ac80b84 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/optimize.td @@ -27,8 +27,7 @@ def RemoveRedundantCastOps : Pat< (TF_ClipByValueOp:$clip $input, $min_value, $max_value), ConstBoolAttrFalse:$truncate2), ConstBoolAttrFalse:$truncate1), - (CreateOpWithOutputType<"TF::CastOp"> - (GetValueType $root_cast), $clip, ConstBoolAttrFalse), + (TF_CastOp $clip, ConstBoolAttrFalse), [(TensorOf<[I8]> $i8_cast), (TensorOf<[I32]> $clip), (IsIntSplatValueEqual<"int32_t", "-128"> $min_value), diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h index d90ad8ef60c..cebc14ed259 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h @@ -16,9 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_PASSES_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_PASSES_H_ +#include +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" @@ -35,9 +39,9 @@ std::unique_ptr> CreateInsertMainFunctionPass(); std::unique_ptr> CreateConvertFakeQuantToQdqPass(); // Lifts the quantizable spots as composite functions. -// TODO(b/249914162): Pass OpSet by value instead of reference. std::unique_ptr> -CreateLiftQuantizableSpotsAsFunctionsPass(const OpSet& op_set); +CreateLiftQuantizableSpotsAsFunctionsPass(OpSet target_opset, + bool enable_two_input_tensors); // Apply graph optimizations such as fusing and constant folding to prepare // lifting. @@ -58,7 +62,9 @@ CreateIssueIDsOfCustomAggregationOpsPass(); // Inserts quantized function library. std::unique_ptr> CreateInsertQuantizedFunctionsPass( - QuantizationMethod quantization_method, const OpSet& op_set); + tensorflow::quantization::QuantizationMethod::ExperimentalMethod + quantization_method, + OpSet target_opset); // Inserts custom aggregation operators for the calibration procedure. std::unique_ptr> @@ -68,7 +74,10 @@ CreateInsertCustomAggregationOpsPass(); // pass runs, functions in the given graph will be replaced with their quantized // versions. By doing so, the quantization will be applied to the given input. std::unique_ptr> CreateQuantizeCompositeFunctionsPass( - QuantizationMethod quantization_method, OpSet target_opset = OpSet::TF); + tensorflow::quantization::QuantizationMethod::ExperimentalMethod + quantization_method, + OpSet target_opset, bool enable_per_channel_quantization, + int min_num_elements_for_weights); // Converts dequantize-(quantizable) call-quantize pattern to a single call op // that has quantized input and output types. It is expected for this pass to @@ -79,16 +88,24 @@ std::unique_ptr> CreateQuantizePass(); // Overloading of CreateQuantizePass which takes QuantizationSpecs. std::unique_ptr> CreateQuantizePass( - QuantizationSpecs quant_specs); + QuantizationSpecs quant_specs, OpSet target_opset); -// Creates an instance of the PrepareQuantize pass, which will perfrom similar +// Creates an instance of the PrepareQuantize pass, which will perform similar // transformations as TFL::PrepareQuantizePass. std::unique_ptr> CreatePrepareQuantizePass( - QuantizationMethod quantization_method); + const QuantizationSpecs& quant_specs, + tensorflow::quantization::QuantizationMethod::ExperimentalMethod + quantization_method); // Creates an instance of the PrepareQuantizeDRQ pass, which will -// perfrom similar transformations as TFL::PrepareQuantizeDynamicRangePass. -std::unique_ptr> CreatePrepareQuantizeDRQPass(); +// perform similar transformations as TFL::PrepareQuantizeDynamicRangePass. +std::unique_ptr> CreatePrepareQuantizeDRQPass( + const QuantizationSpecs& quant_specs, OpSet op_set); + +// Creates an instance of the PreprocessOp pass, which will perform op +// preprocessing to allow multi-axis quantization, prior to quantization. +std::unique_ptr> CreatePreprocessOpPass( + const QuantizationSpecs& quant_specs, OpSet op_set); // Creates an instance of the PostQuantize pass, which will remove unnecessary // ops from the final quantized graph. @@ -123,6 +140,32 @@ CreateMergeInitializerFunctionOpsToMainPass(); // AssignVariableOps. std::unique_ptr> CreateUnfreezeConstantsPass(); +// Creates a pass that duplicates constants that affect the shape of a tensor +// after some computation. +std::unique_ptr> +CreateDuplicateShapeDeterminingConstantsPass(); + +// Creates a pass that creates a RestoreV2 op in the initializer function with +// type "restore_op" that initializes variables from the checkpoint. It finds +// tf.AssignVariableOp(tf.VarHandleOp, tf.Const) patterns in the initializer +// function and replaces tf.Consts with the results of RestoreV2. +std::unique_ptr> CreateInsertRestoreOpPass(); + +// Creates a pass that marks functions with the attribute `tf._noinline = true` +// to avoid being inlined by the `InlinerPass`. `noinline_functions` is the name +// of the functions to mark. +std::unique_ptr> CreateMarkFunctionsNoinlinePass( + ArrayRef noinline_functions); + +// Removes `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` patterns from the +// initializer function (type = "restore_op"). +// Note: initializing values (`tf.Const`s) will be removed and this may result +// in an information loss and uninitialized variables eventually. Make sure that +// this effect is desired (e.g. there is a `tf.RestoreV2Op` that restores the +// variables instead). +std::unique_ptr> +CreateRemoveVariableInitializationByConstPass(); + } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.cc index 4a17b32acba..03c6f14baa1 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.cc @@ -43,7 +43,7 @@ class PostQuantizePass MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PostQuantizePass) // Constructor used by the PassRegistration. This will remove the adaptor ops. - explicit PostQuantizePass() {} + explicit PostQuantizePass() = default; StringRef getArgument() const final { // This is the argument used to refer to the pass in @@ -104,6 +104,28 @@ struct RemoveVolatileOps } }; +// The StorageCastOp is used to cast from a quantized type to its storage type +// or the opposite. If none of its input and output is quantized, the op has +// no effect and should be removed. +class RemoveRedundantScast + : public mlir::OpRewritePattern { + public: + explicit RemoveRedundantScast(MLIRContext* context) + : OpRewritePattern(context) {} + + private: + LogicalResult matchAndRewrite(quantfork::StorageCastOp scast_op, + PatternRewriter& rewriter) const override { + if (QuantizedType::getQuantizedElementType(scast_op.getArg().getType()) || + QuantizedType::getQuantizedElementType(scast_op.getType())) { + return failure(); + } + + scast_op.replaceAllUsesWith(scast_op.getArg()); + return success(); + } +}; + #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.inc" void PostQuantizePass::runOnOperation() { @@ -111,7 +133,7 @@ void PostQuantizePass::runOnOperation() { auto func = getOperation(); auto* ctx = func.getContext(); patterns.add, - RemoveVolatileOps>(ctx); + RemoveVolatileOps, RemoveRedundantScast>(ctx); populateWithGenerated(patterns); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.td index ef4523fbfb5..5d879adea90 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.td @@ -31,5 +31,5 @@ def ReorderIdentityFollowingQuantizedFunction : Pat< (TF_IdentityOp (quantfork_StorageCastOp $value)))), (TF_IdentityOp - (CreateOpWithOutputType<"quantfork::DequantizeCastOp"> - (GetValueType $output), $value))>; + (quantfork_DequantizeCastOp + $value, (returnType (GetValueType $output))))>; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc index 053ac34c6bb..979c0ffca98 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc @@ -13,16 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include +#include "absl/algorithm/container.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" @@ -48,33 +55,200 @@ class PrepareLiftingPass "prepare lifting."; } - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); } void runOnOperation() override; }; -bool HasEqualElementSize(Value filter, Attribute val, - mlir::ArrayRef filter_indices, - mlir::ArrayRef val_indices) { - int filter_result = 1; - int val_result = 1; +// Check if given indices in `val1` has same number of elements as given +// indices in `val2`. +bool HasEqualElementSize(Value val1, Value val2, ArrayRef val1_indices, + ArrayRef val2_indices) { + ShapedType val1_shape = val1.getType().cast(); + ShapedType val2_shape = val2.getType().cast(); + if (!val1_shape.hasRank() || !val2_shape.hasRank()) return false; - mlir::ShapedType shaped_filter = filter.getType().cast(); - mlir::ShapedType shaped_val = val.dyn_cast().getType(); + int val1_result = 1; + int val2_result = 1; + for (auto idx : val1_indices) { + if (idx < 0) idx = idx + val1_shape.getRank(); + if (idx >= val1_shape.getRank() || val1_shape.isDynamicDim(idx)) { + return false; + } + val1_result *= val1_shape.getDimSize(idx); + } + + for (auto idx : val2_indices) { + if (idx < 0) idx = idx + val2_shape.getRank(); + if (idx >= val2_shape.getRank() || val2_shape.isDynamicDim(idx)) { + return false; + } + val2_result *= val2_shape.getDimSize(idx); + } - for (auto idx : filter_indices) { - if (idx >= shaped_filter.getRank()) return false; - filter_result *= shaped_filter.getDimSize(idx); + return val1_result == val2_result; +} + +// Matches convolution op with "NHWC" data format or matmul op with false adj_y. +// The list of supported ops in this function is: +// - Conv2DOp +// - Conv3DOp +// - DepthwiseConv2dNativeOp +// - MatMulOp +// - BatchMatMulV2Op +LogicalResult MatchSupportedAffineOp(Operation* op, Value& binding_output, + Value& binding_input, + Value& binding_weight) { + bool is_supported_affine_op = false; + if (llvm::isa(op)) { + if (const auto data_format = op->getAttrOfType("data_format")) { + is_supported_affine_op = data_format.getValue().equals("NHWC") || + data_format.getValue().equals("NDHWC"); + } + } else if (llvm::isa(op)) { + if (const auto adj_y = op->getAttrOfType("adj_y")) { + is_supported_affine_op = !adj_y.getValue(); + } + } + + if (!is_supported_affine_op) return failure(); + + // Bind input, output and weight to the given values. + binding_output = op->getResult(0); + binding_input = op->getOperand(0); + binding_weight = op->getOperand(1); + return success(); +} + +// Makes the 1D value broadcastable with the `rhs_shape`. +Value MakeOneDimValueBroadcastable(OpBuilder& builder, Location loc, + Value value, ShapedType rhs_shape) { + ShapedType value_shape = value.getType().dyn_cast_or_null(); + if (!value_shape || value_shape.getRank() != 1 || + !value_shape.hasStaticShape() || !rhs_shape.hasStaticShape()) { + return {}; + } + + int64_t num_elements = value_shape.getNumElements(); + llvm::SmallVector new_shape; + for (auto idx : llvm::reverse(llvm::seq(0, rhs_shape.getRank()))) { + const int64_t rhs_dim = rhs_shape.getDimSize(idx); + if (num_elements % rhs_dim != 0) { + return {}; + } + new_shape.push_back(rhs_dim); + num_elements = num_elements / rhs_dim; + if (num_elements == 1) break; } + absl::c_reverse(new_shape); - for (auto idx : val_indices) { - if (idx >= shaped_val.getRank()) return false; - val_result *= shaped_val.getDimSize(idx); + auto reshape_op = builder.create( + loc, value, Create1DConstValue(builder, loc, new_shape)); + return ConstantFoldOpIfPossible(reshape_op).front(); +} + +// Checks if a value can be symetrically quantized. +bool CanBeSymmetricallyQuantized(Value weight) { + auto dq_op = weight.getDefiningOp(); + if (!dq_op) return true; + + auto qtype = dq_op.getArg().getType().cast().getElementType(); + if (auto uniform_type = llvm::dyn_cast_or_null(qtype)) { + return uniform_type.getZeroPoint() == 0; + } else if (auto per_axis_type = + llvm::dyn_cast_or_null(qtype)) { + return absl::c_all_of(per_axis_type.getZeroPoints(), + [](int64_t x) { return x == 0; }); + } + return false; +} + +// Multiplies two 1D arrays with broadcasting support. +template +SmallVector MultiplyTwoArrays(ArrayRef a, ArrayRef b) { + auto get_value_at = [](ArrayRef v, size_t i) -> T { + if (v.size() == 1) return v.front(); + return v[i]; + }; + + size_t max_size = std::max(a.size(), b.size()); + SmallVector result(max_size); + for (size_t i : llvm::seq(0, max_size)) { + result[i] = get_value_at(a, i) * get_value_at(b, i); + } + return result; +} + +// Multiplies the value followed by a FakeQuant op and adjusts the quantization +// params. This funtion only supports symetrically quantized values. +Value MultiplyFakeQuantValue(OpBuilder& builder, Location loc, Value value, + Value multiplier) { + auto dq_op = value.getDefiningOp(); + if (!dq_op) { + auto mul_op = builder.create(loc, value, multiplier); + return ConstantFoldOpIfPossible(mul_op).front(); + } + auto q_op = dq_op.getArg().getDefiningOp(); + if (!q_op) return {}; + + Value float_value = q_op.getArg(); + Value new_value = builder.create(loc, float_value, multiplier); + auto new_value_type = new_value.getType().cast(); + + // Get multiplier value in double. + DenseFPElementsAttr multiplier_attr; + if (!matchPattern(multiplier, m_Constant(&multiplier_attr)) || + multiplier_attr.getType().cast().getRank() > 1) { + return {}; + } + std::vector multiplier_values; + absl::c_transform(multiplier_attr, std::back_inserter(multiplier_values), + [](auto v) { return FloatAttr::getValueAsDouble(v); }); + ArrayRef multiplier_array(multiplier_values.data(), + multiplier_values.size()); + + // Multiply the quantization parameters by the multiplier. + QuantizedType new_qtype; + auto element_type = q_op.getType().cast().getElementType(); + if (auto uniform_type = llvm::dyn_cast(element_type)) { + if (multiplier_attr.isSplat()) { + double new_scale = multiplier_array.front() * uniform_type.getScale(); + new_qtype = UniformQuantizedType::get( + uniform_type.getFlags(), uniform_type.getStorageType(), + uniform_type.getExpressedType(), new_scale, + uniform_type.getZeroPoint(), uniform_type.getStorageTypeMin(), + uniform_type.getStorageTypeMax()); + } else { + auto new_scales = + MultiplyTwoArrays(multiplier_array, {uniform_type.getScale()}); + int32_t quantized_dim = new_value_type.getRank() - 1; + auto new_zero_points = + SmallVector(new_scales.size(), uniform_type.getZeroPoint()); + new_qtype = UniformQuantizedPerAxisType::get( + uniform_type.getFlags(), uniform_type.getStorageType(), + uniform_type.getExpressedType(), new_scales, new_zero_points, + quantized_dim, uniform_type.getStorageTypeMin(), + uniform_type.getStorageTypeMax()); + } + } else if (auto per_axis_type = + llvm::dyn_cast_or_null( + element_type)) { + auto new_scales = + MultiplyTwoArrays(multiplier_array, per_axis_type.getScales()); + new_qtype = UniformQuantizedPerAxisType::get( + per_axis_type.getFlags(), per_axis_type.getStorageType(), + per_axis_type.getExpressedType(), new_scales, + per_axis_type.getZeroPoints(), per_axis_type.getQuantizedDimension(), + per_axis_type.getStorageTypeMin(), per_axis_type.getStorageTypeMax()); } - return filter_result == val_result; + auto quantize = builder.create( + q_op.getLoc(), new_value_type.clone(new_qtype), new_value); + auto dequantize = builder.create( + dq_op.getLoc(), new_value_type, quantize.getResult()); + return ConstantFoldOpIfPossible(dequantize).front(); } // Copied from tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc. @@ -85,8 +259,8 @@ struct RemoveIdentity : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TF::IdentityOp identity, - PatternRewriter &rewriter) const override { - for (Operation *user : identity->getUsers()) { + PatternRewriter& rewriter) const override { + for (Operation* user : identity->getUsers()) { // Replace the op with the input if output is only used by TF ops. // Currently this is more on the conservative side since we need to ensure // every consumer op to be a TF op before applying this pattern. We can @@ -102,7 +276,7 @@ struct RemoveIdentity : public OpRewritePattern { } } - rewriter.replaceOp(identity, identity.input()); + rewriter.replaceOp(identity, identity.getInput()); return success(); } }; @@ -110,7 +284,7 @@ struct RemoveIdentity : public OpRewritePattern { #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.inc" void PrepareLiftingPass::runOnOperation() { - MLIRContext *ctx = &getContext(); + MLIRContext* ctx = &getContext(); auto func = getOperation(); // The pattern includes decomposing batch normalization ops, fusing add/mul diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td index 04651bea7bf..85edb26daf6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td @@ -46,13 +46,13 @@ def FoldFusedBatchNormV3: Pattern< $x, $scale, $offset, $mean, $variance, F32Attr:$epsilon, $exponential_avg_factor, $data_format, IsFalseBoolAttr:$is_training), - [(TF_AddOp + [(TF_AddV2Op (TF_MulOp $x, (TF_MulOp:$multiplier $scale, (TF_RsqrtOp - (TF_AddOp $variance, + (TF_AddV2Op $variance, (TF_ConstOp $epsilon))))), (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))), // We already guaranteed that the last five results have no use so it does @@ -66,15 +66,10 @@ def FoldFusedBatchNormV3: Pattern< (HasNoUseOf:$root__3), (HasNoUseOf:$root__4), (HasNoUseOf:$root__5)]>; -class HasRank : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() == " # n>, - "Checks if the value has rank of 'n'.">; - class HasEqualElementSize shape_1, list shape_2> : Constraint< CPred<"quant::HasEqualElementSize($0, $1," - "llvm::ArrayRef({" # !interleave(shape_1, ", ") # "})," - "llvm::ArrayRef({" # !interleave(shape_2, ", ") # "}))">, + "llvm::ArrayRef({" # !interleave(shape_1, ", ") # "})," + "llvm::ArrayRef({" # !interleave(shape_2, ", ") # "}))">, "Checks if the given dimensions contain the same number of elements.">; def HasEqualShape : Constraint().getShape() == $1.getType().cast().getShape()">, "Checks if the shapes of tensors are same.">; -def Expand1DTo4DForConv2D : NativeCodeCall< - "$0.cast().reshape(" - "RankedTensorType::get({1,1,1,$0.getType().cast().getNumElements()}," - "getElementTypeOrSelf($0.getType())))">; - -def Expand1DTo4DForDepthwiseConv2D : NativeCodeCall< - "$0.cast().reshape(" - "RankedTensorType::get({1,1,$1.getType().cast().getDimSize(2),$1.getType().cast().getDimSize(3)}," - "getElementTypeOrSelf($0.getType())))">; - -def CreateUnrankedTensorTypeWithElementType : NativeCodeCall< - "UnrankedTensorType::get(getElementTypeOrSelf($0.getType()))">; - -// Matching AffineOp followed by an AddOp patterns. -def MatchConv2dAndAddPattern : Pat< - (TF_AddOp (TF_Conv2DOp:$conv_out $input, $filter, $strides, $use_cudnn, $padding, - $explicit_padding, IsDataFormatNHWC:$data_format, $dilations), - (TF_ConstOp IsFloatElementsAttr:$value)), - (TF_BiasAddOp (TF_Conv2DOp $input, $filter, $strides, $use_cudnn, $padding, - $explicit_padding, $data_format, $dilations), (TF_ConstOp $value), $data_format), - [(HasOneUse $conv_out), (HasRank<1> $value), (HasRank<4> $filter), - (HasEqualElementSize<[3],[0]> $filter, $value)]>; - -def MatchDepthwiseConv2dAndAddPattern : Pat< - (TF_AddOp (TF_DepthwiseConv2dNativeOp:$conv_out $input, $filter, $strides, $padding, - $explicit_padding, IsDataFormatNHWC:$data_format, $dilations), - (TF_ConstOp IsFloatElementsAttr:$value)), - (TF_BiasAddOp (TF_DepthwiseConv2dNativeOp $input, $filter, $strides, $padding, - $explicit_padding, $data_format, $dilations, - (returnType (CreateUnrankedTensorTypeWithElementType $input))), - (TF_ConstOp $value), $data_format), - [(HasOneUse $conv_out), (HasRank<1> $value), (HasRank<4> $filter), - (HasEqualElementSize<[2,3],[0]> $filter, $value)]>; - -// Fusing AffineOp followed by an MulOp patterns. -def FuseConv2dAndMul : Pat< - (TF_MulOp (TF_Conv2DOp:$conv_out $input, $filter, $strides, $use_cudnn, $padding, - $explicit_padding, IsDataFormatNHWC:$data_format, $dilations), - (TF_ConstOp IsFloatElementsAttr:$value)), - (TF_Conv2DOp $input, (TF_MulOp $filter, - (TF_ConstOp (Expand1DTo4DForConv2D $value))), $strides, $use_cudnn, $padding, - $explicit_padding, $data_format, $dilations), - [(HasOneUse $conv_out), (HasRank<1> $value), (HasRank<4> $filter), - (HasEqualElementSize<[3],[0]> $filter, $value)]>; - -def FuseDepthwiseConv2dAndMul : Pat< - (TF_MulOp (TF_DepthwiseConv2dNativeOp:$conv_out $input, $filter, $strides, $padding, - $explicit_padding, IsDataFormatNHWC:$data_format, $dilations), - (TF_ConstOp IsFloatElementsAttr:$value)), - (TF_DepthwiseConv2dNativeOp $input, (TF_MulOp $filter, - (TF_ConstOp (Expand1DTo4DForDepthwiseConv2D $value, $filter))), $strides, $padding, - $explicit_padding, $data_format, $dilations), - [(HasOneUse $conv_out), (HasRank<1> $value), (HasRank<4> $filter), - (HasEqualElementSize<[2,3],[0]> $filter, $value)]>; - -// Fusing AffineOp followed by an BiasAddOp and an AddOp patterns. -def FuseConv2dWithBiasAndAdd : Pat< - (TF_AddOp - (TF_BiasAddOp:$biasadd_out - (TF_Conv2DOp:$conv_out $input, $filter, $strides, $use_cudnn, $padding, - $explicit_padding, IsDataFormatNHWC:$data_format, $dilations), - $bias, IsDataFormatNHWC:$bias_data_format), - (TF_ConstOp IsFloatElementsAttr:$value)), - (TF_BiasAddOp - (TF_Conv2DOp $input, $filter, $strides, $use_cudnn, $padding, - $explicit_padding, $data_format, $dilations), - (TF_AddOp $bias, (TF_ConstOp $value)), $bias_data_format), - [(HasOneUse $conv_out), (HasOneUse $biasadd_out), (HasRank<1> $value), - (HasEqualShape $value, $bias)]>; - -def FuseDepthwiseConv2dWithBiasAndAdd : Pat< - (TF_AddOp - (TF_BiasAddOp:$biasadd_out - (TF_DepthwiseConv2dNativeOp:$conv_out $input, $filter, $strides, $padding, - $explicit_padding, IsDataFormatNHWC:$data_format, $dilations), - $bias, IsDataFormatNHWC:$bias_data_format), - (TF_ConstOp IsFloatElementsAttr:$value)), +// Make the 1D value $0 broadcastable with the shape of $1. +def MakeOneDimValueBroadcastable : NativeCodeCall< + "MakeOneDimValueBroadcastable($_builder, $_loc, $0, $1.getType().cast())">; + +// Match convolution op with "NHWC" data format or matmul op. +def SupportedAffineOpMatcher : NativeCodeCall< + "MatchSupportedAffineOp($_self, $0, $1, $2)">; + +// Checks if a value can be symetrically quantized. +def CanBeSymmetricallyQuantized : Constraint>; + +// Multiplies the value followed by a FakeQuant op and adjusts its params. +def MultiplyFakeQuantValue : NativeCodeCall< + "MultiplyFakeQuantValue($_builder, $_loc, $0...)">; + +// Convert AddV2Op following an AffineOp to BiasAddOp. +// For Conv3D, even though the Conv3D op has "NDHWC" data format, the BiasAdd +// will still has the data format of "NHWC". +def ConvertAddToBiasAdd : Pat< + (TF_AddV2Op + (SupportedAffineOpMatcher $conv_out, $input, $weight), + (TF_ConstOp:$add_rhs IsFloatElementsAttr:$add_rhs_value)), + (TF_BiasAddOp $conv_out, $add_rhs, (CreateStringAttr<"NHWC">)), + [(HasRankOf<1> $add_rhs_value), + (HasEqualElementSize<[-1], [0]> $conv_out, $add_rhs)]>; + +// Fuse consecutive BiasAddOp and an AddV2Op. +def FuseBiasAndAddV2 : Pat< + (TF_AddV2Op + (TF_BiasAddOp:$bias_add + $conv_out, + (TF_ConstOp:$bias IsFloatElementsAttr:$bias_value), $data_format), + (TF_ConstOp:$add_rhs IsFloatElementsAttr:$add_rhs_value)), (TF_BiasAddOp - (TF_DepthwiseConv2dNativeOp $input, $filter, $strides, $padding, - $explicit_padding, $data_format, $dilations, - (returnType (CreateUnrankedTensorTypeWithElementType $input))), - (TF_AddOp $bias, (TF_ConstOp $value)), $bias_data_format), - [(HasOneUse $conv_out), (HasOneUse $biasadd_out), (HasRank<1> $value), - (HasEqualShape $value, $bias)]>; - -// Fusing AffineOp followed by an BiasAddOp and an MulOp patterns. -def FuseConv2dWithBiasAndMul : Pat< + $conv_out, (TF_AddV2Op $bias, $add_rhs), $data_format), + [(HasOneUse $bias_add), + (HasEqualShape $bias_value, $add_rhs_value)]>; + +// Fuse AffineOp followed by an MulOp patterns. +def FuseAffineOpAndMul : Pat< (TF_MulOp - (TF_BiasAddOp:$biasadd_out - (TF_Conv2DOp:$conv_out $input, $filter, $strides, $use_cudnn, $padding, - $explicit_padding, IsDataFormatNHWC:$data_format, $dilations), - $bias, IsDataFormatNHWC:$bias_data_format), - (TF_ConstOp IsFloatElementsAttr:$value)), - (TF_BiasAddOp (TF_Conv2DOp $input, (TF_MulOp $filter, - (TF_ConstOp (Expand1DTo4DForConv2D $value))), $strides, $use_cudnn, $padding, - $explicit_padding, $data_format, $dilations), (TF_MulOp $bias, - (TF_ConstOp $value)), $bias_data_format), - [(HasOneUse $conv_out), (HasOneUse $biasadd_out), - (HasRank<1> $value), (HasRank<4> $filter), - (HasEqualElementSize<[3],[0]> $filter, $value), - (HasEqualShape $value, $bias)]>; - -def FuseDepthwiseConv2dWithBiasAndMul : Pat< + (SupportedAffineOpMatcher $conv_out, $input, $weight), + (TF_ConstOp:$mul_rhs IsFloatElementsAttr:$mul_rhs_value)), + (CloneOpWithReplacedOperands + (GetDefiningOp $conv_out), + $input, + (MultiplyFakeQuantValue $weight, + (MakeOneDimValueBroadcastable $mul_rhs, $weight))), + [(HasOneUse $conv_out), + (HasRankOf<1> $mul_rhs_value), + (HasStaticShapeConstraint $weight), + (CanBeSymmetricallyQuantized $weight), + (HasEqualElementSize<[-1], [0]> $conv_out, $mul_rhs)]>; + +// Fuse AffineOp followed by an BiasAddOp and an MulOp patterns. +def FuseAffineOpWithBiasAddAndMul : Pat< (TF_MulOp - (TF_BiasAddOp:$biasadd_out - (TF_DepthwiseConv2dNativeOp:$conv_out $input, $filter, $strides, $padding, - $explicit_padding, IsDataFormatNHWC:$data_format, $dilations), - $bias, IsDataFormatNHWC:$bias_data_format), - (TF_ConstOp IsFloatElementsAttr:$value)), - (TF_BiasAddOp (TF_DepthwiseConv2dNativeOp $input, (TF_MulOp $filter, - (TF_ConstOp (Expand1DTo4DForDepthwiseConv2D $value, $filter))), $strides, $padding, - $explicit_padding, $data_format, $dilations, - (returnType (CreateUnrankedTensorTypeWithElementType $input))), - (TF_MulOp $bias, (TF_ConstOp $value)), $bias_data_format), - [(HasOneUse $conv_out), (HasOneUse $biasadd_out), - (HasRank<1> $value), (HasRank<4> $filter), - (HasEqualElementSize<[2,3],[0]> $filter, $value), - (HasEqualShape $value, $bias)]>; + (TF_BiasAddOp:$bias_add + (SupportedAffineOpMatcher $conv_out, $input, $weight), + $bias, $data_format), + (TF_ConstOp:$mul_rhs IsFloatElementsAttr:$mul_rhs_value)), + (TF_BiasAddOp + (CloneOpWithReplacedOperands + (GetDefiningOp $conv_out), + $input, + (MultiplyFakeQuantValue $weight, + (MakeOneDimValueBroadcastable $mul_rhs, $weight))), + (MultiplyFakeQuantValue $bias, $mul_rhs), $data_format), + [(HasOneUse $conv_out), + (HasOneUse $bias_add), + (HasRankOf<1> $mul_rhs_value), + (HasStaticShapeConstraint $weight), + (CanBeSymmetricallyQuantized $weight), + (CanBeSymmetricallyQuantized $bias), + (HasEqualShape $bias, $mul_rhs_value)]>; + diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc index 9174c89be6f..72ee6151b05 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc @@ -52,6 +52,9 @@ namespace quant { namespace { +using QuantMethod = + tensorflow::quantization::QuantizationMethod::ExperimentalMethod; + // Applies prepare quantization on the model in TF dialect. This pass runs // before the quantization pass and propagate the quantization parameters // across ops. This step is necessary for post-training quantization and also @@ -73,19 +76,23 @@ class PrepareQuantizePass quant_specs_.inference_type = tensorflow::DT_QINT8; } - explicit PrepareQuantizePass(QuantizationMethod quantization_method) { + // Constructor used by manually creating the pass. + explicit PrepareQuantizePass(const QuantizationSpecs& quant_specs, + QuantMethod quantization_method) + : quant_specs_(quant_specs) { quant_specs_.inference_type = tensorflow::DT_QINT8; + enable_per_channel_quantization_ = !quant_specs_.disable_per_channel; enable_post_training_quantize_ = - (quantization_method == QuantizationMethod::kPostTrainingQuantization); + (quantization_method == + tensorflow::quantization::QuantizationMethod::STATIC_RANGE); } PrepareQuantizePass(const PrepareQuantizePass& other) { quant_specs_ = other.quant_specs_; enable_post_training_quantize_ = other.enable_post_training_quantize_; - disable_per_channel_ = other.disable_per_channel_; + enable_per_channel_quantization_ = !quant_specs_.disable_per_channel; } - // Constructor used by manually creating the pass. explicit PrepareQuantizePass(const QuantizationSpecs& quant_specs) : quant_specs_(quant_specs) { enable_post_training_quantize_ = quant_specs.post_training_quantization; @@ -151,9 +158,11 @@ class PrepareQuantizePass *this, "post-training-quantize", llvm::cl::init(false), llvm::cl::desc("Enable post training quantization. Only used in tests.")}; - Option disable_per_channel_{ - *this, "disable-per-channel", llvm::cl::init(false), - llvm::cl::desc("Whether disable per-channel quantized weights.")}; + // A local flag is needed for testing conditions in + // prepare_quantize_ptq_per_channel.mlir. + Option enable_per_channel_quantization_{ + *this, "enable-per-channel-quantization", llvm::cl::init(false), + llvm::cl::desc("Whether enable per-channel quantized weights.")}; }; bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) { @@ -204,9 +213,8 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) { if (!min_max.first.has_value() || !min_max.second.has_value()) return; TypeAttr params = quant::GetQuantizedTypeAttr( - builder, input_type, - builder.getF64FloatAttr(min_max.first.getValue()), - builder.getF64FloatAttr(min_max.second.getValue()), + builder, input_type, builder.getF64FloatAttr(min_max.first.value()), + builder.getF64FloatAttr(min_max.second.value()), /*quant_dim=*/-1, num_bits, narrow_range, is_signed); builder.setInsertionPoint(block, insertion_point); auto q_op = builder.create( @@ -389,7 +397,7 @@ void PrepareQuantizePass::runOnOperation() { // During the legalization, unsigned quantized type is used, so we have to // convert all of them to signed. - RewritePatternSet patterns(&getContext()); + RewritePatternSet patterns(ctx); populateWithGenerated(patterns); patterns.add>(ctx); // Convert quant stats to int8 quantization parameters. @@ -403,9 +411,8 @@ void PrepareQuantizePass::runOnOperation() { // Finally, the quantization parameters can be propagated to the rest of the // values (tensors). ApplyQuantizationParamsPropagation( - func, is_signed, disable_per_channel_ || quant_specs_.disable_per_channel, - GetTFOpQuantSpec, GetTfQuantScaleSpec, infer_tensor_range, - quant_specs_.legacy_float_scale); + func, is_signed, !enable_per_channel_quantization_, GetTFOpQuantSpec, + GetTfQuantScaleSpec, infer_tensor_range, quant_specs_.legacy_float_scale); RewritePatternSet patterns2(ctx); patterns2.add(ctx); @@ -416,8 +423,9 @@ void PrepareQuantizePass::runOnOperation() { // Creates an instance of the TensorFlow dialect PrepareQuantize pass. std::unique_ptr> CreatePrepareQuantizePass( - QuantizationMethod quantization_method) { - return std::make_unique(quantization_method); + const QuantizationSpecs& quant_specs, QuantMethod quantization_method) { + return std::make_unique(quant_specs, + quantization_method); } static PassRegistration pass; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc index 34309788054..2d96d13091c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc @@ -18,7 +18,9 @@ limitations under the License. #include #include +#include #include +#include #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -31,8 +33,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" //===----------------------------------------------------------------------===// // The prepare-quantize-drq Pass. @@ -42,12 +44,13 @@ namespace quant { namespace { -using QuantizationUnits = llvm::SetVector>; +using QuantizationUnit = std::pair; +using QuantizationUnits = llvm::SetVector; // Applies prepare quantization on the model in TF dialect for dynamic range // quantization case. class PrepareQuantizeDRQPass - : public PassWrapper> { + : public PassWrapper> { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); @@ -58,13 +61,22 @@ class PrepareQuantizeDRQPass // Constructor used by the PassRegistration and enforce int8 quantization. // This is only used by test. - explicit PrepareQuantizeDRQPass() { + explicit PrepareQuantizeDRQPass() : op_set_(OpSet::UNIFORM_QUANTIZED) { quant_specs_.inference_type = tensorflow::DT_QINT8; } // Constructor used by manually creating the pass. - explicit PrepareQuantizeDRQPass(const QuantizationSpecs& quant_specs) - : quant_specs_(quant_specs) {} + explicit PrepareQuantizeDRQPass(const QuantizationSpecs& quant_specs, + OpSet op_set) + : quant_specs_(quant_specs), op_set_(op_set) { + enable_per_channel_quantization_ = !quant_specs_.disable_per_channel; + } + + PrepareQuantizeDRQPass(const PrepareQuantizeDRQPass& other) { + quant_specs_ = other.quant_specs_; + op_set_ = other.op_set_; + enable_per_channel_quantization_ = !quant_specs_.disable_per_channel; + } StringRef getArgument() const final { // This is the argument used to refer to the pass in @@ -86,6 +98,11 @@ class PrepareQuantizeDRQPass private: QuantizationSpecs quant_specs_; + OpSet op_set_; + + Option enable_per_channel_quantization_{ + *this, "enable-per-channel-quantization", llvm::cl::init(false), + llvm::cl::desc("Whether enable per-channel quantized weights.")}; }; // If the weight is applicable to dynamic range quantization, insert Quantize @@ -93,9 +110,13 @@ class PrepareQuantizeDRQPass class PrepareDRQQuantizableOp : public OpRewritePattern { public: explicit PrepareDRQQuantizableOp(MLIRContext* context, - const quant::QuantizationSpecs& quant_specs) + const quant::QuantizationSpecs& quant_specs, + OpSet op_set, + bool enable_per_channel_quantization) : OpRewritePattern(context), - quant_specs_(quant_specs) {} + quant_specs_(quant_specs), + op_set_(op_set), + enable_per_channel_quantization_(enable_per_channel_quantization) {} LogicalResult matchAndRewrite(arith::ConstantOp op, PatternRewriter& rewriter) const override { @@ -143,28 +164,49 @@ class PrepareDRQQuantizableOp : public OpRewritePattern { // Apply per-tensor quantization for int8 dynamic range quantization. bool quantizeOpAsInt8(PatternRewriter& rewriter, arith::ConstantOp op, - std::pair quant_op) const { - bool is_narrow_range = true; - bool is_legacy_float = quant_specs_.legacy_float_scale; - bool is_signed = quant_specs_.IsSignedInferenceType(); - int bit_width = quant_specs_.GetQuantizationTypeWidth(); - - QuantizedType quant_type = nullptr; + QuantizationUnit quant_op) const { + auto [quantized_op, weight_idx] = quant_op; + const bool is_narrow_range = true; + const bool is_legacy_float = quant_specs_.legacy_float_scale; + const bool is_signed = quant_specs_.IsSignedInferenceType(); + const int bit_width = quant_specs_.GetQuantizationTypeWidth(); + + std::unique_ptr spec = GetTFOpQuantSpec(quantized_op); + const int quant_dim = spec->coeff_op_quant_dim[weight_idx]; + const bool is_per_channel_quantization = + enable_per_channel_quantization_ && quant_dim != -1; + + QuantizedType quant_type; DenseFPElementsAttr attr; if (!matchPattern(op->getResult(0), m_Constant(&attr))) return false; - quant_type = quant::GetUniformQuantizedTypeForWeight( - attr, is_narrow_range && is_signed, bit_width, is_signed, - is_narrow_range, is_legacy_float) - .template dyn_cast(); + if (attr.size() < quant_specs_.minimum_elements_for_weights) { + op->emitRemark("Quantization is skipped for ") + << quantized_op->getName().getStringRef().str() << " because it has " + << attr.dyn_cast().size() + << " elements which is fewer than the threshold(" + << quant_specs_.minimum_elements_for_weights << " elements)."; + return false; + } + if (is_per_channel_quantization) { + quant_type = quant::GetUniformQuantizedPerAxisTypeForWeight( + attr, quant_dim, + /*symmetric=*/true, bit_width, is_signed, + is_narrow_range, is_legacy_float) + .template dyn_cast(); + } else { + quant_type = quant::GetUniformQuantizedTypeForWeight( + attr, is_narrow_range && is_signed, bit_width, is_signed, + is_narrow_range, is_legacy_float) + .template dyn_cast(); + } return insertQDQ(rewriter, op, quant_type, quant_op); } // Insert Quantize and Dequantize ops. bool insertQDQ(PatternRewriter& rewriter, arith::ConstantOp op, - QuantizedType quant_type, - std::pair quant_op) const { + QuantizedType quant_type, QuantizationUnit quant_op) const { if (!quant_type) return false; Operation* quantize_op = quant_op.first; @@ -208,7 +250,9 @@ class PrepareDRQQuantizableOp : public OpRewritePattern { } protected: - quant::QuantizationSpecs quant_specs_; + QuantizationSpecs quant_specs_; + OpSet op_set_; + bool enable_per_channel_quantization_; }; // Remove all the stats ops which are redundant for dynamic range quantizaiton. @@ -222,23 +266,31 @@ void PrepareQuantizeDRQPass::removeAllStatsOp(func::FuncOp func) { #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.inc" void PrepareQuantizeDRQPass::runOnOperation() { - func::FuncOp func = getOperation(); - MLIRContext* ctx = func.getContext(); - - removeAllStatsOp(func); + MLIRContext* ctx = &getContext(); + RewritePatternSet patterns(ctx); + ModuleOp module_op = getOperation(); - RewritePatternSet patterns(&getContext()); populateWithGenerated(patterns); - patterns.add(ctx, quant_specs_); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + patterns.add(ctx, quant_specs_, op_set_, + enable_per_channel_quantization_); + FrozenRewritePatternSet frozen_patterns(std::move(patterns)); + + for (auto func : module_op.getOps()) { + removeAllStatsOp(func); + if (failed(applyPatternsAndFoldGreedily(func, frozen_patterns))) { + func.emitError() << "quant-prepare-quantize-drq failed."; + signalPassFailure(); + } + } } } // namespace // Creates an instance of the TensorFlow dialect PrepareQuantizeDRQ // pass. -std::unique_ptr> CreatePrepareQuantizeDRQPass() { - return std::make_unique(); +std::unique_ptr> CreatePrepareQuantizeDRQPass( + const QuantizationSpecs& quant_specs, const OpSet op_set) { + return std::make_unique(quant_specs, op_set); } static PassRegistration pass; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc new file mode 100644 index 00000000000..18aa58fe60c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc @@ -0,0 +1,209 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This transformation pass applies quantization propagation on TF dialect. + +#include +#include +#include +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" + +//===----------------------------------------------------------------------===// +// The preprocess-op Pass. +// +namespace mlir { +namespace quant { + +namespace { + +using QuantizationUnit = std::pair; +using QuantizationUnits = llvm::SetVector; + +// Preprocesses ops to allow multi-axis quantization, prior to quantization +// passes. Currently, per-channel quantization only supports 1D results. +class PreprocessOpPass + : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PreprocessOpPass) + + // Constructor used by the PassRegistration and enforce int8 quantization. + // This is only used by test. + explicit PreprocessOpPass() : op_set_(OpSet::UNIFORM_QUANTIZED) { + quant_specs_.inference_type = tensorflow::DT_QINT8; + } + + // Constructor used by manually creating the pass. + explicit PreprocessOpPass(const QuantizationSpecs& quant_specs, OpSet op_set) + : quant_specs_(quant_specs), op_set_(op_set) {} + + PreprocessOpPass(const PreprocessOpPass& other) { + quant_specs_ = other.quant_specs_; + op_set_ = other.op_set_; + } + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "quant-preprocess-op"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Preprocess TF op prior to quantization"; + } + + void runOnOperation() override; + + private: + QuantizationSpecs quant_specs_; + OpSet op_set_; +}; + +// Apply constant transformations for the op_set. +class PreprocessConstantOp : public OpRewritePattern { + public: + explicit PreprocessConstantOp(MLIRContext* context, OpSet op_set) + : OpRewritePattern(context), op_set_(op_set) {} + + LogicalResult matchAndRewrite(TF::PartitionedCallOp op, + PatternRewriter& rewriter) const override { + const auto f_attr = op.getFAttr().dyn_cast(); + // Non-quantizable op + if (!op->hasAttr(kQuantTraitAttrName)) return failure(); + StringRef function_name = f_attr.getValue(); + if (!function_name.startswith("composite_")) { + return failure(); + } + + std::unique_ptr spec = GetTFOpQuantSpec(op); + const absl::flat_hash_set operands = spec->quantizable_operands; + + if (function_name.contains("depthwise_conv2d")) { + // Uniform Quantized op requires weights of tf.DepthwiseConv2dNative to + // be transformed from [H,W,C,M] to [H,W,1,CxM] where + // H=height,W=width,C=channel,M=multiplier. Therefore, a reshape op is + // inserted between the constant op and the function op so that the + // constant is safely transformed for the multi-use cases as well. Note + // that bias doesn't need transformation as its shape is already in [CxM]. + if (operands.size() != 1) return failure(); + int weight_operand_idx = *operands.begin(); + Operation* weight_op = op.getOperand(weight_operand_idx).getDefiningOp(); + + if (op_set_ == OpSet::UNIFORM_QUANTIZED) { + DenseFPElementsAttr attr; + if (!matchPattern(weight_op->getResult(0), m_Constant(&attr))) { + return failure(); + } + + // Get new shape. + llvm::ArrayRef cur_shape = attr.getType().getShape(); + int cur_rank = cur_shape.size(); + if (cur_rank != 4 || cur_shape[2] == 1) return failure(); + TensorType new_shape = RankedTensorType::get( + {cur_shape[0], cur_shape[1], 1, cur_shape[2] * cur_shape[3]}, + attr.getElementType()); + + // Inserts a reshape op. + auto shape_spec_type = + RankedTensorType::get({cur_rank}, rewriter.getIntegerType(64)); + auto new_shape_const_attr = + DenseElementsAttr::get(shape_spec_type, new_shape.getShape()); + rewriter.setInsertionPointAfter(weight_op); + auto new_shape_const = rewriter.create( + weight_op->getLoc(), shape_spec_type, new_shape_const_attr); + auto reshape_op = rewriter.create( + weight_op->getLoc(), new_shape, weight_op->getResult(0), + new_shape_const); + op->setOperand(weight_operand_idx, reshape_op); + + // Create a new function with preprocessed types. + ModuleOp module = op->getParentOfType(); + SymbolTable symbol_table(module); + func::FuncOp float_func = + dyn_cast(symbol_table.lookup(function_name)); + OperandRange func_args = op.getArgs(); + func::FuncOp new_float_func = float_func.clone(); + + SmallVector new_float_func_args{func_args.begin(), + func_args.end()}; + new_float_func_args[weight_operand_idx] = reshape_op; + new_float_func.getArgument(weight_operand_idx).setType(new_shape); + new_float_func.setType(FunctionType::get( + getContext(), TypeRange{ValueRange{new_float_func_args}}, + new_float_func.getResultTypes())); + symbol_table.insert(new_float_func); + + op->setAttr("f", SymbolRefAttr::get(rewriter.getContext(), + new_float_func.getName())); + return success(); + } + } + return failure(); + } + + private: + const OpSet op_set_; +}; + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.inc" + +void PreprocessOpPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + RewritePatternSet patterns(ctx); + ModuleOp module_op = getOperation(); + + populateWithGenerated(patterns); + patterns.add(ctx, op_set_); + FrozenRewritePatternSet frozen_patterns(std::move(patterns)); + + for (auto func : module_op.getOps()) { + if (failed(applyPatternsAndFoldGreedily(func, frozen_patterns))) { + func.emitError() << "quant-preprocess-op failed."; + signalPassFailure(); + } + } +} + +} // namespace + +// Creates an instance of the TensorFlow dialect PreprocessOp +// pass. +std::unique_ptr> CreatePreprocessOpPass( + const QuantizationSpecs& quant_specs, const OpSet op_set) { + return std::make_unique(quant_specs, op_set); +} + +static PassRegistration pass; + +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.td new file mode 100644 index 00000000000..66ff5f06752 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.td @@ -0,0 +1,28 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" + +// Converts tf.Const to arith.constant for statically shaped, non-opaque constants. +// Needed for QuantizationDriver to recognize constants. +def ConvertTfConstToArithConst : Pat< + (TF_ConstOp:$res DenseElementsAttr:$value), + (Arith_ConstantOp $value), + [(AnyStaticShapeTensor $res)], (addBenefit 10)>; \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc index 432d2c9720f..3be531c05d9 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" @@ -52,23 +53,23 @@ namespace quant { //===----------------------------------------------------------------------===// // The actual Quantize Pass. -// +//===----------------------------------------------------------------------===// namespace { enum QuantizationTrait { kFullQuantization, kDynamicRangeQuantization }; // Base struct for quantization. -template +template struct TFQuantizationBase - : public QuantizationPattern { + /*VerifierT=*/void, RootOpT> { explicit TFQuantizationBase(MLIRContext* ctx, const QuantPassSpec& quant_params) - : QuantizationPattern(ctx, quant_params) {} + /*VerifierT=*/void, RootOpT>(ctx, quant_params) {} // Custom op quantization is not supported. static bool IsQuantizableCustomOp(Operation* op, @@ -90,11 +91,13 @@ struct TFQuantizationBase return quantization_trait == kDynamicRangeQuantization; } - // Weight-only quantization is not supported. - static bool IsWeightOnlyOp(Operation* quantized_op, StringSet& ops_blocklist, + // All the quantized ops are supported if the quantization method is weight + // only quantization. + static bool IsWeightOnlyOp(Operation* quantized_op, + absl::flat_hash_set& ops_blocklist, bool weight_only_quantization, const CustomMap& custom_op_map) { - return false; + return weight_only_quantization; } }; @@ -132,17 +135,16 @@ struct TFDynamicRangeQuantization // The benefit of this pattern is set to lower value than other patterns, so // that the other patterns can work on quantize/dequantize ops first. class RemoveUnusedQdqPattern - : public OpRewritePattern { + : public OpRewritePattern { public: explicit RemoveUnusedQdqPattern(MLIRContext* context) - : OpRewritePattern(context) {} - LogicalResult matchAndRewrite(quantfork::QuantizeCastOp op, + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(quantfork::DequantizeCastOp dq_op, PatternRewriter& rewriter) const override { - if (!op->hasOneUse() || - !llvm::isa(*op->getUsers().begin())) { - return failure(); - } - op->getUsers().begin()->getResult(0).replaceAllUsesWith(op.getArg()); + auto q_op = dq_op.getArg().getDefiningOp(); + if (!q_op) return failure(); + + dq_op.replaceAllUsesWith(q_op.getArg()); return success(); } }; @@ -151,11 +153,13 @@ class QuantizeSameScaleOpsPattern : public OpRewritePattern { public: explicit QuantizeSameScaleOpsPattern( - MLIRContext* context, OpQuantScaleSpecGetter op_quant_scale_spec_getter) + MLIRContext* context, OpQuantScaleSpecGetter op_quant_scale_spec_getter, + OpSet target_opset) // Set the score to a large number so it is always preferred, after // quantization patterns. : OpRewritePattern(context, /*benefit=*/200), - op_quant_scale_spec_getter_(op_quant_scale_spec_getter) {} + op_quant_scale_spec_getter_(op_quant_scale_spec_getter), + target_opset_(target_opset) {} LogicalResult matchAndRewrite(quantfork::DequantizeCastOp op, PatternRewriter& rewriter) const override { @@ -184,6 +188,16 @@ class QuantizeSameScaleOpsPattern continue; } + if (target_opset_ == OpSet::XLA && + !IsConnectedWithCompsiteFunction(quantizing_op)) { + continue; + } + + // Same scale op is not supported for Uniform Quantized ops. + if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { + continue; + } + // Collect all the quantized inputs and "clone" the matched op by these // inputs. SmallVector inputs; @@ -263,7 +277,7 @@ class QuantizeSameScaleOpsPattern if (quantizing_op->getNumRegions() != 0) { for (const auto& indexed_regions : llvm::enumerate(quantizing_op->getRegions())) { - BlockAndValueMapping mapping; + IRMapping mapping; indexed_regions.value().cloneInto( &quantized_op->getRegion(indexed_regions.index()), mapping); } @@ -282,7 +296,105 @@ class QuantizeSameScaleOpsPattern } private: + // Checks whether the operation is connnected with a composite function. + // If not, the same-scale op will not be quantized. This decision is based + // on the current assumption that the performance gain of the same-scale + // op itself could not beat the overhead of the quantize and dequantize + // routines need to be added around that op. When the assumption changes, + // this policy might change as well. + bool IsConnectedWithCompsiteFunction(Operation* same_scale_op) const { + for (const auto& operand : same_scale_op->getOperands()) { + auto dq_op = dyn_cast_or_null( + operand.getDefiningOp()); + if (!dq_op) continue; + + Operation* preceding_op = dq_op.getArg().getDefiningOp(); + if (!preceding_op) continue; + + // Check whether the preceding op is a quantized composite function. + if (llvm::isa(preceding_op)) { + auto call_op = llvm::cast(preceding_op); + if (!IsCompositeFunction(call_op)) continue; + return true; + } + + // Check if the preceding op is a quantized same-scale op. + if (llvm::isa(preceding_op)) { + auto sc_op = llvm::cast(preceding_op); + auto sc_arg_type = sc_op.getArg().getType().dyn_cast(); + if (sc_arg_type.getElementType().isInteger(8)) { + return true; + } + } + } + + for (const auto& result : same_scale_op->getResults()) { + // If the user is the Quantize op, it must be the only user. + if (!result.hasOneUse() || + !llvm::isa(*result.user_begin())) { + continue; + } + + auto q_op = llvm::cast(*result.user_begin()); + for (auto following_op : q_op->getUsers()) { + // Check whether the preceding op is a quantized composite function. + if (llvm::isa(following_op)) { + auto call_op = llvm::cast(following_op); + if (!IsCompositeFunction(call_op)) continue; + return true; + } + + // Check if the preceding op is a quantized same-scale op. + if (llvm::isa(following_op)) { + auto sc_op = llvm::cast(following_op); + auto sc_arg_type = sc_op.getResult().getType().dyn_cast(); + if (sc_arg_type.getElementType().isInteger(8)) { + return true; + } + } + } + } + + return false; + } + + // Checks if op calls a composite function and all the inputs are quantized. + bool IsCompositeFunction(TF::PartitionedCallOp call_op) const { + if (!call_op->hasAttr(kQuantTraitAttrName)) { + return false; + } + + const auto f_attr = call_op.getFAttr().dyn_cast(); + if (!f_attr || !f_attr.getValue().startswith("composite_")) { + return false; + } + + bool has_quantized_types = false; + for (Value input : call_op.getArgs()) { + if (auto type = input.getType().dyn_cast()) { + if (type.getElementType().isa()) { + return false; + } + if (type.getElementType().isa()) { + has_quantized_types = true; + } + } + } + for (Value output : call_op.getOutput()) { + if (auto type = output.getType().dyn_cast()) { + if (type.getElementType().isa()) { + return false; + } + if (type.getElementType().isa()) { + has_quantized_types = true; + } + } + } + return has_quantized_types; + } + OpQuantScaleSpecGetter op_quant_scale_spec_getter_; + OpSet target_opset_; }; // The AvgPool op is a same-scale op but it doesn't have int8 kernel, so @@ -290,54 +402,50 @@ class QuantizeSameScaleOpsPattern // TODO(b/229183248): Remove this workaround after int8 kernels have been // added to TF and XLA. struct QuantizeAvgPoolOpPattern - : public OpRewritePattern { + : public OpRewritePattern { explicit QuantizeAvgPoolOpPattern(MLIRContext* context) - : OpRewritePattern(context, /*benefit=*/300) {} + : OpRewritePattern(context, /*benefit=*/100) {} - LogicalResult matchAndRewrite(quantfork::QuantizeCastOp q_op, + LogicalResult matchAndRewrite(quantfork::StorageCastOp sc_op, PatternRewriter& rewriter) const override { - auto avg_pool_op = q_op.getArg().getDefiningOp(); + auto avg_pool_op = sc_op.getArg().getDefiningOp(); if (!avg_pool_op) return failure(); - auto dq_op = dyn_cast_or_null( - avg_pool_op.value().getDefiningOp()); - if (!dq_op) return failure(); + auto preceding_sc_op = dyn_cast_or_null( + avg_pool_op.getValue().getDefiningOp()); + if (!preceding_sc_op) return failure(); // Check if the same-scale requirement is met. - auto dq_arg_type = dq_op.getArg().getType().cast(); + auto dq_arg_type = preceding_sc_op.getArg().getType().cast(); auto qtype = dq_arg_type.getElementType().cast(); - auto q_result_type = q_op.getType().cast(); + auto q_result_type = sc_op.getType().cast(); auto out_qtype = q_result_type.getElementType().cast(); if (qtype != out_qtype) { avg_pool_op.emitError( - "The preceding DequantizeCastOp and the following " - "QuantizeCastOp must have the same quantized type"); + "The preceding StorageCastOp and the following " + "StorageCastOp must have the same quantized type"); return failure(); } // Cast to float type before the AvgPool op. OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPointAfter(dq_op); - auto scast_op = rewriter.create( - dq_op->getLoc(), dq_arg_type.clone(qtype.getStorageType()), - dq_op.getArg()); + rewriter.setInsertionPointAfter(preceding_sc_op); auto fcast_op = rewriter.create( - dq_op->getLoc(), dq_arg_type.clone(rewriter.getF32Type()), - scast_op.getResult()); - dq_op.getResult().replaceUsesWithIf(fcast_op.y(), [&](OpOperand& operand) { - return operand.getOwner() == avg_pool_op; - }); + preceding_sc_op->getLoc(), dq_arg_type.clone(rewriter.getF32Type()), + preceding_sc_op.getResult()); + + // Create a new AvgPool op with float type. + TF::AvgPoolOp float_avg_pool_op = rewriter.create( + avg_pool_op->getLoc(), + avg_pool_op.getType().clone(rewriter.getF32Type()), + /*operands=*/fcast_op.getResult(), + /*attributes=*/avg_pool_op->getAttrs()); // Cast back to the storage type after AvgPool op. - rewriter.setInsertionPointAfter(avg_pool_op); - auto const_val = CreateScalarConstValue(rewriter, q_op.getLoc(), 0.5f); - auto add_val = rewriter.create( - q_op.getLoc(), avg_pool_op.output(), const_val); - auto floor_val = rewriter.create(q_op.getLoc(), add_val); + auto round_val = rewriter.create( + sc_op.getLoc(), float_avg_pool_op.getOutput()); auto icast_op = rewriter.create( - q_op.getLoc(), q_result_type.clone(qtype.getStorageType()), floor_val); - auto iscast_op = rewriter.create( - q_op.getLoc(), q_op.getType(), icast_op.y()); - q_op.getResult().replaceAllUsesWith(iscast_op.getResult()); + sc_op.getLoc(), q_result_type.clone(qtype.getStorageType()), round_val); + avg_pool_op.getResult().replaceAllUsesWith(icast_op.getResult()); return success(); } }; @@ -354,13 +462,16 @@ class QuantizePass } // Constructor used by manually creating the pass. - explicit QuantizePass(const QuantizationSpecs& quant_specs) + explicit QuantizePass(const QuantizationSpecs& quant_specs, + OpSet target_opset) : quant_specs_(quant_specs) { weight_quantization_ = quant_specs.weight_quantization; + target_opset_ = target_opset; } QuantizePass(const QuantizePass& other) : quant_specs_(other.quant_specs_) { weight_quantization_ = other.weight_quantization_; + target_opset_ = other.target_opset_; } StringRef getArgument() const final { @@ -373,6 +484,10 @@ class QuantizePass return "Apply quantization on models in TensorFlow dialect"; } + // Determine if the unused Q-DQ pairs need to be removed. For weight-only + // quantizable ops, Q-DQ ops need to be preserved. + bool shouldKeepUnusedQdqPattern(); + void runOnOperation() override; private: @@ -381,8 +496,23 @@ class QuantizePass Option weight_quantization_{ *this, "weight-quantization", llvm::cl::init(false), llvm::cl::desc("Whether to enable weight quantization.")}; + Option target_opset_{ + *this, "target-opset", llvm::cl::init(OpSet::TF), + llvm::cl::desc("Choose target opset."), + llvm::cl::values( + clEnumValN(OpSet::TF, "TF", + "Uses TF ops that mimic quantization behavior"), + clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"), + clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED", + "Uses TF Uniform Quantized ops"))}; }; +bool QuantizePass::shouldKeepUnusedQdqPattern() { + return target_opset_ == OpSet::XLA && + (quant_specs_.weight_only_quantization || + quant_specs_.weight_quantization); +} + void QuantizePass::runOnOperation() { RewritePatternSet patterns(&getContext()); auto func = getOperation(); @@ -399,26 +529,29 @@ void QuantizePass::runOnOperation() { } else { patterns.add(ctx, quant_params); - patterns.add(ctx, GetTfQuantScaleSpec); + patterns.add(ctx, GetTfQuantScaleSpec, + target_opset_); patterns.add(ctx); } (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); - RewritePatternSet patterns_2(&getContext()); - patterns_2.add(ctx); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns_2)); + if (!shouldKeepUnusedQdqPattern()) { + RewritePatternSet patterns_2(&getContext()); + patterns_2.add(ctx); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns_2)); + } } } // namespace // Creates an instance of the TensorFlow dialect Quantize pass. std::unique_ptr> CreateQuantizePass() { QuantizationSpecs quant_specs; - return std::make_unique(quant_specs); + return std::make_unique(quant_specs, OpSet::TF); } std::unique_ptr> CreateQuantizePass( - QuantizationSpecs quant_specs) { - return std::make_unique(quant_specs); + QuantizationSpecs quant_specs, OpSet target_opset) { + return std::make_unique(quant_specs, target_opset); } static PassRegistration pass; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc index eabb26b51f7..1e1686efe1c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc @@ -12,13 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include +#include #include #include +#include "absl/algorithm/container.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/Casting.h" @@ -36,13 +42,17 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" @@ -52,9 +62,16 @@ namespace mlir { namespace quant { namespace { -constexpr char kQuantizeFuncName[] = "quantize_i8"; -constexpr char kDequantizeFuncName[] = "dequantize_i8"; -constexpr char kAttrMapAttribute[] = "attr_map"; +using QuantMethod = + tensorflow::quantization::QuantizationMethod::ExperimentalMethod; + +constexpr StringRef kQuantizeFuncName = "quantize_i8"; +constexpr StringRef kDequantizeFuncName = "dequantize_i8"; +constexpr StringRef kAttrMapAttribute = "attr_map"; +constexpr StringRef kQuantizedOpsAttribute = "tf_quant.quantized_ops"; +constexpr StringRef kCompositeFuncPrefix = "composite_"; +constexpr StringRef kQuantizedFuncPrefix = "quantized_"; +constexpr StringRef kFloatOutputFuncPrefix = "_float_output_fn"; class QuantizeCompositeFunctionsPass : public mlir::PassWrapper quantization_method_{ + Option quantization_method_{ *this, "quantization-method", - llvm::cl::init(QuantizationMethod::kPostTrainingQuantization), + llvm::cl::init( + tensorflow::quantization::QuantizationMethod::STATIC_RANGE), llvm::cl::desc("Choose quantization method."), llvm::cl::values( - clEnumValN(QuantizationMethod::kPostTrainingQuantization, "ptq", - "Post-training static-range quantization"), - clEnumValN(QuantizationMethod::kDynamicRangeQuantization, "drq", - "Post-training dynamic-range quantizaiton"))}; + clEnumValN(tensorflow::quantization::QuantizationMethod::STATIC_RANGE, + "ptq", "Post-training static-range quantization"), + clEnumValN( + tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE, + "drq", "Post-training dynamic-range quantizaiton"), + clEnumValN(tensorflow::quantization::QuantizationMethod::WEIGHT_ONLY, + "weight_only", "Post-training weight-only quantizaiton"))}; + Option target_opset_{ *this, "target-opset", llvm::cl::init(OpSet::TF), llvm::cl::desc("Choose target opset."), @@ -113,6 +141,12 @@ class QuantizeCompositeFunctionsPass clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"), clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED", "Uses TF Uniform Quantized ops"))}; + + Option enable_per_channel_quantization_{ + *this, "enable-per-channel-quantization", llvm::cl::init(false), + llvm::cl::desc("Whether enable per-channel quantized weights.")}; + + int min_num_elements_for_weights_; }; LogicalResult CreateUniformQuantizedTypeParams(UniformQuantizedType qtype, @@ -178,14 +212,55 @@ LogicalResult CreateQuantizationParams(QuantizedType elem_type, Location loc, return failure(); } +// Converts the element type of the input tensor to the corresponding quantized +// version. Supports only int8 for now and returns nullptr if the input type is +// not supported. +ShapedType ConvertIntToQint(ShapedType input_type, MLIRContext* ctx) { + int bit_width; + bool is_signed; + + Type ele_type = input_type.getElementType(); + if (ele_type.isIntOrFloat()) { + bit_width = ele_type.getIntOrFloatBitWidth(); + is_signed = ele_type.isSignlessIntOrFloat() || ele_type.isSignedInteger(); + } else if (QuantizedType qtype = ele_type.dyn_cast()) { + bit_width = qtype.getStorageTypeIntegralWidth(); + is_signed = qtype.isSigned(); + } else { + return input_type; + } + + Type new_storage_type; + if (is_signed) { + switch (bit_width) { + case 8: + new_storage_type = TF::Qint8Type::get(ctx); + break; + case 32: + new_storage_type = TF::Qint32Type::get(ctx); + break; + default: + return nullptr; // Not yet supported + } + } else { + return nullptr; // Not yet supported + } + + input_type = input_type.clone(new_storage_type); + return input_type; +} + // Replaces quant.qcast op to composite quantize_i8 function. class ReplaceQuantizePattern : public mlir::OpRewritePattern { public: - explicit ReplaceQuantizePattern(MLIRContext* context) - : OpRewritePattern(context) {} + explicit ReplaceQuantizePattern(MLIRContext* context, OpSet target_opset) + : OpRewritePattern(context), + target_opset_(target_opset) {} private: + OpSet target_opset_ = OpSet::TF; + LogicalResult matchAndRewrite(quantfork::QuantizeCastOp q_op, PatternRewriter& rewriter) const override { auto output_type = q_op.getType().cast(); @@ -198,8 +273,21 @@ class ReplaceQuantizePattern return failure(); } - SmallVector output_types = { - output_type.clone(elem_type.getStorageType())}; + SmallVector output_types; + + if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { + ShapedType new_output_type = ConvertIntToQint( + output_type.cast(), rewriter.getContext()); + if (!new_output_type) { + q_op->emitError( + "Failed to convert the type to the corresponding qtype."); + return failure(); + } + output_types = {new_output_type}; + } else { + output_types = {output_type.clone(elem_type.getStorageType())}; + } + SmallVector args = {q_op.getArg(), scale, zero_point}; FlatSymbolRefAttr func_name = FlatSymbolRefAttr::get(rewriter.getStringAttr(kQuantizeFuncName)); @@ -218,10 +306,13 @@ class ReplaceQuantizePattern class ReplaceDequantizePattern : public mlir::OpRewritePattern { public: - explicit ReplaceDequantizePattern(MLIRContext* context) - : OpRewritePattern(context) {} + explicit ReplaceDequantizePattern(MLIRContext* context, OpSet target_opset) + : OpRewritePattern(context), + target_opset_(target_opset) {} private: + OpSet target_opset_ = OpSet::TF; + LogicalResult matchAndRewrite(quantfork::DequantizeCastOp dq_op, PatternRewriter& rewriter) const override { auto input_type = dq_op.getArg().getType().cast(); @@ -235,6 +326,17 @@ class ReplaceDequantizePattern } TensorType output_type = input_type.clone(elem_type.getStorageType()); + if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { + ShapedType new_output_type = ConvertIntToQint( + output_type.cast(), rewriter.getContext()); + if (!new_output_type) { + dq_op->emitError( + "Failed to convert the type to the corresponding qtype."); + return failure(); + } + output_type = new_output_type.cast(); + } + auto scast_op = rewriter.create(loc, output_type, dq_op.getArg()); @@ -254,10 +356,10 @@ class ReplaceDequantizePattern // index information for each op. bool IsQuantizedCallforDynamicRange(TF::PartitionedCallOp call_op) { bool has_quantized_types_for_weights = false; - for (int32_t cur_idx = 0; cur_idx < call_op.args().size(); cur_idx++) { + for (int32_t cur_idx = 0; cur_idx < call_op.getArgs().size(); cur_idx++) { // Check if the only the weight index has QuantizeCastOp. auto cur_op = dyn_cast_or_null( - call_op.args()[cur_idx].getDefiningOp()); + call_op.getArgs()[cur_idx].getDefiningOp()); if ((!cur_op && cur_idx == 1) || (cur_op && cur_idx != 1)) { return false; } else if (cur_op) { @@ -270,7 +372,7 @@ bool IsQuantizedCallforDynamicRange(TF::PartitionedCallOp call_op) { has_quantized_types_for_weights = true; } } - for (Value output : call_op.output()) { + for (Value output : call_op.getOutput()) { if (auto type = output.getType().dyn_cast()) { if (type.getElementType().isa()) { return false; @@ -283,7 +385,7 @@ bool IsQuantizedCallforDynamicRange(TF::PartitionedCallOp call_op) { // Checks if all the inputs are quantized. bool IsQuantizedCallforStaticRange(TF::PartitionedCallOp call_op) { bool has_quantized_types = false; - for (Value input : call_op.args()) { + for (Value input : call_op.getArgs()) { if (auto type = input.getType().dyn_cast()) { if (type.getElementType().isa()) { return false; @@ -293,7 +395,7 @@ bool IsQuantizedCallforStaticRange(TF::PartitionedCallOp call_op) { } } } - for (Value output : call_op.output()) { + for (Value output : call_op.getOutput()) { if (auto type = output.getType().dyn_cast()) { if (type.getElementType().isa()) { return false; @@ -306,39 +408,67 @@ bool IsQuantizedCallforStaticRange(TF::PartitionedCallOp call_op) { return has_quantized_types; } -// Converts the element type of the input tensor to the corresponding quantized -// version. Supports only int8 for now and returns nullptr if the input type is -// not supported. -ShapedType ConvertIntToQint(ShapedType input_type, MLIRContext* ctx) { - int bit_width; - bool is_signed; +// Transfers the attributes of the corresponding ops from the float function to +// the quantized function using the attr_map attribute. In the quantized +// function, this map (map1) is in {attr_name_1: attr_identifier} format; and in +// the float function, this map (map2) is in {attr_identifier: attr_name_2} +// format. Where, the attribute identifiers should match between two maps, +// attr_name_1 is the name of the of the attribute needs to be set in the +// quantized function, attr_name_2 is the name of the attribute corresponding to +// the attribute identifier in the float function. +LogicalResult TransferTFAttributesToTFUniformAttributes( + PatternRewriter& rewriter, func::FuncOp float_func, + func::FuncOp quantized_func, QuantMethod quantization_method, + bool enable_per_channel_quantization) { + // A map to find an attribute from its identifier. + llvm::StringMap identifier_to_attr; - Type ele_type = input_type.getElementType(); - if (ele_type.isIntOrFloat()) { - bit_width = ele_type.getIntOrFloatBitWidth(); - is_signed = ele_type.isSignlessIntOrFloat() || ele_type.isSignedInteger(); - } else if (QuantizedType qtype = ele_type.dyn_cast()) { - bit_width = qtype.getStorageTypeIntegralWidth(); - is_signed = qtype.isSigned(); - } else { - return input_type; + for (Operation& inner_op : float_func.getBody().front().getOperations()) { + if (!inner_op.hasAttr(kAttrMapAttribute)) continue; + // Insert quantization related attribute if they exists. Quantization + // attributes are generated in the prepare pass so the attr_map doesn't + // contain the attribute names. + // TransferQuantizationAttributes(rewriter, inner_op, attrs); + std::string attr_map_str = + inner_op.getAttrOfType(kAttrMapAttribute).str(); + for (absl::string_view element_str : absl::StrSplit(attr_map_str, ',')) { + std::vector key_and_value_pair = + absl::StrSplit(element_str, ':'); + if (key_and_value_pair.size() != 2) { + float_func.emitError("The attr_map attribute is malformed"); + return failure(); + } + identifier_to_attr.insert( + {llvm::StringRef(std::string(key_and_value_pair[1])), + inner_op.getAttr( + llvm::StringRef(std::string(key_and_value_pair[1])))}); + } } - Type new_storage_type; - if (is_signed) { - switch (bit_width) { - case 8: - new_storage_type = mlir::TF::Qint8Type::get(ctx); - break; - default: - return nullptr; // Not yet supported + // Set the attributes for ops with the attr_map attribute. + for (Operation& inner_op : quantized_func.getBody().front().getOperations()) { + if (auto uniform_op = + llvm::dyn_cast(inner_op)) { + if (failed(FillAttributesForUniformQuantizedConvolutionOp( + rewriter, uniform_op, identifier_to_attr, quantization_method, + enable_per_channel_quantization))) + return failure(); + } else if (auto uniform_op = + llvm::dyn_cast( + inner_op)) { + if (failed(FillAttributesForUniformQuantizedConvolutionOp( + rewriter, uniform_op, identifier_to_attr, quantization_method, + enable_per_channel_quantization))) + return failure(); + } else if (auto uniform_op = + llvm::dyn_cast(inner_op)) { + if (failed(FillAttributesForUniformQuantizedDotOp( + rewriter, uniform_op, identifier_to_attr, quantization_method, + enable_per_channel_quantization))) + return failure(); } - } else { - return nullptr; // Not yet supported } - - input_type = input_type.clone(new_storage_type); - return input_type; + return success(); } // Transfers the attributes of the corresponding ops from the float function to @@ -400,26 +530,41 @@ LogicalResult TransferAttributes(func::FuncOp float_func, return success(); } +// Get the corresponding quantized function name from the given function name. +std::string GetQuantizedFunctionName(StringRef func_name) { + if (func_name.startswith(kQuantizedFuncPrefix)) return func_name.str(); + if (!func_name.startswith(kCompositeFuncPrefix)) return ""; + + return llvm::Twine(kQuantizedFuncPrefix) + .concat(llvm::Twine( + func_name.substr(kCompositeFuncPrefix.size()).rsplit("_fn").first)) + .concat("_fn") + .str(); +} + // Unwraps quantization parameters of PartitionedCall ops with quantized // input/outputs that are created from QuantizePass. class QuantizeFunctionPattern : public mlir::OpRewritePattern { public: explicit QuantizeFunctionPattern(MLIRContext* context, - QuantizationMethod quantization_method, - OpSet target_opset) + QuantMethod quantization_method, + OpSet target_opset, + bool enable_per_channel_quantization) : OpRewritePattern(context), quantization_method_(quantization_method), - target_opset_(target_opset) {} + target_opset_(target_opset), + enable_per_channel_quantization_(enable_per_channel_quantization) {} private: - QuantizationMethod quantization_method_ = - QuantizationMethod::kPostTrainingQuantization; + QuantMethod quantization_method_ = + tensorflow::quantization::QuantizationMethod::STATIC_RANGE; OpSet target_opset_ = OpSet::TF; + bool enable_per_channel_quantization_; LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, PatternRewriter& rewriter) const override { - const auto f_attr = call_op.fAttr().dyn_cast(); + const auto f_attr = call_op.getFAttr().dyn_cast(); // removeAttr will return nullptr if no attribute was removed. if (!call_op->removeAttr(kQuantTraitAttrName) || !f_attr) { return failure(); @@ -427,9 +572,11 @@ class QuantizeFunctionPattern // Determines if all required float input/outputs are now quantized. bool has_quantized_types = false; - if (quantization_method_ == QuantizationMethod::kDynamicRangeQuantization) { + if (quantization_method_ == + tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE) { has_quantized_types = IsQuantizedCallforDynamicRange(call_op); - if (f_attr.getValue().startswith("composite_") && !has_quantized_types) { + if (f_attr.getValue().startswith(kCompositeFuncPrefix) && + !has_quantized_types) { call_op->emitError( "Only quantizable ops need to be in composite function for dynamic" "-range PTQ case."); @@ -439,13 +586,14 @@ class QuantizeFunctionPattern has_quantized_types = IsQuantizedCallforStaticRange(call_op); } - if (!f_attr.getValue().startswith("composite_") || !has_quantized_types) { + if (!f_attr.getValue().startswith(kCompositeFuncPrefix) || + !has_quantized_types) { return failure(); } SmallVector args; SmallVector qparam_args; - for (Value arg : call_op.args()) { + for (Value arg : call_op.getArgs()) { if (const auto arg_type = arg.getType().dyn_cast()) { QuantizedType qtype = arg_type.getElementType().dyn_cast(); @@ -491,7 +639,7 @@ class QuantizeFunctionPattern rewriter.setInsertionPoint(call_op); - for (Value arg : call_op.args()) { + for (Value arg : call_op.getArgs()) { TensorType arg_type = arg.getType().dyn_cast(); if (!arg_type) { args.push_back(arg); @@ -504,8 +652,7 @@ class QuantizeFunctionPattern } quantfork::StorageCastOp scast_op; - if (quantization_method_ == - QuantizationMethod::kDynamicRangeQuantization) { + if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { ShapedType new_arg_type = ConvertIntToQint(arg_type.cast(), rewriter.getContext()); if (!new_arg_type) { @@ -548,11 +695,17 @@ class QuantizeFunctionPattern result_types.push_back(result_type); continue; } + + if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { + ShapedType new_result_type = ConvertIntToQint( + result_type.cast(), rewriter.getContext()); + result_types.push_back(new_result_type); + } else { + result_types.push_back(result_type.clone(qtype.getStorageType())); + } auto scast_op = rewriter.create( call_op.getLoc(), result_type, result); replace_map.insert(std::make_pair(result, scast_op)); - - result_types.push_back(result_type.clone(qtype.getStorageType())); } for (auto replace_pair : replace_map) { @@ -569,16 +722,16 @@ class QuantizeFunctionPattern dyn_cast(symbol_table.lookup(f_attr.getValue())); rewriter.setInsertionPointAfter(float_func); - // substr(10) == strip the "composite_" prefix. - const llvm::Twine quantized_function_name = llvm::Twine( - "quantized_", f_attr.getValue().substr(10).rsplit('_').first); - const mlir::func::FuncOp quantized_func = dyn_cast( - symbol_table.lookup(quantized_function_name.str())); + const std::string quantized_function_name = + GetQuantizedFunctionName(f_attr.getValue()); + const mlir::func::FuncOp quantized_func = + dyn_cast(symbol_table.lookup(quantized_function_name)); mlir::func::FuncOp new_quantized_func = dyn_cast(quantized_func->clone()); if (new_quantized_func == nullptr) { return failure(); } + new_quantized_func.setType( FunctionType::get(getContext(), TypeRange{ValueRange{args}}, new_quantized_func.getResultTypes())); @@ -588,8 +741,16 @@ class QuantizeFunctionPattern } // Set the attributes for ops with the attr_map attribute. - if (failed(TransferAttributes(float_func, new_quantized_func))) { - return failure(); + if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { + if (failed(TransferTFAttributesToTFUniformAttributes( + rewriter, float_func, new_quantized_func, quantization_method_, + enable_per_channel_quantization_))) { + return failure(); + } + } else { + if (failed(TransferAttributes(float_func, new_quantized_func))) { + return failure(); + } } rewriter.setInsertionPoint(call_op); @@ -639,15 +800,18 @@ class QuantizeFunctionPattern auto module = call_op->getParentOfType(); SymbolTable symbol_table(module); - const auto f_attr = call_op.fAttr().dyn_cast(); + const auto f_attr = call_op.getFAttr().dyn_cast(); const auto float_func = dyn_cast(symbol_table.lookup(f_attr.getValue())); rewriter.setInsertionPointAfter(float_func); - // substr(10) == strip the "composite_" prefix. - const std::string quantized_function_name = - "quantized_" + f_attr.getValue().substr(10).rsplit("_fn_").first.str() + - "_float_output_fn"; + // the length of the "_fn" suffix. + const size_t fn_suffix_length = 3; + std::string quantized_function_name = + GetQuantizedFunctionName(f_attr.getValue()); + quantized_function_name.replace( + quantized_function_name.size() - fn_suffix_length, fn_suffix_length, + kFloatOutputFuncPrefix); const auto quantized_func = dyn_cast(symbol_table.lookup(quantized_function_name)); auto new_quantized_func = dyn_cast(quantized_func->clone()); @@ -663,8 +827,16 @@ class QuantizeFunctionPattern } // Set the attributes for ops with the attr_map attribute. - if (failed(TransferAttributes(float_func, new_quantized_func))) { - return failure(); + if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { + if (failed(TransferTFAttributesToTFUniformAttributes( + rewriter, float_func, new_quantized_func, quantization_method_, + enable_per_channel_quantization_))) { + return failure(); + } + } else { + if (failed(TransferAttributes(float_func, new_quantized_func))) { + return failure(); + } } rewriter.setInsertionPoint(call_op); @@ -696,14 +868,11 @@ class QuantizeConstPattern : public OpRewritePattern { public: // This pattern should have larger benefit than ReplaceQuantizePattern - explicit QuantizeConstPattern(MLIRContext* context, - QuantizationMethod quantization_method) + explicit QuantizeConstPattern(MLIRContext* context, OpSet target_opset) : OpRewritePattern(context, /*benefit=*/10), - quantization_method_(quantization_method) {} + target_opset_(target_opset) {} private: - QuantizationMethod quantization_method_ = - QuantizationMethod::kPostTrainingQuantization; LogicalResult matchAndRewrite(quantfork::QuantizeCastOp q_op, PatternRewriter& rewriter) const override { DenseFPElementsAttr attr; @@ -721,11 +890,9 @@ class QuantizeConstPattern tensor_qtype.getElementType().cast().getStorageType(); ShapedType new_type = tensor_qtype.clone(storage_type); Location loc = q_op.getArg().getLoc(); - // Convert integer to quantized integer type. Currently only applied for - // dynamic range quantization case. - if (quantization_method_ == QuantizationMethod::kDynamicRangeQuantization) { + + if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { new_type = ConvertIntToQint(new_type, rewriter.getContext()); - tensor_qtype = ConvertIntToQint(tensor_qtype, rewriter.getContext()); // TODO(b/225793355): It adds TensorProtoAttr to the constant as a // workaround. @@ -735,7 +902,12 @@ class QuantizeConstPattern return failure(); } - tensor_proto.set_dtype(tensorflow::DT_QINT8); + const int bit_width = tensor_qtype.getElementType() + .dyn_cast() + .getStorageTypeIntegralWidth(); + + tensor_proto.set_dtype((bit_width == 8) ? tensorflow::DT_QINT8 + : tensorflow::DT_QINT32); tensor_proto_attr = ElementsAttr(TF::TensorProtoAttr::get( new_type, tensorflow::mangling_util::MangleTensor(tensor_proto))); @@ -745,10 +917,161 @@ class QuantizeConstPattern // Add scast op to match quantize -> composition pattern. The added scast // is then removed by canonicalization. ([scast - scast] -> []) auto scast_op = rewriter.create( - loc, tensor_qtype, const_op.output()); + loc, tensor_qtype, const_op.getOutput()); q_op->replaceAllUsesWith(scast_op); return success(); } + + OpSet target_opset_; +}; + +// Prints a summary about the quantization results. +class QuantizationSummary { + public: + explicit QuantizationSummary(ModuleOp module) + : module_(module), symbol_table_(module) {} + + void Print() { + llvm::StringMap func_count_map; + int32_t total_quantized_func_count = 0, float_output_func_count = 0, + quantize_func_count = 0, dequantize_func_count = 0, + weight_only_count = 0; + + module_.walk([&](Operation* op) { + if (auto call_op = llvm::dyn_cast_or_null(op)) { + const auto f_attr = call_op.getFAttr().dyn_cast(); + if (!f_attr) return; + StringRef func_name = f_attr.getValue(); + if (func_name.startswith(kQuantizedFuncPrefix)) { + auto representative_name = GetRepresentativeName(func_name); + if (failed(representative_name)) return; + + func_count_map[representative_name.value()].num_quant++; + total_quantized_func_count++; + if (func_name.contains(kFloatOutputFuncPrefix)) { + float_output_func_count++; + } + } else if (func_name.startswith(kCompositeFuncPrefix)) { + auto representative_name = GetRepresentativeName(func_name); + if (failed(representative_name)) { + // TODO(b/264507511): Print quantization summary for weight-only. + weight_only_count++; + } else { + func_count_map[representative_name.value()].num_float++; + } + } else if (func_name.startswith("quantize_i")) { + quantize_func_count++; + } else if (func_name.startswith("dequantize_i")) { + dequantize_func_count++; + } + } else if (auto einsum = llvm::isa(op)) { + if (IsInCompsiteFunction(op)) return; + // Leftover Einsum ops are always non-quantized. + auto op_name = op->getName().stripDialect(); + func_count_map[op_name].num_float++; + } + }); + + // Pad string to a certain size to format the table. Space is preferred to + // Tab since it is easier to check the format in the mlir tests. + auto pad_string = [](StringRef s, int32_t width) -> std::string { + return llvm::Twine(s).concat(std::string(width - s.size(), ' ')).str(); + }; + + // Generate a quantization report. + size_t name_col_width = 5; + absl::c_for_each(func_count_map.keys(), [&name_col_width](const auto& key) { + name_col_width = std::max(name_col_width, key.size() + 1); + }); + + std::vector lines; + lines.push_back("-------- Quantization Summary --------"); + lines.push_back("Number of quantized layers in the model"); + lines.push_back("--------------------------------"); + lines.push_back( + absl::StrFormat("%s Count/Total", pad_string("Name", name_col_width))); + lines.push_back("================================"); + for (StringRef op_name : func_count_map.keys()) { + const int32_t quantized_count = func_count_map[op_name].num_quant; + const int32_t total_count = + quantized_count + func_count_map[op_name].num_float; + lines.push_back(absl::StrFormat("%s %d/%d", + pad_string(op_name, name_col_width), + quantized_count, total_count)); + } + lines.push_back(""); + lines.push_back(absl::StrFormat( + "Number of quantized layers with quantized outputs: %d/%d", + total_quantized_func_count - float_output_func_count, + total_quantized_func_count)); + lines.push_back(absl::StrFormat("Number of quantize layers added: %d", + quantize_func_count)); + lines.push_back(absl::StrFormat("Number of dequantize layers added: %d", + dequantize_func_count)); + lines.push_back(""); + + // Make the report visible by default. + const std::string log_message = + absl::StrJoin(lines.begin(), lines.end(), /*separator=*/"\n"); + llvm::errs() << log_message; + + // Create a FuncOp and attach the quantization summary to it. This is a + // a hack to check the summary in mlir tests. This function will be + // automatically removed since this pass is always followed by the Symbol + // DCE pass. + OpBuilder builder(module_); + builder.setInsertionPointToEnd(&module_.getBodyRegion().back()); + const auto func_type = + builder.getFunctionType(/*inputs=*/{}, /*results=*/{}); + auto summary_func = builder.create( + builder.getUnknownLoc(), /*sym_name=*/"summary", func_type); + summary_func.setPrivate(); + summary_func->setAttr("quantization_summary", + builder.getStringAttr(log_message)); + } + + private: + // Structs used to count quantized and non-quantized ops. + struct OpCountItem { + int32_t num_quant = 0; + int32_t num_float = 0; + }; + + // Get the representative name attribute value of a composite function. + FailureOr GetRepresentativeName(StringRef func_name) { + std::string quantized_func_name = GetQuantizedFunctionName(func_name); + auto quantized_func = dyn_cast_or_null( + symbol_table_.lookup(quantized_func_name)); + // Quantized function does not exist for weight-only case. + if (!quantized_func || + !quantized_func->hasAttrOfType(kQuantizedOpsAttribute)) { + return failure(); + } + + auto quantized_ops = + quantized_func->getAttrOfType(kQuantizedOpsAttribute) + .getValue(); + if (quantized_ops.empty()) { + quantized_func->emitError() << "At least one op is expected in the " + << kQuantizedOpsAttribute << " attribute."; + return failure(); + } + + // Use the first op as the representative name. + return quantized_ops.front().cast().getValue(); + } + + bool IsInCompsiteFunction(Operation* op) { + func::FuncOp parent = op->getParentOfType(); + if (!parent) return false; + + StringRef sym_name = parent.getSymName(); + return sym_name.startswith(kQuantizedFuncPrefix) || + sym_name.startswith(kCompositeFuncPrefix); + } + + ModuleOp module_; + SymbolTable symbol_table_; }; static PassRegistration pass; @@ -766,47 +1089,71 @@ void QuantizeCompositeFunctionsPass::runOnOperation() { pm.enableVerifier(false); QuantizationSpecs quant_specs; - if (quantization_method_ == QuantizationMethod::kDynamicRangeQuantization) { + pm.addPass(CreatePreprocessOpPass(quant_specs, target_opset_)); + + quant_specs.inference_type = tensorflow::DT_QINT8; + quant_specs.disable_per_channel = !enable_per_channel_quantization_; + // Apply activation-weight quantization. + if (quantization_method_ == + tensorflow::quantization::QuantizationMethod::STATIC_RANGE) { + pm.addNestedPass( + CreatePrepareQuantizePass(quant_specs, quantization_method_)); + pm.addNestedPass( + CreateQuantizePass(quant_specs, target_opset_)); + pm.addNestedPass(CreatePostQuantizePass()); + } + if ((quantization_method_ != + tensorflow::quantization::QuantizationMethod::STATIC_RANGE) || + (target_opset_ == OpSet::XLA)) { + // Apply weight quantization. + // For XLA case, weight quantization will be applied for the remaining f32 + // weights even in SRQ. + quant_specs.minimum_elements_for_weights = min_num_elements_for_weights_; quant_specs.weight_quantization = true; - quant_specs.inference_type = tensorflow::DT_QINT8; - pm.addNestedPass(CreatePrepareQuantizeDRQPass()); - } else { + pm.addPass(CreatePrepareQuantizeDRQPass(quant_specs, target_opset_)); + if (quantization_method_ != + tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE) { + quant_specs.weight_only_quantization = true; + } pm.addNestedPass( - CreatePrepareQuantizePass(quantization_method_)); + CreateQuantizePass(quant_specs, target_opset_)); + pm.addNestedPass(CreatePostQuantizePass()); } - pm.addNestedPass(CreateQuantizePass(quant_specs)); - - pm.addNestedPass(CreatePostQuantizePass()); if (failed(pm.run(module))) { signalPassFailure(); } RewritePatternSet patterns(ctx); patterns.add(ctx, quantization_method_, - target_opset_); + target_opset_, + enable_per_channel_quantization_); if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { signalPassFailure(); } // Constant quantization is a lossy transformation, so they are applied only - // after all the other patterns have been aplied. + // after all the other patterns have been applied. RewritePatternSet patterns_2(ctx); populateWithGenerated(patterns_2); - patterns_2.add(ctx); - patterns_2.add(ctx, quantization_method_); + patterns_2.add( + ctx, target_opset_); + patterns_2.add(ctx, target_opset_); if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns_2))) || failed(verify(module))) { signalPassFailure(); } + QuantizationSummary(module).Print(); } } // namespace std::unique_ptr> CreateQuantizeCompositeFunctionsPass( - QuantizationMethod quantization_method, OpSet target_opset) { - return std::make_unique(quantization_method, - target_opset); + QuantMethod quantization_method, OpSet target_opset, + bool enable_per_channel_quantization, int min_num_elements_for_weights) { + return std::make_unique( + quantization_method, target_opset, enable_per_channel_quantization, + min_num_elements_for_weights); } } // namespace quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library.mlir index 17c17383f69..37aaae94971 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library.mlir @@ -38,18 +38,20 @@ module { %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*xf32> { %scale_prod = "tf.Mul"(%input_scale, %filter_scale) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> %rescale_factor = "tf.Div"(%scale_prod, %out_scale) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - - // Uses tf.floor(x + 0.5) instead of tf.round(x) since tf.round generates - // a very expensive pattern. - %round_cst = "tf.Const"() {value = dense<0.5> : tensor} : () -> tensor %float_out_zp = "tf.Cast"(%out_zp) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> - %zp_plus_round_cst = "tf.AddV2"(%float_out_zp, %round_cst) : (tensor<*xf32>, tensor) -> tensor<*xf32> %cast = "tf.Cast"(%accumulation) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> %mul = "tf.Mul"(%cast, %rescale_factor) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %add = "tf.AddV2"(%mul, %zp_plus_round_cst) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %round = "tf.Floor"(%add) : (tensor<*xf32>) -> tensor<*xf32> - func.return %round : tensor<*xf32> + %add = "tf.AddV2"(%mul, %float_out_zp) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + func.return %add : tensor<*xf32> + } + + func.func private @internal_dequantize_i8_fn(%input : tensor<*xi8>, %scale : tensor<*xf32>, %zp : tensor<*xi32>) -> tensor<*xf32> { + %input_i32 = "tf.Cast"(%input) : (tensor<*xi8>) -> tensor<*xi32> + %output = "tf.Sub"(%input_i32, %zp) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + %cast = "tf.Cast"(%output) : (tensor<*xi32>) -> tensor<*xf32> + %mul = "tf.Mul"(%cast, %scale) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + func.return %mul : tensor<*xf32> } // Requantizes and clips to the range of quantized type if there is no specific activation. @@ -64,9 +66,12 @@ module { tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32> %i8_min = "tf.Const"() {value = dense<-128.0> : tensor} : () -> tensor %i8_max = "tf.Const"() {value = dense<127.0> : tensor} : () -> tensor - %0 = "tf.ClipByValue"(%rescale, %i8_min, %i8_max) : (tensor<*xf32>, tensor, tensor) -> tensor<*xf32> - %1 = "tf.Cast"(%0) {Truncate = false} : (tensor<*xf32>) -> tensor<*xi8> - func.return %1 : tensor<*xi8> + + %clamp_max = "tf.Maximum"(%rescale, %i8_min) : (tensor<*xf32>, tensor) -> tensor<*xf32> + %clamp_min = "tf.Minimum"(%clamp_max, %i8_max) : (tensor<*xf32>, tensor) -> tensor<*xf32> + %round = "tf.Round"(%clamp_min) : (tensor<*xf32>) -> tensor<*xf32> + %cast = "tf.Cast"(%round) {Truncate = false} : (tensor<*xf32>) -> tensor<*xi8> + func.return %cast : tensor<*xi8> } // Requantizes and applies quantized Relu by clipping. @@ -82,10 +87,13 @@ module { %i8_min = "tf.Const"() {value = dense<-128.0> : tensor} : () -> tensor %i8_max = "tf.Const"() {value = dense<127.0> : tensor} : () -> tensor %float_out_zp = "tf.Cast"(%out_zp) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> - %clip_min = "tf.Maximum"(%i8_min, %float_out_zp) : (tensor, tensor<*xf32>) -> tensor - %0 = "tf.ClipByValue"(%rescale, %clip_min, %i8_max) : (tensor<*xf32>, tensor, tensor) -> tensor<*xf32> - %1 = "tf.Cast"(%0) {Truncate = false} : (tensor<*xf32>) -> tensor<*xi8> - func.return %1 : tensor<*xi8> + %clip_min = "tf.Maximum"(%i8_min, %float_out_zp) : (tensor, tensor<*xf32>) -> tensor<*xf32> + + %clamp_max = "tf.Maximum"(%rescale, %clip_min) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %clamp_min = "tf.Minimum"(%clamp_max, %i8_max) : (tensor<*xf32>, tensor) -> tensor<*xf32> + %round = "tf.Round"(%clamp_min) : (tensor<*xf32>) -> tensor<*xf32> + %cast = "tf.Cast"(%round) {Truncate = false} : (tensor<*xf32>) -> tensor<*xi8> + func.return %cast : tensor<*xi8> } // Requantizes and applies quantized Relu6 by clipping. @@ -103,14 +111,17 @@ module { %act_max = "tf.Const"() {value = dense<6.0> : tensor} : () -> tensor %i8_act_max_0 = "tf.PartitionedCall"(%act_max, %out_scale, %out_zp) { config = "", config_proto = "", executor_type = "", f=@quantize_i8 - } : (tensor, tensor<*xf32>, tensor<*xi32>) -> tensor - %i8_act_max_1 = "tf.Cast"(%i8_act_max_0) {Truncate = false} : (tensor) -> tensor + } : (tensor, tensor<*xf32>, tensor<*xi32>) -> tensor<*xi8> + %i8_act_max_1 = "tf.Cast"(%i8_act_max_0) {Truncate = false} : (tensor<*xi8>) -> tensor<*xf32> %float_out_zp = "tf.Cast"(%out_zp) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> - %clip_min = "tf.Maximum"(%i8_min, %float_out_zp) : (tensor, tensor<*xf32>) -> tensor - %clip_max = "tf.Minimum"(%i8_max, %i8_act_max_1) : (tensor, tensor) -> tensor - %0 = "tf.ClipByValue"(%rescale, %clip_min, %clip_max) : (tensor<*xf32>, tensor, tensor) -> tensor<*xf32> - %1 = "tf.Cast"(%0) {Truncate = false} : (tensor<*xf32>) -> tensor<*xi8> - func.return %1 : tensor<*xi8> + %clip_min = "tf.Maximum"(%i8_min, %float_out_zp) : (tensor, tensor<*xf32>) -> tensor<*xf32> + %clip_max = "tf.Minimum"(%i8_max, %i8_act_max_1) : (tensor, tensor<*xf32>) -> tensor<*xf32> + + %clamp_max = "tf.Maximum"(%rescale, %clip_min) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %clamp_min = "tf.Minimum"(%clamp_max, %clip_max) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %round = "tf.Round"(%clamp_min) : (tensor<*xf32>) -> tensor<*xf32> + %cast = "tf.Cast"(%round) {Truncate = false} : (tensor<*xf32>) -> tensor<*xi8> + func.return %cast : tensor<*xi8> } // Dequantizes and clips to the range of quantized type if there is no specific activation. @@ -121,7 +132,21 @@ module { %accumulation_scale = "tf.Mul"(%input_scale, %filter_scale) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> %cast = "tf.Cast"(%accumulation) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> %dequantize = "tf.Mul"(%cast, %accumulation_scale) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - func.return %dequantize : tensor<*xf32> + + %i8_min = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor + %i8_max = "tf.Const"() {value = dense<127> : tensor} : () -> tensor + + %clip_min = "tf.PartitionedCall"(%i8_min, %out_scale, %out_zp) { + config = "", config_proto = "", executor_type = "", f=@internal_dequantize_i8_fn + } : (tensor, tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32> + %clip_max = "tf.PartitionedCall"(%i8_max, %out_scale, %out_zp) { + config = "", config_proto = "", executor_type = "", f=@internal_dequantize_i8_fn + } : (tensor, tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32> + + %clamp_max = "tf.Maximum"(%dequantize, %clip_min) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %clamp_min = "tf.Minimum"(%clamp_max, %clip_max) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + + func.return %clamp_min : tensor<*xf32> } // Dequantizes and applies quantized Relu by clipping. @@ -133,8 +158,22 @@ module { %cast = "tf.Cast"(%accumulation) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> %dequantize = "tf.Mul"(%cast, %accumulation_scale) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %relu = "tf.Relu"(%dequantize) : (tensor<*xf32>) -> tensor<*xf32> - func.return %relu : tensor<*xf32> + %i8_min = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor + %i8_max = "tf.Const"() {value = dense<127> : tensor} : () -> tensor + + %clip_min_0 = "tf.PartitionedCall"(%i8_min, %out_scale, %out_zp) { + config = "", config_proto = "", executor_type = "", f=@internal_dequantize_i8_fn + } : (tensor, tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32> + %clip_max = "tf.PartitionedCall"(%i8_max, %out_scale, %out_zp) { + config = "", config_proto = "", executor_type = "", f=@internal_dequantize_i8_fn + } : (tensor, tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32> + + %clip_min = "tf.Relu"(%clip_min_0) : (tensor<*xf32>) -> tensor<*xf32> + + %clamp_max = "tf.Maximum"(%dequantize, %clip_min) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %clamp_min = "tf.Minimum"(%clamp_max, %clip_max) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + + func.return %clamp_min : tensor<*xf32> } // Dequantizes and applies quantized Relu6 by clipping. @@ -146,8 +185,24 @@ module { %cast = "tf.Cast"(%accumulation) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> %dequantize = "tf.Mul"(%cast, %accumulation_scale) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %relu6 = "tf.Relu6"(%dequantize) : (tensor<*xf32>) -> tensor<*xf32> - func.return %relu6 : tensor<*xf32> + %i8_min = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor + %i8_max = "tf.Const"() {value = dense<127> : tensor} : () -> tensor + %relu6_upper = "tf.Const"() {value = dense<6.0>: tensor} : () -> tensor + + %clip_min_0 = "tf.PartitionedCall"(%i8_min, %out_scale, %out_zp) { + config = "", config_proto = "", executor_type = "", f=@internal_dequantize_i8_fn + } : (tensor, tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32> + %clip_max_0 = "tf.PartitionedCall"(%i8_max, %out_scale, %out_zp) { + config = "", config_proto = "", executor_type = "", f=@internal_dequantize_i8_fn + } : (tensor, tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32> + + %clip_min = "tf.Relu"(%clip_min_0) : (tensor<*xf32>) -> tensor<*xf32> + %clip_max = "tf.Minimum"(%clip_max_0, %relu6_upper) : (tensor<*xf32>, tensor) -> tensor<*xf32> + + %clamp_max = "tf.Maximum"(%dequantize, %clip_min) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %clamp_min = "tf.Minimum"(%clamp_max, %clip_max) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + + func.return %clamp_min : tensor<*xf32> } // Conv2D with int32 accumulation. @@ -237,24 +292,44 @@ module { func.return %6 : tensor<*xi32> } - for main_op in ["conv2d", "depthwise_conv2d", "matmul", "conv3d"] { + // BatchMatMul with int32 accumulation. + func.func private @internal_batch_matmul_fn( + %input : tensor<*xi8>, %weight : tensor<*xi8>, + %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, + %weight_scale : tensor<*xf32>, %weight_zp : tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.Cast"(%input) {Truncate = false} : (tensor<*xi8>) -> tensor<*xi32> + %1 = "tf.Sub"(%0, %input_zp) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + + // Use identity op to avoid the weight being constant-folded. + %identity = "tf.Identity"(%weight) : (tensor<*xi8>) -> tensor<*xi8> + %2 = "tf.Cast"(%identity) {Truncate = false} : (tensor<*xi8>) -> tensor<*xi32> + %3 = "tf.Sub"(%2, %weight_zp) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + + %5 = "tf.BatchMatMulV2"(%1, %3) { + attr_map = "adj_x:0,adj_y:1" + } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + func.return %5 : tensor<*xi32> + } + + for main_op in ["Conv2D", "DepthwiseConv2D", "MatMul", "Conv3D", "BatchMatMul"] { parameters[ - {"suffix": "with_bias_fn", "act_func": "internal_requantize_no_activation_fn", "output_type": "i8"}, - {"suffix": "with_bias_and_relu_fn", "act_func": "internal_requantize_and_relu_fn", "output_type": "i8"}, - {"suffix": "with_bias_and_relu6_fn", "act_func": "internal_requantize_and_relu6_fn", "output_type": "i8"}, - {"suffix": "with_bias_float_output_fn", "act_func": "internal_dequantize_no_activation_fn", "output_type": "f32"}, - {"suffix": "with_bias_and_relu_float_output_fn", "act_func": "internal_dequantize_and_relu_fn", "output_type": "f32"}, - {"suffix": "with_bias_and_relu6_float_output_fn", "act_func": "internal_dequantize_and_relu6_fn", "output_type": "f32"}, + {"quantized_ops": ["${main_op}", "BiasAdd"], "act_func": "internal_requantize_no_activation_fn", "output_type": "i8"}, + {"quantized_ops": ["${main_op}", "BiasAdd", "Relu"], "act_func": "internal_requantize_and_relu_fn", "output_type": "i8"}, + {"quantized_ops": ["${main_op}", "BiasAdd", "Relu6"], "act_func": "internal_requantize_and_relu6_fn", "output_type": "i8"}, + {"quantized_ops": ["${main_op}", "BiasAdd"], "act_func": "internal_dequantize_no_activation_fn", "output_type": "f32"}, + {"quantized_ops": ["${main_op}", "BiasAdd", "Relu"], "act_func": "internal_dequantize_and_relu_fn", "output_type": "f32"}, + {"quantized_ops": ["${main_op}", "BiasAdd", "Relu6"], "act_func": "internal_dequantize_and_relu6_fn", "output_type": "f32"}, ] - func.func @quantized_${main_op}_${suffix}(%input : tensor<*xi8>, + func.func @GenerateQuantizedFunctionName(${quantized_ops}, "${output_type}")(%input : tensor<*xi8>, %filter : tensor<*xi8>, %bias : tensor<*xi32>, %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, %filter_scale : tensor<*xf32>, %filter_zp : tensor<*xi32>, %bias_scale : tensor<*xf32>, %bias_zp : tensor<*xi32>, - %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x${output_type}> { + %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x${output_type}> + attributes {tf_quant.quantized_ops = ${quantized_ops}} { %0 = "tf.PartitionedCall"(%input, %filter, %input_scale, %input_zp, %filter_scale, %filter_zp) { - config = "", config_proto = "", executor_type = "", f=@internal_${main_op}_fn + config = "", config_proto = "", executor_type = "", f=@GenerateImplFunctionName(${main_op}) } : (tensor<*xi8>, tensor<*xi8>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*xi32> %1 = "tf.AddV2"(%0, %bias) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> @@ -267,21 +342,22 @@ module { } parameters[ - {"suffix": "fn", "act_func": "internal_requantize_no_activation_fn", "output_type": "i8"}, - {"suffix": "with_relu_fn", "act_func": "internal_requantize_and_relu_fn", "output_type": "i8"}, - {"suffix": "with_relu6_fn", "act_func": "internal_requantize_and_relu6_fn", "output_type": "i8"}, - {"suffix": "float_output_fn", "act_func": "internal_dequantize_no_activation_fn", "output_type": "f32"}, - {"suffix": "with_relu_float_output_fn", "act_func": "internal_dequantize_and_relu_fn", "output_type": "f32"}, - {"suffix": "with_relu6_float_output_fn", "act_func": "internal_dequantize_and_relu6_fn", "output_type": "f32"}, + {"quantized_ops": ["${main_op}"], "act_func": "internal_requantize_no_activation_fn", "output_type": "i8"}, + {"quantized_ops": ["${main_op}", "Relu"], "act_func": "internal_requantize_and_relu_fn", "output_type": "i8"}, + {"quantized_ops": ["${main_op}", "Relu6"], "act_func": "internal_requantize_and_relu6_fn", "output_type": "i8"}, + {"quantized_ops": ["${main_op}"], "act_func": "internal_dequantize_no_activation_fn", "output_type": "f32"}, + {"quantized_ops": ["${main_op}", "Relu"], "act_func": "internal_dequantize_and_relu_fn", "output_type": "f32"}, + {"quantized_ops": ["${main_op}", "Relu6"], "act_func": "internal_dequantize_and_relu6_fn", "output_type": "f32"}, ] - func.func @quantized_${main_op}_${suffix}( + func.func @GenerateQuantizedFunctionName(${quantized_ops}, "${output_type}")( %input : tensor<*xi8>, %filter : tensor<*xi8>, %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, %filter_scale : tensor<*xf32>, %filter_zp : tensor<*xi32>, - %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x${output_type}> { + %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x${output_type}> + attributes {tf_quant.quantized_ops = ${quantized_ops}} { %0 = "tf.PartitionedCall"(%input, %filter, %input_scale, %input_zp, %filter_scale, %filter_zp) { - config = "", config_proto = "", executor_type = "", f=@internal_${main_op}_fn + config = "", config_proto = "", executor_type = "", f=@GenerateImplFunctionName(${main_op}) } : (tensor<*xi8>, tensor<*xi8>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*xi32> %1 = "tf.PartitionedCall"(%0, %input_scale, %input_zp, %filter_scale, %filter_zp, @@ -294,25 +370,23 @@ module { } // end for func.func @quantize_i8(%input : tensor<*xf32>, %scale : tensor<*xf32>, %zp : tensor<*xi32>) -> tensor<*xi8> { - // Uses tf.floor(x + 0.5) instead of tf.round(x) since tf.round generates - // a very expensive pattern. - %round_cst = "tf.Const"() {value = dense<0.5> : tensor} : () -> tensor %float_zp = "tf.Cast"(%zp) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> - %zp_plus_round_cst = "tf.AddV2"(%float_zp, %round_cst) : (tensor<*xf32>, tensor) -> tensor<*xf32> - %div = "tf.Div"(%input, %scale) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %add = "tf.AddV2"(%div, %zp_plus_round_cst) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %round = "tf.Floor"(%add) : (tensor<*xf32>) -> tensor<*xf32> + %add = "tf.AddV2"(%div, %float_zp) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> %i8_min = "tf.Const"() {value = dense<-128.0> : tensor} : () -> tensor %i8_max = "tf.Const"() {value = dense<127.0> : tensor} : () -> tensor - %clip = "tf.ClipByValue"(%round, %i8_min, %i8_max) : (tensor<*xf32>, tensor, tensor) -> tensor<*xf32> - %i8 = "tf.Cast"(%clip) : (tensor<*xf32>) -> tensor<*xi8> + %clamp_max = "tf.Maximum"(%add, %i8_min) : (tensor<*xf32>, tensor) -> tensor<*xf32> + %clamp_min = "tf.Minimum"(%clamp_max, %i8_max) : (tensor<*xf32>, tensor) -> tensor<*xf32> + %round = "tf.Round"(%clamp_min) : (tensor<*xf32>) -> tensor<*xf32> + %i8 = "tf.Cast"(%round) : (tensor<*xf32>) -> tensor<*xi8> func.return %i8 : tensor<*xi8> } func.func @dequantize_i8(%input : tensor<*xi8>, %scale : tensor<*xf32>, %zp : tensor<*xi32>) -> tensor<*xf32> { - %input_i32 = "tf.Cast"(%input) : (tensor<*xi8>) -> tensor<*xi32> + // Use identity op to avoid the weight being constant-folded. + %identity = "tf.Identity"(%input) : (tensor<*xi8>) -> tensor<*xi8> + %input_i32 = "tf.Cast"(%identity) : (tensor<*xi8>) -> tensor<*xi32> %output = "tf.Sub"(%input_i32, %zp) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> %cast = "tf.Cast"(%output) : (tensor<*xi32>) -> tensor<*xf32> %mul = "tf.Mul"(%cast, %scale) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_tf_drq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_tf_drq.mlir index aa942c18613..7b6ab40d579 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_tf_drq.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_tf_drq.mlir @@ -34,11 +34,20 @@ module { // Note: following functions won't handle per-channel quantization for now. func.func private @internal_quantize_i8(%input : tensor<*xf32>, %scale : tensor<*xf32>, %zp : tensor<*xi32>) -> tensor<*xi8> { + // Uses tf.floor(x + 0.5) instead of tf.round(x) since tf.round generates + // a very expensive pattern. + %round_cst = "tf.Const"() {value = dense<0.5> : tensor} : () -> tensor + %float_zp = "tf.Cast"(%zp) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> + %zp_plus_round_cst = "tf.AddV2"(%float_zp, %round_cst) : (tensor<*xf32>, tensor) -> tensor<*xf32> + %div = "tf.Div"(%input, %scale) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %round = "tf.Round"(%div) : (tensor<*xf32>) -> tensor<*xf32> - %cast = "tf.Cast"(%round) : (tensor<*xf32>) -> tensor<*xi32> - %add = "tf.AddV2"(%cast, %zp) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> - %i8 = "tf.Cast"(%add) : (tensor<*xi32>) -> tensor<*xi8> + %add = "tf.AddV2"(%div, %zp_plus_round_cst) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %round = "tf.Floor"(%add) : (tensor<*xf32>) -> tensor<*xf32> + + %i8_min = "tf.Const"() {value = dense<-128.0> : tensor} : () -> tensor + %i8_max = "tf.Const"() {value = dense<127.0> : tensor} : () -> tensor + %clip = "tf.ClipByValue"(%round, %i8_min, %i8_max) : (tensor<*xf32>, tensor, tensor) -> tensor<*xf32> + %i8 = "tf.Cast"(%clip) : (tensor<*xf32>) -> tensor<*xi8> func.return %i8 : tensor<*xi8> } @@ -46,60 +55,123 @@ module { %input_scale : tensor<*xf32>, %weight_scale : tensor<*xf32>) -> tensor<*xf32> { %scale_prod = "tf.Mul"(%input_scale, %weight_scale) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %0 = "tf.Cast"(%input) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> - %1 = "tf.Mul"(%0, %scale_prod) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - func.return %1 : tensor<*xf32> + + %cast = "tf.Cast"(%input) : (tensor<*xi32>) -> tensor<*xf32> + %mul = "tf.Mul"(%cast, %scale_prod) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + func.return %mul : tensor<*xf32> } - // Note: following function supports per-tensor, symmetric, none narrow_range. + // TODO(b/263199401): Support quantization options for activation quantization for DRQ + // Note: following function supports per-tensor, asymmetric, non_narrow_range. func.func private @internal_calculate_quant_params(%input : tensor<*xf32>) -> (tensor<1xf32>, tensor<1xi32>) { - %zp = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %zero = "tf.Const"() {value = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32> %shape = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> tensor<1xi32> %dim = "tf.Const"() { value = dense<0> : tensor<1xi64> } : () -> tensor<1xi64> + // Check and include zero in the range so that zero value can be correctly + // represented. %input_1d = "tf.Reshape"(%input, %shape) : (tensor<*xf32>, tensor<1xi32>) -> tensor - %r_max = "tf.Max"(%input_1d, %dim) { keep_dims = true }: (tensor, tensor<1xi64>) -> tensor<1xf32> - %r_min = "tf.Min"(%input_1d, %dim) { keep_dims = true }: (tensor, tensor<1xi64>) -> tensor<1xf32> - %r_max_abs = "tf.Abs"(%r_max) : (tensor<1xf32>) -> tensor<1xf32> - %r_min_abs = "tf.Abs"(%r_min) : (tensor<1xf32>) -> tensor<1xf32> - %r_abs_max = "tf.Maximum"(%r_max_abs, %r_min_abs) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> - %r_abs_max_cast = "tf.Cast"(%r_abs_max) : (tensor<1xf32>) -> tensor<1xf64> + %r_max_without_zero = "tf.Max"(%input_1d, %dim) { keep_dims = true }: (tensor, tensor<1xi64>) -> tensor<1xf32> + %r_max = "tf.Maximum"(%zero, %r_max_without_zero) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + %r_min_without_zero = "tf.Min"(%input_1d, %dim) { keep_dims = true }: (tensor, tensor<1xi64>) -> tensor<1xf32> + %r_min = "tf.Minimum"(%zero, %r_min_without_zero) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + %r_max_f64 = "tf.Cast"(%r_max) : (tensor<1xf32>) -> tensor<1xf64> + %r_min_f64 = "tf.Cast"(%r_min) : (tensor<1xf32>) -> tensor<1xf64> + + %i8_min = "tf.Const"() {value = dense<-128.0> : tensor} : () -> tensor + %i8_max = "tf.Const"() {value = dense<127.0> : tensor} : () -> tensor + %i8_min_f64 = "tf.Cast"(%i8_min) : (tensor) -> tensor + %i8_max_f64 = "tf.Cast"(%i8_max) : (tensor) -> tensor - %i8_min = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor - %i8_max = "tf.Const"() {value = dense<127> : tensor} : () -> tensor - %i8_min_cast = "tf.Cast"(%i8_min) : (tensor) -> tensor - %i8_max_cast = "tf.Cast"(%i8_max) : (tensor) -> tensor + %range_nume = "tf.Sub"(%r_max_f64, %r_min_f64) : (tensor<1xf64>, tensor<1xf64>) -> tensor<1xf64> + %range_deno = "tf.Sub"(%i8_max_f64, %i8_min_f64) : (tensor, tensor) -> tensor - %range_nume = "tf.AddV2"(%r_abs_max_cast, %r_abs_max_cast) : (tensor<1xf64>, tensor<1xf64>) -> tensor<1xf64> - %range_deno = "tf.Sub"(%i8_max_cast, %i8_min_cast) : (tensor, tensor) -> tensor + %scale_f64 = "tf.Div"(%range_nume, %range_deno) : (tensor<1xf64>, tensor) -> tensor<1xf64> + %scale = "tf.Cast"(%scale_f64) : (tensor<1xf64>) -> tensor<1xf32> - %scale_double = "tf.Div"(%range_nume, %range_deno) : (tensor<1xf64>, tensor) -> tensor<1xf64> - %scale = "tf.Cast"(%scale_double) : (tensor<1xf64>) -> tensor<1xf32> + // Add comparison with minimum if needed + %intermediate_val = "tf.Div"(%r_max_f64, %scale_f64) : (tensor<1xf64>, tensor<1xf64>) -> tensor<1xf64> + %zp_from_max = "tf.Sub"(%i8_max_f64, %intermediate_val) : (tensor, tensor<1xf64>) -> tensor<1xf64> + %zp_fp32 = "tf.Cast"(%zp_from_max) : (tensor<1xf64>) -> tensor<1xf32> + %zp = "tf.Cast"(%zp_fp32) : (tensor<1xf32>) -> tensor<1xi32> func.return %scale, %zp : tensor<1xf32>, tensor<1xi32> } // Matmul with int32 accumulation func.func private @internal_matmul_fn( - %input : tensor<*xi8>, %weight : tensor<*xi8>, + %input : tensor<*xi8>, %filter : tensor<*xi8>, %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, %weight_scale : tensor<*xf32>, %weight_zp : tensor<*xi32>) -> tensor<*xi32> { %0 = "tf.Cast"(%input) {Truncate = false} : (tensor<*xi8>) -> tensor<*xi32> %1 = "tf.Sub"(%0, %input_zp) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> - %2 = "tf.Cast"(%weight) {Truncate = false} : (tensor<*xi8>) -> tensor<*xi32> + // Use identity op to avoid the weight being constant-folded. + %identity = "tf.Identity"(%filter) : (tensor<*xi8>) -> tensor<*xi8> + %2 = "tf.Cast"(%identity) {Truncate = false} : (tensor<*xi8>) -> tensor<*xi32> %3 = "tf.Sub"(%2, %weight_zp) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> - // TODO(b/215633216): Optimize this function with the XLA Dot op. %5 = "tf.MatMul"(%1, %3) { attr_map = "transpose_a:0,transpose_b:1" } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> func.return %5 : tensor<*xi32> } - func.func @quantized_matmul_fn( + // Conv2D with int32 accumulation + func.func private @internal_conv2d_fn( + %input : tensor<*xi8>, %filter : tensor<*xi8>, + %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, + %filter_scale : tensor<*xf32>, %filter_zp : tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.Cast"(%input) {Truncate = false} : (tensor<*xi8>) -> tensor<*xi32> + %1 = "tf.Sub"(%0, %input_zp) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + + // Use identity op to avoid the weight being constant-folded. + %identity = "tf.Identity"(%filter) : (tensor<*xi8>) -> tensor<*xi8> + %2 = "tf.Cast"(%identity) {Truncate = false} : (tensor<*xi8>) -> tensor<*xi32> + %3 = "tf.Sub"(%2, %filter_zp) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + + %5 = "tf.Conv2D"(%1, %3) { + padding = "VALID", strides = [1, 1, 1, 1], + attr_map = "strides:0,use_cudnn_on_gpu:1,padding:2,explicit_paddings:3,dilations:4" + } : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + func.return %5 : tensor<*xi32> + } + + // DepthwiseConv2D with float computation + func.func private @internal_depthwise_conv2d_fn( + %input : tensor<*xi8>, %filter : tensor<*xi8>, + %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, + %filter_scale : tensor<*xf32>, %filter_zp : tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.Cast"(%input) {Truncate = false} : (tensor<*xi8>) -> tensor<*xi32> + %1 = "tf.Sub"(%0, %input_zp) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + + // Use identity op to avoid the weight being constant-folded. + %identity = "tf.Identity"(%filter) : (tensor<*xi8>) -> tensor<*xi8> + %2 = "tf.Cast"(%identity) {Truncate = false} : (tensor<*xi8>) -> tensor<*xi32> + %3 = "tf.Sub"(%2, %filter_zp) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + + %cast_1_f32 = "tf.Cast"(%1) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> + %cast_3_f32 = "tf.Cast"(%3) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> + + %5 = "tf.DepthwiseConv2dNative"(%cast_1_f32, %cast_3_f32) { + padding = "VALID", strides = [1, 1, 1, 1], + attr_map = "strides:0,padding:1,explicit_paddings:2,dilations:3" + } : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %6 = "tf.Cast"(%5) : (tensor<*xf32>) -> tensor<*xi32> + func.return %6 : tensor<*xi32> + } + + parameters[ + {"quantized_ops": ["MatMul"], "internal_func_name": "internal_matmul_fn"}, + {"quantized_ops": ["Conv2D"], "internal_func_name": "internal_conv2d_fn"}, + {"quantized_ops": ["DepthwiseConv2D"], "internal_func_name": "internal_depthwise_conv2d_fn"} + ] + func.func @GenerateQuantizedFunctionName(${quantized_ops})( %input : tensor<*xf32>, %weight : tensor<*xi8>, - %weight_scale : tensor<*xf32>, %weight_zp : tensor<*xi32>) -> tensor<*xf32> { + %weight_scale : tensor<*xf32>, %weight_zp : tensor<*xi32>) -> tensor<*xf32> + attributes {tf_quant.quantized_ops = ${quantized_ops}} { %input_scale, %input_zp = "tf.PartitionedCall"(%input) { config = "", config_proto = "", executor_type = "", f=@internal_calculate_quant_params @@ -111,7 +183,7 @@ module { %accum_out = "tf.PartitionedCall"(%quantized_input, %weight, %input_scale, %input_zp, %weight_scale, %weight_zp) { - config = "", config_proto = "", executor_type = "", f=@internal_matmul_fn + config = "", config_proto = "", executor_type = "", f=@${internal_func_name} } : (tensor<*xi8>, tensor<*xi8>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*xi32> @@ -122,4 +194,14 @@ module { func.return %out : tensor<*xf32> } + // For weight-only + func.func @dequantize_i8(%input : tensor<*xi8>, %scale : tensor<*xf32>, %zp : tensor<*xi32>) -> tensor<*xf32> { + // Use identity op to avoid the weight being constant-folded. + %identity = "tf.Identity"(%input) : (tensor<*xi8>) -> tensor<*xi8> + %input_i32 = "tf.Cast"(%identity) : (tensor<*xi8>) -> tensor<*xi32> + %output = "tf.Sub"(%input_i32, %zp) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> + %cast = "tf.Cast"(%output) : (tensor<*xi32>) -> tensor<*xf32> + %mul = "tf.Mul"(%cast, %scale) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + func.return %mul : tensor<*xf32> + } } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_uniform_quantized.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_uniform_quantized.mlir index dbec1ed5661..2225e588e39 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_uniform_quantized.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_uniform_quantized.mlir @@ -33,19 +33,76 @@ module { - // TODO(b/240931497) Replace with core tf ops once uniform quantization is submitted. - // Ref bugs for op: b/230804708, b/230805744 - func.func @quantized_conv2d_with_bias_fn( - %input : tensor<*x!tf_type.qint8>, %filter : tensor<*x!tf_type.qint8>, %bias : tensor<*x!tf_type.qint32>, - %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, %filter_scale : tensor<*xf32>, %filter_zp : tensor<*xi32>, %bias_scale : tensor<*xf32>, %bias_zp : tensor<*xi32>, %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x!tf_type.qint8> { - %conv_out = "tf.ExperimentalUniformQuantizedConvolution"(%input, %filter, + for main_op in ["Conv2D", "DepthwiseConv2D"] { + parameters[ + {"quantized_ops": ["${main_op}", "BiasAdd"], "act_func": "internal_requantize_no_activation_fn", "output_type": "!tf_type.qint8"}, + {"quantized_ops": ["${main_op}", "BiasAdd", "Relu"], "act_func": "internal_requantize_and_relu_fn", "output_type": "!tf_type.qint8"}, + {"quantized_ops": ["${main_op}", "BiasAdd", "Relu6"], "act_func": "internal_requantize_and_relu6_fn", "output_type": "!tf_type.qint8"}, + ] + func.func @GenerateQuantizedFunctionName(${quantized_ops})(%input : tensor<*x!tf_type.qint8>, + %filter : tensor<*x!tf_type.qint8>, %bias : tensor<*x!tf_type.qint32>, + %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, + %filter_scale : tensor<*xf32>, %filter_zp : tensor<*xi32>, + %bias_scale : tensor<*xf32>, %bias_zp : tensor<*xi32>, + %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x${output_type}> + attributes {tf_quant.quantized_ops = ${quantized_ops}} { + // TODO(b/258729559): Revisit scale/zp after e2e path for SRQ on UQ is ready. + %main_out = "tf.PartitionedCall"(%input, %filter, %input_scale, %input_zp, + %filter_scale, %filter_zp, %out_scale, %out_zp) { + config = "", config_proto = "", executor_type = "", f=@GenerateImplFunctionName(${main_op}) + } : (tensor<*x!tf_type.qint8>, tensor<*x!tf_type.qint8>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint32> + %add = "tf.UniformQuantizedAdd"(%main_out, %bias, %input_scale, %input_zp, %bias_scale, %bias_zp, %out_scale, %out_zp) { + lhs_quantization_axis = -1, + lhs_quantization_min_val = -128, + lhs_quantization_max_val = 127, + rhs_quantization_axis = -1, + rhs_quantization_min_val = -128, + rhs_quantization_max_val = 127, + output_quantization_axis = -1, + output_quantization_min_val = -128, + output_quantization_max_val = 127, + T = "tfdtype$DT_QINT32", + attr_map = "" + } : (tensor<*x!tf_type.qint32>, tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint32> + %act = "tf.PartitionedCall"(%add, %input_scale, %input_zp, %out_scale, %out_zp) { + config = "", config_proto = "", executor_type = "", f=@${act_func} + } : (tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x${output_type}> + func.return %act : tensor<*x${output_type}> + } + + parameters[ + {"quantized_ops": ["${main_op}"], "act_func": "internal_requantize_no_activation_fn", "output_type": "!tf_type.qint8"}, + {"quantized_ops": ["${main_op}", "Relu"], "act_func": "internal_requantize_and_relu_fn", "output_type": "!tf_type.qint8"}, + {"quantized_ops": ["${main_op}", "Relu6"], "act_func": "internal_requantize_and_relu6_fn", "output_type": "!tf_type.qint8"}, + ] + func.func @GenerateQuantizedFunctionName(${quantized_ops})(%input : tensor<*x!tf_type.qint8>, %filter : tensor<*x!tf_type.qint8>, + %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, + %filter_scale : tensor<*xf32>, %filter_zp : tensor<*xi32>, + %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x${output_type}> + attributes {tf_quant.quantized_ops = ${quantized_ops}} { + // TODO(b/258729559): Revisit scale/zp after e2e path for SRQ on UQ is ready. + %main_out = "tf.PartitionedCall"(%input, %filter, %input_scale, %input_zp, + %filter_scale, %filter_zp, %out_scale, %out_zp) { + config = "", config_proto = "", executor_type = "", f=@GenerateImplFunctionName(${main_op}) + } : (tensor<*x!tf_type.qint8>, tensor<*x!tf_type.qint8>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint32> + %act = "tf.PartitionedCall"(%main_out, %input_scale, %input_zp, %out_scale, %out_zp) { + config = "", config_proto = "", executor_type = "", f=@${act_func} + } : (tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x${output_type}> + func.return %act : tensor<*x${output_type}> + } + } // end for + + // Conv2d Convolution. + func.func private @internal_conv2d_fn( + %input : tensor<*x!tf_type.qint8>, %filter : tensor<*x!tf_type.qint8>, + %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, + %filter_scale : tensor<*xf32>, %filter_zp : tensor<*xi32>, %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x!tf_type.qint32> { + %conv_out = "tf.UniformQuantizedConvolution"(%input, %filter, %input_scale, %input_zp, %filter_scale, %filter_zp, %out_scale, %out_zp) { - // TODO(b/238600711): Populate attributes for quantized_function_library_uniform_quantized - Tlhs = "tfdtype$DT_QINT8", - Trhs = "tfdtype$DT_QINT8", + Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT32", window_strides = [1, 1], - padding = "", + padding = "SAME", explicit_padding = [], lhs_dilation = [], rhs_dilation = [], @@ -61,66 +118,141 @@ module { output_quantization_axis = -1, output_quantization_min_val = -128, output_quantization_max_val = 127, - attr_map = "0:Tlhs,1:Trhs,2:Tout,3:lhs_quantization_min_val,4:lhs_quantization_max_val,5:rhs_quantization_min_val,6:rhs_quantization_max_val,7:output_quantization_min_val,8:output_quantization_max_val" - } : (tensor<*x!tf_type.qint8>, tensor<*x!tf_type.qint8>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint32> - %add_bias = "tf.ExperimentalUniformQuantizedAdd"(%conv_out, %bias, %input_scale, %input_zp, %bias_scale, %bias_zp, %out_scale, %out_zp) { - // TODO(b/238600711): Populate attributes for quantized_function_library_uniform_quantized - lhs_quantization_axis = -1, - lhs_quantization_min_val = -128, - lhs_quantization_max_val = 127, - rhs_quantization_axis = -1, - rhs_quantization_min_val = -128, - rhs_quantization_max_val = 127, - output_quantization_axis = -1, - output_quantization_min_val = -128, - output_quantization_max_val = 127, - T = 1, - attr_map = "0:Tlhs,1:Trhs,2:Tout,3:lhs_quantization_min_val,4:lhs_quantization_max_val,5:rhs_quantization_min_val,6:rhs_quantization_max_val,7:output_quantization_min_val,8:output_quantization_max_val" - } : (tensor<*x!tf_type.qint32>, tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint32> - %requantized_out = "tf.PartitionedCall"(%add_bias, %input_scale, %input_zp, %out_scale, %out_zp) { - config = "", config_proto = "", executor_type = "", f=@requantize_qi8 - } : (tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint8> - func.return %requantized_out : tensor<*x!tf_type.qint8> + attr_map = "" + } : (tensor<*x!tf_type.qint8>, tensor<*x!tf_type.qint8>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint32> + func.return %conv_out : tensor<*x!tf_type.qint32> } - // Quantize initial input at the start of the graph. - func.func @quantize_qi8(%input : tensor<*xf32>, %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x!tf_type.qint8> { - %quantized_out = "tf.ExperimentalUniformQuantize"(%input, %input_scale, %input_zp) { - // TODO(b/238600711): Populate attributes for quantized_function_library_uniform_quantized + // Depthwise convolution. feature_group_count is set to 3rd dim of input shape. + func.func private @internal_depthwise_conv2d_fn( + %input : tensor<*x!tf_type.qint8>, %filter : tensor<*x!tf_type.qint8>, + %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, + %filter_scale : tensor<*xf32>, %filter_zp : tensor<*xi32>, %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x!tf_type.qint32> { + %conv_out = "tf.UniformQuantizedConvolution"(%input, %filter, + %input_scale, %input_zp, %filter_scale, %filter_zp, %out_scale, %out_zp) { + Tin = "tfdtype$DT_QINT8", + Tout = "tfdtype$DT_QINT32", + window_strides = [1, 1], + padding = "SAME", + explicit_padding = [], + lhs_dilation = [], + rhs_dilation = [], + batch_group_count = 1, + feature_group_count = 1, + dimension_numbers = "", + lhs_quantization_axis = -1, + lhs_quantization_min_val = -128, + lhs_quantization_max_val = 127, + rhs_quantization_axis = -1, + rhs_quantization_min_val = -128, + rhs_quantization_max_val = 127, + output_quantization_axis = -1, + output_quantization_min_val = -128, + output_quantization_max_val = 127, + attr_map = "" + } : (tensor<*x!tf_type.qint8>, tensor<*x!tf_type.qint8>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint32> + func.return %conv_out : tensor<*x!tf_type.qint32> + } + + // Quantize initial input at the start of the graph. Output is qint8. + func.func @quantize_i8(%input : tensor<*xf32>, %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>) -> tensor<*x!tf_type.qint8> { + %quantize = "tf.UniformQuantize"(%input, %input_scale, %input_zp) { + Tin = "tfdtype$DT_FLOAT", + Tout = "tfdtype$DT_QINT8", quantization_axis = -1, quantization_min_val = -128, quantization_max_val = 127, - T = 1, - attr_map = "0:Tin,1:Tout,2:quantization_axis,3:quantization_min_val,4:quantization_max_val" + attr_map = "" } : (tensor<*xf32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint8> - func.return %quantized_out : tensor<*x!tf_type.qint8> + func.return %quantize : tensor<*x!tf_type.qint8> } // Requantize a qint32 tensor to qint8 tensor for the next input. - func.func @requantize_qi8(%input : tensor<*x!tf_type.qint32>, %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, %out_scale: tensor<*xf32>, %out_zp: tensor<*xi32>) -> tensor<*x!tf_type.qint8> { - %requantized_out = "tf.ExperimentalUniformRequantize"(%input, %input_scale, %input_zp, %out_scale, %out_zp) { - // TODO(b/238600711): Populate attributes for quantized_function_library_uniform_quantized - input_quantization_axis = -1, - input_quantization_min_val = -128, - input_quantization_max_val = 127, - output_quantization_axis = -1, - output_quantization_min_val = -128, - output_quantization_max_val = 127, - attr_map = "0:Tin,1:Tout,2:input_quantization_axis,3:input_quantization_min_val,4:input_quantization_max_val,5:output_quantization_axis,6:output_quantization_min_val,7:output_quantization_min_val" - } : (tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint8> - func.return %requantized_out : tensor<*x!tf_type.qint8> + func.func private @internal_requantize_qi8_fn(%input : tensor<*x!tf_type.qint32>, %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, %out_scale: tensor<*xf32>, %out_zp: tensor<*xi32>) -> tensor<*x!tf_type.qint8> { + %requantize = "tf.UniformRequantize"(%input, %input_scale, %input_zp, %out_scale, %out_zp) { + Tin = "tfdtype$DT_QINT32", + Tout = "tfdtype$DT_QINT8", + input_quantization_axis = -1, + input_quantization_min_val = -128, + input_quantization_max_val = 127, + output_quantization_axis = -1, + output_quantization_min_val = -128, + output_quantization_max_val = 127, + attr_map = "" + } : (tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint8> + func.return %requantize : tensor<*x!tf_type.qint8> } - // Dequantize final graph output back to f32. - func.func @dequantize_qi8(%input : tensor<*x!tf_type.qint8>, %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, %out_scale : tensor<*xf32>, %out_zp : tensor<*xi8>) -> tensor<*xf32> { - %dequantized_out = "tf.ExperimentalUniformDequantize"(%input, %input_scale, %input_zp) { - // TODO(b/238600711): Populate attributes for quantized_function_library_uniform_quantized + // Dequantize final graph output back to f32. Input is qint8. + func.func @dequantize_i8(%input : tensor<*x!tf_type.qint8>, %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>) -> tensor<*xf32> { + %dequantize = "tf.UniformDequantize"(%input, %input_scale, %input_zp) { + Tin = "tfdtype$DT_QINT8", + Tout = "tfdtype$DT_FLOAT", quantization_axis = -1, quantization_min_val = -128, quantization_max_val = 127, - attr_map = "0:Tin,1:Tout,2:quantization_axis,3:quantization_min_val,4:quantization_max_val" + attr_map = "" } : (tensor<*x!tf_type.qint8>, tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32> - func.return %dequantized_out : tensor<*xf32> + func.return %dequantize : tensor<*xf32> + } + + // Requantizes and applies quantized Relu by clipping. + func.func private @internal_requantize_no_activation_fn(%input : tensor<*x!tf_type.qint32>, %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, + %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x!tf_type.qint8> { + %q_out = "tf.PartitionedCall"(%input, %input_scale, %input_zp, %out_scale, %out_zp) { + config = "", config_proto = "", executor_type = "", f=@internal_requantize_qi8_fn + } : (tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint8> + func.return %q_out : tensor<*x!tf_type.qint8> + } + + // Requantizes and applies quantized Relu6 by clipping. + func.func private @internal_requantize_and_relu_fn(%input : tensor<*x!tf_type.qint32>, %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, + %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x!tf_type.qint8> { + %i8_min = "tf.Const"() {value = dense<-128.0> : tensor} : () -> tensor + %i8_max = "tf.Const"() {value = dense<127.0> : tensor} : () -> tensor + %float_out_zp = "tf.Cast"(%out_zp) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> + %clip_min = "tf.Maximum"(%i8_min, %float_out_zp) : (tensor, tensor<*xf32>) -> tensor + %qclip_min = "tf.Cast"(%i8_min) {Truncate = false} : (tensor) -> tensor + %qi8_max = "tf.Cast"(%i8_max) {Truncate = false} : (tensor) -> tensor + %relu = "tf.UniformQuantizedClipByValue"(%input, %qclip_min, %qi8_max, %out_scale, %out_zp) { + T = "tfdtype$DT_QINT32", + quantization_axis = -1, + quantization_min_val = -128, + quantization_max_val = 127, + attr_map = "" + } : (tensor<*x!tf_type.qint32>, tensor, tensor, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint32> + %requantize = "tf.PartitionedCall"(%relu, %input_scale, %input_zp, %out_scale, %out_zp) { + config = "", config_proto = "", executor_type = "", f=@internal_requantize_qi8_fn + } : (tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint8> + func.return %requantize : tensor<*x!tf_type.qint8> + } + + // Apply requantization and relu6. + func.func private @internal_requantize_and_relu6_fn(%input : tensor<*x!tf_type.qint32>, %input_scale : tensor<*xf32>, %input_zp : tensor<*xi32>, + %out_scale : tensor<*xf32>, %out_zp : tensor<*xi32>) -> tensor<*x!tf_type.qint8> { + %i8_min = "tf.Const"() {value = dense<-128.0> : tensor} : () -> tensor + %i8_max = "tf.Const"() {value = dense<127.0> : tensor} : () -> tensor + %act_max = "tf.Const"() {value = dense<6.0> : tensor} : () -> tensor + %i8_act_max_0 = "tf.PartitionedCall"(%act_max, %input_scale, %input_zp) { + config = "", config_proto = "", executor_type = "", f=@quantize_i8 + } : (tensor, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint8> + %i8_act_max_1 = "tf.Cast"(%i8_act_max_0) {Truncate = false} : (tensor<*x!tf_type.qint8>) -> tensor + %float_out_zp = "tf.Cast"(%out_zp) {Truncate = false} : (tensor<*xi32>) -> tensor<*xf32> + %clip_min = "tf.Maximum"(%i8_min, %float_out_zp) : (tensor, tensor<*xf32>) -> tensor + %clip_max = "tf.Minimum"(%i8_max, %i8_act_max_1) : (tensor, tensor) -> tensor + %qclip_min = "tf.Cast"(%i8_min) {Truncate = false} : (tensor) -> tensor + %qclip_max = "tf.Cast"(%i8_max) {Truncate = false} : (tensor) -> tensor + %relu = "tf.UniformQuantizedClipByValue"(%input, %qclip_min, %qclip_max, %out_scale, %out_zp) { + T = "tfdtype$DT_QINT32", + quantization_axis = -1, + quantization_min_val = -128, + quantization_max_val = 127, + attr_map = "" + } : (tensor<*x!tf_type.qint32>, tensor, tensor, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint32> + %requantize = "tf.PartitionedCall"(%relu, %input_scale, %input_zp, %out_scale, %out_zp) { + config = "", config_proto = "", executor_type = "", f=@internal_requantize_qi8_fn + } : (tensor<*x!tf_type.qint32>, tensor<*xf32>, tensor<*xi32>, tensor<*xf32>, tensor<*xi32>) -> tensor<*x!tf_type.qint8> + func.return %requantize : tensor<*x!tf_type.qint8> } } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_uniform_quantized_drq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_uniform_quantized_drq.mlir index 8588eede7f2..a2db0c25f10 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_uniform_quantized_drq.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library_uniform_quantized_drq.mlir @@ -18,25 +18,73 @@ // Internal functions should be marked as private. They will be inlined and // deleted in `InsertQuantizedFunctionsPass`. // -// Function template can generate functions with different parameters. Ex: -// ``` -// parameters[ -// {"key1": "value11", "key2": "value21"}, -// {"key1": "value12", "key2": "value22"}, -// ] -// func.func func_name_${key1}_fn (...) { -// ...${key2}... -// } -// ``` -// The above template with generate two functions by substituting `key1` and -// `key2` with given values. +// For Uniform Quantized op case, attributes are generated during quantize +// composite pass. Therefore, attr_map is set to an empty string. module { - // TODO(b/238600711): Populate attributes for quantized_function_library_uniform_quantized + // Currently only 4-d case is supported + func.func @quantized_conv2d_fn( + %input : tensor<*xf32>, %weight : tensor<*x!tf_type.qint8>, + %weight_scale : tensor<*xf32>, %weight_zp : tensor<*xi32>) -> tensor<*xf32> + attributes {tf_quant.quantized_ops = ["Conv2D"]} { + + %out = "tf.UniformQuantizedConvolutionHybrid"(%input, %weight, + %weight_scale, %weight_zp) { + Tlhs = "tfdtype$DT_FLOAT", + Trhs = "tfdtype$DT_QINT8", + Tout = "tfdtype$DT_FLOAT", + window_strides = [1, 1], + padding = "", + explicit_padding = [], + lhs_dilation = [], + rhs_dilation = [], + dimension_numbers = "", + batch_group_count = 1, + feature_group_count = 1, + rhs_quantization_axis = -1, + rhs_quantization_min_val = -128, + rhs_quantization_max_val = 127, + attr_map = "" + } : (tensor<*xf32>, tensor<*x!tf_type.qint8>, tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32> + + + func.return %out : tensor<*xf32> + } + + // Currently only 4-d case is supported + func.func @quantized_depthwise_conv2d_fn( + %input : tensor<*xf32>, %weight : tensor<*x!tf_type.qint8>, + %weight_scale : tensor<*xf32>, %weight_zp : tensor<*xi32>) -> tensor<*xf32> + attributes {tf_quant.quantized_ops = ["DepthwiseConv2D"]} { + + %out = "tf.UniformQuantizedConvolutionHybrid"(%input, %weight, + %weight_scale, %weight_zp) { + Tlhs = "tfdtype$DT_FLOAT", + Trhs = "tfdtype$DT_QINT8", + Tout = "tfdtype$DT_FLOAT", + window_strides = [1, 1], + padding = "", + explicit_padding = [], + lhs_dilation = [], + rhs_dilation = [], + dimension_numbers = "", + batch_group_count = 1, + feature_group_count = 1, + rhs_quantization_axis = -1, + rhs_quantization_min_val = -128, + rhs_quantization_max_val = 127, + attr_map = "" + } : (tensor<*xf32>, tensor<*x!tf_type.qint8>, tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32> + + func.return %out : tensor<*xf32> + } + + // Currently only 4-d case is supported func.func @quantized_matmul_fn( %input : tensor<*xf32>, %weight : tensor<*x!tf_type.qint8>, - %weight_scale : tensor<*xf32>, %weight_zp : tensor<*xi32>) -> tensor<*xf32> { + %weight_scale : tensor<*xf32>, %weight_zp : tensor<*xi32>) -> tensor<*xf32> + attributes {tf_quant.quantized_ops = ["MatMul"]} { %out = "tf.UniformQuantizedDotHybrid"(%input, %weight, %weight_scale, %weight_zp) { @@ -46,7 +94,7 @@ module { rhs_quantization_axis = -1, rhs_quantization_min_val = -128, rhs_quantization_max_val = 127, - attr_map = "0:Tlhs,1:Trhs,2:Tout,3:rhs_quantization_axis,4:rhs_quantization_min_val,5:rhs_quantization_max_val" + attr_map = "" } : (tensor<*xf32>, tensor<*x!tf_type.qint8>, tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32> func.return %out : tensor<*xf32> diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_var_init_by_const.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_var_init_by_const.cc new file mode 100644 index 00000000000..b27a4456356 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_var_init_by_const.cc @@ -0,0 +1,121 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" + +namespace mlir { +namespace quant { +namespace { + +using ::mlir::tf_saved_model::GetInitializerFunction; +using ::mlir::tf_saved_model::kTfSavedModelInitializerRestoreType; + +// A pass that removes `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` patterns +// from the initializer function (type = "restore_op"). +// +// Note: initializing values (`tf.Const`s) will be removed and this may result +// in an information loss and uninitialized variable errors. Make sure that this +// effect is desired (e.g. there is a `tf.RestoreV2Op` restoring the variables +// instead). +class RemoveVariableInitializationByConstPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + RemoveVariableInitializationByConstPass) + + StringRef getArgument() const final { + return "quant-remove-var-init-by-const"; + } + + StringRef getDescription() const final { + return "Removes `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` patterns " + "from the initializer function of type 'restore_op'."; + } + + void runOnOperation() override; +}; + +// Finds and removes the `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` +// pattern. `tf.VarHandleOp` and `tf.Const` are removed unless they are used by +// other ops. +struct RemoveVariableAssignmentByConst + : public OpRewritePattern { + // Inherit the constructors. + using OpRewritePattern::OpRewritePattern; + + LogicalResult match(TF::AssignVariableOp assign_op) const override { + Value resource_operand = assign_op.getOperand(0); + Value assigned_value_operand = assign_op.getOperand(1); + + if (isa(resource_operand.getDefiningOp()) && + isa(assigned_value_operand.getDefiningOp())) { + return success(); + } else { + return failure(); + } + } + + void rewrite(TF::AssignVariableOp assign_op, + PatternRewriter& rewriter) const override { + // `TF::ConstOp` and `TF::VarHandleOp` are not manually erased. + // `applyPatternsAndFoldGreedily` performs dead code elimination and unsed + // ops will be erased during the optimization. + rewriter.eraseOp(assign_op); + } +}; + +void RemoveVariableInitializationByConstPass::runOnOperation() { + MLIRContext& ctx = getContext(); + + RewritePatternSet patterns(&ctx); + patterns.add(&ctx); + + ModuleOp module_op = getOperation(); + func::FuncOp init_func_op = GetInitializerFunction( + module_op, /*initializer_type=*/kTfSavedModelInitializerRestoreType); + if (init_func_op) { + if (failed( + applyPatternsAndFoldGreedily(init_func_op, std::move(patterns)))) { + init_func_op->emitError( + "Failed to remove variable assignment by const patterns."); + signalPassFailure(); + } + } else { + LOG(INFO) << "Initializer function with type 'restore_op' does not exist. " + "'RemoveVariableInitializationByConstPass' is a no-op."; + } +} + +static PassRegistration pass{}; + +} // namespace + +std::unique_ptr> +CreateRemoveVariableInitializationByConstPass() { + return std::make_unique(); +} +} // namespace quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc index fed951810f6..b8de48a954f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc @@ -13,13 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include +#include #include #include +#include "absl/algorithm/container.h" #include "absl/strings/str_format.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -74,32 +80,143 @@ void PrepareXlaConvParams(OpBuilder &builder, Location loc, ArrayAttr strides, CreateScalarConstValue(builder, loc, feature_group_cnt); } -// Calculates zero-point offset by reducing weights and multiply it with zp. -Value CalculateZeroPointOffset(OpBuilder &builder, Location loc, Value filter, - int8_t input_zp, int output_dim) { - auto weight_shape = filter.getType().template cast(); - SmallVector weight_non_output_indices; - for (int64_t i : llvm::seq(0, weight_shape.getRank())) { - if (i != output_dim) weight_non_output_indices.push_back(i); +// Calculates other_tensor_zp * tensor for zero point offset calculation. +Value CreateZeroPointPartialOffset(OpBuilder &builder, Location loc, + Value tensor, int8_t other_tensor_zp, + const ArrayRef output_dims) { + if (other_tensor_zp == 0) { + return CreateScalarConstValue(builder, loc, 0); + } + + auto shape = tensor.getType().template cast(); + SmallVector non_output_indices; + for (int64_t i : llvm::seq(0, shape.getRank())) { + if (absl::c_count(output_dims, i) == 0) { + non_output_indices.push_back(i); + } } - Value reduction_indices_value = - Create1DConstValue(builder, loc, weight_non_output_indices); - Value zp = CreateScalarConstValue(builder, loc, input_zp); - - TensorType filter_type = filter.getType().dyn_cast(); - Value filter_i32 = builder.create( - loc, filter_type.clone(builder.getIntegerType(32)), filter); - auto zp_mul_output_type = RankedTensorType::get( - {weight_shape.getDimSize(output_dim)}, builder.getIntegerType(32)); - auto reduced = builder.create( - loc, zp_mul_output_type, filter_i32, reduction_indices_value, - /*keep_dims=*/builder.getBoolAttr(false)); - TF::MulOp mul_op = builder.create(loc, zp, reduced); + auto reduction_indices_value = + Create1DConstValue(builder, loc, non_output_indices); + auto zp = CreateScalarConstValue(builder, loc, other_tensor_zp); + + TensorType tensor_type = tensor.getType().dyn_cast(); + Value tensor_i32 = builder.create( + loc, tensor_type.clone(builder.getIntegerType(32)), tensor); + auto reduced = + builder.create(loc, tensor_i32, reduction_indices_value, + /*keep_dims=*/builder.getBoolAttr(true)); + auto mul_op = builder.create(loc, zp, reduced); + llvm::SmallVector folded_results = ConstantFoldOpIfPossible(mul_op); return folded_results.front(); } +// Calculates zero-point offset by reducing the weight and multiply it with zp. +// Originally, we have: +// output = (int8_input - input_zp) * (int8_weight - weight_zp) +// So, offset = input_zp * int8_weight + weight_zp * int8_input +// - input_zp * weight_zp. +// This function calculates the `offset` value mentioned above. Note that the +// `output_dims` is the weight dimensions that are not contracted, so they +// appear in the output shape. +Value CalculateZeroPointOffset(OpBuilder &builder, Location loc, Value input, + Value weight, int8_t input_zp, int8_t weight_zp, + const ArrayRef input_output_dims, + const ArrayRef weight_output_dims) { + Value zp_input_contribution = CreateZeroPointPartialOffset( + builder, loc, input, weight_zp, input_output_dims); + Value zp_weight_contribution = CreateZeroPointPartialOffset( + builder, loc, weight, input_zp, weight_output_dims); + + auto weight_shape = weight.getType().template cast(); + SmallVector weight_non_output_indices; + for (auto i : llvm::seq(0, weight_shape.getRank())) { + if (absl::c_count(weight_output_dims, i) == 0) { + weight_non_output_indices.push_back(i); + } + } + + if (input_zp != 0 && weight_zp != 0) { + // Add two contributions, and a zeropoint modification term + // Consider two quantized matrices P, Q with zero points z, w. Let's say the + // dimensions are l X n, n X m. + // What we want to calculate is: R = matmul(P-z, Q-w). + // Then r_ij = sigma(k) (p_ik - z) * (q_kj - w) + // = sigma(k)(p_ik * q_kj) - w * sigma(k)p_ik - z * sigma(k)q_kj + // + sigma(k)z*w. + // zp_input_contribution = z * sigma(k)q_kj + // zp_weight_contribution = w * sigma(k)p_ik + // In case z != 0 and w != 0, we need to additionally calculate sigma(k)z*w, + // which is: # of reduced dim(n in this case) * input_zp * weight_zp + int32_t static_dim_total = 1; + Value accum_dynamic_dim = nullptr; + llvm::SmallVector weight_non_output_dynamic_indices; + for (const int64_t weight_idx : weight_non_output_indices) { + if (weight_shape.isDynamicDim(weight_idx)) { + weight_non_output_dynamic_indices.push_back(weight_idx); + } else { + static_dim_total *= weight_shape.getDimSize(weight_idx); + } + } + + if (!weight_non_output_dynamic_indices.empty()) { + // Has dynamic shapes. + auto weight_shape_op = builder.create( + loc, weight, /*use32Bit=*/builder.getBoolAttr(false)); + + auto slice_output_type = RankedTensorType::get({1}, builder.getI64Type()); + auto slice_stride = CreateConstValue(builder, loc, {1}, {1}); + for (int64_t weight_idx : weight_non_output_dynamic_indices) { + auto start = CreateConstValue(builder, loc, {1}, {weight_idx}); + auto end = + CreateConstValue(builder, loc, {1}, {weight_idx + 1}); + auto sliced_shape_op = builder.create( + loc, slice_output_type, weight_shape_op, start, end, slice_stride); + if (accum_dynamic_dim == nullptr) { + accum_dynamic_dim = sliced_shape_op->getResults().front(); + } else { + accum_dynamic_dim = + builder + .create(loc, accum_dynamic_dim, sliced_shape_op) + ->getResults() + .front(); + } + } + } + + const int32_t zp_constant_offset = static_cast(input_zp) * + static_cast(weight_zp) * + static_dim_total; + auto zp_offset_value = + CreateScalarConstValue(builder, loc, zp_constant_offset); + if (accum_dynamic_dim != nullptr) { + accum_dynamic_dim = + builder + .create( + loc, mlir::RankedTensorType::get({1}, builder.getI32Type()), + accum_dynamic_dim) + ->getResults() + .front(); + auto mul_op = + builder.create(loc, accum_dynamic_dim, zp_offset_value); + zp_offset_value = mul_op->getResults().front(); + } + + auto offset_sum = builder.create(loc, zp_input_contribution, + zp_weight_contribution); + auto offset_op = + builder.create(loc, offset_sum, zp_offset_value); + + llvm::SmallVector folded_results = + ConstantFoldOpIfPossible(offset_op); + return folded_results.front(); + } + + if (input_zp != 0) return zp_weight_contribution; + return zp_input_contribution; +} + // Helper function to create a XlaConvV2Op for Conv2DOp, DepthwiseConv2DOp and // Conv3DOp. Value CreateXlaConvOp(OpBuilder &builder, Location loc, Value input, @@ -170,16 +287,21 @@ Value CreateXlaConvOp(OpBuilder &builder, Location loc, Value input, rhs_dilation, feature_group_count, builder.getStringAttr(dnums.SerializeAsString()), /*precision_config=*/builder.getStringAttr(precision_config_str)) - .output(); + .getOutput(); + + // Dynamic-range quantization wil always fall into this case. if (input_zp_value == 0) return xla_conv_output; - Value zp_offset = CalculateZeroPointOffset(builder, loc, /*filter=*/filter, - /*input_zp=*/input_zp_value, - /*output_dim=*/num_dims - 1); - return builder.create(loc, xla_conv_output, zp_offset).z(); + Value zp_offset = CalculateZeroPointOffset( + builder, loc, input, filter, input_zp_value, + /*weight_zp=*/0, + /*input_output_dims=*/ArrayRef({0}), + /*weight_output_dims=*/ArrayRef({num_dims - 1})); + return builder.create(loc, xla_conv_output, zp_offset).getZ(); } -// Creates a XlaConvV2Op from TF Conv2DOp and returns its output. +// Creates a XlaConvV2Op from TF Conv2DOp and returns its output. The returned +// value will be used as an input of the next op. Value CreateXlaConvOpFromTfConv2dOp(OpBuilder &builder, Location loc, Value input, Value filter, Value input_zp, Value conv_output, ArrayAttr strides, @@ -252,11 +374,18 @@ Value CreateXlaConvOpFromTfConv3dOp(OpBuilder &builder, Location loc, // Helper function to create an XlaDotV2Op. Value CreateXlaDotV2Op(OpBuilder &builder, Location loc, Value input, - Value weight, Value input_zp, Value output, - const xla::DotDimensionNumbers &dnums, + Value weight, Value input_zp, Value weight_zp, + Value output, const xla::DotDimensionNumbers &dnums, bool four_bit = false) { - int32_t input_zp_value; - if (!GetSplatValue(input_zp, input_zp_value)) { + int32_t input_zp_value = 0; + int32_t weight_zp_value = 0; + if (input_zp != nullptr && !GetSplatValue(input_zp, input_zp_value)) { + emitError(loc, + "zero point is expected to be a constant with a single value"); + return {}; + } + + if (weight_zp != nullptr && !GetSplatValue(weight_zp, weight_zp_value)) { emitError(loc, "zero point is expected to be a constant with a single value"); return {}; @@ -271,6 +400,7 @@ Value CreateXlaDotV2Op(OpBuilder &builder, Location loc, Value input, precision_config.add_operand_precision(xla::PrecisionConfig::PACKED_NIBBLE); precision_config_str = precision_config.SerializeAsString(); } + Value dot_result = builder .create( @@ -282,15 +412,26 @@ Value CreateXlaDotV2Op(OpBuilder &builder, Location loc, Value input, /*precision_config=*/builder.getStringAttr(precision_config_str)) .getResult(); - Value zp_offset = - CalculateZeroPointOffset(builder, loc, weight, input_zp_value, - /*output_dim=*/1); + auto input_shape = input.getType().template cast(); + auto weight_shape = weight.getType().template cast(); + SmallVector input_output_dims(input_shape.getRank() - 2); + SmallVector weight_output_dims(weight_shape.getRank() - 2); + absl::c_iota(input_output_dims, 0); + absl::c_iota(weight_output_dims, 0); + input_output_dims.push_back(weight_shape.getRank() - 2); + weight_output_dims.push_back(weight_shape.getRank() - 1); + + Value zp_offset = CalculateZeroPointOffset( + builder, loc, input, weight, input_zp_value, weight_zp_value, + ArrayRef(input_output_dims), + ArrayRef(weight_output_dims)); return builder.create(loc, dot_result, zp_offset); } Value CreateXlaDotV2OpFromTfMatMulOp(OpBuilder &builder, Location loc, Value input, Value weight, Value input_zp, - Value output, BoolAttr transpose_a, + Value weight_zp, Value output, + BoolAttr transpose_a, BoolAttr transpose_b) { // Transpose and constant-fold the weight if needed. if (transpose_b.getValue()) { @@ -307,7 +448,174 @@ Value CreateXlaDotV2OpFromTfMatMulOp(OpBuilder &builder, Location loc, dnums.add_lhs_contracting_dimensions(1); } - return CreateXlaDotV2Op(builder, loc, input, weight, input_zp, output, dnums); + return CreateXlaDotV2Op(builder, loc, input, weight, input_zp, weight_zp, + output, dnums); +} + +// Gets the broadcasted shapes of the input and weight of the BatchMatMul op +// from their types. If there are dynamic dimesions, these shapes couldn't be +// used as the arguments for the BroadcastTo ops. +llvm::Optional, SmallVector>> +GetBroadcastShapesForBatchMatmul(ShapedType input_type, + ShapedType weight_type) { + ArrayRef input_shape = input_type.getShape(); + ArrayRef weight_shape = weight_type.getShape(); + + const int64_t num_matmul_dim = 2; + const int64_t num_input_batch_dim = input_type.getRank() - num_matmul_dim; + const int64_t num_weight_batch_dim = weight_type.getRank() - num_matmul_dim; + + ArrayRef input_batch_dims = + input_shape.slice(0, num_input_batch_dim); + ArrayRef weight_batch_dims = + weight_shape.slice(0, num_weight_batch_dim); + ArrayRef input_matmul_dims = + input_shape.slice(num_input_batch_dim, num_matmul_dim); + ArrayRef weight_matmul_dims = + weight_shape.slice(num_weight_batch_dim, num_matmul_dim); + + SmallVector broadcasted_batch_dims; + if (!OpTrait::util::getBroadcastedShape(input_batch_dims, weight_batch_dims, + broadcasted_batch_dims)) { + return std::nullopt; + } + SmallVector broadcasted_input_shape(broadcasted_batch_dims); + broadcasted_input_shape.append(input_matmul_dims.begin(), + input_matmul_dims.end()); + SmallVector broadcasted_weight_shape(broadcasted_batch_dims); + broadcasted_weight_shape.append(weight_matmul_dims.begin(), + weight_matmul_dims.end()); + + return std::make_pair(std::move(broadcasted_input_shape), + std::move(broadcasted_weight_shape)); +} + +// Broadcasts batch dimensions of the input and weight of the BatchMatMul +// op. In XLA, shapes are all constants, so all operations created in this +// function, except BroadcastTo, are expected to be folded. +void BroadcastBatchDimensionsForBatchMatMul(OpBuilder &builder, Location loc, + Value &input, Value &weight) { + ShapedType input_type = input.getType().template cast(); + ShapedType weight_type = weight.getType().template cast(); + const int32_t input_rank = input_type.getRank(); + const int32_t weight_rank = weight_type.getRank(); + const int32_t broadcasted_rank = std::max(input_rank, weight_rank); + + const int32_t num_matmul_dim = 2; + const int32_t num_input_batch_dim = input_rank - num_matmul_dim; + const int32_t num_weight_batch_dim = weight_rank - num_matmul_dim; + if (num_input_batch_dim == 0 && num_weight_batch_dim == 0) return; + + // If the broadcasted shapes can be calculated statically, only add two + // BroadcastTo ops for input and weight. + auto broadcasted_shapes_or = + GetBroadcastShapesForBatchMatmul(input_type, weight_type); + if (!broadcasted_shapes_or.has_value()) return; + const auto broadcasted_input_type = RankedTensorType::get( + broadcasted_shapes_or->first, input_type.getElementType()); + const auto broadcasted_weight_type = RankedTensorType::get( + broadcasted_shapes_or->second, weight_type.getElementType()); + + if (broadcasted_input_type.hasStaticShape() && + broadcasted_weight_type.hasStaticShape()) { + input = builder.create( + loc, broadcasted_input_type, input, + Create1DConstValue(builder, loc, broadcasted_shapes_or->first)); + weight = builder.create( + loc, broadcasted_weight_type, weight, + Create1DConstValue(builder, loc, broadcasted_shapes_or->second)); + return; + } + + const Value zero = Create1DConstValue(builder, loc, {0}); + const Value num_matmul_dim_value = + Create1DConstValue(builder, loc, {num_matmul_dim}); + const Value num_input_batch_dim_value = + Create1DConstValue(builder, loc, {num_input_batch_dim}); + const Value num_weight_batch_dim_value = + Create1DConstValue(builder, loc, {num_weight_batch_dim}); + + // Decompose the input and weight shape into batch and matmul dimensions. + Value input_shape = builder.create( + loc, input, /*use32Bit=*/builder.getBoolAttr(false)); + Value input_batch_dims = builder.create( + loc, RankedTensorType::get({num_input_batch_dim}, builder.getI64Type()), + input_shape, zero, num_input_batch_dim_value); + Value input_matmul_dims = builder.create( + loc, RankedTensorType::get({num_matmul_dim}, builder.getI64Type()), + input_shape, num_input_batch_dim_value, num_matmul_dim_value); + + Value weight_shape = builder.create( + loc, weight, /*use32Bit=*/builder.getBoolAttr(false)); + Value weight_batch_dims = builder.create( + loc, RankedTensorType::get({num_weight_batch_dim}, builder.getI64Type()), + weight_shape, zero, num_weight_batch_dim_value); + Value weight_matmul_dims = builder.create( + loc, RankedTensorType::get({num_matmul_dim}, builder.getI64Type()), + weight_shape, num_weight_batch_dim_value, num_matmul_dim_value); + + // Calculate the broadcasted shapes. + Value broadcasted_batch_dims = builder.create( + loc, + RankedTensorType::get({broadcasted_rank - num_matmul_dim}, + builder.getI64Type()), + input_batch_dims, weight_batch_dims); + Type broadcasted_shape_type = + RankedTensorType::get({broadcasted_rank}, builder.getI64Type()); + + const Value zero_scalar = CreateScalarConstValue(builder, loc, 0); + Value broacasted_input_shape = builder.create( + loc, broadcasted_shape_type, /*concat_dim=*/zero_scalar, + ValueRange{broadcasted_batch_dims, input_matmul_dims}); + Value broacasted_weight_shape = builder.create( + loc, broadcasted_shape_type, /*concat_dim=*/zero_scalar, + ValueRange{broadcasted_batch_dims, weight_matmul_dims}); + + // Broadcast input and weight with the calculated shapes. + input = builder.create(loc, broadcasted_input_type, input, + broacasted_input_shape); + weight = builder.create(loc, broadcasted_weight_type, + weight, broacasted_weight_shape); +} + +Value CreateXlaDotV2OpFromTfBatchMatMulOp(OpBuilder &builder, Location loc, + Value input, Value weight, + Value input_zp, Value weight_zp, + Value output, BoolAttr adj_x, + BoolAttr adj_y) { + // TensorFlow BatchMatMulOp allows the batch dimensions to be broadcastable + // while the XlaDotV2Op doesn't. So we have to broadcast them beforehand. + BroadcastBatchDimensionsForBatchMatMul(builder, loc, input, weight); + + // Both input and weight have the same rank after broadcasting. + ShapedType weight_shape = weight.getType().template cast(); + int num_batch_dim = weight_shape.getRank() - 2; + + // Transpose and constant-fold the weight if needed. + if (adj_y.getValue()) { + SmallVector perm_values(num_batch_dim); + absl::c_iota(perm_values, 0); + perm_values.push_back(num_batch_dim + 1); + perm_values.push_back(num_batch_dim); + Value perm = Create1DConstValue(builder, loc, perm_values); + auto transpose_op = builder.create(loc, weight, perm); + weight = ConstantFoldOpIfPossible(transpose_op).front(); + } + + xla::DotDimensionNumbers dnums; + for (int i : llvm::seq(0, num_batch_dim)) { + dnums.add_lhs_batch_dimensions(i); + dnums.add_rhs_batch_dimensions(i); + } + dnums.add_rhs_contracting_dimensions(num_batch_dim); + if (adj_x.getValue()) { + dnums.add_lhs_contracting_dimensions(num_batch_dim); + } else { + dnums.add_lhs_contracting_dimensions(num_batch_dim + 1); + } + + return CreateXlaDotV2Op(builder, loc, input, weight, input_zp, weight_zp, + output, dnums); } #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.inc" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.td index 49184b95bca..7e679ff4e9c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.td @@ -31,6 +31,9 @@ def CreateXlaDotV2OpFromTfMatMulOp : NativeCodeCall< def CreateXLAConvOpFromTFConv3DOp : NativeCodeCall< "CreateXlaConvOpFromTfConv3dOp($_builder, $_loc, $0...)">; +def CreateXlaDotV2OpFromTfBatchMatMulOp : NativeCodeCall< + "CreateXlaDotV2OpFromTfBatchMatMulOp($_builder, $_loc, $0...)">; + // Converts inlined Conv2D pattern to TF XlaConvV2 op. This pattern doesn't // support non-constant weights. def ConvertTFConv2DToXLAConvOp : Pat< @@ -43,6 +46,27 @@ def ConvertTFConv2DToXLAConvOp : Pat< $input, $filter, $input_zp, $conv, $strides, $dilations, $padding, $explicit_padding), [(IsInt8ElementType $input), + (IsInt8ElementType $filter), + (IsConstTensor $input_zp), + (IsConstTensor $filter), + (IsInt32ElementType $conv), + (HasStaticShapeConstraint $filter), + (HasStaticShapeAtDimsConstraint<"3"> $input)], + (addBenefit 10)>; + +// Same as ConvertTFConv2DToXLAConvOp but handles the case where input zero +// point is dynaically calculated so not a constant. +def ConvertTFConv2DToXLAConvOpDynamicRange : Pat< + (TF_Conv2DOp:$conv + (TF_SubOp:$input (TF_CastOp $input_i8, $truncate0), $input_zp), + (TF_CastOp (TF_IdentityOp $filter), $truncate1), + $strides, $use_cudnn, $padding, $explicit_padding, + IsDataFormatNHWC:$data_format, $dilations), + (CreateXLAConvOpFromTFConv2DOp + $input, $filter, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $conv, $strides, + $dilations, $padding, $explicit_padding), + [(IsInt32ElementType $input), (IsInt8ElementType $filter), (IsConstTensor $filter), (IsInt32ElementType $conv), @@ -74,7 +98,7 @@ def ConvertTFConv2DWithNoZeroPointToXLAConvOp : Pat< def ConvertTFDepthwiseConv2DToXLAConvOp : Pat< (TF_CastOp:$conv (TF_DepthwiseConv2dNativeOp - (TF_CastOp:$casted_input + (TF_CastOp:$cast_input (TF_SubOp (TF_CastOp $input, $truncate1), $input_zp), $truncate2), (TF_CastOp (TF_CastOp (TF_IdentityOp $filter), $truncate3), $truncate4), @@ -84,7 +108,31 @@ def ConvertTFDepthwiseConv2DToXLAConvOp : Pat< $input, $filter, $input_zp, $conv, $strides, $dilations, $padding, $explicit_padding), [(IsInt8ElementType $input), - (IsF32ElementType $casted_input), + (IsF32ElementType $cast_input), + (IsInt8ElementType $filter), + (IsConstTensor $input_zp), + (IsConstTensor $filter), + (IsInt32ElementType $conv), + (HasStaticShapeConstraint $filter), + (HasStaticShapeAtDimsConstraint<"3"> $input)], + (addBenefit 10)>; + +// Same as ConvertTFDepthwiseConv2DToXLAConvOp but handles the case where input +// zero point is dynaically calculated so not a constant. +def ConvertTFDepthwiseConv2DToXLAConvOpDynamicRange : Pat< + (TF_CastOp:$conv + (TF_DepthwiseConv2dNativeOp + (TF_CastOp + (TF_SubOp:$input (TF_CastOp $input_i8, $truncate0), $input_zp), $truncate1), + (TF_CastOp + (TF_CastOp (TF_IdentityOp $filter), $truncate2), $truncate3), + $strides, $padding, $explicit_padding, + IsDataFormatNHWC:$data_format, $dilations), $truncate4), + (CreateXLAConvOpFromTFDepthwiseConv2DOp + $input, $filter, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $conv, $strides, + $dilations, $padding, $explicit_padding), + [(IsInt32ElementType $input), (IsInt8ElementType $filter), (IsConstTensor $filter), (IsInt32ElementType $conv), @@ -97,7 +145,7 @@ def ConvertTFDepthwiseConv2DToXLAConvOp : Pat< def ConvertTFDepthwiseConv2DWithNoZeroPointToXLAConvOp : Pat< (TF_CastOp:$conv (TF_DepthwiseConv2dNativeOp - (TF_CastOp:$casted_input + (TF_CastOp:$cast_input (TF_CastOp $input, $truncate1), $truncate2), (TF_CastOp (TF_CastOp (TF_IdentityOp $filter), $truncate3), $truncate4), @@ -107,7 +155,7 @@ def ConvertTFDepthwiseConv2DWithNoZeroPointToXLAConvOp : Pat< $input, $filter, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), $conv, $strides, $dilations, $padding, $explicit_padding), [(IsInt8ElementType $input), - (IsF32ElementType $casted_input), + (IsF32ElementType $cast_input), (IsInt8ElementType $filter), (IsConstTensor $filter), (IsInt32ElementType $conv), @@ -124,8 +172,29 @@ def ConvertTFMatMulToXLADotV2Op : Pat< (TF_CastOp (TF_IdentityOp $weight), $truncate1), $transpose_a, $transpose_b), (CreateXlaDotV2OpFromTfMatMulOp - $input, $weight, $input_zp, $matmul, $transpose_a, $transpose_b), + $input, $weight, $input_zp, + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), $matmul, + $transpose_a, $transpose_b), [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (IsConstTensor $input_zp), + (IsConstTensor $weight), + (IsInt32ElementType $matmul), + (HasStaticShapeConstraint $weight)], + (addBenefit 10)>; + +// Same as ConvertTFMatMulToXLADotV2Op but handles the case where input zero +// point is dynaically calculated so not a constant. +def ConvertTFMatMulToXLADotV2OpDynamicRange : Pat< + (TF_MatMulOp:$matmul + (TF_SubOp:$input (TF_CastOp $input_i8, $truncate0), $input_zp), + (TF_CastOp (TF_IdentityOp $weight), $truncate1), + $transpose_a, $transpose_b), + (CreateXlaDotV2OpFromTfMatMulOp + $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $matmul, $transpose_a, $transpose_b), + [(IsInt32ElementType $input), (IsInt8ElementType $weight), (IsConstTensor $weight), (IsInt32ElementType $matmul), @@ -141,6 +210,7 @@ def ConvertTFMatMulWithNoZeroPointToXLADotV2Op : Pat< $transpose_a, $transpose_b), (CreateXlaDotV2OpFromTfMatMulOp $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), $matmul, $transpose_a, $transpose_b), [(IsInt8ElementType $input), (IsInt8ElementType $weight), @@ -149,12 +219,86 @@ def ConvertTFMatMulWithNoZeroPointToXLADotV2Op : Pat< (HasStaticShapeConstraint $weight)], (addBenefit 10)>; +// Converts inlined MatMul pattern to TF XlaDotV2 op. This pattern supports +// non-constant weights. +def ConvertTFMatMulWithTwoInputTensorsToXLADotV2Op : Pat< + (TF_MatMulOp:$matmul + (TF_SubOp (TF_CastOp $input, $truncate1), $input_zp), + (TF_SubOp (TF_CastOp $weight, $truncate2), $weight_zp), + $transpose_a, $transpose_b), + (CreateXlaDotV2OpFromTfMatMulOp + $input, $weight, $input_zp, $weight_zp, $matmul, $transpose_a, $transpose_b), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (HasRank $input), + (HasRank $weight), + (HasRankOf<0> $input_zp), + (HasRankOf<0> $weight_zp), + (IsInt32ElementType $matmul)], + (addBenefit 10)>; + +// Same as ConvertTFMatMulWithTwoInputTensorsToXLADotV2Op but handles the case +// where input zero point is 0 and the Sub op has been folded. +def ConvertTFMatMulWithTwoInputTensorsAndNoInputZeroPointToXLADotV2Op : Pat< + (TF_MatMulOp:$matmul + (TF_CastOp $input, $truncate), + (TF_SubOp (TF_CastOp $weight, $truncate2), $weight_zp), + $transpose_a, $transpose_b), + (CreateXlaDotV2OpFromTfMatMulOp + $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $weight_zp, $matmul, $transpose_a, $transpose_b), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (HasRank $input), + (HasRank $weight), + (HasRankOf<0> $weight_zp), + (IsInt32ElementType $matmul)], + (addBenefit 10)>; + +// Same as ConvertTFMatMulWithTwoInputTensorsToXLADotV2Op but handles the case +// where weight zero point is 0 and the Sub op has been folded. +def ConvertTFMatMulWithTwoInputTensorsAndNoWeightZeroPointToXLADotV2Op : Pat< + (TF_MatMulOp:$matmul + (TF_SubOp (TF_CastOp $input, $truncate), $input_zp), + (TF_CastOp $weight, $truncate1), + $transpose_a, $transpose_b), + (CreateXlaDotV2OpFromTfMatMulOp + $input, $weight, $input_zp, + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $matmul, $transpose_a, $transpose_b), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (HasRank $input), + (HasRank $weight), + (HasRankOf<0> $input_zp), + (IsInt32ElementType $matmul)], + (addBenefit 10)>; + +// Same as ConvertTFMatMulWithTwoInputTensorsToXLADotV2Op but handles the case +// where both zero point is 0 and the Sub op has been folded. +def ConvertTFMatMulWithTwoInputTensorsAndNoBothZeroPointsToXLADotV2Op : Pat< + (TF_MatMulOp:$matmul + (TF_CastOp $input, $truncate), + (TF_CastOp $weight, $truncate1), + $transpose_a, $transpose_b), + (CreateXlaDotV2OpFromTfMatMulOp + $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $matmul, $transpose_a, $transpose_b), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (HasRank $input), + (HasRank $weight), + (IsInt32ElementType $matmul)], + (addBenefit 10)>; + + // Converts inlined Conv3D pattern to TF XlaConvV2 op. This pattern // doesn't support non-constant weights. def ConvertTFConv3DToXLAConvOp : Pat< (TF_CastOp:$conv (TF_Conv3DOp - (TF_CastOp:$casted_input + (TF_CastOp:$cast_input (TF_SubOp (TF_CastOp $input, $truncate1), $input_zp), $truncate2), (TF_CastOp (TF_CastOp (TF_IdentityOp $filter), $truncate3), $truncate4), @@ -163,7 +307,7 @@ def ConvertTFConv3DToXLAConvOp : Pat< (CreateXLAConvOpFromTFConv3DOp $input, $filter, $input_zp, $conv, $strides, $dilations, $padding), [(IsInt8ElementType $input), - (IsF32ElementType $casted_input), + (IsF32ElementType $cast_input), (IsInt8ElementType $filter), (IsConstTensor $filter), (IsInt32ElementType $conv), @@ -176,7 +320,7 @@ def ConvertTFConv3DToXLAConvOp : Pat< def ConvertTFConv3DWithNoZeroPointToXLAConvOp : Pat< (TF_CastOp:$conv (TF_Conv3DOp - (TF_CastOp:$casted_input + (TF_CastOp:$cast_input (TF_CastOp $input, $truncate1), $truncate2), (TF_CastOp (TF_CastOp (TF_IdentityOp $filter), $truncate3), $truncate4), @@ -186,10 +330,126 @@ def ConvertTFConv3DWithNoZeroPointToXLAConvOp : Pat< $input, $filter, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), $conv, $strides, $dilations, $padding), [(IsInt8ElementType $input), - (IsF32ElementType $casted_input), + (IsF32ElementType $cast_input), (IsInt8ElementType $filter), (IsConstTensor $filter), (IsInt32ElementType $conv), (HasStaticShapeConstraint $filter), (HasStaticShapeAtDimsConstraint<"4"> $input)], (addBenefit 10)>; + +// Converts inlined BatchMatMul pattern to TF XlaDotV2 op. This pattern doesn't +// support non-constant weights. +def ConvertTFBatchMatMulToXLADotV2Op : Pat< + (TF_BatchMatMulV2Op:$batch_matmul + (TF_SubOp (TF_CastOp $input, $truncate), $input_zp), + (TF_CastOp (TF_IdentityOp $weight), $truncate1), + $adj_x, $adj_y), + (CreateXlaDotV2OpFromTfBatchMatMulOp + $input, $weight, $input_zp, + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $batch_matmul, $adj_x, $adj_y), + [(IsInt8ElementType $input), + (HasRank $input), + (IsInt8ElementType $weight), + (IsConstTensor $weight), + (IsInt32ElementType $batch_matmul), + (HasStaticShapeConstraint $weight)], + (addBenefit 10)>; + +// Same as ConvertTFBatchMatMulToXLADotV2Op but handles the case where input +// zero point is 0 and the Sub op has been folded. +def ConvertTFBatchMatMulWithNoZeroPointToXLADotV2Op : Pat< + (TF_BatchMatMulV2Op:$batch_matmul + (TF_CastOp $input, $truncate), + (TF_CastOp (TF_IdentityOp $weight), $truncate1), + $adj_x, $adj_y), + (CreateXlaDotV2OpFromTfBatchMatMulOp + $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $batch_matmul, $adj_x, $adj_y), + [(IsInt8ElementType $input), + (HasRank $input), + (IsInt8ElementType $weight), + (IsConstTensor $weight), + (IsInt32ElementType $batch_matmul), + (HasStaticShapeConstraint $weight)], + (addBenefit 10)>; + +// Converts inlined BatchMatMul pattern to TF XlaDotV2 op. Support for +// non-constant weights. +// TODO(b/263529454): Remove redundant identity of the rule input on the second +// argument. +def ConvertTFBatchMatMulWithTwoInputTensorsToXLADotV2Op : Pat< + (TF_BatchMatMulV2Op:$batch_matmul + (TF_SubOp (TF_CastOp $input, $truncate), $input_zp), + (TF_SubOp (TF_CastOp (TF_IdentityOp $weight), $truncate1), $weight_zp), + $adj_x, $adj_y), + (CreateXlaDotV2OpFromTfBatchMatMulOp + $input, $weight, $input_zp, $weight_zp, $batch_matmul, $adj_x, $adj_y), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (HasRank $input), + (HasRank $weight), + (HasRankOf<0> $input_zp), + (HasRankOf<0> $weight_zp), + (IsInt32ElementType $batch_matmul)], + (addBenefit 10)>; + +// Same as ConvertTFBatchMatMulWithTwoInputTensorsToXLADotV2O but handles +// the case where input zero point is 0 and the Sub op has been folded. +def ConvertTFBatchMatMulWithTwoInputTensorsAndNoInputZeroPointToXLADotV2Op : Pat< + (TF_BatchMatMulV2Op:$batch_matmul + (TF_CastOp $input, $truncate), + (TF_SubOp (TF_CastOp (TF_IdentityOp $weight), $truncate1), $weight_zp), + $adj_x, $adj_y), + (CreateXlaDotV2OpFromTfBatchMatMulOp + $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $weight_zp, $batch_matmul, $adj_x, $adj_y), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (HasRank $input), + (HasRank $weight), + (HasRankOf<0> $weight_zp), + (IsInt32ElementType $batch_matmul)], + (addBenefit 10)>; + +// Same as ConvertTFBatchMatMulWithTwoInputTensorsToXLADotV2O but handles +// the case where weight zero point is 0 and the Sub op has been folded. +def ConvertTFBatchMatMulWithTwoInputTensorsAndNoWeightZeroPointToXLADotV2Op : Pat< + (TF_BatchMatMulV2Op:$batch_matmul + (TF_SubOp (TF_CastOp $input, $truncate1), $input_zp), + (TF_CastOp $weight, $truncate2), + $adj_x, $adj_y), + (CreateXlaDotV2OpFromTfBatchMatMulOp + $input, $weight, $input_zp, + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $batch_matmul, $adj_x, $adj_y), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (HasRank $input), + (HasRank $weight), + (HasRankOf<0> $input_zp), + (IsInt32ElementType $batch_matmul)], + (addBenefit 10)>; + +// Same as ConvertTFBatchMatMulWithTwoInputTensorsToXLADotV2O but handles +// the case where both zero points are 0 and the Sub op has been folded. +def ConvertTFBatchMatMulWithTwoInputTensorsAndNoBothZeroPointsToXLADotV2Op : Pat< + (TF_BatchMatMulV2Op:$batch_matmul + (TF_CastOp $input, $truncate1), + (TF_CastOp $weight, $truncate2), + $adj_x, $adj_y), + (CreateXlaDotV2OpFromTfBatchMatMulOp + $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $batch_matmul, $adj_x, $adj_y), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (HasRank $input), + (HasRank $weight), + (IsInt32ElementType $batch_matmul)], + (addBenefit 10)>; + + + diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.td index efc863434ba..23991171576 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.td @@ -19,81 +19,6 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/OpAsmInterface.td" -def TF_UniformQuantizedAddOp : TF_Op<"UniformQuantizedAdd", [Pure]> { - // TODO(b/230804708): Improve the operator description. - let summary = "Quantized add operator."; - - let arguments = (ins - TensorOf<[TF_Qint32]>:$lhs, - TensorOf<[TF_Qint32]>:$rhs, - TF_Float32Tensor:$lhs_scales, - TF_Int32Tensor:$lhs_zps, - TF_Float32Tensor:$rhs_scales, - TF_Int32Tensor:$rhs_zps, - TF_Float32Tensor:$output_scales, - TF_Int32Tensor:$output_zps, - - DefaultValuedOptionalAttr:$lhs_quantization_axis, - DefaultValuedOptionalAttr:$lhs_quantization_min_val, - DefaultValuedOptionalAttr:$lhs_quantization_max_val, - DefaultValuedOptionalAttr:$rhs_quantization_axis, - DefaultValuedOptionalAttr:$rhs_quantization_min_val, - DefaultValuedOptionalAttr:$rhs_quantization_max_val, - DefaultValuedOptionalAttr:$output_quantization_axis, - DefaultValuedOptionalAttr:$output_quantization_min_val, - DefaultValuedOptionalAttr:$output_quantization_max_val - ); - - let results = (outs - TensorOf<[TF_Qint32]>:$output - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; -} - -// TODO(b/230804708): Add hybrid convolution. -def TF_UniformQuantizedConvolutionOp : TF_Op<"UniformQuantizedConvolution", [Pure]> { - // TODO(b/230804708): Improve the operator description. - let summary = "Quantized conv2d operator."; - - let arguments = (ins - TensorOf<[TF_Float32, TF_Qint8]>:$lhs, - TensorOf<[TF_Qint8]>:$rhs, - TF_Float32Tensor:$lhs_scales, - TF_Int32Tensor:$lhs_zps, - TF_Float32Tensor:$rhs_scales, - TF_Int32Tensor:$rhs_zps, - TF_Float32Tensor:$output_scales, - TF_Int32Tensor:$output_zps, - TF_Int32Tensor:$window_strides, - TF_Int32Tensor:$padding, - TF_Int32Tensor:$lhs_dilation, - TF_Int32Tensor:$rhs_dilation, - TF_Int32Tensor:$feature_group_count, - - StrAttr:$dimension_numbers, - DefaultValuedOptionalAttr:$batch_group_count, - DefaultValuedOptionalAttr:$lhs_quantization_axis, - DefaultValuedOptionalAttr:$lhs_quantization_min_val, - DefaultValuedOptionalAttr:$lhs_quantization_max_val, - DefaultValuedOptionalAttr:$rhs_quantization_axis, - DefaultValuedOptionalAttr:$rhs_quantization_min_val, - DefaultValuedOptionalAttr:$rhs_quantization_max_val, - DefaultValuedOptionalAttr:$output_quantization_axis, - DefaultValuedOptionalAttr:$output_quantization_min_val, - DefaultValuedOptionalAttr:$output_quantization_max_val - ); - - let results = (outs - TensorOf<[TF_Qint32]>:$output - ); - - TF_DerivedOperandTypeAttr LhsT = TF_DerivedOperandTypeAttr<0>; - TF_DerivedOperandTypeAttr RhsT = TF_DerivedOperandTypeAttr<1>; - TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<8>; - TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; -} - // TODO(b/230804708): Add hybrid dot general. def TF_UniformQuantizedDotGeneralOp : TF_Op<"UniformQuantizedDotGeneral", [Pure]> { // TODO(b/230804708): Improve the operator description. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/unfreeze_constants.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/unfreeze_constants.cc index 38b49d4c3c0..95eeabe78eb 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/unfreeze_constants.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/unfreeze_constants.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -31,6 +32,8 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" @@ -41,18 +44,41 @@ namespace mlir { namespace quant { namespace { +using ::mlir::tf_saved_model::GetInitializerFunction; +using ::mlir::tf_saved_model::GetSessionInitializerOp; +using ::mlir::tf_saved_model::kTfSavedModelExportedNamesAttr; using ::mlir::tf_saved_model::kTfSavedModelInitializerRestoreType; using ::mlir::tf_saved_model::kTfSavedModelInitializerTypeAttr; using ::mlir::tf_saved_model::SessionInitializerOp; constexpr absl::string_view kDefaultConstName = "const"; +// The default lower threshold for the constant size for unfreezing. +constexpr int64_t kDefaultConstantSizeThresholdInBytes = 64 * 1024; // 64KiB + +// This pass "unfreezes" constants found in the moudle and converts them to +// `tf.VarHandleOp`s. Also, an initialization pattern +// `tf.AssignVariableOp(tf.VarHandleOp, tf.ConstOp)` is inserted to the +// initializer function of type "restore_op" for each of the unfrozen constants. +// +// The constants whose sizes are smaller than `size_threshold_in_bytes_` will +// not be converted to variables. class UnfreezeConstantsPass : public PassWrapper> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnfreezeConstantsPass) - explicit UnfreezeConstantsPass() {} + explicit UnfreezeConstantsPass() + : UnfreezeConstantsPass(kDefaultConstantSizeThresholdInBytes) {} + + explicit UnfreezeConstantsPass(const int64_t size_threshold_in_bytes) + : size_threshold_in_bytes_( + CreateSizeThresholdInBytesOption(size_threshold_in_bytes)) {} + + UnfreezeConstantsPass(const UnfreezeConstantsPass& other) + : UnfreezeConstantsPass{} { + size_threshold_in_bytes_ = other.size_threshold_in_bytes_.getValue(); + } StringRef getArgument() const override { return "quant-unfreeze-constants"; } @@ -63,10 +89,23 @@ class UnfreezeConstantsPass void runOnOperation() override; private: + Option CreateSizeThresholdInBytesOption(const int64_t init_value) { + return Option( + *this, "size_threshold_in_bytes", llvm::cl::init(init_value), + llvm::cl::desc( + "Lower threshold of the constant size for unfreezing. Constants " + "smaller than this value will not be converted to variables.")); + } + void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); } + + // Lower-bound threshold for the size of the constant in bytes. Constants + // larger than this threshold will not be unfrozen and will remain as + // constants. + Option size_threshold_in_bytes_; }; // Adds the symbol to the "initializers" attribute of the session_initializer @@ -82,13 +121,31 @@ void AddSymbolToInitializersAttr(SessionInitializerOp session_init_op, ArrayAttr::get(session_init_op.getContext(), initializers_attrs)); } -// Create the initializer function right after the session_initializer op. +// Returns the session_initializer op in the module if exists. Otherwise, +// creates a new session_initializer op and returns it. +SessionInitializerOp GetOrCreateSessionInitializerOp(ModuleOp module_op) { + SessionInitializerOp session_init_op = GetSessionInitializerOp(module_op); + + // Create one if it doesn't exist. + if (!session_init_op) { + OpBuilder builder(&module_op.getBodyRegion()); + + session_init_op = builder.create( + module_op.getLoc(), /*initializers=*/builder.getArrayAttr({})); + } + + return session_init_op; +} + +// Create the initializer function right after the SessionInitializer op. // Returns the newly created initializer function. The initializer function's // initializer_type is set to "restore_op" since it essentially serves as a // variable restoration function. -func::FuncOp CreateInitializerFunc(SymbolTable& symbol_table, - SessionInitializerOp session_init_op) { - OpBuilder builder{session_init_op.getContext()}; +func::FuncOp CreateInitializerFunc(ModuleOp module_op) { + SessionInitializerOp session_init_op = + GetOrCreateSessionInitializerOp(module_op); + + OpBuilder builder(module_op.getContext()); builder.setInsertionPointAfter(session_init_op); const Location loc = builder.getUnknownLoc(); @@ -99,7 +156,7 @@ func::FuncOp CreateInitializerFunc(SymbolTable& symbol_table, builder.createBlock(&init_func.getBody(), /*insertPt=*/init_func.begin(), /*arg_types=*/{}, /*arg_locs=*/{}); - init_func->setAttr("tf_saved_model.exported_names", + init_func->setAttr(kTfSavedModelExportedNamesAttr, builder.getStrArrayAttr( {"tf_saved_model.session_initializer_restore_op"})); init_func->setAttr( @@ -109,6 +166,7 @@ func::FuncOp CreateInitializerFunc(SymbolTable& symbol_table, builder.setInsertionPointToStart(&init_func.front()); builder.create(loc, /*operands=*/ValueRange{}); + SymbolTable symbol_table(module_op); symbol_table.insert(init_func); AddSymbolToInitializersAttr( @@ -126,42 +184,19 @@ bool IsInitializerType(func::FuncOp init_func_op, StringRef initializer_type) { } // Returns the initializer function whose tf_saved_model.initializer_type -// matches `initializer_type`. Creates and returns a new initializer function -// iff such FuncOp is not found. The newly created initializer function's symbol -// will be added to the symbol table and session_initializer op's "intializer" -// attribute. -func::FuncOp GetOrCreateSessionInitializerFunc( - SymbolTable& symbol_table, SessionInitializerOp session_init_op, - StringRef initializer_type) { - for (const auto init_sym : - session_init_op.getInitializers().getAsValueRange()) { - auto init_func_op = symbol_table.lookup(init_sym); - if (!init_func_op) continue; - - if (IsInitializerType(init_func_op, kTfSavedModelInitializerRestoreType)) { - return init_func_op; - } +// is "restore_op". Creates and returns a new initializer function iff such +// `FuncOp` is not found. The newly created initializer function's +// initializer_type is "restore_op" and its symbol will be added to the symbol +// table and session_initializer op's "intializer" attribute. +func::FuncOp GetOrCreateInitializerFunc(ModuleOp module_op) { + if (auto init_func_op = GetInitializerFunction( + module_op, /*initializer_type=*/kTfSavedModelInitializerRestoreType); + init_func_op) { + return init_func_op; + } else { + // Create a new initializer function if the init function is not found. + return CreateInitializerFunc(module_op); } - - // Create a new initializer function if the init function is not found. - return CreateInitializerFunc(symbol_table, session_init_op); -} - -// Returns the session_initializer op in the module if exists. Otherwise, -// creates a new session_initializer op and returns it. -SessionInitializerOp GetOrCreateSessionInitializerOp(ModuleOp module_op) { - SessionInitializerOp session_init_op = - tf_saved_model::GetSessionInitializerOp(module_op); - - // Create one if it doesn't exist. - if (!session_init_op) { - OpBuilder builder{&module_op.getBodyRegion()}; - - session_init_op = builder.create( - module_op.getLoc(), /*initializers=*/builder.getArrayAttr({})); - } - - return session_init_op; } // Retrieve the ConstOp's name from its loc. Returns "const" if a name cannot be @@ -176,15 +211,18 @@ std::string GetConstOpName(TF::ConstOp const_op) { } // Collects the ConstOps to unfreeze. -std::vector GetTargetConstOps(ModuleOp module_op) { +std::vector GetTargetConstOps(const int64_t size_threshold, + ModuleOp module_op) { std::vector target_const_ops{}; // TODO(b/254636388): Lift the assumption that there are no intializer // functions and avoid converting ConstOps inside initializer functions. for (auto func_op : module_op.getOps()) { - auto const_ops = func_op.getOps(); - target_const_ops.insert(target_const_ops.end(), const_ops.begin(), - const_ops.end()); + absl::c_copy_if(func_op.getOps(), + std::back_inserter(target_const_ops), + [size_threshold](TF::ConstOp const_op) -> bool { + return GetSizeInBytes(const_op) > size_threshold; + }); } return target_const_ops; @@ -255,11 +293,11 @@ void CreateAssignVariableOps( // Assign the ConstOp to each VarHandleOp. These will be used to save the // variable values to the checkpoint. auto const_op_copy = - builder.create(const_op.getLoc(), const_op.value()); + builder.create(const_op.getLoc(), const_op.getValue()); builder.create(const_op.getLoc(), /*resource=*/var_handle_op, - /*value=*/const_op_copy.output()); + /*value=*/const_op_copy.getOutput()); } } @@ -268,18 +306,13 @@ void UnfreezeConstantsPass::runOnOperation() { // Find the ConstOps to "unfreeze" into VarHandleOps. const std::vector target_const_ops = - GetTargetConstOps(module_op); + GetTargetConstOps(size_threshold_in_bytes_.getValue(), module_op); if (target_const_ops.empty()) { VLOG(1) << "No ConstOps found. UnfreezeConstantsPass is a no-op."; return; } - SessionInitializerOp session_init_op = - GetOrCreateSessionInitializerOp(module_op); - - SymbolTable symbol_table{module_op}; - func::FuncOp session_init_func = GetOrCreateSessionInitializerFunc( - symbol_table, session_init_op, kTfSavedModelInitializerRestoreType); + func::FuncOp session_init_func = GetOrCreateInitializerFunc(module_op); // Replace each usage of ConstOp to a VarHandleOp -> ReadVariableOp pattern. llvm::MapVector const_op_name_map = diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.cc index 9d228c4bff1..c65f7ac7906 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "llvm/ADT/STLExtras.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h" @@ -76,7 +77,7 @@ LogicalResult IsOperationFoldable(Operation* op) { // folded to preserve the original semantics. if (op->hasTrait() || op->hasTrait() || op->getNumRegions() != 0 || - !MemoryEffectOpInterface::hasNoEffect(op)) { + !isMemoryEffectFree(op)) { return failure(); } @@ -135,7 +136,7 @@ LogicalResult FoldOperation(TFE_Context* ctx, OpBuilder& builder, Operation* op, for (auto operand : op->getOperands()) { auto preceding_const_op = operand.getDefiningOp(); if (preceding_const_op) { - inputs.push_back(preceding_const_op.value()); + inputs.push_back(preceding_const_op.getValue()); continue; } @@ -153,7 +154,7 @@ LogicalResult FoldOperation(TFE_Context* ctx, OpBuilder& builder, Operation* op, } auto preceding_result = preceding_results[preceding_result_id]; preceding_const_op = preceding_result.getDefiningOp(); - inputs.push_back(preceding_const_op.value()); + inputs.push_back(preceding_const_op.getValue()); } // Avoid overlapping folds with the same context. @@ -242,5 +243,15 @@ llvm::SmallVector ConstantFoldOpIfPossible(Operation* op) { return results; } +llvm::SmallVector CloneOpWithReplacedOperands( + OpBuilder& builder, Operation* op, + const llvm::SmallVector& new_operands) { + IRMapping mapping; + for (const auto& arg : llvm::enumerate(new_operands)) { + mapping.map(op->getOperand(arg.index()), arg.value()); + } + return builder.clone(*op, mapping)->getResults(); +} + } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h b/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h index 56a3f686de7..003018f4db7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h @@ -30,17 +30,15 @@ limitations under the License. namespace mlir { namespace quant { +constexpr char kQuantizeFuncName[] = "quantize_i8"; +constexpr char kDequantizeFuncName[] = "dequantize_i8"; +constexpr char kAttrMapAttribute[] = "attr_map"; + // TODO(b/238829558): Populate quantization config based on the -// QuantizationOptions proto. We might want to clean QuantizationMethod as well -// as this can be inferred from the proto. +// QuantizationOptions proto. +// TODO(b/263449239): Put the OpSet aliases separately within each file using OpSet = tensorflow::quantization::OpSet; -enum class QuantizationMethod { - kQuantizationAwareTraining, - kPostTrainingQuantization, - kDynamicRangeQuantization -}; - // Returns true if the value has static shape. bool HasStaticShape(Value value); @@ -130,8 +128,16 @@ bool AreSplatValuesEqual(Value x, Value y) { return splat_x == splat_y; } -// TODO(b/241488936): Remove this function after adding a new constant folding +// Clones an operation with new operands while keeping attributes. +llvm::SmallVector CloneOpWithReplacedOperands( + OpBuilder &builder, Operation *op, + const llvm::SmallVector &new_operands); + +// TODO(b/241488936): Remove these functions after adding a new constant folding // pass to TensorFlow. +// Checks if an Operation is foldable. +LogicalResult IsOperationFoldable(Operation *op); + // Applies constant folding to the operation if possible and return the folded // results. llvm::SmallVector ConstantFoldOpIfPossible(Operation *op); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td index 1e5cc82be94..eaa065f516e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.td @@ -57,6 +57,16 @@ def IsInt32ElementType : Constraint< def IsF32ElementType : Constraint< CPred<"getElementTypeOrSelf($0).isF32()">>; +// Checks if the value has rank. +def HasRank : Constraint< + CPred<"$0.getType().cast().hasRank()">>; + +// Checks if the value has rank of `n`. +class HasRankOf : Constraint< + CPred<"$0.getType().cast().hasRank() && " + "$0.getType().cast().getRank() == " # n>, + "Checks if the value has rank of 'n'.">; + // Checks if the value has static shape. def HasStaticShapeConstraint : Constraint>; @@ -96,7 +106,7 @@ class CreateI64ArrayAttr : NativeCodeCall< // Creates a string attribute with given values. class CreateStringAttr : NativeCodeCall< - "$_builder.getStringAttr("# values #")">; + "$_builder.getStringAttr(\""# values #"\")">; // Creates a new F32 type with the same shape as the given value. def CloneTypeWithF32ElementType : NativeCodeCall< @@ -106,11 +116,6 @@ def CloneTypeWithF32ElementType : NativeCodeCall< def CloneTypeWithI32ElementType : NativeCodeCall< "CloneTypeWithNewElementType($0.getType(), $_builder.getI32Type())">; -// By default, the generated code uses the `create` method without the output -// type field. However, for many ops, the output type field is always required. -class CreateOpWithOutputType : NativeCodeCall< - "$_builder.create<"# op_name #">($_loc, $0...)">; - // Checks if the value is a float constant and its splat value is equal to `x`. class IsSplatValueEqual : Constraint($0, "# x #")">>; @@ -126,3 +131,19 @@ class IsIntSplatValueEqual : Constraint : Constraint($0, $1)">>; + +// Returns defining op of this value. +def GetDefiningOp : NativeCodeCall<"$0.getDefiningOp()">; + +// Checks if an Operation is foldable. +def IsOperationFoldable : Constraint< + CPred<"succeeded(IsOperationFoldable($0.getDefiningOp()))">>; + +// Applies constant-folding to the given op if possible. +def ConstantFoldOpIfPossible : + NativeCodeCall<"ConstantFoldOpIfPossible($0).front()">; + +// Clones an operation with new operands while keeping attributes. +def CloneOpWithReplacedOperands : NativeCodeCall< + "CloneOpWithReplacedOperands(" + "$_builder, $0, llvm::SmallVector{$1...}).front()">; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index 5570fb64d6b..61c7aebff05 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -1,7 +1,17 @@ load("//tensorflow:pytype.default.bzl", "pytype_library", "pytype_strict_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud", "tf_py_test", "tf_python_pybind_extension") +load( + "//tensorflow/core/platform:build_config_root.bzl", + "if_static", +) +load( + "//tensorflow:tensorflow.default.bzl", + "get_compatible_with_cloud", + "tf_py_test", + "tf_python_pybind_extension", +) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/compiler/mlir/quantization/tensorflow:internal_visibility_allowlist_package", "//tensorflow/python:__subpackages__", @@ -9,37 +19,30 @@ package( licenses = ["notice"], ) +# Do NOT directly depend on `quantize_model_cc_impl` unless it is necessary +# (i.e. undefined symbol). See the comments in `quantize_model_cc`. cc_library( - name = "quantize_model_lib", - srcs = [ - "quantize_model.cc", - ], - hdrs = [ - "quantize_model.h", - ], + name = "quantize_model_cc_impl", + srcs = ["quantize_model.cc"], + hdrs = ["quantize_model.h"], compatible_with = get_compatible_with_cloud(), + visibility = [ + # Directly linked to `libtensorflow_cc.so` or + # `_pywrap_tensorflow_internal.so` if static build. + "//tensorflow:__pkg__", + "//tensorflow/python:__pkg__", + ], deps = [ - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:Transforms", "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir/quantization/tensorflow:constants", + "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", - # Required for CustomAggregator op registration. - "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:custom_aggregator_op", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:custom_aggregator_op", # Required for CustomAggregator op registration. + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:save_variables", + "//tensorflow/compiler/mlir/quantization/tensorflow/debugging:mlir_dump", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:export_graphdef", @@ -53,52 +56,99 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:path", "//tensorflow/core/platform:statusor", + "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:path", + "//tensorflow/tsl/platform:status", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Transforms", ], ) +# OSS: This is a header-only target. The implementation target `quantize_model_cc_impl` is +# directly linked to `lib_pywrap_tensorflow_internal.so`, so in most use cases of python- +# exported symbols depending directly on `quantize_model_cc_impl` should be unnecessary. +# Using the header-only target will help avoid the ODR violation. cc_library( name = "quantize_model_cc", - srcs = [ - "quantize_model_wrapper.cc", - ], - hdrs = [ - "quantize_model_wrapper.h", - ], - copts = ["-fexceptions"], - features = [ - "-use_header_modules", # Required for pybind11 - "-parse_headers", - ], - visibility = [ - "//tensorflow/compiler/mlir/quantization/tensorflow:internal_visibility_allowlist_package", - "//tensorflow/python:__subpackages__", + hdrs = ["quantize_model.h"], + compatible_with = get_compatible_with_cloud(), + deps = if_static([":quantize_model_cc_impl"]) + [ + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/core:protos_all_cc", ], +) + +# Exports python symbols via pybind11. +tf_python_pybind_extension( + name = "pywrap_quantize_model", + srcs = ["pywrap_quantize_model.cc"], + # All deps must be header-only. deps = [ - ":quantize_model_lib", + ":quantize_model_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton", "//tensorflow/core:protos_all_cc", "//tensorflow/python/lib/core:pybind11_lib", + "//third_party/python_runtime:headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@pybind11", + "@pybind11_abseil//pybind11_abseil:absl_casters", + "@pybind11_abseil//pybind11_abseil:status_casters", ], ) -tf_python_pybind_extension( - name = "pywrap_quantize_model", +tf_py_test( + name = "pywrap_quantize_model_test", srcs = [ - "pywrap_quantize_model.cc", + "pywrap_quantize_model_test.py", ], - hdrs = [ - "quantize_model_wrapper.h", + tags = ["no_pip"], + deps = [ + ":pywrap_quantize_model", + "//tensorflow:tensorflow_py", + "//tensorflow/python/platform", + ], +) + +pytype_strict_library( + name = "save_model", + srcs = [ + "save_model.py", ], deps = [ - "//tensorflow/python/lib/core:pybind11_lib", - "//third_party/python_runtime:headers", - "@com_google_absl//absl/strings", - "@pybind11", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python:variables", + "//tensorflow/python/client:session", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:tensor_shape", + "//tensorflow/python/lib/io:lib", + "//tensorflow/python/saved_model:builder", + "//tensorflow/python/saved_model:constants", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:tag_constants", + "//tensorflow/python/types", + "@absl_py//absl/logging", ], ) @@ -112,9 +162,10 @@ pytype_strict_library( deps = [ ":pywrap_quantize_model", ":representative_dataset", + ":save_model", + "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_py", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_py", "//tensorflow/core:protos_all_py", - "//tensorflow/python:framework", "//tensorflow/python:framework_ops", "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python/client:session", @@ -122,7 +173,6 @@ pytype_strict_library( "//tensorflow/python/eager:wrap_function", "//tensorflow/python/lib/io:lib", "//tensorflow/python/platform", - "//tensorflow/python/saved_model:builder", "//tensorflow/python/saved_model:load", "//tensorflow/python/saved_model:loader", "//tensorflow/python/saved_model:signature_constants", @@ -138,7 +188,7 @@ tf_py_test( name = "quantize_model_test", size = "medium", srcs = ["integration_test/quantize_model_test.py"], - shard_count = 10, # Parallelize the test to avoid timeouts. + shard_count = 50, # Parallelize the test to avoid timeouts. tags = ["no_pip"], deps = [ ":quantize_model", @@ -165,16 +215,25 @@ pytype_library( "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:io_ops", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:random_ops", + "//tensorflow/python:string_ops", "//tensorflow/python:variables", "//tensorflow/python/client:session", "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/lib/io:lib", + "//tensorflow/python/module", + "//tensorflow/python/ops/ragged:ragged_string_ops", "//tensorflow/python/saved_model:builder", + "//tensorflow/python/saved_model:save", "//tensorflow/python/saved_model:signature_def_utils", + "//tensorflow/python/trackable:asset", "//tensorflow/python/trackable:autotrackable", "//tensorflow/python/types", "//third_party/py/numpy", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/concurrency_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/concurrency_test.py index 6fb9e220627..f2444a8946b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/concurrency_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/concurrency_test.py @@ -41,14 +41,19 @@ def setUp(self): self.pool = futures.ThreadPoolExecutor(max_workers=4) def _convert_with_calibration(self): - class ModelWithAdd(autotrackable.AutoTrackable): """Basic model with addition.""" - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=[10], dtype=dtypes.float32, name='x'), - tensor_spec.TensorSpec(shape=[10], dtype=dtypes.float32, name='y') - ]) + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec( + shape=[10], dtype=dtypes.float32, name='x' + ), + tensor_spec.TensorSpec( + shape=[10], dtype=dtypes.float32, name='y' + ), + ] + ) def add(self, x, y): res = math_ops.add(x, y) return {'output': res} @@ -56,29 +61,34 @@ def add(self, x, y): def data_gen(): for _ in range(255): yield { - 'x': - ops.convert_to_tensor( - np.random.uniform(size=(10)).astype('f4')), - 'y': - ops.convert_to_tensor( - np.random.uniform(size=(10)).astype('f4')) + 'x': ops.convert_to_tensor( + np.random.uniform(size=(10)).astype('f4') + ), + 'y': ops.convert_to_tensor( + np.random.uniform(size=(10)).astype('f4') + ), } root = ModelWithAdd() temp_path = self.create_tempdir().full_path saved_model_save.save( - root, temp_path, signatures=root.add.get_concrete_function()) + root, temp_path, signatures=root.add.get_concrete_function() + ) quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=quant_opts_pb2.QuantizationMethod - .ExperimentalMethod.STATIC_RANGE)) + experimental_method=quant_opts_pb2.QuantizationMethod.ExperimentalMethod.STATIC_RANGE + ) + ) model = quantize_model.quantize( - temp_path, ['serving_default'], [tag_constants.SERVING], + temp_path, + ['serving_default'], + [tag_constants.SERVING], quantization_options=quantization_options, - representative_dataset=data_gen()) + representative_dataset=data_gen(), + ) return model @test_util.run_in_graph_and_eager_modes diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py index 88ddbd7e243..e6519973c44 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py @@ -13,9 +13,10 @@ # limitations under the License. # ============================================================================== """Tests for quantize_model.""" +# TODO(b/264234648): Refactor and cleanup this file. import itertools import os -from typing import List, Mapping, Optional +from typing import List, Mapping, Optional, Sequence, Tuple, Union from absl.testing import parameterized import numpy as np @@ -24,36 +25,100 @@ from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 as quant_opts_pb2 from tensorflow.compiler.mlir.quantization.tensorflow.python import quantize_model from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as repr_dataset +from tensorflow.compiler.mlir.quantization.tensorflow.python import save_model from tensorflow.compiler.mlir.quantization.tensorflow.python.integration_test import quantize_model_test_base +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.lib.io import file_io from tensorflow.python.module import module from tensorflow.python.ops import array_ops -from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import special_math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import builder from tensorflow.python.saved_model import loader_impl as saved_model_loader from tensorflow.python.saved_model import save as saved_model_save +from tensorflow.python.saved_model import save_options from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils_impl from tensorflow.python.saved_model import tag_constants +from tensorflow.python.training import checkpoint_utils from tensorflow.python.types import core # Type aliases for quantization method protobuf enums. _Method = quant_opts_pb2.QuantizationMethod.Method _ExperimentalMethod = quant_opts_pb2.QuantizationMethod.ExperimentalMethod +_TensorShape = Sequence[Union[int, None]] + +_PER_CHANNEL_QUANTIZED_OPS = ( + 'UniformQuantizedConvolution', + 'UniformQuantizedConvolutionHybrid', + 'UniformQuantizedDotHybrid', +) + +# Lists of ops whose channel dimension should be changed if per_channel +# quantization is enabled. Respectively refers to (scale, zero_point). +_SUFFIXES = ('/filter1', '/filter2') +_PER_CHANNEL_OP_NAMES = ( + f'{op}{suffix}' + for op, suffix in itertools.product(_PER_CHANNEL_QUANTIZED_OPS, _SUFFIXES) +) + + +def _is_variable(node_def: node_def_pb2.NodeDef) -> bool: + """Determines whether `node_def` is a variable node. + + Args: + node_def: `NodeDef` to test whether it is a variable or not. + + Returns: + Returns True if it is a variable. + """ + return node_def.op == 'VarHandleOp' + + +def _find_variables( + graph_def: graph_pb2.GraphDef, +) -> Mapping[str, node_def_pb2.NodeDef]: + """Finds all variables within `graph_def`. + + This function makes sense for TF 1 graphs only, as it depends on + `shared_name`. + + Args: + graph_def: `GraphDef` to find variables from. + + Returns: + A mapping of `shared_name` -> `NodeDef` corresponding to a variable op. + """ + variable_nodes = {} + + for var_node in filter(_is_variable, graph_def.node): + shared_name = str(var_node.attr['shared_name'].s, encoding='utf-8') + variable_nodes[shared_name] = var_node + + for func in graph_def.library.function: + for var_node in filter(_is_variable, func.node_def): + variable_nodes[shared_name] = var_node + + return variable_nodes + def parameter_combinations(test_parameters): """Generate all combinations of test parameters.""" @@ -71,9 +136,11 @@ class MultipleSignatureModel(module.Module): Used to test where the quantizer has to handle multiple signatures. """ - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=[1, 4], dtype=dtypes.float32) - ]) + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec(shape=[1, 4], dtype=dtypes.float32) + ] + ) def matmul(self, matmul_input: core.Tensor) -> Mapping[str, core.Tensor]: """Performs a matrix multiplication. @@ -88,9 +155,11 @@ def matmul(self, matmul_input: core.Tensor) -> Mapping[str, core.Tensor]: return {'output': out} - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=(1, 3, 4, 3), dtype=dtypes.float32) - ]) + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec(shape=(1, 3, 4, 3), dtype=dtypes.float32) + ] + ) def conv(self, conv_input: core.Tensor) -> Mapping[str, core.Tensor]: """Performs a 2D convolution operation. @@ -100,22 +169,24 @@ def conv(self, conv_input: core.Tensor) -> Mapping[str, core.Tensor]: Returns: A map of: output key -> output result. """ - filters = np.random.uniform( - low=-10, high=10, size=(2, 3, 3, 2)).astype('f4') + filters = np.random.uniform(low=-10, high=10, size=(2, 3, 3, 2)).astype( + 'f4' + ) out = nn_ops.conv2d( conv_input, filters, strides=[1, 1, 2, 1], dilations=[1, 1, 1, 1], padding='SAME', - data_format='NHWC') + data_format='NHWC', + ) return {'output': out} @test_util.run_all_in_graph_and_eager_modes -class QuantizationMethodTest(quantize_model_test_base.QuantizedModelTest): - """Test cases regarding the use of QuantizationMethod proto. +class QuantizationOptionsTest(quantize_model_test_base.QuantizedModelTest): + """Test cases regarding the use of QuantizationOptions proto. Run all tests cases in both the graph mode (default in TF1) and the eager mode (default in TF2) to ensure support for when TF2 is disabled. @@ -123,9 +194,11 @@ class QuantizationMethodTest(quantize_model_test_base.QuantizedModelTest): class SimpleModel(module.Module): - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=[1, 4], dtype=dtypes.float32) - ]) + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec(shape=[1, 4], dtype=dtypes.float32) + ] + ) def __call__(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: """Performs a matrix multiplication. @@ -149,65 +222,228 @@ def _simple_model_data_gen(self) -> repr_dataset.RepresentativeDataset: """ for _ in range(8): yield { - 'input_tensor': - ops.convert_to_tensor( - np.random.uniform(low=0, high=150, size=(1, 4)).astype('f4')), + 'input_tensor': ops.convert_to_tensor( + np.random.uniform(low=0, high=150, size=(1, 4)).astype('f4') + ), } def test_static_range_quantization_by_default(self): model = self.SimpleModel() - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + saved_model_save.save(model, self._input_saved_model_path) # Use default QuantizationOptions. converted_model = quantize_model.quantize( - input_saved_model_path, - representative_dataset=self._simple_model_data_gen()) + self._input_saved_model_path, + representative_dataset=self._simple_model_data_gen(), + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) # Indirectly prove that it is performing a static-range quantization # by checking that it complains about representative_dataset when it is # not provided. with self.assertRaisesRegex(ValueError, 'representative_dataset'): - quantize_model.quantize(input_saved_model_path) + quantize_model.quantize(self._input_saved_model_path) def test_method_unspecified_raises_value_error(self): model = self.SimpleModel() - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + saved_model_save.save(model, self._input_saved_model_path) options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - method=_Method.METHOD_UNSPECIFIED)) + method=_Method.METHOD_UNSPECIFIED + ) + ) with self.assertRaises(ValueError): quantize_model.quantize( - input_saved_model_path, quantization_options=options) + self._input_saved_model_path, quantization_options=options + ) def test_invalid_method_raises_value_error(self): model = self.SimpleModel() - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + saved_model_save.save(model, self._input_saved_model_path) # Set an invalid value of -1 to QuantizationMethod.method. options = quant_opts_pb2.QuantizationOptions( - quantization_method=quant_opts_pb2.QuantizationMethod(method=-1)) + quantization_method=quant_opts_pb2.QuantizationMethod(method=-1) + ) + + with self.assertRaises(ValueError): + quantize_model.quantize( + self._input_saved_model_path, quantization_options=options + ) + + def test_per_channel_for_non_uniform_opset_raises_value_error(self): + model = self.SimpleModel() + + saved_model_save.save(model, self._input_saved_model_path) + + options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.DYNAMIC_RANGE + ), + op_set=quant_opts_pb2.TF, + enable_per_channel_quantization=True, + ) with self.assertRaises(ValueError): quantize_model.quantize( - input_saved_model_path, quantization_options=options) + self._input_saved_model_path, quantization_options=options + ) + + +class TensorNamePreservationTest(quantize_model_test_base.QuantizedModelTest): + + def test_preserving_input_output_tensor_names(self): + class MultiSignatureModel(module.Module): + + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec( + name='input', shape=[32], dtype=dtypes.float32 + ), + ] + ) + def multiple_output_ops( + self, input_tensor: core.Tensor + ) -> Mapping[str, core.Tensor]: + k = array_ops.constant(4, dtype=dtypes.int32) + values, indices = nn_ops.top_k(input_tensor, k, name='TopK') + adj_values = values + 2 + return {'indices': indices, 'adj_values': adj_values, 'values': values} + + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec( + name='input', shape=[32], dtype=dtypes.float32 + ), + ] + ) + def duplicate_outputs( + self, input_tensor: core.Tensor + ) -> Mapping[str, core.Tensor]: + q_input = array_ops.fake_quant_with_min_max_args( + input_tensor, min=-0.1, max=0.2, num_bits=8, narrow_range=False + ) + adj_values = q_input + 2 + return {'adj_values_1': adj_values, 'adj_values_2': adj_values} + + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec( + name='input', shape=[32], dtype=dtypes.float32 + ), + ] + ) + def return_higher_index_only( + self, input_tensor: core.Tensor + ) -> Mapping[str, core.Tensor]: + k = array_ops.constant(4, dtype=dtypes.int32) + values, indices = nn_ops.top_k(input_tensor, k, name='TopK') + adj_values = values + 2 + return {'indices': indices, 'adj_values': adj_values} + + model = MultiSignatureModel() + signatures = { + 'multiple_output_ops': model.multiple_output_ops, + 'duplicate_outputs': model.duplicate_outputs, + 'return_higher_index_only': model.return_higher_index_only, + } + saved_model_save.save( + model, self._input_saved_model_path, signatures=signatures + ) + + tags = {tag_constants.SERVING} + original_signature_map = save_model.get_signatures_from_saved_model( + self._input_saved_model_path, + signature_keys=signatures.keys(), + tags=tags, + ) + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=quant_opts_pb2.TF, + ) + quantize_model.quantize( + self._input_saved_model_path, + signatures.keys(), + tags, + self._output_saved_model_path, + quantization_options, + ) + converted_signature_map = save_model.get_signatures_from_saved_model( + self._output_saved_model_path, + signature_keys=signatures.keys(), + tags=tags, + ) + + # The original and converted model should have the same signature map. + self.assertAllInSet( + list(original_signature_map.keys()), set(signatures.keys()) + ) + self.assertDictEqual(original_signature_map, converted_signature_map) + + def test_duplicated_tensor_name(self): + with session.Session(graph=ops.Graph()) as sess: + input_tensor = array_ops.placeholder( + dtypes.float32, shape=[], name='input' + ) + q_input = array_ops.fake_quant_with_min_max_args( + input_tensor, min=-0.1, max=0.2, num_bits=8, narrow_range=False + ) + sqrt = math_ops.sqrt(q_input, name='sqrt') + identity = array_ops.identity(sqrt, name='output') + + input_map = {'input': input_tensor} + output_map = {'sqrt': identity} + signature = signature_def_utils_impl.predict_signature_def( + inputs=input_map, outputs=output_map + ) + signature_map = {'main': signature} + + tags = {tag_constants.SERVING} + v1_builder = builder.SavedModelBuilder(self._input_saved_model_path) + v1_builder.add_meta_graph_and_variables( + sess, tags, signature_def_map=signature_map + ) + v1_builder.save() + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=quant_opts_pb2.TF, + ) + quantize_model.quantize( + self._input_saved_model_path, + signature_map.keys(), + tags, + self._output_saved_model_path, + quantization_options, + ) + converted_signature_map = save_model.get_signatures_from_saved_model( + self._output_saved_model_path, + signature_keys=signature_map.keys(), + tags=tags, + ) + # The original and converted model should have the same signature map. + self.assertDictEqual(signature_map, converted_signature_map) class StaticRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): - def _any_warning_contains(self, substring: str, - warnings_list: List['LogRecord']) -> bool: + def _any_warning_contains( + self, substring: str, warnings_list: List['LogRecord'] + ) -> bool: """Returns True if any of the warnings contains a given substring. Args: @@ -220,28 +456,220 @@ def _any_warning_contains(self, substring: str, `warnings_list`. """ return any( - map(lambda warning: substring in str(warning.message), warnings_list)) + map(lambda warning: substring in str(warning.message), warnings_list) + ) @parameterized.parameters( parameter_combinations([{ + 'shapes': [ + ([3, 3], [3, 3]), + ([3, None], [None, 3]), + ([None, None], [None, None]), + ([4, 3, 3], [4, 3, 3]), + ([4, 3, None], [4, None, 3]), + ([None, None, None], [None, None, None]), + ], 'activation_fn': [None, nn_ops.relu, nn_ops.relu6], 'has_bias': [True, False], - 'target_opset': [quant_opts_pb2.XLA], - }])) + 'use_kernel': [True, False], + }]) + ) @test_util.run_in_graph_and_eager_modes - def test_qat_conv_model(self, activation_fn: Optional[ops.Operation], - has_bias: bool, target_opset: quant_opts_pb2.OpSet): + def test_qat_matmul_model( + self, + shapes: Sequence[Tuple[_TensorShape, _TensorShape]], + activation_fn: Optional[ops.Operation], + has_bias: bool, + use_kernel: bool, + ): + n = 5 + x_shape = [v if v is not None else n for v in shapes[0]] + y_shape = [v if v is not None else n for v in shapes[1]] + + class MatmulModel(module.Module): + + def __init__(self, bias: Optional[core.Tensor]): + self._bias = bias + self._kernel = np.random.uniform(size=y_shape).astype('f4') + self._min = (-0.8, -0.8, -0.9) + self._max = (0.9, 0.9, 1.0) + + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec( + name='x', shape=shapes[0], dtype=dtypes.float32 + ) + ] + ) + def matmul_with_kernel(self, x: core.Tensor) -> Mapping[str, core.Tensor]: + return self._matmul(x, self._kernel) + + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec( + name='x', shape=shapes[0], dtype=dtypes.float32 + ), + tensor_spec.TensorSpec( + name='y', shape=shapes[1], dtype=dtypes.float32 + ), + ] + ) + def matmul_without_kernel( + self, x: core.Tensor, y: core.Tensor + ) -> Mapping[str, core.Tensor]: + return self._matmul(x, y) + + def _matmul(self, x, y): + x = array_ops.fake_quant_with_min_max_vars( + x, + min=ops.convert_to_tensor(self._min[0]), + max=ops.convert_to_tensor(self._max[0]), + num_bits=8, + narrow_range=False, + ) + y = array_ops.fake_quant_with_min_max_vars( + y, + min=ops.convert_to_tensor(self._min[1]), + max=ops.convert_to_tensor(self._max[1]), + num_bits=8, + narrow_range=False, + ) + + out = math_ops.matmul(x, y) + if self._bias is not None: + out = nn_ops.bias_add(out, self._bias) + if activation_fn is not None: + out = activation_fn(out) + out = array_ops.fake_quant_with_min_max_vars( + out, + min=ops.convert_to_tensor(self._min[2]), + max=ops.convert_to_tensor(self._max[2]), + num_bits=8, + narrow_range=False, + ) + return {'output': out} + + bias = None + if has_bias: + bias_shape = shapes[1][-1] + if bias_shape is not None: + bias = array_ops.constant( + np.random.uniform(size=[shapes[1][-1]]), dtype=dtypes.float32 + ) + model = MatmulModel(bias) + x = array_ops.constant( + np.random.uniform(size=x_shape), dtype=dtypes.float32 + ) + y = array_ops.constant( + np.random.uniform(size=y_shape), dtype=dtypes.float32 + ) + if use_kernel: + model.matmul = model.matmul_with_kernel + model_inputs = {'x': x} + else: + model.matmul = model.matmul_without_kernel + model_inputs = {'x': x, 'y': y} + + saved_model_save.save( + model, self._input_saved_model_path, signatures=model.matmul + ) + + signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + tags = {tag_constants.SERVING} + + # Check the converted model with TF opset as the baseline. + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=quant_opts_pb2.TF, + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + [signature_key], + tags, + self._output_saved_model_path, + quantization_options, + ) + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {signature_key} + ) + + expected_outputs = model.matmul(**model_inputs) + got_outputs = converted_model.signatures[signature_key](**model_inputs) + self.assertAllClose(expected_outputs, got_outputs, atol=1e-1) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + + # Check the converted model in the XLA opset. + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=quant_opts_pb2.XLA, + enable_two_input_tensors=not use_kernel, + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + [signature_key], + tags, + self._output_saved_model_path_2, + quantization_options, + ) + + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {signature_key} + ) + loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path_2 + ) + graphdef = loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_op(graphdef, 'XlaDotV2')) + + new_outputs = converted_model.signatures[signature_key](**model_inputs) + + # The difference between TF and XLA path is expected to be small (smaller + # or equal to 1 in the quantized domain). + self.assertAllClose(new_outputs, expected_outputs, atol=1e-1) + @parameterized.parameters( + parameter_combinations([{ + 'activation_fn': [None, nn_ops.relu, nn_ops.relu6], + 'has_bias': [True, False], + 'has_batch_norm': [True, False], + 'target_opset': [quant_opts_pb2.XLA], + }]) + ) + @test_util.run_in_graph_and_eager_modes + def test_qat_conv_model( + self, + activation_fn: Optional[ops.Operation], + has_bias: bool, + has_batch_norm: bool, + target_opset: quant_opts_pb2.OpSet, + ): class ConvModel(module.Module): def __init__(self): self.filter_value = np.random.uniform( - low=-0.5, high=0.5, size=(2, 3, 3, 2)).astype('f4') - - @def_function.function(input_signature=[ - tensor_spec.TensorSpec( - name='input', shape=[1, 3, 4, 3], dtype=dtypes.float32), - ]) + low=-0.5, high=0.5, size=(2, 3, 3, 2) + ).astype('f4') + + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec( + name='input', shape=[1, 3, 4, 3], dtype=dtypes.float32 + ), + ] + ) def conv(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: """Performs a 2D convolution operation. @@ -252,90 +680,383 @@ def conv(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: A map of: output key -> output result. """ q_input = array_ops.fake_quant_with_min_max_args( - input_tensor, min=-0.1, max=0.2, num_bits=8, narrow_range=False) + input_tensor, min=-0.1, max=0.2, num_bits=8, narrow_range=False + ) filter_tensor = ops.convert_to_tensor(self.filter_value) + filter_min = array_ops.identity( + array_ops.constant([-0.5, -0.5], dtype=dtypes.float32) + ) + filter_max = array_ops.identity( + array_ops.constant([0.5, 0.5], dtype=dtypes.float32) + ) + q_filter = array_ops.fake_quant_with_min_max_vars_per_channel( + filter_tensor, filter_min, filter_max, num_bits=8, narrow_range=True + ) bias = array_ops.constant([0.1, 0.2], dtype=dtypes.float32) + scale, offset = [1.0] * 2, [0.5] * 2 + mean, variance = scale, offset out = nn_ops.conv2d( q_input, - filter_tensor, + q_filter, strides=[1, 1, 2, 1], dilations=[1, 1, 1, 1], padding='SAME', - data_format='NHWC') + data_format='NHWC', + ) if has_bias: out = nn_ops.bias_add(out, bias, data_format='NHWC') if activation_fn is not None: + # The accuracy is not good when having FusedBatchNorm without + # activation in this test. + if has_batch_norm: + # Fusing is supported for non-training case. + out, _, _, _, _, _ = nn_ops.fused_batch_norm_v3( + out, scale, offset, mean, variance, is_training=False + ) out = activation_fn(out) - q_out = array_ops.fake_quant_with_min_max_args( - out, min=-0.3, max=0.4, num_bits=8, narrow_range=False) + out_min = array_ops.constant([-0.18, -0.32], dtype=dtypes.float32) + out_max = array_ops.constant([0.5, 0.5], dtype=dtypes.float32) + q_out = array_ops.fake_quant_with_min_max_vars_per_channel( + out, min=out_min, max=out_max, num_bits=8, narrow_range=True + ) return {'output': q_out} np.random.seed(1234) model = ConvModel() - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + saved_model_save.save(model, self._input_saved_model_path) signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - tags = [tag_constants.SERVING] + tags = {tag_constants.SERVING} # Check the converted model with TF opset as the baseline. - output_directory = self.create_tempdir().full_path quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE), - op_set=quant_opts_pb2.TF) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=quant_opts_pb2.TF, + ) - converted_model = quantize_model.quantize(input_saved_model_path, - [signature_key], tags, - output_directory, - quantization_options) + converted_model = quantize_model.quantize( + self._input_saved_model_path, + [signature_key], + tags, + self._output_saved_model_path, + quantization_options, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {signature_key}) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {signature_key} + ) input_data = np.random.uniform( - low=-0.1, high=0.2, size=(1, 3, 4, 3)).astype('f4') + low=-0.1, high=0.2, size=(1, 3, 4, 3) + ).astype('f4') expected_outputs = model.conv(input_data) got_outputs = converted_model.signatures[signature_key]( - input=ops.convert_to_tensor(input_data)) - # TODO(b/215633216): Check if the accuracy is acceptable. - self.assertAllClose(expected_outputs, got_outputs, atol=0.01) + input=ops.convert_to_tensor(input_data) + ) + self.assertAllClose(expected_outputs, got_outputs, atol=0.00323) - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) # Check the converted model in the target opset. quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE), - op_set=target_opset) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=target_opset, + ) - output_directory = self.create_tempdir().full_path - converted_model = quantize_model.quantize(input_saved_model_path, - [signature_key], tags, - output_directory, - quantization_options) + converted_model = quantize_model.quantize( + self._input_saved_model_path, + [signature_key], + tags, + self._output_saved_model_path_2, + quantization_options, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {signature_key}) - loader = saved_model_loader.SavedModelLoader(output_directory) - meta_graphdef = loader.get_meta_graph_def_from_tags(tags) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {signature_key} + ) + loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path_2 + ) + graphdef = loader.get_meta_graph_def_from_tags(tags).graph_def if target_opset == quant_opts_pb2.XLA: - self.assertTrue(self._contains_op(meta_graphdef, 'XlaConvV2')) + self.assertTrue(self._contains_op(graphdef, 'XlaConvV2')) new_outputs = converted_model.signatures[signature_key]( - input=ops.convert_to_tensor(input_data)) + input=ops.convert_to_tensor(input_data) + ) + # The difference between TF and XLA path is expected to be small (smaller + # or equal to 1 in the quantized domain). + self.assertAllClose(new_outputs, got_outputs, atol=0.00154) + + # Currently, only some specific forms of equantions are supported. + @parameterized.parameters( + parameter_combinations([{ + 'equation': ['abc,cd->abd', 'abcd,cde->abe'], + 'shape_unknown': [True, False], + 'activation_fn': [None, nn_ops.relu, nn_ops.relu6], + 'has_bias': [True, False], + 'use_kernel': [True, False], + }]) + ) + @test_util.run_in_graph_and_eager_modes + def test_qat_einsum_model( + self, + equation: str, + shape_unknown: bool, + activation_fn: Optional[ops.Operation], + has_bias: bool, + use_kernel: bool, + ): + comma_pos = equation.find(',') + arrow_pos = equation.find('->') + x_labels = equation[0:comma_pos] + y_labels = equation[comma_pos + 1 : arrow_pos] + + label_to_size = {'a': 2, 'b': 3, 'c': 4, 'd': 5, 'e': 6} + x_shape = [label_to_size.get(x_label) for x_label in x_labels] + y_shape = [label_to_size.get(y_label) for y_label in y_labels] + x_signature = [None for _ in x_labels] if shape_unknown else list(x_shape) + y_signature = [None for _ in y_labels] if shape_unknown else list(y_shape) + + class EinsumModel(module.Module): + + def __init__(self, bias: Optional[core.Tensor]): + self._bias = bias + self._kernel = np.random.uniform(size=y_shape).astype('f4') + self._min = (-0.8, -0.8, -0.9) + self._max = (0.9, 0.9, 1.0) + + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec( + name='x', shape=x_signature, dtype=dtypes.float32 + ) + ] + ) + def einsum_with_kernel(self, x: core.Tensor) -> Mapping[str, core.Tensor]: + return self._einsum(x, self._kernel) + + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec( + name='x', shape=x_signature, dtype=dtypes.float32 + ), + tensor_spec.TensorSpec( + name='y', shape=y_signature, dtype=dtypes.float32 + ), + ] + ) + def einsum_without_kernel( + self, x: core.Tensor, y: core.Tensor + ) -> Mapping[str, core.Tensor]: + return self._einsum(x, y) + + def _einsum(self, x, y): + x = array_ops.fake_quant_with_min_max_vars( + x, + min=ops.convert_to_tensor(self._min[0]), + max=ops.convert_to_tensor(self._max[0]), + num_bits=8, + narrow_range=False, + ) + y = array_ops.fake_quant_with_min_max_vars( + y, + min=ops.convert_to_tensor(self._min[1]), + max=ops.convert_to_tensor(self._max[1]), + num_bits=8, + narrow_range=False, + ) + + out = special_math_ops.einsum(equation, x, y) + if self._bias is not None: + out = nn_ops.bias_add(out, self._bias) + if activation_fn is not None: + out = activation_fn(out) + out = array_ops.fake_quant_with_min_max_vars( + out, + min=ops.convert_to_tensor(self._min[2]), + max=ops.convert_to_tensor(self._max[2]), + num_bits=8, + narrow_range=False, + ) + return {'output': out} + + bias = None + if has_bias: + bias_shape = y_signature[-1] + if bias_shape is not None: + bias = array_ops.constant( + np.random.uniform(size=[y_signature[-1]]), dtype=dtypes.float32 + ) + model = EinsumModel(bias) + x = array_ops.constant( + np.random.uniform(size=x_shape), dtype=dtypes.float32 + ) + y = array_ops.constant( + np.random.uniform(size=y_shape), dtype=dtypes.float32 + ) + if use_kernel: + model.einsum = model.einsum_with_kernel + model_inputs = {'x': x} + else: + model.einsum = model.einsum_without_kernel + model_inputs = {'x': x, 'y': y} + + saved_model_save.save( + model, self._input_saved_model_path, signatures=model.einsum + ) + + signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + tags = {tag_constants.SERVING} + + # Check the converted model with TF opset as the baseline. + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=quant_opts_pb2.TF, + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + [signature_key], + tags, + self._output_saved_model_path, + quantization_options, + ) + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {signature_key} + ) + + expected_outputs = model.einsum(**model_inputs) + got_outputs = converted_model.signatures[signature_key](**model_inputs) + self.assertAllClose(expected_outputs, got_outputs, atol=1e-1) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + + # Check the converted model in the XLA opset. + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=quant_opts_pb2.XLA, + enable_two_input_tensors=not use_kernel, + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + [signature_key], + tags, + self._output_saved_model_path_2, + quantization_options, + ) + + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {signature_key} + ) + loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path_2 + ) + graphdef = loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_op(graphdef, 'XlaDotV2')) + + new_outputs = converted_model.signatures[signature_key](**model_inputs) + # The difference between TF and XLA path is expected to be small (smaller # or equal to 1 in the quantized domain). - self.assertAllClose(new_outputs, got_outputs, atol=0.00275) + self.assertAllClose(new_outputs, expected_outputs, atol=1e-1) + + # TODO(b/244276332): Allow table initialization in TF2 eager mode. + @test_util.deprecated_graph_mode_only + def test_qat_vocab_table_lookup_model(self): + tags = {tag_constants.SERVING} + signature_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + + # Create and save a simple model that involves a hash table. + inputs, outputs = self._create_and_save_vocab_table_lookup_qat_model_tf1( + self._input_saved_model_path, tags, signature_def_key + ) + + # Make sure that the desired input key and output key is present. + self.assertIn('input_vocabs', inputs.keys()) + self.assertIn('lookup', outputs.keys()) + + # Representative dataset is composed of a set of vocabs for table lookup. + repr_ds = [ + {'input_vocabs': np.array([b'hello', b'model', b'quantization'])} + for _ in range(4) + ] + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) + + signature_def_keys = [signature_def_key] + + quantize_model.quantize( + self._input_saved_model_path, + signature_def_keys, + tags, + self._output_saved_model_path, + quantization_options, + representative_dataset=repr_ds, + ) + + # Tests table lookup to make sure the table has been initialized + # successfully. + with session.Session(graph=ops.Graph()) as sess: + output_meta_graph_def = saved_model_loader.load( + sess, tags=tags, export_dir=self._output_saved_model_path + ) + + # The graph should contain a quantized function call (it contains a + # single f32 matmul node). + self.assertTrue( + self._contains_quantized_function_call( + output_meta_graph_def.graph_def + ) + ) + self.assertCountEqual( + output_meta_graph_def.signature_def.keys(), signature_def_keys + ) + + signature_def = output_meta_graph_def.signature_def[signature_def_key] + + input_tensor_name = signature_def.inputs['input_vocabs'].name + input_tensor = sess.graph.get_tensor_by_name(input_tensor_name) + + lookup_tensor_name = signature_def.outputs['lookup'].name + lookup_tensor = sess.graph.get_tensor_by_name(lookup_tensor_name) + + lookup_val = sess.run( + lookup_tensor, + feed_dict={ + input_tensor: np.array([b'model', b'quantization', b'hello']) + }, + ) + + self.assertAllClose(lookup_val, [1.0, 2.0, 0.0]) # Run this test only with the eager mode. @test_util.run_v2_only def test_ptq_model_with_variable(self): - class ConvModelWithVariable(module.Module): """A simple model that performs a single convolution to the input tensor. @@ -346,12 +1067,17 @@ def __init__(self) -> None: """Initializes the filter variable.""" self.filters = variables.Variable( random_ops.random_uniform( - shape=(2, 3, 3, 2), minval=-1., maxval=1.)) - - @def_function.function(input_signature=[ - tensor_spec.TensorSpec( - name='input', shape=(1, 3, 4, 3), dtype=dtypes.float32), - ]) + shape=(2, 3, 3, 2), minval=-1.0, maxval=1.0 + ) + ) + + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec( + name='input', shape=(1, 3, 4, 3), dtype=dtypes.float32 + ), + ] + ) def __call__(self, x: core.Tensor) -> Mapping[str, core.Tensor]: """Performs a 2D convolution operation. @@ -367,7 +1093,8 @@ def __call__(self, x: core.Tensor) -> Mapping[str, core.Tensor]: strides=[1, 1, 2, 1], dilations=[1, 1, 1, 1], padding='SAME', - data_format='NHWC') + data_format='NHWC', + ) return {'output': out} def gen_data() -> repr_dataset.RepresentativeDataset: @@ -379,572 +1106,1271 @@ def gen_data() -> repr_dataset.RepresentativeDataset: """ for _ in range(8): yield { - 'input': - random_ops.random_uniform( - shape=(1, 3, 4, 3), minval=0, maxval=150) + 'input': random_ops.random_uniform( + shape=(1, 3, 4, 3), minval=0, maxval=150 + ) } model = ConvModelWithVariable() - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + saved_model_save.save(model, self._input_saved_model_path) signature_keys = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] tags = {tag_constants.SERVING} - output_directory = self.create_tempdir().full_path quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) converted_model = quantize_model.quantize( - input_saved_model_path, + self._input_saved_model_path, signature_keys, tags, - output_directory, + self._output_saved_model_path, quantization_options, - representative_dataset=gen_data()) + representative_dataset=gen_data(), + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - signature_keys) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), signature_keys + ) - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + # TODO(b/263830952): Use dictionaries instead of tuples for parameters. @parameterized.named_parameters( - ('none', None, False, False, quant_opts_pb2.TF, False), - ('relu', nn_ops.relu, False, False, quant_opts_pb2.TF, False), - ('relu6', nn_ops.relu6, False, False, quant_opts_pb2.TF, False), - ('bn', None, False, True, quant_opts_pb2.TF, False), - ('bn_and_relu', nn_ops.relu, False, True, quant_opts_pb2.TF, False), - ('with_bias', None, True, False, quant_opts_pb2.TF, False), - ('with_bias_and_bn', None, True, True, quant_opts_pb2.TF, False), - ('with_bias_and_bn_and_relu', nn_ops.relu, True, True, quant_opts_pb2.TF, - False), - ('with_bias_and_relu', nn_ops.relu, True, False, quant_opts_pb2.TF, - False), - ('with_bias_and_relu6', nn_ops.relu6, True, False, quant_opts_pb2.TF, - False), - ('with_bias_and_bn_to_xla', None, True, True, quant_opts_pb2.XLA, False), - ('with_bias_and_relu6_to_xla', nn_ops.relu6, True, False, - quant_opts_pb2.XLA, False), - ('with_bias_and_bn_to_xla_dynamic', None, True, True, quant_opts_pb2.XLA, - True), - ('with_bias_and_relu6_to_xla_dynamic', nn_ops.relu6, True, False, - quant_opts_pb2.XLA, True), + ('none', None, False, False, quant_opts_pb2.TF, False, False), + ('relu', nn_ops.relu, False, False, quant_opts_pb2.TF, False, False), + ('relu6', nn_ops.relu6, False, False, quant_opts_pb2.TF, False, False), + ('bn', None, False, True, quant_opts_pb2.TF, False, False), + ( + 'bn_and_relu', + nn_ops.relu, + False, + True, + quant_opts_pb2.TF, + False, + False, + ), + ('with_bias', None, True, False, quant_opts_pb2.TF, False, False), + ('with_bias_and_bn', None, True, True, quant_opts_pb2.TF, False, False), + ( + 'with_bias_and_bn_and_relu', + nn_ops.relu, + True, + True, + quant_opts_pb2.TF, + False, + False, + ), + ( + 'with_bias_and_relu', + nn_ops.relu, + True, + False, + quant_opts_pb2.TF, + False, + False, + ), + ( + 'with_bias_and_relu6', + nn_ops.relu6, + True, + False, + quant_opts_pb2.TF, + False, + False, + ), + ( + 'with_bias_and_bn_to_xla', + None, + True, + True, + quant_opts_pb2.XLA, + False, + False, + ), + ( + 'with_bias_and_relu6_to_xla', + nn_ops.relu6, + True, + False, + quant_opts_pb2.XLA, + False, + False, + ), + ( + 'with_bias_and_bn_to_xla_dynamic', + None, + True, + True, + quant_opts_pb2.XLA, + True, + False, + ), + ( + 'with_bias_and_relu6_to_xla_dynamic', + nn_ops.relu6, + True, + False, + quant_opts_pb2.XLA, + True, + False, + ), + ( + 'none_to_uq', + None, + False, + False, + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + False, + ), + ( + 'none_to_uq_per_channel', + None, + False, + False, + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + True, + ), + ( + 'relu_to_uq', + nn_ops.relu, + False, + False, + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + False, + ), + ( + 'with_bias_to_uq', + None, + True, + False, + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + False, + ), + ( + 'with_bias_and_relu_to_uq', + nn_ops.relu, + True, + False, + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + False, + ), + ( + 'with_bias_and_relu6_to_uq', + nn_ops.relu6, + True, + False, + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + False, + ), ) @test_util.run_in_graph_and_eager_modes - def test_conv_ptq_model(self, activation_fn: Optional[ops.Operation], - has_bias: bool, has_bn: bool, - target_opset: quant_opts_pb2.OpSet, - input_shape_dynamic: bool): + def test_conv_ptq_model( + self, + activation_fn: Optional[ops.Operation], + has_bias: bool, + has_batch_norm: bool, + target_opset: quant_opts_pb2.OpSet, + input_shape_dynamic: bool, + enable_per_channel_quantization: bool, + ): input_shape = [None, None, None, 3] if input_shape_dynamic else [1, 3, 4, 3] - - class ConvModel(module.Module): - - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=input_shape, dtype=dtypes.float32) - ]) - def conv(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: - """Performs a 2D convolution operation. - - Args: - input_tensor: Input tensor to perform convolution on. - - Returns: - A map of: output key -> output result. - """ - filters = np.random.uniform( - low=-10, high=10, size=(2, 3, 3, 2)).astype('f4') - bias = np.random.uniform(low=0, high=10, size=(2)).astype('f4') - scale, offset = [1.0, 1.0], [0.5, 0.5] - mean, variance = scale, offset - out = nn_ops.conv2d( - input_tensor, - filters, - strides=[1, 1, 2, 1], - dilations=[1, 1, 1, 1], - padding='SAME', - data_format='NHWC') - if has_bias: - out = nn_ops.bias_add(out, bias) - if has_bn: - # Fusing is supported for non-training case. - out, _, _, _, _, _ = nn_ops.fused_batch_norm_v3( - out, scale, offset, mean, variance, is_training=False) - if activation_fn is not None: - out = activation_fn(out) - return {'output': out} + filter_shape = [2, 3, 3, 2] np.random.seed(1234) - model = ConvModel() - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + model = self._create_conv2d_model( + input_shape, filter_shape, has_bias, has_batch_norm, activation_fn + ) + saved_model_save.save(model, self._input_saved_model_path) def data_gen() -> repr_dataset.RepresentativeDataset: for _ in range(8): yield { - 'input_tensor': - ops.convert_to_tensor( - np.random.uniform(low=0, high=150, - size=(1, 3, 4, 3)).astype('f4')), + 'input_tensor': ops.convert_to_tensor( + np.random.uniform(low=0, high=150, size=(1, 3, 4, 3)).astype( + 'f4' + ) + ), } - tags = [tag_constants.SERVING] - output_directory = self.create_tempdir().full_path + tags = {tag_constants.SERVING} quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE), - op_set=target_opset) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=target_opset, + enable_per_channel_quantization=enable_per_channel_quantization, + ) converted_model = quantize_model.quantize( - input_saved_model_path, ['serving_default'], + self._input_saved_model_path, + ['serving_default'], tags, - output_directory, + self._output_saved_model_path, quantization_options, - representative_dataset=data_gen()) + representative_dataset=data_gen(), + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) - - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def if target_opset == quant_opts_pb2.XLA: - self.assertTrue(self._contains_op(output_meta_graphdef, 'XlaConvV2')) - else: + self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2')) + elif target_opset == quant_opts_pb2.UNIFORM_QUANTIZED: self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) - self.assertFalse( - self._contains_op(output_meta_graphdef, 'FusedBatchNormV3')) + self._contains_op(output_graphdef, 'UniformQuantizedConvolution') + ) + if enable_per_channel_quantization: + quantized_axis = 3 + quantized_dim_size_attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue( + shape=[ + tensor_shape_pb2.TensorShapeProto( + dim=[ + tensor_shape_pb2.TensorShapeProto.Dim( + size=filter_shape[quantized_axis] + ) + ] + ) + ] + ) + ) + else: + quantized_axis = -1 + # Empty dimension. Per-tensor quantization has singular channel. + quantized_dim_size_attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue( + shape=[tensor_shape_pb2.TensorShapeProto()] + ) + ) + quantized_axis_attr = attr_value_pb2.AttrValue(i=quantized_axis) + self.assertEqual( + self._count_ops( + output_graphdef, + _PER_CHANNEL_QUANTIZED_OPS, + 'rhs_quantization_axis', + quantized_axis_attr, + ), + self._count_ops(output_graphdef, _PER_CHANNEL_QUANTIZED_OPS), + ) + self.assertEqual( + self._count_ops( + output_graphdef, + _PER_CHANNEL_OP_NAMES, + '_output_shapes', + quantized_dim_size_attr, + get_op_name=True, + ), + self._count_ops( + output_graphdef, + _PER_CHANNEL_OP_NAMES, + get_op_name=True, + ), + ) + self.assertFalse(self._contains_op(output_graphdef, 'Conv2D')) + else: + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + self.assertFalse(self._contains_op(output_graphdef, 'FusedBatchNormV3')) @parameterized.named_parameters( - ('none', None, False, False, quant_opts_pb2.TF, False), - ('relu', nn_ops.relu, False, False, quant_opts_pb2.TF, False), - ('relu6', nn_ops.relu6, False, False, quant_opts_pb2.TF, False), - ('bn', None, False, True, quant_opts_pb2.TF, False), - ('bn_and_relu', nn_ops.relu, False, True, quant_opts_pb2.TF, False), - ('with_bias', None, True, False, quant_opts_pb2.TF, False), - ('with_bias_and_bn', None, True, True, quant_opts_pb2.TF, False), - ('with_bias_and_bn_and_relu', nn_ops.relu, True, True, quant_opts_pb2.TF, - False), - ('with_bias_and_relu', nn_ops.relu, True, False, quant_opts_pb2.TF, - False), - ('with_bias_and_relu6', nn_ops.relu6, True, False, quant_opts_pb2.TF, - False), - ('with_bias_and_bn_to_xla', None, True, True, quant_opts_pb2.XLA, False), - ('with_bias_and_relu6_to_xla', nn_ops.relu6, True, False, - quant_opts_pb2.XLA, False), - ('with_bias_and_bn_to_xla_dynamic', None, True, True, quant_opts_pb2.XLA, - True), - ('with_bias_and_relu6_to_xla_dynamic', nn_ops.relu6, True, False, - quant_opts_pb2.XLA, True), + ('to_tf', quant_opts_pb2.TF), + ('to_xla', quant_opts_pb2.XLA), + ('to_uq', quant_opts_pb2.UNIFORM_QUANTIZED), ) - @test_util.run_in_graph_and_eager_modes - def test_depthwise_conv_ptq_model(self, - activation_fn: Optional[ops.Operation], - has_bias: bool, has_bn: bool, - target_opset: quant_opts_pb2.OpSet, - input_shape_dynamic: bool): - input_shape = [None, None, None, 3] if input_shape_dynamic else [1, 3, 4, 3] + @test_util.run_v2_only + def test_gather_and_conv_model(self, target_opset: quant_opts_pb2.OpSet): + model = self._create_simple_gather_and_conv_model(filter_shape=(2, 3, 3, 2)) + saved_model_save.save(model, self._input_saved_model_path) - class DepthwiseConvModel(module.Module): + data_gen = self._create_data_generator( + input_key='input_tensor', + shape=[6], + minval=0, + maxval=10, + dtype=dtypes.int64, + ) - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=input_shape, dtype=dtypes.float32) - ]) - def conv(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: - """Performs a 2D convolution operation. + tags = {tag_constants.SERVING} - Args: - input_tensor: Input tensor to perform convolution on. + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=target_opset, + ) - Returns: - A map of: output key -> output result. - """ - filters = np.random.uniform( - low=-10, high=10, size=(2, 3, 3, 1)).astype('f4') - bias = np.random.uniform(low=0, high=10, size=(3)).astype('f4') - scale, offset = [1.0, 1.0, 1.0], [0.5, 0.5, 0.5] - mean, variance = scale, offset - out = nn_ops.depthwise_conv2d_native( - input_tensor, - filters, - strides=[1, 2, 2, 1], - dilations=[1, 1, 1, 1], - padding='SAME', - data_format='NHWC') - if has_bias: - out = nn_ops.bias_add(out, bias) - if has_bn: - # Fusing is supported for non-training case. - out, _, _, _, _, _ = nn_ops.fused_batch_norm_v3( - out, scale, offset, mean, variance, is_training=False) - if activation_fn is not None: - out = activation_fn(out) - return {'output': out} + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + representative_dataset=data_gen, + ) + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + if target_opset == quant_opts_pb2.XLA: + self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2')) + self.assertSizeRatioLessThan( + self._output_saved_model_path, self._input_saved_model_path, 1 / 3 + ) + elif target_opset == quant_opts_pb2.UNIFORM_QUANTIZED: + self.assertTrue( + self._contains_op(output_graphdef, 'UniformQuantizedConvolution') + ) + self.assertSizeRatioGreaterThan( + self._output_saved_model_path, self._input_saved_model_path, 0.95 + ) + else: + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + self.assertSizeRatioGreaterThan( + self._output_saved_model_path, self._input_saved_model_path, 0.95 + ) + # TODO(b/263830952): Use dictionaries instead of tuples for parameters. + @parameterized.named_parameters( + ('none', None, False, False, quant_opts_pb2.TF, False, False), + ('relu', nn_ops.relu, False, False, quant_opts_pb2.TF, False, False), + ('relu6', nn_ops.relu6, False, False, quant_opts_pb2.TF, False, False), + ('bn', None, False, True, quant_opts_pb2.TF, False, False), + ( + 'bn_and_relu', + nn_ops.relu, + False, + True, + quant_opts_pb2.TF, + False, + False, + ), + ('with_bias', None, True, False, quant_opts_pb2.TF, False, False), + ('with_bias_and_bn', None, True, True, quant_opts_pb2.TF, False, False), + ( + 'with_bias_and_bn_and_relu', + nn_ops.relu, + True, + True, + quant_opts_pb2.TF, + False, + False, + ), + ( + 'with_bias_and_relu', + nn_ops.relu, + True, + False, + quant_opts_pb2.TF, + False, + False, + ), + ( + 'with_bias_and_relu6', + nn_ops.relu6, + True, + False, + quant_opts_pb2.TF, + False, + False, + ), + ( + 'with_bias_and_bn_to_xla', + None, + True, + True, + quant_opts_pb2.XLA, + False, + False, + ), + ( + 'with_bias_and_relu6_to_xla', + nn_ops.relu6, + True, + False, + quant_opts_pb2.XLA, + False, + False, + ), + ( + 'with_bias_and_bn_to_xla_dynamic', + None, + True, + True, + quant_opts_pb2.XLA, + True, + False, + ), + ( + 'with_bias_and_relu6_to_xla_dynamic', + nn_ops.relu6, + True, + False, + quant_opts_pb2.XLA, + True, + False, + ), + ( + 'none_to_uq', + None, + False, + False, + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + False, + ), + ( + 'none_to_uq_per_channel', + None, + False, + False, + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + True, + ), + ( + 'relu_to_uq', + nn_ops.relu, + False, + False, + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + False, + ), + ( + 'with_bias_to_uq', + None, + True, + False, + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + False, + ), + ( + 'with_bias_and_relu_to_uq', + nn_ops.relu, + True, + False, + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + False, + ), + ( + 'with_bias_and_relu6_to_uq', + nn_ops.relu6, + True, + False, + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + False, + ), + ) + @test_util.run_in_graph_and_eager_modes + def test_depthwise_conv_ptq_model( + self, + activation_fn: Optional[ops.Operation], + has_bias: bool, + has_batch_norm: bool, + target_opset: quant_opts_pb2.OpSet, + input_shape_dynamic: bool, + enable_per_channel_quantization: bool, + ): + input_shape = [None, None, None, 3] if input_shape_dynamic else [1, 3, 4, 3] + filter_shape = [2, 3, 3, 1] + model = self._create_depthwise_conv2d_model( + input_shape, filter_shape, has_bias, has_batch_norm, activation_fn + ) np.random.seed(1234) - model = DepthwiseConvModel() - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + saved_model_save.save(model, self._input_saved_model_path) def data_gen() -> repr_dataset.RepresentativeDataset: for _ in range(8): yield { - 'input_tensor': - ops.convert_to_tensor( - np.random.uniform(low=0, high=150, - size=(1, 3, 4, 3)).astype('f4')), + 'input_tensor': ops.convert_to_tensor( + np.random.uniform(low=0, high=150, size=(1, 3, 4, 3)).astype( + 'f4' + ) + ), } - tags = [tag_constants.SERVING] - output_directory = self.create_tempdir().full_path + tags = {tag_constants.SERVING} quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE), - op_set=target_opset) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=target_opset, + enable_per_channel_quantization=enable_per_channel_quantization, + ) converted_model = quantize_model.quantize( - input_saved_model_path, ['serving_default'], + self._input_saved_model_path, + ['serving_default'], tags, - output_directory, + self._output_saved_model_path, quantization_options, - representative_dataset=data_gen()) + representative_dataset=data_gen(), + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) - - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def if target_opset == quant_opts_pb2.XLA: - # Quantization for DepthwiseConv is disabled for XLA opset. self.assertTrue( - self._contains_op(output_meta_graphdef, 'DepthwiseConv2dNative')) - else: + self._contains_op(output_graphdef, 'DepthwiseConv2dNative') + ) + elif target_opset == quant_opts_pb2.UNIFORM_QUANTIZED: self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) - self.assertFalse( - self._contains_op(output_meta_graphdef, 'FusedBatchNormV3')) + self._contains_op(output_graphdef, 'UniformQuantizedConvolution') + ) + if enable_per_channel_quantization: + quantized_axis = 3 + quantized_dim_size_attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue( + shape=[ + tensor_shape_pb2.TensorShapeProto( + dim=[ + tensor_shape_pb2.TensorShapeProto.Dim( + # Depthwise conv is reshaped to [H,W,1,CxM]. + size=filter_shape[quantized_axis] + * filter_shape[2] + ) + ] + ) + ] + ) + ) + else: + quantized_axis = -1 + # Empty dimension. Per-tensor quantization has singular channel. + quantized_dim_size_attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue( + shape=[tensor_shape_pb2.TensorShapeProto()] + ) + ) + quantized_axis_attr = attr_value_pb2.AttrValue(i=quantized_axis) + self.assertEqual( + self._count_ops( + output_graphdef, + _PER_CHANNEL_QUANTIZED_OPS, + 'rhs_quantization_axis', + quantized_axis_attr, + ), + self._count_ops(output_graphdef, _PER_CHANNEL_QUANTIZED_OPS), + ) + self.assertEqual( + self._count_ops( + output_graphdef, + _PER_CHANNEL_OP_NAMES, + '_output_shapes', + quantized_dim_size_attr, + get_op_name=True, + ), + self._count_ops( + output_graphdef, + _PER_CHANNEL_OP_NAMES, + get_op_name=True, + ), + ) + self.assertFalse(self._contains_op(output_graphdef, 'Conv2D')) + else: + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + self.assertFalse(self._contains_op(output_graphdef, 'FusedBatchNormV3')) @parameterized.parameters( - parameter_combinations([{ - 'activation_fn': [None, nn_ops.relu, nn_ops.relu6], - 'has_bias': [True, False], - 'target_opset': [quant_opts_pb2.XLA], - }])) + *parameter_combinations([ + { + 'activation_fn': [None, nn_ops.relu, nn_ops.relu6], + 'has_bias': [True, False], + 'batch_sizes': [([], []), ([2, 3], [2, 3])], + 'target_opset': [quant_opts_pb2.XLA], + }, + # Test broadcastable batch sizes. + { + 'activation_fn': [None], + 'has_bias': [True], + 'batch_sizes': [ + ([2], []), + ([], [2]), + ([1], [2]), + ([None], []), + ], + 'target_opset': [quant_opts_pb2.XLA], + }, + ]) + ) @test_util.run_in_graph_and_eager_modes - def test_matmul_ptq_model(self, activation_fn: Optional[ops.Operation], - has_bias: bool, target_opset: quant_opts_pb2.OpSet): + def test_matmul_ptq_model( + self, + activation_fn: Optional[ops.Operation], + has_bias: bool, + batch_sizes: Sequence[int], + target_opset: quant_opts_pb2.OpSet, + ): np.random.seed(1234) - model = self._create_matmul_model(has_bias, activation_fn) - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + lhs_batch_size, rhs_batch_size = batch_sizes + input_shape = (*lhs_batch_size, 1, 1024) + filter_shape = (*rhs_batch_size, 1024, 3) + static_input_shape = [dim if dim is not None else 2 for dim in input_shape] + model = self._create_matmul_model( + input_shape, + filter_shape, + self._input_saved_model_path, + has_bias, + activation_fn, + ) def data_gen() -> repr_dataset.RepresentativeDataset: for _ in range(500): yield { - 'input_tensor': - ops.convert_to_tensor( - np.random.uniform(low=0.0, high=1.0, - size=(1, 1024)).astype('f4')), + 'input_tensor': ops.convert_to_tensor( + np.random.uniform( + low=0.0, high=1.0, size=static_input_shape + ).astype('f4') + ), } - tags = [tag_constants.SERVING] - output_directory = self.create_tempdir().full_path + tags = {tag_constants.SERVING} quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) converted_model = quantize_model.quantize( - input_saved_model_path, ['serving_default'], + self._input_saved_model_path, + ['serving_default'], tags, - output_directory, + self._output_saved_model_path, quantization_options, - representative_dataset=data_gen()) + representative_dataset=data_gen(), + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) input_data = ops.convert_to_tensor( - np.random.uniform(low=0.0, high=1.0, size=(1, 1024)).astype('f4')) + np.random.uniform(low=0.0, high=1.0, size=static_input_shape).astype( + 'f4' + ) + ) expected_outputs = model.matmul(input_data) got_outputs = converted_model.signatures['serving_default']( - input_tensor=ops.convert_to_tensor(input_data)) + input_tensor=ops.convert_to_tensor(input_data) + ) self.assertAllClose(expected_outputs, got_outputs, atol=0.1674) # Check the converted model in the target opset. quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE), - op_set=target_opset) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=target_opset, + ) - output_directory = self.create_tempdir().full_path converted_model = quantize_model.quantize( - input_saved_model_path, ['serving_default'], + self._input_saved_model_path, + ['serving_default'], tags, - output_directory, + self._output_saved_model_path_2, quantization_options, - representative_dataset=data_gen()) + representative_dataset=data_gen(), + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path_2 + ) + output_graphdef = loader.get_meta_graph_def_from_tags(tags).graph_def + if target_opset == quant_opts_pb2.XLA: + self.assertTrue(self._contains_op(output_graphdef, 'XlaDotV2')) new_outputs = converted_model.signatures['serving_default']( - input_tensor=ops.convert_to_tensor(input_data)) - # The difference between TF and target path is expected to be small (smaller - # or equal to 1 in the quantized domain). - self.assertAllClose(new_outputs, got_outputs, atol=0.1048) + input_tensor=ops.convert_to_tensor(input_data) + ) + # The difference between TF and target path is expected to be small. + self.assertAllClose(new_outputs, got_outputs, atol=0.1202) self.assertAllClose(new_outputs, expected_outputs, atol=0.1023) + @parameterized.parameters( + ('abc,cde->abde', (2, 2, 64), (64, 3, 3), (3, 3), quant_opts_pb2.XLA), + ('abc,dce->adbe', (2, 2, 64), (3, 64, 3), (2, 3), quant_opts_pb2.XLA), + ) + def test_einsum_ptq_model( + self, + equation: str, + input_shape: Sequence[int], + weight_shape: Sequence[int], + bias_shape: Sequence[int], + target_opset: quant_opts_pb2.OpSet, + ): + model = self._create_einsum_model( + self._input_saved_model_path, + equation, + input_shape, + weight_shape, + bias_shape, + activation_fn=nn_ops.relu, + ) + + def data_gen() -> repr_dataset.RepresentativeDataset: + for _ in range(200): + yield { + 'input_tensor': ops.convert_to_tensor( + np.random.uniform(low=0.0, high=1.0, size=input_shape).astype( + 'f4' + ) + ), + } + + tags = {tag_constants.SERVING} + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + representative_dataset=data_gen(), + ) + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + + input_data = ops.convert_to_tensor( + np.random.uniform(low=0.0, high=1.0, size=input_shape).astype('f4') + ) + expected_outputs = model.einsum(input_data) + got_outputs = converted_model.signatures['serving_default']( + input_tensor=ops.convert_to_tensor(input_data) + ) + self.assertAllClose(expected_outputs, got_outputs, atol=0.0608) + + # Check the converted model in the target opset. + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=target_opset, + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path_2, + quantization_options, + representative_dataset=data_gen(), + ) + + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path_2 + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + if target_opset == quant_opts_pb2.XLA: + self.assertTrue(self._contains_op(output_graphdef, 'XlaDotV2')) + + new_outputs = converted_model.signatures['serving_default']( + input_tensor=ops.convert_to_tensor(input_data) + ) + # The difference between TF and target path is expected to be small. + self.assertAllClose(new_outputs, got_outputs, atol=0.0666) + self.assertAllClose(new_outputs, expected_outputs, atol=0.057) + + @test_util.run_in_graph_and_eager_modes + def test_function_alias_preserved(self): + model = self._create_conv2d_model( + input_shape=(1, 3, 4, 3), filter_shape=(2, 3, 3, 2) + ) + + signatures = { + 'serving_default': model.conv.get_concrete_function(), + } + save_opts = save_options.SaveOptions( + function_aliases={'conv_func': model.conv} + ) + + saved_model_save.save( + model, self._input_saved_model_path, signatures, save_opts + ) + + def data_gen() -> repr_dataset.RepresentativeDataset: + rng = np.random.default_rng(seed=123) + for _ in range(2): + yield { + 'input_tensor': rng.uniform( + low=0, high=150, size=(1, 3, 4, 3) + ).astype(np.float32), + } + + tags = {tag_constants.SERVING} + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=quant_opts_pb2.OpSet.XLA, + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + representative_dataset=data_gen(), + ) + + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + # Test whether the aliased function exists. + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + + # Confirm that the function alias is preserved. + meta_graph_def = output_loader.get_meta_graph_def_from_tags(tags) + function_aliases = meta_graph_def.meta_info_def.function_aliases + self.assertNotEmpty(function_aliases) + self.assertCountEqual(function_aliases.values(), {'conv_func'}) + + # Test that the aliased function contains a quantized op. + for func_name, alias in function_aliases.items(): + if alias == 'conv_func': + for func in meta_graph_def.graph_def.library.function: + if func.signature.name == func_name: + self._contains_op_with_name_and_attribute( + func.node_def, op_name='XlaConvV2', attr_name='', attr_val=None + ) + + @test_util.deprecated_graph_mode_only + def test_matmul_ptq_model_with_unfreeze_constants(self): + # Uses large weight to exceed the constant size threshold of 64KiB + # (specified by `kDefaultConstantSizeThresholdInBytes`) for unfreezing. + self._create_matmul_model( + input_shape=(1, 20), + weight_shape=(20, 4096), + saved_model_path=self._input_saved_model_path, + ) + + repr_ds = self._create_data_generator( + input_key='input_tensor', shape=(1, 20), num_examples=2 + ) + + tags = {tag_constants.SERVING} + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + freeze_all_variables=quant_opts_pb2.FreezeAllVariables(enabled=False), + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + representative_dataset=repr_ds, + ) + + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + # Confirms that quantization is applied to the model. + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + + # Tests that there are variables in the model. + variable_node_defs = _find_variables(output_graphdef) + self.assertLen(variable_node_defs, 1) + + # Reads the variables from the checkpoint file and matches with the + # variables found in the graph. + checkpoint_path = os.path.join( + self._output_saved_model_path, 'variables', 'variables' + ) + var_name_and_shapes = checkpoint_utils.list_variables(checkpoint_path) + + # Checks that each variable's name and shape match. + self.assertEqual(len(variable_node_defs), len(var_name_and_shapes)) + for var_name, shape in var_name_and_shapes: + self.assertIn(var_name, variable_node_defs) + self.assertEqual( + shape, + tensor_shape.TensorShape( + variable_node_defs[var_name].attr['shape'].shape + ), + ) + @parameterized.named_parameters( ('use_constant', False), ('use_variable', True), ) @test_util.run_v2_only def test_gather_model(self, use_variable): - model = self._create_gather_model(use_variable) - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + saved_model_save.save(model, self._input_saved_model_path) - tags = [tag_constants.SERVING] - output_directory = self.create_tempdir().full_path + tags = {tag_constants.SERVING} quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) data_gen = self._create_data_generator( input_key='input_tensor', shape=[6], minval=0, maxval=10, - dtype=dtypes.int64) + dtype=dtypes.int64, + ) converted_model = quantize_model.quantize( - input_saved_model_path, ['serving_default'], + self._input_saved_model_path, + ['serving_default'], tags, - output_directory, + self._output_saved_model_path, quantization_options, - representative_dataset=data_gen) + representative_dataset=data_gen, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) - - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def # Currently gather is not supported. - self.assertFalse( - self._contains_quantized_function_call(output_meta_graphdef)) + self.assertFalse(self._contains_quantized_function_call(output_graphdef)) @test_util.run_in_graph_and_eager_modes def test_model_ptq_use_representative_samples_list(self): - model = self._create_matmul_model() - input_savedmodel_dir = self.create_tempdir('input').full_path - saved_model_save.save(model, input_savedmodel_dir) + self._create_matmul_model( + input_shape=(1, 1024), + weight_shape=(1024, 3), + saved_model_path=self._input_saved_model_path, + ) quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) - output_savedmodel_dir = self.create_tempdir().full_path + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) tags = {tag_constants.SERVING} - representative_dataset: repr_dataset.RepresentativeDataset = [{ - 'input_tensor': random_ops.random_uniform(shape=(1, 1024)), - } for _ in range(8)] + representative_dataset: repr_dataset.RepresentativeDataset = [ + { + 'input_tensor': random_ops.random_uniform(shape=(1, 1024)), + } + for _ in range(8) + ] converted_model = quantize_model.quantize( - input_savedmodel_dir, ['serving_default'], - output_directory=output_savedmodel_dir, + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, quantization_options=quantization_options, - representative_dataset=representative_dataset) + representative_dataset=representative_dataset, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) - output_loader = saved_model_loader.SavedModelLoader(output_savedmodel_dir) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) @test_util.run_in_graph_and_eager_modes def test_model_ptq_use_ndarray_representative_dataset(self): - model = self._create_matmul_model() - input_savedmodel_dir = self.create_tempdir('input').full_path - saved_model_save.save(model, input_savedmodel_dir) + self._create_matmul_model( + input_shape=(1, 1024), + weight_shape=(1024, 3), + saved_model_path=self._input_saved_model_path, + ) quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) - output_savedmodel_dir = self.create_tempdir().full_path + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) tags = {tag_constants.SERVING} # Use np.ndarrays instead of tf.Tensors for the representative dataset. - representative_dataset = [{ - 'input_tensor': np.random.uniform(size=(1, 1024)).astype(np.float32), - } for _ in range(4)] + representative_dataset = [ + { + 'input_tensor': np.random.uniform(size=(1, 1024)).astype( + np.float32 + ), + } + for _ in range(4) + ] converted_model = quantize_model.quantize( - input_savedmodel_dir, ['serving_default'], - tags=tags, - output_directory=output_savedmodel_dir, + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, quantization_options=quantization_options, - representative_dataset=representative_dataset) + representative_dataset=representative_dataset, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) - output_loader = saved_model_loader.SavedModelLoader(output_savedmodel_dir) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) @test_util.run_in_graph_and_eager_modes def test_model_ptq_use_python_list_representative_dataset(self): - model = self._create_matmul_model() - input_savedmodel_dir = self.create_tempdir('input').full_path - saved_model_save.save(model, input_savedmodel_dir) + self._create_matmul_model( + input_shape=(1, 1024), + weight_shape=(1024, 3), + saved_model_path=self._input_saved_model_path, + ) quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) - output_savedmodel_dir = self.create_tempdir().full_path + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) tags = {tag_constants.SERVING} # Use plain python lists as representative samples. - representative_dataset = [{ - 'input_tensor': [[i * 0.1 for i in range(1024)]], - } for _ in range(4)] + representative_dataset = [ + { + 'input_tensor': [[i * 0.1 for i in range(1024)]], + } + for _ in range(4) + ] converted_model = quantize_model.quantize( - input_savedmodel_dir, ['serving_default'], - tags=tags, - output_directory=output_savedmodel_dir, + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, quantization_options=quantization_options, - representative_dataset=representative_dataset) + representative_dataset=representative_dataset, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) - output_loader = saved_model_loader.SavedModelLoader(output_savedmodel_dir) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) @test_util.run_in_graph_and_eager_modes def test_model_ptq_call_twice(self): - model = self._create_matmul_model() - input_savedmodel_dir = self.create_tempdir('input').full_path - saved_model_save.save(model, input_savedmodel_dir) + self._create_matmul_model( + input_shape=(1, 1024), + weight_shape=(1024, 3), + saved_model_path=self._input_saved_model_path, + ) quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) - output_savedmodel_dir_1 = self.create_tempdir().full_path + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) tags = {tag_constants.SERVING} signature_def_keys = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] - representative_dataset: repr_dataset.RepresentativeDataset = [{ - 'input_tensor': random_ops.random_uniform(shape=(1, 1024)), - } for _ in range(8)] + representative_dataset: repr_dataset.RepresentativeDataset = [ + { + 'input_tensor': random_ops.random_uniform(shape=(1, 1024)), + } + for _ in range(8) + ] # Test the first run. converted_model_1 = quantize_model.quantize( - input_savedmodel_dir, + self._input_saved_model_path, signature_def_keys, - output_directory=output_savedmodel_dir_1, + output_directory=self._output_saved_model_path, quantization_options=quantization_options, - representative_dataset=representative_dataset) + representative_dataset=representative_dataset, + ) self.assertIsNotNone(converted_model_1) - self.assertCountEqual(converted_model_1.signatures._signatures.keys(), - signature_def_keys) - output_loader = saved_model_loader.SavedModelLoader(output_savedmodel_dir_1) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + self.assertCountEqual( + converted_model_1.signatures._signatures.keys(), signature_def_keys + ) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) # Test the second run on the same model. - output_savedmodel_dir_2 = self.create_tempdir().full_path converted_model_2 = quantize_model.quantize( - input_savedmodel_dir, + self._input_saved_model_path, signature_def_keys, - output_directory=output_savedmodel_dir_2, + output_directory=self._output_saved_model_path_2, quantization_options=quantization_options, - representative_dataset=representative_dataset) + representative_dataset=representative_dataset, + ) self.assertIsNotNone(converted_model_2) - self.assertCountEqual(converted_model_2.signatures._signatures.keys(), - signature_def_keys) - output_loader = saved_model_loader.SavedModelLoader(output_savedmodel_dir_2) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + self.assertCountEqual( + converted_model_2.signatures._signatures.keys(), signature_def_keys + ) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path_2 + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) # tf.data.Dataset is as an Iterable (thus can be used as representative # dataset) only in TF2 (eager mode). @test_util.run_v2_only def test_model_ptq_use_tf_dataset_for_representative_dataset(self): - model = self._create_matmul_model() - input_savedmodel_dir = self.create_tempdir('input').full_path - saved_model_save.save(model, input_savedmodel_dir) + self._create_matmul_model( + input_shape=(1, 1024), + weight_shape=(1024, 3), + saved_model_path=self._input_saved_model_path, + ) quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) - output_savedmodel_dir = self.create_tempdir().full_path + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) tags = {tag_constants.SERVING} - representative_samples = [{ - 'input_tensor': random_ops.random_uniform(shape=(1, 1024)), - } for _ in range(8)] + representative_samples = [ + { + 'input_tensor': random_ops.random_uniform(shape=(1, 1024)), + } + for _ in range(8) + ] # Construct a tf.data.Dataset from the representative samples. representative_dataset = dataset_ops.DatasetV2.from_generator( lambda: representative_samples, output_signature={ - 'input_tensor': - tensor_spec.TensorSpec(shape=(1, 1024), dtype=dtypes.float32), - }) + 'input_tensor': tensor_spec.TensorSpec( + shape=(1, 1024), dtype=dtypes.float32 + ), + }, + ) converted_model = quantize_model.quantize( - input_savedmodel_dir, ['serving_default'], - output_directory=output_savedmodel_dir, + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, quantization_options=quantization_options, - representative_dataset=representative_dataset) + representative_dataset=representative_dataset, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) - output_loader = saved_model_loader.SavedModelLoader(output_savedmodel_dir) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) @test_util.run_in_graph_and_eager_modes def test_model_ptq_no_representative_sample_shows_warnings(self): - model = self._create_matmul_model() - input_savedmodel_dir = self.create_tempdir('input').full_path - output_savedmodel_dir = self.create_tempdir().full_path - saved_model_save.save(model, input_savedmodel_dir) + self._create_matmul_model( + input_shape=(1, 1024), + weight_shape=(1024, 3), + saved_model_path=self._input_saved_model_path, + ) - tags = [tag_constants.SERVING] + tags = {tag_constants.SERVING} quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) with self.assertLogs(level='WARN') as warning_logs: # Save the logger verbosity. @@ -953,14 +2379,15 @@ def test_model_ptq_no_representative_sample_shows_warnings(self): try: converted_model = quantize_model.quantize( - input_savedmodel_dir, + self._input_saved_model_path, ['serving_default'], tags, - output_savedmodel_dir, + self._output_saved_model_path, quantization_options, # Put no sample into the representative dataset to make calibration # impossible. - representative_dataset=[]) + representative_dataset=[], + ) finally: # Restore the logger verbosity. logging.set_verbosity(prev_log_level) @@ -969,29 +2396,35 @@ def test_model_ptq_no_representative_sample_shows_warnings(self): # Warning message should contain the function name. self.assertTrue( - self._any_warning_contains('matmul', warning_logs.records)) + self._any_warning_contains('matmul', warning_logs.records) + ) self.assertTrue( - self._any_warning_contains('does not have min or max values', - warning_logs.records)) + self._any_warning_contains( + 'does not have min or max values', warning_logs.records + ) + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) - output_loader = saved_model_loader.SavedModelLoader(output_savedmodel_dir) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def # Model is not quantized because there was no sample data for calibration. - self.assertFalse( - self._contains_quantized_function_call(output_meta_graphdef)) + self.assertFalse(self._contains_quantized_function_call(output_graphdef)) @test_util.run_in_graph_and_eager_modes def test_model_ptq_with_uncalibrated_subgraph(self): - class IfModel(module.Module): """A model that contains a branching op.""" - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=[1, 4], dtype=dtypes.float32) - ]) + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec(shape=[1, 4], dtype=dtypes.float32) + ] + ) def model_fn(self, x: core.Tensor) -> Mapping[str, core.Tensor]: """Runs the input tensor to a branched operations. @@ -1005,39 +2438,40 @@ def model_fn(self, x: core.Tensor) -> Mapping[str, core.Tensor]: A map of: output key -> output result. """ if math_ops.reduce_sum(x) > 10.0: - filters = np.random.uniform( - low=-1.0, high=1.0, size=(4, 3)).astype('f4') + filters = np.random.uniform(low=-1.0, high=1.0, size=(4, 3)).astype( + 'f4' + ) bias = np.random.uniform(low=-1.0, high=1.0, size=(3,)).astype('f4') out = math_ops.matmul(x, filters) out = nn_ops.bias_add(out, bias) return {'output': out} - filters = np.random.uniform( - low=-1.0, high=1.0, size=(4, 3)).astype('f4') + filters = np.random.uniform(low=-1.0, high=1.0, size=(4, 3)).astype( + 'f4' + ) bias = np.random.uniform(low=-1.0, high=1.0, size=(3,)).astype('f4') out = math_ops.matmul(x, filters) out = nn_ops.bias_add(out, bias) return {'output': out} model = IfModel() - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + saved_model_save.save(model, self._input_saved_model_path) def data_gen() -> repr_dataset.RepresentativeDataset: for _ in range(8): yield { - 'x': - ops.convert_to_tensor( - np.random.uniform(low=0.0, high=1.0, - size=(1, 4)).astype('f4')), + 'x': ops.convert_to_tensor( + np.random.uniform(low=0.0, high=1.0, size=(1, 4)).astype('f4') + ), } - tags = [tag_constants.SERVING] - output_directory = self.create_tempdir().full_path + tags = {tag_constants.SERVING} quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) with self.assertLogs(level='WARN') as warning_logs: # Save the logger verbosity. @@ -1046,11 +2480,13 @@ def data_gen() -> repr_dataset.RepresentativeDataset: try: converted_model = quantize_model.quantize( - input_saved_model_path, ['serving_default'], + self._input_saved_model_path, + ['serving_default'], tags, - output_directory, + self._output_saved_model_path, quantization_options, - representative_dataset=data_gen()) + representative_dataset=data_gen(), + ) finally: # Restore the logger verbosity. logging.set_verbosity(log_level) @@ -1061,20 +2497,26 @@ def data_gen() -> repr_dataset.RepresentativeDataset: # is when the condition is true, so 'cond_true' function must be part of # the warning message. self.assertTrue( - self._any_warning_contains('cond_true', warning_logs.records)) + self._any_warning_contains('cond_true', warning_logs.records) + ) self.assertFalse( - self._any_warning_contains('cond_false', warning_logs.records)) + self._any_warning_contains('cond_false', warning_logs.records) + ) self.assertTrue( - self._any_warning_contains('does not have min or max values', - warning_logs.records)) + self._any_warning_contains( + 'does not have min or max values', warning_logs.records + ) + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) # Run this test only with the eager mode. @test_util.run_v2_only @@ -1083,21 +2525,22 @@ def test_ptq_model_with_multiple_signatures(self): model = MultipleSignatureModel() signatures = { - 'sig1': - model.matmul.get_concrete_function( - tensor_spec.TensorSpec(shape=(1, 4), dtype=dtypes.float32)), - 'sig2': - model.conv.get_concrete_function( - tensor_spec.TensorSpec( - shape=(1, 3, 4, 3), dtype=dtypes.float32)), + 'sig1': model.matmul.get_concrete_function( + tensor_spec.TensorSpec(shape=(1, 4), dtype=dtypes.float32) + ), + 'sig2': model.conv.get_concrete_function( + tensor_spec.TensorSpec(shape=(1, 3, 4, 3), dtype=dtypes.float32) + ), } - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path, signatures=signatures) + saved_model_save.save( + model, self._input_saved_model_path, signatures=signatures + ) - output_directory = self.create_tempdir().full_path quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) def data_gen_sig1() -> repr_dataset.RepresentativeDataset: """Generates tuple-style samples for signature 'sig1'. @@ -1125,23 +2568,26 @@ def data_gen_sig2() -> repr_dataset.RepresentativeDataset: tags = {tag_constants.SERVING} converted_model = quantize_model.quantize( - input_saved_model_path, + self._input_saved_model_path, signature_keys=['sig1', 'sig2'], tags=tags, - output_directory=output_directory, + output_directory=self._output_saved_model_path, quantization_options=quantization_options, representative_dataset={ 'sig1': data_gen_sig1(), 'sig2': data_gen_sig2(), - }) + }, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'sig1', 'sig2'}) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'sig1', 'sig2'} + ) - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) # Run this test only with the eager mode. @test_util.run_v2_only @@ -1150,77 +2596,85 @@ def test_ptq_multiple_signatures_invalid_dataset_raises_value_error(self): model = MultipleSignatureModel() signatures = { - 'sig1': - model.matmul.get_concrete_function( - tensor_spec.TensorSpec(shape=(1, 4), dtype=dtypes.float32)), - 'sig2': - model.conv.get_concrete_function( - tensor_spec.TensorSpec( - shape=(1, 3, 4, 3), dtype=dtypes.float32)), + 'sig1': model.matmul.get_concrete_function( + tensor_spec.TensorSpec(shape=(1, 4), dtype=dtypes.float32) + ), + 'sig2': model.conv.get_concrete_function( + tensor_spec.TensorSpec(shape=(1, 3, 4, 3), dtype=dtypes.float32) + ), } - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path, signatures=signatures) + saved_model_save.save( + model, self._input_saved_model_path, signatures=signatures + ) - output_directory = self.create_tempdir().full_path quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) # Use a dict-style samples instead of tuple-style samples. This is invalid # because for a model multiple signatures one must use tuple-style samples. - invalid_dataset: repr_dataset.RepresentativeDataset = [{ - 'matmul_input': random_ops.random_uniform(shape=(1, 4)) - } for _ in range(8)] + invalid_dataset: repr_dataset.RepresentativeDataset = [ + {'matmul_input': random_ops.random_uniform(shape=(1, 4))} + for _ in range(8) + ] with self.assertRaisesRegex(ValueError, 'Invalid representative dataset.'): quantize_model.quantize( - input_saved_model_path, + self._input_saved_model_path, signature_keys=['sig1', 'sig2'], tags={tag_constants.SERVING}, - output_directory=output_directory, + output_directory=self._output_saved_model_path, quantization_options=quantization_options, - representative_dataset=invalid_dataset) + representative_dataset=invalid_dataset, + ) @test_util.run_in_graph_and_eager_modes def test_ptq_model_with_tf1_saved_model_with_variable_for_conv2d(self): - input_saved_model_path = self.create_tempdir('input').full_path signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY tags = {tag_constants.SERVING} input_placeholder = self._create_and_save_tf1_conv_model( - input_saved_model_path, + self._input_saved_model_path, signature_key, tags, input_key='x', output_key='output', - use_variable=True) + use_variable=True, + ) signature_keys = [signature_key] - output_directory = self.create_tempdir().full_path quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) data_gen = self._create_data_generator( - input_key='x', shape=input_placeholder.shape) + input_key='x', shape=input_placeholder.shape + ) converted_model = quantize_model.quantize( - input_saved_model_path, + self._input_saved_model_path, signature_keys, tags, - output_directory, + self._output_saved_model_path, quantization_options, - representative_dataset=data_gen) + representative_dataset=data_gen, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - signature_keys) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), signature_keys + ) - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) @parameterized.named_parameters( ('use_constant', False), @@ -1228,95 +2682,178 @@ def test_ptq_model_with_tf1_saved_model_with_variable_for_conv2d(self): ) @test_util.run_in_graph_and_eager_modes def test_ptq_model_with_tf1_saved_model_with_variable_for_gather( - self, use_variable): - input_saved_model_path = self.create_tempdir('input').full_path + self, use_variable + ): signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY tags = {tag_constants.SERVING} input_placeholder = self._create_and_save_tf1_gather_model( - input_saved_model_path, + self._input_saved_model_path, signature_key, tags, input_key='x', output_key='output', - use_variable=use_variable) + use_variable=use_variable, + ) signature_keys = [signature_key] - output_directory = self.create_tempdir().full_path quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) data_gen = self._create_data_generator( input_key='x', shape=input_placeholder.shape, minval=0, maxval=10, - dtype=dtypes.int64) + dtype=dtypes.int64, + ) converted_model = quantize_model.quantize( - input_saved_model_path, + self._input_saved_model_path, signature_keys, tags, - output_directory, + self._output_saved_model_path, quantization_options, - representative_dataset=data_gen) + representative_dataset=data_gen, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - signature_keys) - - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), signature_keys + ) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def # Quantization is not currently supported for gather. - self.assertFalse( - self._contains_quantized_function_call(output_meta_graphdef)) + self.assertFalse(self._contains_quantized_function_call(output_graphdef)) + + @test_util.deprecated_graph_mode_only + def test_ptq_model_with_variable_tf1_saved_model_unfreeze_constants(self): + signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + tags = {tag_constants.SERVING} + + input_placeholder = self._create_and_save_tf1_conv_model( + self._input_saved_model_path, + signature_key, + tags, + input_key='x', + output_key='output', + input_shape=(1, 16, 16, 8), + # Uses large filter to exceed the constant size threshold of 64KiB + # (specified by `kDefaultConstantSizeThresholdInBytes`) for unfreezing. + filter_shape=(256, 8, 8, 16), + use_variable=True, + ) + + signature_keys = [signature_key] + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + freeze_all_variables=quant_opts_pb2.FreezeAllVariables(enabled=False), + ) + + repr_ds = self._create_data_generator( + input_key='x', shape=input_placeholder.shape, num_examples=2 + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + signature_keys, + tags, + self._output_saved_model_path, + quantization_options, + representative_dataset=repr_ds, + ) + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + # Checks that quantization is applied. + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + + # Tests that there are variables in the model. + variable_node_defs = _find_variables(output_graphdef) + self.assertLen(variable_node_defs, 1) + + # Reads the variables from the checkpoint file and matches with the + # variables found in the graph. + checkpoint_path = os.path.join( + self._output_saved_model_path, 'variables', 'variables' + ) + var_name_and_shapes = checkpoint_utils.list_variables(checkpoint_path) + + # Checks that each variable's name and shape match. + self.assertEqual(len(variable_node_defs), len(var_name_and_shapes)) + for var_name, shape in var_name_and_shapes: + self.assertIn(var_name, variable_node_defs) + self.assertEqual( + shape, + tensor_shape.TensorShape( + variable_node_defs[var_name].attr['shape'].shape + ), + ) @test_util.run_in_graph_and_eager_modes def test_ptq_model_with_tf1_saved_model(self): - input_saved_model_path = self.create_tempdir('input').full_path tags = {tag_constants.SERVING} signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY input_placeholder = self._create_and_save_tf1_conv_model( - input_saved_model_path, + self._input_saved_model_path, signature_key, tags, input_key='p', output_key='output', - use_variable=False) + use_variable=False, + ) signature_keys = [signature_key] - output_directory = self.create_tempdir().full_path quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) data_gen = self._create_data_generator( - input_key='p', shape=input_placeholder.shape) + input_key='p', shape=input_placeholder.shape + ) converted_model = quantize_model.quantize( - input_saved_model_path, + self._input_saved_model_path, signature_keys, tags, - output_directory, + self._output_saved_model_path, quantization_options, - representative_dataset=data_gen) + representative_dataset=data_gen, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - signature_keys) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), signature_keys + ) - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) @test_util.run_in_graph_and_eager_modes def test_ptq_model_with_tf1_saved_model_multiple_signatures(self): - input_saved_model_path = self.create_tempdir('input').full_path tags = {tag_constants.SERVING} # Create two models and add them to a same SavedModel under different @@ -1324,25 +2861,31 @@ def test_ptq_model_with_tf1_saved_model_multiple_signatures(self): with ops.Graph().as_default(), session.Session() as sess: in_placeholder_1, output_tensor_1 = self._create_simple_tf1_conv_model() sig_def_1 = signature_def_utils_impl.predict_signature_def( - inputs={'x1': in_placeholder_1}, outputs={'output1': output_tensor_1}) + inputs={'x1': in_placeholder_1}, outputs={'output1': output_tensor_1} + ) in_placeholder_2, output_tensor_2 = self._create_simple_tf1_conv_model() sig_def_2 = signature_def_utils_impl.predict_signature_def( - inputs={'x2': in_placeholder_2}, outputs={'output2': output_tensor_2}) + inputs={'x2': in_placeholder_2}, outputs={'output2': output_tensor_2} + ) - v1_builder = builder.SavedModelBuilder(input_saved_model_path) + v1_builder = builder.SavedModelBuilder(self._input_saved_model_path) v1_builder.add_meta_graph_and_variables( - sess, tags, signature_def_map={ + sess, + tags, + signature_def_map={ 'sig1': sig_def_1, 'sig2': sig_def_2, - }) + }, + ) v1_builder.save() - output_directory = self.create_tempdir().full_path quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) def data_gen_sig1() -> repr_dataset.RepresentativeDataset: """Generates tuple-style samples. @@ -1369,196 +2912,213 @@ def data_gen_sig2() -> repr_dataset.RepresentativeDataset: yield {'x2': random_ops.random_uniform(shape=in_placeholder_2.shape)} converted_model = quantize_model.quantize( - input_saved_model_path, + self._input_saved_model_path, signature_keys=['sig1', 'sig2'], tags=tags, - output_directory=output_directory, + output_directory=self._output_saved_model_path, quantization_options=quantization_options, representative_dataset={ 'sig1': data_gen_sig1(), 'sig2': data_gen_sig2(), - }) + }, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'sig1', 'sig2'}) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'sig1', 'sig2'} + ) - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) @test_util.run_in_graph_and_eager_modes def test_ptq_model_with_tf1_saved_model_invalid_input_key_raises_value_error( - self): - input_saved_model_path = self.create_tempdir('input').full_path + self, + ): signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY tags = {tag_constants.SERVING} input_placeholder = self._create_and_save_tf1_conv_model( - input_saved_model_path, + self._input_saved_model_path, signature_key, tags, input_key='x', output_key='output', - use_variable=False) + use_variable=False, + ) signature_keys = [signature_key] - output_directory = self.create_tempdir().full_path quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) # Representative generator function that yields with an invalid input key. invalid_data_gen = self._create_data_generator( - input_key='invalid_input_key', shape=input_placeholder.shape) + input_key='invalid_input_key', shape=input_placeholder.shape + ) with self.assertRaisesRegex( ValueError, - 'Failed to run graph for post-training quantization calibration'): + 'Failed to run graph for post-training quantization calibration', + ): quantize_model.quantize( - input_saved_model_path, + self._input_saved_model_path, signature_keys, tags, - output_directory, + self._output_saved_model_path, quantization_options, - representative_dataset=invalid_data_gen) + representative_dataset=invalid_data_gen, + ) @test_util.run_in_graph_and_eager_modes def test_ptq_model_with_non_default_tags(self): - input_saved_model_path = self.create_tempdir('input').full_path signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # Use a different set of tags other than {"serve"}. tags = {tag_constants.TRAINING, tag_constants.GPU} # Non-default tags are usually used when saving multiple metagraphs in TF1. input_placeholder = self._create_and_save_tf1_conv_model( - input_saved_model_path, + self._input_saved_model_path, signature_key, tags, input_key='input', output_key='output', - use_variable=True) + use_variable=True, + ) signature_keys = [signature_key] - output_directory = self.create_tempdir().full_path quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) data_gen = self._create_data_generator( - input_key='input', shape=input_placeholder.shape) + input_key='input', shape=input_placeholder.shape + ) converted_model = quantize_model.quantize( - input_saved_model_path, + self._input_saved_model_path, signature_keys, tags, - output_directory, + self._output_saved_model_path, quantization_options, - representative_dataset=data_gen) + representative_dataset=data_gen, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - signature_keys) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), signature_keys + ) - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) @test_util.run_in_graph_and_eager_modes def test_ptq_model_with_wrong_tags_raises_error(self): - input_saved_model_path = self.create_tempdir('input').full_path signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY save_tags = {tag_constants.TRAINING, tag_constants.GPU} input_placeholder = self._create_and_save_tf1_conv_model( - input_saved_model_path, + self._input_saved_model_path, signature_key, save_tags, input_key='input', output_key='output', - use_variable=True) + use_variable=True, + ) signature_keys = [signature_key] - output_directory = self.create_tempdir().full_path quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) # Try to use a different set of tags to quantize. tags = {tag_constants.SERVING} data_gen = self._create_data_generator( - input_key='input', shape=input_placeholder.shape) - with self.assertRaisesRegex(RuntimeError, - 'Failed to retrieve MetaGraphDef'): + input_key='input', shape=input_placeholder.shape + ) + with self.assertRaisesRegex( + RuntimeError, + "MetaGraphDef associated with tags {'serve'} could not be found", + ): quantize_model.quantize( - input_saved_model_path, + self._input_saved_model_path, signature_keys, tags, - output_directory, + self._output_saved_model_path, quantization_options, - representative_dataset=data_gen) + representative_dataset=data_gen, + ) # TODO(b/244276332): Allow table initialization in TF2 eager mode. @test_util.deprecated_graph_mode_only def test_ptq_vocab_table_lookup_model(self): tags = {tag_constants.SERVING} signature_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - input_model_dir = self.create_tempdir('input').full_path - with session.Session() as sess: - input_vocabs_placeholder, lookup_tensor, output_tensor = ( - self._create_vocab_table_lookup_model_tf1(sess)) + # Create and save a simple model that involves a hash table. + inputs, outputs = self._create_and_save_vocab_table_lookup_model_tf1( + self._input_saved_model_path, tags, signature_def_key + ) - self._save_tf1_model( - sess, - input_model_dir, - signature_def_key, - tags, - inputs={'input_vocabs': input_vocabs_placeholder}, - outputs={ - 'lookup': lookup_tensor, # Table lookup values. - 'output': output_tensor, - }, - init_op=lookup_ops.tables_initializer(), - assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)) + # Make sure that the desired input key and output key is present. + self.assertIn('input_vocabs', inputs.keys()) + self.assertIn('lookup', outputs.keys()) # Representative dataset is composed of a set of vocabs for table lookup. - repr_ds = [{ - 'input_vocabs': np.array([b'hello', b'model', b'quantization']) - } for _ in range(4)] + repr_ds = [ + {'input_vocabs': np.array([b'hello', b'model', b'quantization'])} + for _ in range(4) + ] quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE)) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) signature_def_keys = [signature_def_key] - output_model_dir = self.create_tempdir('output').full_path quantize_model.quantize( - input_model_dir, + self._input_saved_model_path, signature_def_keys, tags, - output_model_dir, + self._output_saved_model_path, quantization_options, - representative_dataset=repr_ds) + representative_dataset=repr_ds, + ) # Tests table lookup to make sure the table has been initialized # successfully. with session.Session(graph=ops.Graph()) as sess: output_meta_graph_def = saved_model_loader.load( - sess, tags=tags, export_dir=output_model_dir) + sess, tags=tags, export_dir=self._output_saved_model_path + ) # The graph should contain a quantized function call (it contains a # single f32 matmul node). self.assertTrue( - self._contains_quantized_function_call(output_meta_graph_def)) - self.assertCountEqual(output_meta_graph_def.signature_def.keys(), - signature_def_keys) + self._contains_quantized_function_call( + output_meta_graph_def.graph_def + ) + ) + self.assertCountEqual( + output_meta_graph_def.signature_def.keys(), signature_def_keys + ) signature_def = output_meta_graph_def.signature_def[signature_def_key] @@ -1572,35 +3132,90 @@ def test_ptq_vocab_table_lookup_model(self): lookup_tensor, feed_dict={ input_tensor: np.array([b'model', b'quantization', b'hello']) - }) + }, + ) - self.assertAllClose(lookup_val, [1., 2., 0.]) + self.assertAllClose(lookup_val, [1.0, 2.0, 0.0]) @parameterized.named_parameters( ('none', None, False, False, quant_opts_pb2.TF, False, 'SAME'), ('relu', nn_ops.relu, False, False, quant_opts_pb2.TF, False, 'SAME'), ('relu6', nn_ops.relu6, False, False, quant_opts_pb2.TF, False, 'SAME'), ('with_bias', None, True, False, quant_opts_pb2.TF, False, 'SAME'), - ('with_bias_and_relu', nn_ops.relu, True, False, quant_opts_pb2.TF, False, - 'SAME'), - ('with_bias_and_relu6', nn_ops.relu6, True, False, quant_opts_pb2.TF, - False, 'SAME'), + ( + 'with_bias_and_relu', + nn_ops.relu, + True, + False, + quant_opts_pb2.TF, + False, + 'SAME', + ), + ( + 'with_bias_and_relu6', + nn_ops.relu6, + True, + False, + quant_opts_pb2.TF, + False, + 'SAME', + ), ('none_to_xla', None, False, False, quant_opts_pb2.XLA, False, 'SAME'), - ('with_bias_and_relu6_to_xla', nn_ops.relu6, True, False, - quant_opts_pb2.XLA, False, 'SAME'), - ('with_bias_to_xla_dynamic', None, True, False, quant_opts_pb2.XLA, True, - 'SAME'), - ('none_to_xla_padding_valid', None, False, False, quant_opts_pb2.XLA, - False, 'VALID'), - ('with_bias_and_relu6_to_xla_padding_valid', nn_ops.relu6, True, False, - quant_opts_pb2.XLA, False, 'VALID'), - ('with_bias_to_xla_dynamic_padding_valid', None, True, False, - quant_opts_pb2.XLA, True, 'VALID'), + ( + 'with_bias_and_relu6_to_xla', + nn_ops.relu6, + True, + False, + quant_opts_pb2.XLA, + False, + 'SAME', + ), + ( + 'with_bias_to_xla_dynamic', + None, + True, + False, + quant_opts_pb2.XLA, + True, + 'SAME', + ), + ( + 'none_to_xla_padding_valid', + None, + False, + False, + quant_opts_pb2.XLA, + False, + 'VALID', + ), + ( + 'with_bias_and_relu6_to_xla_padding_valid', + nn_ops.relu6, + True, + False, + quant_opts_pb2.XLA, + False, + 'VALID', + ), + ( + 'with_bias_to_xla_dynamic_padding_valid', + None, + True, + False, + quant_opts_pb2.XLA, + True, + 'VALID', + ), ) - def test_conv3d_ptq_model(self, activation_fn: Optional[ops.Operation], - has_bias: bool, has_bn: bool, - target_opset: quant_opts_pb2.OpSet, - input_shape_dynamic: bool, padding: str): + def test_conv3d_ptq_model( + self, + activation_fn: Optional[ops.Operation], + has_bias: bool, + has_batch_norm: bool, + target_opset: quant_opts_pb2.OpSet, + input_shape_dynamic: bool, + padding: str, + ): input_shape = [1, 3, 4, 3, 3] if input_shape_dynamic: input_shape = [None, None, None, None, 3] @@ -1609,12 +3224,15 @@ class ConvModel(module.Module): def __init__(self): self.filters = np.random.uniform( - low=-0.5, high=0.5, size=(2, 3, 3, 3, 2)).astype('f4') + low=-0.5, high=0.5, size=(2, 3, 3, 3, 2) + ).astype('f4') self.bias = np.random.uniform(low=0.0, high=0.2, size=(2)).astype('f4') - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=input_shape, dtype=dtypes.float32) - ]) + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec(shape=input_shape, dtype=dtypes.float32) + ] + ) def conv3d(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: """Performs a 3D convolution operation. @@ -1630,7 +3248,8 @@ def conv3d(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: strides=[1, 1, 2, 1, 1], dilations=[1, 1, 1, 1, 1], padding=padding, - data_format='NDHWC') + data_format='NDHWC', + ) if has_bias: out = nn_ops.bias_add(out, self.bias) if activation_fn is not None: @@ -1639,213 +3258,574 @@ def conv3d(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: np.random.seed(1234) model = ConvModel() - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + saved_model_save.save(model, self._input_saved_model_path) repr_ds = [] for _ in range(500): - repr_ds.append({ - 'input_tensor': - ops.convert_to_tensor( - np.random.uniform(low=-0.1, high=0.2, - size=(1, 3, 4, 3, 3)).astype('f4')), - }) + repr_ds.append( + { + 'input_tensor': ops.convert_to_tensor( + np.random.uniform( + low=-0.1, high=0.2, size=(1, 3, 4, 3, 3) + ).astype('f4') + ), + } + ) signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - tags = [tag_constants.SERVING] + tags = {tag_constants.SERVING} # Check the converted model with TF opset as the baseline. - output_directory = self.create_tempdir().full_path quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE), - op_set=quant_opts_pb2.TF) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=quant_opts_pb2.TF, + ) converted_model = quantize_model.quantize( - input_saved_model_path, [signature_key], + self._input_saved_model_path, + [signature_key], tags, - output_directory, + self._output_saved_model_path, quantization_options, - representative_dataset=repr_ds) + representative_dataset=repr_ds, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {signature_key}) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {signature_key} + ) input_data = np.random.uniform( - low=-0.1, high=0.2, size=(1, 3, 4, 3, 3)).astype('f4') + low=-0.1, high=0.2, size=(1, 3, 4, 3, 3) + ).astype('f4') expected_outputs = model.conv3d(input_data) got_outputs = converted_model.signatures[signature_key]( - input_tensor=ops.convert_to_tensor(input_data)) + input_tensor=ops.convert_to_tensor(input_data) + ) self.assertAllClose(expected_outputs, got_outputs, atol=0.00494) - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) # Check the converted model in the target opset. quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.STATIC_RANGE), - op_set=target_opset) + experimental_method=_ExperimentalMethod.STATIC_RANGE + ), + op_set=target_opset, + ) - output_directory = self.create_tempdir().full_path converted_model = quantize_model.quantize( - input_saved_model_path, [signature_key], + self._input_saved_model_path, + [signature_key], tags, - output_directory, + self._output_saved_model_path_2, quantization_options, - representative_dataset=repr_ds) + representative_dataset=repr_ds, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {signature_key}) - loader = saved_model_loader.SavedModelLoader(output_directory) - meta_graphdef = loader.get_meta_graph_def_from_tags(tags) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {signature_key} + ) + loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path_2 + ) + graphdef = loader.get_meta_graph_def_from_tags(tags).graph_def if target_opset == quant_opts_pb2.XLA: - self.assertTrue(self._contains_op(meta_graphdef, 'XlaConvV2')) + self.assertTrue(self._contains_op(graphdef, 'XlaConvV2')) new_outputs = converted_model.signatures[signature_key]( - input_tensor=ops.convert_to_tensor(input_data)) + input_tensor=ops.convert_to_tensor(input_data) + ) # The quantized model in XLA opset is expected to have similar fidelity # compared to the quantized model in TF opset. self.assertAllClose(new_outputs, got_outputs, atol=0.00306) self.assertAllClose(new_outputs, expected_outputs, atol=0.00494) + # Tests the case of having a signature key of `main` because it is a + # special name in the TF quantizer's MLIR pipeline that should be treated + # with care. + @test_util.run_in_graph_and_eager_modes + def test_ptq_model_with_signature_key_main(self): + signature_key = 'main' + tags = {tag_constants.SERVING} + + input_placeholder = self._create_and_save_tf1_conv_model( + self._input_saved_model_path, + signature_key, + tags, + input_key='x', + output_key='output', + use_variable=True, + ) + + signature_keys = [signature_key] + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.STATIC_RANGE + ) + ) + + data_gen = self._create_data_generator( + input_key='x', shape=input_placeholder.shape + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + signature_keys, + tags, + self._output_saved_model_path, + quantization_options, + representative_dataset=data_gen, + ) + + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), signature_keys + ) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + + # Makes sure that the original function identified by the signature key + # `main` is renamed to `main_0` (see `InsertMainFunctionPass` for details). + self.assertTrue( + any( + map( + lambda func: func.signature.name == 'main_0', + output_graphdef.library.function, + ) + ) + ) + class DynamicRangeQuantizationTest(quantize_model_test_base.QuantizedModelTest): """Test cases for dynamic range quantization. - Run all tests cases in both the graph mode (default in TF1) and the eager mode - (default in TF2) to ensure support for when TF2 is disabled. + Tries to run all tests cases in both the graph mode (default in TF1) and the + eager mode (default in TF2) to ensure support for when TF2 is disabled. """ + @parameterized.named_parameters( + ('to_tf_per_tensor', quant_opts_pb2.TF, False), + ('to_xla_per_tensor', quant_opts_pb2.XLA, False), + ( + 'to_uniform_quantized_per_tensor', + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + ), + ( + 'to_uniform_quantized_per_channel', + quant_opts_pb2.UNIFORM_QUANTIZED, + True, + ), + ) @test_util.run_in_graph_and_eager_modes - def test_matmul_model(self): - model = self._create_matmul_model() - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + def test_matmul_model( + self, + target_opset: quant_opts_pb2.OpSet, + enable_per_channel_quantization: bool, + ): + self._create_matmul_model( + input_shape=(1, 1024), + weight_shape=(1024, 3), + saved_model_path=self._input_saved_model_path, + ) - tags = [tag_constants.SERVING] - output_directory = self.create_tempdir().full_path + tags = {tag_constants.SERVING} quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.DYNAMIC_RANGE), - op_set=quant_opts_pb2.OpSet.UNIFORM_QUANTIZED) + experimental_method=_ExperimentalMethod.DYNAMIC_RANGE + ), + op_set=target_opset, + enable_per_channel_quantization=enable_per_channel_quantization, + ) - converted_model = quantize_model.quantize(input_saved_model_path, - ['serving_default'], tags, - output_directory, - quantization_options) + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) - self.assertTrue( - self._contains_op(output_meta_graphdef, 'UniformQuantizedDotHybrid')) - self.assertFalse(self._contains_op(output_meta_graphdef, 'MatMul')) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + + if target_opset == quant_opts_pb2.UNIFORM_QUANTIZED: + self.assertTrue( + self._contains_op(output_graphdef, 'UniformQuantizedDotHybrid') + ) + self.assertFalse(self._contains_op(output_graphdef, 'MatMul')) + if enable_per_channel_quantization: + quantized_axis_attr = attr_value_pb2.AttrValue(i=-1) + self.assertTrue( + self._contains_op( + output_graphdef, + 'UniformQuantizedDotHybrid', + 'rhs_quantization_axis', + quantized_axis_attr, + ) + ) + elif target_opset == quant_opts_pb2.XLA: + self.assertTrue(self._contains_op(output_graphdef, 'XlaDotV2')) + self.assertFalse(self._contains_op(output_graphdef, 'MatMul')) + else: + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + self.assertTrue(self._contains_op(output_graphdef, 'MatMul')) + @parameterized.named_parameters( + ('to_tf_per_tensor', quant_opts_pb2.TF, False), + ('to_xla_per_tensor', quant_opts_pb2.XLA, False), + ( + 'to_uniform_quantized_per_tensor', + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + ), + ( + 'to_uniform_quantized_per_channel', + quant_opts_pb2.UNIFORM_QUANTIZED, + True, + ), + ) @test_util.run_in_graph_and_eager_modes - def test_conv_model(self): - model = self._create_conv2d_model() - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + def test_conv_model( + self, + target_opset: quant_opts_pb2.OpSet, + enable_per_channel_quantization: bool, + ): + filter_shape = (2, 3, 512, 2) + + model = self._create_conv2d_model( + input_shape=(1, 3, 4, 512), + filter_shape=filter_shape, + has_bias=True, + has_batch_norm=True, + activation_fn=nn_ops.relu6, + ) + + saved_model_save.save(model, self._input_saved_model_path) - tags = [tag_constants.SERVING] - output_directory = self.create_tempdir().full_path + tags = {tag_constants.SERVING} quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.DYNAMIC_RANGE), - op_set=quant_opts_pb2.OpSet.UNIFORM_QUANTIZED) + experimental_method=_ExperimentalMethod.DYNAMIC_RANGE + ), + op_set=target_opset, + enable_per_channel_quantization=enable_per_channel_quantization, + ) - converted_model = quantize_model.quantize(input_saved_model_path, - ['serving_default'], tags, - output_directory, - quantization_options) + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + + if enable_per_channel_quantization: + quantized_axis = 3 + quantized_axis_attr = attr_value_pb2.AttrValue(i=quantized_axis) + quantized_dim_size_attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue( + shape=[ + tensor_shape_pb2.TensorShapeProto( + dim=[ + tensor_shape_pb2.TensorShapeProto.Dim( + size=filter_shape[quantized_axis] + ) + ] + ) + ] + ) + ) + + if target_opset == quant_opts_pb2.UNIFORM_QUANTIZED: + self.assertTrue( + self._contains_op( + output_graphdef, 'UniformQuantizedConvolutionHybrid' + ) + ) + self.assertFalse(self._contains_op(output_graphdef, 'Conv2D')) + if enable_per_channel_quantization: + self.assertTrue( + self._contains_op( + output_graphdef, + 'UniformQuantizedConvolutionHybrid', + 'rhs_quantization_axis', + quantized_axis_attr, + ) + ) + self.assertTrue( + self._contains_op( + output_graphdef, + 'Const', + '_output_shapes', + quantized_dim_size_attr, + ) + ) + elif target_opset == quant_opts_pb2.XLA: + self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2')) + self.assertFalse(self._contains_op(output_graphdef, 'Conv2D')) + else: + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + self.assertTrue(self._contains_op(output_graphdef, 'Conv2D')) + + @parameterized.named_parameters( + ('to_tf_per_tensor', quant_opts_pb2.TF, False), + ('to_xla_per_tensor', quant_opts_pb2.XLA, False), + ( + 'to_uniform_quantized_per_tensor', + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + ), + ( + 'to_uniform_quantized_per_channel', + quant_opts_pb2.UNIFORM_QUANTIZED, + True, + ), + ) + @test_util.run_in_graph_and_eager_modes + def test_depthwise_conv_model( + self, + target_opset: quant_opts_pb2.OpSet, + enable_per_channel_quantization: bool, + ): + filter_shape = (2, 3, 1024, 2) + strides = (1, 2, 2, 1) + + model = self._create_depthwise_conv2d_model( + input_shape=(1, 3, 4, 1024), filter_shape=filter_shape, strides=strides + ) + + saved_model_save.save(model, self._input_saved_model_path) - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - # Currently conv is not supported. - self.assertFalse( - self._contains_quantized_function_call(output_meta_graphdef)) + tags = [tag_constants.SERVING] + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.DYNAMIC_RANGE + ), + op_set=target_opset, + enable_per_channel_quantization=enable_per_channel_quantization, + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + ) + + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + + # Uniform Quantized op takes only the first and the second values for + # strides. + strides_to_check = ( + (strides[1], strides[2]) + if target_opset == quant_opts_pb2.UNIFORM_QUANTIZED + else strides + ) + strides_attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue(i=strides_to_check) + ) + + if enable_per_channel_quantization: + quantized_axis_attr = attr_value_pb2.AttrValue(i=3) + quantized_dim_size_attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue( + shape=[ + tensor_shape_pb2.TensorShapeProto( + dim=[ + tensor_shape_pb2.TensorShapeProto.Dim( + size=filter_shape[2] * filter_shape[3] + ) + ] + ) + ] + ) + ) + + if target_opset == quant_opts_pb2.UNIFORM_QUANTIZED: + self.assertTrue( + self._contains_op( + output_graphdef, + 'UniformQuantizedConvolutionHybrid', + 'window_strides', + strides_attr, + ) + ) + self.assertFalse( + self._contains_op(output_graphdef, 'DepthwiseConv2dNative') + ) + if enable_per_channel_quantization: + self.assertTrue( + self._contains_op( + output_graphdef, + 'UniformQuantizedConvolutionHybrid', + 'rhs_quantization_axis', + quantized_axis_attr, + ) + ) + self.assertTrue( + self._contains_op( + output_graphdef, + 'Const', + '_output_shapes', + quantized_dim_size_attr, + ) + ) + elif target_opset == quant_opts_pb2.XLA: + self.assertTrue(self._contains_op(output_graphdef, 'XlaConvV2')) + self.assertFalse( + self._contains_op(output_graphdef, 'DepthwiseConv2dNative') + ) + else: + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + self.assertTrue( + self._contains_op( + output_graphdef, 'DepthwiseConv2dNative', 'strides', strides_attr + ) + ) @parameterized.named_parameters( - ('use_constant', False), - ('use_variable', True), + ('to_tf_use_constant', quant_opts_pb2.TF, False), + ('to_xla_use_constant', quant_opts_pb2.XLA, False), + ( + 'to_uniform_quantized_use_constant', + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + ), + ('to_tf_use_variable', quant_opts_pb2.TF, True), + ('to_xla_use_variable', quant_opts_pb2.XLA, True), + ( + 'to_uniform_quantized_use_variable', + quant_opts_pb2.UNIFORM_QUANTIZED, + True, + ), ) @test_util.run_v2_only - def test_gather_model(self, use_variable): + def test_gather_model( + self, target_opset: quant_opts_pb2.OpSet, use_variable: bool + ): model = self._create_gather_model(use_variable) - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + saved_model_save.save(model, self._input_saved_model_path) - tags = [tag_constants.SERVING] - output_directory = self.create_tempdir().full_path + tags = {tag_constants.SERVING} quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.DYNAMIC_RANGE), - op_set=quant_opts_pb2.OpSet.UNIFORM_QUANTIZED) + experimental_method=_ExperimentalMethod.DYNAMIC_RANGE + ), + op_set=target_opset, + ) - converted_model = quantize_model.quantize(input_saved_model_path, - ['serving_default'], tags, - output_directory, - quantization_options) + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - # Currently gather is not supported. - self.assertFalse( - self._contains_quantized_function_call(output_meta_graphdef)) + # Only XLA opset does not apply weight-only quantization + if target_opset == quant_opts_pb2.XLA: + threshold = 0.25 if use_variable else 0.3 + self.assertSizeRatioLessThan( + self._output_saved_model_path, self._input_saved_model_path, threshold + ) + else: + # Double from the XLA threshold + threshold = 0.4 if use_variable else 0.7 + self.assertSizeRatioGreaterThan( + self._output_saved_model_path, self._input_saved_model_path, threshold + ) @test_util.run_in_graph_and_eager_modes def test_conv_model_with_wrong_tags_raises_error(self): - input_saved_model_path = self.create_tempdir('input').full_path signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY save_tags = {tag_constants.TRAINING, tag_constants.GPU} input_placeholder = self._create_and_save_tf1_conv_model( - input_saved_model_path, + self._input_saved_model_path, signature_key, save_tags, input_key='input', output_key='output', - use_variable=True) + use_variable=True, + ) signature_keys = [signature_key] - output_directory = self.create_tempdir().full_path quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.DYNAMIC_RANGE), - op_set=quant_opts_pb2.OpSet.UNIFORM_QUANTIZED) + experimental_method=_ExperimentalMethod.DYNAMIC_RANGE + ), + op_set=quant_opts_pb2.OpSet.UNIFORM_QUANTIZED, + ) # Try to use a different set of tags to quantize. tags = {tag_constants.SERVING} data_gen = self._create_data_generator( - input_key='input', shape=input_placeholder.shape) - with self.assertRaisesRegex(ValueError, 'Failed to import SavedModel'): + input_key='input', shape=input_placeholder.shape + ) + + # StatusNotOk error. `Exception` is used here because importing + # `StatusNotOk` may break the open-sourced version of TensorFlow. + with self.assertRaisesRegex( + Exception, 'Failed to import SavedModel' + ) as raises: quantize_model.quantize( - input_saved_model_path, + self._input_saved_model_path, signature_keys, tags, - output_directory, + self._output_saved_model_path, quantization_options, - representative_dataset=data_gen) + representative_dataset=data_gen, + ) + + self.assertEqual(raises.exception.__class__.__name__, 'StatusNotOk') @parameterized.named_parameters( ('quantize', True, 0), @@ -1853,136 +3833,482 @@ def test_conv_model_with_wrong_tags_raises_error(self): ) @test_util.run_in_graph_and_eager_modes def test_minimum_elements_for_weights(self, quantize, num_elements): - model = self._create_matmul_model() - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + self._create_matmul_model( + input_shape=(1, 1024), + weight_shape=(1024, 3), + saved_model_path=self._input_saved_model_path, + ) - tags = [tag_constants.SERVING] - output_directory = self.create_tempdir().full_path + tags = {tag_constants.SERVING} quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.DYNAMIC_RANGE), - op_set=quant_opts_pb2.OpSet.UNIFORM_QUANTIZED) + experimental_method=_ExperimentalMethod.DYNAMIC_RANGE + ), + op_set=quant_opts_pb2.OpSet.UNIFORM_QUANTIZED, + ) quantization_options.min_num_elements_for_weights = num_elements - converted_model = quantize_model.quantize(input_saved_model_path, - ['serving_default'], tags, - output_directory, - quantization_options) + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) - - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + type_attr = attr_value_pb2.AttrValue(type=types_pb2.DT_QINT8) if quantize: self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + self._contains_op(output_graphdef, 'Const', 'dtype', type_attr) + ) else: self.assertFalse( - self._contains_quantized_function_call(output_meta_graphdef)) + self._contains_op(output_graphdef, 'Const', 'dtype', type_attr) + ) @parameterized.named_parameters( - ('use_constant', False), - ('use_variable', True), + ('to_tf_use_constant', quant_opts_pb2.TF, False), + ('to_xla_use_constant', quant_opts_pb2.XLA, False), + ( + 'to_uniform_quantized_use_constant', + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + ), + ('to_tf_use_variable', quant_opts_pb2.TF, True), + ('to_xla_use_variable', quant_opts_pb2.XLA, True), + ( + 'to_uniform_quantized_use_variable', + quant_opts_pb2.UNIFORM_QUANTIZED, + True, + ), ) @test_util.run_in_graph_and_eager_modes - def test_gather_model_tf1(self, use_variable): - input_saved_model_path = self.create_tempdir('input').full_path + def test_gather_model_tf1( + self, target_opset: quant_opts_pb2.OpSet, use_variable: bool + ): signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY tags = {tag_constants.SERVING} _ = self._create_and_save_tf1_gather_model( - input_saved_model_path, + self._input_saved_model_path, signature_key, tags, input_key='x', output_key='output', - use_variable=use_variable) + use_variable=use_variable, + ) signature_keys = [signature_key] - output_directory = self.create_tempdir().full_path quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.DYNAMIC_RANGE), - op_set=quant_opts_pb2.OpSet.UNIFORM_QUANTIZED) + experimental_method=_ExperimentalMethod.DYNAMIC_RANGE + ), + op_set=target_opset, + ) - converted_model = quantize_model.quantize(input_saved_model_path, - signature_keys, tags, - output_directory, - quantization_options) + converted_model = quantize_model.quantize( + self._input_saved_model_path, + signature_keys, + tags, + self._output_saved_model_path, + quantization_options, + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - signature_keys) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), signature_keys + ) - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - # Quantization is not currently supported for gather. - self.assertFalse( - self._contains_quantized_function_call(output_meta_graphdef)) + # Only XLA opset does not apply weight-only quantization + if target_opset == quant_opts_pb2.XLA: + threshold = 0.17 if use_variable else 0.4 + self.assertSizeRatioLessThan( + self._output_saved_model_path, self._input_saved_model_path, threshold + ) + else: + # Double from the XLA threshold + threshold = 0.3 if use_variable else 0.8 + self.assertSizeRatioGreaterThan( + self._output_saved_model_path, self._input_saved_model_path, threshold + ) @test_util.run_in_graph_and_eager_modes def test_non_empty_directory_raises_file_exists_error(self): - model = self._create_matmul_model() + self._create_matmul_model( + input_shape=(1, 1024), + weight_shape=(1024, 3), + saved_model_path=self._input_saved_model_path, + ) + tags = {tag_constants.SERVING} - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + # Create a file inside the output directory. + file_io.write_string_to_file( + filename=os.path.join(self._output_saved_model_path, 'dummy_file.txt'), + file_content='Test content', + ) - tags = [tag_constants.SERVING] + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.DYNAMIC_RANGE + ) + ) + + with self.assertRaisesRegex( + FileExistsError, 'Output directory already exists' + ): + quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + ) + + @test_util.run_in_graph_and_eager_modes + def test_non_empty_directory_overwritten(self): + self._create_matmul_model( + input_shape=(1, 1024), + weight_shape=(1024, 3), + saved_model_path=self._input_saved_model_path, + ) + tags = {tag_constants.SERVING} # Create a file inside the output directory. - output_directory = self.create_tempdir().full_path file_io.write_string_to_file( - filename=os.path.join(output_directory, 'dummy_file.txt'), - file_content='Test content') + filename=os.path.join(self._output_saved_model_path, 'dummy_file.txt'), + file_content='Test content', + ) + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.DYNAMIC_RANGE + ) + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + overwrite_output_directory=True, + ) + + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + self.assertTrue(self._contains_quantized_function_call(output_graphdef)) + + # TODO(b/244276332): Allow table initialization in TF2 eager mode. + @test_util.deprecated_graph_mode_only + def test_table_initialized_when_model_has_table_tf1(self): + tags = {tag_constants.SERVING} + signature_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + + # Create and save a simple model that involves a hash table. + inputs, outputs = self._create_and_save_vocab_table_lookup_model_tf1( + self._input_saved_model_path, tags, signature_def_key + ) + + # Make sure that the desired input key and output key is present. + self.assertIn('input_vocabs', inputs.keys()) + self.assertIn('lookup', outputs.keys()) quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.DYNAMIC_RANGE)) + experimental_method=_ExperimentalMethod.DYNAMIC_RANGE + ) + ) + + signature_def_keys = [signature_def_key] + + quantize_model.quantize( + self._input_saved_model_path, + signature_def_keys, + tags, + self._output_saved_model_path, + quantization_options, + ) + + # Tests table lookup to make sure the table has been initialized + # successfully. + with session.Session(graph=ops.Graph()) as sess: + output_meta_graph_def = saved_model_loader.load( + sess, tags=tags, export_dir=self._output_saved_model_path + ) - with self.assertRaisesRegex(FileExistsError, - 'Output directory already exists'): - quantize_model.quantize(input_saved_model_path, ['serving_default'], tags, - output_directory, quantization_options) + self.assertCountEqual( + output_meta_graph_def.signature_def.keys(), signature_def_keys + ) + signature_def = output_meta_graph_def.signature_def[signature_def_key] + + input_tensor_name = signature_def.inputs['input_vocabs'].name + input_tensor = sess.graph.get_tensor_by_name(input_tensor_name) + + lookup_tensor_name = signature_def.outputs['lookup'].name + lookup_tensor = sess.graph.get_tensor_by_name(lookup_tensor_name) + + lookup_val = sess.run( + lookup_tensor, + feed_dict={ + input_tensor: np.array([b'model', b'quantization', b'hello']) + }, + ) + + self.assertAllClose(lookup_val, [1.0, 2.0, 0.0]) + + +class WeightOnlyQuantizationTest(quantize_model_test_base.QuantizedModelTest): + """Test cases for weight-only quantization. + + Run all tests cases in both the graph mode (default in TF1) and the eager mode + (default in TF2) to ensure support for when TF2 is disabled. + """ + + @parameterized.named_parameters( + ('to_tf_per_tensor', quant_opts_pb2.TF, False), + ('to_xla_per_tensor', quant_opts_pb2.XLA, False), + ) @test_util.run_in_graph_and_eager_modes - def test_non_empty_directory_overwritten(self): - model = self._create_matmul_model() + def test_matmul_model( + self, + target_opset: quant_opts_pb2.OpSet, + enable_per_channel_quantization: bool, + ): + input_shape = (1, 512) + + self._create_matmul_model( + input_shape=input_shape, + weight_shape=(512, 2), + saved_model_path=self._input_saved_model_path, + ) - input_saved_model_path = self.create_tempdir('input').full_path - saved_model_save.save(model, input_saved_model_path) + tags = {tag_constants.SERVING} + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.WEIGHT_ONLY + ), + op_set=target_opset, + enable_per_channel_quantization=enable_per_channel_quantization, + ) - tags = [tag_constants.SERVING] + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + ) + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + + self.assertTrue(self._contains_op(output_graphdef, 'MatMul')) + # Due to other meta data, the compression is not exactly 1/4. + threshold = 0.9 if quant_opts_pb2.TF else 0.3 + self.assertSizeRatioLessThan( + self._output_saved_model_path, self._input_saved_model_path, threshold + ) - # Create a file inside the output directory. - output_directory = self.create_tempdir().full_path - file_io.write_string_to_file( - filename=os.path.join(output_directory, 'dummy_file.txt'), - file_content='Test content') + @parameterized.named_parameters( + ('to_tf_per_tensor', quant_opts_pb2.TF, False), + ('to_xla_per_tensor', quant_opts_pb2.XLA, False), + ) + @test_util.run_in_graph_and_eager_modes + def test_conv_model( + self, + target_opset: quant_opts_pb2.OpSet, + enable_per_channel_quantization: bool, + ): + model = self._create_conv2d_model( + input_shape=(1, 3, 4, 512), + filter_shape=(2, 3, 512, 2), + has_bias=False, + has_batch_norm=False, + activation_fn=nn_ops.relu6, + ) + saved_model_save.save(model, self._input_saved_model_path) + + tags = {tag_constants.SERVING} quantization_options = quant_opts_pb2.QuantizationOptions( quantization_method=quant_opts_pb2.QuantizationMethod( - experimental_method=_ExperimentalMethod.DYNAMIC_RANGE)) + experimental_method=_ExperimentalMethod.WEIGHT_ONLY + ), + op_set=target_opset, + enable_per_channel_quantization=enable_per_channel_quantization, + ) converted_model = quantize_model.quantize( - input_saved_model_path, ['serving_default'], + self._input_saved_model_path, + ['serving_default'], tags, - output_directory, + self._output_saved_model_path, quantization_options, - overwrite_output_directory=True) + ) self.assertIsNotNone(converted_model) - self.assertCountEqual(converted_model.signatures._signatures.keys(), - {'serving_default'}) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) - output_loader = saved_model_loader.SavedModelLoader(output_directory) - output_meta_graphdef = output_loader.get_meta_graph_def_from_tags(tags) - self.assertTrue( - self._contains_quantized_function_call(output_meta_graphdef)) + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + + self.assertTrue(self._contains_op(output_graphdef, 'Conv2D')) + # Due to other meta data, the compression is not exactly 1/4. + + threshold = 0.9 if quant_opts_pb2.TF else 0.3 + self.assertSizeRatioLessThan( + self._output_saved_model_path, self._input_saved_model_path, threshold + ) + + @parameterized.named_parameters( + ('to_tf_per_tensor', quant_opts_pb2.TF, False), + ('to_xla_per_tensor', quant_opts_pb2.XLA, False), + ) + @test_util.run_in_graph_and_eager_modes + def test_depthwise_conv_model( + self, + target_opset: quant_opts_pb2.OpSet, + enable_per_channel_quantization: bool, + ): + filter_shape = (2, 3, 512, 2) + strides = (1, 2, 2, 1) + + model = self._create_depthwise_conv2d_model( + input_shape=(1, 3, 4, 512), filter_shape=filter_shape, strides=strides + ) + + saved_model_save.save(model, self._input_saved_model_path) + + tags = {tag_constants.SERVING} + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.WEIGHT_ONLY + ), + op_set=target_opset, + enable_per_channel_quantization=enable_per_channel_quantization, + ) + + converted_model = quantize_model.quantize( + self._input_saved_model_path, + ['serving_default'], + tags, + self._output_saved_model_path, + quantization_options, + ) + + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + output_loader = saved_model_loader.SavedModelLoader( + self._output_saved_model_path + ) + output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def + + self.assertTrue(self._contains_op(output_graphdef, 'DepthwiseConv2dNative')) + # Due to other meta data, the compression is not exactly 1/4. + threshold = 0.9 if quant_opts_pb2.TF else 0.3 + self.assertSizeRatioLessThan( + self._output_saved_model_path, self._input_saved_model_path, threshold + ) + + @parameterized.named_parameters( + ('to_tf_use_constant', quant_opts_pb2.TF, False), + ('to_xla_use_constant', quant_opts_pb2.XLA, False), + ( + 'to_uniform_quantized_use_constant', + quant_opts_pb2.UNIFORM_QUANTIZED, + False, + ), + ('to_tf_use_variable', quant_opts_pb2.TF, True), + ('to_xla_use_variable', quant_opts_pb2.XLA, True), + ( + 'to_uniform_quantized_use_variable', + quant_opts_pb2.UNIFORM_QUANTIZED, + True, + ), + ) + @test_util.run_v2_only + def test_gather_model( + self, target_opset: quant_opts_pb2.OpSet, use_variable: bool + ): + model = self._create_gather_model(use_variable) + input_saved_model_path = self.create_tempdir('input').full_path + saved_model_save.save(model, input_saved_model_path) + + tags = {tag_constants.SERVING} + output_directory = self.create_tempdir().full_path + + quantization_options = quant_opts_pb2.QuantizationOptions( + quantization_method=quant_opts_pb2.QuantizationMethod( + experimental_method=_ExperimentalMethod.WEIGHT_ONLY + ), + op_set=target_opset, + ) + + if target_opset == quant_opts_pb2.UNIFORM_QUANTIZED: + # Uniform quantized opset is not supported for weight-only + with self.assertRaisesRegex( + ValueError, 'Uniform quantized opset does not support weight-only.' + ): + converted_model = quantize_model.quantize( + input_saved_model_path, + ['serving_default'], + tags, + output_directory, + quantization_options, + ) + return + + else: + converted_model = quantize_model.quantize( + input_saved_model_path, + ['serving_default'], + tags, + output_directory, + quantization_options, + ) + + self.assertIsNotNone(converted_model) + self.assertCountEqual( + converted_model.signatures._signatures.keys(), {'serving_default'} + ) + + threshold = 0.3 if quant_opts_pb2.XLA else 0.9 + self.assertSizeRatioLessThan( + self._output_saved_model_path, self._input_saved_model_path, threshold + ) if __name__ == '__main__': diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py index 25ec30f3de5..8794ca87d1f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test_base.py @@ -14,7 +14,7 @@ # ============================================================================== """Base test class for quantize_model Tests.""" import os -from typing import Collection, Iterable, Mapping, Sequence, Tuple, Optional +from typing import Collection, Iterable, Mapping, Sequence, Tuple, Optional, Union, List from absl.testing import parameterized import numpy as np @@ -22,8 +22,8 @@ from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as repr_dataset from tensorflow.core.framework import function_pb2 +from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 -from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.client import session from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes @@ -42,15 +42,81 @@ from tensorflow.python.ops.ragged import ragged_string_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import builder +from tensorflow.python.saved_model import save as saved_model_save from tensorflow.python.saved_model import signature_def_utils_impl from tensorflow.python.trackable import asset from tensorflow.python.trackable import autotrackable from tensorflow.python.types import core +# Type aliases for supported attribute types. +_AttrValType = Union[List[int], bool, str, None] + class QuantizedModelTest(test.TestCase, parameterized.TestCase): """Base test class for TF-quant tests.""" + def setUp(self) -> None: + super().setUp() + + # Many test cases for quantization involve creating and saving the input + # model and saving the output quantized model. These two member + # attributes can be used to specify the paths for such models, + # respectively. These paths will be cleaned up after each test case. + self._input_saved_model_path = self.create_tempdir('input').full_path + self._output_saved_model_path = self.create_tempdir('output').full_path + # Extra output path occasionally used for comparing two different + # quantized models. + self._output_saved_model_path_2 = self.create_tempdir('output2').full_path + + def _get_dir_size(self, path: str = '.'): + """Get the total size of files and sub-directories under the path. + + Args: + path: Path of a directory or a file to calculate the total size. + + Returns: + Total size of the directory or a file. + """ + total = 0 + for root, _, files in os.walk(path): + for filename in files: + total += os.path.getsize(os.path.join(root, filename)) + return total + + def assertSizeRatioGreaterThan( + self, path_a: str, path_b: str, threshold: float + ): + """Check if the size ratio of the given paths is greater than the threshold. + + Args: + path_a: Path of a directory or a file to be the nominator of the ratio. + path_b: Path of a directory or a file to be the denominator of the ratio. + threshold: a number to compare with. + + Returns: + True if the size ratio of path_a / path_b is greater than threshold. + """ + size_a = self._get_dir_size(path_a) + size_b = self._get_dir_size(path_b) + size_ratio = size_a / size_b + return self.assertGreater(size_ratio, threshold) + + def assertSizeRatioLessThan(self, path_a: str, path_b: str, threshold: float): + """Check if the size ratio of the given paths is less than the threshold. + + Args: + path_a: Path of a directory or a file to be the nominator of the ratio. + path_b: Path of a directory or a file to be the denominator of the ratio. + threshold: a number to compare with. + + Returns: + True if the size ratio of path_a / path_b is less than threshold. + """ + size_a = self._get_dir_size(path_a) + size_b = self._get_dir_size(path_b) + size_ratio = size_a / size_b + return self.assertLess(size_ratio, threshold) + def _is_quantized_function(self, func: function_pb2.FunctionDef) -> bool: """Determine whether a FunctionDef is quantized. @@ -73,76 +139,198 @@ def _is_composite_function(self, func: function_pb2.FunctionDef) -> bool: """ return func.signature.name.startswith('composite_') - def _contains_op_with_name(self, nodes: Iterable[node_def_pb2.NodeDef], - op_name: str) -> bool: + def _contains_op_with_name_and_attribute( + self, + nodes: Iterable[node_def_pb2.NodeDef], + op_name: str, + attr_name: str, + attr_val: _AttrValType, + ) -> bool: """Determine whether there is a node whose operation name matches `op_name`. + If `attr_name` is given, additionally check if the `attr_val` matches with + the attribute value of the op. + Args: nodes: Iterable of NodeDefs. op_name: Name of the op to match. + attr_name: Name of the attribute of the op to match. + attr_val: Value of the attr_name to check. Returns: - True iff there exists a node whose name matches `op_name`. + True if there exists a node whose name matches `op_name` and 'attr_val' if + 'attr_name' is given. """ - return any(node.op == op_name for node in nodes) + return any( + node.attr.get(attr_name) == attr_val + for node in nodes + if node.op == op_name + ) def _contains_quantized_function_call( - self, meta_graphdef: meta_graph_pb2.MetaGraphDef) -> bool: + self, graphdef: graph_pb2.GraphDef + ) -> bool: """Determines if the graph def has quantized function call. Args: - meta_graphdef: A MetaGraphDef object. + graphdef: A GraphDef object. Returns: True if and only if the graph def contains a quantized function call. """ - return any( - map(self._is_quantized_function, - meta_graphdef.graph_def.library.function)) + return any(map(self._is_quantized_function, graphdef.library.function)) def _contains_composite_function_call( - self, meta_graphdef: meta_graph_pb2.MetaGraphDef) -> bool: + self, graphdef: graph_pb2.GraphDef + ) -> bool: """Determines if the graph def has composite function call. Args: - meta_graphdef: A MetaGraphDef object. + graphdef: A GraphDef object. Returns: True if and only if the graph def contains a composite function call. """ - return any( - map(self._is_composite_function, - meta_graphdef.graph_def.library.function)) + return any(map(self._is_composite_function, graphdef.library.function)) - def _contains_op(self, meta_graphdef: meta_graph_pb2.MetaGraphDef, - op_name: str) -> bool: + def _contains_op( + self, + graphdef: graph_pb2.GraphDef, + op_name: str, + attr_name: str = '', + attr_val: _AttrValType = None, + ) -> bool: """Determines if the graph def contains the given op. Args: - meta_graphdef: A MetaGraphDef object. + graphdef: A GraphDef object. op_name: Name of the operation to find within the graph. + attr_name: Name of the attribute of the op to match. + attr_val: Value of the attr_name to check. Returns: - True if and only if the graph def contains an op named `op_name`. + True if and only if the graph def contains an op named `op_name`. If + `attr_name` is given, check if the `attr_val` matches with the attribute + value of the op. """ # Check the main graph - if self._contains_op_with_name( - nodes=meta_graphdef.graph_def.node, op_name=op_name): + if self._contains_op_with_name_and_attribute( + nodes=graphdef.node, + op_name=op_name, + attr_name=attr_name, + attr_val=attr_val, + ): return True # Check the graph genederated from user defined functions - return any( - self._contains_op_with_name(nodes=func.node_def, op_name=op_name) - for func in meta_graphdef.graph_def.library.function) + for func in graphdef.library.function: + if self._contains_op_with_name_and_attribute( + nodes=func.node_def, + op_name=op_name, + attr_name=attr_name, + attr_val=attr_val, + ): + return True + return False + + def _count_ops( + self, + graphdef: graph_pb2.GraphDef, + op_names: Collection[str], + attr_name: str = '', + attr_val: _AttrValType = None, + get_op_name: bool = False, + ) -> int: + """Returns the number of given ops in a graph def. + + Args: + graphdef: A GraphDef object. + op_names: Names of the operations to find within the graph. + attr_name: Name of the attribute of the ops to match. + attr_val: Value of the attr_name to check. + get_op_name: If set True, checks node.name rather than node.op. + + Returns: + The number of occurrences of the given ops in a graph. The ops will be + counted only if the ops are named 'op_name' and has 'attr_val' if + 'attr_name' is specified. + """ + op_count = 0 + for op_name in op_names: + # Check the main graph + op_count += self._count_op_with_name_and_attribute( + nodes=graphdef.node, + op_name=op_name, + attr_name=attr_name, + attr_val=attr_val, + get_op_name=get_op_name, + ) + + # Check the graph genederated from user defined functions + for func in graphdef.library.function: + op_count += self._count_op_with_name_and_attribute( + nodes=func.node_def, + op_name=op_name, + attr_name=attr_name, + attr_val=attr_val, + get_op_name=get_op_name, + ) + return op_count + + def _count_op_with_name_and_attribute( + self, + nodes: Iterable[node_def_pb2.NodeDef], + op_name: str, + attr_name: str, + attr_val: _AttrValType, + get_op_name: bool = False, + ) -> int: + """Determine the number of nodes whose operation name matches `op_name`. + + If `attr_name` is given, additionally check if the `attr_val` matches with + the attribute value of the op. - def _create_simple_tf1_conv_model(self, - use_variable_for_filter=False - ) -> Tuple[core.Tensor, core.Tensor]: + Args: + nodes: Iterable of NodeDefs. + op_name: Name of the op to match. + attr_name: Name of the attribute of the op to match. + attr_val: Value of the attr_name to check. + get_op_name: If set True, checks node.name rather than node.op. + + Returns: + The number of occurrences of nodes whose name match `op_name` and + 'attr_val' if 'attr_name' is given. + """ + if get_op_name: + return len( + [ + node.attr.get(attr_name) == attr_val + for node in nodes + if node.name == op_name + ] + ) + else: + return len( + [ + node.attr.get(attr_name) == attr_val + for node in nodes + if node.op == op_name + ] + ) + + def _create_simple_tf1_conv_model( + self, + input_shape: Sequence[int] = (1, 3, 4, 3), + filter_shape: Sequence[int] = (2, 3, 3, 2), + use_variable_for_filter=False, + ) -> Tuple[core.Tensor, core.Tensor]: """Creates a basic convolution model. This is intended to be used for TF1 (graph mode) tests. Args: + input_shape: Shape of the input tensor. + filter_shape: Shape of the filter. use_variable_for_filter: Setting this to `True` makes the filter for the conv operation a `tf.Variable`. @@ -150,10 +338,11 @@ def _create_simple_tf1_conv_model(self, in_placeholder: Input tensor placeholder. output_tensor: The resulting tensor of the convolution operation. """ - in_placeholder = array_ops.placeholder(dtypes.float32, shape=[1, 3, 4, 3]) + in_placeholder = array_ops.placeholder(dtypes.float32, shape=input_shape) filters = random_ops.random_uniform( - shape=(2, 3, 3, 2), minval=-1., maxval=1.) + shape=filter_shape, minval=-1.0, maxval=1.0 + ) if use_variable_for_filter: filters = variables.Variable(filters) @@ -163,13 +352,14 @@ def _create_simple_tf1_conv_model(self, strides=[1, 1, 2, 1], dilations=[1, 1, 1, 1], padding='SAME', - data_format='NHWC') + data_format='NHWC', + ) return in_placeholder, output_tensor - def _create_simple_tf1_gather_model(self, - use_variable_for_filter=False - ) -> Tuple[core.Tensor, core.Tensor]: + def _create_simple_tf1_gather_model( + self, use_variable_for_filter=False + ) -> Tuple[core.Tensor, core.Tensor]: """Creates a basic gather model. This is intended to be used for TF1 (graph mode) tests. @@ -184,7 +374,7 @@ def _create_simple_tf1_gather_model(self, """ in_placeholder = array_ops.placeholder(dtypes.int64, shape=(6)) - filters = random_ops.random_uniform(shape=(64, 512), minval=-1., maxval=1.) + filters = np.random.randn(128, 32).astype(np.float32) if use_variable_for_filter: filters = variables.Variable(filters) @@ -192,9 +382,53 @@ def _create_simple_tf1_gather_model(self, return in_placeholder, output_tensor - def _create_vocab_table_lookup_model_tf1( + def _create_and_save_vocab_table_lookup_model_tf1( self, - sess: session.Session) -> Tuple[core.Tensor, core.Tensor, core.Tensor]: + output_path: str, + tags: Collection[str], + signature_def_key: str, + ) -> Tuple[Mapping[str, core.Tensor], Mapping[str, core.Tensor]]: + """Creates and saves a simple model that uses a vocab table. + + Args: + output_path: Path to the directory to save the created model. + tags: Set of strings that identifies the saved meta graph. + signature_def_key: Name of the SignatureDef. Used to identify the + SignatureDef within the meta graph. + + Returns: + inputs: A mapping of input_key -> input_tensor (placeholder). The input + key is "input_vocabs". + outputs: A mapping of output_key -> output_tensor. The output keys are + "lookup" and "output". + """ + with session.Session(graph=ops.Graph()) as sess: + input_vocabs_placeholder, lookup_tensor, output_tensor = ( + self._create_vocab_table_lookup_model_tf1(sess) + ) + + inputs = {'input_vocabs': input_vocabs_placeholder} + outputs = { + 'lookup': lookup_tensor, + 'output': output_tensor, + } + + self._save_tf1_model( + sess, + output_path, + signature_def_key, + tags, + inputs=inputs, + outputs=outputs, + init_op=lookup_ops.tables_initializer(), + assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), + ) + + return inputs, outputs + + def _create_vocab_table_lookup_model_tf1( + self, sess: session.Session + ) -> Tuple[core.Tensor, core.Tensor, core.Tensor]: """Creates a simple model that initializes and lookups a vocab table. This model creates an asset file at "vocab_file.txt" containing @@ -215,37 +449,165 @@ def _create_vocab_table_lookup_model_tf1( asset_dir = self.create_tempdir('assets').full_path asset_file = os.path.join(asset_dir, 'vocab_file.txt') file_io.write_string_to_file( - filename=asset_file, file_content='hello,model,quantization\n') + filename=asset_file, file_content='hello,model,quantization\n' + ) + + vocab_file = asset.Asset(asset_file) + + raw_vocab = io_ops.read_file(vocab_file) + vocabs = ragged_string_ops.string_split_v2( + string_ops.string_strip(raw_vocab), sep=',' + ) + + # Initialize the vocab table. Each comma-separated word in vocab_file.txt + # corresponds to the numeric identifiers in `values`. + kv_init = lookup_ops.KeyValueTensorInitializer( + keys=vocabs, values=np.array([0, 1, 2]), value_dtype=dtypes.int64 + ) + table = lookup_ops.StaticVocabularyTable(kv_init, num_oov_buckets=5) + + input_vocabs_placeholder = array_ops.placeholder( + dtypes.string, shape=(None,), name='input_vocabs' + ) + + # Introduce a matmul op that takes the lookup values to observe the + # effects of quantization. + lookup_vals = math_ops.cast( + table.lookup(input_vocabs_placeholder), dtypes.float32 + ) + # shape: (2, ?) + matmul_input = array_ops.stack([lookup_vals, lookup_vals]) + + # Create a dummy weight matrix filled with ones. + weight_row = array_ops.ones( + shape=array_ops.shape(input_vocabs_placeholder), dtype=dtypes.float32 + ) + # shape: (?, 2) + weight = array_ops.transpose_v2(array_ops.stack([weight_row, weight_row])) + # shape: (2, 2) + output_tensor = math_ops.matmul(matmul_input, weight) + + return input_vocabs_placeholder, lookup_vals, output_tensor + + def _create_and_save_vocab_table_lookup_qat_model_tf1( + self, + output_path: str, + tags: Collection[str], + signature_def_key: str, + ) -> Tuple[Mapping[str, core.Tensor], Mapping[str, core.Tensor]]: + """Creates and saves a simple QAT model that uses a vocab table. + + Args: + output_path: Path to the directory to save the created model. + tags: Set of strings that identifies the saved meta graph. + signature_def_key: Name of the SignatureDef. Used to identify the + SignatureDef within the meta graph. + + Returns: + inputs: A mapping of input_key -> input_tensor (placeholder). The input + key is "input_vocabs". + outputs: A mapping of output_key -> output_tensor. The output keys are + "lookup" and "output". + """ + with session.Session(graph=ops.Graph()) as sess: + input_vocabs_placeholder, lookup_tensor, output_tensor = ( + self._create_vocab_table_lookup_qat_model_tf1(sess) + ) + + inputs = {'input_vocabs': input_vocabs_placeholder} + outputs = { + 'lookup': lookup_tensor, + 'output': output_tensor, + } + + self._save_tf1_model( + sess, + output_path, + signature_def_key, + tags, + inputs=inputs, + outputs=outputs, + init_op=lookup_ops.tables_initializer(), + assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), + ) + + return inputs, outputs + + def _create_vocab_table_lookup_qat_model_tf1( + self, sess: session.Session + ) -> Tuple[core.Tensor, core.Tensor, core.Tensor]: + """Creates a simple QAT model that initializes and lookups a vocab table. + + This model creates an asset file at "vocab_file.txt" containing + comma-separated vocabularies. It also initializes a `StaticVocabularyTable` + and performs a lookup with the input vocabs, which is a 1D tensor of + strings. + + Args: + sess: Tensorflow Session to create the model in. + + Returns: + (input_vocabs_placeholder, lookup_vals, output_tensor), where + * input_vocabs_placeholder is a placeholder tensor of 1D strings + * lookup_vals is an output tensor that is a direct result of table lookup + * output_tensor is a float 2x2 matrix + """ + # Creates and populates an asset file. + asset_dir = self.create_tempdir('assets').full_path + asset_file = os.path.join(asset_dir, 'vocab_file.txt') + file_io.write_string_to_file( + filename=asset_file, file_content='hello,model,quantization\n' + ) vocab_file = asset.Asset(asset_file) raw_vocab = io_ops.read_file(vocab_file) vocabs = ragged_string_ops.string_split_v2( - string_ops.string_strip(raw_vocab), sep=',') + string_ops.string_strip(raw_vocab), sep=',' + ) # Initialize the vocab table. Each comma-separated word in vocab_file.txt # corresponds to the numeric identifiers in `values`. kv_init = lookup_ops.KeyValueTensorInitializer( - keys=vocabs, values=np.array([0, 1, 2]), value_dtype=dtypes.int64) + keys=vocabs, values=np.array([0, 1, 2]), value_dtype=dtypes.int64 + ) table = lookup_ops.StaticVocabularyTable(kv_init, num_oov_buckets=5) input_vocabs_placeholder = array_ops.placeholder( - dtypes.string, shape=(None,), name='input_vocabs') + dtypes.string, shape=(None,), name='input_vocabs' + ) # Introduce a matmul op that takes the lookup values to observe the # effects of quantization. lookup_vals = math_ops.cast( - table.lookup(input_vocabs_placeholder), dtypes.float32) + table.lookup(input_vocabs_placeholder), dtypes.float32 + ) + # shape: (2, ?) matmul_input = array_ops.stack([lookup_vals, lookup_vals]) + # Insert fake quant to simulate a QAT model. + matmul_input = array_ops.fake_quant_with_min_max_args( + matmul_input, min=-0.3, max=0.3, num_bits=8, narrow_range=False + ) # Create a dummy weight matrix filled with ones. weight_row = array_ops.ones( - shape=array_ops.shape(input_vocabs_placeholder), dtype=dtypes.float32) + shape=array_ops.shape(input_vocabs_placeholder), dtype=dtypes.float32 + ) + # shape: (?, 2) weight = array_ops.transpose_v2(array_ops.stack([weight_row, weight_row])) + # Insert fake quant to simulate a QAT model. + weight = array_ops.fake_quant_with_min_max_args( + weight, min=-0.1, max=0.2, num_bits=8, narrow_range=False + ) + # shape: (2, 2) output_tensor = math_ops.matmul(matmul_input, weight) + # Insert fake quant to simulate a QAT model. + output_tensor = array_ops.fake_quant_with_min_max_args( + output_tensor, min=-0.2, max=0.2, num_bits=8, narrow_range=False + ) return input_vocabs_placeholder, lookup_vals, output_tensor @@ -253,10 +615,11 @@ def _create_data_generator( self, input_key: str, shape: Sequence[int], - minval: float = -1., - maxval: float = 1., + minval: float = -1.0, + maxval: float = 1.0, dtype: dtypes.DType = dtypes.float32, - num_examples: int = 8) -> repr_dataset.RepresentativeDataset: + num_examples: int = 8, + ) -> repr_dataset.RepresentativeDataset: """Creates a data generator to be used as representative dataset. Supports generating random value input tensors mapped by the `input_key`. @@ -286,7 +649,8 @@ def _save_tf1_model( inputs: Mapping[str, core.Tensor], outputs: Mapping[str, core.Tensor], init_op: Optional[ops.Operation] = None, - assets_collection: Optional[Sequence[ops.Tensor]] = None) -> None: + assets_collection: Optional[Sequence[ops.Tensor]] = None, + ) -> None: """Saves a TF1 model. Args: @@ -303,23 +667,69 @@ def _save_tf1_model( """ v1_builder = builder.SavedModelBuilder(saved_model_path) sig_def = signature_def_utils_impl.predict_signature_def( - inputs=inputs, outputs=outputs) + inputs=inputs, outputs=outputs + ) v1_builder.add_meta_graph_and_variables( sess, tags, signature_def_map={signature_key: sig_def}, main_op=init_op, - assets_collection=assets_collection) + assets_collection=assets_collection, + ) v1_builder.save() - def _create_and_save_tf1_gather_model(self, - saved_model_path: str, - signature_key: str, - tags: Collection[str], - input_key: str, - output_key: str, - use_variable=False) -> core.Tensor: + def _create_simple_gather_and_conv_model(self, filter_shape: Sequence[int]): + class SimpleGatherAndConvModel(module.Module): + """A simple model with a single gather and a conv2d.""" + + def __init__(self): + """Initializes a SimpleGatherAndConvModel.""" + embedding_w_val = np.random.randn(1024, 3, 4, 3).astype('f4') + self.embedding_w = embedding_w_val + + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec( + shape=[1], dtype=dtypes.int64, name='input_tensor' + ) + ] + ) + def model(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: + """Performs a gather and a 2D convolution operation. + + Args: + input_tensor: Input tensor to perform operation on. + + Returns: + A map of: output key -> output result. + """ + conv_filters = np.random.uniform( + low=-10, high=10, size=filter_shape + ).astype('f4') + + out = array_ops.gather_v2(self.embedding_w, input_tensor) + out = nn_ops.conv2d( + out, + conv_filters, + strides=[1, 1, 2, 1], + dilations=[1, 1, 1, 1], + padding='SAME', + data_format='NHWC', + ) + return {'output': out} + + return SimpleGatherAndConvModel() + + def _create_and_save_tf1_gather_model( + self, + saved_model_path: str, + signature_key: str, + tags: Collection[str], + input_key: str, + output_key: str, + use_variable=False, + ) -> core.Tensor: """Creates and saves a simple gather model. This is intended to be used for TF1 (graph mode) tests. @@ -339,7 +749,8 @@ def _create_and_save_tf1_gather_model(self, """ with ops.Graph().as_default(), session.Session() as sess: in_placeholder, output_tensor = self._create_simple_tf1_gather_model( - use_variable_for_filter=use_variable) + use_variable_for_filter=use_variable + ) if use_variable: sess.run(variables.global_variables_initializer()) @@ -350,12 +761,12 @@ def _create_and_save_tf1_gather_model(self, signature_key, tags, inputs={input_key: in_placeholder}, - outputs={output_key: output_tensor}) + outputs={output_key: output_tensor}, + ) return in_placeholder def _create_gather_model(self, use_variable): - class GatherModel(autotrackable.AutoTrackable): """A simple model with a single gather.""" @@ -366,33 +777,107 @@ def __init__(self, use_variable): use_variable: If True, creates a variable for weight. """ super(GatherModel, self).__init__() - w_val = np.random.randint( - low=0, high=100, size=(64, 512), dtype=np.int64) + w_val = np.random.randn(128, 32).astype('f4') if use_variable: self.w = variables.Variable(w_val) else: self.w = w_val - @def_function.function(input_signature=[ - tensor_spec.TensorSpec( - shape=[6], dtype=dtypes.int64, name='input_tensor') - ]) - def __call__(self, - input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec( + shape=[6], dtype=dtypes.int64, name='input_tensor' + ) + ] + ) + def __call__( + self, input_tensor: core.Tensor + ) -> Mapping[str, core.Tensor]: """Performs a gather operation.""" out = array_ops.gather_v2(self.w, input_tensor) return {'output': out} return GatherModel(use_variable) - def _create_conv2d_model(self): + def _create_depthwise_conv2d_model( + self, + input_shape: Sequence[int], + filter_shape: Sequence[int], + has_bias: bool = False, + has_batch_norm: bool = False, + activation_fn: Optional[ops.Operation] = None, + strides: Sequence[int] = (1, 2, 2, 1), + dilations: Sequence[int] = (1, 1, 1, 1), + padding: str = 'SAME', + ): + class DepthwiseConvModel(module.Module): + """A simple model with a single depthwise conv2d, bias and relu.""" + + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec(shape=input_shape, dtype=dtypes.float32) + ] + ) + def depthwise_conv( + self, input_tensor: core.Tensor + ) -> Mapping[str, core.Tensor]: + """Performs a 2D depthwise convolution operation. + + Args: + input_tensor: Input tensor to perform convolution on. + + Returns: + A map of: output key -> output result. + """ + filters = np.random.uniform(low=-10, high=10, size=filter_shape).astype( + 'f4' + ) + out_channel_size = filter_shape[2] * filter_shape[3] + bias = np.random.uniform( + low=0, high=10, size=(out_channel_size) + ).astype('f4') + scale, offset = [1.0] * out_channel_size, [0.5] * out_channel_size + mean, variance = scale, offset + out = nn_ops.depthwise_conv2d_native( + input_tensor, + filters, + strides=[1, 2, 2, 1], + dilations=[1, 1, 1, 1], + padding='SAME', + data_format='NHWC', + ) + if has_bias: + out = nn_ops.bias_add(out, bias) + if has_batch_norm: + # Fusing is supported for non-training case. + out, _, _, _, _, _ = nn_ops.fused_batch_norm_v3( + out, scale, offset, mean, variance, is_training=False + ) + if activation_fn is not None: + out = activation_fn(out) + return {'output': out} + + return DepthwiseConvModel() + def _create_conv2d_model( + self, + input_shape: Sequence[int], + filter_shape: Sequence[int], + has_bias: bool = False, + has_batch_norm: bool = False, + activation_fn: Optional[ops.Operation] = None, + strides: Sequence[int] = (1, 2, 2, 1), + dilations: Sequence[int] = (1, 1, 1, 1), + padding: str = 'SAME', + ): class ConvModel(module.Module): """A simple model with a single conv2d, bias and relu.""" - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=[1, 3, 4, 512], dtype=dtypes.float32) - ]) + @def_function.function( + input_signature=[ + tensor_spec.TensorSpec(shape=input_shape, dtype=dtypes.float32) + ] + ) def conv(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: """Performs a 2D convolution operation. @@ -402,51 +887,70 @@ def conv(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: Returns: A map of: output key -> output result. """ - filters = np.random.uniform( - low=-10, high=10, size=(2, 3, 512, 2)).astype('f4') - bias = np.random.uniform(low=0, high=10, size=(2)).astype('f4') + filters = np.random.uniform(low=-10, high=10, size=filter_shape).astype( + 'f4' + ) + out_channel_size = filter_shape[-1] + bias = np.random.uniform( + low=0, high=10, size=(out_channel_size) + ).astype('f4') + scale, offset = [1.0] * out_channel_size, [0.5] * out_channel_size + mean, variance = scale, offset out = nn_ops.conv2d( input_tensor, filters, strides=[1, 1, 2, 1], dilations=[1, 1, 1, 1], padding='SAME', - data_format='NHWC') - out = nn_ops.bias_add(out, bias, data_format='NHWC') - out = nn_ops.relu6(out) + data_format='NHWC', + ) + if has_bias: + out = nn_ops.bias_add(out, bias, data_format='NHWC') + if has_batch_norm: + # Fusing is supported for non-training case. + out, _, _, _, _, _ = nn_ops.fused_batch_norm_v3( + out, scale, offset, mean, variance, is_training=False + ) + if activation_fn is not None: + out = activation_fn(out) return {'output': out} return ConvModel() - def _create_matmul_model(self, - has_bias: bool = False, - activation_fn: Optional[ops.Operation] = None) ->...: - + def _create_matmul_model( + self, + input_shape: Sequence[int], + weight_shape: Sequence[int], + saved_model_path: str, + has_bias: bool = False, + activation_fn: Optional[ops.Operation] = None, + ) -> module.Module: class MatmulModel(module.Module): """A simple model with a single matmul. Bias and activation function are optional. """ - def __init__(self, - has_bias: bool = False, - activation_fn: Optional[ops.Operation] = None) -> None: + def __init__( + self, + weight_shape: Sequence[int], + has_bias: bool = False, + activation_fn: Optional[ops.Operation] = None, + ) -> None: """Initializes a MatmulModel. Args: + weight_shape: Shape of the weight tensor. has_bias: If True, creates and adds a bias term. activation_fn: The activation function to be used. No activation function if None. """ self.has_bias = has_bias self.activation_fn = activation_fn - self.filters = np.random.uniform(low=-1.0, high=1.0, size=(1024, 3)) - self.bias = np.random.uniform(low=-1.0, high=1.0, size=(3,)) + self.filters = np.random.uniform(low=-1.0, high=1.0, size=weight_shape) + self.bias = np.random.uniform(low=-1.0, high=1.0, size=weight_shape[-1]) - @def_function.function(input_signature=[ - tensor_spec.TensorSpec( - shape=(1, 1024), dtype=dtypes.float32, name='input_tensor') - ]) + @def_function.function def matmul(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: """Performs a matrix multiplication. @@ -470,15 +974,106 @@ def matmul(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: return {'output': out} - return MatmulModel(has_bias, activation_fn) + model = MatmulModel(weight_shape, has_bias, activation_fn) + saved_model_save.save( + model, + saved_model_path, + signatures=model.matmul.get_concrete_function( + tensor_spec.TensorSpec( + shape=input_shape, dtype=dtypes.float32, name='input_tensor' + ) + ), + ) + return model + + def _create_einsum_model( + self, + saved_model_path: str, + equation: str, + input_shape: Sequence[int], + weight_shape: Sequence[int], + bias_shape: Optional[Sequence[int]] = None, + activation_fn: Optional[ops.Operation] = None, + ) -> module.Module: + class EinsumModel(module.Module): + """A simple model with a single einsum. + + Bias and activation function are optional. + """ + + def __init__( + self, + equation: str, + weight_shape: Sequence[int], + bias_shape: Optional[Sequence[int]] = None, + activation_fn: Optional[ops.Operation] = None, + ) -> None: + """Initializes a EinsumModel. + + Args: + equation: a string describing the contraction. + weight_shape: Shape of the weight tensor. + bias_shape: Shape of the bias. This is not always 1D so Einsum ops + usually use Add op instead of BiasAdd. + activation_fn: The activation function to be used. No activation + function if None. + """ + self.equation = equation + self.activation_fn = activation_fn + self.weight = np.random.uniform(low=-1.0, high=1.0, size=weight_shape) + self.bias = ( + np.random.uniform(low=-1.0, high=1.0, size=bias_shape) + if bias_shape is not None + else None + ) + + @def_function.function + def einsum(self, input_tensor: core.Tensor) -> Mapping[str, core.Tensor]: + """Evaluates the Einstein summation convention. + + Depending on self.has_bias and self.activation_fn, it may add a bias + term or go through the activaction function. + + Args: + input_tensor: Input tensor to einsum with the weight. + + Returns: + A map of: output key -> output result. + """ + out = tensorflow.einsum(self.equation, input_tensor, self.weight) + + if self.bias is not None: + out = out + self.bias + + if self.activation_fn is not None: + out = self.activation_fn(out) + + return {'output': out} - def _create_and_save_tf1_conv_model(self, - saved_model_path: str, - signature_key: str, - tags: Collection[str], - input_key: str, - output_key: str, - use_variable=False) -> core.Tensor: + model = EinsumModel(equation, weight_shape, bias_shape, activation_fn) + saved_model_save.save( + model, + saved_model_path, + signatures=model.einsum.get_concrete_function( + tensor_spec.TensorSpec( + shape=input_shape, dtype=dtypes.float32, name='input_tensor' + ) + ), + ) + return model + + def _create_and_save_tf1_conv_model( + self, + saved_model_path: str, + signature_key: str, + tags: Collection[str], + input_key: str, + output_key: str, + *, + input_shape: Sequence[int] = (1, 3, 4, 3), + filter_shape: Sequence[int] = (2, 3, 3, 2), + use_variable: bool = False, + ) -> core.Tensor: """Creates and saves a simple convolution model. This is intended to be used for TF1 (graph mode) tests. @@ -490,6 +1085,8 @@ def _create_and_save_tf1_conv_model(self, tags: Set of tags associated with the model. input_key: The key to the input tensor. output_key: The key to the output tensor. + input_shape: Shape of the input tensor. + filter_shape: Shape of the filter. use_variable: Setting this to `True` makes the filter for the conv operation a `tf.Variable`. @@ -498,7 +1095,10 @@ def _create_and_save_tf1_conv_model(self, """ with ops.Graph().as_default(), session.Session() as sess: in_placeholder, output_tensor = self._create_simple_tf1_conv_model( - use_variable_for_filter=use_variable) + input_shape=input_shape, + filter_shape=filter_shape, + use_variable_for_filter=use_variable, + ) if use_variable: sess.run(variables.global_variables_initializer()) @@ -509,6 +1109,7 @@ def _create_and_save_tf1_conv_model(self, signature_key, tags, inputs={input_key: in_placeholder}, - outputs={output_key: output_tensor}) + outputs={output_key: output_tensor}, + ) return in_placeholder diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc index 6be9ad9ea18..8c1beece1ef 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.cc @@ -13,110 +13,223 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include +#include +#include +#include +#include "absl/strings/str_format.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model_wrapper.h" +#include "pybind11/stl.h" +#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil +#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/python/lib/core/pybind11_lib.h" +namespace { + +using ::tensorflow::calibrator::CalibratorSingleton; +using ::tensorflow::quantization::ExportedModel; +using ::tensorflow::quantization::QuantizationOptions; +using ::tensorflow::quantization::QuantizePtqDynamicRange; +using ::tensorflow::quantization::QuantizePtqModelPostCalibration; +using ::tensorflow::quantization::QuantizePtqModelPreCalibration; +using ::tensorflow::quantization::QuantizeQatModel; + +// Serializes an ExportedModel. Raises python ValueError if serialization fails. +std::string Serialize(const ExportedModel& exported_model) { + const std::string exported_model_serialized = + exported_model.SerializeAsString(); + + // Empty string means it failed to serialize the protobuf with an error. See + // the docstring for SerializeAsString for details. + if (exported_model_serialized.empty()) { + throw py::value_error("Failed to serialize ExportedModel."); + } + + return exported_model_serialized; +} + +// Retrieves collected min / max values of a `CustomAggregator` node from the +// singleton. `id` is the identifier of the `CustomAggregator`. +std::pair GetCalibratorMinMax(const absl::string_view id) { + std::optional> min_max = + CalibratorSingleton::GetMinMax(id); + if (min_max == std::nullopt) { + throw py::value_error( + absl::StrFormat("Calibrated data does not exist. Cannot find min/max " + "value for id: '%s'", + id)); + } + + return *min_max; +} + +} // namespace + +namespace pybind11 { +namespace detail { + +// Converts `ExportedModel` (c++) to `bytes` (python). The resulting `bytes` +// object is a serialization of `ExportedModel`. +// +// See https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html for +// further details on how custom type conversions work for pybind11. +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(ExportedModel, const_name("ExportedModel")); + + // Constructs a `bytes` object after serializing `src`. + static handle cast(ExportedModel&& src, return_value_policy policy, + handle parent) { + // release() prevents the reference count from decreasing upon the + // destruction of py::bytes and returns a raw python object handle. + return py::bytes(Serialize(src)).release(); + } +}; + +// Python -> cpp conversion for `QuantizationOptions`. Accepts a serialized +// protobuf string and deserializes into an instance of `QuantizationOptions`. +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(QuantizationOptions, const_name("QuantizationOptions")); + + bool load(handle src, const bool convert) { + auto caster = make_caster(); + // The user should have passed a valid python string. + if (!caster.load(src, convert)) { + return false; + } + + const absl::string_view quantization_opts_serialized = + cast_op(std::move(caster)); + + // NOLINTNEXTLINE: Explicit std::string conversion required for OSS. + return value.ParseFromString(std::string(quantization_opts_serialized)); + } +}; + +} // namespace detail +} // namespace pybind11 + PYBIND11_MODULE(pywrap_quantize_model, m) { + // Supports absl::StatusOr type conversions. + pybind11::google::ImportStatusModule(); + + // Calibrator related functions. m.def( "clear_calibrator", - [] { - tensorflow::quantization::ClearCollectedInformationFromCalibrator(); - }, + [] { CalibratorSingleton::ClearCollectedInformation(); }, R"pbdoc( Clears the collected metrics from the calibrator. )pbdoc"); m.def( "clear_data_from_calibrator", - [](const absl::string_view id) { - tensorflow::quantization::ClearDataFromCalibrator(id); - }, + [](const absl::string_view id) { CalibratorSingleton::ClearData(id); }, R"pbdoc( Clears the collected data of the given id from calibrator. )pbdoc"); m.def( - "get_max_from_calibrator", - [](const absl::string_view id) { - return tensorflow::quantization::GetMaxFromCalibrator(id); + "get_min_from_calibrator", + [](const absl::string_view id) -> float { + const std::pair min_max = GetCalibratorMinMax(id); + return min_max.first; }, R"pbdoc( Return the tuple with the min value of the given id. )pbdoc"); m.def( - "get_min_from_calibrator", - [](const absl::string_view id) { - return tensorflow::quantization::GetMinFromCalibrator(id); + "get_max_from_calibrator", + [](const absl::string_view id) -> float { + const std::pair min_max = GetCalibratorMinMax(id); + return min_max.second; }, R"pbdoc( Return the tuple with the min value of the given id. )pbdoc"); + + // Quantization functions. m.def( "quantize_qat_model", [](const absl::string_view saved_model_path, - const absl::string_view exported_names_str, - const absl::string_view tags, - const absl::string_view quant_opts_serialized) { - const std::string graph_def_serialized = - tensorflow::quantization::QuantizeQatModel(saved_model_path, - exported_names_str, tags, - quant_opts_serialized) - .first; - - return py::bytes(graph_def_serialized); + const std::vector& signature_keys, + const std::unordered_set& tags, + const QuantizationOptions& quant_opts) + -> absl::StatusOr { + return QuantizeQatModel(saved_model_path, signature_keys, tags, + quant_opts); }, R"pbdoc( - Returns serialized GraphDef of a TF model. + Returns serialized ExportedModel that contains the quantized model's + GraphDef and metadata. The user should pass a serialized + `QuantizationOptions` for the `quant_opts` argument. + + Raises `StatusNotOk` exception if when the run was unsuccessful. )pbdoc"); + m.def( "quantize_ptq_dynamic_range", [](const absl::string_view saved_model_path, - const absl::string_view exported_names_str, - const absl::string_view tags, - const absl::string_view quant_opts_serialized) { - const std::string graph_def_serialized = - tensorflow::quantization::QuantizePtqDynamicRange( - saved_model_path, exported_names_str, tags, - quant_opts_serialized) - .first; - - return py::bytes(graph_def_serialized); + const std::vector& signature_keys, + const std::unordered_set& tags, + const QuantizationOptions& quant_opts) + -> absl::StatusOr { + return QuantizePtqDynamicRange(saved_model_path, signature_keys, tags, + quant_opts); }, R"pbdoc( - Returns serialized GraphDef of a TF model. + Returns serialized ExportedModel that contains the quantized model's + GraphDef and metadata. The user should pass a serialized + `QuantizationOptions` for the `quant_opts` argument. + + Raises `StatusNotOk` exception if when the run was unsuccessful. )pbdoc"); + m.def( "quantize_ptq_model_pre_calibration", [](const absl::string_view saved_model_path, - const absl::string_view exported_names_str, - const absl::string_view tags, - const absl::string_view quant_opts_serialized) { - const auto [graph_def_serialized, init_node_name] = - tensorflow::quantization::QuantizePtqModelPreCalibration( - saved_model_path, exported_names_str, tags, - quant_opts_serialized); - - return std::make_pair(py::bytes(graph_def_serialized), init_node_name); + const std::vector& signature_keys, + const std::unordered_set& tags, + const QuantizationOptions& quant_opts, + const absl::flat_hash_map& function_aliases) + -> absl::StatusOr { + return QuantizePtqModelPreCalibration(saved_model_path, signature_keys, + tags, quant_opts, + function_aliases); }, R"pbdoc( - Returns serialized GraphDef of a TF model. + Returns serialized ExportedModel that contains the model's GraphDef and + metadata. The GraphDef contains extra ops required for calibration. The + user should pass a serialized `QuantizationOptions` for the `quant_opts` + argument. + + Raises `StatusNotOk` exception if when the run was unsuccessful. )pbdoc"); + m.def( "quantize_ptq_model_post_calibration", [](const absl::string_view saved_model_path, - const absl::string_view exported_names_str, - const absl::string_view tags, - const absl::string_view quant_opts_serialized) { - const auto [graph_def_serialized, init_node_name] = - tensorflow::quantization::QuantizePtqModelPostCalibration( - saved_model_path, exported_names_str, tags, - quant_opts_serialized); - - return std::make_pair(py::bytes(graph_def_serialized), init_node_name); + const std::vector& signature_keys, + const std::unordered_set& tags, + const QuantizationOptions& quant_opts, + const absl::flat_hash_map& function_aliases) + -> absl::StatusOr { + return QuantizePtqModelPostCalibration(saved_model_path, signature_keys, + tags, quant_opts, + function_aliases); }, R"pbdoc( - Returns serialized GraphDef of a TF model. + Returns serialized ExportedModel that contains the quantized model's + GraphDef and metadata. The user should pass a serialized + `QuantizationOptions` for the `quant_opts` argument. + + Raises `StatusNotOk` exception if when the run was unsuccessful. )pbdoc"); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model_test.py new file mode 100644 index 00000000000..ed531218290 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model_test.py @@ -0,0 +1,51 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for pywrap_quantize_model. + +These test cases are mostly for validation checks. Tests for functionalities +are at `quantize_model_test.py`. +""" +from tensorflow.compiler.mlir.quantization.tensorflow.python import pywrap_quantize_model +from tensorflow.python.platform import test + + +class PywrapQuantizeModelTest(test.TestCase): + """Test cases for quantize_model python wrappers.""" + + def test_quantize_model_fails_when_invalid_quant_options_serialization(self): + saved_model_path = self.create_tempdir('saved_model').full_path + signature_def_keys = ['serving_default'] + tags = {'serve'} + quant_opts_serialized = 'invalid protobuf serialization string' + + with self.assertRaisesRegex(TypeError, 'incompatible function arguments'): + pywrap_quantize_model.quantize_ptq_model_pre_calibration( + saved_model_path, signature_def_keys, tags, quant_opts_serialized + ) + + def test_quantize_model_fails_when_invalid_quant_options_type(self): + saved_model_path = self.create_tempdir('saved_model').full_path + signature_def_keys = ['serving_default'] + tags = {'serve'} + invalid_quant_opts_object = ('a', 'b', 'c') + + with self.assertRaisesRegex(TypeError, 'incompatible function arguments'): + pywrap_quantize_model.quantize_ptq_model_pre_calibration( + saved_model_path, signature_def_keys, tags, invalid_quant_opts_object + ) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index 0517be6937a..ae4c98c05b2 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -14,16 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h" -#include #include #include #include #include #include +#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "llvm/Support/raw_ostream.h" @@ -40,7 +39,10 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/constants.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h" @@ -57,39 +59,113 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/platform/statusor.h" +#include "tensorflow/tsl/platform/env.h" #include "tensorflow/tsl/platform/path.h" +#include "tensorflow/tsl/platform/status.h" namespace tensorflow { namespace quantization { -namespace internal { namespace { -void AddExportPasses(mlir::PassManager &pm) { +using ::mlir::tf_saved_model::kTfSavedModelInitializerInitType; +using ::mlir::tf_saved_model::kTfSavedModelInitializerRestoreType; + +// Suffix string for the module export step. Used for debugging. +constexpr absl::string_view kExportStepSuffix = "_export"; + +// Options when running passes for exporting an MLIR ModuleOp. +struct ExportOptions { + // If set to `true`, it runs `DuplicateShapeDeterminingConstantsPass` before + // lowering to tf_executor dialect. + bool duplicate_shape_determining_constants = true; + + // If set to `true`, unfreezes constants into variables and saves them to a + // checkpoint file. Setting this to `true` is an experimental feature that has + // no stability guarantees. + bool unfreeze_constants = false; + + // Path to the directory where checkpoint files are saved. + std::string checkpoint_dir = ""; + + // Name used to identify the ModuleOp this is exporting. Only used for + // debugging and does not modify the behavior of the export. + std::string debug_name = "tf_quant"; +}; + +// Add passes for transforming the MLIR module op so that it can be exported +// back to GraphDef. Roughly, this consists of: +// 1) Inserting the @main function, which will become the main Graph. +// 2) Duplicating shape-determining constants. +// 3) Converting TF dialect -> tf_executor dialect. +// 4) Adding initializer function's ops into @main function for correct +// resource initialization when loading the exported model. +// +// Duplicating shape-determining constants is required to place constants that +// affect the shape of a tensor to be placed in the TPU graph instead of in the +// CPU graph, when the graph gets converted for TPU inference. This allows these +// constants to be known at XLA compilation time. +void AddExportPasses(const bool duplicate_shape_determining_constants, + mlir::PassManager &pm) { + if (duplicate_shape_determining_constants) { + pm.addNestedPass( + mlir::quant::CreateDuplicateShapeDeterminingConstantsPass()); + } + pm.addPass(mlir::quant::CreateInsertMainFunctionPass()); pm.addNestedPass( mlir::CreateFunctionalToExecutorDialectConversionPass()); pm.addPass(mlir::CreateBreakUpIslandsPass()); pm.addPass(mlir::quant::CreateMergeInitializerFunctionOpsToMainPass()); + + // Used to clean up the "tf._noinliner" attribute that is previously used to + // prevent certain functions from being inlined (see + // `MarkFunctionsNoinlinePass`). InlinerPass must not come after this pass. + pm.addPass(mlir::TF::CreateStripNoinlineAttributePass()); } -// Returns the name of the initializer node from a set of control return nodes. -// Returns an empty string if no initializer node exists. This assumes that -// there is only one node for initialization. -std::string GetInitNodeName( - const absl::flat_hash_set &control_ret_nodes) { +// Finds and returns the name of the node from a set of control output nodes. +// The name should contain the string `contains`. Returns an empty string if no +// node whose name contains `contains` is found. Assumes there is at most one +// such a node. +std::string GetNodeName(const absl::flat_hash_set &control_ret_nodes, + const absl::string_view contains) { for (Node *control_ret_node : control_ret_nodes) { - if (absl::StrContains(control_ret_node->name(), kInitOpNamePrefix)) { - VLOG(1) << "Init node found: " << control_ret_node->name(); + if (absl::StrContains(control_ret_node->name(), contains)) { + VLOG(1) << "Node found: " << control_ret_node->name() + << ", contains: " << contains; return control_ret_node->name(); } } + VLOG(1) << "Could not find node whose name conatins: " << contains; return ""; } +[[nodiscard]] ExportedModel CreateExportedModel( + GraphDef &&graph_def, const absl::string_view init_node_name, + const absl::string_view restore_node_name, + const absl::string_view checkpoint_dir, + const std::vector &variable_shared_names, + const absl::flat_hash_map &function_aliases) { + ExportedModel exported_model{}; + *exported_model.mutable_graph_def() = graph_def; + exported_model.set_init_node_name(std::string(init_node_name)); + exported_model.set_restore_node_name(std::string(restore_node_name)); + exported_model.set_checkpoint_dir(std::string(checkpoint_dir)); + for (auto &shared_name : variable_shared_names) { + *exported_model.mutable_variable_shared_names()->Add() = shared_name; + } + exported_model.mutable_function_aliases()->insert(function_aliases.begin(), + function_aliases.end()); + + return exported_model; +} + // Converts MLIR ModuleOp to ExportedModel. Returns InternalError status -// when the GraphDef conversion fails. +// when the conversion fails. absl::StatusOr ConvertMlirModuleToExportedModel( - const mlir::ModuleOp module_op) { + const mlir::ModuleOp module_op, const absl::string_view checkpoint_dir, + const std::vector &variable_shared_names, + const absl::flat_hash_map &function_aliases) { const GraphExportConfig config{}; FunctionLibraryDefinition flib_def{OpRegistry::Global(), FunctionDefLibrary()}; @@ -102,122 +178,180 @@ absl::StatusOr ConvertMlirModuleToExportedModel( status.error_message()); } - auto graph_def = std::make_unique(); - graph->ToGraphDef(graph_def.get()); + GraphDef graph_def{}; + graph->ToGraphDef(&graph_def); + + const std::string init_node_name = + GetNodeName(control_ret_nodes, kTfSavedModelInitializerInitType); + const std::string restore_node_name = + GetNodeName(control_ret_nodes, kTfSavedModelInitializerRestoreType); - return ExportedModel{*graph_def, GetInitNodeName(control_ret_nodes)}; + return CreateExportedModel(std::move(graph_def), init_node_name, + restore_node_name, checkpoint_dir, + variable_shared_names, function_aliases); } -// Creates a new file to dump the intermediate MLIRs by prefixing the -// `dump_file_name` with the value of the TF_QUANT_MLIR_DUMP_PREFIX env -// variable. Returns absl::FailedPreconditionError if the env variable is not -// set or set to an empty string. -[[nodiscard]] absl::StatusOr> -CreateMlirDumpFile(const absl::string_view dump_file_name) { - const auto prefix = - absl::NullSafeStringView(std::getenv("TF_QUANT_MLIR_DUMP_PREFIX")); - if (prefix.empty()) { - return absl::FailedPreconditionError( - "Environment variable not set: TF_QUANT_MLIR_DUMP_PREFIX, " - "IR dump file for TF quantization is not created."); +// Runs MLIR passes with `module_op`. The passes are added by calling +// `add_passes_func`, which is a callable receiving mlir::PassManager& as its +// only argument. `name` identifies the set of passes added by `add_passes_func` +// and is used for debugging. Changing the `name` does not modify the behavior +// of the passes. +// +// It will try to dump intermediate MLIRs if certain conditions are met. See the +// description from `MaybeEnableIrPrinting` for the details about the +// conditions. +// +// Returns a non-OK status when the pass run fails or it fails to create an MLIR +// dump file. +template +absl::Status RunPasses(const absl::string_view name, FuncT add_passes_func, + mlir::MLIRContext &ctx, mlir::ModuleOp module_op) { + mlir::PassManager pm{&ctx}; + add_passes_func(pm); + + mlir::StatusScopedDiagnosticHandler diagnostic_handler{&ctx}; + const absl::StatusOr> out_dump_file = + MaybeEnableIrPrinting(pm, name); + if (!out_dump_file.ok()) { + return absl::InternalError(out_dump_file.status().message()); } - Env *env = Env::Default(); - const Status status = env->RecursivelyCreateDir(std::string(prefix)); - if (!status.ok()) { - return ToAbslStatus(status); + if (failed(pm.run(module_op))) { + return absl::InternalError( + absl::StrFormat("Failed to run pass: %s. %s", name, + diagnostic_handler.ConsumeStatus().error_message())); } - std::error_code ec{}; // NOLINT: Required to create llvm::raw_fd_ostream - const std::string dump_file_path = tsl::io::JoinPath(prefix, dump_file_name); - auto dump_file = std::make_unique(dump_file_path, ec); - if (ec) { - return absl::InternalError(absl::StrFormat( - "Unable to open file: %s, error: %s", dump_file_path, ec.message())); + return absl::OkStatus(); +} + +// Create a unique local temporary filename. It only creates the name, not the +// actual file. +absl::StatusOr GetLocalTempFilename() { + auto *env = Env::Default(); + std::string tmp_fname{}; + if (!env->LocalTempFilename(&tmp_fname)) { + return absl::InternalError("Failed to create a local temp file name."); } - LOG(INFO) << "IR dump file created: " << dump_file_path; - return dump_file; + return tmp_fname; } -// If verbosity level >= 1, this will dump intermediate IRs of passes to a file. -// The file path is given by prefixing `name`.mlir with the value of the -// TF_QUANT_MLIR_DUMP_PREFIX env variable. Returns `nullptr` iff the verbosity -// level < 1 or TF_QUANT_MLIR_DUMP_PREFIX is not set or set to an empty string. -// The returned ostream instance should live until the pass run is complete. -[[nodiscard]] absl::StatusOr> -MaybeEnableIrPrinting(mlir::PassManager &pm, const absl::string_view name) { - if (!VLOG_IS_ON(1)) { - // Verbosity level is too low to enable IR printing. - return nullptr; - } - - absl::StatusOr> dump_file = - CreateMlirDumpFile(/*dump_file_name=*/absl::StrCat(name, ".mlir")); - if (absl::IsFailedPrecondition(dump_file.status())) { - // The env variable TF_QUANT_MLIR_DUMP_PREFIX is not set. IR printing will - // not be enabled. - LOG(WARNING) << dump_file.status(); - return nullptr; - } else if (!dump_file.ok()) { - return dump_file.status(); - } - - mlir::OpPrintingFlags flag{}; - flag.useLocalScope().elideLargeElementsAttrs().enableDebugInfo(); - - // IR printing requires multithreading disabled. - pm.getContext()->disableMultithreading(); - - // The configuration uses the default parameter values for - // `PassManager::enableIRPrinting`, except for the `printModuleScope` - // parameter, which is true by default. It is set to false to avoid the dump - // file size becoming too large when the passes are running on a large model. - pm.enableIRPrinting( - /*shouldPrintBeforePass=*/[](mlir::Pass *, - mlir::Operation *) { return true; }, - /*shouldPrintAfterPass=*/ - [](mlir::Pass *, mlir::Operation *) { return true; }, - /*printModuleScope=*/false, /*printAfterOnlyOnChange=*/true, - /*printAfterOnlyOnFailure=*/false, **dump_file, flag); - - LOG(INFO) << "IR dump for TensorFlow quantization pipeline enabled. "; - return dump_file; +// Unfreezes constants into variables and saves them to a checkpoint files under +// `checkpoint_dir`. `checkpoint_dir` will be created within this function. It +// will return a non-OK status if it already exists or permission is denied. +// TODO(b/261652258): Make sure this works for when there are non-frozen +// variables in the model. +// TODO(b/262189534): Move this to a separate file for better testing. +absl::StatusOr> UnfreezeConstantsAndSaveVariables( + const absl::string_view checkpoint_dir, mlir::MLIRContext &ctx, + mlir::ModuleOp module_op) { + if (const absl::Status pass_run_status = + RunPasses(/*name=*/kTfQuantConstantUnfreezingStepName, + /*add_passes_func=*/ + [](mlir::PassManager &pm) { + pm.addPass(mlir::quant::CreateUnfreezeConstantsPass()); + }, + ctx, module_op); + !pass_run_status.ok()) { + return pass_run_status; + } + + if (const tsl::Status create_dir_status = + Env::Default()->CreateDir(std::string(checkpoint_dir)); + !create_dir_status.ok()) { + LOG(ERROR) << "Failed to create checkpoint directory at: " + << checkpoint_dir; + return tsl::ToAbslStatus(create_dir_status); + } + + const absl::StatusOr> variable_save_status = + SaveVariablesToCheckpoint(checkpoint_dir, module_op); + if (!variable_save_status.ok()) { + return variable_save_status.status(); + } + + if (const absl::Status pass_run_status = RunPasses( + /*name=*/kTfQuantInsertRestoreOpStepName, + /*add_passes_func=*/ + [](mlir::PassManager &pm) { + pm.addPass(mlir::quant::CreateInsertRestoreOpPass()); + // Initialization by `tf.ConstOp` is no longer required as there is + // a `tf.RestoreV2Op` now. + pm.addPass( + mlir::quant::CreateRemoveVariableInitializationByConstPass()); + }, + ctx, module_op); + !pass_run_status.ok()) { + return pass_run_status; + } + + return *variable_save_status; } -} // namespace +// Sets up and runs the passes for exporting `module_op`. The behavior of the +// exporting passes is controlled by `export_opts`. +absl::StatusOr> RunExportPasses( + const ExportOptions &export_opts, mlir::MLIRContext &ctx, + mlir::ModuleOp module_op) { + std::vector variable_shared_names; + + if (export_opts.unfreeze_constants) { + const absl::StatusOr> shared_names = + UnfreezeConstantsAndSaveVariables(export_opts.checkpoint_dir, ctx, + module_op); + if (!shared_names.ok()) { + return shared_names.status(); + } -absl::StatusOr QuantizeQatModel( - const absl::string_view saved_model_path, - const absl::string_view exported_names_str, const absl::string_view tags, - const absl::string_view quant_opts_serialized) { - const std::unordered_set tag_set = - absl::StrSplit(tags, ',', absl::SkipEmpty()); - std::vector exported_names = - absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); - QuantizationOptions quantization_options; - if (!quantization_options.ParseFromString( - // NOLINTNEXTLINE: std::string conversion required. - std::string(quant_opts_serialized))) { - return absl::InternalError( - "Failed to parse QuantizationOptions from string."); + LOG(INFO) << "Unfrozen constants and saved variables to checkpoint file: " + << export_opts.checkpoint_dir; + + variable_shared_names = std::move(*shared_names); } - // Convert the SavedModelBundle to an MLIR module. - mlir::DialectRegistry registry; + if (const absl::Status pass_run_status = RunPasses( + /*name=*/export_opts.debug_name, + /*add_passes_func=*/ + [dup_constants = export_opts.duplicate_shape_determining_constants]( + mlir::PassManager &pm) { AddExportPasses(dup_constants, pm); }, + ctx, module_op); + !pass_run_status.ok()) { + return pass_run_status; + } + + return variable_shared_names; +} + +// Creates MLIRContext where the dialects required for quantization are +// registered. +mlir::MLIRContext CreateMlirContextForTfQuantization() { + mlir::DialectRegistry registry{}; registry.insert(); - mlir::MLIRContext context(registry); + return mlir::MLIRContext{registry}; +} + +} // namespace + +absl::StatusOr QuantizeQatModel( + const absl::string_view saved_model_path, + const std::vector &signature_keys, + const std::unordered_set &tags, + const QuantizationOptions &quantization_options) { + // Convert the SavedModelBundle to an MLIR module. + mlir::MLIRContext context = CreateMlirContextForTfQuantization(); MLIRImportOptions import_options; import_options.upgrade_legacy = true; auto bundle = std::make_unique(); // TODO(b/213406917): Add support for the object graph based saved model input + std::vector exported_names = signature_keys; StatusOr> module = - SavedModelSignatureDefsToMlirImport(saved_model_path, tag_set, + SavedModelSignatureDefsToMlirImport(saved_model_path, tags, absl::MakeSpan(exported_names), &context, import_options, /*lift_variables=*/false, &bundle); @@ -229,64 +363,97 @@ absl::StatusOr QuantizeQatModel( mlir::OwningOpRef module_ref = std::move(module).value(); - const Status status = PreprocessAndFreezeGraph( - module_ref.get(), &context, bundle ? bundle->GetSession() : nullptr); - if (!status.ok()) { - return absl::InternalError("Failed to preprocess graph: " + - status.error_message()); + if (const absl::Status preprocess_status = PreprocessAndFreezeGraph( + module_ref.get(), &context, bundle ? bundle->GetSession() : nullptr); + !preprocess_status.ok()) { + return preprocess_status; } - mlir::PassManager pm(&context); - const absl::StatusOr> out_dump_file = - MaybeEnableIrPrinting(pm, /*name=*/"tf_quantize_qat"); - if (!out_dump_file.ok()) { - return absl::InternalError(out_dump_file.status().message()); + if (const absl::Status qat_status = + RunPasses(/*name=*/kTfQuantQatStepName, + /*add_passes_func=*/ + [&quantization_options](mlir::PassManager &pm) { + AddQuantizeQatPasses(pm, quantization_options); + }, + context, *module_ref); + !qat_status.ok()) { + return qat_status; } - AddQuantizeQatPasses(pm, quantization_options); - AddExportPasses(pm); + const bool unfreeze_constants = + !quantization_options.freeze_all_variables().enabled(); + const absl::StatusOr checkpoint_dir = GetLocalTempFilename(); + if (!checkpoint_dir.ok()) { + LOG(ERROR) << "Failed to get checkpoint directory name."; + return checkpoint_dir.status(); + } - mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); - if (failed(pm.run(*module_ref))) { - return absl::InternalError( - "failed to apply the quantization: " + - diagnostic_handler.ConsumeStatus().error_message()); + const auto export_opts = ExportOptions{ + /*duplicate_shape_determining_constants=*/true, unfreeze_constants, + *checkpoint_dir, + /*debug_name=*/absl::StrCat(kTfQuantQatStepName, kExportStepSuffix)}; + + const absl::StatusOr> variable_shared_names = + RunExportPasses(export_opts, context, *module_ref); + if (!variable_shared_names.ok()) { + return variable_shared_names.status(); } - return ConvertMlirModuleToExportedModel(*module_ref); + return ConvertMlirModuleToExportedModel(*module_ref, *checkpoint_dir, + *variable_shared_names, + /*function_aliases=*/{}); +} + +// Returns the updated function aliases. `module_op` may have different function +// names from the original model, so it re-associates the aliases with the new +// function names. Both the input `function_aliases` and the returned value +// are function name -> alias mappings. `function_aliases` is the function alias +// mapping of the original function. +absl::flat_hash_map UpdateFunctionAliases( + const absl::flat_hash_map function_aliases, + mlir::ModuleOp module_op) { + absl::flat_hash_map updated_function_aliases; + + module_op->walk([&](mlir::func::FuncOp func_op) { + // We may retrieve the original function's name from the attribute. + // Functions without this attribute are ignored. + auto original_func_name = + func_op->getAttrOfType("tf._original_func_name"); + if (original_func_name) { + if (auto alias_itr = function_aliases.find(original_func_name.str()); + alias_itr != function_aliases.end()) { + const std::string alias = alias_itr->second; + const std::string new_func_name = func_op.getSymName().str(); + + updated_function_aliases[new_func_name] = alias; + + VLOG(1) << "Updated function alias. Alias: " << alias + << ", New function name: " << new_func_name + << ", Old function name: " << original_func_name.str(); + } + } + }); + + return updated_function_aliases; } absl::StatusOr QuantizePtqModelPreCalibration( const absl::string_view saved_model_path, - const absl::string_view exported_names_str, const absl::string_view tags, - const absl::string_view quant_opts_serialized) { - const std::unordered_set tag_set = - absl::StrSplit(tags, ',', absl::SkipEmpty()); - std::vector exported_names = - absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); - QuantizationOptions quantization_options; - if (!quantization_options.ParseFromString( - // NOLINTNEXTLINE: std::string conversion required. - std::string(quant_opts_serialized))) { - return absl::InternalError( - "Failed to parse QuantizationOptions from string."); - } - + const std::vector &signature_keys, + const std::unordered_set &tags, + const QuantizationOptions &quantization_options, + const absl::flat_hash_map &function_aliases) { // Convert the SavedModelBundle to an MLIR module. - mlir::DialectRegistry registry; - registry.insert(); - mlir::MLIRContext context(registry); + mlir::MLIRContext context = CreateMlirContextForTfQuantization(); MLIRImportOptions import_options; import_options.upgrade_legacy = true; auto bundle = std::make_unique(); // TODO(b/213406917): Add support for the object graph based saved model input + std::vector exported_names = signature_keys; StatusOr> module = - SavedModelSignatureDefsToMlirImport(saved_model_path, tag_set, + SavedModelSignatureDefsToMlirImport(saved_model_path, tags, absl::MakeSpan(exported_names), &context, import_options, /*lift_variables=*/false, &bundle); @@ -297,67 +464,81 @@ absl::StatusOr QuantizePtqModelPreCalibration( } mlir::OwningOpRef module_ref = std::move(module).value(); - const Status status = PreprocessAndFreezeGraph( - module_ref.get(), &context, bundle ? bundle->GetSession() : nullptr); - if (!status.ok()) { - return absl::InternalError("Failed to preprocess graph: " + - status.error_message()); + const absl::flat_hash_map updated_function_aliases = + UpdateFunctionAliases(function_aliases, *module_ref); + + // Collect the names of the functions that have aliases so that they may not + // be inlined. + absl::flat_hash_set aliased_function_names; + absl::c_for_each(updated_function_aliases, [&](const auto &aliases) { + return aliased_function_names.insert(aliases.first); + }); + + if (const absl::Status preprocess_status = PreprocessAndFreezeGraph( + /*mlir_dump_file_prefix=*/kTfQuantPtqPreCalibrationStepName, + /*is_inliner_run=*/true, + /*noinline_functions=*/aliased_function_names, module_ref.get(), + &context, bundle ? bundle->GetSession() : nullptr); + !preprocess_status.ok()) { + return preprocess_status; } - mlir::PassManager pm(&context); - const absl::StatusOr> out_dump_file = - MaybeEnableIrPrinting(pm, /*name=*/"tf_quantize_ptq_pre_calibration"); - if (!out_dump_file.ok()) { - return absl::InternalError(out_dump_file.status().message()); + if (const absl::Status pre_calib_pass_status = RunPasses( + /*name=*/kTfQuantPtqPreCalibrationStepName, + /*add_passes_func=*/ + [&quantization_options](mlir::PassManager &pm) { + AddQuantizePtqPreCalibrationPasses(pm, quantization_options); + }, + context, *module_ref); + !pre_calib_pass_status.ok()) { + return pre_calib_pass_status; } - AddQuantizePtqPreCalibrationPasses(pm, quantization_options); - AddExportPasses(pm); - - mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); - if (failed(pm.run(*module_ref))) { - return absl::InternalError( - "Failed to apply the quantization at the pre-calibration stage: " + - diagnostic_handler.ConsumeStatus().error_message()); + const bool unfreeze_constants = + !quantization_options.freeze_all_variables().enabled(); + const absl::StatusOr checkpoint_dir = GetLocalTempFilename(); + if (!checkpoint_dir.ok()) { + return checkpoint_dir.status(); + } + // `duplicate_shape_determining_constants = false` because the + // resulting graph of this step is not expected to be loaded on TPU. + const auto export_opts = ExportOptions{ + /*duplicate_shape_determining_constants=*/false, unfreeze_constants, + *checkpoint_dir, + /*debug_name=*/ + absl::StrCat(kTfQuantPtqPreCalibrationStepName, kExportStepSuffix)}; + + const absl::StatusOr> variable_shared_names = + RunExportPasses(export_opts, context, *module_ref); + if (!variable_shared_names.ok()) { + return variable_shared_names.status(); } - return ConvertMlirModuleToExportedModel(*module_ref); + return ConvertMlirModuleToExportedModel(*module_ref, *checkpoint_dir, + *variable_shared_names, + updated_function_aliases); } absl::StatusOr QuantizePtqModelPostCalibration( const absl::string_view saved_model_path, - const absl::string_view exported_names_str, const absl::string_view tags, - const absl::string_view quant_opts_serialized) { - const std::unordered_set tag_set = - absl::StrSplit(tags, ',', absl::SkipEmpty()); - std::vector exported_names = - absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); - QuantizationOptions quantization_options; - if (!quantization_options.ParseFromString( - // NOLINTNEXTLINE: std::string conversion required. - std::string(quant_opts_serialized))) { - return absl::InternalError( - "Failed to parse QuantizationOptions from string."); - } - + const std::vector &signature_keys, + const std::unordered_set &tags, + const QuantizationOptions &quantization_options, + const absl::flat_hash_map &function_aliases) { // Convert the SavedModelBundle to an MLIR module. - mlir::DialectRegistry registry; - registry.insert(); - mlir::MLIRContext context(registry); + mlir::MLIRContext context = CreateMlirContextForTfQuantization(); MLIRImportOptions import_options; import_options.upgrade_legacy = true; auto bundle = std::make_unique(); // TODO(b/213406917): Add support for the object graph based saved model input + std::vector exported_names = signature_keys; StatusOr> module = - SavedModelSignatureDefsToMlirImport(saved_model_path, tag_set, + SavedModelSignatureDefsToMlirImport(saved_model_path, tags, absl::MakeSpan(exported_names), &context, import_options, - /*lift_variables=*/true, &bundle); + /*lift_variables=*/false, &bundle); if (!module.status().ok()) { return absl::InternalError("Failed to import SavedModel: " + @@ -366,57 +547,78 @@ absl::StatusOr QuantizePtqModelPostCalibration( mlir::OwningOpRef module_ref = std::move(module).value(); - mlir::PassManager pm(&context); - const absl::StatusOr> out_dump_file = - MaybeEnableIrPrinting(pm, /*name=*/"tf_quantize_ptq_post_calibration"); - if (!out_dump_file.ok()) { - return absl::InternalError(out_dump_file.status().message()); + const absl::flat_hash_map updated_function_aliases = + UpdateFunctionAliases(function_aliases, *module_ref); + + // Collect the names of the functions that have aliases so that they may not + // be inlined. + absl::flat_hash_set aliased_function_names; + absl::c_for_each(updated_function_aliases, [&](const auto &aliases) { + return aliased_function_names.insert(aliases.first); + }); + + // Freezing is required again since variables might have been produced during + // the pre-calibration step. `is_inliner_run = false` to prevent the functions + // lifted for quantization from being inlined. + if (const absl::Status preprocess_status = PreprocessAndFreezeGraph( + /*mlir_dump_file_prefix=*/kTfQuantPtqPostCalibrationStepName, + /*is_inliner_run=*/false, + /*noinline_functions=*/aliased_function_names, module_ref.get(), + &context, bundle ? bundle->GetSession() : nullptr); + !preprocess_status.ok()) { + return preprocess_status; } - AddQuantizePtqPostCalibrationPasses(pm, quantization_options); - AddExportPasses(pm); + if (const absl::Status pre_calib_pass_status = RunPasses( + /*name=*/kTfQuantPtqPostCalibrationStepName, + /*add_passes_func=*/ + [&quantization_options](mlir::PassManager &pm) { + AddQuantizePtqPostCalibrationPasses(pm, quantization_options); + }, + context, *module_ref); + !pre_calib_pass_status.ok()) { + return pre_calib_pass_status; + } - mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); - if (failed(pm.run(*module_ref))) { - return absl::InternalError( - "Failed to apply the quantization at the post-calibation stage: " + - diagnostic_handler.ConsumeStatus().error_message()); + const bool unfreeze_constants = + !quantization_options.freeze_all_variables().enabled(); + const absl::StatusOr checkpoint_dir = GetLocalTempFilename(); + if (!checkpoint_dir.ok()) { + return checkpoint_dir.status(); + } + const auto export_opts = ExportOptions{ + /*duplicate_shape_determining_constants=*/true, unfreeze_constants, + *checkpoint_dir, + /*debug_name=*/ + absl::StrCat(kTfQuantPtqPostCalibrationStepName, kExportStepSuffix)}; + + const absl::StatusOr> variable_shared_names = + RunExportPasses(export_opts, context, *module_ref); + if (!variable_shared_names.ok()) { + return variable_shared_names.status(); } - return ConvertMlirModuleToExportedModel(*module_ref); + return ConvertMlirModuleToExportedModel(*module_ref, *checkpoint_dir, + *variable_shared_names, + updated_function_aliases); } absl::StatusOr QuantizePtqDynamicRange( const absl::string_view saved_model_path, - const absl::string_view exported_names_str, const absl::string_view tags, - const absl::string_view quant_opts_serialized) { - const std::unordered_set tag_set = - absl::StrSplit(tags, ',', absl::SkipEmpty()); - std::vector exported_names = - absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); - QuantizationOptions quantization_options; - if (!quantization_options.ParseFromString( - // NOLINTNEXTLINE: std::string conversion required. - std::string(quant_opts_serialized))) { - return absl::InternalError( - "Failed to parse QuantizationOptions from string."); - } - + const std::vector &signature_keys, + const std::unordered_set &tags, + const QuantizationOptions &quantization_options) { // Convert the SavedModelBundle to an MLIR module. - mlir::DialectRegistry registry; - registry.insert(); - mlir::MLIRContext context(registry); + mlir::MLIRContext context = CreateMlirContextForTfQuantization(); MLIRImportOptions import_options; import_options.upgrade_legacy = true; auto bundle = std::make_unique(); // TODO(b/213406917): Add support for the object graph based saved model input + std::vector exported_names = signature_keys; StatusOr> module = - SavedModelSignatureDefsToMlirImport(saved_model_path, tag_set, + SavedModelSignatureDefsToMlirImport(saved_model_path, tags, absl::MakeSpan(exported_names), &context, import_options, /*lift_variables=*/false, &bundle); @@ -428,33 +630,44 @@ absl::StatusOr QuantizePtqDynamicRange( mlir::OwningOpRef module_ref = std::move(module).value(); - const Status status = PreprocessAndFreezeGraph( - module_ref.get(), &context, bundle ? bundle->GetSession() : nullptr); - if (!status.ok()) { - return absl::InternalError("Failed to preprocess graph: " + - status.error_message()); + if (const absl::Status preprocess_status = PreprocessAndFreezeGraph( + module_ref.get(), &context, bundle ? bundle->GetSession() : nullptr); + !preprocess_status.ok()) { + return preprocess_status; } - mlir::PassManager pm(&context); - const absl::StatusOr> out_dump_file = - MaybeEnableIrPrinting(pm, /*name=*/"tf_quantize_drq"); - if (!out_dump_file.ok()) { - return absl::InternalError(out_dump_file.status().message()); + if (const absl::Status ptq_dynamic_range_status = RunPasses( + /*name=*/kTfQuantPtqDynamicRangeStepName, + /*add_passes_func=*/ + [&quantization_options](mlir::PassManager &pm) { + AddQuantizePtqDynamicRangePasses(pm, quantization_options); + }, + context, *module_ref); + !ptq_dynamic_range_status.ok()) { + return ptq_dynamic_range_status; } - AddQuantizePtqDynamicRangePasses(pm, quantization_options); - AddExportPasses(pm); - - mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); - if (failed(pm.run(*module_ref))) { - return absl::InternalError( - "Failed to apply the quantization: " + - diagnostic_handler.ConsumeStatus().error_message()); + const bool unfreeze_constants = + !quantization_options.freeze_all_variables().enabled(); + const absl::StatusOr checkpoint_dir = GetLocalTempFilename(); + if (!checkpoint_dir.ok()) { + return checkpoint_dir.status(); + } + const auto export_opts = ExportOptions{ + /*duplicate_shape_determining_constants=*/true, unfreeze_constants, + *checkpoint_dir, + /*debug_name=*/ + absl::StrCat(kTfQuantPtqDynamicRangeStepName, kExportStepSuffix)}; + const absl::StatusOr> variable_shared_names = + RunExportPasses(export_opts, context, *module_ref); + if (!variable_shared_names.ok()) { + return variable_shared_names.status(); } - return ConvertMlirModuleToExportedModel(*module_ref); + return ConvertMlirModuleToExportedModel(*module_ref, *checkpoint_dir, + *variable_shared_names, + /*function_aliases=*/{}); } -} // namespace internal } // namespace quantization } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h index b69b3deaa0e..c3747fee523 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h @@ -16,43 +16,60 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_QUANTIZE_MODEL_H_ #include +#include +#include +#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/core/framework/graph.pb.h" namespace tensorflow { namespace quantization { -namespace internal { -// Represents an exported TensorFlow model. It consists of a GraphDef and extra -// metadata required for building a SavedModel. -struct ExportedModel { - GraphDef graph_def = {}; - - // Name of the initialization node used for initializing resources like - // hash tables upon loading. - std::string init_node_name = ""; -}; +// Names of the TensorFlow Quantization steps. These names are used primarily +// for debugging. +inline constexpr absl::string_view kTfQuantPtqPreCalibrationStepName = + "tf_quant_ptq_pre_calibration"; +inline constexpr absl::string_view kTfQuantPtqPostCalibrationStepName = + "tf_quant_ptq_post_calibration"; +inline constexpr absl::string_view kTfQuantQatStepName = "tf_quant_qat"; +inline constexpr absl::string_view kTfQuantPtqDynamicRangeStepName = + "tf_quant_ptq_dynamic_range"; +inline constexpr absl::string_view kTfQuantConstantUnfreezingStepName = + "tf_quant_constant_unfreezing"; +inline constexpr absl::string_view kTfQuantInsertRestoreOpStepName = + "tf_quant_insert_restore_op"; absl::StatusOr QuantizeQatModel( - absl::string_view saved_model_path, absl::string_view exported_names_str, - absl::string_view tags, absl::string_view quant_opts_serialized); + absl::string_view saved_model_path, + const std::vector& signature_keys, + const std::unordered_set& tags, + const QuantizationOptions& quant_opts); // Apply post-training dynamic range quantization to the model. absl::StatusOr QuantizePtqDynamicRange( - absl::string_view saved_model_path, absl::string_view exported_names_str, - absl::string_view tags, absl::string_view quant_opts_serialized); + absl::string_view saved_model_path, + const std::vector& signature_keys, + const std::unordered_set& tags, + const QuantizationOptions& quant_opts); absl::StatusOr QuantizePtqModelPreCalibration( - absl::string_view saved_model_path, absl::string_view exported_names_str, - absl::string_view tags, absl::string_view quant_opts_serialized); + absl::string_view saved_model_path, + const std::vector& exported_names, + const std::unordered_set& tags, + const QuantizationOptions& quant_opts, + const absl::flat_hash_map& function_aliases); absl::StatusOr QuantizePtqModelPostCalibration( - absl::string_view saved_model_path, absl::string_view exported_names_str, - absl::string_view tags, absl::string_view quant_opts_serialized); + absl::string_view saved_model_path, + const std::vector& signature_keys, + const std::unordered_set& tags, + const QuantizationOptions& quant_opts, + const absl::flat_hash_map& function_aliases); -} // namespace internal } // namespace quantization } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py index 89eb0e97f05..5c63834cfab 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -18,7 +18,6 @@ import tempfile from typing import Callable, Collection, Dict, Mapping, Optional, Sequence, Tuple import uuid -import warnings from absl import logging import numpy as np @@ -26,19 +25,19 @@ # pylint: disable=invalid-import-order,g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import -from tensorflow.compiler.mlir.quantization.tensorflow.python import pywrap_quantize_model as quantize_model_wrapper +from tensorflow.compiler.mlir.quantization.tensorflow.python import pywrap_quantize_model from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as repr_dataset +from tensorflow.compiler.mlir.quantization.tensorflow.python import save_model +from tensorflow.compiler.mlir.quantization.tensorflow import exported_model_pb2 from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 as quant_opts_pb2 from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.eager import wrap_function -from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.saved_model import builder from tensorflow.python.saved_model import loader_impl as saved_model_loader from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants @@ -61,17 +60,13 @@ _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS = 1024 -def _legalize_tensor_name(tensor_name: str) -> str: - """Converts tensor name from 'name:index' to 'name__index' format.""" - return tensor_name.replace(':', '__') - - def _is_qat_saved_model(saved_model_path: str): """Checks if the SavedModel is QAT-enabled by looking for 'FakeQuant' ops.""" saved_model_proto = saved_model_loader.parse_saved_model(saved_model_path) for meta_graph in saved_model_proto.meta_graphs: if any( - node.op.startswith('FakeQuant') for node in meta_graph.graph_def.node): + node.op.startswith('FakeQuant') for node in meta_graph.graph_def.node + ): return True for function in meta_graph.graph_def.library.function: if any(node.op.startswith('FakeQuant') for node in function.node_def): @@ -79,88 +74,11 @@ def _is_qat_saved_model(saved_model_path: str): return False -def _get_signatures_from_saved_model(saved_model_path: str, - signature_keys: Sequence[str], - tags: Collection[str]) -> _SignatureDefMap: - """Gets a map from signature keys to their SignatureDef from a saved model.""" - loader = saved_model_loader.SavedModelLoader(saved_model_path) - try: - meta_graphdef = loader.get_meta_graph_def_from_tags(tags) - except RuntimeError as runtime_error: - raise RuntimeError( - f'Failed to retrieve MetaGraphDef with tags {tags}' - f' from a SavedModel in {saved_model_path}.') from runtime_error - - signature_def_map = {} - for key, signature_def in meta_graphdef.signature_def.items(): - if key == _INIT_OP_SIGNATURE_KEY or key not in signature_keys: - continue - - signature_def_map[key] = signature_def - - return signature_def_map - - -def _fix_tensor_names(signature_def_map: _SignatureDefMap, - exported_graph: ops.Graph) -> Optional[_SignatureDefMap]: - """Tries fixing tensor names in the signatures to match the exported graph. - - The output tensor names in the original graph usually become names of the - return nodes in the exported graph. This function tries to fix that and checks - if the input tensor names are found in the exported graph. - - Args: - signature_def_map: the signatures of the original graph. - exported_graph: The PTQ-exported GraphDef. - - Returns: - Fixed signatures or None if it couldn't be fixed. - """ - # The InsertMainFunctionPass populates input and output nodes of the newly - # inserted main function with "tf_saved_model.index_path" attributes. These - # attributes can be used to identify outputs in the exported graph. - output_index_path_map = {} - for op in exported_graph.get_operations(): - if (op.type == '_Retval' and - op.get_attr('tf_saved_model.index_path') is not None): - index_path_name = op.get_attr('tf_saved_model.index_path')[0] - index_path_name = index_path_name.decode('utf-8') - output_index_path_map[index_path_name] = op.inputs[0].name - - for signature_def in signature_def_map.values(): - for tensor_info in signature_def.inputs.values(): - try: - exported_graph.get_tensor_by_name(tensor_info.name) - except KeyError: - # If input tensors are not found, the signatures can't be used for the - # exported graph. - warnings.warn('Cannot find the tensor with name %s in the graph.' % - tensor_info.name) - return None - - for tensor_info in signature_def.outputs.values(): - try: - if tensor_info.name in output_index_path_map: - tensor_info.name = output_index_path_map[tensor_info.name] - else: - # Tries to find the return node with the given name and use its input - # as the output tensor name. - return_node = exported_graph.get_operation_by_name( - _legalize_tensor_name(tensor_info.name)) - tensor_info.name = return_node.inputs[0].name - except KeyError: - warnings.warn( - 'Cannot find the tensor or node with name %s in the graph.' % - tensor_info.name) - return None - - return signature_def_map - - def _create_sample_validator( expected_input_keys: Collection[str], -) -> Callable[[repr_dataset.RepresentativeSample], - repr_dataset.RepresentativeSample]: +) -> Callable[ + [repr_dataset.RepresentativeSample], repr_dataset.RepresentativeSample +]: """Creates a validator function for a representative sample. Args: @@ -172,7 +90,7 @@ def _create_sample_validator( """ def validator( - sample: repr_dataset.RepresentativeSample + sample: repr_dataset.RepresentativeSample, ) -> repr_dataset.RepresentativeSample: """Validates a single instance of representative sample. @@ -191,16 +109,19 @@ def validator( the input keys of the function. """ if not isinstance(sample, collections.abc.Mapping): - raise ValueError('Invalid representative sample type. Provide a mapping ' - '(usually a dict) of {input_key: input_value}. ' - f'Got type: {type(sample)} instead.') + raise ValueError( + 'Invalid representative sample type. Provide a mapping ' + '(usually a dict) of {input_key: input_value}. ' + f'Got type: {type(sample)} instead.' + ) if set(sample.keys()) != expected_input_keys: raise KeyError( 'Invalid input keys for representative sample. The function expects ' f'input keys of: {set(expected_input_keys)}. ' f'Got: {set(sample.keys())}. Please provide correct input keys for ' - 'representative samples.') + 'representative samples.' + ) return sample @@ -209,7 +130,8 @@ def validator( def _validate_representative_dataset( representative_dataset: repr_dataset.RepresentativeDatasetOrMapping, - signature_keys: Collection[str]) -> None: + signature_keys: Collection[str], +) -> None: """Validates the representative dataset, based on the signature keys. Representative dataset can be provided in two different forms: a single @@ -239,18 +161,22 @@ def _validate_representative_dataset( raise ValueError( 'The signature keys and the keys of representative dataset map ' f'do not match. Signature keys: {set(signature_keys)}, ' - f'representative dataset map: {set(representative_dataset.keys())}.') + f'representative dataset map: {set(representative_dataset.keys())}.' + ) else: if len(signature_keys) > 1: - raise ValueError('Representative dataset is not a mapping ' - f'(got: {type(representative_dataset)}), ' - 'but there is more than one signature key provided. ' - 'Please provide a map of {signature_key -> dataset} ' - 'with more than one signature key.') + raise ValueError( + 'Representative dataset is not a mapping ' + f'(got: {type(representative_dataset)}), ' + 'but there is more than one signature key provided. ' + 'Please provide a map of {signature_key -> dataset} ' + 'with more than one signature key.' + ) def _convert_values_to_tf_tensors( - sample: repr_dataset.RepresentativeSample) -> Mapping[str, core.Tensor]: + sample: repr_dataset.RepresentativeSample, +) -> Mapping[str, core.Tensor]: """Converts TensorLike values of `sample` to Tensors. Creates a copy of `sample`, where each value is converted to Tensors @@ -278,7 +204,8 @@ def _convert_values_to_tf_tensors( def _create_feed_dict_from_input_data( input_data: repr_dataset.RepresentativeSample, - signature_def: meta_graph_pb2.SignatureDef) -> Dict[str, np.ndarray]: + signature_def: meta_graph_pb2.SignatureDef, +) -> Dict[str, np.ndarray]: """Constructs a feed_dict from input data. Note: This function should only be used in graph mode. @@ -332,7 +259,8 @@ def _log_sample_num_for_calibration( modification. """ num_samples: Optional[int] = repr_dataset.get_num_samples( - representative_dataset) + representative_dataset + ) if num_samples is None: total_num_samples = '?' logging.info('Representative dataset size unknown.') @@ -346,17 +274,26 @@ def _log_sample_num_for_calibration( # Log the sample number for every 5 iterations. logging.log_every_n( - logging.DEBUG, 'Running representative sample for calibration: %d / %s', - 5, sample_num, total_num_samples) + logging.DEBUG, + 'Running representative sample for calibration: %d / %s', + 5, + sample_num, + total_num_samples, + ) yield sample - logging.info('Running representative samples complete: %d / %s', sample_num, - total_num_samples) + logging.info( + 'Running representative samples complete: %d / %s', + sample_num, + total_num_samples, + ) def _run_function_for_calibration_graph_mode( - sess: session.Session, signature_def: meta_graph_pb2.SignatureDef, - representative_dataset: repr_dataset.RepresentativeDataset) -> None: + sess: session.Session, + signature_def: meta_graph_pb2.SignatureDef, + representative_dataset: repr_dataset.RepresentativeDataset, +) -> None: """Runs the representative dataset through a function for calibration. NOTE: This is intended to be run in graph mode (TF1). @@ -376,10 +313,12 @@ def _run_function_for_calibration_graph_mode( ] sample_validator = _create_sample_validator( - expected_input_keys=signature_def.inputs.keys()) + expected_input_keys=signature_def.inputs.keys() + ) - for sample in map(sample_validator, - _log_sample_num_for_calibration(representative_dataset)): + for sample in map( + sample_validator, _log_sample_num_for_calibration(representative_dataset) + ): # Create a mapping from input tensor name to the input tensor value. # ex) "Placeholder:0" -> [0, 1, 2] feed_dict = _create_feed_dict_from_input_data(sample, signature_def) @@ -387,7 +326,8 @@ def _run_function_for_calibration_graph_mode( def _replace_tensors_by_numpy_ndarrays( - repr_ds_map: repr_dataset.RepresentativeDatasetMapping) -> None: + repr_ds_map: repr_dataset.RepresentativeDatasetMapping, +) -> None: """Replaces tf.Tensors by their evaluated numpy arrays. This assumes that tf.Tensors in representative samples are created in the @@ -403,7 +343,8 @@ def _replace_tensors_by_numpy_ndarrays( # by their evaluated values. ds = repr_ds_map[signature_def_key] repr_ds_map[signature_def_key] = ( - repr_dataset.replace_tensors_by_numpy_ndarrays(ds, sess)) + repr_dataset.replace_tensors_by_numpy_ndarrays(ds, sess) + ) def _run_graph_for_calibration_graph_mode( @@ -435,23 +376,27 @@ def _run_graph_for_calibration_graph_mode( # happen when the same model is loaded multiple times in the default graph. with ops.Graph().as_default(), session.Session() as sess: meta_graph: meta_graph_pb2.MetaGraphDef = saved_model_loader.load( - sess, tags, export_dir=model_dir) + sess, tags, export_dir=model_dir + ) for signature_key, repr_ds in representative_dataset_map.items(): sig_def = meta_graph.signature_def[signature_key] try: _run_function_for_calibration_graph_mode( - sess, signature_def=sig_def, representative_dataset=repr_ds) + sess, signature_def=sig_def, representative_dataset=repr_ds + ) except Exception as ex: raise ValueError( 'Failed to run representative dataset through the ' - f'function with the signature key: {signature_key}.') from ex + f'function with the signature key: {signature_key}.' + ) from ex def _run_function_for_calibration_eager_mode( func: wrap_function.WrappedFunction, - representative_dataset: repr_dataset.RepresentativeDataset) -> None: + representative_dataset: repr_dataset.RepresentativeDataset, +) -> None: """Runs the representative dataset through a function for calibration. NOTE: This is intended to be run in eager mode (TF2). @@ -464,10 +409,12 @@ def _run_function_for_calibration_eager_mode( """ _, keyword_args = func.structured_input_signature sample_validator = _create_sample_validator( - expected_input_keys=keyword_args.keys()) + expected_input_keys=keyword_args.keys() + ) - for sample in map(sample_validator, - _log_sample_num_for_calibration(representative_dataset)): + for sample in map( + sample_validator, _log_sample_num_for_calibration(representative_dataset) + ): # Convert any non-Tensor values from the sample to Tensors. # This conversion is required because the model saved in `model_dir` is # saved using TF1 SavedModelBuilder, which doesn't save the @@ -503,11 +450,13 @@ def _run_graph_for_calibration_eager_mode( for signature_key, repr_ds in representative_dataset_map.items(): try: _run_function_for_calibration_eager_mode( - func=root.signatures[signature_key], representative_dataset=repr_ds) + func=root.signatures[signature_key], representative_dataset=repr_ds + ) except Exception as ex: raise ValueError( 'Failed to run representative dataset through the ' - f'function with the signature key: {signature_key}.') from ex + f'function with the signature key: {signature_key}.' + ) from ex def _run_graph_for_calibration( @@ -550,11 +499,13 @@ def _run_graph_for_calibration( try: if context.executing_eagerly(): - _run_graph_for_calibration_eager_mode(float_model_dir, tags, - representative_dataset_map) + _run_graph_for_calibration_eager_mode( + float_model_dir, tags, representative_dataset_map + ) else: - _run_graph_for_calibration_graph_mode(float_model_dir, tags, - representative_dataset_map) + _run_graph_for_calibration_graph_mode( + float_model_dir, tags, representative_dataset_map + ) except Exception as ex: raise ValueError( 'Failed to run graph for post-training quantization calibration.' @@ -563,29 +514,12 @@ def _run_graph_for_calibration( logging.info('Calibration step complete.') -def _create_empty_output_dir(output_directory: str) -> None: - """Creates the `output_directory`. - - If `output_directory` already exists, it recursively deletes all contents - inside the directory. - - Also creates the parent & intermediate directories. - - Args: - output_directory: Output directory. - """ - if file_io.file_exists_v2(output_directory): - logging.info('Deleting existing directory for quantized model output: %s .', - output_directory) - file_io.delete_recursively_v2(output_directory) - - file_io.recursive_create_dir_v2(output_directory) - - def _run_static_range_qat( - saved_model_path: str, signature_def_keys: Sequence[str], + saved_model_path: str, + signature_def_keys: Sequence[str], tags: Collection[str], - quant_opts: quant_opts_pb2.QuantizationOptions) -> graph_pb2.GraphDef: + quant_opts: quant_opts_pb2.QuantizationOptions, +) -> exported_model_pb2.ExportedModel: """Runs static-range quantization for a Quantization-Aware Trained model. Runs the quantization for a model trained using QAT. @@ -598,16 +532,22 @@ def _run_static_range_qat( quant_opts: Quantization options. Returns: - The static-range quantized graph. + exported_model: Contains the GraphDef and extra metadata required for saving + the quantized graph to SavedModel. """ logging.info('Running static-range quantization for QAT model.') - graph_def_serialized = ( - quantize_model_wrapper.quantize_qat_model(saved_model_path, - ','.join(signature_def_keys), - ','.join(tags), - quant_opts.SerializeToString())) + exported_model_serialized = pywrap_quantize_model.quantize_qat_model( + saved_model_path, + list(signature_def_keys), + set(tags), + quant_opts.SerializeToString(), + ) - return graph_pb2.GraphDef.FromString(graph_def_serialized) + exported_model = exported_model_pb2.ExportedModel.FromString( + exported_model_serialized + ) + + return exported_model def _add_calibration_statistics(graph_def: graph_pb2.GraphDef) -> None: @@ -627,40 +567,20 @@ def _add_calibration_statistics(graph_def: graph_pb2.GraphDef) -> None: node_id = node_def.attr['id'].s try: - min_val = quantize_model_wrapper.get_min_from_calibrator(node_id) - max_val = quantize_model_wrapper.get_max_from_calibrator(node_id) - quantize_model_wrapper.clear_data_from_calibrator(node_id) + min_val = pywrap_quantize_model.get_min_from_calibrator(node_id) + max_val = pywrap_quantize_model.get_max_from_calibrator(node_id) + pywrap_quantize_model.clear_data_from_calibrator(node_id) node_def.attr['min'].f = float(min_val) node_def.attr['max'].f = float(max_val) except ValueError: logging.warn( - 'CustomAggregator id "%s" from FunctionDef "%s" does not have ' - 'min or max values. Parts of this function are not quantized.', - node_id.decode('utf-8'), function_def.signature.name) - - -def _find_op(graph: ops.Graph, - op_name: Optional[str]) -> Optional[ops.Operation]: - """Finds the operation with `op_name`. - - Args: - graph: The graph to find from. - op_name: Name of the node. - - Returns: - The operation that corresponds to `op_name`. Returns None iff op_name is an - empty string or None. - - Raises: - ValueError: `op_name` is malformed. - """ - if not op_name: - return None - - init_op = graph.get_operation_by_name(op_name) - logging.debug('Op found in the graph: %s', op_name) - - return init_op + ( + 'CustomAggregator id "%s" from FunctionDef "%s" does not have ' + 'min or max values. Parts of this function are not quantized.' + ), + node_id.decode('utf-8'), + function_def.signature.name, + ) def _run_static_range_ptq( @@ -670,7 +590,7 @@ def _run_static_range_ptq( quant_opts: quant_opts_pb2.QuantizationOptions, representative_dataset: repr_dataset.RepresentativeDatasetOrMapping, signature_def_map: _SignatureDefMap, -) -> Tuple[graph_pb2.GraphDef, _SignatureDefMap, str]: +) -> Tuple[exported_model_pb2.ExportedModel, _SignatureDefMap]: """Runs static-range Post-Training Quantization. Runs static-range PTQ for the model. Runs the calibration step with @@ -693,122 +613,88 @@ def _run_static_range_ptq( ValueError if the graph doesn't contain a valid signature. Returns: - (graph_def, signature_def_map, init_op_name) where graph_def is the - quantized graph and - the signature_def_map contains the SignatureDefs, possibly modified - according to the quantized graph to match the original signature defs. - init_op_name is the name of the initializer op, which is fetched once to - initialize resources (e.g. hash tables) when a SavedModel is loaded. + exported_model: Contains the GraphDef and extra metadata required for saving + the quantized graph to SavedModel. + signature_def_map: Contains the SignatureDefs, possibly modified + according to the quantized graph to match the original signature defs. """ logging.info('Running post-training quantization pre-calibration step.') - graph_def_serialized, init_node_name = ( - quantize_model_wrapper.quantize_ptq_model_pre_calibration( - saved_model_path, ','.join(signature_def_keys), ','.join(tags), - quant_opts.SerializeToString())) - graph_def = graph_pb2.GraphDef.FromString(graph_def_serialized) + loader = saved_model_loader.SavedModelLoader(saved_model_path) + function_aliases = loader.get_meta_graph_def_from_tags( + tags + ).meta_info_def.function_aliases + + exported_model_serialized = ( + pywrap_quantize_model.quantize_ptq_model_pre_calibration( + saved_model_path, + list(signature_def_keys), + set(tags), + quant_opts.SerializeToString(), + dict(function_aliases), + ) + ) + + exported_model = exported_model_pb2.ExportedModel.FromString( + exported_model_serialized + ) + + graph_def = exported_model.graph_def + for function_def in graph_def.library.function: + for node_def in function_def.node_def: + if node_def.op == 'CustomAggregator': + node_def.attr['id'].s = uuid.uuid4().hex.encode('ascii') float_model_dir = tempfile.mkdtemp() - v1_builder = builder.SavedModelBuilder(float_model_dir) - - with session.Session(graph=ops.Graph()) as sess: - for function_def in graph_def.library.function: - for node_def in function_def.node_def: - if node_def.op == 'CustomAggregator': - node_def.attr['id'].s = uuid.uuid4().hex.encode('ascii') - - importer.import_graph_def(graph_def, name='') - working_graph = ops.get_default_graph() - graph_def = working_graph.as_graph_def() - - signature_def_map = _fix_tensor_names(signature_def_map, working_graph) - if signature_def_map is None: - raise ValueError("The input SavedModel doesn't contain a valid signature") - - v1_builder.add_meta_graph_and_variables( - sess, - tags, - signature_def_map=signature_def_map, - main_op=_find_op(working_graph, init_node_name)) - - v1_builder.save() + save_model.save_model_v1( + graph_def, + float_model_dir, + signature_def_map, + tags, + exported_model.init_node_name, + exported_model.restore_node_name, + exported_model.checkpoint_dir, + exported_model.variable_shared_names, + exported_model.function_aliases, + ) # Uses the representative dataset to collect statistics for calibration. # Handles the graph mode execution separately in case TF2 is disabled or # eager execution is disabled. The min & max values are stored separately # in a global CalibratorSingleton instance. - _run_graph_for_calibration(float_model_dir, signature_def_keys, tags, - representative_dataset) + _run_graph_for_calibration( + float_model_dir, signature_def_keys, tags, representative_dataset + ) _add_calibration_statistics(graph_def) calibrated_model_dir = tempfile.mkdtemp() - v1_builder = builder.SavedModelBuilder(calibrated_model_dir) - - with session.Session(graph=ops.Graph()) as sess: - importer.import_graph_def(graph_def, name='') - working_graph = ops.get_default_graph() - graph_def = working_graph.as_graph_def() - - v1_builder.add_meta_graph_and_variables( - sess, - tags, - signature_def_map=signature_def_map, - main_op=_find_op(working_graph, init_node_name)) - - v1_builder.save() - - signature_def_map = _get_signatures_from_saved_model(calibrated_model_dir, - signature_def_keys, tags) + save_model.save_model_v1( + graph_def, + calibrated_model_dir, + signature_def_map, + tags, + exported_model.init_node_name, + exported_model.restore_node_name, + exported_model.checkpoint_dir, + exported_model.variable_shared_names, + ) logging.info('Running post-training quantization post-calibration step.') - graph_def_serialized, init_node_name = ( - quantize_model_wrapper.quantize_ptq_model_post_calibration( - calibrated_model_dir, ','.join(signature_def_keys), ','.join(tags), - quant_opts.SerializeToString())) - - graph_def = graph_pb2.GraphDef.FromString(graph_def_serialized) - - return graph_def, signature_def_map, init_node_name + exported_model_serialized = ( + pywrap_quantize_model.quantize_ptq_model_post_calibration( + calibrated_model_dir, + list(signature_def_keys), + set(tags), + quant_opts.SerializeToString(), + dict(exported_model.function_aliases), + ) + ) + exported_model = exported_model_pb2.ExportedModel.FromString( + exported_model_serialized + ) -def _save_model_v1(graph_def: graph_pb2.GraphDef, - output_dir: str, - signature_def_map: _SignatureDefMap, - tags: Collection[str], - init_op_name: Optional[str] = None) -> None: - """Saves the model. - - Saves the provided graph def as SavedModel. - Uses TF1 SavedModel semantics (i.e. no object graph). - - Args: - graph_def: Graph to save. - output_dir: Output directory for the SavedModel. - signature_def_map: Mapping of signature def key -> SignatureDef. - tags: Tags for the meta graph def. - init_op_name: Name of the node for initialization. - - Raises: - ValueError iff the graph does not contain a valid signature. - """ - _create_empty_output_dir(output_dir) - v1_builder = builder.SavedModelBuilder(output_dir) - - with session.Session(graph=ops.Graph()) as sess: - importer.import_graph_def(graph_def, name='') - - signature_def_map = _fix_tensor_names(signature_def_map, - ops.get_default_graph()) - if signature_def_map is None: - raise ValueError("The input SavedModel doesn't contain a valid signature") - - v1_builder.add_meta_graph_and_variables( - sess, - tags, - signature_def_map=signature_def_map, - main_op=_find_op(sess.graph, op_name=init_op_name)) - - v1_builder.save() + return exported_model, signature_def_map def _static_range_quantize( @@ -818,7 +704,8 @@ def _static_range_quantize( output_directory: str, quantization_options: quant_opts_pb2.QuantizationOptions, representative_dataset: Optional[ - repr_dataset.RepresentativeDatasetOrMapping] = None + repr_dataset.RepresentativeDatasetOrMapping + ] = None, ) -> autotrackable.AutoTrackable: """Quantizes the given SavedModel via static range quantization. @@ -851,42 +738,56 @@ def _static_range_quantize( RuntimeError: When a MetaGraphDef could not be found associated with `tags` in the SavedModel. """ - logging.info('Running static range quantization on model: %s', - saved_model_path) + logging.info( + 'Running static range quantization on model: %s', saved_model_path + ) logging.info('Using SignatureDef keys: %s', signature_keys) logging.info('Using tags: %s', tags) logging.info('QuantizationOptions: \n%s', quantization_options) is_qat_saved_model = _is_qat_saved_model(saved_model_path) - signature_def_map = _get_signatures_from_saved_model(saved_model_path, - signature_keys, tags) + signature_def_map = save_model.get_signatures_from_saved_model( + saved_model_path, signature_keys, tags + ) # Checks if the model is from QAT if representative_dataset is None and not is_qat_saved_model: raise ValueError( 'When `representative_dataset` is not provided, the model should be ' - 'trained with quantization-aware training (QAT).') + 'trained with quantization-aware training (QAT).' + ) if quantization_options.min_num_elements_for_weights > 0: logging.warn( 'min_num_elements_for_weights is set but is not supported for the ' 'Post-training static range quantization. ' - 'The flag is ignored.') + 'The flag is ignored.' + ) if is_qat_saved_model: - init_node_name: Optional[str] = None - graph_def = _run_static_range_qat(saved_model_path, signature_keys, tags, - quantization_options) + exported_model = _run_static_range_qat( + saved_model_path, signature_keys, tags, quantization_options + ) else: - graph_def, signature_def_map, init_node_name = _run_static_range_ptq( - saved_model_path, signature_keys, tags, quantization_options, - representative_dataset, signature_def_map) + exported_model, signature_def_map = _run_static_range_ptq( + saved_model_path, + signature_keys, + tags, + quantization_options, + representative_dataset, + signature_def_map, + ) - _save_model_v1( - graph_def, + save_model.save_model_v1( + exported_model.graph_def, output_directory, signature_def_map, tags, - init_op_name=init_node_name) + init_op_name=exported_model.init_node_name, + restore_op_name=exported_model.restore_node_name, + checkpoint_dir=exported_model.checkpoint_dir, + variable_shared_names=exported_model.variable_shared_names, + function_aliases=exported_model.function_aliases, + ) return saved_model_load(output_directory) @@ -900,6 +801,8 @@ def _dynamic_range_quantize( ) -> autotrackable.AutoTrackable: """Quantizes the given SavedModel via post-training dynamic range quantization. + Weight-only quantization also uses this path. + Args: saved_model_path: Path to the saved model. signature_keys: Sequence of keys identifying SignatureDef containing inputs @@ -917,13 +820,22 @@ def _dynamic_range_quantize( Raises: ValueError: when the model is QAT model. """ + if ( + quantization_options.quantization_method.experimental_method + == _ExperimentalMethod.WEIGHT_ONLY + ): + mode_str = 'weight-only quantization' + else: + mode_str = 'dynamic-range quantization' if _is_qat_saved_model(saved_model_path): raise ValueError( 'The models trained with quantization-aware training (QAT) is not ' - 'supported for dynamic range quantization.') + 'supported for %s.' % mode_str + ) - logging.info('Running post-training dynamic-range quantization on model: %s', - saved_model_path) + logging.info( + 'Running post-training %s on model: %s', mode_str, saved_model_path + ) logging.info('Using SignatureDef keys: %s', signature_keys) logging.info('Using tags: %s', tags) logging.info('QuantizationOptions: \n%s', quantization_options) @@ -934,28 +846,39 @@ def _dynamic_range_quantize( # please also update default value in tflite converter: # tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc;l=201 if quantization_options.min_num_elements_for_weights == 0: - (quantization_options.min_num_elements_for_weights - ) = _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS + (quantization_options.min_num_elements_for_weights) = ( + _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS + ) logging.warn( - 'QuantizationOptions.min_num_elements_for_weights is not set (0). ' - 'Setting to the default value: %s.', - _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS) + ( + 'QuantizationOptions.min_num_elements_for_weights is not set (0). ' + 'Setting to the default value: %s.' + ), + _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS, + ) # Apply post-training dynamic range quantization to the model. - graph_def_serialized = ( - quantize_model_wrapper.quantize_ptq_dynamic_range( - saved_model_path, ','.join(signature_keys), ','.join(tags), - quantization_options.SerializeToString())) - - graph_def = graph_pb2.GraphDef.FromString(graph_def_serialized) - signature_def_map = _get_signatures_from_saved_model(saved_model_path, - signature_keys, tags) - - _save_model_v1( - graph_def, + exported_model_serialized = pywrap_quantize_model.quantize_ptq_dynamic_range( + saved_model_path, + list(signature_keys), + set(tags), + quantization_options.SerializeToString(), + ) + + exported_model = exported_model_pb2.ExportedModel.FromString( + exported_model_serialized + ) + signature_def_map = save_model.get_signatures_from_saved_model( + saved_model_path, signature_keys, tags + ) + + save_model.save_model_v1( + exported_model.graph_def, output_directory, signature_def_map, - tags={tag_constants.SERVING}) + tags=tags, + init_op_name=exported_model.init_node_name, + ) return saved_model_load(output_directory) @@ -976,13 +899,68 @@ def _verify_output_dir(output_dir: Optional[str], overwrite: bool) -> None: FileExistsError: Iff `output_dir` is not empty and `overwrite` is false. """ dir_not_empty = ( - output_dir is not None and file_io.file_exists_v2(output_dir) and - file_io.list_directory_v2(output_dir)) + output_dir is not None + and file_io.file_exists_v2(output_dir) + and file_io.list_directory_v2(output_dir) + ) if dir_not_empty and not overwrite: - raise FileExistsError(f'Output directory already exists: {output_dir} . ' - 'Please set overwrite_output_directory to true to ' - 'overwrite the existing directory.') + raise FileExistsError( + f'Output directory already exists: {output_dir} . ' + 'Please set overwrite_output_directory to true to ' + 'overwrite the existing directory.' + ) + + +def _populate_quantization_options_default_values( + quantization_options: quant_opts_pb2.QuantizationOptions, +) -> None: + """Populates default values for QuantizationOptions. + + Populates unspecified or unset fields of QuantizationOptions with the default + values. + + * If `op_set` is unspecified, it defaults to `OpSet.TF`. + * If `freeze_all_variables` is not set, it defaults to `True`. + * Check if configurations are set correctly: + - Per-channel quantization is supported for Uniform Quantized opset only. + + Args: + quantization_options: An instance of QuantizationOptions. + """ + if quantization_options.op_set == quant_opts_pb2.OpSet.OP_SET_UNSPECIFIED: + quantization_options.op_set = quant_opts_pb2.OpSet.TF + + if not quantization_options.HasField('freeze_all_variables'): + quantization_options.freeze_all_variables.enabled = True + + if quantization_options.enable_per_channel_quantization and ( + quantization_options.op_set != quant_opts_pb2.OpSet.UNIFORM_QUANTIZED + ): + raise ValueError( + 'Currently, per-channel quantization is supported for Uniform ' + 'Quantized opset only.' + ) + + if ( + quantization_options.quantization_method.experimental_method + == _ExperimentalMethod.WEIGHT_ONLY + and quantization_options.op_set == quant_opts_pb2.OpSet.UNIFORM_QUANTIZED + ): + raise ValueError('Uniform quantized opset does not support weight-only.') + + # Converter assumes options are specified. So set SRQ explicitly. + if ( + quantization_options.quantization_method.experimental_method + == _ExperimentalMethod.EXPERIMENTAL_METHOD_UNSPECIFIED + ): + logging.debug( + '"experimental_method" for QuantizationMethod is not specified.' + 'Static range quantization is used by default.' + ) + quantization_options.quantization_method.experimental_method = ( + _ExperimentalMethod.STATIC_RANGE + ) def quantize( @@ -992,7 +970,8 @@ def quantize( output_directory: Optional[str] = None, quantization_options: Optional[quant_opts_pb2.QuantizationOptions] = None, representative_dataset: Optional[ - repr_dataset.RepresentativeDatasetOrMapping] = None, + repr_dataset.RepresentativeDatasetOrMapping + ] = None, *, overwrite_output_directory: bool = False, ) -> autotrackable.AutoTrackable: @@ -1029,39 +1008,52 @@ def quantize( implemented. """ _verify_output_dir(output_directory, overwrite_output_directory) + + # Set default values for None arguments. if output_directory is None: output_directory = tempfile.mkdtemp() - # Set default values for None arguments. if quantization_options is None: quantization_options = quant_opts_pb2.QuantizationOptions() - if quantization_options.op_set == quant_opts_pb2.OpSet.OP_SET_UNSPECIFIED: - quantization_options.op_set = quant_opts_pb2.OpSet.TF + + _populate_quantization_options_default_values(quantization_options) if tags is None: tags = {tag_constants.SERVING} + if signature_keys is None: signature_keys = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] - method: quant_opts_pb2.QuantizationMethod = quantization_options.quantization_method + method: quant_opts_pb2.QuantizationMethod = ( + quantization_options.quantization_method + ) if method.HasField('method'): raise ValueError(f'Invalid value for QuantizationMethod: {method.method}.') elif method.HasField('experimental_method'): if method.experimental_method == _ExperimentalMethod.STATIC_RANGE: - return _static_range_quantize(saved_model_path, signature_keys, tags, - output_directory, quantization_options, - representative_dataset) - elif method.experimental_method == _ExperimentalMethod.DYNAMIC_RANGE: - return _dynamic_range_quantize(saved_model_path, signature_keys, tags, - output_directory, quantization_options) + return _static_range_quantize( + saved_model_path, + signature_keys, + tags, + output_directory, + quantization_options, + representative_dataset, + ) + elif ( + method.experimental_method == _ExperimentalMethod.DYNAMIC_RANGE + or method.experimental_method == _ExperimentalMethod.WEIGHT_ONLY + ): + return _dynamic_range_quantize( + saved_model_path, + signature_keys, + tags, + output_directory, + quantization_options, + ) else: raise NotImplementedError( 'Experimental quantization method {method.experimental_method}' - ' is not implemented.') + ' is not implemented.' + ) else: - logging.debug( - 'Neither "method" nor "experimental_method" for QuantizationMethod ' - 'is specified. Static range quantization is used by default.') - return _static_range_quantize(saved_model_path, signature_keys, tags, - output_directory, quantization_options, - representative_dataset) + raise ValueError(f'Invalid value for QuantizationMethod: {method.method}.') diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model_wrapper.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model_wrapper.cc deleted file mode 100644 index 15299f710f1..00000000000 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model_wrapper.cc +++ /dev/null @@ -1,158 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model_wrapper.h" - -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "pybind11/pybind11.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibrator_singleton.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/python/lib/core/pybind11_lib.h" - -namespace tensorflow { -namespace quantization { -namespace { - -using ::tensorflow::quantization::internal::ExportedModel; - -// Serializes a GraphDef. Raises python ValueError if serialization fails. -std::string SerializeGraphDef(const GraphDef& graph_def, - const absl::string_view function_name, - const int line_no) { - const std::string graph_def_serialized = graph_def.SerializeAsString(); - - // Empty string means it failed to serialize the protobuf with an error. See - // the docstring for SerializeAsString for details. - if (graph_def_serialized.empty()) { - throw py::value_error(absl::StrFormat( - "Failed to serialize GraphDef result from function %s [%s:%d].", - function_name, __FILE__, line_no)); - } - - return graph_def_serialized; -} - -} // namespace - -std::pair QuantizeQatModel( - const absl::string_view saved_model_path, - const absl::string_view exported_names_str, const absl::string_view tags, - const absl::string_view quant_opts_serialized) { - const absl::StatusOr exported_model = - internal::QuantizeQatModel(saved_model_path, exported_names_str, tags, - quant_opts_serialized); - if (!exported_model.ok()) { - throw py::value_error(absl::StrFormat("Failed to quantize QAT model: %s", - exported_model.status().message())); - } - - return std::make_pair( - SerializeGraphDef(exported_model->graph_def, __func__, __LINE__), - exported_model->init_node_name); -} - -std::pair QuantizePtqDynamicRange( - const absl::string_view saved_model_path, - const absl::string_view exported_names_str, const absl::string_view tags, - const absl::string_view quant_opts_serialized) { - const absl::StatusOr exported_model = - internal::QuantizePtqDynamicRange(saved_model_path, exported_names_str, - tags, quant_opts_serialized); - if (!exported_model.ok()) { - throw py::value_error( - absl::StrFormat("Failed to apply post-training dynamic range " - "quantization to the model: %s", - exported_model.status().message())); - } - - return std::make_pair( - SerializeGraphDef(exported_model->graph_def, __func__, __LINE__), - exported_model->init_node_name); -} - -std::pair QuantizePtqModelPreCalibration( - const absl::string_view saved_model_path, - const absl::string_view exported_names_str, const absl::string_view tags, - const absl::string_view quant_opts_serialized) { - const absl::StatusOr exported_model = - internal::QuantizePtqModelPreCalibration( - saved_model_path, exported_names_str, tags, quant_opts_serialized); - if (!exported_model.ok()) { - throw py::value_error(absl::StrFormat( - "Failed to quantize PTQ model at the precalibration stage: %s", - exported_model.status().message())); - } - - return std::make_pair( - SerializeGraphDef(exported_model->graph_def, __func__, __LINE__), - exported_model->init_node_name); -} - -std::pair QuantizePtqModelPostCalibration( - const absl::string_view saved_model_path, - const absl::string_view exported_names_str, const absl::string_view tags, - const absl::string_view quant_opts_serialized) { - const absl::StatusOr exported_model = - internal::QuantizePtqModelPostCalibration( - saved_model_path, exported_names_str, tags, quant_opts_serialized); - if (!exported_model.ok()) { - throw py::value_error(absl::StrFormat( - "Failed to quantize PTQ model at the postcalibration stage: %s", - exported_model.status().message())); - } - - return std::make_pair( - SerializeGraphDef(exported_model->graph_def, __func__, __LINE__), - exported_model->init_node_name); -} - -void ClearCollectedInformationFromCalibrator() { - calibrator::CalibratorSingleton::ClearCollectedInformation(); -} - -void ClearDataFromCalibrator(absl::string_view id) { - calibrator::CalibratorSingleton::ClearData(id); -} - -float GetMinFromCalibrator(absl::string_view id) { - std::optional> min_max = - calibrator::CalibratorSingleton::GetMinMax(id); - if (!min_max.has_value()) { - throw py::value_error(absl::StrFormat( - "No calibrated data; cannot find min value for '%s'", id)); - } - - return min_max->first; -} - -float GetMaxFromCalibrator(absl::string_view id) { - std::optional> min_max = - calibrator::CalibratorSingleton::GetMinMax(id); - if (!min_max.has_value()) { - throw py::value_error(absl::StrFormat( - "No calibrated data; cannot find max value for '%s'", id)); - } - - return min_max->second; -} - -} // namespace quantization -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model_wrapper.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model_wrapper.h deleted file mode 100644 index e18bb01a498..00000000000 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model_wrapper.h +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_QUANTIZE_MODEL_WRAPPER_H_ -#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_QUANTIZE_MODEL_WRAPPER_H_ - -#include -#include - -#include "absl/strings/string_view.h" - -namespace tensorflow { -namespace quantization { - -// TODO(b/247442990): Devise a better data structure to transfer this data -// structure to python. -std::pair QuantizeQatModel( - absl::string_view saved_model_path, absl::string_view exported_names_str, - absl::string_view tags, absl::string_view quant_opts_serialized); - -std::pair QuantizePtqDynamicRange( - absl::string_view saved_model_path, absl::string_view exported_names_str, - absl::string_view tags, absl::string_view quant_opts_serialized); - -// Runs the pre-calibration step of post-training quantization (PTQ). Returns -// (serialized GraphDef, initializer node name). -std::pair QuantizePtqModelPreCalibration( - absl::string_view saved_model_path, absl::string_view exported_names_str, - absl::string_view tags, absl::string_view quant_opts_serialized); - -// Runs the post-calibration step of post-training quantization (PTQ). Returns -// (serialized GraphDef, initializer node name). -std::pair QuantizePtqModelPostCalibration( - absl::string_view saved_model_path, absl::string_view exported_names_str, - absl::string_view tags, absl::string_view quant_opts_serialized); - -void ClearCollectedInformationFromCalibrator(); - -void ClearDataFromCalibrator(absl::string_view id); - -float GetMinFromCalibrator(absl::string_view id); - -float GetMaxFromCalibrator(absl::string_view id); - -} // namespace quantization -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_QUANTIZE_MODEL_WRAPPER_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset.py index 78311571f14..5608605b5c4 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset.py @@ -36,13 +36,55 @@ # A type alias expressing that it can be either a RepresentativeDataset or # a mapping of signature key to RepresentativeDataset. -RepresentativeDatasetOrMapping = Union[RepresentativeDataset, - RepresentativeDatasetMapping] +RepresentativeDatasetOrMapping = Union[ + RepresentativeDataset, RepresentativeDatasetMapping +] + + +class RepresentativeDatasetSaver: + """Representative dataset saver. + + Exposes a single method `save` that saves the provided representative dataset + into files. + + This is useful when you would like to keep a snapshot of your representative + dataset at a file system or when you need to pass the representative dataset + as files. + """ + + def save( + self, representative_dataset: RepresentativeDatasetOrMapping + ) -> None: + """Saves the representative dataset. + + Args: + representative_dataset: RepresentativeDataset or + RepresentativeDatasetMapping which is a signature_def_key -> + representative dataset mapping. RepresentativeDataset should be + considered as: {"serving_default": representative_dataset}. + """ + raise NotImplementedError('Method "save" is not implemented.') + + +class RepresentativeDatasetLoader: + """Representative dataset loader. + + Exposes a single method `load` that loads the representative dataset from + files. + """ + + def load(self) -> RepresentativeDatasetMapping: + """Loads the representative dataset. + + Returns: + A signature def key -> representative dataset mapping. + """ + raise NotImplementedError('Method "load" is not implemented.') def replace_tensors_by_numpy_ndarrays( - repr_ds: RepresentativeDataset, - sess: session.Session) -> RepresentativeDataset: + repr_ds: RepresentativeDataset, sess: session.Session +) -> RepresentativeDataset: """Replaces tf.Tensors in samples by their evaluated numpy arrays. Note: This should be run in graph mode (default in TF1) only. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset_test.py index d48d91c950e..4d2ff7aa776 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset_test.py @@ -40,9 +40,12 @@ def _contains_tensor(sample: repr_dataset.RepresentativeSample) -> bool: class RepresentativeDatasetTest(test.TestCase): """Tests functions for representative datasets.""" - def _assert_tensorlike_all_close(self, sess: session.Session, - tensorlike_value_1: core.TensorLike, - tensorlike_value_2: core.TensorLike) -> None: + def _assert_tensorlike_all_close( + self, + sess: session.Session, + tensorlike_value_1: core.TensorLike, + tensorlike_value_2: core.TensorLike, + ) -> None: """Asserts that two different TensorLike values are "all close". Args: @@ -59,9 +62,11 @@ def _assert_tensorlike_all_close(self, sess: session.Session, self.assertAllClose(tensorlike_value_1, tensorlike_value_2) def _assert_sample_values_all_close( - self, sess: session.Session, + self, + sess: session.Session, repr_ds_1: repr_dataset.RepresentativeDataset, - repr_ds_2: repr_dataset.RepresentativeDataset) -> None: + repr_ds_2: repr_dataset.RepresentativeDataset, + ) -> None: """Asserts that the sample values are "all close" between the two datasets. This assumes that the order of corresponding samples is preserved and the @@ -76,24 +81,29 @@ def _assert_sample_values_all_close( self.assertCountEqual(sample_1.keys(), sample_2.keys()) for input_key in sample_1: - self._assert_tensorlike_all_close(sess, sample_1[input_key], - sample_2[input_key]) + self._assert_tensorlike_all_close( + sess, sample_1[input_key], sample_2[input_key] + ) @test_util.deprecated_graph_mode_only def test_replace_tensors_by_numpy_ndarrays_with_tensor_list(self): num_samples = 8 samples = [ - np.random.uniform(low=-1., high=1., size=(3, 3)).astype('f4') + np.random.uniform(low=-1.0, high=1.0, size=(3, 3)).astype('f4') for _ in range(num_samples) ] - repr_ds: repr_dataset.RepresentativeDataset = [{ - 'input_tensor': ops.convert_to_tensor(sample), - } for sample in samples] + repr_ds: repr_dataset.RepresentativeDataset = [ + { + 'input_tensor': ops.convert_to_tensor(sample), + } + for sample in samples + ] with self.session() as sess: new_repr_ds = repr_dataset.replace_tensors_by_numpy_ndarrays( - repr_ds, sess) + repr_ds, sess + ) # The resulting dataset should not contain any tf.Tensors. self.assertFalse(any(map(_contains_tensor, new_repr_ds))) @@ -103,7 +113,7 @@ def test_replace_tensors_by_numpy_ndarrays_with_tensor_list(self): def test_replace_tensors_by_numpy_ndarrays_with_tensor_generator(self): num_samples = 8 samples = [ - np.random.uniform(low=-1., high=1., size=(1, 4)).astype('f4') + np.random.uniform(low=-1.0, high=1.0, size=(1, 4)).astype('f4') for _ in range(num_samples) ] @@ -113,7 +123,8 @@ def data_gen() -> repr_dataset.RepresentativeDataset: with self.session() as sess: new_repr_ds = repr_dataset.replace_tensors_by_numpy_ndarrays( - data_gen(), sess) + data_gen(), sess + ) # The resulting dataset should not contain any tf.Tensors. self.assertFalse(any(map(_contains_tensor, new_repr_ds))) @@ -122,13 +133,17 @@ def data_gen() -> repr_dataset.RepresentativeDataset: @test_util.deprecated_graph_mode_only def test_replace_tensors_by_numpy_ndarrays_is_noop_when_no_tensor(self): # Fill the representative dataset with np.ndarrays only. - repr_ds: repr_dataset.RepresentativeDataset = [{ - 'input_tensor': np.random.uniform(low=-1., high=1., size=(4, 3)), - } for _ in range(8)] + repr_ds: repr_dataset.RepresentativeDataset = [ + { + 'input_tensor': np.random.uniform(low=-1.0, high=1.0, size=(4, 3)), + } + for _ in range(8) + ] with self.session() as sess: new_repr_ds = repr_dataset.replace_tensors_by_numpy_ndarrays( - repr_ds, sess) + repr_ds, sess + ) # The resulting dataset should not contain any tf.Tensors. self.assertFalse(any(map(_contains_tensor, new_repr_ds))) @@ -138,24 +153,31 @@ def test_replace_tensors_by_numpy_ndarrays_is_noop_when_no_tensor(self): def test_replace_tensors_by_numpy_ndarrays_mixed_tensor_and_ndarray(self): num_tensors = 4 samples = [ - np.random.uniform(low=-1., high=1., size=(3, 3)).astype('f4') + np.random.uniform(low=-1.0, high=1.0, size=(3, 3)).astype('f4') for _ in range(num_tensors) ] - repr_ds: repr_dataset.RepresentativeDataset = [{ - 'tensor_key': ops.convert_to_tensor(sample), - } for sample in samples] + repr_ds: repr_dataset.RepresentativeDataset = [ + { + 'tensor_key': ops.convert_to_tensor(sample), + } + for sample in samples + ] # Extend the representative dataset with np.ndarrays. - repr_ds.extend([{ - 'tensor_key': np.random.uniform(low=-1., high=1., size=(3, 3)) - } for _ in range(4)]) + repr_ds.extend( + [ + {'tensor_key': np.random.uniform(low=-1.0, high=1.0, size=(3, 3))} + for _ in range(4) + ] + ) random.shuffle(repr_ds) with self.session() as sess: new_repr_ds = repr_dataset.replace_tensors_by_numpy_ndarrays( - repr_ds, sess) + repr_ds, sess + ) # The resulting dataset should not contain any tf.Tensors. self.assertFalse(any(map(_contains_tensor, new_repr_ds))) @@ -163,9 +185,10 @@ def test_replace_tensors_by_numpy_ndarrays_mixed_tensor_and_ndarray(self): def test_get_num_samples_returns_num_samples_when_list(self): num_samples = 8 - repr_ds = [{ - 'input': np.random.uniform(low=-1., high=1., size=(1, 2)) - } for _ in range(num_samples)] + repr_ds = [ + {'input': np.random.uniform(low=-1.0, high=1.0, size=(1, 2))} + for _ in range(num_samples) + ] self.assertEqual(repr_dataset.get_num_samples(repr_ds), num_samples) @@ -174,7 +197,9 @@ def test_get_num_samples_returns_none_for_generator(self): def data_gen() -> repr_dataset.RepresentativeDataset: for _ in range(num_samples): - yield {'input_tensor': np.random.uniform(low=-1., high=1., size=(1, 4))} + yield { + 'input_tensor': np.random.uniform(low=-1.0, high=1.0, size=(1, 4)) + } repr_ds = data_gen() self.assertIsNone(repr_dataset.get_num_samples(repr_ds)) @@ -184,7 +209,6 @@ def data_gen() -> repr_dataset.RepresentativeDataset: self.assertLen(list(repr_ds), num_samples) def test_get_num_samples_returns_none_when_len_raises_error(self): - class LenRaisingError: """A test-only class that raises an error when len() is called. @@ -195,10 +219,36 @@ class LenRaisingError: def __len__(self): raise ValueError( - 'You cannot take the len() of instance of LenRaisingError.') + 'You cannot take the len() of instance of LenRaisingError.' + ) self.assertIsNone(repr_dataset.get_num_samples(LenRaisingError())) +class RepresentativeDatasetSaverTest(test.TestCase): + """Test cases for RepresentativeDatasetSaver.""" + + def test_save_raises_error(self): + saver = repr_dataset.RepresentativeDatasetSaver() + repr_ds = {'serving_default': []} + + with self.assertRaisesRegex( + NotImplementedError, 'Method "save" is not implemented.' + ): + saver.save(repr_ds) + + +class RepresentativeDatasetLoaderTest(test.TestCase): + """Test cases for RepresentativeDatasetLoader.""" + + def test_load_raises_error(self): + loader = repr_dataset.RepresentativeDatasetLoader() + + with self.assertRaisesRegex( + NotImplementedError, 'Method "load" is not implemented.' + ): + loader.load() + + if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/save_model.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/save_model.py new file mode 100644 index 00000000000..03204ed5bc6 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/save_model.py @@ -0,0 +1,413 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines utilities involving SavedModel.""" + +from typing import Collection, Dict, Mapping, Optional, Sequence + +from absl import logging + +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import variables +from tensorflow.python.saved_model import builder +from tensorflow.python.saved_model import constants as saved_model_constants +from tensorflow.python.saved_model import loader_impl as saved_model_loader +from tensorflow.python.saved_model import tag_constants +from tensorflow.python.types import core + +# Mapping of signature def key -> SignatureDef. +_SignatureDefMap = Mapping[str, meta_graph_pb2.SignatureDef] + + +def get_signatures_from_saved_model( + saved_model_path: str, + signature_keys: Optional[Sequence[str]] = None, + tags: Optional[Collection[str]] = None, +) -> Dict[str, meta_graph_pb2.SignatureDef]: + """Gets a map from signature keys to their SignatureDef. + + Args: + saved_model_path: Path to the saved model. + signature_keys: List of keys identifying SignatureDef to retrieve. If None, + retrieve all except the init signature. + tags: Set of tags identifying the MetaGraphDef within the SavedModel. + + Returns: + A map from signature_key to its SignatureDef. + """ + if tags is None: + tags = {tag_constants.SERVING} + + loader = saved_model_loader.SavedModelLoader(saved_model_path) + meta_graphdef = loader.get_meta_graph_def_from_tags(tags) + signatures = {} + for key, signature_def in meta_graphdef.signature_def.items(): + if key == saved_model_constants.INIT_OP_SIGNATURE_KEY: + continue + if signature_keys is not None and key not in signature_keys: + continue + signatures[key] = signature_def + + return signatures + + +def _restore_output_tensor_names( + graph_def: graph_pb2.GraphDef, +) -> graph_pb2.GraphDef: + """Restores the output tensor names of the converted model. + + During the conversion, the output tensor names of the original model are + embedded in the `tf_saved_model.index_path` attribute of the RetVal nodes and + might become the name of Retval nodes as well (with an index suffix if there + are multiple output tensors from one node). Since Retval nodes are not used in + SavedModel, this function removes them and restore the names to the actual + output tensors. + + Args: + graph_def: the converted GraphDef. + + Returns: + The GraphDef with Retval nodes removed and output tensor names restored. + """ + output_renaming_map = {} + with session.Session(graph=ops.Graph()): + importer.import_graph_def(graph_def, name='') + graph = ops.get_default_graph() + for op in graph.get_operations(): + if op.type == '_Retval': + expected_node_name = op.name + if op.get_attr('tf_saved_model.index_path') is not None: + index_path_name = op.get_attr('tf_saved_model.index_path')[0] + index_path_name = index_path_name.decode('utf-8').split(':')[0] + try: + # Only use the index_path name if it points to a Retval node. + index_path_node = graph.get_operation_by_name(index_path_name) + if index_path_node.type == '_Retval': + expected_node_name = index_path_name + except KeyError: + pass + retval_input_node_name = op.inputs[0].op.name + output_renaming_map[retval_input_node_name] = expected_node_name + + for node in reversed(graph_def.node): + if node.name in output_renaming_map: + node.name = output_renaming_map[node.name] + elif node.op == '_Retval': + graph_def.node.remove(node) + else: + # Update the inputs referring to the pre-renaming node. + for idx, input_name in enumerate(node.input): + if input_name in output_renaming_map: + node.input[idx] = output_renaming_map[input_name] + # Update the control inputs referring to the pre-renaming node. + updating_inputs = [] + for input_name in reversed(node.input): + if input_name.startswith('^') and input_name[1:] in output_renaming_map: + updating_inputs.append(input_name[1:]) + node.input.remove(input_name) + for updating_input in updating_inputs: + node.input.append('^' + output_renaming_map[updating_input]) + return graph_def + + +def _create_empty_output_dir(output_directory: str) -> None: + """Creates the `output_directory`. + + If `output_directory` already exists, it recursively deletes all contents + inside the directory. + + Also creates the parent & intermediate directories. + + Args: + output_directory: Output directory. + """ + if file_io.file_exists_v2(output_directory): + logging.info( + 'Deleting existing directory for quantized model output: %s .', + output_directory, + ) + file_io.delete_recursively_v2(output_directory) + + file_io.recursive_create_dir_v2(output_directory) + + +def _validate_signatures( + signature_def_map: _SignatureDefMap, exported_graph: ops.Graph +) -> _SignatureDefMap: + """Validates if the tensor names in signatures are consistent with the graph. + + This function checks if the input and output tensor names in the signatures + exist if the graph. The output tensor names might change during conversion, + we try to fix that with `_restore_output_tensor_names`. Besides, if there + are duplicated tensor names, they we will be prefixed with the signature name. + However, if that doesn't work the signatures can't be used with the converted + graph. + + Args: + signature_def_map: the signatures to validate. + exported_graph: The PTQ-exported GraphDef. + + Returns: + The signatures with tensor names prefixed with signature name if necessary. + + Raises: + ValueError: Iff the signatures are not consistent with the graph. + """ + for signature_key, signature_def in signature_def_map.items(): + for tensor_info in signature_def.inputs.values(): + try: + exported_graph.get_tensor_by_name(tensor_info.name) + except KeyError as exc: + try: + prefixed_name = signature_key + '_' + tensor_info.name + exported_graph.get_tensor_by_name(prefixed_name) + tensor_info.name = prefixed_name + except KeyError: + raise ValueError( + 'Cannot find the input tensor with name %s in the graph.' + % tensor_info.name + ) from exc + + for tensor_info in signature_def.outputs.values(): + try: + exported_graph.get_tensor_by_name(tensor_info.name) + except KeyError as exc: + try: + prefixed_name = signature_key + '_' + tensor_info.name + exported_graph.get_tensor_by_name(prefixed_name) + tensor_info.name = prefixed_name + except KeyError: + raise ValueError( + 'Cannot find the output tensor with name %s in the graph.' + % tensor_info.name + ) from exc + + return signature_def_map + + +def _find_op( + graph: ops.Graph, op_name: Optional[str] +) -> Optional[ops.Operation]: + """Finds the operation with `op_name`. + + Args: + graph: The graph to find from. + op_name: Name of the node. + + Returns: + The operation that corresponds to `op_name`. Returns None iff op_name is an + empty string or None. + + Raises: + ValueError: `op_name` is malformed. + """ + if not op_name: + return None + + init_op = graph.get_operation_by_name(op_name) + logging.debug('Op found in the graph: %s', op_name) + + return init_op + + +def _find_file_prefix_tensor(graph: ops.Graph) -> Optional[core.Tensor]: + """Finds the "file_prefix" tensor used for identifying the checkpoint path. + + File prefix tensor can be identified as the output of an `_Arg` node which has + the value "__tf_file_prefix" in its "tf_saved_model.index_path" attribute. + This attribute should have been set to the file prefix argument by the + `InsertRestoreOpPass` when creating the `RestoreV2Op` for the variables. + + Args: + graph: The graph to find the file_prefix tensor from. + + Returns: + None if not found. True if a "file_prefix" tensor is found. + """ + for op in graph.get_operations(): + if op.type == '_Arg' and ( + b'__tf_file_prefix' in op.get_attr('tf_saved_model.index_path') + ): + candidate_tensor = op.outputs[0] + return candidate_tensor + + return None + + +def _create_empty_variable( + node_def: node_def_pb2.NodeDef, +) -> variables.Variable: + """Creates an empty `Variable`. + + Variables with unknown shape and empty value is created. + + Args: + node_def: Instance of `NodeDef` of the `VarHandleOp`. + + Returns: + Empty `Variable` with only `shared_name` and `dtype` populated according to + `node_def`. + """ + shared_name = str(node_def.attr['shared_name'].s, encoding='utf-8') + dtype: dtypes.DType = dtypes.as_dtype(node_def.attr['dtype'].type) + + return variables.Variable( + [], trainable=False, name=shared_name, dtype=dtype, shape=None + ) + + +def _find_variables( + graph_def: graph_pb2.GraphDef, +) -> Mapping[str, node_def_pb2.NodeDef]: + """Finds existing `VarHandleOp`s in the graph. + + Args: + graph_def: `GraphDef` to find variables from. + + Returns: + A shared_name -> `NodeDef` mapping that maps each `NodeDef` corresponding to + `VarHandleOp` to its `shared_name`. + """ + var_mapping = {} + for node in graph_def.node: + if node.op == 'VarHandleOp': + var_mapping[str(node.attr['shared_name'].s, encoding='utf-8')] = node + + for func in graph_def.library.function: + for node in func.node_def: + if node.op == 'VarHandleOp': + var_mapping[str(node.attr['shared_name'].s, encoding='utf-8')] = node + + return var_mapping + + +def _save_function_alias( + saved_model_dir: str, + tags: Collection[str], + function_aliases: Mapping[str, str], +) -> None: + """Saves the function alias to the SavedModel. + + SavedModelBuilder (TF1 saved model saver) does not support saving function + aliases, so this function loads the SavedModel proto and adds the + `function_aliases` field. + + Args: + saved_model_dir: Path to the saved model directory. + tags: A collection of tags to specify the meta graph. + function_aliases: Function name -> function alias mapping. + """ + loader = saved_model_loader.SavedModelLoader(saved_model_dir) + meta_graph_def = loader.get_meta_graph_def_from_tags(tags) + + for function_name, function_alias in function_aliases.items(): + meta_graph_def.meta_info_def.function_aliases[function_name] = ( + function_alias + ) + + saved_model_proto_serialized = loader.saved_model.SerializeToString() + + # TODO(b/266015731): Also update and set the SavedModel fingerprint. + path = file_io.join( + saved_model_dir, saved_model_constants.SAVED_MODEL_FILENAME_PB + ) + file_io.atomic_write_string_to_file(path, saved_model_proto_serialized) + + +def save_model_v1( + graph_def: graph_pb2.GraphDef, + output_dir: str, + signature_def_map: _SignatureDefMap, + tags: Collection[str], + init_op_name: Optional[str] = None, + restore_op_name: Optional[str] = None, + checkpoint_dir: Optional[str] = None, + variable_shared_names: Optional[Sequence[str]] = None, + function_aliases: Optional[Mapping[str, str]] = None, +) -> None: + """Saves the model. + + Saves the provided graph def as SavedModel. + Uses TF1 SavedModel semantics (i.e. no object graph). + + Args: + graph_def: Graph to save. + output_dir: Output directory for the SavedModel. + signature_def_map: Mapping of signature def key -> SignatureDef. + tags: Tags for the meta graph def. + init_op_name: Name of the node for initialization. + restore_op_name: Name of the node for restoration. + checkpoint_dir: Path to checkpoint file where variable values are saved. + variable_shared_names: Shared name of the variables in the model. + function_aliases: Function name -> function alias mapping. + + Raises: + ValueError iff the graph does not contain a valid signature. + """ + _create_empty_output_dir(output_dir) + v1_builder = builder.SavedModelBuilder(output_dir) + + graph_def = _restore_output_tensor_names(graph_def) + with session.Session(graph=ops.Graph()) as sess: + importer.import_graph_def(graph_def, name='') + + signature_def_map = _validate_signatures( + signature_def_map, ops.get_default_graph() + ) + + # `restore_op_name` is non-empty & non-None when variables should be + # restored before saving. + if restore_op_name: + var_mapping = _find_variables(graph_def) + logging.debug( + 'Shared names of the variables to be saved: %s', + str(list(var_mapping.keys())), + ) + + for shared_name in variable_shared_names: + var_node_def = var_mapping[shared_name] + + # Variables with unknown shape and empty value is created. This is + # just there to register a variable with `shared_name` to the resource + # manager and collections, so that the values in checkpoint is + # properly restored via `RestoreV2` op. Once restored, the value, + # dtype and shape will be properly populated. + _create_empty_variable(var_node_def) + + # Restores the variables by running the `RestoreV2` op. + # `v1_builder.save()` saves the restored variables to the variables/ + # directory in `output_dir`. + sess.run( + _find_op(sess.graph, op_name=restore_op_name), + feed_dict={_find_file_prefix_tensor(sess.graph): checkpoint_dir}, + ) + + v1_builder.add_meta_graph_and_variables( + sess, + tags, + signature_def_map=signature_def_map, + main_op=_find_op(sess.graph, op_name=init_op_name), + ) + + v1_builder.save() + + if function_aliases: + _save_function_alias(output_dir, tags, function_aliases) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto b/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto index 8ed6fec1f9c..ad221cdf3a7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto @@ -23,6 +23,7 @@ option cc_enable_arenas = true; // Various techniques for model quantization are defined within this message // along with a field that specifies a method to be used for a particular // quantization request. +// NEXT ID: 3 message QuantizationMethod { // Quantization methods that are supported as a stable API. enum Method { @@ -45,6 +46,9 @@ message QuantizationMethod { // determined in the graph executions. The weights are quantized during // conversion. DYNAMIC_RANGE = 2; + + // Weight-only quantization. Only weights are quantized during conversion. + WEIGHT_ONLY = 3; } // Quantization method is either exprimental or non-experimental method. @@ -71,8 +75,10 @@ enum QuantizationPrecision { // Unit (either nodes or ops at this moment) wise quantization method for // mixed bit precision quantization. It contains the name of the unit, // the granularity of the unit, and the quantization method for each unit. +// NEXT ID: 6 message UnitWiseQuantizationPrecision { // Quantization unit granularity. + // NEXT ID: 3 enum UnitType { // This should never be used. Using this will generally result in an error. UNIT_UNSPECIFIED = 0; @@ -98,6 +104,7 @@ message UnitWiseQuantizationPrecision { // List of supported opsets to deploy the quantized model. // The quantized model contains different set of ops depending on the opset. +// NEXT ID: 4 enum OpSet { OP_SET_UNSPECIFIED = 0; // go/do-include-enum-unspecified // Uses TF ops that mimic quantization behavior. Used when the corresponding @@ -109,6 +116,15 @@ enum OpSet { UNIFORM_QUANTIZED = 3; } +// Configurations for variable freezing during quantization passes. +// NEXT ID: 2 +message FreezeAllVariables { + // Setting this to true freezes all variables to constants during + // quantization. Setting this to `false` is an experimental feature and does + // not have stability guarantees. + bool enabled = 1; +} + // Defines various options to specify and control the behavior of the quantizer. // It consists of // 1) Model-wise quantization configuration as a default configuration. If it is @@ -116,6 +132,7 @@ enum OpSet { // 2) A set of supported operations. // 3) Unit wise quantization precision. // 4) Target hardware name. +// NEXT ID: 9 message QuantizationOptions { // The default quantization configuration for the model. If the below // unit-wise configuration does not exist, we use this default quantization @@ -123,9 +140,11 @@ message QuantizationOptions { // exists, this default one will become the quantization configuration for // units that are not specified in unit-wise configurations. QuantizationMethod quantization_method = 1; - OpSet op_set = 2; + + OpSet op_set = 2; // If not specified, it defaults to `TF`. QuantizationPrecision quantization_precision = 3; + // Quantization precision for each unit. Units can become either // nodes or ops, and the mixture of those different units are allowed. // If there are conflicts or ambiguity in this unit-wise precision, our @@ -136,4 +155,21 @@ message QuantizationOptions { // supported for Post-training Dynamic Range Quantization. By default, it is // set to 1024. To disable this, set the value to -1 explicitly. int64 min_num_elements_for_weights = 5; + + // When set to `true`, freezes all variables in the model into constants. + // When set to `false` the model's large constants are converted to variables. + // Setting this to `false` is an experimental feature and quantization may + // fail. To quantize models larger than 2 GiB, this should be set to `false`. + // If not set, it defaults to `true`. + FreezeAllVariables freeze_all_variables = 6; + + // Enables chnanel-wise quantizaiton. By default, channel-wise quantization is + // not applied regardless of the op support. Currently, it is supported for + // Uniform Quantized opset only. + bool enable_per_channel_quantization = 7; + + // Enables two inputs of an operation to be both tensors. + // Currently supports MatMul and BatchMatMul ops for XLA. + // TODO(b/263528090): Check the condition when this feature is beneficial. + bool enable_two_input_tensors = 8; } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc index 0cbd943990b..d40d8632d87 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -47,63 +48,90 @@ void AddQuantizeQatPasses(mlir::PassManager &pm, const QuantizationOptions &quantization_options) { pm.addNestedPass( mlir::quant::CreateConvertFakeQuantToQdqPass()); - pm.addNestedPass( - mlir::TF::CreateUnrollBatchMatMulPassPass()); - // TODO(b/229995333): Add PrepareLiftingPass for QAT. In QAT, AffineOps are - // connected to FakeQuantOp instead of the ConstOp so need to add separate - // pattern for FakeQuantOp. - // pm.addNestedPass(mlir::quant::CreatePrepareLiftingPass()); + // TODO(b/260031290): Set unfold_batchmatmul = false for ODML support + if (quantization_options.op_set() == OpSet::UNIFORM_QUANTIZED) { + pm.addNestedPass( + mlir::TF::CreateUnrollBatchMatMulPassPass()); + } pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + pm.addNestedPass(mlir::quant::CreatePrepareLiftingPass()); pm.addPass(mlir::quant::CreateLiftQuantizableSpotsAsFunctionsPass( - quantization_options.op_set())); + quantization_options.op_set(), + quantization_options.enable_two_input_tensors())); pm.addPass(mlir::quant::CreateInsertQuantizedFunctionsPass( - mlir::quant::QuantizationMethod::kQuantizationAwareTraining, + quantization_options.quantization_method().experimental_method(), quantization_options.op_set())); + // TODO(b/260677670): Pass quantization options as pass's inputs where + // applicable pm.addPass(mlir::quant::CreateQuantizeCompositeFunctionsPass( - mlir::quant::QuantizationMethod::kQuantizationAwareTraining, - quantization_options.op_set())); + quantization_options.quantization_method().experimental_method(), + quantization_options.op_set(), + quantization_options.enable_per_channel_quantization(), + quantization_options.min_num_elements_for_weights())); pm.addPass(mlir::createSymbolDCEPass()); pm.addPass(mlir::TF::CreateTFShapeInferencePass()); - // For XLA opset, the graph is inlined to take benefit of constant folding - // and the TF Conv/Matmul ops with cast-hack are converted to XLA ops. - if (quantization_options.op_set() == OpSet::XLA) { + // TODO(b/264637396): Deprecate TF opset + if (quantization_options.op_set() != OpSet::TF) { pm.addPass(mlir::createInlinerPass()); pm.addPass(mlir::TF::CreateTFShapeInferencePass()); pm.addNestedPass(mlir::createCanonicalizerPass()); - pm.addNestedPass( - mlir::quant::CreateReplaceCastHacksWithTFXLAOpsPass()); + if (quantization_options.op_set() == OpSet::XLA) { + pm.addNestedPass( + mlir::quant::CreateReplaceCastHacksWithTFXLAOpsPass()); + } pm.addNestedPass(mlir::createCSEPass()); } - pm.addNestedPass(mlir::quant::CreateOptimizePass()); } void AddQuantizePtqDynamicRangePasses( mlir::PassManager &pm, const QuantizationOptions &quantization_options) { + // TODO(b/260031290): Set unfold_batchmatmul = false for ODML support pm.addNestedPass( mlir::TF::CreateUnrollBatchMatMulPassPass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); pm.addNestedPass(mlir::quant::CreatePrepareLiftingPass()); pm.addPass(mlir::quant::CreateLiftQuantizableSpotsAsFunctionsDRQPass( quantization_options.min_num_elements_for_weights())); pm.addPass(mlir::quant::CreateInsertQuantizedFunctionsPass( - mlir::quant::QuantizationMethod::kDynamicRangeQuantization, + quantization_options.quantization_method().experimental_method(), quantization_options.op_set())); pm.addPass(mlir::quant::CreateQuantizeCompositeFunctionsPass( - mlir::quant::QuantizationMethod::kDynamicRangeQuantization, - quantization_options.op_set())); + quantization_options.quantization_method().experimental_method(), + quantization_options.op_set(), + quantization_options.enable_per_channel_quantization(), + quantization_options.min_num_elements_for_weights())); pm.addPass(mlir::createSymbolDCEPass()); pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + + // TODO(b/264637396): Deprecate TF opset + if (quantization_options.op_set() != OpSet::TF) { + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + if (quantization_options.op_set() == OpSet::XLA) { + pm.addNestedPass( + mlir::quant::CreateReplaceCastHacksWithTFXLAOpsPass()); + } + pm.addNestedPass(mlir::createCSEPass()); + } + + pm.addNestedPass(mlir::quant::CreateOptimizePass()); } void AddQuantizePtqPreCalibrationPasses( mlir::PassManager &pm, const QuantizationOptions &quantization_options) { - pm.addNestedPass( - mlir::TF::CreateUnrollBatchMatMulPassPass()); - pm.addNestedPass(mlir::quant::CreatePrepareLiftingPass()); + // TODO(b/260031290): Set unfold_batchmatmul = false for ODML support + if (quantization_options.op_set() == OpSet::UNIFORM_QUANTIZED) { + pm.addNestedPass( + mlir::TF::CreateUnrollBatchMatMulPassPass()); + } pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + pm.addNestedPass(mlir::quant::CreatePrepareLiftingPass()); pm.addPass(mlir::quant::CreateLiftQuantizableSpotsAsFunctionsPass( - quantization_options.op_set())); + quantization_options.op_set(), + quantization_options.enable_two_input_tensors())); pm.addNestedPass( mlir::quant::CreateInsertCustomAggregationOpsPass()); pm.addPass(mlir::quant::CreateIssueIDsOfCustomAggregationOpsPass()); @@ -112,28 +140,31 @@ void AddQuantizePtqPreCalibrationPasses( void AddQuantizePtqPostCalibrationPasses( mlir::PassManager &pm, const QuantizationOptions &quantization_options) { pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); pm.addNestedPass( mlir::quant::CreateConvertCustomAggregationOpToQuantStatsPass()); pm.addPass(mlir::quant::CreateInsertQuantizedFunctionsPass( - mlir::quant::QuantizationMethod::kPostTrainingQuantization, + quantization_options.quantization_method().experimental_method(), quantization_options.op_set())); pm.addPass(mlir::quant::CreateQuantizeCompositeFunctionsPass( - mlir::quant::QuantizationMethod::kPostTrainingQuantization, - quantization_options.op_set())); + quantization_options.quantization_method().experimental_method(), + quantization_options.op_set(), + quantization_options.enable_per_channel_quantization(), + quantization_options.min_num_elements_for_weights())); pm.addPass(mlir::createSymbolDCEPass()); pm.addPass(mlir::TF::CreateTFShapeInferencePass()); - // For XLA opset, the graph is inlined to take benefit of constant folding - // and the TF Conv/Matmul ops with cast-hack are converted to XLA ops. - if (quantization_options.op_set() == OpSet::XLA) { + // TODO(b/264637396): Deprecate TF opset + if (quantization_options.op_set() != OpSet::TF) { pm.addPass(mlir::createInlinerPass()); pm.addPass(mlir::TF::CreateTFShapeInferencePass()); pm.addNestedPass(mlir::createCanonicalizerPass()); - pm.addNestedPass( - mlir::quant::CreateReplaceCastHacksWithTFXLAOpsPass()); + if (quantization_options.op_set() == OpSet::XLA) { + pm.addNestedPass( + mlir::quant::CreateReplaceCastHacksWithTFXLAOpsPass()); + } pm.addNestedPass(mlir::createCSEPass()); } - pm.addNestedPass(mlir::quant::CreateOptimizePass()); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc index 35226cb93b7..e77091aba37 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc @@ -14,39 +14,64 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/Optional.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project -#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/public/session.h" namespace tensorflow { namespace quantization { +namespace { + +absl::Status RunPassesOnModuleOp(const absl::string_view mlir_dump_file_name, + mlir::PassManager& pass_manager, + mlir::ModuleOp module_op) { + mlir::StatusScopedDiagnosticHandler statusHandler(module_op.getContext(), + /*propagate=*/true); + + const absl::StatusOr> dump_file = + MaybeEnableIrPrinting(pass_manager, mlir_dump_file_name); + if (!dump_file.ok()) { + return dump_file.status(); + } + + if (failed(pass_manager.run(module_op))) { + return tsl::ToAbslStatus(statusHandler.ConsumeStatus()); + } + + return absl::OkStatus(); +} -Status PreprocessAndFreezeGraph(mlir::ModuleOp module, - mlir::MLIRContext* context, - llvm::Optional session) { +} // namespace + +absl::Status PreprocessAndFreezeGraph( + const absl::string_view mlir_dump_file_prefix, const bool is_inliner_run, + const absl::flat_hash_set& noinline_functions, + mlir::ModuleOp module_op, mlir::MLIRContext* context, + llvm::Optional session) { mlir::PassManager pm_before_freezing_variables(context); - mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(), + mlir::StatusScopedDiagnosticHandler statusHandler(module_op.getContext(), /*propagate=*/true); mlir::TF::StandardPipelineOptions standard_pipeline_options; @@ -61,22 +86,33 @@ Status PreprocessAndFreezeGraph(mlir::ModuleOp module, mlir::PassManager pm_after_freezing_variables(context); pm_after_freezing_variables.addPass(mlir::TF::CreateTFShapeInferencePass()); pm_after_freezing_variables.addPass(mlir::createCanonicalizerPass()); - pm_after_freezing_variables.addPass(mlir::createInlinerPass()); - if (failed(pm_before_freezing_variables.run(module))) { - return statusHandler.ConsumeStatus(); + // Makes certain functions immune to the `InlinerPass`. Used to preserve + // aliased functions. + pm_after_freezing_variables.addNestedPass( + mlir::quant::CreateMarkFunctionsNoinlinePass(std::vector( + noinline_functions.begin(), noinline_functions.end()))); + if (is_inliner_run) { + pm_after_freezing_variables.addPass(mlir::createInlinerPass()); } - if (session.has_value() && failed(mlir::tf_saved_model::FreezeVariables( - module, session.getValue()))) { - return statusHandler.ConsumeStatus(); + if (const auto pre_variable_freezing_status = RunPassesOnModuleOp( + /*mlir_dump_file_name=*/absl::StrCat( + mlir_dump_file_prefix, "_preprocess_pre_variable_freezing"), + pm_before_freezing_variables, module_op); + !pre_variable_freezing_status.ok()) { + return pre_variable_freezing_status; } - if (failed(pm_after_freezing_variables.run(module))) { - return statusHandler.ConsumeStatus(); + if (session.has_value() && failed(mlir::tf_saved_model::FreezeVariables( + module_op, session.value()))) { + return tsl::ToAbslStatus(statusHandler.ConsumeStatus()); } - return OkStatus(); + return RunPassesOnModuleOp( + /*mlir_dump_file_name=*/absl::StrCat( + mlir_dump_file_prefix, "_preprocess_post_variable_freezing"), + pm_after_freezing_variables, module_op); } } // namespace quantization diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h index c89278e9abd..6914f4ade18 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h @@ -15,17 +15,46 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_PREPROCESS_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_PREPROCESS_H_ +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/Optional.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" -#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/public/session.h" namespace tensorflow { namespace quantization { -Status PreprocessAndFreezeGraph(mlir::ModuleOp module, - mlir::MLIRContext* context, - llvm::Optional session); +// Default MLIR dump file prefix for TensorFlow quantization passes. +inline constexpr absl::string_view kDefaultTfQuantMlirDumpFilePrefix = + "tf_quant"; + +// Preprocesses the `module_op` for quantization. The preprocess steps include +// freezing the variables in the graph into constants. `is_inliner_run` +// determines whether the `InlinerPass` should be run after unfreezing. +// +// `mlir_dump_file_prefix` is primarily used for debugging and does not affect +// the preprocessing behavior. Instructions for producing MLIR dump files are in +// the comments of `tensorflow::quantization::MaybeEnableIrPrinting` function. +absl::Status PreprocessAndFreezeGraph( + absl::string_view mlir_dump_file_prefix, bool is_inliner_run, + const absl::flat_hash_set& noinline_functions, + mlir::ModuleOp module_op, mlir::MLIRContext* context, + llvm::Optional session); + +// Overload of `PreprocessAndFreezeGraph` that uses the default MLIR dump file +// prefix. +inline absl::Status PreprocessAndFreezeGraph(mlir::ModuleOp module_op, + mlir::MLIRContext* context, + llvm::Optional session) { + return PreprocessAndFreezeGraph( + /*mlir_dump_file_prefix=*/kDefaultTfQuantMlirDumpFilePrefix, + /*is_inliner_run=*/true, /*noinline_functions=*/{}, module_op, context, + session); +} } // namespace quantization } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/tests/BUILD index 9a11a6fc4d2..2587a70d9cf 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/BUILD @@ -1,6 +1,7 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/compiler/mlir/quantization/tensorflow:internal_visibility_allowlist_package", ], @@ -23,6 +24,8 @@ glob_lit_tests( data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", size_override = { + "insert_quantized_functions.mlir": "medium", + "insert_quantized_functions_drq.mlir": "medium", "replace_cast_hacks_with_tf_xla_ops_large_constants.mlir": "medium", }, tags_override = { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/duplicate_shape_determining_constants.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/duplicate_shape_determining_constants.mlir new file mode 100644 index 00000000000..f01274b5ff4 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/duplicate_shape_determining_constants.mlir @@ -0,0 +1,223 @@ +// RUN: tf-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -quant-duplicate-shape-determining-constants | FileCheck %s + +// CHECK-LABEL: @duplicate_const_for_shape_determining_operand_at_idx_1 +// CHECK-SAME: (%[[ARG_0:.*]]: tensor) +func.func private @duplicate_const_for_shape_determining_operand_at_idx_1(%arg0: tensor) -> tensor { + %cst = "tf.Const"() {device = "", value = dense<2> : tensor} : () -> tensor + // idx 1 should be a compile time constant + %0 = "tf.ExpandDims"(%arg0, %cst) {device = ""} : (tensor, tensor) -> tensor + %1 = "tf.AddV2"(%cst, %cst) {device = ""} : (tensor, tensor) -> tensor + + return %0 : tensor +} +// Check that the constant is cloned with same value. +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<2> : tensor +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<2> : tensor + +// Check that the constants used for tf.ExpandDims and tf.AddV2 are different. +// CHECK: %[[EXPAND_DIMS:.*]] = "tf.ExpandDims"(%[[ARG_0]], %[[CST_1]]) +// CHECK: %[[ADDV2:.*]] = "tf.AddV2"(%[[CST_0]], %[[CST_0]]) + +// ----- + +// CHECK-LABEL: @duplicate_const_for_shape_determining_operand_at_idx_2 +// CHECK-SAME: (%[[ARG_0:.*]]: tensor<16x4xf32>, %[[ARG_1:.*]]: tensor<16xi32>) +func.func private @duplicate_const_for_shape_determining_operand_at_idx_2(%arg0: tensor<16x4xf32>, %arg1: tensor<16xi32>) -> tensor<16xf32> { + %cst = "tf.Const"() {device = "", value = dense<[1]> : tensor<1xi32>} : () -> tensor<1xi32> + // idx 2 should be a compile time constant + %0 = "tf.GatherV2"(%arg0, %arg1, %cst) {batch_dims = 1: i64} : (tensor<16x4xf32>, tensor<16xi32>, tensor<1xi32>) -> tensor<16xf32> + + // Just to introduce an extra use for %cst. + %1 = "tf.AddV2"(%cst, %cst) {device = ""} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + return %0 : tensor<16xf32> +} +// Check that the constant is cloned with same value. +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<1> : tensor<1xi32> +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<1> : tensor<1xi32> + +// Check that the constants used for tf.GatherV2 and tf.AddV2 are different. +// CHECK: %[[GATHER_V2:.*]] = "tf.GatherV2"(%[[ARG_0]], %[[ARG_1]], %[[CST_1]]) +// CHECK: %[[ADDV2:.*]] = "tf.AddV2"(%[[CST_0]], %[[CST_0]]) + +// ----- + +// CHECK-LABEL: @duplicate_const_for_shape_determining_operand_with_variadic_operand +// CHECK-SAME: %[[ARG_0:.*]]: tensor<16x1xf32> +func.func private @duplicate_const_for_shape_determining_operand_with_variadic_operand(%arg0: tensor<16x1xf32>) -> tensor<16x4xf32> { + %axis = "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + // tf.ConcatV2 accepts a variadic operand. The last operand should be compile + // time constant. + %0 = "tf.ConcatV2"(%arg0, %arg0, %arg0, %arg0, %axis) : (tensor<16x1xf32>, tensor<16x1xf32>, tensor<16x1xf32>, tensor<16x1xf32>, tensor) -> tensor<16x4xf32> + + // Just to introduce an extra use for %cst. + %1 = "tf.AddV2"(%axis, %axis) {device = ""} : (tensor, tensor) -> tensor + + return %0 : tensor<16x4xf32> +} +// Check that the constant is cloned with same value. +// The duplicated constant is the last index of the ConcatV2 op (which +// accepts a variadic arg). +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<1> : tensor +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<1> : tensor + +// Check that the constants used for tf.ConcatV2 and tf.AddV2 are different. +// CHECK: %[[CONCAT_V2:.*]] = "tf.ConcatV2"(%[[ARG_0]], %[[ARG_0]], %[[ARG_0]], %[[ARG_0]], %[[CST_1]]) +// CHECK: %[[ADDV2:.*]] = "tf.AddV2"(%[[CST_0]], %[[CST_0]]) + +// ----- + +// CHECK-LABEL: @duplicate_const_for_multiple_shape_determining_operands +// CHECK-SAME: %[[ARG_0:.*]]: tensor<8x4x16x16x16xf32> +// CHECK-SAME: %[[ARG_1:.*]]: tensor<4x3x3x16x16xf32> +func.func private @duplicate_const_for_multiple_shape_determining_operands( + %arg0: tensor<8x4x16x16x16xf32>, %arg1: tensor<4x3x3x16x16xf32>) -> tensor<8x4x14x14x16xf32> { + %strides = "tf.Const"() {value = dense<[3, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %padding = "tf.Const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + %lhs_dilation = "tf.Const"() {value = dense<[4, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %rhs_dilation = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> + %feature_group_count = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + + // tf.XlaConvV2's 2, 3, 4, 5, 6 indices should be compile-time constants. + %0 = "tf.XlaConvV2"(%arg0, %arg1, %strides, %padding, %lhs_dilation, %rhs_dilation, %feature_group_count) { + batch_group_count = 1 : i64, + dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", + precision_config = ""} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, + tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<8x4x14x14x16xf32> + + // Just to introduce an extra use for %cst. + %1 = "tf.AddV2"(%feature_group_count, %feature_group_count) {device = ""} : (tensor, tensor) -> tensor + %2 = "tf.AddV2"(%lhs_dilation, %lhs_dilation) {device = ""} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + %3 = "tf.AddV2"(%rhs_dilation, %rhs_dilation) {device = ""} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + %4 = "tf.AddV2"(%padding, %padding) {device = ""} : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> + %5 = "tf.AddV2"(%strides, %strides) {device = ""} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + + return %0 : tensor<8x4x14x14x16xf32> +} + +// Check that the constants that are input to XlaConvV2's 3rd, 4th, 5th, 6th +// and 7th arguments are cloned with same value. +// CHECK-DAG: %[[STRIDES:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<[3, 1, 1]> : tensor<3xi32> +// CHECK-DAG: %[[STRIDES_COPY:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<[3, 1, 1]> : tensor<3xi32> +// CHECK-DAG: %[[PADDING:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<0> : tensor<3x2xi32> +// CHECK-DAG: %[[PADDING_COPY:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<0> : tensor<3x2xi32> +// CHECK-DAG: %[[LHS_DILATION:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<[4, 1, 1]> : tensor<3xi32> +// CHECK-DAG: %[[LHS_DILATION_COPY:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<[4, 1, 1]> : tensor<3xi32> +// CHECK-DAG: %[[RHS_DILATION:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<1> : tensor<3xi32> +// CHECK-DAG: %[[RHS_DILATION_COPY:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<1> : tensor<3xi32> +// CHECK-DAG: %[[FEATURE_GROUP_COUNT:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<1> : tensor +// CHECK-DAG: %[[FEATURE_GROUP_COUNT_COPY:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<1> : tensor + +// Check that the constants that are input to XlaConvV2's 3rd and 4th +// arguments are not duplicated. +// CHECK-NOT: "tf.Const"() + +// Check that the constants used for tf.XlaConvV2 and tf.AddV2s are different. +// CHECK: %[[GATHER_V2:.*]] = "tf.XlaConvV2"(%[[ARG_0]], %[[ARG_1]], %[[STRIDES_COPY]], %[[PADDING_COPY]], %[[LHS_DILATION_COPY]], %[[RHS_DILATION_COPY]], %[[FEATURE_GROUP_COUNT_COPY]]) + +// CHECK: %[[ADDV2_2:.*]] = "tf.AddV2"(%[[FEATURE_GROUP_COUNT]], %[[FEATURE_GROUP_COUNT]]) +// CHECK: %[[ADDV2_0:.*]] = "tf.AddV2"(%[[LHS_DILATION]], %[[LHS_DILATION]]) +// CHECK: %[[ADDV2_1:.*]] = "tf.AddV2"(%[[RHS_DILATION]], %[[RHS_DILATION]]) + +// ----- + +// CHECK-LABEL: @stop_recursion_when_arg_is_reached +func.func private @stop_recursion_when_arg_is_reached(%arg0: tensor<1x2x3xf32>, %arg1: tensor) -> tensor { +// The pass wants to duplicate constants for TF::MeanOp's operand idx 1, but +// it can't proceed since it is a function argument. + +// expected-warning @+1 {{Operand idx (zero-based): 1 does not have a defining op and cannot be duplicated}} + %0 = "tf.Mean"(%arg0, %arg1) {device = ""} : (tensor<1x2x3xf32>, tensor) -> tensor + + return %0: tensor +} + +// ----- + +// CHECK-LABEL: @constant_with_single_use_not_duplicated +func.func private @constant_with_single_use_not_duplicated(%arg0: tensor<1x2x3xf32>) -> tensor<1x3xf32> { + %cst = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %0 = "tf.AddV2"(%cst, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %1 = "tf.Max"(%arg0, %0) {device = ""} : (tensor<1x2x3xf32>, tensor) -> tensor<1x3xf32> + + return %1: tensor<1x3xf32> +} +// CHECK-DAG: %[[CST:.*]] = "tf.Const" +// CHECK-SAME: dense<0> +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const" +// CHECK-SAME: dense<1> +// Check that there are no extra "tf.Const"s existing in this function. +// CHECK-NOT: "tf.Const" + +// Check that the usages of %[[CST]] and %[[CST_0]] are untouched. +// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[CST]], %[[CST_0]]) +// CHECK: "tf.Max"({{.*}}, %[[ADD]]) + +// ----- + +// CHECK-LABEL: @recursively_duplicate_constants +func.func private @recursively_duplicate_constants(%arg0: tensor<1x2x3xf32>) -> tensor<1x3xf32> { + %cst = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %0 = "tf.AddV2"(%cst, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %1 = "tf.Max"(%arg0, %0) {device = ""} : (tensor<1x2x3xf32>, tensor) -> tensor<1x3xf32> + + // Just to introduce extra usages for %cst and %cst_0. + %2 = "tf.Mul"(%cst, %cst_0) {device = ""} : (tensor, tensor) -> tensor + + return %1: tensor<1x3xf32> +} +// Check that both constants are duplicated, which are used to transitively +// determine the shape of the result of `tf.Max`. +// CHECK-DAG: %[[CST:.*]] = "tf.Const" +// CHECK-SAME: dense<0> +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const" +// CHECK-SAME: dense<0> +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const" +// CHECK-SAME: dense<1> +// CHECK-DAG: %[[CST_2:.*]] = "tf.Const" +// CHECK-SAME: dense<1> + +// ----- + +// CHECK-LABEL: @early_stop_at_shape_op +func.func private @early_stop_at_shape_op() -> tensor<1x3xi32> { + %cst = "tf.Const"() {device = "", value = dense<1.0> : tensor<1x3xf32>} : () -> tensor<1x3xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<2> : tensor} : () -> tensor + %1 = "tf.Shape"(%cst) : (tensor<1x3xf32>) -> tensor<2xi32> + // Operand index 0 ($dims) should be a compile-time constant. + %2 = "tf.Fill"(%1, %cst_0) {device = ""} : (tensor<2xi32>, tensor) -> tensor<1x3xi32> + + // Just to introduce extra usages for %cst. + %3 = "tf.Mul"(%cst, %cst) {device = ""} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + + return %2: tensor<1x3xi32> +} +// The output of tf.Shape is considered a compile-time constant, so the +// constant leading to tf.Shape (which transitively becomes an input to the +// first arg of tf.Fill) is not duplicated. + +// CHECK-DAG: %[[CST:.*]] = "tf.Const" +// CHECK-SAME: dense<1.000000e+00> : tensor<1x3xf32> +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const" +// CHECK-SAME: dense<2> : tensor +// CHECK: %[[SHAPE:.*]] = "tf.Shape"(%[[CST]]) +// CHECK: %[[FILL:.*]] = "tf.Fill"(%[[SHAPE]], %[[CST_0]]) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/fake_quant_e2e_xla.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/fake_quant_e2e_xla.mlir index 325fe2bd355..5d9801e2085 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/fake_quant_e2e_xla.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/fake_quant_e2e_xla.mlir @@ -18,30 +18,35 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK-LABEL: func @conv_with_multiple_uses // CHECK: %[[div:.*]] = "tf.Div"(%arg0 // CHECK: %[[add:.*]] = "tf.AddV2"(%[[div]] -// CHECK: %[[floor:.*]] = "tf.Floor"(%[[add]] -// CHECK: %[[clip:.*]] = "tf.ClipByValue"(%[[floor]] -// CHECK: %[[quant:.*]] = "tf.Cast"(%[[clip]]) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xi8> +// CHECK: %[[maximum:.*]] = "tf.Maximum"(%[[add]] +// CHECK: %[[minimum:.*]] = "tf.Minimum"(%[[maximum]] +// CHECK: %[[round:.*]] = "tf.Round"(%[[minimum]] +// CHECK: %[[quant:.*]] = "tf.Cast"(%[[round]]) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xi8> // CHECK: %[[pad:.*]] = "tf.PadV2"(%[[quant]] // CHECK: %[[xlaconv:.*]] = "tf.XlaConvV2"(%[[pad]] // CHECK: %[[sub:.*]] = "tf.Sub"(%[[xlaconv]] // CHECK: %[[cast:.*]] = "tf.Cast"(%[[sub]]) {Truncate = false} : (tensor<1x3x2x2xi32>) -> tensor<1x3x2x2xf32> // CHECK: %[[dequant1:.*]] = "tf.Mul"(%[[cast]] // CHECK: %[[relu:.*]] = "tf.Relu"(%[[dequant1]] +// CHECK: %[[clamped:.*]] = "tf.Minimum"(%[[relu]] // CHECK: %[[rescale1:.*]] = "tf.Mul"(%[[cast]] // CHECK: %[[add2:.*]] = "tf.AddV2"(%[[rescale1]] -// CHECK: %[[floor2:.*]] = "tf.Floor"(%[[add2]] -// CHECK: %[[clip2:.*]] = "tf.ClipByValue"(%[[floor2]] -// CHECK: %[[quant2:.*]] = "tf.Cast"(%[[clip2]]) {Truncate = false} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xi8> +// CHECK: %[[maximum2:.*]] = "tf.Maximum"(%[[add2]] +// CHECK: %[[minimum2:.*]] = "tf.Minimum"(%[[maximum2]] +// CHECK: %[[round2:.*]] = "tf.Round"(%[[minimum2]] +// CHECK: %[[quant2:.*]] = "tf.Cast"(%[[round2]]) {Truncate = false} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xi8> // CHECK: %[[pad2:.*]] = "tf.PadV2"(%[[quant2]] // CHECK: %[[xlaconv2:.*]] = "tf.XlaConvV2"(%[[pad2]] // CHECK: %[[sub2:.*]] = "tf.Sub"(%[[xlaconv2]] // CHECK: %[[cast2:.*]] = "tf.Cast"(%[[sub2]]) {Truncate = false} : (tensor<1x3x2x2xi32>) -> tensor<1x3x2x2xf32> // CHECK: %[[rescale2:.*]] = "tf.Mul"(%[[cast2]] +// CHECK: %[[rescale2_maxclamped:.*]] = "tf.Maximum"(%[[rescale2]] +// CHECK: %[[rescale2_minclamped:.*]] = "tf.Minimum"(%[[rescale2_maxclamped]] -// CHECK: %[[sum:.*]] = "tf.Sum"(%[[relu]] -// CHECK: return %[[rescale2]], %[[sum]] +// CHECK: %[[sum:.*]] = "tf.Sum"(%[[clamped]] +// CHECK: return %[[rescale2_minclamped]], %[[sum]] } // ----- diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_main_function.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_main_function.mlir index cd43b2e6aa4..e34b5ed0fb0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_main_function.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_main_function.mlir @@ -1,4 +1,5 @@ -// RUN: tf-quant-opt %s -quant-add-main-function -allow-unregistered-dialect -mlir-disable-threading -split-input-file | FileCheck %s +// RUN: tf-quant-opt %s -quant-insert-main-function -mlir-disable-threading \ +// RUN: -allow-unregistered-dialect -split-input-file | FileCheck %s // CHECK-LABEL: module attributes {tf.versions = {producer = 930 : i32}, tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { module attributes {tf.versions = {producer = 930 : i32}, tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { @@ -34,7 +35,9 @@ module attributes {tf.versions = {producer = 930 : i32}, tf_saved_model.semantic // CHECK-NOT: f = @NoOp // CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @mul1} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> // CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg2, %arg3) {config = "", config_proto = "", executor_type = "", f = @mul2} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> -// CHECK: return %[[PARTITIONEDCALL_0]], %[[PARTITIONEDCALL_1]] : tensor<1xf32>, tensor<1xf32> +// CHECK: %[[IDENTITY_0:.*]] = "tf.Identity"(%[[PARTITIONEDCALL_1]]) +// CHECK: %[[IDENTITY_1:.*]] = "tf.Identity"(%[[PARTITIONEDCALL_0]]) +// CHECK: return %[[IDENTITY_1]], %[[IDENTITY_0]] : tensor<1xf32>, tensor<1xf32> // CHECK: } } @@ -84,5 +87,128 @@ module attributes {tf.versions = {producer = 1132 : i32}, tf_saved_model.semanti // CHECK-SAME: f = @add // CHECK-SAME: } // CHECK-SAME: : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> -// CHECK: return %[[CALL0]] : tensor<1xf32> +// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[CALL0]]) +// CHECK: return %[[IDENTITY]] : tensor<1xf32> +} + +// ----- + +// Test a case where an entry function return multiple values +module attributes {tf.versions = {producer = 930 : i32}, tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () + func.func @NoOp() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"]} { + func.return + } + + func.func @topk(%arg0: tensor<16xf32> {tf_saved_model.index_path = ["input"]}, %arg1: tensor {tf_saved_model.index_path = ["k"]}) -> (tensor {tf_saved_model.index_path = ["values"]}, tensor {tf_saved_model.index_path = ["indices"]}) attributes {tf.entry_function = {inputs = "input:0,k:0", outputs = "TopK:0,TopK:1"}, tf_saved_model.exported_names = ["topk"]} { + %0:2 = "tf.TopKV2"(%arg0, %arg1): (tensor<16xf32>, tensor) -> (tensor, tensor) + func.return %0#0, %0#1: tensor, tensor + } + +// CHECK: func.func private @topk(%arg0: tensor<16xf32>, %arg1: tensor) -> (tensor, tensor) +// CHECK-SAME: attributes {tf.entry_function = {inputs = "input:0,k:0", outputs = "TopK:0,TopK:1"}} + +// CHECK: func.func @main(%arg0: tensor<16xf32> {tf_saved_model.index_path = ["input:0"]}, %arg1: tensor {tf_saved_model.index_path = ["k:0"]}) +// CHECK-SAME: -> (tensor {tf_saved_model.index_path = ["TopK:0"]}, tensor {tf_saved_model.index_path = ["TopK:1"]}) +// CHECK-SAME: attributes {tf.entry_function = {inputs = "input:0,k:0", outputs = "TopK:0,TopK:1"}, tf_saved_model.exported_names = ["main"]} +// CHECK: %[[CALL0:.*]]:2 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @topk} +// Expects an IdentityN op to be created. +// CHECK: %[[IDENTITY:.*]]:2 = "tf.IdentityN"(%[[CALL0]]#0, %[[CALL0]]#1) : (tensor, tensor) -> (tensor, tensor) +// CHECK: return %[[IDENTITY]]#0, %[[IDENTITY]]#1 : tensor, tensor +} + +// ----- + +// Test that the signature prefix is added when there are duplicated input names. +module attributes {tf.versions = {producer = 930 : i32}, tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () + func.func @NoOp() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"]} { + func.return + } + + func.func @mul1(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["y"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "y:0,x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["mul1"]} { + %0 = "tf.Mul"(%arg1, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + func.return %0 : tensor<1xf32> + } + + func.func @mul2(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["y"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "y:0,x:0", outputs = "PartitionedCall_1:0"}, tf_saved_model.exported_names = ["mul2"]} { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor} : () -> tensor + %0 = "tf.Mul"(%arg1, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %1 = "tf.Mul"(%0, %cst) : (tensor<1xf32>, tensor) -> tensor<1xf32> + func.return %1 : tensor<1xf32> + } + +// CHECK: func @main +// CHECK: (%arg0: tensor<1xf32> {tf_saved_model.index_path = ["mul1_y:0"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["mul1_x:0"]} +// CHECK: %arg2: tensor<1xf32> {tf_saved_model.index_path = ["mul2_y:0"]}, %arg3: tensor<1xf32> {tf_saved_model.index_path = ["mul2_x:0"]}) +// CHECK: -> (tensor<1xf32> {tf_saved_model.index_path = ["PartitionedCall:0"]}, tensor<1xf32> {tf_saved_model.index_path = ["PartitionedCall_1:0"]}) +// CHECK: attributes {tf.entry_function = {inputs = "mul1_y:0,mul1_x:0,mul2_y:0,mul2_x:0", outputs = "PartitionedCall:0,PartitionedCall_1:0"}, tf_saved_model.exported_names = ["main"]} +} + +// ----- + +// Test that the signature prefix is added when there are duplicated output names. +module attributes {tf.versions = {producer = 930 : i32}, tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () + func.func @NoOp() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"]} { + func.return + } + + func.func @mul1(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["y"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "mul1_y:0,mul1_x:0", outputs = "output:0"}, tf_saved_model.exported_names = ["mul1"]} { + %0 = "tf.Mul"(%arg1, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + func.return %0 : tensor<1xf32> + } + + func.func @mul2(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["y"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "mul2_y:0,mul2_x:0", outputs = "output:0"}, tf_saved_model.exported_names = ["mul2"]} { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor} : () -> tensor + %0 = "tf.Mul"(%arg1, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %1 = "tf.Mul"(%0, %cst) : (tensor<1xf32>, tensor) -> tensor<1xf32> + func.return %1 : tensor<1xf32> + } +// CHECK: func @main +// CHECK: (%arg0: tensor<1xf32> {tf_saved_model.index_path = ["mul1_y:0"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["mul1_x:0"]} +// CHECK: %arg2: tensor<1xf32> {tf_saved_model.index_path = ["mul2_y:0"]}, %arg3: tensor<1xf32> {tf_saved_model.index_path = ["mul2_x:0"]}) +// CHECK: -> (tensor<1xf32> {tf_saved_model.index_path = ["mul1_output:0"]}, tensor<1xf32> {tf_saved_model.index_path = ["mul2_output:0"]}) +// CHECK: attributes {tf.entry_function = {inputs = "mul1_y:0,mul1_x:0,mul2_y:0,mul2_x:0", outputs = "mul1_output:0,mul2_output:0"}, tf_saved_model.exported_names = ["main"]} +} + +// ----- + +// Tests when a function called @main already exists, it is renamed to +// `main_{i}` to avoid conflict. +module attributes {tf_saved_model.semantics} { + func.func @main(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["x"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["y"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "x:0,y:0", outputs = "output:0"}, tf_saved_model.exported_names = ["main"]} { + %0 = "tf.Mul"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + func.return %0 : tensor<1xf32> + } + +// CHECK: func.func private @main_0 +// CHECK: func.func @main +} + +// ----- + +// Tests when a function called @main already exists and @main_{i} also already +// exists, it increments the suffix number until there's no conflict. +module attributes {tf_saved_model.semantics} { + func.func @main_0(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["z"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "z:0", outputs = "output:0"}, tf_saved_model.exported_names = ["main_0"]} { + %0 = "tf.Identity"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + func.return %0 : tensor<1xf32> + } + + func.func @main(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["x"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["y"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "x:0,y:0", outputs = "output:0"}, tf_saved_model.exported_names = ["main"]} { + %0 = "tf.Mul"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + func.return %0 : tensor<1xf32> + } +// `@main_0` remains touched. +// CHECK: func.func private @main_0 +// CHECK-SAME: z:0 + +// `@main` should be renamed to `@main_1` instead of `@main_0` to avoid +// conflict. +// CHECK: func.func private @main_1 +// CHECK-SAME: x:0 + +// This is the newly created main function. +// CHECK: func.func @main } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_quantized_functions.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_quantized_functions.mlir index 5dd67fd14b8..030d5fb946a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_quantized_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_quantized_functions.mlir @@ -1,5 +1,6 @@ // RUN: tf-quant-opt %s -quant-insert-quantized-functions | FileCheck %s // RUN: tf-quant-opt %s -quant-insert-quantized-functions='quantization-method=ptq target-opset=UNIFORM_QUANTIZED' | FileCheck --check-prefix=UQ-CHECK %s +// RUN: tf-quant-opt %s -quant-insert-quantized-functions='quantization-method=weight_only target-opset=XLA' | FileCheck --check-prefix=WEIGHTONLY %s // Empty module module { @@ -13,22 +14,50 @@ module { // CHECK-NOT: func private @internal_conv2d_fn // CHECK-NOT: func private @internal_matmul_fn // CHECK: func private @quantized_conv2d_with_bias_fn +// CHECK-SAME: tf_quant.quantized_ops = ["Conv2D", "BiasAdd"] // CHECK: func private @quantized_conv2d_with_bias_and_relu_fn // CHECK: func private @quantized_conv2d_with_bias_and_relu6_fn // CHECK: func private @quantized_conv2d_fn // CHECK: func private @quantized_conv2d_with_relu_fn // CHECK: func private @quantized_conv2d_with_relu6_fn +// CHECK: func private @quantized_depthwise_conv2d_with_bias_and_relu_float_output_fn +// CHECK-SAME: tf_quant.quantized_ops = ["DepthwiseConv2D", "BiasAdd", "Relu"] // CHECK: func private @quantized_matmul_with_bias_fn // CHECK: func private @quantized_matmul_with_bias_and_relu_fn // CHECK: func private @quantized_matmul_with_bias_and_relu6_fn // CHECK: func private @quantized_matmul_fn +// CHECK-SAME: tf_quant.quantized_ops = ["MatMul"] // CHECK: func private @quantized_matmul_with_relu_fn // CHECK: func private @quantized_matmul_with_relu6_fn +// CHECK: func private @quantized_conv3d_with_bias_fn +// CHECK-SAME: tf_quant.quantized_ops = ["Conv3D", "BiasAdd"] +// CHECK: func private @quantized_batch_matmul_with_bias_fn +// CHECK-SAME: tf_quant.quantized_ops = ["BatchMatMul", "BiasAdd"] // CHECK: func private @quantize_i8 // CHECK: func private @dequantize_i8 +// UQ-CHECK-NOT: func private @dequantize_i8 +// UQ-CHECK-NOT: func private @internal_conv2d_fn +// UQ-CHECK-NOT: func private @internal_requantize_qi8_fn +// UQ-CHECK-NOT: func private @internal_requantize_no_activation_fn +// UQ-CHECK-NOT: func private @internal_requantize_and_relu_fn +// UQ-CHECK-NOT: func private @quantize_i8 // UQ-CHECK: func private @quantized_conv2d_with_bias_fn -// UQ-CHECK: func private @quantize_qi8 -// UQ-CHECK: func private @requantize_qi8 -// UQ-CHECK: func private @dequantize_qi8 +// UQ-CHECK-SAME: tf_quant.quantized_ops = ["Conv2D", "BiasAdd"] +// UQ-CHECK: func private @quantized_conv2d_with_bias_and_relu_fn +// UQ-CHECK: func private @quantized_conv2d_with_bias_and_relu6_fn +// UQ-CHECK: func private @quantized_conv2d_with_relu_fn +// UQ-CHECK: func private @quantized_conv2d_with_relu6_fn +// UQ-CHECK: func private @quantized_depthwise_conv2d_with_bias_fn +// UQ-CHECK-SAME: tf_quant.quantized_ops = ["DepthwiseConv2D", "BiasAdd"] +// UQ-CHECK: func private @quantized_depthwise_conv2d_with_bias_and_relu_fn +// UQ-CHECK: func private @quantized_depthwise_conv2d_with_bias_and_relu6_fn +// UQ-CHECK: func private @quantized_depthwise_conv2d_with_relu_fn +// UQ-CHECK: func private @quantized_depthwise_conv2d_with_relu6_fn +// WEIGHTONLY: func private @quantized_conv2d +// WEIGHTONLY: func private @quantized_depthwise_conv2d +// WEIGHTONLY: func private @quantized_matmul +// WEIGHTONLY: func private @quantized_conv3d +// WEIGHTONLY: func private @quantized_batch_matmul +// WEIGHTONLY: func private @dequantize_i8 diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_quantized_functions_drq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_quantized_functions_drq.mlir index 2b2a014ccfb..a6e060dc74a 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_quantized_functions_drq.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_quantized_functions_drq.mlir @@ -1,4 +1,5 @@ -// RUN: tf-quant-opt %s -quant-insert-quantized-functions='quantization-method=drq target-opset=UNIFORM_QUANTIZED' | FileCheck %s +// RUN: tf-quant-opt %s -quant-insert-quantized-functions='quantization-method=drq' | FileCheck %s +// RUN: tf-quant-opt %s -quant-insert-quantized-functions='quantization-method=drq target-opset=UNIFORM_QUANTIZED' | FileCheck --check-prefix=UQ-CHECK %s // Empty module module { @@ -12,3 +13,12 @@ module { // CHECK-NOT: func private @internal_quantize_i8 // CHECK-NOT: func private @internal_matmul_fn // CHECK: func private @quantized_matmul_fn +// CHECK-SAME: tf_quant.quantized_ops = ["MatMul"] +// CHECK: func private @quantized_conv2d_fn +// CHECK-SAME: tf_quant.quantized_ops = ["Conv2D"] +// CHECK: func private @quantized_depthwise_conv2d_fn +// CHECK-SAME: tf_quant.quantized_ops = ["DepthwiseConv2D"] + +// UQ-CHECK: func private @quantized_conv2d_fn +// UQ-CHECK: func private @quantized_depthwise_conv2d_fn +// UQ-CHECK: func private @quantized_matmul_fn diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_restore_op.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_restore_op.mlir new file mode 100644 index 00000000000..800f8238d30 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/insert_restore_op.mlir @@ -0,0 +1,170 @@ +// RUN: tf-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -quant-insert-restore-op | FileCheck %s + +// RestoreV2 op created for a single VarHandleOp. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + + func.func @init_func_restore_op() -> () attributes { + tf_saved_model.initializer_type = "restore_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]} { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %var_0 = "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_0, %cst_0) : (tensor>>, tensor<2xf32>) -> () + return + } + +// CHECK: func.func @init_func_restore_op +// Check that an argument ("__tf_file_prefix") is created. +// CHECK-SAME: %[[ARG_0:.*]]: tensor {tf_saved_model.index_path = ["__tf_file_prefix"]} + +// Original `AssignVariableOp(VarHandleOp, Const)` pattern persists. +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {{.*value = dense<1.000000e\+00> : tensor<2xf32>.*}} +// CHECK-DAG: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {{.*shared_name = "var_0".*}} : () -> tensor>> +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[CST_0]]) : (tensor>>, tensor<2xf32>) -> () + +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {{.*value = dense<"var_0"> : tensor<1x!tf_type.string>.*}} +// CHECK-DAG: %[[CST_2:.*]] = "tf.Const"() {{.*value = dense<""> : tensor<1x!tf_type.string>.*}} + +// Test that RestoreV2 op is created with 1 resulting value. +// CHECK: %[[RESTORE:.*]] = "tf.RestoreV2"(%[[ARG_0]], %[[CST_1]], %[[CST_2]]) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<2xf32> +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[RESTORE]]) {validate_shape = false} : (tensor>>, tensor<2xf32>) -> () +} + +// ----- + +// RestoreV2 op created for multiple VarHandleOps. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op_multiple_variables]} : () -> () + + func.func @init_func_restore_op_multiple_variables() -> () attributes { + tf_saved_model.initializer_type = "restore_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]} { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %var_0 = "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_0, %cst_0) : (tensor>>, tensor<2xf32>) -> () + + %cst_1 = "tf.Const"() {value = dense<2> : tensor<4xi32>} : () -> tensor<4xi32> + %var_1 = "tf.VarHandleOp"() {shared_name = "var_1"} : () -> tensor>> + "tf.AssignVariableOp"(%var_1, %cst_1) : (tensor>>, tensor<4xi32>) -> () + return + } + +// CHECK: func.func @init_func_restore_op_multiple_variables +// Check that an argument ("__tf_file_prefix") is created. +// CHECK-SAME: %[[ARG_0:.*]]: tensor {tf_saved_model.index_path = ["__tf_file_prefix"]} + +// CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {{.*shared_name = "var_0".*}} : () -> tensor>> +// CHECK-DAG: %[[VAR_HANDLE_1:.*]] = "tf.VarHandleOp"() {{.*shared_name = "var_1".*}} : () -> tensor>> + +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {{{.*value = dense<\["var_0", "var_1"\]> : tensor<2x!tf_type.string>.*}}} +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {{{.*value = dense<""> : tensor<2x!tf_type.string>.*}}} + +// Test that RestoreV2 op is created with 2 resulting values. +// CHECK: %[[RESTORE:.*]]:2 = "tf.RestoreV2"(%[[ARG_0]], %[[CST_0]], %[[CST_1]]) : (tensor, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<2xf32>, tensor<4xi32>) + +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[RESTORE]]#0) {validate_shape = false} : (tensor>>, tensor<2xf32>) -> () +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_1]], %[[RESTORE]]#1) {validate_shape = false} : (tensor>>, tensor<4xi32>) -> () +} + +// ----- + +// RestoreV2 op not created for `AssignVariableOp(VarHandleOp, Const)` patterns +// in the initializer function of "init_op" type. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_init_op]} : () -> () + + func.func @init_func_init_op() -> () attributes { + tf_saved_model.initializer_type = "init_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_init_op"]} { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %var_0 = "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_0, %cst_0) {validate_shape = false} : (tensor>>, tensor<2xf32>) -> () + return + } +// Check that no function argument is created. +// CHECK: func.func @init_func_init_op() + +// CHECK-DAG: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {{{.*shared_name = "var_0".*}}} : () -> tensor>> +// CHECK-DAG: %[[CST:.*]] = "tf.Const"() {{{.*value = dense<1.000000e\+00> : tensor<2xf32>.*}}} +// Make sure that "tf.RestoreV2" is not created. +// CHECK-NOT: "tf.RestoreV2" +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[CST]]) {validate_shape = false} : (tensor>>, tensor<2xf32>) -> () +} + +// ----- + +// Test that `RestoreV2Op` is created even when the `Const` op is shared across +// `AssignVariableOp`s. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op_multiple_variables_sharing_const]} : () -> () + + func.func @init_func_restore_op_multiple_variables_sharing_const() -> () attributes { + tf_saved_model.initializer_type = "restore_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]} { + // This const is shared and initializes two variables. + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + + %var_0 = "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_0, %cst_0) : (tensor>>, tensor<2xf32>) -> () + + %var_1 = "tf.VarHandleOp"() {shared_name = "var_1"} : () -> tensor>> + "tf.AssignVariableOp"(%var_1, %cst_0) : (tensor>>, tensor<2xf32>) -> () + return + } + +// CHECK: func.func @init_func_restore_op_multiple_variables_sharing_const +// Check that an argument ("__tf_file_prefix") is created. +// CHECK-SAME: %[[ARG_0:.*]]: tensor {tf_saved_model.index_path = ["__tf_file_prefix"]} + +// CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {{.*shared_name = "var_0".*}} : () -> tensor>> +// CHECK-DAG: %[[VAR_HANDLE_1:.*]] = "tf.VarHandleOp"() {{.*shared_name = "var_1".*}} : () -> tensor>> + +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {{{.*value = dense<\["var_0", "var_1"\]> : tensor<2x!tf_type.string>.*}}} +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {{{.*value = dense<""> : tensor<2x!tf_type.string>.*}}} + +// Test that RestoreV2 op is created with 2 resulting values. +// CHECK: %[[RESTORE:.*]]:2 = "tf.RestoreV2"(%[[ARG_0]], %[[CST_0]], %[[CST_1]]) : (tensor, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<2xf32>, tensor<2xf32>) + +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[RESTORE]]#0) {validate_shape = false} : (tensor>>, tensor<2xf32>) -> () +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_1]], %[[RESTORE]]#1) {validate_shape = false} : (tensor>>, tensor<2xf32>) -> () +} + + +// ----- + +// Test that "tf.RestoreV2" is not created because there are no variables. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op_no_variable]} : () -> () + + func.func @init_func_restore_op_no_variable() -> () attributes { + tf_saved_model.initializer_type = "restore_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]} { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + return + } +// CHECK: func.func @init_func_restore_op_no_variable() +// CHECK-NOT: "tf.RestoreV2" +} + +// ----- + +// Test when there are no initializers. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = []} : () -> () +// CHECK-NOT: "tf.RestoreV2" +} + +// ----- + +// Test when there is no SessionInitializerOp. + +module attributes {tf_saved_model.semantics} { +// CHECK-NOT: "tf.RestoreV2" +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_drq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_drq.mlir index a125da0f5a3..0a7b5108387 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_drq.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_drq.mlir @@ -22,3 +22,118 @@ func.func @lift_float_matmul(%arg0: tensor<1x12x12x512xf32>) -> (tensor<*xf32>, // CHECK-NEXT: %[[OUT:.*]] = "tf.MatMul"(%arg0, %arg1) // CHECK-NEXT: return %[[OUT]] } + +// ----- + +// CHECK-LABEL: lift_float_conv +func.func @lift_float_conv(%arg0: tensor<1x3x4x3xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.Conv2D"(%arg0, %cst_1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + %3 = "tf.Conv2D"(%arg0, %cst_1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %4 = "tf.BiasAdd"(%3, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + + func.return %2, %4 : tensor<*xf32>, tensor<*xf32> + +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1]]) +// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable", +// CHECK-SAME: f = @composite_conv2d_fn_2} +// CHECK: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_0]], %[[CONST_0]]) +// CHECK: %[[RELU6_0:.*]] = "tf.Relu6"(%[[BIASADD_0]]) +// CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1]]) +// CHECK-SAME: f = @composite_conv2d_fn_1} +// CHECK: %[[BIASADD_1:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_1]], %[[CONST_0]]) +// CHECK: return %[[RELU6_0]], %[[BIASADD_1]] +// CHECK: } + +// CHECK-LABEL: private @composite_conv2d_fn_2 +// CHECK-NEXT: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %arg1) +// CHECK-SAME: attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations" +// CHECK-SAME: data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true +// CHECK-NEXT: return %[[CONV2D_0]] + +// CHECK-LABEL: private @composite_conv2d_fn_1 +// CHECK-NEXT: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %arg1) +// CHECK-SAME: attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations" +// CHECK-NEXT: return %[[CONV2D_0]] +} + +// ----- + +// CHECK-LABEL: not_lift_float_conv_with_non_constant_weights +func.func @not_lift_float_conv_with_non_constant_weights(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + %3 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %4 = "tf.BiasAdd"(%3, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + + func.return %2, %4 : tensor<*xf32>, tensor<*xf32> + +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> +// CHECK-NOT: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1) +// CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %arg1) +} + +// ----- + +// CHECK-LABEL: lift_float_depthwise_conv +func.func @lift_float_depthwise_conv(%arg0: tensor<1x3x4x3xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst_1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + %3 = "tf.DepthwiseConv2dNative"(%arg0, %cst_1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + %4 = "tf.BiasAdd"(%3, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + func.return %2, %4 : tensor<*xf32>, tensor<*xf32> + +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1]]) +// CHECK-SAME: _tfl_quant_trait = "fully_quantizable", +// CHECK-SAME: f = @composite_depthwise_conv2d_fn_2} +// CHECK: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_0]], %[[CONST_0]]) +// CHECK: %[[RELU6_0:.*]] = "tf.Relu6"(%[[BIASADD_0]]) +// CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1]]) +// CHECK-SAME: f = @composite_depthwise_conv2d_fn_1} +// CHECK: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_1]], %[[CONST_0]]) +// CHECK: return %[[RELU6_0]], %[[BIASADD_0]] +// CHECK: } + +// CHECK-LABEL: private @composite_depthwise_conv2d_fn_2 +// CHECK-NEXT: %[[DEPTHWISECONV2D_0:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %arg1) +// CHECK-SAME: attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations" +// CHECK-NEXT: return %[[DEPTHWISECONV2D_0:.*]] + +// CHECK-LABEL: private @composite_depthwise_conv2d_fn_1 +// CHECK-NEXT: %[[DEPTHWISECONV2D_0:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %arg1) +// CHECK-SAME: attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations" +// CHECK-NEXT: return %[[DEPTHWISECONV2D_0:.*]] +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_drq_min_elements.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_drq_min_elements.mlir index 73836bc4e2a..b1104490025 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_drq_min_elements.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_drq_min_elements.mlir @@ -1,7 +1,7 @@ // RUN: tf-quant-opt %s -split-input-file -quant-lift-quantizable-spots-as-functions-drq="min-num-elements-for-weights=2500000" | FileCheck %s -// CHECK-LABEL: float_matmul -func.func @float_matmul(%arg0: tensor<1x12x12x512xf32>) -> (tensor<*xf32>, tensor<*xf32>) { +// CHECK-LABEL: lift_float_matmul +func.func @lift_float_matmul(%arg0: tensor<1x12x12x512xf32>) -> (tensor<*xf32>, tensor<*xf32>) { %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<512x512xf32>} : () -> tensor<512x512xf32> %out_1 = "tf.MatMul"(%arg0, %cst) { device = "", transpose_a = false, transpose_b = false @@ -21,3 +21,21 @@ func.func @float_matmul(%arg0: tensor<1x12x12x512xf32>) -> (tensor<*xf32>, tenso // CHECK-LABEL: private @composite_matmul_fn_1 } + +// ----- + +// CHECK-LABEL: not_lift_float_conv +func.func @not_lift_float_conv(%arg0: tensor<1x3x4x512xf32>) -> (tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x512x512xf32>} : () -> tensor<2x3x512x512xf32> + %0 = "tf.Conv2D"(%arg0, %cst) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x512xf32>, tensor<2x3x512x512xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> + +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x512x512xf32>} : () -> tensor<2x3x512x512xf32> +// CHECK: %[[PARTITIONEDCALL:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST]]) +// CHECK-NOT: {_tfl_quant_trait = "fully_quantizable", +// CHECK-SAME: {config = "", +// CHECK-SAME: f = @composite_conv2d_fn_1} +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_xla.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_xla.mlir index b2aea357eb7..1d80a199d2e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_xla.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/lift_quantizable_spots_as_functions_xla.mlir @@ -71,3 +71,33 @@ func.func @conv_with_dynamic_channel_dim(%arg0: tensor<1x3x4x?xf32>) -> tensor<* // Check that the `attr_map` attribute has been removed. // CHECK-NOT: attr_map // CHECK-SAME: data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1] + +// ----- + +func.func @const_filter_with_q_dq(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense<[[[[-0.308480561, 0.122108772], [-0.0622722618, 0.285358578], [0.279975802, -0.227407396]], [[-0.223535746, 0.301872164], [0.45813936, 0.375932634], [-0.142182723, 9.95125505E-4]], [[0.183462933, 0.212702021], [-0.129749238, 0.0611961856], [0.00308316527, -0.486231536]]], [[[0.272826612, 0.382641196], [-0.135114014, 0.115396179], [-0.424618751, -1.311760e-01]], [[0.433140099, 0.15137814], [-0.102797419, 0.288730145], [-0.183163881, 0.0680986494]], [[0.369127393, -0.0638265759], [0.302147657, -0.35623318], [0.204260975, 0.204581305]]]]> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<[1.000000e-01, 2.000000e-01]> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "quantfork.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> + %1 = "quantfork.dcast"(%0) : (tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3xf32> + %q_w = "quantfork.qcast"(%cst) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> + %dq_w = "quantfork.dcast"(%q_w) : (tensor<2x3x3x2x!quant.uniform>) -> tensor<2x3x3x2xf32> + %2 = "tf.Conv2D"(%1, %dq_w) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %3 = "tf.BiasAdd"(%2, %cst_0) {data_format = "NHWC", device = ""} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + %4 = "tf.Relu"(%3) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + %5 = "quantfork.qcast"(%4) : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2x!quant.uniform> + %6 = "quantfork.dcast"(%5) : (tensor<1x3x2x2x!quant.uniform>) -> tensor<1x3x2x2xf32> + %7 = "tf.Identity"(%6) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + %8 = "tf.Identity"(%7) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + return %8 : tensor<1x3x2x2xf32> +} + +// CHECK-LABEL: func @const_filter_with_q_dq +// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() {{.*}} : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[BIAS:.*]] = "tf.Const"() {device = "", value = dense<[1.000000e-01, 2.000000e-01]> : tensor<2xf32>} +// CHECK: %[[Q_W:.*]] = "quantfork.qcast"(%[[WEIGHT]]) +// CHECK: %[[DQ_W:.*]] = "quantfork.dcast"(%[[Q_W]]) +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"({{.*}}, %[[DQ_W]], %[[BIAS]]) +// CHECK-SAME: _tfl_quant_trait = "fully_quantizable" +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu_fn_1 + +// CHECK-LABEL: func private @composite_conv2d_with_bias_and_relu_fn_1 \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/mark_functions_noinline.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/mark_functions_noinline.mlir new file mode 100644 index 00000000000..cd8af14e7ea --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/mark_functions_noinline.mlir @@ -0,0 +1,24 @@ +// RUN: tf-quant-opt %s -mark-functions-noinline='noinline-functions=noinline0' \ +// RUN: -allow-unregistered-dialect -mlir-disable-threading \ +// RUN: -split-input-file -verify-diagnostics | FileCheck %s + +// Tests that the function is marked tf._noinline = true. + +// CHECK-LABEL: @noinline0 +// CHECK-SAME: attributes {{{.*tf._noinline = true.*}}} +func.func @noinline0() -> (tensor<0xf32>) { + %cst = "tf.Const"() {value = dense<1.0> : tensor<0xf32>} : () -> tensor<0xf32> + return %cst : tensor<0xf32> +} + +// ----- + +// Tests that the function not listed in the option `noinline-functions` +// is not marked tf._noinline = true. + +// CHECK-LABEL: @inline +// CHECK-NOT: tf._noinline +func.func @inline() -> (tensor<0xf32>) { + %cst = "tf.Const"() {value = dense<1.0> : tensor<0xf32>} : () -> tensor<0xf32> + return %cst : tensor<0xf32> +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/merge_initializer_function_ops_to_main.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/merge_initializer_function_ops_to_main.mlir index 9dfd75a2fc9..117c267deac 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/merge_initializer_function_ops_to_main.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/merge_initializer_function_ops_to_main.mlir @@ -80,13 +80,102 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // Checks that the location for the init op is properly set. // CHECK-LOC-LABEL: func.func @main // CHECK-LOC: tf_executor.island({{.*}}) wraps "tf.NoOp"() -// CHECK-LOC-SAME: loc("init_op__NoOp") +// CHECK-LOC-SAME: loc("init_op_NoOp") +} + +// ----- + +// Tests when the initializer function contains multiple stateful +// initialization ops. They should be transitively connected through +// control dependencies (!tf_executor.control), which is guaranteed by +// the `tf-executor-break-up-islands` pass. + +// CHECK-LABEL: module attributes +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () +// Check that the initializers list is empty. +// CHECK: "tf_saved_model.session_initializer"() +// CHECK-SAME: initializers = [] + + func.func @NoOp() + attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"], tf_saved_model.initializer_type = "init_op"} { + tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Const"() {device = "", value = dense<["test_1"]> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %out_0, %ctl_1 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<[1]> : tensor<1xi64>} : () -> tensor<1xi64> + %out_1, %ctl_2 = tf_executor.island wraps "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %ctl_3 = tf_executor.island wraps "tf.LookupTableImportV2"(%out_1, %out, %out_0) {device = ""} : (tensor, tensor<1x!tf_type.string>, tensor<1xi64>) -> () + + %out_2, %ctl_4 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<["test_2"]> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %out_3, %ctl_5 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<[2]> : tensor<1xi64>} : () -> tensor<1xi64> + // Has a control dependency to the previous LookupTableImportV2. + %out_4, %ctl_6 = tf_executor.island(%ctl_3) wraps "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "2", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %ctl_7 = tf_executor.island wraps "tf.LookupTableImportV2"(%out_4, %out_2, %out_3) {device = ""} : (tensor, tensor<1x!tf_type.string>, tensor<1xi64>) -> () + tf_executor.fetch %ctl_7 : !tf_executor.control + } + return + } +// The session initializer function is removed. +// CHECK-NOT: @NoOp() + + func.func private @serving_default(%arg0: tensor) -> tensor<*xi64> attributes {tf.entry_function = {control_outputs = "", inputs = "input:0", outputs = "output:0"}} { + %0 = tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Const"() {device = "", value = dense<-1> : tensor} : () -> tensor + %out_0, %ctl_1 = tf_executor.island wraps "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %out_1, %ctl_2 = tf_executor.island wraps "tf.LookupTableFindV2"(%out_0, %arg0, %out) {device = ""} : (tensor, tensor, tensor) -> tensor<*xi64> + tf_executor.fetch %out_1 : tensor<*xi64> + } + return %0 : tensor<*xi64> + } + + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["serving_default_input_vocabs:0"]}) -> (tensor<*xi64> {tf_saved_model.index_path = ["StatefulPartitionedCall:0"]}) + attributes {tf.entry_function = {inputs = "serving_default_input_vocabs:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["main"]} { + %0 = tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @serving_default} : (tensor) -> tensor<*xi64> + tf_executor.fetch %out : tensor<*xi64> + } + return %0 : tensor<*xi64> + } +// Sanity check: The main function's signature & attributes have not changed. +// CHECK: func.func @main(%[[ARG:.*]]: tensor +// CHECK-SAME: tf_saved_model.index_path = ["serving_default_input_vocabs:0"] +// CHECK-SAME: -> (tensor<*xi64> {tf_saved_model.index_path = ["StatefulPartitionedCall:0"]}) +// CHECK-SAME: tf.entry_function = {inputs = "serving_default_input_vocabs:0", outputs = "StatefulPartitionedCall:0"} +// CHECK-SAME: tf_saved_model.exported_names = ["main"] + +// CHECK: %[[GRAPH_OUT:.*]] = tf_executor.graph +// CHECK-NEXT: %[[OUT:.*]], %[[CTL:.*]] = tf_executor.island wraps "tf.PartitionedCall"(%[[ARG]]) +// CHECK-SAME: f = @serving_default +// Checks that the contents of @NoOp are copied here. +// CHECK-DAG: %[[OUT_0:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.Const"() {{{.*value = dense<"test_1">.*}}} +// CHECK-DAG: %[[OUT_1:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.Const"() {{{.*value = dense<1>.*}}} + +// CHECK-NEXT: %[[OUT_2:.*]], %[[CTL_2:.*]] = tf_executor.island wraps "tf.HashTableV2"() +// CHECK-NEXT: %[[CTL_3:.*]] = tf_executor.island wraps "tf.LookupTableImportV2"(%[[OUT_2]], %[[OUT_0]], %[[OUT_1]]) + +// CHECK-DAG: %[[OUT_3:.*]], %[[CTL_4:.*]] = tf_executor.island wraps "tf.Const"() {{{.*value = dense<"test_2">.*}}} +// CHECK-DAG: %[[OUT_4:.*]], %[[CTL_5:.*]] = tf_executor.island wraps "tf.Const"() {{{.*value = dense<2>.*}}} + +// CHECK-NEXT: %[[OUT_5:.*]], %[[CTL_6:.*]] = tf_executor.island(%[[CTL_3]]) wraps "tf.HashTableV2"() +// CHECK-NEXT: %[[CTL_7:.*]] = tf_executor.island wraps "tf.LookupTableImportV2"(%[[OUT_5]], %[[OUT_3]], %[[OUT_4]]) + +// Checks that the NoOp with control dependency to the control output for the +// initializer function is created & fetched. +// CHECK-NEXT: %[[CTL_8:.*]] = tf_executor.island(%[[CTL_7]]) wraps "tf.NoOp"() +// CHECK-NEXT: tf_executor.fetch %[[OUT]], %[[CTL_8]] : tensor<*xi64>, !tf_executor.control +// CHECK-NEXT: } +// CHECK-NEXT: return %[[GRAPH_OUT]] : tensor<*xi64> + +// Checks that the location for the init op is properly set. +// CHECK-LOC-LABEL: func.func @main +// CHECK-LOC: tf_executor.island({{.*}}) wraps "tf.NoOp"() +// CHECK-LOC-SAME: loc("init_op_NoOp") } // ----- // Test the case where the initializer function accepts an argument but it // is not used within the body. + // CHECK-LABEL: module attributes module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () @@ -137,18 +226,20 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // Checks that the location for the init op is properly set. // CHECK-LOC-LABEL: func.func @main // CHECK-LOC: tf_executor.island({{.*}}) wraps "tf.NoOp"() -// CHECK-LOC-SAME: loc("init_op__NoOp") +// CHECK-LOC-SAME: loc("init_op_NoOp") } // ----- -// Test the case where there are 2 initializer functions. +// Test the case where there are 2 initializer functions ("init_op" and +// "restore_op"). + // CHECK-LABEL: module attributes module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { "tf_saved_model.session_initializer"() {initializers = [@NoOp_0, @NoOp_1]} : () -> () // Check that the initializer typed "init_op" is removed from initializers list. // CHECK: "tf_saved_model.session_initializer"() -// CHECK-SAME: initializers = [@NoOp_1] +// CHECK-SAME: initializers = [] func.func @NoOp_0() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp_0"], tf_saved_model.initializer_type = "init_op"} { @@ -169,8 +260,8 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p } return } -// The session initializer function typed "restore_op" is not removed. -// CHECK: @NoOp_1() +// The session initializer function is removed. +// CHECK-NOT: @NoOp_1() func.func @main() attributes {tf.entry_function = {inputs = "", outputs = ""}, tf_saved_model.exported_names = ["main"]} { tf_executor.graph { @@ -183,27 +274,113 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK-SAME: tf_saved_model.exported_names = ["main"] // CHECK: tf_executor.graph -// Checks that the contents of @NoOp_0 (type: "init_op") are copied here. -// CHECK-NEXT: %[[OUT_0:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.Const"() {{{.*value = dense<"dummy_op">.*}}} -// Checks that the contents of @NoOp_1 (type: "restore_op") are not copied here. -// CHECK-NOT: tf_executor.island wraps "tf.Const"() {{{.*value = dense<1>.*}}} -// Checks that the NoOp is only dependent on the initializer function with type "init_op". -// This is because the control dependency node is only required for the -// initializer function for resources other than variables. -// CHECK-NEXT: %[[CTL_2:.*]] = tf_executor.island(%[[CTL_0]]) wraps "tf.NoOp"() -// CHECK-NEXT: tf_executor.fetch %[[CTL_2]] : !tf_executor.control +// Checks that the contents of the initializer functions are copied here. +// CHECK-DAG: %[[OUT_0:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.Const"() {{{.*value = dense<"dummy_op"> : tensor<1x!tf_type.string>.*}}} +// CHECK-DAG: %[[OUT_1:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.Const"() {{{.*value = dense<1> : tensor<1xi32>.*}}} + +// Checks that 2 `NoOp`s having control dependencies to each of the initializer +// functions are created. +// CHECK-DAG: %[[CTL_2:.*]] = tf_executor.island(%[[CTL_0]]) wraps "tf.NoOp"() +// CHECK-DAG: %[[CTL_3:.*]] = tf_executor.island(%[[CTL_1]]) wraps "tf.NoOp"() + +// CHECK: tf_executor.fetch +// CHECK-SAME: !tf_executor.control, !tf_executor.control // CHECK-NEXT: } // CHECK-NEXT: return // Checks that the location for the init op is properly set. // CHECK-LOC-LABEL: func.func @main -// CHECK-LOC: tf_executor.island({{.*}}) wraps "tf.NoOp"() -// CHECK-LOC-SAME: loc("init_op__NoOp_0") + +// CHECK-LOC-DAG: tf_executor.island({{.*}}) wraps "tf.NoOp"() {{.*}} loc("init_op_NoOp_0") +// CHECK-LOC-DAG: tf_executor.island({{.*}}) wraps "tf.NoOp"() {{.*}} loc("restore_op_NoOp_1") +} + +// ----- + +// Tests that initializer function for "restore_op" is merged into @main. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () +// CHECK: "tf_saved_model.session_initializer"() {initializers = []} + + func.func @init_func_restore_op(%arg: tensor {tf_saved_model.index_path = ["file_prefix"]}) + attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"], tf_saved_model.initializer_type = "restore_op"} { + tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %out_0, %ctl_0 = tf_executor.island wraps "tf.Const"() {value = dense<"var_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %out_1, %ctl_1 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "var_0", device = "/device:CPU:0"} : () -> tensor>> + %out_2, %ctl_2 = tf_executor.island wraps "tf.RestoreV2"(%arg, %out_0, %out) {} : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<2xf32> + %ctl_3 = tf_executor.island wraps "tf.AssignVariableOp"(%out_1, %out_2) : (tensor>>, tensor<2xf32>) -> () + tf_executor.fetch %ctl_3 : !tf_executor.control + } + return + } + + func.func @main() attributes {tf.entry_function = {inputs = "", outputs = ""}, tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } +// A new argument corresponding to the "file_prefix" should be created. +// CHECK: func.func @main(%[[ARG:.*]]: tensor {tf_saved_model.index_path = ["file_prefix"]}) +// CHECK-SAME: {{{.*tf.entry_function = {inputs = "restore_op_0:0", outputs = ""}.*}}} +// CHECK-NEXT: tf_executor.graph + +// Checks that the ops from @init_func_restore_op are cloned. +// CHECK-DAG: %[[CONST_0:.*]], %[[CTL:.*]] = tf_executor.island wraps "tf.Const"() {{{.*value = dense<""> : tensor<1x!tf_type\.string>.*}}} +// CHECK-DAG: %[[CONST_1:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.Const"() {{{.*value = dense<"var_0"> : tensor<1x!tf_type\.string>.*}}} +// CHECK: %[[VAR_HANDLE:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.VarHandleOp"() {{{.*shared_name = "var_0".*}}} +// CHECK: %[[RESTORE:.*]], %[[CTL_2:.*]] = tf_executor.island wraps "tf.RestoreV2"(%[[ARG]], %[[CONST_1]], %[[CONST_0]]) +// CHECK: %[[CTL_3:.*]] = tf_executor.island wraps "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[RESTORE]]) +// CHECK: %[[CTL_4:.*]] = tf_executor.island(%[[CTL_3]]) wraps "tf.NoOp"() +// CHECK-NEXT: tf_executor.fetch %[[CTL_4]] : !tf_executor.control +// CHECK: return + +// Checks that the Location is properly set for the NoOp. +// CHECK-LOC: tf_executor.island({{.*}}) wraps "tf.NoOp"() {{.*}} loc("restore_op_init_func_restore_op") +} + +// ----- + +// Tests that the input name for the new argument created in @main (for the +// "restore_op" initializer function) is not added when there is no +// tf.entry_function. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () +// CHECK: "tf_saved_model.session_initializer"() {initializers = []} + + func.func @init_func_restore_op(%arg: tensor {tf_saved_model.index_path = ["file_prefix"]}) + attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"], tf_saved_model.initializer_type = "restore_op"} { + tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %out_0, %ctl_0 = tf_executor.island wraps "tf.Const"() {value = dense<"var_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %out_1, %ctl_1 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "var_0", device = "/device:CPU:0"} : () -> tensor>> + %out_2, %ctl_2 = tf_executor.island wraps "tf.RestoreV2"(%arg, %out_0, %out) {} : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<2xf32> + %ctl_3 = tf_executor.island wraps "tf.AssignVariableOp"(%out_1, %out_2) : (tensor>>, tensor<2xf32>) -> () + tf_executor.fetch %ctl_3 : !tf_executor.control + } + return + } + + func.func @main() attributes {tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } +// A new argument corresponding to the "file_prefix" should be created. +// Also checks that tf.entry_function is not created. +// CHECK: func.func @main(%[[ARG:.*]]: tensor {tf_saved_model.index_path = ["file_prefix"]}) attributes {tf_saved_model.exported_names = ["main"]} } // ----- // Tests no change when there's no initializer functions. + // CHECK-LABEL: module attributes module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { "tf_saved_model.session_initializer"() {initializers = []} : () -> () @@ -265,17 +442,17 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // Tests when the initializer function is empty. // CHECK-LABEL: module attributes module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { - "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () + "tf_saved_model.session_initializer"() {initializers = [@init_func_when_main_empty]} : () -> () // Check that the initializers attribute is untouched. // CHECK: "tf_saved_model.session_initializer"() -// CHECK-SAME: initializers = [@NoOp] +// CHECK-SAME: initializers = [@init_func_when_main_empty] - func.func @NoOp() + func.func @init_func_when_main_empty() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"], tf_saved_model.initializer_type = "init_op"} { return } // The initializer function is untouched. -// CHECK: func.func @NoOp +// CHECK: func.func @init_func_when_main_empty() func.func @main() attributes {tf_saved_model.exported_names = ["main"]} { tf_executor.graph { @@ -356,12 +533,13 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // Tests that warning is emitted when an initializer function does not have the // tf_saved_model.initializer_type attribute. + // CHECK-LABEL: module attributes module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () // Check that the initializers attribute is untouched. // CHECK: "tf_saved_model.session_initializer"() -// CHECK-SAME: initializers = [@NoOp] +// CHECK-SAME: initializers = [] // expected-warning @+1 {{Initializer func op does not have tf_saved_model.initializer_type attribute. Func op: NoOp}} func.func @NoOp() @@ -381,6 +559,8 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p } // CHECK: func.func @main() // CHECK-NEXT: tf_executor.graph -// CHECK-NEXT: tf_executor.fetch +// CHECK: %[[OUT:.*]], %[[CTL:.*]] = tf_executor.island wraps "tf.Const"() {{{.*value = dense<1> : tensor<1xi64>.*}}} +// CHECK: %[[CTL_0:.*]] = tf_executor.island(%[[CTL]]) wraps "tf.NoOp"() : () -> () +// CHECK: tf_executor.fetch %[[CTL_0]] : !tf_executor.control // CHECK: return } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir index be4d7356923..7a67389b64e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_lifting.mlir @@ -10,7 +10,7 @@ func.func @decompose_batch_norm(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { // CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<2.49743462E-5> : tensor<2xf32>} : () -> tensor<2xf32> // CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<0.999950051> : tensor<2xf32>} : () -> tensor<2xf32> // CHECK: %[[mul:.*]] = "tf.Mul"(%arg0, %[[CONST_0]]) : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> -// CHECK: %[[add:.*]] = "tf.Add"(%[[mul]], %[[CONST]]) : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> +// CHECK: %[[add:.*]] = "tf.AddV2"(%[[mul]], %[[CONST]]) : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> // CHECK-NEXT: return %[[add]] : tensor<*xf32> func.func @not_decompose_batch_norm(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { @@ -29,7 +29,7 @@ func.func @convert_add_to_biasadd(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2 %cst = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> %cst_0 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> - %1 = "tf.Add"(%0, %cst_0) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.AddV2"(%0, %cst_0) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> func.return %1 : tensor<1x3x2x2xf32> } // CHECK: func @convert_add_to_biasadd @@ -43,14 +43,14 @@ func.func @not_convert_add_to_biasadd(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3 %cst = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x3xf32>} : () -> tensor<2x3x3x3xf32> %cst_0 = "tf.Const"() {value = dense<0.500000e+00> : tensor<1x3x2x3xf32>} : () -> tensor<1x3x2x3xf32> %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x3xf32>) -> tensor<1x3x2x3xf32> - %1 = "tf.Add"(%0, %cst_0) : (tensor<1x3x2x3xf32>, tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32> + %1 = "tf.AddV2"(%0, %cst_0) : (tensor<1x3x2x3xf32>, tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32> func.return %1 : tensor<1x3x2x3xf32> } // CHECK: func @not_convert_add_to_biasadd // CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x3xf32>} : () -> tensor<2x3x3x3xf32> // CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor<1x3x2x3xf32>} : () -> tensor<1x3x2x3xf32> // CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x3xf32>) -> tensor<1x3x2x3xf32> -// CHECK-NEXT: %[[ADD:.*]] = "tf.Add"(%[[CONV2D]], %[[CONST_0]]) : (tensor<1x3x2x3xf32>, tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32> +// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[CONV2D]], %[[CONST_0]]) : (tensor<1x3x2x3xf32>, tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32> // CHECK-NEXT: return %[[ADD]] : tensor<1x3x2x3xf32> func.func @fuse_conv2d_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { @@ -119,7 +119,7 @@ func.func @fuse_conv2d_with_bias_and_add(%arg0: tensor<1x3x4x3xf32>) -> (tensor< %cst_1 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> - %2 = "tf.Add"(%1, %cst_1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.AddV2"(%1, %cst_1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> func.return %2 : tensor<1x3x2x2xf32> } // CHECK: func @fuse_conv2d_with_bias_and_add @@ -134,7 +134,7 @@ func.func @not_fuse_conv2d_with_bias_and_add(%arg0: tensor<1x3x4x3xf32>, %arg1: %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<2xf32>} : () -> tensor<2xf32> %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> - %2 = "tf.Add"(%1, %arg1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.AddV2"(%1, %arg1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> func.return %2 : tensor<1x3x2x2xf32> } // CHECK: func @not_fuse_conv2d_with_bias_and_add @@ -142,66 +142,66 @@ func.func @not_fuse_conv2d_with_bias_and_add(%arg0: tensor<1x3x4x3xf32>, %arg1: // CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<4.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32> // CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> // CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> -// CHECK-NEXT: %[[ADD:.*]] = "tf.Add"(%[[BIASADD]], %arg1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[BIASADD]], %arg1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> // CHECK-NEXT: return %[[ADD]] : tensor<1x3x2x2xf32> func.func @match_depthwise_conv2d_and_add(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> - %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> - %1 = "tf.Add"(%0, %cst_0) : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor + %1 = "tf.AddV2"(%0, %cst_0) : (tensor, tensor<3xf32>) -> tensor<*xf32> func.return %1 : tensor<*xf32> } // CHECK: func @match_depthwise_conv2d_and_add // CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> // CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<4.000000e-01> : tensor<3xf32>} : () -> tensor<3xf32> -// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> -// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> +// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor, tensor<3xf32>) -> tensor<*xf32> // CHECK-NEXT: return %[[BIASADD]] : tensor<*xf32> -func.func @match_depthwise_conv2d_and_mul(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { +func.func @match_depthwise_conv2d_and_mul(%arg0: tensor<*xf32>) -> (tensor) { %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> - %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> - %1 = "tf.Mul"(%0, %cst_0) : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> - func.return %1 : tensor<*xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor + %1 = "tf.Mul"(%0, %cst_0) : (tensor, tensor<3xf32>) -> tensor + func.return %1 : tensor } // CHECK: func @match_depthwise_conv2d_and_mul // CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<8.000000e-01> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> -// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> -// CHECK-NEXT: return %[[DEPTHWISE_CONV2D]] : tensor<*xf32> +// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor +// CHECK-NEXT: return %[[DEPTHWISE_CONV2D]] : tensor -func.func @match_depthwise_conv2d_with_bias_and_add(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { +func.func @match_depthwise_conv2d_with_bias_and_add(%arg0: tensor<*xf32>) -> (tensor) { %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> %cst_1 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> - %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> - %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> - %2 = "tf.Add"(%1, %cst_1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> - func.return %2 : tensor<*xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor, tensor<3xf32>) -> tensor + %2 = "tf.AddV2"(%1, %cst_1) : (tensor, tensor<3xf32>) -> tensor + func.return %2 : tensor } // CHECK: func @match_depthwise_conv2d_with_bias_and_add // CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> // CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<8.000000e-01> : tensor<3xf32>} : () -> tensor<3xf32> -// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> -// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> -// CHECK-NEXT: return %[[BIASADD]] : tensor<*xf32> +// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor, tensor<3xf32>) -> tensor +// CHECK-NEXT: return %[[BIASADD]] : tensor -func.func @match_depthwise_conv2d_with_bias_and_mul(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { +func.func @match_depthwise_conv2d_with_bias_and_mul(%arg0: tensor<*xf32>) -> (tensor) { %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> %cst_1 = "tf.Const"() {value = dense<0.500000e+00> : tensor<3xf32>} : () -> tensor<3xf32> - %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> - %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> - %2 = "tf.Mul"(%1, %cst_1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> - func.return %2 : tensor<*xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor, tensor<3xf32>) -> tensor + %2 = "tf.Mul"(%1, %cst_1) : (tensor, tensor<3xf32>) -> tensor + func.return %2 : tensor } // CHECK: func @match_depthwise_conv2d_with_bias_and_mul // CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> // CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() {value = dense<2.000000e-01> : tensor<3xf32>} : () -> tensor<3xf32> -// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> -// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> -// CHECK-NEXT: return %[[BIASADD]] : tensor<*xf32> +// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) {data_format = "NHWC"} : (tensor, tensor<3xf32>) -> tensor +// CHECK-NEXT: return %[[BIASADD]] : tensor func.func @lower_einsum(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,ikm->ijm"}: (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> @@ -238,3 +238,32 @@ func.func @not_removing_identity_of_returning_value(%arg0: tensor<*xf32>) -> (te // CHECK: func @not_removing_identity_of_returning_value // CHECK: %[[identity:.*]] = "tf.Identity" // CHECK: return %[[identity]] : tensor<*xf32> + +func.func @batch_norm_with_q_dq(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<5.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {device = "", value = dense<5.000000e-01> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantfork.qcast"(%cst_1) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.003937007874015748,0.003937007874015748}>> + %1 = "quantfork.dcast"(%0) : (tensor<2x3x3x2x!quant.uniform:f32:3, {0.003937007874015748,0.003937007874015748}>>) -> tensor<2x3x3x2xf32> + %2 = "quantfork.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> + %3 = "quantfork.dcast"(%2) : (tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3xf32> + %4 = "tf.Conv2D"(%3, %1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %y, %batch_mean, %batch_variance, %reserve_space_1, %reserve_space_2, %reserve_space_3 = "tf.FusedBatchNormV3"(%4, %cst, %cst_0, %cst, %cst_0) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<*xf32>) + %5 = "tf.Relu6"(%y) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + %6 = "quantfork.qcast"(%5) : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2x!quant.uniform:f32:3, {0.0026771653824903836:-60,0.0032283464285332388:-28}>> + %7 = "quantfork.dcast"(%6) : (tensor<1x3x2x2x!quant.uniform:f32:3, {0.0026771653824903836:-60,0.0032283464285332388:-28}>>) -> tensor<1x3x2x2xf32> + %8 = "tf.Identity"(%7) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + %9 = "tf.Identity"(%8) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + return %9 : tensor<1x3x2x2xf32> +} + +// CHECK: func @batch_norm_with_q_dq +// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<0.707036077> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() {value = dense<-0.914072155> : tensor<2xf32>} : () -> tensor<2xf32> +// CHECK: %[[q_input:.*]] = "quantfork.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[dq_input:.*]] = "quantfork.dcast"(%[[q_input]]) : (tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3xf32> +// CHECK: %[[q_weight:.*]] = "quantfork.qcast"(%[[cst]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.005567213212411235,0.005567213212411235}>> +// CHECK: %[[dq_weight:.*]] = "quantfork.dcast"(%[[q_weight]]) : (tensor<2x3x3x2x!quant.uniform:f32:3, {0.005567213212411235,0.005567213212411235}>>) -> tensor<2x3x3x2xf32> +// CHECK: %[[conv:.*]] = "tf.Conv2D"(%[[dq_input]], %[[dq_weight]]) +// CHECK: %[[bias:.*]] = "tf.BiasAdd"(%[[conv]], %[[cst_0]]) {data_format = "NHWC"} +// CHECK: %[[relu6:.*]] = "tf.Relu6"(%[[bias]]) \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize.mlir index 784cb959a31..e46fd9ece88 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize.mlir @@ -1,7 +1,5 @@ // RUN: tf-quant-opt %s -split-input-file -quant-prepare-quantize | FileCheck %s -// ----- - module { func.func @same_scale_test(%arg0: tensor<*xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[-1, 144]> : tensor<2xi32> @@ -26,10 +24,8 @@ module { func.func private @composite_matmul_with_bias_fn_1(%a: tensor<*xf32>, %b: tensor<*xf32>, %c: tensor<*xf32>) -> tensor<*xf32> { func.return %a: tensor<*xf32> } -} // CHECK-LABEL: same_scale_test - // CHECK: %[[maxpool:.*]] = "tf.MaxPool" // CHECK: %[[q1:.*]] = "quantfork.qcast"(%[[maxpool]]) // CHECK-SAME: quant.uniform @@ -42,5 +38,5 @@ module { // CHECK-SAME: quant.uniform // CHECK: "tf.PartitionedCall"(%[[dq2]] // CHECK-SAME: f = @composite_matmul_with_bias_fn_1 +} -// ----- diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_drq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_drq.mlir index f6ba4afb7a0..4da50d4ac91 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_drq.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_drq.mlir @@ -1,28 +1,90 @@ -// RUN: tf-quant-opt %s -split-input-file -quant-prepare-quantize-drq | FileCheck %s - -// ----- +// RUN: tf-quant-opt %s -split-input-file -quant-preprocess-op -quant-prepare-quantize-drq | FileCheck %s module { func.func @matmul(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>) { - %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %1 = "tf.PartitionedCall"(%arg0, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x3xf32>) -> tensor<*xf32> + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2x1024xf32>} : () -> tensor<2x1024xf32> + %1 = "tf.PartitionedCall"(%arg0, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> func.return %1: tensor<*xf32> } - func.func private @composite_matmul_fn(%arg0: tensor<1x2x2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { - %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x3xf32>, tensor<2x3xf32>) -> tensor<*xf32> + func.func private @composite_matmul_fn(%arg0: tensor<1x2x2x3xf32>, %arg1: tensor<2x1024xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: func @matmul -// CHECK-DAG: %cst = arith.constant dense<0.000000e+00> : tensor<2x3xf32> -// CHECK: %0 = "quantfork.qcast"(%cst) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 3.9370078740157481E-9>> -// CHECK: %1 = "quantfork.dcast"(%0) : (tensor<2x3x!quant.uniform:f32, 3.9370078740157481E-9>>) -> tensor<2x3xf32> -// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x3xf32>) -> tensor<*xf32> +// CHECK-DAG: %[[CONST:.*]] = arith.constant dense<0.000000e+00> : tensor<2x1024xf32> +// CHECK: %0 = "quantfork.qcast"(%[[CONST]]) : (tensor<2x1024xf32>) -> tensor<2x1024x!quant.uniform:f32, 3.9370078740157481E-9>> +// CHECK: %1 = "quantfork.dcast"(%0) : (tensor<2x1024x!quant.uniform:f32, 3.9370078740157481E-9>>) -> tensor<2x1024xf32> +// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> // CHECK: return %2 : tensor<*xf32> // CHECK-LABEL: func private @composite_matmul_fn -// CHECK: %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x3xf32>, tensor<2x3xf32>) -> tensor<*xf32> +// CHECK: %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: return %0 : tensor<*xf32> +} + +// ----- + +module { + func.func @conv2d(%arg0: tensor<1x3x4x3xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x512xf32>} : () -> tensor<2x3x3x512xf32> + %1 = "tf.PartitionedCall"(%arg0, %cst_1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_1} : (tensor<1x3x4x3xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32> + %2 = "tf.BiasAdd"(%1, %cst_0) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + func.return %2: tensor<*xf32> + } + func.func private @composite_conv2d_fn_1(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x512xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } + +// CHECK-LABEL: func @conv2d +// CHECK-DAG: %[[CONST_0:.*]] = arith.constant dense<0.000000e+00> : tensor<2xf32> +// CHECK-DAG: %[[CONST_1:.*]] = arith.constant dense<3.000000e+00> : tensor<2x3x3x512xf32> +// CHECK: %0 = "quantfork.qcast"(%[[CONST_1]]) : (tensor<2x3x3x512xf32>) -> tensor<2x3x3x512x!quant.uniform:f32, 0.023622047244094488>> +// CHECK: %1 = "quantfork.dcast"(%0) : (tensor<2x3x3x512x!quant.uniform:f32, 0.023622047244094488>>) -> tensor<2x3x3x512xf32> +// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_1} : (tensor<1x3x4x3xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32> +// CHECK: %3 = "tf.BiasAdd"(%2, %[[CONST_0]]) +// CHECK: return %3 : tensor<*xf32> + +// CHECK-LABEL: func private @composite_conv2d_fn_1 +// CHECK: %0 = "tf.Conv2D"(%arg0, %arg1) // CHECK: return %0 : tensor<*xf32> } // ----- + +module { + func.func @depthwise_conv(%arg0: tensor<1x3x4x512xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x512xf32>} : () -> tensor<2x3x3x512xf32> + %0 = "tf.PartitionedCall"(%arg0, %cst_1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn} : (tensor<1x3x4x512xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + func.return %1: tensor<*xf32> + } + func.func private @composite_depthwise_conv2d_fn(%arg0: tensor<1x3x4x512xf32>, %arg1: tensor<2x3x3x512xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x512xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } + +// CHECK-LABEL: func @depthwise_conv +// CHECK-DAG: %[[CONST_0:.*]] = arith.constant dense<0.000000e+00> : tensor<2xf32> +// CHECK-DAG: %[[CONST_1:.*]] = arith.constant dense<3.000000e+00> : tensor<2x3x1x1536xf32> +// CHECK: %0 = "quantfork.qcast"(%[[CONST_1]]) : (tensor<2x3x1x1536xf32>) -> tensor<2x3x1x1536x!quant.uniform:f32, 0.023622047244094488>> +// CHECK: %1 = "quantfork.dcast"(%0) : (tensor<2x3x1x1536x!quant.uniform:f32, 0.023622047244094488>>) -> tensor<2x3x1x1536xf32> +// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_0} : (tensor<1x3x4x512xf32>, tensor<2x3x1x1536xf32>) -> tensor<*xf32> +// CHECK: %3 = "tf.BiasAdd"(%2, %[[CONST_0]]) +// CHECK: return %3 : tensor<*xf32> + +// CHECK-LABEL: func private @composite_depthwise_conv2d_fn( +// CHECK-SAME: %arg0: tensor<1x3x4x512xf32>, +// CHECK-SAME: %arg1: tensor<2x3x3x512xf32>) + +// CHECK-LABEL: func private @composite_depthwise_conv2d_fn_0( +// CHECK-SAME: %arg0: tensor<1x3x4x512xf32>, +// CHECK-SAME: %arg1: tensor<2x3x1x1536xf32>) +// CHECK: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", +// CHECK: return %0 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_drq_per_channel.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_drq_per_channel.mlir new file mode 100644 index 00000000000..f2d80c0bf4e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_drq_per_channel.mlir @@ -0,0 +1,90 @@ +// RUN: tf-quant-opt %s -split-input-file -quant-preprocess-op -quant-prepare-quantize-drq='enable-per-channel-quantization=true' | FileCheck %s + +module { + func.func @matmul(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2x1024xf32>} : () -> tensor<2x1024xf32> + %1 = "tf.PartitionedCall"(%arg0, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + func.return %1: tensor<*xf32> + } + func.func private @composite_matmul_fn(%arg0: tensor<1x2x2x3xf32>, %arg1: tensor<2x1024xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } + +// CHECK-LABEL: func @matmul +// CHECK-DAG: %[[CONST:.*]] = arith.constant dense<0.000000e+00> : tensor<2x1024xf32> +// CHECK: %0 = "quantfork.qcast"(%[[CONST]]) : (tensor<2x1024xf32>) -> tensor<2x1024x!quant.uniform:f32, 3.9370078740157481E-9>> +// CHECK: %1 = "quantfork.dcast"(%0) : (tensor<2x1024x!quant.uniform:f32, 3.9370078740157481E-9>>) -> tensor<2x1024xf32> +// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: return %2 : tensor<*xf32> + +// CHECK-LABEL: func private @composite_matmul_fn +// CHECK: %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: return %0 : tensor<*xf32> +} + +// ----- + +module { + func.func @conv2d(%arg0: tensor<1x3x4x512xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x512x2xf32>} : () -> tensor<2x3x512x2xf32> + %1 = "tf.PartitionedCall"(%arg0, %cst_1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_1} : (tensor<1x3x4x512xf32>, tensor<2x3x512x2xf32>) -> tensor<*xf32> + %2 = "tf.BiasAdd"(%1, %cst_0) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + func.return %2: tensor<*xf32> + } + func.func private @composite_conv2d_fn_1(%arg0: tensor<1x3x4x512xf32>, %arg1: tensor<2x3x512x2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x512xf32>, tensor<2x3x512x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } + +// CHECK-LABEL: func @conv2d +// CHECK-DAG: %[[CONST_0:.*]] = arith.constant dense<0.000000e+00> : tensor<2xf32> +// CHECK-DAG: %[[CONST_1:.*]] = arith.constant dense<3.000000e+00> : tensor<2x3x512x2xf32> +// CHECK: %0 = "quantfork.qcast"(%[[CONST_1]]) : (tensor<2x3x512x2xf32>) -> tensor<2x3x512x2x!quant.uniform:f32:3, {0.023622047244094488,0.023622047244094488}>> +// CHECK: %1 = "quantfork.dcast"(%0) : (tensor<2x3x512x2x!quant.uniform:f32:3, {0.023622047244094488,0.023622047244094488}>>) -> tensor<2x3x512x2xf32> +// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_1} : (tensor<1x3x4x512xf32>, tensor<2x3x512x2xf32>) -> tensor<*xf32> +// CHECK: %3 = "tf.BiasAdd"(%2, %[[CONST_0]]) +// CHECK: return %3 : tensor<*xf32> + +// CHECK-LABEL: func private @composite_conv2d_fn_1 +// CHECK: %0 = "tf.Conv2D"(%arg0, %arg1) +// CHECK: return %0 : tensor<*xf32> +} + +// ----- + +module { + func.func @depthwise_conv(%arg0: tensor<1x3x4x512xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x512xf32>} : () -> tensor<2x3x3x512xf32> + %0 = "tf.PartitionedCall"(%arg0, %cst_1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn} : (tensor<1x3x4x512xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + func.return %1: tensor<*xf32> + } + func.func private @composite_depthwise_conv2d_fn(%arg0: tensor<1x3x4x512xf32>, %arg1: tensor<2x3x3x512xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x512xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } + +// CHECK-LABEL: func @depthwise_conv +// CHECK-DAG: %[[CONST_0:.*]] = arith.constant dense<0.000000e+00> : tensor<2xf32> +// CHECK-DAG: %[[CONST_1:.*]] = arith.constant dense<3.000000e+00> : tensor<2x3x1x1536xf32> +// CHECK: %0 = "quantfork.qcast"(%[[CONST_1]]) : (tensor<2x3x1x1536xf32>) -> tensor<2x3x1x1536x!quant.uniform:f32:3, {0.023622047244094488, +// CHECK: %1 = "quantfork.dcast"(%0) : (tensor<2x3x1x1536x!quant.uniform:f32:3, {0.023622047244094488, +// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_0} : (tensor<1x3x4x512xf32>, tensor<2x3x1x1536xf32>) -> tensor<*xf32> +// CHECK: %3 = "tf.BiasAdd"(%2, %[[CONST_0]]) +// CHECK: return %3 : tensor<*xf32> + +// CHECK-LABEL: func private @composite_depthwise_conv2d_fn( +// CHECK-SAME: %arg0: tensor<1x3x4x512xf32>, +// CHECK-SAME: %arg1: tensor<2x3x3x512xf32>) + +// CHECK-LABEL: func private @composite_depthwise_conv2d_fn_0( +// CHECK-SAME: %arg0: tensor<1x3x4x512xf32>, +// CHECK-SAME: %arg1: tensor<2x3x1x1536xf32>) +// CHECK: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", +// CHECK: return %0 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_ptq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_ptq.mlir index a5937136000..1ab40ea3f78 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_ptq.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_ptq.mlir @@ -95,11 +95,11 @@ module { // CHECK-DAG: %[[cst_1:.*]] = arith.constant dense<{{.*}}> : tensor<2x3x3x2xf32> // CHECK: %[[q0:.*]] = "quantfork.qcast"(%[[cst]]) {volatile} -// CHECK-SAME: tensor<2x!quant.uniform> +// CHECK-SAME: tensor<2x!quant.uniform> // CHECK: %[[dq0:.*]] = "quantfork.dcast"(%[[q0]]) // CHECK: %[[q1:.*]] = "quantfork.qcast"(%[[cst_1]]) {volatile} -// CHECK-SAME: tensor<2x3x3x2x!quant.uniform:f32:3, {0.075176584439014829,0.072960192762960605} +// CHECK-SAME: tensor<2x3x3x2x!quant.uniform:f32, 0.075176584439014829>> // CHECK: %[[dq1:.*]] = "quantfork.dcast"(%[[q1]]) // CHECK: %[[q2:.*]] = "quantfork.qcast"(%arg0) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_ptq_per_channel.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_ptq_per_channel.mlir new file mode 100644 index 00000000000..baeb053327b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/prepare_quantize_ptq_per_channel.mlir @@ -0,0 +1,47 @@ +// RUN: tf-quant-opt %s -split-input-file -quant-prepare-quantize='post-training-quantize=true enable-per-channel-quantization=true' | FileCheck %s + +module { + func.func private @conv_with_bias_and_relu(%arg0: tensor<1x3x4x3xf32>) -> tensor<*xf32> { + %cst = "tf.Const"() {device = "", value = dense<[7.11401462, 7.05456924]> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<[[[[-6.30731344, 5.4962182], [1.80364347, -7.64542675], [-2.11145878, -7.08605719]], [[-9.54062747, -6.14013147], [6.12640238, -4.18223286], [5.05738974, 8.99269962]], [[3.3535192, 0.84816426], [-6.64676809, -7.95477629], [5.81315517, 9.21566581]]], [[[1.38622558, 4.63866329], [9.54742622, -1.43770897], [-7.96835279, 8.99996852]], [[0.989735424, -4.83384752], [-7.27702999, 1.17216611], [9.33735656, 0.728900194]], [[5.1286211, 8.98645591], [1.55008793, -3.85491467], [3.7003777, 9.26594448]]]]> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[1.27501142, 149.824783]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.PartitionedCall"(%0, %cst_0, %cst) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", device = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_10} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "quantfork.stats"(%1) {layerStats = dense<[0.000000e+00, 6.000000e+00]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> + } + + func.func private @composite_conv2d_with_bias_and_relu6_fn_10(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf.tf_quant.composite_function} { + %0 = "quantfork.stats"(%arg1) {layerStats = dense<[-9.54062747, 9.54742622]> : tensor<2xf32>} : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2xf32> + %1 = "quantfork.stats"(%arg0) {layerStats = dense<[1.27501142, 149.824783]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %2 = "tf.Conv2D"(%1, %0) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %3 = "quantfork.stats"(%arg2) {layerStats = dense<[7.05456924, 7.11401462]> : tensor<2xf32>} : (tensor<2xf32>) -> tensor<2xf32> + %4 = "quantfork.stats"(%2) {layerStats = dense<[-2795.36523, 4609.57373]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + %5 = "tf.BiasAdd"(%4, %3) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %6 = "quantfork.stats"(%5) {layerStats = dense<[-2788.31055, 4616.62842]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + %7 = "tf.Relu6"(%6) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + %8 = "quantfork.stats"(%7) {layerStats = dense<[0.000000e+00, 6.000000e+00]> : tensor<2xf32>} : (tensor<*xf32>) -> tensor<*xf32> + return %8 : tensor<*xf32> + } + +// CHECK-LABEL: conv_with_bias_and_relu +// CHECK-DAG: %[[cst:.*]] = arith.constant dense<[7.11401462, 7.05456924]> : tensor<2xf32> +// CHECK-DAG: %[[cst_1:.*]] = arith.constant dense<{{.*}}> : tensor<2x3x3x2xf32> + +// CHECK: %[[q0:.*]] = "quantfork.qcast"(%[[cst]]) {volatile} +// CHECK-SAME: tensor<2x!quant.uniform> +// CHECK: %[[dq0:.*]] = "quantfork.dcast"(%[[q0]]) + +// CHECK: %[[q1:.*]] = "quantfork.qcast"(%[[cst_1]]) {volatile} +// CHECK-SAME: tensor<2x3x3x2x!quant.uniform:f32:3, {0.075176584439014829,0.072960192762960605}>> +// CHECK: %[[dq1:.*]] = "quantfork.dcast"(%[[q1]]) + +// CHECK: %[[q2:.*]] = "quantfork.qcast"(%arg0) +// CHECK-SAME: tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[dq2:.*]] = "quantfork.dcast"(%[[q2]]) + +// CHECK: %[[call:.*]] = "tf.PartitionedCall"(%[[dq2]], %[[dq1]], %[[dq0]]) +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_10 +// CHECK: %[[q3:.*]] = "quantfork.qcast"(%[[call]]) {volatile} +// CHECK-SAME: tensor<*x!quant.uniform> +// CHECK: %[[dq3:.*]] = "quantfork.dcast"(%[[q3]]) +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/preprocess_op.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/preprocess_op.mlir new file mode 100644 index 00000000000..0ef69f6f6f7 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/preprocess_op.mlir @@ -0,0 +1,39 @@ +// RUN: tf-quant-opt %s -split-input-file -quant-preprocess-op | FileCheck %s + +module { + // For UniformQuantized depthwise convolution, tensor shape should have + // transformed from [H,W,C,M] to [H,W,1,CxM], + func.func @depthwise_conv(%arg0: tensor<1x3x4x3xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<6xf32>} : () -> tensor<6xf32> + %cst_1 = "tf.Const"() {value = dense<[[[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]],[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]],[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]]],[[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]],[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]],[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]]]]> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.PartitionedCall"(%arg0, %cst_1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<6xf32>) -> tensor<*xf32> + func.return %1: tensor<*xf32> + } + func.func private @composite_depthwise_conv2d_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } + +// CHECK-LABEL: func @depthwise_conv +// CHECK-DAG: %[[CONST_0:.*]] = arith.constant dense<0.000000e+00> : tensor<6xf32> +// CHECK: %[[CONST_1:.*]] = arith.constant dense +// CHECK-NOT: tensor<2x3x3x2xf32> +// CHECK-SAME: tensor<2x3x1x6xf32> +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1:.*]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_0} : (tensor<1x3x4x3xf32>, tensor<2x3x1x6xf32>) -> tensor<*xf32> +// CHECK: %[[BIAS_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_0]], %[[CONST_0:.*]]) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<6xf32>) -> tensor<*xf32> +// CHECK: return %[[BIAS_0:.*]] : tensor<*xf32> + +// CHECK-LABEL: func private @composite_depthwise_conv2d_fn( +// CHECK-SAME: %arg0: tensor<1x3x4x3xf32>, +// CHECK-SAME: %arg1: tensor<2x3x3x2xf32>) + +// CHECK-LABEL: func private @composite_depthwise_conv2d_fn_0( +// CHECK-SAME: %arg0: tensor<1x3x4x3xf32>, +// CHECK-SAME: %arg1: tensor<2x3x1x6xf32>) +// CHECK: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", +// CHECK: return %0 : tensor<*xf32> +} + diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir index 420dcc6ddb3..10bedcff581 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize.mlir @@ -67,15 +67,13 @@ func.func @avgpool_test(%arg0: tensor<*xf32>) -> tensor<*xf32> { func.return %4 : tensor<*xf32> } -// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor} : () -> tensor // CHECK: %[[q:.*]] = "quantfork.qcast"(%arg0) // CHECK: %[[sc1:.*]] = "quantfork.scast"(%[[q]]) : (tensor<*x!quant.uniform>) // CHECK: %[[fcast:.*]] = "tf.Cast"(%[[sc1]]) {Truncate = false} : (tensor<*xi8>) -> tensor<*xf32> // CHECK: %[[avgpool_f32:.*]] = "tf.AvgPool"(%[[fcast]]) // CHECK-SAME: (tensor<*xf32>) -> tensor<*xf32> -// CHECK: %[[add:.*]] = "tf.AddV2"(%[[avgpool_f32]], %[[cst]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// CHECK: %[[floor:.*]] = "tf.Floor"(%[[add]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: %[[icast:.*]] = "tf.Cast"(%[[floor]]) {Truncate = false} : (tensor<*xf32>) -> tensor<*xi8> +// CHECK: %[[round:.*]] = "tf.Round"(%[[avgpool_f32]]) +// CHECK: %[[icast:.*]] = "tf.Cast"(%[[round]]) {Truncate = false} : (tensor<*xf32>) -> tensor<*xi8> // CHECK: %[[sc2:.*]] = "quantfork.scast"(%[[icast]]) // CHECK: %[[dq:.*]] = "quantfork.dcast"(%[[sc2]]) : (tensor<*x!quant.uniform>) // CHECK: return %[[dq]] diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir index ac4f42d97e4..915aa2eb1c5 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions.mlir @@ -1,4 +1,4 @@ -// RUN: tf-quant-opt %s -split-input-file -quant-insert-quantized-functions -quant-quantize-composite-functions -symbol-dce | FileCheck %s +// RUN: tf-quant-opt %s -split-input-file -quant-insert-quantized-functions -quant-quantize-composite-functions | FileCheck %s module { func.func @conv(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>, tensor<*xf32>) { @@ -66,6 +66,17 @@ module { // CHECK-SAME: (%arg0: tensor<1x2x2x3xi8>, %arg1: tensor<2x2x3x2xi8>, %arg2: tensor<2xi32>, %arg3: tensor, %arg4: tensor, %arg5: tensor<2xf32>, %arg6: tensor<2xi32>, %arg7: tensor<2xf32>, %arg8: tensor<2xi32>, %arg9: tensor, %arg10: tensor) -> tensor<*xi8> // CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D" // CHECK-SAME: {dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} + +// CHECK: -------- Quantization Summary -------- +// CHECK: Number of quantized layers in the model +// CHECK: -------------------------------- +// CHECK: Name Count/Total +// CHECK: ================================ +// CHECK: Conv2D 1/2 + +// CHECK: Number of quantized layers with quantized outputs: 1/1 +// CHECK: Number of quantize layers added: 1 +// CHECK: Number of dequantize layers added: 1 } // ----- @@ -105,6 +116,17 @@ module { // CHECK-SAME: (%arg0: tensor<1x2x2x3xi8>, %arg1: tensor<2x2x3x2xi8>, %arg2: tensor<2xi32>, %arg3: tensor, %arg4: tensor, %arg5: tensor<2xf32>, %arg6: tensor<2xi32>, %arg7: tensor<2xf32>, %arg8: tensor<2xi32>, %arg9: tensor, %arg10: tensor) -> tensor<*xi8> // CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D" // CHECK-SAME: {dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} + +// CHECK: -------- Quantization Summary -------- +// CHECK: Number of quantized layers in the model +// CHECK: -------------------------------- +// CHECK: Name Count/Total +// CHECK: ================================ +// CHECK: Conv2D 1/1 + +// CHECK: Number of quantized layers with quantized outputs: 1/1 +// CHECK: Number of quantize layers added: 1 +// CHECK: Number of dequantize layers added: 1 } // ----- @@ -131,7 +153,6 @@ module { } // CHECK-LABEL: func @conv_with_avgpool -// CHECK-DAG: %[[cst:.*]] = "tf.Const"() {value = dense<5.000000e-01> : tensor} : () -> tensor // CHECK: %[[quantize:.*]] = "tf.PartitionedCall"(%arg0 // CHECK-SAME: f = @quantize_i8 // CHECK: %[[conv_quant:.*]] = "tf.PartitionedCall"(%[[quantize]] @@ -139,11 +160,42 @@ module { // CHECK-SAME: (tensor<1x2x2x3xi8>, tensor<2x2x3x2xi8>, tensor<2xi32>, tensor, tensor, tensor<2xf32>, tensor<2xi32>, tensor<2xf32>, tensor<2xi32>, tensor, tensor) -> tensor<*xi8> // CHECK: %[[cast_1:.*]] = "tf.Cast"(%[[conv_quant]]) {Truncate = false} : (tensor<*xi8>) -> tensor<*xf32> // CHECK: %[[avgpool:.*]] = "tf.AvgPool"(%[[cast_1]]) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: %[[add:.*]] = "tf.AddV2"(%[[avgpool]], %[[cst]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> -// CHECK: %[[floor:.*]] = "tf.Floor"(%[[add]]) : (tensor<*xf32>) -> tensor<*xf32> -// CHECK: %[[cast_2:.*]] = "tf.Cast"(%[[floor]]) {Truncate = false} : (tensor<*xf32>) -> tensor<*xi8> +// CHECK: %[[round:.*]] = "tf.Round"(%[[avgpool]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: %[[cast_2:.*]] = "tf.Cast"(%[[round]]) {Truncate = false} : (tensor<*xf32>) -> tensor<*xi8> // CHECK: %[[dequantize:.*]] = "tf.PartitionedCall"(%[[cast_2]] // CHECK-SAME: f = @dequantize_i8 // CHECK: return %[[dequantize]] +// CHECK: -------- Quantization Summary -------- +// CHECK: Number of quantized layers in the model +// CHECK: -------------------------------- +// CHECK: Name Count/Total +// CHECK: ================================ +// CHECK: Conv2D 1/1 + +// CHECK: Number of quantized layers with quantized outputs: 1/1 +// CHECK: Number of quantize layers added: 1 +// CHECK: Number of dequantize layers added: 1 +} + + +// ----- + +module { + func.func @float_einsum(%arg0: tensor, %arg1: tensor<32x2x16xf32>) -> (tensor) { + %0 = "tf.Einsum"(%arg0, %arg1) {equation = "abc,cde->abde"} : (tensor, tensor<32x2x16xf32>) -> tensor + func.return %0 : tensor + } + +// CHECK-LABEL: func @float_einsum +// CHECK: -------- Quantization Summary -------- +// CHECK: Number of quantized layers in the model +// CHECK: -------------------------------- +// CHECK: Name Count/Total +// CHECK: ================================ +// CHECK: Einsum 0/1 + +// CHECK: Number of quantized layers with quantized outputs: 0/0 +// CHECK: Number of quantize layers added: 0 +// CHECK: Number of dequantize layers added: 0 } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_drq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_drq.mlir index 49bd9ffceab..b78a891e5b2 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_drq.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_drq.mlir @@ -1,39 +1,137 @@ -// RUN: tf-quant-opt %s -split-input-file -quant-insert-quantized-functions='quantization-method=drq target-opset=UNIFORM_QUANTIZED' -quant-quantize-composite-functions='quantization-method=drq' -symbol-dce | FileCheck %s +// RUN: tf-quant-opt %s -split-input-file -quant-insert-quantized-functions='quantization-method=drq target-opset=UNIFORM_QUANTIZED' -quant-quantize-composite-functions='quantization-method=drq target-opset=UNIFORM_QUANTIZED' -symbol-dce | FileCheck %s module { - func.func @matmul(%arg0: tensor<2x512xf32>) -> (tensor<*xf32>) { - %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<512x512xf32>} : () -> tensor<512x512xf32> - %1 = "tf.PartitionedCall"(%arg0, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1} : (tensor<2x512xf32>, tensor<512x512xf32>) -> tensor<*xf32> + // TODO(b/260020937): Support transpose_a, transpose_b for matmul. + func.func @matmul(%arg0: tensor<2x12xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<12x2xf32>} : () -> tensor<12x2xf32> + %1 = "tf.PartitionedCall"(%arg0, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1} : (tensor<2x12xf32>, tensor<12x2xf32>) -> tensor<*xf32> func.return %1: tensor<*xf32> } - func.func private @composite_matmul_fn_1(%arg0: tensor<2x512xf32>, %arg1: tensor<512x512xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { - %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_b", device = "", transpose_a = false, transpose_b = false} : (tensor<2x512xf32>, tensor<512x512xf32>) -> tensor<*xf32> + func.func private @composite_matmul_fn_1(%arg0: tensor<2x12xf32>, %arg1: tensor<12x2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_b", device = "", transpose_a = false, transpose_b = false} : (tensor<2x12xf32>, tensor<12x2xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: func @matmul -// CHECK-DAG: %[[q_w:.*]] = "tf.Const"() -// CHECK-SAME: tensor<512x512x!tf_type.qint8>} : () -> tensor<512x512x!tf_type.qint8> +// CHECK-DAG: %[[q_w:.*]] = "tf.Const"() {value = #tf_type : tensor} : () -> tensor // CHECK-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor -// CHECK: %0 = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) -// CHECK-SAME: f = @quantized_matmul_fn +// CHECK: %0 = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", +// CHECK-SAME: f = @quantized_matmul_fn_0} : (tensor<2x12xf32>, tensor<12x2x!tf_type.qint8>, tensor, tensor) -> tensor<*xf32> // CHECK-LABEL: func private @quantized_matmul_fn_0 -//CHECK: %0 = "tf.UniformQuantizedDotHybrid"(%arg0, %arg1, %arg2, %arg3) +// CHECK: %0 = "tf.UniformQuantizedDotHybrid"(%arg0, %arg1, %arg2, %arg3) +// CHECK-SAME: rhs_quantization_axis = -1 : i64 +// CHECK-SAME: rhs_quantization_max_val = 127 : i64 +// CHECK-SAME: rhs_quantization_min_val = -127 : i64 + } // ----- module { - func.func @conv(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>) { - %weight = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> - %conv = "tf.Conv2D"(%arg0, %weight) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> - func.return %conv : tensor<*xf32> + func.func @conv(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %weight = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %1 = "tf.PartitionedCall"(%arg0, %weight) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_1} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %2 = "tf.PartitionedCall"(%arg0, %weight) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_2} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + func.return %1, %2 : tensor<*xf32>, tensor<*xf32> + } + + func.func private @composite_conv2d_fn_1(%arg0: tensor<1x2x2x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %conv = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 2, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + return %conv : tensor<*xf32> + } + + func.func private @composite_conv2d_fn_2(%arg0: tensor<1x2x2x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %conv = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 2, 2, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + return %conv : tensor<*xf32> } // CHECK-LABEL: func @conv -// CHECK-DAG: %[[w:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> -// CHECK: %[[conv:.*]] = "tf.Conv2D"(%arg0, %[[w]]) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> +// CHECK-DAG: %[[q_w:.*]] = "tf.Const"() {value = #tf_type : tensor} : () -> tensor +// CHECK-DAG: %[[w_zp:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[quantize_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[w_scale]], %[[w_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_fn_1} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2x!tf_type.qint8>, tensor, tensor) -> tensor<*xf32> +// CHECK: %[[quantize_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w]], %[[w_scale]], %[[w_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_fn_0} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2x!tf_type.qint8>, tensor, tensor) -> tensor<*xf32> +// CHECK: return %[[quantize_1]], %[[quantize_2]] + +// CHECK-LABEL: func private @quantized_conv2d_fn_0 +// CHECK: %[[CONV2D_0:.*]] = "tf.UniformQuantizedConvolutionHybrid" +// CHECK-SAME: batch_group_count = 1 : i64 +// CHECK-SAME: dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02" +// CHECK-SAME: explicit_padding = [] +// CHECK-SAME: feature_group_count = 1 : i64 +// CHECK-SAME: lhs_dilation = [1, 1] +// CHECK-SAME: padding = "VALID" +// CHECK-SAME: rhs_dilation = [2, 2] +// CHECK-SAME: rhs_quantization_axis = -1 : i64 +// CHECK-SAME: rhs_quantization_max_val = 127 : i64 +// CHECK-SAME: rhs_quantization_min_val = -127 : i64 +// CHECK-SAME: window_strides = [1, 2] +// CHECK-SAME: (tensor<1x2x2x3xf32>, tensor<2x3x3x2x!tf_type.qint8>, tensor, tensor) -> tensor<*xf32> + +// CHECK-LABEL: func private @quantized_conv2d_fn_1 +// CHECK: %[[CONV2D_0:.*]] = "tf.UniformQuantizedConvolutionHybrid" +// CHECK-SAME: padding = "SAME" + +} + +// ----- + +module { + func.func @depthwise_conv(%arg0: tensor<1x3x4x3xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %cst_2 = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.PartitionedCall"(%arg0, %cst_1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn} : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> + %2 = "tf.PartitionedCall"(%arg0, %cst_2) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_1} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + func.return %1, %2: tensor<*xf32>, tensor<*xf32> + } + func.func private @composite_depthwise_conv2d_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x1xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } + func.func private @composite_depthwise_conv2d_fn_1(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", dilations = [1, 2, 2, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } + +// CHECK-LABEL: func @depthwise_conv +// CHECK-DAG: %[[bias:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32> +// CHECK-DAG: %[[q_w1:.*]] = "tf.Const"() {value = #tf_type tensor<2x3x1x3x!tf_type.qint8> +// CHECK-DAG: %[[q_w2:.*]] = "tf.Const"() {value = #tf_type tensor<2x3x1x6x!tf_type.qint8> +// CHECK-DAG: %[[w_scale:.*]] = "tf.Const"() {value = dense<0.0236220472> : tensor} : () -> tensor +// CHECK-DAG: %[[w_zp:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + +// CHECK: %[[quantize_1:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w1]], %[[w_scale]], %[[w_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_depthwise_conv2d_fn_1} : (tensor<1x3x4x3xf32>, tensor<2x3x1x3x!tf_type.qint8>, tensor, tensor) -> tensor<*xf32> +// CHECK: %[[quantize_1_add:.*]] = "tf.BiasAdd"(%[[quantize_1]], %[[bias]]) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> +// CHECK: %[[quantize_2:.*]] = "tf.PartitionedCall"(%arg0, %[[q_w2]], %[[w_scale]], %[[w_zp]]) {config = "", config_proto = "", executor_type = "", f = @quantized_depthwise_conv2d_fn_0} : (tensor<1x3x4x3xf32>, tensor<2x3x1x6x!tf_type.qint8>, tensor, tensor) -> tensor<*xf32> +// CHECK: return %[[quantize_1_add]], %[[quantize_2]] + +// CHECK-LABEL: func private @quantized_depthwise_conv2d_fn_0 +// CHECK: %[[CONV2D_0:.*]] = "tf.UniformQuantizedConvolutionHybrid" +// CHECK-SAME: batch_group_count = 1 : i64, +// CHECK-SAME: dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02" +// CHECK-SAME: explicit_padding = [], +// CHECK-SAME: feature_group_count = 3 : i64, +// CHECK-SAME: lhs_dilation = [1, 1], +// CHECK-SAME: padding = "VALID", +// CHECK-SAME: rhs_dilation = [2, 2], +// CHECK-SAME: rhs_quantization_axis = -1 : i64, +// CHECK-SAME: rhs_quantization_max_val = 127 : i64, +// CHECK-SAME: rhs_quantization_min_val = -127 : i64, +// CHECK-SAME: window_strides = [1, 2] +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x1x6x!tf_type.qint8>, tensor, tensor) -> tensor<*xf32> + +// CHECK-LABEL: func private @quantized_depthwise_conv2d_fn_1 +// CHECK: %[[CONV2D_0:.*]] = "tf.UniformQuantizedConvolutionHybrid" +// CHECK-SAME: padding = "SAME" } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_weight_only.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_weight_only.mlir new file mode 100644 index 00000000000..c6a5ff74792 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_weight_only.mlir @@ -0,0 +1,101 @@ +// RUN: tf-quant-opt %s -split-input-file -quant-insert-quantized-functions='quantization-method=weight_only target-opset=XLA' -quant-quantize-composite-functions='quantization-method=weight_only target-opset=XLA' -symbol-dce | FileCheck %s + +module { + // TODO(b/260020937): Support transpose_a, transpose_b for matmul. + func.func @matmul(%arg0: tensor<2x12xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<12x2xf32>} : () -> tensor<12x2xf32> + %1 = "tf.PartitionedCall"(%arg0, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1} : (tensor<2x12xf32>, tensor<12x2xf32>) -> tensor<*xf32> + func.return %1: tensor<*xf32> + } + func.func private @composite_matmul_fn_1(%arg0: tensor<2x12xf32>, %arg1: tensor<12x2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_b", device = "", transpose_a = false, transpose_b = false} : (tensor<2x12xf32>, tensor<12x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } +} + +// CHECK-LABEL: func @matmul +// CHECK-DAG: %[[q_w:.*]] = "tf.Const"() {value = dense<0> : tensor<12x2xi8>} : () -> tensor<12x2xi8> +// CHECK-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<3.93700805E-9> : tensor} : () -> tensor +// CHECK-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[dq_w:.*]] = "tf.PartitionedCall"(%[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", +// CHECK-SAME: f = @dequantize_i8} : (tensor<12x2xi8>, tensor, tensor) -> tensor<12x2xf32> +// CHECK: %[[quantize:.*]] = "tf.PartitionedCall"(%arg0, %[[dq_w]]) {config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1} : +// CHECK-SAME: (tensor<2x12xf32>, tensor<12x2xf32>) -> tensor<*xf32> +// CHECK: return %[[quantize]] + +// ----- + +module { + func.func @conv(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %weight = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %1 = "tf.PartitionedCall"(%arg0, %weight) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_1} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %2 = "tf.PartitionedCall"(%arg0, %weight) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_2} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + func.return %1, %2 : tensor<*xf32>, tensor<*xf32> + } + + func.func private @composite_conv2d_fn_1(%arg0: tensor<1x2x2x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %conv = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 2, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + return %conv : tensor<*xf32> + } + + func.func private @composite_conv2d_fn_2(%arg0: tensor<1x2x2x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %conv = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 2, 2, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x2x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + return %conv : tensor<*xf32> + } + +// CHECK-LABEL: func @conv +// CHECK-DAG: %[[q_w:.*]] = "tf.Const"() +// CHECK-DAG: %[[scale:.*]] = "tf.Const"() +// CHECK-DAG: %[[zp:.*]] = "tf.Const"() +// CHECK: %[[dq_w:.*]] = "tf.PartitionedCall"(%[[q_w]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", +// CHECK-SAME: f = @dequantize_i8} : (tensor<2x3x3x2xi8>, tensor, tensor) -> tensor<2x3x3x2xf32> +// CHECK: %[[quantize_1:.*]] = "tf.PartitionedCall"(%arg0, %[[dq_w]]) {config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_1} : +// CHECK-SAME: (tensor<1x2x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> +// CHECK: %[[quantize_2:.*]] = "tf.PartitionedCall"(%arg0, %[[dq_w]]) {config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_2} : +// CHECK-SAME: (tensor<1x2x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> +// CHECK: return %[[quantize_1]], %[[quantize_2]] + +} + +// ----- + +module { + func.func @depthwise_conv(%arg0: tensor<1x3x4x3xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %cst_2 = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.PartitionedCall"(%arg0, %cst_1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn} : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> + %2 = "tf.PartitionedCall"(%arg0, %cst_2) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_1} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + func.return %1, %2: tensor<*xf32>, tensor<*xf32> + } + func.func private @composite_depthwise_conv2d_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x1xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } + func.func private @composite_depthwise_conv2d_fn_1(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", dilations = [1, 2, 2, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } + +// CHECK-LABEL: func @depthwise_conv +// CHECK-DAG: %[[q_w1:.*]] = "tf.Const"() {value = dense<127> : tensor<2x3x3x1xi8>} +// CHECK-DAG: %[[q_w2:.*]] = "tf.Const"() {value = dense<127> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8> +// CHECK-DAG: %[[scale:.*]] = "tf.Const"() {value = dense<0.0236220472> : tensor} : () -> tensor +// CHECK-DAG: %[[zp:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK-DAG: %[[bias:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<3xf32>} +// CHECK: %[[dq_w1:.*]] = "tf.PartitionedCall"(%[[q_w1]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", +// CHECK-SAME: f = @dequantize_i8} : (tensor<2x3x3x1xi8>, tensor, tensor) -> tensor<2x3x3x1xf32> +// CHECK: %[[dq_w2:.*]] = "tf.PartitionedCall"(%[[q_w2]], %[[scale]], %[[zp]]) {config = "", config_proto = "", executor_type = "", +// CHECK-SAME: f = @dequantize_i8} : (tensor<2x3x3x2xi8>, tensor, tensor) -> tensor<2x3x3x2xf32> +// CHECK: %[[quantize_1:.*]] = "tf.PartitionedCall"(%arg0, %[[dq_w1]]) {config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn} : +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> +// CHECK: %[[bias_add:.*]] = "tf.BiasAdd"(%[[quantize_1]], %[[bias]]) +// CHECK: %[[quantize_2:.*]] = "tf.PartitionedCall"(%arg0, %[[dq_w2]]) {config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_1} : +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> +// CHECK: return %[[bias_add]], %[[quantize_2]] +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_xla.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_xla.mlir index 0a74c393abc..5ba40e0eb1d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_xla.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_composite_functions_xla.mlir @@ -1,4 +1,4 @@ -// RUN: tf-quant-opt %s -split-input-file -quant-insert-quantized-functions -quant-quantize-composite-functions='target-opset=XLA' -symbol-dce | FileCheck %s +// RUN: tf-quant-opt %s -split-input-file -quant-insert-quantized-functions -quant-quantize-composite-functions='target-opset=XLA' | FileCheck %s module { func.func @conv_with_single_layer(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>) { @@ -33,6 +33,17 @@ module { // CHECK-SAME: (%arg0: tensor<1x2x2x3xi8>, %arg1: tensor<2x2x3x2xi8>, %arg2: tensor<2xi32>, %arg3: tensor, %arg4: tensor, %arg5: tensor<2xf32>, %arg6: tensor<2xi32>, %arg7: tensor<2xf32>, %arg8: tensor<2xi32>, %arg9: tensor, %arg10: tensor) -> tensor<*xf32> // CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D" // CHECK-SAME: {dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} + +// CHECK: -------- Quantization Summary -------- +// CHECK: Number of quantized layers in the model +// CHECK: -------------------------------- +// CHECK: Name Count/Total +// CHECK: ================================ +// CHECK: Conv2D 1/1 + +// CHECK: Number of quantized layers with quantized outputs: 0/1 +// CHECK: Number of quantize layers added: 1 +// CHECK: Number of dequantize layers added: 0 } // ----- @@ -70,6 +81,17 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK: %[[conv_quant2:.*]] = "tf.PartitionedCall"(%[[conv_quant]] // CHECK-SAME: f = @quantized_conv2d_float_output_fn_0 // CHECK: return %[[conv_quant2]] + +// CHECK: -------- Quantization Summary -------- +// CHECK: Number of quantized layers in the model +// CHECK: -------------------------------- +// CHECK: Name Count/Total +// CHECK: ================================ +// CHECK: Conv2D 2/2 + +// CHECK: Number of quantized layers with quantized outputs: 1/2 +// CHECK: Number of quantize layers added: 1 +// CHECK: Number of dequantize layers added: 0 } // ----- @@ -105,4 +127,15 @@ module { // CHECK: %[[dequantize:.*]] = "tf.PartitionedCall"(%[[maxpool]] // CHECK-SAME: f = @dequantize_i8 // CHECK: return %[[dequantize]] + +// CHECK: -------- Quantization Summary -------- +// CHECK: Number of quantized layers in the model +// CHECK: -------------------------------- +// CHECK: Name Count/Total +// CHECK: ================================ +// CHECK: Conv2D 1/1 + +// CHECK: Number of quantized layers with quantized outputs: 1/1 +// CHECK: Number of quantize layers added: 1 +// CHECK: Number of dequantize layers added: 1 } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_drq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_drq.mlir index 2d8670480ae..abe0c997195 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_drq.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_drq.mlir @@ -4,18 +4,18 @@ module { func.func @matmul(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>) { - %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %1 = "tf.PartitionedCall"(%arg0, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x3xf32>) -> tensor<*xf32> + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2x1024xf32>} : () -> tensor<2x1024xf32> + %1 = "tf.PartitionedCall"(%arg0, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> func.return %1: tensor<*xf32> } - func.func private @composite_matmul_fn(%arg0: tensor<1x2x2x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { - %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x3xf32>, tensor<2x3xf32>) -> tensor<*xf32> + func.func private @composite_matmul_fn(%arg0: tensor<1x2x2x3xf32>, %arg1: tensor<2x1024xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -// CHECK: %[[cst:.*]] = "arith.constant"() {value = dense<0.000000e+00> : tensor<2x3xf32>} : () -> tensor<2x3xf32> -// CHECK: %[[q_cst:.*]] = "quantfork.qcast"(%[[cst]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 3.9370078740157481E-9>> -// CHECK: %[[out:.*]] = "tf.PartitionedCall"(%arg0, %[[q_cst]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x3x!quant.uniform:f32, 3.9370078740157481E-9>>) -> tensor<*xf32> +// CHECK: %[[cst:.*]] = "arith.constant"() {value = dense<0.000000e+00> : tensor<2x1024xf32>} : () -> tensor<2x1024xf32> +// CHECK: %[[q_cst:.*]] = "quantfork.qcast"(%[[cst]]) : (tensor<2x1024xf32>) -> tensor<2x1024x!quant.uniform:f32, 3.9370078740157481E-9>> +// CHECK: %[[out:.*]] = "tf.PartitionedCall"(%arg0, %[[q_cst]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024x!quant.uniform:f32, 3.9370078740157481E-9>>) -> tensor<*xf32> // CHECK: "func.return"(%[[out]]) : (tensor<*xf32>) -> () } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_xla.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_xla.mlir new file mode 100644 index 00000000000..9123e41967e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_xla.mlir @@ -0,0 +1,136 @@ +// RUN: tf-quant-opt %s -split-input-file -quant-lift-quantizable-spots-as-functions -quant-quantize='target-opset=XLA' -verify-each=false | FileCheck %s + +func.func private @conv(%input: tensor<1x3x4x3xf32> {tf._user_specified_name = "input_tensor"}) -> tensor<*xf32> attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<1x3x4x3>]} { + %weight = arith.constant dense_resource<__elided__> : tensor<2x3x3x2xf32> + %bias = arith.constant dense<[7.11401462, 7.05456924]> : tensor<2xf32> + + %q_input= "quantfork.qcast"(%input) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> + %dq_input= "quantfork.dcast"(%q_input) : (tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3xf32> + %q_weight = "quantfork.qcast"(%weight) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> + %dq_weight = "quantfork.dcast"(%q_weight) : (tensor<2x3x3x2x!quant.uniform>) -> tensor<2x3x3x2xf32> + %q_bias = "quantfork.qcast"(%bias) : (tensor<2xf32>) -> tensor<2x!quant.uniform> + %dq_bias = "quantfork.dcast"(%q_bias) : (tensor<2x!quant.uniform>) -> tensor<2xf32> + %conv = "tf.Conv2D"(%dq_input, %dq_weight) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %biasadd = "tf.BiasAdd"(%conv, %dq_bias) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %res = "tf.Relu6"(%biasadd) : (tensor<*xf32>) -> tensor<*xf32> + %q_res = "quantfork.qcast"(%res) : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %dq_res = "quantfork.dcast"(%q_res) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + + func.return %dq_res : tensor<*xf32> +} + +// CHECK-DAG: [[bias:%.+]] = "arith.constant"() {value = dense<[7.11401462, 7.05456924]> : tensor<2xf32>} : () -> tensor<2xf32> +// CHECK-DAG: [[weight:%.+]] = "arith.constant"() {value = dense_resource<__elided__> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2x!quant.uniform> +// CHECK: [[q_input:%.+]] = "quantfork.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK-NEXT: [[q_bias:%.+]] = "quantfork.qcast"([[bias]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform> +// CHECK-NEXT: [[conv:%.+]] = "tf.PartitionedCall"([[q_input]], [[weight]], [[q_bias]]) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @[[composite_fn:composite_conv2d_with_bias_and_relu6_fn.*]]} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<*x!quant.uniform> +// CHECK-NEXT: [[res:%.+]] = "quantfork.dcast"([[conv]]) : (tensor<*x!quant.uniform>) -> tensor<*xf32> +// CHECK-NEXT: "func.return"([[res]]) : (tensor<*xf32>) -> () + + +// ----- + +// CHECK-LABEL: standalone_same_scale_test +func.func @standalone_same_scale_test(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %cst = arith.constant dense<[-1, 144]> : tensor<2xi32> + %0 = "quantfork.qcast"(%arg0) : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %1 = "quantfork.dcast"(%0) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + %2 = "tf.MaxPool"(%1) {data_format = "NHWC", device = "", explicit_paddings = [], ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<*xf32>) -> tensor<*xf32> + %3 = "quantfork.qcast"(%2) {volatile} : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %4 = "quantfork.dcast"(%3) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + %5 = "tf.Reshape"(%4, %cst) {device = ""} : (tensor<*xf32>, tensor<2xi32>) -> tensor<*xf32> + %6 = "quantfork.qcast"(%5) {volatile} : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %7 = "quantfork.dcast"(%6) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + func.return %7 : tensor<*xf32> +} + +// CHECK: %[[maxpool_i8:.*]] = "tf.MaxPool" +// CHECK-SAME: (tensor<*xf32>) -> tensor<*xf32> +// CHECK: %[[reshape_i8:.*]] = "tf.Reshape"(%[[maxpool_i8]] +// CHECK-SAME: (tensor<*xf32>, tensor<2xi32>) -> tensor<*xf32> + +// ----- + +// CHECK-LABEL: standalone_avgpool_test +func.func @standalone_avgpool_test(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %cst = arith.constant dense<[-1, 144]> : tensor<2xi32> + %0 = "quantfork.qcast"(%arg0) : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %1 = "quantfork.dcast"(%0) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + %2 = "tf.AvgPool"(%1) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<*xf32>) -> tensor<*xf32> + %3 = "quantfork.qcast"(%2) {volatile} : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %4 = "quantfork.dcast"(%3) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + func.return %4 : tensor<*xf32> +} + +// CHECK: %[[avgpool_f32:.*]] = "tf.AvgPool" +// CHECK-SAME: (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %[[avgpool_f32]] + +// ----- + +func.func @same_scale_op_before_matmul(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %cst = arith.constant dense<[-1, 144]> : tensor<2xi32> + %0 = "quantfork.qcast"(%arg0) : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %1 = "quantfork.dcast"(%0) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + %2 = "tf.MaxPool"(%1) {data_format = "NHWC", device = "", explicit_paddings = [], ksize = [1, 2, 2, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>) -> tensor<*xf32> + %3 = "quantfork.qcast"(%2) {volatile} : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %4 = "quantfork.dcast"(%3) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + %5 = "tf.Reshape"(%4, %cst) {device = ""} : (tensor<*xf32>, tensor<2xi32>) -> tensor<*xf32> + %6 = "quantfork.qcast"(%5) {volatile} : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %7 = "quantfork.dcast"(%6) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + %weight = arith.constant dense<1.0> : tensor<144x12xf32> + %q_weight = "quantfork.qcast"(%weight) : (tensor<144x12xf32>) -> tensor<144x12x!quant.uniform> + %dq_weight = "quantfork.dcast"(%q_weight) : (tensor<144x12x!quant.uniform>) -> tensor<144x12xf32> + %9 = "tf.MatMul"(%7, %dq_weight) {transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor<144x12xf32>) -> tensor<*xf32> + %10 = "quantfork.qcast"(%9) {volatile} : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %11 = "quantfork.dcast"(%10) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + func.return %11 : tensor<*xf32> +} + +// CHECK: %[[maxpool_i8:.*]] = "tf.MaxPool" +// CHECK-SAME: (tensor<*xi8>) -> tensor<*xi8> +// CHECK: %[[reshape_i8:.*]] = "tf.Reshape"(%[[maxpool_i8]] +// CHECK-SAME: (tensor<*xi8>, tensor<2xi32>) -> tensor<*xi8> +// CHECK: %[[scast:.*]] = "quantfork.scast"(%[[reshape_i8]] +// CHECK: %[[matmul:.*]] = "tf.PartitionedCall"(%[[scast]] +// CHECK-SAME: f = @composite_matmul_fn_1 +// CHECK-SAME: (tensor<*x!quant.uniform>, tensor<144x12x!quant.uniform>) -> tensor<*x!quant.uniform> + +// ----- + +func.func private @avgpool_after_conv(%input: tensor<1x3x4x3xf32> {tf._user_specified_name = "input_tensor"}) -> tensor<*xf32> attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<1x3x4x3>]} { + %weight = arith.constant dense<1.0> : tensor<2x3x3x2xf32> + %bias = arith.constant dense<[7.11401462, 7.05456924]> : tensor<2xf32> + %cst = arith.constant dense<[-1, 144]> : tensor<2xi32> + + %q_input= "quantfork.qcast"(%input) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> + %dq_input= "quantfork.dcast"(%q_input) : (tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3xf32> + %q_weight = "quantfork.qcast"(%weight) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> + %dq_weight = "quantfork.dcast"(%q_weight) : (tensor<2x3x3x2x!quant.uniform>) -> tensor<2x3x3x2xf32> + %q_bias = "quantfork.qcast"(%bias) : (tensor<2xf32>) -> tensor<2x!quant.uniform> + %dq_bias = "quantfork.dcast"(%q_bias) : (tensor<2x!quant.uniform>) -> tensor<2xf32> + %conv = "tf.Conv2D"(%dq_input, %dq_weight) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %biasadd = "tf.BiasAdd"(%conv, %dq_bias) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %res = "tf.Relu6"(%biasadd) : (tensor<*xf32>) -> tensor<*xf32> + %q_res = "quantfork.qcast"(%res) : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %dq_res = "quantfork.dcast"(%q_res) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + %avg_pool = "tf.AvgPool"(%dq_res) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<*xf32>) -> tensor<*xf32> + %q_pool = "quantfork.qcast"(%avg_pool) {volatile} : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %dq_pool = "quantfork.dcast"(%q_pool) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + %reshape = "tf.Reshape"(%dq_pool, %cst) {device = ""} : (tensor<*xf32>, tensor<2xi32>) -> tensor<*xf32> + %q_reshape = "quantfork.qcast"(%reshape) {volatile} : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %dq_reshape = "quantfork.dcast"(%q_reshape) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + func.return %dq_reshape : tensor<*xf32> +} + +// CHECK: %[[conv:.*]] = "tf.PartitionedCall" +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_1 +// CHECK-SAME: (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<*x!quant.uniform> +// CHECK: %[[scast:.*]] = "quantfork.scast"(%[[conv]] +// CHECK: %[[fcast:.*]] = "tf.Cast"(%[[scast]]) {Truncate = false} : (tensor<*xi8>) -> tensor<*xf32> +// CHECK: %[[avgpool_f32:.*]] = "tf.AvgPool"(%[[fcast]]) +// CHECK-SAME: (tensor<*xf32>) -> tensor<*xf32> +// CHECK: %[[round:.*]] = "tf.Round"(%[[avgpool_f32]]) +// CHECK: %[[icast:.*]] = "tf.Cast"(%[[round]]) {Truncate = false} : (tensor<*xf32>) -> tensor<*xi8> +// CHECK: %[[reshape:.*]] = "tf.Reshape"(%[[icast]] +// CHECK: %[[sc2:.*]] = "quantfork.scast"(%[[reshape]]) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/remove_var_init_by_const.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/remove_var_init_by_const.mlir new file mode 100644 index 00000000000..d5e18209291 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/remove_var_init_by_const.mlir @@ -0,0 +1,150 @@ +// RUN: tf-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -quant-remove-var-init-by-const | FileCheck %s + +// Single `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` pattern removed from +// the initializer function. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + // CHECK: "tf_saved_model.session_initializer"() + // CHECK-SAME: initializers = [@init_func_restore_op] + + func.func @init_func_restore_op() -> () attributes { + tf_saved_model.initializer_type = "restore_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]} { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %var_0 = "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_0, %cst_0) : (tensor>>, tensor<2xf32>) -> () + return + } + // All three ops should have been removed. + // CHECK: @init_func_restore_op + // CHECK-SAME: tf_saved_model.initializer_type = "restore_op" + // CHECK-NEXT: return +} + +// ----- + +// The `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` pattern is not removed +// from the initializer function that is not "restore_op" type. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_init_op]} : () -> () + // CHECK: "tf_saved_model.session_initializer"() + // CHECK-SAME: initializers = [@init_func_init_op] + + func.func @init_func_init_op() -> () attributes { + tf_saved_model.initializer_type = "init_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_init_op"]} { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %var_0 = "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_0, %cst_0) : (tensor>>, tensor<2xf32>) -> () + return + } + // Nothing has been removed. + // CHECK: @init_func_init_op + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.VarHandleOp" + // CHECK-NEXT: "tf.AssignVariableOp" + // CHECK-NEXT: return +} + +// ----- + +// If `tf.Const` is not used to initialize the variable, it is not removed. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + // CHECK: "tf_saved_model.session_initializer"() + // CHECK-SAME: initializers = [@init_func_restore_op] + + func.func @init_func_restore_op() -> () attributes { + tf_saved_model.initializer_type = "restore_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]} { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %var_0 = "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_0, %cst_0) : (tensor>>, tensor<2xf32>) -> () + %add_0 = "tf.Identity"(%cst_0) : (tensor<2xf32>) -> tensor<2xf32> + %var_1 = "tf.VarHandleOp"() {shared_name = "var_1"} : () -> tensor>> + "tf.AssignVariableOp"(%var_1, %add_0) : (tensor>>, tensor<2xf32>) -> () + return + } + // The second AssignVariableOp, which takes the result of the `tf.Identity` + // op, is not removed. Note that the first AssignVariableOp is removed. + // CHECK: @init_func_restore_op + // CHECK-SAME: tf_saved_model.initializer_type = "restore_op" + // CHECK-NOT: "tf.AssignVariableOp" + // CHECK: %[[CST:.*]] = "tf.Const"() + // CHECK-NEXT: %[[IDENTITY:.*]] = "tf.Identity"(%[[CST]]) + // CHECK-NEXT: %[[VAR:.*]] = "tf.VarHandleOp"() {{{.*shared_name = "var_1".*}}} + // CHECK-NEXT: "tf.AssignVariableOp"(%[[VAR]], %[[IDENTITY]]) +} + +// ----- + +// If something other than `tf.VarHandleOp` is being initialized, it is +// not erased. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + // CHECK: "tf_saved_model.session_initializer"() + // CHECK-SAME: initializers = [@init_func_restore_op] + + func.func @init_func_restore_op() -> () attributes { + tf_saved_model.initializer_type = "restore_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]} { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + // Note: this is a contrived example and is an invalid input. + %var_0 = "tf.HashTableV2"() {key_dtype = i64, value_dtype = !tf_type.string} : () -> tensor + "tf.AssignVariableOp"(%var_0, %cst_0) : (tensor, tensor<2xf32>) -> () + return + } + // CHECK: @init_func_restore_op + // CHECK-SAME: tf_saved_model.initializer_type = "restore_op" + // CHECK: %[[CST:.*]] = "tf.Const"() + // CHECK-NEXT: %[[HASH_TABLE:.*]] = "tf.HashTableV2"() + // CHECK-NEXT: "tf.AssignVariableOp"(%[[HASH_TABLE]], %[[CST]]) +} + +// ----- + + +// Nothing happens when there are no `tf_saved_model.session_initializer`. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { +} + +// ----- + +// Nothing happens when there are no initializer functions. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = []} : () -> () +} + +// ----- + +// Nothing happens when the initializer function of type = "restore_op" is +// empty. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + // CHECK: "tf_saved_model.session_initializer"() + // CHECK-SAME: initializers = [@init_func_restore_op] + + func.func @init_func_restore_op() -> () attributes { + tf_saved_model.initializer_type = "restore_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]} { + return + } + // CHECK: @init_func_restore_op + // CHECK-SAME: tf_saved_model.initializer_type = "restore_op" + // CHECK-NEXT: return +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops.mlir index 9f4bf32d40c..c8ed105c3da 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops.mlir @@ -65,12 +65,13 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK-DAG-SAME{LITERAL}: value = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> // CHECK-DAG: %[[CONST_5:.*]] = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor // CHECK-DAG: %[[CONST_6:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8> -// CHECK-DAG: %[[CONST_7:.*]] = "tf.Const"() {value = dense<[-22016, -23680]> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[CONST_7:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<1x1x1x2xi32>} : () -> tensor<1x1x1x2xi32> +// CHECK-DAG-SAME{LITERAL}: value = dense<[[[[-22016, -23680]]]]> // CHECK-DAG: %[[CONST_8:.*]] = "tf.Const"() {value = dense<[162, 160]> : tensor<2xi32>} : () -> tensor<2xi32> // CHECK: %[[PADV2_0:.*]] = "tf.PadV2"({{.*}}, %[[CONST_4]], %[[CONST_5]]) : (tensor<1x3x4x3xi8>, tensor<4x2xi32>, tensor) -> tensor<1x4x5x3xi8> // CHECK: %[[XLACONVV2_0:.*]] = "tf.XlaConvV2"(%[[PADV2_0]], %[[CONST_6]], %[[CONST_0]], %[[CONST_3]], %[[CONST_1]], %[[CONST_1]], %[[CONST_2]]) // CHECK-SAME: (tensor<1x4x5x3xi8>, tensor<2x3x3x2xi8>, tensor<2xi32>, tensor<2x2xi32>, tensor<2xi32>, tensor<2xi32>, tensor) -> tensor<1x3x2x2xi32> -// CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLACONVV2_0]], %[[CONST_7]]) : (tensor<1x3x2x2xi32>, tensor<2xi32>) -> tensor<1x3x2x2xi32> +// CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLACONVV2_0]], %[[CONST_7]]) : (tensor<1x3x2x2xi32>, tensor<1x1x1x2xi32>) -> tensor<1x3x2x2xi32> // CHECK: %[[ADDV2_1:.*]] = "tf.AddV2"(%[[SUB_0]], %[[CONST_8]]) : (tensor<1x3x2x2xi32>, tensor<2xi32>) -> tensor<1x3x2x2xi32> } @@ -150,12 +151,13 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK-DAG: %[[CONST_4:.*]] = "tf.Const"() {value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32> // CHECK-DAG: %[[CONST_5:.*]] = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32> // CHECK-DAG: %[[CONST_6:.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor -// CHECK-DAG: %[[CONST_7:.*]] = "tf.Const"() {value = dense<[55040, -15104, -21376]> : tensor<3xi32>} : () -> tensor<3xi32> +// CHECK-DAG: %[[CONST_7:.*]] = "tf.Const"() {value = dense<{{.*}}> : tensor<1x1x1x3xi32>} : () -> tensor<1x1x1x3xi32> +// CHECK-DAG-SAME{LITERAL}: value = dense<[[[[55040, -15104, -21376]]]]> // CHECK-DAG: %[[CONST_8:.*]] = "tf.Const"() {value = dense<[129, 166, 221]> : tensor<3xi32>} : () -> tensor<3xi32> // CHECK: %[[PADV2_0:.*]] = "tf.PadV2"({{.*}}, %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3xi8>, tensor<4x2xi32>, tensor) -> tensor<1x4x5x3xi8> // CHECK: %[[XLACONVV2_0:.*]] = "tf.XlaConvV2"(%[[PADV2_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], %[[CONST_5]], %[[CONST_5]], %[[CONST_6]]) // CHECK-SAME: (tensor<1x4x5x3xi8>, tensor<2x3x1x3xi8>, tensor<2xi32>, tensor<2x2xi32>, tensor<2xi32>, tensor<2xi32>, tensor) -> tensor<1x2x2x3xi32> -// CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLACONVV2_0]], %[[CONST_7]]) : (tensor<1x2x2x3xi32>, tensor<3xi32>) -> tensor<1x2x2x3xi32> +// CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLACONVV2_0]], %[[CONST_7]]) : (tensor<1x2x2x3xi32>, tensor<1x1x1x3xi32>) -> tensor<1x2x2x3xi32> // CHECK: %[[ADDV2_1:.*]] = "tf.AddV2"(%[[SUB_0]], %[[CONST_8]]) : (tensor<1x2x2x3xi32>, tensor<3xi32>) -> tensor<1x2x2x3xi32> } @@ -204,7 +206,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK-DAG: %[[padding_rank_2:.*]] = "tf.Reshape"(%[[padding_rank_1]], {{.*}}) : (tensor<8xi32>, tensor<2xi64>) -> tensor<4x2xi32> // CHECK-DAG: %[[input_padded:.*]] = "tf.PadV2"(%{{.*}}, %[[padding_rank_2]], {{.*}}) : (tensor, tensor<4x2xi32>, tensor) -> tensor // CHECK: %[[conv_output:.*]] = "tf.XlaConvV2"(%[[input_padded]], %[[filter]], {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) {batch_group_count = 1 : i64, dimension_numbers = "{{.*}}", precision_config = ""} : (tensor, tensor<2x3x3x2xi8>, tensor<2xi32>, tensor<2x2xi32>, tensor<2xi32>, tensor<2xi32>, tensor) -> tensor -// CHECK: %[[conv_output_sub:.*]] = "tf.Sub"(%[[conv_output]], {{.*}}) : (tensor, tensor<2xi32>) -> tensor +// CHECK: %[[conv_output_sub:.*]] = "tf.Sub"(%[[conv_output]], {{.*}}) : (tensor, tensor<1x1x1x2xi32>) -> tensor // CHECK: %[[conv_output_add:.*]] = "tf.AddV2"(%[[conv_output_sub]], {{.*}}) {device = ""} : (tensor, tensor<2xi32>) -> tensor } @@ -262,7 +264,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p } // CHECK-LABEL: func @conv_with_filter_larger_than_1MB -// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<-264192> : tensor<512xi32>} : () -> tensor<512xi32> +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<-264192> : tensor<1x1x1x512xi32>} : () -> tensor<1x1x1x512xi32> // CHECK: %[[PADV2_0:.*]] = "tf.PadV2" // CHECK: %[[XLACONVV2_0:.*]] = "tf.XlaConvV2"(%[[PADV2_0]] // CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLACONVV2_0]], %[[CONST]]) @@ -296,10 +298,154 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p } // CHECK-LABEL: func @matmul_with_relu // CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() {device = "", value = dense<1> : tensor<1024x3xi8>} : () -> tensor<1024x3xi8> -// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<-131072> : tensor<3xi32>} : () -> tensor<3xi32> +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<-131072> : tensor<1x3xi32>} : () -> tensor<1x3xi32> // CHECK: %[[MATMUL:.*]] = "tf.XlaDotV2"({{.*}}, %[[WEIGHT]]) // CHECK-SAME: (tensor<1x1024xi8>, tensor<1024x3xi8>) -> tensor<1x3xi32> -// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[MATMUL]], %[[CONST]]) : (tensor<1x3xi32>, tensor<3xi32>) -> tensor<1x3xi32> +// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[MATMUL]], %[[CONST]]) : (tensor<1x3xi32>, tensor<1x3xi32>) -> tensor<1x3xi32> +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1213 : i32}} { + func.func @matmul_two_tensors_with_static_shape(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> (tensor<2x2xf32>) { + %cst = "tf.Const"() {value = dense<-5.450000e+01> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<0.0156862754> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<-5.000000e-01> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<0.0274509806> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {value = dense<-55> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg1, %cst_0) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %1 = "tf.AddV2"(%0, %cst_1) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %2 = "tf.Floor"(%1) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %3 = "tf.ClipByValue"(%2, %cst_5, %cst_6) : (tensor<2x2xf32>, tensor, tensor) -> tensor<2x2xf32> + %4 = "tf.Cast"(%3) {Truncate = false} : (tensor<2x2xf32>) -> tensor<2x2xi8> + %5 = "tf.Div"(%arg0, %cst_3) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %6 = "tf.AddV2"(%5, %cst) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %7 = "tf.Floor"(%6) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %8 = "tf.ClipByValue"(%7, %cst_5, %cst_6) : (tensor<2x2xf32>, tensor, tensor) -> tensor<2x2xf32> + %9 = "tf.Cast"(%8) {Truncate = false} : (tensor<2x2xf32>) -> tensor<2x2xi8> + %10 = "tf.Cast"(%9) {Truncate = false} : (tensor<2x2xi8>) -> tensor<2x2xi32> + %11 = "tf.Sub"(%10, %cst_4) : (tensor<2x2xi32>, tensor) -> tensor<2x2xi32> + %12 = "tf.Identity"(%4) : (tensor<2x2xi8>) -> tensor<2x2xi8> + %13 = "tf.Cast"(%12) {Truncate = false} : (tensor<2x2xi8>) -> tensor<2x2xi32> + %14 = "tf.Sub"(%13, %cst_2) : (tensor<2x2xi32>, tensor) -> tensor<2x2xi32> + %15 = "tf.MatMul"(%11, %14) {transpose_a = false, transpose_b = false} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + %16 = "tf.Cast"(%15) {Truncate = false} : (tensor<2x2xi32>) -> tensor<2x2xf32> + %17 = "tf.Mul"(%16, %cst_0) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %18 = "tf.AddV2"(%17, %cst) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %19 = "tf.Floor"(%18) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %20 = "tf.ClipByValue"(%19, %cst_5, %cst_6) : (tensor<2x2xf32>, tensor, tensor) -> tensor<2x2xf32> + %21 = "tf.Cast"(%20) {Truncate = false} : (tensor<2x2xf32>) -> tensor<2x2xi8> + %22 = "tf.Identity"(%21) {device = ""} : (tensor<2x2xi8>) -> tensor<2x2xi8> + %23 = "tf.Identity"(%22) {device = ""} : (tensor<2x2xi8>) -> tensor<2x2xi8> + %24 = "tf.Cast"(%23) : (tensor<2x2xi8>) -> tensor<2x2xi32> + %25 = "tf.Sub"(%24, %cst_4) : (tensor<2x2xi32>, tensor) -> tensor<2x2xi32> + %26 = "tf.Cast"(%25) : (tensor<2x2xi32>) -> tensor<2x2xf32> + %27 = "tf.Mul"(%26, %cst_3) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + return %27 : tensor<2x2xf32> + } + +// CHECK-LABEL: func @matmul_two_tensors_with_static_shape +// CHECK: %[[arg1_div:.*]] = "tf.Div"(%arg1 +// CHECK: %[[arg1_add:.*]] = "tf.AddV2"(%[[arg1_div]] +// CHECK: %[[arg1_floor:.*]] = "tf.Floor"(%[[arg1_add]] +// CHECK: %[[arg1_clip:.*]] = "tf.ClipByValue"(%[[arg1_floor]] +// CHECK: %[[arg1_cast:.*]] = "tf.Cast"(%[[arg1_clip]] + +// CHECK: %[[arg0_div:.*]] = "tf.Div"(%arg0 +// CHECK: %[[arg0_add:.*]] = "tf.AddV2"(%[[arg0_div]] +// CHECK: %[[arg0_floor:.*]] = "tf.Floor"(%[[arg0_add]] +// CHECK: %[[arg0_clip:.*]] = "tf.ClipByValue"(%[[arg0_floor]] +// CHECK: %[[arg0_cast:.*]] = "tf.Cast"(%[[arg0_clip]] + +// CHECK: %[[arg1_identity:.*]] = "tf.Identity"(%[[arg1_cast]] + +// CHECK: %[[matmul:.*]] = "tf.XlaDotV2"(%[[arg0_cast]], %[[arg1_identity]] +// CHECK-SAME: (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi32> + +// CHECK: %[[matmul_sub:.*]] = "tf.Sub"(%[[matmul]] +// CHECK: %[[matmul_cast:.*]] = "tf.Cast"(%[[matmul_sub]] +// CHECK: %[[matmul_mul:.*]] = "tf.Mul"(%[[matmul_cast]] +// CHECK: %[[matmul_add:.*]] = "tf.AddV2"(%[[matmul_mul]] +// CHECK: %[[matmul_floor:.*]] = "tf.Floor"(%[[matmul_add]] +// CHECK: %[[matmul_clip:.*]] = "tf.ClipByValue"(%[[matmul_floor]] +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1213 : i32}} { + func.func @matmul_two_tensors_with_dynamic_shape(%arg0: tensor, %arg1: tensor) -> (tensor) { + %cst = "tf.Const"() {value = dense<-5.450000e+01> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<0.0156862754> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<-5.000000e-01> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<0.0274509806> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {value = dense<-55> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg1, %cst_0) : (tensor, tensor) -> tensor + %1 = "tf.AddV2"(%0, %cst_1) : (tensor, tensor) -> tensor + %2 = "tf.Floor"(%1) : (tensor) -> tensor + %3 = "tf.ClipByValue"(%2, %cst_5, %cst_6) : (tensor, tensor, tensor) -> tensor + %4 = "tf.Cast"(%3) {Truncate = false} : (tensor) -> tensor + %5 = "tf.Div"(%arg0, %cst_3) : (tensor, tensor) -> tensor + %6 = "tf.AddV2"(%5, %cst) : (tensor, tensor) -> tensor + %7 = "tf.Floor"(%6) : (tensor) -> tensor + %8 = "tf.ClipByValue"(%7, %cst_5, %cst_6) : (tensor, tensor, tensor) -> tensor + %9 = "tf.Cast"(%8) {Truncate = false} : (tensor) -> tensor + %10 = "tf.Cast"(%4) {Truncate = false} : (tensor) -> tensor + %11 = "tf.Sub"(%10, %cst_2) : (tensor, tensor) -> tensor + %12 = "tf.Identity"(%9) : (tensor) -> tensor + %13 = "tf.Cast"(%12) {Truncate = false} : (tensor) -> tensor + %14 = "tf.Sub"(%13, %cst_4) : (tensor, tensor) -> tensor + %15 = "tf.MatMul"(%11, %14) {transpose_a = false, transpose_b = false} : (tensor, tensor) -> tensor + %16 = "tf.Cast"(%15) {Truncate = false} : (tensor) -> tensor + %17 = "tf.Mul"(%16, %cst_0) : (tensor, tensor) -> tensor + %18 = "tf.AddV2"(%17, %cst) : (tensor, tensor) -> tensor + %19 = "tf.Floor"(%18) : (tensor) -> tensor + %20 = "tf.ClipByValue"(%19, %cst_5, %cst_6) : (tensor, tensor, tensor) -> tensor + %21 = "tf.Cast"(%20) {Truncate = false} : (tensor) -> tensor + %22 = "tf.Identity"(%21) {device = ""} : (tensor) -> tensor + %23 = "tf.Identity"(%22) {device = ""} : (tensor) -> tensor + %24 = "tf.Cast"(%23) : (tensor) -> tensor + %25 = "tf.Sub"(%24, %cst_4) : (tensor, tensor) -> tensor + %26 = "tf.Cast"(%25) : (tensor) -> tensor + %27 = "tf.Mul"(%26, %cst_3) : (tensor, tensor) -> tensor + return %27 : tensor + } + +// CHECK-LABEL: func @matmul_two_tensors_with_dynamic_shape +// CHECK: %[[arg1_div:.*]] = "tf.Div"(%arg1 +// CHECK: %[[arg1_add:.*]] = "tf.AddV2"(%[[arg1_div]] +// CHECK: %[[arg1_floor:.*]] = "tf.Floor"(%[[arg1_add]] +// CHECK: %[[arg1_clip:.*]] = "tf.ClipByValue"(%[[arg1_floor]] +// CHECK: %[[arg1_cast:.*]] = "tf.Cast"(%[[arg1_clip]] + +// CHECK: %[[arg0_div:.*]] = "tf.Div"(%arg0 +// CHECK: %[[arg0_add:.*]] = "tf.AddV2"(%[[arg0_div]] +// CHECK: %[[arg0_floor:.*]] = "tf.Floor"(%[[arg0_add]] +// CHECK: %[[arg0_clip:.*]] = "tf.ClipByValue"(%[[arg0_floor]] +// CHECK: %[[arg0_cast:.*]] = "tf.Cast"(%[[arg0_clip]] +// CHECK: %[[arg0_identity:.*]] = "tf.Identity"(%[[arg0_cast]] + +// CHECK: %[[matmul:.*]] = "tf.XlaDotV2"(%[[arg1_cast]], %[[arg0_identity]] +// CHECK-SAME: (tensor, tensor) -> tensor + +// CHECK: %[[arg0_shape:.*]] = "tf.Shape"(%[[arg0_identity]] +// CHECK: %[[shape_zp_contribute:.*]] = "tf.StridedSlice"(%[[arg0_shape]] +// CHECK: %[[shape_zp_contribute_cast:.*]] = "tf.Cast"(%[[shape_zp_contribute]] +// CHECK: %[[shape_zp_contribute_mul:.*]] = "tf.Mul"(%[[shape_zp_contribute_cast]] +// CHECK: %[[zp:.*]] = "tf.Sub"({{.*}}, %[[shape_zp_contribute_mul]]) + +// CHECK: %[[matmul_sub:.*]] = "tf.Sub"(%[[matmul]], %[[zp]] +// CHECK: %[[matmul_cast:.*]] = "tf.Cast"(%[[matmul_sub]] +// CHECK: %[[matmul_mul:.*]] = "tf.Mul"(%[[matmul_cast]] +// CHECK: %[[matmul_add:.*]] = "tf.AddV2"(%[[matmul_mul]] +// CHECK: %[[matmul_floor:.*]] = "tf.Floor"(%[[matmul_add]] +// CHECK: %[[matmul_clip:.*]] = "tf.ClipByValue"(%[[matmul_floor]] + } // ----- @@ -337,7 +483,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {{.*}} : () -> tensor<5x2xi32> // CHECK-DAG-SAME{LITERAL}: value = dense<[[0, 0], [0, 1], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi32> // CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<-43> : tensor} : () -> tensor -// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() {value = dense<-2322> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() {value = dense<-2322> : tensor<1x1x1x1x2xi32>} : () -> tensor<1x1x1x1x2xi32> // CHECK: %[[PAD:.*]] = "tf.PadV2"({{.*}}, %[[CONST]], %[[CONST_1]]) // CHECK: %[[CONV:.*]] = "tf.XlaConvV2"(%[[PAD]], %[[WEIGHT]] @@ -380,7 +526,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK-LABEL: func @conv3d_with_dynamic_shape // CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() {device = "", value = dense<1> : tensor<2x3x3x3x2xi8>} : () -> tensor<2x3x3x3x2xi8> // CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<-43> : tensor} : () -> tensor -// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() {value = dense<-2322> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() {value = dense<-2322> : tensor<1x1x1x1x2xi32>} : () -> tensor<1x1x1x1x2xi32> // CHECK: %[[CONCAT:.*]] = "tf.Concat"({{.*}}) // CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%[[CONCAT]], {{.*}}) : (tensor<10xi32>, tensor<2xi64>) -> tensor<5x2xi32> @@ -389,3 +535,333 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK-SAME: (tensor, tensor<2x3x3x3x2xi8>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor) -> tensor // CHECK: %[[SUB:.*]] = "tf.Sub"(%[[CONV]], %[[CONST_2]]) } + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1213 : i32}, tf_saved_model.semantics} { + func.func @batch_matmul(%arg0: tensor<20x30x64x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<20x30x64x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "tf.PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() {device = "", value = dense<3.08784583E-5> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<-1.275000e+02> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "", value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "", value = dense<1> : tensor<20x30x1024x3xi8>} : () -> tensor<20x30x1024x3xi8> + %cst_3 = "tf.Const"() {device = "", value = dense<0.00392156886> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {device = "", value = dense<-128> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {device = "", value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_3) {device = ""} : (tensor<20x30x64x1024xf32>, tensor) -> tensor<20x30x64x1024xf32> + %1 = "tf.AddV2"(%0, %cst_0) {device = ""} : (tensor<20x30x64x1024xf32>, tensor) -> tensor<20x30x64x1024xf32> + %2 = "tf.Floor"(%1) {device = ""} : (tensor<20x30x64x1024xf32>) -> tensor<20x30x64x1024xf32> + %3 = "tf.ClipByValue"(%2, %cst_1, %cst_5) {device = ""} : (tensor<20x30x64x1024xf32>, tensor, tensor) -> tensor<20x30x64x1024xf32> + %4 = "tf.Cast"(%3) {Truncate = false, device = ""} : (tensor<20x30x64x1024xf32>) -> tensor<20x30x64x1024xi8> + %5 = "tf.Cast"(%4) {Truncate = false, device = ""} : (tensor<20x30x64x1024xi8>) -> tensor<20x30x64x1024xi32> + %6 = "tf.Sub"(%5, %cst_4) {device = ""} : (tensor<20x30x64x1024xi32>, tensor) -> tensor<20x30x64x1024xi32> + %7 = "tf.Identity"(%cst_2) {device = ""} : (tensor<20x30x1024x3xi8>) -> tensor<20x30x1024x3xi8> + %8 = "tf.Cast"(%7) {Truncate = false, device = ""} : (tensor<20x30x1024x3xi8>) -> tensor<20x30x1024x3xi32> + %9 = "tf.BatchMatMulV2"(%6, %8) {adj_x = false, adj_y = false, device = ""} : (tensor<20x30x64x1024xi32>, tensor<20x30x1024x3xi32>) -> tensor<20x30x64x3xi32> + %10 = "tf.Cast"(%9) {Truncate = false, device = ""} : (tensor<20x30x64x3xi32>) -> tensor<20x30x64x3xf32> + %11 = "tf.Mul"(%10, %cst) {device = ""} : (tensor<20x30x64x3xf32>, tensor) -> tensor<20x30x64x3xf32> + %12 = "tf.Relu"(%11) {device = ""} : (tensor<20x30x64x3xf32>) -> tensor<20x30x64x3xf32> + %13 = "tf.Identity"(%12) {device = ""} : (tensor<20x30x64x3xf32>) -> tensor<20x30x64x3xf32> + return %13 : tensor<20x30x64x3xf32> + } + +// CHECK-LABEL: func @batch_matmul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<-131072> : tensor<20x30x1x3xi32>} : () -> tensor<20x30x1x3xi32> +// CHECK: %[[CAST:.*]] = "tf.Cast" +// CHECK: %[[XLADOTV2_0:.*]] = "tf.XlaDotV2"(%[[CAST]] +// CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLADOTV2_0]], %[[CONST]]) : (tensor<20x30x64x3xi32>, tensor<20x30x1x3xi32>) -> tensor<20x30x64x3xi32> +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1213 : i32}, tf_saved_model.semantics} { + func.func @broadcasting_weight_batch_matmul(%arg0: tensor<2x1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<2x1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() {device = "", value = dense<3.08762283E-5> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<-1.275000e+02> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "", value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "", value = dense<[-241, 5894, -3771]> : tensor<3xi32>} : () -> tensor<3xi32> + %cst_3 = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<1024x3xi8>} : () -> tensor<1024x3xi8> + %cst_4 = "tf.Const"() {device = "", value = dense<0.00392156513> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {device = "", value = dense<-128> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {device = "", value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_4) {device = ""} : (tensor<2x1x1024xf32>, tensor) -> tensor<2x1x1024xf32> + %1 = "tf.AddV2"(%0, %cst_0) {device = ""} : (tensor<2x1x1024xf32>, tensor) -> tensor<2x1x1024xf32> + %2 = "tf.Floor"(%1) {device = ""} : (tensor<2x1x1024xf32>) -> tensor<2x1x1024xf32> + %3 = "tf.ClipByValue"(%2, %cst_1, %cst_6) {device = ""} : (tensor<2x1x1024xf32>, tensor, tensor) -> tensor<2x1x1024xf32> + %4 = "tf.Cast"(%3) {Truncate = false, device = ""} : (tensor<2x1x1024xf32>) -> tensor<2x1x1024xi8> + %5 = "tf.Cast"(%4) {Truncate = false, device = ""} : (tensor<2x1x1024xi8>) -> tensor<2x1x1024xi32> + %6 = "tf.Sub"(%5, %cst_5) {device = ""} : (tensor<2x1x1024xi32>, tensor) -> tensor<2x1x1024xi32> + %7 = "tf.Identity"(%cst_3) {device = ""} : (tensor<1024x3xi8>) -> tensor<1024x3xi8> + %8 = "tf.Cast"(%7) {Truncate = false, device = ""} : (tensor<1024x3xi8>) -> tensor<1024x3xi32> + %9 = "tf.BatchMatMulV2"(%6, %8) {adj_x = false, adj_y = false, device = ""} : (tensor<2x1x1024xi32>, tensor<1024x3xi32>) -> tensor<2x1x3xi32> + %10 = "tf.AddV2"(%9, %cst_2) {device = ""} : (tensor<2x1x3xi32>, tensor<3xi32>) -> tensor<2x1x3xi32> + %11 = "tf.Cast"(%10) {Truncate = false, device = ""} : (tensor<2x1x3xi32>) -> tensor<2x1x3xf32> + %12 = "tf.Mul"(%11, %cst) {device = ""} : (tensor<2x1x3xf32>, tensor) -> tensor<2x1x3xf32> + %13 = "tf.Identity"(%12) {device = ""} : (tensor<2x1x3xf32>) -> tensor<2x1x3xf32> + %14 = "tf.Identity"(%13) {device = ""} : (tensor<2x1x3xf32>) -> tensor<2x1x3xf32> + return %14 : tensor<2x1x3xf32> + } + +// CHECK-LABEL: func @broadcasting_weight_batch_matmul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<[2, 1024, 3]> : tensor<3xi64>} : () -> tensor<3xi64> +// CHECK: %[[CAST:.*]] = "tf.Cast" +// CHECK: %[[BROADCAST_TO:.*]] = "tf.BroadcastTo"({{.*}}, %[[CONST]]) : (tensor<1024x3xi8>, tensor<3xi64>) -> tensor<2x1024x3xi8> +// CHECK: %[[XLADOTV2_0:.*]] = "tf.XlaDotV2"(%[[CAST]], %[[BROADCAST_TO]]) +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1213 : i32}, tf_saved_model.semantics} { + func.func @broadcasting_input_batch_matmul(%arg0: tensor<2x1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<2x2x1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() {device = "", value = dense<3.08762283E-5> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<-1.275000e+02> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "", value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "", value = dense<[-241, 5894, -3771]> : tensor<3xi32>} : () -> tensor<3xi32> + %cst_3 = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<2x2x1024x3xi8>} : () -> tensor<2x2x1024x3xi8> + %cst_4 = "tf.Const"() {device = "", value = dense<0.00392156513> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {device = "", value = dense<-128> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {device = "", value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_4) {device = ""} : (tensor<2x1x1024xf32>, tensor) -> tensor<2x1x1024xf32> + %1 = "tf.AddV2"(%0, %cst_0) {device = ""} : (tensor<2x1x1024xf32>, tensor) -> tensor<2x1x1024xf32> + %2 = "tf.Floor"(%1) {device = ""} : (tensor<2x1x1024xf32>) -> tensor<2x1x1024xf32> + %3 = "tf.ClipByValue"(%2, %cst_1, %cst_6) {device = ""} : (tensor<2x1x1024xf32>, tensor, tensor) -> tensor<2x1x1024xf32> + %4 = "tf.Cast"(%3) {Truncate = false, device = ""} : (tensor<2x1x1024xf32>) -> tensor<2x1x1024xi8> + %5 = "tf.Cast"(%4) {Truncate = false, device = ""} : (tensor<2x1x1024xi8>) -> tensor<2x1x1024xi32> + %6 = "tf.Sub"(%5, %cst_5) {device = ""} : (tensor<2x1x1024xi32>, tensor) -> tensor<2x1x1024xi32> + %7 = "tf.Identity"(%cst_3) {device = ""} : (tensor<2x2x1024x3xi8>) -> tensor<2x2x1024x3xi8> + %8 = "tf.Cast"(%7) {Truncate = false, device = ""} : (tensor<2x2x1024x3xi8>) -> tensor<2x2x1024x3xi32> + %9 = "tf.BatchMatMulV2"(%6, %8) {adj_x = false, adj_y = false, device = ""} : (tensor<2x1x1024xi32>, tensor<2x2x1024x3xi32>) -> tensor<2x2x1x3xi32> + %10 = "tf.AddV2"(%9, %cst_2) {device = ""} : (tensor<2x2x1x3xi32>, tensor<3xi32>) -> tensor<2x2x1x3xi32> + %11 = "tf.Cast"(%10) {Truncate = false, device = ""} : (tensor<2x2x1x3xi32>) -> tensor<2x2x1x3xf32> + %12 = "tf.Mul"(%11, %cst) {device = ""} : (tensor<2x2x1x3xf32>, tensor) -> tensor<2x2x1x3xf32> + %13 = "tf.Identity"(%12) {device = ""} : (tensor<2x2x1x3xf32>) -> tensor<2x2x1x3xf32> + %14 = "tf.Identity"(%13) {device = ""} : (tensor<2x2x1x3xf32>) -> tensor<2x2x1x3xf32> + return %14 : tensor<2x2x1x3xf32> + } + +// CHECK-LABEL: func @broadcasting_input_batch_matmul +// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() {device = "", value = {{.*}} : tensor<2x2x1024x3xi8>} : () -> tensor<2x2x1024x3xi8> +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<[2, 2, 1, 1024]> : tensor<4xi64>} : () -> tensor<4xi64> +// CHECK: %[[CAST:.*]] = "tf.Cast" +// CHECK: %[[BROADCAST_TO:.*]] = "tf.BroadcastTo"(%[[CAST]], %[[CONST]]) : (tensor<2x1x1024xi8>, tensor<4xi64>) -> tensor<2x2x1x1024xi8> +// CHECK: %[[XLADOTV2_0:.*]] = "tf.XlaDotV2"(%[[BROADCAST_TO]], %[[WEIGHT]]) +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1213 : i32}, tf_saved_model.semantics} { + func.func @dynamic_shape_batch_matmul(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() {device = "", value = dense<3.08762283E-5> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<-1.275000e+02> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "", value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "", value = dense<[-241, 5894, -3771]> : tensor<3xi32>} : () -> tensor<3xi32> + %cst_3 = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<1024x3xi8>} : () -> tensor<1024x3xi8> + %cst_4 = "tf.Const"() {device = "", value = dense<0.00392156513> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {device = "", value = dense<-128> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {device = "", value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_4) {device = ""} : (tensor, tensor) -> tensor + %1 = "tf.AddV2"(%0, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %2 = "tf.Floor"(%1) {device = ""} : (tensor) -> tensor + %3 = "tf.ClipByValue"(%2, %cst_1, %cst_6) {device = ""} : (tensor, tensor, tensor) -> tensor + %4 = "tf.Cast"(%3) {Truncate = false, device = ""} : (tensor) -> tensor + %5 = "tf.Cast"(%4) {Truncate = false, device = ""} : (tensor) -> tensor + %6 = "tf.Sub"(%5, %cst_5) {device = ""} : (tensor, tensor) -> tensor + %7 = "tf.Identity"(%cst_3) {device = ""} : (tensor<1024x3xi8>) -> tensor<1024x3xi8> + %8 = "tf.Cast"(%7) {Truncate = false, device = ""} : (tensor<1024x3xi8>) -> tensor<1024x3xi32> + %9 = "tf.BatchMatMulV2"(%6, %8) {adj_x = false, adj_y = false, device = ""} : (tensor, tensor<1024x3xi32>) -> tensor + %10 = "tf.AddV2"(%9, %cst_2) {device = ""} : (tensor, tensor<3xi32>) -> tensor + %11 = "tf.Cast"(%10) {Truncate = false, device = ""} : (tensor) -> tensor + %12 = "tf.Mul"(%11, %cst) {device = ""} : (tensor, tensor) -> tensor + %13 = "tf.Identity"(%12) {device = ""} : (tensor) -> tensor + %14 = "tf.Identity"(%13) {device = ""} : (tensor) -> tensor + return %14 : tensor + } + +// CHECK-LABEL: func @dynamic_shape_batch_matmul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> +// CHECK-DAG: %[[CONST_3:.*]] = "tf.Const"() {value = dense<[1024, 3]> : tensor<2xi64>} : () -> tensor<2xi64> +// CHECK-DAG: %[[CONST_4:.*]] = "tf.Const"() {value = dense<> : tensor<0xi64>} : () -> tensor<0xi64> +// CHECK-DAG: %[[CONST_5:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor +// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() {device = "", value = {{.*}} : tensor<1024x3xi8>} : () -> tensor<1024x3xi8> +// CHECK: %[[CAST:.*]] = "tf.Cast"({{.*}}) {Truncate = false, device = ""} : (tensor) -> tensor +// CHECK: %[[SHAPE:.*]] = "tf.Shape"(%[[CAST]]) : (tensor) -> tensor<3xi64> +// CHECK: %[[SLICE_1:.*]] = "tf.Slice"(%[[SHAPE]], %[[CONST]], %[[CONST_2]]) : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> +// CHECK: %[[SLICE_2:.*]] = "tf.Slice"(%[[SHAPE]], %[[CONST_2]], %[[CONST_1]]) : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64> +// CHECK: %[[BROADCAST_ARGS:.*]] = "tf.BroadcastArgs"(%[[SLICE_1]], %[[CONST_4]]) : (tensor<1xi64>, tensor<0xi64>) -> tensor<1xi64> +// CHECK: %[[CONCAT_1:.*]] = "tf.Concat"(%[[CONST_5]], %[[BROADCAST_ARGS]], %[[SLICE_2]]) : (tensor, tensor<1xi64>, tensor<2xi64>) -> tensor<3xi64> +// CHECK: %[[CONCAT_2:.*]] = "tf.Concat"(%[[CONST_5]], %[[BROADCAST_ARGS]], %[[CONST_3]]) : (tensor, tensor<1xi64>, tensor<2xi64>) -> tensor<3xi64> +// CHECK: %[[BROADCAST_1:.*]] = "tf.BroadcastTo"(%[[CAST]], %[[CONCAT_1]]) : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[BROADCAST_2:.*]] = "tf.BroadcastTo"(%[[WEIGHT]], %[[CONCAT_2]]) : (tensor<1024x3xi8>, tensor<3xi64>) -> tensor +// CHECK: %[[DOT:.*]] = "tf.XlaDotV2"(%[[BROADCAST_1]], %[[BROADCAST_2]]) +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1213 : i32}} { + func.func @batch_matmul_two_tensors_with_static_shape(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> (tensor<2x2x2xf32>) { + %cst = "tf.Const"() {value = dense<-5.450000e+01> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<0.0156862754> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<-5.000000e-01> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<0.0274509806> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {value = dense<-55> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg1, %cst_0) : (tensor<2x2x2xf32>, tensor) -> tensor<2x2x2xf32> + %1 = "tf.AddV2"(%0, %cst_1) : (tensor<2x2x2xf32>, tensor) -> tensor<2x2x2xf32> + %2 = "tf.Floor"(%1) : (tensor<2x2x2xf32>) -> tensor<2x2x2xf32> + %3 = "tf.ClipByValue"(%2, %cst_5, %cst_6) : (tensor<2x2x2xf32>, tensor, tensor) -> tensor<2x2x2xf32> + %4 = "tf.Cast"(%3) {Truncate = false} : (tensor<2x2x2xf32>) -> tensor<2x2x2xi8> + %5 = "tf.Div"(%arg0, %cst_3) : (tensor<2x2x2xf32>, tensor) -> tensor<2x2x2xf32> + %6 = "tf.AddV2"(%5, %cst) : (tensor<2x2x2xf32>, tensor) -> tensor<2x2x2xf32> + %7 = "tf.Floor"(%6) : (tensor<2x2x2xf32>) -> tensor<2x2x2xf32> + %8 = "tf.ClipByValue"(%7, %cst_5, %cst_6) : (tensor<2x2x2xf32>, tensor, tensor) -> tensor<2x2x2xf32> + %9 = "tf.Cast"(%8) {Truncate = false} : (tensor<2x2x2xf32>) -> tensor<2x2x2xi8> + %10 = "tf.Cast"(%4) {Truncate = false} : (tensor<2x2x2xi8>) -> tensor<2x2x2xi32> + %11 = "tf.Sub"(%10, %cst_2) : (tensor<2x2x2xi32>, tensor) -> tensor<2x2x2xi32> + %12 = "tf.Identity"(%9) : (tensor<2x2x2xi8>) -> tensor<2x2x2xi8> + %13 = "tf.Cast"(%12) {Truncate = false} : (tensor<2x2x2xi8>) -> tensor<2x2x2xi32> + %14 = "tf.Sub"(%13, %cst_4) : (tensor<2x2x2xi32>, tensor) -> tensor<2x2x2xi32> + %15 = "tf.BatchMatMulV2"(%11, %14) {adj_x = false, adj_y = false} : (tensor<2x2x2xi32>, tensor<2x2x2xi32>) -> tensor<2x2x2xi32> + %16 = "tf.Cast"(%15) {Truncate = false} : (tensor<2x2x2xi32>) -> tensor<2x2x2xf32> + %17 = "tf.Mul"(%16, %cst_0) : (tensor<2x2x2xf32>, tensor) -> tensor<2x2x2xf32> + %18 = "tf.AddV2"(%17, %cst) : (tensor<2x2x2xf32>, tensor) -> tensor<2x2x2xf32> + %19 = "tf.Floor"(%18) : (tensor<2x2x2xf32>) -> tensor<2x2x2xf32> + %20 = "tf.ClipByValue"(%19, %cst_5, %cst_6) : (tensor<2x2x2xf32>, tensor, tensor) -> tensor<2x2x2xf32> + %21 = "tf.Cast"(%20) {Truncate = false} : (tensor<2x2x2xf32>) -> tensor<2x2x2xi8> + %22 = "tf.Identity"(%21) {device = ""} : (tensor<2x2x2xi8>) -> tensor<2x2x2xi8> + %23 = "tf.Identity"(%22) {device = ""} : (tensor<2x2x2xi8>) -> tensor<2x2x2xi8> + %24 = "tf.Cast"(%23) : (tensor<2x2x2xi8>) -> tensor<2x2x2xi32> + %25 = "tf.Sub"(%24, %cst_4) : (tensor<2x2x2xi32>, tensor) -> tensor<2x2x2xi32> + %26 = "tf.Cast"(%25) : (tensor<2x2x2xi32>) -> tensor<2x2x2xf32> + %27 = "tf.Mul"(%26, %cst_3) : (tensor<2x2x2xf32>, tensor) -> tensor<2x2x2xf32> + return %27 : tensor<2x2x2xf32> + } + +// CHECK-LABEL: func @batch_matmul_two_tensors_with_static_shape +// CHECK: %[[arg1_div:.*]] = "tf.Div"(%arg1 +// CHECK: %[[arg1_add:.*]] = "tf.AddV2"(%[[arg1_div]] +// CHECK: %[[arg1_floor:.*]] = "tf.Floor"(%[[arg1_add]] +// CHECK: %[[arg1_clip:.*]] = "tf.ClipByValue"(%[[arg1_floor]] +// CHECK: %[[arg1_cast:.*]] = "tf.Cast"(%[[arg1_clip]] + +// CHECK: %[[arg0_div:.*]] = "tf.Div"(%arg0 +// CHECK: %[[arg0_add:.*]] = "tf.AddV2"(%[[arg0_div]] +// CHECK: %[[arg0_floor:.*]] = "tf.Floor"(%[[arg0_add]] +// CHECK: %[[arg0_clip:.*]] = "tf.ClipByValue"(%[[arg0_floor]] +// CHECK: %[[arg0_cast:.*]] = "tf.Cast"(%[[arg0_clip]] + +// CHECK: %[[matmul:.*]] = "tf.XlaDotV2"(%[[arg1_cast]], %[[arg0_cast]] +// CHECK-SAME: (tensor<2x2x2xi8>, tensor<2x2x2xi8>) -> tensor<2x2x2xi32> + +// CHECK: %[[matmul_sub:.*]] = "tf.Sub"(%[[matmul]] +// CHECK: %[[matmul_cast:.*]] = "tf.Cast"(%[[matmul_sub]] +// CHECK: %[[matmul_mul:.*]] = "tf.Mul"(%[[matmul_cast]] +// CHECK: %[[matmul_add:.*]] = "tf.AddV2"(%[[matmul_mul]] +// CHECK: %[[matmul_floor:.*]] = "tf.Floor"(%[[matmul_add]] +// CHECK: %[[matmul_clip:.*]] = "tf.ClipByValue"(%[[matmul_floor]] +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1213 : i32}} { + func.func @batch_matmul_two_tensors_with_dynamic_shape(%arg0: tensor<2x?x?xf32>, %arg1: tensor<2x?x?xf32>) -> (tensor<2x?x?xf32>) { + %cst = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_2 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64> + %cst_4 = "tf.Const"() {value = dense<-55> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> + %cst_6 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_7 = "tf.Const"() {value = dense<55> : tensor} : () -> tensor + %cst_8 = "tf.Const"() {value = dense<-5.450000e+01> : tensor} : () -> tensor + %cst_9 = "tf.Const"() {value = dense<0.0156862754> : tensor} : () -> tensor + %cst_10 = "tf.Const"() {value = dense<-5.000000e-01> : tensor} : () -> tensor + %cst_11 = "tf.Const"() {value = dense<0.0274509806> : tensor} : () -> tensor + %cst_12 = "tf.Const"() {value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_13 = "tf.Const"() {value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg1, %cst_9) : (tensor<2x?x?xf32>, tensor) -> tensor<2x?x?xf32> + %1 = "tf.AddV2"(%0, %cst_10) : (tensor<2x?x?xf32>, tensor) -> tensor<2x?x?xf32> + %2 = "tf.Floor"(%1) : (tensor<2x?x?xf32>) -> tensor<2x?x?xf32> + %3 = "tf.ClipByValue"(%2, %cst_12, %cst_13) : (tensor<2x?x?xf32>, tensor, tensor) -> tensor<2x?x?xf32> + %4 = "tf.Cast"(%3) {Truncate = false} : (tensor<2x?x?xf32>) -> tensor<2x?x?xi8> + %5 = "tf.Div"(%arg0, %cst_11) : (tensor<2x?x?xf32>, tensor) -> tensor<2x?x?xf32> + %6 = "tf.AddV2"(%5, %cst_8) : (tensor<2x?x?xf32>, tensor) -> tensor<2x?x?xf32> + %7 = "tf.Floor"(%6) : (tensor<2x?x?xf32>) -> tensor<2x?x?xf32> + %8 = "tf.ClipByValue"(%7, %cst_12, %cst_13) : (tensor<2x?x?xf32>, tensor, tensor) -> tensor<2x?x?xf32> + %9 = "tf.Cast"(%8) {Truncate = false} : (tensor<2x?x?xf32>) -> tensor<2x?x?xi8> + %10 = "tf.Shape"(%4) : (tensor<2x?x?xi8>) -> tensor<3xi64> + %11 = "tf.Slice"(%10, %cst, %cst_1) : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %12 = "tf.Slice"(%10, %cst_1, %cst_0) : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64> + %13 = "tf.Shape"(%9) : (tensor<2x?x?xi8>) -> tensor<3xi64> + %14 = "tf.Slice"(%13, %cst, %cst_1) : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %15 = "tf.Slice"(%13, %cst_1, %cst_0) : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64> + %16 = "tf.BroadcastArgs"(%11, %14) : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> + %17 = "tf.Concat"(%cst_2, %16, %12) : (tensor, tensor<1xi64>, tensor<2xi64>) -> tensor<3xi64> + %18 = "tf.Concat"(%cst_2, %16, %15) : (tensor, tensor<1xi64>, tensor<2xi64>) -> tensor<3xi64> + %19 = "tf.BroadcastTo"(%4, %17) : (tensor<2x?x?xi8>, tensor<3xi64>) -> tensor<2x?x?xi8> + %20 = "tf.BroadcastTo"(%9, %18) : (tensor<2x?x?xi8>, tensor<3xi64>) -> tensor<2x?x?xi8> + %21 = "tf.XlaDotV2"(%19, %20) {dimension_numbers = "\22\01\00\1A\01\00\12\01\01\0A\01\02", precision_config = ""} : (tensor<2x?x?xi8>, tensor<2x?x?xi8>) -> tensor<2x?x?xi32> + %22 = "tf.Cast"(%19) {Truncate = false} : (tensor<2x?x?xi8>) -> tensor<2x?x?xi32> + %23 = "tf.Sum"(%22, %cst_3) {keep_dims = true} : (tensor<2x?x?xi32>, tensor<1xi64>) -> tensor<2x?x1xi32> + %24 = "tf.Mul"(%23, %cst_4) : (tensor<2x?x1xi32>, tensor) -> tensor<2x?x1xi32> + %25 = "tf.Cast"(%20) {Truncate = false} : (tensor<2x?x?xi8>) -> tensor<2x?x?xi32> + %26 = "tf.Sum"(%25, %cst_5) {keep_dims = true} : (tensor<2x?x?xi32>, tensor<1xi64>) -> tensor<2x1x?xi32> + %27 = "tf.Mul"(%26, %cst_6) : (tensor<2x1x?xi32>, tensor) -> tensor<2x1x?xi32> + %28 = "tf.Shape"(%20) : (tensor<2x?x?xi8>) -> tensor<3xi64> + %29 = "tf.StridedSlice"(%28, %cst_5, %cst_3, %cst_5) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<3xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> + %30 = "tf.Cast"(%29) {Truncate = false} : (tensor<1xi64>) -> tensor<1xi32> + %31 = "tf.Mul"(%30, %cst_7) : (tensor<1xi32>, tensor) -> tensor<1xi32> + %32 = "tf.Add"(%24, %27) : (tensor<2x?x1xi32>, tensor<2x1x?xi32>) -> tensor<2x?x?xi32> + %33 = "tf.Sub"(%32, %31) : (tensor<2x?x?xi32>, tensor<1xi32>) -> tensor<2x?x?xi32> + %34 = "tf.Sub"(%21, %33) : (tensor<2x?x?xi32>, tensor<2x?x?xi32>) -> tensor<2x?x?xi32> + %35 = "tf.Cast"(%34) {Truncate = false} : (tensor<2x?x?xi32>) -> tensor<2x?x?xf32> + %36 = "tf.Mul"(%35, %cst_9) : (tensor<2x?x?xf32>, tensor) -> tensor<2x?x?xf32> + %37 = "tf.AddV2"(%36, %cst_8) : (tensor<2x?x?xf32>, tensor) -> tensor<2x?x?xf32> + %38 = "tf.Floor"(%37) : (tensor<2x?x?xf32>) -> tensor<2x?x?xf32> + %39 = "tf.ClipByValue"(%38, %cst_12, %cst_13) : (tensor<2x?x?xf32>, tensor, tensor) -> tensor<2x?x?xf32> + %40 = "tf.Cast"(%39) {Truncate = false} : (tensor<2x?x?xf32>) -> tensor<2x?x?xi8> + %41 = "tf.Identity"(%40) {device = ""} : (tensor<2x?x?xi8>) -> tensor<2x?x?xi8> + %42 = "tf.Identity"(%41) {device = ""} : (tensor<2x?x?xi8>) -> tensor<2x?x?xi8> + %43 = "tf.Cast"(%42) : (tensor<2x?x?xi8>) -> tensor<2x?x?xi32> + %44 = "tf.Sub"(%43, %cst_4) : (tensor<2x?x?xi32>, tensor) -> tensor<2x?x?xi32> + %45 = "tf.Cast"(%44) : (tensor<2x?x?xi32>) -> tensor<2x?x?xf32> + %46 = "tf.Mul"(%45, %cst_11) : (tensor<2x?x?xf32>, tensor) -> tensor<2x?x?xf32> + return %46 : tensor<2x?x?xf32> + } + +// CHECK-LABEL: func @batch_matmul_two_tensors_with_dynamic_shape +// CHECK: %[[arg1_div:.*]] = "tf.Div"(%arg1 +// CHECK: %[[arg1_add:.*]] = "tf.AddV2"(%[[arg1_div]] +// CHECK: %[[arg1_floor:.*]] = "tf.Floor"(%[[arg1_add]] +// CHECK: %[[arg1_clip:.*]] = "tf.ClipByValue"(%[[arg1_floor]] +// CHECK: %[[arg1_cast:.*]] = "tf.Cast"(%[[arg1_clip]] + +// CHECK: %[[arg0_div:.*]] = "tf.Div"(%arg0 +// CHECK: %[[arg0_add:.*]] = "tf.AddV2"(%[[arg0_div]] +// CHECK: %[[arg0_floor:.*]] = "tf.Floor"(%[[arg0_add]] +// CHECK: %[[arg0_clip:.*]] = "tf.ClipByValue"(%[[arg0_floor]] +// CHECK: %[[arg0_cast:.*]] = "tf.Cast"(%[[arg0_clip]] + +// CHECK: %[[arg1_broad:.*]] = "tf.BroadcastTo"(%[[arg1_cast]] +// CHECK: %[[arg0_broad:.*]] = "tf.BroadcastTo"(%[[arg0_cast]] + +// CHECK: %[[matmul:.*]] = "tf.XlaDotV2"(%[[arg1_broad]], %[[arg0_broad]] +// CHECK-SAME: (tensor<2x?x?xi8>, tensor<2x?x?xi8>) -> tensor<2x?x?xi32> + +// CHECK: %[[arg0_shape:.*]] = "tf.Shape"(%[[arg0_broad]] +// CHECK: %[[shape_zp_contribute:.*]] = "tf.StridedSlice"(%[[arg0_shape]] +// CHECK: %[[shape_zp_contribute_cast:.*]] = "tf.Cast"(%[[shape_zp_contribute]] +// CHECK: %[[shape_zp_contribute_mul:.*]] = "tf.Mul"(%[[shape_zp_contribute_cast]] +// CHECK: %[[zp:.*]] = "tf.Sub"({{.*}}, %[[shape_zp_contribute_mul]]) + +// CHECK: %[[matmul_sub:.*]] = "tf.Sub"(%[[matmul]], %[[zp]] +// CHECK: %[[matmul_cast:.*]] = "tf.Cast"(%[[matmul_sub]] +// CHECK: %[[matmul_mul:.*]] = "tf.Mul"(%[[matmul_cast]] +// CHECK: %[[matmul_add:.*]] = "tf.AddV2"(%[[matmul_mul]] +// CHECK: %[[matmul_floor:.*]] = "tf.Floor"(%[[matmul_add]] +// CHECK: %[[matmul_clip:.*]] = "tf.ClipByValue"(%[[matmul_floor]] +} + diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops_large_constants.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops_large_constants.mlir index 59e9b0ea80a..775ab82e105 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops_large_constants.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/replace_cast_hacks_with_tf_xla_ops_large_constants.mlir @@ -56,7 +56,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p } // CHECK-LABEL: func @conv_with_filter_larger_than_1GB -// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<-237772800> : tensor<512xi32>} : () -> tensor<512xi32> +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {value = dense<-237772800> : tensor<1x1x1x512xi32>} : () -> tensor<1x1x1x512xi32> // CHECK: %[[PADV2_0:.*]] = "tf.PadV2" // CHECK: %[[XLACONVV2_0:.*]] = "tf.XlaConvV2"(%[[PADV2_0]] // CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLACONVV2_0]], %[[CONST]]) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/unfreeze_constants.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/unfreeze_constants.mlir index a3faed341bb..600b2926631 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/unfreeze_constants.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/unfreeze_constants.mlir @@ -1,9 +1,9 @@ -// RUN: tf-quant-opt %s -quant-unfreeze-constants -allow-unregistered-dialect \ -// RUN: -mlir-disable-threading -split-input-file -verify-diagnostics | \ -// RUN: FileCheck %s +// RUN: tf-quant-opt %s -quant-unfreeze-constants='size_threshold_in_bytes=16' \ +// RUN: -allow-unregistered-dialect -mlir-disable-threading \ +// RUN: -split-input-file -verify-diagnostics | FileCheck %s // Tests a case with one ConstOp and a tf_saved_model.session_initializer with an empty initializers. -module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1287 : i32}, tf_saved_model.semantics} { +module attributes {tf_saved_model.semantics} { "tf_saved_model.session_initializer"() {initializers = []} : () -> () // Check that the init function is created & added to the initializers attribute. @@ -34,7 +34,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // ----- // Tests the case when there's no tf_saved_model.sesion_initializer. -module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1287 : i32}, tf_saved_model.semantics} { +module attributes {tf_saved_model.semantics} { // Check that a new tf_saved_model.session_initializer is created, along with an initialier function. // CHECK: "tf_saved_model.session_initializer"() @@ -44,34 +44,34 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK-SAME: tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"] // CHECK-SAME: tf_saved_model.initializer_type = "restore_op" -// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {{{.*value = dense<1.000000e\+00> : tensor<8xf32>.*}}} // CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}} // CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[CST_0]]) -// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<2xf32>} +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {{{.*value = dense<2.000000e\+00> : tensor<8xf32>.*}}} // CHECK-DAG: %[[VAR_HANDLE_1:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_1".*}} // CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_1]], %[[CST_1]]) - func.func @serving_default() -> (tensor<2xf32> {tf_saved_model.index_path = ["output"]}) + func.func @serving_default() -> (tensor<8xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "output:0"}, tf_saved_model.exported_names = ["serving_default"]} { - %cst_0 = "tf.Const"() {device = "", value = dense<1.0> : tensor<2xf32>} : () -> tensor<2xf32> - %cst_1 = "tf.Const"() {device = "", value = dense<2.0> : tensor<2xf32>} : () -> tensor<2xf32> - %0 = "tf.Add"(%cst_0, %cst_1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - return %0 : tensor<2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<1.0> : tensor<8xf32>} : () -> tensor<8xf32> + %cst_1 = "tf.Const"() {device = "", value = dense<2.0> : tensor<8xf32>} : () -> tensor<8xf32> + %0 = "tf.AddV2"(%cst_0, %cst_1) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32> + return %0 : tensor<8xf32> } // CHECK: @serving_default -// CHECK-DAG: %[[VAR_HANDLE_2:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}} : () -> tensor>> -// CHECK-DAG: %[[READ_VAR_0:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_2]]) : (tensor>>) -> tensor<2xf32> -// CHECK-DAG: %[[VAR_HANDLE_3:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_1".*}} : () -> tensor>> -// CHECK-DAG: %[[READ_VAR_1:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_3]]) : (tensor>>) -> tensor<2xf32> -// CHECK-DAG: %[[ADD_0:.*]] = "tf.Add"(%[[READ_VAR_0]], %[[READ_VAR_1]]) -// CHECK: return %[[ADD_0]] : tensor<2xf32> +// CHECK-DAG: %[[VAR_HANDLE_2:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}} : () -> tensor>> +// CHECK-DAG: %[[READ_VAR_0:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_2]]) : (tensor>>) -> tensor<8xf32> +// CHECK-DAG: %[[VAR_HANDLE_3:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_1".*}} : () -> tensor>> +// CHECK-DAG: %[[READ_VAR_1:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_3]]) : (tensor>>) -> tensor<8xf32> +// CHECK-DAG: %[[ADD_0:.*]] = "tf.AddV2"(%[[READ_VAR_0]], %[[READ_VAR_1]]) +// CHECK: return %[[ADD_0]] : tensor<8xf32> } // ----- // Tests the case when there's a tf_saved_model.sesion_initializer and an empty init function. -module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1287 : i32}, tf_saved_model.semantics} { +module attributes {tf_saved_model.semantics} { "tf_saved_model.session_initializer"() {initializers = [@init]} : () -> () // CHECK: "tf_saved_model.session_initializer"() @@ -84,34 +84,34 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // CHECK-SAME: tf_saved_model.exported_names = ["tf_saved_model.session_initializer_init"] // CHECK-SAME: tf_saved_model.initializer_type = "restore_op" -// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<8xf32>} // CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() // CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[CST_0]]) -// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<2xf32>} +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<8xf32>} // CHECK-DAG: %[[VAR_HANDLE_1:.*]] = "tf.VarHandleOp"() // CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_1]], %[[CST_1]]) - func.func @serving_default(%arg0: tensor<2xf32> {tf_saved_model.index_path = ["input"]}) -> (tensor<2xf32> {tf_saved_model.index_path = ["output"]}) + func.func @serving_default(%arg0: tensor<8xf32> {tf_saved_model.index_path = ["input"]}) -> (tensor<8xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { - %cst_0 = "tf.Const"() {device = "", value = dense<1.0> : tensor<2xf32>} : () -> tensor<2xf32> - %cst_1 = "tf.Const"() {device = "", value = dense<2.0> : tensor<2xf32>} : () -> tensor<2xf32> - %0 = "tf.Sub"(%cst_0, %cst_1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - return %0 : tensor<2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<1.0> : tensor<8xf32>} : () -> tensor<8xf32> + %cst_1 = "tf.Const"() {device = "", value = dense<2.0> : tensor<8xf32>} : () -> tensor<8xf32> + %0 = "tf.Sub"(%cst_0, %cst_1) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32> + return %0 : tensor<8xf32> } // CHECK: @serving_default -// CHECK-DAG: %[[VAR_HANDLE_2:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}} : () -> tensor>> -// CHECK-DAG: %[[READ_VAR_0:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_2]]) : (tensor>>) -> tensor<2xf32> -// CHECK-DAG: %[[VAR_HANDLE_3:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_1".*}} : () -> tensor>> -// CHECK-DAG: %[[READ_VAR_1:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_3]]) : (tensor>>) -> tensor<2xf32> +// CHECK-DAG: %[[VAR_HANDLE_2:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}} : () -> tensor>> +// CHECK-DAG: %[[READ_VAR_0:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_2]]) : (tensor>>) -> tensor<8xf32> +// CHECK-DAG: %[[VAR_HANDLE_3:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_1".*}} : () -> tensor>> +// CHECK-DAG: %[[READ_VAR_1:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_3]]) : (tensor>>) -> tensor<8xf32> // CHECK-DAG: %[[SUB_0:.*]] = "tf.Sub"(%[[READ_VAR_0]], %[[READ_VAR_1]]) -// CHECK: return %[[SUB_0]] : tensor<2xf32> +// CHECK: return %[[SUB_0]] : tensor<8xf32> } // ----- // Tests the case when there's a tf_saved_model.sesion_initializer and an init function whose type is "init_op". -module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1287 : i32}, tf_saved_model.semantics} { +module attributes {tf_saved_model.semantics} { "tf_saved_model.session_initializer"() {initializers = [@init]} : () -> () // Check that @init_func_restore_op is added to the initializers list. @@ -144,15 +144,63 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p // ----- // Tests the case when there is no ConstOp. -module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1287 : i32}, tf_saved_model.semantics} { +module attributes {tf_saved_model.semantics} { // Check that nothing happens when there's no ConstOp in the graph. // CHECK-NOT: "tf_saved_model.session_initializer"() - func.func @serving_default(%arg_0: tensor<3xf32> {tf_saved_model.index_path = ["input"]}) -> (tensor<3xf32> {tf_saved_model.index_path = ["output"]}) + func.func @serving_default(%arg_0: tensor<5xf32> {tf_saved_model.index_path = ["input"]}) -> (tensor<5xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "inputs:0", outputs = "output:0"}, tf_saved_model.exported_names = ["serving_default"]} { - return %arg_0 : tensor<3xf32> + return %arg_0 : tensor<5xf32> + } +// CHECK: @serving_default(%[[ARG_0:.*]]: tensor<5xf32> {{.*}}) +// CHECK-NEXT: return %[[ARG_0]] : tensor<5xf32> +} + +// ----- + +// Tests that constants that are smaller than "size_threshold_in_bytes" are +// not converted to variables. This test uses the threshold of 16 bytes. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + + func.func @init_func_restore_op() attributes {tf_saved_model.exported_names = ["tf_saved_model.session_initializer_init"], + tf_saved_model.initializer_type = "restore_op"} { + return } -// CHECK: @serving_default(%[[ARG_0:.*]]: tensor<3xf32> {{.*}}) -// CHECK-NEXT: return %[[ARG_0]] : tensor<3xf32> + + func.func @serving_default() -> (tensor<12xf32> {tf_saved_model.index_path = ["output"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "output:0"}, tf_saved_model.exported_names = ["serving_default"]} { + // Should be unfrozen. + %cst_0 = "tf.Const"() {value = dense<5.0> : tensor<8xf32>} : () -> tensor<8xf32> + // Consts below are smaller than or equal to the threshold so they + // should not be converted to variables. + %cst_1 = "tf.Const"() {value = dense<5.0> : tensor<4xf32>} : () -> tensor<4xf32> + %cst_axis = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %0 = "tf.ConcatV2"(%cst_0, %cst_1, %cst_axis) : (tensor<8xf32>, tensor<4xf32>, tensor) -> tensor<12xf32> + return %0 : tensor<12xf32> + } +// CHECK: func.func @init_func_restore_op() + +// Check that `tf.VarHandleOp` is only created for the constant that is larger +// than the threshold (16 bytes for this test). +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {{{.*value = dense<5.000000e\+00> : tensor<8xf32>.*}}} +// CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}} +// CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[CST_0]]) + +// Make sure that there are no more `tf.VarHandleOp`s and `tf.AssignVariableOp`s +// in this function. +// CHECK-NOT: "tf.VarHandleOp" +// CHECK-NOT: "tf.AssignVariableOp" + +// Only the large constant is replaced with the `tf.VarHandleOp -> +// tf.ReadVariableOp` pattern and others remain as `tf.Const`s. +// CHECK: @serving_default +// CHECK-DAG: %[[VAR_HANDLE_2:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}} : () -> tensor>> +// CHECK-DAG: %[[READ_VAR_0:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_2]]) : (tensor>>) -> tensor<8xf32> +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {{{.*value = dense<5.000000e\+00> : tensor<4xf32>.*}}} +// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {{{.*value = dense<0> : tensor.*}}} +// CHECK-DAG: %[[CONCAT:.*]] = "tf.ConcatV2"(%[[READ_VAR_0]], %[[CST_1]], %[[AXIS]]) +// CHECK: return %[[CONCAT]] : tensor<12xf32> } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD index d75b2806c2a..f2d89b2df75 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD @@ -3,6 +3,7 @@ load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") load("//tensorflow:tensorflow.bzl", "tf_cc_test") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/compiler/mlir/quantization/tensorflow:internal_visibility_allowlist_package", ], @@ -41,6 +42,7 @@ cc_library( compatible_with = get_compatible_with_cloud(), deps = [ "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", + "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "@com_google_absl//absl/strings", @@ -50,6 +52,22 @@ cc_library( ], ) +cc_library( + name = "tf_to_uniform_attribute_utils", + srcs = ["tf_to_uniform_attribute_utils.cc"], + hdrs = ["tf_to_uniform_attribute_utils.h"], + compatible_with = get_compatible_with_cloud(), + deps = [ + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", + "//tensorflow/compiler/mlir/quantization/tensorflow:pass_utils", + "//tensorflow/compiler/mlir/quantization/tensorflow:uniform_op_quant_spec", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "tf_to_xla_attribute_utils", srcs = ["tf_to_xla_attribute_utils.cc"], diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/fake_quant_utils.h b/tensorflow/compiler/mlir/quantization/tensorflow/utils/fake_quant_utils.h index e36f6a3c632..93c62d4492d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/fake_quant_utils.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/fake_quant_utils.h @@ -36,8 +36,8 @@ struct FetchMinMaxAttrs { using AttrType = FloatAttr; bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, AttrType &max_value) const { - min_value = tf_op.minAttr(); - max_value = tf_op.maxAttr(); + min_value = tf_op.getMinAttr(); + max_value = tf_op.getMaxAttr(); return true; // Successfully matched and fetched. } }; @@ -47,7 +47,14 @@ struct FetchConstantMinMaxInputs { using AttrType = DenseFPElementsAttr; bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, AttrType &max_value) const { - Value min = tf_op.min(), max = tf_op.max(); + Value min = tf_op.getMin(), max = tf_op.getMax(); + if (auto min_id = min.getDefiningOp()) { + min = min_id.getInput(); + } + if (auto max_id = max.getDefiningOp()) { + max = max_id.getInput(); + } + if (!matchPattern(min, m_Constant(&min_value))) { return false; } @@ -91,7 +98,7 @@ class ConvertFakeQuantOpToQuantOps { using FetchAttrType = typename FetchMinMax::AttrType; LogicalResult matchAndRewrite(TFFakeQuantOp tf_op, OpBuilder &rewriter) const { - if (tf_op.num_bits() != 8) { + if (tf_op.getNumBits() != 8) { return failure(); } @@ -103,7 +110,7 @@ class ConvertFakeQuantOpToQuantOps { return failure(); } - Value input = tf_op.inputs(); + Value input = tf_op.getInputs(); int quant_dim = -1; auto input_type = input.getType().template cast(); if (PerAxis) { @@ -117,8 +124,8 @@ class ConvertFakeQuantOpToQuantOps { // Use the min/max from the operands and the num_bits and narrow_range // attribute to create the quantization parameter for the new quantize op. rewriter.setInsertionPointAfter(tf_op.getOperation()); - IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.num_bits()); - BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.narrow_range()); + IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.getNumBits()); + BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.getNarrowRange()); Type res_type = tf_op.getType(); TypeAttr qtype = quant::GetQuantizedTypeAttr( rewriter, input_type, min_value, max_value, quant_dim, num_bits, @@ -135,7 +142,7 @@ class ConvertFakeQuantOpToQuantOps { tf_op.getLoc(), qtype.getValue(), input); auto dequantize = rewriter.create( tf_op.getLoc(), res_type, quantize.getResult()); - tf_op.outputs().replaceAllUsesWith(dequantize); + tf_op.getOutputs().replaceAllUsesWith(dequantize); return success(); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.cc index 45311e45f7b..25af606123c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -62,7 +63,7 @@ ValueRange createFusedFnCall(OpBuilder builder, Location location, builder.getStringAttr(llvm::StringRef( std::string(QuantTraitValues[QuantizationTrait::FullyQuantizable])))); - return call_op.output(); + return call_op.getOutput(); } // Finds ops in the paths from arguments to results. The ops is listed in an @@ -201,7 +202,7 @@ llvm::SmallVector LiftAsFunctionCall( builder.createBlock(&wrap_func.getBody(), wrap_func.begin(), arg_types, arg_locs); - BlockAndValueMapping mapping; + IRMapping mapping; for (int32_t i : llvm::seq(0, arguments.size())) { mapping.map(arguments[i], wrap_func.getArgument(i)); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h b/tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h index 46fc7dfa837..7cface2dc98 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h @@ -27,7 +27,6 @@ limitations under the License. namespace mlir { namespace quant { -inline constexpr absl::string_view kAttrMapAttribute = "attr_map"; // This attribute will be set for functions created by this pass. inline constexpr absl::string_view kFusedFunctionAttr = "tf_quant.composite_function"; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.cc new file mode 100644 index 00000000000..e6f74a654aa --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.cc @@ -0,0 +1,241 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" +#include "tensorflow/core/util/quantization/uniform_quant_ops_attr.pb.h" + +namespace mlir::quant { + +using QuantMethod = + tensorflow::quantization::QuantizationMethod::ExperimentalMethod; + +Attribute GetWindowStridesValue( + PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { + ArrayAttr stride = identifier_to_attr["strides"].dyn_cast(); + const int stride_h = stride[1].cast().getInt(); + const int stride_w = stride[2].cast().getInt(); + return rewriter.getI64ArrayAttr({stride_h, stride_w}); +} + +Attribute GetLhsDilationValue(PatternRewriter& rewriter, + llvm::StringMap& identifier_to_attr) { + return rewriter.getI64ArrayAttr({1, 1}); +} + +Attribute GetRhsDilationValue(PatternRewriter& rewriter, + llvm::StringMap& identifier_to_attr) { + ArrayAttr dilations = identifier_to_attr["dilations"].dyn_cast(); + const int dilation_h = dilations[1].cast().getInt(); + const int dilation_w = dilations[2].cast().getInt(); + return rewriter.getI64ArrayAttr({dilation_h, dilation_w}); +} + +Attribute GetPaddingValue(PatternRewriter& rewriter, + llvm::StringMap& identifier_to_attr) { + llvm::StringRef padding = + identifier_to_attr["padding"].dyn_cast().getValue(); + return rewriter.getStringAttr(padding); +} + +Attribute GetExplicitPaddingValue( + PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { + ArrayAttr explicit_padding = + identifier_to_attr["explicit_paddings"].dyn_cast(); + return explicit_padding; +} + +Attribute GetDimensionNumbersValue( + PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { + // Only NHWC is lifted in TF-quant and the corresponding dimension number is + // [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]. + + tensorflow::UniformQuantizedConvolutionDimensionNumbersAttr dimension_numbers; + if (!tensorflow::protobuf::TextFormat::ParseFromString( + R"pb( + input_batch_dimension: 0 + input_feature_dimension: 3 + input_spatial_dimensions: 1 + input_spatial_dimensions: 2 + kernel_output_feature_dimension: 3 + kernel_input_feature_dimension: 2 + kernel_spatial_dimensions: 0 + kernel_spatial_dimensions: 1 + output_batch_dimension: 0 + output_feature_dimension: 3 + output_spatial_dimensions: 1 + output_spatial_dimensions: 2 + )pb", + &dimension_numbers)) { + return rewriter.getStringAttr(""); + } + return rewriter.getStringAttr(dimension_numbers.SerializeAsString()); +} + +Attribute GetBatchGroupCountValue( + PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { + // Only 1 case is supported. + return rewriter.getI64IntegerAttr(1); +} + +void FillQuantizationAttributes(PatternRewriter& rewriter, Operation* op, + NamedAttrList& attrs, + llvm::StringMap& identifier_to_attr, + QuantMethod quantization_method) { + // TODO(b/259374419): Support broader quantization schemes + absl::flat_hash_map min_max_scheme_for_8bit_narrow; + min_max_scheme_for_8bit_narrow = {{"min", -127}, {"max", 127}}; + + std::set quantization_attributes; + if (quantization_method == + tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE) { + quantization_attributes = { + "rhs_quantization_min_val", + "rhs_quantization_max_val", + }; + } else { + quantization_attributes = { + "lhs_quantization_min_val", "lhs_quantization_max_val", + "rhs_quantization_min_val", "rhs_quantization_max_val", + "output_quantization_min_val", "output_quantization_max_val", + }; + } + + for (const auto& attr : quantization_attributes) { + auto quant_val = absl::StrContains(attr, "min") + ? min_max_scheme_for_8bit_narrow["min"] + : min_max_scheme_for_8bit_narrow["max"]; + auto quant_val_attr = rewriter.getI64IntegerAttr(quant_val); + attrs.push_back(rewriter.getNamedAttr(attr, quant_val_attr)); + } +} + +LogicalResult FillAttributesForUniformQuantizedDotOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + QuantMethod quantization_method, bool enable_per_channel_quantization) { + NamedAttrList attrs; + + // Fill quantization related attributes. + FillQuantizationAttributes(rewriter, op, attrs, identifier_to_attr, + quantization_method); + + if (!(quantization_method == + tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE)) { + // Per-channel activation is not supported + attrs.push_back(rewriter.getNamedAttr("lhs_quantization_axis", + rewriter.getI64IntegerAttr(-1))); + attrs.push_back(rewriter.getNamedAttr("output_quantization_axis", + rewriter.getI64IntegerAttr(-1))); + } + + std::unique_ptr spec = GetUniformOpQuantSpec(op); + absl::flat_hash_set operands = spec->quantizable_operands; + int quant_dim = -1; + if (enable_per_channel_quantization && operands.size() == 1) { + quant_dim = spec->coeff_op_quant_dim[*(spec->quantizable_operands.begin())]; + } + attrs.push_back(rewriter.getNamedAttr("rhs_quantization_axis", + rewriter.getI64IntegerAttr(quant_dim))); + + op->setAttrs(rewriter.getDictionaryAttr(attrs)); + + return success(); +} + +LogicalResult FillAttributesForUniformQuantizedConvolutionOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + QuantMethod quantization_method, bool enable_per_channel_quantization) { + NamedAttrList attrs; + absl::flat_hash_map&)> + attribute_getter_map; + + attribute_getter_map = {{"window_strides", GetWindowStridesValue}, + {"lhs_dilation", GetLhsDilationValue}, + {"rhs_dilation", GetRhsDilationValue}, + {"padding", GetPaddingValue}, + {"explicit_padding", GetExplicitPaddingValue}, + {"dimension_numbers", GetDimensionNumbersValue}, + {"batch_group_count", GetBatchGroupCountValue}}; + + for (auto& attr : op->getAttrs()) { + llvm::StringRef attr_name = attr.getName().getValue(); + if (attribute_getter_map.find(attr_name.str()) != + attribute_getter_map.end()) { + auto attr_val = + (attribute_getter_map[attr_name.str()])(rewriter, identifier_to_attr); + attrs.push_back(rewriter.getNamedAttr(attr_name, attr_val)); + } + } + + auto feature_group_cnt_attr = llvm::StringRef("feature_group_count"); + int feature_group_cnt = 1; + ShapedType input_shape = op->getOperand(0).getType().dyn_cast(); + if (!input_shape) { + return op->emitError( + "Only input with known shape is supported for Uniform Quantized " + "opset."); + } + + if (op->getParentOfType().getName().contains("depthwise_")) { + feature_group_cnt = input_shape.getDimSize(3); + } + + attrs.push_back(rewriter.getNamedAttr( + feature_group_cnt_attr, rewriter.getI64IntegerAttr(feature_group_cnt))); + + // Fill quantization related attributes. + FillQuantizationAttributes(rewriter, op, attrs, identifier_to_attr, + quantization_method); + + if (quantization_method != + tensorflow::quantization::QuantizationMethod::DYNAMIC_RANGE) { + // Per-channel activation is not supported + attrs.push_back(rewriter.getNamedAttr("lhs_quantization_axis", + rewriter.getI64IntegerAttr(-1))); + attrs.push_back(rewriter.getNamedAttr("output_quantization_axis", + rewriter.getI64IntegerAttr(-1))); + } + + std::unique_ptr spec = GetUniformOpQuantSpec(op); + absl::flat_hash_set operands = spec->quantizable_operands; + int quant_dim = -1; + if (enable_per_channel_quantization && operands.size() == 1) { + quant_dim = spec->coeff_op_quant_dim[*(spec->quantizable_operands.begin())]; + } + attrs.push_back(rewriter.getNamedAttr("rhs_quantization_axis", + rewriter.getI64IntegerAttr(quant_dim))); + + op->setAttrs(rewriter.getDictionaryAttr(attrs)); + + return success(); +} + +} // namespace mlir::quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h new file mode 100644 index 00000000000..547473f3d90 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h @@ -0,0 +1,44 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This header file defines common utils used when transforming TF ops to +// Uniform Quantized ops. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TO_UNIFORM_ATTRIBUTE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TO_UNIFORM_ATTRIBUTE_UTILS_H_ + +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h" + +namespace mlir::quant { + +LogicalResult FillAttributesForUniformQuantizedDotOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::ExperimentalMethod + quantization_method, + bool enable_per_channel_quantization); + +LogicalResult FillAttributesForUniformQuantizedConvolutionOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::ExperimentalMethod + quantization_method, + bool enable_per_channel_quantization); + +} // namespace mlir::quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TO_UNIFORM_ATTRIBUTE_UTILS_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.h b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.h index bf8156bcf83..52dcdcbc780 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This header file defines common utils used when transforming TF ops to -// XLA/Uniform Quantized ops. +// This header file defines common utils used when transforming TF ops to XLA +// ops. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TO_XLA_ATTRIBUTE_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TO_XLA_ATTRIBUTE_UTILS_H_ diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index 406ba58f9e4..e29e29cbd95 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -67,15 +67,46 @@ config.mlir_tools_dir, config.llvm_tools_dir ] tool_names = [ - 'mlir-opt', 'mlir-hlo-opt', 'mlir-translate', 'tf-opt', 'tf-reduce', - 'tf_tfl_translate', 'tf_tfjs_translate', 'flatbuffer_to_string', - 'flatbuffer_translate', 'tf-mlir-translate', 'mlir-tflite-runner', - 'tfcompile', 'json_to_flatbuffer', 'xla-cpu-opt', 'xla-gpu-opt', - 'xla-mlir-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir', 'kernel-gen-opt', - 'tf_to_kernel', 'tf_to_gpu_binary', 'tfjs-opt', 'tac-opt-all-backends', - 'tac-translate', 'tfg-opt-no-passes', 'tfg-transforms-opt', 'tfg-translate', - 'tf-tfrt-opt', 'lhlo-tfrt-opt', 'tf-quant-opt', 'mhlo-tosa-opt', - 'xla-runtime-opt', 'tf-mhlo-tfl-opt', 'odml_to_stablehlo', 'xla-translate' + 'dtensor-opt', + 'flatbuffer_to_string', + 'flatbuffer_translate', + 'hlo_to_llvm_ir', + 'json_to_flatbuffer', + 'kernel-gen-opt', + 'lhlo-tfrt-opt', + 'mhlo-tosa-opt', + 'mlir-bisect', + 'mlir-hlo-opt', + 'mlir-interpreter-runner', + 'mlir-opt', + 'mlir-tflite-runner', + 'mlir-translate', + 'odml-to-stablehlo-opt', + 'odml_to_stablehlo', + 'tac-opt-all-backends', + 'tac-translate', + 'tf-mlir-translate', + 'tf-opt', + 'tf-quant-opt', + 'tf-reduce', + 'tf-tfrt-opt', + 'tf_tfjs_translate', + 'tf_tfl_translate', + 'tf_to_gpu_binary', + 'tf_to_kernel', + 'tfcompile', + 'tfg-opt-no-passes', + 'tfg-transforms-opt', + 'tfg-translate', + 'tfjs-opt', + 'xla-cpu-opt', + 'xla-gpu-opt', + 'xla-mlir-gpu-opt', + 'xla-opt', + 'xla-runtime-opt', + 'xla-translate', + 'xla-translate-gpu-opt', + 'xla-translate-opt', ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py index eb2ba3d1213..293118cda2b 100644 --- a/tensorflow/compiler/mlir/runlit.site.cfg.py +++ b/tensorflow/compiler/mlir/runlit.site.cfg.py @@ -37,26 +37,30 @@ config.suffixes = ['.td', '.mlir', '.pbtxt'] mlir_tf_tools_dirs = [ - 'tensorflow/core/ir/importexport/', - 'tensorflow/core/ir/tests/', - 'tensorflow/core/transforms/', + 'tensorflow/compiler/aot', 'tensorflow/compiler/mlir', - 'tensorflow/compiler/xla/mlir_hlo', - 'tensorflow/compiler/xla/mlir_hlo/tosa', - 'tensorflow/compiler/xla/translate', 'tensorflow/compiler/mlir/lite', 'tensorflow/compiler/mlir/lite/experimental/tac', + 'tensorflow/compiler/mlir/lite/stablehlo', 'tensorflow/compiler/mlir/quantization/tensorflow', 'tensorflow/compiler/mlir/tensorflow', 'tensorflow/compiler/mlir/tfrt', - 'tensorflow/compiler/mlir/xla', 'tensorflow/compiler/mlir/tools/kernel_gen', - 'tensorflow/compiler/aot', - 'tensorflow/compiler/xla/service/mlir_gpu', - 'tensorflow/compiler/xla/service/gpu/tests', + 'tensorflow/compiler/mlir/xla', + 'tensorflow/compiler/xla/mlir/backends/cpu', + 'tensorflow/compiler/xla/mlir/backends/gpu', 'tensorflow/compiler/xla/mlir/runtime', - 'tensorflow/compiler/xla/mlir/tools', - 'tensorflow/compiler/mlir/lite/stablehlo', + 'tensorflow/compiler/xla/mlir/tools/mlir_bisect', + 'tensorflow/compiler/xla/mlir_hlo', + 'tensorflow/compiler/xla/mlir_hlo/tosa', + 'tensorflow/compiler/xla/service/gpu/tests', + 'tensorflow/compiler/xla/service/mlir_gpu', + 'tensorflow/compiler/xla/translate', + 'tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla', + 'tensorflow/core/ir/importexport/', + 'tensorflow/core/ir/tests/', + 'tensorflow/core/transforms/', + 'tensorflow/dtensor/mlir/tests', ] config.mlir_tf_tools_dirs = [ os.path.join(real_test_srcdir, os.environ['TEST_WORKSPACE'], s) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 44102734392..1b7ca19ea13 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -9,6 +9,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_gen_op_wrapper_py") # # copybara:uncomment_end(google-only) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) @@ -706,6 +707,7 @@ cc_library( "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", ], includes = ["include"], + visibility = ["//visibility:public"], deps = [ ":tensorflow_all_ops_inc_gen", ":tensorflow_attributes", @@ -765,6 +767,21 @@ gentbl_cc_library( ], ) +tf_cc_test( + name = "tf_saved_model_test", + srcs = ["ir/tf_saved_model_test.cc"], + deps = [ + ":tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:test", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "decompose_resource_ops", srcs = [ @@ -966,6 +983,7 @@ cc_library( "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework_internal", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/algorithm:container", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -973,6 +991,16 @@ cc_library( ], ) +cc_library( + name = "topological_sort", + srcs = ["utils/topological_sort.cc"], + hdrs = ["utils/topological_sort.h"], + deps = [ + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "initialize_variables_in_session_init", srcs = [ @@ -1000,6 +1028,7 @@ cc_library( cc_library( name = "tf_saved_model_passes", srcs = [ + "transforms/convert_session_initializer_to_function.cc", "transforms/deduplicate_bound_input_bindings.cc", "transforms/freeze_global_tensors.cc", "transforms/freeze_saved_model_assets.cc", @@ -1012,6 +1041,7 @@ cc_library( hdrs = [ "transforms/tf_saved_model_passes.h", ], + visibility = ["//visibility:public"], deps = [ ":resource_value_typed_analyzer", ":tensorflow", @@ -1217,6 +1247,7 @@ cc_library( "transforms/remove_unused_while_results.cc", "transforms/replica_id_to_device_ordinal.cc", "transforms/replicate_invariant_op_hoisting.cc", + "transforms/replicate_tensor_list_init_ops_pass.cc", "transforms/replicate_to_island.cc", "transforms/resource_device_inference.cc", "transforms/resource_op_lifting.cc", @@ -1244,6 +1275,7 @@ cc_library( "transforms/tpu_identity_pruning.cc", "transforms/tpu_merge_variables_with_execute.cc", "transforms/tpu_parallel_execute_sink_resource_write.cc", + "transforms/tpu_partitioned_op_conversion.cc", "transforms/tpu_reorder_replicate_and_partitioned_inputs.cc", "transforms/tpu_resource_partitioning.cc", "transforms/tpu_resource_read_for_write.cc", @@ -1264,10 +1296,12 @@ cc_library( ], hdrs = [ "transforms/bridge.h", + "transforms/call_graph_util.h", "transforms/cluster_ops_by_policy.h", "transforms/collection_ops_util.h", "transforms/einsum.h", "transforms/passes.h", + "translate/split_into_island_per_op_pass.h", ], includes = ["include"], textual_hdrs = [ @@ -1275,6 +1309,7 @@ cc_library( "transforms/tf_passes.h.inc", "transforms/tf_savedmodel_passes.h.inc", ], + visibility = ["//visibility:public"], deps = [ ":attribute_utils", ":bridge_logger", @@ -1304,7 +1339,9 @@ cc_library( ":tf_pass_inc_gen", ":tf_savedmodel_pass_inc_gen", ":tfe_legalize_tfg", + ":topological_sort", ":tpu_cluster_util", + ":tpu_embedding_ops_registry", ":tpu_rewrite_device_util", ":translate_utils", ":unroll_batch_matmul_pass", @@ -1332,6 +1369,8 @@ cc_library( "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc", "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1497,7 +1536,6 @@ cc_library( deps = [ ":attribute_utils", "//tensorflow/compiler/tf2xla:functionalize_control_flow", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core/common_runtime:device", @@ -1519,6 +1557,7 @@ cc_library( hdrs = [ "translate/export_graphdef.h", ], + visibility = ["//visibility:public"], deps = [ ":convert_type", ":error_util", @@ -1531,7 +1570,6 @@ cc_library( "//tensorflow/compiler/mlir:name_utils", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -1587,7 +1625,6 @@ cc_library( "//tensorflow/compiler/jit:shape_inference_helpers", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -1654,7 +1691,6 @@ cc_library( ":tensorflow_attributes", ":tensorflow_types", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", @@ -1700,7 +1736,6 @@ cc_library( "utils/translate_utils.h", ], deps = [ - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@llvm-project//mlir:FuncDialect", @@ -1723,7 +1758,6 @@ cc_library( ":tensorflow", "//tensorflow/compiler/mlir:string_container_utils", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -1760,7 +1794,6 @@ cc_library( ":import_model", ":mlir_roundtrip_flags", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:core_cpu_lib", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -1782,6 +1815,7 @@ cc_library( name = "mlir_roundtrip_flags", srcs = ["translate/mlir_roundtrip_flags.cc"], hdrs = ["translate/mlir_roundtrip_flags.h"], + visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:framework", @@ -1807,9 +1841,9 @@ cc_library( ":convert_tensor", ":convert_type", ":tensorflow_attributes", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", + "//tensorflow/tsl/platform:statusor", "@llvm-project//mlir:IR", ], ) @@ -1831,7 +1865,6 @@ cc_library( deps = [ ":dynamic_shape_utils", ":tensorflow_types", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -1849,7 +1882,6 @@ tf_cc_test( deps = [ ":convert_type", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -1868,10 +1900,10 @@ cc_library( ":mangling_util", ":tensorflow_attributes", ":tensorflow_types", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/tsl/platform:float8", "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", @@ -1889,13 +1921,13 @@ tf_cc_test( ":dynamic_shape_utils", ":tensorflow", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/tsl/platform:float8", "@llvm-project//mlir:IR", ], ) @@ -1938,6 +1970,7 @@ cc_library( hdrs = [ "transforms/constant_fold.h", ], + visibility = ["//visibility:public"], deps = [ ":convert_tensor", ":eval_util", @@ -1946,7 +1979,6 @@ cc_library( ":tensorflow_types", "//tensorflow/c:tf_status", "//tensorflow/compiler/xla/stream_executor", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "@llvm-project//llvm:Support", @@ -1975,7 +2007,6 @@ cc_library( ":mlir_roundtrip_flags", ":tensorflow", "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:graph", @@ -1994,21 +2025,14 @@ cc_library( hdrs = ["utils/eval_util.h"], deps = [ ":convert_tensor", - ":convert_type", ":export_tf_dialect_op", - ":mangling_util", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_internal", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/stream_executor/lib", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -2019,15 +2043,16 @@ cc_library( cc_library( name = "mlir_import_options", hdrs = ["translate/mlir_import_options.h"], + visibility = ["//visibility:public"], ) cc_library( name = "translate_lib", srcs = ["translate/tf_mlir_translate.cc"], hdrs = ["translate/tf_mlir_translate.h"], + visibility = ["//visibility:public"], deps = [ ":error_util", - ":export_graphdef", ":import_model", ":import_utils", ":mangling_util", @@ -2036,13 +2061,13 @@ cc_library( "//tensorflow/cc/saved_model:bundle_v2", "//tensorflow/cc/saved_model:loader", "//tensorflow/cc/saved_model:reader", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler/utils:transitive_fanin", + "//tensorflow/core/util/tensor_bundle:byteswaptensor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -2079,7 +2104,6 @@ cc_library( ":tensorflow", ":translate_cl_options", ":translate_lib", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:protos_all_cc", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", @@ -2162,7 +2186,7 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/compiler/xla/translate/mhlo_to_hlo:layout_util", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//tensorflow/compiler/mlir/xla:tf_xla_passes", - "//tensorflow/compiler/mlir/xla:xla_passes", + "//tensorflow/compiler/xla/mlir/framework/transforms:passes", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:type_to_shape", "//tensorflow/compiler/mlir/xla:adjust_layout", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", @@ -2174,7 +2198,7 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/core/platform:errors", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", @@ -2182,9 +2206,9 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/core/platform:error_payloads", "//tensorflow/core/platform:logging", "//tensorflow/core/tpu:tpu_defs", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy", "@llvm-project//mlir:TensorDialect", + "//tensorflow/compiler/mlir/xla:xla_legalize_targets", ] cc_library( @@ -2228,7 +2252,7 @@ cc_library( "//tensorflow/compiler/mlir:string_container_utils", "//tensorflow/compiler/tf2xla:xla_argument", "//tensorflow/compiler/tf2xla:xla_helpers", - "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:type_to_shape", @@ -2244,6 +2268,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@llvm-project//mlir:TranslateLib", + "@stablehlo//:stablehlo_ops", ], alwayslink = 1, ) @@ -2344,7 +2369,6 @@ cc_library( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/protobuf/tpu:topology_proto_cc", @@ -2428,13 +2452,21 @@ tf_cc_test( srcs = ["utils/dump_mlir_util_test.cc"], deps = [ ":dump_mlir_util", + ":tensorflow", + ":tensorflow_passes", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/platform:test", + "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ], ) @@ -2532,6 +2564,8 @@ cc_library( hdrs = ["utils/attribute_utils.h"], deps = [ "//tensorflow/compiler/tf2xla:tf2xla_defs", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", ], ) @@ -2650,6 +2684,21 @@ cc_library( ], ) +cc_library( + name = "tpu_embedding_ops_registry", + srcs = [ + "ir/tpu_embedding_ops_registry.cc", + ], + hdrs = [ + "ir/tpu_embedding_ops_registry.h", + ], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "set_tpu_infeed_layout", srcs = ["transforms/set_tpu_infeed_layout.cc"], diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc index f8ac6e264e4..16e160d2243 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include "llvm/ADT/ArrayRef.h" @@ -71,11 +72,12 @@ class BacktrackAnalysisInfo { // Returns the argument index of the region to which the given result number // can backtracked to. Such results will be called "function passthrough". If - // the result cannot be backtracked to a region argument, returns llvm::None. + // the result cannot be backtracked to a region argument, returns + // std::nullopt. llvm::Optional GetArg(int result_index) const { if (auto arg = GetValue(result_index).dyn_cast()) if (arg.getParentBlock() == ®ion_->front()) return arg.getArgNumber(); - return llvm::None; + return std::nullopt; } private: @@ -137,10 +139,10 @@ class BacktrackAnalysis { } // Returns the backtrack analysis for the given region if it exists. - // If the region has not yet been analyzed, returns llvm::None. + // If the region has not yet been analyzed, returns std::nullopt. Optional GetAnalysisIfExists(Region& region) const { auto it = info_map_.find(®ion); - if (it == info_map_.end()) return llvm::None; + if (it == info_map_.end()) return std::nullopt; return &it->second; } @@ -208,9 +210,9 @@ Value BacktrackAnalysis::BacktrackValue(Value value) { // we cannot backtrack the value further. Optional callee_info = GetAnalysisIfExists(func); if (!callee_info) break; - Optional passthrough_arg = callee_info.getValue()->GetArg(res_index); + Optional passthrough_arg = callee_info.value()->GetArg(res_index); if (!passthrough_arg) break; - value = call.getArgOperands()[passthrough_arg.getValue()]; + value = call.getArgOperands()[passthrough_arg.value()]; } else if (isa(op)) { value = op->getRegion(0).front().getTerminator()->getOperand(res_index); } else { @@ -384,7 +386,7 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo( while_op.body_function())); } else if (auto while_region = dyn_cast(op)) { AnalyzeWhileLoop(while_region, backtrack_analysis.GetAnalysisForRegion( - while_region.body())); + while_region.getBody())); } else if (auto case_op = dyn_cast(op)) { llvm::SmallVector functions; case_op.get_branch_functions(functions); @@ -406,8 +408,8 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo( for (auto result : filter_resources(op->getResults())) { auto passthrough_arg = func_info.GetArg(result.getResultNumber()); if (passthrough_arg) { - PropagateInputToOutput( - call.getArgOperands()[passthrough_arg.getValue()], result); + PropagateInputToOutput(call.getArgOperands()[passthrough_arg.value()], + result); } else { AddValueUniqueIDMapping(result, kUnknownResourceId); } @@ -432,7 +434,7 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo( mem_interface.getEffectOnValue(value); if (alloc_effect) { TypeID mlir_type_id = - alloc_effect.getValue().getResource()->getResourceID(); + alloc_effect.value().getResource()->getResourceID(); // Update or lookup internal type ID. auto emplace_result = type_id_to_internal_type_id_.try_emplace( mlir_type_id, next_unique_type_id); @@ -502,7 +504,7 @@ void ResourceAliasAnalysisInfo::AnalyzeWhileLoop( int result_index = result.getResultNumber(); passthrough_args[result_index] = body_info.GetArg(result_index); if (passthrough_args[result_index]) { - int passthru_index = passthrough_args[result_index].getValue(); + int passthru_index = passthrough_args[result_index].value(); PropagateInputToOutput(while_op->getOperand(passthru_index), result); need_analysis |= !IsUnknownResource(result) && passthru_index != result_index; @@ -525,7 +527,7 @@ void ResourceAliasAnalysisInfo::AnalyzeWhileLoop( // If this result has a valid passthrough arg, propagate resource IDs // from the result of the passthrough arg int result_index = result.getResultNumber(); - int passthru_index = passthrough_args[result_index].getValue(); + int passthru_index = passthrough_args[result_index].value(); change = PropagateInputToOutput(while_op->getResult(passthru_index), result) || change; @@ -556,7 +558,7 @@ void ResourceAliasAnalysisInfo::AnalyzeFunctionalCaseOrIfOp( }); if (all_passthrough_args_known) { for (const auto& passthrough_arg : passthrough_args) { - Value operand = case_or_if_op.input()[passthrough_arg.getValue()]; + Value operand = case_or_if_op.getInput()[passthrough_arg.value()]; PropagateInputToOutput(operand, result); } } else { diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.cc index 8413b16b905..c06d09bf04f 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.cc @@ -118,7 +118,7 @@ LogicalResult ResourceAnalyzer::AnalyzeRegion(Region& region) { return; } if (auto assign_variable = dyn_cast(op)) { - SetPotentiallyWritten(assign_variable.resource()); + SetPotentiallyWritten(assign_variable.getResource()); return; } if (auto call = dyn_cast(op)) { @@ -131,29 +131,29 @@ LogicalResult ResourceAnalyzer::AnalyzeRegion(Region& region) { if (auto if_op = dyn_cast(op)) { for (auto callee : {if_op.then_function(), if_op.else_function()}) { PropagatePotentiallyWrittenUpFromCallee(callee.getRegion(), - if_op.input()); + if_op.getInput()); } return; } if (auto if_op = dyn_cast(op)) { - PropagatePotentiallyWrittenUpFromCallee(if_op.then_branch(), + PropagatePotentiallyWrittenUpFromCallee(if_op.getThenBranch(), if_op.getODSOperands(1)); - PropagatePotentiallyWrittenUpFromCallee(if_op.else_branch(), + PropagatePotentiallyWrittenUpFromCallee(if_op.getElseBranch(), if_op.getODSOperands(1)); return; } if (auto while_op = dyn_cast(op)) { for (auto callee : {while_op.cond_function(), while_op.body_function()}) { PropagatePotentiallyWrittenUpFromCallee(callee.getRegion(), - while_op.input()); + while_op.getInput()); } return; } if (auto while_op = dyn_cast(op)) { - PropagatePotentiallyWrittenUpFromCallee(while_op.cond(), - while_op.input()); - PropagatePotentiallyWrittenUpFromCallee(while_op.body(), - while_op.input()); + PropagatePotentiallyWrittenUpFromCallee(while_op.getCond(), + while_op.getInput()); + PropagatePotentiallyWrittenUpFromCallee(while_op.getBody(), + while_op.getInput()); return; } // For all other ops, we assume it mutates all resources it uses, so diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index 771294559b2..14ae242525a 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include +#include #include #include "absl/container/node_hash_map.h" @@ -128,7 +129,7 @@ bool MayHaveSideEffect(Operation* op) { if (isa_and_nonnull(op->getDialect())) return TensorFlowDialect::CanHaveSideEffects(op); - if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) return false; + if (mlir::isMemoryEffectFree(op)) return false; // Conservatively assume that there can be side effects. return true; } @@ -241,7 +242,7 @@ class OpSideEffectCollector { } else if (auto while_op = dyn_cast(op)) { AddRegionSideEffectsForOp(while_op.body_function().getBody(), op); } else if (auto while_region_op = dyn_cast(op)) { - AddRegionSideEffectsForOp(while_region_op.body(), op); + AddRegionSideEffectsForOp(while_region_op.getBody(), op); } else if (auto case_op = dyn_cast(op)) { llvm::SmallVector branch_funcs; case_op.get_branch_functions(branch_funcs); @@ -286,15 +287,21 @@ class OpSideEffectCollector { // dead or get pruned, ignore it for side effect analysis. continue; - // Add side effects for op resource ID. - std::string instance_str = ""; + // Add side effects for op resource ID. If `op` does not have + // `GetResourceInstanceInterface`, then all op instances will keep an + // empty `instance_str` which enforces global order. + std::optional instance_str = ""; SideEffects side_effects(GetSideEffectsFromEffectInstance(effect, op)); if (auto resource_instance_op = dyn_cast(op)) { instance_str = resource_instance_op.GetResourceInstanceStr(); } + // No value (`std::nullopt`) instance string signals that we should + // ignore this effect, see comment for `GetResourceInstanceInterface`. + if (!instance_str.has_value()) continue; + TypeID type_id = effect.getResource()->getResourceID(); - ResourceId resource_id = GetOpResourceId(type_id, instance_str); + ResourceId resource_id = GetOpResourceId(type_id, instance_str.value()); side_effects.SetResourceId(resource_id); UpdateSideEffectsByResourceId(side_effects, side_effects_by_resource_id); @@ -347,6 +354,17 @@ SideEffectsByResourceId CollectSideEffectsByResourceId( SideEffectsByResourceId side_effects_by_resource_id; if (!MayHaveSideEffect(op)) return side_effects_by_resource_id; + // For fetch op, set unknown effect to guarantee that it depends on every + // side-effecting op (directly or indirectly). + if (isa(op)) { + SideEffects unknown_effect; + unknown_effect.SetUnknownEffect(); + unknown_effect.SetResourceId(kUnknownResourceId); + UpdateSideEffectsByResourceId(unknown_effect, + side_effects_by_resource_id); + return side_effects_by_resource_id; + } + if (isa(op)) { // For ops that are side-effecting only if their attached regions are, @@ -716,6 +734,14 @@ bool SideEffectAnalysisInfo::IsUnknownAccessIndirectlyTrackedByResource( return is_tracked; } +const llvm::SmallVector& +SideEffectAnalysisInfo::DirectControlPredecessors( + Operation* op) const { + auto it = sorted_control_predecessors_.find(op); + if (it == sorted_control_predecessors_.end()) return empty_operation_set_; + return it->second; +} + llvm::SmallVector SideEffectAnalysisInfo::DirectControlPredecessors( Operation* op, llvm::function_ref filter) const { @@ -729,6 +755,14 @@ SideEffectAnalysisInfo::DirectControlPredecessors( return result; } +const llvm::SmallVector& +SideEffectAnalysisInfo::DirectControlSuccessors( + Operation* op) const { + auto it = sorted_control_successors_.find(op); + if (it == sorted_control_successors_.end()) return empty_operation_set_; + return it->second; +} + llvm::SmallVector SideEffectAnalysisInfo::DirectControlSuccessors( Operation* op, llvm::function_ref filter) const { diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h index 47e7b99183e..f9fac1c61c3 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h @@ -83,16 +83,16 @@ class SideEffectAnalysisInfo { // Returns a vector of ops that are direct control predecessors of `op`, // sorted in program order. If `filter` is provided, only predecessors that // pass the filter (returning true) will be included. + const llvm::SmallVector& DirectControlPredecessors( + Operation* op) const; llvm::SmallVector DirectControlPredecessors( - Operation* op, - llvm::function_ref filter = nullptr) const; + Operation* op, llvm::function_ref filter) const; - // Returns a vector of ops that are direct control successors of `op`, - // sorted in program order. If `filter` is provided, only successors that // pass the filter (returning true) will be included. + const llvm::SmallVector& DirectControlSuccessors( + Operation* op) const; llvm::SmallVector DirectControlSuccessors( - Operation* op, - llvm::function_ref filter = nullptr) const; + Operation* op, llvm::function_ref filter) const; // Returns a vector of ops that are control sinks (i.e. side-effecting ops // with no control successors). @@ -163,6 +163,9 @@ class SideEffectAnalysisInfo { op_to_resource_ids_; llvm::SmallVector> empty_resource_ids_; + // For predecessor / successor queries on ops we don't track. + llvm::SmallVector empty_operation_set_; + // Internal per-resource data structure for building the dependencies. struct PerResourceAccessInfo { // Last op that writes to resource before the current op is being analyzed. diff --git a/tensorflow/compiler/mlir/tensorflow/c/BUILD b/tensorflow/compiler/mlir/tensorflow/c/BUILD index b5bdf0c0e0b..058348bbdbf 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/c/BUILD @@ -6,6 +6,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [":friends"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index 348f5799ec5..f1078898ee6 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include "absl/strings/str_cat.h" #include "llvm/ADT/StringRef.h" @@ -247,9 +248,9 @@ class MlirFunctionContext : public TracingContext { RegisterDialects(*context_); // TODO(aminim) figure out the location story here module_ = ModuleOp::create(builder_.getUnknownLoc()); - func_ = - func::FuncOp::create(builder_.getUnknownLoc(), name, - builder_.getFunctionType(llvm::None, llvm::None)); + func_ = func::FuncOp::create( + builder_.getUnknownLoc(), name, + builder_.getFunctionType(std::nullopt, std::nullopt)); module_->push_back(func_); builder_ = OpBuilder::atBlockBegin(func_.addEntryBlock()); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h index f58c10d72c0..bc46b0c04ec 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h @@ -58,8 +58,8 @@ template < OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr> OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, ArrayRef operands) { - auto lhs_type = arithmetic_op.x().getType().template cast(); - auto rhs_type = arithmetic_op.y().getType().template cast(); + auto lhs_type = arithmetic_op.getX().getType().template cast(); + auto rhs_type = arithmetic_op.getY().getType().template cast(); auto result_type = arithmetic_op.getResult().getType().template cast(); @@ -110,7 +110,7 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, if (rhs_attr && is_valid_broadcasting(lhs_type, rhs_type, result_type)) { if (rhs_attr.isSplat() && rhs_attr.getSplatValue() == identity_attr) - return arithmetic_op.x(); + return arithmetic_op.getX(); } // Fold: Op(Identity, Operand) -> Operand for commutative operations. @@ -118,7 +118,7 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, is_valid_broadcasting(rhs_type, lhs_type, result_type)) { if (lhs_attr.isSplat() && lhs_attr.getSplatValue() == identity_attr) - return arithmetic_op.y(); + return arithmetic_op.getY(); } return {}; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index ba0e7f7bb68..37eeb8b9733 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -70,13 +70,13 @@ struct TFInlinerInterface : public DialectInlinerInterface { // Returns if its legal to inline 'src' region into the 'dest' region // attached to a TF Device operation. bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned, - BlockAndValueMapping& valueMapping) const final { + IRMapping& valueMapping) const final { return true; } // Defines the legality of inlining TF Device operations. bool isLegalToInline(Operation*, Region*, bool, - BlockAndValueMapping&) const final { + IRMapping&) const final { // For now, enable inlining all operations. return true; } @@ -236,7 +236,7 @@ ParseResult ParseReplicateOpOperands( llvm::SmallVector packed_region_arg_types; do { OpAsmParser::UnresolvedOperand operand_type; - if (parser->parseOptionalOperand(operand_type).hasValue()) { + if (parser->parseOptionalOperand(operand_type).has_value()) { packed_inputs->emplace_back(operand_type); if (parser->parseKeyword("as", " between packed input and block argument") || @@ -439,7 +439,7 @@ void BuildReplicateOp( DCHECK_GE(n, 2); state->addAttribute("n", builder->getI32IntegerAttr(n)); - if (devices.has_value()) state->addAttribute("devices", devices.getValue()); + if (devices.has_value()) state->addAttribute("devices", devices.value()); Region* region = state->addRegion(); region->push_back(new Block); @@ -479,7 +479,7 @@ LogicalResult ReplicateOp::verify() { // Check number of devices, if set, matches `n`. if (op.getDevices().has_value()) { - for (auto device_attr : op.getDevices().getValue().getValue()) { + for (auto device_attr : op.getDevices().value().getValue()) { auto device_list = device_attr.getValue().dyn_cast_or_null(); if (!device_list) return op.emitError() diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td index 39599547d2f..8f700aeb2c1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td @@ -38,8 +38,8 @@ def TfDevice_Dialect : Dialect { XlaRun. }]; - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; let cppNamespace = "::mlir::tf_device"; + let useFoldAPI = kEmitFoldAdaptorFolder; } //===----------------------------------------------------------------------===// @@ -115,7 +115,7 @@ def TfDevice_LaunchFuncOp : TfDevice_Op<"launch_func", []> { let arguments = (ins StrAttr:$device, FlatSymbolRefAttr:$func, - Variadic:$operands); + Variadic); let results = (outs Variadic:$results @@ -349,7 +349,7 @@ This op is used for outlining a cluster. let arguments = (ins FlatSymbolRefAttr:$func, - Variadic:$operands + Variadic ); let results = (outs diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 2f2df605602..f7c35420c22 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -69,7 +69,7 @@ struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface { // Override the inlining hook to determine if 'src' can be inlined into // 'dest'. bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, - BlockAndValueMapping &value_mapping) const final { + IRMapping &value_mapping) const final { // Allow inlining into tf.island regions if the incoming region has a single // block. return llvm::isa(dest->getParentOp()) && @@ -511,7 +511,7 @@ void SwitchNOp::print(OpAsmPrinter &p) { p.printOperands(operands.begin(), std::next(operands.begin(), 2)); p << " of " << (getNumResults() - 1); // print control dependencies if any - if (!llvm::empty(getControlInputs())) { + if (!getControlInputs().empty()) { p << " ("; p.printOperands(getControlInputs()); p << ")"; @@ -1085,7 +1085,7 @@ void ControlTriggerOp::getCanonicalizationPatterns(RewritePatternSet &results, // tf_executor.island //===----------------------------------------------------------------------===// -LogicalResult IslandOp::fold(llvm::ArrayRef operands, +LogicalResult IslandOp::fold(FoldAdaptor, llvm::SmallVectorImpl &results) { // This folds IslandOps with no inner ops, one control operand and no data // results. The single control operand is forwarded to the IslandOp control diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td index 1c16e59b6ed..b5b49f878e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td @@ -43,8 +43,8 @@ def TfExecutor_Dialect : Dialect { value). }]; - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; let cppNamespace = "::mlir::tf_executor"; + let useFoldAPI = kEmitFoldAdaptorFolder; } // Control type. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 4597ec07cc4..89645498270 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -16,9 +16,12 @@ limitations under the License. // This is the auto-generated operation definition file for TensorFlow. // // PLEASE DO NOT MANUALLY EDIT THIS FILE! +// ONLY EXCEPTION: FIELDS THAT CANNOT BE GENERATED // // If you absolutely need to modify the generated fields of an op, move the op -// definition to `tf_ops.td` and perform the modification there. +// definition to `tf_ops.td` and perform the modification there. Generated +// fields and the process to generate them are documented at: +// mlir/tensorflow/dialectgen/README.md // // This file contains TensorFlow ops whose definitions are programmatically // generated from the TF op registration and the api-def-files in the following @@ -73,11 +76,11 @@ Provided an input tensor, the `tf.math.acos` operation returns the inverse cosin }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$x + TF_FpOrComplexTensor:$x ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$y + TF_FpOrComplexTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -159,11 +162,11 @@ Inputs must be of same size and shape. }]; let arguments = (ins - Variadic>:$inputs + Variadic>:$inputs ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8, TF_Variant]>:$sum + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8, TF_Variant]>:$sum ); TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; @@ -324,7 +327,7 @@ replica 1's output: `[[B], [D]]` }]; let arguments = (ins - Arg, [{The local input to the sum.}]>:$input, + Arg, [{The local input to the sum.}]>:$input, Arg:$group_assignment, @@ -335,7 +338,7 @@ replica ids in the ith subgroup.}]>:$group_assignment, ); let results = (outs - Res, [{The exchanged result.}]>:$output + Res, [{The exchanged result.}]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -356,7 +359,7 @@ For example: ``` # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -tf.angle(input) ==> [2.0132, 1.056] +tf.math.angle(input) ==> [2.0132, 1.056] ``` @compatibility(numpy) @@ -566,8 +569,8 @@ def TF_ApproximateEqualOp : TF_Op<"ApproximateEqual", [Commutative, Pure]> { let summary = "Returns the truth value of abs(x-y) < tolerance element-wise."; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, + TF_NumberTensor:$x, + TF_NumberTensor:$y, DefaultValuedOptionalAttr:$tolerance ); @@ -599,7 +602,7 @@ Usage: }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[TF_Bfloat16, TF_Bool, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, Arg, [{int16, int32 or int64, must be in the range `[-rank(input), rank(input))`. Describes which dimension of the input Tensor to reduce across. For vectors, use dimension = 0.}]>:$dimension @@ -634,7 +637,7 @@ Usage: }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TensorOf<[TF_Bfloat16, TF_Bool, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, Arg:$dimension @@ -706,11 +709,11 @@ tf.math.asin(y) # [1.047, 0.785] = x }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$x + TF_FpOrComplexTensor:$x ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$y + TF_FpOrComplexTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -877,11 +880,11 @@ tf.math.atan(y) # [1.047, 0.785] = x }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$x + TF_FpOrComplexTensor:$x ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$y + TF_FpOrComplexTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1070,7 +1073,7 @@ is smaller than desired.}]>:$drop_remainder, ); } -def TF_BatchFunctionOp : TF_Op<"BatchFunction", [AttrSizedOperandSegments, Pure]> { +def TF_BatchFunctionOp : TF_Op<"BatchFunction", [AttrSizedOperandSegments, DeclareOpInterfaceMethods, Pure]> { let summary = [{ Batches all the inputs tensors to the computation done by the function. }]; @@ -1213,15 +1216,15 @@ about broadcasting }]; let arguments = (ins - Arg, [{2-D or higher with shape `[..., r_x, c_x]`.}]>:$x, - Arg, [{2-D or higher with shape `[..., r_y, c_y]`.}]>:$y, + Arg, [{2-D or higher with shape `[..., r_x, c_x]`.}]>:$x, + Arg, [{2-D or higher with shape `[..., r_y, c_y]`.}]>:$y, DefaultValuedOptionalAttr:$adj_x, DefaultValuedOptionalAttr:$adj_y ); let results = (outs - Res, [{3-D or higher with shape `[..., r_o, c_o]`}]>:$output + Res, [{3-D or higher with shape `[..., r_o, c_o]`}]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1562,14 +1565,14 @@ Broadcasting is supported, so `value` may have any number of dimensions. }]; let arguments = (ins - Arg, [{Any number of dimensions.}]>:$value, - Arg, [{1-D with size the last dimension of `value`.}]>:$bias, + Arg:$value, + Arg:$bias, DefaultValuedOptionalAttr:$data_format ); let results = (outs - Res, [{Broadcasted sum of `value` and `bias`.}]>:$output + Res:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1597,13 +1600,13 @@ the feature dimension is the third-to-last. }]; let arguments = (ins - Arg, [{Any number of dimensions.}]>:$out_backprop, + Arg:$out_backprop, DefaultValuedOptionalAttr:$data_format ); let results = (outs - Res, [{1-D with size the feature dimension of `out_backprop`.}]>:$output + Res:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1622,12 +1625,12 @@ Broadcasting is supported, so `value` may have any number of dimensions. }]; let arguments = (ins - Arg, [{Any number of dimensions.}]>:$value, - Arg, [{1-D with size the last dimension of `value`.}]>:$bias + Arg:$value, + Arg:$bias ); let results = (outs - Res, [{Broadcasted sum of `value` and `bias`.}]>:$output + Res:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -2224,6 +2227,30 @@ Mutually accumulates multiple tensors of identical type and shape. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_CollectiveGatherV2Op : TF_Op<"CollectiveGatherV2", [TF_CollectiveReduceOrderingEffect]> { + let summary = [{ +Mutually accumulates multiple tensors of identical type and shape. + }]; + + let arguments = (ins + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$input, + TF_Int32Tensor:$group_size, + TF_Int32Tensor:$group_key, + TF_Int32Tensor:$instance_key, + Variadic:$ordering_token, + + DefaultValuedOptionalAttr:$communication_hint, + DefaultValuedOptionalAttr:$timeout_seconds + ); + + let results = (outs + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$data + ); + + TF_DerivedOperandSizeAttr Nordering_token = TF_DerivedOperandSizeAttr<4>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_CollectivePermuteOp : TF_Op<"CollectivePermute", []> { let summary = "An Op to permute tensors across replicated TPU instances."; @@ -2236,13 +2263,13 @@ source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs: }]; let arguments = (ins - Arg, [{The local input to be permuted. Currently only supports float and + Arg:$input, Arg:$source_target_pairs ); let results = (outs - Res, [{The permuted input.}]>:$output + Res:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -2280,7 +2307,34 @@ Mutually reduces multiple tensors of identical type and shape. }]; } -def TF_CollectiveReduceV2Op : TF_Op<"CollectiveReduceV2", [TF_CollectiveReduceOrderingEffect]> { +def TF_CollectiveReduceScatterV2Op : TF_Op<"CollectiveReduceScatterV2", [DeclareOpInterfaceMethods, TF_CollectiveReduceOrderingEffect]> { + let summary = [{ +Mutually reduces multiple tensors of identical type and shape and scatters the result. + }]; + + let arguments = (ins + TF_FpOrI32OrI64Tensor:$input, + TF_Int32Tensor:$group_size, + TF_Int32Tensor:$group_key, + TF_Int32Tensor:$instance_key, + Variadic:$ordering_token, + + TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add"]>:$merge_op, + TF_AnyStrAttrOf<["Id", "Div"]>:$final_op, + DefaultValuedOptionalAttr:$communication_hint, + DefaultValuedOptionalAttr:$timeout_seconds, + DefaultValuedOptionalAttr:$max_subdivs_per_device + ); + + let results = (outs + TF_FpOrI32OrI64Tensor:$data + ); + + TF_DerivedOperandSizeAttr Nordering_token = TF_DerivedOperandSizeAttr<4>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_CollectiveReduceV2Op : TF_Op<"CollectiveReduceV2", [DeclareOpInterfaceMethods, TF_CollectiveReduceOrderingEffect]> { let summary = [{ Mutually reduces multiple tensors of identical type and shape. }]; @@ -2695,6 +2749,35 @@ the `filter` input of the convolution.}]>:$output }]; } +def TF_Conv2DBackpropFilterV2Op : TF_Op<"Conv2DBackpropFilterV2", [Pure]> { + let summary = [{ +Computes the gradients of convolution with respect to the filter. + }]; + + let arguments = (ins + Arg:$input, + Arg:$filter, + Arg:$out_backprop, + + I64ArrayAttr:$strides, + DefaultValuedOptionalAttr:$use_cudnn_on_gpu, + TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding, + DefaultValuedOptionalAttr:$explicit_paddings, + DefaultValuedOptionalAttr:$data_format, + DefaultValuedOptionalAttr:$dilations + ); + + let results = (outs + Res:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_Conv2DBackpropInputOp : TF_Op<"Conv2DBackpropInput", [Pure, TF_LayoutSensitiveInterface]> { let summary = [{ Computes the gradients of convolution with respect to the input. @@ -2734,6 +2817,35 @@ w.r.t. the input of the convolution.}]>:$output }]; } +def TF_Conv2DBackpropInputV2Op : TF_Op<"Conv2DBackpropInputV2", [Pure]> { + let summary = [{ +Computes the gradients of convolution with respect to the input. + }]; + + let arguments = (ins + Arg, [{4-D with shape `[batch, in_height, in_width, in_channels]`. +Only shape of tensor is used.}]>:$input, + Arg, [{4-D with shape +`[filter_height, filter_width, in_channels, out_channels]`.}]>:$filter, + Arg, [{4-D with shape `[batch, out_height, out_width, out_channels]`. +Gradients w.r.t. the output of the convolution.}]>:$out_backprop, + + I64ArrayAttr:$strides, + DefaultValuedOptionalAttr:$use_cudnn_on_gpu, + TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding, + DefaultValuedOptionalAttr:$explicit_paddings, + DefaultValuedOptionalAttr:$data_format, + DefaultValuedOptionalAttr:$dilations + ); + + let results = (outs + Res, [{4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient +w.r.t. the input of the convolution.}]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_Conv3DOp : TF_Op<"Conv3D", [InferTensorType, Pure]> { let summary = [{ Computes a 3-D convolution given 5-D `input` and `filter` tensors. @@ -2775,6 +2887,30 @@ out_channels]`. `in_channels` must match between `input` and `filter`.}]>:$filte } +def TF_Conv3DBackpropFilterOp : TF_Op<"Conv3DBackpropFilter", [Pure]> { + let summary = [{ +Computes the gradients of 3-D convolution with respect to the filter. + }]; + + let arguments = (ins + Arg, [{Shape `[batch, depth, rows, cols, in_channels]`.}]>:$input, + Arg, [{Shape `[depth, rows, cols, in_channels, out_channels]`. +`in_channels` must match between `input` and `filter`.}]>:$filter, + Arg, [{Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +out_channels]`.}]>:$out_backprop, + + ConfinedAttr]>:$strides, + TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding, + DefaultValuedOptionalAttr:$dilations + ); + + let results = (outs + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_Conv3DBackpropFilterV2Op : TF_Op<"Conv3DBackpropFilterV2", [Pure]> { let summary = [{ Computes the gradients of 3-D convolution with respect to the filter. @@ -2802,6 +2938,30 @@ out_channels]`.}]>:$out_backprop, TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_Conv3DBackpropInputOp : TF_Op<"Conv3DBackpropInput", [Pure]> { + let summary = [{ +Computes the gradients of 3-D convolution with respect to the input. + }]; + + let arguments = (ins + Arg, [{Shape `[batch, depth, rows, cols, in_channels]`.}]>:$input, + Arg, [{Shape `[depth, rows, cols, in_channels, out_channels]`. +`in_channels` must match between `input` and `filter`.}]>:$filter, + Arg, [{Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +out_channels]`.}]>:$out_backprop, + + ConfinedAttr]>:$strides, + TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding, + DefaultValuedOptionalAttr:$dilations + ); + + let results = (outs + TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_Conv3DBackpropInputV2Op : TF_Op<"Conv3DBackpropInputV2", [Pure]> { let summary = [{ Computes the gradients of 3-D convolution with respect to the input. @@ -2983,7 +3143,7 @@ tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] }]; let arguments = (ins - Arg, [{A `Tensor`. Must be one of the following types: `float32`, `float64`, + Arg:$x, Arg [b * c, c, 1] ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out + TF_NumberTensor:$out ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3038,7 +3198,7 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] }]; let arguments = (ins - Arg, [{A `Tensor`. Must be one of the following types: `float32`, `float64`, + Arg:$x, Arg [b + c, c, 0] ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out + TF_NumberTensor:$out ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -3085,7 +3245,7 @@ opposite direction. }]; let arguments = (ins - Arg, [{A `Tensor`. Must be one of the following types: `float16`, `float32`, `float64`.}]>:$x, + Arg:$x, Arg:$axis, @@ -3094,7 +3254,7 @@ opposite direction. ); let results = (outs - TensorOf<[TF_Float16, TF_Float32, TF_Float64]>:$out + TF_FloatTensor:$out ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -4083,6 +4243,70 @@ of TPU cores in the task on which the node is placed.}]>:$device_ordinal, TF_DerivedOperandTypeAttr T3 = TF_DerivedOperandTypeAttr<2>; } +def TF_DynamicPartitionOp : TF_Op<"DynamicPartition", [Pure]> { + let summary = [{ +Partitions `data` into `num_partitions` tensors using indices from `partitions`. + }]; + + let description = [{ +For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]` +becomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i` +are placed in `outputs[i]` in lexicographic order of `js`, and the first +dimension of `outputs[i]` is the number of entries in `partitions` equal to `i`. +In detail, + +```python + outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:] + + outputs[i] = pack([data[js, ...] for js if partitions[js] == i]) +``` + +`data.shape` must start with `partitions.shape`. + +For example: + +```python + # Scalar partitions. + partitions = 1 + num_partitions = 2 + data = [10, 20] + outputs[0] = [] # Empty with shape [0, 2] + outputs[1] = [[10, 20]] + + # Vector partitions. + partitions = [0, 0, 1, 1, 0] + num_partitions = 2 + data = [10, 20, 30, 40, 50] + outputs[0] = [10, 20, 50] + outputs[1] = [30, 40] +``` + +See `dynamic_stitch` for an example on how to merge partitions back. + +
+ +
+ + +Raises: + * `InvalidArgumentError` in following cases: + - If partitions is not in range `[0, num_partiions)` + - If `partitions.shape` does not match prefix of `data.shape` argument. + }]; + + let arguments = (ins + TF_Tensor:$data, + Arg:$partitions + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultSizeAttr num_partitions = TF_DerivedResultSizeAttr<0>; +} + def TF_DynamicStitchOp : TF_Op<"DynamicStitch", [Pure, SameVariadicOperandSize]> { let summary = [{ Interleave the values from the `data` tensors into a single tensor. @@ -4995,10 +5219,14 @@ def TF_FakeParamOp : TF_Op<"FakeParam", [Pure, TF_NoConstantFold]> { def TF_FakeQuantWithMinMaxArgsOp : TF_Op<"FakeQuantWithMinMaxArgs", [Pure, TF_SameOperandsAndResultTypeResolveRef]> { let summary = [{ -Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type. +Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same shape and type. }]; let description = [{ +Quantization is called fake since the output is still in floating point. + The API converts inputs into values within the range [min and max] and returns + as output. + Attributes * `[min; max]` define the clamping range for the `inputs` data. @@ -5018,7 +5246,25 @@ the behavior can be unexpected: * If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `, `min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`. -Quantization is called fake since the output is still in floating point. + +Examples + +```python + +inp = tf.constant ([10.03, -10.23, 3]) +out = tf.quantization.fake_quant_with_min_max_args(inp, min=-5, max=5, + num_bits=16) +print(out) + +# Output: +# tf.Tensor([ 4.9999237 -5.0000763 3.0000763], shape=(3,), dtype=float32) +``` + +Raises: + * InvalidArgumentError: + - If num_bits are outside of range [2, 16]. + - If min >= max. + * ValueError: If `inputs` are of any other type than float32. }]; let arguments = (ins @@ -7860,6 +8106,8 @@ def TF_LogicalAndOp : TF_Op<"LogicalAnd", [Commutative, Pure, ResultsBroadcastab let results = (outs TF_BoolTensor:$z ); + + let hasFolder = 1; } def TF_LogicalNotOp : TF_Op<"LogicalNot", [Pure, TF_Involution, TF_SameOperandsAndResultTypeResolveRef]> { @@ -8216,15 +8464,15 @@ cublas. }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$a, - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$b, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$a, + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$b, DefaultValuedOptionalAttr:$transpose_a, DefaultValuedOptionalAttr:$transpose_b ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$product + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$product ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -9390,7 +9638,7 @@ retained with length 1. }]; let arguments = (ins - Arg, [{The tensor to reduce.}]>:$input, + Arg:$input, Arg:$reduction_indices, @@ -9398,7 +9646,7 @@ retained with length 1. ); let results = (outs - Res, [{The reduced tensor.}]>:$output + Res:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -9953,6 +10201,47 @@ the insert operations. It does not support the initialization operation. ); } +def TF_NcclAllReduceOp : TF_Op<"NcclAllReduce", [DeclareOpInterfaceMethods, TF_NcclAllReduceOrderingEffect, TF_SameOperandsAndResultTypeResolveRef]> { + let summary = [{ +Outputs a tensor containing the reduction across all input tensors. + }]; + + let description = [{ +Outputs a tensor containing the reduction across all input tensors passed to ops +within the same `shared_name. + +The graph should be constructed so if one op runs with shared_name value `c`, +then `num_devices` ops will run with shared_name value `c`. Failure to do so +will cause the graph execution to fail to complete. + +input: the input to the reduction +data: the value of the reduction across all `num_devices` devices. +reduction: the reduction operation to perform. +num_devices: The number of devices participating in this reduction. +shared_name: Identifier that shared between ops of the same reduction. + }]; + + let arguments = (ins + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$input, + + TF_AnyStrAttrOf<["min", "max", "prod", "sum"]>:$reduction, + I64Attr:$num_devices, + StrAttr:$shared_name + ); + + let results = (outs + TensorOf<[TF_Float16, TF_Float32, TF_Float64, TF_Int32, TF_Int64]>:$data + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let extraClassDeclaration = [{ + static bool isCompatibleReturnTypes(TypeRange inferred, TypeRange actual) { + return ArraysAreCastCompatible(inferred, actual); + } + }]; +} + def TF_NdtriOp : TF_Op<"Ndtri", [Pure]> { let summary = ""; @@ -10479,6 +10768,18 @@ def TF_OptionalNoneOp : TF_Op<"OptionalNone", [Pure]> { ); } +def TF_OutfeedEnqueueOp : TF_Op<"OutfeedEnqueue", []> { + let summary = "Enqueue a Tensor on the computation outfeed."; + + let arguments = (ins + Arg:$input + ); + + let results = (outs); + + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<0>; +} + def TF_OutfeedEnqueueTupleOp : TF_Op<"OutfeedEnqueueTuple", []> { let summary = "Enqueue multiple Tensor values on the computation outfeed."; @@ -10808,7 +11109,7 @@ truncated normal values using the parameters for each row.}]>:$output TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<1>; } -def TF_PartitionedCallOp : TF_Op<"PartitionedCall", [CallOpInterface, Pure, SymbolUserOpInterface]> { +def TF_PartitionedCallOp : TF_Op<"PartitionedCall", [CallOpInterface, DeclareOpInterfaceMethods, Pure]> { let summary = [{ returns `f(inputs)`, where `f`'s body is placed and partitioned. }]; @@ -10837,18 +11138,15 @@ underlying graph, and executes each of the partitioned subgraphs as a function. let extraClassDeclaration = [{ // Gets the argument operands to the called function. - operand_range getArgOperands() { return args(); } + operand_range getArgOperands() { return getArgs(); } // Returns the callee of this operation. - CallInterfaceCallable getCallableForCallee() { return fAttr(); } + CallInterfaceCallable getCallableForCallee() { return getFAttr(); } // returns the callee of this operation. func::FuncOp func() { - return SymbolTable::lookupNearestSymbolFrom(*this, f()); + return SymbolTable::lookupNearestSymbolFrom(*this, getF()); } - - // SymbolUserOpInterface verifier. - LogicalResult verifySymbolUses(SymbolTableCollection &symbolTable); }]; } @@ -11045,7 +11343,7 @@ retained with length 1. }]; let arguments = (ins - Arg, [{The tensor to reduce.}]>:$input, + Arg:$input, Arg:$reduction_indices, @@ -11053,7 +11351,7 @@ retained with length 1. ); let results = (outs - Res, [{The reduced tensor.}]>:$output + Res:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -12547,12 +12845,12 @@ variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon) Arg:$var, Arg:$m, Arg:$v, - Arg, [{Must be a scalar.}]>:$beta1_power, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr, - Arg, [{Momentum factor. Must be a scalar.}]>:$beta1, - Arg, [{Momentum factor. Must be a scalar.}]>:$beta2, - Arg, [{Ridge term. Must be a scalar.}]>:$epsilon, - Arg, [{The gradient.}]>:$grad, + Arg:$beta1_power, + Arg:$lr, + Arg:$beta1, + Arg:$beta2, + Arg:$epsilon, + Arg:$grad, DefaultValuedOptionalAttr:$use_locking ); @@ -12576,10 +12874,10 @@ var -= update; Arg:$var, Arg:$accum, Arg:$accum_update, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr, - Arg, [{Decay factor. Must be a scalar.}]>:$rho, - Arg, [{Constant factor. Must be a scalar.}]>:$epsilon, - Arg, [{The gradient.}]>:$grad, + Arg:$lr, + Arg:$rho, + Arg:$epsilon, + Arg:$grad, DefaultValuedOptionalAttr:$use_locking ); @@ -12600,8 +12898,8 @@ var -= lr * grad * (1 / sqrt(accum)) let arguments = (ins Arg:$var, Arg:$accum, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr, - Arg, [{The gradient.}]>:$grad, + Arg:$lr, + Arg:$grad, DefaultValuedOptionalAttr:$use_locking, DefaultValuedOptionalAttr:$update_slots @@ -12619,10 +12917,10 @@ def TF_ResourceApplyAdagradDAOp : TF_Op<"ResourceApplyAdagradDA", []> { Arg:$var, Arg:$gradient_accumulator, Arg:$gradient_squared_accumulator, - Arg, [{The gradient.}]>:$grad, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr, - Arg, [{L1 regularization. Must be a scalar.}]>:$l1, - Arg, [{L2 regularization. Must be a scalar.}]>:$l2, + Arg:$grad, + Arg:$lr, + Arg:$l1, + Arg:$l2, Arg:$global_step, DefaultValuedOptionalAttr:$use_locking @@ -12644,9 +12942,9 @@ var -= lr * grad * (1 / (sqrt(accum) + epsilon)) let arguments = (ins Arg:$var, Arg:$accum, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr, - Arg, [{Constant factor. Must be a scalar.}]>:$epsilon, - Arg, [{The gradient.}]>:$grad, + Arg:$lr, + Arg:$epsilon, + Arg:$grad, DefaultValuedOptionalAttr:$use_locking, DefaultValuedOptionalAttr:$update_slots @@ -12671,13 +12969,13 @@ $$\text{var} := \begin{cases} \text{var} - (m_t \beta_1 + g \cdot (1 - \beta_1)) Arg:$var, Arg:$m, Arg:$v, - Arg, [{Must be a scalar.}]>:$beta1_power, - Arg, [{Must be a scalar.}]>:$beta2_power, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr, - Arg, [{Momentum factor. Must be a scalar.}]>:$beta1, - Arg, [{Momentum factor. Must be a scalar.}]>:$beta2, - Arg, [{Ridge term. Must be a scalar.}]>:$epsilon, - Arg, [{The gradient.}]>:$grad, + Arg:$beta1_power, + Arg:$beta2_power, + Arg:$lr, + Arg:$beta1, + Arg:$beta2, + Arg:$epsilon, + Arg:$grad, DefaultValuedOptionalAttr:$use_locking, DefaultValuedOptionalAttr:$use_nesterov @@ -12700,11 +12998,11 @@ variable <- variable - lr_t * update let arguments = (ins Arg:$var, Arg:$m, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr, - Arg, [{Must be a scalar.}]>:$alpha, - Arg, [{Must be a scalar.}]>:$sign_decay, - Arg, [{Must be a scalar.}]>:$beta, - Arg, [{The gradient.}]>:$grad, + Arg:$lr, + Arg:$alpha, + Arg:$sign_decay, + Arg:$beta, + Arg:$grad, DefaultValuedOptionalAttr:$use_locking ); @@ -12743,11 +13041,11 @@ var <- var - mom Arg:$mg, Arg:$ms, Arg:$mom, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr, - Arg, [{Decay rate. Must be a scalar.}]>:$rho, - Arg, [{Momentum Scale. Must be a scalar.}]>:$momentum, - Arg, [{Ridge term. Must be a scalar.}]>:$epsilon, - Arg, [{The gradient.}]>:$grad, + Arg:$lr, + Arg:$rho, + Arg:$momentum, + Arg:$epsilon, + Arg:$grad, DefaultValuedOptionalAttr:$use_locking ); @@ -12772,11 +13070,11 @@ accum = accum_new Arg:$var, Arg:$accum, Arg:$linear, - Arg, [{The gradient.}]>:$grad, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr, - Arg, [{L1 regularization. Must be a scalar.}]>:$l1, - Arg, [{L2 regularization. Must be a scalar.}]>:$l2, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr_power, + Arg:$grad, + Arg:$lr, + Arg:$l1, + Arg:$l2, + Arg:$lr_power, DefaultValuedOptionalAttr:$use_locking, DefaultValuedOptionalAttr:$multiply_linear_by_lr @@ -12804,12 +13102,12 @@ accum = accum_new Arg:$var, Arg:$accum, Arg:$linear, - Arg, [{The gradient.}]>:$grad, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr, - Arg, [{L1 regularization. Must be a scalar.}]>:$l1, - Arg, [{L2 shrinkage regularization. Must be a scalar.}]>:$l2, - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$l2_shrinkage, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr_power, + Arg:$grad, + Arg:$lr, + Arg:$l1, + Arg:$l2, + TF_NumberTensor:$l2_shrinkage, + Arg:$lr_power, DefaultValuedOptionalAttr:$use_locking, DefaultValuedOptionalAttr:$multiply_linear_by_lr @@ -12825,8 +13123,8 @@ def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", [] let arguments = (ins Arg:$var, - Arg, [{Scaling factor. Must be a scalar.}]>:$alpha, - Arg, [{The change.}]>:$delta, + Arg:$alpha, + Arg:$delta, DefaultValuedOptionalAttr:$use_locking ); @@ -12849,9 +13147,9 @@ var += accum let arguments = (ins Arg:$var, Arg:$accum, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr, - Arg, [{The gradient.}]>:$grad, - Arg, [{Momentum. Must be a scalar.}]>:$momentum, + Arg:$lr, + Arg:$grad, + Arg:$momentum, DefaultValuedOptionalAttr:$use_locking, DefaultValuedOptionalAttr:$use_nesterov @@ -12875,9 +13173,9 @@ var -= lr * accum let arguments = (ins Arg:$var, Arg:$accum, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr, - Arg, [{The gradient.}]>:$grad, - Arg, [{Momentum. Must be a scalar.}]>:$momentum, + Arg:$lr, + Arg:$grad, + Arg:$momentum, DefaultValuedOptionalAttr:$use_locking, DefaultValuedOptionalAttr:$use_nesterov @@ -12900,11 +13198,11 @@ variable <- variable - lr_t * update let arguments = (ins Arg:$var, Arg:$m, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr, - Arg, [{Must be a scalar.}]>:$logbase, - Arg, [{Must be a scalar.}]>:$sign_decay, - Arg, [{Must be a scalar.}]>:$beta, - Arg, [{The gradient.}]>:$grad, + Arg:$lr, + Arg:$logbase, + Arg:$sign_decay, + Arg:$beta, + Arg:$grad, DefaultValuedOptionalAttr:$use_locking ); @@ -12928,10 +13226,10 @@ var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0} let arguments = (ins Arg:$var, Arg:$accum, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr, - Arg, [{L1 regularization. Must be a scalar.}]>:$l1, - Arg, [{L2 regularization. Must be a scalar.}]>:$l2, - Arg, [{The gradient.}]>:$grad, + Arg:$lr, + Arg:$l1, + Arg:$l2, + Arg:$grad, DefaultValuedOptionalAttr:$use_locking ); @@ -12951,10 +13249,10 @@ var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0} let arguments = (ins Arg:$var, - Arg, [{Scaling factor. Must be a scalar.}]>:$alpha, - Arg, [{L1 regularization. Must be a scalar.}]>:$l1, - Arg, [{L2 regularization. Must be a scalar.}]>:$l2, - Arg, [{The change.}]>:$delta, + Arg:$alpha, + Arg:$l1, + Arg:$l2, + Arg:$delta, DefaultValuedOptionalAttr:$use_locking ); @@ -12984,11 +13282,11 @@ var <- var - mom Arg:$var, Arg:$ms, Arg:$mom, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr, - Arg, [{Decay rate. Must be a scalar.}]>:$rho, - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum, - Arg, [{Ridge term. Must be a scalar.}]>:$epsilon, - Arg, [{The gradient.}]>:$grad, + Arg:$lr, + Arg:$rho, + TF_NumberTensor:$momentum, + Arg:$epsilon, + Arg:$grad, DefaultValuedOptionalAttr:$use_locking ); @@ -13063,7 +13361,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []` let arguments = (ins Arg:$resource, Arg:$indices, - Arg, [{A tensor of updated values to add to `ref`.}]>:$updates + Arg:$updates ); let results = (outs); @@ -13102,7 +13400,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []` let arguments = (ins Arg:$resource, Arg:$indices, - Arg, [{A tensor of updated values to add to `ref`.}]>:$updates + Arg:$updates ); let results = (outs); @@ -13141,7 +13439,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []` let arguments = (ins Arg:$resource, Arg:$indices, - Arg, [{A tensor of updated values to add to `ref`.}]>:$updates + Arg:$updates ); let results = (outs); @@ -13180,7 +13478,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []` let arguments = (ins Arg:$resource, Arg:$indices, - Arg, [{A tensor of updated values to add to `ref`.}]>:$updates + Arg:$updates ); let results = (outs); @@ -13219,7 +13517,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []` let arguments = (ins Arg:$resource, Arg:$indices, - Arg, [{A tensor of updated values to add to `ref`.}]>:$updates + Arg:$updates ); let results = (outs); @@ -13431,7 +13729,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []` let arguments = (ins Arg:$resource, Arg:$indices, - Arg, [{A tensor of updated values to add to `ref`.}]>:$updates + Arg:$updates ); let results = (outs); @@ -13484,8 +13782,8 @@ var -= lr * grad * (1 / sqrt(accum)) let arguments = (ins Arg:$var, Arg:$accum, - Arg, [{Learning rate. Must be a scalar.}]>:$lr, - Arg, [{The gradient.}]>:$grad, + Arg:$lr, + Arg:$grad, Arg:$indices, DefaultValuedOptionalAttr:$use_locking, @@ -13512,9 +13810,9 @@ var -= lr * grad * (1 / sqrt(accum)) let arguments = (ins Arg:$var, Arg:$accum, - Arg, [{Learning rate. Must be a scalar.}]>:$lr, - Arg, [{Constant factor. Must be a scalar.}]>:$epsilon, - Arg, [{The gradient.}]>:$grad, + Arg:$lr, + Arg:$epsilon, + Arg:$grad, Arg:$indices, DefaultValuedOptionalAttr:$use_locking, @@ -13545,12 +13843,12 @@ accum = accum_new Arg:$var, Arg:$accum, Arg:$linear, - Arg, [{The gradient.}]>:$grad, + Arg:$grad, Arg:$indices, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr, - Arg, [{L1 regularization. Must be a scalar.}]>:$l1, - Arg, [{L2 regularization. Must be a scalar.}]>:$l2, - Arg, [{Scaling factor. Must be a scalar.}]>:$lr_power, + Arg:$lr, + Arg:$l1, + Arg:$l2, + Arg:$lr_power, DefaultValuedOptionalAttr:$use_locking, DefaultValuedOptionalAttr:$multiply_linear_by_lr @@ -14853,7 +15151,7 @@ array([[2.5, 2.5, 2.5, 2.5], }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, + TF_NumberTensor:$data, Arg:$segment_ids ); let results = (outs - Res, [{Has same shape as data, except for dimension 0 which + Res:$output ); @@ -14953,7 +15251,7 @@ array([[4, 6, 6, 4], }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, + TF_NumberTensor:$data, Arg:$segment_ids ); let results = (outs - Res, [{Has same shape as data, except for dimension 0 which + Res:$output ); @@ -14970,6 +15268,64 @@ has size `k`, the number of segments.}]>:$output TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; } +def TF_SegmentProdV2Op : TF_Op<"SegmentProdV2", [Pure]> { + let summary = "Computes the product along segments of a tensor."; + + let description = [{ +Read +[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +for an explanation of segments. + +Computes a tensor such that +\\(output_i = \prod_j data_j\\) where the product is over `j` such +that `segment_ids[j] == i`. + +If the product is empty for a given segment ID `i`, `output[i] = 1`. + +Note: That this op is currently only supported with jit_compile=True. + +The only difference with SegmentProd is the additional input `num_segments`. +This helps in evaluating the output shape in compile time. +`num_segments` should be consistent with segment_ids. +e.g. Max(segment_ids) - 1 should be equal to `num_segments` for a 1-d segment_ids +With inconsistent num_segments, the op still runs. only difference is, +the output takes the size of num_segments irrespective of size of segment_ids and data. +for num_segments less than expected output size, the last elements are ignored +for num_segments more than the expected output size, last elements are assigned 1. + +For example: + +>>> @tf.function(jit_compile=True) +... def test(c): +... return tf.raw_ops.SegmentProdV2(data=c, segment_ids=tf.constant([0, 0, 1]), num_segments=2) +>>> c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]]) +>>> test(c).numpy() +array([[4, 6, 6, 4], + [5, 6, 7, 8]], dtype=int32) + }]; + + let arguments = (ins + TF_NumberTensor:$data, + Arg:$segment_ids, + TF_I32OrI64Tensor:$num_segments + ); + + let results = (outs + Res:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr Tnumsegments = TF_DerivedOperandTypeAttr<2>; +} + def TF_SegmentSumOp : TF_Op<"SegmentSum", [Pure]> { let summary = "Computes the sum along segments of a tensor."; @@ -15003,7 +15359,7 @@ array([[5, 5, 5, 5], }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, + TF_NumberTensor:$data, Arg:$segment_ids ); let results = (outs - Res, [{Has same shape as data, except for dimension 0 which + Res:$output ); @@ -15020,6 +15376,46 @@ has size `k`, the number of segments.}]>:$output TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; } +def TF_SegmentSumV2Op : TF_Op<"SegmentSumV2", [Pure]> { + let summary = "Computes the sum along segments of a tensor."; + + let description = [{ +Read +[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation) +for an explanation of segments. + +Computes a tensor such that +\\(output_i = \sum_j data_j\\) where sum is over `j` such +that `segment_ids[j] == i`. + +If the sum is empty for a given segment ID `i`, `output[i] = 0`. + +Note that this op is currently only supported with jit_compile=True. + + }]; + + let arguments = (ins + TF_NumberTensor:$data, + Arg:$segment_ids, + TF_I32OrI64Tensor:$num_segments + ); + + let results = (outs + Res:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr Tnumsegments = TF_DerivedOperandTypeAttr<2>; +} + def TF_SelfAdjointEigV2Op : TF_Op<"SelfAdjointEigV2", [Pure]> { let summary = [{ Computes the eigen decomposition of one or more square self-adjoint matrices. @@ -16082,10 +16478,10 @@ In the following shapes, `nnz` is the count after taking `thresh` into account. let arguments = (ins Arg:$a_indices, - Arg, [{1-D. The `values` of the first `SparseTensor`, size `[nnz]` Vector.}]>:$a_values, + Arg:$a_values, Arg:$a_shape, Arg:$b_indices, - Arg, [{1-D. The `values` of the second `SparseTensor`, size `[nnz]` Vector.}]>:$b_values, + Arg:$b_values, Arg:$b_shape, Arg:$thresh @@ -16093,7 +16489,7 @@ pair takes space.}]>:$thresh let results = (outs TF_Int64Tensor:$sum_indices, - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$sum_values, + TF_NumberTensor:$sum_values, TF_Int64Tensor:$sum_shape ); @@ -16223,7 +16619,7 @@ which are interpreted according to the indexing rules in Python. let arguments = (ins Arg:$input_indices, - Arg, [{1-D. `N` non-empty values corresponding to `input_indices`.}]>:$input_values, + Arg:$input_values, Arg:$input_shape, Arg:$reduction_axes, @@ -16231,7 +16627,7 @@ SparseTensor, possibly not in canonical ordering.}]>:$input_indices, ); let results = (outs - Res, [{`R-K`-D. The reduced Tensor.}]>:$output + Res:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; @@ -17771,7 +18167,7 @@ retained with length 1. }]; let arguments = (ins - Arg, [{The tensor to reduce.}]>:$input, + Arg:$input, Arg:$reduction_indices, @@ -17779,7 +18175,7 @@ retained with length 1. ); let results = (outs - Res, [{The reduced tensor.}]>:$output + Res:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -18216,11 +18612,11 @@ Given an input tensor, this function computes tangent of every }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$x + TF_FpOrComplexTensor:$x ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$y + TF_FpOrComplexTensor:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -18937,7 +19333,7 @@ Apply a sparse update to a tensor taking the element-wise maximum. Returns a new tensor copied from `tensor` whose values are element-wise maximum between tensor and updates according to the indices. ->>> tensor = [0, 0, 0, 0, 0, 0, 0, 0] +>>> tensor = [0, 0, 0, 0, 0, 0, 0, 0] >>> indices = [[1], [4], [5]] >>> updates = [1, -1, 1] >>> tf.tensor_scatter_nd_max(tensor, indices, updates).numpy() @@ -19236,8 +19632,15 @@ def TF_TimestampOp : TF_Op<"Timestamp", []> { let description = [{ Returns the timestamp as a `float64` for seconds since the Unix epoch. -Note: the timestamp is computed when the op is executed, not when it is added -to the graph. +Common usages include: +* Logging +* Providing a random number seed +* Debugging graph execution +* Generating timing information, mainly through comparison of timestamps + +Note: In graph mode, the timestamp is computed when the op is executed, +not when it is added to the graph. In eager mode, the timestamp is computed +when the op is eagerly executed. }]; let arguments = (ins); @@ -19424,7 +19827,7 @@ left-hand side.}]>:$rhs, def TF_TruncateDivOp : TF_Op<"TruncateDiv", [Pure, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { - let summary = "Returns x / y element-wise for integer types."; + let summary = "Returns x / y element-wise, rounded towards zero."; let description = [{ Truncation designates that negative numbers will round fractional quantities @@ -19542,6 +19945,8 @@ Same shape condition as scales.}]>:$zero_points, TF_DerivedOperandTypeAttr Tin = TF_DerivedOperandTypeAttr<0>; TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; + + let hasVerifier = 1; } def TF_UniformQuantizeOp : TF_Op<"UniformQuantize", [Pure]> { @@ -19570,6 +19975,65 @@ Same shape condition as scales.}]>:$zero_points, TF_DerivedOperandTypeAttr Tin = TF_DerivedOperandTypeAttr<0>; TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; + + let hasVerifier = 1; +} + +def TF_UniformQuantizedAddOp : TF_Op<"UniformQuantizedAdd", [Pure]> { + let summary = [{ +Perform quantized add of quantized Tensor `lhs` and quantized Tensor `rhs` to make quantized `output`. + }]; + + let description = [{ +Given quantized `lhs` and quantized `rhs`, performs quantized add on `lhs` and `rhs` to make quantized `output`. + +`UniformQuantizedAdd` follows Numpy broadcasting rules. +The two input array shapes are compared element-wise. +Starting with the trailing dimensions, the two dimensions either have to be equal or one of them needs to be 1. + +`lhs` and `rhs` must be quantized Tensor, where data value is quantized using the formula: +``` +quantized_data = clip(original_data / scale + zero_point, quantization_min_val, quantization_max_val) +``` +`output` is also quantized, using the same formula. + +If `lhs` and `output` is both per-axis quantized, the quantization axis must match. +Also, if `rhs` and `output` is both per-axis quantized, the quantization axis must match. +*Match* means the axis must match when adding, regarding the broadcasting. +i.e. For both operands `lhs` and `rhs`, +if `operand.quantization_axis` >= 0 and `output.quantization_axis` >= 0, +`operand.dims` - `operand.quantization_axis` must be equal to `output.dims` - `output.quantization_axis`. + }]; + + let arguments = (ins + Arg:$lhs, + Arg:$rhs, + Arg:$lhs_scales, + Arg:$lhs_zero_points, + Arg:$rhs_scales, + Arg:$rhs_zero_points, + Arg:$output_scales, + Arg:$output_zero_points, + + DefaultValuedOptionalAttr:$lhs_quantization_axis, + I64Attr:$lhs_quantization_min_val, + I64Attr:$lhs_quantization_max_val, + DefaultValuedOptionalAttr:$rhs_quantization_axis, + I64Attr:$rhs_quantization_min_val, + I64Attr:$rhs_quantization_max_val, + DefaultValuedOptionalAttr:$output_quantization_axis, + I64Attr:$output_quantization_min_val, + I64Attr:$output_quantization_max_val + ); + + let results = (outs + Res:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } def TF_UniformQuantizedClipByValueOp : TF_Op<"UniformQuantizedClipByValue", [Pure]> { @@ -19604,6 +20068,138 @@ Same shape condition as scales.}]>:$zero_points, TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_UniformQuantizedConvolutionOp : TF_Op<"UniformQuantizedConvolution", [Pure]> { + let summary = [{ +Perform quantized convolution of quantized Tensor `lhs` and quantized Tensor `rhs`. to make quantized `output`. + }]; + + let description = [{ +Given quantized `lhs` and quantized `rhs`, performs quantized dot on `lhs` and `rhs` to make quantized `output`. + +`lhs` and `rhs` must be Tensors of same rank, and meet following shape conditions. +- `lhs_feature` % `feature_group_count` == 0 +- `lhs_feature` % `rhs_input_feature` == 0 +- `lhs_feature` / `feature_group_count` == `rhs_input_feature` +- `rhs_output_feature` % `feature_group_count` == 0 +- `lhs_batch` % `batch_group_count` == 0 +- `rhs_output_feature` % `batch_group_count` == 0 + +`lhs` and `rhs` must be quantized Tensor, where data value is quantized using the formula: +``` +quantized_data = clip(original_data / scale + zero_point, quantization_min_val, quantization_max_val) +``` +`output` is also quantized, using the same formula. +If `rhs` is per-tensor quantized, `output` must be also per-tensor quantized. + }]; + + let arguments = (ins + Arg= 3.}]>:$lhs, + Arg:$rhs, + Arg:$lhs_scales, + Arg:$lhs_zero_points, + Arg:$rhs_scales, + Arg:$rhs_zero_points, + Arg:$output_scales, + Arg:$output_zero_points, + + DefaultValuedOptionalAttr:$window_strides, + StrAttr:$padding, + DefaultValuedOptionalAttr:$explicit_padding, + DefaultValuedOptionalAttr:$lhs_dilation, + DefaultValuedOptionalAttr:$rhs_dilation, + DefaultValuedOptionalAttr:$batch_group_count, + DefaultValuedOptionalAttr:$feature_group_count, + DefaultValuedOptionalAttr:$dimension_numbers, + DefaultValuedOptionalAttr:$lhs_quantization_axis, + I64Attr:$lhs_quantization_min_val, + I64Attr:$lhs_quantization_max_val, + DefaultValuedOptionalAttr:$rhs_quantization_axis, + I64Attr:$rhs_quantization_min_val, + I64Attr:$rhs_quantization_max_val, + DefaultValuedOptionalAttr:$output_quantization_axis, + I64Attr:$output_quantization_min_val, + I64Attr:$output_quantization_max_val + ); + + let results = (outs + Res:$output + ); + + TF_DerivedOperandTypeAttr Tin = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; +} + +def TF_UniformQuantizedConvolutionHybridOp : TF_Op<"UniformQuantizedConvolutionHybrid", [Pure]> { + let summary = [{ +Perform hybrid quantized convolution of float Tensor `lhs` and quantized Tensor `rhs`. + }]; + + let description = [{ +Given float `lhs` and quantized `rhs`, internally performs quantization on `lhs`, +and then performs quantized convolution on quantized `lhs` and `rhs`. + +The internal quantization on `lhs` is a quantization to `Trhs`, dynamic range, +per-batch (per-axis along axis `dimension_numbers.input_batch_dimension`), asymmetric, +and not narrow range (the range is [Trhs_MIN, Trhs_MAX]). + +`lhs` and `rhs` must be Tensors of same rank, and meet following shape conditions. +- lhs_feature % feature_group_count == 0 +- lhs_feature % rhs_input_feature == 0 +- lhs_feature / feature_group_count == rhs_input_feature +- rhs_output_feature % feature_group_count == 0 +- lhs_batch % batch_group_count == 0 +- rhs_output_feature % batch_group_count == 0 + +`rhs` must be quantized Tensor, where its data value is quantized using the formula: +quantized_data = clip(original_data / scale + zero_point, quantization_min_val, quantization_max_val). + }]; + + let arguments = (ins + Arg= 3.}]>:$lhs, + Arg:$rhs, + Arg:$rhs_scales, + Arg:$rhs_zero_points, + + DefaultValuedOptionalAttr:$window_strides, + StrAttr:$padding, + DefaultValuedOptionalAttr:$explicit_padding, + DefaultValuedOptionalAttr:$lhs_dilation, + DefaultValuedOptionalAttr:$rhs_dilation, + DefaultValuedOptionalAttr:$batch_group_count, + DefaultValuedOptionalAttr:$feature_group_count, + DefaultValuedOptionalAttr:$dimension_numbers, + DefaultValuedOptionalAttr:$rhs_quantization_axis, + I64Attr:$rhs_quantization_min_val, + I64Attr:$rhs_quantization_max_val + ); + + let results = (outs + Res:$output + ); + + TF_DerivedOperandTypeAttr Tlhs = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Trhs = TF_DerivedOperandTypeAttr<1>; + TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; + + let hasVerifier = 1; +} + def TF_UniformQuantizedDotOp : TF_Op<"UniformQuantizedDot", [Pure]> { let summary = [{ Perform quantized dot of quantized Tensor `lhs` and quantized Tensor `rhs` to make quantized `output`. @@ -19653,6 +20249,8 @@ Same shape condition as rhs_scales.}]>:$output_zero_points, TF_DerivedOperandTypeAttr Tin = TF_DerivedOperandTypeAttr<0>; TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; + + let hasVerifier = 1; } def TF_UniformQuantizedDotHybridOp : TF_Op<"UniformQuantizedDotHybrid", [Pure]> { @@ -19689,6 +20287,8 @@ The output data is the original output data itself (Not quantized).}]>:$output TF_DerivedOperandTypeAttr Tlhs = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Trhs = TF_DerivedOperandTypeAttr<1>; TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; + + let hasVerifier = 1; } def TF_UniformRequantizeOp : TF_Op<"UniformRequantize", [Pure]> { @@ -19737,6 +20337,8 @@ Same shape condition as scales.}]>:$output_zero_points, TF_DerivedOperandTypeAttr Tin = TF_DerivedOperandTypeAttr<0>; TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; + + let hasVerifier = 1; } def TF_UniqueOp : TF_Op<"Unique", [Pure]> { @@ -19975,7 +20577,7 @@ dimension of its shape if `num_segments` is 0. }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, + TF_NumberTensor:$data, Arg:$segment_ids, ); let results = (outs - Res, [{Has same shape as data, except for the first `segment_ids.rank` + Res:$output ); @@ -20035,7 +20637,7 @@ array([[5, 5, 5, 5], }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$data, + TF_NumberTensor:$data, Arg:$segment_ids, ); let results = (outs - Res, [{Has same shape as data, except for the first `segment_ids.rank` + Res:$output ); @@ -20252,7 +20854,7 @@ where(input) ==> [[0, 0, 0], }]; let arguments = (ins - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input ); let results = (outs @@ -20267,12 +20869,12 @@ def TF_XdivyOp : TF_Op<"Xdivy", [Pure, ResultsBroadcastableShape, TF_SameOperand let summary = "Returns 0 if x == 0, and x / y otherwise, elementwise."; let arguments = (ins - TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$x, - TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$y + TF_FpOrComplexTensor:$x, + TF_FpOrComplexTensor:$y ); let results = (outs - TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$z + TF_FpOrComplexTensor:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -20312,14 +20914,14 @@ for binary operators. }]; let arguments = (ins - Arg, [{the LHS input tensor}]>:$lhs, - Arg, [{the RHS input tensor}]>:$rhs, + Arg:$lhs, + Arg:$rhs, Arg:$broadcast_dims ); let results = (outs - Res, [{the broadcasted LHS tensor}]>:$lhs_output, - Res, [{the broadcasted RHS tensor}]>:$rhs_output + Res:$lhs_output, + Res:$rhs_output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -20362,7 +20964,8 @@ E.g., the specification "2.1" denotes the value args[2].shape[1]. let arguments = (ins Arg, [{A list of `Tensor` with possibly different types to be passed as arguments -to the HLO module.}]>:$args, +to the HLO module. These are all non-dimension arguments. The dimension +arguments are computed at JIT time.}]>:$args, I64Attr:$version, StrAttr:$module, @@ -20403,8 +21006,8 @@ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution }]; let arguments = (ins - Arg, [{the input tensor}]>:$lhs, - Arg, [{the kernel tensor}]>:$rhs, + Arg:$lhs, + Arg:$rhs, Arg:$window_strides, Arg:$padding, Arg:$lhs_dilation, @@ -20416,7 +21019,7 @@ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TF_NumberTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -20434,8 +21037,8 @@ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution }]; let arguments = (ins - Arg, [{input tensor}]>:$lhs, - Arg, [{kernel tensor}]>:$rhs, + Arg:$lhs, + Arg:$rhs, Arg:$window_strides, Arg:$padding, Arg:$lhs_dilation, @@ -20448,7 +21051,7 @@ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TF_NumberTensor:$output ); TF_DerivedOperandTypeAttr LhsT = TF_DerivedOperandTypeAttr<0>; @@ -20497,15 +21100,15 @@ https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral }]; let arguments = (ins - Arg, [{the LHS tensor}]>:$lhs, - Arg, [{the RHS tensor}]>:$rhs, + Arg:$lhs, + Arg:$rhs, StrAttr:$dimension_numbers, StrAttr:$precision_config ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TF_NumberTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -20520,15 +21123,15 @@ https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral }]; let arguments = (ins - Arg, [{the LHS tensor}]>:$lhs, - Arg, [{the RHS tensor}]>:$rhs, + Arg:$lhs, + Arg:$rhs, StrAttr:$dimension_numbers, StrAttr:$precision_config ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TF_NumberTensor:$output ); TF_DerivedOperandTypeAttr LhsT = TF_DerivedOperandTypeAttr<0>; @@ -20629,7 +21232,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather }]; let arguments = (ins - Arg, [{The array we're gathering from.}]>:$operand, + Arg, [{The array we're gathering from.}]>:$operand, Arg:$start_indices, Arg:$slice_sizes, @@ -20638,7 +21241,7 @@ https://www.tensorflow.org/xla/operation_semantics#gather ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -20840,15 +21443,15 @@ https://www.tensorflow.org/performance/xla/operation_semantics#reduce . }]; let arguments = (ins - Arg, [{the input tensor}]>:$input, - Arg, [{a scalar representing the initial value for the reduction}]>:$init_value, + Arg, [{the input tensor}]>:$input, + Arg, [{a scalar representing the initial value for the reduction}]>:$init_value, I64ArrayAttr:$dimensions_to_reduce, SymbolRefAttr:$reducer ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -20907,8 +21510,8 @@ https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . }]; let arguments = (ins - Arg, [{the input tensor}]>:$input, - Arg, [{a scalar representing the initial value for the reduction}]>:$init_value, + Arg, [{the input tensor}]>:$input, + Arg, [{a scalar representing the initial value for the reduction}]>:$init_value, Arg:$window_dimensions, Arg:$window_strides, TF_I32OrI64Tensor:$base_dilations, @@ -20919,7 +21522,7 @@ https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -20994,10 +21597,10 @@ https://www.tensorflow.org/xla/operation_semantics#scatter. }]; let arguments = (ins - Arg, [{Array to be scattered into.}]>:$operand, + Arg, [{Array to be scattered into.}]>:$operand, Arg:$scatter_indices, - Arg, [{Array containing the values that must be used for scattering.}]>:$updates, + Arg, [{Array containing the values that must be used for scattering.}]>:$updates, SymbolRefAttr:$update_computation, StrAttr:$dimension_numbers, @@ -21005,7 +21608,7 @@ be scattered to.}]>:$scatter_indices, ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -21021,19 +21624,19 @@ https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter }]; let arguments = (ins - Arg, [{the input tensor}]>:$operand, + Arg:$operand, Arg:$window_dimensions, Arg:$window_strides, Arg:$padding, - Arg, [{a tensor of values to scatter}]>:$source, - Arg, [{a scalar representing the initial value for the output tensor}]>:$init_value, + Arg:$source, + Arg:$init_value, SymbolRefAttr:$select, SymbolRefAttr:$scatter ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + TF_NumberTensor:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -21056,7 +21659,7 @@ i=0...N-1. }]; let arguments = (ins - Arg, [{the input tensor.}]>:$a, + Arg:$a, BoolAttr:$lower, I64Attr:$max_iter, @@ -21064,9 +21667,9 @@ i=0...N-1. ); let results = (outs - Res, [{The eigenvalues in ascending order, each repeated according to its + Res:$w, - Res, [{The column v[..., :, i] is the normalized eigenvector corresponding to the + Res:$v ); @@ -21147,6 +21750,25 @@ key: A unique identifier for this region used to match up host transfers. TF_DerivedOperandTypeAttr Tinput = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaSetBoundOp : TF_Op<"XlaSetBound", [Pure]> { + let summary = [{ +Set a bound for the given input value as a hint to Xla compiler, + }]; + + let description = [{ +returns the same value. + }]; + + let arguments = (ins + TF_Int32Tensor:$input, + TF_Int32Tensor:$bound + ); + + let results = (outs + TF_Int32Tensor:$output + ); +} + def TF_XlaSetDynamicDimensionSizeOp : TF_Op<"XlaSetDynamicDimensionSize", [InferTensorType, Pure, TF_NoConstantFold]> { let summary = "Make a static dimension into a xla bounded dynamic dimension."; @@ -21196,6 +21818,63 @@ Sorts a tensor. Currently only sorts in ascending order are supported. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_XlaSpmdFullToShardShapeOp : TF_Op<"XlaSpmdFullToShardShape", [Pure]> { + let summary = [{ +An op used by XLA SPMD partitioner to switch from automatic partitioning to + }]; + + let description = [{ +manual partitioning. It annotates the input (full-shape, to be automatically +partitioned) with the same sharding used by manual partitioning, and outputs a +shard-shaped tensor to be consumed by later manually-partitioned ops. If the +shape is not evenly partitionable, the padding region will be masked with 0s. +The conversion can happen partially in subgroups, by specifying the dim +attribute, where only that dim will be converted. + }]; + + let arguments = (ins + TF_Tensor:$input, + + StrAttr:$manual_sharding, + DefaultValuedOptionalAttr:$dim, + DefaultValuedOptionalAttr:$unspecified_dims + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_XlaSpmdShardToFullShapeOp : TF_Op<"XlaSpmdShardToFullShape", [Pure]> { + let summary = [{ +An op used by XLA SPMD partitioner to switch from manual partitioning to + }]; + + let description = [{ +automatic partitioning. It converts the shard-shaped, manually partitioned input +into full-shaped tensor to be partitioned automatically with the same sharding +used by manual partitioning. The conversion can happen partially in subgroups, +by specifying the dim attribute, where only that dim will be converted. + }]; + + let arguments = (ins + TF_Tensor:$input, + + StrAttr:$manual_sharding, + TF_ShapeAttr:$full_shape, + DefaultValuedOptionalAttr:$dim, + DefaultValuedOptionalAttr:$unspecified_dims + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_XlaSvdOp : TF_Op<"XlaSvd", [Pure]> { let summary = [{ Computes the eigen decomposition of a batch of self-adjoint matrices @@ -21209,7 +21888,7 @@ tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[ }]; let arguments = (ins - Arg, [{the input tensor.}]>:$a, + Arg:$a, I64Attr:$max_iter, F32Attr:$epsilon, @@ -21217,10 +21896,10 @@ tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[ ); let results = (outs - Res, [{Singular values. The values are sorted in reverse order of magnitude, so + Res:$s, - Res, [{Left singular vectors.}]>:$u, - Res, [{Right singular vectors.}]>:$v + Res:$u, + Res:$v ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -21238,15 +21917,15 @@ XlaVariadicReduceV2 is a version that supports heterogeneous operands. }]; let arguments = (ins - Arg>, [{the input tensor(s)}]>:$input, - Arg>, [{scalar initial value(s) for the reduction}]>:$init_value, + Arg>, [{the input tensor(s)}]>:$input, + Arg>, [{scalar initial value(s) for the reduction}]>:$init_value, I64ArrayAttr:$dimensions_to_reduce, SymbolRefAttr:$reducer ); let results = (outs - Variadic>:$output + Variadic>:$output ); TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; @@ -21318,12 +21997,12 @@ def TF_Xlog1pyOp : TF_Op<"Xlog1py", [Pure, TF_SameOperandsAndResultElementTypeRe let summary = "Returns 0 if x == 0, and x * log1p(y) otherwise, elementwise."; let arguments = (ins - TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$x, - TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$y + TF_FpOrComplexTensor:$x, + TF_FpOrComplexTensor:$y ); let results = (outs - TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$z + TF_FpOrComplexTensor:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -21334,12 +22013,12 @@ def TF_XlogyOp : TF_Op<"Xlogy", [Pure, ResultsBroadcastableShape, TF_SameOperand let summary = "Returns 0 if x == 0, and x * log(y) otherwise, elementwise."; let arguments = (ins - TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$x, - TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$y + TF_FpOrComplexTensor:$x, + TF_FpOrComplexTensor:$y ); let results = (outs - TensorOf<[TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64]>:$z + TF_FpOrComplexTensor:$z ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 9e05ae6c42b..cb5f7bfd3e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -48,8 +48,7 @@ TODO: Make invariants more structured so that we can reference them in ops. }]; let cppNamespace = "::mlir::TF"; - - let emitAccessorPrefix = kEmitAccessorPrefix_Raw; + let useFoldAPI = kEmitFoldAdaptorFolder; } //===----------------------------------------------------------------------===// @@ -119,7 +118,7 @@ class TF_OpIsBroadcastableToRes : And<[ class TF_AllTypesMatchPred values> : - CPred<"tf_type::AreCastCompatible(llvm::makeArrayRef({" # + CPred<"tf_type::AreCastCompatible(llvm::ArrayRef({" # !interleave(values, ", ") # "}))">; class TF_AllTypesMatch names> : @@ -190,6 +189,7 @@ def TF_TPUExecuteResource : TF_ResourceBase<"TPUExecute">; def TF_RandomGeneratorResource : TF_ResourceBase<"RandomGenerator">; def TF_XlaHostComputeResource : TF_ResourceBase<"XlaHostCompute">; def TF_CollectiveReduceOrderingResource : TF_ResourceBase<"CollectiveReduceOrdering">; +def TF_NcclAllReduceOrderingResource : TF_ResourceBase<"NcclAllReduceOrdering">; // Fake resource, see `TF_MustExecute` below. def TF_MustExecuteResource : TF_ResourceBase<"MustExecute">; @@ -257,6 +257,9 @@ def TF_RandomGeneratorSideEffect : MemoryEffects<[MemWrite]>; +// Special effect for keeping `NcclAllReduce` ops on the same device in order. +def TF_NcclAllReduceOrderingEffect : MemoryEffects<[MemWrite]>; + // Trait for enforcing that a side-effecting op is executed, even if it would be // considered dead by MLIR (see b/195782952). // The trait is implemented as a write effect for a fake resource which is @@ -313,6 +316,8 @@ def TF_Float16Ref : TF_TensorFlowType<"HalfRef", "f16ref">; def TF_Float32Ref : TF_TensorFlowType<"FloatRef", "f32ref">; def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">; def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">; +def TF_Float8E4M3FNRef : TF_TensorFlowType<"Float8E4M3FNRef", "float8e4m3fnref">; +def TF_Float8E5M2Ref : TF_TensorFlowType<"Float8E5M2Ref", "float8e5m2ref">; // Complex reference types def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">; @@ -421,11 +426,14 @@ def TF_Float16 : AnyTypeOf<[F16, TF_Float16Ref], "16-bit float">; def TF_Float32 : AnyTypeOf<[F32, TF_Float32Ref], "32-bit float">; def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">; def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">; +def TF_Float8E4M3FN : AnyTypeOf<[F8E4M3FN, TF_Float8E4M3FNRef], "float8e4m3fn">; +def TF_Float8E5M2 : AnyTypeOf<[F8E5M2, TF_Float8E5M2Ref], "float8e5m2">; def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">; def TF_Float : AnyTypeOf< - [TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16], + [TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16, TF_Float8E4M3FN, + TF_Float8E5M2], "floating-point">; // Tensor types @@ -435,6 +443,8 @@ def TF_Float16Tensor : TensorOf<[TF_Float16]>; def TF_Float32Tensor : TensorOf<[TF_Float32]>; def TF_Float64Tensor : TensorOf<[TF_Float64]>; def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>; +def TF_Float8E4M3FNTensor : TensorOf<[TF_Float8E4M3FN]>; +def TF_Float8E5M2Tensor : TensorOf<[TF_Float8E5M2]>; //===----------------------------------------------------------------------===// // Complex types (including corresponding reference types) @@ -556,7 +566,7 @@ class TF_DerivedOperandTypeListAttr : DerivedAttr< "return {mlir::OperandElementTypeIterator(values.begin()), " "mlir::OperandElementTypeIterator(values.end())};", [{ - ArrayAttr::get($_ctxt, + ArrayAttr::get($_ctxt, [&]() { llvm::SmallVector ret; for (auto t : $_self) @@ -612,7 +622,7 @@ class TF_DerivedResultTypeListAttr : DerivedAttr< "return {mlir::ResultElementTypeIterator(values.begin()), " "mlir::ResultElementTypeIterator(values.end())};", [{ - ArrayAttr::get($_ctxt, + ArrayAttr::get($_ctxt, [&]() { llvm::SmallVector ret; for (auto t : $_self) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h index 8d9bdfa3a68..88cbf879d56 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h @@ -125,7 +125,7 @@ ResourceHandleValueAndId GetResourceHandleValueAndIdBase( #define INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Op) \ LogicalResult Op::inferReturnTypeComponents( \ - MLIRContext* context, Optional location, \ + MLIRContext* context, std::optional location, \ ValueShapeRange operands, DictionaryAttr attributes, \ RegionRange regions, \ SmallVectorImpl& inferredReturnShapes) { \ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td index 0f3d48bf577..2968c8eb1c0 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td @@ -40,8 +40,14 @@ def TF_LayoutSensitiveInterface : OpInterface<"LayoutSensitiveInterface"> { let methods = [ InterfaceMethod< [{Returns current operation data format (data layout).}], - "StringRef", "data_format", (ins) + "StringRef", "getDataFormat", (ins) >, + InterfaceMethod< + [{Deprecated method that returns current operation data format (data layout).}], + "StringRef", "data_format", (ins), + /*methodBody=*/[{ + return $_op.getDataFormat(); + }]>, InterfaceMethod< [{Returns indices of layout dependent arguments.}], "SmallVector", "GetLayoutDependentArgs", (ins) @@ -138,12 +144,14 @@ def TF_GetResourceInstanceInterface : OpInterface<"GetResourceInstanceInterface" let methods = [ InterfaceMethod< /*desc=*/[{Returns a string corresponding to the resource instance - accessed by this op. The implementation must guarantee that the + accessed by this op, or `std::nullopt` if the resource should + be ignored. The implementation must guarantee that the mapping between resource instances and strings is bijective, i.e., two op instances should return the same string if and only if they access the same resource. The interface should - only be used for ops that access exactly one resource.}], - /*retTy=*/"std::string", + only be used for ops that access exactly one op-based resource + (see `tf_op_base.td` for details).}], + /*retTy=*/"std::optional", /*methodName=*/"GetResourceInstanceStr", /*args=*/(ins) >, diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 9c72980d39c..da6bdb54cf3 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -159,7 +159,7 @@ struct TFInlinerInterface : public DialectInlinerInterface { // Returns if its legal to inline 'src' region into the 'dest' region // attached to a TF operation. bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, - BlockAndValueMapping &valueMapping) const final { + IRMapping &valueMapping) const final { // Allow inlining in regions attached to region based control flow // operations only if the src region is a single block region return isa(dest->getParentOp()) && @@ -169,7 +169,7 @@ struct TFInlinerInterface : public DialectInlinerInterface { // Returns true if its legal to inline a TF operation `op` into the `dest` // region. bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, - BlockAndValueMapping &) const final { + IRMapping &) const final { // An op is legal to inline if either of the following conditions is true: // (a) Its legal to duplicate the Op. // (b) The Op is inside a single use function. If that function is inlined, @@ -227,7 +227,7 @@ bool TensorFlowDialect::CanDuplicate(Operation *op) { if (op->hasTrait()) return false; // If the op has no memory side effects, it can be duplicated. - if (MemoryEffectOpInterface::hasNoEffect(op)) return true; + if (isMemoryEffectFree(op)) return true; // If the op is marked stateless using the `is_stateless` attribute, that // attribute determines if the op can be duplicated. @@ -272,7 +272,7 @@ void *TensorFlowDialect::getRegisteredInterfaceForOp( // Returns true if the op can have side effects. bool TensorFlowDialect::CanHaveSideEffects(Operation *op) { // If the op has no memory side effects, it has no side effects - if (MemoryEffectOpInterface::hasNoEffect(op)) return false; + if (isMemoryEffectFree(op)) return false; // If the op is marked stateless using the `is_stateless` attribute, then // it has no side effects. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 8f14ca318d1..c184494fcaa 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -57,12 +57,13 @@ class TF_TensorListInitOp : TF_Op { // Returns data type of the result handle. Returned type contains type of // the TensorList element as a subtype. VariantType handle_dtype() { - return getElementTypeOrSelf(handle().getType()).cast(); + return getElementTypeOrSelf(getHandle().getType()).cast(); } }]; } -def TF_CaseOp : TF_Op<"Case", [DeclareOpInterfaceMethods]> { +def TF_CaseOp : TF_Op<"Case", + [DeclareOpInterfaceMethods]> { let summary = [{ An n-way switch statement which calls a single branch function. }]; @@ -111,13 +112,13 @@ An n-way switch statement, implementing the following: let extraClassDeclaration = [{ - int num_branches() { return branches().size(); } + int num_branches() { return getBranches().size(); } // Gets function corresponding branch # `index`. // Prefer passing in SymbolTableCollection to reduce lookup costs by // enabling reusing cached symbol table lookup. func::FuncOp ResolveBranchFunction(::mlir::SymbolTableCollection* table, int index) { - auto flat_sym_ref = branches()[index].cast(); + auto flat_sym_ref = getBranches()[index].cast(); if (table) return table->lookupNearestSymbolFrom(*this, flat_sym_ref); return SymbolTable::lookupNearestSymbolFrom(*this, flat_sym_ref); @@ -282,9 +283,9 @@ else_branch: A function that takes 'inputs' and returns a list of // enabling reusing cached symbol table lookup. func::FuncOp ResolveThenFunction(::mlir::SymbolTableCollection* table) { if (table) - return table->lookupNearestSymbolFrom(*this, then_branchAttr()); + return table->lookupNearestSymbolFrom(*this, getThenBranchAttr()); return SymbolTable::lookupNearestSymbolFrom( - *this, then_branchAttr()); + *this, getThenBranchAttr()); } // TODO(b/204997177): Deprecate and remove. func::FuncOp then_function(::mlir::SymbolTableCollection* table = nullptr) { @@ -296,9 +297,9 @@ else_branch: A function that takes 'inputs' and returns a list of // enabling reusing cached symbol table lookup. func::FuncOp ResolveElseFunction(::mlir::SymbolTableCollection* table) { if (table) - return table->lookupNearestSymbolFrom(*this, else_branchAttr()); + return table->lookupNearestSymbolFrom(*this, getElseBranchAttr()); return SymbolTable::lookupNearestSymbolFrom( - *this, else_branchAttr()); + *this, getElseBranchAttr()); } // TODO(b/204997177): Deprecate and remove. func::FuncOp else_function(::mlir::SymbolTableCollection* table = nullptr) { @@ -320,7 +321,7 @@ def TF_YieldOp : TF_Op<"Yield", region. }]; - let arguments = (ins Variadic:$operands); + let arguments = (ins Variadic); } def TF_IfRegionOp : TF_Op<"IfRegion", @@ -377,7 +378,8 @@ else_branch: A region that computes the outputs of the op if cond = false. } def TF_LegacyCallOp : TF_Op<"LegacyCall", - [CallOpInterface, Pure]> { + [CallOpInterface, + DeclareOpInterfaceMethods, Pure]> { let summary = "returns `f(inputs)`, where `f` is a function."; @@ -403,18 +405,18 @@ def TF_LegacyCallOp : TF_Op<"LegacyCall", let extraClassDeclaration = [{ // Gets the argument operands to the called function. - operand_range getArgOperands() { return args(); } + operand_range getArgOperands() { return getArgs(); } // Returns the callee of this operation. - CallInterfaceCallable getCallableForCallee() { return fAttr(); } + CallInterfaceCallable getCallableForCallee() { return getFAttr(); } // Returns the resolved callee function of this operation. // Prefer passing in SymbolTableCollection to reduce lookup costs by // enabling reusing cached symbol table lookup. func::FuncOp ResolveFunc(::mlir::SymbolTableCollection* table) { if (table) - return table->lookupNearestSymbolFrom(*this, fAttr()); - return SymbolTable::lookupNearestSymbolFrom(*this, fAttr()); + return table->lookupNearestSymbolFrom(*this, getFAttr()); + return SymbolTable::lookupNearestSymbolFrom(*this, getFAttr()); } // TODO(b/204997177): Deprecate and remove. func::FuncOp func() { return ResolveFunc(nullptr); } @@ -536,7 +538,7 @@ def TF_PlaceholderWithDefaultOp : TF_Op<"PlaceholderWithDefault", [Pure]> { } def TF_StatefulPartitionedCallOp : TF_Op<"StatefulPartitionedCall", - [CallOpInterface, SymbolUserOpInterface]> { + [CallOpInterface, DeclareOpInterfaceMethods]> { let summary = "returns `f(inputs)`, where `f`'s body is placed and partitioned."; @@ -564,28 +566,26 @@ underlying graph, and executes each of the partitioned subgraphs as a function. let extraClassDeclaration = [{ // Gets the argument operands to the called function. - operand_range getArgOperands() { return args(); } + operand_range getArgOperands() { return getArgs(); } // Returns the callee of this operation. - CallInterfaceCallable getCallableForCallee() { return fAttr(); } + CallInterfaceCallable getCallableForCallee() { return getFAttr(); } // Returns the resolved callee function of this operation. // Prefer passing in SymbolTableCollection to reduce lookup costs by // enabling reusing cached symbol table lookup. func::FuncOp ResolveFunc(::mlir::SymbolTableCollection* table) { if (table) - return table->lookupNearestSymbolFrom(*this, fAttr()); - return SymbolTable::lookupNearestSymbolFrom(*this, fAttr()); + return table->lookupNearestSymbolFrom(*this, getFAttr()); + return SymbolTable::lookupNearestSymbolFrom(*this, getFAttr()); } // TODO(b/204997177): Deprecate and remove. func::FuncOp func() { return ResolveFunc(nullptr); } - - // SymbolUserOpInterface verifier. - LogicalResult verifySymbolUses(SymbolTableCollection &symbolTable); }]; } -def TF_WhileOp : TF_Op<"While", [DeclareOpInterfaceMethods]> { +def TF_WhileOp : TF_Op<"While", + [DeclareOpInterfaceMethods]> { let summary = [{ output = input; While (Cond(output)) { output = Body(output) } }]; @@ -644,8 +644,8 @@ body: A function that takes a list of tensors and returns another // enabling reusing cached symbol table lookup. func::FuncOp ResolveCondFunction(::mlir::SymbolTableCollection* table) { if (table) - return table->lookupNearestSymbolFrom(*this, condAttr()); - return SymbolTable::lookupNearestSymbolFrom(*this, condAttr()); + return table->lookupNearestSymbolFrom(*this, getCondAttr()); + return SymbolTable::lookupNearestSymbolFrom(*this, getCondAttr()); } // TODO(b/204997177): Deprecate and remove. func::FuncOp cond_function() { return ResolveCondFunction(nullptr); } @@ -655,8 +655,8 @@ body: A function that takes a list of tensors and returns another // enabling reusing cached symbol table lookup. func::FuncOp ResolveBodyFunction(::mlir::SymbolTableCollection* table) { if (table) - return table->lookupNearestSymbolFrom(*this, bodyAttr()); - return SymbolTable::lookupNearestSymbolFrom(*this, bodyAttr()); + return table->lookupNearestSymbolFrom(*this, getBodyAttr()); + return SymbolTable::lookupNearestSymbolFrom(*this, getBodyAttr()); } // TODO(b/204997177): Deprecate and remove. func::FuncOp body_function() { return ResolveBodyFunction(nullptr); } @@ -779,7 +779,7 @@ Example: TensorType resource_subtype() { return resource_type().getSubtypes()[0]; } ResourceType resource_type() { - return getElementTypeOrSelf(resource()).cast(); + return getElementTypeOrSelf(getResource()).cast(); } }]; @@ -984,7 +984,8 @@ This function is faster and numerically stabler than `bessel_i1(x)`. }]; } -def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface, SymbolUserOpInterface]> { +def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", + [CallOpInterface, DeclareOpInterfaceMethods]> { let summary = "Calls a function placed on a specified TPU device."; let arguments = (ins @@ -1004,24 +1005,21 @@ def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface, Symb let extraClassDeclaration = [{ // Gets the argument operands to the called function. - operand_range getArgOperands() { return args(); } + operand_range getArgOperands() { return getArgs(); } // Returns the callee of this operation. - CallInterfaceCallable getCallableForCallee() { return fAttr(); } + CallInterfaceCallable getCallableForCallee() { return getFAttr(); } // Returns the resolved callee function of this operation. // Prefer passing in SymbolTableCollection to reduce lookup costs by // enabling reusing cached symbol table lookup. func::FuncOp ResolveFunc(::mlir::SymbolTableCollection* table) { if (table) - return table->lookupNearestSymbolFrom(*this, fAttr()); - return SymbolTable::lookupNearestSymbolFrom(*this, fAttr()); + return table->lookupNearestSymbolFrom(*this, getFAttr()); + return SymbolTable::lookupNearestSymbolFrom(*this, getFAttr()); } // TODO(b/204997177): Deprecate and remove. func::FuncOp func() { return ResolveFunc(nullptr); } - - // SymbolUserOpInterface verifier. - LogicalResult verifySymbolUses(SymbolTableCollection &symbolTable); }]; } @@ -1437,6 +1435,9 @@ operations inside a TPU host. ); } +// We must manually define TPUPartitionedInput, TPUPartitionedInputV2, +// TPUPartitionedOutput, and TPUPartitionedOutputV2 since they have an +// optional attribute, _XlaSharding, unlike the TensorFlow definition. def TF_TPUPartitionedInputOp : TF_Op<"TPUPartitionedInput", [Pure]> { let summary = [{ An op that groups a list of partitioned inputs together. This op @@ -1457,6 +1458,28 @@ An op that groups a list of partitioned inputs together. This op TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; } +def TF_TPUPartitionedInputV2Op : TF_Op<"TPUPartitionedInputV2", [Pure]> { + let summary = [{ +An op that groups a list of partitioned inputs together. Supports ND sharding. + }]; + + let arguments = (ins + Variadic:$inputs, + I64ArrayAttr:$partition_dims, + DefaultValuedOptionalAttr:$is_packed, + OptionalAttr:$_XlaSharding + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; + + let hasVerifier = 1; +} + def TF_TPUPartitionedOutputOp : TF_Op<"TPUPartitionedOutput", [Pure]> { let summary = [{ An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned @@ -1481,6 +1504,30 @@ outputs outside the XLA computation. TF_DerivedResultSizeAttr num_splits = TF_DerivedResultSizeAttr<0>; } +def TF_TPUPartitionedOutputV2Op : TF_Op<"TPUPartitionedOutputV2", [Pure]> { + let summary = [{ +An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned + }]; + + let description = [{ +outputs outside the XLA computation. Supports ND sharding. + }]; + + let arguments = (ins + TF_Tensor:$inputs, + + I64ArrayAttr:$partition_dims, + OptionalAttr:$_XlaSharding + ); + + let results = (outs + Variadic:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultSizeAttr num_splits = TF_DerivedResultSizeAttr<0>; +} + // Declares symbol reference attribute `shape_inference_graph` to be optional // unlike the TensorFlow definition. This is required to support ops that use // empty string value for the attribute to signify missing. @@ -1670,16 +1717,16 @@ This op is deprecated. Prefer `tf.nn.batch_normalization`. }]; let arguments = (ins - Arg, [{A 4D input Tensor.}]>:$x, - Arg, [{A 1D mean Tensor with size matching the last dimension of x. + Arg, [{A 4D input Tensor.}]>:$x, + Arg, [{A 1D mean Tensor with size matching the last dimension of x. This is the first output from tf.nn.moments, or a saved moving average thereof.}]>:$m, - Arg, [{A 1D variance Tensor with size matching the last dimension of x. + Arg, [{A 1D variance Tensor with size matching the last dimension of x. This is the second output from tf.nn.moments, or a saved moving average thereof.}]>:$v, - Arg, [{A 1D beta Tensor with size matching the last dimension of x. + Arg, [{A 1D beta Tensor with size matching the last dimension of x. An offset to be added to the normalized tensor.}]>:$beta, - Arg, [{A 1D gamma Tensor with size matching the last dimension of x. + Arg, [{A 1D gamma Tensor with size matching the last dimension of x. If "scale_after_normalization" is true, this tensor will be multiplied with the normalized tensor.}]>:$gamma, @@ -1688,7 +1735,7 @@ with the normalized tensor.}]>:$gamma, ); let results = (outs - TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$result + TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Qint32, TF_Qint8, TF_Quint16, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$result ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -1705,15 +1752,15 @@ greater than `clip_value_max` are set to `clip_value_max`. }]; let arguments = (ins - Arg, [{A `Tensor`.}]>:$x, - Arg, [{A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape + Arg, [{A `Tensor`.}]>:$x, + Arg, [{A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape as `t`. The minimum value to clip by.}]>:$clip_value_min, - Arg, [{A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape + Arg, [{A 0-D (scalar) `Tensor`, or a `Tensor` with the same shape as `x`. The maximum value to clip by.}]>:$clip_value_max ); let results = (outs - Res, [{A clipped `Tensor` with the same shape as input 't'.}]>:$output + Res, [{A clipped `Tensor` with the same shape as input 't'.}]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index e354ecd39d9..3947b17c879 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -129,8 +130,8 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(MergeSummaryOp); // AddOp //===----------------------------------------------------------------------===// -void AddOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void AddOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -138,8 +139,9 @@ void AddOp::getCanonicalizationPatterns(RewritePatternSet &results, // AddNOp //===----------------------------------------------------------------------===// -OpFoldResult AddNOp::fold(ArrayRef operands) { - if (operands.size() == 1) return *inputs().begin(); +OpFoldResult AddNOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); + if (operands.size() == 1) return *getInputs().begin(); // Fold if there is only one single non-zero operand or all operands are zero. int non_zero_index = -1; @@ -175,8 +177,8 @@ OpFoldResult AddNOp::fold(ArrayRef operands) { } // Check the non-zero operand's shape matches the result shape. - if (result_ty == inputs()[non_zero_index].getType()) - return inputs()[non_zero_index]; + if (result_ty == getInputs()[non_zero_index].getType()) + return getInputs()[non_zero_index]; return {}; } @@ -184,12 +186,13 @@ OpFoldResult AddNOp::fold(ArrayRef operands) { // AddV2Op //===----------------------------------------------------------------------===// -void AddV2Op::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void AddV2Op::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } -OpFoldResult AddV2Op::fold(ArrayRef operands) { +OpFoldResult AddV2Op::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); return IdentityArithmeticOpFolder(*this, operands); } @@ -199,7 +202,7 @@ OpFoldResult AddV2Op::fold(ArrayRef operands) { LogicalResult AllOp::verify() { AllOp op = *this; - return VerifyReductionInputAndDims(op.input(), op.reduction_indices(), + return VerifyReductionInputAndDims(op.getInput(), op.getReductionIndices(), op.getLoc()); } @@ -209,7 +212,7 @@ LogicalResult AllOp::verify() { LogicalResult AnyOp::verify() { AnyOp op = *this; - return VerifyReductionInputAndDims(op.input(), op.reduction_indices(), + return VerifyReductionInputAndDims(op.getInput(), op.getReductionIndices(), op.getLoc()); } @@ -224,9 +227,9 @@ struct AssertWithTrue : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AssertOp op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { ElementsAttr cst; - if (matchPattern(op.condition(), m_Constant(&cst))) { + if (matchPattern(op.getCondition(), m_Constant(&cst))) { if (cst.getValues()[0]) { rewriter.eraseOp(op); return success(); @@ -237,28 +240,46 @@ struct AssertWithTrue : public OpRewritePattern { }; } // namespace -void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void AssertOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } +//===----------------------------------------------------------------------===// +// BatchFunctionOp +//===----------------------------------------------------------------------===// + +LogicalResult BatchFunctionOp::verifySymbolUses( + SymbolTableCollection& symbolTable) { + StringAttr func_attr = getFAttr().getRootReference(); + func::FuncOp func = + symbolTable.lookupNearestSymbolFrom(*this, func_attr); + + if (!func) { + return emitError("'f' attribute refers to an undefined function: ") + << func_attr.getValue(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // BatchMatMulV2Op & BatchMatMulOp //===----------------------------------------------------------------------===// template ::value>::type * = nullptr> + OpT, BatchMatMulOp, BatchMatMulV2Op>::value>::type* = nullptr> static LogicalResult Verify(OpT op) { - if (!HasRankAtLeast(op.x(), 2)) { + if (!HasRankAtLeast(op.getX(), 2)) { return op.emitOpError("requires lhs operand to have rank at least two"); } - if (!HasRankAtLeast(op.y(), 2)) { + if (!HasRankAtLeast(op.getY(), 2)) { return op.emitOpError("requires rhs operand to have rank at least two"); } - RankedTensorType x_ty = GetRankedTensorTypeForOperand(op.x()); - RankedTensorType y_ty = GetRankedTensorTypeForOperand(op.y()); + RankedTensorType x_ty = GetRankedTensorTypeForOperand(op.getX()); + RankedTensorType y_ty = GetRankedTensorTypeForOperand(op.getY()); if (!x_ty || !y_ty) return success(); @@ -276,7 +297,7 @@ static LogicalResult Verify(OpT op) { // The last two dimensions are non-batch dimensions that don't need to // participate in batch dimension compatibility check. if (std::is_same()) { - for (const auto &dim_pairs : llvm::zip(x_batches, y_batches)) { + for (const auto& dim_pairs : llvm::zip(x_batches, y_batches)) { int64_t x_dim = std::get<0>(dim_pairs); int64_t y_dim = std::get<1>(dim_pairs); if (!ShapedType::isDynamic(x_dim) && !ShapedType::isDynamic(y_dim) && @@ -294,7 +315,7 @@ static LogicalResult Verify(OpT op) { << x_ty << " and rhs shape " << y_ty; } - RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output()); + RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.getOutput()); if (!output_ty) return success(); int64_t expected_output_rank = std::max(x_ty.getRank(), y_ty.getRank()); @@ -306,8 +327,8 @@ static LogicalResult Verify(OpT op) { // Check output batch dim with potential broadcasting. ArrayRef output_shape = output_ty.getShape(); for (int i = 0; i < result_batch_shape.size(); ++i) { - if (output_shape[i] != ShapedType::kDynamicSize && - result_batch_shape[i] != ShapedType::kDynamicSize && + if (output_shape[i] != ShapedType::kDynamic && + result_batch_shape[i] != ShapedType::kDynamic && output_shape[i] != result_batch_shape[i]) return op.emitOpError() << "has mismatching input batch dimension " @@ -324,17 +345,17 @@ static LogicalResult Verify(OpT op) { int64_t out_row_dim = output_shape[output_shape.size() - 2]; int64_t out_col_dim = output_shape[output_shape.size() - 1]; - int64_t expected_out_row_dim = op.adj_x() ? x_col_dim : x_row_dim; - int64_t expected_out_col_dim = op.adj_y() ? y_row_dim : y_col_dim; + int64_t expected_out_row_dim = op.getAdjX() ? x_col_dim : x_row_dim; + int64_t expected_out_col_dim = op.getAdjY() ? y_row_dim : y_col_dim; - if (expected_out_row_dim != ShapedType::kDynamicSize && - out_row_dim != ShapedType::kDynamicSize && + if (expected_out_row_dim != ShapedType::kDynamic && + out_row_dim != ShapedType::kDynamic && out_row_dim != expected_out_row_dim) return op.emitOpError() << "found invalid output dimension on row, expected " << expected_out_row_dim << " but got " << out_row_dim; - if (expected_out_col_dim != ShapedType::kDynamicSize && - out_col_dim != ShapedType::kDynamicSize && + if (expected_out_col_dim != ShapedType::kDynamic && + out_col_dim != ShapedType::kDynamic && out_col_dim != expected_out_col_dim) return op.emitOpError() << "found invalid output dimension on col, expected " @@ -345,13 +366,13 @@ static LogicalResult Verify(OpT op) { LogicalResult BatchMatMulOp::verify() { return Verify(*this); } LogicalResult BatchMatMulV2Op::verify() { return Verify(*this); } -void BatchMatMulOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void BatchMatMulOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } -void BatchMatMulV2Op::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void BatchMatMulV2Op::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -362,17 +383,17 @@ void BatchMatMulV2Op::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult BatchToSpaceOp::verify() { BatchToSpaceOp op = *this; // Op already has a constraint that block_size >= 2. - int64_t block_size = op.block_size(); + int64_t block_size = op.getBlockSize(); - llvm::SmallVector input_shape(4, ShapedType::kDynamicSize); - auto input_type = op.input().getType().cast(); + llvm::SmallVector input_shape(4, ShapedType::kDynamic); + auto input_type = op.getInput().getType().cast(); if (input_type.hasRank()) { if (input_type.getRank() != 4) return op.emitOpError() << "requires input to be a 4D tensor, but got " << input_type; int64_t input_batch = input_type.getDimSize(0); - if (input_batch != ShapedType::kDynamicSize && + if (input_batch != ShapedType::kDynamic && input_batch % (block_size * block_size) != 0) { return op.emitOpError() << "requires input batch (dimension 0) to be evenly divisible " @@ -384,7 +405,7 @@ LogicalResult BatchToSpaceOp::verify() { input_type.getShape().end()); } - auto crops_type = op.crops().getType().cast(); + auto crops_type = op.getCrops().getType().cast(); if (crops_type.hasRank()) { if (crops_type.getRank() != 2) return op.emitOpError() @@ -403,12 +424,12 @@ LogicalResult BatchToSpaceOp::verify() { // Crops are defined as [[crop_top, crop_bottom], [crop_left, crop_right]], // and flattened as [crop_top, crop_bottom, crop_left, crop_right] llvm::SmallVector crops_values; - if (matchPattern(op.crops(), m_Constant(&crops_attr))) { + if (matchPattern(op.getCrops(), m_Constant(&crops_attr))) { assert(crops_attr.getNumElements() == 4 && "tf.BatchToSpace crops must have 4 elements"); auto crops_range = crops_attr.getValues(); - for (const auto &crops_value : crops_range) { + for (const auto& crops_value : crops_range) { int64_t crops_value_int = crops_value.getSExtValue(); if (crops_value_int < 0) return op.emitOpError() @@ -419,15 +440,14 @@ LogicalResult BatchToSpaceOp::verify() { } } - auto output_type = op.output().getType().cast(); + auto output_type = op.getOutput().getType().cast(); if (output_type.hasRank()) { if (output_type.getRank() != 4) return op.emitOpError() << "requires output to be a 4D tensor, but got " << output_type; auto static_dims = [](int64_t dim_a, int64_t dim_b) { - return dim_a != ShapedType::kDynamicSize && - dim_b != ShapedType::kDynamicSize; + return dim_a != ShapedType::kDynamic && dim_b != ShapedType::kDynamic; }; auto output_shape = output_type.getShape(); @@ -499,8 +519,8 @@ LogicalResult BatchToSpaceOp::verify() { return success(); } -void BatchToSpaceOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void BatchToSpaceOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -510,8 +530,8 @@ void BatchToSpaceOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult BatchToSpaceNDOp::verify() { BatchToSpaceNDOp op = *this; - auto block_shape_ty = op.block_shape().getType().cast(); - auto crops_ty = op.crops().getType().cast(); + auto block_shape_ty = op.getBlockShape().getType().cast(); + auto crops_ty = op.getCrops().getType().cast(); if (block_shape_ty.hasStaticShape() && crops_ty.hasStaticShape()) { const int block_rank = block_shape_ty.getShape().front(); @@ -537,38 +557,40 @@ LogicalResult BatchToSpaceNDOp::verify() { // LogicalResult BiasAddOp::verify() { BiasAddOp op = *this; - absl::string_view data_format(op.data_format().data(), - op.data_format().size()); + absl::string_view data_format(op.getDataFormat().data(), + op.getDataFormat().size()); tensorflow::TensorFormat format; bool is_valid = FormatFromString(data_format, &format); DCHECK(is_valid) << data_format; if (format == tensorflow::TensorFormat::FORMAT_NHWC) { - if (!HasRankAtLeast(op.value(), 2)) + if (!HasRankAtLeast(op.getValue(), 2)) return op.emitOpError( "requires value operand to have rank at least two with `NHWC` data " "format"); } else { // Op definition requires data_format to be either NHWC or NCHW. DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW); - if (!HasRankAtLeast(op.value(), 3)) + if (!HasRankAtLeast(op.getValue(), 3)) return op.emitOpError( "requires value operand to have rank at least three with `NCHW` data " "format"); } - if (!IsOfRankOrUnranked(op.bias(), 1)) + if (!IsOfRankOrUnranked(op.getBias(), 1)) return op.emitOpError("requires bias operand to have rank exactly one"); - RankedTensorType value_ty = op.value().getType().dyn_cast(); - RankedTensorType bias_ty = op.bias().getType().dyn_cast(); + RankedTensorType value_ty = + op.getValue().getType().dyn_cast(); + RankedTensorType bias_ty = + op.getBias().getType().dyn_cast(); if (!bias_ty || !value_ty) return success(); int64_t feature_dim_idx = tensorflow::GetTensorFeatureDimIndex(value_ty.getRank(), format); int64_t feature_dim = value_ty.getDimSize(feature_dim_idx); int64_t bias_len = bias_ty.getDimSize(0); - if (feature_dim != ShapedType::kDynamicSize && - bias_len != ShapedType::kDynamicSize && feature_dim != bias_len) { + if (feature_dim != ShapedType::kDynamic && bias_len != ShapedType::kDynamic && + feature_dim != bias_len) { return op.emitOpError() << "requires channel dimension and feature dimension to match; " "found " @@ -581,11 +603,11 @@ LogicalResult BiasAddOp::UpdateDataFormat(StringRef data_format) { return ::mlir::TF::UpdateDataFormat(data_format, this); } -StringRef BiasAddOp::GetOptimalLayout(const RuntimeDevices &devices) { +StringRef BiasAddOp::GetOptimalLayout(const RuntimeDevices& devices) { // Keep current data format if no GPUs are available or if explicit placement // does not allow to use GPU for this operation. if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) - return data_format(); + return getDataFormat(); // Prefer NHWC for GPU devices. return "NHWC"; @@ -600,20 +622,20 @@ StringRef BiasAddOp::GetOptimalLayout(const RuntimeDevices &devices) { // LogicalResult BiasAddGradOp::verify() { BiasAddGradOp op = *this; - absl::string_view data_format(op.data_format().data(), - op.data_format().size()); + absl::string_view data_format(op.getDataFormat().data(), + op.getDataFormat().size()); tensorflow::TensorFormat format; bool is_valid = FormatFromString(data_format, &format); DCHECK(is_valid) << data_format; if (format == tensorflow::TensorFormat::FORMAT_NHWC) { - if (!HasRankAtLeast(op.out_backprop(), 2)) + if (!HasRankAtLeast(op.getOutBackprop(), 2)) return op.emitOpError( "requires out_backprop operand to have rank at least two with `NHWC` " "data format"); } else { // Op definition requires data_format to be either NHWC or NCHW. DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW); - if (!HasRankAtLeast(op.out_backprop(), 3)) + if (!HasRankAtLeast(op.getOutBackprop(), 3)) return op.emitOpError( "requires out_backprop operand to have rank at least three with " "`NCHW` data format"); @@ -626,8 +648,8 @@ LogicalResult BiasAddGradOp::verify() { // BiasAddV1Op //===----------------------------------------------------------------------===// -void BiasAddV1Op::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void BiasAddV1Op::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -635,8 +657,8 @@ void BiasAddV1Op::getCanonicalizationPatterns(RewritePatternSet &results, // arith::BitcastOp //===----------------------------------------------------------------------===// -void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void BitcastOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -652,8 +674,8 @@ LogicalResult BroadcastToOp::verify() { return success(); } -OpFoldResult BroadcastToOp::fold(ArrayRef operands) { - Value input = this->input(); +OpFoldResult BroadcastToOp::fold(FoldAdaptor) { + Value input = this->getInput(); // Fold broadcast if operand and result types are the same and all dimensions // are statically known (no-op broadcast). @@ -677,11 +699,11 @@ namespace { // Returns `true` if both s0 & s1 are defined via constant op, and fills // s0_shape & s1_shape. bool ExtractInputConstShape(BroadcastGradientArgsOp op, - DenseIntElementsAttr &s0, DenseIntElementsAttr &s1, - SmallVectorImpl &s0_shape, - SmallVectorImpl &s1_shape) { - if (!matchPattern(op.s0(), m_Constant(&s0))) return false; - if (!matchPattern(op.s1(), m_Constant(&s1))) return false; + DenseIntElementsAttr& s0, DenseIntElementsAttr& s1, + SmallVectorImpl& s0_shape, + SmallVectorImpl& s1_shape) { + if (!matchPattern(op.getS0(), m_Constant(&s0))) return false; + if (!matchPattern(op.getS1(), m_Constant(&s1))) return false; for (auto s : s0.getValues()) s0_shape.push_back(s.getSExtValue()); for (auto s : s1.getValues()) s1_shape.push_back(s.getSExtValue()); @@ -699,8 +721,8 @@ bool ExtractInputConstShape(BroadcastGradientArgsOp op, void GetOutputShapeForBroadcastGradientArgs(ArrayRef bcasted_shape, ArrayRef s0_shape, ArrayRef s1_shape, - SmallVectorImpl &r0, - SmallVectorImpl &r1) { + SmallVectorImpl& r0, + SmallVectorImpl& r1) { r0.clear(); r1.clear(); @@ -759,8 +781,8 @@ LogicalResult BroadcastGradientArgsOp::verify() { // Verify that output types are of rank one and matches the computed result // shape. - auto r0_ty = op.r0().getType().dyn_cast(); - auto r1_ty = op.r1().getType().dyn_cast(); + auto r0_ty = op.getR0().getType().dyn_cast(); + auto r1_ty = op.getR1().getType().dyn_cast(); if (r0_ty && r0_ty.hasStaticShape() && r0_ty.getDimSize(0) != r0.size()) return op.emitOpError() << "requires dimension 0 size of 'r0' to be " << r0.size() << " but got " << r0_ty.getShape()[0]; @@ -772,7 +794,7 @@ LogicalResult BroadcastGradientArgsOp::verify() { } LogicalResult BroadcastGradientArgsOp::fold( - ArrayRef operands, SmallVectorImpl &results) { + FoldAdaptor, SmallVectorImpl& results) { SmallVector s0_shape, s1_shape; DenseIntElementsAttr s0, s1; if (!ExtractInputConstShape(*this, s0, s1, s0_shape, s1_shape)) @@ -791,7 +813,7 @@ LogicalResult BroadcastGradientArgsOp::fold( GetOutputShapeForBroadcastGradientArgs(bcasted_shape, s0_shape, s1_shape, r0, r1); - auto build_out_dense_element = [](SmallVectorImpl &shape, + auto build_out_dense_element = [](SmallVectorImpl& shape, Type input_type) { Type element_type = input_type.cast().getElementType(); RankedTensorType type = tensorflow::GetTypeFromTFTensorShape( @@ -807,8 +829,8 @@ LogicalResult BroadcastGradientArgsOp::fold( } }; - results.push_back(build_out_dense_element(r0, this->s0().getType())); - results.push_back(build_out_dense_element(r1, this->s1().getType())); + results.push_back(build_out_dense_element(r0, this->getS0().getType())); + results.push_back(build_out_dense_element(r1, this->getS1().getType())); return success(); } @@ -819,22 +841,22 @@ LogicalResult BroadcastGradientArgsOp::fold( class FoldConstantCaseOp : public OpRewritePattern { public: - explicit FoldConstantCaseOp(MLIRContext *context) + explicit FoldConstantCaseOp(MLIRContext* context) : OpRewritePattern(context) {} LogicalResult matchAndRewrite(TF::CaseOp op, - PatternRewriter &rewriter) const override; + PatternRewriter& rewriter) const override; }; LogicalResult FoldConstantCaseOp::matchAndRewrite( - TF::CaseOp op, PatternRewriter &rewriter) const { + TF::CaseOp op, PatternRewriter& rewriter) const { // Extract the constant cond value. DenseIntElementsAttr branch; - if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure(); + if (!matchPattern(op.getBranchIndex(), m_Constant(&branch))) return failure(); int index = *branch.getValues().begin(); if (index < 0 || index >= op.num_branches()) index = op.num_branches() - 1; - auto func = op.branches()[index].cast(); + auto func = op.getBranches()[index].cast(); auto empty = rewriter.getStringAttr(""); ReplaceTfOpWithNewOp( rewriter, op, op.getResultTypes(), op.getOperands().drop_front(), func, @@ -842,12 +864,12 @@ LogicalResult FoldConstantCaseOp::matchAndRewrite( return success(); } -void CaseOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void CaseOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add>(context); } -static LogicalResult VerifyCaseOpBase(Operation *op, Value branch_index) { +static LogicalResult VerifyCaseOpBase(Operation* op, Value branch_index) { if (!IsOfRankOrUnranked(branch_index, 0)) return op->emitOpError() << "expects 'branch_index' to be a scalar, but got " @@ -856,7 +878,7 @@ static LogicalResult VerifyCaseOpBase(Operation *op, Value branch_index) { } static LogicalResult VerifyCaseOrIfOpBranchFunctions( - SymbolTableCollection &symbol_table, Operation *op, + SymbolTableCollection& symbol_table, Operation* op, ArrayRef branches, llvm::function_ref branch_name) { SmallVector branch_types; @@ -901,7 +923,7 @@ static LogicalResult VerifyCaseOrIfOpBranchFunctions( branch_input_i_types.reserve(branches.size()); llvm::transform( branch_types, std::back_inserter(branch_input_i_types), - [i](FunctionType &branch_type) { return branch_type.getInput(i); }); + [i](FunctionType& branch_type) { return branch_type.getInput(i); }); if (!AreCastCompatible(branch_input_i_types)) { std::string input_types_str; llvm::raw_string_ostream os(input_types_str); @@ -917,15 +939,15 @@ static LogicalResult VerifyCaseOrIfOpBranchFunctions( LogicalResult CaseOp::verify() { CaseOp op = *this; - return VerifyCaseOpBase(op, op.branch_index()); + return VerifyCaseOpBase(op, op.getBranchIndex()); } -LogicalResult CaseOp::verifySymbolUses(SymbolTableCollection &symbol_table) { +LogicalResult CaseOp::verifySymbolUses(SymbolTableCollection& symbol_table) { auto branch_name = [](unsigned index) { return llvm::formatv("branch #{0}", index).str(); }; return VerifyCaseOrIfOpBranchFunctions(symbol_table, *this, - branches().getValue(), branch_name); + getBranches().getValue(), branch_name); } //===----------------------------------------------------------------------===// @@ -934,17 +956,17 @@ LogicalResult CaseOp::verifySymbolUses(SymbolTableCollection &symbol_table) { LogicalResult CaseRegionOp::verify() { CaseRegionOp op = *this; - if (op.branches().empty()) + if (op.getBranches().empty()) return op.emitOpError() << "expects to have at least 1 region"; - if (failed(VerifyCaseOpBase(op, op.branch_index()))) return failure(); + if (failed(VerifyCaseOpBase(op, op.getBranchIndex()))) return failure(); TypeRangeWithDesc results{op.getResultTypes(), "result"}; - for (auto region_and_idx : llvm::enumerate(op.branches())) { + for (auto region_and_idx : llvm::enumerate(op.getBranches())) { std::string description = llvm::formatv("branch #{0} result", region_and_idx.index()).str(); - Operation *yield = region_and_idx.value().front().getTerminator(); + Operation* yield = region_and_idx.value().front().getTerminator(); TypeRangeWithDesc branch_results{yield->getOperandTypes(), description}; if (failed(VerifyTypeRangesAreCompatible(op, branch_results, results))) return failure(); @@ -961,7 +983,7 @@ class CaseOrIfRegionEliminatePassThrough using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CaseOrIfRegionOp op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { RegionRange branches = op.getRegions(); SmallVector new_result_types; // Maps pass through results to extern values. @@ -969,8 +991,8 @@ class CaseOrIfRegionEliminatePassThrough for (auto result : op.getResults()) { unsigned index = result.getResultNumber(); - Region *first_branch = *branches.begin(); - Operation *first_terminator = first_branch->front().getTerminator(); + Region* first_branch = *branches.begin(); + Operation* first_terminator = first_branch->front().getTerminator(); Value returned_val = first_terminator->getOperand(index); // Pass through values would be defined outside the branch region. Keep @@ -981,8 +1003,8 @@ class CaseOrIfRegionEliminatePassThrough continue; } // Check if the same extern value is returned in each branch. - for (Region *region : branches.drop_front()) { - Operation *terminator = region->front().getTerminator(); + for (Region* region : branches.drop_front()) { + Operation* terminator = region->front().getTerminator(); if (terminator->getOperand(index) != returned_val) return failure(); } result_to_extern_value[result] = returned_val; @@ -1003,7 +1025,7 @@ class CaseOrIfRegionEliminatePassThrough continue; } result.replaceAllUsesWith(result_to_extern_value[result]); - for (Region *branch : branches) + for (Region* branch : branches) branch->front().getTerminator()->eraseOperand(next_index); } @@ -1017,8 +1039,8 @@ class CaseOrIfRegionEliminatePassThrough }; } // namespace -void CaseRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void CaseRegionOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add>(context); } @@ -1026,23 +1048,47 @@ void CaseRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, // CastOp //===----------------------------------------------------------------------===// -OpFoldResult CastOp::fold(ArrayRef operands) { +OpFoldResult CastOp::fold(FoldAdaptor) { // Cast with the same type is a no-op. Value operand = getOperand(); if (getType() == operand.getType()) return operand; return {}; } +//===----------------------------------------------------------------------===// +// CollectiveReduceV2Op +//===----------------------------------------------------------------------===// + +// For `CollectiveReduceV2Op` we have two cases: +// 1) If at least one ordering token is present, then we purely rely on ordering +// tokens for side effect modeling and ignore the op-based effect +// `TF_CollectiveReduceOrderingEffect` for which this function is relevant +// (note that returning `std::nullopt` here signals exactly that). +// 2) If no ordering token is present, then we treat the op conservatively which +// means that different op instances need dependencies. This is realized by +// always returning the same string ("") in this case. In fact, we could +// return any string here, as long as it is the same string for all op +// instances without ordering tokens. +std::optional CollectiveReduceV2Op::GetResourceInstanceStr() { + return getNorderingToken() == 0 ? std::optional("") + : std::nullopt; +} + +std::optional +CollectiveReduceScatterV2Op::GetResourceInstanceStr() { + return getNorderingToken() == 0 ? std::optional("") + : std::nullopt; +} + //===----------------------------------------------------------------------===// // ConcatOp and ConcatV2Op //===----------------------------------------------------------------------===// -template ::value>::type * = nullptr> +template ::value>::type* = nullptr> static LogicalResult Verify(OpT op) { // TODO(hinsu): Convert variadic length attributes to derived attributes. - Operation::operand_range values = op.values(); + Operation::operand_range values = op.getValues(); int axis_idx = std::is_same() ? 0 : 1; Value axis = *op.getODSOperands(axis_idx).begin(); @@ -1059,8 +1105,8 @@ static LogicalResult Verify(OpT op) { LogicalResult ConcatOp::verify() { return Verify(*this); } LogicalResult ConcatV2Op::verify() { return Verify(*this); } -void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void ConcatOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -1080,46 +1126,46 @@ namespace { // %1 = "tf.Log1p"(%0) class HoistCwiseUnaryOutOfConcat : public OpRewritePattern { public: - explicit HoistCwiseUnaryOutOfConcat(MLIRContext *context) + explicit HoistCwiseUnaryOutOfConcat(MLIRContext* context) : OpRewritePattern(context) {} LogicalResult matchAndRewrite(TF::ConcatV2Op op, - PatternRewriter &rewriter) const override; + PatternRewriter& rewriter) const override; }; LogicalResult HoistCwiseUnaryOutOfConcat::matchAndRewrite( - TF::ConcatV2Op op, PatternRewriter &rewriter) const { + TF::ConcatV2Op op, PatternRewriter& rewriter) const { auto loc = op.getLoc(); // All concat operands must be defined by ops. - Operation *first_arg_op = op.values().front().getDefiningOp(); + Operation* first_arg_op = op.getValues().front().getDefiningOp(); if (first_arg_op == nullptr) return failure(); // All concat operands must be produced by the coeff-wise unary operation. if (!first_arg_op->hasTrait()) return failure(); // All concat operands must be defined by the op of same kind. - bool args_same_op = llvm::all_of(op.values(), [&](Value arg) -> bool { - Operation *arg_op = arg.getDefiningOp(); + bool args_same_op = llvm::all_of(op.getValues(), [&](Value arg) -> bool { + Operation* arg_op = arg.getDefiningOp(); return arg_op && arg_op->getName() == first_arg_op->getName(); }); if (!args_same_op) return failure(); // Collect unary operations operands. - auto unary_operands = llvm::map_range(op.values(), [](Value arg) -> Value { + auto unary_operands = llvm::map_range(op.getValues(), [](Value arg) -> Value { return arg.getDefiningOp()->getOperand(0); }); SmallVector unary_ops_args(unary_operands); // Concatenate unary ops operands. - auto concat_unary_operands = - rewriter.create(loc, op.getType(), unary_ops_args, op.axis()); + auto concat_unary_operands = rewriter.create( + loc, op.getType(), unary_ops_args, op.getAxis()); // Replace original concat with an unary op. OperationState new_unary_op_state(loc, first_arg_op->getName().getStringRef(), concat_unary_operands.getResult(), op.getResult().getType(), ArrayRef()); - Operation *new_unary_op = rewriter.create(new_unary_op_state); + Operation* new_unary_op = rewriter.create(new_unary_op_state); rewriter.replaceOp(op, new_unary_op->getResults()); @@ -1153,10 +1199,10 @@ LogicalResult HoistCwiseUnaryOutOfConcat::matchAndRewrite( // produce incorrect concat operations. class HoistCwiseBinaryOutOfConcat : public OpRewritePattern { public: - explicit HoistCwiseBinaryOutOfConcat(MLIRContext *context) + explicit HoistCwiseBinaryOutOfConcat(MLIRContext* context) : OpRewritePattern(context) {} LogicalResult matchAndRewrite(TF::ConcatV2Op op, - PatternRewriter &rewriter) const override; + PatternRewriter& rewriter) const override; private: struct HoistParams { @@ -1174,18 +1220,18 @@ class HoistCwiseBinaryOutOfConcat : public OpRewritePattern { // All inputs of `op` should be of the same binary op kind (e.g. tf.Mul), // except from the ones in `exceptions`. In that case, we can synthesize that // binary op kind for the values in `exceptions`. - Optional GetHoistParams( + std::optional GetHoistParams( TF::ConcatV2Op op, int64_t axis, - const llvm::SmallDenseMap &exceptions) const; + const llvm::SmallDenseMap& exceptions) const; }; LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( - TF::ConcatV2Op op, PatternRewriter &rewriter) const { + TF::ConcatV2Op op, PatternRewriter& rewriter) const { auto loc = op.getLoc(); // Axis must be a constant scalar value. DenseIntElementsAttr axis_attr; - if (!matchPattern(op.axis(), m_Constant(&axis_attr))) return failure(); + if (!matchPattern(op.getAxis(), m_Constant(&axis_attr))) return failure(); if (axis_attr.getNumElements() != 1) return failure(); int64_t axis = axis_attr.getSplatValue().getValue().getSExtValue(); @@ -1198,7 +1244,7 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( // (e.g. converting op A to tf.Mul(A, 1.0)) // TODO(hongm): generalize the code here to support cases where the first arg // has no defining op (e.g. might be a block arg). - Operation *first_arg_op = op.values().front().getDefiningOp(); + Operation* first_arg_op = op.getValues().front().getDefiningOp(); if (first_arg_op == nullptr) return failure(); // All concat operands must be produced by the coeff-wise binary operation. @@ -1209,8 +1255,8 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( // Map from the operands to operand indices. llvm::SmallDenseMap exceptions; unsigned operand_idx = 0; - for (Value arg : op.values()) { - Operation *arg_op = arg.getDefiningOp(); + for (Value arg : op.getValues()) { + Operation* arg_op = arg.getDefiningOp(); if (arg_op && arg_op->getName() == first_arg_op->getName()) { ++operand_idx; continue; @@ -1225,7 +1271,7 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( // out of 2 concat inputs is an exception, we don't apply the hoist. If it's 1 // out of 3, we do. const float exception_pct_threshold = 0.5; - if (static_cast(op.values().size()) * exception_pct_threshold <= + if (static_cast(op.getValues().size()) * exception_pct_threshold <= exceptions.size()) return failure(); @@ -1258,7 +1304,7 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( // All checks are passes, and we now prepare for rewrite. auto identity_const = rewriter.create(loc, const_attr); - for (const auto &kv : exceptions) { + for (const auto& kv : exceptions) { assert(!hoist_params->lhs_args[kv.second]); assert(!hoist_params->rhs_args[kv.second]); @@ -1313,21 +1359,21 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( OperationState new_binary_op_state( loc, first_arg_op->getName().getStringRef(), {lhs_concat, rhs_concat}, op.getResult().getType(), ArrayRef()); - Operation *new_binary_op = rewriter.create(new_binary_op_state); + Operation* new_binary_op = rewriter.create(new_binary_op_state); rewriter.replaceOp(op, new_binary_op->getResults()); return success(); } -Optional +std::optional HoistCwiseBinaryOutOfConcat::GetHoistParams( TF::ConcatV2Op op, int64_t axis, - const llvm::SmallDenseMap &exceptions) const { + const llvm::SmallDenseMap& exceptions) const { assert(axis >= 0); // Collects lhs or rhs arguments of concat op operands. auto args = [&](int operand_idx) -> SmallVector { - auto range = llvm::map_range(op.values(), [&](Value arg) { + auto range = llvm::map_range(op.getValues(), [&](Value arg) { if (exceptions.count(arg)) return Value(); return arg.getDefiningOp()->getOperand(operand_idx); }); @@ -1337,7 +1383,7 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams( // Returns true if all binary ops operands at `operand_idx` index are tensors // of `axis + 1` rank and axis dim has size `1`. auto is_all_tensors = [&](int operand_idx, int axis) -> bool { - return llvm::all_of(op.values(), [&](Value arg) -> bool { + return llvm::all_of(op.getValues(), [&](Value arg) -> bool { mlir::Value operand; if (exceptions.count(arg)) { // For exceptions, since we are going to synthesize a binary op that @@ -1355,7 +1401,7 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams( // Returns true if all binary ops operands at `operand_idx` index are scalars. auto is_all_scalars = [&](int operand_idx) -> bool { - return llvm::all_of(op.values(), [&](Value arg) -> bool { + return llvm::all_of(op.getValues(), [&](Value arg) -> bool { if (exceptions.count(arg)) return true; auto operand = arg.getDefiningOp()->getOperand(operand_idx); auto ranked = operand.getType().dyn_cast(); @@ -1365,7 +1411,7 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams( // Concat result type must be a ranked tensor. auto ranked = op.getType().dyn_cast(); - if (!ranked) return None; + if (!ranked) return std::nullopt; // TODO(ezhulenev): Add support for more valid concat patterns. @@ -1376,7 +1422,8 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams( // Concatenate tensor arguments on the same axis as the original operation, // and concatenate scalars into the vector. if (is_all_tensors(0, axis) && is_all_scalars(1)) { - std::array rhs_dims{static_cast(op.values().size())}; + std::array rhs_dims{ + static_cast(op.getValues().size())}; auto rhs_type = tensorflow::GetTypeFromTFTensorShape(rhs_dims, ranked.getElementType()); return HoistParams{args(0), @@ -1387,7 +1434,8 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams( rhs_type, /*scalar_operand_idx=*/1}; } else if (is_all_tensors(1, axis) && is_all_scalars(0)) { - std::array lhs_dims{static_cast(op.values().size())}; + std::array lhs_dims{ + static_cast(op.getValues().size())}; auto lhs_type = tensorflow::GetTypeFromTFTensorShape(lhs_dims, ranked.getElementType()); return HoistParams{args(0), @@ -1398,13 +1446,13 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams( op.getType(), /*scalar_operand_idx=*/0}; } - return None; + return std::nullopt; } } // namespace -void ConcatV2Op::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void ConcatV2Op::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -1414,15 +1462,15 @@ void ConcatV2Op::getCanonicalizationPatterns(RewritePatternSet &results, template ::value>::type * = + OpT, CumsumOp, CumulativeLogsumexpOp, CumprodOp>::value>::type* = nullptr> static LogicalResult Verify(OpT op) { - if (!IsOfRankOrUnranked(op.axis(), 0)) + if (!IsOfRankOrUnranked(op.getAxis(), 0)) return op.emitOpError("requires scalar axis operand"); DenseIntElementsAttr axis_attr; - if (matchPattern(op.axis(), m_Constant(&axis_attr))) { - auto input_ty = op.x().getType().template dyn_cast(); + if (matchPattern(op.getAxis(), m_Constant(&axis_attr))) { + auto input_ty = op.getX().getType().template dyn_cast(); if (input_ty) { int64_t rank = input_ty.getRank(); assert(axis_attr.getNumElements() == 1 && @@ -1448,15 +1496,15 @@ LogicalResult CumulativeLogsumexpOp::verify() { return Verify(*this); } LogicalResult ConcatOffsetOp::verify() { ConcatOffsetOp op = *this; - if (op.N() < 2) - return op.emitOpError() << "requires N to be at least 2, got " << op.N(); + if (op.getN() < 2) + return op.emitOpError() << "requires N to be at least 2, got " << op.getN(); - if (op.shape().size() != op.offset().size()) + if (op.getShape().size() != op.getOffset().size()) return op.emitOpError() << "requires sizes of shapes and offsets to be the same, got sizes " - << op.shape().size() << " and " << op.offset().size(); + << op.getShape().size() << " and " << op.getOffset().size(); - auto ranked_dim = op.concat_dim().getType().dyn_cast(); + auto ranked_dim = op.getConcatDim().getType().dyn_cast(); if (ranked_dim && ranked_dim.getRank() != 0) return op.emitOpError() << "requires concat_dim to be a scalar, got tensor of rank " @@ -1464,7 +1512,7 @@ LogicalResult ConcatOffsetOp::verify() { int64_t num_dims = -1; for (auto shape_offset_idx : - llvm::enumerate(llvm::zip(op.shape(), op.offset()))) { + llvm::enumerate(llvm::zip(op.getShape(), op.getOffset()))) { Value shape = std::get<0>(shape_offset_idx.value()); Value offset = std::get<1>(shape_offset_idx.value()); const size_t idx = shape_offset_idx.index(); @@ -1496,8 +1544,9 @@ LogicalResult ConcatOffsetOp::verify() { return success(); } -LogicalResult ConcatOffsetOp::fold(ArrayRef operands, - SmallVectorImpl &results) { +LogicalResult ConcatOffsetOp::fold(FoldAdaptor adaptor, + SmallVectorImpl& results) { + auto operands = adaptor.getOperands(); // ConcatOffset must have its first operand be concat_dim and at least two // shape tensors in variadic shapes operand. if (operands.size() < 3) return failure(); @@ -1564,11 +1613,11 @@ void ConstOp::getAsmResultNames( setNameFn(getResult(), "cst"); } -OpFoldResult ConstOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); +OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { + assert(adaptor.getOperands().empty() && "constant has no operands"); // Return the held attribute value. - return value(); + return getValue(); } // Builds a constant op with the specified attribute `value`. The result @@ -1576,7 +1625,7 @@ OpFoldResult ConstOp::fold(ArrayRef operands) { // wraps it up with a tensor type of empty shape. // TODO(jpienaar): This one differs from the autogenerated one as it takes an // attribute but always creates an ElementsAttr internally. -void ConstOp::build(OpBuilder &builder, OperationState &result, +void ConstOp::build(OpBuilder& builder, OperationState& result, Attribute value) { ShapedType type; if (auto elem_attr = value.dyn_cast()) { @@ -1595,7 +1644,7 @@ void ConstOp::build(OpBuilder &builder, OperationState &result, llvm_unreachable("unsupported attribute type for building tf.Const"); } -void ConstOp::build(OpBuilder &builder, OperationState &result, Type type, +void ConstOp::build(OpBuilder& builder, OperationState& result, Type type, Attribute value) { // Handle the case where the type and value are already tensors. if (type.isa() && value.isa()) { @@ -1610,9 +1659,9 @@ void ConstOp::build(OpBuilder &builder, OperationState &result, Type type, } LogicalResult ConstOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, + MLIRContext* context, std::optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { + SmallVectorImpl& inferredReturnTypes) { auto value = attributes.get("value"); if (!value) return emitOptionalError(location, "missing attribute 'value'"); if (auto elem_attr = value.dyn_cast()) { @@ -1630,7 +1679,7 @@ LogicalResult ConstOp::inferReturnTypes( static LogicalResult VerifyConvOpAttributes( int num_dims, ArrayRef strides, ArrayRef dilations, - llvm::Optional location) { + std::optional location) { int64_t strides_size = strides.size(); if (strides_size != num_dims) return emitOptionalError( @@ -1655,19 +1704,19 @@ static LogicalResult VerifyConvOpAttributes( // * Number of input channels is divisible by the number of filter input // channels template ::value>::type * = nullptr> + OpT, Conv2DOp, Conv3DOp>::value>::type* = nullptr> static LogicalResult Verify(OpT op) { int num_spatial_dims = std::is_same() ? 2 : 3; int num_dims = 2 + num_spatial_dims; - StringRef data_format = op.data_format(); + StringRef data_format = op.getDataFormat(); tensorflow::TensorFormat format; auto data_format_is_valid = FormatFromString(data_format.str(), &format); if (!data_format_is_valid) { return emitOptionalError(op.getLoc(), "Invalid data format provided"); } - const StringRef paddings = op.padding(); + const StringRef paddings = op.getPadding(); tensorflow::Padding padding; auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding); if (!padding_is_valid.ok()) { @@ -1679,8 +1728,8 @@ static LogicalResult Verify(OpT op) { // * Length of explicit_paddings attribute is valid and has non negative // elements // * strides and dilations attributes have positive elements - if (!IsOfRankOrUnranked(op.input(), num_dims) || - !IsOfRankOrUnranked(op.filter(), num_dims)) + if (!IsOfRankOrUnranked(op.getInput(), num_dims) || + !IsOfRankOrUnranked(op.getFilter(), num_dims)) return emitOptionalError(op.getLoc(), "requires operands to be ", num_dims, "D tensor"); @@ -1712,17 +1761,17 @@ static LogicalResult Verify(OpT op) { "requires non negative explicit paddings"); } - ArrayRef strides = op.strides().getValue(); - ArrayRef dilations = op.dilations().getValue(); + ArrayRef strides = op.getStrides().getValue(); + ArrayRef dilations = op.getDilations().getValue(); if (failed( VerifyConvOpAttributes(num_dims, strides, dilations, op.getLoc()))) { return failure(); } - int64_t input_channels = ShapedType::kDynamicSize; - if (auto ty = op.input().getType().template dyn_cast()) { - absl::string_view data_format(op.data_format().data(), - op.data_format().size()); + int64_t input_channels = ShapedType::kDynamic; + if (auto ty = op.getInput().getType().template dyn_cast()) { + absl::string_view data_format(op.getDataFormat().data(), + op.getDataFormat().size()); tensorflow::TensorFormat format; auto is_valid = FormatFromString(data_format, &format); DCHECK(is_valid) << data_format; @@ -1730,8 +1779,9 @@ static LogicalResult Verify(OpT op) { input_channels = ty.getDimSize(idx); } - int64_t filter_channels = ShapedType::kDynamicSize; - if (auto ty = op.filter().getType().template dyn_cast()) { + int64_t filter_channels = ShapedType::kDynamic; + if (auto ty = + op.getFilter().getType().template dyn_cast()) { int idx = tensorflow::GetFilterTensorInputChannelsDimIndex( num_dims, tensorflow::FORMAT_HWIO); filter_channels = ty.getDimSize(idx); @@ -1756,17 +1806,17 @@ LogicalResult Conv2DOp::verify() { return Verify(*this); } LogicalResult Conv3DOp::verify() { return Verify(*this); } LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) { - auto perm = GetDataFormatPermutation(this->data_format(), data_format); + auto perm = GetDataFormatPermutation(this->getDataFormat(), data_format); if (perm.empty()) return failure(); // Update data_format attribute and result types. if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); // Update convolution attributes. - (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); - (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm)); + (*this)->setAttr("dilations", ShuffleArrayAttr(getDilations(), perm)); + (*this)->setAttr("strides", ShuffleArrayAttr(getStrides(), perm)); (*this)->setAttr("explicit_paddings", - ShuffleArrayAttr(explicit_paddings(), perm, 2)); + ShuffleArrayAttr(getExplicitPaddings(), perm, 2)); return success(); } @@ -1774,21 +1824,21 @@ LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) { // Verifies the inferred return type of the given operation. template ::value>::type * = nullptr> + OpT, Conv2DOpAdaptor, Conv3DOpAdaptor>::value>::type* = nullptr> static LogicalResult inferConvReturnTypeComponents( - llvm::Optional location, OpT op, + std::optional location, OpT op, ArrayRef explicit_padding, - llvm::SmallVectorImpl &inferredReturnShapes) { + llvm::SmallVectorImpl& inferredReturnShapes) { const int64_t num_spatial_dims = std::is_same() ? 2 : 3; const int64_t num_dims = 2 + num_spatial_dims; - const Value input = op.input(); - const Value filter = op.filter(); + const Value input = op.getInput(); + const Value filter = op.getFilter(); const TensorType input_ty = input.getType().template cast(); const TensorType filter_ty = filter.getType().template cast(); - ArrayRef strides = op.strides().getValue(); - StringRef data_format = op.data_format(); - ArrayRef dilations = op.dilations().getValue(); + ArrayRef strides = op.getStrides().getValue(); + StringRef data_format = op.getDataFormat(); + ArrayRef dilations = op.getDilations().getValue(); tensorflow::TensorFormat format; auto data_format_is_valid = FormatFromString(data_format.str(), &format); @@ -1796,7 +1846,7 @@ static LogicalResult inferConvReturnTypeComponents( (void)data_format_is_valid; tensorflow::Padding padding; - const StringRef paddings = op.padding(); + const StringRef paddings = op.getPadding(); auto padding_is_valid = GetPaddingFromString(paddings.str(), &padding); assert(padding_is_valid.ok()); (void)padding_is_valid; @@ -1807,7 +1857,7 @@ static LogicalResult inferConvReturnTypeComponents( // Output always have `num_dims` rank. All dimensions are initialized to // dynamic size and can be partially inferred. - SmallVector return_shape(num_dims, ShapedType::kDynamicSize); + SmallVector return_shape(num_dims, ShapedType::kDynamic); // Output batch and channel dimension can be obtained using utilities from // tensorflow/core/util/tensor_format.h. if (input_ty.hasRank()) { @@ -1852,9 +1902,9 @@ static LogicalResult inferConvReturnTypeComponents( } LogicalResult Conv2DOp::inferReturnTypeComponents( - MLIRContext *context, Optional location, ValueShapeRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnShapes) { + MLIRContext* context, std::optional location, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnShapes) { Conv2DOpAdaptor op(operands.getValues(), attributes); ArrayRef explicit_padding; ArrayAttr explicit_pad = @@ -1868,15 +1918,15 @@ LogicalResult Conv2DOp::inferReturnTypeComponents( inferredReturnShapes); } -StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) { +StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices& devices) { // Keep current data format if no GPUs are available or if explicit placement // does not allow to use GPU for this operation. if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) - return data_format(); + return getDataFormat(); // Input must be a tensor. - auto input_ty = input().getType().dyn_cast(); - if (!input_ty) return data_format(); + auto input_ty = getInput().getType().dyn_cast(); + if (!input_ty) return getDataFormat(); // For f16 data type on devices with Tensor Cores support NHWC data format // is up to ~2x faster. @@ -1886,11 +1936,11 @@ StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) { // For f32/f16 data type decision depends on the filter size in spatial // dimensions, for other data types we keep current data format. if (!input_ty.getElementType().isF32() && !input_ty.getElementType().isF16()) - return data_format(); + return getDataFormat(); // Keep current data format if filter rank is unknown or not equal to 4. - auto filter_ty = filter().getType().dyn_cast(); - if (!filter_ty || filter_ty.getRank() != 4) return data_format(); + auto filter_ty = getFilter().getType().dyn_cast(); + if (!filter_ty || filter_ty.getRank() != 4) return getDataFormat(); const int64_t d0 = filter_ty.getDimSize(0); const int64_t d1 = filter_ty.getDimSize(1); @@ -1905,8 +1955,8 @@ StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) { // be computed as a GEMM in NHWC data format, and can be up to ~2x times // faster than convolution in NCHW. const bool one_by_one = d0 == 1 && d1 == 1; - const bool trivial_strides = all_ones(strides()); - const bool trivial_dilations = all_ones(dilations()); + const bool trivial_strides = all_ones(getStrides()); + const bool trivial_dilations = all_ones(getDilations()); // TODO(ezhulenev): This might lead to excessive transposes in the final IR, // if the ratio of 1x1 convolutions to regular convolutions is close to 1:1. @@ -1926,7 +1976,7 @@ StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) { //===----------------------------------------------------------------------===// LogicalResult Conv2DBackpropFilterOp::UpdateDataFormat(StringRef data_format) { - StringRef src_data_format = this->data_format(); + StringRef src_data_format = this->getDataFormat(); auto perm = GetDataFormatPermutation(src_data_format, data_format); if (perm.empty()) return failure(); @@ -1935,15 +1985,16 @@ LogicalResult Conv2DBackpropFilterOp::UpdateDataFormat(StringRef data_format) { if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); // Update convolution attributes. - (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); - (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm)); + (*this)->setAttr("dilations", ShuffleArrayAttr(getDilations(), perm)); + (*this)->setAttr("strides", ShuffleArrayAttr(getStrides(), perm)); (*this)->setAttr("explicit_paddings", - ShuffleArrayAttr(explicit_paddings(), perm, 2)); + ShuffleArrayAttr(getExplicitPaddings(), perm, 2)); // Permute filter sizes operand. OpBuilder builder(getOperation()); auto filter_sizes_permuted = builder.create( - getLoc(), filter_sizes(), StringAttr::get(getContext(), src_data_format), + getLoc(), getFilterSizes(), + StringAttr::get(getContext(), src_data_format), StringAttr::get(getContext(), data_format)); setOperand(1, filter_sizes_permuted); @@ -1951,15 +2002,15 @@ LogicalResult Conv2DBackpropFilterOp::UpdateDataFormat(StringRef data_format) { } StringRef Conv2DBackpropFilterOp::GetOptimalLayout( - const RuntimeDevices &devices) { + const RuntimeDevices& devices) { // Keep current data format if no GPUs are available or if explicit placement // does not allow to use GPU for this operation. if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) - return data_format(); + return getDataFormat(); // Input must be a tensor. - auto input_ty = input().getType().dyn_cast(); - if (!input_ty) return data_format(); + auto input_ty = getInput().getType().dyn_cast(); + if (!input_ty) return getDataFormat(); // For f16 data type on devices with Tensor Cores support NHWC data format // is up to ~2x faster. @@ -1979,17 +2030,17 @@ LogicalResult Conv2DBackpropInputOp::verify() { int num_spatial_dims = 2; int num_dims = 2 + num_spatial_dims; - if (!IsOfRankOrUnranked(op.out_backprop(), num_dims) || - !IsOfRankOrUnranked(op.filter(), num_dims)) + if (!IsOfRankOrUnranked(op.getOutBackprop(), num_dims) || + !IsOfRankOrUnranked(op.getFilter(), num_dims)) return op.emitOpError() << "requires operands to be " << num_dims << "D tensor"; if (!IsOfRankOrUnranked(op.getResult(), num_dims)) return op.emitOpError() << "requires result to be " << num_dims << "D tensor"; - llvm::Optional location = op.getLoc(); - ArrayRef strides = op.strides().getValue(); - ArrayRef dilations = op.dilations().getValue(); + std::optional location = op.getLoc(); + ArrayRef strides = op.getStrides().getValue(); + ArrayRef dilations = op.getDilations().getValue(); LogicalResult verify_result = VerifyConvOpAttributes(num_dims, strides, dilations, location); if (failed(verify_result)) { @@ -2000,7 +2051,7 @@ LogicalResult Conv2DBackpropInputOp::verify() { } LogicalResult Conv2DBackpropInputOp::UpdateDataFormat(StringRef data_format) { - StringRef src_data_format = this->data_format(); + StringRef src_data_format = this->getDataFormat(); auto perm = GetDataFormatPermutation(src_data_format, data_format); if (perm.empty()) return failure(); @@ -2009,15 +2060,15 @@ LogicalResult Conv2DBackpropInputOp::UpdateDataFormat(StringRef data_format) { if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); // Update convolution attributes. - (*this)->setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); - (*this)->setAttr("strides", ShuffleArrayAttr(strides(), perm)); + (*this)->setAttr("dilations", ShuffleArrayAttr(getDilations(), perm)); + (*this)->setAttr("strides", ShuffleArrayAttr(getStrides(), perm)); (*this)->setAttr("explicit_paddings", - ShuffleArrayAttr(explicit_paddings(), perm, 2)); + ShuffleArrayAttr(getExplicitPaddings(), perm, 2)); // Permute input sizes operand. OpBuilder builder(getOperation()); auto input_sizes_permuted = builder.create( - getLoc(), input_sizes(), StringAttr::get(getContext(), src_data_format), + getLoc(), getInputSizes(), StringAttr::get(getContext(), src_data_format), StringAttr::get(getContext(), data_format)); setOperand(0, input_sizes_permuted); @@ -2025,15 +2076,15 @@ LogicalResult Conv2DBackpropInputOp::UpdateDataFormat(StringRef data_format) { } StringRef Conv2DBackpropInputOp::GetOptimalLayout( - const RuntimeDevices &devices) { + const RuntimeDevices& devices) { // Keep current data format if no GPUs are available or if explicit placement // does not allow to use GPU for this operation. if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) - return data_format(); + return getDataFormat(); // Filter must be a tensor. - auto filter_ty = filter().getType().dyn_cast(); - if (!filter_ty) return data_format(); + auto filter_ty = getFilter().getType().dyn_cast(); + if (!filter_ty) return getDataFormat(); // For f16 data type on devices with Tensor Cores support NHWC data format // is up to ~2x faster. @@ -2049,9 +2100,9 @@ StringRef Conv2DBackpropInputOp::GetOptimalLayout( //===----------------------------------------------------------------------===// LogicalResult Conv3DOp::inferReturnTypeComponents( - MLIRContext *context, Optional location, ValueShapeRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnShapes) { + MLIRContext* context, std::optional location, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnShapes) { Conv3DOpAdaptor op(operands.getValues(), attributes); ArrayRef explicit_padding; ArrayAttr explicit_pad = @@ -2071,7 +2122,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents( LogicalResult DataFormatVecPermuteOp::verify() { DataFormatVecPermuteOp op = *this; - auto input_ty = op.x().getType().dyn_cast(); + auto input_ty = op.getX().getType().dyn_cast(); if (!input_ty) return success(); int rank = input_ty.getRank(); @@ -2080,18 +2131,18 @@ LogicalResult DataFormatVecPermuteOp::verify() { if (rank == 1) { int64_t dim0 = input_ty.getDimSize(0); - if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2) + if (dim0 != ShapedType::kDynamic && dim0 != 4 && dim0 != 2) return op.emitOpError("requires 1D input of size 4 or size 2"); } if (rank == 2) { int64_t dim0 = input_ty.getDimSize(0); - if (dim0 != ShapedType::kDynamicSize && dim0 != 4) + if (dim0 != ShapedType::kDynamic && dim0 != 4) return op.emitOpError( "requires first dimensions of 2D input to be of size 4"); int64_t dim1 = input_ty.getDimSize(1); - if (dim1 != ShapedType::kDynamicSize && dim1 != 2) + if (dim1 != ShapedType::kDynamic && dim1 != 2) return op.emitOpError( "requires second dimensions of 2D input to be of size 2"); } @@ -2125,7 +2176,7 @@ class DivNoNanOrMulNoNanConstantY : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpT op, - PatternRewriter &rewriter) const override { + PatternRewriter& rewriter) const override { static_assert( llvm::is_one_of::value, "only canonicalization of tf.DivNoNan and tf.MulNoNan is supported"); @@ -2168,7 +2219,7 @@ class DivNoNanOrMulNoNanConstantY : public OpRewritePattern { // Note that `y` is the divisor if the op is tf.DivNoNan and it is the // multiplier if the op is tf.MulNoNan. - Value y = op.y(); + Value y = op.getY(); // The below if condition is true iff `y.getDefiningOp()` is of the type // TF::ConstOp, i.e., if `y` is defined by an op and it is the tf.Const op. // In that case, `yDefOp` stores this tf.Const op. @@ -2178,7 +2229,7 @@ class DivNoNanOrMulNoNanConstantY : public OpRewritePattern { // `y.getDefiningOp()` will not return null but dyn_cast_or_null will. if (auto yDefOp = dyn_cast_or_null(y.getDefiningOp())) { Type typeOfElementsInY = getElementTypeOrSelf(y.getType()); - ElementsAttr attr = yDefOp.value(); + ElementsAttr attr = yDefOp.getValue(); bool yHasComplexElements = typeOfElementsInY.isa(); // If `y` is a splat constant, then the op will definitely get replaced. @@ -2223,8 +2274,8 @@ class DivNoNanOrMulNoNanConstantY : public OpRewritePattern { }; } // namespace -void DivNoNanOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void DivNoNanOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add>(context); } @@ -2232,12 +2283,13 @@ void DivNoNanOp::getCanonicalizationPatterns(RewritePatternSet &results, // DivOp //===----------------------------------------------------------------------===// -void DivOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void DivOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } -OpFoldResult DivOp::fold(ArrayRef operands) { +OpFoldResult DivOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); return IdentityArithmeticOpFolder(*this, operands); } @@ -2247,7 +2299,8 @@ OpFoldResult DivOp::fold(ArrayRef operands) { LogicalResult DynamicStitchOp::verify() { DynamicStitchOp op = *this; - if (op.N() < 1) return op.emitOpError("requires attribute N with value >= 1"); + if (op.getN() < 1) + return op.emitOpError("requires attribute N with value >= 1"); if (RankedTensorType out_ty = op.getType().dyn_cast()) { if (out_ty.getRank() == 0) { @@ -2258,8 +2311,8 @@ LogicalResult DynamicStitchOp::verify() { llvm::SmallDenseSet index_values; bool all_indices_const = true; int32_t max_index = -1; - llvm::Optional> inferred_item_shape; - for (auto it : llvm::zip(op.indices(), op.data())) { + std::optional> inferred_item_shape; + for (auto it : llvm::zip(op.getIndices(), op.getData())) { Value index = std::get<0>(it); DenseIntElementsAttr index_attr; @@ -2299,10 +2352,10 @@ LogicalResult DynamicStitchOp::verify() { if (failed(mlir::verifyCompatibleShape(item_shape, *inferred_item_shape))) return op.emitOpError() << "has inconsistent shaped data and index " "pairs; inferred item shapes [" - << llvm::makeArrayRef(*inferred_item_shape) + << llvm::ArrayRef(*inferred_item_shape) << "] and [" << item_shape << "] don't match"; for (int i = 0, e = item_shape.size(); i < e; ++i) { - int64_t &inferred_dim = (*inferred_item_shape)[i]; + int64_t& inferred_dim = (*inferred_item_shape)[i]; int64_t dim = item_shape[i]; if (ShapedType::isDynamic(inferred_dim)) inferred_dim = dim; } @@ -2347,7 +2400,7 @@ LogicalResult DynamicStitchOp::verify() { // TODO(hinsu): Verify einsum equation attribute. LogicalResult EinsumOp::verify() { EinsumOp op = *this; - if (op.N() > 2) { + if (op.getN() > 2) { return op.emitOpError("supports at most two operands"); } return success(); @@ -2357,7 +2410,8 @@ LogicalResult EinsumOp::verify() { // EmptyOp //===----------------------------------------------------------------------===// -OpFoldResult EmptyOp::fold(ArrayRef operands) { +OpFoldResult EmptyOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); assert(operands.size() == 1 && "empty op has one operand"); Attribute attr = operands.front(); @@ -2404,12 +2458,12 @@ LogicalResult EmptyTensorListOp::verify() { "must have exactly one subtype in the result variant type"); } - if (!IsOfRankOrUnranked(op.element_shape(), 0) && - !IsOfRankOrUnranked(op.element_shape(), 1)) { + if (!IsOfRankOrUnranked(op.getElementShape(), 0) && + !IsOfRankOrUnranked(op.getElementShape(), 1)) { return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); } - if (!IsOfRankOrUnranked(op.max_num_elements(), 0)) { + if (!IsOfRankOrUnranked(op.getMaxNumElements(), 0)) { return op.emitOpError("requires max_num_elements operand to be 0D tensor"); } return success(); @@ -2425,7 +2479,7 @@ LogicalResult EmptyTensorListOp::verify() { // Helper function to get an absolute device string, combining device and // ordinal attribute values. -std::string GetAbsDeviceStr(Operation *op, uint64_t device_ordinal) { +std::string GetAbsDeviceStr(Operation* op, uint64_t device_ordinal) { std::string device_ordinal_str = std::to_string(device_ordinal); auto device_attr = op->getAttrOfType("device"); if (!device_attr || device_attr.getValue().empty()) return device_ordinal_str; @@ -2437,53 +2491,58 @@ std::string GetAbsDeviceStr(Operation *op, uint64_t device_ordinal) { return absl::StrCat(device_str, ":", device_ordinal_str); } -std::string +std::optional EnqueueTPUEmbeddingArbitraryTensorBatchOp::GetResourceInstanceStr() { - return GetAbsDeviceStr(*this, device_ordinal()); + return GetAbsDeviceStr(*this, getDeviceOrdinal()); } -std::string EnqueueTPUEmbeddingBatchOp::GetResourceInstanceStr() { - return GetAbsDeviceStr(*this, device_ordinal()); +std::optional +EnqueueTPUEmbeddingBatchOp::GetResourceInstanceStr() { + return GetAbsDeviceStr(*this, getDeviceOrdinal()); } -std::string EnqueueTPUEmbeddingIntegerBatchOp::GetResourceInstanceStr() { - return GetAbsDeviceStr(*this, device_ordinal()); +std::optional +EnqueueTPUEmbeddingIntegerBatchOp::GetResourceInstanceStr() { + return GetAbsDeviceStr(*this, getDeviceOrdinal()); } -std::string EnqueueTPUEmbeddingRaggedTensorBatchOp::GetResourceInstanceStr() { - return GetAbsDeviceStr(*this, device_ordinal()); +std::optional +EnqueueTPUEmbeddingRaggedTensorBatchOp::GetResourceInstanceStr() { + return GetAbsDeviceStr(*this, getDeviceOrdinal()); } -std::string EnqueueTPUEmbeddingSparseBatchOp::GetResourceInstanceStr() { - return GetAbsDeviceStr(*this, device_ordinal()); +std::optional +EnqueueTPUEmbeddingSparseBatchOp::GetResourceInstanceStr() { + return GetAbsDeviceStr(*this, getDeviceOrdinal()); } -std::string EnqueueTPUEmbeddingSparseTensorBatchOp::GetResourceInstanceStr() { - return GetAbsDeviceStr(*this, device_ordinal()); +std::optional +EnqueueTPUEmbeddingSparseTensorBatchOp::GetResourceInstanceStr() { + return GetAbsDeviceStr(*this, getDeviceOrdinal()); } //===----------------------------------------------------------------------===// // EnsureShapeOp //===----------------------------------------------------------------------===// -OpFoldResult EnsureShapeOp::fold(llvm::ArrayRef) { - ShapedType type = input().getType().dyn_cast(); +OpFoldResult EnsureShapeOp::fold(FoldAdaptor) { + ShapedType type = getInput().getType().dyn_cast(); if (!type || !type.hasRank()) return {}; // If shape attribute equals input operand's type's shape, fold it to input. - llvm::Optional> shape_constraint = shape(); - if (type.getShape() == shape_constraint) return input(); + Optional> shape_constraint = getShape(); + if (type.getShape() == shape_constraint) return getInput(); // If input operand's type's shape always satisfies the shape attribute, fold // it to input. if (shape_constraint.has_value() && shape_constraint->size() == type.getShape().size()) { for (int i = 0; i < shape_constraint->size(); ++i) { - if (!ShapedType::isDynamic(shape_constraint.getValue()[i]) && - type.getDimSize(i) != shape_constraint.getValue()[i]) { + if (!ShapedType::isDynamic(shape_constraint.value()[i]) && + type.getDimSize(i) != shape_constraint.value()[i]) { return {}; } } - return input(); + return getInput(); } // Else retain to enable failing dynamically. return {}; @@ -2496,14 +2555,14 @@ OpFoldResult EnsureShapeOp::fold(llvm::ArrayRef) { LogicalResult EqualOp::verify() { EqualOp op = *this; // If we allow inputs to have incompatible type, then nothing to do. - if (!op.incompatible_shape_error()) return success(); + if (!op.getIncompatibleShapeError()) return success(); // Otherwise, check inputs are broadcastable. return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast( op.getOperation()); } -void EqualOp::build(OpBuilder &builder, OperationState &result, Value x, +void EqualOp::build(OpBuilder& builder, OperationState& result, Value x, Value y, BoolAttr incompatible_shape_error) { auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y, incompatible_shape_error); @@ -2515,8 +2574,8 @@ namespace { // Flips the incompatible_shape_error attribute to true if the shapes are known // to be compatible. template -static LogicalResult flipComatibleShapeError(Ty op, PatternRewriter &rewriter) { - if (op.incompatible_shape_error()) { +static LogicalResult flipComatibleShapeError(Ty op, PatternRewriter& rewriter) { + if (op.getIncompatibleShapeError()) { return rewriter.notifyMatchFailure(op, "the attribute is already true"); } @@ -2532,27 +2591,27 @@ static LogicalResult flipComatibleShapeError(Ty op, PatternRewriter &rewriter) { // Unless this is a scalar compare, a scalar output indicates that this will // always fail. - auto x_ty = op.x().getType().template dyn_cast(); - auto y_ty = op.y().getType().template dyn_cast(); + auto x_ty = op.getX().getType().template dyn_cast(); + auto y_ty = op.getY().getType().template dyn_cast(); if (ty.getRank() == 0 && (!x_ty || x_ty.getRank() != 0 || !y_ty || y_ty.getRank() != 0)) { return rewriter.notifyMatchFailure(op, "output rank must match input rank"); } // Shapes are known to be compatible. - rewriter.template replaceOpWithNewOp(op, op.x(), op.y(), + rewriter.template replaceOpWithNewOp(op, op.getX(), op.getY(), rewriter.getBoolAttr(true)); return success(); } } // namespace -void EqualOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void EqualOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(flipComatibleShapeError); } -void NotEqualOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void NotEqualOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(flipComatibleShapeError); } @@ -2582,7 +2641,7 @@ Type InferExpandDimsOpType(Value input, Value dim) { return tensorflow::GetTypeFromTFTensorShape(shape, element_ty); } -void ExpandDimsOp::build(OpBuilder &builder, OperationState &result, +void ExpandDimsOp::build(OpBuilder& builder, OperationState& result, Value input, Value dim) { return build(builder, result, InferExpandDimsOpType(input, dim), input, dim); } @@ -2593,21 +2652,21 @@ void ExpandDimsOp::build(OpBuilder &builder, OperationState &result, LogicalResult FakeQuantWithMinMaxArgsOp::verify() { FakeQuantWithMinMaxArgsOp op = *this; // TODO(fengliuai): moving the following to an utility method. - const llvm::fltSemantics &semantics = op.min().getSemantics(); + const llvm::fltSemantics& semantics = op.getMin().getSemantics(); float rmin, rmax; if (&semantics == &APFloat::IEEEsingle()) { - rmin = op.min().convertToFloat(); - rmax = op.max().convertToFloat(); + rmin = op.getMin().convertToFloat(); + rmax = op.getMax().convertToFloat(); } else { - rmin = op.min().convertToDouble(); - rmax = op.max().convertToDouble(); + rmin = op.getMin().convertToDouble(); + rmax = op.getMax().convertToDouble(); } // Range boundaries must be valid. if (rmin >= rmax) { return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) + "," + Twine(std::to_string(rmax)) + "]"); } - int64_t num_bits = op.num_bits(); + int64_t num_bits = op.getNumBits(); if (num_bits < 2 || num_bits > 16) { return op.emitOpError( "requires num_bits to be between 2 and 16, inclusive"); @@ -2620,15 +2679,15 @@ LogicalResult FakeQuantWithMinMaxArgsOp::verify() { //===----------------------------------------------------------------------===// LogicalResult FakeQuantWithMinMaxVarsOp::verify() { FakeQuantWithMinMaxVarsOp op = *this; - auto min = GetRankedTensorTypeForOperand(op.min()); + auto min = GetRankedTensorTypeForOperand(op.getMin()); if (min && !IsOfRankedFloatTensorType(min, 0)) return op.emitOpError("requires min to be a 0d float tensor"); - auto max = GetRankedTensorTypeForOperand(op.max()); + auto max = GetRankedTensorTypeForOperand(op.getMax()); if (max && !IsOfRankedFloatTensorType(max, 0)) return op.emitOpError("requires max to be a 0d float tensor"); - int64_t num_bits = op.num_bits(); + int64_t num_bits = op.getNumBits(); if (num_bits < 2 || num_bits > 16) { return op.emitOpError( "requires num_bits to be between 2 and 16, inclusive"); @@ -2641,19 +2700,19 @@ LogicalResult FakeQuantWithMinMaxVarsOp::verify() { //===----------------------------------------------------------------------===// LogicalResult FakeQuantWithMinMaxVarsPerChannelOp::verify() { FakeQuantWithMinMaxVarsPerChannelOp op = *this; - auto min = GetRankedTensorTypeForOperand(op.min()); + auto min = GetRankedTensorTypeForOperand(op.getMin()); if (min && !IsOfRankedFloatTensorType(min, 1)) return op.emitOpError("requires min to be a 1d float tensor"); - auto max = GetRankedTensorTypeForOperand(op.max()); + auto max = GetRankedTensorTypeForOperand(op.getMax()); if (max && !IsOfRankedFloatTensorType(max, 1)) return op.emitOpError("requires max to be a 1d float tensor"); - Value inputs = op.inputs(); + Value inputs = op.getInputs(); if (!HasRankAtLeast(inputs, 1)) return op.emitError("requires inputs to be at least 1d float tensor"); - int64_t num_bits = op.num_bits(); + int64_t num_bits = op.getNumBits(); if (num_bits < 2 || num_bits > 16) { return op.emitOpError( "requires num_bits to be between 2 and 16, inclusive"); @@ -2677,9 +2736,9 @@ LogicalResult FakeQuantWithMinMaxVarsPerChannelOp::verify() { LogicalResult FillOp::verify() { FillOp op = *this; - if (!IsOfRankOrUnranked(op.dims(), 1)) + if (!IsOfRankOrUnranked(op.getDims(), 1)) return op.emitOpError() << "requires dims to be a 1D tensor"; - if (!IsOfRankOrUnranked(op.value(), 0)) + if (!IsOfRankOrUnranked(op.getValue(), 0)) return op.emitOpError() << "requires value to be a scalar"; return success(); @@ -2699,7 +2758,7 @@ static ShapedType InferFillOpType(Value dims, Value value) { } if (auto shape_op = dims.getDefiningOp()) { - if (auto t = shape_op.input().getType().dyn_cast()) { + if (auto t = shape_op.getInput().getType().dyn_cast()) { return t; } } @@ -2707,12 +2766,13 @@ static ShapedType InferFillOpType(Value dims, Value value) { return UnrankedTensorType::get(etype); } -void FillOp::build(OpBuilder &builder, OperationState &result, Value dims, +void FillOp::build(OpBuilder& builder, OperationState& result, Value dims, Value value) { FillOp::build(builder, result, InferFillOpType(dims, value), dims, value); } -OpFoldResult FillOp::fold(ArrayRef operands) { +OpFoldResult FillOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); assert(operands.size() == 2 && "fill op has two operand"); auto type = getType().cast(); @@ -2753,15 +2813,15 @@ LogicalResult FusedBatchNormGradV3Op::UpdateDataFormat(StringRef data_format) { } StringRef FusedBatchNormGradV3Op::GetOptimalLayout( - const RuntimeDevices &devices) { + const RuntimeDevices& devices) { // Keep current data format if no GPUs are available or if explicit placement // does not allow to use GPU for this operation. if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) - return data_format(); + return getDataFormat(); // For f16 data type on devices with Tensor Cores support NHWC data format // is up to ~2x faster. - auto x_ty = x().getType().cast(); + auto x_ty = getX().getType().cast(); const bool is_f16 = x_ty.getElementType().isF16(); if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; @@ -2775,23 +2835,23 @@ StringRef FusedBatchNormGradV3Op::GetOptimalLayout( LogicalResult FusedBatchNormOp::verify() { FusedBatchNormOp op = *this; - auto x = GetRankedTensorTypeForOperand(op.x()); + auto x = GetRankedTensorTypeForOperand(op.getX()); if (x && !IsOfRankedFloatTensorType(x, 4)) return op.emitOpError("requires x to be a 4D float tensor"); - auto scale = GetRankedTensorTypeForOperand(op.scale()); + auto scale = GetRankedTensorTypeForOperand(op.getScale()); if (scale && !IsOfRankedFloatTensorType(scale, 1)) return op.emitOpError("requires scale to be a 1D float tensor"); - auto offset = GetRankedTensorTypeForOperand(op.offset()); + auto offset = GetRankedTensorTypeForOperand(op.getOffset()); if (offset && !IsOfRankedFloatTensorType(offset, 1)) return op.emitOpError("requires offset to be a 1D float tensor"); - auto mean = GetRankedTensorTypeForOperand(op.mean()); + auto mean = GetRankedTensorTypeForOperand(op.getMean()); if (mean && !IsOfRankedFloatTensorType(mean, 1)) return op.emitOpError("requires mean to be a 1D float tensor"); - auto variance = GetRankedTensorTypeForOperand(op.variance()); + auto variance = GetRankedTensorTypeForOperand(op.getVariance()); if (variance && !IsOfRankedFloatTensorType(variance, 1)) return op.emitOpError("requires variance to be a 1D float tensor"); @@ -2806,26 +2866,26 @@ LogicalResult FusedBatchNormOp::verify() { template static LogicalResult InferenceFoldOperandsPermutation( - ArrayRef permutation, Op *op) { + ArrayRef permutation, Op* op) { // FusedBatchNorm in training mode is a layout sentitive operation, and should // have already assigned an optimal data format. - if (op->is_training()) return failure(); + if (op->getIsTraining()) return failure(); return ::mlir::TF::FoldOperandsPermutation(permutation, op); } template -static StringRef GetOptimalLayout(const RuntimeDevices &devices, Op *op) { +static StringRef GetOptimalLayout(const RuntimeDevices& devices, Op* op) { // In inference mode FusedBatchNorm is not sensitive to data layout. - if (!op->is_training()) return op->data_format(); + if (!op->getIsTraining()) return op->getDataFormat(); // Keep current data format if no GPUs are available or if explicit placement // does not allow to use GPU for this operation. if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(op->getOperation())) - return op->data_format(); + return op->getDataFormat(); // For f16 data type on devices with Tensor Cores support NHWC data format // is up to ~2x faster. - auto x_ty = op->x().getType().template cast(); + auto x_ty = op->getX().getType().template cast(); const bool is_f16 = x_ty.getElementType().isF16(); if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; @@ -2842,7 +2902,7 @@ LogicalResult FusedBatchNormV2Op::UpdateDataFormat(StringRef data_format) { return ::mlir::TF::UpdateDataFormat(data_format, this); } -StringRef FusedBatchNormV2Op::GetOptimalLayout(const RuntimeDevices &devices) { +StringRef FusedBatchNormV2Op::GetOptimalLayout(const RuntimeDevices& devices) { return ::mlir::TF::GetOptimalLayout(devices, this); } @@ -2855,7 +2915,7 @@ LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) { return ::mlir::TF::UpdateDataFormat(data_format, this); } -StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) { +StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices& devices) { return ::mlir::TF::GetOptimalLayout(devices, this); } @@ -2865,8 +2925,8 @@ StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) { LogicalResult GatherV2Op::verify() { GatherV2Op op = *this; - int64_t batch_dims = op.batch_dims(); - if (auto ty = op.indices().getType().dyn_cast()) { + int64_t batch_dims = op.getBatchDims(); + if (auto ty = op.getIndices().getType().dyn_cast()) { int64_t rank = ty.getRank(); if (batch_dims > rank || batch_dims < -rank) return op.emitOpError() @@ -2875,13 +2935,13 @@ LogicalResult GatherV2Op::verify() { if (batch_dims < 0) batch_dims += rank; } - if (!HasRankAtMost(op.axis(), 1)) + if (!HasRankAtMost(op.getAxis(), 1)) return op.emitOpError("requires axis to have rank at most 1"); DenseIntElementsAttr axis_attr; - if (matchPattern(op.axis(), m_Constant(&axis_attr))) { + if (matchPattern(op.getAxis(), m_Constant(&axis_attr))) { int64_t axis = (*axis_attr.begin()).getSExtValue(); - if (auto ty = op.params().getType().dyn_cast()) { + if (auto ty = op.getParams().getType().dyn_cast()) { int64_t rank = ty.getRank(); if (axis >= rank || axis < -rank) return op.emitOpError() << "axis (" << axis << ") must be in range [" @@ -2898,8 +2958,8 @@ LogicalResult GatherV2Op::verify() { return success(); } -void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void GatherOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -2907,12 +2967,13 @@ void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results, // IfOp //===----------------------------------------------------------------------===// -LogicalResult IfOp::verifySymbolUses(SymbolTableCollection &symbol_table) { +LogicalResult IfOp::verifySymbolUses(SymbolTableCollection& symbol_table) { auto branch_name = [](unsigned index) -> std::string { return index == 0 ? "'then_branch'" : "'else_branch'"; }; return VerifyCaseOrIfOpBranchFunctions( - symbol_table, *this, {then_branchAttr(), else_branchAttr()}, branch_name); + symbol_table, *this, {getThenBranchAttr(), getElseBranchAttr()}, + branch_name); } //===----------------------------------------------------------------------===// @@ -2922,10 +2983,10 @@ LogicalResult IfOp::verifySymbolUses(SymbolTableCollection &symbol_table) { namespace { class FoldConstantIfOp : public OpRewritePattern { public: - explicit FoldConstantIfOp(MLIRContext *context) + explicit FoldConstantIfOp(MLIRContext* context) : OpRewritePattern(context) {} LogicalResult matchAndRewrite(TF::IfOp op, - PatternRewriter &rewriter) const override; + PatternRewriter& rewriter) const override; private: template @@ -2935,27 +2996,28 @@ class FoldConstantIfOp : public OpRewritePattern { }; LogicalResult FoldConstantIfOp::matchAndRewrite( - TF::IfOp op, PatternRewriter &rewriter) const { + TF::IfOp op, PatternRewriter& rewriter) const { // Extract the constant cond value. DenseIntElementsAttr cond_attr; - if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure(); + if (!matchPattern(op.getCond(), m_Constant(&cond_attr))) return failure(); // Cond value must be a scalar. if (cond_attr.getNumElements() != 1) return failure(); // Select a branch function. bool cond = cond_attr.getSplatValue().getValue(); - FlatSymbolRefAttr func = cond ? op.then_branchAttr() : op.else_branchAttr(); + FlatSymbolRefAttr func = + cond ? op.getThenBranchAttr() : op.getElseBranchAttr(); // Replace IfOp with PartitionedCallOp or StatefulPartitionedCallOp. auto rewrite = [&](auto op_type) { auto empty = rewriter.getStringAttr(""); ReplaceTfOpWithNewOp( - rewriter, op, op.getResultTypes(), op.input(), func, + rewriter, op, op.getResultTypes(), op.getInput(), func, /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty); }; - if (op.is_stateless()) + if (op.getIsStateless()) rewrite(CallOpType{}); else rewrite(CallOpType{}); @@ -2964,8 +3026,8 @@ LogicalResult FoldConstantIfOp::matchAndRewrite( } } // anonymous namespace -void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void IfOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add>(context); } @@ -2976,9 +3038,9 @@ void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult IfRegionOp::verifyRegions() { IfRegionOp op = *this; TypeRange then_types = - op.then_branch().front().getTerminator()->getOperandTypes(); + op.getThenBranch().front().getTerminator()->getOperandTypes(); TypeRange else_types = - op.else_branch().front().getTerminator()->getOperandTypes(); + op.getElseBranch().front().getTerminator()->getOperandTypes(); TypeRangeWithDesc results{op.getResultTypes(), "result"}; TypeRangeWithDesc then_results{then_types, "then result"}; @@ -2994,21 +3056,21 @@ LogicalResult IfRegionOp::verifyRegions() { namespace { class FoldConstantIfRegionOp : public OpRewritePattern { public: - explicit FoldConstantIfRegionOp(MLIRContext *context) + explicit FoldConstantIfRegionOp(MLIRContext* context) : OpRewritePattern(context) {} LogicalResult matchAndRewrite(TF::IfRegionOp op, - PatternRewriter &rewriter) const override; + PatternRewriter& rewriter) const override; }; LogicalResult FoldConstantIfRegionOp::matchAndRewrite( - TF::IfRegionOp op, PatternRewriter &rewriter) const { + TF::IfRegionOp op, PatternRewriter& rewriter) const { // Extract the constant cond value. DenseIntElementsAttr cond_attr; - if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure(); + if (!matchPattern(op.getCond(), m_Constant(&cond_attr))) return failure(); // IfRegion condition should always be a scalar. Select the region to fold to. bool cond = cond_attr.getSplatValue().getValue(); - Region ®ion = cond ? op.then_branch() : op.else_branch(); + Region& region = cond ? op.getThenBranch() : op.getElseBranch(); // If the IfRegion is stateless but the region being inlined itself is not // stateless, then inlining the region could cause a loss of information. @@ -3025,7 +3087,7 @@ LogicalResult FoldConstantIfRegionOp::matchAndRewrite( // casts. rewriter.setInsertionPoint(yield); for (auto it : llvm::zip(op.getResultTypes(), updated_results)) { - auto &updated_result = std::get<1>(it); + auto& updated_result = std::get<1>(it); Type result_type = std::get<0>(it); if (result_type != updated_result.getType()) { updated_result = @@ -3041,8 +3103,8 @@ LogicalResult FoldConstantIfRegionOp::matchAndRewrite( } } // anonymous namespace -void IfRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void IfRegionOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add>(context); } @@ -3054,7 +3116,7 @@ void IfRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, // Verifies that the input is 1D. LogicalResult InvertPermutationOp::verify() { InvertPermutationOp op = *this; - auto x_type = op.x().getType().cast(); + auto x_type = op.getX().getType().cast(); if (!x_type.hasRank()) return success(); if (x_type.getShape().size() != 1) return op.emitOpError() << "requires input x to be 1-dimensional"; @@ -3066,15 +3128,16 @@ LogicalResult InvertPermutationOp::verify() { // LeakyReluOp //===----------------------------------------------------------------------===// -OpFoldResult LeakyReluOp::fold(ArrayRef operands) { +OpFoldResult LeakyReluOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); assert(operands.size() == 1 && "leaky relu has one operand"); // leaky_relu(x, alpha: 1) -> x - if (alpha().convertToFloat() == 1.0f) return getOperand(); + if (getAlpha().convertToFloat() == 1.0f) return getOperand(); auto calculate = [&](FloatAttr arg) { APFloat val = arg.getValue(); - if (val.isNegative()) val = alpha() * val; + if (val.isNegative()) val = getAlpha() * val; return FloatAttr::get(arg.getType(), val); }; @@ -3087,21 +3150,76 @@ OpFoldResult LeakyReluOp::fold(ArrayRef operands) { return {}; } +//===----------------------------------------------------------------------===// +// LegacyCallOp +//===----------------------------------------------------------------------===// + +LogicalResult LegacyCallOp::verifySymbolUses( + SymbolTableCollection& symbolTable) { + StringAttr func_attr = getFAttr().getAttr(); + StringRef func_name = func_attr.getValue(); + func::FuncOp func = + symbolTable.lookupNearestSymbolFrom(*this, func_attr); + + if (!func) { + return emitError("'f' attribute refers to an undefined function: ") + << func_name; + } + + FunctionType func_ty = func.getFunctionType(); + int func_arg_count = func_ty.getNumInputs(); + int arg_count = getArgs().size(); + + if (arg_count != func_arg_count) { + return emitError() << "argument count mismatch: 'args' has " << arg_count + << " argument(s), but '" << func_name << "' expects " + << func_arg_count; + } + + return success(); +} + //===----------------------------------------------------------------------===// // LogOp //===----------------------------------------------------------------------===// -void LogOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void LogOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } +//===----------------------------------------------------------------------===// +// LogicalAndOp +//===----------------------------------------------------------------------===// + +OpFoldResult LogicalAndOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); + // TODO(b/264429950): Expand this to work for broadcastable shapes and other + // conditions (e.g. one operand is always True). + auto result_type = getType(); + + for (const auto& operand : operands) { + auto splat_attr = operand.dyn_cast_or_null(); + if (!splat_attr) continue; + + if (splat_attr.getType() != result_type) continue; + + // We can only fold away constant Falses. + auto splat_value = splat_attr.getSplatValue().getValue(); + if (splat_value) continue; + + return operand; + } + + return {}; +} + //===----------------------------------------------------------------------===// // LogicalNotOp //===----------------------------------------------------------------------===// -void LogicalNotOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void LogicalNotOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results .add( @@ -3114,20 +3232,20 @@ void LogicalNotOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult MatrixBandPartOp::verify() { MatrixBandPartOp op = *this; - if (!HasRankAtLeast(op.input(), 2)) { + if (!HasRankAtLeast(op.getInput(), 2)) { return op.emitOpError() << "requires `input` to have rank of at least 2, but found " - << op.input().getType(); + << op.getInput().getType(); } - if (!IsOfRankOrUnranked(op.num_lower(), 0)) { + if (!IsOfRankOrUnranked(op.getNumLower(), 0)) { return op.emitOpError() << "requires `num_lower` to have 0 dimensions, but found " - << op.num_lower().getType(); + << op.getNumLower().getType(); } - if (!IsOfRankOrUnranked(op.num_upper(), 0)) { + if (!IsOfRankOrUnranked(op.getNumUpper(), 0)) { return op.emitOpError() << "requires `num_upper` to have 0 dimensions, but found " - << op.num_upper().getType(); + << op.getNumUpper().getType(); } return success(); } @@ -3136,8 +3254,8 @@ LogicalResult MatrixBandPartOp::verify() { // MatrixDiag Ops //===----------------------------------------------------------------------===// -void MatrixDiagOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void MatrixDiagOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -3145,8 +3263,8 @@ void MatrixDiagOp::getCanonicalizationPatterns(RewritePatternSet &results, // MatrixSetDiagOp //===----------------------------------------------------------------------===// -void MatrixSetDiagOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void MatrixSetDiagOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -3154,8 +3272,8 @@ void MatrixSetDiagOp::getCanonicalizationPatterns(RewritePatternSet &results, // MatrixSetDiagV2Op //===----------------------------------------------------------------------===// -void MatrixSetDiagV2Op::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void MatrixSetDiagV2Op::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -3163,7 +3281,7 @@ void MatrixSetDiagV2Op::getCanonicalizationPatterns(RewritePatternSet &results, // MaxOp //===----------------------------------------------------------------------===// -void MaxOp::build(OpBuilder &builder, OperationState &result, Value input, +void MaxOp::build(OpBuilder& builder, OperationState& result, Value input, Value reduction_indices, BoolAttr keep_dims) { Type out_ty = InferReductionOpType(input, reduction_indices, keep_dims); build(builder, result, out_ty, input, reduction_indices, keep_dims); @@ -3173,8 +3291,8 @@ void MaxOp::build(OpBuilder &builder, OperationState &result, Value input, // MaximumOp //===----------------------------------------------------------------------===// -void MaximumOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void MaximumOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); } @@ -3185,11 +3303,11 @@ void MaximumOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult MaxPoolOp::FoldOperandsPermutation( ArrayRef permutation) { return ::mlir::TF::FoldOperandsPermutation( - permutation, this, {{"strides", strides()}, {"ksize", ksize()}}); + permutation, this, {{"strides", getStrides()}, {"ksize", getKsize()}}); } LogicalResult MaxPoolOp::UpdateDataFormat(StringRef new_data_format) { - StringRef src_data_format = data_format(); + StringRef src_data_format = getDataFormat(); auto perm = GetDataFormatPermutation(src_data_format, new_data_format); if (perm.empty()) return failure(); @@ -3198,18 +3316,18 @@ LogicalResult MaxPoolOp::UpdateDataFormat(StringRef new_data_format) { if (failed(::mlir::TF::UpdateDataFormat(new_data_format, this))) return failure(); - stridesAttr(ShuffleArrayAttr(strides(), perm)); - explicit_paddingsAttr(ShuffleArrayAttr(explicit_paddings(), perm, 2)); - ksizeAttr(ShuffleArrayAttr(ksize(), perm)); + setStridesAttr(ShuffleArrayAttr(getStrides(), perm)); + setExplicitPaddingsAttr(ShuffleArrayAttr(getExplicitPaddings(), perm, 2)); + setKsizeAttr(ShuffleArrayAttr(getKsize(), perm)); return success(); } -StringRef MaxPoolOp::GetOptimalLayout(const RuntimeDevices &devices) { +StringRef MaxPoolOp::GetOptimalLayout(const RuntimeDevices& devices) { // Keep current data format if no GPUs are available or if explicit placement // does not allow to use GPU for this operation. if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) - return data_format(); + return getDataFormat(); // Defaults to NCHW. return "NCHW"; @@ -3221,13 +3339,13 @@ StringRef MaxPoolOp::GetOptimalLayout(const RuntimeDevices &devices) { LogicalResult MaxPoolGradOp::verify() { MaxPoolGradOp op = *this; - if (!IsOfRankOrUnranked(op.orig_input(), 4)) { + if (!IsOfRankOrUnranked(op.getOrigInput(), 4)) { return op.emitOpError() << "requires orig_input to be rank 4"; } - if (!IsOfRankOrUnranked(op.orig_output(), 4)) { + if (!IsOfRankOrUnranked(op.getOrigOutput(), 4)) { return op.emitOpError() << "requires orig_output to be rank 4"; } - if (!IsOfRankOrUnranked(op.grad(), 4)) { + if (!IsOfRankOrUnranked(op.getGrad(), 4)) { return op.emitOpError() << "requires grad to be rank 4"; } return success(); @@ -3240,10 +3358,10 @@ LogicalResult MaxPoolGradOp::verify() { LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef permutation) { // Reduction indices must be defined by a constant operation. auto reduction_op = - dyn_cast_or_null(reduction_indices().getDefiningOp()); + dyn_cast_or_null(getReductionIndices().getDefiningOp()); if (!reduction_op) return failure(); - auto reductions_value = reduction_op.value().dyn_cast(); + auto reductions_value = reduction_op.getValue().dyn_cast(); if (!reductions_value) return failure(); // Prepare new reduction indices according to operand permutation. @@ -3270,8 +3388,8 @@ LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef permutation) { // MulNoNanOp //===----------------------------------------------------------------------===// -void MulNoNanOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void MulNoNanOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add>(context); } @@ -3279,15 +3397,16 @@ void MulNoNanOp::getCanonicalizationPatterns(RewritePatternSet &results, // MulOp //===----------------------------------------------------------------------===// -OpFoldResult MulOp::fold(ArrayRef operands) { +OpFoldResult MulOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); return IdentityArithmeticOpFolder(*this, operands); } //===----------------------------------------------------------------------===// // HashTableOp //===----------------------------------------------------------------------===// -void HashTableOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { +void HashTableOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { results.add(context); results.add(context); results.add(context); @@ -3299,8 +3418,8 @@ void HashTableOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult BitcastOp::verify() { BitcastOp op = *this; - auto input_type = op.input().getType().cast(); - auto output_type = op.output().getType().cast(); + auto input_type = op.getInput().getType().cast(); + auto output_type = op.getOutput().getType().cast(); auto input_element_type = input_type.getElementType(); auto output_element_type = output_type.getElementType(); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h index 1732ecc7658..77f87e0f960 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h @@ -52,7 +52,7 @@ bool AreCancellablePermutations(DenseIntElementsAttr perm0, // attributes besides `data_format` string. template LogicalResult UpdateDataFormat(StringRef data_format, Op *op) { - auto perm = GetDataFormatPermutation(op->data_format(), data_format); + auto perm = GetDataFormatPermutation(op->getDataFormat(), data_format); if (perm.empty()) return failure(); // Update data format attribute. @@ -83,9 +83,10 @@ LogicalResult FoldOperandsPermutation( // Operation data format after folding `permutation`. StringRef target_data_format = [&]() -> StringRef { - if (op->data_format() == "NHWC" && permutation.equals(kNchwToNhwc)) { + if (op->getDataFormat() == "NHWC" && permutation.equals(kNchwToNhwc)) { return "NCHW"; // cancel NCHW->NHWC operand permutation - } else if (op->data_format() == "NCHW" && permutation.equals(kNhwcToNchw)) { + } else if (op->getDataFormat() == "NCHW" && + permutation.equals(kNhwcToNchw)) { return "NHWC"; // cancel NHWC->NCHW operand permutation } else { return ""; @@ -105,7 +106,7 @@ LogicalResult FoldOperandsPermutation( // To bypass %2 we have to change data format to shuffle data format from NCHW // to NHWC, which is the reverse of operand permutation (function argument). auto reverse_permutation = - GetDataFormatPermutation(op->data_format(), target_data_format); + GetDataFormatPermutation(op->getDataFormat(), target_data_format); if (reverse_permutation.empty()) return failure(); (*op)->setAttr("data_format", StringAttr::get(context, target_data_format)); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 1ebd7f8acc0..2c8b6a1cfd1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -29,7 +30,6 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" @@ -68,6 +68,7 @@ limitations under the License. #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_canonicalization_helper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_device_helper.h" @@ -98,6 +99,7 @@ Value LookThroughIdentity(Value result) { #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" } // namespace +INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NcclAllReduceOp); INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NegOp); INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(OnesLikeOp); INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PreventGradientOp); @@ -136,6 +138,21 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(TanhGradOp); INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ZerosLikeOp); INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(_UnaryOpsCompositionOp); +//===----------------------------------------------------------------------===// +// NcclAllReduceOp +//===----------------------------------------------------------------------===// + +// For `NcclAllReduceOp` ops the `device` attribute corresponds to the resource +// instance. +std::optional NcclAllReduceOp::GetResourceInstanceStr() { + auto device_attr = (*this)->getAttrOfType("device"); + // Treat missing device attribute like unspecified (= empty string) attribute. + // Note that different op instances with the same string (including empty + // string) are seen as dependent (same resource instance). + if (!device_attr) return ""; + return device_attr.str(); +} + //===----------------------------------------------------------------------===// // NotEqualOp //===----------------------------------------------------------------------===// @@ -143,7 +160,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(_UnaryOpsCompositionOp); LogicalResult NotEqualOp::verify() { NotEqualOp op = *this; // If we allow inputs to have incompatible type, then nothing to do. - if (!op.incompatible_shape_error()) return success(); + if (!op.getIncompatibleShapeError()) return success(); // Otherwise, check inputs are broadcastable. return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast( @@ -163,9 +180,9 @@ void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x, LogicalResult OneHotOp::verify() { OneHotOp op = *this; - int64_t axis = op.axis(); + int64_t axis = op.getAxis(); - auto indices_ty = op.indices().getType().dyn_cast(); + auto indices_ty = op.getIndices().getType().dyn_cast(); if (indices_ty && !(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) { return op.emitOpError() @@ -178,18 +195,18 @@ LogicalResult OneHotOp::verify() { << ") to be -1 or between [0, rank(indices()))"; } - if (!IsOfRankOrUnranked(op.depth(), 0)) { + if (!IsOfRankOrUnranked(op.getDepth(), 0)) { return op.emitOpError() << "requires depth to be a scalar"; } - if (!IsOfRankOrUnranked(op.on_value(), 0)) { + if (!IsOfRankOrUnranked(op.getOnValue(), 0)) { return op.emitOpError() << "requires on_value to be a scalar"; } - if (!IsOfRankOrUnranked(op.off_value(), 0)) { + if (!IsOfRankOrUnranked(op.getOffValue(), 0)) { return op.emitOpError() << "requires off_value to be a scalar"; } DenseIntElementsAttr depth_attr; - if (matchPattern(op.depth(), m_Constant(&depth_attr))) { + if (matchPattern(op.getDepth(), m_Constant(&depth_attr))) { if (depth_attr.getType().getRank() != 0) return op.emitOpError() << "requires depth to be a scalar"; int64_t depth = depth_attr.getValues()[0].getSExtValue(); @@ -214,7 +231,7 @@ static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value, auto shape = llvm::to_vector<2>(indices_ty.getShape()); if (axis_val == -1) axis_val = shape.size(); - int64_t depth_val = ShapedType::kDynamicSize; + int64_t depth_val = ShapedType::kDynamic; DenseIntElementsAttr depth_attr; if (matchPattern(depth, m_Constant(&depth_attr)) && depth_attr.getNumElements() == 1) @@ -238,7 +255,7 @@ void OneHotOp::build(OpBuilder &builder, OperationState &result, Value indices, LogicalResult PackOp::verify() { PackOp op = *this; // TODO(hinsu): Convert variadic length attributes to derived attributes. - Operation::operand_range values = op.values(); + Operation::operand_range values = op.getValues(); if (failed(VerifyTypesCompatibility(values, /*mask_one_dim=*/false, @@ -262,7 +279,7 @@ LogicalResult PackOp::verify() { // the axis value range is [-(R+1), R+1). int64_t range_begin = -inputs_rank - 1; // Inclusive int64_t range_end = inputs_rank + 1; // Exclusive - int64_t axis = op.axis(); + int64_t axis = op.getAxis(); if (axis < range_begin || axis >= range_end) { return op.emitError() << "attribute 'axis' should be within range [" << range_begin << ", " << range_end @@ -272,7 +289,7 @@ LogicalResult PackOp::verify() { return success(); } -OpFoldResult PackOp::fold(ArrayRef operands) { +OpFoldResult PackOp::fold(FoldAdaptor) { // Fold pack operation if it computes the input tensor shape: // // %shape = tf.Shape(%arg) // [? x ...] @@ -284,40 +301,44 @@ OpFoldResult PackOp::fold(ArrayRef operands) { // batch size. // Pack operation should pack at least two values. - if (values().size() < 2) return {}; + if (getValues().size() < 2) return {}; // Dimensions packed along axis = 0 (pack scalars into vector). - if (axis() != 0) return {}; + if (getAxis() != 0) return {}; // First packed value is defined by a strided slice operation. - auto slice_op = dyn_cast_or_null(values()[0].getDefiningOp()); + auto slice_op = + dyn_cast_or_null(getValues()[0].getDefiningOp()); if (!slice_op) return {}; // Input to the slice op is defined by shape operation. - auto shape_op = dyn_cast_or_null(slice_op.input().getDefiningOp()); + auto shape_op = + dyn_cast_or_null(slice_op.getInput().getDefiningOp()); if (!shape_op) return {}; // Input tensor, which shape is reconstructed by the pack operation. - Value tensor = shape_op.input(); + Value tensor = shape_op.getInput(); // All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing // scalar value from input vector). - if (slice_op.begin_mask() != 0 || slice_op.ellipsis_mask() != 0 || - slice_op.end_mask() != 0 || slice_op.new_axis_mask() != 0 || - slice_op.shrink_axis_mask() != 1) + if (slice_op.getBeginMask() != 0 || slice_op.getEllipsisMask() != 0 || + slice_op.getEndMask() != 0 || slice_op.getNewAxisMask() != 0 || + slice_op.getShrinkAxisMask() != 1) return {}; // Returns a value if the `value` is defined by a ConstOp with a single // integer element in it and has an expected rank. - auto get_const_int = [](Value value, int expected_rank) -> Optional { + auto get_const_int = [](Value value, + int expected_rank) -> std::optional { auto const_op = dyn_cast_or_null(value.getDefiningOp()); - if (!const_op) return None; + if (!const_op) return std::nullopt; - auto value_attr = const_op.value().dyn_cast(); - if (!value_attr || value_attr.getNumElements() != 1) return None; + auto value_attr = const_op.getValue().dyn_cast(); + if (!value_attr || value_attr.getNumElements() != 1) return std::nullopt; auto value_ty = value_attr.getType(); - if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None; + if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) + return std::nullopt; auto splat = value_attr.getSplatValue(); return splat.getValue().getSExtValue(); @@ -325,8 +346,8 @@ OpFoldResult PackOp::fold(ArrayRef operands) { // All other packed values are scalar constants. SmallVector packed_dims; - packed_dims.reserve(values().size() - 1); - for (Value operand : llvm::drop_begin(values(), 1)) { + packed_dims.reserve(getValues().size() - 1); + for (Value operand : llvm::drop_begin(getValues(), 1)) { if (auto dim = get_const_int(operand, /*expected_rank=*/0)) { packed_dims.push_back(*dim); } else { @@ -336,9 +357,9 @@ OpFoldResult PackOp::fold(ArrayRef operands) { // Slice exactly the first shape dimension: // begin = [0] end = [1], strides = [1] - auto begin = get_const_int(slice_op.begin(), /*expected_rank=*/1); - auto end = get_const_int(slice_op.end(), /*expected_rank=*/1); - auto strides = get_const_int(slice_op.strides(), /*expected_rank=*/1); + auto begin = get_const_int(slice_op.getBegin(), /*expected_rank=*/1); + auto end = get_const_int(slice_op.getEnd(), /*expected_rank=*/1); + auto strides = get_const_int(slice_op.getStrides(), /*expected_rank=*/1); if (!begin.has_value() || !end.has_value() || !strides.has_value() || *begin != 0 || *end != 1 || *strides != 1) return {}; @@ -350,7 +371,7 @@ OpFoldResult PackOp::fold(ArrayRef operands) { return {}; // Argument tensor rank is equal to the number of packed dimensions. - if (arg_ty.getRank() != values().size()) return {}; + if (arg_ty.getRank() != getValues().size()) return {}; // All other dimensions are statically known and equal to packed dims. auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1); @@ -358,7 +379,7 @@ OpFoldResult PackOp::fold(ArrayRef operands) { return {}; // Replace %pack with %shape. - return slice_op.input(); + return slice_op.getInput(); } // Convert Pack to Reshape when there is only one operand to be packed. @@ -376,13 +397,13 @@ struct ConvertPackToReshape : public OpRewritePattern { LogicalResult matchAndRewrite(PackOp pack_op, PatternRewriter &rewriter) const override { // Check if there is only one operand to be packed. - if (pack_op.N() != 1) { + if (pack_op.getN() != 1) { return failure(); } // Check if input and output are static. auto input_ty = pack_op.getOperand(0).getType().cast(); - auto output_ty = pack_op.output().getType().cast(); + auto output_ty = pack_op.getOutput().getType().cast(); if (!input_ty.hasStaticShape() || !output_ty.hasStaticShape()) { return failure(); } @@ -411,10 +432,11 @@ void PackOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult PadOp::FoldOperandsPermutation(ArrayRef permutation) { // Paddings must be defined by a constant operation. - auto paddings_op = dyn_cast_or_null(paddings().getDefiningOp()); + auto paddings_op = + dyn_cast_or_null(getPaddings().getDefiningOp()); if (!paddings_op) return failure(); - auto paddings_value = paddings_op.value().dyn_cast(); + auto paddings_value = paddings_op.getValue().dyn_cast(); if (!paddings_value || paddings_value.getNumElements() != permutation.size() * 2) return failure(); @@ -440,7 +462,8 @@ LogicalResult PadOp::FoldOperandsPermutation(ArrayRef permutation) { // Change the result type. getResult().setType(ShuffleRankedTensorType(getResult().getType(), - ReversePermutation(permutation))); + ReversePermutation(permutation)) + .cast()); return success(); } @@ -460,9 +483,9 @@ LogicalResult ParseExampleV2Op::verify() { // NOTE(mrry): The Tdense attr is derived from dense_defaults, so we // do not need to validate dense_defaults. auto dense_types_count = - std::distance(op.Tdense().begin(), op.Tdense().end()); + std::distance(op.getTdense().begin(), op.getTdense().end()); auto dense_values_count = - std::distance(op.dense_values().begin(), op.dense_values().end()); + std::distance(op.getDenseValues().begin(), op.getDenseValues().end()); if (dense_values_count != dense_types_count) { return op.emitError() << "output 'dense_values' should have same length " << "as attribute 'Tdense'"; @@ -472,25 +495,25 @@ LogicalResult ParseExampleV2Op::verify() { // NOTE(mrry): The sparse_types attr is derived from sparse_values, so we // do not need to validate sparse_values. auto sparse_types_count = - std::distance(op.sparse_types().begin(), op.sparse_types().end()); - if (op.num_sparse() != sparse_types_count) { + std::distance(op.getSparseTypes().begin(), op.getSparseTypes().end()); + if (op.getNumSparse() != sparse_types_count) { return op.emitError() << "attribute 'num_sparse' should be the same as " << "the length of attribute 'sparse_types'"; } - if (op.sparse_indices().size() != sparse_types_count) { + if (op.getSparseIndices().size() != sparse_types_count) { return op.emitError() << "output 'sparse_indices' should have same length " << "as attribute 'sparse_types'"; } - if (op.sparse_shapes().size() != sparse_types_count) { + if (op.getSparseShapes().size() != sparse_types_count) { return op.emitError() << "output 'sparse_shapes' should have same length " << "as attribute 'sparse_types'"; } // Validate ragged variadic output lengths. - auto ragged_value_types_count = std::distance(op.ragged_value_types().begin(), - op.ragged_value_types().end()); - auto ragged_split_types_count = std::distance(op.ragged_split_types().begin(), - op.ragged_split_types().end()); + auto ragged_value_types_count = std::distance( + op.getRaggedValueTypes().begin(), op.getRaggedValueTypes().end()); + auto ragged_split_types_count = std::distance( + op.getRaggedSplitTypes().begin(), op.getRaggedSplitTypes().end()); if (ragged_value_types_count != ragged_split_types_count) { return op.emitError() << "attribute 'ragged_value_types' should have same " << "length as attribute 'ragged_split_types'"; @@ -515,7 +538,7 @@ static LogicalResult VerifyPartitionedCall(CallOpClass op, FunctionType function_ty = function.getFunctionType(); int func_arg_count = function_ty.getNumInputs(); - int arg_count = op.args().size(); + int arg_count = op.getArgs().size(); if (arg_count != func_arg_count) { return op.emitError() << "argument count mismatch: 'args' has " << arg_count @@ -543,7 +566,8 @@ LogicalResult TPUPartitionedCallOp::verifySymbolUses( // PowOp //===----------------------------------------------------------------------===// -OpFoldResult PowOp::fold(ArrayRef operands) { +OpFoldResult PowOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); auto constant_y = operands[1].dyn_cast_or_null(); if (constant_y && constant_y.isSplat()) { APFloat y_value = constant_y.getSplatValue(); @@ -554,7 +578,7 @@ OpFoldResult PowOp::fold(ArrayRef operands) { FloatAttr::get(output_type.getElementType(), /*value=*/1.0)); } if (y_value.isExactlyValue(1.0)) { - return x(); + return getX(); } } return {}; @@ -580,12 +604,12 @@ void QuantizeAndDequantizeV2Op::getCanonicalizationPatterns( // LogicalResult QrOp::verify() { QrOp op = *this; - auto ttype = op.input().getType().cast(); + auto ttype = op.getInput().getType().cast(); if (!ttype.hasRank()) return success(); - if (!HasRankAtLeast(op.input(), 2)) + if (!HasRankAtLeast(op.getInput(), 2)) return op.emitOpError( "requires ranked input tensor to be of rank 2 or more"); - if (!HasRankAtMost(op.input(), std::numeric_limits::max())) + if (!HasRankAtMost(op.getInput(), std::numeric_limits::max())) return op.emitOpError( "requires ranked input tensor to be of rank INT32_MAX or less"); @@ -607,7 +631,7 @@ void ReadVariableOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult RandomUniformOp::verify() { RandomUniformOp op = *this; - if (!IsOfRankOrUnranked(op.shape(), 1)) + if (!IsOfRankOrUnranked(op.getShape(), 1)) return op.emitOpError("shape must be 1D tensor"); return success(); } @@ -684,7 +708,8 @@ void RangeOp::build(OpBuilder &builder, OperationState &result, Value start, start, limit, delta); } -OpFoldResult RangeOp::fold(ArrayRef operands) { +OpFoldResult RangeOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); assert(operands.size() == 3); auto start_tensor = operands[0].dyn_cast_or_null(); auto limit_tensor = operands[1].dyn_cast_or_null(); @@ -742,8 +767,8 @@ void RankOp::build(OpBuilder &builder, OperationState &result, Value input) { } // This will create a constant value for RankOp of a ranked tensor. -OpFoldResult RankOp::fold(ArrayRef operands) { - auto type = input().getType(); +OpFoldResult RankOp::fold(FoldAdaptor) { + auto type = getInput().getType(); auto ranked_type = type.dyn_cast(); if (!ranked_type) return {}; @@ -765,7 +790,8 @@ void RealDivOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } -OpFoldResult RealDivOp::fold(ArrayRef operands) { +OpFoldResult RealDivOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); return IdentityArithmeticOpFolder(*this, operands); } @@ -805,7 +831,7 @@ LogicalResult GetReshapeOutputType(Value tensor, Value shape, // shape. if (shape_ty.hasStaticShape()) { llvm::SmallVector dynamic_shape(shape_ty.getDimSize(0), - ShapedType::kDynamicSize); + ShapedType::kDynamic); output_ty = tensorflow::GetTypeFromTFTensorShape(dynamic_shape, element_ty); } @@ -822,7 +848,7 @@ LogicalResult GetReshapeOutputType(Value tensor, Value shape, for (const auto &dim : llvm::enumerate(shape_attr.getValues())) { const int64_t size = dim.value().getSExtValue(); if (size == tensorflow::kTFDynamicSize || // NOLINT - size == ShapedType::kDynamicSize) { // NOLINT + size == ShapedType::kDynamic) { // NOLINT if (unknown_index != -1) return error_handler(llvm::formatv( "requires 'shape' to have at most one dynamic dimension, but got " @@ -885,13 +911,13 @@ LogicalResult ReshapeOp::verify() { return op.emitOpError() << message; }; TensorType expected_ty; - if (failed(GetReshapeOutputType(op.tensor(), op.shape(), error_handler, + if (failed(GetReshapeOutputType(op.getTensor(), op.getShape(), error_handler, expected_ty))) return failure(); auto output_ty = op.getType().dyn_cast(); if (!output_ty) return success(); - auto tensor_ty = op.tensor().getType().cast(); + auto tensor_ty = op.getTensor().getType().cast(); if (output_ty.hasStaticShape() && tensor_ty.hasStaticShape()) { const int64_t output_ty_size = output_ty.getNumElements(); const int64_t tensor_ty_size = tensor_ty.getNumElements(); @@ -929,8 +955,8 @@ void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } -OpFoldResult ReshapeOp::fold(ArrayRef operands) { - Value tensor = this->tensor(); +OpFoldResult ReshapeOp::fold(FoldAdaptor) { + Value tensor = this->getTensor(); // Fold reshape if operand and result types are the same and all dimensions // are statically known (no-op reshape). @@ -956,8 +982,8 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) { // first dimension equal to `cond`. LogicalResult SelectOp::verify() { SelectOp op = *this; - auto then_tensor = op.then_value().getType().cast(); - auto else_tensor = op.else_value().getType().cast(); + auto then_tensor = op.getThenValue().getType().cast(); + auto else_tensor = op.getElseValue().getType().cast(); // Check (1). if (!AreCastCompatible({then_tensor, else_tensor})) return op.emitOpError() << "requires t and e have compatible shapes"; @@ -988,7 +1014,7 @@ LogicalResult SelectOp::verify() { return success(); } - auto cond_tensor = op.condition().getType().dyn_cast(); + auto cond_tensor = op.getCondition().getType().dyn_cast(); if (!cond_tensor) return success(); auto cond_rank = cond_tensor.getRank(); // Check (2a) and (2b). @@ -1001,9 +1027,8 @@ LogicalResult SelectOp::verify() { << "requires that t and e are nonscalar when pred is a vector"; } // We know `data` tensor has a rank of at least 1. - if (data_first_dim != ShapedType::kDynamicSize && - cond_shape != ShapedType::kDynamicSize && - data_first_dim != cond_shape) { + if (data_first_dim != ShapedType::kDynamic && + cond_shape != ShapedType::kDynamic && data_first_dim != cond_shape) { return op.emitOpError() << "requires that, when pred is a vector, the " "shape matches the first dimension of t and e"; } @@ -1099,7 +1124,7 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type, LogicalResult ShapeOp::verify() { ShapeOp op = *this; - return VerifyShapeOperandAndResult(op, op.input().getType(), op.getType()); + return VerifyShapeOperandAndResult(op, op.getInput().getType(), op.getType()); } // Converts shape of the given type to attribute if it is of ranked tensor type. @@ -1121,7 +1146,7 @@ static Attribute ConvertShapeToAttr(Type input_ty, int out_width) { return DenseElementsAttr::get(result_type, dimensions); } -OpFoldResult ShapeOp::fold(ArrayRef operands) { +OpFoldResult ShapeOp::fold(FoldAdaptor) { int width = getType().cast().getElementType().getIntOrFloatBitWidth(); return ConvertShapeToAttr(getOperand().getType(), width); @@ -1144,7 +1169,7 @@ void ShapeOp::build(OpBuilder &builder, OperationState &result, Value input, LogicalResult ShapeNOp::verify() { ShapeNOp op = *this; - const size_t num_tensors = op.N(); + const size_t num_tensors = op.getN(); if (op.getNumOperands() != num_tensors) return op.emitOpError() << "requires " << num_tensors << " operand(s), got " @@ -1243,18 +1268,18 @@ void ShapeNOp::getCanonicalizationPatterns(RewritePatternSet &results, // LogicalResult SizeOp::verify() { SizeOp op = *this; - if (!HasRankAtMost(op.input(), std::numeric_limits::max())) + if (!HasRankAtMost(op.getInput(), std::numeric_limits::max())) return op.emitOpError( "requires ranked input tensor to be of rank INT32_MAX or less"); // Output type needs to be scalar. - if (!IsOfRankOrUnranked(op.output(), /*rank=*/0)) + if (!IsOfRankOrUnranked(op.getOutput(), /*rank=*/0)) return op.emitOpError("requires scalar output"); return success(); } -OpFoldResult SizeOp::fold(ArrayRef operands) { +OpFoldResult SizeOp::fold(FoldAdaptor) { ShapedType output_type = getType().cast(); if (!output_type.hasRank()) return {}; ShapedType input_type = getOperand().getType().cast(); @@ -1284,12 +1309,12 @@ OpFoldResult SizeOp::fold(ArrayRef operands) { // LogicalResult SliceOp::verify() { SliceOp op = *this; - RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin()); + RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.getBegin()); if (begin_ty && begin_ty.getRank() != 1) { return op.emitOpError() << "requires begin operand to be 1D tensor"; } - RankedTensorType size_ty = GetRankedTensorTypeForOperand(op.size()); + RankedTensorType size_ty = GetRankedTensorTypeForOperand(op.getSize()); if (size_ty && size_ty.getRank() != 1) { return op.emitOpError() << "requires size operand to be 1D tensor"; } @@ -1303,13 +1328,13 @@ LogicalResult SliceOp::verify() { " same number of elements"; } - auto input_ty = op.input().getType().dyn_cast(); + auto input_ty = op.getInput().getType().dyn_cast(); if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) { return op.emitOpError() << "requires number of elements in begin and size " "are equal to input rank"; } - auto output_ty = op.output().getType().dyn_cast(); + auto output_ty = op.getOutput().getType().dyn_cast(); if (output_ty && input_ty && output_ty.getRank() != input_ty.getRank()) { return op.emitOpError() << "requires output to have the same rank as input, but got input " @@ -1318,35 +1343,35 @@ LogicalResult SliceOp::verify() { } DenseIntElementsAttr begin_indices; - if (matchPattern(op.begin(), m_Constant(&begin_indices))) { + if (matchPattern(op.getBegin(), m_Constant(&begin_indices))) { DenseIntElementsAttr slice_sizes; bool constant_slice_sizes = - matchPattern(op.size(), m_Constant(&slice_sizes)); + matchPattern(op.getSize(), m_Constant(&slice_sizes)); int dim = 0; // TODO(jpienaar): Reformulate the shape verification below to not use magic // constants. for (const APInt &raw_begin_index : begin_indices.getValues()) { int64_t begin_index = raw_begin_index.getSExtValue(); int64_t input_size = - input_ty ? input_ty.getShape()[dim] : ShapedType::kDynamicSize; + input_ty ? input_ty.getShape()[dim] : ShapedType::kDynamic; int64_t slice_size = constant_slice_sizes ? slice_sizes.getValues()[dim].getSExtValue() : 0; int64_t output_size = - output_ty ? output_ty.getShape()[dim] : ShapedType::kDynamicSize; + output_ty ? output_ty.getShape()[dim] : ShapedType::kDynamic; - if (slice_size == -1 && input_size != ShapedType::kDynamicSize) { + if (slice_size == -1 && input_size != ShapedType::kDynamic) { slice_size = input_size - begin_index; } - if (output_size != ShapedType::kDynamicSize && constant_slice_sizes && + if (output_size != ShapedType::kDynamic && constant_slice_sizes && output_size != slice_size) { return op.emitOpError() << "requires output size to have the same size of slice, got " "slice size " << slice_size << " and output size " << output_size; } - if (begin_index < 0 || (input_size != ShapedType::kDynamicSize && + if (begin_index < 0 || (input_size != ShapedType::kDynamic && begin_index + slice_size > input_size)) { return op.emitOpError() << "requires 0 <= begin[i] <= begin[i] + size[i] <= Di"; @@ -1356,12 +1381,12 @@ LogicalResult SliceOp::verify() { } else if (input_ty) { // If the inputs are ranked, we can do a few more sanity checks. DenseIntElementsAttr slice_sizes; - if (matchPattern(op.size(), m_Constant(&slice_sizes))) { + if (matchPattern(op.getSize(), m_Constant(&slice_sizes))) { auto input_shape = input_ty.getShape(); for (int64_t i = 0; i < input_ty.getRank(); ++i) { int64_t slice_size = slice_sizes.getValues()[i].getSExtValue(); int64_t input_size = input_shape[i]; - if (slice_size != -1 && input_size != ShapedType::kDynamicSize && + if (slice_size != -1 && input_size != ShapedType::kDynamic && slice_size > input_size) { return op.emitOpError() << "requires size[i] <= Di, even if begin[i] " "is unknown at compile time"; @@ -1379,7 +1404,7 @@ LogicalResult SliceOp::verify() { LogicalResult SoftmaxOp::verify() { SoftmaxOp op = *this; - if (!HasRankAtLeast(op.logits(), 1)) { + if (!HasRankAtLeast(op.getLogits(), 1)) { return op.emitOpError("requires operand to have rank at least 1"); } return success(); @@ -1395,9 +1420,10 @@ LogicalResult SoftmaxOp::verify() { // LogicalResult SoftmaxCrossEntropyWithLogitsOp::verify() { SoftmaxCrossEntropyWithLogitsOp op = *this; - auto broadcasted_ty = OpTrait::util::getBroadcastedType( - op.features().getType(), op.labels().getType()) - .dyn_cast_or_null(); + auto broadcasted_ty = + OpTrait::util::getBroadcastedType(op.getFeatures().getType(), + op.getLabels().getType()) + .dyn_cast_or_null(); if (!broadcasted_ty || (broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2)) return op.emitOpError( @@ -1423,18 +1449,18 @@ int64_t SpaceToBatchNDBlockRank(const TensorType block_shape_type, LogicalResult SpaceToBatchNDOp::verify() { SpaceToBatchNDOp op = *this; - const auto input_type = op.input().getType().cast(); - const auto block_shape_type = op.block_shape().getType().cast(); - const auto paddings_type = op.paddings().getType().cast(); + const auto input_type = op.getInput().getType().cast(); + const auto block_shape_type = op.getBlockShape().getType().cast(); + const auto paddings_type = op.getPaddings().getType().cast(); // Check that block_shape has rank 1. - if (!IsOfRankOrUnranked(op.block_shape(), 1)) { + if (!IsOfRankOrUnranked(op.getBlockShape(), 1)) { return op.emitOpError() << "requires rank of block_shape = 1; got " << block_shape_type.getRank(); } // Check that paddings has rank 2. - if (!IsOfRankOrUnranked(op.paddings(), 2)) { + if (!IsOfRankOrUnranked(op.getPaddings(), 2)) { return op.emitOpError() << "requires rank of paddings = 2; got " << paddings_type.getRank(); } @@ -1469,7 +1495,7 @@ LogicalResult SpaceToBatchNDOp::verify() { ElementsAttr paddings_attr = nullptr; // Check that block_shape[*] >= 1. - if (matchPattern(op.block_shape(), m_Constant(&block_shape_attr))) { + if (matchPattern(op.getBlockShape(), m_Constant(&block_shape_attr))) { uint64_t i = 0; for (auto block_len : block_shape_attr.getValues()) { if (block_len.getSExtValue() < 1) { @@ -1483,7 +1509,7 @@ LogicalResult SpaceToBatchNDOp::verify() { } // Check that paddings[*] >= 0. - if (matchPattern(op.paddings(), m_Constant(&paddings_attr))) { + if (matchPattern(op.getPaddings(), m_Constant(&paddings_attr))) { for (uint64_t i = 0; i < block_rank; ++i) { const int64_t pad_start = paddings_attr.getValues()[{i, 0}].getSExtValue(); @@ -1527,14 +1553,14 @@ LogicalResult SpaceToBatchNDOp::verify() { LogicalResult SparseSoftmaxCrossEntropyWithLogitsOp::verify() { SparseSoftmaxCrossEntropyWithLogitsOp op = *this; - if (!IsOfRankOrUnranked(op.features(), 2)) { + if (!IsOfRankOrUnranked(op.getFeatures(), 2)) { return op.emitOpError("requires features operand of rank two"); } - if (!IsOfRankOrUnranked(op.labels(), 1)) { + if (!IsOfRankOrUnranked(op.getLabels(), 1)) { return op.emitOpError("requires labels operand of rank one"); } - auto features_ty = op.features().getType().dyn_cast(); - auto labels_ty = op.labels().getType().dyn_cast(); + auto features_ty = op.getFeatures().getType().dyn_cast(); + auto labels_ty = op.getLabels().getType().dyn_cast(); if (features_ty && labels_ty) { int64_t features_batches = features_ty.getDimSize(0); int64_t labels_batches = labels_ty.getDimSize(0); @@ -1555,10 +1581,11 @@ LogicalResult SparseSoftmaxCrossEntropyWithLogitsOp::verify() { // Writes the split dimension's index (adjusted with input rank) via `dim_index` // if it's a constant. template -LogicalResult VerifySplitInputAndSplitDim(Op op, Optional *dim_index) { - *dim_index = llvm::None; +LogicalResult VerifySplitInputAndSplitDim(Op op, + std::optional *dim_index) { + *dim_index = std::nullopt; - Value split_dim = op.split_dim(); + Value split_dim = op.getSplitDim(); if (auto split_dim_type = split_dim.getType().dyn_cast()) if (split_dim_type.getRank() != 0) return op.emitOpError( @@ -1567,7 +1594,8 @@ LogicalResult VerifySplitInputAndSplitDim(Op op, Optional *dim_index) { // We can perform further verification if the input tensor to be split has // known rank and the split dimension tensor is a constant. - auto input_type = op.value().getType().template dyn_cast(); + auto input_type = + op.getValue().getType().template dyn_cast(); if (!input_type) return success(); int64_t input_rank = input_type.getRank(); @@ -1592,12 +1620,12 @@ LogicalResult VerifySplitInputAndSplitDim(Op op, Optional *dim_index) { LogicalResult SplitOp::verify() { SplitOp op = *this; - Optional dim_index; + std::optional dim_index; if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure(); if (!dim_index) return success(); int64_t input_dim_size = - op.value().getType().cast().getDimSize(*dim_index); + op.getValue().getType().cast().getDimSize(*dim_index); if (ShapedType::isDynamic(input_dim_size)) return success(); if (op.getNumResults() == 0) return failure(); @@ -1616,7 +1644,7 @@ LogicalResult SplitOp::verify() { LogicalResult SplitVOp::verify() { SplitVOp op = *this; auto split_sizes_type = - op.size_splits().getType().dyn_cast(); + op.getSizeSplits().getType().dyn_cast(); if (!split_sizes_type) return success(); if (split_sizes_type.getRank() != 1 || @@ -1625,22 +1653,22 @@ LogicalResult SplitVOp::verify() { return op.emitOpError("split sizes should be a 1D tensor of ") << op.getNumResults() << " elements"; - Optional dim_index = 0; + std::optional dim_index = 0; if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure(); if (!dim_index) return success(); int64_t input_dim_size = - op.value().getType().cast().getDimSize(*dim_index); + op.getValue().getType().cast().getDimSize(*dim_index); if (ShapedType::isDynamic(input_dim_size)) return success(); // If split sizes come from a constant, they must sum to the dimension size // along split_dim, and we can have no more than one dynamic dimension. DenseIntElementsAttr split_sizes_attr; - if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr))) + if (!matchPattern(op.getSizeSplits(), m_Constant(&split_sizes_attr))) return success(); int64_t total_dim_size = 0; // Total dimension size assigned to splits - llvm::Optional dynamic_dim_index; + std::optional dynamic_dim_index; SmallVector split_sizes; split_sizes.reserve( @@ -1690,13 +1718,13 @@ void SquareOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult SqueezeOp::verify() { SqueezeOp op = *this; - auto input_type = op.input().getType().dyn_cast(); + auto input_type = op.getInput().getType().dyn_cast(); if (!input_type) return success(); // Can't verify squeeze dims. int64_t input_rank = input_type.getRank(); for (const auto &squeeze_dim_apint : - op.squeeze_dims().getAsValueRange()) { + op.getSqueezeDims().getAsValueRange()) { int64_t squeeze_dim = squeeze_dim_apint.getSExtValue(); if (squeeze_dim < -input_rank || squeeze_dim >= input_rank) { return op.emitOpError() @@ -1717,7 +1745,8 @@ void SubOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } -OpFoldResult SubOp::fold(ArrayRef operands) { +OpFoldResult SubOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); return IdentityArithmeticOpFolder(*this, operands); } @@ -1732,16 +1761,16 @@ void SumOp::build(OpBuilder &builder, OperationState &result, Value input, } // TODO: Templatize this fold for all reduction ops. -OpFoldResult SumOp::fold(ArrayRef operands) { - auto input_ty = input().getType().template dyn_cast(); +OpFoldResult SumOp::fold(FoldAdaptor) { + auto input_ty = getInput().getType().template dyn_cast(); if (!input_ty) return {}; auto result_ty = getType().template dyn_cast(); if (!result_ty) return {}; // Bypass this op if the result has the same shape and type. This can happen // if the input tensor has size 0 or size 1. - if (!keep_dims() && input_ty == result_ty) { - return input(); + if (!getKeepDims() && input_ty == result_ty) { + return getInput(); } return {}; } @@ -1769,7 +1798,7 @@ static LogicalResult VerifyStridedSliceBase(OpTy op) { // Expected size for operands begin, end and strides vector operands. int64_t expected_size = -1; - for (Value val : {op.begin(), op.end(), op.strides()}) { + for (Value val : {op.getBegin(), op.getEnd(), op.getStrides()}) { auto operand_ty = val.getType().dyn_cast(); if (!operand_ty || !operand_ty.hasStaticShape()) { // TensorFlow constant ops may have non-static shape because the shape is @@ -1804,14 +1833,14 @@ static LogicalResult VerifyStridedSliceBase(OpTy op) { // If strides are constants, verify that none of the element is zero. DenseIntElementsAttr strides; - if (matchPattern(op.strides(), m_Constant(&strides))) { + if (matchPattern(op.getStrides(), m_Constant(&strides))) { if (llvm::is_contained(strides.getValues(), 0)) return op.emitOpError("requires non-zero strides"); } // Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there // exists only no more than one ellipsis. - uint32_t ellipsis_mask = op.ellipsis_mask(); + uint32_t ellipsis_mask = op.getEllipsisMask(); if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask)) return op.emitOpError("cannot have multiple ellipses"); @@ -2050,12 +2079,12 @@ bool StridedSliceOp::GetSlicedBoundRanges( // TODO(hinsu): Support lowering for ops with dynamic begin and end values // when it is possible to derive indices based on mask attributes. DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr; - if (!matchPattern(begin(), m_Constant(&sparse_begin_attr)) || - !matchPattern(end(), m_Constant(&sparse_end_attr)) || - !matchPattern(strides(), m_Constant(&sparse_strides_attr))) + if (!matchPattern(getBegin(), m_Constant(&sparse_begin_attr)) || + !matchPattern(getEnd(), m_Constant(&sparse_end_attr)) || + !matchPattern(getStrides(), m_Constant(&sparse_strides_attr))) return false; - auto input_ty = this->input().getType().dyn_cast(); + auto input_ty = this->getInput().getType().dyn_cast(); if (!input_ty || !input_ty.hasStaticShape()) return false; auto input_shape = llvm::to_vector<4>(input_ty.getShape()); @@ -2069,13 +2098,13 @@ bool StridedSliceOp::GetSlicedBoundRanges( sparse_strides.push_back(stride.getSExtValue()); CalculateSlicedShapeFromSparseIndices( - input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(), - end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(), + input_shape, sparse_begin, sparse_end, sparse_strides, getBeginMask(), + getEndMask(), getEllipsisMask(), getNewAxisMask(), getShrinkAxisMask(), slice_begin, slice_end, slice_stride); return true; } -OpFoldResult StridedSliceOp::fold(ArrayRef operands) { +OpFoldResult StridedSliceOp::fold(FoldAdaptor) { // Fold StridedSlice operation if it extracts statically known dimensions. // // For example, @@ -2093,7 +2122,7 @@ OpFoldResult StridedSliceOp::fold(ArrayRef operands) { // In this case %spatial_shape can be replaced with a constant [2, 3]. // Input to strided slice op is defined by shape operation. - auto shape_op = input().getDefiningOp(); + auto shape_op = getInput().getDefiningOp(); if (!shape_op) { return {}; } @@ -2101,9 +2130,9 @@ OpFoldResult StridedSliceOp::fold(ArrayRef operands) { // `begin`, `end` and `strides` should be constant in order to infer static // dimension. DenseIntElementsAttr begin_attr, end_attr, strides_attr; - if (!matchPattern(begin(), m_Constant(&begin_attr)) || - !matchPattern(end(), m_Constant(&end_attr)) || - !matchPattern(strides(), m_Constant(&strides_attr)) || + if (!matchPattern(getBegin(), m_Constant(&begin_attr)) || + !matchPattern(getEnd(), m_Constant(&end_attr)) || + !matchPattern(getStrides(), m_Constant(&strides_attr)) || begin_attr.getNumElements() != 1 || end_attr.getNumElements() != 1 || strides_attr.getNumElements() != 1) { return {}; @@ -2112,9 +2141,9 @@ OpFoldResult StridedSliceOp::fold(ArrayRef operands) { // Do not fold when `new_axis_mask` is set. It's likely to break the shape // of output. Typically, `new_axis_mask` is not set in this canonicalization // pattern. - if (new_axis_mask() != 0) return {}; + if (getNewAxisMask() != 0) return {}; - auto tensor_ty = shape_op.input().getType().dyn_cast(); + auto tensor_ty = shape_op.getInput().getType().dyn_cast(); // Only ranked tensor can be folded. if (!tensor_ty) return {}; @@ -2129,18 +2158,18 @@ OpFoldResult StridedSliceOp::fold(ArrayRef operands) { // Create `begin` and `end` from `*_mask`. Note that we don't care about // `new_axis_mask` as it can be inferred from `output_ty`. - if (shrink_axis_mask() == 1) { + if (getShrinkAxisMask() == 1) { // When `shrink_axis_mask` is set, output is always a scalar so only // one element is sliced. end_int = begin_int + 1; } - if (begin_mask() == 1) { + if (getBeginMask() == 1) { begin_int = (strides_int > 0) ? 0 : rank - 1; } - if (end_mask() == 1) { + if (getEndMask() == 1) { end_int = (strides_int > 0) ? rank : -1; } - if (ellipsis_mask() == 1) { + if (getEllipsisMask() == 1) { begin_int = 0; end_int = rank; } @@ -2172,10 +2201,11 @@ OpFoldResult StridedSliceOp::fold(ArrayRef operands) { // For unranked or dynamic output, we infer the output type to either a // scalar or a vector based on `shrink_axis_mask` because we have rejected // the case of `new_axis_mask` != 0. - auto output_elt_ty = output().getType().cast().getElementType(); - auto output_ty = output().getType().dyn_cast(); + auto output_elt_ty = + getOutput().getType().cast().getElementType(); + auto output_ty = getOutput().getType().dyn_cast(); if (!output_ty || !output_ty.hasStaticShape()) { - if (shrink_axis_mask() == 1) { + if (getShrinkAxisMask() == 1) { output_ty = tensorflow::GetTypeFromTFTensorShape({}, output_elt_ty); } else { output_ty = tensorflow::GetTypeFromTFTensorShape( @@ -2199,7 +2229,7 @@ OpFoldResult StridedSliceOp::fold(ArrayRef operands) { LogicalResult StridedSliceGradOp::verify() { StridedSliceGradOp op = *this; - auto shape_type = op.shape().getType().dyn_cast(); + auto shape_type = op.getShape().getType().dyn_cast(); if (shape_type && shape_type.getRank() != 1) return op.emitOpError("'shape' operand must be 1D tensor, but got ") << shape_type.getRank() << "D tensor"; @@ -2218,10 +2248,10 @@ bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges( SmallVectorImpl *slice_stride) { DenseIntElementsAttr shape_attr; DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr; - if (!matchPattern(shape(), m_Constant(&shape_attr)) || - !matchPattern(begin(), m_Constant(&sparse_begin_attr)) || - !matchPattern(end(), m_Constant(&sparse_end_attr)) || - !matchPattern(strides(), m_Constant(&sparse_strides_attr))) + if (!matchPattern(getShape(), m_Constant(&shape_attr)) || + !matchPattern(getBegin(), m_Constant(&sparse_begin_attr)) || + !matchPattern(getEnd(), m_Constant(&sparse_end_attr)) || + !matchPattern(getStrides(), m_Constant(&sparse_strides_attr))) return false; int rank = std::distance(shape_attr.begin(), shape_attr.end()); @@ -2241,8 +2271,8 @@ bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges( sparse_strides.push_back(stride.getSExtValue()); CalculateSlicedShapeFromSparseIndices( - *input_shape, sparse_begin, sparse_end, sparse_strides, begin_mask(), - end_mask(), ellipsis_mask(), new_axis_mask(), shrink_axis_mask(), + *input_shape, sparse_begin, sparse_end, sparse_strides, getBeginMask(), + getEndMask(), getEllipsisMask(), getNewAxisMask(), getShrinkAxisMask(), slice_begin, slice_end, slice_stride); return true; } @@ -2256,9 +2286,9 @@ SummaryWriterOp::GetResourceHandleValueAndIdList( llvm::SmallDenseMap &resource_handle_id_map, int64_t &next_id) { llvm::StringRef device = GetDeviceOrEmpty(getOperation()); - return {GetResourceHandleValueAndIdBase(container(), shared_name(), device, - writer(), resource_handle_id_map, - next_id)}; + return {GetResourceHandleValueAndIdBase(getContainer(), getSharedName(), + device, getWriter(), + resource_handle_id_map, next_id)}; } //===----------------------------------------------------------------------===// @@ -2268,11 +2298,11 @@ SummaryWriterOp::GetResourceHandleValueAndIdList( void TPUExecuteOp::getEffects( SmallVectorImpl> &effects) { - effects.reserve(args().size() + 1); + effects.reserve(getArgs().size() + 1); effects.emplace_back(MemoryEffects::Write::get(), ResourceEffects::TPUExecute::get()); - for (Value value : args()) { + for (Value value : getArgs()) { if (value.getType() .cast() .getElementType() @@ -2297,7 +2327,7 @@ void TPUExecuteOp::getEffects( LogicalResult TPUExecuteAndUpdateVariablesOp::verify() { TPUExecuteAndUpdateVariablesOp op = *this; int num_resource_args = 0; - for (Type arg_type : op.args().getTypes()) + for (Type arg_type : op.getArgs().getTypes()) if (arg_type.cast().getElementType().isa()) ++num_resource_args; @@ -2323,19 +2353,19 @@ LogicalResult TPUExecuteAndUpdateVariablesOp::verify() { }; return failure( - failed(check_attr(op.device_var_reads_indices(), + failed(check_attr(op.getDeviceVarReadsIndices(), /*name=*/"device_var_reads_indices", /*min=*/0)) || - failed(check_attr(op.device_var_updates_indices(), + failed(check_attr(op.getDeviceVarUpdatesIndices(), /*name=*/"device_var_updates_indices", /*min=*/-1))); } void TPUExecuteAndUpdateVariablesOp::getEffects( SmallVectorImpl> &effects) { - effects.reserve(device_var_reads_indices().size() + 1); + effects.reserve(getDeviceVarReadsIndices().size() + 1); effects.emplace_back(MemoryEffects::Write::get(), ResourceEffects::TPUExecute::get()); - auto resource_handles = llvm::make_filter_range(args(), [](Value value) { + auto resource_handles = llvm::make_filter_range(getArgs(), [](Value value) { return value.getType() .cast() .getElementType() @@ -2346,7 +2376,7 @@ void TPUExecuteAndUpdateVariablesOp::getEffects( Value value = entry.value(); effects.emplace_back(MemoryEffects::Read::get(), value, ResourceEffects::Variable::get()); - if (device_var_updates_indices() + if (getDeviceVarUpdatesIndices() .getValue()[entry.index()] .cast() .getInt() >= 0) @@ -2371,7 +2401,7 @@ class ConvertTensorListGetItemOpOfTensorListFromTensorOpToGather // Checks that the input is created by TensorListFromTensorOp and the input // is only used by TensorListGetItemOp. auto tensor_list_from_tensor_op = dyn_cast_or_null( - op.input_handle().getDefiningOp()); + op.getInputHandle().getDefiningOp()); if (!tensor_list_from_tensor_op || llvm::any_of( tensor_list_from_tensor_op->getUsers(), @@ -2380,7 +2410,8 @@ class ConvertTensorListGetItemOpOfTensorListFromTensorOpToGather } rewriter.replaceOpWithNewOp( - op, op.getType(), tensor_list_from_tensor_op.tensor(), op.index()); + op, op.getType(), tensor_list_from_tensor_op.getTensor(), + op.getIndex()); return success(); } }; @@ -2405,12 +2436,12 @@ LogicalResult TensorListReserveOp::verify() { return emitOpError( "must have exactly one subtype in the result variant type"); } - if (!IsOfRankOrUnranked(op.element_shape(), 0) && - !IsOfRankOrUnranked(op.element_shape(), 1)) { + if (!IsOfRankOrUnranked(op.getElementShape(), 0) && + !IsOfRankOrUnranked(op.getElementShape(), 1)) { return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); } - if (!IsOfRankOrUnranked(op.num_elements(), 0)) { + if (!IsOfRankOrUnranked(op.getNumElements(), 0)) { return op.emitOpError("requires num_elements operand to be 0D tensor"); } return success(); @@ -2420,7 +2451,7 @@ LogicalResult TensorListReserveOp::verify() { // TensorListElementShapeOp //===----------------------------------------------------------------------===// -OpFoldResult TensorListElementShapeOp::fold(ArrayRef operands) { +OpFoldResult TensorListElementShapeOp::fold(FoldAdaptor) { int width = getType().cast().getElementType().getIntOrFloatBitWidth(); auto variant_type = @@ -2435,8 +2466,8 @@ OpFoldResult TensorListElementShapeOp::fold(ArrayRef operands) { LogicalResult TensorListStackOp::verify() { TensorListStackOp op = *this; - if (!IsOfRankOrUnranked(op.element_shape(), 0) && - !IsOfRankOrUnranked(op.element_shape(), 1)) { + if (!IsOfRankOrUnranked(op.getElementShape(), 0) && + !IsOfRankOrUnranked(op.getElementShape(), 1)) { return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); } return success(); @@ -2448,15 +2479,15 @@ LogicalResult TensorListStackOp::verify() { LogicalResult TensorScatterUpdateOp::verify() { TensorScatterUpdateOp op = *this; - if (!HasRankAtLeast(op.tensor(), 1)) + if (!HasRankAtLeast(op.getTensor(), 1)) return op.emitOpError( "requires tensor operand to have at least 1 dimension"); - if (!HasRankAtLeast(op.indices(), 1)) + if (!HasRankAtLeast(op.getIndices(), 1)) return op.emitOpError( "requires indices operand to have at least 1 dimension"); - auto tensor_ty = op.tensor().getType().dyn_cast(); - auto indices_ty = op.indices().getType().dyn_cast(); + auto tensor_ty = op.getTensor().getType().dyn_cast(); + auto indices_ty = op.getIndices().getType().dyn_cast(); if (!tensor_ty || !indices_ty) return success(); int64_t num_index_dims = indices_ty.getShape().back(); @@ -2485,9 +2516,10 @@ LogicalResult TensorScatterUpdateOp::verify() { LogicalResult TileOp::verify() { TileOp op = *this; - auto input_type = op.input().getType().dyn_cast(); - auto multiples_type = op.multiples().getType().dyn_cast(); - auto output_type = op.output().getType().dyn_cast(); + auto input_type = op.getInput().getType().dyn_cast(); + auto multiples_type = + op.getMultiples().getType().dyn_cast(); + auto output_type = op.getOutput().getType().dyn_cast(); if (multiples_type && multiples_type.getRank() != 1) { return op.emitOpError() << "expected multiples to be rank 1, got rank = " @@ -2512,7 +2544,7 @@ LogicalResult TileOp::verify() { } DenseIntElementsAttr multiples_attr; - if (matchPattern(op.multiples(), m_Constant(&multiples_attr))) { + if (matchPattern(op.getMultiples(), m_Constant(&multiples_attr))) { for (int32_t i = 0, e = input_type.getRank(); i < e; ++i) { const int64_t input_dim = input_type.getDimSize(i); const int64_t output_dim = output_type.getDimSize(i); @@ -2538,14 +2570,14 @@ LogicalResult TileOp::verify() { return success(); } -OpFoldResult TileOp::fold(ArrayRef operands) { +OpFoldResult TileOp::fold(FoldAdaptor) { DenseIntElementsAttr multiples_attr; - if (matchPattern(multiples(), m_Constant(&multiples_attr))) { + if (matchPattern(getMultiples(), m_Constant(&multiples_attr))) { // Return input directly when multiples are all ones, // regardless what input is. if (multiples_attr.isSplat() && multiples_attr.getSplatValue().getSExtValue() == 1) { - return input(); + return getInput(); } } return {}; @@ -2557,11 +2589,11 @@ OpFoldResult TileOp::fold(ArrayRef operands) { LogicalResult TopKV2Op::verify() { TopKV2Op op = *this; - if (!HasRankAtLeast(op.input(), 1)) + if (!HasRankAtLeast(op.getInput(), 1)) return op.emitOpError( "requires input operand to have at least 1 dimension"); - if (!IsOfRankOrUnranked(op.k(), 0)) + if (!IsOfRankOrUnranked(op.getK(), 0)) return op.emitOpError("requires k operand to be 0D tensor"); return success(); @@ -2627,7 +2659,7 @@ void ToBoolOp::getCanonicalizationPatterns(RewritePatternSet &results, } LogicalResult ToBoolOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, + MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back( @@ -2635,15 +2667,41 @@ LogicalResult ToBoolOp::inferReturnTypes( return success(); } +//===----------------------------------------------------------------------===// +// TPUPartitionedInputV2 +//===----------------------------------------------------------------------===// + +// This method mimics this op's core/TF-level shape inference logic +LogicalResult TPUPartitionedInputV2Op::verify() { + TPUPartitionedInputV2Op op = *this; + + int num_partitions = 1; + const mlir::ArrayAttr partition_dims = op.getPartitionDims(); + for (const mlir::Attribute &dim : partition_dims) { + num_partitions *= dim.cast().getInt(); + } + + const bool is_packed = op.getIsPacked(); + const bool replicated = partition_dims.empty(); + const int num_inputs_expected = is_packed ? 1 : num_partitions; + + if (!((replicated && !is_packed) || (op.getN() == num_inputs_expected))) { + return op.emitOpError() << "expected " << num_inputs_expected + << " inputs, got " << op.getN(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// LogicalResult TransposeOp::verify() { TransposeOp op = *this; - auto perm_type = op.perm().getType().dyn_cast(); - auto x_type = op.x().getType().dyn_cast(); - auto y_type = op.y().getType().dyn_cast(); + auto perm_type = op.getPerm().getType().dyn_cast(); + auto x_type = op.getX().getType().dyn_cast(); + auto y_type = op.getY().getType().dyn_cast(); if (perm_type && perm_type.getRank() != 1) { return op.emitOpError() @@ -2669,7 +2727,7 @@ LogicalResult TransposeOp::verify() { } DenseIntElementsAttr attr_perm; - if (matchPattern(op.perm(), m_Constant(&attr_perm))) { + if (matchPattern(op.getPerm(), m_Constant(&attr_perm))) { // y.shape[i] should be equal to x.shape[perm[i]] // for i = [0, 1, ..., rank(x) - 1] for (auto e : llvm::enumerate(attr_perm)) { @@ -2728,7 +2786,7 @@ namespace { OpFoldResult FoldIdentityTranspose(TransposeOp op) { DenseIntElementsAttr perm; - if (!matchPattern(op.perm(), m_Constant(&perm))) return {}; + if (!matchPattern(op.getPerm(), m_Constant(&perm))) return {}; const auto elements = perm.getValues(); for (auto it : llvm::enumerate(elements)) { @@ -2736,37 +2794,37 @@ OpFoldResult FoldIdentityTranspose(TransposeOp op) { } // TODO(jpienaar): Remove if/when we handle this more generally. - if (op.getType() != op.x().getType()) { + if (op.getType() != op.getX().getType()) { // If the types don't match then only fold if all the operands are in the TF // dialect. for (auto user : op.getOperation()->getUsers()) if (user->getDialect() != op->getDialect()) return {}; } - return op.x(); + return op.getX(); } OpFoldResult FoldCancellableTranspose(TransposeOp op) { // Operand is a TransposeOp. - auto transpose = dyn_cast_or_null(op.x().getDefiningOp()); + auto transpose = dyn_cast_or_null(op.getX().getDefiningOp()); if (!transpose) return {}; // Permutations defined by constant operations. DenseIntElementsAttr perm0; DenseIntElementsAttr perm1; - if (!matchPattern(op.perm(), m_Constant(&perm0)) || - !matchPattern(transpose.perm(), m_Constant(&perm1))) + if (!matchPattern(op.getPerm(), m_Constant(&perm0)) || + !matchPattern(transpose.getPerm(), m_Constant(&perm1))) return {}; // With permutation indices that cancel each other if (!AreCancellablePermutations(perm0, perm1)) return {}; - return transpose.x(); + return transpose.getX(); } } // namespace -OpFoldResult TransposeOp::fold(ArrayRef operands) { +OpFoldResult TransposeOp::fold(FoldAdaptor) { if (auto folded = FoldIdentityTranspose(*this)) return folded; if (auto folded = FoldCancellableTranspose(*this)) return folded; return {}; @@ -2804,9 +2862,9 @@ class NMSV3ToNMSV4Op : public OpRewritePattern { new_result_types.push_back(valid_output_type); auto nmsv4 = rewriter.create( - nms_op.getLoc(), new_result_types, nms_op.boxes(), nms_op.scores(), - nms_op.max_output_size(), nms_op.iou_threshold(), - nms_op.score_threshold()); + nms_op.getLoc(), new_result_types, nms_op.getBoxes(), + nms_op.getScores(), nms_op.getMaxOutputSize(), nms_op.getIouThreshold(), + nms_op.getScoreThreshold()); // Cannot replace the NMSv3 Op with NMSv4 since the outputs between the // two are different (v4 expects two output values vs v3 requires only one. nms_op.replaceAllUsesWith(nmsv4.getResult(0)); @@ -2861,11 +2919,11 @@ void FusedBatchNormOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult UnpackOp::verify() { UnpackOp op = *this; - auto value_type = op.value().getType().dyn_cast(); + auto value_type = op.getValue().getType().dyn_cast(); if (!value_type) return success(); int64_t value_rank = value_type.getRank(); - int64_t axis = op.axis(); + int64_t axis = op.getAxis(); if (axis < -value_rank || axis >= value_rank) return op.emitOpError("axis attribute must be in the range of [-") << value_rank << ", " << value_rank << ')'; @@ -2924,7 +2982,7 @@ LogicalResult HoistCwiseUnaryOutOfUnpack::matchAndRewrite( // Unpack results after applying unary operation. auto unpack_unary_op = rewriter.create( - loc, op.getResultTypes(), new_unary_op->getResult(0), op.axis()); + loc, op.getResultTypes(), new_unary_op->getResult(0), op.getAxis()); // Bypass all users of the original unpack operation and use `unpack_unary_op` // results instead. @@ -2954,12 +3012,12 @@ void UnpackOp::getCanonicalizationPatterns(RewritePatternSet &results, template static LogicalResult VerifyUnsortedSegmentReduction(Op op) { - if (!HasRankAtMost(op.num_segments(), 0)) + if (!HasRankAtMost(op.getNumSegments(), 0)) return op.emitOpError("number of segments should be a 0-D tensor"); - auto data_type = op.data().getType().template dyn_cast(); + auto data_type = op.getData().getType().template dyn_cast(); auto segment_ids_type = - op.segment_ids().getType().template dyn_cast(); + op.getSegmentIds().getType().template dyn_cast(); if (data_type && segment_ids_type) { if (data_type.getRank() < segment_ids_type.getRank()) return op.emitOpError( @@ -2982,7 +3040,7 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) { } DenseIntElementsAttr num_segments_attr; - if (matchPattern(op.num_segments(), m_Constant(&num_segments_attr))) { + if (matchPattern(op.getNumSegments(), m_Constant(&num_segments_attr))) { int64_t num_segments = (*num_segments_attr.begin()).getSExtValue(); if (num_segments < 0) return op.emitOpError("num of segments cannot be negative"); @@ -3024,9 +3082,9 @@ VarHandleOp::GetResourceHandleValueAndIdList( llvm::SmallDenseMap &resource_handle_id_map, int64_t &next_id) { llvm::StringRef device = GetDeviceOrEmpty(getOperation()); - return {GetResourceHandleValueAndIdBase(container(), shared_name(), device, - resource(), resource_handle_id_map, - next_id)}; + return {GetResourceHandleValueAndIdBase(getContainer(), getSharedName(), + device, getResource(), + resource_handle_id_map, next_id)}; } //===----------------------------------------------------------------------===// @@ -3070,7 +3128,7 @@ void VariableOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult VariableShapeOp::verify() { VariableShapeOp op = *this; - auto input_type = op.input().getType().cast(); + auto input_type = op.getInput().getType().cast(); if (input_type.hasStaticShape() && input_type.getNumElements() != 1) return op.emitOpError("requires input to have one resource"); @@ -3088,7 +3146,7 @@ LogicalResult VariableShapeOp::verify() { } } -OpFoldResult VariableShapeOp::fold(ArrayRef operands) { +OpFoldResult VariableShapeOp::fold(FoldAdaptor) { int width = getType().cast().getElementType().getIntOrFloatBitWidth(); auto resource_type = @@ -3162,14 +3220,14 @@ static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input, LogicalResult WhileOp::verifySymbolUses(SymbolTableCollection &symbol_table) { auto cond_fn = - symbol_table.lookupNearestSymbolFrom(*this, condAttr()); + symbol_table.lookupNearestSymbolFrom(*this, getCondAttr()); auto body_fn = - symbol_table.lookupNearestSymbolFrom(*this, bodyAttr()); + symbol_table.lookupNearestSymbolFrom(*this, getBodyAttr()); if (!cond_fn) { - return emitOpError("cond refers to an undefined function : ") << cond(); + return emitOpError("cond refers to an undefined function : ") << getCond(); } if (!body_fn) { - return emitOpError("body refers to an undefined function : ") << body(); + return emitOpError("body refers to an undefined function : ") << getBody(); } auto cond_fn_type = cond_fn.getFunctionType(); @@ -3182,7 +3240,7 @@ LogicalResult WhileOp::verifySymbolUses(SymbolTableCollection &symbol_table) { return VerifyWhileTypes(*this, /*cond_input=*/cond_fn_type.getInputs(), /*body_input=*/body_fn_type.getInputs(), /*body_result=*/body_fn_type.getResults(), - shape_invariant()); + getShapeInvariant()); } //===----------------------------------------------------------------------===// @@ -3191,7 +3249,7 @@ LogicalResult WhileOp::verifySymbolUses(SymbolTableCollection &symbol_table) { LogicalResult WhileRegionOp::verify() { WhileRegionOp op = *this; // Verify that the condition generates a single tensor result. - Operation *cond_yield = op.cond().front().getTerminator(); + Operation *cond_yield = op.getCond().front().getTerminator(); if (cond_yield->getNumOperands() != 1) return op.emitOpError() << "condition should have a single tensor result"; @@ -3203,11 +3261,12 @@ LogicalResult WhileRegionOp::verify() { return op.emitOpError() << "condition should have a single tensor result"; - Operation *body_yield = op.body().front().getTerminator(); - if (failed(VerifyWhileTypes(op, /*cond_input=*/op.cond().getArgumentTypes(), - /*body_input=*/op.body().getArgumentTypes(), + Operation *body_yield = op.getBody().front().getTerminator(); + if (failed(VerifyWhileTypes(op, + /*cond_input=*/op.getCond().getArgumentTypes(), + /*body_input=*/op.getBody().getArgumentTypes(), /*body_result=*/body_yield->getOperandTypes(), - op.shape_invariant()))) + op.getShapeInvariant()))) return failure(); return success(); } @@ -3216,7 +3275,7 @@ LogicalResult WhileRegionOp::verify() { // WhileRegionOp LoopLikeOpInterface //===----------------------------------------------------------------------===// -Region &WhileRegionOp::getLoopBody() { return body(); } +Region &WhileRegionOp::getLoopBody() { return getBody(); } //===----------------------------------------------------------------------===// // WhileRegionOp canonicalization @@ -3232,8 +3291,8 @@ struct WhileRegionExplicitCast : public OpRewritePattern { LogicalResult matchAndRewrite(WhileRegionOp while_op, PatternRewriter &rewriter) const override { - auto &body_block = while_op.body().front(); - auto &cond_block = while_op.cond().front(); + auto &body_block = while_op.getBody().front(); + auto &cond_block = while_op.getCond().front(); bool changed = false; for (int op_idx : llvm::seq(0, while_op.getNumOperands())) { auto body_arg = body_block.getArgument(op_idx); @@ -3267,8 +3326,8 @@ struct WhileRegionEliminatePassThrough // argument can be easily found. int old_num_operands = while_op.getNumOperands(); int new_num_operands = old_num_operands; - auto &body_block = while_op.body().front(); - auto &cond_block = while_op.cond().front(); + auto &body_block = while_op.getBody().front(); + auto &cond_block = while_op.getCond().front(); auto &yield = *body_block.getTerminator(); // Bit mask indicating which operands will be removed. @@ -3326,13 +3385,13 @@ struct WhileRegionEliminatePassThrough while_op->getAttrs()); // Move region bodies to the new while. - rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(), - new_while_op.cond().end()); - rewriter.inlineRegionBefore(while_op.body(), new_while_op.body(), - new_while_op.body().end()); + rewriter.inlineRegionBefore(while_op.getCond(), new_while_op.getCond(), + new_while_op.getCond().end()); + rewriter.inlineRegionBefore(while_op.getBody(), new_while_op.getBody(), + new_while_op.getBody().end()); - auto &new_cond_block = new_while_op.cond().front(); - auto &new_body_block = new_while_op.body().front(); + auto &new_cond_block = new_while_op.getCond().front(); + auto &new_body_block = new_while_op.getBody().front(); auto &new_yield = *new_body_block.getTerminator(); // Patch up the region bodies and yield. @@ -3375,12 +3434,12 @@ void XdivyOp::getCanonicalizationPatterns(RewritePatternSet &results, //===----------------------------------------------------------------------===// LogicalResult XlaBroadcastHelperOp::inferReturnTypeComponents( - MLIRContext *context, Optional location, ValueShapeRange operands, - DictionaryAttr attributes, RegionRange regions, + MLIRContext *context, std::optional location, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { XlaBroadcastHelperOpAdaptor op(operands.getValues(), attributes); - Value lhs = op.lhs(); - Value rhs = op.rhs(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); auto set_unranked_results = [&]() { inferredReturnShapes.emplace_back(getElementTypeOrSelf(lhs)); inferredReturnShapes.emplace_back(getElementTypeOrSelf(rhs)); @@ -3395,7 +3454,7 @@ LogicalResult XlaBroadcastHelperOp::inferReturnTypeComponents( int64_t rhs_rank = rhs_ty.getRank(); DenseIntElementsAttr dims; - if (!matchPattern(op.broadcast_dims(), m_Constant(&dims))) { + if (!matchPattern(op.getBroadcastDims(), m_Constant(&dims))) { return set_unranked_results(); } @@ -3459,10 +3518,10 @@ class XlaConvToV2 : public OpRewritePattern { PatternRewriter &rewriter) const override { SmallVector result_types{op.getResult().getType()}; rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), op.lhs(), op.rhs(), op.window_strides(), - op.padding(), op.lhs_dilation(), op.rhs_dilation(), - op.feature_group_count(), op.dimension_numbers(), op.precision_config(), - 1); + op, op.getResult().getType(), op.getLhs(), op.getRhs(), + op.getWindowStrides(), op.getPadding(), op.getLhsDilation(), + op.getRhsDilation(), op.getFeatureGroupCount(), + op.getDimensionNumbers(), op.getPrecisionConfig(), 1); return ::mlir::success(); }; }; @@ -3480,11 +3539,11 @@ LogicalResult XlaConvV2Op::verify() { XlaConvV2Op op = *this; DenseElementsAttr window_strides_attr, padding_attr, lhs_dilation_attr, rhs_dilation_attr, feature_group_count_attr; - if (!(matchPattern(op.window_strides(), m_Constant(&window_strides_attr)) && - matchPattern(op.padding(), m_Constant(&padding_attr)) && - matchPattern(op.lhs_dilation(), m_Constant(&lhs_dilation_attr)) && - matchPattern(op.rhs_dilation(), m_Constant(&rhs_dilation_attr)) && - matchPattern(op.feature_group_count(), + if (!(matchPattern(op.getWindowStrides(), m_Constant(&window_strides_attr)) && + matchPattern(op.getPadding(), m_Constant(&padding_attr)) && + matchPattern(op.getLhsDilation(), m_Constant(&lhs_dilation_attr)) && + matchPattern(op.getRhsDilation(), m_Constant(&rhs_dilation_attr)) && + matchPattern(op.getFeatureGroupCount(), m_Constant(&feature_group_count_attr)))) return success(); @@ -3513,12 +3572,12 @@ LogicalResult XlaConvV2Op::verify() { //===----------------------------------------------------------------------===// LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypeComponents( - MLIRContext *context, Optional location, ValueShapeRange operands, - DictionaryAttr attributes, RegionRange regions, + MLIRContext *context, std::optional location, + ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { XlaSetDynamicDimensionSizeOpAdaptor op(operands.getValues(), attributes); - TensorType operand_ty = op.input().getType().cast(); + TensorType operand_ty = op.getInput().getType().cast(); Type element_ty = operand_ty.getElementType(); TensorType result_ty; @@ -3526,7 +3585,7 @@ LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypeComponents( auto shape = llvm::to_vector<4>(operand_ty.getShape()); DenseIntElementsAttr dim_index_attr; - if (matchPattern(op.dim_index(), m_Constant(&dim_index_attr))) { + if (matchPattern(op.getDimIndex(), m_Constant(&dim_index_attr))) { int64_t dim_index = dim_index_attr.getValues()[0].getSExtValue(); int64_t rank = operand_ty.getRank(); @@ -3534,9 +3593,9 @@ LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypeComponents( return emitOptionalError(location, "dim_index (", dim_index, ") is out of range [0, ", rank, ")"); } - shape[dim_index] = ShapedType::kDynamicSize; + shape[dim_index] = ShapedType::kDynamic; } else { - shape.assign(shape.size(), ShapedType::kDynamicSize); + shape.assign(shape.size(), ShapedType::kDynamic); } result_ty = tensorflow::GetTypeFromTFTensorShape(shape, element_ty); } else { @@ -3558,12 +3617,12 @@ class XlaReduceToXlaVariadicReduceV2 LogicalResult matchAndRewrite(TF::XlaReduceOp op, PatternRewriter &rewriter) const override { - SmallVector inputs{op.input()}; - SmallVector init_values{op.init_value()}; + SmallVector inputs{op.getInput()}; + SmallVector init_values{op.getInitValue()}; SmallVector result_types{op.getResult().getType()}; rewriter.replaceOpWithNewOp( - op, result_types, inputs, init_values, op.dimensions_to_reduce(), - op.reducer()); + op, result_types, inputs, init_values, op.getDimensionsToReduce(), + op.getReducer()); return ::mlir::success(); }; }; @@ -3579,7 +3638,7 @@ void XlaReduceOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult XlaReduceWindowOp::verify() { XlaReduceWindowOp op = *this; - const auto &input_ty = op.input().getType().cast(); + const auto &input_ty = op.getInput().getType().cast(); auto check = [&](mlir::Value val, std::string attr_name) -> LogicalResult { ElementsAttr attr; @@ -3602,18 +3661,18 @@ LogicalResult XlaReduceWindowOp::verify() { return success(); }; - if (check(op.window_dimensions(), "window_dimensions").failed()) + if (check(op.getWindowDimensions(), "window_dimensions").failed()) return failure(); - if (check(op.window_strides(), "window_strides").failed()) return failure(); + if (check(op.getWindowStrides(), "window_strides").failed()) return failure(); - if (check(op.base_dilations(), "base_dilations").failed()) return failure(); + if (check(op.getBaseDilations(), "base_dilations").failed()) return failure(); - if (check(op.window_dilations(), "window_dilations").failed()) + if (check(op.getWindowDilations(), "window_dilations").failed()) return failure(); ElementsAttr padding; - if (matchPattern(op.padding(), m_Constant(&padding))) { + if (matchPattern(op.getPadding(), m_Constant(&padding))) { const ShapedType &padding_ty = padding.getType(); if (padding_ty.getRank() != 2 || padding_ty.getDimSize(1) != 2) { return op.emitOpError() @@ -3624,7 +3683,7 @@ LogicalResult XlaReduceWindowOp::verify() { auto module = op->getParentOfType(); auto func = dyn_cast_or_null( - SymbolTable::lookupSymbolIn(module, op.computation())); + SymbolTable::lookupSymbolIn(module, op.getComputation())); if (!func) { return op.emitOpError() << "has no reduction function specified"; } @@ -3647,7 +3706,7 @@ LogicalResult XlaReduceWindowOp::verify() { LogicalResult XlaSelectAndScatterOp::verify() { XlaSelectAndScatterOp op = *this; - auto input_ty = op.operand().getType().cast(); + auto input_ty = op.getOperand().getType().cast(); auto check = [&](mlir::Value val, std::string attr_name) -> LogicalResult { ElementsAttr attr; @@ -3664,13 +3723,13 @@ LogicalResult XlaSelectAndScatterOp::verify() { return success(); }; - if (check(op.window_dimensions(), "window_dimensions").failed()) + if (check(op.getWindowDimensions(), "window_dimensions").failed()) return failure(); - if (check(op.window_strides(), "window_strides").failed()) return failure(); + if (check(op.getWindowStrides(), "window_strides").failed()) return failure(); ElementsAttr padding; - if (matchPattern(op.padding(), m_Constant(&padding))) { + if (matchPattern(op.getPadding(), m_Constant(&padding))) { const ShapedType &padding_ty = padding.getType(); if (padding_ty.getRank() != 2 || padding_ty.getDimSize(1) != 2) { return op.emitOpError() @@ -3681,7 +3740,7 @@ LogicalResult XlaSelectAndScatterOp::verify() { auto module = op->getParentOfType(); auto select_func = dyn_cast_or_null( - SymbolTable::lookupSymbolIn(module, op.select())); + SymbolTable::lookupSymbolIn(module, op.getSelect())); if (!select_func) { return op.emitOpError() << "has no select function specified"; } @@ -3698,7 +3757,7 @@ LogicalResult XlaSelectAndScatterOp::verify() { << select_func_type.getResult(0); } auto scatter_func = dyn_cast_or_null( - SymbolTable::lookupSymbolIn(module, op.scatter())); + SymbolTable::lookupSymbolIn(module, op.getScatter())); if (!scatter_func) { return op.emitOpError() << "has no scatter function specified"; } @@ -3719,7 +3778,7 @@ LogicalResult XlaSelectAndScatterOp::verify() { LogicalResult XlaVariadicReduceOp::verify() { XlaVariadicReduceOp op = *this; // We rely on V2 for the majority of the checks. - const auto &input_ty = op.input().getType(); + const auto &input_ty = op.getInput().getType(); if (input_ty.empty()) return op.emitOpError() << "No input"; const auto &dtype = input_ty[0].cast().getElementType(); for (const auto &ty : input_ty) { @@ -3738,8 +3797,8 @@ class XlaVariadicReduceToV2 : public OpRewritePattern { PatternRewriter &rewriter) const override { mlir::TF::XlaVariadicReduceV2Op xla_variadic_reduce_v2_op = rewriter.create<::mlir::TF::XlaVariadicReduceV2Op>( - op.getLoc(), op.getResults().getTypes(), op.input(), - op.init_value(), op.dimensions_to_reduce(), op.reducer()); + op.getLoc(), op.getResults().getTypes(), op.getInput(), + op.getInitValue(), op.getDimensionsToReduce(), op.getReducer()); rewriter.replaceOp(op, xla_variadic_reduce_v2_op.getResults()); return ::mlir::success(); @@ -3757,11 +3816,11 @@ void XlaVariadicReduceOp::getCanonicalizationPatterns( LogicalResult XlaVariadicReduceV2Op::verify() { XlaVariadicReduceV2Op op = *this; - const auto &inputs_ty = op.inputs().getType(); + const auto &inputs_ty = op.getInputs().getType(); int n_inputs = inputs_ty.size(); if (n_inputs < 1) return op.emitOpError() << "No inputs"; - const auto &init_values_ty = op.init_values().getType(); + const auto &init_values_ty = op.getInitValues().getType(); int n_init_values = init_values_ty.size(); if (n_init_values != n_inputs) { return op.emitOpError() << "Number of inputs (" << n_inputs @@ -3782,7 +3841,7 @@ LogicalResult XlaVariadicReduceV2Op::verify() { } } - if (op.dimensions_to_reduce().size() > input_ty_0.getRank()) { + if (op.getDimensionsToReduce().size() > input_ty_0.getRank()) { return op.emitOpError() << "Invalid dimensions_to_reduce argument to XlaVariadicReduceV2"; } @@ -3799,7 +3858,7 @@ LogicalResult XlaVariadicReduceV2Op::verify() { auto module = op->getParentOfType(); auto function = dyn_cast_or_null( - SymbolTable::lookupSymbolIn(module, op.reducer())); + SymbolTable::lookupSymbolIn(module, op.getReducer())); if (!function) return op.emitOpError() << "No reducer"; if (!function.getBody().hasOneBlock()) return op.emitOpError() << "reducer has more than one block"; @@ -3813,7 +3872,7 @@ LogicalResult XlaVariadicReduceV2Op::verify() { LogicalResult XlaVariadicSortOp::verify() { XlaVariadicSortOp op = *this; - const auto &inputs_ty = op.inputs().getType(); + const auto &inputs_ty = op.getInputs().getType(); int n_inputs = inputs_ty.size(); auto input_ty_0 = inputs_ty[0].cast(); if (input_ty_0.hasStaticShape()) { @@ -3830,7 +3889,7 @@ LogicalResult XlaVariadicSortOp::verify() { } ElementsAttr dimension; - if (matchPattern(op.dimension(), m_Constant(&dimension))) { + if (matchPattern(op.getDimension(), m_Constant(&dimension))) { if (dimension.getType().getRank() != 0 || dimension.getType().getNumElements() != 1) return op.emitOpError() << "dimension must be a scalar"; @@ -3838,7 +3897,7 @@ LogicalResult XlaVariadicSortOp::verify() { auto module = op->getParentOfType(); auto function = dyn_cast_or_null( - SymbolTable::lookupSymbolIn(module, op.comparator())); + SymbolTable::lookupSymbolIn(module, op.getComparator())); if (!function) return op.emitOpError() << "No comparator"; if (!function.getBody().hasOneBlock()) return op.emitOpError() << "comparator has more than one block"; @@ -3853,9 +3912,10 @@ LogicalResult XlaVariadicSortOp::verify() { LogicalResult SetStaticDimensionBoundsOp::verify() { SetStaticDimensionBoundsOp op = *this; - mlir::ShapedType input_type = op.input().getType().cast(); + mlir::ShapedType input_type = + op.getInput().getType().cast(); mlir::ShapedType static_shape_type = - op.static_shape().getType().cast(); + op.getStaticShape().getType().cast(); int input_type_rank = input_type.hasRank() ? input_type.getRank() : -1; if (input_type_rank > 2) { return op.emitOpError() << "was used with an input tensor with rank > 2, " @@ -3876,6 +3936,134 @@ LogicalResult SetStaticDimensionBoundsOp::verify() { return success(); } +namespace { + +template +LogicalResult VerifyScalesAndZeroPoints(UniformQuantizedOp op, Value scales, + Value zero_points, + int32_t quantization_axis) { + ShapedType scales_type = scales.getType().cast(); + ShapedType zero_points_type = zero_points.getType().cast(); + + if (quantization_axis == -1) { + if (scales_type.hasRank() && scales_type.getRank() != 0) { + return op.emitOpError( + "quantization_axis is -1, scales must have 0 rank."); + } + if (zero_points_type.hasRank() && zero_points_type.getRank() != 0) { + return op.emitOpError( + "quantization_axis is -1, zero_points must have 0 rank."); + } + } else { + if (scales_type.hasRank() && scales_type.getRank() != 1) { + return op.emitOpError( + "quantization_axis is not -1, scales must have 1 rank."); + } + if (zero_points_type.hasRank() && zero_points_type.getRank() != 1) { + return op.emitOpError( + "quantization_axis is not -1, zero_points must have 1 rank."); + } + if (scales_type.hasStaticShape() && zero_points_type.hasStaticShape() && + scales_type.getNumElements() != zero_points_type.getNumElements()) { + return op.emitOpError( + "scales and zero points must have same number of elements."); + } + } + + return success(); +} + +} // namespace + +//===----------------------------------------------------------------------===// +// UniformQuantizedDotHybridOp +//===----------------------------------------------------------------------===// +// + +LogicalResult UniformQuantizedDotHybridOp::verify() { + UniformQuantizedDotHybridOp op = *this; + return VerifyScalesAndZeroPoints(op, op.getRhsScales(), op.getRhsZeroPoints(), + op.getRhsQuantizationAxis()); +} + +//===----------------------------------------------------------------------===// +// UniformQuantizedConvolutionHybridOp +//===----------------------------------------------------------------------===// +// + +LogicalResult UniformQuantizedConvolutionHybridOp::verify() { + UniformQuantizedConvolutionHybridOp op = *this; + return VerifyScalesAndZeroPoints(op, op.getRhsScales(), op.getRhsZeroPoints(), + op.getRhsQuantizationAxis()); +} + +//===----------------------------------------------------------------------===// +// UniformQuantizeOp +//===----------------------------------------------------------------------===// +// + +LogicalResult UniformQuantizeOp::verify() { + UniformQuantizeOp op = *this; + return VerifyScalesAndZeroPoints(op, op.getScales(), op.getZeroPoints(), + op.getQuantizationAxis()); +} + +//===----------------------------------------------------------------------===// +// UniformRequantizeOp +//===----------------------------------------------------------------------===// +// + +LogicalResult UniformRequantizeOp::verify() { + UniformRequantizeOp op = *this; + auto verify_input_params = VerifyScalesAndZeroPoints( + op, op.getInputScales(), op.getInputZeroPoints(), + op.getInputQuantizationAxis()); + if (failed(verify_input_params)) { + return failure(); + } + return VerifyScalesAndZeroPoints(op, op.getOutputScales(), + op.getOutputZeroPoints(), + op.getOutputQuantizationAxis()); +} + +//===----------------------------------------------------------------------===// +// UniformDequantizeOp +//===----------------------------------------------------------------------===// +// + +LogicalResult UniformDequantizeOp::verify() { + UniformDequantizeOp op = *this; + return VerifyScalesAndZeroPoints(op, op.getScales(), op.getZeroPoints(), + op.getQuantizationAxis()); +} + +//===----------------------------------------------------------------------===// +// UniformQuantizedDotOp +//===----------------------------------------------------------------------===// +// + +LogicalResult UniformQuantizedDotOp::verify() { + UniformQuantizedDotOp op = *this; + + auto verify_lhs_params = + VerifyScalesAndZeroPoints(op, op.getLhsScales(), op.getLhsZeroPoints(), + op.getLhsQuantizationAxis()); + if (failed(verify_lhs_params)) { + return failure(); + } + + auto verify_rhs_params = + VerifyScalesAndZeroPoints(op, op.getRhsScales(), op.getRhsZeroPoints(), + op.getRhsQuantizationAxis()); + if (failed(verify_rhs_params)) { + return failure(); + } + + return VerifyScalesAndZeroPoints(op, op.getOutputScales(), + op.getOutputZeroPoints(), + op.getOutputQuantizationAxis()); +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.cc index b2984078afd..24036d17b58 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.cc @@ -77,7 +77,7 @@ Type InferReductionOpType(Value input, Value reduction_indices, // Otherwise, output type has same rank as the input. return RankedTensorType::get( - SmallVector(rank, ShapedType::kDynamicSize), element_ty); + SmallVector(rank, ShapedType::kDynamic), element_ty); } int64_t num_reduce_dim = 0; @@ -112,9 +112,9 @@ Type InferReductionOpType(Value input, Value reduction_indices, // rank and match dimension sizes for all but one of the dimensions. LogicalResult VerifyTypesCompatibility(Operation::operand_type_range types, bool mask_one_dim, Operation *op) { - int64_t common_rank = ShapedType::kDynamicSize; + int64_t common_rank = ShapedType::kDynamic; llvm::SmallVector common_dims; - int64_t dim_to_mask = ShapedType::kDynamicSize; + int64_t dim_to_mask = ShapedType::kDynamic; // Initialize common_rank with rank of the first ranked type and verify that // following ranked types have the same rank. @@ -128,9 +128,9 @@ LogicalResult VerifyTypesCompatibility(Operation::operand_type_range types, if (!ranked_ty) continue; int64_t rank = ranked_ty.getRank(); - if (common_rank == ShapedType::kDynamicSize) { + if (common_rank == ShapedType::kDynamic) { common_rank = rank; - common_dims.resize(common_rank, ShapedType::kDynamicSize); + common_dims.resize(common_rank, ShapedType::kDynamic); } else if (common_rank != rank) { return op->emitError() << "operand type " << ranked_ty @@ -142,16 +142,16 @@ LogicalResult VerifyTypesCompatibility(Operation::operand_type_range types, if (i == dim_to_mask) continue; int64_t dim = ranked_ty.getDimSize(i); - if (dim == ShapedType::kDynamicSize) continue; + if (dim == ShapedType::kDynamic) continue; int64_t &common_dim = common_dims[i]; - if (common_dim == ShapedType::kDynamicSize) { + if (common_dim == ShapedType::kDynamic) { common_dim = dim; } else if (common_dim != dim) { // If mask_one_dim is true, do not emit an error if this is the only // dimension with mismatches. Note down the dimension to mask it from // the following types. - if (mask_one_dim && dim_to_mask == ShapedType::kDynamicSize) { + if (mask_one_dim && dim_to_mask == ShapedType::kDynamic) { dim_to_mask = i; continue; } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc index 4559294c41c..8177609211b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -85,7 +86,7 @@ namespace { LogicalResult _XlaHostComputeMlirOp::verify() { _XlaHostComputeMlirOp op = *this; // Extract the module and function. - StringRef host_module = op.host_mlir_module(); + StringRef host_module = op.getHostMlirModule(); if (host_module.empty()) return success(); @@ -122,7 +123,7 @@ LogicalResult _XlaHostComputeMlirOp::verify() { func::FuncOp _XlaHostComputeMlirOp::GetHostFunc( mlir::OwningOpRef* mlir_module) { - if (!tensorflow::DeserializeMlirModule(host_mlir_module().str(), + if (!tensorflow::DeserializeMlirModule(getHostMlirModule().str(), this->getContext(), mlir_module) .ok()) return nullptr; @@ -135,14 +136,20 @@ func::FuncOp _XlaHostComputeMlirOp::GetHostFunc( // For XLA Send/Recv ops the key corresponds to the resource instance. -std::string _XlaRecvAtHostOp::GetResourceInstanceStr() { return key().str(); } +std::optional _XlaRecvAtHostOp::GetResourceInstanceStr() { + return getKey().str(); +} -std::string _XlaRecvAtHostV2Op::GetResourceInstanceStr() { return key().str(); } +std::optional _XlaRecvAtHostV2Op::GetResourceInstanceStr() { + return getKey().str(); +} -std::string _XlaSendFromHostOp::GetResourceInstanceStr() { return key().str(); } +std::optional _XlaSendFromHostOp::GetResourceInstanceStr() { + return getKey().str(); +} -std::string _XlaSendFromHostV2Op::GetResourceInstanceStr() { - return key().str(); +std::optional _XlaSendFromHostV2Op::GetResourceInstanceStr() { + return getKey().str(); } namespace { @@ -155,24 +162,24 @@ std::string GetRendezvousKey(const std::string& send_device, } } // namespace -std::string _HostRecvOp::GetResourceInstanceStr() { - return GetRendezvousKey(send_device().str(), send_device_incarnation(), - recv_device().str(), tensor_name().str()); +std::optional _HostRecvOp::GetResourceInstanceStr() { + return GetRendezvousKey(getSendDevice().str(), getSendDeviceIncarnation(), + getRecvDevice().str(), getTensorName().str()); } -std::string _HostSendOp::GetResourceInstanceStr() { - return GetRendezvousKey(send_device().str(), send_device_incarnation(), - recv_device().str(), tensor_name().str()); +std::optional _HostSendOp::GetResourceInstanceStr() { + return GetRendezvousKey(getSendDevice().str(), getSendDeviceIncarnation(), + getRecvDevice().str(), getTensorName().str()); } -std::string _RecvOp::GetResourceInstanceStr() { - return GetRendezvousKey(send_device().str(), send_device_incarnation(), - recv_device().str(), tensor_name().str()); +std::optional _RecvOp::GetResourceInstanceStr() { + return GetRendezvousKey(getSendDevice().str(), getSendDeviceIncarnation(), + getRecvDevice().str(), getTensorName().str()); } -std::string _SendOp::GetResourceInstanceStr() { - return GetRendezvousKey(send_device().str(), send_device_incarnation(), - recv_device().str(), tensor_name().str()); +std::optional _SendOp::GetResourceInstanceStr() { + return GetRendezvousKey(getSendDevice().str(), getSendDeviceIncarnation(), + getRecvDevice().str(), getTensorName().str()); } } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index 7e77edef436..93fa8e26f14 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -139,8 +139,8 @@ TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context) static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) { auto attr = named_attr.getValue().dyn_cast(); if (!attr) { - return op->emitError() - << "'tf_saved_model.index_path' attribute should be an ArrayAttr"; + return op->emitError() << "'" << kTfSavedModelIndexPathAttr + << "' attribute should be an ArrayAttr"; } for (auto element : attr) { if (element.isa()) { @@ -151,8 +151,8 @@ static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) { continue; } } - return op->emitError() << "'tf_saved_model.index_path' elements should " - "be strings or 64-bit integers"; + return op->emitError() << "'" << kTfSavedModelIndexPathAttr + << "' elements should be strings or 64-bit integers"; } return mlir::success(); } @@ -206,7 +206,7 @@ LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute( auto arg_type = cast(op).getArgument(arg_index).getType(); return VerifyBoundInputArgType(op, arg_type, symbol_op); } - if (named_attr.getName() == "tf_saved_model.index_path") { + if (named_attr.getName() == kTfSavedModelIndexPathAttr) { return VerifyIndexPath(op, named_attr); } @@ -217,7 +217,7 @@ LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute( LogicalResult TensorFlowSavedModelDialect::verifyRegionResultAttribute( Operation *op, unsigned region_index, unsigned result_index, NamedAttribute named_attr) { - if (named_attr.getName() == "tf_saved_model.index_path") { + if (named_attr.getName() == kTfSavedModelIndexPathAttr) { return VerifyIndexPath(op, named_attr); } @@ -256,13 +256,13 @@ LogicalResult VerifySessionInitOp(SessionInitializerOp session_init_op, static bool HasAnyTfSavedModelArgAttr(func::FuncOp func) { for (int i = 0, e = func.getNumArguments(); i < e; i++) { - if (func.getArgAttr(i, "tf_saved_model.index_path") || + if (func.getArgAttr(i, kTfSavedModelIndexPathAttr) || func.getArgAttr(i, "tf_saved_model.bound_input")) { return true; } } for (int i = 0, e = func.getNumResults(); i < e; i++) { - if (func.getResultAttr(i, "tf_saved_model.index_path") || + if (func.getResultAttr(i, kTfSavedModelIndexPathAttr) || func.getResultAttr(i, "tf_saved_model.bound_input")) { return true; } @@ -273,7 +273,7 @@ static bool HasAnyTfSavedModelArgAttr(func::FuncOp func) { static LogicalResult VerifySavedModelModule( ModuleOp module, TensorFlowSavedModelDialect *dialect) { auto exported_names_ident = - StringAttr::get(dialect->getContext(), "tf_saved_model.exported_names"); + StringAttr::get(dialect->getContext(), kTfSavedModelExportedNamesAttr); // Check that there are no duplicated exported_names. DenseMap exported_name_to_op; for (auto &op : module) { @@ -377,11 +377,12 @@ LogicalResult VerifyExportedFunc(func::FuncOp func) { reached_bound_inputs = true; continue; } - if (func.getArgAttr(i, "tf_saved_model.index_path")) { + if (func.getArgAttr(i, kTfSavedModelIndexPathAttr)) { if (reached_bound_inputs) { return func.emitError() - << "all 'tf_saved_model.index_path' arg attributes should " - "precede all 'tf_saved_model.bound_input' arg attributes"; + << "all '" << kTfSavedModelIndexPathAttr + << "' arg attributes should precede all " + "'tf_saved_model.bound_input' arg attributes"; } continue; } @@ -391,8 +392,9 @@ LogicalResult VerifyExportedFunc(func::FuncOp func) { "unless it is being under construction"; } return func.emitError() - << "all arguments should have 'tf_saved_model.index_path', " - "'tf_saved_model.bound_input' or 'tf.resource_name' attributes"; + << "all arguments should have '" << kTfSavedModelIndexPathAttr + << "', 'tf_saved_model.bound_input' or 'tf.resource_name' " + "attributes"; } llvm::SmallDenseSet unique_bound_inputs; for (int i = 0, e = func.getNumArguments(); i < e; i++) { @@ -407,9 +409,9 @@ LogicalResult VerifyExportedFunc(func::FuncOp func) { } for (int i = 0, e = func.getNumResults(); i < e; i++) { - if (!func.getResultAttr(i, "tf_saved_model.index_path")) { - return func.emitError() << "all results should have " - "'tf_saved_model.index_path' attributes"; + if (!func.getResultAttr(i, kTfSavedModelIndexPathAttr)) { + return func.emitError() << "all results should have '" + << kTfSavedModelIndexPathAttr << "' attributes"; } } @@ -448,20 +450,20 @@ LogicalResult VerifyInitializerTypeAttr(Operation *op, LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute( Operation *op, NamedAttribute named_attr) { - if (named_attr.getName() == "tf_saved_model.exported_names") { + if (named_attr.getName() == kTfSavedModelExportedNamesAttr) { if (!isa(op)) { - return op->emitError() << "'tf_saved_model.exported_names' must be on a " - "'func' or 'tf_saved_model.global_tensor' op"; + return op->emitError() + << "'" << kTfSavedModelExportedNamesAttr + << "' must be on a 'func' or 'tf_saved_model.global_tensor' op"; } if (!IsStrArrayAttr(named_attr.getValue())) { - return op->emitError() - << "'tf_saved_model.exported_names' must be an array of strings"; + return op->emitError() << "'" << kTfSavedModelExportedNamesAttr + << "' must be an array of strings"; } if (!op->getParentOp()->getAttr("tf_saved_model.semantics")) { - return op->emitError() - << "'tf_saved_model.exported_names' must be on an op " - "whose immediate parent has attribute " - "'tf_saved_model.semantics'"; + return op->emitError() << "'" << kTfSavedModelExportedNamesAttr + << "' must be on an op whose immediate parent has " + "attribute 'tf_saved_model.semantics'"; } if (auto func = dyn_cast(op)) { if (failed(VerifyExportedFunc(func))) { @@ -493,7 +495,7 @@ LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute( SmallVector GetExportedNames(Operation *op) { SmallVector ret; auto exported_names = - op->getAttrOfType("tf_saved_model.exported_names"); + op->getAttrOfType(kTfSavedModelExportedNamesAttr); if (exported_names) { for (auto name : exported_names) { ret.push_back(name.cast().getValue()); @@ -504,7 +506,7 @@ SmallVector GetExportedNames(Operation *op) { bool IsExported(Operation *op) { auto exported_names = - op->getAttrOfType("tf_saved_model.exported_names"); + op->getAttrOfType(kTfSavedModelExportedNamesAttr); return exported_names && !exported_names.empty(); } @@ -591,5 +593,38 @@ SmallVector GetSessionInitializerExportedName(ModuleOp op) { return results; } +SmallVector GetInitializerFunctions(ModuleOp module_op) { + SessionInitializerOp session_initializer_op = + GetSessionInitializerOp(module_op); + if (!session_initializer_op) return {}; + + SymbolTable symbol_table(module_op); + + SmallVector init_func_ops; + for (auto init_func_sym : session_initializer_op.getInitializers() + .getAsValueRange()) { + auto init_func_op = symbol_table.lookup(init_func_sym); + // `init_func_op` is guaranteed to be not null in a valid module. + init_func_ops.push_back(init_func_op); + } + + return init_func_ops; +} + +func::FuncOp GetInitializerFunction(ModuleOp module_op, + const StringRef initializer_type) { + SmallVector init_func_ops = + GetInitializerFunctions(module_op); + + auto init_func_itr = absl::c_find_if( + init_func_ops, [initializer_type](const func::FuncOp init_func_op) { + const auto init_type_attr = init_func_op->getAttrOfType( + kTfSavedModelInitializerTypeAttr); + return init_type_attr && init_type_attr == initializer_type; + }); + + return init_func_itr == init_func_ops.end() ? nullptr : *init_func_itr; +} + } // namespace tf_saved_model } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h index 6b66d2e841d..53d335f4ff2 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h @@ -24,6 +24,16 @@ limitations under the License. namespace mlir { namespace tf_saved_model { +// The name of the attribute indicating under what name an object is exported. +inline constexpr StringRef kTfSavedModelExportedNamesAttr = + "tf_saved_model.exported_names"; + +// The name of the attribute attached to input arguments or results of a +// function to represent the path which one would use to index into a structured +// value to reach a given tensor. +inline constexpr StringRef kTfSavedModelIndexPathAttr = + "tf_saved_model.index_path"; + // Name of the attribute that inidicates the type of initializer. It should be // on a function and the function should exist in the initializers attribute of // the SessionInitializerOp. @@ -70,7 +80,7 @@ SmallVector GetExportedNames(Operation *op); bool IsExported(Operation *op); // Returns true if `module` has tf_saved_model linkage semantics. -bool HasTfSavedModelSemantics(ModuleOp module); +bool HasTfSavedModelSemantics(ModuleOp module_op); // Returns the tf_saved_model.global_tensor op that func's arg_index'th argument // refers to as a bound input, or null. @@ -90,10 +100,19 @@ Type GetBoundInputArgTypeFor(mlir::Operation *op); // Returns the session initializer of this module if it exists. Returns null // otherwise. -SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op); +SessionInitializerOp GetSessionInitializerOp(ModuleOp module_op); // Returns the exported name for the session initializer function. -SmallVector GetSessionInitializerExportedName(mlir::ModuleOp op); +SmallVector GetSessionInitializerExportedName(ModuleOp module_op); + +// Returns initializer function ops. These functions' symbols are in the +// "initializers" attribute of the session initializer op. +SmallVector GetInitializerFunctions(ModuleOp module_op); + +// Returns the initializer function whose `tf_saved_model.initializer_type` +// attribute matches `initializer_type`. Returns a null op if it doesn't exist. +func::FuncOp GetInitializerFunction(ModuleOp module_op, + StringRef initializer_type); } // namespace tf_saved_model } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td index c3588720060..03f70eb9388 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td @@ -83,8 +83,7 @@ def TfSavedModel_Dialect : Dialect { }]; let cppNamespace = "::mlir::tf_saved_model"; - - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let useFoldAPI = kEmitFoldAdaptorFolder; } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_test.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_test.cc new file mode 100644 index 00000000000..48cfb26d680 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_test.cc @@ -0,0 +1,189 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/core/platform/test.h" + +namespace mlir { +namespace tf_saved_model { +namespace { + +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::IsNull; +using ::testing::NotNull; +using ::testing::SizeIs; + +// Fixture for testing TfSavedModel functionalities. Initializes a +// `MLIRContext` by loading the `tf_saved_model` dialect. +class TfSavedModelTest : public ::testing::Test { + protected: + TfSavedModelTest() : ctx_() { + ctx_.loadDialect(); + } + + MLIRContext ctx_; +}; + +// Parses `module_op_str` and returns the resulting `ModuleOp`. +ModuleOp ParseModuleOp(const StringRef module_op_str, Block& block, + MLIRContext& ctx) { + const LogicalResult parse_result = + parseSourceString(module_op_str, &block, ParserConfig(&ctx)); + EXPECT_TRUE(succeeded(parse_result)); + + return cast(block.front()); +} + +TEST_F(TfSavedModelTest, + GetInitializerFunctionReturnsNullWhenNoSessionInitializerOp) { + constexpr StringRef kModuleOpStr = + R"mlir(module attributes {tf_saved_model.semantics} {})mlir"; + + Block block; + ModuleOp module_op = ParseModuleOp(kModuleOpStr, block, ctx_); + + func::FuncOp init_func_op = GetInitializerFunction( + module_op, /*initializer_type=*/kTfSavedModelInitializerInitType); + + EXPECT_THAT(init_func_op, IsNull()); +} + +TEST_F(TfSavedModelTest, + GetInitializerFunctionReturnsNullWhenInitializersEmpty) { + constexpr StringRef kModuleOpStr = R"mlir( + module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = []} : () -> () + } + )mlir"; + + Block block; + ModuleOp module_op = ParseModuleOp(kModuleOpStr, block, ctx_); + + func::FuncOp init_func_op = GetInitializerFunction( + module_op, /*initializer_type=*/kTfSavedModelInitializerInitType); + + EXPECT_THAT(init_func_op, IsNull()); +} + +TEST_F(TfSavedModelTest, + GetInitializerFunctionReturnsFuncOpMatchingInitializerType) { + constexpr StringRef kModuleOpStr = R"mlir( + module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func]} : () -> () + func.func @init_func() attributes {tf_saved_model.exported_names = ["init_func"], tf_saved_model.initializer_type = "init_op"} { + func.return + } + } + )mlir"; + + Block block; + ModuleOp module_op = ParseModuleOp(kModuleOpStr, block, ctx_); + + func::FuncOp init_func_op = GetInitializerFunction( + module_op, /*initializer_type=*/kTfSavedModelInitializerInitType); + + EXPECT_THAT(init_func_op, NotNull()); + EXPECT_THAT(init_func_op.getSymName(), "init_func"); + EXPECT_THAT( + init_func_op->getAttrOfType(kTfSavedModelInitializerTypeAttr), + kTfSavedModelInitializerInitType); +} + +TEST_F(TfSavedModelTest, GetInitializerFunctionNoMatchingInitializerType) { + constexpr StringRef kModuleOpStr = R"mlir( + module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func]} : () -> () + func.func @init_func() attributes {tf_saved_model.exported_names = ["init_func"], tf_saved_model.initializer_type = "restore_op"} { + func.return + } + } + )mlir"; + + Block block; + ModuleOp module_op = ParseModuleOp(kModuleOpStr, block, ctx_); + + func::FuncOp init_func_op = GetInitializerFunction( + module_op, /*initializer_type=*/kTfSavedModelInitializerInitType); + + // No initializer function matches the initializer type. + EXPECT_THAT(init_func_op, IsNull()); +} + +TEST_F(TfSavedModelTest, GetInitializerFunctionsEmptyWhenNoInitFunctions) { + constexpr StringRef kModuleOpStr = R"mlir( + module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = []} : () -> () + } + )mlir"; + + Block block; + ModuleOp module_op = ParseModuleOp(kModuleOpStr, block, ctx_); + + SmallVector init_func_ops = + GetInitializerFunctions(module_op); + + EXPECT_THAT(init_func_ops, IsEmpty()); +} + +TEST_F(TfSavedModelTest, + GetInitializerFunctionsEmptyWhenNoSessionInitializerOp) { + constexpr StringRef kModuleOpStr = + R"mlir(module attributes {tf_saved_model.semantics} {})mlir"; + + Block block; + ModuleOp module_op = ParseModuleOp(kModuleOpStr, block, ctx_); + + SmallVector init_func_ops = + GetInitializerFunctions(module_op); + + EXPECT_THAT(init_func_ops, IsEmpty()); +} + +TEST_F(TfSavedModelTest, GetInitializerFunctionsReturnsMultipleFuncOps) { + constexpr StringRef kModuleOpStr = R"mlir( + module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func1, @init_func2]} : () -> () + + func.func @init_func1() attributes {tf_saved_model.exported_names = ["init_func1"], tf_saved_model.initializer_type = "init_op"} { + func.return + } + + func.func @init_func2() attributes {tf_saved_model.exported_names = ["init_func2"], tf_saved_model.initializer_type = "restore_op"} { + func.return + } + } + )mlir"; + + Block block; + ModuleOp module_op = ParseModuleOp(kModuleOpStr, block, ctx_); + + SmallVector init_func_ops = + GetInitializerFunctions(module_op); + + EXPECT_THAT(init_func_ops, SizeIs(2)); + EXPECT_THAT(init_func_ops[0].getSymName(), Eq("init_func1")); + EXPECT_THAT(init_func_ops[1].getSymName(), Eq("init_func2")); +} + +} // namespace +} // namespace tf_saved_model +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h index 9299dad8474..0b456a0de25 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h @@ -99,6 +99,11 @@ struct CollectiveReduceOrdering StringRef getName() final { return "CollectiveReduceOrdering"; } }; +struct NcclAllReduceOrdering + : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "NcclAllReduceOrdering"; } +}; + // Returns true iff resource type with given ID is only self-dependent, i.e., // there are no dependencies to other resource types (including unknown resource // type). diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc index 9017b062b43..d0131d9f8cb 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include + namespace mlir { namespace TF { @@ -34,7 +36,7 @@ llvm::Optional RuntimeDevices::GetGpuDeviceMetadata( if (it != gpu_metadata_.end()) { return it->second; } else { - return llvm::None; + return std::nullopt; } } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h index 03ca1a545a1..28e7daa7f52 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h @@ -116,7 +116,7 @@ inline ShapedType MergeType(ShapedType a, ShapedType b) { for (int i = 0, e = rank; i != e; i++) { int64_t dim0 = a.getDimSize(i); int64_t dim1 = b.getDimSize(i); - dims[i] = (dim0 == ShapedType::kDynamicSize) ? dim1 : dim0; + dims[i] = (dim0 == ShapedType::kDynamic) ? dim1 : dim0; } return RankedTensorType::get(dims, a.getElementType()); } @@ -148,7 +148,7 @@ class SameOperandsAndResultTypeResolveRef } static LogicalResult inferReturnTypeComponentsFromOperands( - MLIRContext*, Optional location, ValueShapeRange operands, + MLIRContext*, std::optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { if (operands.empty()) @@ -175,7 +175,7 @@ template class CannotDuplicate : public TraitBase { public: static LogicalResult verifyTrait(Operation* op) { - if (MemoryEffectOpInterface::hasNoEffect(op)) + if (isMemoryEffectFree(op)) return op->emitError( "operations with no side effects cannot have CannotDuplicate trait"); return success(); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def index a097a3cad88..940292e7724 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def @@ -64,6 +64,8 @@ HANDLE_TF_REF_TYPE(Complex64Ref, COMPLEX64_REF, "complex64ref") HANDLE_TF_REF_TYPE(Complex128Ref, COMPLEX128_REF, "complex128ref") HANDLE_TF_REF_TYPE(HalfRef, HALF_REF, "halfref") HANDLE_TF_REF_TYPE(ResourceRef, RESOURCE_REF, "resourceref") +HANDLE_TF_REF_TYPE(Float8E4M3FNRef, FLOAT8_E4M3FN_REF, "float8e4m3fnref") +HANDLE_TF_REF_TYPE(Float8E5M2Ref, FLOAT8_E5M2_REF, "float8e5m2ref") #ifndef HANDLE_LAST_TF_TYPE #define HANDLE_LAST_TF_TYPE(class, enumerant, name) \ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc index f0a337a15c1..da2075cdc0c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tfrt_ops.cc @@ -38,13 +38,13 @@ _TfrtGetResourceOp::GetResourceHandleValueAndIdList( llvm::SmallVector resource_vec; llvm::StringRef device = GetDeviceOrEmpty(getOperation()); - for (auto iter : llvm::enumerate(results())) { + for (auto iter : llvm::enumerate(getResults())) { auto index = iter.index(); if (getElementTypeOrSelf(iter.value().getType()).isa()) { resource_vec.push_back(GetResourceHandleValueAndIdBase( - container()[index].cast().getValue(), - shared_name()[index].cast().getValue(), device, - results()[index], resource_handle_id_map, next_id)); + getContainer()[index].cast().getValue(), + getSharedName()[index].cast().getValue(), device, + getResults()[index], resource_handle_id_map, next_id)); } } return resource_vec; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.cc b/tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.cc new file mode 100644 index 00000000000..5921efa2096 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.cc @@ -0,0 +1,34 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h" + +#include + +namespace mlir { +namespace TF { + +const llvm::SmallDenseSet& +TPUEmbeddingOpsRegistry::GetOpsTypeIds() { + return ops_type_ids_; +} + +// static +TPUEmbeddingOpsRegistry& TPUEmbeddingOpsRegistry::Global() { + static TPUEmbeddingOpsRegistry* registry = new TPUEmbeddingOpsRegistry; + return *registry; +} +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h b/tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h new file mode 100644 index 00000000000..c8160418e8f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h @@ -0,0 +1,59 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TPU_EMBEDDING_OPS_REGISTRY_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TPU_EMBEDDING_OPS_REGISTRY_H_ + +#include "llvm/ADT/DenseSet.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// A global ops registry that is used to hold TPU embedding ops. +// +// Example: +// TPUEmbeddingOpsRegistry::Global().Add(); +// for (auto op_type_id : TPUEmbeddingOpsRegistry::Global().GetOpsTypeIds()) +// { +// ... +// } +class TPUEmbeddingOpsRegistry { + public: + // Add the op to the registry. + // + // Adding an op here will allow use old bridge legalization from the MLIR + // bridge with the use of fallback mechanism. Therefore, addition of any op + // here must have a python test with MLIR bridge enabled to verify that the + // fallback works correctly. + template + void Add() { + ops_type_ids_.insert(TypeID::get()); + } + + // Returns the type id of the ops in the TPUEmbeddingOpRegistry. + const llvm::SmallDenseSet& GetOpsTypeIds(); + + // Returns the global registry. + static TPUEmbeddingOpsRegistry& Global(); + + private: + llvm::SmallDenseSet ops_type_ids_{}; +}; +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TPU_EMBEDDING_OPS_REGISTRY_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/BUILD index be43c9ebc05..6cc7344b083 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/BUILD @@ -1,11 +1,19 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", + size_override = { + "decompose_resource_ops.mlir": "medium", + "layout_optimization_move_transposes_end.mlir": "medium", + "layout_optimization_to_nhwc.mlir": "medium", + }, tags_override = { "optimize.mlir": ["no_rocm"], "tf_optimize.mlir": ["no_rocm"], diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 0b34e68694f..c991ddfcdcb 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: tf-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' | FileCheck %s // CHECK-LABEL: func @tfAssertTrue func.func @tfAssertTrue(%arg0: tensor<1x1x6x2xf32>) { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir index 641a76f5397..16acdcec5ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir @@ -245,13 +245,13 @@ module { %2 = "tf.A"(%arg0) : (tensor) -> tensor // Note that tf.C is moved before tf_device.launch. - // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[ARG_0]]) : (tensor) -> tensor + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[ARG_0]]) // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) : (tensor) -> tensor %3 = "tf.B"(%2) {device = "tpu0"} : (tensor) -> tensor - %4 = "tf.C"(%arg0) : (tensor) -> tensor + %4 = "tf.C"(%arg0) {is_stateless = true} : (tensor) -> tensor // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[A_OUTPUT]], %[[B_OUTPUT]]) : (tensor, tensor) -> tensor %5 = "tf.D"(%2, %3) {device = "tpu0"} : (tensor, tensor) -> tensor @@ -273,6 +273,78 @@ module { // ----- +// Side effecting ops + +module { + // CHECK-LABEL: func @sideeffect + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) + func.func @sideeffect(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) + %2 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[TPU0_OUTPUT0:[0-9]*]] = "tf_device.launch" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) : (tensor) -> tensor + // CHECK: tf_device.return %[[B_OUTPUT]] + %3 = "tf.B"(%2) {device = "tpu0"} : (tensor) -> tensor + + // tf.B and tf.D cannot be merged because of tf.C, which is assumed to have a side effect. + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[ARG_0]]) : (tensor) -> tensor + + %4 = "tf.C"(%arg0) : (tensor) -> tensor + + // CHECK: %[[TPU0_OUTPUT1:[0-9]*]] = "tf_device.launch" + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[A_OUTPUT]], %[[TPU0_OUTPUT0]]) : (tensor, tensor) -> tensor + // CHECK: tf_device.return %[[D_OUTPUT]] + %5 = "tf.D"(%2, %3) {device = "tpu0"} : (tensor, tensor) -> tensor + + // CHECK: {device = "tpu0"} : () -> tensor + + // CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[C_OUTPUT]], %[[TPU0_OUTPUT1]]) : (tensor, tensor) -> tensor + %6 = "tf.E"(%4, %5) : (tensor, tensor) -> tensor + + // CHECK: tf_executor.yield %[[E_OUTPUT]] + tf_executor.yield %6 : tensor + } + tf_executor.fetch %1#0 : tensor + } + func.return %0 : tensor + } +} + +// ----- + +// Cluster formation that requires reordering users of the cluster op. + +module { + // CHECK-LABEL: func @dominanceorder + // CHECK-SAME: (%[[ARG0:.*]]: tensor) + func.func @dominanceorder(%arg0: tensor) -> (tensor, tensor) { + %0:2 = tf_executor.graph { + %1:3 = tf_executor.island { + %2 = "tf.A"(%arg0) {device = "tpu0"} : (tensor) -> tensor + %3 = "tf.B"(%2) {is_stateless = true} : (tensor) -> tensor + %4 = "tf.C"(%2) {device = "tpu0"} : (tensor) -> tensor + tf_executor.yield %3, %4 : tensor, tensor + + // CHECK: %[[TPU0_OUTPUT:.*]]:2 = "tf_device.launch" + // CHECK: %[[A:.*]] = "tf.A"(%[[ARG0]]) + // CHECK: %[[C:.*]] = "tf.C"(%[[A]]) + // CHECK: tf_device.return %[[A]], %[[C]] + + // CHECK: %[[B:.*]] = "tf.B"(%[[TPU0_OUTPUT]]#0) + // CHECK: tf_executor.yield %[[B]], %[[TPU0_OUTPUT]]#1 + } + tf_executor.fetch %1#0, %1#1 : tensor, tensor + } + func.return %0#0, %0#1 : tensor, tensor + } +} + +// ----- + // Multiple device clusters with intertwined instructions in original block. module { @@ -286,7 +358,7 @@ module { %2 = "tf.A"(%arg0) : (tensor) -> tensor // CHECK: %[[GPU0_OUTPUT:[0-9]*]] = "tf_device.launch" - // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[ARG_0]]) : (tensor) -> tensor + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[ARG_0]]) // CHECK: tf_device.return %[[C_OUTPUT]] // CHECK: {device = "gpu0"} : () -> tensor @@ -294,7 +366,7 @@ module { // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) : (tensor) -> tensor %3 = "tf.B"(%2) {device = "tpu0"} : (tensor) -> tensor - %4 = "tf.C"(%arg0) {device = "gpu0"} : (tensor) -> tensor + %4 = "tf.C"(%arg0) {device = "gpu0", is_stateless = true} : (tensor) -> tensor // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[A_OUTPUT]], %[[B_OUTPUT]]) : (tensor, tensor) -> tensor %5 = "tf.D"(%2, %3) {device = "tpu0"} : (tensor, tensor) -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/BUILD index 0a68d0d85c0..9132abf2fe5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/BUILD @@ -1,6 +1,8 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + licenses(["notice"]) glob_lit_tests( diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/replicate-tensor-list-init-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/replicate-tensor-list-init-ops.mlir new file mode 100644 index 00000000000..2ea3bf4d935 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/replicate-tensor-list-init-ops.mlir @@ -0,0 +1,26 @@ +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: | FileCheck %s + +module attributes {tf.versions = {producer = 179 : i32}} { + func.func @main() -> (tensor<300x?xf32>, tensor<300x?xf32>) { + %elem_shape = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %size = "tf.Const"() {value = dense<300> : tensor} : () -> tensor + %tl = "tf.TensorListReserve"(%elem_shape, %size) : (tensor, tensor) -> tensor>> + + %idx = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem_1 = "tf.Const"() {value = dense<10.0> : tensor<8xf32>} : () -> tensor<8xf32> + %tl_set_item = "tf.TensorListSetItem"(%tl, %idx, %elem_1) : (tensor>>, tensor, tensor<8xf32>) -> tensor>> + %elem_shape_2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %tls = "tf.TensorListStack"(%tl_set_item, %elem_shape_2) {num_elements = 300 : i64} : (tensor>>, tensor<1xi32>) -> tensor<300x?xf32> + + %elem_2 = "tf.Const"() {value = dense<10.0> : tensor<9xf32>} : () -> tensor<9xf32> + %tl_set_item_2 = "tf.TensorListSetItem"(%tl, %idx, %elem_2) : (tensor>>, tensor, tensor<9xf32>) -> tensor>> + %elem_shape_3 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %tls_2 = "tf.TensorListStack"(%tl_set_item_2, %elem_shape_3) {num_elements = 300 : i64} : (tensor>>, tensor<1xi32>) -> tensor<300x?xf32> + func.return %tls, %tls_2 : tensor<300x?xf32>, tensor<300x?xf32> + } +} + +// CHECK-LABEL: HloModule main +// CHECK: ENTRY %main.{{[0-9]+}} () -> (f32[300,8], f32[300,9]) { +// CHECK: %tuple.{{[0-9]+}} = (f32[300,8]{1,0}, f32[300,9]{1,0}) +// CHECK: } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/stablehlo_add.mlir b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/stablehlo_add.mlir new file mode 100644 index 00000000000..8727fd39c12 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/compile_mlir_util/stablehlo_add.mlir @@ -0,0 +1,20 @@ +// RUN: tf-mlir-translate -mlir-tf-to-hlo-text %s -tf-input-shapes=: -tf-xla-emit-return-tuple | FileCheck %s + + +// TODO(b/259459405): Remove this test along with the upstream refactoring to +// avoid non TF inputs. +// This is not a supported mode. +module attributes {tf.versions = {producer = 179 : i32}} { + func.func @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor + } +} + +// CHECK-LABEL: HloModule main +// CHECK: ENTRY %main.{{[0-9]+}} ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> (f32[]) { +// CHECK-NEXT: %[[ARG0]] = f32[] parameter(0) +// CHECK-NEXT: %[[ARG1]] = f32[] parameter(1) +// CHECK-NEXT: [[ADD:%.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]]) +// CHECK-NEXT: ROOT %tuple.{{[0-9]+}} = (f32[]) tuple(f32[] [[ADD]]) +// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index 8af56e9d733..5a3f48c1252 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -608,7 +608,7 @@ func.func @testBroadcastGradientArgsHigherRank() -> (tensor<2xi32>, tensor<2xi32 // CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> tensor<2xi32> // CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32> // CHECK-NOT: tf.BroadcastGradientArgs - // CEHCK: return [[R0]], [[R1]] + // CHECK: return %[[R0]], %[[R1]] func.return %r0, %r1 : tensor<2xi32>, tensor<2xi32> } @@ -621,7 +621,7 @@ func.func @testBroadcastGradientArgsScalar() -> (tensor<2xi32>, tensor<0xi32>) { // CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32> // CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> // CHECK-NOT: tf.BroadcastGradientArgs - // CEHCK: return [[R0]], [[R1]] + // CHECK: return %[[R0]], %[[R1]] func.return %r0, %r1 : tensor<2xi32>, tensor<0xi32> } @@ -634,7 +634,7 @@ func.func @testBroadcastGradientArgI64() -> (tensor<2xi64>, tensor<0xi64>) { // CHECK-DAG: %[[R0:.*]] = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64> // CHECK-DAG: %[[R1:.*]] = "tf.Const"() {value = dense<> : tensor<0xi64>} : () -> tensor<0xi64> // CHECK-NOT: tf.BroadcastGradientArgs - // CEHCK: return [[R0]], [[R1]] + // CHECK: return %[[R0]], %[[R1]] func.return %r0, %r1 : tensor<2xi64>, tensor<0xi64> } @@ -653,7 +653,7 @@ func.func @testEmptyResults(%arg0: tensor<0x2xf32>) -> tensor<0x2xf32> { // // CHECK-LABEL: func @yieldOp func.func @yieldOp(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor) { - // CHECK-2: tf.Yield + // CHECK-COUNT-2: tf.Yield %0 = "tf.IfRegion"(%arg2) ({ "tf.Yield"(%arg0) : (tensor) -> () }, { @@ -697,3 +697,47 @@ func.func @range_float() -> tensor { %0 = "tf.Range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor func.return %0 : tensor } + +// CHECK-LABEL: func @testLogicalAndFoldsWithConstantFalse +func.func @testLogicalAndFoldsWithConstantFalse(%arg0: tensor) -> (tensor) { + // CHECK: [[CST:%.+]] = "tf.Const"() {value = dense : tensor} : () -> tensor + %cst = arith.constant dense : tensor + + %0 = "tf.LogicalAnd"(%cst, %arg0) : (tensor, tensor) -> tensor + + // CHECK: return [[CST]] + func.return %0: tensor +} + +// CHECK-LABEL: func @testLogicalAndFoldsWithConstantFalseSecondArg +func.func @testLogicalAndFoldsWithConstantFalseSecondArg(%arg0: tensor) -> (tensor) { + // CHECK: [[CST:%.+]] = "tf.Const"() {value = dense : tensor} : () -> tensor + %cst = arith.constant dense : tensor + + %0 = "tf.LogicalAnd"(%arg0, %cst) : (tensor, tensor) -> tensor + + // CHECK: return [[CST]] + func.return %0: tensor +} + +// CHECK-LABEL: func @testLogicalAndNoFoldWithConstTrue +func.func @testLogicalAndNoFoldWithConstTrue(%arg0: tensor) -> (tensor) { + %cst = arith.constant dense : tensor + + // CHECK: %[[LOGICAL_AND:.*]] = "tf.LogicalAnd" + %0 = "tf.LogicalAnd"(%cst, %arg0) : (tensor, tensor) -> tensor + + // CHECK: return %[[LOGICAL_AND]] + func.return %0 : tensor +} + +// CHECK-LABEL: func @testLogicalAndDoesntFoldWithConstantFalseBroadcast +func.func @testLogicalAndDoesntFoldWithConstantFalseBroadcast(%arg0: tensor<2xi1>) -> (tensor<2xi1>) { + %cst = arith.constant dense : tensor + + // CHECK: %[[LOGICAL_AND:.*]] = "tf.LogicalAnd" + %0 = "tf.LogicalAnd"(%cst, %arg0) : (tensor, tensor<2xi1>) -> tensor<2xi1> + + // CHECK: return %[[LOGICAL_AND]] + func.return %0: tensor<2xi1> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/convert_session_initializer_to_function.mlir b/tensorflow/compiler/mlir/tensorflow/tests/convert_session_initializer_to_function.mlir new file mode 100644 index 00000000000..f8c56ce5467 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/convert_session_initializer_to_function.mlir @@ -0,0 +1,51 @@ +// RUN: tf-opt %s --tf-saved-model-convert-session-initializer-to-function --split-input-file | FileCheck %s + +// CHECK-LABEL: simple_initializer +// CHECK-NOT: tf_saved_model.session_initializer +// CHECK: func @session_initializer +// CHECK: call @init1 +module @simple_initializer attributes {tf_saved_model.semantics} { +"tf_saved_model.session_initializer"() {initializers = [@init1]} : () -> () +func.func @init1() attributes {tf_saved_model.exported_names = ["init1"]} { + %0 = "tf.Const"() {value = dense<42> : tensor<1xi64>} : () -> tensor<1xi64> + return +} +} + +// ----- + +// CHECK-LABEL: with_initializer_type +// CHECK-NOT: tf_saved_model.session_initializer +// CHECK: func @session_initializer +// CHECK: call @init1 +module @with_initializer_type attributes {tf_saved_model.semantics} { +"tf_saved_model.session_initializer"() {initializers = [@init1]} : () -> () +func.func @init1() attributes {tf_saved_model.exported_names = ["init1"], tf_saved_model.initializer_type = "init_op"} { + %0 = "tf.Const"() {value = dense<42> : tensor<1xi64>} : () -> tensor<1xi64> + return +} +} + +// ----- + +// CHECK-LABEL: multiple_initializers +// CHECK-NOT: tf_saved_model.session_initializer +// CHECK: func @session_initializer +// CHECK: call @init1 +// CHECK: call @init2 +// CHECK: call @init3 +module @multiple_initializers attributes {tf_saved_model.semantics} { +"tf_saved_model.session_initializer"() {initializers = [@init1, @init2, @init3]} : () -> () +func.func @init1() attributes {tf_saved_model.exported_names = ["init1"]} { + %0 = "tf.Const"() {value = dense<42> : tensor<1xi64>} : () -> tensor<1xi64> + return +} +func.func @init2() attributes {tf_saved_model.exported_names = ["init2"]} { + %0 = "tf.Const"() {value = dense<43> : tensor<1xi64>} : () -> tensor<1xi64> + return +} +func.func @init3() attributes {tf_saved_model.exported_names = ["init3"]} { + %0 = "tf.Const"() {value = dense<44> : tensor<1xi64>} : () -> tensor<1xi64> + return +} +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/device_canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/device_canonicalize.mlir index d000293871d..d03a776289c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/device_canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/device_canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: tf-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // Test empty launch with no results is folded away. // CHECK-LABEL: func @empty_launch_no_results diff --git a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir index 0f85928a0f1..67a3e54979c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir @@ -15,6 +15,16 @@ func.func @einsum_matmul(%arg0: tensor<7x9xf32>, %arg1: tensor<9x5xf32>) -> tens // CHECK: return %[[v0]] : tensor<7x5xf32> } +func.func @einsum_matmul_dynamic_size(%arg0: tensor<2x?x?x?xf32>, %arg1: tensor<2x?xf32>) -> tensor<2x?x?x1xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "bxyc,bx->bxyc"} : (tensor<2x?x?x?xf32>, tensor<2x?xf32>) -> tensor<2x?x?x1xf32> + func.return %0 : tensor<2x?x?x1xf32> + // CHECK-LABEL: einsum_matmul_dynamic_size + // CHECK-DAG: %[[cst:.*]] = arith.constant dense<[2, -1, 1, 1]> : tensor<4xi64> + // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg1, %cst) : (tensor<2x?xf32>, tensor<4xi64>) -> tensor<2x?x1x1xf32> + // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %0) {adj_x = false, adj_y = false} : (tensor<2x?x?x?xf32>, tensor<2x?x1x1xf32>) -> tensor<2x?x?x1xf32> + // CHECK: return %[[v1]] : tensor<2x?x?x1xf32> +} + func.func @einsum_broadcast(%arg0: tensor<3x4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<3x4x6xf32> { %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,km->ijm"}: (tensor<3x4x5xf32>, tensor<5x6xf32>) -> tensor<3x4x6xf32> func.return %0 : tensor<3x4x6xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_canonicalize.mlir index abba4557eb2..17efa936c81 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: tf-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // Test single graph with no outputs and one island is folded away. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_materialize_const.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_materialize_const.mlir index 0dfaa3c6490..5b10a6ded47 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_materialize_const.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_materialize_const.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: tf-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // Test that a constant stays inside an island after canonicalization diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_coarsening/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_coarsening/BUILD index c9b9a22838d..954eca9c0e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_coarsening/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_coarsening/BUILD @@ -1,6 +1,8 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + licenses(["notice"]) glob_lit_tests( diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/BUILD index c9b9a22838d..954eca9c0e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/BUILD @@ -1,6 +1,8 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + licenses(["notice"]) glob_lit_tests( diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/BUILD index c9b9a22838d..954eca9c0e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/BUILD @@ -1,6 +1,8 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + licenses(["notice"]) glob_lit_tests( diff --git a/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir b/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir index 7b266683958..e435debc0e0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir @@ -408,3 +408,47 @@ module { func.return %c0 : tensor<0xf32> } } + +// ----- + +// Tests that entries corresponding to removed arguments in "tf._input_shapes" +// is also removed. + +module { + func.func @f() -> tensor<0xf32> { + // CHECK-NOT: "tf.VarHandleOp" + %handle = "tf.VarHandleOp"() {container="", shared_name="var1", device = "/job:worker/replica:0/task:1/device:CPU:0"} : () -> tensor>> + %cst = "tf.Const"() { value = dense<1.0> : tensor<0xf32> } : () -> tensor<0xf32> + %val = "tf.PartitionedCall"(%cst, %handle) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<0xf32>, tensor>>) -> (tensor<0xf32>) + func.return %val : tensor<0xf32> + } + + // CHECK: func private @f_callee(%[[ARG0:.*]]: tensor<0xf32>) -> tensor<0xf32> + // CHECK-SAME: tf._input_shapes = [#tf_type.shape<0>] + func.func private @f_callee(%arg0: tensor<0xf32>, %arg1: tensor<*x!tf_type.resource>) -> tensor<0xf32> attributes {tf._input_shapes = [#tf_type.shape<0>, #tf_type.shape<>]} { + %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>) -> tensor<0xf32> + %1 = "tf.AddV2"(%arg0, %0) : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32> + func.return %1 : tensor<0xf32> + } +} + +// ----- + +// Tests that an error is emitted when the number of arguments and the size of +// "tf._input_shapes" attribute doesn't match. + +module { + func.func @f() -> tensor<0xf32> { + %handle = "tf.VarHandleOp"() {container="", shared_name="var1", device = "/job:worker/replica:0/task:1/device:CPU:0"} : () -> tensor>> + %cst = "tf.Const"() { value = dense<1.0> : tensor<0xf32> } : () -> tensor<0xf32> + %val = "tf.PartitionedCall"(%cst, %handle) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<0xf32>, tensor>>) -> (tensor<0xf32>) + func.return %val : tensor<0xf32> + } + + // expected-error@+1 {{Number of arguments and 'tf._input_shapes' attribute size do not match. Num args: 2, tf._input_shapes size: 3}} + func.func private @f_callee(%arg0: tensor<0xf32>, %arg1: tensor<*x!tf_type.resource>) -> tensor<0xf32> attributes {tf._input_shapes = [#tf_type.shape<0>, #tf_type.shape<>, #tf_type.shape<9x9x9>]} { + %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>) -> tensor<0xf32> + %1 = "tf.AddV2"(%arg0, %0) : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32> + func.return %1 : tensor<0xf32> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD index 4fd67f248d8..71493d0f30a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD @@ -1,6 +1,8 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + licenses(["notice"]) glob_lit_tests( diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/batch_use_same_function/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/batch_use_same_function/BUILD index 4d34db63405..ab1cc6459a1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/batch_use_same_function/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/batch_use_same_function/BUILD @@ -1,6 +1,8 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + licenses(["notice"]) glob_lit_tests( diff --git a/tensorflow/compiler/mlir/tensorflow/tests/guarantee-all-funcs-one-use.mlir b/tensorflow/compiler/mlir/tensorflow/tests/guarantee-all-funcs-one-use.mlir index 8003ed93789..52f29badb19 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/guarantee-all-funcs-one-use.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/guarantee-all-funcs-one-use.mlir @@ -68,3 +68,21 @@ module { func.return } } + +// ----- +// Test stateful and stateless partitioned calls. +// CHECK-LABEL: func @f +func.func @f() { + // CHECK: "tf.PartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @g} : () -> () + "tf.PartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @g} : () -> () + // CHECK: "tf.StatefulPartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @[[NEWG:.+]]} : () -> () + "tf.StatefulPartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @g} : () -> () + func.return +} + +// CHECK: func.func @g() +// CHECK: func.func private @[[NEWG]]() +func.func @g() { + func.return +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/launch_to_device_attribute.mlir b/tensorflow/compiler/mlir/tensorflow/tests/launch_to_device_attribute.mlir index b1539bae843..5c73ce94600 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/launch_to_device_attribute.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/launch_to_device_attribute.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-launch-to-device-attribute | FileCheck %s +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-launch-to-device-attribute=legacy-graph-export=false | FileCheck %s // Tests single TensorFlow op is hoisted out and has the correct device assigned @@ -20,12 +20,13 @@ func.func @single_op_launch() { func.return } -// CHECK: %[[A:.*]] = "tf.opA" -// CHECK: %[[B:.*]]:2 = "tf.opB"(%[[A]]) +// CHECK-NOT: tf_executor.island +// CHECK: %[[A:.*]], {{.*}} = tf_executor.island wraps "tf.opA" +// CHECK: %[[B:.*]]:2, {{.*}} = tf_executor.island wraps "tf.opB"(%[[A]]) // CHECK-SAME: device = "CPU:0" -// CHECK: %[[C:.*]] = "tf.opC" +// CHECK: %[[C:.*]], {{.*}} = tf_executor.island wraps "tf.opC" // CHECK-NOT: "tf_device.launch" -// CHECK: tf_executor.yield %[[A]], %[[B]]#1, %[[B]]#0, %[[C]] +// CHECK-NOT: tf_executor.yield // Tests multiple TensorFlow ops are hoisted out and all have the correct device @@ -48,14 +49,15 @@ func.func @multi_op_launch() { func.return } -// CHECK: %[[A:.*]] = "tf.opA" -// CHECK: %[[B:.*]] = "tf.opB"(%[[A]]) +// CHECK-NOT: tf_executor.island +// CHECK: %[[A:.*]], {{.*}} = tf_executor.island wraps "tf.opA" +// CHECK: %[[B:.*]], {{.*}} = tf_executor.island wraps "tf.opB"(%[[A]]) // CHECK-SAME: device = "CPU:0" -// CHECK: %[[C:.*]] = "tf.opC"(%[[B]]) +// CHECK: %[[C:.*]], {{.*}} = tf_executor.island wraps "tf.opC"(%[[B]]) // CHECK-SAME: device = "CPU:0" -// CHECK: %[[D:.*]] = "tf.opD" +// CHECK: %[[D:.*]], {{.*}} = tf_executor.island wraps "tf.opD" // CHECK-NOT: "tf_device.launch" -// CHECK: tf_executor.yield %[[A]], %[[C]], %[[B]], %[[D]] +// CHECK-NOT: tf_executor.yield %[[A]], %[[C]], %[[B]], %[[D]] // Tests empty device string attributes are overwritten. @@ -74,12 +76,41 @@ func.func @empty_device_op() { func.return } -// CHECK: [[A:%.+]]:2 = "tf.opA" +// CHECK-NOT: tf_executor.island +// CHECK: [[A:%.+]]:2, {{.*}} = tf_executor.island wraps "tf.opA" // CHECK-SAME: device = "CPU:0" // CHECK-NOT: tf_device.launch -// CHECK: tf_executor.yield [[A]]#1, [[A]]#0 +// CHECK-NOT: tf_executor.yield [[A]]#1, [[A]]#0 +// Tests annotation `parallel_execution_ids` can be propagated correctly +// CHECK-LABEL: func @propagate_parallel_execution_ids +func.func @propagate_parallel_execution_ids() { + tf_executor.graph { + %0:5 = tf_executor.island { + %a = "tf.opA"() : () -> tensor + %launch:2 = "tf_device.launch"() ({ + %b = "tf.opB"(%a) : (tensor) -> tensor + %c = "tf.opC"(%b) : (tensor) -> tensor + tf_device.return %c, %b : tensor, tensor + }) {device = "CPU:0", _parallel_execution_ids = "r4:5,p0:0"} : () -> (tensor, tensor) + %d = "tf.opD"() : () -> tensor + tf_executor.yield %a, %launch#0, %launch#1, %d : tensor, tensor, tensor, tensor + } + tf_executor.fetch + } + func.return +} + +// CHECK: %[[A:.*]], {{.*}} = tf_executor.island wraps "tf.opA" +// CHECK: %[[B:.*]], {{.*}} = tf_executor.island wraps "tf.opB"(%[[A]]) +// CHECK-SAME: _parallel_execution_ids = "r4:5,p0:0", device = "CPU:0" +// CHECK: %[[C:.*]], {{.*}} = tf_executor.island wraps "tf.opC"(%[[B]]) +// CHECK-SAME: _parallel_execution_ids = "r4:5,p0:0", device = "CPU:0" +// CHECK: %[[D:.*]], {{.*}} = tf_executor.island wraps "tf.opD" +// CHECK-NOT: "tf_device.launch" +// CHECK-NOT: tf_executor.yield %[[A]], %[[C]], %[[B]], %[[D]] + // ----- diff --git a/tensorflow/compiler/mlir/tensorflow/tests/launch_to_device_attribute_legacy.mlir b/tensorflow/compiler/mlir/tensorflow/tests/launch_to_device_attribute_legacy.mlir new file mode 100644 index 00000000000..c520792b807 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/launch_to_device_attribute_legacy.mlir @@ -0,0 +1,121 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-launch-to-device-attribute=legacy-graph-export=true | FileCheck %s + + +// Tests single TensorFlow op is hoisted out and has the correct device assigned +// by parent `tf_device.launch`. +// CHECK-LABEL: func @single_op_launch +func.func @single_op_launch() { + tf_executor.graph { + %0:5 = tf_executor.island { + %a = "tf.opA"() : () -> tensor + %launch:2 = "tf_device.launch"() ({ + %b:2 = "tf.opB"(%a) : (tensor) -> (tensor, tensor) + tf_device.return %b#1, %b#0 : tensor, tensor + }) {device = "CPU:0"} : () -> (tensor, tensor) + %c = "tf.opC"() : () -> tensor + tf_executor.yield %a, %launch#0, %launch#1, %c : tensor, tensor, tensor, tensor + } + tf_executor.fetch + } + func.return +} + +// CHECK: %[[A:.*]] = "tf.opA" +// CHECK: %[[B:.*]]:2 = "tf.opB"(%[[A]]) +// CHECK-SAME: device = "CPU:0" +// CHECK: %[[C:.*]] = "tf.opC" +// CHECK-NOT: "tf_device.launch" +// CHECK: tf_executor.yield %[[A]], %[[B]]#1, %[[B]]#0, %[[C]] + + +// Tests multiple TensorFlow ops are hoisted out and all have the correct device +// assigned by parent `tf_device.launch`. +// CHECK-LABEL: func @multi_op_launch +func.func @multi_op_launch() { + tf_executor.graph { + %0:5 = tf_executor.island { + %a = "tf.opA"() : () -> tensor + %launch:2 = "tf_device.launch"() ({ + %b = "tf.opB"(%a) : (tensor) -> tensor + %c = "tf.opC"(%b) : (tensor) -> tensor + tf_device.return %c, %b : tensor, tensor + }) {device = "CPU:0"} : () -> (tensor, tensor) + %d = "tf.opD"() : () -> tensor + tf_executor.yield %a, %launch#0, %launch#1, %d : tensor, tensor, tensor, tensor + } + tf_executor.fetch + } + func.return +} + +// CHECK: %[[A:.*]] = "tf.opA" +// CHECK: %[[B:.*]] = "tf.opB"(%[[A]]) +// CHECK-SAME: device = "CPU:0" +// CHECK: %[[C:.*]] = "tf.opC"(%[[B]]) +// CHECK-SAME: device = "CPU:0" +// CHECK: %[[D:.*]] = "tf.opD" +// CHECK-NOT: "tf_device.launch" +// CHECK: tf_executor.yield %[[A]], %[[C]], %[[B]], %[[D]] + + +// Tests empty device string attributes are overwritten. +// CHECK-LABEL: func @empty_device_op +func.func @empty_device_op() { + tf_executor.graph { + %0:3 = tf_executor.island { + %launch:2 = "tf_device.launch"() ({ + %a:2 = "tf.opA"() {device = ""} : () -> (tensor, tensor) + tf_device.return %a#1, %a#0 : tensor, tensor + }) {device = "CPU:0"} : () -> (tensor, tensor) + tf_executor.yield %launch#0, %launch#1: tensor, tensor + } + tf_executor.fetch + } + func.return +} + +// CHECK: [[A:%.+]]:2 = "tf.opA" +// CHECK-SAME: device = "CPU:0" +// CHECK-NOT: tf_device.launch +// CHECK: tf_executor.yield [[A]]#1, [[A]]#0 + + +// ----- + + +// Tests TensorFlow op with conflicting `device` attribute compared to parent +// `tf_device.launch`. +func.func @conflicting_device() { + tf_executor.graph { + %0 = tf_executor.island { + // expected-error@+1 {{'tf_device.launch' op inner op has conflicting 'device' attribute, got 'GPU:0' but expected 'CPU:0'}} + "tf_device.launch"() ({ + "tf.opA"() {device = "GPU:0"} : () -> () + tf_device.return + }) {device = "CPU:0"} : () -> () + tf_executor.yield + } + tf_executor.fetch + } + func.return +} + + +// ----- + + +// Tests TensorFlow op with bad `device` attribute already set. +func.func @bad_tf_device_attr() { + tf_executor.graph { + %0 = tf_executor.island { + // expected-error@+1 {{'tf_device.launch' op inner op has bad 'device' attribute}} + "tf_device.launch"() ({ + "tf.opA"() {device = 0 : i32} : () -> () + tf_device.return + }) {device = "CPU:0"} : () -> () + tf_executor.yield + } + tf_executor.fetch + } + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 8c853de1e69..e9745eb4e08 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -1576,6 +1576,17 @@ func.func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4 func.return %0 : tensor<3x5x1x4xf32> } +// CHECK-LABEL: func @quantized_dot_general_not_converted +// CHECK: "mhlo.dot_general" +func.func @quantized_dot_general_not_converted(%arg0: tensor<1x1x512xf32>, %arg1: tensor<512x512x!quant.uniform>) -> tensor<1x1x512xf32> { + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [0] + >} : (tensor<1x1x512xf32>, tensor<512x512x!quant.uniform>) -> tensor<1x1x512xf32> + func.return %0 : tensor<1x1x512xf32> +} + // CHECK-LABEL: func @convert_dot_general_repeated( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x1024xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<1024x1024xf32>) -> tensor<1x1x1024xf32> { @@ -2214,6 +2225,30 @@ func.func @convert_avgpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8x func.return %4 : tensor<4x8x8x8xf32> } +// CHECK-LABEL: func @convert_avgpool_reshape_broadcast( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> { +// CHECK: %[[VAL_1:.*]] = "tf.AvgPool"(%[[VAL_0]]) {data_format = "NHWC", ksize = [1, 3, 3, 1], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> +// CHECK: return %[[VAL_1]] : tensor<4x8x8x8xf32> +// CHECK: } +func.func @convert_avgpool_reshape_broadcast(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> { + %0 = mhlo.constant dense<1.000000e+00> : tensor<1x16x16x1xf32> + %1 = mhlo.constant dense<0.000000e+00> : tensor + %2 = "mhlo.reduce_window"(%arg0, %1) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = mhlo.add %arg1, %arg2 : tensor + mhlo.return %7 : tensor + }) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor) -> tensor<4x8x8x8xf32> + %3 = "mhlo.reduce_window"(%0, %1) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = mhlo.add %arg1, %arg2 : tensor + mhlo.return %7 : tensor + }) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x16x16x1xf32>, tensor) -> tensor<1x8x8x1xf32> + %4 = mhlo.reshape %3 : (tensor<1x8x8x1xf32>) -> tensor<8x8xf32> + %5 = "mhlo.broadcast_in_dim"(%4) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<8x8xf32>) -> tensor<4x8x8x8xf32> + %6 = mhlo.divide %2, %5 : tensor<4x8x8x8xf32> + return %6 : tensor<4x8x8x8xf32> +} + // CHECK-LABEL: func @convert_maxpool_valid( // CHECK-SAME: %[[VAL_0:.*]]: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> { // CHECK: %[[VAL_1:.*]] = "tf.MaxPool"(%[[VAL_0]]) {data_format = "NHWC", explicit_paddings = [], ksize = [1, 3, 3, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> @@ -2388,6 +2423,23 @@ func.func @convert_floor_mod_int_cst(%arg0: tensor<192x8xi32>) -> tensor<192x8xi func.return %7 : tensor<192x8xi32> } +// CHECK-LABEL: func @convert_floor_mod_bfloat +// CHECK: %[[RESULT:.*]] = "tf.FloorMod"(%arg0, %arg1) : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xbf16> +// CHECK: return %[[RESULT]] +// CHECK: } +func.func @convert_floor_mod_bfloat(%arg0: tensor<10x10xbf16>, %arg1: tensor<10x10xbf16>) -> tensor<10x10xbf16> { + %0 = mhlo.constant dense<0.000000e+00> : tensor<10x10xbf16> + %1 = mhlo.remainder %arg0, %arg1 : tensor<10x10xbf16> + %2 = mhlo.compare NE, %1, %0, FLOAT : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xi1> + %3 = mhlo.compare LT, %1, %0, FLOAT : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xi1> + %4 = mhlo.compare LT, %arg1, %0, FLOAT : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xi1> + %5 = mhlo.compare NE, %3, %4, UNSIGNED : (tensor<10x10xi1>, tensor<10x10xi1>) -> tensor<10x10xi1> + %6 = mhlo.and %5, %2 : tensor<10x10xi1> + %7 = mhlo.add %1, %arg1 : tensor<10x10xbf16> + %8 = mhlo.select %6, %7, %1 : tensor<10x10xi1>, tensor<10x10xbf16> + return %8 : tensor<10x10xbf16> +} + // CHECK-LABEL: func @convert_floor_div // CHECK: %[[RESULT:.*]] = "tf.FloorDiv"(%arg0, %arg1) : (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xbf16> // CHECK: return %[[RESULT]] @@ -3020,6 +3072,94 @@ func.func @convert_not(%arg0: tensor<5x3x1xi1>) -> tensor<5x3x1xi1> { func.return %0 : tensor<5x3x1xi1> } +// CHECK-LABEL: func @convert_not_i8( +// CHECK-SAME: %[[ARG:.*]]: tensor<7x9x11xi8>) -> tensor<7x9x11xi8> { +// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor +// CHECK: %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xi8>, tensor) -> tensor<7x9x11xi8> +// CHECK: return %[[RES]] : tensor<7x9x11xi8> +// CHECK: } +func.func @convert_not_i8(%arg0: tensor<7x9x11xi8>) -> tensor<7x9x11xi8> { + %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi8>) -> (tensor<7x9x11xi8>) + func.return %0 : tensor<7x9x11xi8> +} + +// CHECK-LABEL: func @convert_not_i16( +// CHECK-SAME: %[[ARG:.*]]: tensor<7x9x11xi16>) -> tensor<7x9x11xi16> { +// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor +// CHECK: %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xi16>, tensor) -> tensor<7x9x11xi16> +// CHECK: return %[[RES]] : tensor<7x9x11xi16> +// CHECK: } +func.func @convert_not_i16(%arg0: tensor<7x9x11xi16>) -> tensor<7x9x11xi16> { + %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi16>) -> (tensor<7x9x11xi16>) + func.return %0 : tensor<7x9x11xi16> +} + +// CHECK-LABEL: func @convert_not_i32( +// CHECK-SAME: %[[ARG:.*]]: tensor<7x9x11xi32>) -> tensor<7x9x11xi32> { +// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor +// CHECK: %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xi32>, tensor) -> tensor<7x9x11xi32> +// CHECK: return %[[RES]] : tensor<7x9x11xi32> +// CHECK: } +func.func @convert_not_i32(%arg0: tensor<7x9x11xi32>) -> tensor<7x9x11xi32> { + %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi32>) -> (tensor<7x9x11xi32>) + func.return %0 : tensor<7x9x11xi32> +} + +// CHECK-LABEL: func @convert_not_i64( +// CHECK-SAME: %[[ARG:.*]]: tensor<7x9x11xi64>) -> tensor<7x9x11xi64> { +// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor +// CHECK: %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xi64>, tensor) -> tensor<7x9x11xi64> +// CHECK: return %[[RES]] : tensor<7x9x11xi64> +// CHECK: } +func.func @convert_not_i64(%arg0: tensor<7x9x11xi64>) -> tensor<7x9x11xi64> { + %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi64>) -> (tensor<7x9x11xi64>) + func.return %0 : tensor<7x9x11xi64> +} + +// CHECK-LABEL: func @convert_not_ui8( +// CHECK-SAME: %[[ARG:.*]]: tensor<7x9x11xui8>) -> tensor<7x9x11xui8> { +// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<255> : tensor} : () -> tensor +// CHECK: %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xui8>, tensor) -> tensor<7x9x11xui8> +// CHECK: return %[[RES]] : tensor<7x9x11xui8> +// CHECK: } +func.func @convert_not_ui8(%arg0: tensor<7x9x11xui8>) -> tensor<7x9x11xui8> { + %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui8>) -> (tensor<7x9x11xui8>) + func.return %0 : tensor<7x9x11xui8> +} + +// CHECK-LABEL: func @convert_not_ui16( +// CHECK-SAME: %[[ARG:.*]]: tensor<7x9x11xui16>) -> tensor<7x9x11xui16> { +// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<65535> : tensor} : () -> tensor +// CHECK: %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xui16>, tensor) -> tensor<7x9x11xui16> +// CHECK: return %[[RES]] : tensor<7x9x11xui16> +// CHECK: } +func.func @convert_not_ui16(%arg0: tensor<7x9x11xui16>) -> tensor<7x9x11xui16> { + %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui16>) -> (tensor<7x9x11xui16>) + func.return %0 : tensor<7x9x11xui16> +} + +// CHECK-LABEL: func @convert_not_ui32( +// CHECK-SAME: %[[ARG:.*]]: tensor<7x9x11xui32>) -> tensor<7x9x11xui32> { +// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<4294967295> : tensor} : () -> tensor +// CHECK: %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xui32>, tensor) -> tensor<7x9x11xui32> +// CHECK: return %[[RES]] : tensor<7x9x11xui32> +// CHECK: } +func.func @convert_not_ui32(%arg0: tensor<7x9x11xui32>) -> tensor<7x9x11xui32> { + %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui32>) -> (tensor<7x9x11xui32>) + func.return %0 : tensor<7x9x11xui32> +} + +// CHECK-LABEL: func @convert_not_ui64( +// CHECK-SAME: %[[ARG:.*]]: tensor<7x9x11xui64>) -> tensor<7x9x11xui64> { +// CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<18446744073709551615> : tensor} : () -> tensor +// CHECK: %[[RES:.*]] = "tf.BitwiseXor"(%[[ARG]], %[[CST]]) : (tensor<7x9x11xui64>, tensor) -> tensor<7x9x11xui64> +// CHECK: return %[[RES]] : tensor<7x9x11xui64> +// CHECK: } +func.func @convert_not_ui64(%arg0: tensor<7x9x11xui64>) -> tensor<7x9x11xui64> { + %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui64>) -> (tensor<7x9x11xui64>) + func.return %0 : tensor<7x9x11xui64> +} + // ----- // CHECK-LABEL: func @while_with_variadic() -> (tensor, tensor, tensor) { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index f34c6267661..432195aed3b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -81,9 +81,9 @@ func.func @div_no_nan(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf3 func.return %0 : tensor<*xf32> } -// CHECK-LABEL: @truncate_div +// CHECK-LABEL: @truncate_div_int // CHECK-SAME: (%[[LHS:.*]]: tensor<*xi32>, %[[RHS:.*]]: tensor<*xi32>) -func.func @truncate_div(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) +func.func @truncate_div_int(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { // CHECK: %[[RESULT:.*]] = "tf.Div"(%[[LHS]], %[[RHS]]) // CHECK: return %[[RESULT]] @@ -92,6 +92,23 @@ func.func @truncate_div(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) func.return %0 : tensor<*xi32> } +// CHECK-LABEL: @truncate_div_float +// CHECK-SAME: (%[[LHS:.*]]: tensor<*xf32>, %[[RHS:.*]]: tensor<*xf32>) +func.func @truncate_div_float(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) + -> tensor<*xf32> { + // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} : () -> tensor + // CHECK: %[[XDIVY:.*]] = "tf.Div"(%[[LHS]], %[[RHS]]) + // CHECK: %[[MASK:.*]] = "tf.Less"(%[[XDIVY]], %[[ZERO]]) + // CHECK: %[[CEIL:.*]] = "tf.Ceil"(%[[XDIVY]]) + // CHECK: %[[FLOOR:.*]] = "tf.Floor"(%[[XDIVY]]) + // CHECK: %[[RESULT:.*]] = "tf.SelectV2"(%[[MASK]], %[[CEIL]], %[[FLOOR]]) + %0 = "tf.TruncateDiv"(%arg0, %arg1) + : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + + // CHECK: return %[[RESULT]] + func.return %0 : tensor<*xf32> +} + // CHECK-LABEL: func @mul_no_nan // CHECK-SAME: (%[[X:.*]]: tensor<2x3xf32>, %[[Y:.*]]: tensor<3xf32>) func.func @mul_no_nan(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>) -> tensor<2x3xf32> { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir index 06b24b385d4..a19dd27b73a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir @@ -558,3 +558,12 @@ func.func @variant_block_arg(tensor>>) -> () { func.return } +// CHECK-LABEL: func @set_bound +func.func @set_bound(%arg0: tensor) -> tensor { + %bound = "tf.Const"() {value = dense<16> : tensor} : () -> tensor + // CHECK: tf.XlaSetBound + // CHECK-NOT: _xla_outside_compilation + %bounded = "tf.XlaSetBound"(%arg0, %bound) : (tensor, tensor) -> tensor + func.return %bounded : tensor +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD index 9acd2b232d3..432d0ab8733 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD @@ -1,6 +1,8 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + licenses(["notice"]) glob_lit_tests( diff --git a/tensorflow/compiler/mlir/tensorflow/tests/order_by_dialect.mlir b/tensorflow/compiler/mlir/tensorflow/tests/order_by_dialect.mlir index 399b3b6bac7..bb3f74d702b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/order_by_dialect.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/order_by_dialect.mlir @@ -1,7 +1,7 @@ // RUN: tf-opt %s -allow-unregistered-dialect --tf-order-by-dialect --split-input-file | FileCheck %s // CHECK-LABEL: @interleave -func.func @interleave(%arg0: f32) -> (f32, f32, f32) { +func.func @interleave(%arg0: f32) -> (f32, f32, f32) attributes {ignore_side_effects_for_testing} { %0 = "x.a"(%arg0) : (f32) -> f32 %1 = "y.a"(%arg0) : (f32) -> f32 %2 = "z.a"(%arg0) : (f32) -> f32 @@ -26,7 +26,7 @@ func.func @interleave(%arg0: f32) -> (f32, f32, f32) { // ----- // CHECK-LABEL: @terminator -func.func @terminator(%arg0: f32) -> (f32) { +func.func @terminator(%arg0: f32) -> (f32) attributes {ignore_side_effects_for_testing} { func.call @terminator(%arg0) : (f32) -> (f32) "x.a"(%arg0) : (f32) -> () "y.a"(%arg0) : (f32) -> () @@ -43,7 +43,7 @@ func.func @terminator(%arg0: f32) -> (f32) { // ----- // CHECK-LABEL: @fanout -func.func @fanout(%arg0: f32) -> (f32) { +func.func @fanout(%arg0: f32) -> (f32) attributes {ignore_side_effects_for_testing} { %0 = "x.a"(%arg0) : (f32) -> (f32) %1 = "y.a"(%0) : (f32) -> (f32) %2 = "y.b"(%0) : (f32) -> (f32) @@ -62,7 +62,7 @@ func.func @fanout(%arg0: f32) -> (f32) { // ----- // CHECK-LABEL: @constants -func.func @constants() -> f32 { +func.func @constants() -> f32 attributes {ignore_side_effects_for_testing} { %0 = "a.x"() : () -> f32 %1 = "b.x"() : () -> f32 %2 = "c.x"() : () -> f32 @@ -129,7 +129,7 @@ func.func private @mhlo_while() { // ----- // CHECK-LABEL: @nested_regions -func.func @nested_regions(%arg0: f32) { +func.func @nested_regions(%arg0: f32) attributes {ignore_side_effects_for_testing} { %0 = "x.a"(%arg0) : (f32) -> f32 %1 = "y.a"(%arg0) : (f32) -> f32 %2 = "x.b"(%arg0) : (f32) -> f32 @@ -157,3 +157,61 @@ func.func @nested_regions(%arg0: f32) { // CHECK-NEXT: x.d // CHECK-NEXT: y.c // CHECK-NEXT: y.e + +// ----- + +// CHECK-LABEL: interleaved_tf_and_mhlo +func.func private @interleaved_tf_and_mhlo() { + %m0 = mhlo.constant dense<0> : tensor + %t0 = "tf.Const"() { value = dense<0> : tensor<1xi32> } : () -> tensor<1xi32> + %m1 = mhlo.constant dense<1> : tensor + %t1 = "tf.Const"() { value = dense<1> : tensor<1xi32> } : () -> tensor<1xi32> + %m2 = mhlo.constant dense<1> : tensor + %t2 = "tf.Const"() { value = dense<1> : tensor<1xi32> } : () -> tensor<1xi32> + %m3 = mhlo.constant dense<1> : tensor + %t3 = "tf.Const"() { value = dense<1> : tensor<1xi32> } : () -> tensor<1xi32> + // CHECK: mhlo.constant + // CHECK: mhlo.constant + // CHECK: mhlo.constant + // CHECK: mhlo.constant + // CHECK: tf.Const + // CHECK: tf.Const + // CHECK: tf.Const + // CHECK: tf.Const + return +} + +// ----- + +// CHECK-LABEL: variable_ops +func.func private @variable_ops(%arg0: tensor>>) { + %t3 = "tf.Const"() { value = dense<0> : tensor<0xi32> } : () -> tensor<0xi32> + // Without side effect analysis, we would now schedule tf.ReadVariableOp next, + // since all its operands are ready. Check that we don't. + %0 = mhlo.constant dense<0.> : tensor + "tf.AssignVariableOp"(%arg0, %0) : (tensor>>, tensor) -> () + %1 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor + // CHECK: tf.Const + // CHECK: mhlo.constant + // CHECK: tf.Assign + // CHECK: tf.Read + return +} + +// ----- + +func.func private @id(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + +// CHECK-LABEL: iterators +func.func private @iterators(%arg0 : tensor) { + %0 = "tf.Iterator"() {container = "", output_shapes = [#tf_type.shape<200x28x28x1>, #tf_type.shape<200x10>], output_types = [f32, f32], shared_name = "_iterator1"} : () -> tensor + %1 = func.call @id(%arg0) : (tensor) -> tensor + "tf.MakeIterator"(%1, %0) {_class = ["loc:@BatchDatasetV2"], device = ""} : (tensor, tensor) -> () + %2:2 = "tf.IteratorGetNext"(%0) {_class = ["loc:@iterator"], device = ""} : (tensor) -> (tensor<200x28x28x1xf32>, tensor<200x10xf32>) + // CHECK: tf.Iterator + // CHECK: tf.MakeIterator + // CHECK: tf.IteratorGetNext + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir index bb668bd3806..5c468ee37d8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -tf-parallel-execute-to-islands | FILECHECK_OPTS="" FileCheck %s +// RUN: tf-opt %s -tf-parallel-execute-to-islands=legacy-graph-export=false | FILECHECK_OPTS="" FileCheck %s // CHECK-LABEL: func @testEmptyRegions func.func @testEmptyRegions() { @@ -17,10 +17,10 @@ func.func @testEmptyRegions() { } // CHECK: [[ISLAND_0_CTRL:%.+]] = tf_executor.island { -// CHECK: tf_executor.yield +// CHECK: tf_executor.yield {_parallel_execution_ids = "p0:0"} // CHECK: [[ISLAND_1_CTRL:%.+]] = tf_executor.island { -// CHECK: tf_executor.yield -// CHECK: tf_executor.fetch [[ISLAND_0_CTRL]], [[ISLAND_1_CTRL]] : +// CHECK: tf_executor.yield {_parallel_execution_ids = "p0:1"} +// CHECK: tf_executor.fetch // CHECK-LABEL: func @testDataOperandsAndResults @@ -50,10 +50,10 @@ func.func @testDataOperandsAndResults(%arg0 : tensor) { // CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"([[ARG_0]]) // CHECK-NEXT: tf_executor.yield [[OP_A_OUTPUT]] : // CHECK: [[ISLAND_0_OUTPUT:%.+]], {{%.+}} = tf_executor.island { -// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"([[INPUT_A]]) +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"([[INPUT_A]]) {_parallel_execution_ids = "p0:0"} // CHECK: tf_executor.yield [[OP_B_OUTPUT]] : // CHECK: [[ISLAND_1_OUTPUT:%.+]], {{%.+}} = tf_executor.island { -// CHECK-NEXT: [[OP_C_OUTPUT:%.+]] = "tf.opC"([[INPUT_A]]) +// CHECK-NEXT: [[OP_C_OUTPUT:%.+]] = "tf.opC"([[INPUT_A]]) {_parallel_execution_ids = "p0:1"} // CHECK: tf_executor.yield [[OP_C_OUTPUT]] : // CHECK: tf_executor.fetch [[ISLAND_0_OUTPUT]], [[ISLAND_1_OUTPUT]] : @@ -62,6 +62,7 @@ func.func @testDataOperandsAndResults(%arg0 : tensor) { func.func @testControlOperands() { %0:2 = tf_executor.graph { %1 = tf_executor.island { + "tf.someOp"() : () -> () tf_executor.yield } %2:3 = tf_executor.island(%1) { @@ -81,10 +82,10 @@ func.func @testControlOperands() { // CHECK: [[INPUT_CTRL:%.+]] = tf_executor.island { // CHECK: [[ISLAND_0_OUTPUT:%.+]], {{%.+}} = tf_executor.island([[INPUT_CTRL]]) { -// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"() +// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"() {_parallel_execution_ids = "p0:0"} // CHECK: tf_executor.yield [[OP_A_OUTPUT]] : // CHECK: [[ISLAND_1_OUTPUT:%.+]], {{%.+}} = tf_executor.island([[INPUT_CTRL]]) { -// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"() +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"() {_parallel_execution_ids = "p0:1"} // CHECK: tf_executor.yield [[OP_B_OUTPUT]] : // CHECK: tf_executor.fetch [[ISLAND_0_OUTPUT]], [[ISLAND_1_OUTPUT]] : @@ -103,6 +104,7 @@ func.func @testControlResults() { tf_executor.yield %1#0, %1#1 : tensor, tensor } %3 = tf_executor.island(%0#2) { + "tf.someOp"() : () -> () tf_executor.yield } tf_executor.fetch %3 : !tf_executor.control @@ -111,10 +113,10 @@ func.func @testControlResults() { } // CHECK: {{%.+}}, [[ISLAND_0_CTRL:%.+]] = tf_executor.island { -// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"() +// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"() {_parallel_execution_ids = "p0:0"} // CHECK: tf_executor.yield [[OP_A_OUTPUT]] : // CHECK: {{%.+}}, [[ISLAND_1_CTRL:%.+]] = tf_executor.island { -// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"() +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"() {_parallel_execution_ids = "p0:1"} // CHECK: tf_executor.yield [[OP_B_OUTPUT]] : // CHECK: [[OUTPUT_CTRL:%.+]] = tf_executor.island([[ISLAND_0_CTRL]], [[ISLAND_1_CTRL]]) { // CHECK: [[FETCH_ISLAND:%.+]] = tf_executor.island([[OUTPUT_CTRL]]) { @@ -140,12 +142,38 @@ func.func @testSomeRegionNoUsers() { } // CHECK: [[ISLAND_0_OUTPUT:%.+]], {{%.+}} = tf_executor.island { -// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"() +// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"() {_parallel_execution_ids = "p0:0"} // CHECK: tf_executor.yield [[OP_A_OUTPUT]] : // CHECK: {{%.+}}, [[ISLAND_1_CTRL:%.+]] = tf_executor.island { -// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"() +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"() {_parallel_execution_ids = "p0:1"} // CHECK: tf_executor.yield [[OP_B_OUTPUT]] : -// CHECK: tf_executor.fetch [[ISLAND_0_OUTPUT]], [[ISLAND_1_CTRL]] : +// CHECK: tf_executor.fetch [[ISLAND_0_OUTPUT]] : + +// CHECK-LABEL: @testRegionContainsMultipleOps +func.func @testRegionContainsMultipleOps() { + %0:2 = tf_executor.graph { + %outputs:2, %control = tf_executor.island { + %1:2 = "tf_device.parallel_execute"() ({ + %2 = "tf.opA"() : () -> tensor + %3 = "tf.opB"(%2) : (tensor) -> tensor + tf_device.return %3 : tensor + }, { + %2 = "tf.opC"() : () -> tensor + %3 = "tf.opD"(%2) : (tensor) -> tensor + tf_device.return %3 : tensor + }) : () -> (tensor, tensor) + tf_executor.yield %1#0, %1#1 : tensor, tensor + } + tf_executor.fetch %outputs#0, %outputs#1 : tensor, tensor + } + return +} + +// CHECK: [[OUTPUT_0:%.*]], {{%.*}} = tf_executor.island wraps "tf.opA"() {_parallel_execution_ids = "p0:0"} +// CHECK: [[OUTPUT_1:%.*]], {{%.*}} = tf_executor.island wraps "tf.opB"([[OUTPUT_0:%.*]]) {_parallel_execution_ids = "p0:0"} +// CHECK: [[OUTPUT_2:%.*]], {{%.*}} = tf_executor.island wraps "tf.opC"() {_parallel_execution_ids = "p0:1"} +// CHECK: [[OUTPUT_3:%.*]], {{%.*}} = tf_executor.island wraps "tf.opD"([[OUTPUT_2:%.*]]) {_parallel_execution_ids = "p0:1"} +// CHECK: tf_executor.fetch [[OUTPUT_1:%.*]], [[OUTPUT_3:%.*]] // ----- @@ -175,6 +203,34 @@ func.func @testSingleton(%arg0 : tensor) { // CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"([[ARG_0]]) // CHECK-NEXT: tf_executor.yield [[OP_A_OUTPUT]] : // CHECK: [[ISLAND_0_OUTPUT:%.+]], {{%.+}} = tf_executor.island { -// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"([[INPUT_A]]) +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"([[INPUT_A]]) {_parallel_execution_ids = "p0:0"} // CHECK: tf_executor.yield [[OP_B_OUTPUT]] : // CHECK: tf_executor.fetch [[ISLAND_0_OUTPUT]] : + +// ----- +// Tests parallel_group attr can merge correctly. +// CHECK-LABEL: func @merge_of_parallel_group_attr +func.func @merge_of_parallel_group_attr() { + %0:2 = tf_executor.graph { + %outputs:2, %control = tf_executor.island { + %1:2 = "tf_device.parallel_execute" () ({ + %2 = "tf.opA"() : () -> tensor + %3 = "tf.opB"(%2) : (tensor) -> tensor + tf_device.return %3 : tensor + }, { + %2 = "tf.opC"() : () -> tensor + %3 = "tf.opD"(%2) : (tensor) -> tensor + tf_device.return %2 : tensor + }) {_parallel_execution_ids = "r4:5"} : () -> (tensor, tensor) + tf_executor.yield %1#0, %1#1 : tensor, tensor + } + tf_executor.fetch %outputs#0, %outputs#1 : tensor, tensor + } + return +} + +// CHECK: [[OUTPUT_0:%.*]], {{%.*}} = tf_executor.island wraps "tf.opA"() {_parallel_execution_ids = "r4:5,p0:0"} +// CHECK: [[OUTPUT_1:%.*]], {{%.*}} = tf_executor.island wraps "tf.opB"([[OUTPUT_0:%.*]]) {_parallel_execution_ids = "r4:5,p0:0"} +// CHECK: [[OUTPUT_2:%.*]], {{%.*}} = tf_executor.island wraps "tf.opC"() {_parallel_execution_ids = "r4:5,p0:1"} +// CHECK: [[OUTPUT_3:%.*]], {{%.*}} = tf_executor.island wraps "tf.opD"([[OUTPUT_2:%.*]]) {_parallel_execution_ids = "r4:5,p0:1"} +// CHECK: tf_executor.fetch [[OUTPUT_1:%.*]], [[OUTPUT_3:%.*]] diff --git a/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands_legacy.mlir b/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands_legacy.mlir new file mode 100644 index 00000000000..56c589768b8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands_legacy.mlir @@ -0,0 +1,182 @@ +// RUN: tf-opt %s -tf-parallel-execute-to-islands=legacy-graph-export=true | FILECHECK_OPTS="" FileCheck %s + +// CHECK-LABEL: func @testEmptyRegions +func.func @testEmptyRegions() { + tf_executor.graph { + tf_executor.island() { + "tf_device.parallel_execute"() ({ + tf_device.return + }, { + tf_device.return + }) {} : () -> () + tf_executor.yield + } + tf_executor.fetch + } + func.return +} + +// CHECK: [[ISLAND_0_CTRL:%.+]] = tf_executor.island { +// CHECK: tf_executor.yield +// CHECK: [[ISLAND_1_CTRL:%.+]] = tf_executor.island { +// CHECK: tf_executor.yield +// CHECK: tf_executor.fetch [[ISLAND_0_CTRL]], [[ISLAND_1_CTRL]] : + + +// CHECK-LABEL: func @testDataOperandsAndResults +// CHECK-SAME: ([[ARG_0:%.+]]: tensor) +func.func @testDataOperandsAndResults(%arg0 : tensor) { + %0:2 = tf_executor.graph { + %1:2 = tf_executor.island { + %2 = "tf.opA"(%arg0) : (tensor) -> tensor + tf_executor.yield %2 : tensor + } + %3:3 = tf_executor.island() { + %4:2 = "tf_device.parallel_execute"() ({ + %5 = "tf.opB"(%1#0) : (tensor) -> tensor + tf_device.return %5 : tensor + }, { + %5 = "tf.opC"(%1#0) : (tensor) -> tensor + tf_device.return %5 : tensor + }) {} : () -> (tensor, tensor) + tf_executor.yield %4#0, %4#1 : tensor, tensor + } + tf_executor.fetch %3#0, %3#1 : tensor, tensor + } + func.return +} + +// CHECK: [[INPUT_A:%.+]], {{%.+}} = tf_executor.island { +// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"([[ARG_0]]) +// CHECK-NEXT: tf_executor.yield [[OP_A_OUTPUT]] : +// CHECK: [[ISLAND_0_OUTPUT:%.+]], {{%.+}} = tf_executor.island { +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"([[INPUT_A]]) +// CHECK: tf_executor.yield [[OP_B_OUTPUT]] : +// CHECK: [[ISLAND_1_OUTPUT:%.+]], {{%.+}} = tf_executor.island { +// CHECK-NEXT: [[OP_C_OUTPUT:%.+]] = "tf.opC"([[INPUT_A]]) +// CHECK: tf_executor.yield [[OP_C_OUTPUT]] : +// CHECK: tf_executor.fetch [[ISLAND_0_OUTPUT]], [[ISLAND_1_OUTPUT]] : + + +// CHECK-LABEL: func @testControlOperands +func.func @testControlOperands() { + %0:2 = tf_executor.graph { + %1 = tf_executor.island { + "tf.someOp"() : () -> () + tf_executor.yield + } + %2:3 = tf_executor.island(%1) { + %3:2 = "tf_device.parallel_execute"() ({ + %4 = "tf.opA"() : () -> tensor + tf_device.return %4 : tensor + }, { + %4 = "tf.opB"() : () -> tensor + tf_device.return %4 : tensor + }) {} : () -> (tensor, tensor) + tf_executor.yield %3#0, %3#1 : tensor, tensor + } + tf_executor.fetch %2#0, %2#1 : tensor, tensor + } + func.return +} + +// CHECK: [[INPUT_CTRL:%.+]] = tf_executor.island { +// CHECK: [[ISLAND_0_OUTPUT:%.+]], {{%.+}} = tf_executor.island([[INPUT_CTRL]]) { +// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"() +// CHECK: tf_executor.yield [[OP_A_OUTPUT]] : +// CHECK: [[ISLAND_1_OUTPUT:%.+]], {{%.+}} = tf_executor.island([[INPUT_CTRL]]) { +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"() +// CHECK: tf_executor.yield [[OP_B_OUTPUT]] : +// CHECK: tf_executor.fetch [[ISLAND_0_OUTPUT]], [[ISLAND_1_OUTPUT]] : + + +// CHECK-LABEL: func @testControlResults +func.func @testControlResults() { + tf_executor.graph { + %0:3 = tf_executor.island { + %1:2 = "tf_device.parallel_execute"() ({ + %2 = "tf.opA"() : () -> tensor + tf_device.return %2 : tensor + }, { + %2 = "tf.opB"() : () -> tensor + tf_device.return %2 : tensor + }) {} : () -> (tensor, tensor) + tf_executor.yield %1#0, %1#1 : tensor, tensor + } + %3 = tf_executor.island(%0#2) { + "tf.someOp"() : () -> () + tf_executor.yield + } + tf_executor.fetch %3 : !tf_executor.control + } + func.return +} + +// CHECK: {{%.+}}, [[ISLAND_0_CTRL:%.+]] = tf_executor.island { +// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"() +// CHECK: tf_executor.yield [[OP_A_OUTPUT]] : +// CHECK: {{%.+}}, [[ISLAND_1_CTRL:%.+]] = tf_executor.island { +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"() +// CHECK: tf_executor.yield [[OP_B_OUTPUT]] : +// CHECK: [[OUTPUT_CTRL:%.+]] = tf_executor.island([[ISLAND_0_CTRL]], [[ISLAND_1_CTRL]]) { +// CHECK: [[FETCH_ISLAND:%.+]] = tf_executor.island([[OUTPUT_CTRL]]) { +// CHECK: tf_executor.fetch [[FETCH_ISLAND]] : !tf_executor.control + + +// CHECK-LABEL: func @testSomeRegionNoUsers +func.func @testSomeRegionNoUsers() { + %0 = tf_executor.graph { + %1:3 = tf_executor.island { + %2:2 = "tf_device.parallel_execute"() ({ + %3 = "tf.opA"() : () -> tensor + tf_device.return %3 : tensor + }, { + %3 = "tf.opB"() : () -> tensor + tf_device.return %3 : tensor + }) {} : () -> (tensor, tensor) + tf_executor.yield %2#0, %2#1 : tensor, tensor + } + tf_executor.fetch %1#0 : tensor + } + func.return +} + +// CHECK: [[ISLAND_0_OUTPUT:%.+]], {{%.+}} = tf_executor.island { +// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"() +// CHECK: tf_executor.yield [[OP_A_OUTPUT]] : +// CHECK: {{%.+}}, [[ISLAND_1_CTRL:%.+]] = tf_executor.island { +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"() +// CHECK: tf_executor.yield [[OP_B_OUTPUT]] : +// CHECK: tf_executor.fetch [[ISLAND_0_OUTPUT]], [[ISLAND_1_CTRL]] : + +// ----- + +// Tests a ParallelExecute with a single region. + +// CHECK-LABEL: func @testSingleton +// CHECK-SAME: ([[ARG_0:%.+]]: tensor) +func.func @testSingleton(%arg0 : tensor) { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + %2 = "tf.opA"(%arg0) : (tensor) -> tensor + tf_executor.yield %2 : tensor + } + %3:2 = tf_executor.island() { + %4 = "tf_device.parallel_execute"() ({ + %5 = "tf.opB"(%1#0) : (tensor) -> tensor + tf_device.return %5 : tensor + }) {} : () -> tensor + tf_executor.yield %4 : tensor + } + tf_executor.fetch %3#0 : tensor + } + func.return +} + +// CHECK: [[INPUT_A:%.+]], {{%.+}} = tf_executor.island { +// CHECK-NEXT: [[OP_A_OUTPUT:%.+]] = "tf.opA"([[ARG_0]]) +// CHECK-NEXT: tf_executor.yield [[OP_A_OUTPUT]] : +// CHECK: [[ISLAND_0_OUTPUT:%.+]], {{%.+}} = tf_executor.island { +// CHECK-NEXT: [[OP_B_OUTPUT:%.+]] = "tf.opB"([[INPUT_A]]) +// CHECK: tf_executor.yield [[OP_B_OUTPUT]] : +// CHECK: tf_executor.fetch [[ISLAND_0_OUTPUT]] : diff --git a/tensorflow/compiler/mlir/tensorflow/tests/remove_unused_arguments.mlir b/tensorflow/compiler/mlir/tensorflow/tests/remove_unused_arguments.mlir index d1c71c6f14b..e4c1941a1c2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/remove_unused_arguments.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/remove_unused_arguments.mlir @@ -76,6 +76,21 @@ func.func @handles_partitioned_function_calls(%arg0: tensor, %arg1: tensor< // ----- +func.func private @f(%arg0: tensor, %arg1: tensor) -> tensor { + return %arg0 : tensor +} + +// CHECK-LABEL: handles_tpu_partitioned_function_calls_with_device_ordinal +func.func @handles_tpu_partitioned_function_calls_with_device_ordinal(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.TPUOrdinalSelector"() {device = ""} : () -> tensor + // CHECK: [[ordinal:%[0-9]*]] = "tf.TPUOrdinalSelector" + // CHECK: TPUPartitionedCall"([[ordinal]]) + %1 = "tf.TPUPartitionedCall"(%arg0, %arg1, %0) {f = @f} : (tensor, tensor, tensor) -> tensor + return %1 : tensor +} + +// ----- + func.func private @f(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> f32 { %0 = "tf.Add2"(%arg0, %arg2) : (f32, f32) -> f32 return %0 : f32 diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_tensor_list_init_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_tensor_list_init_ops.mlir new file mode 100644 index 00000000000..10567878eaa --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_tensor_list_init_ops.mlir @@ -0,0 +1,139 @@ +// RUN: tf-opt %s -tf-replicate-tensor-list-init-ops -verify-diagnostics | FileCheck %s + +// CHECK: while_region_op +func.func @while_region_op() { + %elem_shape = "tf.Const"() {value = dense<[-1, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK: TensorListReserve + // CHECK: TensorListReserve + %tl = "tf.TensorListReserve"(%elem_shape, %size) : (tensor<2xi32>, tensor) -> tensor>> + %while:1 = "tf.WhileRegion"(%tl) ({ + ^bb0(%barg1: tensor>>): // no predeceessors + %cond = "tf.false"():()-> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, { + ^bb0(%barg1: tensor>>): // no predeceessors + "tf.Yield"(%barg1) : (tensor>>) -> () + }) {is_stateless = false} : (tensor>>) -> (tensor>>) + func.return +} + +// CHECK: while_region_op_empty_tensor_list +func.func @while_region_op_empty_tensor_list() { + %elem_shape = "tf.Const"() {value = dense<[-1, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK: EmptyTensorList + // CHECK: EmptyTensorList + %tl = "tf.EmptyTensorList"(%elem_shape, %size) : (tensor<2xi32>, tensor) -> tensor>> + %while:1 = "tf.WhileRegion"(%tl) ({ + ^bb0(%barg1: tensor>>): // no predeceessors + %cond = "tf.false"():()-> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, { + ^bb0(%barg1: tensor>>): // no predeceessors + "tf.Yield"(%barg1) : (tensor>>) -> () + }) {is_stateless = false} : (tensor>>) -> (tensor>>) + func.return +} + +// CHECK: while_region_op_twosepargs +func.func @while_region_op_twosepargs() { + %one = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem_shape = "tf.Const"() {value = dense<[-1, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK: %[[RESULT0:.*]] = "tf.TensorListReserve" + // CHECK: %[[RESULT1:.*]] = "tf.TensorListReserve" + // CHECK: %[[RESULT2:.*]] = "tf.TensorListReserve" + // CHECK: tf.WhileRegion + // CHECK-SAME: (%[[RESULT1]], %[[RESULT0]]) + %tl = "tf.TensorListReserve"(%elem_shape, %size) : (tensor<2xi32>, tensor) -> tensor>> + %while:2 = "tf.WhileRegion"(%tl, %tl) ({ + ^bb0(%barg1: tensor>>, %barg2: tensor>>): // no predeceessors + %cond = "tf.false"():()-> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, { + ^bb0(%barg1: tensor>>, %barg2: tensor>>): // no predeceessors + "tf.Yield"(%barg1, %barg2) : (tensor>>, tensor>>) -> () + }) {is_stateless = false} : (tensor>>, tensor>>) -> (tensor>>, tensor>>) + func.return +} + +// CHECK: while_region_op_two_sep_args_empty_tensor_list +func.func @while_region_op_two_sep_args_empty_tensor_list() { + %one = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem_shape = "tf.Const"() {value = dense<[-1, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK: %[[RESULT0:.*]] = "tf.EmptyTensorList" + // CHECK: %[[RESULT1:.*]] = "tf.EmptyTensorList" + // CHECK: %[[RESULT2:.*]] = "tf.EmptyTensorList" + // CHECK: tf.WhileRegion + // CHECK-SAME: (%[[RESULT1]], %[[RESULT0]]) + %tl = "tf.EmptyTensorList"(%elem_shape, %size) : (tensor<2xi32>, tensor) -> tensor>> + %while:2 = "tf.WhileRegion"(%tl, %tl) ({ + ^bb0(%barg1: tensor>>, %barg2: tensor>>): // no predeceessors + %cond = "tf.false"():()-> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, { + ^bb0(%barg1: tensor>>, %barg2: tensor>>): // no predeceessors + "tf.Yield"(%barg1, %barg2) : (tensor>>, tensor>>) -> () + }) {is_stateless = false} : (tensor>>, tensor>>) -> (tensor>>, tensor>>) + func.return +} + +// CHECK: no_while_region_op +func.func @no_while_region_op() { + %one = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem_shape = "tf.Const"() {value = dense<[-1, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK: TensorListReserve + // CHECK: TensorListReserve + %tl = "tf.TensorListReserve"(%elem_shape, %size) : (tensor<2xi32>, tensor) -> tensor>> + %elem_1 = "tf._SomeOtherOp"() : () -> tensor<8x1xf32> + %tl_set_item = "tf.TensorListSetItem"(%tl, %one, %elem_1) : (tensor>>, tensor, tensor<8x1xf32>) -> tensor>> + func.return +} + +// CHECK: no_while_region_op_empty_tensor_list +func.func @no_while_region_op_empty_tensor_list() { + %one = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem_shape = "tf.Const"() {value = dense<[-1, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK: EmptyTensorList + // CHECK: EmptyTensorList + %tl = "tf.EmptyTensorList"(%elem_shape, %size) : (tensor<2xi32>, tensor) -> tensor>> + %elem_1 = "tf._SomeOtherOp"() : () -> tensor<8x1xf32> + %tl_set_item = "tf.TensorListSetItem"(%tl, %one, %elem_1) : (tensor>>, tensor, tensor<8x1xf32>) -> tensor>> + func.return +} + +// CHECK: use_two_sep_ops +func.func @use_two_sep_ops() { + %one = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem_shape = "tf.Const"() {value = dense<[-1, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK: TensorListReserve + // CHECK: TensorListReserve + // CHECK: TensorListReserve + %tl = "tf.TensorListReserve"(%elem_shape, %size) : (tensor<2xi32>, tensor) -> tensor>> + %elem_1 = "tf._FirstOp"() : () -> tensor<8x1xf32> + %tl_set_item = "tf.TensorListSetItem"(%tl, %one, %elem_1) : (tensor>>, tensor, tensor<8x1xf32>) -> tensor>> + %elem_2 = "tf._SecondOp"() : () -> tensor<16x1xf32> + %tl_set_item2 = "tf.TensorListSetItem"(%tl, %one, %elem_2) : (tensor>>, tensor, tensor<16x1xf32>) -> tensor>> + func.return +} + +// CHECK: use_two_sep_ops_empty_tensor_list +func.func @use_two_sep_ops_empty_tensor_list() { + %one = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %elem_shape = "tf.Const"() {value = dense<[-1, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + %size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK: EmptyTensorList + // CHECK: EmptyTensorList + // CHECK: EmptyTensorList + %tl = "tf.EmptyTensorList"(%elem_shape, %size) : (tensor<2xi32>, tensor) -> tensor>> + %elem_1 = "tf._FirstOp"() : () -> tensor<8x1xf32> + %tl_set_item = "tf.TensorListSetItem"(%tl, %one, %elem_1) : (tensor>>, tensor, tensor<8x1xf32>) -> tensor>> + %elem_2 = "tf._SecondOp"() : () -> tensor<16x1xf32> + %tl_set_item2 = "tf.TensorListSetItem"(%tl, %one, %elem_2) : (tensor>>, tensor, tensor<16x1xf32>) -> tensor>> + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir index eed78e0e503..079cdafc484 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -split-input-file -verify-diagnostics %s -tf-replicate-to-island | FileCheck %s +// RUN: tf-opt -split-input-file -verify-diagnostics %s -tf-replicate-to-island=legacy-graph-export=false | FileCheck %s // Tests per replica island has same control operands as island holding // replicate. @@ -44,9 +44,9 @@ func.func @no_devices() { } // CHECK: "tf.opA" -// CHECK: device = "CORE_0" +// CHECK: _parallel_execution_ids = "r0:0", device = "CORE_0" // CHECK: "tf.opA" -// CHECK: device = "CORE_0" +// CHECK: _parallel_execution_ids = "r0:1", device = "CORE_0" // Tests devices are not remapped if device is not in replicate devices. @@ -69,9 +69,9 @@ func.func @no_override_device() { } // CHECK: "tf.opA" -// CHECK: device = "/TPU:2" +// CHECK: _parallel_execution_ids = "r0:0", device = "/TPU:2" // CHECK: "tf.opA" -// CHECK: device = "/TPU:2" +// CHECK: _parallel_execution_ids = "r0:1", device = "/TPU:2" // Tests devices are remapped if device is in replicate devices. @@ -94,9 +94,9 @@ func.func @remap_device() { } // CHECK: "tf.opA" -// CHECK: device = "/CPU:0" +// CHECK: _parallel_execution_ids = "r0:0", device = "/CPU:0" // CHECK: "tf.opA" -// CHECK: device = "/GPU:1" +// CHECK: _parallel_execution_ids = "r0:1", device = "/GPU:1" // Tests replicate with control dependency output has each expanded replica @@ -138,7 +138,7 @@ func.func @unused_replica(%arg0: tensor) { // CHECK: {{%.*}}, [[REPLICA_0_CONTROL:%.*]] = tf_executor.island // CHECK: [[REPLICA_1_OUTPUT:%.*]], {{%.*}} = tf_executor.island -// CHECK: tf_executor.fetch [[REPLICA_1_OUTPUT]], [[REPLICA_0_CONTROL]] +// CHECK: tf_executor.fetch [[REPLICA_1_OUTPUT]] : // Tests replicate results are remapped correctly. @@ -158,10 +158,19 @@ func.func @replicate_result(%arg0: tensor, %arg1: tensor) { func.return } -// CHECK: %[[REPLICA_0:.*]]:2, %{{.*}} = tf_executor.island -// CHECK: %[[REPLICA_1:.*]]:2, %{{.*}} = tf_executor.island -// CHECK: tf_executor.fetch %[[REPLICA_0]]#0, %[[REPLICA_1]]#0, %[[REPLICA_0]]#1, %[[REPLICA_1]]#1 - +// CHECK: %[[REPLICA_OPA_1:.*]], %{{.*}} = tf_executor.island wraps +// CHECK: "tf.opA"(%arg0) +// CHECK: _parallel_execution_ids = "r0:0" +// CHECK: %[[REPLICA_OPB_1:.*]], %{{.*}} = tf_executor.island wraps +// CHECK: "tf.opB"(%arg0) +// CHECK: _parallel_execution_ids = "r0:0" +// CHECK: %[[REPLICA_OPA_2:.*]], %{{.*}} = tf_executor.island wraps +// CHECK: "tf.opA"(%arg1) +// CHECK: _parallel_execution_ids = "r0:1" +// CHECK: %[[REPLICA_OPB_2:.*]], %{{.*}} = tf_executor.island wraps +// CHECK: "tf.opB"(%arg1) +// CHECK: _parallel_execution_ids = "r0:1" +// CHECK: tf_executor.fetch %[[REPLICA_OPA_1]], %[[REPLICA_OPA_2]], %[[REPLICA_OPB_1]], %[[REPLICA_OPB_2]] // Tests replicate results are remapped correctly with packed inputs. // CHECK-LABEL: func @replicate_with_packed_input @@ -181,14 +190,19 @@ func.func @replicate_with_packed_input(%arg0: tensor, %arg1: tensor) { func.return } -// CHECK: %[[REPLICA_0:.*]]:2, %{{.*}} = tf_executor.island +// CHECK: %[[REPLICA_OPA_1:.*]], %{{.*}} = tf_executor.island wraps // CHECK: "tf.opA"(%arg0) +// CHECK: _parallel_execution_ids = "r0:0" +// CHECK: %[[REPLICA_OPB_1:.*]], %{{.*}} = tf_executor.island wraps // CHECK: "tf.opB"(%arg1) -// CHECK: %[[REPLICA_1:.*]]:2, %{{.*}} = tf_executor.island +// CHECK: _parallel_execution_ids = "r0:0" +// CHECK: %[[REPLICA_OPA_2:.*]], %{{.*}} = tf_executor.island wraps // CHECK: "tf.opA"(%arg0) +// CHECK: _parallel_execution_ids = "r0:1" +// CHECK: %[[REPLICA_OPB_2:.*]], %{{.*}} = tf_executor.island wraps // CHECK: "tf.opB"(%arg1) -// CHECK: tf_executor.fetch %[[REPLICA_0]]#0, %[[REPLICA_1]]#0 - +// CHECK: _parallel_execution_ids = "r0:1" +// CHECK: tf_executor.fetch %[[REPLICA_OPA_1]], %[[REPLICA_OPA_2]], %[[REPLICA_OPB_1]], %[[REPLICA_OPB_2]] // Tests replica id is added correctly. // CHECK-LABEL: func @replica_id_attr_added @@ -209,24 +223,30 @@ func.func @replica_id_attr_added(%arg0: tensor, %arg1: tensor : tensor +// CHECK-SAME: _parallel_execution_ids = "r0:0", value = dense<1> : tensor // CHECK: tf_executor.yield [[CONST_0]] // CHECK: tf_executor.island // CHECK: [[CONST_1:%.+]] = "tf.Const" -// CHECK-SAME: value = dense<2> : tensor +// CHECK-SAME: _parallel_execution_ids = "r0:1", value = dense<2> : tensor // CHECK: tf_executor.yield [[CONST_1]] +// ----- +// Tests parallel_execute nested inside replicate +// CHECK-LABEL: func @nested_parallel_execute +func.func @nested_parallel_execute(%arg0: tensor, %arg1: tensor) { + %0:4 = tf_executor.graph { + %1:5 = tf_executor.island { + %2:4 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor) {n = 2 : i32} { + %3:2 = "tf_device.parallel_execute"() ({ + %6 = "tf_device.launch"() ({ + %4 = "tf.OpA"(%arg0) : (tensor) -> tensor + tf_device.return %4: tensor + }) {device = "/TPU:1"} : () -> (tensor) + tf_device.return %6: tensor + }, { + %4 = "tf_device.launch"() ({ + %5 = "tf.OpB"(%arg1) : (tensor) -> (tensor) + tf_device.return %5: tensor + }) {device = "/TPU:2"} : () -> (tensor) + tf_device.return %4 : tensor + }) : () -> (tensor, tensor) + tf_device.return %3#0, %3#1 : tensor, tensor + } + tf_executor.yield %2#0, %2#1, %2#2, %2#3 : tensor, tensor, tensor, tensor + } + tf_executor.fetch %1#0, %1#1, %1#2, %1#3 : tensor, tensor, tensor, tensor + } + func.return +} + +// CHECK: tf_executor.island +// CHECK: tf_device.parallel_execute +// CHECK: tf_device.launch +// CHECK: tf.OpA +// CHECK: {device = "/TPU:1"} +// CHECK: tf_device.launch +// CHECK: tf.OpB +// CHECK: {device = "/TPU:2"} +// CHECK: _parallel_execution_ids = "r0:0" +// CHECK: tf_executor.island +// CHECK: tf_device.parallel_execute +// CHECK: tf_device.launch +// CHECK: tf.OpA +// CHECK: {device = "/TPU:1"} +// CHECK: tf_device.launch +// CHECK: tf.OpB +// CHECK: {device = "/TPU:2"} +// CHECK: _parallel_execution_ids = "r0:1" +// CHECK: tf_executor.fetch + +// ----- +// Tests parallel_group attr can merge correctly. +// CHECK-LABEL: func @merge_of_parallel_group_attr +func.func @merge_of_parallel_group_attr() { + tf_executor.graph { + %0 = tf_executor.island { + tf_device.replicate {n = 2 : i32, devices = {CORE_0 = ["/CPU:0", "/GPU:1"]}, _parallel_execution_ids = "r4:5"} { + "tf_device.launch"() ({ + "tf.opA"() : () -> () + tf_device.return + }) {device = "CORE_0"} : () -> () + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + func.return +} + +// CHECK: "tf.opA" +// CHECK: _parallel_execution_ids = "r4:5,r0:0", device = "/CPU:0" +// CHECK: "tf.opA" +// CHECK: _parallel_execution_ids = "r4:5,r0:1", device = "/GPU:1" + // ----- // Tests tf._TPUDeviceOrdinalPlaceholder cannot be updated when device ordinal diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island_legacy.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island_legacy.mlir new file mode 100644 index 00000000000..91ac4a2e76d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island_legacy.mlir @@ -0,0 +1,277 @@ +// RUN: tf-opt -split-input-file -verify-diagnostics %s -tf-replicate-to-island=legacy-graph-export=true | FileCheck %s + +// Tests per replica island has same control operands as island holding +// replicate. +// CHECK-LABEL: func @controls_per_replica +func.func @controls_per_replica() { + tf_executor.graph { + %1 = tf_executor.ControlTrigger {} + %2 = tf_executor.ControlTrigger {} + %3 = tf_executor.island(%1, %2) { + tf_device.replicate {n = 2 : i32} { + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + func.return +} + +// CHECK: %[[CT_0:.*]] = tf_executor.ControlTrigger +// CHECK: %[[CT_1:.*]] = tf_executor.ControlTrigger +// CHECK: %{{.*}} = tf_executor.island(%[[CT_0]], %[[CT_1]]) +// CHECK: %{{.*}} = tf_executor.island(%[[CT_0]], %[[CT_1]]) + + +// Tests devices are not remapped if no devices were defined in replicate. +// CHECK-LABEL: func @no_devices +func.func @no_devices() { + tf_executor.graph { + %0 = tf_executor.island { + tf_device.replicate {n = 2 : i32} { + "tf_device.launch"() ({ + "tf.opA"() : () -> () + tf_device.return + }) {device = "CORE_0"} : () -> () + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + func.return +} + +// CHECK: "tf.opA" +// CHECK: device = "CORE_0" +// CHECK: "tf.opA" +// CHECK: device = "CORE_0" + + +// Tests devices are not remapped if device is not in replicate devices. +// CHECK-LABEL: func @no_override_device +func.func @no_override_device() { + tf_executor.graph { + %0 = tf_executor.island { + tf_device.replicate {n = 2 : i32, devices = {CORE_0 = ["/CPU:0", "/GPU:1"]}} { + "tf_device.launch"() ({ + "tf.opA"() : () -> () + tf_device.return + }) {device = "/TPU:2"} : () -> () + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + func.return +} + +// CHECK: "tf.opA" +// CHECK: device = "/TPU:2" +// CHECK: "tf.opA" +// CHECK: device = "/TPU:2" + + +// Tests devices are remapped if device is in replicate devices. +// CHECK-LABEL: func @remap_device +func.func @remap_device() { + tf_executor.graph { + %0 = tf_executor.island { + tf_device.replicate {n = 2 : i32, devices = {CORE_0 = ["/CPU:0", "/GPU:1"]}} { + "tf_device.launch"() ({ + "tf.opA"() : () -> () + tf_device.return + }) {device = "CORE_0"} : () -> () + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + func.return +} + +// CHECK: "tf.opA" +// CHECK: device = "/CPU:0" +// CHECK: "tf.opA" +// CHECK: device = "/GPU:1" + + +// Tests replicate with control dependency output has each expanded replica +// control pinned to a sink island. +// CHECK-LABEL: func @replicate_control +func.func @replicate_control() { + tf_executor.graph { + %1 = tf_executor.island { + tf_device.replicate {n = 2 : i32} { + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch %1 : !tf_executor.control + } + func.return +} + +// CHECK: %[[REPLICA_0:.*]] = tf_executor.island +// CHECK: %[[REPLICA_1:.*]] = tf_executor.island +// CHECK: %[[SINK:.*]] = tf_executor.island(%[[REPLICA_0]], %[[REPLICA_1]]) +// CHECK: tf_executor.fetch %[[SINK]] + + +// Tests unused replica are pinned to the graph fetch. +// CHECK-LABEL: func @unused_replica +func.func @unused_replica(%arg0: tensor) { + %0 = tf_executor.graph { + %1:3 = tf_executor.island { + %2:2 = tf_device.replicate([%arg0, %arg0] as %ri0: tensor) {n = 2 : i32} { + tf_device.return %ri0 : tensor + } + tf_executor.yield %2#0, %2#1 : tensor, tensor + } + tf_executor.fetch %1#1 : tensor + } + func.return +} + +// CHECK: {{%.*}}, [[REPLICA_0_CONTROL:%.*]] = tf_executor.island +// CHECK: [[REPLICA_1_OUTPUT:%.*]], {{%.*}} = tf_executor.island +// CHECK: tf_executor.fetch [[REPLICA_1_OUTPUT]], [[REPLICA_0_CONTROL]] + + +// Tests replicate results are remapped correctly. +// CHECK-LABEL: func @replicate_result +func.func @replicate_result(%arg0: tensor, %arg1: tensor) { + %0:4 = tf_executor.graph { + %1:5 = tf_executor.island { + %2:4 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor) {n = 2 : i32} { + %3 = "tf.opA"(%arg2) : (tensor) -> tensor + %4 = "tf.opB"(%arg2) : (tensor) -> tensor + tf_device.return %3, %4 : tensor, tensor + } + tf_executor.yield %2#0, %2#1, %2#2, %2#3 : tensor, tensor, tensor, tensor + } + tf_executor.fetch %1#0, %1#1, %1#2, %1#3 : tensor, tensor, tensor, tensor + } + func.return +} + +// CHECK: %[[REPLICA_0:.*]]:2, %{{.*}} = tf_executor.island +// CHECK: %[[REPLICA_1:.*]]:2, %{{.*}} = tf_executor.island +// CHECK: tf_executor.fetch %[[REPLICA_0]]#0, %[[REPLICA_1]]#0, %[[REPLICA_0]]#1, %[[REPLICA_1]]#1 + + +// Tests replicate results are remapped correctly with packed inputs. +// CHECK-LABEL: func @replicate_with_packed_input +func.func @replicate_with_packed_input(%arg0: tensor, %arg1: tensor) { + %0:4 = tf_executor.graph { + %1:5 = tf_executor.island { + %2:4 = tf_device.replicate(%arg0 as %arg2: tensor, %arg1 as %arg3: tensor) + {n = 2 : i32, _packed_input_indices = [0, 1]} { + %3 = "tf.opA"(%arg2) : (tensor) -> tensor + %4 = "tf.opB"(%arg3) : (tensor) -> tensor + tf_device.return %3, %4 : tensor, tensor + } + tf_executor.yield %2#0, %2#1, %2#2, %2#3 : tensor, tensor, tensor, tensor + } + tf_executor.fetch %1#0, %1#1, %1#2, %1#3 : tensor, tensor, tensor, tensor + } + func.return +} + +// CHECK: %[[REPLICA_0:.*]]:2, %{{.*}} = tf_executor.island +// CHECK: "tf.opA"(%arg0) +// CHECK: "tf.opB"(%arg1) +// CHECK: %[[REPLICA_1:.*]]:2, %{{.*}} = tf_executor.island +// CHECK: "tf.opA"(%arg0) +// CHECK: "tf.opB"(%arg1) +// CHECK: tf_executor.fetch %[[REPLICA_0]]#0, %[[REPLICA_1]]#0 + + +// Tests replica id is added correctly. +// CHECK-LABEL: func @replica_id_attr_added +func.func @replica_id_attr_added(%arg0: tensor, %arg1: tensor) { + tf_executor.graph { + %0 = tf_executor.island { + tf_device.replicate([%arg0, %arg1] as %arg2: tensor) {n = 2 : i32} { + "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg2){table_ids = [1, 2]} : (tensor) -> () + "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg2){table_ids = [1, 2]} : (tensor) -> () + "tf.EnqueueTPUEmbeddingArbitraryTensorBatch"(%arg2){table_ids = [1, 2]} : (tensor) -> () + "tf.A"(%arg2) : (tensor) -> () + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + func.return +} + +// CHECK: tf_executor.island +// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch" +// CHECK-SAME: _xla_replica_id = 0 +// CHECK: "tf.EnqueueTPUEmbeddingRaggedTensorBatch" +// CHECK-SAME: _xla_replica_id = 0 +// CHECK: "tf.EnqueueTPUEmbeddingArbitraryTensorBatch" +// CHECK-SAME: _xla_replica_id = 0 +// CHECK: "tf.A" +// CHECK-NOT: _xla_replica_id +// CHECK: tf_executor.island +// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch" +// CHECK-SAME: _xla_replica_id = 1 +// CHECK: "tf.EnqueueTPUEmbeddingRaggedTensorBatch" +// CHECK-SAME: _xla_replica_id = 1 +// CHECK: "tf.EnqueueTPUEmbeddingArbitraryTensorBatch" +// CHECK-SAME: _xla_replica_id = 1 +// CHECK: "tf.A" +// CHECK-NOT: _xla_replica_id +// CHECK: tf_executor.fetch + + +// Tests tf._TPUDeviceOrdinalPlaceholder ops are replaced with explicit device +// ordinal constant values based on the first TPU core device id. +// CHECK-LABEL: func @device_ordinals +func.func @device_ordinals() { + tf_executor.graph { + %0:3 = tf_executor.island { + %1:2 = tf_device.replicate {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2"]}} { + %2 = "tf._TPUDeviceOrdinalPlaceholder"() : () -> tensor + tf_device.return %2 : tensor + } + tf_executor.yield %1#0, %1#1 : tensor, tensor + } + tf_executor.fetch + } + func.return +} + +// CHECK: tf_executor.island +// CHECK: [[CONST_0:%.+]] = "tf.Const" +// CHECK-SAME: value = dense<1> : tensor +// CHECK: tf_executor.yield [[CONST_0]] +// CHECK: tf_executor.island +// CHECK: [[CONST_1:%.+]] = "tf.Const" +// CHECK-SAME: value = dense<2> : tensor +// CHECK: tf_executor.yield [[CONST_1]] + +// ----- + +// Tests tf._TPUDeviceOrdinalPlaceholder cannot be updated when device ordinal +// is missing. + +func.func @missing_device_ordinals() { + tf_executor.graph { + %0:3 = tf_executor.island { + %1:2 = tf_device.replicate {n = 2 : i32, devices = {TPU_REPLICATED_CORE_1 = ["/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2"]}} { + // expected-error@below {{requires device ordinal from device TPU_REPLICATED_CORE_0 to be present in 'tf.device.replicate' op}} + %2 = "tf._TPUDeviceOrdinalPlaceholder"() : () -> tensor + tf_device.return %2 : tensor + } + tf_executor.yield %1#0, %1#1 : tensor, tensor + } + tf_executor.fetch + } + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 2c816bd4145..0f0fea844e7 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -145,6 +145,51 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr func.return %0 : tensor<*xf32> } + // CHECK-LABEL: func @shape_from_case_to_branch_functions_to_results + // CHECK-SAME: (%arg0: tensor, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + func.func @shape_from_case_to_branch_functions_to_results(%arg0: tensor, %arg1: tensor<1x2x3xf32>) -> tensor<*xf32> { + %0 = "tf.Case"(%arg0, %arg1) {branches = [@case_branch0, @case_branch1], is_stateless = true} : (tensor, tensor<1x2x3xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> + } + + // CHECK-LABEL: func @case_branch0 + // CHECK-SAME: (%arg0: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + func.func @case_branch0(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: return + // CHECK-SAME: tensor<1x2x3xf32> + func.return %arg0 : tensor<*xf32> + } + + // CHECK-LABEL: func @case_branch1 + // CHECK-SAME: (%arg0: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + func.func @case_branch1(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: "tf.Identity"(%arg0) : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + %0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>) + // CHECK: return + // CHECK-SAME: tensor<1x2x3xf32> + func.return %0 : tensor<*xf32> + } + + // CHECK-LABEL: shape_from_case_to_region_bodies_to_output + // CHECK-SAME: -> tensor<1x2x3xf32> + func.func @shape_from_case_to_region_bodies_to_output(%arg0: tensor, %arg1: tensor<1x2x3xf32>) -> tensor<*xf32> { + %unshaped = "tf.Cast"(%arg1) : (tensor<1x2x3xf32>) -> tensor<*xf32> + %0 = "tf.CaseRegion"(%arg0) ({ + // CHECK: "tf.Add"{{.+}}(tensor<1x2x3xf32>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + // CHECK: "tf.Yield"{{.+}}(tensor<1x2x3xf32>) -> () + %1 = "tf.Add"(%unshaped, %unshaped) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%1) : (tensor<*xf32>) -> () + }, { + // CHECK: "tf.Sub"{{.+}}(tensor<1x2x3xf32>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + // CHECK: "tf.Yield"{{.+}}(tensor<1x2x3xf32>) -> () + %2 = "tf.Sub"(%unshaped, %unshaped) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%2) : (tensor<*xf32>) -> () + // CHECK: {is_stateless = true} : (tensor) -> tensor<1x2x3xf32> + }) {is_stateless = true} : (tensor) -> tensor<*xf32> + // CHECK: return {{.*}} : tensor<1x2x3xf32> + func.return %0 : tensor<*xf32> + } + // CHECK-LABEL: func @shape_from_while_to_cond_body_functions func.func @shape_from_while_to_cond_body_functions(%arg0: tensor<4xf32>, %arg1: tensor>>, %arg2: tensor>>) -> tensor<4xf32> { // CHECK: "tf.While" @@ -751,6 +796,34 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr func.return } + // CHECK-LABEL: func @shape_partial_eval + func.func @shape_partial_eval(%arg0: tensor<1x?x7xf32>, %arg1: tensor<3x7xf32>) -> tensor<*xf32> { + %0 = "tf.Shape"(%arg0) : (tensor<1x?x7xf32>) -> tensor<3xi32> + + // CHECK: tf.Reshape + // CHECK-SAME: tensor<1x3x7xf32> + %1 = "tf.Reshape"(%arg1, %0) : (tensor<3x7xf32>, tensor<3xi32>) -> tensor<*xf32> + return %1 : tensor<*xf32> + } + + // CHECK-LABEL: func @gather_concat_reshape + func.func @gather_concat_reshape(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { + %0 = "tf.Shape"(%arg0) : (tensor) -> tensor<3xi32> + + %indices = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor + %1 = "tf.GatherV2"(%0, %indices, %axis) {batch_dims = 0 : i64, device = ""} : (tensor<3xi32>, tensor<2xi32>, tensor) -> tensor<2xi32> + + %last_dim = "tf.Const"() {value = dense<7> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.ConcatV2"(%1, %last_dim, %axis) {device = ""} : (tensor<2xi32>, tensor<1xi32>, tensor) -> tensor<3xi32> + + + // CHECK: tf.Reshape + // CHECK-SAME: tensor + %3 = "tf.Reshape"(%arg1, %2) : (tensor, tensor<3xi32>) -> tensor<*xf32> + return %3 : tensor<*xf32> + } + // CHECK-LABEL: const_fold func.func @const_fold() -> () { // CHECK: tf.Const diff --git a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir index be22c2f0552..b489dd04e73 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/side-effect-analysis-test.mlir @@ -2201,8 +2201,8 @@ func.func @nontrivial_multi_op_islands( // ----- -// Tests that we create dependencies between `CollectiveReduceV2` ops -// (TF_CollectiveReduceOrderingEffect). +// Tests that we create dependencies between `CollectiveReduceV2` ops if no +// ordering tokens are present (TF_CollectiveReduceOrderingEffect). func.func @collective_reduce_ordering_effect( // expected-remark@above {{ID: 7}} %input: tensor, @@ -2236,6 +2236,114 @@ func.func @collective_reduce_ordering_effect( // ----- +// Tests that we don't create dependencies between `CollectiveReduceV2` ops if +// ordering tokens are present and independent. +func.func @collective_reduce_independent_ordering_tokens( + // expected-remark@above {{ID: 7}} + %arg0: tensor<*x!tf_type.resource>> {tf._resource_arg_unique_id = 0 : i64}, + %arg1: tensor<*x!tf_type.resource>> {tf._resource_arg_unique_id = 1 : i64}, + %input: tensor, + %group_key: tensor, + %group_size: tensor, + %instance_key: tensor) { + tf_executor.graph { + // expected-remark@above {{ID: 5}} + %island = tf_executor.island { + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Successors: {4}}} + %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key, %arg0) {merge_op = "Add", final_op = "Id"} : (tensor, tensor, tensor, tensor, tensor<*x!tf_type.resource>>) -> tensor + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {2}}} + %1 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key, %arg1) {merge_op = "Mul", final_op = "Id"} : (tensor, tensor, tensor, tensor, tensor<*x!tf_type.resource>>) -> tensor + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Successors: {2}}} + tf_executor.yield + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Predecessors: {0,1}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Predecessors: {3}}} + } + func.return + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Sinks: {5}}} +} + +// ----- + +// Tests that we do create dependencies between `CollectiveReduceV2` ops if +// ordering tokens are present and dependent. +func.func @collective_reduce_dependent_ordering_tokens( + // expected-remark@above {{ID: 7}} + %arg0: tensor<*x!tf_type.resource>> {tf._resource_arg_unique_id = 0 : i64}, + %arg1: tensor<*x!tf_type.resource>> {tf._resource_arg_unique_id = 0 : i64}, + %input: tensor, + %group_key: tensor, + %group_size: tensor, + %instance_key: tensor) { + tf_executor.graph { + // expected-remark@above {{ID: 5}} + %island = tf_executor.island { + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Successors: {4}}} + %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key, %arg0) {merge_op = "Add", final_op = "Id"} : (tensor, tensor, tensor, tensor, tensor<*x!tf_type.resource>>) -> tensor + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {1}}} + %1 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key, %arg1) {merge_op = "Mul", final_op = "Id"} : (tensor, tensor, tensor, tensor, tensor<*x!tf_type.resource>>) -> tensor + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Successors: {2}}} + // expected-remark@above {{Predecessors: {0}}} + tf_executor.yield + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Predecessors: {1}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Predecessors: {3}}} + } + func.return + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Sinks: {5}}} +} + +// ----- + +// Tests that we don't create dependencies between `CollectiveReduceV2` ops if +// one has ordering tokens and the other one doesn't. +func.func @collective_reduce_dependent_ordering_tokens( + // expected-remark@above {{ID: 7}} + %arg0: tensor<*x!tf_type.resource>> {tf._resource_arg_unique_id = 0 : i64}, + %input: tensor, + %group_key: tensor, + %group_size: tensor, + %instance_key: tensor) { + tf_executor.graph { + // expected-remark@above {{ID: 5}} + %island = tf_executor.island { + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Successors: {4}}} + %0 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key, %arg0) {merge_op = "Add", final_op = "Id"} : (tensor, tensor, tensor, tensor, tensor<*x!tf_type.resource>>) -> tensor + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {2}}} + %1 = "tf.CollectiveReduceV2"(%input, %group_size, %group_key, %instance_key) {merge_op = "Mul", final_op = "Id"} : (tensor, tensor, tensor, tensor) -> tensor + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Successors: {2}}} + tf_executor.yield + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Predecessors: {0,1}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Predecessors: {3}}} + } + func.return + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Sinks: {5}}} +} + +// ----- + // Tests that we don't create dependencies between device launch ops with // multiple stateless ops each. func.func @multi_stateless_op_launches() { @@ -2385,3 +2493,138 @@ func.func @self_dependent_only_feeds_into_fetch( // expected-remark@above {{ID: 8}} // expected-remark@above {{Sinks: {7}}} } + +// ----- + +// Tests that we create dependencies between `NcclAllReduce` ops on same device. +func.func @collective_reduce_ordering_effect( + // expected-remark@above {{ID: 7}} + %input: tensor) { + tf_executor.graph { + // expected-remark@above {{ID: 5}} + %island = tf_executor.island { + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Successors: {4}}} + %0 = "tf.NcclAllReduce"(%input) { reduction = "min", num_devices = 2, shared_name = "name", device = "CPU:0"} : (tensor) -> tensor + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {1}}} + %1 = "tf.NcclAllReduce"(%input) { reduction = "min", num_devices = 2, shared_name = "name", device = "CPU:0"} : (tensor) -> tensor + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Predecessors: {0}}} + // expected-remark@above {{Successors: {2}}} + tf_executor.yield + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Predecessors: {1}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Predecessors: {3}}} + } + func.return + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Sinks: {5}}} +} + +// ----- + +// Tests that we create dependencies between `NcclAllReduce` ops for unspecified +// devices. +func.func @collective_reduce_ordering_effect( + // expected-remark@above {{ID: 7}} + %input: tensor) { + tf_executor.graph { + // expected-remark@above {{ID: 5}} + %island = tf_executor.island { + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Successors: {4}}} + %0 = "tf.NcclAllReduce"(%input) { reduction = "min", num_devices = 2, shared_name = "name"} : (tensor) -> tensor + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {1}}} + %1 = "tf.NcclAllReduce"(%input) { reduction = "min", num_devices = 2, shared_name = "name"} : (tensor) -> tensor + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Predecessors: {0}}} + // expected-remark@above {{Successors: {2}}} + tf_executor.yield + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Predecessors: {1}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Predecessors: {3}}} + } + func.return + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Sinks: {5}}} +} + +// ----- + +// Tests that we don't create dependencies between `NcclAllReduce` ops on +// different devices. +func.func @collective_reduce_ordering_effect( + // expected-remark@above {{ID: 7}} + %input: tensor) { + tf_executor.graph { + // expected-remark@above {{ID: 5}} + %island = tf_executor.island { + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Successors: {4}}} + %0 = "tf.NcclAllReduce"(%input) { reduction = "min", num_devices = 2, shared_name = "name", device = "CPU:0"} : (tensor) -> tensor + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {2}}} + %1 = "tf.NcclAllReduce"(%input) { reduction = "min", num_devices = 2, shared_name = "name", device = "CPU:1"} : (tensor) -> tensor + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Successors: {2}}} + tf_executor.yield + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Predecessors: {0,1}}} + } + tf_executor.fetch %island : !tf_executor.control + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Predecessors: {3}}} + } + func.return + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Sinks: {5}}} +} + +// ----- + +// Tests that we create dependencies to a fetch op, even if the fetch op has +// known effects (in this case due to resource operand) and the resources of +// other side-effecting ops are independent. +func.func @fetch_with_resource_operand( + // expected-remark@above {{ID: 9}} + %arg0: tensor, + %arg1: tensor<*x!tf_type.resource>>, + %arg2: tensor<*x!tf_type.resource>>) { + tf_executor.graph { + // expected-remark@above {{ID: 7}} + %island1 = tf_executor.island { + // expected-remark@above {{ID: 2}} + // expected-remark@above {{Successors: {6}}} + "tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0) {table_ids = [1, 2], device_ordinal = 1} : (tensor) -> () + // expected-remark@above {{ID: 0}} + // expected-remark@above {{Successors: {1}}} + tf_executor.yield + // expected-remark@above {{ID: 1}} + // expected-remark@above {{Predecessors: {0}}} + } + %island2 = tf_executor.island { + // expected-remark@above {{ID: 5}} + // expected-remark@above {{Successors: {6}}} + %read = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>>) -> tensor<32xf32> + // expected-remark@above {{ID: 3}} + // expected-remark@above {{Successors: {4}}} + tf_executor.yield + // expected-remark@above {{ID: 4}} + // expected-remark@above {{Predecessors: {3}}} + } + tf_executor.fetch %arg2, %island1, %island2 : tensor<*x!tf_type.resource>>, !tf_executor.control, !tf_executor.control + // expected-remark@above {{ID: 6}} + // expected-remark@above {{Predecessors: {2,5}}} + } + func.return + // expected-remark@above {{ID: 8}} + // expected-remark@above {{Sinks: {7}}} +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/split_into_island_per_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/split_into_island_per_op.mlir index 173ade30390..771a55cb7a9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/split_into_island_per_op.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/split_into_island_per_op.mlir @@ -121,6 +121,26 @@ func.func @dangling_print(%arg0: tensor<*xi32>, %arg1: tensor) -> (tensor<* // CHECK: } // CHECK: return %[[GRAPH]]#0, %[[GRAPH]]#1 +func.func @drop_fetch_control_dep(%arg0: tensor<*xi32>, %arg1: tensor) -> (tensor<*xi32>, tensor<*xi32>) { + %graph:2 = tf_executor.graph { + %island1:3 = tf_executor.island { + %add1 = "tf.Add"(%arg0, %arg1) : (tensor<*xi32>, tensor) -> tensor<*xi32> + %add2 = "tf.Add"(%add1, %arg1) : (tensor<*xi32>, tensor) -> tensor<*xi32> + tf_executor.yield %add1, %add2 : tensor<*xi32>, tensor<*xi32> + } + tf_executor.fetch %island1#0, %island1#1, %island1#2 : tensor<*xi32>, tensor<*xi32>, !tf_executor.control + } + func.return %graph#0, %graph#1 : tensor<*xi32>, tensor<*xi32> +} + +// CHECK-LABEL: func @drop_fetch_control_dep +// CHECK: %[[GRAPH:.*]]:2 = tf_executor.graph { +// CHECK: %[[ADD1:.*]], %[[ADD1_control:.*]] = tf_executor.island wraps "tf.Add"(%arg0, %arg1) +// CHECK: %[[ADD2:.*]], %[[ADD2_control:.*]] = tf_executor.island wraps "tf.Add"(%[[ADD1]], %arg1) +// CHECK: tf_executor.fetch %[[ADD1]], %[[ADD2]] : +// CHECK: } +// CHECK: return %[[GRAPH]]#0, %[[GRAPH]]#1 + func.func @fetching_arg(%arg0: tensor<*xi32>) { tf_executor.graph { %island:3 = tf_executor.island { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 8b23c7e678a..5c6e0efacfb 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1639,7 +1639,7 @@ func.func @invalidSelect(%arg0: tensor<1x8xi1>, %arg1: tensor<1x8x8xi32>, %arg2: //===--------------------------------------------------------------------===// // Test valid tf.SelectV2 -// CHfaECK-LABEL: func @selectV2BroadcastThen +// CHECK-LABEL: func @selectV2BroadcastThen func.func @selectV2BroadcastThen(%arg0: tensor, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> func.return %0: tensor<2x8x8xi32> @@ -2561,6 +2561,26 @@ func.func @testInvalidToBool(%arg0: tensor) -> tensor<1xi1> { // ----- +// Test invalid tf.TPUPartitionedInputV2 with packing +func.func @testPackedTPUPartitionedInputV2(tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<4x4xf32> { +^bb0(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>): + // expected-error @+1 {{expected 1 inputs, got 2}} + %0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {partition_dims = [2, 1], is_packed = true} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<4x4xf32> + func.return %0 : tensor<4x4xf32> +} + +// ----- + +// Test invalid tf.TPUPartitionedInputV2 without packing +func.func @testUnpackedTPUPartitionedInputV2(tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<4x4xf32> { +^bb0(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>): + // expected-error @+1 {{expected 2 inputs, got 1}} + %0 = "tf.TPUPartitionedInputV2"(%arg0) {partition_dims = [2, 1], is_packed = false} : (tensor<2x4xf32>) -> tensor<4x4xf32> + func.return %0 : tensor<4x4xf32> +} + +// ----- + // Test valid tf.Transpose // CHECK-LABEL: testTranspose func.func @testTranspose(tensor<2x3xf32>) -> tensor<3x2xf32> { @@ -4679,3 +4699,334 @@ func.func @testSetStaticDimensionBounds(%arg0: tensor, %arg1: tensor<4x %dyn_arg0 = "tf.SetStaticDimensionBounds" (%arg0, %arg1) :(tensor, tensor<4xi32>) -> tensor func.return %dyn_arg0 : tensor } + +// ----- + +func.func @testUniformQuantizedDotHybrid(%lhs: tensor<*xf32>, %rhs: tensor<2x2x!tf_type.qint8>, %rhs_scales: tensor<2xf32>, %rhs_zps: tensor) -> tensor<*xf32> { + // expected-error @below {{'tf.UniformQuantizedDotHybrid' op quantization_axis is -1, scales must have 0 rank.}} + %0 = "tf.UniformQuantizedDotHybrid"(%lhs, %rhs, %rhs_scales, %rhs_zps) { + rhs_quantization_axis = -1 : i64, rhs_quantization_min_val = -128 : i64, rhs_quantization_max_val = 127 : i64 + } : (tensor<*xf32>, tensor<2x2x!tf_type.qint8>, tensor<2xf32>, tensor) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +func.func @testUniformQuantizedDotHybrid(%lhs: tensor<*xf32>, %rhs: tensor<2x2x!tf_type.qint8>, %rhs_scales: tensor, %rhs_zps: tensor<2xi32>) -> tensor<*xf32> { + // expected-error @below {{'tf.UniformQuantizedDotHybrid' op quantization_axis is -1, zero_points must have 0 rank.}} + %0 = "tf.UniformQuantizedDotHybrid"(%lhs, %rhs, %rhs_scales, %rhs_zps) { + rhs_quantization_axis = -1 : i64, rhs_quantization_min_val = -128 : i64, rhs_quantization_max_val = 127 : i64 + } : (tensor<*xf32>, tensor<2x2x!tf_type.qint8>, tensor, tensor<2xi32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +func.func @testUniformQuantizedDotHybrid(%lhs: tensor<*xf32>, %rhs: tensor<2x2x!tf_type.qint8>, %rhs_scales: tensor<2x2xf32>, %rhs_zps: tensor<2xi32>) -> tensor<*xf32> { + // expected-error @below {{'tf.UniformQuantizedDotHybrid' op quantization_axis is not -1, scales must have 1 rank.}} + %0 = "tf.UniformQuantizedDotHybrid"(%lhs, %rhs, %rhs_scales, %rhs_zps) { + rhs_quantization_axis = 0 : i64, rhs_quantization_min_val = -128 : i64, rhs_quantization_max_val = 127 : i64 + } : (tensor<*xf32>, tensor<2x2x!tf_type.qint8>, tensor<2x2xf32>, tensor<2xi32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +func.func @testUniformQuantizedDotHybrid(%lhs: tensor<*xf32>, %rhs: tensor<2x2x!tf_type.qint8>, %rhs_scales: tensor<2xf32>, %rhs_zps: tensor<2x2xi32>) -> tensor<*xf32> { + // expected-error @below {{'tf.UniformQuantizedDotHybrid' op quantization_axis is not -1, zero_points must have 1 rank.}} + %0 = "tf.UniformQuantizedDotHybrid"(%lhs, %rhs, %rhs_scales, %rhs_zps) { + rhs_quantization_axis = 0 : i64, rhs_quantization_min_val = -128 : i64, rhs_quantization_max_val = 127 : i64 + } : (tensor<*xf32>, tensor<2x2x!tf_type.qint8>, tensor<2xf32>, tensor<2x2xi32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +func.func @testUniformQuantizedDotHybrid(%lhs: tensor<*xf32>, %rhs: tensor<2x2x!tf_type.qint8>, %rhs_scales: tensor<2xf32>, %rhs_zps: tensor<3xi32>) -> tensor<*xf32> { + // expected-error @below {{'tf.UniformQuantizedDotHybrid' op scales and zero points must have same number of elements.}} + %0 = "tf.UniformQuantizedDotHybrid"(%lhs, %rhs, %rhs_scales, %rhs_zps) { + rhs_quantization_axis = 0 : i64, rhs_quantization_min_val = -128 : i64, rhs_quantization_max_val = 127 : i64 + } : (tensor<*xf32>, tensor<2x2x!tf_type.qint8>, tensor<2xf32>, tensor<3xi32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +func.func @testUniformQuantizedConvolutionHybrid(%lhs: tensor<*xf32>, %rhs: tensor<2x2x!tf_type.qint8>, %rhs_scales: tensor<2xf32>, %rhs_zps: tensor) -> tensor<*xf32> { + // expected-error @below {{'tf.UniformQuantizedConvolutionHybrid' op quantization_axis is -1, scales must have 0 rank.}} + %0 = "tf.UniformQuantizedConvolutionHybrid"(%lhs, %rhs, %rhs_scales, %rhs_zps) { + window_strides = [1, 2], + padding = "VALID", + explicit_padding = [], + lhs_dilation = [1, 1], + rhs_dilation = [2, 2], + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02", + rhs_quantization_axis = -1 : i64, rhs_quantization_min_val = -128 : i64, rhs_quantization_max_val = 127 : i64 + } : (tensor<*xf32>, tensor<2x2x!tf_type.qint8>, tensor<2xf32>, tensor) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +func.func @testUniformQuantizedConvolutionHybrid(%lhs: tensor<*xf32>, %rhs: tensor<2x2x!tf_type.qint8>, %rhs_scales: tensor, %rhs_zps: tensor<2xi32>) -> tensor<*xf32> { + // expected-error @below {{'tf.UniformQuantizedConvolutionHybrid' op quantization_axis is -1, zero_points must have 0 rank.}} + %0 = "tf.UniformQuantizedConvolutionHybrid"(%lhs, %rhs, %rhs_scales, %rhs_zps) { + window_strides = [1, 2], + padding = "VALID", + explicit_padding = [], + lhs_dilation = [1, 1], + rhs_dilation = [2, 2], + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02", + rhs_quantization_axis = -1 : i64, rhs_quantization_min_val = -128 : i64, rhs_quantization_max_val = 127 : i64 + } : (tensor<*xf32>, tensor<2x2x!tf_type.qint8>, tensor, tensor<2xi32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +func.func @testUniformQuantizedConvolutionHybrid(%lhs: tensor<*xf32>, %rhs: tensor<2x2x!tf_type.qint8>, %rhs_scales: tensor<2x2xf32>, %rhs_zps: tensor<2xi32>) -> tensor<*xf32> { + // expected-error @below {{'tf.UniformQuantizedConvolutionHybrid' op quantization_axis is not -1, scales must have 1 rank.}} + %0 = "tf.UniformQuantizedConvolutionHybrid"(%lhs, %rhs, %rhs_scales, %rhs_zps) { + window_strides = [1, 2], + padding = "VALID", + explicit_padding = [], + lhs_dilation = [1, 1], + rhs_dilation = [2, 2], + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02", + rhs_quantization_axis = 0 : i64, rhs_quantization_min_val = -128 : i64, rhs_quantization_max_val = 127 : i64 + } : (tensor<*xf32>, tensor<2x2x!tf_type.qint8>, tensor<2x2xf32>, tensor<2xi32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +func.func @testUniformQuantizedConvolutionHybrid(%lhs: tensor<*xf32>, %rhs: tensor<2x2x!tf_type.qint8>, %rhs_scales: tensor<2xf32>, %rhs_zps: tensor<2x2xi32>) -> tensor<*xf32> { + // expected-error @below {{'tf.UniformQuantizedConvolutionHybrid' op quantization_axis is not -1, zero_points must have 1 rank.}} + %0 = "tf.UniformQuantizedConvolutionHybrid"(%lhs, %rhs, %rhs_scales, %rhs_zps) { + window_strides = [1, 2], + padding = "VALID", + explicit_padding = [], + lhs_dilation = [1, 1], + rhs_dilation = [2, 2], + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02", + rhs_quantization_axis = 0 : i64, rhs_quantization_min_val = -128 : i64, rhs_quantization_max_val = 127 : i64 + } : (tensor<*xf32>, tensor<2x2x!tf_type.qint8>, tensor<2xf32>, tensor<2x2xi32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +func.func @testUniformQuantizedConvolutionHybrid(%lhs: tensor<*xf32>, %rhs: tensor<2x2x!tf_type.qint8>, %rhs_scales: tensor<2xf32>, %rhs_zps: tensor<3xi32>) -> tensor<*xf32> { + // expected-error @below {{'tf.UniformQuantizedConvolutionHybrid' op scales and zero points must have same number of elements.}} + %0 = "tf.UniformQuantizedConvolutionHybrid"(%lhs, %rhs, %rhs_scales, %rhs_zps) { + window_strides = [1, 2], + padding = "VALID", + explicit_padding = [], + lhs_dilation = [1, 1], + rhs_dilation = [2, 2], + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02", + rhs_quantization_axis = 0 : i64, rhs_quantization_min_val = -128 : i64, rhs_quantization_max_val = 127 : i64 + } : (tensor<*xf32>, tensor<2x2x!tf_type.qint8>, tensor<2xf32>, tensor<3xi32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +func.func @testUniformQuantize(%arg0: tensor<*xf32>, %scales: tensor<2xf32>, %zps: tensor) -> tensor<*x!tf_type.qint8> { + // expected-error @below {{'tf.UniformQuantize' op quantization_axis is -1, scales must have 0 rank.}} + %0 = "tf.UniformQuantize"(%arg0, %scales, %zps) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor<*xf32>, tensor<2xf32>, tensor) -> tensor<*x!tf_type.qint8> + func.return %0 : tensor<*x!tf_type.qint8> +} + +// ----- + +func.func @testUniformRequantize( + %arg0: tensor<*x!tf_type.qint8>, + %scales_0: tensor<2xf32>, %zps_0: tensor, + %scales_1: tensor, %zps_1: tensor) -> tensor<*x!tf_type.qint8> { + // expected-error @below {{'tf.UniformRequantize' op quantization_axis is -1, scales must have 0 rank.}} + %0 = "tf.UniformRequantize"(%arg0, %scales_0, %zps_0, %scales_1, %zps_1) { + input_quantization_axis = -1 : i64, input_quantization_min_val = -2147483648 : i64, input_quantization_max_val = 2147483647 : i64, + output_quantization_axis = -1 : i64, output_quantization_min_val = -128 : i64, output_quantization_max_val = 127 : i64 + } : (tensor<*x!tf_type.qint8>, tensor<2xf32>, tensor, tensor, tensor) -> tensor<*x!tf_type.qint8> + func.return %0 : tensor<*x!tf_type.qint8> +} + +// ----- + +func.func @testUniformRequantize( + %arg0: tensor<*x!tf_type.qint8>, + %scales_0: tensor, %zps_0: tensor, + %scales_1: tensor<2xf32>, %zps_1: tensor) -> tensor<*x!tf_type.qint8> { + // expected-error @below {{'tf.UniformRequantize' op quantization_axis is -1, scales must have 0 rank.}} + %0 = "tf.UniformRequantize"(%arg0, %scales_0, %zps_0, %scales_1, %zps_1) { + input_quantization_axis = -1 : i64, input_quantization_min_val = -2147483648 : i64, input_quantization_max_val = 2147483647 : i64, + output_quantization_axis = -1 : i64, output_quantization_min_val = -128 : i64, output_quantization_max_val = 127 : i64 + } : (tensor<*x!tf_type.qint8>, tensor, tensor, tensor<2xf32>, tensor) -> tensor<*x!tf_type.qint8> + func.return %0 : tensor<*x!tf_type.qint8> +} + +// ----- + +func.func @testUniformDequantize(%arg0: tensor<*x!tf_type.qint8>, %scales: tensor<2xf32>, %zps: tensor) -> tensor<*xf32> { + // expected-error @below {{'tf.UniformDequantize' op quantization_axis is -1, scales must have 0 rank.}} + %0 = "tf.UniformDequantize"(%arg0, %scales, %zps) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor<*x!tf_type.qint8>, tensor<2xf32>, tensor) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +func.func @testUniformQuantizedDot( + %input: tensor<*x!tf_type.qint8>, %weight: tensor<2x2x!tf_type.qint8>, + %input_scales: tensor<2xf32>, %input_zps: tensor, + %weight_scales: tensor, %weight_zps: tensor, + %output_scales: tensor, %output_zps: tensor) -> () { + // expected-error @below {{'tf.UniformQuantizedDot' op quantization_axis is -1, scales must have 0 rank.}} + %1 = "tf.UniformQuantizedDot"( + %input, %weight, + %input_scales, %input_zps, + %weight_scales, %weight_zps, + %output_scales, %output_zps) { + lhs_quantization_axis = -1 : i64, + lhs_quantization_min_val = -128 : i64, + lhs_quantization_max_val = 127 : i64, + rhs_quantization_axis = -1 : i64, + rhs_quantization_min_val = -128 : i64, + rhs_quantization_max_val = 127 : i64, + output_quantization_axis = -1 : i64, + output_quantization_min_val = -2147483648 : i64, + output_quantization_max_val = 2147483647 : i64} : ( + tensor<*x!tf_type.qint8>, tensor<2x2x!tf_type.qint8>, + tensor<2xf32>, tensor, + tensor, tensor, + tensor, tensor) -> tensor<*x!tf_type.qint32> + func.return +} + +// ----- + +func.func @testUniformQuantizedDot( + %input: tensor<*x!tf_type.qint8>, %weight: tensor<2x2x!tf_type.qint8>, + %input_scales: tensor, %input_zps: tensor, + %weight_scales: tensor<2xf32>, %weight_zps: tensor, + %output_scales: tensor, %output_zps: tensor) -> () { + // expected-error @below {{'tf.UniformQuantizedDot' op quantization_axis is -1, scales must have 0 rank.}} + %1 = "tf.UniformQuantizedDot"( + %input, %weight, + %input_scales, %input_zps, + %weight_scales, %weight_zps, + %output_scales, %output_zps) { + lhs_quantization_axis = -1 : i64, + lhs_quantization_min_val = -128 : i64, + lhs_quantization_max_val = 127 : i64, + rhs_quantization_axis = -1 : i64, + rhs_quantization_min_val = -128 : i64, + rhs_quantization_max_val = 127 : i64, + output_quantization_axis = -1 : i64, + output_quantization_min_val = -2147483648 : i64, + output_quantization_max_val = 2147483647 : i64} : ( + tensor<*x!tf_type.qint8>, tensor<2x2x!tf_type.qint8>, + tensor, tensor, + tensor<2xf32>, tensor, + tensor, tensor) -> tensor<*x!tf_type.qint32> + func.return +} + +// ----- + +func.func @testUniformQuantizedDot( + %input: tensor<*x!tf_type.qint8>, %weight: tensor<2x2x!tf_type.qint8>, + %input_scales: tensor, %input_zps: tensor, + %weight_scales: tensor, %weight_zps: tensor, + %output_scales: tensor<2xf32>, %output_zps: tensor) -> () { + // expected-error @below {{'tf.UniformQuantizedDot' op quantization_axis is -1, scales must have 0 rank.}} + %1 = "tf.UniformQuantizedDot"( + %input, %weight, + %input_scales, %input_zps, + %weight_scales, %weight_zps, + %output_scales, %output_zps) { + lhs_quantization_axis = -1 : i64, + lhs_quantization_min_val = -128 : i64, + lhs_quantization_max_val = 127 : i64, + rhs_quantization_axis = -1 : i64, + rhs_quantization_min_val = -128 : i64, + rhs_quantization_max_val = 127 : i64, + output_quantization_axis = -1 : i64, + output_quantization_min_val = -2147483648 : i64, + output_quantization_max_val = 2147483647 : i64} : ( + tensor<*x!tf_type.qint8>, tensor<2x2x!tf_type.qint8>, + tensor, tensor, + tensor, tensor, + tensor<2xf32>, tensor) -> tensor<*x!tf_type.qint32> + func.return +} + +// Following tests are for LegacyCall symbol use verifier. + +// ----- + +// Tests that valid symbol use does not produce any error. +func.func @valid_symbol_use(%arg0: tensor) -> () { + "tf.LegacyCall"(%arg0) {f = @call_func} : (tensor) -> (tensor) + func.return +} + +func.func @call_func(%arg0: tensor) -> tensor { + func.return %arg0 : tensor +} + +// ----- + +// Tests that undefined call function produces error. +func.func @test_undefined_function() -> () { + // expected-error @below {{'f' attribute refers to an undefined function: undefined_func}} + "tf.LegacyCall"() {f = @undefined_func} : () -> () + func.return +} + +// ----- + +// Tests that argument count mismatch produces error. +func.func @test_arg_count_mismatch(%arg0: tensor) -> () { + // expected-error @below {{argument count mismatch: 'args' has 1 argument(s), but 'call_func' expects 2}} + "tf.LegacyCall"(%arg0) {f = @call_func} : (tensor) -> tensor + func.return +} + +func.func @call_func(%arg0: tensor, %arg1: tensor) -> tensor { + func.return %arg0 : tensor +} + +// ----- + +func.func @test_batch_function_with_valid_symbol(%arg0: tensor<1x3xf32>, %arg1: tensor>>) -> () { + "tf.BatchFunction"(%arg0, %arg1) {batch_timeout_micros = 100000 : i64, f = @batched_function, max_batch_size = 6 : i64, max_enqueued_batches = 10 : i64, num_batch_threads = 1 : i64, operand_segment_sizes = array} : (tensor<1x3xf32>, tensor>>) -> tensor<*xf32> + func.return +} + +func.func private @batched_function(%arg0: tensor<1x3xf32>, %arg1: tensor<*x!tf_type.resource>) -> tensor<1x3xf32> { + %0 = "tf.Identity"(%arg0) : (tensor<1x3xf32>) -> tensor<1x3xf32> + func.return %0 : tensor<1x3xf32> +} + +// ----- + +func.func @test_batch_function_with_invalid_symbol(%arg0: tensor<1x3xf32>, %arg1: tensor>>) -> () { + // expected-error @below {{'f' attribute refers to an undefined function: undefined_function}} + "tf.BatchFunction"(%arg0, %arg1) {batch_timeout_micros = 100000 : i64, f = @undefined_function, max_batch_size = 6 : i64, max_enqueued_batches = 10 : i64, num_batch_threads = 1 : i64, operand_segment_sizes = array} : (tensor<1x3xf32>, tensor>>) -> tensor<*xf32> + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD index 4c0ca54b235..88c31e5057f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD @@ -2,6 +2,7 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py index 254d1699cd9..b3bd3449851 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common.py @@ -24,8 +24,8 @@ from absl import flags from absl import logging import tensorflow.compat.v2 as tf - from tensorflow.python import pywrap_mlir # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.lib.io import file_io # Use /tmp to make debugging the tests easier (see README.md) flags.DEFINE_string('save_model_path', '', @@ -88,5 +88,9 @@ def app_main(argv): mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, 'canonicalize', show_debug_info) print(mlir) + filename = '%s/result.mlirbc' % save_model_path + pywrap_mlir.experimental_write_bytecode(filename, mlir) + if not file_io.file_exists(filename): + raise app.UsageError('Failed to create bytecode output.') app.run(app_main) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_upgrade_legacy_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_upgrade_legacy_v1.py index 34ca31137da..eefb57b84dd 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_upgrade_legacy_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_upgrade_legacy_v1.py @@ -29,7 +29,9 @@ # CHECK-SAME: then_branch = @"key/[[then:[a-zA-Z_0-9]+]]" # CHECK: func private @"key/[[else]]"( +# CHECK-SAME: tf._original_func_name # CHECK: func private @"key/[[then]]"( +# CHECK-SAME: tf._original_func_name def Test(): diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/debug_info.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/debug_info.py index 364b6953e5c..5a3a9875a00 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/debug_info.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/debug_info.py @@ -30,8 +30,10 @@ def some_function(self, x, y): return x + y # Basic check that the debug info file is being correctly saved and loaded. # - # CHECK: "tf.AddV2"{{.*}}loc(#[[LOC:.*]]) - # CHECK: #[[LOC]] = loc({{.*}}callsite("{{[^"]*}}/debug_info.py{{.*}}":{{[0-9]+}}:{{[0-9]+}} + # CHECK: "tf.AddV2"{{.*}}loc(#loc{{[0-9]+}}) + # CHECK: "tf.Identity"{{.*}}loc(#loc{{[0-9]+}}) + # CHECK: #loc{{[0-9]+}} = loc("{{.*}}debug_info.py":{{[0-9]+}}:{{[0-9]+}}) + # CHECK: #loc{{[0-9]+}} = loc(callsite(#loc{{[0-9]+}} at #loc{{[0-9]+}})) if __name__ == '__main__': diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_to_hlo_pipeline/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/tf_to_hlo_pipeline/BUILD index c9b9a22838d..954eca9c0e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_to_hlo_pipeline/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_to_hlo_pipeline/BUILD @@ -1,6 +1,8 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + licenses(["notice"]) glob_lit_tests( diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_bridge_v1/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/tpu_bridge_v1/BUILD index c9b9a22838d..954eca9c0e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_bridge_v1/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_bridge_v1/BUILD @@ -1,6 +1,8 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + licenses(["notice"]) glob_lit_tests( diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir index 3f575069f6a..355085be8b4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_cluster_formation.mlir @@ -252,7 +252,7 @@ func.func @replication(%arg0: tensor, %arg1: tensor, %arg2: tensor // Test replication with model parallelism using partitioned resource inputs. // The cluster will be wrapped in a `tf_device.cluster` first and then by a // replicate. -// TPUPartitionedInput nodes would be inside the replicate but outside the +// TPUPartitionedInputV2 nodes would be inside the replicate but outside the // cluster. // TPUReplicatedInput and TPUReplicatedOutput nodes will be replaced by the // replicate operands and results. @@ -265,7 +265,7 @@ func.func @replication_with_model_parallelism(%arg0: !rtype, %arg1: !rtype, %arg %1 = "tf.opB"() {is_stateless = true} : () -> tensor %2 = "tf.TPUReplicatedInput"(%arg0, %arg2) : (!rtype, !rtype) -> !rtype %3 = "tf.TPUReplicatedInput"(%arg1, %arg3) : (!rtype, !rtype) -> !rtype - %4 = "tf.TPUPartitionedInput"(%2, %3) {_XlaSharding = "", device = "", partition_dim = -1 : i64} : (!rtype, !rtype) -> !rtype + %4 = "tf.TPUPartitionedInputV2"(%2, %3) {_XlaSharding = "", device = "", partition_dims = []} : (!rtype, !rtype) -> !rtype %5 = "tf.TPUReplicatedInput"(%0, %1) : (tensor, tensor) -> tensor %6 = "tf.opC"(%4) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (!rtype) -> tensor<10x3xf32> %7:2 = "tf.TPUReplicatedOutput"(%6) : (tensor<10x3xf32>) -> (tensor<10x3xf32>, tensor<10x3xf32>) @@ -282,7 +282,7 @@ func.func @replication_with_model_parallelism(%arg0: !rtype, %arg1: !rtype, %arg // CHECK-DAG: [%[[ARG_1]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor>> // CHECK-DAG: [%[[OP_A]], %[[OP_B]]] as %[[RI_2:[a-z0-9]*]]: tensor // CHECK-SAME: n = 2 : i32 -// CHECK: %[[PI:[0-9]*]] = "tf.TPUPartitionedInput"(%[[RI_0]], %[[RI_1]]) +// CHECK: %[[PI:[0-9]*]] = "tf.TPUPartitionedInputV2"(%[[RI_0]], %[[RI_1]]) // CHECK-NEXT: %[[CLUSTER:[0-9]*]]:2 = "tf_device.cluster"() ({ // CHECK: %[[OP_C:[0-9]*]] = "tf.opC"(%[[PI]]) // CHECK: %[[OP_D:[0-9]*]] = "tf.opD"(%[[RI_2]]) @@ -568,8 +568,8 @@ func.func @bad_num_replicas() { func.func @replication_with_model_parallelism(%arg0: !rtype, %arg1: !rtype, %arg2: !rtype, %arg3: !rtype) -> (tensor<10x3xf32>) { %2 = "tf.TPUReplicatedInput"(%arg0, %arg2) : (!rtype, !rtype) -> !rtype %3 = "tf.TPUReplicatedInput"(%arg1, %arg3) : (!rtype, !rtype) -> !rtype - // expected-error@+1 {{'tf.TPUPartitionedInput' op requires 4 operands but found 2}} - %4 = "tf.TPUPartitionedInput"(%2, %3) {_XlaSharding = "", device = "", partition_dim = -1 : i64} : (!rtype, !rtype) -> !rtype + // expected-error@+1 {{'tf.TPUPartitionedInputV2' op requires 4 operands but found 2}} + %4 = "tf.TPUPartitionedInputV2"(%2, %3) {_XlaSharding = "", device = "", partition_dims = []} : (!rtype, !rtype) -> !rtype %6 = "tf.opC"(%4) {_xla_compile_device_type = "TPU", _replication_info = "replicate", is_stateless = true} : (!rtype) -> tensor<10x3xf32> %7:2 = "tf.TPUReplicatedOutput"(%6) : (tensor<10x3xf32>) -> (tensor<10x3xf32>, tensor<10x3xf32>) "tf.TPUReplicateMetadata"() {_xla_compile_device_type = "TPU", _replication_info = "replicate", device = "/device:TPU:0", num_cores_per_replica = 4 : i64, num_replicas = 2 : i64, topology = "topology"} : () -> () diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir index 596b9c7fc8e..9513ef51360 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir @@ -1778,6 +1778,55 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor func.return %1 : tensor<2xi32> } + // Check that a non-XLA value is not routed through the XLA side. + + // CHECK-LABEL: func @nonxla_static + func.func @nonxla_static() -> () { + "tf_device.cluster"() ({ + %0 = "tf.A"() : () -> (tensor) + %1 = "tf.B"(%0) {_xla_outside_compilation = "cluster0"} : (tensor) -> tensor + %2 = "tf.C"(%1) {_xla_outside_compilation = "cluster0"} : (tensor) -> tensor + "tf.D"(%2) : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () + func.return + } + + // Check that a non-XLA value with dynamic shape is not routed through the XLA side. + + // CHECK-LABEL: func @nonxla_dynamic + func.func @nonxla_dynamic() -> () { + "tf_device.cluster"() ({ + %0 = "tf.A"() : () -> (tensor) + %1 = "tf.B"(%0) {_xla_outside_compilation = "cluster0"} : (tensor) -> tensor + %2 = "tf.C"(%1) {_xla_outside_compilation = "cluster0"} : (tensor) -> tensor + "tf.D"(%2) : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () + func.return + } + + // Reproducer for an operand #x does not dominate this use + + // CHECK-LABEL: func @op_dominate_repro + func.func @op_dominate_repro(%writer: tensor<*x!tf_type.resource> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) -> () { + %step = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + %tag = "tf.Const"() {device = "", value = dense<""> : tensor} : () -> tensor + %wmetadata = "tf.Const"() {device = "", value = dense<""> : tensor} : () -> tensor + %pred = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + "tf_device.cluster"() ({ + "tf.IfRegion"(%pred) ({ + %wtensor = "tf.Const"() {device = "", value = dense<0.0> : tensor} : () -> tensor + "tf.WriteSummary"(%writer, %step, %wtensor, %tag, %wmetadata) {_xla_outside_compilation = "auto"} : (tensor<*x!tf_type.resource>, tensor, tensor, tensor, tensor) -> () + "tf.WriteSummary"(%writer, %step, %wtensor, %tag, %wmetadata) {_xla_outside_compilation = "auto"} : (tensor<*x!tf_type.resource>, tensor, tensor, tensor, tensor) -> () + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) {is_stateless = false} : (tensor) -> () + tf_device.return + }) {_replication_info = "cluster__train_single_step", _xla_compile_device_type = "TPU", allow_soft_placement = true, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], num_cores_per_replica = 1 : i64, num_replicas = 1, padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : () -> () + return + } } // ----- @@ -1967,3 +2016,25 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor } } +// ----- + +// Tests that an error is reported when an op with _xla_outside_compilation has +// an ancestor with _xla_outside_compilation. + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + func.func @outside_comp_ancestor() { + "tf_device.cluster"() ({ + "tf.WhileRegion"() ({ + ^bb0(): + // expected-error @+1 {{has an ancestor marked for outside compilation}} + %1 = "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> tensor + "tf.Yield"(%1) : (tensor) -> () + }, { + ^bb0(): + "tf.Yield"() : () -> () + }) {_xla_outside_compilation = "cluster1", is_stateless = true} : () -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", topology = "", device_assignment = []} : () -> () + func.return + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_partitioned_op_conversion.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_partitioned_op_conversion.mlir new file mode 100644 index 00000000000..f2a52c8dc34 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_partitioned_op_conversion.mlir @@ -0,0 +1,106 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-partitioned-op-conversion | FileCheck %s + +// CHECK-LABEL:func @replicated +// CHECK-SAME: ([[ARG0:%.*]]: tensor>>, [[ARG1:%.*]]: tensor>>, [[ARG2:%.*]]: tensor>>, [[ARG3:%.*]]: tensor>>) +func.func @replicated(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>, %arg3: tensor>>) -> tensor>> { + // CHECK: [[PI_0:%.*]] = "tf.TPUPartitionedInputV2"([[ARG0]], [[ARG1]]) + // CHECK-SAME: _XlaSharding = "" + // CHECK-SAME: partition_dims = [] + // CHECK: [[PI_1:%.*]] = "tf.TPUPartitionedInputV2"([[ARG2]], [[ARG3]]) + // CHECK-SAME: _XlaSharding = "" + // CHECK-SAME: partition_dims = [] + // CHECK: [[RI:%.*]] = "tf.TPUReplicatedInput"([[PI_0]], [[PI_1]]) + %pi_0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {_XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> + %pi_1 = "tf.TPUPartitionedInput"(%arg2, %arg3) {_XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> + %ri = "tf.TPUReplicatedInput"(%pi_0, %pi_1) : (tensor>>, tensor>>) -> tensor>> + // CHECK: return [[RI]] + func.return %ri : tensor>> +} + +// ----- + +// CHECK-LABEL:func @partitioned_2d +// CHECK-SAME: ([[ARG0:%.*]]: tensor<10x3xf32>, [[ARG1:%.*]]: tensor<10x3xf32>) +func.func @partitioned_2d(%arg0: tensor<10x3xf32>, %arg1: tensor<10x3xf32>) -> tensor<20x3xf32> { + // CHECK: [[PI_0:%.*]] = "tf.TPUPartitionedInputV2"([[ARG0]], [[ARG1]]) + // CHECK-SAME: _XlaSharding = "123" + // CHECK-SAME: partition_dims = [2, 1] + // CHECK: [[RI:%.*]] = "tf.TPUReplicatedInput"([[PI_0]]) + "tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 1 : i64} : () -> () + %pi_0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {_XlaSharding = "123", partition_dim = 0 : i64} : (tensor<10x3xf32>, tensor<10x3xf32>) -> tensor<20x3xf32> + %ri = "tf.TPUReplicatedInput"(%pi_0) : (tensor<20x3xf32>) -> tensor<20x3xf32> + // CHECK: return [[RI]] + func.return %ri : tensor<20x3xf32> +} + +// ----- + +// CHECK-LABEL:func @partitioned_2d_resource +// CHECK-SAME: ([[ARG0:%.*]]: tensor>>, [[ARG1:%.*]]: tensor>>, [[ARG2:%.*]]: tensor>>, [[ARG3:%.*]]: tensor>>) +func.func @partitioned_2d_resource(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>, %arg3: tensor>>) -> tensor>> { + // CHECK: [[PI_0:%.*]] = "tf.TPUPartitionedInputV2"([[ARG0]], [[ARG1]]) + // CHECK-SAME: _XlaSharding = "123" + // CHECK-SAME: partition_dims = [2, 1] + // CHECK: [[PI_1:%.*]] = "tf.TPUPartitionedInputV2"([[ARG2]], [[ARG3]]) + // CHECK-SAME: _XlaSharding = "123" + // CHECK-SAME: partition_dims = [2, 1] + // CHECK: [[RI:%.*]] = "tf.TPUReplicatedInput"([[PI_0]], [[PI_1]]) + "tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 2 : i64} : () -> () + %pi_0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {_XlaSharding = "123", partition_dim = 0 : i64} : (tensor>>, tensor>>) -> tensor>> + %pi_1 = "tf.TPUPartitionedInput"(%arg2, %arg3) {_XlaSharding = "123", partition_dim = 0 : i64} : (tensor>>, tensor>>) -> tensor>> + %ri = "tf.TPUReplicatedInput"(%pi_0, %pi_1) : (tensor>>, tensor>>) -> tensor>> + // CHECK: return [[RI]] + func.return %ri : tensor>> +} + +// ----- + +// CHECK-LABEL:func @partitioned_3d +// CHECK-SAME: ([[ARG0:%.*]]: tensor>>, [[ARG1:%.*]]: tensor>>, [[ARG2:%.*]]: tensor>>, [[ARG3:%.*]]: tensor>>) +func.func @partitioned_3d(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>, %arg3: tensor>>) -> tensor>> { + // CHECK: [[PI_0:%.*]] = "tf.TPUPartitionedInputV2"([[ARG0]], [[ARG1]]) + // CHECK-SAME: _XlaSharding = "123" + // CHECK-SAME: partition_dims = [1, 2, 1] + // CHECK: [[PI_1:%.*]] = "tf.TPUPartitionedInputV2"([[ARG2]], [[ARG3]]) + // CHECK-SAME: _XlaSharding = "123" + // CHECK-SAME: partition_dims = [1, 2, 1] + // CHECK: [[RI:%.*]] = "tf.TPUReplicatedInput"([[PI_0]], [[PI_1]]) + "tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 2 : i64} : () -> () + %pi_0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {_XlaSharding = "123", partition_dim = 1 : i64} : (tensor>>, tensor>>) -> tensor>> + %pi_1 = "tf.TPUPartitionedInput"(%arg2, %arg3) {_XlaSharding = "123", partition_dim = 1 : i64} : (tensor>>, tensor>>) -> tensor>> + %ri = "tf.TPUReplicatedInput"(%pi_0, %pi_1) : (tensor>>, tensor>>) -> tensor>> + // CHECK: return [[RI]] + func.return %ri : tensor>> +} + +// ----- + +// CHECK-LABEL:func @partitioned_output_3d +// CHECK-SAME: ([[ARG:%.*]]: tensor>>) +func.func @partitioned_output_3d(%arg: tensor>>) -> tensor>> { + // CHECK: [[PO:%.*]] = "tf.TPUPartitionedOutputV2"([[ARG]]) + // CHECK-SAME: _XlaSharding = "123" + // CHECK-SAME: partition_dims = [1, 2, 1] + "tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 2 : i64} : () -> () + %po:2 = "tf.TPUPartitionedOutput"(%arg) {_XlaSharding = "123", partition_dim = 1 : i64} : (tensor>>) -> (tensor>>, tensor>>) + // CHECK: return [[PO:%.*0]] + func.return %po#0 : tensor>> +} + +// ----- + +func.func @out_of_range_dim(%arg: tensor>>) -> tensor>> { + "tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 2 : i64} : () -> () + // expected-error @+1 {{cannot partition 'tensor<16x16x16xf32>' (rank = 3) along dimension 3.}} + %po:2 = "tf.TPUPartitionedOutput"(%arg) {_XlaSharding = "123", partition_dim = 3 : i64} : (tensor>>) -> (tensor>>, tensor>>) + func.return %po#0 : tensor>> +} + +// ----- + +func.func @unranked(%arg: tensor>>) -> tensor>> { + "tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 2 : i64} : () -> () + // expected-error @+1 {{cannot convert op with unranked or non-tensor input type 'tensor<*xf32>'.}} + %po:2 = "tf.TPUPartitionedOutput"(%arg) {_XlaSharding = "123", partition_dim = 3 : i64} : (tensor>>) -> (tensor>>, tensor>>) + func.return %po#0 : tensor>> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_reorder_replicate_and_partitioned_inputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_reorder_replicate_and_partitioned_inputs.mlir index 8de8c8cc915..6daddaba7c1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_reorder_replicate_and_partitioned_inputs.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_reorder_replicate_and_partitioned_inputs.mlir @@ -5,37 +5,73 @@ func.func @simple(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>, %arg3: tensor>>) -> tensor>> { // CHECK: [[RI_0:%.*]] = "tf.TPUReplicatedInput"([[ARG0]], [[ARG2]]) // CHECK: [[RI_1:%.*]] = "tf.TPUReplicatedInput"([[ARG1]], [[ARG3]]) - // CHECK: [[PI:%.*]] = "tf.TPUPartitionedInput"([[RI_0]], [[RI_1]]) - %pi_0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {_XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> - %pi_1 = "tf.TPUPartitionedInput"(%arg2, %arg3) {_XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> + // CHECK: [[PI:%.*]] = "tf.TPUPartitionedInputV2"([[RI_0]], [[RI_1]]) + %pi_0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {_XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> + %pi_1 = "tf.TPUPartitionedInputV2"(%arg2, %arg3) {_XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> %ri = "tf.TPUReplicatedInput"(%pi_0, %pi_1) : (tensor>>, tensor>>) -> tensor>> // CHECK: return [[PI]] func.return %ri : tensor>> } +// CHECK-LABEL:func @simple_packed +// CHECK-SAME: ([[ARG0:%.*]]: tensor>>) +func.func @simple_packed(%arg0: tensor>>) -> tensor>> { + // CHECK: "tf.TPUReplicateMetadata"() + // CHECK: [[RI_0:%.*]] = "tf.TPUReplicatedInput"([[ARG0]]) + // CHECK-SAME: is_packed = true + // CHECK: [[RI_1:%.*]] = "tf.TPUReplicatedInput"([[ARG0]]) + // CHECK-SAME: is_packed = true + // CHECK: [[PI:%.*]] = "tf.TPUPartitionedInputV2"([[RI_0]], [[RI_1]]) + // CHECK-SAME: is_packed = false + "tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 2 : i64} : () -> () + %0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor>>) -> tensor>> + %1 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor>>) -> tensor>> + %2 = "tf.TPUReplicatedInput"(%0, %1) {is_packed = false} : (tensor>>, tensor>>) -> tensor>> + // CHECK: return [[PI]] + func.return %2 : tensor>> +} + +// CHECK-LABEL:func @multi_arg_packed +// CHECK-SAME: ([[ARG0:%.*]]: tensor>>, [[ARG1:%.*]]: tensor>>) +func.func @multi_arg_packed(%arg0: tensor>>, %arg1: tensor>>) -> tensor>> { + // CHECK: "tf.TPUReplicateMetadata"() + // CHECK: [[RI_0:%.*]] = "tf.TPUReplicatedInput"([[ARG0]], [[ARG1]]) + // CHECK-SAME: is_packed = false + // CHECK: [[RI_1:%.*]] = "tf.TPUReplicatedInput"([[ARG0]], [[ARG1]]) + // CHECK-SAME: is_packed = false + // CHECK: [[PI:%.*]] = "tf.TPUPartitionedInputV2"([[RI_0]], [[RI_1]]) + // CHECK-SAME: is_packed = false + "tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 2 : i64} : () -> () + %0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor>>) -> tensor>> + %1 = "tf.TPUPartitionedInputV2"(%arg1) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor>>) -> tensor>> + %2 = "tf.TPUReplicatedInput"(%0, %1) {is_packed = false} : (tensor>>, tensor>>) -> tensor>> + // CHECK: return [[PI]] + func.return %2 : tensor>> +} + // CHECK-LABEL:func @missing_xla_sharding // CHECK-SAME: ([[ARG0:%.*]]: tensor>>, [[ARG1:%.*]]: tensor>>, [[ARG2:%.*]]: tensor>>, [[ARG3:%.*]]: tensor>>) func.func @missing_xla_sharding(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>, %arg3: tensor>>) -> tensor>> { // CHECK: [[RI_0:%.*]] = "tf.TPUReplicatedInput"([[ARG0]], [[ARG2]]) // CHECK: [[RI_1:%.*]] = "tf.TPUReplicatedInput"([[ARG1]], [[ARG3]]) - // CHECK: [[PI:%.*]] = "tf.TPUPartitionedInput"([[RI_0]], [[RI_1]]) - %pi_0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {device = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> - %pi_1 = "tf.TPUPartitionedInput"(%arg2, %arg3) {device = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> + // CHECK: [[PI:%.*]] = "tf.TPUPartitionedInputV2"([[RI_0]], [[RI_1]]) + %pi_0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {device = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> + %pi_1 = "tf.TPUPartitionedInputV2"(%arg2, %arg3) {device = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> %ri = "tf.TPUReplicatedInput"(%pi_0, %pi_1) : (tensor>>, tensor>>) -> tensor>> // CHECK: return [[PI]] func.return %ri : tensor>> } // Test IR is not modified when none of the operands of tf.TPUReplicaedInput is -// a tf.TPUPartitionedInput op. +// a tf.TPUPartitionedInputV2 op. // CHECK-LABEL:func @no_change_to_dag // CHECK-SAME: ([[ARG0:%.*]]: tensor>>, [[ARG1:%.*]]: tensor>>, [[ARG2:%.*]]: tensor>>, [[ARG3:%.*]]: tensor>>) func.func @no_change_to_dag(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>, %arg3: tensor>>) -> (tensor>>, tensor>>, tensor>>) { - // CHECK: [[PI_0:%.*]] = "tf.TPUPartitionedInput"([[ARG0]], [[ARG1]]) - %pi_0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {device = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> - // CHECK: [[PI_1:%.*]] = "tf.TPUPartitionedInput"([[ARG2]], [[ARG3]]) - %pi_1 = "tf.TPUPartitionedInput"(%arg2, %arg3) {device = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> + // CHECK: [[PI_0:%.*]] = "tf.TPUPartitionedInputV2"([[ARG0]], [[ARG1]]) + %pi_0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {device = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> + // CHECK: [[PI_1:%.*]] = "tf.TPUPartitionedInputV2"([[ARG2]], [[ARG3]]) + %pi_1 = "tf.TPUPartitionedInputV2"(%arg2, %arg3) {device = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> // CHECK: [[RI:%.*]] = "tf.TPUReplicatedInput"([[ARG0]], [[ARG1]]) %ri = "tf.TPUReplicatedInput"(%arg0, %arg1) : (tensor>>, tensor>>) -> tensor>> // CHECK: return [[RI]], [[PI_0]], [[PI_1]] @@ -44,10 +80,31 @@ func.func @no_change_to_dag(%arg0: tensor>>, // ----- +func.func @missing_metadata(%arg0: tensor>>) -> tensor>> { + // expected-error@+1 {{num cores per replica unavailable, metadata missing?}} + %0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor>>) -> tensor>> + %1 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor>>) -> tensor>> + %2 = "tf.TPUReplicatedInput"(%0, %1) {is_packed = false} : (tensor>>, tensor>>) -> tensor>> + func.return %2 : tensor>> +} + +// ----- + +func.func @inconsistent_packing(%arg0: tensor>>) -> tensor>> { + "tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 2 : i64} : () -> () + %0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor>>) -> tensor>> + // expected-error@+1 {{packing should match across ops}} + %1 = "tf.TPUPartitionedInputV2"(%arg0, %arg0) {_XlaSharding = "", partition_dims = [], is_packed = false} : (tensor>>, tensor>>) -> tensor>> + %2 = "tf.TPUReplicatedInput"(%0, %1) {is_packed = false} : (tensor>>, tensor>>) -> tensor>> + func.return %2 : tensor>> +} + +// ----- + func.func @xla_sharding_mismatch(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>, %arg3: tensor>>) -> tensor>> { - %pi_0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {_XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> - %pi_1 = "tf.TPUPartitionedInput"(%arg2, %arg3) {_XlaSharding = "123", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> - // expected-error@+1 {{expects all inputs from 'tf.TPUPartitionedInput' ops to have identical XLA sharding}} + %pi_0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {_XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> + %pi_1 = "tf.TPUPartitionedInputV2"(%arg2, %arg3) {_XlaSharding = "123", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> + // expected-error@+1 {{expects all inputs from 'tf.TPUPartitionedInputV2' ops to have identical XLA sharding}} %ri = "tf.TPUReplicatedInput"(%pi_0, %pi_1) : (tensor>>, tensor>>) -> tensor>> func.return %ri : tensor>> } @@ -55,9 +112,9 @@ func.func @xla_sharding_mismatch(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>, %arg3: tensor>>) -> tensor>> { - %pi_0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {_XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> - // expected-error@+1 {{expects partition_dim = -1 but found 0}} - %pi_1 = "tf.TPUPartitionedInput"(%arg2, %arg3) {_XlaSharding = "", partition_dim = 0 : i64} : (tensor>>, tensor>>) -> tensor>> + %pi_0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {_XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> + // expected-error@+1 {{expects partition_dims = [] but found [1, 2]}} + %pi_1 = "tf.TPUPartitionedInputV2"(%arg2, %arg3) {_XlaSharding = "", partition_dims = [1, 2]} : (tensor>>, tensor>>) -> tensor>> %ri = "tf.TPUReplicatedInput"(%pi_0, %pi_1) : (tensor>>, tensor>>) -> tensor>> func.return %ri : tensor>> } @@ -65,9 +122,9 @@ func.func @partition_dim_mismatch(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>, %arg3: tensor>>, %arg4: tensor>>) -> tensor>> { - %pi_0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {_XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> + %pi_0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {_XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> // expected-error@+1 {{expects 2 operands but found 3}} - %pi_1 = "tf.TPUPartitionedInput"(%arg2, %arg3, %arg4) {_XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>, tensor>>) -> tensor>> + %pi_1 = "tf.TPUPartitionedInputV2"(%arg2, %arg3, %arg4) {_XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>, tensor>>) -> tensor>> %ri = "tf.TPUReplicatedInput"(%pi_0, %pi_1) : (tensor>>, tensor>>) -> tensor>> func.return %ri : tensor>> } @@ -75,8 +132,18 @@ func.func @num_partitioned_inputs_mismatch(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>, %arg3: tensor>>) -> tensor>> { - %pi_0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {_XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> - // expected-error@+1 {{'tf.TPUReplicatedInput' op expects all inputs from 'tf.TPUPartitionedInput' ops}} + %pi_0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {_XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> + // expected-error@+1 {{'tf.TPUReplicatedInput' op expects all inputs from 'tf.TPUPartitionedInputV2' ops}} %ri = "tf.TPUReplicatedInput"(%pi_0, %arg2) {index = 1} : (tensor>>, tensor>>) -> tensor>> func.return %ri : tensor>> } + +// ----- + +func.func @num_partitioned_inputs_mismatch_num_cores_per_replica(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>) -> tensor>> { + "tf.TPUReplicateMetadata"() {num_cores_per_replica = 2 : i64, num_replicas = 1 : i64} : () -> () + // expected-error@+1 {{expects 2 operands but found 3}} + %pi = "tf.TPUPartitionedInputV2"(%arg0, %arg1, %arg2) {_XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>, tensor>>) -> tensor>> + %ri = "tf.TPUReplicatedInput"(%pi) : (tensor>>) -> tensor>> + func.return %ri : tensor>> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_resource_partitioning.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_resource_partitioning.mlir index 9f390374aaa..aeab33e74d4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_resource_partitioning.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_resource_partitioning.mlir @@ -7,35 +7,56 @@ func.func private @computation(%arg0: tensor) -> tensor func.func @read_write_resource(%arg0: tensor>>, %arg1: tensor>>) { // CHECK-DAG: [[READ0:%.+]] = "tf.ReadVariableOp"([[ARG0]]) // CHECK-DAG: [[READ1:%.+]] = "tf.ReadVariableOp"([[ARG1]]) - // CHECK: [[INPUT:%.+]] = "tf.TPUPartitionedInput"([[READ0]], [[READ1]]) + // CHECK: [[INPUT:%.+]] = "tf.TPUPartitionedInputV2"([[READ0]], [[READ1]]) // CHECK-SAME: _XlaSharding = "" - // CHECK-SAME: partition_dim = -1 - %0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> + // CHECK-SAME: partition_dims = [] + %0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor // CHECK: [[COMPUTATION:%.+]] = "tf_device.cluster_func"([[INPUT]]) %2 = "tf_device.cluster_func"(%1) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor) -> tensor - // CHECK: [[OUTPUT:%.+]]:2 = "tf.TPUPartitionedOutput"([[COMPUTATION]]) + // CHECK: [[OUTPUT:%.+]]:2 = "tf.TPUPartitionedOutputV2"([[COMPUTATION]]) // CHECK-SAME: _XlaSharding = "" - // CHECK-SAME: partition_dim = -1 + // CHECK-SAME: partition_dims = [] // CHECK-DAG: "tf.AssignVariableOp"([[ARG0]], [[OUTPUT]]#0) // CHECK-DAG: "tf.AssignVariableOp"([[ARG1]], [[OUTPUT]]#1) "tf.AssignVariableOp"(%0, %2) : (tensor>>, tensor) -> () func.return } +// CHECK-LABEL: func @read_write_packed_resource +// CHECK-SAME: ([[ARG0:%.+]]: tensor>>) +func.func @read_write_packed_resource(%arg0: tensor>>) { + // CHECK-DAG: [[READ0:%.+]] = "tf.ReadVariableOp"([[ARG0]]) + // CHECK: [[INPUT:%.+]] = "tf.TPUPartitionedInputV2"([[READ0]]) + // CHECK-SAME: _XlaSharding = "" + // CHECK-SAME: is_packed = true + // CHECK-SAME: partition_dims = [] + %0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor>>) -> tensor>> + %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor + // CHECK: [[COMPUTATION:%.+]] = "tf_device.cluster_func"([[INPUT]]) + %2 = "tf_device.cluster_func"(%1) {func = @computation, use_spmd_for_xla_partitioning = true, num_cores_per_replica = 2 : i64} : (tensor) -> tensor + // CHECK: [[OUTPUT:%.+]]:2 = "tf.TPUPartitionedOutputV2"([[COMPUTATION]]) + // CHECK-SAME: _XlaSharding = "" + // CHECK-SAME: partition_dims = [] + // CHECK-DAG: "tf.AssignVariableOp"([[ARG0]], [[OUTPUT]]#0) + // CHECK-DAG: "tf.AssignVariableOp"([[ARG0]], [[OUTPUT]]#1) + "tf.AssignVariableOp"(%0, %2) : (tensor>>, tensor) -> () + func.return +} + // CHECK-LABEL: func @read_only_resource // CHECK-SAME: ([[ARG0:%.+]]: tensor>>, [[ARG1:%.+]]: tensor>>) func.func @read_only_resource(%arg0: tensor>>, %arg1: tensor>>) -> tensor { // CHECK-DAG: [[READ0:%.+]] = "tf.ReadVariableOp"([[ARG0]]) // CHECK-DAG: [[READ1:%.+]] = "tf.ReadVariableOp"([[ARG1]]) - // CHECK: [[INPUT:%.+]] = "tf.TPUPartitionedInput"([[READ0]], [[READ1]]) + // CHECK: [[INPUT:%.+]] = "tf.TPUPartitionedInputV2"([[READ0]], [[READ1]]) // CHECK-SAME: _XlaSharding = "" - // CHECK-SAME: partition_dim = -1 - %0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> + // CHECK-SAME: partition_dims = [] + %0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor // CHECK: "tf_device.cluster_func"([[INPUT]]) %2 = "tf_device.cluster_func"(%1) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor) -> tensor - // CHECK-NOT: tf.TPUPartitionedOutput + // CHECK-NOT: tf.TPUPartitionedOutputV2 // CHECK-NOT: tf.AssignVariableOp func.return %2 : tensor } @@ -47,11 +68,11 @@ func.func private @computation_two_args(%arg0: tensor, %arg1: tensor) func.func @partitioned_variable_multiple_users(%arg0: tensor>>, %arg1: tensor>>) { // CHECK-DAG: [[READ0:%.+]] = "tf.ReadVariableOp"([[ARG0]]) // CHECK-DAG: [[READ1:%.+]] = "tf.ReadVariableOp"([[ARG1]]) - // CHECK: [[INPUT0:%.+]] = "tf.TPUPartitionedInput"([[READ0]], [[READ1]]) + // CHECK: [[INPUT0:%.+]] = "tf.TPUPartitionedInputV2"([[READ0]], [[READ1]]) // CHECK-DAG: [[READ2:%.+]] = "tf.ReadVariableOp"([[ARG0]]) // CHECK-DAG: [[READ3:%.+]] = "tf.ReadVariableOp"([[ARG1]]) - // CHECK: [[INPUT1:%.+]] = "tf.TPUPartitionedInput"([[READ2]], [[READ3]]) - %0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> + // CHECK: [[INPUT1:%.+]] = "tf.TPUPartitionedInputV2"([[READ2]], [[READ3]]) + %0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor %2 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor // CHECK: "tf_device.cluster_func"([[INPUT0]], [[INPUT1]]) @@ -64,12 +85,12 @@ func.func @partitioned_variable_multiple_users(%arg0: tensor>>, [[ARG1:%.+]]: tensor>>) func.func @no_spmd(%arg0: tensor>>, %arg1: tensor>>) { - // CHECK: "tf.TPUPartitionedInput"([[ARG0]], [[ARG1]]) - %0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> + // CHECK: "tf.TPUPartitionedInputV2"([[ARG0]], [[ARG1]]) + %0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor %2 = "tf_device.cluster_func"(%1) {func = @computation} : (tensor) -> tensor - // CHECK: "tf.TPUPartitionedInput"([[ARG0]], [[ARG1]]) - %3 = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> + // CHECK: "tf.TPUPartitionedInputV2"([[ARG0]], [[ARG1]]) + %3 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> %4 = "tf.ReadVariableOp"(%3) : (tensor>>) -> tensor %5 = "tf_device.cluster_func"(%4) {func = @computation, use_spmd_for_xla_partitioning = false} : (tensor) -> tensor func.return @@ -77,20 +98,20 @@ func.func @no_spmd(%arg0: tensor>>, %arg1: tensor< // CHECK-LABEL: func @read_write_unpartitioned_resource func.func @read_write_unpartitioned_resource(%arg0: tensor>>) { - // CHECK-NOT: tf.TPUPartitionedInput + // CHECK-NOT: tf.TPUPartitionedInputV2 %0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor %1 = "tf_device.cluster_func"(%0) {func = @computation} : (tensor) -> tensor - // CHECK-NOT: tf.TPUPartitionedOutput + // CHECK-NOT: tf.TPUPartitionedOutputV2 "tf.AssignVariableOp"(%arg0, %1) : (tensor>>, tensor) -> () func.return } // CHECK-LABEL: func @read_only_unpartitioned_resource func.func @read_only_unpartitioned_resource(%arg0: tensor>>) { - // CHECK-NOT: tf.TPUPartitionedInput + // CHECK-NOT: tf.TPUPartitionedInputV2 %0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor %1 = "tf_device.cluster_func"(%0) {func = @computation} : (tensor) -> tensor - // CHECK-NOT: tf.TPUPartitionedOutput + // CHECK-NOT: tf.TPUPartitionedOutputV2 // CHECK-NOT: tf.AssignVariableOp func.return } @@ -98,8 +119,8 @@ func.func @read_only_unpartitioned_resource(%arg0: tensor>>, [[ARG1:%.+]]: tensor>>) -> tensor func.func @resource_read_multiple_users(%arg0: tensor>>, %arg1: tensor>>) -> tensor { - // CHECK: "tf.TPUPartitionedInput"([[ARG0]], [[ARG1]]) - %0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> + // CHECK: "tf.TPUPartitionedInputV2"([[ARG0]], [[ARG1]]) + %0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor %2 = "tf_device.cluster_func"(%1) {func = @computation} : (tensor) -> tensor func.return %1 : tensor @@ -107,27 +128,49 @@ func.func @resource_read_multiple_users(%arg0: tensor) -> tensor { - // CHECK-NOT: tf.TPUPartitionedInput + // CHECK-NOT: tf.TPUPartitionedInputV2 %0 = "tf_device.cluster_func"(%arg0) {func = @computation} : (tensor) -> tensor - // CHECK-NOT: tf.TPUPartitionedOutput + // CHECK-NOT: tf.TPUPartitionedOutputV2 func.return %0 : tensor } // CHECK-LABEL: func @resource_missing_subtype // CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor) func.func @resource_missing_subtype(%arg0: tensor, %arg1: tensor) { - // CHECK: "tf.TPUPartitionedInput"([[ARG0]], [[ARG1]]) - %0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor, tensor) -> tensor + // CHECK: "tf.TPUPartitionedInputV2"([[ARG0]], [[ARG1]]) + %0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dims = []} : (tensor, tensor) -> tensor %1 = "tf.ReadVariableOp"(%0) : (tensor) -> tensor %2 = "tf_device.cluster_func"(%1) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor) -> tensor - // CHECK-NOT: tf.TPUPartitionedOutput + // CHECK-NOT: tf.TPUPartitionedOutputV2 "tf.AssignVariableOp"(%0, %2) : (tensor, tensor) -> () func.return } // ----- -// Check outside compiled that uses a TPUPartitionedInput. +func.func @missing_num_cores_per_replica(%arg0: tensor>>) { + // expected-error@+1 {{op num cores per replica unavailable}} + %0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "", partition_dims = [], is_packed = true} : (tensor>>) -> tensor>> + %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor + %2 = "tf_device.cluster_func"(%1) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor) -> tensor + "tf.AssignVariableOp"(%0, %2) : (tensor>>, tensor) -> () + func.return +} + +// ----- + +func.func @mismatch_num_cores_per_replica(%arg0: tensor>>) { + // expected-error@+1 {{expects 2 operands but found 3}} + %0 = "tf.TPUPartitionedInputV2"(%arg0, %arg0, %arg0) {_XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>, tensor>>) -> tensor>> + %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor + %2 = "tf_device.cluster_func"(%1) {func = @computation, use_spmd_for_xla_partitioning = true, num_cores_per_replica = 2 : i64} : (tensor) -> tensor + "tf.AssignVariableOp"(%0, %2) : (tensor>>, tensor) -> () + func.return +} + +// ----- + +// Check outside compiled that uses a TPUPartitionedInputV2. func.func private @computation(%arg0: tensor) -> tensor @@ -136,10 +179,10 @@ func.func private @computation(%arg0: tensor) -> tensor func.func @with_host_process(%arg0: tensor>>, %arg1: tensor>>) { // CHECK-DAG: [[READ0:%.+]] = "tf.ReadVariableOp"([[ARG0]]) // CHECK-DAG: [[READ1:%.+]] = "tf.ReadVariableOp"([[ARG1]]) - // CHECK: [[INPUT:%.+]] = "tf.TPUPartitionedInput"([[READ0]], [[READ1]]) + // CHECK: [[INPUT:%.+]] = "tf.TPUPartitionedInputV2"([[READ0]], [[READ1]]) // CHECK-SAME: _XlaSharding = "" - // CHECK-SAME: partition_dim = -1 - %0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> + // CHECK-SAME: partition_dims = [] + %0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor // CHECK: [[COMPUTATION:%.+]] = "tf_device.parallel_execute"() // CHECK: "tf.OpA"([[READ0]]) @@ -153,9 +196,9 @@ func.func @with_host_process(%arg0: tensor>>, %arg %3 = "tf_device.cluster_func"(%1) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor) -> tensor tf_device.return %3 : tensor }) : () -> tensor - // CHECK: [[OUTPUT:%.+]]:2 = "tf.TPUPartitionedOutput"([[COMPUTATION]]) + // CHECK: [[OUTPUT:%.+]]:2 = "tf.TPUPartitionedOutputV2"([[COMPUTATION]]) // CHECK-SAME: _XlaSharding = "" - // CHECK-SAME: partition_dim = -1 + // CHECK-SAME: partition_dims = [] // CHECK-DAG: "tf.AssignVariableOp"([[ARG0]], [[OUTPUT]]#0) // CHECK-DAG: "tf.AssignVariableOp"([[ARG1]], [[OUTPUT]]#1) "tf.AssignVariableOp"(%0, %2) : (tensor>>, tensor) -> () @@ -165,7 +208,7 @@ func.func @with_host_process(%arg0: tensor>>, %arg // ----- // Check for an error that reports the unsupported case of outside compiled -// code that uses a TPUPartitionedInput without REPLICATED sharding. +// code that uses a TPUPartitionedInputV2 without REPLICATED sharding. // The TPUParitionedInput has the following OpSharding: // Proto debug string: @@ -182,7 +225,7 @@ func.func private @computation(%arg0: tensor) -> tensor func.func @non_replicated_sharding(%arg0: tensor>>, %arg1: tensor>>) { // expected-error@+1 {{support}} - %0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "\08\03\1A\02\01\02\22\02\00\01", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> + %0 = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "\08\03\1A\02\01\02\22\02\00\01", partition_dims = []} : (tensor>>, tensor>>) -> tensor>> %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor %2 = "tf_device.parallel_execute"() ({ "tf_device.launch"() ({ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 187a7035117..c81a69f791f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -1185,7 +1185,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- // Test `tf_device.cluster_func` on TPU with pre-split replicate sharded -// input/output using `tf.TPUPartitionedInput` and `tf.TPUPartitionedOutput`. +// input/output using `tf.TPUPartitionedInputV2` and `tf.TPUPartitionedOutputV2`. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { func.func @cluster(%arg0: tensor>>, %arg1: tensor>>) { @@ -1193,8 +1193,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %read0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor // CHECK: %[[READ_VAR_1:[0-9]*]] = "tf.ReadVariableOp"(%arg1) %read1 = "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor - // CHECK-NOT: tf.TPUPartitionedInput - %partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i64} : (tensor, tensor) -> tensor + // CHECK-NOT: tf.TPUPartitionedInputV2 + %partitioned_input = "tf.TPUPartitionedInputV2"(%read0, %read1) {N = 2 : i64, partition_dims = []} : (tensor, tensor) -> tensor // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:3 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"() // CHECK: "tf_device.launch" @@ -1207,8 +1207,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf.TPUExecute"(%[[READ_VAR_1]], %[[COMPILE_OUTPUT]]#2) // CHECK: device = "/job:worker/replica:0/task:0/device:TPU:1" %computation = "tf_device.cluster_func"(%partitioned_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @computation, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = [""], output_sharding_configuration = [""], use_spmd_for_xla_partitioning = true} : (tensor) -> tensor - // CHECK-NOT: tf.TPUPartitionedOutput - %partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {N = 2 : i64, partition_dim = -1 : i64} : (tensor) -> (tensor, tensor) + // CHECK-NOT: tf.TPUPartitionedOutputV2 + %partitioned_output:2 = "tf.TPUPartitionedOutputV2"(%computation) {N = 2 : i64, partition_dims = []} : (tensor) -> (tensor, tensor) // CHECK: "tf.AssignVariableOp"(%arg0, %[[PARALLEL_EXECUTE_OUTPUT]]#0) // CHECK: "tf.AssignVariableOp"(%arg1, %[[PARALLEL_EXECUTE_OUTPUT]]#1) "tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor>>, tensor) -> () @@ -1223,7 +1223,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- // Test `tf_device.cluster_func` on TPU with pre-split tile sharded input/ -// output using `tf.TPUPartitionedInput` and `tf.TPUPartitionedOutput`. +// output using `tf.TPUPartitionedInputV2` and `tf.TPUPartitionedOutputV2`. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { func.func @cluster(%arg0: tensor>>, %arg1: tensor>>) { @@ -1231,8 +1231,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor %read0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor<3x2xf32> // CHECK: %[[READ_VAR_1:[0-9]*]] = "tf.ReadVariableOp"(%arg1) %read1 = "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor<3x2xf32> - // CHECK-NOT: tf.TPUPartitionedInput - %partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {_XlaSharding = "\08\03\1A\02\01\02\22\02\00\01", partition_dim = 1 : i64} : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x4xf32> + // CHECK-NOT: tf.TPUPartitionedInputV2 + %partitioned_input = "tf.TPUPartitionedInputV2"(%read0, %read1) {_XlaSharding = "\08\03\1A\02\01\02\22\02\00\01", partition_dims = [1, 2]} : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x4xf32> // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:3 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"() // CHECK: "tf_device.launch" @@ -1245,8 +1245,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NEXT: "tf.TPUExecute"(%[[READ_VAR_1]], %[[COMPILE_OUTPUT]]#2) // CHECK: device = "/job:worker/replica:0/task:0/device:TPU:1" %computation = "tf_device.cluster_func"(%partitioned_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @computation, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01"], output_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01"], use_spmd_for_xla_partitioning = true} : (tensor<3x4xf32>) -> tensor<3x4xf32> - // CHECK-NOT: tf.TPUPartitionedOutput - %partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {_XlaSharding = "\08\03\1A\02\01\02\22\02\00\01", partition_dim = 1 : i64} : (tensor<3x4xf32>) -> (tensor<3x2xf32>, tensor<3x2xf32>) + // CHECK-NOT: tf.TPUPartitionedOutputV2 + %partitioned_output:2 = "tf.TPUPartitionedOutputV2"(%computation) {_XlaSharding = "\08\03\1A\02\01\02\22\02\00\01", partition_dims = [1, 2]} : (tensor<3x4xf32>) -> (tensor<3x2xf32>, tensor<3x2xf32>) // CHECK: "tf.AssignVariableOp"(%arg0, %[[PARALLEL_EXECUTE_OUTPUT]]#0) // CHECK: "tf.AssignVariableOp"(%arg1, %[[PARALLEL_EXECUTE_OUTPUT]]#1) "tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor>>, tensor<3x2xf32>) -> () @@ -1260,18 +1260,18 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Test that unsupported input sharding type of TPUPartitionedInputOp inputs of +// Test that unsupported input sharding type of TPUPartitionedInputV2Op inputs of // ClusterFuncOp result in error. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { func.func @cluster(%arg0: tensor>>, %arg1: tensor>>) { %read0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor %read1 = "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor - %partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i64} : (tensor, tensor) -> tensor + %partitioned_input = "tf.TPUPartitionedInputV2"(%read0, %read1) {N = 2 : i64, partition_dims = []} : (tensor, tensor) -> tensor // expected-error@+1 {{unsupported input sharding type MAXIMAL for 0-th input}} %computation = "tf_device.cluster_func"(%partitioned_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @computation, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = [""], use_spmd_for_xla_partitioning = true} : (tensor) -> tensor - %partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {N = 2 : i64, partition_dim = -1 : i64} : (tensor) -> (tensor, tensor) + %partitioned_output:2 = "tf.TPUPartitionedOutputV2"(%computation) {N = 2 : i64, partition_dims = []} : (tensor) -> (tensor, tensor) "tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor>>, tensor) -> () "tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor>>, tensor) -> () func.return @@ -1283,18 +1283,18 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Test that unsupported output sharding type of TPUPartitionedOutputOp outputs +// Test that unsupported output sharding type of TPUPartitionedOutputV2Op outputs // of ClusterFuncOp result in error. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { func.func @cluster(%arg0: tensor>>, %arg1: tensor>>) { %read0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor %read1 = "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor - %partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i64} : (tensor, tensor) -> tensor + %partitioned_input = "tf.TPUPartitionedInputV2"(%read0, %read1) {N = 2 : i64, partition_dims = []} : (tensor, tensor) -> tensor // expected-error@+1 {{unsupported output sharding type MAXIMAL for 0-th output}} %computation = "tf_device.cluster_func"(%partitioned_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @computation, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = [""], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = true} : (tensor) -> tensor - %partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {N = 2 : i64, partition_dim = -1 : i64} : (tensor) -> (tensor, tensor) + %partitioned_output:2 = "tf.TPUPartitionedOutputV2"(%computation) {N = 2 : i64, partition_dims = []} : (tensor) -> (tensor, tensor) "tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor>>, tensor) -> () "tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor>>, tensor) -> () func.return @@ -1307,16 +1307,16 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- // Test that multiple uses of ClusterFuncOp output along with -// TPUPartitionedOutputOp results in error. +// TPUPartitionedOutputV2Op results in error. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { func.func @cluster(%arg0: tensor>>, %arg1: tensor>>) { %read0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor %read1 = "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor - %partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i64} : (tensor, tensor) -> tensor + %partitioned_input = "tf.TPUPartitionedInputV2"(%read0, %read1) {N = 2 : i64, partition_dims = []} : (tensor, tensor) -> tensor %computation = "tf_device.cluster_func"(%partitioned_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @computation, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = [""], output_sharding_configuration = [""], use_spmd_for_xla_partitioning = true} : (tensor) -> tensor - // expected-error@+1 {{'tf.TPUPartitionedOutput' op must be a unique user of TPU Cluster}} - %partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {N = 2 : i64, partition_dim = -1 : i64} : (tensor) -> (tensor, tensor) + // expected-error@+1 {{'tf.TPUPartitionedOutputV2' op must be a unique user of TPU Cluster}} + %partitioned_output:2 = "tf.TPUPartitionedOutputV2"(%computation) {N = 2 : i64, partition_dims = []} : (tensor) -> (tensor, tensor) "tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor>>, tensor) -> () "tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor>>, tensor) -> () "tf._SomeOp"(%computation) : (tensor) -> () @@ -2532,17 +2532,17 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Test `tf.TPUPartitionedInput` has outputs not in `tf_device.cluster_func` +// Test `tf.TPUPartitionedInputV2` has outputs not in `tf_device.cluster_func` module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { func.func @cluster(%arg0: tensor>>, %arg1: tensor>>) { - // expected-error@+1 {{Output of TPUPartitionedInput must be in tpu computation.}} - %partitioned_input = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> + // expected-error@+1 {{Output of TPUPartitionedInputV2 must be in tpu computation.}} + %partitioned_input = "tf.TPUPartitionedInputV2"(%arg0, %arg1) {N = 2 : i64, partition_dims = []} : (tensor>>, tensor>>) -> tensor>> %read = "tf.ReadVariableOp"(%partitioned_input) : (tensor>>) -> tensor %computation = "tf_device.cluster_func"(%read) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @computation, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = [""], output_sharding_configuration = [""], use_spmd_for_xla_partitioning = true} : (tensor) -> tensor - %partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {N = 2 : i64, partition_dim = -1 : i64} : (tensor) -> (tensor, tensor) + %partitioned_output:2 = "tf.TPUPartitionedOutputV2"(%computation) {N = 2 : i64, partition_dims = []} : (tensor) -> (tensor, tensor) "tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor>>, tensor) -> () "tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor>>, tensor) -> () func.return @@ -2554,17 +2554,17 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Test `tf.TPUPartitionedOutput` has inputs not in `tf_device.cluster_func` +// Test `tf.TPUPartitionedOutputV2` has inputs not in `tf_device.cluster_func` module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { func.func @cluster(%arg0: tensor>>, %arg1: tensor>>) { %read0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor %read1 = "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor - %partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, partition_dim = -1 : i64} : (tensor, tensor) -> tensor + %partitioned_input = "tf.TPUPartitionedInputV2"(%read0, %read1) {N = 2 : i64, partition_dims = []} : (tensor, tensor) -> tensor %computation = "tf_device.cluster_func"(%partitioned_input) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @computation, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\01\01\02\10\01\18\02\22\08\00\00\00\00\00\00\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1], input_sharding_configuration = [""], output_sharding_configuration = [""], use_spmd_for_xla_partitioning = true} : (tensor) -> tensor %add_result = "tf.Add"(%computation, %computation) : (tensor, tensor) -> tensor - // expected-error@+1 {{Input of TPUPartitionedOutput must be in tpu computation.}} - %partitioned_output:2 = "tf.TPUPartitionedOutput"(%add_result) {N = 2 : i64, partition_dim = -1 : i64} : (tensor) -> (tensor, tensor) + // expected-error@+1 {{Input of TPUPartitionedOutputV2 must be in tpu computation.}} + %partitioned_output:2 = "tf.TPUPartitionedOutputV2"(%add_result) {N = 2 : i64, partition_dims = []} : (tensor) -> (tensor, tensor) "tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor>>, tensor) -> () "tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor>>, tensor) -> () func.return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir index b673dfd1bc1..921248cf473 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir @@ -274,12 +274,12 @@ func.func @func_body(%arg0: tensor<*xi32>)-> tensor<*xi32> { // CHECK-LABEL: func @partitioned_input_output func.func @partitioned_input_output(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) { - %0 = "tf.TPUPartitionedInput"(%arg0) {_XlaSharding = "\01\02\03", partition_dim = -1 : i64} : (tensor<*xi32>) -> tensor<*xi32> + %0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "\01\02\03", partition_dims = []} : (tensor<*xi32>) -> tensor<*xi32> // CHECK: tf_device.cluster_func // CHECK-SAME: input_sharding_configuration = ["\01\02\03", ""] // CHECK-SAME: output_sharding_configuration = ["", "\04\05\06"] %1:2 = "tf_device.cluster_func"(%0, %arg1) {func = @cluster_func, use_spmd_for_xla_partitioning = true, num_cores_per_replica = 1 : i64} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) - %2 = "tf.TPUPartitionedOutput"(%1#1) {_XlaSharding = "\04\05\06", partition_dim = -1 : i64} : (tensor<*xi32>) -> tensor<*xi32> + %2 = "tf.TPUPartitionedOutputV2"(%1#1) {_XlaSharding = "\04\05\06", partition_dims = []} : (tensor<*xi32>) -> tensor<*xi32> func.return %1#0, %2 : tensor<*xi32>, tensor<*xi32> } @@ -296,7 +296,7 @@ func.func @cluster_func(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<* // CHECK-LABEL: func @partitioned_variable func.func @partitioned_variable(%arg0: tensor>>) { - %0 = "tf.TPUPartitionedInput"(%arg0) {_XlaSharding = "\01\02\03", partition_dim = -1 : i64} : (tensor>>) -> tensor>> + %0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "\01\02\03", partition_dims = []} : (tensor>>) -> tensor>> %1 = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor<*xf32> // CHECK: tf_device.cluster_func // CHECK-SAME: input_sharding_configuration = ["\01\02\03"] @@ -369,12 +369,12 @@ func.func @cluster_func(%arg0: tensor<*xf32>) { // CHECK-LABEL: func @partitioned_input_output func.func @partitioned_input_output(%arg0: tensor<*xi32>) -> tensor<*xi32> { - %0 = "tf.TPUPartitionedInput"(%arg0) {partition_dim = -1 : i64} : (tensor<*xi32>) -> tensor<*xi32> + %0 = "tf.TPUPartitionedInputV2"(%arg0) {partition_dims = []} : (tensor<*xi32>) -> tensor<*xi32> // CHECK: tf_device.cluster_func // CHECK-SAME: input_sharding_configuration = [""] // CHECK-SAME: output_sharding_configuration = [""] %1 = "tf_device.cluster_func"(%0) {func = @cluster_func, use_spmd_for_xla_partitioning = true, num_cores_per_replica = 1 : i64} : (tensor<*xi32>) -> tensor<*xi32> - %2 = "tf.TPUPartitionedOutput"(%1) {partition_dim = -1 : i64} : (tensor<*xi32>) -> tensor<*xi32> + %2 = "tf.TPUPartitionedOutputV2"(%1) {partition_dims = []} : (tensor<*xi32>) -> tensor<*xi32> func.return %2 : tensor<*xi32> } @@ -392,7 +392,7 @@ func.func @cluster_func(%arg0: tensor<*xi32>) -> tensor<*xi32> { // CHECK-LABEL: func @partitioned_input_output func.func @partitioned_input_output(%arg0: tensor>>) { - %0 = "tf.TPUPartitionedInput"(%arg0) {_XlaSharding = "\01\02\03", partition_dim = -1 : i64} : (tensor>>) -> tensor>> + %0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "\01\02\03", partition_dims = []} : (tensor>>) -> tensor>> // CHECK: tf_device.cluster_func // CHECK-SAME: input_sharding_configuration = [] // CHECK-SAME: output_sharding_configuration = ["\01\02\03"] @@ -414,13 +414,13 @@ func.func @cluster_func() -> tensor { // CHECK-LABEL: func @partitioned_input_maximal_sharding_revert_mpmd func.func @partitioned_input_maximal_sharding_revert_mpmd(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) { - %0 = "tf.TPUPartitionedInput"(%arg0) {_XlaSharding = "\08\01\1A\01\01\22\01\00", partition_dim = -1 : i64} : (tensor<*xi32>) -> tensor<*xi32> + %0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "\08\01\1A\01\01\22\01\00", partition_dims = []} : (tensor<*xi32>) -> tensor<*xi32> // CHECK: tf_device.cluster_func // CHECK-SAME: input_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"] // CHECK-SAME: output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\04\05\06"] // CHECK-SAME: use_spmd_for_xla_partitioning = false %1:2 = "tf_device.cluster_func"(%0, %arg1) {func = @cluster_func, use_spmd_for_xla_partitioning = true, num_cores_per_replica = 1 : i64} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) - %2 = "tf.TPUPartitionedOutput"(%1#1) {_XlaSharding = "\04\05\06", partition_dim = -1 : i64} : (tensor<*xi32>) -> tensor<*xi32> + %2 = "tf.TPUPartitionedOutputV2"(%1#1) {_XlaSharding = "\04\05\06", partition_dims = []} : (tensor<*xi32>) -> tensor<*xi32> func.return %1#0, %2 : tensor<*xi32>, tensor<*xi32> } @@ -441,9 +441,9 @@ func.func @partitioned_output_maximal_sharding_revert_mpmd(%arg0: tensor<*xi32>, // CHECK-SAME: input_sharding_configuration = ["\04\05\06", "\08\01\1A\01\01\22\01\00"] // CHECK-SAME: output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", "\08\01\1A\01\01\22\01\00"] // CHECK-SAME: use_spmd_for_xla_partitioning = false - %0 = "tf.TPUPartitionedInput"(%arg0) {_XlaSharding = "\04\05\06", partition_dim = -1 : i64} : (tensor<*xi32>) -> tensor<*xi32> + %0 = "tf.TPUPartitionedInputV2"(%arg0) {_XlaSharding = "\04\05\06", partition_dims = []} : (tensor<*xi32>) -> tensor<*xi32> %1:2 = "tf_device.cluster_func"(%0, %arg1) {func = @cluster_func, use_spmd_for_xla_partitioning = true, num_cores_per_replica = 1 : i64} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) - %2 = "tf.TPUPartitionedOutput"(%1#1) {_XlaSharding = "\08\01\1A\01\01\22\01\00", partition_dim = -1 : i64} : (tensor<*xi32>) -> tensor<*xi32> + %2 = "tf.TPUPartitionedOutputV2"(%1#1) {_XlaSharding = "\08\01\1A\01\01\22\01\00", partition_dims = []} : (tensor<*xi32>) -> tensor<*xi32> func.return %1#0, %2 : tensor<*xi32>, tensor<*xi32> } @@ -615,8 +615,8 @@ func.func @func(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK-SAME: input_sharding_configuration = ["", ""] // CHECK-SAME: output_sharding_configuration = ["\08\03\1A\02\02\01\22\02\00\01"] func.func @check_propagation_for_output_sharding_from_tf_matmul(%arg0: tensor<2x4xf32>, %arg1: tensor<4x2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %0 = "tf_device.cluster_func"(%arg0, %arg1) {func = @_func, use_spmd_for_xla_partitioning = true, use_tpu = true, num_cores_per_replica = 1 : i64} : (tensor<2x4xf32>, tensor<4x2xf32>) -> tensor<2x2xf32> - %1:2 = "tf.TPUPartitionedOutput"(%0) {device = "", partition_dim = 0 : i64} : (tensor<2x2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) + %0 = "tf_device.cluster_func"(%arg0, %arg1) {func = @_func, use_spmd_for_xla_partitioning = true, use_tpu = true, num_cores_per_replica = 2 : i64} : (tensor<2x4xf32>, tensor<4x2xf32>) -> tensor<2x2xf32> + %1:2 = "tf.TPUPartitionedOutputV2"(%0) {device = "", partition_dims = [2, 1]} : (tensor<2x2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) return %1#0, %1#1 : tensor<1x2xf32>, tensor<1x2xf32> } func.func @_func(%arg0: tensor<2x4xf32>, %arg1: tensor<4x2xf32>) -> tensor<2x2xf32> { @@ -630,8 +630,8 @@ func.func @_func(%arg0: tensor<2x4xf32>, %arg1: tensor<4x2xf32>) -> tensor<2x2xf // CHECK-SAME: input_sharding_configuration = ["", ""] // CHECK-SAME: output_sharding_configuration = ["\08\03\1A\02\02\01\22\02\00\01"] func.func @check_propagation_for_output_sharding_from_tf_matmul_following_by_identity_op(%arg0: tensor<2x4xf32>, %arg1: tensor<4x2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) { - %0 = "tf_device.cluster_func"(%arg0, %arg1) {func = @_func, use_spmd_for_xla_partitioning = true, use_tpu = true, num_cores_per_replica = 1 : i64} : (tensor<2x4xf32>, tensor<4x2xf32>) -> tensor<2x2xf32> - %1:2 = "tf.TPUPartitionedOutput"(%0) {device = "", partition_dim = 0 : i64} : (tensor<2x2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) + %0 = "tf_device.cluster_func"(%arg0, %arg1) {func = @_func, use_spmd_for_xla_partitioning = true, use_tpu = true, num_cores_per_replica = 2 : i64} : (tensor<2x4xf32>, tensor<4x2xf32>) -> tensor<2x2xf32> + %1:2 = "tf.TPUPartitionedOutputV2"(%0) {device = "", partition_dims = [2, 1]} : (tensor<2x2xf32>) -> (tensor<1x2xf32>, tensor<1x2xf32>) return %1#0, %1#1 : tensor<1x2xf32>, tensor<1x2xf32> } func.func @_func(%arg0: tensor<2x4xf32>, %arg1: tensor<4x2xf32>) -> tensor<2x2xf32> { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir index 5cbcc21080e..ec2d36d1267 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir @@ -448,6 +448,17 @@ func.func @batchMatMulV2MatrixAdjXY(%arg0: tensor<5x4xf32>, %arg1: tensor<6x5xf3 // CHECK: return %[[MATMUL_1]] : tensor<4x6xf32> } +// ----- + +func.func @batchMatMulV2DynamicSize(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor, tensor) -> tensor + func.return %0 : tensor + + // CHECK-LABEL: batchMatMulV2DynamicSize + // CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor, tensor) -> tensor + // CHECK: return %[[MATMUL_1]] : tensor +} + // ----- // ==== V3 tests ==== diff --git a/tensorflow/compiler/mlir/tensorflow/tests/update_control_dependencies.mlir b/tensorflow/compiler/mlir/tensorflow/tests/update_control_dependencies.mlir index 9a4a23f1cf6..263a6762238 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/update_control_dependencies.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/update_control_dependencies.mlir @@ -205,15 +205,15 @@ func.func @tpu_load_embedding_ops_sink_controls(%arg0: tensor<*x!tf_type.resourc // ----- -// Test that we don't create dependencies between ops on different devices, even -// if both have unknown side effects. +// Tests that we don't create dependencies between ops with same parallel group +// ID but different branch IDs, even if both ops have unknown side effects. // Also test that the fetch op still depends on all side-effecting ops. -func.func @different_devices() { +func.func @same_group_different_branches() { tf_executor.graph { // CHECK: %[[control:.*]] = tf_executor.island wraps "tf.A"() - tf_executor.island wraps "tf.A"() {is_stateless = false, device = "CPU:0"} : () -> () + tf_executor.island wraps "tf.A"() {is_stateless = false, _parallel_execution_ids = "p0:0"} : () -> () // CHECK: %[[control_2:.*]] = tf_executor.island wraps "tf.B"() - tf_executor.island wraps "tf.B"() {is_stateless = false, device = "CPU:1"} : () -> () + tf_executor.island wraps "tf.B"() {is_stateless = false, _parallel_execution_ids = "p0:1"} : () -> () // CHECK: tf_executor.fetch %[[control]], %[[control_2]] : !tf_executor.control, !tf_executor.control tf_executor.fetch } @@ -222,14 +222,14 @@ func.func @different_devices() { // ----- -// Test that we do create dependencies between ops with different but compatible -// device attributes, if both ops have unknown side effects. -func.func @compatible_devices() { +// Tests that we create dependencies between ops with same parallel group ID and +// same branch ID, if both ops have unknown side effects. +func.func @same_group_same_branch() { tf_executor.graph { // CHECK: %[[control:.*]] = tf_executor.island wraps "tf.A"() - tf_executor.island wraps "tf.A"() {is_stateless = false, device = "CPU:0"} : () -> () + tf_executor.island wraps "tf.A"() {is_stateless = false, _parallel_execution_ids = "p0:0"} : () -> () // CHECK: %[[control_2:.*]] = tf_executor.island(%[[control]]) wraps "tf.B"() - tf_executor.island wraps "tf.B"() {is_stateless = false, device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> () + tf_executor.island wraps "tf.B"() {is_stateless = false, _parallel_execution_ids = "p0:0"} : () -> () // CHECK: tf_executor.fetch %[[control_2]] : !tf_executor.control tf_executor.fetch } @@ -238,25 +238,148 @@ func.func @compatible_devices() { // ----- -// More complex test with mixed compatible and different devices. In this case, -// side effect analysis should report following dependencies +// Tests one group with multiple branches. In this case, side effect analysis +// should report following dependencies // A -> B -> C -> D -> E -> fetch -// and we expect following dependency chains (one chain per device) +// and we expect following dependency chains after the pass // A -> D -> fetch, B -> E -> fetch, C -> fetch. -func.func @mixed_compatible_and_different_devices() { +func.func @one_group_multiple_branches() { tf_executor.graph { // CHECK: %[[control:.*]] = tf_executor.island wraps "tf.A"() - tf_executor.island wraps "tf.A"() {is_stateless = false, device = "CPU:0"} : () -> () + tf_executor.island wraps "tf.A"() {is_stateless = false, _parallel_execution_ids = "p0:0"} : () -> () // CHECK: %[[control_2:.*]] = tf_executor.island wraps "tf.B"() - tf_executor.island wraps "tf.B"() {is_stateless = false, device = "TPU:0"} : () -> () + tf_executor.island wraps "tf.B"() {is_stateless = false, _parallel_execution_ids = "p0:1"} : () -> () // CHECK: %[[control_3:.*]] = tf_executor.island wraps "tf.C"() - tf_executor.island wraps "tf.C"() {is_stateless = false, device = "CPU:2"} : () -> () + tf_executor.island wraps "tf.C"() {is_stateless = false, _parallel_execution_ids = "p0:2000"} : () -> () // CHECK: %[[control_4:.*]] = tf_executor.island(%[[control]]) wraps "tf.D"() - tf_executor.island wraps "tf.D"() {is_stateless = false, device = "CPU:0"} : () -> () + tf_executor.island wraps "tf.D"() {is_stateless = false, _parallel_execution_ids = "p0:0"} : () -> () // CHECK: %[[control_5:.*]] = tf_executor.island(%[[control_2]]) wraps "tf.E"() - tf_executor.island wraps "tf.E"() {is_stateless = false, device = "TPU:0"} : () -> () + tf_executor.island wraps "tf.E"() {is_stateless = false, _parallel_execution_ids = "p0:1"} : () -> () // CHECK: tf_executor.fetch %[[control_3]], %[[control_4]], %[[control_5]] : !tf_executor.control, !tf_executor.control, !tf_executor.control tf_executor.fetch } func.return -} \ No newline at end of file +} + +// ----- + +// Tests nested replica and parallel execute groups. +func.func @nested_replica_and_parallel_execute_groups() { + tf_executor.graph { + // CHECK: %[[control:.*]] = tf_executor.island wraps "tf.A"() + tf_executor.island wraps "tf.A"() : () -> () + // CHECK-NEXT: %[[control_2:.*]] = tf_executor.island(%[[control]]) wraps "tf.B"() + tf_executor.island wraps "tf.B"() {_parallel_execution_ids = "r1:1"} : () -> () + // CHECK-NEXT: %[[control_3:.*]] = tf_executor.island(%[[control_2]]) wraps "tf.C"() + tf_executor.island wraps "tf.C"() {_parallel_execution_ids = "r1:1,p2:1"} : () -> () + // CHECK-NEXT: %[[control_4:.*]] = tf_executor.island(%[[control_2]]) wraps "tf.D"() + tf_executor.island wraps "tf.D"() {_parallel_execution_ids = "r1:1,p2:2"} : () -> () + // CHECK-NEXT: %[[control_5:.*]] = tf_executor.island(%[[control]]) wraps "tf.B"() + tf_executor.island wraps "tf.B"() {_parallel_execution_ids = "r1:2"} : () -> () + // CHECK-NEXT: %[[control_6:.*]] = tf_executor.island(%[[control_5]]) wraps "tf.C"() + tf_executor.island wraps "tf.C"() {_parallel_execution_ids = "r1:2,p3:1"} : () -> () + // CHECK-NEXT: %[[control_7:.*]] = tf_executor.island(%[[control_5]]) wraps "tf.D"() + tf_executor.island wraps "tf.D"() {_parallel_execution_ids = "r1:2,p3:2"} : () -> () + // CHECK-NEXT: tf_executor.fetch %[[control_3]], %[[control_4]], %[[control_6]], %[[control_7]] : !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control + tf_executor.fetch + } + func.return +} + +// ----- + +// Tests mixed and nested groups and branches. In this case, side effect +// analysis should report following dependencies +// A -> B -> C -> D -> E -> fetch +// and we expect following dependency chains after the pass +// A -> B -> D -> fetch, C -> fetch, E -> fetch. +func.func @mixed_groups_and_branches_nested() { + tf_executor.graph { + // CHECK: %[[control:.*]] = tf_executor.island wraps "tf.A"() + tf_executor.island wraps "tf.A"() {is_stateless = false, _parallel_execution_ids = "p0:0"} : () -> () + // CHECK-NEXT: %[[control_2:.*]] = tf_executor.island(%[[control]]) wraps "tf.B"() + tf_executor.island wraps "tf.B"() {is_stateless = false, _parallel_execution_ids = "p0:0,r1000:0"} : () -> () + // CHECK-NEXT: %[[control_3:.*]] = tf_executor.island wraps "tf.C"() + tf_executor.island wraps "tf.C"() {is_stateless = false, _parallel_execution_ids = "p0:1,r1000:0"} : () -> () + // CHECK-NEXT: %[[control_4:.*]] = tf_executor.island(%[[control_2]], %[[control_3]]) wraps "tf.D"() + tf_executor.island wraps "tf.D"() {is_stateless = false, _parallel_execution_ids = "r1000:0"} : () -> () + // CHECK-NEXT: %[[control_5:.*]] = tf_executor.island wraps "tf.E"() + tf_executor.island wraps "tf.E"() {is_stateless = false, _parallel_execution_ids = "p0:1,r1000:3000"} : () -> () + // CHECK-NEXT: tf_executor.fetch %[[control_4]], %[[control_5]] : !tf_executor.control, !tf_executor.control + tf_executor.fetch + } + func.return +} + +// ----- + +// Tests that we create dependencies between ops where one op has a parallel +// execution ID and the other has not. +func.func @unspecified_parallel_execution_ids() { + tf_executor.graph { + // CHECK: %[[control:.*]] = tf_executor.island wraps "tf.A"() + tf_executor.island wraps "tf.A"() {is_stateless = false} : () -> () + // CHECK-NEXT: %[[control_2:.*]] = tf_executor.island(%[[control]]) wraps "tf.B"() + tf_executor.island wraps "tf.B"() {is_stateless = false, _parallel_execution_ids = "p0:0"} : () -> () + // CHECK-NEXT: %[[control_3:.*]] = tf_executor.island(%[[control]]) wraps "tf.C"() + tf_executor.island wraps "tf.C"() {is_stateless = false, _parallel_execution_ids = "p0:1"} : () -> () + // CHECK-NEXT: tf_executor.fetch %[[control_2]], %[[control_3]] : !tf_executor.control, !tf_executor.control + tf_executor.fetch + } + func.return +} + +// ----- + +func.func @missing_branch_id() { + tf_executor.graph { + // expected-error@+1 {{Malformed _parallel_execution_ids attribute}} + tf_executor.island wraps "tf.A"() {is_stateless = false, _parallel_execution_ids = "p0:"} : () -> () + tf_executor.fetch + } + func.return +} + +// ----- + +func.func @missing_colon() { + tf_executor.graph { + // expected-error@+1 {{Malformed _parallel_execution_ids attribute}} + tf_executor.island wraps "tf.A"() {is_stateless = false, _parallel_execution_ids = "r01"} : () -> () + tf_executor.fetch + } + func.return +} + +// ----- + +func.func @missing_group_id_prefix() { + tf_executor.graph { + // expected-error@+1 {{Malformed _parallel_execution_ids attribute}} + tf_executor.island wraps "tf.A"() {is_stateless = false, _parallel_execution_ids = "0:0"} : () -> () + tf_executor.fetch + } + func.return +} + +// ----- + +func.func @invalid_group_id_prefix() { + tf_executor.graph { + // expected-error@+1 {{Malformed _parallel_execution_ids attribute}} + tf_executor.island wraps "tf.A"() {is_stateless = false, _parallel_execution_ids = "s0:0"} : () -> () + tf_executor.fetch + } + func.return +} + +// ----- + +func.func @extra_colon() { + tf_executor.graph { + // expected-error@+1 {{Malformed _parallel_execution_ids attribute}} + tf_executor.island wraps "tf.A"() {is_stateless = false, _parallel_execution_ids = "r0:0:1"} : () -> () + tf_executor.fetch + } + func.return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_cluster_formation.mlir index 658ddf5c767..536d14305e7 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/xla_cluster_formation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/xla_cluster_formation.mlir @@ -1,23 +1,30 @@ -// RUN: tf-opt %s -tf-xla-cluster-formation | FileCheck %s +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-xla-cluster-formation | FileCheck %s // Check that we outline the partitioned call to a device cluster (since it has // `_xla_compile_device_type`). +// CHECK-LABEL: func.func @xla_must_compile_true // CHECK: tf_device.cluster // CHECK-NEXT: tf.StatefulPartitionedCall // CHECK-NEXT: tf_device.return // CHECK: tf.Const // CHECK: tf.Add -func.func @xla_must_compile_true(%arg0: tensor) -> tensor { +func.func @xla_must_compile_true(%arg0: tensor) -> tensor attributes {tf.entry_function = {}} { %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", _xla_compile_device_type = "CPU", f = @stateful_pcall_func} : (tensor) -> (tensor) %1 = "tf.Const"() {value = dense<5> : tensor} : () -> tensor %2 = "tf.Add"(%0, %1) : (tensor, tensor) -> (tensor) func.return %2 : tensor } +func.func @stateful_pcall_func(%arg0: tensor) -> tensor { + func.return %arg0 : tensor +} +// ----- + // Check that we don't outline the partitioned call to a device cluster (since // it does not has `_xla_compile_device_type`). +// CHECK-LABEL: func.func @xla_must_compile_false // CHECK-NOT: tf_device.cluster -func.func @xla_must_compile_false(%arg0: tensor) -> tensor { +func.func @xla_must_compile_false(%arg0: tensor) -> tensor attributes {tf.entry_function = {}} { %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_func} : (tensor) -> (tensor) %1 = "tf.Const"() {value = dense<5> : tensor} : () -> tensor %2 = "tf.Add"(%0, %1) : (tensor, tensor) -> (tensor) @@ -27,3 +34,44 @@ func.func @xla_must_compile_false(%arg0: tensor) -> tensor { func.func @stateful_pcall_func(%arg0: tensor) -> tensor { func.return %arg0 : tensor } + +// ----- + +// CHECK-LABEL: func.func @nested_calls +func.func @nested_calls(%arg0: tensor) -> tensor attributes {tf.entry_function = {}} { + %0 = "tf.While"(%arg0) {cond = @while_cond_func, body = @while_body_func, is_stateless = true} : (tensor) -> (tensor) + func.return %0 : tensor +} + +// CHECK-LABEL: func.func @while_cond_func +func.func @while_cond_func(%arg0: tensor) -> tensor { + %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func.func @while_body_func +func.func @while_body_func(%arg0: tensor) -> (tensor) { + // CHECK-NOT: tf_device.cluster + %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @outer_stateful_pcall_func} : (tensor) -> (tensor) + func.return %0 : tensor +} + +// CHECK-LABEL: func.func @outer_stateful_pcall_func +func.func @outer_stateful_pcall_func(%arg0: tensor) -> (tensor) { + // CHECK: tf_device.cluster + // CHECK-NEXT: tf.StatefulPartitionedCall + // CHECK-NEXT: tf_device.return + %0 = "tf.StatefulPartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @inner_stateful_pcall_func} : (tensor) -> (tensor) + func.return %0 : tensor +} + +// CHECK-LABEL: func.func @inner_stateful_pcall_func +func.func @inner_stateful_pcall_func(%arg0: tensor) -> tensor { + // CHECK-NOT: tf_device.cluster + %0 = "tf.StatefulPartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @func} : (tensor) -> (tensor) + func.return %0 : tensor +} + +func.func @func(%arg0: tensor) -> tensor { + func.return %arg0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite.mlir index 32d66ee9424..973fe031d75 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/xla_rewrite.mlir @@ -1,11 +1,12 @@ -// RUN: tf-opt %s -tf-xla-rewrite | FileCheck %s - -// ----- +// RUN: tf-opt %s -split-input-file -tf-xla-rewrite | FileCheck %s // CHECK-LABEL: func.func @convert_partitioned_call func.func @convert_partitioned_call(%arg0: tensor) -> tensor { - // CHECK: "tf.XlaLaunch"(%arg0) {_xla_compile_device_type = "CPU", device = "/device:CPU:0", function = @pcall_func, operand_segment_sizes = array} : (tensor) -> tensor - %0 = "tf.PartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @pcall_func} : (tensor) -> (tensor) + %0 = "tf_device.cluster"() ({ + // CHECK: "tf.XlaLaunch"(%arg0) {_xla_compile_device_type = "CPU", device = "/device:CPU:0", function = @pcall_func, operand_segment_sizes = array} : (tensor) -> tensor + %1 = "tf.PartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @pcall_func} : (tensor) -> (tensor) + tf_device.return %1 : tensor + }) : () -> tensor func.return %0 : tensor } @@ -17,8 +18,12 @@ func.func @pcall_func(%arg0: tensor) -> tensor { // CHECK-LABEL: func.func @convert_stateful_partitioned_call func.func @convert_stateful_partitioned_call(%arg0: tensor) -> tensor { - // CHECK: "tf.XlaLaunch"(%arg0) {_xla_compile_device_type = "CPU", device = "/device:CPU:0", function = @stateful_pcall_func, operand_segment_sizes = array} : (tensor) -> tensor - %0 = "tf.StatefulPartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @stateful_pcall_func} : (tensor) -> (tensor) + %0 = "tf_device.cluster"() ({ + // CHECK: "tf.XlaLaunch"(%arg0) {_xla_compile_device_type = "CPU", device = "/device:CPU:0", function = @stateful_pcall_func, operand_segment_sizes = array} : (tensor) -> tensor + %1 = "tf.StatefulPartitionedCall"(%arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @stateful_pcall_func} : (tensor) -> (tensor) + tf_device.return %1 : tensor + }) : () -> tensor + func.return %0 : tensor } @@ -29,30 +34,39 @@ func.func @stateful_pcall_func(%arg0: tensor) -> tensor { // ----- // CHECK-LABEL: func.func @convert_stateful_partitioned_call_with_resources_in_order -func.func @convert_stateful_partitioned_call_with_resources_in_order(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: %0 = "tf.XlaLaunch"(%arg1, %arg0) {_xla_compile_device_type = "CPU", device = "/device:CPU:0", function = @stateful_pcall_func_with_resources_in_order, operand_segment_sizes = array} : (tensor, tensor) -> tensor - %0 = "tf.StatefulPartitionedCall"(%arg1, %arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @stateful_pcall_func_with_resources_in_order} : (tensor, tensor) -> (tensor) - func.return %0 : tensor +func.func @convert_stateful_partitioned_call_with_resources_in_order(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf_device.cluster"() ({ + // CHECK: "tf.XlaLaunch"(%arg1, %arg0) {_xla_compile_device_type = "CPU", device = "/device:CPU:0", function = @stateful_pcall_func_with_resources_in_order, operand_segment_sizes = array} : (tensor, tensor) -> tensor + %1 = "tf.StatefulPartitionedCall"(%arg1, %arg0) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @stateful_pcall_func_with_resources_in_order} : (tensor, tensor) -> (tensor) + tf_device.return %1 : tensor + }) : () -> tensor + func.return %0 : tensor } -func.func @stateful_pcall_func_with_resources_in_order(%arg0 : tensor, %arg1 : tensor) -> tensor { - func.return %arg0 : tensor +func.func @stateful_pcall_func_with_resources_in_order(%arg0 : tensor, %arg1 : tensor) -> tensor { + func.return %arg0 : tensor } // ----- // CHECK-LABEL: func.func @convert_stateful_partitioned_call_with_resources -func.func @convert_stateful_partitioned_call_with_resources(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: "tf.XlaLaunch"(%arg1, %arg0) {_xla_compile_device_type = "CPU", device = "/device:CPU:0", function = @stateful_pcall_func_with_resources, operand_segment_sizes = array} : (tensor, tensor) -> tensor - %0 = "tf.StatefulPartitionedCall"(%arg0, %arg1) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @stateful_pcall_func_with_resources} : (tensor, tensor) -> (tensor) - // CHECK: "tf.XlaLaunch"(%arg1, %arg0) {_xla_compile_device_type = "CPU", device = "/device:CPU:0", function = @stateful_pcall_func_with_resources, operand_segment_sizes = array} : (tensor, tensor) -> tensor - %1 = "tf.StatefulPartitionedCall"(%arg0, %arg1) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @stateful_pcall_func_with_resources} : (tensor, tensor) -> (tensor) - func.return %0 : tensor +func.func @convert_stateful_partitioned_call_with_resources(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf_device.cluster"() ({ + // CHECK: "tf.XlaLaunch"(%arg1, %arg0) {_xla_compile_device_type = "CPU", device = "/device:CPU:0", function = @stateful_pcall_func_with_resources, operand_segment_sizes = array} : (tensor, tensor) -> tensor + %2 = "tf.StatefulPartitionedCall"(%arg0, %arg1) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @stateful_pcall_func_with_resources} : (tensor, tensor) -> tensor + tf_device.return %2 : tensor + }) : () -> tensor + %1 = "tf_device.cluster"() ({ + // CHECK: "tf.XlaLaunch"(%arg1, %arg0) {_xla_compile_device_type = "CPU", device = "/device:CPU:0", function = @stateful_pcall_func_with_resources, operand_segment_sizes = array} : (tensor, tensor) -> tensor + %2 = "tf.StatefulPartitionedCall"(%arg0, %arg1) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @stateful_pcall_func_with_resources} : (tensor, tensor) -> tensor + tf_device.return %2 : tensor + }) : () -> tensor + return %0 : tensor } // CHECK-LABEL: func.func @stateful_pcall_func_with_resources -// CHECK-SAME: (%arg0: tensor, %arg1: tensor) -> tensor -// CHECK: return %arg0 : tensor -func.func @stateful_pcall_func_with_resources(%arg0 : tensor, %arg1: tensor) -> tensor { - func.return %arg1 : tensor +// CHECK-SAME: (%arg0: tensor, %arg1: tensor) -> tensor +// CHECK: return %arg0 : tensor +func.func @stateful_pcall_func_with_resources(%arg0 : tensor, %arg1: tensor) -> tensor { + func.return %arg1 : tensor } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc index cdd2a0536ba..14ab17be0fd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc @@ -49,8 +49,8 @@ class ConvertTFBatchMatMulToEinsumOp LogicalResult matchAndRewrite(BatchMatMulOpType op, PatternRewriter& rewriter) const override { - Value input_lhs = op.x(); - Value input_rhs = op.y(); + Value input_lhs = op.getX(); + Value input_rhs = op.getY(); // LHS and RHS must be a ranked tensor type auto lhs_type = input_lhs.getType().dyn_cast(); @@ -70,8 +70,8 @@ class ConvertTFBatchMatMulToEinsumOp // einsum equation for batchmatmul std::string equation("...mk,...kn->...mn"); - if (op.adj_x()) std::swap(equation[3], equation[4]); - if (op.adj_y()) std::swap(equation[6 + 3], equation[6 + 4]); + if (op.getAdjX()) std::swap(equation[3], equation[4]); + if (op.getAdjY()) std::swap(equation[6 + 3], equation[6 + 4]); rewriter.replaceOpWithNewOp( op, op.getType(), diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 28779f6c29d..43742deb647 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" #include +#include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -59,9 +60,6 @@ tensorflow::Status RunTFXLABridge( // Populate a passmanager with the list of passes that implement the bridge. pipeline_builder(bridge); - // Add set of passes to lower back to graph (from tf_executor). - TF::AddGraphExportLoweringPasses(bridge); - mlir::StatusScopedDiagnosticHandler diag_handler( module.getContext(), /*propagate=*/false, /*filter_stack=*/!VLOG_IS_ON(1)); @@ -97,6 +95,7 @@ void CreateTPUBridgePipelineImpl(OpPassManager &pm) { // Run shape inference so that tf_executor/tf_device ops created later will // likely to inherit more concrete types. pm.addPass(TF::CreateTFShapeInferencePass()); + pm.addNestedPass(CreateTPUPartitionedOpConversionPass()); pm.addNestedPass( CreateTPUReorderReplicateAndPartitionedInputsPass()); pm.addNestedPass(TF::CreateDecomposeReduceDatasetPass()); @@ -179,6 +178,8 @@ void CreateTPUBridgePipelineImpl(OpPassManager &pm) { pm.addNestedPass( CreateTPUResourceReadsWritesPartitioningPass()); pm.addPass(TFDevice::CreateAnnotateParameterReplicationPass()); + pm.addNestedPass( + mlir::TF::CreateRewriteTPUEmbeddingOpsPass()); pm.addPass(CreateTPURewritePass()); pm.addPass(createSymbolDCEPass()); pm.addNestedPass( @@ -226,8 +227,15 @@ void CreateTPUBridgePipelineV1(OpPassManager &pm) { tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging, bool fallback_enabled) { - Status status = - RunTFXLABridge(module, enable_logging, CreateTPUBridgePipeline); + Status status = RunTFXLABridge(module, enable_logging, [](OpPassManager &pm) { + CreateTPUBridgePipeline(pm); + // Add set of passes to lower back to graph (from tf_executor). + // Use graph export pipline V2 in TPU Bridge. + // TODO(hanxiong): Completely replace AddGraphExportLoweringPasses with + // AddGraphExortLoweringPassessV2 in all the code paths (V1 compat pipeline, + // CPU/GPU bridge, etc.) + TF::AddGraphExportLoweringPassesV2(pm); + }); tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( "tpu", "v2", fallback_enabled, status.ok() ? "success" : "failure"); OkOrSetErrorCounterPayload( @@ -237,8 +245,11 @@ tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging, } tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging, bool fallback_enabled) { - Status status = - RunTFXLABridge(module, enable_logging, CreateTPUBridgePipelineV1); + Status status = RunTFXLABridge(module, enable_logging, [](OpPassManager &pm) { + CreateTPUBridgePipelineV1(pm); + // Add set of passes to lower back to graph (from tf_executor). + TF::AddGraphExportLoweringPasses(pm); + }); tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( "tpu", "v1", fallback_enabled, status.ok() ? "success" : "failure"); return status; @@ -255,10 +266,43 @@ void AddGraphExportLoweringPasses(OpPassManager &pm) { }; add_pass(CreateFunctionalToExecutorDialectConversionPass()); - add_pass(TFDevice::CreateReplicateToIslandPass()); + add_pass(TFDevice::CreateReplicateToIslandPass(/*legacy_graph_export=*/true)); add_pass(TFDevice::CreateReplicaIDToDeviceOrdinalPass()); - add_pass(TFDevice::CreateParallelExecuteToIslandsPass()); - add_pass(TFDevice::CreateLaunchToDeviceAttributePass()); + add_pass(TFDevice::CreateParallelExecuteToIslandsPass( + /*legacy_graph_export=*/true)); + add_pass(TFDevice::CreateLaunchToDeviceAttributePass( + /*legacy_graph_export=*/true)); + pm.addNestedPass(TFTPU::CreateTPUDevicePropagationPass()); + pm.addPass(createSymbolDCEPass()); + if (tensorflow::GetMlirCommonFlags() + ->tf_mlir_enable_convert_control_to_data_outputs_pass) { + pm.addPass(tf_executor::CreateTFExecutorConvertControlToDataOutputsPass()); + } + pm.addPass(CreateVerifySuitableForExportPass()); +} + +void AddGraphExportLoweringPassesV2(OpPassManager &pm) { + // First, we need to convert from functional, to executor dialect. + pm.addNestedPass( + CreateFunctionalToExecutorDialectConversionPass()); + + // Do a single pass to split the graph's single island op into an island per + // op as expected by the following passes. + pm.addNestedPass(CreateSplitIntoIslandPerOpPass()); + + pm.addNestedPass(TFDevice::CreateReplicateToIslandPass( + /*legacy_graph_export=*/false)); + pm.addNestedPass( + TFDevice::CreateReplicaIDToDeviceOrdinalPass()); + pm.addNestedPass(TFDevice::CreateParallelExecuteToIslandsPass( + /*legacy_graph_export=*/false)); + pm.addNestedPass(TFDevice::CreateLaunchToDeviceAttributePass( + /*legacy_graph_export=*/false)); + + // Do a single pass to encode necessary control deps in the IR according to + // the results of side effect analysis. + pm.addPass(tf_executor::CreateTFExecutorUpdateControlDependenciesPass()); + pm.addNestedPass(TFTPU::CreateTPUDevicePropagationPass()); pm.addPass(createSymbolDCEPass()); if (tensorflow::GetMlirCommonFlags() @@ -282,15 +326,13 @@ tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module, /*filter_stack=*/!VLOG_IS_ON(1)); if (enable_logging || VLOG_IS_ON(1)) { - tensorflow::DumpMlirOpToFile("standard_pipeline_before", module, "", - &bridge); + tensorflow::DumpMlirOpToFile(kStandardPipelineBefore, module, "", &bridge); if (VLOG_IS_ON(2)) EnableDetailedLogging(&bridge); } LogicalResult result = bridge.run(module); (void)result; if (enable_logging || VLOG_IS_ON(1)) - tensorflow::DumpMlirOpToFile("standard_pipeline_after", module, "", - &bridge); + tensorflow::DumpMlirOpToFile(kStandardPipelineAfter, module, "", &bridge); return diag_handler.ConsumeStatus(); } @@ -310,6 +352,7 @@ void CreateTFXLABridgePipeline(OpPassManager &pm) { CreateExecutorDialectToFunctionalConversionPass()); // Guarantee all functions have one use, which enables more exact shape // inference. + pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass()); pm.addPass(TF::CreateTFShapeInferencePass()); // Encapsulate PartitionedCall ops within a cluster so that the composite // resource ops can be decomposed. @@ -327,9 +370,9 @@ void CreateTFXLABridgePipeline(OpPassManager &pm) { pm.addPass(TF::CreateTFShapeInferencePass()); pm.addNestedPass(createCanonicalizerPass()); pm.addPass(TFDevice::CreateResourceOpLiftingPass()); - // Inline the StatefulPartitionedCallOp op based in the parent region. - pm.addPass(TFDevice::CreateXlaInlineDeviceOpsPass()); pm.addPass(TFDevice::CreateXlaRewritePass()); + // Inline the cluster ops. + pm.addPass(TFDevice::CreateXlaInlineDeviceOpsPass()); // Re-run the canonicalizer pass as some cleanup during resource op lifting // pass opens up some opportunities for canonicalization of cluster ops. // Specifically, we want to eliminate pass through results from the cluster @@ -342,8 +385,12 @@ void CreateTFXLABridgePipeline(OpPassManager &pm) { } tensorflow::Status RunTFXLABridge(ModuleOp module, bool enable_logging) { - Status status = mlir::TFTPU::RunTFXLABridge(module, enable_logging, - CreateTFXLABridgePipeline); + Status status = mlir::TFTPU::RunTFXLABridge( + module, enable_logging, [](OpPassManager &pm) { + CreateTFXLABridgePipeline(pm); + // Add set of passes to lower back to graph (from tf_executor). + TF::AddGraphExportLoweringPasses(pm); + }); tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( /*device type*/ "cpu/gpu", /*bridge version*/ "tfxla", /*fallback_enabled*/ false, diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h index c0dc24416ff..925149dd843 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h @@ -44,6 +44,9 @@ tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging, namespace TF { +inline constexpr char kStandardPipelineBefore[] = "standard_pipeline_before"; +inline constexpr char kStandardPipelineAfter[] = "standard_pipeline_after"; + // Runs all passes involved in transforming or optimizing an MLIR graph without // any target specialization. When enable_logging is true, enables // tensorflow::BridgeLogger. When enable_inliner is true, enables the inliner diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/call_graph_util.h b/tensorflow/compiler/mlir/tensorflow/transforms/call_graph_util.h new file mode 100644 index 00000000000..e04d1323352 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/call_graph_util.h @@ -0,0 +1,64 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CALL_GRAPH_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CALL_GRAPH_UTIL_H_ + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { + +// Find the outermost ops with any of specified types starting from the tree +// rooted at `root` parameter. The results are stored in `ops`. Addtional +// filters can be specified by providing `predicate` parameter. +template +LogicalResult GetOutermostOpsOfType( + func::FuncOp root, SymbolTable &symtab, llvm::SmallVector &ops, + const std::function &predicate = {}) { + std::stack worklist; + worklist.push(root); + while (!worklist.empty()) { + func::FuncOp u = worklist.top(); + worklist.pop(); + auto result = u.walk([&](SymbolUserOpInterface op) { + if (llvm::isa(op) && (!predicate || predicate(op))) { + ops.push_back(op); + return WalkResult::advance(); + } + for (auto attr : op->getAttrs()) { + auto sym = attr.getValue().dyn_cast(); + if (!sym) continue; + auto v = symtab.lookup(sym.getRootReference()); + if (!v) { + // This is not expected to happen in practice. + v.emitError() << "Cannot find function " << sym.getRootReference(); + return WalkResult::interrupt(); + } + worklist.push(v); + } + return WalkResult::advance(); + }); + if (result.wasInterrupted()) return failure(); + } + return success(); +} + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CALL_GRAPH_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc index ead70410731..3c8541d2a50 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc @@ -17,16 +17,20 @@ limitations under the License. // assigned to save devices. Clusters are represented as regions. // Note that side-effecting ops are not correctly handled yet. +#include +#include + #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -59,19 +63,39 @@ StringRef GetDevice(Operation* op) { return device_attr ? device_attr.getValue() : ""; } -// An op can be merged into cluster if all of its operands are one of the -// following: -// 1) A block argument -// 2) A value produced by other islands -// 1) Defined before the cluster -// 2) Defined by an operation in the cluster +// An op can be merged into cluster if it satisfies both of the following +// conditions: +// +// * All of its operands are one of the following: +// 1) A block argument +// 2) A value produced by other islands +// 3) Defined before the cluster +// 4) Defined by an operation in the cluster +// * Merging the op into the cluster does not reorder control dependencies. +// // TODO(ycao): This is not optimal as it doesn't consider the situation of // defining_op's operands all meet the requirements above. In that case, the // defining_op can be moved and to_merge op would be legal to absorb. -// TODO(ycao): Take op side-effects into consideration since they can not be -// re-ordered but forming clusters of non-continuous ops is effectively -// re-ordering them.. -bool CanMergeIntoCluster(const Cluster& c, Operation* to_merge) { +bool CanMergeIntoCluster( + const Cluster& c, Operation* to_merge, + const TF::SideEffectAnalysis::Info& side_effect_analysis) { + // If any of the op's control predecessors appears after the last op in the + // cluster, merging the op may cause control dependencies to be reordered. + // Hence, the op cannot be merged to the cluster in such a case. + const bool has_control_predecessors_after_cluster = + !side_effect_analysis + .DirectControlPredecessors( + to_merge, + [&c](Operation* pred) { + Operation* const last_c_op = c.ops.back(); + return last_c_op->getBlock() == pred->getBlock() && + last_c_op->isBeforeInBlock(pred); + }) + .empty(); + if (has_control_predecessors_after_cluster) { + return false; + } + return llvm::all_of(to_merge->getOperands(), [&](Value operand) { // Block arguments. if (operand.isa()) return true; @@ -130,6 +154,56 @@ void GetLiveOuts(Region* region, llvm::SmallVectorImpl* live_outs) { } } +// Reorder all users of the given op's results to after the op. +// +// Since launch ops are inserted after the last op in the region, the region is +// guaranteed to dominate all live-in values. On the other hand, it is still +// possible that live-out values don't dominate the region. For example: +// +// ``` +// %0 = "tf.OpA"() +// %1 = "tf.OpB"(%0) +// %2 = "tf.OpC"(%0) +// ``` +// +// Assuming `tf.OpA` and `tf.OpC` are clustered together, the region will be +// inserted right after `tf.OpC`. The live-out `%0`, however, is used by +// `tf.OpB`, which won't dominate the region. This function reorders all users +// of the cluster op to be placed after the cluster op itself so that SSA +// dominance is preserved after cluster op creation. +void ReorderOpResultUses(mlir::Operation* cluster) { + mlir::Block* const cluster_block = cluster->getBlock(); + llvm::SetVector ops_to_reorder; + + llvm::SmallVector worklist; + llvm::append_range(worklist, cluster->getResults()); + + while (!worklist.empty()) { + mlir::Value value = worklist.back(); + worklist.pop_back(); + + for (mlir::Operation* const user : value.getUsers()) { + mlir::Operation* const op = cluster_block->findAncestorOpInBlock(*user); + if (op == nullptr || !op->isBeforeInBlock(cluster)) { + continue; + } + + if (ops_to_reorder.insert(op)) { + llvm::append_range(worklist, op->getResults()); + } + } + } + + std::vector sorted = ops_to_reorder.takeVector(); + llvm::sort(sorted, [](mlir::Operation* lhs, mlir::Operation* rhs) { + return lhs->isBeforeInBlock(rhs); + }); + + for (mlir::Operation* const op : llvm::reverse(sorted)) { + op->moveAfter(cluster); + } +} + // Build a `tf_device.launch` op with a region that contains all the operations // in given cluster. Then all ops in cluster are replaced by `tf_device.launch`. void BuildLaunchForCluster(const Cluster& c, OpBuilder* builder) { @@ -176,9 +250,14 @@ void BuildLaunchForCluster(const Cluster& c, OpBuilder* builder) { // Replace any external uses of live-out values with return values of launch // op. So live-out values no longer escape the region. ReplaceLiveOutExternalUses(live_outs, launch_op); + + // Ensure that users of the launch op's results appear after the launch op + // in order to preserve the dominance property. + ReorderOpResultUses(launch_op); } -void BuildClusters(Block* block, OpBuilder builder) { +void BuildClusters(Block* block, OpBuilder builder, + const TF::SideEffectAnalysis::Info& side_effect_analysis) { // Iteratively find clusters of different devices within an island. // Whenever we see an operation that is assigned to an accelerator device // (ie. device != ""), we try to merge it into the last cluster of same @@ -200,7 +279,7 @@ void BuildClusters(Block* block, OpBuilder builder) { // Check if it is legal to merge op into nearest cluster of same device. // If positive, update cluster and move on to next operation. Cluster& nearest_cluster = it->second; - if (CanMergeIntoCluster(nearest_cluster, &op)) { + if (CanMergeIntoCluster(nearest_cluster, &op, side_effect_analysis)) { nearest_cluster.ops.emplace_back(&op); continue; } @@ -221,21 +300,28 @@ void BuildClusters(Block* block, OpBuilder builder) { } void ClusterFormationPass::runOnOperation() { - auto func = getOperation(); - if (func.isExternal()) return; - OpBuilder builder(func.getContext()); - - // Operates on individual blocks independently of if they are directly in the - // function body or if they are nested in individual `tf_executor.island`. - for (Block& block : func.getBody()) BuildClusters(&block, builder); - func.walk([&](tf_executor::IslandOp island) { - BuildClusters(&island.GetBody(), builder); - }); + auto module = getOperation(); + auto& side_effect_analysis = getAnalysis(); + + for (auto func : module.getOps()) { + if (func.isExternal()) continue; + OpBuilder builder(func.getContext()); + const TF::SideEffectAnalysis::Info& info = + side_effect_analysis.GetAnalysisForFunc(func); + + // Operates on individual blocks independently of if they are directly in + // the function body or if they are nested in individual + // `tf_executor.island`. + for (Block& block : func.getBody()) BuildClusters(&block, builder, info); + func.walk([&](tf_executor::IslandOp island) { + BuildClusters(&island.GetBody(), builder, info); + }); + } } } // namespace -std::unique_ptr> CreateClusterFormationPass() { +std::unique_ptr> CreateClusterFormationPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc index e5273988b0b..2e256ae4eff 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h" +#include + #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" @@ -100,7 +102,7 @@ void ValuesConstraintSet::Walk( Optional ValuesConstraintSet::GetConstraint( Value value) const { auto it = constraints_.find(value); - if (it == constraints_.end()) return None; + if (it == constraints_.end()) return std::nullopt; return it->getSecond(); } @@ -386,10 +388,10 @@ static Optional CanBeClustered( const std::function &filter) { // Check that op has no side effects. This guarantees that we will not // reorder side-effecting ops during cluster formation. - if (!MemoryEffectOpInterface::hasNoEffect(op)) return llvm::None; + if (!isMemoryEffectFree(op)) return std::nullopt; // Operation rejected by the custom filter. - if (filter && !filter(op)) return llvm::None; + if (filter && !filter(op)) return std::nullopt; // Initially we do not have any constraints on the operation results. ValuesConstraintSet result_constraints; @@ -401,7 +403,7 @@ static Optional CanBeClustered( return operands_constraints.Resolve(); } - return llvm::None; + return std::nullopt; } // Compute initial clustering state based on the clustering polocy. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc index 435a85aec34..abf4abef0ae 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc @@ -25,12 +25,14 @@ limitations under the License. // does not exist any operation placed on host_B that conumes any result of any // operation placed on host_A. +#include + #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "absl/strings/str_cat.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/util/device_name_utils.h" @@ -169,7 +171,7 @@ llvm::Optional> GetFunctionMetadatas( return WalkResult::advance(); }); - if (result.wasInterrupted()) return llvm::None; + if (result.wasInterrupted()) return std::nullopt; return metadatas; } @@ -226,7 +228,7 @@ void CreateFunctions(ModuleOp module_op, // operation should use the arguments of the newly created func_op as // appropriate. OpBuilder builder(block, block->end()); - BlockAndValueMapping mapping; + IRMapping mapping; for (int i : llvm::seq(0, metadata.inputs.size())) { Value original_value = metadata.inputs[i]; Value new_value = func_op.getArgument(i); @@ -254,7 +256,7 @@ void CreateFunctions(ModuleOp module_op, // tf_device.remote_run calls. void CreateRemoteRunCalls(MLIRContext *context, const llvm::StringMap &metadatas) { - BlockAndValueMapping mapping; + IRMapping mapping; for (auto &iter : metadatas) { llvm::StringRef host = iter.first(); const FunctionMetadata &metadata = iter.second; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc index a222d4734bd..4caeb7111c3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" +#include + #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Optional.h" @@ -96,7 +98,7 @@ Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc, loc, ArrayRef{element_type}, ArrayRef{slice, GetR1Const(element_type.getShape(), builder, loc)}); - return reshape.output(); + return reshape.getOutput(); } Value SetElement(Value index, Value buffer, Value element, OpBuilder builder, @@ -119,7 +121,7 @@ Value SetElement(Value index, Value buffer, Value element, OpBuilder builder, loc, ArrayRef{buffer.getType()}, ArrayRef{buffer, update_slice, GetIndicesForElement(index, buffer, builder, loc)}) - .output(); + .getOutput(); } TensorType GetSizeType(OpBuilder builder) { @@ -142,7 +144,8 @@ LogicalResult CreateInitBufferValue(ArrayRef element_shape, auto max_count_const_op = llvm::dyn_cast(max_count_op); if (!max_count_const_op) return op->emitOpError("unknown max element count"); int64_t max_size_const = - (*max_count_const_op.value().getValues().begin()).getSExtValue(); + (*max_count_const_op.getValue().getValues().begin()) + .getSExtValue(); return CreateInitBufferValue(element_shape, max_size_const, op, element_dtype, builder, buffer); } @@ -168,7 +171,7 @@ LogicalResult CreateInitBufferValue(ArrayRef element_shape, auto broadcast = builder.create( op->getLoc(), ArrayRef{buffer_type}, ArrayRef{zero, GetR1Const(buffer_shape, builder, op->getLoc())}); - *buffer = broadcast.output(); + *buffer = broadcast.getOutput(); return success(); } @@ -210,7 +213,7 @@ llvm::Optional GetElementTypeFromAccess( if (elem_type && elem_type.hasStaticShape()) return elem_type; } } - return llvm::None; + return std::nullopt; } // Creates a ReadVariableOp on a local variable. @@ -222,7 +225,7 @@ Value ReadLocalVariable(Value local_var, OpBuilder builder, Location loc) { .cast() .getSubtypes()[0]}, ArrayRef{local_var}) - .value(); + .getValue(); } // Creates an AssignVariableOp on a local variable. @@ -252,7 +255,7 @@ int64_t GetFirstIfIndicesAreContiguous(Value indices) { if (!const_op) return -1; int64_t last_index = -1; int64_t first_index = -1; - for (const auto& ind : const_op.value().getValues()) { + for (const auto& ind : const_op.getValue().getValues()) { if (last_index == -1) { last_index = ind.getSExtValue(); first_index = last_index; @@ -314,12 +317,12 @@ Value ScatterAccumulateElements(Value indices, Value updates, Value buffer, GetElement(index, buffer, builder, loc, /*keep_slice_shape=*/true); starts_in_update[0] = i; auto update_slice_starts = GetR1Const(starts_in_update, builder, loc); - auto slice = + Value slice = builder .create( loc, ArrayRef{old_slice.getType()}, ArrayRef{updates, update_slice_starts, slice_sizes}) - .output(); + .getOutput(); slice = AccumulateBuffers(old_slice, slice, builder, loc); buffer = SetElement(index, buffer, slice, builder, loc); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index b1c7879394d..55354d3eec5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -89,7 +89,7 @@ LogicalResult ConstantFoldFallbackHook( // TensorFlow folding hook. if (inst->getNumResults() == 0 || inst->hasTrait() || - inst->getNumRegions() != 0 || !MemoryEffectOpInterface::hasNoEffect(inst)) + inst->getNumRegions() != 0 || !isMemoryEffectFree(inst)) return failure(); // If any of the result types are variants, don't try to constant fold them. @@ -151,6 +151,10 @@ LogicalResult ConstantFoldFallbackHook( // with all GPU/TPU devices ignored and CPU only set to 1. (*config_proto.mutable_device_count())["CPU"] = 1; config_proto.add_device_filters("/device:CPU:*"); + // Limit the thread pool size. Without this, TF by default creates as many + // threads as the number of CPUs (`port::MaxParallelism()`). This can be + // expensive since this TFE context persists the entire program execution. + config_proto.set_inter_op_parallelism_threads(2); std::unique_ptr config( TF_NewBuffer(), TF_DeleteBuffer); DCHECK(config->data == nullptr); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc index b96061aedc5..5af3b02a2f3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc @@ -232,10 +232,9 @@ void AppendFunctionResults(func::FuncOp func, int num_resources, graph_op.erase(); func::ReturnOp return_op = cast(block.getTerminator()); int num_old_arguments = return_op.getNumOperands(); - for (int i = 0; i < num_resources; ++i) { - return_op.operandsMutable().append( - new_graph_op.getResult(num_old_arguments + i)); - } + return_op->insertOperands( + num_old_arguments, + new_graph_op.getResults().slice(num_old_arguments, num_resources)); } // Creates a wrapper island enclosing the `sub_op` dependent on @@ -347,7 +346,7 @@ TF::WhileOp RewriteWhileOp(TF::WhileOp while_op, int num_resource_inputs, // Get the dummy constant. OpBuilder builder(while_wrapper); auto loc = NameLoc::get( - builder.getStringAttr("chain_control_outputs@" + while_op.body())); + builder.getStringAttr("chain_control_outputs@" + while_op.getBody())); IslandOp const_wrapper = GetDummyConstant(builder, const_type, loc); // Get new operand and result types. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_launch_func_to_tf_call.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_launch_func_to_tf_call.cc index 136137121da..96dc678dd1c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_launch_func_to_tf_call.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_launch_func_to_tf_call.cc @@ -52,7 +52,7 @@ void ConvertLaunchFuncToTFCallPass::runOnOperation() { module.walk([&](tf_device::LaunchFuncOp launch) { OpBuilder builder(launch); auto call_op = builder.create( - module.getLoc(), launch.getResultTypes(), launch.operands(), + module.getLoc(), launch.getResultTypes(), launch.getOperands(), SymbolRefAttr::get(builder.getContext(), launch.getFunc()), /*config=*/builder.getStringAttr(""), /*config_proto=*/builder.getStringAttr(""), diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_session_initializer_to_function.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_session_initializer_to_function.cc new file mode 100644 index 00000000000..491cb4e46a7 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_session_initializer_to_function.cc @@ -0,0 +1,100 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" + +namespace mlir { +namespace tf_saved_model { + +namespace { + +#define GEN_PASS_DEF_CONVERTSESSIONINITIALIZERTOFUNCTIONPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_savedmodel_passes.h.inc" + +struct ConvertSessionInitializerToFunctionPass + : public impl::ConvertSessionInitializerToFunctionPassBase< + ConvertSessionInitializerToFunctionPass> { + void runOnOperation() override; +}; + +void ConvertSessionInitializerToFunctionPass::runOnOperation() { + ModuleOp module = getOperation(); + auto session_initializer = tf_saved_model::GetSessionInitializerOp(module); + if (!session_initializer) return; + + OpBuilder builder(session_initializer); + const char *name = "session_initializer"; + + // In the (unlikely) case of there already being a session initializer + // function, bail out. + if (SymbolTable::lookupSymbolIn(module, name)) { + module->emitWarning("session_initializer function already exists"); + session_initializer.erase(); + return; + } + + auto init = builder.create( + module.getLoc(), name, + FunctionType::get(module.getContext(), /*inputs=*/{}, /*results=*/{})); + + // Make savedmodel verification happy. + init->setAttr("tf_saved_model.exported_names", + builder.getStrArrayAttr({name})); + + builder.setInsertionPointToStart(init.addEntryBlock()); + + for (func::FuncOp func : tf_saved_model::GetInitializerFunctions(module)) { + if (func.getNumArguments() != 0) { + session_initializer->emitWarning( + "encountered session initializers with arguments"); + continue; + } + + // Since we're now calling this function, savedmodel verification + // needs it to be private. + func.setVisibility(SymbolTable::Visibility::Private); + func->removeAttr("tf_saved_model.exported_names"); + + ArrayRef args; + builder.create(session_initializer.getLoc(), + func.getFunctionType().getResults(), + func.getSymName(), args); + } + builder.create(session_initializer.getLoc()); + + session_initializer.erase(); +} + +} // namespace + +std::unique_ptr> +CreateConvertSessionInitializerToFunctionPass() { + return std::make_unique(); +} + +} // namespace tf_saved_model +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_tf_control_flow_to_scf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_tf_control_flow_to_scf.cc index a39cf338ac0..a3266f58718 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_tf_control_flow_to_scf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_tf_control_flow_to_scf.cc @@ -87,7 +87,8 @@ class ConvertIfRegionOp : public OpRewritePattern { // condition of the `tf.IfRegion` op is a 0-D tensor of 1-bit signless // integers. Thus, we use the `tensor.extract` op to compute the condition // of `scf.if` from that of `tf.IfRegion`. - auto scf_if_condition = rewriter.create(loc, op.cond()); + auto scf_if_condition = + rewriter.create(loc, op.getCond()); TypeRange tf_if_region_return_type = op.getResultTypes(); @@ -96,8 +97,8 @@ class ConvertIfRegionOp : public OpRewritePattern { rewriter.create(loc, tf_if_region_return_type, scf_if_condition, /*withElseRegion=*/true); - Region& then_region = op.then_branch(); - Region& else_region = op.else_branch(); + Region& then_region = op.getThenBranch(); + Region& else_region = op.getElseBranch(); // Create the `then` and `else` regions of the `scf.if` op. createScfThenOrElse(then_region, scf_if_op.getThenRegion(), @@ -140,7 +141,7 @@ class ConvertWhileRegionOp : public OpRewritePattern { return cond_or_body_terminator; }; - ValueRange opInput = op.input(); + ValueRange opInput = op.getInput(); TypeRange scf_block_arguments_type = opInput.getType(); // Create the `scf.while` op. @@ -155,7 +156,7 @@ class ConvertWhileRegionOp : public OpRewritePattern { // `tensor.extract` op to compute the input of `scf.condition`. rewriter.createBlock(&scf_while_op.getBefore()); Operation* cond_terminator = - createScfCondOrBody(op.cond(), scf_while_op.getBefore(), + createScfCondOrBody(op.getCond(), scf_while_op.getBefore(), scf_block_arguments_type, rewriter); auto scf_condition_input = rewriter.create( cond_terminator->getLoc(), cond_terminator->getOperand(0)); @@ -167,8 +168,9 @@ class ConvertWhileRegionOp : public OpRewritePattern { // the terminator). Note that the arguments' type of this block is kept as // `opInput`'s type. rewriter.createBlock(&scf_while_op.getAfter()); - Operation* body_terminator = createScfCondOrBody( - op.body(), scf_while_op.getAfter(), scf_block_arguments_type, rewriter); + Operation* body_terminator = + createScfCondOrBody(op.getBody(), scf_while_op.getAfter(), + scf_block_arguments_type, rewriter); rewriter.replaceOpWithNewOp(body_terminator, body_terminator->getOperands()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc index 0ea6d309be8..8d91ef5b311 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc @@ -25,9 +25,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor.h" #include "tensorflow/core/framework/logging.h" +#include "tensorflow/tsl/platform/statusor.h" namespace mlir { namespace { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc index 200d8b9ef58..8e89f3988dd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc @@ -86,7 +86,7 @@ AnonymousIteratorV3Op CreateIterator(OpBuilder builder, /*output_types=*/builder.getArrayAttr(type_attrs), /*shape_types=*/builder.getArrayAttr(shape_attrs)); builder.create(reduce_dataset.getLoc(), - reduce_dataset.input_dataset(), + reduce_dataset.getInputDataset(), anonymous_iterator.getResult()); return anonymous_iterator; } @@ -127,7 +127,7 @@ WhileRegionOp CreateDatasetWhile(OpBuilder builder, // condition of whether to continue to next iteration. void PopulateDatasetWhileCond(OpBuilder builder, WhileRegionOp dataset_while, Location loc) { - auto& cond_region = dataset_while.cond(); + auto& cond_region = dataset_while.getCond(); Block* cond_block = builder.createBlock(&cond_region); auto while_input_types = dataset_while.getOperandTypes(); cond_block->addArguments( @@ -160,7 +160,7 @@ IfRegionOp CreateOptionalDatasetIf( // parallelization. dataset_if->setAttr("_lower_using_switch_merge", builder.getBoolAttr(true)); // Empty else branch, if there is no more data, do nothing. - auto& else_branch = dataset_if.else_branch(); + auto& else_branch = dataset_if.getElseBranch(); else_branch.push_back(new Block); builder.setInsertionPointToEnd(&else_branch.front()); // Return only the state variables from the body arguments. @@ -172,7 +172,7 @@ IfRegionOp CreateOptionalDatasetIf( /*operands=*/else_returns); // Then branch gets the data and calls the reduce_function. - auto& then_branch = dataset_if.then_branch(); + auto& then_branch = dataset_if.getThenBranch(); then_branch.push_back(new Block); builder.setInsertionPointToEnd(&then_branch.front()); // Add iterator operational data access inside if. @@ -220,14 +220,14 @@ void PopulateDatasetWhileBody(OpBuilder builder, ReduceDatasetOp reduce_dataset, ArrayRef dataset_types) { const Location loc = reduce_dataset.getLoc(); auto while_input_types = dataset_while.getOperandTypes(); - auto& body_region = dataset_while.body(); + auto& body_region = dataset_while.getBody(); Block* body_block = builder.createBlock(&body_region); auto body_arguments = body_block->addArguments( while_input_types, SmallVector(while_input_types.size(), loc)); auto get_next = builder.create( loc, RankedTensorType::get({}, builder.getType()), - anonymous_iterator.getResult(), anonymous_iterator.output_types(), - anonymous_iterator.output_shapes()); + anonymous_iterator.getResult(), anonymous_iterator.getOutputTypes(), + anonymous_iterator.getOutputShapes()); auto optional_has_value = builder.create( loc, RankedTensorType::get({}, builder.getI1Type()), get_next.getResult()); @@ -279,7 +279,7 @@ LogicalResult DecomposeReduceDatasetInFunction(FuncOp function) { // complexity = # ReduceDataset ops x # of functions in module. func::FuncOp reduce_func = function->getParentOfType().lookupSymbol( - reduce_dataset.f()); + reduce_dataset.getF()); // The reduce function arguments consist of three part in this order: // 1. Reduction state inputs. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc index e5d72fb0174..0a205859957 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc @@ -93,7 +93,7 @@ class DecomposeRngReadAndSkipOp : public RewritePattern { auto rng_op = cast(op); DenseIntElementsAttr alg_constant; - if (!matchPattern(rng_op.alg(), m_Constant(&alg_constant))) { + if (!matchPattern(rng_op.getAlg(), m_Constant(&alg_constant))) { return rewriter.notifyMatchFailure( op, "unable to determine algorithm statically"); } @@ -123,7 +123,7 @@ class DecomposeRngReadAndSkipOp : public RewritePattern { return rewriter.notifyMatchFailure(op, "unexpected op type"); } - if (!HasResourceSubtype(rng_op.resource())) { + if (!HasResourceSubtype(rng_op.getResource())) { return rewriter.notifyMatchFailure(op, "missing resource subtype"); } @@ -131,7 +131,7 @@ class DecomposeRngReadAndSkipOp : public RewritePattern { int state_size = counter_size + tensorflow::RNG_KEY_SIZE; RankedTensorType res_type = RankedTensorType::get({state_size}, state_element_type); - if (res_type != GetResourceSubtype(rng_op.resource())) { + if (res_type != GetResourceSubtype(rng_op.getResource())) { return rewriter.notifyMatchFailure(op, "unexpected resource subtype"); } @@ -139,7 +139,7 @@ class DecomposeRngReadAndSkipOp : public RewritePattern { // Read the state value from the resource. Value state = - rewriter.create(loc, res_type, rng_op.resource()); + rewriter.create(loc, res_type, rng_op.getResource()); // Extract the key and counter from the state. RankedTensorType word_type = RankedTensorType::get({}, state_element_type); @@ -157,7 +157,7 @@ class DecomposeRngReadAndSkipOp : public RewritePattern { RankedTensorType u64_scalar = RankedTensorType::get({}, u64); Value step_size = rewriter.create(loc, GetScalarOfType(u64, 256)); Value increment = - rewriter.create(loc, u64_scalar, step_size, rng_op.delta()); + rewriter.create(loc, u64_scalar, step_size, rng_op.getDelta()); // Increment the counter. SmallVector pack_args; @@ -178,7 +178,7 @@ class DecomposeRngReadAndSkipOp : public RewritePattern { // Save the new state value to the resource. pack_args.push_back(key); Value new_state = rewriter.create(loc, res_type, pack_args); - rewriter.create(loc, rng_op.resource(), new_state); + rewriter.create(loc, rng_op.getResource(), new_state); // Pad the original state as necessary to fill the output shape. int pad = tensorflow::RNG_MAX_COUNTER_SIZE - counter_size; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc index b44aa1526e1..617e355fbfe 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc @@ -93,10 +93,13 @@ LogicalResult ApplyPatternsLocallyUntilConverged( changed = false; auto walk_result = op_with_regions->walk([&patterns, &changed](Operation* operation) { - bool op_changed; - if (failed(applyOpPatternsAndFold(operation, patterns, &op_changed))) + GreedyRewriteConfig config; + config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps; + bool op_erased; + if (failed(applyOpPatternsAndFold(operation, patterns, config, + &op_erased))) return WalkResult::interrupt(); - changed |= op_changed; + changed |= op_erased; return WalkResult::advance(); }); if (walk_result.wasInterrupted()) return failure(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/device_index_selector.cc b/tensorflow/compiler/mlir/tensorflow/transforms/device_index_selector.cc index 008858ca692..e0467bea424 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/device_index_selector.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/device_index_selector.cc @@ -63,7 +63,7 @@ void DeviceIndexSelector::runOnOperation() { // future. OpBuilder b(op); RankedTensorType type = RankedTensorType::get({}, b.getIntegerType(32)); - int index = op.device_names().size(); + int index = op.getDeviceNames().size(); for (auto use : op.getOperation()->getUsers()) { // Skip if it doesn't feed into case. Alternatively this could always // return the CPU device index if it exists. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index 9146951f7d5..df3b7fa1d93 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -19,9 +19,11 @@ limitations under the License. #include #include #include +#include #include #include #include +#include #include "absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h" @@ -47,6 +49,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h" #include "tensorflow/core/util/matmul_bcast.h" @@ -99,7 +102,8 @@ TF::TransposeOp createTransposeOp(Value value, Location loc, TF::ReshapeOp createReshapeOp(Value value, ArrayRef shape, Type element_type, Location loc, PatternRewriter* rewriter) { - auto shape_tensor = createI64ConstantOp(shape, loc, rewriter); + auto shape_tensor = createI64ConstantOp( + tensorflow::ConvertMlirShapeToTF(shape), loc, rewriter); Type resultType = RankedTensorType::get(shape, element_type); return rewriter->create(loc, resultType, /*tensor=*/value, /*shape=*/shape_tensor); @@ -208,11 +212,11 @@ llvm::Optional> EquationToMap( for (int64_t i = 0; i < equation.size(); ++i) { if (!std::isalpha(equation[i])) { // Unsupported character in the equation. - return llvm::None; + return std::nullopt; } if (map.count(equation[i])) { // Duplicate character in the equation. - return llvm::None; + return std::nullopt; } map.try_emplace(equation[i], i); } @@ -243,11 +247,11 @@ llvm::Optional> GetAvailableLabels( labels.remove(label); ++lhs_count; } else if (label == '.') { - if (!is_start_of_ellipsis(lhs, i)) return llvm::None; + if (!is_start_of_ellipsis(lhs, i)) return std::nullopt; i += 2; } else { // Unsupported character in the equation. - return llvm::None; + return std::nullopt; } } *lhs_named_label_count = lhs_count; @@ -260,11 +264,11 @@ llvm::Optional> GetAvailableLabels( labels.remove(label); ++rhs_count; } else if (label == '.') { - if (!is_start_of_ellipsis(rhs, i)) return llvm::None; + if (!is_start_of_ellipsis(rhs, i)) return std::nullopt; i += 2; } else { // Unsupported character in the equation. - return llvm::None; + return std::nullopt; } } @@ -347,38 +351,38 @@ llvm::Optional GetEinsumDimensionNumbers( llvm::StringRef lhs_rhs; llvm::StringRef out; std::tie(lhs_rhs, out) = equation.split("->"); - if (lhs_rhs.empty() || out.empty()) return llvm::None; + if (lhs_rhs.empty() || out.empty()) return std::nullopt; llvm::StringRef lhs; llvm::StringRef rhs; std::tie(lhs, rhs) = lhs_rhs.split(','); - if (lhs.empty() || rhs.empty()) return llvm::None; + if (lhs.empty() || rhs.empty()) return std::nullopt; // Try to flatten the "..." if possible. int lhs_named_label, rhs_named_label; auto available_labels = GetAvailableLabels(lhs, rhs, &lhs_named_label, &rhs_named_label); - if (!available_labels.has_value()) return llvm::None; + if (!available_labels.has_value()) return std::nullopt; auto flattended_labels = FlattenEllipsis(lhs, lhs_named_label, rhs, rhs_named_label, out, lhs_ty, - rhs_ty, available_labels.getValue()); + rhs_ty, available_labels.value()); lhs = std::get<0>(flattended_labels); rhs = std::get<1>(flattended_labels); out = std::get<2>(flattended_labels); auto lhs_map_or = EquationToMap(lhs); - if (!lhs_map_or.has_value()) return llvm::None; - auto lhs_map = lhs_map_or.getValue(); + if (!lhs_map_or.has_value()) return std::nullopt; + auto lhs_map = lhs_map_or.value(); auto rhs_map_or = EquationToMap(rhs); - if (!rhs_map_or.has_value()) return llvm::None; - auto rhs_map = rhs_map_or.getValue(); + if (!rhs_map_or.has_value()) return std::nullopt; + auto rhs_map = rhs_map_or.value(); auto out_map_or = EquationToMap(out); - if (!out_map_or.has_value()) return llvm::None; - auto out_map = out_map_or.getValue(); + if (!out_map_or.has_value()) return std::nullopt; + auto out_map = out_map_or.value(); EinsumDimensionNumbers dnums; for (int64_t i = 0, e = lhs.size(); i < e; ++i) { @@ -410,7 +414,7 @@ llvm::Optional GetEinsumDimensionNumbers( auto rhs_index = rhs_map.find(out[i]); if (lhs_index == lhs_map.end() && rhs_index == rhs_map.end()) { // out only isn't supported - return llvm::None; + return std::nullopt; } } return dnums; @@ -487,7 +491,7 @@ inline int64_t ProdShapeWithIndexInTuple( int64_t prod_shape = 1; for (auto index_tuple : index_tuples) { const int64_t shape_i = shape[std::get(index_tuple)]; - if (shape_i == -1) return -1; + if (ShapedType::isDynamic(shape_i)) return ShapedType::kDynamic; prod_shape *= shape_i; } return prod_shape; @@ -629,7 +633,7 @@ LogicalResult rewriteToBatchMatmul(TF::EinsumOp op, PatternRewriter& rewriter) { if (!dnums.lhs.empty() || !dnums.rhs.empty()) return failure(); - auto inputs = op.inputs(); + auto inputs = op.getInputs(); if (inputs.size() != 2) return failure(); Value lhs = inputs.front(); Value rhs = inputs.back(); @@ -712,8 +716,9 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite( // dynamic dimension is always supported. If there are two or more dynamic // dimensions, it is supported if they only exist in a single component // among: L0,...,Ln R0,...,Rn or C0,...,Cn. - if (const auto dnums_or = GetEinsumDimensionNumbers(op.equation(), lhs, rhs)) - return rewriteToBatchMatmul(op, dnums_or.getValue(), rewriter); + if (const auto dnums_or = + GetEinsumDimensionNumbers(op.getEquation(), lhs, rhs)) + return rewriteToBatchMatmul(op, dnums_or.value(), rewriter); return rewriter.notifyMatchFailure(op, "unsupported einsum lowering"); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc index 122c1b8279c..c45fe527d0f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc @@ -57,7 +57,7 @@ void ExecutorTPUV1IslandInliningPass::runOnOperation() { InlinerInterface inliner(&getContext()); auto walk_result = getOperation().walk([&](TF::PartitionedCallOp call_op) { - if (!call_op.f().getRootReference().getValue().startswith(kNestedModule)) + if (!call_op.getF().getRootReference().getValue().startswith(kNestedModule)) return WalkResult::advance(); // This is a call we need to inline! LLVM_DEBUG(llvm::dbgs() diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc index 5f9fdbd5e64..e4224042d2f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include +#include #include #include #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" @@ -71,7 +71,7 @@ struct TpuV1BridgeExecutorIslandCoarsening }; // Returns name of TPU cluster, if op belongs to a TPU cluster. Otherwise, -// returns `llvm::None`. +// returns `std::nullopt`. llvm::Optional GetTpuClusterName(Operation* op) { if (auto tpu_status = op->getAttrOfType(kTpuStatusAttr)) { // Borrow cluster name from TPU status (for `TPUCompilationResult` op). @@ -80,7 +80,7 @@ llvm::Optional GetTpuClusterName(Operation* op) { auto device_type = op->getAttrOfType(TF::kCompileDeviceTypeAttr); if (!device_type || device_type.getValue() != TF::kTpuDevice) { // Op does not belong to a TPU cluster. - return llvm::None; + return std::nullopt; } // Op belongs to a TPU cluster. if (auto replication_info = @@ -196,7 +196,7 @@ void CollectCandidateIslands( GetTpuClusterName(&candidate_wrapped_op); llvm::StringRef candidate_cluster_name; if (result.has_value()) { - candidate_cluster_name = result.getValue(); + candidate_cluster_name = result.value(); } else if (is_op_calling_func_for_cluster(cluster_name, &candidate_wrapped_op)) { candidate_cluster_name = cluster_name; @@ -298,7 +298,7 @@ LogicalResult MergeIsland( llvm::Optional result = GetTpuClusterName(&wrapped_op); if (!result.has_value()) return success(); - llvm::StringRef cluster_name = result.getValue(); + llvm::StringRef cluster_name = result.value(); // We found a _replication_info, let's build an island for the full cluster! LLVM_DEBUG(llvm::dbgs() << "Processing candidate island: " @@ -396,7 +396,7 @@ bool is_valid_special_tpu_op( bool op_has_inconsistent_cluster_name = wrapped_op_cluster_name.has_value() && - !wrapped_op_cluster_name.getValue().equals(cluster_name); + !wrapped_op_cluster_name.value().equals(cluster_name); if (op_has_inconsistent_cluster_name) { return false; @@ -511,7 +511,7 @@ LogicalResult CollectSpecialTpuOps( llvm::Optional result = GetTpuClusterName(&wrapped_op); if (!result.has_value()) return success(); - llvm::StringRef cluster_name = result.getValue(); + llvm::StringRef cluster_name = result.value(); visited_wrapped_ops.insert(&wrapped_op); @@ -547,12 +547,12 @@ bool ExcludeIdentityOp(llvm::SmallDenseSet& tpu_ops, for (IslandOp wrapper : ops) { Operation* wrapped_op = &wrapper.GetBody().front(); auto cluster_name = GetTpuClusterName(wrapped_op); - if (cluster_name.hasValue() && - cluster_name.getValue() != target_cluster_name) { + if (cluster_name.has_value() && + cluster_name.value() != target_cluster_name) { tpu_ops.erase(iter); return true; } - if (!cluster_name.hasValue() && + if (!cluster_name.has_value() && !tpu_ops.count(wrapper.getOperation())) { tpu_ops.erase(iter); return true; @@ -590,10 +590,10 @@ void EraseIdentityWithNoReplicationInfo(Block& graph_body) { if (!island || island.WrapsSingleOp()) continue; for (Operation& op : llvm::make_early_inc_range(island.GetBody())) { llvm::Optional cluster_name = GetTpuClusterName(&op); - if (cluster_name.hasValue()) continue; + if (cluster_name.has_value()) continue; if (auto identity_op = llvm::dyn_cast_or_null(op)) { - auto identity_input = identity_op.input(); - auto output = identity_op.output(); + auto identity_input = identity_op.getInput(); + auto output = identity_op.getOutput(); output.replaceAllUsesWith(identity_input); identity_op.erase(); } @@ -611,7 +611,7 @@ void TpuV1BridgeExecutorIslandCoarsening::runOnOperation() { func_op.walk([&](Operation* op) { llvm::Optional cluster_name_opt = GetTpuClusterName(op); if (cluster_name_opt.has_value()) { - tpu_funcs[cluster_name_opt.getValue()].insert(func_op); + tpu_funcs[cluster_name_opt.value()].insert(func_op); } }); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc index ea44b748e7b..2f0e7dbdb11 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc @@ -102,14 +102,14 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteBatchMatMulV2Op( } const int x_row = - matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1); + matmul_op.getAdjX() ? shape_x.back() : *(shape_x.rbegin() + 1); const int x_col = - !matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1); + !matmul_op.getAdjX() ? shape_x.back() : *(shape_x.rbegin() + 1); const int y_row = - matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1); + matmul_op.getAdjY() ? shape_y.back() : *(shape_y.rbegin() + 1); const int y_col = - !matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1); + !matmul_op.getAdjY() ? shape_y.back() : *(shape_y.rbegin() + 1); // Checks that matrix multiply can perform a valid contraction. if (x_col != y_row) { @@ -129,7 +129,7 @@ template LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp( Operation* op, PatternRewriter& rewriter) const { auto eq_op = llvm::dyn_cast_or_null(op); - if (eq_op && eq_op.incompatible_shape_error()) + if (eq_op && eq_op.getIncompatibleShapeError()) return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape); return failure(); } @@ -156,7 +156,7 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Check that the operand of the broadcast has fully defined shape. auto broadcast_arg_type = - broadcast.input().getType().dyn_cast_or_null(); + broadcast.getInput().getType().dyn_cast_or_null(); if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue; // Check that the other argument has fully defined shape. @@ -184,7 +184,7 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Update the operand of the op to be the operand of the broadcast. rewriter.updateRootInPlace( - op, [&]() { op->getOpOperand(i).set(broadcast.input()); }); + op, [&]() { op->getOpOperand(i).set(broadcast.getInput()); }); changed = true; } return success(changed); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc index a0b1f0cac59..ad5edd1b2dc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project #include "mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -58,6 +59,7 @@ void FreezeGlobalTensorsPass::runOnOperation() { DataFlowSolver solver; solver.load(); + solver.load(); solver.load(); if (failed(solver.initializeAndRun(module))) return signalPassFailure(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_saved_model_assets.cc b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_saved_model_assets.cc index 47d40e0651f..daaf9df7400 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_saved_model_assets.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_saved_model_assets.cc @@ -48,7 +48,6 @@ struct FreezeAssetsPass : public impl::FreezeAssetsPassBase { void runOnOperation() override; private: - // TODO(team): should be a pass option. std::string saved_model_dir; }; @@ -101,9 +100,9 @@ void FreezeAssetsPass::runOnOperation() { // asset filepath. builder.setInsertionPoint(init_op); builder.create( - init_op.getLoc(), init_op.table_handle(), const_op.getResult(), - init_op.key_index(), init_op.value_index(), init_op.vocab_size(), - init_op.delimiter()); + init_op.getLoc(), init_op.getTableHandle(), const_op.getResult(), + init_op.getKeyIndex(), init_op.getValueIndex(), + init_op.getVocabSize(), init_op.getDelimiter()); init_op.erase(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc index 09dbcb103ed..65d9a288d56 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc @@ -47,9 +47,9 @@ struct FunctionalControlFlowToCFG // Lowers a general tensor argument that is used as a condition to a functional // control flow op into an i1 value. static Value LowerCondition(Location loc, Value value, OpBuilder* builder) { - auto zero_d = builder->create(loc, value); - auto scalar = builder->create(loc, zero_d); - return scalar.getResult(); + Value zero_d = builder->create(loc, value); + Value scalar = builder->create(loc, zero_d); + return scalar; } // Calls the function `fn` with arguments provided by the given function and @@ -143,7 +143,7 @@ static LogicalResult LowerIfOp(IfOp op) { OpBuilder builder(op_inst); // Lower the condition to a boolean value (i1). - Value cond_i1 = LowerCondition(loc, op.cond(), &builder); + Value cond_i1 = LowerCondition(loc, op.getCond(), &builder); if (!cond_i1) return failure(); // Split the basic block before the 'if'. The new dest will be our merge diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc index ff9fe343930..bc64b6ed975 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc @@ -98,19 +98,19 @@ Value ConvertConditionToBoolean(Operation* op, Value cond) { // Transform a functional IfOp to a region based IfRegionOp. LogicalResult ConvertIfOp(IfOp if_op) { - Value cond = ConvertConditionToBoolean(if_op, if_op.cond()); + Value cond = ConvertConditionToBoolean(if_op, if_op.getCond()); OpBuilder builder(if_op); auto if_region = builder.create( - if_op.getLoc(), if_op.getResultTypes(), cond, if_op.is_stateless(), + if_op.getLoc(), if_op.getResultTypes(), cond, if_op.getIsStateless(), builder.getStringAttr(if_op.then_function().getName()), builder.getStringAttr(if_op.else_function().getName())); CopyDeviceAndUnderscoredAttributes(if_op, if_region); CreateCall(if_op, if_op.then_function(), - /*caller_region=*/if_region.then_branch(), if_op.input(), + /*caller_region=*/if_region.getThenBranch(), if_op.getInput(), /*use_region_args=*/false); CreateCall(if_op, if_op.else_function(), - /*caller_region=*/if_region.else_branch(), if_op.input(), + /*caller_region=*/if_region.getElseBranch(), if_op.getInput(), /*use_region_args=*/false); if_op.replaceAllUsesWith(if_region.getResults()); if_op.erase(); @@ -119,21 +119,21 @@ LogicalResult ConvertIfOp(IfOp if_op) { LogicalResult ConvertWhileOp(WhileOp while_op) { auto while_region = OpBuilder(while_op).create( - while_op.getLoc(), while_op.getResultTypes(), while_op.input(), - while_op.parallel_iterations(), while_op.is_stateless(), - while_op.shape_invariant()); + while_op.getLoc(), while_op.getResultTypes(), while_op.getInput(), + while_op.getParallelIterations(), while_op.getIsStateless(), + while_op.getShapeInvariant()); CopyDeviceAndUnderscoredAttributes(while_op, while_region); YieldOp cond_yield = CreateCall(while_op, while_op.cond_function(), - /*caller_region=*/while_region.cond(), while_op.input(), + /*caller_region=*/while_region.getCond(), while_op.getInput(), /*use_region_args=*/true); Value i1_cond = ConvertConditionToBoolean(cond_yield, cond_yield.getOperand(0)); cond_yield.setOperand(0, i1_cond); CreateCall(while_op, while_op.body_function(), - /*caller_region=*/while_region.body(), while_op.input(), + /*caller_region=*/while_region.getBody(), while_op.getInput(), /*use_region_args=*/true); while_op.replaceAllUsesWith(while_region.getResults()); while_op.erase(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc index 2db53e7fa6e..185daa50a86 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include "llvm/ADT/StringRef.h" @@ -82,7 +83,7 @@ BiasAddOp GetBiasAdd(Value op) { for (auto &use : op.getUses()) { auto bias_add = dyn_cast_or_null(use.getOwner()); // If it's a BiasAdd, check that the conv op is the first input. - if (bias_add && bias_add.value() == op) return bias_add; + if (bias_add && bias_add.getValue() == op) return bias_add; } // No BiasAddOps found among uses. return BiasAddOp(); @@ -162,7 +163,7 @@ class FuseContractionWithBiasAdd : public OpRewritePattern { // If there is an activation, only fuse it if this is the only op to use the // result of the BiasAdd. - bool fuse_activation = activation && bias_add.output().hasOneUse(); + bool fuse_activation = activation && bias_add.getOutput().hasOneUse(); Type result_type; // Include info about the activation function if applicable. @@ -181,7 +182,7 @@ class FuseContractionWithBiasAdd : public OpRewritePattern { // with `bias` from the BiasAddOp appended. SmallVector operands(contraction.operand_begin(), contraction.operand_end()); - operands.push_back(bias_add.bias()); + operands.push_back(bias_add.getBias()); // The fused contraction has the same attributes as the original // contraction, with two additions: the list of ops which have been fused @@ -241,15 +242,15 @@ const char kDeviceGpu[] = "GPU"; llvm::Optional GetDevice(mlir::Operation *op) { mlir::StringAttr device = op->getAttrOfType(kDeviceAttr); if (!device || device.getValue().empty()) { - return llvm::None; + return std::nullopt; } const std::string device_name = device.str(); tensorflow::DeviceNameUtils::ParsedName parsed_name; if (!tensorflow::DeviceNameUtils::ParseFullName(device_name, &parsed_name)) { - return llvm::None; + return std::nullopt; } if (!parsed_name.has_type) { - return llvm::None; + return std::nullopt; } return parsed_name.type; } @@ -273,10 +274,11 @@ class FuseConv2DBiasAdd bool AreFuseCompatible(Conv2DOp conv, BiasAddOp bias_add, PatternRewriter &rewriter) const override { // Verify that the data formats match and are valid for fusion. - if (conv.data_format() != bias_add.data_format()) { + if (conv.getDataFormat() != bias_add.getDataFormat()) { (void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) { diag << "data format does not match Conv2D data format (" - << bias_add.data_format() << " vs " << conv.data_format() << ")"; + << bias_add.getDataFormat() << " vs " << conv.getDataFormat() + << ")"; }); return false; } @@ -299,7 +301,7 @@ class FuseConv2DBiasAdd if (IsGpuDevice(conv)) { auto activation = GetActivation(bias_add); if (!activation || activation->getName().stripDialect() != "Relu" || - !bias_add.output().hasOneUse()) { + !bias_add.getOutput().hasOneUse()) { (void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) { diag << "GPU only supports Conv2D+BiasAdd+Relu fusion"; }); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc index f3c0f96e2ad..98e0f3b3454 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc @@ -67,7 +67,7 @@ struct ReluToFusedBatchNorm : public OpRewritePattern { LogicalResult matchAndRewrite(ReluOp relu_op, PatternRewriter &rewriter) const override { - Operation *relu_input = relu_op.features().getDefiningOp(); + Operation *relu_input = relu_op.getFeatures().getDefiningOp(); if (!relu_input) return failure(); auto batch_norm = dyn_cast_or_null(relu_input); AddV2Op add_op; @@ -79,20 +79,20 @@ struct ReluToFusedBatchNorm : public OpRewritePattern { if (!add_op) return failure(); batch_norm = - dyn_cast_or_null(add_op.x().getDefiningOp()); + dyn_cast_or_null(add_op.getX().getDefiningOp()); if (batch_norm) { - side_input = add_op.y(); + side_input = add_op.getY(); } else { // Didn't get a FusedBatchNorm on the LHS of the AddV2, try the RHS. batch_norm = - dyn_cast_or_null(add_op.y().getDefiningOp()); + dyn_cast_or_null(add_op.getY().getDefiningOp()); if (!batch_norm) return failure(); - side_input = add_op.x(); + side_input = add_op.getX(); } } assert(batch_norm); - if (batch_norm.is_training()) return failure(); - if (!batch_norm.y().hasOneUse()) return failure(); + if (batch_norm.getIsTraining()) return failure(); + if (!batch_norm.getY().hasOneUse()) return failure(); // Build the newly fused operation to replace the batch norm OperationState state(batch_norm.getLoc(), @@ -105,7 +105,7 @@ struct ReluToFusedBatchNorm : public OpRewritePattern { rewriter.replaceOp(batch_norm, op->getResults()); // Depending on the case, we may fuse the add, the relu, or both. - if (!add_op || add_op.z().hasOneUse()) { + if (!add_op || add_op.getZ().hasOneUse()) { // We fuse the Relu only if the add has a single use, otherwise we only // fuse the add itself. op->setAttr("activation_mode", rewriter.getStringAttr("Relu")); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc index 7691bd89235..9c03ce5def2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc @@ -20,11 +20,11 @@ limitations under the License. #include "llvm/ADT/DenseSet.h" #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project -#include "mlir/Transforms/SideEffectUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir { @@ -71,7 +71,7 @@ bool ResourceOpCanBeHoisted( auto read_var_op = llvm::dyn_cast(op); if (!read_var_op) return false; auto var_handle_op = llvm::dyn_cast_or_null( - read_var_op.resource().getDefiningOp()); + read_var_op.getResource().getDefiningOp()); if (!var_handle_op) return false; return read_only_vars.contains(GetResourceHandle(var_handle_op)); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_replicate_invariant_resource_writes.cc b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_replicate_invariant_resource_writes.cc index d81cbcbbefc..ffc650cb907 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_replicate_invariant_resource_writes.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_replicate_invariant_resource_writes.cc @@ -71,9 +71,9 @@ void MoveTailWritesAfterReplicate( // returned. auto new_result_types = llvm::to_vector<4>(replicate_op->getResultTypes()); for (auto assign : tail_assign_variable_ops) { - return_op->insertOperands(return_op->getNumOperands(), assign.value()); + return_op->insertOperands(return_op->getNumOperands(), assign.getValue()); new_result_types.insert(new_result_types.end(), num_replicas, - assign.value().getType()); + assign.getValue().getType()); } OpBuilder builder(replicate_op); @@ -114,7 +114,7 @@ SmallVector GetTailWritesToReplicateInvariantResourceVars( if (op_accessed_resources.empty()) continue; if (auto assign = llvm::dyn_cast(op)) { - Value resource_var = assign.resource(); + Value resource_var = assign.getResource(); if (visited_resources.contains(resource_var) || !resource_var.getParentRegion()->isProperAncestor( &replicate_op.getRegion())) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc index be2e6575a81..67bf6fa4221 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc @@ -74,14 +74,14 @@ class ConvertInitializeTableFromTextFileV2 // // In the above case, the delimiter will be not used since the key is just a // whole line and value is a line number. - if (op.key_index() != kTextFileIndex_WholeLine || - op.value_index() != kTextFileIndex_LineNumber) { + if (op.getKeyIndex() != kTextFileIndex_WholeLine || + op.getValueIndex() != kTextFileIndex_LineNumber) { return failure(); } // Try to find filename from constant op. DenseStringElementsAttr filename_attr; - if (!matchPattern(op.filename().getDefiningOp(), + if (!matchPattern(op.getFilename().getDefiningOp(), m_Constant(&filename_attr))) { return failure(); } @@ -111,7 +111,7 @@ class ConvertInitializeTableFromTextFileV2 file->getBuffer().split(lines, "\n", -1, false); // The resize method is used since split operator puts tail value in the end // without splitting the leftovers. - if (op.vocab_size() != -1) lines.resize(op.vocab_size()); + if (op.getVocabSize() != -1) lines.resize(op.getVocabSize()); // Map each line to line number, starting from zero. SmallVector line_nums; @@ -130,7 +130,7 @@ class ConvertInitializeTableFromTextFileV2 op.getLoc(), rewriter.getI64TensorAttr(line_nums)); // Replace the given op with LookupTableImportV2Op. - rewriter.create(op.getLoc(), op.table_handle(), + rewriter.create(op.getLoc(), op.getTableHandle(), key_constant_tensor, value_constant_tensor); rewriter.eraseOp(op); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc b/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc index 4624f5107e5..44e178ac76c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc @@ -54,9 +54,6 @@ void InitializeVariable(TF::VarHandleOp var_handle_op, const_op.getResult()}); } -constexpr char kTfSavedModelExportedNameAttr[] = - "tf_saved_model.exported_names"; - func::FuncOp CreateSessionInitFunc(ModuleOp module) { constexpr char kSessionInitFuncName[] = "SessionInitializerFunction"; @@ -65,7 +62,7 @@ func::FuncOp CreateSessionInitFunc(ModuleOp module) { FunctionType::get(module.getContext(), /*inputs=*/{}, /*results=*/{}); auto func = builder.create(module->getLoc(), kSessionInitFuncName, func_type); - func->setAttr(kTfSavedModelExportedNameAttr, + func->setAttr(kTfSavedModelExportedNamesAttr, builder.getStrArrayAttr({kSessionInitFuncName})); func->setAttr(kTfSavedModelInitializerTypeAttr, builder.getStringAttr(kTfSavedModelInitializerRestoreType)); @@ -91,35 +88,23 @@ func::FuncOp GetOrCreateSessionInitFunc(ModuleOp module) { SessionInitializerOp session_init_op = GetSessionInitializerOp(module); if (!session_init_op) return CreateSessionInitFunc(module); - SymbolTable symbol_table(module); - - // Find the init function that has tf_saved_model.initializer_type == - // "restore_op". - for (auto init_sym : - session_init_op.getInitializers().getAsValueRange()) { - auto init_func_op = symbol_table.lookup(init_sym); - - const auto init_type_attr = init_func_op->getAttrOfType( - kTfSavedModelInitializerTypeAttr); - if (init_type_attr && - init_type_attr == kTfSavedModelInitializerRestoreType) { - return init_func_op; - } - } - - // When the init function with type "restore_op" is not found, fall back to - // taking the init function corresponding to the first symbol in the - // initializers list to be backwards-compatible, before - // tf_saved_model.initializer_type attribute was introduced. - if (!session_init_op.getInitializers().empty()) { - auto init_func_op = - symbol_table.lookup(session_init_op.getInitializers()[0] - .cast() - .getValue()); + auto init_func_op = GetInitializerFunction( + module, /*initializer_type=*/kTfSavedModelInitializerRestoreType); + if (init_func_op) { return init_func_op; + } else if (!session_init_op.getInitializers().empty()) { + // When the init function with type "restore_op" is not found, fall back to + // taking the init function corresponding to the first symbol in the + // initializers list to be backwards-compatible, before + // tf_saved_model.initializer_type attribute was introduced. + SymbolTable symbol_table(module); + return symbol_table.lookup( + session_init_op.getInitializers()[0] + .cast() + .getValue()); + } else { + return CreateSessionInitFunc(module); } - - return CreateSessionInitFunc(module); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc index 40e22f503bd..aa1efc6837e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc @@ -13,16 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "llvm/Support/Casting.h" +#include +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/split_into_island_per_op_pass.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" namespace mlir { namespace TFDevice { @@ -35,6 +40,11 @@ constexpr char kDeviceAttr[] = "device"; struct LaunchToDeviceAttributePass : public impl::LaunchToDeviceAttributePassBase< LaunchToDeviceAttributePass> { + public: + explicit LaunchToDeviceAttributePass(bool legacy_graph_export) { + legacy_graph_export_ = legacy_graph_export; + } + void runOnOperation() override; }; @@ -42,9 +52,14 @@ struct LaunchToDeviceAttributePass LogicalResult AssignDevicesInRegion(const Dialect* tf_dialect, tf_device::LaunchOp launch, Region& region) { + auto parallel_group_attr = + launch->getAttrOfType(TF::kParallelExecAnnotation); auto result = region.walk([&](Operation* op) -> WalkResult { if (op->getDialect() != tf_dialect) return WalkResult::advance(); + if (parallel_group_attr) { + op->setAttr(TF::kParallelExecAnnotation, parallel_group_attr); + } auto device_attr = op->getAttr(kDeviceAttr); if (!device_attr) { op->setAttr(kDeviceAttr, launch.getDeviceAttr()); @@ -109,13 +124,25 @@ void LaunchToDeviceAttributePass::runOnOperation() { }); if (result.wasInterrupted()) return signalPassFailure(); + + if (!legacy_graph_export_) { + // Now, split the island into an island per op since we don't want to + // violate the invariant imposed by the GraphExport pipeline that every + // IslandOp perfectly wraps a single op. + auto control_type = + mlir::tf_executor::ControlType::get(tf_dialect->getContext()); + getOperation().walk( + [&control_type](mlir::tf_executor::IslandOp curr_island) { + mlir::TF::SplitIsland(curr_island, control_type); + }); + } } } // anonymous namespace -std::unique_ptr> -CreateLaunchToDeviceAttributePass() { - return std::make_unique(); +std::unique_ptr> CreateLaunchToDeviceAttributePass( + bool legacy_graph_export) { + return std::make_unique(legacy_graph_export); } } // namespace TFDevice diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc index 31676934544..0fad3c019ea 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc @@ -201,13 +201,13 @@ void MoveTransposeBefore(Operation* op, SmallVector* work_list) { if (!perm) return; // With the same permutation indices. - auto dense_elem_attr = perm.value().dyn_cast(); + auto dense_elem_attr = perm.getValue().dyn_cast(); if (!dense_elem_attr) return; if (!permutation_op) permutation_op = perm; // Check that permutation matches for all result transposes. - if (perm.value() != permutation_op.value()) return; + if (perm.getValue() != permutation_op.getValue()) return; // Add a transpose operation for later reuse. transpose_ops.push_back(transpose); @@ -217,7 +217,7 @@ void MoveTransposeBefore(Operation* op, SmallVector* work_list) { // Nothing to do here. if (!permutation_op || transpose_ops.empty()) return; SmallVector permutation; - auto perm_attr = permutation_op.value().cast(); + auto perm_attr = permutation_op.getValue().cast(); for (const auto& value : perm_attr.getValues()) permutation.push_back(value.getSExtValue()); @@ -247,7 +247,8 @@ void MoveTransposeBefore(Operation* op, SmallVector* work_list) { // Bypass Transpose nodes for all results. for (OpResult result : op->getResults()) { - result.setType(cast(*result.getUsers().begin()).y().getType()); + result.setType( + cast(*result.getUsers().begin()).getY().getType()); for (Operation* transpose : result.getUsers()) { transpose->getResult(0).replaceAllUsesWith(result); } @@ -342,13 +343,13 @@ void MoveTransposeAfter(Operation* op, SmallVector* work_list, if (!perm) return; // With the same permutation indices. - auto dense_elem_attr = perm.value().dyn_cast(); + auto dense_elem_attr = perm.getValue().dyn_cast(); if (!dense_elem_attr) return; if (!permutation_op) permutation_op = perm; // Check that permutation matches for all result transposes. - if (perm.value() != permutation_op.value()) return; + if (perm.getValue() != permutation_op.getValue()) return; // Add a transpose operation for later reuse only if it's used once. if (transpose.getResult().hasOneUse()) transpose_ops.push_back(transpose); @@ -364,7 +365,7 @@ void MoveTransposeAfter(Operation* op, SmallVector* work_list, SmallVector permutation; - auto attr = permutation_op.value().cast(); + auto attr = permutation_op.getValue().cast(); for (const auto& value : attr.getValues()) permutation.push_back(value.getSExtValue()); @@ -372,7 +373,7 @@ void MoveTransposeAfter(Operation* op, SmallVector* work_list, if (fold_operands && fold_transpose_in_ops) { SmallVector permutation; - auto attr = permutation_op.value().cast(); + auto attr = permutation_op.getValue().cast(); for (const auto& value : attr.getValues()) permutation.push_back(value.getSExtValue()); @@ -421,7 +422,7 @@ void MoveTransposeAfter(Operation* op, SmallVector* work_list, transpose.getOperation()->moveBefore(op->getNextNode()); transpose.setOperand(0, result); transpose.setOperand(1, permutation_op); - transpose.getResult().setType(original_type[idx]); + transpose.getResult().setType(original_type[idx].cast()); } else { transpose = builder.create(loc, result, permutation_op); } @@ -451,7 +452,7 @@ void MoveTransposesPass::runOnOperation() { } } else { // Try to push transpose after the user operation. - for (Operation* user : transpose.y().getUsers()) { + for (Operation* user : transpose.getY().getUsers()) { if (!llvm::isa(user)) work_list.push_back(user); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index d7f55ea13c2..52522210055 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -28,7 +29,6 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLForwardCompat.h" #include "llvm/ADT/Sequence.h" @@ -39,7 +39,7 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -59,7 +59,7 @@ limitations under the License. #include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/lib/math/math_util.h" @@ -70,11 +70,11 @@ namespace { using mhlo::DotDimensionNumbersAttr; // Replaces `region`'s terminator to TF::Yield. -void ReplaceReturnOp(Region ®ion, PatternRewriter &rewriter) { +void ReplaceReturnOp(Region& region, PatternRewriter& rewriter) { OpBuilder::InsertionGuard guard(rewriter); - for (auto &block : region.getBlocks()) { - Operation *terminator = block.getTerminator(); + for (auto& block : region.getBlocks()) { + Operation* terminator = block.getTerminator(); auto return_op = llvm::dyn_cast_or_null(terminator); if (return_op == nullptr) continue; @@ -88,7 +88,7 @@ void ReplaceReturnOp(Region ®ion, PatternRewriter &rewriter) { // to the splate constant value. // `SplatValueType` can be `APInt` or `APFloat`. template -LogicalResult GetConstantSplatValue(Value value, SplatValueType &splat_value) { +LogicalResult GetConstantSplatValue(Value value, SplatValueType& splat_value) { DenseElementsAttr attr; if (!matchPattern(value, m_Constant(&attr)) || !attr.isSplat()) { return failure(); @@ -107,7 +107,7 @@ struct PermutationAndShape { // applying the permutation to a given shape through a transpose. PermutationAndShape GetPermutationAndTransposedShape( llvm::ArrayRef permutation_array, ShapedType input_type, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter& rewriter) { assert(permutation_array.size() == input_type.getRank()); llvm::SmallVector transposed_shape(permutation_array.size()); for (int64_t i = 0; i < permutation_array.size(); ++i) { @@ -137,7 +137,7 @@ llvm::SmallVector GetInversePermutationArray( // permutation_array. DenseIntElementsAttr GetInversePermutation( llvm::ArrayRef permutation_array, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter& rewriter) { SmallVector inverse_permutation_array = GetInversePermutationArray(permutation_array); return DenseIntElementsAttr::get( @@ -150,7 +150,7 @@ DenseIntElementsAttr GetInversePermutation( // applying the inverse permutation to a given shape through a transpose. PermutationAndShape GetInversePermutationAndShape( llvm::ArrayRef permutation_array, ShapedType input_type, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter& rewriter) { SmallVector inverse_permutation_array = GetInversePermutationArray(permutation_array); return GetPermutationAndTransposedShape(inverse_permutation_array, input_type, @@ -169,13 +169,13 @@ struct ConvertNdConvOp { // All ones in "lhs_dilation" means this "mhlo.conv" op should be // converted to "tf.Conv2D" or "tf.DepthwiseConv2dNativeOp". if (conv_op.getLhsDilation().has_value()) { - auto lhs_dilation = conv_op.getLhsDilation().getValue(); + auto lhs_dilation = conv_op.getLhsDilation().value(); if (!lhs_dilation.isSplat() || lhs_dilation.getSplatValue() != 1) return false; } if (!conv_op.getWindowStrides().has_value() || conv_op.getWindowStrides() - .getValue() + .value() .getType() .cast() .getRank() != 1) @@ -199,7 +199,7 @@ class Convert1DConvOp : public OpConversionPattern, LogicalResult matchAndRewrite( mhlo::ConvolutionOp conv_op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { // // Check that input is a supported 1d convolution. // @@ -284,7 +284,7 @@ class Convert1DConvOp : public OpConversionPattern, // Padding SmallVector padding_2d_array; - for (const auto v : conv_op.getPadding().getValue().getValues()) { + for (const auto v : conv_op.getPadding().value().getValues()) { padding_2d_array.emplace_back(v); } // The newly added spatial dimension requires zero left and right padding. @@ -295,8 +295,7 @@ class Convert1DConvOp : public OpConversionPattern, // LHS dilation SmallVector lhs_dilation_array_2d; - for (const auto v : - conv_op.getLhsDilation().getValue().getValues()) { + for (const auto v : conv_op.getLhsDilation().value().getValues()) { lhs_dilation_array_2d.emplace_back(v); } lhs_dilation_array_2d.push_back(1); @@ -306,8 +305,7 @@ class Convert1DConvOp : public OpConversionPattern, // RHS dilation SmallVector rhs_dilation_array_2d; - for (const auto v : - conv_op.getRhsDilation().getValue().getValues()) { + for (const auto v : conv_op.getRhsDilation().value().getValues()) { rhs_dilation_array_2d.emplace_back(v); } rhs_dilation_array_2d.push_back(1); @@ -396,16 +394,24 @@ class Convert2DConvOp : public OpConversionPattern, LogicalResult matchAndRewrite( mhlo::ConvolutionOp conv_op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { if (!IsSupportedConvOp(conv_op)) { return failure(); } + // tf Convolution doesn't support quantized type. + if (conv_op.getRhs() + .getType() + .getElementType() + .isa()) { + return failure(); + } + // Constructs strides array. // For example, [2, 3] -> [1, 2, 3, 1]. SmallVector strides({1}); for (const auto v : - conv_op.getWindowStrides().getValue().getValues()) { + conv_op.getWindowStrides().value().getValues()) { strides.emplace_back(v); } strides.emplace_back(1); @@ -415,8 +421,8 @@ class Convert2DConvOp : public OpConversionPattern, if (auto rhs_dilation = conv_op.getRhsDilation()) { // For example, [2, 3] -> [1, 2, 3, 1]. dilation.emplace_back(1); - dilation.append(rhs_dilation.getValue().getValues().begin(), - rhs_dilation.getValue().getValues().end()); + dilation.append(rhs_dilation.value().getValues().begin(), + rhs_dilation.value().getValues().end()); dilation.emplace_back(1); } else { // Default value @@ -440,13 +446,12 @@ class Convert2DConvOp : public OpConversionPattern, std::string padding; SmallVector explicit_padding; if (!conv_op.getPadding().has_value() || - (conv_op.getPadding().getValue().isSplat() && + (conv_op.getPadding().value().isSplat() && conv_op.getPadding()->getSplatValue() == 0)) { padding = "VALID"; } else { SmallVector padding_array; - for (const auto v : - conv_op.getPadding().getValue().getValues()) { + for (const auto v : conv_op.getPadding().value().getValues()) { padding_array.emplace_back(v); } @@ -516,7 +521,7 @@ class Convert2DConvOp : public OpConversionPattern, int default_feature_dim, int default_spatial_dim_start, int num_spatial_dims, RankedTensorType type, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter& rewriter) const { auto shape = type.getShape(); llvm::SmallVector permutation_array(num_spatial_dims + 2); permutation_array[default_batch_dim] = batch_dim; @@ -541,7 +546,7 @@ class Convert2DConvOp : public OpConversionPattern, ArrayRef spatial_dimensions, int default_batch_dim, int default_feature_dim, int default_spatial_dim_start, int num_spatial_dims, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter& rewriter) const { auto type = value.getType().cast(); DenseIntElementsAttr permutation; const int spatial_dim_start = spatial_dimensions.front(); @@ -551,7 +556,7 @@ class Convert2DConvOp : public OpConversionPattern, // Transpose is not needed because the current format is "NHWC". return value; } - std::pair(type, permutation) = + std::pair(type, permutation) = GetReformatTypeAndPermutation(batch_dim, feature_dim, spatial_dim_start, default_batch_dim, default_feature_dim, default_spatial_dim_start, @@ -563,7 +568,7 @@ class Convert2DConvOp : public OpConversionPattern, // Slices the input `value` if there are negative padding values in // `explicit_padding`. Value SliceNegativePadding(Value value, ArrayRef explicit_padding, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter& rewriter) const { // If no padding is negative return the input as is. if (llvm::all_of(explicit_padding, [](int64_t pad) { return pad >= 0; })) { return value; @@ -607,7 +612,7 @@ class Convert2DConvOp : public OpConversionPattern, StringRef padding, ArrayRef explicit_padding, ArrayRef dilation, bool is_depthwise_conv, int input_channels, int num_spatial_dims, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter& rewriter) const { mhlo::ConvDimensionNumbersAttr dnums = conv_op.getDimensionNumbers(); // Transposes lhs and rhs if their formats are not NHWC. Value lhs = FormatToNHWC( @@ -637,8 +642,8 @@ class Convert2DConvOp : public OpConversionPattern, /*default_batch_dim=*/0, /*default_feature_dim=*/num_spatial_dims + 1, /*default_spatial_dim_start=*/1); if (need_transpose_output) { - std::pair(conv_output_type, - permutation) = + std::pair(conv_output_type, + permutation) = GetReformatTypeAndPermutation( dnums.getOutputBatchDimension(), dnums.getOutputFeatureDimension(), @@ -684,8 +689,8 @@ class Convert2DConvOp : public OpConversionPattern, if (need_transpose_output) { // Converts from "NHWC" format back to the original output format. - std::pair(conv_output_type, - permutation) = + std::pair(conv_output_type, + permutation) = GetReformatTypeAndPermutation( /*batch_dim=*/0, /*feature_dim=*/num_spatial_dims + 1, /*spatial_dim_start=*/1, dnums.getOutputBatchDimension(), @@ -706,7 +711,7 @@ class ConvertNonTrivialConvOp LogicalResult matchAndRewrite( mhlo::ConvolutionOp conv_op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { if (IsSupportedConvOp(conv_op, rewriter).failed()) { return rewriter.notifyMatchFailure( conv_op, @@ -728,8 +733,8 @@ class ConvertNonTrivialConvOp // For example, [2, 3] -> [1, 2, 3, 1]. SmallVector strides({1}); strides.append( - conv_op.getLhsDilation().getValue().getValues().begin(), - conv_op.getLhsDilation().getValue().getValues().end()); + conv_op.getLhsDilation().value().getValues().begin(), + conv_op.getLhsDilation().value().getValues().end()); strides.emplace_back(1); // Constructs dilation array. @@ -737,8 +742,8 @@ class ConvertNonTrivialConvOp if (auto rhs_dilation = conv_op.getRhsDilation()) { // For example, [2, 3] -> [1, 2, 3, 1]. dilation.emplace_back(1); - dilation.append(rhs_dilation.getValue().getValues().begin(), - rhs_dilation.getValue().getValues().end()); + dilation.append(rhs_dilation.value().getValues().begin(), + rhs_dilation.value().getValues().end()); dilation.emplace_back(1); } else { // Default value @@ -748,7 +753,7 @@ class ConvertNonTrivialConvOp mhlo::ConvDimensionNumbersAttr dnums = conv_op.getDimensionNumbers(); std::string padding; if (!conv_op.getPadding().has_value() || - (conv_op.getPadding().getValue().isSplat() && + (conv_op.getPadding().value().isSplat() && conv_op.getPadding()->getSplatValue() == 0)) { padding = "VALID"; } else { @@ -807,7 +812,7 @@ class ConvertNonTrivialConvOp } LogicalResult IsSupportedConvOp(mhlo::ConvolutionOp conv_op, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter& rewriter) const { if (!conv_op.getLhs().getType().cast().hasStaticShape() || !conv_op.getRhs().getType().cast().hasStaticShape() || !conv_op.getType().cast().hasStaticShape()) @@ -830,13 +835,13 @@ class ConvertNonTrivialConvOp return rewriter.notifyMatchFailure(conv_op, "requires lhs_dilation attribute"); } - auto lhs_dilation = conv_op.getLhsDilation().getValue(); + auto lhs_dilation = conv_op.getLhsDilation().value(); if (lhs_dilation.isSplat() && lhs_dilation.getSplatValue() == 1) return rewriter.notifyMatchFailure(conv_op, "requires non-trivial lhs_dilation"); if (!conv_op.getWindowStrides().has_value() || conv_op.getWindowStrides() - .getValue() + .value() .getType() .cast() .getRank() != 1) @@ -893,7 +898,7 @@ class ConvertNonTrivialConvOp void CreateResizeBilinearOp(mhlo::ConvolutionOp conv_op, llvm::ArrayRef output_sizes, bool align_corners, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter& rewriter) const { Value output_sizes_attr = rewriter.create( conv_op.getLoc(), DenseIntElementsAttr::get( @@ -911,9 +916,9 @@ class ConvertNonTrivialConvOp rewriter.replaceOp(conv_op, {output}); } - LogicalResult MatchResizeOp(mhlo::ConvolutionOp conv_op, bool &align_corners, - llvm::SmallVector &output_sizes, - ConversionPatternRewriter &rewriter) const { + LogicalResult MatchResizeOp(mhlo::ConvolutionOp conv_op, bool& align_corners, + llvm::SmallVector& output_sizes, + ConversionPatternRewriter& rewriter) const { mhlo::ConvDimensionNumbersAttr dnums = conv_op.getDimensionNumbers(); auto input_spatial_dimensions = dnums.getInputSpatialDimensions(); auto kernel_spatial_dimensions = dnums.getKernelSpatialDimensions(); @@ -934,10 +939,10 @@ class ConvertNonTrivialConvOp return rewriter.notifyMatchFailure( conv_op, "resize op requires rhs_dilation and padding"); - auto lhs_dilation = conv_op.getLhsDilation().getValue(); - auto rhs_dilation = conv_op.getRhsDilation().getValue(); - auto window_strides = conv_op.getWindowStrides().getValue(); - auto padding = conv_op.getPadding().getValue(); + auto lhs_dilation = conv_op.getLhsDilation().value(); + auto rhs_dilation = conv_op.getRhsDilation().value(); + auto window_strides = conv_op.getWindowStrides().value(); + auto padding = conv_op.getPadding().value(); if (lhs_dilation.getNumElements() != 2 || !rhs_dilation.isSplat() || rhs_dilation.getSplatValue() != 1 || window_strides.getNumElements() != 2 || padding.getNumElements() != 4) @@ -1023,7 +1028,7 @@ class ConvertSliceOp : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::SliceOp slice_op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { auto begin = rewriter.create(slice_op.getLoc(), slice_op.getStartIndices()); auto end = @@ -1043,7 +1048,7 @@ class ConvertDynamicSliceOp : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::DynamicSliceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { ShapedType input_type = op.getOperand().getType().cast(); if (!input_type.hasStaticShape()) return failure(); Type start_indices_element_type = op.getStartIndices() @@ -1101,33 +1106,33 @@ class ConvertDynamicSliceOp : public OpConversionPattern { // Appends all elements in `range` to `values`. template -void Append(llvm::SmallVectorImpl &values, Range &&range) { +void Append(llvm::SmallVectorImpl& values, Range&& range) { values.insert(values.end(), range.begin(), range.end()); } // Appends all elements in `range` to `values`. template -void Append(llvm::SmallVectorImpl &values, Range &&range, - RangeTs &&...ranges) { +void Append(llvm::SmallVectorImpl& values, Range&& range, + RangeTs&&... ranges) { values.insert(values.end(), range.begin(), range.end()); Append(values, ranges...); } // Returns the number of elements in `range`. template -size_t Size(Range &&range) { +size_t Size(Range&& range) { return range.size(); } // Returns the total number of elements in a variadic number of `ranges`. template -size_t Size(Range &&range, RangeTs &&...ranges) { +size_t Size(Range&& range, RangeTs&&... ranges) { return range.size() + Size(std::forward(ranges)...); } // Concats all elements in `ranges` and returns a small vector as a result. template -llvm::SmallVector Concat(RangeTs &&...ranges) { +llvm::SmallVector Concat(RangeTs&&... ranges) { llvm::SmallVector results; results.reserve(Size(std::forward(ranges)...)); Append(results, std::forward(ranges)...); @@ -1144,16 +1149,16 @@ struct DimensionVector { }; // Create a single const integer. -Value BuildIntConstOp(ImplicitLocOpBuilder &builder, - ConversionPatternRewriter &rewriter, int64_t const_value, +Value BuildIntConstOp(ImplicitLocOpBuilder& builder, + ConversionPatternRewriter& rewriter, int64_t const_value, Type type) { Value result_const = builder.create(rewriter.getIntegerAttr(type, const_value)); return result_const; } // Create a const integer vector tensor (1-dim). -Value BuildIntArrayConstOp(ImplicitLocOpBuilder &builder, - ConversionPatternRewriter &rewriter, +Value BuildIntArrayConstOp(ImplicitLocOpBuilder& builder, + ConversionPatternRewriter& rewriter, ArrayRef const_value, Type type) { DenseIntElementsAttr const_value_raw; if (type == rewriter.getI64Type()) { @@ -1171,8 +1176,8 @@ Value BuildIntArrayConstOp(ImplicitLocOpBuilder &builder, } // Create a tensor that is reshaped from input. -Value BuildReshapeOp(ImplicitLocOpBuilder &builder, - ConversionPatternRewriter &rewriter, Value input, +Value BuildReshapeOp(ImplicitLocOpBuilder& builder, + ConversionPatternRewriter& rewriter, Value input, ArrayRef shape, Type idx_type, Type element_type) { Value shape_cst = BuildIntArrayConstOp(builder, rewriter, shape, idx_type); @@ -1182,8 +1187,8 @@ Value BuildReshapeOp(ImplicitLocOpBuilder &builder, } // Create a tensor which is equal to input[begin: begin + size]. -Value BuildSliceOp(ImplicitLocOpBuilder &builder, - ConversionPatternRewriter &rewriter, Value input, +Value BuildSliceOp(ImplicitLocOpBuilder& builder, + ConversionPatternRewriter& rewriter, Value input, Value begin, ArrayRef shape, Type idx_type, Type element_type) { Value shape_cst = BuildIntArrayConstOp(builder, rewriter, shape, idx_type); @@ -1199,7 +1204,7 @@ class ConvertDynamicUpdateSliceOp LogicalResult matchAndRewrite( mhlo::DynamicUpdateSliceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { ShapedType operand_type = op.getOperand().getType().cast(); ShapedType update_type = op.getUpdate().getType().dyn_cast_or_null(); @@ -1409,7 +1414,7 @@ class StridedArrayViewBase { if (index[dim] < shape[dim]) return std::move(index); index[dim] = 0; } - return llvm::None; + return std::nullopt; } protected: @@ -1472,7 +1477,7 @@ class StridedArrayView; // Class requires specialization. template <> class StridedArrayView : StridedArrayViewBase { public: - StridedArrayView(const DenseIntElementsAttr &data, ArrayRef shape, + StridedArrayView(const DenseIntElementsAttr& data, ArrayRef shape, ArrayRef index, int64_t axis) : StridedArrayViewBase(shape, index, axis), data_(data) { int64_t element_count = 1; @@ -1490,7 +1495,7 @@ class StridedArrayView : StridedArrayViewBase { } private: - const DenseIntElementsAttr &data_; + const DenseIntElementsAttr& data_; }; // Matches %iota generated from the following mlir codes (rank 2 example): @@ -1542,11 +1547,11 @@ bool MatchIota(DenseIntElementsAttr dimensions, Value iota) { MatchIotaConst(dimensions, iota); } -bool MatchTopKComparator(Region &comparator) { +bool MatchTopKComparator(Region& comparator) { if (!comparator.hasOneBlock()) return false; - Block &comparator_blk = comparator.front(); + Block& comparator_blk = comparator.front(); using OpListType = llvm::iplist; - OpListType &operations = comparator_blk.getOperations(); + OpListType& operations = comparator_blk.getOperations(); if (operations.size() != 2) return false; auto compare_op = dyn_cast_or_null(&operations.front()); auto return_op = dyn_cast_or_null(&operations.back()); @@ -1575,7 +1580,7 @@ class ConvertSortToTfTopk : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::SortOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { if (op->getOperands().size() != 2) return rewriter.notifyMatchFailure( op, "only match for the case where operands is of size 2"); @@ -1643,13 +1648,13 @@ class DotDimensionsInfo { } } - const DimensionVector &batch_dimensions() const { return batch_dimensions_; } - const DimensionVector &contracting_dimensions() const { + const DimensionVector& batch_dimensions() const { return batch_dimensions_; } + const DimensionVector& contracting_dimensions() const { return contracting_dimensions_; } // Out dimensions are any dimensions that are neither batch nor contracting // dimensions, hence will be propagated to output shape. - const DimensionVector &out_dimensions() const { return out_dimensions_; } + const DimensionVector& out_dimensions() const { return out_dimensions_; } // Returns the total dimension size after flattening all contracting // dimensions. @@ -1674,7 +1679,7 @@ class DotDimensionsInfo { DimensionVector out_dimensions_; }; -Value ConvertDot(PatternRewriter &rewriter, Value lhs, Value rhs, +Value ConvertDot(PatternRewriter& rewriter, Value lhs, Value rhs, DotDimensionNumbersAttr dot_dimension_numbers, ShapedType result_type, mlir::Location loc) { auto lhs_type = lhs.getType().cast(); @@ -1767,7 +1772,7 @@ Value ConvertDot(PatternRewriter &rewriter, Value lhs, Value rhs, // Converts mhlo.dot to tf.MatMul. Reshape ops will be inserted when // necessary. -Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) { +Value ConvertDotOp(PatternRewriter& rewriter, Operation* old_op) { auto dot_op = cast(old_op); auto lhs_rank = dot_op.getLhs().getType().cast().getRank(); auto dot_dimension_numbers = @@ -1784,7 +1789,7 @@ Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) { // Converts mhlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be // inserted to convert to well-formed matrix multiply. -Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) { +Value ConvertDotGeneralOp(PatternRewriter& rewriter, Operation* old_op) { auto dot_general_op = cast(old_op); return ConvertDot(rewriter, dot_general_op.getLhs(), dot_general_op.getRhs(), dot_general_op.getDotDimensionNumbers(), @@ -1796,8 +1801,8 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) { // inputs, passes it to an instance of the specifiied reduction op and then // returns the result. template -LogicalResult MatchBinaryReduceFunction(mlir::Region &function) { - Block &body = function.front(); +LogicalResult MatchBinaryReduceFunction(mlir::Region& function) { + Block& body = function.front(); if (body.getNumArguments() != 2) return failure(); mhlo::ReturnOp return_op = dyn_cast(body.back()); @@ -1818,8 +1823,8 @@ LogicalResult MatchBinaryReduceFunction(mlir::Region &function) { // inputs and returns the second input. Functions like this are used by update // scatter like ops. template <> -LogicalResult MatchBinaryReduceFunction(mlir::Region &function) { - Block &body = function.front(); +LogicalResult MatchBinaryReduceFunction(mlir::Region& function) { + Block& body = function.front(); if (body.getNumArguments() != 2) return failure(); mhlo::ReturnOp return_op = dyn_cast(body.back()); @@ -1834,7 +1839,7 @@ LogicalResult MatchBinaryReduceFunction(mlir::Region &function) { template LogicalResult rewriteNonMatchInitValue(mhlo::ReduceOp reduce_op, Value input, ConstOp reduction_indices, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter& rewriter) { Value reduce_result = rewriter.create( reduce_op.getLoc(), reduce_op.getType(0), input, reduction_indices, /*keep_dim=*/rewriter.getBoolAttr(false)); @@ -1849,14 +1854,14 @@ LogicalResult rewriteNonMatchInitValue(mhlo::ReduceOp reduce_op, Value input, template <> LogicalResult rewriteNonMatchInitValue( mhlo::ReduceOp reduce_op, Value input, ConstOp reduction_indices, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter& rewriter) { return failure(); } template <> LogicalResult rewriteNonMatchInitValue( mhlo::ReduceOp reduce_op, Value input, ConstOp reduction_indices, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter& rewriter) { return failure(); } @@ -1877,7 +1882,7 @@ class ConvertReduceOpToTfOp : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::ReduceOp reduce_op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { if (failed(MatchReduceOpOperand(reduce_op))) return failure(); if (failed(MatchBinaryReduceFunction(reduce_op.getBody()))) @@ -1888,7 +1893,7 @@ class ConvertReduceOpToTfOp : public OpConversionPattern { // Get reduction dimension. DenseIntElementsAttr dimension = reduce_op.getDimensions(); SmallVector reduce_dims; - for (const int64_t &dim : dimension.getValues()) { + for (const int64_t& dim : dimension.getValues()) { reduce_dims.emplace_back(dim); } auto dim_type = RankedTensorType::get( @@ -2045,7 +2050,7 @@ class ConvertReduceOpToTfArgMinMax using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( mhlo::ReduceOp reduce_op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { if (reduce_op.getInputs().size() != 2) return failure(); if (reduce_op.getDimensions().getNumElements() != 1) return failure(); @@ -2106,9 +2111,9 @@ class ConvertReduceOpToTfArgMinMax // %8 = select(%7, %lhs_index, %rhs_index) // return %3, %8 // Also note that %1 may be folded if %lhs_value is of integer types. - LogicalResult matchReduceComputation(Region &computation, + LogicalResult matchReduceComputation(Region& computation, bool is_float) const { - Block &body = computation.front(); + Block& body = computation.front(); if (body.getNumArguments() != 4) return failure(); mhlo::ReturnOp return_op = dyn_cast(body.back()); @@ -2187,7 +2192,7 @@ class ConvertReduceOpToTfArgMinMax virtual mhlo::ComparisonDirection CompareDirection() const = 0; - virtual bool IsValueInitValue(const DenseElementsAttr &attr) const = 0; + virtual bool IsValueInitValue(const DenseElementsAttr& attr) const = 0; }; class ConvertReduceOpToTfArgmax @@ -2198,7 +2203,7 @@ class ConvertReduceOpToTfArgmax mhlo::ComparisonDirection CompareDirection() const override { return mhlo::ComparisonDirection::GT; } - bool IsValueInitValue(const DenseElementsAttr &attr) const override { + bool IsValueInitValue(const DenseElementsAttr& attr) const override { auto element_type = attr.getType().getElementType(); if (attr.getNumElements() != 1 || !element_type.isIntOrFloat() || element_type.isInteger(1)) @@ -2222,7 +2227,7 @@ class ConvertReduceOpToTfArgmin mhlo::ComparisonDirection CompareDirection() const override { return mhlo::ComparisonDirection::LT; } - bool IsValueInitValue(const DenseElementsAttr &attr) const override { + bool IsValueInitValue(const DenseElementsAttr& attr) const override { auto element_type = attr.getType().getElementType(); if (attr.getNumElements() != 1 || !element_type.isIntOrFloat() || element_type.isInteger(1)) @@ -2244,7 +2249,7 @@ class ConvertIotaOpToTfRange : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::IotaOp iota_op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { RankedTensorType type = iota_op.getType().dyn_cast_or_null(); // TF::RangeOp doesn't support UI16. @@ -2298,8 +2303,8 @@ class ConvertIotaOpToTfRange : public OpConversionPattern { // true, also outputs the window strides and the TF padding mode ("VALID" or // "SAME"). bool IsSpatialPoolingWithoutDilation( - mhlo::ReduceWindowOp rw, llvm::SmallVectorImpl *window_strides, - std::string *padding_mode) { + mhlo::ReduceWindowOp rw, llvm::SmallVectorImpl* window_strides, + std::string* padding_mode) { // tf.max_pool or tf.avg_pool need at least 3 dimensions (batch, spatial, // channel). const uint64_t rank = rw.getWindowDimensions().size(); @@ -2379,11 +2384,11 @@ class ConvertLoweredCumOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; - virtual bool IsInitValue(const DenseElementsAttr &attr) const = 0; + virtual bool IsInitValue(const DenseElementsAttr& attr) const = 0; LogicalResult matchAndRewrite( mhlo::ReduceWindowOp rw, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { if (rw.getNumResults() != 1 || rw.getInputs().size() != 1 || rw.getInitValues().size() != 1) return failure(); @@ -2406,7 +2411,7 @@ class ConvertLoweredCumOp : public OpConversionPattern { // operand. auto is_splat_int64_ones = [&rewriter, - &operand_type](const ::llvm::Optional &attr) { + &operand_type](const std::optional& attr) { // According to the definition, the default value of these attributes // are all ones when unspecified. if (!attr.has_value()) return true; @@ -2430,7 +2435,7 @@ class ConvertLoweredCumOp : public OpConversionPattern { // where N is the same as the size of the corresponding input dimension // and there is a 1-entry for each input dimension not being operated // over. - const auto &window_dimensions = rw.getWindowDimensions(); + const auto& window_dimensions = rw.getWindowDimensions(); if (window_dimensions.size() != operand_type.getRank()) return failure(); int64_t cumulative_axis = -1; for (int64_t i = 0, e = window_dimensions.size(); i < e; ++i) { @@ -2452,7 +2457,7 @@ class ConvertLoweredCumOp : public OpConversionPattern { // dense<[[0, 0], [0, 0], [N-1, 0], [0, 0]]> // where N is the size of the input dimension being operated over. if (!rw.getPadding()) return failure(); - const auto &padding = rw.getPadding()->getValues(); + const auto& padding = rw.getPadding()->getValues(); if (padding.size() != operand_type.getRank() * 2) return failure(); int64_t padding_value = operand_type.getShape()[cumulative_axis] - 1; for (int64_t dim = 0; dim < operand_type.getRank(); ++dim) { @@ -2480,7 +2485,7 @@ class ConvertLoweredCumOp : public OpConversionPattern { class ConvertLoweredCumSumOp : public ConvertLoweredCumOp { using ConvertLoweredCumOp::ConvertLoweredCumOp; - bool IsInitValue(const DenseElementsAttr &attr) const override { + bool IsInitValue(const DenseElementsAttr& attr) const override { auto element_type = attr.getType().getElementType(); if (attr.getNumElements() != 1 || !element_type.isIntOrFloat()) return false; @@ -2496,7 +2501,7 @@ class ConvertLoweredCumSumOp class ConvertLoweredCumProdOp : public ConvertLoweredCumOp { using ConvertLoweredCumOp::ConvertLoweredCumOp; - bool IsInitValue(const DenseElementsAttr &attr) const override { + bool IsInitValue(const DenseElementsAttr& attr) const override { auto element_type = attr.getType().getElementType(); if (attr.getNumElements() != 1 || !element_type.isIntOrFloat()) return false; @@ -2516,12 +2521,12 @@ class ConvertLoweredCumProdOp // * div(reduce_sum_window(x), reduce_sum_window(constant(1))) class ConvertAvgPoolOp : public OpConversionPattern { public: - explicit ConvertAvgPoolOp(MLIRContext *context) + explicit ConvertAvgPoolOp(MLIRContext* context) : OpConversionPattern(context, /*benefit=*/10) {} LogicalResult matchAndRewrite( mhlo::DivOp div_op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { auto rw = dyn_cast_or_null(div_op.getLhs().getDefiningOp()); if (!rw || rw->getNumResults() != 1) return failure(); @@ -2573,8 +2578,9 @@ class ConvertAvgPoolOp : public OpConversionPattern { window_strides, "VALID", rewriter); } + Value actual_divisor = recursivelyWalkUpDivisor(div_op.getRhs()); auto rw_rhs = - dyn_cast_or_null(div_op.getRhs().getDefiningOp()); + dyn_cast_or_null(actual_divisor.getDefiningOp()); if (rw_rhs && rw_rhs.getNumResults() == 1) { // Check that RHS is a sum-reduce-window. if (failed(MatchBinaryReduceFunction(rw_rhs.getBody()))) @@ -2618,7 +2624,7 @@ class ConvertAvgPoolOp : public OpConversionPattern { llvm::ArrayRef ksizes, llvm::ArrayRef kstrides, llvm::StringRef padding, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter& rewriter) const { if (ksizes.size() == 4) { rewriter.replaceOpWithNewOp( op, op.getType(), input, rewriter.getI64ArrayAttr(ksizes), @@ -2634,6 +2640,18 @@ class ConvertAvgPoolOp : public OpConversionPattern { } return failure(); } + + // Walks up the divisor and ignore all precedding reshape/broadcast op. + // Returns the first producer op which is neither reshape nor broadcast. + Value recursivelyWalkUpDivisor(Value divisor) const { + while (llvm::isa_and_nonnull( + divisor.getDefiningOp())) { + Operation* producer = divisor.getDefiningOp(); + divisor = producer->getOperand(/*idx=*/0); + } + + return divisor; + } }; class ConvertMaxPoolOp : public OpConversionPattern { @@ -2642,7 +2660,7 @@ class ConvertMaxPoolOp : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::ReduceWindowOp rw, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { // Check that the reduce-window is a max-reduce-window. if (failed(MatchBinaryReduceFunction(rw.getBody()))) return failure(); @@ -2697,7 +2715,7 @@ class ConvertMaxPoolOp : public OpConversionPattern { llvm::ArrayRef ksizes, llvm::ArrayRef kstrides, llvm::StringRef padding, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter& rewriter) const { if (ksizes.size() == 4) { rewriter.replaceOpWithNewOp( op, op.getType(0), input, rewriter.getI64ArrayAttr(ksizes), @@ -2725,7 +2743,7 @@ class LegalizeHloToTf : public impl::LegalizeHloToTfPassBase { }; // Returns the shape of the given value in a Constant Op. -arith::ConstantOp ShapeToConst(PatternRewriter &rewriter, Value value) { +arith::ConstantOp ShapeToConst(PatternRewriter& rewriter, Value value) { ArrayRef shape = value.getType().cast().getShape(); auto attr_type = RankedTensorType::get({static_cast(shape.size())}, rewriter.getIntegerType(64)); @@ -2794,7 +2812,7 @@ bool ValueGreaterThanZero(ElementsAttr float_or_int) { } // Returns whether the splat constant is the sign of the int or float Tensor. -bool TensorIsSign(PatternRewriter &rewriter, ElementsAttr float_or_int, +bool TensorIsSign(PatternRewriter& rewriter, ElementsAttr float_or_int, ElementsAttr sgn_cst) { auto sgn_splat = llvm::dyn_cast(sgn_cst); if (!sgn_splat) return false; @@ -2893,9 +2911,9 @@ bool IsIotaAttr(ArrayRef arr, int64_t size) { // Note: NormalizeIndexVector is assumed to have run on the indices already so // that the index_vector_dim is the trailing dimension in `indices`. LogicalResult CanonicalizeScatterUpdates( - Operation *scatter_op, llvm::ArrayRef update_window_dims, - const Value &indices, const ShapedType &indices_type, Value &updates, - ShapedType &updates_type, ConversionPatternRewriter &rewriter) { + Operation* scatter_op, llvm::ArrayRef update_window_dims, + const Value& indices, const ShapedType& indices_type, Value& updates, + ShapedType& updates_type, ConversionPatternRewriter& rewriter) { auto canonical_update_window_dims = llvm::to_vector( llvm::seq(indices_type.getRank() - 1, updates_type.getRank())); @@ -2936,10 +2954,10 @@ LogicalResult CanonicalizeScatterUpdates( // If index_vector_dim == indices.rank() then insert the implicit extra // dimension into indices to normalize everything to index_vector_dim == // indices.rank() - 1. -LogicalResult NormalizeIndexVector(Operation *parent_op, Value &indices, - ShapedType &indices_type, +LogicalResult NormalizeIndexVector(Operation* parent_op, Value& indices, + ShapedType& indices_type, int64_t index_vector_dim, - ConversionPatternRewriter &rewriter) { + ConversionPatternRewriter& rewriter) { if (index_vector_dim == indices_type.getRank()) { llvm::SmallVector new_start_indices_shape( indices_type.getShape().begin(), indices_type.getShape().end()); @@ -2974,7 +2992,7 @@ class ConvertGatherOp : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::GatherOp gather_op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { Value operand = gather_op.getOperand(); Value start_indices = gather_op.getStartIndices(); @@ -3036,7 +3054,7 @@ class ConvertGatherOp : public OpConversionPattern { auto offset_dims = gather_op.getDimensionNumbers().getOffsetDims(); SmallVector offset_dims_vector(offset_dims.begin(), offset_dims.end()); - const TransposeParams &transpose_params = + const TransposeParams& transpose_params = CanonicalizeOffset(/*result_type=*/result_type, /*original_offset_dims=*/offset_dims_vector); @@ -3174,7 +3192,7 @@ class ConvertWhileOp : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::WhileOp while_op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { // HLO WhileOp should have two regions: cond and body. if (while_op->getNumRegions() != 2) return failure(); @@ -3190,10 +3208,10 @@ class ConvertWhileOp : public OpConversionPattern { while_op.getLoc(), while_op->getResultTypes(), while_op->getOperands(), /*parallel_iterations=*/10, /*is_stateless=*/false, /*shape_invariant=*/false); - new_while.cond().takeBody(while_op.getCond()); - new_while.body().takeBody(while_op.getBody()); - ReplaceReturnOp(new_while.cond(), rewriter); - ReplaceReturnOp(new_while.body(), rewriter); + new_while.getCond().takeBody(while_op.getCond()); + new_while.getBody().takeBody(while_op.getBody()); + ReplaceReturnOp(new_while.getCond(), rewriter); + ReplaceReturnOp(new_while.getBody(), rewriter); rewriter.replaceOp(while_op, new_while.getResults()); return success(); } @@ -3205,16 +3223,16 @@ class ConvertIfOp : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::IfOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { // HLO IfOp currently doesn't support stateless auto new_op = rewriter.create( op.getLoc(), op->getResultTypes(), op.getPred(), /*is_stateless=*/false, /*_then_func_name=*/nullptr, /*_else_func_name=*/nullptr); - new_op.then_branch().takeBody(op.getTrueBranch()); - new_op.else_branch().takeBody(op.getFalseBranch()); - ReplaceReturnOp(new_op.then_branch(), rewriter); - ReplaceReturnOp(new_op.else_branch(), rewriter); + new_op.getThenBranch().takeBody(op.getTrueBranch()); + new_op.getElseBranch().takeBody(op.getFalseBranch()); + ReplaceReturnOp(new_op.getThenBranch(), rewriter); + ReplaceReturnOp(new_op.getElseBranch(), rewriter); rewriter.replaceOp(op, new_op.getResults()); return success(); } @@ -3227,7 +3245,7 @@ class ConvertScatterOp : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::ScatterOp scatter_op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ConversionPatternRewriter& rewriter) const final { OperandRange operands = scatter_op.getInputs(); Value indices = scatter_op.getScatterIndices(); OperandRange updates = scatter_op.getUpdates(); @@ -3342,7 +3360,7 @@ using ConvertScatterUpdateOp = ConvertScatterOp; // Converts mhlo.pad to tf.PadV2 -Value ConvertPadOp(PatternRewriter &rewriter, Operation *old_op) { +Value ConvertPadOp(PatternRewriter& rewriter, Operation* old_op) { auto pad_op = cast(old_op); mlir::Location loc = pad_op.getLoc(); @@ -3376,7 +3394,7 @@ bool IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions, // Returns the intermediate shape that input tensor should be reshaped to during // legalization of BroadcastInDimOp. -arith::ConstantOp ExpandedShape(PatternRewriter &rewriter, Value input, +arith::ConstantOp ExpandedShape(PatternRewriter& rewriter, Value input, DenseIntElementsAttr broadcast_dimensions, Value output) { // Initialize expanded shape with output rank and dimensions of 1. @@ -3403,7 +3421,7 @@ arith::ConstantOp ExpandedShape(PatternRewriter &rewriter, Value input, /// Performs the lowering to XLA dialect. void LegalizeHloToTf::runOnOperation() { - MLIRContext &context = getContext(); + MLIRContext& context = getContext(); // Add legalization patterns to the list. RewritePatternSet patterns(&getContext()); @@ -3422,8 +3440,8 @@ void LegalizeHloToTf::runOnOperation() { } // end namespace -void PopulateLegalizeHloToTfPatterns(RewritePatternSet *patterns, - MLIRContext *context) { +void PopulateLegalizeHloToTfPatterns(RewritePatternSet* patterns, + MLIRContext* context) { patterns->add< ConvertAvgPoolOp, Convert2DConvOp, Convert1DConvOp, ConvertNonTrivialConvOp, ConvertDynamicSliceOp, diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td index 4a66528c49d..16493c30286 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td @@ -19,7 +19,7 @@ include "mlir/IR/OpBase.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "stablehlo/dialect/ChloOps.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" -include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.td" // Check if broadcasting is compatible with TF ops. def IsLegalNumpyRankedBroadcast : @@ -40,7 +40,7 @@ def IsNotTFStyleBroadcast : Constraint>, // Return intermediate shape before broadcasting, wrapped in a constant op. def ExpandedShape : NativeCodeCall<"ExpandedShape($_builder, $0, $1, $2)">; -def : Pat<(HLO_ConstantOp:$output $value), (TF_ConstOp $value), +def : Pat<(MHLO_ConstantOp:$output $value), (TF_ConstOp $value), [(TF_Tensor $output)]>; //===----------------------------------------------------------------------===// @@ -50,58 +50,58 @@ def : Pat<(HLO_ConstantOp:$output $value), (TF_ConstOp $value), // context, getting to these ops may require some raising. //===----------------------------------------------------------------------===// -foreach fromToBinPair = [[HLO_AddOp, CHLO_BroadcastAddOp, TF_AddV2Op], - [HLO_DivOp, CHLO_BroadcastDivOp, TF_DivOp], - [HLO_ShiftLeftOp, CHLO_BroadcastShiftLeftOp, TF_LeftShiftOp], - [HLO_MaxOp, CHLO_BroadcastMaxOp, TF_MaximumOp], - [HLO_MinOp, CHLO_BroadcastMinOp, TF_MinimumOp], - [HLO_MulOp, CHLO_BroadcastMulOp, TF_MulOp], - [HLO_PowOp, CHLO_BroadcastPowOp, TF_PowOp], - [HLO_SubtractOp, CHLO_BroadcastSubOp, TF_SubOp], - [HLO_Atan2Op, CHLO_BroadcastAtan2Op, TF_Atan2Op]] in { +foreach fromToBinPair = [[MHLO_AddOp, CHLO_BroadcastAddOp, TF_AddV2Op], + [MHLO_DivOp, CHLO_BroadcastDivOp, TF_DivOp], + [MHLO_ShiftLeftOp, CHLO_BroadcastShiftLeftOp, TF_LeftShiftOp], + [MHLO_MaxOp, CHLO_BroadcastMaxOp, TF_MaximumOp], + [MHLO_MinOp, CHLO_BroadcastMinOp, TF_MinimumOp], + [MHLO_MulOp, CHLO_BroadcastMulOp, TF_MulOp], + [MHLO_PowOp, CHLO_BroadcastPowOp, TF_PowOp], + [MHLO_SubtractOp, CHLO_BroadcastSubOp, TF_SubOp], + [MHLO_Atan2Op, CHLO_BroadcastAtan2Op, TF_Atan2Op]] in { def : Pat<(fromToBinPair[0] $l, $r), (fromToBinPair[2] $l, $r)>; def : Pat<(fromToBinPair[1] $l, $r, $broadcast_dimensions), (fromToBinPair[2] $l, $r), [(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>; } -foreach pair = [[HLO_AndOp, CHLO_BroadcastAndOp, TF_BitwiseAndOp], - [HLO_OrOp, CHLO_BroadcastOrOp, TF_BitwiseOrOp], - [HLO_XorOp, CHLO_BroadcastXorOp, TF_BitwiseXorOp]] in { +foreach pair = [[MHLO_AndOp, CHLO_BroadcastAndOp, TF_BitwiseAndOp], + [MHLO_OrOp, CHLO_BroadcastOrOp, TF_BitwiseOrOp], + [MHLO_XorOp, CHLO_BroadcastXorOp, TF_BitwiseXorOp]] in { def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r), (pair[2] $l, $r)>; def : Pat<(pair[1] TF_IntTensor:$l, TF_IntTensor:$r, $broadcast_dimensions), (pair[2] $l, $r), [(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>; } -foreach pair = [[HLO_AndOp, CHLO_BroadcastAndOp, TF_LogicalAndOp], - [HLO_OrOp, CHLO_BroadcastOrOp, TF_LogicalOrOp]] in { +foreach pair = [[MHLO_AndOp, CHLO_BroadcastAndOp, TF_LogicalAndOp], + [MHLO_OrOp, CHLO_BroadcastOrOp, TF_LogicalOrOp]] in { def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r), (pair[2] $l, $r)>; def : Pat<(pair[1] I1Tensor:$l, I1Tensor:$r, $broadcast_dimensions), (pair[2] $l, $r), [(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>; } -def : Pat<(HLO_ShiftRightArithmeticOp $l, $r), (TF_RightShiftOp $l, $r)>; +def : Pat<(MHLO_ShiftRightArithmeticOp $l, $r), (TF_RightShiftOp $l, $r)>; def : Pat<(CHLO_BroadcastShiftRightArithmeticOp $l, $r, $broadcast_dimensions), (TF_RightShiftOp $l, $r), [(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>; -def : Pat<(HLO_ShiftRightLogicalOp $l, $r), (TF_RightShiftOp $l, $r)>; +def : Pat<(MHLO_ShiftRightLogicalOp $l, $r), (TF_RightShiftOp $l, $r)>; def : Pat<(CHLO_BroadcastShiftRightLogicalOp $l, $r, $broadcast_dimensions), (TF_RightShiftOp $l, $r), [(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>; -def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r)), (TF_FloorDivOp $l, $r)>; -def : Pat<(HLO_FloorOp (CHLO_BroadcastDivOp $l, $r, +def : Pat<(MHLO_FloorOp (MHLO_DivOp $l, $r)), (TF_FloorDivOp $l, $r)>; +def : Pat<(MHLO_FloorOp (CHLO_BroadcastDivOp $l, $r, $broadcast_dimensions)), (TF_FloorDivOp $l, $r), [(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>; -def : Pat<(HLO_ComplexOp $r, $i), (TF_ComplexOp $r, $i)>; +def : Pat<(MHLO_ComplexOp $r, $i), (TF_ComplexOp $r, $i)>; -def : Pat<(HLO_RemOp TF_FpOrI32OrI64Tensor:$l, TF_FpOrI32OrI64Tensor:$r), (TF_ModOp $l, $r)>; +def : Pat<(MHLO_RemOp TF_FpOrI32OrI64Tensor:$l, TF_FpOrI32OrI64Tensor:$r), (TF_ModOp $l, $r)>; def : Pat<(CHLO_BroadcastRemOp TF_FpOrI32OrI64Tensor:$l, TF_FpOrI32OrI64Tensor:$r, $broadcast_dimensions), (TF_ModOp $l, $r), [(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>; @@ -110,64 +110,98 @@ def : Pat<(CHLO_BroadcastRemOp TF_FpOrI32OrI64Tensor:$l, TF_FpOrI32OrI64Tensor:$ // Unary op patterns. //===----------------------------------------------------------------------===// -def : Pat<(HLO_ConvertOp HLO_Tensor:$operand), +def : Pat<(MHLO_ConvertOp MHLO_Tensor:$operand), (TF_CastOp $operand, ConstBoolAttrFalse)>; -foreach Mapping = [[HLO_AbsOp, TF_AbsOp], - [HLO_BitcastConvertOp, TF_BitcastOp], - [HLO_CeilOp, TF_CeilOp], - [HLO_CosineOp, TF_CosOp], - [HLO_ExpOp, TF_ExpOp], - [HLO_Expm1Op, TF_Expm1Op], - [HLO_FloorOp, TF_FloorOp], - [HLO_ImagOp, TF_ImagOp], - [HLO_IsFiniteOp, TF_IsFiniteOp], - [HLO_LogOp, TF_LogOp], - [HLO_Log1pOp, TF_Log1pOp], - [HLO_LogisticOp, TF_SigmoidOp], - [HLO_NegOp, TF_NegOp], - [HLO_RealOp, TF_RealOp], - [HLO_RsqrtOp, TF_RsqrtOp], - [HLO_SineOp, TF_SinOp], - [HLO_SignOp, TF_SignOp], - [HLO_SqrtOp, TF_SqrtOp], - [HLO_TanhOp, TF_TanhOp]] in - def : Pat<(Mapping[0] TF_IntOrFpTensor:$input), (Mapping[1] $input)>; - -def : Pat<(HLO_NotOp TF_BoolTensor:$input), (TF_LogicalNotOp $input)>; -def : Pat<(HLO_AbsOp TF_ComplexTensor:$arg), (TF_ComplexAbsOp $arg)>; - -def : Pat<(HLO_BroadcastOp $arg, $shape), +foreach Mapping = [[MHLO_AbsOp, TF_AbsOp], + [MHLO_BitcastConvertOp, TF_BitcastOp], + [MHLO_CeilOp, TF_CeilOp], + [MHLO_CosineOp, TF_CosOp], + [MHLO_ExpOp, TF_ExpOp], + [MHLO_Expm1Op, TF_Expm1Op], + [MHLO_FloorOp, TF_FloorOp], + [MHLO_ImagOp, TF_ImagOp], + [MHLO_IsFiniteOp, TF_IsFiniteOp], + [MHLO_LogOp, TF_LogOp], + [MHLO_Log1pOp, TF_Log1pOp], + [MHLO_LogisticOp, TF_SigmoidOp], + [MHLO_NegOp, TF_NegOp], + [MHLO_RealOp, TF_RealOp], + [MHLO_RsqrtOp, TF_RsqrtOp], + [MHLO_SineOp, TF_SinOp], + [MHLO_SignOp, TF_SignOp], + [MHLO_SqrtOp, TF_SqrtOp], + [MHLO_TanhOp, TF_TanhOp]] in + def : Pat<(Mapping[0] TF_IntOrFpTensor:$input), (Mapping[1] $input)>; + +class GetRankedScalarAttr : + NativeCodeCall<"DenseElementsAttr::get<" # prefix # "int" # width # "_t>(" + "RankedTensorType::get({}, $_builder.getIntegerType(" + # width # signed # "))," # value # ")">; + +def : Pat<(MHLO_NotOp TF_Int8Tensor:$input), + (TF_BitwiseXorOp $input, + (TF_ConstOp (GetRankedScalarAttr<"", 8, "", "-1">)))>; +def : Pat<(MHLO_NotOp TF_Int16Tensor:$input), + (TF_BitwiseXorOp $input, + (TF_ConstOp (GetRankedScalarAttr<"", 16, "", "-1">)))>; +def : Pat<(MHLO_NotOp TF_Int32Tensor:$input), + (TF_BitwiseXorOp $input, + (TF_ConstOp (GetRankedScalarAttr<"", 32, "", "-1">)))>; +def : Pat<(MHLO_NotOp TF_Int64Tensor:$input), + (TF_BitwiseXorOp $input, + (TF_ConstOp (GetRankedScalarAttr<"", 64, "", "-1">)))>; + +def : Pat<(MHLO_NotOp TF_Uint8Tensor:$input), + (TF_BitwiseXorOp $input, + (TF_ConstOp (GetRankedScalarAttr<"u", 8, ", false", "0xFFU">)))>; +def : Pat<(MHLO_NotOp TF_Uint16Tensor:$input), + (TF_BitwiseXorOp $input, + (TF_ConstOp (GetRankedScalarAttr<"u", 16, ", false", + "0xFFFFU">)))>; +def : Pat<(MHLO_NotOp TF_Uint32Tensor:$input), + (TF_BitwiseXorOp $input, + (TF_ConstOp (GetRankedScalarAttr<"u", 32, ", false", + "0xFFFFFFFFUL">)))>; +def : Pat<(MHLO_NotOp TF_Uint64Tensor:$input), + (TF_BitwiseXorOp $input, + (TF_ConstOp (GetRankedScalarAttr<"u", 64, ", false", + "0xFFFFFFFFFFFFFFFFULL">)))>; + +def : Pat<(MHLO_NotOp TF_BoolTensor:$input), (TF_LogicalNotOp $input)>; +def : Pat<(MHLO_AbsOp TF_ComplexTensor:$arg), (TF_ComplexAbsOp $arg)>; + +def : Pat<(MHLO_BroadcastOp $arg, $shape), (TF_BroadcastToOp $arg, (TF_ConstOp $shape))>; -def : Pat<(HLO_BroadcastInDimOp:$output $input, $broadcast_dimensions), +def : Pat<(MHLO_BroadcastInDimOp:$output $input, $broadcast_dimensions), (TF_BroadcastToOp $input, (ShapeToConst $output)), [(IsTFStyleBroadcast $broadcast_dimensions, $output)]>; -def : Pat<(HLO_BroadcastInDimOp:$output $input, $broadcast_dimensions), +def : Pat<(MHLO_BroadcastInDimOp:$output $input, $broadcast_dimensions), (TF_BroadcastToOp (TF_ReshapeOp $input, (ExpandedShape $input, $broadcast_dimensions, $output)), (ShapeToConst $output)), [(IsNotTFStyleBroadcast $broadcast_dimensions, $output)]>; -def : Pat<(HLO_TransposeOp $arg, $permutation), +def : Pat<(MHLO_TransposeOp $arg, $permutation), (TF_TransposeOp $arg, (TF_ConstOp $permutation))>; -def : Pat<(HLO_ReverseOp $op, $dims), (TF_ReverseV2Op $op, (TF_ConstOp $dims))>; -def : Pat<(HLO_ReshapeOp:$output $input), +def : Pat<(MHLO_ReverseOp $op, $dims), (TF_ReverseV2Op $op, (TF_ConstOp $dims))>; +def : Pat<(MHLO_ReshapeOp:$output $input), (TF_ReshapeOp $input, (ShapeToConst $output))>; //===----------------------------------------------------------------------===// // Ternary op patterns. //===----------------------------------------------------------------------===// -def : Pat<(HLO_ClampOp $min, $arg, $max), +def : Pat<(MHLO_ClampOp $min, $arg, $max), (TF_MaximumOp (TF_MinimumOp $arg, $max), $min)>; -def : Pat<(HLO_SelectOp $cond, $t, $e), (TF_SelectOp $cond, $t, $e)>; +def : Pat<(MHLO_SelectOp $cond, $t, $e), (TF_SelectOp $cond, $t, $e)>; //===----------------------------------------------------------------------===// // Variadic op patterns. //===----------------------------------------------------------------------===// -def : Pat<(HLO_ConcatenateOp $inputs, $dim), +def : Pat<(MHLO_ConcatenateOp $inputs, $dim), (TF_ConcatV2Op $inputs, (TF_ConstOp $dim))>; //===----------------------------------------------------------------------===// @@ -204,8 +238,8 @@ def IsMhloTFCompareType : AttrConstraint< HasMhloCompareType<"::mlir::mhlo::ComparisonType::NOTYPE">]>, "compare type supported by TensorFlow">; -class HLO_ComparisonDirectionValue : - ConstantAttr; +class MHLO_ComparisonDirectionValue : + ConstantAttr; foreach p = [[TF_EqualOp, CHLO_ComparisonDirectionValue<"EQ">], [TF_NotEqualOp, CHLO_ComparisonDirectionValue<"NE">]] in { @@ -215,9 +249,9 @@ foreach p = [[TF_EqualOp, CHLO_ComparisonDirectionValue<"EQ">], [(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>; } -foreach p = [[TF_EqualOp, HLO_ComparisonDirectionValue<"EQ">], - [TF_NotEqualOp, HLO_ComparisonDirectionValue<"NE">]] in { - def : Pat<(HLO_CompareOp $l, $r, p[1], IsMhloTFCompareType:$type), +foreach p = [[TF_EqualOp, MHLO_ComparisonDirectionValue<"EQ">], + [TF_NotEqualOp, MHLO_ComparisonDirectionValue<"NE">]] in { + def : Pat<(MHLO_CompareOp $l, $r, p[1], IsMhloTFCompareType:$type), (p[0] $l, $r, ConstBoolAttrTrue)>; } @@ -231,32 +265,33 @@ foreach p = [[TF_GreaterEqualOp, CHLO_ComparisonDirectionValue<"GE">], [(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>; } -foreach p = [[TF_GreaterEqualOp, HLO_ComparisonDirectionValue<"GE">], - [TF_GreaterOp, HLO_ComparisonDirectionValue<"GT">], - [TF_LessEqualOp, HLO_ComparisonDirectionValue<"LE">], - [TF_LessOp, HLO_ComparisonDirectionValue<"LT">]] in { - def : Pat<(HLO_CompareOp $l, $r, p[1], IsMhloTFCompareType:$type), +foreach p = [[TF_GreaterEqualOp, MHLO_ComparisonDirectionValue<"GE">], + [TF_GreaterOp, MHLO_ComparisonDirectionValue<"GT">], + [TF_LessEqualOp, MHLO_ComparisonDirectionValue<"LE">], + [TF_LessOp, MHLO_ComparisonDirectionValue<"LT">]] in { + def : Pat<(MHLO_CompareOp $l, $r, p[1], IsMhloTFCompareType:$type), (p[0] $l, $r)>; } def ConvertDotOp : NativeCodeCall<"ConvertDotOp($_builder, " "$0.getDefiningOp())">; -def : Pat<(HLO_DotOp:$old_value StaticShapeTensorOf<[TF_ElementType]>:$lhs, +def : Pat<(MHLO_DotOp:$old_value StaticShapeTensorOf<[TF_ElementType]>:$lhs, StaticShapeTensorOf<[TF_ElementType]>:$rhs, $precision_config), (ConvertDotOp $old_value)>; def ConvertDotGeneralOp : NativeCodeCall<"ConvertDotGeneralOp($_builder, " "$0.getDefiningOp())">; -def : Pat<(HLO_DotGeneralOp:$old_value AnyStaticShapeTensor:$lhs, - AnyStaticShapeTensor:$rhs, $dot_dimension_numbers, - $precision_config), +def : Pat<(MHLO_DotGeneralOp:$old_value + StaticShapeTensorOf<[TF_ElementType]>:$lhs, + StaticShapeTensorOf<[TF_ElementType]>:$rhs, + $dot_dimension_numbers, $precision_config), (ConvertDotGeneralOp $old_value)>; def IsZero : Constraint() == 0">>; def ConvertPadOp : NativeCodeCall< "ConvertPadOp($_builder, $0.getDefiningOp())">; -def : Pat<(HLO_PadOp:$old_value $input, $pad_value, $pad_low, $pad_high, +def : Pat<(MHLO_PadOp:$old_value $input, $pad_value, $pad_low, $pad_high, $pad_interior), (ConvertPadOp $old_value), [(IsZero $pad_interior)]>; @@ -282,30 +317,30 @@ def SameTypeOrDefaultCompare : Constraint, + (MHLO_FloorOp:$floor $input)), + (MHLO_ConstantOp $half), + MHLO_ComparisonDirectionValue<"GT">, $compare_type0), - (HLO_AndOp - (HLO_CompareOp + (MHLO_AndOp + (MHLO_CompareOp $frac1, - (HLO_ConstantOp $half1), - HLO_ComparisonDirectionValue<"EQ">, + (MHLO_ConstantOp $half1), + MHLO_ComparisonDirectionValue<"EQ">, $compare_type1), - (HLO_CompareOp - (HLO_SubtractOp + (MHLO_CompareOp + (MHLO_SubtractOp $floor1, - (HLO_MulOp - (HLO_FloorOp (HLO_MulOp $input, (HLO_ConstantOp $half2))), - (HLO_ConstantOp $two))), - (HLO_ConstantOp $one1), - HLO_ComparisonDirectionValue<"EQ">, + (MHLO_MulOp + (MHLO_FloorOp (MHLO_MulOp $input, (MHLO_ConstantOp $half2))), + (MHLO_ConstantOp $two))), + (MHLO_ConstantOp $one1), + MHLO_ComparisonDirectionValue<"EQ">, $compare_type2))), - (HLO_AddOp $floor2, (HLO_ConstantOp $one)), + (MHLO_AddOp $floor2, (MHLO_ConstantOp $one)), $floor3), (TF_RoundOp $input), [(ValueEquals<"1.0"> $one), @@ -330,19 +365,19 @@ def : Pat<(HLO_SelectOp // if ((rem[i] < 0) != (arg0[i] < 0) && arg0[i] != 0) // rem[i] += arg1[i] // return rem -def : Pat<(HLO_SelectOp - (HLO_AndOp - (HLO_CompareOp - (HLO_CompareOp:$rltz - (HLO_RemOp:$rem $arg, $arg1), - (HLO_ConstantOp $cst), - HLO_ComparisonDirectionValue<"LT">, +def : Pat<(MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp + (MHLO_CompareOp:$rltz + (MHLO_RemOp:$rem $arg, $arg1), + (MHLO_ConstantOp $cst), + MHLO_ComparisonDirectionValue<"LT">, $compare_type), - (HLO_CompareOp:$arg1ltz $arg1, (HLO_ConstantOp $cst1), HLO_ComparisonDirectionValue<"LT">, $compare_type1), - HLO_ComparisonDirectionValue<"NE">, + (MHLO_CompareOp:$arg1ltz $arg1, (MHLO_ConstantOp $cst1), MHLO_ComparisonDirectionValue<"LT">, $compare_type1), + MHLO_ComparisonDirectionValue<"NE">, $compare_type2), - (HLO_CompareOp:$rnz $rem1, (HLO_ConstantOp $cst2), HLO_ComparisonDirectionValue<"NE">, $compare_type3)), - (HLO_AddOp $rem2, $arg1), + (MHLO_CompareOp:$rnz $rem1, (MHLO_ConstantOp $cst2), MHLO_ComparisonDirectionValue<"NE">, $compare_type3)), + (MHLO_AddOp $rem2, $arg1), $rem3), (TF_FloorModOp $arg, $arg1), [(ValueEquals<"0.0"> $cst), @@ -352,8 +387,7 @@ def : Pat<(HLO_SelectOp (SameValue $rem, $rem2), (SameValue $rem, $rem3), (SameTypeOrDefaultCompare $compare_type, $cst), - (SameTypeOrDefaultCompare $compare_type1, $cst1), - (SameTypeOrDefaultCompare $compare_type2, $compare_type)]>; + (SameTypeOrDefaultCompare $compare_type1, $cst1)]>; // Converts a dag of HLOs representing floor_mod with a constant to // tf.FloorMod. The pattern matched executes the following computation: @@ -364,15 +398,15 @@ def : Pat<(HLO_SelectOp // if (rem[i] < 0 && rem[i] != 0) // rem[i] += cst // return rem -def : Pat<(HLO_SelectOp - (HLO_AndOp - (HLO_CompareOp:$rltz - (HLO_RemOp:$rem $arg, (HLO_ConstantOp $cst)), - (HLO_ConstantOp $cst1), - HLO_ComparisonDirectionValue<"LT">, +def : Pat<(MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp:$rltz + (MHLO_RemOp:$rem $arg, (MHLO_ConstantOp $cst)), + (MHLO_ConstantOp $cst1), + MHLO_ComparisonDirectionValue<"LT">, $compare_type), - (HLO_CompareOp:$rnz $rem1, (HLO_ConstantOp $cst2), HLO_ComparisonDirectionValue<"NE">, $compare_type3)), - (HLO_AddOp $rem2, (HLO_ConstantOp $cst3)), + (MHLO_CompareOp:$rnz $rem1, (MHLO_ConstantOp $cst2), MHLO_ComparisonDirectionValue<"NE">, $compare_type3)), + (MHLO_AddOp $rem2, (MHLO_ConstantOp $cst3)), $rem3), (TF_FloorModOp $arg, (TF_ConstOp $cst3)), [(ValueGreaterThanZero $cst), @@ -410,24 +444,24 @@ def : Pat<(HLO_SelectOp // the same function in this case the sign function. Named values like 'div' // refer to the same value produced by the same function, in this case division. // Mathematical symbols do not indicate a re-use of the value. -def : Pat<(HLO_RoundOp - (HLO_SelectOp - (HLO_AndOp - (HLO_CompareOp - (HLO_RemOp:$rem $arg0, $arg1), - (HLO_ConstantOp $cst), - HLO_ComparisonDirectionValue<"NE">, +def : Pat<(MHLO_RoundOp + (MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp + (MHLO_RemOp:$rem $arg0, $arg1), + (MHLO_ConstantOp $cst), + MHLO_ComparisonDirectionValue<"NE">, $compare_type), - (HLO_CompareOp - (HLO_SignOp $arg1), - (HLO_SignOp $rem1), - HLO_ComparisonDirectionValue<"NE">, + (MHLO_CompareOp + (MHLO_SignOp $arg1), + (MHLO_SignOp $rem1), + MHLO_ComparisonDirectionValue<"NE">, $compare_type1)), - (HLO_AddOp - (HLO_DivOp:$div - (HLO_SubtractOp $arg0, $rem2), + (MHLO_AddOp + (MHLO_DivOp:$div + (MHLO_SubtractOp $arg0, $rem2), $arg1b), - (HLO_ConstantOp $cst_neg1)), + (MHLO_ConstantOp $cst_neg1)), $div1)), (TF_FloorDivOp $arg0, $arg1), [(ValueEquals<"0.0"> $cst), @@ -466,24 +500,24 @@ def : Pat<(HLO_RoundOp // the same function in this case the sign function. Named values like 'div' // refer to the same value produced by the same function, in this case division. // Mathematical symbols do not indicate a re-use of the value. -def : Pat<(HLO_RoundOp - (HLO_SelectOp - (HLO_AndOp - (HLO_CompareOp - (HLO_RemOp:$rem $arg0, (HLO_ConstantOp:$cst $cstv)), - (HLO_ConstantOp $cst_zero), - HLO_ComparisonDirectionValue<"NE">, +def : Pat<(MHLO_RoundOp + (MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp + (MHLO_RemOp:$rem $arg0, (MHLO_ConstantOp:$cst $cstv)), + (MHLO_ConstantOp $cst_zero), + MHLO_ComparisonDirectionValue<"NE">, $compare_type), - (HLO_CompareOp - (HLO_ConstantOp $cst_sgn), - (HLO_SignOp $rem1), - HLO_ComparisonDirectionValue<"NE">, + (MHLO_CompareOp + (MHLO_ConstantOp $cst_sgn), + (MHLO_SignOp $rem1), + MHLO_ComparisonDirectionValue<"NE">, $compare_type1)), - (HLO_AddOp - (HLO_MulOp:$mul - (HLO_SubtractOp $arg0, $rem2), - (HLO_ConstantOp $cst_recip)), - (HLO_ConstantOp $cst_neg1)), + (MHLO_AddOp + (MHLO_MulOp:$mul + (MHLO_SubtractOp $arg0, $rem2), + (MHLO_ConstantOp $cst_recip)), + (MHLO_ConstantOp $cst_neg1)), $mul1)), (TF_FloorDivOp $arg0, $cst), [(ValueEquals<"0.0"> $cst_zero), @@ -524,24 +558,24 @@ def : Pat<(HLO_RoundOp // the same function in this case the sign function. Named values like 'div' // refer to the same value produced by the same function, in this case division. // Mathematical symbols do not indicate a re-use of the value. -def : Pat<(HLO_RoundOp - (HLO_SelectOp - (HLO_AndOp - (HLO_CompareOp - (HLO_RemOp:$rem $arg0, (HLO_ConstantOp:$cst $cstv)), - (HLO_ConstantOp $cst_zero), - HLO_ComparisonDirectionValue<"NE">, +def : Pat<(MHLO_RoundOp + (MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp + (MHLO_RemOp:$rem $arg0, (MHLO_ConstantOp:$cst $cstv)), + (MHLO_ConstantOp $cst_zero), + MHLO_ComparisonDirectionValue<"NE">, $compare_type), - (HLO_CompareOp - (HLO_ConstantOp $cst_sgn), - (HLO_SignOp $rem1), - HLO_ComparisonDirectionValue<"NE">, + (MHLO_CompareOp + (MHLO_ConstantOp $cst_sgn), + (MHLO_SignOp $rem1), + MHLO_ComparisonDirectionValue<"NE">, $compare_type1)), - (HLO_AddOp - (HLO_DivOp:$div - (HLO_SubtractOp $arg0, $rem2), - (HLO_ConstantOp $cstv1)), - (HLO_ConstantOp $cst_neg1)), + (MHLO_AddOp + (MHLO_DivOp:$div + (MHLO_SubtractOp $arg0, $rem2), + (MHLO_ConstantOp $cstv1)), + (MHLO_ConstantOp $cst_neg1)), $div1)), (TF_FloorDivOp $arg0, $cst), [(ValueEquals<"0.0"> $cst_zero), @@ -586,27 +620,27 @@ def : Pat<(HLO_RoundOp // the same function in this case the sign function. Named values like 'div' // refer to the same value produced by the same function, in this case division. // Mathematical symbols do not indicate a re-use of the value. -def : Pat<(HLO_RoundOp - (HLO_SelectOp - (HLO_AndOp - (HLO_CompareOp - (HLO_RemOp:$rem $arg0, - (HLO_BroadcastInDimOp:$bcst - (HLO_ConstantOp $cstv), +def : Pat<(MHLO_RoundOp + (MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp + (MHLO_RemOp:$rem $arg0, + (MHLO_BroadcastInDimOp:$bcst + (MHLO_ConstantOp $cstv), $broadcast_dimension)), - (HLO_ConstantOp $cst_zero), - HLO_ComparisonDirectionValue<"NE">, + (MHLO_ConstantOp $cst_zero), + MHLO_ComparisonDirectionValue<"NE">, $compare_type), - (HLO_CompareOp - (HLO_ConstantOp $cst_sgn), - (HLO_SignOp $rem1), - HLO_ComparisonDirectionValue<"NE">, + (MHLO_CompareOp + (MHLO_ConstantOp $cst_sgn), + (MHLO_SignOp $rem1), + MHLO_ComparisonDirectionValue<"NE">, $compare_type1)), - (HLO_AddOp - (HLO_DivOp:$div - (HLO_SubtractOp $arg0, $rem2), + (MHLO_AddOp + (MHLO_DivOp:$div + (MHLO_SubtractOp $arg0, $rem2), $bcst1), - (HLO_ConstantOp $cst_neg1)), + (MHLO_ConstantOp $cst_neg1)), $div1)), (TF_FloorDivOp $arg0, $bcst), [(ValueEquals<"0.0"> $cst_zero), @@ -623,6 +657,6 @@ def : Pat<(HLO_RoundOp // TorchIndexSelect op patterns. //===----------------------------------------------------------------------===// -def : Pat<(HLO_TorchIndexSelectOp $params, $indices, $axis, $batch_dims), +def : Pat<(MHLO_TorchIndexSelectOp $params, $indices, $axis, $batch_dims), (TF_GatherV2Op $params, $indices, (TF_ConstOp $axis), $batch_dims)>; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/localize_var_handles.cc b/tensorflow/compiler/mlir/tensorflow/transforms/localize_var_handles.cc index 54327cefe80..9aab63292e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/localize_var_handles.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/localize_var_handles.cc @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project @@ -50,11 +51,11 @@ struct LocalizeVarHandlesPass void MaybeCreateVarHandleForOp(Operation* op, DataFlowSolver& solver) { Value resource; if (auto read = llvm::dyn_cast(op)) { - resource = read.resource(); + resource = read.getResource(); } else if (auto write = llvm::dyn_cast(op)) { - resource = write.resource(); + resource = write.getResource(); } else if (auto next = llvm::dyn_cast(op)) { - resource = next.iterator(); + resource = next.getIterator(); } if (llvm::dyn_cast_or_null(resource.getDefiningOp())) { @@ -79,11 +80,11 @@ void MaybeCreateVarHandleForOp(Operation* op, DataFlowSolver& solver) { container = ""; shared_name = global.getSymName(); } else if (auto handle = llvm::dyn_cast(source)) { - container = handle.container(); - shared_name = handle.shared_name(); + container = handle.getContainer(); + shared_name = handle.getSharedName(); } else if (auto it = llvm::dyn_cast(source)) { - container = it.container(); - shared_name = it.shared_name(); + container = it.getContainer(); + shared_name = it.getSharedName(); } else { // Can't happen, as long as this file and resource_dataflow.cc are in sync. return; @@ -98,7 +99,7 @@ void MaybeCreateVarHandleForOp(Operation* op, DataFlowSolver& solver) { // See core/kernels/data/iterator_ops.cc.) resource_op = builder.create( op->getLoc(), resource.getType(), shared_name, container, - it.output_types(), it.output_shapes()); + it.getOutputTypes(), it.getOutputShapes()); } else { resource_op = builder.create( op->getLoc(), resource.getType(), container, shared_name); @@ -111,6 +112,7 @@ void LocalizeVarHandlesPass::runOnOperation() { DataFlowSolver solver; solver.load(); + solver.load(); solver.load(); if (failed(solver.initializeAndRun(module))) return signalPassFailure(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.cc index c4ddc1c23a7..75ca6e7ce0d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.cc @@ -121,12 +121,6 @@ SymbolRefAttr lookupGlobalTensor(func::FuncOp func, Value resource, } static LogicalResult convertTFGlobals(ModuleOp module) { - if (auto sessionInitializer = - tf_saved_model::GetSessionInitializerOp(module)) { - return sessionInitializer.emitError() - << "Session initializer is not supported yet"; - } - OpBuilder globalBuilder(module.getBodyRegion()); DenseMap opToName; for (auto globalTensor : module.getOps()) { @@ -155,22 +149,22 @@ static LogicalResult convertTFGlobals(ModuleOp module) { } bool success = true; func.walk([&](mlir::TF::ReadVariableOp op) { - auto sym = lookupGlobalTensor(func, op.resource(), syms, opToName); + auto sym = lookupGlobalTensor(func, op.getResource(), syms, opToName); success &= !!sym; if (!success) return; OpBuilder builder(op); auto load = builder.create( - op.getLoc(), op.value().getType(), sym); - op.value().replaceAllUsesWith(load.getResult()); + op.getLoc(), op.getValue().getType(), sym); + op.getValue().replaceAllUsesWith(load.getResult()); op.erase(); }); func.walk([&](mlir::TF::AssignVariableOp op) { - auto sym = lookupGlobalTensor(func, op.resource(), syms, opToName); + auto sym = lookupGlobalTensor(func, op.getResource(), syms, opToName); success &= !!sym; if (!success) return; OpBuilder builder(op); builder.create(op.getLoc(), sym, - op.value()); + op.getValue()); op.erase(); }); if (!success) return failure(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index bb54390816a..1c9b1e03a66 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -260,8 +260,8 @@ class LowerAddNOp : public RewritePattern { // support variant type so variant types require special handling. if (getElementTypeOrSelf(addn_op.getType()).isa()) return failure(); - llvm::SmallVector operands(addn_op.inputs().begin(), - addn_op.inputs().end()); + llvm::SmallVector operands(addn_op.getInputs().begin(), + addn_op.getInputs().end()); int64_t n = operands.size(); // Keep doing tree-based reduction when there are more than one operand. @@ -331,8 +331,8 @@ class LowerDynamicStitchOp : public RewritePattern { // Extract out all the constant indices' attributes and verify that data // types are static. SmallVector indices; - indices.reserve(op.N()); - for (auto it : llvm::zip(op.indices(), op.data())) { + indices.reserve(op.getN()); + for (auto it : llvm::zip(op.getIndices(), op.getData())) { Value index = std::get<0>(it); Value data = std::get<1>(it); @@ -361,7 +361,7 @@ class LowerDynamicStitchOp : public RewritePattern { // Prepare each of the output item by unpacking data and then putting it to // the specified index. SmallVector values(out_ty.getDimSize(0)); - for (auto it : llvm::zip(indices, op.data())) { + for (auto it : llvm::zip(indices, op.getData())) { DenseIntElementsAttr index_attr = std::get<0>(it); Value data = std::get<1>(it); @@ -406,18 +406,18 @@ class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern { PatternRewriter &rewriter) const override { auto op = cast(src_op); - auto input = op.inputs(); + auto input = op.getInputs(); auto input_ty = input.getType().cast(); auto element_ty = input_ty.getElementType(); auto scalar_ty = tensorflow::GetTypeFromTFTensorShape({}, element_ty); - auto num_bits = op.num_bits(); - auto narrow_range = op.narrow_range(); + auto num_bits = op.getNumBits(); + auto narrow_range = op.getNarrowRange(); const double bits_min = narrow_range ? 1 : 0; const double bits_max = (1 << num_bits) - 1; - auto float_min = op.min(); - auto float_max = op.max(); + auto float_min = op.getMin(); + auto float_max = op.getMax(); auto float_diff = rewriter.create(op.getLoc(), float_max, float_min); @@ -534,7 +534,7 @@ class LowerInvertPermutationOp : public RewritePattern { auto op = cast(src_op); Location loc = op.getLoc(); - auto x_type = op.x().getType().dyn_cast(); + auto x_type = op.getX().getType().dyn_cast(); // x input must have static shape. if (!x_type || !x_type.hasStaticShape()) { return failure(); @@ -555,10 +555,10 @@ class LowerInvertPermutationOp : public RewritePattern { auto shape = rewriter.create( loc, DenseElementsAttr::get( shape_type, {static_cast(x_type.getDimSize(0)), 1})); - auto indices = rewriter.create(loc, op.x(), shape); + auto indices = rewriter.create(loc, op.getX(), shape); - rewriter.replaceOpWithNewOp(op, result_type, op.x(), - indices, updates); + rewriter.replaceOpWithNewOp( + op, result_type, op.getX(), indices, updates); return success(); } }; @@ -616,8 +616,8 @@ class LowerLgammaOp : public RewritePattern { auto op = cast(src_op); Location loc = op.getLoc(); - Value input = op.x(); - TensorType original_tensor_type = op.x().getType().cast(); + Value input = op.getX(); + TensorType original_tensor_type = op.getX().getType().cast(); // The approximation is not precise enough for float16. Do the computation // in float32 for that case. @@ -814,13 +814,13 @@ class LowerPackOp : public RewritePattern { auto axis_value = rewriter.create( loc, DenseElementsAttr::get(tensorflow::GetTypeFromTFTensorShape( {}, rewriter.getIntegerType(64)), - op.axis())); - int64_t axis = op.axis(); + op.getAxis())); + int64_t axis = op.getAxis(); Type prev_input_ty, inferred_ty; SmallVector expanded_inputs; - expanded_inputs.reserve(op.N()); - for (Value input : op.values()) { + expanded_inputs.reserve(op.getN()); + for (Value input : op.getValues()) { // If input type is different than the previous input type, infer the // output type. Otherwise, use the already inferred output type from the // previous iteration. @@ -887,17 +887,17 @@ class LowerSpaceToBatchNDOp : public RewritePattern { auto op = cast(src_op); Location loc = op.getLoc(); - auto input_type = op.input().getType().cast(); + auto input_type = op.getInput().getType().cast(); auto element_type = input_type.getElementType(); if (!input_type.hasStaticShape()) { return failure(); } ArrayRef input_shape = input_type.getShape(); - auto block_shape_type = op.block_shape().getType().cast(); + auto block_shape_type = op.getBlockShape().getType().cast(); if (!block_shape_type.hasStaticShape()) { return failure(); } - auto paddings_type = op.paddings().getType().cast(); + auto paddings_type = op.getPaddings().getType().cast(); if (!paddings_type.hasRank()) { return failure(); } @@ -913,12 +913,12 @@ class LowerSpaceToBatchNDOp : public RewritePattern { auto block_shape_i64_type = tensorflow::GetTypeFromTFTensorShape( block_shape_type.getShape(), rewriter.getIntegerType(64)); auto block_shape_i64 = - rewriter.create(loc, block_shape_i64_type, op.block_shape()); + rewriter.create(loc, block_shape_i64_type, op.getBlockShape()); auto paddings_i64_type = tensorflow::GetTypeFromTFTensorShape( paddings_type.getShape(), rewriter.getIntegerType(64)); auto paddings_i64 = - rewriter.create(loc, paddings_i64_type, op.paddings()); + rewriter.create(loc, paddings_i64_type, op.getPaddings()); auto pad00 = rewriter.create( loc, DenseElementsAttr::get( @@ -942,8 +942,8 @@ class LowerSpaceToBatchNDOp : public RewritePattern { ElementsAttr paddings; llvm::SmallVector block_shape_ints; auto padded_shape = llvm::to_vector<4>(input_shape); - if (matchPattern(op.block_shape(), m_Constant(&block_shape)) && - matchPattern(op.paddings(), m_Constant(&paddings))) { + if (matchPattern(op.getBlockShape(), m_Constant(&block_shape)) && + matchPattern(op.getPaddings(), m_Constant(&paddings))) { for (uint64_t i = 0; i < block_rank; i++) { int64_t paddings_sum = paddings.getValues()[{i, 0}].getSExtValue() + @@ -955,7 +955,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern { } } else { for (int i = 0; i < block_rank; i++) { - padded_shape[i + 1] = ShapedType::kDynamicSize; + padded_shape[i + 1] = ShapedType::kDynamic; } block_shape_ints.resize(block_shape_type.getNumElements(), -1); } @@ -964,7 +964,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern { tensorflow::GetTypeFromTFTensorShape(padded_shape, element_type); // padded = pad(input, full_paddings) auto padded = - rewriter.create(loc, padded_type, op.input(), full_paddings); + rewriter.create(loc, padded_type, op.getInput(), full_paddings); auto paddings_sum_type = tensorflow::GetTypeFromTFTensorShape( {input_rank}, rewriter.getIntegerType(64)); @@ -994,7 +994,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern { rewriter .create(loc, padded_shape_splits_types, zero_i32, padded_shape_tensor) - .output()); + .getOutput()); SmallVector block_shape_splits_types( block_rank, @@ -1003,7 +1003,7 @@ class LowerSpaceToBatchNDOp : public RewritePattern { rewriter .create(loc, block_shape_splits_types, zero_i32, block_shape_i64) - .output()); + .getOutput()); SmallVector outer_shape_ints; SmallVector outer_shape_vals; @@ -1099,7 +1099,7 @@ class LowerBatchToSpaceND : public RewritePattern { LogicalResult matchAndRewrite(Operation *src_op, PatternRewriter &rewriter) const override { auto op = cast(src_op); - auto input = op.input(); + auto input = op.getInput(); auto input_ty = input.getType().cast(); auto element_ty = input_ty.getElementType(); if (!input_ty.hasStaticShape()) { @@ -1111,8 +1111,8 @@ class LowerBatchToSpaceND : public RewritePattern { DenseIntElementsAttr block_shape; DenseIntElementsAttr crops; - if (!matchPattern(op.block_shape(), m_Constant(&block_shape)) || - !matchPattern(op.crops(), m_Constant(&crops))) { + if (!matchPattern(op.getBlockShape(), m_Constant(&block_shape)) || + !matchPattern(op.getCrops(), m_Constant(&crops))) { return failure(); } @@ -1279,11 +1279,15 @@ class LowerSparseMatMulOp : public RewritePattern { // Result type must be f32 for applying the pattern (currently this is // required by the op anyway but this might change). - if (!op.product().getType().cast().getElementType().isF32()) { + if (!op.getProduct() + .getType() + .cast() + .getElementType() + .isF32()) { return failure(); } MLIRContext *context = rewriter.getContext(); - llvm::SmallVector operands{op.a(), op.b()}; + llvm::SmallVector operands{op.getA(), op.getB()}; for (Value &operand : operands) { TensorType tensor_type = operand.getType().cast(); Type element_type = tensor_type.getElementType(); @@ -1302,8 +1306,8 @@ class LowerSparseMatMulOp : public RewritePattern { operand = rewriter.create(op.getLoc(), tensor_type_f32, operand); } Value result = rewriter.create( - op.getLoc(), op.product().getType(), operands[0], operands[1], - op.transpose_a(), op.transpose_b()); + op.getLoc(), op.getProduct().getType(), operands[0], operands[1], + op.getTransposeA(), op.getTransposeB()); rewriter.replaceOp(op, {result}); return success(); @@ -1319,8 +1323,8 @@ class Lower_UnaryOpsComposition LogicalResult matchAndRewrite(_UnaryOpsCompositionOp op, PatternRewriter &rewriter) const override { - Value result = op.x(); - for (StringRef op_name : op.op_names().getAsValueRange()) { + Value result = op.getX(); + for (StringRef op_name : op.getOpNames().getAsValueRange()) { std::string full_name = "tf." + op_name.str(); // All ops in the sequences have the same result type as the original // result type. @@ -1372,10 +1376,10 @@ class LowerResizeNearestNeighbor : public RewritePattern { auto loc = op.getLoc(); auto result_ty = op.getType().cast(); - auto input = op.images(); + auto input = op.getImages(); auto input_ty = input.getType().cast(); auto input_element_ty = input_ty.getElementType(); - auto out_size = op.size(); + auto out_size = op.getSize(); auto out_size_ty = out_size.getType().cast(); auto out_size_element_ty = out_size_ty.getElementType(); @@ -1425,7 +1429,7 @@ class LowerResizeNearestNeighbor : public RewritePattern { in_y_cst < 0 || in_x_cst < 0 ? -1 : in_y_cst * in_x_cst; // TODO(suderman): Add support for these optional parameters. - if (op.align_corners() == true || op.half_pixel_centers() == true) { + if (op.getAlignCorners() == true || op.getHalfPixelCenters() == true) { return failure(); } @@ -1615,14 +1619,15 @@ struct LowerRollOp : public RewritePattern { PatternRewriter &rewriter) const override { auto tf_roll_op = cast(op); - auto input_ty = tf_roll_op.input().getType().dyn_cast(); + auto input_ty = + tf_roll_op.getInput().getType().dyn_cast(); if (!input_ty || !input_ty.hasStaticShape()) { return rewriter.notifyMatchFailure( op, "require the type of input to have static shapes"); } DenseIntElementsAttr shift_attr; - Value shift = tf_roll_op.shift(); + Value shift = tf_roll_op.getShift(); auto shift_ranked_attr_type = shift.getType().dyn_cast(); if (!shift_ranked_attr_type || !matchPattern(shift, m_Constant(&shift_attr))) { @@ -1630,7 +1635,7 @@ struct LowerRollOp : public RewritePattern { } DenseIntElementsAttr axis_attr; - Value axis = tf_roll_op.axis(); + Value axis = tf_roll_op.getAxis(); auto axis_ranked_attr_type = axis.getType().dyn_cast(); if (!axis_ranked_attr_type || !matchPattern(axis, m_Constant(&axis_attr))) { return failure(); @@ -1681,7 +1686,7 @@ struct LowerRollOp : public RewritePattern { size); }; - auto result = tf_roll_op.input(); + auto result = tf_roll_op.getInput(); auto scalar_type = tensorflow::GetTypeFromTFTensorShape({}, rewriter.getIntegerType(32)); for (int i = 0; i < adjusted_axis.size(); ++i) { @@ -1697,7 +1702,7 @@ struct LowerRollOp : public RewritePattern { rewriter.create(op->getLoc(), scalar_type, dim_attr); auto concat_op = rewriter.create( op->getLoc(), input_ty, - ArrayRef({slice_op_1.output(), slice_op_2.output()}), + ArrayRef({slice_op_1.getOutput(), slice_op_2.getOutput()}), concat_dim); result = concat_op.getResult(); } @@ -1721,7 +1726,7 @@ class LowerSoftmaxOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - Value logits = op.logits(); + Value logits = op.getLogits(); auto loc = op.getLoc(); // Note that the TensorFlow Softmax op verifies that the input rank is @@ -1833,7 +1838,8 @@ void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context, LowerSquaredDifferenceOpOnRealTensors, LowerSquaredDifferenceOpOneComplexTensors, LowerTanhGradOp, - LowerTruncateDivOp, + LowerTruncateDivOpOnIntTensors, + LowerTruncateDivOpOnFloatTensors, LowerXdivyOp, LowerXlog1pyOp, LowerXlogyOp>(context); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index 8464bc24610..7d0bb5b273a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -218,10 +218,23 @@ def LowerMulNoNanOp : BinaryNoNanPat; // TruncateDiv op patterns. //===----------------------------------------------------------------------===// -def LowerTruncateDivOp : Pat< - (TF_TruncateDivOp $lhs, $rhs), +def LowerTruncateDivOpOnIntTensors : Pat< + (TF_TruncateDivOp TF_IntTensor:$lhs, $rhs), (TF_DivOp $lhs, $rhs)>; +// Note: truncation could also be implemented as sign(x) * floor(abs(x)) or +// (-1 & x) || floor(abs(x)), based on performance benchmarks. +def LowerTruncateDivOpOnFloatTensors : Pat< + (TF_TruncateDivOp TF_FloatTensor:$lhs, $rhs), + (TF_SelectV2Op + (TF_LessOp + (TF_DivOp:$div $lhs, $rhs), + (TF_ConstOp:$zero (GetScalarOfFloatType<"0.0"> $lhs)) + ), + (TF_CeilOp $div), + (TF_FloorOp $div) + )>; + //===----------------------------------------------------------------------===// // Fill op patterns. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_variable_ops_to_ml_program.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_variable_ops_to_ml_program.cc index ea1ef145077..ac8048c6a4f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_variable_ops_to_ml_program.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_variable_ops_to_ml_program.cc @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project #include "mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -45,8 +46,8 @@ namespace { std::string GetVariableName(Operation* op) { if (auto handle = dyn_cast(op)) { - std::string container = handle.container().str(); - std::string shared_name = handle.shared_name().str(); + std::string container = handle.getContainer().str(); + std::string shared_name = handle.getSharedName().str(); if (container.empty()) { return absl::StrCat("vars.", shared_name); } else { @@ -61,9 +62,9 @@ std::string GetVariableName(Operation* op) { Operation* GetHandleSource(Operation* op, DataFlowSolver& solver) { Value resource; if (auto read = llvm::dyn_cast(op)) { - resource = read.resource(); + resource = read.getResource(); } else if (auto write = llvm::dyn_cast(op)) { - resource = write.resource(); + resource = write.getResource(); } const TF::ResourceDataflowAnalysis::StateT* state = solver.lookupState(resource); @@ -90,7 +91,7 @@ Type GetGlobalType(Operation* source) { // Resources are represented as tensor>>, so // unwrap until we get to the inner tensor<...>. auto tensor = - llvm::dyn_cast(var_handle_op.resource().getType()); + llvm::dyn_cast(var_handle_op.getResource().getType()); if (!tensor) return nullptr; TF::ResourceType resource = llvm::dyn_cast(tensor.getElementType()); @@ -148,6 +149,7 @@ struct LowerVariableOpsToMlProgramPass DataFlowSolver solver; solver.load(); + solver.load(); solver.load(); if (failed(solver.initializeAndRun(module))) return signalPassFailure(); @@ -165,8 +167,8 @@ struct LowerVariableOpsToMlProgramPass Operation* load = builder.create( op.getLoc(), globalOp.getType(), SymbolRefAttr::get(op->getContext(), globalOp.getSymName())); - if (globalOp.getType() != op.value().getType()) { - load = builder.create(op.getLoc(), op.value().getType(), + if (globalOp.getType() != op.getValue().getType()) { + load = builder.create(op.getLoc(), op.getValue().getType(), load->getResult(0)); } op.getResult().replaceAllUsesWith(load->getResult(0)); @@ -182,8 +184,8 @@ struct LowerVariableOpsToMlProgramPass symbol_table.insert(globalOp); OpBuilder builder(op); globalOp.setIsMutableAttr(builder.getUnitAttr()); - Value value_to_store = op.value(); - if (globalOp.getType() != op.value().getType()) { + Value value_to_store = op.getValue(); + if (globalOp.getType() != op.getValue().getType()) { value_to_store = builder.create( op.getLoc(), globalOp.getType(), value_to_store); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_input_output_aliases.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_input_output_aliases.cc index 842d7a8fe28..58eb0959eb3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_input_output_aliases.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_input_output_aliases.cc @@ -62,7 +62,7 @@ LogicalResult BuildAliasingInfo( auto assign_op = llvm::dyn_cast_or_null( result.use_begin()->getOwner()); if (!assign_op) continue; - AliasInfo& alias_info = resource_alias_info_map[assign_op.resource()]; + AliasInfo& alias_info = resource_alias_info_map[assign_op.getResource()]; // TODO(b/184420848): We may not need to skip aliasing for entire function // in case of multiple assigns. if (alias_info.output_index != kUnassigned) { @@ -82,7 +82,7 @@ LogicalResult BuildAliasingInfo( operand.get().getDefiningOp()); if (!read_op) continue; if (!read_op->hasOneUse()) continue; - auto it = resource_alias_info_map.find(read_op.resource()); + auto it = resource_alias_info_map.find(read_op.getResource()); if (it == resource_alias_info_map.end()) continue; AliasInfo& alias_info = it->getSecond(); // TODO(b/184420848): We may not need to skip aliasing for entire function diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc index 0c268c2002c..980325dbb96 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include @@ -90,17 +91,22 @@ void AddSupportedOpsUsingFolding(MLIRContext* context, supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end()); } -// Adds the list of ops that are supported through dynamic padder using op by op -// fallback to the TF2XLA bridge. -// TODO(b/168036682): Remove this once ops are supported using dynamic padder -// on MLIR bridge. -void AddSupportedOpsUsingDynamicPadder( - MLIRContext* context, llvm::DenseSet* supported_ops) { +// Adds the list of ops that are only supported in the old bridge. +// TODO(b/168036682): Remove bounded dynamism ops now that MLIR bridge supports +// bounded dynamism. +// TODO(b/257574556): Remove the need for this manual list by making use of old +// bridge phase 2 op list. +void AddOldBridgeOnlyOps(MLIRContext* context, + llvm::DenseSet* supported_ops) { llvm::SmallDenseSet allowlist_ops = { + OperationName(TF::DynamicPartitionOp::getOperationName(), context), + OperationName(TF::OutfeedEnqueueOp::getOperationName(), context), OperationName(TF::WhereOp::getOperationName(), context), OperationName(TF::UniqueOp::getOperationName(), context), OperationName(TF::XlaSetDynamicDimensionSizeOp::getOperationName(), context), + OperationName(TF::XlaSpmdFullToShardShapeOp::getOperationName(), context), + OperationName(TF::XlaSpmdShardToFullShapeOp::getOperationName(), context), }; supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end()); @@ -373,8 +379,8 @@ bool ContainsUncompilableOps(const Dialect* tf_dialect, Block* block, // Unmarks outside compilation for any op that has parents already // marked for outside compilation since the child will be extracted // anyways. -void UnmarkChildren(Block* block) { - block->walk([&](Operation* op) { +void UnmarkChildren(ModuleOp module) { + module->walk([&](Operation* op) { if (!op->getAttrOfType(kXlaOutsideCompilationAttr)) return; Operation* iter_op = op; bool remove_attr = false; @@ -409,12 +415,12 @@ void MarkOpsForOutsideCompilation::runOnOperation() { llvm::DenseSet supported_ops; PatternApplicator(std::move(patterns)) .walkAllPatterns([&](const Pattern& pattern) { - Optional root_kind = pattern.getRootKind(); - if (root_kind.has_value()) supported_ops.insert(root_kind.getValue()); + std::optional root_kind = pattern.getRootKind(); + if (root_kind.has_value()) supported_ops.insert(root_kind.value()); }); AddSupportedFunctionalOps(module.getContext(), &supported_ops); AddSupportedOpsUsingFolding(module.getContext(), &supported_ops); - AddSupportedOpsUsingDynamicPadder(module.getContext(), &supported_ops); + AddOldBridgeOnlyOps(module.getContext(), &supported_ops); AddRewrittenEmbeddingOps(module.getContext(), &supported_ops); AddRewrittenCompositeOps(module.getContext(), &supported_ops); @@ -439,16 +445,7 @@ void MarkOpsForOutsideCompilation::runOnOperation() { if (result.wasInterrupted()) return signalPassFailure(); - module.walk([&](tf_device::ClusterOp cluster) { - // Only if `allow_soft_placement` attribute is true should we unmark ops - // for outside compilation. - auto soft_placement_attr = - cluster->getAttrOfType(kAllowSoftPlacementAttr); - if (!(soft_placement_attr && soft_placement_attr.getValue())) { - return; - } - UnmarkChildren(&cluster.GetBody()); - }); + UnmarkChildren(module); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc b/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc index 8a57f7193bd..8b1c1fe28c7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc @@ -48,7 +48,7 @@ class MaterializePassthroughOpPass void MaterializePassthroughOpPass::runOnOperation() { getOperation().walk([](TF::MlirPassthroughOp op) { - std::string module_string(op.mlir_module()); + std::string module_string(op.getMlirModule()); // Parse the module. auto nested_module = parseSourceString(module_string, op.getContext()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc b/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc index 9bccaaa3112..e47bc7f4771 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc @@ -81,10 +81,10 @@ struct MergeControlFlowPass llvm::SmallSetVector GetAllOpsFromIf(TF::IfRegionOp if_op) { llvm::SmallSetVector all_ops; all_ops.insert(if_op); - for (Operation& op : if_op.then_branch().front()) { + for (Operation& op : if_op.getThenBranch().front()) { all_ops.insert(&op); } - for (Operation& op : if_op.else_branch().front()) { + for (Operation& op : if_op.getElseBranch().front()) { all_ops.insert(&op); } return all_ops; @@ -121,14 +121,14 @@ bool SafeToMerge(TF::IfRegionOp first_if, TF::IfRegionOp second_if, dependencies.push_back(successor); } } - for (Operation& op : first_if.then_branch().front()) { + for (Operation& op : first_if.getThenBranch().front()) { for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) { if (!downstream_if_ops.contains(successor) && !destination_ops.contains(successor)) dependencies.push_back(successor); } } - for (Operation& op : first_if.else_branch().front()) { + for (Operation& op : first_if.getElseBranch().front()) { for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) { if (!downstream_if_ops.contains(successor) && !destination_ops.contains(successor)) @@ -173,14 +173,14 @@ bool SafeToMerge(TF::IfRegionOp first_if, TF::IfRegionOp second_if, // Move the body excluding the terminators of else and then regions from // 'second_if' to 'first_if'. void MoveBranches(TF::IfRegionOp first_if, TF::IfRegionOp second_if) { - Block& first_if_then_block = first_if.then_branch().front(); - auto& second_if_then_body = second_if.then_branch().front().getOperations(); + Block& first_if_then_block = first_if.getThenBranch().front(); + auto& second_if_then_body = second_if.getThenBranch().front().getOperations(); first_if_then_block.getOperations().splice( first_if_then_block.without_terminator().end(), second_if_then_body, second_if_then_body.begin(), std::prev(second_if_then_body.end())); - Block& first_if_else_block = first_if.else_branch().front(); - auto& second_if_else_body = second_if.else_branch().front().getOperations(); + Block& first_if_else_block = first_if.getElseBranch().front(); + auto& second_if_else_body = second_if.getElseBranch().front().getOperations(); first_if_else_block.getOperations().splice( first_if_else_block.without_terminator().end(), second_if_else_body, second_if_else_body.begin(), std::prev(second_if_else_body.end())); @@ -396,14 +396,14 @@ void ReplaceInternalUsage(llvm::SmallVector& if_op_segment) { for (OpResult result : it->getResults()) { replaceAllUsesInRegionWith( result, - it->then_branch().front().getTerminator()->getOperand( + it->getThenBranch().front().getTerminator()->getOperand( result.getResultNumber()), - it2->then_branch()); + it2->getThenBranch()); replaceAllUsesInRegionWith( result, - it->else_branch().front().getTerminator()->getOperand( + it->getElseBranch().front().getTerminator()->getOperand( result.getResultNumber()), - it2->else_branch()); + it2->getElseBranch()); } } } @@ -487,12 +487,12 @@ void CreateYieldOps( auto if_op = index_and_value.value(); for (auto i : return_indices[index_and_value.index()]) { merged_then_yield_values.push_back( - if_op.then_branch().front().getTerminator()->getOperand(i)); + if_op.getThenBranch().front().getTerminator()->getOperand(i)); } } - builder.setInsertionPointToEnd(&new_if_op.then_branch().front()); + builder.setInsertionPointToEnd(&new_if_op.getThenBranch().front()); builder.create( - first_if.then_branch().front().getTerminator()->getLoc(), + first_if.getThenBranch().front().getTerminator()->getLoc(), /*operands=*/merged_then_yield_values); llvm::SmallVector merged_else_yield_values; @@ -500,12 +500,12 @@ void CreateYieldOps( auto if_op = index_and_value.value(); for (auto i : return_indices[index_and_value.index()]) { merged_else_yield_values.push_back( - if_op.else_branch().front().getTerminator()->getOperand(i)); + if_op.getElseBranch().front().getTerminator()->getOperand(i)); } } - builder.setInsertionPointToEnd(&new_if_op.else_branch().front()); + builder.setInsertionPointToEnd(&new_if_op.getElseBranch().front()); builder.create( - first_if.else_branch().front().getTerminator()->getLoc(), + first_if.getElseBranch().front().getTerminator()->getLoc(), /*operands=*/merged_else_yield_values); } @@ -541,12 +541,12 @@ void MergeIfPerSegment( builder.setInsertionPoint(if_op_segment.back().getOperation()); auto new_if_op = builder.create( - first_if.getLoc(), merged_return_types, first_if.cond(), + first_if.getLoc(), merged_return_types, first_if.getCond(), llvm::all_of(if_op_segment, - [&](TF::IfRegionOp op) { return op.is_stateless(); }), - first_if._then_func_nameAttr(), first_if._else_func_nameAttr()); - new_if_op.then_branch().push_back(new Block); - new_if_op.else_branch().push_back(new Block); + [&](TF::IfRegionOp op) { return op.getIsStateless(); }), + first_if.get_thenFuncNameAttr(), first_if.get_elseFuncNameAttr()); + new_if_op.getThenBranch().push_back(new Block); + new_if_op.getElseBranch().push_back(new Block); // Replace internal usages of merged if ops. ReplaceInternalUsage(if_op_segment); @@ -607,9 +607,9 @@ void OptimizeIfRegions(Block* block, ModuleOp module) { grouped_if_ops; llvm::SmallVector if_cond_order; block->walk([&](TF::IfRegionOp if_op) { - auto it = grouped_if_ops.try_emplace(if_op.cond()); + auto it = grouped_if_ops.try_emplace(if_op.getCond()); if (it.second) { - if_cond_order.push_back(if_op.cond()); + if_cond_order.push_back(if_op.getCond()); } it.first->getSecond().push_back(if_op); }); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.cc index 7588a7a7b93..90dc43afaee 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" namespace tensorflow { @@ -36,7 +36,10 @@ void PopulateLowerToMlProgramAndHloPipeline(mlir::OpPassManager& pm) { // Remove unused global tensors, or make then immutable if possible. pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); - pm.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass()); + pm.addPass( + mlir::tf_saved_model::CreateConvertSessionInitializerToFunctionPass()); + pm.addNestedPass( + mlir::TFDevice::CreateDecomposeResourceOpsPass()); pm.addPass(mlir::TF::CreateNameAnonymousIteratorsPass()); // This will add regions to IfOp/WhileOp (turning them into IfRegionOp @@ -50,7 +53,8 @@ void PopulateLowerToMlProgramAndHloPipeline(mlir::OpPassManager& pm) { pm.addPass(mlir::tf_saved_model::CreateStripSavedModuleMetadataPass()); pm.addPass(mlir::TF::CreateRemoveUnusedArgumentsPass()); - pm.addPass(mlir::TF::CreateRemoveUnusedWhileResultsPass()); + pm.addNestedPass( + mlir::TF::CreateRemoveUnusedWhileResultsPass()); pm.addPass(mlir::createInlinerPass()); pm.addPass(mlir::createCanonicalizerPass()); @@ -63,8 +67,6 @@ void PopulateLowerToMlProgramAndHloPipeline(mlir::OpPassManager& pm) { /*allow_partial_conversion=*/true, /*legalize_chlo=*/true, tf2xla_fallback_device_type, /*prefer_tf2xla=*/false)); - pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass()); - pm.addPass(mlir::TF::CreateStripTfAttributesPass()); pm.addPass(mlir::createCanonicalizerPass()); @@ -73,7 +75,6 @@ void PopulateLowerToMlProgramAndHloPipeline(mlir::OpPassManager& pm) { pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::TF::CreateOrderByDialectPass()); - pm.addPass(mlir::TF::CreateGroupByDialectPass()); pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/name_anonymous_iterators.cc b/tensorflow/compiler/mlir/tensorflow/transforms/name_anonymous_iterators.cc index 7031386c39a..ef3627f95a0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/name_anonymous_iterators.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/name_anonymous_iterators.cc @@ -52,7 +52,7 @@ int replace(OP op, int count) { auto new_op = builder.create( op->getLoc(), op->getResultTypes()[0], name, /*container=*/"", - op.output_types(), op.output_shapes()); + op.getOutputTypes(), op.getOutputShapes()); op->getResults()[0].replaceAllUsesWith(new_op->getResults()[0]); if (op->use_empty()) op->erase(); return count; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index 7ddd52cd314..cd608bdf269 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -50,20 +50,20 @@ class SimplifyBroadcastReshape : public OpRewritePattern { LogicalResult matchAndRewrite(BroadcastToOp op, PatternRewriter &rewriter) const override { // Only rewrite if the Broadcast has only one consumer. - if (!op.output().hasOneUse()) return failure(); + if (!op.getOutput().hasOneUse()) return failure(); - Operation *user = *op.output().getUsers().begin(); + Operation *user = *op.getOutput().getUsers().begin(); auto reshape_op = llvm::dyn_cast_or_null(user); if (!reshape_op) return failure(); - auto reshape_type = reshape_op.output().getType().cast(); + auto reshape_type = reshape_op.getOutput().getType().cast(); if (!reshape_type.hasStaticShape()) return failure(); ArrayRef reshape_shape = reshape_type.getShape(); - auto input_type = op.input().getType().cast(); - auto output_type = op.output().getType().cast(); + auto input_type = op.getInput().getType().cast(); + auto output_type = op.getOutput().getType().cast(); if (!input_type.hasRank() || !output_type.hasRank()) return failure(); @@ -120,11 +120,11 @@ class SimplifyBroadcastReshape : public OpRewritePattern { auto new_reshape_type = RankedTensorType::get(new_reshape_dims, el_ty); ReshapeOp new_reshape = rewriter.create(new_reshape_shape.getLoc(), new_reshape_type, - op.input(), new_reshape_shape); + op.getInput(), new_reshape_shape); TF::ConstOp new_broadcast_shape = GetI64ConstantTensor(rewriter, reshape_shape, op.getLoc()); rewriter.replaceOpWithNewOp( - reshape_op, reshape_op.output().getType(), new_reshape, + reshape_op, reshape_op.getOutput().getType(), new_reshape, new_broadcast_shape); return success(); } @@ -166,8 +166,7 @@ void CreateTFStandardPipeline(OpPassManager &pm, func_pm.addPass(tf_executor::CreateTFExecutorGraphPruningPass()); func_pm.addPass(tf_executor::CreateTFExecutorIslandCoarseningPass()); func_pm.addPass(CreateMaterializePassthroughOpPass()); - if (options.form_clusters) - func_pm.addPass(TFDevice::CreateClusterFormationPass()); + if (options.form_clusters) pm.addPass(TFDevice::CreateClusterFormationPass()); // Hopefully there is a single island left, or there wasn't any to begin with. // We now run the optimizer which operates mostly inside islands. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/order_by_dialect.cc b/tensorflow/compiler/mlir/tensorflow/transforms/order_by_dialect.cc index d8f5ccb5325..5a3f91c0d23 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/order_by_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/order_by_dialect.cc @@ -16,13 +16,16 @@ limitations under the License. #include #include #include +#include #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/topological_sort.h" namespace mlir { namespace TF { @@ -36,19 +39,42 @@ std::vector groupOperationsByDialect(Block& block); // Reorder operations so that consecutive ops stay in the same dialect, as far // as possible. This is to optimize the op order for the group-by-dialect pass, // which factors consecutive same-dialect ops into functions. -// TODO(kramm): This pass needs to become aware of side-effects between ops -// of different dialects. class OrderByDialectPass : public impl::OrderByDialectPassBase { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OrderByDialectPass) + void runOnOperation() override; +}; + +int DialectOrdering(Operation* predecessor, Operation* op) { + return predecessor && predecessor->getName().getDialectNamespace() == + op->getName().getDialectNamespace(); +} - void runOnOperation() override { - getOperation().walk([](Operation* function) { +void OrderByDialectPass::runOnOperation() { + ModuleOp module = getOperation(); + for (func::FuncOp func : module.getOps()) { + std::vector> side_effect_data; + const detail::SideEffectAnalysisInfo* info = nullptr; + auto extra_dependencies = + [&](Operation* op, + bool incoming) -> llvm::SmallVector const& { + return incoming ? info->DirectControlPredecessors(op) + : info->DirectControlSuccessors(op); + }; + // Some tests have recursive calls and other shenanigans, so allow + // them to skip side effect analysis. + if (!func->hasAttr("ignore_side_effects_for_testing")) { + info = + &getAnalysis().GetAnalysisForFunc(func); + } + func->walk([&](Operation* function) { for (Region& region : function->getRegions()) { for (Block& block : region.getBlocks()) { if (block.empty()) continue; - auto ops = groupOperationsByDialect(block); + auto ops = SortBlockTopologically( + block, DialectOrdering, + info ? extra_dependencies : no_extra_dependencies); // Replace the block with the reordered block. for (Operation* op : ops) { op->remove(); @@ -58,97 +84,6 @@ class OrderByDialectPass } }); } -}; - -// Similar to MLIR's topological sort (lib/Transforms/TopologicalSort.cpp) -// but has an explicit scoring function to determine the next op to emit. -// Note that this doesn't explicitly handle TF side effects. However, -// it typically leaves the order of operations within a given dialect the -// same, and different dialects tend to not access the same resources. -std::vector groupOperationsByDialect(Block& block) { - llvm::DenseMap remaining_incoming_edges; - llvm::DenseMap position; - llvm::DenseMap ancestor; - SmallVector ready; - - int i = 0; - for (Operation& op : block.getOperations()) { - int incoming_edges = 0; - op.walk([&](Operation* child) { - ancestor[child] = &op; - for (Value v : child->getOperands()) { - if (v.getParentBlock() == &block) { - incoming_edges++; - } - } - }); - remaining_incoming_edges[&op] = incoming_edges; - if (incoming_edges == 0) { - ready.push_back(&op); - } - position[&op] = i++; - } - - std::queue todo; - for (Value value : block.getArguments()) { - todo.push(value); - } - - StringRef current_dialect = ""; - - std::vector result; - while (!todo.empty() || !ready.empty()) { - while (!todo.empty()) { - Value value = todo.front(); - todo.pop(); - // All operations that have all their inputs available are good to go. - // Uses, not Users, in case getUsers ever dedups. - for (OpOperand& operand : value.getUses()) { - Operation* user = ancestor[operand.getOwner()]; - if (--remaining_incoming_edges[user] == 0) { - ready.push_back(user); - } - } - } - - // Find the "best" operation to emit. We - // (a) stay in the same dialect as far as possible. - // (b) preserve order within the ops of one dialect. - // (c) emit the terminator last. - auto better = [&](Operation* a, Operation* b) { - if (a->hasTrait() != - b->hasTrait()) { - return b->hasTrait(); - } - bool a_current = a->getName().getDialectNamespace() == current_dialect; - bool b_current = b->getName().getDialectNamespace() == current_dialect; - if (a_current != b_current) { - return a_current; - } - return position[a] < position[b]; // preserve order - }; - - Operation* best = nullptr; - for (Operation* op : ready) { - if (best == nullptr || better(op, best)) { - best = op; - } - } - - if (!best) { - assert(ready.empty()); - return result; // happens for unused results for ops in the todo list - } - - // Consider this operation emitted, and make its results available. - ready.erase(std::find(ready.begin(), ready.end(), best)); - current_dialect = best->getName().getDialectNamespace(); - for (Value result : best->getResults()) { - todo.push(result); - } - result.push_back(best); - } - return result; } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc index b4e7630d219..dcb806635ee 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc @@ -65,15 +65,24 @@ limitations under the License. // then this pass will run following `replicate-to-island` pass and // `tf-executor-break-up-islands` pass. +#include +#include + +#include "absl/strings/str_cat.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/split_into_island_per_op_pass.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" namespace mlir { namespace TFDevice { @@ -85,6 +94,9 @@ namespace { struct ParallelExecuteToIslandsPass : public impl::ParallelExecuteToIslandsPassBase< ParallelExecuteToIslandsPass> { + explicit ParallelExecuteToIslandsPass(bool legacy_graph_export) { + legacy_graph_export_ = legacy_graph_export; + } void runOnOperation() override; }; @@ -94,7 +106,8 @@ struct ParallelExecuteToIslandsPass void ExpandParallelExecuteToIslands( tf_executor::IslandOp island_op, tf_device::ParallelExecuteOp parallel_execute_op, OpBuilder* builder, - llvm::SmallVectorImpl& executes) { + llvm::SmallVectorImpl& executes, + bool legacy_graph_export, int parallel_group_idx) { const int num_regions = parallel_execute_op.getOperation()->getNumRegions(); executes.reserve(num_regions); @@ -117,19 +130,46 @@ void ExpandParallelExecuteToIslands( // Move over tf_device.parallel_execute body region into newly the created // island. execute_island.getBody().takeBody(*execute_block.getParent()); + + // In new graph export pipeline, we will update control dependencies in the + // end of the pipeline. Mostly, it will rely on side effect analysis by + // considering accessing resource only. However, for branches under parallel + // group, there should not be any control deps between them even side effect + // analysis indicate some control deps. Therefore, we will mark parallel + // group and branch information here so that `UpdateControlDependenciesPass` + // can fetch the related information later. + if (!legacy_graph_export) { + std::string group_annotation = absl::StrCat( + "p", std::to_string(parallel_group_idx), ":", std::to_string(i)); + if (auto parallel_group_attr = + parallel_execute_op->getAttrOfType( + TF::kParallelExecAnnotation)) { + // Extend the existing attribute so that nested parallel execution + // structure is supported. + group_annotation = absl::StrCat(parallel_group_attr.getValue().str(), + ",", group_annotation); + } + for (auto& op : execute_island.GetBody()) { + op.setAttr(TF::kParallelExecAnnotation, + builder->getStringAttr(group_annotation)); + } + } + executes.push_back(execute_island); } } void CreateIslandsFromParallelExecute( tf_executor::IslandOp island_op, - tf_device::ParallelExecuteOp parallel_execute_op) { + tf_device::ParallelExecuteOp parallel_execute_op, bool legacy_graph_export, + int parallel_group_idx) { OpBuilder builder(island_op); // Create islands for each region of the parallel_execute op. llvm::SmallVector executes; ExpandParallelExecuteToIslands(island_op, parallel_execute_op, &builder, - executes); + executes, legacy_graph_export, + parallel_group_idx); // Remap all results of parallel_execute op with outputs from newly created // islands. @@ -163,22 +203,34 @@ void CreateIslandsFromParallelExecute( island_op.getControl().replaceAllUsesWith(island_sink.getControl()); } - // Islands with no uses should be pinned to a graph fetch so they still - // execute. - llvm::SmallVector unused_execute_controls; - for (auto& execute : executes) - if (execute.use_empty()) - unused_execute_controls.push_back(execute.getControl()); - - if (!unused_execute_controls.empty()) { - auto graph_op = island_op->getParentOfType(); - tf_executor::FetchOp fetch = graph_op.GetFetch(); - auto fetches = llvm::to_vector<8>(fetch.getOperands()); - fetches.append(unused_execute_controls.begin(), - unused_execute_controls.end()); - builder.setInsertionPoint(fetch); - builder.create(fetch.getLoc(), fetches); - fetch.erase(); + if (legacy_graph_export) { + // Islands with no uses should be pinned to a graph fetch so they still + // execute. + llvm::SmallVector unused_execute_controls; + for (auto& execute : executes) + if (execute.use_empty()) + unused_execute_controls.push_back(execute.getControl()); + + if (!unused_execute_controls.empty()) { + auto graph_op = island_op->getParentOfType(); + tf_executor::FetchOp fetch = graph_op.GetFetch(); + auto fetches = llvm::to_vector<8>(fetch.getOperands()); + fetches.append(unused_execute_controls.begin(), + unused_execute_controls.end()); + builder.setInsertionPoint(fetch); + builder.create(fetch.getLoc(), fetches); + fetch.erase(); + } + } else { + // Now, finally, we need to maintain the invariant expected to be maintained + // throughout the graph export pipeline that all islands always perfectly + // wrap a single op. So we'll split all islands which wrap multiple ops. + auto control_type = tf_executor::ControlType::get(island_op.getContext()); + for (auto& execute : executes) { + if (execute.GetBody().getOperations().size() > 1) { + mlir::TF::SplitIsland(execute, control_type); + } + } } island_op.erase(); @@ -190,24 +242,37 @@ void ParallelExecuteToIslandsPass::runOnOperation() { llvm::SmallVector parallel_execute_op_islands; getOperation().walk([&](tf_executor::GraphOp graph_op) { for (auto island_op : graph_op.getOps()) { - if (!island_op.WrapsSingleOp()) continue; + if (!island_op.WrapsSingleOp()) { + island_op.emitError( + "tf_executor.island must perfectly wrap a single op"); + signalPassFailure(); + } if (isa(&island_op.GetBody().front())) parallel_execute_op_islands.push_back(island_op); } }); + // This number is unique within each function which is sufficient for + // `UpdateControlDependenciesPass` which consumes the related attributes. + // However, this assumes that we don't inline functions between this pass + // and `UpdateControlDependenciesPass`. + // If we need globally unique parallel group IDs in the future, + // we can either make this pass a module pass (using a global counter) + // or use an atomic counter. + int parallel_group_idx = 0; for (tf_executor::IslandOp island_op : parallel_execute_op_islands) { auto parallel_execute_op = cast(island_op.GetBody().front()); - CreateIslandsFromParallelExecute(island_op, parallel_execute_op); + CreateIslandsFromParallelExecute(island_op, parallel_execute_op, + legacy_graph_export_, parallel_group_idx); } } } // anonymous namespace -std::unique_ptr> -CreateParallelExecuteToIslandsPass() { - return std::make_unique(); +std::unique_ptr> CreateParallelExecuteToIslandsPass( + bool legacy_graph_export) { + return std::make_unique(legacy_graph_export); } } // namespace TFDevice diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index e69e1a3e8a5..fd3a34eefea 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" namespace mlir { @@ -30,13 +31,6 @@ namespace mlir { // islands, each with a single op. std::unique_ptr> CreateBreakUpIslandsPass(); -// Creates a pass that breaks up an island with multiple ops into multiple -// islands, each with a single op. This pass intentionally does not propagate -// control dependencies across newly created islands, a following pass will -// handle this. -// TODO(b/244596254) Implement followup pass for creating control deps. -std::unique_ptr> CreateSplitIntoIslandPerOpPass(); - // Creates a pass that converts mlir functions consisting of mlir ops into a // tf_executor dialect as a single island. std::unique_ptr> @@ -88,6 +82,11 @@ CreateTFRegionControlFlowToFunctional(); std::unique_ptr> CreateMaterializePassthroughOpPass(); +// Replicates the TensorList init op by undoing some CSE needed for correct +// shape assignment in shape_inference. +std::unique_ptr> +CreateReplicateTensorListInitOpsPass(); + // Performs Shape Inference on the TensorFlow dialect using the global registry. std::unique_ptr> CreateTFShapeInferencePass(); @@ -266,6 +265,12 @@ std::unique_ptr> CreateConstantOpDeviceAssignmentPass(); // Populates the supplied passmanager with the passes required to export // to TensorFlow Graph. +void AddGraphExportLoweringPassesV2(OpPassManager& pm); + +// Populates the supplied passmanager with the passes required to export +// to TensorFlow Graph. +// ***This is the legacy graph export pipeline, prefer +// AddGraphExportLoweringPassesV2***. void AddGraphExportLoweringPasses(OpPassManager& pm); // Returns pass that verifies whether all functions in module are of single @@ -307,6 +312,13 @@ std::unique_ptr> CreateStripTfAttributesPass(); // Converts AnonymousIteratorOps to (named) IteratorOps. std::unique_ptr> CreateNameAnonymousIteratorsPass(); +// Creates a pass that breaks up an island with multiple ops into multiple +// islands, each with a single op. This pass intentionally does not propagate +// control dependencies across newly created islands, a following pass will +// handle this. +// TODO(b/244596254) Implement followup pass for creating control deps. +std::unique_ptr> CreateSplitIntoIslandPerOpPass(); + // Populates the supplied passmanager with the passes required to run the // CPU/GPU bridge. void CreateTFXLABridgePipeline(OpPassManager& pm); @@ -356,7 +368,7 @@ CreateTFExecutorUpdateControlDependenciesPass(); namespace TFDevice { // Creates a pass that forms clusters from instructions that are assigned to // same device. -std::unique_ptr> CreateClusterFormationPass(); +std::unique_ptr> CreateClusterFormationPass(); // Sinks `tf.Const` operations in the ClusterOp region using them. This is // performed in order to limit the number of values implicitly captured in this @@ -412,7 +424,8 @@ CreateReplicateInvariantOpHoistingPass(); // Creates a pass that forms replica `tf_executor.island` from a single // `tf_device.replicate` island. -std::unique_ptr> CreateReplicateToIslandPass(); +std::unique_ptr> CreateReplicateToIslandPass( + bool legacy_graph_export = true); // Creates a pass that sets the device ordinal attribute of the required op // using the replica id attribute. @@ -421,8 +434,8 @@ CreateReplicaIDToDeviceOrdinalPass(); // Creates a pass that creates `tf_executor.island` from a single // `tf_device.parallel_execute` island. -std::unique_ptr> -CreateParallelExecuteToIslandsPass(); +std::unique_ptr> CreateParallelExecuteToIslandsPass( + bool legacy_graph_export = true); // Creates a pass that annotates whether a LaunchFuncOp's parameters have the // same data across replicas. @@ -445,8 +458,8 @@ CreateDeviceAttributeToLaunchPass(); // Creates a pass that hoists a `tf_device.launch` body and assigns a `device` // attribute to each TensorFlow dialect op in the body based on the `device` // attribute on the `tf_device.launch`. -std::unique_ptr> -CreateLaunchToDeviceAttributePass(); +std::unique_ptr> CreateLaunchToDeviceAttributePass( + bool legacy_graph_export = true); // Creates a pass that extracts ops in tf_device.launch op with host device // assignment and adds an `_xla_outside_compilation` attribute value. @@ -471,6 +484,10 @@ namespace TFTPU { std::unique_ptr> CreateConvertToLegacyCompileAndReplicateAttributesPass(); +// Creates a pass that converts all TPUPartitionedInput to TPUPartitionedInputV2 +std::unique_ptr> +CreateTPUPartitionedOpConversionPass(); + // Creates a pass that forms clusters from operations of the same // `_replication_info` attribute. std::unique_ptr> CreateTPUClusterFormationPass(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc index aaa2aef1fa6..457dca838af 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc @@ -78,7 +78,7 @@ class RewriteXlaHostComputeMlir // and use it for `shape_inference_graph` attribute on XlaHostCompute. func::FuncOp cloned_func; SymbolTable manager(op->getParentOfType()); - StringRef host_module = op.host_mlir_module(); + StringRef host_module = op.getHostMlirModule(); if (!host_module.empty()) { mlir::OwningOpRef module_for_func; @@ -98,7 +98,7 @@ class RewriteXlaHostComputeMlir auto recv_at_host = rewriter.create( func.getLoc(), op.getOperandTypes(), /*dynamic_key=*/dynamic_key, - op.send_keyAttr(), + op.getSendKeyAttr(), /*device_ordinal=*/rewriter.getI64IntegerAttr(0)); for (auto result : llvm::zip(cloned_func.getArguments(), recv_at_host->getResults())) { @@ -109,19 +109,19 @@ class RewriteXlaHostComputeMlir rewriter.create( func.getLoc(), cloned_func.getBody().front().getTerminator()->getOperands(), - /*dynamic_key=*/dynamic_key, op.recv_keyAttr(), + /*dynamic_key=*/dynamic_key, op.getRecvKeyAttr(), /*device_ordinal=*/rewriter.getI64IntegerAttr(0)); } constexpr int64_t kDefaultCostEstimate = 1000000; rewriter.replaceOpWithNewOp( - op, op.getResultTypes(), op.inputs(), + op, op.getResultTypes(), op.getInputs(), /*ancestors=*/rewriter.getArrayAttr({}), rewriter.getArrayAttr(shape_attrs), /*shape_inference_graph=*/ cloned_func ? SymbolRefAttr::get(cloned_func) : SymbolRefAttr(), - /*key=*/rewriter.getStringAttr(""), op.send_keyAttr(), - op.recv_keyAttr(), + /*key=*/rewriter.getStringAttr(""), op.getSendKeyAttr(), + op.getRecvKeyAttr(), /*cost_estimate_ns=*/rewriter.getI64IntegerAttr(kDefaultCostEstimate), /*tpu_core=*/rewriter.getI64IntegerAttr(0)); return success(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index ffd6c9537a0..a7226b39ebe 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -136,16 +136,16 @@ mlir::LogicalResult PromoteVarHandlesToArguments( // then we keep them as VarHandleOps. if (!VariableIsInitialized(var_handle_op)) continue; - llvm::StringRef name = var_handle_op.shared_nameAttr().getValue(); + llvm::StringRef name = var_handle_op.getSharedNameAttr().getValue(); auto it = var_arg_index_by_name.insert({name, func_arg_types.size()}); if (it.second) { var_handle_shared_names->emplace_back(name); - auto resource_type = var_handle_op.resource().getType(); + auto resource_type = var_handle_op.getResource().getType(); func_arg_types.push_back(resource_type); - var_handle_op.resource().replaceAllUsesWith( + var_handle_op.getResource().replaceAllUsesWith( block.addArgument(resource_type, var_handle_op.getLoc())); } else { - var_handle_op.resource().replaceAllUsesWith( + var_handle_op.getResource().replaceAllUsesWith( block.getArgument(it.first->getSecond())); } var_handle_op.erase(); @@ -226,28 +226,28 @@ LogicalResult PromoteResourcesToArguments( // live value. for (Operation& op : llvm::make_early_inc_range(block)) { if (auto read_op = llvm::dyn_cast(&op)) { - if (auto func_arg = read_op.resource().dyn_cast()) { + if (auto func_arg = read_op.getResource().dyn_cast()) { if (func_arg.getOwner() != &block) return read_op.emitOpError(kResourceFunctionMsg); ResourceInfo& resource_info = resources[func_arg.getArgNumber()]; resource_info.read = true; - read_op.value().replaceAllUsesWith(resource_info.live_value); + read_op.getValue().replaceAllUsesWith(resource_info.live_value); } else { return read_op.emitOpError(kInvalidResourceMsg); } read_op.erase(); } else if (auto write_op = llvm::dyn_cast(&op)) { - if (auto func_arg = write_op.resource().dyn_cast()) { + if (auto func_arg = write_op.getResource().dyn_cast()) { if (func_arg.getOwner() != &block) return write_op.emitOpError(kResourceFunctionMsg); ResourceInfo& resource_info = resources[func_arg.getArgNumber()]; resource_info.write = true; - resource_info.live_value = write_op.value(); + resource_info.live_value = write_op.getValue(); } else { - return read_op.emitOpError(kInvalidResourceMsg); + return write_op.emitOpError(kInvalidResourceMsg); } write_op.erase(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc index 7b197b39c1a..6ff39e9d69e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc @@ -168,7 +168,8 @@ void ConvertReadonlyReferenceVariablesToResourceVariablesPass:: ArrayRef{}, ArrayRef{ builder.getNamedAttr("device", device_attr), - builder.getNamedAttr("container", variable_v2_op.containerAttr()), + builder.getNamedAttr("container", + variable_v2_op.getContainerAttr()), builder.getNamedAttr("shared_name", builder.getStringAttr(variable_name))}); for (Operation *user : diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc index 6ea75c86d63..a0335a0857f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc @@ -17,6 +17,8 @@ limitations under the License. // the TensorFlow dialect to their functional counterparts, i.e., // tf.IfRegion -> tf.If and tf.WhileRegion -> tf.While +#include + #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" @@ -185,13 +187,13 @@ void ExtractSingleBlockRegion(Region& region, StringRef name, // does not conform to this pattern. llvm::Optional IsSingleCallRegion(Region& region, bool allow_to_bool = false) { - if (!llvm::hasSingleElement(region)) return llvm::None; + if (!llvm::hasSingleElement(region)) return std::nullopt; Block& block = region.front(); auto it = block.rbegin(); YieldOp yield = dyn_cast(*it++); - if (it == block.rend()) return llvm::None; + if (it == block.rend()) return std::nullopt; // Operation which is expected to consume all the call results. Operation* call_consumer = yield; @@ -199,28 +201,28 @@ llvm::Optional IsSingleCallRegion(Region& region, // Allow a single ToBoolOp between the call and the yield (valid only // when the yield has a single operand) if (allow_to_bool && yield.getNumOperands() == 1 && isa(*it)) { - if (it->getResult(0) != yield.getOperand(0)) return llvm::None; + if (it->getResult(0) != yield.getOperand(0)) return std::nullopt; call_consumer = cast(*it); it++; - if (it == block.rend()) return llvm::None; + if (it == block.rend()) return std::nullopt; } // Check if there is a Call before the Yield. func::CallOp call = dyn_cast(*it++); - if (!call) return llvm::None; + if (!call) return std::nullopt; // All call results should feed into expected consumer // All results of the call should feed into the yield. if (call.getNumResults() != call_consumer->getNumOperands()) - return llvm::None; + return std::nullopt; for (auto res_it : llvm::zip(call.getResults(), call_consumer->getOperands())) - if (std::get<0>(res_it) != std::get<1>(res_it)) return llvm::None; + if (std::get<0>(res_it) != std::get<1>(res_it)) return std::nullopt; // There can only be non-truncating cast op's prior to the call. for (; it != block.rend(); ++it) { CastOp cast = dyn_cast(*it); - if (!cast || cast.Truncate()) return llvm::None; + if (!cast || cast.getTruncate()) return std::nullopt; } return call; @@ -246,7 +248,7 @@ bool MatchCallArgs(func::CallOp first, func::CallOp second, // Consider cast compatibility in case // %cast = "tf.Cast"(%0) : (tensor<2xi64>) -> tensor<2xf32> // is skipped. - if (cast_op.SrcT() != cast_op.DstT()) { + if (cast_op.getSrcT() != cast_op.getDstT()) { break; } value = cast_op.getOperand(); @@ -286,13 +288,12 @@ struct TrivialTransformInfo { ArgMatcherFn arg_matcher) { if (!first_call || !second_call) return; - if (!MatchCallArgs(first_call.getValue(), second_call.getValue(), - arg_matcher)) + if (!MatchCallArgs(first_call.value(), second_call.value(), arg_matcher)) return; can_transform = true; - callee_names = {first_call.getValue().getCallee(), - second_call.getValue().getCallee()}; + callee_names = {first_call.value().getCallee(), + second_call.value().getCallee()}; } }; @@ -310,8 +311,8 @@ LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) { return true; }; - const TrivialTransformInfo tti(IsSingleCallRegion(if_region.then_branch()), - IsSingleCallRegion(if_region.else_branch()), + const TrivialTransformInfo tti(IsSingleCallRegion(if_region.getThenBranch()), + IsSingleCallRegion(if_region.getElseBranch()), if_arg_matcher); std::string then_name, else_name; @@ -322,8 +323,8 @@ LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) { else_name = tti.callee_names[1].str(); } else { // Collect external values that are used within the else and then bodies. - extern_values = - CollectExternValues(if_region.then_branch(), if_region.else_branch()); + extern_values = CollectExternValues(if_region.getThenBranch(), + if_region.getElseBranch()); // These external values need to be added as inputs to the generated If. The // order is determined by the order of these values the `extern_vales`. @@ -332,30 +333,32 @@ LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) { // and outline the `then` and `else` regions by moving the bodies of these // regions into these functions. Replace tf.yield with a regular return. if (if_region->hasAttrOfType(kThenFuncNameAttr) && - !if_region._then_func_nameAttr().getValue().empty()) { + !if_region.get_thenFuncNameAttr().getValue().empty()) { then_name = - mapper.GetUniqueName(if_region._then_func_nameAttr().getValue()) + mapper.GetUniqueName(if_region.get_thenFuncNameAttr().getValue()) .str(); } else { then_name = GetName(if_region, "_then"); } - ExtractSingleBlockRegion(if_region.then_branch(), then_name, extern_values, - worklist, /*extern_values_passthrough=*/false); + ExtractSingleBlockRegion(if_region.getThenBranch(), then_name, + extern_values, worklist, + /*extern_values_passthrough=*/false); if (if_region->hasAttrOfType(kElseFuncNameAttr) && - !if_region._else_func_nameAttr().getValue().empty()) { + !if_region.get_elseFuncNameAttr().getValue().empty()) { else_name = - mapper.GetUniqueName(if_region._else_func_nameAttr().getValue()) + mapper.GetUniqueName(if_region.get_elseFuncNameAttr().getValue()) .str(); } else { else_name = GetName(if_region, "_else"); } - ExtractSingleBlockRegion(if_region.else_branch(), else_name, extern_values, - worklist, /*extern_values_passthrough=*/false); + ExtractSingleBlockRegion(if_region.getElseBranch(), else_name, + extern_values, worklist, + /*extern_values_passthrough=*/false); } // Look through ToBool operations for the condition. - Value cond = if_region.cond(); + Value cond = if_region.getCond(); auto to_bool = dyn_cast_or_null(cond.getDefiningOp()); if (to_bool) cond = to_bool.getOperand(); @@ -365,7 +368,7 @@ LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) { OpBuilder builder(if_region); auto if_op = builder.create( if_region.getLoc(), if_region.getResultTypes(), cond, extern_values, - then_name, else_name, if_region.is_stateless()); + then_name, else_name, if_region.getIsStateless()); CopyAndOverrideAttributes(if_region, if_op, &builder); if_region.replaceAllUsesWith(if_op.getResults()); @@ -399,8 +402,8 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp( }; const TrivialTransformInfo tti( - IsSingleCallRegion(while_region.cond(), /*allow_to_bool=*/true), - IsSingleCallRegion(while_region.body()), while_arg_matcher); + IsSingleCallRegion(while_region.getCond(), /*allow_to_bool=*/true), + IsSingleCallRegion(while_region.getBody()), while_arg_matcher); // All existing inputs to while region are inputs to the functional while. auto new_inputs = llvm::to_vector<4>(while_region.getOperands()); @@ -422,16 +425,16 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp( // to the region arguments, all these external references need to be added // as function arguments. llvm::SmallVector extern_values = - CollectExternValues(while_region.cond(), while_region.body()); + CollectExternValues(while_region.getCond(), while_region.getBody()); // Outline the `cond` and `body` regions by moving the bodies of these // regions into new functions. Replace tf.yield with a regular return. cond_name = GetName(while_region, "_cond"); - ExtractSingleBlockRegion(while_region.cond(), cond_name, extern_values, + ExtractSingleBlockRegion(while_region.getCond(), cond_name, extern_values, worklist, /*extern_values_passthrough=*/false); body_name = GetName(while_region, "_body"); - ExtractSingleBlockRegion(while_region.body(), body_name, extern_values, + ExtractSingleBlockRegion(while_region.getBody(), body_name, extern_values, worklist, /*extern_values_passthrough=*/true); // All extern values become additional inputs and additional output types @@ -445,8 +448,8 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp( OpBuilder builder(while_region); auto while_op = builder.create( while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name, - while_region.parallel_iterations(), while_region.is_stateless(), - while_region.shape_invariant()); + while_region.getParallelIterations(), while_region.getIsStateless(), + while_region.getShapeInvariant()); CopyAndOverrideAttributes(while_region, while_op, &builder); // Redirect old results to new results. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc b/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc index a46f585df88..d587df91c1b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc @@ -22,7 +22,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/FunctionInterfaces.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project @@ -110,7 +110,7 @@ void EraseResults(Operation* op, llvm::BitVector erase) { Operation* new_op = builder.create(state); for (const auto& indexed_regions : llvm::enumerate(op->getRegions())) { Region& region = op->getRegion(indexed_regions.index()); - BlockAndValueMapping mapping; + IRMapping mapping; indexed_regions.value().cloneInto(®ion, mapping); } int new_position = 0; @@ -223,7 +223,14 @@ void RemoveUnusedArgumentsPass::runOnOperation() { op.getOperation()->getResult(from).replaceAllUsesWith( op.getOperation()->getOperand(to)); } - op->eraseOperands(args_to_erase.lookup(func)); + BitVector operands_to_erase(op->getNumOperands()); + int args_start = op->getNumOperands() + ? op.getArgOperands().getBase()->getOperandNumber() + : 0; + operands_to_erase |= args_to_erase.lookup(func); + operands_to_erase <<= args_start; + op->eraseOperands(operands_to_erase); + EraseResults(op, results_to_erase.lookup(func)); }); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_while_results.cc b/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_while_results.cc index 05336d2818d..b4818592ef6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_while_results.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_while_results.cc @@ -50,8 +50,8 @@ bool TryPruneResultDefiningOp(TF::WhileRegionOp while_op, OpResult result) { // Don't prune if result is used. if (!result.use_empty()) return false; - Block& body_block = while_op.body().front(); - Block& cond_block = while_op.cond().front(); + Block& body_block = while_op.getBody().front(); + Block& cond_block = while_op.getCond().front(); Operation* body_yield_op = body_block.getTerminator(); // The body yield operand, body block argument, condition block argument, and @@ -69,7 +69,7 @@ bool TryPruneResultDefiningOp(TF::WhileRegionOp while_op, OpResult result) { if (TF::TensorFlowDialect::CanHaveSideEffects(candidate_op)) { return false; } - } else if (!MemoryEffectOpInterface::hasNoEffect(candidate_op)) { + } else if (!isMemoryEffectFree(candidate_op)) { return false; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc index 171f03a9e94..1831ecc68db 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc @@ -47,7 +47,7 @@ struct ReplicateInvariantOpHoistingPass void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas, Block* replicate_block, TF::ShapeOp shape_op) { - Value input = shape_op.input(); + Value input = shape_op.getInput(); // If ShapeOp operand is replicate tensor block argument, replace with the // associated first replica operand. if (auto block_arg = input.dyn_cast()) { @@ -72,7 +72,7 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas, // shape has not changed in replicate prior to read. Currently after both // ResourceOpLiftingPass and TPURewritePass, there should not be any updates // to resources prior to their respective ReadVariableOp. - if (auto block_arg = read_var_op.resource().dyn_cast()) { + if (auto block_arg = read_var_op.getResource().dyn_cast()) { if (block_arg.getOwner() != replicate_block) return; OpBuilder builder(shape_op); @@ -86,7 +86,7 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas, } // Check if op uses a device from a list of virtual devices. -bool UsesVirtualDevice(const Optional& virtual_devices, +bool UsesVirtualDevice(const std::optional& virtual_devices, Operation* operation) { if (!virtual_devices.has_value()) return false; @@ -94,7 +94,7 @@ bool UsesVirtualDevice(const Optional& virtual_devices, StringAttr op_device = op->getAttrOfType(kDeviceAttr); if (!op_device) return WalkResult::advance(); - if (virtual_devices.getValue().get(op_device.getValue())) + if (virtual_devices.value().get(op_device.getValue())) return WalkResult::interrupt(); return WalkResult::advance(); }); @@ -132,7 +132,7 @@ void HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op) { }); Region* replicate_region = &replicate_op.getBody(); - Optional virtual_device_list = replicate_op.getDevices(); + std::optional virtual_device_list = replicate_op.getDevices(); for (Operation& inner_op : llvm::make_early_inc_range(replicate_op.GetBody())) { if (llvm::isa(inner_op)) continue; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_tensor_list_init_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_tensor_list_init_ops_pass.cc new file mode 100644 index 00000000000..a8b6ecc0530 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_tensor_list_init_ops_pass.cc @@ -0,0 +1,79 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +namespace mlir { +namespace TF { + +namespace { + +#define GEN_PASS_DEF_REPLICATETENSORLISTINITOPSPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" + +// Replicates the TensorList initialization ops for all the uses. +// No need to delete the original TensorList as it might be used elsewhere. +template +void ReplicateTensorListForUses(T tensor_list_op) { + Value tensor_list = tensor_list_op.getResult(); + std::vector uses; + for (auto& use : tensor_list.getUses()) { + uses.emplace_back(&use); + } + OpBuilder builder(tensor_list_op.getOperation()); + for (OpOperand* operand : uses) { + auto new_op = builder.clone(*tensor_list_op.getOperation()); + operand->set(new_op->getResult(0)); + } +} + +// This transformation pass replicates TensorList initialization ops. +class ReplicateTensorListInitOps + : public impl::ReplicateTensorListInitOpsPassBase< + ReplicateTensorListInitOps> { + public: + void runOnOperation() override { + getOperation().walk([](Operation* op) { + if (auto tl_reserve = dyn_cast(op)) { + ReplicateTensorListForUses(tl_reserve); + } + if (auto tl_empty = dyn_cast(op)) { + ReplicateTensorListForUses(tl_empty); + } + }); + } +}; +} // namespace + +std::unique_ptr> +CreateReplicateTensorListInitOpsPass() { + return std::make_unique(); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index d1b248e9b0b..1b2ea737b50 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -17,10 +17,11 @@ limitations under the License. // `tf_device.replicate` island. #include +#include +#include #include #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" @@ -29,7 +30,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project @@ -39,6 +40,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/split_into_island_per_op_pass.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" @@ -55,6 +59,9 @@ constexpr char kTPUCore0[] = "TPU_REPLICATED_CORE_0"; struct ReplicateToIslandPass : public impl::ReplicateToIslandPassBase { + explicit ReplicateToIslandPass(bool legacy_graph_export) { + legacy_graph_export_ = legacy_graph_export; + } void runOnOperation() override; }; @@ -68,22 +75,22 @@ bool RequiresReplicaIDAttribute(Operation* op) { // Collects TPU device ordinal for outside compilation communication ops. This // currently assumes outside compilation only uses `TPU_REPLICATED_CORE_0` // aliased device for the device computation. -llvm::Optional GetDeviceOrdinal( - const llvm::Optional& devices, Location loc, +std::optional GetDeviceOrdinal( + const std::optional& devices, Location loc, unsigned replica_id) { int64_t device_ordinal = 0; if (devices.has_value()) { - if (auto tpu_replica_0 = devices.getValue().get(kTPUCore0)) { + if (auto tpu_replica_0 = devices.value().get(kTPUCore0)) { llvm::StringRef tpu_device = tpu_replica_0.cast()[replica_id] .cast() .getValue(); if (succeeded(tensorflow::GetDeviceOrdinalFromDeviceString( loc, tpu_device, &device_ordinal))) { - return llvm::Optional(device_ordinal); + return std::optional(device_ordinal); } } } - return llvm::None; + return std::nullopt; } // Updates replica variant ops in a region based on replica `replica_id`. @@ -94,8 +101,8 @@ llvm::Optional GetDeviceOrdinal( // represents replica id. LogicalResult UpdateRegionReplicateVariantOps( OpBuilder& builder, Location loc, Region& region, int replica_id, - const llvm::Optional& devices) { - llvm::Optional device_ordinal = + const std::optional& devices) { + std::optional device_ordinal = GetDeviceOrdinal(devices, loc, replica_id); auto result = region.walk([&](Operation* op) -> WalkResult { @@ -114,7 +121,7 @@ LogicalResult UpdateRegionReplicateVariantOps( auto const_op = builder.create( op->getLoc(), DenseIntElementsAttr::get( RankedTensorType::get({}, builder.getI64Type()), - {device_ordinal.getValue()})); + {device_ordinal.value()})); op->replaceAllUsesWith(const_op); op->erase(); return WalkResult::advance(); @@ -124,7 +131,7 @@ LogicalResult UpdateRegionReplicateVariantOps( // Map aliased devices to explicit devices based on replica. if (auto launch = dyn_cast(op)) - if (auto device_by_replica = devices.getValue().get(launch.getDevice())) + if (auto device_by_replica = devices.value().get(launch.getDevice())) launch->setAttr( kDeviceAttr, device_by_replica.cast()[replica_id].cast()); @@ -142,7 +149,8 @@ LogicalResult UpdateRegionReplicateVariantOps( LogicalResult ExpandReplicateIntoReplicas( const Dialect* tf_dialect, OpBuilder& builder, tf_executor::IslandOp island_op, tf_device::ReplicateOp replicate_op, - int num_replicas, llvm::SmallVectorImpl& replicas) { + int num_replicas, llvm::SmallVectorImpl& replicas, + bool legacy_graph_export, int replica_group_idx) { replicas.reserve(num_replicas); auto devices = replicate_op.getDevices(); @@ -159,7 +167,7 @@ LogicalResult ExpandReplicateIntoReplicas( terminator.erase(); builder.setInsertionPoint(island_op); - BlockAndValueMapping mapping; + IRMapping mapping; for (int i : llvm::seq(0, num_replicas)) { // Create new island for replica. auto replica = builder.create( @@ -179,6 +187,28 @@ LogicalResult ExpandReplicateIntoReplicas( /*replica_id=*/i, devices))) return failure(); + // In new graph export pipeline, we will update control dependencies in the + // end of the pipeline. Mostly, it will rely on side effect analysis by + // considering accessing resource only. However, for branches under parallel + // group, there should not be any control deps between them even side effect + // analysis indicate some control deps. Therefore, we will mark parallel + // group and branch information here so that `UpdateControlDependenciesPass` + // can fetch the related information later. + if (!legacy_graph_export) { + std::string group_annotation = absl::StrCat( + "r", std::to_string(replica_group_idx), ":", std::to_string(i)); + if (auto parallel_group_attr = replicate_op->getAttrOfType( + TF::kParallelExecAnnotation)) { + // Extend the existing attribute so that nested parallel execution + // structure is supported. + group_annotation = absl::StrCat(parallel_group_attr.getValue().str(), + ",", group_annotation); + } + for (auto& op : replica.GetBody()) { + op.setAttr(TF::kParallelExecAnnotation, + builder.getStringAttr(group_annotation)); + } + } replicas.push_back(replica); } @@ -238,14 +268,17 @@ LogicalResult ExpandReplicateIntoReplicas( LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect, tf_executor::GraphOp graph_op, tf_executor::IslandOp island_op, - tf_device::ReplicateOp replicate_op) { + tf_device::ReplicateOp replicate_op, + bool legacy_graph_export, + int replica_group_idx) { OpBuilder builder(island_op); const int num_replicas = replicate_op.getN(); // Create islands per replica. llvm::SmallVector replicas; - if (failed(ExpandReplicateIntoReplicas(tf_dialect, builder, island_op, - replicate_op, num_replicas, replicas))) + if (failed(ExpandReplicateIntoReplicas( + tf_dialect, builder, island_op, replicate_op, num_replicas, replicas, + legacy_graph_export, replica_group_idx))) return failure(); // Collect all replica results. @@ -280,21 +313,33 @@ LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect, island_op.getControl().replaceAllUsesWith(island_sink.getControl()); } - // Replicas with no uses should be pinned to a graph fetch so they still - // execute. - llvm::SmallVector unused_replica_controls; - for (auto& replica : replicas) - if (replica.use_empty()) - unused_replica_controls.push_back(replica.getControl()); - - if (!unused_replica_controls.empty()) { - tf_executor::FetchOp fetch = graph_op.GetFetch(); - auto fetches = llvm::to_vector<8>(fetch.getOperands()); - fetches.append(unused_replica_controls.begin(), - unused_replica_controls.end()); - builder.setInsertionPoint(fetch); - builder.create(fetch.getLoc(), fetches); - fetch.erase(); + if (legacy_graph_export) { + // Replicas with no uses should be pinned to a graph fetch so they still + // execute. + llvm::SmallVector unused_replica_controls; + for (auto& replica : replicas) + if (replica.use_empty()) + unused_replica_controls.push_back(replica.getControl()); + + if (!unused_replica_controls.empty()) { + tf_executor::FetchOp fetch = graph_op.GetFetch(); + auto fetches = llvm::to_vector<8>(fetch.getOperands()); + fetches.append(unused_replica_controls.begin(), + unused_replica_controls.end()); + builder.setInsertionPoint(fetch); + builder.create(fetch.getLoc(), fetches); + fetch.erase(); + } + } else { + // Now, finally, we need to maintain the invariant expected to be maintained + // throughout the graph export pipeline that all islands always perfectly + // wrap a single op. So we'll split all replica islands. + auto control_type = tf_executor::ControlType::get(island_op.getContext()); + for (auto& replica : replicas) { + if (replica.GetBody().getOperations().size() > 1) { + mlir::TF::SplitIsland(replica, control_type); + } + } } island_op.erase(); @@ -320,19 +365,31 @@ void ReplicateToIslandPass::runOnOperation() { } }); + // This number is unique within each function which is sufficient for + // `UpdateControlDependenciesPass` which consumes the related attributes. + // However, this assumes that we don't inline functions between this pass + // and `UpdateControlDependenciesPass`. + // If we need globally unique replica group IDs in the future, + // we can either make this pass a module pass (using a global counter) + // or use an atomic counter. + int replica_group_idx = 0; for (tf_executor::IslandOp island_op : replicate_op_islands) { auto graph_op = island_op->getParentOfType(); auto replicate_op = cast(island_op.GetBody().front()); if (failed(CreateIslandsFromReplicate(tf_dialect, graph_op, island_op, - replicate_op))) + replicate_op, legacy_graph_export_, + replica_group_idx))) { + replica_group_idx++; return signalPassFailure(); + } } } } // anonymous namespace -std::unique_ptr> CreateReplicateToIslandPass() { - return std::make_unique(); +std::unique_ptr> CreateReplicateToIslandPass( + bool legacy_graph_export) { + return std::make_unique(legacy_graph_export); } } // namespace TFDevice diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc index bc3944539c6..020abcfd63e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include @@ -78,7 +79,7 @@ class PerFunctionResult { // Returns the recorded device assignment for a resource, if any. Optional DeviceForResource(Value resource) const { Optional result; - if (alias_analysis_.IsUnknownResource(resource)) return llvm::None; + if (alias_analysis_.IsUnknownResource(resource)) return std::nullopt; for (int64_t id : alias_analysis_.GetResourceUniqueIds(resource)) { auto it = resource_id_to_device_.find(id); if (it == resource_id_to_device_.end()) continue; @@ -87,7 +88,7 @@ class PerFunctionResult { continue; } // Got conflicting assignments - return llvm::None; + return std::nullopt; } return result; } @@ -145,7 +146,8 @@ inline StringRef GetDeviceAttr(Operation* op) { // Print operation with debug info (to get line number info for debugging) void dump(StringRef message, Operation* op) { llvm::dbgs() << message; - op->print(llvm::dbgs(), OpPrintingFlags().enableDebugInfo(true)); + op->print(llvm::dbgs(), OpPrintingFlags().enableDebugInfo( + /*enable=*/true, /*prettyForm=*/true)); llvm::dbgs() << "\n"; } @@ -178,7 +180,7 @@ LogicalResult ComputeResourceDevicesInComputation(func::FuncOp func_op, // Record VarHandleOp's device attribute. StringRef device_attr = GetDeviceAttr(op); if (device_attr.empty()) return WalkResult::advance(); - auto res = AddResourceDeviceAndEmitError(var_handle.resource(), + auto res = AddResourceDeviceAndEmitError(var_handle.getResource(), device_attr, op, result); if (failed(res)) return WalkResult::interrupt(); } else if (auto identity = dyn_cast(op)) { @@ -286,7 +288,7 @@ void ResourceDeviceInference::runOnOperation() { return WalkResult::interrupt(); } else if (auto if_op = dyn_cast(op)) { if (failed(propagate_operands_to_callee_arguments( - if_op, if_op.input(), + if_op, if_op.getInput(), {if_op.then_function(), if_op.else_function()}, func_res))) return WalkResult::interrupt(); } else if (auto call = dyn_cast(op)) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index c2221db9f2d..aa4941ec5b6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -30,7 +30,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -102,7 +102,7 @@ void SetAllVarIsInitializedToTrue(Block* block) { DenseIntElementsAttr::get( RankedTensorType::get(/*shape=*/{}, builder.getI1Type()), true)); - op.is_initialized().replaceAllUsesWith(const_true); + op.getIsInitialized().replaceAllUsesWith(const_true); op.erase(); } } @@ -124,23 +124,23 @@ void ForwardStoreToLoad(Block* block) { // nested deeper in regions. for (Operation& op : llvm::make_early_inc_range(*block)) { if (auto read_variable_op = dyn_cast(&op)) { - Value resource = read_variable_op.resource(); + Value resource = read_variable_op.getResource(); auto last_store = resource_handle_to_last_store_op[resource]; if (!last_store) continue; // Use stored value in last_store to replace all uses of current resource // load's result, then erase this resource load. Add an intermediate // CastOp if the shape of types doesn't exactly match. - Type read_type = read_variable_op.value().getType(); - if (read_type != last_store.value().getType()) { + Type read_type = read_variable_op.getValue().getType(); + if (read_type != last_store.getValue().getType()) { OpBuilder builder(last_store); builder.setInsertionPointAfter(last_store); auto cast = builder.create( - last_store.getLoc(), read_type, last_store.value(), + last_store.getLoc(), read_type, last_store.getValue(), /*Truncate=*/builder.getBoolAttr(false)); - read_variable_op.value().replaceAllUsesWith(cast); + read_variable_op.getValue().replaceAllUsesWith(cast); } else { - read_variable_op.value().replaceAllUsesWith(last_store.value()); + read_variable_op.getValue().replaceAllUsesWith(last_store.getValue()); } read_variable_op.erase(); @@ -148,7 +148,7 @@ void ForwardStoreToLoad(Block* block) { } if (auto assign_variable_op = dyn_cast(&op)) { - Value resource = assign_variable_op.resource(); + Value resource = assign_variable_op.getResource(); auto last_store = resource_handle_to_last_store_op[resource]; // Previous store ops to same resource can be erased. if (last_store) last_store.erase(); @@ -331,13 +331,13 @@ LogicalResult RegionResourceHoister::Analyze() { if (read && !info.is_read) { info.is_read = true; - info.RefineType(read.value().getType()); + info.RefineType(read.getValue().getType()); info.read_attrs = user->getAttrDictionary(); } if (write) { info.is_written = true; - info.RefineType(write.value().getType()); + info.RefineType(write.getValue().getType()); info.write_attrs = user->getAttrDictionary(); written_regions.set(user->getParentRegion()->getRegionNumber()); } @@ -397,7 +397,7 @@ void RegionResourceHoister::ReplaceResourceLoads(Region& region, // ops nested deeper in regions. auto all_reads = region.front().getOps(); for (auto read_op : llvm::make_early_inc_range(all_reads)) { - Value resource = read_op.resource(); + Value resource = read_op.getResource(); if (!Contains(resource)) continue; ResourceInfo& info = resources_[resource]; @@ -435,7 +435,7 @@ void RegionResourceHoister::AppendResourceStoreValueToReturn( // regions should have been lifted out. auto assign_ops = front.getOps(); for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) { - Value resource = assign_variable_op.resource(); + Value resource = assign_variable_op.getResource(); if (!IsWritten(resource)) continue; // TODO(ycao): Prevent same value from being returned multiple times. @@ -443,7 +443,7 @@ void RegionResourceHoister::AppendResourceStoreValueToReturn( // of cluster. Both of these can be post-resource-op-lifting cleanup // passes. int result_index = resources_[resource].result_index; - new_return_operands[result_index] = assign_variable_op.value(); + new_return_operands[result_index] = assign_variable_op.getValue(); assign_variable_op.erase(); } old_return->setOperands(new_return_operands); @@ -633,7 +633,7 @@ LogicalResult FindResourceArgUseInfo( if (auto assign = llvm::dyn_cast(user)) { read_or_assigned = true; info.updated = true; - info.data_type = assign.value().getType(); + info.data_type = assign.getValue().getType(); continue; } @@ -768,11 +768,11 @@ LogicalResult LiftArgRetResourcesForFunction( // For writes, invoke the callback and then erase the write. auto assign_ops = func_op.front().getOps(); for (auto assign_variable_op : llvm::make_early_inc_range(assign_ops)) { - Value resource = assign_variable_op.resource(); + Value resource = assign_variable_op.getResource(); if (!hoister.Contains(resource)) continue; auto arg = resource.dyn_cast(); - handle_updated_arg_value(arg.getArgNumber(), assign_variable_op.value()); + handle_updated_arg_value(arg.getArgNumber(), assign_variable_op.getValue()); assign_variable_op.erase(); } @@ -965,7 +965,7 @@ LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef branches) { // Now use the filtered original operands, which will be replaced by // AddLoadsStoresOutsideControlFlowOp(). auto new_operands = - FilterRange(op.input(), resource_arg_uses); + FilterRange(op.getInput(), resource_arg_uses); new_operands.insert(new_operands.begin(), op.getOperand(0)); func::FuncOp first_func = branches.front(); auto new_op = builder.create( @@ -1110,8 +1110,8 @@ void UpdatePartitionedCallOpWithNewCallee( OpBuilder builder(call_op); // Now use the filtered original operands, which will be replaced by // AddLoadsStoresOutsideControlFlowOp(). - auto new_operands = - FilterRange(call_op.args(), lifting_info.use_info); + auto new_operands = FilterRange(call_op.getArgs(), + lifting_info.use_info); auto new_call = builder.create( call_op.getLoc(), lifting_info.lifted_callee.getFunctionType().getResults(), new_operands, diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc index f2a62a96db1..14d0b1047fb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc @@ -36,7 +36,7 @@ bool IsResource(Value value) { bool IsCastOfResource(Operation &op) { auto cast = dyn_cast(op); if (!cast) return false; - return IsResource(cast.x()); + return IsResource(cast.getX()); } // Removes passthrough ops in the block. The device computation does not need @@ -60,7 +60,7 @@ void RemoveDeadLocalVariables(Block &block) { } } for (auto local_var : local_vars) { - auto users = local_var.resource().getUsers(); + auto users = local_var.getResource().getUsers(); if (llvm::all_of(users, [](const Operation *user) { return isa(user); })) { @@ -120,7 +120,7 @@ void EliminateUnusedResults( func::FuncOp CloneFunctionIfNeeded(func::FuncOp func) { ModuleOp module = func->getParentOfType(); auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); - if (func_uses.has_value() && llvm::hasSingleElement(func_uses.getValue())) + if (func_uses.has_value() && llvm::hasSingleElement(func_uses.value())) return func; func::FuncOp cloned = func.clone(); cloned.setPrivate(); @@ -220,8 +220,10 @@ void EliminateUnusedResultsForWhile(TF::WhileOp op) { func::FuncOp cloned_cond = CloneFunctionIfNeeded(cond); func::FuncOp cloned_body = CloneFunctionIfNeeded(body); - op.condAttr(FlatSymbolRefAttr::get(op.getContext(), cloned_cond.getName())); - op.bodyAttr(FlatSymbolRefAttr::get(op.getContext(), cloned_body.getName())); + op.setCondAttr( + FlatSymbolRefAttr::get(op.getContext(), cloned_cond.getName())); + op.setBodyAttr( + FlatSymbolRefAttr::get(op.getContext(), cloned_body.getName())); // Drop cond/body args and return value. WhileOp result will be dropped later // in EliminateUnusedResults. Traverse in reverse order so that indices to be @@ -270,21 +272,21 @@ LogicalResult ForwardCommonArgToOutput(Operation *op, } if (!common_arg_index.has_value()) { common_arg_index = block_arg.getArgNumber(); - } else if (common_arg_index.getValue() != block_arg.getArgNumber()) { + } else if (common_arg_index.value() != block_arg.getArgNumber()) { return op->emitError("result #") << result_idx << " is not tied to the same argument across all branches"; } } - if (io_match && result_idx != common_arg_index.getValue()) { + if (io_match && result_idx != common_arg_index.value()) { return op->emitOpError("Result #") << result_idx << " is tied to argument #" - << common_arg_index.getValue(); + << common_arg_index.value(); } // Forward the corresponding input to the output - result.replaceAllUsesWith(branch_args[common_arg_index.getValue()]); + result.replaceAllUsesWith(branch_args[common_arg_index.value()]); } return success(); } @@ -373,8 +375,8 @@ LogicalResult CanonicalizeRegionIfCaseCluster(Operation *op) { // the body, the result is replaced with the operand and all argument/results // and retuns values corresponding to that result are dropped. LogicalResult CanonicalizeWhileRegion(TF::WhileRegionOp op) { - Region &body = op.body(); - Region &cond = op.cond(); + Region &body = op.getBody(); + Region &cond = op.getCond(); llvm::BitVector can_eliminate(op.getNumResults()); // Traverse in reverse order so that indices to be deleted stay unchanged. @@ -423,11 +425,12 @@ LogicalResult CleanupAndCanonicalize(Operation *parent_op) { if (auto if_op = dyn_cast(op)) { result = CanonicalizeFunctionalIfCase( - op, {if_op.then_function(), if_op.else_function()}, if_op.input()); + op, {if_op.then_function(), if_op.else_function()}, if_op.getInput()); } else if (auto case_op = dyn_cast(op)) { SmallVector branches; case_op.get_branch_functions(branches); - result = CanonicalizeFunctionalIfCase(case_op, branches, case_op.input()); + result = + CanonicalizeFunctionalIfCase(case_op, branches, case_op.getInput()); } else if (auto while_op = dyn_cast(op)) { if (while_op.cond_function().walk(check_while_cond).wasInterrupted()) return WalkResult::interrupt(); @@ -436,7 +439,7 @@ LogicalResult CleanupAndCanonicalize(Operation *parent_op) { op)) { result = CanonicalizeRegionIfCaseCluster(op); } else if (auto while_region = dyn_cast(op)) { - if (while_region.cond().walk(check_while_cond).wasInterrupted()) + if (while_region.getCond().walk(check_while_cond).wasInterrupted()) return WalkResult::interrupt(); // For while region, the body input and output arg should match. result = CanonicalizeWhileRegion(while_region); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc index 2084ff3d99b..2ff6c78896f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_tpu_embedding_ops.cc @@ -72,7 +72,7 @@ LogicalResult RunOnRegion(Region* region) { if (!recv_op && !send_op) return success(); Location loc = recv_op ? recv_op.getLoc() : send_op.getLoc(); - StringRef config = recv_op ? recv_op.config() : send_op.config(); + StringRef config = recv_op ? recv_op.getConfig() : send_op.getConfig(); // Create XlaRecvTPUEmbeddingDeduplicationData op. OpBuilder builder(region); @@ -89,8 +89,8 @@ LogicalResult RunOnRegion(Region* region) { // Rewrite SendTPUEmbeddingGradients op to the corresponding internal op and // then update the OperandSegmentSize attribute. if (send_op) { - int32_t operand_sizes[] = {static_cast(send_op.N()), - static_cast(send_op.NN()), 1}; + int32_t operand_sizes[] = {static_cast(send_op.getN()), + static_cast(send_op.getNN()), 1}; auto operand_size_attr = builder.getDenseI32ArrayAttr(operand_sizes); auto new_send_op = AddOperandAndRewriteAs( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.cc index 4be54c05b90..0d2406e30e8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h" +#include #include #include "tensorflow/core/util/device_name_utils.h" @@ -30,14 +31,14 @@ const char kDeviceGpu[] = "GPU"; llvm::Optional GetOpDevice(mlir::Operation *op) { mlir::StringAttr device = op->getAttrOfType(kDeviceAttr); if (!device || device.getValue().empty()) { - return llvm::None; + return std::nullopt; } tensorflow::DeviceNameUtils::ParsedName parsed_name; if (!tensorflow::DeviceNameUtils::ParseFullName(device.str(), &parsed_name)) { - return llvm::None; + return std::nullopt; } if (!parsed_name.has_type) { - return llvm::None; + return std::nullopt; } return parsed_name.type; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc b/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc index 0eac58abbcd..fed9894db2e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc @@ -45,7 +45,7 @@ bool SetTPUInfeedLayout(mlir::OwningOpRef &mlir_module) { // Do not append a UnitAttr for the "token" operand here to avoid // compilation failure when exporting the "layouts" attribute to a graph // node. Instead, add the UnitAttr during LegalizeTF pass. - op->setAttr("layouts", layout.getValue()); + op->setAttr("layouts", layout.value()); return mlir::WalkResult::advance(); }); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index cd6523ef81a..a77c5bb1119 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -19,12 +19,12 @@ limitations under the License. #include #include #include +#include #include #include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Hashing.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -63,7 +63,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" @@ -233,7 +232,7 @@ bool NeedsCastBack(OpOperand& use, Dialect* tf_dialect) { TensorType CreateTensorType(llvm::Optional> shape, Type element_type) { if (shape.has_value()) - return tensorflow::GetTypeFromTFTensorShape(shape.getValue(), element_type); + return tensorflow::GetTypeFromTFTensorShape(shape.value(), element_type); return UnrankedTensorType::get(element_type); } @@ -246,11 +245,11 @@ bool IsTensorListInitOp(Operation* op) { // Returns the `element_shape` operand of the ops that create a TensorList. Value GetElementShapeOperand(Operation* op) { if (auto empty_tl = dyn_cast(op)) - return empty_tl.element_shape(); + return empty_tl.getElementShape(); if (auto tl_reserve = dyn_cast(op)) - return tl_reserve.element_shape(); + return tl_reserve.getElementShape(); if (auto tl_from_tensor = dyn_cast(op)) - return tl_from_tensor.element_shape(); + return tl_from_tensor.getElementShape(); llvm_unreachable("unsupported TensorList op"); } @@ -318,10 +317,10 @@ bool CanInferTensorListElementType(Value tensorlist, for (auto& use : tensorlist.getUses()) { if (auto push = llvm::dyn_cast(use.getOwner())) { auto element_type = - push.tensor().getType().dyn_cast(); + push.getTensor().getType().dyn_cast(); if (!verify_and_update_potential_element_type(element_type)) return false; - worklist.emplace(push.output_handle()); + worklist.emplace(push.getOutputHandle()); continue; } if (auto scatter = llvm::dyn_cast( @@ -329,27 +328,27 @@ bool CanInferTensorListElementType(Value tensorlist, // For scatter op we can get the element shape by dropping the first // dimension of the input tensor. RankedTensorType element_type = - DropFirstDimension(scatter.tensor().getType()); + DropFirstDimension(scatter.getTensor().getType()); if (!verify_and_update_potential_element_type(element_type)) return false; - worklist.emplace(scatter.output_handle()); + worklist.emplace(scatter.getOutputHandle()); continue; } if (auto set_item = llvm::dyn_cast(use.getOwner())) { auto element_type = - set_item.item().getType().dyn_cast(); + set_item.getItem().getType().dyn_cast(); DCOMMENT("\tTensorListSetItemOp " << element_type); if (!verify_and_update_potential_element_type(element_type)) return false; - worklist.emplace(set_item.output_handle()); + worklist.emplace(set_item.getOutputHandle()); continue; } if (auto pop = llvm::dyn_cast(use.getOwner())) { - worklist.emplace(pop.output_handle()); + worklist.emplace(pop.getOutputHandle()); continue; } if (auto resize = llvm::dyn_cast(use.getOwner())) { - worklist.emplace(resize.output_handle()); + worklist.emplace(resize.getOutputHandle()); continue; } // WhileRegionOp can explicitly capture TensorList value to be used inside @@ -449,6 +448,8 @@ struct ValuePort { return producer == other.producer && port == other.port; } + ValuePort() = default; + // Convert output value to ValuePort. explicit ValuePort(Value v) { OpResult opr = v.dyn_cast(); @@ -472,6 +473,8 @@ struct ValuePort { os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end())); return os; } + + bool IsValid() const { return !producer.isNull(); } }; struct ValuePortHasher { @@ -487,6 +490,86 @@ using ComputedQueryFn = function_ref; using ValueQueryFn = function_ref; using ValuePortInputs = SmallVectorImpl; +// Note: Following implements the rank 1 pack op case so could be +// generalized. +// +// Maps the specified component in the `port` of the given op's result to one of +// the element in the input. +ValuePort ComputeInputComponentFor(PackOp op, ArrayRef port) { + auto type = op.getType().cast(); + if (!type.hasRank() || type.getRank() != 1) return {}; + if (port.size() != 2) return {}; + assert(port[0] == 0); + return ValuePort(op.getOperand(port[1])); +} + +ValuePort ComputeInputComponentFor(ConcatV2Op op, ArrayRef port) { + if (port.size() != 2) return {}; + assert(port[0] == 0); + + int64_t element_idx = port[1]; + for (Value val : op.getValues()) { + auto val_ty = val.getType().cast(); + if (!val_ty.hasStaticShape() || val_ty.getRank() != 1) return {}; + + int64_t dim_size = val_ty.getNumElements(); + if (element_idx >= dim_size) { + element_idx -= dim_size; + continue; + } + + ValuePort req(val); + req.port.push_back(element_idx); + return req; + } + return {}; +} + +ValuePort ComputeInputComponentFor(GatherV2Op op, ArrayRef port) { + if (port.size() != 2) return {}; + assert(port[0] == 0); + + auto params = op.getParams(); + auto params_ty = params.getType().dyn_cast(); + if (!params_ty || !params_ty.hasStaticShape() || params_ty.getRank() != 1 || + op.getBatchDims() != 0) { + return {}; + } + + DenseIntElementsAttr axis; + if (!matchPattern(op.getAxis(), m_Constant(&axis)) || + axis.getNumElements() != 1 || + !axis.getSplatValue().isZero()) { + return {}; + } + + DenseIntElementsAttr indices; + if (!matchPattern(op.getIndices(), m_Constant(&indices)) || + indices.getType().getRank() != 1 || port[1] >= indices.getNumElements()) { + return {}; + } + + int64_t input_idx = indices.getValues()[port[1]].getInt(); + if (input_idx >= params_ty.getDimSize(0)) return {}; + + ValuePort req(params); + req.port.push_back(input_idx); + return req; +} + +ValuePort ComputeInputComponentFor(Operation* op, ArrayRef port) { + if (auto pack_op = llvm::dyn_cast(op)) { + return ComputeInputComponentFor(pack_op, port); + } + if (auto concat_op = llvm::dyn_cast(op)) { + return ComputeInputComponentFor(concat_op, port); + } + if (auto gather_op = llvm::dyn_cast(op)) { + return ComputeInputComponentFor(gather_op, port); + } + return {}; +} + // TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are // intended to be switched to op interfaces once more refined. LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, @@ -496,17 +579,11 @@ LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port, auto& port = value_port.port; if (!op) return failure(); - // No inputs required for constants. - if (matchPattern(op, m_Constant())) return success(); - - // Note: this focusses only on the trivial pack op case and this could be - // generalized. - if (auto pack_op = dyn_cast(op)) { - auto type = pack_op.getType().cast(); - if (!type.hasRank() || type.getRank() != 1) return failure(); - if (port.size() != 2) return failure(); - assert(port[0] == 0); - ValuePort req(pack_op.getOperand(port[1])); + // No inputs required for constants and ShapeOp. + if (matchPattern(op, m_Constant()) || isa(op)) return success(); + + ValuePort req = ComputeInputComponentFor(op, port); + if (req.IsValid()) { if (!has_been_computed(req)) inputs->push_back(req); return success(); } @@ -533,23 +610,49 @@ Attribute ComputeOutputComponent(const ValuePort& value_port, ElementsAttr attr; if (matchPattern(op, m_Constant(&attr))) { if (port.size() == 1 && port[0] == 0) return attr; + if (port.size() == 2) { + assert(port[0] == 0); + DenseIntElementsAttr value; + if (!matchPattern(op, m_Constant(&value)) || + value.getType().getRank() != 1 || port[1] >= value.getNumElements()) { + return nullptr; + } + + auto range = value.getValues(); + auto component_ty = RankedTensorType::get({1}, value.getElementType()); + return DenseElementsAttr::get(component_ty, range[port[1]]); + } return nullptr; } if (auto id = dyn_cast(op)) { if (port.size() == 1 && port[0] == 0) - return ComputeOutputComponent(ValuePort(id.input()), values); + return ComputeOutputComponent(ValuePort(id.getInput()), values); return nullptr; } - // Note: this focusses only on the trivial pack op case and this could be - // generalized. - if (auto pack_op = dyn_cast(op)) { - TensorType type = pack_op.getType().cast(); - if (!type.hasRank() || type.getRank() != 1) return nullptr; - if (port.size() != 2 || port[0] != 0) return nullptr; - ValuePort op_port(op->getOperand(port[1])); - return values(op_port); + if (auto shape_op = dyn_cast(op)) { + // No shape available in an unranked tensor type. + auto operand_ty = + shape_op.getOperand().getType().dyn_cast(); + if (!operand_ty) return nullptr; + + // Shape op has a single output so the first element should always be zero + // and the second element of port points to a particular element in the + // shape result. + if (port.size() != 2 || port[0] != 0 || port[1] >= operand_ty.getRank()) + return nullptr; + + // If the dim is dynamic, the dimension can't be inferred during + // compilation. + int64_t dim = operand_ty.getDimSize(port[1]); + if (dim == ShapedType::kDynamic) return nullptr; + + // Create an elements attribute for the particular dimension. + Type element_ty = getElementTypeOrSelf(shape_op.getType()); + APInt dim_value(element_ty.getIntOrFloatBitWidth(), dim); + auto component_ty = RankedTensorType::get({1}, element_ty); + return DenseElementsAttr::get(component_ty, {dim_value}); } if (auto graph = dyn_cast(op)) { @@ -566,6 +669,9 @@ Attribute ComputeOutputComponent(const ValuePort& value_port, return nullptr; } + ValuePort req = ComputeInputComponentFor(op, port); + if (req.IsValid()) return values(req); + return nullptr; } @@ -759,6 +865,14 @@ class ShapeInference { // yields. bool InferShapeForIfRegion(IfRegionOp op); + // Infers the shape CaseOp outputs based on the shapes of branch function + // result types. + bool InferShapeForCase(CaseOp op); + + // Infers the shape CaseRegion outputs based on the shapes of the branch + // yields. + bool InferShapeForCaseRegion(CaseRegionOp op); + // Infers the shape of _XlaHostComputeMlir based on the host computation // module. Returns true if a return type was changed. bool InferShapeForXlaHostComputeMlir(_XlaHostComputeMlirOp op); @@ -967,8 +1081,8 @@ bool ShapeInference::InferShapeForIf(IfOp op) { bool ShapeInference::InferShapeForIfRegion(IfRegionOp op) { bool changed = false; - Operation* then_yield = op.then_branch().front().getTerminator(); - Operation* else_yield = op.else_branch().front().getTerminator(); + Operation* then_yield = op.getThenBranch().front().getTerminator(); + Operation* else_yield = op.getElseBranch().front().getTerminator(); for (auto result : zip(op.getResults(), then_yield->getOperandTypes(), else_yield->getOperandTypes())) { // If then and else types do not match, skip refinement for that result. @@ -979,6 +1093,44 @@ bool ShapeInference::InferShapeForIfRegion(IfRegionOp op) { return changed; } +bool ShapeInference::InferShapeForCase(CaseOp op) { + DCOMMENT_OP(op.getOperation(), "Infer shape for case "); + + llvm::SmallVector branch_result_types; + for (int i = 0; i < op.num_branches(); ++i) { + branch_result_types.push_back(op.ResolveBranchFunction(&symbol_table_, i) + .getFunctionType() + .getResults()); + } + + bool changed = false; + for (const auto& result : op.getResults()) { + llvm::DenseSet types; + for (const auto& branch_result_type : branch_result_types) { + types.insert(branch_result_type[result.getResultNumber()]); + } + if (types.size() == 1) { + changed = RefineResultType(op, result, *types.begin()) || changed; + } + } + return changed; +} + +bool ShapeInference::InferShapeForCaseRegion(CaseRegionOp op) { + bool changed = false; + for (const auto& result : op.getResults()) { + llvm::DenseSet types; + for (auto& branch : op.getBranches()) { + Operation* yield = branch.front().getTerminator(); + types.insert(yield->getOperandTypes()[result.getResultNumber()]); + } + if (types.size() == 1) { + changed = RefineResultType(op, result, *types.begin()) || changed; + } + } + return changed; +} + bool ShapeInference::InferShapeForXlaHostComputeMlir( _XlaHostComputeMlirOp host_compute_op) { // Extract the module and function. @@ -1043,7 +1195,7 @@ bool ShapeInference::InferShapeForRestore(Operation* op) { if (!assign_op) { continue; } - auto subtypes = getElementTypeOrSelf(assign_op.resource()) + auto subtypes = getElementTypeOrSelf(assign_op.getResource()) .cast() .getSubtypes(); if (subtypes.empty()) { @@ -1145,7 +1297,7 @@ bool ShapeInference::InferShapeForMapDataset(MapDatasetOp op, // op. The MapDataset op always has N+1 inputs. // TODO(jpienaar): Avoid this lookup. auto module = op->getParentOfType(); - auto f = module.lookupSymbol(op.f()); + auto f = module.lookupSymbol(op.getF()); // Skip if function is not found or more than one caller. if (!f || !llvm::hasSingleElement(GetCallers(f))) return false; return InferShapeForDatasetOpCommon(op, f, max_iterations); @@ -1159,7 +1311,7 @@ bool ShapeInference::InferShapeForTakeWhileDataset(TakeWhileDatasetOp op, // TakeWhileDataset op. The TakeWhileDataset op always has N+1 inputs. // TODO(jpienaar): Avoid this lookup. auto module = op->getParentOfType(); - auto f = module.lookupSymbol(op.predicate()); + auto f = module.lookupSymbol(op.getPredicate()); // Skip if function is not found or more than one caller. if (!f || !llvm::hasSingleElement(GetCallers(f))) return false; return InferShapeForDatasetOpCommon(op, f, max_iterations); @@ -1177,14 +1329,14 @@ bool ShapeInference::InferShapeForReduceDataset(ReduceDatasetOp op, // TODO(jpienaar): Avoid this lookup. auto module = op->getParentOfType(); - auto f = module.lookupSymbol(op.f()); + auto f = module.lookupSymbol(op.getF()); // Skip if function is not found or it has more than one caller. if (!f || !llvm::hasSingleElement(GetCallers(f))) return false; - DatasetInput input_elements = GetDatasetInput(op.input_dataset()); + DatasetInput input_elements = GetDatasetInput(op.getInputDataset()); - const int num_states = op.output_shapes().size(); + const int num_states = op.getOutputShapes().size(); const int num_captured_arguments = op.getNumOperands() - 1 - num_states; // If input_elements is undefined, we can still infer the shapes for the @@ -1217,7 +1369,7 @@ bool ShapeInference::InferShapeForReduceDataset(ReduceDatasetOp op, // Set the first num_states arguments shapes & types from the state. for (int i = 0; i < num_states; ++i) { - Type t = GetType(op.output_shapes()[i], op.output_types()[i]); + Type t = GetType(op.getOutputShapes()[i], op.getOutputTypes()[i]); t = TypeMeet(*it, t); changed = changed || (t != *it); *it++ = t; @@ -1263,7 +1415,7 @@ bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) { if (auto tl_from_tensor = dyn_cast(op)) { // For TensorListFromTensor op we can infer element shape by dropping the // first dimension of input tensor. - element_type = DropFirstDimension(tl_from_tensor.tensor().getType()); + element_type = DropFirstDimension(tl_from_tensor.getTensor().getType()); if (!element_type || !element_type.hasStaticShape()) return false; } if (!CanInferTensorListElementType(handle, initial_element_shape, @@ -1284,7 +1436,7 @@ bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) { bool ShapeInference::InferShapeForVarHandleOp(VarHandleOp op) { DCOMMENT_OP(op, "Inferring shape for VarHandleOp"); - Value resource = op.resource(); + Value resource = op.getResource(); if (!CanBeRefined(resource.getType())) return false; // Make sure there are only use cases from the `AssignVariableOp` and @@ -1304,9 +1456,9 @@ bool ShapeInference::InferShapeForVarHandleOp(VarHandleOp op) { Operation* def = use.getOwner(); Value value; if (AssignVariableOp assign_op = dyn_cast(def)) { - value = assign_op.value(); + value = assign_op.getValue(); } else if (ReadVariableOp read_op = dyn_cast(def)) { - value = read_op.value(); + value = read_op.getValue(); } else { llvm_unreachable("unexpected operator type"); } @@ -1327,7 +1479,7 @@ bool ShapeInference::InferShapeForVarHandleOp(VarHandleOp op) { } // Helper function for creating a Window proto from user-supplied data. -// Returns llvm::None if the user-supplied data was invalid. +// Returns std::nullopt if the user-supplied data was invalid. llvm::Optional InferWindowFromDimensions( llvm::SmallVector window_dimensions, llvm::SmallVector window_strides, @@ -1352,7 +1504,7 @@ llvm::Optional InferWindowFromDimensions( verify_size(padding.size(), "padding entries") && verify_size(lhs_dilation.size(), "lhs dilation factors") && verify_size(rhs_dilation.size(), "rhs dilation factors"))) - return llvm::None; + return std::nullopt; xla::Window window; for (size_t i = 0; i < window_dimensions.size(); i++) { @@ -1392,7 +1544,7 @@ llvm::Optional InferWindowOutputShape( llvm::errs() << "Window has dimension " << window.dimensions_size() << " but base shape has dimension " << base_shape.getRank() << "\n"; - return llvm::None; + return std::nullopt; } std::vector output_dimensions(window.dimensions_size()); @@ -1402,26 +1554,26 @@ llvm::Optional InferWindowOutputShape( if (dim.size() <= 0) { llvm::errs() << "Window " << window.DebugString() << " has a non-positive dimension.\n"; - return llvm::None; + return std::nullopt; } if (dim.stride() <= 0) { llvm::errs() << "Window " << window.DebugString() << " has a non-positive stride.\n"; - return llvm::None; + return std::nullopt; } if (dim.base_dilation() < 1) { llvm::errs() << "Window " << window.DebugString() << " has a non-positive base area dilation factor.\n"; - return llvm::None; + return std::nullopt; } if (dim.window_dilation() < 1) { llvm::errs() << "Window " << window.DebugString() << " has a non-positive window dilation factor.\n"; - return llvm::None; + return std::nullopt; } if (base_shape.isDynamicDim(i)) { - output_dimensions[i] = ShapedType::kDynamicSize; + output_dimensions[i] = ShapedType::kDynamic; } else { const int64_t dilated_base = xla::window_util::DilatedBound( base_shape.getDimSize(i), dim.base_dilation()); @@ -1443,15 +1595,15 @@ bool ShapeInference::InferShapeForXlaReduceWindowOp(XlaReduceWindowOp op) { bool changed = false; - auto input_ty = op.input().getType().cast(); + auto input_ty = op.getInput().getType().cast(); DenseElementsAttr window_dimensions, window_strides, base_dilations, window_dilations, padding; if (input_ty.hasStaticShape() && - matchPattern(op.window_dimensions(), m_Constant(&window_dimensions)) && - matchPattern(op.window_strides(), m_Constant(&window_strides)) && - matchPattern(op.base_dilations(), m_Constant(&base_dilations)) && - matchPattern(op.window_dilations(), m_Constant(&window_dilations)) && - matchPattern(op.padding(), m_Constant(&padding))) { + matchPattern(op.getWindowDimensions(), m_Constant(&window_dimensions)) && + matchPattern(op.getWindowStrides(), m_Constant(&window_strides)) && + matchPattern(op.getBaseDilations(), m_Constant(&base_dilations)) && + matchPattern(op.getWindowDilations(), m_Constant(&window_dilations)) && + matchPattern(op.getPadding(), m_Constant(&padding))) { llvm::SmallVector window_dimensions_vec, window_strides_vec, base_dilations_vec, window_dilations_vec; llvm::SmallVector> padding_pairs( @@ -1489,15 +1641,15 @@ bool ShapeInference::InferShapeForXlaReduceWindowOp(XlaReduceWindowOp op) { op->emitOpError("failed to create window"); } auto output_shape = InferWindowOutputShape( - input_ty, window.getValue(), - op.init_value().getType().cast().getElementType()); + input_ty, window.value(), + op.getInitValue().getType().cast().getElementType()); if (!output_shape) { op->emitOpError("failed to infer output shape"); } changed = RefineResultType(op.getOperation(), op.getResult(), - output_shape.getValue()); + output_shape.value()); } return changed; @@ -1507,13 +1659,13 @@ bool ShapeInference::InferShapeForXlaSelectAndScatterOp( XlaSelectAndScatterOp op) { DCOMMENT_OP(op, "Inferring shape for XlaSelectAndScatterOp"); - auto operand_shape = op.operand().getType().cast(); - auto source_shape = op.source().getType().cast(); + auto operand_shape = op.getOperand().getType().cast(); + auto source_shape = op.getSource().getType().cast(); DenseElementsAttr window_dimensions, window_strides, padding; if (operand_shape.hasRank() && source_shape.hasRank() && - matchPattern(op.window_dimensions(), m_Constant(&window_dimensions)) && - matchPattern(op.window_strides(), m_Constant(&window_strides)) && - matchPattern(op.padding(), m_Constant(&padding))) { + matchPattern(op.getWindowDimensions(), m_Constant(&window_dimensions)) && + matchPattern(op.getWindowStrides(), m_Constant(&window_strides)) && + matchPattern(op.getPadding(), m_Constant(&padding))) { llvm::SmallVector window_dimensions_vec, window_strides_vec, base_dilations_vec, window_dilations_vec; llvm::SmallVector> padding_pairs( @@ -1540,36 +1692,36 @@ bool ShapeInference::InferShapeForXlaSelectAndScatterOp( op->emitOpError("failed to create window"); } auto window_result_shape = InferWindowOutputShape( - operand_shape, window.getValue(), operand_shape.getElementType()); + operand_shape, window.value(), operand_shape.getElementType()); if (!window_result_shape) { op->emitOpError("failed to infer window result shape"); } - if (window_result_shape.getValue() != source_shape) { + if (window_result_shape.value() != source_shape) { op->emitOpError( "Source shape does not match the shape of window-reduced operand."); } } return RefineResultType(op.getOperation(), op.getResult(), - op.operand().getType()); + op.getOperand().getType()); } bool ShapeInference::InferShapeForXlaGatherOp(XlaGatherOp op) { - xla::Shape input_shape = xla::TypeToShape(op.operand().getType()); + xla::Shape input_shape = xla::TypeToShape(op.getOperand().getType()); if (input_shape == xla::Shape()) return false; xla::Shape start_indices_shape = - xla::TypeToShape(op.start_indices().getType()); + xla::TypeToShape(op.getStartIndices().getType()); if (start_indices_shape == xla::Shape()) return false; xla::GatherDimensionNumbers gather_dim_numbers; - if (!gather_dim_numbers.ParseFromString(op.dimension_numbers().str())) + if (!gather_dim_numbers.ParseFromString(op.getDimensionNumbers().str())) return false; DenseIntElementsAttr slice_sizes_attr; - if (!matchPattern(op.slice_sizes(), m_Constant(&slice_sizes_attr))) + if (!matchPattern(op.getSliceSizes(), m_Constant(&slice_sizes_attr))) return false; llvm::SmallVector slice_sizes; for (const auto& attr : slice_sizes_attr.getValues()) { @@ -1590,7 +1742,7 @@ bool ShapeInference::InferShapeForXlaGatherOp(XlaGatherOp op) { return false; } - return RefineResultType(op, op.output(), *refined_type); + return RefineResultType(op, op.getOutput(), *refined_type); } llvm::Optional InferXlaConvOutputShape( @@ -1631,11 +1783,11 @@ llvm::Optional InferXlaConvOutputShape( lhs_dilations, rhs_dilations); auto output_shape = - InferWindowOutputShape(base_shape, window.getValue(), element_type); + InferWindowOutputShape(base_shape, window.value(), element_type); for (auto i = 0; i < num_spatial_dims; ++i) { output_dims[dnums.output_spatial_dimensions(i)] = - output_shape.getValue().getShape()[i]; + output_shape.value().getShape()[i]; DCOMMENT("inferrd output spatial dimension " << i << " at dimension numebr " << dnums.output_spatial_dimensions(i) << " is " @@ -1650,14 +1802,14 @@ llvm::Optional InferXlaConvOutputShape( // "third_party/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc" is // resolved LogicalResult PrecheckForXlaConvV2Op(XlaConvV2Op op) { - auto input_tensor = op.lhs(); - auto kernel_tensor = op.rhs(); - auto window_strides = op.window_strides(); - auto padding = op.padding(); - auto lhs_dilation = op.lhs_dilation(); - auto rhs_dilation = op.rhs_dilation(); - auto feature_group_count = op.feature_group_count(); - int64_t batch_group_count = op.batch_group_count(); + auto input_tensor = op.getLhs(); + auto kernel_tensor = op.getRhs(); + auto window_strides = op.getWindowStrides(); + auto padding = op.getPadding(); + auto lhs_dilation = op.getLhsDilation(); + auto rhs_dilation = op.getRhsDilation(); + auto feature_group_count = op.getFeatureGroupCount(); + int64_t batch_group_count = op.getBatchGroupCount(); auto input_args_have_static_shape = [&]() -> bool { return input_tensor.getType().cast().hasStaticShape() && @@ -1700,7 +1852,7 @@ LogicalResult PrecheckForXlaConvV2Op(XlaConvV2Op op) { DenseElementsAttr feature_group_count_attr; xla::ConvolutionDimensionNumbers dnums; - dnums.ParseFromString(op.dimension_numbersAttr().getValue().str()); + dnums.ParseFromString(op.getDimensionNumbersAttr().getValue().str()); if (dnums.input_spatial_dimensions_size() != dnums.kernel_spatial_dimensions_size()) { return op.emitOpError() << "Both arguments to convolution must have " @@ -1783,13 +1935,13 @@ bool ShapeInference::InferShapeForXlaConvV2Op(XlaConvV2Op op) { return changed; } - auto input_tensor = op.lhs(); - auto kernel_tensor = op.rhs(); - auto window_strides = op.window_strides(); - auto padding = op.padding(); - auto lhs_dilation = op.lhs_dilation(); - auto rhs_dilation = op.rhs_dilation(); - int64_t batch_group_count = op.batch_group_count(); + auto input_tensor = op.getLhs(); + auto kernel_tensor = op.getRhs(); + auto window_strides = op.getWindowStrides(); + auto padding = op.getPadding(); + auto lhs_dilation = op.getLhsDilation(); + auto rhs_dilation = op.getRhsDilation(); + int64_t batch_group_count = op.getBatchGroupCount(); DenseIntElementsAttr window_strides_attr, padding_attr, lhs_dilation_attr, rhs_dilation_attr; @@ -1802,7 +1954,7 @@ bool ShapeInference::InferShapeForXlaConvV2Op(XlaConvV2Op op) { llvm::SmallVector> padding_pairs( padding_attr.getNumElements() / 2); xla::ConvolutionDimensionNumbers dnums; - dnums.ParseFromString(op.dimension_numbersAttr().getValue().str()); + dnums.ParseFromString(op.getDimensionNumbersAttr().getValue().str()); auto input_tensor_shape = input_tensor.getType().cast(); for (auto i = 0; i < input_tensor_shape.getShape().size(); ++i) { @@ -1847,9 +1999,9 @@ bool ShapeInference::InferShapeForXlaConvV2Op(XlaConvV2Op op) { padding_pairs, lhs_dilations_vec, rhs_dilations_vec, batch_group_count, dnums, element_type); - if (output_shape.getValue()) { + if (output_shape.value()) { changed = RefineResultType(op.getOperation(), op.getResult(), - output_shape.getValue()); + output_shape.value()); return changed; } } @@ -2050,7 +2202,7 @@ bool CanWhileTypeBeRefinedWith(TensorType current_type, int64_t current_dim = std::get<0>(dim); int64_t potential_refined_dim = std::get<1>(dim); if (current_dim != potential_refined_dim && - current_dim != ShapedType::kDynamicSize) + current_dim != ShapedType::kDynamic) return false; } return true; @@ -2059,12 +2211,12 @@ bool CanWhileTypeBeRefinedWith(TensorType current_type, template bool ShapeInference::InferShapeForWhile(WhileOpTy op, TypeRange body_result_types) { - if (!op.shape_invariant()) - return RefineTypeForPassThroughOperands(op, op.input(), op.output()); + if (!op.getShapeInvariant()) + return RefineTypeForPassThroughOperands(op, op.getInput(), op.getOutput()); bool changed = false; for (auto entry : - zip(op.input().getTypes(), op.output(), body_result_types)) { + zip(op.getInput().getTypes(), op.getOutput(), body_result_types)) { Value result = std::get<1>(entry); TensorType body_result_type = std::get<2>(entry).template cast(); @@ -2143,6 +2295,11 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op, if (auto if_region = dyn_cast(op)) return InferShapeForIfRegion(if_region); + if (auto case_op = dyn_cast(op)) return InferShapeForCase(case_op); + + if (auto case_region = dyn_cast(op)) + return InferShapeForCaseRegion(case_region); + if (auto while_op = dyn_cast(op)) return InferShapeForWhile( while_op, while_op.body_function().getFunctionType().getResults()); @@ -2150,7 +2307,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op, if (auto while_region = dyn_cast(op)) return InferShapeForWhile( while_region, - while_region.body().front().getTerminator()->getOperandTypes()); + while_region.getBody().front().getTerminator()->getOperandTypes()); if (auto host_compute_op = dyn_cast<_XlaHostComputeMlirOp>(op)) { return InferShapeForXlaHostComputeMlir(host_compute_op); @@ -2215,7 +2372,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op, llvm::SmallVector inferred_return_shapes; if (failed(InferReturnTypeComponentsForTFOp( - /*location=*/None, op, graph_version_, operand_as_constant_fn, + /*location=*/std::nullopt, op, graph_version_, operand_as_constant_fn, op_result_as_shape_fn, result_element_type_fn, inferred_return_shapes))) return false; @@ -2230,8 +2387,8 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op, ShapedTypeComponents inferred = std::get<1>(result); TensorType inferred_type; if (inferred.hasRank()) { - inferred_type = tensorflow::GetTypeFromTFTensorShape( - inferred.getDims(), inferred.getElementType()); + inferred_type = + RankedTensorType::get(inferred.getDims(), inferred.getElementType()); } else { inferred_type = UnrankedTensorType::get(inferred.getElementType()); @@ -2289,7 +2446,7 @@ FailureOr ShapeInference::PropagateShapeToFunctions( any_failure = true; continue; } - any_nonconvergence = any_nonconvergence || !failure_or_converged.getValue(); + any_nonconvergence = any_nonconvergence || !failure_or_converged.value(); if (failed(InferShapeForFunctionReturnType(func))) any_failure = true; } if (any_failure) return failure(); @@ -2319,7 +2476,7 @@ FailureOr ShapeInference::PropagateShapeToRegions( InferShapeUntilFixPoint(region, max_iterations); if (failed(failure_or_converged)) any_failure = true; - else if (!failure_or_converged.getValue()) + else if (!failure_or_converged.value()) any_nonconvergence = true; } if (any_failure) return failure(); @@ -2399,7 +2556,7 @@ RankedTensorType GetCompatibleRankedTensorType(RankedTensorType lhs, if (lhs_dim == std::get<1>(dim)) { dims.push_back(lhs_dim); } else { - dims.push_back(ShapedType::kDynamicSize); + dims.push_back(ShapedType::kDynamic); } } return tensorflow::GetTypeFromTFTensorShape( @@ -2447,23 +2604,23 @@ FailureOr ShapeInference::PropagateShapeIntoAttachedFunctions( if (auto if_op = dyn_cast(op)) { DCOMMENT("Propagating shapes into If"); return PropagateShapeToFunctions( - module, if_op.input().getTypes(), + module, if_op.getInput().getTypes(), {if_op.ResolveThenFunction(&symbol_table_), if_op.ResolveElseFunction(&symbol_table_)}, max_iterations); } else if (auto case_op = dyn_cast(op)) { SmallVector branches; case_op.get_branch_functions(branches); - return PropagateShapeToFunctions(module, case_op.input().getTypes(), + return PropagateShapeToFunctions(module, case_op.getInput().getTypes(), branches, max_iterations); } else if (auto while_op = dyn_cast(op)) { // If `shape_invariant` is set, operand shapes cannot be simply propagated // to result shapes as the op may have different intermediate shapes (such // While ops can have different result shapes from operand shapes). // Compatible shapes must be determined before propagating them. - if (while_op.shape_invariant()) { + if (while_op.getShapeInvariant()) { auto compatible_types = GetWhileCompatibleTypes( - while_op.input().getTypes(), while_op.output().getTypes(), + while_op.getInput().getTypes(), while_op.getOutput().getTypes(), while_op.ResolveBodyFunction(&symbol_table_) .getFunctionType() .getInputs()); @@ -2474,7 +2631,7 @@ FailureOr ShapeInference::PropagateShapeIntoAttachedFunctions( max_iterations); } return PropagateShapeToFunctions( - module, while_op.input().getTypes(), + module, while_op.getInput().getTypes(), {while_op.ResolveCondFunction(&symbol_table_), while_op.ResolveBodyFunction(&symbol_table_)}, max_iterations); @@ -2504,19 +2661,19 @@ FailureOr ShapeInference::PropagateShapeIntoAttachedFunctions( }; if (auto xla_reduce_window_op = dyn_cast(op)) { - return propagate_shape_to(xla_reduce_window_op.computation()); + return propagate_shape_to(xla_reduce_window_op.getComputation()); } if (auto xla_select_and_scatter_op = dyn_cast(op)) { - return propagate_shape_to(xla_select_and_scatter_op.select()) - .getValue() && - propagate_shape_to(xla_select_and_scatter_op.scatter()).getValue(); + return propagate_shape_to(xla_select_and_scatter_op.getSelect()) + .value() && + propagate_shape_to(xla_select_and_scatter_op.getScatter()).value(); } else if (auto xla_variadic_reduce_v2_op = dyn_cast(op)) { - return propagate_shape_to(xla_variadic_reduce_v2_op.reducer()); + return propagate_shape_to(xla_variadic_reduce_v2_op.getReducer()); } else if (auto xla_variadic_sort_op = dyn_cast(op)) { - return propagate_shape_to(xla_variadic_sort_op.comparator()); + return propagate_shape_to(xla_variadic_sort_op.getComparator()); } } @@ -2532,16 +2689,16 @@ FailureOr ShapeInference::PropagateShapeIntoAttachedRegions( // to result shapes as the op may have different intermediate shapes (such // While ops can have different result shapes from operand shapes). // Compatible shapes must be determined before propagating them. - if (while_op.shape_invariant()) { + if (while_op.getShapeInvariant()) { auto compatible_types = GetWhileCompatibleTypes( - while_op.input().getTypes(), while_op.output().getTypes(), - while_op.body().getArgumentTypes()); + while_op.getInput().getTypes(), while_op.getOutput().getTypes(), + while_op.getBody().getArgumentTypes()); return PropagateShapeToRegions(compatible_types, - {&while_op.cond(), &while_op.body()}, + {&while_op.getCond(), &while_op.getBody()}, max_iterations); } - return PropagateShapeToRegions(while_op.input().getTypes(), - {&while_op.cond(), &while_op.body()}, + return PropagateShapeToRegions(while_op.getInput().getTypes(), + {&while_op.getCond(), &while_op.getBody()}, max_iterations); } return true; @@ -2745,7 +2902,7 @@ static FailureOr InferShapeForFunction(ShapeInference& context, int64_t max_iterations) { FailureOr failure_or_converged = context.InferShapeUntilFixPoint(&func.getBody(), max_iterations); - if (failed(failure_or_converged) || !failure_or_converged.getValue()) + if (failed(failure_or_converged) || !failure_or_converged.value()) return failure_or_converged; // TODO(b/156276510): Verify that it is always fine to refine a function's // return type, as long as we do not change the argument shapes. @@ -2799,7 +2956,7 @@ FailureOr InferShapeForFunction(func::FuncOp func, FailureOr failure_or_converged = context.InferShapeUntilFixPoint(&func.getBody(), max_iterations); - if (failed(failure_or_converged) || !failure_or_converged.getValue()) + if (failed(failure_or_converged) || !failure_or_converged.value()) return failure_or_converged; if (failed(context.InferShapeForFunctionReturnType(func))) return failure(); @@ -2833,7 +2990,7 @@ FailureOr InferModuleShape(ModuleOp module, int64_t max_iterations) { func::FuncOp func = context.front(); FailureOr failure_or_converged = InferShapeForFunction(context, func, max_iterations); - if (failed(failure_or_converged) || !failure_or_converged.getValue()) + if (failed(failure_or_converged) || !failure_or_converged.value()) return failure_or_converged; context.pop_front(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc index 072399f8a8d..8180e2d4084 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_pass.cc @@ -40,7 +40,7 @@ class ShapeInference auto failure_or_converged = InferModuleShape(getOperation(), max_iterations_); if (failed(failure_or_converged)) return signalPassFailure(); - if (!failure_or_converged.getValue()) { + if (!failure_or_converged.value()) { getOperation().emitError() << "shape inference pass did not reach convergence after " << max_iterations_; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc index fa8682d9d97..1a84be115b3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc @@ -63,7 +63,7 @@ class ClusterConstantSinkingPass if (!const_op) return; // Filter constants using user provided predicate function. - if (filter && !filter(cluster, const_op.value())) return; + if (filter && !filter(cluster, const_op.getValue())) return; // We found a constant, try to insert it in the map and re-use its // cloned value if any. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index fc6d37bc7ca..c07e8b6a38a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -139,7 +139,7 @@ LogicalResult HandleWhileOp( llvm::SmallDenseMap body_map; auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional { auto it = data_var_to_size_var.find(while_op.getOperand(index)); - if (it == data_var_to_size_var.end()) return llvm::None; + if (it == data_var_to_size_var.end()) return std::nullopt; return it->getFirst().getType(); }; auto add_size_vars_to_return = [&](ArrayRef new_args) { @@ -211,7 +211,7 @@ LogicalResult HandleIfOp( auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional { auto it = data_var_to_size_var.find(if_op.getOperand(index + 1)); - if (it == data_var_to_size_var.end()) return llvm::None; + if (it == data_var_to_size_var.end()) return std::nullopt; return it->getFirst().getType(); }; ModifyFunctionSignature(then_func, &then_map, find_arg_stack_type); @@ -316,7 +316,7 @@ LogicalResult HandlePartitionedCallOp( } auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional { auto it = data_var_to_size_var.find(call.getOperand(index)); - if (it == data_var_to_size_var.end()) return llvm::None; + if (it == data_var_to_size_var.end()) return std::nullopt; return it->getFirst().getType(); }; ModifyFunctionSignature(lowered_callee, &callee_map, find_arg_stack_type); @@ -355,10 +355,10 @@ LogicalResult HandleStackV2Op( llvm::SmallDenseMap* data_var_to_size_var) { // Create a buffer variable and a size variable to replace the stack. auto elem_type = cutil::GetElementTypeFromAccess( - stack.handle(), module, [](Operation* user) -> llvm::Optional { + stack.getHandle(), module, [](Operation* user) -> llvm::Optional { auto push = llvm::dyn_cast(user); - if (!push) return llvm::None; - return push.elem().getType(); + if (!push) return std::nullopt; + return push.getElem().getType(); }); if (!elem_type.has_value()) { return stack.emitOpError("cannot infer element shape of stack"); @@ -366,7 +366,7 @@ LogicalResult HandleStackV2Op( OpBuilder builder(stack); Value buffer; if (failed(cutil::CreateInitBufferValue( - elem_type->getShape(), stack.max_size(), stack, + elem_type->getShape(), stack.getMaxSize(), stack, elem_type->getElementType(), builder, &buffer))) { return failure(); } @@ -384,7 +384,7 @@ LogicalResult HandleStackV2Op( cutil::GetR1Const({0LL}, builder, stack.getLoc()), builder, stack.getLoc()); cutil::WriteLocalVariable(local_var, buffer, builder, stack.getLoc()); - stack.handle().replaceAllUsesWith(local_var); + stack.getHandle().replaceAllUsesWith(local_var); (*data_var_to_size_var)[local_var] = local_size_var; stack.erase(); return success(); @@ -393,22 +393,23 @@ LogicalResult HandleStackV2Op( LogicalResult HandleStackPushV2Op( TF::StackPushV2Op push, llvm::SmallDenseMap* data_var_to_size_var) { - auto it = data_var_to_size_var->find(push.handle()); + auto it = data_var_to_size_var->find(push.getHandle()); if (it == data_var_to_size_var->end()) { return push.emitOpError("unknown stack"); } // Push output simply forward the input element. - push.replaceAllUsesWith(push.elem()); + push.replaceAllUsesWith(push.getElem()); OpBuilder builder(push); // Read the current buffer and size. auto stack_val = - cutil::ReadLocalVariable(push.handle(), builder, push.getLoc()); + cutil::ReadLocalVariable(push.getHandle(), builder, push.getLoc()); auto index = cutil::ReadLocalVariable(it->getSecond(), builder, push.getLoc()); - stack_val = - cutil::SetElement(index, stack_val, push.elem(), builder, push.getLoc()); + stack_val = cutil::SetElement(index, stack_val, push.getElem(), builder, + push.getLoc()); // Assign the new buffer and size. - cutil::WriteLocalVariable(push.handle(), stack_val, builder, push.getLoc()); + cutil::WriteLocalVariable(push.getHandle(), stack_val, builder, + push.getLoc()); index = builder.create( push.getLoc(), ArrayRef{index.getType()}, ArrayRef{index, cutil::GetR1Const({1}, builder, push.getLoc())}); @@ -420,14 +421,14 @@ LogicalResult HandleStackPushV2Op( LogicalResult HandleStackPopV2Op( TF::StackPopV2Op pop, llvm::SmallDenseMap* data_var_to_size_var) { - auto it = data_var_to_size_var->find(pop.handle()); + auto it = data_var_to_size_var->find(pop.getHandle()); if (it == data_var_to_size_var->end()) { return pop.emitOpError("unknown stack"); } OpBuilder builder(pop); // Read the current buffer and size. auto stack_val = - cutil::ReadLocalVariable(pop.handle(), builder, pop.getLoc()); + cutil::ReadLocalVariable(pop.getHandle(), builder, pop.getLoc()); auto size = cutil::ReadLocalVariable(it->getSecond(), builder, pop.getLoc()); auto new_size = builder.create( pop.getLoc(), ArrayRef{size.getType()}, @@ -501,7 +502,7 @@ LogicalResult DecomposeStackOpsInternal( return failure(); } } else if (auto close = llvm::dyn_cast(&op)) { - data_var_to_size_var->erase(close.handle()); + data_var_to_size_var->erase(close.getHandle()); close.erase(); } else if (auto while_op = llvm::dyn_cast(&op)) { if (failed(HandleWhileOp(while_op, module, *data_var_to_size_var, diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/strip_saved_module_metadata.cc b/tensorflow/compiler/mlir/tensorflow/transforms/strip_saved_module_metadata.cc index 9300109fe66..422b722c9d2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/strip_saved_module_metadata.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/strip_saved_module_metadata.cc @@ -64,8 +64,10 @@ void StripFunction(func::FuncOp func) { } for (int i = 0; i < func.getNumArguments(); ++i) { + llvm::ArrayRef attrs = + mlir::function_interface_impl::getArgAttrs(func, i); auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range( - func.getArgAttrs(i), + attrs, [](NamedAttribute namedAttr) { return ShouldStripAttr(namedAttr); })); for (auto namedAttr : stripAttrs) { func.removeArgAttr(i, namedAttr.getName()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/strip_tf_attributes.cc b/tensorflow/compiler/mlir/tensorflow/transforms/strip_tf_attributes.cc index 0f2d41efbba..26f4bbf8f93 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/strip_tf_attributes.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/strip_tf_attributes.cc @@ -50,8 +50,10 @@ void StripFunction(func::FuncOp func) { } for (int i = 0; i < func.getNumArguments(); ++i) { + llvm::ArrayRef attrs = + mlir::function_interface_impl::getArgAttrs(func, i); auto stripAttrs = llvm::to_vector<4>(llvm::make_filter_range( - func.getArgAttrs(i), + attrs, [](NamedAttribute namedAttr) { return ShouldStripAttr(namedAttr); })); for (auto namedAttr : stripAttrs) { func.removeArgAttr(i, namedAttr.getName()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index ce9706d8fd6..d18d7781930 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -80,17 +81,17 @@ LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split, RankedTensorType* elem_type, int64_t* count) { auto lengths_const = - llvm::dyn_cast_or_null(split.lengths().getDefiningOp()); + llvm::dyn_cast_or_null(split.getLengths().getDefiningOp()); if (!lengths_const) return split.emitOpError("non-constant split lengths"); - *count = lengths_const.value().getNumElements(); + *count = lengths_const.getValue().getNumElements(); if (*count <= 0) return split.emitOpError("non-positive split count"); - auto buffer_type = split.value().getType().dyn_cast(); + auto buffer_type = split.getValue().getType().dyn_cast(); if (!buffer_type || !buffer_type.hasStaticShape() || buffer_type.getRank() < 1) { return split.emitOpError("unknown or invalid split tensor shape"); } int64_t length = buffer_type.getDimSize(0) / *count; - for (const auto& len : lengths_const.value().getValues()) { + for (const auto& len : lengths_const.getValue().getValues()) { if (length == len.getSExtValue()) continue; return split.emitOpError("different split lengths are not supported"); } @@ -106,7 +107,7 @@ LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split, // Tries to infer the tensor array element shape. llvm::Optional> GetTensorArrayElementShape( TF::TensorArrayV3Op ta, ModuleOp module) { - auto element_shape = ta.element_shapeAttr().cast(); + auto element_shape = ta.getElementShapeAttr().cast(); if (element_shape.hasStaticShape()) { auto shape = element_shape.getShape(); // Convert int64 to int64_t. @@ -116,21 +117,21 @@ llvm::Optional> GetTensorArrayElementShape( bool has_failure = false; auto elem_type = cutil::GetElementTypeFromAccess( - ta.handle(), module, [&](Operation* user) -> llvm::Optional { - if (has_failure) return llvm::None; + ta.getHandle(), module, [&](Operation* user) -> llvm::Optional { + if (has_failure) return std::nullopt; if (auto write = llvm::dyn_cast(user)) { - return write.value().getType(); + return write.getValue().getType(); } else if (auto split = llvm::dyn_cast(user)) { - if (!split.lengths().getDefiningOp() || - !llvm::isa(split.lengths().getDefiningOp())) { - return llvm::None; + if (!split.getLengths().getDefiningOp() || + !llvm::isa(split.getLengths().getDefiningOp())) { + return std::nullopt; } RankedTensorType t; int64_t count; if (failed(GetSplitElementTypeAndCount(split, &t, &count))) { has_failure = true; - return llvm::None; + return std::nullopt; } return t; } else if (auto scatter = @@ -138,28 +139,28 @@ llvm::Optional> GetTensorArrayElementShape( // TensorArrayScatter writes vector of tensors to TensorArray. We can // deduce the shape of TensorArray by dropping the 0th dim of // TensorArrayScatter `value`. - auto t = scatter.value().getType().dyn_cast(); - if (!t || t.getShape().empty()) return llvm::None; + auto t = scatter.getValue().getType().dyn_cast(); + if (!t || t.getShape().empty()) return std::nullopt; return RankedTensorType::get(t.getShape().drop_front(), t.getElementType()); } else if (auto gather = llvm::dyn_cast(user)) { // Try to infer from result type of gather. - auto t = gather.value().getType().dyn_cast(); + auto t = gather.getValue().getType().dyn_cast(); if (t && !t.getShape().empty()) return RankedTensorType::get(t.getShape().drop_front(), t.getElementType()); // Try to infer from `element_shape` attribute of gather. - auto element_shape = gather.element_shapeAttr() + auto element_shape = gather.getElementShapeAttr() .dyn_cast_or_null(); if (element_shape && element_shape.hasStaticShape()) { return RankedTensorType::get(element_shape.getShape(), - gather.dtype()); + gather.getDtype()); } } - return llvm::None; + return std::nullopt; }); - if (!elem_type) return llvm::None; + if (!elem_type) return std::nullopt; return llvm::to_vector<8>(elem_type->getShape()); } @@ -196,13 +197,13 @@ LogicalResult HandleTensorArrayV3Op( llvm::SmallDenseMap* stats) { auto elem_shape = GetTensorArrayElementShape(ta, module); if (!elem_shape) return ta.emitOpError("unknown element shape"); - if (ta.dynamic_size()) { + if (ta.getDynamicSize()) { return ta.emitOpError("dynamic tensor array size is unsupported"); } Value buffer; OpBuilder builder(ta); - if (failed(cutil::CreateInitBufferValue(*elem_shape, ta.size(), ta, - ta.dtype(), builder, &buffer))) { + if (failed(cutil::CreateInitBufferValue(*elem_shape, ta.getSize(), ta, + ta.getDtype(), builder, &buffer))) { return failure(); } auto var_type = RankedTensorType::get( @@ -212,7 +213,7 @@ LogicalResult HandleTensorArrayV3Op( auto local_var = builder.create( ta.getLoc(), ArrayRef{var_type}, ArrayRef{}); cutil::WriteLocalVariable(local_var, buffer, builder, ta.getLoc()); - ta.handle().replaceAllUsesWith(local_var); + ta.getHandle().replaceAllUsesWith(local_var); // The flow output is just a way for the front end to enforce ordering among // tensor array ops, but in the MLIR TF dialect they have sequential ordering. // Just create a constant to replace its uses. @@ -220,7 +221,7 @@ LogicalResult HandleTensorArrayV3Op( scalar_tensor.scalar()() = 0.0f; auto flow = builder.create( ta.getLoc(), tensorflow::ConvertTensor(scalar_tensor, &builder).value()); - ta.flow().replaceAllUsesWith(flow); + ta.getFlow().replaceAllUsesWith(flow); ta.erase(); (*stats)[local_var].accumulate_on_write = false; return success(); @@ -229,17 +230,17 @@ LogicalResult HandleTensorArrayV3Op( LogicalResult HandleTensorArrayReadV3Op( TF::TensorArrayReadV3Op read, const llvm::SmallDenseMap& stats) { - auto local_var = read.handle(); + auto local_var = read.getHandle(); if (stats.count(local_var) == 0) { return read.emitOpError("unknown tensor array"); } OpBuilder builder(read); auto buffer = cutil::ReadLocalVariable(local_var, builder, read.getLoc()); auto index_reshape = - cutil::ReshapeScalarToSizeType(builder, read.index(), read.getLoc()); + cutil::ReshapeScalarToSizeType(builder, read.getIndex(), read.getLoc()); auto elem = cutil::GetElement(index_reshape, buffer, builder, read.getLoc()); - ReplaceAllUsesExceptTerminator(read.value(), elem); - ReplaceAllUsesWithCast(read.value(), elem); + ReplaceAllUsesExceptTerminator(read.getValue(), elem); + ReplaceAllUsesWithCast(read.getValue(), elem); read.erase(); // The clear_after_read attribute does not mean setting the tensor to 0 after // read; instead it does not allow a second read before the next write. We @@ -250,14 +251,14 @@ LogicalResult HandleTensorArrayReadV3Op( LogicalResult HandleTensorArrayWriteV3Op( TF::TensorArrayWriteV3Op write, const llvm::SmallDenseMap& stats) { - auto local_var = write.handle(); + auto local_var = write.getHandle(); auto stat_it = stats.find(local_var); if (stat_it == stats.end()) return write.emitOpError("unknown tensor array"); OpBuilder builder(write); auto buffer = cutil::ReadLocalVariable(local_var, builder, write.getLoc()); auto index_reshape = - cutil::ReshapeScalarToSizeType(builder, write.index(), write.getLoc()); - auto elem = write.value(); + cutil::ReshapeScalarToSizeType(builder, write.getIndex(), write.getLoc()); + Value elem = write.getValue(); if (stat_it->getSecond().accumulate_on_write) { // Get the old slice, and accumulate with it. We set keep_slice_shape // (keeping the leading size-1 dimension) because it avoids reshape back and @@ -277,7 +278,7 @@ LogicalResult HandleTensorArrayWriteV3Op( buffer = cutil::SetElement(index_reshape, buffer, elem, builder, write.getLoc()); cutil::WriteLocalVariable(local_var, buffer, builder, write.getLoc()); - write.flow_out().replaceAllUsesWith(write.flow_in()); + write.getFlowOut().replaceAllUsesWith(write.getFlowIn()); write.erase(); return success(); } @@ -285,7 +286,7 @@ LogicalResult HandleTensorArrayWriteV3Op( LogicalResult HandleTensorArrayConcatV3Op( TF::TensorArrayConcatV3Op concat, const llvm::SmallDenseMap& stats) { - auto local_var = concat.handle(); + auto local_var = concat.getHandle(); if (stats.count(local_var) == 0) { return concat.emitOpError("unknown tensor array"); } @@ -304,8 +305,8 @@ LogicalResult HandleTensorArrayConcatV3Op( RankedTensorType::get(shape, buffer_type.getElementType())}, ArrayRef{buffer, cutil::GetR1Const(shape, builder, concat.getLoc())}); - ReplaceAllUsesExceptTerminator(concat.value(), buffer); - ReplaceAllUsesWithCast(concat.value(), buffer); + ReplaceAllUsesExceptTerminator(concat.getValue(), buffer); + ReplaceAllUsesWithCast(concat.getValue(), buffer); // Create the lengths as a list of the same value (element size). tensorflow::Tensor lengths_tensor(tensorflow::DT_INT64, @@ -313,7 +314,7 @@ LogicalResult HandleTensorArrayConcatV3Op( for (int64_t i = 0; i < buffer_type.getDimSize(0); ++i) { lengths_tensor.vec()(i) = buffer_type.getDimSize(1); } - concat.lengths().replaceAllUsesWith(builder.create( + concat.getLengths().replaceAllUsesWith(builder.create( concat.getLoc(), tensorflow::ConvertTensor(lengths_tensor, &builder).value())); concat.erase(); @@ -323,7 +324,7 @@ LogicalResult HandleTensorArrayConcatV3Op( LogicalResult HandleTensorArraySplitV3Op( TF::TensorArraySplitV3Op split, const llvm::SmallDenseMap& stats) { - auto local_var = split.handle(); + auto local_var = split.getHandle(); if (stats.count(local_var) == 0) { return split.emitOpError("unknown tensor array"); } @@ -337,22 +338,23 @@ LogicalResult HandleTensorArraySplitV3Op( buffer_shape.push_back(count); for (int64_t dim : elem_type.getShape()) buffer_shape.push_back(dim); // Reshape the input to match the buffer of the tensor array. - auto buffer = builder - .create( - split.getLoc(), - ArrayRef{RankedTensorType::get( - buffer_shape, elem_type.getElementType())}, - ArrayRef{split.value(), - cutil::GetR1Const(buffer_shape, builder, - split.getLoc())}) - .output(); + Value buffer = + builder + .create( + split.getLoc(), + ArrayRef{RankedTensorType::get(buffer_shape, + elem_type.getElementType())}, + ArrayRef{ + split.getValue(), + cutil::GetR1Const(buffer_shape, builder, split.getLoc())}) + .getOutput(); // Accumulate with the old buffer. auto old_buffer = cutil::ReadLocalVariable(local_var, builder, split.getLoc()); buffer = cutil::AccumulateBuffers(old_buffer, buffer, builder, split.getLoc()); cutil::WriteLocalVariable(local_var, buffer, builder, split.getLoc()); - split.flow_out().replaceAllUsesWith(split.flow_in()); + split.getFlowOut().replaceAllUsesWith(split.getFlowIn()); split.erase(); return success(); } @@ -360,7 +362,7 @@ LogicalResult HandleTensorArraySplitV3Op( LogicalResult HandleTensorArraySizeV3Op( TF::TensorArraySizeV3Op size, const llvm::SmallDenseMap& stats) { - auto local_var = size.handle(); + auto local_var = size.getHandle(); if (stats.count(local_var) == 0) { return size.emitOpError("unknown tensor array"); } @@ -371,7 +373,7 @@ LogicalResult HandleTensorArraySizeV3Op( OpBuilder builder(size); auto result = cutil::CreateScalarConst(buffer_type.getDimSize(0), builder, size.getLoc()); - size.size().replaceAllUsesWith(result); + size.getSize().replaceAllUsesWith(result); size.erase(); return success(); } @@ -398,13 +400,13 @@ LogicalResult CreateAndInitializeGradVariable(Type local_var_type, LogicalResult HandleTensorArrayGradV3Op( TF::TensorArrayGradV3Op grad, llvm::SmallDenseMap* stats) { - auto local_var = grad.handle(); + auto local_var = grad.getHandle(); OpBuilder builder(grad); Value grad_var; auto sit = stats->find(local_var); if (sit == stats->end()) return grad.emitOpError("unknown tensor array"); auto emplace_res = - sit->getSecond().grads.try_emplace(grad.source().str(), Value()); + sit->getSecond().grads.try_emplace(grad.getSource().str(), Value()); if (!emplace_res.second) { // If the source has been assigned a grad, use it. grad_var = emplace_res.first->second; @@ -417,8 +419,8 @@ LogicalResult HandleTensorArrayGradV3Op( // Write to a grad accumulates with previous writes. (*stats)[grad_var].accumulate_on_write = true; } - grad.flow_out().replaceAllUsesWith(grad.flow_in()); - grad.grad_handle().replaceAllUsesWith(grad_var); + grad.getFlowOut().replaceAllUsesWith(grad.getFlowIn()); + grad.getGradHandle().replaceAllUsesWith(grad_var); grad.erase(); return success(); } @@ -426,16 +428,16 @@ LogicalResult HandleTensorArrayGradV3Op( LogicalResult HandleTensorArrayGatherV3Op( TF::TensorArrayGatherV3Op gather, const llvm::SmallDenseMap& stats) { - auto local_var = gather.handle(); + auto local_var = gather.getHandle(); if (stats.count(local_var) == 0) { return gather.emitOpError("unknown tensor array"); } OpBuilder builder(gather); auto buffer = cutil::ReadLocalVariable(local_var, builder, gather.getLoc()); - auto result = - cutil::GatherElements(gather.indices(), buffer, builder, gather.getLoc()); - ReplaceAllUsesExceptTerminator(gather.value(), result); - ReplaceAllUsesWithCast(gather.value(), result); + auto result = cutil::GatherElements(gather.getIndices(), buffer, builder, + gather.getLoc()); + ReplaceAllUsesExceptTerminator(gather.getValue(), result); + ReplaceAllUsesWithCast(gather.getValue(), result); gather.erase(); return success(); } @@ -443,16 +445,17 @@ LogicalResult HandleTensorArrayGatherV3Op( LogicalResult HandleTensorArrayScatterV3Op( TF::TensorArrayScatterV3Op scatter, const llvm::SmallDenseMap& stats) { - auto local_var = scatter.handle(); + auto local_var = scatter.getHandle(); if (stats.count(local_var) == 0) { return scatter.emitOpError("unknown tensor array"); } OpBuilder builder(scatter); auto buffer = cutil::ReadLocalVariable(local_var, builder, scatter.getLoc()); - buffer = cutil::ScatterAccumulateElements(scatter.indices(), scatter.value(), - buffer, builder, scatter.getLoc()); + buffer = + cutil::ScatterAccumulateElements(scatter.getIndices(), scatter.getValue(), + buffer, builder, scatter.getLoc()); cutil::WriteLocalVariable(local_var, buffer, builder, scatter.getLoc()); - scatter.flow_out().replaceAllUsesWith(scatter.flow_in()); + scatter.getFlowOut().replaceAllUsesWith(scatter.getFlowIn()); scatter.erase(); return success(); } @@ -488,7 +491,7 @@ llvm::SmallDenseMap> AccessedGradients( return; } if (auto grad = llvm::dyn_cast(op)) { - insert(grad.handle(), grad.source().str(), func_block); + insert(grad.getHandle(), grad.getSource().str(), func_block); } else if (auto while_op = llvm::dyn_cast(op)) { for (const auto& entry : AccessedGradients( {while_op.body_function(), while_op.cond_function()}, module)) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index 2b47b332671..92140dfbc98 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -13,6 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include + #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -33,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" @@ -158,7 +164,7 @@ LogicalResult HandleWhileOp( llvm::SmallDenseMap body_map; auto find_arg_tensor_list_type = [&](int64_t index) -> llvm::Optional { auto it = buffer_to_size->find(while_op.getOperand(index)); - if (it == buffer_to_size->end()) return llvm::None; + if (it == buffer_to_size->end()) return std::nullopt; return it->getFirst().getType(); }; auto arg_buffer_size_is_fixed = [&](int64_t index) { @@ -219,7 +225,7 @@ LogicalResult HandleCaseOrIfOp( auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional { auto it = buffer_to_size->find(op.getOperand(index + 1)); - if (it == buffer_to_size->end()) return llvm::None; + if (it == buffer_to_size->end()) return std::nullopt; return it->getFirst().getType(); }; auto arg_buffer_size_is_fixed = [&](int64_t index) { @@ -288,7 +294,7 @@ LogicalResult HandleWhileRegionOp( }; // Rewrite body. - Region& body_region = while_op.body(); + Region& body_region = while_op.getBody(); modify_region_arguments(body_region); if (failed(DecomposeTensorListOpsInternal( &body_region.front(), module, buffer_to_size, @@ -299,7 +305,7 @@ LogicalResult HandleWhileRegionOp( body_region.front(), *buffer_to_size); // Rewrite cond. - Region& cond_region = while_op.cond(); + Region& cond_region = while_op.getCond(); modify_region_arguments(cond_region); if (failed(DecomposeTensorListOpsInternal( &cond_region.front(), module, buffer_to_size, @@ -319,8 +325,8 @@ LogicalResult HandleWhileRegionOp( auto new_while = builder.create( while_op.getLoc(), body_region.front().getTerminator()->getOperandTypes(), new_while_operands, while_op->getAttrs()); - new_while.body().takeBody(body_region); - new_while.cond().takeBody(cond_region); + new_while.getBody().takeBody(body_region); + new_while.getCond().takeBody(cond_region); for (const auto& entry : output_buffer_to_size) { (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = { new_while.getResult(std::get<1>(entry)), std::get<2>(entry)}; @@ -337,8 +343,8 @@ LogicalResult HandleIfRegionOp( llvm::StringMap* decomposed_partitioned_call_callees) { // Rewrite the branches. - Region& then_branch = if_op.then_branch(); - Region& else_branch = if_op.else_branch(); + Region& then_branch = if_op.getThenBranch(); + Region& else_branch = if_op.getElseBranch(); if (failed(DecomposeTensorListOpsInternal( &then_branch.front(), module, buffer_to_size, decomposed_partitioned_call_callees))) @@ -364,8 +370,8 @@ LogicalResult HandleIfRegionOp( new_op.getResult(std::get<1>(entry)), std::get<2>(entry)}; } - new_op.then_branch().takeBody(if_op.then_branch()); - new_op.else_branch().takeBody(if_op.else_branch()); + new_op.getThenBranch().takeBody(if_op.getThenBranch()); + new_op.getElseBranch().takeBody(if_op.getElseBranch()); if_op.replaceAllUsesWith( new_op.getResults().take_front(if_op.getNumResults())); @@ -474,7 +480,7 @@ LogicalResult HandlePartitionedCallOp( } auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional { auto it = buffer_to_size->find(call.getOperand(index)); - if (it == buffer_to_size->end()) return llvm::None; + if (it == buffer_to_size->end()) return std::nullopt; return it->getFirst().getType(); }; auto arg_buffer_size_is_fixed = [&](int64_t index) { @@ -530,9 +536,9 @@ LogicalResult GetConstShapeValue(Value shape_value, if (!shape_op) return failure(); auto shape_const_op = llvm::dyn_cast(shape_op); if (!shape_const_op) return failure(); - for (const auto& v : shape_const_op.value().getValues()) { + for (const auto& v : shape_const_op.getValue().getValues()) { int64_t dim_size = v.getSExtValue(); - if (dim_size == ShapedType::kDynamicSize) return failure(); + if (dim_size == tensorflow::kTFDynamicSize) return failure(); shape->push_back(dim_size); } return success(); @@ -562,17 +568,17 @@ LogicalResult HandleEmptyTensorListOp( // shape inference might have successfully inferred the element shape from // write operations on the TensorList. if (failed(GetElementShapeFromResultType(list.getType(), &element_shape))) { - if (failed(GetConstShapeValue(list.element_shape(), &element_shape))) { + if (failed(GetConstShapeValue(list.getElementShape(), &element_shape))) { return list.emitOpError("unknown tensor list element shape"); } } if (failed(cutil::CreateInitBufferValue( - element_shape, list.max_num_elements(), list, list.element_dtype(), + element_shape, list.getMaxNumElements(), list, list.getElementDtype(), builder, &buffer))) { return failure(); } Value size = cutil::GetR1Const({0LL}, builder, list.getLoc()); - list.handle().replaceAllUsesWith(buffer); + list.getHandle().replaceAllUsesWith(buffer); (*buffer_to_size)[buffer] = {size, /*fixed=*/false}; list.erase(); return success(); @@ -589,19 +595,19 @@ LogicalResult HandleTensorListReserveOp( // shape inference might have successfully inferred the element shape from // write operations on the TensorList. if (failed(GetElementShapeFromResultType(list.getType(), &element_shape))) { - if (failed(GetConstShapeValue(list.element_shape(), &element_shape))) { + if (failed(GetConstShapeValue(list.getElementShape(), &element_shape))) { return list.emitOpError("unknown tensor list element shape"); } } - if (failed(cutil::CreateInitBufferValue(element_shape, list.num_elements(), - list, list.element_dtype(), builder, + if (failed(cutil::CreateInitBufferValue(element_shape, list.getNumElements(), + list, list.getElementDtype(), builder, &buffer))) { return failure(); } - Value size = cutil::ReshapeScalarToSizeType(builder, list.num_elements(), + Value size = cutil::ReshapeScalarToSizeType(builder, list.getNumElements(), list.getLoc()); (*buffer_to_size)[buffer] = {size, /*fixed=*/true}; - list.handle().replaceAllUsesWith(buffer); + list.getHandle().replaceAllUsesWith(buffer); list.erase(); return success(); } @@ -611,15 +617,15 @@ LogicalResult HandleTensorListFromTensorOp( llvm::SmallDenseMap* buffer_to_size) { OpBuilder builder(list); Value buffer = builder.create( - list.getLoc(), ArrayRef{list.tensor().getType()}, - ArrayRef{list.tensor()}); + list.getLoc(), ArrayRef{list.getTensor().getType()}, + ArrayRef{list.getTensor()}); auto type = buffer.getType().cast(); if (!type.hasStaticShape()) { return list.emitOpError("TensorListFromTensorOp input has unknown shape."); } Value size = cutil::GetR1Const({type.getShape()[0]}, builder, list.getLoc()); (*buffer_to_size)[buffer] = {size, /*fixed=*/true}; - list.output_handle().replaceAllUsesWith(buffer); + list.getOutputHandle().replaceAllUsesWith(buffer); list.erase(); return success(); } @@ -627,7 +633,7 @@ LogicalResult HandleTensorListFromTensorOp( LogicalResult HandleTensorListPushBackOp( TF::TensorListPushBackOp push, llvm::SmallDenseMap* buffer_to_size) { - auto buffer = push.input_handle(); + auto buffer = push.getInputHandle(); auto it = buffer_to_size->find(buffer); if (it == buffer_to_size->end()) { return push.emitOpError( @@ -639,11 +645,11 @@ LogicalResult HandleTensorListPushBackOp( auto size = it->getSecond().size; OpBuilder builder(push); auto new_buffer = - cutil::SetElement(size, buffer, push.tensor(), builder, push.getLoc()); + cutil::SetElement(size, buffer, push.getTensor(), builder, push.getLoc()); auto new_size = builder.create( push.getLoc(), ArrayRef{size.getType()}, ArrayRef{size, cutil::GetR1Const({1LL}, builder, push.getLoc())}); - push.output_handle().replaceAllUsesWith(new_buffer); + push.getOutputHandle().replaceAllUsesWith(new_buffer); (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false}; push.erase(); return success(); @@ -652,7 +658,7 @@ LogicalResult HandleTensorListPushBackOp( LogicalResult HandleTensorListPopBackOp( TF::TensorListPopBackOp pop, llvm::SmallDenseMap* buffer_to_size) { - auto buffer = pop.input_handle(); + auto buffer = pop.getInputHandle(); auto it = buffer_to_size->find(buffer); if (it == buffer_to_size->end()) { pop.emitOpError("found tf.TensorListPopBack on unknown TensorList."); @@ -669,8 +675,8 @@ LogicalResult HandleTensorListPopBackOp( pop.getLoc(), ArrayRef{size.getType()}, ArrayRef{size, cutil::GetR1Const({1LL}, builder, pop.getLoc())}); auto element = cutil::GetElement(new_size, new_buffer, builder, pop.getLoc()); - pop.output_handle().replaceAllUsesWith(new_buffer); - pop.tensor().replaceAllUsesWith(element); + pop.getOutputHandle().replaceAllUsesWith(new_buffer); + pop.getTensor().replaceAllUsesWith(element); pop.erase(); (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false}; return success(); @@ -679,18 +685,18 @@ LogicalResult HandleTensorListPopBackOp( LogicalResult HandleTensorListGetItemOp( TF::TensorListGetItemOp get_item, const llvm::SmallDenseMap& buffer_to_size) { - auto buffer = get_item.input_handle(); + auto buffer = get_item.getInputHandle(); auto it = buffer_to_size.find(buffer); if (it == buffer_to_size.end()) { get_item.emitOpError("found tf.TensorListGetItemOp on unknown TensorList."); return failure(); } OpBuilder builder(get_item); - auto index = cutil::ReshapeScalarToSizeType(builder, get_item.index(), + auto index = cutil::ReshapeScalarToSizeType(builder, get_item.getIndex(), get_item.getLoc()); auto element = cutil::GetElement(index, buffer, OpBuilder(get_item), get_item.getLoc()); - get_item.item().replaceAllUsesWith(element); + get_item.getItem().replaceAllUsesWith(element); get_item.erase(); return success(); } @@ -698,18 +704,18 @@ LogicalResult HandleTensorListGetItemOp( LogicalResult HandleTensorListSetItemOp( TF::TensorListSetItemOp set_item, llvm::SmallDenseMap* buffer_to_size) { - auto buffer = set_item.input_handle(); + auto buffer = set_item.getInputHandle(); auto it = buffer_to_size->find(buffer); if (it == buffer_to_size->end()) { set_item.emitOpError("found tf.TensorListSetItemOp on unknown TensorList."); return failure(); } OpBuilder builder(set_item); - auto index = cutil::ReshapeScalarToSizeType(builder, set_item.index(), + auto index = cutil::ReshapeScalarToSizeType(builder, set_item.getIndex(), set_item.getLoc()); - auto new_buffer = cutil::SetElement(index, buffer, set_item.item(), builder, - set_item.getLoc()); - set_item.output_handle().replaceAllUsesWith(new_buffer); + auto new_buffer = cutil::SetElement(index, buffer, set_item.getItem(), + builder, set_item.getLoc()); + set_item.getOutputHandle().replaceAllUsesWith(new_buffer); auto size = it->getSecond(); (*buffer_to_size)[new_buffer] = size; set_item.erase(); @@ -719,7 +725,7 @@ LogicalResult HandleTensorListSetItemOp( LogicalResult HandleTensorListLengthOp( TF::TensorListLengthOp length, const llvm::SmallDenseMap& buffer_to_size) { - auto it = buffer_to_size.find(length.input_handle()); + auto it = buffer_to_size.find(length.getInputHandle()); if (it == buffer_to_size.end()) { length.emitOpError("found tf.TensorListLength on unknown TensorList."); return failure(); @@ -727,9 +733,10 @@ LogicalResult HandleTensorListLengthOp( OpBuilder builder(length); if (it->getSecond().fixed) { auto dim = cutil::CreateScalarConst( - length.input_handle().getType().cast().getDimSize(0), + length.getInputHandle().getType().cast().getDimSize( + 0), builder, length.getLoc()); - length.length().replaceAllUsesWith(dim); + length.getLength().replaceAllUsesWith(dim); } else { auto current_size = it->getSecond().size; // Reshapes the R1 length to a scalar. @@ -739,7 +746,7 @@ LogicalResult HandleTensorListLengthOp( {}, getElementTypeOrSelf(current_size.getType()))}, ArrayRef{current_size, cutil::GetR1Const({}, builder, length.getLoc())}); - length.length().replaceAllUsesWith(reshape); + length.getLength().replaceAllUsesWith(reshape); } length.erase(); return success(); @@ -748,15 +755,15 @@ LogicalResult HandleTensorListLengthOp( LogicalResult HandleTensorListElementShapeOp( TF::TensorListElementShapeOp elem_shape, const llvm::SmallDenseMap& buffer_to_size) { - if (buffer_to_size.count(elem_shape.input_handle()) == 0) { + if (buffer_to_size.count(elem_shape.getInputHandle()) == 0) { return elem_shape.emitOpError("unknown tensor list"); } - auto buffer = elem_shape.input_handle(); + auto buffer = elem_shape.getInputHandle(); auto result = cutil::GetR1Const( buffer.getType().cast().getShape().drop_front(), OpBuilder(elem_shape), elem_shape.getLoc(), - elem_shape.shape_type().getIntOrFloatBitWidth()); - elem_shape.element_shape().replaceAllUsesWith(result); + elem_shape.getShapeType().getIntOrFloatBitWidth()); + elem_shape.getElementShape().replaceAllUsesWith(result); elem_shape.erase(); return success(); } @@ -764,14 +771,14 @@ LogicalResult HandleTensorListElementShapeOp( LogicalResult HandleTensorListGatherOp( TF::TensorListGatherOp gather, const llvm::SmallDenseMap& buffer_to_size) { - auto it = buffer_to_size.find(gather.input_handle()); + auto it = buffer_to_size.find(gather.getInputHandle()); if (it == buffer_to_size.end()) { return gather.emitOpError("unknown tensor list"); } - auto buffer = gather.input_handle(); - auto result = cutil::GatherElements(gather.indices(), buffer, + auto buffer = gather.getInputHandle(); + auto result = cutil::GatherElements(gather.getIndices(), buffer, OpBuilder(gather), gather.getLoc()); - gather.values().replaceAllUsesWith(result); + gather.getValues().replaceAllUsesWith(result); gather.erase(); return success(); } @@ -779,24 +786,24 @@ LogicalResult HandleTensorListGatherOp( LogicalResult HandleTensorListScatterIntoExistingListOp( TF::TensorListScatterIntoExistingListOp scatter, llvm::SmallDenseMap* buffer_to_size) { - auto it = buffer_to_size->find(scatter.input_handle()); + auto it = buffer_to_size->find(scatter.getInputHandle()); if (it == buffer_to_size->end()) { return scatter.emitOpError("unknown tensor list"); } - auto buffer = scatter.input_handle(); + auto buffer = scatter.getInputHandle(); OpBuilder builder(scatter); - auto indices_type = scatter.indices().getType().cast(); + auto indices_type = scatter.getIndices().getType().cast(); if (!indices_type) return scatter.emitOpError("unranked indices shape"); auto shape_type = RankedTensorType::get({2}, builder.getIntegerType(32)); auto shape = builder.create( scatter.getLoc(), DenseElementsAttr::get( shape_type, {static_cast(indices_type.getDimSize(0)), 1})); - auto indices = - builder.create(scatter.getLoc(), scatter.indices(), shape); + auto indices = builder.create(scatter.getLoc(), + scatter.getIndices(), shape); Value tensor_scatter_update = builder.create( - scatter.getLoc(), buffer, indices, scatter.tensor()); - scatter.output_handle().replaceAllUsesWith(tensor_scatter_update); + scatter.getLoc(), buffer, indices, scatter.getTensor()); + scatter.getOutputHandle().replaceAllUsesWith(tensor_scatter_update); scatter.erase(); auto size = it->getSecond(); (*buffer_to_size)[tensor_scatter_update] = size; @@ -846,7 +853,7 @@ LogicalResult DecomposeTensorListOpsInternal( return failure(); } } else if (auto stack = llvm::dyn_cast(&op)) { - stack.tensor().replaceAllUsesWith(stack.input_handle()); + stack.getTensor().replaceAllUsesWith(stack.getInputHandle()); stack.erase(); } else if (auto elem_shape = llvm::dyn_cast(&op)) { @@ -867,15 +874,15 @@ LogicalResult DecomposeTensorListOpsInternal( } else if (auto addn = llvm::dyn_cast(&op)) { auto it = buffer_to_size->find(addn.getOperand(0)); if (it != buffer_to_size->end()) { - addn.sum().setType(addn.getOperand(0).getType()); + addn.getSum().setType(addn.getOperand(0).getType().cast()); auto size = it->getSecond(); - (*buffer_to_size)[addn.sum()] = size; + (*buffer_to_size)[addn.getSum()] = size; } } else if (auto zeros = llvm::dyn_cast(&op)) { - if (buffer_to_size->count(zeros.x()) > 0) { - zeros.y().setType(zeros.x().getType()); - auto size = (*buffer_to_size)[zeros.x()]; - (*buffer_to_size)[zeros.y()] = size; + if (buffer_to_size->count(zeros.getX()) > 0) { + zeros.getY().setType(zeros.getX().getType()); + auto size = (*buffer_to_size)[zeros.getX()]; + (*buffer_to_size)[zeros.getY()] = size; } } else if (auto while_op = llvm::dyn_cast(&op)) { if (failed(HandleWhileOp(while_op, module, buffer_to_size, diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc index ea555789383..72302903b37 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.cc @@ -29,7 +29,7 @@ struct FuseParallelMapAndBatch : public OpRewritePattern { LogicalResult matchAndRewrite(BatchDatasetV2Op op, PatternRewriter &rewriter) const override { - auto batchInputDataset = op.input_dataset(); + auto batchInputDataset = op.getInputDataset(); ParallelMapDatasetOp batchInputOp = dyn_cast_or_null( batchInputDataset.getDefiningOp()); @@ -39,19 +39,19 @@ struct FuseParallelMapAndBatch : public OpRewritePattern { // and MapAndBatchDataset is different (int32 and int64 respectively) auto num_parallel_calls_op = rewriter.create( op.getLoc(), UnrankedTensorType::get(rewriter.getIntegerType(64)), - batchInputOp.num_parallel_calls(), rewriter.getBoolAttr(false)); + batchInputOp.getNumParallelCalls(), rewriter.getBoolAttr(false)); - if (op.metadata() != batchInputOp.metadata()) { + if (op.getMetadata() != batchInputOp.getMetadata()) { return failure(); } auto fused_op = rewriter.create( - op.getLoc(), op.getType(), batchInputOp.input_dataset(), - batchInputOp.other_arguments(), op.batch_size(), - num_parallel_calls_op.y(), op.drop_remainder(), batchInputOp.f(), - op.output_types(), op.output_shapes(), - batchInputOp.preserve_cardinality(), op.metadata()); - rewriter.replaceOp(op, {fused_op.handle()}); + op.getLoc(), op.getType(), batchInputOp.getInputDataset(), + batchInputOp.getOtherArguments(), op.getBatchSize(), + num_parallel_calls_op.getY(), op.getDropRemainder(), + batchInputOp.getF(), op.getOutputTypes(), op.getOutputShapes(), + batchInputOp.getPreserveCardinality(), op.getMetadata()); + rewriter.replaceOp(op, {fused_op.getHandle()}); return failure(); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td index 04aac42d728..f46fb14d2ca 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.td @@ -273,11 +273,61 @@ def LaunchToDeviceAttributePass : Pass<"tf-launch-to-device-attribute", "mlir::f ``` }]; + let options = [ + Option<"legacy_graph_export_", "legacy-graph-export", "bool", + /*default=*/"true", + "Determines whether or not this pass should execute logic that is " + "reserved for the legacy graph export pipeline to maintain expected " + "invariants. In the case of this pass, that means manually propagating " + "controls to lifted parallel execute regions to the graph fetch to " + "ensure the ops execute, as well as determining whether or not the " + "islands created by this pass should be split after the replicated " + "ops have been lifted."> + ]; + let constructor = "TFDevice::CreateLaunchToDeviceAttributePass()"; } def XlaClusterFormationPass : Pass<"tf-xla-cluster-formation", "ModuleOp"> { - let summary = "Encapsulate StatefulPartitionedCallOp within a Cluster op"; + let summary = "Encapsulate partitioned calls within a Cluster op"; + let description = [{ + This pass clusters `tf.PartitionedCall` and `tf.StatefulPartitionedCall` + with `_xla_compile_device_type` attribute into a `tf_device.cluster`. + Notice this pass will only rewrite the outermost call if there are nested + calls to avoid nested `tf.XlaLaunch` operations from being created later. + + For example, the following code + + ```mlir + func.func @main() -> tensor { + %0 = "tf.StatefulPartitionedCall"() {_xla_compile_device_type = "CPU", f = @stateful_pcall_func} : () -> (tensor) + func.return %0 : tensor + } + + func.func @stateful_pcall_func() -> tensor { + %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + func.return %0 : tensor + } + ``` + + will be transformed into, + + ```mlir + func.func @main() -> tensor { + %0 = "tf_device.cluster"() ({ + %1 = "tf.StatefulPartitionedCall"() {_xla_compile_device_type = "CPU", f = @stateful_pcall_func} : () -> tensor + tf_device.return %1 : tensor + }) : () -> tensor + func.return %0 : tensor + } + + func.func @stateful_pcall_func() -> tensor { + %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + func.return %0 : tensor + } + + ``` + }]; let constructor = "TFDevice::CreateXlaClusterFormationPass()"; let dependentDialects = ["tf_device::TensorFlowDeviceDialect"]; } @@ -293,16 +343,23 @@ def XlaRewritePass : Pass<"tf-xla-rewrite", "mlir::ModuleOp"> { let description = [{ This pass rewrites `tf.PartitionedCall` and `tf.StatefulPartitionedCall` - operations with `_xla_compile_device_type` attribute into `tf.XlaLaunch` - operations. This makes the attached function execute with XLA. - `tf.XlaLaunch` requires resource-type arguments come at the end, so this - pass rewrites the called function if necessary. + operations with `_xla_compile_device_type` attribute in a + `tf_device.cluster` into `tf.XlaLaunch` operations. This makes the attached + function execute with XLA. `tf.XlaLaunch` requires resource-type arguments + come at the end, so this pass rewrites the called function if necessary. + This pass assumes there are no nested `tf_device.cluster`s so we don't end + up creating nested XLA launch ops. For example, the `tf.PartitionedCall` operation in the following code ```mlir func.func @convert_partitioned_call_with_resources(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.PartitionedCall"(%arg0, %arg1) {_xla_compile_device_type = "CPU", config = "", config_proto = "", device = "/device:CPU:0", executor_type = "", f = @pcall_func_with_resources} : (tensor, tensor) -> (tensor) + %0 = "tf.PartitionedCall"(%arg0, %arg1) {_xla_compile_device_type = "CPU", f = @pcall_func_with_resources} : (tensor, tensor) -> (tensor) + %0 = "tf_device.cluster"() ({ + %1 = "tf.StatefulPartitionedCall"() {_xla_compile_device_type = "CPU", f = @stateful_pcall_func} : () -> tensor + tf_device.return %1 : tensor + }) : () -> tensor + func.return %0 : tensor } @@ -315,9 +372,14 @@ def XlaRewritePass : Pass<"tf-xla-rewrite", "mlir::ModuleOp"> { ```mlir func.func @convert_partitioned_call_with_resources(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.XlaLaunch"(%arg1, %arg0) {_xla_compile_device_type = "CPU", device = "/device:CPU:0", function = @pcall_func_with_resources_0, operand_segment_sizes = array} : (tensor, tensor) -> tensor + %0 = "tf_device.cluster"() ({ + %1 = "tf.XlaLaunch"(%arg1, %arg0) {_xla_compile_device_type = "CPU", function = @pcall_func_with_resources_0, operand_segment_sizes = array} : (tensor, tensor) -> tensor + tf_device.return %1 : tensor + }) : () -> tensor + return %0 : tensor } + func.func @pcall_func_with_resources_0(%arg0: tensor, %arg1: tensor) -> tensor { return %arg0 : tensor } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 29a10930567..e84cf959800 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" @@ -31,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/public/session_options.h" +#include "tensorflow/tsl/platform/statusor.h" #define DEBUG_TYPE "run-tf-graph-optimization" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td index 70c0b1f0bb3..647cad07aaf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td @@ -306,7 +306,7 @@ def UnrollBatchMatMulPass : Pass<"tf-unroll-batch-matmul", "mlir::func::FuncOp"> let constructor = "TF::CreateUnrollBatchMatMulPassPass()"; } -def ClusterFormationPass : Pass<"tf-device-cluster-formation", "mlir::func::FuncOp"> { +def ClusterFormationPass : Pass<"tf-device-cluster-formation", "mlir::ModuleOp"> { let summary = "Form clusters from instructions assigned to same device"; let constructor = "TFDevice::CreateClusterFormationPass()"; let dependentDialects = ["tf_device::TensorFlowDeviceDialect"]; @@ -368,6 +368,17 @@ def LegalizeTFGToTFPass : Pass<"tfe-legalize-tfg", "ModuleOp"> { def ReplicateToIslandPass : Pass<"tf-replicate-to-island", "mlir::func::FuncOp"> { let summary = "Lowers device replicate to executor islands"; let constructor = "TFDevice::CreateReplicateToIslandPass()"; + let options = [ + Option<"legacy_graph_export_", "legacy-graph-export", "bool", + /*default=*/"true", + "Determines whether or not this pass should execute logic that is " + "reserved for the legacy graph export pipeline to maintain expected " + "invariants. In the case of this pass, that means manually propagating " + "controls to lifted parallel execute regions to the graph fetch to " + "ensure the ops execute, as well as determining whether or not the " + "islands created by this pass should be split after the replicated " + "ops have been lifted."> + ]; } def ReplicaIDToDeviceOrdinalPass : Pass<"tf-replica-id-to-device-ordinal", "mlir::func::FuncOp"> { @@ -388,6 +399,27 @@ def ConvertReadonlyReferenceVariablesToResourceVariablesPass : let constructor = "TF::CreateConvertReadonlyReferenceVariablesToResourceVariablesPass()"; } +def ReplicateTensorListInitOpsPass : Pass<"tf-replicate-tensor-list-init-ops", "mlir::func::FuncOp"> { + let summary = + "Replicate TensorList init ops for correct shape assignments in shape inference"; + + let description = [{ + If we pass same TensorList to a while op as multiple arguments or just use + the same TensorList at multiple places and assign different + TensorListSetItem to elements of TensorList, the shape inference is then + unable to identify the Shape of these args and thus the input TensorList + shape is unidentifiable. + All of these args are supposed to be independent and not related to original + creation of TensorList. + + This pass will create multiple instances of TensorList for each arg of the + while op and each use and thus there will be not a conflict in resolving the + shape of these different inputs. + }]; + + let constructor = "TF::CreateReplicateTensorListInitOpsPass()"; +} + def TensorFlowShapeInferencePass : Pass<"tf-shape-inference", "ModuleOp"> { let summary = "Shape inference on TF dialect and ops implementing InferTypeOpInterface"; @@ -904,6 +936,12 @@ def ExecutorTPUV1IslandInliningPass : Pass<"tf-executor-tpu-v1-island-inlining", let constructor = "tf_executor::CreateTFExecutorTPUV1IslandInliningPass()"; } +def TPUPartitionedOpConversionPass : Pass<"tf-tpu-partitioned-op-conversion", "mlir::func::FuncOp"> { + let summary = "Rewrite all TPU Partitioned ops into their V2 counterparts."; + + let constructor = "TFTPU::CreateTPUPartitionedOpConversionPass()"; +} + def TPUClusterFormationPass : Pass<"tf-tpu-cluster-formation", "ModuleOp"> { let summary = "Forms clusters from operations assigned to the same TPU computation"; @@ -2181,7 +2219,7 @@ def SplitIntoIslandPerOpPass : Pass<"tf-executor-split-into-island-per-op", "mli } ``` }]; - let constructor = "mlir::CreateSplitIntoIslandPerOpPass()"; + let constructor = "mlir::TF::CreateSplitIntoIslandPerOpPass()"; let dependentDialects = ["mlir::tf_executor::TensorFlowExecutorDialect"]; } @@ -2257,6 +2295,15 @@ def BroadcastFoldPass : Pass<"tf-broadcast-fold", "mlir::func::FuncOp"> { def ParallelExecuteToIslandsPass : Pass<"tf-parallel-execute-to-islands", "mlir::func::FuncOp"> { let summary = "Lowers device parallel_execute to executor islands"; let constructor = "TFDevice::CreateParallelExecuteToIslandsPass()"; + let options = [ + Option<"legacy_graph_export_", "legacy-graph-export", "bool", + /*default=*/"true", + "Determines whether or not this pass should execute logic that is " + "reserved for the legacy graph export pipeline to maintain expected " + "invariants. In the case of this pass, that means manually propagating " + "controls to lifted parallel execute regions to the graph fetch to " + "ensure the ops execute."> + ]; } def ConstantOpDeviceAssignmentPass : Pass<"constant-op-device-assignment", "ModuleOp"> { @@ -2345,7 +2392,7 @@ def DeviceIndexSelectorPass : Pass<"tf-device-index-selector", "mlir::func::Func let constructor = "TF::CreateDeviceIndexSelectorPass()"; } -def OrderByDialectPass : Pass<"tf-order-by-dialect", "mlir::func::FuncOp"> { +def OrderByDialectPass : Pass<"tf-order-by-dialect", "mlir::ModuleOp"> { let summary = "Reorders ops so ops of the same dialect are next to each other."; let description = [{ Performs a reordering of ops so that diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.cc index 829abbbb89e..20326e2ef49 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.cc @@ -15,9 +15,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" +#include #include #include +#include "absl/algorithm/container.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseSet.h" @@ -46,6 +48,9 @@ namespace mlir { namespace tf_saved_model { namespace { +// Attribute name that specifies the input shapes of a function. +constexpr StringRef kTfInputShapesAttr = "tf._input_shapes"; + // Build and returns ElementsAttr which holds the data in 'tensor'. ElementsAttr GetTensorValueAsElementsAttr(const tensorflow::Tensor& tensor, OpBuilder builder) { @@ -119,7 +124,7 @@ void PropagateUsage( } } else if (auto if_op = dyn_cast(user_op)) { (*arguments_to_erase)[if_op].push_back(argument_index); - for (auto callee : {&if_op.then_branch(), &if_op.else_branch()}) { + for (auto callee : {&if_op.getThenBranch(), &if_op.getElseBranch()}) { work_list->push_back(std::make_pair(callee, argument_index)); } } else if (auto while_op = dyn_cast(user_op)) { @@ -130,7 +135,7 @@ void PropagateUsage( } } else if (auto while_op = dyn_cast(user_op)) { (*arguments_to_erase)[while_op].push_back(argument_index); - for (auto callee : {&while_op.cond(), &while_op.body()}) { + for (auto callee : {&while_op.getCond(), &while_op.getBody()}) { work_list->push_back(std::make_pair(callee, argument_index)); } } @@ -176,22 +181,6 @@ void ReplaceVarWithConstant( } } -// Helper that returns the FuncOp that is the SessionInit function which -// will be called to initialize all resources. -// Returns nullptr if no function is found. -func::FuncOp GetSessionInitializerFunc(ModuleOp module) { - auto session_init_op = tf_saved_model::GetSessionInitializerOp(module); - SymbolTable symbol_table(module); - if (session_init_op && !session_init_op.getInitializers().empty()) { - func::FuncOp init_func_op = symbol_table.lookup( - session_init_op.getInitializers()[0] - .cast() - .getValue()); - return init_func_op; - } - return nullptr; -} - // Returns ID for identifying a resource. std::tuple GetResourceKey( Operation* op) { @@ -254,9 +243,9 @@ void RemoveVariablesInitializations( // Updates terminator op arguments of 'func' after removing arguments // specified in 'arguments_to_erase'. template -void UpdateTerminatorArguments( - T& func, const llvm::SmallVector& arguments_to_erase, - llvm::BitVector& erase_indices) { +void UpdateTerminatorArguments(T& func, + const ArrayRef arguments_to_erase, + llvm::BitVector& erase_indices) { auto terminator = func.front().getTerminator(); int num_operands = terminator->getNumOperands(); erase_indices.resize(num_operands); @@ -264,8 +253,7 @@ void UpdateTerminatorArguments( auto argument = func.getArgument(arg_index); for (auto& use : argument.getUses()) { if (llvm::isa(use.getOwner())) { - int operand_index = use.getOperandNumber(); - erase_indices.set(operand_index); + erase_indices.set(use.getOperandNumber()); } } func.getArgument(arg_index).dropAllUses(); @@ -305,9 +293,67 @@ T GetUpdatedWhileOp(T while_op, const U& argument_types, return new_while_op; } +// Erases function arguments indexed at `args_to_erase`. Also applies the +// changes to any relevant function attributes accordingly. +void EraseFuncOpArguments(func::FuncOp func_op, + const ArrayRef args_to_erase) { + BitVector args_to_erase_bit_vector(func_op.getNumArguments()); + for (const unsigned i : args_to_erase) args_to_erase_bit_vector.set(i); + + func_op.eraseArguments(args_to_erase_bit_vector); + + // Erases entries in "tf._input_shapes" attribute of `func_op` that correspond + // to the erased arguments. + if (auto input_shapes_attr = + func_op->getAttrOfType(kTfInputShapesAttr); + input_shapes_attr) { + // Construct a new array of input shapes excluding the input shapes of the + // erased arguments. + SmallVector updated_input_shapes_attr; + for (const unsigned i : args_to_erase_bit_vector.flip().set_bits()) { + updated_input_shapes_attr.emplace_back(input_shapes_attr[i]); + } + + // Replaces the attribute with the updated "#tf_type.shape" array. + // Builder builder(func_op.getContext()); + func_op->setAttr( + kTfInputShapesAttr, + ArrayAttr::get(func_op.getContext(), updated_input_shapes_attr)); + } +} + +// Validates func ops. Returns `failure` if the function is invalid. +LogicalResult ValidateFuncOp(func::FuncOp func_op) { + auto input_shapes_attr = + func_op->getAttrOfType(kTfInputShapesAttr); + if (!input_shapes_attr) return success(); + + if (input_shapes_attr.size() != func_op.getNumArguments()) { + return func_op->emitError( + "Number of arguments and 'tf._input_shapes' " + "attribute size do not match. ") + << "Num args: " << func_op.getNumArguments() + << ", tf._input_shapes size: " << input_shapes_attr.size(); + } + + return success(); +} + +// Validates ModuleOp. Returns `failure` if the module op is invalid. +LogicalResult ValidateModule(ModuleOp module_op) { + for (auto func_op : module_op.getOps()) { + if (failed(ValidateFuncOp(func_op))) { + return failure(); + } + } + return success(); +} + } // namespace LogicalResult FreezeVariables(ModuleOp module, tensorflow::Session* session) { + if (failed(ValidateModule(module))) return failure(); + const tensorflow::DeviceMgr* mgr = nullptr; auto status = session->LocalDeviceManager(&mgr); if (!status.ok()) { @@ -316,7 +362,10 @@ LogicalResult FreezeVariables(ModuleOp module, tensorflow::Session* session) { return failure(); } - func::FuncOp session_init_func = GetSessionInitializerFunc(module); + SmallVector session_init_funcs = + tf_saved_model::GetInitializerFunctions(module); + func::FuncOp session_init_func = + session_init_funcs.empty() ? nullptr : session_init_funcs[0]; TF::ResourceAnalyzer analyzer(module, /*skip_session_init=*/true); llvm::SmallVector variables; @@ -324,7 +373,7 @@ LogicalResult FreezeVariables(ModuleOp module, tensorflow::Session* session) { for (auto func : module.getOps()) { if (func == session_init_func) continue; for (auto var_handle_op : func.getOps()) { - if (!analyzer.IsPotentiallyWritten(var_handle_op.resource())) { + if (!analyzer.IsPotentiallyWritten(var_handle_op.getResource())) { variables.push_back(var_handle_op); } } @@ -347,40 +396,34 @@ LogicalResult FreezeVariables(ModuleOp module, tensorflow::Session* session) { // Container to hold all update actions on ops. // Key: Operation to update. - // Value: optional list of arguments to delete from this op. + // Value: optional list of argument indices to delete from this op. // Note that we use MapVector because we want to iterate on the same order // of insertion. llvm::MapVector> arguments_to_erase; - for (auto variable_value_pair : + for (auto [var_handle_op, resource_tensor] : llvm::zip(variables, resource_tensors_or.value())) { - auto var_handle_op = std::get<0>(variable_value_pair); builder.setInsertionPointAfterValue(var_handle_op); auto elements_attr = GetTensorValueAsElementsAttr( - var_handle_op, std::get<1>(variable_value_pair), mgr, builder); + var_handle_op, resource_tensor, mgr, builder); ReplaceVarWithConstant(var_handle_op, elements_attr, &arguments_to_erase); } // All updates to different ops are captured in 'arguments_to_erase'. // Now loop on them and based on each item type update accordingly. - for (auto& items : arguments_to_erase) { - auto* user_op = items.first; - auto& args_to_erase = items.second; + for (auto& [user_op, args_to_erase] : arguments_to_erase) { if (auto func = dyn_cast(user_op)) { // To update a function we will need to: // 1) Remove the unused arguments from the function itself. + // 1-2) Remove func attributes corresponding to the removed arguments. // 2) Remove any returns that are not needed from the function terminator - // op in the function. 3) Update function result to match the terminator. + // op in the function. + // 3) Update function result to match the terminator. llvm::BitVector result_indices_to_erase; UpdateTerminatorArguments(func, args_to_erase, result_indices_to_erase); - llvm::BitVector args_to_erase_bit_vector(func.getNumArguments()); - for (auto i : args_to_erase) args_to_erase_bit_vector.set(i); - func.eraseArguments(args_to_erase_bit_vector); - llvm::BitVector indices_to_erase(func.getNumResults()); - const int indices_to_erase_size = result_indices_to_erase.size(); - for (int i = 0; i < indices_to_erase_size; ++i) - if (result_indices_to_erase.test(i)) indices_to_erase.set(i); - func.eraseResults(indices_to_erase); + EraseFuncOpArguments(func, args_to_erase); + + func.eraseResults(result_indices_to_erase); } else if (auto read_var = dyn_cast(user_op)) { // Read variables was already replaced by constant op. Just remove the op. read_var->erase(); @@ -390,20 +433,20 @@ LogicalResult FreezeVariables(ModuleOp module, tensorflow::Session* session) { while_op->erase(); } else if (auto while_op = dyn_cast(user_op)) { auto new_while_op = GetUpdatedWhileOp( - while_op, while_op.cond().getArgumentTypes(), args_to_erase); - new_while_op.cond().takeBody(while_op.cond()); - new_while_op.body().takeBody(while_op.body()); + while_op, while_op.getCond().getArgumentTypes(), args_to_erase); + new_while_op.getCond().takeBody(while_op.getCond()); + new_while_op.getBody().takeBody(while_op.getBody()); llvm::BitVector erase_indices; - UpdateTerminatorArguments(new_while_op.body(), args_to_erase, + UpdateTerminatorArguments(new_while_op.getBody(), args_to_erase, erase_indices); llvm::BitVector body_bit_vector( - new_while_op.body().front().getNumArguments()); + new_while_op.getBody().front().getNumArguments()); for (auto i : args_to_erase) body_bit_vector.set(i); - new_while_op.body().front().eraseArguments(body_bit_vector); + new_while_op.getBody().front().eraseArguments(body_bit_vector); llvm::BitVector cond_bit_vector( - new_while_op.cond().front().getNumArguments()); + new_while_op.getCond().front().getNumArguments()); for (auto i : args_to_erase) cond_bit_vector.set(i); - new_while_op.cond().front().eraseArguments(cond_bit_vector); + new_while_op.getCond().front().eraseArguments(cond_bit_vector); while_op->erase(); } else { llvm::BitVector erase_indices(user_op->getNumOperands()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h index e2ebd57b3e3..9530262a423 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h @@ -60,6 +60,10 @@ CreateLowerVariableOpsToMlProgramPass(); // Strips saved_model attributes from a module and its functions. std::unique_ptr> CreateStripSavedModuleMetadataPass(); +// Convert the session initializer to a function. +std::unique_ptr> +CreateConvertSessionInitializerToFunctionPass(); + #define GEN_PASS_REGISTRATION #define GEN_PASS_DECL_DEDUPBOUNDINPUTBINDINGPASS #define GEN_PASS_DECL_FREEZEASSETSPASS diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_savedmodel_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_savedmodel_passes.td index 42a0aa84546..cbf294ff3a9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_savedmodel_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_savedmodel_passes.td @@ -132,3 +132,19 @@ def StripSavedModuleMetadataPass : Pass<"tf-strip-saved-module-metadata", "Modul }]; let constructor = "::mlir::tf_saved_model::CreateStripSavedModuleMetadataPass()"; } + +def ConvertSessionInitializerToFunctionPass : Pass<"tf-saved-model-convert-session-initializer-to-function", "ModuleOp"> { + let summary = "Converts the session initializer to a function."; + let description = [{ + This converts + "tf_saved_model.session_initializer"() {initializers = [@a, @b, @c]} : () -> () + to + func.func @session_initializer() { + call @a() : () -> () + call @b() : () -> () + call @c() : () -> () + return + } + }]; + let constructor = "::mlir::tf_saved_model::CreateConvertSessionInitializerToFunctionPass()"; +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc index 528691755f8..09ea90c8947 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc @@ -18,7 +18,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index f3dde17483e..c9dd2824ce1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -463,13 +463,14 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas, } // When model parallelism is used in conjunction with data parallelism // for resource inputs, we need to collect the per replica resource - // inputs from input to `tf.TPUPartitionedInput` ops. - if (auto pi = llvm::dyn_cast_or_null(def)) { + // inputs from input to `tf.TPUPartitionedInputV2` ops. + if (auto pi = + llvm::dyn_cast_or_null(def)) { if (pi->getNumOperands() != num_cores_per_replica) status = pi.emitOpError() << "requires " << num_cores_per_replica << " operands but found " << pi->getNumOperands(); - for (auto operand : pi.inputs()) { + for (auto operand : pi.getInputs()) { if (auto ri = llvm::dyn_cast_or_null( operand.getDefiningOp())) { if (!seen_ops.contains(ri)) { @@ -495,7 +496,7 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas, llvm::SmallVector packed_ops; for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) { auto input = pos_and_input.value(); - bool is_packed = input.is_packed(); + bool is_packed = input.getIsPacked(); const int num_operands = input->getNumOperands(); int num_inputs = is_packed ? 1 : num_replicas; if (num_operands != num_inputs) @@ -521,7 +522,7 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas, for (const auto& pos_and_input : llvm::enumerate(ordered_tpu_replicate_inputs)) { auto tpu_replicated_input = pos_and_input.value(); - if (tpu_replicated_input.is_mirrored_variable()) { + if (tpu_replicated_input.getIsMirroredVariable()) { mirrored_variable_indices.push_back(pos_and_input.index()); } } @@ -563,7 +564,7 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas, } } - // Collect all `tf.TPUPartitionedInput` ops to be moved inside the + // Collect all `tf.TPUPartitionedInputV2` ops to be moved inside the // `tf_device.replicate` later. llvm::SmallSet partitioned_inputs; for (auto input_and_block_arg : @@ -573,9 +574,9 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas, Value block_arg = std::get<1>(input_and_block_arg); mlir::replaceAllUsesInRegionWith(input->getResult(0), block_arg, cluster.getBody()); - // Update replicated input use in tf.TPUPartitionedInput op. + // Update replicated input use in tf.TPUPartitionedInputV2 op. for (auto& use : input->getUses()) { - auto pi = llvm::dyn_cast(use.getOwner()); + auto pi = llvm::dyn_cast(use.getOwner()); if (pi) { pi.setOperand(use.getOperandNumber(), block_arg); partitioned_inputs.insert(pi.getOperation()); @@ -584,7 +585,7 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas, } // Create terminator for replicate op and move `tf_device.cluster` and - // `tf.TPUPartitionedInput`(s) into replicate body. + // `tf.TPUPartitionedInputV2`(s) into replicate body. builder.setInsertionPointToEnd(&replicate_op.GetBody()); auto return_op = builder.create(replicate_op.getLoc(), cluster.getResults()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_composite_resource_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_composite_resource_ops.cc index 9bcfad6038f..098d34e05cf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_composite_resource_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_colocate_composite_resource_ops.cc @@ -82,7 +82,8 @@ llvm::SmallVector GetResourceOpsUsingCompositeArgsInReplicate( // Account for pass-through identity ops. if (auto pass_through_identity = llvm::dyn_cast(resource_user)) { - for (auto identity_user : pass_through_identity.output().getUsers()) { + for (auto identity_user : + pass_through_identity.getOutput().getUsers()) { new_resource_users.emplace_back(identity_user); } } @@ -97,8 +98,7 @@ void ColocateCompositeResourceOpsInReplicate( tf_device::ReplicateOp replicate_op, OpBuilder* builder) { auto devices = replicate_op.getDevices(); if (!devices) return; - if (!devices.getValue().get(tensorflow::GetDeviceAliasForLogicalCore(0))) - return; + if (!devices.value().get(tensorflow::GetDeviceAliasForLogicalCore(0))) return; const auto composite_resource_users = GetResourceOpsUsingCompositeArgsInReplicate(replicate_op); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc index 055ae25ee86..c8ad200e328 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc @@ -78,7 +78,7 @@ bool IsSupportedInputOp( TF::IteratorGetNextOp iterator_op = llvm::dyn_cast(op); if (!iterator_op) return false; - Value resource_iterator = iterator_op.iterator(); + Value resource_iterator = iterator_op.getIterator(); if (resource_alias_analysis.IsUnknownResource(resource_iterator)) return false; @@ -127,7 +127,7 @@ TF::TPUGetLayoutOp BuildGetLayout(const int64_t execute_arg_index, OpBuilder* builder) { return builder->create( compile_launch.getLoc(), - llvm::ArrayRef{RankedTensorType::get({ShapedType::kDynamicSize}, + llvm::ArrayRef{RankedTensorType::get({ShapedType::kDynamic}, builder->getIntegerType(64))}, llvm::ArrayRef{compilation_key}, llvm::ArrayRef{ @@ -143,7 +143,7 @@ TF::TPUCopyWithLayoutOp BuildCopyWithLayout(tf_device::LaunchOp execute_launch, Value input, OpBuilder* builder) { return builder->create( execute_launch.getLoc(), llvm::ArrayRef{input.getType()}, - llvm::ArrayRef{input, get_layout.layout()}); + llvm::ArrayRef{input, get_layout.getLayout()}); } // Performs transformation for a non-replicated input. @@ -151,7 +151,7 @@ void HandleInput(Value input, const int64_t execute_arg_index, TF::TPUExecuteOp execute, tf_device::LaunchOp execute_launch, tf_device::LaunchOp compile_launch) { OpBuilder builder = CreateBuilderAfterOp(compile_launch); - auto get_layout = BuildGetLayout(execute_arg_index, execute.key(), + auto get_layout = BuildGetLayout(execute_arg_index, execute.getKey(), compile_launch, &builder); builder.setInsertionPoint(execute_launch); auto copy_with_layout = BuildCopyWithLayout(execute_launch, compile_launch, @@ -187,7 +187,7 @@ bool HandleReplicatedInputs( entry.value().get(), &builder); auto device_list = replicate.getDevices() - .getValue() + .value() .get(execute_launch.getDevice()) .cast(); copy_with_layout->setAttr(kDeviceAttr, @@ -207,7 +207,7 @@ void HandleCompileAndExecutes( auto compile = llvm::cast(compile_launch.GetBody().front()); tensorflow::tpu::TPUCompileMetadataProto metadata; - metadata.ParseFromString(compile.metadata().str()); + metadata.ParseFromString(compile.getMetadata().str()); llvm::SmallVector, 4> input_mappings = tensorflow::GetMetadataArgumentMapping(metadata); @@ -222,14 +222,14 @@ void HandleCompileAndExecutes( llvm::cast(execute_launch.GetBody().front()); const auto& input_mapping = std::get<1>(execute_and_input_mapping); - for (auto& input_and_idx : llvm::enumerate(execute.args())) { + for (auto& input_and_idx : llvm::enumerate(execute.getArgs())) { Value input = input_and_idx.value(); const int64_t execute_arg_index = input_and_idx.index(); if (auto block_arg = input.dyn_cast()) { // For a block argument, consider transforms only when it is a // replicated input (defining ops will be outside the replicate node). if (maybe_replicate != block_arg.getParentRegion()->getParentOp() || - !HandleReplicatedInputs(execute_arg_index, execute.key(), + !HandleReplicatedInputs(execute_arg_index, execute.getKey(), execute_launch, compile_launch, block_arg, maybe_replicate, resource_alias_analysis)) { continue; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc index a253687b0d9..110b8e34281 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc @@ -32,7 +32,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -96,7 +96,7 @@ func::FuncOp BuildFunction(llvm::ArrayRef ops, Block* outlined_func_block = outlined_func.addEntryBlock(); // Clone the operations and remap the inputs to use the function arguments. - BlockAndValueMapping mapping; + IRMapping mapping; mapping.map(inputs, outlined_func.getArguments()); builder->setInsertionPoint(outlined_func_block, outlined_func_block->begin()); for (Operation* op : ops) { @@ -214,19 +214,19 @@ TF::IfRegionOp CloneEmptyIfWithPredicate(TF::IfRegionOp if_region, OpBuilder& builder) { // Mark op as stateful due to side-effecting communication ops added later. auto host_side_if = builder.create( - if_region.getLoc(), llvm::SmallVector{}, if_region.cond(), - /*is_stateless=*/false, if_region._then_func_nameAttr(), - if_region._else_func_nameAttr()); + if_region.getLoc(), llvm::SmallVector{}, if_region.getCond(), + /*is_stateless=*/false, if_region.get_thenFuncNameAttr(), + if_region.get_elseFuncNameAttr()); // Create empty then branch region. - auto& then_branch = host_side_if.then_branch(); + auto& then_branch = host_side_if.getThenBranch(); then_branch.push_back(new Block); builder.setInsertionPointToEnd(&then_branch.front()); builder.create(if_region.getLoc(), /*operands=*/ArrayRef{}); // Create empty else branch region. - auto& else_branch = host_side_if.else_branch(); + auto& else_branch = host_side_if.getElseBranch(); else_branch.push_back(new Block); builder.setInsertionPointToEnd(&else_branch.front()); builder.create(if_region.getLoc(), @@ -243,7 +243,7 @@ TF::WhileRegionOp CloneEmptyWhile(uint64_t parallel_iterations, Location loc, parallel_iterations, /*is_stateless=*/false, /*shape_invariant=*/false); // Create empty else branch region. - auto& body = host_side_while.body(); + auto& body = host_side_while.getBody(); body.push_back(new Block); builder.setInsertionPointToEnd(&body.front()); builder.create(loc, /*operands=*/ArrayRef{}); @@ -321,6 +321,14 @@ bool HasDynamicExternalValues(Operation* op) { .wasInterrupted(); } +// Checks if `type` is allowed for XLA. String and resources are not XLA types. +// There are other TF types that are not XLA types which will be removed by +// successive passes in TF/XLA bridge phase 2. +bool TypeValidForXLA(const Type& type) { + const Type elem = getElementTypeOrSelf(type); + return !elem.isa() && !elem.isa(); +} + // Returns operands of `cluster_ops` that need to be // communicated from device->host. This is for the case when all operands have a // static shape. @@ -330,15 +338,20 @@ llvm::SmallSetVector GetStaticExternalOperands( llvm::SmallSetVector external_values; for (Operation* op : cluster_ops) { op->walk([&](Operation* walked_op) { - if (llvm::isa( + if (llvm::isa( walked_op)) return WalkResult::advance(); for (Value v : walked_op->getOperands()) { + if (!TypeValidForXLA(v.getType())) continue; if (auto* defining_op = v.getDefiningOp()) { if (!op->isAncestor(defining_op) && tpu_cluster->isAncestor(defining_op) && !HasOutsideCompilationAncestor(defining_op) && - !llvm::isa(defining_op)) { + // Ignore operands that have already been received by a previously + // created cluster. + !llvm::isa( + defining_op)) { external_values.insert(v); } continue; @@ -361,6 +374,7 @@ llvm::SmallSetVector GetAllExternalOperands( for (Operation* op : cluster_ops) { op->walk([&](Operation* walked_op) { for (Value v : walked_op->getOperands()) { + if (!TypeValidForXLA(v.getType())) continue; Operation* defining_op = v.getDefiningOp(); if (!defining_op || !cluster_ops.count(defining_op)) { external_values.insert(v); @@ -406,7 +420,8 @@ void GetExternalOutputs(const llvm::SmallSetVector& cluster_ops, HasDynamicOutputs(user)) { if (!user_set.insert(user).second) continue; for (Value v : user->getOperands()) { - if (v.getDefiningOp() == op && !isa(user)) + if (TypeValidForXLA(v.getType()) && v.getDefiningOp() == op && + !isa(user)) external_outputs.insert(v); if (v.getDefiningOp() == op && isa(user)) tmp_host_outputs.push_back(v); @@ -447,11 +462,13 @@ void MarkOutsideCompiled(Operation* op) { } // Returns whether an outside compilation cluster should be closed. True when: -// 1. There is a dynamically shaped output consumed by a non-outside compiled +// 1. There is no non-XLA output. +// 2. There is a dynamically shaped output consumed by a non-outside compiled // op. -// 2. There is no dynamically shaped output. +// 3. There is no dynamically shaped output. bool ShouldCloseCluster(llvm::ArrayRef outputs) { bool has_dynamic_output = false; + bool has_nonxla_output = false; for (Value v : outputs) { if (TF::CanBeRefined(v.getType())) { has_dynamic_output = true; @@ -461,8 +478,12 @@ bool ShouldCloseCluster(llvm::ArrayRef outputs) { return true; } } + if (!TypeValidForXLA(v.getType())) + for (const Operation* user : v.getUsers()) + if (!isa(user)) has_nonxla_output = true; } - return !has_dynamic_output; + + return !has_nonxla_output && !has_dynamic_output; } // Replaces `external_operands` with the results from `recv_at_host`. @@ -715,14 +736,14 @@ LogicalResult DecomposeControlFlow(tf_device::ClusterOp tpu_cluster, if (!HasOutsideCompilationNested(op)) return WalkResult::advance(); OpBuilder builder(if_op); auto host_if = CloneEmptyIfWithPredicate(if_op, builder); - if (failed(MoveOpsToHost(tpu_cluster, &if_op.then_branch().front(), - host_if.then_branch().front().getTerminator(), + if (failed(MoveOpsToHost(tpu_cluster, &if_op.getThenBranch().front(), + host_if.getThenBranch().front().getTerminator(), compilation_key, device_ordinal, default_device_ordignal, communication_key_index))) return WalkResult::interrupt(); - if (failed(MoveOpsToHost(tpu_cluster, &if_op.else_branch().front(), - host_if.else_branch().front().getTerminator(), + if (failed(MoveOpsToHost(tpu_cluster, &if_op.getElseBranch().front(), + host_if.getElseBranch().front().getTerminator(), compilation_key, device_ordinal, default_device_ordignal, communication_key_index))) @@ -734,16 +755,17 @@ LogicalResult DecomposeControlFlow(tf_device::ClusterOp tpu_cluster, if (auto while_op = llvm::dyn_cast(op)) { if (!HasOutsideCompilationNested(op)) return WalkResult::advance(); OpBuilder builder(while_op); - auto host_while = CloneEmptyWhile(while_op.parallel_iterations(), + auto host_while = CloneEmptyWhile(while_op.getParallelIterations(), while_op.getLoc(), builder); const auto condition_send_recv_key = llvm::formatv("while_condition_channel_{0}", communication_key_index++) .str(); - auto& cond = host_while.cond(); + auto& cond = host_while.getCond(); cond.push_back(new Block); - auto condition = while_op.cond().front().getTerminator()->getOperand(0); - builder.setInsertionPoint(while_op.cond().front().getTerminator()); + auto condition = + while_op.getCond().front().getTerminator()->getOperand(0); + builder.setInsertionPoint(while_op.getCond().front().getTerminator()); builder.create(while_op.getLoc(), condition, condition_send_recv_key); builder.setInsertionPointToEnd(&cond.front()); @@ -754,13 +776,13 @@ LogicalResult DecomposeControlFlow(tf_device::ClusterOp tpu_cluster, builder.create(while_op.getLoc(), recv_condition_at_host->getResults()); - if (failed(MoveOpsToHost(tpu_cluster, &while_op.cond().front(), + if (failed(MoveOpsToHost(tpu_cluster, &while_op.getCond().front(), recv_condition_at_host, compilation_key, device_ordinal, default_device_ordignal, communication_key_index))) return WalkResult::interrupt(); - if (failed(MoveOpsToHost(tpu_cluster, &while_op.body().front(), - host_while.body().front().getTerminator(), + if (failed(MoveOpsToHost(tpu_cluster, &while_op.getBody().front(), + host_while.getBody().front().getTerminator(), compilation_key, device_ordinal, default_device_ordignal, communication_key_index))) @@ -1047,12 +1069,12 @@ LogicalResult CreateParallelExecuteForOutsideCompilation( builder.setInsertionPoint(tmp_host_launch_op.GetBody().getTerminator()); auto compilation_key_op = CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), builder); - Value compilation_key = compilation_key_op.program(); + Value compilation_key = compilation_key_op.getProgram(); auto device_ordinal_op = builder.create( tpu_cluster.getLoc(), RankedTensorType::get({}, builder.getI64Type())); Value device_ordinal = nullptr; if (tpu_cluster->getParentOfType()) { - device_ordinal = device_ordinal_op.device_ordinal(); + device_ordinal = device_ordinal_op.getDeviceOrdinal(); } int default_device_ordinal = 0; if (failed(GetDefaultDeviceOrdinal(tpu_cluster, default_device_ordinal))) { @@ -1100,19 +1122,11 @@ LogicalResult CreateParallelExecuteForOutsideCompilation( return success(); } -// Checks if `type` is allowed for data on TPUs. String and resources cannot be -// assigned to TPUs. There are other TF types that are not allowed on TPUs, but -// these will be removed by successive passes in TF/XLA bridge phase 2. -bool TypeValidForTPU(Type type) { - Type elem = getElementTypeOrSelf(type); - return !elem.isa() && !elem.isa(); -} - // Check that cluster results are valid. An result is invalid when it does not // have a valid XLA type. LogicalResult CheckClusterResults(tf_device::ClusterOp cluster) { for (OpResult result : cluster.getResults()) { - if (!TypeValidForTPU(result.getType())) { + if (!TypeValidForXLA(result.getType())) { cluster.emitError() << "The TPUExtractHeadTailOutsideCompilation pass produced a TPU " "cluster with a result with a non-XLA type: " @@ -1123,6 +1137,35 @@ LogicalResult CheckClusterResults(tf_device::ClusterOp cluster) { return success(); } +// Check that op marked for outside compilation has an ancestor also marked for +// outside compilation. +LogicalResult CheckAncestorNotOutsideComp(Operation* op) { + if (!op->getAttrOfType(kXlaOutsideCompilationAttr)) + return success(); + Operation* iter_op = op; + while (auto* parent_op = iter_op->getParentOp()) { + if (parent_op->getAttrOfType(kXlaOutsideCompilationAttr)) { + op->emitOpError() + << "An op marked for outside compilation (having attribute " + << kXlaOutsideCompilationAttr + << ") has an ancestor marked for outside compilation."; + return failure(); + } + iter_op = parent_op; + } + return success(); +} + +// Check the validity of the module, pre-pass. +LogicalResult CheckPreconditions(ModuleOp module) { + auto walk_result = module.walk([&](Operation* op) { + if (failed(CheckAncestorNotOutsideComp(op))) return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (walk_result.wasInterrupted()) return failure(); + return success(); +} + // Check the validity of the module, post-pass. LogicalResult CheckPostconditions(ModuleOp module) { auto walk_result = module.walk([&](tf_device::ClusterOp cluster) { @@ -1136,6 +1179,8 @@ LogicalResult CheckPostconditions(ModuleOp module) { void TPUExtractOutsideCompilation::runOnOperation() { // Get runtime devices information from the closest parent module. auto module = getOperation(); + if (failed(CheckPreconditions(module))) signalPassFailure(); + mlir::TF::RuntimeDevices devices; if (failed(tensorflow::GetDevicesFromOp(module, &devices))) return signalPassFailure(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc index 9fbd0ceea4a..24ee61facfb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc @@ -47,7 +47,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" #define DEBUG_TYPE "tf-tpu-merge-variables-with-execute" @@ -144,6 +144,19 @@ bool AddAccessedResourceIds( return false; } +/* Resources may be merged with an execute op when they are on its device or a + * `COMPOSITE`. Note that a `COMPOSITE` represents a set of devices, they + * are typically associated with packed variables. Presently, we assume this + * set spans all the devices. So, a variable on a `COMPOSITE` will have a local + * instance on the execute op's device. + */ +bool IsResourceMergeable(Attribute& resource_attr, Attribute& device_attr) { + return resource_attr && + ((resource_attr == device_attr) || + (resource_attr.cast().getValue().find( + "COMPOSITE") != llvm::StringRef::npos)); +} + // Finds the variable access info for a TPUExecute op. // - `check_device` specifies whether it checks the device assignment of the // variables to match the TPUExecute op. This is optional in some context, @@ -181,13 +194,13 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo( read_op->getParentRegion() != execute_parent->getParentRegion()) continue; - auto resource = read_op.resource(); + auto resource = read_op.getResource(); if (check_device) { // TODO(lyandy): Wrap resource ops in tf_device.launch. if (auto* resource_op = resource.getDefiningOp()) { auto resource_attr = resource_op->getAttr(kDeviceAttr); // Check device matching for the node defining the resource. - if (!resource_attr || resource_attr != device_attr) continue; + if (!IsResourceMergeable(resource_attr, device_attr)) continue; } else { auto resource_arg = resource.dyn_cast(); assert(resource_arg); @@ -195,7 +208,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo( // Check device matching for the argument defining the resource. auto resource_attr = func.getArgAttrOfType( resource_arg.getArgNumber(), kFuncDeviceAttr); - if (!resource_attr || resource_attr != device_attr) continue; + if (!IsResourceMergeable(resource_attr, device_attr)) continue; } } @@ -230,12 +243,13 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo( first_read->getIterator(), execute_parent->getIterator()))) { if (auto read_op = llvm::dyn_cast(&op)) { VLOG(2) << "Processing read op " << debugString(op); - auto info_it = var_access_info.per_resource_info.find(read_op.resource()); + auto info_it = + var_access_info.per_resource_info.find(read_op.getResource()); bool is_merge_candidate = info_it != var_access_info.per_resource_info.end(); if (is_merge_candidate && - !IsResourceSafeForMerge(read_op.resource(), resource_analysis_info, + !IsResourceSafeForMerge(read_op.getResource(), resource_analysis_info, var_access_info, resource_ids, previous_unknown_resource_access)) { VLOG(2) << " removing op from merge candidates"; @@ -272,7 +286,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo( auto assign_op = llvm::dyn_cast(*result.user_begin()); if (!assign_op) continue; - auto resource = assign_op.resource(); + auto resource = assign_op.getResource(); auto it = var_access_info.per_resource_info.find(resource); if (it == var_access_info.per_resource_info.end()) continue; auto& info = it->getSecond(); @@ -307,14 +321,15 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo( bool is_merge_candidate = true; if (all_assigns.count(assign_op) == 0) is_merge_candidate = false; auto info_it = - var_access_info.per_resource_info.find(assign_op.resource()); + var_access_info.per_resource_info.find(assign_op.getResource()); if (info_it == var_access_info.per_resource_info.end()) is_merge_candidate = false; if (is_merge_candidate && - !IsResourceSafeForMerge( - assign_op.resource(), resource_analysis_info, var_access_info, - resource_ids, previous_unknown_resource_access)) { + !IsResourceSafeForMerge(assign_op.getResource(), + resource_analysis_info, var_access_info, + resource_ids, + previous_unknown_resource_access)) { VLOG(2) << " removing op from merge candidates"; output_merged[info_it->second.execute_output_index] = false; info_it->second.execute_output_index = -1; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_parallel_execute_sink_resource_write.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_parallel_execute_sink_resource_write.cc index 96c994495b1..9ef8cda3d6f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_parallel_execute_sink_resource_write.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_parallel_execute_sink_resource_write.cc @@ -58,9 +58,9 @@ TF::AssignVariableOp GetSingleUseResourceWrite( auto assign_var = dyn_cast(use.getOwner()); if (!assign_var) return nullptr; - if (use.get() != assign_var.value()) return nullptr; + if (use.get() != assign_var.getValue()) return nullptr; - auto* resource_handle_op = assign_var.resource().getDefiningOp(); + auto* resource_handle_op = assign_var.getResource().getDefiningOp(); if (resource_handle_op == parallel_execute) return nullptr; if (resource_handle_op && @@ -104,7 +104,8 @@ void SinkResourceWritesIntoParallelExecute( // resource variable to be the non forwarded value from within the // parallel_execute region. assign_var.getOperation()->moveBefore(terminator); - assign_var.valueMutable().assign(terminator->getOperand(result.index())); + assign_var.getValueMutable().assign( + terminator->getOperand(result.index())); results_to_remove.push_back(result.index()); } @@ -136,7 +137,7 @@ void SinkResourceWritesIntoParallelExecute( for (auto region : llvm::zip(new_parallel_execute.getRegions(), parallel_execute.getRegions())) - std::get<0>(region)->takeBody(*std::get<1>(region)); + std::get<0>(region).takeBody(std::get<1>(region)); for (auto result : llvm::zip(results_to_remap, new_parallel_execute.getResults())) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_partitioned_op_conversion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_partitioned_op_conversion.cc new file mode 100644 index 00000000000..a2232f9f33b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_partitioned_op_conversion.cc @@ -0,0 +1,147 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +namespace mlir { +namespace TFTPU { +namespace { + +#define GEN_PASS_DEF_TPUPARTITIONEDOPCONVERSIONPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" + +struct TPUPartitionedOpConversionPass + : public impl::TPUPartitionedOpConversionPassBase< + TPUPartitionedOpConversionPass> { + void runOnOperation() override; +}; + +template +LogicalResult ReplacePartitionedOp(IntegerAttr num_cores_per_replica, T op) { + constexpr bool is_input = + std::is_same_v, TF::TPUPartitionedInputOp>; + static_assert( + is_input || std::is_same_v, TF::TPUPartitionedOutputOp>, + "operator should either be an input or output"); + + OpBuilder builder(op); + int partition_dim = op.getPartitionDim(); + bool is_replicated = partition_dim == -1; + if (!(is_replicated || num_cores_per_replica)) return failure(); + + Type first_operand_type; + if constexpr (is_input) { + first_operand_type = op.getOperand(0).getType(); + } else { + first_operand_type = op.getOperand().getType(); + } + + auto element_type = getElementTypeOrSelf(first_operand_type); + if (element_type.isa()) { + first_operand_type = + element_type.cast().getSubtypes().front(); + } + + auto tensor_type = first_operand_type.dyn_cast_or_null(); + if (!(tensor_type && tensor_type.hasRank())) { + return op->emitError() + << "cannot convert op with unranked or non-tensor input type " + << tensor_type << "."; + } + + int rank = tensor_type.getRank(); + if (rank <= partition_dim) { + return op->emitError() << "cannot partition " << first_operand_type + << " (rank = " << rank << ") along dimension " + << partition_dim << "."; + } + + llvm::SmallVector partition_dims(is_replicated ? 0 : rank, 1); + if (!is_replicated) { + partition_dims[partition_dim] = num_cores_per_replica.getInt(); + } + + if constexpr (is_input) { + auto pi = builder.create( + op.getLoc(), op.getType(), op.getOperands(), + builder.getI64ArrayAttr(partition_dims), builder.getBoolAttr(false), + op.get_XlaShardingAttr()); + op->replaceAllUsesWith(pi); + } else { + auto po = builder.create( + op.getLoc(), op.getResultTypes(), op.getOperand(), + builder.getI64ArrayAttr(partition_dims), op.get_XlaShardingAttr()); + op->replaceAllUsesWith(po); + } + + return success(); +} + +void TPUPartitionedOpConversionPass::runOnOperation() { + llvm::SmallVector metadata; + getOperation()->walk( + [&metadata](TF::TPUReplicateMetadataOp op) { metadata.push_back(op); }); + + IntegerAttr num_cores_per_replica; + if (metadata.size() == 1) { + num_cores_per_replica = metadata.front().getNumCoresPerReplicaAttr(); + } + + auto result = getOperation()->walk([&num_cores_per_replica](Operation* op) { + std::optional status; + if (auto partitioned_input = + llvm::dyn_cast_or_null(op)) { + status = ReplacePartitionedOp(num_cores_per_replica, partitioned_input); + } else if (auto partitioned_output = + llvm::dyn_cast_or_null(op)) { + status = ReplacePartitionedOp(num_cores_per_replica, partitioned_output); + } + + if (status.has_value()) { + if (failed(*status) || !op->use_empty()) return WalkResult::interrupt(); + + op->erase(); + } + + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) { + signalPassFailure(); + return; + } +} + +} // namespace + +std::unique_ptr> +CreateTPUPartitionedOpConversionPass() { + return std::make_unique(); +} + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_reorder_replicate_and_partitioned_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_reorder_replicate_and_partitioned_inputs.cc index 2ed5bdae886..be4f986bf1f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_reorder_replicate_and_partitioned_inputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_reorder_replicate_and_partitioned_inputs.cc @@ -11,6 +11,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" @@ -34,80 +35,113 @@ struct TPUReorderReplicateAndPartitionedInputsPass LogicalResult ReorderReplicateAndPartitionedInputs( TF::TPUReplicatedInputOp replicated_input) { - if (!llvm::all_of(replicated_input.inputs(), [](Value input) { - return llvm::isa_and_nonnull( + if (!llvm::all_of(replicated_input.getInputs(), [](Value input) { + return llvm::isa_and_nonnull( input.getDefiningOp()); })) return replicated_input.emitOpError() - << "expects all inputs from 'tf.TPUPartitionedInput' ops"; + << "expects all inputs from 'tf.TPUPartitionedInputV2' ops"; - auto first_partitioned_input = llvm::cast( + const auto metadata_iter = + replicated_input->getBlock()->getOps(); + TF::TPUReplicateMetadataOp metadata; + if (!metadata_iter.empty()) metadata = *(metadata_iter.begin()); + + auto first_partitioned_input = llvm::cast( replicated_input.getOperand(0).getDefiningOp()); - llvm::Optional<::llvm::StringRef> xla_sharding = - first_partitioned_input._XlaSharding(); - int64_t partition_dim = first_partitioned_input.partition_dim(); + auto partition_dims = first_partitioned_input.getPartitionDims(); + const std::optional<::llvm::StringRef> xla_sharding = + first_partitioned_input.get_XlaSharding(); + size_t num_cores_per_replica = first_partitioned_input.getNumOperands(); + if (metadata) { + num_cores_per_replica = metadata.getNumCoresPerReplica(); + } else if (first_partitioned_input.getIsPacked()) { + return first_partitioned_input->emitOpError() + << "num cores per replica unavailable, metadata missing?"; + } - for (auto operand : replicated_input.inputs().drop_front()) { + const bool packed_input = first_partitioned_input.getIsPacked(); + const size_t num_operands_expected = packed_input ? 1 : num_cores_per_replica; + if (metadata && + num_operands_expected != first_partitioned_input.getNumOperands()) { + return first_partitioned_input->emitOpError() + << "expects " << num_operands_expected << " operands but found " + << first_partitioned_input.getNumOperands(); + } + + for (const auto& operand : replicated_input.getInputs().drop_front()) { auto partitioned_input = - llvm::cast(operand.getDefiningOp()); - llvm::Optional<::llvm::StringRef> op_xla_sharding = - partitioned_input._XlaSharding(); - int64_t op_partition_dim = partitioned_input.partition_dim(); - // Abort if TPUPartitionedInput(s) do not have the same attributes. - if (partition_dim != op_partition_dim) + llvm::cast(operand.getDefiningOp()); + const std::optional<::llvm::StringRef> op_xla_sharding = + partitioned_input.get_XlaSharding(); + const auto op_partition_dims = partitioned_input.getPartitionDims(); + // Abort if TPUPartitionedInputV2(s) do not have the same attributes. + if (!llvm::equal(partition_dims, op_partition_dims)) { + return partitioned_input->emitOpError() + << "expects partition_dims = " << partition_dims << " but found " + << op_partition_dims; + } else if (partitioned_input.getIsPacked() != + first_partitioned_input.getIsPacked()) { return partitioned_input->emitOpError() - << "expects partition_dim = " << partition_dim << " but found " - << op_partition_dim; - if (partitioned_input.getNumOperands() != num_cores_per_replica) + << "packing should match across ops"; + } else if (partitioned_input.getNumOperands() != num_operands_expected) { return partitioned_input->emitOpError() - << "expects " << num_cores_per_replica << " operands but found " + << "expects " << num_operands_expected << " operands but found " << partitioned_input.getNumOperands(); - if (xla_sharding != op_xla_sharding) + } else if (xla_sharding != op_xla_sharding) { return replicated_input.emitOpError() - << "expects all inputs from 'tf.TPUPartitionedInput' ops to have " - "identical XLA sharding"; + << "expects all inputs from 'tf.TPUPartitionedInputV2' ops to " + "have identical XLA sharding"; + } } // 2D Matrix to store per core per replica operands. The matrix dimensions are // num_cores_per_replica x num_replicas. i-th row holds the operands for i-th // core. j-th column holds the operands for j-th replica. llvm::SmallVector, 4> - operands_per_replica_per_core; - operands_per_replica_per_core.resize(num_cores_per_replica); + operands_per_replica_per_core(num_cores_per_replica); // Collect all operands in the 2D matrix. - for (auto operand : replicated_input.inputs()) { - auto pi = llvm::cast(operand.getDefiningOp()); - for (auto& pi_operand : pi->getOpOperands()) { - unsigned core_id = pi_operand.getOperandNumber(); - operands_per_replica_per_core[core_id].push_back(pi_operand.get()); + for (auto operand : replicated_input.getInputs()) { + Operation* pi = operand.getDefiningOp(); + for (unsigned core_id = 0; core_id < num_cores_per_replica; ++core_id) { + const auto pi_operand = + packed_input ? pi->getOperand(0) : pi->getOperand(core_id); + operands_per_replica_per_core[core_id].push_back(pi_operand); } } // Create new `tf.TPUReplicatedInput` ops feeding into one - // `tf.TPUPartitionedInput` op. + // `tf.TPUPartitionedInputV2` op. OpBuilder builder(replicated_input); llvm::SmallVector operands_per_core; - for (const auto& operands_per_replica : operands_per_replica_per_core) { + for (auto& operands_per_replica : operands_per_replica_per_core) { + const bool is_packed = + packed_input && llvm::all_equal(operands_per_replica); + if (is_packed) // reduce the duplicates to one input for packed vars + operands_per_replica.erase(operands_per_replica.begin() + 1, + operands_per_replica.end()); auto replicate_op = builder.create( replicated_input.getLoc(), replicated_input.getType(), operands_per_replica, replicated_input->getAttrs()); + replicate_op.setIsPacked(is_packed); operands_per_core.push_back(replicate_op); } - auto pi = builder.create( + auto pi = builder.create( first_partitioned_input.getLoc(), replicated_input.getType(), operands_per_core, first_partitioned_input->getAttrs()); - replicated_input.replaceAllUsesWith(pi.output()); + pi.setIsPacked(false); // inputs are now ops--not resources + replicated_input.replaceAllUsesWith(pi.getOutput()); return success(); } void TPUReorderReplicateAndPartitionedInputsPass::runOnOperation() { auto result = getOperation()->walk([](TF::TPUReplicatedInputOp replicated_input) { - if (llvm::none_of(replicated_input.inputs(), [](Value input) { - return llvm::isa_and_nonnull( + if (llvm::none_of(replicated_input.getInputs(), [](Value input) { + return llvm::isa_and_nonnull( input.getDefiningOp()); })) return WalkResult::advance(); @@ -124,7 +158,7 @@ void TPUReorderReplicateAndPartitionedInputsPass::runOnOperation() { return; } - getOperation()->walk([](TF::TPUPartitionedInputOp partitioned_input) { + getOperation()->walk([](TF::TPUPartitionedInputV2Op partitioned_input) { if (partitioned_input->use_empty()) partitioned_input->erase(); }); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc index 998cc1ef0bd..bb3941df08a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" @@ -38,6 +39,9 @@ namespace { #define GEN_PASS_DEF_TPURESOURCEREADSWRITESPARTITIONINGPASS #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" +constexpr char kUseSpmdAttr[] = "use_spmd_for_xla_partitioning"; +constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica"; + struct TPUResourceReadsWritesPartitioningPass : public impl::TPUResourceReadsWritesPartitioningPassBase< TPUResourceReadsWritesPartitioningPass> { @@ -71,25 +75,25 @@ Type GetResourceSubtype(Value resource) { // `old_partitioned_input` is the predecessor of `old_read`. `new_reads` // contains the predecessors of `new_partitioned_input`. LogicalResult UpdateReadUses(TF::ReadVariableOp old_read, - TF::TPUPartitionedInputOp old_partitioned_input, - TF::TPUPartitionedInputOp new_partitioned_input, + TF::TPUPartitionedInputV2Op old_partitioned_input, + TF::TPUPartitionedInputV2Op new_partitioned_input, llvm::SmallVector new_reads) { xla::OpSharding sharding; sharding.ParseFromString( - old_partitioned_input._XlaShardingAttr().getValue().str()); + old_partitioned_input.get_XlaShardingAttr().getValue().str()); for (OpOperand& read_use : - llvm::make_early_inc_range(old_read.value().getUses())) { + llvm::make_early_inc_range(old_read.getValue().getUses())) { if (dyn_cast_or_null(read_use.getOwner())) { // ClusterFunc's use of the Read is replaced with use of the - // TPUPartitionedInput. + // TPUPartitionedInputV2. read_use.set(new_partitioned_input); } else { - // Outside compiled code's use of the Read after TPUPartitionedInput is - // replaced with use of the first Read before the TPUPartitionedInput. + // Outside compiled code's use of the Read after TPUPartitionedInputV2 is + // replaced with use of the first Read before the TPUPartitionedInputV2. if (sharding.type() != xla::OpSharding::REPLICATED) { // TODO(b/243077297): Generalize to any sharding. old_partitioned_input.emitOpError( - "TPUPartitionedInput variable used in outside compiled code is " + "TPUPartitionedInputV2 variable used in outside compiled code is " "only supported with REPLICATED sharding"); return failure(); } @@ -109,12 +113,14 @@ LogicalResult UpdateReadUses(TF::ReadVariableOp old_read, LogicalResult PartitionResourceReadsWrites( tf_device::ClusterFuncOp cluster_func) { bool use_spmd = false; - if (auto use_spmd_attr = cluster_func->getAttrOfType( - "use_spmd_for_xla_partitioning")) + if (auto use_spmd_attr = cluster_func->getAttrOfType(kUseSpmdAttr)) use_spmd = use_spmd_attr.getValue(); if (!use_spmd) return success(); + auto num_cores_per_replica_attr = + cluster_func->getAttrOfType(kNumCoresPerReplicaAttr); + // Wrap the ClusterFunc with a ParallelExecute if it does not already exist. OpBuilder builder(cluster_func); tf_device::ParallelExecuteOp parallel_execute = @@ -122,35 +128,56 @@ LogicalResult PartitionResourceReadsWrites( if (!parallel_execute) parallel_execute = BuildParallelExecuteOp(cluster_func, &builder); - // Rewrite results before rewriting operands as `tf.TPUPartitionedInput` + // Rewrite results before rewriting operands as `tf.TPUPartitionedInputV2` // resource handle results is an indicator for a partitioned resource - // variable. These `tf.TPUPartitionedInput` will be removed when rewriting + // variable. These `tf.TPUPartitionedInputV2` will be removed when rewriting // the operands. for (Value result : parallel_execute.getExecuteOutputs()) { if (!result.hasOneUse()) continue; auto assign_var = llvm::dyn_cast(*result.getUsers().begin()); - if (!assign_var || assign_var.value() != result) continue; - auto partitioned_input = llvm::dyn_cast_or_null( - assign_var.resource().getDefiningOp()); + if (!assign_var || assign_var.getValue() != result) continue; + auto partitioned_input = + llvm::dyn_cast_or_null( + assign_var.getResource().getDefiningOp()); if (!partitioned_input || - !AllResourceTypesHaveSubtypes(partitioned_input.inputs().getTypes())) + !AllResourceTypesHaveSubtypes(partitioned_input.getInputs().getTypes())) continue; + const auto inputs = partitioned_input.getInputs(); + const bool packed_input = partitioned_input.getIsPacked(); + int num_cores_per_replica = partitioned_input.getN(); + if (num_cores_per_replica_attr) { + num_cores_per_replica = num_cores_per_replica_attr.getInt(); + } else if (packed_input) { + return partitioned_input->emitOpError() + << "num cores per replica unavailable"; + } + + const int num_operands_expected = packed_input ? 1 : num_cores_per_replica; + if (num_cores_per_replica_attr && num_operands_expected != inputs.size()) { + return partitioned_input->emitOpError() + << "expects " << num_operands_expected << " operands but found " + << partitioned_input.getNumOperands(); + } + builder.setInsertionPoint(assign_var); llvm::SmallVector partitioned_output_types; - partitioned_output_types.reserve(partitioned_input.N()); - for (Type input_type : partitioned_input.inputs().getTypes()) - partitioned_output_types.push_back(GetResourceSubtype(input_type)); - auto partitioned_output = builder.create( + partitioned_output_types.reserve(num_cores_per_replica); + for (int i = 0; i < num_cores_per_replica; ++i) { + const auto& input = packed_input ? inputs[0] : inputs[i]; + partitioned_output_types.push_back(GetResourceSubtype(input.getType())); + } + + auto partitioned_output = builder.create( cluster_func->getLoc(), partitioned_output_types, result, - partitioned_input.partition_dimAttr(), - partitioned_input._XlaShardingAttr()); - for (auto resource_write : - llvm::zip(partitioned_input.inputs(), partitioned_output.output())) + partitioned_input.getPartitionDimsAttr(), + partitioned_input.get_XlaShardingAttr()); + for (auto [i, value] : llvm::enumerate(partitioned_output.getOutput())) { + const auto& resource = packed_input ? inputs[0] : inputs[i]; builder.create( - assign_var->getLoc(), /*resource=*/std::get<0>(resource_write), - /*value=*/std::get<1>(resource_write)); + assign_var->getLoc(), /*resource=*/resource, /*value=*/value); + } assign_var.erase(); } @@ -158,24 +185,37 @@ LogicalResult PartitionResourceReadsWrites( auto read_var = llvm::dyn_cast_or_null( operand.get().getDefiningOp()); if (!read_var) continue; - auto partitioned_input = llvm::dyn_cast_or_null( - read_var.resource().getDefiningOp()); - if (!partitioned_input || - !AllResourceTypesHaveSubtypes(partitioned_input.inputs().getTypes())) { + auto partitioned_input = + llvm::dyn_cast_or_null( + read_var.getResource().getDefiningOp()); + if (!partitioned_input || !AllResourceTypesHaveSubtypes( + partitioned_input.getInputs().getTypes())) { continue; } - builder.setInsertionPoint(partitioned_input); + // we only want to create one read variable op per unique input + // otherwise tpu rewriting will fail to clean up the duplicates + llvm::SmallMapVector read_variable_ops; llvm::SmallVector partitioned_reads; - for (Value input : partitioned_input.inputs()) { - auto partitioned_read = builder.create( - read_var->getLoc(), GetResourceSubtype(input), input); - partitioned_reads.push_back(partitioned_read.value()); + builder.setInsertionPoint(partitioned_input); + + for (Value input : partitioned_input.getInputs()) { + auto search = read_variable_ops.find(input); + // if a read variable op already doesn't exist for this input, create it + if (search == read_variable_ops.end()) { + auto partitioned_read = builder.create( + read_var->getLoc(), GetResourceSubtype(input), input); + search = read_variable_ops.insert({input, partitioned_read.getValue()}) + .first; + } + partitioned_reads.push_back(search->second); } - auto partitioned_read = builder.create( - partitioned_input->getLoc(), read_var.value().getType(), - partitioned_reads, partitioned_input.partition_dimAttr(), - partitioned_input._XlaShardingAttr()); + + auto partitioned_read = builder.create( + partitioned_input->getLoc(), read_var.getValue().getType(), + partitioned_reads, partitioned_input.getPartitionDimsAttr(), + partitioned_input.getIsPackedAttr(), + partitioned_input.get_XlaShardingAttr()); if (failed(UpdateReadUses(read_var, partitioned_input, partitioned_read, partitioned_reads))) return failure(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc index 45fd630de91..06f7c911a6a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_read_for_write.cc @@ -57,18 +57,18 @@ ResourceValueAndSubtype GetResourceWriteResult( auto assign_var = dyn_cast(result_user); if (!assign_var) return resource; - auto handle = assign_var.resource(); + auto handle = assign_var.getResource(); // Skip result if cluster writes to the same variable via multiple results. for (Operation* handle_user : handle.getUsers()) { if (handle_user == assign_var) continue; auto assign_var_user = dyn_cast(handle_user); if (!assign_var_user) continue; - if (assign_var_user.value().getDefiningOp() == cluster_func) + if (assign_var_user.getValue().getDefiningOp() == cluster_func) return resource; } - resource.resource = assign_var.resource(); - resource.subtype = assign_var.value().getType(); + resource.resource = assign_var.getResource(); + resource.subtype = assign_var.getValue().getType(); return resource; } @@ -77,7 +77,7 @@ bool ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func, Value resource) { for (Operation* resource_user : resource.getUsers()) if (auto read = dyn_cast(resource_user)) - for (Operation* read_user : read.value().getUsers()) + for (Operation* read_user : read.getValue().getUsers()) if (read_user == cluster_func) return true; return false; @@ -105,7 +105,7 @@ void TPUResourceReadForWritePass::runOnOperation() { auto new_read = builder.create( resource_and_type.resource.getLoc(), resource_and_type.subtype, resource_and_type.resource); - read_operands.push_back(new_read.value()); + read_operands.push_back(new_read.getValue()); } if (read_operands.empty()) continue; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 64cf3da19a0..d22946dcee1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -301,7 +301,7 @@ LogicalResult SetMetadataProtoFromClusterFuncOp( if (xla_device_assignment.has_value()) *metadata->mutable_device_assignment() = - std::move(xla_device_assignment.getValue()); + std::move(xla_device_assignment.value()); auto use_spmd_attr = op->getAttrOfType(kUseXlaSpmdAttr); if (!use_spmd_attr) return op.emitOpError(CreateMissingAttributeMsg(kUseXlaSpmdAttr)); @@ -473,7 +473,8 @@ int MovePreservedParallelExecuteChildren( tf_device::ParallelExecuteOp old_parallel_execute, tf_device::ParallelExecuteOp* new_parallel_execute) { // `num_moved_children` is the number of children that will be preserved. - const int num_moved_children = old_parallel_execute.regions().size() - 1; + const size_t num_moved_children = + old_parallel_execute.getRegions().size() - 1; *new_parallel_execute = builder->create( old_parallel_execute->getLoc(), num_moved_children + num_cores_per_replica, concatenated_output_types); @@ -481,8 +482,8 @@ int MovePreservedParallelExecuteChildren( // `cluster_idx` is the index of the child with the `ClusterFuncOp`, which // will be replaced. int cluster_idx = -1; - for (int child_idx = 0; child_idx < old_parallel_execute.regions().size(); - ++child_idx) { + for (size_t child_idx = 0; + child_idx < old_parallel_execute.getRegions().size(); ++child_idx) { auto& block = old_parallel_execute.GetRegionBlockWithIndex(child_idx); if (cluster_func->getBlock() == &block) { assert(cluster_idx == -1); @@ -496,8 +497,8 @@ int MovePreservedParallelExecuteChildren( int old_idx = child_idx >= cluster_idx ? child_idx + 1 : child_idx; int new_idx = child_idx >= cluster_idx ? child_idx + num_cores_per_replica : child_idx; - new_parallel_execute->getRegions()[new_idx]->takeBody( - *old_parallel_execute.getRegions()[old_idx]); + new_parallel_execute->getRegions()[new_idx].takeBody( + old_parallel_execute.getRegions()[old_idx]); } return cluster_idx; @@ -528,9 +529,9 @@ LogicalResult AddToParallelExecuteOp( concatenated_output_types.reserve(num_results_pre_cluster + cluster_result_types.size() * num_cores_per_replica); - for (auto* region : old_parallel_execute.getRegions()) { - if (!isa(region->front().front())) { - for (Type t : region->front().front().getResultTypes()) + for (mlir::Region& region : old_parallel_execute.getRegions()) { + if (!isa(region.front().front())) { + for (Type t : region.front().front().getResultTypes()) concatenated_output_types.emplace_back(t); } } @@ -646,11 +647,11 @@ LogicalResult CheckTPUPartitionedInputAndOutputAreValid( for (auto cluster_result : parallel_execute.getExecuteOutputs()) { for (Operation* user : llvm::make_early_inc_range(cluster_result.getUsers())) { - // Check that user has no outputs that are TPUPartitionedOutput + // Check that user has no outputs that are TPUPartitionedOutputV2 for (auto result : user->getResults()) { for (Operation* user : llvm::make_early_inc_range(result.getUsers())) { - if (llvm::isa(user)) { - user->emitError() << "Input of TPUPartitionedOutput must " + if (llvm::isa(user)) { + user->emitError() << "Input of TPUPartitionedOutputV2 must " << "be in tpu computation."; return failure(); } @@ -658,17 +659,17 @@ LogicalResult CheckTPUPartitionedInputAndOutputAreValid( } } } - for (auto cluster_operand : cluster.operands()) { + for (auto cluster_operand : cluster.getOperands()) { Operation* def = cluster_operand.getDefiningOp(); - // This pass assumes that a TPUPartitionedInput is preceeded by + // This pass assumes that a TPUPartitionedInputV2 is preceeded by // ReadVariable ops, and not vice versa. An earlier pass, // TPUResourceReadsWritesPartitioning, should have ensured this // precondition. if (!def) continue; for (auto operand : def->getOperands()) { Operation* def_of_read = operand.getDefiningOp(); - if (llvm::isa_and_nonnull(def_of_read)) { - def_of_read->emitError() << "Output of TPUPartitionedInput must " + if (llvm::isa_and_nonnull(def_of_read)) { + def_of_read->emitError() << "Output of TPUPartitionedInputV2 must " << "be in tpu computation."; return failure(); } @@ -682,8 +683,8 @@ LogicalResult CheckParallelExecuteConstainsValidNonClusterProcess( int num_pre_cluster_regions = 0; int num_post_cluster_regions = 0; int num_cluster_regions = 0; - for (auto* region : parallel_execute.getRegions()) { - if (isa(region->front().front())) { + for (mlir::Region& region : parallel_execute.getRegions()) { + if (isa(region.front().front())) { if (num_cluster_regions == 0) { num_pre_cluster_regions++; } else { @@ -704,9 +705,9 @@ LogicalResult CheckParallelExecuteConstainsValidNonClusterProcess( int GetNumResultsPreCluster(tf_device::ParallelExecuteOp parallel_execute) { int num_results_pre_cluster = 0; - for (auto region : parallel_execute.getRegions()) { - if (isa(region->front().front())) { - num_results_pre_cluster = region->front().front().getResultTypes().size(); + for (mlir::Region& region : parallel_execute.getRegions()) { + if (isa(region.front().front())) { + num_results_pre_cluster = region.front().front().getResultTypes().size(); } } return num_results_pre_cluster; @@ -730,7 +731,7 @@ LogicalResult Rewrite( if (!old_parallel_execute) old_parallel_execute = BuildParallelExecuteOp(cluster_func, builder); - // check TPUPartitionedInput and TPUPartitionedOutput are in valid pattern + // check TPUPartitionedInputV2 and TPUPartitionedOutputV2 are in valid pattern if (failed(CheckTPUPartitionedInputAndOutputAreValid(cluster_func, old_parallel_execute))) return failure(); @@ -840,7 +841,7 @@ LogicalResult Rewrite( } else if (compile_device_op) { result_id->setAttr("device", compile_device_op); } - res.output().replaceAllUsesWith(compile_op->getResult(0)); + res.getOutput().replaceAllUsesWith(compile_op->getResult(0)); } BuildTPUCompileSucceededAssertOp( @@ -879,8 +880,8 @@ LogicalResult Rewrite( return RemoveSingletonParallelExecuteOp(new_parallel_execute, builder); } -// Erase rewritten ClusterFuncOp(s). If TPUPartitionedInputOp / -// TPUPartitionedOutputOp are present, they must be removed along with the +// Erase rewritten ClusterFuncOp(s). If TPUPartitionedInputV2Op / +// TPUPartitionedOutputV2Op are present, they must be removed along with the // ClusterFuncOp(s). void EraseClusterFuncs( llvm::MutableArrayRef to_be_erased) { @@ -891,17 +892,17 @@ void EraseClusterFuncs( for (auto result : old_parallel_execute.getExecuteOutputs()) { for (Operation* user : llvm::make_early_inc_range(result.getUsers())) { - if (llvm::isa(user)) { + if (llvm::isa(user)) { assert(user->use_empty()); user->erase(); } } } - for (auto operand : cluster.operands()) { + for (auto operand : cluster.getOperands()) { Operation* def = operand.getDefiningOp(); if (operand.hasOneUse() && - llvm::isa_and_nonnull(def)) { + llvm::isa_and_nonnull(def)) { operand.dropAllUses(); def->erase(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index 9dc8e7bba56..19d27247efe 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -70,16 +70,16 @@ std::string CreateMissingAttributeMsg(llvm::StringRef attribute) { // `tf_device.cluster_func` operand value. If value is a resource type then // TPUPartitionedInput op will be connected to a ReadVariable op that feeds into // a `tf_device.cluster_func`. -llvm::Optional GetXlaShardingFromOperand(Value value) { +std::optional GetXlaShardingFromOperand(Value value) { Value value_to_visit = value; if (auto read_var = value_to_visit.getDefiningOp()) - value_to_visit = read_var.resource(); + value_to_visit = read_var.getResource(); if (auto partitioned_input = - value_to_visit.getDefiningOp()) - return partitioned_input._XlaSharding(); + value_to_visit.getDefiningOp()) + return partitioned_input.get_XlaSharding(); - return llvm::None; + return std::nullopt; } // Given a `tf_device.cluster_func` operand value return true iff it a device @@ -142,7 +142,7 @@ LogicalResult VerifyShardings( // Assign the logical device if an op has an attribute `TPU_REPLICATED_CORE:n`, // the corresponding input sharding arg will be associated with // logical device `n`. -llvm::Optional AssignLogicalDeviceFromTPUReplicatedCoreAttr( +std::optional AssignLogicalDeviceFromTPUReplicatedCoreAttr( Operation* op, const llvm::SmallVector& logical_device_vec) { if (auto device = op->getAttrOfType("device")) { if (!device.getValue().empty() && !device.getValue().str().empty()) { @@ -155,7 +155,7 @@ llvm::Optional AssignLogicalDeviceFromTPUReplicatedCoreAttr( } } } - return llvm::None; + return std::nullopt; } // Returns XLA sharding from a XlaSharding op connected to an argument value. If @@ -168,7 +168,7 @@ llvm::Optional AssignLogicalDeviceFromTPUReplicatedCoreAttr( // Case, While) ops and Caller return values. // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded // inputs. -llvm::Optional GetXlaShardingFromArg( +std::optional GetXlaShardingFromArg( Value value, const llvm::SmallVector& logical_device_vec) { llvm::SmallPtrSet visited_values; llvm::SmallVector values_to_visit{value}; @@ -180,7 +180,7 @@ llvm::Optional GetXlaShardingFromArg( for (auto& use : value_to_visit.getUses()) { Operation* owner = use.getOwner(); if (auto sharding = llvm::dyn_cast(owner)) - return sharding._XlaSharding(); + return sharding.get_XlaSharding(); if (auto logical_device = AssignLogicalDeviceFromTPUReplicatedCoreAttr( owner, logical_device_vec)) { @@ -205,7 +205,7 @@ llvm::Optional GetXlaShardingFromArg( values_to_visit.swap(next_values_to_visit); } - return llvm::None; + return std::nullopt; } // Extracts sharding configurations for all inputs by parsing XlaSharding/ @@ -241,19 +241,19 @@ void IdentifyXlaShardingForComputationInputs( // Sharding configurations are added to the tf_device.ClusterFunc as an // attribute and the function as an argument attribute. for (auto operand_and_arg : - llvm::zip(cluster_func.operands(), function_block.getArguments())) { + llvm::zip(cluster_func.getOperands(), function_block.getArguments())) { Value operand = std::get<0>(operand_and_arg); BlockArgument arg = std::get<1>(operand_and_arg); if (auto operand_sharding = GetXlaShardingFromOperand(operand)) { - sharding_for_args.push_back(operand_sharding.getValue()); + sharding_for_args.push_back(operand_sharding.value()); continue; } if (infer_from_computation) { auto arg_sharding = GetXlaShardingFromArg(arg, logical_device_vec); if (arg_sharding) { - sharding_for_args.push_back(arg_sharding.getValue()); + sharding_for_args.push_back(arg_sharding.value()); continue; } } @@ -275,20 +275,21 @@ void IdentifyXlaShardingForComputationInputs( // Returns XLA sharding from TPUPartitionedOutput or TPUPartitionedInput (via // AssignVariableOp/resource write) op connected to a `tf_device.cluster_func` // result value. -llvm::Optional GetXlaShardingFromResult(Value value) { - if (!value.hasOneUse()) return llvm::None; +std::optional GetXlaShardingFromResult(Value value) { + if (!value.hasOneUse()) return std::nullopt; Operation* user = *value.getUsers().begin(); if (auto partitioned_output = - llvm::dyn_cast(user)) - return partitioned_output._XlaSharding(); + llvm::dyn_cast(user)) + return partitioned_output.get_XlaSharding(); if (auto assign_var = llvm::dyn_cast(user)) if (auto partitioned_input = - assign_var.resource().getDefiningOp()) - return partitioned_input._XlaSharding(); + assign_var.getResource() + .getDefiningOp()) + return partitioned_input.get_XlaSharding(); - return llvm::None; + return std::nullopt; } // Looks up arg->retval aliases for every argument, and builds a reverse map. @@ -305,7 +306,7 @@ void ExtractAliases(func::FuncOp func, llvm::SmallVectorImpl& aliases) { } // Returns XLA sharding from argument connected via tf.aliasing_output. -llvm::Optional GetXlaShardingFromAlias( +std::optional GetXlaShardingFromAlias( Value value, llvm::SmallVectorImpl& aliases, const llvm::SmallVectorImpl& sharding_for_args) { int retval_index = value.cast().getResultNumber(); @@ -315,7 +316,7 @@ llvm::Optional GetXlaShardingFromAlias( return sharding_for_args[arg_index]; } } - return llvm::None; + return std::nullopt; } // Returns XLA sharding from XlaSharding op connected to a result value. @@ -327,7 +328,7 @@ llvm::Optional GetXlaShardingFromAlias( // Case, While) ops and Caller argument values. // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded // inputs. -llvm::Optional GetXlaShardingFromRetval( +std::optional GetXlaShardingFromRetval( Value value, const llvm::SmallVector& logical_device_vec) { llvm::SmallPtrSet visited_values; llvm::SmallVector values_to_visit; @@ -346,7 +347,7 @@ llvm::Optional GetXlaShardingFromRetval( } if (auto sharding = llvm::dyn_cast_or_null(def)) - return sharding._XlaSharding(); + return sharding.get_XlaSharding(); if (auto sharding = def->getAttrOfType("_XlaSharding")) { return sharding.strref(); @@ -385,7 +386,7 @@ llvm::Optional GetXlaShardingFromRetval( } } - return llvm::None; + return std::nullopt; } // Extracts sharding configurations for all outputs by parsing XlaSharding/ @@ -418,20 +419,20 @@ void IdentifyXlaShardingForComputationOutputs( OpOperand& retval = std::get<1>(result_and_retval); if (auto result_sharding = GetXlaShardingFromResult(result)) { - sharding_for_rets.push_back(result_sharding.getValue()); + sharding_for_rets.push_back(result_sharding.value()); continue; } if (auto from_alias = GetXlaShardingFromAlias(result, aliases, sharding_for_args)) { - sharding_for_rets.push_back(from_alias.getValue()); + sharding_for_rets.push_back(from_alias.value()); continue; } if (infer_from_computation) { if (auto retval_sharding = GetXlaShardingFromRetval(retval.get(), logical_device_vec)) { - sharding_for_rets.push_back(retval_sharding.getValue()); + sharding_for_rets.push_back(retval_sharding.value()); continue; } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc index b6366083d42..b0dfc09df0e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -84,7 +85,7 @@ void HandleFuncOp(Operation* op) { // Handles cast op between the first convolution and the block argument. LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef new_shape) { - auto cast_input = cast_op.x(); + auto cast_input = cast_op.getX(); // Update input type. auto transform_result_type = RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input)); @@ -98,7 +99,7 @@ LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef new_shape) { block_arg = nullptr; cast_op_input = nullptr; } else { - auto cast_input = cast_op_input.x(); + auto cast_input = cast_op_input.getX(); // Update input type. auto transform_result_type = RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input)); @@ -113,7 +114,7 @@ LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef new_shape) { // Handles padding before convolution for space to depth transform. LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) { - auto ranked_type = op.input().getType().dyn_cast(); + auto ranked_type = op.getInput().getType().dyn_cast(); if (!ranked_type) return failure(); auto pad_input_shape = ranked_type.getShape(); Location loc = op.getLoc(); @@ -162,7 +163,7 @@ void HandleConv2DStride(TF::Conv2DOp conv2d) { // Transforms input shape for the first convolution. void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) { - auto input = conv2d.input(); + auto input = conv2d.getInput(); auto input_shape = input.getType().cast().getShape(); SmallVector transform_shape = { input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size, @@ -223,7 +224,7 @@ void HandleConv2DFilter(TF::Conv2DOp conv2d, int64_t block_size) { // 2. Reshape to [4, 2, 4, 2, 3, 64] // 3. Transpose to [4, 4, 2, 2, 3, 64] // 4. Reshape to [4, 4, 12, 64] - auto filter = conv2d.filter(); + auto filter = conv2d.getFilter(); OpBuilder builder(conv2d); builder.setInsertionPoint(conv2d); // Book keeping filter information. @@ -296,7 +297,7 @@ void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop, OpBuilder builder(backprop); builder.setInsertionPoint(backprop); - auto input = backprop.input(); + auto input = backprop.getInput(); // Get new filter size from new_filter_shape. auto new_filter_sizes = builder.create( backprop.getLoc(), @@ -321,10 +322,10 @@ void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop, // Build new BackPropFilterOp. auto loc = backprop.getLoc(); auto new_backprop = builder.create( - loc, new_result_type, input, new_filter_sizes, backprop.out_backprop(), - strides, backprop.use_cudnn_on_gpu(), backprop.padding(), - backprop.explicit_paddings(), backprop.data_format(), - backprop.dilations()); + loc, new_result_type, input, new_filter_sizes, backprop.getOutBackprop(), + strides, backprop.getUseCudnnOnGpu(), backprop.getPadding(), + backprop.getExplicitPaddings(), backprop.getDataFormat(), + backprop.getDilations()); // For example, if new filter shape is [4, 4, 12, 64], old filter shape // is [7, 7, 3, 64] with block_size 2. @@ -440,7 +441,7 @@ void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size, llvm::dyn_cast(cluster_func->getParentOp()); llvm::SmallVector transform_input_indices; - for (auto input : llvm::enumerate(cluster_func.operands())) { + for (const auto& input : llvm::enumerate(cluster_func.getOperands())) { if (auto block_arg = input.value().dyn_cast()) { if (block_arg.getArgNumber() != arg_num) continue; // For a block argument, consider transforms only when it is a replicated @@ -486,13 +487,13 @@ bool Conv2DInputShapeCanTransform(Value input) { // Get block argument id and number of users for the input arg. Optional GetBlockArgNum(Value arg) { if (auto block_arg = arg.dyn_cast()) { - if (!Conv2DInputShapeCanTransform(arg)) return None; + if (!Conv2DInputShapeCanTransform(arg)) return std::nullopt; unsigned num_users = std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end()); BlockArgumentInfo block_arg_info = {block_arg.getArgNumber(), num_users}; return block_arg_info; } - return None; + return std::nullopt; } // Gets input block argument id and number of users for the input recursively. @@ -508,47 +509,47 @@ Optional GetInputBlockArgNum(Value input) { while (pad_op || cast_op) { if (pad_op) { - auto block_arg_num = GetBlockArgNum(pad_op.input()); + auto block_arg_num = GetBlockArgNum(pad_op.getInput()); if (block_arg_num.has_value()) return block_arg_num; - next_input = pad_op.input(); + next_input = pad_op.getInput(); } else { - auto block_arg_num = GetBlockArgNum(cast_op.x()); + auto block_arg_num = GetBlockArgNum(cast_op.getX()); if (block_arg_num.has_value()) return block_arg_num; - next_input = cast_op.x(); + next_input = cast_op.getX(); } pad_op = dyn_cast_or_null(next_input.getDefiningOp()); cast_op = dyn_cast_or_null(next_input.getDefiningOp()); } - return None; + return std::nullopt; } // Checks if a convoluton can apply SpaceToDepth transform. // Only the first convolution in the graph whose batch size smaller than 8 // and its input feature size smaller than 8 can be transformed. Optional GetConv2DInputArgNum(TF::Conv2DOp conv2d) { - if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) { - return None; + if (conv2d.getDataFormat() != "NHWC" || conv2d.getStrides().size() != 4) { + return std::nullopt; } // Current supported ops between convolution input and the block arguments are // PadOp and CastOp. - return GetInputBlockArgNum(conv2d.input()); + return GetInputBlockArgNum(conv2d.getInput()); } // Applies space to depth transform for the first convolution on TPU device. void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { // Check if input and filter type are RankedTensorType. auto input_tensor_type = - conv2d.input().getType().dyn_cast(); + conv2d.getInput().getType().dyn_cast(); auto filter_tensor_type = - conv2d.filter().getType().dyn_cast(); + conv2d.getFilter().getType().dyn_cast(); if (!input_tensor_type || !filter_tensor_type) return; // Book keeping filter shape for padding and backprop filter rewrite. auto filter_shape = filter_tensor_type.getShape(); SmallVector old_filter_shape(filter_shape.begin(), filter_shape.end()); // Handles input. - auto conv2d_input = conv2d.input(); + auto conv2d_input = conv2d.getInput(); if (auto block_arg = conv2d_input.dyn_cast()) { // Change on device function type/shape. HandleFuncOp(block_arg.getOwner()->getParentOp()); @@ -557,7 +558,7 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { if (auto pad_op = dyn_cast_or_null(conv2d_input.getDefiningOp())) { // Rewrite pad_op before Convolutioin. if (failed(HandlePad(pad_op, filter_shape[0], block_size))) return; - auto pad_input = pad_op.input(); + auto pad_input = pad_op.getInput(); if (auto block_arg = pad_input.dyn_cast()) { // Change on device function type/shape. HandleFuncOp(block_arg.getOwner()->getParentOp()); @@ -571,7 +572,8 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { // Book keeping new filter shape for backprop filter rewrite. // Filter shape is defined in HandleConv2DFilter, thus it is RankedTensorType. - filter_shape = conv2d.filter().getType().cast().getShape(); + filter_shape = + conv2d.getFilter().getType().cast().getShape(); SmallVector new_filter_shape(filter_shape.begin(), filter_shape.end()); @@ -591,7 +593,7 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) { SmallVector strides(4, 1); for (int i = 0; i < 3; ++i) { - strides[i] = conv2d.strides()[i].cast().getInt(); + strides[i] = conv2d.getStrides()[i].cast().getInt(); } // Space to depth only supports striding at spatial dimension. @@ -634,8 +636,8 @@ void TPUSpaceToDepthPass::runOnOperation() { if (arg_num_and_num_users.has_value()) { // Get block size for the first convolution. int64_t block_size = GetConv2DBlockSize(conv2d); - auto arg_num = arg_num_and_num_users.getValue().arg_num; - auto num_users = arg_num_and_num_users.getValue().num_users; + auto arg_num = arg_num_and_num_users.value().arg_num; + auto num_users = arg_num_and_num_users.value().num_users; argnum_and_convolutions[arg_num].emplace_back(conv2d, block_size); argnum_num_users[arg_num] = num_users; return WalkResult::interrupt(); @@ -689,7 +691,7 @@ void TPUSpaceToDepthPass::runOnOperation() { auto conv2d_and_block_sizes = argnum_and_convolution.getSecond(); int64_t block_size = conv2d_and_block_sizes[0].second; // Apply space to depth transform to the input on the host. - HandleCluster(cluster_func.getValue(), block_size, + HandleCluster(cluster_func.value(), block_size, argnum_and_convolution.getFirst()); // Transform the convolution. for (auto conv2d_and_block_size : conv2d_and_block_sizes) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index 29b43216e1b..797038c7772 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -101,8 +101,8 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( TF::WhileRegionOp while_op, tf_device::ReplicateOp replicate, TF::TPUExecuteAndUpdateVariablesOp execute, tf_device::LaunchOp compile_launch) { - Region& body = while_op.body(); - Region& cond = while_op.cond(); + Region& body = while_op.getBody(); + Region& cond = while_op.getCond(); llvm::SmallVector>, 4> mapping; auto mirrored_variable_indices_attr = @@ -111,7 +111,7 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( // Finds the mapping from a replicate argument to an execute operand. llvm::SmallDenseMap replicate_arg_to_execute_arg; - for (auto index_and_arg : llvm::enumerate(execute.args())) { + for (auto index_and_arg : llvm::enumerate(execute.getArgs())) { auto arg = SkipIdentity(index_and_arg.value(), /*allow_other_use=*/false); if (!arg.hasOneUse() || !getElementTypeOrSelf(arg.getType()).isa()) { @@ -203,12 +203,12 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( // Sort the mapping according to execute operand order. llvm::sort(mapping, llvm::less_first()); // Populate the `retval_index_for_sharding` field of the argument metadate. - for (auto entry : llvm::enumerate(execute.device_var_reads_indices())) { + for (auto entry : llvm::enumerate(execute.getDeviceVarReadsIndices())) { int64_t arg_index = entry.value().cast().getInt(); auto arg_metadata = metadata.mutable_args(arg_index); if (arg_metadata->enable_xla_sharding() == ::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED) { - int64_t ret_index = execute.device_var_updates_indices() + int64_t ret_index = execute.getDeviceVarUpdatesIndices() .getValue()[entry.index()] .cast() .getInt(); @@ -257,7 +257,7 @@ tf_device::ReplicateOp AddInputsToReplicateOp( } SmallVector new_input_values; new_input_values.reserve(new_inputs.size()); - for (auto var : new_inputs) new_input_values.push_back(var.resource()); + for (auto var : new_inputs) new_input_values.push_back(var.getResource()); new_replicated_inputs.emplace_back(new_input_values, new_input_values.front().getType()); OpBuilder builder(replicate); @@ -362,7 +362,7 @@ bool HandleReplicateOp(TF::WhileRegionOp while_op, }); if (!execute) return false; auto compile = - SkipIdentity(execute.key(), /*allow_other_use=*/true).getDefiningOp(); + SkipIdentity(execute.getKey(), /*allow_other_use=*/true).getDefiningOp(); if (!compile) return false; auto compile_launch = llvm::dyn_cast(compile); if (!compile_launch || !compile_launch.WrapsSingleOp() || @@ -379,7 +379,7 @@ bool HandleReplicateOp(TF::WhileRegionOp while_op, auto devices_attr = replicate.getDevices(); if (!devices_attr) return false; - auto device_map = devices_attr.getValue(); + auto device_map = devices_attr.value(); llvm::SmallDenseMap> devices; devices.reserve(device_map.size()); @@ -408,7 +408,7 @@ bool HandleReplicateOp(TF::WhileRegionOp while_op, // `replicate`. llvm::SmallVector reformat_operands; for (const auto& entry : execute_arg_to_outer_args) { - reformat_operands.push_back(execute.args()[entry.first]); + reformat_operands.push_back(execute.getArgs()[entry.first]); } reformat_operands.push_back(compile_launch.getResult(1)); reformat_operands.push_back(replicate.GetBody().getArgument( @@ -433,7 +433,7 @@ bool HandleReplicateOp(TF::WhileRegionOp while_op, } llvm::SmallVector state_var_vals(state_vars.size()); for (const auto& entry : llvm::enumerate(state_vars)) { - state_var_vals[entry.index()] = entry.value().resource(); + state_var_vals[entry.index()] = entry.value().getResource(); } // Add the replicated state var to the end of the replicate operands. unformat_replicate_operands.emplace_back(state_var_vals, @@ -485,7 +485,7 @@ void TPUVariableRuntimeReformattingPass::runOnOperation() { bool reshard_was_inserted = false; module.walk([&](TF::WhileRegionOp while_op) { tf_device::ReplicateOp replicate; - while_op.body().walk([&](tf_device::ReplicateOp replicate_op) { + while_op.getBody().walk([&](tf_device::ReplicateOp replicate_op) { if (replicate == nullptr) { replicate = replicate_op; return WalkResult::advance(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc index 2e72492db4f..abdd1a83d51 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc @@ -111,7 +111,7 @@ std::vector ConvertTFBatchMatMulOp::sliceInput( // [1, num_rows, num_cols] -> [num_rows, num_cols] auto reshape_op = createReshapeOp(value, {num_rows, num_cols}, element_type, loc, rewriter); - sliced.emplace_back(reshape_op.output()); + sliced.emplace_back(reshape_op.getOutput()); } else { // Reshape to rank-3 tensor with first dimension as the batch size. auto reshape_op = createReshapeOp(value, {batch_size, num_rows, num_cols}, @@ -128,16 +128,17 @@ std::vector ConvertTFBatchMatMulOp::sliceInput( SmallVector slice_size = {1, num_rows, num_cols}; Type slice_result_type = RankedTensorType::get(slice_size, element_type); llvm::SmallVector output_types(batch_size, slice_result_type); - auto split_op = rewriter.create( - loc, output_types, split_dimension_op.output(), reshape_op.output()); + auto split_op = rewriter.create(loc, output_types, + split_dimension_op.getOutput(), + reshape_op.getOutput()); // Squeeze each batch, i.e. reshape // [1, num_rows, num_cols] -> [num_rows, num_cols] - for (const auto& split_value : split_op.output()) { + for (const auto& split_value : split_op.getOutput()) { auto reshape_op = createReshapeOp(split_value, {num_rows, num_cols}, element_type, loc, rewriter); - sliced.emplace_back(reshape_op.output()); + sliced.emplace_back(reshape_op.getOutput()); } } return sliced; @@ -146,8 +147,8 @@ std::vector ConvertTFBatchMatMulOp::sliceInput( template LogicalResult ConvertTFBatchMatMulOp::matchAndRewrite( BatchMatMulOpType op, PatternRewriter& rewriter) const { - Value input_lhs = op.x(); - Value input_rhs = op.y(); + Value input_lhs = op.getX(); + Value input_rhs = op.getY(); if (!input_lhs.getType().isa()) { // LHS must be a ranked tensor type @@ -190,15 +191,15 @@ LogicalResult ConvertTFBatchMatMulOp::matchAndRewrite( // Replace the last 2 dimensions of LHS and RHS if necessary. // The actual transpose is done by MatMulOp. - if (op.adj_x()) { + if (op.getAdjX()) { std::swap(lhs_shape[lhs_dims - 1], lhs_shape[lhs_dims - 2]); } - if (op.adj_y()) { + if (op.getAdjY()) { std::swap(rhs_shape[rhs_dims - 1], rhs_shape[rhs_dims - 2]); } - const int rows = lhs_shape[lhs_dims - 2]; - const int cols = rhs_shape[rhs_dims - 1]; + const int64_t rows = lhs_shape[lhs_dims - 2]; + const int64_t cols = rhs_shape[rhs_dims - 1]; if (lhs_shape[lhs_dims - 1] != rhs_shape[rhs_dims - 2]) { // Input dimensions must be compatible for multiplication. @@ -212,20 +213,20 @@ LogicalResult ConvertTFBatchMatMulOp::matchAndRewrite( rewriter.replaceOpWithNewOp(op, matmul_type, /*a=*/input_lhs, /*b=*/input_rhs, - /*transpose_a=*/op.adj_x(), - /*transpose_b=*/op.adj_y()); + /*transpose_a=*/op.getAdjX(), + /*transpose_b=*/op.getAdjY()); return success(); } // Input dimensions must be defined. MatMulBCast does not support partial // shapes. for (auto dim : lhs_shape) { - if (dim == -1) { + if (dim == mlir::ShapedType::kDynamic) { return failure(); } } for (auto dim : rhs_shape) { - if (dim == -1) { + if (dim == mlir::ShapedType::kDynamic) { return failure(); } } @@ -260,9 +261,9 @@ LogicalResult ConvertTFBatchMatMulOp::matchAndRewrite( auto matmul = rewriter.create(loc, matmul_type, /*a=*/sliced_lhs[lhs_batch_idx], /*b=*/sliced_rhs[rhs_batch_idx], - /*transpose_a=*/op.adj_x(), - /*transpose_b=*/op.adj_y()); - matmuls.emplace_back(matmul.product()); + /*transpose_a=*/op.getAdjX(), + /*transpose_b=*/op.getAdjY()); + matmuls.emplace_back(matmul.getProduct()); } // Combine the result of each individual MatMul into a rank-3 tensor. @@ -279,9 +280,9 @@ LogicalResult ConvertTFBatchMatMulOp::matchAndRewrite( result_shape.push_back(rows); result_shape.push_back(cols); - auto reshape_op = createReshapeOp(pack_op.output(), result_shape, + auto reshape_op = createReshapeOp(pack_op.getOutput(), result_shape, element_type, loc, rewriter); - rewriter.replaceOp(op, reshape_op.output()); + rewriter.replaceOp(op, reshape_op.getOutput()); return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc b/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc index 4455152d708..f1403fc3a80 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/update_control_dependencies.cc @@ -14,9 +14,12 @@ limitations under the License. ==============================================================================*/ #include -#include +#include #include +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -25,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h" @@ -34,13 +36,28 @@ namespace mlir { namespace tf_executor { namespace { +// Comparator for `OpsInReverseProgramOrder`. +struct IsAfterInBlock { + bool operator()(Operation* op, Operation* other_op) const { + // This function has an average complexity of O(1). + return other_op->isBeforeInBlock(op); + } +}; + +// Maps group IDs to branch IDs. +using GroupIdToBranchIdMap = absl::flat_hash_map; +// Maps an op to parallel execution IDs. +using OpToParallelIdsMap = + absl::flat_hash_map; +// Maps an op to a set of ops. +using OpToOpsMap = + absl::flat_hash_map>; +// Represents a set of ops in reverse program order. +using OpsInReverseProgramOrder = absl::btree_set; + #define GEN_PASS_DEF_EXECUTORUPDATECONTROLDEPENDENCIESPASS #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" -// Note that `SetVector` provides efficient lookup and deletion as well as -// deterministic iteration order which we need here. -using OpToIslandsMap = llvm::DenseMap>; - class UpdateControlDependenciesPass : public impl::ExecutorUpdateControlDependenciesPassBase< UpdateControlDependenciesPass> { @@ -48,144 +65,222 @@ class UpdateControlDependenciesPass void runOnOperation() override; }; -// Returns true iff the islands are guaranteed to have different devices -// assigned. -bool HaveDifferentDevices(IslandOp first_island, IslandOp second_island) { - Operation& first_op = first_island.GetBody().front(); - Operation& second_op = second_island.GetBody().front(); - llvm::SmallVector parsed_names; - - for (Operation* op : {&first_op, &second_op}) { - auto device_attr = op->getAttrOfType(tensorflow::kDeviceAttr); - // For empty device we can't guarantee that devices are different. - if (!device_attr || device_attr.getValue().empty()) return false; - - tensorflow::DeviceNameUtils::ParsedName parsed_name; - bool success = tensorflow::DeviceNameUtils::ParseFullOrLocalName( - device_attr.getValue(), &parsed_name); - // If parsing was not successful, then we can't guarantee that devices are - // different. - if (!success) return false; - parsed_names.push_back(parsed_name); - } - // If device names are not compatible, then corresponding devices must be - // different. - return !tensorflow::DeviceNameUtils::AreCompatibleDevNames(parsed_names[0], - parsed_names[1]); +const GroupIdToBranchIdMap& EmptyGroupIdToBranchIdMap() { + // clang-format off + static auto* empty_map = new absl::flat_hash_map{}; + return *empty_map; } -// Returns true iff we should ignore a dependency between both islands. -bool ShouldIgnoreDependency(IslandOp first_island, IslandOp second_island) { - return HaveDifferentDevices(first_island, second_island); +// Returns map whose elements are the (group ID,branch ID) pairs for `op`. +const GroupIdToBranchIdMap& GetGroupIdToBranchIdMap( + Operation* op, const OpToParallelIdsMap& op_to_parallel_ids_map) { + auto iter = op_to_parallel_ids_map.find(op); + if (iter == op_to_parallel_ids_map.end()) return EmptyGroupIdToBranchIdMap(); + return iter->second; } -// Collects direct control predecessors per op by querying side effect analysis. -// -// We only collect control predecessor that are islands, others (if any) are -// irrelevant for this pass. -void CollectDirectControlPredecessors( - Operation* op, const TF::SideEffectAnalysis::Info& analysis_for_func, - OpToIslandsMap& control_predecessors_map) { - for (Operation* control_predecessor : - analysis_for_func.DirectControlPredecessors(op)) { - if (auto control_pred_island = - dyn_cast(control_predecessor)) { - control_predecessors_map[op].insert(control_pred_island); +// Returns true iff a control dependency between both ops is considered valid, +// depending on their parallel execution IDs. +// A control dependency is invalid if both ops share a common parallel execution +// group with different branch IDs (in that case, the ops are expected to run in +// parallel). +bool IsValidDependency(Operation* op, Operation* other_op, + const OpToParallelIdsMap& op_to_parallel_ids_map) { + const GroupIdToBranchIdMap& parallel_ids_map = + GetGroupIdToBranchIdMap(op, op_to_parallel_ids_map); + const GroupIdToBranchIdMap& other_parallel_ids_map = + GetGroupIdToBranchIdMap(other_op, op_to_parallel_ids_map); + + for (auto [group_id, branch_id] : parallel_ids_map) { + auto iter = other_parallel_ids_map.find(group_id); + // `other_op` has same group as `op`, with different branch ID. + if (iter != other_parallel_ids_map.end() && iter->second != branch_id) { + return false; } } + // The ops don't share a common group with different branch IDs. + return true; } -// Propagates control predecessors for cases where we don't want to create a -// control dependency even though side effect analysis sees a dependency. -// -// Currently, this is the case for ops with different assigned devices: It can -// happen that side effect analysis sees a dependency because the ops may use -// the same resource (which is basically a modeling issue we have to work -// around here). In such a case, we ignore the dependency, but we have to make -// sure that we don't lose any indirect dependencies we want to keep. -// For example, say side effect analysis sees dependencies A -> B -> C, and A -// and C have the same assigned device and B has a different assigned device. -// Then we want to ignore the dependencies A -> B and B -> C but keep the -// transitive dependency A -> C. -// This function updates `control_predecessors_map` such that this is always the -// case. -void PropagateControlPredecessors( - IslandOp island, const TF::SideEffectAnalysis::Info& analysis_for_func, - OpToIslandsMap& control_predecessors_map) { - // Find control predecessors we want to ignore and mark them for propagation. - llvm::SmallVector control_predecessors_to_propagate; - for (IslandOp control_pred_island : control_predecessors_map[island]) { - if (ShouldIgnoreDependency(island, control_pred_island)) { - control_predecessors_to_propagate.push_back(control_pred_island); +void ClearControlInputs(Operation* op, int& num_control_inputs_removed) { + // We only call this function for island or fetch ops. The second pair of + // parentheses is needed for successful compilation. + assert((isa(op))); + if (auto island = dyn_cast(op)) { + num_control_inputs_removed += island.getControlInputs().size(); + island.getControlInputsMutable().clear(); + } else if (auto fetch = dyn_cast(op)) { + GraphOp graph = fetch->getParentOfType(); + int num_control_fetches = fetch.getNumOperands() - graph.getNumResults(); + if (num_control_fetches > 0) { + fetch.getFetchesMutable().erase(graph.getNumResults(), + num_control_fetches); + num_control_inputs_removed += num_control_fetches; } } - // For all control predecessors to propagate, remove them from island's - // control predecessors and add them as control predecessors for all control - // successors of island (this is to make sure we don't lose any transitive - // dependencies). - for (IslandOp control_pred_island : control_predecessors_to_propagate) { - control_predecessors_map[island].remove(control_pred_island); - for (Operation* control_successor : - analysis_for_func.DirectControlSuccessors(island)) { - control_predecessors_map[control_successor].insert(control_pred_island); +} + +void SetControlInputs( + Operation* op, + const llvm::SmallVector& preds_in_reverse_program_order, + int& num_control_inputs_added) { + // We only call this function for island or fetch ops. The second pair of + // parentheses is needed for successful compilation. + assert((isa(op))); + mlir::MutableOperandRange mutable_control_inputs = + isa(op) ? cast(op).getControlInputsMutable() + : cast(op).getFetchesMutable(); + // Add control inputs in program order of the defining ops. + for (auto iter = preds_in_reverse_program_order.rbegin(); + iter != preds_in_reverse_program_order.rend(); + ++iter) { + Operation* pred = *iter; + if (auto pred_island = dyn_cast(pred)) { + mutable_control_inputs.append(pred_island.getControl()); } } + num_control_inputs_added += preds_in_reverse_program_order.size(); } -void UpdateAllControlDependencies( - func::FuncOp func, const TF::SideEffectAnalysis::Info& analysis_for_func) { - int control_inputs_added = 0; - llvm::SmallVector new_control_inputs; - llvm::SmallVector fetch_control_predecessors; - - OpToIslandsMap control_predecessors_map; - auto graph_op = cast(func.front().front()); - graph_op.walk([&](Operation* op) { - if (!isa(op)) return WalkResult::advance(); - CollectDirectControlPredecessors(op, analysis_for_func, - control_predecessors_map); - if (auto island = dyn_cast(op)) { - PropagateControlPredecessors(island, analysis_for_func, - control_predecessors_map); +// Fills `op_to_parallel_ids_map` from parallel execution attributes in `graph`. +// Returns `failure` iff any attribute is malformed. +LogicalResult FillOpToParallelIdsMap( + GraphOp graph, OpToParallelIdsMap& op_to_parallel_ids_map) { + for (Operation& op : graph.GetBody()) { + auto island = dyn_cast(&op); + if (!island) continue; + + // We call `VerifyExportSuitable` in the beginning of the pass, so every + // island wraps a single op. + Operation& wrapped_op = island.GetBody().front(); + TF::ParallelExecutionIdPairs id_pairs; + if (failed(TF::ParseParallelExecutionIds(&wrapped_op, id_pairs))) { + wrapped_op.emitError() + << "Malformed " << TF::kParallelExecAnnotation << " attribute"; + return failure(); } - return WalkResult::advance(); - }); - - graph_op.walk([&](IslandOp island) { - // Update control inputs for island. - for (Operation* control_predecessor : control_predecessors_map[island]) { - if (auto control_pred_island = - dyn_cast(control_predecessor)) { - new_control_inputs.push_back(control_pred_island.getControl()); + if (id_pairs.empty()) continue; + + GroupIdToBranchIdMap& ids_map = op_to_parallel_ids_map[island]; + for (auto [group_id, branch_id] : id_pairs) ids_map[group_id] = branch_id; + } + return success(); +} + +// Computes and sets direct control inputs for `op`. Also fills +// `active_transitive_preds` and `inactive_transitive_preds` for `op`. +void +UpdateControlDependenciesForOp( + Operation* op, const TF::SideEffectAnalysis::Info& analysis_for_func, + const OpToParallelIdsMap& op_to_parallel_ids_map, + OpToOpsMap& active_transitive_preds, + OpToOpsMap& inactive_transitive_preds, + int& num_control_inputs_removed, + int& num_control_inputs_added, + int& num_invalid_dependencies) { + OpsInReverseProgramOrder potential_preds; + active_transitive_preds[op].insert(op); + for (Operation* pred : analysis_for_func.DirectControlPredecessors(op)) { + // Propagate inactive transitive dependencies from `pred` to `op`. + inactive_transitive_preds[op].insert( + inactive_transitive_preds[pred].begin(), + inactive_transitive_preds[pred].end()); + // Inactive transitive predecessors of `pred` are potential direct + // predecessors of `op` (they are not tracked by `pred`). + for (Operation* transitive_pred : inactive_transitive_preds[pred]) { + potential_preds.insert(transitive_pred); + } + if (IsValidDependency(pred, op, op_to_parallel_ids_map)) { + // We know that any active transitive predecessors will still be covered + // by (pred, op), so we don't have to add them to `potential_preds`. + potential_preds.insert(pred); + } else { + // Active transitive predecessors will not be covered by (pred, op) + // anymore, so add them all as candidates. + for (Operation* transitive_pred : active_transitive_preds[pred]) { + potential_preds.insert(transitive_pred); } + ++num_invalid_dependencies; } - // None of the originally given control deps are necessary. - island.getControlInputsMutable().clear(); - island.getControlInputsMutable().append(new_control_inputs); - control_inputs_added += new_control_inputs.size(); - new_control_inputs.clear(); - }); - - // Update control inputs for fetch op. - FetchOp fetch_op = graph_op.GetFetch(); - - // None of the originally given control deps are necessary. - int num_control_fetches = - fetch_op.getNumOperands() - graph_op.getNumResults(); - if (num_control_fetches > 0) { - fetch_op.getFetchesMutable().erase(graph_op.getNumResults(), - num_control_fetches); } - for (Operation* control_predecessor : control_predecessors_map[fetch_op]) { - if (auto control_pred_island = - dyn_cast(control_predecessor)) { - new_control_inputs.push_back(control_pred_island.getControl()); + llvm::SmallVector preds_in_reverse_program_order; + for (Operation* potential_pred : potential_preds) { + bool is_valid = + IsValidDependency(potential_pred, op, op_to_parallel_ids_map); + if (!is_valid) { + // We don't keep the (pred, op) dependency, so all active transitive + // dependencies become inactive. + inactive_transitive_preds[op].insert( + active_transitive_preds[potential_pred].begin(), + active_transitive_preds[potential_pred].end()); + } else if (!active_transitive_preds[op].contains(potential_pred)) { + // `potential_pred` is not an active transitive predecessor of `op` yet, + // so we must add it as a direct predecessor. + preds_in_reverse_program_order.push_back(potential_pred); + // We keep the (pred, op) dependency, so all active transitive + // dependencies stay active. + active_transitive_preds[op].insert( + active_transitive_preds[potential_pred].begin(), + active_transitive_preds[potential_pred].end()); } } - control_inputs_added += new_control_inputs.size(); - fetch_op.getFetchesMutable().append(new_control_inputs); + ClearControlInputs(op, num_control_inputs_removed); + SetControlInputs(op, preds_in_reverse_program_order, + num_control_inputs_added); +} - VLOG(2) << "Number of control inputs added: " << control_inputs_added; +// This function updates all control dependencies in `func`, represented as +// control inputs for island and fetch ops of the graph body in `func`. +// Ideally, we would purely rely on side effect analysis here and propagate +// the queried dependencies to the island and fetch ops. However, this is +// currently not in line with execution semantics in case of replication and +// parallel executes: If two ops originated from different branches of a +// `tf_device.replicate` or `tf_device.parallel_execute` op, then there should +// be no control dependency between them irrespective of side effects, even if +// this could cause a race condition (see b/262304795). +// Because of this, we need to keep track of the origin of such ops which we do +// via `kParallelExecAnnotation` attributes that are interpreted in this pass. +// +// NOTE: This pass guarantees the minimum number of control inputs. Its runtime +// and space complexity can be quadratic in the number of side-effecting ops per +// function. If that becomes a problem in practice, we could look into speed-ups +// used for `DependencyOptimizer::TransitiveReduction` which solves a similar +// problem and also has worst-case quadratic runtime and space complexity. +// Alternatively, we could allow redundant control inputs (less bookkeeping). +LogicalResult UpdateAllControlDependencies( + func::FuncOp func, const TF::SideEffectAnalysis::Info& analysis_for_func) { + int num_control_inputs_removed = 0; + int num_control_inputs_added = 0; + int num_invalid_dependencies = 0; + + // Maps island ops to parallel IDs of the wrapped ops. + OpToParallelIdsMap op_to_parallel_ids_map; + OpToOpsMap active_transitive_preds, inactive_transitive_preds; + + // We call `VerifyExportSuitable` in the beginning of the pass, so every + // function has a single graph op. + auto graph = cast(func.front().front()); + if (failed(FillOpToParallelIdsMap(graph, op_to_parallel_ids_map))) { + return failure(); + } + for (Operation& op_ref : graph.GetBody()) { + Operation* op = &op_ref; + // We only represent control dependencies between island and fetch ops. + if (!isa(op)) continue; + UpdateControlDependenciesForOp( + op, + analysis_for_func, + op_to_parallel_ids_map, + active_transitive_preds, + inactive_transitive_preds, + num_control_inputs_removed, + num_control_inputs_added, + num_invalid_dependencies); + } + VLOG(2) << "Number of control inputs removed: " << num_control_inputs_removed; + VLOG(2) << "Number of control inputs added: " << num_control_inputs_added; + VLOG(2) << "Number of invalid dependencies: " << num_invalid_dependencies; + return success(); } void UpdateControlDependenciesPass::runOnOperation() { @@ -198,13 +293,14 @@ void UpdateControlDependenciesPass::runOnOperation() { return; } TF::SideEffectAnalysis side_effect_analysis(module); - - // Recompute control dependencies between all islands. for (auto func : module.getOps()) { if (func.isExternal()) continue; const auto& analysis_for_func = side_effect_analysis.GetAnalysisForFunc(func); - UpdateAllControlDependencies(func, analysis_for_func); + if (failed(UpdateAllControlDependencies(func, analysis_for_func))) { + signalPassFailure(); + return; + } } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_cluster_formation.cc index 2c759132900..46e91eaeb51 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_cluster_formation.cc @@ -13,12 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/call_graph_util.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" +inline constexpr absl::string_view kEntryFunctionAttr = "tf.entry_function"; + namespace mlir { namespace { @@ -32,36 +39,55 @@ struct XlaClusterFormationPass void runOnOperation() override; }; -void EncapsulatePartitionedCall(TF::StatefulPartitionedCallOp call_op) { - mlir::OpBuilder builder(call_op); +void EncapsulatePartitionedCall(Operation *call_op) { + OpBuilder builder(call_op); - auto cluster = builder.create( - call_op.getLoc(), call_op.getResultTypes()); + auto cluster = builder.create( + call_op->getLoc(), call_op->getResultTypes()); - call_op.replaceAllUsesWith(cluster.getResults()); + call_op->replaceAllUsesWith(cluster.getResults()); - cluster.getBody().push_back(new mlir::Block); + cluster.getBody().push_back(new Block); - call_op.getOperation()->moveBefore(&cluster.GetBody(), - cluster.GetBody().end()); + call_op->moveBefore(&cluster.GetBody(), cluster.GetBody().end()); builder.setInsertionPointToEnd(&cluster.GetBody()); - builder.create(call_op.getLoc(), - call_op->getResults()); + builder.create(call_op->getLoc(), call_op->getResults()); } void XlaClusterFormationPass::runOnOperation() { ModuleOp module = getOperation(); - - llvm::SmallVector ops; - module.walk([&](TF::StatefulPartitionedCallOp call_op) { - if (call_op->hasAttr(tensorflow::kCompileDeviceTypeAttr)) { - ops.push_back(call_op); + SymbolTable symtab(module); + + llvm::SmallVector entry_funcs; + // A model may have multiple graphs, with each graph having its own entry. + // When a graph is imported to MLIR, `tf.entry_function` will be added to + // each entry function. The one exception are initializer functions, which + // have `tf_saved_model.initializer_type` instead. + module.walk([&](func::FuncOp func) { + if (func->hasAttr(kEntryFunctionAttr) || + func->hasAttr(tf_saved_model::kTfSavedModelInitializerTypeAttr)) { + entry_funcs.push_back(func); } }); - - for (auto call_op : ops) { - EncapsulatePartitionedCall(call_op); + if (entry_funcs.empty()) { + LOG(WARNING) << "no entry function is found"; + } + auto predicate = [](Operation *op) { + if (op->hasAttr(tensorflow::kCompileDeviceTypeAttr)) return true; + return false; + }; + for (auto &root : entry_funcs) { + llvm::SmallVector outermost_call_ops; + if (failed(GetOutermostOpsOfType( + root, symtab, outermost_call_ops, predicate))) + return signalPassFailure(); + // Cluster outermost partitioned calls with _xla_compile_device_type + // attribute. + for (auto &call_op : outermost_call_ops) { + EncapsulatePartitionedCall(call_op); + } } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc index 6ff5363765e..550e5804430 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc @@ -13,16 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This transformation pass converts stateful partitioned calls with -// _xla_compile_device_type attribute to XLA launch ops. +// This transformation pass converts stateful and stateless paritioned calls +// with _xla_compile_device_type attribute to XLA launch ops. -#include -#include -#include +#include #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/call_graph_util.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" @@ -38,8 +37,8 @@ struct XlaRewritePass : public impl::XlaRewritePassBase { void runOnOperation() override; }; -void MoveResourceArgsToEnd(mlir::func::FuncOp callee) { - llvm::DenseMap mapping; +void MoveResourceArgsToEnd(func::FuncOp callee) { + llvm::DenseMap mapping; unsigned num_params = callee.getNumArguments(); llvm::BitVector removed_params(num_params); // Copy the resource-type parameters to the end. @@ -65,10 +64,10 @@ template ::value>::type * = nullptr> -void Rewrite(OpT pcall_op, SymbolTable &symtab) { +void RewriteCall(OpT call_op, SymbolTable &symtab) { llvm::SmallVector non_resource_args, resource_args; bool has_resources = false, in_order = true; - for (const mlir::Value &arg : pcall_op.args()) { + for (const Value &arg : call_op.getArgs()) { if (!getElementTypeOrSelf(arg.getType()).template isa()) { non_resource_args.push_back(arg); if (has_resources) in_order = false; @@ -79,39 +78,52 @@ void Rewrite(OpT pcall_op, SymbolTable &symtab) { } if (!in_order) { - // Functions do not get reused in practise, so skip the check for if the + // Functions do not get reused in practice, so skip the check for if the // callee has been updated. - StringAttr callee_sym = cast(pcall_op.fAttr()).getAttr(); - MoveResourceArgsToEnd(cast(symtab.lookup(callee_sym))); + StringAttr callee_sym = + cast(call_op.getFAttr()).getRootReference(); + MoveResourceArgsToEnd(symtab.lookup(callee_sym)); } - OpBuilder builder(pcall_op->getContext()); - builder.setInsertionPoint(pcall_op); + OpBuilder builder(call_op->getContext()); + builder.setInsertionPoint(call_op); auto xla_launch_op = builder.create( - pcall_op.getLoc(), pcall_op.getResultTypes(), + call_op.getLoc(), call_op.getResultTypes(), /*constants=*/ValueRange({}), ValueRange(non_resource_args), - ValueRange(resource_args), pcall_op.fAttr()); + ValueRange(resource_args), call_op.getFAttr()); - CopyDeviceAndUnderscoredAttributes(pcall_op, xla_launch_op); - pcall_op.replaceAllUsesWith(xla_launch_op.getResults()); - pcall_op.erase(); + CopyDeviceAndUnderscoredAttributes(call_op, xla_launch_op); + call_op.replaceAllUsesWith(xla_launch_op.getResults()); + call_op.erase(); } void XlaRewritePass::runOnOperation() { - mlir::ModuleOp module = getOperation(); + ModuleOp module = getOperation(); SymbolTable symtab(module); + module.walk([&](tf_device::ClusterOp cluster_op) { + cluster_op.getBody().walk([&](mlir::Operation *op) { + if (auto call_op = llvm::dyn_cast(op)) { + RewriteCall(call_op, symtab); + } else if (auto call_op = llvm::dyn_cast(op)) { + RewriteCall(call_op, symtab); + } + }); + }); - module.walk([&](mlir::Operation *op) { - if (!op->hasAttr(tensorflow::kCompileDeviceTypeAttr)) - return WalkResult::advance(); - if (auto pcall_op = dyn_cast(op)) { - Rewrite(pcall_op, symtab); - } else if (auto stateful_pcall_op = - dyn_cast(op)) { - Rewrite(stateful_pcall_op, symtab); + // Verify that there are no nested XLA launch ops. + module.walk([&](TF::XlaLaunchOp xla_launch_op) { + llvm::SmallVector nested_launch_ops; + func::FuncOp root = symtab.lookup( + xla_launch_op.getFunctionAttr().getRootReference()); + if (failed(GetOutermostOpsOfType(root, symtab, + nested_launch_ops))) + return signalPassFailure(); + if (!nested_launch_ops.empty()) { + xla_launch_op.emitError() << "Nested XLA launch ops detected"; + return signalPassFailure(); } - return WalkResult::advance(); }); } + } // namespace namespace TFDevice { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc index 307257ca0c1..8e36d069930 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc @@ -177,7 +177,7 @@ void PopulateEmptyIsland(tf_executor::IslandOp island) { Value operand = yield.getOperand(0); auto identity = builder.create(island.getLoc(), operand.getType(), operand); - yield.setOperand(0, identity.output()); + yield.setOperand(0, identity.getOutput()); } else { auto identity_n = builder.create( island.getLoc(), yield.getOperandTypes(), yield.getOperands()); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 48187ad18a1..4aff79c8585 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -76,7 +76,7 @@ using mlir::Operation; using mlir::SymbolTable; using mlir::Value; using mlir::func::FuncOp; -using stream_executor::port::StatusOr; +using tsl::StatusOr; namespace { @@ -220,7 +220,7 @@ StatusOr> Exporter::GetArgumentNode( *node_def->mutable_device() = device_attr.getValue().str(); llvm::ArrayRef func_arg_i_attrs = - func.getArgAttrs(index); + mlir::function_interface_impl::getArgAttrs(func, index); absl::flat_hash_set attrs_to_ignore = {kDeviceAttr, kAliasingAttr}; TF_RETURN_IF_ERROR(ConvertAttributes(func_arg_i_attrs, attrs_to_ignore, @@ -279,7 +279,8 @@ Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node, TF_RET_CHECK(node_it != nodes_.end()) << "Use of OpResult encountered before def!"; if (input_result.getType().isa()) { - graph_->AddControlEdge(node_it->second, dst_node); + graph_->AddControlEdge(node_it->second, dst_node, + /*allow_duplicates=*/true); } else { graph_->AddEdge(node_it->second, input_result.getResultNumber(), dst_node, dst_index); @@ -769,7 +770,7 @@ StatusOr> ConvertMlirToGraphdef( return graphdef; } -stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef( +tsl::Status ConvertMlirFunctionToFunctionLibraryDef( FuncOp func, const GraphExportConfig& configs, FunctionDef* function_def) { Dialect* tf_dialect = func.getContext()->getLoadedDialect("tf"); FunctionDefLibrary flib; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h index d0768969254..562226e1c76 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h @@ -23,7 +23,6 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -31,8 +30,8 @@ limitations under the License. namespace tensorflow { // Given an MLIR module, returns a GraphDef. -stream_executor::port::StatusOr> -ConvertMlirToGraphdef(mlir::ModuleOp module, const GraphExportConfig& configs); +tsl::StatusOr> ConvertMlirToGraphdef( + mlir::ModuleOp module, const GraphExportConfig& configs); // Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition. // The "main" function of the module is stored in the graph and the rest of diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc index f80bedcd602..04120e81c7b 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc @@ -74,7 +74,9 @@ void SetShapeAttribute(absl::string_view name, ContainerT shapes, for (const llvm::Optional>& shape : shapes) { TensorShapeProto& tshape = *shape_list.add_shape(); if (shape.has_value()) { - for (int64_t dim : *shape) tshape.add_dim()->set_size(dim); + for (int64_t dim : *shape) { + tshape.add_dim()->set_size(mlir::ShapedType::isDynamic(dim) ? -1 : dim); + } } else { tshape.set_unknown_rank(true); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h index 72e47cd2f60..ae87dad305a 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h @@ -19,7 +19,6 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 7eaf8457e81..912404b2bca 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -91,7 +91,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/shape_refiner.h" @@ -141,15 +140,17 @@ namespace tensorflow { constexpr size_t kNumThreadToConvertSignatures = 10; constexpr absl::string_view kOutputShapesAttrName = "_output_shapes"; -using mlir::NamedAttrList; -using mlir::TensorType; -using mlir::tf_saved_model::AssetOp; -using mlir::tf_saved_model::GlobalTensorOp; -using mlir::tf_saved_model::kTfSavedModelInitializerInitType; -using mlir::tf_saved_model::kTfSavedModelInitializerRestoreType; -using mlir::tf_saved_model::kTfSavedModelInitializerTypeAttr; -using mlir::tf_saved_model::SessionInitializerOp; -using stream_executor::port::StatusOr; +using ::mlir::NamedAttrList; +using ::mlir::TensorType; +using ::mlir::tf_saved_model::AssetOp; +using ::mlir::tf_saved_model::GlobalTensorOp; +using ::mlir::tf_saved_model::kTfSavedModelExportedNamesAttr; +using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; +using ::mlir::tf_saved_model::kTfSavedModelInitializerInitType; +using ::mlir::tf_saved_model::kTfSavedModelInitializerRestoreType; +using ::mlir::tf_saved_model::kTfSavedModelInitializerTypeAttr; +using ::mlir::tf_saved_model::SessionInitializerOp; +using ::tsl::StatusOr; namespace { @@ -1281,7 +1282,7 @@ StatusOr ImporterBase::ConvertElementTypeAndShape( } return GetTypeFromTFTensorShape( - llvm::makeArrayRef(dimensions.begin(), dimensions.end()), element_type); + llvm::ArrayRef(dimensions.begin(), dimensions.end()), element_type); } StatusOr ImporterBase::ConvertSubtypes( @@ -1359,7 +1360,7 @@ StatusOr ImporterBase::ConvertAttributeValue( if (attr) attrs.push_back(attr); } return builder_.getArrayAttr( - llvm::makeArrayRef(attrs.begin(), attrs.end())); + llvm::ArrayRef(attrs.begin(), attrs.end())); } return ConvertNonFuncAttributeValue(value, &builder_); } @@ -1797,7 +1798,7 @@ mlir::Location ImporterBase::GetLocation(const Node& node) { // If there are more locations then generate a stack trace, otherwise just // return the name loc. - auto callsite_locs = llvm::makeArrayRef(locations).drop_front(); + auto callsite_locs = llvm::ArrayRef(locations).drop_front(); return callsite_locs.empty() ? node_name_loc : mlir::CallSiteLoc::get(node_name_loc, callsite_locs); @@ -1987,11 +1988,7 @@ mlir::Operation* ImporterBase::CreateOperation( resource = true; return true; } - if (auto with_subtype = - type.dyn_cast()) { - with_subtype.walkSubTypes( - [&](mlir::Type t) { record_resource(t); }); - } + type.walk([&](mlir::Type t) { record_resource(t); }); return resource; }; @@ -2419,7 +2416,7 @@ StatusOr> GraphDefImporter::Convert( crash_analysis::RemoveReportData(flib_crash_handle); }); - VLOG(1) << "Importing: " + VLOG(2) << "Importing: " << ::tensorflow::DumpGraphToFile("tf_mlir_importer_base", graph, &flib_def); @@ -2522,7 +2519,7 @@ StatusOr> GraphDefImporter::Convert( : mlir::func::FuncOp::Visibility::Private; function.setVisibility(visibility); } - VLOG(1) << "Imported: " + VLOG(2) << "Imported: " << tensorflow::DumpMlirOpToFile("tf_mlir_imported_base", module.get()); return module; @@ -3087,7 +3084,7 @@ StatusOr> StructuredValueLinearizer::GetLeafIndexPaths( llvm::StringRef error_context) const { if (error_message_.empty()) { - return llvm::makeArrayRef(leaf_index_paths_); + return llvm::ArrayRef(leaf_index_paths_); } return errors::InvalidArgument( error_context.str(), error_message_, @@ -3377,7 +3374,7 @@ Status CreateSavedModelIR( call.getResults()); } func->setAttr( - "tf_saved_model.exported_names", + kTfSavedModelExportedNamesAttr, builder.getStrArrayAttr(object_names.GetExportedNames(node_id))); const SavedConcreteFunction& concrete_function = object_graph.concrete_functions().at(function.concrete_functions(0)); @@ -3410,7 +3407,7 @@ Status CreateSavedModelIR( " vs ", bound_input_base, ")"); } for (auto index_path : llvm::enumerate(input_index_paths)) { - func.setArgAttr(index_path.index(), "tf_saved_model.index_path", + func.setArgAttr(index_path.index(), kTfSavedModelIndexPathAttr, index_path.value()); } @@ -3437,7 +3434,7 @@ Status CreateSavedModelIR( " vs ", func.getNumResults(), ")"); } for (auto index_path : llvm::enumerate(output_index_paths)) { - func.setResultAttr(index_path.index(), "tf_saved_model.index_path", + func.setResultAttr(index_path.index(), kTfSavedModelIndexPathAttr, index_path.value()); } } else if (object.kind_case() == SavedObject::kVariable) { @@ -3476,7 +3473,7 @@ Status CreateSavedModelIR( /*type=*/mlir::TypeAttr::get(type), /*is_mutable=*/builder.getUnitAttr()); op->setAttr( - "tf_saved_model.exported_names", + kTfSavedModelExportedNamesAttr, builder.getStrArrayAttr(object_names.GetExportedNames(node_id))); } else if (object.kind_case() == SavedObject::kConstant) { const SavedConstant& constant = object.constant(); @@ -3496,7 +3493,7 @@ Status CreateSavedModelIR( /*type=*/mlir::TypeAttr::get(value_attr.getType()), /*is_mutable=*/nullptr); op->setAttr( - "tf_saved_model.exported_names", + kTfSavedModelExportedNamesAttr, builder.getStrArrayAttr(object_names.GetExportedNames(node_id))); } } @@ -3885,7 +3882,7 @@ Status SavedModelSignatureDefImporterLite::ConvertInitializer( // Set the exported name of init function to an reserved name for // tf_saved_model. init_func_op->setAttr( - "tf_saved_model.exported_names", + kTfSavedModelExportedNamesAttr, builder.getStrArrayAttr({absl::StrCat( "__tf_saved_model_session_initializer_", target_node_name)})); init_func_op->setAttr(kTfSavedModelInitializerTypeAttr, @@ -3958,20 +3955,30 @@ Status SavedModelSignatureDefImporterLite::ConvertSignature( << sig_def_key << "."; // Use unique SignatureDef key as exported name. - func_op->setAttr("tf_saved_model.exported_names", + func_op->setAttr(kTfSavedModelExportedNamesAttr, builder.getStrArrayAttr({sig_def_key})); // Transfer input and output parameter names to index_path attributes. for (auto input_and_idx : llvm::enumerate(inputs)) { - func_op.setArgAttr(input_and_idx.index(), "tf_saved_model.index_path", + func_op.setArgAttr(input_and_idx.index(), kTfSavedModelIndexPathAttr, builder.getStrArrayAttr({input_and_idx.value().first})); } for (auto output_and_idx : llvm::enumerate(outputs)) { func_op.setResultAttr( - output_and_idx.index(), "tf_saved_model.index_path", + output_and_idx.index(), kTfSavedModelIndexPathAttr, builder.getStrArrayAttr({output_and_idx.value().first})); } + // Add the original TF function name as a function attribute. + // TODO(b/258817244) Remove this after TFRT exports functions. + for (const auto& [tf_name, mlir_name] : tf_name_to_mlir_name) { + auto func_op = sub_symbol_table.lookup(mlir_name); + TF_RET_CHECK(func_op) + << "Graphdef importer should have created a function named " + << mlir_name << "."; + func_op->setAttr("tf._original_func_name", builder.getStringAttr(tf_name)); + } + // Move the converted functions to top level MLIR module. return MoveConvertedFunctionsToModule(sig_def_key, *sub_module, tf_name_to_mlir_name); @@ -4248,10 +4255,9 @@ StatusOr> ConvertGraphToMlir( tf_name_to_mlir_name); } -stream_executor::port::StatusOr> -ConvertFunctionToMlir(const FunctionBody* fbody, - const FunctionLibraryDefinition& flib_def, - mlir::MLIRContext* context) { +tsl::StatusOr> ConvertFunctionToMlir( + const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def, + mlir::MLIRContext* context) { tensorflow::GraphDebugInfo dummy_debug_info; tensorflow::GraphImportConfig specs; specs.graph_func_name = fbody->fdef.signature().name(); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index 7bcfaa8685d..27972cbf972 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" @@ -40,44 +39,38 @@ inline constexpr absl::string_view kImportModelDefaultGraphFuncName = "main"; // Given a GraphDef, returns a MLIR module containing the graph, expressed with // tf_executor dialect. -stream_executor::port::StatusOr> -ConvertGraphdefToMlir(const GraphDef& graphdef, - const GraphDebugInfo& debug_info, - const GraphImportConfig& specs, - mlir::MLIRContext* context, - bool add_default_attributes = true); +tsl::StatusOr> ConvertGraphdefToMlir( + const GraphDef& graphdef, const GraphDebugInfo& debug_info, + const GraphImportConfig& specs, mlir::MLIRContext* context, + bool add_default_attributes = true); // Given a Graph, returns a MLIR module containing the graph, expressed with // tf_executor dialect. -stream_executor::port::StatusOr> -ConvertGraphToMlir(const Graph& graph, const GraphDebugInfo& debug_info, - const FunctionLibraryDefinition& flib_def, - const GraphImportConfig& specs, mlir::MLIRContext* context); +tsl::StatusOr> ConvertGraphToMlir( + const Graph& graph, const GraphDebugInfo& debug_info, + const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, + mlir::MLIRContext* context); // [Experimental] // Given a Function, returns a MLIR module containing the graph, expressed with // tf_executor dialect. -stream_executor::port::StatusOr> -ConvertFunctionToMlir(const FunctionBody* fbody, - const FunctionLibraryDefinition& flib_def, - mlir::MLIRContext* context); +tsl::StatusOr> ConvertFunctionToMlir( + const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def, + mlir::MLIRContext* context); // Given a SavedModel, returns a MLIR module containing the functions, expressed // with tf_executor dialect. -stream_executor::port::StatusOr> -ConvertSavedModelToMlir(SavedModelV2Bundle* saved_model, - mlir::MLIRContext* context, - absl::Span exported_names, - bool add_default_attributes = true, - bool unconditionally_use_set_output_shapes = false); +tsl::StatusOr> ConvertSavedModelToMlir( + SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, + absl::Span exported_names, bool add_default_attributes = true, + bool unconditionally_use_set_output_shapes = false); // Given a V1 SavedModel, returns a MLIR module containing the functions, // expressed with tf_executor dialect. -stream_executor::port::StatusOr> -ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model, - absl::Span exported_names, - mlir::MLIRContext* context, MLIRImportOptions options, - bool lift_variables = true); +tsl::StatusOr> ConvertSavedModelV1ToMlir( + const SavedModelBundle& saved_model, absl::Span exported_names, + mlir::MLIRContext* context, MLIRImportOptions options, + bool lift_variables = true); // Given a V1 SavedModel, returns a MLIR module containing the functions, // expressed with tf_executor dialect. It does not require a session to be @@ -89,8 +82,7 @@ ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model, // ConvertSavedModelV1ToMlir(), and is not related to TFLite. // // TODO(b/179683149): Rename this class to avoid confusion with TFLite. -stream_executor::port::StatusOr> -ConvertSavedModelV1ToMlirLite( +tsl::StatusOr> ConvertSavedModelV1ToMlirLite( const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info, std::optional> exported_names, mlir::MLIRContext* context, MLIRImportOptions options); @@ -123,8 +115,8 @@ class SavedModelMLIRImportInput { // and remain valid for the graph. // `name` is a unique identifier for this subgraph, so the implementation can // use it for eg. debugging or caching compilation results. - virtual stream_executor::port::StatusOr GetSubGraph( - absl::string_view name, GraphImportConfig& specs) = 0; + virtual tsl::StatusOr GetSubGraph(absl::string_view name, + GraphImportConfig& specs) = 0; private: const MetaGraphDef* meta_graph_def_ = nullptr; @@ -142,8 +134,7 @@ class SavedModelMLIRImportInput { // ConvertSavedModelV1ToMlir(), and is not related to TFLite. // // TODO(b/179683149): Rename this class to avoid confusion with TFLite. -stream_executor::port::StatusOr> -ConvertSavedModelV1ToMlirLite( +tsl::StatusOr> ConvertSavedModelV1ToMlirLite( SavedModelMLIRImportInput& input, std::optional> exported_names, mlir::MLIRContext* context, diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc index e1f2c68f04d..ce1242d342f 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include #include #include #include @@ -208,7 +209,7 @@ Status ParseInputArrayInfo( info.shape.set_unknown_rank(true); continue; } - for (auto& dim : node_shapes[i].getValue()) { + for (auto& dim : node_shapes[i].value()) { info.shape.add_dim()->set_size(dim); } } @@ -224,7 +225,7 @@ Status ParseNodeShapes( std::vector node_shapes_str = absl::StrSplit(shapes_str, ':'); for (int i = 0; i < node_shapes_str.size(); i++) { if (node_shapes_str[i] == "*") { - shapes_vector.push_back(llvm::None); + shapes_vector.push_back(std::nullopt); continue; } TF_ASSIGN_OR_RETURN(auto shape, ParseShapeStr(node_shapes_str[i])); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc index a9c335ca5c6..65a6dbaa1c5 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" @@ -32,9 +31,9 @@ namespace tensorflow { using mlir::MLIRContext; -static stream_executor::port::StatusOr> -Import(const GraphOptimizationPassOptions& options, const Graph& graph, - MLIRContext* context) { +static tsl::StatusOr> Import( + const GraphOptimizationPassOptions& options, const Graph& graph, + MLIRContext* context) { // TODO(fengliuai): get debug info at runtime. GraphDebugInfo debug_info; GraphImportConfig specs; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/split_into_island_per_op_pass.cc b/tensorflow/compiler/mlir/tensorflow/translate/split_into_island_per_op_pass.cc index 258e62b2fe1..a1dced4bf5e 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/split_into_island_per_op_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/split_into_island_per_op_pass.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/mlir/tensorflow/translate/split_into_island_per_op_pass.h" + #include #include @@ -45,10 +47,6 @@ class SplitIntoIslandPerOpPass : public impl::SplitIntoIslandPerOpPassBase { public: void runOnOperation() override; - - private: - void SplitIsland(tf_executor::IslandOp island_op, - tf_executor::GraphOp graph_op); }; void SplitIntoIslandPerOpPass::runOnOperation() { @@ -91,7 +89,7 @@ void SplitIntoIslandPerOpPass::runOnOperation() { // Break up all islands by simply creating a new island wrapping each // individual sub op. Do not create any control dependencies between the // newly created islands. - SplitIsland(island_op, graph_op); + SplitIsland(island_op, tf_executor::ControlType::get(&getContext())); // None of the originally given control deps are necessary. tf_executor::FetchOp fetch_op = graph_op.GetFetch(); @@ -103,6 +101,8 @@ void SplitIntoIslandPerOpPass::runOnOperation() { } } +} // namespace + // Populates an empty IslandOp and with a NoOp or Identity/IdentityN depending // on if there are any data results. void PopulateEmptyIsland(tf_executor::IslandOp island) { @@ -114,7 +114,7 @@ void PopulateEmptyIsland(tf_executor::IslandOp island) { Value operand = yield.getOperand(0); auto identity = builder.create(island.getLoc(), operand.getType(), operand); - yield.setOperand(0, identity.output()); + yield.setOperand(0, identity.getOutput()); } else { auto identity_n = builder.create( island.getLoc(), yield.getOperandTypes(), yield.getOperands()); @@ -142,8 +142,8 @@ tf_executor::IslandOp CreateIsland(TypeRange result_types, } // Converts a single island into multiple islands (one for each op). -void SplitIntoIslandPerOpPass::SplitIsland(tf_executor::IslandOp island_op, - tf_executor::GraphOp graph_op) { +void SplitIsland(mlir::tf_executor::IslandOp island_op, + mlir::tf_executor::ControlType control_type) { auto island_body = island_op.GetBody().without_terminator(); // Populate islands that are empty (only yield). if (island_body.empty()) { @@ -154,8 +154,6 @@ void SplitIntoIslandPerOpPass::SplitIsland(tf_executor::IslandOp island_op, // Skip islands that are already only a single op. if (island_op.WrapsSingleOp()) return; - auto control_type = tf_executor::ControlType::get(&getContext()); - // For each operation in the island, construct a new island to wrap the op, // yield all the results, and replace all the usages with the results of the // new island. @@ -171,14 +169,25 @@ void SplitIntoIslandPerOpPass::SplitIsland(tf_executor::IslandOp island_op, for (auto item : llvm::zip(island_op.getOutputs(), island_op.GetYield().getFetches())) std::get<0>(item).replaceAllUsesWith(std::get<1>(item)); + + auto graph_op = island_op->getParentOfType(); + + // Dropping all uses of an island op's control dep using + // `island_op.getControl().dropAllUses();` of a control dep that's only used + // in a graph's fetch, immediately leads to a segfault. Turns out we need to + // drop its uses manually so that we don't leave dangling controls. + for (auto& fetch : llvm::enumerate(graph_op.GetFetch().getFetches())) { + if (fetch.value() == island_op.getControl()) { + graph_op.GetFetch().getFetchesMutable().erase(fetch.index(), 1); + break; + } + } island_op.erase(); } -} // namespace -} // namespace TF - std::unique_ptr> CreateSplitIntoIslandPerOpPass() { return std::make_unique(); } +} // namespace TF } // namespace mlir diff --git a/tensorflow/stream_executor/gpu/asm_compiler.h b/tensorflow/compiler/mlir/tensorflow/translate/split_into_island_per_op_pass.h similarity index 53% rename from tensorflow/stream_executor/gpu/asm_compiler.h rename to tensorflow/compiler/mlir/tensorflow/translate/split_into_island_per_op_pass.h index 1cb3ee052b6..45924dea9f1 100644 --- a/tensorflow/stream_executor/gpu/asm_compiler.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/split_into_island_per_op_pass.h @@ -13,9 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_STREAM_EXECUTOR_GPU_ASM_COMPILER_H_ -#define TENSORFLOW_STREAM_EXECUTOR_GPU_ASM_COMPILER_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_SPLIT_INTO_ISLAND_PER_OP_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_SPLIT_INTO_ISLAND_PER_OP_PASS_H_ -#include "tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" -#endif // TENSORFLOW_STREAM_EXECUTOR_GPU_ASM_COMPILER_H_ +namespace mlir { +namespace TF { + +// Converts a single island into multiple islands (one for each op). +void SplitIsland(mlir::tf_executor::IslandOp island_op, + mlir::tf_executor::ControlType control_type); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_SPLIT_INTO_ISLAND_PER_OP_PASS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index 0d7a8f1242a..2a85214d355 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" +#include "tensorflow/core/util/tensor_bundle/byte_swap_tensor.h" namespace tensorflow { @@ -54,6 +55,8 @@ static StatusOr> GraphdefToMlirImport( GraphDef graphdef; TF_RETURN_IF_ERROR( tensorflow::LoadProtoFromBuffer({input.data(), input.size()}, &graphdef)); + if (!port::kLittleEndian) + TF_RETURN_IF_ERROR(ByteSwapTensorContentInGraphDef(&graphdef)); GraphDebugInfo debug_info; if (!debug_info_file.empty()) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index a5911036edd..3e5aabb5de8 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -27,12 +27,11 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" namespace tensorflow { -using stream_executor::port::Status; -using stream_executor::port::StatusOr; +using tsl::Status; +using tsl::StatusOr; // TODO(antiagainst): Directly manipulating files in library functions is not // a good idea. We should pass in a string/stream here. diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index 3486a6c1114..f1c39aba7ad 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -30,12 +30,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/graph.pb.h" namespace mlir { -using stream_executor::port::Status; -using stream_executor::port::StatusOr; +using tsl::Status; +using tsl::StatusOr; namespace { inline absl::string_view StringRefToView(llvm::StringRef ref) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h index 4d19ff37f15..fb2c22ce385 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_UPGRADE_GRAPH_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_UPGRADE_GRAPH_H_ -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.cc index 2c8ed09b3f7..098c7d19411 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.cc @@ -16,7 +16,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include +#include +#include +#include +#include "absl/strings/str_split.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project @@ -57,5 +61,37 @@ LogicalResult IsValidDeviceTypeOrEmpty(StringAttr device_attr) { return success(); } +LogicalResult ParseParallelExecutionIds(Operation* op, + ParallelExecutionIdPairs& id_pairs) { + auto attr = op->getAttrOfType(kParallelExecAnnotation); + if (!attr) return success(); + + // ID pairs are separated by `,`. + llvm::SmallVector str_list = + absl::StrSplit(attr.getValue().str(), ',', absl::SkipWhitespace()); + id_pairs.reserve(str_list.size()); + for (const std::string& str : str_list) { + // IDs of one pair are separated by `:`. + llvm::SmallVector id_pair = absl::StrSplit(str, ':'); + + // Check for malformed IDs. + if (id_pair.size() != 2) return failure(); + if (id_pair[0].empty() || id_pair[1].empty()) return failure(); + + auto is_digit = [](char c) { return absl::ascii_isdigit(c); }; + const std::string& group_id = id_pair[0]; + if (group_id[0] != 'p' && group_id[0] != 'r') return failure(); + if (!std::all_of(std::next(group_id.begin()), group_id.end(), is_digit)) { + return failure(); + } + const std::string& branch_id = id_pair[1]; + if (!std::all_of(branch_id.begin(), branch_id.end(), is_digit)) { + return failure(); + } + id_pairs.push_back(std::make_pair(id_pair[0], id_pair[1])); + } + return success(); +} + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h index 09dd3bcd20d..0c6a4733dc9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h @@ -16,6 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_ +#include +#include + +#include "llvm/ADT/SmallVector.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/tf2xla/tf2xla_defs.h" @@ -46,6 +50,60 @@ inline constexpr llvm::StringRef kTpuDevice = "TPU"; inline constexpr llvm::StringRef kSkipIslandOutlining = "_skip_island_outlining"; +// This string attribute encodes parallel execution groups and their associated +// branches. It has the following format: +// `_parallel_execution_ids= group1:branch1,group2:branch2,...` +// For example, if we have IR as follows: +// +// tf_executor.island wraps "tf.OpA" +// tf_executor.island { +// "tf_device.replicate" {n = 2} { +// "tf.OpB" +// "tf_device.parallel_execute"() ({ +// "tf.OpC" +// }, { +// "tf.OpD" +// }) +// } +// +// The above IR will be flattened after `ReplicateToIslandPass` and +// `ParallelExecuteToIslandsPass` as follows: +// +// tf_executor.island wraps "tf.OpA" +// tf_executor.island {_parallel_execution_ids=r0:0} wraps "tf.OpB" +// tf_executor.island {_parallel_execution_ids=r0:0,p0:0} wraps "tf.OpC" +// tf_executor.island {_parallel_execution_ids=r0:0,p0:1} wraps "tf.OpD" +// tf_executor.island {_parallel_execution_ids=r0:1} wraps "tf.OpB" +// tf_executor.island {_parallel_execution_ids=r0:1,p0:0} wraps "tf.OpC" +// tf_executor.island {_parallel_execution_ids=r0:1,p0:1} wraps "tf.OpD" +// +// "tf.OpA" will not have `_parallel_execution_ids` attr, +// means it does not belong to any parallel execution groups. +// First instance of "tf.OpB" after flattening will have +// `_parallel_execution_ids = "r0:0"`, +// which represents the first branch of replicate group 0. +// Second instance of "tf.OpB" after flattening will have +// `_parallel_execution_ids = "r0:1"` +// which represents the second branch of replicate group 0. +// First instance of "tf.OpC" after flattening will have +// `_parallel_execution_ids = "r0:0,p0:0"` +// which represents the first branch of replicate group 0 and +// the first branch of parallel group 0. +// Second instance of "tf.OpC" after flattening will have +// `_parallel_execution_ids = "r0:1,p0:0"` +// which represents the second branch of replicate group 0 and +// the first branch of parallel group 0. +// First instance of "tf.OpD" after flattening will have +// `_parallel_execution_ids = "r0:0,p0:1"` +// which represents the first branch of replicate group 0 and +// the second branch of parallel group 0. +// Second instance of "tf.OpD" after flattening will have +// `_parallel_execution_ids = "r0:1,p0:1"` +// which represents the second branch of replicate group 0 and +// the second branch of parallel group 0. +inline constexpr llvm::StringRef kParallelExecAnnotation = + "_parallel_execution_ids"; + // Copies attributes that satisfy the given predicate from `from` to `to`. template void CopyAttributes(Operation *from, Operation *to, Predicate P) { @@ -95,6 +153,14 @@ LogicalResult HasValidCompilationAndReplicationAttributes(Operation &op); // Checks if the device attribute is valid. LogicalResult IsValidDeviceTypeOrEmpty(StringAttr attr); +using ParallelExecutionIdPairs = + llvm::SmallVector, 8>; +// Parses the parallel execution attribute for `op` and fills `id_pairs` with +// the corresponding (group ID,branch ID) pairs. +// Returns `failure` if the attribute is malformed. +LogicalResult ParseParallelExecutionIds(Operation *op, + ParallelExecutionIdPairs &id_pairs); + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 7eca89870f1..44b92532142 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -55,14 +55,15 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" #include "tensorflow/compiler/mlir/xla/transforms/adjust_layout.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/register.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/layout_util.h" #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" @@ -173,7 +174,7 @@ mlir::RankedTensorType GetBufferType(mlir::Type ty) { .dyn_cast_or_null(); if (encoding && !encoding.getBounds().empty()) { for (int64_t dim = 0; dim < rank; ++dim) { - if (dims[dim] == mlir::ShapedType::kDynamicSize) { + if (dims[dim] == mlir::ShapedType::kDynamic) { dims[dim] = encoding.getBounds()[dim]; } } @@ -326,67 +327,26 @@ bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) { } // namespace -Status RefineShapes(llvm::ArrayRef arg_shapes, - mlir::ModuleOp module) { - auto producer_or = GetTfGraphProducerVersion(module); - if (!producer_or.ok()) return producer_or.status(); - int64_t producer_version = producer_or.value(); - - llvm::SmallVector shape_backing; - llvm::SmallVector, 4> arg_shapes_copy; - { - // Convert arg_shapes to a mlir friendly format. - size_t count = 0; - for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) { - if (tensor_resource_shape.is_resource) continue; - count += tensor_resource_shape.shape.dims(); - } - shape_backing.resize(count); - arg_shapes_copy.reserve(arg_shapes.size()); - size_t offset = 0; - for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) { - if (tensor_resource_shape.is_resource) { - arg_shapes_copy.push_back(llvm::ArrayRef()); - continue; - } - size_t start = offset; - for (tensorflow::TensorShapeDim dim : tensor_resource_shape.shape) { - shape_backing[offset] = dim.size; - ++offset; - } - if (offset == start) { - arg_shapes_copy.push_back(llvm::ArrayRef()); - } else { - arg_shapes_copy.push_back( - llvm::ArrayRef(&shape_backing[start], offset - start)); - } - } - } - - auto main_func = module.lookupSymbol("main"); - - mlir::StatusScopedDiagnosticHandler error_handler(module.getContext()); - mlir::LogicalResult result = mlir::TF::InferShapeForFunction( - main_func, arg_shapes_copy, producer_version); - - if (failed(result)) { - return error_handler.Combine( - errors::Internal("MLIR Shape refinement failed")); - } - return error_handler.ConsumeStatus(); -} - void CreateConvertMlirToXlaHloPipeline( mlir::OpPassManager& pm, llvm::StringRef device_type, bool prefer_tf2xla, llvm::MutableArrayRef> custom_legalization_passes, bool allow_partial_conversion) { + bool legalize_chlo = true; + // Note that the region-based control-flow produced here still contains // function call ops which get inlined by the subsequent inliner pass. pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); pm.addPass(mlir::createInlinerPass()); pm.addNestedPass( mlir::TF::CreateDropWhileShapeInvariantPass()); + // Create a replicated TensorList initialization ops for all of its uses. This + // pass undo some CSE because shape_inference is not correctly able to + // identify the shapes of TensorList initialization ops. + // This pass requires CanonicalizerPass before + // CreateTensorListOpsDecompositionPass for clean-ups. + pm.addNestedPass( + mlir::TF::CreateReplicateTensorListInitOpsPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); // The SCCP pass performs constant propagation across the IR, which, for // example, propagates constant arguments into callee functions. @@ -413,28 +373,36 @@ void CreateConvertMlirToXlaHloPipeline( mlir::TFDevice::CreateDecomposeResourceOpsPass()); pm.addPass(mlir::TF::CreatePromoteResourcesToArgsPass()); pm.addPass(mlir::createSymbolDCEPass()); + + // Sink constants to regions so that ops requiring constant operands can + // access the constant and there is no indirection through control flow region + // arguments. Also, note that this pass is in MHLO but it is generic and sinks + // constants for all ops with regions. + pm.addNestedPass( + mlir::mhlo::createSinkConstantsToControlFlowPass()); pm.addPass(mlir::TF::CreateTFShapeInferencePass()); - // TODO(b/171426148): We cannot completely remove region to functional control - // flow conversion from this pipeline yet as it causes some unit tests to - // fail. - pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); - // LegalizeTFControlFlow encapsulates arguments for control flow operations - // with a tuple argument which break the assumption of resource lifting - // inside PromoteResourcesToArgs. - pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass()); + + // Legalize any StableHLO ops to MHLO. Bridge still doesn't use StableHLO but + // such ops might be present in the input from upstream like TFRT compilation. + // Later on, this could be merged in the legalization pass when we migrate + // bridge to StableHLO. + // TODO(b/259459405): Avoid this peculiar use through some refactoring in + // the the caller. + // This needs to happen before legalization. + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); pm.addNestedPass(mlir::TF::CreateLowerQuantizedPass()); pm.addPass(mlir::mhlo::CreateLegalizeTfTypesPass()); pm.addPass(mlir::mhlo::createLegalizeTFModulePass( /*tf2xla_fallback_device_type=*/device_type)); + pm.addNestedPass(mlir::mhlo::createLegalizeTFPass( - /*allow_partial_conversion=*/true, /*legalize_chlo=*/true, + /*allow_partial_conversion=*/true, legalize_chlo, /*tf2xla_fallback_device_type=*/device_type, prefer_tf2xla)); for (auto& target_pass : custom_legalization_passes) { pm.addNestedPass(std::move(target_pass)); } pm.addNestedPass(mlir::mhlo::CreateAdjustLayoutPass()); - pm.addPass(mlir::mhlo::CreateLegalizeTFCommunicationPass()); pm.addPass(mlir::mhlo::CreateLegalizeTFCollectivePass()); pm.addNestedPass(mlir::createCanonicalizerPass()); // Run shape inference pass to propagate shapes through tensor_cast operations @@ -442,15 +410,25 @@ void CreateConvertMlirToXlaHloPipeline( // inference was originally missing in a TF op but the corresponding HLO op // had static shape after lowering. pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + // Run LegalizeTFPass again because the previous legalization passes can // expose more graph pruning and canonicalization opportunities that are // necessary for the second LegalizeTFPass(allow_partial_conversion=false) // invocation. pm.addNestedPass(mlir::mhlo::createLegalizeTFPass( - /*allow_partial_conversion=*/allow_partial_conversion, - /*legalize_chlo=*/true, + /*allow_partial_conversion=*/allow_partial_conversion, legalize_chlo, /*tf2xla_fallback_device_type=*/device_type, prefer_tf2xla)); + // This pass operates on MHLO control flow ops so it should be legalized after + // the control flow ops are legalized. + pm.addPass(mlir::mhlo::CreateLegalizeTFCommunicationPass()); + + // Everything should be MHLO after this. + if (!allow_partial_conversion) { + pm.addNestedPass( + mlir::mhlo::CreateVerifyTFXLALegalizationPass(legalize_chlo)); + } + if (CanInlineFunctionsPostLegalization(device_type)) pm.addPass(mlir::createInlinerPass()); @@ -460,6 +438,56 @@ void CreateConvertMlirToXlaHloPipeline( mlir::mhlo::createSinkConstantsToControlFlowPass()); } +Status RefineShapes(llvm::ArrayRef arg_shapes, + mlir::ModuleOp module) { + auto producer_or = GetTfGraphProducerVersion(module); + if (!producer_or.ok()) return producer_or.status(); + int64_t producer_version = producer_or.value(); + + llvm::SmallVector shape_backing; + llvm::SmallVector, 4> arg_shapes_copy; + { + // Convert arg_shapes to a mlir friendly format. + size_t count = 0; + for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) { + if (tensor_resource_shape.is_resource) continue; + count += tensor_resource_shape.shape.dims(); + } + shape_backing.resize(count); + arg_shapes_copy.reserve(arg_shapes.size()); + size_t offset = 0; + for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) { + if (tensor_resource_shape.is_resource) { + arg_shapes_copy.push_back(llvm::ArrayRef()); + continue; + } + size_t start = offset; + for (tensorflow::TensorShapeDim dim : tensor_resource_shape.shape) { + shape_backing[offset] = dim.size; + ++offset; + } + if (offset == start) { + arg_shapes_copy.push_back(llvm::ArrayRef()); + } else { + arg_shapes_copy.push_back( + llvm::ArrayRef(&shape_backing[start], offset - start)); + } + } + } + + auto main_func = module.lookupSymbol("main"); + + mlir::StatusScopedDiagnosticHandler error_handler(module.getContext()); + mlir::LogicalResult result = mlir::TF::InferShapeForFunction( + main_func, arg_shapes_copy, producer_version); + + if (failed(result)) { + return error_handler.Combine( + errors::Internal("MLIR Shape refinement failed")); + } + return error_handler.ConsumeStatus(); +} + Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type, bool prefer_tf2xla, llvm::MutableArrayRef> diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 7c7d6a46806..0db64e59db8 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -28,22 +28,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_argument.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" namespace tensorflow { -// Populates the supplied passmanager with the passes required to run the -// TF MLIR to XLA HLO MLIR conversion/legalization. Custom legalization passes -// can be populated in `custom_legalization_passes`. -void CreateConvertMlirToXlaHloPipeline( - mlir::OpPassManager& pm, llvm::StringRef device_type, bool prefer_tf2xla, - llvm::MutableArrayRef> - custom_legalization_passes, - bool allow_partial_conversion = false); - // Lowers MLIR module to XLA HLO inside an XlaComputation. The input module // should only contain operations in tf dialect. If the input module contains // operation in the tf_executor dialect, for example, returns an error. @@ -83,6 +73,32 @@ Status ConvertMLIRToXlaComputation( llvm::MutableArrayRef> custom_legalization_passes = {}); +// Creates a MLIR pipeline that lowers MLIR module to MHLO dialect. The input +// module should only contain operations in tf dialect. For example, if the +// input module contains operation in the tf_executor dialect, the pass raises +// an error unless the tf_executor dialect ops are optimized away by +// canonicalization. +// +// The pipeline is used in ConvertMLIRToXlaComputation. And it generally has the +// following pass structure: +// - TensorFlow passes +// - Legalization passes +// - MHLO passes +// +// device_type: XLA JIT device to use for compilation such as "XLA_CPU_JIT", +// "XLA_GPU_JIT" or "XLA_TPU_JIT". +// prefer_tf2xla: when this is true, prefer tf2xla fallback kernels over MLIR +// native kernels for legalization to HLO. +// custom_legalization_passes: passes to run before the default TF legalization +// passes for backend-specific ops. +// allow_partial_conversion: when this is true, allow operations that can't be +// legalized. +void CreateConvertMlirToXlaHloPipeline( + mlir::OpPassManager& pm, llvm::StringRef device_type, bool prefer_tf2xla, + llvm::MutableArrayRef> + custom_legalization_passes, + bool allow_partial_conversion = false); + // Helper struct representing argument tensor or resource handle shapes. struct TensorOrResourceShape { TensorShape shape; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.cc index bd203d2742e..fc0ee8b9d20 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.cc @@ -76,8 +76,7 @@ StatusOr ConvertNonFuncAttributeValue(const AttrValue& value, return tensorflow::errors::Unimplemented( absl::StrCat("Attribute ", value.DebugString())); } - return builder->getArrayAttr( - llvm::makeArrayRef(attrs.begin(), attrs.end())); + return builder->getArrayAttr(llvm::ArrayRef(attrs.begin(), attrs.end())); } case AttrValue::VALUE_NOT_SET: return builder->getUnitAttr(); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h index 7131eda4b6c..3eb21956721 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h @@ -17,12 +17,12 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { -using stream_executor::port::StatusOr; +using tsl::StatusOr; // Converts non func AttrValue proto into an MLIR attribute. Func attribute is // exclused in this function because the function might be renamed when the diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index a4cf7f9e289..42716d6e9ec 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "absl/base/casts.h" #include "absl/container/inlined_vector.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -46,6 +46,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tstring.h" +#include "tensorflow/tsl/platform/float8.h" namespace tensorflow { @@ -85,19 +86,13 @@ StatusOr ConvertFlatTensor(const Tensor& input_tensor, ShapedType type) { auto arr = input_tensor.flat(); return ElementsAttr(mlir::DenseElementsAttr::get( - type, llvm::makeArrayRef(arr.data(), arr.size()))); + type, llvm::ArrayRef(arr.data(), arr.size()))); } -ElementsAttr ConvertBf16Tensor(const Tensor& input_tensor, - RankedTensorType type) { - auto buffer = llvm::makeArrayRef(static_cast(input_tensor.data()), - input_tensor.TotalBytes()); - return mlir::DenseElementsAttr::getFromRawBuffer(type, buffer); -} - -ElementsAttr ConvertHalfTensor(const Tensor& tensor, RankedTensorType type) { - auto buffer = llvm::makeArrayRef(static_cast(tensor.data()), - tensor.TotalBytes()); +ElementsAttr ConvertTensorOfCustomFloatType(const Tensor& tensor, + RankedTensorType type) { + auto buffer = + llvm::ArrayRef(static_cast(tensor.data()), tensor.TotalBytes()); return mlir::DenseElementsAttr::getFromRawBuffer(type, buffer); } @@ -145,12 +140,11 @@ StatusOr ConvertTensor(const Tensor& input_tensor, CONVERT_FLAT(DT_COMPLEX64, std::complex) CONVERT_FLAT(DT_COMPLEX128, std::complex) - // BFLOAT16 is a special case that it needs to be cast to double type to - // match its storage type. case DT_BFLOAT16: - return ConvertBf16Tensor(input_tensor, type); case DT_HALF: - return ConvertHalfTensor(input_tensor, type); + case DT_FLOAT8_E5M2: + case DT_FLOAT8_E4M3FN: + return ConvertTensorOfCustomFloatType(input_tensor, type); case DT_STRING: return ConvertStringTensor(input_tensor, type); default: @@ -262,16 +256,11 @@ PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type) { mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) { if (type.isa()) { - return mlir::TF::ShapeAttr::get(type.getContext(), llvm::None); + return mlir::TF::ShapeAttr::get(type.getContext(), std::nullopt); } if (auto tensor_type = type.dyn_cast()) { - llvm::SmallVector shape; - for (int64_t d : tensor_type.getShape()) { - shape.push_back(ShapedType::isDynamic(d) ? kTFDynamicSize : d); - } - return mlir::TF::ShapeAttr::get(type.getContext(), - llvm::makeArrayRef(shape)); + return mlir::TF::ShapeAttr::get(type.getContext(), tensor_type.getShape()); } // If type is not a RankedTensor or UnrankedTensor, it must be a scalar. @@ -283,15 +272,15 @@ mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) { StatusOr ConvertTensorShapeProto(const TensorShapeProto& shape, mlir::MLIRContext* context) { if (shape.unknown_rank()) - return mlir::TF::ShapeAttr::get(context, llvm::None); + return mlir::TF::ShapeAttr::get(context, std::nullopt); llvm::SmallVector dims; dims.reserve(shape.dim().size()); for (const auto& dim : shape.dim()) { - dims.push_back(dim.size() == kTFDynamicSize ? ShapedType::kDynamicSize + dims.push_back(dim.size() == kTFDynamicSize ? ShapedType::kDynamic : dim.size()); } - return mlir::TF::ShapeAttr::get(context, llvm::makeArrayRef(dims)); + return mlir::TF::ShapeAttr::get(context, llvm::ArrayRef(dims)); } // Converts an MLIR dense string elements attribute to a TensorFlow tensor @@ -400,6 +389,20 @@ void ConvertBfloat16ElementsAttr(const mlir::DenseElementsAttr attr, } } +template +void ConvertFloat8ElementsAttr(const mlir::DenseElementsAttr attr, + std::string* output) { + if (attr.isSplat()) { + if (attr.getSplatValue() != T(0)) + output->push_back( + Eigen::numext::bit_cast(attr.getSplatValue())); + } else { + output->reserve(attr.getNumElements()); + for (const T value : attr.getValues()) + output->push_back(Eigen::numext::bit_cast(value)); + } +} + Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { auto type = attr.getType(); auto shape = type.getShape(); @@ -438,6 +441,14 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { ConvertFloatElementsAttr(dense_attr, output->mutable_float_val(), output->mutable_tensor_content()); break; + case DT_FLOAT8_E5M2: + ConvertFloat8ElementsAttr(dense_attr, + output->mutable_float8_val()); + break; + case DT_FLOAT8_E4M3FN: + ConvertFloat8ElementsAttr( + dense_attr, output->mutable_float8_val()); + break; case DT_QUINT8: case DT_INT8: ConvertUIntElementsAttr(dense_attr, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h index 29c8b470fad..9255667c647 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h @@ -21,14 +21,13 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" namespace tensorflow { -using stream_executor::port::StatusOr; +using tsl::StatusOr; // Converts an TensorFlow tensor proto into an MLIR elements attribute. StatusOr ConvertTensorProto(const TensorProto& input_tensor, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc index 8fef254c74d..115a1cbbfd2 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -25,13 +25,13 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/tensor_util.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/tsl/platform/float8.h" namespace tensorflow { namespace { @@ -140,6 +140,12 @@ TEST_F(ConvertTensorTest, Simple) { {1.0, -1.0}, DT_FLOAT, mlir::FloatType::getF32(&context))); ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1.0, -1.0}, DT_DOUBLE, mlir::FloatType::getF64(&context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {tsl::float8_e5m2{1.0}, tsl::float8_e5m2{-1.0}}, DT_FLOAT8_E5M2, + mlir::FloatType::getFloat8E5M2(&context))); + ASSERT_NO_FATAL_FAILURE(VerifyConversion( + {tsl::float8_e4m3fn{1.0}, tsl::float8_e4m3fn{-1.0}}, DT_FLOAT8_E4M3FN, + mlir::FloatType::getFloat8E4M3FN(&context))); ASSERT_NO_FATAL_FAILURE(VerifyConversion( {1, -1}, DT_INT8, mlir::IntegerType::get(&context, 8))); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc index efcd7e3830a..2546fa44a05 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc @@ -79,6 +79,12 @@ Status ConvertDataType(DataType dtype, Builder builder, Type* type) { case DT_COMPLEX128: *type = mlir::ComplexType::get(builder.getF64Type()); return OkStatus(); + case tensorflow::DT_FLOAT8_E4M3FN: + *type = builder.getFloat8E4M3FNType(); + return ::tensorflow::OkStatus(); + case tensorflow::DT_FLOAT8_E5M2: + *type = builder.getFloat8E5M2Type(); + return ::tensorflow::OkStatus(); #define HANDLE_TF_TYPE(tftype, enumerant, name) \ case DT_##enumerant: \ *type = builder.getType(); \ @@ -104,6 +110,12 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { } else if (type.isBF16()) { *dtype = DT_BFLOAT16; return OkStatus(); + } else if (type.isFloat8E4M3FN()) { + *dtype = DT_FLOAT8_E4M3FN; + return OkStatus(); + } else if (type.isFloat8E5M2()) { + *dtype = DT_FLOAT8_E5M2; + return OkStatus(); } else if (auto itype = type.dyn_cast()) { switch (itype.getWidth()) { case 1: @@ -164,8 +176,7 @@ void ConvertToMlirShape(const TensorShape& input_shape, llvm::SmallVectorImpl* shape) { shape->reserve(input_shape.dims()); for (const auto& d : input_shape) { - shape->push_back(d.size == kTFDynamicSize ? ShapedType::kDynamicSize - : d.size); + shape->push_back(d.size == kTFDynamicSize ? ShapedType::kDynamic : d.size); } } @@ -177,7 +188,7 @@ Status ConvertToMlirShape(const TensorShapeProto& input_shape, if (d.size() > std::numeric_limits::max()) { return errors::InvalidArgument("Shape element overflows"); } - shape->push_back(d.size() == kTFDynamicSize ? ShapedType::kDynamicSize + shape->push_back(d.size() == kTFDynamicSize ? ShapedType::kDynamic : d.size()); } return OkStatus(); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h index 76e05663be5..35a3d1fb156 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h @@ -18,14 +18,13 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" namespace tensorflow { -using stream_executor::port::StatusOr; +using tsl::StatusOr; // Converts the TensorFlow DataType 'dtype' into an MLIR (scalar) type. Status ConvertDataType(DataType dtype, mlir::Builder builder, mlir::Type* type); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc index 0670bbdc20a..7bc65919030 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h index 2d6b0f01e9a..2c400925a88 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_GRAPH_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_GRAPH_H_ +#include #include #include "mlir/IR/OperationSupport.h" // from @llvm-project @@ -50,7 +51,7 @@ struct MlirDumpConfig { // debug information is printed in a more readable 'pretty' form but this // pretty form is not parsable (so only for human readability). MlirDumpConfig& emit_location_information(bool pretty_form = false) { - this->op_printing_flags.enableDebugInfo(pretty_form); + this->op_printing_flags.enableDebugInfo(/*enable=*/true, pretty_form); return *this; } @@ -60,7 +61,7 @@ struct MlirDumpConfig { } // Op printing flags. - mlir::OpPrintingFlags op_printing_flags = llvm::None; + mlir::OpPrintingFlags op_printing_flags = std::nullopt; // The target MLIR dialect. Dialect dialect = Dialect::kTFG; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index a8e2a988159..e6dffa6d217 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -186,8 +186,9 @@ void PrintPassPipeline(const mlir::PassManager& pass_manager, llvm::interleaveComma( pass_manager.getPasses(), passOS, [&](mlir::Pass& pass) { pass.printAsTextualPipeline(passOS); }); - os << "{-# external_resources: { mlir_reproducer: { pipeline: \"" - << passOS.str() << "\", "; + os << "{-# external_resources: { mlir_reproducer: { pipeline: " + "\"builtin.module(" + << passOS.str() << ")\", "; os << "disable_threading: true, "; os << "verify_each: true } } #-}"; os << "\n\n"; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc index d8cefd434f7..fc0dd628eca 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc @@ -15,11 +15,25 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" +#include +#include +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project +#include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/core/framework/device.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" @@ -27,6 +41,8 @@ limitations under the License. namespace tensorflow { namespace { +using ::testing::IsNull; + TEST(DumpMlirModuleTest, NoEnvPrefix) { mlir::MLIRContext context; mlir::OwningOpRef module_ref = @@ -104,7 +120,8 @@ TEST(DumpCrashReproducerTest, Valid) { std::string expected_txt_module; { llvm::raw_string_ostream os(expected_txt_module); - os << "{-# external_resources: { mlir_reproducer: { pipeline: \"\", " + os << "{-# external_resources: { mlir_reproducer: { pipeline: " + "\"builtin.module()\", " "disable_threading: true, verify_each: true } } #-}\n\n"; module_ref->getOperation()->print(os, mlir::OpPrintingFlags().useLocalScope()); @@ -123,6 +140,46 @@ TEST(DumpCrashReproducerTest, Valid) { EXPECT_EQ(file_txt_module, expected_txt_module); } +TEST(DumpCrashReproducerTest, RoundtripDumpAndReadValid) { + mlir::MLIRContext context; + mlir::OwningOpRef module_ref = + mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + + setenv("TF_DUMP_GRAPH_PREFIX", testing::TmpDir().c_str(), 1); + std::string filepath = + testing::TmpDir() + "/" + mlir::TF::kStandardPipelineBefore + ".mlir"; + + std::string output_dump = testing::TmpDir() + "/" + "output_dump.txt"; + + TF_ASSERT_OK(mlir::TF::RunBridgeWithStandardPipeline( + module_ref.get(), + /*enable_logging=*/true, /*enable_inliner=*/false)); + + std::string errorMessage; + auto input_file = mlir::openInputFile(filepath, &errorMessage); + EXPECT_THAT(input_file, Not(IsNull())); + + auto output_stream = mlir::openOutputFile(output_dump, &errorMessage); + EXPECT_THAT(output_stream, Not(IsNull())); + + mlir::PassPipelineCLParser passPipeline( + /*arg=*/"", /*description=*/"Compiler passes to run", /*alias=*/"p"); + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + + mlir::registerAllPasses(); + mlir::registerTensorFlowPasses(); + + EXPECT_TRUE(mlir::MlirOptMain(output_stream->os(), std::move(input_file), + passPipeline, registry, + /*splitInputFile=*/false, + /*verifyDiagnostics=*/false, + /*verifyPasses=*/false, + /*allowUnregisteredDialects=*/false) + .succeeded()); +} + TEST(DumpRawStringToFileTest, Valid) { llvm::StringRef example = "module {\n}"; setenv("TF_DUMP_GRAPH_PREFIX", testing::TmpDir().c_str(), 1); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.cc index 62627642d7f..a82374a23c1 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.cc @@ -21,7 +21,7 @@ namespace tensorflow { llvm::SmallVector ConvertTFShapeToMlir( llvm::ArrayRef shapes) { return llvm::to_vector(llvm::map_range(shapes, [](int64_t shape) { - return shape == kTFDynamicSize ? mlir::ShapedType::kDynamicSize : shape; + return shape == kTFDynamicSize ? mlir::ShapedType::kDynamic : shape; })); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc index d4a5cb22bbc..c03ba4c2f8a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h index 33547c830f9..24152cad81c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h @@ -27,7 +27,6 @@ limitations under the License. #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -40,7 +39,7 @@ class ShapedType; namespace tensorflow { -using stream_executor::port::StatusOr; +using tsl::StatusOr; // Add custom op prefix for TensorFlow dialects. Status AddTensorFlowOpPrefix(std::string); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/parallel_execute_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/parallel_execute_util.cc index 2f37747f6a0..f88e3e57364 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/parallel_execute_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/parallel_execute_util.cc @@ -39,7 +39,7 @@ tf_device::ParallelExecuteOp BuildParallelExecuteOp( LogicalResult RemoveSingletonParallelExecuteOp( tf_device::ParallelExecuteOp parallel_execute, OpBuilder* builder) { - if (parallel_execute.regions().size() == 1) { + if (parallel_execute.getRegions().size() == 1) { builder->setInsertionPoint(parallel_execute); auto& block = parallel_execute.GetRegionBlockWithIndex(0); llvm::SmallVector ops_move; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc index 0846aad45fa..5b89105156d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc @@ -31,7 +31,7 @@ std::string GetVariableName(TF::VarHandleOp var_handle_op) { // then fallback to shared_name attribute. if (auto loc = var_handle_op->getLoc().dyn_cast()) return loc.getName().str(); - return var_handle_op.shared_name().str(); + return var_handle_op.getSharedName().str(); } std::vector GetVariableNames( @@ -50,11 +50,11 @@ tensorflow::Var* GetVariableFromSession(mlir::TF::VarHandleOp var_handle_op, if (!mgr || !mgr->LookupDevice(StringRefToView(device_name), &device).ok()) return nullptr; tensorflow::Var* var_ptr = nullptr; - const auto& container = var_handle_op.container().str(); + const auto& container = var_handle_op.getContainer().str(); auto status = device->resource_manager()->Lookup( (container.empty() ? device->resource_manager()->default_container() : container), - var_handle_op.shared_name().str(), &var_ptr); + var_handle_op.getSharedName().str(), &var_ptr); if (!device || !status.ok()) return nullptr; return var_ptr; } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc index 64e04628fa1..14c723f8fce 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h" +#include + #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" @@ -24,7 +26,7 @@ namespace mlir { namespace TF { LogicalResult InferReturnTypeComponentsForTFOp( - Optional location, Operation* op, int64_t graph_version, + std::optional location, Operation* op, int64_t graph_version, tfg::OperandAsConstantFn operand_as_constant_fn, tfg::OpResultAsShapeFn op_result_as_shape_fn, tfg::ResultElementTypeFn result_element_type_fn, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h index e81ccce70a7..77fb3f8364a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h @@ -33,7 +33,7 @@ namespace TF { // and instead is temporary until shape functions are reimplemented/migrated to // being in MLIR instead of the TensorFlow op registry. LogicalResult InferReturnTypeComponentsForTFOp( - Optional location, Operation* op, int64_t graph_version, + std::optional location, Operation* op, int64_t graph_version, tfg::OperandAsConstantFn operand_as_constant_fn, tfg::OpResultAsShapeFn op_result_as_shape_fn, tfg::ResultElementTypeFn result_element_type_fn, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc index c5758029eee..2799c8eb069 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" @@ -43,8 +44,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/utils/string_container_utils.h" #include "tensorflow/compiler/tf2xla/xla_argument.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/type_to_shape.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -142,7 +143,7 @@ Status ParseArgumentShapes( continue; } TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape( - shape.value().getValue(), &arg_shapes[shape.index()].shape)); + *shape.value(), &arg_shapes[shape.index()].shape)); } return OkStatus(); @@ -232,8 +233,7 @@ Status ParseXlaArguments(absl::string_view input_shapes_str, TensorShape shape; auto input_shapes = std::get<1>(arg_components); if (input_shapes.has_value()) { - TF_RETURN_IF_ERROR( - TensorShapeUtils::MakeShape(input_shapes.getValue(), &shape)); + TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(*input_shapes, &shape)); } else { TF_RETURN_IF_ERROR( TensorShapeUtils::MakeShape(static_cast(nullptr), 0, &shape)); @@ -371,8 +371,10 @@ static mlir::LogicalResult MlirTfGraphToHloTextTranslateFunction( } static void RegisterMlirInputDialects(mlir::DialectRegistry& registry) { - registry.insert(); + // TODO(b/259459405): Remove support for stablehlo as an input. + registry + .insert(); } static void RegisterGraphInputDialects(mlir::DialectRegistry& registry) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.cc b/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.cc new file mode 100644 index 00000000000..460c4baa49e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.cc @@ -0,0 +1,157 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/utils/topological_sort.h" + +#include +#include +#include +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace mlir { +namespace TF { + +ExtraDependenciesFunction no_extra_dependencies = nullptr; + +std::vector SortBlockTopologically( + Block& block, PriorityFunction priorityFunction, + ExtraDependenciesFunction extraDependencies) { + llvm::DenseMap remaining_incoming_data_edges; + llvm::DenseMap remaining_incoming_ctrl_edges; + llvm::DenseMap position; + llvm::DenseMap ancestor; + SmallVector ready; + + llvm::SmallVector empty_op_set; + auto ctrlPredecessors = + [&](Operation* op) -> llvm::SmallVector const& { + if (extraDependencies) { + return extraDependencies(op, /*incoming=*/true); + } else { + return empty_op_set; + } + }; + auto ctrlSuccessors = + [&](Operation* op) -> llvm::SmallVector const& { + if (extraDependencies) { + return extraDependencies(op, /*incoming=*/false); + } else { + return empty_op_set; + } + }; + + int i = 0; + for (Operation& op : block.getOperations()) { + int incoming_ctrl_edges = 0; + int incoming_data_edges = 0; + op.walk([&](Operation* child) { + ancestor[child] = &op; + for (Operation* predecessor : ctrlPredecessors(child)) { + if (predecessor->getBlock() == &block) { + incoming_ctrl_edges++; + } + } + for (Value v : child->getOperands()) { + if (v.getParentBlock() == &block) { + incoming_data_edges++; + } + } + }); + remaining_incoming_data_edges[&op] = incoming_data_edges; + remaining_incoming_ctrl_edges[&op] = incoming_ctrl_edges; + if (incoming_data_edges == 0 && incoming_ctrl_edges == 0) { + ready.push_back(&op); + } + position[&op] = i++; + } + + std::queue todo; + for (Value value : block.getArguments()) { + todo.push(value); + } + + std::vector result; + Operation* previous_op = nullptr; + while (!todo.empty() || !ready.empty()) { + while (!todo.empty()) { + Value value = todo.front(); + todo.pop(); + // All operations that have all their inputs available are good to go. + // Uses, not Users, in case getUsers ever dedups. + for (OpOperand& operand : value.getUses()) { + Operation* user = ancestor[operand.getOwner()]; + remaining_incoming_data_edges[user]--; + if (remaining_incoming_data_edges[user] == 0 && + remaining_incoming_ctrl_edges[user] == 0) { + ready.push_back(user); + } + } + } + + // Find the "best" operation to emit. We + // (a) emit the terminator last. + // (b) honor the priority function (as far as possible). + // (c) preserve order within the ops of one dialect. + auto better = [&](Operation* a, Operation* b) { + if (a->hasTrait() != + b->hasTrait()) { + return b->hasTrait(); + } + int a_priority = priorityFunction(previous_op, a); + int b_priority = priorityFunction(previous_op, b); + if (a_priority != b_priority) { + return a_priority > b_priority; + } else { + return position[a] < position[b]; // preserve order + } + }; + + Operation* best = nullptr; + for (Operation* op : ready) { + if (best == nullptr || better(op, best)) { + best = op; + } + } + + if (!best) { + assert(ready.empty()); + return result; // happens for unused results for ops in the todo list + } + + // Consider this operation emitted, and make its results available. + ready.erase(std::find(ready.begin(), ready.end(), best)); + previous_op = best; + for (Value result : best->getResults()) { + todo.push(result); + } + for (Operation* successor : ctrlSuccessors(best)) { + if (ancestor.find(successor) != ancestor.end()) { + successor = ancestor[successor]; + remaining_incoming_ctrl_edges[successor]--; + if (remaining_incoming_ctrl_edges[successor] == 0 && + remaining_incoming_data_edges[successor] == 0) { + ready.push_back(successor); + } + } + } + result.push_back(best); + } + return result; +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.h b/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.h new file mode 100644 index 00000000000..a62fe17add2 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.h @@ -0,0 +1,70 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TOPOLOGICAL_SORT_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TOPOLOGICAL_SORT_H_ + +#include +#include + +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// A function that determines which op to emit next in the case of ties. +// The predecessor (which can be null) is the last op we emitted, +// and op is the candidate we're considering. A larger returned integer +// means the op has a higher chance of being emitted first. +typedef int (*PriorityFunction)(Operation *predecessor, Operation *op); + +// A function that returns extra dependencies for each op. These might +// e.g. be known side-effects (or control dependencies) between ops. +// If "incoming" is true, then the list of (extra) predecessors of the +// op should be returned. If "incoming" is false, the list of successors. +// The algorithm assumes that these are consistent which each other. So +// if (and only if) op1 is in extra_dependencies(op2, true), then op2 +// must also be in extra_dependencies(op1, false). +// This function is called multiple times during the topological sort, +// so the implementation should preferably be constant-time. +typedef llvm::function_ref const &( + Operation *, bool incoming)> + ExtraDependenciesFunction; + +// Convenience function if there are no extra dependencies to declare. +// (Unlike nullptr, this also works inside the ternary operator) +extern ExtraDependenciesFunction no_extra_dependencies; + +// Sort a block topologically, so that for all ops, all operands are +// available at the time of execution. This is similar to MLIR's topological +// sort (lib/Transforms/TopologicalSort.cpp) but also takes a priority +// function to determine the next op to emit in the case of ambiguity. This +// makes it possible to group operations by certain attributes. For example, +// the order_by_dialect pass uses this function to group by dialect. +// Only the operations nested directly under the block will be reordered. +// Nested blocks will be left alone. +// Also takes a list of control dependencies (vector of operation pairs, +// from->to) that will be honored when ordering the ops together with the +// data dependencies given through (the ops/results of) the operations +// themselves. +std::vector SortBlockTopologically( + Block &block, PriorityFunction priorityFunction, + ExtraDependenciesFunction extraDependencies = no_extra_dependencies); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TOPOLOGICAL_SORT_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index 264d4576b9d..8529a55753f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/utils/string_container_utils.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/service/computation_placer.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/errors.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h index 6c515922416..d51a06feaa5 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h @@ -27,13 +27,13 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { -using stream_executor::port::StatusOr; +using tsl::StatusOr; inline constexpr absl::string_view kTPUReplicatedHost = "TPU_REPLICATED_HOST"; inline constexpr absl::string_view kNumCoresPerReplicaAttr = diff --git a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc index f1a301342a7..988950389ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc @@ -66,8 +66,7 @@ mlir::LogicalResult ExtractTfVersions(mlir::ModuleOp module, return mlir::success(); } -::stream_executor::port::StatusOr GetTfGraphProducerVersion( - mlir::ModuleOp module) { +::tsl::StatusOr GetTfGraphProducerVersion(mlir::ModuleOp module) { auto versions = module->getAttrOfType<::mlir::DictionaryAttr>("tf.versions"); if (!versions) { return errors::Internal( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h index 4f2a306b72d..feccb1754d5 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h @@ -20,8 +20,8 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/platform/statusor.h" namespace tensorflow { @@ -37,8 +37,7 @@ mlir::LogicalResult ExtractTfVersions(mlir::ModuleOp module, // Returns TensorFlow GraphDef producer version for the given module. Returns an // error if the version information is missing for the module or is not valid. -::stream_executor::port::StatusOr GetTfGraphProducerVersion( - mlir::ModuleOp module); +::tsl::StatusOr GetTfGraphProducerVersion(mlir::ModuleOp module); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.cc index be30e296a46..bcc70642c97 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.cc @@ -25,7 +25,7 @@ namespace TF { LogicalResult VerifyShapeOfReshapeOp(ArrayRef shape) { bool has_dynamic_dim = false; for (int64_t dim : shape) { - if (dim != ShapedType::kDynamicSize) { + if (dim != ShapedType::kDynamic) { if (dim < 0) return failure(); continue; } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index 2b607727e20..2a6d94828a3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -62,8 +62,7 @@ mlir::LogicalResult CreateSplitOp(const int num_split, auto input_type = src_input.getType().cast(); if (input_type.hasRank()) { - if (input_type.getShape()[split_dimension] == - mlir::ShapedType::kDynamicSize) { + if (input_type.getShape()[split_dimension] == mlir::ShapedType::kDynamic) { output_type = input_type; } else { auto shape = llvm::to_vector<4>(input_type.getShape()); @@ -87,7 +86,7 @@ mlir::LogicalResult CreateSplitOp(const int num_split, // Creates a split op that splits |src_input| along |split_dimension|. llvm::SmallVector output_types(num_split, output_type); *split_op = builder->create( - location, output_types, split_dimension_op.output(), src_input); + location, output_types, split_dimension_op.getOutput(), src_input); (*split_op)->setAttr( kNumSplitAttr, builder->getIntegerAttr(builder->getIntegerType(32), num_split)); @@ -115,8 +114,7 @@ mlir::TF::ConcatOp CreateConcatOp(const int concat_dimension, auto input_type = inputs[0].getType().cast(); if (input_type.hasRank()) { - if (input_type.getShape()[concat_dimension] == - mlir::ShapedType::kDynamicSize) { + if (input_type.getShape()[concat_dimension] == mlir::ShapedType::kDynamic) { output_type = input_type; } else { auto shape = llvm::to_vector<4>(input_type.getShape()); @@ -129,7 +127,7 @@ mlir::TF::ConcatOp CreateConcatOp(const int concat_dimension, } return builder->create( - location, output_type, concat_dimension_op.output(), inputs); + location, output_type, concat_dimension_op.getOutput(), inputs); } // For tile sharded inputs to TPU computation, inject split op between the @@ -245,11 +243,11 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( input_index, tiled_input_size, num_cores_per_replica)); }; - // If input is already partitioned using the `tf.TPUPartitionedInput` op, + // If input is already partitioned using the `tf.TPUPartitionedInputV2` op, // only replicated sharding is supported where i-th operand to - // `tf.TPUPartitionedInput` op is input to the i-th logical device. + // `tf.TPUPartitionedInputV2` op is input to the i-th logical device. if (auto partitioned_input = - llvm::dyn_cast_or_null( + llvm::dyn_cast_or_null( input_value.getDefiningOp())) { if (UnsupportedPartitionedShardingType(input_sharding_type)) return cluster_func->emitOpError() @@ -264,14 +262,15 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( } } else { assert(input_sharding_type == xla::OpSharding::OTHER); - if (partitioned_input.inputs().size() != num_cores_per_replica) - return tiled_sharding_mismatched(partitioned_input.inputs().size()); + if (partitioned_input.getInputs().size() != num_cores_per_replica) + return tiled_sharding_mismatched( + partitioned_input.getInputs().size()); for (int i = 0; i < sharding.tile_assignment_devices_size(); ++i) { const int assigned_logical_device = sharding.tile_assignment_devices(i); (*input_list)[assigned_logical_device].emplace_back( - partitioned_input.inputs()[i]); + partitioned_input.getInputs()[i]); } } continue; @@ -479,7 +478,7 @@ mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape( const auto output_splits = dimension_and_output_splits.value(); const auto output_shape = cluster_func_output_type.getShape(); - if (output_shape[dimension_index] == mlir::ShapedType::kDynamicSize) { + if (output_shape[dimension_index] == mlir::ShapedType::kDynamic) { *tiled_logical_computation_type = cluster_func_output_type; break; } @@ -577,15 +576,16 @@ mlir::LogicalResult RemapOutputsFromLogicalDevices( output_sharding_config[tpu_cluster_output_index]; const auto output_sharding_type = output_sharding.type(); - // If output is demultiplexed using the `tf.TPUPartitionedOutput` op, only + // If output is demultiplexed using the `tf.TPUPartitionedOutputV2` op, only // replicated sharding is supported where i-th output of - // `tf.TPUPartitionedOutput` op maps to the output of i-th logical device. - // Also `tf.TPUPartitionedOutput` op must be a unique user of + // `tf.TPUPartitionedOutputV2` op maps to the output of i-th logical device. + // Also `tf.TPUPartitionedOutputV2` op must be a unique user of // TPU Cluster (`tf_device.old_parallel_execute`) output. - mlir::TF::TPUPartitionedOutputOp partitioned_output; + mlir::TF::TPUPartitionedOutputV2Op partitioned_output; for (auto user : old_parallel_execute_output.getUsers()) { if (auto partitioned_output_user = - llvm::dyn_cast_or_null(user)) { + llvm::dyn_cast_or_null( + user)) { partitioned_output = partitioned_output_user; break; } @@ -604,7 +604,7 @@ mlir::LogicalResult RemapOutputsFromLogicalDevices( if (output_sharding_type == xla::OpSharding::REPLICATED) { for (const auto& index_and_output : - llvm::enumerate(partitioned_output.output())) { + llvm::enumerate(partitioned_output.getOutput())) { const auto output_from_logical_device = new_parallel_execute.GetRegionOutputs( cluster_idx + @@ -621,7 +621,7 @@ mlir::LogicalResult RemapOutputsFromLogicalDevices( &tile_sharded_outputs))) return mlir::failure(); for (auto result : - llvm::zip(partitioned_output.output(), tile_sharded_outputs)) + llvm::zip(partitioned_output.getOutput(), tile_sharded_outputs)) std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); } continue; diff --git a/tensorflow/compiler/mlir/tf2xla/BUILD b/tensorflow/compiler/mlir/tf2xla/BUILD index 1727147da34..206b77246c1 100644 --- a/tensorflow/compiler/mlir/tf2xla/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/BUILD @@ -1,6 +1,8 @@ # Description: # TF2XLA Bridge and related components. +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + cc_library( name = "mlir_bridge_rollout_policy", srcs = ["mlir_bridge_rollout_policy.cc"], diff --git a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h index f8fd9f59d56..262ebc0fd2e 100644 --- a/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h +++ b/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h @@ -36,9 +36,6 @@ enum class MlirBridgeRolloutPolicy { // features in the model, the MLIR bridge should be run. If the MLIR Bridge // errors, the fallback path should be used whenever possible. kEnabledAfterGraphAnalysis, - // The bridge was fallback enabled in a safe mode and passed all graph - // analysis checks. - kEnabledAfterGraphAnalysisSafeModeFallback }; // Analyzes the user requested policy as well as the contents of the graph and diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index 2a732b54484..7c768abb3fd 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project -#include "mlir/InitAllDialects.h" // from @llvm-project -#include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "stablehlo/dialect/Register.h" // from @stablehlo #include "tensorflow//compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "tensorflow/compiler/mlir/init_mlir.h" @@ -34,16 +36,15 @@ limitations under the License. #include "tensorflow/compiler/mlir/tosa/tfl_passes.h" #include "tensorflow/compiler/mlir/tosa/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" -#include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/compiler/xla/mlir/framework/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/register.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" int main(int argc, char **argv) { tensorflow::InitMlir y(&argc, &argv); - mlir::registerAllPasses(); + mlir::registerTransformsPasses(); mlir::registerTensorFlowPasses(); mlir::TFDevice::registerTensorFlowDevicePasses(); mlir::tf_saved_model::registerTensorFlowSavedModelPasses(); @@ -51,11 +52,9 @@ int main(int argc, char **argv) { mlir::mhlo::registerAllMhloPasses(); mlir::lmhlo::registerAllLmhloPasses(); // These are in compiler/mlir/xla and not part of the above MHLO passes. + mlir::mhlo::registerLegalizeTfPasses(); mlir::mhlo::registerTfXlaPasses(); - mlir::mhlo::registerXlaPasses(); - mlir::mhlo::registerLegalizeTFPass(); - mlir::mhlo::registerLegalizeTFControlFlowPass(); - mlir::mhlo::registerLegalizeTfTypesPassPass(); + mlir::mhlo::registerXlaFrameworkPasses(); mlir::tosa::registerLegalizeTosaPasses(); mlir::tosa::registerTFtoTOSALegalizationPipeline(); mlir::tosa::registerTFLtoTOSALegalizationPipeline(); @@ -66,15 +65,16 @@ int main(int argc, char **argv) { tensorflow::RegisterMlProgramPasses(); mlir::DialectRegistry registry; - mlir::registerAllDialects(registry); mlir::RegisterAllTensorFlowDialects(registry); mlir::mhlo::registerAllMhloDialects(registry); mlir::stablehlo::registerAllDialects(registry); - registry.insert(); - registry.insert(); + registry.insert(); registry.insert(); registry.insert(); - registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); return failed( mlir::MlirOptMain(argc, argv, "TensorFlow pass driver\n", registry)); } diff --git a/tensorflow/compiler/mlir/tf_mlir_reduce_main.cc b/tensorflow/compiler/mlir/tf_mlir_reduce_main.cc index a0e3df7f297..00953ce144a 100644 --- a/tensorflow/compiler/mlir/tf_mlir_reduce_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_reduce_main.cc @@ -27,10 +27,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" -#include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir/framework/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/register.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" namespace { @@ -60,9 +60,8 @@ int main(int argc, char *argv[]) { mlir::lmhlo::registerAllLmhloPasses(); // These are in compiler/mlir/xla and not part of the above MHLO passes. mlir::mhlo::registerTfXlaPasses(); - mlir::mhlo::registerXlaPasses(); + mlir::mhlo::registerXlaFrameworkPasses(); mlir::mhlo::registerLegalizeTFPass(); - mlir::mhlo::registerLegalizeTFControlFlowPass(); mlir::mhlo::registerLegalizeTfTypesPassPass(); mlir::DialectRegistry registry; diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index dc7f6ed1ddc..6758aee3b77 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -88,7 +88,7 @@ int main(int argc, char** argv) { tensorflow::InitMlir y(&argc, &argv); // Add flags for all the registered translations. - llvm::cl::opt + llvm::cl::opt requested_translation("", llvm::cl::desc("Translation to perform")); mlir::registerAsmPrinterCLOptions(); llvm::cl::ParseCommandLineOptions(argc, argv, "TF MLIR translation driver\n"); @@ -155,10 +155,10 @@ int main(int argc, char** argv) { // Processes the memory buffer with a new MLIRContext. auto processBuffer = [&](std::unique_ptr ownedBuffer, llvm::raw_ostream& os) { - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc()); + auto sourceMgr = std::make_shared(); + sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc()); mlir::MLIRContext context; - mlir::SourceMgrDiagnosticHandler diagnostic_handler(sourceMgr, &context); + mlir::SourceMgrDiagnosticHandler diagnostic_handler(*sourceMgr, &context); return (*requested_translation)(sourceMgr, os, &context); }; diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index db66e902602..80fc8ae1c84 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -15,6 +15,7 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ ":friends", ], @@ -223,7 +224,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:export_graphdef", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", @@ -246,14 +246,13 @@ tf_cc_test( deps = [ ":tfr_decompose_ctx", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:framework", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/types:span", - "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", ], ) @@ -265,7 +264,6 @@ cc_library( deps = [ ":tfr_decompose_ctx", "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:lib", "//tensorflow/core/common_runtime:device_set", "@llvm-project//mlir:IR", @@ -297,7 +295,6 @@ cc_library( hdrs = ["integration/node_expansion_pass.h"], deps = [ ":tfr_decompose_ctx", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/core:lib", "//tensorflow/core/common_runtime/eager:core_no_xla", "//tensorflow/core/common_runtime/eager:eager_op_rewrite_registry", diff --git a/tensorflow/compiler/mlir/tfr/examples/customization/BUILD b/tensorflow/compiler/mlir/tfr/examples/customization/BUILD index 10053c2425a..748a189e25c 100644 --- a/tensorflow/compiler/mlir/tfr/examples/customization/BUILD +++ b/tensorflow/compiler/mlir/tfr/examples/customization/BUILD @@ -2,6 +2,7 @@ load("//tensorflow:tensorflow.default.bzl", "tf_py_test") load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ ":friends", ], diff --git a/tensorflow/compiler/mlir/tfr/examples/customization/test_ops_test.py b/tensorflow/compiler/mlir/tfr/examples/customization/test_ops_test.py index c0fa6aa4a6d..6310782e92e 100644 --- a/tensorflow/compiler/mlir/tfr/examples/customization/test_ops_test.py +++ b/tensorflow/compiler/mlir/tfr/examples/customization/test_ops_test.py @@ -24,7 +24,7 @@ class TestOpsDefsTest(test_utils.OpsDefsTest): def test_test_ops(self): - attr = tf.function(test_ops.test_attr)(T=tf.float32) + attr = tf.function(test_ops.test_attr)(tf.float32) self.assertAllClose(attr.numpy(), 100.0) diff --git a/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD index 4ae874c8fb3..4160d864f2e 100644 --- a/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD +++ b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD @@ -3,6 +3,7 @@ load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries") load("//tensorflow/core/platform:distribute.bzl", "distribute_py_test") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ ":friends", ], diff --git a/tensorflow/compiler/mlir/tfr/examples/pad/BUILD b/tensorflow/compiler/mlir/tfr/examples/pad/BUILD index 1602facf04c..837ce6e4f22 100644 --- a/tensorflow/compiler/mlir/tfr/examples/pad/BUILD +++ b/tensorflow/compiler/mlir/tfr/examples/pad/BUILD @@ -2,6 +2,7 @@ load("//tensorflow:tensorflow.default.bzl", "tf_py_test") load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_libraries") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ ":friends", ], diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc index 5eb23b5bddf..206e5ef13f8 100644 --- a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.cc @@ -17,8 +17,8 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h index 462ca6d7269..e415f5cbea9 100644 --- a/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h +++ b/tensorflow/compiler/mlir/tfr/integration/graph_decompose_pass.h @@ -18,7 +18,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { namespace tfr { diff --git a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc index 21d61d76b41..f078093e3e2 100644 --- a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc +++ b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.cc @@ -19,9 +19,9 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/lib/monitoring/counter.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h index 319d9f96955..d29f2bdbf32 100644 --- a/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h +++ b/tensorflow/compiler/mlir/tfr/integration/node_expansion_pass.h @@ -16,8 +16,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_NODE_EXPANSION_PASS_H_ #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { namespace tfr { diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc index d031330e906..7fddd2526d5 100644 --- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc @@ -47,10 +47,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" #include "tensorflow/compiler/mlir/tfr/passes/passes.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/util/env_var.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { namespace tfr { diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h index ab97468c525..79c41a3f946 100644 --- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h @@ -20,15 +20,15 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/tsl/platform/statusor.h" namespace tensorflow { namespace tfr { extern const char* const kTFRLibEnv; -using stream_executor::port::StatusOr; +using tsl::StatusOr; // An wrapper for all the objects used to decompose a module (graph mode) and // node_def (eager mode). Note that this class owns the decomposition library. diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc index c4b2dd4c8f0..cf25bfd5020 100644 --- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/common_shape_fns.h" @@ -33,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/tsl/platform/statusor.h" using testing::ElementsAreArray; using testing::Test; diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc index 22253dc91ff..8c33af424b8 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc @@ -74,7 +74,7 @@ class TFRInlinerInterface : public DialectInlinerInterface { // Returns true if the given region 'src' can be inlined into the region // 'dest' that is attached to an operation registered to the current dialect. bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, - BlockAndValueMapping &) const final { + IRMapping &) const final { return true; } @@ -82,7 +82,7 @@ class TFRInlinerInterface : public DialectInlinerInterface { // dialect, can be inlined into the region 'dest' that is attached to an // operation registered to the current dialect. bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, - BlockAndValueMapping &) const final { + IRMapping &) const final { return true; } @@ -93,7 +93,7 @@ class TFRInlinerInterface : public DialectInlinerInterface { auto retValOp = dyn_cast(op); if (!retValOp) return; - for (auto ret_value : llvm::zip(valuesToRepl, retValOp.operands())) { + for (auto ret_value : llvm::zip(valuesToRepl, retValOp.getOperands())) { std::get<0>(ret_value).replaceAllUsesWith(std::get<1>(ret_value)); } } @@ -359,11 +359,15 @@ ParseResult TFRFuncOp::parse(OpAsmParser &parser, OperationState &result) { function_interface_impl::VariadicFlag, std::string &) { return builder.getFunctionType(arg_types, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, build_func_type); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), build_func_type, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void TFRFuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } } // namespace TFR @@ -679,21 +683,21 @@ class RemoveQParamsOp : public OpRewritePattern { auto scales_type = RankedTensorType::get( {static_cast(num_channels)}, rewriter.getF32Type()); auto scales_attr = - DenseElementsAttr::get(scales_type, llvm::makeArrayRef(scales)); + DenseElementsAttr::get(scales_type, llvm::ArrayRef(scales)); scale_op = rewriter.create(loc, scales_attr); auto zps_type = RankedTensorType::get( {static_cast(num_channels)}, rewriter.getI32Type()); - auto zps_attr = DenseElementsAttr::get(zps_type, llvm::makeArrayRef(zps)); + auto zps_attr = DenseElementsAttr::get(zps_type, llvm::ArrayRef(zps)); zp_op = rewriter.create(loc, zps_attr); } if (!scale_op || !zp_op) { return failure(); } auto scale_cast = rewriter.create( - loc, qparams_op.getScale().getType(), scale_op.output()); + loc, qparams_op.getScale().getType(), scale_op.getOutput()); auto zp_cast = rewriter.create(loc, qparams_op.getZp().getType(), - zp_op.output()); + zp_op.getOutput()); qparams_op.getScale().replaceAllUsesWith(scale_cast.getOut()); qparams_op.getZp().replaceAllUsesWith(zp_cast.getOut()); @@ -770,10 +774,9 @@ class RemoveScaleFactorOp : public OpRewritePattern { rewriter.setInsertionPoint(scale_factor_op); const Location loc = scale_factor_op->getLoc(); auto result_scale_op = rewriter.create( - loc, - DenseElementsAttr::get(scale_type, llvm::makeArrayRef(scale_factors))); + loc, DenseElementsAttr::get(scale_type, llvm::ArrayRef(scale_factors))); auto result_scale_cast_op = rewriter.create( - loc, scale_factor_op.getType(), result_scale_op.output()); + loc, scale_factor_op.getType(), result_scale_op.getOutput()); scale_factor_op.getScaleFactor().replaceAllUsesWith( result_scale_cast_op.getOut()); return success(); @@ -810,7 +813,7 @@ class RemoveRescaleOp : public OpRewritePattern { auto zp_tensor = rewriter.create( loc, RankedTensorType::get({}, zp.getType()), zp_attr); auto zp_cast = rewriter.create( - loc, rewriter.getType(), zp_tensor.output()); + loc, rewriter.getType(), zp_tensor.getOutput()); rewriter.setInsertionPoint(rescale_op); auto cast_input_to_float_op = rewriter.create( @@ -893,14 +896,17 @@ void TFRQuantScaleFactorOp::getCanonicalizationPatterns( results.add(context); } -OpFoldResult TFR::EqualOp::fold(ArrayRef operands) { +OpFoldResult TFR::EqualOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); assert(operands.size() == 2 && "equal op has two operands"); auto ctx = getContext(); if (operands[0] == operands[1]) return BoolAttr::get(ctx, true); return BoolAttr::get(ctx, false); } -OpFoldResult ConstOp::fold(ArrayRef operands) { +OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); + (void)operands; assert(operands.empty() && "constant has no operands"); // Return the held attribute value. diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td index 12bac76e735..66ae99ae2ec 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td @@ -41,8 +41,7 @@ def TFR_Dialect : Dialect { }]; let cppNamespace = "::mlir::TFR"; - - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let useFoldAPI = kEmitFoldAdaptorFolder; } //===----------------------------------------------------------------------===// @@ -416,7 +415,9 @@ def TFR_TFRFuncOp : TFR_Op<"func", [HasParent<"ModuleOp">, let arguments = (ins TypeAttrOf:$function_type, - StrAttr:$sym_name + StrAttr:$sym_name, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs); @@ -473,9 +474,9 @@ def TFR_TFRReturnOp : TFR_Op<"return", [HasParent<"TFRFuncOp">, Pure, Note that only the tfr.tensor and tfr.tensor_list can be returned. }]; - let arguments = (ins Variadic:$operands); + let arguments = (ins Variadic:$arguments); - let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + let assemblyFormat = "attr-dict ($arguments^ `:` type($arguments))?"; } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc b/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc index df30f3779d3..443781b6b63 100644 --- a/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc +++ b/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc @@ -24,7 +24,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project @@ -66,7 +66,7 @@ class UnrollSCFForOp : public OpRewritePattern { // TODO(fengliuai): use loopUnrollByFactor once the iter_arg is supported Block *single_block = for_op.getBody(); - BlockAndValueMapping mapping; + IRMapping mapping; Value iv = for_op.getInductionVar(); for (auto iter_op : llvm::zip(for_op.getRegionIterArgs(), for_op.getInitArgs())) { diff --git a/tensorflow/compiler/mlir/tfr/resources/BUILD b/tensorflow/compiler/mlir/tfr/resources/BUILD index 07944d6e940..12966f2920f 100644 --- a/tensorflow/compiler/mlir/tfr/resources/BUILD +++ b/tensorflow/compiler/mlir/tfr/resources/BUILD @@ -1,6 +1,7 @@ load("//tensorflow/compiler/mlir/tfr:build_defs.bzl", "gen_op_bindings") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ ":friends", ], diff --git a/tensorflow/compiler/mlir/tfr/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tfr/tests/canonicalize.mlir index 6188141a445..77508b60046 100644 --- a/tensorflow/compiler/mlir/tfr/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tfr/tests/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: tfr-opt %s -canonicalize -verify-diagnostics -split-input-file | FileCheck %s +// RUN: tfr-opt %s -canonicalize="test-convergence" -verify-diagnostics -split-input-file | FileCheck %s // Tests for ops with canonicalization patterns. diff --git a/tensorflow/compiler/mlir/tfr/tests/ops.mlir b/tensorflow/compiler/mlir/tfr/tests/ops.mlir index e05f96b83a8..592ab9e9090 100644 --- a/tensorflow/compiler/mlir/tfr/tests/ops.mlir +++ b/tensorflow/compiler/mlir/tfr/tests/ops.mlir @@ -28,7 +28,7 @@ func.func private @tensor_list_type_tuple_like() -> !tfr.tensor_list // ----- -// expected-error@+1 {{unbalanced '>' character in pretty dialect name}} +// expected-error@+1 {{unbalanced '[' character in pretty dialect name}} func.func private @tensor_invalid_1() -> !tfr.tensor<[N, T> // ----- diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 296766aa31d..bc3fbfb8cdb 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -12,6 +12,7 @@ load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud", "get_co # TF to TFRT kernels conversion. package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [":friends"], licenses = ["notice"], ) @@ -38,7 +39,6 @@ package_group( "//smartass/brain/ops/...", "//tensorflow_serving/servables/tensorflow/google/...", "//third_party/tf_runtime_google/...", - "//third_party/auroraml/...", ]), ) @@ -113,9 +113,11 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_passes", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", + "//tensorflow/compiler/xla/mlir/backends/cpu/transforms:passes", "//tensorflow/compiler/xla/mlir/runtime/ir:rt", "//tensorflow/compiler/xla/mlir/runtime/transforms:compiler", "//tensorflow/compiler/xla/mlir_hlo:gml_st_passes", + "//tensorflow/compiler/xla/mlir_hlo:gml_st_transforms", "//tensorflow/compiler/xla/mlir_hlo:mhlo_passes", "//tensorflow/compiler/xla/mlir_hlo:transforms_passes", "//tensorflow/compiler/xla/runtime:compiler", @@ -231,7 +233,6 @@ tfrt_cc_library( name = "tf_jitrt_request_context", srcs = ["jit/tf_jitrt_request_context.cc"], hdrs = ["jit/tf_jitrt_request_context.h"], - # copybara:uncomment compatible_with = ["//buildenv/target:gce"], deps = [ "//tensorflow/compiler/xla/runtime:async_values_cache", "//tensorflow/compiler/xla/runtime:jit_executable", @@ -338,7 +339,6 @@ cc_library( "@llvm-project//mlir:Transforms", "@tf_runtime//:basic_kernels_opdefs", "@tf_runtime//:core_runtime_opdefs", - "@tf_runtime//:distributed_kernels_opdefs", ], ) @@ -361,6 +361,22 @@ cc_library( ], ) +cc_library( + name = "transforms/gpu_passes", + srcs = ["transforms/gpu_passes.cc"], + hdrs = ["transforms/gpu_passes.h"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tfrt/ir:tfrt_gpu_opdefs", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) + cc_library( name = "tf_to_tfrt", srcs = [ @@ -372,10 +388,10 @@ cc_library( "transforms/merge_tf_if_ops.cc", "transforms/optimize.cc", "transforms/optimize_tf_control_flow_side_effect.cc", - "transforms/remote_run_encapsulate.cc", "transforms/remove_device_attribute.cc", "transforms/remove_tf_if_const_args.cc", "transforms/reorder_assert.cc", + "transforms/sink_in_invariant_ops.cc", "transforms/tf_to_tfrt.cc", "transforms/tpu_passes.h", ], @@ -390,6 +406,7 @@ cc_library( ":tensor_array_side_effect_analysis", ":tf_jitrt_opdefs", ":tf_jitrt_pipeline", + ":transforms/gpu_passes", ":transforms/set_shape_invariant_in_while_ops", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -413,49 +430,20 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core/platform:tstring", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", + "//tensorflow/compiler/mlir/tfrt/ir:tfrt_gpu_opdefs", "@tf_runtime//:basic_kernels_opdefs", "@tf_runtime//:core_runtime_opdefs", - "@tf_runtime//:distributed_kernels_opdefs", "@tf_runtime//backends/jitrt:jitrt_opdefs", "@tf_runtime//:stream_analysis", "@tf_runtime//:test_kernels_opdefs", + "//tensorflow/compiler/mlir/tfrt:transform_utils", + "//tensorflow/tsl/platform:status", ] + if_google([ "//learning/brain/tfrt/tpu/compiler/mlir:tf_to_tfrt_tpu", ]), alwayslink = 1, ) -cc_library( - name = "tf_to_tfrt_data", - srcs = [ - "transforms/tf_to_tfrt_data.cc", - ], - hdrs = [ - "transforms/tf_to_tfrt_data.h", - ], - deps = [ - ":tf_to_tfrt", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow:import_model", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:status", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Transforms", - "@tf_runtime//:basic_kernels_opdefs", - "@tf_runtime//:bef", - "@tf_runtime//:data_opdefs", - "@tf_runtime//:mlirtobef", - ], - alwayslink = 1, -) - cc_library( name = "host_context_util", srcs = ["utils/host_context.cc"], @@ -565,6 +553,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core/common_runtime:function_body", "//tensorflow/core/common_runtime:function_def_utils", + "//tensorflow/core/tfrt/fallback:fallback_state", "//tensorflow/core/platform:status", "@tf_runtime//:bef", "@tf_runtime//:mlirtobef", @@ -644,7 +633,6 @@ cc_library( "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_passes", "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_test_passes", "//tensorflow/compiler/mlir/tfrt:tf_to_tfrt", - "//tensorflow/compiler/mlir/tfrt:tf_to_tfrt_data", ] + if_google([ "//learning/brain/tfrt/tpu/compiler/mlir:tf_to_tfrt_tpu", ]), @@ -678,9 +666,11 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:bridge_pass_test_pipeline_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", + "//tensorflow/compiler/mlir/tfrt:transforms/gpu_passes", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_opdefs", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_sync_opdefs", + "//tensorflow/compiler/mlir/tfrt/ir:tfrt_gpu_opdefs", "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_passes", "//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_test_passes", "//tensorflow/compiler/xla/mlir_hlo:gml_st", @@ -760,7 +750,6 @@ cc_library( "tfrt_fallback_registration.h", ], visibility = [":friends"] + if_google([ - "//learning/brain/experimental/tfrt/distributed_runtime:__pkg__", "//learning/brain/experimental/tfrt/visualization:__pkg__", # Allow visibility from the mlir language server. "//learning/brain/mlir/mlir_lsp_server:__pkg__", @@ -810,3 +799,37 @@ cc_library( "@tf_runtime//:core_runtime_opdefs", ], ) + +cc_library( + name = "transform_utils", + srcs = [ + "transforms/utils.cc", + ], + hdrs = [ + "transforms/utils.h", + ], + visibility = [":friends"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@tf_runtime//:basic_kernels_opdefs", + "@tf_runtime//:core_runtime_opdefs", + "@tf_runtime//:support", + ], +) + +cc_library( + name = "transforms/update_op_cost_in_tfrt_mlir", + srcs = ["transforms/update_op_cost_in_tfrt_mlir.cc"], + hdrs = ["transforms/update_op_cost_in_tfrt_mlir.h"], + deps = [ + "//tensorflow/core/tfrt/fallback:cost_recorder", + "@llvm-project//mlir:IR", + ], +) diff --git a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc index 0de55acca2a..3fedc1deb3f 100644 --- a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc +++ b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc @@ -55,8 +55,8 @@ int64_t InferLookupTableFindV2Cost(const CostContext& context, constexpr int64_t kLookupTableFindCostScale = 8; constexpr int64_t kLookupTableFindStringKeyCostScale = 16; - auto value_type = op.values().getType().cast(); - auto key_type = op.keys().getType().cast(); + auto value_type = op.getValues().getType().cast(); + auto key_type = op.getKeys().getType().cast(); int64_t output_size = InferTensorSize(context, value_type); @@ -71,14 +71,14 @@ int64_t InferLookupTableFindV2Cost(const CostContext& context, // The cost function for tf.GatherV2. int64_t InferGatherV2Cost(const CostContext& context, mlir::TF::GatherV2Op op) { return InferTensorSize(context, - op.output().getType().cast()); + op.getOutput().getType().cast()); } // The cost function for tf.SparseSegmentSumOp. template int64_t InferSparseSegmentOpCost(const CostContext& context, OpType op) { return InferTensorSize( - context, op.output().getType().template cast()); + context, op.getOutput().getType().template cast()); } // CostFunctionRegistry is a map from op names to their cost functions. @@ -128,14 +128,7 @@ void RegisterCostFunction(absl::string_view op_name, std::move(cost_function)); } -int64_t CostAnalysis::GetCost(mlir::Operation* op, int64_t op_key) const { - // Try to use its measured cost. - const auto& measured_cost_map = op_cost_map_proto_.op_cost_map(); - if (const auto op_cost = measured_cost_map.find(op_key); - op_cost != measured_cost_map.end()) { - return op_cost->second; - } - +int64_t CostAnalysis::GetCost(mlir::Operation* op) const { assert(cost_map_.count(op) > 0); return cost_map_.lookup(op); } @@ -201,16 +194,5 @@ void CostAnalysis::EvaluateCost(mlir::Operation* op) { cost_map_[op] = cost; } -Status CostAnalysis::ReadMeasuredCosts() { - const char* env_var = getenv("TF_TFRT_MEASURED_COST_PATH"); - // No need to read because the cost measurement is disabled. - if (env_var == nullptr) return OkStatus(); - - tensorflow::Env* env = Env::Default(); - const std::string measured_cost_path(env_var); - TF_RETURN_IF_ERROR(env->FileExists(measured_cost_path)); - return ReadTextProto(env, measured_cost_path, &op_cost_map_proto_); -} - } // namespace tfrt_compiler } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h index 77d2fedf336..9f8d957e066 100644 --- a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h +++ b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h @@ -38,21 +38,18 @@ class CostAnalysis { public: explicit CostAnalysis(mlir::func::FuncOp func_op) { AnalyzeArguments(func_op); - TF_CHECK_OK(ReadMeasuredCosts()); AnalyzeBlock(&func_op.front()); } - int64_t GetCost(mlir::Operation* op, int64_t op_key) const; + int64_t GetCost(mlir::Operation* op) const; private: void AnalyzeArguments(mlir::func::FuncOp func_op); void AnalyzeBlock(mlir::Block* block); void EvaluateCost(mlir::Operation* op); - Status ReadMeasuredCosts(); int64_t max_arg_size_ = 1; llvm::DenseMap cost_map_; - tfrt_stub::OpCostMapProto op_cost_map_proto_; }; struct CostContext { diff --git a/tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.cc b/tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.cc index c27960fbe26..53cd5945766 100644 --- a/tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.cc @@ -31,9 +31,7 @@ bool IsTensorArrayOp(mlir::Operation* op) { static bool FunctionContainsOnlyNoSideEffectOpOrTensorArrayOp( mlir::func::FuncOp func_op) { for (mlir::Operation& op : func_op.front()) { - if (!mlir::MemoryEffectOpInterface::hasNoEffect(&op) && - !IsTensorArrayOp(&op)) - return false; + if (!mlir::isMemoryEffectFree(&op) && !IsTensorArrayOp(&op)) return false; } return true; diff --git a/tensorflow/compiler/mlir/tfrt/analysis/test_cost_analysis_pass.cc b/tensorflow/compiler/mlir/tfrt/analysis/test_cost_analysis_pass.cc index 84470c64f63..60c4f6e480b 100644 --- a/tensorflow/compiler/mlir/tfrt/analysis/test_cost_analysis_pass.cc +++ b/tensorflow/compiler/mlir/tfrt/analysis/test_cost_analysis_pass.cc @@ -32,9 +32,8 @@ class TestCostAnalysis const auto& cost_analysis = getAnalysis(); auto func_op = getOperation(); - int64_t op_key = 0; for (auto& op : func_op.front()) { - op.emitRemark() << "Cost: " << cost_analysis.GetCost(&op, op_key++); + op.emitRemark() << "Cost: " << cost_analysis.GetCost(&op); } } }; diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD b/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD index 707ef6c4abd..00a5639381b 100644 --- a/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD @@ -5,7 +5,10 @@ load( "tf_cc_test", ) -package(default_visibility = ["//visibility:private"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:private"], +) licenses(["notice"]) @@ -47,6 +50,7 @@ cc_library( hdrs = ["benchmark_mlir_function.h"], deps = [ ":benchmark", + "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tfrt:host_context_util", "//tensorflow/compiler/mlir/tfrt:runtime_fallback_executor", @@ -88,6 +92,16 @@ tf_cc_test( deps = [":cwise_op_unary_benchmark"], ) +tf_cc_test( + name = "concat_benchmark", + testonly = 1, + srcs = ["concat_benchmark.cc"], + deps = [ + ":benchmark", + ":benchmark_mlir_function", + ], +) + tf_cc_binary( name = "cwise_op_fusion_benchmark", testonly = 1, @@ -164,7 +178,9 @@ tf_cc_binary( ]), deps = [ ":benchmark", + "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/mlir/tfrt:host_context_util", + "@llvm-project//llvm:Support", ], ) @@ -206,6 +222,54 @@ tf_cc_binary( ], ) +tf_cc_binary( + name = "map_op_benchmark", + testonly = 1, + srcs = ["map_op_benchmark.cc"], + # Args() not supported. Enable when we got rid of tf benchmark and use the + # standard gunit benchmark. + tags = if_oss([ + "no_oss", + "manual", + ]), + deps = [ + ":benchmark", + ":benchmark_mlir_function", + ], +) + +tf_cc_binary( + name = "fused_map_bcast_benchmark", + testonly = 1, + srcs = ["fused_map_bcast_benchmark.cc"], + # Args() not supported. Enable when we got rid of tf benchmark and use the + # standard gunit benchmark. + tags = if_oss([ + "no_oss", + "manual", + ]), + deps = [ + ":benchmark", + ":benchmark_mlir_function", + ], +) + +tf_cc_binary( + name = "scatter_op_benchmark", + testonly = 1, + srcs = ["scatter_op_benchmark.cc"], + # Args() not supported. Enable when we got rid of tf benchmark and use the + # standard gunit benchmark. + tags = if_oss([ + "no_oss", + "manual", + ]), + deps = [ + ":benchmark", + ":benchmark_mlir_function", + ], +) + tf_cc_binary( name = "sum_full_op_benchmark", testonly = 1, @@ -307,3 +371,20 @@ tf_cc_binary( "@llvm-project//llvm:Support", ], ) + +tf_cc_binary( + name = "reverse_op_benchmark", + testonly = 1, + srcs = ["reverse_op_benchmark.cc"], + # Args() not supported. Enable when we got rid of tf benchmark and use the + # standard gunit benchmark. + tags = if_oss([ + "no_oss", + "manual", + ]), + deps = [ + ":benchmark", + ":benchmark_mlir_function", + "@llvm-project//llvm:Support", + ], +) diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.cc b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.cc index a038a91ff8f..58ff8929c2a 100644 --- a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.cc +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/FileUtilities.h" #include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "tensorflow/compiler/xla/mlir/runtime/transforms/compiler.h" #include "tensorflow/core/platform/logging.h" @@ -128,6 +129,16 @@ MemrefDesc TensorToMemrefDesc(const Tensor& tensor) { }); } +llvm::SmallVector GetTensorTypeShape( + llvm::ArrayRef shape, llvm::ArrayRef dynamic_dims) { + llvm::SmallVector type_shape; + for (int64_t i = 0; i < shape.size(); ++i) { + type_shape.push_back( + dynamic_dims[i] == kDynamicDim ? mlir::ShapedType::kDynamic : shape[i]); + } + return type_shape; +} + std::string PrintTensorType(llvm::ArrayRef shape, llvm::StringRef element_type) { std::string result{"tensor<"}; diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h index e8350f2c977..c315d2e9917 100644 --- a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h @@ -168,7 +168,10 @@ struct ExecuteAssignOp { // Common utilities. // -------------------------------------------------------------------------- // -static constexpr int64_t kDynSize = mlir::ShapedType::kDynamicSize; +static constexpr int64_t kDynSize = mlir::ShapedType::kDynamic; + +llvm::SmallVector GetTensorTypeShape( + llvm::ArrayRef shape, llvm::ArrayRef dynamic_dims); // Prints an MLIR tensor type, i.e. for `shape` {1, kDynSize} and `element_type` // "f32" the output is "tensor<1x?xf32>". diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.cc b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.cc index 04d2b524d8d..de7bebf43d9 100644 --- a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.cc +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/Support/SourceMgr.h" #include "mlir/Parser/Parser.h" // from @llvm-project +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" #include "tensorflow/compiler/mlir/tfrt/utils/host_context.h" @@ -90,6 +91,9 @@ void RunJitRtBenchmark(::testing::benchmark::State& state, : CreateSingleThreadedHostContext(); TfJitRtPipelineOptions tf_jitrt_opts; + tf_jitrt_opts.enable_xla_cpu_transformations = + tensorflow::GetJitRtFlags().enable_xla_cpu_transformations; + tf_jitrt_opts.lower_to_mmt4d = tensorflow::GetJitRtFlags().pack_matmul; tf_jitrt_opts.vectorize = vectorize; tf_jitrt_opts.codegen_transpose = codegen_transpose; JitExecutable& jit_executable = diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/concat_benchmark.cc b/tensorflow/compiler/mlir/tfrt/benchmarks/concat_benchmark.cc new file mode 100644 index 00000000000..a020c946077 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/concat_benchmark.cc @@ -0,0 +1,435 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h" +#include "tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.h" + +namespace tensorflow { +namespace { + +llvm::SmallVector GetOutputTypeShape(llvm::ArrayRef arg_shape, + llvm::ArrayRef dynamic_dims, + int64_t concat_dim, + int64_t num_concats) { + llvm::SmallVector out_ty = + GetTensorTypeShape(arg_shape, dynamic_dims); + if (dynamic_dims[concat_dim] == kStaticDim) out_ty[concat_dim] *= num_concats; + return out_ty; +} + +static const char* kBinaryConcatIR = R"( +func.func @main(%lhs: {0}, %rhs: {0}) -> {1} { + %0 = "tf.Log"(%lhs): ({0}) -> {0} + %1 = "tf.Log"(%rhs): ({0}) -> {0} + %2 = "tf.Const"() {{ value = dense<{2}> : tensor } : () -> tensor + %3 = "tf.ConcatV2"(%0, %1, %2) : ({0}, {0}, tensor) -> {1} + %4 = "tf.Log"(%3): ({1}) -> {1} + func.return %4 : {1} +} +)"; + +std::string GetBinaryConcatIR(llvm::ArrayRef arg_shape, + llvm::ArrayRef dynamic_dims, + int64_t concat_dim) { + llvm::SmallVector in_ty = + GetTensorTypeShape(arg_shape, dynamic_dims); + llvm::SmallVector out_ty = GetOutputTypeShape( + arg_shape, dynamic_dims, concat_dim, /*num_concats=*/2); + return llvm::formatv(kBinaryConcatIR, PrintTensorType(in_ty, "f32"), + PrintTensorType(out_ty, "f32"), concat_dim); +} + +static const char* kTeraryConcatIR = R"( +func.func @main(%arg0: {0}, %arg1: {0}, %arg2: {0}) -> {1} { + %0 = "tf.Log"(%arg0): ({0}) -> {0} + %1 = "tf.Log"(%arg1): ({0}) -> {0} + %2 = "tf.Log"(%arg2): ({0}) -> {0} + %3 = "tf.Const"() {{ value = dense<{2}> : tensor } : () -> tensor + %4 = "tf.ConcatV2"(%0, %1, %2, %3) : ({0}, {0}, {0}, tensor) -> {1} + %5 = "tf.Log"(%4): ({1}) -> {1} + func.return %5 : {1} +} +)"; + +std::string GetTernaryConcatIR(llvm::ArrayRef arg_shape, + llvm::ArrayRef dynamic_dims, + int64_t concat_dim) { + llvm::SmallVector in_ty = + GetTensorTypeShape(arg_shape, dynamic_dims); + llvm::SmallVector out_ty = GetOutputTypeShape( + arg_shape, dynamic_dims, concat_dim, /*num_concats=*/3); + return llvm::formatv(kTeraryConcatIR, PrintTensorType(in_ty, "f32"), + PrintTensorType(out_ty, "f32"), concat_dim); +} + +static const char* kOctonaryConcatIR = R"( +func.func @main(%arg0: {0}, %arg1: {0}, %arg2: {0}, %arg3: {0}, %arg4: {0}, + %arg5: {0}, %arg6: {0}, %arg7: {0}) -> {1} { + %0 = "tf.Log"(%arg0): ({0}) -> {0} + %1 = "tf.Log"(%arg1): ({0}) -> {0} + %2 = "tf.Log"(%arg2): ({0}) -> {0} + %3 = "tf.Log"(%arg3): ({0}) -> {0} + %4 = "tf.Log"(%arg4): ({0}) -> {0} + %5 = "tf.Log"(%arg5): ({0}) -> {0} + %6 = "tf.Log"(%arg6): ({0}) -> {0} + %7 = "tf.Log"(%arg7): ({0}) -> {0} + %8 = "tf.Const"() {{ value = dense<{2}> : tensor } : () -> tensor + %9 = "tf.ConcatV2"(%0, %1, %2, %3, %4, %5, %6, %7, %8) + : ({0}, {0}, {0}, {0}, {0}, {0}, {0}, {0}, tensor) -> {1} + %10 = "tf.Log"(%9): ({1}) -> {1} + func.return %10 : {1} +} +)"; + +std::string GetOctonaryConcatIR(llvm::ArrayRef arg_shape, + llvm::ArrayRef dynamic_dims, + int64_t concat_dim) { + llvm::SmallVector in_ty = + GetTensorTypeShape(arg_shape, dynamic_dims); + llvm::SmallVector out_ty = GetOutputTypeShape( + arg_shape, dynamic_dims, concat_dim, /*num_concats=*/8); + return llvm::formatv(kOctonaryConcatIR, PrintTensorType(in_ty, "f32"), + PrintTensorType(out_ty, "f32"), concat_dim); +} + +template +TensorShape GetOutShape(llvm::ArrayRef inputs) { + auto lhsShape = inputs[0].shape(); + auto rhsShape = inputs[1].shape(); + std::vector shape; + shape.reserve(inputs[0].dims()); + for (int64_t i = 0; i < inputs[0].dims(); ++i) + shape.push_back(inputs[0].dim_size(i)); + for (int64_t i = 1; i < inputs.size(); ++i) shape[D] += inputs[i].dim_size(D); + return TensorShape(shape); +} + +template +auto GetEigenBinaryConcatFn() { + return [](llvm::ArrayRef inputs, + llvm::Optional device) { + auto lhs = inputs[0].tensor(); + auto rhs = inputs[1].tensor(); + Tensor output(DT_FLOAT, GetOutShape(inputs)); + auto out = output.tensor(); + if (device.has_value()) { + out.device(*device) = lhs.log().concatenate(rhs.log(), D).log(); + } else { + out = lhs.log().concatenate(rhs.log(), D).log(); + } + }; +} + +template +auto GetEigenTernaryConcatFn() { + return [](llvm::ArrayRef inputs, + llvm::Optional device) { + auto arg0 = inputs[0].tensor(); + auto arg1 = inputs[1].tensor(); + auto arg2 = inputs[2].tensor(); + Tensor output(DT_FLOAT, GetOutShape(inputs)); + auto out = output.tensor(); + if (device.has_value()) { + out.device(*device) = arg0.log() + .concatenate(arg1.log(), D) + .concatenate(arg2.log(), D) + .log(); + } else { + out = arg0.log() + .concatenate(arg1.log(), D) + .concatenate(arg2.log(), D) + .log(); + } + }; +} + +template +auto GetEigenOctonaryConcatFn() { + return [](llvm::ArrayRef inputs, + llvm::Optional device) { + auto arg0 = inputs[0].tensor(); + auto arg1 = inputs[1].tensor(); + auto arg2 = inputs[2].tensor(); + auto arg3 = inputs[3].tensor(); + auto arg4 = inputs[4].tensor(); + auto arg5 = inputs[5].tensor(); + auto arg6 = inputs[6].tensor(); + auto arg7 = inputs[7].tensor(); + Tensor output(DT_FLOAT, GetOutShape(inputs)); + auto out = output.tensor(); + if (device.has_value()) { + out.device(*device) = arg0.log() + .concatenate(arg1.log(), D) + .concatenate(arg2.log(), D) + .concatenate(arg3.log(), D) + .concatenate(arg4.log(), D) + .concatenate(arg5.log(), D) + .concatenate(arg6.log(), D) + .concatenate(arg7.log(), D) + .log(); + } else { + out = arg0.log() + .concatenate(arg1.log(), D) + .concatenate(arg2.log(), D) + .concatenate(arg3.log(), D) + .concatenate(arg4.log(), D) + .concatenate(arg5.log(), D) + .concatenate(arg6.log(), D) + .concatenate(arg7.log(), D) + .log(); + } + }; +} + +#define WRAP(...) __VA_ARGS__ + +#define BM_BINARY_CONCAT(NAME, RANK, ARG_SHAPE, DYNAMIC_DIMS, CONCAT_DIM) \ + BM_Jitrt(BinaryConcat_##NAME, \ + GetBinaryConcatIR({ARG_SHAPE}, {DYNAMIC_DIMS}, CONCAT_DIM), "main", \ + llvm::ArrayRef({InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE})})) \ + ->Arg(0); \ + BM_JitrtV(BinaryConcat_##NAME, \ + GetBinaryConcatIR({ARG_SHAPE}, {DYNAMIC_DIMS}, CONCAT_DIM), \ + "main", \ + llvm::ArrayRef({InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE})})) \ + ->Arg(0); \ + BM_Eigen(BinaryConcat_##NAME, (GetEigenBinaryConcatFn()), \ + llvm::ArrayRef({InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE})})) \ + ->Arg(0) + +#define BM_TERNARY_CONCAT(NAME, RANK, ARG_SHAPE, DYNAMIC_DIMS, CONCAT_DIM) \ + BM_Jitrt(TernaryConcat_##NAME, \ + GetTernaryConcatIR({ARG_SHAPE}, {DYNAMIC_DIMS}, CONCAT_DIM), \ + "main", \ + llvm::ArrayRef({InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE})})) \ + ->Arg(0); \ + BM_JitrtV(TernaryConcat_##NAME, \ + GetTernaryConcatIR({ARG_SHAPE}, {DYNAMIC_DIMS}, CONCAT_DIM), \ + "main", \ + llvm::ArrayRef({InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE})})) \ + ->Arg(0); \ + BM_Eigen(TernaryConcat_##NAME, \ + (GetEigenTernaryConcatFn()), \ + llvm::ArrayRef({InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE})})) \ + ->Arg(0) + +#define BM_OCTONARY_CONCAT(NAME, RANK, ARG_SHAPE, DYNAMIC_DIMS, CONCAT_DIM) \ + BM_Jitrt(OcternaryConcat_##NAME, \ + GetOctonaryConcatIR({ARG_SHAPE}, {DYNAMIC_DIMS}, CONCAT_DIM), \ + "main", \ + llvm::ArrayRef({InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE})})) \ + ->Arg(0); \ + BM_JitrtV(OcternaryConcat_##NAME, \ + GetOctonaryConcatIR({ARG_SHAPE}, {DYNAMIC_DIMS}, CONCAT_DIM), \ + "main", \ + llvm::ArrayRef({InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE})})) \ + ->Arg(0); \ + BM_Eigen(OcternaryConcat_##NAME, \ + (GetEigenOctonaryConcatFn()), \ + llvm::ArrayRef({InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE}), \ + InputTensorSpec(DT_FLOAT, {ARG_SHAPE})})) \ + ->Arg(0) + +#define BM_NARY_CONCAT(NAME, RANK, ARG_SHAPE, DYNAMIC_DIMS, CONCAT_DIM) \ + BM_BINARY_CONCAT(NAME, RANK, WRAP(ARG_SHAPE), WRAP(DYNAMIC_DIMS), \ + CONCAT_DIM); \ + BM_TERNARY_CONCAT(NAME, RANK, WRAP(ARG_SHAPE), WRAP(DYNAMIC_DIMS), \ + CONCAT_DIM); \ + BM_OCTONARY_CONCAT(NAME, RANK, WRAP(ARG_SHAPE), WRAP(DYNAMIC_DIMS), \ + CONCAT_DIM) + +// Static Concat 1D +#define BM_NARY_CONCAT_STATIC_1D(N) \ + BM_NARY_CONCAT(Static1D_##N, 1, WRAP(N), WRAP(kStaticDim), 0) +BM_NARY_CONCAT_STATIC_1D(1); +BM_NARY_CONCAT_STATIC_1D(8); +BM_NARY_CONCAT_STATIC_1D(1024); +BM_NARY_CONCAT_STATIC_1D(1026); +BM_NARY_CONCAT_STATIC_1D(1048576); +BM_NARY_CONCAT_STATIC_1D(1048578); + +// Dynamic Concat 1D +#define BM_NARY_CONCAT_DYNAMIC_1D(N) \ + BM_NARY_CONCAT(Dynamic1D_##N, 1, WRAP(N), WRAP(kDynamicDim), 0) +BM_NARY_CONCAT_DYNAMIC_1D(1); +BM_NARY_CONCAT_DYNAMIC_1D(8); +BM_NARY_CONCAT_DYNAMIC_1D(1024); +BM_NARY_CONCAT_DYNAMIC_1D(1026); +BM_NARY_CONCAT_DYNAMIC_1D(1048576); +BM_NARY_CONCAT_DYNAMIC_1D(1048578); + +// Static Concat 2D +#define BM_NARY_CONCAT_STATIC_2D(M, N, CONCAT_DIM) \ + BM_NARY_CONCAT(Static2D_##M##x##N##_dim##CONCAT_DIM, 2, WRAP(M, N), \ + WRAP(kStaticDim, kStaticDim), CONCAT_DIM) +// Sqaure operands +BM_NARY_CONCAT_STATIC_2D(512, 512, 0); +BM_NARY_CONCAT_STATIC_2D(512, 512, 1); +BM_NARY_CONCAT_STATIC_2D(514, 514, 0); +BM_NARY_CONCAT_STATIC_2D(514, 514, 1); +BM_NARY_CONCAT_STATIC_2D(1024, 1024, 0); +BM_NARY_CONCAT_STATIC_2D(1024, 1024, 1); +BM_NARY_CONCAT_STATIC_2D(1026, 1026, 0); +BM_NARY_CONCAT_STATIC_2D(1026, 1026, 1); +// Slice operands +BM_NARY_CONCAT_STATIC_2D(1, 1024, 0); +BM_NARY_CONCAT_STATIC_2D(1024, 1, 1); +BM_NARY_CONCAT_STATIC_2D(1, 1026, 0); +BM_NARY_CONCAT_STATIC_2D(1026, 1, 1); +BM_NARY_CONCAT_STATIC_2D(1, 1048576, 0); +BM_NARY_CONCAT_STATIC_2D(1048576, 1, 1); +BM_NARY_CONCAT_STATIC_2D(1, 1048578, 0); +BM_NARY_CONCAT_STATIC_2D(1048578, 1, 1); + +// Concat 2D with static concatenation dimension +#define BM_NARY_CONCAT_W_STATIC_CONCAT_DIM0_2D(M, N) \ + BM_NARY_CONCAT(StaticConcatDim2D_##M##x##N##_dim0, 2, WRAP(M, N), \ + WRAP(kStaticDim, kDynamicDim), 0) +BM_NARY_CONCAT_W_STATIC_CONCAT_DIM0_2D(1, 1024); +BM_NARY_CONCAT_W_STATIC_CONCAT_DIM0_2D(1, 1026); +BM_NARY_CONCAT_W_STATIC_CONCAT_DIM0_2D(1, 1048576); +BM_NARY_CONCAT_W_STATIC_CONCAT_DIM0_2D(1, 1048578); +#define BM_NARY_CONCAT_W_STATIC_CONCAT_DIM1_2D(M, N) \ + BM_NARY_CONCAT(StaticConcatDim2D_##M##x##N##_dim1, 2, WRAP(M, N), \ + WRAP(kDynamicDim, kStaticDim), 1) +BM_NARY_CONCAT_W_STATIC_CONCAT_DIM1_2D(1024, 1); +BM_NARY_CONCAT_W_STATIC_CONCAT_DIM1_2D(1026, 1); +BM_NARY_CONCAT_W_STATIC_CONCAT_DIM1_2D(1048576, 1); +BM_NARY_CONCAT_W_STATIC_CONCAT_DIM1_2D(1048578, 1); + +// Dynamic Concat 2D +#define BM_NARY_CONCAT_DYNAMIC_2D(M, N, CONCAT_DIM) \ + BM_NARY_CONCAT(Dynamic2D_##M##x##N##_dim##CONCAT_DIM, 2, WRAP(M, N), \ + WRAP(kDynamicDim, kDynamicDim), CONCAT_DIM) +// Sqaure operands +BM_NARY_CONCAT_DYNAMIC_2D(512, 512, 0); +BM_NARY_CONCAT_DYNAMIC_2D(512, 512, 1); +BM_NARY_CONCAT_DYNAMIC_2D(514, 514, 0); +BM_NARY_CONCAT_DYNAMIC_2D(514, 514, 1); +BM_NARY_CONCAT_DYNAMIC_2D(1024, 1024, 0); +BM_NARY_CONCAT_DYNAMIC_2D(1024, 1024, 1); +BM_NARY_CONCAT_DYNAMIC_2D(1026, 1026, 0); +BM_NARY_CONCAT_DYNAMIC_2D(1026, 1026, 1); +// Slice operands +BM_NARY_CONCAT_DYNAMIC_2D(1, 1024, 0); +BM_NARY_CONCAT_DYNAMIC_2D(1024, 1, 1); +BM_NARY_CONCAT_DYNAMIC_2D(1, 1026, 0); +BM_NARY_CONCAT_DYNAMIC_2D(1026, 1, 1); +BM_NARY_CONCAT_DYNAMIC_2D(1, 1048576, 0); +BM_NARY_CONCAT_DYNAMIC_2D(1048576, 1, 1); +BM_NARY_CONCAT_DYNAMIC_2D(1, 1048578, 0); +BM_NARY_CONCAT_DYNAMIC_2D(1048578, 1, 1); + +// Static Concat 4D +#define BM_NARY_CONCAT_STATIC_4D(M, N, O, P, CONCAT_DIM) \ + BM_NARY_CONCAT( \ + Static4D_##M##x##N##x##O##x##P##_dim##CONCAT_DIM, 4, WRAP(M, N, O, P), \ + WRAP(kStaticDim, kStaticDim, kStaticDim, kStaticDim), CONCAT_DIM) +// Sqaure operands +BM_NARY_CONCAT_STATIC_4D(32, 32, 32, 32, 0); +BM_NARY_CONCAT_STATIC_4D(32, 32, 32, 32, 1); +BM_NARY_CONCAT_STATIC_4D(34, 34, 34, 34, 0); +BM_NARY_CONCAT_STATIC_4D(34, 34, 34, 34, 1); +BM_NARY_CONCAT_STATIC_4D(1024, 1024, 4, 4, 0); +BM_NARY_CONCAT_STATIC_4D(1024, 1024, 4, 4, 1); +BM_NARY_CONCAT_STATIC_4D(1026, 1026, 2, 6, 0); +BM_NARY_CONCAT_STATIC_4D(1026, 1026, 2, 6, 1); +// Slice operands +BM_NARY_CONCAT_STATIC_4D(32, 32, 1, 1024, 2); +BM_NARY_CONCAT_STATIC_4D(32, 32, 1024, 1, 3); +BM_NARY_CONCAT_STATIC_4D(34, 34, 1, 1026, 2); +BM_NARY_CONCAT_STATIC_4D(34, 34, 1026, 1, 3); +BM_NARY_CONCAT_STATIC_4D(4, 4, 1, 1048576, 2); +BM_NARY_CONCAT_STATIC_4D(4, 4, 1048576, 1, 3); +BM_NARY_CONCAT_STATIC_4D(2, 6, 1, 1048578, 2); +BM_NARY_CONCAT_STATIC_4D(2, 6, 1048578, 1, 3); + +// Concat 4D with static concatenation dimension +#define BM_NARY_CONCAT_W_STATIC_CONCAT_DIM2_4D(M, N, O, P) \ + BM_NARY_CONCAT(StaticConcatDim4D_##M##x##N##x##O##x##P##_dim0, 4, \ + WRAP(M, N, O, P), \ + WRAP(kDynamicDim, kDynamicDim, kStaticDim, kDynamicDim), 2) +BM_NARY_CONCAT_W_STATIC_CONCAT_DIM2_4D(32, 32, 1, 1024); +BM_NARY_CONCAT_W_STATIC_CONCAT_DIM2_4D(34, 34, 1, 1026); +BM_NARY_CONCAT_W_STATIC_CONCAT_DIM2_4D(4, 4, 1, 1048576); +BM_NARY_CONCAT_W_STATIC_CONCAT_DIM2_4D(2, 6, 1, 1048578); +#define BM_NARY_CONCAT_W_STATIC_CONCAT_DIM3_4D(M, N, O, P) \ + BM_NARY_CONCAT(StaticConcatDim4D_##M##x##N##x##O##x##P##_dim1, 4, \ + WRAP(M, N, O, P), \ + WRAP(kDynamicDim, kDynamicDim, kDynamicDim, kStaticDim), 3) +BM_NARY_CONCAT_W_STATIC_CONCAT_DIM3_4D(32, 32, 1024, 1); +BM_NARY_CONCAT_W_STATIC_CONCAT_DIM3_4D(34, 34, 1026, 1); +BM_NARY_CONCAT_W_STATIC_CONCAT_DIM3_4D(4, 4, 1048576, 1); +BM_NARY_CONCAT_W_STATIC_CONCAT_DIM3_4D(2, 6, 1048578, 1); + +// Dynamic Concat 4D +#define BM_NARY_CONCAT_DYNAMIC_4D(M, N, O, P, CONCAT_DIM) \ + BM_NARY_CONCAT( \ + Dynamic4D_##M##x##N##x##O##x##P##_dim##CONCAT_DIM, 4, WRAP(M, N, O, P), \ + WRAP(kDynamicDim, kDynamicDim, kDynamicDim, kDynamicDim), CONCAT_DIM) +// Sqaure operands +BM_NARY_CONCAT_DYNAMIC_4D(32, 32, 32, 32, 0); +BM_NARY_CONCAT_DYNAMIC_4D(32, 32, 32, 32, 1); +BM_NARY_CONCAT_DYNAMIC_4D(34, 34, 34, 34, 0); +BM_NARY_CONCAT_DYNAMIC_4D(34, 34, 34, 34, 1); +BM_NARY_CONCAT_DYNAMIC_4D(1024, 1024, 4, 4, 0); +BM_NARY_CONCAT_DYNAMIC_4D(1024, 1024, 4, 4, 1); +BM_NARY_CONCAT_DYNAMIC_4D(1026, 1026, 2, 6, 0); +BM_NARY_CONCAT_DYNAMIC_4D(1026, 1026, 2, 6, 1); +// Slice operands +BM_NARY_CONCAT_DYNAMIC_4D(32, 32, 1, 1024, 2); +BM_NARY_CONCAT_DYNAMIC_4D(32, 32, 1024, 1, 3); +BM_NARY_CONCAT_DYNAMIC_4D(34, 34, 1, 1026, 2); +BM_NARY_CONCAT_DYNAMIC_4D(34, 34, 1026, 1, 3); +BM_NARY_CONCAT_DYNAMIC_4D(4, 4, 1, 1048576, 2); +BM_NARY_CONCAT_DYNAMIC_4D(4, 4, 1048576, 1, 3); +BM_NARY_CONCAT_DYNAMIC_4D(2, 6, 1, 1048578, 2); +BM_NARY_CONCAT_DYNAMIC_4D(2, 6, 1048578, 1, 3); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/fused_map_bcast_benchmark.cc b/tensorflow/compiler/mlir/tfrt/benchmarks/fused_map_bcast_benchmark.cc new file mode 100644 index 00000000000..99f3a84d97d --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/fused_map_bcast_benchmark.cc @@ -0,0 +1,128 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h" +#include "tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.h" + +namespace tensorflow { +namespace { + +const char* kMapIR = R"( + func.func @main(%arg0: {0}, %arg1: {1}) -> {1} { + %abs = "tf.Abs"(%arg0) + {{device = "/job:localhost/replica:0/task:0/device:CPU:0"} + : ({0}) -> {0} + %exp = "tf.Exp"(%abs) + {{device = "/job:localhost/replica:0/task:0/device:CPU:0"} + : ({0}) -> {0} + %tanh = "tf.Tanh"(%exp) + {{device = "/job:localhost/replica:0/task:0/device:CPU:0"} + : ({0}) -> {0} + %add = "tf.AddV2"(%exp, %arg1) + {{device = "/job:localhost/replica:0/task:0/device:CPU:0"} + : ({0}, {1}) -> {1} + func.return %add : {1} + } +)"; + +std::string Map(llvm::ArrayRef dynamic_dims, + llvm::ArrayRef input_shape, + llvm::ArrayRef bcst_shape) { + llvm::SmallVector mlir_input_shape; + for (int i = 0; i < input_shape.size(); ++i) { + mlir_input_shape.push_back(dynamic_dims[i] ? kDynSize : input_shape[i]); + } + return llvm::formatv(kMapIR, PrintTensorType(mlir_input_shape, "f32"), + PrintTensorType(bcst_shape, "f32")); +} + +auto EigenMap() { + return [](llvm::ArrayRef inputs, + llvm::Optional) { + Tensor output(DT_FLOAT, {inputs[0].dim_size(0), inputs[0].dim_size(1)}); + + auto in0 = inputs[0].tensor(); + auto in1 = inputs[1].tensor(); + auto out = output.tensor(); + out.setZero(); + Eigen::DefaultDevice d; + out.device(d) = in0.abs().exp().tanh() + in1; + }; +} + +llvm::SmallVector Inputs(ssize_t rows, ssize_t cols, + ssize_t cols1) { + return {InputTensorSpec(DT_FLOAT, {rows, cols}), + InputTensorSpec(DT_FLOAT, {rows, cols1})}; +} + +#define BM(FN) BM_##FN->Arg(0); + +#define BM_SUITE(NAME, DYNAMIC_ROW, DYNAMIC_COL, ROWS, COLS, COLS1) \ + BM(JitrtV(NAME, \ + Map({DYNAMIC_ROW, DYNAMIC_COL}, {ROWS, COLS}, {ROWS, COLS1}), \ + "main", Inputs(ROWS, COLS, COLS1))); \ + BM(Eigen(NAME, EigenMap(), Inputs(ROWS, COLS, COLS1))); \ + BM(Tfrt(NAME, Map({DYNAMIC_ROW, DYNAMIC_COL}, {ROWS, COLS}, {ROWS, COLS1}), \ + "main", Inputs(ROWS, COLS, COLS1))) + +#define BM_DYNAMIC_ALL(ROWS, COLS, COLS1) \ + BM_SUITE(FusedMapBcastDynamicAll_##ROWS##_##COLS1, kDynamicDim, kDynamicDim, \ + ROWS, COLS, COLS1) +BM_DYNAMIC_ALL(2, 1, 80); +BM_DYNAMIC_ALL(8, 1, 6); +BM_DYNAMIC_ALL(80, 1, 1); +BM_DYNAMIC_ALL(80, 1, 60); +BM_DYNAMIC_ALL(81, 1, 61); +BM_DYNAMIC_ALL(800, 1, 600); +BM_DYNAMIC_ALL(802, 1, 602); + +#define BM_STATIC_ROW(ROWS, COLS, COLS1) \ + BM_SUITE(FusedMapBcastStaticRow_##ROWS##_##COLS1, kStaticDim, kDynamicDim, \ + ROWS, COLS, COLS1) +BM_STATIC_ROW(2, 1, 80); +BM_STATIC_ROW(8, 1, 6); +BM_STATIC_ROW(80, 1, 1); +BM_STATIC_ROW(80, 1, 60); +BM_STATIC_ROW(81, 1, 61); +BM_STATIC_ROW(800, 1, 600); +BM_STATIC_ROW(802, 1, 602); + +#define BM_STATIC_COL(ROWS, COLS, COLS1) \ + BM_SUITE(FusedMapBcastStaticCol_##ROWS##_##COLS1, kDynamicDim, kStaticDim, \ + ROWS, COLS, COLS1) +BM_STATIC_COL(2, 1, 80); +BM_STATIC_COL(8, 1, 6); +BM_STATIC_COL(80, 1, 1); +BM_STATIC_COL(80, 1, 60); +BM_STATIC_COL(81, 1, 61); +BM_STATIC_COL(800, 1, 600); +BM_STATIC_COL(802, 1, 602); + +#define BM_STATIC_ALL(ROWS, COLS, COLS1) \ + BM_SUITE(FusedMapBcastStaticAll_##ROWS##_##COLS1, kStaticDim, kStaticDim, \ + ROWS, COLS, COLS1) +BM_STATIC_ALL(2, 1, 80); +BM_STATIC_ALL(8, 1, 6); +BM_STATIC_ALL(80, 1, 1); +BM_STATIC_ALL(80, 1, 60); +BM_STATIC_ALL(81, 1, 61); +BM_STATIC_ALL(800, 1, 600); +BM_STATIC_ALL(802, 1, 602); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/map_op_benchmark.cc b/tensorflow/compiler/mlir/tfrt/benchmarks/map_op_benchmark.cc new file mode 100644 index 00000000000..53453374836 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/map_op_benchmark.cc @@ -0,0 +1,115 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h" +#include "tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.h" + +namespace tensorflow { +namespace { + +const char* kMapIR = R"( + func.func @main(%arg0: {0}) -> {0} { + %abs = "tf.Abs"(%arg0) + {{device = "/job:localhost/replica:0/task:0/device:CPU:0"} + : ({0}) -> {0} + %exp = "tf.Exp"(%abs) + {{device = "/job:localhost/replica:0/task:0/device:CPU:0"} + : ({0}) -> {0} + %tanh = "tf.Tanh"(%exp) + {{device = "/job:localhost/replica:0/task:0/device:CPU:0"} + : ({0}) -> {0} + func.return %tanh : {0} + } +)"; + +std::string Map(llvm::ArrayRef dynamic_dims, + llvm::ArrayRef input_shape) { + llvm::SmallVector mlir_input_shape; + for (int i = 0; i < input_shape.size(); ++i) { + mlir_input_shape.push_back(dynamic_dims[i] ? kDynSize : input_shape[i]); + } + return llvm::formatv(kMapIR, PrintTensorType(mlir_input_shape, "f32")); +} + +auto EigenMap() { + return [](llvm::ArrayRef inputs, + llvm::Optional) { + Tensor output(DT_FLOAT, {inputs[0].dim_size(0), inputs[0].dim_size(1)}); + + auto in = inputs[0].tensor(); + auto out = output.tensor(); + out.setZero(); + Eigen::DefaultDevice d; + out.device(d) = in.abs().exp().tanh(); + }; +} + +llvm::SmallVector Inputs(ssize_t rows, ssize_t cols) { + return {InputTensorSpec(DT_FLOAT, {rows, cols})}; +} + +#define BM(FN) BM_##FN->Arg(0); + +#define BM_SUITE(NAME, DYNAMIC_ROW, DYNAMIC_COL, ROWS, COLS) \ + BM(JitrtV(NAME, Map({DYNAMIC_ROW, DYNAMIC_COL}, {ROWS, COLS}), "main", \ + Inputs(ROWS, COLS))); \ + BM(Eigen(NAME, EigenMap(), Inputs(ROWS, COLS))); \ + BM(Tfrt(NAME, Map({DYNAMIC_ROW, DYNAMIC_COL}, {ROWS, COLS}), "main", \ + Inputs(ROWS, COLS))) + +#define BM_DYNAMIC_ALL(ROWS, COLS) \ + BM_SUITE(MapDynamicAll_##ROWS##_##COLS, kDynamicDim, kDynamicDim, ROWS, COLS) +BM_DYNAMIC_ALL(2, 80); +BM_DYNAMIC_ALL(8, 6); +BM_DYNAMIC_ALL(80, 1); +BM_DYNAMIC_ALL(80, 60); +BM_DYNAMIC_ALL(81, 61); +BM_DYNAMIC_ALL(800, 600); +BM_DYNAMIC_ALL(802, 602); + +#define BM_STATIC_ROW(ROWS, COLS) \ + BM_SUITE(MapStaticRow_##ROWS##_##COLS, kStaticDim, kDynamicDim, ROWS, COLS) +BM_STATIC_ROW(2, 80); +BM_STATIC_ROW(8, 6); +BM_STATIC_ROW(80, 1); +BM_STATIC_ROW(80, 60); +BM_STATIC_ROW(81, 61); +BM_STATIC_ROW(800, 600); +BM_STATIC_ROW(802, 602); + +#define BM_STATIC_COL(ROWS, COLS) \ + BM_SUITE(MapStaticCol_##ROWS##_##COLS, kDynamicDim, kStaticDim, ROWS, COLS) +BM_STATIC_COL(2, 80); +BM_STATIC_COL(8, 6); +BM_STATIC_COL(80, 1); +BM_STATIC_COL(80, 60); +BM_STATIC_COL(81, 61); +BM_STATIC_COL(800, 600); +BM_STATIC_COL(802, 602); + +#define BM_STATIC_ALL(ROWS, COLS) \ + BM_SUITE(MapStaticAll_##ROWS##_##COLS, kStaticDim, kStaticDim, ROWS, COLS) +BM_STATIC_ALL(2, 80); +BM_STATIC_ALL(8, 6); +BM_STATIC_ALL(80, 1); +BM_STATIC_ALL(80, 60); +BM_STATIC_ALL(81, 61); +BM_STATIC_ALL(800, 600); +BM_STATIC_ALL(802, 602); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.cc b/tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.cc index e602deb07ce..77198f43cb2 100644 --- a/tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.cc +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.cc @@ -15,36 +15,101 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.h" +#include + namespace tensorflow { -static const char* mlir_input = R"( -func.func @matmul(%arg0: tensor, - %arg1: tensor) -> tensor { - %0 = "tf.MatMul"(%arg0, %arg1) { +// Use type aliases compatible with MLIR type names. +using f32 = float; + +static const char* matmul_ir_skeleton = R"( +func.func @matmul(%arg0: {0}, %arg1: {1}) -> {2} { + %0 = "tf.MatMul"(%arg0, %arg1) {{ transpose_a = false, transpose_b = false - } : (tensor, tensor) -> tensor - func.return %0 : tensor + } : ({0}, {1}) -> {2} + func.return %0 : {2} } )"; -// Use type aliases compatible with MLIR type names. -using f32 = float; +std::string GetMatmulIR(llvm::ArrayRef lhs_shape, + llvm::ArrayRef lhs_dyn_dims, + llvm::ArrayRef rhs_shape, + llvm::ArrayRef rhs_dyn_dims, + llvm::ArrayRef out_shape, + llvm::ArrayRef out_dyn_dims, + llvm::StringRef element_type) { + llvm::SmallVector mlir_lhs_shape, mlir_rhs_shape, mlir_out_shape; + for (int i = 0; i < lhs_shape.size(); ++i) { + mlir_lhs_shape.push_back(lhs_dyn_dims[i] ? kDynSize : lhs_shape[i]); + } + for (int i = 0; i < rhs_shape.size(); ++i) { + mlir_rhs_shape.push_back(rhs_dyn_dims[i] ? kDynSize : rhs_shape[i]); + } + for (int i = 0; i < out_shape.size(); ++i) { + mlir_out_shape.push_back(out_dyn_dims[i] ? kDynSize : out_shape[i]); + } + return llvm::formatv( + matmul_ir_skeleton, + PrintTensorType(mlir_lhs_shape, element_type), // LHS type {0} + PrintTensorType(mlir_rhs_shape, element_type), // RHS type {1} + PrintTensorType(mlir_out_shape, element_type) // Out type {2} + ); +} + +BM_TFMlir_DYNAMIC_ALL(16, 16, 16, 8, 8, 8, "matmul", f32); +BM_TFMlir_STATIC_ALL(16, 16, 16, 8, 8, 8, "matmul", f32); +BM_Eigen_WRAPPER(16, 16, 16, f32); + +BM_TFMlir_DYNAMIC_ALL(64, 64, 64, 8, 8, 8, "matmul", f32); +BM_TFMlir_STATIC_ALL(64, 64, 64, 8, 8, 8, "matmul", f32); +BM_Eigen_WRAPPER(64, 64, 64, f32); + +BM_TFMlir_DYNAMIC_ALL(128, 128, 128, 8, 8, 8, "matmul", f32); +BM_TFMlir_STATIC_ALL(128, 128, 128, 8, 8, 8, "matmul", f32); +BM_Eigen_WRAPPER(128, 128, 128, f32); + +BM_TFMlir_DYNAMIC_ALL(256, 256, 256, 8, 8, 8, "matmul", f32); +BM_TFMlir_STATIC_ALL(256, 256, 256, 8, 8, 8, "matmul", f32); +BM_Eigen_WRAPPER(256, 256, 256, f32); + +BM_TFMlir_DYNAMIC_ALL(512, 512, 512, 8, 8, 8, "matmul", f32); +BM_TFMlir_STATIC_ALL(512, 512, 512, 8, 8, 8, "matmul", f32); +BM_Eigen_WRAPPER(512, 512, 512, f32); + +BM_TFMlir_DYNAMIC_ALL(1024, 1024, 1024, 8, 8, 8, "matmul", f32); +BM_TFMlir_STATIC_ALL(1024, 1024, 1024, 8, 8, 8, "matmul", f32); +BM_Eigen_WRAPPER(1024, 1024, 1024, f32); + +BM_TFMlir_DYNAMIC_ALL(2048, 2048, 2048, 8, 8, 8, "matmul", f32); +BM_TFMlir_STATIC_ALL(2048, 2048, 2048, 8, 8, 8, "matmul", f32); +BM_Eigen_WRAPPER(2048, 2048, 2048, f32); + +BM_TFMlir_DYNAMIC_ALL(100, 100, 100, 8, 8, 8, "matmul", f32); +BM_TFMlir_STATIC_ALL(100, 100, 100, 8, 8, 8, "matmul", f32); +BM_Eigen_WRAPPER(100, 100, 100, f32); + +BM_TFMlir_DYNAMIC_ALL(1, 18, 300, 8, 8, 8, "matmul", f32); +BM_TFMlir_STATIC_ALL(1, 18, 300, 8, 8, 8, "matmul", f32); +BM_Eigen_WRAPPER(1, 18, 300, f32); + +BM_TFMlir_DYNAMIC_ALL(1, 300, 300, 8, 8, 8, "matmul", f32); +BM_TFMlir_STATIC_ALL(1, 300, 300, 8, 8, 8, "matmul", f32); +BM_Eigen_WRAPPER(1, 300, 300, f32); + +BM_TFMlir_DYNAMIC_ALL(1, 300, 1, 8, 8, 8, "matmul", f32); +BM_TFMlir_STATIC_ALL(1, 300, 1, 8, 8, 8, "matmul", f32); +BM_Eigen_WRAPPER(1, 300, 1, f32); -BM_TFMlir(MatMul, mlir_input, "matmul", f32) - ->Args({10, 10, 10}) - ->Args({128, 128, 128}) - ->Args({256, 256, 256}) - ->Args({1, 18, 300}) - ->Args({1, 300, 300}) - ->Args({1, 300, 1}); - -BM_Eigen(MatMul, f32) - ->Args({10, 10, 10}) - ->Args({128, 128, 128}) - ->Args({256, 256, 256}) - ->Args({1, 18, 300}) - ->Args({1, 300, 300}) - ->Args({1, 300, 1}); +BM_TFMlir_DYNAMIC_ALL(10, 10, 10, 8, 8, 8, "matmul", f32); +BM_TFMlir_STATIC_ALL(10, 10, 10, 8, 8, 8, "matmul", f32); +BM_TFMlir_DYNAMIC_ALL(10, 10, 10, 4, 4, 4, "matmul", f32); +BM_TFMlir_STATIC_ALL(10, 10, 10, 4, 4, 4, "matmul", f32); +BM_TFMlir_DYNAMIC_ALL(10, 10, 10, 2, 2, 2, "matmul", f32); +BM_TFMlir_STATIC_ALL(10, 10, 10, 2, 2, 2, "matmul", f32); +BM_TFMlir_STATIC_ALL(10, 10, 10, 2, 2, 8, "matmul", f32); +BM_TFMlir_STATIC_ALL(10, 10, 10, 2, 8, 2, "matmul", f32); +BM_TFMlir_STATIC_ALL(10, 10, 10, 8, 2, 2, "matmul", f32); +BM_Eigen_WRAPPER(10, 10, 10, f32); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.h b/tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.h index 3741f894cd0..07d8952a7e4 100644 --- a/tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.h +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_BENCHMARKS_MATMUL_OP_BENCHMARK_H_ #define TENSORFLOW_COMPILER_MLIR_TFRT_BENCHMARKS_MATMUL_OP_BENCHMARK_H_ +#include #include +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h" #include "tensorflow/compiler/mlir/tfrt/utils/host_context.h" @@ -27,6 +29,13 @@ namespace tensorflow { // used only to build benchmarks for different functions in this folder, so // it is ok to put convenience using-declarations here. +std::string GetMatmulIR(llvm::ArrayRef lhs_shape, + llvm::ArrayRef lhs_dynamic_dims, + llvm::ArrayRef rhs_shape, + llvm::ArrayRef rhs_dynamic_dims, + llvm::ArrayRef output_shape, + llvm::ArrayRef output_dynamic_dims); + using ::tfrt::AsyncValue; using ::tfrt::AsyncValuePtr; using ::tfrt::HostContext; @@ -46,7 +55,9 @@ using ::xla::runtime::MemrefDesc; template void RunMatMulMlirBenchmark(::testing::benchmark::State& state, - llvm::StringRef mlir_input, + // output_name is actually used on debug mode. + // NOLINTNEXTLINE + std::string output_name, llvm::StringRef mlir_input, llvm::StringRef function_name) { // MatMul: [m, k] x [k, n] ssize_t m = state.range(0); @@ -56,6 +67,13 @@ void RunMatMulMlirBenchmark(::testing::benchmark::State& state, std::unique_ptr host = CreateSingleThreadedHostContext(); TfJitRtPipelineOptions tf_jitrt_opts; + tf_jitrt_opts.vectorize = tensorflow::GetJitRtFlags().vectorize; + tf_jitrt_opts.lower_to_mmt4d = tensorflow::GetJitRtFlags().pack_matmul; + tf_jitrt_opts.enable_xla_cpu_transformations = + tensorflow::GetJitRtFlags().enable_xla_cpu_transformations; + tf_jitrt_opts.matmul_tile_sizes = {state.range(3), state.range(4), + state.range(5)}; + JitExecutable& jit_executable = CreateJitExecutable(*host, mlir_input, function_name, /*lower_from_tensorflow=*/true, tf_jitrt_opts); @@ -99,6 +117,19 @@ void RunMatMulMlirBenchmark(::testing::benchmark::State& state, jit_executable.GetExecutable(operands); if (!executable.ok()) LOG(FATAL) << "Failed to specialize executable"; +#if defined(DEBUG_XLA_RUNTIME_COMPILER) + std::string dump_path = "/tmp/"; + std::unique_ptr obj = (*executable)->obj_file(); + CHECK(obj) << "Failed to get executable obj file"; + std::string object_filename = output_name; + if (tf_jitrt_opts.lower_to_mmt4d) object_filename += "_packed"; + object_filename += ".o"; + std::error_code ec; + llvm::raw_fd_ostream dump_stream(dump_path + object_filename, ec); + CHECK(!ec) << "Failed to dump object file: " << ec.message(); + dump_stream.write(obj->getBufferStart(), obj->getBufferSize()); +#endif + // Wait for the compilation completion. host->Await({executable->CopyRef()}); @@ -143,6 +174,8 @@ void RunMatMulEigenBenchmark(::testing::benchmark::State& state) { using Device = Eigen::DefaultDevice; Device d; + CHECK(d.numThreads() == 1) << "Executing Eigen in multi-threaded"; + Eigen::Tensor dst(m, n); dst.setZero(); @@ -166,16 +199,41 @@ void RunMatMulEigenBenchmark(::testing::benchmark::State& state) { // Macros to dispatch to different MatMul shapes. // -------------------------------------------------------------------------- // -#define BM_TFMlir(NAME, MLIR_INPUT, FN, TYPE) \ - static void BM_mlir_##NAME##_##TYPE(::testing::benchmark::State& state) { \ - RunMatMulMlirBenchmark(state, MLIR_INPUT, FN); \ - } \ +#define INTS(...) __VA_ARGS__ +#define BOOLS(...) __VA_ARGS__ + +#define BM_TFMlir(NAME, LHS_SHAPE, LHS_DYN_DIMS, RHS_SHAPE, RHS_DYN_DIMS, \ + OUT_SHAPE, OUT_DYN_DIMS, FN, TYPE) \ + static void BM_mlir_##NAME##_##TYPE(::testing::benchmark::State& state) { \ + RunMatMulMlirBenchmark( \ + state, #NAME, \ + GetMatmulIR({LHS_SHAPE}, {LHS_DYN_DIMS}, {RHS_SHAPE}, {RHS_DYN_DIMS}, \ + {OUT_SHAPE}, {OUT_DYN_DIMS}, #TYPE), \ + FN); \ + } \ BENCHMARK(BM_mlir_##NAME##_##TYPE) +#define BM_TFMlir_DYNAMIC_ALL(M, N, K, T_M, T_N, T_K, FN, TYPE) \ + BM_TFMlir(MatmulDynamicAll_##M##_##K##_##N##_##T_M##_##T_N##_##T_K, \ + INTS(M, K), BOOLS(kDynamicDim, kDynamicDim), INTS(K, N), \ + BOOLS(kDynamicDim, kDynamicDim), INTS(M, N), \ + BOOLS(kDynamicDim, kDynamicDim), FN, TYPE) \ + ->Args({M, K, N, T_M, T_N, T_K}) + +#define BM_TFMlir_STATIC_ALL(M, N, K, T_M, T_N, T_K, FN, TYPE) \ + BM_TFMlir(MatmulStaticAll_##M##_##K##_##N##_##T_M##_##T_N##_##T_K, \ + INTS(M, K), BOOLS(kStaticDim, kStaticDim), INTS(K, N), \ + BOOLS(kStaticDim, kStaticDim), INTS(M, N), \ + BOOLS(kStaticDim, kStaticDim), FN, TYPE) \ + ->Args({M, K, N, T_M, T_N, T_K}) + #define BM_Eigen(NAME, TYPE) \ static void BM_eigen_##NAME##_##TYPE(::testing::benchmark::State& state) { \ RunMatMulEigenBenchmark(state); \ } \ BENCHMARK(BM_eigen_##NAME##_##TYPE) +#define BM_Eigen_WRAPPER(M, N, K, TYPE) \ + BM_Eigen(Matmul_##M##_##K##_##N, TYPE)->Args({M, K, N}) + #endif // TENSORFLOW_COMPILER_MLIR_TFRT_BENCHMARKS_MATMUL_OP_BENCHMARK_H_ diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/reverse_op_benchmark.cc b/tensorflow/compiler/mlir/tfrt/benchmarks/reverse_op_benchmark.cc new file mode 100644 index 00000000000..895075c468d --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/reverse_op_benchmark.cc @@ -0,0 +1,207 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/Support/FormatVariadic.h" +#include "tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h" +#include "tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.h" + +namespace tensorflow { +namespace { + +const char* kReverseIR = R"( + func.func @main(%input: {0}) -> {0} { + %reverse_dims = "tf.Const"() {{ + value = {1} : {2}, + device = "/job:localhost/replica:0/task:0/device:CPU:0" + } : () -> {2} + %result = "tf.ReverseV2"(%input, %reverse_dims) + {{device = "/job:localhost/replica:0/task:0/device:CPU:0"} + : ({0}, {2}) -> {0} + func.return %result : {0} + } +)"; + +std::string Reverse(llvm::ArrayRef input_shape, + llvm::ArrayRef dynamic_dims, + llvm::ArrayRef reverse_dims, + llvm::StringRef element_type) { + llvm::SmallVector mlir_input_shape = + GetTensorTypeShape(input_shape, dynamic_dims); + return llvm::formatv( + kReverseIR, + PrintTensorType(mlir_input_shape, element_type), // Input type {0} + PrintDenseArray(reverse_dims), // Dims to reverse attr {1} + PrintTensorType(static_cast(reverse_dims.size()), + "i32") // Dims to reverse type {2} + ); +} + +template +auto EigenReverse(std::array reverse_dims) { + return [reverse_dims](llvm::ArrayRef inputs, + llvm::Optional device) { + std::array bool_reverse_dims; + bool_reverse_dims.fill(false); + for (auto i : reverse_dims) { + bool_reverse_dims[i] = true; + } + Tensor output(DT_FLOAT, inputs[0].shape()); + auto in = inputs[0].tensor(); + auto out = output.tensor(); + if (device.has_value()) { + out.device(*device) = in.reverse(bool_reverse_dims); + } else { + out = in.reverse(bool_reverse_dims); + } + }; +} + +llvm::SmallVector GetInputSpec( + llvm::ArrayRef input_shape) { + return {InputTensorSpec(DT_FLOAT, input_shape)}; +} + +#define INTS(...) __VA_ARGS__ +#define BOOLS(...) __VA_ARGS__ + +#define BM(KIND, ...) BM_##KIND(__VA_ARGS__)->Arg(0); + +#define BM_SUITE(NAME, INPUT_RANK, INPUT_SHAPE, DYNAMIC_DIMS, N_REVERSE_DIMS, \ + REVERSE_DIMS) \ + BM(JitrtV, NAME, \ + Reverse({INPUT_SHAPE}, {DYNAMIC_DIMS}, {REVERSE_DIMS}, "f32"), "main", \ + GetInputSpec({INPUT_SHAPE})); \ + BM(Eigen, NAME, \ + (EigenReverse( \ + std::array{REVERSE_DIMS})), \ + GetInputSpec({INPUT_SHAPE})); \ + BM(Tfrt, NAME, \ + Reverse({INPUT_SHAPE}, {DYNAMIC_DIMS}, {REVERSE_DIMS}, "f32"), "main", \ + GetInputSpec({INPUT_SHAPE})) + +//////////////////////////////////////////////////////////////////////////////// +// Reverse 1D tensors. +//////////////////////////////////////////////////////////////////////////////// + +#define BM_STATIC_1D(SIZE) \ + BM_SUITE(ReverseStatic_1D_##SIZE, 1, INTS(SIZE), BOOLS(kStaticDim), 1, \ + INTS(0)) +BM_STATIC_1D(3); +BM_STATIC_1D(8); +BM_STATIC_1D(80); +BM_STATIC_1D(800); +BM_STATIC_1D(8000); +BM_STATIC_1D(8131); +BM_STATIC_1D(1000000); +BM_STATIC_1D(1010131); + +#define BM_DYNAMIC_1D(SIZE) \ + BM_SUITE(ReverseDynamic_1D_##SIZE, 1, INTS(SIZE), BOOLS(kDynamicDim), 1, \ + INTS(0)) +BM_DYNAMIC_1D(3); +BM_DYNAMIC_1D(8); +BM_DYNAMIC_1D(80); +BM_DYNAMIC_1D(800); +BM_DYNAMIC_1D(8000); +BM_DYNAMIC_1D(8131); +BM_DYNAMIC_1D(1000000); +BM_DYNAMIC_1D(1010131); + +//////////////////////////////////////////////////////////////////////////////// +// Reverse 2D tensors. +//////////////////////////////////////////////////////////////////////////////// + +#define BM_STATIC_2D_ROW(ROWS, COLS) \ + BM_SUITE(ReverseStatic_2D_ROW_##ROWS##_##COLS, 2, INTS(ROWS, COLS), \ + BOOLS(kStaticDim, kStaticDim), 1, INTS(0)) +BM_STATIC_2D_ROW(2, 80); +BM_STATIC_2D_ROW(8, 6); +BM_STATIC_2D_ROW(80, 1); +BM_STATIC_2D_ROW(80, 3); +BM_STATIC_2D_ROW(80, 7); +BM_STATIC_2D_ROW(80, 60); +BM_STATIC_2D_ROW(81, 61); +BM_STATIC_2D_ROW(800, 600); +BM_STATIC_2D_ROW(802, 602); + +#define BM_STATIC_2D_COL(ROWS, COLS) \ + BM_SUITE(ReverseStatic_2D_COL_##ROWS##_##COLS, 2, INTS(ROWS, COLS), \ + BOOLS(kStaticDim, kStaticDim), 1, INTS(1)) +BM_STATIC_2D_COL(2, 80); +BM_STATIC_2D_COL(8, 6); +BM_STATIC_2D_COL(80, 1); +BM_STATIC_2D_COL(80, 3); +BM_STATIC_2D_COL(80, 7); +BM_STATIC_2D_COL(80, 60); +BM_STATIC_2D_COL(81, 61); +BM_STATIC_2D_COL(800, 600); +BM_STATIC_2D_COL(802, 602); + +#define BM_STATIC_2D_ALL(ROWS, COLS) \ + BM_SUITE(ReverseStatic_2D_ALL_##ROWS##_##COLS, 2, INTS(ROWS, COLS), \ + BOOLS(kStaticDim, kStaticDim), 2, INTS(0, 1)) +BM_STATIC_2D_ALL(2, 80); +BM_STATIC_2D_ALL(8, 6); +BM_STATIC_2D_ALL(80, 1); +BM_STATIC_2D_ALL(80, 3); +BM_STATIC_2D_ALL(80, 7); +BM_STATIC_2D_ALL(80, 60); +BM_STATIC_2D_ALL(81, 61); +BM_STATIC_2D_ALL(800, 600); +BM_STATIC_2D_ALL(802, 602); + +#define BM_DYNAMIC_2D_ROW(ROWS, COLS) \ + BM_SUITE(ReverseDynamic_2D_ROW_##ROWS##_##COLS, 2, INTS(ROWS, COLS), \ + BOOLS(kDynamicDim, kStaticDim), 1, INTS(0)) +BM_DYNAMIC_2D_ROW(2, 80); +BM_DYNAMIC_2D_ROW(8, 6); +BM_DYNAMIC_2D_ROW(80, 1); +BM_DYNAMIC_2D_ROW(80, 3); +BM_DYNAMIC_2D_ROW(80, 7); +BM_DYNAMIC_2D_ROW(80, 60); +BM_DYNAMIC_2D_ROW(81, 61); +BM_DYNAMIC_2D_ROW(800, 600); +BM_DYNAMIC_2D_ROW(802, 602); + +#define BM_DYNAMIC_2D_COL(ROWS, COLS) \ + BM_SUITE(ReverseDynamic_2D_COL_##ROWS##_##COLS, 2, INTS(ROWS, COLS), \ + BOOLS(kStaticDim, kDynamicDim), 1, INTS(1)) +BM_DYNAMIC_2D_COL(2, 80); +BM_DYNAMIC_2D_COL(8, 6); +BM_DYNAMIC_2D_COL(80, 1); +BM_DYNAMIC_2D_COL(80, 3); +BM_DYNAMIC_2D_COL(80, 7); +BM_DYNAMIC_2D_COL(80, 60); +BM_DYNAMIC_2D_COL(81, 61); +BM_DYNAMIC_2D_COL(800, 600); +BM_DYNAMIC_2D_COL(802, 602); + +#define BM_DYNAMIC_2D_ALL(ROWS, COLS) \ + BM_SUITE(ReverseDynamic_2D_ALL_##ROWS##_##COLS, 2, INTS(ROWS, COLS), \ + BOOLS(kDynamicDim, kDynamicDim), 2, INTS(0, 1)) +BM_DYNAMIC_2D_ALL(2, 80); +BM_DYNAMIC_2D_ALL(8, 6); +BM_DYNAMIC_2D_ALL(80, 1); +BM_DYNAMIC_2D_ALL(80, 3); +BM_DYNAMIC_2D_ALL(80, 7); +BM_DYNAMIC_2D_ALL(80, 60); +BM_DYNAMIC_2D_ALL(81, 61); +BM_DYNAMIC_2D_ALL(800, 600); +BM_DYNAMIC_2D_ALL(802, 602); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/scatter_op_benchmark.cc b/tensorflow/compiler/mlir/tfrt/benchmarks/scatter_op_benchmark.cc new file mode 100644 index 00000000000..fef65faf7d0 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/scatter_op_benchmark.cc @@ -0,0 +1,130 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h" +#include "tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.h" + +namespace tensorflow { +namespace { + +// {0} -- updates_shape +// {1} -- output_shape +// {2} -- indices_value +// {3} -- indices_value +const char* kMapScatterIR = R"( + func.func @main(%updates: {0}, %out: {1}) -> {1} { + %indices = "tf.Const"() + {{value = {2} : {3}, + device = "/job:localhost/replica:0/task:0/device:CPU:0"} + : () -> {3} + %updates_exp = "tf.Exp"(%updates) + {{device = "/job:localhost/replica:0/task:0/device:CPU:0"} + : ({0}) -> {0} + %scattered = "tf.TensorScatterAdd"(%out, %indices, %updates) + {{device = "/job:localhost/replica:0/task:0/device:CPU:0"} + : ({1}, {3}, {0}) -> {1} + func.return %scattered : {1} + } +)"; + +std::string GetScatterIndices(llvm::ArrayRef updates_shape, + ssize_t rows) { + std::string result{"dense<["}; + llvm::raw_string_ostream ss(result); + for (size_t i = 0; i < updates_shape[0]; ++i) { + if (i > 0) ss << ','; + ss << "[0, " << (i * 5) % rows << "]"; + } + ss << "]>"; + return result; +} + +std::string MapScatter(llvm::ArrayRef dynamic_dims, + llvm::ArrayRef updates_shape, + llvm::ArrayRef output_shape) { + llvm::SmallVector mlir_output_shape; + for (int i = 0; i < output_shape.size(); ++i) { + mlir_output_shape.push_back(dynamic_dims[i] ? kDynSize : output_shape[i]); + } + llvm::SmallVector indeces_shape = {updates_shape[0], 2}; + return llvm::formatv(kMapScatterIR, PrintTensorType(updates_shape, "f32"), + PrintTensorType(mlir_output_shape, "f32"), + GetScatterIndices(updates_shape, output_shape[1]), + PrintTensorType(indeces_shape, "i32")); +} + +llvm::SmallVector Inputs(ssize_t rows, ssize_t cols, + ssize_t rows_upd) { + return {InputTensorSpec(DT_FLOAT, {rows_upd, cols}), // updates_shape + InputTensorSpec(DT_FLOAT, {1, rows, cols})}; // output_shapes +} + +// This benchmark checks the insertion of full rows (Tfrt requirement). +#define BM(FN) BM_##FN->Arg(0); + +#define BM_SUITE(NAME, DYNAMIC_ROW, DYNAMIC_COL, ROWS, COLS, ROWS_UPD) \ + BM(JitrtV(NAME, \ + MapScatter({DYNAMIC_ROW, DYNAMIC_COL}, {ROWS_UPD, COLS}, \ + {1, ROWS, COLS}), \ + "main", Inputs(ROWS, COLS, ROWS_UPD))); \ + BM(Tfrt(NAME, \ + MapScatter({DYNAMIC_ROW, DYNAMIC_COL}, {ROWS_UPD, COLS}, \ + {1, ROWS, COLS}), \ + "main", Inputs(ROWS, COLS, ROWS_UPD))) + +#define BM_STATIC_ALL(ROWS, COLS, ROWS_UPD) \ + BM_SUITE(MapScatterStaticAll_##ROWS##_##COLS##_##ROWS_UPD, kStaticDim, \ + kStaticDim, ROWS, COLS, ROWS_UPD) +BM_STATIC_ALL(11, 1, 5); +BM_STATIC_ALL(20, 11, 5); +BM_STATIC_ALL(1, 80, 100); +BM_STATIC_ALL(80, 1, 5); +BM_STATIC_ALL(800, 600, 10); +BM_STATIC_ALL(802, 602, 100); + +#define BM_DYNAMIC_ALL(ROWS, COLS, ROWS_UPD) \ + BM_SUITE(MapScatterDynamicAll_##ROWS##_##COLS##_##ROWS_UPD, kDynamicDim, \ + kDynamicDim, ROWS, COLS, ROWS_UPD) +BM_DYNAMIC_ALL(11, 1, 5); +BM_DYNAMIC_ALL(20, 11, 5); +BM_DYNAMIC_ALL(1, 80, 100); +BM_DYNAMIC_ALL(80, 1, 5); +BM_DYNAMIC_ALL(800, 600, 10); +BM_DYNAMIC_ALL(802, 602, 100); + +#define BM_STATIC_ROW(ROWS, COLS, ROWS_UPD) \ + BM_SUITE(MapScatterStaticRow_##ROWS##_##COLS##_##ROWS_UPD, kStaticDim, \ + kDynamicDim, ROWS, COLS, ROWS_UPD) +BM_STATIC_ROW(11, 1, 5); +BM_STATIC_ROW(20, 11, 5); +BM_STATIC_ROW(1, 80, 100); +BM_STATIC_ROW(80, 1, 5); +BM_STATIC_ROW(800, 600, 10); +BM_STATIC_ROW(802, 602, 100); + +#define BM_STATIC_COL(ROWS, COLS, ROWS_UPD) \ + BM_SUITE(MapScatterStaticCol_##ROWS##_##COLS##_##ROWS_UPD, kDynamicDim, \ + kStaticDim, ROWS, COLS, ROWS_UPD) +BM_STATIC_COL(11, 1, 5); +BM_STATIC_COL(20, 11, 5); +BM_STATIC_COL(1, 80, 100); +BM_STATIC_COL(80, 1, 5); +BM_STATIC_COL(800, 600, 10); +BM_STATIC_COL(802, 602, 100); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/softmax_op_benchmark.cc b/tensorflow/compiler/mlir/tfrt/benchmarks/softmax_op_benchmark.cc index 07f53407e9c..ea68306d8ec 100644 --- a/tensorflow/compiler/mlir/tfrt/benchmarks/softmax_op_benchmark.cc +++ b/tensorflow/compiler/mlir/tfrt/benchmarks/softmax_op_benchmark.cc @@ -51,8 +51,8 @@ static void ComputeSoftmax(const Eigen::DefaultDevice& d, InT logits, const int batch_size = logits.dimension(kBatchDim); const int num_classes = logits.dimension(kClassDim); -// These arrays are used to reduce along the class dimension, and broadcast -// the resulting value to all classes. + // These arrays are used to reduce along the class dimension, and broadcast + // the resulting value to all classes. Eigen::IndexList > along_class; Eigen::IndexList > batch_by_one; batch_by_one.set(0, batch_size); @@ -109,8 +109,9 @@ BM_DYNAMIC_ALL(81, 61); BM_DYNAMIC_ALL(800, 600); BM_DYNAMIC_ALL(802, 602); -#define BM_STATIC_ROW(ROWS, COLS) \ - BM_SUITE(SoftmaxStaticRow##ROWS##_##COLS, kStaticDim, kDynamicDim, ROWS, COLS) +#define BM_STATIC_ROW(ROWS, COLS) \ + BM_SUITE(SoftmaxStaticRow_##ROWS##_##COLS, kStaticDim, kDynamicDim, ROWS, \ + COLS) BM_STATIC_ROW(2, 80); BM_STATIC_ROW(8, 6); BM_STATIC_ROW(80, 1); diff --git a/tensorflow/compiler/mlir/tfrt/function/function.cc b/tensorflow/compiler/mlir/tfrt/function/function.cc index 306ec82e374..892fa355bcc 100644 --- a/tensorflow/compiler/mlir/tfrt/function/function.cc +++ b/tensorflow/compiler/mlir/tfrt/function/function.cc @@ -74,8 +74,10 @@ Status CompileTFMLIRToBEF(const TfrtFunctionCompileOptions& options, pass_options.tpu_fuse_ops = options.tpu_fuse_ops; pass_options.tpu_transfer_result_to_host = options.tpu_transfer_result_to_host; - pass_options.enable_native_ops = options.enable_native_ops; - tensorflow::CreateTfExecutorToTfrtPipeline(pm, pass_options); + Status status = tensorflow::CreateTfExecutorToTfrtPipeline(pm, pass_options); + if (!status.ok()) { + return diag_handler.Combine(status); + } if (mlir::failed(pm.run(module))) return diag_handler.Combine(tensorflow::errors::Internal( diff --git a/tensorflow/compiler/mlir/tfrt/function/function.h b/tensorflow/compiler/mlir/tfrt/function/function.h index 936119e5570..1a7d8bd0592 100644 --- a/tensorflow/compiler/mlir/tfrt/function/function.h +++ b/tensorflow/compiler/mlir/tfrt/function/function.h @@ -42,9 +42,6 @@ struct TfrtFunctionCompileOptions : public TfrtCompileOptions { // Currently only SavedModel API inference uses the tpu_fuse_ops option TfrtFunctionCompileOptions() { tpu_fuse_ops = false; - // TF function in eager execution uses CoreRT native ops as fallback states - // are not initialized in that code path. - enable_native_ops = true; // Currently grappler is not correctly applied in the eager execution of TF // functions, as it may sometimes remove arguments and results. enable_grappler = false; diff --git a/tensorflow/compiler/mlir/tfrt/ir/BUILD b/tensorflow/compiler/mlir/tfrt/ir/BUILD index 99ec360b617..f63f0c7ff07 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/BUILD @@ -3,6 +3,7 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -181,3 +182,63 @@ gentbl_cc_library( ], deps = [":tfrt_fallback_td_files"], ) + +td_library( + name = "gpu_ops_td_file", + srcs = [ + "gpu_ops.td", + ], + includes = ["."], + visibility = [ + "//tensorflow/compiler/mlir/tfrt:__subpackages__", + ], + deps = [ + ":tfrt_fallback_td_files", + "@tf_runtime//:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "tfrt_gpu_opdefs_inc_gen", + tbl_outs = [ + ( + ["-gen-op-decls"], + "gpu_ops.h.inc", + ), + ( + ["-gen-op-defs"], + "gpu_ops.cpp.inc", + ), + ( + [ + "-gen-dialect-decls", + "-dialect=gpurt", + ], + "gpurt_dialect.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "gpu_ops.td", + test = True, + visibility = [ + "//tensorflow/compiler/mlir/tfrt:__subpackages__", + ], + deps = [":gpu_ops_td_file"], +) + +cc_library( + name = "tfrt_gpu_opdefs", + srcs = [ + "gpu_ops.cc", + ], + hdrs = ["gpu_ops.h"], + visibility = [ + "//tensorflow/compiler/mlir/tfrt:__subpackages__", + ], + deps = [ + ":tfrt_fallback_opdefs", + ":tfrt_gpu_opdefs_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) diff --git a/tensorflow/compiler/mlir/tfrt/ir/gpu_ops.cc b/tensorflow/compiler/mlir/tfrt/ir/gpu_ops.cc new file mode 100644 index 00000000000..77e71b834b1 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/gpu_ops.cc @@ -0,0 +1,41 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/ir/gpu_ops.h" + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h" + +namespace tfrt { +namespace gpu { + +GpuRuntimeDialect::GpuRuntimeDialect(MLIRContext *context) + : Dialect(/*name=*/"gpurt", context, TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/mlir/tfrt/ir/gpu_ops.cpp.inc" + >(); +} + +} // namespace gpu +} // namespace tfrt + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/gpu_ops.cpp.inc" diff --git a/tensorflow/compiler/mlir/tfrt/ir/gpu_ops.h b/tensorflow/compiler/mlir/tfrt/ir/gpu_ops.h new file mode 100644 index 00000000000..a270b8badf1 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/gpu_ops.h @@ -0,0 +1,40 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_IR_GPU_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_IR_GPU_OPS_H_ + +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project + +using namespace mlir; // NOLINT + +namespace tfrt { +namespace gpu { + +// Dialect for TFRT GPU operations. +class GpuRuntimeDialect : public Dialect { + public: + explicit GpuRuntimeDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "gpurt"; } +}; + +} // namespace gpu +} // namespace tfrt + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/gpu_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_IR_GPU_OPS_H_ diff --git a/tensorflow/compiler/mlir/tfrt/ir/gpu_ops.td b/tensorflow/compiler/mlir/tfrt/ir/gpu_ops.td new file mode 100644 index 00000000000..88fc36fb6da --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/ir/gpu_ops.td @@ -0,0 +1,91 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef TFRT_GPU_OPS +#else +#define TFRT_GPU_OPS + +include "tfrt/tfrt_op_base.td" +include "tfrt/compiler/opdefs/tfrt_op_interfaces.td" +include "tfrt/compiler/opdefs/tfrt_traits.td" +include "tfrt_fallback.td" + +def TFRT_GPU_Dialect : Dialect { + let name = "gpurt"; + + let description = [{ + The TFRT GPU Dialect. + }]; + + let cppNamespace = "::tfrt::gpu"; + let useFoldAPI = kEmitFoldAdaptorFolder; +} + +class Gpu_Op traits = []> : + Op { +} + +// TODO(b/260267885): We may add a device argument when we want to support +// GPU MIG. +def TransferToDeviceOp: Gpu_Op<"transfer_to_device"> { + let summary = "Transfer a CPU tensor to device."; + + let description = [{ + Transfer a CPU tensor to device. + + Example: + %device_tensor = gpurt.transfer_to_device %cpu_tensor + }]; + + let arguments = (ins TFTensorType); + let results = (outs TFTensorType); + let assemblyFormat = "operands attr-dict"; +} + +// TODO(b/260267885): We may add a device argument when we want to support +// GPU MIG. +def TransferFromDeviceOp: Gpu_Op<"transfer_from_device"> { + let summary = "Transfer a tensor from device."; + + let description = [{ + Transfer a tensor from device. + + Example: + %cpu_tensor = gpurt.transfer_from_device %device_tensor + }]; + + let arguments = (ins TFTensorType); + let results = (outs TFTensorType); + let assemblyFormat = "operands attr-dict"; +} + +// TODO(b/260267885): We may add a device argument when we want to support +// GPU MIG. +def MaybeTransferVariableOp: Gpu_Op<"maybe_transfer_variable"> { + let summary = "Transfer a CPU variable tensor to device."; + let description = [{ + Transfer a CPU variable tensor to device if the variable has not been + transferred before. + + Example: + %device_var = gpurt.maybe_transfer_variable %cpu_var + }]; + + let arguments = (ins TFTensorType); + let results = (outs TFTensorType); + let assemblyFormat = "operands attr-dict"; +} + +#endif // TFRT_GPU_OPS diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td index c06a389460c..540d5685929 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td @@ -27,8 +27,7 @@ def Fallback_Dialect : Dialect { }]; let cppNamespace = "::tfrt::fallback"; - - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let useFoldAPI = kEmitFoldAdaptorFolder; } // This corresponds to tensorflow::Tensor. diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.cc b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.cc index 7247a3affed..9643c041cf6 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.cc +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.cc @@ -46,7 +46,7 @@ struct FallbackInlinerInterface : public mlir::DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; bool isLegalToInline(Operation *op, Region *dest, bool would_be_cloned, - BlockAndValueMapping &) const final { + IRMapping &) const final { return true; } }; @@ -304,9 +304,7 @@ void ExecuteOp::getOpAttrs( // ConstDenseTensorOp //===----------------------------------------------------------------------===// -OpFoldResult ConstDenseTensorOp::fold(ArrayRef operands) { - return getValue(); -} +OpFoldResult ConstDenseTensorOp::fold(FoldAdaptor) { return getValue(); } //===----------------------------------------------------------------------===// // CoreRTTensorHandleToFallbackTensorOp diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.td b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.td index 260efe10bc0..881fdc5e35d 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.td +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.td @@ -33,8 +33,7 @@ def FallbackAsync_Dialect : Dialect { }]; let cppNamespace = "::tfrt::fallback_async"; - - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let useFoldAPI = kEmitFoldAdaptorFolder; } class FallbackAsync_Op traits = []> : diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.cc b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.cc index 83663441028..8083fcac076 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.cc +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.cc @@ -53,36 +53,6 @@ static Type GetTensorType(Builder *builder) { return tfrt::t::TensorType::get(builder->getContext()); } -LogicalResult SyncExecuteOp::verify() { - return fallback_common::VerifyExecuteOpCommon(*this); -} - -ParseResult SyncExecuteOp::parse(OpAsmParser &parser, OperationState &result) { - fallback_common::ParseExecuteOpOptions parse_options; - parse_options.has_chain = false; - parse_options.has_key = false; - parse_options.has_device = false; - parse_options.has_func_attr = false; - parse_options.has_cost = false; - - auto &builder = parser.getBuilder(); - return fallback_common::ParseExecuteOpCommon( - parser, builder, result, GetTensorType(&builder), parse_options); -} - -void SyncExecuteOp::print(OpAsmPrinter &p) { - p << " " << (*this)->getAttr("op_name") << '(' << operands() << ')'; - - fallback_common::PrintExecuteOpCommon(p, *this); - if (!getResults().empty()) p << " : " << getResults().size(); -} - -void SyncExecuteOp::getOpAttrs( - SmallVectorImpl> *op_attrs) { - fallback_common::GetExecuteOpAttrsCommon( - this->getContext(), this->getOpAttrs().getValue(), op_attrs); -} - } // namespace fallback_sync } // namespace tfrt diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.td b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.td index d5a58b22fa9..0f9a50285d7 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.td +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.td @@ -32,8 +32,7 @@ def FallbackSync_Dialect : Dialect { }]; let cppNamespace = "::tfrt::fallback_sync"; - - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let useFoldAPI = kEmitFoldAdaptorFolder; } class FallbackSync_Op traits = []> : @@ -84,7 +83,7 @@ def GetResourceOp : FallbackSync_Op<"get_resource", let assemblyFormat = "attr-dict `:` type($results)"; } -def CreateOp: FallbackSync_Op<"createop", [CoreRT_TypedAttributeTrait]> { +def CreateOp: FallbackSync_Op<"createop", []> { let summary = "The Fallback CreateOp"; let description = [{ @@ -102,17 +101,16 @@ def CreateOp: FallbackSync_Op<"createop", [CoreRT_TypedAttributeTrait]> { }]; let arguments = (ins - I64Attr:$num_args, - ArrayAttr:$op_attrs, - I64Attr:$op_key, - StrAttr:$op_name + StrAttr:$node_def, + I32Attr:$op_key ); let results = (outs); + + let assemblyFormat = "attr-dict"; } -def SyncExecuteOp : FallbackSync_Op<"executeop", - [Pure, CoreRT_TypedAttributeTrait]> { +def SyncExecuteOp : FallbackSync_Op<"executeop", [Pure]> { let summary = "The Fallback Sync ExecuteOp"; let description = [{ The ExecuteOp executes an operation on the specified device. @@ -124,28 +122,14 @@ def SyncExecuteOp : FallbackSync_Op<"executeop", }]; let arguments = (ins - Variadic:$operands, - ArrayAttr:$op_attrs, - I64Attr:$op_key, - StrAttr:$op_name + Variadic, + StrAttr:$node_def, + I32Attr:$op_key ); - let results = (outs - Variadic:$results - ); - - let extraClassDeclaration = [{ - void getOpAttrs(SmallVectorImpl>* op_attrs); - }]; - - let builders = [ - OpBuilder<(ins "ArrayRef":$results, "ValueRange":$operands, - "ArrayRef>":$op_attrs, - "StringRef":$op_name)>]; - - let hasVerifier = 1; + let results = (outs Variadic); - let hasCustomAssemblyFormat = 1; + let assemblyFormat = "`(`operands`)` attr-dict `:` functional-type(operands, results)"; } diff --git a/tensorflow/compiler/mlir/tfrt/jit/default/BUILD b/tensorflow/compiler/mlir/tfrt/jit/default/BUILD index 715c66af133..ee0e1e53cbc 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/default/BUILD +++ b/tensorflow/compiler/mlir/tfrt/jit/default/BUILD @@ -1,6 +1,7 @@ load("@tf_runtime//:build_defs.bzl", "tfrt_cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [":friends"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.cc b/tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.cc index 25a7e036c20..263595c9fcd 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.cc @@ -48,12 +48,12 @@ struct JitRuntimeInlinerInterface : public DialectInlinerInterface { } bool isLegalToInline(Region*, Region*, bool, - BlockAndValueMapping&) const final { + IRMapping&) const final { return true; } bool isLegalToInline(Operation*, Region*, bool, - BlockAndValueMapping&) const final { + IRMapping&) const final { return true; } }; diff --git a/tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.td b/tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.td index a426ce1432b..4176d16a394 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.td +++ b/tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.td @@ -40,8 +40,7 @@ def TF_JITRT_Dialect : Dialect { }]; let cppNamespace = "::mlir::tf_jitrt"; - - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let useFoldAPI = kEmitFoldAdaptorFolder; } //===----------------------------------------------------------------------===// @@ -117,17 +116,15 @@ def FallbackExecuteOp : TF_JITRT_Op<"fallback.execute", let arguments = (ins SymbolRefAttr:$kernel, - Variadic:$operands, + Variadic, StrAttr:$device ); - let results = (outs - Variadic:$results - ); + let results = (outs Variadic); let assemblyFormat = [{ - $kernel `(` $operands `)` `device` `(` $device `)` attr-dict `:` - functional-type($operands, $results) + $kernel `(` operands `)` `device` `(` $device `)` attr-dict `:` + functional-type(operands, results) }]; } @@ -141,7 +138,7 @@ def FallbackDebugExecuteOp : TF_JITRT_Op<"fallback.debug.execute"> { let arguments = (ins SymbolRefAttr:$kernel, - Variadic:$operands, + Variadic, StrAttr:$device, // Print to standard output whenever compiled kernel specialized for the // operands shapes or values. @@ -154,13 +151,11 @@ def FallbackDebugExecuteOp : TF_JITRT_Op<"fallback.debug.execute"> { BoolAttr:$legalize_i1_tensors ); - let results = (outs - Variadic:$results - ); + let results = (outs Variadic); let assemblyFormat = [{ - $kernel `(` $operands `)` `device` `(` $device `)` attr-dict `:` - functional-type($operands, $results) + $kernel `(` operands `)` `device` `(` $device `)` attr-dict `:` + functional-type(operands, results) }]; } diff --git a/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD b/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD index b0c7f004ed2..6314fb77c90 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD +++ b/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD @@ -4,6 +4,7 @@ load("//tensorflow:strict.default.bzl", "py_strict_test") licenses(["notice"]) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [":__subpackages__"], ) diff --git a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.cc b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.cc index 723e3728e6f..b74ce13e165 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.cc @@ -76,12 +76,11 @@ TfJitRtExecutor::TfJitRtExecutor() }, CreateMallocAllocator(), CreateMultiThreadedWorkQueue(4, 4)) {} -TfJitRtExecutor::Handle TfJitRtExecutor::Compile(const std::string& mlir_module, - const std::string& entrypoint, - Specialization specialization, - bool vectorize, - bool codegen_transpose, - bool legalize_i1_tensors) { +TfJitRtExecutor::Handle TfJitRtExecutor::Compile( + const std::string& mlir_module, const std::string& entrypoint, + Specialization specialization, bool vectorize, bool codegen_transpose, + bool legalize_i1_tensors, bool peel, bool enable_xla_cpu_transformations, + bool pack_matmul) { // Options for the default JitRt compilation pipeline (lowering to LLVM). CompilationPipelineOptions copts; copts.alignment = EIGEN_MAX_ALIGN_BYTES; @@ -102,6 +101,9 @@ TfJitRtExecutor::Handle TfJitRtExecutor::Compile(const std::string& mlir_module, opts.vectorize = vectorize; opts.codegen_transpose = codegen_transpose; opts.legalize_i1_tensors = legalize_i1_tensors; + opts.peel = peel; + opts.enable_xla_cpu_transformations = enable_xla_cpu_transformations; + opts.lower_to_mmt4d = pack_matmul; tensorflow::CreateTfJitRtPipeline(*passes, opts); CreateDefaultJitRtCompilationPipeline(passes, copts); }; @@ -245,7 +247,8 @@ std::vector TfJitRtExecutor::Execute( PyBindingResultConverter converter(results, results_ctx); converter.AddConversion(ReturnStridedMemref); if (auto st = (*executable)->Execute(memrefs, converter, opts); !st.ok()) - throw std::runtime_error(StrCat("Unsupported argument: ", st.message())); + throw std::runtime_error( + StrCat("Unsupported argument: ", st.status().message())); // Pull Python arrays out of async values. std::vector ret_values; @@ -289,7 +292,9 @@ PYBIND11_MODULE(_tf_jitrt_executor, m) { py::arg("specialization") = tensorflow::TfJitRtExecutor::Specialization::kEnabled, py::arg("vectorize") = false, py::arg("codegen_transpose") = false, - py::arg("legalize_i1_tensors") = false) + py::arg("legalize_i1_tensors") = false, py::arg("peel") = true, + py::arg("enable_xla_cpu_transformations") = false, + py::arg("pack_matmul") = false) .def("execute", &tensorflow::TfJitRtExecutor::Execute) .def("built_with", &tensorflow::TfJitRtExecutor::BuiltWith, py::arg("cpu_feature")); diff --git a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.h b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.h index e68b5f99896..0a6542c9a30 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.h +++ b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "llvm/ADT/DenseMap.h" #include "pybind11/numpy.h" @@ -43,7 +44,8 @@ class TfJitRtExecutor { // execute function. Handle Compile(const std::string& mlir_module, const std::string& entrypoint, Specialization specialization, bool vectorize, - bool codegen_transpose, bool legalize_i1_tensors); + bool codegen_transpose, bool legalize_i1_tensors, bool peel, + bool enable_xla_cpu_transformations, bool pack_matmul); // Executes compiled mlir module with Python array arguments. Converts // returned memrefs into Python arrays. diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc index eac5a04424a..25020671505 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc @@ -480,6 +480,9 @@ static Expected> CompileImpl( opts.legalize_i1_tensors = tf_jitrt_opts->legalize_i1_tensors; } else { opts.vectorize = GetJitRtFlags().vectorize; + opts.enable_xla_cpu_transformations = + tensorflow::GetJitRtFlags().enable_xla_cpu_transformations; + opts.lower_to_mmt4d = tensorflow::GetJitRtFlags().pack_matmul; } // Lower from Tensorflow to Linalg on buffers. @@ -730,7 +733,7 @@ static void ExecuteImpl(Executable& executable, ArrayRef memrefs, // notify the HostContext to emit the diagnostics for the kernel invocation. auto status = executable.Execute(memrefs, converter, opts); if (LLVM_UNLIKELY(!status.ok())) { - EmitError(exec_ctx, status.message()); + EmitError(exec_ctx, status.status().message()); return; } } diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc index 91482049bce..be2d8b4107c 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc @@ -31,11 +31,13 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h" #include "tensorflow/compiler/xla/mlir/runtime/ir/rt_dialect.h" #include "tensorflow/compiler/xla/mlir/runtime/transforms/compiler.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/transforms.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/transforms/passes.h" // -------------------------------------------------------------------------- // // Custom passes that are missing upstream. @@ -66,38 +68,6 @@ struct AddTensorflowProducerVersion } }; -// Adds Linalg passes to perform fusion, tiling, peeling and vectorization. -void AddLinalgTransformations(OpPassManager& pm, - const TfJitRtPipelineOptions& options) { - pm.addNestedPass(CreateFusionPass()); - - if (!options.vectorize) return; - - pm.addNestedPass(CreateDetensorizeLinalgPass()); - - pm.addNestedPass(CreateTileReductionPass( - options.vector_size, options.reduction_1d_tile_size, - options.reduction_2d_tile_sizes)); - - // TODO(b/248219927): Enable matmul transformations when bufferization works. - // pm.addNestedPass( - // mlir::gml_st::createTransformMatmulForCpuPass(options.matmul_tile_sizes)); - - if (options.vectorize && options.codegen_transpose) - pm.addNestedPass(CreateTileTransposePass()); - pm.addNestedPass(CreateTileCWisePass(options.vector_size)); - if (options.peel) { - pm.addNestedPass(CreatePeelTiledLoopsPass()); - } - pm.addNestedPass(mlir::createCSEPass()); - pm.addPass(mlir::createCanonicalizerPass()); - if (options.fuse_fill) { - pm.addNestedPass(CreateFuseFillIntoTiledReductionPass()); - } - pm.addNestedPass(CreateTileFillPass(options.vector_size)); - pm.addNestedPass(mlir::gml_st::createVectorizeGmlStLoopsPass()); -} - void AddBufferizationPasses(OpPassManager& pm) { // Rewrite tensor.empty ops to bufferization.alloc_tensor ops. pm.addNestedPass( @@ -122,8 +92,12 @@ void CreateTfJitRtPipeline(OpPassManager& pm, pm.addPass(mlir::TF::CreateTFShapeInferencePass()); pm.addPass(mlir::createCanonicalizerPass()); + // This will add regions to IfOp/WhileOp (turning them into IfRegionOp + // and WhileRegionOp), but be aware that those regions will still contain + // calls. + pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); + // Transform TF operation to HLO. - pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass()); pm.addNestedPass(mlir::mhlo::createLegalizeTFPass()); if (options.legalize_i1_tensors) { @@ -132,43 +106,38 @@ void CreateTfJitRtPipeline(OpPassManager& pm, } // Remove redundant shape operations left after legalizing to HLO. + pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); - // Resolve all shape constraints (e.g. broadcast constraints that can be - // proved statically and changed to const witness) early to allow more - // efficient broadcast operations moving. - pm.addNestedPass( - CreateSymbolicShapeOptimizationPass(/*constraints_only=*/true)); - - // Analyze shapes and try to simplify the IR as early as possible. - pm.addNestedPass(mlir::createSymbolicShapeOptimizationPass()); + // Analyze shapes and try to simplify the IR early. + pm.addNestedPass(mlir::mhlo::createSymbolicShapeOptimizationPass()); pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createCanonicalizerPass()); // Move up broadcasting operations to allow for more fusion opportunities. - // Add the broadcast propagation pass first, because it can help to avoid - // exponential complexity from the EarlyBroadcastInDimOp pattern which is used - // in the merge assuming ops pass further down. pm.addNestedPass(mlir::mhlo::createMergeAssumingOpsPass()); pm.addNestedPass(mlir::mhlo::createBroadcastPropagationPass()); pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createCanonicalizerPass()); - // After all shape constraints removed and broadcasts moved to the top, try - // to resolve broadcasts that can be converted to linalg generic operations. - pm.addNestedPass(CreateSymbolicShapeOptimizationPass()); - - // Group reduction and parallel dimensions of reduction operations and realize - // them through equivalent 1D or 2D reductions, if possible. pm.addNestedPass(mlir::mhlo::createGroupReductionDimensionsPass()); + pm.addNestedPass( + mlir::mhlo::createHloCanonicalizeScatterPass()); // Also, try to simplify reshape operations. - pm.addNestedPass(mlir::createSymbolicShapeOptimizationPass()); + pm.addNestedPass(mlir::mhlo::createSymbolicShapeOptimizationPass()); // Transform HLO operations to Linalg and Standard. pm.addNestedPass(mlir::mhlo::createLegalizeControlFlowPass()); pm.addNestedPass(mlir::mhlo::createLegalizeSortPass()); - pm.addNestedPass(mlir::mhlo::createLegalizeHloToLinalgPass()); + pm.addNestedPass(xla::cpu::createLegalizeCollectiveOpsPass()); + + if (options.vectorize) { + pm.addNestedPass( + mlir::mhlo::createLegalizeMHLOToTHLOPass()); + } + pm.addNestedPass(mlir::mhlo::createLegalizeHloToLinalgPass( + /*enablePrimitiveOps=*/options.vectorize)); pm.addPass(mlir::mhlo::createLegalizeToArithmeticPass()); pm.addNestedPass( mlir::mhlo::createLegalizeHloShapeOpsToStandardPass()); @@ -180,7 +149,7 @@ void CreateTfJitRtPipeline(OpPassManager& pm, // Lower shape dialect to standard to enable linalg canonicalizations (e.g. // use linalg inputs instead of outputs for memref.dim operations). - pm.addNestedPass(mlir::createShapeSimplification()); + pm.addNestedPass(mlir::mhlo::createShapeSimplification()); pm.addNestedPass(mlir::createShapeToShapeLowering()); pm.addPass(mlir::createConvertShapeToStandardPass()); pm.addNestedPass(mlir::createConvertShapeConstraintsPass()); @@ -196,8 +165,20 @@ void CreateTfJitRtPipeline(OpPassManager& pm, // Convert complex types. pm.addPass(mlir::createConvertComplexToStandardPass()); - // Add linalg passes to perform fusion, tiling, peeling and vectorization. - AddLinalgTransformations(pm, options); + // Add passes to perform fusion, tiling, peeling and vectorization. + if (options.vectorize) { + mlir::gml_st::GmlStCPUPipelineOptions gml_st_opts; + gml_st_opts.vectorize = options.vectorize; + gml_st_opts.vectorSize = options.vector_size; + gml_st_opts.reduction1DTileSize = options.reduction_1d_tile_size; + gml_st_opts.reduction2DTileSizes = options.reduction_2d_tile_sizes; + gml_st_opts.matmulTileSizes = options.matmul_tile_sizes; + gml_st_opts.lowerToMmt4d = options.lower_to_mmt4d; + + mlir::gml_st::addCPUTilingPipeline(pm, gml_st_opts); + } else { + pm.addNestedPass(CreateFusionPass()); + } // Inline everything, bufferization doesn't model ownership across calls. pm.addPass(mlir::createInlinerPass()); @@ -211,6 +192,12 @@ void CreateTfJitRtPipeline(OpPassManager& pm, pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createCanonicalizerPass()); + if (options.vectorize) + pm.addNestedPass(mlir::gml_st::createVectorizeCopyPass()); + + if (options.enable_xla_cpu_transformations) + pm.addNestedPass(mlir::gml_st::createSimplifyDeadCopyPass()); + // Deallocate all temporary buffers. pm.addNestedPass(mlir::bufferization::createBufferDeallocationPass()); @@ -227,14 +214,14 @@ void CreateTfJitRtPipeline(OpPassManager& pm, pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createCanonicalizerPass()); - if (options.vectorize && options.codegen_transpose) - pm.addNestedPass(CreateLowerVectorTransposePass()); + pm.addNestedPass(mlir::gml_st::createRewriteVectorTransposePass()); mlir::VectorTransferToSCFOptions vec_to_scf_options; vec_to_scf_options.unroll = true; pm.addNestedPass( mlir::createConvertVectorToSCFPass(vec_to_scf_options)); - pm.addNestedPass(createRewriteVectorMultiReductionPass()); + pm.addNestedPass( + mlir::gml_st::createRewriteVectorMultiReductionPass()); pm.addNestedPass(CreateMathApproximationPass({"all"})); } diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.h b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.h index 450c519ad29..6f546dd77a5 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.h +++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.h @@ -29,6 +29,11 @@ struct TfJitRtPipelineOptions llvm::cl::desc("Enable tiling for vectorization."), llvm::cl::init(false)}; + Option enable_xla_cpu_transformations{ + *this, "enable_xla_cpu_transformations", + llvm::cl::desc("Enable tiling/fusion shared with XLA:CPU Next."), + llvm::cl::init(false)}; + Option peel{*this, "peel", llvm::cl::desc("Enable loop peeling."), llvm::cl::init(true)}; @@ -55,6 +60,12 @@ struct TfJitRtPipelineOptions llvm::cl::desc("Tile sizes for `linalg.matmul`."), llvm::cl::list_init({4, 4, 4}), llvm::cl::ZeroOrMore}; + Option lower_to_mmt4d{ + *this, "lower-to-mmt4d", + llvm::cl::desc("Enable the specific code generation (packing) for matmul " + "operations."), + llvm::cl::init(false)}; + Option legalize_i1_tensors{ *this, "legalize-i1-tensors", llvm::cl::desc("Convert i1 tensors to i8 tensors."), diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD b/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD index ee6c883beb1..bfc26199451 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/BUILD @@ -4,6 +4,7 @@ load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") # TF to TFRT kernels conversion. package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//tensorflow/compiler/mlir/tfrt:friends"], licenses = ["notice"], ) @@ -47,20 +48,11 @@ cc_library( "tf_jitrt_buffer_forwarding.cc", "tf_jitrt_clustering_pass.cc", "tf_jitrt_copy_removal.cc", - "tf_jitrt_detensorize_linalg.cc", "tf_jitrt_fission.cc", - "tf_jitrt_fuse_fill_into_tiled_reduction.cc", "tf_jitrt_fusion.cc", "tf_jitrt_legalize_i1_type.cc", - "tf_jitrt_lower_vector_transpose.cc", "tf_jitrt_math_approximation.cc", "tf_jitrt_passes.cc", - "tf_jitrt_peel_tiled_loops.cc", - "tf_jitrt_rewrite_vector_multi_reduction.cc", - "tf_jitrt_symbolic_shape_optimization.cc", - "tf_jitrt_tile_cwise.cc", - "tf_jitrt_tile_reduction.cc", - "tf_jitrt_tile_transpose.cc", ], hdrs = ["tf_jitrt_passes.h"], compatible_with = get_compatible_with_cloud(), diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc index 80a4d063bd0..6391560c757 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc @@ -209,19 +209,19 @@ class BroadcastToOpClusteringPolicy BroadcastToOp op, const ValuesConstraintSet& results, ValuesConstraintSet& operands) const final { // Only ranked inputs are supported. - operands.Insert(op.input(), ValueConstraint::kRank); + operands.Insert(op.getInput(), ValueConstraint::kRank); if (auto result_constraint = results.GetConstraint(op.getResult())) { if (*result_constraint == ValueConstraint::kValue) return failure(); // For a static output shape we need a constant shape operand. if (*result_constraint == ValueConstraint::kShape) { - operands.Insert(op.shape(), ValueConstraint::kValue); + operands.Insert(op.getShape(), ValueConstraint::kValue); return success(); } } // Producing a ranked output requires a known shape for the shape operand. - operands.Insert(op.shape(), ValueConstraint::kShape); + operands.Insert(op.getShape(), ValueConstraint::kShape); return success(); } @@ -398,13 +398,13 @@ class ConcatV2OpClusteringPolicy // Propagate constraint from the result to the input. All inputs always need // a known rank. - for (auto value : op.values()) { + for (auto value : op.getValues()) { operands.Insert(value, - result_constraint.getValueOr(ValueConstraint::kRank)); + result_constraint.value_or(ValueConstraint::kRank)); } // Force axis to be a constant. - operands.Insert(op.axis(), ValueConstraint::kValue); + operands.Insert(op.getAxis(), ValueConstraint::kValue); return success(); } @@ -423,7 +423,7 @@ class ConstOpClusteringPolicy : public TensorflowOpClusteringPolicy { auto result_constraint = results.GetConstraint(op.getResult()); if (!result_constraint.has_value()) return failure(); - return IsCompilableConstant(op.value()); + return IsCompilableConstant(op.getValue()); } }; @@ -439,13 +439,13 @@ class ExpandDimsOpClusteringPolicy // Propagate constraint from the result to the input. if (auto result_constraint = results.GetConstraint(op->getResult(0))) { if (*result_constraint == ValueConstraint::kValue) return failure(); - operands.Insert(op.input(), *result_constraint); + operands.Insert(op.getInput(), *result_constraint); } else { - operands.Insert(op.input(), ValueConstraint::kRank); + operands.Insert(op.getInput(), ValueConstraint::kRank); } // The inserted dimension must be always known at compile time. - operands.Insert(op.dim(), ValueConstraint::kValue); + operands.Insert(op.getDim(), ValueConstraint::kValue); return success(); } @@ -466,12 +466,12 @@ class FusedMatMulOpClusteringPolicy return failure(); // Check if we do support a set of fused operations. - size_t n = op.fused_ops().size(); + size_t n = op.getFusedOps().size(); auto fusion = - n > 0 ? op.fused_ops()[0].dyn_cast() : nullptr; + n > 0 ? op.getFusedOps()[0].dyn_cast() : nullptr; auto activation = - n > 1 ? op.fused_ops()[1].dyn_cast() : nullptr; + n > 1 ? op.getFusedOps()[1].dyn_cast() : nullptr; if ((n > 0 && !fusion) || (n > 1 && !activation)) return failure(); @@ -502,11 +502,11 @@ class FillOpClusteringPolicy : public TensorflowOpClusteringPolicy { // To know the result shape we need to know the shape operand value. if (*result_constraint == ValueConstraint::kShape) - operands.Insert(op.dims(), ValueConstraint::kValue); + operands.Insert(op.getDims(), ValueConstraint::kValue); // To know the result rank we need to know the shape operand shape. if (*result_constraint == ValueConstraint::kRank) - operands.Insert(op.dims(), ValueConstraint::kShape); + operands.Insert(op.getDims(), ValueConstraint::kShape); // Value constraint propagation is not supported. if (*result_constraint == ValueConstraint::kValue) return failure(); @@ -534,8 +534,8 @@ class OneHotOpClusteringPolicy : public TensorflowOpClusteringPolicy { if (*constraint == ValueConstraint::kValue) return failure(); // MHLO lowering needs a static shape for the indices and a constant depth. - operands.Insert(op.indices(), ValueConstraint::kShape); - operands.Insert(op.depth(), ValueConstraint::kValue); + operands.Insert(op.getIndices(), ValueConstraint::kShape); + operands.Insert(op.getDepth(), ValueConstraint::kValue); return success(); } @@ -561,7 +561,7 @@ class RangeOpClusteringPolicy : public TensorflowOpClusteringPolicy { // To know the result shape we need the input values. if (*result_constraint == ValueConstraint::kShape) { - operands.Insert({op.start(), op.limit(), op.delta()}, + operands.Insert({op.getStart(), op.getLimit(), op.getDelta()}, ValueConstraint::kValue); } @@ -582,7 +582,7 @@ class ReshapeOpClusteringPolicy ReshapeOp op, const ValuesConstraintSet& results, ValuesConstraintSet& operands) const final { // The runtime only supports ranked tensors. - operands.Insert(op.tensor(), ValueConstraint::kRank); + operands.Insert(op.getTensor(), ValueConstraint::kRank); // Reshape operation does not have any default constraints. auto result_constraint = results.GetConstraint(op.getResult()); @@ -591,13 +591,13 @@ class ReshapeOpClusteringPolicy // To know the result shape we need to know the shape operand value. We also // require a static shape on the input in case there's a -1 in the shape. if (*result_constraint == ValueConstraint::kShape) { - operands.Insert(op.shape(), ValueConstraint::kValue); - operands.Insert(op.tensor(), ValueConstraint::kShape); + operands.Insert(op.getShape(), ValueConstraint::kValue); + operands.Insert(op.getTensor(), ValueConstraint::kShape); } // To know the result rank we need to know the shape operand shape. if (*result_constraint == ValueConstraint::kRank) - operands.Insert(op.shape(), ValueConstraint::kShape); + operands.Insert(op.getShape(), ValueConstraint::kShape); // Value constraint propagation is not supported. if (*result_constraint == ValueConstraint::kValue) return failure(); @@ -615,7 +615,7 @@ class ShapeOpClusteringPolicy : public TensorflowOpClusteringPolicy { ShapeOp op, const ValuesConstraintSet& results, ValuesConstraintSet& operands) const final { // Unranked inputs aren't supported by JitRt. - operands.Insert(op.input(), ValueConstraint::kRank); + operands.Insert(op.getInput(), ValueConstraint::kRank); // Check constraint on the result value. auto result_constraint = results.GetConstraint(op.getResult()); @@ -623,11 +623,11 @@ class ShapeOpClusteringPolicy : public TensorflowOpClusteringPolicy { // To know the result shape we need only the rank of the input. if (*result_constraint == ValueConstraint::kShape) - operands.Insert(op.input(), ValueConstraint::kRank); + operands.Insert(op.getInput(), ValueConstraint::kRank); // To know the result value we need to know the shape of the input. if (*result_constraint == ValueConstraint::kValue) - operands.Insert(op.input(), ValueConstraint::kShape); + operands.Insert(op.getInput(), ValueConstraint::kShape); return success(); } @@ -667,9 +667,9 @@ class SqueezeOpClusteringPolicy } // If squeeze_dims is not present we need a static shape. - if (op.squeeze_dims().empty()) input_constraint = ValueConstraint::kShape; + if (op.getSqueezeDims().empty()) input_constraint = ValueConstraint::kShape; - operands.Insert(op.input(), input_constraint); + operands.Insert(op.getInput(), input_constraint); return success(); } }; @@ -692,13 +692,13 @@ class TransposeOpClusteringPolicy ValuesConstraintSet& operands) const final { // Propagate result constraints to the input, at minimum require known rank. if (auto constraint = results.GetConstraint(op.getResult())) { - operands.Insert(op.x(), *constraint); + operands.Insert(op.getX(), *constraint); } else { - operands.Insert(op.x(), ValueConstraint::kRank); + operands.Insert(op.getX(), ValueConstraint::kRank); } // Permutation must be always known at compile time. - operands.Insert(op.perm(), ValueConstraint::kValue); + operands.Insert(op.getPerm(), ValueConstraint::kValue); return success(); } @@ -717,12 +717,12 @@ class SliceOpClusteringPolicy : public TensorflowOpClusteringPolicy { if (*constraint == ValueConstraint::kValue) return failure(); // We must know the shape of the input. - operands.Insert(op.input(), ValueConstraint::kShape); + operands.Insert(op.getInput(), ValueConstraint::kShape); // Force begin and size to be constants. The restriction on begin could be // lifted if we know that there are no `-1` sizes. // TODO(kramerb): Revisit this when mhlo.real_dynamic_slice stabilizes. - operands.Insert({op.begin(), op.size()}, ValueConstraint::kValue); + operands.Insert({op.getBegin(), op.getSize()}, ValueConstraint::kValue); return success(); } @@ -738,10 +738,10 @@ class StridedSliceOpClusteringPolicy StridedSliceOp op, const ValuesConstraintSet& results, ValuesConstraintSet& operands) const final { // We must know the shape of the input. - operands.Insert(op.input(), ValueConstraint::kShape); + operands.Insert(op.getInput(), ValueConstraint::kShape); // And values of operands that control the slice size. - operands.Insert({op.begin(), op.end(), op.strides()}, + operands.Insert({op.getBegin(), op.getEnd(), op.getStrides()}, ValueConstraint::kValue); return success(); @@ -915,7 +915,7 @@ mlir::LogicalResult VerifyCluster(const Cluster& cluster) { // Small constants will be sunk into the compiled function body. auto const_op = mlir::dyn_cast(op); - if (!const_op || failed(IsCompilableConstant(const_op.value()))) + if (!const_op || failed(IsCompilableConstant(const_op.getValue()))) return failure(); } diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering_pass.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering_pass.cc index ed61a74842e..ee10d7eedba 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering_pass.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering_pass.cc @@ -80,8 +80,8 @@ struct ClusteringPass : public impl::ClusteringBase { // If the clustering tier is not defined, it means that the opset will later // filter supported operations, so it's ok to use `all` tier. - populateTfJitRtClusteringPolicies( - policies, tier.getValueOr(JitRtClusteringTier::kAll)); + populateTfJitRtClusteringPolicies(policies, + tier.value_or(JitRtClusteringTier::kAll)); // If opset is not empty restrict operations that are enabled for // clustering. diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_detensorize_linalg.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_detensorize_linalg.cc deleted file mode 100644 index 9d76d3c3cec..00000000000 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_detensorize_linalg.cc +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" - -namespace tensorflow { -namespace { - -#define GEN_PASS_DEF_DETENSORIZELINALG -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc" - -using mlir::AffineMap; -using mlir::ConversionPatternRewriter; -using mlir::failure; -using mlir::LogicalResult; -using mlir::OpConversionPattern; -using mlir::OpRewritePattern; -using mlir::PatternRewriter; -using mlir::RankedTensorType; -using mlir::success; -using mlir::Type; -using mlir::TypeRange; -using mlir::Value; -using mlir::linalg::GenericOp; -using mlir::tensor::ExtractOp; -using mlir::tensor::FromElementsOp; - -bool IsNotZeroRankTensor(RankedTensorType tensor_type) { - return !tensor_type || tensor_type.getRank() > 0; -} - -/// A conversion patttern for detensoring Linalg ops. -struct DetensorizeLinalgOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - GenericOp op, OpAdaptor /*adaptor*/, - ConversionPatternRewriter& rewriter) const override { - mlir::Location loc = op.getLoc(); - mlir::SmallVector indexing_maps = op.getIndexingMapsArray(); - - mlir::SmallVector inputs; - bool found_zero_dim_tensor = false; - for (auto& en : llvm::enumerate(op.getDpsInputOperands())) { - auto tensor_type = - en.value()->get().getType().dyn_cast(); - if (IsNotZeroRankTensor(tensor_type)) { - inputs.push_back(en.value()->get()); - continue; - } - found_zero_dim_tensor = true; - indexing_maps[en.index()] = - AffineMap::get(op.getNumLoops(), 0, llvm::None, op.getContext()); - inputs.push_back(rewriter.create(loc, en.value()->get(), - mlir::ValueRange{})); - } - if (!found_zero_dim_tensor) return failure(); - - auto linalg_op = rewriter.create( - loc, op.getResultTypes(), inputs, op.getOutputs(), - rewriter.getAffineMapArrayAttr(indexing_maps), op.getIteratorTypes(), - mlir::StringAttr(), mlir::StringAttr()); - mlir::Region& region = linalg_op.getRegion(); - rewriter.inlineRegionBefore(op.getBodyRegion(), region, region.end()); - rewriter.replaceOp(op, linalg_op.getResults()); - return success(); - } -}; - -struct DetensorizeLinalgPass - : public impl::DetensorizeLinalgBase { - DetensorizeLinalgPass() = default; - - void runOnOperation() override { - auto func = getOperation(); - auto* context = &getContext(); - - mlir::ConversionTarget target(*context); - target.markUnknownOpDynamicallyLegal([](mlir::Operation*) { return true; }); - target.addDynamicallyLegalOp([&](GenericOp op) { - return llvm::all_of(TypeRange{op.getInputs()}, [&](Type type) { - return IsNotZeroRankTensor(type.dyn_cast()); - }); - }); - - // Detensorize. - mlir::RewritePatternSet patterns(context); - patterns.add(context); - if (failed(applyFullConversion(func, target, std::move(patterns)))) - signalPassFailure(); - - // Canonicalize. - mlir::RewritePatternSet canonicalization_patterns(context); - FromElementsOp::getCanonicalizationPatterns(patterns, context); - if (failed(applyPatternsAndFoldGreedily( - func, std::move(canonicalization_patterns)))) - signalPassFailure(); - } -}; - -} // namespace - -std::unique_ptr> -CreateDetensorizeLinalgPass() { - return std::make_unique(); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fission.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fission.cc index bfb2f031206..ce205d8f06c 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fission.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fission.cc @@ -36,13 +36,13 @@ struct FusedMatMulFission auto loc = op.getLoc(); auto type = op.getResult().getType(); - size_t n = op.fused_ops().size(); + size_t n = op.getFusedOps().size(); // Extract fused operations from the operation attributes. mlir::StringAttr fusion0 = - n > 0 ? op.fused_ops()[0].dyn_cast() : nullptr; + n > 0 ? op.getFusedOps()[0].dyn_cast() : nullptr; mlir::StringAttr fusion1 = - n > 1 ? op.fused_ops()[1].dyn_cast() : nullptr; + n > 1 ? op.getFusedOps()[1].dyn_cast() : nullptr; // Match to supported operations bool is_bias_add = fusion0 && fusion0.getValue() == "BiasAdd"; @@ -53,7 +53,7 @@ struct FusedMatMulFission auto lhs = op.getOperand(0); auto rhs = op.getOperand(1); return rewriter.create( - loc, type, lhs, rhs, op.transpose_a(), op.transpose_b()); + loc, type, lhs, rhs, op.getTransposeA(), op.getTransposeB()); }; // FusedMatMul[BiasAdd]. diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fuse_fill_into_tiled_reduction.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fuse_fill_into_tiled_reduction.cc deleted file mode 100644 index 17e88727a5a..00000000000 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fuse_fill_into_tiled_reduction.cc +++ /dev/null @@ -1,340 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" - -namespace tensorflow { -namespace { - -#define GEN_PASS_DEF_FUSEFILLINTOTILEDREDUCTION -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc" - -using llvm::makeArrayRef; -using mlir::BlockAndValueMapping; -using mlir::BlockArgument; -using mlir::dyn_cast; -using mlir::failure; -using mlir::Location; -using mlir::LogicalResult; -using mlir::MLIRContext; -using mlir::OpBuilder; -using mlir::Operation; -using mlir::OpFoldResult; -using mlir::OpRewritePattern; -using mlir::PatternRewriter; -using mlir::SmallVector; -using mlir::success; -using mlir::Value; -using mlir::ValueRange; -using mlir::gml_st::LoopOp; -using mlir::linalg::FillOp; -using mlir::linalg::GenericOp; -using mlir::linalg::LinalgOp; -using mlir::linalg::YieldOp; -using mlir::tensor::EmptyOp; -using mlir::tensor::ExtractSliceOp; -using mlir::tensor::InsertSliceOp; - -SmallVector GetParallelDimStep(LoopOp tiled_loop) { - assert(tiled_loop.getNumLoops() == 2 && "Expected a 2D loop"); - Value step = tiled_loop.isParallelDimension(0) ? tiled_loop.getStep().front() - : tiled_loop.getStep().back(); - if (auto constant = step.getDefiningOp()) { - return {constant.getValue()}; - } - return {step}; -} - -// Fuses `linalg.fill` into a loop with a tiled reduction. -// Currently, only 2D case is supported. Fusion into a tiled 1D reduction is -// also possible. -struct FuseFillIntoTiledReductionPattern : public OpRewritePattern { - explicit FuseFillIntoTiledReductionPattern(MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit) {} - - LogicalResult matchAndRewrite(GenericOp linalg_op, - PatternRewriter &rewriter) const override { - if (linalg_op.getNumDpsInits() != 1) return failure(); - if (linalg_op.getNumLoops() != 2) return failure(); - - // Get immediate parent. - auto tiled_loop_op = - dyn_cast(linalg_op->getParentRegion()->getParentOp()); - if (!tiled_loop_op) return failure(); - if (tiled_loop_op.getNumLoops() != 2) return failure(); - - return RewriteTiledReduction(rewriter, tiled_loop_op, linalg_op); - } - - private: - // Add a new output argument to the `tiled_loop`. It will be produced by - // `empty` op with the same shape of the tiled output argument. - // - // Rewrite - // - // %init = tensor.empty - // %fill = linalg.fill(%cst, %init) - // linalg.tiled_loop outs(%fill) - // - // into - // - // %init = tensor.empty - //** %init_tile = tensor.empty [%stride] - // %fill = linalg.fill(%cst, %init) - //** linalg.tiled_loop outs(%fill, %init_tile) - BlockArgument CloneAndAppendEmptyTensorToTiledLoop(PatternRewriter &rewriter, - FillOp fill, - LoopOp tiled_loop) const { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(fill); - - auto empty = fill.output().getDefiningOp(); - - Value empty_clone = rewriter.create( - empty.getLoc(), GetParallelDimStep(tiled_loop), - empty.getType().cast().getElementType()); - mlir::OpOperand *empty_clone_output_operand; - rewriter.updateRootInPlace(tiled_loop, [&]() { - empty_clone_output_operand = - &tiled_loop.appendOutputOperand(rewriter, empty_clone); - }); - return tiled_loop.getTiedBlockArgument(*empty_clone_output_operand); - } - - // Fuse `fill` operation into the `tiled_loop`, rewire the `linalg.generic` to - // use it as the output for the reduced tile. Also create an additional - // `insert_slice` that updates the new output. - // - // Rewrite - // - // %init = tensor.empty - // %init_tile = tensor.empty [%stride] - // %fill = linalg.fill(%cst, %init) - // linalg.tiled_loop outs(%fill, %init_tile) { - // %extract_output_slice = tensor.extract_slice %fill - // %reduce = linalg.generic outs (%extract_output_slice) - // %insert_output_slice = tensor.insert_slice %reduce into %fill - // linalg.yield %insert_output_slice - // } - // - // into - // - // %init = tensor.empty - // %init_tile = tensor.empty - // %fill = linalg.fill(%cst, %init) - // linalg.tiled_loop outs(%fill, %init_tile) { - // %extract_output_slice = tensor.extract_slice %fill - // - //** %slice_of_output_tile = tensor.extract_slice %init - //** %fill_of_output_tile = linalg.fill(%cst, %slice_of_output_tile) - //** %reduce = linalg.generic outs (%fill_of_output_tile) - //** %update_output_tile = tensor.insert_slice %reduce into %init_tile - // - // %insert_output_slice = tensor.insert_slice %reduce into %fill - // linalg.yield %insert_output_slice, %update_output_tile - // } - void FuseFill(PatternRewriter &rewriter, LinalgOp tiled_op, FillOp fill, - BlockArgument loop_output_bb_arg, - BlockArgument output_tile_bb_arg, - ExtractSliceOp extract_output_slice, - InsertSliceOp insert_output_slice) const { - Location loc = tiled_op.getLoc(); - - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(tiled_op); - - SmallVector offset{rewriter.getIndexAttr(0)}; - Value slice_of_output_tile = rewriter.create( - loc, output_tile_bb_arg, offset, extract_output_slice.getMixedSizes(), - extract_output_slice.getMixedStrides()); - - auto fused_fill = - rewriter.create(loc, fill.value(), slice_of_output_tile); - rewriter.updateRootInPlace(tiled_op, [&]() { - tiled_op.getDpsInitOperand(0)->set(fused_fill.result()); - }); - - rewriter.setInsertionPointAfter(tiled_op); - Value cloned_insert = rewriter.create( - loc, fused_fill.getResult(0), output_tile_bb_arg, offset, - extract_output_slice.getMixedSizes(), - extract_output_slice.getMixedStrides()); - - auto yield = tiled_op.getOperation()->getBlock()->getTerminator(); - rewriter.updateRootInPlace( - yield, [&]() { yield->insertOperands(1, cloned_insert); }); - } - - // Add an operation that combines the partial result with the output. - // - // Rewrite - // - // %init = tensor.empty - // %init_tile = tensor.empty - // %fill = linalg.fill(%cst, %init) - // linalg.tiled_loop outs(%fill, %init_tile) { - // %extract_output_slice = tensor.extract_slice %fill - // - // %slice_of_output_tile = tensor.extract_slice %init - // %fill_of_output_tile = linalg.fill(%cst, %slice_of_output_tile) - // %reduce = linalg.generic outs (%fill_of_output_tile) - // %update_output_tile = tensor.insert_slice %reduce into %init_tile - // - // %insert_output_slice = tensor.insert_slice %reduce into %fill - // linalg.yield %insert_output_slice, %update_output_tile - // } - // - // into - // - // %init = tensor.empty - // %init_tile = tensor.empty - // %fill = linalg.fill(%cst, %init) - // linalg.tiled_loop outs(%fill, %init_tile) { - // %extract_output_slice = tensor.extract_slice %fill - // - // %slice_of_output_tile = tensor.extract_slice %init - // %fill_of_output_tile = linalg.fill(%cst, %slice_of_output_tile) - // %reduce = linalg.generic outs (%fill_of_output_tile) - // %update_output_tile = tensor.insert_slice %reduce into %init_tile - // - //** %combine = linalg.generic ins (%reduce) outs (%extract_output_slice) - //** %insert_output_slice = tensor.insert_slice %combine into %fill - // - // linalg.yield %insert_output_slice, %update_output_tile - // } - LogicalResult CombineReducedTileWithOutput( - PatternRewriter &rewriter, LinalgOp tiled_op, Value partial_result, - ExtractSliceOp extract_output_slice, - InsertSliceOp insert_output_slice) const { - rewriter.setInsertionPointAfter(tiled_op); - auto num_parallel_loops = tiled_op.getNumParallelLoops(); - SmallVector parallel_iter_types( - num_parallel_loops, mlir::getParallelIteratorTypeName()); - auto id_map = rewriter.getMultiDimIdentityMap(num_parallel_loops); - - auto combiner_or = DetectCombiner(tiled_op); - if (failed(combiner_or)) return failure(); - Operation *combiner = combiner_or.getValue(); - - auto accumulator = rewriter.create( - tiled_op.getLoc(), partial_result.getType(), - makeArrayRef(partial_result), makeArrayRef((Value)extract_output_slice), - makeArrayRef({id_map, id_map}), parallel_iter_types, - [&](OpBuilder &b, Location nested_loc, ValueRange args) { - BlockAndValueMapping bvm; - bvm.map(combiner->getOperands(), args); - Value result_val = b.clone(*combiner, bvm)->getResult(0); - b.create(nested_loc, result_val); - }); - - rewriter.updateRootInPlace(insert_output_slice, [&]() { - insert_output_slice.getSourceMutable().assign(accumulator.getResult(0)); - }); - return success(); - } - - // Unfortunaly, there is no way to modify the results of the loop inplace. So - // we have to replace it with a clone. - LoopOp CreateLoopWithUpdatedResults(PatternRewriter &rewriter, - LoopOp tiled_loop) const { - auto loc = tiled_loop.getLoc(); - rewriter.setInsertionPoint(tiled_loop); - auto new_loop = rewriter.create( - loc, mlir::TypeRange(tiled_loop.getOutputs()), tiled_loop.getOperands(), - tiled_loop->getAttrs()); - rewriter.inlineRegionBefore(tiled_loop.getRegion(), new_loop.getRegion(), - new_loop.getRegion().begin()); - - rewriter.replaceOp(tiled_loop, new_loop.getResult(0)); - return new_loop; - } - - // Fuses FillOp producer of the output argument of the LoopOp and inserts - // an operation that accumulates the partial result, i.e. reduced tile, and - // the current value of the output tile. - LogicalResult RewriteTiledReduction(PatternRewriter &rewriter, - LoopOp tiled_loop, - LinalgOp tiled_op) const { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(tiled_op); - - // Find tiled loop output operand and the corresponding block argument. - mlir::OpOperand *loop_output_operand = - tiled_loop.findOutputOperand(tiled_loop.getOutputs().front()); - BlockArgument loop_output_bb_arg = - tiled_loop.getTiedBlockArgument(*loop_output_operand); - - // Find `linalg.fill` producer of the output. - auto fill = loop_output_operand->get().getDefiningOp(); - if (!fill) return failure(); - - // Find extract_slice/insert_slice pair used to RMW output. - auto extract_output_slice = - tiled_op.getDpsInitOperand(0)->get().getDefiningOp(); - if (!extract_output_slice) return failure(); - - Value tiled_op_result = tiled_op->getResult(0); - auto insert_output_slice = - dyn_cast(*tiled_op_result.getUsers().begin()); - if (!insert_output_slice) return failure(); - - // Fuse the output. - BlockArgument output_tile_bb_arg = - CloneAndAppendEmptyTensorToTiledLoop(rewriter, fill, tiled_loop); - FuseFill(rewriter, tiled_op, fill, loop_output_bb_arg, output_tile_bb_arg, - extract_output_slice, insert_output_slice); - // We have already modified the loop above, so we need to update the - // results. - CreateLoopWithUpdatedResults(rewriter, tiled_loop); - return CombineReducedTileWithOutput(rewriter, tiled_op, tiled_op_result, - extract_output_slice, - insert_output_slice); - } -}; - -struct FuseFillIntoTiledReductionPass - : public impl::FuseFillIntoTiledReductionBase< - FuseFillIntoTiledReductionPass> { - void runOnOperation() override { - auto func = getOperation(); - auto context = func.getContext(); - - mlir::RewritePatternSet patterns(context); - patterns.add(context); - (void)mlir::applyPatternsAndFoldGreedily(func, std::move(patterns)); - } -}; - -} // namespace - -std::unique_ptr> -CreateFuseFillIntoTiledReductionPass() { - return std::make_unique(); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fusion.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fusion.cc index b0f9525f792..0e3a24ee5c1 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fusion.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_fusion.cc @@ -149,8 +149,7 @@ struct FusionPass : public impl::FusionBase { // Use TopDownTraversal for compile time reasons. mlir::GreedyRewriteConfig grc; grc.useTopDownTraversal = true; - (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), - grc); + (void)applyPatternsAndFoldGreedily(op, std::move(patterns), grc); } }; diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_lower_vector_transpose.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_lower_vector_transpose.cc deleted file mode 100644 index 91df9886ef0..00000000000 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_lower_vector_transpose.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" -#include "mlir/Dialect/X86Vector/Transforms.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" - -namespace tensorflow { -namespace { - -#define GEN_PASS_DEF_LOWERTRANSPOSE -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc" - -struct LowerTransposePass - : public impl::LowerTransposeBase { - void runOnOperation() override { - auto avx_lowering_options = - mlir::x86vector::avx2::LoweringOptions().setTransposeOptions( - mlir::x86vector::avx2::TransposeLoweringOptions() - .lower4x8xf32() - .lower8x8xf32()); - - mlir::func::FuncOp funcOp = getOperation(); - mlir::MLIRContext *context = funcOp.getContext(); - mlir::RewritePatternSet patterns(context); - mlir::vector::VectorTransformsOptions vectorTransformOptions; - vectorTransformOptions = vectorTransformOptions.setVectorTransposeLowering( - mlir::vector::VectorTransposeLowering::EltWise); - mlir::vector::populateVectorTransposeLoweringPatterns( - patterns, vectorTransformOptions); - mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns( - patterns, avx_lowering_options, /*benefit=*/10); - - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr> -CreateLowerVectorTransposePass() { - return std::make_unique(); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_math_approximation.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_math_approximation.cc index dbd1678dff8..e0148059827 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_math_approximation.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_math_approximation.cc @@ -256,7 +256,7 @@ struct EigenExpM1Approximation : public OpRewritePattern { LogicalResult EigenExpM1Approximation::matchAndRewrite( math::ExpM1Op op, PatternRewriter &rewriter) const { auto shape = vectorShape(op.getOperand().getType(), isF32); - if (!shape.hasValue()) + if (!shape.has_value()) return rewriter.notifyMatchFailure(op, "unsupported operand type"); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.cc index 9d975a363b8..f81159c9699 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.cc +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.cc @@ -20,7 +20,6 @@ limitations under the License. namespace tensorflow { using ::mlir::Operation; -using ::mlir::linalg::LinalgOp; bool IsContiguousMemref(mlir::Value value) { auto memref_type = value.getType().dyn_cast(); @@ -29,29 +28,4 @@ bool IsContiguousMemref(mlir::Value value) { return canonical_type.getLayout().isIdentity(); } -mlir::FailureOr DetectCombiner(LinalgOp linalg_op) { - mlir::SmallVector combiners; - if (!matchReduction(linalg_op.getRegionOutputArgs(), 0, combiners) || - combiners.size() != 1) - return mlir::failure(); - return combiners.front(); -} - -constexpr llvm::StringLiteral kTransformMarker = - "__internal_transformation_marker__"; - -void setTransformationAttr(mlir::OpBuilder &b, Operation *op) { - op->setAttr(kTransformMarker, b.getBoolAttr(true)); -} - -void removeTransformationAttr(Operation *op) { - op->removeAttr(kTransformMarker); -} - -bool hasTransformationAttr(Operation *op) { - auto marker = op->getAttr(kTransformMarker); - if (!marker) return false; - return marker && marker.cast().getValue(); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h index 1b4016d5de8..ae7a7b8da17 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h @@ -29,13 +29,10 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace tensorflow { -#define GEN_PASS_DECL_TILEREDUCTION -#define GEN_PASS_DECL_TILEFILL -#define GEN_PASS_DECL_TILECWISE #define GEN_PASS_DECL_MATHAPPROXIMATION #define GEN_PASS_DECL_CLUSTERING #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc" @@ -48,62 +45,18 @@ CreateLinalgTrivialBufferForwardingPass(); std::unique_ptr> CreateLinalgTrivialCopyRemovalPass(); -// Pass to optimize padding in tiled loops by peeling the final loop iteration. -std::unique_ptr> -CreatePeelTiledLoopsPass(); - -// Pass to tile and fuse linalg.generic on tensors that models reduction. -std::unique_ptr> -CreateTileReductionPass(); -std::unique_ptr> -CreateTileReductionPass(int64_t reduction_vector_size, - int64_t reduction_1d_tile_size, - llvm::ArrayRef reduction_2d_tile_sizes); - -// Pass to fuse `linalg.fill` into a tiled reduction. -std::unique_ptr> -CreateFuseFillIntoTiledReductionPass(); - // Pass to replace 'i1' tensor types with 'i8' tensor types. This pass is a // temporary workaround to avoid the problem of vectorizing 'i1' tensors (see // b/205714705). std::unique_ptr> CreateJitRtLegalizeI1TypesPass(); -// Rewrite `vector.multi_reduction` into a sequence of `vector.reduction` ops. -std::unique_ptr> -createRewriteVectorMultiReductionPass(); - -// Code generation passes targeting transpose operations. -std::unique_ptr> -CreateTileTransposePass(); -std::unique_ptr> -CreateLowerVectorTransposePass(); - -// Pass to tile elementwise linalg.generic on tensors. -std::unique_ptr> CreateTileCWisePass(); -std::unique_ptr> CreateTileCWisePass( - int64_t cwise_tile_size); - -// Pass to tile linalg.fill on tensors. -std::unique_ptr> CreateTileFillPass(); -std::unique_ptr> CreateTileFillPass( - int64_t cwise_tile_size); - // Pass to split _Fused Tensorflow kernels into primitives. std::unique_ptr> CreateFissionPass(); // Pass to fuse Linalg generic operations on Tensors. std::unique_ptr> CreateFusionPass(); -// Pass to optimize broadcasts based on the symbolic shape constraints. -std::unique_ptr> -CreateSymbolicShapeOptimizationPass(bool constraints_only = false); - -// Pass to replace 0-d tensor inputs to LinalgOp with extracted elements. -std::unique_ptr> -CreateDetensorizeLinalgPass(); - // Creates `tf_device.cluster` operations according to the TF JitRt clustering // policy. std::unique_ptr> @@ -119,20 +72,6 @@ CreateMathApproximationPass(llvm::ArrayRef oplist = {}); // Returns true if the `value` type is a memref that is contiguous in memory. bool IsContiguousMemref(mlir::Value value); -// Detects the combiner in the body of LinalgOp if any. Currently, only -// ops with a single combiner are supported. -mlir::FailureOr DetectCombiner( - mlir::linalg::LinalgOp linalg_op); - -// Sets the attribute to the `op` that indicates that the op was transformed. -void setTransformationAttr(mlir::OpBuilder &b, mlir::Operation *op); - -// Removes the attribute that indicates that it was transformed. -void removeTransformationAttr(mlir::Operation *op); - -// Checks if `op` has the attribute that indicates that it was transformed. -bool hasTransformationAttr(mlir::Operation *op); - } // namespace tensorflow #define GEN_PASS_REGISTRATION diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.td b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.td index 391aab0eab8..5e2578bfb50 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.td +++ b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.td @@ -52,85 +52,16 @@ def LinalgTrivialCopyRemoval }]; } -def TileCWise : Pass<"tf-jitrt-tile-cwise", "mlir::func::FuncOp"> { - let summary = "Tile cwise linalg.generic on tensors."; - let constructor = "tensorflow::CreateTileCWisePass()"; - - let dependentDialects = [ - "mlir::gml_st::GmlStDialect", - "mlir::linalg::LinalgDialect" - ]; - - let options = [ - Option<"cwise_tile_size", "cwise-tile-size", - "int64_t", /*default=*/"8", - "Tile size for the innermost dimension of an elementwise op.">, - ]; -} - -def TileFill : Pass<"tf-jitrt-tile-fill", "mlir::func::FuncOp"> { - let summary = "Tile linalg.fill on tensors"; - let constructor = "tensorflow::CreateTileFillPass()"; - - let dependentDialects = [ - "mlir::gml_st::GmlStDialect", - "mlir::linalg::LinalgDialect" - ]; - - let options = [ - Option<"cwise_tile_size", "cwise-tile-size", - "int64_t", /*default=*/"8", - "Tile size for the innermost dimension of an elementwise op.">, - ]; -} - -def PeelTiledLoops : Pass<"tf-jitrt-peel-tiled-loops", "mlir::func::FuncOp"> { - let summary = "Optimize away padding in tiled loops"; - let constructor = "tensorflow::CreatePeelTiledLoopsPass()"; - let dependentDialects = [ - "mlir::gml_st::GmlStDialect", - "mlir::linalg::LinalgDialect" - ]; -} - -def TileReduction : Pass<"tf-jitrt-tile-reduction", "mlir::func::FuncOp"> { - let summary = "Tile and fuse linalg.generic reduction on tensors."; - let constructor = "tensorflow::CreateTileReductionPass()"; - let dependentDialects = [ - "mlir::gml_st::GmlStDialect", - "mlir::linalg::LinalgDialect", - "mlir::memref::MemRefDialect", - "mlir::scf::SCFDialect" - ]; - +def Fusion : Pass<"tf-jitrt-fusion", "mlir::func::FuncOp"> { + let summary = "Fuse Linalg generic operations on Tensors"; + let constructor = "tensorflow::CreateFusionPass()"; let description = [{ - Matches linalg.generic to understand whether it is a reduction or not. - After that performs tiling for vectorization and fusion of producers. + Fuse Linalg generic operations on Tensors using custom heuristics for + producer fusion profitability. }]; - - let options = [ - Option<"reduction_vector_size", "reduction-vector-size", - "int64_t", /*default=*/"8", "Vector size.">, - Option<"reduction_1d_tile_size", "reduction-1d-tile-size", - "int64_t", /*default=*/"32", "Tile size for a 1D reduction.">, - ListOption<"reduction_2d_tile_sizes", "reduction-2d-tile-sizes", "int64_t", - "Tile sizes for a 2D reduction.">, - ]; -} - -def FuseFillIntoTiledReduction - : Pass<"tf-jitrt-fuse-fill-into-tiled-reduction", "mlir::func::FuncOp"> { - let summary = "Fuse `linalg.fill` into `linalg.tiled_loop` with a reduction."; - let constructor = "tensorflow::CreateFuseFillIntoTiledReductionPass()"; let dependentDialects = [ - "mlir::gml_st::GmlStDialect", - "mlir::linalg::LinalgDialect" + "mlir::TF::TensorFlowDialect" ]; - - let description = [{ - Fuses `linalg.fill` producers of output tensor arguments into - `linalg.tiled_loop`. - }]; } def Fission : Pass<"tf-jitrt-fission", "mlir::func::FuncOp"> { @@ -141,18 +72,6 @@ def Fission : Pass<"tf-jitrt-fission", "mlir::func::FuncOp"> { ]; } -def Fusion : Pass<"tf-jitrt-fusion", "mlir::func::FuncOp"> { - let summary = "Fuse Linalg generic operations on Tensors"; - let constructor = "tensorflow::CreateFusionPass()"; - let description = [{ - Fuse Linalg generic operations on Tensors using custom heuristics for - producer fusion profitability. - }]; - let dependentDialects = [ - "mlir::TF::TensorFlowDialect" - ]; -} - def JitRtLegalizeI1Types : Pass<"tf-jitrt-legalize-i1-types", "mlir::ModuleOp"> { let summary = "Legalize 'i1' tensor types"; @@ -166,29 +85,6 @@ def JitRtLegalizeI1Types ]; } -def SymbolicShapeOptimization - : Pass<"tf-jitrt-symbolic-shape-optimization", "mlir::func::FuncOp"> { - let summary = "Optimizes broadcasts based on the symbolic shapes"; - let constructor = "tensorflow::CreateSymbolicShapeOptimizationPass()"; - let description = [{ - A simple pass that replaces shape constraints with const witnesses and - rewrites mhlo.broadcast_in_dim operations with linalg.generic broadcasts - using the symbolic shape attributes defined on the entrypoint function - arguments. - }]; - let dependentDialects = [ - "mlir::mhlo::MhloDialect", - "mlir::linalg::LinalgDialect" - ]; - - let options = [ - Option<"optimize_only_constraints", "optimize-only-constraints", - "bool", /*default=*/"false", - "Optimize only shape constraints and do not touch broadcasts.">, - - ]; -} - def Clustering : Pass<"tf-jitrt-clustering", "mlir::func::FuncOp"> { let summary = "Creates `tf_device.cluster` operations according to the TF " "JitRt clustering policy"; @@ -210,38 +106,6 @@ def Clustering : Pass<"tf-jitrt-clustering", "mlir::func::FuncOp"> { ]; } -def TileTranspose : Pass<"tf-jitrt-tile-transpose", "mlir::func::FuncOp"> { - let summary = "Tile transpose operations"; - let constructor = "tensorflow::CreateTileTransposePass()"; - let dependentDialects = [ - "mlir::gml_st::GmlStDialect", - "mlir::linalg::LinalgDialect" - ]; -} - -def LowerTranspose : Pass<"tf-jitrt-lower-vector-transpose", "mlir::func::FuncOp"> { - let summary = "Lower vector transpose operations"; - let constructor = "tensorflow::CreateLowerVectorTransposePass()"; - let dependentDialects = [ - "mlir::vector::VectorDialect", - "mlir::LLVM::LLVMDialect" - ]; -} - -def RewriteVectorMultiReductionPass : - Pass<"tf-jitrt-rewrite-vector-multi-reduction", "mlir::func::FuncOp"> { - let summary = "Convert `vector.multi_reduction` into `vector.reduction` ops."; - let constructor = "tensorflow::createRewriteVectorMultiReductionPass()"; - let dependentDialects = ["mlir::memref::MemRefDialect"]; -} - - -def DetensorizeLinalg : Pass<"tf-jitrt-detensorize-linalg", "mlir::func::FuncOp"> { - let summary = "Replace 0d tensor inputs to LinalgOp with extracted elements."; - let constructor = "tensorflow::CreateDetensorizeLinalgPass()"; - let dependentDialects = ["mlir::linalg::LinalgDialect"]; -} - def MathApproximation : Pass<"tf-jitrt-math-approximation", "mlir::func::FuncOp"> { let summary = "Approximate math operations with an implementation meant to " "match Eigen's results. This is a useful property to have when " diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_peel_tiled_loops.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_peel_tiled_loops.cc deleted file mode 100644 index 99b8f729ff0..00000000000 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_peel_tiled_loops.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/StringRef.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" - -namespace tensorflow { -namespace { - -#define GEN_PASS_DEF_PEELTILEDLOOPS -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc" - -constexpr llvm::StringRef kWasPeeledAttr = "PeelStLoopsPeeledAttr"; - -using mlir::gml_st::ForOp; -using mlir::gml_st::LoopOp; -using mlir::gml_st::ParallelOp; - -template -struct PeelGmlStLoop : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - LoopTy loop, mlir::PatternRewriter &rewriter) const override { - if (loop->hasAttr(kWasPeeledAttr)) return mlir::failure(); - auto true_attr = mlir::BoolAttr::get(rewriter.getContext(), true); - loop->setAttr(kWasPeeledAttr, true_attr); - for (int peeled_idx = loop.getNumLoops() - 1; peeled_idx >= 0; - peeled_idx--) { - LoopTy peel; - // Mark the new loop if one was created - if (mlir::gml_st::peelAndCanonicalizeGmlStLoop(rewriter, loop, peeled_idx, - peel) - .succeeded()) - peel->setAttr(kWasPeeledAttr, true_attr); - } - return mlir::success(); - } -}; - -struct PeelTiledLoopsPass - : public impl::PeelTiledLoopsBase { - void runOnOperation() override { - auto func_op = getOperation(); - - // Apply some canonicalizations before loop splitting confuses the - // situation. - // TODO(tpopp): See if this is still necessary in the integrated version. - mlir::RewritePatternSet canonicalizations(func_op.getContext()); - LoopOp::getCanonicalizationPatterns(canonicalizations, - func_op.getContext()); - ForOp::getCanonicalizationPatterns(canonicalizations, func_op.getContext()); - mlir::linalg::populateLinalgTilingCanonicalizationPatterns( - canonicalizations); - (void)applyPatternsAndFoldGreedily(func_op, std::move(canonicalizations)); - - mlir::RewritePatternSet loop_peeling(func_op.getContext()); - loop_peeling.add, PeelGmlStLoop, - PeelGmlStLoop>(func_op.getContext()); - (void)applyPatternsAndFoldGreedily(func_op, std::move(loop_peeling)); - - func_op->walk([&](LoopOp op) { - if (op->hasAttr(kWasPeeledAttr)) op->removeAttr(kWasPeeledAttr); - }); - func_op->walk([&](ParallelOp op) { - if (op->hasAttr(kWasPeeledAttr)) op->removeAttr(kWasPeeledAttr); - }); - func_op->walk([&](ForOp op) { - if (op->hasAttr(kWasPeeledAttr)) op->removeAttr(kWasPeeledAttr); - }); - } -}; - -} // namespace - -std::unique_ptr> -CreatePeelTiledLoopsPass() { - return std::make_unique(); -} -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_symbolic_shape_optimization.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_symbolic_shape_optimization.cc deleted file mode 100644 index e7c7eb169ac..00000000000 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_symbolic_shape_optimization.cc +++ /dev/null @@ -1,327 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/TypeRange.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/iterator_range.h" -#include "llvm/Support/Casting.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Analysis/shape_component_analysis.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" - -namespace tensorflow { -namespace { - -using llvm::ArrayRef; -using llvm::SmallVector; - -using mlir::AffineExpr; -using mlir::AffineMap; -using mlir::failure; -using mlir::Location; -using mlir::LogicalResult; -using mlir::MLIRContext; -using mlir::OpBuilder; -using mlir::OperationPass; -using mlir::RankedTensorType; -using mlir::ShapeComponentAnalysis; -using mlir::success; -using mlir::TypeRange; -using mlir::Value; -using mlir::ValueRange; -using mlir::arith::ConstantIndexOp; -using mlir::arith::ConstantOp; -using mlir::arith::IndexCastOp; -using mlir::func::FuncOp; - -namespace linalg = mlir::linalg; -namespace mhlo = mlir::mhlo; -namespace shape = mlir::shape; -namespace tensor = mlir::tensor; - -#define GEN_PASS_DEF_SYMBOLICSHAPEOPTIMIZATION -#define GEN_PASS_DECL_SYMBOLICSHAPEOPTIMIZATION -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc" - -// -------------------------------------------------------------------------- // - - - - - - -// Replace shape.broadcast with a shape if it's statically known. -class BroadcastOpLowering final - : public mlir::OpRewritePattern { - public: - explicit BroadcastOpLowering(MLIRContext* ctx) : OpRewritePattern(ctx) {} - - LogicalResult matchAndRewrite(shape::BroadcastOp op, - mlir::PatternRewriter& rewriter) const override; -}; - -// Returns a shape tensor if the shapes can be broadcasted to a known shape. -// Will either return one of the shapes or a generated mix of the shapes. -llvm::Optional simplifyBroadcast(ShapeComponentAnalysis& analysis, - ValueRange shapes, Location loc, - OpBuilder* builder) { - // First find the input shape with the largest rank. - SmallVector> shapes_found; - size_t maxRank = 0; - for (const auto &shape : llvm::enumerate(shapes)) { - auto found_shape = analysis.GetValueInfo(shape.value()); - if (!found_shape) return {}; - shapes_found.push_back(*found_shape); - maxRank = std::max(maxRank, found_shape->size()); - } - if (maxRank == 0) { - return Value(builder->create( - loc, shapes[0].getType(), SmallVector())); - } - - SmallVector joined_dimensions( - maxRank); - SmallVector> shape_and_rank_for_dim(maxRank); - for (const auto &shape : llvm::enumerate(shapes_found)) { - for (const auto &dim : llvm::enumerate(llvm::reverse(shape.value()))) { - // 1 dimensions don't contribute to the final result. - if (dim.value().isConstant(1)) continue; - // If it's not a 1 dimension it will be present in the result. Remember - // where it came from. - auto index = maxRank - dim.index() - 1; - if (!joined_dimensions[index]) { - joined_dimensions[index] = &dim.value(); - shape_and_rank_for_dim[index] = - std::make_pair(shapes[shape.index()], shape.value().size()); - continue; - } - // Bail if the dimensions are neither equal nor 1. - if (*joined_dimensions[index] != dim.value()) return {}; - } - } - // If the output is the same as one of the inputs just return that. - if (llvm::all_equal(shape_and_rank_for_dim) && - shape_and_rank_for_dim[0].first) { - return shape_and_rank_for_dim[0].first; - } - // Otherwise rematerialize the shape from the pieces we have. - SmallVector elements; - for (int i = 0; i != maxRank; ++i) { - // 1 dimensions are filtered above, recreate the constant. - if (!shape_and_rank_for_dim[i].first) { - auto one = builder->getIntegerAttr( - shapes[0].getType().cast().getElementType(), 1); - elements.push_back(builder->create(loc, one)); - continue; - } - // Extract from one of the shapes, accounting for the reverse indexing - // performed by broadcast. - Value index = builder->create( - loc, i - maxRank + shape_and_rank_for_dim[i].second); - elements.push_back(builder->create( - loc, shape_and_rank_for_dim[i].first, index)); - } - return Value(builder->create(loc, elements)); -} - -LogicalResult BroadcastOpLowering::matchAndRewrite( - shape::BroadcastOp op, mlir::PatternRewriter& rewriter) const { - ShapeComponentAnalysis shape_component_analysis; - auto new_broadcast = simplifyBroadcast( - shape_component_analysis, op.getShapes(), op.getLoc(), &rewriter); - if (!new_broadcast) return failure(); - rewriter.replaceOp(op, {*new_broadcast}); - return success(); -} - -// -------------------------------------------------------------------------- // - -// Rewrite mhlo.dynamic_broadcast_in_dim operation into linalg.generic operation -// if can infer the indexing maps for the operand from the symbolic shapes. -class DynamicBroadcastInDimOpLowering - : public mlir::OpRewritePattern { - public: - using Base = OpRewritePattern; - - explicit DynamicBroadcastInDimOpLowering(MLIRContext* ctx); - - LogicalResult matchAndRewrite(mhlo::DynamicBroadcastInDimOp op, - mlir::PatternRewriter& rewriter) const override; -}; - -DynamicBroadcastInDimOpLowering::DynamicBroadcastInDimOpLowering( - MLIRContext* ctx) - : Base(ctx) {} - -// Check if broadcasting `from` to `to_shape` is statically known to only have -// dimensions that never expand or always expand. -llvm::Optional isNonExpandingBroadcast( - ShapeComponentAnalysis& analysis, Value from, Value to_shape) { - auto in_shape = analysis.GetShapeInfo(from); - auto out_shape = analysis.GetValueInfo(to_shape); - if (!in_shape || !out_shape) return {}; - - SmallVector input_map_exprs; - size_t rank = out_shape->size(); - MLIRContext* ctx = (*out_shape)[0].expr.getContext(); - size_t d = 0; - auto affine_zero = getAffineConstantExpr(0, ctx); - for (auto zip : - llvm::zip(llvm::reverse(*in_shape), llvm::reverse(*out_shape))) { - const auto& in = std::get<0>(zip); - const auto& out = std::get<1>(zip); - bool extend = in.isConstant(1) && !out.isConstant(1); - input_map_exprs.push_back(extend ? affine_zero - : getAffineDimExpr(rank - d - 1, ctx)); - ++d; - - // Bail if this is neither a known expansion nor a known non-expansion. - if (!extend && in != out) return {}; - } - // Any leading dimensions will be expanded. - input_map_exprs.resize(in_shape->size(), affine_zero); - std::reverse(input_map_exprs.begin(), input_map_exprs.end()); - return AffineMap::get(/*dimCount=*/rank, - /*symbolCount=*/0, input_map_exprs, ctx); -} - -LogicalResult DynamicBroadcastInDimOpLowering::matchAndRewrite( - mhlo::DynamicBroadcastInDimOp op, mlir::PatternRewriter& rewriter) const { - MLIRContext* ctx = getContext(); - - auto in_type = op.getOperand().getType().dyn_cast(); - auto out_type = op.getResult().getType().dyn_cast(); - if (!in_type || !out_type) return failure(); - - // Check that broadcast is right-aligned (numpy style), so that operand - // dimensions broadcasted to match inner-most dimensions of the output. - auto bcast_dims = op.getBroadcastDimensions().getValues(); - auto expected_bcast_dims = llvm::seq( - out_type.getRank() - in_type.getRank(), out_type.getRank()); - if (!llvm::equal(bcast_dims, expected_bcast_dims)) return failure(); - - ShapeComponentAnalysis shape_component_analysis; - auto input_map = isNonExpandingBroadcast( - shape_component_analysis, op.getOperand(), op.getOutputDimensions()); - if (!input_map) return failure(); - - // Resolve dynamic output dimensions for the `tensor.empty` operation. - SmallVector output_dyn_dimensions; - Location loc = op.getLoc(); - int64_t rank = out_type.getRank(); - for (size_t d = 0; d < rank; ++d) { - int64_t output_dim = out_type.getShape()[d]; - - // Skip static output dimensions, they will be resolved from the shape. - if (output_dim >= 0) continue; - - // Resolve the dynamic size of the output dimension. - Value output_dyn_dim = rewriter.create( - loc, op.getOutputDimensions(), - ValueRange{rewriter.create(loc, d)}); - - // Symbolic shape analysis might have given us an i32 or i64. Cast to index. - if (!output_dyn_dim.getType().isIndex()) - output_dyn_dim = rewriter.create( - loc, rewriter.getIndexType(), output_dyn_dim); - - output_dyn_dimensions.push_back(output_dyn_dim); - } - - // Create a tensor.empty operation to initialize output. - Value emptyTensor = rewriter.create( - loc, out_type.getShape(), out_type.getElementType(), - output_dyn_dimensions); - - // Output indexing map is an identity with `rank` number of loops. - AffineMap output_map = AffineMap::getMultiDimIdentityMap(rank, ctx); - - // All iterators are parallel. - SmallVector iterator_types(rank, "parallel"); - - rewriter.replaceOpWithNewOp( - op, /*resultTensorTypes=*/TypeRange{emptyTensor.getType()}, - /*inputs=*/ValueRange{op.getOperand()}, - /*outputs=*/ValueRange{emptyTensor}, - /*indexingMaps=*/llvm::makeArrayRef({*input_map, output_map}), - /*iteratorTypes=*/iterator_types, - [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { - nested_builder.create(nested_loc, args[0]); - }); - - return success(); -} - -// -------------------------------------------------------------------------- // -// Optimize function based on the symbolic shape attributes. -// -------------------------------------------------------------------------- // - -struct SymbolicShapeOptimizationPass - : public impl::SymbolicShapeOptimizationBase< - SymbolicShapeOptimizationPass> { - SymbolicShapeOptimizationPass() = default; - - explicit SymbolicShapeOptimizationPass(bool constraints_only) { - this->optimize_only_constraints = constraints_only; - } - - void runOnOperation() override { - MLIRContext* ctx = &getContext(); - mlir::RewritePatternSet patterns(ctx); - - // Rewrite shape.broadcast based on the symbolic shapes. - patterns.add(ctx); - - // Rewrite broadcasts based on the symbolic shapes if enabled. - if (!optimize_only_constraints) - patterns.add(ctx); - - // Add shape dialect canonicalization patterns to fold shape operations - // after constraints are replaced with constant witness. - for (auto op : ctx->getRegisteredOperations()) { - if (llvm::isa(op.getDialect())) - op.getCanonicalizationPatterns(patterns, ctx); - } - - if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr> CreateSymbolicShapeOptimizationPass( - bool constraints_only) { - return std::make_unique(constraints_only); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_tile_cwise.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_tile_cwise.cc deleted file mode 100644 index bd99c578362..00000000000 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_tile_cwise.cc +++ /dev/null @@ -1,174 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" - -namespace tensorflow { -namespace { - -#define GEN_PASS_DEF_TILEFILL -#define GEN_PASS_DEF_TILECWISE -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc" - -using mlir::failure; -using mlir::LogicalResult; -using mlir::MLIRContext; -using mlir::OpBuilder; -using mlir::Operation; -using mlir::PatternRewriter; -using mlir::SmallVector; -using mlir::success; -using mlir::Value; -using mlir::arith::ConstantIndexOp; -using mlir::gml_st::LoopOp; -using mlir::linalg::FillOp; -using mlir::linalg::GenericOp; -using mlir::linalg::LinalgOp; -using mlir::linalg::LinalgTilingOptions; - -struct TileCWisePattern : public mlir::OpInterfaceRewritePattern { - TileCWisePattern(LinalgTilingOptions options, MLIRContext *context, - llvm::function_ref match_fn, - mlir::PatternBenefit benefit = 1) - : mlir::OpInterfaceRewritePattern(context, benefit), - match_fn(match_fn), - options(options) {} - - LogicalResult matchAndRewrite(LinalgOp linalg_op, - PatternRewriter &rewriter) const override { - if (hasTransformationAttr(linalg_op)) return failure(); - if (!match_fn(linalg_op)) return failure(); - - auto tiled_linalg_op = - mlir::gml_st::tileLinalgOp(rewriter, linalg_op, options); - if (failed(tiled_linalg_op) || tiled_linalg_op.getValue().loops.empty()) - return failure(); - - LoopOp tiled_loop = - mlir::dyn_cast(*tiled_linalg_op.getValue().loops.front()); - if (!tiled_loop) return failure(); - - tiled_loop->walk( - [&](LinalgOp tiledOp) { setTransformationAttr(rewriter, tiledOp); }); - - rewriter.replaceOp(linalg_op, tiled_loop->getResults()); - return success(); - } - - private: - llvm::function_ref match_fn; - LinalgTilingOptions options; -}; - -// Return true if the generic has only parallel iterations. This disallows -// windowed and reduction iteration. -bool isNonTiledCwiseGeneric(Operation *op) { - if (op->getParentOfType()) return false; - auto linalg_op = mlir::dyn_cast(op); - if (linalg_op) { - if (!linalg_op.hasTensorSemantics()) return false; - return llvm::all_of(linalg_op.getIteratorTypesArray(), - mlir::linalg::isParallelIterator); - } - if (auto fill_op = mlir::dyn_cast(op)) { - return fill_op.hasTensorSemantics(); - } - return false; -} - -// Return true if the generic has only parallel iterations. This disallows -// windowed and reduction iteration. -bool isNonTiledFill(Operation *op) { - if (op->getParentOfType()) return false; - if (auto fill_op = mlir::dyn_cast(op)) { - return fill_op.hasTensorSemantics(); - } - return false; -} - -void Tile(mlir::func::FuncOp func, int64_t tile_size, - llvm::function_ref match_fn) { - LinalgTilingOptions tiling_options; - // Tile the innermost dimension by `tile_size` for vectorization and scalarize - // the other dimensions. - tiling_options.setTileSizeComputationFunction( - [&](OpBuilder b, Operation *op) { - auto num_loops = llvm::cast(op).getNumLoops(); - SmallVector tiles(num_loops, - b.create(op->getLoc(), 1)); - if (!tiles.empty()) - tiles.back() = b.create(op->getLoc(), tile_size); - return tiles; - }); - - mlir::RewritePatternSet patterns(func.getContext()); - patterns.add(tiling_options, patterns.getContext(), - match_fn); - (void)mlir::applyPatternsAndFoldGreedily(func, std::move(patterns)); - - // Ensure we drop the marker in the end. - func.walk([](LinalgOp op) { removeTransformationAttr(op); }); -} - -struct TileCWisePass : public impl::TileCWiseBase { - TileCWisePass() = default; - explicit TileCWisePass(int64_t tile_size) { cwise_tile_size = tile_size; } - - void runOnOperation() override { - auto func = getOperation(); - Tile(func, cwise_tile_size, isNonTiledCwiseGeneric); - } -}; - -struct TileFillPass : public impl::TileFillBase { - TileFillPass() = default; - explicit TileFillPass(int64_t tile_size) { cwise_tile_size = tile_size; } - - void runOnOperation() override { - auto func = getOperation(); - Tile(func, cwise_tile_size, isNonTiledFill); - } -}; - -} // namespace - -std::unique_ptr> CreateTileCWisePass() { - return std::make_unique(); -} - -std::unique_ptr> CreateTileCWisePass( - int64_t cwise_tile_size) { - return std::make_unique(cwise_tile_size); -} - -std::unique_ptr> CreateTileFillPass() { - return std::make_unique(); -} - -std::unique_ptr> CreateTileFillPass( - int64_t cwise_tile_size) { - return std::make_unique(cwise_tile_size); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_tile_reduction.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_tile_reduction.cc deleted file mode 100644 index b65e8f1f48a..00000000000 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_tile_reduction.cc +++ /dev/null @@ -1,420 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" - -namespace tensorflow { -namespace { - -#define GEN_PASS_DEF_TILEREDUCTION -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc" - -using llvm::makeArrayRef; -using mlir::BlockAndValueMapping; -using mlir::dyn_cast; -using mlir::failure; -using mlir::FailureOr; -using mlir::Location; -using mlir::LogicalResult; -using mlir::MLIRContext; -using mlir::OpBuilder; -using mlir::Operation; -using mlir::OpRewritePattern; -using mlir::PatternRewriter; -using mlir::RankedTensorType; -using mlir::ShapedType; -using mlir::SmallVector; -using mlir::success; -using mlir::Value; -using mlir::ValueRange; -using mlir::arith::ConstantIndexOp; -using mlir::gml_st::IteratorTypeAttr; -using mlir::gml_st::LoopOp; -using mlir::linalg::FillOp; -using mlir::linalg::GenericOp; -using mlir::linalg::LinalgOp; -using mlir::linalg::LinalgTilingOptions; -using mlir::tensor::EmptyOp; -using mlir::tensor::ExpandShapeOp; -using mlir::tensor::ExtractSliceOp; -using mlir::utils::IteratorType; - -// Match 1D or 2D reduction. -bool isCanonicalizedReduction(Operation *op) { - auto reduction = mlir::dyn_cast(op); - if (!reduction) return false; - - if (reduction.getNumDpsInits() != 1) return false; - if (reduction.getNumLoops() > 2) return false; - return reduction.getNumReductionLoops() == 1; -} - -// Tiles a GenericOp that models a 2D row or column reduction. -struct RowOrColumnReductionTilingPattern : public OpRewritePattern { - RowOrColumnReductionTilingPattern(const LinalgTilingOptions &options, - MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), options(options) {} - - LogicalResult matchAndRewrite(GenericOp linalg_op, - PatternRewriter &rewriter) const override { - if (hasTransformationAttr(linalg_op)) return failure(); - if (!isCanonicalizedReduction(linalg_op)) return failure(); - - if (linalg_op.getNumDpsInits() != 1) return failure(); - if (linalg_op.getNumLoops() != 2) return failure(); - - auto tiled_op = mlir::gml_st::tileLinalgOp(rewriter, linalg_op, options); - if (failed(tiled_op)) return failure(); - - tiled_op->loops.front()->walk( - [&](LinalgOp tOp) { setTransformationAttr(rewriter, tOp); }); - - rewriter.replaceOp(linalg_op, tiled_op->tensorResults); - return success(); - } - - private: - LinalgTilingOptions options; -}; - -// Rewrites a 1D reduction for vectorization. Matches `linalg.generic` that -// combines elements of tensor into tensor and then -// creates a perfectly-tilable loop to reduce tensor -> -// tensor and an additional `linalg.generic` that reduces -// tensor to tensor. -// -// Example: -// -// %sum = linalg.generic { -// indexing_maps = [affine_map<(d0) -> (d0)>, -// affine_map<(d0) -> ()>], -// iterator_types = ["reduction"]} -// ins(%input : tensor) -// outs(%fill : tensor) { -// ^bb0(%in: f32, %out: f32): -// %add = arith.addf %in, %out : f32 -// linalg.yield %add : f32 -// } -> tensor -// -// will be rewritten as -// -// %vector_result = gml_st.loop (%i) -// = (%c0) to (%TILABLE_UB) step (%vector_size) -// ins (%input_ = %input: tensor) -// outs (%tmp_result_ = %tmp_result: tensor) -// iterators["reduction"] { -// %tile = tensor.extract_slice %arg2[%i] [%TILE_SIZE] [1] -// : tensor to tensor -// %tile_reshape = tensor.expand_shape %tile [[0, 1]] -// : tensor into tensor<1xVECTOR_SIZExf32> -// %combine = linalg.generic ins(%tile_reshape : tensor<1xVECTOR_SIZExf32>) -// outs(%tmp_result_ : tensor) -> tensor -// linalg.yield %combine : tensor -// } -// %horizontal_reduce = linalg.generic -// ins(%vector_result : tensor) -// outs(%fill : tensor) -> tensor // combiner only -// %result = gml_st.loop (%i) -// = (%TILABLE_UB) to (%INPUT_SIZE) step (%vector_size) -// ins (%input_ = %input: tensor) -// outs (%tmp_result_ = %horizontal_reduce: tensor) -// iterators["reduction"] { -// linalg.generic // reduces the tail -// } -// -// This is necessary to push horizontal reduction to the later stage. -struct OneDimReductionTilingPattern : public OpRewritePattern { - OneDimReductionTilingPattern(int64_t vector_size, int64_t tile_size, - mlir::MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - vector_size(vector_size), - tile_size(tile_size) {} - - LogicalResult matchAndRewrite(GenericOp linalg_op, - PatternRewriter &rewriter) const override { - if (hasTransformationAttr(linalg_op)) return failure(); - if (!isCanonicalizedReduction(linalg_op)) return failure(); - - // Check if all inputs have a 1D identity map. - if (linalg_op.getNumLoops() != 1) return failure(); - auto indexing_maps = linalg_op.getIndexingMapsArray(); - for (auto affine_map : makeArrayRef(indexing_maps).drop_back()) { - if (!affine_map.isIdentity()) return failure(); - } - - Location loc = linalg_op.getLoc(); - Value input = linalg_op.getDpsInputOperand(0)->get(); - // All inputs have the same size because of identity maps for indexing. - SmallVector inputs = linalg_op.getInputs(); - Value input_size = rewriter.create(loc, input, 0); - - auto fill_op = linalg_op.getOutputs().front().getDefiningOp(); - auto empty_op = fill_op.output().getDefiningOp(); - - auto neutral_value = fill_op.value(); - auto element_type = empty_op.getType().getElementType(); - - Value zero = rewriter.create(loc, 0); - Value tile_size_value = rewriter.create(loc, tile_size); - Value new_empty = - rewriter.create(loc, vector_size, element_type); - Value new_fill = - rewriter.create(loc, fill_op.value(), new_empty).result(); - - llvm::Optional tilable_bound_or = - getTilableBound(rewriter, loc, zero, input_size, tile_size_value); - Value tilable_bound = - tilable_bound_or.has_value() ? *tilable_bound_or : input_size; - - GenericOp tiled_reduction; - auto perfectly_tiled_loop = rewriter.create( - loc, makeArrayRef(zero), makeArrayRef(tilable_bound), - makeArrayRef(tile_size_value), inputs, makeArrayRef(new_fill), - rewriter.getArrayAttr({IteratorTypeAttr::get(rewriter.getContext(), - IteratorType::reduction)}), - [&](OpBuilder &b, Location nested_loc, ValueRange ivs, - ValueRange inputs, ValueRange outputs) { - SmallVector reshaped_tiled_inputs = - TileAndReshapeInputTensors(b, nested_loc, ivs, inputs, - neutral_value, input_size, - tile_size_value); - // Create `linalg.generic` to combine - // `tensor<(TILE_SIZE/VECTOR_SIZE)xVECTOR_SIZExELEM_TYPE> input with - // the `tensor` output. - SmallVector iter_types{ - mlir::getReductionIteratorTypeName(), - mlir::getParallelIteratorTypeName()}; - SmallVector indexing_maps( - inputs.size(), rewriter.getMultiDimIdentityMap(2)); - indexing_maps.push_back( - mlir::AffineMap::get(2, 0, b.getAffineDimExpr(1))); - tiled_reduction = b.create( - nested_loc, outputs[0].getType(), reshaped_tiled_inputs, - makeArrayRef({outputs[0]}), indexing_maps, iter_types, - /*bodyBuild=*/nullptr); - mlir::Region ®ion = tiled_reduction.getRegion(); - OpBuilder::InsertionGuard g(rewriter); - rewriter.cloneRegionBefore(linalg_op.getRegion(), region, - region.end()); - b.create(nested_loc, - tiled_reduction.getResult(0)); - }); - // Create `linalg.generic` to reduce - // tensor->tensor. - auto horizontal_reduction_or = ReduceVectorIntoOutput( - rewriter, linalg_op, perfectly_tiled_loop.getResult(0)); - if (failed(horizontal_reduction_or)) return failure(); - auto horizontal_reduction = horizontal_reduction_or.getValue(); - Value result = horizontal_reduction->getResult(0); - - // If the loop was not perfectly tiled, then we have to combine - // `horizontal_reduction` with the elements in the `tail`. - if (tilable_bound_or.has_value()) { - auto final_reduction = rewriter.create( - loc, tilable_bound, input_size, tile_size_value, inputs, - makeArrayRef(result), - rewriter.getArrayAttr({IteratorTypeAttr::get( - rewriter.getContext(), IteratorType::reduction)}), - [&](OpBuilder &b, Location nested_loc, ValueRange ivs, - ValueRange inputs, ValueRange outputs) { - BlockAndValueMapping bvm; - mlir::AffineExpr sym0, sym1; - bindSymbols(b.getContext(), sym0, sym1); - auto diff_map = mlir::AffineMap::get(0, 2, {sym1 - sym0}); - - Value one = b.create(nested_loc, 1); - auto size = b.createOrFold( - nested_loc, diff_map, ValueRange{tilable_bound, input_size}); - std::vector sliced_inputs; - sliced_inputs.reserve(inputs.size()); - for (Value input : inputs) { - sliced_inputs.push_back( - b.create(nested_loc, input, ivs, size, one)); - } - bvm.map(linalg_op.getInputs(), sliced_inputs); - bvm.map(linalg_op.getOutputs(), outputs); - auto new_linalg_op = b.clone(*linalg_op.getOperation(), bvm); - setTransformationAttr(b, new_linalg_op); - b.create(nested_loc, - new_linalg_op->getResult(0)); - }); - result = final_reduction.getResult(0); - } - rewriter.replaceOp(linalg_op, result); - - perfectly_tiled_loop->walk( - [&](GenericOp op) { setTransformationAttr(rewriter, op); }); - setTransformationAttr(rewriter, horizontal_reduction); - return success(); - } - - private: - // Computes an upper bound that can be perfectly tiled. Return llvm::None, if - // the loop is already perfectly tiled. - mlir::Optional getTilableBound(OpBuilder &b, Location loc, Value lb, - Value ub, Value step) const { - auto lb_int = getConstantIntValue(lb); - auto ub_int = getConstantIntValue(ub); - auto step_int = getConstantIntValue(step); - - // No specialization necessary if step already divides upper bound evenly. - if (lb_int && ub_int && step_int && (*ub_int - *lb_int) % *step_int == 0) - return llvm::None; - // No specialization necessary if step size is 1. - if (mlir::isConstantIntValue(step, 1)) return llvm::None; - mlir::AffineExpr sym0, sym1, sym2; - bindSymbols(b.getContext(), sym0, sym1, sym2); - - // New upper bound: %ub - (%ub - %lb) mod %step - auto mod_map = mlir::AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)}); - return {b.createOrFold(loc, mod_map, - ValueRange{lb, ub, step})}; - } - - // Tiles, pads and reshapes every input argument of type tensor - // into tensor<(TILE_SIZE/VECTOR_SIZE)xVECTOR_SIZExELEM_TYPE>. - SmallVector TileAndReshapeInputTensors( - OpBuilder &b, Location nested_loc, ValueRange ivs, ValueRange inputs, - Value neutral_value, Value input_size, Value tile_size_value) const { - SmallVector reshaped_tiled_inputs; - - SmallVector indices = {{0, 1}}; - auto identity_1d_map = b.getMultiDimIdentityMap(1); - auto iv = ivs.front(); - - mlir::OpFoldResult tile_size_fold = tile_size_value; - mlir::OpFoldResult input_size_fold = input_size; - auto tile_sizes = mlir::linalg::computeTileSizes( - b, nested_loc, tile_size_fold, input_size_fold); - for (auto input : inputs) { - // Extract slice of input. - Value slice = mlir::linalg::makeTiledShape( - b, nested_loc, input, tile_size_fold, identity_1d_map, - mlir::OpFoldResult(iv), input_size_fold, tile_sizes, - /*omitPartialTileCheck=*/true); - auto element_type = slice.getType().cast().getElementType(); - - // Reshape input tile to - // tensor<(TILE_SIZE/VECTOR_SIZE)xVECTOR_SIZExELEM_TYPE>. - Value expand_shape = b.create( - nested_loc, - RankedTensorType::get({tile_size / vector_size, vector_size}, - element_type), - slice, indices); - reshaped_tiled_inputs.push_back(expand_shape); - } - return reshaped_tiled_inputs; - } - - // Creates `linalg.generic` to reduce - // tensor->tensor. To perform that we match - // the combiner in the original "untiled" linalg_op. - FailureOr ReduceVectorIntoOutput(PatternRewriter &rewriter, - LinalgOp linalg_op, - Value partial_result) const { - SmallVector reduction_iter_type( - 1, mlir::getReductionIteratorTypeName()); - auto map = mlir::AffineMap::get(1, 0, llvm::None, rewriter.getContext()); - - auto combiner_or = DetectCombiner(linalg_op); - if (failed(combiner_or)) return failure(); - Operation *combiner = combiner_or.getValue(); - - auto accumulator = rewriter.create( - linalg_op.getLoc(), linalg_op->getResultTypes(), - makeArrayRef(partial_result), - makeArrayRef(linalg_op.getDpsInitOperand(0)->get()), - makeArrayRef({rewriter.getMultiDimIdentityMap(1), map}), - reduction_iter_type, - [&](OpBuilder &b, Location nested_loc, ValueRange args) { - BlockAndValueMapping bvm; - bvm.map(combiner->getOperands(), args); - Value result_val = b.clone(*combiner, bvm)->getResult(0); - b.create(nested_loc, result_val); - }); - return accumulator; - } - - private: - int64_t vector_size; - int64_t tile_size; -}; - -struct TileReductionPass : public impl::TileReductionBase { - TileReductionPass() = default; - TileReductionPass(int64_t vector_size, int64_t reduction_1d_tile, - llvm::ArrayRef reduction_2d_tiles) { - reduction_vector_size = vector_size; - reduction_1d_tile_size = reduction_1d_tile; - reduction_2d_tile_sizes = reduction_2d_tiles; - } - void runOnOperation() override { - auto func = getOperation(); - auto context = func.getContext(); - - assert(reduction_1d_tile_size % reduction_vector_size == 0 && - "Tile size for 1D reduction should be a multiple of vector size"); - auto patterns = - mlir::linalg::getLinalgTilingCanonicalizationPatterns(context); - patterns.add( - reduction_vector_size, reduction_1d_tile_size, patterns.getContext()); - - assert(reduction_2d_tile_sizes.size() == 2 && - "Tiling sizes for 2D reductions should have two elements"); - patterns.add( - LinalgTilingOptions{}.setTileSizes(reduction_2d_tile_sizes), - patterns.getContext()); - (void)mlir::applyPatternsAndFoldGreedily(func, std::move(patterns)); - - // Ensure we drop the marker in the end. - func.walk([](LinalgOp op) { removeTransformationAttr(op); }); - } -}; - -} // namespace - -std::unique_ptr> -CreateTileReductionPass() { - return std::make_unique(); -} - -std::unique_ptr> -CreateTileReductionPass(int64_t reduction_vector_size, - int64_t reduction_1d_tile_size, - llvm::ArrayRef reduction_2d_tile_sizes) { - return std::make_unique( - reduction_vector_size, reduction_1d_tile_size, reduction_2d_tile_sizes); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_tile_transpose.cc b/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_tile_transpose.cc deleted file mode 100644 index 72866dd1170..00000000000 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_tile_transpose.cc +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" - -namespace tensorflow { -namespace { - -#define GEN_PASS_DEF_TILETRANSPOSE -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc" - -using llvm::SmallVector; -using mlir::Attribute; -using mlir::dyn_cast; -using mlir::failure; -using mlir::MLIRContext; -using mlir::Operation; -using mlir::PatternRewriter; -using mlir::success; -using mlir::Value; -using mlir::arith::ConstantIndexOp; -using mlir::gml_st::LoopOp; -using mlir::linalg::GenericOp; -using mlir::linalg::LinalgTilingOptions; - -/// Returns true if the operation is a GenericOp implementing a transposition. -// TODO(diegocaballero): Move it to MLIR core? -bool IsTransposeGenericOp(Operation *op) { - // Check that op is a generic op and has at least 2 dimensions. - auto generic_op = dyn_cast(op); - if (!generic_op) return false; - if (generic_op.getNumLoops() < 2) return false; - - // Check whether the body has only one operation (yield op). Transpose ops - // fused with any other operations are not supported for now. - mlir::Block *body = generic_op.getBody(); - if (body->empty() || body->begin() != std::prev(body->end())) return false; - auto yield_op = dyn_cast(body->back()); - if (!yield_op || (yield_op.getNumOperands() != 1)) return false; - - // Check input and output. - if ((generic_op.getNumDpsInputs() != 1) || (generic_op.getNumDpsInits() != 1)) - return false; - - // Check that input is yielded. - if (generic_op.getMatchingBlockArgument(generic_op.getDpsInputOperand(0)) != - yield_op.getOperand(0)) - return false; - - // Check parallel iterators. - auto iterator_types = generic_op.getIteratorTypesArray(); - if (std::any_of(iterator_types.begin(), iterator_types.end(), - [](auto iterator_type) { - return !mlir::linalg::isParallelIterator(iterator_type); - })) - return false; - - // Check that the two indexing maps are a permutation. - auto indexing_maps = generic_op.getIndexingMapsArray(); - if (indexing_maps.size() != 2) return false; - return (indexing_maps[0].isIdentity() && indexing_maps[1].isPermutation()) || - (indexing_maps[0].isPermutation() && indexing_maps[1].isIdentity()); -} - -struct TileTransposePattern : public mlir::OpRewritePattern { - TileTransposePattern(LinalgTilingOptions options, MLIRContext *context, - mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern(context, benefit), options(options) {} - - mlir::LogicalResult matchAndRewrite( - GenericOp linalg_op, PatternRewriter &rewriter) const override { - if (hasTransformationAttr(linalg_op)) return failure(); - if (!IsTransposeGenericOp(linalg_op)) return failure(); - - auto tiled_linalg_op = - mlir::gml_st::tileLinalgOp(rewriter, linalg_op, options); - if (failed(tiled_linalg_op) || tiled_linalg_op.getValue().loops.empty()) - return failure(); - - auto tiled_loop = - mlir::dyn_cast(*tiled_linalg_op.getValue().loops.front()); - if (!tiled_loop) return failure(); - - tiled_loop->walk( - [&](GenericOp tiledOp) { setTransformationAttr(rewriter, tiledOp); }); - - rewriter.replaceOp(linalg_op, tiled_loop->getResults()); - return success(); - } - - private: - LinalgTilingOptions options; -}; - -struct TileTransposePass : public impl::TileTransposeBase { - void runOnOperation() override { - auto get_tile_size = [&](mlir::OpBuilder b, Operation *op) { - auto generic_op = llvm::cast(op); - unsigned num_loops = generic_op.getNumLoops(); - assert(num_loops >= 2 && "Expect two or more dimension in transpose op"); - - // Compute the tile sizes for the 2-D vectorization of the transpose. We - // pick eight as default vectorization factor for both dimensions since - // it's the most performant AVX2 pattern for now. We pick the contiguous - // dimension of the input as first vector dimension and the contiguous - // dimension of the output as second vector dimension. This will maximize - // contiguous vector loads/stores and minimize insert/extract/gather/ - // scatter operations. - SmallVector tiles(num_loops, - b.create(op->getLoc(), 1)); - auto indexing_maps = generic_op.getIndexingMapsArray(); - unsigned last_dim = num_loops - 1; - unsigned vec_factor0 = 8, vec_factor1 = 8; - unsigned vec_dim0 = indexing_maps[0].getDimPosition(last_dim); - unsigned vec_dim1 = indexing_maps[1].getDimPosition(last_dim); - - // If the contiguous dimensions of both input and output are not - // transposed (i.e, they are the same), we vectorize only that dimension. - // That transpose case doesn't require intra-register transposition but - // just copying a set of contiguous sub-buffers from the input to the - // output tensor. Vectorizing a second dimension would increase too much - // the memory pressure for no reason. - if (vec_dim0 == vec_dim1) { - tiles[vec_dim0] = b.create(op->getLoc(), vec_factor0); - } else { - tiles[vec_dim0] = b.create(op->getLoc(), vec_factor0); - tiles[vec_dim1] = b.create(op->getLoc(), vec_factor1); - } - - return tiles; - }; - - auto func = getOperation(); - auto tiling_options = - LinalgTilingOptions().setTileSizeComputationFunction(get_tile_size); - - mlir::RewritePatternSet patterns(func.getContext()); - patterns.add(tiling_options, patterns.getContext()); - if (failed(mlir::applyPatternsAndFoldGreedily(func, std::move(patterns)))) { - signalPassFailure(); - } - - // Ensure we drop the marker in the end. - func.walk([](GenericOp op) { removeTransformationAttr(op); }); - } -}; - -} // namespace - -std::unique_ptr> -CreateTileTransposePass() { - return std::make_unique(); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/lhlo-tfrt-opt.cc b/tensorflow/compiler/mlir/tfrt/lhlo-tfrt-opt.cc index 3c61d83f6b9..7af989f9cd3 100644 --- a/tensorflow/compiler/mlir/tfrt/lhlo-tfrt-opt.cc +++ b/tensorflow/compiler/mlir/tfrt/lhlo-tfrt-opt.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" -#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h" +#include "lhlo/IR/lhlo_ops.h" +#include "lhlo_gpu/IR/lhlo_gpu_ops.h" #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/InitAllDialects.h" // from @llvm-project #include "mlir/InitAllPasses.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfrt/python_tests/BUILD b/tensorflow/compiler/mlir/tfrt/python_tests/BUILD index 365e9e360f1..3413657181d 100644 --- a/tensorflow/compiler/mlir/tfrt/python_tests/BUILD +++ b/tensorflow/compiler/mlir/tfrt/python_tests/BUILD @@ -1,6 +1,8 @@ load("//tensorflow:strict.default.bzl", "py_strict_test") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + licenses(["notice"]) py_strict_test( @@ -309,6 +311,30 @@ py_strict_test( ], ) +py_strict_test( + name = "tf_reverse_test", + srcs = ["tf_reverse_test.py"], + python_version = "PY3", + tags = ["no_pip"], # TODO(b/201803253): TFRT pybindings not in OSS. + deps = [ + "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + +py_strict_test( + name = "tf_scatter_test", + srcs = ["tf_scatter_test.py"], + python_version = "PY3", + tags = ["no_pip"], # TODO(b/201803253): TFRT pybindings not in OSS. + deps = [ + "//tensorflow/compiler/mlir/tfrt/jit/python_binding:tf_jitrt", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + td_library( name = "python_test_attrs_td_files", srcs = ["python_test_attrs.td"], diff --git a/tensorflow/compiler/mlir/tfrt/python_tests/python_test_attrs.td b/tensorflow/compiler/mlir/tfrt/python_tests/python_test_attrs.td index 0b575a07709..4b40d2c7865 100644 --- a/tensorflow/compiler/mlir/tfrt/python_tests/python_test_attrs.td +++ b/tensorflow/compiler/mlir/tfrt/python_tests/python_test_attrs.td @@ -36,8 +36,7 @@ def PythonTestAttrsDialect : Dialect { return (getDialectNamespace() + ".shape_value").str(); } }]; - - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // PYTHON_TEST_ATTRS diff --git a/tensorflow/compiler/mlir/tfrt/python_tests/regression_tests/broadcasting_25.mlir b/tensorflow/compiler/mlir/tfrt/python_tests/regression_tests/broadcasting_25.mlir new file mode 100644 index 00000000000..308dfc65d28 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/python_tests/regression_tests/broadcasting_25.mlir @@ -0,0 +1,15 @@ +func.func @test( + %V__0: tensor { python_test_attrs.static_type = tensor<1xf32> }, + %V__1: tensor { python_test_attrs.static_type = tensor<1xf32> }, + %V__2: tensor { python_test_attrs.static_type = tensor<1xf32> } + ) -> tensor { + %1 = "tf.AddV2"(%V__0, %V__0) { device = "/job:localhost/replica:0/task:0/device:CPU:0" } + : (tensor, tensor) -> tensor + %2 = "tf.Rint"(%V__1) { device = "/job:localhost/replica:0/task:0/device:CPU:0" } + : (tensor) -> tensor + %3 = "tf.Exp"(%2) { device = "/job:localhost/replica:0/task:0/device:CPU:0" } + : (tensor) -> tensor + %4 = "tf.AddV2"(%V__2, %3) { device = "/job:localhost/replica:0/task:0/device:CPU:0" } + : (tensor, tensor) -> tensor + func.return %4 : tensor +} diff --git a/tensorflow/compiler/mlir/tfrt/python_tests/tf_matmul_test.py b/tensorflow/compiler/mlir/tfrt/python_tests/tf_matmul_test.py index 81a16bcf45e..f6c684fab5f 100644 --- a/tensorflow/compiler/mlir/tfrt/python_tests/tf_matmul_test.py +++ b/tensorflow/compiler/mlir/tfrt/python_tests/tf_matmul_test.py @@ -47,14 +47,16 @@ class TfMatMulTest(test.TestCase): # Matmul: [1, k] x [k, 1] def test_dot_product(self): - compiled = jitrt.compile(matmul(), "matmul", vectorize=True) + compiled = jitrt.compile( + matmul(), "matmul", vectorize=True, enable_xla_cpu_transformations=True) for _ in range(100): k = np.random.randint(1, 10) verify_matmul(compiled, 1, k, 1) # Matmul: [1, k] x [k, n] def test_vec_mat(self): - compiled = jitrt.compile(matmul(), "matmul", vectorize=True) + compiled = jitrt.compile( + matmul(), "matmul", vectorize=True, enable_xla_cpu_transformations=True) for _ in range(100): k = np.random.randint(1, 10) n = np.random.randint(1, 10) @@ -62,7 +64,8 @@ def test_vec_mat(self): # Matmul: [n, k] x [k, 1] def test_mat_vec(self): - compiled = jitrt.compile(matmul(), "matmul", vectorize=True) + compiled = jitrt.compile( + matmul(), "matmul", vectorize=True, enable_xla_cpu_transformations=True) for _ in range(100): m = np.random.randint(1, 10) k = np.random.randint(1, 10) @@ -70,7 +73,8 @@ def test_mat_vec(self): # Matmul: [m, k] x [k, n] def test_matmul(self): - compiled = jitrt.compile(matmul(), "matmul", vectorize=True) + compiled = jitrt.compile( + matmul(), "matmul", vectorize=True, enable_xla_cpu_transformations=True) for _ in range(100): m = np.random.randint(1, 10) k = np.random.randint(1, 10) diff --git a/tensorflow/compiler/mlir/tfrt/python_tests/tf_reduction_test.py b/tensorflow/compiler/mlir/tfrt/python_tests/tf_reduction_test.py index c7f43ed6975..1c8a1bd6b6d 100644 --- a/tensorflow/compiler/mlir/tfrt/python_tests/tf_reduction_test.py +++ b/tensorflow/compiler/mlir/tfrt/python_tests/tf_reduction_test.py @@ -175,7 +175,7 @@ def test_2d_row_any(self): compiled = jitrt.compile( mlir_function, 'test', vectorize=True, legalize_i1_tensors=True) - arg0 = np.random.choice(a=[False, True], size=(8, 10)).astype(np.bool) + arg0 = np.random.choice(a=[False, True], size=(8, 10)).astype(bool) [res] = jitrt.execute(compiled, [arg0]) np.testing.assert_equal(res, np.any(arg0, axis=1)) @@ -193,7 +193,7 @@ def test_2d_row_all(self): compiled = jitrt.compile( mlir_function, 'test', vectorize=True, legalize_i1_tensors=True) - arg0 = np.random.choice(a=[False, True], size=(40, 2)).astype(np.bool) + arg0 = np.random.choice(a=[False, True], size=(40, 2)).astype(bool) [res] = jitrt.execute(compiled, [arg0]) np.testing.assert_equal(res, np.all(arg0, axis=1)) diff --git a/tensorflow/compiler/mlir/tfrt/python_tests/tf_reverse_test.py b/tensorflow/compiler/mlir/tfrt/python_tests/tf_reverse_test.py new file mode 100644 index 00000000000..9990a139cde --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/python_tests/tf_reverse_test.py @@ -0,0 +1,114 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Tensorflow -> jitrt compilation.""" + +import numpy as np + +from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt +from tensorflow.python.platform import test + +jitrt = tf_jitrt.TfJitRtExecutor() + + +class TfReverseTest(test.TestCase): + + def test_1d_static(self): + mlir_function = """ + func.func @test(%input: tensor<10xf32>) -> tensor<10xf32> { + %reverse_dims = "tf.Const"() {value = dense<[0]> : tensor<1xi64>} + : () -> tensor<1xi64> + %0 = "tf.ReverseV2"(%input, %reverse_dims) + : (tensor<10xf32>, tensor<1xi64>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> + }""" + + compiled = jitrt.compile(mlir_function, 'test', vectorize=True) + + arg0 = np.random.uniform(1.0, 10.0, size=(10)).astype(np.float32) + + [res] = jitrt.execute(compiled, [arg0]) + np.testing.assert_allclose(res, np.flip(arg0, axis=0)) + + def test_1d_dynamic(self): + mlir_function = """ + func.func @test(%input: tensor) -> tensor { + %reverse_dims = "tf.Const"() {value = dense<[0]> : tensor<1xi64>} + : () -> tensor<1xi64> + %0 = "tf.ReverseV2"(%input, %reverse_dims) + : (tensor, tensor<1xi64>) -> tensor + func.return %0 : tensor + }""" + + compiled = jitrt.compile(mlir_function, 'test', vectorize=True) + + arg0 = np.random.uniform(1.0, 15.0, size=(15)).astype(np.float32) + + [res] = jitrt.execute(compiled, [arg0]) + np.testing.assert_allclose(res, np.flip(arg0, axis=0)) + + def test_2d_dynamic(self): + mlir_function = """ + func.func @test(%input: tensor) -> tensor { + %reverse_dims = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} + : () -> tensor<1xi64> + %0 = "tf.ReverseV2"(%input, %reverse_dims) + : (tensor, tensor<1xi64>) -> tensor + func.return %0 : tensor + }""" + + compiled = jitrt.compile(mlir_function, 'test', vectorize=True) + + arg0 = np.random.uniform(1.0, 10.0, size=(2, 2)).astype(np.float32) + + [res] = jitrt.execute(compiled, [arg0]) + np.testing.assert_allclose(res, np.flip(arg0, axis=1)) + + def test_3d_dynamic(self): + mlir_function = """ + func.func @test(%input: tensor) -> tensor { + %reverse_dims = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} + : () -> tensor<2xi64> + %0 = "tf.ReverseV2"(%input, %reverse_dims) + : (tensor, tensor<2xi64>) -> tensor + func.return %0 : tensor + }""" + + compiled = jitrt.compile(mlir_function, 'test', vectorize=True) + + arg0 = np.random.uniform(1.0, 30.0, size=(2, 3, 4)).astype(np.float32) + + [res] = jitrt.execute(compiled, [arg0]) + np.testing.assert_allclose(res, np.flip(arg0, axis=(0, 1))) + + def test_3d_dynamic_reverse_last(self): + mlir_function = """ + func.func @test(%input: tensor) -> tensor { + %reverse_dims = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi64>} + : () -> tensor<2xi64> + %0 = "tf.ReverseV2"(%input, %reverse_dims) + : (tensor, tensor<2xi64>) -> tensor + func.return %0 : tensor + }""" + + compiled = jitrt.compile(mlir_function, 'test', vectorize=True) + + arg0 = np.random.uniform(1.0, 30.0, size=(2, 3, 4)).astype(np.float32) + + [res] = jitrt.execute(compiled, [arg0]) + np.testing.assert_allclose(res, np.flip(arg0, axis=(0, 2))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/mlir/tfrt/python_tests/tf_scatter_test.py b/tensorflow/compiler/mlir/tfrt/python_tests/tf_scatter_test.py new file mode 100644 index 00000000000..c908a775543 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/python_tests/tf_scatter_test.py @@ -0,0 +1,54 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Tensorflow -> jitrt compilation.""" + +import numpy as np + +from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt +from tensorflow.python.platform import test + +jitrt = tf_jitrt.TfJitRtExecutor() + + +class TfScatterTest(test.TestCase): + + def test_scatter(self): + mlir_function = """ + func.func @test(%index: tensor<5x2xi32>, %updates: tensor<5x8xf32>, + %out: tensor<1x11x11xf32>) -> tensor<1x11x11xf32> { + %1 = "tf.TensorScatterAdd"(%out, %index, %updates) + : (tensor<1x11x11xf32>, tensor<5x2xi32>, tensor<5x8xf32>) -> + tensor<1x11x11xf32> + return %1 : tensor<1x11x11xf32> + } + """ + compiled = jitrt.compile(mlir_function, 'test', vectorize=True) + index = np.array([[0, 0], [0, 0], [0, 5], [0, 5], [0, 10]], dtype=np.int32) + updates = np.array( + [[1] * 8, [2] * 8, [3] * 8, [4] * 8, [5] * 8], dtype=np.float32 + ) + out = np.zeros((1, 11, 11), dtype=np.float32) + + exp_res = np.zeros((1, 11, 11), dtype=np.float32) + exp_res[0][0][:8] += 3 + exp_res[0][5][:8] += 7 + exp_res[0][10][:8] += 5 + + [res] = jitrt.execute(compiled, [index, updates, out]) + np.testing.assert_allclose(res, exp_res) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.cc b/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.cc index 5bda1f8c12c..62ec862a393 100644 --- a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.cc +++ b/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.cc @@ -147,7 +147,7 @@ void RuntimeFallbackExecutor::Prepare(llvm::StringRef mlir_input) { TfrtPipelineOptions pipeline_opts; pipeline_opts.default_device = kDefaultHostDeviceName; pipeline_opts.hoist_invariant_ops = true; - pipeline_opts.enable_native_ops = false; + pipeline_opts.sink_in_invariant_ops = false; pipeline_opts.cost_threshold = 1024; pipeline_opts.upper_cost_threshold = 100000; pipeline_opts.merge_inter_dependent_streams = true; diff --git a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.h b/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.h index a182df69778..597a2e2bf71 100644 --- a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.h +++ b/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.h @@ -59,4 +59,4 @@ class RuntimeFallbackExecutor { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_EXECUTOR_H_ +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_EXECUTOR_H_ diff --git a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.td b/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.td index 44925511380..cb515614cb6 100644 --- a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.td +++ b/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.td @@ -42,8 +42,7 @@ def RuntimeFallback_Dialect : Dialect { }]; let cppNamespace = "::mlir::tfd"; - - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let useFoldAPI = kEmitFoldAdaptorFolder; } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc index 6a4b45ab9ee..47572390666 100644 --- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc +++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" @@ -37,6 +38,8 @@ limitations under the License. namespace tensorflow { namespace { +using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; + llvm::StringRef ProcessIndexPath(mlir::ArrayAttr index_path) { if (index_path.size() == 1 && index_path[0].isa()) { // TODO(chky): Support cases where index_path is not a single string. @@ -86,7 +89,7 @@ Status MapFunctionSignaturesFromTFSavedModelMLIR( llvm::SmallVector bound_inputs; for (unsigned i = 0, e = func.getNumArguments(); i != e; ++i) { if (auto input_index_path = func.getArgAttrOfType( - i, "tf_saved_model.index_path")) { + i, kTfSavedModelIndexPathAttr)) { input_names.push_back(ProcessIndexPath(input_index_path)); auto statusor_spec = ProcessTensorSpec(func_type.getInput(i).cast()); @@ -114,7 +117,7 @@ Status MapFunctionSignaturesFromTFSavedModelMLIR( output_specs; for (unsigned i = 0, e = func.getNumResults(); i != e; ++i) { if (auto output_index_path = func.getResultAttrOfType( - i, "tf_saved_model.index_path")) { + i, kTfSavedModelIndexPathAttr)) { output_names.push_back(ProcessIndexPath(output_index_path)); auto statusor_spec = ProcessTensorSpec(func_type.getResult(i).cast()); diff --git a/tensorflow/compiler/mlir/tfrt/tests/BUILD b/tensorflow/compiler/mlir/tfrt/tests/BUILD index 0c69ce6158e..f7df07f4708 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/BUILD @@ -1,7 +1,10 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "if_oss") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], diff --git a/tensorflow/compiler/mlir/tfrt/tests/analysis/BUILD b/tensorflow/compiler/mlir/tfrt/tests/analysis/BUILD index 07185e853d5..91f9e57a9b2 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/analysis/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/analysis/BUILD @@ -1,11 +1,15 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -load("//tensorflow:tensorflow.bzl", "if_oss") +load("//tensorflow:tensorflow.bzl", "if_oss", "tf_cc_test") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], driver = "//tensorflow/compiler/mlir:run_lit.sh", + exclude = ["testdata/**"], features = if_oss(["--path=org_tensorflow/tensorflow/compiler/mlir/tfrt"]), test_file_exts = ["mlir"], ) @@ -22,3 +26,23 @@ filegroup( "@llvm-project//mlir:run_lit.sh", ], ) + +tf_cc_test( + name = "update_op_cost_in_tfrt_mlir_test", + srcs = ["update_op_cost_in_tfrt_mlir_test.cc"], + data = [ + "testdata/test.mlir", + ], + deps = [ + "//tensorflow/compiler/mlir/tfrt:transforms/update_op_cost_in_tfrt_mlir", + "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", + "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_sync_opdefs", + "//tensorflow/core:test", + "//tensorflow/core/platform:resource_loader", + "//tensorflow/core/tfrt/fallback:cost_recorder", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:Parser", + "@tf_runtime//:init_tfrt_dialects", + ], +) diff --git a/tensorflow/compiler/mlir/tfrt/tests/analysis/testdata/test.mlir b/tensorflow/compiler/mlir/tfrt/tests/analysis/testdata/test.mlir new file mode 100644 index 00000000000..d4d843cd4ab --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/analysis/testdata/test.mlir @@ -0,0 +1,7 @@ +func.func @test(%ch: !tfrt.chain, %arg0: !corert.tensorhandle, %arg1_th: !corert.tensorhandle) { + %cpu = corert.get_op_handler %ch "cpu" + %0 = corert.executeop(%cpu) "tf.Relu"(%arg0) { T = f32 } : 1 + %arg1 = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor %arg1_th {_tfrt_cost = 1 : i64, device = "/CPU:0"} : (!corert.tensorhandle) -> (!tfrt_fallback.tf_tensor) + %1 = tfrt_fallback_async.executeop key(0) cost(100) device("/CPU:0") "tf.Relu"(%arg1) { T = f32 } : 1 + tfrt.return +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/analysis/update_op_cost_in_tfrt_mlir_test.cc b/tensorflow/compiler/mlir/tfrt/tests/analysis/update_op_cost_in_tfrt_mlir_test.cc new file mode 100644 index 00000000000..0c50bbbdea8 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/analysis/update_op_cost_in_tfrt_mlir_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.h" + +#include +#include +#include + +#include +#include "absl/container/flat_hash_map.h" +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" +#include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.h" +#include "tensorflow/core/platform/resource_loader.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" +#include "tfrt/init_tfrt_dialects.h" // from @tf_runtime + +namespace tensorflow { +namespace { + +constexpr char kCostAttrName[] = "_tfrt_cost"; +constexpr char kOpKeyAttrName[] = "op_key"; + +absl::flat_hash_map GetOpCostMap(mlir::ModuleOp op) { + absl::flat_hash_map op_cost_map; + op.walk([&](mlir::Operation* op) { + const auto cost_attr = op->getAttrOfType(kCostAttrName); + if (!cost_attr) return; + const auto op_key_attr = + op->getAttrOfType(kOpKeyAttrName); + if (!op_key_attr) return; + op_cost_map[op_key_attr.getInt()] = cost_attr.getInt(); + }); + return op_cost_map; +} + +TEST(CostUpdateTest, Basic) { + std::string saved_model_mlir_path = tensorflow::GetDataDependencyFilepath( + "tensorflow/compiler/mlir/tfrt/tests/analysis/testdata/test.mlir"); + + mlir::DialectRegistry registry; + tfrt::RegisterTFRTDialects(registry); + registry.insert(); + registry.insert(); + mlir::MLIRContext context(registry); + auto module = + mlir::parseSourceFile(saved_model_mlir_path, &context); + ASSERT_TRUE(module); + + // Create a cost recorder with fake cost records. + auto expected_op_cost_map = GetOpCostMap(module.get()); + EXPECT_EQ(expected_op_cost_map.size(), 1); + unsigned int seed = 23579; + for (auto& [op_key, cost] : expected_op_cost_map) { + cost = rand_r(&seed) % 1000; + } + tensorflow::tfrt_stub::CostRecorder cost_recorder; + for (const auto& [op_key, cost] : expected_op_cost_map) { + cost_recorder.RecordCostNanosecond(op_key, cost); + } + + // Update the TFRT MLIR with the cost recorder. + tfrt_compiler::UpdateOpCostInTfrtMlir(module.get(), cost_recorder); + + // Check the updated costs. + const auto got_op_cost_map = GetOpCostMap(module.get()); + EXPECT_THAT(got_op_cost_map, ::testing::ContainerEq(expected_op_cost_map)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/tests/batch_function_fallback_resource_variable_as_captured_tensor.mlir b/tensorflow/compiler/mlir/tfrt/tests/batch_function_fallback_resource_variable_as_captured_tensor.mlir index 9e8796d1efc..0791e2d8c9b 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/batch_function_fallback_resource_variable_as_captured_tensor.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/batch_function_fallback_resource_variable_as_captured_tensor.mlir @@ -1,4 +1,4 @@ -// RUN: tf-tfrt-opt -pass-pipeline='tf-executor-to-tfrt-pipeline{target-tpurt=true}' %s | FileCheck %s +// RUN: tf-tfrt-opt -pass-pipeline='builtin.module(tf-executor-to-tfrt-pipeline{target-tpurt=true})' %s | FileCheck %s module attributes {tf_saved_model.semantics} { // CHECK-LABEL: func @main diff --git a/tensorflow/compiler/mlir/tfrt/tests/batch_function_lowering.mlir b/tensorflow/compiler/mlir/tfrt/tests/batch_function_lowering.mlir index 448686d176f..eaeaa638328 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/batch_function_lowering.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/batch_function_lowering.mlir @@ -1,4 +1,4 @@ -// RUN: tf-tfrt-opt -tf-executor-to-tfrt-pipeline="enable-native-ops=false func-use-fallback-tensor=true" %s | FileCheck %s --dump-input=always +// RUN: tf-tfrt-opt -tf-executor-to-tfrt-pipeline="func-use-fallback-tensor=true" %s | FileCheck %s --dump-input=always func.func private @batched_function(%arg0: tensor<1x3xf32> {tf._user_specified_name = "0"}, %arg1: tensor<*x!tf_type.resource>) -> tensor<1x3xf32> attributes {tf._input_shapes = [#tf_type.shape<1x3>, #tf_type.shape<*>], tf.signature.is_stateful} { %0 = "tf.ReadVariableOp"(%arg1) {device = "/device:CPU:0"} : (tensor<*x!tf_type.resource>) -> tensor<1x3xf32> diff --git a/tensorflow/compiler/mlir/tfrt/tests/fuse_tpu_compile_and_execute_ops.mlir b/tensorflow/compiler/mlir/tfrt/tests/fuse_tpu_compile_and_execute_ops.mlir index 984d28af38e..01f1731e83a 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/fuse_tpu_compile_and_execute_ops.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/fuse_tpu_compile_and_execute_ops.mlir @@ -89,3 +89,21 @@ func.func private @test_fuse_dynamic_dimension_ops(%arg0: tensor, %arg1 } +// ----- + +module attributes {tf_saved_model.semantics} { + +// CHECK-LABEL: func private @reorder_execute_arg_defining_ops +// CHECK: tf.VarHandleOp +// CHECK-NEXT: tf.ReadVariableOp +// CHECK-NEXT: tf.TPUCompileMlirAndExecute +func.func private @reorder_execute_arg_defining_ops(%arg0: tensor<1x3xf32> {tf.device = "/CPU:0"}) -> (tensor<1x1xf32> {tf.device = "/TPU:0"}) { + %compilation_status, %program = "tf._TPUCompileMlir"() {device = "/CPU:0", metadata = "metadata", mlir_module = "propgram"} : () -> (tensor, tensor<3x!tf_type.string>) + "tf.TPUCompileSucceededAssert"(%compilation_status) {device = "/CPU:0"} : (tensor) -> () + %0 = "tf.VarHandleOp"() {_xla_inferred_shapes = [#tf_type.shape<>], allowed_devices = [], container = "", device = "/CPU:0", shared_name = "y"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = "/CPU:0"} : (tensor>>) -> tensor<3x1xf32> + %2 = "tf.TPUExecute"(%arg0, %1, %program) {_producer_name = "UNKNOWN", device = "/TPU:0"} : (tensor<1x3xf32>, tensor<3x1xf32>, tensor<3x!tf_type.string>) -> tensor<1x1xf32> + return %2 : tensor<1x1xf32> +} + +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops.mlir b/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops.mlir index 411bbb54cf0..1cafb216743 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/hoist_invariant_ops.mlir @@ -128,10 +128,21 @@ func.func @hoist_var_read_write() -> (tensor {tf_saved_model.index_path = [ module attributes {tf_saved_model.semantics} { -// Test not hoisting varhandle op that used by control flow ops. +// Test not hoisting read variable op that used by control flow ops if var handle op and read variable op are separated, but still hoists const ops and var handle ops. +// CHECK-LABEL: func @_tfrt_resource_init +// CHECK: [[handle:%.*]] = "tf.VarHandleOp"() {container = "", shared_name = "x"} : () -> tensor>> +// CHECK: "tf._TfrtSetResource"([[handle]]) +// CHECK-SAME: index = [[handle_index:.*]] +// CHECK: [[handle1:%.*]] = "tf.VarHandleOp"() {container = "", shared_name = "x"} : () -> tensor>> +// CHECK: "tf._TfrtSetResource"([[handle1]]) +// CHECK-SAME: index = [[handle1_index:.*]] +// CHECK: [[const:%.*]] = "tf.Const"() {device = "/CPU:0", value = dense : tensor} : () -> tensor +// CHECK: "tf._TfrtSetResource"([[const]]) +// CHECK-SAME: index = [[const_index:.*]] func.func private @some_func( %arg: tensor>>) -> tensor { + // CHECK: tf.ReadVariableOp %0 = "tf.ReadVariableOp"(%arg) {device = "cpu"} : (tensor>>) -> tensor func.return %0 : tensor } @@ -139,6 +150,8 @@ func.func private @some_func( // CHECK-LABEL: func @test_not_hoist_stateful_call func.func @not_hoist_stateful_call(%arg: tensor {tf_saved_model.index_path = ["input"]}) -> (tensor {tf_saved_model.index_path = ["r"]}) attributes {tf_saved_model.exported_names = ["test_not_hoist_stateful_call"]} { + // CHECK-NOT: tf.VarHandleOp + // CHECK: "tf._TfrtGetResource"() %handle = "tf.VarHandleOp"() {container = "", shared_name = "x"} : () -> tensor>> // CHECK: tf.StatefulPartitionedCall %x = "tf.StatefulPartitionedCall"(%handle) {device = "/CPU:0", config = "", config_proto = "", executor_type = "", f = @some_func} : (tensor>>) -> (tensor) @@ -150,6 +163,8 @@ func.func @not_hoist_stateful_call(%arg: tensor {tf_saved_model.index_path func.func @not_hoist_if(%arg: tensor {tf_saved_model.index_path = ["input"]}) -> (tensor {tf_saved_model.index_path = ["r"]}) attributes {tf_saved_model.exported_names = ["test_not_hoist_if"]} { %handle = "tf.VarHandleOp"() {container = "", shared_name = "x"} : () -> tensor>> + // CHECK-NOT: tf.Const + // CHECK: "tf._TfrtGetResource"() %cond = "tf.Const"() {device = "/CPU:0", value = dense : tensor} : () -> tensor // CHECK: tf.If %x = "tf.If"(%cond, %handle) {then_branch = @some_func, else_branch = @some_func, is_stateless = false} : (tensor, tensor>>) -> tensor @@ -163,6 +178,39 @@ func.func @not_hoist_if(%arg: tensor {tf_saved_model.index_path = ["input"] module attributes {tf_saved_model.semantics} { +// Test hoist var handle op and read variable op in the batch function. + +// CHECK-LABEL: func private @batched_function +func.func private @batched_function(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> + attributes {tf._input_shapes = [#tf_type.shape<1x3>, #tf_type.shape<*>], tf.signature.is_stateful} { + // CHECK-NOT: tf.VarHandleOp + // CHECK-NOT: tf.ReadVariableOp + // CHECK: "tf._TfrtGetResource"() + %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = "/device:CPU:0"} : (tensor>>) -> tensor<1x3xf32> + %2 = "tf.AddV2"(%arg0, %1) {device = "/device:CPU:0"} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %3 = "tf.Identity"(%2) {device = "/device:CPU:0"} : (tensor<1x3xf32>) -> tensor<1x3xf32> + func.return %3 : tensor<1x3xf32> +} + +// CHECK-LABEL: func @main +func.func @main(%arg0: tensor<1x3xf32> {tf_saved_model.index_path = ["input"]}) -> (tensor<*xf32> {tf_saved_model.index_path = ["r"]}) + attributes {tf_saved_model.exported_names = ["main"]} { + // CHECK-NOT: tf.VarHandleOp + // CHECK: "tf._TfrtGetResource"() + %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor>> + // CHECK: "tf.BatchFunction"(%arg0, %0) + // CHECK: operand_segment_sizes = array + %1 = "tf.BatchFunction"(%arg0, %0) {allowed_batch_sizes = [6], batch_timeout_micros = 100000 : i64, batching_queue = "", container = "", device = "/device:CPU:0", enable_large_batch_splitting = false, f = @batched_function, max_batch_size = 6 : i64, max_enqueued_batches = 10 : i64, num_batch_threads = 1 : i64, operand_segment_sizes = array, shared_name = "batch/"} : (tensor<1x3xf32>, tensor>>) -> tensor<*xf32> + func.return %1 : tensor<*xf32> +} + +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // Test not hoisting callees in init functions. "tf_saved_model.session_initializer"() {initializers = [@init]} : () -> () diff --git a/tensorflow/compiler/mlir/tfrt/tests/ir/BUILD b/tensorflow/compiler/mlir/tfrt/tests/ir/BUILD index 307f5d867b9..e4662bd66d7 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/ir/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/ir/BUILD @@ -1,7 +1,10 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "if_oss", "tf_cc_test") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/BUILD b/tensorflow/compiler/mlir/tfrt/tests/jit/BUILD index 1d57ed7b6df..d4abbb3bc44 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/jit/BUILD @@ -2,7 +2,10 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "if_oss") load("@tf_runtime//:build_defs.bzl", "tfrt_cc_test") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], @@ -27,6 +30,7 @@ tfrt_cc_test( name = "tf_jitrt_benchmark_test", srcs = ["tf_jitrt_benchmark_test.cc"], deps = [ + "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tfrt:tf_jitrt_pipeline", "//tensorflow/compiler/xla/mlir/runtime/transforms:compiler", diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/detensorize_linalg.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/detensorize_linalg.mlir deleted file mode 100644 index 9df18316667..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/detensorize_linalg.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: tf-tfrt-opt %s -tf-jitrt-detensorize-linalg | FileCheck %s - -#id = affine_map<(d0) -> (d0)> -#empty = affine_map<(d0) -> ()> - -// CHECK-LABEL: func @detensorize -func.func @detensorize(%arg : tensor<100xi32>) -> (tensor<100xi1>) attributes {} { - %c10 = arith.constant 10 : i32 - %tensor = tensor.from_elements %c10 : tensor - %init = tensor.empty() : tensor<100xi1> - %result = linalg.generic { - indexing_maps = [#id, #empty, #id], - iterator_types = ["parallel"]} - ins(%arg, %tensor : tensor<100xi32>, tensor) - outs(%init : tensor<100xi1>) { - ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): - %0 = arith.cmpi slt, %arg0, %arg1 : i32 - linalg.yield %0 : i1 - } -> tensor<100xi1> - func.return %result : tensor<100xi1> -} -// CHECK: %[[C10:.*]] = arith.constant 10 : i32 -// CHECK: linalg.generic { -// CHECK-SAME: indexing_maps = [#{{map[0-9]*}}, #{{[map0-9]*}}, #{{[map0-9]*}}], -// CHECK-SAME: iterator_types = ["parallel"] -// CHECK-SAME: ins(%{{.*}}, %[[C10]] : tensor<100xi32>, i32) diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/symbolic_shape_optimization.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/symbolic_shape_optimization.mlir deleted file mode 100644 index 1310af681c6..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/symbolic_shape_optimization.mlir +++ /dev/null @@ -1,203 +0,0 @@ -// RUN: tf-tfrt-opt %s -split-input-file -tf-jitrt-symbolic-shape-optimization \ -// RUN: | FileCheck %s - -// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)> - -// CHECK: @optimize_1dx1d_bcast( -// CHECK-SAME: %[[ARG0:[a-z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: tensor -func.func @optimize_1dx1d_bcast( - %arg0: tensor - {rt.symbolic_shape = dense<[-2]> : tensor<1xi64>}, - %arg1: tensor - {rt.symbolic_shape = dense<[-2]> : tensor<1xi64>} -) -> tensor { - %0 = shape.shape_of %arg0 : tensor -> tensor<1xindex> - %1 = shape.shape_of %arg1 : tensor -> tensor<1xindex> - %2 = shape.broadcast %0, %1 : tensor<1xindex>, tensor<1xindex> - -> tensor<1xindex> - - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] - // CHECK: %[[OUT:.*]] = tensor.empty(%[[D0]]) - // CHECK: %[[RET:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] - // CHECK-SAME: iterator_types = ["parallel"] - // CHECK-SAME: ins(%[[ARG0]] : tensor) - // CHECK-SAME: outs(%[[OUT]] : tensor) - %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) - {broadcast_dimensions = dense<[0]> : tensor<1xi64>} - : (tensor, tensor<1xindex>) -> tensor - - func.return %3: tensor -} - -// ----- - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK: @optimize_1dx2d_bcast_const_shape( -// CHECK-SAME: %[[ARG0:[a-z0-9]+]]: tensor<512xf32> -// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: tensor -func.func @optimize_1dx2d_bcast_const_shape( - %arg0: tensor<512xf32>, - %arg1: tensor - {rt.symbolic_shape = dense<[-2, 512]> : tensor<2xi64>} -) -> tensor { - %0 = shape.const_shape [512] : tensor<1xindex> - %1 = shape.shape_of %arg1 : tensor -> tensor<2xindex> - %2 = shape.broadcast %0, %1 : tensor<1xindex>, tensor<2xindex> - -> tensor<2xindex> - - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[D0:.*]] = tensor.dim %[[ARG1]], %[[C0]] - // CHECK: %[[OUT:.*]] = tensor.empty(%[[D0]]) - // CHECK: %[[RET:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] - // CHECK-SAME: iterator_types = ["parallel", "parallel"] - // CHECK-SAME: ins(%[[ARG0]] : tensor<512xf32>) - // CHECK-SAME: outs(%[[OUT]] : tensor) - %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) - {broadcast_dimensions = dense<[1]> : tensor<1xi64>} - : (tensor<512xf32>, tensor<2xindex>) -> tensor - - func.return %3: tensor -} - -// ----- - -// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0)> - -// CHECK: @optimize_1dx1dx1d_bcast( -// CHECK-SAME: %[[ARG0:[a-z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-z0-9]+]]: tensor -func.func @optimize_1dx1dx1d_bcast( - %arg0: tensor - {rt.symbolic_shape = dense<[-2]> : tensor<1xi64>}, - %arg1: tensor - {rt.symbolic_shape = dense<[-2]> : tensor<1xi64>}, - %arg2: tensor - {rt.symbolic_shape = dense<[-2]> : tensor<1xi64>} -) -> tensor { - %0 = shape.shape_of %arg0 : tensor -> tensor<1xindex> - %1 = shape.shape_of %arg1 : tensor -> tensor<1xindex> - %2 = shape.shape_of %arg2 : tensor -> tensor<1xindex> - %3 = shape.broadcast %0, %1 : tensor<1xindex>, tensor<1xindex> - -> tensor<1xindex> - %4 = shape.broadcast %3, %2 : tensor<1xindex>, tensor<1xindex> - -> tensor<1xindex> - - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] - // CHECK: %[[OUT:.*]] = tensor.empty(%[[D0]]) - // CHECK: %[[RET:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] - // CHECK-SAME: iterator_types = ["parallel"] - // CHECK-SAME: ins(%[[ARG0]] : tensor) - // CHECK-SAME: outs(%[[OUT]] : tensor) - %5 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) - {broadcast_dimensions = dense<[0]> : tensor<1xi64>} - : (tensor, tensor<1xindex>) -> tensor - - func.return %5: tensor -} - -// ----- - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1)> - -// CHECK: @optimize_2dx1d_bcast( -// CHECK-SAME: %[[ARG0:[a-z0-9]+]]: tensor<10x?xf32> -// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: tensor -func.func @optimize_2dx1d_bcast( - %arg0: tensor<10x?xf32> - {rt.symbolic_shape = dense<[10, -2]> : tensor<2xi64>}, - %arg1: tensor - {rt.symbolic_shape = dense<[-2]> : tensor<1xi64>} -) -> (tensor<10x?xf32>, tensor<10x?xf32>) { - %0 = shape.shape_of %arg0 : tensor<10x?xf32> -> tensor<2xindex> - %1 = shape.shape_of %arg1 : tensor -> tensor<1xindex> - %2 = shape.broadcast %0, %1 : tensor<2xindex>, tensor<1xindex> - -> tensor<2xindex> - - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] - - // CHECK: %[[OUT0:.*]] = tensor.empty(%[[D1]]) - // CHECK: %[[RET0:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]] - // CHECK-SAME: iterator_types = ["parallel", "parallel"] - // CHECK-SAME: ins(%[[ARG0]] : tensor<10x?xf32>) - // CHECK-SAME: outs(%[[OUT0]] : tensor<10x?xf32>) - %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) - {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} - : (tensor<10x?xf32>, tensor<2xindex>) -> tensor<10x?xf32> - - // CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C1]] - // CHECK: %[[OUT1:.*]] = tensor.empty(%[[D0]]) - // CHECK: %[[RET1:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]]] - // CHECK-SAME: iterator_types = ["parallel", "parallel"] - // CHECK-SAME: ins(%[[ARG1]] : tensor) - // CHECK-SAME: outs(%[[OUT1]] : tensor<10x?xf32>) - %4 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %2) - {broadcast_dimensions = dense<[1]> : tensor<1xi64>} - : (tensor, tensor<2xindex>) -> tensor<10x?xf32> - - // CHECK: return %[[RET0]], %[[RET1]] - func.return %3, %4: tensor<10x?xf32>, tensor<10x?xf32> -} - -// ----- - -// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)> -// CHECK: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (0, d1, 0)> - -// CHECK: @optimize_3dx3d_bcast( -// CHECK-SAME: %[[ARG0:[a-z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: tensor<1x?x1xf32> -func.func @optimize_3dx3d_bcast( - %arg0: tensor - {rt.symbolic_shape = dense<[-2, 1, -3]> : tensor<3xi64>}, - %arg1: tensor<1x?x1xf32> - {rt.symbolic_shape = dense<[1, -4, 1]> : tensor<3xi64>} -) -> (tensor, tensor) { - %0 = shape.shape_of %arg0 : tensor -> tensor<3xindex> - %1 = shape.shape_of %arg1 : tensor<1x?x1xf32> -> tensor<3xindex> - %2 = shape.broadcast %0, %1 : tensor<3xindex>, tensor<3xindex> - -> tensor<3xindex> - - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index - - // CHECK: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] - // CHECK: %[[D1:.*]] = tensor.dim %[[ARG1]], %[[C1]] - // CHECK: %[[D2:.*]] = tensor.dim %[[ARG0]], %[[C2]] - // CHECK: %[[OUT0:.*]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) - // CHECK: %[[RET0:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] - // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] - // CHECK-SAME: ins(%[[ARG0]] : tensor) - // CHECK-SAME: outs(%[[OUT0]] : tensor) - %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) - {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} - : (tensor, tensor<3xindex>) -> tensor - - // CHECK: %[[OUT1:.*]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) - // CHECK: %[[RET1:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP1]]] - // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] - // CHECK-SAME: ins(%[[ARG1]] : tensor<1x?x1xf32>) - // CHECK-SAME: outs(%[[OUT1]] : tensor) - %4 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %2) - {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} - : (tensor<1x?x1xf32>, tensor<3xindex>) -> tensor - - // CHECK: return %[[RET0]], %[[RET1]] - func.return %3, %4: tensor, tensor -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_benchmark_test.cc b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_benchmark_test.cc index a08c3c725d8..cabd1c39fad 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_benchmark_test.cc +++ b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_benchmark_test.cc @@ -17,6 +17,7 @@ #include #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.h" #include "tensorflow/compiler/xla/mlir/runtime/transforms/compiler.h" @@ -90,6 +91,9 @@ static void BM_InstantiateExecutable(::testing::benchmark::State& state) { opts.compiler.create_compilation_pipeline = [&](xla::runtime::PassManager& passes) { TfJitRtPipelineOptions opts; + opts.enable_xla_cpu_transformations = + tensorflow::GetJitRtFlags().enable_xla_cpu_transformations; + opts.lower_to_mmt4d = tensorflow::GetJitRtFlags().pack_matmul; // Lower from Tensorflow to Linalg on buffers. CreateTfJitRtPipeline(*passes, opts); diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_codegen_transpose.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_codegen_transpose.mlir index 5ccf7201a66..50b31cd12df 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_codegen_transpose.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_codegen_transpose.mlir @@ -1,7 +1,5 @@ // RUN: tf-tfrt-opt -tf-jitrt-pipeline="vectorize codegen-transpose" -split-input-file %s | FileCheck %s -// Verify that transpose codegen is working within the pipeline. - func.func @transpose_2d(%arg0: tensor) -> tensor { %0 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>, @@ -18,41 +16,12 @@ func.func @transpose_2d(%arg0: tensor) -> tensor { // 8x8 tiling. // CHECK: scf.parallel {{.*}} step (%[[C8]], %[[C8]]) { // Vector xfer reads: unrolled second vector dimension. -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read +// CHECK-COUNT-8: vector.transfer_read // AVX2 shuffle/asm sequence. // CHECK-COUNT-12: vector.shuffle // CHECK-COUNT-8: llvm.inline_asm // CHECK-COUNT-8: vector.shuffle // Vector xfer writes: unrolled second vector dimension. -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write // ----- @@ -70,41 +39,12 @@ func.func @transpose_3d_021(%arg0: tensor) -> tensor { // 1x8x8 tiling. // CHECK: scf.parallel {{.*}} step (%[[C1]], %[[C8]], %[[C8]]) { // Vector xfer reads: unrolled second vector dimension. -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read +// CHECK-COUNT-8: vector.transfer_read // AVX2 shuffle/asm sequence. // CHECK-COUNT-12: vector.shuffle // CHECK-COUNT-8: llvm.inline_asm // CHECK-COUNT-8: vector.shuffle // Vector xfer writes: unrolled second vector dimension. -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write // ----- @@ -120,43 +60,14 @@ func.func @transpose_3d_201(%arg0: tensor) -> tensor { // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // 8x1x8 tiling. -// CHECK: scf.parallel {{.*}} step (%[[C8]], %[[C1]], %[[C8]]) { +// CHECK: scf.parallel {{.*}} step (%[[C1]], %[[C8]], %[[C8]]) { // Vector xfer reads: unrolled second vector dimension. -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read +// CHECK-COUNT-8: vector.transfer_read // AVX2 shuffle/asm sequence. // CHECK-COUNT-12: vector.shuffle // CHECK-COUNT-8: llvm.inline_asm // CHECK-COUNT-8: vector.shuffle // Vector xfer writes: unrolled second vector dimension. -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write // ----- @@ -174,41 +85,12 @@ func.func @transpose_3d_210(%arg0: tensor) -> tensor { // 8x1x8 tiling. // CHECK: scf.parallel {{.*}} step (%[[C8]], %[[C1]], %[[C8]]) { // Vector xfer reads: unrolled second vector dimension. -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read +// CHECK-COUNT-8: vector.transfer_read // AVX2 shuffle/asm sequence. // CHECK-COUNT-12: vector.shuffle // CHECK-COUNT-8: llvm.inline_asm // CHECK-COUNT-8: vector.shuffle // Vector xfer writes: unrolled second vector dimension. -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write // ----- @@ -224,64 +106,11 @@ func.func @transpose_3d_120(%arg0: tensor) -> tensor { // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // 1x8x8 tiling. -// CHECK: scf.parallel {{.*}} step (%[[C1]], %[[C8]], %[[C8]]) { +// CHECK: scf.parallel {{.*}} step (%[[C8]], %[[C1]], %[[C8]]) { // Vector xfer reads: unrolled second vector dimension. -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_read +// CHECK-COUNT-8: vector.transfer_read // AVX2 shuffle/asm sequence. // CHECK-COUNT-12: vector.shuffle // CHECK-COUNT-8: llvm.inline_asm // CHECK-COUNT-8: vector.shuffle // Vector xfer writes: unrolled second vector dimension. -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write -// CHECK-NEXT: affine.apply -// CHECK-NEXT: vector.transfer_write - -// ----- - -func.func @transpose_3d_102(%arg0: tensor) -> tensor { - %0 = "tf.Const"() { value = dense<[1, 0, 2]> : tensor<3xi64> } - : () -> tensor<3xi64> - %1 = "tf.Transpose"(%arg0, %0) - : (tensor, tensor<3xi64>) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @transpose_3d_102 -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// 1x1x8 tiling. -// CHECK: scf.parallel {{.*}} step (%[[C1]], %[[C1]], %[[C8]]) { -// Vector xfer read: we only vectorize one dimension for "memcopy" transposes. -// CHECK-NEXT: vector.transfer_read -// No transposition is required here so no AVX2 shuffle/asm should be generated. -// CHECK-NOT: vector.shuffle -// CHECK-NOT: llvm.inline_asm -// Vector xfer write: we only vectorize one dimension for "memcopy" transposes. -// CHECK-NEXT: vector.transfer_write - diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_codegen_transpose_detection.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_codegen_transpose_detection.mlir deleted file mode 100644 index c32148cf963..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_codegen_transpose_detection.mlir +++ /dev/null @@ -1,72 +0,0 @@ -// RUN: tf-tfrt-opt -tf-jitrt-tile-transpose -split-input-file %s | FileCheck %s - -// Make sure that transpose codegen passes only trigger on generic ops -// implementing a transpose operation. - -#map0 = affine_map<(d0, d1) -> (d1, d0)> -#map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @transpose_2d(%arg0: tensor) -> tensor { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %0 = tensor.dim %arg0, %c0 : tensor - %1 = tensor.dim %arg0, %c1 : tensor - %2 = tensor.empty(%1, %0) : tensor - %3 = linalg.generic {indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : tensor) outs(%2 : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - linalg.yield %arg1 : f32 - } -> tensor - func.return %3 : tensor -} - -// CHECK-LABEL: func @transpose_2d( -// CHECK: gml_st.loop -// CHECK: linalg.generic -// CHECK: linalg.yield -// CHECK: gml_st.yield - -// ----- - -#map0 = affine_map<(d0, d1) -> (d0, d0)> -#map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @identity(%arg0: tensor) -> tensor { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %0 = tensor.dim %arg0, %c0 : tensor - %1 = tensor.dim %arg0, %c1 : tensor - %2 = tensor.empty(%1, %0) : tensor - %3 = linalg.generic {indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : tensor) outs(%2 : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - linalg.yield %arg1 : f32 - } -> tensor - func.return %3 : tensor -} - -// CHECK-LABEL: func @identity( -// CHECK-NOT: gml_st.loop - -// ----- - -#map0 = affine_map<(d0, d1) -> (d1, d0)> -#map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @transpose_add(%arg0: tensor) -> tensor{ - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %0 = tensor.dim %arg0, %c0 : tensor - %1 = tensor.dim %arg0, %c1 : tensor - %2 = tensor.empty(%1, %0) : tensor - %3 = linalg.generic {indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : tensor) outs(%2 : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %add = arith.addf %arg1, %arg1 : f32 - linalg.yield %add : f32 - } -> tensor - func.return %3 : tensor -} - -// CHECK-LABEL: func @transpose_add( -// CHECK-NOT: gml_st.loop diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_fuse_fill_into_tiled_reduction.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_fuse_fill_into_tiled_reduction.mlir deleted file mode 100644 index 13ae6536c78..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_fuse_fill_into_tiled_reduction.mlir +++ /dev/null @@ -1,230 +0,0 @@ -// RUN: tf-tfrt-opt -tf-jitrt-fuse-fill-into-tiled-reduction %s \ -// RUN: --split-input-file |\ -// RUN: FileCheck %s - -#map0 = affine_map<(d0)[s0] -> (4, -d0 + s0)> -#map1 = affine_map<(d0)[s0] -> (2, -d0 + s0)> -#map2 = affine_map<(d0, d1) -> (d0, d1)> -#map3 = affine_map<(d0, d1) -> (d0)> -func.func @reduce_row_sum_2d(%lhs: tensor, %rhs: tensor) -> tensor { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c2 = arith.constant 2 : index - %c1 = arith.constant 1 : index - %0 = tensor.dim %lhs, %c0 : tensor - %1 = tensor.empty(%0) : tensor - %fill = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor - %3 = tensor.dim %lhs, %c0 : tensor - %4 = tensor.dim %lhs, %c1 : tensor - %5 = gml_st.loop (%i, %j) = (%c0, %c0) to (%3, %4) step (%c4, %c2) - ins (%lhs_ = %lhs: tensor, %rhs_ = %rhs: tensor) - outs (%fill_ = %fill: tensor) - iterators[#gml_st.iterator_type, - #gml_st.iterator_type] { - %6 = affine.min #map0(%i)[%3] - %7 = affine.min #map1(%j)[%4] - %8 = tensor.extract_slice %lhs_[%i, %j] [%6, %7] [1, 1] - : tensor to tensor - %9 = affine.min #map0(%i)[%3] - %10 = affine.min #map1(%j)[%4] - %11 = tensor.extract_slice %rhs_[%i, %j] [%9, %10] [1, 1] - : tensor to tensor - %12 = affine.min #map0(%i)[%3] - %13 = tensor.extract_slice %fill_[%i] [%12] [1] - : tensor to tensor - %14 = linalg.generic { - indexing_maps = [#map2, #map2, #map3], - iterator_types = ["parallel", "reduction"]} - ins(%8, %11 : tensor, tensor) - outs(%13 : tensor) { - ^bb0(%arg7: f32, %arg8: f32, %arg9: f32): - %16 = arith.mulf %arg7, %arg8 : f32 - %17 = arith.addf %16, %arg9 : f32 - linalg.yield %17 : f32 - } -> tensor - %15 = tensor.insert_slice %14 into %fill_[%i] [%12] [1] - : tensor into tensor - gml_st.yield %15 : tensor - } - func.return %5 : tensor -} -// CHECK-LABEL: func @reduce_row_sum_2d( -// CHECK-SAME: %[[LHS:.*]]: tensor, -// CHECK-SAME: %[[RHS:.*]]: tensor) -> tensor - -// CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - -// CHECK: %[[DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : [[TY_2D:.*]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]]) : [[TY_1D:.*]] -// CHECK: %[[INIT_TILE:.*]] = tensor.empty() : tensor<4xf32> -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[C0_F32]]{{.*}}outs(%[[INIT]] -// CHECK: %[[DIM_0_:.*]] = tensor.dim %[[LHS]], %[[C0]] : [[TY_2D]] -// CHECK: %[[DIM_1:.*]] = tensor.dim %[[LHS]], %[[C1]] : [[TY_2D]] - -// CHECK: gml_st.loop (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[DIM_0_]], %[[DIM_1]]) step (%[[C4]], %[[C2]]) -// CHECK-SAME: ins (%[[LHS_:.*]] = %[[LHS]]: [[TY_2D]], -// CHECK-SAME: %[[RHS_:.*]] = %[[RHS]]: [[TY_2D]]) -// CHECK-SAME: outs (%[[OUT_:.*]] = %[[FILL]]: [[TY_1D]], -// CHECK-SAME: %[[INIT_TILE_:.*]] = %[[INIT_TILE]]: tensor<4xf32>) - -// CHECK: %[[LHS_SUB:.*]] = tensor.extract_slice %[[LHS_]][%[[I]], %[[J]]] -// CHECK: %[[RHS_SUB:.*]] = tensor.extract_slice %[[RHS_]][%[[I]], %[[J]]] -// CHECK: %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]]] -// CHECK: %[[INIT_TILE_SUB:.*]] = tensor.extract_slice %[[INIT_TILE_]][0] - -// CHECK: %[[FILL_SUB:.*]] = linalg.fill ins(%[[C0_F32]]{{.*}}outs(%[[INIT_TILE_SUB]] -// CHECK: %[[SUM_OF_PROD_SUB:.*]] = linalg.generic -// CHECK-SAME: ins(%[[LHS_SUB]], %[[RHS_SUB]] : [[TY_2D]], [[TY_2D]]) -// CHECK-SAME: outs(%[[FILL_SUB]] : [[TY_1D]]) -// CHECK: mulf -// CHECK: addf -// CHECK-NEXT: linalg.yield - -// CHECK: %[[ACC:.*]] = linalg.generic -// CHECK-SAME: ins(%[[SUM_OF_PROD_SUB]] : [[TY_1D]]) -// CHECK-SAME: outs(%[[OUT_SUB]] : [[TY_1D]]) { -// CHECK-NOT: mulf -// CHECK: addf -// CHECK-NEXT: linalg.yield - -// CHECK: %[[INIT_TILE_UPDATE:.*]] = tensor.insert_slice %[[SUM_SUB:.*]] into %[[INIT_TILE_]] -// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[ACC:.*]] into %[[OUT_]] -// CHECK: gml_st.yield %[[UPDATE]], %[[INIT_TILE_UPDATE]] : [[TY_1D]], tensor<4xf32> - -// ----- - -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0)> -module { - func.func @reduce_row_sum_2d_static(%in: tensor<8x16xf32>) -> tensor<8xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %c2 = arith.constant 2 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index - %c0 = arith.constant 0 : index - %0 = tensor.empty() : tensor<8xf32> - %fill = linalg.fill ins(%cst : f32) outs(%0 : tensor<8xf32>) -> tensor<8xf32> - %2 = gml_st.loop (%i, %j) = (%c0, %c0) to (%c8, %c16) step (%c4, %c2) - ins (%in_ = %in: tensor<8x16xf32>) - outs (%fill_ = %fill: tensor<8xf32>) - iterators[#gml_st.iterator_type, - #gml_st.iterator_type] { - %3 = tensor.extract_slice %in_[%i, %j] [4, 2] [1, 1] - : tensor<8x16xf32> to tensor<4x2xf32> - %4 = tensor.extract_slice %fill_[%i] [4] [1] - : tensor<8xf32> to tensor<4xf32> - %5 = linalg.generic { - indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "reduction"]} - ins(%3 : tensor<4x2xf32>) - outs(%4 : tensor<4xf32>) { - ^bb0(%arg5: f32, %arg6: f32): - %7 = arith.addf %arg5, %arg6 : f32 - linalg.yield %7 : f32 - } -> tensor<4xf32> - %6 = tensor.insert_slice %5 into %fill_[%i] [4] [1] - : tensor<4xf32> into tensor<8xf32> - gml_st.yield %6 : tensor<8xf32> - } - func.return %2 : tensor<8xf32> - } -} -// CHECK-LABEL: func @reduce_row_sum_2d_static -// CHECK: gml_st.loop -// CHECK: tensor.insert_slice - -// ----- - -#map0 = affine_map<(d0)[s0] -> (4, -d0 + s0)> -#map1 = affine_map<(d0, d1) -> (d0, d1)> -#map2 = affine_map<(d0, d1) -> (d1)> -module { - func.func @reduce_column_sum_2d(%in: tensor) -> tensor { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %0 = tensor.dim %in, %c0 : tensor - %1 = tensor.empty(%0) : tensor - %fill = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor - %3 = tensor.dim %in, %c0 : tensor - %4 = tensor.dim %in, %c1 : tensor - %5 = gml_st.loop (%i, %j) = (%c0, %c0) to (%3, %4) step (%c4, %c4) - ins (%in_ = %in: tensor) - outs (%fill_ = %fill: tensor) - iterators[#gml_st.iterator_type, - #gml_st.iterator_type] { - %6 = affine.min #map0(%i)[%3] - %7 = affine.min #map0(%j)[%4] - %8 = tensor.extract_slice %in_[%i, %j] [%6, %7] [1, 1] - : tensor to tensor - %9 = affine.min #map0(%j)[%4] - %10 = tensor.extract_slice %fill_[%j] [%9] [1] - : tensor to tensor - %11 = linalg.generic { - indexing_maps = [#map1, #map2], - iterator_types = ["reduction", "parallel"]} - ins(%8 : tensor) - outs(%10 : tensor) { - ^bb0(%arg5: f32, %arg6: f32): - %13 = arith.addf %arg5, %arg6 : f32 - linalg.yield %13 : f32 - } -> tensor - %12 = tensor.insert_slice %11 into %fill_[%j] [%9] [1] - : tensor into tensor - gml_st.yield %12 : tensor - } - func.return %5 : tensor - } -} -// CHECK-LABEL: func @reduce_column_sum_2d -// CHECK-SAME: %[[INPUT:.*]]: tensor) -> tensor - -// CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - -// CHECK: %[[DIM_0:.*]] = tensor.dim %[[INPUT]], %[[C0]] : [[TY_2D:.*]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]]) : [[TY_1D:.*]] -// CHECK: %[[INIT_TILE:.*]] = tensor.empty() : tensor<4xf32> -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[C0_F32]]{{.*}}outs(%[[INIT]] -// CHECK: %[[DIM_0_:.*]] = tensor.dim %[[INPUT]], %[[C0]] : [[TY_2D]] -// CHECK: %[[DIM_1:.*]] = tensor.dim %[[INPUT]], %[[C1]] : [[TY_2D]] - -// CHECK: gml_st.loop (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[DIM_0_]], %[[DIM_1]]) step (%[[C4]], %[[C4]]) -// CHECK-SAME: ins (%[[IN_:.*]] = %[[INPUT]]: [[TY_2D]]) -// CHECK-SAME: outs (%[[OUT_:.*]] = %[[FILL]]: [[TY_1D]], -// CHECK-SAME: %[[INIT_TILE_:.*]] = %[[INIT_TILE]]: tensor<4xf32>) - -// CHECK: %[[IN_SUB:.*]] = tensor.extract_slice %[[IN_]][%[[I]], %[[J]]] -// CHECK: %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[J]]] -// CHECK: %[[INIT_TILE_SUB:.*]] = tensor.extract_slice %[[INIT_TILE_]][0] - -// CHECK: %[[FILL_SUB:.*]] = linalg.fill ins(%[[C0_F32]]{{.*}}outs(%[[INIT_TILE_SUB]] -// CHECK: %[[SUM_SUB:.*]] = linalg.generic -// CHECK-SAME: ins(%[[IN_SUB]] : [[TY_2D]]) -// CHECK-SAME: outs(%[[FILL_SUB]] : [[TY_1D]]) -// CHECK: addf -// CHECK-NEXT: linalg.yield - - -// CHECK: %[[ACC:.*]] = linalg.generic -// CHECK-SAME: ins(%[[SUM_SUB]] : [[TY_1D]]) -// CHECK-SAME: outs(%[[OUT_SUB]] : [[TY_1D]]) { -// CHECK: addf -// CHECK-NEXT: linalg.yield - -// CHECK: %[[INIT_TILE_UPDATE:.*]] = tensor.insert_slice -// CHECK-SAME: %[[SUM_SUB:.*]] into %[[INIT_TILE_]] -// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[ACC:.*]] into %[[OUT_]] -// CHECK: gml_st.yield %[[UPDATE]], %[[INIT_TILE_UPDATE]] -// CHECK-SAME: [[TY_1D]], tensor<4xf32> diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_peel_tiled_loops.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_peel_tiled_loops.mlir deleted file mode 100644 index 726802b9f06..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_peel_tiled_loops.mlir +++ /dev/null @@ -1,373 +0,0 @@ -// RUN: tf-tfrt-opt %s -allow-unregistered-dialect -split-input-file \ -// RUN: -tf-jitrt-peel-tiled-loops -cse -canonicalize | FileCheck %s - -#map0 = affine_map<(d0) -> (8, -d0 + 102401)> -#map1 = affine_map<(d0)[s0] -> (d0 + s0)> - -func.func @tanh_1d(%arg0: memref<102401xf32>) -> memref<102401xf32> { - %c102401 = arith.constant 102401 : index - %c8 = arith.constant 8 : index - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %0 = memref.alloc() : memref<102401xf32> - gml_st.loop (%arg1) = (%c0) to (%c102401) step (%c8) - ins (%arg2 = %arg0: memref<102401xf32>) - outs (%arg3 = %0: memref<102401xf32>) { - %1 = affine.min #map0(%arg1) - %2 = memref.subview %arg2[%arg1] [%1] [1] - : memref<102401xf32> to memref - %3 = memref.subview %arg3[%arg1] [%1] [1] - : memref<102401xf32> to memref - %4 = vector.transfer_read %2[%c0], %cst - : memref, vector<8xf32> - %5 = math.tanh %4 : vector<8xf32> - vector.transfer_write %5, %3[%c0] : vector<8xf32>, memref - memref.copy %3, %3 : memref to memref - gml_st.yield - } - func.return %0 : memref<102401xf32> -} - -// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> - -// CHECK-LABEL: func @tanh_1d - -// CHECK: gml_st.loop -// CHECK: memref.subview -// CHECK-SAME: memref<102401xf32> to memref<8xf32, strided<[1], offset: ?>> -// CHECK: memref.subview -// CHECK-SAME: memref<102401xf32> to memref<8xf32, strided<[1], offset: ?>> - -// CHECK: gml_st.loop -// CHECK: memref.subview -// CHECK-SAME: memref<102401xf32> to memref -// CHECK: memref.subview -// CHECK-SAME: memref<102401xf32> to memref - -// ----- - -func.func @tanh_3d(%d0: index, %d1: index, %d2: index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - gml_st.loop (%arg1 ,%arg2, %arg3) = (%c0, %c0, %c0) - to (%d0, %d1, %d2) step (%c8, %c1, %c8) - ins () outs () { - "prevent.dce"() : () -> () - gml_st.yield - } - func.return -} - -// CHECK-LABEL: func @tanh_3d( -// CHECK-SAME: %[[D0:[a-z0-9]+]]: index, %[[D1:[a-z0-9]+]]: index, -// CHECK-SAME: %[[D2:[a-z0-9]+]]: index) { -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index - -// CHECK-DAG: %[[SPLIT0:.*]] = affine.apply{{.*}}%[[D0]] -// CHECK-DAG: %[[SPLIT2:.*]] = affine.apply{{.*}}%[[D2]] - -// CHECK: gml_st.loop{{.*}}(%[[C0]], %[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[SPLIT0]], %arg1, %[[SPLIT2]]) -// CHECK-SAME: step (%[[C8]], %[[C1]], %[[C8]]) - -// CHECK: gml_st.loop{{.*}}(%[[SPLIT0]], %[[C0]], %[[C0]]) -// CHECK-SAME: to (%arg0, %arg1, %[[SPLIT2]]) -// CHECK-SAME: step (%[[C8]], %[[C1]], %[[C8]]) - -// CHECK: gml_st.loop{{.*}}(%[[C0]], %[[C0]], %[[SPLIT2]]) -// CHECK-SAME: to (%arg0, %arg1, %arg2) -// CHECK-SAME: step (%[[C8]], %[[C1]], %[[C8]]) - -// ----- - -func.func @reduce_column_sum_2d_dynamic(%in: tensor) -> tensor { - %cst = arith.constant 0.000000e+00 : f32 - %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c0 = arith.constant 0 : index - - %dim_X = tensor.dim %in, %c0 : tensor - %dim_Y = tensor.dim %in, %c1 : tensor - - %1 = tensor.empty(%dim_Y) : tensor - %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor - %5 = gml_st.loop (%i, %j) = (%c0, %c0) to (%dim_Y, %dim_X) - step (%c4, %c4) - ins (%in_ = %in: tensor, %cst_ = %cst: f32) - outs (%out_ = %2: tensor) - iterators[#gml_st.iterator_type, - #gml_st.iterator_type] { - %6 = affine.min affine_map<(d0)[s0] -> (4, -d0 + s0)>(%j)[%dim_X] - %9 = affine.min affine_map<(d0)[s0] -> (4, -d0 + s0)>(%i)[%dim_Y] - - %8 = tensor.extract_slice %in_[%j, %i] [%6, %9] [1, 1] - : tensor to tensor - %11 = tensor.extract_slice %out_[%i] [%9] [1] - : tensor to tensor - - %12 = linalg.fill ins(%cst_ : f32) outs(%11 : tensor) -> tensor - %13 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, - affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%8 : tensor) - outs(%12 : tensor) { - ^bb0(%arg6: f32, %arg7: f32): - %16 = arith.addf %arg6, %arg7 : f32 - linalg.yield %16 : f32 - } -> tensor - %14 = linalg.generic { - indexing_maps = [affine_map<(d0) -> (d0)>, - affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - ins(%13 : tensor) - outs(%11 : tensor) { - ^bb0(%arg6: f32, %arg7: f32): - %16 = arith.addf %arg6, %arg7 : f32 - linalg.yield %16 : f32 - } -> tensor - %15 = tensor.insert_slice %14 into %out_[%i] [%9] [1] - : tensor into tensor - gml_st.yield %15 : tensor - } - func.return %5 : tensor -} - -// CHECK-LABEL: func @reduce_column_sum_2d_dynamic - -// CHECK: linalg.fill -// CHECK: gml_st.loop -// CHECK: tensor.extract_slice -// CHECK-SAME: tensor to tensor<4x4xf32> -// CHECK: tensor.extract_slice -// CHECK-SAME: tensor<4xf32> - -// CHECK: gml_st.loop -// CHECK: tensor.extract_slice -// CHECK-SAME: tensor to tensor<4x?xf32> -// CHECK: tensor.extract_slice -// CHECK-SAME: tensor to tensor - -// CHECK: gml_st.loop -// CHECK: tensor.extract_slice -// CHECK-SAME: tensor to tensor -// CHECK: tensor.extract_slice -// CHECK-SAME: tensor to tensor - -// ----- - -func.func @reduce_row_sum_2d_dynamic(%in: tensor) -> tensor { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - - %dim_X = tensor.dim %in, %c0 : tensor - %dim_Y = tensor.dim %in, %c1 : tensor - - %1 = tensor.empty(%dim_X) : tensor - %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor - %5 = gml_st.loop (%i, %j) = (%c0, %c0) to (%dim_X, %dim_Y) - step (%c4, %c4) - ins (%in_ = %in: tensor, %cst_ = %cst: f32) - outs (%out_ = %2: tensor) - iterators[#gml_st.iterator_type, - #gml_st.iterator_type] { - %6 = affine.min affine_map<(d0)[s0] -> (4, -d0 + s0)>(%i)[%dim_X] - %7 = affine.min affine_map<(d0)[s0] -> (4, -d0 + s0)>(%j)[%dim_Y] - - %8 = tensor.extract_slice %in_[%i, %j] [%6, %7] [1, 1] - : tensor to tensor - %11 = tensor.extract_slice %out_[%i] [%6] [1] - : tensor to tensor - %12 = linalg.fill ins(%cst_ : f32) outs(%11 : tensor) -> tensor - %13 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%8 : tensor) - outs(%12 : tensor) { - ^bb0(%arg6: f32, %arg7: f32): - %16 = arith.addf %arg6, %arg7 : f32 - linalg.yield %16 : f32 - } -> tensor - %14 = linalg.generic { - indexing_maps = [affine_map<(d0) -> (d0)>, - affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - ins(%13 : tensor) - outs(%11 : tensor) { - ^bb0(%arg6: f32, %arg7: f32): - %16 = arith.addf %arg6, %arg7 : f32 - linalg.yield %16 : f32 - } -> tensor - %15 = tensor.insert_slice %14 into %out_[%i] [%6] [1] - : tensor into tensor - gml_st.yield %15 : tensor - } - func.return %5 : tensor -} - -// CHECK-LABEL: func @reduce_row_sum_2d_dynamic - -// CHECK: linalg.fill -// CHECK: gml_st.loop -// CHECK: tensor.extract_slice -// CHECK-SAME: tensor to tensor<4x4xf32> -// CHECK: tensor.extract_slice -// CHECK-SAME: tensor<4xf32> - -// CHECK: gml_st.loop -// CHECK: tensor.extract_slice -// CHECK-SAME: tensor to tensor -// CHECK: tensor.extract_slice -// CHECK-SAME: tensor to tensor - -// CHECK: gml_st.loop -// CHECK: tensor.extract_slice -// CHECK-SAME: tensor to tensor -// CHECK: tensor.extract_slice -// CHECK-SAME: tensor to tensor - -// ----- - -func.func @matmul(%arg0: tensor, %arg1: tensor) -> tensor { - %c2 = arith.constant 2 : index - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - %c4 = arith.constant 4 : index - %cst = arith.constant 0.000000e+00 : f32 - %dim = tensor.dim %arg0, %c0 : tensor - %dim_0 = tensor.dim %arg1, %c1 : tensor - %0 = tensor.empty(%dim, %dim_0) : tensor - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor - %dim_1 = tensor.dim %arg0, %c0 : tensor - %dim_2 = tensor.dim %arg0, %c1 : tensor - %dim_3 = tensor.dim %arg1, %c1 : tensor - %2 = gml_st.parallel (%arg2, %arg3) = (%c0, %c0) to (%dim_1, %dim_3) step (%c8, %c4) { - %3 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 8)>(%arg2)[%dim_1] - %4 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%arg3)[%dim_3] - %5 = gml_st.tile [%arg2, 0] [%3, %dim_2] [1, 1] : !gml_st.tile - %6 = gml_st.materialize %arg0[%5] : tensor[!gml_st.tile] to tensor - %7 = gml_st.tile [0, %arg3] [%dim_2, %4] [1, 1] : !gml_st.tile - %8 = gml_st.materialize %arg1[%7] : tensor[!gml_st.tile] to tensor - %9 = gml_st.tile [%arg2, %arg3] [%3, %4] [1, 1] : !gml_st.tile - %10 = gml_st.materialize %1[%9] : tensor[!gml_st.tile] to tensor - %dim_4 = tensor.dim %6, %c0 : tensor - %dim_5 = tensor.dim %6, %c1 : tensor - %dim_6 = tensor.dim %8, %c1 : tensor - %11 = gml_st.for (%arg4) = (%c0) to (%dim_5) step (%c2) outs (%arg5 = %10: tensor) { - %12 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 2)>(%arg4)[%dim_5] - %13 = gml_st.tile [0, %arg4] [%dim_4, %12] [1, 1] : !gml_st.tile - %14 = gml_st.materialize %6[%13] : tensor[!gml_st.tile] to tensor - %15 = gml_st.tile [%arg4, 0] [%12, %dim_6] [1, 1] : !gml_st.tile - %16 = gml_st.materialize %8[%15] : tensor[!gml_st.tile] to tensor - %17 = gml_st.tile [0, 0] [%dim_4, %dim_6] [1, 1] : !gml_st.tile - %18 = gml_st.materialize %arg5[%17] : tensor[!gml_st.tile] to tensor - %19 = linalg.matmul ins(%14, %16 : tensor, tensor) outs(%18 : tensor) -> tensor - gml_st.set_yield %19 into %arg5[%17] : tensor into tensor[!gml_st.tile] - } : tensor - gml_st.set_yield %11 into %1[%9] : tensor into tensor[!gml_st.tile] - } : tensor - return %2 : tensor -} - -// CHECK-DAG: #[[$MAP_MAIN_PAR_I:.*]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8)> -// CHECK-DAG: #[[$MAP_MAIN_PAR_J:.*]] = affine_map<()[s0] -> ((s0 floordiv 4) * 4)> -// CHECK-DAG: #[[$MAP_MAIN_FOR:.*]] = affine_map<()[s0] -> ((s0 floordiv 2) * 2)> -// CHECK-DAG: #[[$MAP_REM_PAR1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0)> -// CHECK-DAG: #[[$MAP_REM_PAR2:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 8)> -// CHECK-DAG: #[[$MAP_REM_FOR:.*]] = affine_map<(d0, d1) -> (-d0 + d1)> - -// CHECK-LABEL: func @matmul( -// CHECK-SAME: %[[LHS:.*]]: tensor, -// CHECK-SAME: %[[RHS:.*]]: tensor) -> tensor - -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[LHS_ROW:.*]] = tensor.dim %[[LHS]], %[[C0]] -// CHECK-DAG: %[[RHS_COL:.*]] = tensor.dim %[[RHS]], %[[C1]] -// CHECK-DAG: %[[FILL:.*]] = linalg.fill -// CHECK-DAG: %[[MAIN_PAR_I_UB:.*]] = affine.apply #[[$MAP_MAIN_PAR_I]]()[%[[LHS_ROW]]] -// CHECK-DAG: %[[MAIN_PAR_J_UB:.*]] = affine.apply #[[$MAP_MAIN_PAR_J]]()[%[[RHS_COL]]] - -// CHECK: %[[MAIN_PAR:.*]] = gml_st.parallel ( -// CHECK-SAME: %[[MAIN_PAR_I:.*]], %[[MAIN_PAR_J:.*]]) = (%[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[MAIN_PAR_I_UB]], %[[MAIN_PAR_J_UB]]) step (%[[C8]], %[[C4]]) -// CHECK: %[[MAIN_PAR_OUT_TILE:.*]] = gml_st.tile [%[[MAIN_PAR_I]], %[[MAIN_PAR_J]]] -// CHECK: %[[MAIN_PAR_OUT_SLICE:.*]] = gml_st.materialize %[[FILL]][%[[MAIN_PAR_OUT_TILE]]] : -// CHECK: %[[MAIN_PAR_FOR_K_UB:.*]] = affine.apply #[[$MAP_MAIN_FOR]]()[%[[MAIN_PAR_LHS_SUB_COL:.*]]] - -// CHECK: %[[MAIN_PAR_MAIN_FOR:.*]] = gml_st.for ( -// CHECK-SAME: %[[MAIN_PAR_MAIN_FOR_K:.*]]) = (%[[C0]]) -// CHECK-SAME: to (%[[MAIN_PAR_FOR_K_UB]]) step (%[[C2]]) -// CHECK-SAME: outs (%[[MAIN_PAR_MAIN_FOR_OUT:.*]] = %[[MAIN_PAR_OUT_SLICE]]: tensor<8x4xf32>) -// CHECK: %[[MAIN_PAR_MAIN_FOR_OUT_SLICE:.*]] = gml_st.materialize %[[MAIN_PAR_MAIN_FOR_OUT]][%[[MAIN_PAR_MAIN_FOR_OUT_TILE:.*]]] : -// CHECK: %[[MAIN_PAR_MAIN_FOR_MATMUL:.*]] = linalg.matmul ins({{.*}}) outs(%[[MAIN_PAR_MAIN_FOR_OUT_SLICE]] -// CHECK-NEXT: gml_st.set_yield %[[MAIN_PAR_MAIN_FOR_MATMUL]] into %[[MAIN_PAR_MAIN_FOR_OUT]][%[[MAIN_PAR_MAIN_FOR_OUT_TILE]]] : tensor<8x4xf32> into tensor<8x4xf32>[!gml_st.tile<8x4>] - -// CHECK: %[[MAIN_PAR_REM_FOR:.*]] = gml_st.for ( -// CHECK-SAME: %[[MAIN_PAR_REM_FOR_K:.*]]) = (%[[MAIN_PAR_FOR_K_UB]]) -// CHECK-SAME: to (%[[MAIN_PAR_LHS_SUB_COL]]) step (%[[C2]]) -// CHECK-SAME: outs (%[[MAIN_PAR_REM_FOR_OUT:.*]] = %[[MAIN_PAR_MAIN_FOR]]: tensor<8x4xf32>) -// CHECK: %[[MAIN_PAR_REM_FOR_OUT_SLICE:.*]] = gml_st.materialize %[[MAIN_PAR_REM_FOR_OUT]][%[[MAIN_PAR_REM_FOR_OUT_TILE:.*]]] : -// CHECK: %[[MAIN_PAR_REM_FOR_MATMUL:.*]] = linalg.matmul ins({{.*}}) outs(%[[MAIN_PAR_REM_FOR_OUT_SLICE]] -// CHECK-NEXT: gml_st.set_yield %[[MAIN_PAR_REM_FOR_MATMUL]] into %[[MAIN_PAR_REM_FOR_OUT]][%[[MAIN_PAR_REM_FOR_OUT_TILE]]] : tensor<8x4xf32> into tensor<8x4xf32>[!gml_st.tile<8x4>] - -// CHECK: gml_st.set_yield %[[MAIN_PAR_REM_FOR]] into %[[FILL]][%[[MAIN_PAR_OUT_TILE]]] - -// CHECK: %[[REM_PAR_LHS_ROW:.*]] = gml_st.parallel ( -// CHECK-SAME: %[[REM_PAR_LHS_ROW_I:.*]], %[[REM_PAR_LHS_ROW_J:.*]]) = (%[[MAIN_PAR_I_UB]], %[[C0]]) -// CHECK-SAME: to (%[[LHS_ROW]], %[[MAIN_PAR_J_UB]]) step (%[[C8]], %[[C4]]) -// CHECK: %[[REM_PAR_LHS_ROW_OUT_TILE:.*]] = gml_st.tile [%[[REM_PAR_LHS_ROW_I]], %[[REM_PAR_LHS_ROW_J]]] -// CHECK: %[[REM_PAR_LHS_ROW_OUT_SLICE:.*]] = gml_st.materialize %[[MAIN_PAR]][%[[REM_PAR_LHS_ROW_OUT_TILE]]] : -// CHECK: %[[REM_PAR_LHS_ROW_FOR_K_UB:.*]] = affine.apply #[[$MAP_MAIN_FOR]]()[%[[REM_PAR_LHS_ROW_LHS_SUB_COL:.*]]] - -// CHECK: %[[REM_PAR_LHS_ROW_MAIN_FOR:.*]] = gml_st.for ( -// CHECK-SAME: %[[REM_PAR_LHS_ROW_MAIN_FOR_K:.*]]) = (%[[C0]]) -// CHECK-SAME: to (%[[REM_PAR_LHS_ROW_FOR_K_UB]]) step (%[[C2]]) -// CHECK-SAME: outs (%[[REM_PAR_LHS_ROW_MAIN_FOR_OUT:.*]] = %[[REM_PAR_LHS_ROW_OUT_SLICE]]: -// CHECK: %[[REM_PAR_LHS_ROW_MAIN_FOR_OUT_SLICE:.*]] = gml_st.materialize %[[REM_PAR_LHS_ROW_MAIN_FOR_OUT]][%[[REM_PAR_LHS_ROW_MAIN_FOR_OUT_TILE:.*]]] : -// CHECK: %[[REM_PAR_LHS_ROW_MAIN_FOR_MATMUL:.*]] = linalg.matmul ins({{.*}}) outs(%[[REM_PAR_LHS_ROW_MAIN_FOR_OUT_SLICE]] -// CHECK-NEXT: gml_st.set_yield %[[REM_PAR_LHS_ROW_MAIN_FOR_MATMUL]] into %[[REM_PAR_LHS_ROW_MAIN_FOR_OUT]][%[[REM_PAR_LHS_ROW_MAIN_FOR_OUT_TILE]]] - -// CHECK: %[[REM_PAR_LHS_ROW_REM_FOR:.*]] = gml_st.for ( -// CHECK-SAME: %[[REM_PAR_LHS_ROW_REM_FOR_K:.*]]) = (%[[REM_PAR_LHS_ROW_FOR_K_UB]]) -// CHECK-SAME: to (%[[REM_PAR_LHS_ROW_LHS_SUB_COL]]) step (%[[C2]]) -// CHECK-SAME: outs (%[[REM_PAR_LHS_ROW_REM_FOR_OUT:.*]] = %[[REM_PAR_LHS_ROW_MAIN_FOR]]: -// CHECK: %[[REM_PAR_LHS_ROW_REM_FOR_OUT_SLICE:.*]] = gml_st.materialize %[[REM_PAR_LHS_ROW_REM_FOR_OUT]][%[[REM_PAR_LHS_ROW_REM_FOR_OUT_TILE:.*]]] : -// CHECK: %[[REM_PAR_LHS_ROW_REM_FOR_MATMUL:.*]] = linalg.matmul ins({{.*}}) outs(%[[REM_PAR_LHS_ROW_REM_FOR_OUT_SLICE]] -// CHECK-NEXT: gml_st.set_yield %[[REM_PAR_LHS_ROW_REM_FOR_MATMUL]] into %[[REM_PAR_LHS_ROW_REM_FOR_OUT]][%[[REM_PAR_LHS_ROW_REM_FOR_OUT_TILE]]] - -// CHECK: gml_st.set_yield %[[REM_PAR_LHS_ROW_REM_FOR]] into %[[MAIN_PAR]][%[[REM_PAR_LHS_ROW_OUT_TILE]]] - -// CHECK: %[[REM_PAR_RHS_COL:.*]] = gml_st.parallel ( -// CHECK-SAME: %[[REM_PAR_RHS_COL_I:.*]], %[[REM_PAR_RHS_COL_J:.*]]) = (%[[C0]], %[[MAIN_PAR_J_UB]]) -// CHECK-SAME: to (%[[LHS_ROW]], %[[RHS_COL]]) step (%[[C8]], %[[C4]]) -// CHECK: %[[REM_PAR_RHS_COL_OUT_SLICE:.*]] = gml_st.materialize %[[REM_PAR_LHS_ROW]][%[[REM_PAR_RHS_COL_OUT_TILE:.*]]] : -// CHECK: %[[REM_PAR_RHS_COL_FOR_K_UB:.*]] = affine.apply #[[$MAP_MAIN_FOR]]()[%[[REM_PAR_RHS_COL_LHS_SUB_COL:.*]]] - -// CHECK: %[[REM_PAR_RHS_COL_MAIN_FOR:.*]] = gml_st.for ( -// CHECK-SAME: %[[REM_PAR_RHS_COL_MAIN_FOR_K:.*]]) = (%[[C0]]) -// CHECK-SAME: to (%[[REM_PAR_RHS_COL_FOR_K_UB]]) step (%[[C2]]) -// CHECK-SAME: outs (%[[REM_PAR_RHS_COL_MAIN_FOR_OUT:.*]] = %[[REM_PAR_RHS_COL_OUT_SLICE]]: -// CHECK: %[[REM_PAR_RHS_COL_MAIN_FOR_OUT_SLICE:.*]] = gml_st.materialize %[[REM_PAR_RHS_COL_MAIN_FOR_OUT]][%[[REM_PAR_RHS_COL_MAIN_FOR_OUT_TILE:.*]]] : -// CHECK: %[[REM_PAR_RHS_COL_MAIN_FOR_MATMUL:.*]] = linalg.matmul ins({{.*}}) outs(%[[REM_PAR_RHS_COL_MAIN_FOR_OUT_SLICE]] -// CHECK-NEXT: gml_st.set_yield %[[REM_PAR_RHS_COL_MAIN_FOR_MATMUL]] into %[[REM_PAR_RHS_COL_MAIN_FOR_OUT]][%[[REM_PAR_RHS_COL_MAIN_FOR_OUT_TILE]]] - -// CHECK: %[[REM_PAR_RHS_COL_REM_FOR:.*]] = gml_st.for ( -// CHECK-SAME: %[[REM_PAR_RHS_COL_REM_FOR_K:.*]]) = (%[[REM_PAR_RHS_COL_FOR_K_UB]]) -// CHECK-SAME: to (%[[REM_PAR_RHS_COL_LHS_SUB_COL]]) step (%[[C2]]) -// CHECK-SAME: outs (%[[REM_PAR_RHS_COL_REM_FOR_OUT:.*]] = %[[REM_PAR_RHS_COL_MAIN_FOR]]: -// CHECK: %[[REM_PAR_RHS_COL_REM_FOR_OUT_SLICE:.*]] = gml_st.materialize %[[REM_PAR_RHS_COL_REM_FOR_OUT]][%[[REM_PAR_RHS_COL_REM_FOR_OUT_TILE:.*]]] : -// CHECK: %[[REM_PAR_RHS_COL_REM_FOR_MATMUL:.*]] = linalg.matmul ins({{.*}}) outs(%[[REM_PAR_RHS_COL_REM_FOR_OUT_SLICE]] -// CHECK-NEXT: gml_st.set_yield %[[REM_PAR_RHS_COL_REM_FOR_MATMUL]] into %[[REM_PAR_RHS_COL_REM_FOR_OUT]][%[[REM_PAR_RHS_COL_REM_FOR_OUT_TILE]]] - -// CHECK: gml_st.set_yield %[[REM_PAR_RHS_COL_REM_FOR]] into %[[REM_PAR_LHS_ROW]][%[[REM_PAR_RHS_COL_OUT_TILE]]] diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline.mlir index 0a3a512a5a8..a01636a4d9d 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline.mlir @@ -38,8 +38,8 @@ func.func @sigmoid_dynamic_dim(%arg0: tensor) -> tensor { // ----- -// CHECK: #map{{[0-9]*}} = affine_map<(d0) -> ()> -// CHECK: #map{{[0-9]*}} = affine_map<(d0) -> (d0)> +// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0) -> ()> +// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0) -> (d0)> // CHECK-LABEL: @add_scalar_with_vec func.func @add_scalar_with_vec(%arg0: tensor, @@ -271,8 +271,8 @@ func.func @cast_sub(%arg0: tensor, %arg1: tensor) // ----- -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d1, d0)> -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: @tf_transpose_const_perm func.func @tf_transpose_const_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { @@ -289,8 +289,8 @@ func.func @tf_transpose_const_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { // ----- -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (d2, d0, d1)> -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +// CHECK-DAG: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: @tf_transpose_after_transpose func.func @tf_transpose_after_transpose(%arg0: tensor) diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline_one_shot.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline_one_shot.mlir deleted file mode 100644 index 0a3a512a5a8..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline_one_shot.mlir +++ /dev/null @@ -1,440 +0,0 @@ -// RUN: tf-tfrt-opt -split-input-file -tf-jitrt-pipeline %s | FileCheck %s - -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK-LABEL: @tanh_lower_and_fuse -// CHECK-SAME: %[[ARG:.*]]: memref -func.func @tanh_lower_and_fuse(%arg0: tensor) -> tensor { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[DIM:.*]] = memref.dim %[[ARG]], %[[C0]] - // CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DIM]]) {{.*}} : memref - - // CHECK: linalg.generic - // CHECK-SAME: indexing_maps = [#map, #map] - // CHECK-SAME: iterator_types = ["parallel", "parallel"] - // CHECK-SAME: ins(%[[ARG]] : memref) - // CHECK-SAME: outs(%[[MEMREF]] : memref) - // CHECK: tanh - // CHECK-NEXT: tanh - - // CHECK: return %[[MEMREF]] - %0 = "tf.Tanh"(%arg0): (tensor) -> tensor - %1 = "tf.Tanh"(%0): (tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK-LABEL: @sigmoid_dynamic_dim -func.func @sigmoid_dynamic_dim(%arg0: tensor) -> tensor { - // CHECK: linalg.generic - // CHECK-SAME: indexing_maps = [#map, #map] - // CHECK-SAME: iterator_types = ["parallel", "parallel"] - %0 = "tf.Sigmoid"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK: #map{{[0-9]*}} = affine_map<(d0) -> ()> -// CHECK: #map{{[0-9]*}} = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: @add_scalar_with_vec -func.func @add_scalar_with_vec(%arg0: tensor, - %arg1: tensor) -> tensor { - // CHECK: linalg.generic - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg1): (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK: #map = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: @add_vec_vec -func.func @add_vec_vec( - %arg0: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg1: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>} -) -> tensor { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg1): (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK: #map = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: @add_vec_vec_vec -func.func @add_vec_vec_vec( - %arg0: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg1: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg2: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>} -) -> tensor { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK: addf - // CHECK: addf - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg1): (tensor, tensor) -> tensor - %1 = "tf.AddV2"(%0, %arg2): (tensor, tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// Verify that symbolic shape optimization can move all the broadcasts up, and -// progressively remove all shape constraints and replace mhlo broadcasts with -// linalg.generic operations that in the end all are fused together. - -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, 0)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - -// CHECK: compute_with_bcast -func.func @compute_with_bcast( - %arg0: tensor<1x?x1xf32> - {rt.symbolic_shape = dense<[1, -2, 1]> : tensor<3xi64>}, - %arg1: tensor<512xf32>, - %arg2: tensor<1x?x512xf32> - {rt.symbolic_shape = dense<[1, -2, 512]> : tensor<3xi64>}, - %arg3: tensor<1x?x1xf32> - {rt.symbolic_shape = dense<[1, -2, 1]> : tensor<3xi64>}, - %arg4: tensor<512xf32> -) -> tensor { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK: addf - // CHECK-NEXT: math.rsqrt - // CHECK-NEXT: mulf - // CHECK-NEXT: mulf - // CHECK-NEXT: subf - // CHECK-NEXT: mulf - // CHECK-NEXT: addf - // CHECK-NEXT: linalg.yield - // CHECK-NOT: linalg.generic - %c = "tf.Const"() {value = dense<9.99999996E-13> - : tensor} : () -> tensor - %0 = "tf.AddV2"(%arg0, %c) - : (tensor<1x?x1xf32>, tensor) -> tensor - %1 = "tf.Rsqrt"(%0) - : (tensor) -> tensor - %2 = "tf.Mul"(%1, %arg1) - : (tensor, tensor<512xf32>) -> tensor - %3 = "tf.Mul"(%2, %arg2) - : (tensor, tensor<1x?x512xf32>) -> tensor - %4 = "tf.Mul"(%2, %arg3) - : (tensor, tensor<1x?x1xf32>) -> tensor - %5 = "tf.Sub"(%arg4, %4) - : (tensor<512xf32>, tensor) -> tensor - %6 = "tf.AddV2"(%3, %5) - : (tensor, tensor) -> tensor - func.return %6 : tensor -} - -// ----- - -// CHECK: add_vec_vec_vec_vec -func.func @add_vec_vec_vec_vec( - %arg0: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg1: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg2: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>}, - %arg3: tensor {rt.symbolic_shape = dense<-2>: tensor<1xi64>} -) -> tensor { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK: addf - // CHECK: addf - // CHECK: addf - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg1): (tensor, tensor) -> tensor - %1 = "tf.AddV2"(%0, %arg2): (tensor, tensor) -> tensor - %2 = "tf.AddV2"(%1, %arg3): (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// ----- - -// CHECK: add_vec_tensor_tensor -func.func @add_vec_tensor_tensor( - %arg0: tensor<512xf32>, - %arg1: tensor<1x?x512xf32> - {rt.symbolic_shape = dense<[1, -2, 512]> : tensor<3xi64>}, - %arg2: tensor<1x?x512xf32> - {rt.symbolic_shape = dense<[1, -2, 512]> : tensor<3xi64>} -) -> tensor<1x?x512xf32> { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK: addf - // CHECK: addf - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg1) - : (tensor<512xf32>, tensor<1x?x512xf32>) -> tensor<1x?x512xf32> - %1 = "tf.AddV2"(%arg2, %0) - : (tensor<1x?x512xf32>, tensor<1x?x512xf32>) -> tensor<1x?x512xf32> - func.return %1 : tensor<1x?x512xf32> -} - -// ----- - -// CHECK-LABEL: @tf_binary_with_bcast -func.func @tf_binary_with_bcast(%arg0: tensor, - %arg1: tensor) -> tensor { - // CHECK-NOT: shape. - // CHECK: %[[LHS:.*]] = memref.reinterpret_cast - // CHECK: %[[RHS:.*]] = memref.reinterpret_cast - // CHECK: linalg.generic {{.*}} ins(%[[LHS]], %[[RHS]] : - // CHECK: mulf - %0 = "tf.Mul"(%arg0, %arg1) - : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @tf_binary_with_bcast_and_fusion -// CHECK-SAME: %[[ARG0:.*]]: memref, -// CHECK-SAME: %[[ARG1:.*]]: memref<4xf32>, -// CHECK-SAME: %[[ARG2:.*]]: memref<4xf32> -func.func @tf_binary_with_bcast_and_fusion(%arg0: tensor, - %arg1: tensor<4xf32>, - %arg2: tensor<4xf32>) -> tensor { - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : {{.*}}) - // CHECK: math.log1p - // CHECK-NEXT: subf - // CHECK-NEXT: mulf - // CHECK-NEXT: linalg.yield - // CHECK-NOT: linalg.generic - %0 = "tf.Log1p"(%arg0) - : (tensor) -> tensor - %1 = "tf.Sub"(%0, %arg1) - : (tensor, tensor<4xf32>) -> tensor - %2 = "tf.Mul"(%1, %arg2) - : (tensor, tensor<4xf32>) -> tensor - func.return %2 : tensor -} - -// ----- - -// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK: tf_binary_with_bcast_symbolic_shapes -func.func @tf_binary_with_bcast_symbolic_shapes( - %arg0: tensor {rt.symbolic_shape = dense<[ -3]>: tensor<1xi64>}, - %arg1: tensor {rt.symbolic_shape = dense<[-2,-3]>: tensor<2xi64>}, - %arg2: tensor {rt.symbolic_shape = dense<[-2,-3]>: tensor<2xi64>}, - %arg3: tensor {rt.symbolic_shape = dense<[-2,-3]>: tensor<2xi64>} -) -> tensor { - // CHECK-NOT: memref.reinterpret_cast - // CHECK: linalg.generic - // CHECK: log1p - // CHECK: addf - // CHECK: addf - // CHECK: addf - // CHECK-NOT: linalg.generic - %0 = "tf.Log1p"(%arg0) - : (tensor) -> tensor - %1 = "tf.AddV2"(%0, %arg1) - : (tensor, tensor) -> tensor - %2 = "tf.AddV2"(%1, %arg2) - : (tensor, tensor) -> tensor - %3 = "tf.AddV2"(%2, %arg3) - : (tensor, tensor) -> tensor - func.return %3 : tensor -} - -// ----- - -// CHECK-LABEL: @cast_sub -func.func @cast_sub(%arg0: tensor, %arg1: tensor) - -> tensor { - // CHECK: linalg.generic - // CHECK-SAME: outs(%[[RESULT_BUF:.*]] : memref) - // CHECK-SAME: { - // CHECK: ^bb0(%[[LHS:.*]]: f16, %[[RHS:.*]]: i16, %{{.*}}: f16): - // CHECK: %[[RHS_CASTED:.*]] = arith.sitofp %[[RHS]] : i16 to f16 - // CHECK: %[[RESULT:.*]] = arith.subf %[[LHS]], %[[RHS_CASTED]] : f16 - // CHECK: linalg.yield %[[RESULT]] : f16 - // CHECK: } - // CHECK: return %[[RESULT_BUF]] : memref - %0 = "tf.Cast"(%arg0) : (tensor) -> tensor - %1 = "tf.Sub"(%arg1, %0) : (tensor, tensor) - -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d1, d0)> -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK-LABEL: @tf_transpose_const_perm -func.func @tf_transpose_const_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { - // CHECK: %[[OUT:.*]] = memref.alloc() {{.*}} : memref<3x2xf32> - // CHECK: linalg.generic {indexing_maps = [#map{{[0-9]*}}, #map{{[0-9]*}}] - // CHECK-SAME: ins(%arg0 : memref<2x3xf32>) - // CHECK-SAME: outs(%[[OUT]] : memref<3x2xf32>) - %0 = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } - : () -> tensor<2xi32> - %1 = "tf.Transpose"(%arg0, %0) - : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> - func.return %1 : tensor<3x2xf32> -} - -// ----- - -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (d2, d0, d1)> -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - -// CHECK-LABEL: @tf_transpose_after_transpose -func.func @tf_transpose_after_transpose(%arg0: tensor) - -> tensor { - // CHECK: %[[OUT:.*]] = memref.alloc - // CHECK: linalg.generic {indexing_maps = [#map{{[0-9]*}}, #map{{[0-9]*}}] - // CHECK-SAME: ins(%arg0 : memref) - // CHECK-SAME: outs(%[[OUT]] : memref) - // CHECK-NOT: linalg.generic - %0 = "tf.Const"() { value = dense<[0, 2, 1]> : tensor<3xi32> } - : () -> tensor<3xi32> - %1 = "tf.Const"() { value = dense<[2, 1, 0]> : tensor<3xi32> } - : () -> tensor<3xi32> - %2 = "tf.Transpose"(%arg0, %0) - : (tensor, tensor<3xi32>) -> tensor - %3 = "tf.Transpose"(%2, %1) - : (tensor, tensor<3xi32>) -> tensor - func.return %3 : tensor -} - -// ----- - -// CHECK-LABEL: @bias_add_and_relu -// CHECK-SAME: %[[ARG0:.*]]: memref -// CHECK-SAME: %[[ARG1:.*]]: memref<32xf32> -func.func @bias_add_and_relu(%arg0: tensor, - %arg1: tensor<32xf32>) -> tensor { - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) - // CHECK: addf - // CHECK: maxf - // CHECK-NEXT: linalg.yield - // CHECK-NOT: linalg.generic - %0 = "tf.BiasAdd"(%arg0, %arg1) - : (tensor, tensor<32xf32>) -> tensor - %1 = "tf.Relu"(%0): (tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @sub_sub -func.func @sub_sub(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK: linalg.generic - // CHECK-SAME: outs(%[[RESULT_BUF:.*]] : memref) - // CHECK: ^bb0(%[[A:.*]]: f16, %[[B:.*]]: f16, %[[C:.*]]: f16, %{{.*}}: f16): - // CHECK: %[[TMP:.*]] = arith.subf %[[B]], %[[C]] - // CHECK: %[[RESULT:.*]] = arith.subf %[[A]], %[[TMP]] - // CHECK: linalg.yield %[[RESULT]] - // CHECK: return %[[RESULT_BUF]] : memref - %0 = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor - %1 = "tf.Sub"(%arg2, %0) : (tensor, tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @strided_slice_1d_to_0d -func.func @strided_slice_1d_to_0d(%arg0: tensor<3xi32>) -> tensor { - %cst_0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %cst_1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[0] [1] [1] - // CHECK-SAME: : memref<3xi32> to memref<1xi32, strided<[1]>> - // CHECK: %[[RET:.*]] = memref.collapse_shape %[[SUBVIEW]] - // CHECK: return %[[RET]] - %0 = "tf.StridedSlice"(%arg0, %cst_1, %cst_0, %cst_0) - { - begin_mask = 0 : i64, - ellipsis_mask = 0 : i64, - end_mask = 0 : i64, - new_axis_mask = 0 : i64, - shrink_axis_mask = 1 : i64 - } : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) - -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK: memref.global "private" constant @__constant_2xi32 : memref<2xi32> = dense<[0, 1]> -// CHECK-SAME: {alignment = 64 : i64} -// CHECK-LABEL: @constant_folding -func.func @constant_folding() -> tensor<2xi32> { - %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: %[[CONST:.*]] = memref.get_global @__constant_2xi32 : memref<2xi32> - // CHECK: return %[[CONST]] - %2 = "tf.Pack"(%0, %1) {axis = 0 : i64} - : (tensor, tensor) -> tensor<2xi32> - func.return %2 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: @add_floormod_add -func.func @add_floormod_add(%arg0: tensor) -> tensor { - // CHECK: linalg.generic - // CHECK-NOT: linalg.generic - %0 = "tf.AddV2"(%arg0, %arg0) - : (tensor, tensor) -> tensor - %1 = "tf.FloorMod"(%0, %arg0) - : (tensor, tensor) -> tensor - %2 = "tf.AddV2"(%1, %arg0) - : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// ----- - -// CHECK-LABEL: @min_clip_by_value -func.func @min_clip_by_value(%V__0: tensor) -> tensor { - %dims0 = "tf.Const"() { value = dense<[1, 2]> : tensor<2xi32> }: () -> tensor<2xi32> - %0 = "tf.Min"(%V__0, %dims0) {keep_dims = true} : (tensor, tensor<2xi32>) -> tensor - %1 = "tf.ClipByValue"(%V__0, %0, %V__0) : (tensor, tensor, tensor) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @rint_sq_sub -func.func @rint_sq_sub(%arg0: tensor) -> tensor { - // CHECK: linalg.generic - // CHECK-NOT: linalg.generic - %0 = "tf.Rint"(%arg0) : (tensor) -> tensor - %1 = "tf.Square"(%arg0) : (tensor) -> tensor - %2 = "tf.Sub"(%0, %1) : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// ----- - -// CHECK-LABEL: @do_not_fuse_if_multiple_uses -func.func @do_not_fuse_if_multiple_uses(%arg0: tensor) - -> (tensor, tensor) { - // CHECK: linalg.generic - // CHECK: math.rsqrt - // CHECK-NEXT: math.rsqrt - // CHECK-NEXT: linalg.yield - %0 = "tf.Rsqrt"(%arg0) : (tensor) -> tensor - %1 = "tf.Rsqrt"(%0) : (tensor) -> tensor - // CHECK: linalg.generic - // CHECK: math.rsqrt - // CHECK-NEXT: linalg.yield - %2 = "tf.Rsqrt"(%1) : (tensor) -> tensor - // CHECK-NOT: linalg.generic - func.return %1, %2 : tensor, tensor -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline_vectorized.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline_vectorized.mlir index f5f7ad855e5..66bfff88f31 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline_vectorized.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_pipeline_vectorized.mlir @@ -9,14 +9,9 @@ func.func @reduce_row_sum_2d_dynamic(%input: tensor) -> tensor { : (tensor, tensor<1xi32>) -> tensor func.return %0 : tensor } -// CHECK: linalg.fill -// CHECK: scf.parallel -// CHECK: scf.for -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> -// CHECK-NOT: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> +// CHECK: scf.parallel +// CHECK: scf.for +// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> // ----- @@ -28,14 +23,9 @@ func.func @reduce_column_sum_2d_dynamic(%input: tensor) -> tensor, tensor<1xi32>) -> tensor func.return %0 : tensor } -// CHECK: linalg.fill -// CHECK: scf.parallel -// CHECK: scf.for -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> -// CHECK-NOT: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> +// CHECK: scf.parallel +// CHECK: scf.for +// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> // ----- @@ -47,17 +37,11 @@ func.func @reduce_row_mean_2d_dynamic(%input: tensor) -> tensor : (tensor, tensor<1xi32>) -> tensor func.return %0 : tensor } -// CHECK: linalg.fill -// CHECK: scf.parallel -// CHECK: scf.for -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> -// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> -// CHECK-NOT: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> -// CHECK: scf.parallel -// CHECK: vector.broadcast %{{.*}} : f32 to vector<8xf32> -// CHECK-NEXT: arith.divf %{{.*}}, %{{.*}} : vector<8xf32> +// CHECK: scf.parallel +// CHECK: scf.for +// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> +// CHECK: scf.yield +// CHECK: arith.divf %{{.*}}, %{{.*}} : vector<4xf32> // ----- @@ -85,11 +69,7 @@ func.func @reduction_of_cast(%arg0: tensor) -> tensor { : (tensor, tensor<1xi32>) -> tensor func.return %1 : tensor } +// CHECK: scf.parallel +// CHECK: arith.trunci // CHECK: scf.for -// CHECK: arith.trunci %{{.*}} : vector<4x8xi64> to vector<4x8xi32> -// CHECK: arith.muli %{{.*}}, %{{.*}} : vector<8xi32> -// CHECK: vector.reduction -// CHECK: scf.for -// CHECK: linalg.generic -// CHECK: arith.trunci -// CHECK: arith.muli +// CHECK: arith.muli diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_tile_cwise.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_tile_cwise.mlir deleted file mode 100644 index 412db03b729..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_tile_cwise.mlir +++ /dev/null @@ -1,61 +0,0 @@ -// RUN: tf-tfrt-opt -tf-jitrt-tile-cwise %s | FileCheck %s - -#map = affine_map<(d0, d1) -> (d0, d1)> -func.func @tanh_2d(%input: tensor) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %dim0 = tensor.dim %input, %c0 : tensor - %dim1 = tensor.dim %input, %c1 : tensor - %init = tensor.empty(%dim0, %dim1) : tensor - %1 = linalg.generic - {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} - ins(%input : tensor) - outs(%init : tensor) - { - ^bb0(%arg1: f32, %arg2: f32): - %2 = math.tanh %arg1 : f32 - linalg.yield %2 : f32 - } -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @tanh_2d( -// CHECK-SAME: %[[INPUT:.*]]: tensor) -> tensor { -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[STEP:.*]] = arith.constant 8 : index -// CHECK-NOT: tensor.dim -// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[INPUT]], %[[C0]] -// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[INPUT]], %[[C1]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]]) -// CHECK-DAG: %[[DIM0_OUT:.*]] = tensor.dim %[[INPUT]], %[[C0]] -// CHECK-DAG: %[[DIM1_OUT:.*]] = tensor.dim %[[INPUT]], %[[C1]] -// CHECK: %[[OUTPUT:.*]] = gml_st.loop -// CHECK-SAME: (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[DIM0_OUT]], %[[DIM1_OUT]]) -// CHECK-SAME: step (%[[C1]], %[[STEP]]) -// CHECK-SAME: ins (%[[IN_TENS:.*]] = %[[INPUT]]: tensor) -// CHECK-SAME: outs (%[[OUT_TENS:.*]] = %[[INIT]]: tensor) { -// CHECK: %[[IN_SLICE:.*]] = tensor.extract_slice -// CHECK-SAME: %[[IN_TENS]]{{\[}}%[[ARG1]], %[[ARG2]]] -// CHECK-SAME: {{\[}}1, %{{.*}}] [1, 1] -// CHECK: %[[OUT_SLICE:.*]] = tensor.extract_slice -// CHECK-SAME: %[[OUT_TENS]]{{\[}}%[[ARG1]], %[[ARG2]]] -// CHECK-SAME: {{\[}}1, %{{.*}}] [1, 1] -// CHECK: %[[VECTOR_RESULT:.*]] = linalg.generic -// CHECK-SAME: {indexing_maps = [#map1, #map1], -// CHECK-SAME: iterator_types = ["parallel", "parallel"]} -// CHECK-SAME: ins(%[[IN_SLICE]] : tensor<1x?xf32>) -// CHECK-SAME: outs(%[[OUT_SLICE]] : tensor<1x?xf32>) { -// CHECK-NEXT: ^bb0(%[[SCALAR_INPUT:.*]]: f32, %[[VAL_20:.*]]: f32): -// CHECK-NEXT: %[[TANH_OUT:.*]] = math.tanh %[[SCALAR_INPUT]] : f32 -// CHECK-NEXT: linalg.yield %[[TANH_OUT]] : f32 -// CHECK-NEXT: } -> tensor<1x?xf32> -// CHECK-NEXT: %[[INSERT_RESULT:.*]] = tensor.insert_slice -// CHECK-SAME: %[[VAL_23:.*]] into %[[OUT_TENS]] -// CHECK-SAME: {{\[}}%[[ARG1]], %[[ARG2]]] [1, -// CHECK-SAME: %{{.*}}] [1, 1] -// CHECK-NEXT: gml_st.yield %[[INSERT_RESULT]] : tensor -// CHECK-NEXT: } -// CHECK-NEXT: return %[[FINAL_OUTPUT:.*]] : tensor -// CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_tile_fill.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_tile_fill.mlir deleted file mode 100644 index 4385001296f..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_tile_fill.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: tf-tfrt-opt -tf-jitrt-tile-fill %s | FileCheck %s - -func.func @fill(%tensor : tensor<64xf32>, %value : f32) -> tensor<64xf32> { - %0 = linalg.fill ins(%value : f32) outs(%tensor : tensor<64xf32>) -> tensor<64xf32> - func.return %0 : tensor<64xf32> -} -// CHECK-LABEL: func @fill( -// CHECK-SAME: %[[TNSR:.*]]: tensor<64xf32>, %[[VAL:.*]]: f32) -// CHECK-DAG: %[[STEP:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: gml_st.loop (%[[I:.*]]) = (%[[C0]]) to (%[[C64]]) -// CHECK-SAME: step (%[[STEP]]) -// CHECK-SAME: ins (%[[VAL_:.*]] = %[[VAL]]: f32) -// CHECK-SAME: outs (%[[OUT_:.*]] = %[[TNSR]]: tensor<64xf32>) -// CHECK: %[[SLICE_:.*]] = tensor.extract_slice %[[OUT_]][%[[I]]] [8] [1] -// CHECK: %[[FILLED_:.*]] = linalg.fill ins(%[[VAL_]]{{.*}}outs(%[[SLICE_]] -// CHECK: %[[INSERTED_:.*]] = tensor.insert_slice %[[FILLED_]] into %[[OUT_]][%[[I]]] [8] [1] -// CHECK: gml_st.yield %[[INSERTED_:.*]] diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_tile_matmul.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_tile_matmul.mlir deleted file mode 100644 index 5c24d7b4dc7..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_tile_matmul.mlir +++ /dev/null @@ -1,94 +0,0 @@ -// RUN: tf-tfrt-opt %s -split-input-file \ -// RUN: -xla-cpu-transform-matmul="tile-sizes=0,0,0" \ -// RUN: | FileCheck %s --check-prefix=TILE-EMPTY - -// RUN: tf-tfrt-opt %s -split-input-file \ -// RUN: -xla-cpu-transform-matmul="tile-sizes=8,4,2" \ -// RUN: | FileCheck %s - -func.func @matmul(%arg0: tensor, %arg1: tensor) -> tensor { - %c0 = arith.constant 0 : index - %0 = tensor.dim %arg0, %c0 : tensor - %c1 = arith.constant 1 : index - %1 = tensor.dim %arg1, %c1 : tensor - %2 = tensor.empty(%0, %1) : tensor - %cst = arith.constant 0.000000e+00 : f32 - %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor) -> tensor - %4 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) - outs(%3 : tensor) -> tensor - return %4 : tensor -} - -// TILE-EMPTY-LABEL: func @matmul( -// TILE-EMPTY-SAME: %[[LHS:.*]]: tensor, -// TILE-EMPTY-SAME: %[[RHS:.*]]: tensor) -> tensor - -// TILE-EMPTY-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32 -// TILE-EMPTY-DAG: %[[C0:.*]] = arith.constant 0 : index -// TILE-EMPTY-DAG: %[[C1:.*]] = arith.constant 1 : index - -// TILE-EMPTY: %[[DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : [[TY_2D:.*]] -// TILE-EMPTY: %[[DIM_1:.*]] = tensor.dim %[[RHS]], %[[C1]] : [[TY_2D]] -// TILE-EMPTY: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_1]]) : [[TY_2D]] -// TILE-EMPTY: %[[FILL:.*]] = linalg.fill ins(%[[C0_F32]]{{.*}}outs(%[[INIT]] - -// TILE-EMPTY: %[[MATMUL:.*]] = linalg.matmul -// TILE-EMPTY-SAME: ins(%[[LHS]], %[[RHS]] : [[TY_2D]], [[TY_2D]]) -// TILE-EMPTY-SAME: outs(%[[FILL]] : [[TY_2D]]) - -// ----- - -// CHECK-LABEL: func @matmul( -// CHECK-SAME: %[[LHS:.*]]: tensor, -// CHECK-SAME: %[[RHS:.*]]: tensor) - -// CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index - -// CHECK: %[[DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : [[TY_2D:.*]] -// CHECK: %[[DIM_1:.*]] = tensor.dim %[[RHS]], %[[C1]] : [[TY_2D]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_1]]) : [[TY_2D]] -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[C0_F32]]{{.*}}outs(%[[INIT]] -// CHECK: %[[LHS_ROW:.*]] = tensor.dim %[[LHS]], %[[C0]] : [[TY_2D]] -// CHECK: %[[LHS_COL:.*]] = tensor.dim %[[LHS]], %[[C1]] : [[TY_2D]] -// CHECK: %[[RHS_COL:.*]] = tensor.dim %[[RHS]], %[[C1]] : [[TY_2D]] - -// CHECK: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[LHS_ROW]], %[[RHS_COL]]) step (%[[C8]], %[[C4]]) - -// CHECK: %[[LHS_TILE:.*]] = gml_st.tile [%[[I]], 0] -// CHECK: %[[LHS_SLICE:.*]] = gml_st.materialize %[[LHS]][%[[LHS_TILE]]] - -// CHECK: %[[RHS_TILE:.*]] = gml_st.tile [0, %[[J]]] -// CHECK: %[[RHS_SLICE:.*]] = gml_st.materialize %[[RHS]][%[[RHS_TILE]]] - -// CHECK: %[[OUT_TILE:.*]] = gml_st.tile [%[[I]], %[[J]]] -// CHECK: %[[OUT_SLICE:.*]] = gml_st.materialize %[[FILL]][%[[OUT_TILE]]] - -// CHECK: %[[LHS_SUB_ROW:.*]] = tensor.dim %[[LHS_SLICE]], %[[C0]] : [[TY_2D]] -// CHECK: %[[LHS_SUB_COL:.*]] = tensor.dim %[[LHS_SLICE]], %[[C1]] : [[TY_2D]] -// CHECK: %[[RHS_SUB_COL:.*]] = tensor.dim %[[RHS_SLICE]], %[[C1]] : [[TY_2D]] -// CHECK: %[[FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) -// CHECK-SAME: to (%[[LHS_SUB_COL]]) step (%[[C2]]) -// CHECK-SAME: outs (%[[OUT_SUB_ARG:.*]] = %[[OUT_SLICE]]: [[TY_2D]]) - -// CHECK: %[[LHS_SUB_TILE:.*]] = gml_st.tile [0, %[[K]]] -// CHECK: %[[LHS_SUB_SLICE:.*]] = gml_st.materialize %[[LHS_SLICE]][%[[LHS_SUB_TILE]]] - -// CHECK: %[[RHS_SUB_TILE:.*]] = gml_st.tile [%[[K]], 0] -// CHECK: %[[RHS_SUB_SLICE:.*]] = gml_st.materialize %[[RHS_SLICE]][%[[RHS_SUB_TILE]]] - -// CHECK: %[[OUT_SUB_TILE:.*]] = gml_st.tile [0, 0] [%[[LHS_SUB_ROW]], %[[RHS_SUB_COL]]] -// CHECK: %[[OUT_SUB_SLICE:.*]] = gml_st.materialize %[[OUT_SUB_ARG]][%[[OUT_SUB_TILE]]] - -// CHECK: %[[MATMUL:.*]] = linalg.matmul -// CHECK-SAME: ins(%[[LHS_SUB_SLICE]], %[[RHS_SUB_SLICE]] : [[TY_2D]], [[TY_2D]]) -// CHECK: outs(%[[OUT_SUB_SLICE]] : [[TY_2D]]) - -// CHECK-NEXT: gml_st.set_yield %[[MATMUL]] into %[[OUT_SUB_ARG]][%[[OUT_SUB_TILE]]] - -// CHECK: gml_st.set_yield %[[FOR]] into %[[FILL]][%[[OUT_TILE]]] diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_tile_reduction.mlir b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_tile_reduction.mlir deleted file mode 100644 index e791a49b5c7..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_tile_reduction.mlir +++ /dev/null @@ -1,238 +0,0 @@ -// RUN: tf-tfrt-opt %s -split-input-file \ -// RUN: -tf-jitrt-tile-reduction="reduction-2d-tile-sizes=4,4 reduction-vector-size=8 reduction-1d-tile-size=16" \ -// RUN: | FileCheck %s - -func.func @reduce_row_sum_2d(%lhs: tensor, - %rhs: tensor) -> tensor { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %0 = tensor.dim %lhs, %c0 : tensor - - %init = tensor.empty(%0) : tensor - %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor - %sum_of_prod = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%lhs, %rhs : tensor, tensor) - outs(%fill : tensor) { - ^bb0(%l: f32, %r: f32, %o: f32): - %prod = arith.mulf %l, %r : f32 - %add = arith.addf %prod, %o : f32 - linalg.yield %add : f32 - } -> tensor - func.return %sum_of_prod : tensor -} -// CHECK-LABEL: func @reduce_row_sum_2d( -// CHECK-SAME: %[[LHS:.*]]: tensor, -// CHECK-SAME: %[[RHS:.*]]: tensor) -> tensor - -// CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - -// CHECK: %[[DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : [[TY_2D:.*]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]]) : [[TY_1D:.*]] -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[C0_F32]]{{.*}}outs(%[[INIT]] -// CHECK: %[[DIM_0_:.*]] = tensor.dim %[[LHS]], %[[C0]] : [[TY_2D]] -// CHECK: %[[DIM_1:.*]] = tensor.dim %[[LHS]], %[[C1]] : [[TY_2D]] - -// CHECK: gml_st.loop (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[DIM_0_]], %[[DIM_1]]) step (%[[C4]], %[[C4]]) -// CHECK-SAME: ins (%[[LHS_:.*]] = %[[LHS]]: [[TY_2D]], -// CHECK-SAME: %[[RHS_:.*]] = %[[RHS]]: [[TY_2D]]) -// CHECK-SAME: outs (%[[OUT_:.*]] = %[[FILL]]: [[TY_1D]]) - -// CHECK: %[[LHS_SUB:.*]] = tensor.extract_slice %[[LHS_]][%[[I]], %[[J]]] -// CHECK: %[[RHS_SUB:.*]] = tensor.extract_slice %[[RHS_]][%[[I]], %[[J]]] -// CHECK: %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]]] - -// CHECK: %[[SUM_SUB:.*]] = linalg.generic -// CHECK-SAME: ins(%[[LHS_SUB]], %[[RHS_SUB]] : [[TY_2D]], [[TY_2D]]) -// CHECK-SAME: outs(%[[OUT_SUB]] : [[TY_1D]]) -// CHECK: mulf -// CHECK: addf -// CHECK-NEXT: linalg.yield - -// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[SUM_SUB]] into %[[OUT_]] -// CHECK-NEXT: gml_st.yield %[[UPDATE]] : [[TY_1D]] - -// ----- - -func.func @reduce_row_sum_2d_static(%input: tensor<8x16xf32>) -> tensor<8xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %0 = tensor.dim %input, %c0 : tensor<8x16xf32> - - %init = tensor.empty() : tensor<8xf32> - %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<8xf32>) -> tensor<8xf32> - %sum = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%input : tensor<8x16xf32>) - outs(%fill : tensor<8xf32>) { - ^bb0(%in: f32, %out: f32): - %add = arith.addf %in, %out : f32 - linalg.yield %add : f32 - } -> tensor<8xf32> - func.return %sum : tensor<8xf32> -} -// CHECK-LABEL: func @reduce_row_sum_2d_static -// CHECK: gml_st.loop -// CHECK: tensor.insert_slice - -// ----- - -func.func @reduce_column_sum_2d(%input: tensor) -> tensor { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %0 = tensor.dim %input, %c0 : tensor - - %init = tensor.empty(%0) : tensor - %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor - %sum = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d1)>], - iterator_types = ["reduction", "parallel"]} - ins(%input : tensor) - outs(%fill : tensor) { - ^bb0(%in: f32, %out: f32): - %add = arith.addf %in, %out : f32 - linalg.yield %add : f32 - } -> tensor - func.return %sum : tensor -} -// CHECK-LABEL: func @reduce_column_sum_2d -// CHECK-SAME: %[[INPUT:.*]]: tensor) -> tensor - -// CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - -// CHECK: %[[DIM_0:.*]] = tensor.dim %[[INPUT]], %[[C0]] : [[TY_2D:.*]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]]) : [[TY_1D:.*]] -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[C0_F32]]{{.*}}outs(%[[INIT]] -// CHECK: %[[DIM_0_:.*]] = tensor.dim %[[INPUT]], %[[C0]] : [[TY_2D]] -// CHECK: %[[DIM_1:.*]] = tensor.dim %[[INPUT]], %[[C1]] : [[TY_2D]] - -// CHECK: gml_st.loop (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[DIM_0_]], %[[DIM_1]]) step (%[[C4]], %[[C4]]) -// CHECK-SAME: ins (%[[IN_:.*]] = %[[INPUT]]: [[TY_2D]]) -// CHECK-SAME: outs (%[[OUT_:.*]] = %[[FILL]]: [[TY_1D]]) - -// CHECK: %[[IN_SUB:.*]] = tensor.extract_slice %[[IN_]][%[[I]], %[[J]]] -// CHECK: %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[J]]] - -// CHECK: %[[SUM_SUB:.*]] = linalg.generic -// CHECK-SAME: ins(%[[IN_SUB]] : [[TY_2D]]) -// CHECK-SAME: outs(%[[OUT_SUB]] : [[TY_1D]]) -// CHECK: addf -// CHECK-NEXT: linalg.yield - -// CHECK: %[[UPDATE:.*]] = tensor.insert_slice %[[ACC:.*]] into %[[OUT_]] -// CHECK: gml_st.yield %[[UPDATE]] : [[TY_1D]] - -// ----- - -func.func @abs(%input: tensor) -> tensor { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = tensor.dim %input, %c0 : tensor - %1 = tensor.dim %input, %c1 : tensor - - %init = tensor.empty(%0, %1) : tensor - %sum = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%input : tensor) - outs(%init : tensor) { - ^bb0(%in: f32, %out: f32): - %abs = math.absf %in: f32 - linalg.yield %abs : f32 - } -> tensor - func.return %sum : tensor -} -// CHECK-LABEL: func @abs -// CHECK-NOT: gml_st.loop - -// ----- - -func.func @reduce_sum_1d(%lhs: tensor, %rhs: tensor) -> tensor { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %0 = tensor.dim %lhs, %c0 : tensor - - %init = tensor.empty() : tensor - %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor - %sum = linalg.generic { - indexing_maps = [affine_map<(d0) -> (d0)>, - affine_map<(d0) -> (d0)>, - affine_map<(d0) -> ()>], - iterator_types = ["reduction"]} - ins(%lhs, %rhs : tensor, tensor) - outs(%fill : tensor) { - ^bb0(%l: f32, %r: f32, %out: f32): - %prod = arith.mulf %l, %r : f32 - %add = arith.addf %prod, %out : f32 - linalg.yield %add : f32 - } -> tensor - func.return %sum : tensor -} - -// CHECK-LABEL: func @reduce_sum_1d( -// CHECK-SAME: %[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) - // CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32 - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index - - // CHECK: %[[INIT:.*]] = tensor.empty() : tensor - // CHECK: %[[FILL:.*]] = linalg.fill ins(%[[C0_F32]]{{.*}}outs(%[[INIT]] - // CHECK: %[[INPUT_SIZE:.*]] = tensor.dim %[[LHS]], %[[C0]] - - // CHECK: %[[TMP_INIT:.*]] = tensor.empty() : tensor<8xf32> - // CHECK: %[[TMP_FILL:.*]] = linalg.fill ins(%[[C0_F32]]{{.*}}outs(%[[TMP_INIT]] - // CHECK: %[[TILABLE_UB:.*]] = affine.apply -// CHECK-SAME: %[[INPUT_SIZE]] - // CHECK: %[[TMP_SUM:.*]] = gml_st.loop (%[[I:.*]]) = (%[[C0]]) -// CHECK-SAME: to (%[[TILABLE_UB]]) step (%[[C16]]) -// CHECK-SAME: ins (%[[LHS_:.*]] = %[[LHS]]: tensor, -// CHECK-SAME: %[[RHS_:.*]] = %[[RHS]]: tensor) -// CHECK-SAME: outs (%[[TMP_INIT_:.*]] = %[[TMP_FILL]]: tensor<8xf32>) - - // CHECK: %[[LHS_SUB:.*]] = tensor.extract_slice %[[LHS_]][%[[I]]] - // CHECK: %[[LHS_RESHAPE:.*]] = tensor.expand_shape %[[LHS_SUB]] -// CHECK-SAME: {{\[\[}}0, 1]] -// CHECK-SAME: : tensor<16xf32> into tensor<2x8xf32> - - // CHECK: %[[RHS_SUB:.*]] = tensor.extract_slice %[[RHS_]][%[[I]]] - // CHECK: %[[RHS_RESHAPE:.*]] = tensor.expand_shape %[[RHS_SUB]] -// CHECK-SAME: {{\[\[}}0, 1]] -// CHECK-SAME: : tensor<16xf32> into tensor<2x8xf32> - - // CHECK: %[[SUM_OF_PROD:.*]] = linalg.generic -// CHECK-SAME: ins(%[[LHS_RESHAPE]], %[[RHS_RESHAPE]] -// CHECK-SAME: tensor<2x8xf32>, tensor<2x8xf32>) -// CHECK-SAME: outs(%[[TMP_INIT_]] : tensor<8xf32>) { - // CHECK: ^bb0(%[[L:.*]]: f32, %[[R:.*]]: f32, %[[O:.*]]: f32): - // CHECK: %[[MUL:.*]] = arith.mulf %[[L]], %[[R]] : f32 - // CHECK: %[[ADD:.*]] = arith.addf %[[MUL]], %[[O]] : f32 - // CHECK: linalg.yield %[[ADD]] : f32 - // CHECK: } -> tensor<8xf32> - // CHECK: gml_st.yield %[[SUM_OF_PROD]] : tensor<8xf32> - // CHECK: } - // CHECK: %[[HORIZONTAL_REDUCE:.*]] = linalg.generic -// CHECK-SAME: ins(%[[TMP_SUM]] : tensor<8xf32>) outs(%[[FILL]] : tensor) -// CHECK-NOT: mulf -// CHECK: addf - - // CHECK: gml_st.loop (%[[K:.*]]) = (%[[TILABLE_UB]]) -// CHECK-SAME: to (%[[INPUT_SIZE]]) step (%[[C16]]) - // CHECK: linalg.generic -// CHECK: mulf -// CHECK: addf diff --git a/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/BUILD b/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/BUILD index 9b2f0f1719d..8ddeef89b30 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt/BUILD @@ -1,7 +1,10 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "if_oss") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) # copybara:uncomment_begin # diff --git a/tensorflow/compiler/mlir/tfrt/tests/remote_run_encapsulate.mlir b/tensorflow/compiler/mlir/tfrt/tests/remote_run_encapsulate.mlir deleted file mode 100644 index 934336a6c05..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/remote_run_encapsulate.mlir +++ /dev/null @@ -1,47 +0,0 @@ -// RUN: tf-tfrt-opt -tfrt-dist-remote-run-encapsulate %s | FileCheck %s - - -func.func private @init(%arg0 : !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle {tfrt.device = "/job:worker1/task:0/device:CPU:0"}) attributes {host = "/job:worker1/task:0"} { - %ch0 = tfrt.new.chain - %cpu = corert.get_op_handler %ch0 "cpu" - %arg1 = corert.executeop(%cpu) "tfrt_test.create_dense_tensor"() - { shape = [1, 65536], values = [1.0 : f32] } : 1 - %result = corert.executeop(%cpu) "tf.AddV2"(%arg0, %arg1) : 1 - tfrt.return %ch0, %result : !tfrt.chain, !corert.tensorhandle -} - -func.func private @print(%chain : !tfrt.chain, %tensor_handle : !corert.tensorhandle) -> (!tfrt.chain) attributes {host = "/job:worker1/task:0"} { - %ch2 = "corert.print_tensorhandle"(%tensor_handle, %chain) : (!corert.tensorhandle, !tfrt.chain) -> !tfrt.chain - tfrt.return %ch2 : !tfrt.chain -} - -// CHECK-LABEL: func @remote_execute -func.func @remote_execute(%arg0 : !corert.tensorhandle) -> (!tfrt.chain, !tfrt.chain, !corert.tensorhandle) { - %c0 = tfrt.new.chain - // CHECK: %[[CONFIGS:.*]]:2 = tfrt_dist.test_create_configurations : 2 - %configs:2 = tfrt_dist.test_create_configurations : 2 - // CHECK-NEXT: %[[CLIENT_CTX:.*]] = tfrt_dist.test_create_distributed_context %[[CONFIGS]]#0 - %client_context = tfrt_dist.test_create_distributed_context %configs#0 : (!tfrt_dist.dist_context_configuration) -> !tfrt_dist.dist_context - // CHECK-NEXT: %[[WORKER_TASK:.*]] = tfrt_dist.get_task_handle %[[CLIENT_CTX]] {task_name = "/job:worker/task:1"} - %worker_task = tfrt_dist.get_task_handle %client_context {task_name = "/job:worker/task:1"} - // This is the remote invocation of the @print and @init functions, check that - // we correctly serialize and encapsulate them. - // CHECK-NEXT: %[[REGISTER_CHAIN_0:.*]] = tfrt_dist.register_tfrt_function(%[[IN_CHAIN:.*]], %[[CLIENT_CTX]], %[[WORKER_TASK]]) "init" {{.*}}func @init(%[[ARG_0:.*]]: !corert.tensorhandle loc({{.*}}) -> (!tfrt.chain, !corert.tensorhandle {{.*}} - // CHECK-NEXT: %[[SPEC_0:.*]] = tfrt_dist.create_remote_execute_spec - // CHECK-SAME: {output_devices = ["/job:worker1/task:0/device:CPU:0", "/job:worker1/task:0/device:CPU:0"]} - // CHECK-NEXT: %[[OBJECT_ID_0:.*]] = tfrt_dist.get_remote_object_id_from_th %[[ARG_1:.*]] - // CHECK-NEXT: %[[EXEC_CHAIN_0:.*]], %[[RESULTS_0:.*]]:3 = tfrt_dist.remote_execute_th[%[[REGISTER_CHAIN_0]], %[[CLIENT_CTX]], %[[WORKER_TASK]], %[[SPEC_0]], 1] "init"(%[[OBJECT_ID_0]]) : (!tfrt_dist.remote_object_id) -> (!tfrt_dist.remote_object_id, !tfrt_dist.remote_object_id, !corert.tensorhandle) - %execute_chain, %remote_chain, %remote_tensor = tfrt_dist.remote_execute_func [%c0, %client_context, %worker_task] @init(%arg0) : (!corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle) - - // CHECK-NEXT: %[[REGISTER_CHAIN_1:.*]] = tfrt_dist.register_tfrt_function(%[[EXEC_CHAIN_0]], %[[CLIENT_CTX]], %[[WORKER_TASK]]) "print" {{.*}}@print(%[[ARG_2:.*]]: !tfrt.chain loc({{.*}}), %[[ARG_3:.*]]: !corert.tensorhandle loc({{.*}})) -> !tfrt.chain {{.*}} - // CHECK-NEXT: %[[SPEC_1:.*]] = tfrt_dist.create_remote_execute_spec - // CHECK-SAME: {output_devices = ["/job:worker1/task:0/device:CPU:0"]} - // CHECK-NEXT: %[[OBJECT_ID_1:.*]] = tfrt_dist.get_remote_object_id_from_th %[[RESULTS_0]]#2 - // CHECK-NEXT: %[[EXEC_CHAIN_1:.*]], %[[RESULTS_1:.*]] = tfrt_dist.remote_execute_th[%[[REGISTER_CHAIN_1]], %[[CLIENT_CTX]], %[[WORKER_TASK]], %[[SPEC_1]], 0] "print"(%[[RESULTS_0]]#0, %[[OBJECT_ID_1]]) : (!tfrt_dist.remote_object_id, !tfrt_dist.remote_object_id) -> !tfrt_dist.remote_object_id - %execute_chain2, %remote_chain2 = tfrt_dist.remote_execute_func[%execute_chain, %client_context, %worker_task] @print(%remote_chain, %remote_tensor) : (!tfrt.chain, !corert.tensorhandle) -> (!tfrt.chain) - - - // CHECK-NEXT: tfrt.return %[[EXEC_CHAIN_0]], %[[EXEC_CHAIN_1]], %[[RESULTS_0]]#2 : !tfrt.chain, !tfrt.chain, !corert.tensorhandle - tfrt.return %execute_chain, %execute_chain2, %remote_tensor : !tfrt.chain, !tfrt.chain, !corert.tensorhandle -} - diff --git a/tensorflow/compiler/mlir/tfrt/tests/saved_model/BUILD b/tensorflow/compiler/mlir/tfrt/tests/saved_model/BUILD index f7620143d00..f736d94ea71 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/saved_model/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/saved_model/BUILD @@ -1,12 +1,17 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) tf_cc_test( name = "saved_model_test", srcs = ["saved_model_test.cc"], data = [ "testdata/test.mlir", + "testdata/xla_launch.mlir", + "testdata/xla_launch_xla_reduce_window.mlir", ], tags = ["no_oss"], deps = [ diff --git a/tensorflow/compiler/mlir/tfrt/tests/saved_model/saved_model_test.cc b/tensorflow/compiler/mlir/tfrt/tests/saved_model/saved_model_test.cc index e866979f302..995bb242fe3 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/saved_model/saved_model_test.cc +++ b/tensorflow/compiler/mlir/tfrt/tests/saved_model/saved_model_test.cc @@ -101,6 +101,68 @@ TEST(SavedModelTest, CompileToBEF) { TF_ASSERT_OK(ConvertTfMlirToBef(options, module.get(), &bef_buffer)); } +TEST(SavedModelTest, ConvertTfMlirToBefWithXlaFuncExport) { + std::string saved_model_mlir_path = tensorflow::GetDataDependencyFilepath( + "tensorflow/compiler/mlir/tfrt/tests/saved_model/testdata/" + "xla_launch.mlir"); + + mlir::DialectRegistry registry; + mlir::RegisterAllTensorFlowDialects(registry); + mlir::MLIRContext context(registry); + auto module = + mlir::parseSourceFile(saved_model_mlir_path, &context); + ASSERT_TRUE(module); + + tfrt::BefBuffer bef_buffer; + TfrtCompileOptions options; + options.device_target = TfrtDeviceInfraTarget::kGpu; + options.use_bridge_for_gpu = true; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr fallback_state, + tfrt_stub::FallbackState::Create(SessionOptions(), FunctionDefLibrary())); + TF_ASSERT_OK(ConvertTfMlirToBef(options, module.get(), &bef_buffer, + fallback_state.get())); + + // The module contains an XLA function, as well as a while body and a while + // condition within the XLA function. + EXPECT_EQ(fallback_state->process_function_library_runtime() + .GetFunctionLibraryDefinition() + ->num_functions(), + 3); +} + +TEST(SavedModelTest, ConvertTfMlirToBefExportingXlaReduceWindow) { + std::string saved_model_mlir_path = tensorflow::GetDataDependencyFilepath( + "tensorflow/compiler/mlir/tfrt/tests/saved_model/testdata/" + "xla_launch_xla_reduce_window.mlir"); + + mlir::DialectRegistry registry; + mlir::RegisterAllTensorFlowDialects(registry); + mlir::MLIRContext context(registry); + auto module = + mlir::parseSourceFile(saved_model_mlir_path, &context); + ASSERT_TRUE(module); + + tfrt::BefBuffer bef_buffer; + TfrtCompileOptions options; + options.device_target = TfrtDeviceInfraTarget::kGpu; + options.use_bridge_for_gpu = true; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr fallback_state, + tfrt_stub::FallbackState::Create(SessionOptions(), FunctionDefLibrary())); + TF_ASSERT_OK(ConvertTfMlirToBef(options, module.get(), &bef_buffer, + fallback_state.get())); + + // The module contains an XLA function, as well as a sum_reducer function + // referenced by an XlaReduceWindow op. + EXPECT_EQ(fallback_state->process_function_library_runtime() + .GetFunctionLibraryDefinition() + ->num_functions(), + 2); +} + // TODO(b/162442824): Add a SavedModel test that covers the error pass. } // namespace diff --git a/tensorflow/compiler/mlir/tfrt/tests/saved_model/testdata/xla_launch.mlir b/tensorflow/compiler/mlir/tfrt/tests/saved_model/testdata/xla_launch.mlir new file mode 100644 index 00000000000..4553a6fe278 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/saved_model/testdata/xla_launch.mlir @@ -0,0 +1,25 @@ + +func.func @while_cond(%arg0: tensor) -> tensor { + %0 = "tf.Const"() {value = dense<9> : tensor} : () -> tensor + %1 = "tf.Less"(%arg0, %0) {} : (tensor, tensor) -> tensor + func.return %1 : tensor +} + +func.func @while_body(%arg0: tensor) -> tensor { + %1 = "tf.AddV2"(%arg0, %arg0) {} : (tensor, tensor) -> tensor + func.return %1 : tensor +} + +func.func private @xla_func_0(%arg0: tensor<1x3xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._XlaMustCompile = true, tf._noinline = true, tf._original_func_name = "should_not_be_used"} { + %1 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %2 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %3 = "tf.While"(%2) { cond = @while_cond, body = @while_body, is_stateless = false, parallel_iterations = 1} : (tensor) -> (tensor) + func.return %1 : tensor<1x3xf32> +} + +func.func @main(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "input:0", outputs = "output:0"}} { + %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = "/device:CPU:0"} : (tensor>>) -> tensor<1x3xf32> + %2 = "tf.XlaLaunch"(%arg0, %1) {_noinline = true, _xla_compile_device_type = "GPU", device = "/device:GPU:0", function = @xla_func_0, operand_segment_sizes = array} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + func.return %2 : tensor<1x3xf32> +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/saved_model/testdata/xla_launch_xla_reduce_window.mlir b/tensorflow/compiler/mlir/tfrt/tests/saved_model/testdata/xla_launch_xla_reduce_window.mlir new file mode 100644 index 00000000000..8541250a925 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/saved_model/testdata/xla_launch_xla_reduce_window.mlir @@ -0,0 +1,22 @@ + +func.func private @sum_reducer(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +func.func @xla_func_0(%arg0: tensor<7xf32>, %arg1: tensor) -> tensor<10xf32> { + %cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32> + %cst_0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_2 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_3 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32> + %0 = "tf.XlaReduceWindow"(%arg0, %arg1, %cst_0, %cst_1, %cst_2, %cst_3, %cst) {computation = @sum_reducer} : (tensor<7xf32>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +func.func @main(%arg0: tensor<7xf32>) -> tensor<10xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "input:0", outputs = "output:0"}} { + %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = "/device:CPU:0"} : (tensor>>) -> tensor + %2 = "tf.XlaLaunch"(%arg0, %1) {_noinline = true, _xla_compile_device_type = "GPU", device = "/device:GPU:0", function = @xla_func_0, operand_segment_sizes = array} : (tensor<7xf32>, tensor) -> tensor<10xf32> + func.return %2 : tensor<10xf32> +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/sink_in_invariant_ops.mlir b/tensorflow/compiler/mlir/tfrt/tests/sink_in_invariant_ops.mlir new file mode 100644 index 00000000000..84cb50239eb --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/sink_in_invariant_ops.mlir @@ -0,0 +1,248 @@ +// RUN: tf-tfrt-opt -split-input-file -tfrt-sink-in-invariant-ops %s | FileCheck %s --dump-input=fail --dump-input-filter=all + +module attributes {tf_saved_model.semantics} { + +// Test sinks in var handle op to batch function. + +// CHECK-LABEL: func private @batched_function +// CHECK: arg1 +func.func private @batched_function(%arg0: tensor<1x3xf32>, %arg1: tensor<*x!tf_type.resource>) -> tensor<1x3xf32> + attributes {tf._input_shapes = [#tf_type.shape<1x3>, #tf_type.shape<*>], tf.signature.is_stateful} { + // CHECK: [[handle:%.*]] = "tf.VarHandleOp"() + // CHECK: "tf.ReadVariableOp"([[handle]]) + %0 = "tf.ReadVariableOp"(%arg1) {device = "/device:CPU:0"} : (tensor<*x!tf_type.resource>) -> tensor<1x3xf32> + %1 = "tf.AddV2"(%arg0, %0) {device = "/device:CPU:0"} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %2 = "tf.Identity"(%1) {device = "/device:CPU:0"} : (tensor<1x3xf32>) -> tensor<1x3xf32> + func.return %2 : tensor<1x3xf32> +} + +// CHECK-LABEL: func @main +func.func @main(%arg0: tensor<1x3xf32> {tf_saved_model.index_path = ["input"]}) -> (tensor<*xf32> {tf_saved_model.index_path = ["r"]}) + attributes {tf_saved_model.exported_names = ["main"]} { + // CHECK: tf.VarHandleOp + %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor>> + + // CHECK: "tf.BatchFunction"(%arg0, %0) + // CHECK: operand_segment_sizes = array + %1 = "tf.BatchFunction"(%arg0, %0) {allowed_batch_sizes = [6], batch_timeout_micros = 100000 : i64, batching_queue = "", container = "", device = "/device:CPU:0", enable_large_batch_splitting = false, f = @batched_function, max_batch_size = 6 : i64, max_enqueued_batches = 10 : i64, num_batch_threads = 1 : i64, operand_segment_sizes = array, shared_name = "batch/"} : (tensor<1x3xf32>, tensor>>) -> tensor<*xf32> + func.return %1 : tensor<*xf32> +} + +} + +// ----- + +module attributes {tf_saved_model.semantics} { + +// Test sinks in const op to batch function. + +// CHECK-LABEL: func private @batched_function +// CHECK: arg1 +func.func private @batched_function(%arg0: tensor, %arg1: tensor) -> tensor + attributes {tf._input_shapes = [#tf_type.shape<1x3>, #tf_type.shape<*>], tf.signature.is_stateful} { + // CHECK: tf.Const + %1 = "tf.AddV2"(%arg0, %arg1) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor + %2 = "tf.Identity"(%1) {device = "/device:CPU:0"} : (tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: func @main +func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input"]}) -> (tensor {tf_saved_model.index_path = ["r"]}) + attributes {tf_saved_model.exported_names = ["main"]} { + // CHECK: [[handle:%.*]] = "tf.Const"() + %0 = "tf.Const"() {device = "/CPU:0", value = dense<0> : tensor} : () -> tensor + // CHECK: "tf.BatchFunction"(%arg0, [[handle]]) + // CHECK-SAME: operand_segment_sizes = array + %1 = "tf.BatchFunction"(%arg0, %0) {allowed_batch_sizes = [6], batch_timeout_micros = 100000 : i64, batching_queue = "", container = "", device = "/device:CPU:0", enable_large_batch_splitting = false, f = @batched_function, max_batch_size = 6 : i64, max_enqueued_batches = 10 : i64, num_batch_threads = 1 : i64, operand_segment_sizes = array, shared_name = "batch/"} : (tensor, tensor) -> tensor + func.return %1 : tensor +} + +} + +// ----- + +module attributes {tf_saved_model.semantics} { + +// Test sink in multiple invariant ops. + +// CHECK-LABEL: func private @batched_function +func.func private @batched_function(%arg0: tensor>>, %arg1: tensor>>) -> tensor<1x3xf32> + attributes {tf._input_shapes = [#tf_type.shape<1x3>, #tf_type.shape<*>], tf.signature.is_stateful} { + // CHECK: [[handle1:%.*]] = "tf.VarHandleOp"() {{{.*}}, shared_name = "variable1"} + // CHECK: [[handle2:%.*]] = "tf.VarHandleOp"() {{{.*}}, shared_name = "variable2"} + // CHECK: "tf.ReadVariableOp"([[handle1]]) + // CHECK: "tf.ReadVariableOp"([[handle2]]) + %0 = "tf.ReadVariableOp"(%arg0) {device = "/device:CPU:0"} : (tensor>>) -> tensor<1x3xf32> + %1 = "tf.ReadVariableOp"(%arg1) {device = "/device:CPU:0"} : (tensor>>) -> tensor<1x3xf32> + %2 = "tf.AddV2"(%0, %1) {device = "/device:CPU:0"} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %3 = "tf.Identity"(%2) {device = "/device:CPU:0"} : (tensor<1x3xf32>) -> tensor<1x3xf32> + func.return %3 : tensor<1x3xf32> +} + +// CHECK-LABEL: func @main +func.func @main(%arg0: tensor<1x3xf32> {tf_saved_model.index_path = ["input"]}) -> (tensor<*xf32> {tf_saved_model.index_path = ["r"]}) + attributes {tf_saved_model.exported_names = ["main"]} { + // CHECK: tf.VarHandleOp + %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable1"} : () -> tensor>> + %1 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable2"} : () -> tensor>> + // CHECK: "tf.BatchFunction"(%0, %1) + %2 = "tf.BatchFunction"(%0, %1) {allowed_batch_sizes = [6], batch_timeout_micros = 100000 : i64, batching_queue = "", container = "", device = "/device:CPU:0", enable_large_batch_splitting = false, f = @batched_function, max_batch_size = 6 : i64, max_enqueued_batches = 10 : i64, num_batch_threads = 1 : i64, operand_segment_sizes = array, shared_name = "batch/"} : (tensor>>, tensor>>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> +} + +} + +// ----- + +module attributes {tf_saved_model.semantics} { + +// Test sinks in var handle op that used by control flow ops. + +// CHECK-LABEL: func private @some_func +func.func private @some_func( + %arg: tensor>>) -> tensor { + // CHECK: tf.VarHandleOp + // CHECK: tf.ReadVariableOp + %0 = "tf.ReadVariableOp"(%arg) {device = "cpu"} : (tensor>>) -> tensor + + func.return %0 : tensor +} + +// CHECK-LABEL: func private @some_other_func +func.func private @some_other_func( + %arg: tensor>>) -> tensor { + // CHECK: [[handle:%.*]] = "tf.VarHandleOp"() + // CHECK: "tf.ReadVariableOp"([[handle]]) + %0 = "tf.ReadVariableOp"(%arg) {device = "cpu"} : (tensor>>) -> tensor + + func.return %0 : tensor +} + +// CHECK-LABEL: func @sink_in_stateful_call +func.func @sink_in_stateful_call(%arg: tensor {tf_saved_model.index_path = ["input"]}) -> (tensor {tf_saved_model.index_path = ["r"]}) + attributes {tf_saved_model.exported_names = ["test_sink_in_stateful_call"]} { + // CHECK: [[handle:%.*]] = "tf.VarHandleOp"() + %handle = "tf.VarHandleOp"() {container = "", shared_name = "x"} : () -> tensor>> + // CHECK: "tf.StatefulPartitionedCall"([[handle]]) + %x = "tf.StatefulPartitionedCall"(%handle) {device = "/CPU:0", config = "", config_proto = "", executor_type = "", f = @some_func} : (tensor>>) -> (tensor) + %r = "tf.AddV2"(%arg, %x) {device = "/CPU:0"} : (tensor, tensor) -> tensor + func.return %r : tensor +} + +// CHECK-LABEL: func @sink_in_if +func.func @sink_in_if(%arg: tensor {tf_saved_model.index_path = ["input"]}) -> (tensor {tf_saved_model.index_path = ["r"]}) + attributes {tf_saved_model.exported_names = ["test_sink_in_if"]} { + // CHECK: [[handle:%.*]] = "tf.VarHandleOp"() + %handle = "tf.VarHandleOp"() {container = "", shared_name = "x"} : () -> tensor>> + // CHECK: [[cond:%.*]] = "tf.Const"() + %cond = "tf.Const"() {device = "/CPU:0", value = dense : tensor} : () -> tensor + // CHECK: "tf.If"([[cond]], [[handle]]) + %x = "tf.If"(%cond, %handle) {then_branch = @some_other_func, else_branch = @some_other_func, is_stateless = false} : (tensor, tensor>>) -> tensor + %r = "tf.AddV2"(%arg, %x) {device = "/CPU:0"} : (tensor, tensor) -> tensor + func.return %r : tensor +} + +} + +// ----- + +module attributes {tf_saved_model.semantics} { + +// Test doesn't sink in to the callee that invoked by multiple callers. + +// CHECK: func private @some_func([[arg0:.+]]: tensor>>) +func.func private @some_func(%arg0: tensor>>) -> tensor { + // CHECK-NOT: tf.VarHandleOp + // CHECK: tf.ReadVariableOp + %0 = "tf.ReadVariableOp"(%arg0) {device = "cpu"} : (tensor>>) -> tensor + + func.return %0 : tensor +} + +// CHECK-LABEL: func @sink_in_stateful_call +func.func @sink_in_stateful_call(%arg0: tensor {tf_saved_model.index_path = ["input"]}) -> (tensor {tf_saved_model.index_path = ["r"]}) + attributes {tf_saved_model.exported_names = ["test_sink_in_stateful_call"]} { + // CHECK: tf.VarHandleOp + %0 = "tf.VarHandleOp"() {container = "", shared_name = "x"} : () -> tensor>> + // CHECK: "tf.StatefulPartitionedCall"(%0) + %1 = "tf.StatefulPartitionedCall"(%0) {device = "/CPU:0", config = "", config_proto = "", executor_type = "", f = @some_func} : (tensor>>) -> (tensor) + %2 = "tf.AddV2"(%arg0, %1) {device = "/CPU:0"} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: func @sink_in_if +func.func @sink_in_if(%arg0: tensor {tf_saved_model.index_path = ["input"]}) -> (tensor {tf_saved_model.index_path = ["r"]}) + attributes {tf_saved_model.exported_names = ["test_sink_in_if"]} { + // CHECK: tf.VarHandleOp + %0 = "tf.VarHandleOp"() {container = "", shared_name = "x"} : () -> tensor>> + %cst = "tf.Const"() {device = "/CPU:0", value = dense : tensor} : () -> tensor + // CHECK: "tf.If"(%cst, %0) + %1 = "tf.If"(%cst, %0) {then_branch = @some_func, else_branch = @some_func, is_stateless = false} : (tensor, tensor>>) -> tensor + %2 = "tf.AddV2"(%arg0, %1) {device = "/CPU:0"} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +} + +// ----- + +module attributes {tf_saved_model.semantics} { + +// Test doesn't sink in var handle op + read variable op. Consider implement when we see it from production. + +// CHECK-LABEL: func private @batched_function +func.func private @batched_function(%arg0: tensor<1x3xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x3xf32> + attributes {tf._input_shapes = [#tf_type.shape<1x3>, #tf_type.shape<*>], tf.signature.is_stateful} { + // CHECK-NOT: tf.VarHandleOp + // CHECK-NOT: tf.ReadVariableOp + %1 = "tf.AddV2"(%arg0, %arg1) {device = "/device:CPU:0"} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %2 = "tf.Identity"(%1) {device = "/device:CPU:0"} : (tensor<1x3xf32>) -> tensor<1x3xf32> + func.return %2 : tensor<1x3xf32> +} + +// CHECK-LABEL: func @main +func.func @main(%arg0: tensor<1x3xf32> {tf_saved_model.index_path = ["input"]}) -> (tensor<*xf32> {tf_saved_model.index_path = ["r"]}) + attributes {tf_saved_model.exported_names = ["main"]} { + // CHECK: [[handle:%.*]] = "tf.VarHandleOp"() + // CHECK: "tf.ReadVariableOp"([[handle]]) + %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = "/device:CPU:0"} : (tensor>>) -> tensor<1x3xf32> + // CHECK: "tf.BatchFunction"(%arg0, %1) + %2 = "tf.BatchFunction"(%arg0, %1) {allowed_batch_sizes = [6], batch_timeout_micros = 100000 : i64, batching_queue = "", container = "", device = "/device:CPU:0", enable_large_batch_splitting = false, f = @batched_function, max_batch_size = 6 : i64, max_enqueued_batches = 10 : i64, num_batch_threads = 1 : i64, operand_segment_sizes = array, shared_name = "batch/"} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> +} + +} + +// ----- + +module attributes {tf_saved_model.semantics} { + +// Test sinks in var handle op if it's used by one callee, and also by read only ops in the current funciton. + +// CHECK-LABEL: func private @batched_function +// CHECK: arg1 +func.func private @batched_function(%arg0: tensor<1x3xf32>, %arg1: tensor>>) -> tensor<1x3xf32> + attributes {tf._input_shapes = [#tf_type.shape<1x3>, #tf_type.shape<*>], tf.signature.is_stateful} { + // CHECK: tf.VarHandleOp + // CHECK: tf.ReadVariableOp + %1 = "tf.ReadVariableOp"(%arg1) {device = "/device:CPU:0"} : (tensor>>) -> tensor<1x3xf32> + %2 = "tf.AddV2"(%arg0, %1) {device = "/device:CPU:0"} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %3 = "tf.Identity"(%1) {device = "/device:CPU:0"} : (tensor<1x3xf32>) -> tensor<1x3xf32> + func.return %2 : tensor<1x3xf32> +} + +// CHECK-LABEL: func @main +func.func @main(%arg0: tensor<1x3xf32> {tf_saved_model.index_path = ["input"]}) -> (tensor<*xf32> {tf_saved_model.index_path = ["r"]}) + attributes {tf_saved_model.exported_names = ["main"]} { + // CHECK: [[handle:%.*]] = "tf.VarHandleOp"() + %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor>> + // CHECK: "tf.ReadVariableOp"([[handle]]) + %1 = "tf.ReadVariableOp"(%0) {device = "/device:CPU:0"} : (tensor>>) -> tensor<1x3xf32> + // CHECK: "tf.BatchFunction"(%arg0, [[handle]]) + // CHECK-SAME: operand_segment_sizes = array + %2 = "tf.BatchFunction"(%arg0, %0) {allowed_batch_sizes = [6], batch_timeout_micros = 100000 : i64, batching_queue = "", container = "", device = "/device:CPU:0", enable_large_batch_splitting = false, f = @batched_function, max_batch_size = 6 : i64, max_enqueued_batches = 10 : i64, num_batch_threads = 1 : i64, operand_segment_sizes = array, shared_name = "batch/"} : (tensor<1x3xf32>, tensor>>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> +} + +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/BUILD b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/BUILD index 30aa6a6ef3d..4badbc11669 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/BUILD @@ -1,12 +1,18 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow:tensorflow.bzl", "if_oss") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], driver = "//tensorflow/compiler/mlir:run_lit.sh", features = if_oss(["--path=org_tensorflow/tensorflow/compiler/mlir/tfrt"]), + size_override = { + "fallback.mlir": "medium", + }, test_file_exts = ["mlir"], ) diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/attributes.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/attributes.mlir index f314558dee9..77e795b8bf4 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/attributes.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/attributes.mlir @@ -1,4 +1,4 @@ -// RUN: tf-tfrt-opt -tf-to-tfrt %s | FileCheck %s --dump-input=fail +// RUN: tf-tfrt-opt -tf-to-tfrt=func-use-fallback-tensor=true %s | FileCheck %s --dump-input=fail // _output_shapes and f.* attributes are removed during tf-to-tfrt lowering. // CHECK-LABEL: func @remove_unused_attr @@ -14,8 +14,8 @@ func.func @basic( %arg1: tensor>>) -> (tensor<3x3xf32>) { %1 = "tf.ReadVariableOp"(%arg1) {_output_shapes = ["tfshape$dim { size: 1 } dim { size: 3 }"], device = "/device:CPU:0", dtype = f32} : (tensor>>) -> tensor<1x3xf32> - // CHECK: {{%.*}} = corert.executeop({{%.*}}) "tf.MatMul" - // CHECK-SAME: {T = f32, device = "/device:CPU:0", transpose_a = false, transpose_b = false} + // CHECK: {{%.*}} = tfrt_fallback_async.executeop {{.*}} device("/device:CPU:0") "tf.MatMul" + // CHECK-SAME: {T = f32, transpose_a = false, transpose_b = false} %2 = "tf.MatMul"(%arg0, %1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "/device:CPU:0", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> func.return %2 : tensor<3x3xf32> } diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/auto-fusion.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/auto-fusion.mlir index 76d93695f34..b6c63f3f560 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/auto-fusion.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/auto-fusion.mlir @@ -1,4 +1,4 @@ -// RUN: tf-tfrt-opt -tf-executor-to-tfrt-pipeline="enable-native-ops=false auto-fusion-oplist=tf.Rsqrt,tf.Tanh auto-fusion-min-cluster-size=1" -split-input-file %s \ +// RUN: tf-tfrt-opt -tf-executor-to-tfrt-pipeline="auto-fusion-oplist=tf.Rsqrt,tf.Tanh auto-fusion-min-cluster-size=1" -split-input-file %s \ // RUN: | FileCheck %s --dump-input=always // CHECK-LABEL: func @single_op_cluster diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/basic.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/basic.mlir index 22199ed5230..49debdadd36 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/basic.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/basic.mlir @@ -1,4 +1,4 @@ -// RUN: tf-tfrt-opt -pass-pipeline='func.func(tf-tensor-device-copy),tfrt-lower-tf-savedmodel{hoist-invariant-ops=true},tf-to-tfrt{enable-native-ops=false func-use-fallback-tensor=true tfrt-cost-threshold=1024 tfrt-upper-cost-threshold=65536 tfrt-merge-inter-dependent-streams=true}' %s | FileCheck %s --dump-input-filter=all +// RUN: tf-tfrt-opt -pass-pipeline='builtin.module(func.func(tf-tensor-device-copy),tfrt-lower-tf-savedmodel{hoist-invariant-ops=true},tf-to-tfrt{func-use-fallback-tensor=true tfrt-cost-threshold=1024 tfrt-upper-cost-threshold=65536 tfrt-merge-inter-dependent-streams=true})' %s | FileCheck %s --dump-input-filter=all // CHECK-NOT: tf_saved_model.semantics // CHECK: tfrt.cost_threshold = 1024 diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/device_conversion.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/device_conversion.mlir index a9fdd728a5c..149fee8f244 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/device_conversion.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/device_conversion.mlir @@ -1,23 +1,11 @@ -// RUN: tf-tfrt-opt -tf-to-tfrt %s | FileCheck %s --dump-input=fail +// RUN: tf-tfrt-opt -tf-to-tfrt=func-use-fallback-tensor=true %s | FileCheck %s --dump-input=fail // CHECK-LABEL: func @device_test func.func @device_test( %arg0: tensor<3x1xf32> {tf_saved_model.index_path = [0]}, %arg1: tensor<1x3xf32> {tf_saved_model.index_path = [0]}) -> (tensor<3x3xf32> {tf_saved_model.index_path = []}) { - // CHECK: {{%.*}} = corert.get_op_handler %arg0 "/device:GPU:0" - - %2 = "tf.MatMul"(%arg0, %arg1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "/device:GPU:0", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> - func.return %2 : tensor<3x3xf32> -} - -// CHECK-LABEL: func @legacy_device_name -func.func @legacy_device_name( - %arg0: tensor<3x1xf32> {tf_saved_model.index_path = [0]}, - %arg1: tensor<1x3xf32> {tf_saved_model.index_path = [0]}) - -> (tensor<3x3xf32> {tf_saved_model.index_path = []}) { - // CHECK: {{%.*}} = corert.get_op_handler %arg0 "/device:GPU:0" - + // CHECK: device("/device:GPU:0") %2 = "tf.MatMul"(%arg0, %arg1) {T = f32, _output_shapes = ["tfshape$dim { size: 3 } dim { size: 3 }"], device = "/device:GPU:0", transpose_a = false, transpose_b = false} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> func.return %2 : tensor<3x3xf32> } diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fallback.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fallback.mlir index d8e5492614e..0e605ccc6af 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fallback.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fallback.mlir @@ -1,5 +1,5 @@ -// RUN: tf-tfrt-opt -tf-to-tfrt=enable-native-ops=false %s | FileCheck %s --dump-input=fail --dump-input-filter=all -// RUN: tf-tfrt-opt -pass-pipeline='tf-to-tfrt{enable-native-ops=false target-tpurt=true tpu-use-core-selector=false}' %s | FileCheck %s --dump-input=fail --dump-input-filter=all +// RUN: tf-tfrt-opt -tf-to-tfrt %s | FileCheck %s --dump-input=fail --dump-input-filter=all +// RUN: tf-tfrt-opt -pass-pipeline='builtin.module(tf-to-tfrt{target-tpurt=true tpu-use-core-selector=false})' %s | FileCheck %s --dump-input=fail --dump-input-filter=all // CHECK-LABEL: func @_tfrt_fallback_init // CHECK-SAME: {{.*}} !tfrt.chain diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fallback_canonicalization.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fallback_canonicalization.mlir index 5af5e8dea54..30a643fbb80 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fallback_canonicalization.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/fallback_canonicalization.mlir @@ -1,4 +1,4 @@ -// RUN: tf-tfrt-opt %s -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: tf-tfrt-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // CHECK-LABEL: func @test_const_tensor_canonicalization_single_denst_tensor_operand func.func @test_const_tensor_canonicalization_single_denst_tensor_operand() -> !tfrt_fallback.tf_tensor { diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/func_attributes.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/func_attributes.mlir index 92092520cd2..d4a8d0b5f75 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/func_attributes.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/func_attributes.mlir @@ -12,7 +12,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr func.return %4 : tensor } // CHECK-LABEL: __inference_Dataset_flat_map_lambda_190 - func.func private @__inference_Dataset_flat_map_lambda_190(%arg0: tensor {tf._user_specified_name = "args_0"}) -> tensor attributes {tf._tf_data_function = true, tf.signature.is_stateful} { + func.func private @__inference_Dataset_flat_map_lambda_190(%arg0: tensor {tf._user_specified_name = "args_0"}) -> tensor attributes {tf._original_func_name = "__inference_Dataset_flat_map_lambda_19", tf._tf_data_function = true, tf.signature.is_stateful} { %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor %1 = "tf.Const"() {device = "/device:CPU:0", value = dense<1> : tensor} : () -> tensor %2 = "tf.Const"() {device = "/device:CPU:0", value = dense<5> : tensor} : () -> tensor diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/func_attributes_multiple_callers.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/func_attributes_multiple_callers.mlir new file mode 100644 index 00000000000..1615363409a --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/func_attributes_multiple_callers.mlir @@ -0,0 +1,40 @@ +// RUN: tf-tfrt-opt -tf-executor-to-tfrt-pipeline %s | FileCheck %s + +// Checks that the ops' function attribute references the original function name +// `funcB` for `funcB_renamed` after the module is lowered to TFRT. Note that, +// `funcB_renamed` are called twice, so `CreateGuaranteeAllFuncsOneUsePass` will +// make a replicaion of `funcB_renamed` with a different name. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 567 : i32}} { + // CHECK-LABEL: @funcA + func.func @funcA() -> (tensor, tensor) attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "flatmapdataset__4_RetVal"}} { + %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %1 = "tf.Const"() {device = "/device:CPU:0", value = dense<5> : tensor} : () -> tensor + %2 = "tf.Const"() {device = "/device:CPU:0", value = dense<1> : tensor} : () -> tensor + %3 = "tf.RangeDataset"(%0, %1, %2) {device = "/device:CPU:0", output_shapes = [#tf_type.shape<>], output_types = [i64], metadata = ""} : (tensor, tensor, tensor) -> tensor + // CHECK: tfrt_fallback_async.executeop key({{[0-9]+}}) cost({{.*}}) device("/device:CPU:0") "tf.FlatMapDataset"({{.*}}) {Targuments = [], metadata = "", output_shapes = [#corert.shape<>], output_types = [i64]} {f = "funcB"} : 1 + %4 = "tf.FlatMapDataset"(%3) {Targuments = [], device = "/device:CPU:0", f = @funcB_renamed, output_shapes = [#tf_type.shape<>], output_types = [i64], metadata = ""} : (tensor) -> tensor + %5 = "tf.RangeDataset"(%1, %2, %0) {device = "/device:CPU:0", output_shapes = [#tf_type.shape<>], output_types = [i64], metadata = ""} : (tensor, tensor, tensor) -> tensor + // CHECK: tfrt_fallback_async.executeop key({{[0-9]+}}) cost({{.*}}) device("/device:CPU:0") "tf.FlatMapDataset"({{.*}}) {Targuments = [], metadata = "", output_shapes = [#corert.shape<>], output_types = [i64]} {f = "funcB"} : 1 + %6 = "tf.FlatMapDataset"(%5) {Targuments = [], device = "/device:CPU:0", f = @funcB_renamed, output_shapes = [#tf_type.shape<>], output_types = [i64], metadata = ""} : (tensor) -> tensor + func.return %4, %6 : tensor, tensor + } + // CHECK-LABEL: @funcB_renamed + func.func private @funcB_renamed(%arg0: tensor {tf._user_specified_name = "args_0"}) -> tensor attributes {tf._original_func_name = "funcB", tf._tf_data_function = true, tf.signature.is_stateful} { + %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor + %1 = "tf.Const"() {device = "/device:CPU:0", value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {device = "/device:CPU:0", value = dense<5> : tensor} : () -> tensor + %3 = "tf.RangeDataset"(%0, %2, %1) {device = "/device:CPU:0", output_shapes = [#tf_type.shape<>], output_types = [i64], metadata = ""} : (tensor, tensor, tensor) -> tensor + // CHECK: tfrt_fallback_async.executeop key({{[0-9]+}}) cost({{.*}}) device("/device:CPU:0") "tf.MapDataset"({{.*}}) {Targuments = [], metadata = "", output_shapes = [#corert.shape<>], output_types = [i64], preserve_cardinality = true, use_inter_op_parallelism = true} {f = "funcC"} : 1 + %4 = "tf.MapDataset"(%3) {device = "/device:CPU:0", f = @funcC_renamed, f._tf_data_function = true, output_shapes = [#tf_type.shape<>], output_types = [i64], preserve_cardinality = true, use_inter_op_parallelism = true, metadata = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = "/device:CPU:0"} : (tensor) -> tensor + func.return %5 : tensor + } + // CHECK-LABEL: @funcC_renamed + func.func private @funcC_renamed(%arg0: tensor {tf._user_specified_name = "args_0"}) -> tensor attributes {tf._tf_data_function = true, tf._original_func_name = "funcC"} { + %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<2> : tensor} : () -> tensor + %1 = "tf.Mul"(%arg0, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor + %2 = "tf.Identity"(%1) {device = "/device:CPU:0"} : (tensor) -> tensor + func.return %2 : tensor + } +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline.mlir index 0f47e8bc872..dd57c72674a 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline.mlir @@ -1,4 +1,4 @@ -// RUN: tf-tfrt-opt -tf-executor-to-tfrt-pipeline="enable-native-ops=false enable-optimizer=true tfrt-cost-threshold=1024" %s | FileCheck %s --dump-input=fail +// RUN: tf-tfrt-opt -tf-executor-to-tfrt-pipeline="enable-optimizer=true tfrt-cost-threshold=1024" %s | FileCheck %s --dump-input=fail // CHECK: tfrt.cost_threshold = 1024 : i64 module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 462 : i32}} { @@ -14,14 +14,14 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-NEXT: [[arg5:%.*]] = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor [[arg5_th]] {device = "/job:localhost/replica:0/task:0/device:CPU:0" // CHECK-NEXT: [[arg2:%.*]] = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor [[arg2_th]] {device = "/job:localhost/replica:0/task:0/device:CPU:0" // CHECK-NEXT: [[arg3:%.*]] = tfrt_fallback_async.corert_tensorhandle_to_fallback_tensor [[arg3_th]] {device = "/job:localhost/replica:0/task:0/device:CPU:0" -// CHECK: [[o2_chain:%.*]], [[o2:%.*]] = tfrt_fallback_async.executeop.seq([[in_chain]]) key(0) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.ReadVariableOp"([[arg3]]) -// CHECK-NEXT: [[o3_chain:%.*]], [[o3:%.*]] = tfrt_fallback_async.executeop.seq([[in_chain]]) key(1) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.ReadVariableOp"([[arg2]]) -// CHECK-NEXT: [[o4_chain:%.*]], [[o4:%.*]] = tfrt_fallback_async.executeop.seq([[in_chain]]) key(2) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.ReadVariableOp"([[arg5]]) -// CHECK-NEXT: [[o5_chain:%.*]], [[o5:%.*]] = tfrt_fallback_async.executeop.seq([[in_chain]]) key(3) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.ReadVariableOp"([[arg4]]) -// CHECK-NEXT: [[o6:%.*]] = tfrt_fallback_async.executeop key(4) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf._FusedConv2D"([[arg1]], [[o3]], [[o2]]) -// CHECK-NEXT: [[o7:%.*]] = tfrt_fallback_async.executeop key(5) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.AvgPool"([[o6]]) -// CHECK-NEXT: [[o8:%.*]] = tfrt_fallback_async.executeop key(6) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.Reshape"([[o7]], [[o1]]) -// CHECK-NEXT: [[o9:%.*]] = tfrt_fallback_async.executeop key(7) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf._FusedMatMul"([[o8]], [[o5]], [[o4]]) +// CHECK: [[o2_chain:%.*]], [[o2:%.*]] = tfrt_fallback_async.executeop.seq([[in_chain]]) key({{[0-9]+}}) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.ReadVariableOp"([[arg3]]) +// CHECK-NEXT: [[o3_chain:%.*]], [[o3:%.*]] = tfrt_fallback_async.executeop.seq([[in_chain]]) key({{[0-9]+}}) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.ReadVariableOp"([[arg2]]) +// CHECK-NEXT: [[o4_chain:%.*]], [[o4:%.*]] = tfrt_fallback_async.executeop.seq([[in_chain]]) key({{[0-9]+}}) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.ReadVariableOp"([[arg5]]) +// CHECK-NEXT: [[o5_chain:%.*]], [[o5:%.*]] = tfrt_fallback_async.executeop.seq([[in_chain]]) key({{[0-9]+}}) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.ReadVariableOp"([[arg4]]) +// CHECK-NEXT: [[o6:%.*]] = tfrt_fallback_async.executeop key({{[0-9]+}}) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf._FusedConv2D"([[arg1]], [[o3]], [[o2]]) +// CHECK-NEXT: [[o7:%.*]] = tfrt_fallback_async.executeop key({{[0-9]+}}) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.AvgPool"([[o6]]) +// CHECK-NEXT: [[o8:%.*]] = tfrt_fallback_async.executeop key({{[0-9]+}}) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf.Reshape"([[o7]], [[o1]]) +// CHECK-NEXT: [[o9:%.*]] = tfrt_fallback_async.executeop key({{[0-9]+}}) cost({{.*}}) device("/job:localhost/replica:0/task:0/device:CPU:0") "tf._FusedMatMul"([[o8]], [[o5]], [[o4]]) // CHECK-NEXT: [[out_chain:%.*]] = tfrt.merge.chains [[o2_chain]], [[o3_chain]], [[o4_chain]], [[o5_chain]] // CHECK-NEXT: [[o9_th:%.*]] = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle [[o9]] // CHECK-NEXT: [[o5_th:%.*]] = tfrt_fallback_async.fallback_tensor_to_corert_tensorhandle [[o5]] @@ -70,15 +70,15 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-DAG: tfrt_fallback_async.const_dense_tensor dense<0> : tensor // CHECK-NEXT: tfrt_fallback_async.executeop key({{.*}}) cost({{.*}}) device("/device:CPU:0") "tf.Less" // CHECK-NEXT: [[pred:%.*]] = tfrt_fallback_async.predicate - // CHECK-NEXT: tfrt.while [[pred]] @"while_body_add2/tfrt_body_1" + // CHECK-NEXT: tfrt.while [[pred]] @"[[while_func_prefix:.*]]/tfrt_body_1" // CHECK-NEXT: tfrt.merge.chains // CHECK-NEXT: tfrt.return %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<0> : tensor} : () -> tensor %1 = "tf.While"(%0) { cond = @while_cond_lt9, body = @while_body_add2, is_stateless = false, parallel_iterations = 1} : (tensor) -> (tensor) func.return %1 : tensor } - // CHECK: func @"while_body_add2/tfrt_body_1" + // CHECK: func @"[[while_func_prefix]]/tfrt_body_1" // CHECK-NOT: tfrt.call - // CHECK: func @"while_cond_lt9/tfrt_predicate" + // CHECK: func @"[[while_cond_prefix:.*]]/tfrt_predicate" } diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_cpurt.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_cpurt.mlir index 2ac33275179..43267d6ec22 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_cpurt.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_cpurt.mlir @@ -1,7 +1,6 @@ // RUN: tf-tfrt-opt %s \ // RUN: -split-input-file \ // RUN: -tf-executor-to-tfrt-pipeline=" \ -// RUN: enable-native-ops=false \ // RUN: enable-optimizer=true \ // RUN: tfrt-cost-threshold=1024 \ // RUN: auto-fusion-oplist=tf.Relu,tf.Transpose,tf.Const \ diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_refvar.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_refvar.mlir index 504357d8ad5..e9518003023 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_refvar.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/tf_to_corert_pipeline_refvar.mlir @@ -1,4 +1,4 @@ -// RUN: tf-tfrt-opt -tf-executor-to-tfrt-pipeline=enable-native-ops=false %s | FileCheck %s --dump-input=fail +// RUN: tf-tfrt-opt -tf-executor-to-tfrt-pipeline %s | FileCheck %s --dump-input=fail // CHECK-LABEL: func @__inference_pruned_131 // CHECK-SAME: ([[in_chain:%.*]]: !tfrt.chain) -> (!tfrt.chain, !corert.tensorhandle) diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/whileop.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/whileop.mlir new file mode 100644 index 00000000000..5858a015061 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/tf_to_corert/whileop.mlir @@ -0,0 +1,33 @@ +// RUN: tf-tfrt-opt -tf-executor-to-tfrt-pipeline="enable-optimizer=true tfrt-cost-threshold=1024" %s | FileCheck %s --dump-input=fail + +// Check that unused While op results and the associated ops are removed. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 462 : i32}} { + func.func @while_cond_lt9(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<9> : tensor} : () -> tensor + %1 = "tf.Less"(%arg0, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor + func.return %1 : tensor + } + + func.func @while_body_add2(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0 = "tf.Const"() {device = "/device:CPU:0", value = dense<1> : tensor} : () -> tensor + %1 = "tf.AddV2"(%arg0, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor + %2 = "tf.Div"(%arg1, %0) {device = "/device:CPU:0"} : (tensor, tensor) -> tensor + func.return %1, %2 : tensor, tensor + } + + // CHECK-LABEL: func @while_test_remove_unused_results + // CHECK: [[pred:%.*]] = tfrt_fallback_async.predicate + // CHECK-NEXT: tfrt.while [[pred]] @"[[while_func_prefix:.*]]/tfrt_body_1" + // CHECK-SAME: (!tfrt.chain, !corert.tensorhandle) -> (!tfrt.chain, !corert.tensorhandle) + // CHECK-NOT: func.call + func.func @while_test_remove_unused_results(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0:2 = "tf.While"(%arg0, %arg1) { cond = @while_cond_lt9, body = @while_body_add2, is_stateless = false, parallel_iterations = 1} : (tensor, tensor) -> (tensor, tensor) + %1:2 = func.call @while_body_add2(%arg0, %arg1) : (tensor, tensor) -> (tensor, tensor) + func.return %0#0, %1#0 : tensor, tensor + } + + // CHECK: func @"[[while_func_prefix]]/tfrt_body_1" + // CHECK: "tf.AddV2" + // CHECK-NOT: "tf.Div" +} diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data/batch.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data/batch.mlir deleted file mode 100644 index 0b94db888ef..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data/batch.mlir +++ /dev/null @@ -1,24 +0,0 @@ -// RUN: tf-tfrt-opt -tf-to-tfrt-data %s | FileCheck %s - -module { - -// CHECK-LABEL: func @main() -> !tfrt_data.dataset - func.func @main() -> tensor<*x!tf_type.variant> { - // CHECK-NEXT: %[[START:.*]] = tfrt.constant.i64 0 - // CHECK-NEXT: %[[STEP:.*]] = tfrt.constant.i64 1 - // CHECK-NEXT: %[[STOP:.*]] = tfrt.constant.i64 1000 - // CHECK-NEXT: %[[RANGE:.*]] = tfrt_data.range_dataset %[[START]], %[[STOP]], %[[STEP]] {element_type = i64} - // CHECK-NEXT: %[[BATCH_SIZE:.*]] = tfrt.constant.i64 10 - // CHECK-NEXT: %[[DROP_REMAINDER:.*]] = tfrt.constant.i1 false - // CHECK-NEXT: %[[BATCH:.*]] = tfrt_data.batch_dataset.i64 %[[RANGE]], %[[BATCH_SIZE]] {same_input_metadata = false} - // CHECK-NEXT: tfrt.return %[[BATCH]] : !tfrt_data.dataset - %start = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %step = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %stop = "tf.Const"() {value = dense<1000> : tensor} : () -> tensor - %range = "tf.RangeDataset"(%start, %stop, %step) {output_shapes = [#tf_type.shape<>], output_types = [i64], metadata = ""} : (tensor, tensor, tensor) -> tensor<*x!tf_type.variant> - %batch_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor - %drop_remainder = "tf.Const"() {value = dense : tensor} : () -> tensor - %batch = "tf.BatchDatasetV2"(%range, %batch_size, %drop_remainder) {output_shapes = [#tf_type.shape<>], output_types = [i64], parallel_copy = false, metadata = ""} : (tensor<*x!tf_type.variant>, tensor, tensor) -> tensor<*x!tf_type.variant> - func.return %batch : tensor<*x!tf_type.variant> - } -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data/range.mlir b/tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data/range.mlir deleted file mode 100644 index 69d6a40d521..00000000000 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data/range.mlir +++ /dev/null @@ -1,18 +0,0 @@ -// RUN: tf-tfrt-opt -tf-to-tfrt-data %s | FileCheck %s - -module { - -// CHECK-LABEL: func @main() -> !tfrt_data.dataset - func.func @main() -> tensor<*x!tf_type.variant> { - // CHECK-NEXT: %[[START:.*]] = tfrt.constant.i64 0 - // CHECK-NEXT: %[[STEP:.*]] = tfrt.constant.i64 1 - // CHECK-NEXT: %[[STOP:.*]] = tfrt.constant.i64 1000 - // CHECK-NEXT: %[[RANGE:.*]] = tfrt_data.range_dataset %[[START]], %[[STOP]], %[[STEP]] {element_type = i64} - // CHECK-NEXT: tfrt.return %[[RANGE]] : !tfrt_data.dataset - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<1000> : tensor} : () -> tensor - %4 = "tf.RangeDataset"(%1, %3, %2) {device = "", output_shapes = [#tf_type.shape<>], output_types = [i64], metadata = ""} : (tensor, tensor, tensor) -> tensor<*x!tf_type.variant> - func.return %4 : tensor<*x!tf_type.variant> - } -} diff --git a/tensorflow/compiler/mlir/tfrt/tests/xla_launch_fallback.mlir b/tensorflow/compiler/mlir/tfrt/tests/xla_launch_fallback.mlir new file mode 100644 index 00000000000..3905074bfd2 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/xla_launch_fallback.mlir @@ -0,0 +1,82 @@ +// RUN: tf-tfrt-opt -split-input-file -tf-executor-to-tfrt-pipeline="target-gpu=true use-bridge-for-gpu=true func-use-fallback-tensor=true" -tfrt-lower-tf-savedmodel=hoist-invariant-ops=true %s | FileCheck %s --dump-input=fail --dump-input-filter=all + +func.func private @xla_func_0(%arg0: tensor<1x3xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._XlaMustCompile = true, tf._noinline = true, tf._original_func_name = "should_not_be_used"} { + %1 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + func.return %1 : tensor<1x3xf32> +} + +// CHECK-LABEL: func @main +func.func @main(%arg0: tensor<1x3xf32>) -> tensor<*xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "input:0", outputs = "output:0"}} { + %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = "/device:CPU:0"} : (tensor>>) -> tensor<1x3xf32> + // CHECK: [[INPUT_0:%.*]] = gpurt.transfer_to_device + // CHECK: [[VAR_0:%.*]] = gpurt.maybe_transfer_variable + // CHECK: tfrt_fallback_async.executeop.seq{{.*}}"tf.XlaLaunch"([[INPUT_0]], [[VAR_0]]) + // CHECK-SAME: {function = "xla_func_0"} + // CHECK: gpurt.transfer_from_device + %2 = "tf.XlaLaunch"(%arg0, %1) {_noinline = true, _xla_compile_device_type = "GPU", device = "/device:GPU:0", function = @xla_func_0, operand_segment_sizes = array} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> +} + +// Check the case when there are multiple XLA clusters. + +func.func private @xla_func_1(%arg0: tensor<1x3xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._XlaMustCompile = true, tf._noinline = true, tf._original_func_name = "should_not_be_used"} { + %1 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + func.return %1 : tensor<1x3xf32> +} + +func.func private @xla_func_2(%arg0: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {tf._XlaMustCompile = true, tf._noinline = true, tf._original_func_name = "should_not_be_used"} { + %1 = "tf.AddV2"(%arg0, %arg0) : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + func.return %1 : tensor<1x3xf32> +} + +// CHECK-LABEL: func @multi_clusters +func.func @multi_clusters(%arg0: tensor<1x3xf32>) -> tensor<*xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "input:0", outputs = "output:0"}} { + %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = "/device:CPU:0"} : (tensor>>) -> tensor<1x3xf32> + // CHECK: [[INPUT_0:%.*]] = gpurt.transfer_to_device + // CHECK: [[VAR_0:%.*]] = gpurt.maybe_transfer_variable + // CHECK: tfrt_fallback_async.executeop.seq{{.*}}"tf.XlaLaunch"([[INPUT_0]], [[VAR_0]]) + // CHECK-SAME: {function = "xla_func_1"} + // CHECK: [[RESULT_1:%.*]] = gpurt.transfer_from_device + %2 = "tf.XlaLaunch"(%arg0, %1) {_noinline = true, _xla_compile_device_type = "GPU", device = "/device:GPU:0", function = @xla_func_1, operand_segment_sizes = array} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + + // The output of the above XLA cluster is consumed by the below XLA cluster. + // Currently, the output is first transferred back to CPU and then + // transferred to GPU again, which is unnecessary. + // TODO(b/262280565): Remove unnecessary data transfers when there are + // multiple XLA clusters. + // CHECK: [[INPUT_1:%.*]] = gpurt.transfer_to_device [[RESULT_1]] + // CHECK: tfrt_fallback_async.executeop.seq{{.*}}"tf.XlaLaunch"([[INPUT_1]]) + // CHECK-SAME: {function = "xla_func_2"} + // CHECK: gpurt.transfer_from_device + %3 = "tf.XlaLaunch"(%2) {_noinline = true, _xla_compile_device_type = "GPU", device = "/device:GPU:0", function = @xla_func_2, operand_segment_sizes = array} : (tensor<1x3xf32>) -> tensor<*xf32> + + func.return %3 : tensor<*xf32> +} + + +// Check that unused outputs of the XLA cluster are not transferred. + +func.func private @xla_func_3(%arg0: tensor<1x3xf32>, %arg1: tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor<1x3xf32>) attributes {tf._XlaMustCompile = true, tf._noinline = true, tf._original_func_name = "should_not_be_used"} { + %1 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %2 = "tf.DIV"(%arg0, %arg1) : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + func.return %1, %2 : tensor<1x3xf32>, tensor<1x3xf32> +} + +// CHECK-LABEL: func @skip_unused_output +func.func @skip_unused_output(%arg0: tensor<1x3xf32>) -> tensor<*xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "input:0", outputs = "output:0"}} { + %0 = "tf.VarHandleOp"() {device = "/device:CPU:0", container = "", shared_name = "variable"} : () -> tensor>> + %1 = "tf.ReadVariableOp"(%0) {device = "/device:CPU:0"} : (tensor>>) -> tensor<1x3xf32> + // CHECK: [[INPUT_0:%.*]] = gpurt.transfer_to_device + // CHECK: [[VAR_0:%.*]] = gpurt.maybe_transfer_variable + // CHECK: tfrt_fallback_async.executeop.seq{{.*}}"tf.XlaLaunch"([[INPUT_0]], [[VAR_0]]) + // CHECK-SAME: {function = "xla_func_3"} + // Since only one output of the XlaLaunch is used, there is only one data transfer. + // CHECK: gpurt.transfer_from_device + // CHECK-NOT: gpurt.transfer_from_device + %2:2 = "tf.XlaLaunch"(%arg0, %1) {_noinline = true, _xla_compile_device_type = "GPU", device = "/device:GPU:0", function = @xla_func_3, operand_segment_sizes = array} : (tensor<1x3xf32>, tensor<1x3xf32>) -> (tensor<*xf32>, tensor<*xf32>) + func.return %2#0 : tensor<*xf32> +} + + diff --git a/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc b/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc index cb9c5fb9aed..eb5615dc2c6 100644 --- a/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc +++ b/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc @@ -27,9 +27,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.h" #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_test_passes.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/passes.h" #include "tensorflow/core/platform/init_main.h" #include "tfrt/init_tfrt_dialects.h" // from @tf_runtime @@ -56,6 +57,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); tensorflow::RegisterTPUDialects(®istry); + tensorflow::RegisterGpuDialects(®istry); tfrt::RegisterTFRTDialects(registry); return failed( diff --git a/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.cc b/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.cc index 7990d58a369..41567df0731 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tfrt/core_runtime/opdefs/attributes.h" // from @tf_runtime #include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime -#include "tfrt/distributed_runtime/opdefs/kernels.h" // from @tf_runtime namespace tensorflow { @@ -41,7 +40,6 @@ CoreRTConverter::CoreRTConverter( : builder_(context), side_effect_analysis_(*side_effect_analysis) { addConversion([](tfrt::compiler::ChainType type) { return type; }); addConversion([](tfrt::corert::OpHandlerType type) { return type; }); - addConversion([](tfrt::dist::DistributedContextType type) { return type; }); addConversion([](tfrt::corert::TensorHandleType type) { return type; }); addConversion([=](mlir::TensorType type) -> llvm::Optional { // Ref types are not supported in both compiler and runtime. @@ -65,8 +63,9 @@ void CoreRTConverter::MaterializeDerivedAttributes(mlir::Operation *op) { } mlir::ArrayAttr CoreRTConverter::CreateOpFuncAttrs( - ArrayRef attrs, - llvm::SmallVector *func_attr_keys) { + const mlir::SymbolTable &symbol_table, ArrayRef attrs, + llvm::SmallVector *func_attr_keys, + bool use_mlir_func_name) { llvm::SmallVector attr_array; for (auto key_and_value : attrs) { auto attr_key = key_and_value.getName(); @@ -74,7 +73,10 @@ mlir::ArrayAttr CoreRTConverter::CreateOpFuncAttrs( if (!IsUnusedTfrtAttribute(attr_key) && attr_value.isa()) { auto func_attr = attr_value.dyn_cast(); - auto converted = ConvertSymbolAttrToStringAttr(func_attr); + auto converted = ConvertSymbolAttrToStringAttr(symbol_table, func_attr, + use_mlir_func_name); + if (!converted) return {}; + mlir::StringAttr key = builder_.getStringAttr(attr_key.strref()); attr_array.push_back(builder_.getArrayAttr({key, converted})); @@ -151,14 +153,7 @@ mlir::Value CoreRTConverter::GetDistributedContext( if (iter != distributed_context_by_func_.end()) { return iter->second; } - ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter); - rewriter->setInsertionPoint(op); - auto get_dist_ctx_op = rewriter->create( - op->getLoc(), distributed_context_type()); - - mlir::Value result = get_dist_ctx_op.getResult(); - distributed_context_by_func_[func_op.getOperation()] = result; - return result; + return mlir::Value(); } mlir::Value CoreRTConverter::GetRemoteChainManager( @@ -168,18 +163,7 @@ mlir::Value CoreRTConverter::GetRemoteChainManager( if (iter != remote_chain_mgr_by_func_.end()) { return iter->second; } - ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter); - rewriter->setInsertionPoint(op); - - mlir::Type remote_chain_mgr_type = - builder_.getType<::tfrt::dist::RemoteChainManagerType>(); - mlir::Value dist_ctx = GetDistributedContext(op, rewriter); - auto create_mgr_op = rewriter->create( - op->getLoc(), remote_chain_mgr_type, dist_ctx); - - mlir::Value result = create_mgr_op.getResult(); - remote_chain_mgr_by_func_[func_op.getOperation()] = result; - return result; + return mlir::Value(); } mlir::Value CoreRTConverter::GetLocalSideEffectChain( @@ -229,42 +213,41 @@ mlir::Value CoreRTConverter::GetTaskHandle( return iter->second; } - mlir::Value distributed_context = GetDistributedContext(op, rewriter); - auto task_handle_op = rewriter->create( - op->getLoc(), rewriter->getType(), - distributed_context, task_name); - - task_handle_by_name[task_name] = task_handle_op.getResult(); - return task_handle_op.getResult(); -} - -mlir::Value CoreRTConverter::GetRemoteSideEffectChain( - mlir::Operation *op, StringRef remote_host, - mlir::ConversionPatternRewriter *rewriter) { - mlir::Value remote_chain_mgr = GetRemoteChainManager(op, rewriter); - mlir::Value local_chain = GetLocalSideEffectChain(op, rewriter); - mlir::Value task_handle = GetTaskHandle(op, remote_host, rewriter); - mlir::Type remote_obj_id_ty = - rewriter->getType(); - - // Get the remote chain using the tfrt_dist.get_chain_for_task_handle op. - auto get_chain_op = rewriter->create( - op->getLoc(), remote_obj_id_ty, local_chain, remote_chain_mgr, - task_handle); - return get_chain_op.getResult(); + return mlir::Value(); } mlir::StringAttr CoreRTConverter::ConvertSymbolAttrToStringAttr( - mlir::FlatSymbolRefAttr symbol_attr) { + const mlir::SymbolTable &symbol_table, mlir::FlatSymbolRefAttr symbol_attr, + bool use_mlir_func_name) { + if (use_mlir_func_name) { + return mlir::StringAttr::get(builder_.getContext(), + symbol_attr.getValue().str()); + } + // Currently in TF graph to MLIR importing, a "0" is appended to the original - // function name, so we pop it here. The renaming is for TF/XLA v1 bridge - // use cases. Refer to b/142268695, b/141617294 for more context. + // function name. The renaming is for TF/XLA v1 bridge use cases. Refer to + // b/142268695, b/141617294 for more context. // - // In TFRT use cases, in almost every case "0" is the only literal - // appended since TF Graph already guarantee function name uniqueness. - // TODO(b/172092902): Investigate a better way to make the tf_func_name to - // mlir_tf_func_name conversion reversible. - auto func_name = symbol_attr.getValue().drop_back().str(); + // TFRT currently uses the original function library. Hence, we retrieve the + // original function name from the function attributes. Longer term, we + // probably want to export the MLIR functions. + func::FuncOp callee = + symbol_table.lookup(symbol_attr.getValue()); + if (!callee) return mlir::StringAttr(); + + mlir::StringAttr original_func_name = + callee->getAttrOfType("tf._original_func_name"); + std::string func_name; + if (!original_func_name) { + // If there is no function attribute "tf._original_func_name" in the callee, + // we use the workaround to recover the original function name by removing + // the last char of the MLIR function name. + // TODO(b/259138201): Remove this workwaround after we make sure + // "tf._original_func_name" is present in callees in all code paths. + func_name = symbol_attr.getValue().drop_back().str(); + } else { + func_name = original_func_name.str(); + } return mlir::StringAttr::get(builder_.getContext(), func_name); } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h b/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h index c7c42c81bbc..c96e1d8a43b 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h @@ -25,7 +25,6 @@ limitations under the License. #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime #include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime -#include "tfrt/distributed_runtime/opdefs/types.h" // from @tf_runtime namespace tensorflow { @@ -49,9 +48,16 @@ class CoreRTConverter : public mlir::TypeConverter { // named attribute lists, which is an array of pairs, with keys and values // both being string attributes. The values represent function names. // This method also populates a vector of attribute keys to be removed. + // If `use_mlir_func_name` is true, the function name given by MLIR will be + // used, which could be different from the original function name in the graph + // function library. This is used when the original function has been changed + // by lowering passes, and hence it needs to be exported to function library + // for runtime to use. mlir::ArrayAttr CreateOpFuncAttrs( + const mlir::SymbolTable &symbol_table, llvm::ArrayRef attrs, - llvm::SmallVector *func_attr_keys); + llvm::SmallVector *func_attr_keys, + bool use_mlir_func_name = false); // Parse the device name of `op` to TFRT's device name. For example, "/CPU:0" // will be parsed as "cpu". Return None if no device is assigned. @@ -69,24 +75,19 @@ class CoreRTConverter : public mlir::TypeConverter { // Get a DistributedContext value to be used by the given op. The // DistributedContext value should be shared by all operations in the body - // of the same FuncOp. If there does not exist one, insert a - // GetDistributedContext op right before the given op and return the result - // value. + // of the same FuncOp. If there does not exist one, return a null Value. mlir::Value GetDistributedContext(mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter); // Get a RemoteChainManager value to be used by the given op. The // RemoteChainManager value should be shared by all operations in the body - // of the same FuncOp. If there does not exist one, insert a - // tfrt_dist.test_create_remote_chain_manager op right before the given op and - // return the result value. + // of the same FuncOp. If there does not exist one, return a null Value. mlir::Value GetRemoteChainManager(mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter); // Get a TaskHandle value with the given task name. If the TaskHandle value // has already been created for the given task name within the same FuncOp, - // return this TaskHandle value. Otherwise, insert a tfrt_dist.get_task_handle - // op right before the given op and return the result value. + // return this TaskHandle value. Otherwise, return a null Value. mlir::Value GetTaskHandle(mlir::Operation *op, StringRef task_name, mlir::ConversionPatternRewriter *rewriter); @@ -102,11 +103,6 @@ class CoreRTConverter : public mlir::TypeConverter { mlir::Value GetLocalSideEffectChain( mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter); - // Return a remote chain for side effects for `op`. - mlir::Value GetRemoteSideEffectChain( - mlir::Operation *op, StringRef remote_host, - mlir::ConversionPatternRewriter *rewriter); - mlir::Type op_handler_type() { return builder_.getType<::tfrt::corert::OpHandlerType>(); } @@ -119,10 +115,6 @@ class CoreRTConverter : public mlir::TypeConverter { return builder_.getType<::tfrt::compiler::ChainType>(); } - mlir::Type distributed_context_type() { - return builder_.getType<::tfrt::dist::DistributedContextType>(); - } - mlir::Builder &builder() { return builder_; } private: @@ -156,7 +148,8 @@ class CoreRTConverter : public mlir::TypeConverter { mlir::TypeAttr ConvertTypeAttribute(mlir::TypeAttr type_attr); mlir::StringAttr ConvertSymbolAttrToStringAttr( - mlir::FlatSymbolRefAttr symbol_attr); + const mlir::SymbolTable &symbol_table, + mlir::FlatSymbolRefAttr symbol_attr, bool use_mlir_func_name = false); mlir::Builder builder_; diff --git a/tensorflow/compiler/mlir/tfrt/transforms/deduplicate_batch_function.cc b/tensorflow/compiler/mlir/tfrt/transforms/deduplicate_batch_function.cc index 970fa7ae959..31f3cc65d7b 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/deduplicate_batch_function.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/deduplicate_batch_function.cc @@ -131,10 +131,10 @@ mlir::LogicalResult DeduplicateFunctionsInovkedByBatchFunction::Run() { // User is not a BatchFunctionOp if (!op) return false; if (shared_name.empty()) { - shared_name = op.shared_name(); + shared_name = op.getSharedName(); return true; } - return shared_name == op.shared_name(); + return shared_name == op.getSharedName(); })) { shared_name_to_func_ops[shared_name].push_back(func); } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/fuse_tpu_compile_and_execute_ops.cc b/tensorflow/compiler/mlir/tfrt/transforms/fuse_tpu_compile_and_execute_ops.cc index 26ad592d6f3..fd1d4cce445 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/fuse_tpu_compile_and_execute_ops.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/fuse_tpu_compile_and_execute_ops.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -26,6 +27,41 @@ limitations under the License. namespace tensorflow { namespace { +void RecursivelyMoveOp(mlir::TF::_TPUCompileMlirOp compile_op, + mlir::Operation *op_candidate, + llvm::SmallDenseSet *ops_to_move) { + if (!op_candidate || !ops_to_move->contains(op_candidate)) return; + // Move the parent first. + for (const auto &operand : op_candidate->getOperands()) { + RecursivelyMoveOp(compile_op, operand.getDefiningOp(), ops_to_move); + } + op_candidate->moveBefore(compile_op); + // Erase the op to avoid moving the common ancestor. + ops_to_move->erase(op_candidate); +} + +// Move exec_op's args defining ops before the compile op to group compile op +// and execute op. +void GroupCompileOpAndExecuteOp(mlir::func::FuncOp func, + mlir::TF::_TPUCompileMlirOp compile_op, + mlir::TF::TPUExecuteOp exec_op) { + // Collect the ops between compile op and execute op. + llvm::SmallDenseSet ops_to_move; + bool collect = false; + func.walk( + [&ops_to_move, &collect, &compile_op, &exec_op](mlir::Operation *op) { + if (collect) ops_to_move.insert(op); + if (op == compile_op.getOperation()) collect = true; + return op == exec_op.getOperation() ? mlir::WalkResult::interrupt() + : mlir::WalkResult::advance(); + }); + // Recursively move the defining op of the execute op argument in front of the + // compile op such that the compile op and execute op are grouped together. + for (const auto &operand : exec_op.getArgs()) { + RecursivelyMoveOp(compile_op, operand.getDefiningOp(), &ops_to_move); + } +} + // This pass rewrites tf._TPUCompileMlirOp and tf.TPUExecuteOp into a single // tf.TPUCompileMlirAndExecuteOp. Also it removes the unnecessary // TPUCompileSucceededAssertOp. @@ -79,7 +115,7 @@ class FuseTpuCompileAndExecutePass mlir::OpBuilder builder(&func.getBody()); for (auto exec_op : tpu_execute_ops) { - auto compile_cache_entry = exec_op.key(); + auto compile_cache_entry = exec_op.getKey(); auto compile_op = ::llvm::dyn_cast( compile_cache_entry.getDefiningOp()); if (!compile_op) { @@ -88,6 +124,8 @@ class FuseTpuCompileAndExecutePass return; } + GroupCompileOpAndExecuteOp(func, compile_op, exec_op); + builder.setInsertionPointAfter(compile_op); llvm::SmallVector output_types; output_types.push_back(mlir::RankedTensorType::get( @@ -97,20 +135,20 @@ class FuseTpuCompileAndExecutePass llvm::SmallVector static_shaped_operand_indices_attr; llvm::SmallVector static_shape_tensors; llvm::SmallVector exec_op_args; - exec_op_args.resize(exec_op.args().size()); + exec_op_args.resize(exec_op.getArgs().size()); auto &static_shaped_operands = exec_to_static_shaped_operands_map[exec_op]; - for (int i = 0; i < exec_op.args().size(); ++i) { + for (int i = 0; i < exec_op.getArgs().size(); ++i) { auto iter = static_shaped_operands.find(i); if (iter != static_shaped_operands.end()) { static_shaped_operand_indices_attr.push_back(iter->first); - static_shape_tensors.push_back(iter->second.static_shape()); - exec_op_args[i] = iter->second.input(); + static_shape_tensors.push_back(iter->second.getStaticShape()); + exec_op_args[i] = iter->second.getInput(); // The first operand is the input tensor, while the second operand is // the static shape tensor, hence the drop_back here. iter->second->replaceAllUsesWith( - mlir::ValueRange({iter->second.input()})); + mlir::ValueRange({iter->second.getInput()})); iter->second->erase(); } else { exec_op_args[i] = exec_op->getOperand(i); @@ -126,12 +164,13 @@ class FuseTpuCompileAndExecutePass exec_op.getLoc(), output_types, exec_op_args, static_shape_tensors, builder.getI32ArrayAttr(static_shaped_operand_indices_attr), - compile_op.mlir_module(), compile_op.metadata(), producer_name); + compile_op.getMlirModule(), compile_op.getMetadata(), + producer_name); - exec_op.replaceAllUsesWith(compile_and_execute_op.results()); - for (auto program_result : compile_op.program()) { + exec_op.replaceAllUsesWith(compile_and_execute_op.getResults()); + for (auto program_result : compile_op.getProgram()) { program_result.replaceAllUsesWith( - compile_and_execute_op.rendezvous_key_base()); + compile_and_execute_op.getRendezvousKeyBase()); } assert(exec_op.use_empty()); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.cc new file mode 100644 index 00000000000..732d5fcd07f --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.cc @@ -0,0 +1,32 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h" + +#include "tensorflow/compiler/mlir/tfrt/ir/gpu_ops.h" + +namespace tensorflow { + +void RegisterGpuDialects(mlir::DialectRegistry *registry) { + registry->insert(); +} + +void AddGpuTargetDialectAndPatterns(mlir::MLIRContext *context, + mlir::ConversionTarget *target, + mlir::RewritePatternSet *patterns) { + target->addLegalDialect(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h b/tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h new file mode 100644 index 00000000000..2af46e34ce9 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h @@ -0,0 +1,37 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_GPU_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_GPU_PASSES_H_ + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassOptions.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace tensorflow { + +// Registers dialects used in TFRT GPU lowering. +void RegisterGpuDialects(mlir::DialectRegistry *registry); + +// Adds a target dialect and rewrite patterns for TFRT GPU lowering. +void AddGpuTargetDialectAndPatterns(mlir::MLIRContext *context, + mlir::ConversionTarget *target, + mlir::RewritePatternSet *patterns); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_GPU_PASSES_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc b/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc index 758d76a838a..75f39ac9b9d 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc @@ -23,7 +23,7 @@ limitations under the License. #include "llvm/ADT/StringSet.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project @@ -37,6 +37,9 @@ limitations under the License. namespace tensorflow { namespace { +using ::mlir::tf_saved_model::kTfSavedModelExportedNamesAttr; +using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; + constexpr char kCpuDeviceName[] = "/job:localhost/replica:0/task:0/device:CPU:0"; @@ -78,7 +81,7 @@ struct HoistInfo { // Mapping from the old values produced by hoisted ops before hoisting to the // new values after hoisting. - mlir::BlockAndValueMapping value_mapping; + mlir::IRMapping value_mapping; // `hoisted_values` is to keep all values that are produced by hoisted ops // but used by non-hoisted ops. These values will be replaced by results of @@ -155,7 +158,7 @@ void ReplaceHoistedValues( builder.getStrArrayAttr(container_arr)); get_resource_op->setAttr("device", builder.getStringAttr(device)); - auto new_values = get_resource_op.results(); + auto new_values = get_resource_op.getResults(); for (auto iter : llvm::zip(old_values, new_values)) { auto old_value = std::get<0>(iter); auto new_value = std::get<1>(iter); @@ -165,10 +168,11 @@ void ReplaceHoistedValues( } } -bool OnlyHasReadEffect(mlir::Operation *op) { +bool OnlyHasReadOrNoEffect(mlir::Operation *op) { auto interface = llvm::dyn_cast(op); if (!interface) return false; - return interface.onlyHasEffect(); + return interface.onlyHasEffect() || + interface.hasNoEffect(); } bool CanHoist(const llvm::DenseSet &read_only_vars, @@ -177,7 +181,7 @@ bool CanHoist(const llvm::DenseSet &read_only_vars, if (op->mightHaveTrait()) return false; // Non-side-effecting ops can be hoisted. - if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) return true; + if (mlir::isMemoryEffectFree(op)) return true; // ResourceHandle ops can be hoisted. if (llvm::isa(op)) @@ -186,7 +190,7 @@ bool CanHoist(const llvm::DenseSet &read_only_vars, // If it is ReadVariableOp and the variable is readonly, it can be hoisted. if (auto read_var_op = llvm::dyn_cast(op)) { if (auto var_handle_op = llvm::dyn_cast_or_null( - read_var_op.resource().getDefiningOp())) { + read_var_op.getResource().getDefiningOp())) { if (read_only_vars.count(GetResourceHandle(var_handle_op)) > 0) return true; } @@ -197,7 +201,7 @@ bool CanHoist(const llvm::DenseSet &read_only_vars, if (auto lookup_table_size_op = llvm::dyn_cast(op)) { if (auto hash_table_op = llvm::dyn_cast_or_null( - lookup_table_size_op.table_handle().getDefiningOp())) { + lookup_table_size_op.getTableHandle().getDefiningOp())) { if (read_only_vars.count(GetResourceHandle(hash_table_op)) > 0) return true; } @@ -332,7 +336,7 @@ void HoistInvariantOps(mlir::ModuleOp module) { const auto &vars = iter.second; if (std::all_of(vars.begin(), vars.end(), [](mlir::Operation *op) { for (auto *user : op->getUsers()) { - if (!OnlyHasReadEffect(user)) return false; + if (!OnlyHasReadOrNoEffect(user)) return false; } return true; })) { @@ -458,7 +462,7 @@ class LowerTFSavedModelPass mlir::OpBuilder builder(&getContext()); auto resource_id = builder.getStringAttr("tf.resource_name"); auto bound_id = builder.getStringAttr("tf_saved_model.bound_input"); - auto path_id = builder.getStringAttr("tf_saved_model.index_path"); + auto path_id = builder.getStringAttr(kTfSavedModelIndexPathAttr); module.walk([resource_id, bound_id, path_id, &builder](mlir::Operation *op) mutable { @@ -478,7 +482,7 @@ class LowerTFSavedModelPass func_op.removeResultAttr(i, path_id); } if (auto exported_names = func_op->getAttrOfType( - "tf_saved_model.exported_names")) { + kTfSavedModelExportedNamesAttr)) { bool is_session_initializer = IsSessionInitializer(func_op); // Create a function for each exported name. @@ -486,7 +490,7 @@ class LowerTFSavedModelPass // TODO(b/148477882): TFRT dialect should have similar concepts of // exported names so that a function can be referenced by multiple // exported names. - func_op->removeAttr("tf_saved_model.exported_names"); + func_op->removeAttr(kTfSavedModelExportedNamesAttr); for (auto exported_name : exported_names) { auto exported_func_op = func_op.clone(); exported_func_op.setName(exported_name.cast()); @@ -565,14 +569,14 @@ class ConvertReferenceVariableToResourceVariablePass mlir::LogicalResult ConvertReferenceVariableToResourceVariable( mlir::TF::VariableV2Op var_op) { auto tensor_type = - mlir::TF::DropRefType(var_op.ref().getType()).cast(); + mlir::TF::DropRefType(var_op.getRef().getType()).cast(); llvm::SmallVector identity_ops; llvm::SmallVector assign_ops; llvm::SmallVector, 4> side_effect_free_ops; - for (mlir::OpOperand &use : var_op.ref().getUses()) { + for (mlir::OpOperand &use : var_op.getRef().getUses()) { mlir::Operation *user = use.getOwner(); if (auto identity = llvm::dyn_cast(user)) { @@ -581,11 +585,11 @@ mlir::LogicalResult ConvertReferenceVariableToResourceVariable( } else if (auto assign = llvm::dyn_cast(user)) { // Conservatively we only allow the case that the output of this tf.Assign // is not consumed by any other ops. - if (assign.output_ref().use_empty()) { + if (assign.getOutputRef().use_empty()) { assign_ops.push_back(assign); continue; } - } else if (mlir::MemoryEffectOpInterface::hasNoEffect(user)) { + } else if (mlir::isMemoryEffectFree(user)) { side_effect_free_ops.push_back({user, use.getOperandNumber()}); continue; } @@ -603,7 +607,7 @@ mlir::LogicalResult ConvertReferenceVariableToResourceVariable( {}, mlir::TF::ResourceType::get( llvm::ArrayRef{tensor_type}, builder.getContext())), - var_op.container(), var_op.shared_name()); + var_op.getContainer(), var_op.getSharedName()); for (auto op : identity_ops) { // Set insertion point to this identity_op so that the side-effect @@ -611,7 +615,7 @@ mlir::LogicalResult ConvertReferenceVariableToResourceVariable( builder.setInsertionPoint(op); auto read_var_op = builder.create( op.getLoc(), op.getType(), var_handle_op); - op.replaceAllUsesWith(read_var_op.value()); + op.replaceAllUsesWith(read_var_op.getValue()); op.erase(); } @@ -620,7 +624,7 @@ mlir::LogicalResult ConvertReferenceVariableToResourceVariable( // dominating the newly created op. builder.setInsertionPoint(op); builder.create(op.getLoc(), var_handle_op, - op.value()); + op.getValue()); op.erase(); } @@ -633,7 +637,7 @@ mlir::LogicalResult ConvertReferenceVariableToResourceVariable( // Create a new read variable op, so that the side-effects are preserved. auto read_var_op = builder.create( op->getLoc(), tensor_type, var_handle_op); - op->setOperand(idx, read_var_op.value()); + op->setOperand(idx, read_var_op.getValue()); } return mlir::success(); @@ -650,7 +654,7 @@ void ConvertReferenceVariableToResourceVariablePass::runOnOperation() { // First, we collect all variables' corresponding tf.VariableV2 ops. module.walk([&ref_vars](mlir::TF::VariableV2Op op) { - if (op.shared_name().empty()) { + if (op.getSharedName().empty()) { op.emitOpError() << "unable to convert reference variables with empty shared_names."; return mlir::WalkResult::interrupt(); @@ -661,7 +665,7 @@ void ConvertReferenceVariableToResourceVariablePass::runOnOperation() { device = device_attr.getValue(); } - ref_vars[{device, op.container(), op.shared_name()}].push_back(op); + ref_vars[{device, op.getContainer(), op.getSharedName()}].push_back(op); return mlir::WalkResult::advance(); }); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/merge_tf_if_ops.cc b/tensorflow/compiler/mlir/tfrt/transforms/merge_tf_if_ops.cc index ba0a0141acb..527f1c7c996 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/merge_tf_if_ops.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/merge_tf_if_ops.cc @@ -117,7 +117,7 @@ class MergeTfIfOpsPass auto if_op = llvm::dyn_cast(&op); // Skip non tf.If ops and tf.If ops that are side-effecting. - if (!if_op || !if_op.is_stateless()) continue; + if (!if_op || !if_op.getIsStateless()) continue; if_ops_to_merge[if_op].push_back(if_op); } @@ -169,33 +169,34 @@ class MergeTfIfOpsPass } auto branch_function_type = builder.getFunctionType( - if_ops.front().input().getTypes(), new_result_types); + if_ops.front().getInput().getTypes(), new_result_types); // Create new branches for the merged tf.If op. auto then_branch_name = CreateBranchFunction( builder, loc, branch_prefix, /*branch_suffix=*/"_then", branch_function_type, if_ops, - [](mlir::TF::IfOp op) { return op.then_branchAttr(); }); + [](mlir::TF::IfOp op) { return op.getThenBranchAttr(); }); auto else_branch_name = CreateBranchFunction( builder, loc, branch_prefix, /*branch_suffix=*/"_else", branch_function_type, if_ops, - [](mlir::TF::IfOp op) { return op.else_branchAttr(); }); + [](mlir::TF::IfOp op) { return op.getElseBranchAttr(); }); mlir::OpBuilder::InsertionGuard guard(builder); builder.setInsertionPoint(if_ops.front()); // Create the merged tf.If op using the new branches. auto new_if_op = builder.create( - loc, new_result_types, if_ops.front().cond(), if_ops.front().input(), - then_branch_name, else_branch_name, /*is_stateless=*/true); + loc, new_result_types, if_ops.front().getCond(), + if_ops.front().getInput(), then_branch_name, else_branch_name, + /*is_stateless=*/true); // Replace the uses of results of the original tf.If ops with the results of // the merged tf.If op. - auto new_result_iter = new_if_op.output().begin(); + auto new_result_iter = new_if_op.getOutput().begin(); for (auto if_op : if_ops) { - for (auto result : if_op.output()) { - assert(new_result_iter != new_if_op.output().end()); + for (auto result : if_op.getOutput()) { + assert(new_result_iter != new_if_op.getOutput().end()); result.replaceAllUsesWith(*new_result_iter); ++new_result_iter; } @@ -233,7 +234,7 @@ class MergeTfIfOpsPass empty_string_attr); // The results are the concatenation of the original branches. - results.append(call_op.output().begin(), call_op.output().end()); + results.append(call_op.getOutput().begin(), call_op.getOutput().end()); } builder.create(loc, results); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc b/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc index 9f9e880685c..b73b84c330e 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc @@ -39,7 +39,7 @@ class FoldDeviceIndex : public mlir::OpRewritePattern { return mlir::failure(); int32_t i = 0; - mlir::ArrayAttr device_names = op.device_names(); + mlir::ArrayAttr device_names = op.getDeviceNames(); for (; i < device_names.size(); ++i) { auto device_name = device_names[i].cast().getValue(); if (device_name == parsed_name.type) break; diff --git a/tensorflow/compiler/mlir/tfrt/transforms/optimize_tf_control_flow_side_effect.cc b/tensorflow/compiler/mlir/tfrt/transforms/optimize_tf_control_flow_side_effect.cc index 6cf747d578c..29e7e744b0c 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/optimize_tf_control_flow_side_effect.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/optimize_tf_control_flow_side_effect.cc @@ -32,7 +32,7 @@ bool FunctionHasSideEffect( auto op_has_side_effect = [&](mlir::Operation* op) { if (auto while_op = llvm::dyn_cast(op)) { - if (while_op.is_stateless()) return false; + if (while_op.getIsStateless()) return false; return FunctionHasSideEffect(while_op.cond_function(), function_side_effect) || @@ -41,7 +41,7 @@ bool FunctionHasSideEffect( } if (auto if_op = llvm::dyn_cast(op)) { - if (if_op.is_stateless()) return false; + if (if_op.getIsStateless()) return false; return FunctionHasSideEffect(if_op.else_function(), function_side_effect) || @@ -53,7 +53,7 @@ bool FunctionHasSideEffect( // ops' callee functions contain them, we treat them as non-side-effecting. if (llvm::isa(op)) return false; - return !mlir::MemoryEffectOpInterface::hasNoEffect(op); + return !mlir::isMemoryEffectFree(op); }; // Speculatively setting the function to have no side effect to avoid infinite @@ -96,7 +96,7 @@ class OptimizeTfControlFlowSideEffectPass mlir::Builder builder(module.getContext()); module.walk([&](mlir::Operation* op) { if (auto while_op = llvm::dyn_cast(op)) { - if (while_op.is_stateless()) return; + if (while_op.getIsStateless()) return; if (!FunctionHasSideEffect(while_op.cond_function(), function_side_effect) && @@ -107,7 +107,7 @@ class OptimizeTfControlFlowSideEffectPass } if (auto if_op = llvm::dyn_cast(op)) { - if (if_op.is_stateless()) return; + if (if_op.getIsStateless()) return; if (!FunctionHasSideEffect(if_op.else_function(), function_side_effect) && diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.h b/tensorflow/compiler/mlir/tfrt/transforms/passes.h index afebf94a121..e1e7df8fb11 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.h @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h" +#include "tensorflow/tsl/platform/status.h" namespace mlir { class PassManager; @@ -75,6 +76,11 @@ CreateOptimizeTfForTfrtPass(); class CoreRTConverter; +// Create a pass that sink in the var handle op to the callee function when +// proper. +std::unique_ptr> +CreateSinkInInvariantOpsPass(); + // Create a pass that rewrites tf_saved_model dialect's ops according to TFRT's // requirements. std::unique_ptr> @@ -92,12 +98,6 @@ mlir::LogicalResult TFSavedModelToCoreRTConversionPassRun( mlir::ConversionTarget* target, mlir::RewritePatternSet* patterns, CoreRTConverter* corert_converter); -// Create an operation pass that converts each tfrt_dist.remote_execute_func op -// into a combination of tfrt_dist.register_tfrt_function op and -// tfrt_dist.remote_execute op. -std::unique_ptr> -CreateDistRemoteRunEncapsulatePass(); - // Create an operation pass that removes the device attribute from every // corert.executeop. std::unique_ptr> @@ -167,12 +167,17 @@ struct TfrtPipelineOptions llvm::cl::desc("If true, fallback executeops that produce inputs to tpu " "program will use tpu host allocator."), llvm::cl::init(false)}; - Option enable_native_ops{ - *this, "enable-native-ops", - llvm::cl::desc( - "If true, native ops will be used on an opt-in basis instead of " - "fallback ops. If false, no native ops are used."), - llvm::cl::init(true)}; + + Option target_gpu{ + *this, "target-gpu", + llvm::cl::desc("If true, target GPU compiler passes."), + llvm::cl::init(false)}; + + // TODO(b/260915352): Remove the flag and default to using bridge. + Option use_bridge_for_gpu{ + *this, "use-bridge-for-gpu", + llvm::cl::desc("If true, GPU bridge is used."), llvm::cl::init(false)}; + Option func_use_fallback_tensor{ *this, "func-use-fallback-tensor", llvm::cl::desc( @@ -192,6 +197,12 @@ struct TfrtPipelineOptions "out to run during loading."), llvm::cl::init(false)}; + Option sink_in_invariant_ops{ + *this, "sink-in-invariant-ops", + llvm::cl::desc("If true, sink the selected invariant ops in to the " + "nested functions to facilitate invariant ops hoisting."), + llvm::cl::init(false)}; + Option cost_threshold{ *this, "tfrt-cost-threshold", llvm::cl::desc( @@ -243,8 +254,8 @@ void CreateTFExecutorToTFPipeline(mlir::OpPassManager& pm, // Creates a pipeline of passes that lowers MLIR TF dialect from tf.function to // TFRT dialect. SavedModel related conversions are not included. -void CreateTfExecutorToTfrtPipeline(mlir::PassManager& pm, - const TfrtPipelineOptions& options); +tsl::Status CreateTfExecutorToTfrtPipeline(mlir::PassManager& pm, + const TfrtPipelineOptions& options); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/remote_run_encapsulate.cc b/tensorflow/compiler/mlir/tfrt/transforms/remote_run_encapsulate.cc deleted file mode 100644 index 6b869311823..00000000000 --- a/tensorflow/compiler/mlir/tfrt/transforms/remote_run_encapsulate.cc +++ /dev/null @@ -1,245 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This pass converts each tfrt_dist.remote_execute_func op into a combination -// of tfrt_dist.register_tfrt_function op and tfrt_dist.remote_execute op. The -// function to be executed in the remote host will be serialized as a string -// attribute of the tfrt_dist.register_tfrt_function op. - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/Sequence.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project -#include "mlir/Transforms/RegionUtils.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" -#include "tensorflow/core/util/device_name_utils.h" -#include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime -#include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime -#include "tfrt/distributed_runtime/opdefs/kernels.h" // from @tf_runtime -#include "tfrt/distributed_runtime/opdefs/types.h" // from @tf_runtime -#include "tfrt/test_kernels/opdefs/test_kernels.h" // from @tf_runtime - -namespace tensorflow { - -namespace { - -constexpr const char* kHost = "host"; -constexpr const char* kTFRTDevice = "tfrt.device"; - -struct DistRemoteRunEncapsulatePass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DistRemoteRunEncapsulatePass) - - llvm::StringRef getArgument() const final { - return "tfrt-dist-remote-run-encapsulate"; - } - llvm::StringRef getDescription() const final { - return "This pass looks for a remote_run_func and serialize the callee to " - "a string attribute attached to a remote_register operation, " - "followed by a remote_execute invocation."; - } - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } -}; - -LogicalResult EncapsulateFuncAndSerialize(func::FuncOp entry_func, - std::string* serialized_func_module) { - ModuleOp module = entry_func->getParentOfType(); - SymbolTable entry_module_table(module); - SmallVector referenced({entry_func}); - - // Create a new module to hold func and all referenced functions. - OwningOpRef module_for_func = - ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext())); - SymbolTable symbol_table(module_for_func.get()); - - while (!referenced.empty()) { - func::FuncOp func = referenced.pop_back_val(); - - // Skip functions that have already been cloned into new module. - if (symbol_table.lookup(func.getName())) continue; - - // Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone - // all found FuncOps to new_module to make sure new_module is - // self-contained. - Optional uses = SymbolTable::getSymbolUses(func); - assert(uses && "expected to be able to collect symbol uses"); - for (SymbolTable::SymbolUse use : *uses) { - func::FuncOp referenced_func = entry_module_table.lookup( - use.getSymbolRef().cast().getValue()); - - // Skip Symbols that do not map to a function. - if (!referenced_func) continue; - - referenced.emplace_back(referenced_func); - } - - func::FuncOp clone = func.clone(); - if (clone.getName() == entry_func.getName()) { - clone.setPublic(); - } else { - clone.setPrivate(); - } - symbol_table.insert(clone); - } - - *serialized_func_module = - tensorflow::SerializeMlirModule(module_for_func.get()); - return success(); -} - -void DistRemoteRunEncapsulatePass::runOnOperation() { - mlir::TF::RuntimeDevices devices; - ModuleOp module = getOperation(); - SymbolTable symtab(module); - Type chain_type = tfrt::compiler::ChainType::get(&getContext()); - Type remote_object_id_ty = tfrt::dist::RemoteObjectIdType::get(&getContext()); - Type tensor_handle_ty = tfrt::corert::TensorHandleType::get(&getContext()); - module.walk([&](tfrt::dist::RemoteExecuteFuncOp remote_exec_op) { - FlatSymbolRefAttr callee_sym = remote_exec_op.getCalleeAttr(); - func::FuncOp callee = symtab.lookup(callee_sym.getValue()); - if (!callee) { - remote_exec_op.emitOpError("callee function ") - << callee_sym.getValue() << " is not found"; - signalPassFailure(); - return WalkResult::interrupt(); - } - std::string txt_module; - if (failed(EncapsulateFuncAndSerialize(callee, &txt_module))) { - remote_exec_op.emitOpError("failed to serialize the callee function ") - << callee.getName(); - signalPassFailure(); - return WalkResult::interrupt(); - } - Location loc = remote_exec_op.getLoc(); - StringAttr callee_name = - StringAttr::get(&getContext(), callee_sym.getValue()); - OpBuilder builder(remote_exec_op); - auto register_op = builder.create( - loc, chain_type, remote_exec_op.getInOpChain(), - remote_exec_op.getContext(), remote_exec_op.getRemoteTask(), - StringAttr::get(&getContext(), txt_module), callee_name); - - // Build the device assignment for the results - // TODO(tfrt-devs): Define properly MLIR types and operations - SmallVector result_devices; - for (const auto& result : llvm::enumerate(remote_exec_op.getResults())) { - StringAttr device = - callee.getResultAttrOfType(result.index(), kTFRTDevice); - if (!device) { - // The result might not have the device attribute if it is added by - // the tf-to-tfrt pass. Use the first CPU on the remote host as the - // device of this result. - DeviceNameUtils::ParsedName parsed_name; - if (StringAttr host_attr = callee->getAttrOfType(kHost)) { - auto host = host_attr.getValue(); - DeviceNameUtils::ParseFullName({host.data(), host.size()}, - &parsed_name); - } - parsed_name.has_type = true; - parsed_name.type = "CPU"; - parsed_name.has_id = true; - parsed_name.id = 0; - device = StringAttr::get( - &getContext(), DeviceNameUtils::ParsedNameToString(parsed_name)); - } - result_devices.push_back(std::move(device)); - } - // IDEA(donglin): Update the create_remote_execute_spec kernel to use Device - // object instead of Device string. - Type remote_spec_ty = tfrt::dist::RemoteExecuteSpecType::get(&getContext()); - auto result_devices_attr = ArrayAttr::get(&getContext(), result_devices); - auto remote_spec = builder.create( - loc, remote_spec_ty, remote_exec_op.getContext(), result_devices_attr); - // If original argument is already tfrt_dist.remote_object_id, use it - // directly. If it is TensorHandle, insert an op to extract the - // tfrt_dist.remote_object_id from it. Otherwise, emit an error. - SmallVector arguments; - for (Value value : remote_exec_op.getCalleeArgs()) { - if (value.getType().isa()) { - arguments.push_back(value); - } else if (value.getType().isa()) { - auto new_op = builder.create( - loc, remote_object_id_ty, value); - arguments.push_back(new_op.getResult()); - } else { - remote_exec_op.emitOpError( - "callee argument type should be either " - "TensorHandle or RemoteObjectId"); - signalPassFailure(); - return WalkResult::interrupt(); - } - } - // Result types are 1 chain, followed by `num_th_results + 1` - // tfrt_dist.remote_object_id results, followed by `num_th_results` - // corert.tensorhandle results. - int32_t num_th_results = remote_exec_op.getResults().size() - 1; - SmallVector result_types; - result_types.push_back(chain_type); - for (int count : llvm::seq(0, num_th_results + 1)) { - (void)count; - result_types.push_back(remote_object_id_ty); - } - for (int count : llvm::seq(0, num_th_results)) { - (void)count; - result_types.push_back(tensor_handle_ty); - } - auto new_remote_exec_th_op = builder.create( - loc, result_types, register_op.getOutOpChain(), - remote_exec_op.getContext(), remote_exec_op.getRemoteTask(), - remote_spec, num_th_results, callee_name.getValue(), - std::move(arguments)); - // The part of the new results to replace the original results are 2 chains, - // followed `num_th_results` corert.tesnorhandle results from the callee - // function. - SmallVector new_results; - new_results.push_back(new_remote_exec_th_op.getResult(0)); - new_results.push_back(new_remote_exec_th_op.getResult(1)); - for (int i : llvm::seq(0, num_th_results)) { - new_results.push_back( - new_remote_exec_th_op.getResult(i + 2 + num_th_results)); - } - remote_exec_op.replaceAllUsesWith(new_results); - remote_exec_op.erase(); - - return WalkResult::advance(); - }); -} - -} // namespace - -std::unique_ptr> CreateDistRemoteRunEncapsulatePass() { - return std::make_unique(); -} - -static PassRegistration pass; - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/remove_device_attribute.cc b/tensorflow/compiler/mlir/tfrt/transforms/remove_device_attribute.cc index 7051e8d5656..c765e08742a 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/remove_device_attribute.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/remove_device_attribute.cc @@ -63,7 +63,7 @@ void RemoveDeviceAttributePass::runOnOperation() { OpBuilder builder(execute_op); auto new_execute_op = builder.create( execute_op.getLoc(), execute_op.getResultTypes(), - execute_op.getOpHandler(), execute_op.operands(), new_op_attrs, + execute_op.getOpHandler(), execute_op.getArguments(), new_op_attrs, op_func_attrs, execute_op.getOpName()); execute_op.replaceAllUsesWith(new_execute_op.getResults()); execute_op.erase(); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/remove_tf_if_const_args.cc b/tensorflow/compiler/mlir/tfrt/transforms/remove_tf_if_const_args.cc index 07948452d99..d855fa41344 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/remove_tf_if_const_args.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/remove_tf_if_const_args.cc @@ -61,7 +61,7 @@ class RemoveTfIfConstArgs llvm::SmallVector const_arg_indices; // Record the remaining operands that won't be removed. llvm::SmallVector remaining_args; - for (auto iter : llvm::enumerate(if_op.input())) { + for (auto iter : llvm::enumerate(if_op.getInput())) { mlir::Value operand = iter.value(); if (auto const_op = operand.getDefiningOp()) { const_args.push_back(const_op); @@ -95,10 +95,10 @@ class RemoveTfIfConstArgs // Change the if_op's argumetns to the new arguments, branches to new // branches. Note that the outputs are not changed. - if_op.inputMutable().assign(remaining_args); - if_op.then_branchAttr( + if_op.getInputMutable().assign(remaining_args); + if_op.setThenBranchAttr( mlir::SymbolRefAttr::get(builder.getContext(), new_then_function_name)); - if_op.else_branchAttr( + if_op.setElseBranchAttr( mlir::SymbolRefAttr::get(builder.getContext(), new_else_function_name)); } @@ -153,7 +153,8 @@ class RemoveTfIfConstArgs new_branch.getLoc(), new_branch_type.getResults(), call_args, branch.getSymName(), "", "", ""); // Note that the outputs are not changed. - builder.create(new_branch.getLoc(), call_op.output()); + builder.create(new_branch.getLoc(), + call_op.getOutput()); return new_branch.getSymName(); } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/reorder_assert.cc b/tensorflow/compiler/mlir/tfrt/transforms/reorder_assert.cc index b3993afeaf0..eb6eb9dbdae 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/reorder_assert.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/reorder_assert.cc @@ -83,8 +83,7 @@ class ReorderTfAssertPass bool IsFunctionNonSideEffectingOrAssert(mlir::func::FuncOp func_op) { auto& block = func_op.front(); for (mlir::Operation& op : block) { - if (!llvm::isa(&op) && - !mlir::MemoryEffectOpInterface::hasNoEffect(&op)) + if (!llvm::isa(&op) && !mlir::isMemoryEffectFree(&op)) return false; } return true; diff --git a/tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.cc b/tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.cc index 5df5f9264a7..a80d1ba7e18 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.cc @@ -36,7 +36,7 @@ class SetShapeInvariantInWhileOps func_op.walk([&](mlir::TF::WhileOp op) { // Skip tf.While op on TPU. if (!op->hasAttr("_tpu_replicate")) { - op.shape_invariantAttr(shape_invariant); + op.setShapeInvariantAttr(shape_invariant); } }); } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc b/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc new file mode 100644 index 00000000000..e775ebd896a --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc @@ -0,0 +1,186 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +namespace tensorflow { +namespace { + +// Clone the sinkable op associated with the func op to the func op +void CloneOpIntoFuncOp( + const llvm::DenseMap> + &func_op_operands_to_sink) { + for (auto &iter : func_op_operands_to_sink) { + auto func = iter.first; + mlir::OpBuilder builder(func); + builder.setInsertionPointToStart(&func.getBody().front()); + + for (auto &operand_iter : iter.second) { + auto *cloned_op = operand_iter.second->clone(); + func.getArgument(operand_iter.first) + .replaceAllUsesWith(*cloned_op->getResults().begin()); + builder.insert(cloned_op); + builder.setInsertionPointAfter(cloned_op); + } + } +} + +// TODO(b/262610234): Generalize the sinking conditions. +// Check if the op qualifies to sink to the callee. +bool IsSinkCandidate(mlir::Operation *op) { + return op && llvm::isa(op); +} + +// Check if the op is allowed to be sinked. We are being conservative here to +// whilelist very limited set of ops here. +bool AllowSinkTo(mlir::Operation *op) { + return llvm::isa(op); +} + +// There are following cases: +// #1, sink v1 +// @func1 { v1 = VarHandleOp, v2 = CallerOp{ f=@func2 }(v1) } +// @func2(arg0) { v2 = ReadVariableOp } +// +// #2, copy v1 to callee, still keep in func1 +// @func1 { v1 = VarHandleOp, v2 = ReadVariableOp, v3 = CallerOp{ f=@func2 }(v1) +// } +// @func2(arg0) { v2 = ReadVariableOp(arg0) } +// +// #3, sink v1 and v2 +// @func1 { v1 = VarHandleOp, v2 = ReadVariableOp, v3 = CallerOp{ f=@func2 }(v2) +// } +// @func2(arg0) { v2 = OtherOp(arg0) } +// +// #4, copy v1 and v2 to func2, keep in func1 +// @func1 { v1 = VarHandleOp, v2 = ReadVariableOp, v3 = OtherOp(v2), v4 = +// CallerOp{ f=@func2 }(v2) } +// @func2(arg0) { v2 = OtherOp(arg0) } +// +// We only support #1 for now as that's the most common pattern from production. +// If we implement #2 and #4 in the future, should consider dedupe in the +// tfrt_resource_init because multiple resource handle will be created on the +// same resource. + +void SinkInInvariantOps(mlir::ModuleOp module) { + mlir::SymbolTable symbol_table(module); + mlir::SymbolTableCollection symbol_table_collection; + mlir::SymbolUserMap symbol_users(symbol_table_collection, module); + + // Maps from function op, to the operand index to erase, to the caller op. + llvm::DenseMap> + func_op_operands_to_sink; + + // TODO(b/263191534): Replace with CallOpInterface to handle callees. + // Identify the invariant Op, Caller, Callee FuncOp to update. + module.walk([&](mlir::Operation *op) { + if (!AllowSinkTo(op)) return; + + auto track_callee = [&](mlir::func::FuncOp &func_op) { + auto diff = op->getNumOperands() - func_op.getNumArguments(); + for (int i = 0; i < func_op.getNumArguments(); ++i) { + auto arg_op = op->getOperand(diff + i).getDefiningOp(); + if (!IsSinkCandidate(arg_op)) continue; + func_op_operands_to_sink[func_op][i] = arg_op; + } + }; + + llvm::DenseSet callees; + for (const auto &named_attr : op->getAttrs()) { + if (auto symbol_attr = + named_attr.getValue().dyn_cast()) { + auto symbol = symbol_attr.getValue(); + + auto callee = symbol_table.lookup(symbol); + if (!callee) continue; + + // One callee invoked by multiple caller is skipped for simplicity. + // Consider adding support if more usage are observed from production. + if (const llvm::ArrayRef users = + symbol_users.getUsers(callee); + users.size() > 1) + continue; + + // Invoked by same caller multiple times, only process the first one. + if (callees.count(symbol)) continue; + track_callee(callee); + callees.insert(symbol); + } + } + }); + + CloneOpIntoFuncOp(func_op_operands_to_sink); +} + +class SinkInInvariantOpsPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SinkInInvariantOpsPass) + + llvm::StringRef getArgument() const final { + return "tfrt-sink-in-invariant-ops"; + } + llvm::StringRef getDescription() const final { + return "Sink in the invariant ops to facilitate invariant ops hoisting."; + } + + void runOnOperation() override { + auto module = getOperation(); + SinkInInvariantOps(module); + } +}; + +} // namespace + +std::unique_ptr> +CreateSinkInInvariantOpsPass() { + return std::make_unique(); +} + +static mlir::PassRegistration + sink_in_invariant_ops_pass; + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc index 39c558a66ad..91ec99eb94a 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h" #include "tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.h" +#include "tensorflow/compiler/mlir/tfrt/ir/gpu_ops.h" #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h" #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" #include "tensorflow/compiler/mlir/tfrt/jit/opdefs/tf_jitrt_ops.h" @@ -60,8 +61,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/attr_lowering_utils.h" #include "tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h" #include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/utils.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/tstring.h" @@ -72,8 +75,6 @@ limitations under the License. #include "tfrt/core_runtime/opdefs/attributes.h" // from @tf_runtime #include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime -#include "tfrt/distributed_runtime/opdefs/kernels.h" // from @tf_runtime -#include "tfrt/distributed_runtime/opdefs/types.h" // from @tf_runtime #include "tfrt/test_kernels/opdefs/test_kernels.h" // from @tf_runtime namespace tensorflow { @@ -93,8 +94,7 @@ constexpr int64_t kDefaultCheapCost = 1; void getDependentConversionDialects(mlir::DialectRegistry ®istry) { registry.insert(); + tfrt::compiler::TFRTDialect, tf_jitrt::JitRuntimeDialect>(); } mlir::Value GetFunctionInputChain(mlir::Operation *op) { @@ -102,6 +102,49 @@ mlir::Value GetFunctionInputChain(mlir::Operation *op) { return func_op.getArgument(0); } +llvm::SmallVector AddGpuVariableAndInputTensorTransferOps( + mlir::Operation *op, llvm::SmallVector operands, + mlir::ConversionPatternRewriter &rewriter) { + llvm::SmallVector new_operands; + assert(op->getOperands().size() == operands.size()); + for (int i = 0; i < op->getOperands().size(); ++i) { + if (IsResultVariable(op->getOperand(i), operands[i])) { + auto transfer_variable_op = + rewriter.create( + op->getLoc(), rewriter.getType(), + mlir::ValueRange{operands[i]}, ArrayRef()); + new_operands.push_back(transfer_variable_op); + } else { + auto transfer_to_device_op = + rewriter.create( + op->getLoc(), rewriter.getType(), + mlir::ValueRange{operands[i]}, ArrayRef()); + new_operands.push_back(transfer_to_device_op); + } + } + return new_operands; +} + +llvm::SmallVector AddGpuTransferFromDeviceOps( + mlir::Operation *op, llvm::SmallVector results, + mlir::ConversionPatternRewriter &rewriter) { + assert(results.size() == op->getNumResults()); + llvm::SmallVector new_results; + for (int idx = 0; idx < results.size(); ++idx) { + if (op->getResult(idx).use_empty()) { + // If the result is not used, it is not transferred. + new_results.push_back(results[idx]); + } else { + auto transfer_from_device_op = + rewriter.create( + op->getLoc(), rewriter.getType(), + results[idx]); + new_results.push_back(transfer_from_device_op); + } + } + return new_results; +} + // Convert TF dialect ops to tfrt_fallback.executeop for non-side-effecting ops // and tfrt_fallback.executeop.seq for side-effecting ops. // @@ -120,15 +163,18 @@ class FallbackExecuteOpConversion : public mlir::ConversionPattern { FallbackExecuteOpConversion( mlir::MLIRContext *context, CoreRTConverter *corert_converter, tfrt_compiler::FallbackConverter *fallback_converter, + const mlir::SymbolTable *symbol_table, const tfrt_compiler::CostAnalysis *cost_analysis, - bool tpu_lower_to_fallback, bool target_tpurt) + bool tpu_lower_to_fallback, bool target_tpurt, bool use_bridge_for_gpu) : mlir::ConversionPattern(mlir::Pattern::MatchAnyOpTypeTag(), kFallbackBenefit, context), corert_converter_(*corert_converter), fallback_converter_(*fallback_converter), + symbol_table_(*symbol_table), cost_analysis_(*cost_analysis), tpu_lower_to_fallback_(tpu_lower_to_fallback), - target_tpurt_(target_tpurt) {} + target_tpurt_(target_tpurt), + use_bridge_for_gpu_(use_bridge_for_gpu) {} LogicalResult matchAndRewrite( mlir::Operation *op, ArrayRef operands, @@ -151,8 +197,21 @@ class FallbackExecuteOpConversion : public mlir::ConversionPattern { // Convert the function (symbol) attributes to an array of string // attributes, which represents the function names. llvm::SmallVector func_attr_keys; - mlir::ArrayAttr op_func_attrs = - corert_converter_.CreateOpFuncAttrs(op->getAttrs(), &func_attr_keys); + + // If the op is XlaLaunch on GPU, the function attribute will use the + // function name in MLIR, instead of the original function name in the + // function library, because the function could have been changed by bridge, + // e.g., variable lifting. The new MLIR function will need to be exported to + // the function library for runtime to use. + bool use_mlir_func_name = + parsed_device_name->device_type == DEVICE_GPU && use_bridge_for_gpu_ && + op->getName().getStringRef().str() == "tf.XlaLaunch"; + + mlir::ArrayAttr op_func_attrs = corert_converter_.CreateOpFuncAttrs( + symbol_table_, op->getAttrs(), &func_attr_keys, use_mlir_func_name); + if (!op_func_attrs) { + return op->emitWarning("failed to create func attributes."); + } // Remove the function attributes, which have already been processed. for (const auto &key : func_attr_keys) op->removeAttr(key); @@ -235,9 +294,12 @@ class FallbackExecuteOpConversion : public mlir::ConversionPattern { CoreRTConverter &corert_converter_; tfrt_compiler::FallbackConverter &fallback_converter_; + const mlir::SymbolTable &symbol_table_; const tfrt_compiler::CostAnalysis &cost_analysis_; bool tpu_lower_to_fallback_; bool target_tpurt_; + // TODO(b/260915352): Remove the flag and default to using bridge. + bool use_bridge_for_gpu_; }; mlir::LogicalResult FallbackExecuteOpConversion::ConvertToFallbackExecuteOp( @@ -262,16 +324,27 @@ mlir::LogicalResult FallbackExecuteOpConversion::ConvertToFallbackExecuteOp( IntegerAttr cost; auto parsed_device_name = corert_converter_.ParseDeviceName(device.getValue()); - if (parsed_device_name && parsed_device_name->device_type == DEVICE_GPU) { + bool is_gpu_op = + parsed_device_name && parsed_device_name->device_type == DEVICE_GPU; + if (is_gpu_op) { // For GPU ops, the host only needs to dispatch them to GPUs, which should // be relatively cheap for the host. cost = rewriter.getI64IntegerAttr(kDefaultCheapCost); } else { - cost = rewriter.getI64IntegerAttr( - cost_analysis_.GetCost(op, fallback_key.getInt())); + cost = rewriter.getI64IntegerAttr(cost_analysis_.GetCost(op)); + } + + // For now, we only consider GPU XLA clusters in the form of XlaLaunch for + // simplicity. We could extend to support other GPU ops that cann't be XLAed. + bool is_xla_launch_on_gpu = + is_gpu_op && use_bridge_for_gpu_ && + op->getName().getStringRef().str() == "tf.XlaLaunch"; + if (is_xla_launch_on_gpu) { + new_operands = + AddGpuVariableAndInputTensorTransferOps(op, new_operands, rewriter); } - if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) { + if (mlir::isMemoryEffectFree(op)) { auto new_op = rewriter.create( op->getLoc(), result_types, new_operands, device, op_attrs, op_func_attrs, fallback_key, op_name, cost); @@ -297,8 +370,16 @@ mlir::LogicalResult FallbackExecuteOpConversion::ConvertToFallbackExecuteOp( new_operands, device, op_attrs, op_func_attrs, fallback_key, op_name, cost); fallback_converter.RegisterFallbackOp(new_op); - rewriter.replaceOp(op, new_op.getResults()); out_chain = new_op.getOutOpChain(); + if (is_xla_launch_on_gpu) { + // TODO(b/262280565): Remove unnecessary data transfers when there are + // multiple XLA clusters. + auto results = + AddGpuTransferFromDeviceOps(op, new_op.getResults(), rewriter); + rewriter.replaceOp(op, results); + } else { + rewriter.replaceOp(op, new_op.getResults()); + } } // Register the converted op so that it can be retrieved by successors. @@ -328,7 +409,7 @@ mlir::LogicalResult FallbackExecuteOpConversion::ConvertToCoreRTExecuteOp( corert_converter_.ConvertOpHandler(op, op_handler_name, &rewriter); if (!op_handler) return failure(); - if (mlir::MemoryEffectOpInterface::hasNoEffect(op)) { + if (mlir::isMemoryEffectFree(op)) { auto new_op = rewriter.create( op->getLoc(), result_types, op_handler, new_operands, op_attrs, op_func_attrs, op_name); @@ -360,15 +441,15 @@ class FallbackConstOpConversion mlir::TF::ConstOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { // Some data types are handled separately using a fast path. - if (IsSupportedTfrtNumericDType(op.dtype()) || - op.dtype().isa()) + if (IsSupportedTfrtNumericDType(op.getDtype()) || + op.getDtype().isa()) return failure(); // For other data types that do not have a fast path (eg. quantized types), // we convert it to serialized tensor proto. tensorflow::TensorProto tensor_proto; - auto status = ConvertToTensorProto(op.value(), &tensor_proto); + auto status = ConvertToTensorProto(op.getValue(), &tensor_proto); if (!status.ok()) return op.emitError(status.error_message()); rewriter.replaceOpWithNewOp( @@ -411,7 +492,7 @@ class FallbackSetResourceOp auto new_op = rewriter.create( op.getLoc(), corert_converter_.chain_type(), corert_converter_.GetLocalSideEffectChain(op, &rewriter), - new_operands[0], device.getValue(), op.index()); + new_operands[0], device.getValue(), op.getIndex()); // Register the converted op so that it can be retrieved by successors. corert_converter_.RegisterLocalSideEffectChain(op, new_op.getOutCh()); @@ -448,7 +529,7 @@ class FallbackGetResourceOp auto new_op = rewriter.create( op.getLoc(), corert_converter_.chain_type(), result_types, ready_chain, - device.getValue(), op.indices()); + device.getValue(), op.getIndices()); rewriter.replaceOp(op, new_op.getResults()); @@ -492,16 +573,6 @@ class TFDeviceRemoteRunOpConversion LogicalResult matchAndRewrite( tf_device::RemoteRunOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - mlir::Value distributed_context = - corert_converter_.GetDistributedContext(op.getOperation(), &rewriter); - mlir::Value in_op_chain = - corert_converter_.GetLocalSideEffectChain(op, &rewriter); - mlir::Value task_handle = corert_converter_.GetTaskHandle( - op.getOperation(), op.getHost(), &rewriter); - mlir::Value remote_chain_mgr = - corert_converter_.GetRemoteChainManager(op, &rewriter); - mlir::Type remote_obj_id_ty = - rewriter.getType(); ModuleOp module = op->getParentOfType(); SymbolTable symtab(module); func::FuncOp callee = symtab.lookup(op.getCallee()); @@ -517,11 +588,7 @@ class TFDeviceRemoteRunOpConversion } llvm::SmallVector arguments; - // The first argument of the remote function should be a remote chain which - // is added to the function signature when it is lowered from TF dialect to - // TFRT dialect. - arguments.push_back(corert_converter_.GetRemoteSideEffectChain( - op, host.getValue(), &rewriter)); + for (mlir::Value argument : op.getCalleeArgs()) { arguments.push_back(argument); } @@ -530,23 +597,9 @@ class TFDeviceRemoteRunOpConversion // The first result of the remote function should be a remote chain which // is added to the function signature when it is lowered from TF dialect to // TFRT dialect. - result_types.push_back(remote_obj_id_ty); for (mlir::Type type : op.getResultTypes()) { (void)type_converter_.convertType(type, result_types); } - auto remote_execute_func_op = - rewriter.create( - op.getLoc(), corert_converter_.chain_type(), result_types, - in_op_chain, distributed_context, task_handle, op.getCallee(), - arguments); - rewriter.replaceOp(op, remote_execute_func_op.getResults().drop_front(1)); - - auto set_chain_op = rewriter.create( - op.getLoc(), corert_converter_.chain_type(), - remote_execute_func_op.getOutOpChain(), remote_chain_mgr, task_handle, - remote_execute_func_op.getResults().front()); - corert_converter_.RegisterLocalSideEffectChain( - op, set_chain_op.getOutOpChain()); return success(); } @@ -602,7 +655,8 @@ class FallbackBatchFunctionOpConversion op->getNumResults(), rewriter.getType()); auto new_op = rewriter.create( - op.getLoc(), result_types, new_operands, device, op.fAttr(), op_attrs); + op.getLoc(), result_types, new_operands, device, op.getFAttr(), + op_attrs); rewriter.replaceOp(op, new_op.getResults()); return success(); } @@ -624,7 +678,7 @@ class CoreRTConstDenseTensorOpConversion LogicalResult matchAndRewrite( mlir::TF::ConstOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!IsSupportedTfrtNumericDType(op.dtype())) return failure(); + if (!IsSupportedTfrtNumericDType(op.getDtype())) return failure(); // Only CPU ops can be lowered using this conversion. If there is no device // assignment, this op is treated as a CPU op and can be lowered. @@ -633,7 +687,7 @@ class CoreRTConstDenseTensorOpConversion auto new_op = rewriter.create( op.getLoc(), corert_converter_.tensor_handle_type(), - op.value().cast()); + op.getValue().cast()); rewriter.replaceOp(op, new_op->getResult(0)); return success(); } @@ -746,9 +800,10 @@ class CoreRTConstStringTensorOpConversion LogicalResult matchAndRewrite( mlir::TF::ConstOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // NOLINT - if (!op.dtype().isa()) return failure(); + if (!op.getDtype().isa()) return failure(); - DenseStringElementsAttr attr = op.value().cast(); + DenseStringElementsAttr attr = + op.getValue().cast(); llvm::SmallVector values; values.reserve(attr.getNumElements()); @@ -757,7 +812,7 @@ class CoreRTConstStringTensorOpConversion llvm::StringRef(element.data(), element.size()))); // Create the shape attribute from the tensor shape. - ArrayRef shape = op.value().getType().getShape(); + ArrayRef shape = op.getValue().getType().getShape(); llvm::SmallVector dims; dims.reserve(shape.size()); auto i64_type = rewriter.getIntegerType(64); @@ -777,100 +832,6 @@ class CoreRTConstStringTensorOpConversion CoreRTConverter &corert_converter_; }; -// Convert TF dialect operations with no side effects to CoreRT ExecuteOp. For -// example, -// -// %0 = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : -// (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> -// -// is converted to -// -// %result = corert.executeop(%device) -// "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : -// (!corert.tensorhandle, !corert.tensorhandle) -> !corert.tensorhandle -// -// Note that it will fail to match if some attributes are not supported. -template -class CoreRTExecuteOpConversion : public mlir::OpConversionPattern { - public: - CoreRTExecuteOpConversion(mlir::MLIRContext *context, - CoreRTConverter *corert_converter) - : CoreRTExecuteOpConversion(context, corert_converter, "") {} - - // If device_name is not empty, only ops that are using this device is lowered - // using CoreRTExecuteOpConversion. - CoreRTExecuteOpConversion(mlir::MLIRContext *context, - CoreRTConverter *corert_converter, - llvm::StringRef device_name) - : mlir::OpConversionPattern(context, kCoreRTBenefit), - corert_converter_(*corert_converter), - device_name_(device_name) {} - - LogicalResult matchAndRewrite( - TF_Op op, typename TF_Op::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto parsed_device_name = corert_converter_.ParseDeviceName(op); - // Return failure and emit warning if there is no device assignment. - if (!parsed_device_name) { - return op->emitWarning( - "failed to retrieve valid device when converting to " - "corert.executeop"); - } - - // If device_name is specified, check the device of this op first. - if (!device_name_.empty()) { - // Skip if it does not match the specified device. - if (parsed_device_name->device_name != device_name_) return failure(); - } - - mlir::StringAttr op_name = rewriter.getStringAttr(op.getOperationName()); - - llvm::SmallVector result_types; - for (auto type : op.getOperation()->getResultTypes()) { - if (failed(corert_converter_.convertType(type, result_types))) - return failure(); - } - - corert_converter_.MaterializeDerivedAttributes(op); - - // Convert the function (symbol) attributes to an array of string - // attributes, which represents the function names. - llvm::SmallVector func_attr_keys; - ArrayAttr op_func_attrs = - corert_converter_.CreateOpFuncAttrs(op->getAttrs(), &func_attr_keys); - - // Remove the function attributes, which have already been processed. - for (const auto &key : func_attr_keys) op->removeAttr(key); - - Builder builder(op.getContext()); - ArrayAttr op_attrs = CreateTfrtOpAttrs(op->getAttrs(), builder); - if (!op_attrs) return op.emitError("failed to lower attributes."); - - llvm::SmallVector new_operands; - if (mlir::failed(tfrt_compiler::ConvertCoreRTOperands( - op, adaptor.getOperands(), &new_operands, rewriter))) - return failure(); - - // Get the op handler, or create one if there does not exist one. Note that - // ConvertOpHandler changes internal state so it can only be called if the - // rewrite is guaranteed to succeed afterwards. - auto op_handler = corert_converter_.ConvertOpHandler( - op, parsed_device_name->op_handler_name, &rewriter); - if (!op_handler) return failure(); - - auto new_op = rewriter.create( - op.getLoc(), result_types, op_handler, new_operands, op_attrs, - op_func_attrs, op_name); - - rewriter.replaceOp(op, new_op.getResults()); - return success(); - } - - private: - CoreRTConverter &corert_converter_; - llvm::StringRef device_name_; -}; - LogicalResult ConvertFunctionCallOperands( mlir::Operation *op, ValueRange operands, llvm::SmallVectorImpl *new_operands, @@ -932,7 +893,7 @@ class TFRTCallOpConversion : public mlir::OpConversionPattern { new_operands); rewriter.replaceOp(op, new_op.getResults().drop_front()); - if (!mlir::MemoryEffectOpInterface::hasNoEffect(op)) { + if (!mlir::isMemoryEffectFree(op)) { // Register the converted op so that it can be retrieved by successors. // TODO(chky): Add OpTraits or OpInterface, rather than assume first // result is a chain. @@ -1011,7 +972,7 @@ class TFRTCaseOpConversion : public mlir::OpConversionPattern { LogicalResult matchAndRewrite( TF::CaseOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - mlir::ArrayAttr branches = op.branches(); + mlir::ArrayAttr branches = op.getBranches(); llvm::SmallVector result_types; result_types.push_back(corert_converter_.chain_type()); @@ -1092,8 +1053,8 @@ class TFRTCondOpConversion : public mlir::OpConversionPattern { mlir::LogicalResult matchAndRewrite( mlir::TF::IfOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - mlir::FlatSymbolRefAttr then_branch = op.then_branchAttr(); - mlir::FlatSymbolRefAttr else_branch = op.else_branchAttr(); + mlir::FlatSymbolRefAttr then_branch = op.getThenBranchAttr(); + mlir::FlatSymbolRefAttr else_branch = op.getElseBranchAttr(); llvm::SmallVector result_types; result_types.push_back(rewriter.getType()); @@ -1124,7 +1085,7 @@ class TFRTCondOpConversion : public mlir::OpConversionPattern { // The first result is a !tfrt.chain. rewriter.replaceOp(op, new_op.getResults().drop_front(1)); - if (!mlir::MemoryEffectOpInterface::hasNoEffect(op)) { + if (!mlir::isMemoryEffectFree(op)) { // Register the converted op so that it can be retrieved by successors. // TODO(chky): Add OpTraits or OpInterface, rather than assume first // result is a chain. @@ -1195,8 +1156,8 @@ class TFRTWhileOpConversion mlir::LogicalResult matchAndRewrite( mlir::TF::WhileOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - mlir::FlatSymbolRefAttr cond_fn = op.condAttr(); - mlir::FlatSymbolRefAttr body_fn = op.bodyAttr(); + mlir::FlatSymbolRefAttr cond_fn = op.getCondAttr(); + mlir::FlatSymbolRefAttr body_fn = op.getBodyAttr(); llvm::SmallVector while_arg_result_types; // Insert a chain for side effects as the first argument/result. @@ -1249,7 +1210,7 @@ class TFRTWhileOpConversion while_args[0] = pred_chain; int64_t parallel_iterations = - enable_while_parallel_iterations_ ? op.parallel_iterations() : 1; + enable_while_parallel_iterations_ ? op.getParallelIterations() : 1; auto new_op = rewriter.create( op.getLoc(), while_arg_result_types, first_iteration_bool_cond, @@ -1259,7 +1220,7 @@ class TFRTWhileOpConversion if (!has_at_most_tensor_array_effect) out_chain = new_op.getResult(0); - if (!mlir::MemoryEffectOpInterface::hasNoEffect(op)) { + if (!mlir::isMemoryEffectFree(op)) { // Register the converted op so that it can be retrieved by successors. // TODO(chky): Add OpTraits or OpInterface, rather than assume first // result is a chain. @@ -1370,7 +1331,7 @@ mlir::func::FuncOp TFRTWhileOpConversion::GetWhileBodyFunction( mlir::func::FuncOp pred_fn, mlir::TypeRange arg_types, mlir::ConversionPatternRewriter &rewriter) const { int64_t parallel_iterations = - enable_while_parallel_iterations_ ? op.parallel_iterations() : 1; + enable_while_parallel_iterations_ ? op.getParallelIterations() : 1; std::string body_fn_name = original_body_fn.getValue().str() + "/tfrt_body_" + absl::StrCat(parallel_iterations); @@ -1491,7 +1452,6 @@ void SetUpTFToTFRTConversionLegality(mlir::ConversionTarget *target, target->addLegalDialect(); target->addLegalDialect(); target->addLegalDialect(); - target->addLegalDialect(); target->addLegalDialect(); target->addLegalDialect(); target->addIllegalDialect(); @@ -1529,13 +1489,12 @@ void PopulateTFToTFRTConversionPatterns( const tfrt_compiler::CostAnalysis *cost_analysis, const tfrt_compiler::TensorArraySideEffectAnalysis *tensor_array_side_effect_analysis, - bool enable_native_ops, bool func_use_fallback_tensor, - bool enable_while_parallel_iterations, bool tpu_lower_to_fallback, - bool target_tpurt) { + bool func_use_fallback_tensor, bool enable_while_parallel_iterations, + bool tpu_lower_to_fallback, bool target_tpurt, bool use_bridge_for_gpu) { // By default, we lower all TF ops to fallback ops. patterns->add( - context, corert_converter, fallback_converter, cost_analysis, - tpu_lower_to_fallback, target_tpurt); + context, corert_converter, fallback_converter, symbol_table, + cost_analysis, tpu_lower_to_fallback, target_tpurt, use_bridge_for_gpu); patterns->add(context, corert_converter); @@ -1571,47 +1530,6 @@ void PopulateTFToTFRTConversionPatterns( // use ExecuteOp pattern to convert string tensor attribute. patterns->add(context, corert_converter); - - if (enable_native_ops) { - // Below TF operations will be converted to use corert.executeop, which will - // invoke TFRT native op if implemented. - // TODO(b/187942369): Pattern registration for TF operations is not - // sustainable currently. We need to figure out a plan. - patterns->add, - // TODO(chky): Move the ReadVariableOp + Identity pattern - // to optimizer. - // CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion, - CoreRTExecuteOpConversion>(context, - corert_converter); - } } // Lower TF dialect MLIR to TFRT dialect. @@ -1622,6 +1540,7 @@ class TfToTfrtConversionPass getDependentConversionDialects(registry); if (target_tpurt_) RegisterTPUDialects(®istry); + if (target_gpu_) RegisterGpuDialects(®istry); } llvm::StringRef getArgument() const final { return "tf-to-tfrt"; } @@ -1636,7 +1555,6 @@ class TfToTfrtConversionPass TfToTfrtConversionPass() = default; explicit TfToTfrtConversionPass(const TfrtPipelineOptions &options) { target_tpurt_ = options.target_tpurt; - enable_native_ops_ = options.enable_native_ops; tpu_use_core_selector_ = options.tpu_use_core_selector; tpu_use_bundled_transfer_ = options.tpu_use_bundled_transfer; tpu_lower_to_fallback_ = options.tpu_lower_to_fallback; @@ -1649,6 +1567,8 @@ class TfToTfrtConversionPass func_use_fallback_tensor_ = options.func_use_fallback_tensor; enable_while_parallel_iterations_ = options.enable_while_parallel_iterations; + target_gpu_ = options.target_gpu; + use_bridge_for_gpu_ = options.use_bridge_for_gpu; } TfToTfrtConversionPass(const TfToTfrtConversionPass &) {} @@ -1673,6 +1593,10 @@ class TfToTfrtConversionPass tpu_transfer_result_to_host_, use_tpu_host_allocator_for_inputs_}, tpu_lower_to_fallback_); + if (target_gpu_) { + AddGpuTargetDialectAndPatterns(&context, &target, &patterns); + } + mlir::TypeConverter *func_type_converter; if (func_use_fallback_tensor_) { func_type_converter = &fallback_converter; @@ -1686,9 +1610,8 @@ class TfToTfrtConversionPass PopulateTFToTFRTConversionPatterns( &context, &patterns, &corert_converter, &fallback_converter, &symbol_table, &cost_analysis, &tensor_array_side_effect_analysis, - enable_native_ops_, func_use_fallback_tensor_, - enable_while_parallel_iterations_, tpu_lower_to_fallback_, - target_tpurt_); + func_use_fallback_tensor_, enable_while_parallel_iterations_, + tpu_lower_to_fallback_, target_tpurt_, use_bridge_for_gpu_); return mlir::applyPartialConversion(func, target, std::move(patterns)); } @@ -1859,13 +1782,6 @@ class TfToTfrtConversionPass llvm::cl::desc("Target TPURT dialect if true."), llvm::cl::init(false)}; - Option enable_native_ops_{ - *this, "enable-native-ops", - llvm::cl::desc( - "If true, native ops will be used on an opt-in basis " - "instead of fallback ops. If false, no native ops are used."), - llvm::cl::init(true)}; - Option tpu_use_core_selector_{ *this, "tpu-use-core-selector", llvm::cl::desc("If true, use ServingCoreSelector to pick TPU core. " @@ -1899,6 +1815,16 @@ class TfToTfrtConversionPass "program will use tpu host allocator."), llvm::cl::init(false)}; + Option target_gpu_{ + *this, "target-gpu", + llvm::cl::desc("If true, target GPU compiler passes."), + llvm::cl::init(false)}; + + // TODO(b/260915352): Remove the flag and default to using bridge. + Option use_bridge_for_gpu_{ + *this, "use-bridge-for-gpu", + llvm::cl::desc("If true, GPU bridge is used."), llvm::cl::init(false)}; + Option cost_threshold_{ *this, "tfrt-cost-threshold", llvm::cl::desc( @@ -2233,6 +2159,9 @@ void CreateTFExecutorToTFPipeline(mlir::OpPassManager &pm, pm.addNestedPass( tfrt_compiler::CreateOptimizeTfForTfrtPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); + // Guarantee all functions have one use, which enables more exact shape + // inference. + pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass()); pm.addPass(mlir::TF::CreateTFShapeInferencePass()); pm.addPass(mlir::createInlinerPass()); pm.addPass(mlir::createSymbolDCEPass()); @@ -2264,6 +2193,15 @@ void CreateTFExecutorToTFPipeline(mlir::OpPassManager &pm, pm.addPass( tfrt_compiler::CreateDeduplicateFunctionsInovkedByBatchFunctionPass()); + // RemoveUnusedWhileResultsPass operates on the region-based control flow, so + // the functional control flow is first converted to region-based control + // flow, which is converted back after the optimization passes are performed. + pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); + pm.addPass(mlir::createInlinerPass()); + pm.addNestedPass( + mlir::TF::CreateRemoveUnusedWhileResultsPass()); + pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional()); + // Apply standard optimization after optimizing control flow ops. pm.addPass(mlir::createInlinerPass()); pm.addNestedPass(mlir::createCSEPass()); @@ -2309,28 +2247,26 @@ void CreateTFExecutorToTFPipeline(mlir::OpPassManager &pm, // convert them to functions. We currently support only tfrt fallback tensors // as operands, so we disable these passes if we can have native ops after // lowering. - if (!options.enable_native_ops) { - pm.addNestedPass(CreateTfJitRtClusteringPass( - options.auto_fusion_oplist, options.auto_fusion_min_cluster_size)); - - // Sink small constants into the outlined clusters to reduce the number of - // arguments for each of the execute operations. - auto is_compilable_const = [](mlir::tf_device::ClusterOp cluster, - mlir::ElementsAttr value) -> bool { - // Ensure that cluster was formed for TFRT JIT compilation. - auto policy = cluster->getAttr("policy").dyn_cast_or_null(); - if (!policy || policy.getValue() != "tfrt.auto-fusion") return false; - - // Check that TF->JitRt compiler supports constant compilation. - return mlir::succeeded(IsCompilableConstant(value)); - }; + pm.addNestedPass(CreateTfJitRtClusteringPass( + options.auto_fusion_oplist, options.auto_fusion_min_cluster_size)); - pm.addNestedPass( - mlir::TFDevice::CreateClusterConstantSinkingPass(is_compilable_const)); + // Sink small constants into the outlined clusters to reduce the number of + // arguments for each of the execute operations. + auto is_compilable_const = [](mlir::tf_device::ClusterOp cluster, + mlir::ElementsAttr value) -> bool { + // Ensure that cluster was formed for TFRT JIT compilation. + auto policy = cluster->getAttr("policy").dyn_cast_or_null(); + if (!policy || policy.getValue() != "tfrt.auto-fusion") return false; - // Outline formed JIT compiled device clusters into function. - pm.addPass(CreateOutlineJitRtClustersPass()); - } + // Check that TF->JitRt compiler supports constant compilation. + return mlir::succeeded(IsCompilableConstant(value)); + }; + + pm.addNestedPass( + mlir::TFDevice::CreateClusterConstantSinkingPass(is_compilable_const)); + + // Outline formed JIT compiled device clusters into function. + pm.addPass(CreateOutlineJitRtClustersPass()); // Rewriter operation sequences to device specific fusions. DeviceNameUtils::ParsedName parsed_name; @@ -2358,6 +2294,10 @@ void CreateTFExecutorToTFPipeline(mlir::OpPassManager &pm, AddTfDeviceAssignmentPasses(pm, options); + if (options.sink_in_invariant_ops) { + pm.addPass(CreateSinkInInvariantOpsPass()); + } + pm.addPass(CreateLowerTFSavedModelPass(options.hoist_invariant_ops)); } @@ -2379,11 +2319,22 @@ void CreateTfExecutorToTfrtPipelineHelper(mlir::OpPassManager &pm, } } +Status ValidateTfrtPipelineOptions(const TfrtPipelineOptions &options) { + if (options.target_tpurt && + (options.target_gpu || options.use_bridge_for_gpu)) { + return tensorflow::errors::Internal( + "Invalid pipeline options. Targeting both TPU and GPU is not " + "supported."); + } + return OkStatus(); +} + // If verbose logging is on, dump the output of each pass to a file directory, // set via env var TF_DUMP_GRAPH_PREFIX. e.g.: // export TF_DUMP_GRAPH_PREFIX=/tmp/mlir -void CreateTfExecutorToTfrtPipeline(mlir::PassManager &pm, - const TfrtPipelineOptions &options) { +Status CreateTfExecutorToTfrtPipeline(mlir::PassManager &pm, + const TfrtPipelineOptions &options) { + TF_RETURN_IF_ERROR(ValidateTfrtPipelineOptions(options)); if (VLOG_IS_ON(1)) { // Print the whole module after each pass, which requires disabling // multi-threading as well. @@ -2392,6 +2343,7 @@ void CreateTfExecutorToTfrtPipeline(mlir::PassManager &pm, /*print_module_scope=*/true)); } CreateTfExecutorToTfrtPipelineHelper(pm, options); + return OkStatus(); } static mlir::PassRegistration tf_to_tfrt_pass; diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt_data.cc b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt_data.cc deleted file mode 100644 index 8de47733baf..00000000000 --- a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt_data.cc +++ /dev/null @@ -1,349 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements lowering of TF dialect to TFRT data kernels. -#include "tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt_data.h" - -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/protobuf/graph_debug_info.pb.h" -#include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime -#include "tfrt/bef_converter/mlir_to_bef.h" // from @tf_runtime -#include "tfrt/data/opdefs/data_ops.h" // from @tf_runtime -#include "tfrt/data/opdefs/types.h" // from @tf_runtime - -#define DEBUG_TYPE "tf-to-tfrt-data" - -namespace tensorflow { -namespace { - -bool isIntScalar(Type t, size_t width) { - if (auto ttype = t.dyn_cast()) { - if (ttype.hasStaticShape() && ttype.getNumElements() == 1 && - ttype.getRank() == 0 && ttype.getElementType().isSignlessInteger(width)) - return true; - } - return false; -} - -// Converts `value_attr` from a TF Const node to the required type attr type `U` -template -T ConstAttrToTypeAttr(ElementsAttr value_attr) { - if (T type_attr = value_attr.dyn_cast()) { - return type_attr; - } else if (auto v = value_attr.dyn_cast()) { - return v.getSplatValue().dyn_cast(); - } - return T(nullptr); -} - -template -LogicalResult ReplaceConst(TF::ConstOp &op, ConversionPatternRewriter &rewriter, - Type type) { - IntegerAttr newAttr = ConstAttrToTypeAttr(op.value()); - - if (!newAttr) { - return failure(); - } - - auto tfrtConst = rewriter.create(op.getLoc(), type, newAttr); - rewriter.replaceOp(op.getOperation(), tfrtConst.getResult()); - return success(); -} - -mlir::Type CreateDatasetType(mlir::Builder *builder) { - return builder->getType(); -} - -// A helper class for converting data-specific types and attributes -class DataConverter : public mlir::TypeConverter { - public: - explicit DataConverter(mlir::MLIRContext *context) { - addConversion([](Type type) { return type; }); - addConversion([context](TensorType type) { - mlir::Builder builder(context); - // tf.data datasets are represented by DT_VARIANT tensors in TF. - // TODO(rachelim): Identify datasets more accurately. - if (type.getElementType().dyn_cast()) { - return CreateDatasetType(&builder); - } - return type.dyn_cast(); - }); - } -}; // namespace - -struct ConstOpConversion : public mlir::OpConversionPattern { - explicit ConstOpConversion(MLIRContext *context) - : OpConversionPattern(context) {} - - LogicalResult matchAndRewrite( - TF::ConstOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (isIntScalar(op.getType(), 64)) { - return ReplaceConst(op, rewriter, - rewriter.getI64Type()); - } - if (isIntScalar(op.getType(), 1)) { - return ReplaceConst(op, rewriter, - rewriter.getI1Type()); - } - // TODO(rachelim): Support converting other const types. - return failure(); - } -}; - -struct ReturnOpConversion - : public mlir::OpConversionPattern { - explicit ReturnOpConversion(MLIRContext *context) - : OpConversionPattern(context) {} - - LogicalResult matchAndRewrite( - mlir::func::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, adaptor.getOperands()); - return success(); - } -}; - -class RangeDatasetOpConversion - : public OpConversionPattern { - public: - explicit RangeDatasetOpConversion(MLIRContext *context) - : OpConversionPattern(context), - builder_(context), - dataset_type_(CreateDatasetType(&builder_)) {} - - LogicalResult matchAndRewrite( - TF::RangeDatasetOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (op.output_types().size() != 1) { - // Range dataset should only have one output type. - return failure(); - } - if (auto output_type = op.output_types().begin()->cast()) { - rewriter.replaceOpWithNewOp( - op, dataset_type_, adaptor.start(), adaptor.stop(), adaptor.step(), - output_type); - return success(); - } - return failure(); - } - - private: - mlir::Builder builder_; - mlir::Type dataset_type_; -}; - -class BatchDatasetV2OpConversion - : public OpConversionPattern { - public: - explicit BatchDatasetV2OpConversion(MLIRContext *context) - : OpConversionPattern(context), - builder_(context), - dataset_type_(CreateDatasetType(&builder_)) {} - - LogicalResult matchAndRewrite( - TF::BatchDatasetV2Op op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Since TFRT's BatchDataset doesn't have a drop_remainder=True option, - // we only convert this op if its drop_remainder input is statically known - // to be false. - auto drop_remainder_op = op.drop_remainder().getDefiningOp(); - if (!drop_remainder_op) return failure(); - BoolAttr drop_remainder_val = - ConstAttrToTypeAttr(drop_remainder_op.value()); - if (!drop_remainder_val || drop_remainder_val.getValue()) { - return failure(); - } - - // TODO(b/155892156): Support converting non-unary BatchDataset - if (op.output_types().size() != 1) return failure(); - - // TODO(b/155892156): Support converting BatchDataset with unknown rank - auto output_shape = op.output_shapes()[0].cast(); - if (!output_shape.hasRank()) { - return failure(); - } - - if (output_shape.getRank() >= 2) { // Input is a tensor - rewriter.replaceOpWithNewOp( - op, dataset_type_, adaptor.input_dataset(), adaptor.batch_size(), - /*same_input_metadata=*/rewriter.getBoolAttr(false)); - return success(); - } - - auto output_type = op.output_types()[0].cast().getValue(); - - if (output_type.isInteger(32)) { - rewriter.replaceOpWithNewOp( - op, dataset_type_, adaptor.input_dataset(), adaptor.batch_size(), - /*same_input_metadata=*/rewriter.getBoolAttr(false)); - return success(); - } - if (output_type.isInteger(64)) { - rewriter.replaceOpWithNewOp( - op, dataset_type_, adaptor.input_dataset(), adaptor.batch_size(), - /*same_input_metadata=*/rewriter.getBoolAttr(false)); - return success(); - } - return failure(); - } - - private: - mlir::Builder builder_; - mlir::Type dataset_type_; -}; - -// This rewrite converts a tf.data function that returns a tf.data dataset (in -// the TF dialect) to the equivalent function in the TFRT and Data dialects that -// returns a `!tfrt.dataset`. -// -// For now, this can only lower a RangeDataset op and its inputs. As we add more -// native TFRT datasets, we add the corresponding lowering pattern here. -class TFToTFRTDataRewritePass - : public mlir::PassWrapper> { - public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TFToTFRTDataRewritePass) - - private: - llvm::StringRef getArgument() const final { return "tf-to-tfrt-data"; } - llvm::StringRef getDescription() const final { - return "Convert Tensorflow dialect to TFRT's data dialect."; - } - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - auto module = getOperation(); - auto *context = &getContext(); - mlir::ConversionTarget target(*context); - DataConverter data_converter(context); - target.addIllegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addDynamicallyLegalOp( - [&data_converter](func::FuncOp op) { - return data_converter.isSignatureLegal(op.getFunctionType()); - }); - mlir::RewritePatternSet patterns(&getContext()); - patterns.add(context); - mlir::populateFunctionOpInterfaceTypeConversionPattern( - patterns, data_converter); - - auto result = - mlir::applyPartialConversion(module, target, std::move(patterns)); - if (failed(result)) { - signalPassFailure(); - } - } -}; - -// Creates a pipeline of passes that converts MLIR TF Executor dialect to -// Hex and Data dialect. -void CreateTFExecutorToTFRTDataPipeline(mlir::OpPassManager &pm) { - // Prune unused operations. - pm.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass()); - - // Run the TF standard pipeline - mlir::TF::StandardPipelineOptions tf_options; - tf_options.enable_inliner = true; - mlir::TF::CreateTFStandardPipeline(pm, tf_options); - - // After all the standard passes, lower to TFRT Data. - pm.addPass(CreateTFToTFRTDataConversionPass()); -} - -Status TFDataGraphDefToTFDataMLIR( - const GraphDef &graph_def, mlir::MLIRContext *mlir_ctx, - mlir::OwningOpRef *module_ref) { - // Import to TF dialect - string output_node; - for (const auto &node : graph_def.node()) { - if (node.op() == "_Retval") { - output_node = node.input(0); - VLOG(2) << "Output node: " << output_node; - break; - } - } - auto import_config = tensorflow::GraphImportConfig(); - import_config.outputs.push_back(std::move(output_node)); - import_config.prune_unused_nodes = true; - TF_ASSIGN_OR_RETURN(*module_ref, ConvertGraphdefToMlir( - graph_def, tensorflow::GraphDebugInfo(), - std::move(import_config), mlir_ctx)); - - return OkStatus(); -} - -Status CompileTFDataMLIRToBEF(mlir::ModuleOp module, - tfrt::BefBuffer *bef_buffer) { - VLOG(1) << "TF Dialect: " << MlirModuleToString(module); - - mlir::PassManager pm(module.getContext()); - CreateTFExecutorToTFRTDataPipeline(pm); - - mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); - if (mlir::failed(pm.run(module))) - return diag_handler.Combine( - errors::Internal("failed to lower TF Dialect to TFRT Data dialect.")); - - VLOG(1) << "TFRT Dialect: " << MlirModuleToString(module); - - *bef_buffer = - tfrt::ConvertMLIRToBEF(module, /*disable_optional_sections=*/false); - if (bef_buffer->empty()) - return diag_handler.Combine( - errors::Internal("failed to convert MLIR to BEF.")); - - return OkStatus(); -} - -} // namespace - -std::unique_ptr CreateTFToTFRTDataConversionPass() { - return std::make_unique(); -} - -Status TFDataGraphDefToHostBEF(const GraphDef &graph_def, - tfrt::BefBuffer *bef) { - mlir::MLIRContext mlir_ctx; - mlir::OwningOpRef module_ref; - TF_RETURN_IF_ERROR( - TFDataGraphDefToTFDataMLIR(graph_def, &mlir_ctx, &module_ref)); - TF_RETURN_IF_ERROR(CompileTFDataMLIRToBEF(module_ref.get(), bef)); - - return OkStatus(); -} - -static mlir::PassRegistration pass; - -} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt_data.h b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt_data.h deleted file mode 100644 index 1058215ce8e..00000000000 --- a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt_data.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TF_TO_TFRT_DATA_H_ -#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TF_TO_TFRT_DATA_H_ - -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/platform/status.h" -#include "tfrt/bef/bef_buffer.h" // from @tf_runtime - -namespace tensorflow { - -// Create a pass that converts MLIR TF dialect to MLIR TFRT CoreRT dialect. -std::unique_ptr CreateTFToTFRTDataConversionPass(); - -Status TFDataGraphDefToHostBEF(const GraphDef& graph_def, tfrt::BefBuffer* bef); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TF_TO_TFRT_DATA_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.cc b/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.cc new file mode 100644 index 00000000000..6b5a0132a67 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.cc @@ -0,0 +1,44 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.h" + +#include "mlir/IR/Builders.h" // from @llvm-project + +namespace tensorflow { +namespace tfrt_compiler { + +constexpr char kCostAttrName[] = "_tfrt_cost"; +constexpr char kOpKeyAttrName[] = "op_key"; + +void UpdateOpCostInTfrtMlir(mlir::ModuleOp op, + const tfrt_stub::CostRecorder& cost_recorder) { + mlir::Builder builder(op); + op.walk([&](mlir::Operation* op) { + // Only update ops with existing cost attr. + const auto cost_attr = op->getAttrOfType(kCostAttrName); + if (!cost_attr) return; + // Only fallback ops have `op_key`s. + const auto op_key_attr = + op->getAttrOfType(kOpKeyAttrName); + if (!op_key_attr) return; + // Set the cost attr with a new value. + const int64_t op_key = op_key_attr.getInt(); + op->setAttr(kCostAttrName, builder.getI64IntegerAttr( + cost_recorder.GetCostNanosecond(op_key))); + }); +} + +} // namespace tfrt_compiler +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.h b/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.h new file mode 100644 index 00000000000..99b7c192f7a --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.h @@ -0,0 +1,32 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_UPDATE_OP_COST_IN_TFRT_MLIR_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_UPDATE_OP_COST_IN_TFRT_MLIR_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" + +namespace tensorflow { +namespace tfrt_compiler { + +// Updates the existing costs for all the fallback ops with the records in +// `cost_recorder`. +void UpdateOpCostInTfrtMlir(mlir::ModuleOp op, + const tfrt_stub::CostRecorder& cost_recorder); + +} // namespace tfrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_UPDATE_OP_COST_IN_TFRT_MLIR_H_ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/utils.cc b/tensorflow/compiler/mlir/tfrt/transforms/utils.cc new file mode 100644 index 00000000000..dc92e0cf511 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/utils.cc @@ -0,0 +1,59 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tfrt/transforms/utils.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tfrt/basic_kernels/opdefs/tfrt_base.h" // from @tf_runtime +#include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime +#include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime + +namespace tensorflow { + +bool IsResourceArgument(mlir::Value value) { + auto arg = value.dyn_cast(); + if (!arg) return false; + + auto func = llvm::cast(arg.getOwner()->getParentOp()); + + return func.getArgAttr(arg.getArgNumber(), "tf.resource_name") != nullptr; +} + +bool IsResultVariable(const mlir::Value &original_operand, + const mlir::Value &operand) { + if (original_operand.isa()) { + auto defining_op = original_operand.getDefiningOp(); + + // TODO(b/174753886): When device assignment is properly done, we + // should check that TF::ReadVariableOp is for TPU device here. + if (llvm::isa(defining_op) && + defining_op->getNumOperands() == 1) { + return true; + } else if (llvm::isa(defining_op)) { + return true; + } + return false; + } + return IsResourceArgument(operand); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/utils.h b/tensorflow/compiler/mlir/tfrt/transforms/utils.h new file mode 100644 index 00000000000..4440149df89 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/utils.h @@ -0,0 +1,31 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_UTILS_H_ + +#include "mlir/IR/Value.h" // from @llvm-project + +namespace tensorflow { + +// Checks if the given `value` is a resource argument. +bool IsResourceArgument(mlir::Value value); + +// Checks if an operand is the value of a variable. +bool IsResultVariable(const mlir::Value &original_operand, + const mlir::Value &operand); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_UTILS_H_ diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc index 0fd4bfbd06e..942096d2736 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc @@ -15,12 +15,17 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h" +#include +#include #include +#include #include "absl/strings/match.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" @@ -32,6 +37,65 @@ limitations under the License. namespace tensorflow { +namespace { + +// Exports all XLA functions in the form of XlaLaunch, and their nested +// functions. +StatusOr> ExportXlaFunctions(mlir::ModuleOp module) { + // Find all XLA functions. + std::vector xla_functions; + module.walk([&](mlir::TF::XlaLaunchOp xla_launch_op) { + std::string func_name = + xla_launch_op.getFunctionAttr().getRootReference().str(); + xla_functions.push_back(func_name); + }); + + // Convert all XLA functions and their nested functions. + std::deque queue; + for (const std::string& func : xla_functions) { + queue.push_back(func); + } + + const mlir::SymbolTable symbol_table(module); + absl::flat_hash_set visited; + std::vector xla_func_defs; + while (!queue.empty()) { + const std::string func_name = queue.front(); + queue.pop_front(); + + if (visited.contains(func_name)) continue; + + const auto func_op = symbol_table.lookup(func_name); + if (!func_op) { + return tensorflow::errors::Internal( + absl::StrCat("Function ", func_name, " is not found.")); + } + FunctionDef func_def; + TF_RETURN_IF_ERROR(ConvertMlirFunctionToFunctionLibraryDef( + func_op, GraphExportConfig(), &func_def)); + xla_func_defs.push_back(func_def); + + // Visit each op in the function and find out referenced functions from the + // attributes. + func_op->walk([&](mlir::Operation* op) { + for (const mlir::NamedAttribute& attr : op->getAttrs()) { + if (const auto sym = + attr.getValue().dyn_cast()) { + mlir::Operation* func = + mlir::SymbolTable::lookupNearestSymbolFrom(op, sym); + if (func) { + queue.push_back(sym.getValue().str()); + } + } + } + }); + visited.insert(func_name); + } + return xla_func_defs; +} + +} // namespace + Status ConvertFunctionToBef( mlir::StringRef function_name, const tensorflow::FunctionBody* fbody, const FunctionLibraryDefinition& flib_def, @@ -62,10 +126,11 @@ Status ConvertFunctionToBef( } Status ConvertTfMlirToBef(const TfrtCompileOptions& options, - mlir::ModuleOp module, tfrt::BefBuffer* bef_buffer) { + mlir::ModuleOp module, tfrt::BefBuffer* bef_buffer, + tfrt_stub::FallbackState* fallback_state) { mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); - if (options.tpu_target == TfrtTpuInfraTarget::kTpurt) { + if (options.device_target == TfrtDeviceInfraTarget::kTpurt) { VLOG(1) << "Running MLIR TPU bridge for tpurt"; if (VLOG_IS_ON(1)) { tensorflow::DumpMlirOpToFile("tpu_bct_conversion_before", module); @@ -90,13 +155,28 @@ Status ConvertTfMlirToBef(const TfrtCompileOptions& options, TF_RETURN_IF_ERROR( mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1))); - } else if (options.tpu_target == TfrtTpuInfraTarget::kTfFallback) { + } else if (options.device_target == TfrtDeviceInfraTarget::kTfFallback) { auto tpu_partitioned_call_fallback_compat_result = tensorflow::RunTPUPartitionedCallFallbackCompatConversion(module); if (mlir::failed(tpu_partitioned_call_fallback_compat_result)) { return diag_handler.Combine(tensorflow::errors::Internal( "Failed to process TPUPartitionedCallOp for fallback execution")); } + } else if (options.device_target == TfrtDeviceInfraTarget::kGpu && + options.use_bridge_for_gpu) { + TF_RETURN_IF_ERROR( + mlir::TF::RunTFXLABridge(module, /*enable_logging=*/VLOG_IS_ON(1))); + + // GPU XLA clusters are wrapped in functions, which could be transformed by + // bridge. Hence, the MLIR functions for XLA clusters are exported and added + // to the function library. + if (fallback_state != nullptr) { + TF_ASSIGN_OR_RETURN(const std::vector xla_func_defs, + ExportXlaFunctions(module)); + for (const auto& func_def : xla_func_defs) { + TF_RETURN_IF_ERROR(fallback_state->AddFunctionDef(func_def)); + } + } } if (VLOG_IS_ON(1)) { @@ -118,12 +198,15 @@ Status ConvertTfMlirToBef(const TfrtCompileOptions& options, // ops. pass_options.decompose_resource_ops = options.decompose_resource_ops; pass_options.enable_optimizer = options.enable_optimizer; - pass_options.enable_native_ops = options.enable_native_ops; pass_options.target_tpurt = - (options.tpu_target == TfrtTpuInfraTarget::kTpurt); + (options.device_target == TfrtDeviceInfraTarget::kTpurt); + pass_options.target_gpu = + (options.device_target == TfrtDeviceInfraTarget::kGpu); + pass_options.use_bridge_for_gpu = options.use_bridge_for_gpu; pass_options.tpu_fuse_ops = options.tpu_fuse_ops; pass_options.use_tpu_host_allocator_for_inputs = options.use_tpu_host_allocator_for_inputs; + pass_options.sink_in_invariant_ops = options.sink_in_invariant_ops; pass_options.hoist_invariant_ops = options.hoist_invariant_ops; pass_options.func_use_fallback_tensor = true; pass_options.enable_while_parallel_iterations = @@ -135,7 +218,10 @@ Status ConvertTfMlirToBef(const TfrtCompileOptions& options, pass_options.upper_cost_threshold = options.upper_cost_threshold; pass_options.merge_inter_dependent_streams = options.merge_inter_dependent_streams; - tensorflow::CreateTfExecutorToTfrtPipeline(pm, pass_options); + Status status = tensorflow::CreateTfExecutorToTfrtPipeline(pm, pass_options); + if (!status.ok()) { + return diag_handler.Combine(status); + } if (mlir::failed(pm.run(module))) return diag_handler.Combine(tensorflow::errors::Internal( diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.h b/tensorflow/compiler/mlir/tfrt/translate/import_model.h index b74712979fc..99d4bcf8a4d 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.h +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_IMPORT_MODEL_H_ #define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_IMPORT_MODEL_H_ +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -23,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" #include "tfrt/bef/bef_buffer.h" // from @tf_runtime namespace tensorflow { @@ -41,8 +44,12 @@ Status ConvertFunctionToBef( tfrt::BefBuffer* bef_buffer); // Converts an MLIR `module` in TF dialect to TFRT's Binary Executable Format. +// If `fallback_state` is not null, the MLIR functions for XLA clusters in +// the form of XlaLaunch will be exported and added to the function library when +// needed. The nested functions will also be exported. Status ConvertTfMlirToBef(const TfrtCompileOptions& options, - mlir::ModuleOp module, tfrt::BefBuffer* bef_buffer); + mlir::ModuleOp module, tfrt::BefBuffer* bef_buffer, + tfrt_stub::FallbackState* fallback_state = nullptr); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.cc b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.cc index 87cd5f3a51e..1e4a81d0d0c 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.cc @@ -23,16 +23,19 @@ limitations under the License. namespace tensorflow { -std::ostream& operator<<(std::ostream& os, TfrtTpuInfraTarget tpu_target) { - switch (tpu_target) { - case TfrtTpuInfraTarget::kNoTpu: - return os << "NoTpu"; - case TfrtTpuInfraTarget::kTpurt: +std::ostream& operator<<(std::ostream& os, + TfrtDeviceInfraTarget device_target) { + switch (device_target) { + case TfrtDeviceInfraTarget::kCpu: + return os << "Cpu"; + case TfrtDeviceInfraTarget::kTpurt: return os << "Tpurt"; - case TfrtTpuInfraTarget::kTfFallback: + case TfrtDeviceInfraTarget::kTfFallback: return os << "TfFallback"; - case TfrtTpuInfraTarget::kBridgeFallback: + case TfrtDeviceInfraTarget::kBridgeFallback: return os << "BridgeFallback"; + case TfrtDeviceInfraTarget::kGpu: + return os << "Gpu"; } } @@ -41,10 +44,9 @@ std::ostream& operator<<(std::ostream& os, const TfrtCompileOptions& options) { << "variable_device = " << options.variable_device << ", default_device = " << options.default_device << ", enable_optimizer = " << options.enable_optimizer - << ", enable_native_ops = " << options.enable_native_ops << ", enable_grappler = " << options.enable_grappler << ", force_data_format = " << options.force_data_format - << ", tpu_target = " << options.tpu_target + << ", device_target = " << options.device_target << ", tpu_fuse_ops = " << options.tpu_fuse_ops << ", tpu_move_resource_gather_to_host = " << options.tpu_move_resource_gather_to_host diff --git a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h index d8cbfc57bfa..ab4f0a04304 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h +++ b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h @@ -23,15 +23,17 @@ limitations under the License. namespace tensorflow { -enum class TfrtTpuInfraTarget { - kNoTpu, // No TPU support. +enum class TfrtDeviceInfraTarget { + kCpu, // CPU only, no device support. kTpurt, // Target TPURT dialect and kernels. kTfFallback, // Target TPU kernels in TF Fallback. kBridgeFallback, // TPU support but choose kTpurt or kTfFallback depending on - // whether the graph has unsupported feature in Bridge + // whether the graph has unsupported feature in Bridge. + kGpu, // Target GPU specific compiler passes and runtime + // initializations. }; -std::ostream& operator<<(std::ostream& os, TfrtTpuInfraTarget tpu_target); +std::ostream& operator<<(std::ostream& os, TfrtDeviceInfraTarget device_target); struct TfrtCompileOptions { // TODO(tfrt-devs): Ideally, compiler should make the decision where @@ -42,13 +44,6 @@ struct TfrtCompileOptions { // Enable compiler optimization in TFRT dialect. bool enable_optimizer = true; - // If true, native ops will be used if they are implemented in TFRT. If - // false, all ops are using fallback. - // - // This option is experimental. Native ops are still under development and - // likely to cause performance issue when enabled. - bool enable_native_ops = false; - // If true, run grappler passes before compiling. bool enable_grappler = true; @@ -61,9 +56,9 @@ struct TfrtCompileOptions { // data format should be changed, instead of controlled by users. std::string force_data_format; - // The target TPU infrastructure to use. This will trigger TPU target specific + // The target device infrastructure to use. This will trigger target specific // compiler passes and runtime initialization. - TfrtTpuInfraTarget tpu_target = TfrtTpuInfraTarget::kNoTpu; + TfrtDeviceInfraTarget device_target = TfrtDeviceInfraTarget::kCpu; // If true, use the fused TPU compile_and_execute kernel, which performs all // TPU inference related operations, e.g. core selection, h2d/d2h transfers, @@ -90,6 +85,13 @@ struct TfrtCompileOptions { // supposed to be turned on by default. bool hoist_invariant_ops = false; + // If true, the compiler will try to sink in the invariant ops (e.g. const + // ops, var handle ops, etc.) to the nested function (e.g. batch function) to + // facilitate invariant ops hoisting. + // TODO(tfrt-devs): Set the default value to true after testing as it is + // supposed to be turned on by default. + bool sink_in_invariant_ops = false; + // If true, tf.While's iterations will be parallelized on a best-effort // basis. This is currently experimental. bool enable_while_parallel_iterations = false; @@ -123,13 +125,17 @@ struct TfrtCompileOptions { // If true, streams with inter data depenedencies will be preferred to be // merged for inline execution. - bool merge_inter_dependent_streams = false; + bool merge_inter_dependent_streams = true; // Whether to enable the DecomposeResourceOpsPass. bool decompose_resource_ops = true; // Whether to compile to sync TFRT dialect. bool compile_to_sync_tfrt_dialect = false; + + // Whether to use bridge for GPU. + // TODO(b/260915352): Remove the flag and default to using bridge. + bool use_bridge_for_gpu = false; }; std::ostream& operator<<(std::ostream& os, const TfrtCompileOptions& options); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 3bf71aa5ee9..8e866876253 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -22,6 +22,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [":friends"], licenses = ["notice"], ) @@ -98,7 +99,6 @@ tf_cc_binary( "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:lib", - "//tensorflow/compiler/xla/stream_executor/lib", "@com_google_absl//absl/strings", "@llvm-project//llvm:Analysis", "@llvm-project//llvm:ARMCodeGen", # fixdeps: keep diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD index abddd366b56..f040ca2af3b 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD @@ -3,6 +3,7 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/compiler/mlir/tools/kernel_gen:friends", # Allow visibility from the mlir language server. diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td index 3b2b8463a67..a5a9f12d6a2 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td @@ -40,7 +40,7 @@ def TFFramework_Dialect : Dialect { }]; let useDefaultTypePrinterParser = 1; - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let useFoldAPI = kEmitFoldAdaptorFolder; } def TFFramework_OpKernelContextType : DialectType( mlir::kernel_gen::transforms::CreateTFToJITInvocationPass( tile_sizes, unroll_factors, max_supported_rank, enable_ftz, - index_64bit, jit_i64_indexed_for_large_tensors)); + index_64bit, + /*cpu_codegen=*/false, jit_i64_indexed_for_large_tensors)); pm.addPass(mlir::kernel_gen::tf_framework::CreateEmbedTFFrameworkPass()); pm.addNestedPass( mlir::bufferization::createEmptyTensorToAllocTensorPass()); @@ -150,6 +151,7 @@ Status LowerTFtoLoops(mlir::ModuleOp module, llvm::ArrayRef tile_sizes, mlir::kernel_gen::transforms::CreateTFToJITInvocationPass( tile_sizes, unroll_factors, max_supported_rank, enable_ftz, index_64bit, + /*cpu_codegen=*/false, /*jit_i64_indexed_for_large_tensors=*/true)); } pm.addNestedPass(mlir::mhlo::createLegalizeTFNoFallbackPass( @@ -162,7 +164,7 @@ Status LowerTFtoLoops(mlir::ModuleOp module, llvm::ArrayRef tile_sizes, pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addNestedPass(mlir::createCSEPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); - pm.addNestedPass(mlir::createShapeSimplification()); + pm.addNestedPass(mlir::mhlo::createShapeSimplification()); pm.addNestedPass(mlir::mhlo::createMergeAssumingOpsPass()); pm.addNestedPass(mlir::mhlo::createBroadcastPropagationPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); @@ -234,8 +236,8 @@ Status LowerTFtoLoops(mlir::ModuleOp module, llvm::ArrayRef tile_sizes, return OkStatus(); } -Status LowerLoopsToGPU(mlir::ModuleOp module, bool embed_memref_prints, - bool index_64bit, bool apply_cl_options) { +Status LowerLoopsToGPU(mlir::ModuleOp module, bool index_64bit, + bool apply_cl_options) { mlir::PassManager pm(module.getContext()); if (apply_cl_options) applyTensorflowAndCLOptions(pm); @@ -297,9 +299,6 @@ Status LowerLoopsToGPU(mlir::ModuleOp module, bool embed_memref_prints, pm.addPass(::mlir::createConvertSCFToCFPass()); // Map asserts to the tensorflow framework. pm.addPass(mlir::kernel_gen::tf_framework::CreateRewriteTFFrameworkAssert()); - if (embed_memref_prints) { - pm.addPass(mlir::kernel_gen::transforms::CreateEmbedMemRefPrintsPass()); - } if (failed(pm.run(module))) { return tensorflow::errors::Internal("Lowering to GPU kernels failed."); } @@ -426,13 +425,23 @@ StatusOr> GenerateKernelForTfCode( mlir::MLIRContext& context, llvm::StringRef tf_code, llvm::ArrayRef architectures, llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, - int64_t max_supported_rank, bool embed_memref_prints, bool print_ptx, - bool print_llvmir, bool enable_ftz, bool index_64bit, bool jit_compile, + int64_t max_supported_rank, bool print_ptx, bool print_llvmir, + bool enable_ftz, bool index_64bit, bool jit_compile, bool jit_i64_indexed_for_large_tensors, bool apply_cl_options) { + if (jit_compile && jit_i64_indexed_for_large_tensors) { + return tensorflow::Status( + tensorflow::error::Code::INVALID_ARGUMENT, + "jit compilation for large tensors " + "(`jit_i64_indexed_for_large_tensors`) and unconditioned jit " + "compilation (`jit`) must not be requested together"); + } + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, SetupContextAndParseModule(context, tf_code)); if (jit_compile) { + assert(!jit_i64_indexed_for_large_tensors && + "expect to have reported an error earlier"); TF_RETURN_IF_ERROR(LowerTFToJITInvocation( module.get(), tile_sizes, unroll_factors, max_supported_rank, enable_ftz, index_64bit, @@ -442,8 +451,8 @@ StatusOr> GenerateKernelForTfCode( LowerTFtoLoops(module.get(), tile_sizes, unroll_factors, max_supported_rank, enable_ftz, index_64bit, jit_i64_indexed_for_large_tensors, apply_cl_options)); - TF_RETURN_IF_ERROR(LowerLoopsToGPU(module.get(), embed_memref_prints, - index_64bit, apply_cl_options)); + TF_RETURN_IF_ERROR( + LowerLoopsToGPU(module.get(), index_64bit, apply_cl_options)); TF_RETURN_IF_ERROR( LowerKernelBodiesToLowLevelIr(module.get(), apply_cl_options)); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h index 794f0eb61ad..0ac2af80ada 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h @@ -44,8 +44,8 @@ StatusOr> GenerateKernelForTfCode( mlir::MLIRContext& context, llvm::StringRef tf_code, llvm::ArrayRef architectures, llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, - int64_t max_supported_rank, bool embed_memref_prints, bool print_ptx, - bool print_llvmir, bool enable_ftz, bool index_64bit, bool jit_compile, + int64_t max_supported_rank, bool print_ptx, bool print_llvmir, + bool enable_ftz, bool index_64bit, bool jit_compile, bool jit_i64_indexed_for_large_tensors, bool apply_cl_options); } // namespace kernel_gen diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/tests/BUILD index 35b3e52dfe4..4d4abfd9d90 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/BUILD @@ -1,7 +1,10 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/copy_cleanup.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/copy_cleanup.mlir index 6c7be698b5a..a8ab6c524d8 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/copy_cleanup.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/copy_cleanup.mlir @@ -23,8 +23,8 @@ builtin.module { } } -// CHECK: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 * s0)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 * s0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @Copy( // CHECK-SAME: %[[LHS:.*]]: memref, %[[RHS:.*]]: memref) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -64,8 +64,8 @@ builtin.module { } } -// CHECK: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 * s0)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 * s0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @CopyWithWrite( // CHECK-SAME: %[[LHS:.*]]: memref, %[[RHS:.*]]: memref) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -109,8 +109,8 @@ builtin.module { } } -// CHECK: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 * s0)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 * s0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @CopyWithMutation( // CHECK-SAME: %[[LHS:.*]]: memref, %[[RHS:.*]]: memref) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/jit_i64_indexed_for_large_tensors.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/jit_i64_indexed_for_large_tensors.mlir deleted file mode 100644 index 1aaddc58d99..00000000000 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/jit_i64_indexed_for_large_tensors.mlir +++ /dev/null @@ -1,44 +0,0 @@ -// RUN: kernel-gen-opt %s --split-input-file \ -// RUN: --tf-to-jit-invocation="tile-sizes=1,2,3 \ -// RUN: unroll-factors=3,2,1 max-supported-rank=32 \ -// RUN: enable-ftz=false index_64bit=false cpu-codegen=false \ -// RUN: jit_i64_indexed_for_large_tensors=true" | \ -// RUN: FileCheck %s - -// CHECK-LABEL: @unary_tanh_rint -// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -func.func @unary_tanh_rint(%arg : tensor<*xf32>) -> (tensor<*xf32>) { - // CHECK: %[[MAX_SIZE:.*]] = arith.constant 4294967296 : index - // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 - // CHECK: %[[ELEMENT_COUNT:.*]] = shape.num_elements %[[SHAPE:.*]] : tensor -> index - // CHECK: %[[CONDITION:.*]] = arith.cmpi sgt, %[[ELEMENT_COUNT:.*]], %[[MAX_SIZE:.*]] : index - // CHECK: %[[IF_RES:.*]] = scf.if %[[CONDITION:.*]] -> (tensor<*xf32>) { - // CHECK: %[[CALLABLE:.*]] = tf_framework.jit_compile_from_str - // CHECK-SAME: " - // CHECK-SAME: module { - // CHECK-SAME: func @main(%arg0: tensor<*xf32>) -> tensor<*xf32> - // CHECK-SAME: attributes {tf_entry} - // CHECK-SAME: { - // CHECK-SAME: %0 = \22tf.Tanh\22(%arg0) - // CHECK-SAME: return %0 - // CHECK-SAME: } - // CHECK-SAME: } - // CHECK-SAME: " - // CHECK-SAME: { - // CHECK-SAME: cpuCodegen = false - // CHECK-SAME: enableFtz = false - // CHECK-SAME: index64Bit = true - // CHECK-SAME: maxSupportedRank = 32 - // CHECK-SAME: tileSizes = [1, 2, 3] - // CHECK-SAME: unrollFactors = [3, 2, 1] - // CHECK-SAME: } - // CHECK: %[[RES:.*]] = tf_framework.jit_execute %[[CALLABLE]](%[[ARG]]) - // CHECK: scf.yield %[[RES:.*]] - // CHECK: } else { - // CHECK: %4 = "tf.Tanh"(%arg0) - // CHECK: scf.yield %4 : tensor<*xf32> - // CHECK: } - // CHECK: return %[[IF_RES]] - %0 = "tf.Tanh"(%arg) : (tensor<*xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/print_memrefs.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/print_memrefs.mlir deleted file mode 100644 index 3fe2c9af670..00000000000 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/print_memrefs.mlir +++ /dev/null @@ -1,55 +0,0 @@ -// RUN: kernel-gen-opt %s --embed-memref-prints | FileCheck %s - -func.func @print_memrefs( - %ctx: !tf_framework.op_kernel_context, %input: memref<*xf32>) - -> memref<*xf32> attributes {tf_entry} { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %rank = memref.rank %input : memref<*xf32> - %shape = memref.alloca(%rank) : memref - scf.for %i = %c0 to %rank step %c1 { - %dim = memref.dim %input, %i : memref<*xf32> - memref.store %dim, %shape[%i] : memref - } - - %c9000 = arith.constant 9000 : index - %num_elem = memref.alloca() : memref<1xindex> - memref.store %c9000, %num_elem[%c0] : memref<1xindex> - %flat_input = memref.reshape %input(%num_elem) - : (memref<*xf32>, memref<1xindex>) -> memref - - %flat_output = tf_framework.alloc(%ctx, %c9000) : memref - %output = memref.reshape %flat_output(%shape) - : (memref, memref) -> memref<*xf32> - func.return %output : memref<*xf32> -} - -// CHECK-DAG: global internal constant @[[STR0:debug_op_[0-9]+]]({{.*}} @print_memrefs -// CHECK-DAG: global internal constant @[[STR1:debug_op_[0-9]+]]({{.*}} -> memref -// CHECK-DAG: global internal constant @[[STR2:debug_op_[0-9]+]]({{.*}} -> memref<*xf32> -// CHECK-DAG: func private @printMemrefF32(memref<*xf32>) -// CHECK-DAG: llvm.func @printCString(!llvm.ptr) - -// CHECK: func @print_memrefs -// CHECK-SAME: , %[[ARG:.*]]: memref<*xf32>) -// Print debug info for the function arg. -// CHECK: %[[STR0_ADDR:.*]] = llvm.mlir.addressof @[[STR0]] -// CHECK: %[[STR0_PTR:.*]] = llvm.getelementptr %[[STR0_ADDR]] -// CHECK: llvm.call @printCString(%[[STR0_PTR]]) : (!llvm.ptr) -// CHECK: call @printMemrefF32(%[[ARG]]) : (memref<*xf32>) -> () - -// Print debug info for reshape from unranked to ranked. -// CHECK: %[[RESHAPE:.*]] = memref.reshape %[[ARG]] -// CHECK: %[[STR1_ADDR:.*]] = llvm.mlir.addressof @[[STR1]] -// CHECK: %[[STR1_PTR:.*]] = llvm.getelementptr %[[STR1_ADDR]] -// CHECK: llvm.call @printCString(%[[STR1_PTR]]) : (!llvm.ptr) -// CHECK: %[[UNRANKED_BUF:.*]] = memref.cast %[[RESHAPE]] -// CHECK: call @printMemrefF32(%[[UNRANKED_BUF]]) : (memref<*xf32>) - -// Print debug info for reshape from ranked to unranked. -// CHECK: %[[ALLOC:.*]] = tf_framework.alloc -// CHECK: %[[RESHAPE_2:.*]] = memref.reshape %[[ALLOC]] -// CHECK: %[[STR2_ADDR:.*]] = llvm.mlir.addressof @[[STR2]] -// CHECK: %[[STR2_PTR:.*]] = llvm.getelementptr %[[STR2_ADDR]] -// CHECK: llvm.call @printCString(%[[STR2_PTR]]) : (!llvm.ptr) -// CHECK: call @printMemrefF32(%[[RESHAPE_2]]) : (memref<*xf32>) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_jit_invocations.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_jit_invocations.mlir index f85658e6cea..79b1ca008b9 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_jit_invocations.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_jit_invocations.mlir @@ -1,60 +1,90 @@ // RUN: kernel-gen-opt %s --split-input-file \ // RUN: --tf-to-jit-invocation="tile-sizes=1,2,3 unroll-factors=3,2,1 \ -// RUN: max-supported-rank=32 enable-ftz=false cpu-codegen=false" | \ +// RUN: max-supported-rank=32 enable-ftz=false cpu-codegen=false" | \ // RUN: FileCheck %s -// CHECK-LABEL: @unary_tanh -// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) +// RUN: kernel-gen-opt %s --split-input-file \ +// RUN: --tf-to-jit-invocation="tile-sizes=1,2,3 unroll-factors=3,2,1 \ +// RUN: max-supported-rank=32 enable-ftz=false cpu-codegen=false \ +// RUN: jit_i64_indexed_for_large_tensors=true" | \ +// RUN: FileCheck %s --check-prefix=CHECK-JFLT + func.func @unary_tanh(%arg : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[CALLABLE:.*]] = tf_framework.jit_compile_from_str - // CHECK-SAME: " - // CHECK-SAME: module { - // CHECK-SAME: func @main(%arg0: tensor<*xf32>) -> tensor<*xf32> - // CHECK-SAME: attributes {tf_entry} - // CHECK-SAME: { - // CHECK-SAME: %0 = \22tf.Tanh\22(%arg0) - // CHECK-SAME: return %0 - // CHECK-SAME: } - // CHECK-SAME: } - // CHECK-SAME: " - // CHECK-SAME: { - // CHECK-SAME: cpuCodegen = false - // CHECK-SAME: enableFtz = false - // CHECK-SAME: maxSupportedRank = 32 : i64 - // CHECK-SAME: tileSizes = [1, 2, 3] - // CHECK-SAME: unrollFactors = [3, 2, 1] - // CHECK-SAME: } - // CHECK: %[[RES:.*]] = tf_framework.jit_execute %[[CALLABLE]](%[[ARG]]) - // CHECK: return %[[RES]] %0 = "tf.Tanh"(%arg) : (tensor<*xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> } +// CHECK-LABEL: @unary_tanh +// CHECK-SAME: %[[ARG:.*]]: tensor<*xf32> +// CHECK: %[[CALLABLE:.*]] = tf_framework.jit_compile_from_str +// CHECK-SAME: " +// CHECK-SAME: module { +// CHECK-SAME: func @main(%[[ARG_JIT:.*]]: tensor<*xf32>) -> tensor<*xf32> +// CHECK-SAME: attributes {tf_entry} +// CHECK-SAME: { +// CHECK-SAME: %[[RES_JIT:.*]] = \22tf.Tanh\22(%[[ARG_JIT]]) +// CHECK-SAME: return %[[RES_JIT]] +// CHECK-SAME: } +// CHECK-SAME: } +// CHECK-SAME: " +// CHECK-SAME: { +// CHECK-SAME: cpuCodegen = false +// CHECK-SAME: enableFtz = false +// CHECK-SAME: maxSupportedRank = 32 : i64 +// CHECK-SAME: tileSizes = [1, 2, 3] +// CHECK-SAME: unrollFactors = [3, 2, 1] +// CHECK-SAME: } +// CHECK: %[[RES:.*]] = tf_framework.jit_execute %[[CALLABLE]](%[[ARG]]) +// CHECK: return %[[RES]] + +// CHECK-JFLT-LABEL: @unary_tanh +// CHECK-JFLT-SAME: %[[ARG0:.*]]: tensor<*xf32> +// CHECK-JFLT-DAG: %[[C4294967296:.*]] = arith.constant 4294967296 +// CHECK-JFLT: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] +// CHECK-JFLT: %[[NUM:.*]] = shape.num_elements %[[SHAPE]] +// CHECK-JFLT: %[[CMPI:.*]] = arith.cmpi sgt, %[[NUM]], %[[C4294967296]] +// CHECK-JFLT: %[[IF:.*]] = scf.if %[[CMPI]] +// CHECK-JFLT: %[[JIT:.*]] = tf_framework.jit_compile_from_str +// CHECK-JFLT-SAME: "module +// CHECK-JFLT-SAME: cpuCodegen = false +// CHECK-JFLT-SAME: enableFtz = false +// CHECK-JFLT-SAME: index64Bit = true +// CHECK-JFLT-SAME: maxSupportedRank = 32 +// CHECK-JFLT-SAME: tileSizes = [1, 2, 3] +// CHECK-JFLT-SAME: unrollFactors = [3, 2, 1] +// CHECK-JFLT: %[[JIT_0:.*]] = tf_framework.jit_execute %[[JIT]](%[[ARG0]]) +// CHECK-JFLT: scf.yield %[[JIT_0]] +// CHECK-JFLT: else +// CHECK-JFLT: %[[VAL:.*]] = "tf.Tanh"(%[[ARG0]]) +// CHECK-JFLT: scf.yield %[[VAL]] +// CHECK-JFLT: return %[[IF]] + // ----- -// CHECK-LABEL: @binary_sub -// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) func.func @binary_sub(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: %[[CALLABLE:.*]] = tf_framework.jit_compile_from_str - // CHECK-SAME: " - // CHECK-SAME: module { - // CHECK-SAME: func @main(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> - // CHECK-SAME: attributes {tf_entry} - // CHECK-SAME: { - // CHECK-SAME: %0 = \22tf.Sub\22(%arg0, %arg1) - // CHECK-SAME: return %0 - // CHECK-SAME: } - // CHECK-SAME: } - // CHECK-SAME: " - // CHECK-SAME: { - // CHECK-SAME: cpuCodegen = false - // CHECK-SAME: enableFtz = false - // CHECK-SAME: maxSupportedRank = 32 : i64 - // CHECK-SAME: tileSizes = [1, 2, 3] - // CHECK-SAME: unrollFactors = [3, 2, 1] - // CHECK-SAME: } - // CHECK: %[[RES:.*]] = tf_framework.jit_execute %[[CALLABLE]](%[[ARG0]], %[[ARG1]]) - // CHECK: return %[[RES]] %0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> } + +// CHECK-LABEL: @binary_sub +// CHECK-SAME: %[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32> +// CHECK: %[[CALLABLE:.*]] = tf_framework.jit_compile_from_str +// CHECK-SAME: " +// CHECK-SAME: module { +// CHECK-SAME: func @main(%[[ARG0_JIT:.*]]: tensor<*xf32>, %[[ARG1_JIT:.*]]: tensor<*xf32>) -> tensor<*xf32> +// CHECK-SAME: attributes {tf_entry} +// CHECK-SAME: { +// CHECK-SAME: %[[RES_JIT:.*]] = \22tf.Sub\22(%[[ARG0_JIT]], %[[ARG1_JIT]]) +// CHECK-SAME: return %[[RES_JIT]] +// CHECK-SAME: } +// CHECK-SAME: } +// CHECK-SAME: " +// CHECK-SAME: { +// CHECK-SAME: cpuCodegen = false +// CHECK-SAME: enableFtz = false +// CHECK-SAME: maxSupportedRank = 32 : i64 +// CHECK-SAME: tileSizes = [1, 2, 3] +// CHECK-SAME: unrollFactors = [3, 2, 1] +// CHECK-SAME: } +// CHECK: %[[RES:.*]] = tf_framework.jit_execute %[[CALLABLE]](%[[ARG0]], %[[ARG1]]) +// CHECK: return %[[RES]] diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD index 482aea37393..38f0b297272 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/tf_to_kernel/BUILD @@ -1,6 +1,9 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc index 71e862cd096..c893a81ab47 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc @@ -107,7 +107,7 @@ extern "C" void _mlir_ciface_tf_report_error(void* op_kernel_ctx, } auto* ctx = static_cast(op_kernel_ctx); ctx->CtxFailureWithWarning( - tensorflow::Status{ConvertAttrToEnumValue(symbol.getValue()), msg}); + tensorflow::Status{ConvertAttrToEnumValue(symbol.value()), msg}); } static void ReportError(void* op_kernel_ctx, ErrorCode error_code, @@ -182,7 +182,7 @@ llvm::Expected> Compile( tensorflow::StatusOr> status_or_module = tensorflow::kernel_gen::GenerateKernelForTfCode( context, code, architectures, tile_sizes, unroll_factors, - max_supported_rank, /*embed_memref_prints=*/false, + max_supported_rank, /*print_ptx=*/false, /*print_llvmir=*/false, enable_ftz, index_64bit, /*jit_compile=*/false, diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc index 46b9ceb10b6..1a26888eb1f 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_kernel.cc @@ -108,22 +108,21 @@ Status Run(llvm::StringRef input_file, llvm::StringRef output_file, llvm::ArrayRef architectures, llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, int64_t max_supported_rank, - bool embed_memref_prints, bool print_ptx, bool print_llvmir, - bool enable_ftz, bool index_64bit, bool jit_compile, - bool jit_i64_indexed_for_large_tensors) { + bool print_ptx, bool print_llvmir, bool enable_ftz, bool index_64bit, + bool jit_compile, bool jit_i64_indexed_for_large_tensors) { // Read TF code. std::string tf_code; TF_RETURN_IF_ERROR( ReadFileToString(Env::Default(), input_file.str(), &tf_code)); + // Compile. mlir::MLIRContext context; TF_ASSIGN_OR_RETURN( mlir::OwningOpRef module, GenerateKernelForTfCode(context, tf_code, architectures, tile_sizes, - unroll_factors, max_supported_rank, - embed_memref_prints, print_ptx, print_llvmir, - enable_ftz, index_64bit, jit_compile, - jit_i64_indexed_for_large_tensors, + unroll_factors, max_supported_rank, print_ptx, + print_llvmir, enable_ftz, index_64bit, + jit_compile, jit_i64_indexed_for_large_tensors, /*apply_cl_options=*/true)); // Get binary. @@ -149,10 +148,6 @@ int main(int argc, char** argv) { llvm::cl::opt index_64bit("index_64bit", llvm::cl::desc("enable 64 bit indexing"), llvm::cl::init(false)); - llvm::cl::opt embed_memref_prints( - "embed_memref_prints", - llvm::cl::desc("embed memref prints at the end of their lifetime"), - llvm::cl::init(false)); llvm::cl::opt print_ptx( "print-ptx", llvm::cl::desc("print generated PTX code per target architecture."), @@ -198,8 +193,8 @@ int main(int argc, char** argv) { auto status = tensorflow::kernel_gen::Run( input_file, output_file, architectures, tile_sizes, unroll_factors, - max_supported_rank, embed_memref_prints, print_ptx, print_llvmir, - enable_ftz, index_64bit, jit_compile, jit_i64_indexed_for_large_tensors); + max_supported_rank, print_ptx, print_llvmir, enable_ftz, index_64bit, + jit_compile, jit_i64_indexed_for_large_tensors); if (!status.ok()) { LOG(ERROR) << status; return 1; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc index 441f54a8848..7a862e83073 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc @@ -20,10 +20,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/register.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" int main(int argc, char **argv) { mlir::registerAllPasses(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index d9af399327f..79a43aeb240 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -10,6 +10,7 @@ load( load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//tensorflow/compiler/mlir/tools/kernel_gen:friends"], licenses = ["notice"], ) @@ -176,7 +177,6 @@ cc_library( "buffer_reuse_pass.cc", "bufferize_pass.cc", "copy_cleanup_pass.cc", - "embed_memref_prints.cc", "embed_tf_framework_pass.cc", "fuse_inner_parallel_loops_pass.cc", "parallel_loops_to_sequential.cc", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc index a99b0f38935..a0787525b39 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" constexpr llvm::StringRef mlir::kernel_gen::tf_framework::TFAllocOp::kReuseOutputAttrName; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc index 24c5e1c403c..91e80b920fa 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc @@ -23,7 +23,7 @@ limitations under the License. #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc index 46c83032570..75bf544a86f 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc @@ -22,7 +22,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/transforms/passes.h" namespace mlir { namespace kernel_gen { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc deleted file mode 100644 index 139a9fc7b1f..00000000000 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_memref_prints.cc +++ /dev/null @@ -1,196 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" -#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/utils.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" - -namespace mlir { -namespace kernel_gen { -namespace transforms { -namespace { - -constexpr StringRef kPrintStringFuncName = "printCString"; - -#define GEN_PASS_DEF_EMBEDMEMREFPRINTSPASS -#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" - -Operation* EmitMemRefPrint(Location loc, Type element_type, Value arg, - OpBuilder* b) { - StringRef func_name; - if (element_type.isF32()) { - func_name = "printMemrefF32"; - } - if (element_type.isF64()) { - func_name = "printMemrefF64"; - } - if (element_type.isInteger(32)) { - func_name = "printMemrefI32"; - } - if (element_type.isInteger(64) || element_type.isIndex()) { - func_name = "printMemrefI64"; - } - assert(!func_name.empty() && - "Did not find a print function for the element type"); - - auto caller_func = - b->getInsertionBlock()->getParent()->getParentOfType(); - auto func_name_attr = b->getStringAttr(func_name); - - auto callee_func = SymbolTable::lookupNearestSymbolFrom( - caller_func, func_name_attr); - if (!callee_func) { - OpBuilder::InsertionGuard insertGuard(*b); - - auto module = caller_func->getParentOfType(); - b->setInsertionPointToStart(module.getBody()); - auto func_type = FunctionType::get(b->getContext(), arg.getType(), - /*results=*/llvm::None); - callee_func = - b->create(module.getLoc(), func_name, func_type); - callee_func.setPrivate(); - } - return b->create(loc, callee_func, arg); -} - -bool IsElementTypePrintalble(Type element_type) { - return element_type.isF32() || element_type.isF64() || - element_type.isInteger(32) || element_type.isInteger(64) || - element_type.isIndex(); -} - -void EmitMemRefPrint(Location loc, Value memref, OpBuilder* b) { - auto memref_type = memref.getType(); - if (auto unranked_type = memref_type.dyn_cast()) { - Type element_type = unranked_type.getElementType(); - if (!IsElementTypePrintalble(element_type)) return; - - EmitMemRefPrint(loc, element_type, memref, b); - } - if (auto ranked_type = memref_type.dyn_cast()) { - Type element_type = ranked_type.getElementType(); - if (!IsElementTypePrintalble(element_type)) return; - - if (element_type.isIndex()) { - element_type = b->getI64Type(); - ranked_type = MemRefType::get(ranked_type.getShape(), element_type, - ranked_type.getLayout(), - ranked_type.getMemorySpace()); - memref = b->create(loc, ranked_type, memref); - } - - auto unranked_type = UnrankedMemRefType::get( - element_type, ranked_type.getMemorySpaceAsInt()); - Value unranked_memref = - b->create(loc, unranked_type, memref); - EmitMemRefPrint(loc, element_type, unranked_memref, b); - } -} - -SmallVector ExtractValuesToPrint(Operation* op) { - if (isa(op) || isa(op) || - isa(op) || isa(op)) { - return {op->getResult(0)}; - } - if (auto linalg = dyn_cast(op)) { - return linalg.getDpsInitOperands(); - } - if (auto loop = dyn_cast(op)) { - return loop.getOutputs(); - } - if (auto loop = dyn_cast(op)) { - return loop.getIterOperands(); - } - if (auto copy = dyn_cast(op)) { - return {copy.getTarget()}; - } - return {}; -} - -void EmitOperationPrint(Operation* op, OpBuilder* b) { - std::string debug_str = "\n\nPrint memref content after the following op\n"; - llvm::raw_string_ostream output_stream(debug_str); - - mlir::OpPrintingFlags flags; - op->print(output_stream, flags); - output_stream << "\n\n"; - - Location loc = op->getLoc(); - Value message_constant = CreateOrFindGlobalStringConstant( - loc, GetGlobalName("debug_op", debug_str), debug_str, b); - - // Insert function call. - MLIRContext* ctx = op->getContext(); - auto func_type = LLVM::LLVMFunctionType::get( - LLVM::LLVMVoidType::get(op->getContext()), - {LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8))}); - FlatSymbolRefAttr tf_func_ref = - GetOrInsertLLVMFunction(kPrintStringFuncName, func_type, op, b); - b->create(loc, llvm::None, tf_func_ref, - llvm::makeArrayRef({message_constant})); -} - -// The pass inserts printing on every mutation of memrefs. -struct EmbedMemRefPrintsPass - : public impl::EmbedMemRefPrintsPassBase { - void runOnOperation() override { - ModuleOp module = getOperation(); - module.walk([&](func::FuncOp func) { - if (func.isDeclaration()) return; - Block* body = &func.getBody().front(); - - // Print arguments. - OpBuilder b(&getContext()); - b.setInsertionPointToStart(body); - Location loc = func.getLoc(); - auto args = func.getArguments(); - if (!args.empty()) { - EmitOperationPrint(func, &b); - } - for (auto arg : args) { - EmitMemRefPrint(loc, arg, &b); - } - // Print buffers after every change. - for (auto& op : func.getBody().front().getOperations()) { - b.setInsertionPointAfter(&op); - auto memrefs = ExtractValuesToPrint(&op); - if (!memrefs.empty()) { - EmitOperationPrint(&op, &b); - } - for (auto memref : memrefs) { - EmitMemRefPrint(op.getLoc(), memref, &b); - } - } - }); - } -}; - -} // namespace - -std::unique_ptr> CreateEmbedMemRefPrintsPass() { - return std::make_unique(); -} - -} // namespace transforms -} // namespace kernel_gen -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc index 4a807527e92..dcb59b2ae06 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc @@ -19,7 +19,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" #include "tensorflow/compiler/xla/debug_options_flags.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/target_constants.h" @@ -143,7 +143,6 @@ class GpuKernelToBlobPass // Compile and collect requested cubin and PTX images. std::vector images; - TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config)); auto gpu_asm_opts = xla::gpu::PtxOptsFromDebugOptions(config.debug_options()); for (const std::string& arch_str : architectures_) { @@ -165,7 +164,7 @@ class GpuKernelToBlobPass xla::gpu::nvptx::CompileToPtx( llvm_module_copy.get(), tensorflow::se::CudaComputeCapability{cc_major, cc_minor}, config, - libdevice_dir, enable_fusion)); + enable_fusion)); if (print_ptx_) { llvm::dbgs() << "Generated PTX code for module '" << gpu_module.getName() << "' on architecture sm_" << arch @@ -237,21 +236,6 @@ class GpuKernelToBlobPass return std::pair(is_compute_profile, arch); } - tensorflow::StatusOr GetLibdeviceDir( - const xla::HloModuleConfig& hlo_module_config) { - for (const std::string& cuda_root : tsl::CandidateCudaRoots( - hlo_module_config.debug_options().xla_gpu_cuda_data_dir())) { - std::string libdevice_dir = - tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice"); - VLOG(2) << "Looking for libdevice at " << libdevice_dir; - if (tsl::Env::Default()->IsDirectory(libdevice_dir).ok()) { - VLOG(2) << "Found libdevice dir " << libdevice_dir; - return libdevice_dir; - } - } - return tensorflow::errors::Internal( - "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice"); - } bool enable_ftz_; }; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h index bc775f3877e..167b370e17f 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h @@ -35,7 +35,6 @@ limitations under the License. #define GEN_PASS_DECL_PARALLELLOOPSTOSEQUENTIAL #define GEN_PASS_DECL_PROPAGATETFABIKNOWLEDGETOKERNELS #define GEN_PASS_DECL_PROPAGATESHAPEKNOWLEDGETOKERNELS -#define GEN_PASS_DECL_EMBEDMEMREFPRINTSPASS #define GEN_PASS_DECL_FUSEINNERPARALLELLOOPSPASS #define GEN_PASS_DECL_COPYCLEANUPPASS @@ -95,9 +94,6 @@ CreatePropagateTfAbiKnowledgeToKernels(); std::unique_ptr> CreatePropagateShapeKnowledgeToKernels(); -// Pass to print content of memrefs. -std::unique_ptr> CreateEmbedMemRefPrintsPass(); - /// Greedily maps loops to GPU hardware dimensions. std::unique_ptr> CreateMapParallelLoopsPass(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td index a12551c7205..3e32ca9da0d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td @@ -111,12 +111,6 @@ def PropagateShapeKnowledgeToKernels let constructor = "transforms::CreatePropagateShapeKnowledgeToKernels()"; } -def EmbedMemRefPrintsPass : Pass<"embed-memref-prints", "mlir::ModuleOp"> { - let summary = "Pass to print content of memrefs"; - let constructor = "transforms::CreateEmbedMemRefPrintsPass()"; - let dependentDialects = ["LLVM::LLVMDialect"]; -} - def FuseInnerParallelLoopsPass : Pass<"fuse-inner-parallel-loops", "mlir::func::FuncOp"> { let summary = "Limited pass to forward stores to loads."; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc index 5d9efdc5f46..246556da0ca 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc @@ -35,7 +35,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #define DEBUG_TYPE "kernel-gen-shapes" @@ -120,7 +120,7 @@ struct ShapeValue { ArrayRef scalars() const { assert(!is_vector); - return llvm::makeArrayRef(shape); + return llvm::ArrayRef(shape); } bool isVector() const { return is_vector; } @@ -294,8 +294,7 @@ class ShapeEqualityKnowledge { if (!candidate) candidate = dimOp.getSource(); auto index = dimOp.getConstantIndex(); if (!index.has_value()) return false; - return candidate == dimOp.getSource() && - p.index() == index.getValue(); + return candidate == dimOp.getSource() && p.index() == index.value(); }); if (all_are_dimops && candidate) { equal_shapes_.unionSets(candidate.getAsOpaquePointer(), diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tensorflow_abi_knowledge_propagation.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tensorflow_abi_knowledge_propagation.cc index 4eb1b7f3fad..debac04d583 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tensorflow_abi_knowledge_propagation.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tensorflow_abi_knowledge_propagation.cc @@ -31,7 +31,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" namespace mlir { namespace kernel_gen { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc index d36ddaca580..d888f664830 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc @@ -59,13 +59,13 @@ class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern { std::pair ConvertArrayAttrToStackAllocatedArray( Location loc, Type size_ty, Type element_ty, - llvm::Optional attr, ConversionPatternRewriter *rewriter, + std::optional attr, ConversionPatternRewriter *rewriter, std::function create_element) const { Type element_ptr_ty = LLVM::LLVMPointerType::get(element_ty); // If the attribute is missing or empty, set the element count to 0 and // return NULL. - if (!attr.has_value() || attr.getValue().empty()) { + if (!attr.has_value() || attr.value().empty()) { Value zero = rewriter->create( loc, size_ty, rewriter->getIntegerAttr(size_ty, 0)); Value null_ptr = rewriter->create(loc, element_ptr_ty); @@ -73,7 +73,7 @@ class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern { } // Allocate array to store the elements. - auto &array_attr = attr.getValue(); + auto &array_attr = attr.value(); Value array_size = rewriter->create( loc, size_ty, rewriter->getIntegerAttr(size_ty, array_attr.size())); Value array_ptr = rewriter->create( @@ -91,7 +91,7 @@ class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern { std::pair ConvertIntegerArrayAttrToStackAllocatedArray( Location loc, Type size_ty, Type element_ty, - llvm::Optional attr, + std::optional attr, ConversionPatternRewriter *rewriter) const { assert(size_ty.isa() && "expect integer size type"); assert(element_ty.isa() && "expect integer element type"); @@ -134,7 +134,7 @@ class TFAllocOpConverter : public ConvertToLLVMCallOpPattern { Value output_index = rewriter.create( loc, llvmInt32Type, rewriter.getI32IntegerAttr(tf_alloc_op.getOutputIndex().has_value() - ? tf_alloc_op.getOutputIndex().getValue() + ? tf_alloc_op.getOutputIndex().value() : -1)); // Convert `candidate_input_indices`. @@ -150,10 +150,9 @@ class TFAllocOpConverter : public ConvertToLLVMCallOpPattern { rewriter .create( loc, getVoidPtrType(), tf_func_ref, - llvm::makeArrayRef({adaptor.getCtx(), num_elements, - element_size, output_index, - candidates_count_and_ptr.first, - candidates_count_and_ptr.second})) + llvm::ArrayRef({adaptor.getCtx(), num_elements, element_size, + output_index, candidates_count_and_ptr.first, + candidates_count_and_ptr.second})) .getResult(); MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor( @@ -173,7 +172,7 @@ class TFAllocOpConverter : public ConvertToLLVMCallOpPattern { Type llvm_void_ptr_type = getVoidPtrType(); return LLVM::LLVMFunctionType::get( llvm_void_ptr_type, - llvm::makeArrayRef( + llvm::ArrayRef( {/*void* op_kernel_ctx*/ llvm_void_ptr_type, /*size_t num_elements*/ getIndexType(), /*size_t element_size*/ getIndexType(), @@ -239,7 +238,7 @@ class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern { GetOrInsertLLVMFunction(GetFuncName(), GetFuncType(), op, &rewriter); rewriter.replaceOpWithNewOp( op, llvm::None, tf_func_ref, - llvm::makeArrayRef({adaptor.getCtx(), allocated_bytes_ptr})); + llvm::ArrayRef({adaptor.getCtx(), allocated_bytes_ptr})); return success(); } @@ -285,10 +284,10 @@ class JITCompileFromStrOpConverter GetOrInsertLLVMFunction(GetFuncName(), GetFuncType(), op, &rewriter); rewriter.replaceOpWithNewOp( op, getVoidPtrType(), tf_func_ref, - llvm::makeArrayRef({adaptor.getCtx(), jit_module_code, tile_sizes.first, - tile_sizes.second, unroll_factors.first, - unroll_factors.second, max_supported_rank, - enable_ftz, index_64bit, cpu_codegen})); + llvm::ArrayRef({adaptor.getCtx(), jit_module_code, tile_sizes.first, + tile_sizes.second, unroll_factors.first, + unroll_factors.second, max_supported_rank, enable_ftz, + index_64bit, cpu_codegen})); return success(); } @@ -418,7 +417,7 @@ class ReportErrorOpConverter adaptor.getErrorCodeAttr()); rewriter.replaceOpWithNewOp( op, llvm::None, tf_func_ref, - llvm::makeArrayRef({adaptor.getCtx(), error_code, message_constant})); + llvm::ArrayRef({adaptor.getCtx(), error_code, message_constant})); return success(); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc index 75caa39d1e4..0cc3e5f3a66 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc @@ -270,7 +270,7 @@ class TFKernelToLLVMPass arith::populateArithExpandOpsPatterns(patterns); memref::populateExpandOpsPatterns(patterns); arith::populateArithToLLVMConversionPatterns(type_converter, patterns); - populateMemRefToLLVMConversionPatterns(type_converter, patterns); + populateFinalizeMemRefToLLVMConversionPatterns(type_converter, patterns); populateMathToLLVMConversionPatterns(type_converter, patterns); populateFuncToLLVMConversionPatterns(type_converter, patterns); cf::populateControlFlowToLLVMConversionPatterns(type_converter, patterns); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_to_jit_invocations.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_to_jit_invocations.cc index 36147381845..6bc16df2a50 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_to_jit_invocations.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_to_jit_invocations.cc @@ -13,20 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include "mlir/Dialect/SCF/IR/SCF.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project @@ -34,7 +33,6 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeRange.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" @@ -49,7 +47,7 @@ namespace kernel_gen { namespace transforms { namespace { -constexpr int64_t i32BitLimit = 4294967296; +constexpr int64_t i32Limit = 4294967296; using shape::ShapeOfOp; bool IsSingleResultTFOperation(Operation *op) { @@ -93,7 +91,7 @@ struct TFToJITInvocationsPattern : public RewritePattern { op->getOperandTypes(), locs); // Map operands. - BlockAndValueMapping bvm; + IRMapping bvm; for (auto it : llvm::zip(op->getOperands(), block->getArguments())) bvm.map(std::get<0>(it), std::get<1>(it)); @@ -122,60 +120,51 @@ struct TFToI64JITInvocationForLargeTensorsPattern : public RewritePattern { return failure(); } - auto results = llvm::to_vector<16>(op->getResults()); - auto operand_types = llvm::to_vector<16>(llvm::map_range( - op->getOperands(), [](Value v) { return v.getType(); })); - auto result_types = llvm::to_vector<16>( - llvm::map_range(results, [](Value v) { return v.getType(); })); - - // Create the JIT compile op. + // Create large argument condition. auto loc = op->getLoc(); - Value shape_size_limit = - rewriter.create(loc, i32BitLimit); auto arg = op->getOperands().front(); auto shape = rewriter.create(loc, arg); auto num_elems = rewriter.create(loc, shape); - Value coniditon_check_main = rewriter.create( - loc, arith::CmpIPredicate::sgt, num_elems, shape_size_limit); - - Value conditional_path = - rewriter - .create( - loc, op->getResultTypes(), coniditon_check_main, - [&](OpBuilder &b, Location l) { - auto jit_compile_op = - rewriter.create( - loc, - rewriter.getType(), - llvm::None); - BlockAndValueMapping bvm; - { - OpBuilder::InsertionGuard guard(rewriter); - Block *block = rewriter.createBlock( - &jit_compile_op.getBody(), {}, operand_types, - SmallVector(operand_types.size(), loc)); - for (auto it : - llvm::zip(op->getOperands(), block->getArguments())) - bvm.map(std::get<0>(it), std::get<1>(it)); - rewriter.setInsertionPointToStart(block); - rewriter.clone(*op, bvm); - auto new_op = rewriter.clone(*op, bvm); - rewriter.create( - loc, TypeRange{}, new_op->getResults()); - } - auto jit_execute_op = - rewriter.create( - loc, result_types, Value(), - jit_compile_op.getResult(), op->getOperands()); - b.create(l, jit_execute_op.getResult()); - }, - [&](OpBuilder &b, Location l) { - auto new_op = rewriter.clone(*op); - b.create(l, new_op->getResult(0)); - }) - .getResult(0); - - rewriter.replaceOp(op, conditional_path); + Value cst_i32_limit = + rewriter.create(loc, i32Limit); + Value large_tensor_predicate = rewriter.create( + loc, arith::CmpIPredicate::sgt, num_elems, cst_i32_limit); + + // Create dispatch code. + auto jit_body_builder_fn = [&](OpBuilder &b, Location loc) { + // Create JIT compile op. + auto callable_ty = b.getType(); + auto jit_compile_op = b.create( + loc, callable_ty, /*ctx=*/Value()); + IRMapping bvm; + { + OpBuilder::InsertionGuard g(b); + Block *block = + b.createBlock(&jit_compile_op.getBody(), {}, op->getOperandTypes(), + SmallVector(op->getNumOperands(), loc)); + for (auto it : llvm::zip(op->getOperands(), block->getArguments())) + bvm.map(std::get<0>(it), std::get<1>(it)); + b.setInsertionPointToStart(block); + Operation *cloned_op = b.clone(*op, bvm); + b.create( + loc, cloned_op->getResults().front()); + } + + // Create JIT execute op. + auto jit_execute_op = b.create( + loc, op->getResultTypes().front(), /*ctx=*/Value(), + jit_compile_op.getResult(), arg); + b.create(loc, jit_execute_op.getResult()); + }; + auto aot_body_builder_fn = [&](OpBuilder &b, Location loc) { + Operation *cloned_op = b.clone(*op); + b.create(loc, cloned_op->getResults().front()); + }; + + // Create and replace in two steps to clone the original op. + auto ifOp = rewriter.create( + loc, large_tensor_predicate, jit_body_builder_fn, aot_body_builder_fn); + rewriter.replaceOp(op, ifOp.getResults()); return success(); } }; @@ -188,14 +177,13 @@ struct PackJITCompileOpPattern llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, int64_t max_supported_rank, bool enable_ftz, - bool index_64bit_if_jit_compiling, - bool cpu_codegen) + bool index_64bit, bool cpu_codegen) : OpRewritePattern(ctx), tile_sizes(tile_sizes), unroll_factors(unroll_factors), max_supported_rank(max_supported_rank), enable_ftz(enable_ftz), - index_64bit_if_jit_compiling(index_64bit_if_jit_compiling), + index_64bit(index_64bit), cpu_codegen(cpu_codegen) {} LogicalResult matchAndRewrite(tf_framework::JITCompileOp op, @@ -207,25 +195,31 @@ struct PackJITCompileOpPattern // Temporarily, build the module that would be JIT-compiled. This is only to // obtain the serialized code attribute. auto loc = op->getLoc(); - OpBuilder tmp_module_builder(getContext(), rewriter.getListener()); - auto jit_module = tmp_module_builder.create(loc); - tmp_module_builder.setInsertionPointToStart( - jit_module.SingleBlock::getBody()); - auto jit_function = tmp_module_builder.create( - loc, tf_framework::JITCompileFromStrOp::kJITEntryFunctionName, - tmp_module_builder.getFunctionType(body->getArgumentTypes(), - yield_op->getOperandTypes())); - jit_function->setAttr(tf_framework::TFFrameworkDialect::kTFEntryAttrName, - tmp_module_builder.getUnitAttr()); - jit_function.getBody().takeBody(op.getBodyRegion()); - tmp_module_builder.setInsertionPointToEnd(&jit_function.getBody().front()); - tmp_module_builder.create(loc, yield_op.getResult()); - rewriter.eraseOp(yield_op); + auto jit_module = rewriter.create(loc); + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointToStart(jit_module.SingleBlock::getBody()); + auto jit_function = rewriter.create( + loc, tf_framework::JITCompileFromStrOp::kJITEntryFunctionName, + rewriter.getFunctionType(body->getArgumentTypes(), + yield_op->getOperandTypes())); + jit_function->setAttr(tf_framework::TFFrameworkDialect::kTFEntryAttrName, + rewriter.getUnitAttr()); + jit_function.getBody().takeBody(op.getBodyRegion()); + rewriter.setInsertionPointToEnd(&jit_function.getBody().front()); + rewriter.create(loc, yield_op.getResult()); + rewriter.eraseOp(yield_op); + } // Serialize JIT module. std::string code; llvm::raw_string_ostream ss(code); - jit_module.print(ss); + assert(succeeded(jit_module.verify())); + mlir::OpPrintingFlags flags; + jit_module.print(ss, flags.assumeVerified()); + + // Remove temporary module. + rewriter.eraseOp(jit_module); // Finally, create the new JIT compile op. rewriter.replaceOpWithNewOp( @@ -233,8 +227,7 @@ struct PackJITCompileOpPattern rewriter.getI64ArrayAttr(tile_sizes), rewriter.getI64ArrayAttr(unroll_factors), rewriter.getI64IntegerAttr(max_supported_rank), - rewriter.getBoolAttr(enable_ftz), - rewriter.getBoolAttr(index_64bit_if_jit_compiling), + rewriter.getBoolAttr(enable_ftz), rewriter.getBoolAttr(index_64bit), rewriter.getBoolAttr(cpu_codegen)); return success(); @@ -245,7 +238,7 @@ struct PackJITCompileOpPattern llvm::ArrayRef unroll_factors; int64_t max_supported_rank; bool enable_ftz; - bool index_64bit_if_jit_compiling; + bool index_64bit; bool cpu_codegen; }; diff --git a/tensorflow/compiler/mlir/tosa/BUILD b/tensorflow/compiler/mlir/tosa/BUILD index 935d6d08bf8..b942c64f611 100644 --- a/tensorflow/compiler/mlir/tosa/BUILD +++ b/tensorflow/compiler/mlir/tosa/BUILD @@ -8,6 +8,7 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") # TODO: Tighten visibility once targets are at the right granularity. package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [":internal"], licenses = ["notice"], ) @@ -193,6 +194,7 @@ cc_library( "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineTransforms", "@llvm-project//mlir:Dialect", diff --git a/tensorflow/compiler/mlir/tosa/tests/BUILD b/tensorflow/compiler/mlir/tosa/tests/BUILD index 23b25f7b706..e7c4a5b9a61 100644 --- a/tensorflow/compiler/mlir/tosa/tests/BUILD +++ b/tensorflow/compiler/mlir/tosa/tests/BUILD @@ -1,7 +1,10 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir index 1e18821b1e0..c0e5b23e3b8 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir @@ -10,7 +10,7 @@ // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<16xf32>} // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi32>} // CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%arg1, %[[VAR1]]) -// CHECK: %[[VAR3:.*]] = "tosa.conv2d"(%arg0, %[[VAR2]], %[[VAR0]]) {dilation = [1, 1], pad = [0, 1, 0, 1], stride = [1, 1]} +// CHECK: %[[VAR3:.*]] = "tosa.conv2d"(%arg0, %[[VAR2]], %[[VAR0]]) {dilation = array, pad = array, stride = array} func.func @test_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x16xf32>) -> tensor<1x32x32x16xf32> { %3 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x32x32x8xf32>, tensor<2x2x8x16xf32>) -> tensor<1x32x32x16xf32> func.return %3 : tensor<1x32x32x16xf32> @@ -20,7 +20,7 @@ func.func @test_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x16xf32> // CHECK-LABEL: test_depthwise_conv2d // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<16xf32>} -// CHECK: %[[VAR1:.*]] = "tosa.depthwise_conv2d"(%arg0, %arg1, %0) {dilation = [1, 1], pad = [0, 1, 0, 1], stride = [1, 1]} +// CHECK: %[[VAR1:.*]] = "tosa.depthwise_conv2d"(%arg0, %arg1, %0) {dilation = array, pad = array, stride = array} func.func @test_depthwise_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x2xf32>) -> tensor<1x32x32x16xf32> { %5 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x32x32x8xf32>, tensor<2x2x8x2xf32>) -> tensor<1x32x32x16xf32> %6 = "tf.Identity"(%5) : (tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32> @@ -29,11 +29,12 @@ func.func @test_depthwise_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2 // ----- -// CHECK-LABEL: test_transpose_conv2d -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<16xf32>} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%arg1, %[[VAR0]]) -// CHECK: %[[VAR3:.*]] = "tosa.transpose_conv2d"(%arg0, %[[VAR2]], %[[VAR1]]) {out_pad = [0, 0, 0, 0], out_shape = [1, 32, 32, 16], stride = [1, 1]} +// CHECK-LABEL: @test_transpose_conv2d +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x32x32x8xf32>, %[[ARG1:.*]]: tensor<1x1x16x8xf32> +// CHECK: %[[CONST:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<16xf32>} +// CHECK: %[[RESHAPE:.*]] = "tosa.reshape"(%[[ARG1]]) {new_shape = array} +// CHECK: %[[TRANSPOSE:.*]] = "tosa.transpose_conv2d"(%[[ARG0]], %[[RESHAPE]], %[[CONST]]) {out_pad = array, out_shape = array, stride = array} +// CHECK: return %[[TRANSPOSE]] func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1x1x16x8xf32>) -> tensor<1x32x32x16xf32> { %3 = "tf.Const"() {value = dense<[1, 32, 32, 16]> : tensor<4xi32>} : () -> tensor<4xi32> %4 = "tf.Conv2DBackpropInput"(%3, %arg1, %arg0) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<4xi32>, tensor<1x1x16x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x16xf32> @@ -48,7 +49,7 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1x1 // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<4xf32>} // CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() {value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>} // CHECK: %[[VAL_4:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_3]]) -// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) {dilation = [1, 1, 1], pad = [0, 1, 0, 1, 0, 1], stride = [1, 2, 2]} +// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) {dilation = array, pad = array, stride = array} func.func @test_conv3d(%arg0: tensor<2x4x128x128x8xf32>, %arg1: tensor<2x3x3x2x4xf32>) -> tensor<2x4x64x64x4xf32> { %0 = "tf.Conv3D"(%arg0, %arg1) {data_format = "NDHWC", device = "", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 2, 1]} : (tensor<2x4x128x128x8xf32>, tensor<2x3x3x2x4xf32>) -> tensor<2x4x64x64x4xf32> return %0 : tensor<2x4x64x64x4xf32> @@ -62,7 +63,7 @@ func.func @test_conv3d(%arg0: tensor<2x4x128x128x8xf32>, %arg1: tensor<2x3x3x2x4 // CHECK-SAME: %[[VAL_2:.*]]: tensor<10xf32>) -> tensor<3x32x16x16x10xf32> // CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() {value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>} // CHECK: %[[VAL_4:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_3]]) -// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) {dilation = [1, 1, 1], pad = [0, 1, 1, 1, 1, 1], stride = [1, 1, 1]} +// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) {dilation = array, pad = array, stride = array} func.func @test_conv3d_bias(%arg0: tensor<3x32x16x16x5xf32>, %arg1: tensor<2x3x3x5x10xf32>, %bias: tensor<10xf32>) -> tensor<3x32x16x16x10xf32> { %0 = "tf.Conv3D"(%arg0, %arg1) {data_format = "NDHWC", device = "", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<3x32x16x16x5xf32>, tensor<2x3x3x5x10xf32>) -> tensor<3x32x16x16x10xf32> %1 = "tf.BiasAdd"(%0, %bias) {data_format = "NHWC", device = ""} : (tensor<3x32x16x16x10xf32>, tensor<10xf32>) -> tensor<3x32x16x16x10xf32> @@ -241,7 +242,7 @@ func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> { // CHECK-LABEL: test_reduce_any // CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_any"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [21, 3]} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.Any"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xi1>, tensor<1xi32>) -> tensor<21x3xi1> @@ -252,7 +253,7 @@ func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { // CHECK-LABEL: test_reduce_all // CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_all"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [21, 3]} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.All"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xi1>, tensor<1xi32>) -> tensor<21x3xi1> @@ -263,7 +264,7 @@ func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { // CHECK-LABEL: test_reduce_min // CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_min"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [21, 3]} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.Min"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> @@ -274,7 +275,7 @@ func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // CHECK-LABEL: test_reduce_max // CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_max"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [21, 3]} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.Max"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> @@ -285,7 +286,7 @@ func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // CHECK-LABEL: test_reduce_sum // CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [21, 3]} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.Sum"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> @@ -294,10 +295,27 @@ func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // ----- +// CHECK-LABEL: test_reduce_sum_nonzero_axis +// CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20x30x40x50xf32> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() {value = dense<[0, 1, 2, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32> +// CHECK: %[[VAL_2:.*]] = "tosa.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<10x20x30x40x50xf32>, tensor<5xi32>) -> tensor<10x20x30x50x40xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = array} : (tensor<10x20x30x50x40xf32>) -> tensor<300000x40xf32> +// CHECK: %[[VAL_4:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 1 : i64} : (tensor<300000x40xf32>) -> tensor<300000x1xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.reshape"(%[[VAL_4]]) {new_shape = array} : (tensor<300000x1xf32>) -> tensor<10x20x30x50xf32> +// CHECK: return %[[VAL_5]] : tensor<10x20x30x50xf32> +func.func @test_reduce_sum_nonzero_axis(%arg0: tensor<10x20x30x40x50xf32> {tf._user_specified_name = "inp_list"}) -> tensor<10x20x30x50xf32> { + %cst = "tf.Const"() {device = "", value = dense<3> : tensor} : () -> tensor + %0 = "tf.Sum"(%arg0, %cst) {device = "", keep_dims = false} : (tensor<10x20x30x40x50xf32>, tensor) -> tensor<10x20x30x50xf32> + %1 = "tf.Identity"(%0) {device = ""} : (tensor<10x20x30x50xf32>) -> tensor<10x20x30x50xf32> + func.return %1 : tensor<10x20x30x50xf32> +} + +// ----- + // CHECK-LABEL: test_reduce_mean // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.0769230798> : tensor<1x1xf32>} // CHECK-DAG: %[[VAR1:.*]] = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.reshape"(%[[VAR1]]) {new_shape = [21, 3]} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.reshape"(%[[VAR1]]) {new_shape = array} // CHECK: %[[VAR4:.*]] = "tosa.mul"(%[[VAR2]], %[[VAR0]]) {shift = 0 : i32} func.func @test_reduce_mean(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> @@ -309,7 +327,7 @@ func.func @test_reduce_mean(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // CHECK-LABEL: test_reduce_product // CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [21, 3]} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} func.func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.Prod"(%arg0, %2) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> @@ -399,6 +417,62 @@ func.func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- +// CHECK-LABEL: test_sin +// CHECK-SAME: -> tensor<10xf32> +func.func @test_sin(%arg0: tensor<10xf32>) -> tensor<*xf32> { + // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<1xf32>} + // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<1xf32>} + // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() {value = dense<2.38418579E-7> : tensor<1xf32>} + // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() {value = dense<3.276700e+04> : tensor<1xf32>} + // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() {value = dense<0.159154937> : tensor<1xf32>} + // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() {value = dense<{{.+}}> : tensor<513xi16>} + // CHECK-DAG: %[[IN_SCALED:.+]] = "tosa.mul"(%arg0, %[[IN_SCALE]]) + // CHECK-DAG: %[[FLOOR:.+]] = "tosa.floor"(%[[IN_SCALED]]) + // CHECK-DAG: %[[SUB1:.+]] = "tosa.sub"(%[[IN_SCALED]], %[[FLOOR]]) + // CHECK-DAG: %[[MUL1:.+]] = "tosa.mul"(%[[SUB1]], %[[TWO]]) + // CHECK-DAG: %[[SUB2:.+]] = "tosa.sub"(%[[MUL1]], %[[ONE]]) + // CHECK-DAG: %[[MUL2:.+]] = "tosa.mul"(%[[SUB2]], %[[INT_MAX]]) + // CHECK-DAG: %[[TO_INT:.+]] = "tosa.cast"(%[[MUL2]]) + // CHECK-DAG: %[[TABLE:.+]] = "tosa.table"(%[[TO_INT]], %[[TBLVAL]]) + // CHECK-DAG: %[[TABLE_CAST:.+]] = "tosa.cast"(%[[TABLE]]) + // CHECK-DAG: %[[RESULT:.+]] = "tosa.mul"(%[[TABLE_CAST:.+]], %[[RESULT_SCALE]]) + %0 = "tf.Sin"(%arg0) : (tensor<10xf32>) -> tensor<*xf32> + + // CHECK: return %[[RESULT]] + func.return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: test_cos +// CHECK-SAME: -> tensor<10xf32> +func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> { + // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() {value = dense<2.38418579E-7> : tensor<1xf32>} + // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() {value = dense<3.276700e+04> : tensor<1xf32>} + // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<1xf32>} + // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<1xf32>} + // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() {value = dense<0.159154937> : tensor<1xf32>} + // CHECK-DAG: %[[HALF_PI:.+]] = "tosa.const"() {value = dense<1.57079637> : tensor<1xf32>} + // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() {value = dense<{{.+}}> : tensor<513xi16>} + // CHECK-DAG: %[[IN_TRANSLATE:.+]] = "tosa.add"(%arg0, %[[HALF_PI]]) + // CHECK-DAG: %[[IN_SCALED:.+]] = "tosa.mul"(%[[IN_TRANSLATE]], %[[IN_SCALE]]) + // CHECK-DAG: %[[FLOOR:.+]] = "tosa.floor"(%[[IN_SCALED]]) + // CHECK-DAG: %[[SUB1:.+]] = "tosa.sub"(%[[IN_SCALED]], %[[FLOOR]]) + // CHECK-DAG: %[[MUL1:.+]] = "tosa.mul"(%[[SUB1]], %[[TWO]]) + // CHECK-DAG: %[[SUB2:.+]] = "tosa.sub"(%[[MUL1]], %[[ONE]]) + // CHECK-DAG: %[[MUL2:.+]] = "tosa.mul"(%[[SUB2]], %[[INT_MAX]]) + // CHECK-DAG: %[[TO_INT:.+]] = "tosa.cast"(%[[MUL2]]) + // CHECK-DAG: %[[TABLE:.+]] = "tosa.table"(%[[TO_INT]], %[[TBLVAL]]) + // CHECK-DAG: %[[TABLE_CAST:.+]] = "tosa.cast"(%[[TABLE]]) + // CHECK-DAG: %[[RESULT:.+]] = "tosa.mul"(%[[TABLE_CAST:.+]], %[[RESULT_SCALE]]) + %0 = "tf.Cos"(%arg0) : (tensor<10xf32>) -> tensor<*xf32> + + // CHECK: return %[[RESULT]] + func.return %0 : tensor<*xf32> +} + +// ----- + // CHECK-LABEL: test_sigmoid // CHECK: %[[VAR0:.*]] = "tosa.sigmoid"(%arg0) func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { @@ -475,7 +549,7 @@ func.func @test_argmax(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xi32> { // ----- // CHECK-LABEL: test_avg_pool2d -// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} +// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} func.func @test_avg_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { %2 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> func.return %2 : tensor<1x32x32x8xf32> @@ -484,7 +558,7 @@ func.func @test_avg_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32 // ----- // CHECK-LABEL: test_max_pool2d -// CHECK: %[[VAR0:.*]] = "tosa.max_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} +// CHECK: %[[VAR0:.*]] = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { %2 = "tf.MaxPool"(%arg0) {data_format = "NHWC", explicit_paddings = [], ksize = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> func.return %2 : tensor<1x32x32x8xf32> @@ -493,7 +567,7 @@ func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32 // ----- // CHECK-LABEL: test_reshape -// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 819]} +// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> { %0 = "tf.Const"() {value = dense<[1, 819]> : tensor<2xi32>} : () -> tensor<2xi32> %3 = "tf.Reshape"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<2xi32>) -> tensor<1x819xf32> @@ -515,7 +589,7 @@ func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> { // ----- // CHECK-LABEL: test_slice -// CHECK: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = [4, 11, 1], start = [6, 8, 0]} +// CHECK: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> { %2 = "tf.Const"() {value = dense<[6, 8, 0]> : tensor<3xi64>} : () -> tensor<3xi64> %3 = "tf.Const"() {value = dense<[4, 11, 1]> : tensor<3xi64>} : () -> tensor<3xi64> @@ -526,10 +600,10 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> { // ----- // CHECK-LABEL: test_strided_slice -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = [9, 21, 2], start = [4, 0, 1]} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [9, 1, 7, 3, 2, 1]} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) {size = [9, 1, 7, 1, 2, 1], start = [0, 0, 0, 0, 0, 0]} -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [9, 7, 2]} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) {size = array, start = array} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} func.func @test_strided_slice(%arg0: tensor<13x21x3xf32>) -> tensor<9x7x2xf32> { %2 = "tf.Const"() {value = dense<[4, 0, 1]> : tensor<3xi64>} : () -> tensor<3xi64> %3 = "tf.Const"() {value = dense<[13, 21, 3]> : tensor<3xi64>} : () -> tensor<3xi64> @@ -541,7 +615,7 @@ func.func @test_strided_slice(%arg0: tensor<13x21x3xf32>) -> tensor<9x7x2xf32> { // ----- // CHECK-LABEL: test_select -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 1]} : (tensor<1xi1>) -> tensor<1x1x1xi1> +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg2) {new_shape = array} : (tensor<1xi1>) -> tensor<1x1x1xi1> // CHECK: %[[VAR2:.*]] = "tosa.select"(%[[VAR1]], %arg0, %arg1) func.func @test_select(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<1xi1>) -> tensor<13x21x3xf32> { %2 = "tf.SelectV2"(%arg2, %arg0, %arg1) : (tensor<1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> @@ -573,7 +647,7 @@ func.func @test_concatv2(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, // CHECK-LABEL: test_stack // CHECK-DAG: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [4, 13, 21, 3]} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} func.func @test_stack(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32> { %2 = "tf.Pack"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i64} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32> func.return %2 : tensor<4x13x21x3xf32> @@ -582,7 +656,7 @@ func.func @test_stack(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %a // ----- // CHECK-LABEL: test_unstack -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = [32, 32, 8]} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} func.func @test_unstack(%arg0: tensor<1x32x32x8xf32>) -> tensor<32x32x8xf32> { %2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<1x32x32x8xf32>) -> tensor<32x32x8xf32> %3 = "tf.Identity"(%2) : (tensor<32x32x8xf32>) -> tensor<32x32x8xf32> @@ -604,7 +678,7 @@ func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<15x23x5xf32> { // ----- // CHECK-LABEL: test_expand_dims -// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 13, 21, 3]} +// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} func.func @test_expand_dims(%arg0: tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32> { %2 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor %3 = "tf.ExpandDims"(%arg0, %2) : (tensor<13x21x3xf32>, tensor) -> tensor<1x13x21x3xf32> @@ -614,7 +688,7 @@ func.func @test_expand_dims(%arg0: tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32> // ----- // CHECK-LABEL: test_expand_dims_negative_index -// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [13, 1, 21, 3]} +// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} func.func @test_expand_dims_negative_index(%arg0: tensor<13x21x3xf32>) -> tensor<13x1x21x3xf32> { %2 = "tf.Const"() {value = dense<-2> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.ExpandDims"(%arg0, %2) : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<13x1x21x3xf32> @@ -690,10 +764,10 @@ func.func @test_batch_matmul_3d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x3x4 // ----- // CHECK-LABEL: test_batch_matmul_4d -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [65, 21, 3]} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [65, 3, 42]} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = array} // CHECK-DAG: %[[VAR2:.*]] = "tosa.matmul"(%[[VAR0]], %[[VAR1]]) -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [5, 13, 21, 42]} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} func.func @test_batch_matmul_4d(%arg0: tensor<5x13x21x3xf32>, %arg1: tensor<5x13x3x42xf32>) -> tensor<5x13x21x42xf32> { %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false, device = ""} : (tensor<5x13x21x3xf32>, tensor<5x13x3x42xf32>) -> tensor<5x13x21x42xf32> func.return %0 : tensor<5x13x21x42xf32> @@ -702,10 +776,10 @@ func.func @test_batch_matmul_4d(%arg0: tensor<5x13x21x3xf32>, %arg1: tensor<5x13 // ----- // CHECK-LABEL: test_matmul -// CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 14, 19]} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 19, 28]} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = array} // CHECK-DAG: %[[VAR2:.*]] = "tosa.matmul"(%[[VAR0]], %[[VAR1]]) -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [14, 28]} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} func.func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<14x28xf32> { %2 = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<14x19xf32>, tensor<19x28xf32>) -> tensor<14x28xf32> func.return %2 : tensor<14x28xf32> @@ -738,9 +812,9 @@ func.func @test_add_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) - // ----- // CHECK-LABEL: test_split -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = [13, 7, 3], start = [0, 0, 0]} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.slice"(%arg0) {size = [13, 7, 3], start = [0, 7, 0]} -// CHECK: %[[VAR2:.*]] = "tosa.slice"(%arg0) {size = [13, 7, 3], start = [0, 14, 0]} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK: %[[VAR2:.*]] = "tosa.slice"(%arg0) {size = array, start = array} func.func @test_split(%arg0: tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>) { %6 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor %7:3 = "tf.Split"(%6, %arg0) : (tensor, tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>) @@ -775,9 +849,9 @@ func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi32>} // CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor} // CHECK-DAG: %[[VAR2:.*]] = "tosa.pad"(%arg0, %[[VAR0]], %[[PVAL]]) -// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [13, 11, 2, 3]} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} // CHECK-DAG: %[[VAR4:.*]] = "tosa.transpose"(%[[VAR3]], %[[VAR1]]) -// CHECK: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) {new_shape = [26, 11, 3]} +// CHECK: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) {new_shape = array} func.func @test_space_to_batch(%arg0: tensor<13x21x3xf32>) -> tensor<26x11x3xf32> { %2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> %3 = "tf.Const"() {value = dense<[[0, 1]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32> @@ -791,9 +865,9 @@ func.func @test_space_to_batch(%arg0: tensor<13x21x3xf32>) -> tensor<26x11x3xf32 // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[3, 1, 2, 0]> : tensor<4xi32>} // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<[2, 3, 0, 4, 1, 5]> : tensor<6xi32>} // CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%arg0, %[[VAR0]]) -// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [2, 2, 2, 32, 32, 1]} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} // CHECK-DAG: %[[VAR4:.*]] = "tosa.transpose"(%[[VAR3]], %[[VAR1]]) -// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) {new_shape = [2, 64, 64, 1]} +// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) {new_shape = array} // CHECK: return %[[VAR5]] func.func @test_batch_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<2x64x64x1xf32> { %2 = "tf.Const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> @@ -808,9 +882,9 @@ func.func @test_batch_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<2x64x64x1 // CHECK-LABEL: test_space_to_depth // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 16, 2, 16, 2, 8]} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} // CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%[[VAR1]], %[[VAR0]]) -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [1, 16, 16, 32]} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} func.func @test_space_to_depth(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x16x16x32xf32> { %2 = "tf.SpaceToDepth"(%arg0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<1x32x32x8xf32>) -> tensor<1x16x16x32xf32> func.return %2 : tensor<1x16x16x32xf32> @@ -820,9 +894,9 @@ func.func @test_space_to_depth(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x16x16x3 // CHECK-LABEL: test_depth_to_space // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 32, 32, 2, 2, 2]} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} // CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%[[VAR1]], %[[VAR0]]) -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [1, 64, 64, 2]} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} func.func @test_depth_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x2xf32> { %2 = "tf.DepthToSpace"(%arg0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<1x32x32x8xf32>) -> tensor<1x64x64x2xf32> func.return %2 : tensor<1x64x64x2xf32> @@ -848,17 +922,16 @@ func.func @test_right_shift(%arg0: tensor<4x4xi32>, %arg1: tensor<1x1xi32>) -> t // ----- -// CHECK-LABEL: test_one_hot -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1]} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.tile"(%[[VAR1]]) {multiples = [16, 1, 1]} -// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 1]} -// CHECK-DAG: %[[VAR4:.*]] = "tosa.tile"(%[[VAR3]]) {multiples = [16, 2, 1]} -// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg0) {new_shape = [16, 1]} -// CHECK-DAG: %[[VAR6:.*]] = "tosa.scatter"(%[[VAR4]], %[[VAR5]], %[[VAR2]]) -// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = [16, 1, 2]} -// CHECK-DAG: %[[VAR8:.*]] = "tosa.transpose"(%[[VAR7]], %[[VAR0]]) -// CHECK: %[[VAR9:.*]] = "tosa.reshape"(%[[VAR8]]) {new_shape = [4, 4, 2]} +// CHECK-LABEL: @test_one_hot +// CHECK-SAME: %[[ARG0_0:.*]]: tensor<4x4xi32>, %[[ARG1_0:.*]]: tensor, %[[ARG2:.*]]: tensor +// CHECK: %[[RESHAPE_0:.*]] = "tosa.reshape"(%[[ARG1_0]]) {new_shape = array} +// CHECK: %[[TILE:.*]] = "tosa.tile"(%[[RESHAPE_0]]) {multiples = array} +// CHECK: %[[RESHAPE_1:.*]] = "tosa.reshape"(%[[ARG2]]) {new_shape = array} +// CHECK: %[[TILE_0:.*]] = "tosa.tile"(%[[RESHAPE_1]]) {multiples = array} +// CHECK: %[[RESHAPE_2:.*]] = "tosa.reshape"(%[[ARG0_0]]) {new_shape = array} +// CHECK: %[[SCATTER:.*]] = "tosa.scatter"(%[[TILE_0]], %[[RESHAPE_2]], %[[TILE]]) +// CHECK: %[[RESHAPE_3:.*]] = "tosa.reshape"(%[[SCATTER]]) {new_shape = array} +// CHECK: return %[[RESHAPE_3]] func.func @test_one_hot(%arg0: tensor<4x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<4x4x2xf32> { %0 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor %1 = "tf.OneHot"(%arg0, %0, %arg1, %arg2) {axis = -1 : i64} : (tensor<4x4xi32>, tensor, tensor, tensor) -> tensor<4x4x2xf32> @@ -889,9 +962,9 @@ func.func @test_fakequant_with_min_max_args(%arg0: tensor<13x21x3xf32>) -> tenso // ----- // CHECK-LABEL: test_gather // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<1x49xi32>} -// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 13, 63]} +// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) {new_shape = array} // CHECK-DAG: %[[VAR6:.*]] = "tosa.gather"(%[[VAR4]], %[[VAR0]]) -// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = [7, 7, 21, 3]} +// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = array} // CHECK: return %[[VAR7]] func.func @test_gather(%arg0: tensor<13x21x3xf32>) -> tensor<7x7x21x3xf32> { %0 = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor @@ -904,9 +977,9 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>) -> tensor<7x7x21x3xf32> { // ----- // CHECK-LABEL: test_gather_nd // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<1x42xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 13, 63]} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} // CHECK-DAG: %[[VAR2:.*]] = "tosa.gather"(%[[VAR1]], %[[VAR0]]) -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [6, 7, 21, 3]} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} func.func @test_gather_nd(%arg0: tensor<13x21x3xf32>) -> tensor<6x7x21x3xf32> { %0 = "tf.Const"() {device = "", value = dense<[[[0], [5], [3], [12], [2], [4], [3]], [[11], [1], [11], [10], [3], [12], [8]], [[5], [3], [1], [11], [3], [10], [0]], [[0], [8], [4], [7], [3], [12], [2]], [[7], [6], [11], [4], [2], [10], [11]], [[11], [1], [11], [1], [1], [11], [8]]]> : tensor<6x7x1xi32>} : () -> tensor<6x7x1xi32> %1 = "tf.GatherNd"(%arg0, %0) {device = ""} : (tensor<13x21x3xf32>, tensor<6x7x1xi32>) -> tensor<6x7x21x3xf32> @@ -920,15 +993,15 @@ func.func @test_gather_nd(%arg0: tensor<13x21x3xf32>) -> tensor<6x7x21x3xf32> { // CHECK-LABEL: test_fused_batch_norm func.func @test_fused_batch_norm(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { // CHECK: %[[ONE:.+]] = "tosa.const"() {value = dense<1.000000e-03> : tensor<1xf32>} - // CHECK: %[[RES0:.+]] = "tosa.reshape"(%arg3) {new_shape = [1, 1, 1, 8]} + // CHECK: %[[RES0:.+]] = "tosa.reshape"(%arg3) {new_shape = array} // CHECK: %[[SUB0:.+]] = "tosa.sub"(%arg0, %[[RES0]]) // CHECK: %[[ADD0:.+]] = "tosa.add"(%arg4, %[[ONE]]) // CHECK: %[[RSQR:.+]] = "tosa.rsqrt"(%[[ADD0]]) - // CHECK: %[[RES1:.+]] = "tosa.reshape"(%[[RSQR]]) {new_shape = [1, 1, 1, 8]} + // CHECK: %[[RES1:.+]] = "tosa.reshape"(%[[RSQR]]) {new_shape = array} // CHECK: %[[MUL0:.+]] = "tosa.mul"(%[[SUB0]], %[[RES1]]) {shift = 0 : i32} - // CHECK: %[[RES1:.+]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 8]} + // CHECK: %[[RES1:.+]] = "tosa.reshape"(%arg1) {new_shape = array} // CHECK: %[[MUL1:.+]] = "tosa.mul"(%[[MUL0]], %[[RES1]]) {shift = 0 : i32} - // CHECK: %[[RES2:.+]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 1, 8]} + // CHECK: %[[RES2:.+]] = "tosa.reshape"(%arg2) {new_shape = array} // CHECK: %[[ADD1:.+]] = "tosa.add"(%[[MUL1]], %[[RES2]]) %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) @@ -949,12 +1022,12 @@ func.func @test_fused_batch_norm_training(%arg0: tensor<8x8x8x8xf32>, %arg1: ten // CHECK-LABEL: mirrorpad_symmetric // CHECK-SAME: %[[VAL_0:.*]]: tensor<5x10xf32> -// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) {size = [1, 10], start = [0, 0]} : (tensor<5x10xf32>) -// CHECK: %[[VAL_2:.*]] = "tosa.slice"(%[[VAL_0]]) {size = [2, 10], start = [3, 0]} : (tensor<5x10xf32>) +// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array, start = array} : (tensor<5x10xf32>) +// CHECK: %[[VAL_2:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array, start = array} : (tensor<5x10xf32>) // CHECK: %[[VAL_3:.*]] = "tosa.reverse"(%[[VAL_2]]) {axis = 0 : i64} : (tensor<2x10xf32>) // CHECK: %[[VAL_4:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]], %[[VAL_3]]) {axis = 0 : i64} : (tensor<1x10xf32>, tensor<5x10xf32>, tensor<2x10xf32>) -// CHECK: %[[VAL_5:.*]] = "tosa.slice"(%[[VAL_4]]) {size = [8, 1], start = [0, 0]} : (tensor<8x10xf32>) -// CHECK: %[[VAL_6:.*]] = "tosa.slice"(%[[VAL_4]]) {size = [8, 2], start = [0, 8]} : (tensor<8x10xf32>) +// CHECK: %[[VAL_5:.*]] = "tosa.slice"(%[[VAL_4]]) {size = array, start = array} : (tensor<8x10xf32>) +// CHECK: %[[VAL_6:.*]] = "tosa.slice"(%[[VAL_4]]) {size = array, start = array} : (tensor<8x10xf32>) // CHECK: %[[VAL_7:.*]] = "tosa.reverse"(%[[VAL_6]]) {axis = 1 : i64} : (tensor<8x2xf32>) // CHECK: %[[VAL_8:.*]] = "tosa.concat"(%[[VAL_5]], %[[VAL_4]], %[[VAL_7]]) {axis = 1 : i64} : (tensor<8x1xf32>, tensor<8x10xf32>, tensor<8x2xf32>) func.func @mirrorpad_symmetric(%arg0: tensor<5x10xf32>) -> tensor<8x13xf32> { @@ -968,11 +1041,11 @@ func.func @mirrorpad_symmetric(%arg0: tensor<5x10xf32>) -> tensor<8x13xf32> { // CHECK-LABEL: mirrorpad_reflect // CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xf32> -// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) {size = [1, 21, 3], start = [1, 0, 0]} : (tensor<13x21x3xf32>) +// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array, start = array} : (tensor<13x21x3xf32>) // CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]]) {axis = 0 : i64} : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_2]]) {size = [14, 1, 3], start = [0, 1, 0]} : (tensor<14x21x3xf32>) +// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_2]]) {size = array, start = array} : (tensor<14x21x3xf32>) // CHECK: %[[VAL_4:.*]] = "tosa.concat"(%[[VAL_3]], %[[VAL_2]]) {axis = 1 : i64} : (tensor<14x1x3xf32>, tensor<14x21x3xf32>) -// CHECK: %[[VAL_5:.*]] = "tosa.slice"(%[[VAL_4]]) {size = [14, 22, 1], start = [0, 0, 1]} : (tensor<14x22x3xf32>) +// CHECK: %[[VAL_5:.*]] = "tosa.slice"(%[[VAL_4]]) {size = array, start = array} : (tensor<14x22x3xf32>) // CHECK: %[[VAL_6:.*]] = "tosa.concat"(%[[VAL_5]], %[[VAL_4]]) {axis = 2 : i64} : (tensor<14x22x1xf32>, tensor<14x22x3xf32>) func.func @mirrorpad_reflect(%arg0: tensor<13x21x3xf32>) -> tensor<14x22x4xf32> { %cst = "tf.Const"() {device = "", value = dense<[[1, 0], [1, 0], [1, 0]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir index 7633be3acf0..95c8f252767 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt --pass-pipeline='func.func(tosa-legalize-tfl{disable-patterns=TFLConv2D,TFLSoftmax, enable-patterns=TFLFullyConnected,TFLTranspose})' %s | FileCheck %s +// RUN: tf-opt --pass-pipeline='builtin.module(func.func(tosa-legalize-tfl{disable-patterns=TFLConv2D,TFLSoftmax, enable-patterns=TFLFullyConnected,TFLTranspose}))' %s | FileCheck %s // ----- diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index daad78226a4..839839a0926 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -12,7 +12,7 @@ // CHECK-LABEL: test_conv2d // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<16xf32>} -// CHECK: %[[VAR1:.*]] = "tosa.conv2d"(%arg0, %arg1, %[[VAR0]]) {dilation = [1, 1], pad = [0, 1, 0, 1], stride = [1, 1]} +// CHECK: %[[VAR1:.*]] = "tosa.conv2d"(%arg0, %arg1, %[[VAR0]]) {dilation = array, pad = array, stride = array} func.func @test_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>) -> tensor<*xf32> { %cst = arith.constant dense<0.000000e+00> : tensor<16xf32> %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<*xf32> @@ -33,7 +33,7 @@ func.func @test_conv2d_dynamic(%arg0: tensor, %arg1: tensor<16x1x // ----- // CHECK-LABEL: test_conv2d_bias -// CHECK: %[[VAR0:.*]] = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 1, 0, 1], stride = [1, 1]} +// CHECK: %[[VAR0:.*]] = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} // CHECK-SAME: tensor<1x32x32x16xf32> func.func @test_conv2d_bias(%arg0: tensor<1x32x32x8xf32>, %cst: tensor<16x2x2x8xf32>, %cst_0: tensor<16xf32>) -> tensor<*xf32> { %0 = "tfl.conv_2d"(%arg0, %cst, %cst_0) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<*xf32> @@ -44,11 +44,24 @@ func.func @test_conv2d_bias(%arg0: tensor<1x32x32x8xf32>, %cst: tensor<16x2x2x8x // CHECK-LABEL: test_transpose_conv2d // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<16xf32>} -// CHECK: %[[VAR1:.*]] = "tosa.transpose_conv2d"(%arg0, %arg1, %[[VAR0]]) {out_pad = [0, 0, 0, 0], out_shape = [1, 32, 32, 16], stride = [1, 1]} +// CHECK: %[[VAR1:.*]] = "tosa.transpose_conv2d"(%arg0, %arg1, %[[VAR0]]) {out_pad = array, out_shape = array, stride = array} func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16x1x1x8xf32>) -> tensor<1x32x32x16xf32> { %cst = arith.constant dense<[1, 32, 32, 16]> : tensor<4xi32> %cst_1 = "tfl.no_value"() {value = unit} : () -> none - %0 = "tfl.transpose_conv"(%cst, %cst_0, %arg0, %cst_1) {padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<4xi32>, tensor<16x1x1x8xf32>, tensor<1x32x32x8xf32>, none) -> tensor<1x32x32x16xf32> + %0 = "tfl.transpose_conv"(%cst, %cst_0, %arg0, %cst_1) {padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<16x1x1x8xf32>, tensor<1x32x32x8xf32>, none) -> tensor<1x32x32x16xf32> + func.return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +// CHECK-LABEL: test_transpose_conv2d_relu +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<16xf32>} +// CHECK: %[[VAR1:.*]] = "tosa.transpose_conv2d"(%arg0, %arg1, %[[VAR0]]) {out_pad = array, out_shape = array, stride = array} +// CHECK: %[[VAR2:.*]] = "tosa.clamp"(%[[VAR1]]) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} +func.func @test_transpose_conv2d_relu(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16x1x1x8xf32>) -> tensor<1x32x32x16xf32> { + %cst = arith.constant dense<[1, 32, 32, 16]> : tensor<4xi32> + %cst_1 = "tfl.no_value"() {value = unit} : () -> none + %0 = "tfl.transpose_conv"(%cst, %cst_0, %arg0, %cst_1) {padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "RELU"} : (tensor<4xi32>, tensor<16x1x1x8xf32>, tensor<1x32x32x8xf32>, none) -> tensor<1x32x32x16xf32> func.return %0 : tensor<1x32x32x16xf32> } @@ -57,7 +70,7 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16 // CHECK-LABEL: test_conv2d_qi8 // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<16x2x2x8xi8>} // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<0> : tensor<16xi32>} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.conv2d"(%arg0, %[[VAR0]], %[[VAR1]]) {dilation = [1, 1], pad = [0, 1, 0, 1], quantization_info = #tosa.conv_quant, stride = [1, 1]} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.conv2d"(%arg0, %[[VAR0]], %[[VAR1]]) {dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} // CHECK: %[[VAR3:.*]] = "tosa.rescale"(%[[VAR2]]) func.func @test_conv2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform>) -> tensor<1x32x32x16x!quant.uniform> { %0 = "tfl.pseudo_qconst"() {qtype = tensor<16x2x2x8x!quant.uniform:f32:0, {0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1}>>, value = dense<42> : tensor<16x2x2x8xi8>} : () -> tensor<16x2x2x8x!quant.uniform:f32:0, {0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1} >> @@ -71,7 +84,7 @@ func.func @test_conv2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform : tensor<16xi48>} // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<16x1x1x8xi8>} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.conv2d"(%arg0, %[[VAR1]], %[[VAR0]]) {dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant, stride = [1, 1]} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.conv2d"(%arg0, %[[VAR1]], %[[VAR0]]) {dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} // CHECK: %[[VAR3:.*]] = "tosa.rescale"(%[[VAR2]]) func.func @test_conv2d_qi16(%arg0: tensor<1x32x32x8x!quant.uniform>) -> tensor<1x32x32x16x!quant.uniform> { %0 = "tfl.pseudo_qconst"() {qtype = tensor<16x1x1x8x!quant.uniform>, value = dense<42> : tensor<16x1x1x8xi8>} : () -> tensor<16x1x1x8x!quant.uniform> @@ -82,16 +95,16 @@ func.func @test_conv2d_qi16(%arg0: tensor<1x32x32x8x!quant.uniform // ----- -// CHECK-LABEL: test_depthwise_conv2d_bias_qi8 -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<1x2x2x16xi8>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<16xi32>} -// CHECK-DAG: %[[VAR3:.*]] = "tosa.transpose"(%[[VAR0]], %[[VAR1]]) -// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%[[VAR3]]) {new_shape = [2, 2, 8, 2]} -// CHECK-DAG: %[[VAR5:.*]] = "tosa.depthwise_conv2d"(%arg0, %[[VAR4]], %[[VAR2]]) {dilation = [1, 1], pad = [0, 1, 0, 1], quantization_info = #tosa.conv_quant, stride = [1, 1]} -// CHECK: %[[VAR6:.*]] = "tosa.rescale"(%[[VAR5]]) -// CHECK-SAME: multiplier = [1373724854 : i32, 1373724854 : i32, 1373724854 : i32, 1373724854 : i32, 1803013871 : i32, 1373724854 : i32, 1373724854 : i32, 1373724854 : i32, 1373724854 : i32, 1373724854 : i32, 1373724854 : i32, 1373724854 : i32, 1373724854 : i32, 1373724854 : i32, 1373724854 : i32, 1373724854 : i32] -// CHECK-SAME: shift = [36 : i32, 36 : i32, 36 : i32, 36 : i32, 32 : i32, 36 : i32, 36 : i32, 36 : i32, 36 : i32, 36 : i32, 36 : i32, 36 : i32, 36 : i32, 36 : i32, 36 : i32, 36 : i32] +// CHECK-LABEL: @test_depthwise_conv2d_bias_qi8 +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x32x32x8x!quant.uniform> +// CHECK-DAG: %[[CONST:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<16xi32>} +// CHECK-DAG: %[[CONST_0:.*]] = "tosa.const"() {value = dense<{{.*}}> : tensor<1x2x2x16xi8>} +// CHECK-DAG: %[[RESHAPE:.*]] = "tosa.reshape"(%[[CONST_0]]) {new_shape = array} +// CHECK-DAG: %[[DEPTHWISE:.*]] = "tosa.depthwise_conv2d"(%[[ARG0]], %[[RESHAPE]], %[[CONST]]) {dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} +// CHECK: %[[RESCALE:.*]] = "tosa.rescale"(%[[DEPTHWISE]]) +// CHECK-SAME: multiplier = array +// CHECK-SAME: shift = array +// CHECK: return %[[RESCALE]] func.func @test_depthwise_conv2d_bias_qi8(%arg0: tensor<1x32x32x8x!quant.uniform>) -> tensor<1x32x32x16x!quant.uniform> { %0 = "tfl.pseudo_qconst"() {qtype = tensor<1x2x2x16x!quant.uniform:f32:3, {0.1,0.1,0.1,0.1,2.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1}>>, value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127], [-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]], [[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127], [-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x2x2x16xi8>} : () -> tensor<1x2x2x16x!quant.uniform:f32:3, {0.1,0.1,0.1,0.1,2.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1} >> %1 = "tfl.pseudo_qconst"() {qtype = tensor<16x!quant.uniform>, value = dense<[-2879, 6636, 3531, 23376, -79787, -6142, 5582, -30384, 17330, -4549, -3518, 16215, 2695, -2670, 8399, -12223]> : tensor<16xi32>} : () -> tensor<16x!quant.uniform> @@ -117,7 +130,7 @@ func.func @test_depthwise_conv2d_bias_inferred(%arg0: tensor, %ar // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<4xf32>} // CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() {value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>} // CHECK: %[[VAL_4:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_3]]) -// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) {dilation = [1, 1, 1], pad = [0, 1, 1, 1, 1, 1], stride = [1, 1, 1]} +// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) {dilation = array, pad = array, stride = array} func.func @test_conv3d(%arg0: tensor<2x2x7x7x2xf32>, %arg1: tensor<2x3x3x2x4xf32>) -> tensor<2x2x7x7x4xf32> { %cst = "tfl.no_value"() {value} : () -> none %0 = "tfl.conv_3d"(%arg0, %arg1, %cst) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<2x2x7x7x2xf32>, tensor<2x3x3x2x4xf32>, none) -> tensor<2x2x7x7x4xf32> @@ -132,7 +145,7 @@ func.func @test_conv3d(%arg0: tensor<2x2x7x7x2xf32>, %arg1: tensor<2x3x3x2x4xf32 // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<16xf32>} // CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() {value = dense<[4, 0, 1, 2, 3]> : tensor<5xi32>} // CHECK: %[[VAL_4:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_3]]) -// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) {dilation = [1, 1, 1], pad = [1, 1, 0, 0, 0, 0], stride = [1, 1, 1]} +// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) {dilation = array, pad = array, stride = array} func.func @test_conv3d_dynamic(%arg0: tensor, %arg1: tensor<3x1x1x8x16xf32>) -> tensor<*xf32> { %cst = "tfl.no_value"() {value} : () -> none %0 = "tfl.conv_3d"(%arg0, %arg1, %cst) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor, tensor<3x1x1x8x16xf32>, none) -> tensor<*xf32> @@ -147,7 +160,7 @@ func.func @test_conv3d_dynamic(%arg0: tensor, %arg1: tensor<3x // CHECK-SAME: %[[VAL_2:.*]]: tensor<8xf32> // CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() {value = dense<[4, 0, 1, 2, 3]> // CHECK: %[[VAL_4:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_3]]) -// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) {dilation = [1, 1, 1], pad = [7, 8, 0, 1, 0, 1], stride = [1, 1, 1]} +// CHECK: %[[VAL_5:.*]] = "tosa.conv3d"(%[[VAL_0]], %[[VAL_4]], %[[VAL_2]]) {dilation = array, pad = array, stride = array} func.func @test_conv3d_bias(%arg0: tensor<10x3x64x64x12xf32>, %arg1: tensor<16x2x2x12x8xf32>, %cst: tensor<8xf32>) -> tensor<10x3x64x64x8xf32> { %0 = "tfl.conv_3d"(%arg0, %arg1, %cst) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<10x3x64x64x12xf32>, tensor<16x2x2x12x8xf32>, tensor<8xf32>) -> tensor<10x3x64x64x8xf32> func.return %0 : tensor<10x3x64x64x8xf32> @@ -168,7 +181,7 @@ func.func @test_conv3d_bias(%arg0: tensor<10x3x64x64x12xf32>, %arg1: tensor<16x2 // CHECK: %[[VAL_9:.*]] = "tosa.sub"(%[[VAL_8]], %[[VAL_2]]) // CHECK: %[[VAL_10:.*]] = "tosa.mul"(%[[VAL_9]], %[[VAL_3]]) {shift = 0 : i32} // CHECK: %[[VAL_11:.*]] = "tosa.transpose"(%[[VAL_1]], %[[VAL_7]]) -// CHECK: %[[VAL_12:.*]] = "tosa.conv3d"(%[[VAL_10]], %[[VAL_11]], %[[VAL_6]]) {dilation = [1, 1, 1], pad = [0, 1, 1, 1, 1, 1], stride = [1, 1, 2]} +// CHECK: %[[VAL_12:.*]] = "tosa.conv3d"(%[[VAL_10]], %[[VAL_11]], %[[VAL_6]]) {dilation = array, pad = array, stride = array} // CHECK: %[[VAL_13:.*]] = "tosa.mul"(%[[VAL_12]], %[[VAL_4]]) {shift = 0 : i32} // CHECK: %[[VAL_14:.*]] = "tosa.add"(%[[VAL_13]], %[[VAL_5]]) // CHECK: %[[VAL_15:.*]] = "tosa.cast"(%[[VAL_14]]) @@ -284,6 +297,15 @@ func.func @test_relu1(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- +// CHECK-LABEL: test_relu0To1 +// CHECK: %[[VAL0:.*]] = "tosa.clamp"(%arg0) {max_fp = 1.000000e+00 : f32, max_int = 1 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} +func.func @test_relu0To1(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tfl.relu_0_to_1"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + func.return %0 : tensor<13x21x3xf32> +} + +// ----- + // CHECK-LABEL: test_relu6 // CHECK: %[[VAR0:.*]] = "tosa.clamp"(%arg0) {max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} func.func @test_relu6(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { @@ -318,7 +340,7 @@ func.func @test_leaky_relu(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_prelu // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 2, 3]} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = array} // CHECK-DAG: %[[VAR2:.*]] = "tosa.mul"(%arg0, %[[VAR1]]) {shift = 0 : i32} // CHECK-DAG: %[[VAR3:.*]] = "tosa.greater_equal"(%arg0, %[[VAR0]]) // CHECK: %[[VAR4:.*]] = "tosa.select"(%[[VAR3]], %arg0, %[[VAR2]]) @@ -365,7 +387,7 @@ func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<*xi1> { // CHECK-LABEL: test_reduce_any // CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_any"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [21, 3]} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { %cst = arith.constant dense<0> : tensor<1xi32> %0 = "tfl.reduce_any"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xi1>, tensor<1xi32>) -> tensor<21x3xi1> @@ -376,7 +398,7 @@ func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { // CHECK-LABEL: test_reduce_min // CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_min"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [21, 3]} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %cst = arith.constant dense<0> : tensor<1xi32> %0 = "tfl.reduce_min"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> @@ -387,7 +409,7 @@ func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // CHECK-LABEL: test_reduce_max // CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_max"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [21, 3]} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %cst = arith.constant dense<0> : tensor<1xi32> %0 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> @@ -398,13 +420,29 @@ func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // CHECK-LABEL: test_reduce_sum // CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [21, 3]} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %cst = arith.constant dense<0> : tensor<1xi32> %0 = "tfl.sum"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> func.return %0 : tensor<21x3xf32> } +// CHECK-LABEL: test_reduce_sum_nonzero_axis +// CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20x30x40x50xf32> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() {value = dense<[0, 1, 2, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32> +// CHECK: %[[VAL_2:.*]] = "tosa.transpose"(%[[VAL_0]], %[[VAL_1]]) : (tensor<10x20x30x40x50xf32>, tensor<5xi32>) -> tensor<10x20x30x50x40xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_2]]) {new_shape = array} : (tensor<10x20x30x50x40xf32>) -> tensor<300000x40xf32> +// CHECK: %[[VAL_4:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 1 : i64} : (tensor<300000x40xf32>) -> tensor<300000x1xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.reshape"(%[[VAL_4]]) {new_shape = array} : (tensor<300000x1xf32>) -> tensor<10x20x30x50xf32> +// CHECK: return %[[VAL_5]] : tensor<10x20x30x50xf32> +func.func @test_reduce_sum_nonzero_axis(%arg0: tensor<10x20x30x40x50xf32> {tf._user_specified_name = "inp_list"}) -> tensor<10x20x30x50xf32> { + %cst = arith.constant dense<3> : tensor + %0 = "tfl.sum"(%arg0, %cst) {device = "", keep_dims = false} : (tensor<10x20x30x40x50xf32>, tensor) -> tensor<10x20x30x50xf32> + func.return %0 : tensor<10x20x30x50xf32> +} + +// ----- + // ----- // CHECK-LABEL: test_reduce_sum_5D @@ -412,9 +450,9 @@ func.func @test_reduce_sum_5D(%arg0: tensor<4x5x6x7x8xf32>) -> tensor<6x8xf32> { %cst = arith.constant dense<[0, 1, 3]> : tensor<3xi32> // CHECK-DAG: %[[PERM:.+]] = "tosa.const"() {value = dense<[2, 4, 0, 1, 3]> : tensor<5xi32>} // CHECK-DAG: %[[TRANSPOSE:.+]] = "tosa.transpose"(%arg0, %[[PERM]]) - // CHECK-DAG: %[[RESHAPE0:.+]] = "tosa.reshape"(%[[TRANSPOSE:.+]]) {new_shape = [48, 140]} + // CHECK-DAG: %[[RESHAPE0:.+]] = "tosa.reshape"(%[[TRANSPOSE:.+]]) {new_shape = array} // CHECK-DAG: %[[REDUCE:.+]] = "tosa.reduce_sum"(%[[RESHAPE0]]) {axis = 1 : i64} - // CHECK: %[[RESHAPE1:.+]] = "tosa.reshape"(%[[REDUCE]]) {new_shape = [6, 8]} + // CHECK: %[[RESHAPE1:.+]] = "tosa.reshape"(%[[REDUCE]]) {new_shape = array} %0 = "tfl.sum"(%arg0, %cst) {keep_dims = false} : (tensor<4x5x6x7x8xf32>, tensor<3xi32>) -> tensor<6x8xf32> func.return %0 : tensor<6x8xf32> } @@ -424,7 +462,7 @@ func.func @test_reduce_sum_5D(%arg0: tensor<4x5x6x7x8xf32>) -> tensor<6x8xf32> { // CHECK-LABEL: test_reduce_mean // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<0.0769230798> : tensor<1x1xf32>} // CHECK-DAG: %[[VAR1:.*]] = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.reshape"(%[[VAR1]]) {new_shape = [21, 3]} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.reshape"(%[[VAR1]]) {new_shape = array} // CHECK: %[[VAR4:.*]] = "tosa.mul"(%[[VAR2]], %[[VAR0]]) {shift = 0 : i32} func.func @test_reduce_mean(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %cst = arith.constant dense<0> : tensor<1xi32> @@ -436,7 +474,7 @@ func.func @test_reduce_mean(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // CHECK-LABEL: test_reduce_product // CHECK-DAG: %[[VAR0:.*]] = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [21, 3]} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} func.func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %cst = arith.constant dense<0> : tensor<1xi32> %0 = "tfl.reduce_prod"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor<21x3xf32> @@ -602,6 +640,47 @@ func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> { // ----- +// CHECK-LABEL: test_atan2 +// CHECK-SAME: -> tensor<13x21x3xf32> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() {value = dense<3.276700e+04> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() {value = dense<2.38418579E-7> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() {value = dense<1.57079637> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_7:.*]] = "tosa.const"() {value = dense<3.14159274> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_8:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_9:.*]] = "tosa.const"() {value = dense<{{.+}}> : tensor<513xi16>} : () -> tensor<513xi16> +// CHECK: %[[VAL_10:.*]] = "tosa.abs"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_11:.*]] = "tosa.abs"(%arg1) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_12:.*]] = "tosa.minimum"(%[[VAL_10]], %[[VAL_11]]) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_13:.*]] = "tosa.maximum"(%[[VAL_10]], %[[VAL_11]]) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_14:.*]] = "tosa.reciprocal"(%[[VAL_13]]) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_15:.*]] = "tosa.mul"(%[[VAL_14]], %[[VAL_12]]) {shift = 0 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_16:.*]] = "tosa.mul"(%[[VAL_15]], %[[VAL_2]]) {shift = 0 : i32} : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_17:.*]] = "tosa.sub"(%[[VAL_16]], %[[VAL_3]]) : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_18:.*]] = "tosa.mul"(%[[VAL_17]], %[[VAL_4]]) {shift = 0 : i32} : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_19:.*]] = "tosa.cast"(%[[VAL_18]]) : (tensor<13x21x3xf32>) -> tensor<13x21x3xi16> +// CHECK: %[[VAL_20:.*]] = "tosa.table"(%[[VAL_19]], %[[VAL_9]]) : (tensor<13x21x3xi16>, tensor<513xi16>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_21:.*]] = "tosa.cast"(%[[VAL_20]]) : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_22:.*]] = "tosa.mul"(%[[VAL_21]], %[[VAL_5]]) {shift = 0 : i32} : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_23:.*]] = "tosa.sub"(%[[VAL_6]], %[[VAL_22]]) : (tensor<1x1x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_24:.*]] = "tosa.greater"(%[[VAL_10]], %[[VAL_11]]) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> +// CHECK: %[[VAL_25:.*]] = "tosa.select"(%[[VAL_24]], %[[VAL_23]], %[[VAL_22]]) : (tensor<13x21x3xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_26:.*]] = "tosa.sub"(%[[VAL_7]], %[[VAL_25]]) : (tensor<1x1x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_27:.*]] = "tosa.greater"(%[[VAL_8]], %arg1) : (tensor<1x1x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> +// CHECK: %[[VAL_28:.*]] = "tosa.select"(%[[VAL_27]], %[[VAL_26]], %[[VAL_25]]) : (tensor<13x21x3xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_29:.*]] = "tosa.negate"(%[[VAL_28]]) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_30:.*]] = "tosa.greater"(%[[VAL_8]], %arg0) : (tensor<1x1x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> +// CHECK: %[[VAL_31:.*]] = "tosa.select"(%[[VAL_30]], %[[VAL_29]], %[[VAL_28]]) : (tensor<13x21x3xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: return %[[VAL_31]] : tensor<13x21x3xf32> +func.func @test_atan2(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<*xf32> { + %0 = "tfl.atan2"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + + // CHECK-LABEL: test_sigmoid // CHECK: %[[VAR0:.*]] = "tosa.sigmoid"(%arg0) func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { @@ -686,7 +765,7 @@ func.func @test_less_equal_dynamic(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x? // ----- // CHECK-LABEL: test_avg_pool2d -// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} +// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} func.func @test_avg_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -695,7 +774,7 @@ func.func @test_avg_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_avg_pool2d_dynamic -// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} +// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} func.func @test_avg_pool2d_dynamic(%arg0: tensor) -> tensor<*xf32> { %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -704,7 +783,7 @@ func.func @test_avg_pool2d_dynamic(%arg0: tensor) -> tensor<*xf32 // ----- // CHECK-LABEL: test_max_pool2d -// CHECK: %[[VAR0:.*]] = "tosa.max_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} +// CHECK: %[[VAR0:.*]] = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -713,7 +792,7 @@ func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_max_pool2d_dynamic -// CHECK: %[[VAR0:.*]] = "tosa.max_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} +// CHECK: %[[VAR0:.*]] = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} func.func @test_max_pool2d_dynamic(%arg0: tensor) -> tensor<*xf32> { %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -722,7 +801,7 @@ func.func @test_max_pool2d_dynamic(%arg0: tensor) -> tensor<*xf32 // ----- // CHECK-LABEL: test_reshape -// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 819]} +// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[1, 819]> : tensor<2xi32> %0 = "tfl.reshape"(%arg0, %cst) : (tensor<13x21x3xf32>, tensor<2xi32>) -> tensor<*xf32> @@ -732,7 +811,7 @@ func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_reshape_unknown -// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [9, -1]} +// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} // CHECK-SAME: -> tensor<9x91xf32> func.func @test_reshape_unknown(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[9, -1]> : tensor<2xi32> @@ -743,7 +822,7 @@ func.func @test_reshape_unknown(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_reshape_dynamic -// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [3, -1]} +// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} // CHECK-SAME: -> tensor<3x?xf32> func.func @test_reshape_dynamic(%arg0: tensor<13x21x?xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[3, -1]> : tensor<2xi32> @@ -776,7 +855,7 @@ func.func @test_transpose(%arg0: tensor<13x?x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_slice -// CHECK: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = [4, 11, 1], start = [6, 8, 0]} +// CHECK: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[6, 8, 0]> : tensor<3xi32> %cst_0 = arith.constant dense<[4, 11, 1]> : tensor<3xi32> @@ -787,10 +866,10 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_strided_slice_simple -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = [9, 21, 2], start = [4, 0, 1]} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [9, 1, 7, 3, 2, 1]} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) {size = [9, 1, 7, 1, 2, 1], start = [0, 0, 0, 0, 0, 0]} -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [9, 7, 2]} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) {size = array, start = array} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} func.func @test_strided_slice_simple(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32> %cst_0 = arith.constant dense<[13, 21, 3]> : tensor<3xi32> @@ -802,8 +881,8 @@ func.func @test_strided_slice_simple(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32 // ----- // CHECK-LABEL: test_strided_slice_strideless -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = [9, 1, 2], start = [4, 0, 1]} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [9, 2]} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} func.func @test_strided_slice_strideless(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32> %cst_0 = arith.constant dense<[13, 21, 3]> : tensor<3xi32> @@ -815,10 +894,10 @@ func.func @test_strided_slice_strideless(%arg0: tensor<13x21x3xf32>) -> tensor<* // ----- // CHECK-LABEL: test_strided_slice_shrink -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = [1, 21, 1], start = [4, 0, 1]} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [1, 1, 7, 3, 1, 1]} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) {size = [1, 1, 7, 1, 1, 1], start = [0, 0, 0, 0, 0, 0]} -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [7]} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.slice"(%[[VAR1]]) {size = array, start = array} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} func.func @test_strided_slice_shrink(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32> %cst_0 = arith.constant dense<[13, 21, 3]> : tensor<3xi32> @@ -830,8 +909,8 @@ func.func @test_strided_slice_shrink(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32 // ----- // CHECK-LABEL: test_strided_slice_shrink_ignore_stride -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = [1, 1, 2], start = [4, 0, 1]} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [2]} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} func.func @test_strided_slice_shrink_ignore_stride(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[4, 0, 1]> : tensor<3xi32> %cst_0 = arith.constant dense<[13, 21, 3]> : tensor<3xi32> @@ -843,8 +922,8 @@ func.func @test_strided_slice_shrink_ignore_stride(%arg0: tensor<13x21x3xf32>) - // ----- // CHECK-LABEL: test_strided_slice_unstrided -// CEHCK-SAME: -> tensor<9x21x2xf32> -// CHECK: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = [9, 21, 2], start = [4, 0, 1]} +// CHECK-SAME: -> tensor<9x21x2xf32> +// CHECK: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} // CHECK: %[[VAR1:.*]] = "tosa.reverse"(%[[VAR0]]) {axis = 2 : i64} // CHECK: return %[[VAR1]] func.func @test_strided_slice_unstrided(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { @@ -859,7 +938,7 @@ func.func @test_strided_slice_unstrided(%arg0: tensor<13x21x3xf32>) -> tensor<*x // CHECK-LABEL: test_strided_slice_unstrided_shorter // CHECK: -> tensor<9x21x3xf32> -// CHECK: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = [9, 21, 3], start = [4, 0, 0]} +// CHECK: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} // CHECK: %[[VAR1:.*]] = "tosa.reverse"(%[[VAR0]]) {axis = 1 : i64} // CHECK: return %[[VAR1]] func.func @test_strided_slice_unstrided_shorter(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { @@ -911,10 +990,10 @@ func.func @test_strided_slice_dynamic_end(%arg0: tensor<10x?x?xf32>) -> tensor<* %end = arith.constant dense<[7, -1, 6]> : tensor<3xi32> %stride = arith.constant dense<[1, 2, -1]> : tensor<3xi32> - // CHECK: %[[SLICE1:.+]] = "tosa.slice"(%arg0) {size = [7, -1, 1], start = [0, 1, 2]} - // CHECK: %[[RESHAPE1:.+]] = "tosa.reshape"(%[[SLICE1]]) {new_shape = [7, 1, -1, 2, 1, 1]} - // CHECK: %[[SLICE2:.+]] = "tosa.slice"(%[[RESHAPE1]]) {size = [7, 1, -1, 1, 1, 1], start = [0, 0, 0, 0, 0, 0]} - // CHECK: %[[RESHAPE2:.+]] = "tosa.reshape"(%[[SLICE2]]) {new_shape = [7, -1]} + // CHECK: %[[SLICE1:.+]] = "tosa.slice"(%arg0) {size = array, start = array} + // CHECK: %[[RESHAPE1:.+]] = "tosa.reshape"(%[[SLICE1]]) {new_shape = array} + // CHECK: %[[SLICE2:.+]] = "tosa.slice"(%[[RESHAPE1]]) {size = array, start = array} + // CHECK: %[[RESHAPE2:.+]] = "tosa.reshape"(%[[SLICE2]]) {new_shape = array} %0 = "tfl.strided_slice"(%arg0, %begin, %end, %stride) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 4 : i32} : (tensor<10x?x?xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<*xf32> // CHECK: return %[[RESHAPE2]] func.return %0 : tensor<*xf32> @@ -923,7 +1002,7 @@ func.func @test_strided_slice_dynamic_end(%arg0: tensor<10x?x?xf32>) -> tensor<* // ----- // CHECK-LABEL: test_select -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 1]} : (tensor<1xi1>) -> tensor<1x1x1xi1> +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg2) {new_shape = array} : (tensor<1xi1>) -> tensor<1x1x1xi1> // CHECK: %[[VAR2:.*]] = "tosa.select"(%[[VAR1]], %arg0, %arg1) func.func @test_select(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<1xi1>) -> tensor<13x21x3xf32> { %0 = "tfl.select_v2"(%arg2, %arg0, %arg1) : (tensor<1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> @@ -964,7 +1043,7 @@ func.func @test_concatv2(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, // CHECK-LABEL: test_stack // CHECK-DAG: %[[VAR0:.*]] = "tosa.concat"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i64} -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = [4, 13, 21, 3]} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%[[VAR0]]) {new_shape = array} func.func @test_stack(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32> { %0 = "tfl.pack"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i32, values_count = 4 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32> func.return %0 : tensor<4x13x21x3xf32> @@ -973,7 +1052,7 @@ func.func @test_stack(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %a // ----- // CHECK-LABEL: test_unstack -// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = [32, 32, 8]} +// CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} func.func @test_unstack(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { %0 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 1 : i32} : (tensor<1x32x32x8xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -1013,7 +1092,7 @@ func.func @test_pad_v2(%arg0: tensor<1x256x8x25xf32>) -> (tensor<*xf32>) { // ----- // CHECK-LABEL: test_expand_dims -// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 13, 21, 3]} +// CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} func.func @test_expand_dims(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[1, 13, 21, 3]> : tensor<4xi32> %0 = "tfl.reshape"(%arg0, %cst) : (tensor<13x21x3xf32>, tensor<4xi32>) -> tensor<*xf32> @@ -1168,10 +1247,10 @@ func.func @test_batch_matmul(%arg0: tensor<1x16x128xf32>, %arg1: tensor<1x128x32 // CHECK-LABEL: @test_batch_matmul_4d func.func @test_batch_matmul_4d(%arg0: tensor<4x5x16x128xf32>, %arg1: tensor<4x5x128x32xf32>) -> (tensor<4x5x16x32xf32> ) { - // CHECK: %[[R0:.*]] = "tosa.reshape"(%arg0) {new_shape = [20, 16, 128]} - // CHECK: %[[R1:.*]] = "tosa.reshape"(%arg1) {new_shape = [20, 128, 32]} + // CHECK: %[[R0:.*]] = "tosa.reshape"(%arg0) {new_shape = array} + // CHECK: %[[R1:.*]] = "tosa.reshape"(%arg1) {new_shape = array} // CHECK: %[[MM:.*]] = "tosa.matmul"(%[[R0]], %[[R1]]) - // CHECK: "tosa.reshape"(%[[MM]]) {new_shape = [4, 5, 16, 32]} + // CHECK: "tosa.reshape"(%[[MM]]) {new_shape = array} %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<4x5x16x128xf32>, tensor<4x5x128x32xf32>) -> tensor<4x5x16x32xf32> func.return %0 : tensor<4x5x16x32xf32> } @@ -1265,9 +1344,9 @@ func.func @test_fused_activation_relun1to1_clamp( // ----- // CHECK-LABEL: test_split -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = [13, 7, 3], start = [0, 0, 0]} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.slice"(%arg0) {size = [13, 7, 3], start = [0, 7, 0]} -// CHECK: %[[VAR2:.*]] = "tosa.slice"(%arg0) {size = [13, 7, 3], start = [0, 14, 0]} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK: %[[VAR2:.*]] = "tosa.slice"(%arg0) {size = array, start = array} func.func @test_split(%arg0: tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>) { %cst_0 = arith.constant dense<1> : tensor %0:3 = "tfl.split"(%cst_0, %arg0) {num_splits = 3 : i32} : (tensor, tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>) @@ -1279,13 +1358,13 @@ func.func @test_split(%arg0: tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor // CHECK-LABEL: test_split_dynamic func.func @test_split_dynamic(%arg0: tensor<13x?x3xf32>) -> (tensor<13x?x3xf32>, tensor<13x?x3xf32>, tensor<13x?x3xf32>) { %cst_0 = arith.constant dense<1> : tensor - // CHECK-DAG: %[[VAR0:.+]] = "tosa.reshape"(%arg0) {new_shape = [13, 3, -1, 3]} - // CHECK-DAG: %[[VAR1:.+]] = "tosa.slice"(%[[VAR0]]) {size = [13, 1, -1, 3], start = [0, 0, 0, 0]} - // CHECK-DAG: %[[VAR2:.+]] = "tosa.slice"(%[[VAR0]]) {size = [13, 1, -1, 3], start = [0, 1, 0, 0]} - // CHECK-DAG: %[[VAR3:.+]] = "tosa.slice"(%[[VAR0]]) {size = [13, 1, -1, 3], start = [0, 2, 0, 0]} - // CHECK-DAG: %[[VAR4:.+]] = "tosa.reshape"(%[[VAR1]]) {new_shape = [13, -1, 3]} - // CHECK-DAG: %[[VAR5:.+]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [13, -1, 3]} - // CHECK-DAG: %[[VAR6:.+]] = "tosa.reshape"(%[[VAR3]]) {new_shape = [13, -1, 3]} + // CHECK-DAG: %[[VAR0:.+]] = "tosa.reshape"(%arg0) {new_shape = array} + // CHECK-DAG: %[[VAR1:.+]] = "tosa.slice"(%[[VAR0]]) {size = array, start = array} + // CHECK-DAG: %[[VAR2:.+]] = "tosa.slice"(%[[VAR0]]) {size = array, start = array} + // CHECK-DAG: %[[VAR3:.+]] = "tosa.slice"(%[[VAR0]]) {size = array, start = array} + // CHECK-DAG: %[[VAR4:.+]] = "tosa.reshape"(%[[VAR1]]) {new_shape = array} + // CHECK-DAG: %[[VAR5:.+]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} + // CHECK-DAG: %[[VAR6:.+]] = "tosa.reshape"(%[[VAR3]]) {new_shape = array} // CHECK: return %[[VAR4]], %[[VAR5]], %[[VAR6]] %0:3 = "tfl.split"(%cst_0, %arg0) {num_splits = 3 : i32} : (tensor, tensor<13x?x3xf32>) -> (tensor<13x?x3xf32>, tensor<13x?x3xf32>, tensor<13x?x3xf32>) func.return %0#0, %0#1, %0#2 : tensor<13x?x3xf32>, tensor<13x?x3xf32>, tensor<13x?x3xf32> @@ -1294,9 +1373,9 @@ func.func @test_split_dynamic(%arg0: tensor<13x?x3xf32>) -> (tensor<13x?x3xf32>, // ----- // CHECK-LABEL: test_split_neg -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = [13, 7, 3], start = [0, 0, 0]} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.slice"(%arg0) {size = [13, 7, 3], start = [0, 7, 0]} -// CHECK: %[[VAR2:.*]] = "tosa.slice"(%arg0) {size = [13, 7, 3], start = [0, 14, 0]} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK: %[[VAR2:.*]] = "tosa.slice"(%arg0) {size = array, start = array} func.func @test_split_neg(%arg0: tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>) { %cst_0 = arith.constant dense<-2> : tensor %0:3 = "tfl.split"(%cst_0, %arg0) {num_splits = 3 : i32} : (tensor, tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor<13x7x3xf32>, tensor<13x7x3xf32>) @@ -1306,9 +1385,9 @@ func.func @test_split_neg(%arg0: tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, te // ----- // CHECK-LABEL: test_split_axis_0 -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = [7, 13, 3], start = [0, 0, 0]} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.slice"(%arg0) {size = [7, 13, 3], start = [7, 0, 0]} -// CHECK: %[[VAR2:.*]] = "tosa.slice"(%arg0) {size = [7, 13, 3], start = [14, 0, 0]} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK: %[[VAR2:.*]] = "tosa.slice"(%arg0) {size = array, start = array} func.func @test_split_axis_0(%arg0: tensor<21x13x3xf32>) -> (tensor<7x13x3xf32>, tensor<7x13x3xf32>, tensor<7x13x3xf32>) { %cst_0 = arith.constant dense<0> : tensor %0:3 = "tfl.split"(%cst_0, %arg0) {num_splits = 3 : i32} : (tensor, tensor<21x13x3xf32>) -> (tensor<7x13x3xf32>, tensor<7x13x3xf32>, tensor<7x13x3xf32>) @@ -1318,8 +1397,8 @@ func.func @test_split_axis_0(%arg0: tensor<21x13x3xf32>) -> (tensor<7x13x3xf32>, // ----- // CHECK-LABEL: test_split_v_neg_axis -// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = [2, 3, 3, 3], start = [0, 0, 0, 0]} -// CHECK: %[[VAR1:.*]] = "tosa.slice"(%arg0) {size = [2, 3, 3, 5], start = [0, 0, 0, 3]} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.slice"(%arg0) {size = array, start = array} +// CHECK: %[[VAR1:.*]] = "tosa.slice"(%arg0) {size = array, start = array} func.func @test_split_v_neg_axis(%arg0: tensor<2x3x3x8xf32>) -> (tensor<2x3x3x3xf32>, tensor<2x3x3x5xf32>) { %split_size = arith.constant dense<[3, 5]> : tensor<2xi32> %axis = arith.constant dense<-1> : tensor @@ -1344,9 +1423,9 @@ func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi32>} // CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() {value = dense<0.000000e+00> : tensor} // CHECK-DAG: %[[VAR2:.*]] = "tosa.pad"(%arg0, %[[VAR0]], %[[PVAL]]) -// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [13, 11, 2, 3]} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} // CHECK-DAG: %[[VAR4:.*]] = "tosa.transpose"(%[[VAR3]], %[[VAR1]]) -// CHECK: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) {new_shape = [26, 11, 3]} +// CHECK: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) {new_shape = array} func.func @test_space_to_batch(%arg0: tensor<13x21x3xf32>) -> tensor<26x11x3xf32> { %cst = arith.constant dense<2> : tensor<1xi32> %cst_0 = arith.constant dense<[[0, 1]]> : tensor<1x2xi32> @@ -1361,9 +1440,9 @@ func.func @test_space_to_batch(%arg0: tensor<13x21x3xf32>) -> tensor<26x11x3xf32 // CHECK-DAG: %[[C1:.+]] = "tosa.const"() {value = dense<{{\[\[}}0, 0], [0, 2], [0, 0], [0, 0]]> : tensor<4x2xi32>} // CHECK-DAG: %[[C2:.+]] = "tosa.const"() {value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>} // CHECK-DAG: %[[PAD:.+]] = "tosa.pad"(%arg0, %[[C1]], %[[C0]]) : (tensor, tensor<4x2xi32>, tensor) -> tensor -// CHECK-DAG: %[[R0:.+]] = "tosa.reshape"(%[[PAD]]) {new_shape = [-1, 81, 3, 1, 1, 80]} +// CHECK-DAG: %[[R0:.+]] = "tosa.reshape"(%[[PAD]]) {new_shape = array} // CHECK-DAG: %[[T:.+]] = "tosa.transpose"(%[[R0]], %[[C2]]) -// CHECK-DAG: %[[R1:.+]] = "tosa.reshape"(%[[T]]) {new_shape = [-1, 81, 1, 80]} +// CHECK-DAG: %[[R1:.+]] = "tosa.reshape"(%[[T]]) {new_shape = array} // CHECK: return %[[R1]] : tensor func.func @test_space_to_batch_dyn(%arg0 : tensor) -> (tensor) { %0 = "tfl.pseudo_const"() {value = dense<[3, 1]> : tensor<2xi32>} : () -> tensor<2xi32> @@ -1378,9 +1457,9 @@ func.func @test_space_to_batch_dyn(%arg0 : tensor) -> (tensor : tensor<4xi32>} // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<[2, 3, 0, 4, 1, 5]> : tensor<6xi32>} // CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%arg0, %[[VAR0]]) -// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [2, 2, 2, 32, 32, 1]} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} // CHECK-DAG: %[[VAR4:.*]] = "tosa.transpose"(%[[VAR3]], %[[VAR1]]) -// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) {new_shape = [2, 64, 64, 1]} +// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) {new_shape = array} // CHECK: return %[[VAR5:.*]] func.func @test_batch_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<2x64x64x1xf32> { %cst = arith.constant dense<2> : tensor<2xi32> @@ -1395,10 +1474,10 @@ func.func @test_batch_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<2x64x64x1 // CHECK-LABEL: @test_batch_to_space_dyn // CHECK-DAG: %[[C0:.+]] = "tosa.const"() {value = dense<[2, 3, 0, 4, 1, 5]> : tensor<6xi32>} -// CHECK-DAG: %[[R0:.+]] = "tosa.reshape"(%arg0) {new_shape = [3, 1, -1, 79, 1, 80]} +// CHECK-DAG: %[[R0:.+]] = "tosa.reshape"(%arg0) {new_shape = array} // CHECK-DAG: %[[T:.+]] = "tosa.transpose"(%[[R0]], %[[C0]]) -// CHECK-DAG: %[[R1:.+]] = "tosa.reshape"(%[[T]]) {new_shape = [-1, 237, 1, 80]} -// CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[R1]]) {size = [-1, 235, 1, 80], start = [0, 0, 0, 0]} +// CHECK-DAG: %[[R1:.+]] = "tosa.reshape"(%[[T]]) {new_shape = array} +// CHECK-DAG: %[[SLICE:.+]] = "tosa.slice"(%[[R1]]) {size = array, start = array} // CHECK: return %[[SLICE]] func.func @test_batch_to_space_dyn(%arg0 : tensor) -> (tensor) { %0 = "tfl.pseudo_const"() {value = dense<[3, 1]> : tensor<2xi32>} : () -> tensor<2xi32> @@ -1411,9 +1490,9 @@ func.func @test_batch_to_space_dyn(%arg0 : tensor) -> (tensor : tensor<6xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 16, 2, 16, 2, 8]} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} // CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%[[VAR1]], %[[VAR0]]) -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [1, 16, 16, 32]} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} func.func @test_space_to_depth(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x16x16x32xf32> { %0 = "tfl.space_to_depth"(%arg0) {block_size = 2 : i32} : (tensor<1x32x32x8xf32>) -> tensor<1x16x16x32xf32> func.return %0 : tensor<1x16x16x32xf32> @@ -1423,9 +1502,9 @@ func.func @test_space_to_depth(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x16x16x3 // CHECK-LABEL: test_depth_to_space // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 32, 32, 2, 2, 2]} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} // CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%[[VAR1]], %[[VAR0]]) -// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [1, 64, 64, 2]} +// CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} func.func @test_depth_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x2xf32> { %0 = "tfl.depth_to_space"(%arg0) {block_size = 2 : i32} : (tensor<1x32x32x8xf32>) -> tensor<1x64x64x2xf32> func.return %0 : tensor<1x64x64x2xf32> @@ -1433,17 +1512,30 @@ func.func @test_depth_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x2 // ----- -// CHECK-LABEL: test_one_hot -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1]} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.tile"(%[[VAR1]]) {multiples = [16, 1, 1]} -// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%arg2) {new_shape = [1, 1, 1]} -// CHECK-DAG: %[[VAR4:.*]] = "tosa.tile"(%[[VAR3]]) {multiples = [16, 2, 1]} -// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg0) {new_shape = [16, 1]} -// CHECK-DAG: %[[VAR6:.*]] = "tosa.scatter"(%[[VAR4]], %[[VAR5]], %[[VAR2]]) -// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = [16, 1, 2]} -// CHECK-DAG: %[[VAR8:.*]] = "tosa.transpose"(%[[VAR7]], %[[VAR0]]) -// CHECK: %[[VAR9:.*]] = "tosa.reshape"(%[[VAR8]]) {new_shape = [4, 4, 2]} +// CHECK-LABEL: @test_bucketize +// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() {value = dense<{{\[\[\[}}0.000000e+00, 3.000000e+00, 8.000000e+00, 1.100000e+01]]]> : tensor<1x1x4xf32>} +// CHECK: %[[VAL_1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK: %[[VAL_2:.*]] = "tosa.greater_equal"(%[[VAL_1]], %[[VAL_0]]) +// CHECK: %[[VAL_3:.*]] = "tosa.cast"(%[[VAL_2]]) : (tensor<2x5x4xi1>) -> tensor<2x5x4xi32> +// CHECK: %[[VAL_4:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 2 : i64} +// CHECK: %[[VAL_5:.*]] = "tosa.reshape"(%[[VAL_4]]) {new_shape = array} +func.func @test_bucketize(%arg0: tensor<2x5xf32>) -> tensor<2x5xi32> { + %0 = "tfl.bucketize"(%arg0) {boundaries = [0.000000e+00 : f32, 3.000000e+00 : f32, 8.000000e+00 : f32, 1.100000e+01 : f32]} : (tensor<2x5xf32>) -> tensor<2x5xi32> + func.return %0 : tensor<2x5xi32> +} + +// ----- + +// CHECK-LABEL: @test_one_hot +// CHECK-SAME: %[[ARG0:.*]]: tensor<4x4xi32>, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor +// CHECK-DAG: %[[RESHAPE:.*]] = "tosa.reshape"(%[[ARG1]]) {new_shape = array} +// CHECK-DAG: %[[TILE:.*]] = "tosa.tile"(%[[RESHAPE]]) {multiples = array} +// CHECK-DAG: %[[RESHAPE_0:.*]] = "tosa.reshape"(%[[ARG2]]) {new_shape = array} +// CHECK-DAG: %[[TILE_0:.*]] = "tosa.tile"(%[[RESHAPE_0]]) {multiples = array} +// CHECK-DAG: %[[RESHAPE_1:.*]] = "tosa.reshape"(%[[ARG0]]) {new_shape = array} +// CHECK-DAG: %[[SCATTER:.*]] = "tosa.scatter"(%[[TILE_0]], %[[RESHAPE_1]], %[[TILE]]) +// CHECK-DAG: %[[RESHAPE_2:.*]] = "tosa.reshape"(%[[SCATTER]]) {new_shape = array} +// CHECK: return %[[RESHAPE_2]] func.func @test_one_hot(%arg0: tensor<4x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<4x4x2xf32> { %0 = arith.constant dense<2> : tensor %1 = "tfl.one_hot"(%arg0, %0, %arg1, %arg2) {axis = -1 : i32} : (tensor<4x4xi32>, tensor, tensor, tensor) -> tensor<4x4x2xf32> @@ -1551,7 +1643,7 @@ func.func @test_mul_qi8(%arg0: tensor<13x21x3x!quant.uniform, stride = [1, 1]} +// CHECK: %[[VAR0:.*]] = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, quantization_info = #tosa.unary_quant, stride = array} // CHECK-SAME: -> tensor<1x32x32x8x!quant.uniform> func.func @test_avg_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform>) -> tensor<*x!quant.uniform> { %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8x!quant.uniform>) -> tensor<*x!quant.uniform> @@ -1561,7 +1653,7 @@ func.func @test_avg_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform, pad = array, stride = array} // CHECK-SAME: -> tensor<1x32x32x8xi16> func.func @test_avg_pool2d_i16(%arg0: tensor<1x32x32x8xi16>) -> tensor<*xi16> { %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xi16>) -> tensor<*xi16> @@ -1571,7 +1663,7 @@ func.func @test_avg_pool2d_i16(%arg0: tensor<1x32x32x8xi16>) -> tensor<*xi16> { // ----- // CHECK-LABEL: test_max_pool2d_qi8 -// CHECK: %[[VAR0:.*]] = "tosa.max_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} +// CHECK: %[[VAR0:.*]] = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} func.func @test_max_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform>) -> tensor<*x!quant.uniform> { %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8x!quant.uniform>) -> tensor<*x!quant.uniform> func.return %0 : tensor<*x!quant.uniform> @@ -1594,10 +1686,10 @@ func.func @test_max_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform : tensor<513xi16>} // CHECK-DAG: %[[VAR13:.*]] = "tosa.const"() {value = dense<"0x4{{.*}}"> : tensor<513xi16>} // CHECK-DAG: %[[VAR14:.*]] = "tosa.const"() {value = dense<"0x0{{.*}}"> : tensor<513xi16>} -// CHECK-DAG: %[[VAR15:.*]] = "tosa.rescale"(%arg0) {double_round = false, input_zp = -1 : i32, multiplier = [1073741824 : i32], output_zp = 0 : i32, per_channel = false, scale32 = true, shift = [30 : i32]} +// CHECK-DAG: %[[VAR15:.*]] = "tosa.rescale"(%arg0) {double_round = false, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} // CHECK-DAG: %[[VAR16:.*]] = "tosa.reduce_max"(%[[VAR15]]) {axis = 2 : i64} // CHECK-DAG: %[[VAR17:.*]] = "tosa.sub"(%[[VAR15]], %[[VAR16]]) -// CHECK-DAG: %[[VAR18:.*]] = "tosa.rescale"(%[[VAR17]]) {double_round = false, input_zp = 0 : i32, multiplier = [1073741824 : i32], output_zp = 0 : i32, per_channel = false, scale32 = true, shift = [23 : i32]} +// CHECK-DAG: %[[VAR18:.*]] = "tosa.rescale"(%[[VAR17]]) {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} // CHECK-DAG: %[[VAR19:.*]] = "tosa.table"(%[[VAR18]], %[[VAR14]]) // CHECK-DAG: %[[VAR20:.*]] = "tosa.table"(%[[VAR18]], %[[VAR13]]) // CHECK-DAG: %[[VAR21:.*]] = "tosa.table"(%[[VAR18]], %[[VAR12]]) @@ -1634,7 +1726,7 @@ func.func @test_max_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array} func.func @test_softmax_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.return %0 : tensor<13x21x3x!quant.uniform> @@ -1653,10 +1745,10 @@ func.func @test_softmax_qi8(%arg0: tensor<13x21x3x!quant.uniform : tensor<1x1xi32>} // CHECK-DAG: %[[VAR7:.*]] = "tosa.const"() {value = dense<"0xF{{.*}}> // CHECK-DAG: %[[VAR8:.*]] = "tosa.const"() {value = dense<"0x0{{.*}}> : tensor<513xi16>} -// CHECK-DAG: %[[VAR9:.*]] = "tosa.rescale"(%arg0) {double_round = false, input_zp = 0 : i32, multiplier = [1073741824 : i32], output_zp = 0 : i32, per_channel = false, scale32 = true, shift = [30 : i32]} +// CHECK-DAG: %[[VAR9:.*]] = "tosa.rescale"(%arg0) {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} // CHECK-DAG: %[[VAR10:.*]] = "tosa.reduce_max"(%[[VAR9]]) {axis = 1 : i64} // CHECK-DAG: %[[VAR11:.*]] = "tosa.sub"(%[[VAR9]], %[[VAR10]]) -// CHECK-DAG: %[[VAR12:.*]] = "tosa.rescale"(%[[VAR11]]) {double_round = true, input_zp = 0 : i32, multiplier = [1717965619 : i32], output_zp = 0 : i32, per_channel = false, scale32 = true, shift = [32 : i32]} +// CHECK-DAG: %[[VAR12:.*]] = "tosa.rescale"(%[[VAR11]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} // CHECK-DAG: %[[VAR13:.*]] = "tosa.add"(%[[VAR12]], %[[VAR6]]) // CHECK-DAG: %[[VAR14:.*]] = "tosa.cast"(%[[VAR13]]) // CHECK-DAG: %[[VAR15:.*]] = "tosa.table"(%[[VAR14]], %[[VAR8]]) @@ -1674,7 +1766,7 @@ func.func @test_softmax_qi8(%arg0: tensor<13x21x3x!quant.uniform, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} func.func @test_softmax_qi16(%arg0: tensor<14x19x!quant.uniform>) -> tensor<14x19x!quant.uniform> { %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<14x19x!quant.uniform>) -> tensor<14x19x!quant.uniform> func.return %0 : tensor<14x19x!quant.uniform> @@ -1712,6 +1804,16 @@ func.func @test_relu_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<*x!quant.uniform> { + %0 = "tfl.relu_0_to_1"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<*x!quant.uniform> + func.return %0 : tensor<*x!quant.uniform> +} + +// ----- + // CHECK-LABEL: test_relu6_qi8 // CHECK-DAG: %[[VAR0:.*]] = "tosa.rescale"(%arg0) // CHECK: %[[VAR1:.*]] = "tosa.clamp"(%0) {max_fp = 6.000000e+00 : f32, max_int = 384 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} @@ -1723,10 +1825,10 @@ func.func @test_relu6_qi8(%arg0: tensor<13x21x3x!quant.uniform, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK: %[[RESCALE:.+]] = "tosa.rescale"(%[[CAST]]) {double_round = false, input_zp = -128 : i32, multiplier = array, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array} // CHECK: %[[CLAMP:.+]] = "tosa.clamp"(%[[RESCALE]]) {max_fp = 6.000000e+00 : f32, max_int = 22 : i64, min_fp = 0.000000e+00 : f32, min_int = -128 : i64} -// CHECK: %[[OUT:.+]] = "tosa.rescale"(%[[CLAMP]]) {double_round = false, input_zp = -128 : i32, multiplier = [1073741824 : i32], output_zp = 0 : i32, per_channel = false, scale32 = true, shift = [30 : i32]} +// CHECK: %[[OUT:.+]] = "tosa.rescale"(%[[CLAMP]]) {double_round = false, input_zp = -128 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} func.func @test_relu6_qu8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<*x!quant.uniform> { %0 = "tfl.relu6"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<*x!quant.uniform> func.return %0 : tensor<*x!quant.uniform> @@ -1749,8 +1851,8 @@ func.func @test_leaky_relu_qi8(%arg0: tensor<14x19x!quant.uniform, mode = "BILINEAR", offset = array, scale = array} +// CHECK: %[[VAR2:.*]] = "tosa.rescale"(%[[VAR1]]) {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} func.func @test_resize_bilinear_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_bilinear"(%arg0, %0) {align_corners = false, half_pixel_centers = false} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -1760,7 +1862,7 @@ func.func @test_resize_bilinear_qi8(%arg0: tensor<1x80x80x2x!quant.uniform, mode = "BILINEAR", offset = array, scale = array} func.func @test_resize_bilinear_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_bilinear"(%arg0, %0) {align_corners = false, half_pixel_centers = true} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -1770,7 +1872,7 @@ func.func @test_resize_bilinear_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform< // ----- // CHECK-LABEL: test_resize_bilinear_align_qi8 -// CHECK-DAG: %[[VAR1:.*]] = "tosa.resize"(%arg0) {border = [0, 0], mode = "BILINEAR", offset = [0, 0], scale = [1278, 158, 1278, 158]} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.resize"(%arg0) {border = array, mode = "BILINEAR", offset = array, scale = array} func.func @test_resize_bilinear_align_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_bilinear"(%arg0, %0) {align_corners = true, half_pixel_centers = false} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -1780,7 +1882,7 @@ func.func @test_resize_bilinear_align_qi8(%arg0: tensor<1x80x80x2x!quant.uniform // ----- // CHECK-LABEL: test_resize_bilinear_align_half_qi8 -// CHECK-DAG: %[[VAR1:.*]] = "tosa.resize"(%arg0) {border = [-560, -560], mode = "BILINEAR", offset = [-560, -560], scale = [1278, 158, 1278, 158]} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.resize"(%arg0) {border = array, mode = "BILINEAR", offset = array, scale = array} func.func @test_resize_bilinear_align_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_bilinear"(%arg0, %0) {align_corners = true, half_pixel_centers = true} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -1790,7 +1892,7 @@ func.func @test_resize_bilinear_align_half_qi8(%arg0: tensor<1x80x80x2x!quant.un // ----- // CHECK-LABEL: test_resize_nearest_qi8 -// CHECK: %[[VAR1:.*]] = "tosa.resize"(%arg0) {border = [14, 14], mode = "NEAREST_NEIGHBOR", offset = [0, 0], scale = [16, 2, 16, 2]} +// CHECK: %[[VAR1:.*]] = "tosa.resize"(%arg0) {border = array, mode = "NEAREST_NEIGHBOR", offset = array, scale = array} func.func @test_resize_nearest_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_nearest_neighbor"(%arg0, %0) {align_corners = false, half_pixel_centers = false} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -1801,7 +1903,7 @@ func.func @test_resize_nearest_qi8(%arg0: tensor<1x80x80x2x!quant.uniform, mode = "NEAREST_NEIGHBOR", offset = array, scale = array} func.func @test_resize_nearest_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_nearest_neighbor"(%arg0, %0) {align_corners = false, half_pixel_centers = true} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -1811,7 +1913,7 @@ func.func @test_resize_nearest_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform, mode = "NEAREST_NEIGHBOR", offset = array, scale = array} func.func @test_resize_nearest_align_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_nearest_neighbor"(%arg0, %0) {align_corners = true, half_pixel_centers = false} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -1821,7 +1923,7 @@ func.func @test_resize_nearest_align_qi8(%arg0: tensor<1x80x80x2x!quant.uniform< // ----- // CHECK-LABEL: test_resize_nearest_align_half_qi8 -// CHECK: %[[VAR1:.*]] = "tosa.resize"(%arg0) {border = [718, 718], mode = "NEAREST_NEIGHBOR", offset = [718, 718], scale = [1278, 158, 1278, 158]} +// CHECK: %[[VAR1:.*]] = "tosa.resize"(%arg0) {border = array, mode = "NEAREST_NEIGHBOR", offset = array, scale = array} func.func @test_resize_nearest_align_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_nearest_neighbor"(%arg0, %0) {align_corners = true, half_pixel_centers = true} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -1835,7 +1937,7 @@ func.func @test_resize_nearest_align_half_qi8(%arg0: tensor<1x80x80x2x!quant.uni // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<0> : tensor<28xi32>} // CHECK-DAG: %[[VAR2:.*]] = "tosa.transpose"(%arg1, %[[VAR0]]) // CHECK-DAG: %[[VAR3:.*]] = "tosa.fully_connected"(%arg0, %[[VAR2]], %[[VAR1]]) {quantization_info = #tosa.conv_quant} -// CHECK: %[[VAR4:.*]] = "tosa.rescale"(%[[VAR3]]) {double_round = true, input_zp = 0 : i32, multiplier = [1353377973 : i32], output_zp = 3 : i32, per_channel = false, scale32 = true, shift = [40 : i32]} +// CHECK: %[[VAR4:.*]] = "tosa.rescale"(%[[VAR3]]) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 3 : i32, per_channel = false, scale32 = true, shift = array} func.func @test_fullyconnected_qi8(%arg0: tensor<14x19x!quant.uniform>, %arg1: tensor<19x28x!quant.uniform>) -> tensor<14x28x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.transpose"(%arg1, %0) : (tensor<19x28x!quant.uniform>, tensor<2xi32>) -> tensor<28x19x!quant.uniform> @@ -1846,10 +1948,10 @@ func.func @test_fullyconnected_qi8(%arg0: tensor<14x19x!quant.uniform} +// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg1) {new_shape = array} // CHECK-DAG: %[[VAR6:.*]] = "tosa.gather"(%[[VAR4]], %[[VAR5]]) -// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = [7, 7, 21, 3]} +// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = array} // CHECK: return %[[VAR7]] func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<7x7xi32>) -> tensor<*xf32> { %2 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor<7x7xi32>) -> tensor<*xf32> @@ -1858,10 +1960,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<7x7xi32>) -> te // ----- // CHECK-LABEL: test_gather_dyn -// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, -1, 63]} -// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 49]} +// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg1) {new_shape = array} // CHECK-DAG: %[[VAR6:.*]] = "tosa.gather"(%[[VAR4]], %[[VAR5]]) -// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = [7, 7, 21, 3]} +// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = array} // CHECK: return %[[VAR7]] func.func @test_gather_dyn(%arg0: tensor, %arg1 : tensor<7x7xi32>) -> tensor<*xf32> { %2 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor, tensor<7x7xi32>) -> tensor<*xf32> @@ -1871,10 +1973,10 @@ func.func @test_gather_dyn(%arg0: tensor, %arg1 : tensor<7x7xi32>) - // ----- // CHECK-LABEL: test_gather_channel_dyn -// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 13, -1]} -// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 49]} +// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg1) {new_shape = array} // CHECK-DAG: %[[VAR6:.*]] = "tosa.gather"(%[[VAR4]], %[[VAR5]]) -// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = [7, 7, 21, -1]} +// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = array} // CHECK: return %[[VAR7]] func.func @test_gather_channel_dyn(%arg0: tensor<13x21x?xf32>, %arg1: tensor<7x7xi32>) -> tensor<*xf32> { %2 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<13x21x?xf32>, tensor<7x7xi32>) -> tensor<*xf32> @@ -1883,10 +1985,10 @@ func.func @test_gather_channel_dyn(%arg0: tensor<13x21x?xf32>, %arg1: tensor<7x7 // ----- // CHECK-LABEL: test_gather_indices_dyn -// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) {new_shape = [1, 13, 63]} -// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, -1]} +// CHECK-DAG: %[[VAR4:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%arg1) {new_shape = array} // CHECK-DAG: %[[VAR6:.*]] = "tosa.gather"(%[[VAR4]], %[[VAR5]]) -// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = [-1, 7, 21, 3]} +// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = array} // CHECK: return %[[VAR7]] func.func @test_gather_indices_dyn(%arg0: tensor<13x21x3xf32>, %arg1: tensor) -> tensor<*xf32> { %2 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor) -> tensor<*xf32> @@ -1896,9 +1998,9 @@ func.func @test_gather_indices_dyn(%arg0: tensor<13x21x3xf32>, %arg1: tensor} // CHECK-DAG: %[[VAR2:.*]] = "tosa.gather"(%[[VAR1]], %[[VAR0]]) -// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [1, 3, 4, 4]} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} // CHECK: return %[[VAR3]] func.func @test_gather_batch(%arg0: tensor<1x4x4x4xi32>) -> tensor<1x3x4x4xi32> { %0 = "tfl.pseudo_const"() {value = dense<[[0, 3, 1]]> : tensor<1x3xi32>} : () -> tensor<1x3xi32> @@ -1908,9 +2010,9 @@ func.func @test_gather_batch(%arg0: tensor<1x4x4x4xi32>) -> tensor<1x3x4x4xi32> // ----- // CHECK-LABEL: test_gather_batch_dyn -// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = [-1, 4, 16]} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.reshape"(%arg0) {new_shape = array} // CHECK-DAG: %[[VAR2:.*]] = "tosa.gather"(%[[VAR1]], %arg1) -// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [-1, 3, 4, 4]} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = array} // CHECK: return %[[VAR3]] func.func @test_gather_batch_dyn(%arg0: tensor, %arg1: tensor) -> tensor { %1 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32, batch_dims = 1 : i32} : (tensor, tensor) -> tensor @@ -1920,18 +2022,31 @@ func.func @test_gather_batch_dyn(%arg0: tensor, %arg1: tensor} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%arg1) {new_shape = array} // CHECK-DAG: %[[VAR5:.*]] = "tosa.mul"(%[[VAR3]], %[[VAR1]]) {shift = 0 : i32} // CHECK-DAG: %[[VAR6:.*]] = "tosa.reduce_sum"(%[[VAR5]]) {axis = 1 : i64} -// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = [1, 42]} +// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR6]]) {new_shape = array} // CHECK-DAG: %[[VAR8:.*]] = "tosa.gather"(%[[VAR2]], %[[VAR7]]) -// CHECK: %[[VAR9:.*]] = "tosa.reshape"(%[[VAR8]]) {new_shape = [6, 7, 3]} +// CHECK: %[[VAR9:.*]] = "tosa.reshape"(%[[VAR8]]) {new_shape = array} func.func @test_gather_nd(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6x7x2xi32>) -> tensor<6x7x3xf32> { %1 = "tfl.gather_nd"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<6x7x2xi32>) -> tensor<6x7x3xf32> func.return %1 : tensor<6x7x3xf32> } +// ----- +// CHECK-LABEL: test_gather_cast +// CHECK-DAG: %[[VAR1:.*]] = "tosa.cast"(%arg1) +// CHECK-DAG: %[[VAR2:.*]] = "tosa.reshape"(%arg0) {new_shape = array} +// CHECK-DAG: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR1]]) {new_shape = array} +// CHECK-DAG: %[[VAR4:.*]] = "tosa.gather"(%[[VAR2]], %[[VAR3]]) +// CHECK-DAG: %[[VAR5:.*]] = "tosa.reshape"(%[[VAR4]]) {new_shape = array} +// CHECK: return %[[VAR5]] +func.func @test_gather_cast(%arg0: tensor<13x21x3xf32>, %arg1: tensor<7x7xi64>) -> tensor<*xf32> { + %2 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<13x21x3xf32>, tensor<7x7xi64>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> +} + // ----- // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<{{\[\[}}48, 1]]> : tensor<1x2xi32>} @@ -1939,10 +2054,10 @@ func.func @test_gather_nd(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6x7x2xi32>) // CHECK-DAG: %[[VAR2:.*]] = "tosa.cast"(%arg0) // CHECK-DAG: %[[VAR4:.*]] = "tosa.mul"(%[[VAR2]], %[[VAR0]]) {shift = 0 : i32} // CHECK-DAG: %[[VAR5:.*]] = "tosa.reduce_sum"(%[[VAR4]]) {axis = 1 : i64} -// CHECK-DAG: %[[VAR6:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, -1, 1]} -// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR5]]) {new_shape = [1, -1]} +// CHECK-DAG: %[[VAR6:.*]] = "tosa.reshape"(%arg1) {new_shape = array} +// CHECK-DAG: %[[VAR7:.*]] = "tosa.reshape"(%[[VAR5]]) {new_shape = array} // CHECK-DAG: %[[VAR8:.*]] = "tosa.scatter"(%[[VAR1]], %[[VAR7]], %[[VAR6]]) -// CHECK-DAG: %[[VAR9:.*]] = "tosa.reshape"(%[[VAR8]]) {new_shape = [1, 48]} +// CHECK-DAG: %[[VAR9:.*]] = "tosa.reshape"(%[[VAR8]]) {new_shape = array} // CHECK: return %[[VAR9]] func.func @sparse_to_dense(%arg0 : tensor, %arg1 : tensor) -> (tensor<1x48xi64>) { %0 = arith.constant dense<[1, 48]> : tensor<2xi64> @@ -1963,6 +2078,16 @@ func.func @test_arg_max(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- +// CHECK-LABEL: @test_arg_max_negative_dim +func.func @test_arg_max_negative_dim(%arg0: tensor<13x21x3xf32>) -> tensor<13x21xf32> { + // CHECK: %[[ARGMAX:.+]] = "tosa.argmax"(%arg0) {axis = 2 : i64} + %0 = "tfl.pseudo_const"() {value = dense<-1> : tensor} : () -> tensor + %1 = "tfl.arg_max"(%arg0, %0) : (tensor<13x21x3xf32>, tensor) -> tensor<13x21xf32> + func.return %1 : tensor<13x21xf32> +} + +// ----- + // CHECK-LABEL: test_fakequant // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() {value = dense<-2.00003052> : tensor<1x1x1xf32>} // CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() {value = dense<1.99996948> : tensor<1x1x1xf32>} @@ -2065,33 +2190,94 @@ func.func @test_gelu_qi8(%arg0: tensor<1x4x4x4x!quant.uniform -// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) {size = [2, 9], start = [1, 0]} : (tensor<4x9xi32>) -// CHECK: %[[VAL_2:.*]] = "tosa.reverse"(%[[VAL_1]]) {axis = 0 : i64} : (tensor<2x9xi32>) -// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_0]]) {size = [1, 9], start = [2, 0]} : (tensor<4x9xi32>) -// CHECK: %[[VAL_4:.*]] = "tosa.concat"(%[[VAL_2]], %[[VAL_0]], %[[VAL_3]]) {axis = 0 : i64} : (tensor<2x9xi32>, tensor<4x9xi32>, tensor<1x9xi32>) -// CHECK: %[[VAL_5:.*]] = "tosa.slice"(%[[VAL_4]]) {size = [7, 2], start = [0, 1]} : (tensor<7x9xi32>) -// CHECK: %[[VAL_6:.*]] = "tosa.reverse"(%[[VAL_5]]) {axis = 1 : i64} : (tensor<7x2xi32>) -// CHECK: %[[VAL_7:.*]] = "tosa.slice"(%[[VAL_4]]) {size = [7, 1], start = [0, 7]} : (tensor<7x9xi32>) -// CHECK: %[[VAL_8:.*]] = "tosa.concat"(%[[VAL_6]], %[[VAL_4]], %[[VAL_7]]) {axis = 1 : i64} : (tensor<7x2xi32>, tensor<7x9xi32>, tensor<7x1xi32>) -func.func @mirrorpad_reflect(%arg0: tensor<4x9xi32>) -> tensor<7x12xi32> { +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x9x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array, start = array} : (tensor<4x9x!quant.uniform>) +// CHECK: %[[VAL_2:.*]] = "tosa.reverse"(%[[VAL_1]]) {axis = 0 : i64} : (tensor<2x9x!quant.uniform>) +// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array, start = array} : (tensor<4x9x!quant.uniform>) +// CHECK: %[[VAL_4:.*]] = "tosa.concat"(%[[VAL_2]], %[[VAL_0]], %[[VAL_3]]) {axis = 0 : i64} : (tensor<2x9x!quant.uniform>, tensor<4x9x!quant.uniform>, tensor<1x9x!quant.uniform>) +// CHECK: %[[VAL_5:.*]] = "tosa.slice"(%[[VAL_4]]) {size = array, start = array} : (tensor<7x9x!quant.uniform>) +// CHECK: %[[VAL_6:.*]] = "tosa.reverse"(%[[VAL_5]]) {axis = 1 : i64} : (tensor<7x2x!quant.uniform>) +// CHECK: %[[VAL_7:.*]] = "tosa.slice"(%[[VAL_4]]) {size = array, start = array} : (tensor<7x9x!quant.uniform>) +// CHECK: %[[VAL_8:.*]] = "tosa.concat"(%[[VAL_6]], %[[VAL_4]], %[[VAL_7]]) {axis = 1 : i64} : (tensor<7x2x!quant.uniform>, tensor<7x9x!quant.uniform>, tensor<7x1x!quant.uniform>) +func.func @mirrorpad_reflect(%arg0: tensor<4x9x!quant.uniform>) -> tensor<7x12x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<[[2, 1], [2, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> - %1 = "tfl.mirror_pad"(%arg0, %0) {mode = #tfl} : (tensor<4x9xi32>, tensor<2x2xi32>) -> tensor<7x12xi32> - return %1 : tensor<7x12xi32> + %1 = "tfl.mirror_pad"(%arg0, %0) {mode = #tfl} : (tensor<4x9x!quant.uniform>, tensor<2x2xi32>) -> tensor<7x12x!quant.uniform> + return %1 : tensor<7x12x!quant.uniform> } // ----- // CHECK-LABEL: mirrorpad_symmetric // CHECK-SAME: %[[VAL_0:.*]]: tensor<15x23x2xf32> -// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) {size = [1, 23, 2], start = [0, 0, 0]} : (tensor<15x23x2xf32>) +// CHECK: %[[VAL_1:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array, start = array} : (tensor<15x23x2xf32>) // CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_1]], %[[VAL_0]]) {axis = 0 : i64} : (tensor<1x23x2xf32>, tensor<15x23x2xf32>) -// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_2]]) {size = [16, 1, 2], start = [0, 0, 0]} : (tensor<16x23x2xf32>) +// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_2]]) {size = array, start = array} : (tensor<16x23x2xf32>) // CHECK: %[[VAL_4:.*]] = "tosa.concat"(%[[VAL_3]], %[[VAL_2]]) {axis = 1 : i64} : (tensor<16x1x2xf32>, tensor<16x23x2xf32>) -// CHECK: %[[VAL_5:.*]] = "tosa.slice"(%[[VAL_4]]) {size = [16, 24, 1], start = [0, 0, 0]} : (tensor<16x24x2xf32>) +// CHECK: %[[VAL_5:.*]] = "tosa.slice"(%[[VAL_4]]) {size = array, start = array} : (tensor<16x24x2xf32>) // CHECK: %[[VAL_6:.*]] = "tosa.concat"(%[[VAL_5]], %[[VAL_4]]) {axis = 2 : i64} : (tensor<16x24x1xf32>, tensor<16x24x2xf32>) func.func @mirrorpad_symmetric(%arg0: tensor<15x23x2xf32>) -> tensor<16x24x3xf32> { %0 = "tfl.pseudo_const"() {value = dense<[[1, 0], [1, 0], [1, 0]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> %1 = "tfl.mirror_pad"(%arg0, %0) {mode = #tfl} : (tensor<15x23x2xf32>, tensor<3x2xi32>) -> tensor<16x24x3xf32> return %1 : tensor<16x24x3xf32> } + +// ----- + +// CHECK-LABEL: test_tfl_custom +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x64x64x32xf32> +// CHECK: %[[VAL_0:.*]] = "tosa.custom"(%[[ARG_0]]) {config = "TFL", identifier = "MaxPoolingWithArgmax2D", implementation_attrs = "{{.*}}"} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) +func.func @test_tfl_custom(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) { + // custom op for "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) + %0, %1 = "tfl.custom"(%arg0) {custom_option = #tfl, custom_code = "MaxPoolingWithArgmax2D"} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) + func.return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32> +} + +// ----- +// CHECK-LABEL: test_tfl_while_loop +// CHECK: %[[VAL_0:.*]]: tensor<1x4x4x4xf32> {tf_saved_model.index_path = ["placeholder_0"]}) -> (tensor<1x4x4x4xf32> {tf_saved_model.index_path = ["output_0"]}) { +// CHECK: %[[VAL_1:.*]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.while_loop"(%[[VAL_0]]) ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: tensor<1x4x4x4xf32>): +// CHECK: %[[VAL_4:.*]] = "tosa.reduce_sum"(%[[VAL_3]]) {axis = 1 : i64} : (tensor<1x4x4x4xf32>) -> tensor<1x1x4x4xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.reduce_sum"(%[[VAL_4]]) {axis = 2 : i64} : (tensor<1x1x4x4xf32>) -> tensor<1x1x1x4xf32> +// CHECK: %[[VAL_6:.*]] = "tosa.reduce_sum"(%[[VAL_5]]) {axis = 3 : i64} : (tensor<1x1x1x4xf32>) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_7:.*]] = "tosa.reshape"(%[[VAL_6]]) {new_shape = array} : (tensor<1x1x1x1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_8:.*]] = "tosa.greater"(%[[VAL_1]], %[[VAL_7]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> +// CHECK: %[[VAL_9:.*]] = "tosa.reshape"(%[[VAL_8]]) {new_shape = array} : (tensor<1xi1>) -> tensor +// CHECK: "tosa.yield"(%[[VAL_9]]) : (tensor) -> () +// CHECK: }, { +// CHECK: ^bb0(%[[VAL_10:.*]]: tensor<1x4x4x4xf32>): +// CHECK: %[[VAL_11:.*]] = "tosa.sigmoid"(%[[VAL_10]]) : (tensor<1x4x4x4xf32>) -> tensor<1x4x4x4xf32> +// CHECK: %[[VAL_12:.*]] = "tosa.add"(%[[VAL_10]], %[[VAL_11]]) : (tensor<1x4x4x4xf32>, tensor<1x4x4x4xf32>) -> tensor<1x4x4x4xf32> +// CHECK: "tosa.yield"(%[[VAL_12]]) : (tensor<1x4x4x4xf32>) -> () +// CHECK: }) : (tensor<1x4x4x4xf32>) -> tensor<1x4x4x4xf32> +// CHECK: return %[[VAL_13:.*]] : tensor<1x4x4x4xf32> +// CHECK: } +func.func @test_tfl_while_loop(%arg0: tensor<1x4x4x4xf32> {tf_saved_model.index_path = ["placeholder_0"]}) -> (tensor<1x4x4x4xf32> {tf_saved_model.index_path = ["output_0"]}) { + %0 = "tfl.while"(%arg0) ({ + ^bb0(%arg1: tensor<1x4x4x4xf32>): + %1 = func.call @result_cond(%arg1) : (tensor<1x4x4x4xf32>) -> tensor + "tfl.yield"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor<1x4x4x4xf32>): + %1 = func.call @result_body(%arg1) : (tensor<1x4x4x4xf32>) -> tensor<1x4x4x4xf32> + "tfl.yield"(%1) : (tensor<1x4x4x4xf32>) -> () + }) : (tensor<1x4x4x4xf32>) -> tensor<1x4x4x4xf32> + func.return %0 : tensor<1x4x4x4xf32> +} +func.func private @result_cond(%arg0: tensor<1x4x4x4xf32>) -> tensor { + %0 = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tfl.sum"(%arg0, %0) {keep_dims = false} : (tensor<1x4x4x4xf32>, tensor<4xi32>) -> tensor + %2 = "tfl.pseudo_const"() {value = dense<2.000000e+00> : tensor<1xf32>} : () -> tensor<1xf32> + %3 = tfl.less(%1, %2) : (tensor, tensor<1xf32>) -> tensor<1xi1> + %4 = "tfl.pseudo_const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + %5 = "tfl.reshape"(%3, %4) : (tensor<1xi1>, tensor<0xi32>) -> tensor + func.return %5 : tensor +} +func.func private @result_body(%arg0: tensor<1x4x4x4xf32>) -> tensor<1x4x4x4xf32> { + %0 = "tfl.logistic"(%arg0) : (tensor<1x4x4x4xf32>) -> tensor<1x4x4x4xf32> + %1 = tfl.add %arg0, %0 {fused_activation_function = "NONE"} : tensor<1x4x4x4xf32> + func.return %1 : tensor<1x4x4x4xf32> +} + + diff --git a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc index 15031a8620d..62058cbe799 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc @@ -195,9 +195,9 @@ LogicalResult convert_graph_uint8_tensor(mlir::MLIRContext &context, function.getLoc(), int8_type, arg, builder.getI32IntegerAttr(uint8_zp), builder.getI32IntegerAttr(int8_zp), - builder.getI32ArrayAttr({1 << 30}), builder.getI32ArrayAttr({30}), - builder.getBoolAttr(true), builder.getBoolAttr(false), - builder.getBoolAttr(false)); + builder.getDenseI32ArrayAttr({1 << 30}), + builder.getDenseI32ArrayAttr({30}), builder.getBoolAttr(true), + builder.getBoolAttr(false), builder.getBoolAttr(false)); Operation *op_rescale_op = static_cast(rescale_op); bb.push_front(op_rescale_op); @@ -313,9 +313,9 @@ LogicalResult convert_graph_uint8_tensor(mlir::MLIRContext &context, function.getLoc(), uint8_output_type, input_val, builder.getI32IntegerAttr(int8_zp), builder.getI32IntegerAttr(uint8_zp), - builder.getI32ArrayAttr({1 << 30}), builder.getI32ArrayAttr({30}), - builder.getBoolAttr(true), builder.getBoolAttr(false), - builder.getBoolAttr(false)); + builder.getDenseI32ArrayAttr({1 << 30}), + builder.getDenseI32ArrayAttr({30}), builder.getBoolAttr(true), + builder.getBoolAttr(false), builder.getBoolAttr(false)); Operation *op_rescale_op = static_cast(rescale_op); bb.push_back(op_rescale_op); diff --git a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc index 5b7261c4cdc..ae3a07d324a 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc @@ -75,8 +75,8 @@ LogicalResult ConvertTFBiasAddOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "output not a ranked tensor"); } - auto value = tf_biasadd_op.value(); - auto bias = tf_biasadd_op.bias(); + auto value = tf_biasadd_op.getValue(); + auto bias = tf_biasadd_op.getBias(); auto bias_shape = bias.getType().cast().getShape(); if (bias_shape.size() != 1) { return rewriter.notifyMatchFailure(op, "bias tensor must be rank 1"); @@ -86,7 +86,7 @@ LogicalResult ConvertTFBiasAddOp::matchAndRewrite( llvm::dyn_cast_if_present(value.getDefiningOp())) { // Sanity check to confirm rhs() has the expected shape of bias auto filter_shape = - tf_conv2d_op.filter().getType().cast().getShape(); + tf_conv2d_op.getFilter().getType().cast().getShape(); // Assume the filter shape is [H, W, I, O] if (filter_shape.back() != bias_shape.back()) { @@ -95,10 +95,10 @@ LogicalResult ConvertTFBiasAddOp::matchAndRewrite( } auto result = convertTFConv2DCommon( - rewriter, op, output_type, tf_conv2d_op.input(), tf_conv2d_op.filter(), - bias, tf_conv2d_op.strides(), tf_conv2d_op.dilations(), - tf_conv2d_op.explicit_paddings(), tf_conv2d_op.padding(), - tf_conv2d_op.data_format()); + rewriter, op, output_type, tf_conv2d_op.getInput(), + tf_conv2d_op.getFilter(), bias, tf_conv2d_op.getStrides(), + tf_conv2d_op.getDilations(), tf_conv2d_op.getExplicitPaddings(), + tf_conv2d_op.getPadding(), tf_conv2d_op.getDataFormat()); if (!result) return failure(); @@ -111,7 +111,7 @@ LogicalResult ConvertTFBiasAddOp::matchAndRewrite( llvm::dyn_cast_if_present(value.getDefiningOp())) { // Sanity check to confirm rhs() has the expected shape of bias auto filter_shape = - tf_conv3d_op.filter().getType().cast().getShape(); + tf_conv3d_op.getFilter().getType().cast().getShape(); // Assume the filter shape is [D, H, W, I, O] if (filter_shape.back() != bias_shape.back()) { @@ -120,9 +120,10 @@ LogicalResult ConvertTFBiasAddOp::matchAndRewrite( } llvm::Optional result = convertTFConv3DCommon( - rewriter, op, output_type, tf_conv3d_op.input(), tf_conv3d_op.filter(), - bias, tf_conv3d_op.strides(), tf_conv3d_op.dilations(), - tf_conv3d_op.padding(), tf_conv3d_op.data_format()); + rewriter, op, output_type, tf_conv3d_op.getInput(), + tf_conv3d_op.getFilter(), bias, tf_conv3d_op.getStrides(), + tf_conv3d_op.getDilations(), tf_conv3d_op.getPadding(), + tf_conv3d_op.getDataFormat()); if (!result) return failure(); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index c0bb4ebd6ba..45e9c531f0b 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -44,8 +44,8 @@ namespace mlir { namespace tosa { static int64_t multiply_dims(int64_t a, int64_t b) { - if (a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize) { - return ShapedType::kDynamicSize; + if (a == ShapedType::kDynamic || b == ShapedType::kDynamic) { + return ShapedType::kDynamic; } return a * b; } @@ -53,7 +53,7 @@ static int64_t multiply_dims(int64_t a, int64_t b) { static int64_t multiply_dims(llvm::ArrayRef dims, int64_t res = 1) { for (auto dim : dims) { if (ShapedType::isDynamic(dim)) { - return ShapedType::kDynamicSize; + return ShapedType::kDynamic; } res = res * dim; } @@ -186,8 +186,8 @@ llvm::Optional convertPackOp(PatternRewriter& rewriter, Operation* op, RankedTensorType reshape_rank1_size1_type = tensorflow::GetTypeFromTFTensorShape(reshape_rank1_size1_shape, result_type.getElementType()); - ArrayAttr shape_rank1_size1_attr = - rewriter.getI64ArrayAttr(reshape_rank1_size1_shape); + DenseI64ArrayAttr shape_rank1_size1_attr = rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(reshape_rank1_size1_shape)); for (int i = 0; i < inputs.size(); i++) { auto a0_reshape_op = CreateOpAndInfer( rewriter, op->getLoc(), reshape_rank1_size1_type, inputs[i], @@ -254,7 +254,8 @@ llvm::Optional convertPackOp(PatternRewriter& rewriter, Operation* op, output_shape_vals.end()); } IntegerAttr concat_axis_attr = rewriter.getI64IntegerAttr(concat_axis); - ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_output_shape); + DenseI64ArrayAttr shape_attr = rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(reshape_output_shape)); // Concat output shape will depend on concat_axis. E.g. [N * A, B, C] SmallVector concat_output_shape; @@ -300,7 +301,7 @@ llvm::Optional convertPackOp(PatternRewriter& rewriter, Operation* op, return CreateOpAndInfer( rewriter, op->getLoc(), result_type, a2_reshape_op.getResult(), - a3_transpose_perm.getValue()) + a3_transpose_perm.value()) .getResult(); } @@ -353,7 +354,7 @@ llvm::Optional> convertUnpackOp(PatternRewriter& rewriter, rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(a1_transpose_shape, input_type.getElementType()), - input_value, a1_transpose_perm.getValue()); + input_value, a1_transpose_perm.value()); transposed_input_value = a1_transpose_op.getResult(); } else { @@ -383,8 +384,9 @@ llvm::Optional> convertUnpackOp(PatternRewriter& rewriter, } } - ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals); - ArrayAttr size = rewriter.getI64ArrayAttr(size_vals); + DenseI64ArrayAttr begin = rewriter.getDenseI64ArrayAttr(begin_vals); + DenseI64ArrayAttr size = rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(size_vals)); auto a2_slice_op = CreateOpAndInfer( rewriter, op->getLoc(), @@ -396,7 +398,9 @@ llvm::Optional> convertUnpackOp(PatternRewriter& rewriter, rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape( shape_vals, transposed_input_type.getElementType()), - a2_slice_op.getResult(), rewriter.getI64ArrayAttr(shape_vals)); + a2_slice_op.getResult(), + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(shape_vals))); results_vec.push_back(a3_reshape_op.getResult()); } @@ -446,7 +450,9 @@ llvm::Optional convertSelectOp(PatternRewriter& rewriter, Operation* op, rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(new_cond_dims, condition_type.getElementType()), - condition_value, rewriter.getI64ArrayAttr(new_cond_dims)); + condition_value, + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(new_cond_dims))); return CreateOpAndInfer(rewriter, op->getLoc(), result_type, reshape_op, x_value, y_value) @@ -600,7 +606,7 @@ llvm::Optional convertRoundOp(PatternRewriter& rewriter, Operation* op, // Lowers ConcatV2 to TOSA Concat. llvm::Optional convertConcatV2Op(PatternRewriter& rewriter, - Operation* op, Value result_value, + Operation* op, ShapedType result_type, SmallVectorImpl& values, int32_t axis) { // Check all inputs are RankedTensorType @@ -611,15 +617,6 @@ llvm::Optional convertConcatV2Op(PatternRewriter& rewriter, } } - // Check output is Ranked tensor type - if (!result_value.getType().dyn_cast()) { - (void)rewriter.notifyMatchFailure(op, - "output value type not ranked tensor"); - return llvm::None; - } - - RankedTensorType result_type = - result_value.getType().dyn_cast(); mlir::quant::UniformQuantizedType result_quant_type = result_type.getElementType() .dyn_cast_or_null(); @@ -660,7 +657,7 @@ llvm::Optional convertConcatV2Op(PatternRewriter& rewriter, } } - int32_t tensor_rank = result_type.getShape().size(); + int32_t tensor_rank = values[0].getType().cast().getRank(); if (axis < 0) axis += tensor_rank; if ((axis < 0) || (axis > tensor_rank)) { @@ -669,7 +666,7 @@ llvm::Optional convertConcatV2Op(PatternRewriter& rewriter, } auto concat_op = CreateOpAndInfer( - rewriter, op->getLoc(), result_value.getType(), values_rescaled, + rewriter, op->getLoc(), result_type, values_rescaled, rewriter.getI64IntegerAttr(axis)); return concat_op.getResult(); @@ -772,11 +769,11 @@ llvm::Optional convertSpaceToBatchNDOp(PatternRewriter& rewriter, auto input_shape = input_type.getShape(); int block_rank = block_shape[0]; - int batch_size = input_shape[0]; + int64_t batch_size = input_shape[0]; int input_rank = input_type.getRank(); int remaining_shape_rank = input_rank - block_rank - 1; - int block_num_elems = 1; - int padding_sum = 0; + int64_t block_num_elems = 1; + int64_t padding_sum = 0; ElementsAttr block_shape_elems; ElementsAttr paddings_elems; @@ -837,7 +834,7 @@ llvm::Optional convertSpaceToBatchNDOp(PatternRewriter& rewriter, auto a0_pad_const_op = rewriter.create( op->getLoc(), a0_pad_const_attr_type, DenseElementsAttr::get(a0_pad_const_attr_type, - llvm::makeArrayRef(a0_pad_const))); + llvm::ArrayRef(a0_pad_const))); auto a1_pad_input_op = CreateOpAndInfer( rewriter, op->getLoc(), @@ -860,7 +857,7 @@ llvm::Optional convertSpaceToBatchNDOp(PatternRewriter& rewriter, int32_t block_shape_val = block_shape_elems.getValues()[i].getInt(); a2_shape[1 + i * 2 + 0] = padded_shape[1 + i]; - if (a2_shape[1 + i * 2 + 0] != ShapedType::kDynamicSize) { + if (a2_shape[1 + i * 2 + 0] != ShapedType::kDynamic) { a2_shape[1 + i * 2 + 0] /= block_shape_val; } @@ -875,10 +872,10 @@ llvm::Optional convertSpaceToBatchNDOp(PatternRewriter& rewriter, auto a2_reshape_a1_op = CreateOpAndInfer( rewriter, op->getLoc(), - tensorflow::GetTypeFromTFTensorShape(a2_shape, - result_type.getElementType()), + RankedTensorType::get(a2_shape, result_type.getElementType()), a1_pad_input_op.getResult(), - rewriter.getI64ArrayAttr(tensorflow::ConvertTFShapeToMlir(a2_shape))); + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF((a2_shape)))); // 3. Transpose dimensions to: // block-shape + @@ -914,7 +911,7 @@ llvm::Optional convertSpaceToBatchNDOp(PatternRewriter& rewriter, rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(a3_transpose_shape, result_type.getElementType()), - a2_reshape_a1_op.getResult(), a3_transpose_const.getValue()); + a2_reshape_a1_op.getResult(), a3_transpose_const.value()); // 4. Reshape the transposed tensor to flatten block_shape // into the batch dimension with the following shape: @@ -933,7 +930,7 @@ llvm::Optional convertSpaceToBatchNDOp(PatternRewriter& rewriter, int32_t block_shape_val = block_shape_elems.getValues()[i].getInt(); a4_reshape_shape[i + 1] = padded_shape[i + 1]; - if (a4_reshape_shape[i + 1] != ShapedType::kDynamicSize) { + if (a4_reshape_shape[i + 1] != ShapedType::kDynamic) { a4_reshape_shape[i + 1] /= block_shape_val; } } @@ -946,7 +943,8 @@ llvm::Optional convertSpaceToBatchNDOp(PatternRewriter& rewriter, return CreateOpAndInfer( rewriter, op->getLoc(), result_type, a3_transpose_a2_op.getResult(), - rewriter.getI64ArrayAttr(a4_reshape_shape)) + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(a4_reshape_shape))) .getResult(); } @@ -1031,7 +1029,7 @@ llvm::Optional convertBatchToSpaceNDOp(PatternRewriter& rewriter, // Another 4-step process int block_rank = block_shape_type.getShape()[0]; int input_rank = input_type.getRank(); - int crops_dims = crops_type.getShape()[0]; + int64_t crops_dims = crops_type.getShape()[0]; int remaining_shape_rank = input_rank - block_rank - 1; auto input_shape = input_type.getShape(); @@ -1052,9 +1050,9 @@ llvm::Optional convertBatchToSpaceNDOp(PatternRewriter& rewriter, SmallVector> crops(crops_dims); // Extract values for block_shape and crops now. - int block_num_elems = 1; + int64_t block_num_elems = 1; for (int i = 0; i < block_rank; i++) { - int block_shape_val = + int64_t block_shape_val = rewriter .getI32IntegerAttr( block_shape_elems.getValues()[i].getInt()) @@ -1089,8 +1087,8 @@ llvm::Optional convertBatchToSpaceNDOp(PatternRewriter& rewriter, for (int i = 0; i < block_rank; i++) a1_shape[i] = block_shape[i]; - a1_shape[block_rank] = (input_shape[0] == ShapedType::kDynamicSize) - ? ShapedType::kDynamicSize + a1_shape[block_rank] = (input_shape[0] == ShapedType::kDynamic) + ? ShapedType::kDynamic : input_shape[0] / block_num_elems; for (int i = 0; i < input_rank - 1; i++) @@ -1100,7 +1098,9 @@ llvm::Optional convertBatchToSpaceNDOp(PatternRewriter& rewriter, rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(a1_shape, result_type.getElementType()), - input_value, rewriter.getI64ArrayAttr(a1_shape)); + input_value, + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(a1_shape))); // 2. Permute to shape // [ batch / prod(block_shape) ], @@ -1137,7 +1137,7 @@ llvm::Optional convertBatchToSpaceNDOp(PatternRewriter& rewriter, rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(a2_transpose_shape, result_type.getElementType()), - a1_reshape_input_op.getResult(), a2_transpose_perm.getValue()); + a1_reshape_input_op.getResult(), a2_transpose_perm.value()); // Step 3. Reshape to: // [ batch / prod(block_shape) ], @@ -1148,7 +1148,7 @@ llvm::Optional convertBatchToSpaceNDOp(PatternRewriter& rewriter, SmallVector a4_shape(input_rank); a4_shape[0] = input_shape[0]; - if (a4_shape[0] != ShapedType::kDynamicSize) { + if (a4_shape[0] != ShapedType::kDynamic) { a4_shape[0] /= block_num_elems; } for (int i = 0; i < block_rank; i++) { @@ -1162,7 +1162,9 @@ llvm::Optional convertBatchToSpaceNDOp(PatternRewriter& rewriter, rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(a4_shape, result_type.getElementType()), - a2_transpose_a1_op.getResult(), rewriter.getI64ArrayAttr(a4_shape)); + a2_transpose_a1_op.getResult(), + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(a4_shape))); // 4. Crop the start/end dimensions on 'spatial dimension' according to // crops @@ -1191,8 +1193,10 @@ llvm::Optional convertBatchToSpaceNDOp(PatternRewriter& rewriter, rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(a4_size_vals, result_type.getElementType()), - a3_reshape_a2.getResult(), rewriter.getI64ArrayAttr(a4_begin_vals), - rewriter.getI64ArrayAttr(a4_size_vals)) + a3_reshape_a2.getResult(), + rewriter.getDenseI64ArrayAttr(a4_begin_vals), + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(a4_size_vals))) .getResult(); } @@ -1252,7 +1256,8 @@ llvm::Optional convertExpandDimsOp(PatternRewriter& rewriter, } } - ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims); + DenseI64ArrayAttr shape_attr = rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(reshape_dims)); return CreateOpAndInfer(rewriter, op->getLoc(), output_type, input_value, shape_attr) @@ -1315,7 +1320,8 @@ llvm::Optional convertSqueezeOp(PatternRewriter& rewriter, Operation* op, } } - ArrayAttr shape_attr = rewriter.getI64ArrayAttr(reshape_dims); + DenseI64ArrayAttr shape_attr = rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(reshape_dims)); return CreateOpAndInfer(rewriter, op->getLoc(), output_type, input_value, shape_attr) @@ -1915,7 +1921,8 @@ llvm::Optional convertSpaceToDepthOp(PatternRewriter& rewriter, a_reshape_dims, output_type.getElementType()); auto a2_reshape_a_op = CreateOpAndInfer( rewriter, op->getLoc(), a_reshape_output_type, input_value, - rewriter.getI64ArrayAttr(a_reshape_dims)); + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(a_reshape_dims))); llvm::Optional a3_transpose_perm = getConstTensor( rewriter, op, /*vec=*/{0, 1, 3, 2, 4, 5}, /*shape=*/{6}); @@ -1924,7 +1931,7 @@ llvm::Optional convertSpaceToDepthOp(PatternRewriter& rewriter, auto a3_transpose_a2_op = CreateOpAndInfer( rewriter, op->getLoc(), a_reshape_output_type, - a2_reshape_a_op.getResult(), a3_transpose_perm.getValue()); + a2_reshape_a_op.getResult(), a3_transpose_perm.value()); SmallVector a3_reshape_dims; a3_reshape_dims.push_back(input_shape[0]); @@ -1938,7 +1945,8 @@ llvm::Optional convertSpaceToDepthOp(PatternRewriter& rewriter, return CreateOpAndInfer( rewriter, op->getLoc(), a3_reshape_output_type, a3_transpose_a2_op.getResult(), - rewriter.getI64ArrayAttr(a3_reshape_dims)) + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(a3_reshape_dims))) .getResult(); } @@ -2003,7 +2011,8 @@ llvm::Optional convertDepthToSpaceOp(PatternRewriter& rewriter, a_reshape_dims, output_type.getElementType()); auto a2_reshape_a_op = CreateOpAndInfer( rewriter, op->getLoc(), a_reshape_output_type, input_value, - rewriter.getI64ArrayAttr(a_reshape_dims)); + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(a_reshape_dims))); llvm::Optional a3_transpose_perm = getConstTensor( rewriter, op, /*vec=*/{0, 1, 3, 2, 4, 5}, /*shape=*/{6}); @@ -2012,7 +2021,7 @@ llvm::Optional convertDepthToSpaceOp(PatternRewriter& rewriter, auto a3_transpose_a2_op = CreateOpAndInfer( rewriter, op->getLoc(), a_reshape_output_type, - a2_reshape_a_op.getResult(), a3_transpose_perm.getValue()); + a2_reshape_a_op.getResult(), a3_transpose_perm.value()); SmallVector a3_reshape_dims; a3_reshape_dims.push_back(input_shape[0]); @@ -2026,7 +2035,8 @@ llvm::Optional convertDepthToSpaceOp(PatternRewriter& rewriter, return CreateOpAndInfer( rewriter, op->getLoc(), a3_reshape_output_type, a3_transpose_a2_op.getResult(), - rewriter.getI64ArrayAttr(a3_reshape_dims)) + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(a3_reshape_dims))) .getResult(); } @@ -2074,7 +2084,8 @@ llvm::Optional> convertSplitOp( slice_value = CreateOpAndInfer( rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(new_shape, etype), input_value, - rewriter.getI64ArrayAttr(new_shape)); + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(new_shape))); } RankedTensorType slice_type = slice_value.getType().cast(); @@ -2094,8 +2105,9 @@ llvm::Optional> convertSplitOp( SmallVector results_vec; for (int i = 0; i < num_split; i++) { begin_vals[axis] = i * size_vals[axis]; - ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals); - ArrayAttr size = rewriter.getI64ArrayAttr(size_vals); + DenseI64ArrayAttr begin = rewriter.getDenseI64ArrayAttr(begin_vals); + DenseI64ArrayAttr size = rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(size_vals)); Value result = CreateOpAndInfer( rewriter, op->getLoc(), @@ -2112,7 +2124,9 @@ llvm::Optional> convertSplitOp( CreateOpAndInfer( rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(out_reshape_shape, etype), - result, rewriter.getI64ArrayAttr(out_reshape_shape)) + result, + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(out_reshape_shape))) .getResult(); } @@ -2178,8 +2192,9 @@ llvm::Optional> convertSplitVOp( } } - ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals); - ArrayAttr size = rewriter.getI64ArrayAttr(size_vals); + DenseI64ArrayAttr begin = rewriter.getDenseI64ArrayAttr(begin_vals); + DenseI64ArrayAttr size = rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(size_vals)); auto slice_op = CreateOpAndInfer( rewriter, op->getLoc(), @@ -2200,15 +2215,13 @@ llvm::Optional> convertSplitVOp( // the only legal negative stride. static Value reverseNegativeStride(PatternRewriter& rewriter, Operation* op, Value input, ArrayRef strides) { - Type reverse_ty = UnrankedTensorType::get( - input.getType().cast().getElementType()); for (auto it : llvm::enumerate(strides)) { auto axis = it.index(); auto stride = it.value(); if (stride != -1) continue; input = CreateOpAndInfer(rewriter, op->getLoc(), - reverse_ty, input, + input.getType(), input, rewriter.getI64IntegerAttr(axis)) .getResult(); } @@ -2380,8 +2393,8 @@ llvm::Optional convertStridedSliceOp( rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(a1_size, input_type.getElementType()), - input_value, rewriter.getI64ArrayAttr(a1_begin), - rewriter.getI64ArrayAttr(a1_size)); + input_value, rewriter.getDenseI64ArrayAttr(a1_begin), + rewriter.getDenseI64ArrayAttr(tensorflow::ConvertMlirShapeToTF(a1_size))); if (all_strides_one) { auto reversed = @@ -2398,7 +2411,8 @@ llvm::Optional convertStridedSliceOp( return CreateOpAndInfer( rewriter, op->getLoc(), result_type, reversed, - rewriter.getI64ArrayAttr(new_shape)) + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(new_shape))) .getResult(); } @@ -2413,7 +2427,9 @@ llvm::Optional convertStridedSliceOp( rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(a2_shape, input_type.getElementType()), - a1_slice_op.getResult(), rewriter.getI64ArrayAttr(a2_shape)); + a1_slice_op.getResult(), + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(a2_shape))); // Step 3: take a slice along the strides SmallVector a3_begin(input_rank * 2), a3_size(input_rank * 2); @@ -2434,8 +2450,8 @@ llvm::Optional convertStridedSliceOp( rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(a3_size, input_type.getElementType()), - a2_reshape_op.getResult(), rewriter.getI64ArrayAttr(a3_begin), - rewriter.getI64ArrayAttr(a3_size)); + a2_reshape_op.getResult(), rewriter.getDenseI64ArrayAttr(a3_begin), + rewriter.getDenseI64ArrayAttr(tensorflow::ConvertMlirShapeToTF(a3_size))); // Step 4: reshape the now-strided tensor SmallVector a4_shape; @@ -2448,9 +2464,10 @@ llvm::Optional convertStridedSliceOp( } auto a4_reshape_op = - CreateOpAndInfer(rewriter, op->getLoc(), result_type, - a3_slice_op.getResult(), - rewriter.getI64ArrayAttr(a4_shape)) + CreateOpAndInfer( + rewriter, op->getLoc(), result_type, a3_slice_op.getResult(), + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(a4_shape))) .getResult(); return reverseNegativeStride(rewriter, op, a4_reshape_op, strides); @@ -2667,16 +2684,17 @@ static Value convertGenericReduceOp(PatternRewriter& rewriter, Operation* op, Value perms_value = getConstTensor(rewriter, op, perms, {static_cast(perms.size())}) - .getValue(); + .value(); auto transpose_op = CreateOpAndInfer( - rewriter, loc, UnrankedTensorType::get(rewriter.getI32Type()), input, - perms_value); + rewriter, loc, UnrankedTensorType::get(input_etype), input, perms_value); auto reshape_op = CreateOpAndInfer( rewriter, loc, tensorflow::GetTypeFromTFTensorShape(reshape_shape, input_etype), - transpose_op, rewriter.getI64ArrayAttr(reshape_shape)); + transpose_op, + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(reshape_shape))); return CreateOpAndInfer(rewriter, loc, UnrankedTensorType::get(reduce_etype), reshape_op, @@ -2756,7 +2774,8 @@ llvm::Optional convertReduceOpCommon( // Squeeze out the reduced axes. auto reshape_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, val, - rewriter.getI64ArrayAttr(output_shape)); + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(output_shape))); return reshape_op.getResult(); } @@ -2967,7 +2986,7 @@ llvm::Optional convertReduceMeanOp(PatternRewriter& rewriter, if (!input_is_qtype) { Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale); return CreateOpAndInfer(rewriter, op->getLoc(), output_type, - val.getValue(), div_const, 0) + val.value(), div_const, 0) .getResult(); } @@ -3094,10 +3113,12 @@ llvm::Optional convertResizeOp(PatternRewriter& rewriter, Operation* op, normalize(input_width, output_width, scale_x_n, scale_x_d, offset_x, border_x); - ArrayAttr scale = - rewriter.getI64ArrayAttr({scale_y_n, scale_y_d, scale_x_n, scale_x_d}); - ArrayAttr offset = rewriter.getI64ArrayAttr({offset_y, offset_x}); - ArrayAttr border = rewriter.getI64ArrayAttr({border_y, border_x}); + DenseI64ArrayAttr scale = rewriter.getDenseI64ArrayAttr( + {scale_y_n, scale_y_d, scale_x_n, scale_x_d}); + DenseI64ArrayAttr offset = + rewriter.getDenseI64ArrayAttr({offset_y, offset_x}); + DenseI64ArrayAttr border = + rewriter.getDenseI64ArrayAttr({border_y, border_x}); StringAttr resize_mode = rewriter.getStringAttr(mode); @@ -3288,11 +3309,11 @@ llvm::Optional convertDequantizeOp(PatternRewriter& rewriter, auto op2_sub_op1 = CreateOpAndInfer(rewriter, op->getLoc(), output_type, - op1_cast_in.getResult(), zp_val.getValue()); + op1_cast_in.getResult(), zp_val.value()); return CreateOpAndInfer(rewriter, op->getLoc(), output_type, op2_sub_op1.getResult(), - scale_val.getValue(), 0) + scale_val.value(), 0) .getResult(); } @@ -3449,8 +3470,8 @@ llvm::Optional convertMirrorPadCommon(PatternRewriter& rewriter, rewriter, op->getLoc(), RankedTensorType::get(slice_before_size, output_type.getElementType()), - current_tensor, rewriter.getI64ArrayAttr(slice_before_begin), - rewriter.getI64ArrayAttr(slice_before_size)); + current_tensor, rewriter.getDenseI64ArrayAttr(slice_before_begin), + rewriter.getDenseI64ArrayAttr(slice_before_size)); // Reverse op is superfluous when the padding value is 1. if (pad_before == 1) { @@ -3472,8 +3493,8 @@ llvm::Optional convertMirrorPadCommon(PatternRewriter& rewriter, auto slice_after_op = CreateOpAndInfer( rewriter, op->getLoc(), RankedTensorType::get(slice_after_size, output_type.getElementType()), - current_tensor, rewriter.getI64ArrayAttr(slice_after_begin), - rewriter.getI64ArrayAttr(slice_after_size)); + current_tensor, rewriter.getDenseI64ArrayAttr(slice_after_begin), + rewriter.getDenseI64ArrayAttr(slice_after_size)); if (pad_after == 1) { slices.push_back(slice_after_op); @@ -3491,21 +3512,16 @@ llvm::Optional convertMirrorPadCommon(PatternRewriter& rewriter, pad_before + input_type.getDimSize(axis) + pad_after; // Create the expected output shape and type, and initialize it with zero. - RankedTensorType result_type = - RankedTensorType::get(current_dim_size, output_type.getElementType()); - DenseElementsAttr zero = result_type.getElementType().isa() - ? DenseElementsAttr::get(result_type, {0.f}) - : DenseElementsAttr::get(result_type, {0}); - Value result_value = - rewriter.create(op->getLoc(), result_type, zero); + ShapedType result_type = + UnrankedTensorType::get(output_type.getElementType()); // Concatenate the old tensor with padding areas. - result = convertConcatV2Op(rewriter, op, result_value, slices, axis); + result = convertConcatV2Op(rewriter, op, result_type, slices, axis); if (!result) return llvm::None; // Update to the padded tensor - current_tensor = result.getValue(); + current_tensor = result.value(); } return result; @@ -3538,7 +3554,7 @@ llvm::Optional convertTFConv2DCommon( rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(a1_transpose_dims, filter_type.getElementType()), - filter, a1_filter_transpose_perm.getValue()); + filter, a1_filter_transpose_perm.value()); // Only support NHWC now. if (data_format_ref.str() != "NHWC") { @@ -3546,27 +3562,27 @@ llvm::Optional convertTFConv2DCommon( return llvm::None; } - ArrayAttr stride; - ArrayAttr dilation; - ArrayAttr pad; + DenseI64ArrayAttr stride; + DenseI64ArrayAttr dilation; + DenseI64ArrayAttr pad; { if (!strides_attr) { - stride = rewriter.getI64ArrayAttr({1, 1}); + stride = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now int64_t stride_h = strides_attr[1].cast().getInt(); int64_t stride_w = strides_attr[2].cast().getInt(); - stride = rewriter.getI64ArrayAttr({stride_h, stride_w}); + stride = rewriter.getDenseI64ArrayAttr({stride_h, stride_w}); } } { if (!dilations_attr) { - dilation = rewriter.getI64ArrayAttr({1, 1}); + dilation = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now int64_t dilation_h = dilations_attr[1].cast().getInt(); int64_t dilation_w = dilations_attr[2].cast().getInt(); - dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w}); + dilation = rewriter.getDenseI64ArrayAttr({dilation_h, dilation_w}); } } { @@ -3628,12 +3644,12 @@ llvm::Optional convertConv3DCommon(PatternRewriter& rewriter, return llvm::None; } - ArrayAttr strides_attr = rewriter.getI64ArrayAttr(strides); - ArrayAttr dilations_attr = rewriter.getI64ArrayAttr(dilations); + DenseI64ArrayAttr strides_attr = rewriter.getDenseI64ArrayAttr(strides); + DenseI64ArrayAttr dilations_attr = rewriter.getDenseI64ArrayAttr(dilations); RankedTensorType input_type = input.getType().cast(); RankedTensorType filter_type = filter.getType().cast(); - ArrayAttr pads_attr; + DenseI64ArrayAttr pads_attr; if (!getPaddingValuesFromPadType(tf_pad, data_format_tf, 0, input_type, filter_type, strides_attr, dilations_attr, rewriter, pads_attr)) { @@ -3783,6 +3799,15 @@ llvm::Optional convertGatherOp(PatternRewriter& rewriter, Operation* op, return llvm::None; } + // tf/tfl allow i64 indices, but tosa does not. + if (indices_type.getElementType().isInteger(64)) { + indices_type = + indices_type.clone(rewriter.getI32Type()).dyn_cast(); + indices_value = CreateOpAndInfer(rewriter, op->getLoc(), + indices_type, indices_value) + .getResult(); + } + // Sizes for each of these fields. SmallVector params_batch, params_indices, params_left_channels, params_right_channels; @@ -3914,7 +3939,7 @@ llvm::Optional convertGatherOp(PatternRewriter& rewriter, Operation* op, rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(params_transpose_shape, params_type.getElementType()), - params_value, params_transpose_perm_val.getValue()); + params_value, params_transpose_perm_val.value()); if (count_dynamic_dims(tosa_values_shape) > 1) { return (void)rewriter.notifyMatchFailure( @@ -3929,7 +3954,8 @@ llvm::Optional convertGatherOp(PatternRewriter& rewriter, Operation* op, tensorflow::GetTypeFromTFTensorShape(tosa_values_shape, params_type.getElementType()), params_transpose_op.getResult(), - rewriter.getI64ArrayAttr(tosa_values_shape)); + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(tosa_values_shape))); if (count_dynamic_dims(tosa_indices_shape) > 1) { return (void)rewriter.notifyMatchFailure( @@ -3943,7 +3969,9 @@ llvm::Optional convertGatherOp(PatternRewriter& rewriter, Operation* op, rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(tosa_indices_shape, indices_type.getElementType()), - indices_value, rewriter.getI64ArrayAttr(tosa_indices_shape)); + indices_value, + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(tosa_indices_shape))); auto tosa_gather_op = CreateOpAndInfer( rewriter, op->getLoc(), @@ -3963,12 +3991,13 @@ llvm::Optional convertGatherOp(PatternRewriter& rewriter, Operation* op, tensorflow::GetTypeFromTFTensorShape(result_reshape_shape, params_type.getElementType()), tosa_gather_op.getResult(), - rewriter.getI64ArrayAttr(result_reshape_shape)); + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(result_reshape_shape))); - return CreateOpAndInfer( - rewriter, op->getLoc(), result_type, - tosa_result_reshape_op.getResult(), - result_transpose_perm_val.getValue()) + return CreateOpAndInfer(rewriter, op->getLoc(), + result_type, + tosa_result_reshape_op.getResult(), + result_transpose_perm_val.value()) .getResult(); } @@ -4082,14 +4111,18 @@ llvm::Optional convertGatherNdOp(PatternRewriter& rewriter, rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(tosa_values_shape, params_type.getElementType()), - params_value, rewriter.getI64ArrayAttr(tosa_values_shape)); + params_value, + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(tosa_values_shape))); // Flatten the input indices tensor to an [W, ND] matrix. auto indices_matrix_reshape_op = CreateOpAndInfer( rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(indices_matrix_shape, indices_type.getElementType()), - indices_value, rewriter.getI64ArrayAttr(indices_matrix_shape)); + indices_value, + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(indices_matrix_shape))); SmallVector flattened_coeff_vec; for (int i = 1; i < ND; i++) { @@ -4112,8 +4145,7 @@ llvm::Optional convertGatherNdOp(PatternRewriter& rewriter, rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(indices_matrix_shape, indices_type.getElementType()), - indices_matrix_reshape_op.getResult(), flattened_coeff_value.getValue(), - 0); + indices_matrix_reshape_op.getResult(), flattened_coeff_value.value(), 0); // Sum up the products of the coefficients and coordinates auto flattened_indices_reduce_op = CreateOpAndInfer( @@ -4128,7 +4160,8 @@ llvm::Optional convertGatherNdOp(PatternRewriter& rewriter, tensorflow::GetTypeFromTFTensorShape(tosa_indices_shape, indices_type.getElementType()), flattened_indices_reduce_op.getResult(), - rewriter.getI64ArrayAttr(tosa_indices_shape)); + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(tosa_indices_shape))); // Now the gather op itself auto tosa_gather_op = CreateOpAndInfer( @@ -4141,7 +4174,8 @@ llvm::Optional convertGatherNdOp(PatternRewriter& rewriter, // ParamChannels]. return CreateOpAndInfer( rewriter, op->getLoc(), result_type, tosa_gather_op.getResult(), - rewriter.getI64ArrayAttr(result_type.getShape())) + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(result_type.getShape()))) .getResult(); } @@ -4226,35 +4260,38 @@ llvm::Optional convertOneHotOp(PatternRewriter& rewriter, Operation* op, rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape({1, 1, 1}, on_value_type.getElementType()), - on_value, rewriter.getI64ArrayAttr({1, 1, 1})); + on_value, rewriter.getDenseI64ArrayAttr({1, 1, 1})); // And tile to [N, W, C] auto op2_tile_op1 = CreateOpAndInfer( rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape({N, W, C}, on_value_type.getElementType()), - op1_reshape_on_value.getResult(), rewriter.getI64ArrayAttr({N, W, C})); + op1_reshape_on_value.getResult(), + rewriter.getDenseI64ArrayAttr({N, W, C})); // Reshape off_value to [1, 1, 1] auto op3_reshape_off_value = CreateOpAndInfer( rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape({1, 1, 1}, off_value_type.getElementType()), - off_value, rewriter.getI64ArrayAttr({1, 1, 1})); + off_value, rewriter.getDenseI64ArrayAttr({1, 1, 1})); // And tile to [N, K, C] auto op4_tile_op3 = CreateOpAndInfer( rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape({N, K, C}, on_value_type.getElementType()), - op3_reshape_off_value.getResult(), rewriter.getI64ArrayAttr({N, K, C})); + op3_reshape_off_value.getResult(), + rewriter.getDenseI64ArrayAttr({N, K, C})); // Reshape indices to [N, W] auto op5_reshape_indices = CreateOpAndInfer( rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape({N, W}, indices_type.getElementType()), - indices_value, rewriter.getI64ArrayAttr({N, W})); + indices_value, + rewriter.getDenseI64ArrayAttr(tensorflow::ConvertMlirShapeToTF({N, W}))); // Scatter to [N, K, C] auto op6_scatter_op4_op5_op2 = CreateOpAndInfer( @@ -4270,7 +4307,8 @@ llvm::Optional convertOneHotOp(PatternRewriter& rewriter, Operation* op, tensorflow::GetTypeFromTFTensorShape({left_dim, right_dim, K}, result_type.getElementType()), op6_scatter_op4_op5_op2.getResult(), - rewriter.getI64ArrayAttr({left_dim, right_dim, K})); + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF({left_dim, right_dim, K}))); // Transposed to [LeftDims, K, RightDims]. llvm::Optional perm_const = @@ -4282,12 +4320,117 @@ llvm::Optional convertOneHotOp(PatternRewriter& rewriter, Operation* op, rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape({left_dim, K, right_dim}, result_type.getElementType()), - op7_reshape_op6.getResult(), perm_const.getValue()); + op7_reshape_op6.getResult(), perm_const.value()); // Reshaped to result.shape. return CreateOpAndInfer( rewriter, op->getLoc(), result_type, op8_transpose_op7.getResult(), - rewriter.getI64ArrayAttr(result_type.getShape())) + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(result_type.getShape()))) + .getResult(); +} + +// Lowers Sin operator to a sequence of TOSA ops. +llvm::Optional convertSinOp(PatternRewriter& rewriter, Operation* op, + Value input, ShapedType output_type) { + RankedTensorType input_type = input.getType().dyn_cast(); + Location loc = op->getLoc(); + + Type input_ety = input_type.getElementType(); + Type output_ety = output_type.getElementType(); + + if (!input) return llvm::None; + + if (input_ety != output_ety) { + (void)rewriter.notifyMatchFailure(op, + "input/output element type must match"); + return llvm::None; + } + + bool input_is_fp = input_ety.isF32(); + bool output_is_fp = output_ety.isF32(); + + if (!input_is_fp || !output_is_fp) { + (void)rewriter.notifyMatchFailure(op, "input/result must be fp32"); + return llvm::None; + } + + // To perform a sin operation we remap the sin domain to be over a single + // period of the function, remapping to the domain of the table function. + // We then remap the range of the table function to map to the range of the + // sin operation. + + // 1. Normalize the period of the domain from [0, 2π) to [0, 1). + auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type()); + Value fp_scale = rewriter.create( + loc, fp_scalar_ty, + DenseElementsAttr::get(fp_scalar_ty, {static_cast(0.5 / M_PI)})); + + // 2. Remap the periodic behavior of the domain to line up within [0, 1). + Value fp_scaled = + CreateOpAndInfer(rewriter, loc, input_type, input, fp_scale, + rewriter.getI32IntegerAttr(0)); + auto floored = + CreateOpAndInfer(rewriter, loc, input_type, fp_scaled); + auto repeated = CreateOpAndInfer(rewriter, loc, input_type, + fp_scaled, floored); + + // 3. Scale and translate the normalized domain to the table domain. This + // includes a translating and scaling to [-int16_max, int16_max] and casting + // to an i16. + Value one = rewriter.create( + loc, fp_scalar_ty, DenseElementsAttr::get(fp_scalar_ty, {1.0f})); + + Value two = rewriter.create( + loc, fp_scalar_ty, DenseElementsAttr::get(fp_scalar_ty, {2.0f})); + auto scale_up = CreateOpAndInfer( + rewriter, loc, input_type, repeated, two, rewriter.getI32IntegerAttr(0)); + auto translate = + CreateOpAndInfer(rewriter, loc, input_type, scale_up, one); + + Value int_limit = rewriter.create( + loc, fp_scalar_ty, + DenseElementsAttr::get( + fp_scalar_ty, + {static_cast(std::numeric_limits::max())})); + auto int_scaled = + CreateOpAndInfer(rewriter, loc, input_type, translate, + int_limit, rewriter.getI32IntegerAttr(0)); + + auto int16_ty = input_type.clone(rewriter.getIntegerType(16)); + auto casted = + CreateOpAndInfer(rewriter, loc, int16_ty, int_scaled); + + // 4. Compute the lookup table using the range of [-255, 255] for sin. + llvm::SmallVector values; + const int num_values = 513; + values.resize(num_values, 0); + // First and last values should be 0; + for (int i = 1; i < num_values - 1; ++i) + values[i] = std::numeric_limits::max() * + sin(static_cast(i) * 2.0 * M_PI / (num_values - 1.0)); + + auto table_ty = + RankedTensorType::get({num_values}, rewriter.getIntegerType(16)); + Value table = rewriter.create( + loc, table_ty, DenseElementsAttr::get(table_ty, llvm::ArrayRef(values))); + + auto table_result_ty = input_type.clone(rewriter.getIntegerType(32)); + auto table_result = CreateOpAndInfer( + rewriter, loc, table_result_ty, casted, table); + + // 5. The range of table is a 23-bit two's compliment value. Normalize the + // range by casting to an fp32 and dividing by 2^22. + auto table_result_fp = + CreateOpAndInfer(rewriter, loc, input_type, table_result); + auto output_scale = rewriter.create( + loc, fp_scalar_ty, + DenseElementsAttr::get( + fp_scalar_ty, + {static_cast(1.0 / static_cast(1 << 22))})); + + return CreateOpAndInfer(rewriter, loc, output_type, table_result_fp, + output_scale, rewriter.getI32IntegerAttr(0)) .getResult(); } diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h index 5684353860c..686b536032e 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h @@ -72,7 +72,7 @@ llvm::Optional convertRoundOp(PatternRewriter& rewriter, Operation* op, // Lowers ConcatV2 to TOSA. llvm::Optional convertConcatV2Op(PatternRewriter& rewriter, - Operation* op, Value result_value, + Operation* op, ShapedType result_type, SmallVectorImpl& values, int32_t axis); @@ -298,6 +298,10 @@ llvm::Optional convertOneHotOp(PatternRewriter& rewriter, Operation* op, Value on_value, Value off_value, int32_t depth, int32_t axis); +// Lowers 32-bit floating sin operator to a sequence of TOSA ops. +llvm::Optional convertSinOp(PatternRewriter& rewriter, Operation* op, + Value input, ShapedType output_type); + }; // namespace tosa }; // namespace mlir diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc index 99f6ec88785..350cb4c295b 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc @@ -19,7 +19,10 @@ limitations under the License. #include #include #include +#include +#include #include +#include #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project @@ -45,7 +48,7 @@ namespace { // Performs lowering to TOSA dialect class LegalizeTF : public impl::TosaLegalizeTFPassBase { public: - explicit LegalizeTF() {} + explicit LegalizeTF() = default; void runOnOperation() override; }; @@ -129,6 +132,8 @@ DECL_CONVERT_OP(GatherNd); DECL_CONVERT_OP(SelectV2); DECL_CONVERT_OP(SpaceToDepth); DECL_CONVERT_OP(DepthToSpace); +DECL_CONVERT_OP(Sin); +DECL_CONVERT_OP(Cos); DECL_CONVERT_OP(SpaceToBatchND); DECL_CONVERT_OP(BatchToSpaceND); DECL_CONVERT_OP(ZerosLike); @@ -156,7 +161,7 @@ LogicalResult ConvertTFReluOp::matchAndRewrite( if (!output_type) return failure(); CreateReplaceOpAndInfer( - rewriter, op, output_type, tf_relu_op.features(), + rewriter, op, output_type, tf_relu_op.getFeatures(), rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(std::numeric_limits::max()), rewriter.getF32FloatAttr(0.0f), @@ -174,7 +179,7 @@ LogicalResult ConvertTFRelu6Op::matchAndRewrite( if (!output_type) return failure(); CreateReplaceOpAndInfer( - rewriter, op, output_type, tf_relu6_op.features(), + rewriter, op, output_type, tf_relu6_op.getFeatures(), rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(6), rewriter.getF32FloatAttr(0.0f), rewriter.getF32FloatAttr(6.0f)); return success(); @@ -189,8 +194,8 @@ LogicalResult ConvertTFEqualOp::matchAndRewrite( // Not a tensor output if (!output_type) return failure(); - CreateReplaceOpAndInfer(rewriter, op, output_type, - tf_equal_op.x(), tf_equal_op.y()); + CreateReplaceOpAndInfer( + rewriter, op, output_type, tf_equal_op.getX(), tf_equal_op.getY()); return success(); } @@ -203,9 +208,9 @@ LogicalResult ConvertTFNotEqualOp::matchAndRewrite( // Not a tensor output if (!output_type) return failure(); - auto op1_equal_in = - CreateOpAndInfer(rewriter, op->getLoc(), output_type, - tf_not_equal_op.x(), tf_not_equal_op.y()); + auto op1_equal_in = CreateOpAndInfer( + rewriter, op->getLoc(), output_type, tf_not_equal_op.getX(), + tf_not_equal_op.getY()); auto op2_not_op1 = CreateOpAndInfer( rewriter, op->getLoc(), output_type, op1_equal_in.getResult()); @@ -225,7 +230,7 @@ LogicalResult ConvertTFGreaterOp::matchAndRewrite( if (!output_type) return failure(); CreateReplaceOpAndInfer( - rewriter, op, output_type, tf_greater_op.x(), tf_greater_op.y()); + rewriter, op, output_type, tf_greater_op.getX(), tf_greater_op.getY()); return success(); } @@ -239,8 +244,50 @@ LogicalResult ConvertTFGreaterEqualOp::matchAndRewrite( if (!output_type) return failure(); CreateReplaceOpAndInfer(rewriter, op, output_type, - tf_greater_equal_op.x(), - tf_greater_equal_op.y()); + tf_greater_equal_op.getX(), + tf_greater_equal_op.getY()); + return success(); +} + +LogicalResult ConvertTFSinOp::matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const { + auto tf_sin_op = cast(op); + ShapedType output_type = tf_sin_op.getResult().getType().cast(); + + llvm::Optional result = + convertSinOp(rewriter, op, tf_sin_op.getX(), output_type); + if (!result) return failure(); + + rewriter.replaceOp(op, {result.value()}); + return success(); +} + +LogicalResult ConvertTFCosOp::matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const { + auto tf_cos_op = cast(op); + Value input = tf_cos_op.getX(); + RankedTensorType input_ty = input.getType().dyn_cast(); + ShapedType output_ty = tf_cos_op.getResult().getType().dyn_cast(); + + if (!input_ty || !output_ty) return failure(); + + bool input_is_fp = input_ty.getElementType().isa(); + bool output_is_fp = output_ty.getElementType().isa(); + + if (!input_is_fp || !output_is_fp) { + return rewriter.notifyMatchFailure( + op, "ConvertTFCosOp: input/result must be fp."); + } + + // Replace with the equivalent sin operation: + // cos(x) = sin(x + π / 2). + auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type()); + auto pi_2 = rewriter.create( + op->getLoc(), fp_scalar_ty, + DenseElementsAttr::get(fp_scalar_ty, {static_cast(M_PI_2)})); + auto offset = rewriter.create(op->getLoc(), input_ty, input, pi_2); + + CreateReplaceOpAndInfer(rewriter, op, output_ty, offset); return success(); } @@ -253,8 +300,8 @@ LogicalResult ConvertTFAddOp::matchAndRewrite(Operation* op, // Not a tensor output if (!output_type) return failure(); - CreateReplaceOpAndInfer(rewriter, op, output_type, tf_add_op.x(), - tf_add_op.y()); + CreateReplaceOpAndInfer(rewriter, op, output_type, + tf_add_op.getX(), tf_add_op.getY()); return success(); } @@ -268,7 +315,7 @@ LogicalResult ConvertTFAddV2Op::matchAndRewrite( if (!output_type) return failure(); CreateReplaceOpAndInfer(rewriter, op, output_type, - tf_addv2_op.x(), tf_addv2_op.y()); + tf_addv2_op.getX(), tf_addv2_op.getY()); return success(); } @@ -282,7 +329,7 @@ LogicalResult ConvertTFAddNOp::matchAndRewrite( // Not a tensor output if (!output_type) return failure(); - SmallVector inputs(tf_addn_op.inputs()); + SmallVector inputs(tf_addn_op.getInputs()); assert(inputs.size() >= 2); @@ -307,8 +354,8 @@ LogicalResult ConvertTFSubOp::matchAndRewrite(Operation* op, // Not a tensor output if (!output_type) return failure(); - CreateReplaceOpAndInfer(rewriter, op, output_type, tf_sub_op.x(), - tf_sub_op.y()); + CreateReplaceOpAndInfer(rewriter, op, output_type, + tf_sub_op.getX(), tf_sub_op.getY()); return success(); } @@ -317,11 +364,11 @@ LogicalResult ConvertTFMulOp::matchAndRewrite(Operation* op, auto tf_mul_op = cast(op); llvm::Optional result = convertMultiplyOp( - rewriter, op, tf_mul_op.getResult(), tf_mul_op.x(), tf_mul_op.y()); + rewriter, op, tf_mul_op.getResult(), tf_mul_op.getX(), tf_mul_op.getY()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -331,11 +378,11 @@ LogicalResult ConvertTFSquareOp::matchAndRewrite( llvm::Optional result = convertMultiplyOp(rewriter, op, tf_square_op.getResult(), - tf_square_op.x(), tf_square_op.x()); + tf_square_op.getX(), tf_square_op.getX()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -345,11 +392,11 @@ LogicalResult ConvertTFSquaredDifferenceOp::matchAndRewrite( llvm::Optional result = convertSquaredDifferenceOp(rewriter, op, tf_squared_op.getResult(), - tf_squared_op.x(), tf_squared_op.y()); + tf_squared_op.getX(), tf_squared_op.getY()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -357,22 +404,22 @@ LogicalResult ConvertTFRoundOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_round_op = cast(op); - TensorType input_type = tf_round_op.x().getType().dyn_cast(); + TensorType input_type = tf_round_op.getX().getType().dyn_cast(); if (!input_type) { return rewriter.notifyMatchFailure(op, "input not tensor type"); } if (input_type.getElementType().isa()) { - llvm::Optional result = - convertRoundOp(rewriter, op, tf_round_op.getResult(), tf_round_op.x()); + llvm::Optional result = convertRoundOp( + rewriter, op, tf_round_op.getResult(), tf_round_op.getX()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } else { - tf_round_op.replaceAllUsesWith(tf_round_op.x()); + tf_round_op.replaceAllUsesWith(tf_round_op.getX()); return success(); } } @@ -383,11 +430,11 @@ LogicalResult ConvertTFFloorDivOp::matchAndRewrite( llvm::Optional result = convertFloorDivOp(rewriter, op, tf_floordiv_op.getResult(), - tf_floordiv_op.x(), tf_floordiv_op.y()); + tf_floordiv_op.getX(), tf_floordiv_op.getY()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -398,11 +445,11 @@ LogicalResult ConvertTFFloorModOp::matchAndRewrite( llvm::Optional result = convertFloorModOp(rewriter, op, tf_floormod_op.getResult(), - tf_floormod_op.x(), tf_floormod_op.y()); + tf_floormod_op.getX(), tf_floormod_op.getY()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -424,7 +471,7 @@ LogicalResult ConvertTFMaximumOp::matchAndRewrite( if (!output_type) return failure(); CreateReplaceOpAndInfer( - rewriter, op, output_type, tf_maximum_op.x(), tf_maximum_op.y()); + rewriter, op, output_type, tf_maximum_op.getX(), tf_maximum_op.getY()); return success(); } @@ -438,7 +485,7 @@ LogicalResult ConvertTFMinimumOp::matchAndRewrite( if (!output_type) return failure(); CreateReplaceOpAndInfer( - rewriter, op, output_type, tf_minimum_op.x(), tf_minimum_op.y()); + rewriter, op, output_type, tf_minimum_op.getX(), tf_minimum_op.getY()); return success(); } @@ -446,7 +493,7 @@ LogicalResult ConvertTFRealDivOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_div_op = cast(op); - TensorType y_type = tf_div_op.y().getType().dyn_cast(); + TensorType y_type = tf_div_op.getY().getType().dyn_cast(); TensorType output_type = tf_div_op.getResult().getType().dyn_cast(); // Not a tensor output @@ -456,15 +503,15 @@ LogicalResult ConvertTFRealDivOp::matchAndRewrite( if (element_type.isa()) { CreateReplaceOpAndInfer(rewriter, op, output_type, - tf_div_op.x(), tf_div_op.y()); + tf_div_op.getX(), tf_div_op.getY()); return success(); } auto reciprocal_op = CreateOpAndInfer( - rewriter, op->getLoc(), tf_div_op.y().getType(), tf_div_op.y()); + rewriter, op->getLoc(), tf_div_op.getY().getType(), tf_div_op.getY()); auto mul_op = CreateOpAndInfer(rewriter, op->getLoc(), - output_type, tf_div_op.x(), + output_type, tf_div_op.getX(), reciprocal_op.getResult(), 0); rewriter.replaceOp(op, {mul_op.getResult()}); @@ -475,14 +522,15 @@ LogicalResult ConvertTFArgMaxOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_argmax_op = cast(op); - TensorType input_type = tf_argmax_op.input().getType().dyn_cast(); + TensorType input_type = + tf_argmax_op.getInput().getType().dyn_cast(); TensorType output_type = tf_argmax_op.getResult().getType().dyn_cast(); // Not a tensor output if (!output_type || !input_type) return failure(); ElementsAttr axis_elems; - if (!matchPattern(tf_argmax_op.dimension(), m_Constant(&axis_elems))) + if (!matchPattern(tf_argmax_op.getDimension(), m_Constant(&axis_elems))) return failure(); int32_t axis = axis_elems.getValues()[0].getInt(); @@ -497,7 +545,7 @@ LogicalResult ConvertTFArgMaxOp::matchAndRewrite( IntegerAttr axis_attr = rewriter.getI64IntegerAttr(axis); CreateReplaceOpAndInfer(rewriter, op, output_type, - tf_argmax_op.input(), axis_attr); + tf_argmax_op.getInput(), axis_attr); return success(); } @@ -507,57 +555,57 @@ LogicalResult ConvertTFAvgPoolOp::matchAndRewrite( auto tf_avgpool_op = cast(op); RankedTensorType input_type = - tf_avgpool_op.value().getType().dyn_cast(); + tf_avgpool_op.getValue().getType().dyn_cast(); RankedTensorType output_type = tf_avgpool_op.getResult().getType().dyn_cast(); // Not a ranked tensor output if (!input_type || !output_type) return failure(); - auto tmpAttr = tf_avgpool_op.data_formatAttr(); + auto tmpAttr = tf_avgpool_op.getDataFormatAttr(); if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure(); - ArrayAttr pad; - ArrayAttr stride; - ArrayAttr kernel; + DenseI64ArrayAttr pad; + DenseI64ArrayAttr stride; + DenseI64ArrayAttr kernel; { - auto tmpAttr = tf_avgpool_op.strides(); + auto tmpAttr = tf_avgpool_op.getStrides(); if (!tmpAttr) { - stride = rewriter.getI64ArrayAttr({1, 1}); + stride = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now int64_t stride_h = tmpAttr[1].dyn_cast().getInt(); int64_t stride_w = tmpAttr[2].dyn_cast().getInt(); - stride = rewriter.getI64ArrayAttr({stride_h, stride_w}); + stride = rewriter.getDenseI64ArrayAttr({stride_h, stride_w}); } } { - auto tmpAttr = tf_avgpool_op.ksize(); + auto tmpAttr = tf_avgpool_op.getKsize(); if (!tmpAttr) { - kernel = rewriter.getI64ArrayAttr({1, 1}); + kernel = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now int64_t kernel_h = tmpAttr[1].dyn_cast().getInt(); int64_t kernel_w = tmpAttr[2].dyn_cast().getInt(); - kernel = rewriter.getI64ArrayAttr({kernel_h, kernel_w}); + kernel = rewriter.getDenseI64ArrayAttr({kernel_h, kernel_w}); } } { tensorflow::Padding tf_pad; - if (!GetPaddingFromString(tf_avgpool_op.padding().str(), &tf_pad).ok()) + if (!GetPaddingFromString(tf_avgpool_op.getPadding().str(), &tf_pad).ok()) return failure(); - ArrayAttr dilation = - rewriter.getI64ArrayAttr({1, 1}); // Pooling has no non-unit dilation + DenseI64ArrayAttr dilation = rewriter.getDenseI64ArrayAttr( + {1, 1}); // Pooling has no non-unit dilation SmallVector i64array; - for (auto& elem : tf_avgpool_op.ksize()) { + for (auto& elem : tf_avgpool_op.getKsize()) { int64_t value = elem.dyn_cast().getInt(); i64array.emplace_back(value); } RankedTensorType filter_type = tensorflow::GetTypeFromTFTensorShape( - llvm::makeArrayRef(i64array), rewriter.getIntegerType(64)); + llvm::ArrayRef(i64array), rewriter.getIntegerType(64)); if (!getPaddingValuesFromPadType( tf_pad, @@ -568,7 +616,7 @@ LogicalResult ConvertTFAvgPoolOp::matchAndRewrite( } CreateReplaceOpAndInfer( - rewriter, op, output_type, tf_avgpool_op.value(), kernel, stride, pad); + rewriter, op, output_type, tf_avgpool_op.getValue(), kernel, stride, pad); return success(); } @@ -577,57 +625,57 @@ LogicalResult ConvertTFMaxPoolOp::matchAndRewrite( auto tf_maxpool_op = cast(op); RankedTensorType input_type = - tf_maxpool_op.input().getType().dyn_cast(); + tf_maxpool_op.getInput().getType().dyn_cast(); RankedTensorType output_type = tf_maxpool_op.getResult().getType().dyn_cast(); // Not a ranked tensor output if (!input_type || !output_type) return failure(); - auto tmpAttr = tf_maxpool_op.data_formatAttr(); + auto tmpAttr = tf_maxpool_op.getDataFormatAttr(); if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure(); - ArrayAttr pad; - ArrayAttr stride; - ArrayAttr kernel; + DenseI64ArrayAttr pad; + DenseI64ArrayAttr stride; + DenseI64ArrayAttr kernel; { - auto tmpAttr = tf_maxpool_op.strides(); + auto tmpAttr = tf_maxpool_op.getStrides(); if (!tmpAttr) { - stride = rewriter.getI64ArrayAttr({1, 1}); + stride = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now int64_t stride_h = tmpAttr[1].dyn_cast().getInt(); int64_t stride_w = tmpAttr[2].dyn_cast().getInt(); - stride = rewriter.getI64ArrayAttr({stride_h, stride_w}); + stride = rewriter.getDenseI64ArrayAttr({stride_h, stride_w}); } } { - auto tmpAttr = tf_maxpool_op.ksize(); + auto tmpAttr = tf_maxpool_op.getKsize(); if (!tmpAttr) { - kernel = rewriter.getI64ArrayAttr({1, 1}); + kernel = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now int64_t kernel_h = tmpAttr[1].dyn_cast().getInt(); int64_t kernel_w = tmpAttr[2].dyn_cast().getInt(); - kernel = rewriter.getI64ArrayAttr({kernel_h, kernel_w}); + kernel = rewriter.getDenseI64ArrayAttr({kernel_h, kernel_w}); } } { tensorflow::Padding tf_pad; - if (!GetPaddingFromString(tf_maxpool_op.padding().str(), &tf_pad).ok()) + if (!GetPaddingFromString(tf_maxpool_op.getPadding().str(), &tf_pad).ok()) return failure(); // Pooling has no non-unit dilation - ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1}); + DenseI64ArrayAttr dilation = rewriter.getDenseI64ArrayAttr({1, 1}); SmallVector i64array; - for (auto& elem : tf_maxpool_op.ksize()) { + for (auto& elem : tf_maxpool_op.getKsize()) { int64_t value = elem.dyn_cast().getInt(); i64array.emplace_back(value); } RankedTensorType filter_type = tensorflow::GetTypeFromTFTensorShape( - llvm::makeArrayRef(i64array), rewriter.getIntegerType(64)); + llvm::ArrayRef(i64array), rewriter.getIntegerType(64)); if (!getPaddingValuesFromPadType( tf_pad, @@ -638,27 +686,28 @@ LogicalResult ConvertTFMaxPoolOp::matchAndRewrite( } CreateReplaceOpAndInfer( - rewriter, op, output_type, tf_maxpool_op.input(), kernel, stride, pad); + rewriter, op, output_type, tf_maxpool_op.getInput(), kernel, stride, pad); return success(); } LogicalResult ConvertTFConcatV2Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_concatv2_op = cast(op); - SmallVector values(tf_concatv2_op.values()); + auto result_type = tf_concatv2_op.getResult().getType().cast(); + SmallVector values(tf_concatv2_op.getValues()); ElementsAttr axis_elems; - if (!matchPattern(tf_concatv2_op.axis(), m_Constant(&axis_elems))) + if (!matchPattern(tf_concatv2_op.getAxis(), m_Constant(&axis_elems))) return failure(); int32_t axis = axis_elems.getValues()[0].getInt(); llvm::Optional result = - convertConcatV2Op(rewriter, op, tf_concatv2_op.getResult(), values, axis); + convertConcatV2Op(rewriter, op, result_type, values, axis); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -678,10 +727,10 @@ LogicalResult ConvertTFReshapeOp::matchAndRewrite( for (int i = 0; i < output_type.getShape().size(); i++) { shape_vals.push_back(output_type.getShape()[i]); } - ArrayAttr shape_attr = rewriter.getI64ArrayAttr(shape_vals); + DenseI64ArrayAttr shape_attr = rewriter.getDenseI64ArrayAttr(shape_vals); - CreateReplaceOpAndInfer(rewriter, op, output_type, - tf_reshape_op.tensor(), shape_attr); + CreateReplaceOpAndInfer( + rewriter, op, output_type, tf_reshape_op.getTensor(), shape_attr); return success(); } @@ -690,14 +739,14 @@ LogicalResult ConvertTFRankOp::matchAndRewrite( auto tf_rank_op = cast(op); RankedTensorType input_type = - tf_rank_op.input().getType().dyn_cast(); + tf_rank_op.getInput().getType().dyn_cast(); if (!input_type) return failure(); int32_t rank = input_type.getRank(); RankedTensorType rank_type = tensorflow::GetTypeFromTFTensorShape({1}, rewriter.getIntegerType(32)); - auto rank_attr = DenseElementsAttr::get(rank_type, {rank}); + auto rank_attr = DenseI32ArrayAttr::get(rewriter.getContext(), {rank}); auto rank_const = CreateOpAndInfer(rewriter, op->getLoc(), rank_type, rank_attr); @@ -716,7 +765,7 @@ LogicalResult ConvertTFShapeOp::matchAndRewrite( if (!output_type) return failure(); RankedTensorType input_type = - tf_shape_op.input().getType().dyn_cast(); + tf_shape_op.getInput().getType().dyn_cast(); if (!input_type) return failure(); auto input_shape = input_type.getShape(); @@ -729,7 +778,7 @@ LogicalResult ConvertTFShapeOp::matchAndRewrite( RankedTensorType shape_type = tensorflow::GetTypeFromTFTensorShape( {static_cast(shape_arr.size())}, rewriter.getIntegerType(32)); auto shape_attr = - DenseElementsAttr::get(shape_type, llvm::makeArrayRef(shape_arr)); + DenseI32ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(shape_arr)); auto shape_const = CreateOpAndInfer(rewriter, op->getLoc(), shape_type, shape_attr); @@ -742,13 +791,13 @@ LogicalResult ConvertTFExpandDimsOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_expanddims_op = cast(op); - llvm::Optional result = - convertExpandDimsOp(rewriter, op, tf_expanddims_op.getResult(), - tf_expanddims_op.input(), tf_expanddims_op.dim()); + llvm::Optional result = convertExpandDimsOp( + rewriter, op, tf_expanddims_op.getResult(), tf_expanddims_op.getInput(), + tf_expanddims_op.getDim()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -758,7 +807,7 @@ LogicalResult ConvertTFSqueezeOp::matchAndRewrite( auto tf_squeeze_op = cast(op); // Copy squeeze_dims into int32_t array - auto squeeze_dims_attr = tf_squeeze_op.squeeze_dimsAttr(); + auto squeeze_dims_attr = tf_squeeze_op.getSqueezeDimsAttr(); SmallVector squeeze_dims; for (auto& squeeze_dim : squeeze_dims_attr) { squeeze_dims.emplace_back(squeeze_dim.dyn_cast().getInt()); @@ -766,11 +815,11 @@ LogicalResult ConvertTFSqueezeOp::matchAndRewrite( llvm::Optional result = convertSqueezeOp(rewriter, op, tf_squeeze_op.getResult(), - tf_squeeze_op.input(), squeeze_dims); + tf_squeeze_op.getInput(), squeeze_dims); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -785,7 +834,7 @@ LogicalResult ConvertTFFillOp::matchAndRewrite( if (!output_type) return failure(); ElementsAttr dims_elems; - if (!matchPattern(tf_fill_op.dims(), m_Constant(&dims_elems))) + if (!matchPattern(tf_fill_op.getDims(), m_Constant(&dims_elems))) return failure(); SmallVector dims_vals; uint32_t total_size = 1; @@ -795,24 +844,26 @@ LogicalResult ConvertTFFillOp::matchAndRewrite( } ElementsAttr value_elem; - if (!matchPattern(tf_fill_op.value(), m_Constant(&value_elem))) + if (!matchPattern(tf_fill_op.getValue(), m_Constant(&value_elem))) return failure(); RankedTensorType fill_type = tensorflow::GetTypeFromTFTensorShape( ArrayRef(dims_vals), value_elem.getType().getElementType()); - DenseElementsAttr fill_attr; + DenseArrayAttr fill_attr; // Convert to a compatible zero type if (value_elem.getType().getElementType().isa()) { SmallVector fill_arr( total_size, value_elem.getValues()[0].getValue().convertToFloat()); - fill_attr = DenseElementsAttr::get(fill_type, llvm::makeArrayRef(fill_arr)); + fill_attr = + DenseF32ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(fill_arr)); } else { SmallVector fill_arr( total_size, value_elem.getValues()[0].getValue().getLimitedValue()); - fill_attr = DenseElementsAttr::get(fill_type, llvm::makeArrayRef(fill_arr)); + fill_attr = + DenseI32ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(fill_arr)); } auto fill_const_op = CreateOpAndInfer(rewriter, op->getLoc(), fill_type, fill_attr); @@ -826,7 +877,7 @@ LogicalResult ConvertTFConv2DOp::matchAndRewrite( auto tf_conv2d_op = cast(op); RankedTensorType filter_type = - tf_conv2d_op.filter().getType().dyn_cast(); + tf_conv2d_op.getFilter().getType().dyn_cast(); RankedTensorType output_type = tf_conv2d_op.getResult().getType().dyn_cast(); @@ -839,14 +890,14 @@ LogicalResult ConvertTFConv2DOp::matchAndRewrite( bias_attr.cast()); llvm::Optional result = convertTFConv2DCommon( - rewriter, op, output_type, tf_conv2d_op.input(), tf_conv2d_op.filter(), - bias, tf_conv2d_op.strides(), tf_conv2d_op.dilations(), - tf_conv2d_op.explicit_paddings(), tf_conv2d_op.padding(), - tf_conv2d_op.data_format()); + rewriter, op, output_type, tf_conv2d_op.getInput(), + tf_conv2d_op.getFilter(), bias, tf_conv2d_op.getStrides(), + tf_conv2d_op.getDilations(), tf_conv2d_op.getExplicitPaddings(), + tf_conv2d_op.getPadding(), tf_conv2d_op.getDataFormat()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -856,7 +907,7 @@ LogicalResult ConvertTFConv3DOp::matchAndRewrite( auto tf_conv3d_op = cast(op); RankedTensorType filter_type = - tf_conv3d_op.filter().getType().dyn_cast(); + tf_conv3d_op.getFilter().getType().dyn_cast(); RankedTensorType output_type = tf_conv3d_op.getResult().getType().dyn_cast(); @@ -874,9 +925,10 @@ LogicalResult ConvertTFConv3DOp::matchAndRewrite( bias_attr.cast()); llvm::Optional result = convertTFConv3DCommon( - rewriter, op, output_type, tf_conv3d_op.input(), tf_conv3d_op.filter(), - bias, tf_conv3d_op.strides(), tf_conv3d_op.dilations(), - tf_conv3d_op.padding(), tf_conv3d_op.data_format()); + rewriter, op, output_type, tf_conv3d_op.getInput(), + tf_conv3d_op.getFilter(), bias, tf_conv3d_op.getStrides(), + tf_conv3d_op.getDilations(), tf_conv3d_op.getPadding(), + tf_conv3d_op.getDataFormat()); if (!result) return failure(); @@ -890,9 +942,9 @@ LogicalResult ConvertTFDepthwiseConv2dNativeOp::matchAndRewrite( auto tf_dwconv2d_op = cast(op); RankedTensorType input_type = - tf_dwconv2d_op.input().getType().dyn_cast(); + tf_dwconv2d_op.getInput().getType().dyn_cast(); RankedTensorType filter_type = - tf_dwconv2d_op.filter().getType().dyn_cast(); + tf_dwconv2d_op.getFilter().getType().dyn_cast(); RankedTensorType output_type = tf_dwconv2d_op.getResult().getType().dyn_cast(); // Not a ranked tensor output @@ -904,46 +956,47 @@ LogicalResult ConvertTFDepthwiseConv2dNativeOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "filter type unranked tensor"); } - auto tmpAttr = tf_dwconv2d_op.data_formatAttr(); + auto tmpAttr = tf_dwconv2d_op.getDataFormatAttr(); if (tmpAttr && tmpAttr.getValue().str() != "NHWC") return failure(); - ArrayAttr stride; - ArrayAttr dilation; - ArrayAttr pad; + DenseI64ArrayAttr stride; + DenseI64ArrayAttr dilation; + DenseI64ArrayAttr pad; { - auto tmpAttr = tf_dwconv2d_op.strides(); + auto tmpAttr = tf_dwconv2d_op.getStrides(); if (!tmpAttr) { - stride = rewriter.getI64ArrayAttr({1, 1}); + stride = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now int64_t stride_h = tmpAttr[1].dyn_cast().getInt(); int64_t stride_w = tmpAttr[2].dyn_cast().getInt(); - stride = rewriter.getI64ArrayAttr({stride_h, stride_w}); + stride = rewriter.getDenseI64ArrayAttr({stride_h, stride_w}); } } { - auto tmpAttr = tf_dwconv2d_op.dilations(); + auto tmpAttr = tf_dwconv2d_op.getDilations(); if (!tmpAttr) { - dilation = rewriter.getI64ArrayAttr({1, 1}); + dilation = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now int64_t dilation_h = tmpAttr[1].dyn_cast().getInt(); int64_t dilation_w = tmpAttr[2].dyn_cast().getInt(); - dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w}); + dilation = rewriter.getDenseI64ArrayAttr({dilation_h, dilation_w}); } } { tensorflow::Padding tf_pad; - if (!GetPaddingFromString(tf_dwconv2d_op.padding().str(), &tf_pad).ok()) + if (!GetPaddingFromString(tf_dwconv2d_op.getPadding().str(), &tf_pad).ok()) return failure(); tensorflow::TensorFormat data_format_tf; - if (!FormatFromString(tf_dwconv2d_op.data_format().str(), &data_format_tf)) + if (!FormatFromString(tf_dwconv2d_op.getDataFormat().str(), + &data_format_tf)) return failure(); if (tf_pad == tensorflow::Padding::EXPLICIT) { pad = getPaddingValuesFromExplicitPadAttr( - tf_dwconv2d_op.explicit_paddings(), data_format_tf, rewriter); + tf_dwconv2d_op.getExplicitPaddings(), data_format_tf, rewriter); } else { if (!getPaddingValuesFromPadType(tf_pad, data_format_tf, 0, // tensorflow::FORMAT_HWIO @@ -962,8 +1015,8 @@ LogicalResult ConvertTFDepthwiseConv2dNativeOp::matchAndRewrite( bias_attr.cast()); CreateReplaceOpAndInfer( - rewriter, op, output_type, tf_dwconv2d_op.input(), - tf_dwconv2d_op.filter(), bias, pad, stride, dilation); + rewriter, op, output_type, tf_dwconv2d_op.getInput(), + tf_dwconv2d_op.getFilter(), bias, pad, stride, dilation); return success(); } @@ -972,9 +1025,9 @@ LogicalResult ConvertTFConv2DBackpropInputOp::matchAndRewrite( auto tf_conv_op = cast(op); RankedTensorType input_type = - tf_conv_op.out_backprop().getType().dyn_cast(); + tf_conv_op.getOutBackprop().getType().dyn_cast(); RankedTensorType filter_type = - tf_conv_op.filter().getType().dyn_cast(); + tf_conv_op.getFilter().getType().dyn_cast(); RankedTensorType output_type = tf_conv_op.getResult().getType().dyn_cast(); // Not a ranked tensor output @@ -998,24 +1051,24 @@ LogicalResult ConvertTFConv2DBackpropInputOp::matchAndRewrite( rewriter, op->getLoc(), tensorflow::GetTypeFromTFTensorShape(ArrayRef(a1_transpose_dims), filter_type.getElementType()), - tf_conv_op.filter(), a1_filter_transpose_perm.getValue()); + tf_conv_op.getFilter(), a1_filter_transpose_perm.value()); - ArrayAttr stride; - ArrayAttr outpad; - ArrayAttr output_shape; + DenseI64ArrayAttr stride; + DenseI64ArrayAttr outpad; + DenseI64ArrayAttr output_shape; { - auto tmpAttr = tf_conv_op.strides(); + auto tmpAttr = tf_conv_op.getStrides(); if (!tmpAttr) { - stride = rewriter.getI64ArrayAttr({1, 1}); + stride = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now int64_t stride_h = tmpAttr[1].dyn_cast().getInt(); int64_t stride_w = tmpAttr[2].dyn_cast().getInt(); - stride = rewriter.getI64ArrayAttr({stride_h, stride_w}); + stride = rewriter.getDenseI64ArrayAttr({stride_h, stride_w}); } } { - auto tmpAttr = tf_conv_op.dilations(); + auto tmpAttr = tf_conv_op.getDilations(); if (tmpAttr) { // Note: hardcoded to NHWC for now int64_t dilation_h = tmpAttr[1].dyn_cast().getInt(); @@ -1026,16 +1079,16 @@ LogicalResult ConvertTFConv2DBackpropInputOp::matchAndRewrite( } { tensorflow::Padding tf_pad; - if (!GetPaddingFromString(tf_conv_op.padding().str(), &tf_pad).ok()) + if (!GetPaddingFromString(tf_conv_op.getPadding().str(), &tf_pad).ok()) return failure(); tensorflow::TensorFormat data_format_tf; - if (!FormatFromString(tf_conv_op.data_format().str(), &data_format_tf)) + if (!FormatFromString(tf_conv_op.getDataFormat().str(), &data_format_tf)) return failure(); if (tf_pad == tensorflow::Padding::EXPLICIT) { outpad = getPaddingValuesFromExplicitPadAttr( - tf_conv_op.explicit_paddings(), data_format_tf, rewriter); + tf_conv_op.getExplicitPaddings(), data_format_tf, rewriter); } else { if (!getTransposeConv2dPaddingValues(tf_pad, data_format_tf, 0, // tensorflow::FORMAT_HWIO, @@ -1047,16 +1100,16 @@ LogicalResult ConvertTFConv2DBackpropInputOp::matchAndRewrite( { ElementsAttr output_shape_elems; // Match from input_sizes tensor first. - if (matchPattern(tf_conv_op.input_sizes(), + if (matchPattern(tf_conv_op.getInputSizes(), m_Constant(&output_shape_elems))) { SmallVector shape_vec; for (int i = 0; i < output_shape_elems.getNumElements(); i++) shape_vec.push_back( output_shape_elems.getValues()[i].getInt()); - output_shape = rewriter.getI64ArrayAttr(shape_vec); + output_shape = rewriter.getDenseI64ArrayAttr(shape_vec); } else { // Use output tensor's shape otherwise. - output_shape = rewriter.getI64ArrayAttr(output_type.getShape()); + output_shape = rewriter.getDenseI64ArrayAttr(output_type.getShape()); } } @@ -1068,8 +1121,8 @@ LogicalResult ConvertTFConv2DBackpropInputOp::matchAndRewrite( if (!zero_bias) return failure(); CreateReplaceOpAndInfer( - rewriter, op, output_type, tf_conv_op.out_backprop(), - a1_filter_transpose_op.getResult(), zero_bias.getValue(), outpad, stride, + rewriter, op, output_type, tf_conv_op.getOutBackprop(), + a1_filter_transpose_op.getResult(), zero_bias.value(), outpad, stride, output_shape); return success(); @@ -1084,15 +1137,15 @@ LogicalResult ConvertTFAllOp::matchAndRewrite(Operation* op, if (!output_type) return failure(); ElementsAttr axes_elems; - if (!matchPattern(tf_all_op.reduction_indices(), m_Constant(&axes_elems))) + if (!matchPattern(tf_all_op.getReductionIndices(), m_Constant(&axes_elems))) return failure(); llvm::Optional result = convertReduceAllOp( - rewriter, op, output_type, tf_all_op.input(), axes_elems); + rewriter, op, output_type, tf_all_op.getInput(), axes_elems); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1106,15 +1159,15 @@ LogicalResult ConvertTFAnyOp::matchAndRewrite(Operation* op, if (!output_type) return failure(); ElementsAttr axes_elems; - if (!matchPattern(tf_any_op.reduction_indices(), m_Constant(&axes_elems))) + if (!matchPattern(tf_any_op.getReductionIndices(), m_Constant(&axes_elems))) return failure(); llvm::Optional result = convertReduceAnyOp( - rewriter, op, output_type, tf_any_op.input(), axes_elems); + rewriter, op, output_type, tf_any_op.getInput(), axes_elems); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1128,15 +1181,15 @@ LogicalResult ConvertTFMaxOp::matchAndRewrite(Operation* op, if (!output_type) return failure(); ElementsAttr axes_elems; - if (!matchPattern(tf_max_op.reduction_indices(), m_Constant(&axes_elems))) + if (!matchPattern(tf_max_op.getReductionIndices(), m_Constant(&axes_elems))) return failure(); llvm::Optional result = convertReduceMaxOp( - rewriter, op, output_type, tf_max_op.input(), axes_elems); + rewriter, op, output_type, tf_max_op.getInput(), axes_elems); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1150,15 +1203,15 @@ LogicalResult ConvertTFMinOp::matchAndRewrite(Operation* op, if (!output_type) return failure(); ElementsAttr axes_elems; - if (!matchPattern(tf_min_op.reduction_indices(), m_Constant(&axes_elems))) + if (!matchPattern(tf_min_op.getReductionIndices(), m_Constant(&axes_elems))) return failure(); llvm::Optional result = convertReduceMinOp( - rewriter, op, output_type, tf_min_op.input(), axes_elems); + rewriter, op, output_type, tf_min_op.getInput(), axes_elems); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1172,15 +1225,15 @@ LogicalResult ConvertTFMeanOp::matchAndRewrite( if (!output_type) return failure(); ElementsAttr axes_elems; - if (!matchPattern(tf_mean_op.reduction_indices(), m_Constant(&axes_elems))) + if (!matchPattern(tf_mean_op.getReductionIndices(), m_Constant(&axes_elems))) return failure(); llvm::Optional result = convertReduceMeanOp( - rewriter, op, output_type, tf_mean_op.input(), axes_elems); + rewriter, op, output_type, tf_mean_op.getInput(), axes_elems); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1194,15 +1247,15 @@ LogicalResult ConvertTFProdOp::matchAndRewrite( if (!output_type) return failure(); ElementsAttr axes_elems; - if (!matchPattern(tf_prod_op.reduction_indices(), m_Constant(&axes_elems))) + if (!matchPattern(tf_prod_op.getReductionIndices(), m_Constant(&axes_elems))) return failure(); llvm::Optional result = convertReduceProdOp( - rewriter, op, output_type, tf_prod_op.input(), axes_elems); + rewriter, op, output_type, tf_prod_op.getInput(), axes_elems); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1216,15 +1269,15 @@ LogicalResult ConvertTFSumOp::matchAndRewrite(Operation* op, if (!output_type) return failure(); ElementsAttr axes_elems; - if (!matchPattern(tf_sum_op.reduction_indices(), m_Constant(&axes_elems))) + if (!matchPattern(tf_sum_op.getReductionIndices(), m_Constant(&axes_elems))) return failure(); llvm::Optional result = convertReduceSumOp( - rewriter, op, output_type, tf_sum_op.input(), axes_elems); + rewriter, op, output_type, tf_sum_op.getInput(), axes_elems); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1233,12 +1286,12 @@ LogicalResult ConvertTFEluOp::matchAndRewrite(Operation* op, PatternRewriter& rewriter) const { auto tf_elu_op = cast(op); - llvm::Optional result = - convertEluOp(rewriter, op, tf_elu_op.getResult(), tf_elu_op.features()); + llvm::Optional result = convertEluOp( + rewriter, op, tf_elu_op.getResult(), tf_elu_op.getFeatures()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1249,11 +1302,11 @@ LogicalResult ConvertTFSoftmaxOp::matchAndRewrite( llvm::Optional result = convertSoftmaxOp(rewriter, op, tf_softmax_op.getResult(), - tf_softmax_op.logits(), /*beta=*/1.0); + tf_softmax_op.getLogits(), /*beta=*/1.0); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1263,11 +1316,11 @@ LogicalResult ConvertTFLogSoftmaxOp::matchAndRewrite( auto tf_logsoftmax_op = cast(op); llvm::Optional result = convertLogSoftmaxOp( - rewriter, op, tf_logsoftmax_op.getResult(), tf_logsoftmax_op.logits()); + rewriter, op, tf_logsoftmax_op.getResult(), tf_logsoftmax_op.getLogits()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1302,9 +1355,9 @@ LogicalResult ConvertTFFusedBatchNormOp::matchAndRewrite( // op6 = add(op5, boffset) RankedTensorType mean_type = - tf_batchnorm_op.mean().getType().dyn_cast(); + tf_batchnorm_op.getMean().getType().dyn_cast(); RankedTensorType variance_type = - tf_batchnorm_op.variance().getType().dyn_cast(); + tf_batchnorm_op.getVariance().getType().dyn_cast(); if (!variance_type || !mean_type) return failure(); Value mean_val, variance_val; @@ -1312,25 +1365,25 @@ LogicalResult ConvertTFFusedBatchNormOp::matchAndRewrite( if (mean_type.getNumElements() == 0) { mean_val = getTosaConstTensorSingleF32(rewriter, tf_batchnorm_op, 0); } else { - mean_val = tf_batchnorm_op.mean(); + mean_val = tf_batchnorm_op.getMean(); } if (variance_type.getNumElements() == 0) { variance_val = getTosaConstTensorSingleF32(rewriter, tf_batchnorm_op, 1.0); } else { - variance_val = tf_batchnorm_op.variance(); + variance_val = tf_batchnorm_op.getVariance(); } RankedTensorType epsilon_type = tensorflow::GetTypeFromTFTensorShape({1}, variance_type.getElementType()); auto epsilon_attr = - DenseFPElementsAttr::get(epsilon_type, {tf_batchnorm_op.epsilon()}); + DenseFPElementsAttr::get(epsilon_type, {tf_batchnorm_op.getEpsilon()}); auto epsilon_const = CreateOpAndInfer( rewriter, op->getLoc(), epsilon_type, epsilon_attr); auto op1_sub_input_mean = CreateOpAndInfer( rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(), - tf_batchnorm_op.x(), mean_val); + tf_batchnorm_op.getX(), mean_val); auto op2_add_var_epsilon = CreateOpAndInfer( rewriter, op->getLoc(), variance_val.getType(), variance_val, @@ -1346,11 +1399,11 @@ LogicalResult ConvertTFFusedBatchNormOp::matchAndRewrite( auto op5_mul_op4_scale = CreateOpAndInfer( rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(), - op4_mul_op1_op3.getResult(), tf_batchnorm_op.scale(), 0); + op4_mul_op1_op3.getResult(), tf_batchnorm_op.getScale(), 0); auto op6_add_op5_offset = CreateOpAndInfer( rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(), - op5_mul_op4_scale.getResult(), tf_batchnorm_op.offset()); + op5_mul_op4_scale.getResult(), tf_batchnorm_op.getOffset()); rewriter.replaceOp(op, {op6_add_op5_offset.getResult()}); return success(); @@ -1360,7 +1413,7 @@ LogicalResult ConvertTFFusedBatchNormV3Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_batchnorm_op = cast(op); - if (tf_batchnorm_op.is_training()) + if (tf_batchnorm_op.getIsTraining()) return rewriter.notifyMatchFailure( op, "unable to lower when is_training is set"); @@ -1389,25 +1442,25 @@ LogicalResult ConvertTFFusedBatchNormV3Op::matchAndRewrite( auto op1_sub_input_mean = CreateOpAndInfer( rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(), - tf_batchnorm_op.x(), tf_batchnorm_op.mean()); + tf_batchnorm_op.getX(), tf_batchnorm_op.getMean()); RankedTensorType variance_type = - tf_batchnorm_op.variance().getType().dyn_cast(); + tf_batchnorm_op.getVariance().getType().dyn_cast(); if (!variance_type) return failure(); auto epsilon_type = tensorflow::GetTypeFromTFTensorShape({1}, variance_type.getElementType()); auto epsilon_attr = - DenseFPElementsAttr::get(epsilon_type, {tf_batchnorm_op.epsilon()}); + DenseFPElementsAttr::get(epsilon_type, {tf_batchnorm_op.getEpsilon()}); auto epsilon_const = CreateOpAndInfer( rewriter, op->getLoc(), epsilon_type, epsilon_attr); auto op2_add_var_epsilon = CreateOpAndInfer( - rewriter, op->getLoc(), tf_batchnorm_op.variance().getType(), - tf_batchnorm_op.variance(), epsilon_const); + rewriter, op->getLoc(), tf_batchnorm_op.getVariance().getType(), + tf_batchnorm_op.getVariance(), epsilon_const); auto op3_rsqrt_op2 = CreateOpAndInfer( - rewriter, op->getLoc(), tf_batchnorm_op.variance().getType(), + rewriter, op->getLoc(), tf_batchnorm_op.getVariance().getType(), op2_add_var_epsilon.getResult()); auto op4_mul_op1_op3 = CreateOpAndInfer( @@ -1416,18 +1469,18 @@ LogicalResult ConvertTFFusedBatchNormV3Op::matchAndRewrite( auto op5_mul_op4_scale = CreateOpAndInfer( rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(), - op4_mul_op1_op3.getResult(), tf_batchnorm_op.scale(), 0); + op4_mul_op1_op3.getResult(), tf_batchnorm_op.getScale(), 0); auto op6_add_op5_offset = CreateOpAndInfer( rewriter, op->getLoc(), tf_batchnorm_op.getResult(0).getType(), - op5_mul_op4_scale.getResult(), tf_batchnorm_op.offset()); + op5_mul_op4_scale.getResult(), tf_batchnorm_op.getOffset()); llvm::SmallVector replacements = { - op6_add_op5_offset.getResult(), tf_batchnorm_op.mean(), - tf_batchnorm_op.variance(), + op6_add_op5_offset.getResult(), tf_batchnorm_op.getMean(), + tf_batchnorm_op.getVariance(), // The last three are reserved spaces and have no purpose currently. - tf_batchnorm_op.mean(), tf_batchnorm_op.variance(), - tf_batchnorm_op.variance()}; + tf_batchnorm_op.getMean(), tf_batchnorm_op.getVariance(), + tf_batchnorm_op.getVariance()}; rewriter.replaceOp(op, replacements); return success(); } @@ -1442,8 +1495,8 @@ LogicalResult ConvertTFBiasAddOp::matchAndRewrite( if (!output_type) return failure(); auto add_op = CreateOpAndInfer( - rewriter, op->getLoc(), output_type, tf_biasadd_op.value(), - tf_biasadd_op.bias()); + rewriter, op->getLoc(), output_type, tf_biasadd_op.getValue(), + tf_biasadd_op.getBias()); rewriter.replaceOp(op, {add_op.getResult()}); return success(); @@ -1463,7 +1516,7 @@ LogicalResult ConvertTFSliceOp::matchAndRewrite( SmallVector begin_vals, size_vals; // Assuming begin is always compile-time constant - if (!matchPattern(tf_slice_op.begin(), m_Constant(&begin_elems))) { + if (!matchPattern(tf_slice_op.getBegin(), m_Constant(&begin_elems))) { return rewriter.notifyMatchFailure(op, "begin is not constant"); } @@ -1472,7 +1525,7 @@ LogicalResult ConvertTFSliceOp::matchAndRewrite( // Try to match size as compile-time constant first, // if this fails, use the output tensor shape instead. - if (matchPattern(tf_slice_op.size(), m_Constant(&size_elems))) { + if (matchPattern(tf_slice_op.getSize(), m_Constant(&size_elems))) { for (int i = 0; i < size_elems.getNumElements(); i++) size_vals.push_back(size_elems.getValues()[i].getInt()); } else { @@ -1480,11 +1533,11 @@ LogicalResult ConvertTFSliceOp::matchAndRewrite( output_type.getShape().end()); } - ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals); - ArrayAttr size = rewriter.getI64ArrayAttr(size_vals); + DenseI64ArrayAttr begin = rewriter.getDenseI64ArrayAttr(begin_vals); + DenseI64ArrayAttr size = rewriter.getDenseI64ArrayAttr(size_vals); CreateReplaceOpAndInfer(rewriter, op, output_type, - tf_slice_op.input(), begin, size); + tf_slice_op.getInput(), begin, size); return success(); } @@ -1498,17 +1551,18 @@ LogicalResult ConvertTFTileOp::matchAndRewrite( if (!output_type) return failure(); ElementsAttr multiples_elems; - if (!matchPattern(tf_tile_op.multiples(), m_Constant(&multiples_elems))) + if (!matchPattern(tf_tile_op.getMultiples(), m_Constant(&multiples_elems))) return failure(); SmallVector multiples_vals; for (int i = 0; i < multiples_elems.getNumElements(); i++) multiples_vals.push_back( multiples_elems.getValues()[i].getInt()); - ArrayAttr multiples_attr = rewriter.getI64ArrayAttr(multiples_vals); + DenseI64ArrayAttr multiples_attr = + rewriter.getDenseI64ArrayAttr(multiples_vals); CreateReplaceOpAndInfer(rewriter, op, output_type, - tf_tile_op.input(), multiples_attr); + tf_tile_op.getInput(), multiples_attr); return success(); } @@ -1524,8 +1578,9 @@ LogicalResult ConvertTFTransposeOp::matchAndRewrite( return failure(); } - CreateReplaceOpAndInfer( - rewriter, op, output_type, tf_transpose_op.x(), tf_transpose_op.perm()); + CreateReplaceOpAndInfer(rewriter, op, output_type, + tf_transpose_op.getX(), + tf_transpose_op.getPerm()); return success(); } @@ -1534,11 +1589,11 @@ LogicalResult ConvertTFPackOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_pack_op = cast(op); - SmallVector inputs(tf_pack_op.values()); + SmallVector inputs(tf_pack_op.getValues()); assert(inputs.size() >= 2); - IntegerAttr axis_attr = tf_pack_op.axisAttr(); + IntegerAttr axis_attr = tf_pack_op.getAxisAttr(); if (!axis_attr) axis_attr = rewriter.getI64IntegerAttr(0); int32_t axis_i32 = axis_attr.getInt(); @@ -1548,7 +1603,7 @@ LogicalResult ConvertTFPackOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1559,18 +1614,18 @@ LogicalResult ConvertTFUnpackOp::matchAndRewrite( IntegerAttr axis_attr; { - auto tmpAttr = tf_unpack_op.axisAttr(); + auto tmpAttr = tf_unpack_op.getAxisAttr(); if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0); axis_attr = tmpAttr; } int32_t axis_i32 = axis_attr.getInt(); llvm::Optional> results = - convertUnpackOp(rewriter, op, tf_unpack_op.value(), axis_i32); + convertUnpackOp(rewriter, op, tf_unpack_op.getValue(), axis_i32); if (!results) return failure(); - rewriter.replaceOp(op, results.getValue()); + rewriter.replaceOp(op, results.value()); return success(); } @@ -1589,17 +1644,17 @@ LogicalResult ConvertTFSplitOp::matchAndRewrite( // Get the axis int32_t axis = 0; ElementsAttr axisAttrElems; - if (matchPattern(tf_split_op.split_dim(), m_Constant(&axisAttrElems))) { + if (matchPattern(tf_split_op.getSplitDim(), m_Constant(&axisAttrElems))) { axis = axisAttrElems.getValues()[0].getInt(); } llvm::Optional> results = convertSplitOp(rewriter, op, tf_split_op.getResult(0), - tf_split_op.value(), num_split, axis); + tf_split_op.getValue(), num_split, axis); if (!results) return failure(); - rewriter.replaceOp(op, results.getValue()); + rewriter.replaceOp(op, results.value()); return success(); } @@ -1612,7 +1667,7 @@ LogicalResult ConvertTFSplitVOp::matchAndRewrite( // Get the size_splits array SmallVector size_split; ElementsAttr size_split_elems; - if (!matchPattern(tf_splitv_op.size_splits(), + if (!matchPattern(tf_splitv_op.getSizeSplits(), m_Constant(&size_split_elems))) { return failure(); } @@ -1623,7 +1678,7 @@ LogicalResult ConvertTFSplitVOp::matchAndRewrite( // Get the axis ElementsAttr axisAttrElems; - if (!matchPattern(tf_splitv_op.split_dim(), m_Constant(&axisAttrElems))) { + if (!matchPattern(tf_splitv_op.getSplitDim(), m_Constant(&axisAttrElems))) { return rewriter.notifyMatchFailure(op, "cannot read split_dim elems"); } @@ -1631,11 +1686,11 @@ LogicalResult ConvertTFSplitVOp::matchAndRewrite( llvm::Optional> results = convertSplitVOp(rewriter, op, tf_splitv_op.getResult(0), - tf_splitv_op.value(), size_split, axis); + tf_splitv_op.getValue(), size_split, axis); if (!results) return failure(); - rewriter.replaceOp(op, results.getValue()); + rewriter.replaceOp(op, results.value()); return success(); } @@ -1651,7 +1706,8 @@ LogicalResult ConvertTFLessOp::matchAndRewrite( // less(x, y) is not(greater_equal(x, y)) auto greater_equal_op = CreateOpAndInfer( - rewriter, op->getLoc(), output_type, tf_less_op.x(), tf_less_op.y()); + rewriter, op->getLoc(), output_type, tf_less_op.getX(), + tf_less_op.getY()); auto not_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, greater_equal_op.getResult()); @@ -1671,8 +1727,8 @@ LogicalResult ConvertTFLessEqualOp::matchAndRewrite( // less_equal(x, y) is not(greater(x, y)) auto greater_op = CreateOpAndInfer( - rewriter, op->getLoc(), output_type, tf_less_equal_op.x(), - tf_less_equal_op.y()); + rewriter, op->getLoc(), output_type, tf_less_equal_op.getX(), + tf_less_equal_op.getY()); auto not_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, greater_op.getResult()); @@ -1689,9 +1745,9 @@ LogicalResult ConvertTFPadOp::matchAndRewrite(Operation* op, // Not a ranked tensor output if (!output_type) return failure(); - auto pad_op = - CreateOpAndInfer(rewriter, op->getLoc(), output_type, - tf_pad_op.input(), tf_pad_op.paddings()); + auto pad_op = CreateOpAndInfer(rewriter, op->getLoc(), + output_type, tf_pad_op.getInput(), + tf_pad_op.getPaddings()); rewriter.replaceOp(op, {pad_op.getResult()}); return success(); @@ -1708,7 +1764,7 @@ LogicalResult ConvertTFMirrorPadOp::matchAndRewrite( } TFTFLMirrorPaddingType mode; - StringRef tf_mode = tf_mirrorpad_op.mode(); + StringRef tf_mode = tf_mirrorpad_op.getMode(); if (tf_mode == "REFLECT") { mode = TFTFLMirrorPaddingType::REFLECT; } else if (tf_mode == "SYMMETRIC") { @@ -1718,11 +1774,11 @@ LogicalResult ConvertTFMirrorPadOp::matchAndRewrite( op, "mode isn't one of REFLECT or SYMMETRIC"); } - llvm::Optional result = - convertMirrorPadCommon(rewriter, op, output_type, tf_mirrorpad_op.input(), - tf_mirrorpad_op.paddings(), mode); + llvm::Optional result = convertMirrorPadCommon( + rewriter, op, output_type, tf_mirrorpad_op.getInput(), + tf_mirrorpad_op.getPaddings(), mode); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1737,13 +1793,13 @@ LogicalResult ConvertTFResizeBilinearOp::matchAndRewrite( if (!output_type) return failure(); llvm::Optional result = convertResizeOp( - rewriter, op, output_type, tf_resize_op.images(), StringRef("BILINEAR"), - tf_resize_op.align_cornersAttr().getValue(), - tf_resize_op.half_pixel_centersAttr().getValue()); + rewriter, op, output_type, tf_resize_op.getImages(), + StringRef("BILINEAR"), tf_resize_op.getAlignCornersAttr().getValue(), + tf_resize_op.getHalfPixelCentersAttr().getValue()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1758,14 +1814,14 @@ LogicalResult ConvertTFResizeNearestNeighborOp::matchAndRewrite( if (!output_type) return failure(); llvm::Optional result = - convertResizeOp(rewriter, op, output_type, tf_resize_op.images(), + convertResizeOp(rewriter, op, output_type, tf_resize_op.getImages(), StringRef("NEAREST_NEIGHBOR"), - tf_resize_op.align_cornersAttr().getValue(), - tf_resize_op.half_pixel_centersAttr().getValue()); + tf_resize_op.getAlignCornersAttr().getValue(), + tf_resize_op.getHalfPixelCentersAttr().getValue()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1775,9 +1831,9 @@ LogicalResult ConvertTFMatMulOp::matchAndRewrite( auto tf_matmul_op = cast(op); RankedTensorType a_type = - tf_matmul_op.a().getType().dyn_cast(); + tf_matmul_op.getA().getType().dyn_cast(); RankedTensorType b_type = - tf_matmul_op.b().getType().dyn_cast(); + tf_matmul_op.getB().getType().dyn_cast(); RankedTensorType output_type = tf_matmul_op.getResult().getType().dyn_cast(); @@ -1814,12 +1870,12 @@ LogicalResult ConvertTFMatMulOp::matchAndRewrite( // Need to reshape input and output since TOSA matmul only supports // [N, H, C] * [N, C, W] -> [N, H, W]. auto op1_reshape_a = CreateOpAndInfer( - rewriter, op->getLoc(), batch_a_type, tf_matmul_op.a(), - rewriter.getI64ArrayAttr(batch_a_shape)); + rewriter, op->getLoc(), batch_a_type, tf_matmul_op.getA(), + rewriter.getDenseI64ArrayAttr(batch_a_shape)); auto op2_reshape_b = CreateOpAndInfer( - rewriter, op->getLoc(), batch_b_type, tf_matmul_op.b(), - rewriter.getI64ArrayAttr(batch_b_shape)); + rewriter, op->getLoc(), batch_b_type, tf_matmul_op.getB(), + rewriter.getDenseI64ArrayAttr(batch_b_shape)); auto op3_matmul_op1_op2 = CreateOpAndInfer( rewriter, op->getLoc(), batch_output_type, op1_reshape_a.getResult(), @@ -1827,7 +1883,7 @@ LogicalResult ConvertTFMatMulOp::matchAndRewrite( CreateReplaceOpAndInfer( rewriter, op, output_type, op3_matmul_op1_op2.getResult(), - rewriter.getI64ArrayAttr(output_type.getShape())); + rewriter.getDenseI64ArrayAttr(output_type.getShape())); return success(); } @@ -1841,12 +1897,12 @@ LogicalResult ConvertTFGatherOp::matchAndRewrite( int32_t axis = 0; llvm::Optional result = convertGatherOp( - rewriter, op, tf_gather_op.getResult(), tf_gather_op.params(), - tf_gather_op.indices(), batch_dims, axis); + rewriter, op, tf_gather_op.getResult(), tf_gather_op.getParams(), + tf_gather_op.getIndices(), batch_dims, axis); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1857,20 +1913,20 @@ LogicalResult ConvertTFGatherV2Op::matchAndRewrite( // Axis is a tensor. Pull out the one integer value. ElementsAttr axis_elem; - if (!matchPattern(tf_gather_op.axis(), m_Constant(&axis_elem))) + if (!matchPattern(tf_gather_op.getAxis(), m_Constant(&axis_elem))) return failure(); assert(axis_elem.getNumElements() == 1); int32_t axis = axis_elem.getValues()[0].getInt(); - int32_t batch_dims = tf_gather_op.batch_dimsAttr().getInt(); + int32_t batch_dims = tf_gather_op.getBatchDimsAttr().getInt(); llvm::Optional result = convertGatherOp( - rewriter, op, tf_gather_op.getResult(), tf_gather_op.params(), - tf_gather_op.indices(), batch_dims, axis); + rewriter, op, tf_gather_op.getResult(), tf_gather_op.getParams(), + tf_gather_op.getIndices(), batch_dims, axis); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1879,13 +1935,13 @@ LogicalResult ConvertTFGatherNdOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_gathernd_op = cast(op); - llvm::Optional result = - convertGatherNdOp(rewriter, op, tf_gathernd_op.getResult(), - tf_gathernd_op.params(), tf_gathernd_op.indices()); + llvm::Optional result = convertGatherNdOp( + rewriter, op, tf_gathernd_op.getResult(), tf_gathernd_op.getParams(), + tf_gathernd_op.getIndices()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1895,12 +1951,12 @@ LogicalResult ConvertTFSelectV2Op::matchAndRewrite( auto tf_sel_op = cast(op); llvm::Optional result = convertSelectOp( - rewriter, op, tf_sel_op.getResult(), tf_sel_op.condition(), - tf_sel_op.then_value(), tf_sel_op.else_value()); + rewriter, op, tf_sel_op.getResult(), tf_sel_op.getCondition(), + tf_sel_op.getThenValue(), tf_sel_op.getElseValue()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1910,12 +1966,12 @@ LogicalResult ConvertTFSpaceToDepthOp::matchAndRewrite( auto tf_s2d_op = cast(op); llvm::Optional result = convertSpaceToDepthOp( - rewriter, op, tf_s2d_op.getResult(), tf_s2d_op.input(), - tf_s2d_op.block_sizeAttr(), tf_s2d_op.data_formatAttr()); + rewriter, op, tf_s2d_op.getResult(), tf_s2d_op.getInput(), + tf_s2d_op.getBlockSizeAttr(), tf_s2d_op.getDataFormatAttr()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1925,12 +1981,12 @@ LogicalResult ConvertTFDepthToSpaceOp::matchAndRewrite( auto tf_d2s_op = cast(op); llvm::Optional result = convertDepthToSpaceOp( - rewriter, op, tf_d2s_op.getResult(), tf_d2s_op.input(), - tf_d2s_op.block_sizeAttr(), tf_d2s_op.data_formatAttr()); + rewriter, op, tf_d2s_op.getResult(), tf_d2s_op.getInput(), + tf_d2s_op.getBlockSizeAttr(), tf_d2s_op.getDataFormatAttr()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1940,11 +1996,11 @@ LogicalResult ConvertTFSpaceToBatchNDOp::matchAndRewrite( auto tf_s2b_op = cast(op); llvm::Optional result = convertSpaceToBatchNDOp( - rewriter, op, tf_s2b_op.getResult(), tf_s2b_op.input(), - tf_s2b_op.block_shape(), tf_s2b_op.paddings()); + rewriter, op, tf_s2b_op.getResult(), tf_s2b_op.getInput(), + tf_s2b_op.getBlockShape(), tf_s2b_op.getPaddings()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1954,12 +2010,12 @@ LogicalResult ConvertTFBatchToSpaceNDOp::matchAndRewrite( auto tf_b2s_op = cast(op); llvm::Optional result = convertBatchToSpaceNDOp( - rewriter, op, tf_b2s_op.getResult(), tf_b2s_op.input(), - tf_b2s_op.block_shape(), tf_b2s_op.crops()); + rewriter, op, tf_b2s_op.getResult(), tf_b2s_op.getInput(), + tf_b2s_op.getBlockShape(), tf_b2s_op.getCrops()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1969,15 +2025,16 @@ LogicalResult ConvertTFStridedSliceOp::matchAndRewrite( auto tf_ss_op = cast(op); llvm::Optional result = convertStridedSliceOp( - rewriter, op, tf_ss_op.getResult(), tf_ss_op.input(), tf_ss_op.begin(), - tf_ss_op.end(), tf_ss_op.strides(), tf_ss_op.begin_maskAttr().getInt(), - tf_ss_op.end_maskAttr().getInt(), tf_ss_op.ellipsis_maskAttr().getInt(), - tf_ss_op.new_axis_maskAttr().getInt(), - tf_ss_op.shrink_axis_maskAttr().getInt()); + rewriter, op, tf_ss_op.getResult(), tf_ss_op.getInput(), + tf_ss_op.getBegin(), tf_ss_op.getEnd(), tf_ss_op.getStrides(), + tf_ss_op.getBeginMaskAttr().getInt(), tf_ss_op.getEndMaskAttr().getInt(), + tf_ss_op.getEllipsisMaskAttr().getInt(), + tf_ss_op.getNewAxisMaskAttr().getInt(), + tf_ss_op.getShrinkAxisMaskAttr().getInt()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1987,11 +2044,11 @@ LogicalResult ConvertTFZerosLikeOp::matchAndRewrite( auto tf_zeroslike_op = cast(op); llvm::Optional result = convertZerosLikeOp( - rewriter, op, tf_zeroslike_op.getResult(), tf_zeroslike_op.x()); + rewriter, op, tf_zeroslike_op.getResult(), tf_zeroslike_op.getX()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2004,7 +2061,7 @@ LogicalResult ConvertTFSigmoidOp::matchAndRewrite( if (!output_type) return failure(); CreateReplaceOpAndInfer(rewriter, op, output_type, - tf_sigmoid_op.x()); + tf_sigmoid_op.getX()); return success(); } @@ -2017,7 +2074,7 @@ LogicalResult ConvertTFTanhOp::matchAndRewrite( if (!output_type) return failure(); CreateReplaceOpAndInfer(rewriter, op, output_type, - tf_tanh_op.x()); + tf_tanh_op.getX()); return success(); } @@ -2050,7 +2107,7 @@ LogicalResult ConvertTFLeakyReluOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "only support F32"); } - FloatAttr tmpAttr = tf_leakyrelu_op.alphaAttr(); + FloatAttr tmpAttr = tf_leakyrelu_op.getAlphaAttr(); // There is disagreement between the MLIR .td defaults and TF // documentation on 0.2 vs 0.3, but 0.2 will be used here. double alpha = 0.2; @@ -2062,15 +2119,15 @@ LogicalResult ConvertTFLeakyReluOp::matchAndRewrite( Value const_zero = getTosaConstTensorSingleF32(rewriter, op, 0.0); auto a1_mul = CreateOpAndInfer( - rewriter, op->getLoc(), output_type, tf_leakyrelu_op.features(), + rewriter, op->getLoc(), output_type, tf_leakyrelu_op.getFeatures(), getTosaConstTensorSingleF32(rewriter, op, alpha), 0); auto a2_ge = CreateOpAndInfer( rewriter, op->getLoc(), UnrankedTensorType::get(rewriter.getI1Type()), - tf_leakyrelu_op.features(), const_zero); + tf_leakyrelu_op.getFeatures(), const_zero); auto a3_select = CreateOpAndInfer( - rewriter, op->getLoc(), output_type, a2_ge, tf_leakyrelu_op.features(), + rewriter, op->getLoc(), output_type, a2_ge, tf_leakyrelu_op.getFeatures(), a1_mul.getResult()); rewriter.replaceOp(op, {a3_select.getResult()}); @@ -2086,7 +2143,7 @@ LogicalResult ConvertTFNegOp::matchAndRewrite(Operation* op, if (!output_type) return failure(); CreateReplaceOpAndInfer(rewriter, op, output_type, - tf_neg_op.x()); + tf_neg_op.getX()); return success(); } @@ -2099,7 +2156,7 @@ LogicalResult ConvertTFStopGradientOp::matchAndRewrite( if (!output_type) return failure(); CreateReplaceOpAndInfer(rewriter, op, output_type, - tf_stopgrad_op.input()); + tf_stopgrad_op.getInput()); return success(); } @@ -2108,17 +2165,17 @@ LogicalResult ConvertTFReverseV2Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_reverse_op = cast(op); RankedTensorType input_type = - tf_reverse_op.tensor().getType().dyn_cast(); + tf_reverse_op.getTensor().getType().dyn_cast(); TensorType output_type = tf_reverse_op.getResult().getType().dyn_cast(); if (!input_type || !output_type) return failure(); ElementsAttr axis_elems; - if (!matchPattern(tf_reverse_op.axis(), m_Constant(&axis_elems))) + if (!matchPattern(tf_reverse_op.getAxis(), m_Constant(&axis_elems))) return failure(); auto input_rank = input_type.getShape().size(); - Value val = tf_reverse_op.tensor(); + Value val = tf_reverse_op.getTensor(); if (axis_elems.getNumElements() == 0) { auto identity_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, val); @@ -2150,15 +2207,15 @@ LogicalResult ConvertTFFakeQuantWithMinMaxArgsOp::matchAndRewrite( if (!output_type) return failure(); llvm::Optional result = - convertFakeQuantOp(rewriter, op, output_type, tf_fakequant_op.inputs(), - tf_fakequant_op.minAttr().getValueAsDouble(), - tf_fakequant_op.maxAttr().getValueAsDouble(), - tf_fakequant_op.num_bitsAttr().getInt(), - tf_fakequant_op.narrow_rangeAttr().getValue()); + convertFakeQuantOp(rewriter, op, output_type, tf_fakequant_op.getInputs(), + tf_fakequant_op.getMinAttr().getValueAsDouble(), + tf_fakequant_op.getMaxAttr().getValueAsDouble(), + tf_fakequant_op.getNumBitsAttr().getInt(), + tf_fakequant_op.getNarrowRangeAttr().getValue()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2174,10 +2231,10 @@ LogicalResult ConvertTFFakeQuantWithMinMaxVarsOp::matchAndRewrite( // Only support min/max that can be matched at compile time ElementsAttr min_elems, max_elems; - if (!matchPattern(tf_fakequant_op.min(), m_Constant(&min_elems))) + if (!matchPattern(tf_fakequant_op.getMin(), m_Constant(&min_elems))) return failure(); - if (!matchPattern(tf_fakequant_op.max(), m_Constant(&max_elems))) + if (!matchPattern(tf_fakequant_op.getMax(), m_Constant(&max_elems))) return failure(); if (min_elems.getNumElements() != 1 && max_elems.getNumElements() != 1) @@ -2187,13 +2244,13 @@ LogicalResult ConvertTFFakeQuantWithMinMaxVarsOp::matchAndRewrite( int64_t max_val = max_elems.getValues()[0].getInt(); llvm::Optional result = convertFakeQuantOp( - rewriter, op, output_type, tf_fakequant_op.inputs(), min_val, max_val, - tf_fakequant_op.num_bitsAttr().getInt(), - tf_fakequant_op.narrow_rangeAttr().getValue()); + rewriter, op, output_type, tf_fakequant_op.getInputs(), min_val, max_val, + tf_fakequant_op.getNumBitsAttr().getInt(), + tf_fakequant_op.getNarrowRangeAttr().getValue()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2206,8 +2263,9 @@ LogicalResult ConvertTFLeftShiftOp::matchAndRewrite( tf_left_shift_op.getResult().getType().dyn_cast(); if (!output_type) return failure(); - CreateReplaceOpAndInfer( - rewriter, op, output_type, tf_left_shift_op.x(), tf_left_shift_op.y()); + CreateReplaceOpAndInfer(rewriter, op, output_type, + tf_left_shift_op.getX(), + tf_left_shift_op.getY()); return success(); } @@ -2229,12 +2287,12 @@ LogicalResult ConvertTFRightShiftOp::matchAndRewrite( if (is_signed) { CreateReplaceOpAndInfer( - rewriter, op, output_type, tf_right_shift_op.x(), tf_right_shift_op.y(), - false); + rewriter, op, output_type, tf_right_shift_op.getX(), + tf_right_shift_op.getY(), false); } else { CreateReplaceOpAndInfer( - rewriter, op, output_type, tf_right_shift_op.x(), - tf_right_shift_op.y()); + rewriter, op, output_type, tf_right_shift_op.getX(), + tf_right_shift_op.getY()); } return success(); @@ -2245,20 +2303,20 @@ LogicalResult ConvertTFOneHotOp::matchAndRewrite( auto tf_one_hot_op = cast(op); ElementsAttr depth_elems; - if (!matchPattern(tf_one_hot_op.depth(), m_Constant(&depth_elems))) + if (!matchPattern(tf_one_hot_op.getDepth(), m_Constant(&depth_elems))) return failure(); int32_t depth = depth_elems.getValues()[0].getInt(); - IntegerAttr axisAttr = tf_one_hot_op.axisAttr(); + IntegerAttr axisAttr = tf_one_hot_op.getAxisAttr(); int32_t axis = axisAttr.getInt(); llvm::Optional result = convertOneHotOp( - rewriter, op, tf_one_hot_op.getResult(), tf_one_hot_op.indices(), - tf_one_hot_op.on_value(), tf_one_hot_op.off_value(), depth, axis); + rewriter, op, tf_one_hot_op.getResult(), tf_one_hot_op.getIndices(), + tf_one_hot_op.getOnValue(), tf_one_hot_op.getOffValue(), depth, axis); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2268,9 +2326,9 @@ LogicalResult ConvertTFBatchMatMulV2Op::matchAndRewrite( auto tf_batch_matmul_op = cast(op); RankedTensorType x_type = - tf_batch_matmul_op.x().getType().dyn_cast(); + tf_batch_matmul_op.getX().getType().dyn_cast(); RankedTensorType y_type = - tf_batch_matmul_op.y().getType().dyn_cast(); + tf_batch_matmul_op.getY().getType().dyn_cast(); RankedTensorType output_type = tf_batch_matmul_op.getResult().getType().dyn_cast(); @@ -2290,8 +2348,8 @@ LogicalResult ConvertTFBatchMatMulV2Op::matchAndRewrite( // Rank 3 batch matmul can be directly mapped to tosa.matmul trivially. if (x_type.getRank() == 3) { CreateReplaceOpAndInfer(rewriter, op, output_type, - tf_batch_matmul_op.x(), - tf_batch_matmul_op.y()); + tf_batch_matmul_op.getX(), + tf_batch_matmul_op.getY()); } else { // 1. Reshape x from: (similar for y) // [a0, a1, ... an, H, C] to [N, H, C]. @@ -2321,12 +2379,12 @@ LogicalResult ConvertTFBatchMatMulV2Op::matchAndRewrite( rank3_output_shape, output_type.getElementType()); auto op1_reshape_x = CreateOpAndInfer( - rewriter, op->getLoc(), rank3_x_type, tf_batch_matmul_op.x(), - rewriter.getI64ArrayAttr(rank3_x_shape)); + rewriter, op->getLoc(), rank3_x_type, tf_batch_matmul_op.getX(), + rewriter.getDenseI64ArrayAttr(rank3_x_shape)); auto op2_reshape_y = CreateOpAndInfer( - rewriter, op->getLoc(), rank3_y_type, tf_batch_matmul_op.y(), - rewriter.getI64ArrayAttr(rank3_y_shape)); + rewriter, op->getLoc(), rank3_y_type, tf_batch_matmul_op.getY(), + rewriter.getDenseI64ArrayAttr(rank3_y_shape)); auto op3_matmul_op1_op2 = CreateOpAndInfer( rewriter, op->getLoc(), rank3_output_type, op1_reshape_x.getResult(), @@ -2334,7 +2392,7 @@ LogicalResult ConvertTFBatchMatMulV2Op::matchAndRewrite( CreateReplaceOpAndInfer( rewriter, op, output_type, op3_matmul_op1_op2.getResult(), - rewriter.getI64ArrayAttr(output_type.getShape())); + rewriter.getDenseI64ArrayAttr(output_type.getShape())); } return success(); } @@ -2423,6 +2481,8 @@ void populateLegalizeTFPatterns(MLIRContext* ctx, RewritePatternSet& patterns) { patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); + patterns.add(ctx); + patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index ca12ed601a7..8c7f0d7ecdc 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -29,15 +29,16 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project -#include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h" #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h" #include "tensorflow/compiler/mlir/tosa/transforms/passes.h" @@ -91,6 +92,7 @@ struct ConvertConstantOp : public RewritePattern { DECL_CONVERT_OP(Gelu); DECL_CONVERT_OP(Relu); DECL_CONVERT_OP(Relu1); +DECL_CONVERT_OP(Relu0To1); DECL_CONVERT_OP(Relu6); DECL_CONVERT_OP(Equal); DECL_CONVERT_OP(NotEqual); @@ -157,8 +159,10 @@ DECL_CONVERT_OP(SpaceToBatchNd); DECL_CONVERT_OP(BatchToSpaceNd); DECL_CONVERT_OP(SpaceToDepth); DECL_CONVERT_OP(DepthToSpace); +DECL_CONVERT_OP(Bucketize); DECL_CONVERT_OP(Sin); DECL_CONVERT_OP(Cos); +DECL_CONVERT_OP(Atan2); DECL_CONVERT_OP(Logistic); DECL_CONVERT_OP(Tanh); DECL_CONVERT_OP(PRelu); @@ -177,6 +181,7 @@ DECL_CONVERT_OP(SparseToDense); DECL_CONVERT_OP(OneHot); DECL_CONVERT_OP(ArgMax); DECL_CONVERT_OP(FakeQuant); +DECL_CONVERT_OP(While); #undef DECL_CONVERT_OP @@ -387,6 +392,56 @@ LogicalResult ConvertTFLRelu1Op::matchAndRewrite( return success(); } +LogicalResult ConvertTFLRelu0To1Op::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_relu0to1_op = cast(op); + + ShapedType input_type = tfl_relu0to1_op.getX().getType().cast(); + ShapedType output_type = + tfl_relu0to1_op.getResult().getType().cast(); + + bool input_is_qtype = + input_type.getElementType().isa(); + bool output_is_qtype = + output_type.getElementType().isa(); + + if (input_is_qtype != output_is_qtype) { + return rewriter.notifyMatchFailure( + op, + "input/output tensor should be all quantized or all floating-point"); + } + + int64_t clamp_min = 0; + int64_t clamp_max = 1; + Value clamp_in = tfl_relu0to1_op.getX(); + + if (output_is_qtype && input_is_qtype) { + UniformQuantizedType input_qtype = + input_type.getElementType().cast(); + UniformQuantizedType output_qtype = + output_type.getElementType().cast(); + + clamp_min = output_qtype.getZeroPoint(); + + clamp_max = std::llround(1.0f / output_qtype.getScale()) + + output_qtype.getZeroPoint(); + + clamp_in = + buildRescale(rewriter, op, output_type, tfl_relu0to1_op.getX(), + input_qtype.getScale() / output_qtype.getScale(), + input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), + /*double_round=*/false, /*scale32=*/true); + } + + CreateReplaceOpAndInfer(rewriter, op, output_type, clamp_in, + rewriter.getI64IntegerAttr(clamp_min), + rewriter.getI64IntegerAttr(clamp_max), + rewriter.getF32FloatAttr(0.0f), + rewriter.getF32FloatAttr(1.0f)); + + return success(); +} + LogicalResult ConvertTFLRelu6Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_relu6_op = cast(op); @@ -671,7 +726,7 @@ static LogicalResult matchAndRewriteAddSub(Operation* op, if (!fused_activation_val) return failure(); - rewriter.replaceOp(op, {fused_activation_val.getValue()}); + rewriter.replaceOp(op, {fused_activation_val.value()}); return success(); } @@ -705,15 +760,15 @@ LogicalResult ConvertTFLMulOp::matchAndRewrite( if (fused_activation_fn) { llvm::Optional fused_activation_val = convertFusedActivation( - rewriter, op, result.getValue(), fused_activation_fn); + rewriter, op, result.value(), fused_activation_fn); if (!fused_activation_val) return failure(); - rewriter.replaceOp(op, {fused_activation_val.getValue()}); + rewriter.replaceOp(op, {fused_activation_val.value()}); return success(); } - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -727,7 +782,7 @@ LogicalResult ConvertTFLSquareOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -741,7 +796,7 @@ LogicalResult ConvertTFLSquaredDifferenceOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -760,7 +815,7 @@ LogicalResult ConvertTFLRoundOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } else { @@ -805,7 +860,7 @@ LogicalResult ConvertTFLDivOp::matchAndRewrite( if (!fused_activation_val) return failure(); - rewriter.replaceOp(op, {fused_activation_val.getValue()}); + rewriter.replaceOp(op, {fused_activation_val.value()}); return success(); } @@ -933,7 +988,7 @@ LogicalResult ConvertTFLFloorDivOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -948,7 +1003,7 @@ LogicalResult ConvertTFLFloorModOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -991,13 +1046,13 @@ LogicalResult ConvertTFLAveragePool2DOp::matchAndRewrite( // Kernels and strides are dimensionally ordered SmallVector i64array({1, 1, 1, 1}); - ArrayAttr kernel_size; - ArrayAttr stride; - ArrayAttr pad; + DenseI64ArrayAttr kernel_size; + DenseI64ArrayAttr stride; + DenseI64ArrayAttr pad; { int64_t kernel_h = tfl_avgpool_op.getFilterHeight(); int64_t kernel_w = tfl_avgpool_op.getFilterWidth(); - kernel_size = rewriter.getI64ArrayAttr({kernel_h, kernel_w}); + kernel_size = rewriter.getDenseI64ArrayAttr({kernel_h, kernel_w}); // i64array is formatted as NHWC now i64array[1] = kernel_h; i64array[2] = kernel_w; @@ -1005,7 +1060,7 @@ LogicalResult ConvertTFLAveragePool2DOp::matchAndRewrite( { int64_t stride_h = tfl_avgpool_op.getStrideH(); int64_t stride_w = tfl_avgpool_op.getStrideW(); - stride = rewriter.getI64ArrayAttr({stride_h, stride_w}); + stride = rewriter.getDenseI64ArrayAttr({stride_h, stride_w}); } { tensorflow::Padding tf_pad; @@ -1013,10 +1068,10 @@ LogicalResult ConvertTFLAveragePool2DOp::matchAndRewrite( return failure(); // Pooling has no non-unit dilation - ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1}); + DenseI64ArrayAttr dilation = rewriter.getDenseI64ArrayAttr({1, 1}); RankedTensorType filter_type = RankedTensorType::get( - llvm::makeArrayRef(i64array), rewriter.getIntegerType(64)); + llvm::ArrayRef(i64array), rewriter.getIntegerType(64)); // TFLite doesn't support explicit padding if (!getPaddingValuesFromPadType( @@ -1068,13 +1123,13 @@ LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite( // Kernels and strides are dimensionally ordered SmallVector i64array({1, 1, 1, 1}); - ArrayAttr kernel_size; - ArrayAttr stride; - ArrayAttr pad; + DenseI64ArrayAttr kernel_size; + DenseI64ArrayAttr stride; + DenseI64ArrayAttr pad; { int64_t kernel_h = tfl_maxpool_op.getFilterHeight(); int64_t kernel_w = tfl_maxpool_op.getFilterWidth(); - kernel_size = rewriter.getI64ArrayAttr({kernel_h, kernel_w}); + kernel_size = rewriter.getDenseI64ArrayAttr({kernel_h, kernel_w}); // i64array is formatted as NHWC now i64array[1] = kernel_h; i64array[2] = kernel_w; @@ -1082,7 +1137,7 @@ LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite( { int64_t stride_h = tfl_maxpool_op.getStrideH(); int64_t stride_w = tfl_maxpool_op.getStrideW(); - stride = rewriter.getI64ArrayAttr({stride_h, stride_w}); + stride = rewriter.getDenseI64ArrayAttr({stride_h, stride_w}); } { tensorflow::Padding tf_pad; @@ -1090,7 +1145,7 @@ LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite( return failure(); // Pooling has no non-unit dilation - ArrayAttr dilation = rewriter.getI64ArrayAttr({1, 1}); + DenseI64ArrayAttr dilation = rewriter.getDenseI64ArrayAttr({1, 1}); RankedTensorType filter_type = RankedTensorType::get(i64array, rewriter.getIntegerType(64)); @@ -1140,18 +1195,18 @@ LogicalResult ConvertTFLConv2DOp::matchAndRewrite( "be all quantized or all floating-point"); } - ArrayAttr pad; - ArrayAttr stride; - ArrayAttr dilation; + DenseI64ArrayAttr pad; + DenseI64ArrayAttr stride; + DenseI64ArrayAttr dilation; { int64_t stride_h = tfl_conv2d_op.getStrideH(); int64_t stride_w = tfl_conv2d_op.getStrideW(); - stride = rewriter.getI64ArrayAttr({stride_h, stride_w}); + stride = rewriter.getDenseI64ArrayAttr({stride_h, stride_w}); } { int64_t dilation_h = tfl_conv2d_op.getDilationHFactor(); int64_t dilation_w = tfl_conv2d_op.getDilationWFactor(); - dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w}); + dilation = rewriter.getDenseI64ArrayAttr({dilation_h, dilation_w}); } { tensorflow::Padding tf_pad; @@ -1195,7 +1250,7 @@ LogicalResult ConvertTFLConv2DOp::matchAndRewrite( if (!fused_activation_val) return failure(); - rewriter.replaceOp(op, {fused_activation_val.getValue()}); + rewriter.replaceOp(op, {fused_activation_val.value()}); return success(); } @@ -1263,9 +1318,9 @@ LogicalResult ConvertTFLConv3DOp::matchAndRewrite( Value conv3d_output = input_is_qtype - ? buildRescaleOpConvOutput(rewriter, op, a1_conv3d_op.getValue(), + ? buildRescaleOpConvOutput(rewriter, op, a1_conv3d_op.value(), input_type, filter_type, output_type) - : a1_conv3d_op.getValue(); + : a1_conv3d_op.value(); if (auto fused_activation_fn = tfl_conv3d_op.getFusedActivationFunctionAttr()) { @@ -1313,13 +1368,13 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( "be all quantized or all floating-point"); } - ArrayAttr stride; - ArrayAttr outpad; - ArrayAttr output_shape; + DenseI64ArrayAttr stride; + DenseI64ArrayAttr outpad; + DenseI64ArrayAttr output_shape; { int64_t stride_h = tfl_conv_op.getStrideH(); int64_t stride_w = tfl_conv_op.getStrideW(); - stride = rewriter.getI64ArrayAttr({stride_h, stride_w}); + stride = rewriter.getDenseI64ArrayAttr({stride_h, stride_w}); } { @@ -1343,10 +1398,10 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( for (int i = 0; i < output_shape_elems.getNumElements(); i++) shape_vec.push_back( output_shape_elems.getValues()[i].getSExtValue()); - output_shape = rewriter.getI64ArrayAttr(shape_vec); + output_shape = rewriter.getDenseI64ArrayAttr(shape_vec); } else if (output_type.hasRank()) { // Use output tensor's shape otherwise - output_shape = rewriter.getI64ArrayAttr(output_type.getShape()); + output_shape = rewriter.getDenseI64ArrayAttr(output_type.getShape()); } else { // TODO(suderman): Figure out rankless shape propagation. return failure(); @@ -1390,7 +1445,7 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( auto a1_conv2d_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type.clone(bias_ety), - tfl_conv_op.getInput(), tfl_conv_op.getWeights(), zero_bias.getValue(), + tfl_conv_op.getInput(), tfl_conv_op.getWeights(), zero_bias.value(), outpad, stride, output_shape); Value conv2d_output; @@ -1402,6 +1457,18 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( conv2d_output = a1_conv2d_op.getResult(); } + auto fused_activation_fn = tfl_conv_op.getFusedActivationFunctionAttr(); + + if (fused_activation_fn) { + llvm::Optional fused_activation_val = convertFusedActivation( + rewriter, op, conv2d_output, fused_activation_fn); + + if (!fused_activation_val) return failure(); + + rewriter.replaceOp(op, {fused_activation_val.value()}); + return success(); + } + rewriter.replaceOp(op, {conv2d_output}); return success(); @@ -1452,20 +1519,20 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( // a3_transpose_conv2d = tosa.transpose_conv2d(input, a2_reshape, padding, // stride, dilation) - ArrayAttr pad; - ArrayAttr stride; - ArrayAttr dilation; + DenseI64ArrayAttr pad; + DenseI64ArrayAttr stride; + DenseI64ArrayAttr dilation; auto depth_multiplier = tfl_conv2d_op.getDepthMultiplierAttr(); { int64_t stride_h = tfl_conv2d_op.getStrideH(); int64_t stride_w = tfl_conv2d_op.getStrideW(); - stride = rewriter.getI64ArrayAttr({stride_h, stride_w}); + stride = rewriter.getDenseI64ArrayAttr({stride_h, stride_w}); } { int64_t dilation_h = tfl_conv2d_op.getDilationHFactor(); int64_t dilation_w = tfl_conv2d_op.getDilationWFactor(); - dilation = rewriter.getI64ArrayAttr({dilation_h, dilation_w}); + dilation = rewriter.getDenseI64ArrayAttr({dilation_h, dilation_w}); } { tensorflow::Padding tf_pad; @@ -1501,14 +1568,14 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( rewriter, op->getLoc(), RankedTensorType::get(ArrayRef(a1_transpose_dims), filter_type.getElementType()), - tfl_conv2d_op.getFilter(), a1_filter_transpose_perms.getValue()); + tfl_conv2d_op.getFilter(), a1_filter_transpose_perms.value()); auto a2_filter_reshape_op = CreateOpAndInfer( rewriter, op->getLoc(), RankedTensorType::get(ArrayRef(a2_reshape_dims), filter_type.getElementType()), a1_filter_transpose_op.getResult(), - rewriter.getI64ArrayAttr(a2_reshape_dims)); + rewriter.getDenseI64ArrayAttr(a2_reshape_dims)); Value unquantized_bias = tfl_conv2d_op.getBias(); Type bias_ety = @@ -1538,7 +1605,7 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( if (!fused_activation_val) return failure(); - rewriter.replaceOp(op, {fused_activation_val.getValue()}); + rewriter.replaceOp(op, {fused_activation_val.value()}); return success(); } @@ -1590,11 +1657,11 @@ LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( lhs = CreateOpAndInfer( rewriter, op->getLoc(), UnrankedTensorType::get(lhs_ty.getElementType()), lhs, - rewriter.getI64ArrayAttr(new_lhs_shape)); + rewriter.getDenseI64ArrayAttr(new_lhs_shape)); rhs = CreateOpAndInfer( rewriter, op->getLoc(), UnrankedTensorType::get(rhs_ty.getElementType()), rhs, - rewriter.getI64ArrayAttr(new_rhs_shape)); + rewriter.getDenseI64ArrayAttr(new_rhs_shape)); lhs_ty = lhs.getType().cast(); rhs_ty = rhs.getType().cast(); } @@ -1602,7 +1669,7 @@ LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( if (transpose_lhs) { Value perms = getConstTensor(rewriter, op, /*vec=*/{0, 2, 1}, /*shape=*/{3}) - .getValue(); + .value(); Type output_type = UnrankedTensorType::get(lhs_ty.getElementType()); lhs = CreateOpAndInfer(rewriter, op->getLoc(), output_type, lhs, perms) @@ -1612,7 +1679,7 @@ LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( if (transpose_rhs) { Value perms = getConstTensor(rewriter, op, /*vec=*/{0, 2, 1}, /*shape=*/{3}) - .getValue(); + .value(); Type output_type = UnrankedTensorType::get(rhs_ty.getElementType()); rhs = CreateOpAndInfer(rewriter, op->getLoc(), output_type, rhs, perms) @@ -1640,7 +1707,7 @@ LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( matmul = CreateOpAndInfer( rewriter, op->getLoc(), UnrankedTensorType::get(matmul_ty.getElementType()), matmul, - rewriter.getI64ArrayAttr(new_shape)); + rewriter.getDenseI64ArrayAttr(new_shape)); } if (lhs_is_qtype) { @@ -1704,7 +1771,7 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( RankedTensorType::get(shape_vals, input_type.getElementType()); auto reshape_op = CreateOpAndInfer( rewriter, op->getLoc(), reshape_type, tfl_fc_op.getInput(), - rewriter.getI64ArrayAttr(shape_vals)); + rewriter.getDenseI64ArrayAttr(shape_vals)); input_val = reshape_op.getResult(); } @@ -1727,7 +1794,7 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( new_bias_type = RankedTensorType::get(bias_shape, input_type.getElementType()); bias_attr = - DenseElementsAttr::get(new_bias_type, llvm::makeArrayRef(bias_arr)); + DenseElementsAttr::get(new_bias_type, llvm::ArrayRef(bias_arr)); } else { SmallVector bias_arr(bias_shape[0]); @@ -1745,7 +1812,7 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( : rewriter.getI32Type(); new_bias_type = RankedTensorType::get(bias_shape, new_bias_ety); bias_attr = - DenseElementsAttr::get(new_bias_type, llvm::makeArrayRef(bias_arr)); + DenseElementsAttr::get(new_bias_type, llvm::ArrayRef(bias_arr)); } auto bias_op = CreateOpAndInfer(rewriter, op->getLoc(), new_bias_type, bias_attr); @@ -1778,7 +1845,7 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( fc_output = CreateOpAndInfer( rewriter, op->getLoc(), UnrankedTensorType::get(fc_type.getElementType()), fc_output, - rewriter.getI64ArrayAttr(output_type.getShape())); + rewriter.getDenseI64ArrayAttr(output_type.getShape())); } auto fused_activation_fn = tfl_fc_op.getFusedActivationFunctionAttr(); @@ -1789,7 +1856,7 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( if (!fused_activation_val) return failure(); - rewriter.replaceOp(op, {fused_activation_val.getValue()}); + rewriter.replaceOp(op, {fused_activation_val.value()}); return success(); } @@ -1801,6 +1868,7 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( LogicalResult ConvertTFLConcatenationOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_concat_op = cast(op); + auto result_type = tfl_concat_op.getResult().getType().dyn_cast(); SmallVector values(tfl_concat_op.getValues()); @@ -1815,11 +1883,11 @@ LogicalResult ConvertTFLConcatenationOp::matchAndRewrite( int32_t axis = axis_attr.getInt(); llvm::Optional result = - convertConcatV2Op(rewriter, op, tfl_concat_op.getResult(), values, axis); + convertConcatV2Op(rewriter, op, result_type, values, axis); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1831,20 +1899,20 @@ LogicalResult ConvertTFLReshapeOp::matchAndRewrite( ShapedType shape_type = shape.getType().dyn_cast(); ShapedType output_type = tfl_reshape_op.getType().dyn_cast(); - int64_t rank = ShapedType::kDynamicSize; + int64_t rank = ShapedType::kDynamic; if (output_type.hasRank()) rank = output_type.getRank(); // Check the inferred rank from the shape tensor matches the output. if (shape_type.hasRank() && !shape_type.isDynamicDim(0)) { int64_t dim = shape_type.getDimSize(0); - if (rank != ShapedType::kDynamicSize && rank != dim) { + if (rank != ShapedType::kDynamic && rank != dim) { return rewriter.notifyMatchFailure(op, "static dim mismatch on tfl.reshape"); } rank = dim; } - if (rank == ShapedType::kDynamicSize) { + if (rank == ShapedType::kDynamic) { return rewriter.notifyMatchFailure(op, "unknown rank for output shape"); } @@ -1856,10 +1924,10 @@ LogicalResult ConvertTFLReshapeOp::matchAndRewrite( auto e_ty = shape_ty.getElementType(); Value dim = rewriter.createOrFold( op->getLoc(), RankedTensorType::get({1}, e_ty), shape, - rewriter.getI64ArrayAttr({i}), rewriter.getI64ArrayAttr({1})); + rewriter.getDenseI64ArrayAttr({i}), rewriter.getDenseI64ArrayAttr({1})); dim = rewriter.createOrFold( op->getLoc(), RankedTensorType::get({}, e_ty), dim, - rewriter.getI64ArrayAttr({})); + rewriter.getDenseI64ArrayAttr({})); shape_vals.push_back(dim); } @@ -1870,7 +1938,7 @@ LogicalResult ConvertTFLReshapeOp::matchAndRewrite( if (!reshape.has_value()) return failure(); - rewriter.replaceOp(op, {reshape.getValue()}); + rewriter.replaceOp(op, {reshape.value()}); return success(); } @@ -1886,7 +1954,7 @@ LogicalResult ConvertTFLRankOp::matchAndRewrite( RankedTensorType rank_type = RankedTensorType::get({1}, rewriter.getIntegerType(32)); - auto rank_attr = DenseElementsAttr::get(rank_type, {rank}); + auto rank_attr = DenseI32ArrayAttr::get(rewriter.getContext(), {rank}); auto rank_const = CreateOpAndInfer(rewriter, op->getLoc(), rank_type, rank_attr); @@ -1919,7 +1987,7 @@ LogicalResult ConvertTFLShapeOp::matchAndRewrite( RankedTensorType shape_type = RankedTensorType::get( {static_cast(shape_arr.size())}, rewriter.getIntegerType(32)); auto shape_attr = - DenseElementsAttr::get(shape_type, llvm::makeArrayRef(shape_arr)); + DenseI32ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(shape_arr)); auto shape_const = CreateOpAndInfer(rewriter, op->getLoc(), shape_type, shape_attr); @@ -1938,7 +2006,7 @@ LogicalResult ConvertTFLExpandDimsOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1960,7 +2028,7 @@ LogicalResult ConvertTFLSqueezeOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -1990,17 +2058,19 @@ LogicalResult ConvertTFLFillOp::matchAndRewrite( RankedTensorType fill_type = RankedTensorType::get( ArrayRef(dims_vals), value_elem.getType().getElementType()); - DenseElementsAttr fill_attr; + DenseArrayAttr fill_attr; // Convert to a compatible zero type. if (value_elem.getType().getElementType().isa()) { SmallVector fill_arr( total_size, value_elem.getValues()[0].convertToFloat()); - fill_attr = DenseElementsAttr::get(fill_type, llvm::makeArrayRef(fill_arr)); + fill_attr = + DenseF32ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(fill_arr)); } else { SmallVector fill_arr( total_size, value_elem.getValues()[0].getLimitedValue()); - fill_attr = DenseElementsAttr::get(fill_type, llvm::makeArrayRef(fill_arr)); + fill_attr = + DenseI32ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(fill_arr)); } auto fill_const_op = CreateOpAndInfer(rewriter, op->getLoc(), fill_type, fill_attr); @@ -2026,7 +2096,7 @@ LogicalResult ConvertTFLReduceAnyOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2048,7 +2118,7 @@ LogicalResult ConvertTFLReduceMaxOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2070,7 +2140,7 @@ LogicalResult ConvertTFLReduceMinOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2092,7 +2162,7 @@ LogicalResult ConvertTFLReduceProdOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2114,7 +2184,7 @@ LogicalResult ConvertTFLMeanOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2136,7 +2206,7 @@ LogicalResult ConvertTFLSumOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2150,7 +2220,7 @@ LogicalResult ConvertTFLEluOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2165,7 +2235,7 @@ LogicalResult ConvertTFLSoftmaxOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2201,14 +2271,14 @@ LogicalResult ConvertTFLL2NormalizationOp::matchAndRewrite( rewriter.getI64IntegerAttr(input_ty.getRank() - 1)); SmallVector min(1, sqrt(std::numeric_limits::min())); - Value min_val = getConstTensor(rewriter, op, min, {}).getValue(); + Value min_val = getConstTensor(rewriter, op, min, {}).value(); auto max = CreateOpAndInfer(rewriter, loc, result_ty, sum, min_val); auto rsqrt = CreateOpAndInfer(rewriter, loc, result_ty, max) .getResult(); - auto result = CreateOpAndInfer(rewriter, loc, result_ty, rsqrt, - input, shift) - .getResult(); + Value result = CreateOpAndInfer(rewriter, loc, result_ty, + rsqrt, input, shift) + .getResult(); auto fused_activation_fn = tfl_l2norm_op.getFusedActivationFunctionAttr(); @@ -2216,7 +2286,7 @@ LogicalResult ConvertTFLL2NormalizationOp::matchAndRewrite( llvm::Optional fused_activation_val = convertFusedActivation(rewriter, op, result, fused_activation_fn); if (!fused_activation_val) return failure(); - result = fused_activation_val.getValue(); + result = fused_activation_val.value(); } rewriter.replaceOp(op, result); @@ -2236,7 +2306,7 @@ LogicalResult ConvertTFLLogSoftmaxOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2265,8 +2335,8 @@ LogicalResult ConvertTFLSliceOp::matchAndRewrite( for (int i = 0; i < size_elems.getNumElements(); i++) size_vals.push_back(size_elems.getValues()[i].getSExtValue()); - ArrayAttr begin = rewriter.getI64ArrayAttr(begin_vals); - ArrayAttr size = rewriter.getI64ArrayAttr(size_vals); + DenseI64ArrayAttr begin = rewriter.getDenseI64ArrayAttr(begin_vals); + DenseI64ArrayAttr size = rewriter.getDenseI64ArrayAttr(size_vals); CreateReplaceOpAndInfer(rewriter, op, output_type, tfl_slice_op.getInput(), begin, size); @@ -2290,7 +2360,8 @@ LogicalResult ConvertTFLTileOp::matchAndRewrite( multiples_vals.push_back( multiples_elems.getValues()[i].getSExtValue()); - ArrayAttr multiples_attr = rewriter.getI64ArrayAttr(multiples_vals); + DenseI64ArrayAttr multiples_attr = + rewriter.getDenseI64ArrayAttr(multiples_vals); CreateReplaceOpAndInfer(rewriter, op, output_type, tfl_tile_op.getInput(), multiples_attr); @@ -2329,7 +2400,7 @@ LogicalResult ConvertTFLPackOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2351,7 +2422,7 @@ LogicalResult ConvertTFLUnpackOp::matchAndRewrite( if (!results) return failure(); - rewriter.replaceOp(op, results.getValue()); + rewriter.replaceOp(op, results.value()); return success(); } @@ -2386,7 +2457,7 @@ LogicalResult ConvertTFLSplitOp::matchAndRewrite( if (!results) return failure(); - rewriter.replaceOp(op, results.getValue()); + rewriter.replaceOp(op, results.value()); return success(); } @@ -2424,7 +2495,7 @@ LogicalResult ConvertTFLSplitVOp::matchAndRewrite( if (!results) return failure(); - rewriter.replaceOp(op, results.getValue()); + rewriter.replaceOp(op, results.value()); return success(); } @@ -2473,7 +2544,7 @@ LogicalResult ConvertTFLMirrorPadOp::matchAndRewrite( rewriter, op, output_type, tfl_mirrorpad_op.getInput(), tfl_mirrorpad_op.getPad(), mode); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2508,7 +2579,7 @@ LogicalResult ConvertTFLResizeBilinearOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2530,7 +2601,7 @@ LogicalResult ConvertTFLResizeNearestNeighborOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2544,7 +2615,7 @@ LogicalResult ConvertTFLSelectOp::matchAndRewrite( tfl_sel_op.getX(), tfl_sel_op.getY()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2558,7 +2629,7 @@ LogicalResult ConvertTFLSelectV2Op::matchAndRewrite( tfl_sel_op.getX(), tfl_sel_op.getY()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2572,7 +2643,7 @@ LogicalResult ConvertTFLSpaceToBatchNdOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2587,7 +2658,7 @@ LogicalResult ConvertTFLBatchToSpaceNdOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2603,7 +2674,7 @@ LogicalResult ConvertTFLSpaceToDepthOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2619,7 +2690,69 @@ LogicalResult ConvertTFLDepthToSpaceOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); + + return success(); +} + +LogicalResult ConvertTFLBucketizeOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_bucketize_op = cast(op); + Location loc = op->getLoc(); + + Value input = tfl_bucketize_op.getInput(); + auto boundaries_attr = tfl_bucketize_op.getBoundaries(); + RankedTensorType input_type = input.getType().dyn_cast(); + if (!input_type) { + return rewriter.notifyMatchFailure(op, "input is not a ranked tensor"); + } + + // The lowering is done by broadcasting the input and boundaries together, and + // using GE comparison for each input against each boundary. Adding the + // results of the comparison for each input generates the bucket it belongs + // to, as the boundaries are sorted. + ShapedType output_type = + tfl_bucketize_op.getResult().getType().dyn_cast(); + + auto input_shape = input_type.getShape(); + + SmallVector boundaries; + for (auto& boundary : boundaries_attr) { + boundaries.emplace_back(boundary.dyn_cast().getValue()); + } + int64_t boundaries_size = boundaries.size(); + + // Add a dim at the end of input shape for broadcasting with the boundaries. + SmallVector broadcast_shape(input_shape.begin(), input_shape.end()); + broadcast_shape.push_back(boundaries_size); + SmallVector new_input_shape(input_shape.begin(), input_shape.end()); + new_input_shape.push_back(1); + + auto boundaries_type = + RankedTensorType::get({boundaries_size}, rewriter.getF32Type()); + + auto boundaries_op = CreateOpAndInfer( + rewriter, loc, boundaries_type, + DenseElementsAttr::get(boundaries_type, boundaries)); + + auto reshaped_input = CreateOpAndInfer( + rewriter, loc, input_type.clone(new_input_shape), input, + rewriter.getDenseI64ArrayAttr(new_input_shape)); + + auto ge = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(rewriter.getIntegerType(1)), + reshaped_input, boundaries_op); + + auto casted = CreateOpAndInfer( + rewriter, loc, UnrankedTensorType::get(rewriter.getIntegerType(32)), ge); + + auto sum = CreateOpAndInfer( + rewriter, loc, output_type, casted, + rewriter.getI64IntegerAttr(input_type.getRank())); + + CreateReplaceOpAndInfer( + rewriter, op, output_type, sum, + rewriter.getDenseI64ArrayAttr(output_type.getShape())); return success(); } @@ -2638,7 +2771,7 @@ LogicalResult ConvertTFLStridedSliceOp::matchAndRewrite( tfl_ss_op.getShrinkAxisMaskAttr().getInt()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2652,7 +2785,7 @@ LogicalResult ConvertTFLZerosLikeOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -2738,133 +2871,168 @@ LogicalResult ConvertTFLHardSwishOp::matchAndRewrite( LogicalResult ConvertTFLSinOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_sin_op = cast(op); - Location loc = op->getLoc(); - Value input = tfl_sin_op.getX(); - RankedTensorType input_ty = input.getType().dyn_cast(); - ShapedType output_ty = + auto input = tfl_sin_op.getX(); + ShapedType output_type = tfl_sin_op.getResult().getType().dyn_cast(); - Type input_ety = input_ty.getElementType(); - Type output_ety = output_ty.getElementType(); + llvm::Optional result = convertSinOp(rewriter, op, input, output_type); + if (!result) return failure(); - if (!input_ty || !output_ty) return failure(); + rewriter.replaceOp(op, {result.value()}); + return success(); +} - if (input_ety != output_ety) { - return rewriter.notifyMatchFailure(op, - "input/output element type must match"); - } +LogicalResult ConvertTFLCosOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_cos_op = cast(op); + Value input = tfl_cos_op.getX(); + RankedTensorType input_ty = input.getType().dyn_cast(); + ShapedType output_ty = + tfl_cos_op.getResult().getType().dyn_cast(); - bool input_is_fp = input_ty.getElementType().isF32(); - bool output_is_fp = output_ty.getElementType().isF32(); + if (!input_ty || !output_ty) return failure(); + + bool input_is_fp = input_ty.getElementType().isa(); + bool output_is_fp = output_ty.getElementType().isa(); if (!input_is_fp || !output_is_fp) { - return rewriter.notifyMatchFailure(op, "input/result must be fp32"); + return rewriter.notifyMatchFailure(op, "input/result must be fp"); } - // To perform a sin operation we remap the sin domain to be over a single - // period of the function, remapping to the domain of the table function. - // We then remap the range of the table function to map to the range of the - // sin operation. - - // 1. Normalize the period of the domain from [0, 2π) to [0, 1). + // Replace with the equivalent sin operation: + // cos(x) = sin(x + π / 2). auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type()); - Value fp_scale = rewriter.create( - loc, fp_scalar_ty, - DenseElementsAttr::get(fp_scalar_ty, {static_cast(0.5 / M_PI)})); + auto pi_2 = rewriter.create( + op->getLoc(), fp_scalar_ty, + DenseElementsAttr::get(fp_scalar_ty, {static_cast(M_PI_2)})); + auto offset = rewriter.create(op->getLoc(), input_ty, input, pi_2); - // 2. Remap the periodic behavior of the domain to line up within [0, 1). - Value fp_scaled = CreateOpAndInfer( - rewriter, loc, input_ty, input, fp_scale, rewriter.getI32IntegerAttr(0)); - auto floored = - CreateOpAndInfer(rewriter, loc, input_ty, fp_scaled); - auto repeated = CreateOpAndInfer(rewriter, loc, input_ty, - fp_scaled, floored); + CreateReplaceOpAndInfer(rewriter, op, output_ty, offset); + return success(); +} - // 3. Scale and translate the normalized domain to the table domain. This - // includes a translating and scaling to [-int16_max, int16_max] and casting - // to an i16. - Value one = rewriter.create( - loc, fp_scalar_ty, DenseElementsAttr::get(fp_scalar_ty, {1.0f})); +LogicalResult ConvertTFLAtan2Op::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_atan2_op = cast(op); + Location loc = op->getLoc(); + Value input_y = tfl_atan2_op.getY(); + Value input_x = tfl_atan2_op.getX(); - Value two = rewriter.create( - loc, fp_scalar_ty, DenseElementsAttr::get(fp_scalar_ty, {2.0f})); - auto scale_up = CreateOpAndInfer( - rewriter, loc, input_ty, repeated, two, rewriter.getI32IntegerAttr(0)); - auto translate = - CreateOpAndInfer(rewriter, loc, input_ty, scale_up, one); + auto input_y_ty = dyn_cast(input_y.getType()); + auto input_x_ty = dyn_cast(input_x.getType()); + auto output_ty = dyn_cast(tfl_atan2_op.getResult().getType()); + + if (!input_y_ty || !input_x_ty || !output_ty) { + return rewriter.notifyMatchFailure(op, "ranked inputs/output required"); + } + if (!input_y_ty.getElementType().isF32()) { + return rewriter.notifyMatchFailure(op, "input must be fp32"); + } + + // To perform an atan2 operation we make use of an atan lookup table, + // then determine the correct quadrant for each output. To restrict the + // input domain of the lookup table from [-inf, inf] to [0, 1], we make + // use of two identities and undo the transformation later on: + // + // acrtan(z) = π/2 - arctan(1/z) (0) + // + // and + // + // arctan(-z) = -arctan(z) (1) + + Value pi = getTosaConstTensorSingleF32(rewriter, op, M_PI); + Value pi_2 = getTosaConstTensorSingleF32(rewriter, op, M_PI_2); + Value zero = getTosaConstTensorSingleF32(rewriter, op, 0.0); + Value one = getTosaConstTensorSingleF32(rewriter, op, 1.0); + Value two = getTosaConstTensorSingleF32(rewriter, op, 2.0); + + // 1. Restrict the input to the atan lookup from [-inf, inf] to [0, 1]. + // By utilizing (0) and (1) we compute: min(|x|, |y|) / max(|x|, |y|). + auto abs_y = + CreateOpAndInfer(rewriter, loc, input_y_ty, input_y); + auto abs_x = + CreateOpAndInfer(rewriter, loc, input_y_ty, input_x); + auto min_xy = CreateOpAndInfer(rewriter, loc, input_y_ty, + abs_y, abs_x); + auto max_xy = CreateOpAndInfer(rewriter, loc, input_y_ty, + abs_y, abs_x); + auto recip = + CreateOpAndInfer(rewriter, loc, input_y_ty, max_xy); + auto atan_input = CreateOpAndInfer( + rewriter, loc, input_y_ty, recip, min_xy, rewriter.getI32IntegerAttr(0)); + + // 2. Scale and translate the normalized domain to the table domain. This + // includes a translating and scaling to [-int16_max, int16_max] and casting + // to an i16 as it is the highest precision the table operation supports. + auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type()); + auto scale_up = + CreateOpAndInfer(rewriter, loc, input_y_ty, atan_input, two, + rewriter.getI32IntegerAttr(0)); + auto translate = + CreateOpAndInfer(rewriter, loc, input_y_ty, scale_up, one); Value int_limit = rewriter.create( loc, fp_scalar_ty, DenseElementsAttr::get( fp_scalar_ty, {static_cast(std::numeric_limits::max())})); auto int_scaled = - CreateOpAndInfer(rewriter, loc, input_ty, translate, + CreateOpAndInfer(rewriter, loc, input_y_ty, translate, int_limit, rewriter.getI32IntegerAttr(0)); - auto int16_ty = input_ty.clone(rewriter.getIntegerType(16)); + auto int16_ty = input_y_ty.clone(rewriter.getIntegerType(16)); auto casted = CreateOpAndInfer(rewriter, loc, int16_ty, int_scaled); - // 4. Compute the lookup table using the range of [-255, 255] for sin. - llvm::SmallVector values; - const int num_values = 513; - values.resize(num_values, 0); - // First and last values should be 0; - for (int i = 1; i < num_values - 1; ++i) - values[i] = std::numeric_limits::max() * - sin(static_cast(i) * 2.0 * M_PI / (num_values - 1.0)); - - auto table_ty = - RankedTensorType::get({num_values}, rewriter.getIntegerType(16)); - Value table = rewriter.create( - loc, table_ty, - DenseElementsAttr::get(table_ty, llvm::makeArrayRef(values))); - - auto table_result_ty = input_ty.clone(rewriter.getIntegerType(32)); + // 3. Compute a lookup table using the domain of [0, 1] for atan. + // Note: the implementation of std::atan2 may be different on + // different machines, so may result in varying numerical results. + auto atan_func = [](double x) -> double { return std::atan(x); }; + Value table_const = getTosaConst16bitTable(rewriter, op, atan_func, 0.0, 1.0); auto table_result = CreateOpAndInfer( - rewriter, loc, table_result_ty, casted, table); + rewriter, loc, output_ty.clone(rewriter.getIntegerType(32)), casted, + table_const); - // 5. The range of table is a 23-bit two's compliment value. Normalize the + // 4. The range of table is a 23-bit two's complement value. Normalize the // range by casting to an fp32 and dividing by 2^22. auto table_result_fp = - CreateOpAndInfer(rewriter, loc, input_ty, table_result); + CreateOpAndInfer(rewriter, loc, output_ty, table_result); auto output_scale = rewriter.create( loc, fp_scalar_ty, DenseElementsAttr::get( fp_scalar_ty, {static_cast(1.0 / static_cast(1 << 22))})); - CreateReplaceOpAndInfer(rewriter, op, output_ty, table_result_fp, - output_scale, rewriter.getI32IntegerAttr(0)); - return success(); -} + auto table_output = CreateOpAndInfer( + rewriter, loc, output_ty, table_result_fp, output_scale, + rewriter.getI32IntegerAttr(0)); -LogicalResult ConvertTFLCosOp::matchAndRewrite( - Operation* op, PatternRewriter& rewriter) const { - auto tfl_cos_op = cast(op); - Value input = tfl_cos_op.getX(); - RankedTensorType input_ty = input.getType().dyn_cast(); - ShapedType output_ty = - tfl_cos_op.getResult().getType().dyn_cast(); + auto bool_ty = output_ty.clone(rewriter.getIntegerType(1)); - if (!input_ty || !output_ty) return failure(); + // 5. If (0) was applied to the atan input, apply π/2 - table_output. + auto sub_pi_2 = CreateOpAndInfer(rewriter, loc, output_ty, pi_2, + table_output); + auto condition = + CreateOpAndInfer(rewriter, loc, bool_ty, abs_y, abs_x); + auto transform_output = CreateOpAndInfer( + rewriter, loc, output_ty, condition, sub_pi_2, table_output); - bool input_is_fp = input_ty.getElementType().isa(); - bool output_is_fp = output_ty.getElementType().isa(); + // 6. Determine the correct atan2 quadrant. + // If x < 0, apply π - transform_output. + auto sub_pi = CreateOpAndInfer(rewriter, loc, output_ty, pi, + transform_output); + auto cond_1 = + CreateOpAndInfer(rewriter, loc, bool_ty, zero, input_x); + auto quadrant_select = CreateOpAndInfer( + rewriter, loc, output_ty, cond_1, sub_pi, transform_output); - if (!input_is_fp || !output_is_fp) { - return rewriter.notifyMatchFailure(op, "input/result must be fp"); - } + // 7. If (1) was applied to the atan input, negate output. + auto neg_r = CreateOpAndInfer(rewriter, loc, output_ty, + quadrant_select); + auto cond_2 = + CreateOpAndInfer(rewriter, loc, bool_ty, zero, input_y); + CreateReplaceOpAndInfer(rewriter, op, output_ty, cond_2, + neg_r, quadrant_select); - // Replace with the equivalent sin operation: - // cos(x) = sin(x + π / 2). - auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type()); - auto pi_2 = rewriter.create( - op->getLoc(), fp_scalar_ty, - DenseElementsAttr::get(fp_scalar_ty, {static_cast(M_PI_2)})); - auto offset = rewriter.create(op->getLoc(), input_ty, input, pi_2); - - CreateReplaceOpAndInfer(rewriter, op, output_ty, offset); return success(); } @@ -3169,9 +3337,14 @@ LogicalResult ConvertTFLYieldOp::matchAndRewrite( LogicalResult ConvertTFLCustomOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_custom_op = cast(op); - rewriter.replaceOpWithNewOp(op, op->getResultTypes(), - tfl_custom_op.getCustomCode(), - op->getOperands()); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), tfl_custom_op.getCustomCode(), + rewriter.getStringAttr("TFL"), + tfl_custom_op.getCustomOption() + .cast() + .getValue() + .str(), + op->getOperands()); return success(); } @@ -3257,7 +3430,7 @@ LogicalResult ConvertTFLQuantizeOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -3294,7 +3467,7 @@ LogicalResult ConvertTFLDequantizeOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -3317,7 +3490,7 @@ LogicalResult ConvertTFLDequantizeOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -3422,7 +3595,7 @@ LogicalResult ConvertTFLGatherOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -3436,7 +3609,7 @@ LogicalResult ConvertTFLGatherNdOp::matchAndRewrite( tfl_gathernd_op.getIndices()); if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -3490,7 +3663,7 @@ LogicalResult ConvertTFLSparseToDenseOp::matchAndRewrite( auto multiply_constant_type = RankedTensorType::get({result_rank}, indices_ety); auto multiply_constant_attr = DenseElementsAttr::get( - multiply_constant_type, llvm::makeArrayRef(multiply_constant_ints)); + multiply_constant_type, llvm::ArrayRef(multiply_constant_ints)); Value multiply_constant = CreateOpAndInfer( rewriter, loc, multiply_constant_type, multiply_constant_attr); @@ -3503,12 +3676,13 @@ LogicalResult ConvertTFLSparseToDenseOp::matchAndRewrite( auto values_reshape_op = CreateOpAndInfer( rewriter, loc, UnrankedTensorType::get(result_ety), values, - rewriter.getI64ArrayAttr( - ArrayRef{1, values_ty.getDimSize(0), 1})); + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF({1, values_ty.getDimSize(0), 1}))); auto index_reshape_op = CreateOpAndInfer( rewriter, loc, UnrankedTensorType::get(indices_ety), reduce_op, - rewriter.getI64ArrayAttr(ArrayRef{1, indices_ty.getDimSize(0)})); + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF({1, indices_ty.getDimSize(0)}))); auto scatter = CreateOpAndInfer( rewriter, loc, UnrankedTensorType::get(result_ety), default_const, @@ -3516,7 +3690,8 @@ LogicalResult ConvertTFLSparseToDenseOp::matchAndRewrite( CreateReplaceOpAndInfer( rewriter, op, result_ty, scatter, - rewriter.getI64ArrayAttr(result_ty.getShape())); + rewriter.getDenseI64ArrayAttr( + tensorflow::ConvertMlirShapeToTF(result_ty.getShape()))); return success(); } @@ -3539,7 +3714,7 @@ LogicalResult ConvertTFLOneHotOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); return success(); } @@ -3553,6 +3728,12 @@ LogicalResult ConvertTFLArgMaxOp::matchAndRewrite( return failure(); int32_t dim = dim_elems.getValues()[0].getSExtValue(); + + if (dim < 0) { + auto input_type = cast(arg_max_op.getInput().getType()); + dim += input_type.getRank(); + } + CreateReplaceOpAndInfer( rewriter, op, arg_max_op.getType(), arg_max_op.getInput(), rewriter.getIntegerAttr(rewriter.getI64Type(), dim)); @@ -3578,7 +3759,39 @@ LogicalResult ConvertTFLFakeQuantOp::matchAndRewrite( if (!result) return failure(); - rewriter.replaceOp(op, {result.getValue()}); + rewriter.replaceOp(op, {result.value()}); + + return success(); +} + +// Clone block, convert yield from TFL to TOSA +static void inlineWhileCase(Region& srcRegion, Region& dstRegion, + PatternRewriter& rewriter) { + rewriter.cloneRegionBefore(srcRegion, &dstRegion.back()); + rewriter.eraseBlock(&dstRegion.back()); + + Block* headBlock = &dstRegion.front(); + + auto yield = cast(headBlock->getTerminator()); + rewriter.setInsertionPoint(yield); + rewriter.create(yield.getLoc(), yield.getOperands()); + rewriter.eraseOp(yield); +} + +LogicalResult ConvertTFLWhileOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_while_op = cast(op); + + auto while_op = rewriter.create( + op->getLoc(), op->getResultTypes(), op->getOperands()); + + rewriter.createBlock(&while_op.getCond()); + rewriter.createBlock(&while_op.getBody()); + + inlineWhileCase(tfl_while_op.getCond(), while_op.getCond(), rewriter); + inlineWhileCase(tfl_while_op.getBody(), while_op.getBody(), rewriter); + + rewriter.replaceOp(tfl_while_op, while_op.getResults()); return success(); } @@ -3623,6 +3836,7 @@ void populateLegalizeTFLPatterns(MLIRContext* ctx, DEF_PATTERN_INSERT(TFLGelu); DEF_PATTERN_INSERT(TFLRelu); DEF_PATTERN_INSERT(TFLRelu1); + DEF_PATTERN_INSERT(TFLRelu0To1); DEF_PATTERN_INSERT(TFLRelu6); DEF_PATTERN_INSERT(TFLEqual); DEF_PATTERN_INSERT(TFLNotEqual); @@ -3689,8 +3903,10 @@ void populateLegalizeTFLPatterns(MLIRContext* ctx, DEF_PATTERN_INSERT(TFLBatchToSpaceNd); DEF_PATTERN_INSERT(TFLSpaceToDepth); DEF_PATTERN_INSERT(TFLDepthToSpace); + DEF_PATTERN_INSERT(TFLBucketize); DEF_PATTERN_INSERT(TFLSin); DEF_PATTERN_INSERT(TFLCos); + DEF_PATTERN_INSERT(TFLAtan2); DEF_PATTERN_INSERT(TFLLogistic); DEF_PATTERN_INSERT(TFLTanh); DEF_PATTERN_INSERT(TFLPRelu); @@ -3710,6 +3926,7 @@ void populateLegalizeTFLPatterns(MLIRContext* ctx, DEF_PATTERN_INSERT(TFLOneHot); DEF_PATTERN_INSERT(TFLArgMax); DEF_PATTERN_INSERT(TFLFakeQuant); + DEF_PATTERN_INSERT(TFLWhile); } // Creates an instance of the TensorFlow Lite dialect LegalizeTFL pass. diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc index 81f7f5ceecb..45bbfb12b59 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc @@ -113,7 +113,7 @@ llvm::Optional buildReshapeWithDynamicDims(PatternRewriter& rewriter, return llvm::None; } - ArrayAttr shape_attr = rewriter.getI64ArrayAttr(static_dims); + DenseI64ArrayAttr shape_attr = rewriter.getDenseI64ArrayAttr(static_dims); auto output_ty = tensorflow::GetTypeFromTFTensorShape(static_dims, e_ty); return rewriter .create(op->getLoc(), output_ty, input_value, shape_attr) @@ -136,9 +136,9 @@ Value buildRescale(PatternRewriter& rewriter, Operation* op, rewriter, op->getLoc(), output_type, input_val, rewriter.getI32IntegerAttr(static_cast(input_zp)), rewriter.getI32IntegerAttr(static_cast(output_zp)), - rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}), - rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round), - rewriter.getBoolAttr(false)); + rewriter.getDenseI32ArrayAttr({multiplier}), + rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32), + rewriter.getBoolAttr(double_round), rewriter.getBoolAttr(false)); return rescale_op.getResult(); } @@ -206,8 +206,8 @@ Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op, auto rescale_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, conv_val, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp), - rewriter.getI32ArrayAttr({multiplier}), - rewriter.getI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32), + rewriter.getDenseI32ArrayAttr({multiplier}), + rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round), rewriter.getBoolAttr(false)); return rescale_op.getResult(); @@ -242,8 +242,8 @@ Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op, auto rescale_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, conv_val, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp), - rewriter.getI32ArrayAttr(multiplier_arr), - rewriter.getI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32), + rewriter.getDenseI32ArrayAttr(multiplier_arr), + rewriter.getDenseI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round), rewriter.getBoolAttr(true)); return rescale_op.getResult(); @@ -277,8 +277,7 @@ Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op, auto const_type = tensorflow::GetTypeFromTFTensorShape({256}, element_qtype); auto storage_type = tensorflow::GetTypeFromTFTensorShape( {256}, element_qtype.getStorageType()); - auto const_attr = - DenseElementsAttr::get(storage_type, llvm::makeArrayRef(table)); + auto const_attr = DenseElementsAttr::get(storage_type, llvm::ArrayRef(table)); auto const_op = rewriter.create(op->getLoc(), const_type, const_attr); @@ -314,14 +313,9 @@ Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op, table.push_back( static_cast(std::min(std::max(max_val, -32768), 32767))); - auto element_qtype = - UniformQuantizedType::get(true, rewriter.getIntegerType(16), - rewriter.getF32Type(), 1.0f, 0, -32768, 32767); - auto const_type = tensorflow::GetTypeFromTFTensorShape({513}, element_qtype); - auto storage_type = tensorflow::GetTypeFromTFTensorShape( - {513}, element_qtype.getStorageType()); - auto const_attr = - DenseElementsAttr::get(storage_type, llvm::makeArrayRef(table)); + auto const_type = + tensorflow::GetTypeFromTFTensorShape({513}, rewriter.getIntegerType(16)); + auto const_attr = DenseElementsAttr::get(const_type, llvm::ArrayRef(table)); auto const_op = rewriter.create(op->getLoc(), const_type, const_attr); @@ -372,13 +366,13 @@ void getTosaConst32bitTable(PatternRewriter& rewriter, Operation* op, {513}, element_qtype.getStorageType()); auto first_const_attr = - DenseElementsAttr::get(storage_type, llvm::makeArrayRef(first_table)); + DenseElementsAttr::get(storage_type, llvm::ArrayRef(first_table)); auto second_const_attr = - DenseElementsAttr::get(storage_type, llvm::makeArrayRef(second_table)); + DenseElementsAttr::get(storage_type, llvm::ArrayRef(second_table)); auto third_const_attr = - DenseElementsAttr::get(storage_type, llvm::makeArrayRef(third_table)); + DenseElementsAttr::get(storage_type, llvm::ArrayRef(third_table)); auto fourth_const_attr = - DenseElementsAttr::get(storage_type, llvm::makeArrayRef(fourth_table)); + DenseElementsAttr::get(storage_type, llvm::ArrayRef(fourth_table)); first_const = rewriter.create(op->getLoc(), const_type, first_const_attr) @@ -449,9 +443,10 @@ bool getPaddingValuesFromPadType(tensorflow::Padding tf_pad, tensorflow::TensorFormat data_format_tf, uint32_t first_filter_spatial_dim, ShapedType input_type, ShapedType filter_type, - ArrayAttr strides, ArrayAttr dilations, + DenseI64ArrayAttr strides, + DenseI64ArrayAttr dilations, PatternRewriter& rewriter, - ArrayAttr& explicit_padding) { + DenseI64ArrayAttr& explicit_padding) { assert(tf_pad != tensorflow::Padding::EXPLICIT); if (!input_type.hasRank() || !filter_type.getRank()) return false; // Only support NHWC for now. @@ -479,8 +474,8 @@ bool getPaddingValuesFromPadType(tensorflow::Padding tf_pad, int64_t ifm_dim = i + dim_index_shift; int64_t filter_dim = first_filter_spatial_dim + i; - int64_t dim_dilation = dilations[i].template cast().getInt(); - int64_t dim_stride = strides[i].template cast().getInt(); + int64_t dim_dilation = dilations[i]; + int64_t dim_stride = strides[i]; int64_t ip_size = input_type.getDimSize(ifm_dim); int64_t f_size = filter_type.getDimSize(filter_dim); @@ -499,7 +494,7 @@ bool getPaddingValuesFromPadType(tensorflow::Padding tf_pad, computed_paddings.push_back(pad_after); } - explicit_padding = rewriter.getI64ArrayAttr(computed_paddings); + explicit_padding = rewriter.getDenseI64ArrayAttr(computed_paddings); return true; } @@ -513,7 +508,7 @@ bool getPaddingValuesFromPadType(tensorflow::Padding tf_pad, // The explicit padding array in TF holds 2 pad values for every // dimension, even those that are not the 2 spatial ones. Just extract the // 2x pad values for the XY dims. -ArrayAttr getPaddingValuesFromExplicitPadAttr( +DenseI64ArrayAttr getPaddingValuesFromExplicitPadAttr( ArrayAttr explicit_pad, tensorflow::TensorFormat data_format_tf, PatternRewriter& rewriter) { SmallVector computed_paddings; @@ -522,22 +517,21 @@ ArrayAttr getPaddingValuesFromExplicitPadAttr( for (int i = 0; i < 2; i++) { // Two spatial dimensions X&Y int64_t dim = GetTensorSpatialDimIndex(4, data_format_tf, i); // 4D tensor, NHWC/NCHW format - pad_before = explicit_pad[dim * 2].template cast().getInt(); pad_after = explicit_pad[dim * 2 + 1].template cast().getInt(); computed_paddings.push_back(pad_before); computed_paddings.push_back(pad_after); } - return rewriter.getI64ArrayAttr(computed_paddings); + return rewriter.getDenseI64ArrayAttr(computed_paddings); } // Calculates the TOSA padding values for transposeConv2d bool getTransposeConv2dPaddingValues( tensorflow::Padding tf_pad, tensorflow::TensorFormat data_format_tf, uint32_t first_filter_spatial_dim, ShapedType input_type, - ShapedType filter_type, ShapedType output_type, ArrayAttr strides, - PatternRewriter& rewriter, ArrayAttr& explicit_padding) { + ShapedType filter_type, ShapedType output_type, DenseI64ArrayAttr strides, + PatternRewriter& rewriter, DenseI64ArrayAttr& explicit_padding) { assert(tf_pad != tensorflow::Padding::EXPLICIT); if (!input_type.hasRank() || !filter_type.hasRank() || !output_type.hasRank()) return false; @@ -558,7 +552,7 @@ bool getTransposeConv2dPaddingValues( int64_t ifm_size = input_type.getDimSize(ifm_dim); int64_t filter_size = filter_type.getDimSize(filter_dim); int64_t ofm_size = output_type.getDimSize(ofm_dim); - int64_t dim_stride = strides[i].template cast().getInt(); + int64_t dim_stride = strides[i]; // These dimensions need to be static to legalize. if (ShapedType::isDynamic(filter_size) || ShapedType::isDynamic(ifm_size) || @@ -576,7 +570,7 @@ bool getTransposeConv2dPaddingValues( computed_paddings.push_back(pad_after); } - explicit_padding = rewriter.getI64ArrayAttr(computed_paddings); + explicit_padding = rewriter.getDenseI64ArrayAttr(computed_paddings); return true; } @@ -682,7 +676,7 @@ LogicalResult ApplyPatternsWithShapeResolution( // type stripping changing. func.walk([&](tosa::ConstOp op) { auto ety = op.getValue().getType().getElementType(); - auto new_ty = op.getType().cast().clone(ety); + auto new_ty = op.getType().cast().clone(ety); op.getResult().setType(new_ty); }); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h index 854195d46bb..8a6ac407c27 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h @@ -107,12 +107,13 @@ bool getPaddingValuesFromPadType(tensorflow::Padding tf_pad, tensorflow::TensorFormat data_format_tf, uint32_t first_filter_spatial_dim, ShapedType input_type, ShapedType filter_type, - ArrayAttr strides, ArrayAttr dilations, + DenseI64ArrayAttr strides, + DenseI64ArrayAttr dilations, PatternRewriter& rewriter, - ArrayAttr& explicit_pad); + DenseI64ArrayAttr& explicit_pad); // Calculates the TOSA padding values for explicit-padded TF operators. -ArrayAttr getPaddingValuesFromExplicitPadAttr( +DenseI64ArrayAttr getPaddingValuesFromExplicitPadAttr( ArrayAttr explicit_pad, tensorflow::TensorFormat data_format_tf, PatternRewriter& rewriter); @@ -120,8 +121,8 @@ ArrayAttr getPaddingValuesFromExplicitPadAttr( bool getTransposeConv2dPaddingValues( tensorflow::Padding tf_pad, tensorflow::TensorFormat data_format_tf, uint32_t first_filter_spatial_dim, ShapedType input_type, - ShapedType filter_type, ShapedType output_type, ArrayAttr strides, - PatternRewriter& rewriter, ArrayAttr& explicit_pad); + ShapedType filter_type, ShapedType output_type, DenseI64ArrayAttr strides, + PatternRewriter& rewriter, DenseI64ArrayAttr& explicit_pad); // Templated function to create a constant op for given type and shape. // T: storage C type. @@ -183,7 +184,7 @@ TosaOp CreateOpAndInfer(PatternRewriter& rewriter, Location loc, Type result_ty, Type new_ty = newKnowledge.hasRank ? Type{tensorflow::GetTypeFromTFTensorShape( - llvm::makeArrayRef(newKnowledge.sizes), newKnowledge.dtype)} + llvm::ArrayRef(newKnowledge.sizes), newKnowledge.dtype)} : Type{mlir::UnrankedTensorType::get(newKnowledge.dtype)}; result.setType(new_ty); return op; diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index dbaff4a8552..b888cf43c40 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -1,9 +1,10 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) @@ -47,30 +48,6 @@ gentbl_cc_library( ], ) -gentbl_cc_library( - name = "xla_passes_inc_gen", - compatible_with = get_compatible_with_cloud(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=Xla", - ], - "transforms/xla_passes.h.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "transforms/xla_passes.td", - deps = [ - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", - "//tensorflow/compiler/xla/mlir_hlo:hlo_ops_td_files", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncTdFiles", - "@llvm-project//mlir:PassBaseTdFiles", - "@llvm-project//mlir:TensorOpsTdFiles", - ], -) - gentbl_cc_library( name = "tf_xla_passes_inc_gen", compatible_with = get_compatible_with_cloud(), @@ -96,39 +73,6 @@ gentbl_cc_library( ], ) -cc_library( - name = "xla_passes", - srcs = [ - "transforms/outline_with_xla_framework.cc", - "transforms/xla_framework_to_llvm_pass.cc", - ], - hdrs = [ - "transforms/xla_passes.h", - ], - deps = [ - ":xla_framework", - ":xla_passes_inc_gen", - "//tensorflow/compiler/xla/mlir_hlo", - "@llvm-project//llvm:Core", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ArithToLLVM", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncToLLVM", - "@llvm-project//mlir:FuncTransforms", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMCommonConversion", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:MathToLLVM", - "@llvm-project//mlir:MemRefToLLVM", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ReconcileUnrealizedCasts", - "@llvm-project//mlir:SCFToControlFlow", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", - ], -) - cc_library( name = "tf_xla_passes", srcs = [ @@ -148,6 +92,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", ], ) @@ -200,6 +145,7 @@ cc_library( "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", "@stablehlo//:chlo_ops", ], alwayslink = 1, @@ -214,44 +160,67 @@ cc_library( "transforms/adjust_layout.h", ], deps = [ - ":tf_xla_passes_inc_gen", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client:sharding_builder", - "//tensorflow/compiler/xla/client/lib:conv_grad_size_util", "//tensorflow/compiler/xla/mlir_hlo", - "//tensorflow/compiler/xla/mlir_hlo:convert_op_folder", "//tensorflow/compiler/xla/stream_executor/tpu:c_api_conversions", - "//tensorflow/compiler/xla/stream_executor/tpu:c_api_decl", "//tensorflow/compiler/xla/stream_executor/tpu:tpu_api", - "//tensorflow/compiler/xla/stream_executor/tpu:tpu_executor_c_api_hdrs", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:type_to_shape", - "//tensorflow/core:framework", - "//tensorflow/core/kernels:conv_grad_shape_utils", - "//tensorflow/tsl/platform:bfloat16", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@llvm-project//mlir:Dialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "xla_legalize_targets", + srcs = [ + "transforms/xla_legalize_targets.cc", + ], + hdrs = [ + "transforms/xla_legalize_targets.h", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/xla/mlir_hlo", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@stablehlo//:chlo_ops", + ], +) + +tf_cc_test( + name = "xla_legalize_targets_test", + srcs = ["transforms/xla_legalize_targets_test.cc"], + deps = [ + ":xla_legalize_targets", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@stablehlo//:chlo_ops", ], ) cc_library( name = "xla_legalize_tf", srcs = [ + "transforms/convert_mhlo_quant_to_int.cc", "transforms/legalize_tf_collective.cc", "transforms/legalize_tf_communication.cc", - "transforms/legalize_tf_control_flow.cc", "transforms/legalize_tf_types.cc", "transforms/tf_xla_passes.h.inc", + "transforms/verify_tfxla_legalization.cc", "transforms/xla_legalize_tf.cc", "transforms/xla_legalize_tf_passes.h.inc", ], @@ -261,6 +230,8 @@ cc_library( deps = [ ":legalize_tf", ":legalize_utils", + ":xla_legalize_targets", + ":xla_legalize_tf_no_fallback", ":xla_legalize_tf_passes_inc_gen", ":xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/mlir/tensorflow", @@ -275,9 +246,12 @@ cc_library( "//tensorflow/compiler/xla/mlir_hlo", "//tensorflow/compiler/xla/mlir_hlo:chlo_legalize_to_hlo", "//tensorflow/compiler/xla/mlir_hlo:convert_op_folder", + "//tensorflow/compiler/xla/translate/hlo_to_mhlo:attribute_importer", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:type_to_shape", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/util/quantization:uniform_quant_ops_params", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", @@ -330,15 +304,19 @@ cc_library( srcs = [ "transforms/legalize_tf_with_tf2xla.cc", ], + hdrs = [ + "transforms/passes.h", + ], deps = [ ":tf_xla_passes_inc_gen", + ":xla_legalize_tf_passes_inc_gen", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", - "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tpu_embedding_ops_registry", "//tensorflow/compiler/mlir/tensorflow:translate_utils", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_context", @@ -348,7 +326,6 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/mlir_hlo", "//tensorflow/compiler/xla/stream_executor:timer", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/compiler/xla/translate/hlo_to_mhlo:mlir_hlo_builder", "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", @@ -370,82 +347,25 @@ cc_library( ], ) -cc_library( - name = "mhlo_to_lhlo_with_xla", - srcs = ["transforms/mhlo_to_lhlo_with_xla.cc"], - hdrs = ["transforms/mhlo_to_lhlo_with_xla.h"], - deps = [ - "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/xla:debug_options_flags", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:window_util", - "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/mlir_hlo", - "//tensorflow/compiler/xla/mlir_hlo:lhlo", - "//tensorflow/compiler/xla/mlir_hlo:lhlo_gpu", - "//tensorflow/compiler/xla/service:backend", - "//tensorflow/compiler/xla/service:buffer_assignment", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service/gpu:backend_configs_cc", - "//tensorflow/compiler/xla/service/gpu:cublas_cudnn", - "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", - "//tensorflow/compiler/xla/service/gpu:matmul_utils", - "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", - "//tensorflow/compiler/xla/translate/hlo_to_mhlo:attribute_importer", - "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_module_importer", - "//tensorflow/compiler/xla/translate/hlo_to_mhlo:hlo_utils", - "//tensorflow/compiler/xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", - "//tensorflow/compiler/xla/translate/mhlo_to_hlo:type_to_shape", - "//tensorflow/tsl/platform:status", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/types:optional", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:BufferizationDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:TranslateLib", - ], -) - -cc_library( - name = "translate_cl_registration", - testonly = True, - srcs = ["xla_mlir_translate_registration.cc"], - deps = [ - "//tensorflow/compiler/jit:xla_cpu_jit", - "//tensorflow/compiler/jit:xla_gpu_jit", - "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla", - "@llvm-project//mlir:TranslateLib", - ], - alwayslink = 1, -) - tf_cc_binary( name = "xla-opt", testonly = True, srcs = ["xla_opt_main.cc"], deps = [ ":adjust_layout", # buildcleaner: keep - ":mhlo_to_lhlo_with_xla", # buildcleaner: keep ":tf_xla_passes", # buildcleaner: keep - ":xla_framework", ":xla_legalize_tf", # buildcleaner: keep ":xla_legalize_tf_no_fallback", # buildcleaner: keep - ":xla_passes", # buildcleaner: keep "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", + "//tensorflow/compiler/xla/mlir/framework/ir:xla_framework", + "//tensorflow/compiler/xla/mlir/framework/transforms:passes", "//tensorflow/compiler/xla/mlir_hlo:all_passes", "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", "//tensorflow/compiler/xla/service:cpu_plugin", - "//tensorflow/compiler/xla/service/cpu:hlo_xla_runtime_pipeline", # buildcleaner: keep + "//tensorflow/compiler/xla/service/cpu:hlo_xla_runtime_pipeline", + "//tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla", # buildcleaner: keep "//tensorflow/core/ir/types:Dialect", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:MlirOptLib", @@ -459,86 +379,22 @@ tf_cc_binary( srcs = ["xla_opt_main.cc"], deps = [ ":adjust_layout", # buildcleaner: keep - ":mhlo_to_lhlo_with_xla", # buildcleaner: keep ":tf_xla_passes", # buildcleaner: keep - ":xla_framework", # buildcleaner: keep ":xla_legalize_tf", # buildcleaner: keep ":xla_legalize_tf_no_fallback", # buildcleaner: keep - ":xla_passes", # buildcleaner: keep "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", + "//tensorflow/compiler/xla/mlir/framework/ir:xla_framework", + "//tensorflow/compiler/xla/mlir/framework/transforms:passes", "//tensorflow/compiler/xla/mlir_hlo:all_passes", "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service/cpu:hlo_xla_runtime_pipeline", + "//tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla", # buildcleaner: keep "//tensorflow/core/ir/types:Dialect", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:MlirOptLib", "@stablehlo//:register", ], ) - -td_library( - name = "td_files", - srcs = [ - "ir/xla_framework_ops.td", - ], - compatible_with = get_compatible_with_cloud(), - deps = [ - "@llvm-project//mlir:ControlFlowInterfacesTdFiles", - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:SideEffectInterfacesTdFiles", - ], -) - -gentbl_cc_library( - name = "xla_framework_inc_gen", - compatible_with = get_compatible_with_cloud(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/xla_framework.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/xla_framework.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "ir/xla_framework_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "ir/xla_framework_dialect.cc.inc", - ), - ( - ["-gen-typedef-decls"], - "ir/xla_framework_types.h.inc", - ), - ( - ["-gen-typedef-defs"], - "ir/xla_framework_types.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "ir/xla_framework_ops.td", - deps = [":td_files"], -) - -cc_library( - name = "xla_framework", - srcs = [ - "ir/xla_framework.cc", - "ir/xla_framework.cc.inc", - "ir/xla_framework.h.inc", - ], - hdrs = ["ir/xla_framework.h"], - deps = [ - ":xla_framework_inc_gen", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Support", - ], -) diff --git a/tensorflow/compiler/mlir/xla/tests/BUILD b/tensorflow/compiler/mlir/xla/tests/BUILD index 85181486fa4..b43e012b752 100644 --- a/tensorflow/compiler/mlir/xla/tests/BUILD +++ b/tensorflow/compiler/mlir/xla/tests/BUILD @@ -1,11 +1,20 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", + size_override = { + "legalize-tf-binary-elementwise.mlir": "medium", + "legalize-tf-include-tf2xla-fallback.mlir": "medium", + "legalize-tf-prefer-tf2xla.mlir": "medium", + "legalize-tf.mlir": "medium", + }, test_file_exts = [ "mlir", "hlotxt", diff --git a/tensorflow/compiler/mlir/xla/tests/adjust-layout.mlir b/tensorflow/compiler/mlir/xla/tests/adjust-layout.mlir index a68f469c71e..8d60633ab52 100644 --- a/tensorflow/compiler/mlir/xla/tests/adjust-layout.mlir +++ b/tensorflow/compiler/mlir/xla/tests/adjust-layout.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt -pass-pipeline='func.func(xla-adjust-layout)' %s | FILECHECK_OPTS="" FileCheck %s +// RUN: xla-opt -pass-pipeline='builtin.module(func.func(xla-adjust-layout))' %s | FILECHECK_OPTS="" FileCheck %s func.func @infeed_dequeue_tuple() -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>) { // CHECK: [[TOKEN:%.*]] = mhlo.create_token : !mhlo.token diff --git a/tensorflow/compiler/mlir/xla/tests/convert-mhlo-quant-to-int.mlir b/tensorflow/compiler/mlir/xla/tests/convert-mhlo-quant-to-int.mlir new file mode 100644 index 00000000000..a2fae364156 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/convert-mhlo-quant-to-int.mlir @@ -0,0 +1,51 @@ +// RUN: xla-opt -convert-mhlo-quant-to-int -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func @uniform_quantize_and_dequantize +func.func @uniform_quantize_and_dequantize(%arg0: tensor) -> tensor { + // CHECK-DAG: %[[SCALES:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[ZPS:.*]] = mhlo.constant dense<3> : tensor + // CHECK-DAG: %[[HALF:.*]] = mhlo.constant dense<5.000000e-01> : tensor + // CHECK-DAG: %[[QUANT_MIN:.*]] = mhlo.constant dense<-128> : tensor + // CHECK-DAG: %[[QUANT_MAX:.*]] = mhlo.constant dense<127> : tensor + // CHECK: %[[VAL0:.*]] = chlo.broadcast_divide %arg0, %[[SCALES]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: %[[VAL1:.*]] = chlo.broadcast_add %[[VAL0]], %[[HALF]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: %[[VAL2:.*]] = mhlo.floor %[[VAL1]] : tensor + // CHECK: %[[VAL3:.*]] = mhlo.convert %[[VAL2]] : (tensor) -> tensor + // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL3]], %[[ZPS]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: %[[VAL5:.*]] = chlo.broadcast_maximum %[[VAL4]], %[[QUANT_MIN]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: %[[VAL6:.*]] = chlo.broadcast_minimum %[[VAL5]], %[[QUANT_MAX]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: %[[VAL7:.*]] = mhlo.convert %[[VAL6]] : (tensor) -> tensor + + // CHECK-DAG: %[[SCALES_DQ:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[ZPS_DQ:.*]] = mhlo.constant dense<3> : tensor + // CHECK: %[[VAL8:.*]] = mhlo.convert %[[VAL7]] : (tensor) -> tensor + // CHECK: %[[VAL9:.*]] = chlo.broadcast_subtract %[[VAL8]], %[[ZPS_DQ]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: %[[VAL10:.*]] = mhlo.convert %[[VAL9]] : (tensor) -> tensor + // CHECK: %[[VAL11:.*]] = chlo.broadcast_multiply %[[VAL10]], %[[SCALES_DQ]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: return %[[VAL11]] : tensor + %0 = mhlo.uniform_quantize %arg0 : (tensor) -> tensor> + %1 = mhlo.uniform_dequantize %0 : (tensor>) -> tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: func @uniform_quantize_and_dequantize_type_exensions +func.func @uniform_quantize_and_dequantize_type_exensions(%arg0: tensor>) -> () { + // CHECK: %[[QUANTIZED:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor>) -> tensor> + // CHECK: %[[DEQUANTIZED:.*]] = chlo.broadcast_multiply %[[VAL1:.*]], %[[CONST_SCALE:.*]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor>, tensor) -> tensor> + %0 = mhlo.uniform_quantize %arg0 : (tensor>) -> tensor, #mhlo.type_extensions> + %1 = mhlo.uniform_dequantize %0 : (tensor, #mhlo.type_extensions>) -> tensor> + return +} + +// ----- + +// CHECK-LABEL: func @uniform_quantize_and_dequantize_sparse_tensor_encoding +func.func @uniform_quantize_and_dequantize_sparse_tensor_encoding(%arg0: tensor>) -> () { + // CHECK: %[[QUANTIZED:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor>) -> tensor> + // CHECK: %[[DEQUANTIZED:.*]] = chlo.broadcast_multiply %[[VAL1:.*]], %[[CONST_SCALE:.*]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor>, tensor) -> tensor> + %0 = mhlo.uniform_quantize %arg0 : (tensor>) -> tensor, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>> + %1 = mhlo.uniform_dequantize %0 : (tensor, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }>>) -> tensor> + return +} diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_xla_runtime_pipeline.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_xla_runtime_pipeline.mlir index b7014db13d1..64225279774 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_xla_runtime_pipeline.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo_xla_runtime_pipeline.mlir @@ -9,64 +9,44 @@ func.func @simple_add(%arg0: tensor) -> tensor { // ----- -// TODO(ecg): bring back the Sparse tests once BufferResultsToOutParams is -// restricted to the main entry point. - -//#CSR = #sparse_tensor.encoding<{dimLevelType = [ "dense", "compressed" ]}> - -// NOCHECK-LABEL: func.func @csr_abs_eltwise( -//func.func @csr_abs_eltwise(%arg0: tensor<10x20xf32, #CSR>) -// -> tensor<10x20xf32, #CSR> { - // NOCHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // NOCHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // NOCHECK-DAG: %[[C10:.*]] = arith.constant 10 : index - // NOCHECK-DAG: %[[PTR:.*]] = call @sparsePointers0 - // NOCHECK-DAG: %[[IDX:.*]] = call @sparseIndices0 - // NOCHECK-DAG: %[[VAL:.*]] = call @sparseValuesF32 - // NOCHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C10]] step %[[C1]] { - // NOCHECK: %[[L:.*]] = memref.load %[[PTR]][%[[I]]] : memref - // NOCHECK: %[[A:.*]] = arith.addi %[[I]], %[[C1]] : index - // NOCHECK: %[[U:.*]] = memref.load %[[PTR]][%[[A]]] : memref - // NOCHECK: scf.for %[[JJ:.*]] = %[[L]] to %[[U]] step %[[C1]] { - // NOCHECK: %[[J:.*]] = memref.load %[[IDX]][%[[JJ]]] : memref - // NOCHECK: %[[V:.*]] = memref.load %[[VAL]][%[[JJ]]] : memref - // NOCHECK: math.absf %[[V]] : f32 - // NOCHECK: } - // NOCHECK: } -// %0 = mhlo.abs %arg0 : tensor<10x20xf32, #CSR> -// func.return %0 : tensor<10x20xf32, #CSR> -//} - -// ----- - -//#CSR = #sparse_tensor.encoding<{dimLevelType = [ "dense", "compressed" ]}> - -// NOCHECK-LABEL: func.func @csr_gendot( -//func.func @csr_gendot(%arg0: tensor<32x64xf64, #CSR>, -// %arg1: tensor<64x32xf64>) -> tensor<32x32xf64> { - // NOCHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // NOCHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // NOCHECK-DAG: %[[C32:.*]] = arith.constant 32 : index - // NOCHECK-DAG: %[[PTR:.*]] = call @sparsePointers0 - // NOCHECK-DAG: %[[IDX:.*]] = call @sparseIndices0 - // NOCHECK-DAG: %[[VAL:.*]] = call @sparseValuesF64 - // NOCHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C32]] step %[[C1]] { - // NOCHECK: %[[L:.*]] = memref.load %[[PTR]][%[[I]]] : memref - // NOCHECK: %[[A:.*]] = arith.addi %[[I]], %[[C1]] : index - // NOCHECK: %[[U:.*]] = memref.load %[[PTR]][%[[A]]] : memref - // NOCHECK: scf.for %[[JJ:.*]] = %[[L]] to %[[U]] step %[[C1]] { - // NOCHECK: %[[J:.*]] = memref.load %[[IDX]][%[[JJ]]] : memref - // NOCHECK: %[[V:.*]] = memref.load %[[VAL]][%[[JJ]]] : memref - // NOCHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C32]] step %[[C1]] { - // NOCHECK: } - // NOCHECK: } - // NOCHECK: } -// %0 = "mhlo.dot_general"(%arg0, %arg1) { -// dot_dimension_numbers = #mhlo.dot, -// precision_config = [#mhlo, -// #mhlo]} -// : (tensor<32x64xf64, #CSR>, -// tensor<64x32xf64>) -> tensor<32x32xf64> -// return %0 : tensor<32x32xf64> -//} +#CSR = #sparse_tensor.encoding<{dimLevelType = [ "dense", "compressed" ]}> + +// CHECK-LABEL: func.func @csr_gendot( +// CHECK-SAME: %[[PTR:.*0]]: memref, +// CHECK-SAME: %[[IDX:.*1]]: memref, +// CHECK-SAME: %[[VAL:.*2]]: memref, +// CHECK-SAME: %[[SPEC:.*3]]: !llvm.struct<(array<2 x i64>, array<3 x i64>)> +// CHECK-SAME: %[[DENSE:.*4]]: memref<64x32xf64>) -> memref<32x32xf64> { +func.func @csr_gendot(%arg0: tensor<32x64xf64, #CSR>, + %arg1: tensor<64x32xf64>) -> tensor<32x32xf64> { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index + // CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() + // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C32]] step %[[C1]] { + // CHECK: %[[L:.*]] = memref.load %[[PTR]][%[[I]]] : memref + // CHECK: %[[A:.*]] = arith.addi %[[I]], %[[C1]] : index + // CHECK: %[[U:.*]] = memref.load %[[PTR]][%[[A]]] : memref + // CHECK: scf.for %[[JJ:.*]] = %[[L]] to %[[U]] step %[[C1]] { + // CHECK: %[[J:.*]] = memref.load %[[IDX]][%[[JJ]]] : memref + // CHECK: %[[V:.*]] = memref.load %[[VAL]][%[[JJ]]] : memref + // CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C32]] step %[[C1]] { + // CHECK: %[[T1:.*]] = memref.load %[[ALLOC]][%[[I]], %[[K]]] : memref<32x32xf64> + // CHECK: %[[T2:.*]] = memref.load %[[DENSE]][%[[J]], %[[K]]] : memref<64x32xf64> + // CHECK: %[[T3:.*]] = arith.mulf %[[V]], %[[T2]] : f64 + // CHECK: %[[T4:.*]] = arith.addf %[[T1]], %[[T3]] : f64 + // CHECK: memref.store %[[T4]], %[[ALLOC]][%[[I]], %[[K]]] : memref<32x32xf64> + // CHECK: } + // CHECK: } + // CHECK: } + // CHECK: return %[[ALLOC]] : memref<32x32xf64> + // CHECK: } + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot, + precision_config = [#mhlo, + #mhlo]} + : (tensor<32x64xf64, #CSR>, + tensor<64x32xf64>) -> tensor<32x32xf64> + return %0 : tensor<32x32xf64> +} diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_xla_sparsification.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_xla_sparsification.mlir new file mode 100644 index 00000000000..712944d4157 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/hlo_xla_sparsification.mlir @@ -0,0 +1,33 @@ +// RUN: xla-opt -hlo-legalize-to-linalg -hlo-xla-runtime-sparsification %s | FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{ dimLevelType = ["compressed"] }> + +// CHECK-LABEL: func.func @mult_sparse_dense( +// CHECK-SAME: %[[PTR:.*0]]: memref, +// CHECK-SAME: %[[IDX:.*1]]: memref, +// CHECK-SAME: %[[VAL:.*2]]: memref, +// CHECK-SAME: %[[SPEC:.*3]]: !llvm.struct<(array<1 x i64>, array<3 x i64>)> +// CHECK-SAME: %[[DENSE:.*4]]: memref<10xf64>) -> memref<10xf64> { +// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[A:.*]] = memref.alloc() {alignment = 64 : i64} : memref<10xf64> +// CHECK: linalg.fill ins(%[[F0]] : f64) outs(%[[A]] : memref<10xf64>) +// CHECK: %[[LO:.*]] = memref.load %[[PTR]][%[[C0]]] : memref +// CHECK: %[[HI:.*]] = memref.load %[[PTR]][%[[C1]]] : memref +// CHECK: scf.for %[[II:.*]] = %[[LO]] to %[[HI]] step %[[C1]] { +// CHECK: %[[I:.*]] = memref.load %[[IDX]][%[[II]]] : memref +// CHECK: %[[T0:.*]] = memref.load %[[VAL]][%[[II]]] : memref +// CHECK: %[[T1:.*]] = memref.load %[[DENSE]][%[[I]]] : memref<10xf64> +// CHECK: %[[T3:.*]] = arith.mulf %[[T0]], %[[T1]] : f64 +// CHECK: memref.store %[[T3]], %[[A]][%[[I]]] : memref<10xf64> +// CHECK: } +// CHECK: return %[[A]] : memref<10xf64> +// CHECK: } +func.func @mult_sparse_dense(%arg0: tensor<10xf64, #SparseVector>, + %arg1: tensor<10xf64>) + -> tensor<10xf64> { + %0 = mhlo.multiply %arg0, %arg1 : (tensor<10xf64, #SparseVector>, + tensor<10xf64>) -> tensor<10xf64> + return %0 : tensor<10xf64> +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir deleted file mode 100644 index ae491195515..00000000000 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir +++ /dev/null @@ -1,223 +0,0 @@ -// RUN: xla-opt -split-input-file -xla-legalize-tf-control-flow %s | FileCheck %s - -// CHECK-LABEL: @if -// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor) -func.func @if(%arg0: tensor, %arg1: tensor) -> (tensor) { - // CHECK: [[VAL0:%.+]] = mhlo.compare GT, [[ARG0]], [[ARG1]] : (tensor, tensor) -> tensor - %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - // CHECK: [[VAL2:%.+]] = "mhlo.if"([[VAL0]]) ({ - // CHECK: [[VAL3:%.+]] = func.call @cond_true([[ARG0]], [[ARG1]]) - // CHECK: mhlo.return [[VAL3]] : tensor - // CHECK: }, { - // CHECK: [[VAL4:%.+]] = func.call @cond_false([[ARG0]], [[ARG1]]) - // CHECK: mhlo.return [[VAL4]] : tensor - // CHECK: }) - %1 = "tf.If"(%0, %arg0, %arg1) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor, tensor, tensor) -> tensor - - // CHECK: return [[VAL2]] - func.return %1 : tensor -} - -func.func @cond_false(%arg0: tensor, %arg1: tensor) -> tensor -attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { - %0 = mhlo.exponential %arg1 : (tensor) -> tensor - func.return %0 : tensor -} - -func.func @cond_true(%arg0: tensor, %arg1: tensor) -> tensor -attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { - %0 = mhlo.log %arg0 : (tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @ifRegion -// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor) -func.func @ifRegion(%arg0: tensor, %arg1: tensor) -> (tensor) { - // CHECK: [[VAL0:%.+]] = mhlo.compare GT, [[ARG0]], [[ARG1]] - %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - // CHECK: [[VAL1:%.+]] = "mhlo.if"([[VAL0]]) ({ - %1 = "tf.IfRegion"(%0) ({ - // CHECK: [[VAL2:%.+]] = mhlo.log [[ARG0]] - %2 = mhlo.log %arg0 : (tensor) -> tensor - // CHECK: mhlo.return [[VAL2]] - "tf.Yield"(%2) : (tensor) -> () - }, { - // CHECK: [[VAL3:%.+]] = mhlo.exponential [[ARG1]] - %2 = mhlo.exponential %arg1 : (tensor) -> tensor - // CHECK: mhlo.return [[VAL3]] - "tf.Yield"(%2) : (tensor) -> () - // CHECK: }) : (tensor) -> tensor - }) {is_stateless = true} : (tensor) -> tensor - // CHECK: return [[VAL1]] - func.return %1 : tensor -} - - -// CHECK-LABEL: func @case -// CHECK-SAME: %[[BRANCH_INDEX:.*]]: tensor, %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -> (tensor, tensor) -func.func @case(%index: tensor, %arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0:2 = "tf.Case"(%index, %arg0, %arg1) {branches = [@exponential, @log, @floor], is_stateless = true} : (tensor, tensor, tensor) -> (tensor, tensor) - // CHECK: %[[CASE:.*]]:2 = "mhlo.case"(%[[BRANCH_INDEX]]) ({ - // CHECK: %[[CALL_EXP:.*]]:2 = func.call @exponential(%[[ARG0]], %[[ARG1]]) : (tensor, tensor) -> (tensor, tensor) - // CHECK: mhlo.return %[[CALL_EXP]]#0, %[[CALL_EXP]]#1 : tensor, tensor - // CHECK: }, { - // CHECK: %[[CALL_LOG:.*]]:2 = func.call @log(%[[ARG0]], %[[ARG1]]) : (tensor, tensor) -> (tensor, tensor) - // CHECK: mhlo.return %[[CALL_LOG]]#0, %[[CALL_LOG]]#1 : tensor, tensor - // CHECK: }, { - // CHECK: %[[CALL_FLOOR:.*]]:2 = func.call @floor(%[[ARG0]], %[[ARG1]]) : (tensor, tensor) -> (tensor, tensor) - // CHECK: mhlo.return %[[CALL_FLOOR]]#0, %[[CALL_FLOOR]]#1 : tensor, tensor - // CHECK: }) : (tensor) -> (tensor, tensor) - func.return %0#0, %0#1 : tensor, tensor -// CHECK: return %[[CASE]]#0, %[[CASE]]#1 : tensor, tensor -} - -func.func @exponential(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0 = mhlo.exponential %arg1 : (tensor) -> tensor - func.return %0, %arg1 : tensor, tensor -} - -func.func @log(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0 = mhlo.log %arg0 : (tensor) -> tensor - func.return %0, %arg1 : tensor, tensor -} - -func.func @floor(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0 = "mhlo.floor"(%arg0) : (tensor) -> tensor - func.return %0, %arg1 : tensor, tensor -} - - -// CHECK-LABEL: func @caseRegion -// CHECK-SAME: ([[BRANCH_INDEX:%.+]]: tensor, [[ARG0:.+]]: tensor, [[ARG1:%.+]]: tensor) -func.func @caseRegion(%index: tensor, %arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - // CHECK: [[VAL1:%.+]]:2 = "mhlo.case"([[BRANCH_INDEX]]) ({ - %0:2 = "tf.CaseRegion"(%index) ({ - // CHECK: [[VAL2:%.+]] = mhlo.exponential [[ARG1]] - %1 = mhlo.exponential %arg1 : (tensor) -> tensor - // CHECK: mhlo.return [[VAL2]], [[ARG1]] - "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () - }, { - // CHECK: [[VAL3:%.+]] = mhlo.log [[ARG0]] - %1 = mhlo.log %arg0 : (tensor) -> tensor - // CHECK: mhlo.return [[VAL3]], [[ARG1]] - "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () - }, { - // CHECK: [[VAL4:%.+]] = mhlo.floor [[ARG0]] - %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor - // CHECK: mhlo.return [[VAL4]], [[ARG1]] - "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () - // CHECK: }) : (tensor) -> (tensor, tensor) - }) {is_stateless = true} : (tensor) -> (tensor, tensor) - // CHECK: return [[VAL1]]#0, [[VAL1]]#1 : tensor, tensor - func.return %0#0, %0#1 : tensor, tensor -} - -// ----- - -// This test case also ensures the mhlo dialect is loaded as a dependency by the -// pass and hence the split here. - -// CHECK-LABEL: func @while -// CHECK-SAME: %[[VAL0:.*]]: tensor, %[[VAL1:.*]]: tensor -func.func @while(%in0: tensor, %in1: tensor) -> tensor { - // CHECK: [[VAL2:%.+]]:3 = mhlo.while([[ITER_ARG0:.*]] = %[[VAL0]], [[ITER_ARG1:.*]] = %[[VAL1]], [[ITER_ARG2:.*]] = %[[VAL0]]) - // CHECK: [[VAL3:%.+]] = func.call @while_cond([[ITER_ARG0]], [[ITER_ARG1]], [[ITER_ARG2]]) - // CHECK: mhlo.return [[VAL3]] - // CHECK: } do { - // CHECK: [[VAL3:%.+]]:3 = func.call @while_body([[ITER_ARG0]], [[ITER_ARG1]], [[ITER_ARG2]]) - // CHECK: mhlo.return [[VAL3]]#0, [[VAL3]]#1, [[VAL3]]#2 - // CHECK: return [[VAL2]]#2 - %2:3 = "tf.While"(%in0, %in1, %in0) {body = @while_body, cond = @while_cond, is_stateless = true, parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) - func.return %2#2 : tensor -} -func.func @while_cond(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - func.return %0 : tensor -} -func.func @while_body(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor, tensor) { - %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - func.return %0, %0, %0 : tensor, tensor, tensor -} - -// ----- - -// CHECK-LABEL: func @whileRegion -func.func @whileRegion() -> tensor { - // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> - %0 = mhlo.constant dense<0> : tensor - // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> - %1 = mhlo.constant dense<-1> : tensor - // CHECK: [[VAL2:%.+]]:3 = mhlo.while([[ITER_ARG0:.*]] = [[VAL0]], [[ITER_ARG1:.*]] = [[VAL1]], [[ITER_ARG2:.*]] = [[VAL0]]) - %2:3 = "tf.WhileRegion"(%0, %1, %0) ({ - ^cond(%carg0: tensor, %carg1: tensor, %carg2: tensor): - // CHECK: [[VAL3:%.+]] = mhlo.constant dense<10> - %3 = mhlo.constant dense<10> : tensor - // CHECK: [[VAL4:%.+]] = mhlo.compare LT, [[ITER_ARG2]], [[VAL3]] - %4 = "mhlo.compare"(%carg2, %3) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - // CHECK: mhlo.return [[VAL4]] - "tf.Yield"(%4) : (tensor) -> () - }, { - ^body(%barg0: tensor, %barg1: tensor, %barg2: tensor): - // CHECK: [[VAL5:%.+]] = mhlo.constant dense<1> - %5 = mhlo.constant dense<1> : tensor - // CHECK: [[VAL6:%.+]] = mhlo.add [[ITER_ARG2]], [[VAL5]] - %6 = mhlo.add %barg2, %5 : tensor - // CHECK: [[VAL7:%.+]] = mhlo.add [[ITER_ARG0]], [[VAL5]] - %7 = mhlo.add %barg0, %5 : tensor - // CHECK: mhlo.return [[VAL7]], [[ITER_ARG1]], [[VAL6]] - "tf.Yield"(%7, %barg1, %6) : (tensor, tensor, tensor) -> () - }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) - // CHECK: return [[VAL2]]#2 - func.return %2#2 : tensor -} - - -// CHECK-LABEL: func @whileRegionImplicitInputs -// CHECK-SAME: ([[ARG0:%.+]]: tensor) -func.func @whileRegionImplicitInputs(%arg0: tensor) -> tensor { - // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> - %0 = mhlo.constant dense<0> : tensor - // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> - %1 = mhlo.constant dense<-1> : tensor - // CHECK: [[VAL2:%.+]]:3 = mhlo.while([[ITER_ARG0:.*]] = [[ARG0]], [[ITER_ARG1:.*]] = [[VAL0]], [[ITER_ARG2:.*]] = [[VAL1]]) - %2 = "tf.WhileRegion"(%arg0) ({ - ^cond(%carg0: tensor): - // CHECK: [[VAL3:%.+]] = mhlo.compare LT, [[ITER_ARG0]], [[ITER_ARG1]] - %3 = mhlo.compare LT, %carg0, %0 : (tensor, tensor) -> tensor - // CHECK: mhlo.return [[VAL3]] - "tf.Yield"(%3) : (tensor) -> () - }, { - ^body(%barg0: tensor): - // CHECK: [[VAL3:%.+]] = mhlo.add [[ITER_ARG0]], [[ITER_ARG2]] - %3 = mhlo.add %barg0, %1 : tensor - // CHECK: [[VAL4:%.+]] = mhlo.add [[ITER_ARG0]], [[VAL3]] - %4 = mhlo.add %barg0, %3 : tensor - // CHECK: mhlo.return [[VAL4]], [[ITER_ARG1]], [[ITER_ARG2]] - "tf.Yield"(%4) : (tensor) -> () - }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor) -> tensor - // CHECK: return [[VAL2]]#0 - func.return %2 : tensor -} - - -// CHECK-LABEL: func @whileRegionMultipleImplicitInputs -func.func @whileRegionMultipleImplicitInputs() { - // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> - %0 = mhlo.constant dense<0> : tensor - // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> - %1 = mhlo.constant dense<-1> : tensor - // CHECK: [[VAL2:%.+]]:2 = mhlo.while([[ITER_ARG0:.*]] = [[VAL0]], [[ITER_ARG1:.*]] = [[VAL1]]) - "tf.WhileRegion"() ({ - // CHECK: [[VAL3:%.+]] = mhlo.compare LT, [[ITER_ARG0]], [[ITER_ARG1]] - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - // CHECK: mhlo.return [[VAL3]] - "tf.Yield"(%2) : (tensor) -> () - }, { - // CHECK: [[VAL3:%.+]] = mhlo.add [[ITER_ARG0]], [[ITER_ARG1]] - %2 = mhlo.add %0, %1 : tensor - // CHECK: mhlo.return [[ITER_ARG0]], [[ITER_ARG1]] - "tf.Yield"() : () -> () - }) {is_stateless = true, parallel_iterations = 10 : i64} : () -> () - // CHECK: return - func.return -} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir index 000876ed4c2..d1301bb070b 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-no-tf2xla-fallback.mlir @@ -2351,7 +2351,7 @@ func.func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: @tan // CHECK-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32> func.func @tan(%arg : tensor<2xf32>) -> tensor<2xf32> { - // CHECK: chlo.tan %[[ARG]] : tensor<2xf32> + // CHECK: mhlo.tan %[[ARG]] : tensor<2xf32> %result = "tf.Tan"(%arg) : (tensor<2xf32>) -> tensor<2xf32> func.return %result : tensor<2xf32> } @@ -2361,7 +2361,7 @@ func.func @tan(%arg : tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: @tan_unranked // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32> func.func @tan_unranked(%arg : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: chlo.tan %[[ARG]] : tensor<*xf32> + // CHECK: mhlo.tan %[[ARG]] : tensor<*xf32> %result = "tf.Tan"(%arg) : (tensor<*xf32>) -> tensor<*xf32> func.return %result : tensor<*xf32> } @@ -4363,8 +4363,8 @@ func.func @conv_dynamic(%arg0: tensor, %arg1: tensor<3x3x3x16xf32 // ----- -// CHECK-LABEL: @split_not_match_non_const_split_dim -func.func @split_not_match_non_const_split_dim(%input: tensor<4x4xf32>, %split_dim: tensor) -> (tensor<*xf32>, tensor<*xf32>) { +// CHECK-LABEL: @split_not_match_dynamic_split_dim_input +func.func @split_not_match_dynamic_split_dim_input(%input: tensor<4x4xf32>, %split_dim: tensor) -> (tensor<*xf32>, tensor<*xf32>) { // CHECK: tf.Split %0:2 = "tf.Split"(%split_dim, %input) : (tensor, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) func.return %0#0, %0#1 : tensor<*xf32>, tensor<*xf32> @@ -4372,8 +4372,8 @@ func.func @split_not_match_non_const_split_dim(%input: tensor<4x4xf32>, %split_d // ----- -// CHECK-LABEL: @split_not_match_unknown_input_dim -func.func @split_not_match_unknown_input_dim(%input: tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) { +// CHECK-LABEL: @split_not_match_dynamic_input_shape +func.func @split_not_match_dynamic_input_shape(%input: tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) { %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor // CHECK: tensor.dim {{.*}} : tensor<4x?x4xf32> // CHECK: arith.divsi {{.*}} : index @@ -4391,6 +4391,25 @@ func.func @split_not_match_unknown_input_dim(%input: tensor<4x?x4xf32>) -> (tens // ----- +// CHECK-LABEL: @split_not_match_static_split_dim_size +func.func @split_not_match_static_split_dim_size(%input: tensor<4x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) { + %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK: tensor.dim {{.*}} : tensor<4x?x4xf32> + // CHECK: arith.divsi {{.*}} : index + // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> + // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<2x?x4xf32> + // CHECK: muli {{.*}} : index + // CHECK: muli {{.*}} : index + // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> + // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> + // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> + // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<2x?x4xf32> + %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) + func.return %0#0, %0#1 : tensor<2x?x4xf32>, tensor<2x?x4xf32> +} + +// ----- + // CHECK-LABEL: @split_match_and_split_into_two func.func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) { %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor @@ -4403,18 +4422,6 @@ func.func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x // ----- -// CHECK-LABEL: @split_match_and_split_into_two_dynamic -func.func @split_match_and_split_into_two_dynamic(%input: tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) { - %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, -1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32> - // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, -1]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32> - %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) - // CHECK: return %[[ONE]], %[[TWO]] - func.return %0#0, %0#1 : tensor<2x?xf32>, tensor<2x?xf32> -} - -// ----- - // CHECK-LABEL: @split_match_and_split_into_three // CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) func.func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) { @@ -4483,19 +4490,6 @@ func.func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor // ----- -// CHECK-LABEL: @splitv_match_and_split_into_three_dynamic -func.func @splitv_match_and_split_into_three_dynamic(%input: tensor) -> (tensor, tensor, tensor) { - %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> - %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor) -> tensor - // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor) -> tensor - // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor) -> tensor - %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor, tensor<3xi32>, tensor) -> (tensor, tensor, tensor) - func.return %0#0, %0#1, %0#2 : tensor, tensor, tensor -} - -// ----- - // CHECK-LABEL: @splitv_dynamic_dim_in_split_sizes func.func @splitv_dynamic_dim_in_split_sizes(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} : () -> tensor<3xi32> @@ -4507,6 +4501,17 @@ func.func @splitv_dynamic_dim_in_split_sizes(%input: tensor<4x6xf32>) -> (tensor func.return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> } +// ----- + +// CHECK-LABEL: @splitv_dynamic +func.func @splitv_dynamic(%input: tensor) -> (tensor, tensor, tensor) { + %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> + %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: tf.SplitV + %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor, tensor<3xi32>, tensor) -> (tensor, tensor, tensor) + func.return %0#0, %0#1, %0#2 : tensor, tensor, tensor +} + //===----------------------------------------------------------------------===// // tf.Assert legalization //===----------------------------------------------------------------------===// @@ -5095,17 +5100,23 @@ func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK: [[NEW_IV:%.*]] = chlo.broadcast_add [[ITER_ARG]], [[ONE]] // CHECK: mhlo.return [[NEW_IV]], [[ITER_ARG1]], [[INDICES2]] - // CHECK: [[GATHER:%.*]] = "mhlo.gather"([[INPUT]], [[WHILE_OUT]]#2) + // CHECK: [[CONSTANT1:%.*]] = mhlo.constant dense<1> : tensor<1xi64> + // CHECK: [[ARITH_CONSTANT:%.*]] = arith.constant 1 : index + // CHECK: [[SHAPE_DIM:%.*]] = shape.dim %arg0, [[ARITH_CONSTANT]] : tensor<4x?x16xf32>, index -> index + // CHECK: [[INDEX_CAST:%.*]] = arith.index_cast [[SHAPE_DIM]] : index to i64 + // CHECK: [[FROM_ELEMENTS:%.*]] = tensor.from_elements [[INDEX_CAST]] : tensor<1xi64> + // CHECK: [[CONSTANT2:%.*]] = mhlo.constant dense<16> : tensor<1xi64> + // CHECK: [[CONCATENATE:%.*]] = "mhlo.concatenate"([[CONSTANT1]], [[FROM_ELEMENTS]], [[CONSTANT2]]) {dimension = 0 : i64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64> + // CHECK: [[DYNAMIC_GATHER:%.*]] = "mhlo.dynamic_gather"([[INPUT]], [[WHILE_OUT]]#2, [[CONCATENATE]]) // CHECK-SAME: dimension_numbers = // CHECK-SAME: offset_dims = [1, 2] // CHECK-SAME: collapsed_slice_dims = [0] // CHECK-SAME: start_index_map = [0] // CHECK-SAME: index_vector_dim = 1 // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: slice_sizes = dense<[1, -1, 16]> - // CHECK: (tensor<4x?x16xf32>, tensor<4xi32>) -> tensor<4x?x16xf32> + // CHECK-SAME:: (tensor<4x?x16xf32>, tensor<4xi32>, tensor<3xi64>) -> tensor<4x?x16xf32> - // CHECK: return [[GATHER]] + // CHECK: return [[DYNAMIC_GATHER]] %0 = "tf.RandomShuffle"(%input) : (tensor<4x?x16xf32>) -> (tensor<4x?x16xf32>) func.return %0: tensor<4x?x16xf32> @@ -5555,7 +5566,7 @@ func.func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32 // CHECK-LABEL: xla_sharding func.func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> { - // CHECK-NEXT: "mhlo.custom_call"(%arg0) {call_target_name = "Sharding", mhlo.sharding = ""} + // CHECK-NEXT: mhlo.custom_call @Sharding(%arg0) {mhlo.sharding = ""} %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "", sharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32> func.return %0 : tensor<4x16xf32> } @@ -5768,21 +5779,6 @@ func.func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> { func.return %1 : tensor<4xf32> } -//===----------------------------------------------------------------------===// -// Qr op legalization -//===----------------------------------------------------------------------===// - -// CHECK: func @qr([[VAL_0:%.*]]: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) -func.func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) { - // The tf.Qr lowering is a full algorithm that is not effective to verify with - // FileCheck. Just verify that it converted. - // TODO(laurenzo): Move this out of the mainline tf2xla conversion as it is - // really only applicable to certain legacy uses. - // CHECK-NOT: "tf.Qr" - %0:2 = "tf.Qr"(%arg0) {full_matrices = false} : (tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) - func.return %0#0, %0#1 : tensor<500x100x75xf32>, tensor<500x75x75xf32> -} - //===----------------------------------------------------------------------===// // tf.Softplus legalization //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-prefer-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-prefer-tf2xla.mlir index a997f933e74..911f2f6575a 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-prefer-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-prefer-tf2xla.mlir @@ -59,10 +59,10 @@ func.func @random_uniform_simple(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { // CHECK-LABEL: func @random_uniform_with_seeds func.func @random_uniform_with_seeds(%arg0: tensor<4xi32>) -> tensor<32x12x12x64xf32> { // CHECK: %0 = mhlo.constant dense<[32, 12, 12, 64]> : tensor<4xi32> - // CHECK-NEXT : %1 = mhlo.constant dense<0.000000e+00> : tensor - // CHECK-NEXT : %2 = mhlo.constant dense<1.000000e+00> : tensor - // CHECK-NEXT : %3 = mhlo.convert %0 : (tensor<4xi32>) -> tensor<4xi64> - // CHECK-NEXT : %4 = "mhlo.rng"(%1, %2, %3) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<4xi64>) -> tensor<32x12x12x64xf32> + // CHECK-NEXT: %1 = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %2 = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-NEXT: %3 = mhlo.convert %0 : (tensor<4xi32>) -> tensor<4xi64> + // CHECK-NEXT: %4 = "mhlo.rng"(%1, %2, %3) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<4xi64>) -> tensor<32x12x12x64xf32> %cst = "tf.Const"() {value = dense<[32, 12, 12, 64]> : tensor<4xi32>} : () -> tensor<4xi32> %0 = "tf.RandomUniform"(%cst) {seed = 87654321 : i64, seed2 = 0 : i64} : (tensor<4xi32>) -> tensor<32x12x12x64xf32> // CHECK: return %4 : tensor<32x12x12x64xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index 871a4cce648..603982022d5 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt "-xla-legalize-tf-with-tf2xla=device-type=XLA_CPU_JIT legalize-test-only-ops" %s -verify-diagnostics | FileCheck %s +// RUN: xla-opt "-xla-legalize-tf=device-type=XLA_CPU_JIT allow-partial-conversion=true prefer-tf2xla=true use-tf2xla-fallback=true" %s -verify-diagnostics | FileCheck %s module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { @@ -29,7 +29,7 @@ func.func @not_allowlisted_op(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: t // CHECK-LABEL: unranked_operand func.func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: tf.Atan2 - // expected-remark@+1 {{lowering requires static shaped tensor operands}} + // expected-remark@+1 {{lowering requires bounded tensor operands}} %0 = "tf.Atan2"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -38,7 +38,7 @@ func.func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: dynamic_operand func.func @dynamic_operand(%arg0: tensor) -> tensor { // CHECK: tf.Atan2 - // expected-remark@+1 {{lowering requires static shaped tensor operands}} + // expected-remark@+1 {{lowering requires bounded tensor operands}} %0 = "tf.Atan2"(%arg0, %arg0) : (tensor, tensor) -> tensor func.return %0 : tensor @@ -102,13 +102,13 @@ func.func @convert(%arg0: tensor<2xi32>) -> tensor<2xf32> { } // CHECK-LABEL: func @constant -func.func @constant(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2xf32> - // CHECK: %[[RESULT:.*]] = mhlo.divide %[[ONE]], %arg0 : tensor<2xf32> +func.func @constant(%arg0: tensor) -> tensor { + // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[RESULT:.*]] = mhlo.divide %[[ONE]], %arg0 : tensor // CHECK: return %[[RESULT]] - %0 = "tf.Inv"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> + %0 = "tf.Inv"(%arg0) : (tensor) -> tensor + func.return %0 : tensor } // CHECK-LABEL: func @greater @@ -357,6 +357,16 @@ func.func @atan2_with_symbol_ref(%arg0: tensor<2xf32>) -> tensor<2xf32> { func.return %0 : tensor<2xf32> } +func.func private @branch0(tensor<2xf32>) -> tensor<2xf32> +func.func private @branch1(tensor<2xf32>) -> tensor<2xf32> + +func.func @case_with_symbol_ref(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: tf.Case + // expected-remark@+1 {{ops with symbol references are not supported}} + %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch0, @branch1], is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + // CHECK-LABEL: const func.func @const() -> tensor<2xf32> { // CHECK: mhlo.const @@ -364,6 +374,53 @@ func.func @const() -> tensor<2xf32> { func.return %cst : tensor<2xf32> } +// CHECK-LABEL: @bounds_propagation +func.func @bounds_propagation(%input: tensor<4xf32>, %size: tensor) -> tensor { + %dimension = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + // CHECK: %[[BOUNDED:.*]] = "mhlo.set_dimension_size" + // CHECK-SAME: {dimension = 0 : i64} : (tensor<4xf32>, tensor) -> tensor> + %0 = "tf.XlaSetDynamicDimensionSize"(%input, %dimension, %size) : (tensor<4xf32>, tensor, tensor) -> tensor + + %axis = "tf.Const"() { value = dense<0> : tensor<1xi32> } : () -> tensor<1xi32> + // CHECK: %[[REVERSED:.*]] = "mhlo.reverse"(%[[BOUNDED]]) + // CHECK-SAME: {dimensions = dense<0> : tensor<1xi64>} + // CHECK-SAME: (tensor>) -> tensor> + %1 = "tf.ReverseV2"(%0, %axis) : (tensor, tensor<1xi32>) -> tensor + + // CHECK: %[[RESULT:.*]] = tensor.cast %[[REVERSED]] : tensor> to tensor + // CHECK: return %[[RESULT]] : tensor + func.return %1 : tensor +} + +// CHECK-LABEL: @bounds_propagation_skip_symbol_ref_ops +func.func @bounds_propagation_skip_symbol_ref_ops(%input: tensor<4xf32>, %size: tensor) -> tensor { + %dimension = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + // CHECK: %[[BOUNDED:.*]] = "mhlo.set_dimension_size" + // CHECK-SAME: {dimension = 0 : i64} : (tensor<4xf32>, tensor) -> tensor> + %0 = "tf.XlaSetDynamicDimensionSize"(%input, %dimension, %size) : (tensor<4xf32>, tensor, tensor) -> tensor + + // CHECK: %[[ORIGINAL:.*]] = tensor.cast %[[BOUNDED]] : tensor> to tensor + + %axis = "tf.Const"() { value = dense<0> : tensor<1xi32> } : () -> tensor<1xi32> + // CHECK: tf.ReverseV2 + // CHECK-SAME: (tensor, tensor<1xi32>) -> tensor + // expected-remark@+1 {{lowering requires bounded tensor operands}} + %1 = "tf.ReverseV2"(%0, %axis) {_body = @identity} : (tensor, tensor<1xi32>) -> tensor + + func.return %1 : tensor +} + +// CHECK-LABEL: func @set_bound +func.func @set_bound(%arg0: tensor) -> tensor { + %bound = "tf.Const"() {value = dense<16> : tensor} : () -> tensor + + // CHECK: %[[RESULT:.*]] = mhlo.custom_call @SetBound(%arg0) {backend_config = "", mhlo.literal = dense<16> : tensor} + %bounded = "tf.XlaSetBound"(%arg0, %bound) : (tensor, tensor) -> tensor + + // CHECK: return %[[RESULT]] + func.return %bounded : tensor +} + // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 02b69842254..256257b6ad1 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1701,6 +1701,21 @@ func.func @unhandled_partitioned_call_2(%arg0: tensor, %arg1: tensor<*xi32> // ----- +// CHECK-LABEL: func @no_args_and_results +func.func @no_args_and_results() { + // CHECK: call @callee() : () -> () + // CHECK: call @callee() : () -> () + // CHECK: call @callee() : () -> () + "tf.PartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> () + "tf.StatefulPartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> () + "tf.LegacyCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> () + func.return +} + +func.func @callee() { + func.return +} + //===----------------------------------------------------------------------===// // ReverseV2 op legalization. //===----------------------------------------------------------------------===// @@ -2406,36 +2421,6 @@ func.func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { // ----- -// CHECK-LABEL: @tan -// CHECK-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32> -// CHLO-LABEL: @tan -// CHLO-SAME: (%[[ARG:.*]]: tensor<2xf32>) -> tensor<2xf32> -func.func @tan(%arg : tensor<2xf32>) -> tensor<2xf32> { - // CHECK: chlo.tan %[[ARG]] : tensor<2xf32> - // CHLO: %[[SINE:.*]] = mhlo.sine %[[ARG]] - // CHLO %[[COSINE:.*]] = mhlo.cosine %[[ARG]] - // CHLO %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]]) - %result = "tf.Tan"(%arg) : (tensor<2xf32>) -> tensor<2xf32> - func.return %result : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: @tan_unranked -// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32> -// CHLO-LABEL: @tan_unranked -// CHLO-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> tensor<*xf32> -func.func @tan_unranked(%arg : tensor<*xf32>) -> tensor<*xf32> { - // CHECK: chlo.tan %[[ARG]] : tensor<*xf32> - // CHLO: %[[SINE:.*]] = mhlo.sine %[[ARG]] - // CHLO %[[COSINE:.*]] = mhlo.cosine %[[ARG]] - // CHLO %[[RESULT:.*]] = "mhlo.divide"(%[[SINE]], %[[COSINE]]) - %result = "tf.Tan"(%arg) : (tensor<*xf32>) -> tensor<*xf32> - func.return %result : tensor<*xf32> -} - -// ----- - // CHECK-LABEL: func @cast_dynamic_i2f func.func @cast_dynamic_i2f(%arg0: tensor) -> tensor { // CHECK: mhlo.convert %arg0 : (tensor) -> tensor @@ -2508,6 +2493,15 @@ func.func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { // ----- +// CHECK-LABEL: @tan +func.func @tan(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: mhlo.tan %arg0 : tensor<2xf32> + %0 = "tf.Tan"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + // CHECK-LABEL: func @cos_dynamic func.func @cos_dynamic(%arg0: tensor) -> tensor { // CHECK: mhlo.cosine %arg0 : tensor @@ -4488,8 +4482,8 @@ func.func @conv_dynamic(%arg0: tensor, %arg1: tensor<3x3x3x16xf32 // ----- -// CHECK-LABEL: @split_not_match_non_const_split_dim -func.func @split_not_match_non_const_split_dim(%input: tensor<4x4xf32>, %split_dim: tensor) -> (tensor<*xf32>, tensor<*xf32>) { +// CHECK-LABEL: @split_not_match_dynamic_split_dim_input +func.func @split_not_match_dynamic_split_dim_input(%input: tensor<4x4xf32>, %split_dim: tensor) -> (tensor<*xf32>, tensor<*xf32>) { // CHECK: tf.Split %0:2 = "tf.Split"(%split_dim, %input) : (tensor, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) func.return %0#0, %0#1 : tensor<*xf32>, tensor<*xf32> @@ -4497,8 +4491,8 @@ func.func @split_not_match_non_const_split_dim(%input: tensor<4x4xf32>, %split_d // ----- -// CHECK-LABEL: @split_not_match_unknown_input_dim -func.func @split_not_match_unknown_input_dim(%input: tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) { +// CHECK-LABEL: @split_not_match_dynamic_input_shape +func.func @split_not_match_dynamic_input_shape(%input: tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) { %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor // CHECK: tensor.dim {{.*}} : tensor<4x?x4xf32> // CHECK: arith.divsi {{.*}} : index @@ -4516,6 +4510,25 @@ func.func @split_not_match_unknown_input_dim(%input: tensor<4x?x4xf32>) -> (tens // ----- +// CHECK-LABEL: @split_not_match_static_split_dim_size +func.func @split_not_match_static_split_dim_size(%input: tensor<4x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) { + %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK: tensor.dim {{.*}} : tensor<4x?x4xf32> + // CHECK: arith.divsi {{.*}} : index + // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> + // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<2x?x4xf32> + // CHECK: muli {{.*}} : index + // CHECK: muli {{.*}} : index + // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> + // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> + // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> + // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<2x?x4xf32> + %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) + func.return %0#0, %0#1 : tensor<2x?x4xf32>, tensor<2x?x4xf32> +} + +// ----- + // CHECK-LABEL: @split_match_and_split_into_two func.func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) { %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor @@ -4528,18 +4541,6 @@ func.func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x // ----- -// CHECK-LABEL: @split_match_and_split_into_two_dynamic -func.func @split_match_and_split_into_two_dynamic(%input: tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) { - %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[2, -1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32> - // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) {limit_indices = dense<[4, -1]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x?xf32>) -> tensor<2x?xf32> - %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x?xf32>) -> (tensor<2x?xf32>, tensor<2x?xf32>) - // CHECK: return %[[ONE]], %[[TWO]] - func.return %0#0, %0#1 : tensor<2x?xf32>, tensor<2x?xf32> -} - -// ----- - // CHECK-LABEL: @split_match_and_split_into_three // CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) func.func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) { @@ -4608,19 +4609,6 @@ func.func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor // ----- -// CHECK-LABEL: @splitv_match_and_split_into_three_dynamic -func.func @splitv_match_and_split_into_three_dynamic(%input: tensor) -> (tensor, tensor, tensor) { - %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> - %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor) -> tensor - // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor) -> tensor - // CHECK: "mhlo.slice"(%{{.*}}) {limit_indices = dense<[-1, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor) -> tensor - %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor, tensor<3xi32>, tensor) -> (tensor, tensor, tensor) - func.return %0#0, %0#1, %0#2 : tensor, tensor, tensor -} - -// ----- - // CHECK-LABEL: @splitv_dynamic_dim_in_split_sizes func.func @splitv_dynamic_dim_in_split_sizes(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} : () -> tensor<3xi32> @@ -4632,6 +4620,17 @@ func.func @splitv_dynamic_dim_in_split_sizes(%input: tensor<4x6xf32>) -> (tensor func.return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> } +// ----- + +// CHECK-LABEL: @splitv_dynamic +func.func @splitv_dynamic(%input: tensor) -> (tensor, tensor, tensor) { + %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> + %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: tf.SplitV + %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor, tensor<3xi32>, tensor) -> (tensor, tensor, tensor) + func.return %0#0, %0#1, %0#2 : tensor, tensor, tensor +} + //===----------------------------------------------------------------------===// // tf.Assert legalization //===----------------------------------------------------------------------===// @@ -5175,6 +5174,17 @@ func.func @tensor_scatter_max(%tensor: tensor, %indices: tensor tensor { + // CHECK: [[INPUT:%.*]] = mhlo.constant dense<1.000000e+20> : tensor + // CHECK-NEXT: return [[INPUT]] + %cst = "tf.Const"() {value = dense<1.000000e+20> : tensor} : () -> tensor + %0 = "tf.RandomShuffle"(%cst) {device = "", seed = -4294967297 : i64, seed2 = -2147483649 : i64} : (tensor) -> tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: @random_shuffle_first_dim_1 // CHECK-SAME: [[INPUT:%.*]]: tensor<1x?xf32> func.func @random_shuffle_first_dim_1(%input: tensor<1x?xf32>) -> tensor<1x?xf32> { @@ -5243,17 +5253,23 @@ func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK: mhlo.return [[NEW_IV]], [[ITER_ARG1]], [[INDICES2]] // CHECK: } - // CHECK: [[GATHER:%.*]] = "mhlo.gather"([[INPUT]], [[WHILE_OUT]]#2) + // CHECK: [[CONSTANT1:%.*]] = mhlo.constant dense<1> : tensor<1xi64> + // CHECK: [[ARITH_CONSTANT:%.*]] = arith.constant 1 : index + // CHECK: [[SHAPE_DIM:%.*]] = shape.dim %arg0, [[ARITH_CONSTANT]] : tensor<4x?x16xf32>, index -> index + // CHECK: [[INDEX_CAST:%.*]] = arith.index_cast [[SHAPE_DIM]] : index to i64 + // CHECK: [[FROM_ELEMENTS:%.*]] = tensor.from_elements [[INDEX_CAST]] : tensor<1xi64> + // CHECK: [[CONSTANT2:%.*]] = mhlo.constant dense<16> : tensor<1xi64> + // CHECK: [[CONCATENATE:%.*]] = "mhlo.concatenate"([[CONSTANT1]], [[FROM_ELEMENTS]], [[CONSTANT2]]) {dimension = 0 : i64} : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64> + // CHECK: [[DYNAMIC_GATHER:%.*]] = "mhlo.dynamic_gather"([[INPUT]], [[WHILE_OUT]]#2, [[CONCATENATE]]) // CHECK-SAME: dimension_numbers = // CHECK-SAME: offset_dims = [1, 2] // CHECK-SAME: collapsed_slice_dims = [0] // CHECK-SAME: start_index_map = [0] // CHECK-SAME: index_vector_dim = 1 // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: slice_sizes = dense<[1, -1, 16]> - // CHECK: (tensor<4x?x16xf32>, tensor<4xi32>) -> tensor<4x?x16xf32> + // CHECK-SAME:: (tensor<4x?x16xf32>, tensor<4xi32>, tensor<3xi64>) -> tensor<4x?x16xf32> - // CHECK: return [[GATHER]] + // CHECK: return [[DYNAMIC_GATHER]] %0 = "tf.RandomShuffle"(%input) : (tensor<4x?x16xf32>) -> (tensor<4x?x16xf32>) func.return %0: tensor<4x?x16xf32> @@ -5703,7 +5719,7 @@ func.func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32 // CHECK-LABEL: xla_sharding func.func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> { - // CHECK-NEXT: "mhlo.custom_call"(%arg0) {call_target_name = "Sharding", mhlo.sharding = ""} + // CHECK-NEXT: mhlo.custom_call @Sharding(%arg0) {mhlo.sharding = ""} %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "", sharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32> func.return %0 : tensor<4x16xf32> } @@ -5917,21 +5933,6 @@ func.func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> { func.return %1 : tensor<4xf32> } -//===----------------------------------------------------------------------===// -// Qr op legalization -//===----------------------------------------------------------------------===// - -// CHECK: func @qr([[VAL_0:%.*]]: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) -func.func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) { - // The tf.Qr lowering is a full algorithm that is not effective to verify with - // FileCheck. Just verify that it converted. - // TODO(laurenzo): Move this out of the mainline tf2xla conversion as it is - // really only applicable to certain legacy uses. - // CHECK-NOT: "tf.Qr" - %0:2 = "tf.Qr"(%arg0) {full_matrices = false} : (tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) - func.return %0#0, %0#1 : tensor<500x100x75xf32>, tensor<500x75x75xf32> -} - //===----------------------------------------------------------------------===// // tf.UniformQuantizedDotHybrid legalization //===----------------------------------------------------------------------===// @@ -5942,17 +5943,268 @@ func.func @quantized_matmul_fn(%input: tensor<*xf32>) -> tensor<*xf32> { %weight_scales = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor %weight_zps = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - // CHECK: %[[CONST:.*]] = mhlo.constant - // CHECK-SAME{LITERAL}: dense<[[1, 2], [3, 4]]> : tensor<2x2xi8> + // CHECK: %[[CONST:.*]] = mhlo.constant() + // CHECK-SAME{LITERAL}: value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi8> // CHECK-SAME: tensor<2x2x!quant.uniform> - // CHECK: %[[ADD:.*]] = chlo.broadcast_add - // CHECK: "mhlo.dot"(%[[ADD]], %[[CONST]]) : (tensor<*xf32>, tensor<2x2x!quant.uniform>) -> tensor<*xf32> + // CHECK: "mhlo.dot"(%arg0, %[[CONST]]) : (tensor<*xf32>, tensor<2x2x!quant.uniform>) -> tensor<*xf32> + + %0 = "tf.UniformQuantizedDotHybrid"(%input, %weight, %weight_scales, %weight_zps) {rhs_quantization_axis = -1 : i64, rhs_quantization_min_val = -128 : i64, rhs_quantization_max_val = 127 : i64} : (tensor<*xf32>, tensor<2x2x!tf_type.qint8>, tensor, tensor) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +//===----------------------------------------------------------------------===// +// tf.UniformQuantizedConvolutionHybrid legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @uniform_quantized_convolution_hybrid +func.func @uniform_quantized_convolution_hybrid(%input: tensor<1x2x2x3xf32>) -> tensor<*xf32> { + %weight = "tf.Const"() {value = #tf_type : tensor<2x3x3x2x!tf_type.qint8>} : () -> tensor<2x3x3x2x!tf_type.qint8> + %weight_scales = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor + %weight_zps = "tf.Const"() { value = dense<3> : tensor } : () -> tensor + + // CHECK: %[[CONST:.*]] = mhlo.constant() + // CHECK-SAME{LITERAL} value = dense<127> : tensor<2x3x3x2xi8> + // CHECK-SAME: tensor<2x3x3x2x!quant.uniform> + // CHECK: mhlo.convolution(%arg0, %[[CONST]]) + // CHECK-SAME{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + // CHECK-SAME{LITERAL}: window = {stride = [1, 2], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [2, 2]} + // CHECK-SAME{LITERAL}: batch_group_count = 1 : i64, feature_group_count = 1 : i64 + // CHECK-SAME: (tensor<1x2x2x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<*xf32> + + %0 = "tf.UniformQuantizedConvolutionHybrid"(%input, %weight, %weight_scales, %weight_zps) { + window_strides = [1, 2], + padding = "VALID", + explicit_padding = [], + lhs_dilation = [1, 1], + rhs_dilation = [2, 2], + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02", + rhs_quantization_axis = -1 : i64, + rhs_quantization_min_val = -128 : i64, + rhs_quantization_max_val = 127 : i64 + } : (tensor<1x2x2x3xf32>, tensor<2x3x3x2x!tf_type.qint8>, tensor, tensor) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} - %0 = "tf.AddV2"(%input, %input) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> - %1 = "tf.UniformQuantizedDotHybrid"(%0, %weight, %weight_scales, %weight_zps) {rhs_quantization_axis = -1 : i64, rhs_quantization_min_val = -128 : i64, rhs_quantization_max_val = 127 : i64} : (tensor<*xf32>, tensor<2x2x!tf_type.qint8>, tensor, tensor) -> tensor<*xf32> +// ----- + +// : func @uniform_quantized_convolution_hybrid_same +func.func @uniform_quantized_convolution_hybrid_same(%input: tensor<1x2x2x3xf32>) -> tensor<*xf32> { + %weight = "tf.Const"() {value = #tf_type : tensor<2x3x3x2x!tf_type.qint8>} : () -> tensor<2x3x3x2x!tf_type.qint8> + %weight_scales = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor + %weight_zps = "tf.Const"() { value = dense<3> : tensor } : () -> tensor + + // CHECK: %[[CONST:.*]] = mhlo.constant() + // CHECK-SAME{LITERAL} value = dense<127> : tensor<2x3x3x2xi8> + // CHECK-SAME: tensor<2x3x3x2x!quant.uniform> + // CHECK: mhlo.convolution(%arg0, %[[CONST]]) + // CHECK-SAME{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + // CHECK-SAME{LITERAL}: window = {stride = [1, 2], pad = [[1, 1], [2, 1]], lhs_dilate = [1, 1], rhs_dilate = [2, 2]} + // CHECK-SAME{LITERAL}: batch_group_count = 1 : i64, feature_group_count = 1 : i64 + // CHECK-SAME: (tensor<1x2x2x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<*xf32> + + %0 = "tf.UniformQuantizedConvolutionHybrid"(%input, %weight, %weight_scales, %weight_zps) { + window_strides = [1, 2], + padding = "SAME", + explicit_padding = [], + lhs_dilation = [1, 1], + rhs_dilation = [2, 2], + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02", + rhs_quantization_axis = -1 : i64, + rhs_quantization_min_val = -128 : i64, + rhs_quantization_max_val = 127 : i64 + } : (tensor<1x2x2x3xf32>, tensor<2x3x3x2x!tf_type.qint8>, tensor, tensor) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +// : func @uniform_quantized_convolution_hybrid_explicit +func.func @uniform_quantized_convolution_hybrid_explicit(%input: tensor<1x2x2x3xf32>) -> tensor<*xf32> { + %weight = "tf.Const"() {value = #tf_type : tensor<2x3x3x2x!tf_type.qint8>} : () -> tensor<2x3x3x2x!tf_type.qint8> + %weight_scales = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor + %weight_zps = "tf.Const"() { value = dense<3> : tensor } : () -> tensor + + // CHECK: %[[CONST:.*]] = mhlo.constant() + // CHECK-SAME{LITERAL} value = dense<127> : tensor<2x3x3x2xi8> + // CHECK-SAME: tensor<2x3x3x2x!quant.uniform> + // CHECK: mhlo.convolution(%arg0, %[[CONST]]) + // CHECK-SAME{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + // CHECK-SAME{LITERAL}: window = {stride = [1, 2], pad = [[1, 2], [3, 4]], lhs_dilate = [1, 1], rhs_dilate = [2, 2]} + // CHECK-SAME{LITERAL}: batch_group_count = 1 : i64, feature_group_count = 1 : i64 + // CHECK-SAME: (tensor<1x2x2x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<*xf32> + + %0 = "tf.UniformQuantizedConvolutionHybrid"(%input, %weight, %weight_scales, %weight_zps) { + window_strides = [1, 2], + padding = "EXPLICIT", + explicit_padding = [1, 2, 3, 4], + lhs_dilation = [1, 1], + rhs_dilation = [2, 2], + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02", + rhs_quantization_axis = -1 : i64, + rhs_quantization_min_val = -128 : i64, + rhs_quantization_max_val = 127 : i64 + } : (tensor<1x2x2x3xf32>, tensor<2x3x3x2x!tf_type.qint8>, tensor, tensor) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +//===----------------------------------------------------------------------===// +// tf.UniformQuantize and tf.UniformDequantize legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @uniform_quantize_and_dequantize +func.func @uniform_quantize_and_dequantize(%arg0 : tensor<*xf32>) -> tensor<*xf32> { + %scales = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor + %zps = "tf.Const"() { value = dense<3> : tensor } : () -> tensor + + // CHECK: %[[QUANTIZE:.*]] = mhlo.uniform_quantize %arg0 : (tensor<*xf32>) -> tensor<*x!quant.uniform> + // CHECK: %[[DEQUANTIZE:.*]] = mhlo.uniform_dequantize %[[QUANTIZE]] : (tensor<*x!quant.uniform>) -> tensor<*xf32> + // CHECK: return %[[DEQUANTIZE]] : tensor<*xf32> + + %0 = "tf.UniformQuantize"(%arg0, %scales, %zps) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor<*xf32>, tensor, tensor) -> tensor<*x!tf_type.qint8> + %1 = "tf.UniformDequantize"(%0, %scales, %zps) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor<*x!tf_type.qint8>, tensor, tensor) -> tensor<*xf32> func.return %1 : tensor<*xf32> } +// ----- + +// CHECK-LABEL: func @uniform_quantize_and_dequantize_per_axis +func.func @uniform_quantize_and_dequantize_per_axis(%arg0 : tensor<2x2xf32>) -> tensor<2x2xf32> { + %scales = "tf.Const"() { value = dense<[1.0, 2.0]> : tensor<2xf32> } : () -> tensor<2xf32> + %zps = "tf.Const"() { value = dense<[3, 4]> : tensor<2xi32> } : () -> tensor<2xi32> + + // CHECK: %[[QUANTIZE:.*]] = mhlo.uniform_quantize %arg0 : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> + // CHECK: %[[DEQUANTIZE:.*]] = mhlo.uniform_dequantize %[[QUANTIZE]] : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> + // CHECK: return %[[DEQUANTIZE]] : tensor<2x2xf32> + + %0 = "tf.UniformQuantize"(%arg0, %scales, %zps) { + quantization_axis = 0 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor<2x2xf32>, tensor<2xf32>, tensor<2xi32>) -> tensor<2x2x!tf_type.qint8> + %1 = "tf.UniformDequantize"(%0, %scales, %zps) { + quantization_axis = 0 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor<2x2x!tf_type.qint8>, tensor<2xf32>, tensor<2xi32>) -> tensor<2x2xf32> + func.return %1 : tensor<2x2xf32> +} + +//===----------------------------------------------------------------------===// +// tf.UniformRequantize legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @uniform_quantize_requantize_and_dequantize +func.func @uniform_quantize_requantize_and_dequantize(%arg0 : tensor<*xf32>) -> tensor<*xf32> { + %scales_0 = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor + %zps_0 = "tf.Const"() { value = dense<3> : tensor } : () -> tensor + %scales_1 = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor + %zps_1 = "tf.Const"() { value = dense<5> : tensor } : () -> tensor + + // CHECK: %[[QUANTIZE:.*]] = mhlo.uniform_quantize %arg0 : (tensor<*xf32>) -> tensor<*x!quant.uniform> + // CHECK: %[[REQUANTIZE:.*]] = mhlo.uniform_quantize %[[QUANTIZE]] : (tensor<*x!quant.uniform>) -> tensor<*x!quant.uniform> + // CHECK: %[[DEQUANTIZE:.*]] = mhlo.uniform_dequantize %[[REQUANTIZE]] : (tensor<*x!quant.uniform>) -> tensor<*xf32> + // CHECK: return %[[DEQUANTIZE]] : tensor<*xf32> + + %0 = "tf.UniformQuantize"(%arg0, %scales_0, %zps_0) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor<*xf32>, tensor, tensor) -> tensor<*x!tf_type.qint8> + %1 = "tf.UniformRequantize"(%0, %scales_0, %zps_0, %scales_1, %zps_1) { + input_quantization_axis = -1 : i64, input_quantization_min_val = -128 : i64, input_quantization_max_val = 127 : i64, + output_quantization_axis = -1 : i64, output_quantization_min_val = -128 : i64, output_quantization_max_val = 127 : i64 + } : (tensor<*x!tf_type.qint8>, tensor, tensor, tensor, tensor) -> tensor<*x!tf_type.qint8> + %2 = "tf.UniformDequantize"(%1, %scales_1, %zps_1) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor<*x!tf_type.qint8>, tensor, tensor) -> tensor<*xf32> + func.return %2 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: func @uniform_quantize_requantize_and_dequantize_per_axis +func.func @uniform_quantize_requantize_and_dequantize_per_axis(%arg0 : tensor<2x2xf32>) -> tensor<2x2xf32> { + %scales_0 = "tf.Const"() { value = dense<[1.0, 2.0]> : tensor<2xf32> } : () -> tensor<2xf32> + %zps_0 = "tf.Const"() { value = dense<[3, 4]> : tensor<2xi32> } : () -> tensor<2xi32> + %scales_1 = "tf.Const"() { value = dense<[3.0, 4.0]> : tensor<2xf32> } : () -> tensor<2xf32> + %zps_1 = "tf.Const"() { value = dense<[5, 6]> : tensor<2xi32> } : () -> tensor<2xi32> + + // CHECK: %[[QUANTIZE:.*]] = mhlo.uniform_quantize %arg0 : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> + // CHECK: %[[REQUANTIZE:.*]] = mhlo.uniform_quantize %[[QUANTIZE]] : (tensor<2x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> + // CHECK: %[[DEQUANTIZE:.*]] = mhlo.uniform_dequantize %[[REQUANTIZE]] : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> + // CHECK: return %[[DEQUANTIZE]] : tensor<2x2xf32> + + %0 = "tf.UniformQuantize"(%arg0, %scales_0, %zps_0) { + quantization_axis = 0 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor<2x2xf32>, tensor<2xf32>, tensor<2xi32>) -> tensor<2x2x!tf_type.qint8> + %1 = "tf.UniformRequantize"(%0, %scales_0, %zps_0, %scales_1, %zps_1) { + input_quantization_axis = 0 : i64, input_quantization_min_val = -128 : i64, input_quantization_max_val = 127 : i64, + output_quantization_axis = 0 : i64, output_quantization_min_val = -128 : i64, output_quantization_max_val = 127 : i64 + } : (tensor<2x2x!tf_type.qint8>, tensor<2xf32>, tensor<2xi32>, tensor<2xf32>, tensor<2xi32>) -> tensor<2x2x!tf_type.qint8> + %2 = "tf.UniformDequantize"(%1, %scales_1, %zps_1) { + quantization_axis = 0 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor<2x2x!tf_type.qint8>, tensor<2xf32>, tensor<2xi32>) -> tensor<2x2xf32> + func.return %2 : tensor<2x2xf32> +} + +//===----------------------------------------------------------------------===// +// tf.UniformQuantizedDot legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @uniform_quantized_dot +func.func @uniform_quantized_dot(%input: tensor<*xf32>) -> () { + %input_scales = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor + %input_zps = "tf.Const"() { value = dense<4> : tensor } : () -> tensor + + %weight = "tf.Const"() { value = #tf_type : tensor<2x2x!tf_type.qint8> } : () -> tensor<2x2x!tf_type.qint8> + %weight_scales = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor + %weight_zps = "tf.Const"() { value = dense<3> : tensor } : () -> tensor + + %output_scales = "tf.Const"() { value = dense<3.0> : tensor } : () -> tensor + %output_zps = "tf.Const"() { value = dense<5> : tensor } : () -> tensor + + // CHECK-DAG: %[[LHS:.*]] = mhlo.uniform_quantize %arg0 : (tensor<*xf32>) -> tensor<*x!quant.uniform> + // CHECK-DAG: %[[RHS:.*]] = mhlo.constant() + // CHECK-SAME{LITERAL}: {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi8>} : () -> tensor<2x2x!quant.uniform> + // CHECK: "mhlo.dot"(%[[LHS]], %[[RHS]]) : (tensor<*x!quant.uniform>, tensor<2x2x!quant.uniform>) + // CHECK-SAME: -> tensor<*x!quant.uniform> + + %0 = "tf.UniformQuantize"(%input, %input_scales, %input_zps) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor<*xf32>, tensor, tensor) -> tensor<*x!tf_type.qint8> + %1 = "tf.UniformQuantizedDot"( + %0, %weight, + %input_scales, %input_zps, + %weight_scales, %weight_zps, + %output_scales, %output_zps) { + lhs_quantization_axis = -1 : i64, + lhs_quantization_min_val = -128 : i64, + lhs_quantization_max_val = 127 : i64, + rhs_quantization_axis = -1 : i64, + rhs_quantization_min_val = -128 : i64, + rhs_quantization_max_val = 127 : i64, + output_quantization_axis = -1 : i64, + output_quantization_min_val = -2147483648 : i64, + output_quantization_max_val = 2147483647 : i64} : ( + tensor<*x!tf_type.qint8>, tensor<2x2x!tf_type.qint8>, + tensor, tensor, + tensor, tensor, + tensor, tensor) -> tensor<*x!tf_type.qint32> + func.return +} + //===----------------------------------------------------------------------===// // tf.Softplus legalization //===----------------------------------------------------------------------===// @@ -6533,3 +6785,231 @@ func.func @test_xla_optimization_barrier(%arg0: tensor<4x4xf32>, %arg1: tensor<3 %0, %1 = "tf.XlaOptimizationBarrier"(%arg0, %arg1) : (tensor<4x4xf32>, tensor<3x4xi32>) -> (tensor<4x4xf32>, tensor<3x4xi32>) func.return %0, %1 : tensor<4x4xf32>, tensor<3x4xi32> } + +// CHECK-LABEL: @ifRegion +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor) +func.func @ifRegion(%arg0: tensor, %arg1: tensor) -> (tensor) { + // CHECK: [[VAL0:%.+]] = mhlo.compare GT, [[ARG0]], [[ARG1]] + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + // CHECK: [[VAL1:%.+]] = "mhlo.if"([[VAL0]]) ({ + %1 = "tf.IfRegion"(%0) ({ + // CHECK: [[VAL2:%.+]] = mhlo.log [[ARG0]] + %2 = "tf.Log"(%arg0) : (tensor) -> tensor + // CHECK: mhlo.return [[VAL2]] + "tf.Yield"(%2) : (tensor) -> () + }, { + // CHECK: [[VAL3:%.+]] = mhlo.exponential [[ARG1]] + %2 = "tf.Exp"(%arg1) : (tensor) -> tensor + // CHECK: mhlo.return [[VAL3]] + "tf.Yield"(%2) : (tensor) -> () + // CHECK: }) : (tensor) -> tensor + }) {is_stateless = true} : (tensor) -> tensor + // CHECK: return [[VAL1]] + func.return %1 : tensor +} + +// CHECK-LABEL: func @caseRegion +// CHECK-SAME: ([[BRANCH_INDEX:%.+]]: tensor, [[ARG0:.+]]: tensor, [[ARG1:%.+]]: tensor) +func.func @caseRegion(%index: tensor, %arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + // CHECK: [[VAL1:%.+]]:2 = "mhlo.case"([[BRANCH_INDEX]]) ({ + %0:2 = "tf.CaseRegion"(%index) ({ + // CHECK: [[VAL2:%.+]] = mhlo.exponential [[ARG1]] + %1 = mhlo.exponential %arg1 : (tensor) -> tensor + // CHECK: mhlo.return [[VAL2]], [[ARG1]] + "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () + }, { + // CHECK: [[VAL3:%.+]] = mhlo.log [[ARG0]] + %1 = mhlo.log %arg0 : (tensor) -> tensor + // CHECK: mhlo.return [[VAL3]], [[ARG1]] + "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () + }, { + // CHECK: [[VAL4:%.+]] = mhlo.floor [[ARG0]] + %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor + // CHECK: mhlo.return [[VAL4]], [[ARG1]] + "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () + // CHECK: }) : (tensor) -> (tensor, tensor) + }) {is_stateless = true} : (tensor) -> (tensor, tensor) + // CHECK: return [[VAL1]]#0, [[VAL1]]#1 : tensor, tensor + func.return %0#0, %0#1 : tensor, tensor +} + +// ----- + +// This test case also ensures the mhlo dialect is loaded as a dependency by the +// pass and hence the split here. + +// CHECK-LABEL: func @whileRegion +func.func @whileRegion() -> tensor { + %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %2:3 = "tf.WhileRegion"(%0, %1, %0) ({ + ^cond(%carg0: tensor, %carg1: tensor, %carg2: tensor): + %3 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + "tf.Yield"(%3) : (tensor) -> () + }, { + ^body(%barg0: tensor, %barg1: tensor, %barg2: tensor): + %4 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + "tf.Yield"(%4, %4, %4) : (tensor, tensor, tensor) -> () + }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) + func.return %2#2 : tensor +} + +// ----- + +// CHECK-LABEL: func @whileRegion +func.func @whileRegion() -> tensor { + // CHECK: [[VAL0:%.+]] = mhlo.constant + %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + // CHECK: [[VAL1:%.+]] = mhlo.constant + %1 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + // CHECK: [[VAL2:%.+]]:3 = mhlo.while([[ITER_ARG0:.*]] = [[VAL0]], [[ITER_ARG1:.*]] = [[VAL1]], [[ITER_ARG2:.*]] = [[VAL0]]) + %2:3 = "tf.WhileRegion"(%0, %1, %0) ({ + ^cond(%carg0: tensor, %carg1: tensor, %carg2: tensor): + // CHECK: [[VAL3:%.+]] = mhlo.constant + %3 = "tf.Const"() {value = dense<10> : tensor} : () -> tensor + // CHECK: [[VAL4:%.+]] = mhlo.compare LT, [[ITER_ARG2]], [[VAL3]] + %4 = "mhlo.compare"(%carg2, %3) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + // CHECK: mhlo.return [[VAL4]] + "tf.Yield"(%4) : (tensor) -> () + }, { + ^body(%barg0: tensor, %barg1: tensor, %barg2: tensor): + // CHECK: [[VAL5:%.+]] = mhlo.constant + %5 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: [[VAL6:%.+]] = mhlo.add [[ITER_ARG2]], [[VAL5]] + %6 = mhlo.add %barg2, %5 : tensor + // CHECK: [[VAL7:%.+]] = mhlo.add [[ITER_ARG0]], [[VAL5]] + %7 = mhlo.add %barg0, %5 : tensor + // CHECK: mhlo.return [[VAL7]], [[ITER_ARG1]], [[VAL6]] + "tf.Yield"(%7, %barg1, %6) : (tensor, tensor, tensor) -> () + }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) + // CHECK: return [[VAL2]]#2 + func.return %2#2 : tensor +} + +// ----- + +// CHECK-LABEL: func @whileRegionImplicitInputs +// CHECK-SAME: ([[ARG0:%.+]]: tensor) +func.func @whileRegionImplicitInputs(%arg0: tensor) -> tensor { + // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> + %0 = mhlo.constant dense<0> : tensor + // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> + %1 = mhlo.constant dense<-1> : tensor + // CHECK: [[VAL2:%.+]] = mhlo.while([[ITER_ARG0:.*]] = [[ARG0]]) + %2 = "tf.WhileRegion"(%arg0) ({ + ^cond(%carg0: tensor): + // CHECK: [[VAL3:%.+]] = mhlo.compare LT, [[ITER_ARG0]], [[VAL0]] + %3 = mhlo.compare LT, %carg0, %0 : (tensor, tensor) -> tensor + // CHECK: mhlo.return [[VAL3]] + "tf.Yield"(%3) : (tensor) -> () + }, { + ^body(%barg0: tensor): + // CHECK: [[VAL3:%.+]] = mhlo.add [[ITER_ARG0]], [[VAL1]] + %3 = mhlo.add %barg0, %1 : tensor + // CHECK: [[VAL4:%.+]] = mhlo.add [[ITER_ARG0]], [[VAL3]] + %4 = mhlo.add %barg0, %3 : tensor + // CHECK: mhlo.return [[VAL4]] + "tf.Yield"(%4) : (tensor) -> () + }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor) -> tensor + // CHECK: return [[VAL2]] + func.return %2 : tensor +} + +// CHECK-LABEL: func @whileRegionMultipleImplicitInputs +func.func @whileRegionMultipleImplicitInputs() { + // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> + %0 = mhlo.constant dense<0> : tensor + // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> + %1 = mhlo.constant dense<-1> : tensor + // CHECK: mhlo.while() + "tf.WhileRegion"() ({ + // CHECK: [[VAL3:%.+]] = mhlo.compare LT, [[VAL0]], [[VAL1]] + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + // CHECK: mhlo.return [[VAL3]] + "tf.Yield"(%2) : (tensor) -> () + }, { + // CHECK: [[VAL3:%.+]] = mhlo.add [[VAL0]], [[VAL1]] + %2 = mhlo.add %0, %1 : tensor + // CHECK: mhlo.return + "tf.Yield"() : () -> () + }) {is_stateless = true, parallel_iterations = 10 : i64} : () -> () + // CHECK: return + func.return +} + +//===----------------------------------------------------------------------===// +// quant.uniform type handling with control flow ops +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @while_region_with_quant +func.func @while_region_with_quant(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %scales = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor + %zps = "tf.Const"() { value = dense<3> : tensor } : () -> tensor + + // CHECK: %[[QUANT0:.*]] = mhlo.uniform_quantize %[[ARG:.*]] : (tensor<*xf32>) -> tensor<*x!quant.uniform> + // CHECK: %[[QUANT1:.*]] = mhlo.while(%[[ITER_ARG:.*]] = %[[QUANT0]]) : tensor<*x!quant.uniform> + // CHECK: mhlo.return %[[ITER_ARG]] : tensor<*x!quant.uniform> + // CHECK: %[[RET:.*]] = mhlo.uniform_dequantize %[[QUANT1]] : (tensor<*x!quant.uniform>) -> tensor<*xf32> + // CHECK: return %[[RET]] : tensor<*xf32> + + %0 = "tf.UniformQuantize"(%arg0, %scales, %zps) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor<*xf32>, tensor, tensor) -> tensor<*x!tf_type.qint8> + %1 = "tf.WhileRegion"(%0) ({ + ^bb0(%carg0: tensor<*x!tf_type.qint8>): + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + "tf.Yield"(%cst) : (tensor) -> () + }, { + ^bb0(%barg0: tensor<*x!tf_type.qint8>): + %id = "tf.Identity"(%barg0) : (tensor<*x!tf_type.qint8>) -> tensor<*x!tf_type.qint8> + "tf.Yield"(%id) : (tensor<*x!tf_type.qint8>) -> () + }) {is_stateless = false} : (tensor<*x!tf_type.qint8>) -> tensor<*x!tf_type.qint8> + %2 = "tf.UniformDequantize"(%1, %scales, %zps) {quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64} : (tensor<*x!tf_type.qint8>, tensor, tensor) -> tensor<*xf32> + func.return %2 : tensor<*xf32> +} + +// CHECK-LABEL: func @while_region_with_quant_two_args +func.func @while_region_with_quant_two_args(%arg0: tensor<2x2xf32>) -> (tensor<2x?xf32>, tensor) { + %scales = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor + %zps2 = "tf.Const"() { value = dense<2> : tensor } : () -> tensor + %zps4 = "tf.Const"() { value = dense<4> : tensor } : () -> tensor + + + // CHECK: %[[QUANT0:.*]] = mhlo.uniform_quantize %[[ARG:.*]] : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> + %0 = "tf.UniformQuantize"(%arg0, %scales, %zps2) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor<2x2xf32>, tensor, tensor) -> tensor<2x2x!tf_type.qint8> + // CHECK: %[[QUANT1:.*]] = mhlo.uniform_quantize %[[ARG:.*]] : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> + %1 = "tf.UniformQuantize"(%arg0, %scales, %zps4) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 + } : (tensor<2x2xf32>, tensor, tensor) -> tensor<2x2x!tf_type.qint8> + + // CHECK: %[[WHILE_RESULT:.*]]:2 = mhlo.while(%[[ARG0:.*]] = %[[QUANT0]], %[[ARG1:.*]] = %[[QUANT1]]) + // CHECK-SAME: tensor<2x2x!quant.uniform>, tensor<2x2x!quant.uniform> + + // CHECK: cond + + // CHECK: do + // CHECK: mhlo.return %[[ARG0]], %[[ARG1]] : tensor<2x?x!quant.uniform>, tensor> + + %2:2 = "tf.WhileRegion"(%0, %1) ({ + ^bb0(%carg0: tensor<2x?x!tf_type.qint8>, %carg1: tensor): + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + "tf.Yield"(%cst) : (tensor) -> () + }, { + ^bb0(%barg0: tensor<2x?x!tf_type.qint8>, %barg1: tensor): + %id = "tf.Identity"(%barg0) : (tensor<2x?x!tf_type.qint8>) -> tensor<2x?x!tf_type.qint8> + "tf.Yield"(%id, %barg1) : (tensor<2x?x!tf_type.qint8>, tensor) -> () + }) {is_stateless = false} : (tensor<2x2x!tf_type.qint8>, tensor<2x2x!tf_type.qint8>) -> (tensor<2x?x!tf_type.qint8>, tensor) + + // %[[RESULT0:.*]] = mhlo.uniform_dequantize %[[WHILE_RESULT]]#0 + %3 = "tf.UniformDequantize"(%2#0, %scales, %zps2) {quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64} : (tensor<2x?x!tf_type.qint8>, tensor, tensor) -> tensor<2x?xf32> + + // %[[RESULT1:.*]] = mhlo.uniform_dequantize %[[WHILE_RESULT]]#0 + %4 = "tf.UniformDequantize"(%2#1, %scales, %zps4) {quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64} : (tensor, tensor, tensor) -> tensor + + // return %[[RESULT0]], %[[RESULT1]] + func.return %3, %4 : tensor<2x?xf32>, tensor +} diff --git a/tensorflow/compiler/mlir/xla/tests/verify-tfxla-legalization-no-chlo.mlir b/tensorflow/compiler/mlir/xla/tests/verify-tfxla-legalization-no-chlo.mlir new file mode 100644 index 00000000000..b4f83b39f4a --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/verify-tfxla-legalization-no-chlo.mlir @@ -0,0 +1,11 @@ +// RUN: xla-opt "-tfxla-verify-legalization=legalize-chlo=false" -verify-diagnostics -split-input-file %s | FileCheck %s --dump-input=fail +// Tests the VerifyTFXLALegalization Pass, that just ensures we don't have +// any illegal ops at the end of the pipeline. This runs with +// legalize-chlo=false since errors can't be mixed with the legalize-chlo=True +// version. + +// CHECK-LABEL: allows_chlo +func.func @allows_chlo(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + func.return %0 : tensor<1x32x10x32xi32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/verify-tfxla-legalization.mlir b/tensorflow/compiler/mlir/xla/tests/verify-tfxla-legalization.mlir new file mode 100644 index 00000000000..2e86dd5ea06 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/verify-tfxla-legalization.mlir @@ -0,0 +1,39 @@ +// RUN: xla-opt "-tfxla-verify-legalization=legalize-chlo=true" -verify-diagnostics -split-input-file %s | FileCheck -dump-input=fail %s +// Tests the VerifyTFXLALegalization Pass, that just ensures we don't have +// any illegal ops at the end of the pipeline. + +// CHECK-LABEL: allowsMHLO +func.func @allowsMHLO() -> (tensor<8x64x32x4xcomplex> {mhlo.sharding = ""}) { + %0 = mhlo.constant dense<(1.000000e+00,-1.000000e+00)> : tensor<128x32x4xcomplex> + %1 = mhlo.constant dense<(1.000000e+00,1.000000e+00)> : tensor<8x64x128xcomplex> + %2 = "mhlo.einsum"(%1, %0) {einsum_config = "abc,cde->abde"} : (tensor<8x64x128xcomplex>, tensor<128x32x4xcomplex>) -> tensor<8x64x32x4xcomplex> + return %2 : tensor<8x64x32x4xcomplex> +} + +// ----- + +func.func @invalid_non_mhlo() -> (tensor<8x64x32x4xcomplex> {mhlo.sharding = ""}) { + // expected-error @+1 {{Could not legalize op: tf.Const}} + %cst = "tf.Const"() {value = dense<(1.000000e+00,-1.000000e+00)> : tensor<128x32x4xcomplex>} : () -> tensor<128x32x4xcomplex> + %cst_0 = "tf.Const"() {value = dense<(1.000000e+00,1.000000e+00)> : tensor<8x64x128xcomplex>} : () -> tensor<8x64x128xcomplex> + %0 = "tf.XlaEinsum"(%cst_0, %cst) {equation = "abc,cde->abde"} : (tensor<8x64x128xcomplex>, tensor<128x32x4xcomplex>) -> tensor<8x64x32x4xcomplex> + return %0 : tensor<8x64x32x4xcomplex> +} + +// ----- + +func.func @invalid_mixed_mhlo() -> (tensor<8x64x32x4xcomplex> {mhlo.sharding = ""}) { + %0 = mhlo.constant dense<(1.000000e+00,-1.000000e+00)> : tensor<128x32x4xcomplex> + // expected-error @+1 {{Could not legalize op: tf.Const}} + %cst_0 = "tf.Const"() {value = dense<(1.000000e+00,1.000000e+00)> : tensor<8x64x128xcomplex>} : () -> tensor<8x64x128xcomplex> + %1 = "tf.XlaEinsum"(%cst_0, %0) {equation = "abc,cde->abde"} : (tensor<8x64x128xcomplex>, tensor<128x32x4xcomplex>) -> tensor<8x64x32x4xcomplex> + return %1 : tensor<8x64x32x4xcomplex> +} + +// ----- + +func.func @fails_chlo(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { + // expected-error @+1 {{Could not legalize op: chlo.broadcast_add}} + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + func.return %0 : tensor<1x32x10x32xi32> +} diff --git a/tensorflow/compiler/mlir/xla/transforms/adjust_layout.cc b/tensorflow/compiler/mlir/xla/transforms/adjust_layout.cc index 5be2e9dbaa2..ae9974146ed 100644 --- a/tensorflow/compiler/mlir/xla/transforms/adjust_layout.cc +++ b/tensorflow/compiler/mlir/xla/transforms/adjust_layout.cc @@ -32,7 +32,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/xla/layout.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_api.h" @@ -48,8 +48,8 @@ static FailureOr> GetTPUInfeedLayoutFromAPI( xla::Shape old_shape = xla::TypeToShape(t); XLA_Shape old_shape_c = {}; XLA_Shape new_shape_c = {}; - TfTpu_ExecutorApiFn *executor = tensorflow::tpu::ExecutorApiFn(); - if (!tensorflow::tpu::IsInitialized(executor)) { + TfTpu_ExecutorApiFn *executor = stream_executor::tpu::ExecutorApiFn(); + if (!stream_executor::tpu::IsInitialized(executor)) { return failure(); } ApiConverter::ToC(old_shape, &old_shape_c); @@ -72,7 +72,7 @@ FailureOr GetTPUInfeedLayout(const ArrayRef types, if (t.isa()) continue; auto layout = GetTPUInfeedLayout({t}, rewriter); if (failed(layout)) return failure(); - v.push_back(layout.getValue()); + v.push_back(layout.value()); } ArrayRef shape(v); return rewriter.getArrayAttr(shape); @@ -85,7 +85,7 @@ FailureOr GetTPUInfeedLayout(const ArrayRef types, if (t.isa()) continue; auto layout = GetTPUInfeedLayout({t}, rewriter); if (failed(layout)) return failure(); - v.push_back(layout.getValue()); + v.push_back(layout.value()); } ArrayRef shape(v); return rewriter.getArrayAttr(shape); @@ -94,7 +94,7 @@ FailureOr GetTPUInfeedLayout(const ArrayRef types, auto layout = GetTPUInfeedLayoutFromAPI(t); std::vector minor_to_major; if (succeeded(layout)) { - minor_to_major = layout.getValue(); + minor_to_major = layout.value(); } else { /* If we're not running on a TPU node, we might not be able to * actually call the part of the TPU API that gives us layout. @@ -151,7 +151,7 @@ class AdjustLayout auto layout = GetTPUInfeedLayout(result_types, builder); if (failed(layout)) return; - op->setAttr("layout", layout.getValue()); + op->setAttr("layout", layout.value()); } } diff --git a/tensorflow/compiler/mlir/xla/transforms/convert_mhlo_quant_to_int.cc b/tensorflow/compiler/mlir/xla/transforms/convert_mhlo_quant_to_int.cc new file mode 100644 index 00000000000..0dd72dc3ad6 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/convert_mhlo_quant_to_int.cc @@ -0,0 +1,240 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/xla/transforms/utils.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace mlir { +namespace mhlo { +namespace { + +#define GEN_PASS_DEF_CONVERTMHLOQUANTTOINT +#include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes.h.inc" + +FailureOr GetSameShapeTensorType(Operation *op, + TensorType tensor_type, + Type element_type, + PatternRewriter &rewriter) { + if (auto ranked_ty = tensor_type.dyn_cast_or_null()) { + Attribute encoding = ranked_ty.getEncoding(); + if (!(!encoding || encoding.isa() || + encoding.isa())) { + return rewriter.notifyMatchFailure( + op, + "Ranked tensor encoding must be either null, TypeExtensionsAttr, or " + "SparseTensorEncodingAttr."); + } + return RankedTensorType::get(ranked_ty.getShape(), element_type, encoding); + } + if (auto unranked_ty = tensor_type.dyn_cast_or_null()) { + return UnrankedTensorType::get(element_type); + } + llvm_unreachable("unhandled type"); +} + +class ConvertMHLOQuantToInt + : public impl::ConvertMHLOQuantToIntBase { + public: + // Performs conversion of MHLO quant ops to primitive ops. + void runOnOperation() override; +}; + +class ConvertUniformQuantizeOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::UniformQuantizeOp op, UniformQuantizeOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto element_type = getElementTypeOrSelf(op.getResult().getType()) + .dyn_cast(); + // Currently for activation, PTQ supports per-tensor quantization only, and + // UniformQuantize op is only for activation. + if (!element_type) { + return rewriter.notifyMatchFailure( + op, "Legalization supports only per-tensor quantization."); + } + Value scale = rewriter.create( + op->getLoc(), rewriter.getF32FloatAttr(element_type.getScale())); + Value zero_point = rewriter.create( + op->getLoc(), rewriter.getI32IntegerAttr( + static_cast(element_type.getZeroPoint()))); + Value half = rewriter.create( + op->getLoc(), rewriter.getF32FloatAttr(0.5f)); + Value quantization_min = rewriter.create( + op->getLoc(), rewriter.getI32IntegerAttr(static_cast( + element_type.getStorageTypeMin()))); + Value quantization_max = rewriter.create( + op->getLoc(), rewriter.getI32IntegerAttr(static_cast( + element_type.getStorageTypeMax()))); + + auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + auto res_float_tensor_type_or = + GetSameShapeTensorType(op, op.getOperand().getType().cast(), + rewriter.getF32Type(), rewriter); + if (failed(res_float_tensor_type_or)) { + return failure(); + } + Value res_float = rewriter.create( + op->getLoc(), *res_float_tensor_type_or, adaptor.getOperand(), scale, + scalar_broadcast_dims); + res_float = rewriter.create( + op->getLoc(), *res_float_tensor_type_or, res_float, half, + scalar_broadcast_dims); + res_float = rewriter.create(op->getLoc(), res_float); + auto res_int32_tensor_type_or = + GetSameShapeTensorType(op, res_float.getType().cast(), + rewriter.getI32Type(), rewriter); + if (failed(res_int32_tensor_type_or)) { + return failure(); + } + Value res_int32 = rewriter.create( + op->getLoc(), *res_int32_tensor_type_or, res_float); + res_int32 = rewriter.create( + op->getLoc(), *res_int32_tensor_type_or, res_int32, zero_point, + scalar_broadcast_dims); + res_int32 = rewriter.create( + op->getLoc(), *res_int32_tensor_type_or, res_int32, quantization_min, + scalar_broadcast_dims); + res_int32 = rewriter.create( + op->getLoc(), *res_int32_tensor_type_or, res_int32, quantization_max, + scalar_broadcast_dims); + auto res_final_tensor_type_or = + GetSameShapeTensorType(op, res_int32.getType().cast(), + rewriter.getI8Type(), rewriter); + rewriter.replaceOpWithNewOp(op, *res_final_tensor_type_or, + res_int32); + return success(); + } +}; + +class ConvertUniformDequantizeOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::UniformDequantizeOp op, UniformDequantizeOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto element_type = getElementTypeOrSelf(op.getOperand().getType()) + .dyn_cast(); + // Currently for activation, PTQ supports per-tensor quantization only, and + // UniformQuantize op is only for activation. + if (!element_type) { + return rewriter.notifyMatchFailure( + op, "Legalization supports only per-tensor quantization."); + } + Value scale = rewriter.create( + op->getLoc(), rewriter.getF32FloatAttr(element_type.getScale())); + Value zero_point = rewriter.create( + op->getLoc(), rewriter.getI32IntegerAttr( + static_cast(element_type.getZeroPoint()))); + + Value input = adaptor.getOperand(); + auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + auto res_int32_tensor_type_or = + GetSameShapeTensorType(op, input.getType().cast(), + rewriter.getI32Type(), rewriter); + if (failed(res_int32_tensor_type_or)) { + return failure(); + } + Value res_int32 = rewriter.create( + op->getLoc(), *res_int32_tensor_type_or, input); + res_int32 = rewriter.create( + op->getLoc(), *res_int32_tensor_type_or, res_int32, zero_point, + scalar_broadcast_dims); + auto res_float_tensor_type_or = + GetSameShapeTensorType(op, res_int32.getType().cast(), + rewriter.getF32Type(), rewriter); + if (failed(res_float_tensor_type_or)) { + return failure(); + } + Value res_float = rewriter.create( + op->getLoc(), *res_float_tensor_type_or, res_int32); + res_float = rewriter.replaceOpWithNewOp( + op, *res_float_tensor_type_or, res_float, scale, scalar_broadcast_dims); + return success(); + } +}; + +// Performs conversion of MHLO quant ops to primitive ops. +void ConvertMHLOQuantToInt::runOnOperation() { + Operation *op = getOperation(); + MLIRContext *context = op->getContext(); + RewritePatternSet patterns(context); + + // Populate MHLO quant ops conversion patterns. + patterns.add(context); + + ConversionTarget target(*op->getContext()); + auto is_legal = [](Operation *op) { + auto is_not_quant = [](Type type) { + return !getElementTypeOrSelf(type).isa(); + }; + return llvm::all_of(op->getOperandTypes(), is_not_quant) && + llvm::all_of(op->getResultTypes(), is_not_quant); + }; + target.addDynamicallyLegalDialect(is_legal); + target.addDynamicallyLegalDialect(is_legal); + + LogicalResult result = + applyPartialConversion(op, target, std::move(patterns)); + if (failed(result)) { + signalPassFailure(); + } +} + +} // end namespace + +std::unique_ptr> createConvertMHLOQuantToIntPass() { + return std::make_unique(); +} + +} // end namespace mhlo +} // end namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index f7aaa1bae3a..e1de21cf038 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -56,9 +56,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/convert_op_folder.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/hlo_utils.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/utils/convert_op_folder.h" +#include "tensorflow/compiler/xla/mlir_hlo/utils/hlo_utils.h" #include "tensorflow/compiler/xla/translate/hlo_to_mhlo/attribute_importer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_shape_util.h" @@ -220,87 +220,6 @@ static ConstantOp GetScalarLimitConstOfType(Type ty, Location loc, return builder->create(loc, hlo::getScalarLimitOfType(ty, limit)); } -// Creates an mhlo::SliceOp where the major dimensions have full size, and -// the minor dimensions have the provided offsets and sizes. -static Value SliceInMinorDims(Location loc, Value v, - ArrayRef minor_starts, - ArrayRef minor_limits, - OpBuilder *builder) { - auto type = v.getType().cast(); - llvm::SmallVector slice_starts(type.getRank(), 0); - int64_t major_dims = type.getRank() - minor_starts.size(); - std::copy(minor_starts.begin(), minor_starts.end(), - slice_starts.begin() + major_dims); - auto slice_limits = llvm::to_vector<4>(type.getShape()); - std::copy(minor_limits.begin(), minor_limits.end(), - slice_limits.begin() + major_dims); - llvm::SmallVector slice_strides(type.getRank(), 1); - return builder->create(loc, v, - GetI64ElementsAttr(slice_starts, builder), - GetI64ElementsAttr(slice_limits, builder), - GetI64ElementsAttr(slice_strides, builder)); -} - -// Creates a vector of index values: -// [0, 0, ..., minor_indices[0], minor_indices[1], ... minor_indices[-1]] -// with length `rank`. -static llvm::SmallVector CreateFullIndexVectorFromMinorIndices( - Location loc, ArrayRef minor_indices, int64_t rank, - OpBuilder *builder) { - auto zero = - GetScalarConstOfType(getElementTypeOrSelf(minor_indices[0].getType()), - loc, 0, builder) - .getOutput(); - llvm::SmallVector indices(rank, zero); - std::copy(minor_indices.begin(), minor_indices.end(), - indices.begin() + (rank - minor_indices.size())); - return indices; -} - -// Creates an mhlo::DynamicSliceOp where the major dimensions have full size, -// and the minor dimensions have the provided offsets and sizes. -static Value DynamicSliceInMinorDims(Location loc, Value v, - ArrayRef minor_starts, - ArrayRef minor_sizes, - OpBuilder *builder) { - if (minor_starts.empty()) return v; - auto type = v.getType().cast(); - auto slice_starts = CreateFullIndexVectorFromMinorIndices( - loc, minor_starts, type.getRank(), builder); - int64_t major_dims = type.getRank() - minor_starts.size(); - auto slice_sizes = llvm::to_vector<4>(type.getShape()); - std::copy(minor_sizes.begin(), minor_sizes.end(), - slice_sizes.begin() + major_dims); - return builder->create( - loc, v, slice_starts, GetI64ElementsAttr(slice_sizes, builder)); -} - -// Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero -// offsets, and the minor dimensions have the provided offsets. -static Value DynamicUpdateSliceInMinorDims(Location loc, Value v, Value update, - ArrayRef minor_starts, - OpBuilder *builder) { - if (minor_starts.empty()) return v; - auto type = v.getType().cast(); - auto dus_starts = CreateFullIndexVectorFromMinorIndices( - loc, minor_starts, type.getRank(), builder); - return builder->create(loc, type, v, update, - llvm::makeArrayRef(dus_starts)); -} - -// Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero -// offsets, and the minor dimensions have the provided static offsets. -static Value UpdateSliceInMinorDims(Location loc, Value v, Value update, - ArrayRef minor_starts, - OpBuilder *builder) { - llvm::SmallVector dus_starts(minor_starts.size()); - for (uint64_t i = 0; i < minor_starts.size(); ++i) { - dus_starts[i] = GetScalarConstOfType(builder->getIntegerType(32), loc, - minor_starts[i], builder); - } - return DynamicUpdateSliceInMinorDims(loc, v, update, dus_starts, builder); -} - // Deprecated: This is maintained to aid in porting old code that is not yet // dynamic shape aware and uses broadcasting modes that CHLO does not support. // Gets the resulting type from a broadcast between two types for statically @@ -442,35 +361,6 @@ static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to, loc, to_type, input, result_extents, broadcast_dims); } -// Creates a batch dot using mhlo::DotGeneralOp. -Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs, - bool transpose_rhs, int64_t num_batch_dims, - ArrayAttr precision_config, OpBuilder *builder) { - auto batch_dimensions = - llvm::to_vector<4>(llvm::seq(0, num_batch_dims)); - auto lhs_contracting_dimensions = llvm::to_vector<1>(llvm::makeArrayRef( - {transpose_lhs ? num_batch_dims : num_batch_dims + 1})); - auto rhs_contracting_dimensions = llvm::to_vector<1>(llvm::makeArrayRef( - {transpose_rhs ? num_batch_dims + 1 : num_batch_dims})); - auto dimension_numbers = DotDimensionNumbersAttr::get( - builder->getContext(), - /*lhs_batching_dimensions=*/batch_dimensions, - /*rhs_batching_dimensions=*/batch_dimensions, - /*lhs_contracting_dimensions=*/lhs_contracting_dimensions, - /*rhs_contracting_dimensions=*/rhs_contracting_dimensions); - auto lhs_shape = lhs.getType().cast().getShape(); - auto rhs_shape = rhs.getType().cast().getShape(); - auto shape = llvm::to_vector<4>(lhs_shape); - shape[shape.size() - 2] = - transpose_lhs ? lhs_shape.back() : lhs_shape[lhs_shape.size() - 2]; - shape[shape.size() - 1] = - transpose_rhs ? rhs_shape[rhs_shape.size() - 2] : rhs_shape.back(); - Type element_type = getElementTypeOrSelf(lhs.getType()); - return builder->create( - loc, tensorflow::GetTypeFromTFTensorShape(shape, element_type), lhs, rhs, - dimension_numbers, precision_config); -} - // Builds a set of operations for applying reduction on the input value. A // tf.sum op is created and will be legalized to tfl ops automatically. static Value ApplyReduction(Location loc, Value input, @@ -992,15 +882,15 @@ class ConvertBiasAddOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); tensorflow::TensorFormat data_format; - if (!FormatFromString(op.data_format().str(), &data_format)) + if (!FormatFromString(op.getDataFormat().str(), &data_format)) return op.emitOpError("invalid data format"); - auto value_type = op.value().getType().dyn_cast(); + auto value_type = op.getValue().getType().dyn_cast(); if (!value_type) return failure(); auto feature_dim = GetFeatureDimension(data_format, value_type); - auto bias_broadcast = Broadcast1DToFeatureDim(loc, op.value(), op.bias(), - feature_dim, rewriter); - Value add = rewriter.create(loc, op.value(), bias_broadcast); + auto bias_broadcast = Broadcast1DToFeatureDim( + loc, op.getValue(), op.getBias(), feature_dim, rewriter); + Value add = rewriter.create(loc, op.getValue(), bias_broadcast); if (add.getType() != op.getType()) { add = rewriter.create(loc, op.getType(), add); } @@ -1017,10 +907,10 @@ class ConvertGatherV2OpDynamic : public OpRewritePattern { LogicalResult matchAndRewrite(TF::GatherV2Op op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value params = op.params(); + Value params = op.getParams(); // params and indices of GatherNdOp must be ranked auto params_ty = params.getType().dyn_cast(); - Value indices = op.indices(); + Value indices = op.getIndices(); auto indices_ty = indices.getType().dyn_cast(); if (!params_ty || !indices_ty) return failure(); @@ -1035,7 +925,7 @@ class ConvertGatherV2OpDynamic : public OpRewritePattern { // axis DenseIntElementsAttr axis_attr; // axis must be const for GatherOp - if (!matchPattern(op.axis(), m_Constant(&axis_attr))) return failure(); + if (!matchPattern(op.getAxis(), m_Constant(&axis_attr))) return failure(); int64_t axis = (*axis_attr.begin()).getSExtValue(); if (axis < 0) axis += params_rank; @@ -1059,7 +949,7 @@ class ConvertGatherV2OpDynamic : public OpRewritePattern { loc, rewriter.getIntegerAttr(indices_ty.getElementType(), 1))); } else { int64_t dim_size = params_ty.getDimSize(dim_idx); - if (dim_size != ShapedType::kDynamicSize) { + if (dim_size != ShapedType::kDynamic) { slice_sizes_vals.push_back(rewriter.create( loc, rewriter.getIntegerAttr(indices_ty.getElementType(), dim_size))); @@ -1094,7 +984,7 @@ class ConvertGatherV2OpDynamic : public OpRewritePattern { /*index_vector_dim=*/index_vector_dim); rewriter.replaceOpWithNewOp( - op, op.getType(), op.params(), op.indices(), slice_sizes_value, + op, op.getType(), op.getParams(), op.getIndices(), slice_sizes_value, dims_attr); return success(); } @@ -1181,16 +1071,17 @@ class ConvertConvDynamic : public OpRewritePattern { LogicalResult matchAndRewriteDynamicConv(OpT op, PatternRewriter &rewriter) const { tensorflow::TensorFormat data_format; - if (!FormatFromString(op.data_format().str(), &data_format)) + if (!FormatFromString(op.getDataFormat().str(), &data_format)) return op.emitOpError("invalid data format"); tensorflow::Padding padding; - if (!GetPaddingFromString(op.padding().str(), &padding).ok()) + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) return failure(); - auto input_ty = op.input().getType().template dyn_cast(); + auto input_ty = + op.getInput().getType().template dyn_cast(); auto filter_ty = - op.filter().getType().template dyn_cast(); + op.getFilter().getType().template dyn_cast(); auto result_ty = op.getType().template dyn_cast(); if (!input_ty || !filter_ty || !result_ty) return failure(); // TODO(disc): Remove this constraint once fold and canonicalization @@ -1198,8 +1089,8 @@ class ConvertConvDynamic : public OpRewritePattern { if (input_ty.hasStaticShape() && filter_ty.hasStaticShape()) return failure(); - ArrayRef dilations = op.dilations().getValue(); - ArrayRef strides = op.strides().getValue(); + ArrayRef dilations = op.getDilations().getValue(); + ArrayRef strides = op.getStrides().getValue(); ArrayRef explicit_paddings; if (padding == tensorflow::Padding::EXPLICIT) { // EXPLICIT padding mode and the associated attribute is attached to @@ -1246,8 +1137,8 @@ class ConvertConvDynamic : public OpRewritePattern { pad_low = get_const(get_int(explicit_paddings[2 * dim])); pad_high = get_const(get_int(explicit_paddings[2 * dim + 1])); } else { - auto input_size = get_dim_value(op.input(), dim); - auto filter_size = get_dim_value(op.filter(), i); + auto input_size = get_dim_value(op.getInput(), dim); + auto filter_size = get_dim_value(op.getFilter(), i); if (!GetPaddingValues(op, rewriter, input_size, filter_size, dilation, stride, padding, shape_scalar_type, &pad_low, &pad_high)) { @@ -1298,9 +1189,9 @@ class ConvertConvDynamic : public OpRewritePattern { llvm::SmallVector new_shape( filter_shape.begin(), filter_shape.begin() + num_spatial_dims); new_shape.push_back(1); - if (filter_shape[num_spatial_dims] == ShapedType::kDynamicSize || - filter_shape[num_spatial_dims + 1] == ShapedType::kDynamicSize) { - new_shape.push_back(ShapedType::kDynamicSize); + if (filter_shape[num_spatial_dims] == ShapedType::kDynamic || + filter_shape[num_spatial_dims + 1] == ShapedType::kDynamic) { + new_shape.push_back(ShapedType::kDynamic); } else { new_shape.push_back(filter_shape[num_spatial_dims] * filter_shape[num_spatial_dims + 1]); @@ -1310,7 +1201,7 @@ class ConvertConvDynamic : public OpRewritePattern { SmallVector filter_dim_sizes; for (int i = 0; i < rank; ++i) { filter_dim_sizes.push_back( - rewriter.create(loc, op.filter(), i)); + rewriter.create(loc, op.getFilter(), i)); } filter_dim_sizes[rank-1] = rewriter.create( loc, filter_dim_sizes[rank-1], filter_dim_sizes[rank-2]); @@ -1326,7 +1217,7 @@ class ConvertConvDynamic : public OpRewritePattern { dimension_numbers_attr, feature_group_count_attr, batch_group_count_attr}; rewriter.replaceOpWithNewOp(op, op.getType(), operands, - llvm::makeArrayRef(attrs)); + llvm::ArrayRef(attrs)); return success(); } @@ -1366,16 +1257,17 @@ class ConvertConvOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { tensorflow::TensorFormat data_format; - if (!FormatFromString(op.data_format().str(), &data_format)) + if (!FormatFromString(op.getDataFormat().str(), &data_format)) return op.emitOpError("invalid data format"); tensorflow::Padding padding; - if (!GetPaddingFromString(op.padding().str(), &padding).ok()) + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) return failure(); - auto input_ty = op.input().getType().template dyn_cast(); + auto input_ty = + op.getInput().getType().template dyn_cast(); auto filter_ty = - op.filter().getType().template dyn_cast(); + op.getFilter().getType().template dyn_cast(); // With the exception of input's batch dimension, input and filter need to // have static shape for calculation of HLO paddings and feature group count @@ -1383,8 +1275,8 @@ class ConvertConvOp : public OpRewritePattern { if (!input_ty || !filter_ty || !filter_ty.hasStaticShape()) return failure(); - ArrayRef dilations = op.dilations().getValue(); - ArrayRef strides = op.strides().getValue(); + ArrayRef dilations = op.getDilations().getValue(); + ArrayRef strides = op.getStrides().getValue(); ArrayRef explicit_paddings; if (padding == tensorflow::Padding::EXPLICIT) { // EXPLICIT padding mode and the associated attribute is limited to @@ -1422,7 +1314,7 @@ class ConvertConvOp : public OpRewritePattern { int64_t pad_low_int64; int64_t pad_high_int64; int64_t input_size = input_ty.getDimSize(dim); - if (input_size == ShapedType::kDynamicSize) return failure(); + if (input_size == ShapedType::kDynamic) return failure(); tsl::Status status = tensorflow::GetWindowedOutputSizeVerboseV2( input_size, filter_ty.getDimSize(i), dilation, stride, padding, &output_size, &pad_low_int64, &pad_high_int64); @@ -1445,7 +1337,7 @@ class ConvertConvOp : public OpRewritePattern { const int64_t input_channels = GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format)); - if (input_channels == ShapedType::kDynamicSize) return failure(); + if (input_channels == ShapedType::kDynamic) return failure(); // Filters data_format is always HWIO so input channels dimension is after // all spatial dimensions. const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims); @@ -1486,7 +1378,7 @@ class ConvertConvOp : public OpRewritePattern { dimension_numbers_attr, feature_group_count_attr, batch_group_count_attr, paddings_attr}; rewriter.replaceOpWithNewOp(op, op.getType(), operands, - llvm::makeArrayRef(attrs)); + llvm::ArrayRef(attrs)); return success(); } }; @@ -1506,9 +1398,9 @@ class ConvertPadOpDynamic : public OpRewritePattern { LogicalResult matchAndRewrite(TF::PadV2Op op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto input = op.input(); - auto paddings = op.paddings(); - auto constant_values = op.constant_values(); + auto input = op.getInput(); + auto paddings = op.getPaddings(); + auto constant_values = op.getConstantValues(); auto input_type = input.getType().dyn_cast(); auto paddings_type = paddings.getType().dyn_cast(); if (!input_type || !paddings_type || !paddings_type.hasStaticShape()) @@ -1571,16 +1463,16 @@ class ConvertGatherNdOpDynamic : public OpRewritePattern { LogicalResult matchAndRewrite(TF::GatherNdOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto params = op.params(); + auto params = op.getParams(); auto params_ty = params.getType().dyn_cast(); - auto indices = op.indices(); + auto indices = op.getIndices(); auto indices_ty = indices.getType().dyn_cast(); auto params_rank = params_ty.getRank(); auto indices_rank = indices_ty.getRank(); int64_t num_index_dims = indices_ty.getDimSize(indices_rank - 1); if (!params_ty || !indices_ty) return failure(); // the last dim of indices of GatherNdOp must be fixed shaped - if (num_index_dims == ShapedType::kDynamicSize) return failure(); + if (num_index_dims == ShapedType::kDynamic) return failure(); SmallVector slice_sizes; slice_sizes.reserve(params_rank); @@ -1601,7 +1493,7 @@ class ConvertGatherNdOpDynamic : public OpRewritePattern { loc, rewriter.getIntegerAttr(indices_ty.getElementType(), 1))); } else { int64_t dim_size = params_ty.getDimSize(i); - if (dim_size != ShapedType::kDynamicSize) { + if (dim_size != ShapedType::kDynamic) { slice_sizes_vals.push_back(rewriter.create( loc, rewriter.getIntegerAttr(indices_ty.getElementType(), dim_size))); @@ -1643,11 +1535,11 @@ class ConvertGatherNdOpDynamic : public OpRewritePattern { // implemented. if (params_ty.hasStaticShape() && indices_ty.hasStaticShape()) { rewriter.replaceOpWithNewOp( - op, op.getType(), op.params(), op.indices(), dims_attr, + op, op.getType(), op.getParams(), op.getIndices(), dims_attr, GetI64ElementsAttr(slice_sizes, &rewriter)); } else { rewriter.replaceOpWithNewOp( - op, op.getType(), op.params(), op.indices(), slice_sizes_value, + op, op.getType(), op.getParams(), op.getIndices(), slice_sizes_value, dims_attr); } return success(); @@ -1672,12 +1564,12 @@ class ConvertBF16FloorDivOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::FloorDivOp op, PatternRewriter &rewriter) const override { - auto l = op.x(); - auto r = op.y(); + auto l = op.getX(); + auto r = op.getY(); auto element_type = getElementTypeOrSelf(l.getType()); if (!element_type.isBF16()) return failure(); - auto out_type = op.z().getType().cast(); + auto out_type = op.getZ().getType().cast(); l = rewriter.create(op.getLoc(), l, rewriter.getF32Type()); r = rewriter.create(op.getLoc(), r, rewriter.getF32Type()); @@ -1700,8 +1592,8 @@ class ConvertBroadcastToOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::BroadcastToOp op, PatternRewriter &rewriter) const override { - auto input_type = op.input().getType().dyn_cast(); - auto output_type = op.output().getType(); + auto input_type = op.getInput().getType().dyn_cast(); + auto output_type = op.getOutput().getType(); if (!input_type) { return rewriter.notifyMatchFailure(op, "requires ranked input shape"); } @@ -1712,13 +1604,13 @@ class ConvertBroadcastToOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "requires ranked output shape"); } auto rank_diff = ranked_output_type.getRank() - input_type.getRank(); - // The tf.BroadcastTo op performs "right-aligned" numpy-style + // The tf.BroadcastTo op.getPerforms "right-aligned" numpy-style // broadcasting. broadcast_dimensions = llvm::to_vector<4>( llvm::seq(rank_diff, ranked_output_type.getRank())); } rewriter.replaceOpWithNewOp( - op, output_type, op.input(), op.shape(), + op, output_type, op.getInput(), op.getShape(), rewriter.getI64TensorAttr(broadcast_dimensions)); return success(); } @@ -1731,19 +1623,19 @@ class ConvertRollOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TF::RollOp op, PatternRewriter &rewriter) const override { - auto shift_ty = op.shift().getType().dyn_cast(); + auto shift_ty = op.getShift().getType().dyn_cast(); if (!shift_ty || shift_ty.getRank() != 0) { return rewriter.notifyMatchFailure( op, "require the type of shift to be 0D tensor"); } APInt val; - if (!matchPattern(op.axis(), m_ConstantInt(&val))) { + if (!matchPattern(op.getAxis(), m_ConstantInt(&val))) { return rewriter.notifyMatchFailure(op, "require axis to be constant"); } int axis = val.getSExtValue(); - auto input_ty = op.input().getType().dyn_cast(); + auto input_ty = op.getInput().getType().dyn_cast(); if (!input_ty || !input_ty.hasStaticShape()) { return rewriter.notifyMatchFailure( op, "require the type of input to have static shapes"); @@ -1756,7 +1648,7 @@ class ConvertRollOp : public OpRewritePattern { // offsets positive. // offset = ((offset % axis_size) + axis_size) % axis_size ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value offset = op.shift(); + Value offset = op.getShift(); auto axis_size = b.create(b.getIntegerAttr( getElementTypeOrSelf(offset.getType()), input_shape[axis])); offset = b.create( @@ -1767,8 +1659,8 @@ class ConvertRollOp : public OpRewritePattern { // offset. This also works if shift is not constant. // DynamicSliceOp requires the sizes being integer, and we can get the // information from input shape. - auto concat = b.create(ValueRange{op.input(), op.input()}, - b.getI64IntegerAttr(axis)); + auto concat = b.create( + ValueRange{op.getInput(), op.getInput()}, b.getI64IntegerAttr(axis)); Value zero = b.create( b.getIntegerAttr(getElementTypeOrSelf(offset.getType()), 0)); SmallVector slice_begin_indices(input_rank, zero); @@ -1788,11 +1680,11 @@ class ConvertLeakyReluOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::LeakyReluOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value features = op.features(); + Value features = op.getFeatures(); // Use ConstantLike for `alpha` to match the shape of feature. auto alphaVal = chlo::getConstantLike( - rewriter, loc, op.alpha().convertToFloat(), features); + rewriter, loc, op.getAlpha().convertToFloat(), features); Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); Value leakyActivationVal = @@ -1816,13 +1708,13 @@ class ConvertLeakyReluGradOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::LeakyReluGradOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value gradients = op.gradients(); - Value features = op.features(); + Value gradients = op.getGradients(); + Value features = op.getFeatures(); auto featureType = features.getType(); // Use ConstantLike for `alpha` to match the shape of feature. auto alphaVal = chlo::getConstantLike( - rewriter, loc, op.alpha().convertToFloat(), features); + rewriter, loc, op.getAlpha().convertToFloat(), features); Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); Value leakyGradientVal = @@ -1859,7 +1751,7 @@ class ConvertDiagPartOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::DiagPartOp op, PatternRewriter &rewriter) const override { - auto input_type = op.input().getType().dyn_cast(); + auto input_type = op.getInput().getType().dyn_cast(); if (!input_type || !input_type.hasStaticShape()) return failure(); int64_t num_dims = input_type.getRank(); if (num_dims < 2 || num_dims % 2 != 0) return failure(); @@ -1877,7 +1769,7 @@ class ConvertDiagPartOp : public OpRewritePattern { op.getLoc(), tensorflow::GetTypeFromTFTensorShape({new_size, new_size}, input_type.getElementType()), - op.input()); + op.getInput()); auto iota_type = tensorflow::GetTypeFromTFTensorShape( {new_size, new_size}, rewriter.getIntegerType(32)); auto iota0 = rewriter.create(op.getLoc(), iota_type, @@ -1919,7 +1811,7 @@ class ConvertMatrixDiagPartV3Op // tuple of two values (starting and ending diagonal, for a band). LogicalResult ExtractK(TF::MatrixDiagPartV3Op op, int64_t (*k)[2]) const { DenseIntElementsAttr kattr; - if (!matchPattern(op.k(), m_Constant(&kattr))) { + if (!matchPattern(op.getK(), m_Constant(&kattr))) { return failure(); } DenseIntElementsAttr::iterator it = kattr.begin(); @@ -1955,7 +1847,7 @@ class ConvertMatrixDiagPartV3Op LogicalResult matchAndRewrite(TF::MatrixDiagPartV3Op op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - ShapedType input_type = op.input().getType().dyn_cast(); + ShapedType input_type = op.getInput().getType().dyn_cast(); // Align is a string specifying how superdiagonals and subdiagonals should // be aligned/padded for diagonals that are shorter than max_diag_len. The @@ -2142,7 +2034,7 @@ class ConvertMatrixDiagPartV3Op /*collapsed_slice_dims=*/collapsed_dims, start_index_map, /*index_vector_dim=*/0); Value gather = rewriter.create( - loc, op.input(), start_indices, dims_attr, + loc, op.getInput(), start_indices, dims_attr, GetI64ElementsAttr(slice_sizes, &rewriter)); // We now need to broadcast the "in_bounds" boolean expression, as well as @@ -2157,7 +2049,7 @@ class ConvertMatrixDiagPartV3Op rewriter.getIntegerType(1)), in_bounds, GetI64ElementsAttr(broadcast_bounds, &rewriter)); Value b_padding = rewriter.create( - loc, op.padding_value(), GetI64ElementsAttr(output_shape, &rewriter)); + loc, op.getPaddingValue(), GetI64ElementsAttr(output_shape, &rewriter)); // Replace all out-of-bounds values in the result with padding_value. Value result = @@ -2183,11 +2075,11 @@ class ConvertEinsumOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::EinsumOp op, PatternRewriter &rewriter) const override { StringAttr equation = op->getAttrOfType("equation"); - if (op.N() == 1) { + if (op.getN() == 1) { rewriter.replaceOpWithNewOp( - op, op.getType(), *op.inputs().begin(), equation); - } else if (op.N() == 2) { - ValueRange inputs = op.inputs(); + op, op.getType(), *op.getInputs().begin(), equation); + } else if (op.getN() == 2) { + ValueRange inputs = op.getInputs(); rewriter.replaceOpWithNewOp(op, op.getType(), inputs[0], inputs[1], equation); } else { @@ -2216,13 +2108,13 @@ class ConvertFFTOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - auto input_ty = op.input().getType().template cast(); + auto input_ty = op.getInput().getType().template cast(); if (!input_ty.hasRank()) { return failure(); } auto input_shape = input_ty.getShape(); DenseIntElementsAttr fft_length_attr; - if (!matchPattern(op.fft_length(), m_Constant(&fft_length_attr))) { + if (!matchPattern(op.getFftLength(), m_Constant(&fft_length_attr))) { return failure(); } int64_t fft_length; @@ -2249,7 +2141,7 @@ class ConvertFFTOp : public OpRewritePattern { expected_shape.push_back(expected_dim); // Zero pad or truncate the last axis - Value reshaped = op.input(); + Value reshaped = op.getInput(); SmallVector begin_indices(input_shape.size(), 0); SmallVector strides(input_shape.size(), 1); @@ -2259,7 +2151,7 @@ class ConvertFFTOp : public OpRewritePattern { op.getLoc(), tensorflow::GetTypeFromTFTensorShape(expected_shape, input_ty.getElementType()), - op.input(), GetI64ElementsAttr(begin_indices, &rewriter), + op.getInput(), GetI64ElementsAttr(begin_indices, &rewriter), GetI64ElementsAttr(expected_shape, &rewriter), GetI64ElementsAttr(strides, &rewriter)); @@ -2274,7 +2166,7 @@ class ConvertFFTOp : public OpRewritePattern { loc, tensorflow::GetTypeFromTFTensorShape(expected_shape, input_ty.getElementType()), - op.input(), zero, GetI64ElementsAttr(no_padding, &rewriter), + op.getInput(), zero, GetI64ElementsAttr(no_padding, &rewriter), GetI64ElementsAttr(padding, &rewriter), GetI64ElementsAttr(no_padding, &rewriter)); } @@ -2282,7 +2174,7 @@ class ConvertFFTOp : public OpRewritePattern { rewriter.replaceOpWithNewOp( op, op.getType(), reshaped, FftTypeAttr::get(rewriter.getContext(), - symbolizeFftType(fft_string).getValue()), + symbolizeFftType(fft_string).value()), rewriter.getI64TensorAttr(fft_length)); return success(); } @@ -2303,11 +2195,11 @@ class ConvertFusedBatchNormGradBase LogicalResult matchAndRewrite(FusedBatchNormGradOpT op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value grad = op.y_backprop(); - Value act = op.x(); - Value scale = op.scale(); - Value mean = op.reserve_space_1(); - Value var = op.reserve_space_2(); + Value grad = op.getYBackprop(); + Value act = op.getX(); + Value scale = op.getScale(); + Value mean = op.getReserveSpace_1(); + Value var = op.getReserveSpace_2(); // TODO(b/141785544): Update this to not require static shapes. // activation shape needs to be static to convert negative indices in @@ -2324,7 +2216,7 @@ class ConvertFusedBatchNormGradBase act = rewriter.create(loc, act, kernel_type); tensorflow::TensorFormat data_format; - if (!FormatFromString(op.data_format().str(), &data_format)) + if (!FormatFromString(op.getDataFormat().str(), &data_format)) return op.emitOpError("invalid data format"); auto feature_dim_attr = getFeatureDimensionAttr(rewriter, data_format, act); @@ -2332,7 +2224,7 @@ class ConvertFusedBatchNormGradBase // Gets the result values. Value x_backprop, scale_backprop, offset_backprop; - if (op.is_training()) { // training + if (op.getIsTraining()) { // training // TODO(b/145536565): handle GPU logic separately. // Infers the output type with the converted `act`. Type feature_type = tensorflow::GetTypeFromTFTensorShape( @@ -2341,7 +2233,7 @@ class ConvertFusedBatchNormGradBase SmallVector operand_types = {act.getType(), feature_type, feature_type}; auto training_op = rewriter.create( - loc, operand_types, act, scale, mean, var, grad, op.epsilon(), + loc, operand_types, act, scale, mean, var, grad, op.getEpsilon(), feature_dim); x_backprop = training_op.getResult(0); @@ -2362,7 +2254,7 @@ class ConvertFusedBatchNormGradBase RankedTensorType scalar_float = tensorflow::GetTypeFromTFTensorShape({}, kernel_type); auto epsilon = rewriter.create( - loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()})); + loc, DenseFPElementsAttr::get(scalar_float, {op.getEpsilon()})); auto add_op = rewriter.create( loc, var, epsilon.getResult(), scalar_broadcast_dims); @@ -2378,7 +2270,7 @@ class ConvertFusedBatchNormGradBase // x_backprop = y_backprop * (scale * scratch1) auto scaled_grad = - rewriter.create(loc, op.scale(), scratch1); + rewriter.create(loc, op.getScale(), scratch1); x_backprop = rewriter.create( loc, grad, Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim, @@ -2395,7 +2287,7 @@ class ConvertFusedBatchNormGradBase Value last_val[2]; if (op.getResult(3).use_empty() && op.getResult(4).use_empty()) { // It doesn't matter what values we provide for the last 2 results. - last_val[0] = last_val[1] = op.x(); + last_val[0] = last_val[1] = op.getX(); } else { auto const_val = rewriter.create( op.getLoc(), DenseElementsAttr::get( @@ -2435,34 +2327,36 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { LogicalResult matchAndRewrite(FusedBatchNormOpT op, PatternRewriter &rewriter) const override { tensorflow::TensorFormat data_format; - if (!FormatFromString(op.data_format().str(), &data_format)) + if (!FormatFromString(op.getDataFormat().str(), &data_format)) return op.emitOpError("invalid data format"); - auto feature_dim = getFeatureDimensionAttr(rewriter, data_format, op.x()); + auto feature_dim = + getFeatureDimensionAttr(rewriter, data_format, op.getX()); - auto input_type_tensor = op.x().getType().template cast(); + auto input_type_tensor = op.getX().getType().template cast(); auto input_element_type = input_type_tensor.getElementType(); - auto scale_type_tensor = op.scale().getType().template cast(); + auto scale_type_tensor = + op.getScale().getType().template cast(); auto scale_element_type = scale_type_tensor.getElementType(); - auto mean_type_tensor = op.mean().getType().template cast(); + auto mean_type_tensor = op.getMean().getType().template cast(); auto mean_element_type = mean_type_tensor.getElementType(); // In the training case, dimensions of input tensors must be static. - if (op.is_training() && (!input_type_tensor.hasStaticShape() || - !scale_type_tensor.hasStaticShape() || - !mean_type_tensor.hasStaticShape())) + if (op.getIsTraining() && (!input_type_tensor.hasStaticShape() || + !scale_type_tensor.hasStaticShape() || + !mean_type_tensor.hasStaticShape())) return failure(); // TODO(b/69928690): Support mixed precision in the XLA batch // normalization operators. As a workaround, create a new x with the same // element type as scale (which may be more precise than the input type). - Value bn_train_input = rewriter.create(op.getLoc(), op.x(), - scale_element_type); + Value bn_train_input = rewriter.create( + op.getLoc(), op.getX(), scale_element_type); TensorType bn_train_input_type_tensor = bn_train_input.getType().template cast(); - if (op.is_training()) { + if (op.getIsTraining()) { // Training case. auto operand_shape = bn_train_input_type_tensor.getShape(); // The mean and variance are each 1 dimensional arrays the size of the @@ -2476,8 +2370,8 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { SmallVector operand_types = {bn_train_input_type_tensor, mean_var_type, mean_var_type}; auto bn_train_op = rewriter.create( - op.getLoc(), operand_types, bn_train_input, op.scale(), op.offset(), - op.epsilon(), feature_dim.getInt()); + op.getLoc(), operand_types, bn_train_input, op.getScale(), + op.getOffset(), op.getEpsilon(), feature_dim.getInt()); // HLO op outputs a tuple of tensors. Extract those results. Value y_out = bn_train_op.getResult(0); Value batch_mean = bn_train_op.getResult(1); @@ -2504,7 +2398,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { input_element_type); float exponential_avg_factor = - op.exponential_avg_factor().convertToFloat(); + op.getExponentialAvgFactor().convertToFloat(); if (exponential_avg_factor != 1.0f) { auto alpha = rewriter.create( op.getLoc(), rewriter.getFloatAttr(mean_element_type, @@ -2515,7 +2409,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // new_running_mean = alpha * old_mean + beta * batch_mean. auto alpha_mul_old_mean = rewriter.create( - op.getLoc(), op.mean().getType(), alpha, op.mean(), + op.getLoc(), op.getMean().getType(), alpha, op.getMean(), /*broadcast_dimensions=*/DenseIntElementsAttr()); auto beta_mul_batch_mean = rewriter.create( op.getLoc(), batch_mean.getType(), beta, batch_mean, @@ -2526,7 +2420,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // new_running_variance = alpha * old_variance + beta * batch_variance. auto alpha_mul_old_variance = rewriter.create( - op.getLoc(), op.variance().getType(), alpha, op.variance(), + op.getLoc(), op.getVariance().getType(), alpha, op.getVariance(), /*broadcast_dimensions=*/DenseIntElementsAttr()); auto beta_mul_batch_variance = rewriter.create( op.getLoc(), corrected_variance.getType(), beta, corrected_variance, @@ -2571,8 +2465,8 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { auto bn_train_op = rewriter.create( op.getLoc(), /*result_type=*/bn_train_input_type_tensor, bn_train_input, - op.scale(), op.offset(), op.mean(), op.variance(), op.epsilon(), - feature_dim.getInt()); + op.getScale(), op.getOffset(), op.getMean(), op.getVariance(), + op.getEpsilon(), feature_dim.getInt()); // Convert back to input type to stay aligned with expected output type // for TF op. @@ -2586,10 +2480,10 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // reserved_space_2. if (std::is_same::value) { rewriter.replaceOp(op, {/*y=*/y_out, - /*batch_mean=*/op.mean(), - /*batch_variance=*/op.variance(), - /*reserve_space_1=*/op.mean(), - /*reserve_space_2=*/op.variance()}); + /*batch_mean=*/op.getMean(), + /*batch_variance=*/op.getVariance(), + /*reserve_space_1=*/op.getMean(), + /*reserve_space_2=*/op.getVariance()}); } else { // For FusedBatchNormV3Op, also create a constant tensor to forward to // last reserve_space_3 output. @@ -2606,10 +2500,10 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { dummy_const = rewriter.create( op.getLoc(), reserve_space_3_type, dummy_const); rewriter.replaceOp(op, {/*y=*/y_out, - /*batch_mean=*/op.mean(), - /*batch_variance=*/op.variance(), - /*reserve_space_1=*/op.mean(), - /*reserve_space_2=*/op.variance(), + /*batch_mean=*/op.getMean(), + /*batch_variance=*/op.getVariance(), + /*reserve_space_1=*/op.getMean(), + /*reserve_space_2=*/op.getVariance(), /*reserve_space_3=*/dummy_const}); } } @@ -2694,7 +2588,7 @@ Operation *AvgPoolDivideByCount( RankedTensorType orig_input_type = tensorflow::GetTypeFromTFTensorShape(input_shape, element_type); - if (op.padding() == "VALID") { + if (op.getPadding() == "VALID") { // All window counts are equal here because we don't have padding // (each entry of `pooled` corresponds to a window that consists of // original input entries only). @@ -2707,7 +2601,7 @@ Operation *AvgPoolDivideByCount( result = rewriter.create( loc, pooled_type, pooled, divisor, scalar_broadcast_dims); } else { - assert(op.padding() == "SAME"); + assert(op.getPadding() == "SAME"); // For SAME padding, only original entries that contributed to a window // are counted for the average of this window, not padded entries. @@ -2717,8 +2611,9 @@ Operation *AvgPoolDivideByCount( // Get padding for the input. DenseIntElementsAttr input_padding_attr = - GetReduceWindowPaddingAsAttr( - input_shape, op.ksize(), op.strides(), op.padding(), &rewriter); + GetReduceWindowPaddingAsAttr(input_shape, op.getKsize(), + op.getStrides(), op.getPadding(), + &rewriter); // Count the 1's in each window, using the same padding as for the input, // which gives us the window counts by which `pooled` needs to be divided. @@ -2726,8 +2621,8 @@ Operation *AvgPoolDivideByCount( loc, pooled_type, /*operand=*/all_ones_tensor, /*init_value=*/zero, - /*window_dimensions=*/GetI64ElementsAttr(op.ksize()), - /*window_strides=*/GetI64ElementsAttr(op.strides()), + /*window_dimensions=*/GetI64ElementsAttr(op.getKsize()), + /*window_strides=*/GetI64ElementsAttr(op.getStrides()), /*base_dilations=*/DenseIntElementsAttr(), /*window_dilations=*/DenseIntElementsAttr(), /*padding=*/input_padding_attr); @@ -2740,8 +2635,8 @@ Operation *AvgPoolDivideByCount( return result; } -Value GetAvgPoolInput(TF::AvgPoolOp op) { return op.value(); } -Value GetAvgPoolInput(TF::AvgPool3DOp op) { return op.input(); } +Value GetAvgPoolInput(TF::AvgPoolOp op) { return op.getValue(); } +Value GetAvgPoolInput(TF::AvgPool3DOp op) { return op.getInput(); } // Converts AvgPool op to HLO ReduceWindow op by setting appropriate window // dimensions with add as the reduction function. The reduction result is @@ -2779,11 +2674,11 @@ class ConvertAvgPoolOp : public OpRewritePattern { Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter); DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( - input_type.getShape(), op.ksize(), op.strides(), op.padding(), + input_type.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), &rewriter); auto reduce = rewriter.create( op.getLoc(), result_type, input_value, init, - GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()), + GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), /*base_dilations=*/DenseIntElementsAttr(), /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); BuildReduceBody(sum_element_type, &reduce.getBody(), &rewriter); @@ -2793,8 +2688,8 @@ class ConvertAvgPoolOp : public OpRewritePattern { SmallVector input_shape( llvm::to_vector(input_type.getShape())); SmallVector ksize, strides; - GetI64ArrayAttrValues(op.ksize(), &ksize); - GetI64ArrayAttrValues(op.strides(), &strides); + GetI64ArrayAttrValues(op.getKsize(), &ksize); + GetI64ArrayAttrValues(op.getStrides(), &strides); Operation *result_op = AvgPoolDivideByCount( reduce.getResult(0), input_shape, ksize, strides, op, init, rewriter); @@ -2866,12 +2761,12 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); tensorflow::TensorFormat data_format; - if (!FormatFromString(op.data_format().str(), &data_format)) { + if (!FormatFromString(op.getDataFormat().str(), &data_format)) { return op.emitOpError("invalid data format"); } // `out_grad` is the gradient that was propagated via backpropagation from // the output layer. - Value out_grad = op.grad(); + Value out_grad = op.getGrad(); auto out_grad_type = out_grad.getType().template dyn_cast(); if (!out_grad_type) { @@ -2879,7 +2774,7 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { } Type element_type = out_grad_type.getElementType(); DenseIntElementsAttr orig_input_shape_attr; - if (!matchPattern(op.orig_input_shape(), + if (!matchPattern(op.getOrigInputShape(), m_Constant(&orig_input_shape_attr))) { return failure(); } @@ -2887,8 +2782,8 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { DimVector orig_input_shape(orig_input_shape_values.begin(), orig_input_shape_values.end()); DimVector ksize, strides; - GetI64ArrayAttrValues(op.ksize(), &ksize); - GetI64ArrayAttrValues(op.strides(), &strides); + GetI64ArrayAttrValues(op.getKsize(), &ksize); + GetI64ArrayAttrValues(op.getStrides(), &strides); Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter); auto out_grad_divided = AvgPoolDivideByCount( @@ -2896,7 +2791,8 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { // Get same padding as for original input. PaddingArray orig_padding = GetReduceWindowPaddingAsArray( - orig_input_shape, op.ksize(), op.strides(), op.padding(), &rewriter); + orig_input_shape, op.getKsize(), op.getStrides(), op.getPadding(), + &rewriter); // Add padding around `out_grad_divided` values in such a way that the // subsequent `ReduceWindowOp` produces the gradient. @@ -2969,7 +2865,7 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { sum_element_type), /*operand=*/reduce_window_input, /*init_value=*/zero, - /*window_dimensions=*/GetI64ElementsAttr(op.ksize()), + /*window_dimensions=*/GetI64ElementsAttr(op.getKsize()), /*window_strides=*/ones, /*base_dilations=*/DenseIntElementsAttr(), /*window_dilations=*/DenseIntElementsAttr(), @@ -3009,10 +2905,10 @@ class ConvertMaxPoolOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Type element_type = - op.input().getType().template cast().getElementType(); + op.getInput().getType().template cast().getElementType(); if (!element_type.isSignlessIntOrFloat()) return failure(); tensorflow::Padding padding; - if (!GetPaddingFromString(op.padding().str(), &padding).ok()) + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) return failure(); if (padding == tensorflow::Padding::EXPLICIT) { return failure(); @@ -3021,13 +2917,15 @@ class ConvertMaxPoolOp : public OpRewritePattern { ConstantOp init = GetScalarLimitConstOfType( element_type, loc, hlo::kInfinityLowest, &rewriter); - auto input_ty = op.input().getType().template dyn_cast(); + auto input_ty = + op.getInput().getType().template dyn_cast(); if (!input_ty) return failure(); DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( - input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter); + input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), + &rewriter); auto reduce = rewriter.create( - loc, op.getType(), op.input(), init, GetI64ElementsAttr(op.ksize()), - GetI64ElementsAttr(op.strides()), + loc, op.getType(), op.getInput(), init, + GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), /*base_dilations=*/DenseIntElementsAttr(), /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); BuildReduceBody(element_type, &reduce.getBody(), &rewriter); @@ -3049,17 +2947,17 @@ class ConvertSelectOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::SelectOp op, PatternRewriter &rewriter) const override { // This lowering only works on ranked types. - auto cond_type = op.condition().getType().dyn_cast(); - auto then_type = op.then_value().getType().dyn_cast(); - auto else_type = op.else_value().getType().dyn_cast(); + auto cond_type = op.getCondition().getType().dyn_cast(); + auto then_type = op.getThenValue().getType().dyn_cast(); + auto else_type = op.getElseValue().getType().dyn_cast(); if (!cond_type || !then_type || !else_type) { return failure(); } ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value cond_shape = b.createOrFold(op.condition()); - Value then_shape = b.createOrFold(op.then_value()); - Value else_shape = b.createOrFold(op.else_value()); + Value cond_shape = b.createOrFold(op.getCondition()); + Value then_shape = b.createOrFold(op.getThenValue()); + Value else_shape = b.createOrFold(op.getElseValue()); // First check that the `then` and `else` shapes are the equal. Value assumption = @@ -3095,7 +2993,7 @@ class ConvertSelectOp : public OpRewritePattern { b.createBlock(&assuming_op.getDoRegion()); // Broadcast the cond if necessary. - Value cond = op.condition(); + Value cond = op.getCondition(); if (needs_broadcast) { Value result_extents = b.create( GetExtentsTensorTypeFor(result_type), then_shape); @@ -3105,8 +3003,8 @@ class ConvertSelectOp : public OpRewritePattern { cond, result_extents, GetI64ElementsAttrForSeq(0, cond_type.getRank(), &b)); } - Value select = b.create(result_type, cond, op.then_value(), - op.else_value()); + Value select = b.create( + result_type, cond, op.getThenValue(), op.getElseValue()); b.create(select); rewriter.replaceOp(op, {assuming_op.getResult(0)}); return success(); @@ -3179,9 +3077,9 @@ class ConvertSliceOpDynamic : public OpRewritePattern { LogicalResult matchAndRewrite(TF::SliceOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value input = op.input(); - Value begin_indices = op.begin(); - Value sizes = op.size(); + Value input = op.getInput(); + Value begin_indices = op.getBegin(); + Value sizes = op.getSize(); auto input_ty = input.getType().dyn_cast(); auto begin_type = begin_indices.getType().dyn_cast(); @@ -3196,7 +3094,7 @@ class ConvertSliceOpDynamic : public OpRewritePattern { // TODO(disc): remove static shape check once folding/canonicalization func // added DenseIntElementsAttr size_attr; - if (matchPattern(op.size(), m_Constant(&size_attr)) && input_ty.hasStaticShape() + if (matchPattern(op.getSize(), m_Constant(&size_attr)) && input_ty.hasStaticShape() && result_ty.hasStaticShape()) { return failure(); } @@ -3340,15 +3238,15 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op, PatternRewriter &rewriter) const override { - Value lhs = op.x(); - Value rhs = op.y(); + Value lhs = op.getX(); + Value rhs = op.getY(); auto lhs_type = lhs.getType().dyn_cast(); auto rhs_type = rhs.getType().dyn_cast(); if (!lhs_type || !rhs_type) return failure(); - if (lhs_type.getElementType().isa() && op.adj_x()) { + if (lhs_type.getElementType().isa() && op.getAdjX()) { lhs = rewriter.create(op.getLoc(), lhs_type, lhs); } - if (rhs_type.getElementType().isa() && op.adj_y()) { + if (rhs_type.getElementType().isa() && op.getAdjY()) { rhs = rewriter.create(op.getLoc(), rhs_type, rhs); } @@ -3361,9 +3259,9 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { int64_t rank = lhs_type.getRank(); auto batch_dimensions = llvm::to_vector<4>(llvm::seq(0, rank - 2)); auto lhs_contracting_dimensions = llvm::to_vector<4>( - llvm::makeArrayRef({op.adj_x() ? rank - 2 : rank - 1})); + llvm::ArrayRef({op.getAdjX() ? rank - 2 : rank - 1})); auto rhs_contracting_dimensions = llvm::to_vector<4>( - llvm::makeArrayRef({op.adj_y() ? rank - 1 : rank - 2})); + llvm::ArrayRef({op.getAdjY() ? rank - 1 : rank - 2})); auto dimension_numbers = DotDimensionNumbersAttr::get( rewriter.getContext(), /*lhs_batching_dimensions=*/batch_dimensions, @@ -3385,15 +3283,15 @@ class ConvertBatchMatMulOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::BatchMatMulOp op, PatternRewriter &rewriter) const override { - Value lhs = op.x(); - Value rhs = op.y(); + Value lhs = op.getX(); + Value rhs = op.getY(); auto lhs_type = lhs.getType().dyn_cast(); auto rhs_type = rhs.getType().dyn_cast(); if (!lhs_type || !rhs_type) return failure(); - if (lhs_type.getElementType().isa() && op.adj_x()) { + if (lhs_type.getElementType().isa() && op.getAdjX()) { lhs = rewriter.create(op.getLoc(), lhs_type, lhs); } - if (rhs_type.getElementType().isa() && op.adj_y()) { + if (rhs_type.getElementType().isa() && op.getAdjY()) { rhs = rewriter.create(op.getLoc(), rhs_type, rhs); } @@ -3406,8 +3304,8 @@ class ConvertBatchMatMulOp : public OpRewritePattern { rewriter.getContext(), /*lhs_batching_dimensions=*/batch_dimensions, /*rhs_batching_dimensions=*/batch_dimensions, - /*lhs_contracting_dimensions=*/{op.adj_x() ? rank - 2 : rank - 1}, - /*rhs_contracting_dimensions=*/{op.adj_y() ? rank - 1 : rank - 2}); + /*lhs_contracting_dimensions=*/{op.getAdjX() ? rank - 2 : rank - 1}, + /*rhs_contracting_dimensions=*/{op.getAdjY() ? rank - 1 : rank - 2}); // TODO(silvasean): Emit shape checks for contracting dimensions. // (The batch dimensions are checked by the broadcasting logic) rewriter.replaceOpWithNewOp(op, op.getType(), lhs, rhs, @@ -3457,12 +3355,12 @@ class ConvertSplitOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::SplitOp op, PatternRewriter &rewriter) const override { // We can only split along static dimensions. - auto input_type = op.value().getType().dyn_cast(); + auto input_type = op.getValue().getType().dyn_cast(); if (!input_type || !input_type.hasStaticShape()) return failure(); // We can only match when the split dimension is a constant scalar. DenseIntElementsAttr split_dim_attr; - if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr))) + if (!matchPattern(op.getSplitDim(), m_Constant(&split_dim_attr))) return failure(); // Get the dimension we are splitting at. Offset properly if it's negative. @@ -3472,9 +3370,6 @@ class ConvertSplitOp : public OpRewritePattern { // Calculate the dimension size for each slice along the split dimension. int64_t input_dim_size = input_type.getDimSize(dim_index); - // If we are splitting along the dynamic dimension then we cannot compute - // the static dimension length. - if (ShapedType::isDynamic(input_dim_size)) return failure(); int64_t num_splits = op.getNumResults(); int64_t slice_size = input_dim_size / num_splits; @@ -3487,7 +3382,7 @@ class ConvertSplitOp : public OpRewritePattern { // Parameters for constructing each slice. SmallVector begin_indices(input_rank, 0); - auto end_indices = tensorflow::ConvertMlirShapeToTF(input_type.getShape()); + auto end_indices = llvm::to_vector<4>(input_type.getShape()); SmallVector strides(input_rank, 1); // All HLO slice results used to replace the original tf.Split op. @@ -3498,7 +3393,7 @@ class ConvertSplitOp : public OpRewritePattern { begin_indices[dim_index] = i * slice_size; end_indices[dim_index] = (i + 1) * slice_size; slices.push_back( - rewriter.create(op.getLoc(), slice_type, op.value(), + rewriter.create(op.getLoc(), slice_type, op.getValue(), GetI64ElementsAttr(begin_indices, &rewriter), GetI64ElementsAttr(end_indices, &rewriter), GetI64ElementsAttr(strides, &rewriter))); @@ -3520,12 +3415,19 @@ class ConvertSplitOpDynamic : public OpRewritePattern { LogicalResult matchAndRewrite(TF::SplitOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value input = op.value(); + Value input = op.getValue(); auto input_type = input.getType().dyn_cast(); if (!input_type) return failure(); + + // TODO(disc): remove static shape check once folding/canonicalization func + // added and ConvertSplitOp deleted. Calculate the dimension size for each + // slice along the split dimension. We are splitting along the dynamic + // dimension, or using static pattern transform + if (input_type.hasStaticShape()) return failure(); + // We can only match when the split dimension is a constant scalar. DenseIntElementsAttr split_dim_attr; - if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr))) + if (!matchPattern(op.getSplitDim(), m_Constant(&split_dim_attr))) return failure(); // Get the dimension we are splitting at. Offset properly if it's negative. @@ -3634,19 +3536,19 @@ class ConvertSplitVOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::SplitVOp op, PatternRewriter &rewriter) const override { - // We can only split along static dimensions. + // We can only split inputs that have fully static shape. // TODO(b/145731001): enhance to support dynamic-shaped inputs. - auto input_type = op.value().getType().dyn_cast(); - if (!input_type) return failure(); + auto input_type = op.getValue().getType().dyn_cast(); + if (!input_type || !input_type.hasStaticShape()) return failure(); // We can only match when the split dimension is a constant scalar. DenseIntElementsAttr split_dim_attr; - if (!matchPattern(op.split_dim(), m_Constant(&split_dim_attr))) + if (!matchPattern(op.getSplitDim(), m_Constant(&split_dim_attr))) return failure(); // We can only match when the split sizes is a constant int vector. DenseIntElementsAttr split_sizes_attr; - if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr))) + if (!matchPattern(op.getSizeSplits(), m_Constant(&split_sizes_attr))) return failure(); // Get each chunck's size along the dimension to split. It may contain @@ -3674,8 +3576,6 @@ class ConvertSplitVOp : public OpRewritePattern { if (dim_index < 0) dim_index += input_rank; int64_t input_dim_size = input_type.getDimSize(dim_index); - if (ShapedType::isDynamic(input_dim_size)) return failure(); - assert(((dynamic_dim_index && total_dim_size <= input_dim_size) || (!dynamic_dim_index && total_dim_size == input_dim_size)) && "invalid split sizes"); @@ -3686,7 +3586,7 @@ class ConvertSplitVOp : public OpRewritePattern { // Parameters for constructing each slice. SmallVector begin_indices(input_rank, 0); - auto end_indices = tensorflow::ConvertMlirShapeToTF(input_type.getShape()); + auto end_indices = llvm::to_vector<4>(input_type.getShape()); SmallVector strides(input_rank, 1); // All HLO slice results used to replace the original tf.Split op. @@ -3696,7 +3596,8 @@ class ConvertSplitVOp : public OpRewritePattern { for (int i = 0, end = op.getNumResults(); i < end; ++i) { end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i]; slices.push_back(rewriter.create( - op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter), + op.getLoc(), op.getValue(), + GetI64ElementsAttr(begin_indices, &rewriter), GetI64ElementsAttr(end_indices, &rewriter), GetI64ElementsAttr(strides, &rewriter))); // Prepare the begin indice for the next slice. @@ -3777,10 +3678,10 @@ class ConvertStridedSliceOp : public OpRewritePattern { } Location loc = op.getLoc(); - Value input = op.input(); + Value input = op.getInput(); if (!dims_to_reverse.empty()) input = rewriter.create( - loc, input_ty, op.input(), + loc, input_ty, op.getInput(), GetI64ElementsAttr(dims_to_reverse, &rewriter)); auto sliced = rewriter.create( loc, input, GetI64ElementsAttr(hlo_begin_indices, &rewriter), @@ -3800,7 +3701,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // If begin and end values are dynamic, we can only support this lowering // if strides are a known value of 1. DenseIntElementsAttr sparse_strides_attr; - if (!matchPattern(op.strides(), m_Constant(&sparse_strides_attr))) { + if (!matchPattern(op.getStrides(), m_Constant(&sparse_strides_attr))) { return rewriter.notifyMatchFailure( op, "requires that strides are known when begin/end values are dynamic"); @@ -3821,7 +3722,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // When begin/end values are dynamic, the ellipsis mask, if set, must refer // to the last dimension. - int ellipsis_mask = op.ellipsis_mask(); + int ellipsis_mask = op.getEllipsisMask(); if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim))) return rewriter.notifyMatchFailure( op, @@ -3838,9 +3739,9 @@ class ConvertStridedSliceOp : public OpRewritePattern { // Begin must be a ranked, 1-dimensional tensor: This is checked by the // verifier. int64_t slicing_dim_size = - op.begin().getType().cast().getDimSize(0); - uint64_t begin_mask = op.begin_mask(); - uint64_t end_mask = op.end_mask(); + op.getBegin().getType().cast().getDimSize(0); + uint64_t begin_mask = op.getBeginMask(); + uint64_t end_mask = op.getEndMask(); const int input_rank = input_shape.size(); for (int d = 0; d < input_rank; ++d) { // Each dimension is either sliced fully or has size of one. @@ -3860,7 +3761,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // For the dimensions that are to be sliced, all have slice sizes of 1. SmallVector slice_sizes; auto begin_element_ty = - op.begin().getType().cast().getElementType(); + op.getBegin().getType().cast().getElementType(); // Scalar tensor type. TensorType type = tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, begin_element_ty); @@ -3875,7 +3776,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { } auto index = rewriter.create( - loc, op.begin(), GetI64ElementsAttr({d}, &rewriter), + loc, op.getBegin(), GetI64ElementsAttr({d}, &rewriter), GetI64ElementsAttr({d + 1}, &rewriter), GetI64ElementsAttr({1}, &rewriter)); // Convert index to scalar. @@ -3899,7 +3800,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // This must be an xla DynamicSlice op due to the inputs that aren't // constant. auto sliced = rewriter.create( - loc, sliced_type, op.input(), slice_begin_indices, slice_sizes_attr); + loc, sliced_type, op.getInput(), slice_begin_indices, slice_sizes_attr); // Reshape slice result so that the shape is updated depending on // 'new_axis_mask' or 'shrink_axis_mask' attributes. @@ -3914,7 +3815,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // // TODO(hinsu): Relax this constraint for ops without negative indices and // strides. - auto input_ty = op.input().getType().dyn_cast(); + auto input_ty = op.getInput().getType().dyn_cast(); if (!input_ty || !input_ty.hasStaticShape()) return failure(); // Output shape needs to be static to apply 'new_axis_mask' or @@ -3925,8 +3826,8 @@ class ConvertStridedSliceOp : public OpRewritePattern { if (!result_ty || !result_ty.hasStaticShape()) return failure(); DenseIntElementsAttr sparse_begin_attr, sparse_end_attr; - if (!matchPattern(op.begin(), m_Constant(&sparse_begin_attr)) || - !matchPattern(op.end(), m_Constant(&sparse_end_attr))) { + if (!matchPattern(op.getBegin(), m_Constant(&sparse_begin_attr)) || + !matchPattern(op.getEnd(), m_Constant(&sparse_end_attr))) { // Disable this path and rely on ConvertStridedSliceOpDynamic to do the // conversion. // return rewriteWithUnknownBegin(op, input_ty, result_ty, rewriter); @@ -3957,7 +3858,7 @@ class ConvertStridedSliceGradOp PatternRewriter &rewriter) const override { // We need constant input shape to perform padding calculations later. DenseIntElementsAttr input_shape_attr; - if (!matchPattern(op.shape(), m_Constant(&input_shape_attr))) + if (!matchPattern(op.getShape(), m_Constant(&input_shape_attr))) return failure(); // We also need constant begin/end indices and strides to perform padding @@ -3970,7 +3871,7 @@ class ConvertStridedSliceGradOp &strides)) return failure(); - Value grad = op.dy(); + Value grad = op.getDy(); Type element_type = grad.getType().cast().getElementType(); // Perform reshape to undo any new/shrink axes done by strided slice. @@ -4345,14 +4246,17 @@ bool GetSlicedBoundRanges( SmallVectorImpl& input_shape_vec, SmallVectorImpl& output_shape_vec) { Location loc = op.getLoc(); - auto input_ty = op.input().getType().cast(); + auto input_ty = op.getInput().getType().cast(); int64_t rank = input_ty.getRank(); - Value in_shape = rewriter.create(loc, op.input()); + Value in_shape = rewriter.create(loc, op.getInput()); - int64_t sparse_rank = op.begin().getType().cast().getDimSize(0); - if (op.end().getType().cast().getDimSize(0) != sparse_rank) + int64_t sparse_rank = + op.getBegin().getType().cast().getDimSize(0); + if (op.getEnd().getType().cast().getDimSize(0) != + sparse_rank) return false; - if (op.strides().getType().cast().getDimSize(0) != sparse_rank) + if (op.getStrides().getType().cast().getDimSize(0) != + sparse_rank) return false; auto to_shape_scalar_type = [&](Value v) { @@ -4377,15 +4281,15 @@ bool GetSlicedBoundRanges( for (int64_t i = 0; i < sparse_rank; ++i) { Value idx = rewriter.create(loc, i); sparse_begin.push_back( - rewriter.create(loc, op.begin(), idx)); + rewriter.create(loc, op.getBegin(), idx)); sparse_end.push_back( - rewriter.create(loc, op.end(), idx)); + rewriter.create(loc, op.getEnd(), idx)); } CalculateSlicedShapeFromSparseIndices( - &rewriter, loc, shape_scalar_type, input_shape_vec, sparse_begin, sparse_end, - sparse_strides, op.begin_mask(), op.end_mask(), op.ellipsis_mask(), - op.new_axis_mask(), op.shrink_axis_mask(), + &rewriter, loc, shape_scalar_type, input_shape_vec, sparse_begin, + sparse_end, sparse_strides, op.getBeginMask(), op.getEndMask(), + op.getEllipsisMask(), op.getNewAxisMask(), op.getShrinkAxisMask(), &slice_begin, &slice_end, &slice_stride, &output_shape_vec, /*calc_final_shape*/ true); @@ -4418,7 +4322,7 @@ class ConvertStridedSliceOpDynamic : public OpRewritePattern loc, 1, elem_ty.getIntOrFloatBitWidth()); Value zero = rewriter.create( loc, 0, elem_ty.getIntOrFloatBitWidth()); - auto in_shape = rewriter.create(loc, op.input()); + auto in_shape = rewriter.create(loc, op.getInput()); for (int64_t i = 0; i < indices_elements; ++i) { Value idx = rewriter.create(loc, i); @@ -4468,16 +4372,16 @@ class ConvertStridedSliceOpDynamic : public OpRewritePattern LogicalResult matchAndRewrite(TF::StridedSliceOp op, PatternRewriter &rewriter) const override { // Only static rank case is supported a.t.m. - auto input_ty = op.input().getType().dyn_cast(); + auto input_ty = op.getInput().getType().dyn_cast(); if (!input_ty) return failure(); auto result_ty = op.getType().dyn_cast(); if (!result_ty) return failure(); // Only static shape begin/end/strides is supported a.t.m. - auto begin_ty = op.begin().getType().dyn_cast(); - auto end_ty = op.end().getType().dyn_cast(); - auto strides_ty = op.strides().getType().dyn_cast(); + auto begin_ty = op.getBegin().getType().dyn_cast(); + auto end_ty = op.getEnd().getType().dyn_cast(); + auto strides_ty = op.getStrides().getType().dyn_cast(); if (!begin_ty || !begin_ty.hasStaticShape() || !end_ty || !end_ty.hasStaticShape() || !strides_ty || !strides_ty.hasStaticShape()) @@ -4488,7 +4392,7 @@ class ConvertStridedSliceOpDynamic : public OpRewritePattern // TODO(disc): support negative indices. // TODO(disc): support dynamic stride DenseIntElementsAttr sparse_strides_attr; - if (!matchPattern(op.strides(), m_Constant(&sparse_strides_attr))) { + if (!matchPattern(op.getStrides(), m_Constant(&sparse_strides_attr))) { return rewriter.notifyMatchFailure(op, "requires that strides are constants"); } @@ -4519,12 +4423,12 @@ class ConvertStridedSliceOpDynamic : public OpRewritePattern "failed to calculate reverse dims"); } - Value input = op.input(); + Value input = op.getInput(); Location loc = op.getLoc(); if (!dims_to_reverse.empty()) { input = rewriter.create( - loc, input_ty, op.input(), - GetI64ElementsAttr(dims_to_reverse, &rewriter)); + loc, input_ty, op.getInput(), + GetI64ElementsAttr(dims_to_reverse, &rewriter)); } Value begin_vec = @@ -4533,8 +4437,8 @@ class ConvertStridedSliceOpDynamic : public OpRewritePattern rewriter.create(loc, hlo_end_indices); Value strides_vec = rewriter.create(loc, hlo_strides); - SmallVector slice_result_shape( - begin_indices.size(), ShapedType::kDynamicSize); + SmallVector slice_result_shape(begin_indices.size(), + ShapedType::kDynamic); RankedTensorType slice_result_type = RankedTensorType::get(slice_result_shape, input_ty.getElementType()); Value sliced = rewriter.create( @@ -4588,11 +4492,11 @@ class ConvertRangeOp : public OpRewritePattern { auto iota = rewriter.create(op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); auto scaled = rewriter.create( - op.getLoc(), result_type, iota, op.delta(), - hlo::getBroadcastDimensionsAttr(&rewriter, iota, op.delta())); + op.getLoc(), result_type, iota, op.getDelta(), + hlo::getBroadcastDimensionsAttr(&rewriter, iota, op.getDelta())); rewriter.replaceOpWithNewOp( - op, result_type, scaled, op.start(), - hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); + op, result_type, scaled, op.getStart(), + hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.getStart())); return success(); } }; @@ -4619,9 +4523,9 @@ class ConvertDynamicRangeOp : public OpRewritePattern { return failure(); } - Value start = op.start(); - Value delta = op.delta(); - Value limit = op.limit(); + Value start = op.getStart(); + Value delta = op.getDelta(); + Value limit = op.getLimit(); // To compute the length we need to use floating point calculations so that // ceil can be computed for the number of steps. @@ -4710,7 +4614,7 @@ class ConvertLinSpaceOp : public OpRewritePattern { } DenseIntElementsAttr num_attr; - if (!matchPattern(op.num(), m_Constant(&num_attr))) { + if (!matchPattern(op.getNum(), m_Constant(&num_attr))) { return rewriter.notifyMatchFailure(op, "Num must be a constant scalar"); } @@ -4721,10 +4625,11 @@ class ConvertLinSpaceOp : public OpRewritePattern { // Calculate the scaling that needs to be applied to the iota. auto step_numerator = rewriter.create( - op.getLoc(), op.start().getType(), op.stop(), op.start(), - hlo::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start())); + op.getLoc(), op.getStart().getType(), op.getStop(), op.getStart(), + hlo::getBroadcastDimensionsAttr(&rewriter, op.getStop(), + op.getStart())); Value step_denominator = rewriter.create( - op.getLoc(), op.num(), result_type.getElementType()); + op.getLoc(), op.getNum(), result_type.getElementType()); if (num > 1) { Value one = GetScalarConstOfType(result_type.getElementType(), op.getLoc(), 1, &rewriter); @@ -4744,8 +4649,8 @@ class ConvertLinSpaceOp : public OpRewritePattern { op.getLoc(), result_type, iota, step, hlo::getBroadcastDimensionsAttr(&rewriter, iota, step)); rewriter.replaceOpWithNewOp( - op, result_type, scaled, op.start(), - hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); + op, result_type, scaled, op.getStart(), + hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.getStart())); return success(); } }; @@ -4772,12 +4677,13 @@ class GenericConvertReductionOp : public OpRewritePattern { // TODO(b/141785544): Update this to not require ranked shapes. // Input shape needs to be ranked to convert negative indices in TensorFlow // to absolute indices required by HLO. - auto input_ty = op.input().getType().template dyn_cast(); + auto input_ty = + op.getInput().getType().template dyn_cast(); if (!input_ty) return failure(); ArrayRef input_shape = input_ty.getShape(); DenseIntElementsAttr dimensions; - if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions))) + if (!matchPattern(op.getReductionIndices(), m_Constant(&dimensions))) return failure(); // Build the final shape from input_shape and dimensions using a bitmap @@ -4808,7 +4714,7 @@ class GenericConvertReductionOp : public OpRewritePattern { Type reduce_element_type = is_accumulation ? GetAccumulationType(element_type) : element_type; auto casted_input = - rewriter.create(loc, op.input(), reduce_element_type); + rewriter.create(loc, op.getInput(), reduce_element_type); // Each reduction op can have a different initial value. Value init = Derived::GetInitialValue(reduce_element_type, loc, &rewriter); @@ -4822,7 +4728,7 @@ class GenericConvertReductionOp : public OpRewritePattern { // The mean op needs to divide by the product of the reduced dimensions. if (std::is_same::value) { - Value in_shape = rewriter.create(loc, op.input()); + Value in_shape = rewriter.create(loc, op.getInput()); Value divisor_count = rewriter.create(loc, 1); for (size_t i = 0; i < input_shape.size(); ++i) { if (reduced_dimensions_bitmap[i]) { @@ -4855,7 +4761,7 @@ class GenericConvertReductionOp : public OpRewritePattern { // reshape. Various code generation techniques benefit from the knowledge // that this is a restricted form of shape manipulation that is just adding // unit dims. - if (op.keep_dims()) { + if (op.getKeepDims()) { for (auto &dim_is_reduced : llvm::enumerate(reduced_dimensions_bitmap)) { if (dim_is_reduced.value()) { auto index_attr = GetI32ElementsAttr( @@ -4999,7 +4905,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { RankedTensorType input_type = - op.input().getType().template dyn_cast(); + op.getInput().getType().template dyn_cast(); if (!input_type) { return failure(); } @@ -5014,7 +4920,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern { Derived::GetInitialValue(input_element_type, loc, rewriter); RankedTensorType output_type = - op.output().getType().template dyn_cast(); + op.getOutput().getType().template dyn_cast(); if (!output_type) { return rewriter.notifyMatchFailure(op, "requires known rank"); } @@ -5027,18 +4933,18 @@ class ConvertArgMinMaxOp : public OpRewritePattern { input_type.getShape(), index_element_type); llvm::Optional optional_axis = - GetIntegerHLOAxisFromTFAxis(op.dimension(), input_type.getRank()); + GetIntegerHLOAxisFromTFAxis(op.getDimension(), input_type.getRank()); if (!optional_axis.has_value()) return rewriter.notifyMatchFailure(op, "required axis"); - int64_t axis = optional_axis.getValue(); + int64_t axis = optional_axis.value(); IntegerAttr iota_dimension = IntegerAttr::get(rewriter.getIntegerType(64), axis); - Value input_shape = rewriter.create(loc, op.input()); + Value input_shape = rewriter.create(loc, op.getInput()); Value index_values = rewriter.create( loc, index_type, input_shape, iota_dimension); - Value operands[] = {op.input(), index_values}; + Value operands[] = {op.getInput(), index_values}; Value init_values[] = {init_value, index_init_value}; DenseIntElementsAttr reduction_dimensions = GetI64ElementsAttr({axis}, &rewriter); @@ -5111,11 +5017,11 @@ class ConvertTensorScatterOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { auto tensor_ty = - op.tensor().getType().template dyn_cast(); + op.getTensor().getType().template dyn_cast(); auto indices_ty = - op.indices().getType().template dyn_cast(); + op.getIndices().getType().template dyn_cast(); auto updates_ty = - op.updates().getType().template dyn_cast(); + op.getUpdates().getType().template dyn_cast(); if (!tensor_ty || !indices_ty || !updates_ty) return failure(); // Last dimension of the indices needs to known at compile time for @@ -5124,7 +5030,7 @@ class ConvertTensorScatterOp : public OpRewritePattern { int64_t num_index_dims = indices_ty.getShape().back(); if (ShapedType::isDynamic(num_index_dims)) return failure(); - auto updates = op.updates(); + auto updates = op.getUpdates(); // Broadcast scalar `updates` in into expected shape as following shape: // updates.shape == indices.shape[:-1] + tensor.shape[indices.shape[-1]:] @@ -5161,11 +5067,11 @@ class ConvertTensorScatterOp : public OpRewritePattern { rewriter.create(op->getLoc(), const_type, const_attr); auto broadcast_to_type = tensorflow::GetTypeFromTFTensorShape( - llvm::makeArrayRef(expected_update_shape), + llvm::ArrayRef(expected_update_shape), updates_ty.getElementType()); updates = rewriter.create( - op->getLoc(), broadcast_to_type, op.updates(), const_op); + op->getLoc(), broadcast_to_type, op.getUpdates(), const_op); updates_ty = updates.getType().template dyn_cast(); } @@ -5185,9 +5091,9 @@ class ConvertTensorScatterOp : public OpRewritePattern { indices_rank - 1); Location loc = op.getLoc(); - auto scatter = rewriter.create(loc, op.getType(), - ValueRange(Value(op.tensor())), - op.indices(), updates, dims_attr); + auto scatter = rewriter.create( + loc, op.getType(), ValueRange(Value(op.getTensor())), op.getIndices(), + updates, dims_attr); Derived::BuildScatterBody(tensor_ty.getElementType(), &scatter.getUpdateComputation(), loc, rewriter); @@ -5304,13 +5210,13 @@ class ConvertTileOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::TileOp op, PatternRewriter &rewriter) const override { - auto input_ty = op.input().getType().dyn_cast(); + auto input_ty = op.getInput().getType().dyn_cast(); if (!input_ty || !input_ty.hasStaticShape()) return failure(); ArrayRef input_shape = input_ty.getShape(); Type element_type = input_ty.getElementType(); DenseIntElementsAttr multiples; - if (!matchPattern(op.multiples(), m_Constant(&multiples)) || + if (!matchPattern(op.getMultiples(), m_Constant(&multiples)) || multiples.getType().getRank() != 1) return failure(); @@ -5350,7 +5256,7 @@ class ConvertTileOp : public OpRewritePattern { Type output_type = op.getType(); Value result = rewriter.create( - loc, broadcasted_type, op.input(), + loc, broadcasted_type, op.getInput(), GetI64ElementsAttr(broadcast_dimensions, &rewriter)); if (output_type != broadcasted_type) { @@ -5383,8 +5289,8 @@ class ConvertTileOpDynamic : public OpRewritePattern { LogicalResult matchAndRewrite(TF::TileOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); - Value input = op.input(); - Value multiples = op.multiples(); + Value input = op.getInput(); + Value multiples = op.getMultiples(); auto input_ty = input.getType().dyn_cast(); if (!input_ty) return failure(); auto result_ty = op.getType().dyn_cast(); @@ -5398,7 +5304,7 @@ class ConvertTileOpDynamic : public OpRewritePattern { SmallVector input_shape_values; for (int64_t i = 0; i < input_rank; ++i) { auto dim_size = input_ty.getDimSize(i); - if (dim_size == ShapedType::kDynamicSize) { + if (dim_size == ShapedType::kDynamic) { input_shape_values.push_back( rewriter.create(loc, input, i)); } else { @@ -5444,7 +5350,7 @@ class ConvertTileOpDynamic : public OpRewritePattern { {static_cast(out_dim_size.size())}, index_ty), out_dim_size); SmallVector broadcast_shape(input_rank * 2, - ShapedType::kDynamicSize); + ShapedType::kDynamic); RankedTensorType broadcast_type = tensorflow::GetTypeFromTFTensorShape(broadcast_shape, element_type); Value broadcast = rewriter.create( @@ -5476,22 +5382,25 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Type element_type = - op.orig_input().getType().template cast().getElementType(); + Type element_type = op.getOrigInput() + .getType() + .template cast() + .getElementType(); // Compute paddings using the original input and kernel shape and strides. // Here, ReduceWindow op as used as the MaxPool op is lowered to the // ReduceWindow op. auto input_ty = - op.orig_input().getType().template dyn_cast(); + op.getOrigInput().getType().template dyn_cast(); if (!input_ty) return failure(); DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( - input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter); + input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), + &rewriter); auto result = rewriter.create( - loc, op.getType(), op.orig_input(), op.grad(), + loc, op.getType(), op.getOrigInput(), op.getGrad(), GetScalarConstOfType(element_type, loc, 0, &rewriter), - GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()), + GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), paddings_attr); BuildReduceBody(element_type, &result.getScatter(), &rewriter); @@ -5533,19 +5442,19 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { PatternRewriter &rewriter) const override { // Unpack all of the attributes. tensorflow::TensorFormat data_format; - if (!FormatFromString(op.data_format().str(), &data_format)) + if (!FormatFromString(op.getDataFormat().str(), &data_format)) return op.emitOpError("invalid data format"); constexpr int num_dims = num_spatial_dims + 2; int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); tensorflow::Padding padding; - if (!GetPaddingFromString(op.padding().str(), &padding).ok()) + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) return failure(); auto out_backprop_ty = - op.out_backprop().getType().template dyn_cast(); + op.getOutBackprop().getType().template dyn_cast(); auto filter_ty = - op.filter().getType().template dyn_cast(); + op.getFilter().getType().template dyn_cast(); // With the exception of out_backprop's batch dimension, out_backprop and // filter need to have static shape. Filter is validated here, out_backprop @@ -5561,24 +5470,24 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { // "tf.Pack"(%142, %cst_301, %cst_301, %cst_300) {axis = 0 : i64, ...} std::vector input_shape; DenseIntElementsAttr input_shape_attr; - if (matchPattern(op.input_sizes(), m_Constant(&input_shape_attr)) && + if (matchPattern(op.getInputSizes(), m_Constant(&input_shape_attr)) && input_shape_attr.getType().getRank() == 1) { input_shape.insert(input_shape.end(), input_shape_attr.getValues().begin(), input_shape_attr.getValues().end()); } else { - auto pack = op.input_sizes().template getDefiningOp(); - if (!pack || pack.axis() != 0) return failure(); + auto pack = op.getInputSizes().template getDefiningOp(); + if (!pack || pack.getAxis() != 0) return failure(); auto pack_ty = pack.getType().template dyn_cast(); if (!pack_ty || pack_ty.getRank() != 1) return failure(); for (auto i = 0; i < pack_ty.getDimSize(0); ++i) { if (i == batch_dim) { // We don't use the batch dimension below, so we don't care about // its size. Might as well populate it with -1. - input_shape.push_back(ShapedType::kDynamicSize); + input_shape.push_back(ShapedType::kDynamic); } else { DenseIntElementsAttr input_dims_attr; - if (matchPattern(pack.values()[i], m_Constant(&input_dims_attr)) && + if (matchPattern(pack.getValues()[i], m_Constant(&input_dims_attr)) && input_dims_attr.getType().getRank() == 0) { input_shape.push_back(input_dims_attr.getSplatValue()); } else { @@ -5588,11 +5497,11 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { } } - auto dilations_attr = GetI64ElementsAttr(op.dilations()); + auto dilations_attr = GetI64ElementsAttr(op.getDilations()); std::vector dilations{ dilations_attr.template getValues().begin(), dilations_attr.template getValues().end()}; - auto strides_attr = GetI64ElementsAttr(op.strides()); + auto strides_attr = GetI64ElementsAttr(op.getStrides()); std::vector strides{ strides_attr.template getValues().begin(), strides_attr.template getValues().end()}; @@ -5626,9 +5535,9 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { // Prepare metadata indexed by spatial_dim for computing pad_before // and pad_after. int64_t input_size = input_shape[spatial_dim]; - if (input_size == ShapedType::kDynamicSize) return failure(); + if (input_size == ShapedType::kDynamic) return failure(); int64_t output_size = out_backprop_ty.getDimSize(spatial_dim); - if (output_size == ShapedType::kDynamicSize) return failure(); + if (output_size == ShapedType::kDynamic) return failure(); int64_t filter_size = filter_ty.getDimSize(i); int64_t stride = strides[spatial_dim]; int64_t dilation = dilations[spatial_dim]; @@ -5665,12 +5574,12 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { {num_spatial_dims, 2}, rewriter.getIntegerType(64)); auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings); - Value filter = op.filter(); + Value filter = op.getFilter(); const int feature_dim = tensorflow::GetTensorFeatureDimIndex(num_dims, data_format); const int64_t in_depth = *(input_shape.begin() + feature_dim); - if (in_depth == ShapedType::kDynamicSize) return failure(); + if (in_depth == ShapedType::kDynamic) return failure(); const int64_t filter_in_depth = filter_shape[num_spatial_dims]; const int64_t feature_group_count = in_depth / filter_in_depth; @@ -5715,7 +5624,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { // activation gradients // = gradients (with padding and dilation) mirrored_weights Value result = rewriter.create( - op.getLoc(), op.getType(), op.out_backprop(), filter, + op.getLoc(), op.getType(), op.getOutBackprop(), filter, /*window_strides=*/ GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, &rewriter), @@ -5802,7 +5711,7 @@ class ConvertConvBackpropInputDynamic : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - tensorflow::Status ConvBackpropExtractAndVerifyDimensionDyn( + tensorflow::Status ConvBackpopExtractAndVerifyDimensionDyn( OpTy& op, PatternRewriter& rewriter, std::vector dilations, const std::vector& strides, tensorflow::Padding padding, Value padding_before, Value padding_after, int spatial_dim, @@ -5840,9 +5749,9 @@ class ConvertConvBackpropInputDynamic : public OpRewritePattern { }; Value idx = rewriter.create(loc, spatial_dim); dim->input_size = - rewriter.create(loc, op.input_sizes(), idx); - dim->filter_size = get_dim_value(op.filter(), filter_spatial_dim); - dim->output_size = get_dim_value(op.out_backprop(), spatial_dim); + rewriter.create(loc, op.getInputSizes(), idx); + dim->filter_size = get_dim_value(op.getFilter(), filter_spatial_dim); + dim->output_size = get_dim_value(op.getOutBackprop(), spatial_dim); dim->stride = strides[spatial_dim]; dim->dilation = dilations[spatial_dim]; @@ -5862,7 +5771,7 @@ class ConvertConvBackpropInputDynamic : public OpRewritePattern { dim->pad_after = sub_vals( sub_vals(padded_out_size, dim->expanded_output_size), dim->pad_before); - return tensorflow::Status::OK(); + return tensorflow::OkStatus(); } tensorflow::Status ConvBackpropComputeDimensionsV2Dyn( @@ -5896,11 +5805,11 @@ class ConvertConvBackpropInputDynamic : public OpRewritePattern { num_dims, "-dimensional"); } int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); - dims->batch_size = get_dim_value(op.out_backprop(), batch_dim); + dims->batch_size = get_dim_value(op.getOutBackprop(), batch_dim); // TODO(feiwen) : check in_depth and out_depth - dims->in_depth = get_dim_value(op.filter(), num_dims - 2); - dims->out_depth = get_dim_value(op.filter(), num_dims - 1); + dims->in_depth = get_dim_value(op.getFilter(), num_dims - 2); + dims->out_depth = get_dim_value(op.getFilter(), num_dims - 1); // TODO(feiwen): to support grouped conv dims->spatial_dims.resize(num_spatial_dims); @@ -5912,11 +5821,11 @@ class ConvertConvBackpropInputDynamic : public OpRewritePattern { padding_before = get_const(explicit_paddings[2 * image_dim]); padding_after = get_const(explicit_paddings[2 * image_dim + 1]); } - TF_RETURN_IF_ERROR(ConvBackpropExtractAndVerifyDimensionDyn( + TF_RETURN_IF_ERROR(ConvBackpopExtractAndVerifyDimensionDyn( op, rewriter, dilations, strides, padding, padding_before, padding_after, image_dim, i, &dims->spatial_dims[i])); } - return tensorflow::Status::OK(); + return tensorflow::OkStatus(); } bool GetPaddingValues(OpTy& op, PatternRewriter& rewriter, Value input_size, @@ -6003,37 +5912,37 @@ class ConvertConvBackpropInputDynamic : public OpRewritePattern { // Unpack all of the attributes. tensorflow::TensorFormat data_format; - if (!FormatFromString(op.data_format().str(), &data_format)) + if (!FormatFromString(op.getDataFormat().str(), &data_format)) return op.emitOpError("invalid data format"); tensorflow::Padding padding; - if (!GetPaddingFromString(op.padding().str(), &padding).ok()) + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) return failure(); auto input_sizes_ty = - op.input_sizes().getType().template dyn_cast(); + op.getInputSizes().getType().template dyn_cast(); constexpr int num_dims = num_spatial_dims + 2; if (!input_sizes_ty || input_sizes_ty.getRank() != 1 || input_sizes_ty.getShape()[0] != num_dims) return failure(); auto filter_ty = - op.filter().getType().template dyn_cast(); + op.getFilter().getType().template dyn_cast(); auto out_backprop_ty = - op.out_backprop().getType().template dyn_cast(); + op.getOutBackprop().getType().template dyn_cast(); DenseIntElementsAttr input_shape_attr; if (!out_backprop_ty || !filter_ty || (out_backprop_ty.hasStaticShape() && filter_ty.hasStaticShape() && - matchPattern(op.input_sizes(), m_Constant(&input_shape_attr)))) { + matchPattern(op.getInputSizes(), m_Constant(&input_shape_attr)))) { return failure(); } - auto dilations_attr = GetI64ElementsAttr(op.dilations()); + auto dilations_attr = GetI64ElementsAttr(op.getDilations()); std::vector dilations{ dilations_attr.template getValues().begin(), dilations_attr.template getValues().end()}; - auto strides_attr = GetI64ElementsAttr(op.strides()); + auto strides_attr = GetI64ElementsAttr(op.getStrides()); std::vector strides{ strides_attr.template getValues().begin(), strides_attr.template getValues().end()}; @@ -6077,7 +5986,7 @@ class ConvertConvBackpropInputDynamic : public OpRewritePattern { paddings.push_back(spatial_dim_i.pad_after); } - Value filter = op.filter(); + Value filter = op.getFilter(); // TODO(feiwen): support group conv const int feature_dim = @@ -6095,7 +6004,7 @@ class ConvertConvBackpropInputDynamic : public OpRewritePattern { tensorflow::GetTensorBatchDimIndex(num_dims, data_format); SmallVector operands; - operands.push_back(op.out_backprop()); + operands.push_back(op.getOutBackprop()); operands.push_back(filter); Value paddings_op = rewriter.create(op.getLoc(), paddings); @@ -6158,16 +6067,17 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern { PatternRewriter &rewriter) const override { // Unpack all of the attributes. tensorflow::TensorFormat data_format; - if (!FormatFromString(op.data_format().str(), &data_format)) + if (!FormatFromString(op.getDataFormat().str(), &data_format)) return op.emitOpError("invalid data format"); tensorflow::Padding padding; - if (!GetPaddingFromString(op.padding().str(), &padding).ok()) + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) return failure(); auto out_backprop_ty = - op.out_backprop().getType().template dyn_cast(); - auto input_ty = op.input().getType().template dyn_cast(); + op.getOutBackprop().getType().template dyn_cast(); + auto input_ty = + op.getInput().getType().template dyn_cast(); for (RankedTensorType ty : {out_backprop_ty, input_ty}) if (!ty || !ty.hasStaticShape()) return failure(); @@ -6176,15 +6086,15 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern { ArrayRef input_shape = input_ty.getShape(); DenseIntElementsAttr filter_shape_attr; - if (!matchPattern(op.filter_sizes(), m_Constant(&filter_shape_attr)) || + if (!matchPattern(op.getFilterSizes(), m_Constant(&filter_shape_attr)) || filter_shape_attr.getType().getRank() != 1) return failure(); - auto dilations_attr = GetI64ElementsAttr(op.dilations()); + auto dilations_attr = GetI64ElementsAttr(op.getDilations()); std::vector dilations{ dilations_attr.template getValues().begin(), dilations_attr.template getValues().end()}; - auto strides_attr = GetI64ElementsAttr(op.strides()); + auto strides_attr = GetI64ElementsAttr(op.getStrides()); std::vector strides{ strides_attr.template getValues().begin(), strides_attr.template getValues().end()}; @@ -6313,7 +6223,7 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern { tensorflow::GetTensorBatchDimIndex(num_dims, data_format); Value result = rewriter.create( - op.getLoc(), op.getType(), op.input(), op.out_backprop(), + op.getLoc(), op.getType(), op.getInput(), op.getOutBackprop(), /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter), /*padding=*/paddings_attr, /*lhs_dilation=*/ GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, @@ -6430,7 +6340,7 @@ class ConvertConvBackpropFilterDynamic : public OpRewritePattern { return true; } - tensorflow::Status ConvBackpropExtractAndVerifyDimensionDyn( + tensorflow::Status ConvBackpopExtractAndVerifyDimensionDyn( OpTy& op, PatternRewriter& rewriter, const std::vector dilations, const std::vector& strides, tensorflow::Padding padding, Value padding_before, Value padding_after, int spatial_dim, @@ -6466,12 +6376,12 @@ class ConvertConvBackpropFilterDynamic : public OpRewritePattern { auto get_int = [](Attribute attr) { return attr.template cast().getInt(); }; - dim->input_size = get_dim_value(op.input(), spatial_dim); + dim->input_size = get_dim_value(op.getInput(), spatial_dim); Value idx = rewriter.create(loc, filter_spatial_dim); dim->filter_size = - rewriter.create(loc, op.filter_sizes(), idx); - dim->output_size = get_dim_value(op.out_backprop(), spatial_dim); + rewriter.create(loc, op.getFilterSizes(), idx); + dim->output_size = get_dim_value(op.getOutBackprop(), spatial_dim); dim->stride = strides[spatial_dim]; dim->dilation = dilations[spatial_dim]; int64_t out_size = 0; @@ -6490,7 +6400,7 @@ class ConvertConvBackpropFilterDynamic : public OpRewritePattern { dim->pad_before = sub_one(sub_vals(effective_filter_size, padding_before)); dim->pad_after = sub_vals( sub_vals(padded_out_size, dim->expanded_output_size), dim->pad_before); - return tensorflow::Status::OK(); + return tensorflow::OkStatus(); } tensorflow::Status ConvBackpropComputeDimensionsV2Dyn( @@ -6522,14 +6432,14 @@ class ConvertConvBackpropFilterDynamic : public OpRewritePattern { num_dims, "-dimensional"); } int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); - dims->batch_size = get_dim_value(op.out_backprop(), batch_dim); + dims->batch_size = get_dim_value(op.getOutBackprop(), batch_dim); int feature_dim = GetTensorFeatureDimIndex(num_dims, data_format); // TODO(feiwen) : check in_depth and out_depth - dims->in_depth = get_dim_value(op.input(), feature_dim); + dims->in_depth = get_dim_value(op.getInput(), feature_dim); // The input and output feature dimensions are the second last and last // dimensions of the filter Tensor. - dims->out_depth = get_dim_value(op.out_backprop(), feature_dim); + dims->out_depth = get_dim_value(op.getOutBackprop(), feature_dim); dims->spatial_dims.resize(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { @@ -6539,11 +6449,11 @@ class ConvertConvBackpropFilterDynamic : public OpRewritePattern { padding_before = get_const(explicit_paddings[2 * image_dim]); padding_after = get_const(explicit_paddings[2 * image_dim + 1]); } - TF_RETURN_IF_ERROR(ConvBackpropExtractAndVerifyDimensionDyn( + TF_RETURN_IF_ERROR(ConvBackpopExtractAndVerifyDimensionDyn( op, rewriter, dilations, strides, padding, padding_before, padding_after, image_dim, i, &dims->spatial_dims[i])); } - return tensorflow::Status::OK(); + return tensorflow::OkStatus(); } LogicalResult matchAndRewrite(OpTy op, @@ -6580,32 +6490,33 @@ class ConvertConvBackpropFilterDynamic : public OpRewritePattern { }; // Unpack all of the attributes. tensorflow::TensorFormat data_format; - if (!FormatFromString(op.data_format().str(), &data_format)) + if (!FormatFromString(op.getDataFormat().str(), &data_format)) return op.emitOpError("invalid data format"); tensorflow::Padding padding; - if (!GetPaddingFromString(op.padding().str(), &padding).ok()) + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) return failure(); auto out_backprop_ty = - op.out_backprop().getType().template dyn_cast(); - auto input_ty = op.input().getType().template dyn_cast(); + op.getOutBackprop().getType().template dyn_cast(); + auto input_ty = + op.getInput().getType().template dyn_cast(); DenseIntElementsAttr filter_shape_attr; if (!out_backprop_ty || !input_ty || (out_backprop_ty.hasStaticShape() && input_ty.hasStaticShape() && - matchPattern(op.filter_sizes(), m_Constant(&filter_shape_attr)))) { + matchPattern(op.getFilterSizes(), m_Constant(&filter_shape_attr)))) { return failure(); } ArrayRef out_backprop_shape = out_backprop_ty.getShape(); ArrayRef input_shape = input_ty.getShape(); - auto dilations_attr = GetI64ElementsAttr(op.dilations()); + auto dilations_attr = GetI64ElementsAttr(op.getDilations()); std::vector dilations{ dilations_attr.template getValues().begin(), dilations_attr.template getValues().end()}; - auto strides_attr = GetI64ElementsAttr(op.strides()); + auto strides_attr = GetI64ElementsAttr(op.getStrides()); std::vector strides{ strides_attr.template getValues().begin(), strides_attr.template getValues().end()}; @@ -6727,8 +6638,8 @@ class ConvertConvBackpropFilterDynamic : public OpRewritePattern { tensorflow::GetTensorBatchDimIndex(num_dims, data_format); SmallVector operands; - operands.push_back(op.input()); - operands.push_back(op.out_backprop()); + operands.push_back(op.getInput()); + operands.push_back(op.getOutBackprop()); Value paddings_op = rewriter.create(op.getLoc(), paddings); operands.push_back(paddings_op); @@ -6786,18 +6697,18 @@ class ConvertOneHotOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::OneHotOp op, PatternRewriter &rewriter) const override { - auto indices_ty = op.indices().getType().dyn_cast(); + auto indices_ty = op.getIndices().getType().dyn_cast(); if (!indices_ty || !indices_ty.hasStaticShape()) return failure(); ArrayRef indices_shape = indices_ty.getShape(); Type element_type = indices_ty.getElementType(); DenseIntElementsAttr depth_attr; - if (!matchPattern(op.depth(), m_Constant(&depth_attr))) { + if (!matchPattern(op.getDepth(), m_Constant(&depth_attr))) { return failure(); } int64_t depth = depth_attr.getValues()[0].getSExtValue(); - int64_t axis = op.axis(); + int64_t axis = op.getAxis(); if (axis == -1) axis = indices_shape.size(); llvm::SmallVector broadcast_dims(indices_shape.size()); @@ -6819,16 +6730,16 @@ class ConvertOneHotOp : public OpRewritePattern { auto iota = rewriter.create( loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis)); auto broadcast_indices = rewriter.create( - loc, index_type, op.indices(), + loc, index_type, op.getIndices(), GetI64ElementsAttr(broadcast_dims, &rewriter)); Value compare = rewriter.create( loc, broadcast_indices, iota, ComparisonDirection::EQ); Value on_value = rewriter.create( - loc, op.getType(), op.on_value(), + loc, op.getType(), op.getOnValue(), GetI64ElementsAttr(output_dims, &rewriter)); Value off_value = rewriter.create( - loc, op.getType(), op.off_value(), + loc, op.getType(), op.getOffValue(), GetI64ElementsAttr(output_dims, &rewriter)); Value result = rewriter.create(loc, op.getType(), compare, on_value, off_value); @@ -6867,8 +6778,8 @@ class ConvertInfeedDequeueTupleOp LogicalResult matchAndRewrite(TF::InfeedDequeueTupleOp op, PatternRewriter &rewriter) const override { SmallVector result_types; - result_types.reserve(op.outputs().size() + 1); - for (const auto &output : op.outputs()) { + result_types.reserve(op.getOutputs().size() + 1); + for (const auto &output : op.getOutputs()) { Type ty = output.getType(); if (auto tensor_ty = ty.dyn_cast()) { if (!tensor_ty.hasStaticShape()) return failure(); @@ -6891,11 +6802,11 @@ class ConvertInfeedDequeueTupleOp result_types.pop_back(); // remove the token type. - if (op._XlaSharding().has_value()) { + if (op.get_XlaSharding().has_value()) { // _XlaSharding attribute in TF is a serialized string of the OpSharding // proto, so convert to a text form here. ::xla::OpSharding sharding_proto; - if (!sharding_proto.ParseFromString(op._XlaSharding().getValue().str())) + if (!sharding_proto.ParseFromString(op.get_XlaSharding().value().str())) return failure(); // Token is a control signal and not a real data, so arbitrarily assign @@ -6907,7 +6818,7 @@ class ConvertInfeedDequeueTupleOp kShardingAttr, rewriter.getStringAttr(sharding_proto.SerializeAsString())); } else { - data_and_token->setAttr(kShardingAttr, op._XlaShardingAttr()); + data_and_token->setAttr(kShardingAttr, op.get_XlaShardingAttr()); } } @@ -6954,7 +6865,7 @@ class ConvertOutfeedEnqueueTupleOp auto token_type = mhlo::TokenType::get(rewriter.getContext()); auto token = rewriter.create(op.getLoc(), token_type); - rewriter.create(op.getLoc(), token_type, op.inputs(), token, + rewriter.create(op.getLoc(), token_type, op.getInputs(), token, /*outfeed_config=*/rewriter.getStringAttr("")); rewriter.eraseOp(op); return success(); @@ -6970,17 +6881,17 @@ class ConvertTopKV2Op : public OpRewritePattern { PatternRewriter &rewriter) const override { // We can only match when the `k` operand is a constant scalar. DenseIntElementsAttr k_attr; - if (!matchPattern(op.k(), m_Constant(&k_attr))) return failure(); + if (!matchPattern(op.getK(), m_Constant(&k_attr))) return failure(); int64_t k = (*k_attr.begin()).getSExtValue(); - TensorType input_type = op.input().getType().cast(); + TensorType input_type = op.getInput().getType().cast(); if (!input_type.hasRank()) return failure(); int64_t input_rank = input_type.getRank(); int64_t last_dim_index = input_rank - 1; int64_t last_dim_size = input_type.getDimSize(last_dim_index); - if (last_dim_size == ShapedType::kDynamicSize) return failure(); + if (last_dim_size == ShapedType::kDynamic) return failure(); - rewriter.replaceOpWithNewOp(op, op.input(), k); + rewriter.replaceOpWithNewOp(op, op.getInput(), k); return success(); } }; @@ -6997,11 +6908,11 @@ class ConvertUnpackOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::UnpackOp op, PatternRewriter &rewriter) const override { - auto value_type = op.value().getType().dyn_cast(); + auto value_type = op.getValue().getType().dyn_cast(); if (!value_type || !value_type.hasStaticShape()) return failure(); int64_t value_rank = value_type.getRank(); - int64_t axis = op.axis(); + int64_t axis = op.getAxis(); if (axis < 0) axis += value_rank; // Parameters for constructing each slice. @@ -7018,13 +6929,14 @@ class ConvertUnpackOp : public OpRewritePattern { end_indices[axis] = i + 1; auto slice_op = rewriter.create( - op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter), + op.getLoc(), op.getValue(), + GetI64ElementsAttr(begin_indices, &rewriter), GetI64ElementsAttr(end_indices, &rewriter), GetI64ElementsAttr(strides, &rewriter)); // Reshape to drop the axis dimension. - auto result = - rewriter.create(op.getLoc(), op.getType(i), slice_op, - rewriter.getI64ArrayAttr(op.axis())); + auto result = rewriter.create( + op.getLoc(), op.getType(i), slice_op, + rewriter.getI64ArrayAttr(op.getAxis())); results.push_back(result); } @@ -7042,14 +6954,14 @@ class ConvertUnpackOpDynamic : public OpRewritePattern { LogicalResult matchAndRewrite(TF::UnpackOp op, PatternRewriter &rewriter) const override { - auto value_type = op.value().getType().dyn_cast(); + auto value_type = op.getValue().getType().dyn_cast(); if (!value_type) return failure(); // TODO(disc): Remove this constraint once fold and canonicalization // implemented. if (value_type.hasStaticShape()) return failure(); int64_t value_rank = value_type.getRank(); - int64_t axis = op.axis(); + int64_t axis = op.getAxis(); if (axis < 0) axis += value_rank; Location loc = op.getLoc(); @@ -7063,10 +6975,10 @@ class ConvertUnpackOpDynamic : public OpRewritePattern { SmallVector shape_values; shape_values.reserve(value_rank - 1); // slice shape before reshape, should be like{?, 1, ?, ?} if axis = 1 - SmallVector slice_shape(value_rank, ShapedType::kDynamicSize); + SmallVector slice_shape(value_rank, ShapedType::kDynamic); for (int64_t dim_idx = 0; dim_idx < value_rank; ++dim_idx) { int64_t dim_size = value_type.getDimSize(dim_idx); - if (dim_size == ShapedType::kDynamicSize) { + if (dim_size == ShapedType::kDynamic) { Value dim_i = rewriter.create( loc, shape_scalar_type, rewriter.create(loc, op.getOperand(), dim_idx)); @@ -7101,7 +7013,7 @@ class ConvertUnpackOpDynamic : public OpRewritePattern { loc, tensorflow::GetTypeFromTFTensorShape(slice_shape, value_type.getElementType()), - op.value(), + op.getValue(), rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape( @@ -7143,8 +7055,8 @@ class ConvertSigmoidGradOpDynamic : public OpRewritePattern { LogicalResult matchAndRewrite(TF::SigmoidGradOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value y = op.y(); - Value dy = op.dy(); + Value y = op.getY(); + Value dy = op.getDy(); auto tp_y = y.getType().dyn_cast(); auto tp_dy = dy.getType().dyn_cast(); if (!tp_y || !tp_dy) return failure(); @@ -7179,7 +7091,7 @@ class ConvertSigmoidGradOpDynamic : public OpRewritePattern { // Converts TF unsorted segment reduction ops to XLA HLO scatter op. // -// TF unsorted segment reduction op peforms the following calculation: +// TF unsorted segment reduction op.getPeforms the following calculation: // // Assume segment ids' shape is [SI0, SI1, ..., SIm] and data's shape is // [D0, D1, ..., Dn]. Note that segment ids' shape must be a prefix of data's @@ -7200,17 +7112,18 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - auto data_type = op.data().getType().template dyn_cast(); + auto data_type = + op.getData().getType().template dyn_cast(); if (!data_type) return failure(); int64_t data_rank = data_type.getRank(); auto segment_ids_type = - op.segment_ids().getType().template dyn_cast(); + op.getSegmentIds().getType().template dyn_cast(); if (!segment_ids_type) return failure(); int64_t segment_ids_rank = segment_ids_type.getRank(); DenseIntElementsAttr num_segments_attr; - if (!matchPattern(op.num_segments(), m_Constant(&num_segments_attr))) + if (!matchPattern(op.getNumSegments(), m_Constant(&num_segments_attr))) return failure(); // The final shape for TF unsorted segment reduction op is [num_segments] + @@ -7243,7 +7156,7 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { auto scatter = rewriter.create( op.getLoc(), op.getType(), ValueRange(Value(broadcasted_init)), - op.segment_ids(), op.data(), dims_attr); + op.getSegmentIds(), op.getData(), dims_attr); BuildReduceBody(data_type.getElementType(), &scatter.getUpdateComputation(), &rewriter); @@ -7320,19 +7233,24 @@ class ConvertRandomShuffleOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::RandomShuffleOp op, PatternRewriter &rewriter) const override { - auto input_type = op.value().getType().dyn_cast(); + auto no_op = [&]() { + rewriter.replaceOp(op, op.getValue()); + return success(); + }; + + auto input_type = op.getValue().getType().dyn_cast(); if (!input_type) return failure(); + if (input_type.hasStaticShape() && input_type.getNumElements() <= 1) + // No shuffling is required, so copy input directly to output. + return no_op(); int64_t input_rank = input_type.getRank(); int64_t first_dim_size = input_type.getDimSize(0); if (ShapedType::isDynamic(first_dim_size)) return failure(); - // We are shuffling along the first dimension. If its size is <= 1, then - // shuffling is a no-op. - if (first_dim_size <= 1) { - rewriter.replaceOp(op, op.value()); - return success(); - } + if (first_dim_size <= 1) + // No shuffling is required, so copy input directly to output. + return no_op(); // For vectors, shuffle values by sorting instead of the obvious // Fisher-Yates algorithm. Fisher-Yates is simple to implement and correct, @@ -7377,7 +7295,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern { int rounds = std::ceil(exponent * std::log(num_elements) / std::log(u32_max)); - Value current = op.value(); + Value current = op.getValue(); for (int i = 0; i < rounds; ++i) { auto keys = CreateRngUniform32(op.getLoc(), num_elements, /*lower_limit=*/0, @@ -7433,11 +7351,11 @@ class ConvertRandomShuffleOp : public OpRewritePattern { // Then perform the swap. // indices[i] <- indices[swaps[i]] indices = builder->create( - loc, indices.getType(), indices, target_index, llvm::makeArrayRef(i)); + loc, indices.getType(), indices, target_index, llvm::ArrayRef(i)); // indices[swaps[i]] <- indices[i] indices = builder->create( loc, indices.getType(), indices, source_index, - llvm::makeArrayRef(swap_index)); + llvm::ArrayRef(swap_index)); // Update new values. new_values->assign({swaps, indices}); @@ -7458,9 +7376,32 @@ class ConvertRandomShuffleOp : public OpRewritePattern { /*collapsed_slice_dims=*/{0}, /*start_index_map=*/{0}, /*index_vector_dim=*/1); - rewriter.replaceOpWithNewOp( - op, op.getType(), op.value(), swaped_indices, dims_attr, - GetI64ElementsAttr(slice_sizes, &rewriter)); + + SmallVector slice_sizes_values; + for (auto i = 0; i < slice_sizes.size(); ++i) { + if (slice_sizes[i] == tensorflow::kTFDynamicSize) { + Value i_const = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(i)); + Value slice_size_index = + rewriter.create(op.getLoc(), op.getValue(), i_const); + Value index_to_i64 = rewriter.create( + op.getLoc(), rewriter.getI64Type(), slice_size_index); + Value i64_to_tensor = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape({1}, rewriter.getI64Type()), + index_to_i64); + slice_sizes_values.push_back(i64_to_tensor); + } else { + slice_sizes_values.push_back(rewriter.create( + op.getLoc(), GetI64ElementsAttr({slice_sizes[i]}, &rewriter))); + } + } + + auto slice_sizes_concat = rewriter.create( + op.getLoc(), slice_sizes_values, rewriter.getI64IntegerAttr(0)); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getValue(), swaped_indices, slice_sizes_concat, + dims_attr); return success(); } @@ -7475,15 +7416,15 @@ class ConvertXlaShardingOp : public OpRewritePattern { PatternRewriter &rewriter) const override { // TODO(b/148313088): define sharding attribute struct in MLIR intead of // using a string. - if (!op._XlaSharding().has_value()) return failure(); + if (!op.get_XlaSharding().has_value()) return failure(); NamedAttribute call_target_name = rewriter.getNamedAttr( "call_target_name", rewriter.getStringAttr("Sharding")); auto custom_call = rewriter.create( - op.getLoc(), op.getType(), op.input(), + op.getLoc(), op.getType(), op.getInput(), ArrayRef{call_target_name}); - custom_call->setAttr(kShardingAttr, op._XlaShardingAttr()); + custom_call->setAttr(kShardingAttr, op.get_XlaShardingAttr()); rewriter.replaceOp(op, custom_call.getResult(0)); return success(); @@ -7497,9 +7438,9 @@ class ConvertInplaceUpdateOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::InplaceUpdateOp op, PatternRewriter &rewriter) const override { - auto input = op.x(); - auto indices = op.i(); - auto updates = op.v(); + auto input = op.getX(); + auto indices = op.getI(); + auto updates = op.getV(); // Slice each row of `i` and `v` to perform a separate dynamic-update-slice // on the contents of `x`. @@ -7544,7 +7485,7 @@ class ConvertInplaceUpdateOp : public OpRewritePattern { input_indices.resize(input_type.getRank(), cst); for (auto pair : - llvm::zip(unpacked_indices.output(), split_updates.output())) { + llvm::zip(unpacked_indices.getOutput(), split_updates.getOutput())) { input_indices.front() = std::get<0>(pair); input = rewriter.create( op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices); @@ -7563,7 +7504,7 @@ class ConvertXlaDynamicUpdateSliceOp LogicalResult matchAndRewrite(TF::XlaDynamicUpdateSliceOp op, PatternRewriter &rewriter) const override { - auto indices_type = op.indices().getType().dyn_cast(); + auto indices_type = op.getIndices().getType().dyn_cast(); if (!indices_type || !indices_type.hasStaticShape() || indices_type.getShape().size() != 1) return failure(); @@ -7572,10 +7513,11 @@ class ConvertXlaDynamicUpdateSliceOp indices_type.getDimSize(0), tensorflow::GetTypeFromTFTensorShape( {}, indices_type.getElementType())); auto unpacked_indices = rewriter.create( - op.getLoc(), unpacked_indices_type, op.indices(), + op.getLoc(), unpacked_indices_type, op.getIndices(), IntegerAttr::get(rewriter.getIntegerType(64), 0)); rewriter.replaceOpWithNewOp( - op, op.getType(), op.input(), op.update(), unpacked_indices.output()); + op, op.getType(), op.getInput(), op.getUpdate(), + unpacked_indices.getOutput()); return success(); } }; @@ -7588,7 +7530,7 @@ class ConvertXlaReduceScatterOp LogicalResult matchAndRewrite(TF::XlaReduceScatterOp op, PatternRewriter &rewriter) const override { DenseIntElementsAttr group_assignment; - if (!matchPattern(op.group_assignment(), m_Constant(&group_assignment))) + if (!matchPattern(op.getGroupAssignment(), m_Constant(&group_assignment))) return failure(); auto replica_groups = hlo::convertElementsAttr(group_assignment, rewriter.getIntegerType(64)) @@ -7596,19 +7538,19 @@ class ConvertXlaReduceScatterOp if (replica_groups.getType().getRank() != 2) return failure(); APInt scatter_dimension; - if (!matchPattern(op.scatter_dimension(), + if (!matchPattern(op.getScatterDimension(), m_ConstantInt(&scatter_dimension))) return failure(); Location loc = op.getLoc(); - Type element_type = getElementTypeOrSelf(op.input().getType()); + Type element_type = getElementTypeOrSelf(op.getInput().getType()); auto reduce_scatter = rewriter.create( - loc, op.getType(), op.input(), + loc, op.getType(), op.getInput(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), scatter_dimension.getSExtValue()), replica_groups, ChannelHandleAttr()); - StringRef reduce_op = op.reduce_op(); + StringRef reduce_op = op.getReduceOp(); if (reduce_op == "Add") { BuildReduceBody(element_type, &reduce_scatter.getComputation(), &rewriter); @@ -7655,12 +7597,13 @@ class ConvertXlaReduceWindowOp PatternRewriter &rewriter) const override { DenseElementsAttr window_dimensions, window_strides, base_dilations, window_dilations, padding; - if (!(matchPattern(op.window_dimensions(), + if (!(matchPattern(op.getWindowDimensions(), m_Constant(&window_dimensions)) && - matchPattern(op.window_strides(), m_Constant(&window_strides)) && - matchPattern(op.base_dilations(), m_Constant(&base_dilations)) && - matchPattern(op.window_dilations(), m_Constant(&window_dilations)) && - matchPattern(op.padding(), m_Constant(&padding)))) + matchPattern(op.getWindowStrides(), m_Constant(&window_strides)) && + matchPattern(op.getBaseDilations(), m_Constant(&base_dilations)) && + matchPattern(op.getWindowDilations(), + m_Constant(&window_dilations)) && + matchPattern(op.getPadding(), m_Constant(&padding)))) return failure(); Location loc = op.getLoc(); @@ -7668,7 +7611,7 @@ class ConvertXlaReduceWindowOp SmallVector result_types{op.getResult().getType()}; // Create the mhlo.SelectAndScatter op. auto reduce_window_op = rewriter.create( - loc, result_types, op.input(), op.init_value(), + loc, result_types, op.getInput(), op.getInitValue(), hlo::convertElementsAttr(window_dimensions, rewriter.getIntegerType(64)) .cast(), hlo::convertElementsAttr(window_strides, rewriter.getIntegerType(64)) @@ -7680,7 +7623,7 @@ class ConvertXlaReduceWindowOp hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)) .cast()); // Insert a call to the reducer in the region of the mhlo op. - mlir::SymbolRefAttr func = op.computation(); + mlir::SymbolRefAttr func = op.getComputation(); auto func_op = cast(SymbolTable::lookupSymbolIn( op->getParentOfType(), func)); auto func_ty = func_op.getFunctionType(); @@ -7701,9 +7644,9 @@ class ConvertClipByValueOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::ClipByValueOp op, PatternRewriter &rewriter) const override { - Value input = op.x(); - Value min = op.clip_value_min(); - Value max = op.clip_value_max(); + Value input = op.getX(); + Value min = op.getClipValueMin(); + Value max = op.getClipValueMax(); auto input_ty = input.getType().cast(); auto min_ty = min.getType().cast(); @@ -7748,7 +7691,7 @@ class ConvertConstOp : public OpRewritePattern { return failure(); Location loc = op.getLoc(); - Value result = rewriter.create(loc, op.value()); + Value result = rewriter.create(loc, op.getValue()); if (result.getType() != op.getType()) result = rewriter.create(loc, op.getType(), result); rewriter.replaceOp(op, result); @@ -7767,7 +7710,7 @@ class ConvertCumOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const override { - auto input = op.x(); + auto input = op.getX(); auto input_type = input.getType().template dyn_cast(); if (!input_type || !input_type.hasStaticShape()) { return failure(); @@ -7778,7 +7721,7 @@ class ConvertCumOp : public OpRewritePattern { // We can only match when the axis is a constant scalar. DenseIntElementsAttr axis_attr; - if (!matchPattern(op.axis(), m_Constant(&axis_attr))) { + if (!matchPattern(op.getAxis(), m_Constant(&axis_attr))) { return failure(); } @@ -7791,7 +7734,7 @@ class ConvertCumOp : public OpRewritePattern { // If we're supposed to sum things up in the reverse direction, we reverse // the input and then later reverse the output. - if (op.reverse()) { + if (op.getReverse()) { llvm::SmallVector dims_to_reverse({axis}); input = rewriter.create( op.getLoc(), input, GetI64ElementsAttr(dims_to_reverse, &rewriter)); @@ -7833,7 +7776,7 @@ class ConvertCumOp : public OpRewritePattern { &rewriter); Value result = reduce.getResult(0); - if (op.exclusive()) { + if (op.getExclusive()) { // In "exclusive" operation, the output will start with the "init" (0) // values. There is no way to express that as a ReduceWindowOp, so run the // normal operation, and then use a PadOp to add the 0 "column" on the @@ -7853,7 +7796,7 @@ class ConvertCumOp : public OpRewritePattern { result = rewriter.create(op.getLoc(), result, input_element_type); - if (op.reverse()) { + if (op.getReverse()) { llvm::SmallVector dims_to_reverse({axis}); result = rewriter.create( op.getLoc(), result, GetI64ElementsAttr(dims_to_reverse, &rewriter)); @@ -7877,7 +7820,7 @@ class ConvertShapeOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::ShapeOp op, PatternRewriter &rewriter) const override { - Value input = op.input(); + Value input = op.getInput(); auto result_ty = op.getResult().getType().dyn_cast(); if (!result_ty) { @@ -7899,7 +7842,7 @@ class ConvertDynamicExpandDimsOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::ExpandDimsOp op, PatternRewriter &rewriter) const override { - auto input = op.input(); + auto input = op.getInput(); auto input_ty = input.getType().cast(); auto result_ty = op.getType().cast(); if (!result_ty.hasRank() || !input_ty.hasRank() || @@ -7908,7 +7851,7 @@ class ConvertDynamicExpandDimsOp : public OpRewritePattern { } DenseIntElementsAttr expand_dims_attr; - if (!matchPattern(op.dim(), m_Constant(&expand_dims_attr))) { + if (!matchPattern(op.getDim(), m_Constant(&expand_dims_attr))) { return failure(); } @@ -7957,7 +7900,7 @@ class ConvertDynamicSqueezeOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::SqueezeOp op, PatternRewriter &rewriter) const override { - auto input = op.input(); + auto input = op.getInput(); auto input_ty = input.getType().cast(); auto result_ty = op.getType().cast(); if (!result_ty.hasRank() || !input_ty.hasRank() || @@ -7966,14 +7909,14 @@ class ConvertDynamicSqueezeOp : public OpRewritePattern { } // The fully dynamic case is unsupported. - if (op.squeeze_dims().empty()) { + if (op.getSqueezeDims().empty()) { return failure(); } SmallVector squeeze_dims; int64_t input_rank = input_ty.getRank(); for (const auto &squeeze_dim_apint : - op.squeeze_dims().getAsValueRange()) { + op.getSqueezeDims().getAsValueRange()) { int64_t squeeze_dim = squeeze_dim_apint.getSExtValue(); // Handle negative inputs. if (squeeze_dim < 0) squeeze_dim += input_rank; @@ -8007,7 +7950,7 @@ class ConvertDynamicStitchOpDynamic : public OpRewritePattern(); if (!tensor_ty) return ty; - SmallVector shape(rank, ShapedType::kDynamicSize); + SmallVector shape(rank, ShapedType::kDynamic); return RankedTensorType::get(shape, tensor_ty.getElementType()); } @@ -8023,8 +7966,8 @@ class ConvertDynamicStitchOpDynamic : public OpRewritePattern indices; - indices.reserve(op.N()); - for (auto it : llvm::zip(op.indices(), op.data())) { + indices.reserve(op.getN()); + for (auto it : llvm::zip(op.getIndices(), op.getData())) { Value index = std::get<0>(it); Value data = std::get<1>(it); @@ -8041,7 +7984,7 @@ class ConvertDynamicStitchOpDynamic : public OpRewritePattern values(out_ty.getDimSize(0)); - for (auto it : llvm::zip(indices, op.data())) { + for (auto it : llvm::zip(indices, op.getData())) { DenseIntElementsAttr index_attr = std::get<0>(it); SmallVector shapes; Value data = std::get<1>(it); @@ -8070,8 +8013,10 @@ class ConvertDynamicStitchOpDynamic : public OpRewritePattern().getElementType(); auto unpacked_ty = SmallVector( - index_attr.size(), RankedTensorType::get( - SmallVector(extracted_ranks, ShapedType::kDynamicSize), reshaped_data_ty)); + index_attr.size(), + RankedTensorType::get( + SmallVector(extracted_ranks, ShapedType::kDynamic), + reshaped_data_ty)); auto items = rewriter.create( loc, unpacked_ty, reshaped_data, /*axis*/0); @@ -8087,509 +8032,6 @@ class ConvertDynamicStitchOpDynamic : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::QrOp op, - PatternRewriter &rewriter) const override { - // Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van - // Loan. def qr_blocked(a, block_size): - // m = a.shape[0] - // n = a.shape[1] - // q = np.eye(m) - // for i in xrange(0, min(m, n), block_size): - // k = min(block_size, min(m, n) - s) - // (a, vs, taus) = qr(a[i:, i:i+k]) - // y = vs - // w = ComputeWYRepresentation(vs, taus, m-i, k) - // a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:])) - // q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T)) - // return (q, a) - auto type = op.input().getType().dyn_cast(); - if (!type || !type.hasStaticShape()) return failure(); - // The block size is chosen to match old bridge lowering. - constexpr int64_t kBlockSize = 128; - Value a = op.input(); - int64_t m = type.getDimSize(type.getRank() - 2); - int64_t n = type.getDimSize(type.getRank() - 1); - int64_t p = std::min(m, n); - auto batch_dims = type.getShape().drop_back(2); - auto iota_type = tensorflow::GetTypeFromTFTensorShape( - {m, m}, rewriter.getIntegerType(32)); - auto iota0 = rewriter.create(op.getLoc(), iota_type, - rewriter.getI64IntegerAttr(0)); - auto iota1 = rewriter.create(op.getLoc(), iota_type, - rewriter.getI64IntegerAttr(1)); - Value compare = rewriter.create(op.getLoc(), iota0, iota1, - ComparisonDirection::EQ); - Value identity_matrix = - rewriter.create(op.getLoc(), compare, type.getElementType()); - auto q_shape = llvm::to_vector<4>(type.getShape()); - q_shape.back() = m; - Value q = - rewriter.create(op.getLoc(), identity_matrix, - GetI64ElementsAttr(batch_dims, &rewriter)); - auto precision_config = rewriter.getArrayAttr( - {PrecisionAttr::get(rewriter.getContext(), Precision::HIGHEST), - PrecisionAttr::get(rewriter.getContext(), Precision::HIGHEST)}); - for (int64_t i = 0; i < p; i += kBlockSize) { - int64_t k = std::min(kBlockSize, p - i); - auto a_block = - SliceInMinorDims(op.getLoc(), a, {i, i}, {m, i + k}, &rewriter); - Value r_block; - Value taus; - Value vs; - QRBlock(op.getLoc(), a_block, &r_block, &taus, &vs, &rewriter); - a = UpdateSliceInMinorDims(op.getLoc(), a, r_block, {i, i}, &rewriter); - - // Compute the I-WY block representation of a product of Householder - // matrices. - Value w = - ComputeWYRepresentation(op.getLoc(), type.getElementType(), - batch_dims, vs, taus, m - i, k, &rewriter); - auto y = vs; - - // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) - Value a_panel = - SliceInMinorDims(op.getLoc(), a, {i, i + k}, {m, n}, &rewriter); - auto a_update = BatchDot(op.getLoc(), w, true, a_panel, false, - batch_dims.size(), precision_config, &rewriter); - a_update = BatchDot(op.getLoc(), y, false, a_update, false, - batch_dims.size(), precision_config, &rewriter); - a_panel = rewriter.create(op.getLoc(), a_panel, a_update); - a = UpdateSliceInMinorDims(op.getLoc(), a, a_panel, {i, i + k}, - &rewriter); - - // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) - Value q_panel = - SliceInMinorDims(op.getLoc(), q, {0, i}, {m, m}, &rewriter); - Value q_update = BatchDot(op.getLoc(), q_panel, false, w, false, - batch_dims.size(), precision_config, &rewriter); - q_update = BatchDot(op.getLoc(), q_update, false, y, true, - batch_dims.size(), precision_config, &rewriter); - q_panel = rewriter.create(op.getLoc(), q_panel, q_update); - q = UpdateSliceInMinorDims(op.getLoc(), q, q_panel, {i}, &rewriter); - } - // full_matrices is false when only a partial result in needed. Slice to the - // needed dimensions here. - if (!op.full_matrices()) { - q = SliceInMinorDims(op.getLoc(), q, {0, 0}, {m, p}, &rewriter); - a = SliceInMinorDims(op.getLoc(), a, {0, 0}, {p, n}, &rewriter); - } - rewriter.replaceOp(op, {q, a}); - return success(); - } - - private: - // Computes a Householder reflection of the form: - // H = I - tau v v.T. - // such that - // H . ( x1 ) = ( x1 ) - // ( x2 ) = ( x2 ) - // ( ... ) = ( ... ) - // ( xk ) = ( beta ) - // ( ... ) ( 0 ) - // ( ... ) ( 0 ) - // Unlike the usual formulation, we allow the caller to supply 'k' rather than - // only providing the relevant part of 'x' to maintain XLA's static shape - // invariant. In addition, the implementation supports batching. - // Pseudo-code, without batching: - // alpha = x[k] - // x_copy = np.copy(x) - // x_copy[:k+1] = 0 - // xnorm = norm2(x_copy) - // if xnorm == 0: - // beta = alpha - // tau = 0 - // v = np.zeros_like(x) - // else: - // beta = - np.sign(alpha) * dlapy2(alpha, xnorm) - // tau = (beta - alpha) / beta - // v = x / (alpha - beta) - // v[k] = 1 - // return (v, tau, beta) - void House(Location loc, Value x, Value k, ArrayRef batch_dims, - const int64_t m, OpBuilder *builder, Value *v, Value *tau, - Value *beta) const { - auto x_type = x.getType().cast(); - - llvm::SmallVector batch_dim_ids(batch_dims.size()); - std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0); - const int64_t minor_dim = batch_dims.size(); - - Value zero = GetScalarConstOfType(x_type.getElementType(), loc, 0, builder); - Value one = GetScalarConstOfType(x_type.getElementType(), loc, 1, builder); - - // alpha = x[k] - Value alpha = DynamicSliceInMinorDims(loc, x, {k}, {1}, builder); - alpha = builder->create(loc, - tensorflow::GetTypeFromTFTensorShape( - batch_dims, x_type.getElementType()), - alpha); - - // Compute x[k+1:] (padded with zeros in elements 0..k) - Value iota = builder->create( - loc, - tensorflow::GetTypeFromTFTensorShape({m}, builder->getIntegerType(32)), - builder->getI64IntegerAttr(0)); - Value gtk = builder->create( - loc, iota, k, GetI64ElementsAttr({}, builder), - chlo::ComparisonDirection::GT); - gtk = builder->create(loc, gtk, x_type.getElementType()); - Value x_after_k = builder->create( - loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder)); - Value x_after_k_sq = builder->create(loc, x_after_k, x_after_k); - // sigma = np.dot(x[k+1:], x[k+1:]) - auto sigma = builder->create( - loc, x_after_k_sq, zero, GetI64ElementsAttr({minor_dim}, builder)); - BuildReduceBody(x_type.getElementType(), &sigma.getBody(), builder); - // mu = np.sqrt(x[k]*x[k] + sigma) - Value alpha_sq = builder->create(loc, alpha, alpha); - Value mu = builder->create( - loc, builder->create(loc, alpha_sq, sigma.getResult(0))); - - Value sigma_is_zero = builder->create( - loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder), - chlo::ComparisonDirection::EQ); - Value alpha_is_negative = builder->create( - loc, alpha, zero, GetI64ElementsAttr({}, builder), - chlo::ComparisonDirection::LT); - auto batch_size_one = builder->create( - loc, one, GetI64ElementsAttr(batch_dims, builder)); - Value signed_mu = builder->create( - loc, - builder->create(loc, alpha_is_negative, batch_size_one, - builder->create(loc, batch_size_one)), - mu, GetI64ElementsAttr({}, builder)); - *beta = builder->create(loc, sigma_is_zero, alpha, signed_mu); - *tau = builder->create( - loc, builder->create(loc, *beta, alpha), *beta); - Value zero_tau = builder->create( - loc, zero, GetI64ElementsAttr(batch_dims, builder)); - *tau = builder->create(loc, sigma_is_zero, zero_tau, *tau); - Value divisor = builder->create(loc, alpha, *beta); - divisor = - builder->create(loc, sigma_is_zero, batch_size_one, divisor); - - Value eqk = builder->create( - loc, iota, k, GetI64ElementsAttr({}, builder), - chlo::ComparisonDirection::EQ); - eqk = builder->create(loc, eqk, x_type.getElementType()); - llvm::SmallVector e_k_shape(batch_dims.size(), 1); - e_k_shape.push_back(m); - auto e_k = builder->create( - loc, eqk, - GetI64ElementsAttr(llvm::SmallVector(batch_dims.size(), 1), - builder)); - - // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor - // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. - // Note that the add performs a degenerate broadcast. - *v = builder->create( - loc, e_k, - StaticBinaryBroadcast(loc, x_after_k, divisor, - GetI64ElementsAttr(batch_dim_ids, builder), - *builder), - /*broadcast_dimensions=*/nullptr); - } - - // Householder QR decomposition. Algorithm 5.2.1 from Golub and Van - // Loan "Matrix Computations", 4th Edition. This is an unblocked - // implementation used as an inner routine of the blocked implementation. - // Algorithm is adapted slightly so the shapes inside the loop are static, at - // the cost of some redundant computation. Since this is used as an inner - // block kernel, accumulates the Householder transformations (vs, taus) rather - // than the matrix q. Equivalent Python code, without batching: def qr(a): - // m = a.shape[0] - // n = a.shape[1] - // vs = np.zeros([m, n]) - // taus = np.zeros([n]) - // for j in xrange(min(m, n)): - // v, tau, beta = house(a[:, j], j) - // # Unusually, we apply the Householder transformation to the entirety of - // # a, wasting FLOPs to maintain the static shape invariant that XLA - // # requires. For columns that precede j this has no effect. - // a[:, :] -= tau * np.dot(v[:, np.newaxis], - // np.dot(v[np.newaxis, :], a[:, :])) - // # Form column j explicitly rather than relying on the precision of the - // # Householder update. - // a[j, j] = beta - // a[j+1:, j] = np.zeros([m - j - 1], dtype=a.dtype) - // vs[:, j] = v - // taus[j] = tau - // return (q, vs, taus) - void QRBlock(Location loc, Value a, Value *r, Value *taus, Value *vs, - PatternRewriter *rewriter) const { - auto a_type = a.getType().cast(); - const int num_dims = a_type.getRank(); - assert(num_dims >= 2 && "Argument to QR must have rank >= 2"); - - const int64_t m = a_type.getDimSize(a_type.getRank() - 2); - const int64_t n = a_type.getDimSize(a_type.getRank() - 1); - - const int64_t num_batch_dims = num_dims - 2; - auto batch_dims = a_type.getShape().take_front(num_batch_dims); - llvm::SmallVector batch_dim_indices(batch_dims.size()); - std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); - - auto qr_body_fn = [&](Location loc, Value j, ArrayRef old_values, - SmallVectorImpl *new_values, - OpBuilder *builder) { - auto a = old_values[0]; - auto vs = old_values[1]; - auto taus = old_values[2]; - - // v, beta = house(a[:, j], j) - auto x = DynamicSliceInMinorDims(loc, a, {j}, {1}, builder); - auto x_collapsed_shape = llvm::to_vector<4>(batch_dims); - x_collapsed_shape.push_back(m); - auto x_collapsed = builder->create( - loc, - tensorflow::GetTypeFromTFTensorShape( - x_collapsed_shape, getElementTypeOrSelf(x.getType())), - x); - Value v, tau, beta; - House(loc, x_collapsed, j, batch_dims, m, builder, &v, &tau, &beta); - - auto shape = llvm::to_vector<4>(batch_dims); - shape.append({1, m}); - auto v_broadcast = builder->create( - loc, - tensorflow::GetTypeFromTFTensorShape( - shape, getElementTypeOrSelf(v.getType())), - v); - // a[:, :] -= tau * np.dot(v[:, np.newaxis], - // np.dot(v[np.newaxis, :], a[:, :])) - auto precision = builder->getArrayAttr( - {PrecisionAttr::get(builder->getContext(), Precision::HIGHEST), - PrecisionAttr::get(builder->getContext(), Precision::HIGHEST)}); - auto vva = BatchDot(loc, v_broadcast, false, a, false, num_batch_dims, - precision, builder); - vva = BatchDot(loc, v_broadcast, true, vva, false, num_batch_dims, - precision, builder); - auto tau_x_vva = StaticBinaryBroadcast( - loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder), - *builder); - a = builder->create(loc, a, tau_x_vva); - - // It is more precise to populate column 'k' explicitly, rather than - // computing it implicitly by applying the Householder transformation. - // a[k,k] = beta - // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype) - auto iota = - builder->create(loc, - tensorflow::GetTypeFromTFTensorShape( - {m, 1}, builder->getIntegerType(32)), - builder->getI64IntegerAttr(0)); - Value predecessor_mask = builder->create( - loc, iota, j, GetI64ElementsAttr({}, builder), - chlo::ComparisonDirection::LT); - predecessor_mask = builder->create(loc, predecessor_mask, - a_type.getElementType()); - Value mask = builder->create( - loc, iota, j, GetI64ElementsAttr({}, builder), - chlo::ComparisonDirection::EQ); - mask = builder->create(loc, mask, a_type.getElementType()); - mask = builder->create( - loc, mask, - GetI64ElementsAttr(llvm::SmallVector(num_batch_dims, 1), - builder)); - Value predecessor_masked_x = StaticBinaryBroadcast( - loc, x, predecessor_mask, - GetI64ElementsAttr({num_dims - 2, num_dims - 1}, builder), *builder); - Value masked_beta = StaticBinaryBroadcast( - loc, beta, mask, GetI64ElementsAttr(batch_dim_indices, builder), - *builder); - Value new_x = - builder->create(loc, predecessor_masked_x, masked_beta); - // Update a[:,j] - llvm::SmallVector dim_ids(num_dims); - std::iota(dim_ids.begin(), dim_ids.end(), 0); - new_x = builder->create( - loc, a_type, new_x, GetI64ElementsAttr(dim_ids, builder)); - const int64_t minor_dim = num_batch_dims; - auto iota_mn = builder->create( - loc, - tensorflow::GetTypeFromTFTensorShape(a_type.getShape(), - builder->getIntegerType(32)), - builder->getI64IntegerAttr(minor_dim + 1)); - Value xa_mask = builder->create( - loc, iota_mn, j, GetI64ElementsAttr({}, builder), - chlo::ComparisonDirection::EQ); - a = builder->create(loc, xa_mask, new_x, a); - - // vs[:, j] = v - llvm::SmallVector vs_broadcast_dims(num_batch_dims + 1); - std::iota(vs_broadcast_dims.begin(), vs_broadcast_dims.end(), 0); - Value vs_zeros = - GetScalarConstOfType(a_type.getElementType(), loc, 0, builder); - vs_zeros = builder->create( - loc, vs_zeros, - GetI64ElementsAttr(vs.getType().cast().getShape(), - builder)); - auto vs_update = builder->create( - loc, xa_mask, - StaticBinaryBroadcast( - loc, vs_zeros, v, GetI64ElementsAttr(vs_broadcast_dims, builder), - *builder), - vs_zeros); - vs = builder->create(loc, vs, vs_update); - - // taus[j] = tau - llvm::SmallVector tau_broadcast_dims(batch_dims.size()); - std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0); - - auto iota_shape = llvm::to_vector<4>(batch_dims); - iota_shape.push_back(n); - auto iota_n = - builder->create(loc, - tensorflow::GetTypeFromTFTensorShape( - iota_shape, builder->getIntegerType(32)), - builder->getI64IntegerAttr(minor_dim)); - Value taus_zeros = - GetScalarConstOfType(a_type.getElementType(), loc, 0, builder); - taus_zeros = builder->create( - loc, taus_zeros, - GetI64ElementsAttr(taus.getType().cast().getShape(), - builder)); - Value taus_mask = builder->create( - loc, iota_n, j, GetI64ElementsAttr({}, builder), - chlo::ComparisonDirection::EQ); - auto taus_update = builder->create( - loc, taus_mask, - StaticBinaryBroadcast( - loc, taus_zeros, tau, - GetI64ElementsAttr(tau_broadcast_dims, builder), *builder), - taus_zeros); - taus = builder->create(loc, taus, taus_update); - new_values->assign({a, vs, taus}); - }; - - Value zero = - GetScalarConstOfType(a_type.getElementType(), loc, 0, rewriter); - *vs = rewriter->create( - loc, zero, GetI64ElementsAttr(a_type.getShape(), rewriter)); - auto taus_shape = llvm::to_vector<4>(batch_dims); - taus_shape.push_back(n); - *taus = rewriter->create( - loc, zero, GetI64ElementsAttr(taus_shape, rewriter)); - - SmallVector while_output; - CreateWhile32(loc, std::min(m, n), qr_body_fn, {a, *vs, *taus}, - &while_output, rewriter); - *r = while_output[0]; - *vs = while_output[1]; - *taus = while_output[2]; - } - - // Computes W and Y such that I-WY is equivalent to the sequence of - // Householder - // transformations given by vs and taus. - // Golub and van Loan, "Matrix Computations", algorithm 5.1.2. - // Y = np.zeros([m, n]) - // W = np.zeros([m, n]) - // Y[:, 0] = vs[:, 0] - // W[:, 0] = -taus[0] * vs[:, 0] - // for j in xrange(1, n): - // v = vs[:, j] - // z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v)) - // W[:, j] = z - // Y[:, j] = v - // return W - // There is no need to return Y since at termination of the loop it is equal - // to vs. - Value ComputeWYRepresentation(Location loc, Type data_type, - ArrayRef batch_dims, Value vs, - Value taus, int64_t m, int64_t n, - PatternRewriter *rewriter) const { - int64_t n_index = batch_dims.size() + 1; - llvm::SmallVector batch_dim_indices(batch_dims.size()); - std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); - - auto body_fn = [&](Location loc, Value j, ArrayRef old_values, - SmallVectorImpl *new_values, OpBuilder *builder) { - // w has shape [..., m, n] - auto w = old_values[0]; - const auto vs = old_values[1]; - const auto taus = old_values[2]; - - // Want j values in range [1, ... n). - j = builder->create( - loc, j, - GetScalarConstOfType(getElementTypeOrSelf(j.getType()), loc, 1, - builder)); - // vs has shape [..., m, 1] - auto v = DynamicSliceInMinorDims(loc, vs, {j}, {1}, builder); - // beta has shape [..., 1] - auto beta = DynamicSliceInMinorDims(loc, taus, {j}, {1}, builder); - - auto iota_shape = llvm::to_vector<4>(batch_dims); - iota_shape.append({m, n}); - auto iota_mn = - builder->create(loc, - tensorflow::GetTypeFromTFTensorShape( - iota_shape, builder->getIntegerType(32)), - builder->getI64IntegerAttr(n_index)); - - // y has shape [..., m, n] - Value zero = GetScalarConstOfType(getElementTypeOrSelf(vs.getType()), loc, - 0, builder); - zero = builder->create( - loc, zero, - GetI64ElementsAttr(vs.getType().cast().getShape(), - builder)); - auto compare = builder->create( - loc, iota_mn, j, GetI64ElementsAttr({}, builder), - chlo::ComparisonDirection::GE); - auto y = builder->create(loc, compare, zero, vs); - - // yv has shape [..., n, 1] - auto precision = builder->getArrayAttr( - {PrecisionAttr::get(builder->getContext(), Precision::HIGHEST), - PrecisionAttr::get(builder->getContext(), Precision::HIGHEST)}); - auto yv = BatchDot(loc, y, true, v, false, batch_dims.size(), precision, - builder); - // wyv has shape [..., m, 1] - auto wyv = BatchDot(loc, w, false, yv, false, batch_dims.size(), - precision, builder); - - // z = -beta * (v + wyv) - auto neg_beta = builder->create(loc, beta); - auto v_wyv = builder->create(loc, v, wyv); - auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices); - beta_broadcast_dims.push_back(n_index); - auto z = StaticBinaryBroadcast( - loc, neg_beta, v_wyv, - GetI64ElementsAttr(beta_broadcast_dims, builder), *rewriter); - - w = DynamicUpdateSliceInMinorDims(loc, w, z, {j}, builder); - new_values->assign({w, vs, taus}); - }; - - Value w = - GetScalarConstOfType(getElementTypeOrSelf(data_type), loc, 0, rewriter); - auto w_shape = llvm::to_vector<4>(batch_dims); - w_shape.append({m, n}); - w = rewriter->create(loc, w, - GetI64ElementsAttr(w_shape, rewriter)); - auto v = SliceInMinorDims(loc, vs, {0}, {1}, rewriter); - auto beta = SliceInMinorDims(loc, taus, {0}, {1}, rewriter); - auto neg_beta = rewriter->create(loc, beta); - auto beta_broadcast_dims = llvm::to_vector<4>(batch_dim_indices); - beta_broadcast_dims.push_back(n_index); - auto bv = StaticBinaryBroadcast( - loc, neg_beta, v, GetI64ElementsAttr(beta_broadcast_dims, rewriter), - *rewriter); - w = UpdateSliceInMinorDims(loc, w, bv, {0}, rewriter); - - SmallVector while_output; - CreateWhile32(loc, n - 1, body_fn, {w, vs, taus}, &while_output, rewriter); - return while_output[0]; - } -}; - // Converts tf.XlaConvV2 to mhlo.Conv class ConvertXlaConvV2Op : public OpRewritePattern { public: @@ -8599,11 +8041,12 @@ class ConvertXlaConvV2Op : public OpRewritePattern { PatternRewriter &rewriter) const override { DenseElementsAttr window_strides_attr, padding_attr, lhs_dilation_attr, rhs_dilation_attr, feature_group_count_attr; - if (!(matchPattern(op.window_strides(), m_Constant(&window_strides_attr)) && - matchPattern(op.padding(), m_Constant(&padding_attr)) && - matchPattern(op.lhs_dilation(), m_Constant(&lhs_dilation_attr)) && - matchPattern(op.rhs_dilation(), m_Constant(&rhs_dilation_attr)) && - matchPattern(op.feature_group_count(), + if (!(matchPattern(op.getWindowStrides(), + m_Constant(&window_strides_attr)) && + matchPattern(op.getPadding(), m_Constant(&padding_attr)) && + matchPattern(op.getLhsDilation(), m_Constant(&lhs_dilation_attr)) && + matchPattern(op.getRhsDilation(), m_Constant(&rhs_dilation_attr)) && + matchPattern(op.getFeatureGroupCount(), m_Constant(&feature_group_count_attr)))) return failure(); @@ -8634,29 +8077,29 @@ class ConvertXlaConvV2Op : public OpRewritePattern { rewriter.getI64IntegerAttr(feature_group_count_val)); auto batch_group_count_named_attr = - rewriter.getNamedAttr("batch_group_count", op.batch_group_countAttr()); + rewriter.getNamedAttr("batch_group_count", op.getBatchGroupCountAttr()); xla::ConvolutionDimensionNumbers dnums; - dnums.ParseFromString(op.dimension_numbersAttr().getValue().str()); + dnums.ParseFromString(op.getDimensionNumbersAttr().getValue().str()); auto dimension_numbers_named_attr = rewriter.getNamedAttr( "dimension_numbers", xla::ConvertConvDimensionNumbers(dnums, &rewriter)); xla::PrecisionConfig precision_config; precision_config.ParseFromString( - op.precision_configAttr().getValue().str()); + op.getPrecisionConfigAttr().getValue().str()); auto precision_config_named_attr = rewriter.getNamedAttr( "precision_config", xla::ConvertPrecisionConfig(&precision_config, &rewriter)); - SmallVector operands{op.lhs(), op.rhs()}; + SmallVector operands{op.getLhs(), op.getRhs()}; NamedAttribute attrs[] = { window_strides_named_attr, padding_named_attr, lhs_dilation_named_attr, rhs_dilation_named_attr, feature_group_count_named_attr, batch_group_count_named_attr, dimension_numbers_named_attr, precision_config_named_attr}; rewriter.replaceOpWithNewOp(op, op.getType(), operands, - llvm::makeArrayRef(attrs)); + llvm::ArrayRef(attrs)); return success(); } }; @@ -8670,10 +8113,10 @@ class ConvertXlaSelectAndScatterOp LogicalResult matchAndRewrite(TF::XlaSelectAndScatterOp op, PatternRewriter &rewriter) const override { ElementsAttr window_dimensions, window_strides, padding; - if (!(matchPattern(op.window_dimensions(), + if (!(matchPattern(op.getWindowDimensions(), m_Constant(&window_dimensions)) && - matchPattern(op.window_strides(), m_Constant(&window_strides)) && - matchPattern(op.padding(), m_Constant(&padding)))) + matchPattern(op.getWindowStrides(), m_Constant(&window_strides)) && + matchPattern(op.getPadding(), m_Constant(&padding)))) return failure(); Location loc = op.getLoc(); @@ -8681,7 +8124,7 @@ class ConvertXlaSelectAndScatterOp SmallVector result_types{op.getResult().getType()}; // Create the mhlo.SelectAndScatter op. auto select_and_scatter_op = rewriter.create( - loc, result_types, op.operand(), op.source(), op.init_value(), + loc, result_types, op.getOperand(), op.getSource(), op.getInitValue(), hlo::convertElementsAttr(window_dimensions, rewriter.getIntegerType(64)) .cast(), hlo::convertElementsAttr(window_strides, rewriter.getIntegerType(64)) @@ -8697,10 +8140,10 @@ class ConvertXlaSelectAndScatterOp }; // Insert a call to the select function in the select region of the mhlo op. - insert_call_to(op.select(), &select_and_scatter_op.getSelect()); + insert_call_to(op.getSelect(), &select_and_scatter_op.getSelect()); // Insert a call to the scatter function in the scatter region of the mhlo // op. - insert_call_to(op.scatter(), &select_and_scatter_op.getScatter()); + insert_call_to(op.getScatter(), &select_and_scatter_op.getScatter()); rewriter.replaceOp(op, select_and_scatter_op.getResult()); @@ -8716,9 +8159,9 @@ class ConvertXlaSortOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::XlaSortOp op, PatternRewriter &rewriter) const override { // Create the sort op. - Type element_type = getElementTypeOrSelf(op.input().getType()); + Type element_type = getElementTypeOrSelf(op.getInput().getType()); auto sort_op = - createSortOp(&rewriter, op.getLoc(), {op.input()}, {element_type}, + createSortOp(&rewriter, op.getLoc(), {op.getInput()}, {element_type}, /*dimension=*/-1, /*is_stable=*/false, /*direction=*/ComparisonDirection::LT); rewriter.replaceOp(op, sort_op.getResult(0)); @@ -8748,7 +8191,7 @@ class ConvertXlaRngBitGeneratorOp PatternRewriter &rewriter) const override { Location loc = op.getLoc(); DenseElementsAttr algorithm; - if (!(matchPattern(op.algorithm(), m_Constant(&algorithm))) || + if (!(matchPattern(op.getAlgorithm(), m_Constant(&algorithm))) || algorithm.getType().getRank()) { return op.emitOpError() << "algorithm must be a constant scalar"; } @@ -8761,9 +8204,9 @@ class ConvertXlaRngBitGeneratorOp auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get( rewriter.getContext(), - *mlir::mhlo::symbolizeRngAlgorithm(xla_alg.getValue())); + *mlir::mhlo::symbolizeRngAlgorithm(xla_alg.value())); auto rng_bit_generator_op = rewriter.create( - loc, op.getResultTypes(), algorithm_attr, op.initial_state()); + loc, op.getResultTypes(), algorithm_attr, op.getInitialState()); rewriter.replaceOp(op, rng_bit_generator_op.getResults()); @@ -8783,9 +8226,9 @@ class ConvertXlaVariadicReduceV2Op // Create the mhlo.reduce op. auto reduce_op = rewriter.create( - loc, op.inputs(), op.init_values(), - GetI64ElementsAttr(op.dimensions_to_reduce())); - mlir::SymbolRefAttr func = op.reducer(); + loc, op.getInputs(), op.getInitValues(), + GetI64ElementsAttr(op.getDimensionsToReduce())); + mlir::SymbolRefAttr func = op.getReducer(); auto func_op = cast(SymbolTable::lookupSymbolIn( op->getParentOfType(), func)); auto func_ty = func_op.getFunctionType(); @@ -8808,12 +8251,12 @@ class ConvertXlaVariadicSortOp PatternRewriter &rewriter) const override { Location loc = op.getLoc(); ElementsAttr dimension; - matchPattern(op.dimension(), m_Constant(&dimension)); + matchPattern(op.getDimension(), m_Constant(&dimension)); // Create the mhlo.sort op. auto sort_op = rewriter.create( - loc, op.inputs(), dimension.getValues()[0].getInt(), - op.is_stable()); - mlir::SymbolRefAttr func = op.comparator(); + loc, op.getInputs(), dimension.getValues()[0].getInt(), + op.getIsStable()); + mlir::SymbolRefAttr func = op.getComparator(); auto func_op = cast(SymbolTable::lookupSymbolIn( op->getParentOfType(), func)); auto func_ty = func_op.getFunctionType(); @@ -8834,25 +8277,118 @@ class ConvertXlaReducePrecisionOp LogicalResult matchAndRewrite(TF::XlaReducePrecisionOp op, PatternRewriter &rewriter) const override { IntegerType int32_type = rewriter.getIntegerType(32); - APInt exponent_bits = op.exponent_bitsAttr().getValue(); + APInt exponent_bits = op.getExponentBitsAttr().getValue(); // Truncating to 32-bits is safe, since pasing any number above the dtype // size (which is at most 64, for float64) is equivalent to passing the // dtype size. IntegerAttr new_exponent_attr = IntegerAttr::get(int32_type, exponent_bits.truncSSat(32)); - APInt mantissa_bits = op.mantissa_bitsAttr().getValue(); + APInt mantissa_bits = op.getMantissaBitsAttr().getValue(); IntegerAttr new_mantissa_attr = IntegerAttr::get(int32_type, mantissa_bits.truncSSat(32)); rewriter.replaceOpWithNewOp( - op, op.getType(), op.operand(), new_exponent_attr, new_mantissa_attr); + op, op.getType(), op.getOperand(), new_exponent_attr, + new_mantissa_attr); + return success(); + } +}; + +class LowerYieldOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + TF::YieldOp op, TF::YieldOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; +// Returns a new tensor type from the given type with element type updated to +// the given type. +TensorType UpdateElementTypeTo(Type ty, Type element_ty) { + auto ranked_ty = ty.dyn_cast(); + if (!ranked_ty) { + return UnrankedTensorType::get(element_ty); + } + return RankedTensorType::get(ranked_ty.getShape(), element_ty, + ranked_ty.getEncoding()); +} + +template +class LowerControlFlowOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + SrcOpT op, typename SrcOpT::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + DstOpT mhlo_op; + Location loc = op.getLoc(); + + // To handle quant type conversions, use the converted operands' element + // types and original source op's shapes and encoding to get converted op's + // result types. This is only done for the While op for now. + llvm::SmallVector element_types; + int64_t num_results = op.getNumResults(); + if constexpr (std::is_same::value) { + element_types.reserve(num_results); + for (Value value : adaptor.getOperands()) { + element_types.push_back(getElementTypeOrSelf(value.getType())); + } + } + + if constexpr (std::is_same::value) { + // Explicitly handle the Case op because it has variadic regions and takes + // the number of regions as an input along with the operands. + mhlo_op = rewriter.create(loc, op.getResultTypes(), + adaptor.getBranchIndex(), + op.getBranches().size()); + } else if constexpr (std::is_same::value) { + llvm::SmallVector while_result_types; + while_result_types.reserve(num_results); + for (int64_t idx = 0; idx < num_results; ++idx) { + auto ty = UpdateElementTypeTo(op.getType(idx), element_types[idx]); + while_result_types.push_back(ty); + } + + mhlo_op = rewriter.create(loc, TypeRange(while_result_types), + adaptor.getOperands()); + } else { + mhlo_op = rewriter.create(loc, op.getResultTypes(), + adaptor.getOperands()); + } + + // Replace all uses of `op` results with the newly created op. + rewriter.replaceOp(op, mhlo_op.getResults()); + + int64_t num_regions = op.getNumRegions(); + for (int64_t idx = 0; idx < num_regions; ++idx) { + Region ®ion = mhlo_op.getBodyRegion(idx); + rewriter.inlineRegionBefore(op.getBodyRegion(idx), region, region.end()); + + // Update region's entry blocks argument types to handle quantized element + // types. + if constexpr (std::is_same::value) { + TypeConverter::SignatureConversion signature(num_results); + Block &block = region.front(); + for (auto &[block_idx, original_ty] : + llvm::enumerate(block.getArgumentTypes())) { + TensorType updated_ty = + UpdateElementTypeTo(original_ty, element_types[block_idx]); + signature.addInputs(block_idx, {updated_ty}); + } + rewriter.applySignatureConversion(®ion, signature); + } + } + return success(); + } +}; } // end namespace #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" - +// LINT.IfChange void PopulateLegalizeTfPatterns(MLIRContext *context, RewritePatternSet *patterns) { populateWithGenerated(*patterns); @@ -8906,7 +8442,6 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, - ConvertQrOp, ConvertDynamicRangeOp, ConvertMatrixDiagPartV3Op, ConvertRangeOp, @@ -8959,9 +8494,13 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertStridedSliceOpDynamic, ConvertConv2DBackpropInputDynamic, ConvertConv2DBackpropFilterDynamic, - ConvertDynamicStitchOpDynamic>(context); + ConvertDynamicStitchOpDynamic, + LowerControlFlowOp, + LowerControlFlowOp, + LowerControlFlowOp, + LowerYieldOp>(context); // clang-format on } - +// LINT.ThenChange(:MlirPreferredOps) } // end namespace mhlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_collective.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_collective.cc index a64cb31e5fd..2115fd69d0a 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_collective.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_collective.cc @@ -36,9 +36,9 @@ limitations under the License. #include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/utils.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/convert_op_folder.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/hlo_utils.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/utils/convert_op_folder.h" +#include "tensorflow/compiler/xla/mlir_hlo/utils/hlo_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace mlir { @@ -206,7 +206,7 @@ class ConvertXlaAllReduce LogicalResult matchAndRewrite(TF::XlaAllReduceOp all_reduce, PatternRewriter& rewriter) const override { DenseIntElementsAttr replica_groups; - if (failed(ConvertReplicaGroups(rewriter, all_reduce.group_assignment(), + if (failed(ConvertReplicaGroups(rewriter, all_reduce.getGroupAssignment(), replica_groups, all_reduce))) { return failure(); } @@ -217,7 +217,7 @@ class ConvertXlaAllReduce return failure(); } - StringRef reduce_op = all_reduce.reduce_op(); + StringRef reduce_op = all_reduce.getReduceOp(); StringRef merge_op, final_op; if (reduce_op == "Add") { @@ -243,8 +243,9 @@ class ConvertXlaAllReduce int64_t channel_id = channel_id_++; return ConvertAllReduce(rewriter, channel_id, all_reduce.getType(), - replica_groups, all_reduce.mode(), - all_reduce.input(), merge_op, final_op, all_reduce); + replica_groups, all_reduce.getMode(), + all_reduce.getInput(), merge_op, final_op, + all_reduce); } }; @@ -258,13 +259,14 @@ class ConvertCollectiveReduceV2 LogicalResult matchAndRewrite(TF::CollectiveReduceV2Op all_reduce, PatternRewriter& rewriter) const override { TF::CollectiveAssignGroupV2Op assign_group = - all_reduce.group_size().getDefiningOp(); + all_reduce.getGroupSize() + .getDefiningOp(); if (assign_group) { // Found a group assignment. Use replica_groups to represent group // assignment. - if (assign_group != all_reduce.group_key() + if (assign_group != all_reduce.getGroupKey() .getDefiningOp()) { return all_reduce->emitOpError() << "group_size and group_key are not from the " @@ -272,7 +274,8 @@ class ConvertCollectiveReduceV2 } DenseIntElementsAttr replica_groups; - if (failed(ConvertReplicaGroups(rewriter, assign_group.group_assignment(), + if (failed(ConvertReplicaGroups(rewriter, + assign_group.getGroupAssignment(), replica_groups, all_reduce))) { return failure(); } @@ -293,13 +296,14 @@ class ConvertCollectiveReduceV2 // ops are used. return ConvertAllReduce(rewriter, channel_id, all_reduce.getType(), replica_groups, /* mode=*/"CrossReplica", - all_reduce.input(), all_reduce.merge_op(), - all_reduce.final_op(), all_reduce); + all_reduce.getInput(), all_reduce.getMergeOp(), + all_reduce.getFinalOp(), all_reduce); } // No group assignment, use separate channels per group_key. DenseIntElementsAttr group_size_attr; - if (!matchPattern(all_reduce.group_size(), m_Constant(&group_size_attr))) { + if (!matchPattern(all_reduce.getGroupSize(), + m_Constant(&group_size_attr))) { return all_reduce.emitOpError() << "group_size must be a compile time constant"; } @@ -322,7 +326,8 @@ class ConvertCollectiveReduceV2 // TODO(b/226201111): Stop emitting CollectiveInfo when it is no longer // needed. DenseIntElementsAttr group_key_attr; - if (!matchPattern(all_reduce.group_key(), m_Constant(&group_key_attr))) { + if (!matchPattern(all_reduce.getGroupKey(), + m_Constant(&group_key_attr))) { return all_reduce.emitOpError() << "group_key must be a compile time constant"; } @@ -342,8 +347,8 @@ class ConvertCollectiveReduceV2 int64_t channel_id = channel_id_++; return ConvertAllReduce( rewriter, channel_id, all_reduce.getType(), replica_groups, - /* mode= */ "CrossReplicaAndPartition", all_reduce.input(), - all_reduce.merge_op(), all_reduce.final_op(), all_reduce); + /* mode= */ "CrossReplicaAndPartition", all_reduce.getInput(), + all_reduce.getMergeOp(), all_reduce.getFinalOp(), all_reduce); } }; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc index aa8143a76cd..ff72dce3e7d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc @@ -39,7 +39,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/side_effect_util.h" #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/type_to_shape.h" @@ -330,18 +330,18 @@ Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id, Location loc = host_compute.getLoc(); SmallVector send_tokens; - for (auto operand : llvm::enumerate(host_compute.inputs())) { + for (auto operand : llvm::enumerate(host_compute.getInputs())) { auto send_token = CreateSendOp( - builder, channel_id, loc, operand.value(), host_compute.send_key(), + builder, channel_id, loc, operand.value(), host_compute.getSendKey(), operand.index(), token, xla::kXlaHostTransferTfRendezvousHandlerName); send_tokens.push_back(send_token); } token = CreateSinkToken(builder, loc, send_tokens, token); SmallVector recv_tokens; - for (auto result : llvm::enumerate(host_compute.outputs())) { + for (auto result : llvm::enumerate(host_compute.getOutputs())) { auto recv_token = CreateRecvOp( - builder, channel_id, loc, result.value(), host_compute.recv_key(), + builder, channel_id, loc, result.value(), host_compute.getRecvKey(), result.index(), token, xla::kXlaHostTransferTfRendezvousHandlerName); recv_tokens.push_back(recv_token); } @@ -356,7 +356,7 @@ Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id, TF::XlaSendToHostOp send_to_host, Value token) { builder.setInsertionPoint(send_to_host); token = CreateSendOp(builder, channel_id, send_to_host.getLoc(), - send_to_host.input(), send_to_host.key(), + send_to_host.getInput(), send_to_host.getKey(), /*index=*/0, token, xla::kXlaHostTransferTfRendezvousHandlerName); @@ -369,7 +369,7 @@ Value RewriteRecvFromHostOp(OpBuilder& builder, int64_t& channel_id, TF::XlaRecvFromHostOp recv_from_host, Value token) { builder.setInsertionPoint(recv_from_host); token = CreateRecvOp(builder, channel_id, recv_from_host.getLoc(), - recv_from_host.output(), recv_from_host.key(), + recv_from_host.getOutput(), recv_from_host.getKey(), /*index=*/0, token, xla::kXlaHostTransferTfRendezvousHandlerName); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc deleted file mode 100644 index 78190e1dd57..00000000000 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc +++ /dev/null @@ -1,452 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements logic for lowering TensorFlow dialect's control flow to -// the XLA dialect. - -#include -#include -#include -#include -#include - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Transforms/RegionUtils.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" - -using mlir::PassRegistration; - -namespace mlir { -namespace mhlo { -namespace { - -#define GEN_PASS_DEF_LEGALIZETFCONTROLFLOW -#include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes.h.inc" - -class LegalizeTFControlFlow - : public impl::LegalizeTFControlFlowBase { - public: - void runOnOperation() override; -}; -} // namespace - -std::unique_ptr> -createLegalizeTFControlFlowPass() { - return std::make_unique(); -} - -namespace { - -void Detuple(Value tuple, ValueRange replace, OpBuilder* builder) { - // De-tuple the results of the xla hlo if result. - for (auto result_it : llvm::enumerate(replace)) { - auto get_tuple_value = builder->create( - result_it.value().getLoc(), tuple, result_it.index()); - result_it.value().replaceAllUsesWith(get_tuple_value); - } -} - -// For mlir::IfOp or mlir::CaseOp, replace the uses of their region's block -// arguments with 'implicit_operands'. Here | 'implicit_operands' | == Number of -// arguments in any of the regions in IfOp or CaseOp. -void ReplaceBlockArgumentsWithImplicitOperands( - mlir::Operation* op, llvm::ArrayRef implicit_operands) { - assert((mlir::dyn_cast(*op) || - mlir::dyn_cast(*op)) && - "Unexpected mlir op in ReplaceBlockArgumentsWithImplicitOperands!"); - - for (auto& region : op->getRegions()) { - int implicit_operand_index = 0; - for (auto arg : region.getArguments()) { - assert(implicit_operand_index < implicit_operands.size()); - arg.replaceAllUsesWith(implicit_operands[implicit_operand_index++]); - } - - region.front().eraseArguments(0, region.getNumArguments()); - } -} - -// Imports the source region into the destination region. MHLO supports -// multiple arguments per branch and multiple returns which are individually -// tupled together during export to XLA. This tupling is needed as XLA if/while -// operation only supports one argument per branch and a single return value. -// `tuple_arg` allows any branch that requires additional arguments to have -// their values be tupled together. Similarly, `tuple_return` allows the results -// of the if/while operation to be tupled together. -void ImportXlaRegion(mlir::func::FuncOp func, Region* dest_region, Location loc, - bool tuple_return = true, bool tuple_arg = true) { - OpBuilder builder(dest_region); - - auto entry_block = builder.createBlock(dest_region); - func::CallOp result; - if (!tuple_arg) { - auto inputs = func.getFunctionType().getInputs(); - auto args = entry_block->addArguments( - inputs, SmallVector(inputs.size(), loc)); - ArrayRef callop_args(args.begin(), args.end()); - result = builder.create(loc, func, callop_args); - } else { - auto tuple_arg = entry_block->addArgument( - builder.getTupleType(func.getFunctionType().getInputs()), loc); - llvm::SmallVector detupled_args; - detupled_args.reserve(func.getNumArguments()); - - for (int64_t i = 0, s = func.getNumArguments(); i < s; i++) { - auto extract = builder.create(loc, tuple_arg, i); - detupled_args.push_back(extract); - } - - result = builder.create(loc, func, detupled_args); - } - - if (!tuple_return) { - builder.create(loc, result.getResults()); - } else { - auto tuple_op = builder.create(loc, result.getResults()); - builder.create(loc, tuple_op.getResult()); - } -} - -void LowerIf(TF::IfOp op) { - Location loc = op.getLoc(); - OpBuilder builder(op); - - SmallVector inputs(op.input()); - - // Create the new `mhlo.if` op. - auto if_op = builder.create(loc, op.getResultTypes(), op.cond()); - - // Import the regions for both the true and false cases. These regions - // must be updated to tuple the return results together and use the xla hlo - // return op. - ImportXlaRegion(op.then_function(), &if_op.getTrueBranch(), loc, - /*tuple_return=*/false, /*tuple_arg=*/false); - ImportXlaRegion(op.else_function(), &if_op.getFalseBranch(), loc, - /*tuple_return=*/false, /*tuple_arg=*/false); - - // Replace the uses of block-arguments of the IfOp with the - // implicit_operands. - ReplaceBlockArgumentsWithImplicitOperands(if_op.getOperation(), inputs); - - op->replaceAllUsesWith(if_op); - op.erase(); -} - -void LowerCase(TF::CaseOp op) { - Location loc = op.getLoc(); - OpBuilder builder(op); - - SmallVector inputs(op.input()); - - // Create the new `mhlo.case` op. - auto case_op = builder.create( - loc, op.getResultTypes(), op.branch_index(), op.branches().size()); - - // Import the regions for all branches. - for (unsigned i = 0; i < op.num_branches(); ++i) { - mlir::func::FuncOp branch_func = op.branch_function(i); - ImportXlaRegion(branch_func, &case_op.getBranches()[i], loc, - /*tuple_return=*/false, /*tuple_arg=*/false); - } - - // Replace the uses of block-arguments of the IfOp with the - // implicit_operands. - ReplaceBlockArgumentsWithImplicitOperands(case_op.getOperation(), inputs); - - op.replaceAllUsesWith(case_op); - op.erase(); -} - -void LowerWhile(TF::WhileOp op) { - Location loc = op.getLoc(); - OpBuilder builder(op); - - // XLA prefers tuple arguments for control flow due to XLA not supporting - // multiple return values. - SmallVector inputs(op.input()); - builder.setInsertionPoint(op); - - // Create the new `mhlo.while` op with inputs. - auto while_op = - builder.create(loc, op.getResultTypes(), inputs); - - // Import the regions for both the cond and body. - ImportXlaRegion(op.body_function(), &while_op.getBody(), loc, - /*tuple_return=*/false, /*tuple_arg=*/false); - ImportXlaRegion(op.cond_function(), &while_op.getCond(), loc, - /*tuple_return=*/false, /*tuple_arg=*/false); - - op->replaceAllUsesWith(while_op); - op.erase(); -} - -// Replaces all block arguments of a block with a single block arg of Tuple -// type `tuple_type`. Single block arguments are removed and remapped to -// get_tuple_element(tuple_arg, index). -void ReplaceBlockArgs(Block* block, Type tuple_type, OpBuilder* builder) { - auto tuple_arg = block->addArgument(tuple_type, block->getParent()->getLoc()); - Detuple(tuple_arg, block->getArguments().drop_back(1), builder); - for (int i = block->getNumArguments() - 2; i >= 0; --i) - block->eraseArgument(i); -} - -// Replaces implicitly captured value uses with block arguments. -llvm::SmallVector ReplaceImplicitInputs( - Block* block, int offset, ArrayRef implicit_inputs) { - llvm::SmallVector implicit_input_elements; - implicit_input_elements.reserve(implicit_inputs.size()); - - Region* region = block->getParent(); - - for (auto& implicit_input : llvm::enumerate(implicit_inputs)) { - Value implicit_input_value = implicit_input.value(); - BlockArgument arg = block->getArgument(implicit_input.index() + offset); - implicit_input_elements.emplace_back(arg); - for (auto& use : - llvm::make_early_inc_range(implicit_input_value.getUses())) { - if (!region->isAncestor(use.getOwner()->getParentRegion())) continue; - use.set(arg); - } - } - - return implicit_input_elements; -} - -// Replaces implicitly captured value uses with tuple block argument. -// get_tuple_element's are created to extract specific values. Values from -// get_tuple_element's are returned in the order of `implicit_inputs`. -llvm::SmallVector ReplaceImplicitInputsWithTupleElements( - Block* block, int offset, ArrayRef implicit_inputs, - OpBuilder* builder) { - llvm::SmallVector implicit_input_elements; - implicit_input_elements.reserve(implicit_inputs.size()); - - Region* region = block->getParent(); - assert(block->getNumArguments() == 1); - - BlockArgument tuple_arg = block->getArgument(0); - for (auto& implicit_input : llvm::enumerate(implicit_inputs)) { - Value implicit_input_value = implicit_input.value(); - auto get_tuple_element = builder->create( - implicit_input_value.getLoc(), tuple_arg, - implicit_input.index() + offset); - implicit_input_elements.emplace_back(get_tuple_element.getResult()); - for (auto& use : - llvm::make_early_inc_range(implicit_input_value.getUses())) { - if (!region->isAncestor(use.getOwner()->getParentRegion())) continue; - use.set(get_tuple_element.getResult()); - } - } - - return implicit_input_elements; -} - -// Finds and replaces implicitly captured value uses with tuple block argument. -// A tuple of implicitly captured values is also created and returned, for use -// as an operand to the associated mhlo control flow op. -Value TupleImplicitInputs(Region& region, Location loc, OpBuilder* builder) { - llvm::SetVector implicit_inputs; - getUsedValuesDefinedAbove(region, region, implicit_inputs); - llvm::ArrayRef implicit_inputs_ref = implicit_inputs.getArrayRef(); - Value tuple_input = builder->create(loc, implicit_inputs_ref); - Block& block = region.front(); - // `tf.CaseRegion`/`tf.IfRegion` are expected to have no block arguments and - // instead all inputs used by their branch regions are implicitly captured - // from above. - assert(block.getNumArguments() == 0); - block.addArgument(tuple_input.getType(), loc); - builder->setInsertionPointToStart(&block); - ReplaceImplicitInputsWithTupleElements(&block, /*offset=*/0, - implicit_inputs_ref, builder); - return tuple_input; -} - -// Replaces block terminator (tf.Yield) with `mhlo.return`. Additional results -// can be returned if `extra_results` is not empty. If `tuple_return` is -// set, a tuple of the return values will be set as the terminator operand. -void ReplaceTerminator(Block* block, ArrayRef extra_results, - OpBuilder* builder, bool tuple_return = true) { - Operation* terminator = block->getTerminator(); - assert(isa(terminator)); - Location loc = terminator->getLoc(); - - builder->setInsertionPoint(terminator); - auto results = llvm::to_vector<4>(terminator->getOperands()); - results.append(extra_results.begin(), extra_results.end()); - if (tuple_return) { - auto tuple_results = builder->create(loc, results); - builder->create(loc, tuple_results.getResult()); - } else { - builder->create(loc, results); - } - - terminator->erase(); -} - -void LowerIfRegion(TF::IfRegionOp op) { - Location loc = op.getLoc(); - OpBuilder builder(op); - - builder.setInsertionPoint(op); - ReplaceTerminator(&op.then_branch().front(), /*extra_results=*/{}, &builder, - /*tuple_return=*/false); - - builder.setInsertionPoint(op); - ReplaceTerminator(&op.else_branch().front(), /*extra_results=*/{}, &builder, - /*tuple_return=*/false); - - // Create the new `mhlo.if` op and take ownership of regions from - // `tf.IfRegion` op. - builder.setInsertionPoint(op); - auto if_op = builder.create(loc, op.getResultTypes(), op.cond()); - if_op.getTrueBranch().takeBody(op.then_branch()); - if_op.getFalseBranch().takeBody(op.else_branch()); - - // Replace all uses of `op` results with that of `mhlo.IfOp`. - op->replaceAllUsesWith(if_op); - - op.erase(); -} - -void LowerCaseRegion(TF::CaseRegionOp op) { - Location loc = op.getLoc(); - OpBuilder builder(op); - - for (Region& region : op.branches()) { - builder.setInsertionPoint(op); - ReplaceTerminator(®ion.front(), /*extra_results=*/{}, &builder, - /*tuple_return=*/false); - } - - // Create the new `mhlo.case` op and take ownership of regions from - // `tf.CaseRegion` op. - builder.setInsertionPoint(op); - auto case_op = builder.create( - loc, op.getResultTypes(), op.branch_index(), op.branches().size()); - for (auto region : llvm::zip(case_op.getBranches(), op.branches())) - std::get<0>(region).takeBody(std::get<1>(region)); - - // Replace all uses of `op` results with that of `mhlo.CaseOp`. - op.replaceAllUsesWith(case_op); - op.erase(); -} - -void LowerWhileRegion(TF::WhileRegionOp op) { - Location loc = op.getLoc(); - OpBuilder builder(op); - - SmallVector inputs(op.input()); - const int inputs_size = inputs.size(); - llvm::SetVector implicit_inputs; - getUsedValuesDefinedAbove(op.getOperation()->getRegions(), implicit_inputs); - inputs.append(implicit_inputs.begin(), implicit_inputs.end()); - - builder.setInsertionPoint(op); - - // Create the new `mhlo.while` op with 'inputs'. Implicit inputs are also - // returned. - auto while_result_types = llvm::to_vector<4>(op.getResultTypes()); - while_result_types.reserve(while_result_types.size() + - implicit_inputs.size()); - for (const auto& implicit_input : implicit_inputs) - while_result_types.emplace_back(implicit_input.getType()); - auto while_op = - builder.create(loc, while_result_types, inputs); - - // Rewrite cond and associated block arguments and terminator. Ownership of - // cond region is transfered over from `tf.WhileRegion` to `mhlo.while`. - Region& cond = while_op.getCond(); - cond.takeBody(op.cond()); - Block& cond_block = cond.front(); - builder.setInsertionPointToStart(&cond_block); - - // Add args corresponding to 'implicit_inputs'. - for (const auto& implicit_input : implicit_inputs) - cond_block.addArgument(implicit_input.getType(), loc); - ReplaceImplicitInputs(&cond_block, inputs_size, - implicit_inputs.getArrayRef()); - // Cond always returns a single result of bool type. - ReplaceTerminator(&cond_block, /*extra_results=*/{}, &builder, - /*tuple_return=*/false); - - // Rewrite body and associated block arguments and terminator. Ownership of - // body region is transfered over from `tf.WhileRegion` to `mhlo.while`. - Region& body = while_op.getBody(); - body.takeBody(op.body()); - Block& body_block = body.front(); - builder.setInsertionPointToStart(&body_block); - // Add args corresponding to 'implicit_inputs'. - for (const auto& implicit_input : implicit_inputs) - body_block.addArgument(implicit_input.getType(), loc); - auto implicit_input_elements = ReplaceImplicitInputs( - &body_block, inputs_size, implicit_inputs.getArrayRef()); - ReplaceTerminator(&body_block, implicit_input_elements, &builder, false); - - // Replace all uses of `op` results with that of `mhlo.while`. - builder.setInsertionPoint(op); - if (while_op.getNumResults() > 1) { - for (const auto& result_it : llvm::enumerate(op.getResults())) - result_it.value().replaceAllUsesWith( - while_op.getResult(result_it.index())); - } else { - op->replaceAllUsesWith(while_op); - } - op.erase(); -} -} // namespace - -void LegalizeTFControlFlow::runOnOperation() { - getOperation().walk([&](Operation* op) { - if (auto while_op = dyn_cast(op)) { - LowerWhile(while_op); - return; - } - if (auto while_region_op = dyn_cast(op)) { - LowerWhileRegion(while_region_op); - return; - } - if (auto if_op = dyn_cast(op)) { - LowerIf(if_op); - return; - } - if (auto if_region_op = dyn_cast(op)) { - LowerIfRegion(if_region_op); - return; - } - if (auto case_op = dyn_cast(op)) { - LowerCase(case_op); - return; - } - if (auto case_region_op = dyn_cast(op)) { - LowerCaseRegion(case_region_op); - return; - } - }); -} -} // namespace mhlo -} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 8804fb09af6..bf92d8cc8e3 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -21,7 +21,7 @@ include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/Dialect/Tensor/IR/TensorOps.td" include "stablehlo/dialect/ChloOps.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" -include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.td" def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>; def UnsignedIntTensor : TensorOf<[UI8, UI16, UI32, UI64]>; @@ -64,8 +64,8 @@ def CastElementsToI64Elements : NativeCodeCall< // ApproximateEqual op pattern. //===----------------------------------------------------------------------===// -class HLO_ComparisonDirectionValue : - ConstantAttr; +class MHLO_ComparisonDirectionValue : + ConstantAttr; class CHLO_ComparisonDirectionValue : ConstantAttr; @@ -73,11 +73,11 @@ class CHLO_ComparisonDirectionValue : // TODO(b/228291745): Assert that $x and $y have the same shape. def : Pat<(TF_ApproximateEqualOp:$result $x, $y, $tolerance), (CHLO_BroadcastCompareOp - (HLO_AbsOp:$abs (HLO_SubtractOp $x, $y)), - (CastValueToElementType $result, (HLO_ConstantOp $tolerance), $abs), + (MHLO_AbsOp:$abs (MHLO_SubtractOp $x, $y)), + (CastValueToElementType $result, (MHLO_ConstantOp $tolerance), $abs), (NullDenseIntElementsAttr), CHLO_ComparisonDirectionValue<"LT">, - (HLO_DEFAULT_COMPARISON_TYPE))>; + (CHLO_DEFAULT_COMPARISON_TYPE))>; //===----------------------------------------------------------------------===// // Assert op pattern. @@ -131,7 +131,7 @@ def LowerRightShiftUnsigned : // // return floor(div(x, y)) def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), - (HLO_FloorOp + (MHLO_FloorOp (CHLO_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))), [(IEEEFloatTensor $l)]>; @@ -146,7 +146,7 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), // dimensions. This computes the broadcast of 'l' to broadcast('l', 'r') // without returning the broadcast of 'r' to broadcast('l', 'r'). def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), - (HLO_SelectOp + (MHLO_SelectOp (CHLO_BroadcastAndOp (CHLO_BroadcastCompareOp (CHLO_BroadcastMulOp:$mul @@ -157,18 +157,18 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp (CHLO_BroadcastCompareOp:$l_cmp $l, - (HLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), + (MHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), (NullDenseIntElementsAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp:$r_cmp $r, - (HLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), + (MHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), (NullDenseIntElementsAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (BinBroadcastDimensions $l_cmp, $r_cmp), CHLO_ComparisonDirectionValue<"NE">, (CHLO_DEFAULT_COMPARISON_TYPE)), (NullDenseIntElementsAttr)), (CHLO_BroadcastSubOp $div, - (HLO_ConstantOp:$ones (GetScalarOfType<1> $div)), + (MHLO_ConstantOp:$ones (GetScalarOfType<1> $div)), (NullDenseIntElementsAttr)), $div), [(SignedIntTensor $l)]>; @@ -184,16 +184,16 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), // return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y // : trunc_mod def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), - (HLO_SelectOp + (MHLO_SelectOp (CHLO_BroadcastAndOp (CHLO_BroadcastCompareOp (CHLO_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), - (HLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), + (MHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), (NullDenseIntElementsAttr), CHLO_ComparisonDirectionValue<"NE">, (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp (CHLO_BroadcastCompareOp:$r_cmp $r, - (HLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), + (MHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), (NullDenseIntElementsAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp:$rem_cmp $rem, $r_zeros, @@ -214,10 +214,10 @@ def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), def Get2DTransposePerm: NativeCodeCall< "Get2DTransposePerm($0, &$_builder)">; -def : Pat<(TF_RiscAddOp $l, $r), (HLO_AddOp $l, $r)>; +def : Pat<(TF_RiscAddOp $l, $r), (MHLO_AddOp $l, $r)>; def : Pat<(TF_RiscDotOp $a, $b, $transpose_a, $transpose_b), - (HLO_DotOp + (MHLO_DotOp (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))), (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))), /*precision_config=*/(NullArrayAttr))>; @@ -259,7 +259,7 @@ class EqualityPat (CHLO_BroadcastCompareOp $l, $r, (BinBroadcastDimensions $l, $r), direction, (CHLO_DEFAULT_COMPARISON_TYPE)), - [(HLO_Tensor $l)]>; + [(MHLO_Tensor $l)]>; def : EqualityPat>; def : EqualityPat>; @@ -290,7 +290,7 @@ def IsShapedTensor // if HLO constant op is introduced as an replacement for the TensorFlow // Constant op. def : Pat<(TF_ConcatV2Op $inputs, (ConstantLikeMatcher OneElementAttr:$axis)), - (HLO_ConcatenateOp $inputs, + (MHLO_ConcatenateOp $inputs, (GetHLOAxisFromTFAxisVariadic $axis, $inputs)), [(HasRankedFirstOperand $inputs)]>; @@ -299,7 +299,7 @@ def : Pat<(TF_ConcatV2Op $inputs, (ConstantLikeMatcher OneElementAttr:$axis)), //===----------------------------------------------------------------------===// def : Pat<(TF_CollectivePermuteOp $input, (ConstantLikeMatcher ElementsAttr:$source_target_pairs)), - (HLO_CollectivePermuteOp $input, + (MHLO_CollectivePermuteOp $input, (CastElementsToI64Elements $source_target_pairs), (NullChannelHandleAttr))>; @@ -308,22 +308,23 @@ def : Pat<(TF_CollectivePermuteOp $input, (ConstantLikeMatcher ElementsAttr:$sou //===----------------------------------------------------------------------===// def : Pat<(TF_CrossReplicaSumOp $input, (ConstantLikeMatcher ElementsAttr:$group_assignment)), - (HLO_CrossReplicaSumOp $input, + (MHLO_CrossReplicaSumOp $input, (CastElementsToI64Elements $group_assignment))>; //===----------------------------------------------------------------------===// // All2All op patterns. //===----------------------------------------------------------------------===// +def ValueToVariadic: NativeCodeCall<"SmallVector{$0}">; def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (ConstantLikeMatcher ElementsAttr:$group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count), - (HLO_AllToAllOp $input, $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment))>; + (MHLO_AllToAllOp (ValueToVariadic $input), $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment), (NullChannelHandleAttr))>; //===----------------------------------------------------------------------===// // FFT op patterns. //===----------------------------------------------------------------------===// -class HLO_FftTypeValue : - ConstantAttr; +class MHLO_FftTypeValue : + ConstantAttr; def GetInnerDimFromValue : NativeCodeCall< "GetInnerDimFromValue($0.getType().cast(), &$_builder)">; @@ -332,11 +333,11 @@ def CheckInnerDimStatic : Constraint(), &$_builder)">>; def : Pat<(TF_FFTOp:$res $input), - (HLO_FftOp $input, HLO_FftTypeValue<"FFT">, (GetInnerDimFromValue $res)), + (MHLO_FftOp $input, MHLO_FftTypeValue<"FFT">, (GetInnerDimFromValue $res)), [(CheckInnerDimStatic $input)]>; def : Pat<(TF_IFFTOp:$res $input), - (HLO_FftOp $input, HLO_FftTypeValue<"IFFT">, (GetInnerDimFromValue $res)), + (MHLO_FftOp $input, MHLO_FftTypeValue<"IFFT">, (GetInnerDimFromValue $res)), [(CheckInnerDimStatic $input)]>; //===----------------------------------------------------------------------===// @@ -352,7 +353,7 @@ def : Pat<(TF_IFFTOp:$res $input), // def LegalizeGatherV2 : // Pat<(TF_GatherV2Op AnyRankedTensor:$params, AnyRankedTensor:$indices, // (ConstantLikeMatcher ElementsAttr:$axis), $batch_dims), -// (HLO_TorchIndexSelectOp $params, $indices, +// (MHLO_TorchIndexSelectOp $params, $indices, // (GetHLOAxisFromTFAxis $axis, $params), // (GetHLOAxisFromTFAxis $batch_dims, $indices))>; @@ -373,7 +374,7 @@ def GetInteriorPadding : NativeCodeCall < // TODO: commented by DISC due to mhlo.pad is not supported in bufferization // and all the consequent passes. // def : Pat<(TF_PadV2Op $input, (ConstantLikeMatcher ElementsAttr:$padding), $c), -// (HLO_PadOp $input, $c, +// (MHLO_PadOp $input, $c, // (SliceDenseIntElementsAttrColumn2D<"0"> $padding), // (SliceDenseIntElementsAttrColumn2D<"1"> $padding), // (GetInteriorPadding $padding))>; @@ -394,7 +395,7 @@ foreach src = [TF_PreventGradientOp, TF_CheckNumericsOp] in //===----------------------------------------------------------------------===// def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b), - (HLO_DotOp + (MHLO_DotOp (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))), (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))), /*precision_config=*/(NullArrayAttr))>; @@ -404,41 +405,41 @@ def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b), //===----------------------------------------------------------------------===// def : Pat<(TF_ZerosLikeOp AnyTensor:$arg), - (HLO_ConstantLike<"0"> $arg)>; + (MHLO_ConstantLike<"0"> $arg)>; //===----------------------------------------------------------------------===// // Lower `tf.OnesLike` //===----------------------------------------------------------------------===// def : Pat<(TF_OnesLikeOp AnyTensor:$arg), - (HLO_ConstantLike<"1"> $arg)>; + (MHLO_ConstantLike<"1"> $arg)>; //===----------------------------------------------------------------------===// // Elu op patterns. //===----------------------------------------------------------------------===// def : Pat<(TF_EluOp AnyTensor:$features), - (HLO_SelectOp - (HLO_CompareOp + (MHLO_SelectOp + (MHLO_CompareOp $features, - (HLO_ConstantLike<"0">:$zero $features), - HLO_ComparisonDirectionValue<"GT">, (HLO_DEFAULT_COMPARISON_TYPE)), + (MHLO_ConstantLike<"0">:$zero $features), + MHLO_ComparisonDirectionValue<"GT">, (MHLO_DEFAULT_COMPARISON_TYPE)), $features, - (HLO_Expm1Op $features))>; + (MHLO_Expm1Op $features))>; def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features), - (HLO_SelectOp + (MHLO_SelectOp (CHLO_BroadcastCompareOp $features, - (HLO_ConstantOp:$zero (GetScalarOfType<0> $features)), + (MHLO_ConstantOp:$zero (GetScalarOfType<0> $features)), (BinBroadcastDimensions $zero, $features), CHLO_ComparisonDirectionValue<"GT">, (CHLO_DEFAULT_COMPARISON_TYPE)), $gradients, - (HLO_MulOp + (MHLO_MulOp $gradients, (CHLO_BroadcastAddOp $features, - (HLO_ConstantOp:$one (GetScalarOfType<1> $features)), + (MHLO_ConstantOp:$one (GetScalarOfType<1> $features)), (BinBroadcastDimensions $one, $features))))>; //===----------------------------------------------------------------------===// @@ -451,24 +452,24 @@ def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$featur // TODO(hinsu): Lower quantized types after supporting them in GetScalarOfType. def : Pat<(TF_ReluOp AnyTensor:$input), (CHLO_BroadcastMaxOp - (HLO_ConstantOp:$zero (GetScalarOfType<0> $input)), $input, + (MHLO_ConstantOp:$zero (GetScalarOfType<0> $input)), $input, (BinBroadcastDimensions $zero, $input)), [(TF_IntOrFpTensor $input)]>; // TODO(hinsu): Lower quantized types after supporting them in GetScalarOfType. def : Pat<(TF_Relu6Op AnyRankedTensor:$input), - (HLO_ClampOp (HLO_ConstantOp (GetScalarOfType<0> $input)), $input, - (HLO_ConstantOp (GetScalarOfType<6> $input))), + (MHLO_ClampOp (MHLO_ConstantOp (GetScalarOfType<0> $input)), $input, + (MHLO_ConstantOp (GetScalarOfType<6> $input))), [(TF_IntOrFpTensor $input)]>; // ReluGrad(gradients, features) = gradients * (features > 0) // The condition that $gradients and $features need to have the same shape is // implicitly enforced: $zero is created to have the same shape as $features, -// HLO_SelectOp enforces that $gradients and $zero have the same shape. +// MHLO_SelectOp enforces that $gradients and $zero have the same shape. def : Pat<(TF_ReluGradOp AnyTensor:$gradients, AnyTensor:$features), - (HLO_SelectOp - (HLO_CompareOp $features, (HLO_ConstantLike<"0">:$zero $features), - HLO_ComparisonDirectionValue<"GT">, (HLO_DEFAULT_COMPARISON_TYPE)), + (MHLO_SelectOp + (MHLO_CompareOp $features, (MHLO_ConstantLike<"0">:$zero $features), + MHLO_ComparisonDirectionValue<"GT">, (MHLO_DEFAULT_COMPARISON_TYPE)), $gradients, $zero)>; //===----------------------------------------------------------------------===// @@ -478,9 +479,9 @@ def : Pat<(TF_ReluGradOp AnyTensor:$gradients, AnyTensor:$features), /// Converts a TF::SoftsignOp to HLO. /// Softsign(features) = features / (1 + abs(features)) def : Pat<(TF_SoftsignOp AnyTensor:$input), - (HLO_DivOp + (MHLO_DivOp $input, - (HLO_AddOp (HLO_ConstantLike<"1"> $input), (HLO_AbsOp $input)) + (MHLO_AddOp (MHLO_ConstantLike<"1"> $input), (MHLO_AbsOp $input)) ) >; @@ -489,12 +490,12 @@ def : Pat<(TF_SoftsignOp AnyTensor:$input), def : Pattern< (TF_SoftsignGradOp AnyRankedTensor:$gradients, AnyRankedTensor:$features), [(CHLO_BroadcastAddOp:$add - (HLO_ConstantOp:$one (GetScalarOfType<1> $features)), (HLO_AbsOp $features), + (MHLO_ConstantOp:$one (GetScalarOfType<1> $features)), (MHLO_AbsOp $features), (BinBroadcastDimensions $one, $features) ), (CHLO_BroadcastDivOp $gradients, - (HLO_MulOp $add, $add), + (MHLO_MulOp $add, $add), (BinBroadcastDimensions $gradients, $add) ) ]>; @@ -504,7 +505,7 @@ def : Pattern< //===----------------------------------------------------------------------===// def UnpackStartingIndices: NativeCodeCall< - "UnpackTensorAlongZeroDim($0.getLoc(), $1, &$_builder).output()">; + "UnpackTensorAlongZeroDim($0.getLoc(), $1, &$_builder).getOutput()">; def CanBeTranslatedToDynamicSlice : Constraint())">>; @@ -513,9 +514,9 @@ def TFSliceSizes2HLOSliceSizes : NativeCodeCall< "TFSliceSizes2HLOSliceSizes($0, $1, $2.cast()," "&$_builder)">; -def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices, +def : Pat<(TF_SliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, (ConstantLikeMatcher AnyAttr:$slice_sizes)), - (HLO_DynamicSliceOp $input, + (MHLO_DynamicSliceOp $input, (UnpackStartingIndices $op, $starting_indices), (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)), [(CanBeTranslatedToDynamicSlice $input, $starting_indices, @@ -525,8 +526,8 @@ def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices, // Select op patterns. //===----------------------------------------------------------------------===// - def : Pat<(TF_SelectV2Op HLO_Tensor:$pred, HLO_Tensor:$on_true, - HLO_Tensor:$on_false), + def : Pat<(TF_SelectV2Op MHLO_Tensor:$pred, MHLO_Tensor:$on_true, + MHLO_Tensor:$on_false), (CHLO_BroadcastSelectOp $pred, $on_true, $on_false)>; //===----------------------------------------------------------------------===// @@ -534,7 +535,9 @@ def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices, //===----------------------------------------------------------------------===// def ArgTypesMatchCallee : Constraint< - CPred<"ArgTypesMatchCallee($0[0].getOwner(), $1, $2)">>; + // $0 is a resultset (possibly empty), and $_op isn't assigned. So retrieve + // the op using the builder. + CPred<"ArgTypesMatchCallee(&*$_builder.getInsertionPoint(), $1, $2)">>; foreach callOp = [TF_PartitionedCallOp, TF_StatefulPartitionedCallOp] in { def : Pat<(callOp:$op $args, FlatSymbolRefAttr:$f, @@ -557,75 +560,75 @@ def : Pat<(TF_LegacyCallOp:$op $args, FlatSymbolRefAttr:$f, $attr), def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1.cast(), &$_builder)">; def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher ElementsAttr:$axis)), - (HLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; + (MHLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; //===----------------------------------------------------------------------===// // Unary op patterns. //===----------------------------------------------------------------------===// foreach Mapping = [ - [TF_AbsOp, HLO_AbsOp], + [TF_AbsOp, MHLO_AbsOp], [TF_AcosOp, CHLO_AcosOp], [TF_AcoshOp, CHLO_AcoshOp], [TF_AsinOp, CHLO_AsinOp], [TF_AsinhOp, CHLO_AsinhOp], [TF_AtanOp, CHLO_AtanOp], [TF_AtanhOp, CHLO_AtanhOp], - [TF_CeilOp, HLO_CeilOp], + [TF_CeilOp, MHLO_CeilOp], [TF_CoshOp, CHLO_CoshOp], - [TF_ComplexAbsOp, HLO_AbsOp], + [TF_ComplexAbsOp, MHLO_AbsOp], [TF_ConjOp, CHLO_ConjOp], - [TF_CosOp, HLO_CosineOp], + [TF_CosOp, MHLO_CosineOp], [TF_DigammaOp, CHLO_DigammaOp], - [TF_ExpOp, HLO_ExpOp], - [TF_Expm1Op, HLO_Expm1Op], + [TF_ExpOp, MHLO_ExpOp], + [TF_Expm1Op, MHLO_Expm1Op], [TF_ErfOp, CHLO_ErfOp], [TF_ErfcOp, CHLO_ErfcOp], - [TF_FloorOp, HLO_FloorOp], - [TF_ImagOp, HLO_ImagOp], - [TF_InvertOp, HLO_NotOp], - [TF_IsFiniteOp, HLO_IsFiniteOp], + [TF_FloorOp, MHLO_FloorOp], + [TF_ImagOp, MHLO_ImagOp], + [TF_InvertOp, MHLO_NotOp], + [TF_IsFiniteOp, MHLO_IsFiniteOp], [TF_IsInfOp, CHLO_IsInfOp], [TF_LgammaOp, CHLO_LgammaOp], - [TF_LogOp, HLO_LogOp], - [TF_Log1pOp, HLO_Log1pOp], - [TF_LogicalNotOp, HLO_NotOp], - [TF_NegOp, HLO_NegOp], - [TF_RealOp, HLO_RealOp], - [TF_RsqrtOp, HLO_RsqrtOp], - [TF_SigmoidOp, HLO_LogisticOp], + [TF_LogOp, MHLO_LogOp], + [TF_Log1pOp, MHLO_Log1pOp], + [TF_LogicalNotOp, MHLO_NotOp], + [TF_NegOp, MHLO_NegOp], + [TF_RealOp, MHLO_RealOp], + [TF_RsqrtOp, MHLO_RsqrtOp], + [TF_SigmoidOp, MHLO_LogisticOp], [TF_SinhOp, CHLO_SinhOp], - [TF_SinOp, HLO_SineOp], - [TF_SqrtOp, HLO_SqrtOp], - [TF_TanhOp, HLO_TanhOp], - [TF_TanOp, CHLO_TanOp] + [TF_SinOp, MHLO_SineOp], + [TF_SqrtOp, MHLO_SqrtOp], + [TF_TanhOp, MHLO_TanhOp], + [TF_TanOp, MHLO_TanOp] ] in { - def : Pat<(Mapping[0] HLO_Tensor:$input), + def : Pat<(Mapping[0] MHLO_Tensor:$input), (Mapping[1] $input)>; } -def : Pat<(TF_AngleOp $x), (HLO_Atan2Op (HLO_ImagOp $x), (HLO_RealOp $x))>; +def : Pat<(TF_AngleOp $x), (MHLO_Atan2Op (MHLO_ImagOp $x), (MHLO_RealOp $x))>; // TODO(bixia): Lower with Truncate=True for floating point value conversions. -def : Pat<(TF_CastOp $arg, ConstBoolAttrFalse), (HLO_ConvertOp $arg)>; +def : Pat<(TF_CastOp $arg, ConstBoolAttrFalse), (MHLO_ConvertOp $arg)>; def : Pat<(TF_TransposeOp:$res $arg, (ConstantLikeMatcher ElementsAttr:$permutation)), - (HLO_TransposeOp $arg, (CastElementsToI64Elements $permutation))>; + (MHLO_TransposeOp $arg, (CastElementsToI64Elements $permutation))>; // Lowering these ops with static shape to mhlo.reshape foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in { - def : Pat<(TfOp:$res HLO_Tensor:$arg, $ignored), - (HLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)], + def : Pat<(TfOp:$res MHLO_Tensor:$arg, $ignored), + (MHLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)], (addBenefit 2)>; } // Lowering tf.Reshape with dynamic shape -def : Pat<(TF_ReshapeOp:$res HLO_Tensor:$arg, $shape), +def : Pat<(TF_ReshapeOp:$res MHLO_Tensor:$arg, $shape), (CHLO_DynamicReshapeOp $arg, $shape)>; // Returns NaN if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. -def : Pat<(TF_SignOp $x), (HLO_SignOp $x)>; +def : Pat<(TF_SignOp $x), (MHLO_SignOp $x)>; def BothElementTypesSameWidthIntOrFloat : Constraint; // TODO(jpienaar): Lower constant like to constant to broadcast if dynamic @@ -647,39 +650,39 @@ def : Pat<(TF_BitcastOp:$res HLO_Tensor:$arg), //===----------------------------------------------------------------------===// // TODO(b/148269299): handle random number generator seeds/states correctly. -class HLO_RngDistributionValue : - ConstantAttr; +class MHLO_RngDistributionValue : + ConstantAttr; def : Pat<(TF_RandomUniformOp:$old $shape, $seed, $seed2), - (HLO_RngOp - (HLO_ConstantOp - (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 0.0)">)), - (HLO_ConstantOp - (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 1.0)">)), + (MHLO_RngOp + (MHLO_ConstantOp + (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 0.0)">)), + (MHLO_ConstantOp + (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 1.0)">)), (CastValueToI64 $old, $shape), - HLO_RngDistributionValue<"UNIFORM">), + MHLO_RngDistributionValue<"UNIFORM">), [(IsShapedTensor $shape)]>; def : Pat<(TF_RandomStandardNormalOp:$old $shape, $seed, $seed2), - (HLO_RngOp - (HLO_ConstantOp - (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 0.0)">)), - (HLO_ConstantOp - (NativeCodeCall<"$_builder.getFloatAttr(old.dtype(), 1.0)">)), + (MHLO_RngOp + (MHLO_ConstantOp + (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 0.0)">)), + (MHLO_ConstantOp + (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 1.0)">)), (CastValueToI64 $old, $shape), - HLO_RngDistributionValue<"NORMAL">), + MHLO_RngDistributionValue<"NORMAL">), [(IsShapedTensor $shape)]>; //===----------------------------------------------------------------------===// // Sigmoid grad op. //===----------------------------------------------------------------------===// - -// TODO(hinsu): Handle unranked inputs by broadcasting constant one to the -// shape of $l instead of having it as a constant. -def : Pat<(TF_SigmoidGradOp AnyStaticShapeTensor:$l, AnyStaticShapeTensor:$r), - (HLO_MulOp - (HLO_MulOp $r, $l), - (HLO_SubtractOp (HLO_ConstantOp (ConstantSplat<"1"> $l)), $l))>; +// Disc disable: use ConvertSigmoidGradOpDynamic +// // TODO(hinsu): Handle unranked inputs by broadcasting constant one to the +// // shape of $l instead of having it as a constant. +// def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r), +// (MHLO_MulOp +// (MHLO_MulOp $r, $l), +// (MHLO_SubtractOp (MHLO_ConstantOp (ConstantSplat<"1"> $l)), $l))>; //===----------------------------------------------------------------------===// // Softplus op. @@ -689,22 +692,22 @@ def EpsilonValue : NativeCodeCall<"GetEpsilonValue($0.getType())">; def : Pattern<(TF_SoftplusOp AnyTensor:$features), [ - (HLO_ExpOp:$features_exp $features), + (MHLO_ExpOp:$features_exp $features), (CHLO_BroadcastAddOp:$threshold - (HLO_LogOp (HLO_ConstantOp (EpsilonValue $features))), - (HLO_ConstantOp (GetScalarOfType<2> $features)), + (MHLO_LogOp (MHLO_ConstantOp (EpsilonValue $features))), + (MHLO_ConstantOp (GetScalarOfType<2> $features)), (NullDenseIntElementsAttr) ), - (HLO_SelectOp:$output + (MHLO_SelectOp:$output (CHLO_BroadcastCompareOp $features, - (HLO_NegOp $threshold), + (MHLO_NegOp $threshold), (NullDenseIntElementsAttr), CHLO_ComparisonDirectionValue<"GT">, (CHLO_DEFAULT_COMPARISON_TYPE) ), $features, - (HLO_SelectOp + (MHLO_SelectOp (CHLO_BroadcastCompareOp $features, $threshold, @@ -713,7 +716,7 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), (CHLO_DEFAULT_COMPARISON_TYPE) ), $features_exp, - (HLO_Log1pOp $features_exp) + (MHLO_Log1pOp $features_exp) ) ), (replaceWithValue $output) @@ -724,7 +727,7 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), //===----------------------------------------------------------------------===// def : Pat<(TF_XlaReplicaIdOp), - (TF_CastOp (HLO_ReplicaIdOp), /*truncate=*/ConstBoolAttrFalse)>; + (TF_CastOp (MHLO_ReplicaIdOp), /*truncate=*/ConstBoolAttrFalse)>; //===----------------------------------------------------------------------===// // XlaGather op. @@ -736,7 +739,7 @@ def HasValidGatherDims : Constraint>; def : Pat<(TF_XlaGatherOp $operand, $start_indices, (ConstantLikeMatcher ElementsAttr:$slice_sizes), $dimension_numbers, $indices_are_sorted), - (HLO_GatherOp $operand, $start_indices, + (MHLO_GatherOp $operand, $start_indices, (ToGatherDimNumsAttr $dimension_numbers), (CastElementsToI64Elements $slice_sizes), $indices_are_sorted), @@ -755,7 +758,7 @@ def HasValidDotDims : Constraint>; def HasValidPrecisionConfig : Constraint>; def : Pat<(TF_XlaDotOp $lhs, $rhs, $dimension_numbers, $precision_config), - (HLO_DotGeneralOp $lhs, $rhs, + (MHLO_DotGeneralOp $lhs, $rhs, (ToDotDimNumsAttr $dimension_numbers), (ToPrecisionConfigsAttr $precision_config)), [(HasValidDotDims $dimension_numbers), (HasValidPrecisionConfig $precision_config)]>; @@ -765,7 +768,7 @@ def : Pat<(TF_XlaDotOp $lhs, $rhs, $dimension_numbers, $precision_config), //===----------------------------------------------------------------------===// def : Pat<(TF_XlaDotV2Op $lhs, $rhs, $dimension_numbers, $precision_config), - (HLO_DotGeneralOp $lhs, $rhs, + (MHLO_DotGeneralOp $lhs, $rhs, (ToDotDimNumsAttr $dimension_numbers), (ToPrecisionConfigsAttr $precision_config)), [(HasValidDotDims $dimension_numbers), (HasValidPrecisionConfig $precision_config)]>; @@ -774,9 +777,9 @@ def : Pat<(TF_XlaDotV2Op $lhs, $rhs, $dimension_numbers, $precision_config), // XlaDynamicSlice op. //===----------------------------------------------------------------------===// -def : Pat<(TF_XlaDynamicSliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices, +def : Pat<(TF_XlaDynamicSliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, (ConstantLikeMatcher AnyAttr:$slice_sizes)), - (HLO_DynamicSliceOp $input, + (MHLO_DynamicSliceOp $input, (UnpackStartingIndices $op, $starting_indices), (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes))>; @@ -785,11 +788,11 @@ def : Pat<(TF_XlaDynamicSliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indi //===----------------------------------------------------------------------===// def : Pat<(TF_XlaEinsumOp $lhs, $rhs, $equation), - (HLO_EinsumOp $lhs, $rhs, $equation)>; + (MHLO_EinsumOp $lhs, $rhs, $equation)>; //===----------------------------------------------------------------------===// // XlaOptimizationBarrierOp op. //===----------------------------------------------------------------------===// def : Pat<(TF_XlaOptimizationBarrierOp $args), - (HLO_OptimizationBarrierOp $args)>; + (MHLO_OptimizationBarrierOp $args)>; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc index 0ca3614c3d5..9ec9287d811 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc @@ -141,7 +141,8 @@ class TfTypePattern : public ConversionPattern { if (newType == type) continue; tensorflow::Tensor out; - if (tensorflow::ConvertToTensor(elemsAttr, &out) != tensorflow::Status::OK()) + if (tensorflow::ConvertToTensor(elemsAttr, &out) != + tensorflow::OkStatus()) return failure(); ArrayRef data(static_cast(out.data()), out.TotalBytes()); auto newAttr = DenseElementsAttr::getFromRawBuffer(newType, data); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 07a35b6bb06..78ec0ad35a3 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project @@ -39,21 +40,23 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor.h" #include "tensorflow/compiler/xla/translate/hlo_to_mhlo/mlir_hlo_builder.h" #include "tensorflow/core/common_runtime/device.h" @@ -74,6 +77,7 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/tsl/platform/env.h" #include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/platform/statusor.h" namespace mlir { namespace mhlo { @@ -86,197 +90,213 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { // all tf2xla kernels. // Use a pointer for the static set, so the set is not destructed upon thread // end, which would not be thread safe. - // clang-format off - static auto* ops = - new llvm::SmallDenseSet{ - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - // TODO(hinsu): Canonicalize QuantizeAndDequantize and - // QuantizeAndDequantizeV2 to QuantizeAndDequantizeV3 by converting - // attributes to operands. - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - }; - // clang-format on + static auto* ops = [] { + llvm::SmallDenseSet* ops_set = + new llvm::SmallDenseSet{ + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + // CaseOp isn't actually supported but is enabled for testing to + // make sure ops with symbol ref attributes are filtered out. + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + // TODO(hinsu): Canonicalize QuantizeAndDequantize and + // QuantizeAndDequantizeV2 to QuantizeAndDequantizeV3 by converting + // attributes to operands. + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + }; + + // Add the ops from the TPUEmbeddingOpsRegistry. + for (auto op_type_id : + TF::TPUEmbeddingOpsRegistry::Global().GetOpsTypeIds()) { + ops_set->insert(op_type_id); + } + return ops_set; + }(); auto abstractOp = op->getRegisteredInfo(); if (!abstractOp) return false; @@ -304,6 +324,7 @@ bool IsOpAllowedTf2XlaPreferred(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -315,6 +336,7 @@ bool IsOpAllowedTf2XlaPreferred(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -329,6 +351,7 @@ bool IsOpAllowedTf2XlaPreferred(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -336,7 +359,9 @@ bool IsOpAllowedTf2XlaPreferred(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -372,12 +397,15 @@ bool IsOpAllowedTf2XlaPreferred(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -390,27 +418,17 @@ bool IsOpAllowedTf2XlaPreferred(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), }; // clang-format on - auto abstractOp = op->getRegisteredInfo(); - if (!abstractOp) return false; - return ops->count(abstractOp->getTypeID()); -} -// LINT.ThenChange() -bool IsOpAllowedForTesting(Operation* op) { - // clang-format off - static auto* ops = - new llvm::SmallDenseSet{ - // Op used to verify handling of XlaExpression of kind constant. - TypeID::get(), - }; - // clang-format on auto abstractOp = op->getRegisteredInfo(); if (!abstractOp) return false; return ops->count(abstractOp->getTypeID()); } +// LINT.ThenChange() // List of ops that require falling back to XlaOpKernel legalizations and also // require the ability to create functions. @@ -547,21 +565,53 @@ LogicalResult Tf2XlaRewriter::PrepareParams() { return success(); } +// Returns true if the given type is a ranked tensor type with static or bounded +// dimensions. +bool IsBounded(Type ty) { + auto ranked_ty = ty.dyn_cast(); + if (!ranked_ty) return false; + + if (ranked_ty.hasStaticShape()) return true; + + auto encoding = + ranked_ty.getEncoding().dyn_cast_or_null(); + if (!encoding) return false; + + for (int i = 0; i < ranked_ty.getRank(); ++i) { + if (ranked_ty.isDynamicDim(i) && + encoding.getBounds()[i] == ShapedType::kDynamic) { + return false; + } + } + return true; +} + +bool HasSymbolRefAttr(Operation* op) { + for (const auto& attr : op->getAttrs()) { + Attribute attr_value = attr.getValue(); + if (attr_value.isa()) { + return true; + } else if (auto array_attr = attr_value.dyn_cast()) { + if (!array_attr.empty() && array_attr.begin()->isa()) { + return true; + } + } + } + return false; +} + LogicalResult Tf2XlaRewriter::LegalizeOp() { - // Only static shaped operands are supported in XLA builders for now. for (Type ty : op_->getOperandTypes()) { auto ranked_ty = ty.dyn_cast(); - if (!ranked_ty || !ranked_ty.hasStaticShape()) { + // Only bounded operands are supported in the XLA builders. + if (!IsBounded(ranked_ty)) { return op_->emitRemark() - << "lowering requires static shaped tensor operands"; + << "lowering requires bounded tensor operands " << ranked_ty; } } - for (const auto& attr : op_->getAttrs()) { - if (attr.getValue().isa()) { - return op_->emitRemark() - << "ops with symbol references are not supported"; - } + if (HasSymbolRefAttr(op_)) { + return op_->emitRemark() << "ops with symbol references are not supported"; } auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef( @@ -679,11 +729,6 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() { "output"); } mlir::Value value = hlo_builder_.GetValue(expr->AsXlaOp(&hlo_builder_)); - mlir::OpResult old_result = op_->getResult(i); - if (value.getType() != old_result.getType()) { - value = hlo_builder_.create(old_result.getType(), - value); - } values.push_back(value); } rewriter_.replaceOp(op_, values); @@ -723,20 +768,27 @@ tensorflow::XlaExpression Tf2XlaRewriter::GetExprForOperand(Value operand, return tensorflow::XlaExpression::XlaOp(xla_op, dtype); } -class Tf2XlaRewritePattern : public RewritePattern { +class Tf2XlaRewritePattern : public ConversionPattern { public: - explicit Tf2XlaRewritePattern(MLIRContext* ctx, + explicit Tf2XlaRewritePattern(MLIRContext* ctx, TypeConverter& converter, const std::string& device_type, - bool prefer_tf2xla, bool legalize_test_only_ops, - bool is_module_pass) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx), + bool prefer_tf2xla, bool is_module_pass) + : ConversionPattern(converter, MatchAnyOpTypeTag(), /*benefit=*/1, ctx), device_type_(device_type), prefer_tf2xla_(prefer_tf2xla), - legalize_test_only_ops_(legalize_test_only_ops), is_module_pass_(is_module_pass) {} - LogicalResult matchAndRewrite(Operation* op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite( + Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const override { + // This pattern is a conversion pattern because we want to specify a type + // converter. However, this pattern still uses the original op's operands + // while creating the ops so make sure there aren't any type changes between + // the original op operands and the operands during the conversion. + for (auto&& [old_val, new_val] : llvm::zip(op->getOperands(), operands)) { + if (old_val.getType() != new_val.getType()) return failure(); + } + if (is_module_pass_) { // Module passes should only ever legalize ops that have been specifically // whitelisted for legalization within a module pass. They will never @@ -745,8 +797,7 @@ class Tf2XlaRewritePattern : public RewritePattern { return failure(); } } else if (!(IsOpAllowedTf2XlaFallback(op) || - (prefer_tf2xla_ && IsOpAllowedTf2XlaPreferred(op)) || - (legalize_test_only_ops_ && IsOpAllowedForTesting(op)))) { + (prefer_tf2xla_ && IsOpAllowedTf2XlaPreferred(op)))) { return failure(); } return Tf2XlaRewriter::RewriteOp(op, rewriter, device_type_, @@ -756,52 +807,101 @@ class Tf2XlaRewritePattern : public RewritePattern { private: std::string device_type_; bool prefer_tf2xla_; - bool legalize_test_only_ops_; bool is_module_pass_; }; -// Include declaration for LegalizeTFWithTF2XLAOptions -#define GEN_PASS_DECL_LEGALIZETFWITHTF2XLA -#define GEN_PASS_DEF_LEGALIZETFWITHTF2XLA -#include "tensorflow/compiler/mlir/xla/transforms/tf_xla_passes.h.inc" +bool ShouldRefineTypeTo(Type original_ty, Type updated_ty) { + auto updated = updated_ty.dyn_cast(); + auto original = original_ty.dyn_cast(); + + // Both types must be shaped types. + if (!original || !updated) return false; + + // Element types must match. + if (original.getElementType() != updated.getElementType()) return false; + + // If the updated type doesn't have a rank, then it can't be a more refined + // type. + if (!updated.hasRank()) return false; -class LegalizeTF : public impl::LegalizeTFWithTF2XLABase { + // If the original type doesn't have a rank, then refine as the updated type + // has a rank. + if (!original.hasRank()) return true; + + // Both types must have the same rank. + if (original.getRank() != updated.getRank()) return false; + + // Refine if the updated type is bounded. + return IsBounded(updated); +} + +// Propagates more refined type by cloning op using the new operands. This +// allows all rewrite patterns that requires refined types to work without +// requiring a rewrite to the conversion pattern. Declarative rewrite pattern +// (DRR) doesn't even support conversion patterns with TableGen. +class TypePropagator : public ConversionPattern { public: - LegalizeTF() = default; - explicit LegalizeTF(llvm::StringRef device_type, bool prefer_tf2xla) { - device_type_ = device_type.str(); - prefer_tf2xla_ = prefer_tf2xla; - } + explicit TypePropagator(MLIRContext* ctx) + : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite( + Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const override { + // This could be generalized to other ops as needs arise. We could even + // remove this restriction altogether except for the terminators that + // require function signature change and shouldn't be + if (op->getName().getDialectNamespace() != + TF::TensorFlowDialect::getDialectNamespace()) + return failure(); - LegalizeTF(const LegalizeTF&) {} + // Refining types may have implications to the attached regions or symbol + // references so do not update such ops. + if (!op->getRegions().empty() || HasSymbolRefAttr(op)) return failure(); - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - patterns.add(&getContext(), device_type_, - prefer_tf2xla_, legalize_test_only_ops_, - /*is_module_pass=*/false); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) - signalPassFailure(); - } + IRMapping mapper; + bool has_type_change = false; + for (auto [original, updated] : llvm::zip(op->getOperands(), operands)) { + Type original_ty = original.getType(); + Type updated_ty = updated.getType(); + if (original_ty != updated_ty) has_type_change = true; - private: + if (!ShouldRefineTypeTo(original_ty, updated_ty)) return failure(); + mapper.map(original, updated); + } + if (!has_type_change) return failure(); + + Operation* cloned_op = rewriter.clone(*op, mapper); + rewriter.replaceOp(op, cloned_op->getResults()); + return success(); + } }; } // end namespace -void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type, - RewritePatternSet& patterns, - MLIRContext* ctx, bool prefer_tf2xla, - bool is_module_pass) { - patterns.add(ctx, device_type.str(), prefer_tf2xla, - /*legalize_test_only_ops=*/false, - is_module_pass); +Tf2XlaTypeConverter::Tf2XlaTypeConverter() { + // Currently, we don't do any type conversions. Any TensorFlow op with a type + // that is not supported in MHLO will fail conversion. Quantized types are + // going to handled separately so we don't need to handle those. + addConversion([](Type ty) { return ty; }); + + // This materialization is helpful in cases where we have more refined types + // after conversion to mhlo compared to the original type in TF. For example, + // a TF op with result type tensor<*xf32> will have a bounded type after + // fallback legalization. + auto cast_value = [&](OpBuilder& builder, Type result_type, ValueRange inputs, + Location loc) -> Value { + return builder.create(loc, result_type, + inputs.front()); + }; + addSourceMaterialization(cast_value); } -std::unique_ptr> createLegalizeTfWithTf2XlaPass( - llvm::StringRef device_type, bool prefer_tf2xla) { - return std::make_unique(device_type, prefer_tf2xla); +void PopulateLegalizeTfWithTf2XlaPatterns( + llvm::StringRef device_type, RewritePatternSet& patterns, MLIRContext* ctx, + Tf2XlaTypeConverter& converter, bool prefer_tf2xla, bool is_module_pass) { + patterns.add(ctx); + patterns.add(ctx, converter, device_type.str(), + prefer_tf2xla, is_module_pass); } } // end namespace mhlo diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 48717085e94..280f5cb05a6 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { @@ -54,20 +55,25 @@ std::unique_ptr> createLegalizeTFPass( std::unique_ptr> createLegalizeTFModulePass( StringRef tf2xla_fallback_device_type = ""); +// Legalizes from MHLO quantized ops with MHLO quant types to MHLO primitive ops +// like int ops. +std::unique_ptr> createConvertMHLOQuantToIntPass(); + /// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is /// false, emits an error if there is any operation that can't be legalized. std::unique_ptr> createLegalizeTFNoFallbackPass( bool allow_partial_conversion = false); -/// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the -/// specified device type. -std::unique_ptr> createLegalizeTfWithTf2XlaPass( - llvm::StringRef device_type = "", bool prefer_tf2xla = false); - /// Replaces types that do not exist in MHLO with equivalent types that do /// exist. std::unique_ptr> CreateLegalizeTfTypesPass(); +/// Converter to be used along with the fallback Tf2Xla patterns below. +class Tf2XlaTypeConverter : public TypeConverter { + public: + Tf2XlaTypeConverter(); +}; + /// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list. /// `prefer_tf2xla` means an op will be included iff it is not in /// `MlirLegalizedUnderPreferTf2XlaSet`. `!prefer_tf2xla` mean an op will be @@ -75,6 +81,7 @@ std::unique_ptr> CreateLegalizeTfTypesPass(); void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type, RewritePatternSet& patterns, MLIRContext* ctx, + Tf2XlaTypeConverter& converter, bool prefer_tf2xla = false, bool is_module_pass = false); @@ -93,9 +100,6 @@ void PopulateLegalizeTfQuantizationPatterns(MLIRContext* context, /// Checks whether the op is supported by the Tf2Xla fallback for legalization. bool HasTf2XlaFallback(Operation* op); -/// Lowers from TF dialect's control flow to HLO dialect's control flow. -std::unique_ptr> createLegalizeTFControlFlowPass(); - /// Converts the provided Operation as well as all nested operations into HLO /// dialect using the conversion patterns registered by the HLO dialect. When /// allow_partial_conversion is false, emits an error if there is any operation @@ -118,13 +122,18 @@ std::unique_ptr> CreateLegalizeTFCommunicationPass(); // ops. std::unique_ptr> CreateLegalizeTFCollectivePass(); +// Verifies that the TF/XLA ops have all been lowered to MHLO. +std::unique_ptr> CreateVerifyTFXLALegalizationPass( + bool legalize_chlo = true); + #define GEN_PASS_REGISTRATION #define GEN_PASS_DECL_LEGALIZETF #define GEN_PASS_DECL_LEGALIZETFCOLLECTIVE -#define GEN_PASS_DECL_LEGALIZETFCONTROLFLOW #define GEN_PASS_DECL_LEGALIZETFMODULEPASS #define GEN_PASS_DECL_LEGALIZETFNOFALLBACK #define GEN_PASS_DECL_LEGALIZETFTYPESPASS +#define GEN_PASS_DECL_VERIFYTFXLALEGALIZATION +#define GEN_PASS_DECL_CONVERTMHLOQUANTTOINT #include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes.h.inc" #define GEN_PASS_REGISTRATION diff --git a/tensorflow/compiler/mlir/xla/transforms/tf_xla_passes.td b/tensorflow/compiler/mlir/xla/transforms/tf_xla_passes.td index f89db021381..d42f91118f1 100644 --- a/tensorflow/compiler/mlir/xla/transforms/tf_xla_passes.td +++ b/tensorflow/compiler/mlir/xla/transforms/tf_xla_passes.td @@ -15,25 +15,6 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def LegalizeTFWithTF2XLA - : Pass<"xla-legalize-tf-with-tf2xla", "mlir::func::FuncOp"> { - let summary = "Legalize from TensorFlow to the HLO dialect using tf2xla kernels"; - let dependentDialects = ["mhlo::MhloDialect", "sparse_tensor::SparseTensorDialect"]; - let constructor = "::mlir::mhlo::createLegalizeTfWithTf2XlaPass()"; - let options = [ - // TODO(hinsu): Support finer grained device type assignment instead of a - // global device type for all TensorFlow ops. - Option<"device_type_", "device-type", "std::string", "", - "XLA device type for execution of TensorFlow ops.">, - Option<"prefer_tf2xla_", "prefer-tf2xla", "bool", "", - "Enable legalization when it is not in the list of " - "MLIR-legalized ops.">, - Option<"legalize_test_only_ops_", "legalize-test-only-ops", "bool", "", - "Enable tf2xla legalizations for some ops that are " - "enabled only for testing."> - ]; -} - def LegalizeTFCommunicationPass : Pass<"xla-legalize-tf-communication", "ModuleOp"> { let summary = "Legalize TF/XLA communication ops (TensorFlow dialect) to the HLO " "dialect"; diff --git a/tensorflow/compiler/mlir/xla/transforms/utils.cc b/tensorflow/compiler/mlir/xla/transforms/utils.cc index 974e7236185..88243c6435f 100644 --- a/tensorflow/compiler/mlir/xla/transforms/utils.cc +++ b/tensorflow/compiler/mlir/xla/transforms/utils.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/transforms/utils.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/hlo_utils.h" +#include "tensorflow/compiler/xla/mlir_hlo/utils/hlo_utils.h" namespace mlir { namespace mhlo { diff --git a/tensorflow/compiler/mlir/xla/transforms/utils.h b/tensorflow/compiler/mlir/xla/transforms/utils.h index a5aa6a1418d..f1f25c842b3 100644 --- a/tensorflow/compiler/mlir/xla/transforms/utils.h +++ b/tensorflow/compiler/mlir/xla/transforms/utils.h @@ -22,7 +22,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { namespace mhlo { diff --git a/tensorflow/compiler/mlir/xla/transforms/verify_tfxla_legalization.cc b/tensorflow/compiler/mlir/xla/transforms/verify_tfxla_legalization.cc new file mode 100644 index 00000000000..6501b539efc --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/verify_tfxla_legalization.cc @@ -0,0 +1,76 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets.h" + +namespace mlir { +namespace mhlo { + +namespace { + +#define GEN_PASS_DEF_VERIFYTFXLALEGALIZATION +#include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes.h.inc" + +class VerifyTFXLALegalization + : public impl::VerifyTFXLALegalizationBase { + public: + explicit VerifyTFXLALegalization(bool legalize_chlo) { + legalize_chlo_ = legalize_chlo_; + } + + void runOnOperation() override; +}; + +void VerifyTFXLALegalization::runOnOperation() { + Operation* func_op = getOperation(); + ConversionTarget default_conversion_target = + GetDefaultLegalConversionTargets(getContext(), legalize_chlo_); + + auto walk_result = func_op->walk([&](Operation* op) { + if (default_conversion_target.isLegal(op)) { + return WalkResult::advance(); + } + + emitError(op->getLoc()) << "Could not legalize op: " << op->getName(); + + return WalkResult::interrupt(); + }); + + if (walk_result.wasInterrupted()) signalPassFailure(); +} + +} // namespace + +std::unique_ptr> +CreateVerifyTFXLALegalizationPass(bool legalize_chlo) { + return std::make_unique(legalize_chlo); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets.cc new file mode 100644 index 00000000000..c39026eb7e0 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets.cc @@ -0,0 +1,56 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace mhlo { + +ConversionTarget GetDefaultLegalConversionTargets(MLIRContext& mlir_context, + bool legalize_chlo) { + ConversionTarget target(mlir_context); + + if (legalize_chlo) { + target.addIllegalDialect(); + } else { + target.addLegalDialect(); + } + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalOp(); + + // These ops are legalized in LegalizeTFCommunication after this and that pass + // only operates on MHLO control flow ops. + target.addLegalOp(); + + return target; +} + +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets.h b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets.h new file mode 100644 index 00000000000..62483edca80 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets.h @@ -0,0 +1,34 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_LEGALIZE_TARGETS_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_LEGALIZE_TARGETS_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { +namespace mhlo { + +// Returns a ConversionTarget that includes default legalized MLIR dialects +// for conversion to XLA. +// If legalize_chlo is true, the resulting conversion target cannot have CHLO. +mlir::ConversionTarget GetDefaultLegalConversionTargets( + MLIRContext& mlir_context, bool legalize_chlo); + +} // namespace mhlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_LEGALIZE_TARGETS_H_ diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets_test.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets_test.cc new file mode 100644 index 00000000000..846e6358b6f --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets_test.cc @@ -0,0 +1,96 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets.h" + +#include +#include +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace mhlo { +namespace { + +mlir::DialectRegistry GetDefaultDialectRegistry() { + mlir::DialectRegistry registry; + + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + + return registry; +} + +class XlaLegalizeTargetsTest : public testing::Test { + public: + XlaLegalizeTargetsTest() + : context_(GetDefaultDialectRegistry()), + module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_))), + builder_(&module_->getBodyRegion()) { + context_.loadAllAvailableDialects(); + } + + protected: + mlir::MLIRContext context_; + mlir::OwningOpRef module_; + mlir::OpBuilder builder_; +}; + +TEST_F(XlaLegalizeTargetsTest, CreatesConversionTargets) { + auto const_int = builder_.create( + builder_.getUnknownLoc(), /*value=*/10, builder_.getI32Type()); + + ConversionTarget target = + GetDefaultLegalConversionTargets(context_, /*legalize_chlo=*/false); + EXPECT_TRUE(target.isLegal(const_int)); +} + +TEST_F(XlaLegalizeTargetsTest, AllowsCHLODialect) { + auto const_int = builder_.create( + builder_.getUnknownLoc(), builder_.getI32TensorAttr({42})); + + ConversionTarget target = + GetDefaultLegalConversionTargets(context_, /*legalize_chlo=*/true); + + EXPECT_TRUE(target.isIllegal(const_int)); +} + +TEST_F(XlaLegalizeTargetsTest, DontAllowCHLODialect) { + auto const_int = builder_.create( + builder_.getUnknownLoc(), builder_.getI32TensorAttr({42})); + + ConversionTarget target = + GetDefaultLegalConversionTargets(context_, /*legalize_chlo=*/false); + EXPECT_TRUE(target.isLegal(const_int)); +} + +} // namespace +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf.cc index e3d18b37f17..c8b1fa65968 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf.cc @@ -13,17 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include +#include #include #include #include +#include "absl/strings/string_view.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -31,19 +37,31 @@ limitations under the License. #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_targets.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/xla/translate/hlo_to_mhlo/attribute_importer.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/util/quantization/uniform_quant_ops_attr.pb.h" +#include "tensorflow/core/util/quantization/uniform_quant_ops_params.h" namespace mlir { namespace mhlo { @@ -62,7 +80,7 @@ class LegalizeTF : public impl::LegalizeTFBase { prefer_tf2xla_ = prefer_tf2xla; use_tf2xla_fallback_ = tf2xla_fallback_device_type.has_value(); if (tf2xla_fallback_device_type.has_value()) { - device_type_ = tf2xla_fallback_device_type.getValue().str(); + device_type_ = tf2xla_fallback_device_type.value().str(); } } /// Performs the lowering to XLA dialect. @@ -83,6 +101,19 @@ class LegalizeTFModulePass void runOnOperation() override; }; +FailureOr GetStorageType(Operation *op, + Type original_output_element_type, + PatternRewriter &rewriter) { + if (original_output_element_type.isa()) { + return rewriter.getIntegerType(8); + } else if (original_output_element_type.isa()) { + return rewriter.getIntegerType(32); + } else { + return rewriter.notifyMatchFailure( + op, "Quantized type must be qint8 or qint32."); + } +} + TensorType GetSameShapeTensorType(TensorType tensor_type, Type element_type) { if (auto ranked_tensor_ty = tensor_type.dyn_cast_or_null()) { @@ -95,6 +126,180 @@ TensorType GetSameShapeTensorType(TensorType tensor_type, Type element_type) { llvm_unreachable("unhandled type"); } +template +FailureOr GetUniformQuantizedType( + UniformQuantizedOp op, Type original_type, + TypedValue scales_value, + TypedValue zero_points_value, FloatType expressed_type, + int64_t storage_type_min, int64_t storage_type_max, + int64_t quantized_dimension, PatternRewriter &rewriter) { + // Check whether the scales operand has constant op. + DenseFPElementsAttr scales; + if (!matchPattern(scales_value, m_Constant(&scales))) { + return rewriter.notifyMatchFailure(op, "scales must be constant"); + } + + // Check whether the zero_points operand has constant op. + DenseIntElementsAttr zero_points; + if (!matchPattern(zero_points_value, m_Constant(&zero_points))) { + return rewriter.notifyMatchFailure(op, "zero_points must be constant"); + } + + auto storage_type_or = + GetStorageType(op, getElementTypeOrSelf(original_type), rewriter); + if (failed(storage_type_or)) { + return failure(); + } + + const unsigned flags = quant::QuantizationFlags::Signed; + Type elem_ty; + if (quantized_dimension == -1) { + elem_ty = quant::UniformQuantizedType::get( + flags, *storage_type_or, expressed_type, scales.getValues()[0], + zero_points.getValues()[0], storage_type_min, + storage_type_max); + } else { + SmallVector scales_vec; + SmallVector zero_points_vec; + for (auto elem : scales.getValues()) scales_vec.push_back(elem); + for (auto elem : zero_points.getValues()) + zero_points_vec.push_back(elem); + elem_ty = quant::UniformQuantizedPerAxisType::get( + flags, *storage_type_or, expressed_type, scales_vec, zero_points_vec, + quantized_dimension, storage_type_min, storage_type_max); + } + + return GetSameShapeTensorType(original_type.cast(), elem_ty); +} + +template +FailureOr CreateConstantOpForQint8Rhs( + UniformQuantizedOp op, TensorType new_rhs_type, PatternRewriter &rewriter) { + // Check whether the rhs operand has constant op. + TF::TensorProtoAttr tensor_proto_attr; + if (!matchPattern(op.getRhs(), m_Constant(&tensor_proto_attr))) { + return rewriter.notifyMatchFailure(op, "rhs must be constant."); + } + + llvm::StringRef mangled_tensor = tensor_proto_attr.getValue(); + absl::string_view tensor_view(mangled_tensor.data(), mangled_tensor.size()); + // TODO(hinsu): Instead of getting the weight from TensorProto, use MLIR + // constant attribute to avoid depending on the Tensor proto. + tensorflow::TensorProto tensor_proto; + tensorflow::Status status = + tensorflow::mangling_util::DemangleTensor(tensor_view, &tensor_proto); + if (!status.ok()) { + return rewriter.notifyMatchFailure(op, status.error_message()); + } + + tensorflow::Tensor t; + if (!t.FromProto(tensor_proto)) { + return op.emitError("Failed to convert tensor proto to Tensor."); + } + + auto arr = t.flat(); + auto dense_attr = mlir::DenseElementsAttr::get( + GetSameShapeTensorType(new_rhs_type, rewriter.getIntegerType(8)), + llvm::ArrayRef(arr.data(), arr.size())); + return rewriter.create(op.getLoc(), new_rhs_type, + dense_attr); +} + +xla::ConvolutionDimensionNumbers ConvertConvolutionDimensionNumbers( + const tensorflow::UniformQuantizedConvolutionDimensionNumbersAttr + &dnums_input) { + xla::ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(dnums_input.input_batch_dimension()); + dnums.set_input_feature_dimension(dnums_input.input_feature_dimension()); + for (auto value : dnums_input.input_spatial_dimensions()) { + dnums.add_input_spatial_dimensions(value); + } + dnums.set_kernel_input_feature_dimension( + dnums_input.kernel_input_feature_dimension()); + dnums.set_kernel_output_feature_dimension( + dnums_input.kernel_output_feature_dimension()); + for (auto value : dnums_input.kernel_spatial_dimensions()) { + dnums.add_kernel_spatial_dimensions(value); + } + dnums.set_output_batch_dimension(dnums_input.output_batch_dimension()); + dnums.set_output_feature_dimension(dnums_input.output_feature_dimension()); + for (auto value : dnums_input.output_spatial_dimensions()) { + dnums.add_output_spatial_dimensions(value); + } + return dnums; +} + +DenseIntElementsAttr ConvertToDenseElementsAttr(ArrayAttr array_attr, + PatternRewriter &rewriter) { + SmallVector array; + array.reserve(array_attr.size()); + for (auto elem : array_attr.getAsRange()) { + array.push_back(elem.getInt()); + } + return DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(array_attr.size())}, + rewriter.getIntegerType(64)), + array); +} + +FailureOr ConvertPaddingAttr( + TF::UniformQuantizedConvolutionHybridOp op, + const xla::ConvolutionDimensionNumbers &dnums, PatternRewriter &rewriter) { + StringAttr conv_padding = op.getPaddingAttr(); + SmallVector padding_nums; + ShapedType lhs_shape = op.getLhs().getType().cast(); + ShapedType rhs_shape = op.getRhs().getType().cast(); + + // Handle only static shape cases. + // TODO(b/260284866): Handle dynamic shape cases. + if (!lhs_shape.hasStaticShape()) { + return op.emitError("lhs must have static shape."); + } + if (!rhs_shape.hasStaticShape()) { + return op.emitError("rhs must have static shape."); + } + + const int64_t padding_nums_size = 2 * (rhs_shape.getRank() - 2); + padding_nums.reserve(padding_nums_size); + if (conv_padding.strref().equals("EXPLICIT")) { + for (auto padding_elem : + op.getExplicitPaddingAttr().getAsRange()) { + padding_nums.push_back(padding_elem.getInt()); + } + } else if (conv_padding.strref().equals("VALID")) { + padding_nums.resize(padding_nums_size, 0); + } else { + padding_nums.resize(padding_nums_size); + for (int i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + const int64_t stride = + op.getWindowStridesAttr()[i].cast().getInt(); + const int64_t lhs_size_dilated = + tensorflow::UniformQuantizedConvolutionParams::DilatedSize( + lhs_shape.getDimSize(dnums.input_spatial_dimensions(i)), + op.getLhsDilationAttr()[i].cast().getInt()); + const int64_t rhs_size_dilated = + tensorflow::UniformQuantizedConvolutionParams::DilatedSize( + rhs_shape.getDimSize(dnums.kernel_spatial_dimensions(i)), + op.getRhsDilationAttr()[i].cast().getInt()); + + const int64_t output_size = (lhs_size_dilated + stride - 1) / stride; + const int64_t total_padding = std::max( + (output_size - 1) * stride + rhs_size_dilated - lhs_size_dilated, + static_cast(0)); + const int64_t padding_end = total_padding / 2; + const int64_t padding_begin = total_padding - padding_end; + padding_nums[2 * i] = padding_begin; + padding_nums[2 * i + 1] = padding_end; + } + } + + ElementsAttr padding_attr = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(padding_nums.size() / 2), 2}, + rewriter.getIntegerType(64)), + padding_nums); + return padding_attr; +} + // TODO(hinsu): Move this pattern to legalize_tf after resolving the dependency // on the tensor proto. class ConvertUniformQuantizedDotHybridOp @@ -104,73 +309,204 @@ class ConvertUniformQuantizedDotHybridOp LogicalResult matchAndRewrite(TF::UniformQuantizedDotHybridOp op, PatternRewriter &rewriter) const override { - // Check whether the rhs operand has constant op. - TF::TensorProtoAttr tensor_proto_attr; - if (!matchPattern(op.rhs(), m_Constant(&tensor_proto_attr))) + // Uniform Quantized type for the rhs. + int64_t rhs_quantized_dimension = op.getRhsQuantizationAxis(); + // Currently for dot, PTQ supports per-tensor quantization. + if (rhs_quantized_dimension != -1) { + return rewriter.notifyMatchFailure( + op, "Legalization supports only rhs_quantization_axis -1."); + } + auto rhs_type = GetUniformQuantizedType( + op, op.getRhs().getType(), op.getRhsScales(), op.getRhsZeroPoints(), + /*expressed_type=*/rewriter.getF32Type(), op.getRhsQuantizationMinVal(), + op.getRhsQuantizationMaxVal(), rhs_quantized_dimension, rewriter); + if (failed(rhs_type)) { return failure(); + } - // Check whether the rhs_scales operand has constant op. - DenseFPElementsAttr rhs_scales; - if (!matchPattern(op.rhs_scales(), m_Constant(&rhs_scales))) + auto rhs = CreateConstantOpForQint8Rhs(op, *rhs_type, rewriter); + if (failed(rhs)) { return failure(); + } - // Check whether the rhs_zero_points operand has constant op. - DenseIntElementsAttr rhs_zero_points; - if (!matchPattern(op.rhs_zero_points(), m_Constant(&rhs_zero_points))) - return failure(); + rewriter.replaceOpWithNewOp(op, op.getType(), op.getLhs(), + *rhs, + /*precision_config=*/nullptr); + return success(); + } +}; - // Invalid quantization parameter. - if (rhs_scales.empty()) return failure(); - if (rhs_scales.size() != rhs_zero_points.size()) return failure(); +class ConvertUniformQuantizedConvolutionHybridOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::UniformQuantizedConvolutionHybridOp op, + PatternRewriter &rewriter) const override { // Uniform Quantized type for the rhs. - IntegerType storage_type = rewriter.getIntegerType(8); - FloatType expressed_type = rewriter.getF32Type(); - int64_t storage_type_min = op.rhs_quantization_min_val(); - int64_t storage_type_max = op.rhs_quantization_max_val(); - int32_t quantized_dimension = op.rhs_quantization_axis(); - const unsigned flags = mlir::quant::QuantizationFlags::Signed; - - // Currently, PTQ supports per-tensor quantization, for now. - if (quantized_dimension != -1) return failure(); - - Type rhs_elem_ty; - rhs_elem_ty = quant::UniformQuantizedType::get( - flags, storage_type, expressed_type, rhs_scales.getValues()[0], - rhs_zero_points.getValues()[0], storage_type_min, - storage_type_max); + auto rhs_type = GetUniformQuantizedType( + op, op.getRhs().getType(), op.getRhsScales(), op.getRhsZeroPoints(), + /*expressed_type=*/rewriter.getF32Type(), op.getRhsQuantizationMinVal(), + op.getRhsQuantizationMaxVal(), op.getRhsQuantizationAxis(), rewriter); + if (failed(rhs_type)) { + return failure(); + } - Type rhs_type = GetSameShapeTensorType( - op.rhs().getType().cast(), rhs_elem_ty); - - llvm::StringRef mangled_tensor = tensor_proto_attr.getValue(); - absl::string_view tensor_view(mangled_tensor.data(), mangled_tensor.size()); - // TODO(hinsu): Instead of getting the weight from TensorProto, use MLIR - // constant attribute to avoid depending on the Tensor proto. - tensorflow::TensorProto tensor_proto; - tensorflow::Status status = - tensorflow::mangling_util::DemangleTensor(tensor_view, &tensor_proto); - if (!status.ok()) { + auto rhs = CreateConstantOpForQint8Rhs(op, *rhs_type, rewriter); + if (failed(rhs)) { return failure(); } - tensorflow::Tensor t; - if (!t.FromProto(tensor_proto)) { + // TODO(b/261005147): Update the lowering logic after migration to mhlo + // ConvolutionDimensionNumbers. + tensorflow::UniformQuantizedConvolutionDimensionNumbersAttr dnums_input; + if (!dnums_input.ParseFromString(std::string(op.getDimensionNumbers()))) { + return op->emitError("Parse dimension_numbers failed."); + } + xla::ConvolutionDimensionNumbers dnums = + ConvertConvolutionDimensionNumbers(dnums_input); + + SmallVector converted_attrs; + for (auto attr : op->getAttrs()) { + if (attr.getName() == op.getFeatureGroupCountAttrName() || + attr.getName() == op.getBatchGroupCountAttrName()) { + converted_attrs.push_back(attr); + } else if (attr.getName() == op.getDimensionNumbersAttrName()) { + attr.setValue(xla::ConvertConvDimensionNumbers(dnums, &rewriter)); + converted_attrs.push_back(attr); + } else if (attr.getName() == op.getPaddingAttrName()) { + auto value_or = ConvertPaddingAttr(op, dnums, rewriter); + if (failed(value_or)) { + return failure(); + } + attr.setValue(*value_or); + converted_attrs.push_back(attr); + } else if (attr.getName() == op.getWindowStridesAttrName() || + attr.getName() == op.getLhsDilationAttrName() || + attr.getName() == op.getRhsDilationAttrName()) { + attr.setValue(ConvertToDenseElementsAttr( + attr.getValue().cast(), rewriter)); + converted_attrs.push_back(attr); + } + } + + SmallVector operands{op.getLhs(), *rhs}; + rewriter.replaceOpWithNewOp(op, op.getType(), operands, + converted_attrs); + return success(); + } +}; + +class ConvertUniformQuantizeOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::UniformQuantizeOp op, + PatternRewriter &rewriter) const override { + auto output_type = GetUniformQuantizedType( + op, op.getOutput().getType(), op.getScales(), op.getZeroPoints(), + /*expressed_type=*/rewriter.getF32Type(), op.getQuantizationMinVal(), + op.getQuantizationMaxVal(), op.getQuantizationAxis(), rewriter); + if (failed(output_type)) { return failure(); } - auto arr = t.flat(); - auto dense_attr = ElementsAttr(mlir::DenseElementsAttr::get( - GetSameShapeTensorType(rhs_type.cast(), storage_type), - llvm::makeArrayRef(arr.data(), arr.size()))); + rewriter.replaceOpWithNewOp(op, *output_type, + op.getInput()); + return success(); + } +}; + +// UniformDequantizeOp takes TF quantized types as input which would have been +// converted to the mhlo quantized types. Use OpConversionPattern in order to +// retrieve the operand type *after* conversion, using OpAdaptor operand +// accessor. +// Same for other Uniform Quant Ops that take TF quantized types as input. +class ConvertUniformDequantizeOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + TF::UniformDequantizeOp op, TF::UniformDequantizeOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getInput(); + + rewriter.replaceOpWithNewOp( + op, op.getOutput().getType(), input); + return success(); + } +}; - Value lhs = op.lhs(); - rewriter.setInsertionPointAfterValue(op.rhs()); - Value rhs = rewriter.create(rewriter.getUnknownLoc(), - rhs_type, dense_attr); +class ConvertUniformRequantizeOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + TF::UniformRequantizeOp op, TF::UniformRequantizeOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getInput(); + + auto output_type = GetUniformQuantizedType( + op, op.getOutput().getType(), op.getOutputScales(), + op.getOutputZeroPoints(), + /*expressed_type=*/rewriter.getF32Type(), + op.getOutputQuantizationMinVal(), op.getOutputQuantizationMaxVal(), + op.getOutputQuantizationAxis(), rewriter); + if (failed(output_type)) { + return failure(); + } - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp(op, lhs, rhs, + rewriter.replaceOpWithNewOp(op, *output_type, + input); + return success(); + } +}; + +class ConvertUniformQuantizedDotOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + TF::UniformQuantizedDotOp op, TF::UniformQuantizedDotOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lhs = adaptor.getLhs(); + + // Uniform Quantized type for the rhs. + int64_t rhs_quantized_dimension = op.getRhsQuantizationAxis(); + // Currently for dot, PTQ supports per-tensor quantization. + if (rhs_quantized_dimension != -1) { + return rewriter.notifyMatchFailure( + op, "Legalization supports only rhs_quantization_axis -1."); + } + auto rhs_type = GetUniformQuantizedType( + op, adaptor.getRhs().getType(), op.getRhsScales(), + op.getRhsZeroPoints(), + /*expressed_type=*/rewriter.getF32Type(), op.getRhsQuantizationMinVal(), + op.getRhsQuantizationMaxVal(), rhs_quantized_dimension, rewriter); + if (failed(rhs_type)) { + return failure(); + } + + auto rhs_or = CreateConstantOpForQint8Rhs(op, *rhs_type, rewriter); + if (failed(rhs_or)) { + return failure(); + } + + auto output_type = GetUniformQuantizedType( + op, op.getOutput().getType(), op.getOutputScales(), + op.getOutputZeroPoints(), + /*expressed_type=*/rewriter.getF32Type(), + op.getOutputQuantizationMinVal(), op.getOutputQuantizationMaxVal(), + op.getOutputQuantizationAxis(), rewriter); + if (failed(output_type)) { + return failure(); + } + + rewriter.replaceOpWithNewOp(op, *output_type, lhs, *rhs_or, /*precision_config=*/nullptr); return success(); } @@ -224,6 +560,7 @@ void EmitLegalizationErrors(Operation *op, /// Returns ops that should use MLIR legalization only in the case of /// prefer_tf2xla. All other ops not in this list should use XlaOpKernel /// legalization only or not be legalized by the new bridge. +// LINT.IfChange const llvm::DenseSet &MlirPreferredOps() { // The static variable is a pointer in order to avoid destruction upon thread // termination. @@ -238,25 +575,21 @@ const llvm::DenseSet &MlirPreferredOps() { TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -273,10 +606,7 @@ const llvm::DenseSet &MlirPreferredOps() { TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), - TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -289,8 +619,6 @@ const llvm::DenseSet &MlirPreferredOps() { TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get(), - TypeID::get(), // Ops that have no XlaOpKernel. TypeID::get(), @@ -320,10 +648,17 @@ const llvm::DenseSet &MlirPreferredOps() { TypeID::get(), TypeID::get(), TypeID::get(), + + // Conditional ops + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), }; // clang-format on return *ops; } +// LINT.ThenChange(:PopulateLegalizeTfPatterns) // Patterns whose root op is in the set `include_ops` are moved from the set // `from` to the returned set. This is used to partition patterns by op so they @@ -333,7 +668,7 @@ RewritePatternSet PatternsIncludeOps( RewritePatternSet to(from.getContext()); // Filter NativePatterns. for (auto &pattern : from.getNativePatterns()) { - Optional pat_op_name = pattern->getRootKind(); + std::optional pat_op_name = pattern->getRootKind(); // If the pattern does not have a specific operation, always include it, // If the pattern is in include_ops then include it. bool include = @@ -351,18 +686,8 @@ RewritePatternSet PatternsIncludeOps( mlir::LogicalResult ApplyPatterns(Operation *op, RewritePatternSet &patterns, bool legalize_chlo, bool allow_partial_conversion) { - ConversionTarget target(*op->getContext()); - if (legalize_chlo) { - target.addIllegalDialect(); - } else { - target.addLegalDialect(); - } - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalOp(); + ConversionTarget target = + GetDefaultLegalConversionTargets(*op->getContext(), legalize_chlo); if (!allow_partial_conversion) { // Fully qualify ReturnOp here as mhlo dialect also defines a ReturnOp. @@ -427,10 +752,12 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, ? PatternsIncludeOps(legalize_lower_patterns, MlirPreferredOps()) : std::move(legalize_lower_patterns); + Tf2XlaTypeConverter converter; if (tf2xla_fallback_device_type) { // Add TF->HLO legalization patterns via TF2XLA fallback. - PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.getValue(), - patterns, context, prefer_tf2xla); + PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.value(), + patterns, context, converter, + prefer_tf2xla); } // Populate with CHLO->HLO lowerings to account for TF ops legalized to @@ -498,8 +825,9 @@ void LegalizeTFModulePass::runOnOperation() { Operation *op = getOperation(); MLIRContext *context = op->getContext(); RewritePatternSet patterns(context); + Tf2XlaTypeConverter converter; PopulateLegalizeTfWithTf2XlaPatterns(device_type_, patterns, context, - /*prefer_tf2xla=*/false, + converter, /*prefer_tf2xla=*/false, /*is_module_pass=*/true); if (failed(ApplyPatterns(op, patterns, @@ -513,7 +841,11 @@ void LegalizeTFModulePass::runOnOperation() { void PopulateLegalizeTfQuantizationPatterns(MLIRContext *context, RewritePatternSet *patterns) { - patterns->add(context); + patterns->add( + context); } std::unique_ptr> createLegalizeTFPass( diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_no_fallback.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_no_fallback.cc index ffc9e7fc7d4..f2159a3ec4a 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_no_fallback.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_no_fallback.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { namespace mhlo { diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes.td b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes.td index a84207c1dc0..3e72481ede8 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes.td +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes.td @@ -76,6 +76,18 @@ def LegalizeTFModulePass : Pass<"xla-fallback-legalize-tf-module-pass", "ModuleO "shape::ShapeDialect", "func::FuncDialect", "sparse_tensor::SparseTensorDialect"]; } +def ConvertMHLOQuantToInt : Pass<"convert-mhlo-quant-to-int", "mlir::func::FuncOp"> { + let summary = "Convert from MHLO quantized ops to MHLO primitive ops."; + + let description = [{ + Convert from MHLO quantized ops with MHLO quant types to MHLO primitive ops + like int ops. + }]; + + let constructor = "mlir::mhlo::createConvertMHLOQuantToIntPass()"; + let dependentDialects = ["chlo::ChloDialect", "mhlo::MhloDialect"]; +} + def LegalizeTFNoFallback : Pass<"xla-legalize-tf-no-fallback", "mlir::func::FuncOp"> { let summary = "Legalize from TF dialect's or HLO dialect's control flow."; @@ -95,13 +107,6 @@ def LegalizeTFNoFallback : Pass<"xla-legalize-tf-no-fallback", "mlir::func::Func "shape::ShapeDialect", "func::FuncDialect", "sparse_tensor::SparseTensorDialect"]; } -def LegalizeTFControlFlow : Pass<"xla-legalize-tf-control-flow", "ModuleOp"> { - let summary = "Legalize from TF dialect's to HLO dialect's control flow."; - - let constructor = "mlir::mhlo::createLegalizeTFControlFlowPass()"; - let dependentDialects = ["mhlo::MhloDialect", "sparse_tensor::SparseTensorDialect"]; -} - def LegalizeTfTypesPass : Pass<"xla-legalize-tf-types"> { let summary = "Replace TensorFlow types with types that are legal in the MHLO dialect"; @@ -126,3 +131,19 @@ def LegalizeTFCollective : Pass<"xla-legalize-tf-collective", "ModuleOp"> { let constructor = "mlir::mhlo::CreateLegalizeTFCollectivePass()"; let dependentDialects = ["mhlo::MhloDialect", "sparse_tensor::SparseTensorDialect"]; } + +def VerifyTFXLALegalization : Pass<"tfxla-verify-legalization", "mlir::func::FuncOp"> { + let summary = "Verifies that all TF ops have been legalized to XLA."; + + let description = [{"Ensures that all Tensorflow ops have been legalized to " + "XLA and reports an error about which op has not been" + "legalized. This pass does not transform any ops and is just" + " a verification pass to ensure invariants are true."}]; + + let options = [ + Option<"legalize_chlo_", "legalize-chlo", "bool", /*default=*/"true", + "Legalizes intermediate chlo ops to hlo"> + ]; + + let constructor = "mlir::mhlo::CreateVerifyTFXLALegalizationPass()"; +} diff --git a/tensorflow/compiler/mlir/xla/xla_opt_main.cc b/tensorflow/compiler/mlir/xla/xla_opt_main.cc index 47dc6b3a0ee..77316ac2ef3 100644 --- a/tensorflow/compiler/mlir/xla/xla_opt_main.cc +++ b/tensorflow/compiler/mlir/xla/xla_opt_main.cc @@ -19,14 +19,15 @@ limitations under the License. #include "stablehlo/dialect/Register.h" // from @stablehlo #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" -#include "tensorflow/compiler/mlir/xla/ir/xla_framework.h" #include "tensorflow/compiler/mlir/xla/transforms/adjust_layout.h" -#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" -#include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir/framework/ir/xla_framework.h" +#include "tensorflow/compiler/xla/mlir/framework/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/register.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/service/cpu/hlo_xla_runtime_pipeline.h" +#include "tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.h" #include "tensorflow/core/ir/types/dialect.h" int main(int argc, char **argv) { @@ -37,7 +38,7 @@ int main(int argc, char **argv) { mlir::lmhlo::registerAllLmhloPasses(); // These are in compiler/mlir/xla and not part of the above MHLO passes. mlir::mhlo::registerTfXlaPasses(); - mlir::mhlo::registerXlaPasses(); + mlir::mhlo::registerXlaFrameworkPasses(); mlir::mhlo::RegisterAdjustLayoutPass(); mlir::mhlo::registerLegalizeTfPasses(); mlir::RegisterMhloToLhloWithXlaPass(); @@ -45,6 +46,7 @@ int main(int argc, char **argv) { mlir::registerAllDialects(registry); mlir::mhlo::registerAllMhloDialects(registry); mlir::stablehlo::registerAllDialects(registry); + xla::cpu::RegisterHloXlaRuntimePipelineDialects(registry); registry.insert(); return failed( diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD index dc1c2391e94..e582e196099 100644 --- a/tensorflow/compiler/plugin/BUILD +++ b/tensorflow/compiler/plugin/BUILD @@ -31,6 +31,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") """ package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 3bd9498aa9c..28c0667fd51 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -13,6 +13,7 @@ load( load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [":internal"], licenses = ["notice"], ) @@ -728,7 +729,7 @@ tf_xla_py_test( tf_xla_py_test( name = "slice_ops_test", - size = "small", + size = "medium", srcs = ["slice_ops_test.py"], enable_mlir_bridge = True, python_version = "PY3", @@ -1063,7 +1064,7 @@ tf_xla_py_test( "cpu", "cpu_ondemand", ], - enable_mlir_bridge = False, + enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1119,6 +1120,7 @@ tf_xla_py_test( ], deps = [ ":xla_test", + "//tensorflow:tensorflow_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework", "//tensorflow/python:math_ops", @@ -1203,6 +1205,30 @@ tf_xla_py_test( ], ) +# copybara:uncomment_begin(google-only) +# tf_xla_py_test( +# name = "reverse_sequence_op_args_test", +# size = "medium", +# srcs = ["reverse_sequence_op_args_test.py"], +# enable_mlir_bridge = False, +# main = "reverse_sequence_op_args_test.py", +# python_version = "PY3", +# tags = [ +# "no_pip", +# "optonly", +# ], +# deps = [ +# ":xla_test", +# "//tensorflow/compiler/jit:xla_cpu_jit", # DisableOnExport +# "//tensorflow/python:array_ops", +# "//tensorflow/python:framework", +# "//tensorflow/python:platform_test", +# "//tensorflow/python/compat:v2_compat", +# "//tensorflow/python/eager:function", +# ], +# ) +# copybara:uncomment_end + tf_xla_py_test( name = "rmsprop_test", size = "small", @@ -1445,6 +1471,7 @@ tf_xla_py_test( srcs = ["unary_ops_test.py"], enable_mlir_bridge = True, python_version = "PY3", + shard_count = 4, tags = [ "no_cuda_asan", # times out "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -2207,6 +2234,37 @@ tf_xla_py_test( size = "small", srcs = ["where_op_test.py"], enable_mlir_bridge = False, + enabled_backends = [ + "cpu", + "gpu", + ], + tags = [ + "no_pip", + "optonly", + ], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python/compiler/xla:compiler_py", + ], +) + +tf_xla_py_test( + name = "where_op_tpu_test", + size = "small", + srcs = ["where_op_test.py"], + args = ["--tpu_use_tfrt=true"], + disabled_backends = [ + "cpu", + "cpu_ondemand", + "gpu", + ], + enable_mlir_bridge = False, + main = "where_op_test.py", tags = [ "no_pip", "optonly", @@ -2259,10 +2317,24 @@ tf_xla_py_test( ], ) +cuda_py_test( + name = "const_test", + size = "small", + srcs = ["const_test.py"], + python_version = "PY3", + xla_enable_strict_auto_jit = False, + xla_enabled = True, + deps = [ + ":xla_test", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework", + ], +) + tpu_py_test( - name = "const_op_test", + name = "giant_const_op_test", srcs = [ - "const_op_test.py", + "giant_const_op_test.py", ], disable_experimental = True, # TODO(b/188995810): Add an optimization in MLIR importer to not @@ -2345,3 +2417,41 @@ tf_xla_py_test( "//tensorflow/python:training", ], ) + +tf_xla_py_test( + name = "bincount_op_test", + size = "small", + srcs = ["bincount_op_test.py"], + enable_mlir_bridge = False, + python_version = "PY3", + shard_count = 10, + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + deps = [ + ":xla_test", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "unique_ops_test", + size = "small", + srcs = ["unique_ops_test.py"], + enable_mlir_bridge = False, + enabled_backends = [ + "cpu", + "gpu", + ], + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) diff --git a/tensorflow/compiler/tests/approx_topk_test.py b/tensorflow/compiler/tests/approx_topk_test.py index e2672e0d399..53e5ce05429 100644 --- a/tensorflow/compiler/tests/approx_topk_test.py +++ b/tensorflow/compiler/tests/approx_topk_test.py @@ -146,19 +146,20 @@ def ann(qy, db, k): def test_l2ann(self, k, db_size, qy_size, feature_dim): qy = self._rng.random([qy_size, feature_dim], dtype=np.float32) db = self._rng.random([db_size, feature_dim], dtype=np.float32) - db_half_norm = np.linalg.norm(db, axis=1) / 2 + db_half_norm_sq = np.linalg.norm(db, axis=1)**2 / 2 @function(jit_compile=True) - def ann(qy, db, db_half_norm, k): - scores = db_half_norm - math_ops.matmul(qy, db, transpose_b=True) + def ann(qy, db, db_half_norm_sq, k): + scores = db_half_norm_sq - math_ops.matmul(qy, db, transpose_b=True) return nn_ops.approx_min_k(scores, k) with ops.device('/device:TPU:0'): qy_op = variables.Variable(qy) db_op = variables.Variable(db) - db_half_norm_op = variables.Variable(db_half_norm) - result = ann(qy_op, db_op, db_half_norm_op, k)[1] - scores = db_half_norm_op - math_ops.matmul(qy_op, db_op, transpose_b=True) + db_half_norm_sq_op = variables.Variable(db_half_norm_sq) + result = ann(qy_op, db_op, db_half_norm_sq_op, k)[1] + scores = db_half_norm_sq_op - math_ops.matmul( + qy_op, db_op, transpose_b=True) gt = np.argsort(scores.numpy())[:, :k] ann_recall = self.compute_recall(result.numpy(), gt) @@ -218,6 +219,36 @@ def ann_with_grads(db, out_grads): self.assertAllClose(expected_in_grads, result_in_grads) + # Tests that multiple ops are supported and the comparison functions are + # renamed properly to avoid conflict while using the MLIR bridge. + def test_multiple_ops(self): + k = 1 + + row_size = 100 + num_rows = 10 + + row = np.arange(row_size, dtype=np.float32) + db1 = np.stack(list(self._rng.permutation(row) for _ in range(num_rows))) + db2 = np.stack(list(self._rng.permutation(row) for _ in range(num_rows))) + + @function(jit_compile=True) + def ann(db1, db2): + result1 = nn_ops.approx_max_k(db1, k, aggregate_to_topk=True) + result2 = nn_ops.approx_max_k(db2, k, aggregate_to_topk=True) + return (result1, result2) + + with ops.device('/device:TPU:0'): + db1_op = variables.Variable(db1) + db2_op = variables.Variable(db2) + result1, result2 = ann(db1_op, db2_op) + + gt = np.argsort(-db1)[:, :k] + ann_recall = self.compute_recall(result1[1].numpy(), gt) + self.assertGreaterEqual(ann_recall, 0.95) + + gt = np.argsort(-db2)[:, :k] + ann_recall = self.compute_recall(result2[1].numpy(), gt) + self.assertGreaterEqual(ann_recall, 0.95) if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/tests/async_comp_test.py b/tensorflow/compiler/tests/async_comp_test.py index 06511451850..ce9b58329c9 100644 --- a/tensorflow/compiler/tests/async_comp_test.py +++ b/tensorflow/compiler/tests/async_comp_test.py @@ -15,6 +15,7 @@ """Tests for asynchronous compilation on the CPU and GPU devices.""" import os +import unittest from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as session_lib @@ -51,6 +52,7 @@ class AsyncCompilationTest(test.TestCase): # Asynchrobnous compilation uses the existing fallback path and existing # compiler. This test only tests that asynchronus compilation is performed. + @unittest.skip("b/263146341 - flaky Kokoro build.") def testAsyncCompilationJit(self): @function.Defun(compiled=True) diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 575e17ada1d..a312df10e1f 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -694,12 +694,20 @@ def _testDivision(self, dtype): rtol=7e-15 if dtype == np.float64 else None, atol=3.9e-15 if dtype == np.float64 else None) - if dtype not in self.complex_types: # floordiv unsupported for complex. + # floordiv/truncatediv unsupported for complex. + if dtype not in self.complex_types: self._testBinary( gen_math_ops.floor_div, np.array([3, 3, -1, -9, -8], dtype=dtype), np.array([2, -2, 7, 2, -4], dtype=dtype), expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) + + self._testBinary( + gen_math_ops.truncate_div, + np.array([3, 3, -1, -9, -8.1], dtype=dtype), + np.array([2, -2, 7, 2, -4], dtype=dtype), + expected=np.array([1, -1, 0, -4, 2], dtype=dtype)) + if dtype in self.signed_int_types: # Overflow cases. int_min = np.iinfo(dtype).min diff --git a/tensorflow/compiler/tests/bincount_op_test.py b/tensorflow/compiler/tests/bincount_op_test.py new file mode 100644 index 00000000000..79e8a7e91b8 --- /dev/null +++ b/tensorflow/compiler/tests/bincount_op_test.py @@ -0,0 +1,40 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for bincount using the XLA JIT.""" +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import errors +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.platform import googletest + + +class BincountTest(xla_test.XLATestCase): + + def testInputRank0(self): + with self.session(): + with self.test_scope(): + bincount = gen_math_ops.bincount(arr=6, size=804, weights=[52, 351]) + + with self.assertRaisesRegex( + errors.InvalidArgumentError, + ( + "`weights` must be the same shape as `arr` or a length-0" + " `Tensor`, in which case it acts as all weights equal to 1." + ), + ): + self.evaluate(bincount) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 0f58bd5a951..bd10ecad818 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -108,6 +108,7 @@ def tf_xla_py_test( for mlir_option in enable_mlir_bridge_options: extra_dep = [] + extra_tag = [] updated_name = test_name mlir_bridge_dep = "//tensorflow/python:is_mlir_bridge_test_true" @@ -117,6 +118,12 @@ def tf_xla_py_test( updated_name = updated_name[:-5] updated_name += "_mlir_bridge_test" extra_dep = [] if has_mlir_dep else [mlir_bridge_dep] + + # Mark gpu mlir_bridge tests as ondemand + # + # This is for testing book keeping because the bridge does not have any gpu specific + # logic at this time, so CPU testing is good enough and cheaper. + extra_tag = ["ondemand"] if backend == "gpu" else [] elif has_mlir_dep: # Some tests run only with mlir_bridge by explicitly adding the MLIR # bridge dep so if the dep is already present skip non MLIR @@ -131,7 +138,7 @@ def tf_xla_py_test( main = "{}.py".format(name) if main == None else main, data = data + backend_data, deps = deps + backend_deps + extra_dep, - tags = test_tags, + tags = test_tags + extra_tag, exec_properties = tf_exec_properties({"tags": test_tags}), **kwargs ) diff --git a/tensorflow/compiler/tests/const_test.py b/tensorflow/compiler/tests/const_test.py new file mode 100644 index 00000000000..4e11a436e85 --- /dev/null +++ b/tensorflow/compiler/tests/const_test.py @@ -0,0 +1,60 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for const op compilation.""" + +import numpy as np + +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +# This test doesn't use XLATestCase like the other tests in this directory. +# The Const op xla op kernel is compilation only and therefore is not executed +# with XLA in the on demand compilation mode. Instead we use +# tf.function(jit_compile=True) +class ConstOpTest(test_util.TensorFlowTestCase): + + # Verifies that the Const op works + # @test_util.run_v2_only + def testConst(self): + types = { + dtypes.bool, dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, + dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64, + dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64, + dtypes.float8_e5m2, dtypes.float8_e4m3fn, + } + for dtype in types: + with self.subTest(dtype=dtype): + if dtype == dtypes.bool: + values = [True, False] + else: + values = [0., 1., -1., dtype.min, dtype.max] + if dtype.is_floating: + values.extend([float("Inf"), -float("Inf"), float("NaN")]) + values = np.array(values, dtype=dtype.as_numpy_dtype) + + @def_function.function(jit_compile=True) + def f(): + return constant_op.constant(values, dtype) # pylint: disable=cell-var-from-loop + + result = f() + self.assertAllEqual(self.evaluate(result), values) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index 767f62fbffc..50f7df3a975 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -27,8 +27,10 @@ from tensorflow.python.platform import googletest BATCH_DIMS = (3, 5) -RTOL = 0.02 # Eigen/cuFFT differ widely from np, especially for FFT3D -ATOL = 1e-3 +RTOL = 0.009 # Eigen/cuFFT differ widely from np, especially for FFT3D +ATOL = 1e-4 +RTOL_3D = 0.07 +ATOL_3D = 4e-4 def pick_10(x): @@ -55,8 +57,13 @@ def to_32bit(x): class FFTTest(xla_test.XLATestCase): - def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected, - tf_method): + def _VerifyFftMethod(self, + inner_dims, + complex_to_input, + input_to_expected, + tf_method, + atol=ATOL, + rtol=RTOL): for indims in inner_dims: print("nfft =", indims) shape = BATCH_DIMS + indims @@ -72,7 +79,7 @@ def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected, dtypes.as_dtype(data.dtype), shape=data.shape) out = tf_method(ph) value = sess.run(out, {ph: data}) - self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL) + self.assertAllClose(expected, value, rtol=rtol, atol=atol) def testContribSignalSTFT(self): ws = 512 @@ -111,7 +118,7 @@ def testFFT2D(self): def testFFT3D(self): self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, lambda x: np.fft.fftn(x, axes=(-3, -2, -1)), - signal.fft3d) + signal.fft3d, ATOL_3D, RTOL_3D) def testIFFT(self): self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.ifft, @@ -124,7 +131,7 @@ def testIFFT2D(self): def testIFFT3D(self): self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, lambda x: np.fft.ifftn(x, axes=(-3, -2, -1)), - signal.ifft3d) + signal.ifft3d, ATOL_3D, RTOL_3D) def testRFFT(self): @@ -155,7 +162,8 @@ def _tf_fn(x): return signal.rfft3d( x, fft_length=[x.shape[-3], x.shape[-2], x.shape[-1]]) - self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn) + self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn, ATOL_3D, + RTOL_3D) def testRFFT3DMismatchedSize(self): @@ -209,7 +217,8 @@ def _tf_fn(x): return signal.irfft3d( x, fft_length=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)]) - self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn) + self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn, + ATOL_3D, RTOL_3D) def testIRFFT3DMismatchedSize(self): @@ -229,7 +238,8 @@ def _tf_fn(x): return signal.irfft3d( x, fft_length=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2]) - self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn) + self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn, + ATOL_3D, RTOL_3D) diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py index 8a7cbccd117..14a26570dc5 100644 --- a/tensorflow/compiler/tests/fifo_queue_test.py +++ b/tensorflow/compiler/tests/fifo_queue_test.py @@ -37,7 +37,7 @@ def testEnqueueWithShape(self): enqueue_correct_op.run() with self.assertRaises(ValueError): q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],)) - self.assertEqual(1, q.size().eval()) + self.assertEqual(1, self.evaluate(q.size())) def testMultipleDequeues(self): with self.session(), self.test_scope(): @@ -85,7 +85,7 @@ def enqueue(enqueue_op): # Dequeue every element using a single thread. results = [] for _ in range(len(elems)): - results.append(dequeued_t.eval()) + results.append(self.evaluate(dequeued_t)) self.assertItemsEqual(elems, results) def testParallelDequeue(self): @@ -175,7 +175,7 @@ def testMultiEnqueueAndDequeue(self): def testQueueSizeEmpty(self): with self.session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) - self.assertEqual([0], q.size().eval()) + self.assertEqual([0], self.evaluate(q.size())) def testQueueSizeAfterEnqueueAndDequeue(self): with self.session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/const_op_test.py b/tensorflow/compiler/tests/giant_const_op_test.py similarity index 97% rename from tensorflow/compiler/tests/const_op_test.py rename to tensorflow/compiler/tests/giant_const_op_test.py index cab962073e9..c0f4b47be01 100644 --- a/tensorflow/compiler/tests/const_op_test.py +++ b/tensorflow/compiler/tests/giant_const_op_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for const op compilation.""" +"""Tests for giant const op compilation.""" import os import numpy as np @@ -54,10 +54,10 @@ def get_tpu_strategy(): # with XLA in the on demand compilation mode. Also, here we want to feed the # full program to XLA to verify handling of programs with giant constant # tensors. -class ConstOp(test.TestCase): +class GiantConstOp(test.TestCase): def setUp(self): - super(ConstOp, self).setUp() + super(GiantConstOp, self).setUp() # Make sure TF_XLA_FLAGS is not already set to avoid dropping the existing # value silently. assert "TF_XLA_FLAGS" not in os.environ diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 45297ecbe18..d4a264953cf 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -78,7 +78,7 @@ def testRGBToHSVRoundTrip(self): with self.test_scope(): hsv = image_ops.rgb_to_hsv(placeholder) rgb = image_ops.hsv_to_rgb(hsv) - rgb_tf = rgb.eval(feed_dict={placeholder: rgb_np}) + rgb_tf = rgb.eval(feed_dict={placeholder: rgb_np}) self.assertAllCloseAccordingToType(rgb_tf, rgb_np, bfloat16_atol=0.03) def testRGBToHSVNumpy(self): @@ -520,9 +520,6 @@ def testBFloat16(self): dtype=np.float32)) def testAlignCorners3x3To12x12_uint8(self): - # TODO(b/72099414): enable the test for TPU when the issue is fixed. - if (self.device not in ["XLA_GPU", "XLA_CPU"]): - return # Ensure that resize with convolution works on XLA/GPU for integer types self._assertForwardOpMatchesExpected( np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.uint8), [12, 12], diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index 3d2695b15e9..3a7e22c02e5 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -18,7 +18,9 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import nn_ops @@ -560,6 +562,34 @@ def AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding, self._TestPooling(nn_ops.avg_pool, AvgPoolGrad) + @test_util.disable_mlir_bridge( + "TODO(b/266613412): investigate FPE in AvgPoolGrad for TPU" + ) + def testAvgPoolGradSamePaddingZeroStrideZeroSize(self): + output_gradient_vals = np.array([0.39117979], dtype=np.float32) + output_gradient_vals = output_gradient_vals.reshape([1, 1, 1, 1]) + with self.session() as sess: + with self.test_scope(): + output_gradients = array_ops.placeholder( + dtypes.float32, shape=output_gradient_vals.shape + ) + t = gen_nn_ops.avg_pool_grad( + orig_input_shape=[1, 0, 0, 0], + grad=output_gradients, + ksize=[1, 0, 0, 0], + strides=[1, 0, 0, 0], + padding="SAME", + data_format="NCHW", + ) + with self.assertRaisesRegex( + errors.InvalidArgumentError, + ( + "Sliding window ksize field for dimension 1 must be positive but" + " is 0" + ), + ): + sess.run(t, {output_gradients: output_gradient_vals}) + # The CPU implementation of AvgPoolGrad doesn't accept kernels smaller than # the stride size, so we only run the following tests on MaxPoolGrad. diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 03e919dee91..80be5b5e836 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -279,6 +279,12 @@ def testShuffle2d(self): self.assertAllEqual(len(result.flatten()), len(expected)) self.assertAllEqual(set(result.flatten()), set(expected)) + def testRandomShuffleInputRank0(self): + with self.session(): + with self.test_scope(): + shuffle = random_ops.random_shuffle(value=1e20) + self.evaluate(shuffle) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tests/reverse_sequence_op_args_test.py b/tensorflow/compiler/tests/reverse_sequence_op_args_test.py new file mode 100644 index 00000000000..3ccb9b1df27 --- /dev/null +++ b/tensorflow/compiler/tests/reverse_sequence_op_args_test.py @@ -0,0 +1,52 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.reverse_sequence_op.""" + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.compat import v2_compat +from tensorflow.python.eager import def_function +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ReverseSequenceArgsTest(xla_test.XLATestCase): + """Tests argument verification of array_ops.reverse_sequence.""" + + def testInvalidArguments(self): + # seq_axis negative + with self.assertRaisesRegex( + (errors.InvalidArgumentError, ValueError), "seq_dim must be >=0" + ): + + @def_function.function(jit_compile=True) + def f(x): + return array_ops.reverse_sequence(x, [2, 2], seq_axis=-1) + + f([[1, 2], [3, 4]]) + + # batch_axis negative + with self.assertRaisesRegex(ValueError, "batch_dim must be >=0"): + + @def_function.function(jit_compile=True) + def g(x): + return array_ops.reverse_sequence(x, [2, 2], seq_axis=1, batch_axis=-1) + + g([[1, 2], [3, 4]]) + + +if __name__ == "__main__": + v2_compat.enable_v2_behavior() + test.main() diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 7ef6d6607c1..2eac4222ae3 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -41,6 +41,14 @@ def _unsortedSegmentSum(self, data, indices, num_segments): return self._segmentReduction(math_ops.unsorted_segment_sum, data, indices, num_segments) + def _segmentSumV2(self, data, indices, num_segments): + return self._segmentReduction(math_ops.segment_sum_v2, data, indices, + num_segments) + + def _segmentProdV2(self, data, indices, num_segments): + return self._segmentReduction(math_ops.segment_prod_v2, data, indices, + num_segments) + def _unsortedSegmentProd(self, data, indices, num_segments): return self._segmentReduction(math_ops.unsorted_segment_prod, data, indices, num_segments) @@ -53,6 +61,38 @@ def _unsortedSegmentMax(self, data, indices, num_segments): return self._segmentReduction(math_ops.unsorted_segment_max, data, indices, num_segments) + def testSegmentSum(self): + for dtype in self.numeric_types: + self.assertAllClose( + np.array([1, 0, 2, 12], dtype=dtype), + self._segmentSumV2( + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), + np.array([0, 0, 2, 3, 3, 3], dtype=np.int32), 4)) + + def testSegmentProd(self): + for dtype in self.numeric_types: + self.assertAllClose( + np.array([0, 1, 2, 60], dtype=dtype), + self._segmentProdV2( + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), + np.array([0, 0, 2, 3, 3, 3], dtype=np.int32), 4)) + + def testSegmentProdNumSegmentsLess(self): + for dtype in self.numeric_types: + self.assertAllClose( + np.array([0, 1, 2], dtype=dtype), + self._segmentProdV2( + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), + np.array([0, 0, 2, 3, 3, 3], dtype=np.int32), 3)) + + def testSegmentProdNumSegmentsMore(self): + for dtype in self.numeric_types: + self.assertAllClose( + np.array([0, 1, 2, 60, 1], dtype=dtype), + self._segmentProdV2( + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), + np.array([0, 0, 2, 3, 3, 3], dtype=np.int32), 5)) + def testUnsortedSegmentSum0DIndices1DData(self): for dtype in self.numeric_types: self.assertAllClose( diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index b6a1f08e0a1..012fe158e1c 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -15,10 +15,12 @@ """Tests for stateless random-number generation ops.""" import functools +import os from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test +from tensorflow.python.client import device_lib from tensorflow.python.eager import def_function from tensorflow.python.framework import config from tensorflow.python.framework import dtypes @@ -34,14 +36,36 @@ from tensorflow.python.platform import test +def xla_device(): + devices = device_lib.list_local_devices() + + def find_type(device_type): + for d in devices: + if d.device_type == device_type: + return d + return None + + d = find_type('TPU') or find_type('XLA_GPU') or find_type('XLA_CPU') + if d is None: + raise ValueError('Cannot find any XLA device. Available devices:\n%s' % + devices) + return d + + +def _allowed_types(include_int=False): + allowed_types = { + dtypes.float64, dtypes.float32, dtypes.float16, dtypes.bfloat16 + } + if include_int: + allowed_types.update({dtypes.int32, dtypes.int64}) + return allowed_types + + class StatelessRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase): """Test cases for stateless random-number generator operators.""" def _random_types(self, include_int=False): - allowed_types = {dtypes.float64, dtypes.float32, dtypes.bfloat16} - if include_int: - allowed_types.update({dtypes.int32, dtypes.int64}) - return self.all_tf_types & allowed_types + return self.all_tf_types & _allowed_types(include_int) @test_util.run_v2_only def testForcedCompile(self): @@ -148,28 +172,30 @@ def testLargeNormal(self): y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) self.assertAllEqual([1024, 32000], y.shape) - def testDeterminism(self): + @parameterized.named_parameters( + (f'_{op_name}_{shape}_{dtype.name}', stateless_op, shape, dtype) # pylint: disable=g-complex-comprehension + for dtype in _allowed_types() for shape in ((), (3,), (2, 5)) + for op_name, stateless_op in ( + ('uniform', stateless.stateless_random_uniform), + ('normal', stateless.stateless_random_normal), + )) + def testDeterminism(self, stateless_op, shape, dtype): # Stateless values should be equal iff the seeds are equal (roughly) + seeds = [(x, y) for x in range(-2, 3) for y in range(-2, 3)] * 3 # pylint: disable=g-complex-comprehension with self.session(), self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) - seeds = [(x, y) for x in range(-2, 3) for y in range(-2, 3)] * 3 # pylint: disable=g-complex-comprehension - for stateless_op in [ - stateless.stateless_random_uniform, stateless.stateless_random_normal - ]: - for shape in (), (3,), (2, 5): - for dtype in self._random_types(): - # Skip bfloat16. The result of bfloat16 is truncated from 32-bit - # result. With different seeds, the 32-bit results are different, - # but the truncated 16-bit results might be the same. - if dtype == dtypes.bfloat16: - continue - pure = stateless_op(shape, seed=seed_t, dtype=dtype) - values = [(seed, pure.eval(feed_dict={ - seed_t: seed - })) for seed in seeds] - for s0, v0 in values: - for s1, v1 in values: - self.assertEqual(s0 == s1, np.all(v0 == v1)) + pure = stateless_op(shape, seed=seed_t, dtype=dtype) + values = [(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds] + for s0, v0 in values: + for s1, v1 in values: + if s0 == s1: + self.assertAllEqual(v0, v1) + else: + # The resolutions of float16 and bfloat16 are too low, so + # in some cases (e.g. scalar shape) different seeds may + # lead to the same output. So we skip those dtypes. + if not (dtype in (dtypes.bfloat16, dtypes.float16) and shape == ()): # pylint: disable=g-explicit-bool-comparison + self.assertNotAllEqual(v0, v1) def testRandomUniformIsInRange(self): with self.session() as sess, self.test_scope(): @@ -184,26 +210,59 @@ def testRandomUniformIsInRange(self): self.assertTrue(np.all(y >= 0)) self.assertTrue(np.all(y < maxval)) - def testDistributionOfStatelessRandomUniform(self): + @parameterized.named_parameters( + (f'_{alg.name}_{dtype.name}_{seed}', alg, dtype, seed) # pylint: disable=g-complex-comprehension + for seed in ([1, 2], [12, 23], [123, 456], [565656, 121212]) + for dtype in _allowed_types(include_int=True) + for alg in list(stateless.Algorithm)) + def testDistributionOfStatelessRandomUniform(self, alg, dtype, seed): """Use Pearson's Chi-squared test to test for uniformity.""" + philox = stateless.Algorithm.PHILOX + auto_select = stateless.Algorithm.AUTO_SELECT + device = xla_device() + if 'CPU' in device.device_type: + device_type = 'CPU' + elif 'GPU' in device.device_type: + device_type = 'GPU' + elif device.device_type == 'TPU': + device_type = 'TPU' + else: + device_type = None + bad_combos1 = [ + (dtypes.int32, [123, 456]), + (dtypes.int64, [123, 456]), + (dtypes.float16, [565656, 121212]), + (dtypes.bfloat16, [1, 2]), + ] + bad_combos2 = [ + (dtypes.int32, [1, 2]), + (dtypes.int32, [12, 23]), + ] + # TODO(b/244649364): Investigate why these combinations fail. + if (device_type in ('CPU', 'GPU') and alg in (philox, auto_select) and + (dtype, seed) in bad_combos1 or device_type == 'TPU' and + (alg == philox and + (dtype, seed) in bad_combos1 or alg == auto_select and + (dtype, seed) in bad_combos2)): + self.skipTest( + 'This (device, alg, dtype, seed) combination fails (b/244649364).') with self.session() as sess, self.test_scope(): - for dtype in self._random_types(include_int=True): - seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) - n = 1000 - maxval = 1 - if dtype.is_integer: - maxval = 100 - x = stateless.stateless_random_uniform( - shape=[n], seed=seed_t, maxval=maxval, dtype=dtype) - y = sess.run(x, {seed_t: [565656, 121212]}) - # Convert y to float and normalize its value to range [0, 1) when - # maxval != 1. - y = y.astype(float) / maxval - # Tests that the values are distributed amongst 10 bins with equal - # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with - # p=0.05. This test is probabilistic and would be flaky if the random - # seed were not fixed. - self.assertLess(random_test_util.chi_squared(y, 10), 16.92) + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + n = 1000 + maxval = 1 + if dtype.is_integer: + maxval = 100 + x = stateless.stateless_random_uniform( + shape=[n], seed=seed_t, maxval=maxval, dtype=dtype, alg=alg) + y = sess.run(x, {seed_t: seed}) + # Convert y to float and normalize its value to range [0, 1) when + # maxval != 1. + y = y.astype(float) / maxval + # Tests that the values are distributed amongst 10 bins with equal + # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with + # p=0.05. This test is probabilistic and would be flaky if the random + # seed were not fixed. + self.assertLess(random_test_util.chi_squared(y, 10), 16.92) def testRandomNormalIsFinite(self): with self.session() as sess, self.test_scope(): @@ -214,32 +273,60 @@ def testRandomNormalIsFinite(self): y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) self.assertTrue(np.all(np.isfinite(y))) - def testDistributionOfStatelessRandomNormal(self): + @parameterized.named_parameters( + (f'_{dtype.name}_{seed}', dtype, seed) # pylint: disable=g-complex-comprehension + for seed in ([1, 2], [12, 23], [123, 456], [25252, 314159]) + for dtype in _allowed_types()) + def testDistributionOfStatelessRandomNormal(self, dtype, seed): """Use Anderson-Darling test to test distribution appears normal.""" with self.session() as sess, self.test_scope(): - for dtype in self._random_types(): - seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) - n = 1000 - x = stateless.stateless_random_normal( - shape=[n], seed=seed_t, dtype=dtype) - y = sess.run(x, {seed_t: [25252, 314159]}) - # The constant 2.492 is the 5% critical value for the Anderson-Darling - # test where the mean and variance are known. This test is probabilistic - # so to avoid flakiness the seed is fixed. - self.assertLess( - random_test_util.anderson_darling(y.astype(float)), 2.492) - - def testTruncatedNormal(self): - for dtype in self._random_types(): - with self.session() as sess, self.test_scope(): - seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) - n = 10000000 - x = stateless.stateless_truncated_normal( - shape=[n], seed=seed_t, dtype=dtype) - y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) - random_test_util.test_truncated_normal( - self.assertEqual, self.assertAllClose, n, y, - variance_rtol=6e-3 if dtype == dtypes.bfloat16 else 1e-3) + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + n = 1000 + x = stateless.stateless_random_normal(shape=[n], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: seed}) + # The constant 2.492 is the 5% critical value for the Anderson-Darling + # test where the mean and variance are known. This test is probabilistic + # so to avoid flakiness the seed is fixed. + self.assertLess(random_test_util.anderson_darling(y.astype(float)), 2.492) + + @parameterized.named_parameters( + (f'_{dtype.name}', dtype) for dtype in _allowed_types()) + def testTruncatedNormal(self, dtype): + with self.session() as sess, self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + n = 10000000 + x = stateless.stateless_truncated_normal( + shape=[n], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) + is_megacore = 'megacore' in os.environ.get('TEST_TARGET', '').lower() + if dtype == dtypes.float16: + if is_megacore: + mean_atol = 2e-3 + else: + mean_atol = 7e-4 + else: + mean_atol = 5e-4 + + if dtype == dtypes.float16 and is_megacore: + median_atol = 2e-3 + else: + median_atol = 8e-4 + + if dtype == dtypes.bfloat16: + variance_rtol = 6e-3 + elif dtype == dtypes.float16: + variance_rtol = 3e-3 + else: + variance_rtol = 1e-3 + + random_test_util.test_truncated_normal( + self.assertEqual, + self.assertAllClose, + n, + y, + mean_atol=mean_atol, + median_atol=median_atol, + variance_rtol=variance_rtol) def _testParameterizedTruncatedNormal(self, means, @@ -329,6 +416,10 @@ def builder_fn(): xla_test.Benchmark(self, builder_fn, use_xla_jit=use_xla_jit, device='cpu') + def benchmarkUniformF16(self): + self._benchmarkUniform( + 'uniform_f16', dtype=dtypes.float16, use_xla_jit=False) + def benchmarkUniformF32(self): self._benchmarkUniform( 'uniform_f32', dtype=dtypes.float32, use_xla_jit=False) @@ -337,6 +428,10 @@ def benchmarkUniformF64(self): self._benchmarkUniform( 'uniform_f64', dtype=dtypes.float64, use_xla_jit=False) + def benchmarkUniformF16XLA(self): + self._benchmarkUniform( + 'uniform_f16', dtype=dtypes.float16, use_xla_jit=True) + def benchmarkUniformF32XLA(self): self._benchmarkUniform( 'uniform_f32', dtype=dtypes.float32, use_xla_jit=True) diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index 659d9f41e8d..3c9b29b1835 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -236,6 +236,17 @@ def testZerosLikeForTensorList(self): self.assertAllEqual(z.shape.as_list(), [None]) self.assertAllEqual(z, [0.0, 0.0]) + def testInvalidSplitLength(self): + with self.session(), self.test_scope(): + tensor_list_split = list_ops.tensor_list_split( + tensor=[1], element_shape=[-1], lengths=[0] + ) + with self.assertRaisesRegex( + errors.UnimplementedError, "All lengths must be positive" + ): + self.evaluate(tensor_list_split) + + if __name__ == "__main__": os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " + os.environ.get("TF_XLA_FLAGS", "")) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 7334d316a55..04ee8dbd615 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -864,52 +864,70 @@ def testBiasAddGrad(self): expected=np.array([14., 22.], dtype=np.float32)) def testCast(self): - shapes = [[], [4], [2, 3], [2, 0, 4]] types = { dtypes.bool, dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64 } for src_type in types: for dst_type in types: - src_np_dtype = src_type.as_numpy_dtype - dst_np_dtype = dst_type.as_numpy_dtype - - for shape in shapes: - src = np.arange(np.prod(shape)).astype(src_np_dtype) - - if src_type in self.complex_tf_types: - src += (np.arange(np.prod(shape)) * 2j).astype(src_np_dtype) - src = src.reshape(shape) - dst = src.astype(dst_np_dtype) - self._assertOpOutputMatchesExpected( - lambda x, dst_type=dst_type: math_ops.cast(x, dst_type), - src, - expected=dst) - - # Check special values. - if src_type.is_integer: - imin = np.iinfo(src_np_dtype).min - imax = np.iinfo(src_np_dtype).max - src = np.array([imin, imax, 0, 1, -1], dtype=src_np_dtype) - elif src_type in self.float_tf_types: - if dst_type.is_integer: - imin = np.iinfo(dst_np_dtype).min - imax = np.iinfo(dst_np_dtype).max // 2 - src = np.array([imin, imax, 0, 1], dtype=src_np_dtype) - elif dst_type in self.float_tf_types: - fmin = np.finfo(dst_np_dtype).min - fmax = np.finfo(dst_np_dtype).max - tiny = np.finfo(dst_np_dtype).tiny - eps = np.finfo(dst_np_dtype).eps - src = np.array( - [fmin, fmax, np.nan, eps, -eps, tiny, -tiny, np.inf, -np.inf], - dtype=src_np_dtype) + self._testCast(src_type, dst_type) + + def testCastFp8(self): + fp8_types = {dtypes.float8_e5m2, dtypes.float8_e4m3fn} + # TODO(b/259609697): Test casting to bool. Casting from float8 to bool is + # currently not supported since the cast is lowered to an Ne (not-equal) op, + # and FP8 is currently not supported with Ne. + other_types = { + dtypes.float32, dtypes.float64, dtypes.complex64, + dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64 + } + for fp8_type in fp8_types: + for other_type in other_types | fp8_types: + self._testCast(fp8_type, other_type) + self._testCast(other_type, fp8_type) + + def _testCast(self, src_type, dst_type): + with self.subTest(src_type=src_type, dst_type=dst_type): + shapes = [[], [4], [2, 3], [2, 0, 4]] + src_np_dtype = src_type.as_numpy_dtype + dst_np_dtype = dst_type.as_numpy_dtype + + for shape in shapes: + src = np.arange(np.prod(shape)).astype(src_np_dtype) + + if src_type in self.complex_tf_types: + src += (np.arange(np.prod(shape)) * 2j).astype(src_np_dtype) + src = src.reshape(shape) dst = src.astype(dst_np_dtype) self._assertOpOutputMatchesExpected( lambda x, dst_type=dst_type: math_ops.cast(x, dst_type), src, expected=dst) + # Check special values. + if src_type.is_integer: + imin = np.iinfo(src_np_dtype).min + imax = np.iinfo(src_np_dtype).max + src = np.array([imin, imax, 0, 1, -1], dtype=src_np_dtype) + elif src_type in self.float_tf_types: + if dst_type.is_integer: + imin = np.iinfo(dst_np_dtype).min + imax = np.iinfo(dst_np_dtype).max // 2 + src = np.array([imin, imax, 0, 1], dtype=src_np_dtype) + elif dst_type in self.float_tf_types: + fmin = np.finfo(dst_np_dtype).min + fmax = np.finfo(dst_np_dtype).max + tiny = np.finfo(dst_np_dtype).tiny + eps = np.finfo(dst_np_dtype).eps + src = np.array( + [fmin, fmax, np.nan, eps, -eps, tiny, -tiny, np.inf, -np.inf], + dtype=src_np_dtype) + dst = src.astype(dst_np_dtype) + self._assertOpOutputMatchesExpected( + lambda x, dst_type=dst_type: math_ops.cast(x, dst_type), + src, + expected=dst) + def testBitcast(self): self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.int32), diff --git a/tensorflow/compiler/tests/unique_ops_test.py b/tensorflow/compiler/tests/unique_ops_test.py new file mode 100644 index 00000000000..0938bfa430d --- /dev/null +++ b/tensorflow/compiler/tests/unique_ops_test.py @@ -0,0 +1,46 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for unique ops.""" + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.platform import googletest + + +class UniqueTest(xla_test.XLATestCase): + + def testNegativeAxis(self): + """Verifies that an axis with negative index is converted to positive.""" + with self.session() as session: + with self.test_scope(): + px = array_ops.placeholder(dtypes.float32, [2, 1, 1], name="x") + axis = constant_op.constant([-1], dtype=dtypes.int32) + output = gen_array_ops.unique_v2(px, axis) + result = session.run( + output, {px: np.array([[[-2.0]], [[10.0]]], dtype=np.float32)} + ) + self.assertAllEqual( + result.y, np.array([[[-2.0]], [[10.0]]], dtype=np.float32) + ) + self.assertAllEqual(result.idx, np.array([0], dtype=np.int32)) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/xla_call_module_test.py b/tensorflow/compiler/tests/xla_call_module_test.py index afdd847d29e..ae18b95d301 100644 --- a/tensorflow/compiler/tests/xla_call_module_test.py +++ b/tensorflow/compiler/tests/xla_call_module_test.py @@ -160,13 +160,13 @@ def f(x): # x: f32[2, b] module = """ module @jit_f.0 { func.func public @main(%arg0: tensor, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor) { - %0 = mhlo.sine %arg1 : tensor<2x?xf32> + %0 = stablehlo.sine %arg1 : tensor<2x?xf32> return %0, %arg0 : tensor<2x?xf32>, tensor } } """ return xla.call_module([x], - version=1, + version=2, module=module, Tout=[x.dtype, np.int32], Sout=[(None, 3), ()], @@ -174,6 +174,29 @@ def f(x): # x: f32[2, b] self._assertOpOutputMatchesExpected(f, (x,), (np.sin(x), x.shape[1])) + def test_dim_var_basic_dim_arg_i64(self): + x = np.arange(6, dtype=np.float32).reshape((2, 3)) + + def f(x): # x: f32[2, b] + # Module takes another argument which is the value of b + # (sin(x), x.shape[1]) + module = """ +module @jit_f.0 { + func.func public @main(%arg0: tensor, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor) { + %0 = stablehlo.sine %arg1 : tensor<2x?xf32> + return %0, %arg0 : tensor<2x?xf32>, tensor + } +} +""" + return xla.call_module([x], + version=2, + module=module, + Tout=[x.dtype, np.int64], + Sout=[(None, 3), ()], + dim_args_spec=['0.1']) + + self._assertOpOutputMatchesExpected(f, (x,), (np.sin(x), x.shape[1])) + def test_dim_var_basic_wrapped(self): """Like dim_arg_var_basic, but with the wrapper already added.""" x = np.arange(6, dtype=np.float32).reshape((2, 3)) @@ -184,18 +207,18 @@ def f(x): # x: f32[2, b] module = """ module @jit_f.0 { func.func public @main(%arg0: tensor, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor) { - %arg0_new = "mhlo.get_dimension_size"(%arg1) {dimension = 1 : i64} : (tensor<2x?xf32>) -> tensor + %arg0_new = "stablehlo.get_dimension_size"(%arg1) {dimension = 1 : i64} : (tensor<2x?xf32>) -> tensor %arg1_new = tensor.cast %arg1 : tensor<2x?xf32> to tensor<2x?xf32> %0, %1 = call @dyn_main(%arg0_new, %arg1_new) : (tensor, tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor) return %0, %1 : tensor<2x?xf32>, tensor } func.func private @dyn_main(%arg0: tensor, %arg1: tensor<2x?xf32>) -> (tensor<2x?xf32>, tensor) { - %0 = mhlo.sine %arg1 : tensor<2x?xf32> + %0 = stablehlo.sine %arg1 : tensor<2x?xf32> return %0, %arg0 : tensor<2x?xf32>, tensor } } """ - return xla.call_module([x], version=1, + return xla.call_module([x], version=2, module=module, Tout=[x.dtype, np.int32], Sout=[(None, 3), ()], @@ -203,6 +226,88 @@ def f(x): # x: f32[2, b] self._assertOpOutputMatchesExpected(f, (x,), (np.sin(x), x.shape[1])) + def test_dim_args_spec_errors(self): + # x, y: f32[2, b, c] + x = np.arange(24, dtype=np.float32).reshape((2, 3, 4)) + y = x + + # Module takes two prefix arguments with the values of b and c + # return (sin(x + y), x.shape[1]) + module = """ +module @jit_f.0 { + func.func public @main(%arg0: tensor, %arg1: tensor, %arg2: tensor<2x?x?xf32>, %arg3: tensor<2x?x?xf32>) -> (tensor<2x?x?xf32>, tensor) { + %0 = stablehlo.add %arg2, %arg3 : tensor<2x?x?xf32> + %1 = stablehlo.sine %0 : tensor<2x?x?xf32> + return %1, %arg0 : tensor<2x?x?xf32>, tensor + } +} +""" + + dim_args_spec = ['0.1', '0.2'] + def f(x, y): + return xla.call_module([x, y], + version=2, + module=module, + Tout=[x.dtype, np.int32], + Sout=[(None, 3), ()], + dim_args_spec=dim_args_spec) + self._assertOpOutputMatchesExpected(f, (x, y), (np.sin(x + y), x.shape[1])) + + dim_args_spec = ['0.0', '0.0', '0.0', '0.0'] # Too many dim_args_spec + with self.assertRaisesRegex( + errors.InvalidArgumentError, + 'The module should have 4 dimension arguments, ' + 'but it has only 4 total arguments'): + self._assertOpOutputMatchesExpected(f, (x, y), + (np.sin(x + y), x.shape[1])) + + dim_args_spec = ['0.0', '0.0', '0.0'] # dim_args_spec refers to non-scalar + with self.assertRaisesRegex( + errors.InvalidArgumentError, + 'Module argument at index 2 should be a 0-dimensional integer-tensor ' + 'dimension argument but has type'): + self._assertOpOutputMatchesExpected(f, (x, y), + (np.sin(x + y), x.shape[1])) + + dim_args_spec = [] # No dim_args_spec + with self.assertRaisesRegex( + errors.InvalidArgumentError, + 'Module main has dynamic shapes but no dim_args_spec was given'): + self._assertOpOutputMatchesExpected(f, (x, y), + (np.sin(x + y), x.shape[1])) + + dim_args_spec = ['1.0'] # Too few dim_args_spec + with self.assertRaisesRegex( + errors.InvalidArgumentError, + 'Incorrect number of arguments for XlaCallModule: 2. ' + 'The module has 4 of which 1 were declared to be dimension arguments.'): + self._assertOpOutputMatchesExpected(f, (x, y), + (np.sin(x + y), x.shape[1])) + + dim_args_spec = ['0.b', '0.1'] # axis_idx not a number + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Syntax error in dim_args_spec '0.b'"): + self._assertOpOutputMatchesExpected(f, (x, y), + (np.sin(x + y), x.shape[1])) + + dim_args_spec = ['2.0', '0.1'] # arg_idx too large + with self.assertRaisesRegex( + errors.InvalidArgumentError, + 'Invalid argument index 2 when the number of non-dimension arguments ' + "is 2 in dim_arg_spec '2.0'"): + self._assertOpOutputMatchesExpected(f, (x, y), + (np.sin(x + y), x.shape[1])) + + dim_args_spec = ['0.3', '0.1'] # axis_idx too large + with self.assertRaisesRegex( + errors.InvalidArgumentError, + 'Invalid axis index 3 when the rank of non-dimension argument 0 ' + "is 3 in dim_arg_spec '0.3'"): + self._assertOpOutputMatchesExpected(f, (x, y), + (np.sin(x + y), x.shape[1])) + + @unittest.skip('TODO(burmako): Re-enable this after shape refinement is done') def test_dynamic_iota(self): x = np.ones((3, 5), dtype=np.int32) res = np.arange(x.shape[0], dtype=np.int32) @@ -212,13 +317,38 @@ def f(x): # x: f32[b, 5] module = """ module @jit_fun.1 { func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = mhlo.reshape %arg0 : (tensor) -> tensor<1xi32> - %1 = "mhlo.dynamic_iota"(%0) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor + %0 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> + %1 = "stablehlo.dynamic_iota"(%0) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor return %1 : tensor } } """ - return xla.call_module([x,], version=1, + return xla.call_module([x,], version=2, + module=module, + Tout=[res.dtype], + Sout=[(None,)], + dim_args_spec=['0.0']) + + self._assertOpOutputMatchesExpected(f, (x,), (res,)) + + @unittest.skip('TODO(burmako): Shape inference leaves dynamic_reshape') + def test_dynamic_reshape(self): + x = np.ones((4, 3), dtype=np.float32) + res = x.reshape((-1,)) + + def f(x): # x: f32[b, 3] + module = """ +module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.constant dense<3> : tensor + %1 = stablehlo.multiply %arg0, %0 : tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.dynamic_reshape %arg1, %2 : (tensor, tensor<1xi32>) -> tensor + return %3 : tensor + } +} +""" + return xla.call_module([x], module=module, Tout=[res.dtype], Sout=[(None,)], @@ -226,6 +356,121 @@ def f(x): # x: f32[b, 5] self._assertOpOutputMatchesExpected(f, (x,), (res,)) + @unittest.skip('TODO(burmako): Shape inference adds tf.Cast') + def test_dynamic_reshape_cast(self): + x = np.ones((4, 2, 3), dtype=np.float32) + res = np.sin(x).reshape((4, -1)) + + def f(x): # x: f32[b, 2, 3] + module = """ +module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.sine %arg1 : tensor + %1 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> + %2 = stablehlo.constant dense<6> : tensor<1xi32> + %3 = stablehlo.concatenate %1, %2, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %4 = stablehlo.dynamic_reshape %0, %3 : (tensor, tensor<2xi32>) -> tensor + return %4 : tensor + } +} +""" + return xla.call_module([x], + module=module, + Tout=[res.dtype], + Sout=[(None, 6)], + dim_args_spec=['0.0']) + + self._assertOpOutputMatchesExpected(f, (x,), (res,)) + + @unittest.skip('TODO(burmako): Crash in simplifyDynamicGatherToGather()') + def test_dynamic_gather(self): + x = np.ones((3, 4), dtype=np.float32) + idx = np.array([2, 2], np.int32) + res = x[idx] + + def f(x): # x: f32[b, 4] + module = """ +module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.constant dense<0> : tensor + %1 = stablehlo.constant dense<0> : tensor<1xi64> + %2 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> + %3 = stablehlo.constant dense<2> : tensor<1xi32> + %4 = stablehlo.concatenate %2, %3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %5 = "stablehlo.dynamic_gather"(%arg1, %1, %4) {dimension_numbers = #stablehlo.gather, indices_are_sorted = true} : (tensor, tensor<1xi64>, tensor<2xi32>) -> tensor + return %5 : tensor + } +} +""" + return xla.call_module([x], + module=module, + Tout=[res.dtype], + Sout=[(None, 2)], + dim_args_spec=['0.0']) + + self._assertOpOutputMatchesExpected(f, (x,), (res,)) + + @unittest.skip('TODO(burmako): Shape inference leaves real_dynamic_slice') + def test_real_dynamic_slice(self): + x = np.ones((3, 4), dtype=np.float32) + res = x[-1, :] # TODO(necula): adjust this, if not the right result + + def f(x): # x: f32[b, 4] + module = """ +module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { + %0 = stablehlo.constant dense<-1> : tensor + %1 = stablehlo.add %arg0, %0 : tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.constant dense<0> : tensor<1xi32> + %4 = stablehlo.concatenate %2, %3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %5 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> + %6 = stablehlo.constant dense<4> : tensor<1xi32> + %7 = stablehlo.concatenate %5, %6, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %10 = stablehlo.constant dense<1> : tensor<2xi32> + %11 = stablehlo.real_dynamic_slice %arg1, %4, %7, %10 : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x4xf32> + %12 = stablehlo.reshape %11 : (tensor<1x4xf32>) -> tensor<4xf32> + return %12 : tensor<4xf32> + } +} +""" + return xla.call_module([x], + module=module, + Tout=[x.dtype], + Sout=[(4,)], + dim_args_spec=['0.0']) + + self._assertOpOutputMatchesExpected(f, (x,), (res,)) + + @unittest.skip('TODO(burmako): Module verification with dynamic_update_slice') + def test_dynamic_update_slice(self): + x = np.ones((3, 4), dtype=np.float32) + idx = np.int32(-2) + res = x # The update should be a nop + + def f(x, idx): # x: f32[b, 4] idx: i32 + module = """ +module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = stablehlo.constant dense<0> : tensor + %1 = stablehlo.compare LT, %arg2, %0, SIGNED : (tensor, tensor) -> tensor + %2 = stablehlo.add %arg2, %arg0 : tensor + %3 = stablehlo.select %1, %2, %arg2 : tensor, tensor + %4 = stablehlo.constant dense<0> : tensor + %5 = stablehlo.dynamic_update_slice %arg1, %arg1, %3, %4 : (tensor, tensor, tensor, tensor) -> tensor + return %5 : tensor + } +} +""" + return xla.call_module([x, idx], + module=module, + Tout=[res.dtype], + Sout=[(None, 4)], + dim_args_spec=['0.0']) + + self._assertOpOutputMatchesExpected(f, (x, idx), (res,)) + + @unittest.skip('TODO(burmako): Re-enable this after shape refinement is done') def test_dynamic_broadcast_in_dim(self): x = np.ones((3, 4), dtype=np.float32) y = np.ones((2, 3, 4), dtype=np.float32) @@ -236,17 +481,17 @@ def f(x, y): # x: f32[b, 4] y: f32[2, b, 4] module = """ module @jit_fun.0 { func.func public @main(%arg0: tensor, %arg1: tensor, %arg2: tensor<2x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) { - %0 = mhlo.constant dense<2> : tensor<1xi32> - %2 = mhlo.reshape %arg0 : (tensor) -> tensor<1xi32> - %3 = mhlo.constant dense<4> : tensor<1xi32> - %4 = "mhlo.concatenate"(%0, %2, %3) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> - %5 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %4) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xi32>) -> tensor<2x?x4xf32> - %6 = mhlo.add %5, %arg2 : (tensor<2x?x4xf32>, tensor<2x?x4xf32>) -> tensor<2x?x4xf32> + %0 = stablehlo.constant dense<2> : tensor<1xi32> + %2 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> + %3 = stablehlo.constant dense<4> : tensor<1xi32> + %4 = "stablehlo.concatenate"(%0, %2, %3) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + %5 = "stablehlo.dynamic_broadcast_in_dim"(%arg1, %4) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xi32>) -> tensor<2x?x4xf32> + %6 = stablehlo.add %5, %arg2 : (tensor<2x?x4xf32>, tensor<2x?x4xf32>) -> tensor<2x?x4xf32> return %5, %6 : tensor<2x?x4xf32>, tensor<2x?x4xf32> } } """ - return xla.call_module([x, y], version=1, + return xla.call_module([x, y], version=2, module=module, Tout=[res[0].dtype, res[1].dtype], Sout=[(2, None, 4), (2, None, 4)], @@ -274,7 +519,7 @@ def f(x): # x: i32[b] } } """ - return xla.call_module([x], version=2, + return xla.call_module([x], version=1, module=module, Tout=[res.dtype], Sout=[res.shape], @@ -282,6 +527,38 @@ def f(x): # x: i32[b] self._assertOpOutputMatchesExpected(f, (x,), (res,)) + @unittest.skip('TODO(burmako): tf.Cast added after reduce') + def test_reduce_broadcast(self): + x = np.broadcast_to(np.arange(3, dtype=np.float32).reshape(3, 1), (3, 5)) + res = np.any(x, axis=1) # TODO(necula): not sure this should be the result + + def f(x): # x: f32[b, 5] + module = """ +module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.constant dense<0.000000e+00> : tensor + %1 = stablehlo.reduce(%arg1 init: %0) across dimensions = [1] : (tensor, tensor) -> tensor + reducer(%arg2: tensor, %arg3: tensor) { + %6 = stablehlo.add %arg2, %arg3 : tensor + stablehlo.return %6 : tensor + } + %2 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> + %3 = stablehlo.constant dense<1> : tensor<1xi32> + %4 = stablehlo.concatenate %2, %3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %5 = stablehlo.dynamic_broadcast_in_dim %1, %4, dims = [0] : (tensor, tensor<2xi32>) -> tensor + return %5 : tensor + } +} +""" + return xla.call_module([x,], + module=module, + Tout=[res.dtype], + Sout=[(None, 1)], + dim_args_spec=['0.0']) + + self._assertOpOutputMatchesExpected(f, (x,), (res,)) + + @unittest.skip('TODO(burmako): Re-enable this after shape refinement is done') def test_call(self): """A chain of calls.""" x = np.ones((5,), dtype=np.float32) @@ -295,13 +572,13 @@ def f(x): # x: f32[b] return %0 : tensor } func.func private @f(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = mhlo.reshape %arg0 : (tensor) -> tensor<1xi32> - %1 = "mhlo.dynamic_iota"(%0) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor + %0 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> + %1 = "stablehlo.dynamic_iota"(%0) {iota_dimension = 0 : i64} : (tensor<1xi32>) -> tensor return %1 : tensor } } """ - return xla.call_module([x,], version=1, + return xla.call_module([x,], version=2, module=module, Tout=[res.dtype], Sout=[()], @@ -309,6 +586,67 @@ def f(x): # x: f32[b] self._assertOpOutputMatchesExpected(f, (x,), (res,)) + def test_identity(self): + x = np.ones((5,), dtype=np.float32) + res = x + + def f(x): # x: f32[b] + module = """ +module @jit_fun_3 { + func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + return %arg1 : tensor + } +} +""" + return xla.call_module([x], + version=2, + module=module, + Tout=[res.dtype], + Sout=[()], + dim_args_spec=['0.0']) + + self._assertOpOutputMatchesExpected(f, (x,), (res,)) + + @unittest.skip('TODO(burmako): Shape inference failure for while') + def test_while(self): + """A while loop with carryied dynamic shapes.""" + x = np.ones((5,), dtype=np.float32) + # Compute the result in Pyton first + res0 = x + for i in range(5): + res0 += np.arange(x.shape[0], dtype=np.float32) + res1 = np.int64(i) + + def f(x): # x: f32[b] + module = """ +module @jit_fun_flat_jax { + func.func public @main(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0 = stablehlo.constant dense<0> : tensor + %1:2 = stablehlo.while(%iterArg = %arg1, %iterArg_0 = %0) : tensor, tensor + cond { + %2 = stablehlo.constant dense<5> : tensor + %3 = stablehlo.compare LT, %iterArg_0, %2, SIGNED : (tensor, tensor) -> tensor + stablehlo.return %3 : tensor + } do { + %2 = stablehlo.reshape %arg0 : (tensor) -> tensor<1xi32> + %3 = stablehlo.dynamic_iota %2, dim = 0 : (tensor<1xi32>) -> tensor + %4 = stablehlo.add %iterArg, %3 : tensor + %5 = stablehlo.constant dense<1> : tensor + %6 = stablehlo.add %iterArg_0, %5 : tensor + stablehlo.return %4, %6 : tensor, tensor + } + return %1#0, %1#1 : tensor, tensor + } +} +""" + return xla.call_module([x,], version=2, + module=module, + Tout=[res0.dtype, res1.dtype], + Sout=[(None,), res1.shape], + dim_args_spec=['0.0']) + + self._assertOpOutputMatchesExpected(f, (x,), (res0, res1)) + if __name__ == '__main__': # This test is using Tensorflow sessions which are not compatible with eager diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index bb075fddc36..17055f4070d 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -636,7 +636,7 @@ def testDynamicSliceWithIncorrectStartIndicesShape(self): session.run(output) self.assertRegex( invalid_arg_error.exception.message, - (r'op has mismatched number of slice sizes \(3\) and number of start' + (r'has mismatched number of slice sizes \(3\) and number of start' r' indices \(2\)')) def testDynamicSliceWithIncorrectSizeIndicesShape(self): @@ -649,7 +649,7 @@ def testDynamicSliceWithIncorrectSizeIndicesShape(self): session.run(output) self.assertRegex( invalid_arg_error.exception.message, - (r'op has mismatched number of slice sizes \(2\) and number of start' + (r'has mismatched number of slice sizes \(2\) and number of start' r' indices \(3\)')) def test_optimization_barrier(self): diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index a83709e6b98..25dc03b4ee5 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -28,6 +28,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], features = [ "-layering_check", @@ -99,7 +100,7 @@ tf_cuda_cc_test( deps = [ ":trt_logging", ":utils", - "//tensorflow/core/common_runtime/gpu:gpu_init", + "//tensorflow/compiler/xla/stream_executor/gpu:gpu_init", "//tensorflow/core:lib", "//tensorflow/core/platform:stream_executor", "//tensorflow/core:test", @@ -618,6 +619,11 @@ tf_cuda_library( ":utils", ":op_converter", "//tensorflow/core:lib", + "//tensorflow/core/platform:env", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", ] + if_static([":op_converter_registry_impl"]), ) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 608d88bb942..79a60d2b1de 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -785,13 +785,17 @@ Status ConvertGraph(const TRTOptimizationPass::ConversionParams& params, params.use_calibration, params.use_implicit_batch, params.use_explicit_precision); TF_RETURN_IF_ERROR(segment::SegmentGraph( - &graph, &static_graph_properties, + /*tf_graph=*/&graph, + /*graph_properties=*/&static_graph_properties, + /*candidate_fn=*/ std::bind(&TrtNodeValidator::IsTensorRTCandidate, &validator, std::placeholders::_1), // Input validation is already done by TrtNodeValidator, so we don't // need to check the input edges. - [](const Edge* edge) { return true; }, OutputEdgeValidator(), - segment_options, &initial_segments)); + /*input_candidate_fn=*/[](const Edge* edge) { return true; }, + /*output_candidate_fn=*/OutputEdgeValidator(), + /*options=*/segment_options, + /*segments=*/&initial_segments)); LOG(INFO) << "Number of TensorRT candidate segments: " << initial_segments.size(); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 2ea9b47b7ef..584b9c867c5 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -2098,7 +2098,11 @@ Status ConvertConv2DHelper(const OpConverterParams* params, int group, if (params->use_explicit_precision) { TRT_ENSURE(inputs.at(1).is_tensor()); - conv_layer->setInput(1, *inputs.at(1).tensor()->trt_tensor()); + nvinfer1::IShuffleLayer* layer = params->converter->network()->addShuffle( + *inputs.at(1).tensor()->trt_tensor()); + layer->setFirstTranspose({3, 2, 0, 1}); + layer->setReshapeDimensions({4, {0, 0, 0, 0}}); + conv_layer->setInput(1, *layer->getOutput(0)); } params->converter->SetLayerName(conv_layer, node_def, "conv"); @@ -3649,6 +3653,24 @@ Status ConvertIdentity(const OpConverterParams* params) { return OkStatus(); } +// This converter is a debug-only feature designed to allow graph segmentation +// experiments. Its use is being controled by +// `TF_TRT_OP_FAKELIST=OpName1,OpName2,...`. +// See `op_converter_registry.cc` for further details. +// +// This converter is designed as followed: +// - always succeed at graph segmentation time. +// - always fail at TRT Engine build time. +Status ConvertFake(const OpConverterParams* params) { + if (params->validation_only) return OkStatus(); + + return errors::Unimplemented( + "This converter is not valid after graph " + "segmentation. Building an engine using this " + "converter will trigger a native segment " + "fallback."); +} + Status ConvertSquare(const OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; @@ -5718,7 +5740,9 @@ Status ConvertAddN(const OpConverterParams* params) { tensor_inputs.push_back(input.tensor()); } else { auto dims = input.weights().Shape(); - TF_RETURN_IF_ERROR(dims.RemoveBatchDimension()); + if (params->use_implicit_batch) { + TF_RETURN_IF_ERROR(dims.RemoveBatchDimension()); + } tensor_inputs.push_back(params->converter->CreateConstantLayer( input.weights(), dims.AsTrtDims())); } @@ -5795,6 +5819,8 @@ REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertIdentity, "StopGradient", "_CopyFromHostToGpu"}); REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertBatchMatMul, {"BatchMatMul", "BatchMatMulV2"}); +// Debug converter only accessible via `TF_TRT_OP_FAKELIST=OpName1,OpName2,...` +REGISTER_DEFAULT_TRT_OP_CONVERTER(ConvertFake, "FakeOp"); Status ConvertGraphDefToEngine( const GraphDef& gdef, OpKernelContext* ctx, TrtPrecisionMode precision_mode, diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index db67e6e358b..cf3c6ad22d4 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_tensor_proxy.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/costs/graph_properties.h" @@ -44,7 +43,7 @@ namespace tensorflow { namespace tensorrt { namespace convert { -using ::stream_executor::port::StatusOr; +using ::tsl::StatusOr; struct EngineConnection { // Constructs a non-control edge. diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 170a707f88b..a52ceae693c 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -5034,8 +5034,7 @@ TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertStridedSlice) { }; // Same input is used for all tests. - const std::vector ok_input = {1, 2, 3, 4, 5, 6}; - + const std::vector ok_input = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; Status modified_batch_dim_status = (trt_mode_ == TrtTestMode::kImplicitBatch) ? errors::Unimplemented( @@ -5712,6 +5711,48 @@ TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertStridedSlice) { "new_axis_mask is not supported for StridedSlice"), /*runtime_status=*/OkStatus(), /*partial_input_dims=*/{1, 6}}, + // Test all axes dynamic inputs with shrink_axis_mask + TestParams{/*input_dims=*/{1, 3, 2}, + /*begin=*/{0, 0, 0}, + /*end=*/{0, 0, 3}, + /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 1, 1}), + /*end_mask=*/get_mask({0, 1, 1}), + /*ellipsis_mask=*/0, + /*new_axis_mask=*/0, + /*shrink_axis_mask=*/1, + /*expected_output_dims=*/{3, 2}, + /*expected_output=*/{1, 2, 3, 4, 5, 6}, + /*conversion_status=*/modified_batch_dim_status, OkStatus(), + /*partial_input_dims=*/{-1, -1, -1}}, + // Test dynamic input with shrink_axis_mask along axis=0 + TestParams{/*input_dims=*/{2, 3, 2}, + /*begin=*/{0, 0, 0}, + /*end=*/{0, 0, 3}, + /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 1, 1}), + /*end_mask=*/get_mask({0, 1, 1}), + /*ellipsis_mask=*/0, + /*new_axis_mask=*/0, + /*shrink_axis_mask=*/1, + /*expected_output_dims=*/{3, 2}, + /*expected_output=*/{1, 2, 3, 4, 5, 6}, + /*conversion_status=*/modified_batch_dim_status, OkStatus(), + /*partial_input_dims=*/{-1, -1, 2}}, + // Test dynamic input sizes with multiple axes shrinking + TestParams{/*input_dims=*/{2, 3, 2}, + /*begin=*/{0, 0, 0}, + /*end=*/{0, 0, 3}, + /*strides=*/{1, 1, 1}, + /*begin_mask=*/get_mask({0, 1, 1}), + /*end_mask=*/get_mask({0, 1, 1}), + /*ellipsis_mask=*/0, + /*new_axis_mask=*/0, + /*shrink_axis_mask=*/3, + /*expected_output_dims=*/{2}, + /*expected_output=*/{1, 2}, + /*conversion_status=*/modified_batch_dim_status, OkStatus(), + /*partial_input_dims=*/{-1, -1, 2}}, }; int i = 0; @@ -5737,7 +5778,6 @@ TEST_P(OpConverter_FP32_FP16_INT32_Test, ConvertStridedSlice) { if (p.partial_input_dims.size() > 0) { AddTestTensor("input", p.input_dims, tf_type_, ok_input, p.partial_input_dims); - } else { AddTestTensor("input", p.input_dims, tf_type_, ok_input, p.input_dims); diff --git a/tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.cc b/tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.cc index b1c47282845..b33bd3f6a33 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.cc @@ -15,9 +15,15 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h" #include +#include #include +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/env_var.h" #if GOOGLE_CUDA && GOOGLE_TENSORRT @@ -57,7 +63,37 @@ class OpConverterRegistry::Impl { return {}; } - StatusOr LookUp(const string& name) { + StatusOr LookUp(string name) { + // Fetch the user-provide TF operations denylisted for conversion by TF-TRT. + static const absl::flat_hash_set tftrt_op_fakelist = [] { + string tftrt_op_fakelist_str; + TF_CHECK_OK(ReadStringFromEnvVar("TF_TRT_OP_FAKELIST", + /*default_value=*/"", + &tftrt_op_fakelist_str)); + absl::flat_hash_set tftrt_op_fakelist{}; + for (const auto& x : str_util::Split(tftrt_op_fakelist_str, ",")) { + tftrt_op_fakelist.insert(x); + } + // Force a rehash of the flat hash set + tftrt_op_fakelist.rehash(0); + return tftrt_op_fakelist; + }(); + + // In case the TensorFlow OP `name` matches any of the names passed to + // TF_TRT_OP_FAKELIST environment variable, force ::LookUp to resolves to + // ConvertFake OP converter. + if (tftrt_op_fakelist.contains(name)) { + LOG_FIRST_N(INFO, 2) << "Emulating OP Converter: `" << name << "`. It " + << "will cause TRT engine building to fail. This " + << "feature is only intended to be used for " + << "TF-TRT graph segmentation experiments. This " + << "feature is controlled using: " + << "`TF_TRT_OP_FAKELIST=OpName1,OpName2`."; + // Forces ::LookUp to resolve to `ConvertFake` registred to `FakeOp`. + mutex_lock lock(mu_); + return registry_.find("FakeOp")->second.converter; + } + mutex_lock lock(mu_); auto found = registry_.find(name); if (found != registry_.end()) { diff --git a/tensorflow/compiler/tf2tensorrt/convert/ops/fill_ops.cc b/tensorflow/compiler/tf2tensorrt/convert/ops/fill_ops.cc index adf0796840b..fc5fc589211 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/ops/fill_ops.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/ops/fill_ops.cc @@ -288,7 +288,7 @@ class ConvertRange : public ConvertFillBase { }; std::string convert_range_error_msg(float start, float limit, float delta) { - const char* format_string = + constexpr char* format_string = "For parameters (start, limit) = (%.2f, %.2f) " "of the Range operation delta cannot be %s, got %.2f"; return absl::StrFormat(format_string, start, limit, diff --git a/tensorflow/compiler/tf2tensorrt/convert/ops/slice_ops.cc b/tensorflow/compiler/tf2tensorrt/convert/ops/slice_ops.cc index 2e3fd920b16..19305195016 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/ops/slice_ops.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/ops/slice_ops.cc @@ -72,7 +72,11 @@ Status ConvertStridedSliceHelper( size_dims.d[i] = (std::abs(end_dims->dim(i) - begin_dims->dim(i)) + std::abs(stride_dims->dim(i)) - 1) / std::abs(stride_dims->dim(i)); - + // When begin tensor has negative values, currently range can't be computed. + if (begin_dims->dim(i) < 0) { + return errors::Unimplemented( + "Negative values in begin weight tensor are unsupported"); + } if (input_dims.dim_size(i) < 0) { // end_dims and begin_dims do not have valid information yet. dynamic_input_size_indices.push_back(i); @@ -103,14 +107,6 @@ Status ConvertStridedSliceHelper( params->converter->network(), params->weight_store); TRT_ENSURE_OK(builder); - // VLOG(2) << "strided slice helper:" - // << " begin:" << DebugString(begin_dims) - // << "\n stride: " << DebugString(stride_dims) - // << "\n end: " << DebugString(end_dims) - // << "\n size: " << DebugString(size_dims) - // << "\n Dynamic indices: " << - // DebugString(dynamic_input_size_indices) - // << "\n Static indices: " << DebugString(static_input_size_indices); // Create the slice operation. For dynamic dims, the inputs of the operations // may be reassigned later. StatusOr slice = @@ -130,11 +126,33 @@ Status ConvertStridedSliceHelper( op_instance); ITensorProxyPtr tensor = (*slice)->getOutput(0); - // Reshape for shrink_axis. + // Reshape for shrink axis, ellipsis masks based on the shape computed by + // ValidateStridedSliceOp or HandleDynamicStridedSliceInput. + nvinfer1::Dims dims = tensor->trt_tensor()->getDimensions(); + std::vector slice_input_dims(dims.d, dims.d + dims.nbDims); + StridedSliceShapeSpec empty_spec; + empty_spec.shrink_axis_dense_mask = 0; + auto shrink_axis_mask = + strided_slice_spec.value_or(empty_spec).shrink_axis_dense_mask; if (final_shape) { - TF_RETURN_IF_ERROR(PrepareTensorForShape( - params->converter, TRT_TensorOrWeights(tensor), *final_shape, - /*validation_only=*/false, &tensor, node_def, op_instance)); + if (shrink_axis_mask) { + int shrink_idx = params->use_implicit_batch ? 1 : 0; + const auto bShrink_axis_mask = std::bitset<32>(shrink_axis_mask); + for (int idx = 0; idx < slice_input_dims.size(); ++idx, ++shrink_idx) { + const bool shrink_axis = bShrink_axis_mask[shrink_idx]; + if (shrink_axis) { + slice_input_dims[idx] = 0; + } + } + TF_RETURN_IF_ERROR(params->converter->SqueezeTensor( + tensor, &slice_input_dims, params, &tensor, op_instance)); + } else { + /* To do: pmajety: + Remove the else condition when shrink_axis_mask is always defined */ + TF_RETURN_IF_ERROR(PrepareTensorForShape( + params->converter, TRT_TensorOrWeights(tensor), *final_shape, + /*validation_only=*/false, &tensor, node_def, op_instance)); + } } params->outputs->push_back(TRT_TensorOrWeights(tensor)); return OkStatus(); @@ -152,6 +170,39 @@ Status HandleDynamicStridedSliceInput( nvinfer1::ITensor* input_tensor = slice_layer->getInput(0); TRT_ENSURE(input_tensor); + // When begin_mask or end_mask are set, we have to disregard the begin_tensor + // and end_tensor values. In static indices cases, ValidateStridedSliceOp + // returns the correct begin_tensor and end_tensor values, however with + // dynamic indices the correct shape has to be computed. + + VLOG(3) << "begin_dims before: " << DebugString(begin_dims); + VLOG(3) << "end_dims before: " << DebugString(end_dims); + const auto begin_mask = std::bitset<32>(strided_slice_spec.begin_dense_mask); + const auto end_mask = std::bitset<32>(strided_slice_spec.end_dense_mask); + const auto shrink_axis_mask = + std::bitset<32>(strided_slice_spec.shrink_axis_dense_mask); + nvinfer1::Dims dims = input_tensor->getDimensions(); + + for (int idx = 0; idx < dims.nbDims; ++idx) { + VLOG(3) << "begin_mask[" << idx << "]: " << begin_mask[idx]; + VLOG(3) << "end_mask[" << idx << "]: " << end_mask[idx]; + VLOG(3) << "shrink_mask[" << idx << "]: " << shrink_axis_mask[idx]; + if (begin_mask[idx]) { + begin_dims.d[idx] = 0; + } + if (end_mask[idx]) { + end_dims.d[idx] = dims.d[idx]; + } + if (shrink_axis_mask[idx]) { + end_dims.d[idx] = begin_dims.d[idx] + 1; + } + } + + VLOG(2) << "begin_dims after shrink_axis_mask correction: " + << DebugString(begin_dims); + VLOG(2) << "end_dims after shrink_axis_mask correction: " + << DebugString(end_dims); + // For each dynamic input dimension of the input, do some preprocessing based // on whether this dimension is set in "begin_mask" or "end_mask" and the sign // of the dimension's stride value. @@ -167,8 +218,7 @@ Status HandleDynamicStridedSliceInput( // dynamic size of dimension "dynamic_idx". absl::InlinedVector dynamic_begin_indices; absl::InlinedVector dynamic_end_indices; - const auto begin_mask = std::bitset<32>(strided_slice_spec.begin_dense_mask); - const auto end_mask = std::bitset<32>(strided_slice_spec.end_dense_mask); + for (int i = 0; i < dynamic_input_size_indices.size(); i++) { auto dynamic_idx = dynamic_input_size_indices[i]; if (begin_mask[dynamic_idx]) { @@ -177,7 +227,7 @@ Status HandleDynamicStridedSliceInput( dynamic_begin_indices.push_back(dynamic_idx); } } - if (end_mask[dynamic_idx]) { + if (end_mask[dynamic_idx] && !shrink_axis_mask[dynamic_idx]) { end_dims.d[dynamic_idx] = stride_dims.d[dynamic_idx] > 0 ? 0 : -1; if (stride_dims.d[dynamic_idx] > 0) { dynamic_end_indices.push_back(dynamic_idx); @@ -185,8 +235,8 @@ Status HandleDynamicStridedSliceInput( } } - // VLOG(2) << " Dynamic begin indices: " << DebugString(dynamic_begin_indices) - // << " Dynamic end indices: " << DebugString(dynamic_end_indices); + VLOG(2) << " Dynamic begin indices: " << DebugString(dynamic_begin_indices) + << " Dynamic end indices: " << DebugString(dynamic_end_indices); // Create ITensors for each of the begin/stride/end constants. StatusOr begin_const = builder->Constant( diff --git a/tensorflow/compiler/tf2tensorrt/tensorrt_test.cc b/tensorflow/compiler/tf2tensorrt/tensorrt_test.cc index 3ef38388edd..6e5366f6921 100644 --- a/tensorflow/compiler/tf2tensorrt/tensorrt_test.cc +++ b/tensorflow/compiler/tf2tensorrt/tensorrt_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/common/utils.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" -#include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_init.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/test.h" @@ -241,7 +241,7 @@ TEST(TensorrtTest, BasicFunctions) { #endif // Handle the case where the test is run on machine with no gpu available. - if (CHECK_NOTNULL(GPUMachineManager())->VisibleDeviceCount() <= 0) { + if (CHECK_NOTNULL(se::GPUMachineManager())->VisibleDeviceCount() <= 0) { LOG(WARNING) << "No gpu device available, probably not being run on a gpu " "machine. Skipping..."; return; diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc index 1481856e066..109339d0af9 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc @@ -35,10 +35,7 @@ bool IsGoogleTensorRTEnabled() { #else // TF_USE_TENSORRT_STATIC auto handle_or = se::internal::DsoLoader::TryDlopenTensorRTLibraries(); if (!handle_or.ok()) { - LOG_WARNING_WITH_PREFIX - << "Cannot dlopen some TensorRT libraries. If you would like " - "to use Nvidia GPU with TensorRT, please make sure the " - "missing libraries mentioned above are installed properly."; + LOG_WARNING_WITH_PREFIX << "Could not find TensorRT"; } return handle_or.ok(); #endif // TF_USE_TENSORRT_STATIC diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h index e51f5e75102..b0935afb5b2 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/common/datavec.h" #include "tensorflow/compiler/tf2tensorrt/common/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -33,7 +32,7 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -using ::stream_executor::port::StatusOr; +using ::tsl::StatusOr; // Creates a TensorRT execution context. ExecutionContext CreateExecutionContext(nvinfer1::ICudaEngine* cuda_engine); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index eefde262851..968882b88c4 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -20,6 +20,7 @@ load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load("//tensorflow/compiler/xla/service/cpu:build_defs.bzl", "runtime_copts") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [":internal"], licenses = ["notice"], ) @@ -287,11 +288,16 @@ cc_library( "//third_party/eigen3", "//tensorflow/core/framework:numeric_types", "//tensorflow/core/platform:bfloat16", + "//tensorflow/core/platform:float8", "//tensorflow/core/platform:stringpiece", # Extra dependencies required for multithreaded runtime objects. "//tensorflow/core/platform:blocking_counter", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:mutex", + "//tensorflow/tsl/platform:blocking_counter", + "//tensorflow/tsl/platform:logging", + "//tensorflow/tsl/platform:tstring", + "//tensorflow/tsl/platform:types", ] + tf_additional_tensor_coding_deps(), alwayslink = 1, ) @@ -313,6 +319,7 @@ cc_library( # "//tensorflow/tsl/platform:byte_order", # "//tensorflow/tsl/platform:cord", # "//tensorflow/tsl/platform:env_time", +# "//tensorflow/tsl/platform:float8", # "//tensorflow/tsl/platform:logging", # "//tensorflow/tsl/platform:macros", # "//tensorflow/tsl/platform:mutex", @@ -337,6 +344,7 @@ cc_library( # "//tensorflow/core/platform:byte_order", # "//tensorflow/core/platform:cord", # "//tensorflow/core/platform:env_time", +# "//tensorflow/core/platform:float8", # "//tensorflow/core/platform:logging", # "//tensorflow/core/platform:macros", # "//tensorflow/core/platform:mutex", @@ -457,6 +465,7 @@ cc_library( "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", + "//tensorflow/compiler/jit:xla_compile_util", "//tensorflow/compiler/mlir:array_container_utils", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:layout_util", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", @@ -472,8 +481,9 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/core/util:overflow", + "//tensorflow/core/tpu:tpu_defs", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -640,8 +650,8 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:computation_placer_hdr", - "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", "//tensorflow/compiler/xla/stream_executor:stream_executor_headers", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:layout_util", @@ -667,7 +677,7 @@ cc_library( ":host_compute_metadata_proto_cc", ":xla_resource", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/core:framework", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", @@ -907,7 +917,6 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -1073,6 +1082,7 @@ cc_library( name = "mlir_bridge_pass", srcs = ["mlir_bridge_pass.cc"], hdrs = ["mlir_bridge_pass.h"], + visibility = [":internal"], deps = [ ":tf2xla_defs", "//tensorflow/compiler/jit:flags", @@ -1325,7 +1335,7 @@ cc_library( hdrs = ["mlir_xla_op_kernel.h"], deps = [ ":xla_compiler", - "//tensorflow/compiler/jit:xla_compilation_cache", + "//tensorflow/compiler/jit:xla_compile_util", "//tensorflow/compiler/mlir:array_container_utils", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index fc42992e4b7..0f6754e005b 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -2,6 +2,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.default.bzl", "tf_gen_op_wrapper_cc") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//tensorflow/compiler/tf2xla:friends"], licenses = ["notice"], ) @@ -48,5 +49,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/tsl/protobuf:protos_all_cc", ], ) diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 6cebed16049..96787b05baf 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -12,6 +12,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ # copybara:uncomment "//learning/infra/mira:__subpackages__", "//third_party/cloud_tpu/inference_converter:__subpackages__", diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 62f3209bcea..a970c873695 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -212,7 +212,24 @@ XLA_MAKE_BINARY( xla::Div(xla::Mul(rhs, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), lhs, extend_dimensions)); -XLA_MAKE_BINARY(TruncateDiv, xla::Div(lhs, rhs, extend_dimensions)); +// Implementation of TruncateDiv. +// +// For floating-point values, returns trunc(x / y). For integers, simply +// returns x / y. +static xla::XlaOp TruncateDivImpl(xla::XlaBuilder* b, DataType dtype, + xla::XlaOp x, xla::XlaOp y, + const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + if (!DataTypeIsFloating(dtype)) { + return xla::Div(x, y); + } + auto zero = XlaHelpers::Zero(b, dtype); + auto x_div_y = xla::Div(x, y); + auto round_up = xla::Lt(x_div_y, zero); + return xla::Select(round_up, xla::Ceil(x_div_y), xla::Floor(x_div_y)); +} +XLA_MAKE_BINARY(TruncateDiv, + TruncateDivImpl(b, input_type(0), lhs, rhs, broadcast_helper)); XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions)); // Comparison ops diff --git a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc index ba95c77f612..5cac4d98c50 100644 --- a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc @@ -62,21 +62,15 @@ class DenseBincountOp : public XlaOpKernel { StatusOr input_shape_or = ctx->builder()->GetShape(input); OP_REQUIRES_OK(ctx, input_shape_or.status()); auto input_shape = input_shape_or.value(); - auto size = input_shape.dimensions(0); - if (!size) { - output = xla::Broadcast(zero, {output_size}); - ctx->SetOutput(0, output); - return; - } auto rank = input_shape.rank(); OP_REQUIRES(ctx, rank <= 2, errors::InvalidArgument( "Shape must be at most rank 2 but is rank ", rank)); - xla::XlaOp weights = ctx->Input(2); StatusOr weights_shape_or = ctx->builder()->GetShape(weights); + OP_REQUIRES_OK(ctx, weights_shape_or.status()); auto weights_shape = weights_shape_or.value(); @@ -91,11 +85,20 @@ class DenseBincountOp : public XlaOpKernel { "1. Received ", weights_shape.DebugString())); + auto size = input_shape.dimensions(0); + + if (!size) { + output = xla::Broadcast(zero, {output_size}); + ctx->SetOutput(0, output); + return; + } + auto weights_size = weights_shape.dimensions(0); bool has_weights = false; if (weights_size) { has_weights = true; } + xla::Shape output_shape = xla::ShapeUtil::MakeShape(dtype, {output_size}); xla::ScatterDimensionNumbers scatter_dnums; scatter_dnums.set_index_vector_dim(1); diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 248ab5b5323..20934423141 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -153,7 +153,11 @@ Status ConvBackpropComputeDimensionsV2XlaShapes( } // anonymous namespace -std::vector GetXlaConvTypes() { +std::vector GetXlaConvTypesForNonGpu() { + return {DT_FLOAT, DT_BFLOAT16, DT_HALF, DT_DOUBLE, DT_INT32}; +} + +std::vector GetXlaConvTypesForGpu() { return {DT_FLOAT, DT_BFLOAT16, DT_HALF, DT_DOUBLE}; } diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h index b7643999f81..7922c6ba821 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -35,9 +35,10 @@ limitations under the License. namespace tensorflow { -// We don't support integers for convolutions, so we list the supported types -// here. -std::vector GetXlaConvTypes(); +// We don't support integers for convolutions for GPU, so we list the supported +// types for non-gpu and gpu here. +std::vector GetXlaConvTypesForNonGpu(); +std::vector GetXlaConvTypesForGpu(); // ConvOpAttrs contains all of the metadata necessary to specify a TF or XLA // convolution. diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 5a31901142e..1d94cf4969f 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -70,25 +70,21 @@ class Conv2DOp : public ConvOp { explicit Conv2DOp(OpKernelConstruction* ctx) : ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {} }; -REGISTER_XLA_OP(Name("Conv2D").TypeConstraint("T", GetXlaConvTypes()), - Conv2DOp); +REGISTER_XLA_CONV_OP(Name("Conv2D"), Conv2DOp); class Conv3DOp : public ConvOp { public: explicit Conv3DOp(OpKernelConstruction* ctx) : ConvOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {} }; -REGISTER_XLA_OP(Name("Conv3D").TypeConstraint("T", GetXlaConvTypes()), - Conv3DOp); +REGISTER_XLA_CONV_OP(Name("Conv3D"), Conv3DOp); class DepthwiseConv2DOp : public ConvOp { public: explicit DepthwiseConv2DOp(OpKernelConstruction* ctx) : ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} }; -REGISTER_XLA_OP( - Name("DepthwiseConv2dNative").TypeConstraint("T", GetXlaConvTypes()), - DepthwiseConv2DOp); +REGISTER_XLA_CONV_OP(Name("DepthwiseConv2dNative"), DepthwiseConv2DOp); // Backprop for input. class ConvBackpropInputOp : public XlaOpKernel { @@ -134,30 +130,27 @@ class Conv2DBackpropInputOp : public ConvBackpropInputOp { explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx) : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {} }; -REGISTER_XLA_OP(Name("Conv2DBackpropInput") - .CompileTimeConstantInput("input_sizes") - .TypeConstraint("T", GetXlaConvTypes()), - Conv2DBackpropInputOp); +REGISTER_XLA_CONV_OP( + Name("Conv2DBackpropInput").CompileTimeConstantInput("input_sizes"), + Conv2DBackpropInputOp); class Conv3DBackpropInputOp : public ConvBackpropInputOp { public: explicit Conv3DBackpropInputOp(OpKernelConstruction* ctx) : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {} }; -REGISTER_XLA_OP(Name("Conv3DBackpropInputV2") - .CompileTimeConstantInput("input_sizes") - .TypeConstraint("T", GetXlaConvTypes()), - Conv3DBackpropInputOp); +REGISTER_XLA_CONV_OP( + Name("Conv3DBackpropInputV2").CompileTimeConstantInput("input_sizes"), + Conv3DBackpropInputOp); class DepthwiseConv2DBackpropInputOp : public ConvBackpropInputOp { public: explicit DepthwiseConv2DBackpropInputOp(OpKernelConstruction* ctx) : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} }; -REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput") - .CompileTimeConstantInput("input_sizes") - .TypeConstraint("T", GetXlaConvTypes()), - DepthwiseConv2DBackpropInputOp); +REGISTER_XLA_CONV_OP(Name("DepthwiseConv2dNativeBackpropInput") + .CompileTimeConstantInput("input_sizes"), + DepthwiseConv2DBackpropInputOp); class ConvBackpropFilterOp : public XlaOpKernel { public: @@ -198,10 +191,9 @@ class Conv2DBackpropFilterOp : public ConvBackpropFilterOp { : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) { } }; -REGISTER_XLA_OP(Name("Conv2DBackpropFilter") - .CompileTimeConstantInput("filter_sizes") - .TypeConstraint("T", GetXlaConvTypes()), - Conv2DBackpropFilterOp); +REGISTER_XLA_CONV_OP( + Name("Conv2DBackpropFilter").CompileTimeConstantInput("filter_sizes"), + Conv2DBackpropFilterOp); class Conv3DBackpropFilterOp : public ConvBackpropFilterOp { public: @@ -209,20 +201,18 @@ class Conv3DBackpropFilterOp : public ConvBackpropFilterOp { : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) { } }; -REGISTER_XLA_OP(Name("Conv3DBackpropFilterV2") - .CompileTimeConstantInput("filter_sizes") - .TypeConstraint("T", GetXlaConvTypes()), - Conv3DBackpropFilterOp); +REGISTER_XLA_CONV_OP( + Name("Conv3DBackpropFilterV2").CompileTimeConstantInput("filter_sizes"), + Conv3DBackpropFilterOp); class DepthwiseConv2DBackpropFilterOp : public ConvBackpropFilterOp { public: explicit DepthwiseConv2DBackpropFilterOp(OpKernelConstruction* ctx) : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} }; -REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter") - .CompileTimeConstantInput("filter_sizes") - .TypeConstraint("T", GetXlaConvTypes()), - DepthwiseConv2DBackpropFilterOp); +REGISTER_XLA_CONV_OP(Name("DepthwiseConv2dNativeBackpropFilter") + .CompileTimeConstantInput("filter_sizes"), + DepthwiseConv2DBackpropFilterOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 7566cc79742..cd03b617158 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -146,6 +146,10 @@ class DynamicStitchOp : public XlaOpKernel { for (int input_num = 0; input_num < indices.size(); input_num++) { for (int i = 0; i < indices[input_num].shape().dimensions(0); ++i) { int index = indices[input_num].Get({i}); + OP_REQUIRES( + ctx, index >= 0, + errors::InvalidArgument("indices[", index, "] is out of range")); + src_input_vector[index] = input_num; src_slice_vector[index] = i; if (!src_index_used[index]) { diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index 7b46d848e1d..55bce65bd8e 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -178,11 +178,9 @@ class ExtractImagePatchesOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(ExtractImagePatchesOp); }; -// We don't support integers for the convolution used in the implementation of -// this op, so we limit the supported types. -REGISTER_XLA_OP( - Name("ExtractImagePatches").TypeConstraint("T", GetXlaConvTypes()), - ExtractImagePatchesOp); +// We don't support integers for the convolution for GPU used in the +// implementation of this op, so we limit the supported types. +REGISTER_XLA_CONV_OP(Name("ExtractImagePatches"), ExtractImagePatchesOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc b/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc index d60d07b5caa..15314b0434e 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc @@ -49,8 +49,7 @@ absl::InlinedVector ConvertCompileTimeConstArgumentsToConst( xla::ValueInferenceMode::kUpperBound); if ((maybe_constant.ok() && maybe_constant->has_value()) || (bounds.ok() && bounds->has_value())) { - StatusOr values_are_dynamic = - expression.ResolveDynamism(ctx->compiler()->client()); + StatusOr values_are_dynamic = expression.ResolveDynamism(); bool all_values_are_static = false; if (values_are_dynamic.ok()) { xla::Literal literal = diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index da15cf27e7c..4abfb149792 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -386,6 +386,13 @@ struct SuppressBodyFn { num_outputs_so_far); // Slice out the row_idx. auto row_iou = xla::DynamicSlice(iou_mask, {row_idx, zero}, {1, num_boxes}); + + TF_ASSIGN_OR_RETURN(auto iou_shape, builder->GetShape(iou_mask)); + auto boxes_runtime_size = xla::GetDimensionSize(row_iou, 1); + if (iou_shape.is_dynamic_dimension(1)) { + row_iou = xla::SetDimensionSize(row_iou, boxes_runtime_size, 1); + } + // Remove the diagonal from consideration. An elem cannot suppress // itself. row_iou = xla::DynamicUpdateSlice( @@ -395,8 +402,12 @@ struct SuppressBodyFn { row_iou = xla::Reshape(row_iou, {num_boxes}); auto supp_mask = xla::Not(row_iou); // Update mask iff current elem is not suppressed. - included_iou = xla::Select(xla::Broadcast(active_elem, {num_boxes}), - xla::And(included_iou, supp_mask), included_iou); + auto cond = xla::Broadcast(active_elem, {num_boxes}); + if (iou_shape.is_dynamic_dimension(1)) { + cond = xla::SetDimensionSize(cond, boxes_runtime_size, 0); + } + included_iou = + xla::Select(cond, xla::And(included_iou, supp_mask), included_iou); row_idx = row_idx + xla::ConstantR0(builder, 1); return std::vector{row_idx, num_outputs_so_far, iou_mask, included_iou}; @@ -485,7 +496,7 @@ class NonMaxSuppressionOp : public XlaOpKernel { const xla::XlaOp indices_sorted = xla::GetTupleElement(indices_sort, 1); const xla::XlaOp scores = xla::GetTupleElement(indices_sort, 0); - // Shapes are henceforth [1, num_boxes]. 'c_y0' denotes 'coordinate' y0. + // Shapes are henceforth [1, <=num_boxes]. 'c_y0' denotes 'coordinate' y0. const xla::XlaOp c_y0 = xla::Reshape(xla::SliceInDim(boxes_sorted, /*start_index=*/0, /*limit_index=*/1, @@ -517,14 +528,14 @@ class NonMaxSuppressionOp : public XlaOpKernel { xla::XlaOp x2 = xla::Select(xla::Le(c_x0, c_x1), c_x1, c_x0); xla::XlaOp area = (y2 - y1) * (x2 - x1); - // Shapes are henceforth [1, num_boxes]. + // Shapes are henceforth [1, <=num_boxes]. y1 = xla::Broadcast(y1, {1}); y2 = xla::Broadcast(y2, {1}); x1 = xla::Broadcast(x1, {1}); x2 = xla::Broadcast(x2, {1}); area = xla::Broadcast(area, {1}); - // Shapes are henceforth [num_boxes, num_boxes]. + // Shapes are henceforth [<=num_boxes, <=num_boxes]. xla::XlaOp i_xmin = xla::Max(x1, xla::Transpose(x1, {1, 0})); xla::XlaOp i_ymin = xla::Max(y1, xla::Transpose(y1, {1, 0})); xla::XlaOp i_xmax = xla::Min(x2, xla::Transpose(x2, {1, 0})); @@ -540,6 +551,13 @@ class NonMaxSuppressionOp : public XlaOpKernel { xla::XlaOp included_iou = xla::Broadcast(xla::ConstantR0(builder, true), {num_boxes}); + auto iou_shape_or = builder->GetShape(iou_thresh_mask); + OP_REQUIRES_OK(context, iou_shape_or.status()); + auto boxes_runtime_size = xla::GetDimensionSize(iou_thresh_mask, 1); + if (iou_shape_or.value().is_dynamic_dimension(1)) { + included_iou = xla::SetDimensionSize(included_iou, boxes_runtime_size, 0); + } + std::vector init_values; init_values.reserve(4); init_values.push_back(xla::ConstantR0(builder, 0)); // col_idx diff --git a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc index 8e2357a9499..f169d86e8b1 100644 --- a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc +++ b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/kernels/callback.pb.h" @@ -60,7 +61,7 @@ const char* const kTfCallbackCustomCall = "GenericTfCallbackGPU"; static StatusOr TensorFromProto(const TensorProto& proto) { Tensor out; if (!out.FromProto(proto)) { - return se::port::InternalError("Failed deserializing a TensorProto"); + return tsl::errors::Internal("Failed deserializing a TensorProto"); } return out; } @@ -108,7 +109,7 @@ Status LightOutsideCompilationOp::CompileToCustomCallCallingTfKernel( if (absl::c_any_of(xla_shape.dynamic_dimensions(), [](const bool is_dynamic) { return is_dynamic; })) { // TODO(cheshire): Support input dynamic dimensions. - return se::port::InternalError( + return tsl::errors::Internal( "Input dynamic dimensions are not supported for light outside " "compilation"); } @@ -157,8 +158,7 @@ Status LightOutsideCompilationOp::CompileToCustomCallCallingTfKernel( TensorShapeProto output_tensor_shape_proto = ic.ShapeHandleToProto(ic.output(i)); if (output_tensor_shape_proto.unknown_rank()) { - return se::port::InternalError( - absl::StrCat("Output ", i, " has unknown rank")); + return tsl::errors::Internal("Output ", i, " has unknown rank"); } int rank = output_tensor_shape_proto.dim_size(); @@ -172,8 +172,8 @@ Status LightOutsideCompilationOp::CompileToCustomCallCallingTfKernel( if (dim->size() < 0) { if (it == dimension_bounds.end()) { - return se::port::InternalError(absl::StrCat( - "Bound for unknown dimension not found for dimension ", d)); + return tsl::errors::Internal( + "Bound for unknown dimension not found for dimension ", d); } dim->set_size(it->second); dynamic_dimensions[d] = true; @@ -291,8 +291,12 @@ class TfCallbackDevice : public DeviceBase { const TfCallbackData& callback_data) : DeviceBase(Env::Default()), stream_(stream), +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM gpu_allocator_(GPUProcessState::singleton()->GetGPUAllocator( - tsl::TfDeviceId{stream_->parent()->device_ordinal()})), + *BaseGPUDevice::FindTfDeviceId(stream))), +#else + gpu_allocator_(nullptr), +#endif cpu_allocator_( ProcessState::singleton()->GetCPUAllocator(/*numa_node=*/0)) { for (int i = 0; i < callback_data.outputs_size(); ++i) { @@ -329,8 +333,7 @@ class TfCallbackDevice : public DeviceBase { context, gpu_stream, /*platform_device_id=*/ tsl::PlatformDeviceId(stream_->parent()->device_ordinal()), allocator, - // TODO(cheshire): Pass meaningful scratch - // buffer. + // TODO(cheshire): Pass meaningful scratch buffer. /*scratch=*/nullptr); return OkStatus(); #else @@ -347,7 +350,12 @@ class TfCallbackDevice : public DeviceBase { if (attr.on_host()) { if (attr.gpu_compatible()) { GPUProcessState* ps = GPUProcessState::singleton(); - return ps->GetGpuHostAllocator(0); + // TODO(jlebar): The very first call to GetGpuHostAllocator sets its + // memory limits. So passing {} for the options here means that if + // nobody gets this allocator before us, we will not respect any limits + // the user might have set on host memory allocation. Our call to + // GetGPUAllocator in the constructor has the same problem. + return ps->GetGpuHostAllocator(/*options=*/{}, 0); } else { return cpu_allocator_; } diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 8fdcd5d3199..81754a0a767 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/util/mirror_pad_mode.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 7f0d712f13d..64351c6a741 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -33,15 +33,41 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/util/determinism.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/tsl/platform/errors.h" namespace tensorflow { namespace { +template +static Status ValidateKernelSizes(const T& ksizes) { + for (size_t i = 0; i < ksizes.size(); ++i) { + if (ksizes[i] <= 0) { + return errors::InvalidArgument( + "Sliding window ksize field for dimension ", i, + " must be positive but is ", ksizes[i]); + } + } + return OkStatus(); +} + +template +static Status ValidateStrides(const T& strides) { + for (size_t i = 0; i < strides.size(); ++i) { + if (strides[i] <= 0) { + return errors::InvalidArgument( + "Sliding window stride field for dimension ", i, + " must be positive but is ", strides[i]); + } + } + return OkStatus(); +} + // Superclass of pooling ops. class PoolingOp : public XlaOpKernel { public: @@ -83,50 +109,54 @@ class PoolingOp : public XlaOpKernel { protected: StatusOr> GetKernelSize(XlaOpKernelContext* ctx) { - if (ctx->num_inputs() == 1) { - return ksize_; - } - const TensorShape ksize_shape = ctx->InputShape(1); - // Validate input sizes. - if (!TensorShapeUtils::IsVector(ksize_shape)) { - return errors::InvalidArgument("ksize must be a vector, not shape ", - ksize_shape.DebugString()); - } - if (ksize_shape.num_elements() != num_dims()) { - return errors::InvalidArgument( - "Sliding window ksize field must " - "specify ", - num_dims(), " dimensions"); - } std::vector ksize; - auto status = ctx->ConstantInputAsIntVector(1, &ksize); - if (!status.ok()) { - return status; + if (ctx->num_inputs() == 1) { + ksize = ksize_; + } else { + const TensorShape ksize_shape = ctx->InputShape(1); + // Validate input sizes. + if (!TensorShapeUtils::IsVector(ksize_shape)) { + return errors::InvalidArgument("ksize must be a vector, not shape ", + ksize_shape.DebugString()); + } + if (ksize_shape.num_elements() != num_dims()) { + return errors::InvalidArgument( + "Sliding window ksize field must " + "specify ", + num_dims(), " dimensions"); + } + auto status = ctx->ConstantInputAsIntVector(1, &ksize); + if (!status.ok()) { + return status; + } } + TF_RETURN_IF_ERROR(ValidateKernelSizes(ksize)); return ksize; } StatusOr> GetStride(XlaOpKernelContext* ctx) { - if (ctx->num_inputs() == 1) { - return stride_; - } - const TensorShape stride_shape = ctx->InputShape(2); - // Validate input sizes. - if (!TensorShapeUtils::IsVector(stride_shape)) { - return errors::InvalidArgument("stride must be a vector, not shape ", - stride_shape.DebugString()); - } - if (stride_shape.num_elements() != num_dims()) { - return errors::InvalidArgument( - "Sliding window stride field must " - "specify ", - num_dims(), " dimensions"); - } std::vector stride; - auto status = ctx->ConstantInputAsIntVector(2, &stride); - if (!status.ok()) { - return status; + if (ctx->num_inputs() == 1) { + stride = stride_; + } else { + const TensorShape stride_shape = ctx->InputShape(2); + // Validate input sizes. + if (!TensorShapeUtils::IsVector(stride_shape)) { + return errors::InvalidArgument("stride must be a vector, not shape ", + stride_shape.DebugString()); + } + if (stride_shape.num_elements() != num_dims()) { + return errors::InvalidArgument( + "Sliding window stride field must " + "specify ", + num_dims(), " dimensions"); + } + auto status = ctx->ConstantInputAsIntVector(2, &stride); + if (!status.ok()) { + return status; + } } + TF_RETURN_IF_ERROR(ValidateStrides(stride)); return stride; } @@ -355,10 +385,12 @@ class MaxPoolGradOp : public XlaOpKernel { errors::InvalidArgument("Sliding window ksize field must " "specify ", num_dims(), " dimensions")); + OP_REQUIRES_OK(ctx, ValidateKernelSizes(ksize_)); OP_REQUIRES(ctx, stride_.size() == num_dims(), errors::InvalidArgument("Sliding window strides field must " "specify ", num_dims(), " dimensions")); + OP_REQUIRES_OK(ctx, ValidateStrides(stride_)); const TensorShape tensor_in_shape = ctx->InputShape(0); const TensorShape tensor_out_shape = ctx->InputShape(1); @@ -446,11 +478,13 @@ class AvgPoolGradOp : public XlaOpKernel { errors::InvalidArgument("Sliding window ksize field must " "specify ", num_dims(), " dimensions")); + OP_REQUIRES_OK(ctx, ValidateKernelSizes(ksize_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); OP_REQUIRES(ctx, stride_.size() == num_dims(), errors::InvalidArgument("Sliding window strides field must " "specify ", num_dims(), " dimensions")); + OP_REQUIRES_OK(ctx, ValidateStrides(stride_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); OP_REQUIRES(ctx, padding_ != EXPLICIT, errors::Unimplemented( @@ -579,10 +613,12 @@ class MaxPoolGradGradOp : public XlaOpKernel { errors::InvalidArgument("Sliding window ksize field must " "specify ", num_dims(), " dimensions")); + OP_REQUIRES_OK(ctx, ValidateKernelSizes(ksize_)); OP_REQUIRES(ctx, stride_.size() == num_dims(), errors::InvalidArgument("Sliding window strides field must " "specify ", num_dims(), " dimensions")); + OP_REQUIRES_OK(ctx, ValidateStrides(stride_)); const TensorShape tensor_in_shape = ctx->InputShape(0); const TensorShape tensor_out_shape = ctx->InputShape(1); diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.h b/tensorflow/compiler/tf2xla/kernels/resampler_ops.h index a8e78e4b5db..195f41dce76 100644 --- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.h @@ -37,4 +37,4 @@ class ResamplerGradOp : public XlaOpKernel { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_TF2XLA_KERNELS_RESAMPLER_OPS_H_ +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_RESAMPLER_OPS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index a9a67bc3b17..08a8545ec68 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -42,10 +42,15 @@ class ReverseSequenceOp : public XlaOpKernel { seq_lens_shape.dims())); OP_REQUIRES(context, batch_dim_ != seq_dim_, errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim_)); + OP_REQUIRES(context, seq_dim_ >= 0, + errors::InvalidArgument("seq_dim must be >=0, got ", seq_dim_)); OP_REQUIRES( context, seq_dim_ < input_shape.dims(), errors::InvalidArgument("seq_dim must be < input rank", " ( ", seq_dim_, " vs. ", input_shape.dims(), ")")); + OP_REQUIRES( + context, batch_dim_ >= 0, + errors::InvalidArgument("batch_dim must be >=0, got ", batch_dim_)); OP_REQUIRES( context, batch_dim_ < input_shape.dims(), errors::InvalidArgument("batch_dim must be < input rank", " ( ", diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index c9cb757c913..1812ddab2b6 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -119,9 +119,10 @@ class ScatterNdOp : public XlaOpKernel { auto updates = context->Input(1); auto combine = context->input_xla_type(1) == xla::PRED ? CombineBool : CombineNum; - auto result = - XlaScatter(buffer, updates, indices, - /*indices_are_vectors=*/true, /*combiner=*/combine, builder); + auto result = XlaScatter(buffer, updates, indices, + /*indices_are_vectors=*/true, + /*indices_are_sorted=*/false, + /*combiner=*/combine, builder); OP_REQUIRES_OK(context, result.status()); context->SetOutput(0, result.value()); } @@ -173,7 +174,8 @@ void CompileTensorScatter( auto indices = context->Input(1); auto updates = context->Input(2); auto result = XlaScatter(buffer, updates, indices, - /*indices_are_vectors=*/true, combiner, builder); + /*indices_are_vectors=*/true, + /*indices_are_sorted=*/false, combiner, builder); OP_REQUIRES_OK(context, result.status()); context->SetOutput(0, result.value()); } diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 78a1a67c5a8..afceb14044e 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -25,9 +25,10 @@ limitations under the License. namespace tensorflow { namespace { -class UnsortedSegmentReduce : public XlaOpKernel { +class SegmentReduce : public XlaOpKernel { public: - explicit UnsortedSegmentReduce(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + explicit SegmentReduce(OpKernelConstruction* ctx, bool indices_are_sorted) + : XlaOpKernel(ctx), indices_are_sorted_(indices_are_sorted) { DataType dtype; OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype)); OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &type_)); @@ -46,7 +47,9 @@ class UnsortedSegmentReduce : public XlaOpKernel { // output[i] == 0 if i does not appear in indices // // Contrast with segment_sum(), which assumes indices are sorted and that - // max(indices)+1 is the desired size of the output. + // max(indices)+1 is the desired size of the output. Note that + // segment_sum_v2 also takes num_segments as an input and can be supported + // similarly. // // The returned output tensor has the same type as data, and the same shape // as data with the first indices.rank dimensions are replaced @@ -118,19 +121,22 @@ class UnsortedSegmentReduce : public XlaOpKernel { xla::XlaBuilder* builder) { return Combine(a, b); }; auto result = XlaScatter(buffer, /*updates=*/data, indices, - /*indices_are_vectors=*/false, combiner, builder); + /*indices_are_vectors=*/false, indices_are_sorted_, + combiner, builder); OP_REQUIRES_OK(ctx, result.status()); ctx->SetOutput(0, result.value()); } protected: xla::PrimitiveType type_; + bool indices_are_sorted_; }; -class UnsortedSegmentSum : public UnsortedSegmentReduce { +template +class SegmentSum : public SegmentReduce { public: - explicit UnsortedSegmentSum(OpKernelConstruction* ctx) - : UnsortedSegmentReduce(ctx) {} + explicit SegmentSum(OpKernelConstruction* ctx) + : SegmentReduce(ctx, indices_are_sorted) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return xla::Zero(builder, type_); @@ -138,14 +144,17 @@ class UnsortedSegmentSum : public UnsortedSegmentReduce { xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a + b; }; }; +REGISTER_XLA_OP(Name("SegmentSumV2").CompileTimeConstantInput("num_segments"), + SegmentSum); REGISTER_XLA_OP( Name("UnsortedSegmentSum").CompileTimeConstantInput("num_segments"), - UnsortedSegmentSum); + SegmentSum); -class UnsortedSegmentProd : public UnsortedSegmentReduce { +template +class SegmentProd : public SegmentReduce { public: - explicit UnsortedSegmentProd(OpKernelConstruction* ctx) - : UnsortedSegmentReduce(ctx) {} + explicit SegmentProd(OpKernelConstruction* ctx) + : SegmentReduce(ctx, indices_are_sorted) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return xla::One(builder, type_); @@ -155,12 +164,14 @@ class UnsortedSegmentProd : public UnsortedSegmentReduce { REGISTER_XLA_OP( Name("UnsortedSegmentProd").CompileTimeConstantInput("num_segments"), - UnsortedSegmentProd); + SegmentProd); +REGISTER_XLA_OP(Name("SegmentProdV2").CompileTimeConstantInput("num_segments"), + SegmentProd); -class UnsortedSegmentMin : public UnsortedSegmentReduce { +class UnsortedSegmentMin : public SegmentReduce { public: explicit UnsortedSegmentMin(OpKernelConstruction* ctx) - : UnsortedSegmentReduce(ctx) {} + : SegmentReduce(ctx, /*indices_are_sorted=*/false) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return xla::MaxFiniteValue(builder, type_); @@ -174,10 +185,10 @@ REGISTER_XLA_OP( Name("UnsortedSegmentMin").CompileTimeConstantInput("num_segments"), UnsortedSegmentMin); -class UnsortedSegmentMax : public UnsortedSegmentReduce { +class UnsortedSegmentMax : public SegmentReduce { public: explicit UnsortedSegmentMax(OpKernelConstruction* ctx) - : UnsortedSegmentReduce(ctx) {} + : SegmentReduce(ctx, /*indices_are_sorted=*/false) {} xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return xla::MinFiniteValue(builder, type_); diff --git a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc index bb5dfa5426b..4f3c7b79861 100644 --- a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc @@ -90,6 +90,7 @@ class SparseToDenseOp : public XlaOpKernel { } auto result = XlaScatter(buffer, sparse_values, indices, /*indices_are_vectors=*/indices_shape.dims() > 1, + /*indices_are_sorted=*/false, /*combiner=*/{}, builder); context->SetOutput(0, builder->ReportErrorOrReturn(result)); } diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index b8c033d3539..ae5150f14f9 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -59,8 +59,19 @@ class SplitOp : public XlaOpKernel { errors::InvalidArgument( "Number of ways to split should be > 0, but got ", num_split)); + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp input = ctx->Input(1); + auto shape_or = builder->GetShape(input); + OP_REQUIRES_OK(ctx, shape_or.status()); + + xla::Shape xla_shape = shape_or.value(); + OP_REQUIRES( + ctx, !xla_shape.is_dynamic_dimension(split_dim), + errors::InvalidArgument( + "Split op doesn't support split for the dynamic dimension")); + OP_REQUIRES( - ctx, input_shape.dim_size(split_dim) % num_split == 0, + ctx, xla_shape.dimensions(split_dim) % num_split == 0, errors::InvalidArgument( "Number of ways to split should evenly divide the split " "dimension, but got split_dim ", @@ -83,8 +94,6 @@ class SplitOp : public XlaOpKernel { limits[i] = dim; } - auto input = ctx->Input(1); - // Create each of the outputs. for (int i = 0; i < num_split; ++i) { // Slice out the ith split from the split dimension. @@ -164,12 +173,23 @@ class SplitVOp : public XlaOpKernel { } } + xla::XlaBuilder* builder = ctx->builder(); + auto shape_or = builder->GetShape(input); + OP_REQUIRES_OK(ctx, shape_or.status()); + + // TODO(b/265880112): Support this using the SetDimensionSize op. + xla::Shape xla_shape = shape_or.value(); + OP_REQUIRES( + ctx, !xla_shape.is_dynamic_dimension(split_dim), + errors::Unimplemented("SplitV op doesn't yet support dynamic split " + "dimension.")); + OP_REQUIRES( ctx, (neg_one_dim == -1 && - total_split_size == input_shape.dim_size(split_dim)) || + total_split_size == xla_shape.dimensions(split_dim)) || (neg_one_dim >= 0 && - total_split_size <= input_shape.dim_size(split_dim)), + total_split_size <= xla_shape.dimensions(split_dim)), errors::InvalidArgument("Determined shape must either match " "input shape along split_dim exactly if " "fully specified, or be less than the size of " diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 5f5ae70a89d..82547ae61f4 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -68,7 +68,7 @@ Status MaybeInitializeStack(xla::XlaBuilder* builder, XlaResource* resource, } TensorShape stack_shape; - stack_shape.AddDim(resource->max_array_size()); + TF_RETURN_IF_ERROR(stack_shape.AddDimWithStatus(resource->max_array_size())); stack_shape.AppendShape(elem_shape); if (!resource->initialized()) { diff --git a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc index b6aac5f53f1..bdb91c33509 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc @@ -230,6 +230,15 @@ Status CompileImpl( return OkStatus(); } +DataType MaybeConvertBF16ToF32(DataType const& dtype) { + if (dtype == DT_BFLOAT16) { + // We'll go through F32 to generate BF16. + // TODO(b/256243456): Generate BF16 directly from U16. + return DT_FLOAT; + } + return dtype; +} + class StatefulUniformOp : public XlaOpKernel { public: explicit StatefulUniformOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { @@ -241,8 +250,8 @@ class StatefulUniformOp : public XlaOpKernel { auto sampler = [builder, this](xla::RandomAlgorithm alg, xla::XlaOp state, xla::XlaOp key, TensorShape shape) -> SamplerReturnType { + auto rng_dtype = MaybeConvertBF16ToF32(dtype_); xla::Shape xla_shape; - DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); xla::PrimitiveType rng_primitive_type = xla_shape.element_type(); xla::RngOutput uniform_state = StatefulRngUniform( @@ -269,8 +278,8 @@ class StatefulUniformOp : public XlaOpKernel { REGISTER_XLA_OP(Name("StatefulUniform") .CompileTimeConstantInput("algorithm") .CompileTimeConstantInput("shape") - .TypeConstraint("dtype", - {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}), + .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_HALF, + DT_BFLOAT16}), StatefulUniformOp); class StatefulStandardNormalOp : public XlaOpKernel { @@ -285,8 +294,8 @@ class StatefulStandardNormalOp : public XlaOpKernel { // Needs explicit lambda return type because it fails to be inferred. [this](xla::RandomAlgorithm alg, xla::XlaOp state, xla::XlaOp key, TensorShape shape) -> SamplerReturnType { + auto rng_dtype = MaybeConvertBF16ToF32(dtype_); xla::Shape xla_shape; - DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); xla::RngOutput value_state = xla::NormalFloatingPointDistribution( key, state, BitGen(alg), xla_shape); @@ -308,8 +317,8 @@ class StatefulStandardNormalOp : public XlaOpKernel { REGISTER_XLA_OP(Name("StatefulStandardNormalV2") .CompileTimeConstantInput("algorithm") .CompileTimeConstantInput("shape") - .TypeConstraint("dtype", - {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}), + .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_HALF, + DT_BFLOAT16}), StatefulStandardNormalOp); class StatefulTruncatedNormalOp : public XlaOpKernel { @@ -326,8 +335,8 @@ class StatefulTruncatedNormalOp : public XlaOpKernel { [builder, this](xla::RandomAlgorithm alg, xla::XlaOp state, xla::XlaOp key, TensorShape shape) -> SamplerReturnType { + auto rng_dtype = MaybeConvertBF16ToF32(dtype_); xla::Shape xla_shape; - DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); xla::RngOutput uniform_result = StatefulRngUniform( @@ -355,8 +364,8 @@ class StatefulTruncatedNormalOp : public XlaOpKernel { REGISTER_XLA_OP(Name("StatefulTruncatedNormal") .CompileTimeConstantInput("algorithm") .CompileTimeConstantInput("shape") - .TypeConstraint("dtype", - {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}), + .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_HALF, + DT_BFLOAT16}), StatefulTruncatedNormalOp); class StatefulUniformIntOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index df49a2bf794..ad33157ed78 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -67,6 +67,10 @@ xla::BitGeneratorTy GetBitGeneratorForDevice( xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) { if (dtype == DT_BFLOAT16) { xla::XlaBuilder* builder = input.builder(); + // TODO(b/256243456): Instead of doing + // `ConvertElementType(BitcastConvertType(u32, F32), BF16)` we should do + // `BitcastConvertType(ConvertElementType(u32, U16), BF16)`, to avoid the + // unclear `ConvertElementType(f32, BF16)` behavior. xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) & xla::ConstantR0(builder, 0xFFFF0000); return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32), @@ -87,6 +91,7 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string, xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0); xla::PrimitiveType type = shape.element_type(); switch (type) { + case xla::F16: case xla::F32: case xla::F64: return xla::UniformFloatingPointDistribution( @@ -94,7 +99,7 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string, GetBitGeneratorForDevice(device_type_string), minval, maxval, shape) .value; - case xla::S32: // fall through + case xla::S32: case xla::S64: return UniformIntDistribution( key, initial_state, @@ -104,7 +109,7 @@ xla::XlaOp StatelessRngUniform(absl::string_view device_type_string, break; default: return builder->ReportError(xla::Unimplemented( - "Types other than F32, S32 and S64 are not implemented by " + "Types other than F16, F32, S32 and S64 are not implemented by " "StatelessRngUniform; got %s", xla::primitive_util::LowercasePrimitiveTypeName(type))); } @@ -139,6 +144,15 @@ xla::XlaOp StatelessRngUniformFullInt(absl::string_view device_type_string, } } +DataType MaybeConvertBF16ToF32(DataType const& dtype) { + if (dtype == DT_BFLOAT16) { + // We'll go through F32 to generate BF16. + // TODO(b/256243456): Generate BF16 directly from U16. + return DT_FLOAT; + } + return dtype; +} + class StatelessRandomUniformOp : public XlaOpKernel { public: explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) @@ -159,7 +173,7 @@ class StatelessRandomUniformOp : public XlaOpKernel { seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); - DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT; + auto rng_dtype = MaybeConvertBF16ToF32(dtype_); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); xla::PrimitiveType rng_primitive_type = xla_shape.element_type(); @@ -182,7 +196,8 @@ class StatelessRandomUniformOp : public XlaOpKernel { // TODO(phawkins): generalize to non-float, non-int32 seed types. REGISTER_XLA_OP(Name("StatelessRandomUniform") .CompileTimeConstantInput("shape") - .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}) + .TypeConstraint("dtype", + {DT_DOUBLE, DT_FLOAT, DT_HALF, DT_BFLOAT16}) .TypeConstraint("Tseed", DT_INT32), StatelessRandomUniformOp); @@ -295,9 +310,8 @@ class StatelessRandomNormalOp : public XlaOpKernel { errors::InvalidArgument("seed must have shape [2], not ", seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); + auto rng_dtype = MaybeConvertBF16ToF32(dtype_); xla::Shape xla_shape; - - DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); xla::XlaBuilder* builder = seed.builder(); @@ -325,7 +339,8 @@ class StatelessRandomNormalOp : public XlaOpKernel { // TODO(phawkins): generalize to non-float, non-int32 seed types. REGISTER_XLA_OP(Name("StatelessRandomNormal") .CompileTimeConstantInput("shape") - .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}) + .TypeConstraint("dtype", + {DT_DOUBLE, DT_FLOAT, DT_HALF, DT_BFLOAT16}) .TypeConstraint("Tseed", DT_INT32), StatelessRandomNormalOp); @@ -348,7 +363,7 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { xla::XlaOp seed = ctx->Input(1); xla::XlaBuilder* builder = ctx->builder(); - DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT; + auto rng_dtype = MaybeConvertBF16ToF32(dtype_); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); xla::XlaOp uniform = StatelessRngUniform( @@ -369,7 +384,8 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { REGISTER_XLA_OP(Name("StatelessTruncatedNormal") .CompileTimeConstantInput("shape") - .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}) + .TypeConstraint("dtype", + {DT_DOUBLE, DT_FLOAT, DT_HALF, DT_BFLOAT16}) .TypeConstraint("Tseed", DT_INT32), StatelessTruncatedNormalOp); @@ -392,8 +408,7 @@ class StatelessParameterizedTruncatedNormalOp : public XlaOpKernel { xla::XlaOp seed = ctx->Input(1); xla::XlaBuilder* builder = ctx->builder(); - DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT; - + auto rng_dtype = MaybeConvertBF16ToF32(dtype_); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); @@ -432,7 +447,8 @@ class StatelessParameterizedTruncatedNormalOp : public XlaOpKernel { REGISTER_XLA_OP(Name("StatelessParameterizedTruncatedNormal") .CompileTimeConstantInput("shape") - .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}) + .TypeConstraint("dtype", + {DT_DOUBLE, DT_FLOAT, DT_HALF, DT_BFLOAT16}) .TypeConstraint("Tseed", DT_INT32), StatelessParameterizedTruncatedNormalOp); diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc index 6d3cce554e0..255474a62ca 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops_v2.cc @@ -99,6 +99,7 @@ xla::RngOutput StatelessRngUniformV2(xla::RandomAlgorithm const& alg, using std::placeholders::_3; auto generator = std::bind(BitGenerator, alg, _1, _2, _3); switch (type) { + case xla::F16: case xla::F32: case xla::F64: return xla::UniformFloatingPointDistribution(key, counter, generator, @@ -112,7 +113,7 @@ xla::RngOutput StatelessRngUniformV2(xla::RandomAlgorithm const& alg, break; default: return {builder->ReportError(xla::Unimplemented( - "Types other than F32, S32, S64, U32 and U64 are not " + "Types other than F16, F32, S32, S64, U32 and U64 are not " "implemented by " "StatelessRngUniformV2; got %s", xla::primitive_util::LowercasePrimitiveTypeName(type))), @@ -179,6 +180,15 @@ xla::XlaOp MaybeSliceCounter(xla::RandomAlgorithm const& alg, return counter; } +DataType MaybeConvertBF16ToF32(DataType const& dtype) { + if (dtype == DT_BFLOAT16) { + // We'll go through F32 to generate BF16. + // TODO(b/256243456): Generate BF16 directly from U16. + return DT_FLOAT; + } + return dtype; +} + class StatelessRandomUniformOp : public XlaOpKernel { public: explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) @@ -209,7 +219,7 @@ class StatelessRandomUniformOp : public XlaOpKernel { ctx->InputShape(key_input_idx), counter_shape)); - DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT; + auto rng_dtype = MaybeConvertBF16ToF32(dtype_); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); xla::PrimitiveType rng_primitive_type = xla_shape.element_type(); @@ -247,8 +257,8 @@ class StatelessRandomUniformOp : public XlaOpKernel { REGISTER_XLA_OP(Name("StatelessRandomUniformV2") .CompileTimeConstantInput("shape") .CompileTimeConstantInput("alg") - .TypeConstraint("dtype", - {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}), + .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_HALF, + DT_BFLOAT16}), StatelessRandomUniformOp); class StatelessRandomUniformIntOp : public XlaOpKernel { @@ -392,8 +402,7 @@ class StatelessRandomNormalOp : public XlaOpKernel { ctx->InputShape(key_input_idx), counter_shape)); - DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT; - + auto rng_dtype = MaybeConvertBF16ToF32(dtype_); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); @@ -431,8 +440,8 @@ class StatelessRandomNormalOp : public XlaOpKernel { REGISTER_XLA_OP(Name("StatelessRandomNormalV2") .CompileTimeConstantInput("shape") .CompileTimeConstantInput("alg") - .TypeConstraint("dtype", - {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}), + .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_HALF, + DT_BFLOAT16}), StatelessRandomNormalOp); class StatelessTruncatedNormalOp : public XlaOpKernel { @@ -464,7 +473,7 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { xla::XlaBuilder* builder = ctx->builder(); - DataType rng_dtype = dtype_ == DT_DOUBLE ? DT_DOUBLE : DT_FLOAT; + auto rng_dtype = MaybeConvertBF16ToF32(dtype_); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(rng_dtype, shape, &xla_shape)); @@ -488,8 +497,8 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { REGISTER_XLA_OP(Name("StatelessTruncatedNormalV2") .CompileTimeConstantInput("shape") .CompileTimeConstantInput("alg") - .TypeConstraint("dtype", - {DT_DOUBLE, DT_FLOAT, DT_BFLOAT16}), + .TypeConstraint("dtype", {DT_DOUBLE, DT_FLOAT, DT_HALF, + DT_BFLOAT16}), StatelessTruncatedNormalOp); class GetKeyCounterOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 4074a3fb3d3..a0f8f62cd57 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -76,7 +76,7 @@ Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape_or_status.value(), &shape)); TensorShape ta_shape; - ta_shape.AddDim(resource->max_array_size()); + TF_RETURN_IF_ERROR(ta_shape.AddDimWithStatus(resource->max_array_size())); ta_shape.AppendShape(elem_shape); if (ta_shape != shape) { return errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 980ca07e117..5d299fde600 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -553,6 +553,8 @@ class TensorListSplitOp : public XlaOpKernel { OP_REQUIRES(ctx, len == length, errors::Unimplemented("All lengths have to be the same")); } + OP_REQUIRES(ctx, length, + errors::Unimplemented("All lengths must be positive")); OP_REQUIRES( ctx, element_dims[0] % length == 0, errors::Unimplemented("Buffer size has to be a multiple of length")); diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 84f8981a979..d873396a828 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -184,8 +184,9 @@ class InvertPermutationOp : public XlaOpKernel { xla::Iota(ctx->builder(), xla::primitive_util::NativeToPrimitiveType(), size); auto result = XlaScatter(iota, iota, indices, - /*indices_are_vectors=*/false, /*combiner=*/{}, - ctx->builder()); + /*indices_are_vectors=*/false, + /*indices_are_sorted=*/false, + /*combiner=*/{}, ctx->builder()); OP_REQUIRES_OK(ctx, result.status()); ctx->SetOutput(0, result.value()); } diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index bf209504748..555670f69c3 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -64,6 +64,7 @@ REGISTER_XLA_OP(Name("Ceil"), MlirXlaOpKernel); REGISTER_XLA_OP(Name("Cos"), MlirXlaOpKernel); XLAJIT_MAKE_UNARY(Cosh, xla::Cosh(x)); XLAJIT_MAKE_UNARY(Sin, xla::Sin(x)); +XLAJIT_MAKE_UNARY(Tan, xla::Tan(x)); REGISTER_XLA_OP(Name("Exp"), MlirXlaOpKernel); REGISTER_XLA_OP(Name("Expm1"), MlirXlaOpKernel); REGISTER_XLA_OP(Name("Floor"), MlirXlaOpKernel); @@ -117,7 +118,6 @@ XLAJIT_MAKE_UNARY(Softplus, Softplus(b, x)); XLAJIT_MAKE_UNARY(Softsign, x / (xla::Abs(x) + xla::ScalarLike(x, 1.0))); REGISTER_XLA_OP(Name("Sqrt"), MlirXlaOpKernel); XLAJIT_MAKE_UNARY(Square, x* x); -XLAJIT_MAKE_UNARY(Tan, xla::Tan(x)); REGISTER_XLA_OP(Name("Tanh"), MlirXlaOpKernel); REGISTER_XLA_OP(Name("Real"), MlirXlaOpKernel); REGISTER_XLA_OP(Name("Imag"), MlirXlaOpKernel); diff --git a/tensorflow/compiler/tf2xla/kernels/unique_op.cc b/tensorflow/compiler/tf2xla/kernels/unique_op.cc index be31b285b6f..c4389baf4a8 100644 --- a/tensorflow/compiler/tf2xla/kernels/unique_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unique_op.cc @@ -147,6 +147,10 @@ class UniqueOpBase : public XlaOpKernel { StatusOr input_shape_or = ctx->builder()->GetShape(input); OP_REQUIRES_OK(ctx, input_shape_or.status()); auto input_shape = input_shape_or.value(); + axis = axis < 0 ? axis + input_shape.rank() : axis; + OP_REQUIRES(ctx, 0 <= axis && axis < input_shape.rank(), + errors::InvalidArgument("axis has to be between [0, ", + input_shape.rank(), ")")); auto aux = MoveAxis(input, axis, 0, input_shape); auto aux_shape = ctx->builder()->GetShape(aux).value(); int64_t leading_size = aux_shape.dimensions(0); diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index aaddba9b8f0..6fe7a44f32a 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -190,7 +190,7 @@ class ResourceScatterOp : public XlaOpKernel { const xla::XlaOp updates = context->Input(2); auto result = XlaScatter(var_value, updates, indices, indices_are_vectors_, - combiner_, builder); + /*indices_are_sorted=*/false, combiner_, builder); OP_REQUIRES_OK(context, result.status()); OP_REQUIRES_OK(context, context->AssignVariable(0, dtype, result.value())); } diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc index 81c926b458c..7cf6d47bb28 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project @@ -33,7 +34,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -62,10 +63,11 @@ StatusOr ComputeDimensionValue(int version, string dim_arg_spec, return errors::InvalidArgument("Syntax error in dim_args_spec '", dim_arg_spec, "'"); } - if (arg_idx < 0 || arg_idx > arguments.size()) { + if (arg_idx < 0 || arg_idx >= arguments.size()) { return errors::InvalidArgument( - "Invalid argument index ", arg_idx, " when number of arguments is ", - arguments.size(), " in dim_arg_spec '", dim_arg_spec, "'"); + "Invalid argument index ", arg_idx, + " when the number of non-dimension arguments is ", arguments.size(), + " in dim_arg_spec '", dim_arg_spec, "'"); } mlir::RankedTensorType arg_type = arguments[arg_idx].getType().dyn_cast(); @@ -75,19 +77,30 @@ StatusOr ComputeDimensionValue(int version, string dim_arg_spec, "' does not have a RankedTensorType"); } if (arg_axis_idx < 0 || arg_axis_idx >= arg_type.getShape().size()) { - return errors::InvalidArgument( - "Invalid axis index ", arg_axis_idx, " when rank of input is ", - arg_type.getShape().size(), " in dim_arg_spec '", dim_arg_spec, "'"); + return errors::InvalidArgument("Invalid axis index ", arg_axis_idx, + " when the rank of non-dimension argument ", + arg_idx, " is ", arg_type.getShape().size(), + " in dim_arg_spec '", dim_arg_spec, "'"); } mlir::Value val; + mlir::Type get_dim_type = + mlir::RankedTensorType::get({}, op_builder.getI32Type()); if (version >= VERSION_START_STABLE_HLO) { val = op_builder.create( - arguments[arg_idx].getLoc(), dim_arg_type, arguments[arg_idx], + arguments[arg_idx].getLoc(), get_dim_type, arguments[arg_idx], op_builder.getI64IntegerAttr(arg_axis_idx)); + if (dim_arg_type != get_dim_type) { + val = op_builder.create( + arguments[arg_idx].getLoc(), dim_arg_type, val); + } } else { val = op_builder.create( - arguments[arg_idx].getLoc(), dim_arg_type, arguments[arg_idx], + arguments[arg_idx].getLoc(), get_dim_type, arguments[arg_idx], op_builder.getI64IntegerAttr(arg_axis_idx)); + if (dim_arg_type != get_dim_type) { + val = op_builder.create( + arguments[arg_idx].getLoc(), dim_arg_type, val); + } } return val; } @@ -100,11 +113,12 @@ StatusOr ComputeDimensionValue(int version, string dim_arg_spec, // // where %arg0 and %arg1 are dimension arguments, always first among the // arguments, and whose values are computed based on the static shapes of the -// array arguments (%arg2 and following). +// non-dimension arguments (%arg2 and following). // In the above example, the dim_args_spec array would have two elements, one // for %arg0 and one for %arg1. E.g., ['0.0', '0.1'] specifies that %arg0 // should be set to the size of axis 0 or array argument 0 (%arg2), while // %arg1 should be set to the size of axis 1. +// The dimension arguments must be 0-dimensional tensors of integer type. // // We create a new "main" function as follows: // func public main(%arg2: f32[?, ?, 8]) { @@ -132,9 +146,9 @@ Status AddMainWrapper(int version, mlir::ModuleOp module, return errors::InvalidArgument("Cannot find 'main' in module"); } if (orig_main.getNumArguments() <= nr_dim_args) { - return errors::InvalidArgument("'main' has ", orig_main.getNumArguments(), - " arguments, but it must have at least ", - nr_dim_args, " dimension arguments"); + return errors::InvalidArgument( + "The module should have ", nr_dim_args, " dimension arguments, but it ", + "has only ", orig_main.getNumArguments(), " total arguments"); } mlir::Block &orig_main_body = orig_main.front(); @@ -161,6 +175,17 @@ Status AddMainWrapper(int version, mlir::ModuleOp module, std::vector call_args(orig_main_body.getNumArguments()); for (int i = 0; i < orig_main_body.getNumArguments(); ++i) { if (i < nr_dim_args) { + mlir::Type arg_type = orig_main.getArgument(i).getType(); + mlir::RankedTensorType arg_ranked_type = + arg_type.dyn_cast(); + if (!arg_ranked_type || + !arg_ranked_type.getElementType().dyn_cast() || + !arg_ranked_type.getShape().empty()) { + return errors::InvalidArgument( + "Module argument at index ", i, + " should be a 0-dimensional integer-tensor dimension argument", + " but has type ", debugString(arg_type)); + } TF_ASSIGN_OR_RETURN(call_args[i], ComputeDimensionValue( version, dim_args_spec[i], block_args, op_builder, @@ -194,7 +219,8 @@ Status AddMainWrapper(int version, mlir::ModuleOp module, // inference to refine all dynamic shapes, and to rewrite the dynamic ops, // e.g., to replace dynamic_broadcast_in_dim with broadcast_in_dim. Status RefineDynamicShapes(XlaOpKernelContext *ctx, - mlir::OwningOpRef *module) { + mlir::OwningOpRef *module, + int nr_dim_args) { // Locate the (wrapped) 'main' function. // This is the convention used by MlirToXlaComputation. mlir::func::FuncOp main = (*module)->lookupSymbol("main"); @@ -202,17 +228,18 @@ Status RefineDynamicShapes(XlaOpKernelContext *ctx, return errors::InvalidArgument("Cannot find 'main' in module"); } mlir::Block &main_body = main.front(); - int nr_array_arguments = ctx->num_inputs(); - if (nr_array_arguments != main_body.getNumArguments()) { + int non_dimension_arguments = ctx->num_inputs(); + if (non_dimension_arguments != main_body.getNumArguments()) { return errors::InvalidArgument( - "Incorrect number of arguments for XlaCallModule. ", - "The wrapped module expects ", main_body.getNumArguments(), - " arguments, but there are ", nr_array_arguments, " arguments"); + "Incorrect number of arguments for XlaCallModule: ", + non_dimension_arguments, ". The module has ", + main_body.getNumArguments() + nr_dim_args, " of which ", nr_dim_args, + " were declared to be dimension arguments."); } mlir::Builder builder((*module)->getContext()); - std::vector static_array_input_types(nr_array_arguments); - for (int i = 0, end = nr_array_arguments; i < end; ++i) { + std::vector static_array_input_types(non_dimension_arguments); + for (int i = 0, end = non_dimension_arguments; i < end; ++i) { TF_ASSIGN_OR_RETURN(xla::Shape xla_shape, ctx->InputXlaShape(i)); std::vector xla_dimensions(xla_shape.dimensions().begin(), xla_shape.dimensions().end()); @@ -232,11 +259,24 @@ Status RefineDynamicShapes(XlaOpKernelContext *ctx, // This will only change the argument types and will not propagate the // additional type information further. For that, we'll need to run // shape inference as explained below. - main.setType( - builder.getFunctionType(static_array_input_types, main.getResultTypes())); + auto static_array_output_types = llvm::to_vector(main.getResultTypes()); for (auto i = 0; i < main_body.getNumArguments(); ++i) { - main_body.getArgument(i).setType(static_array_input_types[i]); + auto arg = main_body.getArgument(i); + arg.setType(static_array_input_types[i]); + // If the argument is used by `func.return`, then we also need to + // update function result types. It's not great that we need this hack, + // but in the future when we have stablehlo.func, stablehlo.return, etc, + // this will not be needed. + // TODO(burmako): Once https://github.com/openxla/stablehlo/issues/425 is + // fixed, clean this up. + for (mlir::OpOperand &use : arg.getUses()) { + if (auto ret = llvm::dyn_cast(use.getOwner())) { + static_array_output_types[use.getOperandNumber()] = arg.getType(); + } + } } + main.setType(builder.getFunctionType(static_array_input_types, + static_array_output_types)); // --tf-shape-inference, despite its TF-specific name, seems to be general // enough to also work on MHLO. (Although it fails if it doesn't see a // tf.versions attribute on the module, which we hackily attach). @@ -244,6 +284,14 @@ Status RefineDynamicShapes(XlaOpKernelContext *ctx, builder.getNamedAttr("producer", builder.getI32IntegerAttr(0)); (**module)->setAttr("tf.versions", builder.getDictionaryAttr({tf_producer})); + // Verify the module before running passes on it. + // If the module doesn't pass verification, all sorts of weirdness might + // happen if we run the pass manager. + if (failed(verify(**module))) { + VLOG(3) << "XlaCallModule module with verification failed: " + << debugString(**module); + return errors::InvalidArgument("Module verification failed"); + } mlir::PassManager pm((*module)->getContext()); if (VLOG_IS_ON(3)) { auto print_before = [](mlir::Pass *, mlir::Operation *) { return true; }; @@ -260,6 +308,35 @@ Status RefineDynamicShapes(XlaOpKernelContext *ctx, if (!mlir::succeeded(pm.run(**module))) { return errors::InvalidArgument("Module shape inference failed"); } + + // Finally, make sure that no dynamic shapes are left, otherwise all sorts of + // weirdness might happen in the HLO exporter. + bool moduleHasDynamicShapes = false; + auto hasDynamicShape = [](mlir::Value value) { + auto shaped_type = value.getType().dyn_cast(); + return shaped_type ? !shaped_type.hasStaticShape() : false; + }; + (*module)->walk([&](mlir::Operation *op) { + // It's sufficient to only check results because operands either come from + // results or from block arguments which are checked below. + bool opHasDynamicShapes = false; + opHasDynamicShapes |= llvm::any_of(op->getResults(), hasDynamicShape); + for (mlir::Region ®ion : op->getRegions()) { + opHasDynamicShapes |= + llvm::any_of(region.getArguments(), hasDynamicShape); + } + moduleHasDynamicShapes |= opHasDynamicShapes; + if (opHasDynamicShapes) { + std::string opStr; + llvm::raw_string_ostream os(opStr); + op->print(os); + VLOG(3) << "Operation still has dynamic shapes: " << opStr; + } + }); + if (moduleHasDynamicShapes) { + return errors::InvalidArgument("Module still has dynamic shapes"); + } + VLOG(3) << "XlaCallModule module with inferred types: " << debugString(**module); return OkStatus(); @@ -284,7 +361,9 @@ Status LoadAndPreprocessModule(int version, if (!*module) { return errors::InvalidArgument("Cannot deserialize computation"); } - VLOG(3) << "Parsed serialized module (version" << version << ")\n" + VLOG(3) << "Parsed serialized module (version " << version + << ", dim_args_spec = [" << absl::StrJoin(dim_args_spec, ", ") + << "])\n" << debugString(**module); if (failed((*module)->verifyInvariants())) { @@ -310,8 +389,12 @@ Status LoadAndPreprocessModule(int version, } } + if (*has_dynamic_shapes && dim_args_spec.empty()) { + return errors::InvalidArgument( + "Module main has dynamic shapes but no dim_args_spec was given"); + } if (!dim_args_spec.empty()) { - if (!has_dynamic_shapes) { + if (!*has_dynamic_shapes) { return errors::InvalidArgument( "Module main has dim_args_spec but does not have dynamic shapes"); } @@ -347,7 +430,8 @@ class XlaCallModuleOp : public XlaOpKernel { void Compile(XlaOpKernelContext *ctx) override { if (has_dynamic_shapes_) { - OP_REQUIRES_OK(ctx, RefineDynamicShapes(ctx, &module_)); + OP_REQUIRES_OK(ctx, + RefineDynamicShapes(ctx, &module_, dim_args_spec_.size())); } std::vector inputs(ctx->num_inputs()); @@ -395,8 +479,8 @@ class XlaCallModuleOp : public XlaOpKernel { int nr_outputs_; std::vector dim_args_spec_; bool has_dynamic_shapes_; - mlir::OwningOpRef module_; mlir::MLIRContext context_{mlir::MLIRContext::Threading::DISABLED}; + mlir::OwningOpRef module_; }; REGISTER_XLA_OP(Name("XlaCallModule"), XlaCallModuleOp); diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index a4e312a134f..14cc77651c8 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -4,6 +4,7 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//tensorflow/compiler/tf2xla:friends"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 0cb8bf31778..c5cb60bc48c 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -34,6 +34,7 @@ namespace tensorflow { StatusOr XlaScatter( const xla::XlaOp& buffer, const xla::XlaOp& updates, const xla::XlaOp& indices, bool indices_are_vectors, + bool indices_are_sorted, const std::function& combiner, xla::XlaBuilder* builder) { @@ -200,7 +201,7 @@ StatusOr XlaScatter( << "]"; return xla::Scatter(buffer, indices, new_updates, combiner_computation, - dim_numbers); + dim_numbers, indices_are_sorted); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 96881ed20ab..ef9c738d9b4 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -46,6 +46,7 @@ namespace tensorflow { StatusOr XlaScatter( const xla::XlaOp& buffer, const xla::XlaOp& updates, const xla::XlaOp& indices, bool indices_are_vectors, + bool indices_are_sorted, const std::function& combiner, xla::XlaBuilder* builder); diff --git a/tensorflow/compiler/tf2xla/light_outside_compilation_kernels_for_test.cc b/tensorflow/compiler/tf2xla/light_outside_compilation_kernels_for_test.cc index 850f93065ca..e13333ca28b 100644 --- a/tensorflow/compiler/tf2xla/light_outside_compilation_kernels_for_test.cc +++ b/tensorflow/compiler/tf2xla/light_outside_compilation_kernels_for_test.cc @@ -191,7 +191,7 @@ class DynamicMultidimOp : public OpKernel { TensorShape output_shape; auto vec = ctx->input(0).flat(); for (int i = 0; i < vec.size(); i++) { - output_shape.AddDim(vec(i)); + OP_REQUIRES_OK(ctx, output_shape.AddDimWithStatus(vec(i))); } Tensor* out_tensor = nullptr; OP_REQUIRES_OK(ctx, diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index 9786ab582a7..5ffe2a06f34 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -171,11 +171,6 @@ MlirOptimizationPassState MlirBridgePass::GetPassState( // We set `uses_uninitialized_resource_args` to false here because the first // phase of the bridge is not affected by uninitialized resource args. - // Note we are recording the stats using LogGraphFeatures in the pass - // that calls this one to avoid duplicate logging due to - // GetMlirBridgeRolloutPolicy being called multiple times for the same graph. - // TODO(b/241853328): Add caching of pass state and call logging/metrics - // related to graph analysis from here. MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy( graph, &function_library, config_proto, /*uses_uninitialized_resource_args=*/false, @@ -185,8 +180,6 @@ MlirOptimizationPassState MlirBridgePass::GetPassState( return MlirOptimizationPassState::Enabled; case MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis: return MlirOptimizationPassState::FallbackEnabled; - case MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysisSafeModeFallback: - return MlirOptimizationPassState::FallbackEnabled; case MlirBridgeRolloutPolicy::kDisabledByUser: VLOG(1) << "Skipping MLIR TPU Bridge, MLIR TPU bridge disabled by user. " "Old bridge will evaluate."; @@ -198,6 +191,14 @@ MlirOptimizationPassState MlirBridgePass::GetPassState( "graph has unsupported features. Old bridge will evaluate."; metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v2", true, "invalid_graph"); + // We set `uses_uninitialized_resource_args` to false here because the + // first phase of the bridge is not affected by uninitialized resource + // args. + // For Invalid Graph Analysis we need to log here because Run will not be + // called. + LogGraphFeatures(graph, &function_library, config_proto, + /*uses_uninitialized_resource_args=*/false, + /*is_v1_compat=*/false); return MlirOptimizationPassState::Disabled; } } @@ -245,8 +246,17 @@ Status MlirBridgePass::Run(const ConfigProto& config_proto, if (is_qualified_for_tpu_bridge) { bool fallback_enabled = false; - if (pass_state == MlirOptimizationPassState::FallbackEnabled) + if (pass_state == MlirOptimizationPassState::FallbackEnabled) { + // We set `uses_uninitialized_resource_args` to false here because the + // first phase of the bridge is not affected by uninitialized resource + // args. + // TODO (b/241853328) Consider moving logging if caching for graph + // analysis or GetPassState is added + LogGraphFeatures(graph, &function_library, config_proto, + /*uses_uninitialized_resource_args=*/false, + /*is_v1_compat=*/false); fallback_enabled = true; + } VLOG(1) << "Running MLIR TPU Bridge"; mlir_bridge_gauge_v2->GetCell()->Set(true); return mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1), @@ -263,15 +273,8 @@ MlirOptimizationPassState MlirBridgeV1CompatPass::GetPassState( // Skip MLIR TPU Bridge if no TPU devices found. if (device_set && !HasTPUDevice(*device_set)) return MlirOptimizationPassState::Disabled; - - // Do not run the bridge if it's enabled by the graph analysis, - // only run if it's enabled by the user explicitly. // We set `uses_uninitialized_resource_args` to false here because the first // phase of the bridge is not affected by uninitialized resource args. - // Note we are recording the stats using LogGraphFeatures in the pass - // that calls this one. - // TODO(b/241853328): Add caching of pass state and call logging/metrics - // related to graph analysis from here. MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy( graph, /*function_library=*/&function_library, config_proto, /*uses_uninitialized_resource_args=*/false, /*is_v1_compat=*/true, @@ -279,8 +282,6 @@ MlirOptimizationPassState MlirBridgeV1CompatPass::GetPassState( switch (policy) { case MlirBridgeRolloutPolicy::kEnabledByUser: return MlirOptimizationPassState::Enabled; - case MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysisSafeModeFallback: - return MlirOptimizationPassState::FallbackEnabled; case MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis: return MlirOptimizationPassState::FallbackEnabled; case MlirBridgeRolloutPolicy::kDisabledByUser: @@ -295,6 +296,14 @@ MlirOptimizationPassState MlirBridgeV1CompatPass::GetPassState( "evaluate."; metrics::UpdateTfMlirBridgeFirstPhaseCounter("tpu", "v1", true, "invalid_graph"); + // We set `uses_uninitialized_resource_args` to false here because the + // first phase of the bridge is not affected by uninitialized resource + // args. + // For Invalid Graph Analysis we need to log here because Run will not be + // called. + LogGraphFeatures(graph, &function_library, config_proto, + /*uses_uninitialized_resource_args=*/false, + /*is_v1_compat=*/true); return MlirOptimizationPassState::Disabled; } } @@ -332,9 +341,18 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options, VLOG(1) << "Running MLIR TPU Bridge V1 Compat"; - bool fallback_enabled = true; - if (pass_state == MlirOptimizationPassState::Enabled) - fallback_enabled = false; + bool fallback_enabled = false; + if (pass_state == MlirOptimizationPassState::FallbackEnabled) { + // We set `uses_uninitialized_resource_args` to false here because the first + // phase of the bridge is not affected by uninitialized resource args. + // TODO (b/241853328) Consider moving logging if caching for graph analysis + // or GetPassState is added + LogGraphFeatures(**options.graph, options.flib_def, + options.session_options->config, + /*uses_uninitialized_resource_args=*/false, + /*is_v1_compat=*/true); + fallback_enabled = true; + } mlir_bridge_gauge_v1->GetCell()->Set(true); diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index cc3e52d4ba1..16718e9026b 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -36,7 +36,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc index 8f9a88530ba..2df09ff8f86 100644 --- a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" -#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/xla_compile_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/mlir/utils/array_container_utils.h" @@ -96,7 +96,8 @@ Status MlirXlaOpKernel::ConstructXlaOp(XlaOpKernelContext* ctx) { } // Create a graph that wraps the kernel. - TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(def(), xla_args, result_dtypes)); + TF_ASSIGN_OR_RETURN(auto graph, + CreateSingleOpGraph(def(), xla_args, result_dtypes)); // Compile the graph to HLO. GraphDebugInfo debug_info; diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index 5f073b3f305..ce8016bc509 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -6,6 +6,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//tensorflow:internal"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 5ff70d64bb5..9978030a252 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -1369,7 +1369,8 @@ specification of how to compute its value, as a string, in the form E.g., the specification "2.1" denotes the value args[2].shape[1]. args: A list of `Tensor` with possibly different types to be passed as arguments - to the HLO module. + to the HLO module. These are all non-dimension arguments. The dimension + arguments are computed at JIT time. version: Changes when we change the semantics of the op, to support backwards compatibility. Version 1 carries an MHLO text or bytecode `module`. From version 2, the op carries a StableHLO text or bytecode `module`. diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index f5c9672e7bb..815fc42b44a 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -5,6 +5,7 @@ load( load("//tensorflow:tensorflow.default.bzl", "tf_custom_op_py_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//visibility:public", ], diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index bd582951bec..f763289bf57 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -102,6 +102,7 @@ def unary_op_wrapper(x, name=None): round = _unary_op(math_ops.round) sin = _unary_op(math_ops.sin) sign = _unary_op(math_ops.sign) +tan = _unary_op(math_ops.tan) tanh = _unary_op(math_ops.tanh) # Bessel @@ -603,7 +604,7 @@ def custom_call_v2( ) -def call_module(args, *, version=1, module, Tout, Sout, dim_args_spec=()): +def call_module(args, *, version=2, module, Tout, Sout, dim_args_spec=()): # See documentation for the XlaCallModule op. return gen_xla_ops.xla_call_module( args, version=version, module=module, dim_args_spec=dim_args_spec, diff --git a/tensorflow/compiler/tf2xla/resource_util.cc b/tensorflow/compiler/tf2xla/resource_util.cc index 1bfe364ea6d..80fa72d84a5 100644 --- a/tensorflow/compiler/tf2xla/resource_util.cc +++ b/tensorflow/compiler/tf2xla/resource_util.cc @@ -22,7 +22,6 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" @@ -32,7 +31,7 @@ limitations under the License. namespace tensorflow { namespace { -using stream_executor::port::StatusOr; +using tsl::StatusOr; const char kIdentityNOp[] = "IdentityN"; const char kIfOp[] = "If"; diff --git a/tensorflow/compiler/tf2xla/resource_util.h b/tensorflow/compiler/tf2xla/resource_util.h index cbc2b9cf91b..4aac73638d6 100644 --- a/tensorflow/compiler/tf2xla/resource_util.h +++ b/tensorflow/compiler/tf2xla/resource_util.h @@ -22,7 +22,6 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/hash/hash.h" #include "absl/strings/str_cat.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 4fb232f8b89..e40df038bbb 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -94,7 +94,7 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, } *tensor_shape = TensorShape(); for (int i = 0; i < shape.rank(); ++i) { - tensor_shape->AddDim(shape.dimensions(i)); + TF_RETURN_IF_ERROR(tensor_shape->AddDimWithStatus(shape.dimensions(i))); } return OkStatus(); } diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index 2ce48b8a72b..464413633a6 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -16,7 +16,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { @@ -55,6 +57,12 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_UINT64: *type = xla::U64; return OkStatus(); + case tensorflow::DT_FLOAT8_E5M2: + *type = xla::F8E5M2; + return OkStatus(); + case tensorflow::DT_FLOAT8_E4M3FN: + *type = xla::F8E4M3FN; + return OkStatus(); case tensorflow::DT_BFLOAT16: *type = xla::BF16; return OkStatus(); @@ -84,6 +92,8 @@ StatusOr EncodePrimitiveTypeAsDataType(xla::PrimitiveType type) { static const absl::flat_hash_map& data_type_map = *new absl::flat_hash_map({ {xla::PRED, DT_BOOL}, + {xla::F8E5M2, DT_FLOAT8_E5M2}, + {xla::F8E4M3FN, DT_FLOAT8_E4M3FN}, {xla::BF16, DT_BFLOAT16}, {xla::F16, DT_HALF}, {xla::F32, DT_FLOAT}, diff --git a/tensorflow/compiler/tf2xla/xla_argument.h b/tensorflow/compiler/tf2xla/xla_argument.h index 6401c05544f..7497153e0bc 100644 --- a/tensorflow/compiler/tf2xla/xla_argument.h +++ b/tensorflow/compiler/tf2xla/xla_argument.h @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index b6f15be0891..25e6ab60846 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include #include #include @@ -25,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/compiler/jit/xla_compile_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/mlir/utils/array_container_utils.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" @@ -60,6 +62,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" +#include "tensorflow/core/tpu/tpu_defs.h" #include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -726,6 +729,62 @@ std::vector GetValidControlRets( return valid_control_rets; } +Status XlaCompiler::CompileSingleOp( + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::SingleOpCompileArgument& single_op_compile_argument, + absl::Span args, XlaCompiler::CompilationResult* result) { + const std::vector& result_dtypes = + single_op_compile_argument.output_dtypes; + const NodeDef& node_def = single_op_compile_argument.node_def; + TF_ASSIGN_OR_RETURN( + auto graph, + CreateSingleOpGraph(node_def, args, + single_op_compile_argument.output_dtypes)); + + auto compile_with_old_bridge = [&]() { + *result = {}; + return CompileGraph(compile_options, node_def.name(), std::move(graph), + args, result); + }; + + const ConfigProto* config = &(single_op_compile_argument.config_proto); + auto bridge_rollout = GetMlirBridgeRolloutState( + config ? std::optional(*config) : std::nullopt); + if (bridge_rollout == + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED || + node_def.op() == "VarIsInitializedOp" || + (bridge_rollout != + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED && + options_.device_type.type_string() != DEVICE_TPU_XLA_JIT)) { + return compile_with_old_bridge(); + } + + GraphDebugInfo debug_info; + std::vector control_rets; + if (result_dtypes.empty()) { + control_rets.push_back(node_def.name()); + } + + bool mlir_enabled = (bridge_rollout == + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED); + VLOG(1) << "Attempting MLIR bridge." + << (mlir_enabled ? " MLIR is explicitly enabled." : ""); + auto mlir_result = CompileGraphToXlaHlo( + *graph, mlir::SpanToArrayRef(args), control_rets, + options_.device_type.type_string(), compile_options.use_tuple_arg, + /*analyse_graph=*/!mlir_enabled, *options_.flib_def, debug_info, + options_.shape_determination_fns, result); + + if (mlir_result.ok() || mlir_enabled) { + return mlir_result; + } + + VLOG(2) << "Failed second phase of the MLIR bridge. Will " + "retry with the old bridge. MLIR bridge compilation status: " + << mlir_result; + return compile_with_old_bridge(); +} + Status XlaCompiler::CompileFunction( const XlaCompiler::CompileOptions& options, const NameAttrList& fn_name_attrs, @@ -916,7 +975,7 @@ Status XlaCompiler::XLAShapeForArgument( } TF_RET_CHECK(absl::holds_alternative(arg.shape)); TensorShape shape; - shape.AddDim(arg.max_array_size); + TF_RETURN_IF_ERROR(shape.AddDimWithStatus(arg.max_array_size)); shape.AppendShape(std::get(arg.shape)); TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape)); @@ -934,7 +993,7 @@ Status XlaCompiler::XLAShapeForArgument( } TF_RET_CHECK(absl::holds_alternative(arg.shape)); TensorShape shape; - shape.AddDim(arg.max_array_size); + TF_RETURN_IF_ERROR(shape.AddDimWithStatus(arg.max_array_size)); shape.AppendShape(std::get(arg.shape)); xla::Shape buffer_shape; TF_RETURN_IF_ERROR( @@ -974,6 +1033,20 @@ void XlaCompiler::PopulateArgumentFromResource(const XlaResource& resource, arg->name = resource.name(); } +XlaCompiler::SingleOpCompileArgument::SingleOpCompileArgument( + const OpKernelContext& ctx) { + std::vector output_dtypes(ctx.num_outputs()); + for (int i = 0; i < output_dtypes.size(); ++i) { + output_dtypes[i] = ctx.expected_output_dtype(i); + } + this->output_dtypes = output_dtypes; + this->node_def = ctx.op_kernel().def(); + auto* config_proto = ctx.function_library()->config_proto(); + if (config_proto != nullptr) { + this->config_proto = *config_proto; + } +} + // Builds XLA computations for each of the arguments to the computation. // `args` are the arguments to the computation. Status XlaCompiler::BuildArguments( @@ -1429,6 +1502,13 @@ Status XlaCompiler::CompileGraph( &result->resource_updates, &result->xla_output_shape, result->input_mapping)); + for (const auto& [key, send] : host_compute_sends_) { + *result->host_compute_metadata.add_device_to_host() = send; + } + for (const auto& [key, recv] : host_compute_recvs_) { + *result->host_compute_metadata.add_host_to_device() = recv; + } + VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; VLOG(2) << "XLA output shape: " diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 1e424256957..d027326239e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -101,6 +101,8 @@ class XlaContext; // `tensor_array_gradients` ordered set. class XlaCompiler { public: + // TODO(b/255826209): Remove this alias. Depending on XlaCompiler just to use + // XlaArgument seeems weird and can cause circular dependencies. using Argument = ::tensorflow::XlaArgument; // Options pertaining to an individual call to CompileGraph() or @@ -212,6 +214,10 @@ class XlaCompiler { // This is currently only used to obtain MLIR TPU bridge rollout state. // Can be removed once full rollout is complete. ConfigProto config_proto; + + SingleOpCompileArgument() = default; + + explicit SingleOpCompileArgument(const OpKernelContext& ctx); }; explicit XlaCompiler(Options options); @@ -227,6 +233,11 @@ class XlaCompiler { absl::Span args, CompilationResult* result); + Status CompileSingleOp( + const CompileOptions& options, + const SingleOpCompileArgument& single_op_compile_argument, + absl::Span args, CompilationResult* result); + // Compiles a tensorflow::Graph into an xla::XlaComputation. // Similar to CompileFunction, but takes a Graph as input rather than a // function. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 6abdec5b6dd..5231d9e4246 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -41,17 +41,21 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/version.h" @@ -1920,7 +1924,7 @@ TEST_F(XlaCompilerTest, AliasResourceUpdates) { EXPECT_EQ(alias.entries(0).parameter_number(), 0); } -// Tests that passing in an exact duplicate input to SetDeviceToHostMeatadata +// Tests that passing in an exact duplicate input to SetDeviceToHostMetadata // is not an error. TEST_F(XlaCompilerTest, SetDeviceToHostMetadataExactDuplicate) { XlaCompiler compiler(DefaultOptions()); @@ -1949,7 +1953,7 @@ TEST_F(XlaCompilerTest, SetDeviceToHostMetadataMismatchedDuplicate) { EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT); } -// Tests that passing in an exact duplicate input to SetHostToDeviceMeatadata +// Tests that passing in an exact duplicate input to SetHostToDeviceMetadata // is not an error. TEST_F(XlaCompilerTest, SetHostToDeviceMetadataExactDuplicate) { XlaCompiler compiler(DefaultOptions()); @@ -1978,5 +1982,67 @@ TEST_F(XlaCompilerTest, SetHostToDeviceMetadataMismatchedDuplicate) { EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT); } +TEST_F(OpsTestBase, BuildSingleOpCompileArgument) { + TF_EXPECT_OK(NodeDefBuilder("identity_op", "Identity") + .Input(FakeInput(DT_FLOAT)) + .Attr("T", DT_FLOAT) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + AddInputFromArray(TensorShape({1, 2}), {0, 1}); + TF_EXPECT_OK(RunOpKernel()); + + XlaCompiler::SingleOpCompileArgument arg(*context_); + + EXPECT_THAT(arg.output_dtypes, ::testing::ElementsAreArray({DT_FLOAT})); + EXPECT_EQ(arg.node_def.SerializeAsString(), + context_->op_kernel().def().SerializeAsString()); + EXPECT_EQ(arg.config_proto.ByteSizeLong(), 0); +} + +TEST_F(OpsTestBase, CompileSingleOp) { + TF_EXPECT_OK(NodeDefBuilder("identity_op", "Identity") + .Input(FakeInput(DT_FLOAT)) + .Attr("T", DT_FLOAT) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + AddInputFromArray(TensorShape({1, 2}), {6.9, 4.2}); + TF_EXPECT_OK(RunOpKernel()); + + XlaCompiler::SingleOpCompileArgument single_op_arg(*context_); + + xla::Client* client = xla::ClientLibrary::LocalClientOrDie(); + XlaOpRegistry::RegisterCompilationKernels(); + FunctionDefLibrary flib; + std::unique_ptr flib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), flib)); + + XlaCompiler::Options options; + options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); + options.client = client; + options.flib_def = flib_def.get(); + + XlaCompiler compiler(options); + + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kConstant; + args[0].type = DT_FLOAT; + args[0].shape = TensorShape({1, 2}); + args[0].constant_value = GetInput(0); + args[0].initialized = true; + + XlaCompiler::CompilationResult result; + TF_EXPECT_OK(compiler.CompileSingleOp(XlaCompiler::CompileOptions(), + single_op_arg, args, &result)); + + // Tests that the generated computation works. + std::unique_ptr actual = + client->Execute(*result.computation, {}).value(); + xla::Literal actual_literal = client->Transfer(*actual).value(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR2({{6.9, 4.2}}); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index 4725193b7a9..494660d48f0 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -112,7 +112,7 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { }); } -StatusOr XlaExpression::ResolveDynamism(xla::Client* client) const { +StatusOr XlaExpression::ResolveDynamism() const { switch (kind()) { case Kind::kConstant: { // Constant values are considered static. @@ -133,9 +133,6 @@ StatusOr XlaExpression::ResolveDynamism(xla::Client* client) const { HumanString()); } - if (!client) - return errors::InvalidArgument("client is required to resolve constant"); - TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape()); // The XLA layout is specified minor to major, and TensorFlow uses a major to diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h index 73173ac8c55..9eb11023ce9 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.h +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -131,7 +131,7 @@ class XlaExpression { // ResolveDynamism computes where a value inside this op is dynamic or can be // inferred at compile time. - StatusOr ResolveDynamism(xla::Client* client) const; + StatusOr ResolveDynamism() const; // Returns the shape of the tensor. // The shape of a resource is the shape of a resource handle (i.e., a scalar), diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 06d73fde192..0e621995cbc 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -23,8 +23,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" #include "tensorflow/compiler/xla/service/computation_placer.h" -#include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/layout_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index ce765935885..6f45dcf1726 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -292,27 +292,26 @@ TEST(XlaJitCompiledCpuFunction, CanCompileWithAdditionalPlatform) { const string& Name() const override { return name_; } - se::port::StatusOr> - DescriptionForDevice(int ordinal) const override { + tsl::StatusOr> DescriptionForDevice( + int ordinal) const override { return std::unique_ptr(nullptr); } - se::port::StatusOr ExecutorForDevice( - int ordinal) override { + tsl::StatusOr ExecutorForDevice(int ordinal) override { return nullptr; } - se::port::StatusOr ExecutorForDeviceWithPluginConfig( + tsl::StatusOr ExecutorForDeviceWithPluginConfig( int ordinal, const se::PluginConfig& config) override { return nullptr; } - se::port::StatusOr GetExecutor( + tsl::StatusOr GetExecutor( const se::StreamExecutorConfig& config) override { return nullptr; } - se::port::StatusOr> GetUncachedExecutor( + tsl::StatusOr> GetUncachedExecutor( const se::StreamExecutorConfig& config) override { return std::unique_ptr(nullptr); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index cee1f62de6a..0f7373659bd 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -244,6 +244,13 @@ Status XlaOpKernelContext::ConstantInputAsIntScalar( return ConstantInputAsIntScalar(index, out, mode); } +StatusOr XlaOpKernelContext::ConstantInputAsIntScalar( + absl::string_view name, xla::ValueInferenceMode mode) { + int64_t out; + TF_RETURN_IF_ERROR(ConstantInputAsIntScalar(name, &out, mode)); + return out; +} + Status XlaOpKernelContext::ConstantInputAsFloatScalar( int index, double* out, xla::ValueInferenceMode mode) { xla::Literal literal; @@ -270,8 +277,7 @@ static Status LiteralToPredVector(const xla::LiteralSlice& literal, Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) { xla::Literal literal; XlaExpression e = InputExpression(index); - auto* client = compiler() ? compiler()->client() : nullptr; - StatusOr dynamism_or_status = e.ResolveDynamism(client); + StatusOr dynamism_or_status = e.ResolveDynamism(); if (!dynamism_or_status.ok()) { // When failed to resolve dynamism, conservatively consider the value // dynamic. This could happen if the input depends on some ops like @@ -313,8 +319,7 @@ Status XlaOpKernelContext::ResolveInputDynamismReshaped( int index, absl::Span new_dims, xla::Literal* dynamism_literal) { XlaExpression e = InputExpression(index); - auto* client = compiler() ? compiler()->client() : nullptr; - StatusOr dynamism_or_status = e.ResolveDynamism(client); + StatusOr dynamism_or_status = e.ResolveDynamism(); if (!dynamism_or_status.ok()) { xla::Literal true_literal = xla::LiteralUtil::CreateR0(true); // When failed to resolve dynamism, conservatively consider the value diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index b200af35a8d..d6aaa993eb7 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -162,6 +162,10 @@ class XlaOpKernelContext { absl::string_view name, int64_t* out, xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + StatusOr ConstantInputAsIntScalar( + absl::string_view name, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + // Converts a constant scalar float32 or float64 tensor into a float64. Status ConstantInputAsFloatScalar( int index, double* out, diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index bb76b5d22b9..7f1b5dbd1b9 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -87,7 +87,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; } if (!x.has_device_allowlist && !y.has_device_allowlist) { LOG(WARNING) << "Duplicate registrations of " << x.name - << "with no device allowlists."; + << " with no device allowlists."; return false; } if (x.has_device_allowlist && y.has_device_allowlist) { diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index a4dd68f7db7..68656bb2e74 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -45,6 +45,8 @@ extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT" extern const char* const DEVICE_XLA_CPU; extern const char* const DEVICE_XLA_GPU; +// Do not include DT_FLOAT8_* as float or numeric types since they are only +// supported in a very limited set of ops. constexpr std::array kFloatTypes = { {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; constexpr std::array kFloatAndComplexTypes = { @@ -54,15 +56,17 @@ constexpr std::array kNumericTypes = { DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BFLOAT16}}; -constexpr std::array kCpuAllTypes = { - {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, - DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; +constexpr std::array kCpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, + DT_INT8, DT_QINT8, DT_INT16, DT_INT32, DT_QINT32, + DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_BOOL, DT_BFLOAT16, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN}}; -constexpr std::array kGpuAllTypes = { - {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, - DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; +constexpr std::array kGpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, + DT_INT8, DT_QINT8, DT_INT16, DT_INT32, DT_QINT32, + DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_BOOL, DT_BFLOAT16, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN}}; // Class that manages registrations of operators and devices for the XLA JIT. // Not thread-safe. @@ -330,6 +334,12 @@ class XlaOpRegistry { #define REGISTER_XLA_OP(NAME, OP) \ REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP) +#define REGISTER_XLA_CONV_OP(BUILDER, OP) \ + REGISTER_XLA_OP(BUILDER.TypeConstraint("T", GetXlaConvTypesForNonGpu()), OP) \ + REGISTER_XLA_OP(BUILDER.TypeConstraint("T", GetXlaConvTypesForGpu()) \ + .Device(DEVICE_GPU_XLA_JIT), \ + OP) + class XlaOpRegistrationBuilder { public: // Starts an operator registration chain. diff --git a/tensorflow/compiler/tf2xla/xla_op_registry_test.cc b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc index 7b3b15b1af7..4d8e1bc31f8 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry_test.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc @@ -82,6 +82,25 @@ TEST(XlaOpRegistryTest, XlaOpRegistrationWithOverride) { } } +TEST(XlaOpReigstryTest, XlaOpRegistrationDeviceKernels) { + XlaOpRegistry::RegisterCompilationKernels(); + auto registered_devices = XlaOpRegistry::BackendNames(); + for (const auto& resgistered_device : registered_devices) { + auto kernels = XlaOpRegistry::DeviceKernels(resgistered_device, true); + for (const auto& kernel : kernels) { + if (kernel->op() == "DummyDuplicateOp") { + if (resgistered_device == DEVICE_CPU_XLA_JIT) { + EXPECT_EQ(kernel->constraint(0).allowed_values().list().type(0), + DT_INT32); + } else { + EXPECT_EQ(kernel->constraint(0).allowed_values().list().type(0), + DT_FLOAT); + } + } + } + } +} + // A dummy generic OpKernel for all backends. class DummyInfeasibleTypeConstraintOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 68f574557d6..bf7e2ecc551 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -138,7 +138,7 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { } case kTensorArray: { TensorShape ta_shape; - ta_shape.AddDim(max_array_size_); + TF_RETURN_IF_ERROR(ta_shape.AddDimWithStatus(max_array_size_)); ta_shape.AppendShape(shape_); value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); @@ -146,7 +146,7 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { } case kStack: { TensorShape ta_shape; - ta_shape.AddDim(max_array_size_); + TF_RETURN_IF_ERROR(ta_shape.AddDimWithStatus(max_array_size_)); ta_shape.AppendShape(shape_); value_ = xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_), @@ -171,7 +171,7 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, std::unique_ptr& gradient = tensor_array_gradients_[source]; if (!gradient) { TensorShape ta_shape; - ta_shape.AddDim(max_array_size_); + TF_RETURN_IF_ERROR(ta_shape.AddDimWithStatus(max_array_size_)); ta_shape.AppendShape(shape_); xla::XlaOp gradient_value = xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 877ae56dc88..65adb039a4f 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -1,12 +1,15 @@ -load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library") -load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_portable") +load("//tensorflow/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") load( "//tensorflow/tsl/platform:build_config.bzl", "tf_proto_library", ) load("//third_party/compute_library:build_defs.bzl", "if_enable_acl") +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test", "xla_py_proto_library") +load( + "//tensorflow/tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) package( default_visibility = ["//visibility:public"], @@ -15,10 +18,9 @@ package( package_group( name = "friends", - includes = ["//tensorflow:internal"], + includes = ["//tensorflow/compiler/xla:internal"], packages = [ # copybara:uncomment "//learning/infra/mira/...", - "//third_party/auroraml/...", "//third_party/australis/...", "//third_party/iree/...", "//third_party/mira/...", @@ -33,6 +35,7 @@ package_group( package_group( name = "internal", + includes = ["//tensorflow:internal"], packages = [ "//tensorflow/compiler/xla/...", ], @@ -46,6 +49,8 @@ package_group( ], ) +exports_files(["run_lit.sh"]) + # Filegroup used to collect source files for dependency checking. filegroup( name = "c_srcs", @@ -108,7 +113,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "bit_cast_test", srcs = ["bit_cast_test.cc"], deps = [ @@ -136,13 +141,14 @@ cc_library( ":types", ":util", ":xla_data_proto_cc", + "//tensorflow/tsl/platform:float8", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) -tf_cc_test( +xla_cc_test( name = "comparison_util_test", srcs = ["comparison_util_test.cc"], deps = [ @@ -169,6 +175,18 @@ cc_library( ], ) +cc_library( + name = "frontend_attributes", + srcs = [ + "frontend_attributes.cc", + ], + hdrs = [ + "frontend_attributes.h", + ], + visibility = [":friends"], + deps = ["//tensorflow/compiler/xla/hlo/ir:hlo"], +) + cc_library( name = "test", testonly = 1, @@ -218,7 +236,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "status_macros_test", size = "small", srcs = ["status_macros_test.cc"], @@ -248,13 +266,13 @@ cc_library( "statusor.h", ], linkopts = select({ - "//tensorflow:freebsd": ["-lexecinfo"], + "//tensorflow/tsl:freebsd": ["-lexecinfo"], "//conditions:default": [], }), visibility = ["//visibility:public"], deps = [ ":status", - "//tensorflow/compiler/xla/stream_executor/lib", + "//tensorflow/tsl/platform:statusor", ], ) @@ -294,13 +312,14 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "util_test", srcs = ["util_test.cc"], deps = [ ":test", ":types", ":util", + "//tensorflow/tsl/platform:float8", "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:test_main", ], @@ -318,7 +337,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "permutation_util_test", srcs = ["permutation_util_test.cc"], deps = [ @@ -348,7 +367,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "iterator_util_test", srcs = ["iterator_util_test.cc"], deps = [ @@ -379,6 +398,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":permutation_util", + ":printer", ":status", ":status_macros", ":statusor", @@ -386,6 +406,7 @@ cc_library( ":util", ":xla_data_proto_cc", "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:float8", "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:platform_port", "@com_google_absl//absl/algorithm:container", @@ -394,6 +415,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", ], ) @@ -412,7 +434,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "shape_test", srcs = ["shape_test.cc"], deps = [ @@ -430,7 +452,7 @@ tf_cc_test( ], ) -tf_cc_test( +xla_cc_test( name = "shape_util_test", srcs = ["shape_util_test.cc"], deps = [ @@ -442,6 +464,7 @@ tf_cc_test( ":types", ":util", ":xla_data_proto_cc", + "//tensorflow/tsl/platform:env", "//tensorflow/tsl/platform:protobuf", "//tensorflow/tsl/platform:test_benchmark", "//tensorflow/tsl/platform:test_main", @@ -449,7 +472,7 @@ tf_cc_test( ], ) -tf_cc_test( +xla_cc_test( name = "primitive_util_test", srcs = ["primitive_util_test.cc"], deps = [ @@ -465,7 +488,7 @@ tf_cc_test( ], ) -tf_cc_test( +xla_cc_test( name = "layout_util_test", srcs = ["layout_util_test.cc"], deps = [ @@ -476,7 +499,7 @@ tf_cc_test( ], ) -tf_cc_test( +xla_cc_test( name = "layout_test", srcs = ["layout_test.cc"], deps = [ @@ -488,7 +511,7 @@ tf_cc_test( ], ) -tf_cc_test( +xla_cc_test( name = "index_util_test", srcs = ["index_util_test.cc"], deps = [ @@ -509,6 +532,7 @@ cc_library( ":array3d", ":array4d", ":permutation_util", + ":printer", ":shape_util", ":status_macros", ":types", @@ -516,6 +540,7 @@ cc_library( ":xla_data_proto_cc", "//tensorflow/tsl/lib/core:bitmap", "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:float8", "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:platform_port", "//tensorflow/tsl/platform:protobuf", @@ -523,14 +548,13 @@ cc_library( "//tensorflow/tsl/util:byte_swap_array", "@com_google_absl//absl/base", "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/hash", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", ], ) -tf_cc_test( +xla_cc_test( name = "literal_test", srcs = ["literal_test.cc"], deps = [ @@ -543,6 +567,7 @@ tf_cc_test( ":test", ":types", "//tensorflow/tsl/lib/core:status_test_util", + "//tensorflow/tsl/platform:float8", "//tensorflow/tsl/platform:test", "//tensorflow/tsl/platform:test_main", "@com_google_absl//absl/base", @@ -591,8 +616,10 @@ cc_library( ":error_spec", ":literal", ":literal_util", + ":shape_util", ":util", "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:float8", "//tensorflow/tsl/platform:status", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", @@ -640,7 +667,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "array_test", srcs = ["array_test.cc"], deps = [ @@ -663,7 +690,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "array2d_test", srcs = ["array2d_test.cc"], deps = [ @@ -684,7 +711,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "array3d_test", srcs = ["array3d_test.cc"], deps = [ @@ -709,7 +736,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "array4d_test", srcs = ["array4d_test.cc"], deps = [ @@ -785,7 +812,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "text_literal_reader_test", srcs = ["text_literal_reader_test.cc"], deps = [ @@ -818,7 +845,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "text_literal_writer_test", srcs = ["text_literal_writer_test.cc"], deps = [ @@ -854,7 +881,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "shape_tree_test", srcs = ["shape_tree_test.cc"], deps = [ @@ -875,6 +902,7 @@ cc_library( hdrs = ["shape_layout.h"], visibility = ["//visibility:public"], deps = [ + ":printer", ":shape_util", ":types", ":util", @@ -900,7 +928,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "window_util_test", srcs = ["window_util_test.cc"], deps = [ @@ -926,7 +954,7 @@ cc_library( "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/hlo/evaluator:hlo_evaluator", - "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/tsl/lib/math:math_util", "//tensorflow/tsl/platform:logging", @@ -936,7 +964,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "reference_util_test", srcs = ["reference_util_test.cc"], deps = [ @@ -971,14 +999,13 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "parse_flags_from_env_test", srcs = ["parse_flags_from_env_test.cc"], deps = [ ":parse_flags_from_env", "//tensorflow/tsl/platform:env", - "//tensorflow/tsl/platform:env_impl", "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:subprocess", "//tensorflow/tsl/platform:test", @@ -1021,7 +1048,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "debug_options_parsers_test", size = "small", srcs = [ @@ -1031,7 +1058,7 @@ tf_cc_test( deps = [ ":xla_proto_cc", - "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/tsl/platform:logging", "//tensorflow/tsl/platform:test", "@com_google_absl//absl/container:flat_hash_map", @@ -1052,7 +1079,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "refcounting_hash_map_test", srcs = ["refcounting_hash_map_test.cc"], deps = [ @@ -1073,6 +1100,68 @@ cc_library( hdrs = ["side_effect_util.h"], ) +cc_library( + name = "lazy", + hdrs = ["lazy.h"], + deps = ["@com_google_absl//absl/functional:any_invocable"], +) + +tf_proto_library( + name = "autotune_results_proto", + srcs = ["autotune_results.proto"], + cc_api_version = 2, + protodeps = [ + "//tensorflow/tsl/protobuf:autotuning_proto", + ], + visibility = ["//visibility:public"], +) + +xla_py_proto_library( + name = "autotune_results_py_pb2", + api_version = 2, + visibility = ["//visibility:public"], + deps = [ + ":autotune_results_proto", + ], +) + +cc_library( + name = "autotune_serialize", + srcs = if_cuda_is_configured(["autotune_serialize.cc"]), + hdrs = if_cuda_is_configured(["autotune_serialize.h"]), + # TODO(aminim): There appears to be an in-progress refactoring in TF/XLA to + # mark rules as compatible_with GCE. The rules that this depends on are + # not yet marked as compatible, so this one can't be either (yet). + compatible_with = [], + visibility = ["//visibility:public"], + deps = if_cuda_is_configured([ + ":autotune_results_proto_cc", + ":statusor", + "//tensorflow/compiler/xla/service/gpu:gemm_algorithm_picker", + "//tensorflow/compiler/xla/service/gpu:gpu_conv_algorithm_picker", + ]), +) + +cc_library( + name = "printer", + srcs = ["printer.cc"], + hdrs = ["printer.h"], + visibility = ["//visibility:public"], + deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +) + +filegroup( + name = "litfiles", + srcs = [ + "runlit.cfg.py", + "runlit.site.cfg.py", + ], + visibility = ["//tensorflow/compiler/xla:__subpackages__"], +) + # ----------------------------------------------------------------------------- # copybara:uncomment_begin(google-only) @@ -1092,17 +1181,3 @@ cc_library( # deps = [":xla_proto"], # ) # copybara:uncomment_end - -# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. -cc_header_only_library( - name = "xla_headers_lib", - visibility = ["//visibility:public"], - deps = [ - ":xla_data_proto_cc", - ":xla_proto_cc", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:stream_executor_headers_lib", - ], -) diff --git a/tensorflow/compiler/xla/README.md b/tensorflow/compiler/xla/README.md index 029a2e0081f..9799312625c 100644 --- a/tensorflow/compiler/xla/README.md +++ b/tensorflow/compiler/xla/README.md @@ -5,3 +5,91 @@ XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that optimizes TensorFlow computations. See the [documentation](./g3doc/index.md). + +This directory is currently migrating to [OpenXLA](https://github.com/openxla/) +and will be the root of the [openxla/xla](https://github.com/openxla/xla) +repository. + +== Directory Structure == + +We're currently re-organizing the directory structure, the end result should be +that no sources are directly present at the top-level. Here is the current plan +for the directory layout: + +* backends/ (created from directories under xla/service) + * cpu/ + * gpu/ + * interpreter/ + * ... +* hlo/ (created from xla/service/ mostly, no sources expected directly here) + * client/ (created from xla/client) + * evaluator/ (created from the relevant files in xla/service) + * experimental/ (created from xla/experimental) + * ir/ (created from the relevant files in xla/service) + * python/ (created from xla/python) + * tests/ (created from xla/tests) + * transforms/ (created from the relevant files in xla/service) + * utils/ (created from the relevant files in xla/service) +* mlir/ (also exported as the root of https://github.com/tensorflow/mlir-hlo + and building with CMake) + * CMakeLists.txt (just like now for mlir-hlo repo). + * backends/ (same as xla/backends/ but for the MLIR specific bits: this is + a short-term solution pending more convergence / XLA Next) + * cpu + * gpu (populated from /compiler/xla/mlir/transforms/gpu/passes.td, + will contain all the glue for e2e GPU compilation) + * bindings/ + * c/ (bootstrapped from mlir/hlo/{include,lib}/mlir-hlo-c) + * python/ (bootstrapped from mlir/hlo/python, should talk about some + low-level LAX?) + * integration_tests/ (to be defined / refined) + * tools/ (xla-opt, fuzzer, ir-reducer, interpreter/evaluator) + * transforms/ (generic / cross dialect transforms) + * utils/ +* // below are dialects and transforms folders + * framework/ (moved from compiler/mlir/xla/ir/xla_framework_ops.td) + * gml_st + * gmlst-opt.cc + * gmlst-runner.cc (runner tool that can execute IR at ~gmlst level) + * ir/ + * integration_test (tests that run things: Tensor(s) in -> Tensor(s) + out) + * test (IR -> IR tests for passes interaction) + * transforms/ + * bufferize_tiled_loop/ + * bufferize_tiled_loop.cc + * bufferize_tiled_loop.h + * ... + * lhlo_gpu/ + * mhlo/ + * mhlo-opt.cc + * analysis/ + * dataflow/ + * dataflow.h + * dataflow.cc + * test_pass.cc // test_only target, linked into opt tool for + testing only. + * integration_test (tests that run things: Tensor(s) in -> Tensor(s) + out) + * ir/ (dialect definition) + * test (IR -> IR tests for passes interaction) + * transforms/ + * materialize_broadcasts/ + * materialize_broadcasts.cc + * materialize_broadcasts.h // headers stays with the source + * broadcast_analysis.{cc, h} // private analysis/utils needed + for this pass + * test/ (.mlir unit-tests are collocated with the pass + itself). + * … + * passes.td // enables group registration for all passes. + * utils/ + * thlo/ + * runtime/ +* pjrt/ (created from xla/pjrt) +* rpc/ (created from xla/rpc) +* runtime/ +* stream_executor/ (moved from TensorFlow) +* third_party/ (vendoring of TSL base library) +* tools/ (created from mlir/hlo/tools and xla/tools) +* translate/ (StableHLO to MHLO, MHLO to HLO, HLO to MHLO, MHLO to TOSA) diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 6e20cda083b..bdfc8d687e6 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -40,29 +40,10 @@ namespace xla { namespace array_impl { -// conjunction -// -// Performs a compile-time logical AND operation on the passed types (which -// must have `::value` members convertible to `bool`. Short-circuits if it -// encounters any `false` members (and does not compare the `::value` members -// of any remaining arguments). -// -// This metafunction is designed to be a drop-in replacement for the C++17 -// `std::conjunction` metafunction. -template -struct conjunction; - -template -struct conjunction - : std::conditional, T>::type {}; - -template <> -struct conjunction<> : std::true_type {}; - // A type trait that is valid when all elements in a parameter pack are of // integral type. Not using an alias template to work around MSVC 14.00 bug. template -struct pack_is_integral : conjunction...> {}; +struct pack_is_integral : std::conjunction...> {}; // Compares three same-sized vectors elementwise. For each item in `values`, // returns false if any of values[i] is outside the half-open range [starts[i], @@ -139,10 +120,12 @@ class Array { CHECK(idx == num_elements()); } - // Creates a 2D array of a floating-point type (half, bfloat16, float, + // Creates a 2D array of a floating-point type (float8, half, bfloat16, float, // or double) from an initializer list of float values. template ::value || + (std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value) && diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index 38f63603728..77e3c9c94e8 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -50,10 +50,12 @@ class Array2D : public Array { Array2D(std::initializer_list> values) : Array(values) {} - // Creates an array of a floating-point type (half, bfloat16, float, + // Creates an array of a floating-point type (float8, half, bfloat16, float, // or double) from the given nested initializer list of float values. template ::value || + (std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value) && diff --git a/tensorflow/compiler/xla/autotune_results.proto b/tensorflow/compiler/xla/autotune_results.proto new file mode 100644 index 00000000000..125b28cee79 --- /dev/null +++ b/tensorflow/compiler/xla/autotune_results.proto @@ -0,0 +1,51 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +import "tensorflow/tsl/protobuf/autotuning.proto"; + +// A collection of algorithms for particular dot/convs. Usually this is "the +// best" algorithm for the particular dot/conv, although that's not strictly +// required. +// +// Users don't interact with this proto directly. It's used internally to +// facilitate ahead-of-time autotuning -- The string used by +// xla::{Serialize,Load}AutotuneResults is, internally, a serialization of this +// proto. +// +// LINT.IfChange +message AutotuneResults { + message Entry { + string device = 1; + string hlo = 2; + + // nb: These results are always tied to a particular version of + // cublas/cudnn, but this is *especially* true for cublasLt results. For + // cublasLt gemms, the result is an index into the list of candidate + // algorithms returned by cublasLt. Different version of cublasLt -> + // different list of algos -> different interpretation of results! + tensorflow.AutotuneResult result = 3; + } + + int32 version = 1; + repeated Entry dots = 2; + repeated Entry convs = 3; +} +// LINT.ThenChange( +// "autotune_serialize.cc:version" +// ) diff --git a/tensorflow/compiler/xla/autotune_serialize.cc b/tensorflow/compiler/xla/autotune_serialize.cc new file mode 100644 index 00000000000..149fdf9f24f --- /dev/null +++ b/tensorflow/compiler/xla/autotune_serialize.cc @@ -0,0 +1,63 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/autotune_serialize.h" + +#include + +#include "tensorflow/compiler/xla/autotune_results.pb.h" +#include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h" + +namespace xla { +namespace { + +// Bump this version whenever you change the structure of the results. +// LINT.IfChange(version) +constexpr int kVersion = 1; +// LINT.ThenChange() + +} // anonymous namespace + +Status LoadAutotuneResults(absl::string_view data) { + AutotuneResults results; + // The cast here is necessary for MacOS builds. + if (!results.ParseFromString(std::string(data))) { // NOLINT + return tsl::errors::InvalidArgument( + "Failed to parse autotune results string."); + } + if (results.version() != kVersion) { + return tsl::errors::InvalidArgument(absl::StrFormat( + "Version mismatch in autotune results. Expected %d but was %d", + kVersion, results.version())); + } + + TF_RETURN_IF_ERROR(gpu::GpuConvAlgorithmPicker::LoadAutotuneResults(results)); + TF_RETURN_IF_ERROR(gpu::GemmAlgorithmPicker::LoadAutotuneResults(results)); + return OkStatus(); +} + +StatusOr SerializeAutotuneResults() { + AutotuneResults results; + results.set_version(kVersion); + + TF_RETURN_IF_ERROR( + gpu::GpuConvAlgorithmPicker::WriteAutotuneResults(&results)); + TF_RETURN_IF_ERROR(gpu::GemmAlgorithmPicker::WriteAutotuneResults(&results)); + + return results.SerializeAsString(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/autotune_serialize.h b/tensorflow/compiler/xla/autotune_serialize.h new file mode 100644 index 00000000000..8d555b90575 --- /dev/null +++ b/tensorflow/compiler/xla/autotune_serialize.h @@ -0,0 +1,70 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_AUTOTUNE_SERIALIZE_H_ +#define TENSORFLOW_COMPILER_XLA_AUTOTUNE_SERIALIZE_H_ + +#include + +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Functions to save/load XLA's autotuning results. +// +// This is used for ahead-of-time autotuning. Specifically: +// +// When XLA calls cublas (for matmuls, aka "gemm" or "dot") or cudnn (for +// convolutions), it usually has to choose an "algorithm" for the particular +// dot/conv. XLA queries cublas/cudnn for a list of candidate algorithms. Then +// it runs all of them and picks the fastest one. This is what we call +// "autotuning". It happens in GemmAlgorithmPicker and GpuConvAlgorithmPicker. +// +// Autotuning is necessary to get good performance for dot/conv. But it also +// has some disadvantages. +// +// - Because it relies on timing data, it is fundamentally nondeterministic. +// But even if two algorithms have similar runtimes, our choice of algorithm +// may be visible to the user: Different algorithms can have different +// numerics, and sometimes they can even have different bugs! +// +// - Trying all the candidate algorithms can be slow, especially if when some +// of the candidates are "very bad" and run especially slowly compared to the +// optimal candidate. This slows down compilation. +// +// To address the disadvantages above, we allow users to save/restore the +// autotuning choices that XLA has made, using the functions below. +// +// Loading autotuning results does not erase existing autotuning choices, but in +// the event of a disagreement between the existing data and the new data, the +// new algorithm is chosen. +// +// Note that even if you call LoadAutotuneResults(), if XLA encounters a +// dot/conv that is *not* covered by the loaded data, it will go ahead and +// autotune it like normal. In other words, the behavior of XLA should be +// identical with or without ahead-of-time autotuning, modulo nondeterminism. +// +// This is important if you want to be able to use the same autotuning file with +// different versions of XLA, because as XLA changes, exactly which dots/convs +// it wants to run can also change. For example, XLA might change the conv +// padding heuristics it uses, and we don't want that to mean that all users of +// ahead-of-time autotuning are broken. +// +StatusOr SerializeAutotuneResults(); +Status LoadAutotuneResults(absl::string_view data); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_AUTOTUNE_SERIALIZE_H_ diff --git a/tensorflow/compiler/xla/backends/interpreter/BUILD b/tensorflow/compiler/xla/backends/interpreter/BUILD index b3f82448e81..f431901a4e8 100644 --- a/tensorflow/compiler/xla/backends/interpreter/BUILD +++ b/tensorflow/compiler/xla/backends/interpreter/BUILD @@ -5,6 +5,7 @@ load( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) @@ -32,6 +33,7 @@ cc_library( "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:cholesky_expander", @@ -43,7 +45,6 @@ cc_library( "//tensorflow/compiler/xla/service:eigh_expander", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", - "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_cse", @@ -84,16 +85,16 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_proto_cc", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:dynamic_dimension_inference", "//tensorflow/compiler/xla/service:executable", - "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_execution_profile", "//tensorflow/compiler/xla/service:maybe_owning_device_memory", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/stream_executor", "//tensorflow/compiler/xla/stream_executor:event", - "//tensorflow/compiler/xla/stream_executor/lib", + "//tensorflow/tsl/platform:statusor", ], ) @@ -111,8 +112,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/evaluator:hlo_evaluator", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:executable", - "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_execution_profile", "//tensorflow/compiler/xla/service:hlo_module_config", @@ -137,7 +138,8 @@ cc_library( ":executor", ":platform_id", "//tensorflow/compiler/xla/stream_executor", - "//tensorflow/compiler/xla/stream_executor/lib", + "//tensorflow/compiler/xla/stream_executor/platform", + "//tensorflow/tsl/platform:status", "@com_google_absl//absl/strings:str_format", ], alwayslink = True, # Registers itself with the MultiPlatformManager. @@ -155,6 +157,7 @@ cc_library( "//tensorflow/compiler/xla/stream_executor/host:host_stream", "//tensorflow/compiler/xla/stream_executor/host:host_timer", "//tensorflow/tsl/platform:logging", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/backends/interpreter/compiler.h b/tensorflow/compiler/xla/backends/interpreter/compiler.h index 7f7b2717285..73423a0964a 100644 --- a/tensorflow/compiler/xla/backends/interpreter/compiler.h +++ b/tensorflow/compiler/xla/backends/interpreter/compiler.h @@ -20,10 +20,10 @@ limitations under the License. #include #include "tensorflow/compiler/xla/backends/interpreter/platform_id.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/backends/interpreter/executable.cc b/tensorflow/compiler/xla/backends/interpreter/executable.cc index e47c19f282c..d2d5a3063b5 100644 --- a/tensorflow/compiler/xla/backends/interpreter/executable.cc +++ b/tensorflow/compiler/xla/backends/interpreter/executable.cc @@ -24,9 +24,9 @@ limitations under the License. #include "tensorflow/compiler/xla/backends/interpreter/executable_base.h" #include "tensorflow/compiler/xla/backends/interpreter/executor.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/backends/interpreter/executable.h b/tensorflow/compiler/xla/backends/interpreter/executable.h index 96ff3465ea2..9690159e68c 100644 --- a/tensorflow/compiler/xla/backends/interpreter/executable.h +++ b/tensorflow/compiler/xla/backends/interpreter/executable.h @@ -21,10 +21,10 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/backends/interpreter/executable_base.h" #include "tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" diff --git a/tensorflow/compiler/xla/backends/interpreter/executable_base.cc b/tensorflow/compiler/xla/backends/interpreter/executable_base.cc index 63af800b4c4..7abcd9aa1f3 100644 --- a/tensorflow/compiler/xla/backends/interpreter/executable_base.cc +++ b/tensorflow/compiler/xla/backends/interpreter/executable_base.cc @@ -18,17 +18,17 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/compiler/xla/stream_executor/platform.h" #include "tensorflow/compiler/xla/stream_executor/stream.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" +#include "tensorflow/tsl/platform/statusor.h" namespace xla { namespace interpreter { diff --git a/tensorflow/compiler/xla/backends/interpreter/executable_base.h b/tensorflow/compiler/xla/backends/interpreter/executable_base.h index ad6c5d7bb30..04c979a9fd9 100644 --- a/tensorflow/compiler/xla/backends/interpreter/executable_base.h +++ b/tensorflow/compiler/xla/backends/interpreter/executable_base.h @@ -19,11 +19,11 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/backends/interpreter/executor.cc b/tensorflow/compiler/xla/backends/interpreter/executor.cc index a2dbfc8d966..9845e192127 100644 --- a/tensorflow/compiler/xla/backends/interpreter/executor.cc +++ b/tensorflow/compiler/xla/backends/interpreter/executor.cc @@ -16,7 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/backends/interpreter/executor.h" #include +#include +#include "absl/functional/any_invocable.h" #include "tensorflow/compiler/xla/status_macros.h" namespace stream_executor { @@ -53,15 +55,15 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst, uint64_t size) { AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() { // Ignore errors. - port::Status ok = SynchronousMemcpy(host_dst, dev_src, size); + tsl::Status ok = SynchronousMemcpy(host_dst, dev_src, size); }); - port::Status status = AsExecutorStream(stream)->BlockUntilDone(); + tsl::Status status = AsExecutorStream(stream)->BlockUntilDone(); if (status.ok()) { return true; } - // TODO(b/199316985): Return 'Status' instead of 'bool', so we don't need to - // throw away error information here. + // TODO(b/199316985): Return 'tsl::Status' instead of 'bool', so we don't need + // to throw away error information here. LOG(WARNING) << "Memcpy: error on stream: " << status; return false; } @@ -70,48 +72,49 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, const void *host_src, uint64_t size) { AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() { // Ignore errors. - port::Status ok = SynchronousMemcpy(dev_dst, host_src, size); + tsl::Status ok = SynchronousMemcpy(dev_dst, host_src, size); }); - port::Status status = AsExecutorStream(stream)->BlockUntilDone(); + tsl::Status status = AsExecutorStream(stream)->BlockUntilDone(); if (status.ok()) { return true; } - // TODO(b/199316985): Return 'Status' instead of 'bool', so we don't need to - // throw away error information here. + // TODO(b/199316985): Return 'tsl::Status' instead of 'bool', so we don't need + // to throw away error information here. LOG(WARNING) << "Memcpy: error on stream: " << status; return false; } -port::Status XlaInterpreterExecutor::SynchronousMemcpy( - DeviceMemoryBase *dev_dst, const void *host_src, uint64_t size) { +tsl::Status XlaInterpreterExecutor::SynchronousMemcpy(DeviceMemoryBase *dev_dst, + const void *host_src, + uint64_t size) { memcpy(dev_dst->opaque(), host_src, size); return ::tsl::OkStatus(); } -port::Status XlaInterpreterExecutor::SynchronousMemcpy( +tsl::Status XlaInterpreterExecutor::SynchronousMemcpy( void *host_dst, const DeviceMemoryBase &dev_src, uint64_t size) { memcpy(host_dst, dev_src.opaque(), size); return ::tsl::OkStatus(); } bool XlaInterpreterExecutor::HostCallback( - Stream *stream, std::function callback) { - AsExecutorStream(stream)->EnqueueTaskWithStatus(callback); + Stream *stream, absl::AnyInvocable callback) { + AsExecutorStream(stream)->EnqueueTaskWithStatus(std::move(callback)); return true; } bool XlaInterpreterExecutor::CreateStreamDependency(Stream *dependent, Stream *other) { - AsExecutorStream(dependent)->EnqueueTask( + AsExecutorStream(dependent)->EnqueueTaskWithStatus( [other]() { return other->BlockHostUntilDone(); }); - port::Status status = AsExecutorStream(dependent)->BlockUntilDone(); + tsl::Status status = AsExecutorStream(dependent)->BlockUntilDone(); if (status.ok()) { return true; } - // TODO(b/199316985): Return 'Status' instead of 'bool', so we don't need to - // throw away error information here. + // TODO(b/199316985): Return 'tsl::Status' instead of 'bool', so we don't need + // to throw away error information here. LOG(WARNING) << "CreateStreamDependency: error on stream: " << status; return false; } @@ -126,11 +129,11 @@ bool XlaInterpreterExecutor::StopTimer(Stream *stream, Timer *timer) { return true; } -port::Status XlaInterpreterExecutor::BlockHostUntilDone(Stream *stream) { +tsl::Status XlaInterpreterExecutor::BlockHostUntilDone(Stream *stream) { return AsExecutorStream(stream)->BlockUntilDone(); } -port::StatusOr> +tsl::StatusOr> XlaInterpreterExecutor::CreateDeviceDescription(int device_ordinal) { internal::DeviceDescriptionBuilder builder; diff --git a/tensorflow/compiler/xla/backends/interpreter/executor.h b/tensorflow/compiler/xla/backends/interpreter/executor.h index 1d78220728f..290685579ba 100644 --- a/tensorflow/compiler/xla/backends/interpreter/executor.h +++ b/tensorflow/compiler/xla/backends/interpreter/executor.h @@ -19,9 +19,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_BACKENDS_INTERPRETER_EXECUTOR_H_ #define TENSORFLOW_COMPILER_XLA_BACKENDS_INTERPRETER_EXECUTOR_H_ -#include #include +#include "absl/functional/any_invocable.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/stream_executor/blas.h" @@ -52,18 +52,18 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { explicit XlaInterpreterExecutor(const PluginConfig &plugin_config); ~XlaInterpreterExecutor() override; - port::Status Init(int device_ordinal, DeviceOptions device_options) override { + tsl::Status Init(int device_ordinal, DeviceOptions device_options) override { return ::tsl::OkStatus(); } - port::Status GetKernel(const MultiKernelLoaderSpec &spec, - KernelBase *kernel) override { - return port::UnimplementedError("Not Implemented"); + tsl::Status GetKernel(const MultiKernelLoaderSpec &spec, + KernelBase *kernel) override { + return tsl::errors::Unimplemented("Not Implemented"); } - port::Status Launch(Stream *stream, const ThreadDim &thread_dims, - const BlockDim &block_dims, const KernelBase &kernel, - const KernelArgsArrayBase &args) override { - return port::UnimplementedError("Not Implemented"); + tsl::Status Launch(Stream *stream, const ThreadDim &thread_dims, + const BlockDim &block_dims, const KernelBase &kernel, + const KernelArgsArrayBase &args) override { + return tsl::errors::Unimplemented("Not Implemented"); } DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; @@ -88,59 +88,56 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { return false; } - port::Status MemZero(Stream *stream, DeviceMemoryBase *location, - uint64_t size) override { - return port::InternalError("Interpreter can not memzero"); + tsl::Status MemZero(Stream *stream, DeviceMemoryBase *location, + uint64_t size) override { + return tsl::errors::Internal("Interpreter can not memzero"); } - port::Status Memset(Stream *stream, DeviceMemoryBase *location, - uint8_t pattern, uint64_t size) override { - return port::InternalError("Interpreter can not memset"); + tsl::Status Memset(Stream *stream, DeviceMemoryBase *location, + uint8_t pattern, uint64_t size) override { + return tsl::errors::Internal("Interpreter can not memset"); } - port::Status Memset32(Stream *stream, DeviceMemoryBase *location, - uint32_t pattern, uint64_t size) override { - return port::InternalError("Interpreter can not memset"); + tsl::Status Memset32(Stream *stream, DeviceMemoryBase *location, + uint32_t pattern, uint64_t size) override { + return tsl::errors::Internal("Interpreter can not memset"); } // No "synchronize all activity" implemented for this platform at the moment. bool SynchronizeAllActivity() override { return true; } - port::Status SynchronousMemZero(DeviceMemoryBase *location, - uint64_t size) override { - return port::InternalError("Interpreter can not memzero"); + tsl::Status SynchronousMemZero(DeviceMemoryBase *location, + uint64_t size) override { + return tsl::errors::Internal("Interpreter can not memzero"); } - port::Status SynchronousMemSet(DeviceMemoryBase *location, int value, - uint64_t size) override { - return port::InternalError("Interpreter can not memset"); + tsl::Status SynchronousMemSet(DeviceMemoryBase *location, int value, + uint64_t size) override { + return tsl::errors::Internal("Interpreter can not memset"); } - port::Status SynchronousMemcpy(DeviceMemoryBase *dev_dst, - const void *host_src, uint64_t size) override; - port::Status SynchronousMemcpy(void *host_dst, - const DeviceMemoryBase &dev_src, - uint64_t size) override; - port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst, - const DeviceMemoryBase &pop_src, - uint64_t size) override { - return port::Status{port::error::UNIMPLEMENTED, ""}; + tsl::Status SynchronousMemcpy(DeviceMemoryBase *dev_dst, const void *host_src, + uint64_t size) override; + tsl::Status SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &dev_src, + uint64_t size) override; + tsl::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst, + const DeviceMemoryBase &pop_src, + uint64_t size) override { + return tsl::Status{tsl::error::UNIMPLEMENTED, ""}; } bool HostCallback(Stream *stream, - std::function callback) override; + absl::AnyInvocable callback) override; - port::Status AllocateEvent(Event *event) override { - return ::tsl::OkStatus(); - } + tsl::Status AllocateEvent(Event *event) override { return ::tsl::OkStatus(); } - port::Status DeallocateEvent(Event *event) override { + tsl::Status DeallocateEvent(Event *event) override { return ::tsl::OkStatus(); } - port::Status RecordEvent(Stream *stream, Event *event) override { - return port::Status{port::error::UNIMPLEMENTED, "RecordEvent"}; + tsl::Status RecordEvent(Stream *stream, Event *event) override { + return tsl::Status{tsl::error::UNIMPLEMENTED, "RecordEvent"}; } - port::Status WaitForEvent(Stream *stream, Event *event) override { - return port::Status{port::error::UNIMPLEMENTED, "WaitForEvent"}; + tsl::Status WaitForEvent(Stream *stream, Event *event) override { + return tsl::Status{tsl::error::UNIMPLEMENTED, "WaitForEvent"}; } Event::Status PollForEventStatus(Event *event) override { @@ -156,7 +153,7 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { bool StartTimer(Stream *stream, Timer *timer) override; bool StopTimer(Stream *stream, Timer *timer) override; - port::Status BlockHostUntilDone(Stream *stream) override; + tsl::Status BlockHostUntilDone(Stream *stream) override; int PlatformDeviceCount() override { return 1; } @@ -164,15 +161,15 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { return false; } - port::StatusOr> CreateDeviceDescription() + tsl::StatusOr> CreateDeviceDescription() const override { return CreateDeviceDescription(0); } - static port::StatusOr> + static tsl::StatusOr> CreateDeviceDescription(int device_ordinal); - port::Status EnablePeerAccessTo(StreamExecutorInterface *other) override { + tsl::Status EnablePeerAccessTo(StreamExecutorInterface *other) override { return ::tsl::OkStatus(); } @@ -203,8 +200,7 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { private: DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape); - port::StatusOr AllocateOutputBuffer( - const xla::Shape &shape); + tsl::StatusOr AllocateOutputBuffer(const xla::Shape &shape); const PluginConfig plugin_config_; }; diff --git a/tensorflow/compiler/xla/backends/interpreter/interpreter_transfer_manager.h b/tensorflow/compiler/xla/backends/interpreter/interpreter_transfer_manager.h index fa4c001e653..35701597742 100644 --- a/tensorflow/compiler/xla/backends/interpreter/interpreter_transfer_manager.h +++ b/tensorflow/compiler/xla/backends/interpreter/interpreter_transfer_manager.h @@ -26,6 +26,18 @@ class InterpreterTransferManager : public GenericTransferManager { InterpreterTransferManager(); ~InterpreterTransferManager() override = default; + bool CanShapedBufferBeAccessedNow( + se::StreamExecutor* executor, + const ShapedBuffer& device_buffer) const override { + return true; + } + + bool CanBufferBeAccessedNow( + se::StreamExecutor* executor, + const se::DeviceMemoryBase& device_buffer) const override { + return true; + } + private: InterpreterTransferManager(const InterpreterTransferManager&) = delete; InterpreterTransferManager& operator=(const InterpreterTransferManager&) = diff --git a/tensorflow/compiler/xla/backends/interpreter/platform.cc b/tensorflow/compiler/xla/backends/interpreter/platform.cc index 9c3309fcda6..0259baf8221 100644 --- a/tensorflow/compiler/xla/backends/interpreter/platform.cc +++ b/tensorflow/compiler/xla/backends/interpreter/platform.cc @@ -21,10 +21,10 @@ limitations under the License. #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/backends/interpreter/executor.h" #include "tensorflow/compiler/xla/stream_executor/device_options.h" -#include "tensorflow/compiler/xla/stream_executor/lib/initialize.h" -#include "tensorflow/compiler/xla/stream_executor/lib/status.h" #include "tensorflow/compiler/xla/stream_executor/multi_platform_manager.h" #include "tensorflow/compiler/xla/stream_executor/platform.h" +#include "tensorflow/compiler/xla/stream_executor/platform/initialize.h" +#include "tensorflow/tsl/platform/status.h" namespace stream_executor { namespace interpreter { @@ -41,12 +41,12 @@ int XlaInterpreterPlatform::VisibleDeviceCount() const { return 1; } const std::string& XlaInterpreterPlatform::Name() const { return name_; } -port::StatusOr> +tsl::StatusOr> XlaInterpreterPlatform::DescriptionForDevice(int ordinal) const { return XlaInterpreterExecutor::CreateDeviceDescription(ordinal); } -port::StatusOr XlaInterpreterPlatform::ExecutorForDevice( +tsl::StatusOr XlaInterpreterPlatform::ExecutorForDevice( int ordinal) { StreamExecutorConfig config; config.ordinal = ordinal; @@ -55,7 +55,7 @@ port::StatusOr XlaInterpreterPlatform::ExecutorForDevice( return GetExecutor(config); } -port::StatusOr +tsl::StatusOr XlaInterpreterPlatform::ExecutorForDeviceWithPluginConfig( int device_ordinal, const PluginConfig& plugin_config) { StreamExecutorConfig config; @@ -65,13 +65,13 @@ XlaInterpreterPlatform::ExecutorForDeviceWithPluginConfig( return GetExecutor(config); } -port::StatusOr XlaInterpreterPlatform::GetExecutor( +tsl::StatusOr XlaInterpreterPlatform::GetExecutor( const StreamExecutorConfig& config) { return executor_cache_.GetOrCreate( config, [&]() { return GetUncachedExecutor(config); }); } -port::StatusOr> +tsl::StatusOr> XlaInterpreterPlatform::GetUncachedExecutor( const StreamExecutorConfig& config) { auto executor = std::make_unique( @@ -79,8 +79,8 @@ XlaInterpreterPlatform::GetUncachedExecutor( config.ordinal); auto init_status = executor->Init(config.device_options); if (!init_status.ok()) { - return port::Status{ - port::error::INTERNAL, + return tsl::Status{ + tsl::error::INTERNAL, absl::StrFormat( "failed initializing StreamExecutor for device ordinal %d: %s", config.ordinal, init_status.ToString())}; @@ -100,7 +100,7 @@ void XlaInterpreterPlatform::UnregisterTraceListener(TraceListener* listener) { static void InitializeXlaInterpreterPlatform() { std::unique_ptr platform(new XlaInterpreterPlatform); - SE_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform))); + TF_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform))); } } // namespace interpreter diff --git a/tensorflow/compiler/xla/backends/interpreter/platform.h b/tensorflow/compiler/xla/backends/interpreter/platform.h index ba638071e33..d48d460ed78 100644 --- a/tensorflow/compiler/xla/backends/interpreter/platform.h +++ b/tensorflow/compiler/xla/backends/interpreter/platform.h @@ -40,18 +40,18 @@ class XlaInterpreterPlatform : public Platform { const std::string& Name() const override; - port::StatusOr> DescriptionForDevice( + tsl::StatusOr> DescriptionForDevice( int ordinal) const override; - port::StatusOr ExecutorForDevice(int ordinal) override; + tsl::StatusOr ExecutorForDevice(int ordinal) override; - port::StatusOr ExecutorForDeviceWithPluginConfig( + tsl::StatusOr ExecutorForDeviceWithPluginConfig( int ordinal, const PluginConfig& config) override; - port::StatusOr GetExecutor( + tsl::StatusOr GetExecutor( const StreamExecutorConfig& config) override; - port::StatusOr> GetUncachedExecutor( + tsl::StatusOr> GetUncachedExecutor( const StreamExecutorConfig& config) override; void RegisterTraceListener(std::unique_ptr listener) override; diff --git a/tensorflow/compiler/xla/backends/profiler/BUILD b/tensorflow/compiler/xla/backends/profiler/BUILD index 14c9a946112..c02ea2c3d0c 100644 --- a/tensorflow/compiler/xla/backends/profiler/BUILD +++ b/tensorflow/compiler/xla/backends/profiler/BUILD @@ -1,14 +1,26 @@ -load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_cuda_library") +load("//tensorflow/tsl:tsl.bzl", "if_libtpu", "tsl_gpu_library") + +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) package_group( name = "friends", - packages = ["//tensorflow/compiler/xla/backends/profiler/tpu"], + packages = [ + "//tensorflow/compiler/xla/backends/profiler/cpu", + "//tensorflow/compiler/xla/backends/profiler/gpu", + "//tensorflow/compiler/xla/backends/profiler/tpu", + ], ) -tf_cuda_library( +tsl_gpu_library( name = "profiler_backends", - visibility = ["//tensorflow:internal"], - deps = if_libtpu([ + cuda_deps = [ + "//tensorflow/compiler/xla/backends/profiler/gpu:device_tracer", + ], + visibility = ["//tensorflow/compiler/xla:internal"], + deps = [ + "//tensorflow/compiler/xla/backends/profiler/cpu:host_tracer", + "//tensorflow/compiler/xla/backends/profiler/cpu:metadata_collector", + ] + if_libtpu([ "//tensorflow/compiler/xla/backends/profiler/tpu:tpu_tracer", ]), alwayslink = True, diff --git a/tensorflow/compiler/xla/backends/profiler/cpu/BUILD b/tensorflow/compiler/xla/backends/profiler/cpu/BUILD new file mode 100644 index 00000000000..3543ed7a558 --- /dev/null +++ b/tensorflow/compiler/xla/backends/profiler/cpu/BUILD @@ -0,0 +1,116 @@ +load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") +load("//tensorflow/tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") + +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + +cc_library( + name = "host_tracer", + srcs = ["host_tracer_factory.cc"], + visibility = [ + "//tensorflow/compiler/xla/backends/profiler:__pkg__", + "//tensorflow/core/profiler:internal", + "//third_party/car/onboard/gpu:__subpackages__", + ], + deps = [ + ":host_tracer_impl", + "//tensorflow/tsl/profiler/lib:profiler_factory", + "//tensorflow/tsl/profiler/protobuf:profiler_options_proto_cc", + ], + alwayslink = True, +) + +cc_library( + name = "host_tracer_impl", + srcs = ["host_tracer.cc"], + hdrs = ["host_tracer.h"], + copts = tf_profiler_copts(), + visibility = [ + "//tensorflow/core/profiler:internal", + ], + deps = [ + "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/platform:types", + "//tensorflow/tsl/profiler/backends/cpu:host_tracer_utils", + "//tensorflow/tsl/profiler/backends/cpu:traceme_recorder", + "//tensorflow/tsl/profiler/lib:profiler_interface", + "//tensorflow/tsl/profiler/protobuf:xplane_proto_cc", + "//tensorflow/tsl/profiler/utils:time_utils", + "//tensorflow/tsl/profiler/utils:xplane_schema", + "//tensorflow/tsl/profiler/utils:xplane_utils", + ], +) + +cc_library( + name = "python_tracer", + srcs = ["python_tracer_factory.cc"], + visibility = [ + "//tensorflow/compiler/xla/python:__pkg__", + "//tensorflow/core/profiler:internal", + ], + deps = [ + ":python_tracer_impl", + "//tensorflow/tsl/profiler/lib:profiler_factory", + "//tensorflow/tsl/profiler/protobuf:profiler_options_proto_cc", + ], + alwayslink = True, +) + +cc_library( + name = "python_tracer_impl", + srcs = ["python_tracer.cc"], + hdrs = ["python_tracer.h"], + copts = tf_profiler_copts() + ["-fexceptions"], + features = ["-use_header_modules"], + visibility = [ + "//tensorflow/core/profiler:internal", + ], + deps = [ + "//tensorflow/compiler/xla/python/profiler/internal:python_hooks", + "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:logging", + "//tensorflow/tsl/platform:macros", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/profiler/lib:profiler_interface", + "//tensorflow/tsl/profiler/protobuf:xplane_proto_cc", + ], +) + +cc_library( + name = "metadata_collector", + srcs = ["metadata_collector.cc"], + copts = tf_profiler_copts(), + visibility = [ + "//tensorflow/compiler/xla/backends/profiler:__pkg__", + "//tensorflow/core/profiler:internal", + ], + deps = [ + ":metadata_utils", + "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/compiler/xla/service:xla_debug_info_manager", + "//tensorflow/tsl/platform:macros", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/profiler/lib:profiler_factory", + "//tensorflow/tsl/profiler/lib:profiler_interface", + "//tensorflow/tsl/profiler/protobuf:profiler_options_proto_cc", + "//tensorflow/tsl/profiler/protobuf:xplane_proto_cc", + "//tensorflow/tsl/profiler/utils:xplane_schema", + "//tensorflow/tsl/profiler/utils:xplane_utils", + ], + alwayslink = True, +) + +cc_library( + name = "metadata_utils", + hdrs = ["metadata_utils.h"], + visibility = [ + "//tensorflow/core/profiler:internal", + ], + deps = [ + "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/tsl/profiler/convert:xla_op_utils", + "//tensorflow/tsl/profiler/protobuf:xplane_proto_cc", + "//tensorflow/tsl/profiler/utils:xplane_builder", + "//tensorflow/tsl/profiler/utils:xplane_schema", + ], +) diff --git a/tensorflow/compiler/xla/backends/profiler/cpu/host_tracer.cc b/tensorflow/compiler/xla/backends/profiler/cpu/host_tracer.cc new file mode 100644 index 00000000000..1b52addbebd --- /dev/null +++ b/tensorflow/compiler/xla/backends/profiler/cpu/host_tracer.cc @@ -0,0 +1,122 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/backends/profiler/cpu/host_tracer.h" + +#include +#include +#include +#include + +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/platform/types.h" +#include "tensorflow/tsl/profiler/backends/cpu/host_tracer_utils.h" +#include "tensorflow/tsl/profiler/backends/cpu/traceme_recorder.h" +#include "tensorflow/tsl/profiler/lib/profiler_interface.h" +#include "tensorflow/tsl/profiler/protobuf/xplane.pb.h" +#include "tensorflow/tsl/profiler/utils/time_utils.h" +#include "tensorflow/tsl/profiler/utils/xplane_schema.h" +#include "tensorflow/tsl/profiler/utils/xplane_utils.h" + +namespace xla { +namespace profiler { +namespace { + +// Controls TraceMeRecorder and converts TraceMeRecorder::Events into XEvents. +// +// Thread-safety: This class is go/thread-compatible. +class HostTracer : public tsl::profiler::ProfilerInterface { + public: + explicit HostTracer(int host_trace_level); + ~HostTracer() override; + + tsl::Status Start() override; // TENSORFLOW_STATUS_OK + + tsl::Status Stop() override; // TENSORFLOW_STATUS_OK + + tsl::Status CollectData( // TENSORFLOW_STATUS_OK + tensorflow::profiler::XSpace* space) override; + + private: + // Level of host tracing. + const int host_trace_level_; + + // True if currently recording. + bool recording_ = false; + + // Timestamp at the start of tracing. + uint64_t start_timestamp_ns_ = 0; + + // Container of all traced events. + tsl::profiler::TraceMeRecorder::Events events_; +}; + +HostTracer::HostTracer(int host_trace_level) + : host_trace_level_(host_trace_level) {} + +HostTracer::~HostTracer() { Stop().IgnoreError(); } // NOLINT + +tsl::Status HostTracer::Start() { // TENSORFLOW_STATUS_OK + if (recording_) { + return tsl::errors::Internal("TraceMeRecorder already started"); + } + + // All TraceMe captured should have a timestamp greater or equal to + // start_timestamp_ns_ to prevent timestamp underflow in XPlane. + // Therefore this have to be done before TraceMeRecorder::Start. + start_timestamp_ns_ = tsl::profiler::GetCurrentTimeNanos(); + recording_ = tsl::profiler::TraceMeRecorder::Start(host_trace_level_); + if (!recording_) { + return tsl::errors::Internal("Failed to start TraceMeRecorder"); + } + return tsl::OkStatus(); +} + +tsl::Status HostTracer::Stop() { // TENSORFLOW_STATUS_OK + if (!recording_) { + return tsl::errors::Internal("TraceMeRecorder not started"); + } + events_ = tsl::profiler::TraceMeRecorder::Stop(); + recording_ = false; + return tsl::OkStatus(); +} + +tsl::Status HostTracer::CollectData( // TENSORFLOW_STATUS_OK + tensorflow::profiler::XSpace* space) { + VLOG(2) << "Collecting data to XSpace from HostTracer."; + if (recording_) { + return tsl::errors::Internal("TraceMeRecorder not stopped"); + } + if (events_.empty()) { + return tsl::OkStatus(); + } + tensorflow::profiler::XPlane* plane = + tsl::profiler::FindOrAddMutablePlaneWithName( + space, tsl::profiler::kHostThreadsPlaneName); + ConvertCompleteEventsToXPlane(start_timestamp_ns_, std::exchange(events_, {}), + plane); + return tsl::OkStatus(); +} + +} // namespace + +std::unique_ptr CreateHostTracer( + const HostTracerOptions& options) { + if (options.trace_level == 0) return nullptr; + return std::make_unique(options.trace_level); +} + +} // namespace profiler +} // namespace xla diff --git a/tensorflow/compiler/xla/backends/profiler/cpu/host_tracer.h b/tensorflow/compiler/xla/backends/profiler/cpu/host_tracer.h new file mode 100644 index 00000000000..79d0fb6f2e7 --- /dev/null +++ b/tensorflow/compiler/xla/backends/profiler/cpu/host_tracer.h @@ -0,0 +1,43 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_CPU_HOST_TRACER_H_ +#define TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_CPU_HOST_TRACER_H_ + +#include + +#include "tensorflow/tsl/profiler/lib/profiler_interface.h" + +namespace xla { +namespace profiler { + +struct HostTracerOptions { + // Levels of host tracing: + // - Level 0 is used to disable host traces. + // - Level 1 enables tracing of only user instrumented (or default) TraceMe. + // - Level 2 enables tracing of all level 1 TraceMe(s) and instrumented high + // level program execution details (expensive TF ops, XLA ops, etc). + // This is the default. + // - Level 3 enables tracing of all level 2 TraceMe(s) and more verbose + // (low-level) program execution details (cheap TF ops, etc). + int trace_level = 2; +}; + +std::unique_ptr CreateHostTracer( + const HostTracerOptions& options); + +} // namespace profiler +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_CPU_HOST_TRACER_H_ diff --git a/tensorflow/core/profiler/backends/cpu/host_tracer_factory.cc b/tensorflow/compiler/xla/backends/profiler/cpu/host_tracer_factory.cc similarity index 72% rename from tensorflow/core/profiler/backends/cpu/host_tracer_factory.cc rename to tensorflow/compiler/xla/backends/profiler/cpu/host_tracer_factory.cc index c4313446cd5..5ade9e67320 100644 --- a/tensorflow/core/profiler/backends/cpu/host_tracer_factory.cc +++ b/tensorflow/compiler/xla/backends/profiler/cpu/host_tracer_factory.cc @@ -12,16 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/backends/cpu/host_tracer.h" -#include "tensorflow/core/profiler/lib/profiler_factory.h" -#include "tensorflow/core/profiler/profiler_options.pb.h" +#include -namespace tensorflow { +#include "tensorflow/compiler/xla/backends/profiler/cpu/host_tracer.h" +#include "tensorflow/tsl/profiler/lib/profiler_factory.h" +#include "tensorflow/tsl/profiler/protobuf/profiler_options.pb.h" + +namespace xla { namespace profiler { namespace { -std::unique_ptr CreateHostTracer( - const ProfileOptions& profile_options) { +std::unique_ptr CreateHostTracer( + const tensorflow::ProfileOptions& profile_options) { HostTracerOptions options; options.trace_level = profile_options.host_tracer_level(); return CreateHostTracer(options); @@ -34,4 +36,4 @@ auto register_host_tracer_factory = [] { } // namespace } // namespace profiler -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/core/profiler/backends/cpu/metadata_collector.cc b/tensorflow/compiler/xla/backends/profiler/cpu/metadata_collector.cc similarity index 69% rename from tensorflow/core/profiler/backends/cpu/metadata_collector.cc rename to tensorflow/compiler/xla/backends/profiler/cpu/metadata_collector.cc index 47029fbe028..9f3a23b34d8 100644 --- a/tensorflow/core/profiler/backends/cpu/metadata_collector.cc +++ b/tensorflow/compiler/xla/backends/profiler/cpu/metadata_collector.cc @@ -18,19 +18,19 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/backends/profiler/cpu/metadata_utils.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/xla_debug_info_manager.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/profiler/backends/cpu/metadata_utils.h" -#include "tensorflow/core/profiler/lib/profiler_factory.h" -#include "tensorflow/core/profiler/lib/profiler_interface.h" -#include "tensorflow/core/profiler/profiler_options.pb.h" -#include "tensorflow/core/profiler/protobuf/xplane.pb.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" - -namespace tensorflow { +#include "tensorflow/tsl/platform/macros.h" +#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/profiler/lib/profiler_factory.h" +#include "tensorflow/tsl/profiler/lib/profiler_interface.h" +#include "tensorflow/tsl/profiler/protobuf/profiler_options.pb.h" +#include "tensorflow/tsl/profiler/protobuf/xplane.pb.h" +#include "tensorflow/tsl/profiler/utils/xplane_schema.h" +#include "tensorflow/tsl/profiler/utils/xplane_utils.h" + +namespace xla { namespace profiler { namespace { @@ -38,7 +38,7 @@ namespace { // from XLA runtime etc. // // Thread-safety: This class is go/thread-compatible. -class MetadataCollector : public ProfilerInterface { +class MetadataCollector : public tsl::profiler::ProfilerInterface { public: MetadataCollector() = default; @@ -58,9 +58,11 @@ class MetadataCollector : public ProfilerInterface { return OkStatus(); } - Status CollectData(XSpace* space) override { + Status CollectData(tsl::profiler::XSpace* space) override { if (!debug_info_.empty()) { - XPlane* plane = FindOrAddMutablePlaneWithName(space, kMetadataPlaneName); + tsl::profiler::XPlane* plane = + tsl::profiler::FindOrAddMutablePlaneWithName( + space, tsl::profiler::kMetadataPlaneName); MetadataXPlaneBuilder metadata_plane(plane); for (auto& hlo_proto : debug_info_) { metadata_plane.AddHloProto(hlo_proto->hlo_module().id(), *hlo_proto); @@ -78,8 +80,8 @@ class MetadataCollector : public ProfilerInterface { TF_DISALLOW_COPY_AND_ASSIGN(MetadataCollector); }; -std::unique_ptr CreatMetadataCollector( - const ProfileOptions& options) { +std::unique_ptr CreatMetadataCollector( + const tensorflow::ProfileOptions& options) { return options.enable_hlo_proto() ? std::make_unique() : nullptr; } @@ -92,4 +94,4 @@ auto register_metadata_collector_factory = [] { }(); } // namespace profiler -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/backends/profiler/cpu/metadata_utils.h b/tensorflow/compiler/xla/backends/profiler/cpu/metadata_utils.h new file mode 100644 index 00000000000..349ddb977d0 --- /dev/null +++ b/tensorflow/compiler/xla/backends/profiler/cpu/metadata_utils.h @@ -0,0 +1,55 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_CPU_METADATA_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_CPU_METADATA_UTILS_H_ + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/tsl/profiler/convert/xla_op_utils.h" +#include "tensorflow/tsl/profiler/protobuf/xplane.pb.h" +#include "tensorflow/tsl/profiler/utils/xplane_builder.h" +#include "tensorflow/tsl/profiler/utils/xplane_schema.h" + +namespace xla { +namespace profiler { + +class MetadataXPlaneBuilder { + public: + explicit MetadataXPlaneBuilder(tsl::profiler::XPlane* raw_plane) + : plane_(raw_plane), + hlo_proto_stat_(plane_.GetOrCreateStatMetadata( + GetStatTypeStr(tsl::profiler::StatType::kHloProto))) {} + + void AddHloProto(uint64_t program_id, const xla::HloProto& hlo_proto) { + tsl::profiler::XEventMetadata* event_metadata = + plane_.GetOrCreateEventMetadata(program_id); + if (event_metadata->name().empty()) { + event_metadata->set_name(tsl::profiler::HloModuleNameWithProgramId( + hlo_proto.hlo_module().name(), program_id)); + tsl::profiler::XStatsBuilder event_stats( + event_metadata, &plane_); + event_stats.AddStatValue(*hlo_proto_stat_, hlo_proto); + } + } + + private: + tsl::profiler::XPlaneBuilder plane_; + const tsl::profiler::XStatMetadata* hlo_proto_stat_ = nullptr; +}; + +} // namespace profiler +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_CPU_METADATA_UTILS_H_ diff --git a/tensorflow/core/profiler/backends/cpu/python_tracer.cc b/tensorflow/compiler/xla/backends/profiler/cpu/python_tracer.cc similarity index 56% rename from tensorflow/core/profiler/backends/cpu/python_tracer.cc rename to tensorflow/compiler/xla/backends/profiler/cpu/python_tracer.cc index 8fa870ec2de..bd381905f23 100644 --- a/tensorflow/core/profiler/backends/cpu/python_tracer.cc +++ b/tensorflow/compiler/xla/backends/profiler/cpu/python_tracer.cc @@ -12,79 +12,79 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/backends/cpu/python_tracer.h" +#include "tensorflow/compiler/xla/backends/profiler/cpu/python_tracer.h" #include -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/profiler/lib/profiler_interface.h" -#include "tensorflow/core/profiler/protobuf/xplane.pb.h" -#include "tensorflow/python/profiler/internal/python_hooks.h" +#include "tensorflow/compiler/xla/python/profiler/internal/python_hooks.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/logging.h" +#include "tensorflow/tsl/platform/macros.h" +#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/profiler/lib/profiler_interface.h" +#include "tensorflow/tsl/profiler/protobuf/xplane.pb.h" -namespace tensorflow { +namespace xla { namespace profiler { namespace { // This profiler interface enables Python function call tracing. -class PythonTracer : public ProfilerInterface { +class PythonTracer : public tsl::profiler::ProfilerInterface { public: explicit PythonTracer(const PythonHooksOptions& options) : options_(options) {} ~PythonTracer() override; - Status Start() override; + tsl::Status Start() override; // TENSORFLOW_STATUS_OK - Status Stop() override; + tsl::Status Stop() override; // TENSORFLOW_STATUS_OK - Status CollectData(XSpace* space) override; + tsl::Status CollectData( // TENSORFLOW_STATUS_OK + tensorflow::profiler::XSpace* space) override; private: bool recording_ = false; const PythonHooksOptions options_; - std::unique_ptr context_; + std::unique_ptr context_; TF_DISALLOW_COPY_AND_ASSIGN(PythonTracer); }; -PythonTracer::~PythonTracer() { - Stop().IgnoreError(); -} +PythonTracer::~PythonTracer() { Stop().IgnoreError(); } // NOLINT -Status PythonTracer::Start() { +tsl::Status PythonTracer::Start() { // TENSORFLOW_STATUS_OK if (recording_) { - return errors::Internal("PythonTracer already started"); + return tsl::errors::Internal("PythonTracer already started"); } VLOG(1) << __FUNCTION__; recording_ = true; PythonHooks::GetSingleton()->Start(options_); - return OkStatus(); + return tsl::OkStatus(); } -Status PythonTracer::Stop() { +tsl::Status PythonTracer::Stop() { // TENSORFLOW_STATUS_OK if (!recording_) { - return errors::Internal("PythonTracer not started"); + return tsl::errors::Internal("PythonTracer not started"); } VLOG(1) << __FUNCTION__; context_ = PythonHooks::GetSingleton()->Stop(); recording_ = false; - return OkStatus(); + return tsl::OkStatus(); } -Status PythonTracer::CollectData(XSpace* space) { +tsl::Status PythonTracer::CollectData( // TENSORFLOW_STATUS_OK + tensorflow::profiler::XSpace* space) { VLOG(2) << "Collecting data to XSpace from PythonTracer."; if (context_) { context_->Finalize(space); context_.reset(); } - return OkStatus(); + return tsl::OkStatus(); } } // namespace -std::unique_ptr CreatePythonTracer( +std::unique_ptr CreatePythonTracer( const PythonTracerOptions& options) { if (!options.enable_trace_python_function && !options.enable_python_traceme) { return nullptr; @@ -94,8 +94,8 @@ std::unique_ptr CreatePythonTracer( options.enable_trace_python_function; pyhooks_options.enable_python_traceme = options.enable_python_traceme; pyhooks_options.end_to_end_mode = options.end_to_end_mode; - return absl::make_unique(pyhooks_options); + return std::make_unique(pyhooks_options); } } // namespace profiler -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/backends/profiler/cpu/python_tracer.h b/tensorflow/compiler/xla/backends/profiler/cpu/python_tracer.h new file mode 100644 index 00000000000..4413bc72440 --- /dev/null +++ b/tensorflow/compiler/xla/backends/profiler/cpu/python_tracer.h @@ -0,0 +1,43 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_CPU_PYTHON_TRACER_H_ +#define TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_CPU_PYTHON_TRACER_H_ + +#include + +#include "tensorflow/tsl/profiler/lib/profiler_interface.h" + +namespace xla { +namespace profiler { + +struct PythonTracerOptions { + // Whether to enable python function calls tracing. + // NOTE: Runtime overhead ensues if enabled. + bool enable_trace_python_function = false; + + // Whether to enable python TraceMe instrumentation. + bool enable_python_traceme = true; + + // Whether profiling stops within an atexit handler. + bool end_to_end_mode = false; +}; + +std::unique_ptr CreatePythonTracer( + const PythonTracerOptions& options); + +} // namespace profiler +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_CPU_PYTHON_TRACER_H_ diff --git a/tensorflow/core/profiler/backends/cpu/python_tracer_factory.cc b/tensorflow/compiler/xla/backends/profiler/cpu/python_tracer_factory.cc similarity index 74% rename from tensorflow/core/profiler/backends/cpu/python_tracer_factory.cc rename to tensorflow/compiler/xla/backends/profiler/cpu/python_tracer_factory.cc index e9b8b8c2b85..71818f00c87 100644 --- a/tensorflow/core/profiler/backends/cpu/python_tracer_factory.cc +++ b/tensorflow/compiler/xla/backends/profiler/cpu/python_tracer_factory.cc @@ -12,16 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/backends/cpu/python_tracer.h" -#include "tensorflow/core/profiler/lib/profiler_factory.h" -#include "tensorflow/core/profiler/profiler_options.pb.h" +#include -namespace tensorflow { +#include "tensorflow/compiler/xla/backends/profiler/cpu/python_tracer.h" +#include "tensorflow/tsl/profiler/lib/profiler_factory.h" +#include "tensorflow/tsl/profiler/protobuf/profiler_options.pb.h" + +namespace xla { namespace profiler { namespace { -std::unique_ptr CreatePythonTracer( - const ProfileOptions& profile_options) { +std::unique_ptr CreatePythonTracer( + const tensorflow::ProfileOptions& profile_options) { PythonTracerOptions options; options.enable_trace_python_function = profile_options.python_tracer_level(); options.enable_python_traceme = profile_options.host_tracer_level(); @@ -35,4 +37,4 @@ auto register_python_tracer_factory = [] { } // namespace } // namespace profiler -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/backends/profiler/gpu/BUILD b/tensorflow/compiler/xla/backends/profiler/gpu/BUILD new file mode 100644 index 00000000000..12a4f1a9926 --- /dev/null +++ b/tensorflow/compiler/xla/backends/profiler/gpu/BUILD @@ -0,0 +1,290 @@ +load("@local_config_cuda//cuda:build_defs.bzl", "cuda_library", "if_cuda") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load( + "//tensorflow/tsl:tsl.bzl", + "tsl_copts", + "tsl_gpu_library", +) +load("//tensorflow/tsl:tsl.default.bzl", "tsl_gpu_cc_test") +load( + "//tensorflow/tsl/platform:build_config.bzl", + "tf_additional_device_tracer_srcs", +) +load( + "//tensorflow/tsl/platform:build_config_root.bzl", + "tf_cuda_tests_tags", +) +load( + "//tensorflow/compiler/xla/stream_executor:build_defs.bzl", + "tf_additional_cupti_deps", +) +load("//tensorflow/tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") +load( + "//tensorflow/tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) + +package( + default_visibility = ["//tensorflow/compiler/xla:internal"], + features = [ + "-layering_check", + ], + licenses = ["notice"], +) + +tsl_gpu_library( + name = "device_tracer", + srcs = tf_additional_device_tracer_srcs(), + copts = tf_profiler_copts() + tsl_copts(), + cuda_deps = [ + ":cupti_tracer", + ":cupti_wrapper", + ":rocm_tracer", + ], + deps = [ + ":cupti_utils", + "//tensorflow/tsl/platform:abi", + "//tensorflow/tsl/platform:env_time", + "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:macros", + "//tensorflow/tsl/platform:mutex", + "//tensorflow/tsl/platform:thread_annotations", + "//tensorflow/tsl/profiler/lib:profiler_factory", + "//tensorflow/tsl/profiler/lib:profiler_interface", + "//tensorflow/tsl/profiler/protobuf:xplane_proto_cc", + "//tensorflow/tsl/profiler/utils:time_utils", + "//tensorflow/tsl/util:env_var", + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + ], + alwayslink = 1, +) + +tsl_gpu_library( + name = "cupti_interface", + hdrs = if_cuda(["cupti_interface.h"]), + copts = tf_profiler_copts() + tsl_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/tsl/platform:macros", + "//tensorflow/tsl/platform:types", + ] + tf_additional_cupti_deps(), +) + +tsl_gpu_library( + name = "mock_cupti", + testonly = 1, + hdrs = if_cuda(["mock_cupti.h"]), + copts = tf_profiler_copts() + tsl_copts(), + cuda_deps = [ + ":cupti_interface", + ], + deps = [ + "//tensorflow/tsl/platform:test", + ], +) + +tsl_gpu_library( + name = "cupti_error_manager", + srcs = if_cuda(["cupti_error_manager.cc"]), + hdrs = if_cuda(["cupti_error_manager.h"]), + copts = tf_profiler_copts() + tsl_copts(), + cuda_deps = [ + ":cupti_interface", + ":cupti_wrapper", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/tsl/platform:logging", + "//tensorflow/tsl/platform:mutex", + "//tensorflow/tsl/platform:thread_annotations", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/debugging:leak_check", + "@com_google_absl//absl/synchronization", + ], +) + +tsl_gpu_cc_test( + name = "cupti_error_manager_test", + size = "small", + srcs = ["cupti_error_manager_test.cc"], + tags = tf_cuda_tests_tags() + [ + "nomac", + "gpu_cupti", + ], + deps = [ + "//tensorflow/tsl/platform:test_main", + ] + if_cuda_is_configured([ + ":cuda_test", + ":cupti_error_manager", + ":cupti_tracer", + ":cupti_utils", + ":cupti_wrapper", + ":mock_cupti", + "@com_google_absl//absl/memory", + "//tensorflow/tsl/platform:env_impl", + "//tensorflow/tsl/profiler/utils:time_utils", + "//tensorflow/tsl/profiler/backends/cpu:annotation_stack_impl", + ]), +) + +cuda_library( + name = "cuda_test", + testonly = 1, + srcs = ["cuda_test.cu.cc"], + hdrs = ["cuda_test.h"], + copts = select({ + "@local_config_cuda//cuda:using_nvcc": [ + "-nvcc_options", + "ptxas-options=-v", + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/tsl/platform:test", + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cudart", + ], +) + +# Rationale for linkstatic: The symbols in libcupti_static.a have hidden +# visibility. The wrapper will fail to find them if it's ever built as a +# shared library. This is the same issue as b/11094727. Always linking +# the wrapper statically works around the issue. An alternative would be +# to patch libcupti_static, but it's not worth the trouble considering +# that the wrapper is about the only direct user. +tsl_gpu_library( + name = "cupti_wrapper", + srcs = if_cuda(["cupti_wrapper.cc"]), + hdrs = if_cuda(["cupti_wrapper.h"]), + copts = tf_profiler_copts() + tsl_copts(), + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + ":cupti_interface", + ] + tf_additional_cupti_deps(), +) + +tsl_gpu_library( + name = "cupti_tracer", + srcs = if_cuda(["cupti_tracer.cc"]), + hdrs = if_cuda(["cupti_tracer.h"]), + copts = tf_profiler_copts() + tsl_copts(), + visibility = ["//visibility:public"], + deps = [ + ":cupti_collector", + ":cupti_interface", + ":cupti_utils", + ":nvtx_utils", + "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:logging", + "//tensorflow/tsl/platform:macros", + "//tensorflow/tsl/platform:platform_port", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/platform:types", + "//tensorflow/tsl/profiler/backends/cpu:annotation_stack", + "//tensorflow/tsl/profiler/lib:scoped_annotation", + "//tensorflow/tsl/profiler/utils:buffer_pool", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/types:optional", + ], +) + +tsl_gpu_library( + name = "rocm_tracer", + srcs = if_rocm(["rocm_tracer.cc"]), + hdrs = if_rocm(["rocm_tracer.h"]), + copts = tf_profiler_copts() + tsl_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla/stream_executor/rocm:roctracer_wrapper", + "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:logging", + "//tensorflow/tsl/platform:macros", + "//tensorflow/tsl/platform:platform_port", + "//tensorflow/tsl/profiler/backends/cpu:annotation_stack", + "//tensorflow/tsl/profiler/utils:time_utils", + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/types:optional", + ], +) + +tsl_gpu_library( + name = "nvtx_utils", + srcs = if_cuda(["nvtx_utils.cc"]), + hdrs = if_cuda(["nvtx_utils.h"]), + copts = tf_profiler_copts() + tsl_copts(), + deps = [ + "//tensorflow/tsl/platform", + "//tensorflow/tsl/platform:macros", + "//tensorflow/tsl/platform:mutex", + ], +) + +tsl_gpu_library( + name = "cupti_collector", + srcs = if_cuda(["cupti_collector.cc"]), + hdrs = if_cuda(["cupti_collector.h"]), + copts = tf_profiler_copts() + tsl_copts(), + visibility = ["//visibility:public"], + deps = [ + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/strings", + "//tensorflow/tsl/platform:abi", + "//tensorflow/tsl/platform:platform_port", + "//tensorflow/tsl/platform:mutex", + "//tensorflow/tsl/platform:macros", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/platform:types", + "//tensorflow/tsl/profiler/protobuf:xplane_proto_cc", + "//tensorflow/tsl/profiler/utils:parse_annotation", + "//tensorflow/tsl/profiler/utils:xplane_builder", + "//tensorflow/tsl/profiler/utils:xplane_schema", + "//tensorflow/tsl/profiler/utils:xplane_utils", + "//tensorflow/tsl/profiler/utils:trace_utils", + ] + tf_additional_cupti_deps(), +) + +cc_library( + name = "cupti_collector_header", + hdrs = ["cupti_collector.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/tsl/platform:macros", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/platform:types", + "//tensorflow/tsl/profiler/protobuf:xplane_proto_cc", + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/strings", + ], +) + +tsl_gpu_library( + name = "cupti_utils", + srcs = if_cuda(["cupti_utils.cc"]), + copts = tf_profiler_copts() + tsl_copts(), + cuda_deps = [ + ":cupti_error_manager", + ":cupti_interface", + ":cupti_wrapper", + ], + visibility = ["//visibility:public"], + alwayslink = 1, +) diff --git a/tensorflow/core/profiler/backends/gpu/cuda_test.cu.cc b/tensorflow/compiler/xla/backends/profiler/gpu/cuda_test.cu.cc similarity index 96% rename from tensorflow/core/profiler/backends/gpu/cuda_test.cu.cc rename to tensorflow/compiler/xla/backends/profiler/gpu/cuda_test.cu.cc index 24e4656d1a9..4ff692341a8 100644 --- a/tensorflow/core/profiler/backends/gpu/cuda_test.cu.cc +++ b/tensorflow/compiler/xla/backends/profiler/gpu/cuda_test.cu.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // Creates some GPU activity to test functionalities of gpuperfcounter/gputrace. -#include "tensorflow/core/profiler/backends/gpu/cuda_test.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cuda_test.h" #if GOOGLE_CUDA #include @@ -23,9 +23,9 @@ limitations under the License. #include "third_party/gpus/cuda/include/driver_types.h" #endif -#include "tensorflow/core/platform/test.h" +#include "tensorflow/tsl/platform/test.h" -namespace tensorflow { +namespace xla { namespace profiler { namespace test { @@ -185,4 +185,4 @@ void MemCopyP2PExplicit() { } // namespace test } // namespace profiler -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/backends/profiler/gpu/cuda_test.h b/tensorflow/compiler/xla/backends/profiler/gpu/cuda_test.h new file mode 100644 index 00000000000..ec583d5f4e1 --- /dev/null +++ b/tensorflow/compiler/xla/backends/profiler/gpu/cuda_test.h @@ -0,0 +1,55 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUDA_TEST_H_ +#define TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUDA_TEST_H_ + +namespace xla { +namespace profiler { +namespace test { +// Calls a function on the device to print a string as many times as indicated +// by iters. +void PrintfKernel(int iters = 1); + +// Calls an empty kernel (named "empty") on the device as many times as +// indicated by iters. +void EmptyKernel(int iters = 1); + +// Waits for device activity to complete. +void Synchronize(); + +// Copies a few bytes of memory from host to device. +void MemCopyH2D(); + +// Copies a few bytes of memory from device to host, asynchronously. +void MemCopyH2D_Async(); + +// Copies a few bytes of memory from device to host. +void MemCopyD2H(); + +// Returns true if it s possible to copy bytes from device 0 to device 1. +bool MemCopyP2PAvailable(); + +// Copies a few bytes of memory from device 0 to device 1. +void MemCopyP2PImplicit(); + +// Copies a few bytes of memory from device 0 to device 1. +void MemCopyP2PExplicit(); + +} // namespace test +} // namespace profiler +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUDA_TEST_H_ diff --git a/tensorflow/core/profiler/backends/gpu/cupti_collector.cc b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_collector.cc similarity index 89% rename from tensorflow/core/profiler/backends/gpu/cupti_collector.cc rename to tensorflow/compiler/xla/backends/profiler/gpu/cupti_collector.cc index 3018da4bd14..17a391be6f1 100644 --- a/tensorflow/core/profiler/backends/gpu/cupti_collector.cc +++ b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_collector.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/backends/gpu/cupti_collector.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_collector.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -23,20 +23,36 @@ limitations under the License. #include "third_party/gpus/cuda/extras/CUPTI/include/cupti_activity.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_occupancy.h" -#include "tensorflow/core/platform/abi.h" -#include "tensorflow/core/platform/host_info.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/profiler/utils/parse_annotation.h" -#include "tensorflow/core/profiler/utils/trace_utils.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" - -namespace tensorflow { +#include "tensorflow/tsl/platform/abi.h" +#include "tensorflow/tsl/platform/host_info.h" +#include "tensorflow/tsl/platform/mutex.h" +#include "tensorflow/tsl/profiler/utils/parse_annotation.h" +#include "tensorflow/tsl/profiler/utils/trace_utils.h" +#include "tensorflow/tsl/profiler/utils/xplane_builder.h" +#include "tensorflow/tsl/profiler/utils/xplane_schema.h" +#include "tensorflow/tsl/profiler/utils/xplane_utils.h" + +namespace xla { namespace profiler { namespace { +using tensorflow::profiler::XEventMetadata; +using tensorflow::profiler::XSpace; +using tsl::mutex; +using tsl::mutex_lock; +using tsl::profiler::Annotation; +using tsl::profiler::FindOrAddMutablePlaneWithName; +using tsl::profiler::GpuPlaneName; +using tsl::profiler::kCuptiDriverApiPlaneName; +using tsl::profiler::kDeviceVendorNvidia; +using tsl::profiler::kThreadIdOverhead; +using tsl::profiler::ParseAnnotationStack; +using tsl::profiler::StatType; +using tsl::profiler::XEventBuilder; +using tsl::profiler::XLineBuilder; +using tsl::profiler::XPlaneBuilder; + bool IsHostEvent(const CuptiTracerEvent& event, int64_t* line_id) { // DriverCallback(i.e. kernel launching) events are host events. if (event.source == CuptiTracerEventSource::DriverCallback) { @@ -124,7 +140,7 @@ class PerDeviceCollector { } void CreateXEvent(const CuptiTracerEvent& event, XPlaneBuilder* plane, - uint64 start_gpu_ns, uint64 end_gpu_ns, + tsl::uint64 start_gpu_ns, tsl::uint64 end_gpu_ns, XLineBuilder* line) { if (event.start_time_ns < start_gpu_ns || event.end_time_ns > end_gpu_ns || event.start_time_ns > event.end_time_ns) { @@ -133,7 +149,7 @@ class PerDeviceCollector { << " end time(ns): " << event.end_time_ns; return; } - std::string kernel_name = port::MaybeAbiDemangle(event.name.c_str()); + std::string kernel_name = tsl::port::MaybeAbiDemangle(event.name.c_str()); if (kernel_name.empty()) { kernel_name = GetTraceEventTypeName(event.type); } @@ -161,7 +177,7 @@ class PerDeviceCollector { if (event.context_id != CuptiTracerEvent::kInvalidContextId) { xevent.AddStatValue( *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kContextId)), - absl::StrCat("$$", static_cast(event.context_id))); + absl::StrCat("$$", static_cast(event.context_id))); } if (event.type == CuptiTracerEventType::Kernel && @@ -190,10 +206,11 @@ class PerDeviceCollector { occ_stats.occupancy_pct); xevent.AddStatValue(*plane->GetOrCreateStatMetadata( GetStatTypeStr(StatType::kOccupancyMinGridSize)), - static_cast(occ_stats.min_grid_size)); - xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr( - StatType::kOccupancySuggestedBlockSize)), - static_cast(occ_stats.suggested_block_size)); + static_cast(occ_stats.min_grid_size)); + xevent.AddStatValue( + *plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kOccupancySuggestedBlockSize)), + static_cast(occ_stats.suggested_block_size)); xevent.AddStatValue(*plane->GetOrCreateStatMetadata( GetStatTypeStr(StatType::kKernelDetails)), *plane->GetOrCreateStatMetadata(ToXStat( @@ -272,11 +289,11 @@ class PerDeviceCollector { } } - absl::optional GetDeviceAttribute(CUdevice device, - CUdevice_attribute attrib) { + std::optional GetDeviceAttribute(CUdevice device, + CUdevice_attribute attrib) { int ret_val; CUresult err = cuDeviceGetAttribute(&ret_val, attrib, device); - if (err != CUDA_SUCCESS) return absl::nullopt; + if (err != CUDA_SUCCESS) return std::nullopt; return ret_val; } @@ -303,7 +320,7 @@ class PerDeviceCollector { events_.emplace_back(std::move(event)); } - size_t Flush(uint64 start_gpu_ns, uint64 end_gpu_ns, + size_t Flush(tsl::uint64 start_gpu_ns, tsl::uint64 end_gpu_ns, XPlaneBuilder* device_plane, XPlaneBuilder* host_plane) { mutex_lock l(m_); // Tracking event types per line. @@ -375,7 +392,7 @@ class PerDeviceCollector { // Times 2 because HBM is DDR memory; it gets two data bits per each // data lane. auto memory_bandwidth = - uint64{2} * (*mem_clock_khz) * 1000 * (*mem_bus_width_bits) / 8; + tsl::uint64{2} * (*mem_clock_khz) * 1000 * (*mem_bus_width_bits) / 8; device_plane->AddStatValue( *device_plane->GetOrCreateStatMetadata( GetStatTypeStr(StatType::kDevCapMemoryBandwidth)), @@ -387,7 +404,7 @@ class PerDeviceCollector { device_plane->AddStatValue( *device_plane->GetOrCreateStatMetadata( GetStatTypeStr(StatType::kDevCapMemorySize)), - static_cast(total_memory)); + static_cast(total_memory)); } auto compute_capability_major = GetDeviceAttribute( @@ -452,7 +469,7 @@ class PerDeviceCollector { } // namespace -void AnnotationMap::Add(uint32 device_id, uint32 correlation_id, +void AnnotationMap::Add(tsl::uint32 device_id, tsl::uint32 correlation_id, const absl::string_view annotation, const absl::string_view nvtx_range) { if (annotation.empty() && nvtx_range.empty()) return; @@ -471,8 +488,8 @@ void AnnotationMap::Add(uint32 device_id, uint32 correlation_id, } } -AnnotationMap::AnnotationInfo AnnotationMap::LookUp(uint32 device_id, - uint32 correlation_id) { +AnnotationMap::AnnotationInfo AnnotationMap::LookUp( + tsl::uint32 device_id, tsl::uint32 correlation_id) { if (device_id >= per_device_map_.size()) return AnnotationInfo(); auto& per_device_map = per_device_map_[device_id]; absl::MutexLock lock(&per_device_map.mutex); @@ -486,7 +503,8 @@ AnnotationMap::AnnotationInfo AnnotationMap::LookUp(uint32 device_id, class CuptiTraceCollectorImpl : public CuptiTraceCollector { public: CuptiTraceCollectorImpl(const CuptiTracerCollectorOptions& option, - uint64 start_walltime_ns, uint64 start_gpu_ns) + tsl::uint64 start_walltime_ns, + tsl::uint64 start_gpu_ns) : CuptiTraceCollector(option), num_callback_events_(0), num_activity_events_(0), @@ -512,13 +530,14 @@ class CuptiTraceCollectorImpl : public CuptiTraceCollector { } per_device_collector_[event.device_id].AddEvent(std::move(event)); } - void OnEventsDropped(const std::string& reason, uint32 num_events) override { + void OnEventsDropped(const std::string& reason, + tsl::uint32 num_events) override { absl::MutexLock lock(&mutex_); dropped_events_[reason] += num_events; } void Flush() override {} // Returns true if some GPU events are captured. - bool Export(XSpace* space, uint64 end_gpu_ns) override { + bool Export(XSpace* space, tsl::uint64 end_gpu_ns) override { LOG(INFO) << " GpuTracer has collected " << num_callback_events_ << " callback api events and " << num_activity_events_ << " activity events. " << ReportDroppedEvents(); @@ -546,7 +565,7 @@ class CuptiTraceCollectorImpl : public CuptiTraceCollector { std::string ReportDroppedEvents() { absl::MutexLock lock(&mutex_); - string result; + std::string result; for (const auto& dropped : dropped_events_) { absl::StrAppend(&result, " ", dropped.second, " events dropped because ", dropped.first, ";"); @@ -557,8 +576,8 @@ class CuptiTraceCollectorImpl : public CuptiTraceCollector { std::string ReportNumEventsIfDropped() override { std::string events_dropped = ReportDroppedEvents(); if (events_dropped.empty()) return ""; - return absl::StrCat("Detected GPU events dropped on ", port::Hostname(), - ": Profiler has collected ", + return absl::StrCat("Detected GPU events dropped on ", + tsl::port::Hostname(), ": Profiler has collected ", num_callback_events_.load(), " driver events and ", num_activity_events_.load(), " device events.", events_dropped); @@ -568,10 +587,10 @@ class CuptiTraceCollectorImpl : public CuptiTraceCollector { std::atomic num_callback_events_; std::atomic num_activity_events_; absl::Mutex mutex_; - absl::flat_hash_map dropped_events_ + absl::flat_hash_map dropped_events_ ABSL_GUARDED_BY(mutex_); - uint64 start_walltime_ns_; - uint64 start_gpu_ns_; + tsl::uint64 start_walltime_ns_; + tsl::uint64 start_gpu_ns_; int num_gpus_; // Set the all XLines of specified XPlane to starting walltime. @@ -580,7 +599,7 @@ class CuptiTraceCollectorImpl : public CuptiTraceCollector { // this fact. Eventually we change line start time to corresponding // start_walltime_ns to normalize with CPU wall time. static void NormalizeTimeStamps(XPlaneBuilder* plane, - uint64 start_walltime_ns) { + tsl::uint64 start_walltime_ns) { plane->ForEachLine( [&](XLineBuilder line) { line.SetTimestampNs(start_walltime_ns); }); } @@ -591,8 +610,8 @@ class CuptiTraceCollectorImpl : public CuptiTraceCollector { }; std::unique_ptr CreateCuptiCollector( - const CuptiTracerCollectorOptions& options, const uint64 start_walltime_ns, - const uint64 start_gputime_ns) { + const CuptiTracerCollectorOptions& options, + const tsl::uint64 start_walltime_ns, const tsl::uint64 start_gputime_ns) { return std::make_unique(options, start_walltime_ns, start_gputime_ns); } @@ -621,4 +640,4 @@ absl::string_view GetMemoryKindName(int8_t memory_kind) { } } // namespace profiler -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/backends/profiler/gpu/cupti_collector.h b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_collector.h new file mode 100644 index 00000000000..654a906834e --- /dev/null +++ b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_collector.h @@ -0,0 +1,277 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUPTI_COLLECTOR_H_ +#define TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUPTI_COLLECTOR_H_ + +#include + +#include "absl/container/fixed_array.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_set.h" +#include "absl/strings/string_view.h" +#include "tensorflow/tsl/platform/macros.h" +#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/platform/types.h" +#include "tensorflow/tsl/profiler/protobuf/xplane.pb.h" + +namespace xla { +namespace profiler { + +struct MemcpyDetails { + // The amount of data copied for memcpy events. + size_t num_bytes; + // The destination device for peer-2-peer communication (memcpy). The source + // device is implicit: it's the current device. + tsl::uint32 destination; + // Whether or not the memcpy is asynchronous. + bool async; + // This contains CUpti_ActivityMemcpyKind for activity event (on device). + // For events from other CuptiTracerEventSource, it is always 0. + tsl::int8 copy_kind; + // CUpti_ActivityMemoryKind of source. + tsl::int8 src_mem_kind; + // CUpti_ActivityMemoryKind of destination. + tsl::int8 dst_mem_kind; + + // ID of the hardware channel on which this operation ran. + uint32_t channel_id = -1; + // CUpti_ChannelType of the channel above. + int8_t channel_type = 0; // CUPTI_CHANNEL_TYPE_INVALID +}; + +struct MemAllocDetails { + // Size of memory to be written over in bytes. + size_t num_bytes; + // The CUpti_ActivityMemoryKind value for this activity event. + tsl::int8 mem_kind; + // The virtual address of allocation. 0 if it is a free operation. + tsl::uint64 address; +}; + +using MemFreeDetails = MemAllocDetails; + +// Memory residency contains details read from CUpti_ActivityMemory type. This +// is populated in the CUPTI tracer encounters a CUPTI_ACTIVITY_KIND_MEMORY +// event. The start of this even corresponse to a cudaMalloc, and the end +// corresponds to a cudaFree. +using MemoryResidencyDetails = MemAllocDetails; + +struct MemsetDetails { + // Size of memory to be written over in bytes. + size_t num_bytes; + // The CUpti_ActivityMemoryKind value for this activity event. + tsl::int8 mem_kind; + // Whether or not the memset is asynchronous. + bool async; + + // ID of the hardware channel on which this operation ran. + uint32_t channel_id = -1; + // CUpti_ChannelType of the channel above. + int8_t channel_type = 0; // CUPTI_CHANNEL_TYPE_INVALID +}; + +struct KernelDetails { + // The number of registers used in this kernel. + tsl::uint32 registers_per_thread; + // The amount of shared memory space used by a thread block. + tsl::uint32 static_shared_memory_usage; + // The amount of dynamic memory space used by a thread block. + tsl::uint32 dynamic_shared_memory_usage; + // X-dimension of a thread block. + tsl::uint32 block_x; + // Y-dimension of a thread block. + tsl::uint32 block_y; + // Z-dimension of a thread block. + tsl::uint32 block_z; + // X-dimension of a grid. + tsl::uint32 grid_x; + // Y-dimension of a grid. + tsl::uint32 grid_y; + // Z-dimension of a grid. + tsl::uint32 grid_z; + + // ID of the hardware channel on which this operation ran. + uint32_t channel_id = -1; + // CUpti_ChannelType of the channel above. + int8_t channel_type = 0; // CUPTI_CHANNEL_TYPE_INVALID +}; + +inline std::string ToXStat(const KernelDetails& kernel_info, + double occupancy_pct) { + return absl::StrCat( + "regs:", kernel_info.registers_per_thread, + " static_shared:", kernel_info.static_shared_memory_usage, + " dynamic_shared:", kernel_info.dynamic_shared_memory_usage, + " grid:", kernel_info.grid_x, ",", kernel_info.grid_y, ",", + kernel_info.grid_z, " block:", kernel_info.block_x, ",", + kernel_info.block_y, ",", kernel_info.block_z, + " occ_pct:", occupancy_pct); +} + +// Gets the name of the CUpti_ActivityMemoryKind value. +absl::string_view GetMemoryKindName(int8_t memory_kind); + +enum class CuptiTracerEventType { + Unsupported = 0, + Kernel = 1, + MemcpyH2D = 2, + MemcpyD2H = 3, + MemcpyD2D = 4, + MemcpyP2P = 5, + MemcpyOther = 6, + MemoryAlloc = 7, + Overhead = 8, + UnifiedMemory = 9, + MemoryFree = 10, + Memset = 11, + MemoryResidency = 12, + Generic = 100, +}; + +const char* GetTraceEventTypeName(const CuptiTracerEventType& type); + +enum class CuptiTracerEventSource { + Invalid = 0, + DriverCallback = 1, + Activity = 2, + // Maybe consider adding runtime callback and metric api in the future. +}; + +struct CuptiTracerEvent { + static constexpr tsl::uint32 kInvalidThreadId = + std::numeric_limits::max(); + static constexpr tsl::uint32 kInvalidCorrelationId = + std::numeric_limits::max(); + static constexpr tsl::uint64 kInvalidContextId = + std::numeric_limits::max(); + static constexpr tsl::uint64 kInvalidStreamId = + std::numeric_limits::max(); + CuptiTracerEventType type = CuptiTracerEventType::Unsupported; + CuptiTracerEventSource source = CuptiTracerEventSource::Invalid; + // Although CUpti_CallbackData::functionName is persistent, however + // CUpti_ActivityKernel4::name is not persistent, therefore we need a copy of + // it. + std::string name; + // This points to strings in AnnotationMap, which should outlive the point + // where serialization happens. + absl::string_view annotation; + absl::string_view nvtx_range; + tsl::uint64 start_time_ns = 0; + tsl::uint64 end_time_ns = 0; + tsl::uint32 device_id = 0; + tsl::uint32 correlation_id = kInvalidCorrelationId; + tsl::uint32 thread_id = kInvalidThreadId; + int64_t context_id = kInvalidContextId; + int64_t stream_id = kInvalidStreamId; + union { + // For Memcpy API and activities. `type` must be Memcpy*. + MemcpyDetails memcpy_info; + // Used for MemAlloc API. `type` must be MemoryAlloc. + MemAllocDetails memalloc_info; + // Used for kernel activities. `type` must be Kernel. + KernelDetails kernel_info; + // Used for MemFree activities. `type` must be MemoryFree. + MemFreeDetails memfree_info; + // Used for Memset API and activities. `type` must be Memset. + MemsetDetails memset_info; + // Used for Memory residency activities. `type` must be MemoryResidency. + MemoryResidencyDetails memory_residency_info; + }; +}; + +struct CuptiTracerCollectorOptions { + // Maximum number of events to collect from callback API; if -1, no limit. + // if 0, the callback API is enabled to build a correlation map, but no + // events are collected. + tsl::uint64 max_callback_api_events = 2 * 1024 * 1024; + // Maximum number of events to collect from activity API; if -1, no limit. + tsl::uint64 max_activity_api_events = 2 * 1024 * 1024; + // Maximum number of annotation strings that we can accommodate. + tsl::uint64 max_annotation_strings = 1024 * 1024; + // Number of GPUs involved. + tsl::uint32 num_gpus; +}; + +class AnnotationMap { + public: + struct AnnotationInfo { + absl::string_view annotation; + absl::string_view nvtx_range; + }; + + explicit AnnotationMap(tsl::uint64 max_size, tsl::uint32 num_gpus) + : max_size_(max_size), per_device_map_(num_gpus) {} + void Add(tsl::uint32 device_id, tsl::uint32 correlation_id, + const absl::string_view annotation, + const absl::string_view nvtx_range); + AnnotationInfo LookUp(tsl::uint32 device_id, tsl::uint32 correlation_id); + + private: + struct PerDeviceAnnotationMap { + // The population/consumption of annotations might happen from multiple + // callback/activity api related threads. + absl::Mutex mutex; + // Annotation tends to be repetitive, use a hash_set to store the strings, + // an use the reference to the string in the map. + absl::node_hash_set annotations; + absl::node_hash_set nvtx_ranges; + absl::flat_hash_map correlation_map; + }; + const tsl::uint64 max_size_; + absl::FixedArray per_device_map_; + + TF_DISALLOW_COPY_AND_ASSIGN(AnnotationMap); +}; + +class CuptiTraceCollector { + public: + explicit CuptiTraceCollector(const CuptiTracerCollectorOptions& options) + : options_(options), + annotation_map_(options.max_annotation_strings, options.num_gpus) {} + virtual ~CuptiTraceCollector() {} + + // Producer side functions (i.e. called by CuptiTracer). + virtual void AddEvent(CuptiTracerEvent&& event) = 0; + virtual void OnEventsDropped(const std::string& reason, + tsl::uint32 num_events) = 0; + virtual void Flush() = 0; + + // Consumer side functions (i.e. called by GPU tracer); + virtual bool Export(tensorflow::profiler::XSpace* space, + tsl::uint64 end_gpu_ns) { + return true; + } + virtual std::string ReportNumEventsIfDropped() { return ""; } + + AnnotationMap* annotation_map() { return &annotation_map_; } + + protected: + CuptiTracerCollectorOptions options_; + + private: + AnnotationMap annotation_map_; + + TF_DISALLOW_COPY_AND_ASSIGN(CuptiTraceCollector); +}; + +std::unique_ptr CreateCuptiCollector( + const CuptiTracerCollectorOptions& options, + const tsl::uint64 start_walltime_ns, const tsl::uint64 start_gputime_ns); + +} // namespace profiler +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUPTI_COLLECTOR_H_ diff --git a/tensorflow/core/profiler/backends/gpu/cupti_error_manager.cc b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_error_manager.cc similarity index 99% rename from tensorflow/core/profiler/backends/gpu/cupti_error_manager.cc rename to tensorflow/compiler/xla/backends/profiler/gpu/cupti_error_manager.cc index a639af78d77..9b469261099 100644 --- a/tensorflow/core/profiler/backends/gpu/cupti_error_manager.cc +++ b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_error_manager.cc @@ -13,16 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/backends/gpu/cupti_error_manager.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_error_manager.h" #include #include "absl/debugging/leak_check.h" -#include "tensorflow/core/platform/logging.h" +#include "tensorflow/tsl/platform/logging.h" -namespace tensorflow { +namespace xla { namespace profiler { +using tsl::mutex_lock; + CuptiErrorManager::CuptiErrorManager(std::unique_ptr interface) : interface_(std::move(interface)), disabled_(0), undo_disabled_(false) {} @@ -500,4 +502,4 @@ std::string CuptiErrorManager::ResultString(CUptiResult error) const { } } // namespace profiler -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/backends/profiler/gpu/cupti_error_manager.h b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_error_manager.h new file mode 100644 index 00000000000..c197889507e --- /dev/null +++ b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_error_manager.h @@ -0,0 +1,277 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUPTI_ERROR_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUPTI_ERROR_MANAGER_H_ + +#include +#include + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_interface.h" +#include "tensorflow/tsl/platform/mutex.h" +#include "tensorflow/tsl/platform/thread_annotations.h" + +namespace xla { +namespace profiler { + +class CuptiErrorManager : public xla::profiler::CuptiInterface { + public: + explicit CuptiErrorManager(std::unique_ptr interface); + + // Returns whether CUPTI is disabled. + bool Disabled() const override { return disabled_.load(); } + + // CUPTI activity API: all thread-safe + // Disables activity monitoring. + CUptiResult ActivityDisable(CUpti_ActivityKind kind) override; + + // Enables activity monitoring. If this is successfully executed, we add + // ActivityDisable to the undo log. + CUptiResult ActivityEnable(CUpti_ActivityKind kind) override; + + // Flushes all outstanding activities. + CUptiResult ActivityFlushAll(uint32_t flag) override; + + // Gets a next activity record from a pool of already collected activity + // records. + CUptiResult ActivityGetNextRecord(uint8_t* buffer, + size_t valid_buffer_size_bytes, + CUpti_Activity** record) override; + + // Reports the number of dropped activity records. + CUptiResult ActivityGetNumDroppedRecords(CUcontext context, + uint32_t stream_id, + size_t* dropped) override; + + CUptiResult ActivityConfigureUnifiedMemoryCounter( + CUpti_ActivityUnifiedMemoryCounterConfig* config, + uint32_t count) override; + + // Registers callback functions handling activity. + CUptiResult ActivityRegisterCallbacks( + CUpti_BuffersCallbackRequestFunc func_buffer_requested, + CUpti_BuffersCallbackCompleteFunc func_buffer_completed) override; + + // Returns device ID for a given context. + CUptiResult GetDeviceId(CUcontext context, uint32_t* device_id) override; + + // Returns CUPTI timestamp. + CUptiResult GetTimestamp(uint64_t* timestamp) override; + + // Explicitly destroys and cleans up all resources associated with CUPTI in + // the current process. + CUptiResult Finalize() override; + + // CUPTI callback API + // Enables or disables callback. If we successfully enables callback, we add + // EnableCallback to disable callback to the undo log. + CUptiResult EnableCallback(uint32_t enable, CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain, + CUpti_CallbackId callback_id) override; + + // Enables or disables callback domain. If we successfully enables a domain, + // we add EnableDomain to disable the domain to the undo log. + CUptiResult EnableDomain(uint32_t enable, CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain) override; + + // Subscribes callbacks. If we successfully subscribes the callback, we add + // Unsubscribe to the undo log. + CUptiResult Subscribe(CUpti_SubscriberHandle* subscriber, + CUpti_CallbackFunc callback, void* userdata) override; + + // Unsubscribes callbacks. + CUptiResult Unsubscribe(CUpti_SubscriberHandle subscriber) override; + + // CUPTI event API + // Returns a list of event domains. + CUptiResult DeviceEnumEventDomains( + CUdevice device, size_t* array_size_bytes, + CUpti_EventDomainID* domain_array) override; + + // Returns domain attributes. + CUptiResult DeviceGetEventDomainAttribute(CUdevice device, + CUpti_EventDomainID event_domain, + CUpti_EventDomainAttribute attrib, + size_t* value_size, + void* value) override; + + // Disables kernel replay mode. + CUptiResult DisableKernelReplayMode(CUcontext context) override; + + // Enables kernel replay mode. If we successfully enable kernel replay mode, + // we add DisableKernelReplayMode to the undo log. + CUptiResult EnableKernelReplayMode(CUcontext context) override; + + // Returns the number of event domains. + CUptiResult DeviceGetNumEventDomains(CUdevice device, + uint32_t* num_domains) override; + + // Returns a list of events. + CUptiResult EventDomainEnumEvents(CUpti_EventDomainID event_domain, + size_t* array_size_bytes, + CUpti_EventID* event_array) override; + + // Returns the number of events. + CUptiResult EventDomainGetNumEvents(CUpti_EventDomainID event_domain, + uint32_t* num_events) override; + + // Returns an event attribute. + CUptiResult EventGetAttribute(CUpti_EventID event, + CUpti_EventAttribute attrib, size_t* value_size, + void* value) override; + + // Convverts event ID from event name. + CUptiResult EventGetIdFromName(CUdevice device, const char* event_name, + CUpti_EventID* event) override; + + // Disables event group. + CUptiResult EventGroupDisable(CUpti_EventGroup event_group) override; + + // Enables event group. If we successfully enable an event group, we add + // EventGroupDisable to the undo log. + CUptiResult EventGroupEnable(CUpti_EventGroup event_group) override; + + // Returns an event group attribute. + CUptiResult EventGroupGetAttribute(CUpti_EventGroup event_group, + CUpti_EventGroupAttribute attrib, + size_t* value_size, void* value) override; + + // Returns a performance counter value. + CUptiResult EventGroupReadEvent(CUpti_EventGroup event_group, + CUpti_ReadEventFlags flags, + CUpti_EventID event, + size_t* event_value_buffer_size_bytes, + uint64_t* event_value_buffer) override; + + // Returns an event group set attribute. + CUptiResult EventGroupSetAttribute(CUpti_EventGroup event_group, + CUpti_EventGroupAttribute attrib, + size_t value_size, void* value) override; + + // Creates an event group set. If we successfully creates an event group set, + // we add EventGroupSetsDestroy to the undo log. + CUptiResult EventGroupSetsCreate( + CUcontext context, size_t event_id_array_size_bytes, + CUpti_EventID* event_id_array, + CUpti_EventGroupSets** event_group_passes) override; + + // Destroys an event group set. + CUptiResult EventGroupSetsDestroy( + CUpti_EventGroupSets* event_group_sets) override; + + // CUPTI metric API: all thread-safe + // Enumerates metrics. + CUptiResult DeviceEnumMetrics(CUdevice device, size_t* arraySizeBytes, + CUpti_MetricID* metricArray) override; + + // Returns the number of metrics. + CUptiResult DeviceGetNumMetrics(CUdevice device, + uint32_t* num_metrics) override; + + // Converts a metric ID to a metric name. + CUptiResult MetricGetIdFromName(CUdevice device, const char* metric_name, + CUpti_MetricID* metric) override; + + // Returns the number of events required to calculate a particular metric. + CUptiResult MetricGetNumEvents(CUpti_MetricID metric, + uint32_t* num_events) override; + + // Returns a list of events required to calculate a particular metric. + CUptiResult MetricEnumEvents(CUpti_MetricID metric, + size_t* event_id_array_size_bytes, + CUpti_EventID* event_id_array) override; + + // Returns a metric attribute. + CUptiResult MetricGetAttribute(CUpti_MetricID metric, + CUpti_MetricAttribute attrib, + size_t* value_size, void* value) override; + + // Returns a metric value. + CUptiResult MetricGetValue(CUdevice device, CUpti_MetricID metric, + size_t event_id_array_size_bytes, + CUpti_EventID* event_id_array, + size_t event_value_array_size_bytes, + uint64_t* event_value_array, + uint64_t time_duration, + CUpti_MetricValue* metric_value) override; + + CUptiResult GetResultString(CUptiResult result, const char** str) override; + + CUptiResult GetContextId(CUcontext context, uint32_t* context_id) override; + + CUptiResult GetStreamIdEx(CUcontext context, CUstream stream, + uint8_t per_thread_stream, + uint32_t* stream_id) override; + + // Clears Undo stack. We are maintaining undo stack for each profiling phase. + // Once the profiling is done, we need to clear the undo stack. + void CleanUp() override; + + private: + typedef std::function UndoFunction; + + // Register undo function. + void RegisterUndoFunction(const UndoFunction& func); + + // Resets profiling status by calling some undo functions registered, + // and then disables profiling. + void UndoAndDisable(); + + // Returns a descriptive string for a CUptiResult. + std::string ResultString(CUptiResult result) const; + + // Contains a pointer to a cupti interface instance. Normally, this will point + // to a real CUPTI interface that interacts with underlying hardware, but for + // testing, we often replace this with a CUPTI mock object to mock hardware + // behavior. This will be set when CuptiBase singleton was created and an + // object that this variable points to will die when CuptiBase singleton dies, + // i.e., at the end of program execution. + std::unique_ptr interface_; + + // A vector of functions that needs to be called by Undo upon an error + // detected. This vector is managed like a statck through push_back and + // pop_back. Whenever an API function is successfully executed, its + // corresponding undo function will be pushed into this stack and Undo will + // pop and execute the unroll function upon detecting an error. + std::vector undo_stack_ TF_GUARDED_BY(undo_stack_mu_); + + // A mutex to guarantee atomicity for undo_stack_. Given that threads that + // can update undo_stack_ are a profiling control thread such as a webserver + // thread or a thread that executes a kernel during performance counter + // profiling, which is already serialized, the contention for this lock will + // be extremely low. In other words, it will be contended only when the + // profiling is being enabled or disabled, and we will have at most two + // threads that will contend for this mutex. + tsl::mutex undo_stack_mu_; + + // Once an error is detected, we will ignore any CUPTI API call. + std::atomic disabled_; + + // Prevent recursive undo if an UndoFunction fails. + bool undo_disabled_; + + TF_DISALLOW_COPY_AND_ASSIGN(CuptiErrorManager); +}; + +} // namespace profiler +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUPTI_ERROR_MANAGER_H_ diff --git a/tensorflow/core/profiler/backends/gpu/cupti_error_manager_test.cc b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_error_manager_test.cc similarity index 88% rename from tensorflow/core/profiler/backends/gpu/cupti_error_manager_test.cc rename to tensorflow/compiler/xla/backends/profiler/gpu/cupti_error_manager_test.cc index 9d08b094857..27909b31cee 100644 --- a/tensorflow/core/profiler/backends/gpu/cupti_error_manager_test.cc +++ b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_error_manager_test.cc @@ -15,30 +15,30 @@ limitations under the License. #if GOOGLE_CUDA -#include "tensorflow/core/profiler/backends/gpu/cupti_error_manager.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_error_manager.h" #include #include #include #include "absl/memory/memory.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/profiler/backends/gpu/cuda_test.h" -#include "tensorflow/core/profiler/backends/gpu/cupti_interface.h" -#include "tensorflow/core/profiler/backends/gpu/cupti_tracer.h" -#include "tensorflow/core/profiler/backends/gpu/cupti_wrapper.h" -#include "tensorflow/core/profiler/backends/gpu/mock_cupti.h" -#include "tensorflow/core/profiler/utils/time_utils.h" - -namespace tensorflow { +#include "tensorflow/compiler/xla/backends/profiler/gpu/cuda_test.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_interface.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_tracer.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_wrapper.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/mock_cupti.h" +#include "tensorflow/tsl/platform/test.h" +#include "tensorflow/tsl/profiler/utils/time_utils.h" + +namespace xla { namespace profiler { namespace test { -using tensorflow::profiler::CuptiInterface; -using tensorflow::profiler::CuptiTracer; -using tensorflow::profiler::CuptiTracerCollectorOptions; -using tensorflow::profiler::CuptiTracerOptions; -using tensorflow::profiler::CuptiWrapper; +using xla::profiler::CuptiInterface; +using xla::profiler::CuptiTracer; +using xla::profiler::CuptiTracerCollectorOptions; +using xla::profiler::CuptiTracerOptions; +using xla::profiler::CuptiWrapper; using ::testing::_; using ::testing::Invoke; @@ -74,7 +74,7 @@ class CuptiErrorManagerTest : public ::testing::Test { CuptiTracerCollectorOptions collector_options; collector_options.num_gpus = CuptiTracer::NumGpus(); uint64_t start_gputime_ns = CuptiTracer::GetTimestamp(); - uint64_t start_walltime_ns = tensorflow::profiler::GetCurrentTimeNanos(); + uint64_t start_walltime_ns = tsl::profiler::GetCurrentTimeNanos(); cupti_collector_ = CreateCuptiCollector( collector_options, start_walltime_ns, start_gputime_ns); } @@ -107,7 +107,7 @@ class CuptiErrorManagerTest : public ::testing::Test { // CuptiWrapper instance to which mock_ calls are delegated. std::unique_ptr cupti_wrapper_; - std::unique_ptr cupti_collector_; + std::unique_ptr cupti_collector_; }; // Verifies that failed EnableProfiling() does not kill an application. @@ -208,6 +208,6 @@ TEST_F(CuptiErrorManagerTest, GpuTraceAutoEnableTest) { } // namespace test } // namespace profiler -} // namespace tensorflow +} // namespace xla #endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/xla/backends/profiler/gpu/cupti_interface.h b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_interface.h new file mode 100644 index 00000000000..58fe6e78e49 --- /dev/null +++ b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_interface.h @@ -0,0 +1,204 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUPTI_INTERFACE_H_ +#define TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUPTI_INTERFACE_H_ + +#include +#include + +#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "tensorflow/tsl/platform/macros.h" +#include "tensorflow/tsl/platform/types.h" + +namespace xla { +namespace profiler { + +// Provides a wrapper interface to every single CUPTI API function. This class +// is needed to create an easy mock object for CUPTI API calls. All member +// functions are defined in the following order: activity related APIs, callback +// related APIs, Event APIs, and metric APIs. Within each category, we follow +// the order in the original CUPTI documentation. +class CuptiInterface { + public: + CuptiInterface() {} + + virtual ~CuptiInterface() {} + + // CUPTI activity API + virtual CUptiResult ActivityDisable(CUpti_ActivityKind kind) = 0; + + virtual CUptiResult ActivityEnable(CUpti_ActivityKind kind) = 0; + + virtual CUptiResult ActivityFlushAll(uint32_t flag) = 0; + + virtual CUptiResult ActivityGetNextRecord(uint8_t* buffer, + size_t valid_buffer_size_bytes, + CUpti_Activity** record) = 0; + + virtual CUptiResult ActivityGetNumDroppedRecords(CUcontext context, + uint32_t stream_id, + size_t* dropped) = 0; + + virtual CUptiResult ActivityConfigureUnifiedMemoryCounter( + CUpti_ActivityUnifiedMemoryCounterConfig* config, uint32_t count) = 0; + + virtual CUptiResult ActivityRegisterCallbacks( + CUpti_BuffersCallbackRequestFunc func_buffer_requested, + CUpti_BuffersCallbackCompleteFunc func_buffer_completed) = 0; + + virtual CUptiResult GetDeviceId(CUcontext context, tsl::uint32* deviceId) = 0; + + virtual CUptiResult GetTimestamp(uint64_t* timestamp) = 0; + + virtual CUptiResult Finalize() = 0; + + // CUPTI callback API + virtual CUptiResult EnableCallback(uint32_t enable, + CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain, + CUpti_CallbackId cbid) = 0; + + virtual CUptiResult EnableDomain(uint32_t enable, + CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain) = 0; + + virtual CUptiResult Subscribe(CUpti_SubscriberHandle* subscriber, + CUpti_CallbackFunc callback, + void* userdata) = 0; + + virtual CUptiResult Unsubscribe(CUpti_SubscriberHandle subscriber) = 0; + + // CUPTI event API + virtual CUptiResult DeviceEnumEventDomains( + CUdevice device, size_t* array_size_bytes, + CUpti_EventDomainID* domain_array) = 0; + + virtual CUptiResult DeviceGetEventDomainAttribute( + CUdevice device, CUpti_EventDomainID event_domain, + CUpti_EventDomainAttribute attrib, size_t* value_size, void* value) = 0; + + virtual CUptiResult DisableKernelReplayMode(CUcontext context) = 0; + + virtual CUptiResult EnableKernelReplayMode(CUcontext context) = 0; + + virtual CUptiResult DeviceGetNumEventDomains(CUdevice device, + uint32_t* num_domains) = 0; + + virtual CUptiResult EventDomainEnumEvents(CUpti_EventDomainID event_domain, + size_t* array_size_bytes, + CUpti_EventID* event_array) = 0; + + virtual CUptiResult EventDomainGetNumEvents(CUpti_EventDomainID event_domain, + uint32_t* num_events) = 0; + + virtual CUptiResult EventGetAttribute(CUpti_EventID event, + CUpti_EventAttribute attrib, + size_t* value_size, void* value) = 0; + + virtual CUptiResult EventGetIdFromName(CUdevice device, + const char* event_name, + CUpti_EventID* event) = 0; + + virtual CUptiResult EventGroupDisable(CUpti_EventGroup event_group) = 0; + + virtual CUptiResult EventGroupEnable(CUpti_EventGroup event_group) = 0; + + virtual CUptiResult EventGroupGetAttribute(CUpti_EventGroup event_group, + CUpti_EventGroupAttribute attrib, + size_t* value_size, + void* value) = 0; + + virtual CUptiResult EventGroupReadEvent(CUpti_EventGroup event_group, + CUpti_ReadEventFlags flags, + CUpti_EventID event, + size_t* event_value_buffer_size_bytes, + uint64_t* eventValueBuffer) = 0; + + virtual CUptiResult EventGroupSetAttribute(CUpti_EventGroup event_group, + CUpti_EventGroupAttribute attrib, + size_t value_size, + void* value) = 0; + + virtual CUptiResult EventGroupSetsCreate( + CUcontext context, size_t event_id_array_size_bytes, + CUpti_EventID* event_id_array, + CUpti_EventGroupSets** event_group_passes) = 0; + + virtual CUptiResult EventGroupSetsDestroy( + CUpti_EventGroupSets* event_group_sets) = 0; + + // CUPTI metric API + virtual CUptiResult DeviceEnumMetrics(CUdevice device, size_t* arraySizeBytes, + CUpti_MetricID* metricArray) = 0; + + virtual CUptiResult DeviceGetNumMetrics(CUdevice device, + uint32_t* num_metrics) = 0; + + virtual CUptiResult MetricGetIdFromName(CUdevice device, + const char* metric_name, + CUpti_MetricID* metric) = 0; + + virtual CUptiResult MetricGetNumEvents(CUpti_MetricID metric, + uint32_t* num_events) = 0; + + virtual CUptiResult MetricEnumEvents(CUpti_MetricID metric, + size_t* event_id_array_size_bytes, + CUpti_EventID* event_id_array) = 0; + + virtual CUptiResult MetricGetAttribute(CUpti_MetricID metric, + CUpti_MetricAttribute attrib, + size_t* value_size, void* value) = 0; + + virtual CUptiResult MetricGetValue(CUdevice device, CUpti_MetricID metric, + size_t event_id_array_size_bytes, + CUpti_EventID* event_id_array, + size_t event_value_array_size_bytes, + uint64_t* event_value_array, + uint64_t time_duration, + CUpti_MetricValue* metric_value) = 0; + + virtual CUptiResult GetResultString(CUptiResult result, const char** str) = 0; + + virtual CUptiResult GetContextId(CUcontext context, uint32_t* context_id) = 0; + + virtual CUptiResult GetStreamIdEx(CUcontext context, CUstream stream, + uint8_t per_thread_stream, + uint32_t* stream_id) = 0; + + // Interface maintenance functions. Not directly related to CUPTI, but + // required for implementing an error resilient layer over CUPTI API. + + // Performance any clean up work that is required each time profile session + // is done. Therefore this can be called multiple times during process life + // time. + virtual void CleanUp() = 0; + + // Whether CUPTI API is currently disabled due to unrecoverable errors. + // All subsequent calls will fail immediately without forwarding calls to + // CUPTI library. + virtual bool Disabled() const = 0; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(CuptiInterface); +}; + +CuptiInterface* GetCuptiInterface(); + +} // namespace profiler +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUPTI_INTERFACE_H_ diff --git a/tensorflow/core/profiler/backends/gpu/cupti_tracer.cc b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_tracer.cc similarity index 93% rename from tensorflow/core/profiler/backends/gpu/cupti_tracer.cc rename to tensorflow/compiler/xla/backends/profiler/gpu/cupti_tracer.cc index 8ea1b21da62..bf37f0d7cb3 100644 --- a/tensorflow/core/profiler/backends/gpu/cupti_tracer.cc +++ b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_tracer.cc @@ -13,29 +13,34 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/backends/gpu/cupti_tracer.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_tracer.h" +#include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/node_hash_map.h" #include "absl/container/node_hash_set.h" #include "third_party/gpus/cuda/extras/CUPTI/include/cupti_activity.h" #include "third_party/gpus/cuda/extras/CUPTI/include/generated_nvtx_meta.h" -#include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/host_info.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/profiler/backends/cpu/annotation_stack.h" -#include "tensorflow/core/profiler/backends/gpu/cupti_collector.h" -#include "tensorflow/core/profiler/backends/gpu/nvtx_utils.h" - -namespace tensorflow { +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_collector.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/nvtx_utils.h" +#include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/host_info.h" +#include "tensorflow/tsl/platform/logging.h" +#include "tensorflow/tsl/platform/macros.h" +#include "tensorflow/tsl/profiler/backends/cpu/annotation_stack.h" + +namespace xla { namespace profiler { namespace { +using tsl::Env; +using tsl::OkStatus; +using tsl::Status; +using tsl::profiler::AnnotationStack; + // CUPTI from CUDA 11.6 adds information about the hardware channel that ops // run on; this makes its way into the channel_id and channel_type fields in the // structs we export. @@ -71,7 +76,7 @@ Status ToStatus(CUptiResult result) { } const char *str = nullptr; cuptiGetResultString(result, &str); - return errors::Unavailable("CUPTI error: ", str ? str : ""); + return tsl::errors::Unavailable("CUPTI error: ", str ? str : ""); } Status ToStatus(CUresult result) { @@ -80,7 +85,7 @@ Status ToStatus(CUresult result) { } const char *str = nullptr; cuGetErrorName(result, &str); - return errors::Unavailable("CUDA error: ", str ? str : ""); + return tsl::errors::Unavailable("CUDA error: ", str ? str : ""); } inline void LogIfError(const Status &status) { @@ -143,9 +148,9 @@ const char *getActivityUnifiedMemoryKindString( cupti_interface_->GetResultString(status, &errstr); \ LOG(ERROR) << "function " << #expr << "failed with error " << errstr; \ if (status == CUPTI_ERROR_INSUFFICIENT_PRIVILEGES) { \ - return errors::PermissionDenied("CUPTI need root access!"); \ + return tsl::errors::PermissionDenied("CUPTI need root access!"); \ } else { \ - return errors::Internal("CUPTI call error", errstr); \ + return tsl::errors::Internal("CUPTI call error", errstr); \ } \ } \ } while (false) @@ -367,9 +372,10 @@ void CUPTIAPI ProcessCuptiActivityBuffer(CUcontext context, uint32_t stream_id, } } -void AddKernelEventUponApiExit(CuptiTraceCollector *collector, uint32 device_id, +void AddKernelEventUponApiExit(CuptiTraceCollector *collector, + tsl::uint32 device_id, const CUpti_CallbackData *cbdata, - uint64 start_time, uint64 end_time) { + tsl::uint64 start_time, tsl::uint64 end_time) { CuptiTracerEvent event{}; event.type = CuptiTracerEventType::Kernel; event.source = CuptiTracerEventSource::DriverCallback; @@ -387,8 +393,8 @@ void AddKernelEventUponApiExit(CuptiTraceCollector *collector, uint32 device_id, // Performs the actual callback for both normal and P2P memcpy operations. CuptiTracerEvent PopulateMemcpyCallbackEvent( CuptiTracerEventType type, const CUpti_CallbackData *cbdata, - size_t num_bytes, uint32 src_device, uint32 dst_device, bool async, - uint64 start_time, uint64 end_time) { + size_t num_bytes, tsl::uint32 src_device, tsl::uint32 dst_device, + bool async, tsl::uint64 start_time, tsl::uint64 end_time) { CuptiTracerEvent event{}; event.type = type; event.source = CuptiTracerEventSource::DriverCallback; @@ -409,9 +415,11 @@ CuptiTracerEvent PopulateMemcpyCallbackEvent( } void AddNormalMemcpyEventUponApiExit(CuptiTraceCollector *collector, - uint32 device_id, CUpti_CallbackId cbid, + tsl::uint32 device_id, + CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, - uint64 start_time, uint64 end_time) { + tsl::uint64 start_time, + tsl::uint64 end_time) { size_t num_bytes; CuptiTracerEventType type; bool async; @@ -426,9 +434,9 @@ void AddNormalMemcpyEventUponApiExit(CuptiTraceCollector *collector, } void AddCuMemsetEventUponApiExit(CuptiTraceCollector *collector, - uint32 device_id, CUpti_CallbackId cbid, + tsl::uint32 device_id, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, - uint64 start_time, uint64 end_time) { + tsl::uint64 start_time, tsl::uint64 end_time) { // We are casting all variants of cuMemset to cuMemsetD8 for accessing the // first member attribute, a CUdeviceptr. const auto *params = @@ -459,16 +467,17 @@ void AddCuMemsetEventUponApiExit(CuptiTraceCollector *collector, void AddP2PMemcpyEventUponApiExit(CuptiTraceCollector *collector, CuptiInterface *cupti_interface, - uint32 device_id, CUpti_CallbackId cbid, + tsl::uint32 device_id, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, - uint64 start_time, uint64 end_time) { + tsl::uint64 start_time, + tsl::uint64 end_time) { size_t num_bytes; CuptiTracerEventType type; bool async; std::tie(num_bytes, type, async) = DecodeDriverMemcpy(cbid, cbdata->functionParams); - uint32 dst_device = -1, src_device = -1; + tsl::uint32 dst_device = -1, src_device = -1; const auto *p2p_params = static_cast(cbdata->functionParams); cupti_interface->GetDeviceId(p2p_params->srcContext, &src_device); @@ -482,9 +491,10 @@ void AddP2PMemcpyEventUponApiExit(CuptiTraceCollector *collector, } void AddCuMemAllocEventUponApiExit(CuptiTraceCollector *collector, - uint32 device_id, CUpti_CallbackId cbid, + tsl::uint32 device_id, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, - uint64 start_time, uint64 end_time) { + tsl::uint64 start_time, + tsl::uint64 end_time) { const auto *params = static_cast(cbdata->functionParams); CuptiTracerEvent event{}; @@ -505,9 +515,11 @@ void AddCuMemAllocEventUponApiExit(CuptiTraceCollector *collector, } void AddCuMemAllocPitchEventUponApiExit(CuptiTraceCollector *collector, - uint32 device_id, CUpti_CallbackId cbid, + tsl::uint32 device_id, + CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, - uint64 start_time, uint64 end_time) { + tsl::uint64 start_time, + tsl::uint64 end_time) { const auto *params = static_cast(cbdata->functionParams); CuptiTracerEvent event{}; @@ -529,9 +541,10 @@ void AddCuMemAllocPitchEventUponApiExit(CuptiTraceCollector *collector, } void AddCuMemFreeEventUponApiExit(CuptiTraceCollector *collector, - uint32 device_id, CUpti_CallbackId cbid, + tsl::uint32 device_id, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, - uint64 start_time, uint64 end_time) { + tsl::uint64 start_time, + tsl::uint64 end_time) { const auto *params = static_cast(cbdata->functionParams); CuptiTracerEvent event{}; @@ -550,9 +563,9 @@ void AddCuMemFreeEventUponApiExit(CuptiTraceCollector *collector, } void AddGenericEventUponApiExit(CuptiTraceCollector *collector, - uint32 device_id, CUpti_CallbackId cbid, + tsl::uint32 device_id, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, - uint64 start_time, uint64 end_time) { + tsl::uint64 start_time, tsl::uint64 end_time) { CuptiTracerEvent event{}; event.type = CuptiTracerEventType::Generic; event.source = CuptiTracerEventSource::DriverCallback; @@ -862,8 +875,8 @@ class CuptiDriverApiHookWithActivityApi : public CuptiDriverApiHook { } // Grab timestamp for API exit. API entry timestamp saved in cbdata. - uint64 end_tsc = CuptiTracer::GetTimestamp(); - uint64 start_tsc = *cbdata->correlationData; + uint64_t end_tsc = CuptiTracer::GetTimestamp(); + uint64_t start_tsc = *cbdata->correlationData; TrackContext(cbid, cbdata->context); return AddDriverApiCallbackEvent(collector_, cupti_interface_, device_id, start_tsc, end_tsc, domain, cbid, cbdata); @@ -910,11 +923,11 @@ struct KernelRecord { // record the stream and infer the context during collection. CUcontext context; CUstream stream; - uint32 correlation_id; + tsl::uint32 correlation_id; CUevent start_event; CUevent stop_event; KernelDetails details; - uint64 start_timestamp; + tsl::uint64 start_timestamp; }; struct MemcpyRecord { @@ -922,11 +935,11 @@ struct MemcpyRecord { size_t size_bytes; CUcontext context; CUstream stream; - uint32 correlation_id; + tsl::uint32 correlation_id; bool async; CUevent start_event; CUevent stop_event; - uint64 start_timestamp; + tsl::uint64 start_timestamp; }; Status CreateAndRecordEvent(CUevent *event, CUstream stream) { @@ -945,7 +958,7 @@ class ScopedCudaContext { CUcontext context; if (cuStreamGetCtx(stream, &context) != CUDA_SUCCESS) return; context_ = context; - uint32 device_ordinal; + tsl::uint32 device_ordinal; if (cuptiGetDeviceId(context, &device_ordinal) != CUPTI_SUCCESS) return; device_ordinal_ = device_ordinal; context_pushed_ = cuCtxPushCurrent(context) == CUDA_SUCCESS; @@ -957,17 +970,17 @@ class ScopedCudaContext { } // If successful, return the device ordinal of the relevant cuda stream. - // Otherwise absl::nullopt; - absl::optional GetDeviceOrdinal() { return device_ordinal_; } + // Otherwise std::nullopt; non-std ok + std::optional GetDeviceOrdinal() { return device_ordinal_; } // If successful, return the cuda context of the relevant cuda stream. - // Otherwise absl::nullopt; - absl::optional GetContext() { return context_; } + // Otherwise std::nullopt; + std::optional GetContext() { return context_; } private: CUstream stream_; - absl::optional context_; - absl::optional device_ordinal_; + std::optional context_; + std::optional device_ordinal_; bool context_pushed_ = false; }; #endif @@ -994,7 +1007,7 @@ class CudaEventRecorder { // to StopKernel() after the kernel launch has completed. template size_t StartKernel(const char *kernel_name, CUcontext context, - uint32 correlation_id, const T *params) { + tsl::uint32 correlation_id, const T *params) { CUstream stream = params->hStream; KernelRecord record = {kernel_name, context, stream, correlation_id}; record.details.registers_per_thread = 0; // unknown. @@ -1013,7 +1026,7 @@ class CudaEventRecorder { kernel_records_.push_back(record); return kernel_records_.size() - 1; } - uint64 StopKernel(size_t index) { + tsl::uint64 StopKernel(size_t index) { absl::MutexLock lock(&mutex_); if (index >= kernel_records_.size()) return 0; auto &record = kernel_records_[index]; @@ -1024,8 +1037,8 @@ class CudaEventRecorder { // Registers the start of a copy operation. The returned index should be // passed to StopMemcpy() after the memcpy has completed. size_t StartMemcpy(CuptiTracerEventType type, size_t size_bytes, - CUcontext context, CUstream stream, uint32 correlation_id, - bool async) { + CUcontext context, CUstream stream, + tsl::uint32 correlation_id, bool async) { MemcpyRecord record = {type, size_bytes, context, stream, correlation_id, async}; record.start_timestamp = CuptiTracer::GetTimestamp(); @@ -1035,7 +1048,7 @@ class CudaEventRecorder { memcpy_records_.push_back(record); return memcpy_records_.size() - 1; } - uint64 StopMemcpy(size_t index) { + tsl::uint64 StopMemcpy(size_t index) { absl::MutexLock lock(&mutex_); if (index >= memcpy_records_.size()) return 0; auto &record = memcpy_records_[index]; @@ -1100,13 +1113,13 @@ class CudaEventRecorder { private: struct ContextInfo { - uint32 context_id = 0; + tsl::uint32 context_id = 0; int num_streams = 0; CUevent end_event; }; struct StreamInfo { - uint32 stream_id = 0; + tsl::uint32 stream_id = 0; std::string name; int index; // 0 is reserved for null stream. const ContextInfo *ctx_info; @@ -1127,7 +1140,7 @@ class CudaEventRecorder { auto it = context_infos_.find(context); if (it == context_infos_.end()) { - uint32 context_id = 0; + tsl::uint32 context_id = 0; RETURN_IF_CUPTI_ERROR( cupti_interface_->GetContextId(context, &context_id)); ContextInfo ctx_info = {context_id}; @@ -1154,7 +1167,7 @@ class CudaEventRecorder { ContextInfo *ctx_info; TF_RETURN_IF_ERROR(GetContextInfo(context, &ctx_info)); int index = stream ? ++ctx_info->num_streams : 0; - uint32 stream_id = 0; + tsl::uint32 stream_id = 0; #if defined(CUDA_API_PER_THREAD_DEFAULT_STREAM) RETURN_IF_CUPTI_ERROR( cupti_interface_->GetStreamIdEx(context, stream, 1, &stream_id)); @@ -1174,7 +1187,7 @@ class CudaEventRecorder { CuptiApiTracingDisabler disabler; float elapsed_ms = 0.0f; LogIfError(ToStatus(cuEventElapsedTime(&elapsed_ms, start, stop))); - return static_cast( + return static_cast( std::llroundf(1000 * std::max(elapsed_ms, 0.0f))); } @@ -1245,13 +1258,14 @@ class CudaEventRecorder { absl::Mutex mutex_; bool stopped_ TF_GUARDED_BY(mutex_) = false; std::vector kernel_records_ TF_GUARDED_BY(mutex_); - std::vector memcpy_records_ TF_GUARDED_BY(mutex_); + std::vector memcpy_records_ + TF_GUARDED_BY(mutex_); // non std ok CuptiInterface *cupti_interface_; CuptiTraceCollector *collector_; const int ordinal_; std::string device_name_; - uint64 end_walltime_us_; + tsl::uint64 end_walltime_us_; // Include context in key to distinguish null streams. using StreamKey = std::pair; @@ -1308,7 +1322,7 @@ class CuptiDriverApiHookWithCudaEvent : public CuptiDriverApiHook { const auto *params = static_cast( cbdata->functionParams); - std::vector record_indices; + std::vector record_indices; record_indices.reserve(params->numDevices); *cbdata->correlationData = -1; // Invalid value. const auto &annotation = AnnotationStack::Get(); @@ -1317,7 +1331,7 @@ class CuptiDriverApiHookWithCudaEvent : public CuptiDriverApiHook { ScopedCudaContext scoped_cuda_context(stream); auto dev_id = scoped_cuda_context.GetDeviceOrdinal(); auto context = scoped_cuda_context.GetContext(); - if (!dev_id) return errors::Internal("Invalid CUDA stream"); + if (!dev_id) return tsl::errors::Internal("Invalid CUDA stream"); // Because annotation are per device, therefore we need to populate // annotation for each device involved. collector_->annotation_map()->Add(*dev_id, cbdata->correlationId, @@ -1330,7 +1344,8 @@ class CuptiDriverApiHookWithCudaEvent : public CuptiDriverApiHook { auto *callback_context = new CuptiApiCallbackContext(std::move(record_indices)); callback_contexts_.insert(callback_context); - *cbdata->correlationData = reinterpret_cast(callback_context); + *cbdata->correlationData = + reinterpret_cast(callback_context); #else VLOG(1) << "Unhandled cuLaunchCooperativeKernelMultiDevice."; #endif @@ -1385,7 +1400,7 @@ class CuptiDriverApiHookWithCudaEvent : public CuptiDriverApiHook { const CUpti_CallbackData *cbdata) override { auto *recorder = cuda_event_recorders_[device_id].get(); if (*cbdata->correlationData == static_cast(-1)) return OkStatus(); - uint64 start_tsc = 0; + tsl::uint64 start_tsc = 0; switch (cbid) { case CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel: case CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel: @@ -1402,12 +1417,12 @@ class CuptiDriverApiHookWithCudaEvent : public CuptiDriverApiHook { static_cast( cbdata->functionParams); if (record_indices.size() != params->numDevices) - return errors::Internal("Invalid correlation data"); + return tsl::errors::Internal("Invalid correlation data"); for (int i = 0; i < params->numDevices; ++i) { CUstream stream = params->launchParamsList[i].hStream; ScopedCudaContext scoped_cuda_context(stream); auto dev_id = scoped_cuda_context.GetDeviceOrdinal(); - if (!dev_id) return errors::Internal("Invalid CUDA stream"); + if (!dev_id) return tsl::errors::Internal("Invalid CUDA stream"); start_tsc = cuda_event_recorders_[*dev_id]->StopKernel(record_indices[i]); } @@ -1434,7 +1449,7 @@ class CuptiDriverApiHookWithCudaEvent : public CuptiDriverApiHook { } // Grab timestamp for API exit. API entry timestamp saved in cbdata. - uint64 end_tsc = CuptiTracer::GetTimestamp(); + tsl::uint64 end_tsc = CuptiTracer::GetTimestamp(); return AddDriverApiCallbackEvent(collector_, cupti_interface_, device_id, start_tsc, end_tsc, domain, cbid, cbdata); } @@ -1505,9 +1520,9 @@ class CuptiDriverApiHookWithCudaEvent : public CuptiDriverApiHook { // However there is no guarantee that we receive such callbacks in pairs, we // maintain a on-going API calls to make sure no memory leaks. struct CuptiApiCallbackContext { - explicit CuptiApiCallbackContext(std::vector &&r) + explicit CuptiApiCallbackContext(std::vector &&r) : record_indices(std::move(r)) {} - std::vector record_indices; + std::vector record_indices; }; const CuptiTracerOptions option_; @@ -1519,14 +1534,14 @@ class CuptiDriverApiHookWithCudaEvent : public CuptiDriverApiHook { }; /*static*/ std::string ErrorWithHostname(absl::string_view error_message) { - return absl::StrCat(port::Hostname(), ": ", error_message); + return absl::StrCat(tsl::port::Hostname(), ": ", error_message); } } // namespace /*static*/ Status CuptiDriverApiHook::AddDriverApiCallbackEvent( CuptiTraceCollector *collector, CuptiInterface *cupti_interface, - int device_id, uint64 start_tsc, uint64 end_tsc, + int device_id, tsl::uint64 start_tsc, tsl::uint64 end_tsc, CUpti_CallbackDomain domain, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata) { switch (cbid) { @@ -1681,13 +1696,13 @@ void CuptiTracer::Enable(const CuptiTracerOptions &option, } Status status = EnableApiTracing(); - need_root_access_ |= status.code() == error::PERMISSION_DENIED; + need_root_access_ |= status.code() == tsl::error::PERMISSION_DENIED; if (!status.ok()) return; if (option_->enable_activity_api) { EnableActivityTracing().IgnoreError(); } - tensorflow::profiler::AnnotationStack::Enable(true); + tsl::profiler::AnnotationStack::Enable(true); } void CuptiTracer::Disable() { @@ -1702,7 +1717,7 @@ void CuptiTracer::Disable() { collector_ = nullptr; option_.reset(); cupti_driver_api_hook_.reset(); - tensorflow::profiler::AnnotationStack::Enable(false); + tsl::profiler::AnnotationStack::Enable(false); } Status CuptiTracer::EnableApiTracing() { @@ -1809,7 +1824,7 @@ Status CuptiTracer::Finalize() { return OkStatus(); } -/*static*/ uint64 CuptiTracer::GetTimestamp() { +/*static*/ tsl::uint64 CuptiTracer::GetTimestamp() { uint64_t tsc; CuptiInterface *cupti_interface = GetCuptiInterface(); if (cupti_interface && cupti_interface->GetTimestamp(&tsc) == CUPTI_SUCCESS) { @@ -1852,15 +1867,15 @@ Status CuptiTracer::HandleCallback(CUpti_CallbackDomain domain, // API callback is called before any CUDA context is created. // This is expected to be rare, and we ignore this case. VLOG(3) << "API callback received before creation of CUDA context\n"; - return errors::Internal("cutpi callback without context"); + return tsl::errors::Internal("cutpi callback without context"); } // Grab a correct device ID. - uint32 device_id = -1; + tsl::uint32 device_id = -1; RETURN_IF_CUPTI_ERROR( cupti_interface_->GetDeviceId(cbdata->context, &device_id)); if (device_id >= num_gpus_) { - return errors::Internal("Invalid device id:", device_id); + return tsl::errors::Internal("Invalid device id:", device_id); } if (cbdata->callbackSite == CUPTI_API_ENTER) { @@ -1937,8 +1952,7 @@ void CuptiTracer::RequestActivityBuffer(uint8_t **buffer, size_t *size) { Status CuptiTracer::ProcessActivityBuffer(CUcontext context, uint32_t stream_id, uint8_t *buffer, size_t size) { - auto buffer_cleanup = - gtl::MakeCleanup([&]() { buffer_pool_.ReclaimBuffer(buffer); }); + absl::Cleanup buffer_cleanup = [&]() { buffer_pool_.ReclaimBuffer(buffer); }; if (size == 0) { return OkStatus(); } @@ -1946,7 +1960,7 @@ Status CuptiTracer::ProcessActivityBuffer(CUcontext context, uint32_t stream_id, LOG(WARNING) << "CUPTI activity buffer is reclaimed after flush."; return OkStatus(); } - if (cupti_interface_->Disabled()) return errors::Internal("Disabled."); + if (cupti_interface_->Disabled()) return tsl::errors::Internal("Disabled."); CUpti_Activity *record = nullptr; while (true) { @@ -1996,7 +2010,7 @@ Status CuptiTracer::ProcessActivityBuffer(CUcontext context, uint32_t stream_id, } else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) { break; } else { - return errors::Internal("Parse cupti activity buffer error."); + return tsl::errors::Internal("Parse cupti activity buffer error."); } } @@ -2005,7 +2019,7 @@ Status CuptiTracer::ProcessActivityBuffer(CUcontext context, uint32_t stream_id, RETURN_IF_CUPTI_ERROR(cupti_interface_->ActivityGetNumDroppedRecords( context, stream_id, &dropped)); if (dropped != 0) { - uint32 device_id = -1; + tsl::uint32 device_id = -1; RETURN_IF_CUPTI_ERROR(cupti_interface_->GetDeviceId(context, &device_id)); collector_->OnEventsDropped("cupti activity buffer full", dropped); } @@ -2026,4 +2040,4 @@ Status CuptiTracer::ProcessActivityBuffer(CUcontext context, uint32_t stream_id, } } // namespace profiler -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/backends/profiler/gpu/cupti_tracer.h b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_tracer.h new file mode 100644 index 00000000000..96e32ca164d --- /dev/null +++ b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_tracer.h @@ -0,0 +1,156 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUPTI_TRACER_H_ +#define TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUPTI_TRACER_H_ + +#include "absl/types/optional.h" +#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" +#include "third_party/gpus/cuda/include/nvtx3/nvToolsExt.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_collector.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_interface.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/platform/types.h" +#include "tensorflow/tsl/profiler/utils/buffer_pool.h" + +namespace xla { +namespace profiler { + +struct CuptiTracerOptions { + bool enable_activity_api = true; + + // Use cuda events to enclose the kernel/memcpy to measure device activity. + // enable_event_based_activity, if true, will override the enable_activity_api + // setting. + bool enable_event_based_activity = false; + + bool required_callback_api_events = true; + // The callback ids that will be enabled and monitored, if empty, all + // Callback ids to be enabled using Callback API. + // We only care CUPTI_CB_DOMAIN_DRIVER_API domain for now. It is kind of + // redundant to have both CUPTI_CB_DOMAIN_DRIVER_API and + // CUPTI_CB_DOMAIN_RUNTIME_API. + std::vector cbids_selected; + // Activity kinds to be collected using Activity API. If empty, the Activity + // API is disable. + std::vector activities_selected; + // Whether to call cuptiFinalize. + bool cupti_finalize = false; + // Whether to call cuCtxSynchronize for each device before Stop(). + bool sync_devices_before_stop = false; + // Whether to enable NVTX tracking, we need this for TensorRT tracking. + bool enable_nvtx_tracking = false; +}; + +class CuptiDriverApiHook { + public: + virtual ~CuptiDriverApiHook() {} + + virtual tsl::Status OnDriverApiEnter( + int device_id, CUpti_CallbackDomain domain, CUpti_CallbackId cbid, + const CUpti_CallbackData* callback_info) = 0; + virtual tsl::Status OnDriverApiExit( + int device_id, CUpti_CallbackDomain domain, CUpti_CallbackId cbid, + const CUpti_CallbackData* callback_info) = 0; + virtual tsl::Status SyncAndFlush() = 0; + + protected: + static tsl::Status AddDriverApiCallbackEvent( + CuptiTraceCollector* collector, CuptiInterface* cupti_interface, + int device_id, tsl::uint64 start_tsc, tsl::uint64 end_tsc, + CUpti_CallbackDomain domain, CUpti_CallbackId cbid, + const CUpti_CallbackData* callback_info); +}; + +// The class use to enable cupti callback/activity API and forward the collected +// trace events to CuptiTraceCollector. There should be only one CuptiTracer +// per process. +class CuptiTracer { + public: + // Not copyable or movable + CuptiTracer(const CuptiTracer&) = delete; + CuptiTracer& operator=(const CuptiTracer&) = delete; + + // Returns a pointer to singleton CuptiTracer. + static CuptiTracer* GetCuptiTracerSingleton(); + + // Only one profile session can be live in the same time. + bool IsAvailable() const; + bool NeedRootAccess() const { return need_root_access_; } + + void Enable(const CuptiTracerOptions& option, CuptiTraceCollector* collector); + void Disable(); + + tsl::Status HandleCallback(CUpti_CallbackDomain domain, CUpti_CallbackId cbid, + const CUpti_CallbackData* callback_info); + + // Returns a buffer and its size for CUPTI to store activities. This buffer + // will be reclaimed when CUPTI makes a callback to ProcessActivityBuffer. + void RequestActivityBuffer(uint8_t** buffer, size_t* size); + + // Parses CUPTI activity events from activity buffer, and emits events for + // CuptiTraceCollector. This function is public because called from registered + // callback. + tsl::Status ProcessActivityBuffer(CUcontext context, uint32_t stream_id, + uint8_t* buffer, size_t size); + + static uint64_t GetTimestamp(); + static int NumGpus(); + // Returns the error (if any) when using libcupti. + static std::string ErrorIfAny(); + + protected: + // protected constructor for injecting mock cupti interface for testing. + explicit CuptiTracer(CuptiInterface* cupti_interface); + + private: + // Buffer size and alignment, 32K and 8 as in CUPTI samples. + static constexpr size_t kBufferSizeInBytes = 32 * 1024; + + tsl::Status EnableApiTracing(); + tsl::Status EnableActivityTracing(); + tsl::Status DisableApiTracing(); + tsl::Status DisableActivityTracing(); + tsl::Status Finalize(); + void ConfigureActivityUnifiedMemoryCounter(bool enable); + tsl::Status HandleNVTXCallback(CUpti_CallbackId cbid, + const CUpti_CallbackData* cbdata); + + int num_gpus_; + std::optional option_; + CuptiInterface* cupti_interface_ = nullptr; + CuptiTraceCollector* collector_ = nullptr; + + // CUPTI 10.1 and higher need root access to profile. + bool need_root_access_ = false; + + bool api_tracing_enabled_ = false; + // Cupti handle for driver or runtime API callbacks. Cupti permits a single + // subscriber to be active at any time and can be used to trace Cuda runtime + // as and driver calls for all contexts and devices. + CUpti_SubscriberHandle subscriber_; // valid when api_tracing_enabled_. + + bool activity_tracing_enabled_ = false; + + std::unique_ptr cupti_driver_api_hook_; + + tsl::profiler::BufferPool buffer_pool_; +}; + +} // namespace profiler +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUPTI_TRACER_H_ diff --git a/tensorflow/core/profiler/backends/gpu/cupti_utils.cc b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_utils.cc similarity index 77% rename from tensorflow/core/profiler/backends/gpu/cupti_utils.cc rename to tensorflow/compiler/xla/backends/profiler/gpu/cupti_utils.cc index 83bd6e86165..c7f1bb3fcfb 100644 --- a/tensorflow/core/profiler/backends/gpu/cupti_utils.cc +++ b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_utils.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "absl/memory/memory.h" -#include "tensorflow/core/profiler/backends/gpu/cupti_error_manager.h" -#include "tensorflow/core/profiler/backends/gpu/cupti_interface.h" -#include "tensorflow/core/profiler/backends/gpu/cupti_wrapper.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_error_manager.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_interface.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_wrapper.h" -namespace tensorflow { +namespace xla { namespace profiler { CuptiInterface* GetCuptiInterface() { @@ -27,4 +27,4 @@ CuptiInterface* GetCuptiInterface() { } } // namespace profiler -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/core/profiler/backends/gpu/cupti_wrapper.cc b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_wrapper.cc similarity index 97% rename from tensorflow/core/profiler/backends/gpu/cupti_wrapper.cc rename to tensorflow/compiler/xla/backends/profiler/gpu/cupti_wrapper.cc index c125e3d810e..99f9966c9ac 100644 --- a/tensorflow/core/profiler/backends/gpu/cupti_wrapper.cc +++ b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_wrapper.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/backends/gpu/cupti_wrapper.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_wrapper.h" #include -namespace tensorflow { +namespace xla { namespace profiler { CUptiResult CuptiWrapper::ActivityDisable(CUpti_ActivityKind kind) { @@ -56,7 +56,8 @@ CUptiResult CuptiWrapper::ActivityRegisterCallbacks( func_buffer_completed); } -CUptiResult CuptiWrapper::GetDeviceId(CUcontext context, uint32* deviceId) { +CUptiResult CuptiWrapper::GetDeviceId(CUcontext context, + tsl::uint32* deviceId) { return cuptiGetDeviceId(context, deviceId); } @@ -245,4 +246,4 @@ CUptiResult CuptiWrapper::GetStreamIdEx(CUcontext context, CUstream stream, } } // namespace profiler -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/backends/profiler/gpu/cupti_wrapper.h b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_wrapper.h new file mode 100644 index 00000000000..c1fb2999f31 --- /dev/null +++ b/tensorflow/compiler/xla/backends/profiler/gpu/cupti_wrapper.h @@ -0,0 +1,185 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUPTI_WRAPPER_H_ +#define TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUPTI_WRAPPER_H_ + +#include +#include + +#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_interface.h" + +namespace xla { +namespace profiler { + +class CuptiWrapper : public xla::profiler::CuptiInterface { + public: + CuptiWrapper() {} + + ~CuptiWrapper() override {} + + // CUPTI activity API + CUptiResult ActivityDisable(CUpti_ActivityKind kind) override; + + CUptiResult ActivityEnable(CUpti_ActivityKind kind) override; + + CUptiResult ActivityFlushAll(uint32_t flag) override; + + CUptiResult ActivityGetNextRecord(uint8_t* buffer, + size_t valid_buffer_size_bytes, + CUpti_Activity** record) override; + + CUptiResult ActivityGetNumDroppedRecords(CUcontext context, + uint32_t stream_id, + size_t* dropped) override; + + CUptiResult ActivityConfigureUnifiedMemoryCounter( + CUpti_ActivityUnifiedMemoryCounterConfig* config, + uint32_t count) override; + + CUptiResult ActivityRegisterCallbacks( + CUpti_BuffersCallbackRequestFunc func_buffer_requested, + CUpti_BuffersCallbackCompleteFunc func_buffer_completed) override; + + CUptiResult GetDeviceId(CUcontext context, tsl::uint32* deviceId) override; + + CUptiResult GetTimestamp(uint64_t* timestamp) override; + + // cuptiFinalize is only defined in CUDA8 and above. + // To enable it in CUDA8, the environment variable CUPTI_ENABLE_FINALIZE must + // be set to 1. + CUptiResult Finalize() override; + + // CUPTI callback API + CUptiResult EnableCallback(uint32_t enable, CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain, + CUpti_CallbackId cbid) override; + + CUptiResult EnableDomain(uint32_t enable, CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain) override; + + CUptiResult Subscribe(CUpti_SubscriberHandle* subscriber, + CUpti_CallbackFunc callback, void* userdata) override; + + CUptiResult Unsubscribe(CUpti_SubscriberHandle subscriber) override; + + // CUPTI event API + CUptiResult DeviceEnumEventDomains( + CUdevice device, size_t* array_size_bytes, + CUpti_EventDomainID* domain_array) override; + + CUptiResult DeviceGetEventDomainAttribute(CUdevice device, + CUpti_EventDomainID event_domain, + CUpti_EventDomainAttribute attrib, + size_t* value_size, + void* value) override; + + CUptiResult DisableKernelReplayMode(CUcontext context) override; + + CUptiResult EnableKernelReplayMode(CUcontext context) override; + + CUptiResult DeviceGetNumEventDomains(CUdevice device, + uint32_t* num_domains) override; + + CUptiResult EventDomainEnumEvents(CUpti_EventDomainID event_domain, + size_t* array_size_bytes, + CUpti_EventID* event_array) override; + + CUptiResult EventDomainGetNumEvents(CUpti_EventDomainID event_domain, + uint32_t* num_events) override; + + CUptiResult EventGetAttribute(CUpti_EventID event, + CUpti_EventAttribute attrib, size_t* value_size, + void* value) override; + + CUptiResult EventGetIdFromName(CUdevice device, const char* event_name, + CUpti_EventID* event) override; + + CUptiResult EventGroupDisable(CUpti_EventGroup event_group) override; + + CUptiResult EventGroupEnable(CUpti_EventGroup event_group) override; + + CUptiResult EventGroupGetAttribute(CUpti_EventGroup event_group, + CUpti_EventGroupAttribute attrib, + size_t* value_size, void* value) override; + + CUptiResult EventGroupReadEvent(CUpti_EventGroup event_group, + CUpti_ReadEventFlags flags, + CUpti_EventID event, + size_t* event_value_buffer_size_bytes, + uint64_t* event_value_buffer) override; + + CUptiResult EventGroupSetAttribute(CUpti_EventGroup event_group, + CUpti_EventGroupAttribute attrib, + size_t value_size, void* value) override; + + CUptiResult EventGroupSetsCreate( + CUcontext context, size_t event_id_array_size_bytes, + CUpti_EventID* event_id_array, + CUpti_EventGroupSets** event_group_passes) override; + + CUptiResult EventGroupSetsDestroy( + CUpti_EventGroupSets* event_group_sets) override; + + // CUPTI metric API + CUptiResult DeviceEnumMetrics(CUdevice device, size_t* arraySizeBytes, + CUpti_MetricID* metricArray) override; + + CUptiResult DeviceGetNumMetrics(CUdevice device, + uint32_t* num_metrics) override; + + CUptiResult MetricGetIdFromName(CUdevice device, const char* metric_name, + CUpti_MetricID* metric) override; + + CUptiResult MetricGetNumEvents(CUpti_MetricID metric, + uint32_t* num_events) override; + + CUptiResult MetricEnumEvents(CUpti_MetricID metric, + size_t* event_id_array_size_bytes, + CUpti_EventID* event_id_array) override; + + CUptiResult MetricGetAttribute(CUpti_MetricID metric, + CUpti_MetricAttribute attrib, + size_t* value_size, void* value) override; + + CUptiResult MetricGetValue(CUdevice device, CUpti_MetricID metric, + size_t event_id_array_size_bytes, + CUpti_EventID* event_id_array, + size_t event_value_array_size_bytes, + uint64_t* event_value_array, + uint64_t time_duration, + CUpti_MetricValue* metric_value) override; + + CUptiResult GetResultString(CUptiResult result, const char** str) override; + + CUptiResult GetContextId(CUcontext context, uint32_t* context_id) override; + + CUptiResult GetStreamIdEx(CUcontext context, CUstream stream, + uint8_t per_thread_stream, + uint32_t* stream_id) override; + + void CleanUp() override {} + bool Disabled() const override { return false; } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(CuptiWrapper); +}; + +} // namespace profiler +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_CUPTI_WRAPPER_H_ diff --git a/tensorflow/core/profiler/backends/gpu/device_tracer_cuda.cc b/tensorflow/compiler/xla/backends/profiler/gpu/device_tracer_cuda.cc similarity index 83% rename from tensorflow/core/profiler/backends/gpu/device_tracer_cuda.cc rename to tensorflow/compiler/xla/backends/profiler/gpu/device_tracer_cuda.cc index 68de0d32463..3e8980178b3 100644 --- a/tensorflow/core/profiler/backends/gpu/device_tracer_cuda.cc +++ b/tensorflow/compiler/xla/backends/profiler/gpu/device_tracer_cuda.cc @@ -23,24 +23,29 @@ limitations under the License. #include "absl/container/fixed_array.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "tensorflow/core/framework/step_stats.pb.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/profiler/backends/gpu/cupti_collector.h" -#include "tensorflow/core/profiler/backends/gpu/cupti_tracer.h" -#include "tensorflow/core/profiler/backends/gpu/cupti_wrapper.h" -#include "tensorflow/core/profiler/lib/profiler_factory.h" -#include "tensorflow/core/profiler/lib/profiler_interface.h" -#include "tensorflow/core/profiler/protobuf/xplane.pb.h" -#include "tensorflow/core/profiler/utils/time_utils.h" -#include "tensorflow/core/util/env_var.h" - -namespace tensorflow { +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_collector.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_tracer.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_wrapper.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/macros.h" +#include "tensorflow/tsl/platform/thread_annotations.h" +#include "tensorflow/tsl/profiler/lib/profiler_factory.h" +#include "tensorflow/tsl/profiler/lib/profiler_interface.h" +#include "tensorflow/tsl/profiler/protobuf/xplane.pb.h" +#include "tensorflow/tsl/profiler/utils/time_utils.h" +#include "tensorflow/tsl/util/env_var.h" + +namespace xla { namespace profiler { +using tensorflow::ProfileOptions; +using tensorflow::profiler::XSpace; +using tsl::OkStatus; +using tsl::ReadBoolFromEnvVar; +using tsl::Status; + // GpuTracer for GPU. -class GpuTracer : public profiler::ProfilerInterface { +class GpuTracer : public tsl::profiler::ProfilerInterface { public: GpuTracer(CuptiTracer* cupti_tracer, CuptiInterface* cupti_interface) : cupti_tracer_(cupti_tracer) { @@ -73,7 +78,7 @@ class GpuTracer : public profiler::ProfilerInterface { Status GpuTracer::DoStart() { if (!cupti_tracer_->IsAvailable()) { - return errors::Unavailable("Another profile session running."); + return tsl::errors::Unavailable("Another profile session running."); } options_.cbids_selected = { @@ -149,8 +154,8 @@ Status GpuTracer::DoStart() { CuptiTracerCollectorOptions collector_options; collector_options.num_gpus = cupti_tracer_->NumGpus(); - uint64 start_gputime_ns = CuptiTracer::GetTimestamp(); - uint64 start_walltime_ns = GetCurrentTimeNanos(); + tsl::uint64 start_gputime_ns = CuptiTracer::GetTimestamp(); + tsl::uint64 start_walltime_ns = tsl::profiler::GetCurrentTimeNanos(); cupti_collector_ = CreateCuptiCollector(collector_options, start_walltime_ns, start_gputime_ns); @@ -189,7 +194,8 @@ Status GpuTracer::CollectData(XSpace* space) { VLOG(1) << "No trace data collected, session wasn't started"; return OkStatus(); case State::kStartedOk: - return errors::FailedPrecondition("Cannot collect trace before stopping"); + return tsl::errors::FailedPrecondition( + "Cannot collect trace before stopping"); case State::kStartedError: LOG(ERROR) << "Cannot collect, profiler failed to start"; return OkStatus(); @@ -206,17 +212,17 @@ Status GpuTracer::CollectData(XSpace* space) { space->add_warnings(std::move(events_dropped)); } if (cupti_collector_) { - uint64 end_gpu_ns = CuptiTracer::GetTimestamp(); + tsl::uint64 end_gpu_ns = CuptiTracer::GetTimestamp(); cupti_collector_->Export(space, end_gpu_ns); } return OkStatus(); } } - return errors::Internal("Invalid profiling state: ", profiling_state_); + return tsl::errors::Internal("Invalid profiling state: ", profiling_state_); } // Not in anonymous namespace for testing purposes. -std::unique_ptr CreateGpuTracer( +std::unique_ptr CreateGpuTracer( const ProfileOptions& options) { if (options.device_tracer_level() == 0) return nullptr; if (options.device_type() != ProfileOptions::GPU && @@ -237,6 +243,6 @@ auto register_gpu_tracer_factory = [] { }(); } // namespace profiler -} // namespace tensorflow +} // namespace xla #endif // GOOGLE_CUDA diff --git a/tensorflow/core/profiler/backends/gpu/device_tracer_rocm.cc b/tensorflow/compiler/xla/backends/profiler/gpu/device_tracer_rocm.cc similarity index 93% rename from tensorflow/core/profiler/backends/gpu/device_tracer_rocm.cc rename to tensorflow/compiler/xla/backends/profiler/gpu/device_tracer_rocm.cc index 2560c1b5d48..ecf9fee8d02 100644 --- a/tensorflow/core/profiler/backends/gpu/device_tracer_rocm.cc +++ b/tensorflow/compiler/xla/backends/profiler/gpu/device_tracer_rocm.cc @@ -24,25 +24,46 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/abi.h" -#include "tensorflow/core/platform/env_time.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/profiler/backends/cpu/annotation_stack.h" -#include "tensorflow/core/profiler/backends/gpu/rocm_tracer.h" -#include "tensorflow/core/profiler/lib/profiler_factory.h" -#include "tensorflow/core/profiler/lib/profiler_interface.h" -#include "tensorflow/core/profiler/utils/parse_annotation.h" -#include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tensorflow/core/util/env_var.h" - -namespace tensorflow { +#include "tensorflow/compiler/xla/backends/profiler/gpu/rocm_tracer.h" +#include "tensorflow/tsl/platform/abi.h" +#include "tensorflow/tsl/platform/env_time.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/macros.h" +#include "tensorflow/tsl/platform/mutex.h" +#include "tensorflow/tsl/platform/thread_annotations.h" +#include "tensorflow/tsl/profiler/backends/cpu/annotation_stack.h" +#include "tensorflow/tsl/profiler/lib/profiler_factory.h" +#include "tensorflow/tsl/profiler/lib/profiler_interface.h" +#include "tensorflow/tsl/profiler/utils/parse_annotation.h" +#include "tensorflow/tsl/profiler/utils/xplane_builder.h" +#include "tensorflow/tsl/profiler/utils/xplane_schema.h" +#include "tensorflow/tsl/profiler/utils/xplane_utils.h" +#include "tensorflow/tsl/util/env_var.h" + +namespace xla { namespace profiler { + +using tensorflow::ProfileOptions; +using tsl::mutex; +using tsl::mutex_lock; +using tsl::OkStatus; +using tsl::Status; +using tsl::profiler::Annotation; +using tsl::profiler::AnnotationStack; +using tsl::profiler::FindOrAddMutablePlaneWithName; +using tsl::profiler::GetStatTypeStr; +using tsl::profiler::GpuPlaneName; +using tsl::profiler::kDeviceVendorAMD; using tsl::profiler::kThreadIdOverhead; +using tsl::profiler::ParseAnnotationStack; +using tsl::profiler::ProfilerInterface; +using tsl::profiler::RegisterProfilerFactory; +using tsl::profiler::StatType; +using tsl::profiler::XEventBuilder; +using tsl::profiler::XEventMetadata; +using tsl::profiler::XLineBuilder; +using tsl::profiler::XPlaneBuilder; +using tsl::profiler::XSpace; namespace { // Set the all XLines of specified XPlane to starting walltime. @@ -52,7 +73,7 @@ namespace { // start_walltime_ns to normalize with CPU wall time. static void NormalizeTimeStamps(XPlaneBuilder* plane, uint64_t start_walltime_ns) { - plane->ForEachLine([&](tensorflow::profiler::XLineBuilder line) { + plane->ForEachLine([&](tsl::profiler::XLineBuilder line) { line.SetTimestampNs(start_walltime_ns); }); } @@ -170,8 +191,8 @@ class RocmTraceCollectorImpl : public profiler::RocmTraceCollector { void Export(XSpace* space) { uint64_t end_gputime_ns = RocmTracer::GetTimestamp(); - XPlaneBuilder host_plane( - FindOrAddMutablePlaneWithName(space, kRoctracerApiPlaneName)); + XPlaneBuilder host_plane(FindOrAddMutablePlaneWithName( + space, tsl::profiler::kRoctracerApiPlaneName)); for (int i = 0; i < options_.num_gpus; ++i) { std::string name = GpuPlaneName(i); XPlaneBuilder device_plane(FindOrAddMutablePlaneWithName(space, name)); @@ -195,20 +216,20 @@ class RocmTraceCollectorImpl : public profiler::RocmTraceCollector { uint64_t start_gputime_ns_; mutex event_maps_mutex_; - absl::flat_hash_map api_events_map_ + absl::flat_hash_map api_events_map_ TF_GUARDED_BY(event_maps_mutex_); - absl::flat_hash_map activity_api_events_map_ + absl::flat_hash_map activity_api_events_map_ TF_GUARDED_BY(event_maps_mutex_); /* Some apis such as MEMSETD32 (based on an observation with ResNet50), trigger multiple HIP ops domain activities. We keep them in a vector and merge them with api activities at flush time. */ - absl::flat_hash_map> + absl::flat_hash_map> activity_ops_events_map_ TF_GUARDED_BY(event_maps_mutex_); // This is for the APIs that we track because we need some information from // them to populate the corresponding activity that we actually track. - absl::flat_hash_map auxiliary_api_events_map_ + absl::flat_hash_map auxiliary_api_events_map_ TF_GUARDED_BY(event_maps_mutex_); const std::vector ApiActivityInfoExchange() { @@ -563,7 +584,7 @@ class RocmTraceCollectorImpl : public profiler::RocmTraceCollector { // Times 2 because HBM is DDR memory; it gets two data bits per each // data lane. auto memory_bandwidth = - uint64{2} * (mem_clock_khz)*1000 * (mem_bus_width_bits) / 8; + tsl::uint64{2} * (mem_clock_khz)*1000 * (mem_bus_width_bits) / 8; device_plane->AddStatValue( *device_plane->GetOrCreateStatMetadata( GetStatTypeStr(StatType::kDevCapMemoryBandwidth)), @@ -575,7 +596,7 @@ class RocmTraceCollectorImpl : public profiler::RocmTraceCollector { device_plane->AddStatValue( *device_plane->GetOrCreateStatMetadata( GetStatTypeStr(StatType::kDevCapMemorySize)), - static_cast(total_memory)); + static_cast(total_memory)); } auto compute_capability_major = device_properties_.major; @@ -671,7 +692,7 @@ class RocmTraceCollectorImpl : public profiler::RocmTraceCollector { << " corr. id:" << event.correlation_id; return; } - std::string kernel_name = port::MaybeAbiDemangle(event.name.c_str()); + std::string kernel_name = tsl::port::MaybeAbiDemangle(event.name.c_str()); if (kernel_name.empty()) { kernel_name = GetRocmTracerEventTypeName(event.type); } @@ -700,7 +721,7 @@ class RocmTraceCollectorImpl : public profiler::RocmTraceCollector { // xevent.AddStatValue( // *plane->GetOrCreateStatMetadata( // GetStatTypeStr(StatType::kContextId)), - // absl::StrCat("$$", static_cast(event.context_id))); + // absl::StrCat("$$", static_cast(event.context_id))); // } if (event.type == RocmTracerEventType::Kernel && @@ -731,10 +752,11 @@ class RocmTraceCollectorImpl : public profiler::RocmTraceCollector { occ_stats.occupancy_pct); xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr( StatType::kOccupancyMinGridSize)), - static_cast(occ_stats.min_grid_size)); - xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr( - StatType::kOccupancySuggestedBlockSize)), - static_cast(occ_stats.suggested_block_size)); + static_cast(occ_stats.min_grid_size)); + xevent.AddStatValue( + *plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kOccupancySuggestedBlockSize)), + static_cast(occ_stats.suggested_block_size)); xevent.AddStatValue(*plane->GetOrCreateStatMetadata( GetStatTypeStr(StatType::kKernelDetails)), *plane->GetOrCreateStatMetadata(ToXStat( @@ -821,7 +843,7 @@ class RocmTraceCollectorImpl : public profiler::RocmTraceCollector { } } } - bool IsHostEvent(const RocmTracerEvent& event, int64* line_id) { + bool IsHostEvent(const RocmTracerEvent& event, tsl::int64* line_id) { // DriverCallback(i.e. kernel launching) events are host events. if (event.source == RocmTracerEventSource::ApiCallback) { *line_id = event.thread_id; @@ -859,7 +881,7 @@ class RocmTraceCollectorImpl : public profiler::RocmTraceCollector { int host_ev_cnt = 0, dev_ev_cnt = 0; mutex_lock l(events_mutex); // Tracking event types per line. - absl::flat_hash_map> + absl::flat_hash_map> events_types_per_line; for (const RocmTracerEvent& event : events) { int64_t line_id = RocmTracerEvent::kInvalidThreadId; @@ -899,7 +921,7 @@ class RocmTraceCollectorImpl : public profiler::RocmTraceCollector { mutex events_mutex; std::vector events TF_GUARDED_BY(events_mutex); - absl::flat_hash_map correlation_info_ + absl::flat_hash_map correlation_info_ TF_GUARDED_BY(events_mutex); absl::flat_hash_map occupancy_cache_; @@ -925,7 +947,6 @@ class GpuTracer : public profiler::ProfilerInterface { private: Status DoStart(); Status DoStop(); - Status DoCollectData(XSpace* space); RocmTracerOptions GetRocmTracerOptions(); @@ -1036,7 +1057,7 @@ RocmTraceCollectorOptions GpuTracer::GetRocmTraceCollectorOptions( Status GpuTracer::DoStart() { if (!rocm_tracer_->IsAvailable()) { - return errors::Unavailable("Another profile session running."); + return tsl::errors::Unavailable("Another profile session running."); } AnnotationStack::Enable(true); @@ -1044,7 +1065,7 @@ Status GpuTracer::DoStart() { RocmTraceCollectorOptions trace_collector_options = GetRocmTraceCollectorOptions(rocm_tracer_->NumGpus()); uint64_t start_gputime_ns = RocmTracer::GetTimestamp(); - uint64_t start_walltime_ns = tensorflow::EnvTime::NowNanos(); + uint64_t start_walltime_ns = tsl::EnvTime::NowNanos(); rocm_trace_collector_ = std::make_unique( trace_collector_options, start_walltime_ns, start_gputime_ns); @@ -1079,18 +1100,14 @@ Status GpuTracer::Stop() { return OkStatus(); } -Status GpuTracer::DoCollectData(XSpace* space) { - if (rocm_trace_collector_) rocm_trace_collector_->Export(space); - return OkStatus(); -} - Status GpuTracer::CollectData(XSpace* space) { switch (profiling_state_) { case State::kNotStarted: VLOG(3) << "No trace data collected, session wasn't started"; return OkStatus(); case State::kStartedOk: - return errors::FailedPrecondition("Cannot collect trace before stopping"); + return tsl::errors::FailedPrecondition( + "Cannot collect trace before stopping"); case State::kStartedError: LOG(ERROR) << "Cannot collect, roctracer failed to start"; return OkStatus(); @@ -1098,11 +1115,11 @@ Status GpuTracer::CollectData(XSpace* space) { VLOG(3) << "No trace data collected"; return OkStatus(); case State::kStoppedOk: { - DoCollectData(space); + if (rocm_trace_collector_) rocm_trace_collector_->Export(space); return OkStatus(); } } - return errors::Internal("Invalid profiling state: ", profiling_state_); + return tsl::errors::Internal("Invalid profiling state: ", profiling_state_); } // Not in anonymous namespace for testing purposes. @@ -1125,6 +1142,6 @@ auto register_rocm_gpu_tracer_factory = [] { }(); } // namespace profiler -} // namespace tensorflow +} // namespace xla #endif // TENSORFLOW_USE_ROCM diff --git a/tensorflow/compiler/xla/backends/profiler/gpu/mock_cupti.h b/tensorflow/compiler/xla/backends/profiler/gpu/mock_cupti.h new file mode 100644 index 00000000000..3a7b4b3eb7e --- /dev/null +++ b/tensorflow/compiler/xla/backends/profiler/gpu/mock_cupti.h @@ -0,0 +1,168 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_MOCK_CUPTI_H_ +#define TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_MOCK_CUPTI_H_ + +#include +#include + +#include + +#include "tensorflow/compiler/xla/backends/profiler/gpu/cupti_interface.h" +#include "tensorflow/tsl/platform/test.h" + +namespace xla { +namespace profiler { + +// A mock object automatically generated by gmock_gen.py. +class MockCupti : public xla::profiler::CuptiInterface { + public: + MOCK_METHOD(CUptiResult, ActivityDisable, (CUpti_ActivityKind kind), + (override)); + MOCK_METHOD(CUptiResult, ActivityEnable, (CUpti_ActivityKind kind), + (override)); + MOCK_METHOD(CUptiResult, ActivityFlushAll, (uint32_t flag), (override)); + MOCK_METHOD(CUptiResult, ActivityGetNextRecord, + (uint8_t * buffer, size_t valid_buffer_size_bytes, + CUpti_Activity** record), + (override)); + MOCK_METHOD(CUptiResult, ActivityGetNumDroppedRecords, + (CUcontext context, uint32_t stream_id, size_t* dropped), + (override)); + MOCK_METHOD(CUptiResult, ActivityConfigureUnifiedMemoryCounter, + (CUpti_ActivityUnifiedMemoryCounterConfig * config, + uint32_t count), + (override)); + MOCK_METHOD(CUptiResult, ActivityRegisterCallbacks, + (CUpti_BuffersCallbackRequestFunc func_buffer_requested, + CUpti_BuffersCallbackCompleteFunc func_buffer_completed), + (override)); + MOCK_METHOD(CUptiResult, GetDeviceId, (CUcontext context, uint32_t* deviceId), + (override)); + MOCK_METHOD(CUptiResult, GetTimestamp, (uint64_t * timestamp), (override)); + MOCK_METHOD(CUptiResult, Finalize, (), (override)); + MOCK_METHOD(CUptiResult, EnableCallback, + (uint32_t enable, CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain, CUpti_CallbackId cbid), + (override)); + MOCK_METHOD(CUptiResult, EnableDomain, + (uint32_t enable, CUpti_SubscriberHandle subscriber, + CUpti_CallbackDomain domain), + (override)); + MOCK_METHOD(CUptiResult, Subscribe, + (CUpti_SubscriberHandle * subscriber, CUpti_CallbackFunc callback, + void* userdata), + (override)); + MOCK_METHOD(CUptiResult, Unsubscribe, (CUpti_SubscriberHandle subscriber), + (override)); + MOCK_METHOD(CUptiResult, DeviceEnumEventDomains, + (CUdevice device, size_t* array_size_bytes, + CUpti_EventDomainID* domain_array), + (override)); + MOCK_METHOD(CUptiResult, DeviceGetEventDomainAttribute, + (CUdevice device, CUpti_EventDomainID event_domain, + CUpti_EventDomainAttribute attrib, size_t* value_size, + void* value), + (override)); + MOCK_METHOD(CUptiResult, DisableKernelReplayMode, (CUcontext context), + (override)); + MOCK_METHOD(CUptiResult, EnableKernelReplayMode, (CUcontext context), + (override)); + MOCK_METHOD(CUptiResult, DeviceGetNumEventDomains, + (CUdevice device, uint32_t* num_domains), (override)); + MOCK_METHOD(CUptiResult, EventDomainEnumEvents, + (CUpti_EventDomainID event_domain, size_t* array_size_bytes, + CUpti_EventID* event_array), + (override)); + MOCK_METHOD(CUptiResult, EventDomainGetNumEvents, + (CUpti_EventDomainID event_domain, uint32_t* num_events), + (override)); + MOCK_METHOD(CUptiResult, EventGetAttribute, + (CUpti_EventID event, CUpti_EventAttribute attrib, + size_t* value_size, void* value), + (override)); + MOCK_METHOD(CUptiResult, EventGetIdFromName, + (CUdevice device, const char* event_name, CUpti_EventID* event), + (override)); + MOCK_METHOD(CUptiResult, EventGroupDisable, (CUpti_EventGroup event_group), + (override)); + MOCK_METHOD(CUptiResult, EventGroupEnable, (CUpti_EventGroup event_group), + (override)); + MOCK_METHOD(CUptiResult, EventGroupGetAttribute, + (CUpti_EventGroup event_group, CUpti_EventGroupAttribute attrib, + size_t* value_size, void* value), + (override)); + MOCK_METHOD(CUptiResult, EventGroupReadEvent, + (CUpti_EventGroup event_group, CUpti_ReadEventFlags flags, + CUpti_EventID event, size_t* event_value_buffer_size_bytes, + uint64_t* eventValueBuffer), + (override)); + MOCK_METHOD(CUptiResult, EventGroupSetAttribute, + (CUpti_EventGroup event_group, CUpti_EventGroupAttribute attrib, + size_t value_size, void* value), + (override)); + MOCK_METHOD(CUptiResult, EventGroupSetsCreate, + (CUcontext context, size_t event_id_array_size_bytes, + CUpti_EventID* event_id_array, + CUpti_EventGroupSets** event_group_passes), + (override)); + MOCK_METHOD(CUptiResult, EventGroupSetsDestroy, + (CUpti_EventGroupSets * event_group_sets), (override)); + MOCK_METHOD(CUptiResult, DeviceEnumMetrics, + (CUdevice device, size_t* arraySizeBytes, + CUpti_MetricID* metricArray), + (override)); + MOCK_METHOD(CUptiResult, DeviceGetNumMetrics, + (CUdevice device, uint32_t* num_metrics), (override)); + MOCK_METHOD(CUptiResult, MetricGetIdFromName, + (CUdevice device, const char* metric_name, + CUpti_MetricID* metric), + (override)); + MOCK_METHOD(CUptiResult, MetricGetNumEvents, + (CUpti_MetricID metric, uint32_t* num_events), (override)); + MOCK_METHOD(CUptiResult, MetricEnumEvents, + (CUpti_MetricID metric, size_t* event_id_array_size_bytes, + CUpti_EventID* event_id_array), + (override)); + MOCK_METHOD(CUptiResult, MetricGetAttribute, + (CUpti_MetricID metric, CUpti_MetricAttribute attrib, + size_t* value_size, void* value), + (override)); + MOCK_METHOD(CUptiResult, MetricGetValue, + (CUdevice device, CUpti_MetricID metric, + size_t event_id_array_size_bytes, CUpti_EventID* event_id_array, + size_t event_value_array_size_bytes, uint64_t* event_value_array, + uint64_t time_duration, CUpti_MetricValue* metric_value), + (override)); + MOCK_METHOD(CUptiResult, GetResultString, + (CUptiResult result, const char** str), (override)); + + MOCK_METHOD(CUptiResult, GetContextId, + (CUcontext context, uint32_t* context_id), (override)); + + MOCK_METHOD(CUptiResult, GetStreamIdEx, + (CUcontext context, CUstream stream, uint8_t per_thread_stream, + uint32_t* stream_id), + (override)); + + MOCK_METHOD(void, CleanUp, (), (override)); + MOCK_METHOD(bool, Disabled, (), (const, override)); +}; + +} // namespace profiler +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_MOCK_CUPTI_H_ diff --git a/tensorflow/core/profiler/backends/gpu/nvtx_utils.cc b/tensorflow/compiler/xla/backends/profiler/gpu/nvtx_utils.cc similarity index 85% rename from tensorflow/core/profiler/backends/gpu/nvtx_utils.cc rename to tensorflow/compiler/xla/backends/profiler/gpu/nvtx_utils.cc index 13d6807a845..2e54e047ab9 100644 --- a/tensorflow/core/profiler/backends/gpu/nvtx_utils.cc +++ b/tensorflow/compiler/xla/backends/profiler/gpu/nvtx_utils.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/backends/gpu/nvtx_utils.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/nvtx_utils.h" #include "third_party/gpus/cuda/include/nvtx3/nvToolsExt.h" -#include "tensorflow/core/platform/platform.h" +#include "tensorflow/tsl/platform/platform.h" -namespace tensorflow { +namespace xla { namespace profiler { /*static*/ std::stack &NVTXRangeTracker::GetRangeStack() { @@ -27,4 +27,4 @@ namespace profiler { } } // namespace profiler -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/backends/profiler/gpu/nvtx_utils.h b/tensorflow/compiler/xla/backends/profiler/gpu/nvtx_utils.h new file mode 100644 index 00000000000..f6fb1f27f12 --- /dev/null +++ b/tensorflow/compiler/xla/backends/profiler/gpu/nvtx_utils.h @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_NVTX_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_NVTX_UTILS_H_ + +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/tsl/platform/macros.h" + +namespace xla { +namespace profiler { + +/*** + * We have no intention to use NVTX in tensorflow right now, we use this class + * to track NVTX instrumentation inside NVIDIA libraries (such as TensorRT). + * This bears a lot of resemblance to ScopedAnnotation for now. In the future, + * we will use TraceMe to keep track trace context within a thread. + */ +class NVTXRangeTracker { + public: + static void EnterRange(const std::string& range) { + auto& range_stack = GetRangeStack(); + range_stack.push(range); + } + static void ExitRange() { + auto& range_stack = GetRangeStack(); + if (!range_stack.empty()) range_stack.pop(); + } + static const absl::string_view CurrentRange() { + auto& range_stack = GetRangeStack(); + if (!range_stack.empty()) return range_stack.top(); + return ""; + } + + private: + static std::stack& GetRangeStack(); + + TF_DISALLOW_COPY_AND_ASSIGN(NVTXRangeTracker); +}; + +} // namespace profiler +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_NVTX_UTILS_H_ diff --git a/tensorflow/core/profiler/backends/gpu/rocm_tracer.cc b/tensorflow/compiler/xla/backends/profiler/gpu/rocm_tracer.cc similarity index 94% rename from tensorflow/core/profiler/backends/gpu/rocm_tracer.cc rename to tensorflow/compiler/xla/backends/profiler/gpu/rocm_tracer.cc index 811168993dc..fddd2f9032c 100644 --- a/tensorflow/core/profiler/backends/gpu/rocm_tracer.cc +++ b/tensorflow/compiler/xla/backends/profiler/gpu/rocm_tracer.cc @@ -13,34 +13,38 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/profiler/backends/gpu/rocm_tracer.h" +#include "tensorflow/compiler/xla/backends/profiler/gpu/rocm_tracer.h" #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" #include "rocm/rocm_config.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mem.h" -#include "tensorflow/core/profiler/backends/cpu/annotation_stack.h" -#include "tensorflow/core/profiler/utils/time_utils.h" - -namespace tensorflow { +#include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/logging.h" +#include "tensorflow/tsl/platform/macros.h" +#include "tensorflow/tsl/platform/mem.h" +#include "tensorflow/tsl/profiler/backends/cpu/annotation_stack.h" +#include "tensorflow/tsl/profiler/utils/time_utils.h" + +namespace xla { namespace profiler { +namespace se = ::stream_executor; +using tsl::mutex; +using tsl::mutex_lock; +using tsl::profiler::AnnotationStack; + constexpr uint32_t RocmTracerEvent::kInvalidDeviceId; -#define RETURN_IF_ROCTRACER_ERROR(expr) \ - do { \ - roctracer_status_t status = expr; \ - if (status != ROCTRACER_STATUS_SUCCESS) { \ - const char* errstr = se::wrap::roctracer_error_string(); \ - LOG(ERROR) << "function " << #expr << "failed with error " << errstr; \ - return errors::Internal(absl::StrCat("roctracer call error", errstr)); \ - } \ +#define RETURN_IF_ROCTRACER_ERROR(expr) \ + do { \ + roctracer_status_t status = expr; \ + if (status != ROCTRACER_STATUS_SUCCESS) { \ + const char* errstr = se::wrap::roctracer_error_string(); \ + LOG(ERROR) << "function " << #expr << "failed with error " << errstr; \ + return tsl::errors::Internal( \ + absl::StrCat("roctracer call error", errstr)); \ + } \ } while (false) namespace { @@ -50,7 +54,7 @@ namespace { // it can take roughly 98ns, while it takes roughly 1ns with this caching. int32_t GetCachedTID() { static thread_local int32_t current_thread_id = - Env::Default()->GetCurrentThreadId(); + tsl::Env::Default()->GetCurrentThreadId(); return current_thread_id; } @@ -77,7 +81,7 @@ const char* GetActivityDomainName(uint32_t domain) { return ""; } -string GetActivityDomainOpName(uint32_t domain, uint32_t op) { +std::string GetActivityDomainOpName(uint32_t domain, uint32_t op) { std::ostringstream oss; oss << GetActivityDomainName(domain) << " - "; switch (domain) { @@ -267,7 +271,7 @@ const char* GetRocmTracerEventDomainName(const RocmTracerEventDomain& domain) { void DumpRocmTracerEvent(const RocmTracerEvent& event, uint64_t start_walltime_ns, uint64_t start_gputime_ns, - const string& message) { + const std::string& message) { std::ostringstream oss; oss << "correlation_id=" << event.correlation_id; oss << ",type=" << GetRocmTracerEventTypeName(event.type); @@ -308,8 +312,8 @@ void DumpRocmTracerEvent(const RocmTracerEvent& event, VLOG(3) << oss.str(); } -Status RocmApiCallbackImpl::operator()(uint32_t domain, uint32_t cbid, - const void* cbdata) { +tsl::Status RocmApiCallbackImpl::operator()(uint32_t domain, uint32_t cbid, + const void* cbdata) { /* Some APIs such as hipMalloc, implicitly work on th devices set by the user using APIs such as hipSetDevice. API callbacks and activity records for functions like hipMalloc does not return the device id (CUDA does). To @@ -321,7 +325,7 @@ Status RocmApiCallbackImpl::operator()(uint32_t domain, uint32_t cbid, // DumpApiCallbackData(domain, cbid, cbdata); - if (domain != ACTIVITY_DOMAIN_HIP_API) return OkStatus(); + if (domain != ACTIVITY_DOMAIN_HIP_API) return tsl::OkStatus(); const hip_api_data_t* data = reinterpret_cast(cbdata); @@ -349,7 +353,7 @@ Status RocmApiCallbackImpl::operator()(uint32_t domain, uint32_t cbid, } else { LOG(WARNING) << "An API exit callback received without API enter " "with same correlation id. Event droped!"; - return OkStatus(); // This API does not belong to us. + return tsl::OkStatus(); // This API does not belong to us. } exit_time = RocmTracer::GetTimestamp(); } @@ -430,7 +434,7 @@ Status RocmApiCallbackImpl::operator()(uint32_t domain, uint32_t cbid, break; } } - return OkStatus(); + return tsl::OkStatus(); } void RocmApiCallbackImpl::AddKernelEventUponApiExit(uint32_t cbid, @@ -869,8 +873,8 @@ void RocmApiCallbackImpl::AddSynchronizeEventUponApiExit( collector_->AddEvent(std::move(event), is_auxiliary); } -Status RocmActivityCallbackImpl::operator()(const char* begin, - const char* end) { +tsl::Status RocmActivityCallbackImpl::operator()(const char* begin, + const char* end) { // we do not dump activities in this set in logger static std::set dump_excluded_activities = { @@ -947,7 +951,7 @@ Status RocmActivityCallbackImpl::operator()(const char* begin, default: if (dump_excluded_activities.find(record->op) == dump_excluded_activities.end()) { - string drop_message( + std::string drop_message( "\nNot in the API tracked activities. Dropped!"); DumpActivityRecord(record, drop_message); } @@ -988,30 +992,36 @@ Status RocmActivityCallbackImpl::operator()(const char* begin, // markers are with 0ns duration. break; default: - string drop_message( + std::string drop_message( "\nNot in the HIP-OPS-COPY tracked activities. Dropeed!"); DumpActivityRecord(record, drop_message); break; } // switch (record->kind) break; default: - string drop_message( + std::string drop_message( "\nNot in the HIP-OPS tracked activities. Dropped!"); DumpActivityRecord(record, drop_message); break; } // switch (record->op). break; default: - string drop_message("\nNot in the tracked domain activities. Dropped!"); + std::string drop_message( + "\nNot in the tracked domain activities. Dropped!"); DumpActivityRecord(record, drop_message); break; } RETURN_IF_ROCTRACER_ERROR(static_cast( - roctracer_next_record(record, &record))); +#if TF_ROCM_VERSION >= 50300 + se::wrap::roctracer_next_record(record, &record) +#else + roctracer_next_record(record, &record) +#endif + )); } - return OkStatus(); + return tsl::OkStatus(); } void RocmActivityCallbackImpl::AddHipKernelActivityEvent( @@ -1401,16 +1411,18 @@ void RocmTracer::Disable() { void ApiCallback(uint32_t domain, uint32_t cbid, const void* cbdata, void* user_data) { RocmTracer* tracer = reinterpret_cast(user_data); - tracer->ApiCallbackHandler(domain, cbid, cbdata); + tracer->ApiCallbackHandler(domain, cbid, cbdata).IgnoreError(); } -void RocmTracer::ApiCallbackHandler(uint32_t domain, uint32_t cbid, - const void* cbdata) { - if (api_tracing_enabled_) (*api_cb_impl_)(domain, cbid, cbdata); +tsl::Status RocmTracer::ApiCallbackHandler(uint32_t domain, uint32_t cbid, + const void* cbdata) { + if (api_tracing_enabled_) + TF_RETURN_IF_ERROR((*api_cb_impl_)(domain, cbid, cbdata)); + return tsl::OkStatus(); } -Status RocmTracer::EnableApiTracing() { - if (api_tracing_enabled_) return OkStatus(); +tsl::Status RocmTracer::EnableApiTracing() { + if (api_tracing_enabled_) return tsl::OkStatus(); api_tracing_enabled_ = true; for (auto& iter : options_->api_callbacks) { @@ -1432,11 +1444,11 @@ Status RocmTracer::EnableApiTracing() { } } } - return OkStatus(); + return tsl::OkStatus(); } -Status RocmTracer::DisableApiTracing() { - if (!api_tracing_enabled_) return OkStatus(); +tsl::Status RocmTracer::DisableApiTracing() { + if (!api_tracing_enabled_) return tsl::OkStatus(); api_tracing_enabled_ = false; for (auto& iter : options_->api_callbacks) { @@ -1458,17 +1470,18 @@ Status RocmTracer::DisableApiTracing() { } } } - return OkStatus(); + return tsl::OkStatus(); } void ActivityCallback(const char* begin, const char* end, void* user_data) { RocmTracer* tracer = reinterpret_cast(user_data); - tracer->ActivityCallbackHandler(begin, end); + tracer->ActivityCallbackHandler(begin, end).IgnoreError(); } -void RocmTracer::ActivityCallbackHandler(const char* begin, const char* end) { +tsl::Status RocmTracer::ActivityCallbackHandler(const char* begin, + const char* end) { if (activity_tracing_enabled_) { - (*activity_cb_impl_)(begin, end); + TF_RETURN_IF_ERROR((*activity_cb_impl_)(begin, end)); } else { LOG(WARNING) << "ActivityCallbackHandler called when " "activity_tracing_enabled_ is false"; @@ -1481,14 +1494,21 @@ void RocmTracer::ActivityCallbackHandler(const char* begin, const char* end) { while (record < end_record) { DumpActivityRecord(record, "activity_tracing_enabled_ is false. Dropped!"); - roctracer_next_record(record, &record); +#if TF_ROCM_VERSION >= 50300 + RETURN_IF_ROCTRACER_ERROR(static_cast( + se::wrap::roctracer_next_record(record, &record))); +#else + RETURN_IF_ROCTRACER_ERROR(static_cast( + roctracer_next_record(record, &record))); +#endif } VLOG(3) << "Dropped Activity Records End"; } + return tsl::OkStatus(); } -Status RocmTracer::EnableActivityTracing() { - if (activity_tracing_enabled_) return OkStatus(); +tsl::Status RocmTracer::EnableActivityTracing() { + if (activity_tracing_enabled_) return tsl::OkStatus(); activity_tracing_enabled_ = true; if (!options_->activity_tracing.empty()) { @@ -1525,11 +1545,11 @@ Status RocmTracer::EnableActivityTracing() { } } - return OkStatus(); + return tsl::OkStatus(); } -Status RocmTracer::DisableActivityTracing() { - if (!activity_tracing_enabled_) return OkStatus(); +tsl::Status RocmTracer::DisableActivityTracing() { + if (!activity_tracing_enabled_) return tsl::OkStatus(); for (auto& iter : options_->activity_tracing) { activity_domain_t domain = iter.first; @@ -1574,13 +1594,13 @@ Status RocmTracer::DisableActivityTracing() { << ", Threshold = " << threshold; VLOG(3) << "Wait for pending activity records : sleep for " << duration_ms << " ms"; - tensorflow::profiler::SleepForMillis(duration_ms); + tsl::profiler::SleepForMillis(duration_ms); } ClearPendingActivityRecordsCount(); activity_tracing_enabled_ = false; - return OkStatus(); + return tsl::OkStatus(); } /*static*/ uint64_t RocmTracer::GetTimestamp() { @@ -1596,4 +1616,4 @@ Status RocmTracer::DisableActivityTracing() { } } // namespace profiler -} // namespace tensorflow +} // namespace xla diff --git a/tensorflow/compiler/xla/backends/profiler/gpu/rocm_tracer.h b/tensorflow/compiler/xla/backends/profiler/gpu/rocm_tracer.h new file mode 100644 index 00000000000..38f583727ad --- /dev/null +++ b/tensorflow/compiler/xla/backends/profiler/gpu/rocm_tracer.h @@ -0,0 +1,395 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_ROCM_TRACER_H_ +#define TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_ROCM_TRACER_H_ + +#include "absl/container/fixed_array.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_set.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/stream_executor/rocm/roctracer_wrapper.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/macros.h" +#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/platform/types.h" + +namespace xla { +namespace profiler { + +struct MemcpyDetails { + // The amount of data copied for memcpy events. + size_t num_bytes; + // The destination device for peer-2-peer communication (memcpy). The source + // device is implicit: it's the current device. + uint32_t destination; + // Whether or not the memcpy is asynchronous. + bool async; +}; + +struct MemsetDetails { + // The number of memory elements getting set + size_t num_bytes; + // Whether or not the memset is asynchronous. + bool async; +}; + +struct MemAllocDetails { + // The amount of data requested for cudaMalloc events. + uint64_t num_bytes; +}; + +struct KernelDetails { + // The number of registers used in this kernel. + uint32_t registers_per_thread; + // The amount of shared memory space used by a thread block. + uint32_t static_shared_memory_usage; + // The amount of dynamic memory space used by a thread block. + uint32_t dynamic_shared_memory_usage; + // X-dimension of a thread block. + uint32_t block_x; + // Y-dimension of a thread block. + uint32_t block_y; + // Z-dimension of a thread block. + uint32_t block_z; + // X-dimension of a grid. + uint32_t grid_x; + // Y-dimension of a grid. + uint32_t grid_y; + // Z-dimension of a grid. + uint32_t grid_z; + + // kernel address. Used for calculating core occupancy + void* func_ptr; +}; + +// RocmTracerSyncTypes forward decleration +enum class RocmTracerSyncTypes; +struct SynchronizationDetails { + RocmTracerSyncTypes sync_type; +}; + +enum class RocmTracerEventType { + Unsupported = 0, + Kernel, + MemcpyH2D, + MemcpyD2H, + MemcpyD2D, + MemcpyP2P, + MemcpyOther, + MemoryAlloc, + MemoryFree, + Memset, + Synchronization, + Generic, +}; + +const char* GetRocmTracerEventTypeName(const RocmTracerEventType& type); + +enum class RocmTracerEventSource { + Invalid = 0, + ApiCallback, + Activity, +}; + +const char* GetRocmTracerEventSourceName(const RocmTracerEventSource& source); + +enum class RocmTracerEventDomain { + InvalidDomain = 0, + HIP_API, + HCC_OPS, // TODO(rocm-profiler): renme this to HIP_OPS +}; +enum class RocmTracerSyncTypes { + InvalidSync = 0, + StreamSynchronize, // caller thread wait stream to become empty + EventSynchronize, // caller thread will block until event happens + StreamWait // compute stream will wait for event to happen +}; + +const char* GetRocmTracerEventDomainName(const RocmTracerEventDomain& domain); + +struct RocmTracerEvent { + static constexpr uint32_t kInvalidDeviceId = + std::numeric_limits::max(); + static constexpr uint32_t kInvalidThreadId = + std::numeric_limits::max(); + static constexpr uint32_t kInvalidCorrelationId = + std::numeric_limits::max(); + static constexpr uint64_t kInvalidStreamId = + std::numeric_limits::max(); + RocmTracerEventType type; + RocmTracerEventSource source = RocmTracerEventSource::Invalid; + RocmTracerEventDomain domain; + std::string name; + // This points to strings in AnnotationMap, which should outlive the point + // where serialization happens. + absl::string_view annotation; + absl::string_view roctx_range; + uint64_t start_time_ns = 0; + uint64_t end_time_ns = 0; + uint32_t device_id = kInvalidDeviceId; + uint32_t correlation_id = kInvalidCorrelationId; + uint32_t thread_id = kInvalidThreadId; + int64_t stream_id = kInvalidStreamId; + union { + MemcpyDetails memcpy_info; // If type == Memcpy* + MemsetDetails memset_info; // If type == Memset* + MemAllocDetails memalloc_info; // If type == MemoryAlloc + KernelDetails kernel_info; // If type == Kernel + SynchronizationDetails synchronization_info; // If type == Synchronization + }; +}; + +void DumpRocmTracerEvent(const RocmTracerEvent& event, + uint64_t start_walltime_ns, uint64_t start_gputime_ns, + const std::string& message); + +struct RocmTracerOptions { + std::set api_tracking_set; // actual api set we want to profile + + // map of domain --> ops for which we need to enable the API callbacks + // If the ops vector is empty, then enable API callbacks for entire domain + absl::flat_hash_map > api_callbacks; + + // map of domain --> ops for which we need to enable the Activity records + // If the ops vector is empty, then enable Activity records for entire domain + absl::flat_hash_map > + activity_tracing; +}; + +struct RocmTraceCollectorOptions { + // Maximum number of events to collect from callback API; if -1, no limit. + // if 0, the callback API is enabled to build a correlation map, but no + // events are collected. + uint64_t max_callback_api_events; + // Maximum number of events to collect from activity API; if -1, no limit. + uint64_t max_activity_api_events; + // Maximum number of annotation strings that we can accommodate. + uint64_t max_annotation_strings; + // Number of GPUs involved. + uint32_t num_gpus; +}; + +class AnnotationMap { + public: + explicit AnnotationMap(uint64_t max_size) : max_size_(max_size) {} + void Add(uint32_t correlation_id, const std::string& annotation); + absl::string_view LookUp(uint32_t correlation_id); + + private: + struct AnnotationMapImpl { + // The population/consumption of annotations might happen from multiple + // callback/activity api related threads. + absl::Mutex mutex; + // Annotation tends to be repetitive, use a hash_set to store the strings, + // an use the reference to the string in the map. + absl::node_hash_set annotations; + absl::flat_hash_map correlation_map; + }; + const uint64_t max_size_; + AnnotationMapImpl map_; + + public: + // Disable copy and move. + AnnotationMap(const AnnotationMap&) = delete; + AnnotationMap& operator=(const AnnotationMap&) = delete; +}; + +class RocmTraceCollector { + public: + explicit RocmTraceCollector(const RocmTraceCollectorOptions& options) + : options_(options), annotation_map_(options.max_annotation_strings) {} + virtual ~RocmTraceCollector() {} + + virtual void AddEvent(RocmTracerEvent&& event, bool is_auxiliary) = 0; + virtual void OnEventsDropped(const std::string& reason, + uint32_t num_events) = 0; + virtual void Flush() = 0; + + AnnotationMap* annotation_map() { return &annotation_map_; } + + protected: + RocmTraceCollectorOptions options_; + + private: + AnnotationMap annotation_map_; + + public: + // Disable copy and move. + RocmTraceCollector(const RocmTraceCollector&) = delete; + RocmTraceCollector& operator=(const RocmTraceCollector&) = delete; +}; + +class RocmTracer; + +class RocmApiCallbackImpl { + public: + RocmApiCallbackImpl(const RocmTracerOptions& options, RocmTracer* tracer, + RocmTraceCollector* collector) + : options_(options), tracer_(tracer), collector_(collector) {} + + tsl::Status operator()(uint32_t domain, uint32_t cbid, const void* cbdata); + + private: + void AddKernelEventUponApiExit(uint32_t cbid, const hip_api_data_t* data, + uint64_t enter_time, uint64_t exit_time); + void AddNormalMemcpyEventUponApiExit(uint32_t cbid, + const hip_api_data_t* data, + uint64_t enter_time, uint64_t exit_time); + void AddMemcpyPeerEventUponApiExit(uint32_t cbid, const hip_api_data_t* data, + uint64_t enter_time, uint64_t exit_time); + void AddMemsetEventUponApiExit(uint32_t cbid, const hip_api_data_t* data, + uint64_t enter_time, uint64_t exit_time); + void AddMallocFreeEventUponApiExit(uint32_t cbid, const hip_api_data_t* data, + uint32_t device_id, uint64_t enter_time, + uint64_t exit_time); + void AddStreamSynchronizeEventUponApiExit(uint32_t cbid, + const hip_api_data_t* data, + uint64_t enter_time, + uint64_t exit_time); + void AddSynchronizeEventUponApiExit(uint32_t cbid, const hip_api_data_t* data, + uint64_t enter_time, uint64_t exit_time); + + RocmTracerOptions options_; + RocmTracer* tracer_ = nullptr; + RocmTraceCollector* collector_ = nullptr; + tsl::mutex api_call_start_mutex_; + // TODO(rocm-profiler): replace this with absl hashmap + // keep a map from the corr. id to enter time for API callbacks. + std::map api_call_start_time_ + TF_GUARDED_BY(api_call_start_mutex_); +}; + +class RocmActivityCallbackImpl { + public: + RocmActivityCallbackImpl(const RocmTracerOptions& options, RocmTracer* tracer, + RocmTraceCollector* collector) + : options_(options), tracer_(tracer), collector_(collector) {} + + tsl::Status operator()(const char* begin, const char* end); + + private: + void AddHipKernelActivityEvent(const roctracer_record_t* record); + void AddNormalHipMemcpyActivityEvent(const roctracer_record_t* record); + void AddHipMemsetActivityEvent(const roctracer_record_t* record); + void AddHipMallocActivityEvent(const roctracer_record_t* record); + void AddHipStreamSynchronizeActivityEvent(const roctracer_record_t* record); + void AddHccKernelActivityEvent(const roctracer_record_t* record); + void AddNormalHipOpsMemcpyActivityEvent(const roctracer_record_t* record); + void AddHipOpsMemsetActivityEvent(const roctracer_record_t* record); + RocmTracerOptions options_; + RocmTracer* tracer_ = nullptr; + RocmTraceCollector* collector_ = nullptr; +}; + +// The class use to enable cupti callback/activity API and forward the collected +// trace events to RocmTraceCollector. There should be only one RocmTracer +// per process. +class RocmTracer { + public: + // Returns a pointer to singleton RocmTracer. + static RocmTracer* GetRocmTracerSingleton(); + + // Only one profile session can be live in the same time. + bool IsAvailable() const; + + void Enable(const RocmTracerOptions& options, RocmTraceCollector* collector); + void Disable(); + + tsl::Status ApiCallbackHandler(uint32_t domain, uint32_t cbid, + const void* cbdata); + tsl::Status ActivityCallbackHandler(const char* begin, const char* end); + + static uint64_t GetTimestamp(); + static int NumGpus(); + + void AddToPendingActivityRecords(uint32_t correlation_id) { + pending_activity_records_.Add(correlation_id); + } + + void RemoveFromPendingActivityRecords(uint32_t correlation_id) { + pending_activity_records_.Remove(correlation_id); + } + + void ClearPendingActivityRecordsCount() { pending_activity_records_.Clear(); } + + size_t GetPendingActivityRecordsCount() { + return pending_activity_records_.Count(); + } + + protected: + // protected constructor for injecting mock cupti interface for testing. + explicit RocmTracer() : num_gpus_(NumGpus()) {} + + private: + tsl::Status EnableApiTracing(); + tsl::Status DisableApiTracing(); + + tsl::Status EnableActivityTracing(); + tsl::Status DisableActivityTracing(); + + int num_gpus_; + std::optional options_; + RocmTraceCollector* collector_ = nullptr; + + bool api_tracing_enabled_ = false; + bool activity_tracing_enabled_ = false; + + RocmApiCallbackImpl* api_cb_impl_; + RocmActivityCallbackImpl* activity_cb_impl_; + + class PendingActivityRecords { + public: + // add a correlation id to the pending set + void Add(uint32_t correlation_id) { + absl::MutexLock lock(&mutex); + pending_set.insert(correlation_id); + } + // remove a correlation id from the pending set + void Remove(uint32_t correlation_id) { + absl::MutexLock lock(&mutex); + pending_set.erase(correlation_id); + } + // clear the pending set + void Clear() { + absl::MutexLock lock(&mutex); + pending_set.clear(); + } + // count the number of correlation ids in the pending set + size_t Count() { + absl::MutexLock lock(&mutex); + return pending_set.size(); + } + + private: + // set of co-relation ids for which the hcc activity record is pending + absl::flat_hash_set pending_set; + // the callback which processes the activity records (and consequently + // removes items from the pending set) is called in a separate thread + // from the one that adds item to the list. + absl::Mutex mutex; + }; + PendingActivityRecords pending_activity_records_; + + public: + // Disable copy and move. + RocmTracer(const RocmTracer&) = delete; + RocmTracer& operator=(const RocmTracer&) = delete; +}; + +} // namespace profiler +} // namespace xla +#endif // TENSORFLOW_COMPILER_XLA_BACKENDS_PROFILER_GPU_ROCM_TRACER_H_ diff --git a/tensorflow/compiler/xla/backends/profiler/tpu/BUILD b/tensorflow/compiler/xla/backends/profiler/tpu/BUILD index a6518b3edd6..67261b24e7c 100644 --- a/tensorflow/compiler/xla/backends/profiler/tpu/BUILD +++ b/tensorflow/compiler/xla/backends/profiler/tpu/BUILD @@ -1,9 +1,9 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") -load("//tensorflow:tensorflow.bzl", "if_with_tpu_support") +load("//tensorflow/tsl:tsl.bzl", "if_with_tpu_support") package( - default_visibility = ["//tensorflow:internal"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) @@ -11,17 +11,18 @@ cc_library( name = "tpu_tracer", srcs = if_with_tpu_support(["tpu_tracer.cc"]), copts = tf_profiler_copts(), + visibility = ["//tensorflow/compiler/xla:internal"], deps = [ "//tensorflow/compiler/xla/stream_executor/tpu:status_helper", "//tensorflow/compiler/xla/stream_executor/tpu:tpu_api", "//tensorflow/compiler/xla/stream_executor/tpu:tpu_ops_c_api_hdrs", - "//tensorflow/core/profiler:profiler_options_proto_cc", - "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status", "//tensorflow/tsl/platform:types", "//tensorflow/tsl/profiler/lib:profiler_factory", "//tensorflow/tsl/profiler/lib:profiler_interface", + "//tensorflow/tsl/profiler/protobuf:profiler_options_proto_cc", + "//tensorflow/tsl/profiler/protobuf:xplane_proto_cc", "//tensorflow/tsl/profiler/utils:xplane_schema", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/compiler/xla/backends/profiler/tpu/tpu_tracer.cc b/tensorflow/compiler/xla/backends/profiler/tpu/tpu_tracer.cc index 2091e7f4b5b..48f9ec1bead 100644 --- a/tensorflow/compiler/xla/backends/profiler/tpu/tpu_tracer.cc +++ b/tensorflow/compiler/xla/backends/profiler/tpu/tpu_tracer.cc @@ -22,13 +22,13 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/tpu/status_helper.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_api.h" #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_ops_c_api.h" -#include "tensorflow/core/profiler/profiler_options.pb.h" -#include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/status.h" #include "tensorflow/tsl/platform/types.h" #include "tensorflow/tsl/profiler/lib/profiler_factory.h" #include "tensorflow/tsl/profiler/lib/profiler_interface.h" +#include "tensorflow/tsl/profiler/protobuf/profiler_options.pb.h" +#include "tensorflow/tsl/profiler/protobuf/xplane.pb.h" #include "tensorflow/tsl/profiler/utils/xplane_schema.h" namespace xla { @@ -62,21 +62,21 @@ class TpuTracer : public ProfilerInterface { TpuTracer::TpuTracer() { StatusHelper status; - tensorflow::tpu::OpsApiFn()->TpuProfiler_CreateFn(&tpu_profiler_, - status.c_status); + stream_executor::tpu::OpsApiFn()->TpuProfiler_CreateFn(&tpu_profiler_, + status.c_status); if (!status.ok()) { LOG(ERROR) << status.status().error_message(); } } TpuTracer::~TpuTracer() { - tensorflow::tpu::OpsApiFn()->TpuProfiler_DestroyFn(tpu_profiler_); + stream_executor::tpu::OpsApiFn()->TpuProfiler_DestroyFn(tpu_profiler_); } Status TpuTracer::Start() { StatusHelper status; - tensorflow::tpu::OpsApiFn()->TpuProfiler_StartFn(tpu_profiler_, - status.c_status); + stream_executor::tpu::OpsApiFn()->TpuProfiler_StartFn(tpu_profiler_, + status.c_status); if (!status.ok()) { LOG(ERROR) << "TPU tracer failed to start."; return status.status(); @@ -86,8 +86,8 @@ Status TpuTracer::Start() { Status TpuTracer::Stop() { StatusHelper status; - tensorflow::tpu::OpsApiFn()->TpuProfiler_StopFn(tpu_profiler_, - status.c_status); + stream_executor::tpu::OpsApiFn()->TpuProfiler_StopFn(tpu_profiler_, + status.c_status); if (!status.ok()) { LOG(ERROR) << "TPU tracer failed to stop."; return status.status(); @@ -99,13 +99,13 @@ Status TpuTracer::CollectData(XSpace* space) { StatusHelper status; // Get size of buffer required for TPU driver to serialize XSpace into. size_t size_in_bytes; - tensorflow::tpu::OpsApiFn()->TpuProfiler_CollectDataFn( + stream_executor::tpu::OpsApiFn()->TpuProfiler_CollectDataFn( tpu_profiler_, status.c_status, /*buffer=*/nullptr, &size_in_bytes); // Prepare an appropriately sized buffer. if (size_in_bytes > 0) { std::vector buffer(size_in_bytes); - tensorflow::tpu::OpsApiFn()->TpuProfiler_CollectDataFn( + stream_executor::tpu::OpsApiFn()->TpuProfiler_CollectDataFn( tpu_profiler_, status.c_status, buffer.data(), &size_in_bytes); // Deserialize XSpace from the buffer and return it. XSpace tpu_space; @@ -132,7 +132,7 @@ std::unique_ptr CreateTpuTracer( return nullptr; } // Don't attempt to create a TpuTracer if the TPU C API isn't initialized. - if (tensorflow::tpu::OpsApiFn()->TpuProfiler_CreateFn == nullptr) { + if (stream_executor::tpu::OpsApiFn()->TpuProfiler_CreateFn == nullptr) { return nullptr; } return std::make_unique(); diff --git a/tensorflow/compiler/xla/c/BUILD b/tensorflow/compiler/xla/c/BUILD index 21f7844ad7b..ad39c5cbfb4 100644 --- a/tensorflow/compiler/xla/c/BUILD +++ b/tensorflow/compiler/xla/c/BUILD @@ -1,6 +1,7 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//learning/brain/tfrt/tpu_plugin:__subpackages__", "//tensorflow/core/common_runtime/next_pluggable_device:__subpackages__", diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 2ec2f6a1021..2568100dac4 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -1,11 +1,12 @@ # Description: # XLA client libraries. -load("//tensorflow:tensorflow.default.bzl", "filegroup") +load("//tensorflow/tsl:tsl.default.bzl", "filegroup") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) @@ -56,7 +57,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "padding_test", srcs = ["padding_test.cc"], deps = [ @@ -100,13 +101,15 @@ cc_library( "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/compiler/xla/pjrt:compile_options_proto_cc", + "//tensorflow/compiler/xla/service:compilation_environments", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/tsl/platform:env", - "//tensorflow/tsl/platform:logging", + "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], @@ -232,10 +235,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/hlo/evaluator:hlo_evaluator", - "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:hlo_proto_cc", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -263,7 +266,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/tsl/lib/core:bitmap", @@ -278,10 +281,11 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "xla_builder_test", srcs = ["xla_builder_test.cc"], deps = [ + ":sharding_builder", ":value_inference", ":xla_builder", ":xla_computation", @@ -293,7 +297,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/tsl/platform:test", diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index c1d389fa839..30259c323a2 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -15,10 +15,15 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include +#include + #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/tsl/platform/statusor.h" namespace xla { @@ -55,6 +60,13 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout( return *this; } +CompilationEnvironments* ExecutableBuildOptions::mutable_comp_envs() { + if (!has_comp_envs()) { + comp_envs_.emplace(); + } + return &*comp_envs_; +} + const Shape* ExecutableBuildOptions::result_layout() const { return result_layout_set_ ? &result_layout_ : nullptr; } @@ -109,6 +121,9 @@ StatusOr ExecutableBuildOptions::ToProto() const { if (result_layout()) { *output.mutable_result_layout() = result_layout()->ToProto(); } + if (has_comp_envs()) { + *output.mutable_comp_envs() = comp_envs().ToProto(); + } if (has_debug_options()) { *output.mutable_debug_options() = debug_options(); } @@ -128,8 +143,12 @@ StatusOr ExecutableBuildOptions::ToProto() const { } output.set_alias_passthrough_params(alias_passthrough_params()); output.set_run_backend_only(run_backend_only()); - output.set_allow_spmd_sharding_propagation_to_output( - allow_spmd_sharding_propagation_to_output()); + if (!allow_spmd_sharding_propagation_to_output().empty()) { + output.mutable_allow_spmd_sharding_propagation_to_output()->Clear(); + for (bool v : allow_spmd_sharding_propagation_to_output()) { + output.mutable_allow_spmd_sharding_propagation_to_output()->Add(v); + } + } return output; } @@ -143,6 +162,12 @@ StatusOr ExecutableBuildOptionsFromProto( if (input.has_result_layout()) { output.set_result_layout(xla::Shape(input.result_layout())); } + if (input.has_comp_envs()) { + TF_ASSIGN_OR_RETURN( + auto comp_envs, + xla::CompilationEnvironments::CreateFromProto(input.comp_envs())); + *output.mutable_comp_envs() = std::move(*comp_envs); + } if (input.has_debug_options()) { *output.mutable_debug_options() = input.debug_options(); } @@ -210,8 +235,14 @@ ExecutionOptions CreateExecutionOptions( execution_options.mutable_auto_spmd_partitioning_mesh_ids()->Add(t); } execution_options.set_deduplicate_hlo(build_options.deduplicate_hlo()); - execution_options.set_allow_spmd_sharding_propagation_to_output( - build_options.allow_spmd_sharding_propagation_to_output()); + if (!build_options.allow_spmd_sharding_propagation_to_output().empty()) { + execution_options.mutable_allow_spmd_sharding_propagation_to_output() + ->Clear(); + for (bool v : build_options.allow_spmd_sharding_propagation_to_output()) { + execution_options.mutable_allow_spmd_sharding_propagation_to_output() + ->Add(v); + } + } if (build_options.has_device_assignment()) { TF_CHECK_OK(build_options.device_assignment().Serialize( execution_options.mutable_device_assignment())); diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index a6f4e246772..ed4554e9c77 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -21,8 +21,11 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/pjrt/compile_options.pb.h" +#include "tensorflow/compiler/xla/service/compilation_environments.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla.pb.h" @@ -59,8 +62,16 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); const Shape* result_layout() const; + // Expose access to the XLA compilation environments, which will be passed to + // the compilation process. `comp_envs()` must not be called if + // `has_comp_envs()` returns false. + bool has_comp_envs() const { return comp_envs_.has_value(); } + const CompilationEnvironments& comp_envs() const { return *comp_envs_; } + CompilationEnvironments* mutable_comp_envs(); + // Expose access to the XLA debug options which will be passed to the - // compilation process. + // compilation process. `debug_options()` must not be called if + // `has_debug_options()` returns false. bool has_debug_options() const { return debug_options_.has_value(); } const DebugOptions& debug_options() const { return *debug_options_; } DebugOptions* mutable_debug_options(); @@ -146,9 +157,13 @@ class ExecutableBuildOptions { return *this; } - bool allow_spmd_sharding_propagation_to_output() const { + absl::Span allow_spmd_sharding_propagation_to_output() const { return allow_spmd_sharding_propagation_to_output_; } + bool any_allow_spmd_sharding_propagation_to_output() const { + return absl::c_any_of(allow_spmd_sharding_propagation_to_output_, + [](bool b) { return b; }); + } // Allows sharding propagation to propagate to the outputs. This changes the // output shape of the computation (which is undesirable), but it can be used // to allow to run partial compilation to determine what would be the output @@ -157,9 +172,10 @@ class ExecutableBuildOptions { // sharding of operations when multiple computation would be chained and // merged together. ExecutableBuildOptions& set_allow_spmd_sharding_propagation_to_output( - bool allow_spmd_sharding_propagation_to_output) { - allow_spmd_sharding_propagation_to_output_ = - allow_spmd_sharding_propagation_to_output; + absl::Span allow_spmd_sharding_propagation_to_output) { + allow_spmd_sharding_propagation_to_output_.assign( + allow_spmd_sharding_propagation_to_output.begin(), + allow_spmd_sharding_propagation_to_output.end()); return *this; } @@ -190,6 +206,7 @@ class ExecutableBuildOptions { int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; + std::optional comp_envs_; std::optional debug_options_; se::DeviceMemoryAllocator* device_allocator_ = nullptr; int num_replicas_ = 1; @@ -203,7 +220,8 @@ class ExecutableBuildOptions { std::optional device_assignment_; bool alias_passthrough_params_ = false; bool run_backend_only_ = false; - bool allow_spmd_sharding_propagation_to_output_ = false; + absl::InlinedVector allow_spmd_sharding_propagation_to_output_ = { + false}; tsl::thread::ThreadPool* compile_thread_pool_ = nullptr; LayoutCanonicalizationCallback layout_canonicalization_callback_; }; diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 8e503c1cd08..69531d0d307 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -1,10 +1,11 @@ # Common computation builders for XLA. -load("//tensorflow:tensorflow.default.bzl", "filegroup") +load("//tensorflow/tsl:tsl.default.bzl", "filegroup") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//tensorflow/compiler/xla/client:friends"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index d18395dbfed..1972b97c299 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -1163,10 +1163,6 @@ XlaOp Asin(XlaOp x) { XlaOp Atan(XlaOp x) { return Atan2(x, ScalarLike(x, 1.0)); } -XlaOp Tan(XlaOp x) { - return DoWithUpcastToF32(x, {F16}, [](XlaOp x) { return Sin(x) / Cos(x); }); -} - // Hyperbolic trigonometric functions. // acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1 diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index e6b5ac992cc..b922f85b4a5 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -93,9 +93,6 @@ XlaOp Asin(XlaOp x); // Computes the arc tangent of 'x'. XlaOp Atan(XlaOp x); -// Computes the tangent of 'x'. -XlaOp Tan(XlaOp x); - // Hyperbolic trigonometric functions // Computes the inverse hyperbolic cosine of 'x'. diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index b2399e62081..d31036e6035 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -199,6 +199,10 @@ XLA_TEST_F(MathTest, RealFpOnlyOps) { } else { continue; } + if (ty == F8E5M2 || ty == F8E4M3FN) { + // TODO(b/259609697): Add FP8 support to math ops + continue; + } for (const auto& test : std::vector, std::string>>({ diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index 1f224c7af13..8466b3d51a8 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -190,7 +190,9 @@ SplitShapePair SplitShapeIntoHalves(const Shape& shape) { } } } - CHECK_GE(pair.split_dim, 0); + if (pair.split_dim < 0) { + LOG(ERROR) << "This point shouldn't have been reached."; + } std::vector half_shape_dims; std::vector concat_shape_dims; const auto rank = shape.rank(); @@ -251,6 +253,18 @@ RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) { return {result, inputs_state.second}; } +// Generates random 16bits with the given shape using the Three Fry +// implementation. Returns the random bits and the new state. +RngOutput ThreeFryRngBit16(XlaOp op_key, XlaOp initial_state, + const Shape& shape) { + // TODO(b/256713018): Use a better approach to not waste the upper 16 bits. + auto new_shape = shape; + new_shape.set_element_type(U32); + auto output = ThreeFryRngBit32(op_key, initial_state, new_shape); + output.value = ConvertElementType(output.value, U16); + return output; +} + // Generates random 64bits with the given shape using the Three Fry // implementation. Returns the random bits and the new state. RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) { @@ -429,6 +443,21 @@ RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state, return {Reshape(numbers, shape.dimensions()), new_state}; } +// Generates an array of primitive type U16 with the given shape containing +// random bits generated by the Philox algorithm. Returns the array and the new +// state of the random number generator. +RngOutput PhiloxRngBit16(XlaOp op_key, XlaOp initial_state, + const Shape& shape) { + // We use PhiloxRngBit32 and throw away the upper 16 bits here, to align with + // the non-XLA kernels. + // TODO(b/256713018): Use a better approach to not waste the upper 16 bits. + auto new_shape = shape; + new_shape.set_element_type(U32); + auto output = PhiloxRngBit32(op_key, initial_state, new_shape); + output.value = ConvertElementType(output.value, U16); + return output; +} + // Generates an array of primitive type U64 with the given shape containing // random bits generated by the Philox algorithm. Returns the array and the new // state of the random number generator. @@ -471,29 +500,59 @@ XlaOp ConvertRandomBitsToUniformFloatingPoint(XlaOp bits, XlaOp minval, TF_ASSIGN_OR_RETURN(const Shape* bits_shape, builder->GetShapePtr(bits)); PrimitiveType value_type = minval_shape->element_type(); PrimitiveType bit_type = bits_shape->element_type(); - CHECK((value_type == F32 && bit_type == U32) || - (value_type == F64 && bit_type == U64)); - - // Form random mantissa bits for float/double, with a leading 1 bit. - int num_float_bits = primitive_util::BitWidth(value_type); - // Subtract one as SignificandWidth includes the leading 1 bit. - int num_mantissa_bits = primitive_util::SignificandWidth(value_type) - 1; - - // Ignore the exponent bits and convert the mantissa bits to the floating - // point type. - bits = ShiftRightLogical( - bits, ScalarLike(bits, num_float_bits - num_mantissa_bits)); - - // We have an integer-valued floating point number in the range - // [0, 2**{num_mantissa_bits}). - XlaOp values = ConvertElementType(bits, value_type); - - // Divide by 2**{-num_mantissa_bits} to get a number in the range - // [0.0, 1.0). - values = values * ScalarLike(values, std::ldexp(1., -num_mantissa_bits)); + auto is_f32_or_f64 = (value_type == F32 && bit_type == U32) || + (value_type == F64 && bit_type == U64); + auto is_f16 = value_type == F16 && bit_type == U16; + if (!(is_f32_or_f64 || is_f16)) { + return InvalidArgument( + "In ConvertRandomBitsToUniformFloatingPoint, value_type and bit_type " + "can only be one of those combinations: (float16, uint16), (float32, " + "uint32) and (float64, uint64). Got combination: (%s, %s).", + primitive_util::LowercasePrimitiveTypeName(value_type), + primitive_util::LowercasePrimitiveTypeName(bit_type)); + } - // Multiply and add to shift to the range [minval, maxval). - return values * (maxval - minval) + minval; + if (is_f32_or_f64) { + // TODO(b/256715195): Consider using the approach in the F16 case. + + // Form random mantissa bits for float/double, with a leading 1 bit. + int num_float_bits = primitive_util::BitWidth(value_type); + // Subtract one as SignificandWidth includes the leading 1 bit. + int num_mantissa_bits = primitive_util::SignificandWidth(value_type) - 1; + + // Ignore the exponent bits and convert the mantissa bits to the floating + // point type. + bits = ShiftRightLogical( + bits, ScalarLike(bits, num_float_bits - num_mantissa_bits)); + + // We have an integer-valued floating point number in the range + // [0, 2**{num_mantissa_bits}). + XlaOp values = ConvertElementType(bits, value_type); + + // Divide by 2**{-num_mantissa_bits} to get a number in the range + // [0.0, 1.0). + values = values * ScalarLike(values, std::ldexp(1., -num_mantissa_bits)); + + // Multiply and add to shift to the range [minval, maxval). + return values * (maxval - minval) + minval; + } else if (is_f16) { + // This path follows the approach of the non-XLA kernels (see + // `tsl::random::Uint16ToHalf`). IEEE754 halfs are formatted as follows + // (MSB first): + // sign(1) exponent(5) mantissa(10) + // Conceptually construct the following: + // sign == 0 + // exponent == 15 -- an excess 15 representation of a zero exponent + // mantissa == 10 random bits + + auto mantissa = bits & ScalarLike(bits, 0x3ffu); // 10 bit mantissa + auto exponent = ScalarLike(bits, static_cast(15) << 10); + auto u16_result = exponent | mantissa; + auto result = BitcastConvertType(u16_result, F16); + return result - ScalarLike(result, 1.0); + } else { + return InternalError("This point shouldn't have been reached."); + } }); } @@ -534,6 +593,10 @@ RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape) { PrimitiveType type = shape.element_type(); switch (type) { + case F16: + case U16: + case S16: + return ThreeFryRngBit16(key, initial_state, shape); case F32: case U32: case S32: @@ -543,11 +606,12 @@ RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, case S64: return ThreeFryRngBit64(key, initial_state, shape); default: - return {key.builder()->ReportError(Unimplemented( - "Types other than F32, F64, U32, S32, U64 and S64 " - "are not implemented by ThreeFryBitGenerator; got %s", - primitive_util::LowercasePrimitiveTypeName(type))), - initial_state}; + return { + key.builder()->ReportError(Unimplemented( + "Types other than F16, F32, F64, U16, S16, U32, S32, U64 and S64 " + "are not implemented by ThreeFryBitGenerator; got %s", + primitive_util::LowercasePrimitiveTypeName(type))), + initial_state}; } } @@ -555,6 +619,10 @@ RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape) { PrimitiveType type = shape.element_type(); switch (type) { + case F16: + case U16: + case S16: + return PhiloxRngBit16(key, initial_state, shape); case F32: case U32: case S32: @@ -564,11 +632,12 @@ RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, case S64: return PhiloxRngBit64(key, initial_state, shape); default: - return {key.builder()->ReportError(Unimplemented( - "Types other than F32, F64, U32, S32, U64 and S64 " - "are not implemented by PhiloxFryBitGenerator; got %s", - primitive_util::LowercasePrimitiveTypeName(type))), - initial_state}; + return { + key.builder()->ReportError(Unimplemented( + "Types other than F16, F32, F64, U16, S16, U32, S32, U64 and S64 " + "are not implemented by PhiloxBitGenerator; got %s", + primitive_util::LowercasePrimitiveTypeName(type))), + initial_state}; } } @@ -600,9 +669,14 @@ RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state, PrimitiveType unsigned_type; if (type == U32 || type == S32) { unsigned_type = U32; - } else { - DCHECK(type == U64 || type == S64); + } else if (type == U64 || type == S64) { unsigned_type = U64; + } else { + return {key.builder()->ReportError(Unimplemented( + "Types other than U32, S32, U64 and S64 " + "are not implemented by UniformIntDistribution; got %s", + primitive_util::LowercasePrimitiveTypeName(type))), + initial_state}; } return { ConvertRandomBitsToUniformInt(bits, minval, maxval, type, unsigned_type), @@ -612,10 +686,18 @@ RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state, RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state, BitGeneratorTy bit_generator, const Shape& shape) { + XlaBuilder* builder = key.builder(); PrimitiveType primitive_type = shape.element_type(); - DCHECK(primitive_type == F32 || primitive_type == F64); + if (!(primitive_type == F16 || primitive_type == F32 || + primitive_type == F64)) { + return { + builder->ReportError(Unimplemented( + "Types other than F16, F32 and F64 " + "are not implemented by NormalFloatingPointDistribution; got %s", + primitive_util::LowercasePrimitiveTypeName(primitive_type))), + initial_state}; + } - XlaBuilder* builder = key.builder(); auto shape_pair = SplitShapeIntoHalves(shape); RngOutput bits_state = UniformFloatingPointDistribution( key, initial_state, bit_generator, diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc index 1fe8d556d1d..7fec04e2ac5 100644 --- a/tensorflow/compiler/xla/client/padding.cc +++ b/tensorflow/compiler/xla/client/padding.cc @@ -35,6 +35,16 @@ Status ValidatePaddingValues(absl::Span input_dimensions, input_dimensions.size(), window_dimensions.size(), window_strides.size()); } + for (size_t i = 0; i < input_dimensions.size(); ++i) { + if (window_dimensions[i] <= 0) { + return InvalidArgument("Window dimension %u has non-positive size %d", i, + window_dimensions[i]); + } + if (window_strides[i] <= 0) { + return InvalidArgument("Window dimension %u has non-positive stride %d", + i, window_strides[i]); + } + } return OkStatus(); } diff --git a/tensorflow/compiler/xla/client/value_inference.cc b/tensorflow/compiler/xla/client/value_inference.cc index db873636140..c14bb8dd043 100644 --- a/tensorflow/compiler/xla/client/value_inference.cc +++ b/tensorflow/compiler/xla/client/value_inference.cc @@ -24,20 +24,20 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/comparison_util.h" +#include "tensorflow/compiler/xla/hlo/ir/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/statusor.h" namespace xla { namespace { @@ -510,6 +510,7 @@ StatusOr PostorderDFSVisitor::AnalyzeConstantValueFallback( case HloOpcode::kSubtract: case HloOpcode::kCos: case HloOpcode::kSin: + case HloOpcode::kTan: case HloOpcode::kNegate: case HloOpcode::kAbs: case HloOpcode::kDivide: @@ -1125,6 +1126,7 @@ StatusOr PostorderDFSVisitor::AnalyzeIsDynamic( case HloOpcode::kConvert: case HloOpcode::kSqrt: case HloOpcode::kCbrt: + case HloOpcode::kTan: case HloOpcode::kTanh: { // Forward operand as they don't change if a value is dynamic or static. return result.AddVisit([](Literal operand) { return operand; }); diff --git a/tensorflow/compiler/xla/client/value_inference.h b/tensorflow/compiler/xla/client/value_inference.h index aedc214c8c2..2579f65059f 100644 --- a/tensorflow/compiler/xla/client/value_inference.h +++ b/tensorflow/compiler/xla/client/value_inference.h @@ -20,11 +20,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.h" +#include "tensorflow/compiler/xla/hlo/ir/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index bde30ba28c2..576c3ce68f4 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -29,7 +30,6 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/match.h" -#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" @@ -37,14 +37,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/comparison_util.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_input_output_alias_config.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/permutation_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_instructions.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/sharding_op_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -109,13 +108,26 @@ XlaOp XlaBuilderFriend::BuildAddDependency(XlaBuilder* builder, XlaOp operand, }); } -XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder, - absl::Span operands, - absl::string_view fusion_kind, - const XlaComputation& fused_computation) { +XlaOp XlaBuilderFriend::BuildFusion( + XlaBuilder* builder, absl::Span operands, + absl::string_view fusion_kind, const XlaComputation& fused_computation, + absl::Span>> + output_operand_aliasing) { return builder->ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; instr.set_fusion_kind(std::string(fusion_kind)); + if (!output_operand_aliasing.empty()) { + for (const auto& pair : output_operand_aliasing) { + auto aliasing = instr.add_output_operand_aliasing(); + aliasing->set_operand_index(pair.second.first); + for (int64_t index : pair.second.second) { + aliasing->add_operand_shape_index(index); + } + for (int64_t index : pair.first) { + aliasing->add_output_shape_index(index); + } + } + } std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(auto program_shape, fused_computation.GetProgramShape()); @@ -126,7 +138,7 @@ XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder, }); } -XlaOp XlaBuilderFriend::BuildAsyncStart( +std::pair XlaBuilderFriend::BuildAsyncStart( XlaBuilder* builder, absl::Span operands, std::string execution_thread, const XlaComputation& called_computation, const Shape& shape) { @@ -134,38 +146,42 @@ XlaOp XlaBuilderFriend::BuildAsyncStart( called_computation, shape); } -XlaOp XlaBuilderFriend::BuildAsyncStart( +std::pair XlaBuilderFriend::BuildAsyncStart( XlaBuilder* builder, absl::Span operands, std::string execution_thread, int64_t group_id, const XlaComputation& called_computation, const Shape& shape) { - return builder->ReportErrorOrReturn([&]() -> StatusOr { + int64_t called_computation_id; + auto start_op = builder->ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); instr.set_async_execution_thread(execution_thread); instr.set_async_group_id(group_id); builder->AddCalledComputation(called_computation, &instr); + called_computation_id = instr.called_computation_ids()[0]; return builder->AddInstruction(std::move(instr), HloOpcode::kAsyncStart, operands); }); + return {start_op, called_computation_id}; } -XlaOp XlaBuilderFriend::BuildAsyncUpdate( - XlaBuilder* builder, const XlaOp operand, std::string execution_thread, - const XlaComputation& called_computation, const Shape& shape) { +XlaOp XlaBuilderFriend::BuildAsyncUpdate(XlaBuilder* builder, + const XlaOp operand, + std::string execution_thread, + int64_t called_computation, + const Shape& shape) { return BuildAsyncUpdate(builder, operand, execution_thread, /*group_id=*/-1, called_computation, shape); } XlaOp XlaBuilderFriend::BuildAsyncUpdate( XlaBuilder* builder, const XlaOp operand, std::string execution_thread, - int64_t group_id, const XlaComputation& called_computation, - const Shape& shape) { + int64_t group_id, int64_t called_computation, const Shape& shape) { return builder->ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); instr.set_async_execution_thread(execution_thread); instr.set_async_group_id(group_id); - builder->AddCalledComputation(called_computation, &instr); + instr.add_called_computation_ids(called_computation); return builder->AddInstruction(std::move(instr), HloOpcode::kAsyncUpdate, {operand}); }); @@ -173,7 +189,7 @@ XlaOp XlaBuilderFriend::BuildAsyncUpdate( XlaOp XlaBuilderFriend::BuildAsyncDone(XlaBuilder* builder, const XlaOp operand, std::string execution_thread, - const XlaComputation& called_computation, + int64_t called_computation, const Shape& shape) { return BuildAsyncDone(builder, operand, execution_thread, /*group_id=*/-1, called_computation, shape); @@ -182,14 +198,14 @@ XlaOp XlaBuilderFriend::BuildAsyncDone(XlaBuilder* builder, const XlaOp operand, XlaOp XlaBuilderFriend::BuildAsyncDone(XlaBuilder* builder, const XlaOp operand, std::string execution_thread, int64_t group_id, - const XlaComputation& called_computation, + int64_t called_computation, const Shape& shape) { return builder->ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); instr.set_async_execution_thread(execution_thread); instr.set_async_group_id(group_id); - builder->AddCalledComputation(called_computation, &instr); + instr.add_called_computation_ids(called_computation); return builder->AddInstruction(std::move(instr), HloOpcode::kAsyncDone, {operand}); }); @@ -239,6 +255,56 @@ XlaOp XlaBuilderFriend::BuildAllReduceDone(XlaBuilder* builder, }); } +XlaOp XlaBuilderFriend::BuildCopyStart( + XlaBuilder* builder, const XlaOp operand, + std::optional cross_program_prefetch_index) { + return builder->ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + if (cross_program_prefetch_index) { + instr.set_cross_program_prefetch_index(*cross_program_prefetch_index); + } + + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, + builder->GetShapePtr(operand)); + Shape u32 = ShapeUtil::MakeScalarShape(PrimitiveType::U32); + Shape shape = + ShapeUtil::MakeTupleShapeWithPtrs({operand_shape, operand_shape, &u32}); + *instr.mutable_shape() = shape.ToProto(); + + return builder->AddInstruction(std::move(instr), HloOpcode::kCopyStart, + {operand}); + }); +} + +XlaOp XlaBuilderFriend::BuildCopyDone(XlaBuilder* builder, const XlaOp operand, + const Shape& shape) { + return builder->ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + return builder->AddInstruction(std::move(instr), HloOpcode::kCopyDone, + {operand}); + }); +} + +XlaOp XlaBuilderFriend::BuildCollectivePermuteStart( + XlaBuilder* builder, XlaOp operand, + const std::vector>& source_target_pairs, + const std::optional& channel_id) { + return builder->CollectivePermuteImpl(operand, source_target_pairs, + channel_id, /*async=*/true); +} + +XlaOp XlaBuilderFriend::BuildCollectivePermuteDone(XlaBuilder* builder, + const XlaOp operand, + const Shape& shape) { + return builder->ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + return builder->AddInstruction( + std::move(instr), HloOpcode::kCollectivePermuteDone, {operand}); + }); +} + XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand, const Shape& shape) { return builder->ReportErrorOrReturn([&]() -> StatusOr { @@ -271,6 +337,73 @@ XlaOp XlaBuilderFriend::BuildPartitionId(XlaBuilder* builder, }); } +XlaOp XlaBuilderFriend::BuildSend(XlaBuilder* builder, XlaOp operand, + XlaOp token, const ChannelHandle& handle, + bool is_host_transfer) { + return builder->ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto send_instr; + TF_ASSIGN_OR_RETURN(const Shape* shape, builder->GetShapePtr(operand)); + // Send instruction produces a tuple of {aliased operand, U32 context, + // token}. + *send_instr.mutable_shape() = + ShapeUtil::MakeTupleShape({*shape, ShapeUtil::MakeShape(U32, {}), + ShapeUtil::MakeTokenShape()}) + .ToProto(); + send_instr.set_channel_id(handle.handle()); + send_instr.set_is_host_transfer(is_host_transfer); + return builder->AddInstruction(std::move(send_instr), HloOpcode::kSend, + {operand, token}); + }); +} + +XlaOp XlaBuilderFriend::BuildSendDone(XlaBuilder* builder, XlaOp operand, + const ChannelHandle& handle, + bool is_host_transfer) { + return builder->ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto send_done_instr; + *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); + send_done_instr.set_channel_id(handle.handle()); + send_done_instr.set_is_host_transfer(is_host_transfer); + return builder->AddInstruction(std::move(send_done_instr), + HloOpcode::kSendDone, {operand}); + }); +} + +XlaOp XlaBuilderFriend::BuildRecv(XlaBuilder* builder, XlaOp token, + const Shape& shape, + const ChannelHandle& handle, + bool is_host_transfer) { + return builder->ReportErrorOrReturn([&]() -> StatusOr { + // Recv instruction produces a tuple of {receive buffer, U32 context, + // token}. + HloInstructionProto recv_instr; + *recv_instr.mutable_shape() = + ShapeUtil::MakeTupleShape( + {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}) + .ToProto(); + recv_instr.set_channel_id(handle.handle()); + recv_instr.set_is_host_transfer(is_host_transfer); + return builder->AddInstruction(std::move(recv_instr), HloOpcode::kRecv, + {token}); + }); +} + +XlaOp XlaBuilderFriend::BuildRecvDone(XlaBuilder* builder, XlaOp token, + const Shape& shape, + const ChannelHandle& handle, + bool is_host_transfer) { + return builder->ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto recv_done_instr; + *recv_done_instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}) + .ToProto(); + recv_done_instr.set_channel_id(handle.handle()); + recv_done_instr.set_is_host_transfer(is_host_transfer); + return builder->AddInstruction(std::move(recv_done_instr), + HloOpcode::kRecvDone, {token}); + }); +} + XlaOp XlaBuilderFriend::BuildRngGetAndUpdateState(XlaBuilder* builder, int64_t delta, @@ -393,7 +526,7 @@ void XlaBuilder::ToStringHelper(std::string* out, int ident, XlaBuilder::XlaBuilder(const std::string& computation_name) : name_(computation_name) {} -XlaBuilder::~XlaBuilder() {} +XlaBuilder::~XlaBuilder() = default; XlaOp XlaBuilder::ReportError(const Status& error) { CHECK(!error.ok()); @@ -1988,15 +2121,35 @@ void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout, // Outfeed takes a token as its second operand. Generate the token to pass // to the outfeed. - HloInstructionProto token_instr; - *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); - TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), - HloOpcode::kAfterAll, {})); - - TF_RETURN_IF_ERROR( - AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand, token}) - .status()); - + XlaOp token; + auto make_token = [&]() { + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); + return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {}); + }; + auto make_outfeed = [&](XlaOp token) { + return AddInstruction(std::move(instr), HloOpcode::kOutfeed, + {operand, token}); + }; + if (sharding()) { + XlaScopedShardingAssignment scoped_sharding( + this, sharding_builder::AssignDevice(0)); + TF_ASSIGN_OR_RETURN(token, make_token()); + } else { + TF_ASSIGN_OR_RETURN(token, make_token()); + } + if (sharding()) { + OpSharding tuple_sharding = *sharding(); + if (tuple_sharding.type() != OpSharding::TUPLE) { + tuple_sharding = sharding_builder::Tuple({}); + *tuple_sharding.add_tuple_shardings() = *sharding(); + } + *tuple_sharding.add_tuple_shardings() = sharding_builder::AssignDevice(0); + XlaScopedShardingAssignment scoped_sharding(this, tuple_sharding); + TF_RETURN_IF_ERROR(make_outfeed(token).status()); + } else { + TF_RETURN_IF_ERROR(make_outfeed(token).status()); + } // The outfeed instruction produces a token. However, existing users expect // a nil shape (empty tuple). This should only be relevant if the outfeed is // the root of a computation. @@ -2007,7 +2160,7 @@ void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout, // The dummy tuple should have no sharding. { - XlaScopedShardingAssignment scoped_sharding(this, OpSharding()); + XlaScopedShardingAssignment scoped_sharding(this, std::nullopt); TF_ASSIGN_OR_RETURN( XlaOp empty_tuple, AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {})); @@ -2163,7 +2316,7 @@ StatusOr XlaBuilder::CustomCallInternal( AddCalledComputation(*computation, &instr); } for (const auto& pair : output_operand_aliasing) { - auto aliasing = instr.add_custom_call_output_operand_aliasing(); + auto aliasing = instr.add_output_operand_aliasing(); aliasing->set_operand_index(pair.second.first); for (int64_t index : pair.second.second) { aliasing->add_operand_shape_index(index); @@ -3232,20 +3385,22 @@ XlaOp XlaBuilder::ReduceScatter( XlaOp XlaBuilder::AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, absl::Span replica_groups, - const std::optional& layout) { + const std::optional& layout, + const std::optional& channel_id) { // Array all_to_all may need to violate layout constraint to be legal so use // the tuple version. if (layout.has_value()) { return AllToAllTuple(operand, split_dimension, concat_dimension, - split_count, replica_groups, layout); + split_count, replica_groups, layout, channel_id); } return AllToAllArray(operand, split_dimension, concat_dimension, split_count, - replica_groups); + replica_groups, channel_id); } -XlaOp XlaBuilder::AllToAllArray(XlaOp operand, int64_t split_dimension, - int64_t concat_dimension, int64_t split_count, - absl::Span replica_groups) { +XlaOp XlaBuilder::AllToAllArray( + XlaOp operand, int64_t split_dimension, int64_t concat_dimension, + int64_t split_count, absl::Span replica_groups, + const std::optional& channel_id) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN( @@ -3265,6 +3420,9 @@ XlaOp XlaBuilder::AllToAllArray(XlaOp operand, int64_t split_dimension, } } instr.add_dimensions(split_dimension); + if (channel_id.has_value()) { + instr.set_channel_id(channel_id->handle()); + } TF_ASSIGN_OR_RETURN( XlaOp all_to_all, AddInstruction(std::move(instr), HloOpcode::kAllToAll, {operand})); @@ -3297,9 +3455,11 @@ XlaOp XlaBuilder::AllToAllArray(XlaOp operand, int64_t split_dimension, }); } -XlaOp XlaBuilder::AllToAllTuple(absl::Span operands, - absl::Span replica_groups, - const std::optional& layout) { +XlaOp XlaBuilder::AllToAllTuple( + absl::Span operands, + absl::Span replica_groups, + const std::optional& layout, + const std::optional& channel_id) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(auto operand_shapes, this->GetOperandShapes(operands)); @@ -3330,15 +3490,19 @@ XlaOp XlaBuilder::AllToAllTuple(absl::Span operands, for (const ReplicaGroup& group : replica_groups) { *instr.add_replica_groups() = group; } + if (channel_id.has_value()) { + instr.set_channel_id(channel_id->handle()); + } return AddInstruction(std::move(instr), HloOpcode::kAllToAll, operands); }); } -XlaOp XlaBuilder::AllToAllTuple(XlaOp operand, int64_t split_dimension, - int64_t concat_dimension, int64_t split_count, - absl::Span replica_groups, - const std::optional& layout) { +XlaOp XlaBuilder::AllToAllTuple( + XlaOp operand, int64_t split_dimension, int64_t concat_dimension, + int64_t split_count, absl::Span replica_groups, + const std::optional& layout, + const std::optional& channel_id) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); @@ -3366,7 +3530,8 @@ XlaOp XlaBuilder::AllToAllTuple(XlaOp operand, int64_t split_dimension, } // Handle data communication. - XlaOp alltoall = this->AllToAllTuple(slices, replica_groups, layout); + XlaOp alltoall = + this->AllToAllTuple(slices, replica_groups, layout, channel_id); // Concat the N received parts. std::vector received; @@ -3382,6 +3547,14 @@ XlaOp XlaBuilder::CollectivePermute( XlaOp operand, const std::vector>& source_target_pairs, const std::optional& channel_id) { + return CollectivePermuteImpl(operand, source_target_pairs, channel_id, + /*async=*/false); +} + +XlaOp XlaBuilder::CollectivePermuteImpl( + XlaOp operand, + const std::vector>& source_target_pairs, + const std::optional& channel_id, bool async) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); HloInstructionProto instr; @@ -3399,7 +3572,9 @@ XlaOp XlaBuilder::CollectivePermute( instr.set_channel_id(channel_id->handle()); } - return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute, + return AddInstruction(std::move(instr), + async ? HloOpcode::kCollectivePermuteStart + : HloOpcode::kCollectivePermute, {operand}); }); } @@ -3550,24 +3725,10 @@ XlaOp XlaBuilder::SendWithToken(XlaOp operand, XlaOp token, return InvalidArgument("Send must use a device-to-device channel"); } - // Send instruction produces a tuple of {aliased operand, U32 context, - // token}. - HloInstructionProto send_instr; - TF_ASSIGN_OR_RETURN(const Shape* shape, GetShapePtr(operand)); - *send_instr.mutable_shape() = - ShapeUtil::MakeTupleShape({*shape, ShapeUtil::MakeShape(U32, {}), - ShapeUtil::MakeTokenShape()}) - .ToProto(); - send_instr.set_channel_id(handle.handle()); - TF_ASSIGN_OR_RETURN(XlaOp send, - AddInstruction(std::move(send_instr), HloOpcode::kSend, - {operand, token})); - - HloInstructionProto send_done_instr; - *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); - send_done_instr.set_channel_id(handle.handle()); - return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, - {send}); + XlaOp send_op = internal::XlaBuilderFriend::BuildSend(this, operand, token, + handle, false); + return internal::XlaBuilderFriend::BuildSendDone(this, send_op, handle, + false); }); } @@ -3603,24 +3764,10 @@ XlaOp XlaBuilder::RecvWithToken(XlaOp token, const Shape& shape, return InvalidArgument("Recv must use a device-to-device channel"); } - // Recv instruction produces a tuple of {receive buffer, U32 context, - // token}. - HloInstructionProto recv_instr; - *recv_instr.mutable_shape() = - ShapeUtil::MakeTupleShape( - {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}) - .ToProto(); - recv_instr.set_channel_id(handle.handle()); - TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr), - HloOpcode::kRecv, {token})); - - HloInstructionProto recv_done_instr; - *recv_done_instr.mutable_shape() = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}) - .ToProto(); - recv_done_instr.set_channel_id(handle.handle()); - return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, - {recv}); + XlaOp recv_op = internal::XlaBuilderFriend::BuildRecv(this, token, shape, + handle, false); + return internal::XlaBuilderFriend::BuildRecvDone(this, recv_op, shape, + handle, false); }); } @@ -4095,7 +4242,13 @@ StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, *instr.mutable_metadata() = metadata_; } if (sharding_) { - *instr.mutable_sharding() = *sharding_; + // Normalize tuple sharding and fail the call if the sharding is not valid. + Shape shape(instr.shape()); + TF_ASSIGN_OR_RETURN(HloSharding sharding, + HloSharding::FromProto(*sharding_)); + sharding = sharding.NormalizeTupleSharding(shape); + TF_RETURN_IF_ERROR(sharding.Validate(shape)); + *instr.mutable_sharding() = sharding.ToProto(); } *instr.mutable_frontend_attributes() = frontend_attributes_; @@ -4648,10 +4801,10 @@ XlaOp OptimizationBarrier(XlaOp operand) { return operand.builder()->OptimizationBarrier(operand); } -XlaOp Complex(const XlaOp lhs, const XlaOp rhs, +XlaOp Complex(const XlaOp real, const XlaOp imag, absl::Span broadcast_dimensions) { - return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs, - broadcast_dimensions); + return real.builder()->BinaryOp(HloOpcode::kComplex, real, imag, + broadcast_dimensions); } XlaOp Conj(const XlaOp operand) { @@ -4854,25 +5007,30 @@ XlaOp ReduceScatter(const XlaOp operand, const XlaComputation& computation, XlaOp AllToAll(const XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, absl::Span replica_groups, - const std::optional& layout) { + const std::optional& layout, + const std::optional& channel_id) { return operand.builder()->AllToAll(operand, split_dimension, concat_dimension, - split_count, replica_groups, layout); + split_count, replica_groups, layout, + channel_id); } XlaOp AllToAllTuple(absl::Span operands, absl::Span replica_groups, - const std::optional& layout) { + const std::optional& layout, + const std::optional& channel_id) { CHECK(!operands.empty()); - return operands[0].builder()->AllToAllTuple(operands, replica_groups, layout); + return operands[0].builder()->AllToAllTuple(operands, replica_groups, layout, + channel_id); } XlaOp AllToAllTuple(const XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, absl::Span replica_groups, - const std::optional& layout) { + const std::optional& layout, + const std::optional& channel_id) { return operand.builder()->AllToAllTuple(operand, split_dimension, concat_dimension, split_count, - replica_groups, layout); + replica_groups, layout, channel_id); } XlaOp CollectivePermute( @@ -4910,10 +5068,9 @@ XlaOp Abs(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kAbs, operand); } -XlaOp Atan2(const XlaOp lhs, const XlaOp rhs, +XlaOp Atan2(const XlaOp y, const XlaOp x, absl::Span broadcast_dimensions) { - return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs, - broadcast_dimensions); + return y.builder()->BinaryOp(HloOpcode::kAtan2, y, x, broadcast_dimensions); } XlaOp Exp(const XlaOp operand) { @@ -4955,6 +5112,9 @@ XlaOp Cos(const XlaOp operand) { XlaOp Sin(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kSin, operand); } +XlaOp Tan(const XlaOp operand) { + return operand.builder()->UnaryOp(HloOpcode::kTan, operand); +} XlaOp Tanh(const XlaOp operand) { return operand.builder()->UnaryOp(HloOpcode::kTanh, operand); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index d5683432877..6a8654ec61d 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -36,16 +36,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/comparison_util.h" +#include "tensorflow/compiler/xla/hlo/ir/dynamic_parameter_binding.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_input_output_alias_config.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/stacktrace.h" @@ -61,32 +59,26 @@ struct XlaBuilderFriend { static XlaOp BuildAddDependency(XlaBuilder* builder, XlaOp operand, XlaOp token, const Shape& shape); - static XlaOp BuildAsyncStart(XlaBuilder* builder, - absl::Span operands, - std::string execution_thread, int64_t group_id, - const XlaComputation& called_computation, - const Shape& shape); - static XlaOp BuildAsyncStart(XlaBuilder* builder, - absl::Span operands, - std::string execution_thread, - const XlaComputation& called_computation, - const Shape& shape); + static std::pair BuildAsyncStart( + XlaBuilder* builder, absl::Span operands, + std::string execution_thread, int64_t group_id, + const XlaComputation& called_computation, const Shape& shape); + static std::pair BuildAsyncStart( + XlaBuilder* builder, absl::Span operands, + std::string execution_thread, const XlaComputation& called_computation, + const Shape& shape); static XlaOp BuildAsyncUpdate(XlaBuilder* builder, const XlaOp operands, std::string execution_thread, int64_t group_id, - const XlaComputation& called_computation, - const Shape& shape); + int64_t called_computation, const Shape& shape); static XlaOp BuildAsyncUpdate(XlaBuilder* builder, const XlaOp operands, std::string execution_thread, - const XlaComputation& called_computation, - const Shape& shape); + int64_t called_computation, const Shape& shape); static XlaOp BuildAsyncDone(XlaBuilder* builder, const XlaOp operands, std::string execution_thread, int64_t group_id, - const XlaComputation& called_computation, - const Shape& shape); + int64_t called_computation, const Shape& shape); static XlaOp BuildAsyncDone(XlaBuilder* builder, const XlaOp operands, std::string execution_thread, - const XlaComputation& called_computation, - const Shape& shape); + int64_t called_computation, const Shape& shape); static XlaOp BuildAllGatherStart( XlaBuilder* builder, XlaOp operand, int64_t all_gather_dimension, @@ -106,16 +98,43 @@ struct XlaBuilderFriend { static XlaOp BuildAllReduceDone(XlaBuilder* builder, const XlaOp operands, const Shape& shape); - static XlaOp BuildFusion(XlaBuilder* builder, - absl::Span operands, - absl::string_view fusion_kind, - const XlaComputation& fused_computation); + static XlaOp BuildCollectivePermuteStart( + XlaBuilder* builder, XlaOp operand, + const std::vector>& source_target_pairs, + const std::optional& channel_id = std::nullopt); + static XlaOp BuildCollectivePermuteDone(XlaBuilder* builder, + const XlaOp operands, + const Shape& shape); + + static XlaOp BuildCopyStart( + XlaBuilder* builder, XlaOp operand, + std::optional cross_program_prefetch_index = std::nullopt); + static XlaOp BuildCopyDone(XlaBuilder* builder, const XlaOp operand, + const Shape& shape); + + static XlaOp BuildFusion( + XlaBuilder* builder, absl::Span operands, + absl::string_view fusion_kind, const XlaComputation& fused_computation, + absl::Span>> + output_operand_aliasing = {}); static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand, const Shape& shape); static XlaOp BuildPartitionId(XlaBuilder* builder, const Shape& shape); + static XlaOp BuildSend(XlaBuilder* builder, XlaOp operand, XlaOp token, + const ChannelHandle& handle, bool is_host_transfer); + static XlaOp BuildSendDone(XlaBuilder* builder, XlaOp operand, + const ChannelHandle& handle, + bool is_host_transfer); + + static XlaOp BuildRecv(XlaBuilder* builder, XlaOp token, const Shape& shape, + const ChannelHandle& handle, bool is_host_transfer); + static XlaOp BuildRecvDone(XlaBuilder* builder, XlaOp token, + const Shape& shape, const ChannelHandle& handle, + bool is_host_transfer); + static XlaOp BuildDomain(XlaBuilder* builder, XlaOp operand, const OpSharding entry, const OpSharding exit, const Shape& shape); @@ -266,7 +285,8 @@ class XlaBuilder { // As a result they are set on the computation builder and all the // instructions generated via the computation builder will have the same // frontend attributes attached to them. - void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) { + virtual void SetFrontendAttributes( + const FrontendAttributes& frontend_attributes) { frontend_attributes_ = frontend_attributes; } @@ -361,7 +381,7 @@ class XlaBuilder { // // This will copy the needed ops/computations to the subgraph. StatusOr BuildConstantSubGraph( - XlaOp root_op, bool dynamic_dimension_is_uint_max = false); + XlaOp root_op, bool dynamic_dimension_is_minus_one = false); // Returns the first error that was encountered while building the // computation. When an error is encountered, by default we return a vacuous @@ -815,16 +835,20 @@ class XlaBuilder { XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, absl::Span replica_groups, - const std::optional& layout = std::nullopt); + const std::optional& layout = std::nullopt, + const std::optional& channel_id = std::nullopt); - XlaOp AllToAllTuple(absl::Span operands, - absl::Span replica_groups, - const std::optional& layout); + XlaOp AllToAllTuple( + absl::Span operands, + absl::Span replica_groups, + const std::optional& layout, + const std::optional& channel_id = std::nullopt); - XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension, - int64_t concat_dimension, int64_t split_count, - absl::Span replica_groups, - const std::optional& layout); + XlaOp AllToAllTuple( + XlaOp operand, int64_t split_dimension, int64_t concat_dimension, + int64_t split_count, absl::Span replica_groups, + const std::optional& layout, + const std::optional& channel_id = std::nullopt); XlaOp CollectivePermute( XlaOp operand, @@ -1431,14 +1455,17 @@ class XlaBuilder { friend XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, absl::Span replica_groups, - const std::optional& layout); + const std::optional& layout, + const std::optional& channel_id); friend XlaOp AllToAllTuple(absl::Span operands, absl::Span replica_groups, - const std::optional& layout); + const std::optional& layout, + const std::optional& channel_id); friend XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, absl::Span replica_groups, - const std::optional& layout); + const std::optional& layout, + const std::optional& channel_id); friend XlaOp CollectivePermute( XlaOp operand, const std::vector>& source_target_pairs, @@ -1471,6 +1498,7 @@ class XlaBuilder { friend XlaOp Clz(XlaOp operand); friend XlaOp Cos(XlaOp operand); friend XlaOp Sin(XlaOp operand); + friend XlaOp Tan(XlaOp operand); friend XlaOp Tanh(XlaOp operand); friend XlaOp Real(XlaOp operand); friend XlaOp Imag(XlaOp operand); @@ -1586,14 +1614,20 @@ class XlaBuilder { const std::optional use_global_device_ids, bool async); + XlaOp CollectivePermuteImpl( + XlaOp operand, + const std::vector>& source_target_pairs, + const std::optional& channel_id, bool async); + XlaOp ConditionalImpl( XlaOp branch_index, absl::Span branch_computations, absl::Span branch_operands); - XlaOp AllToAllArray(XlaOp operand, int64_t split_dimension, - int64_t concat_dimension, int64_t split_count, - absl::Span replica_groups); + XlaOp AllToAllArray( + XlaOp operand, int64_t split_dimension, int64_t concat_dimension, + int64_t split_count, absl::Span replica_groups, + const std::optional& channel_id = std::nullopt); // Creates an op with the given opcode and the output shape. virtual StatusOr AddOpWithShape(HloOpcode opcode, const Shape& shape, @@ -1624,7 +1658,7 @@ class XlaBuilder { // // TODO(hinsu): Return const pointer within StatusOr and use // absl::implicit_cast at callsites. This requires implicit_cast support in - // stream_executor::port::StatusOr similar to absl::StatusOr. + // xla::StatusOr similar to absl::StatusOr. template StatusOr LookUpInstructionInternal(XlaOp op) const { TF_RETURN_IF_ERROR(CheckOpBuilder(op)); @@ -2215,7 +2249,7 @@ XlaOp CustomCallWithComputation( absl::Span>> output_operand_aliasing = {}, const Literal* literal = nullptr, - CustomCallSchedule = CustomCallSchedule::SCHEDULE_NONE, + CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, CustomCallApiVersion api_version = API_VERSION_ORIGINAL); // Overload which constructs a custom call with fixed layouts. The operands will @@ -2435,16 +2469,20 @@ XlaOp ReduceScatter( XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, absl::Span replica_groups = {}, - const std::optional& layout = std::nullopt); + const std::optional& layout = std::nullopt, + const std::optional& channel_id = std::nullopt); -XlaOp AllToAllTuple(absl::Span operand, - absl::Span replica_groups = {}, - const std::optional& layout = std::nullopt); +XlaOp AllToAllTuple( + absl::Span operand, + absl::Span replica_groups = {}, + const std::optional& layout = std::nullopt, + const std::optional& channel_id = std::nullopt); -XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension, - int64_t concat_dimension, int64_t split_count, - absl::Span replica_groups = {}, - const std::optional& layout = std::nullopt); +XlaOp AllToAllTuple( + XlaOp operand, int64_t split_dimension, int64_t concat_dimension, + int64_t split_count, absl::Span replica_groups = {}, + const std::optional& layout = std::nullopt, + const std::optional& channel_id = std::nullopt); // Enqueues an collective operation that sends and receives data cross replicas. // @@ -2526,6 +2564,9 @@ XlaOp Cos(XlaOp operand); // Enqueues a sine instruction onto the computation. XlaOp Sin(XlaOp operand); +// Enqueues a tan instruction onto the computation. +XlaOp Tan(XlaOp operand); + // Enqueues a tanh instruction onto the computation. XlaOp Tanh(XlaOp operand); diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 881af35e02a..d01b57e289d 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -15,17 +15,22 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" +#include +#include +#include +#include #include +#include +#include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/value_inference.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/debug_options_flags.h" -#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" -#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/util.h" @@ -1446,5 +1451,65 @@ TEST_F(XlaBuilderTest, ComplexAbsConstant) { PrimitiveType::F32); } +TEST_F(XlaBuilderTest, OutfeedDummyTupleSharding) { + XlaBuilder b(TestName()); + XlaOp value = ConstantR1(&b, {0}); + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(S32, /* dimensions= */ {1}, + /* minor_to_major= */ {0}); + Outfeed(value, shape, ""); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + EXPECT_FALSE(module->entry_computation()->root_instruction()->has_sharding()); +} + +TEST_F(XlaBuilderTest, OutfeedTokenSharding) { + XlaBuilder b(TestName()); + XlaOp value = ConstantR1(&b, {0}); + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(S32, /* dimensions= */ {1}, + /* minor_to_major= */ {0}); + b.SetSharding(sharding_builder::Replicate()); + Outfeed(value, shape, ""); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto it = std::find_if(module->entry_computation()->instructions().begin(), + module->entry_computation()->instructions().end(), + [](const HloInstruction* i) { + return i->opcode() == HloOpcode::kOutfeed; + }); + EXPECT_NE(it, module->entry_computation()->instructions().end()); + auto* outfeed = *it; + EXPECT_TRUE(outfeed->has_sharding()); + EXPECT_TRUE(outfeed->sharding().IsTuple()); + EXPECT_EQ(outfeed->sharding().tuple_elements().size(), 2); + EXPECT_TRUE(outfeed->operand(1)->has_sharding()); + EXPECT_EQ(outfeed->sharding().tuple_elements().back(), + HloSharding::FromProto(sharding_builder::AssignDevice(0)).value()); + EXPECT_EQ(outfeed->operand(1)->sharding(), + HloSharding::FromProto(sharding_builder::AssignDevice(0)).value()); +} + +TEST_F(XlaBuilderTest, NormalizeTupleSharding) { + XlaBuilder b(TestName()); + Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {6})}); + b.SetSharding(sharding_builder::Replicate()); + Parameter(&b, 0, tuple_param_shape, "p0"); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_TRUE(root->has_sharding()); + EXPECT_TRUE(root->sharding().IsTuple()); + EXPECT_EQ(root->sharding().tuple_elements().size(), 2); +} + +TEST_F(XlaBuilderTest, InvalidSharding) { + XlaBuilder b(TestName()); + Shape shape2d = ShapeUtil::MakeShape(F32, {6, 8}); + Shape shape1d = ShapeUtil::MakeShape(F32, {5}); + b.SetSharding(sharding_builder::Tile1D(shape1d, 4)); + Parameter(&b, 0, shape2d, "p0"); + auto statusor = b.Build(); + EXPECT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Number of tile assignment dimensions (excluding " + "subgroups) is different than the input rank")); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/comparison_util.cc b/tensorflow/compiler/xla/comparison_util.cc index 734e47a58af..91998790d1e 100644 --- a/tensorflow/compiler/xla/comparison_util.cc +++ b/tensorflow/compiler/xla/comparison_util.cc @@ -55,6 +55,11 @@ bool IsValidComparison(xla::PrimitiveType type, Comparison::Order order) { case PRIMITIVE_TYPE_INVALID: case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_: case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_: + // TODO(b/259609697): Add support for comparing F8 values. F8 values are + // comparable like any other floating-point type, but comparisons are not + // yet implemented by any backend. + case F8E5M2: + case F8E4M3FN: return false; } } @@ -97,6 +102,8 @@ Comparison::Order DefaultOrdering(PrimitiveType type) { case U32: case U64: return Comparison::Order::kTotal; + case F8E5M2: + case F8E4M3FN: case BF16: case F16: case F32: @@ -250,6 +257,8 @@ Comparison::Type Comparison::DefaultComparisonType(PrimitiveType type) { case U32: case U64: return Type::kUnsigned; + case F8E5M2: + case F8E4M3FN: case F16: case F32: case BF16: @@ -317,6 +326,8 @@ std::optional Comparison::Inverse() const { case TUPLE: case OPAQUE_TYPE: case TOKEN: + case F8E5M2: + case F8E4M3FN: case PRIMITIVE_TYPE_INVALID: case PrimitiveType_INT_MAX_SENTINEL_DO_NOT_USE_: case PrimitiveType_INT_MIN_SENTINEL_DO_NOT_USE_: diff --git a/tensorflow/compiler/xla/comparison_util.h b/tensorflow/compiler/xla/comparison_util.h index 649788333ba..a9023e66891 100644 --- a/tensorflow/compiler/xla/comparison_util.h +++ b/tensorflow/compiler/xla/comparison_util.h @@ -164,7 +164,7 @@ class Comparison { // Returns a comparison operator: (T, T) -> bool for this Comparison's // Direction. template - std::function GetComparator() const { + inline std::function GetComparator() const { switch (GetDirection()) { case Direction::kEq: return std::equal_to(); @@ -184,8 +184,8 @@ class Comparison { // Applies the comparison from this Comparison's direction and ordering for // integral types. template ::value, int> = 0> - bool Compare(const T a, const T b) const { - CHECK(primitive_util::IsCanonicalRepresentation(primitive_type_)); + inline bool Compare(const T a, const T b) const { + DCHECK(primitive_util::IsCanonicalRepresentation(primitive_type_)); return GetComparator()(a, b); } @@ -195,13 +195,13 @@ class Comparison { absl::enable_if_t::value || std::is_same::value, int> = 0> - bool Compare(const T a, const T b) const { - CHECK(primitive_util::IsCanonicalRepresentation(primitive_type_)); + inline bool Compare(const T a, const T b) const { + DCHECK(primitive_util::IsCanonicalRepresentation(primitive_type_)); if (IsTotalOrder()) { // -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN // Reference: // https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations - using R = typename SignedIntegerTypeForSize::type; + using R = SignedIntegerTypeForSizeType; return GetComparator()(ToSignMagnitude(a), ToSignMagnitude(b)); } return GetComparator()(a, b); diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index d74da0f8b61..876e0b9596b 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/debug_options_parsers.h" #include "tensorflow/compiler/xla/parse_flags_from_env.h" +#include "tensorflow/compiler/xla/xla.pb.h" namespace xla { @@ -41,6 +42,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_multi_thread_eigen(true); opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); opts.set_xla_gpu_asm_extra_flags(""); + opts.set_xla_gpu_use_runtime_fusion(true); opts.set_xla_eliminate_hlo_implicit_broadcast(true); opts.set_xla_dump_hlo_as_html(false); opts.set_xla_dump_fusion_visualization(false); @@ -48,6 +50,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_dump_max_hlo_modules(-1); opts.set_xla_dump_module_metadata(false); opts.set_xla_dump_hlo_as_long_text(false); + opts.set_xla_dump_enable_mlir_pretty_form(true); #ifdef ENABLE_MKL opts.set_xla_cpu_use_mkl_dnn(true); #endif // ENABLE_MKL @@ -55,7 +58,6 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_use_acl(true); #endif opts.set_xla_cpu_use_xla_runtime(false); - opts.set_xla_gpu_max_kernel_unroll_factor(4); opts.set_xla_cpu_enable_fast_math(false); // Disable forms of fast math that have caused users problems in the past. @@ -71,6 +73,9 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_cublaslt(false); + // TODO(b/258036887): Remove this flag once CUDA Graphs are fully supported. + opts.set_xla_gpu_enable_cuda_graphs(false); + // Despite the name, fast min/max on GPUs does not seem to be any faster, and // adds very counter-intuitive "NaN-swallowing" behavior. opts.set_xla_gpu_enable_fast_min_max(false); @@ -85,18 +90,24 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_multiheap_size_constraint_per_heap(-1); opts.set_xla_detailed_logging_and_dumping(true); - opts.set_xla_gpu_enable_xla_runtime_executable(false); + opts.set_xla_gpu_enable_xla_runtime_executable(true); opts.set_xla_gpu_nccl_termination_timeout_seconds(-1); opts.set_xla_gpu_enable_shared_constants(true); // Set 4GB space limit for redzone scratch allocator. opts.set_xla_gpu_redzone_scratch_max_megabytes(1LL << 12); opts.set_xla_gpu_shape_checks(DebugOptions::RUNTIME); - opts.set_xla_cpu_enable_mlir_lowering(false); opts.set_xla_gpu_enable_mlir_lowering(true); - opts.set_xla_gpu_enable_softmax_fusion(false); - opts.set_xla_gpu_normalize_layouts(false); + opts.set_xla_gpu_enable_softmax_fusion(true); + opts.set_xla_gpu_normalize_layouts(true); opts.set_xla_gpu_simplify_all_fp_conversions(true); + opts.set_xla_dump_latency_hiding_schedule(false); + opts.set_xla_gpu_enable_latency_hiding_scheduler(false); + + opts.set_xla_cpu_enable_mlir_tiling_and_fusion(false); + + opts.set_xla_partitioning_algorithm( + DebugOptions::PARTITIONING_ALGORITHM_NOOP); return opts; } @@ -137,86 +148,134 @@ static void WarnIfFuelWasNeverConsumed() { } } -// Allocates flag_values and flag_objects; this function must not be called more -// than once - its call done via call_once. -static void AllocateFlags() { - flag_values = new DebugOptions(DefaultDebugOptionsIgnoringFlags()); - - // Returns a lambda that calls "member_setter" on "flag_values" with the +void MakeDebugOptionsFlags(std::vector* flag_list, + DebugOptions* debug_options) { + // Returns a lambda that calls "member_setter" on "debug_options" with the // argument passed in to the lambda. - auto bool_setter_for = [](void (DebugOptions::*member_setter)(bool)) { - return [member_setter](bool value) { - (flag_values->*member_setter)(value); - return true; - }; - }; + auto bool_setter_for = + [debug_options](void (DebugOptions::*member_setter)(bool)) { + return [debug_options, member_setter](bool value) { + (debug_options->*member_setter)(value); + return true; + }; + }; - // Returns a lambda that calls "member_setter" on "flag_values" with the + // Returns a lambda that calls "member_setter" on "debug_options" with the // argument passed in to the lambda. - auto int32_setter_for = [](void (DebugOptions::*member_setter)(int32_t)) { - return [member_setter](int32_t value) { - (flag_values->*member_setter)(value); - return true; - }; - }; + auto int32_setter_for = + [debug_options](void (DebugOptions::*member_setter)(int32_t)) { + return [debug_options, member_setter](int32_t value) { + (debug_options->*member_setter)(value); + return true; + }; + }; - auto int64_setter_for = [](void (DebugOptions::*member_setter)(int64_t)) { - return [member_setter](int64_t value) { - (flag_values->*member_setter)(value); + auto int64_setter_for = + [debug_options](void (DebugOptions::*member_setter)(int64_t)) { + return [debug_options, member_setter](int64_t value) { + (debug_options->*member_setter)(value); + return true; + }; + }; + + auto string_setter_for = [debug_options](void (DebugOptions::*member_setter)( + const std::string& value)) { + return [debug_options, member_setter](const std::string& value) { + (debug_options->*member_setter)(value); return true; }; }; - auto string_setter_for = - [](void (DebugOptions::*member_setter)(const std::string& value)) { - return [member_setter](const std::string& value) { - (flag_values->*member_setter)(value); - return true; - }; + // Custom "sub-parser" lambda for xla_gpu_shape_checks. + auto setter_for_xla_gpu_shape_checks = + [debug_options](const std::string& value) { + DebugOptions::ShapeChecks shape_checks; + if (!DebugOptions::ShapeChecks_Parse(value, &shape_checks)) { + return false; + } + debug_options->set_xla_gpu_shape_checks(shape_checks); + return true; }; // Custom "sub-parser" lambda for xla_disable_hlo_passes. auto setter_for_xla_disable_hlo_passes = - [](std::string comma_separated_values) { + [debug_options](std::string comma_separated_values) { for (const auto& passname : std::vector( absl::StrSplit(comma_separated_values, ','))) { - flag_values->add_xla_disable_hlo_passes(passname); + debug_options->add_xla_disable_hlo_passes(passname); } return true; }; // Custom "sub-parser" lambda for xla_enable_hlo_passes_only. auto setter_for_xla_enable_hlo_passes_only = - [](std::string comma_separated_values) { + [debug_options](std::string comma_separated_values) { for (const auto& passname : std::vector( absl::StrSplit(comma_separated_values, ','))) { - flag_values->add_xla_enable_hlo_passes_only(passname); + debug_options->add_xla_enable_hlo_passes_only(passname); } return true; }; // Custom "sub-parser" lambda for xla_gpu_ptx_file. - auto setter_for_xla_gpu_ptx_file = [](std::string value) { - flag_values->add_xla_gpu_ptx_file(value); + auto setter_for_xla_gpu_ptx_file = [debug_options](std::string value) { + debug_options->add_xla_gpu_ptx_file(value); return true; }; // Custom "sub-parser" lambda for xla_gpu_llvm_ir_file. - auto setter_for_xla_gpu_llvm_ir_file = [](const std::string& value) { - flag_values->add_xla_gpu_llvm_ir_file(value); - return true; - }; + auto setter_for_xla_gpu_llvm_ir_file = + [debug_options](const std::string& value) { + debug_options->add_xla_gpu_llvm_ir_file(value); + return true; + }; // Custom "sub-parser" lambda for xla_backend_extra_options. auto setter_for_xla_backend_extra_options = - [](std::string comma_separated_values) { + [debug_options](std::string comma_separated_values) { auto* extra_options_map = - flag_values->mutable_xla_backend_extra_options(); + debug_options->mutable_xla_backend_extra_options(); parse_xla_backend_extra_options(extra_options_map, comma_separated_values); return true; }; + auto setter_for_xla_gpu_enable_softmax_fusion = [debug_options](bool value) { + // It is only possible to enable softmax fusion if + // xla_gpu_enable_mlir_lowering is also enabled. + if (value && !debug_options->xla_gpu_enable_mlir_lowering()) { + LOG(ERROR) << "xla_gpu_enable_softmax_fusion can only be enabled if " + "xla_gpu_enable_mlir_lowering is enabled as well"; + return false; + } + debug_options->set_xla_gpu_enable_softmax_fusion(value); + return true; + }; + + auto setter_for_xla_gpu_enable_mlir_lowering = [debug_options](bool value) { + // It is only possible to disable mlir lowering if + // xla_gpu_enable_softmax_fusion is also disabled. + if (!value && debug_options->xla_gpu_enable_softmax_fusion()) { + LOG(ERROR) << "xla_gpu_enable_mlir_lowering can only be disabled if " + "xla_gpu_enable_softmax_fusion is disabled as well"; + return false; + } + debug_options->set_xla_gpu_enable_mlir_lowering(value); + return true; + }; + + // Custom "sub-parser" lambda for xla_partitioning_algorithm. + auto setter_for_xla_partitioning_algorithm = + [debug_options](const std::string& value) { + DebugOptions::PartitioningAlgorithm partitioning_algorithm; + if (!DebugOptions::PartitioningAlgorithm_Parse( + value, &partitioning_algorithm)) { + return false; + } + debug_options->set_xla_partitioning_algorithm(partitioning_algorithm); + return true; + }; + // Custom "sub-parser" for xla_fuel. Note that ConsumeFuel does not do any // locking on the fuel global variables. This means that it's // illegal/undefined behavior to modify this flag value while the compiler is @@ -266,96 +325,95 @@ static void AllocateFlags() { return true; }; - flag_objects = new std::vector(); // Don't use an initializer list for initializing the vector; this would // create a temporary copy, and exceeds the stack space when compiling with // certain configurations. - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_cpu_enable_fast_math", bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), - flag_values->xla_cpu_enable_fast_math(), + debug_options->xla_cpu_enable_fast_math(), "Enable unsafe fast-math optimizations in the CPU compiler; this may " "produce faster code at the expense of some accuracy.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_cpu_fast_math_honor_nans", bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_nans), - flag_values->xla_cpu_fast_math_honor_nans(), + debug_options->xla_cpu_fast_math_honor_nans(), "When xla_cpu_enable_fast_math is true then this controls whether we " "allow operations to produce NaNs. Ignored when " "xla_cpu_enable_fast_math is false.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_cpu_fast_math_honor_infs", bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_infs), - flag_values->xla_cpu_fast_math_honor_infs(), + debug_options->xla_cpu_fast_math_honor_infs(), "When xla_cpu_enable_fast_math is true then this controls whether we " "allow operations to produce infinites. Ignored when " "xla_cpu_enable_fast_math is false.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_cpu_fast_math_honor_division", bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_division), - flag_values->xla_cpu_fast_math_honor_division(), + debug_options->xla_cpu_fast_math_honor_division(), "When xla_cpu_enable_fast_math is true then this controls whether we " "forbid to use multiplication by the reciprocal instead of division. " "Ignored when xla_cpu_enable_fast_math is false.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_cpu_fast_math_honor_functions", bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_functions), - flag_values->xla_cpu_fast_math_honor_functions(), + debug_options->xla_cpu_fast_math_honor_functions(), "When xla_cpu_enable_fast_math is true then this controls whether we " "forbid to approximate calculations for functions. Ignored when " "xla_cpu_enable_fast_math is false.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_cpu_enable_fast_min_max", bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_min_max), - flag_values->xla_cpu_enable_fast_min_max(), + debug_options->xla_cpu_enable_fast_min_max(), "Enable fast floating point min/max lowering that always propagates " "NaNs.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_enable_fast_min_max", bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max), - flag_values->xla_gpu_enable_fast_min_max(), + debug_options->xla_gpu_enable_fast_min_max(), "Enable fast floating point min/max lowering that does not propagate " "NaNs.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_llvm_enable_alias_scope_metadata", bool_setter_for(&DebugOptions::set_xla_llvm_enable_alias_scope_metadata), - flag_values->xla_llvm_enable_alias_scope_metadata(), + debug_options->xla_llvm_enable_alias_scope_metadata(), "In LLVM-based backends, enable the emission of !alias.scope metadata in " "the generated IR.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_llvm_enable_noalias_metadata", bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata), - flag_values->xla_llvm_enable_noalias_metadata(), + debug_options->xla_llvm_enable_noalias_metadata(), "In LLVM-based backends, enable the emission of !noalias metadata in the " "generated IR.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_llvm_enable_invariant_load_metadata", bool_setter_for( &DebugOptions::set_xla_llvm_enable_invariant_load_metadata), - flag_values->xla_llvm_enable_invariant_load_metadata(), + debug_options->xla_llvm_enable_invariant_load_metadata(), "In LLVM-based backends, enable the emission of !invariant.load metadata " "in the generated IR.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_llvm_disable_expensive_passes", bool_setter_for(&DebugOptions::set_xla_llvm_disable_expensive_passes), - flag_values->xla_llvm_disable_expensive_passes(), + debug_options->xla_llvm_disable_expensive_passes(), "In LLVM-based backends, disable a custom set of expensive optimization " "passes.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_backend_optimization_level", int32_setter_for(&DebugOptions::set_xla_backend_optimization_level), - flag_values->xla_backend_optimization_level(), + debug_options->xla_backend_optimization_level(), "Numerical optimization level for the XLA compiler backend.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "", "Comma-separated list of hlo passes to be disabled. These names must " "exactly match the passes' names; no whitespace around commas.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_enable_hlo_passes_only", setter_for_xla_enable_hlo_passes_only, "", "Comma-separated list of hlo passes to be enabled. These names must " "exactly match the passes' names; no whitespace around commas. The " "unspecified passes are all disabled.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_disable_all_hlo_passes", bool_setter_for(&DebugOptions::set_xla_disable_all_hlo_passes), false, "Disables all HLO passes. Notes that some passes are necessary for " @@ -364,44 +422,39 @@ static void AllocateFlags() { "over time. The only 'guarantee', such as it is, is that if you compile " "XLA and dump the optimized HLO for some graph, you should be able to " "run it again on the same device with the same build of XLA.")); - flag_objects->push_back( + flag_list->push_back( tsl::Flag("xla_embed_ir_in_executable", bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), - flag_values->xla_embed_ir_in_executable(), + debug_options->xla_embed_ir_in_executable(), "Embed the compiler IR as a string in the executable.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_eliminate_hlo_implicit_broadcast", bool_setter_for(&DebugOptions::set_xla_eliminate_hlo_implicit_broadcast), - flag_values->xla_eliminate_hlo_implicit_broadcast(), + debug_options->xla_eliminate_hlo_implicit_broadcast(), "Eliminate implicit broadcasts when lowering user computations to HLO " "instructions; use explicit broadcast instead.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_cpu_multi_thread_eigen", bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen), - flag_values->xla_cpu_multi_thread_eigen(), + debug_options->xla_cpu_multi_thread_eigen(), "When generating calls to Eigen in the CPU backend, use multi-threaded " "Eigen mode.")); - flag_objects->push_back(tsl::Flag( - "xla_gpu_cuda_data_dir", flag_values->mutable_xla_gpu_cuda_data_dir(), + flag_list->push_back(tsl::Flag( + "xla_gpu_cuda_data_dir", debug_options->mutable_xla_gpu_cuda_data_dir(), "If non-empty, specifies a local directory containing ptxas and nvvm " "libdevice files; otherwise we use those from runfile directories.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_ftz", bool_setter_for(&DebugOptions::set_xla_gpu_ftz), - flag_values->xla_gpu_ftz(), + debug_options->xla_gpu_ftz(), "If true, flush-to-zero semantics are enabled in the code generated for " "GPUs.")); - flag_objects->push_back(tsl::Flag( - "xla_gpu_max_kernel_unroll_factor", - int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor), - flag_values->xla_gpu_max_kernel_unroll_factor(), - "Specify the maximum kernel unroll factor for the GPU backend.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_ptx_file", setter_for_xla_gpu_ptx_file, "", "If non-empty, specifies a file containing ptx to use. The filename " "prefix must have the same pattern as PTX dumped by XLA. This allows to " "match one specific module. General workflow. Get the generated module " "ptx from XLA, modify it, then pass it back via this option.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_llvm_ir_file", setter_for_xla_gpu_llvm_ir_file, "", "If non-empty, specifies a file containing textual LLVM IR to use. The " "filename prefix must have the same pattern as LLVM dumped by XLA " @@ -409,390 +462,446 @@ static void AllocateFlags() { "allows to match one specific module. General workflow. Get the not " "optimized LLVM IR from XLA, modify it, then pass it back via this " "option.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_test_all_output_layouts", bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), - flag_values->xla_test_all_output_layouts(), + debug_options->xla_test_all_output_layouts(), "Let ClientLibraryTestBase::ComputeAndCompare* test all permutations of " "output layouts. For example, with a 3D shape, all permutations of the " "set {0, 1, 2} are tried.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_test_all_input_layouts", bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts), - flag_values->xla_test_all_input_layouts(), + debug_options->xla_test_all_input_layouts(), "Let ClientLibraryTestBase::ComputeAndCompare* test all permutations of " "*input* layouts. For example, for 2 input arguments with 2D shape and " "4D shape, the computation will run 2! * 4! times for every possible " "layouts")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_hlo_profile", bool_setter_for(&DebugOptions::set_xla_hlo_profile), - flag_values->xla_hlo_profile(), + debug_options->xla_hlo_profile(), "Instrument the computation to collect per-HLO cycle counts")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_backend_extra_options", setter_for_xla_backend_extra_options, "", "Extra options to pass to a backend; comma-separated list of 'key=val' " "strings (=val may be omitted); no whitespace around commas.")); - flag_objects->push_back( + flag_list->push_back( tsl::Flag("xla_cpu_use_mkl_dnn", bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn), - flag_values->xla_cpu_use_mkl_dnn(), + debug_options->xla_cpu_use_mkl_dnn(), "Generate calls to MKL-DNN in the CPU backend.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_cpu_use_acl", bool_setter_for(&DebugOptions::set_xla_cpu_use_acl), - flag_values->xla_cpu_use_acl(), + debug_options->xla_cpu_use_acl(), "Generate calls to ACL (Arm Compute Library) in the CPU backend.")); - flag_objects->push_back( + flag_list->push_back( tsl::Flag("xla_cpu_use_xla_runtime", bool_setter_for(&DebugOptions::set_xla_cpu_use_xla_runtime), - flag_values->xla_cpu_use_xla_runtime(), + debug_options->xla_cpu_use_xla_runtime(), "Enable XLA Runtime in the CPU backend.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_crash_on_verification_failures", bool_setter_for( &DebugOptions::set_xla_gpu_crash_on_verification_failures), - flag_values->xla_gpu_crash_on_verification_failures(), + debug_options->xla_gpu_crash_on_verification_failures(), "Crashes the program on extra verification failures, e.g. cuDNN cross " "checking failures")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_strict_conv_algorithm_picker", bool_setter_for(&DebugOptions::set_xla_gpu_strict_conv_algorithm_picker), - flag_values->xla_gpu_strict_conv_algorithm_picker(), + debug_options->xla_gpu_strict_conv_algorithm_picker(), "Upgrades warnings to failures when all algorithms fail conv " "autotuning.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_autotune_level", int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level), - flag_values->xla_gpu_autotune_level(), + debug_options->xla_gpu_autotune_level(), "Set GEMM and Convolution auto-tuning level. 0 = off; 1 = on; 2 = " "on+init; 3 = on+init+reinit; 4 = on+init+reinit+check.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_force_host_platform_device_count", int32_setter_for(&DebugOptions::set_xla_force_host_platform_device_count), - flag_values->xla_force_host_platform_device_count(), + debug_options->xla_force_host_platform_device_count(), "Force the host platform to pretend that there are these many host " "\"devices\". All of these host devices are backed by the same " "threadpool. Setting this to anything other than 1 can increase overhead " "from context switching but we let the user override this behavior to " "help run tests on the host that run models in parallel across multiple " "devices.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_disable_gpuasm_optimizations", bool_setter_for(&DebugOptions::set_xla_gpu_disable_gpuasm_optimizations), - flag_values->xla_gpu_disable_gpuasm_optimizations(), + debug_options->xla_gpu_disable_gpuasm_optimizations(), "In XLA:GPU run ptxas in -O0 (default is -O3).")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_asm_extra_flags", string_setter_for(&DebugOptions::set_xla_gpu_asm_extra_flags), "", "Pass extra parameters to the GPU assembler tool (i.e., ptxas for CUDA). " "If multiple parameters, separate them by comma.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_fuel", setter_for_xla_fuel, /*default_value_for_display=*/"", "Sets compiler fuel, useful for bisecting bugs in passes. Format " "--xla_fuel=PASS1=NUM1,PASS2=NUM2,...")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_dump_to", string_setter_for(&DebugOptions::set_xla_dump_to), - flag_values->xla_dump_to(), + debug_options->xla_dump_to(), "Directory into which debugging data is written. If not specified but " "another dumping flag is passed, data will be written to stdout. To " "explicitly write to stdout, set this to \"-\". The values \"sponge\" " "and \"test_undeclared_outputs_dir\" have a special meaning: They cause " "us to dump into the directory specified by the environment variable " "TEST_UNDECLARED_OUTPUTS_DIR.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_dump_hlo_as_text", bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_text), - flag_values->xla_dump_hlo_as_text(), - "Dumps HLO modules as text before and after optimizations. Results are " + debug_options->xla_dump_hlo_as_text(), + "Dumps HLO modules as text before and after optimizations. debug_options " + "are " "written to the --xla_dump_to dir, or, if no dir is specified, to " "stdout.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_dump_hlo_as_long_text", bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_long_text), - flag_values->xla_dump_hlo_as_long_text(), - "Dumps HLO modules as long text before and after optimizations. Results " + debug_options->xla_dump_hlo_as_long_text(), + "Dumps HLO modules as long text before and after optimizations. " + "debug_options " "are written to the --xla_dump_to dir, or, if no dir is specified, to " "stdout. Ignored unless xla_dump_hlo_as_text is true.")); - flag_objects->push_back( + flag_list->push_back( tsl::Flag("xla_dump_hlo_as_proto", bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_proto), - flag_values->xla_dump_hlo_as_proto(), + debug_options->xla_dump_hlo_as_proto(), "Dumps HLO modules as HloProtos to the directory specified by " "--xla_dump_to.")); - flag_objects->push_back( + flag_list->push_back( tsl::Flag("xla_dump_hlo_as_dot", bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_dot), - flag_values->xla_dump_hlo_as_dot(), + debug_options->xla_dump_hlo_as_dot(), "Dumps HLO modules rendered as dot files to the " "directory specified by --xla_dump_to.")); - flag_objects->push_back( + flag_list->push_back( tsl::Flag("xla_dump_hlo_as_html", bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_html), - flag_values->xla_dump_hlo_as_html(), + debug_options->xla_dump_hlo_as_html(), "Dumps HLO modules rendered as HTML files to the " "directory specified by --xla_dump_to.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_dump_hlo_as_url", bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_url), - flag_values->xla_dump_hlo_as_url(), + debug_options->xla_dump_hlo_as_url(), "Tries to dump HLO modules rendered as URLs to stdout (and also to the " "directory specified by --xla_dump_to). This is not implemented by " "default; you need to add a plugin which calls " "RegisterGraphToURLRenderer().")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_dump_fusion_visualization", bool_setter_for(&DebugOptions::set_xla_dump_fusion_visualization), - flag_values->xla_dump_fusion_visualization(), + debug_options->xla_dump_fusion_visualization(), "Tries to generate HLO fusion visualization as an HTML page to the " "directory specified by --xla_dump_to). This is not implemented by " "default; you need to add a plugin which calls " "RegisterGraphToURLRenderer(). Generates a file per computation. " "Currently only implemented for the GPU backend.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_dump_hlo_snapshots", bool_setter_for(&DebugOptions::set_xla_dump_hlo_snapshots), - flag_values->xla_dump_hlo_snapshots(), + debug_options->xla_dump_hlo_snapshots(), "Every time an HLO module is run, dumps an HloSnapshot to the directory " "specified by --xla_dump_to.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_dump_hlo_module_re", string_setter_for(&DebugOptions::set_xla_dump_hlo_module_re), - flag_values->xla_dump_hlo_module_re(), + debug_options->xla_dump_hlo_module_re(), "Limits dumping only to modules which match this regular expression. " "Default is to dump all modules.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_dump_hlo_pass_re", string_setter_for(&DebugOptions::set_xla_dump_hlo_pass_re), - flag_values->xla_dump_hlo_pass_re(), + debug_options->xla_dump_hlo_pass_re(), "If specified, dumps HLO before and after optimization passes which " "match this regular expression, in addition to dumping at the very " "beginning and end of compilation.")); - flag_objects->push_back( + flag_list->push_back( tsl::Flag("xla_dump_include_timestamp", bool_setter_for(&DebugOptions::set_xla_dump_include_timestamp), - flag_values->xla_dump_include_timestamp(), + debug_options->xla_dump_include_timestamp(), "If specified, includes a timestamp in the dumped filenames.")); - flag_objects->push_back( + flag_list->push_back( tsl::Flag("xla_dump_max_hlo_modules", int32_setter_for(&DebugOptions::set_xla_dump_max_hlo_modules), - flag_values->xla_dump_max_hlo_modules(), + debug_options->xla_dump_max_hlo_modules(), "Max number of hlo module dumps in a directory. Set to < 0 for " "unbounded.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_dump_module_metadata", bool_setter_for(&DebugOptions::set_xla_dump_module_metadata), - flag_values->xla_dump_module_metadata(), + debug_options->xla_dump_module_metadata(), "Dumps HloModuleMetadata as text protos to the directory specified " "by --xla_dump_to.")); - flag_objects->push_back( + flag_list->push_back( tsl::Flag("xla_dump_compress_protos", bool_setter_for(&DebugOptions::set_xla_dump_compress_protos), - flag_values->xla_dump_compress_protos(), + debug_options->xla_dump_compress_protos(), "Gzip-compress protos dumped by --xla_dump_hlo_as_proto.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_hlo_graph_addresses", bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), - flag_values->xla_hlo_graph_addresses(), + debug_options->xla_hlo_graph_addresses(), "When rendering graphs (--xla_dump_hlo_as_{dot,html,url}), displays " "the address in memory of each HloInstruction object.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_hlo_graph_sharding_color", bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), - flag_values->xla_hlo_graph_sharding_color(), + debug_options->xla_hlo_graph_sharding_color(), "Assign colors based on sharding assignments when generating the HLO " "graphs.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_allow_excess_precision", bool_setter_for(&DebugOptions::set_xla_allow_excess_precision), - flag_values->xla_allow_excess_precision(), + debug_options->xla_allow_excess_precision(), "Allow xla to increase the output precision of an instruction.")); - flag_objects->push_back( + flag_list->push_back( tsl::Flag("xla_gpu_force_conv_nchw", bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nchw), - flag_values->xla_gpu_force_conv_nchw(), + debug_options->xla_gpu_force_conv_nchw(), "For cuDNN convolutions, always use NCHW layouts.")); - flag_objects->push_back( + flag_list->push_back( tsl::Flag("xla_gpu_force_conv_nhwc", bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nhwc), - flag_values->xla_gpu_force_conv_nhwc(), + debug_options->xla_gpu_force_conv_nhwc(), "For cuDNN convolutions, always use NHWC layouts.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_algorithm_denylist_path", string_setter_for(&DebugOptions::set_xla_gpu_algorithm_denylist_path), - flag_values->xla_gpu_algorithm_denylist_path(), + debug_options->xla_gpu_algorithm_denylist_path(), "An AlgorithmDenylist text proto file as a denylist of convolutions to " "avoid to use.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back( + tsl::Flag("xla_gpu_use_runtime_fusion", + bool_setter_for(&DebugOptions::set_xla_gpu_use_runtime_fusion), + debug_options->xla_gpu_use_runtime_fusion(), + "For using cuDNN runtime compiled fusion kernels.")); + flag_list->push_back(tsl::Flag( "xla_tpu_detect_nan", bool_setter_for(&DebugOptions::set_xla_tpu_detect_nan), - flag_values->xla_tpu_detect_nan(), + debug_options->xla_tpu_detect_nan(), "Trigger error on execution on TPU if a NAN value is detected")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_tpu_detect_inf", bool_setter_for(&DebugOptions::set_xla_tpu_detect_inf), - flag_values->xla_tpu_detect_inf(), + debug_options->xla_tpu_detect_inf(), "Trigger error on execution on TPU if a INF value is detected")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_cpu_enable_xprof_traceme", bool_setter_for(&DebugOptions::set_xla_cpu_enable_xprof_traceme), - flag_values->xla_cpu_enable_xprof_traceme(), + debug_options->xla_cpu_enable_xprof_traceme(), "If true, XLA CPU generates code to call " "TraceMe::Activity{Start|End} around HLO operations.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found", bool_setter_for( &DebugOptions:: set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found), - flag_values->xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found(), + debug_options->xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found(), "If true, XLA GPU falls back to the driver if ptxas is not found. Note " "that falling back to the driver can have drawbacks like using more " "memory and/or other bugs during compilation, so we recommend setting " "this flag to false.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_multiheap_size_constraint_per_heap", int32_setter_for( &DebugOptions::set_xla_multiheap_size_constraint_per_heap), - flag_values->xla_multiheap_size_constraint_per_heap(), + debug_options->xla_multiheap_size_constraint_per_heap(), "Generates multiple heaps (i.e., temp buffers) with a size " "constraint on each heap to avoid Out-of-Memory due to memory " "fragmentation. The constraint is soft, so it works with tensors " "larger than the given constraint size. -1 corresponds to no " "constraints.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_force_compilation_parallelism", int32_setter_for( &DebugOptions::set_xla_gpu_force_compilation_parallelism), - flag_values->xla_gpu_force_compilation_parallelism(), - "Overrides normal multi-threaded compilation settting to use this many " + debug_options->xla_gpu_force_compilation_parallelism(), + "Overrides normal multi-threaded compilation setting to use this many " "threads. Setting to 0 (the default value) means no enforcement.")); - flag_objects->push_back( + flag_list->push_back( tsl::Flag("xla_gpu_deterministic_ops", bool_setter_for(&DebugOptions::set_xla_gpu_deterministic_ops), - flag_values->xla_gpu_deterministic_ops(), + debug_options->xla_gpu_deterministic_ops(), "Guarantees run-to-run determinism on GPU.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_enable_async_all_reduce", bool_setter_for(&DebugOptions::set_xla_gpu_enable_async_all_reduce), - flag_values->xla_gpu_enable_async_all_reduce(), + debug_options->xla_gpu_enable_async_all_reduce(), "Converts synchronous all-reduce ops into asynchronous.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_async_collective_permute", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_async_collective_permute), + debug_options->xla_gpu_enable_async_collective_permute(), + "Converts synchronous collective-permute ops into asynchronous.")); + flag_list->push_back(tsl::Flag( "xla_gpu_all_reduce_combine_threshold_bytes", int64_setter_for( &DebugOptions::set_xla_gpu_all_reduce_combine_threshold_bytes), - flag_values->xla_gpu_all_reduce_combine_threshold_bytes(), + debug_options->xla_gpu_all_reduce_combine_threshold_bytes(), "Size threshold (in bytes) for the GPU all-reduce combiner.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_all_reduce_contiguous", bool_setter_for(&DebugOptions::set_xla_gpu_all_reduce_contiguous), - flag_values->xla_gpu_all_reduce_contiguous(), + debug_options->xla_gpu_all_reduce_contiguous(), "Combine all-reduces into a single operation over a contiguous buffer.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_all_reduce_blueconnect_num_devices_per_host", int32_setter_for( &DebugOptions:: set_xla_gpu_all_reduce_blueconnect_num_devices_per_host), - flag_values->xla_gpu_all_reduce_blueconnect_num_devices_per_host(), + debug_options->xla_gpu_all_reduce_blueconnect_num_devices_per_host(), "Number of devices per host for first stage of BlueConnect decomposition " "pass. The pass will attempt to decompose all-reduces ops into a " "ReduceScatter-AllReduce-AllGather sequence, with the initial " "ReduceScatter being performed over all of the devices in the same host. " "Set to < 1 to disable all-reduce decomposition.")); - flag_objects->push_back( + flag_list->push_back( tsl::Flag("xla_gpu_dump_llvmir", bool_setter_for(&DebugOptions::set_xla_gpu_dump_llvmir), - flag_values->xla_gpu_dump_llvmir(), "Dump LLVM IR.")); - flag_objects->push_back(tsl::Flag( + debug_options->xla_gpu_dump_llvmir(), "Dump LLVM IR.")); + flag_list->push_back(tsl::Flag( "xla_gpu_enable_cudnn_frontend", bool_setter_for(&DebugOptions::set_xla_gpu_enable_cudnn_frontend), - flag_values->xla_gpu_enable_cudnn_frontend(), + debug_options->xla_gpu_enable_cudnn_frontend(), "Use the cuDNN frontend API for convolutions when possible.")); - flag_objects->push_back( + flag_list->push_back( tsl::Flag("xla_gpu_enable_cublaslt", bool_setter_for(&DebugOptions::set_xla_gpu_enable_cublaslt), - flag_values->xla_gpu_enable_cublaslt(), + debug_options->xla_gpu_enable_cublaslt(), "Use cuBLASLt for GEMMs when possible.")); - flag_objects->push_back( + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_cuda_graphs", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_cuda_graphs), + debug_options->xla_gpu_enable_cuda_graphs(), + "Use CUDA graphs to execute XLA GPU executables when possible.")); + flag_list->push_back( tsl::Flag("xla_dump_disable_metadata", bool_setter_for(&DebugOptions::set_xla_dump_disable_metadata), - flag_values->xla_dump_disable_metadata(), + debug_options->xla_dump_disable_metadata(), "Disable dumping HLO metadata in HLO dumps.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_dump_hlo_pipeline_re", string_setter_for(&DebugOptions::set_xla_dump_hlo_pipeline_re), - flag_values->xla_dump_hlo_pipeline_re(), + debug_options->xla_dump_hlo_pipeline_re(), "If specified, dumps HLO before and after optimization passes in the " "pass pipelines that match this regular expression.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( + "xla_dump_enable_mlir_pretty_form", + bool_setter_for(&DebugOptions::set_xla_dump_enable_mlir_pretty_form), + debug_options->xla_dump_enable_mlir_pretty_form(), + "Enable dumping MLIR using pretty print form. If set to false, the " + "dumped " + "MLIR will be in the llvm-parsable format and can be processed by " + "mlir-opt tools. " + "Pretty print form is not legal MLIR.")); + flag_list->push_back(tsl::Flag( "xla_gpu_enable_xla_runtime_executable", bool_setter_for(&DebugOptions::set_xla_gpu_enable_xla_runtime_executable), - flag_values->xla_gpu_enable_xla_runtime_executable(), + debug_options->xla_gpu_enable_xla_runtime_executable(), "Whether to enable XLA runtime for XLA:GPU backend")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_nccl_termination_timeout_seconds", int64_setter_for( &DebugOptions::set_xla_gpu_nccl_termination_timeout_seconds), - flag_values->xla_gpu_nccl_termination_timeout_seconds(), + debug_options->xla_gpu_nccl_termination_timeout_seconds(), "Timeout in seconds before terminating jobs stuck in NCCL Rendezvous.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_enable_shared_constants", bool_setter_for(&DebugOptions::set_xla_gpu_enable_shared_constants), - flag_values->xla_gpu_enable_shared_constants(), + debug_options->xla_gpu_enable_shared_constants(), "Enable constant sharing between GPU executables")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_redzone_scratch_max_megabytes", int64_setter_for( &DebugOptions::set_xla_gpu_redzone_scratch_max_megabytes), - flag_values->xla_gpu_redzone_scratch_max_megabytes(), + debug_options->xla_gpu_redzone_scratch_max_megabytes(), "Max size (in megabytes) for the GPU redzone scratch allocator.")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_gpu_simplify_all_fp_conversions", bool_setter_for(&DebugOptions::set_xla_gpu_simplify_all_fp_conversions), - flag_values->xla_gpu_simplify_all_fp_conversions(), + debug_options->xla_gpu_simplify_all_fp_conversions(), "Allows any chain of floating-point conversions to be simplified.")); - flag_objects->push_back(tsl::Flag( - "xla_cpu_enable_mlir_lowering", - bool_setter_for(&DebugOptions::set_xla_cpu_enable_mlir_lowering), - flag_values->xla_cpu_enable_mlir_lowering(), - "Enable MLIR-based lowering in XLA:CPU instead of LLVM emitters.")); - flag_objects->push_back(tsl::Flag( - "xla_gpu_enable_mlir_lowering", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_mlir_lowering), - flag_values->xla_gpu_enable_mlir_lowering(), + flag_list->push_back(tsl::Flag( + "xla_gpu_shape_checks", setter_for_xla_gpu_shape_checks, + DebugOptions::ShapeChecks_Name(debug_options->xla_gpu_shape_checks()), + "When to perform shape checks in XLA:GPU.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_mlir_lowering", setter_for_xla_gpu_enable_mlir_lowering, + debug_options->xla_gpu_enable_mlir_lowering(), "Enable MLIR-based lowering in XLA:GPU instead of LLVM emitters.")); - flag_objects->push_back(tsl::Flag( - "xla_gpu_enable_softmax_fusion", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_softmax_fusion), - flag_values->xla_gpu_enable_mlir_lowering(), - "Enable MLIR-based softmax fusion.")); - flag_objects->push_back( + flag_list->push_back(tsl::Flag("xla_gpu_enable_softmax_fusion", + setter_for_xla_gpu_enable_softmax_fusion, + debug_options->xla_gpu_enable_softmax_fusion(), + "Enable MLIR-based softmax fusion.")); + flag_list->push_back( tsl::Flag("xla_gpu_normalize_layouts", bool_setter_for(&DebugOptions::set_xla_gpu_normalize_layouts), - flag_values->xla_gpu_normalize_layouts(), + debug_options->xla_gpu_normalize_layouts(), "An experimental option to force all layouts present in the " "after-optimizations HLO to be descending")); - flag_objects->push_back(tsl::Flag( + flag_list->push_back(tsl::Flag( "xla_cpu_strict_dot_conv_math", bool_setter_for(&DebugOptions::set_xla_cpu_strict_dot_conv_math), - flag_values->xla_cpu_strict_dot_conv_math(), + debug_options->xla_cpu_strict_dot_conv_math(), "By default, XLA:CPU will run fp16 dot/conv as fp32, as this is " "generally (much) faster on our hardware. Set this flag to true to " "disable this behavior.")); + flag_list->push_back(tsl::Flag( + "xla_dump_latency_hiding_schedule", + bool_setter_for(&DebugOptions::set_xla_dump_latency_hiding_schedule), + debug_options->xla_dump_latency_hiding_schedule(), + "Dump the schedule from the latency-hiding scheduler.")); + flag_list->push_back(tsl::Flag( + "xla_cpu_enable_mlir_tiling_and_fusion", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_mlir_tiling_and_fusion), + debug_options->xla_cpu_enable_mlir_tiling_and_fusion(), + "Enable MLIR tiling and fusion.")); + flag_list->push_back( + tsl::Flag("xla_gpu_enable_latency_hiding_scheduler", + bool_setter_for( + &DebugOptions::set_xla_gpu_enable_latency_hiding_scheduler), + debug_options->xla_gpu_enable_latency_hiding_scheduler(), + "Enable latency-hiding scheduler for XLA:GPU")); + flag_list->push_back(tsl::Flag( + "xla_partitioning_algorithm", setter_for_xla_partitioning_algorithm, + DebugOptions::PartitioningAlgorithm_Name( + debug_options->xla_partitioning_algorithm()), + "The partitioning algorithm to be used in the PartitionAssignment pass")); +} // NOLINT(readability/fn_size) +// Allocates flag_values and flag_objects; this function must not be called more +// than once - its call done via call_once. +static void AllocateFlags(DebugOptions* defaults) { + if (defaults == nullptr) { + defaults = new DebugOptions(DefaultDebugOptionsIgnoringFlags()); + } + flag_values = defaults; + flag_objects = new std::vector(); + MakeDebugOptionsFlags(flag_objects, flag_values); ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); -} // NOLINT(readability/fn_size) +} -void AppendDebugOptionsFlags(std::vector* flag_list) { - absl::call_once(flags_init, &AllocateFlags); +void AppendDebugOptionsFlags(std::vector* flag_list, + DebugOptions* debug_options) { + absl::call_once(flags_init, &AllocateFlags, debug_options); flag_list->insert(flag_list->end(), flag_objects->begin(), flag_objects->end()); } xla::DebugOptions GetDebugOptionsFromFlags() { - absl::call_once(flags_init, &AllocateFlags); + absl::call_once(flags_init, &AllocateFlags, nullptr); return *flag_values; } void ResetThreadLocalFuel() { - absl::call_once(flags_init, &AllocateFlags); + absl::call_once(flags_init, &AllocateFlags, nullptr); thread_fuel = std::make_unique< absl::node_hash_map>>(); @@ -803,7 +912,7 @@ void ResetThreadLocalFuel() { } bool ConsumeFuel(absl::string_view pass, bool* just_ran_out) { - absl::call_once(flags_init, &AllocateFlags); + absl::call_once(flags_init, &AllocateFlags, nullptr); if (just_ran_out != nullptr) { *just_ran_out = false; } diff --git a/tensorflow/compiler/xla/debug_options_flags.h b/tensorflow/compiler/xla/debug_options_flags.h index 2669418d482..1b2ad1c6d65 100644 --- a/tensorflow/compiler/xla/debug_options_flags.h +++ b/tensorflow/compiler/xla/debug_options_flags.h @@ -25,8 +25,17 @@ limitations under the License. namespace xla { -// Appends flag definitions for debug options to flag_list. -void AppendDebugOptionsFlags(std::vector* flag_list); +// Construct flags which write to the debug_options proto when parsed. Existing +// contents of debug_options is used as the default. Can be called multiple +// times. +void MakeDebugOptionsFlags(std::vector* flag_list, + DebugOptions* debug_options); + +// Appends flag definitions for debug options to flag_list. Existing +// contents of debug_options is used as the default. If debug_options is null, +// uses global defaults. Modifies global state on first call. +void AppendDebugOptionsFlags(std::vector* flag_list, + DebugOptions* debug_options = nullptr); // Fetches a DebugOptions proto message from flags provided to the program. // Flags must be registered with the flags parser using AppendDebugOptionsFlags diff --git a/tensorflow/compiler/xla/examples/axpy/BUILD b/tensorflow/compiler/xla/examples/axpy/BUILD new file mode 100644 index 00000000000..a2e266481dd --- /dev/null +++ b/tensorflow/compiler/xla/examples/axpy/BUILD @@ -0,0 +1,29 @@ +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") + +xla_cc_test( + name = "stablehlo_compile_test", + srcs = ["stablehlo_compile_test.cc"], + data = ["stablehlo_axpy.mlir"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/pjrt:local_device_state", + "//tensorflow/compiler/xla/pjrt:mlir_to_hlo", + "//tensorflow/compiler/xla/pjrt:pjrt_stream_executor_client", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/tsl/lib/core:status_test_util", + "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:path", + "//tensorflow/tsl/platform:statusor", + "//tensorflow/tsl/platform:test", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@stablehlo//:register", + ], +) diff --git a/tensorflow/compiler/xla/examples/axpy/README.md b/tensorflow/compiler/xla/examples/axpy/README.md new file mode 100644 index 00000000000..397dd21c8fb --- /dev/null +++ b/tensorflow/compiler/xla/examples/axpy/README.md @@ -0,0 +1,221 @@ +# Compile a StableHLO program with XLA + +This tutorial and the code in this directory shows how to write a simple +StableHLO program and then compile it with XLA. The purpose is simply to +show how XLA can injest a StableHLO program and produce an executable +that's compatible with the local device. As such, the program is very +simple: $\alpha x+y$ ("axpy"). + +The process includes just a few steps: + +1. Construct a StableHLO program using the StableHLO dialect. +2. Tell XLA to create a "computation" based on this program. In this example, + we will use PjRt (Pretty much just another Runtime) to achieve that. +3. Run the compiled executable with some inputs to compute results. + +All the code is already provided in this directory, which you can build and +run using the steps at the end of this page. + +## 1. Create the StableHLO program + +We'll define the computation axpy as a StableHLO program, using an +[MLIR](https://mlir.llvm.org/) file in the +[StableHLO](https://github.com/openxla/stablehlo) dialect. + +It can be helpful to consider the computation as a graph, where each node is an +operation (an "op" or "HLO" which means "high-level operation") and the graph +edges are the data flow between operations. So the graph for axpy looks like +this: + +```mermaid +graph TD + p0(alpha f32) --> mul(Multiply 4xf32) + p1(x 4xf32) --> mul --> add(Add 4xf32) + p2(y 4xf32) --> add +``` + +And here's how we define the program using MLIR (in the StableHLO dialect): + +```mlir +func.func @main( + %alpha: tensor, %x: tensor<4xf32>, %y: tensor<4xf32> +) -> tensor<4xf32> { + %0 = stablehlo.broadcast_in_dim %alpha, dims = [] + : (tensor) -> tensor<4xf32> + %1 = stablehlo.multiply %0, %x : tensor<4xf32> + %2 = stablehlo.add %1, %y : tensor<4xf32> + func.return %2: tensor<4xf32> +} +``` + +This code is in [`stablehlo_axpy.mlir`](stablehlo_axpy.mlir). + +**Note:** StableHLO expresses broadcasting explicitly, so we use +`"stablehlo.broadcast_in_dim"` to broadcast our scalar to a rank-1 tensor. + +## 2. Compile the StableHLO program + +Our program for this tutorial is set up as a test in +[`stablehlo_compile_test.cc`](stablehlo_compile_test.cc). In this file, +you'll see that we first set up a `PjRtStreamExecutorClient` that +allows us to compile our StableHLO program: + +```c++ +// Setup client +LocalClient* local_client = xla::ClientLibrary::LocalClientOrDie(); + +// Retrieve the "platform" we intend to execute the computation on. The +// concept of "platform" in XLA abstracts entirely everything need to +// interact with some hardware (compiler, runtime, etc.). New HW vendor +// plugs into XLA by registering a new platform with a different string +// key. For example for an Nvidia GPU change the following to: +// PlatformUtil::GetPlatform("CUDA")); +TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, + PlatformUtil::GetPlatform("cpu")); +se::StreamExecutorConfig config; +config.ordinal = 0; +TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, + platform->GetExecutor(config)); + +// LocalDeviceState and PjRtStreamExecutorDevice describes the state of a +// device which can do computation or transfer buffers. Could represent a GPU +// or accelerator, but we'll use the CPU for this example. +auto device_state = std::make_unique( + executor, local_client, LocalDeviceState::kSynchronous, + /*max_inflight_computations=*/32, + /*allow_event_reuse=*/false, /*use_callback_stream=*/false); +auto device = std::make_unique( + 0, std::move(device_state), "cpu"); +std::vector> devices; +devices.emplace_back(std::move(device)); + +// The PjRtStreamExecutorClient will allow us to compile and execute +// computations on the device we just configured. +auto pjrt_se_client = PjRtStreamExecutorClient( + "cpu", local_client, std::move(devices), /*process_index=*/0, + /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, + /*should_stage_host_to_device_transfers=*/false, + /*gpu_run_options=*/nullptr); +``` + +Then we read the StableHLO program from our MLIR file into a string: + +```c++ +// Read StableHLO program to string +std::string program_path = tsl::io::JoinPath( + tsl::testing::XlaSrcRoot(), "examples", "axpy", "stablehlo_axpy.mlir"); +std::string program_string; + +TF_ASSERT_OK( + tsl::ReadFileToString(tsl::Env::Default(), program_path, &program_string)); +``` + +In order to parse the StableHLO program, we must first register the appropriate +MLIR dialects: + +```c++ +// Register MLIR dialects necessary to parse our program. In our case this is +// just the Func dialect and StableHLO. +mlir::DialectRegistry dialects; +dialects.insert(); +mlir::stablehlo::registerAllDialects(dialects); + +// Parse StableHLO program. +auto ctx = std::make_unique(dialects); +mlir::OwningOpRef program = + mlir::parseSourceString(program_string, ctx.get()); +``` + +Now that we've set up our client and parsed the StableHLO program we can +compile it to an executable: + +```c++ +// Use our client to compile our StableHLO program to an executable. +TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + pjrt_se_client.Compile(*program, CompileOptions{})); +``` + +## 3. Execute the computation + +Finally, in [`stablehlo_compile_test.cc`](stablehlo_compile_test.cc), +we can feed the executable some inputs for the three arguments and +compute the results: + +```c++ +// Create inputs to our computation. +auto alpha_literal = xla::LiteralUtil::CreateR0(3.14f); +auto x_literal = xla::LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); +auto y_literal = + xla::LiteralUtil::CreateR1({10.5f, 20.5f, 30.5f, 40.5f}); + +// Get the host device. +PjRtDevice* cpu = pjrt_se_client.devices()[0]; + +// Transfer our literals to buffers. If we were using a GPU, these buffers +// would correspond to device memory. +TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr alpha, + pjrt_se_client.BufferFromHostLiteral(alpha_literal, cpu)); +TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr x, + pjrt_se_client.BufferFromHostLiteral(x_literal, cpu)); +TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr y, + pjrt_se_client.BufferFromHostLiteral(y_literal, cpu)); + +// Do our computation. +TF_ASSERT_OK_AND_ASSIGN( + std::vector>> axpy_result, + executable->Execute({{alpha.get(), x.get(), y.get()}}, /*options=*/{})); + +// Convert result buffer back to literal. +TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr axpy_result_literal, + axpy_result[0][0]->ToLiteralSync()); + +// Check to make sure that our results match what we expect. +xla::LiteralTestUtil::ExpectR1Near({13.64f, 26.78f, 39.92f, 53.06f}, + *axpy_result_literal, + xla::ErrorSpec(0.01f)); +``` + +## 4. Build and run the code + +You can build and run this example as follows using +[Bazelisk](https://github.com/bazelbuild/bazelisk#readme) or +[Bazel](https://bazel.build/) (run from within `xla/examples/axpy/`): + +```sh +bazelisk test :stablehlo_compile_test --test_output=all --nocheck_visibility +``` + +Sample output from the test should look like this: + +```sh +==================== Test output for //xla/examples/axpy:stablehlo_compile_test: +[==========] Running 1 test from 1 test suite. +[----------] Global test environment set-up. +[----------] 1 test from StableHloAxpyTest +[ RUN ] StableHloAxpyTest.LoadAndRunCpuExecutable +Loaded StableHLO program from xla/examples/axpy/stablehlo_axpy.mlir: +func.func @main( + %alpha: tensor, %x: tensor<4xf32>, %y: tensor<4xf32> +) -> tensor<4xf32> { + %0 = stablehlo.broadcast_in_dim %alpha, dims = [] + : (tensor) -> tensor<4xf32> + %1 = stablehlo.multiply %0, %x : tensor<4xf32> + %2 = stablehlo.add %1, %y : tensor<4xf32> + func.return %2: tensor<4xf32> +} + +Computation inputs: + alpha:f32[] 3.14 + x:f32[4] {1, 2, 3, 4} + y:f32[4] {10.5, 20.5, 30.5, 40.5} +Computation output: f32[4] {13.64, 26.78, 39.920002, 53.06} +[ OK ] StableHloAxpyTest.LoadAndRunCpuExecutable (264 ms) +[----------] 1 test from StableHloAxpyTest (264 ms total) + +[----------] Global test environment tear-down +[==========] 1 test from 1 test suite ran. (264 ms total) +[ PASSED ] 1 test. +``` diff --git a/tensorflow/compiler/xla/examples/axpy/stablehlo_axpy.mlir b/tensorflow/compiler/xla/examples/axpy/stablehlo_axpy.mlir new file mode 100644 index 00000000000..7f4205999f8 --- /dev/null +++ b/tensorflow/compiler/xla/examples/axpy/stablehlo_axpy.mlir @@ -0,0 +1,9 @@ +func.func @main( + %alpha: tensor, %x: tensor<4xf32>, %y: tensor<4xf32> +) -> tensor<4xf32> { + %0 = stablehlo.broadcast_in_dim %alpha, dims = [] + : (tensor) -> tensor<4xf32> + %1 = stablehlo.multiply %0, %x : tensor<4xf32> + %2 = stablehlo.add %1, %y : tensor<4xf32> + func.return %2: tensor<4xf32> +} diff --git a/tensorflow/compiler/xla/examples/axpy/stablehlo_compile_test.cc b/tensorflow/compiler/xla/examples/axpy/stablehlo_compile_test.cc new file mode 100644 index 00000000000..da0a9dba19b --- /dev/null +++ b/tensorflow/compiler/xla/examples/axpy/stablehlo_compile_test.cc @@ -0,0 +1,148 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "stablehlo/dialect/Register.h" // from @stablehlo +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/pjrt/local_device_state.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/tsl/lib/core/status_test_util.h" +#include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/path.h" +#include "tensorflow/tsl/platform/statusor.h" +#include "tensorflow/tsl/platform/test.h" + +namespace xla { +namespace { + +TEST(StableHloAxpyTest, LoadAndRunCpuExecutable) { + // Setup client + LocalClient* local_client = xla::ClientLibrary::LocalClientOrDie(); + + // Retrieve the "platform" we intend to execute the computation on. The + // concept of "platform" in XLA abstracts entirely everything needed to + // interact with some hardware (compiler, runtime, etc.). New HW vendor + // plugs into XLA by registering a new platform with a different string + // key. For example for an Nvidia GPU change the following to: + // PlatformUtil::GetPlatform("CUDA")); + TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, + PlatformUtil::GetPlatform("cpu")); + se::StreamExecutorConfig config; + config.ordinal = 0; + TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, + platform->GetExecutor(config)); + + // LocalDeviceState and PjRtStreamExecutorDevice describes the state of a + // device which can do computation or transfer buffers. This could represent a + // GPU or accelerator, but we'll use the CPU for this example. + auto device_state = std::make_unique( + executor, local_client, LocalDeviceState::kSynchronous, + /*max_inflight_computations=*/32, + /*allow_event_reuse=*/false, /*use_callback_stream=*/false); + auto device = std::make_unique( + 0, std::move(device_state), "cpu"); + std::vector> devices; + devices.emplace_back(std::move(device)); + + // The PjRtStreamExecutorClient will allow us to compile and execute + // computations on the device we just configured. + auto pjrt_se_client = PjRtStreamExecutorClient( + "cpu", local_client, std::move(devices), /*process_index=*/0, + /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, + /*should_stage_host_to_device_transfers=*/false, + /*gpu_run_options=*/nullptr); + + // Read StableHLO program to string. + std::string program_path = tsl::io::JoinPath( + tsl::testing::XlaSrcRoot(), "examples", "axpy", "stablehlo_axpy.mlir"); + std::string program_string; + + TF_ASSERT_OK(tsl::ReadFileToString(tsl::Env::Default(), program_path, + &program_string)); + + std::cerr << "Loaded StableHLO program from " << program_path << ":\n" + << program_string << std::endl; + + // Register MLIR dialects necessary to parse our program. In our case this is + // just the Func dialect and StableHLO. + mlir::DialectRegistry dialects; + dialects.insert(); + mlir::stablehlo::registerAllDialects(dialects); + + // Parse StableHLO program. + auto ctx = std::make_unique(dialects); + mlir::OwningOpRef program = + mlir::parseSourceString(program_string, ctx.get()); + + // Use our client to compile our StableHLO program to an executable. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + pjrt_se_client.Compile(*program, CompileOptions{})); + + // Create inputs to our computation. + auto alpha_literal = xla::LiteralUtil::CreateR0(3.14f); + auto x_literal = xla::LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); + auto y_literal = + xla::LiteralUtil::CreateR1({10.5f, 20.5f, 30.5f, 40.5f}); + + std::cerr << "Computation inputs:" << std::endl; + std::cerr << "\talpha:" << alpha_literal << std::endl; + std::cerr << "\tx:" << x_literal << std::endl; + std::cerr << "\ty:" << y_literal << std::endl; + + // Get the host device. + PjRtDevice* cpu = pjrt_se_client.devices()[0]; + + // Transfer our literals to buffers. If we were using a GPU, these buffers + // would correspond to device memory. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr alpha, + pjrt_se_client.BufferFromHostLiteral(alpha_literal, cpu)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr x, + pjrt_se_client.BufferFromHostLiteral(x_literal, cpu)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr y, + pjrt_se_client.BufferFromHostLiteral(y_literal, cpu)); + + // Do our computation. + TF_ASSERT_OK_AND_ASSIGN( + std::vector>> axpy_result, + executable->Execute({{alpha.get(), x.get(), y.get()}}, /*options=*/{})); + + // Convert result buffer back to literal. + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr axpy_result_literal, + axpy_result[0][0]->ToLiteralSync()); + + // Check to make sure that our results match what we expect. + xla::LiteralTestUtil::ExpectR1Near({13.64f, 26.78f, 39.92f, 53.06f}, + *axpy_result_literal, + xla::ErrorSpec(0.01f)); + std::cerr << "Computation output: " << *axpy_result_literal << std::endl; +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 2827c08ef4e..21264536631 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -72,6 +72,16 @@ stream_executor::Stream* ExecutableRunOptions::host_to_device_stream() const { return host_to_device_stream_; } +ExecutableRunOptions& ExecutableRunOptions::set_device_to_host_stream( + stream_executor::Stream* stream) { + device_to_host_stream_ = stream; + return *this; +} + +stream_executor::Stream* ExecutableRunOptions::device_to_host_stream() const { + return device_to_host_stream_; +} + ExecutableRunOptions& ExecutableRunOptions::set_intra_op_thread_pool( const Eigen::ThreadPoolDevice* intra_op_thread_pool) { intra_op_thread_pool_ = intra_op_thread_pool; diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 0b0a4c760c1..8a0aa19dc06 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ #define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ +#include #include +#include #include // These classes are forward declared so that ExecutableRunOptions can be linked @@ -25,18 +27,33 @@ limitations under the License. // need to be linked). namespace stream_executor { class Stream; +class Event; class Platform; class DeviceMemoryAllocator; +class DeviceMemoryBase; } // namespace stream_executor namespace Eigen { struct ThreadPoolDevice; } // namespace Eigen +namespace tsl { +class Status; +template +class StatusOr; +template +class AsyncValueRef; +} // namespace tsl + namespace xla { +using ::tsl::Status; // TENSORFLOW_STATUS_OK +using ::tsl::StatusOr; // TENSORFLOW_STATUS_OK + class DeviceAssignment; class ExecutionProfile; +class Shape; + namespace gpu { class GpuExecutableRunOptions; } // namespace gpu @@ -77,6 +94,22 @@ class RunId { using ThenExecuteFunction = std::function)>; +// Callback for sending device buffer to a channel. Returned event will be +// recorded on a `stream` once the send operation is completed and data was +// copied from the `src` memory. +using SendDeviceMemoryFunction = + std::function>( + int64_t channel_id, stream_executor::Stream* stream, const Shape& shape, + const stream_executor::DeviceMemoryBase& src)>; + +// Callback for receiving device buffer from a channel. Returned event will be +// recorded on a `stream` once the recv operation is completed and data was +// copied into the `dst` memory. +using RecvDeviceMemoryFunction = + std::function>( + int64_t channel_id, stream_executor::Stream* stream, const Shape& shape, + stream_executor::DeviceMemoryBase* dst)>; + // Class containing options for running a LocalExecutable. class ExecutableRunOptions { public: @@ -99,13 +132,21 @@ class ExecutableRunOptions { ExecutableRunOptions& set_stream(stream_executor::Stream* stream); stream_executor::Stream* stream() const; - // If set, this is the stream to perform any pre-computation transfers on. - // The platform of the stream must match the platform the executable was - // built for. A value of nullptr indicates the option has not been set. + // If set, this is the stream to perform host to device transfers on (e.g. any + // pre-computation transfers). The platform of the stream must match the + // platform the executable was built for. A value of nullptr indicates the + // option has not been set. ExecutableRunOptions& set_host_to_device_stream( stream_executor::Stream* stream); stream_executor::Stream* host_to_device_stream() const; + // If set, this is the stream to perform device to host transfers on. + // The platform of the stream must match the platform the executable was + // built for. A value of nullptr indicates the option has not been set. + ExecutableRunOptions& set_device_to_host_stream( + stream_executor::Stream* stream); + stream_executor::Stream* device_to_host_stream() const; + // Sets the thread pool device on which to run Eigen subcomputations. // // This field must be set for XLA:CPU models that call Eigen routines, but may @@ -148,6 +189,26 @@ class ExecutableRunOptions { return then_execute_function_; } + // See documentation on SendDeviceMemoryFunction. + ExecutableRunOptions& set_send_device_memory_function( + SendDeviceMemoryFunction* f) { + send_device_memory_function_ = f; + return *this; + } + SendDeviceMemoryFunction* send_device_memory_function() const { + return send_device_memory_function_; + } + + // See documentation on RecvDeviceMemoryFunction. + ExecutableRunOptions& set_recv_device_memory_function( + RecvDeviceMemoryFunction* f) { + recv_device_memory_function_ = f; + return *this; + } + RecvDeviceMemoryFunction* recv_device_memory_function() const { + return recv_device_memory_function_; + } + // GPU-backend specific options. These are kept out-of-line to avoid bloating // the size of this dependency for CPU-only AOT builds. ExecutableRunOptions& set_gpu_executable_run_options( @@ -163,8 +224,11 @@ class ExecutableRunOptions { ExecutionProfile* execution_profile_ = nullptr; int rng_seed_ = 0; int32_t launch_id_ = 0; + stream_executor::Stream* device_to_host_stream_ = nullptr; stream_executor::Stream* host_to_device_stream_ = nullptr; ThenExecuteFunction* then_execute_function_ = nullptr; + SendDeviceMemoryFunction* send_device_memory_function_ = nullptr; + RecvDeviceMemoryFunction* recv_device_memory_function_ = nullptr; RunId run_id_; const gpu::GpuExecutableRunOptions* gpu_executable_run_options_ = nullptr; }; diff --git a/tensorflow/compiler/mlir/xla/experimental/conv_emitter/BUILD b/tensorflow/compiler/xla/experimental/conv_emitter/BUILD similarity index 91% rename from tensorflow/compiler/mlir/xla/experimental/conv_emitter/BUILD rename to tensorflow/compiler/xla/experimental/conv_emitter/BUILD index 30920e1754e..a786c5c89b6 100644 --- a/tensorflow/compiler/mlir/xla/experimental/conv_emitter/BUILD +++ b/tensorflow/compiler/xla/experimental/conv_emitter/BUILD @@ -1,11 +1,11 @@ # Description: # MLIR-GPU-specific convolution in XLA service implementation. -load("//tensorflow:tensorflow.default.bzl", "filegroup") -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [":friends"], licenses = ["notice"], ) @@ -32,7 +32,7 @@ cc_library( ":conv_emitter_transforms", "//tensorflow/compiler/xla:permutation_util", "//tensorflow/compiler/xla:window_util", - "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service/llvm_ir:llvm_type_conversion_util", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", @@ -64,7 +64,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "conv_emitter_test", srcs = ["conv_emitter_test.cc"], deps = [ diff --git a/tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter.cc b/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter.cc similarity index 99% rename from tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter.cc rename to tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter.cc index fe9157feb8d..c5af2884e0a 100644 --- a/tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter.cc +++ b/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter.cc @@ -25,7 +25,7 @@ limitations under the License. // * Use milr::AffineExpr to analyze all accesses. It aims to algorithmically // find memory access strategies for given input layouts and tiling configs. -#include "tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter.h" +#include "tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" @@ -40,7 +40,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project -#include "tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_transforms.h" +#include "tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_transforms.h" #include "tensorflow/compiler/xla/permutation_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h" #include "tensorflow/compiler/xla/window_util.h" diff --git a/tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter.h b/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter.h similarity index 85% rename from tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter.h rename to tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter.h index 12270e9da03..a380800b2f7 100644 --- a/tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter.h +++ b/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_H_ #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" namespace xla { namespace experimental { @@ -46,4 +46,4 @@ Status ConvIsImplemented(const HloInstruction* conv); } // namespace experimental } // namespace xla -#endif // TENSORFLOW_COMPILER_MLIR_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_H_ +#endif // TENSORFLOW_COMPILER_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_H_ diff --git a/tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_test.cc b/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_test.cc similarity index 97% rename from tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_test.cc rename to tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_test.cc index bbae2513ff5..8c66e6d5b82 100644 --- a/tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_test.cc +++ b/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter.h" +#include "tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter.h" #include @@ -74,7 +74,7 @@ std::string CompileHloConvAndGetMlir(absl::string_view hlo_text) { mlir::PassManager pm(mlir_module->getContext()); pm.addPass(mlir::createLowerAffinePass()); pm.addPass(mlir::createConvertSCFToCFPass()); - pm.addPass(mlir::createMemRefToLLVMConversionPass()); + pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); pm.addPass(mlir::createConvertFuncToLLVMPass()); CHECK(mlir::succeeded(pm.run(*mlir_module))); } diff --git a/tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_transforms.cc b/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_transforms.cc similarity index 98% rename from tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_transforms.cc rename to tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_transforms.cc index 400a97d73d2..91268062959 100644 --- a/tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_transforms.cc +++ b/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_transforms.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_transforms.h" +#include "tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_transforms.h" #include diff --git a/tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_transforms.h b/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_transforms.h similarity index 93% rename from tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_transforms.h rename to tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_transforms.h index 5ae7d7473b2..97c44daa52f 100644 --- a/tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_transforms.h +++ b/tensorflow/compiler/xla/experimental/conv_emitter/conv_emitter_transforms.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_ +#define TENSORFLOW_COMPILER_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_ #include "absl/types/span.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project @@ -99,4 +99,4 @@ void SinkPerfectlyNestedLoops(llvm::MutableArrayRef loops, } // namespace experimental } // namespace xla -#endif // TENSORFLOW_COMPILER_MLIR_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_ +#endif // TENSORFLOW_COMPILER_XLA_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_ diff --git a/tensorflow/compiler/mlir/xla/experimental/conv_emitter/g3doc/conv_emitter.md b/tensorflow/compiler/xla/experimental/conv_emitter/g3doc/conv_emitter.md similarity index 100% rename from tensorflow/compiler/mlir/xla/experimental/conv_emitter/g3doc/conv_emitter.md rename to tensorflow/compiler/xla/experimental/conv_emitter/g3doc/conv_emitter.md diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/frontend_attributes.cc similarity index 58% rename from tensorflow/compiler/xla/service/hlo_opcode.h rename to tensorflow/compiler/xla/frontend_attributes.cc index ed040c833bf..a1f6ad58c27 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/frontend_attributes.cc @@ -12,18 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/frontend_attributes.h" -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_ +namespace xla { -#include -#include -#include +void SetDisjointReadWriteRegionsAttr(HloInstruction* instruction) { + FrontendAttributes attrs; + (*attrs.mutable_map())[xla::kXlaDisjointReadWriteRegions] = "true"; + instruction->add_frontend_attributes(attrs); +} -#include "tensorflow/compiler/xla/comparison_util.h" -#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" +bool HasDisjointReadWriteRegionsAttr(HloInstruction* instruction) { + return instruction->frontend_attributes().map().contains( + xla::kXlaDisjointReadWriteRegions); +} -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_ +} // namespace xla diff --git a/tensorflow/compiler/xla/frontend_attributes.h b/tensorflow/compiler/xla/frontend_attributes.h new file mode 100644 index 00000000000..32ba357e533 --- /dev/null +++ b/tensorflow/compiler/xla/frontend_attributes.h @@ -0,0 +1,38 @@ +#ifndef TENSORFLOW_COMPILER_XLA_FRONTEND_ATTRIBUTES_H_ +#define TENSORFLOW_COMPILER_XLA_FRONTEND_ATTRIBUTES_H_ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" + +namespace xla { + +// Attribute which indicates that an in-place instruction has disjoint read +// and write regions w.r.t aliased input/output buffers. +inline constexpr char kXlaDisjointReadWriteRegions[] = + "_xla_disjoint_read_write_regions"; + +// Set frontend attribute on 'instruction' which indices that in-place +// 'instruction' has disjoint read/write buffer regions. +void SetDisjointReadWriteRegionsAttr(HloInstruction* instruction); + +// Returns 'true' if in-place 'instruction' has the kXlaDisjointReadWriteRegions +// frontend attribute set (returns false otherwise). +bool HasDisjointReadWriteRegionsAttr(HloInstruction* instruction); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_FRONTEND_ATTRIBUTES_H_ diff --git a/tensorflow/compiler/xla/g3doc/images/batch_group_counts.svg b/tensorflow/compiler/xla/g3doc/images/batch_group_counts.svg new file mode 100644 index 00000000000..799e8f895a2 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/images/batch_group_counts.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 26221082a06..bf9fe3c99ee 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -228,8 +228,11 @@ size `m` and spatial sizes `w` and `h`): \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right) \\\\ +d_l&= +\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} +\\\\ \nabla x_{ijkl} &= \frac{\gamma_{l}}{\sqrt{\sigma^2_{l}+\epsilon}} -\left( \nabla y_{ijkl} - \mathrm{mean}(\nabla y) - c_l (x_{ijkl} - \mu_{l}) +\left( \nabla y_{ijkl} - d_l - c_l (x_{ijkl} - \mu_{l}) \right) \\\\ \nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} @@ -861,6 +864,16 @@ The output shape has these dimensions, in this order: * `spatial_dims`: One value for each valid placement of the convolutional window. +
+ +
+ +The figure above shows how `batch_group_count` field works. Effectively, we +slice each lhs batch into `batch_group_count` groups, and do the same for the +output features. Then, for each of these groups we do pairwise convolutions and +concatenate the output along the output feature dimension. The operational +semantics of all the other dimensions (feature and spatial) remain the same. + The valid placements of the convolutional window are determined by the strides and the size of the base area after padding. @@ -1388,6 +1401,8 @@ using the comparison operator of the element type of `operand`. `Cbrt(operand)` Element-wise cubic root operation `x -> cbrt(x)`. +`Tan(operand)` Element-wise tangent `x -> tan(x)`. + `Tanh(operand)` Element-wise hyperbolic tangent `x -> tanh(x)`. `Round(operand)` Element-wise rounding, ties away from zero. @@ -1929,7 +1944,7 @@ XlaOp for the received data. The client API of `Recv` operation represents synchronous communication. However, the instruction is internally decomposed into 2 HLO instructions (`Recv` and `RecvDone`) to enable asynchronous data transfers. See also -[`HloInstruction::CreateRecv` and `HloInstruction::CreateRecvDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h). +[`HloInstruction::CreateRecv` and `HloInstruction::CreateRecvDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h). `Recv(const Shape& shape, int64 channel_id)` @@ -2785,7 +2800,7 @@ that shares the same channel handle. Does not return any data. Similar to the `Recv` operation, the client API of `Send` operation represents synchronous communication, and is internally decomposed into 2 HLO instructions (`Send` and `SendDone`) to enable asynchronous data transfers. See also -[`HloInstruction::CreateSend` and `HloInstruction::CreateSendDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h). +[`HloInstruction::CreateSend` and `HloInstruction::CreateSendDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h). `Send(HloInstruction operand, int64 channel_id)` diff --git a/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb index b1722f0d8da..88f94c2bbc3 100644 --- a/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb +++ b/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb @@ -70,7 +70,7 @@ "source": [ "This tutorial trains a TensorFlow model to classify the [CIFAR-10](https://en.wikipedia.org/wiki/CIFAR-10) dataset, and we compile it using XLA.\n", "\n", - "Load and normalize the dataset using the [TensorFlow Datasets](https://tensorflow.org/datasets) API:" + "You will load and normalize the dataset using the [TensorFlow Datasets (TFDS)](https://tensorflow.org/datasets) API. First, install/upgrade TensorFlow and TFDS:" ] }, { @@ -81,7 +81,7 @@ }, "outputs": [], "source": [ - "!pip install tensorflow_datasets" + "!pip install -U -q tensorflow tensorflow_datasets" ] }, { @@ -190,7 +190,7 @@ "outputs": [], "source": [ "def compile_model(model):\n", - " opt = tf.keras.optimizers.RMSprop(learning_rate=0.0001, decay=1e-6)\n", + " opt = tf.keras.optimizers.RMSprop(learning_rate=0.0001)\n", " model.compile(loss='categorical_crossentropy',\n", " optimizer=opt,\n", " metrics=['accuracy'])\n", diff --git a/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb index fb8a2ba0e12..b9967f4e94f 100644 --- a/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb +++ b/tensorflow/compiler/xla/g3doc/tutorials/jit_compile.ipynb @@ -81,8 +81,7 @@ }, "outputs": [], "source": [ - "import tensorflow as tf\n", - "tf.compat.v1.enable_eager_execution()" + "import tensorflow as tf\n" ] }, { diff --git a/tensorflow/compiler/xla/glob_lit_test.bzl b/tensorflow/compiler/xla/glob_lit_test.bzl new file mode 100644 index 00000000000..8863805fff7 --- /dev/null +++ b/tensorflow/compiler/xla/glob_lit_test.bzl @@ -0,0 +1,117 @@ +# Test definitions for Lit, the LLVM test runner. +# +# This is reusing the LLVM Lit test runner in the interim until the new build +# rules are upstreamed. +# TODO(b/136126535): remove this custom rule. +"""Lit runner globbing test +""" + +load("@bazel_skylib//lib:paths.bzl", "paths") + +# Default values used by the test runner. +_default_test_file_exts = ["mlir", ".pbtxt", ".td"] +_default_driver = "@llvm-project//mlir:run_lit.sh" +_default_size = "small" +_default_tags = [] + +# These are patterns which we should never match, for tests, subdirectories, or +# test input data files. +_ALWAYS_EXCLUDE = [ + "**/LICENSE.txt", + "**/README.txt", + "**/lit.local.cfg", + # Exclude input files that have spaces in their names, since bazel + # cannot cope with such "targets" in the srcs list. + "**/* *", + "**/* */**", +] + +def _run_lit_test(name, data, size, tags, driver, features, exec_properties): + """Runs lit on all tests it can find in `data` under xla/. + + Note that, due to Bazel's hermetic builds, lit only sees the tests that + are included in the `data` parameter, regardless of what other tests might + exist in the directory searched. + + Args: + name: str, the name of the test, including extension. + data: [str], the data input to the test. + size: str, the size of the test. + tags: [str], tags to attach to the test. + driver: str, label of the driver shell script. + Note: use of a custom driver is not currently supported + and specifying a default driver will abort the tests. + features: [str], list of extra features to enable. + exec_properties: may enable things like remote execution. + """ + + # Disable tests on windows for now, to enable testing rest of all xla and mlir. + native.py_test( + name = name, + srcs = ["@llvm-project//llvm:lit"], + tags = tags + ["no_pip", "no_windows"], + args = [ + "xla/" + paths.basename(data[-1]) + " --config-prefix=runlit -v", + ] + features, + data = data + [ + "//tensorflow/compiler/xla:litfiles", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:count", + "@llvm-project//llvm:not", + ], + size = size, + main = "lit.py", + exec_properties = exec_properties, + ) + +def glob_lit_tests( + exclude = [], + test_file_exts = _default_test_file_exts, + default_size = _default_size, + size_override = {}, + data = [], + per_test_extra_data = {}, + default_tags = _default_tags, + tags_override = {}, + driver = _default_driver, + features = [], + exec_properties = {}): + """Creates all plausible Lit tests (and their inputs) under this directory. + + Args: + exclude: [str], paths to exclude (for tests and inputs). + test_file_exts: [str], extensions for files that are tests. + default_size: str, the test size for targets not in "size_override". + size_override: {str: str}, sizes to use for specific tests. + data: [str], additional input data to the test. + per_test_extra_data: {str: [str]}, extra data to attach to a given file. + default_tags: [str], additional tags to attach to the test. + tags_override: {str: str}, tags to add to specific tests. + driver: str, label of the driver shell script. + Note: use of a custom driver is not currently supported + and specifying a default driver will abort the tests. + features: [str], list of extra features to enable. + exec_properties: a dictionary of properties to pass on. + """ + + # Ignore some patterns by default for tests and input data. + exclude = _ALWAYS_EXCLUDE + exclude + + tests = native.glob( + ["*." + ext for ext in test_file_exts], + exclude = exclude, + ) + + # Run tests individually such that errors can be attributed to a specific + # failure. + for curr_test in tests: + # Instantiate this test with updated parameters. + _run_lit_test( + name = curr_test + ".test", + data = data + [curr_test] + per_test_extra_data.get(curr_test, []), + size = size_override.get(curr_test, default_size), + tags = default_tags + tags_override.get(curr_test, []), + driver = driver, + features = features, + exec_properties = exec_properties, + ) diff --git a/tensorflow/compiler/xla/hlo/evaluator/BUILD b/tensorflow/compiler/xla/hlo/evaluator/BUILD index 4d6ccf87e53..f99a903ecdf 100644 --- a/tensorflow/compiler/xla/hlo/evaluator/BUILD +++ b/tensorflow/compiler/xla/hlo/evaluator/BUILD @@ -2,9 +2,10 @@ # XLA evaluator implementation. load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [":friends"], licenses = ["notice"], ) @@ -27,6 +28,7 @@ cc_library( "hlo_evaluator_typed_visitor_complex64.cc", "hlo_evaluator_typed_visitor_double.cc", "hlo_evaluator_typed_visitor_float.cc", + "hlo_evaluator_typed_visitor_float8.cc", "hlo_evaluator_typed_visitor_half.cc", "hlo_evaluator_typed_visitor_int16.cc", "hlo_evaluator_typed_visitor_int32.cc", @@ -49,13 +51,15 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:call_graph", + "//tensorflow/compiler/xla/service:compilation_environments", "//tensorflow/compiler/xla/service:dynamic_dimension_inference", - "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_query", "//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/compiler/xla/service:shape_inference", + "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", - "//tensorflow/compiler/xla/stream_executor/lib", "//tensorflow/tsl/lib/core:bitmap", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:logging", @@ -75,7 +79,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "hlo_evaluator_test", srcs = ["hlo_evaluator_test.cc"], deps = [ @@ -92,7 +96,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", diff --git a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.cc b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.cc index 358a83d33b7..5911a721fdc 100644 --- a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -33,26 +34,29 @@ limitations under the License. #include "absl/cleanup/cleanup.h" #include "absl/container/inlined_vector.h" #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/compilation_environments.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" -#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/stream_executor/lib/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -179,22 +183,137 @@ StatusOr Compare(const Shape& shape, return std::move(result); } +std::optional GetInstructionStaticValueAsBool( + HloInstruction* instruction) { + HloEvaluator evaluator; + StatusOr static_value = evaluator.Evaluate( + instruction, /*recursively_evaluate_nonconstant_operands=*/true); + if (static_value.ok()) { + return static_value->GetFirstElement(); + } + return std::nullopt; +} + +constexpr absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl"; + +// Use this class to represent the precise details of the error to enable +// special treatment. +enum class EvalErrorDetail : uint32_t { + // The evaluation result depends on dynamic values such as parameters and + // infeed. Therefore, the HLO's value cannot be statically evaluated. + kDynamicValueDependence = 0, +}; + +std::optional ParseEvalErrorDetail(const Status& error) { + auto error_detail = error.GetPayload(kEvalErrorDetailUrl); + if (!error_detail.has_value() && error_detail->empty()) { + return std::nullopt; + } + return static_cast( + absl::little_endian::Load32(error_detail->Flatten().data())); +} + +Status MakeEvalErrorDueToParamOrInfeed(const HloInstruction& eval_instruction) { + Status error = tsl::errors::FailedPrecondition( + "Failed to evaluate instruction (", eval_instruction.name(), + ") since it depends on infeed or parameters to its parent computation (", + eval_instruction.parent()->name(), ")."); + std::string error_payload; + error_payload.resize(sizeof(EvalErrorDetail)); + absl::little_endian::Store32( + const_cast(error_payload.data()), + static_cast(EvalErrorDetail::kDynamicValueDependence)); + error.SetPayload(kEvalErrorDetailUrl, absl::Cord(error_payload)); + return error; +} + +// Repesents a value that might or might not be determined statically. +struct DynamicOrStaticInteger { + std::optional static_value; + bool is_dynamic() const { return !static_value.has_value(); } + + std::string ToString() const { + return is_dynamic() ? std::string("DYNAMIC") : absl::StrCat(*static_value); + } +}; + +std::optional GetInstructionValueAsInteger( + HloInstruction* instruction) { + HloEvaluator evaluator; + StatusOr static_value = evaluator.Evaluate( + instruction, /*recursively_evaluate_nonconstant_operands=*/true); + if (static_value.ok()) { + if (instruction->shape().element_type() == PrimitiveType::PRED) { + return DynamicOrStaticInteger{ + static_cast(static_value->GetFirstElement())}; + } else { + return DynamicOrStaticInteger{static_value->GetFirstInteger()}; + } + } + + std::optional eval_error_detail = + ParseEvalErrorDetail(static_value.status()); + if (eval_error_detail.has_value() && + *eval_error_detail == EvalErrorDetail::kDynamicValueDependence) { + return DynamicOrStaticInteger{std::nullopt}; + } + return std::nullopt; +} + // Represents an index into the while argument tuple and / or a value. // At least one of param_index and value has a value; both of them could have // a value. struct ParamIndexAndValue { std::optional param_index; - std::optional value; + std::optional value; bool IsValid() const { return param_index.has_value() || value.has_value(); } + std::string ToString() const { + return absl::StrCat( + "param_index:", + !param_index.has_value() ? std::string("UNKNOWN") + : absl::StrCat(*param_index), + ",", "value:", + !value.has_value() ? std::string("UNKONWN") : value->ToString()); + } }; +std::optional TryParsingInstructionAsParameterAndInteger( + HloInstruction* instruction) { + // Skip copies. + if (instruction->opcode() == HloOpcode::kCopy) { + return TryParsingInstructionAsParameterAndInteger( + instruction->mutable_operand(0)); + } + if (instruction->opcode() == HloOpcode::kCopyDone) { + return TryParsingInstructionAsParameterAndInteger( + instruction->mutable_operand(0)->mutable_operand(1)); + } + ParamIndexAndValue result; + if (Match(instruction, match::GetTupleElement().WithOperand( + 0, match::Parameter().WithParameterNum(0)))) { + result.param_index = instruction->tuple_index(); + } + std::optional integer_value = + GetInstructionValueAsInteger(instruction); + result.value = std::move(integer_value); + if (!result.IsValid()) { + return std::nullopt; + } + return std::optional(std::move(result)); +} + // Represents the while loop condition comparison. // We assume comparison is of the form: lhs comp rhs. struct WhileCondComparison { - ComparisonDirection comparson_direction; + ComparisonDirection comparison_direction; ParamIndexAndValue lhs; ParamIndexAndValue rhs; + + std::string ToString() const { + return absl::StrCat("WhileCondComparison{", "LHS:{", lhs.ToString(), + "},RHS:{", rhs.ToString(), "}}"); + } }; // Represents the parsed while loop condition. The loop induction variable may @@ -204,57 +323,55 @@ struct WhileCondComparison { using WhileCondComparisonOrNoOp = std::variant; +std::optional ParseComparisonOperand( + HloInstruction* operand) { + if (operand->opcode() == HloOpcode::kCopy || + operand->opcode() == HloOpcode::kCopyStart || + operand->opcode() == HloOpcode::kCopyDone) { + return ParseComparisonOperand(operand->mutable_operand(0)); + } + std::optional param_index; + if (Match(operand, match::GetTupleElement().WithOperand( + 0, match::Parameter().WithParameterNum(0)))) { + param_index = operand->tuple_index(); + } + std::optional operand_value = + GetInstructionValueAsInteger(operand); + if (!param_index.has_value() && !operand_value.has_value()) { + return std::nullopt; + } + return ParamIndexAndValue{param_index, operand_value}; +} + +std::optional PatternMatchLoopCondComparison( + HloInstruction* comparison) { + CHECK_EQ(comparison->opcode(), HloOpcode::kCompare); + std::optional lhs = + ParseComparisonOperand(comparison->mutable_operand(0)); + std::optional rhs = + ParseComparisonOperand(comparison->mutable_operand(1)); + if (!lhs.has_value() || !rhs.has_value()) { + return std::nullopt; + } + return WhileCondComparison{comparison->comparison_direction(), + *std::move(lhs), *std::move(rhs)}; +} // Finds the while loop condition comparison by matching the loop condition root // with known patterns. -std::optional PatternMatchLoopCondComparison( +std::optional PatternMatchLoopCondRoot( HloInstruction* loop_cond_root) { - // Base pattern #1: gte-0 comp gte-1 - if (Match(loop_cond_root, - match::Compare() - .WithOperand(0, match::GetTupleElement().WithOperand( - 0, match::Parameter().WithParameterNum(0))) - .WithOperand(1, - match::GetTupleElement().WithOperand( - 0, match::Parameter().WithParameterNum(0))))) { - return WhileCondComparison{ - loop_cond_root->comparison_direction(), - {/*param_index=*/loop_cond_root->operand(0)->tuple_index()}, - {/*param_index=*/loop_cond_root->operand(1)->tuple_index()}}; - } - // Base pattern #2: constant comp gte - if (Match(loop_cond_root, - match::Compare() - .WithOperand(0, match::Constant()) - .WithOperand(1, - match::GetTupleElement().WithOperand( - 0, match::Parameter().WithParameterNum(0))))) { - std::optional lhs_value = - loop_cond_root->operand(0)->literal().GetFirstInteger(); - if (!lhs_value.has_value()) { - return std::nullopt; - } - return WhileCondComparison{ - loop_cond_root->comparison_direction(), - {/*param_index=*/std::nullopt, /*value=*/*lhs_value}, - {/*param_index=*/loop_cond_root->operand(1)->tuple_index()}}; + if (loop_cond_root->opcode() == HloOpcode::kCopy) { + return PatternMatchLoopCondRoot(loop_cond_root->mutable_operand(0)); } - // Base pattern #3: gte comp constant - if (Match(loop_cond_root, - match::Compare() - .WithOperand(0, match::GetTupleElement().WithOperand( - 0, match::Parameter().WithParameterNum(0))) - .WithOperand(1, match::Constant()))) { - std::optional rhs_value = - loop_cond_root->operand(1)->literal().GetFirstInteger(); - if (!rhs_value.has_value()) { - return std::nullopt; - } - return WhileCondComparison{ - loop_cond_root->comparison_direction(), - {/*param_index=*/loop_cond_root->operand(0)->tuple_index(), - /*value=*/std::nullopt}, - {/*param_index=*/std::nullopt, /*value=*/*rhs_value}, - }; + if (loop_cond_root->opcode() == HloOpcode::kCopyDone) { + return PatternMatchLoopCondRoot( + loop_cond_root->mutable_operand(0)->mutable_operand(1)); + } + if (loop_cond_root->opcode() == HloOpcode::kCompare) { + // Base pattern #1: gte-0 comp gte-1 + // Base pattern #2: constant comp gte + // Base pattern #3: gte comp constant + return PatternMatchLoopCondComparison(loop_cond_root); } // Base pattern #4: gte is a boolean scalar and it was return immediately. if (Match(loop_cond_root, match::GetTupleElement().WithOperand( @@ -279,7 +396,7 @@ std::optional PatternMatchLoopCondComparison( HloComputation* to_apply = call_instruction->to_apply(); HloInstruction* to_apply_root = to_apply->root_instruction(); if (Match(to_apply_root, match::Tuple())) { - return PatternMatchLoopCondComparison( + return PatternMatchLoopCondRoot( to_apply_root->mutable_operand(loop_cond_root->tuple_index())); } } @@ -291,417 +408,384 @@ std::optional PatternMatchLoopCondComparison( HloInstruction* new_cond_root = loop_cond_root->mutable_operand(0)->mutable_operand( loop_cond_root->tuple_index()); - return PatternMatchLoopCondComparison(new_cond_root); + return PatternMatchLoopCondRoot(new_cond_root); } return std::nullopt; } -// Tries to parse the loop body to find how the induction variable is updated -// using pattern matching. -std::optional PatternMatchInductionVarUpdate( - HloInstruction* loop_body_root, int64_t tuple_index) { - // Pattern #1: induc_var = induc_var + constant - if (Match(loop_body_root, - match::Tuple().WithOperand( - tuple_index, - match::Add() - .WithOperand(0, match::GetTupleElement() - .WithTupleIndex(tuple_index) - .WithOperand(0, match::Parameter())) - .WithOperand(1, match::Constant())))) { - std::optional step_size = loop_body_root->operand(tuple_index) - ->operand(1) - ->literal() - .GetFirstInteger(); - if (!step_size.has_value()) { - return std::nullopt; - } - return *step_size; - } - // Pattern #2: induc_var = constant + induc_var - if (Match( - loop_body_root, - match::Tuple().WithOperand( - tuple_index, - match::Add() - .WithOperand(0, match::Constant()) - .WithOperand(1, match::GetTupleElement() - .WithTupleIndex(tuple_index) - .WithOperand(0, match::Parameter()))))) { - std::optional step_size = loop_body_root->operand(tuple_index) - ->operand(0) - ->literal() - .GetFirstInteger(); - if (!step_size.has_value()) { - return std::nullopt; +std::optional PatternMatchInductionVarUpdate( + HloInstruction* induction_var_update, int64_t tuple_index) { + if (induction_var_update->opcode() == HloOpcode::kCopy) { + return PatternMatchInductionVarUpdate( + induction_var_update->mutable_operand(0), tuple_index); + } + if (induction_var_update->opcode() == HloOpcode::kCopyDone) { + return PatternMatchInductionVarUpdate( + induction_var_update->mutable_operand(0)->mutable_operand(1), + tuple_index); + } + std::optional update_param_index_and_value = + TryParsingInstructionAsParameterAndInteger(induction_var_update); + + if (update_param_index_and_value.has_value()) { + if (update_param_index_and_value->param_index.has_value()) { + if (*update_param_index_and_value->param_index == tuple_index) { + // Pattern: the induc_var is directly returned from the loop body with + // no changes. + VLOG(3) << "PatternMatchInductionVarUpdate, pattern: [induc_var]."; + return DynamicOrStaticInteger{/*static_value=*/0}; + } else { + VLOG(3) + << "PatternMatchInductionVarUpdate, induction variable is set to " + "another parameter value. Parsed update: " + << update_param_index_and_value->ToString(); + return std::nullopt; + } } - return *step_size; - } - - // Pattern #3: induc_var = induc_var - constant - if (Match(loop_body_root, - match::Tuple().WithOperand( - tuple_index, - match::Subtract() - .WithOperand(0, match::GetTupleElement() - .WithTupleIndex(tuple_index) - .WithOperand(0, match::Parameter())) - .WithOperand(1, match::Constant())))) { - std::optional step_size = loop_body_root->operand(tuple_index) - ->operand(1) - ->literal() - .GetFirstInteger(); - if (!step_size.has_value()) { + if (update_param_index_and_value->value.has_value() && + !update_param_index_and_value->value->is_dynamic()) { + VLOG(3) << "PatternMatchInductionVarUpdate, induction variable is set to " + "a constant. Parsed update: " + << update_param_index_and_value->ToString(); return std::nullopt; } - return -*step_size; } - // Pattern #4: the induc_var is directly returned from the loop body with - // no changes. - if (Match(loop_body_root, - match::Tuple().WithOperand( - tuple_index, - match::GetTupleElement() - .WithOperand(0, match::Parameter().WithParameterNum(0)) - .WithTupleIndex(tuple_index)))) { - return 0; + if (induction_var_update->opcode() != HloOpcode::kAdd && + induction_var_update->opcode() != HloOpcode::kSubtract) { + return std::nullopt; + } + bool negate_update = induction_var_update->opcode() == HloOpcode::kSubtract; + HloInstruction* update_lhs = induction_var_update->mutable_operand(0); + VLOG(3) << "PatternMatchInductionVarUpdate, LHS: " << update_lhs->ToString(); + std::optional update_lhs_param_index_and_value = + TryParsingInstructionAsParameterAndInteger(update_lhs); + + HloInstruction* update_rhs = induction_var_update->mutable_operand(1); + VLOG(3) << "PatternMatchInductionVarUpdate, RHS: " << update_rhs->ToString(); + std::optional update_rhs_param_index_and_value = + TryParsingInstructionAsParameterAndInteger(update_rhs); + + if (!update_lhs_param_index_and_value.has_value() || + !update_lhs_param_index_and_value->value.has_value() || + !update_rhs_param_index_and_value.has_value() || + !update_rhs_param_index_and_value->value.has_value()) { + VLOG(3) << "PatternMatchInductionVarUpdate, failed to parse operands. " + "Induction var update instruction: " + << induction_var_update->ToString(); + return std::nullopt; } - return std::nullopt; -} -std::optional PatternMatchLoopCondVarOverride( - HloInstruction* loop_body_root, int64_t tuple_index) { - if (Match(loop_body_root, match::Tuple()) && - loop_body_root->operand_count() > tuple_index) { - HloInstruction* cond_var_override = - loop_body_root->mutable_operand(tuple_index); - HloEvaluator evaluator; - StatusOr new_cond_var = evaluator.Evaluate( - cond_var_override, /*recursively_evaluate_nonconstant_operands=*/true); - if (new_cond_var.ok()) { - return new_cond_var->GetFirstElement(); + VLOG(3) << "update_lhs: " << update_lhs->ToString(); + VLOG(3) << "update_rhs: " << update_rhs->ToString(); + + if (update_lhs_param_index_and_value->param_index.has_value() && + *update_lhs_param_index_and_value->param_index == tuple_index && + update_lhs_param_index_and_value->value->is_dynamic()) { + if (update_rhs_param_index_and_value->value->is_dynamic()) { + return update_rhs_param_index_and_value->value; } + int64_t update_value = + *update_rhs_param_index_and_value->value->static_value; + return negate_update + ? DynamicOrStaticInteger{/*static_value=*/-update_value} + : DynamicOrStaticInteger{/*static_value=*/update_value}; + } + + if (update_rhs_param_index_and_value->param_index.has_value() && + *update_rhs_param_index_and_value->param_index == tuple_index && + update_rhs_param_index_and_value->value->is_dynamic() && !negate_update) { + return update_lhs_param_index_and_value->value; } + VLOG(3) << "Failed to pattern match induction variable update."; return std::nullopt; } -// Repesents a value that might or might not be determined statically. -struct DynamicOrStaticValue { - std::optional static_value; - bool is_dynamic() const { return !static_value.has_value(); } -}; - -constexpr absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl"; - -// Use this class to represent the precise details of the error to enable -// special treatment. -enum class EvalErrorDetail : uint32_t { - // The evaluation result depends on dynamic values such as parameters and - // infeed. Therefore, the HLO's value cannot be statically evaluated. - kDynamicValueDependence = 0, -}; - -Status MakeEvalErrorDueToParamOrInfeed(const HloInstruction& eval_instruction) { - Status error = tsl::errors::FailedPrecondition( - "Failed to evaluate instruction (", eval_instruction.name(), - ") since it depends on infeed or parameters to its parent computation (", - eval_instruction.parent()->name(), ")."); - std::string error_payload; - error_payload.resize(sizeof(EvalErrorDetail)); - absl::little_endian::Store32( - const_cast(error_payload.data()), - static_cast(EvalErrorDetail::kDynamicValueDependence)); - error.SetPayload(kEvalErrorDetailUrl, error_payload); - return error; +// Tries to parse the loop body to find how the induction variable is updated +// using pattern matching. +std::optional +PatternMatchInductionVarUpdateFromLoopBodyRoot(HloInstruction* loop_body_root, + int64_t tuple_index) { + if (loop_body_root->opcode() != HloOpcode::kTuple || + loop_body_root->operand_count() <= tuple_index) { + return std::nullopt; + } + HloInstruction* induction_var_update = + loop_body_root->mutable_operand(tuple_index); + return PatternMatchInductionVarUpdate(induction_var_update, tuple_index); } -std::optional ParseEvalErrorDetail(const Status& error) { - auto error_detail = error.GetPayload(kEvalErrorDetailUrl); - if (!error_detail.has_value() && error_detail->empty()) { +std::optional PatternMatchLoopCondVarOverride( + HloInstruction* loop_body_root, int64_t tuple_index) { + if (!Match(loop_body_root, match::Tuple()) || + loop_body_root->operand_count() <= tuple_index) { return std::nullopt; } - return static_cast( - absl::little_endian::Load32(error_detail->Flatten().data())); + HloInstruction* cond_var_override = + loop_body_root->mutable_operand(tuple_index); + return GetInstructionStaticValueAsBool(cond_var_override); } // A convenience wrapper to compute the while loop's argument's init value at // the given tuple_index. If the init value depends on parameters to the // while loop's parent computation or infeed, we consider the init value // dynamic. -std::optional EvaluateWhileLoopParamInitValue( +std::optional EvaluateWhileLoopParamInitValue( HloInstruction* param_instruction, int64_t tuple_index) { if (param_instruction->opcode() != HloOpcode::kTuple) { return std::nullopt; } HloInstruction* element_instruction = param_instruction->mutable_operand(tuple_index); - HloEvaluator evaluator; - StatusOr value = evaluator.Evaluate( - element_instruction, /*recursively_evaluate_nonconstant_operands=*/true); - if (value.ok()) { - if (element_instruction->shape().element_type() == PrimitiveType::PRED) { - return DynamicOrStaticValue{ - static_cast(value->GetFirstElement())}; - } else { - return DynamicOrStaticValue{value->GetFirstInteger()}; + return GetInstructionValueAsInteger(element_instruction); +} + +} // namespace + +std::optional HandleNoopLoopCondition( + const ParamIndexAndValue& parameter_index_and_value, + HloInstruction* while_operand, HloComputation* while_body) { + CHECK(parameter_index_and_value.param_index.has_value()); + int64_t loop_cond_var_index = *parameter_index_and_value.param_index; + std::optional noop_value = + EvaluateWhileLoopParamInitValue(while_operand, loop_cond_var_index); + + if (noop_value.has_value()) { + if (noop_value->is_dynamic()) { + return kParsedDynamicWhileLoop; + } else if (*noop_value->static_value == 0) { + return ParsedWhileLoop{ + ParsedStaticWhileLoop{/*trip_count=*/0, + /*induction_var_index=*/loop_cond_var_index, + /*induction_var_init_value=*/0, + /*step_size=*/0, + /*loop_bound=*/0}}; } - } else { - std::optional eval_error_detail = - ParseEvalErrorDetail(value.status()); - if (eval_error_detail.has_value() && - *eval_error_detail == EvalErrorDetail::kDynamicValueDependence) { - return DynamicOrStaticValue{std::nullopt}; + std::optional updated_loop_cond_var = PatternMatchLoopCondVarOverride( + while_body->root_instruction(), loop_cond_var_index); + if (updated_loop_cond_var.has_value()) { + if (!*updated_loop_cond_var) { + return ParsedWhileLoop{ + ParsedStaticWhileLoop{/*trip_count=*/1, + /*induction_var_index=*/loop_cond_var_index, + /*induction_var_init_value=*/0, + /*step_size=*/1, + /*loop_bound=*/1}}; + } else { + // This is an infinite loop and we set trip_count to -1. + return ParsedWhileLoop{ + ParsedStaticWhileLoop{/*trip_count=*/-1, + /*induction_var_index=*/loop_cond_var_index, + /*induction_var_init_value=*/0, + /*step_size=*/0, + /*loop_bound=*/1}}; + } } } return std::nullopt; } -} // namespace +int64_t ComputeTripCountFromComparison(int64_t init, int64_t bound, + int64_t update, + bool comparison_with_equal) { + if (comparison_with_equal && init > bound) { + return 0; + } + if (!comparison_with_equal && init >= bound) { + return 0; + } + int64_t distance = bound - init; + int64_t trip_count = (distance + update - 1) / update; + CHECK_GE(trip_count, 0); + // Additional logic to deal with equal comparison. + if (comparison_with_equal && (bound - init) % update == 0) { + trip_count += 1; + } + return trip_count; +} + +std::optional HandleStaticLoopComparison( + int64_t lhs, int64_t rhs, Comparison::Direction comparison_direction) { + if ((comparison_direction == Comparison::Direction::kLt && lhs < rhs) || + (comparison_direction == Comparison::Direction::kLe && lhs <= rhs) || + (comparison_direction == Comparison::Direction::kGt && lhs > rhs) || + (comparison_direction == Comparison::Direction::kGe && lhs >= rhs) || + (comparison_direction == Comparison::Direction::kEq && lhs == rhs) || + (comparison_direction == Comparison::Direction::kNe && lhs != rhs)) { + // This is an infinite loop and we set trip_count to -1. + // There is no induction variable. + return ParsedWhileLoop{ParsedStaticWhileLoop{/*trip_count=*/-1, + /*induction_var_index=*/-1, + /*induction_var_init_value=*/0, + /*step_size=*/0, + /*loop_bound=*/1}}; + } + return ParsedWhileLoop{ParsedStaticWhileLoop{/*trip_count=*/0, + /*induction_var_index=*/-1, + /*induction_var_init_value=*/0, + /*step_size=*/0, + /*loop_bound=*/0}}; +} std::optional PatternMatchParseWhileLoop( HloInstruction* while_op) { + VLOG(3) << "PatternMatchParseWhileLoop, while_op: " << while_op->name(); HloComputation* while_cond = while_op->while_condition(); HloComputation* while_body = while_op->while_body(); HloInstruction* while_operand = while_op->mutable_operand(0); // Try to parse the loop condition comparison. std::optional loop_comparison_or_noop = - PatternMatchLoopCondComparison(while_cond->root_instruction()); + PatternMatchLoopCondRoot(while_cond->root_instruction()); if (!loop_comparison_or_noop.has_value()) { return std::nullopt; } if (loop_comparison_or_noop->index() == 1) { - ParamIndexAndValue& parameter_index_and_value = - std::get(*loop_comparison_or_noop); - CHECK(parameter_index_and_value.param_index.has_value()); - int64_t loop_cond_var_index = *parameter_index_and_value.param_index; - std::optional noop_value = - EvaluateWhileLoopParamInitValue(while_operand, loop_cond_var_index); - - if (noop_value.has_value()) { - if (noop_value->is_dynamic()) { - return kParsedDynamicWhileLoop; - } else if (*noop_value->static_value == 0) { - return ParsedWhileLoop{ - ParsedStaticWhileLoop{/*trip_count=*/0, - /*induction_var_index=*/loop_cond_var_index, - /*induction_var_init_value=*/0, - /*step_size=*/0, - /*loop_bound=*/0}}; - } - std::optional updated_loop_cond_var = - PatternMatchLoopCondVarOverride(while_body->root_instruction(), - loop_cond_var_index); - if (updated_loop_cond_var.has_value()) { - if (!*updated_loop_cond_var) { - return ParsedWhileLoop{ - ParsedStaticWhileLoop{/*trip_count=*/1, - /*induction_var_index=*/loop_cond_var_index, - /*induction_var_init_value=*/0, - /*step_size=*/1, - /*loop_bound=*/1}}; - } else { - // This is an infinite loop and we set trip_count to -1. - return ParsedWhileLoop{ - ParsedStaticWhileLoop{/*trip_count=*/-1, - /*induction_var_index=*/loop_cond_var_index, - /*induction_var_init_value=*/0, - /*step_size=*/0, - /*loop_bound=*/1}}; - } - } - } - return std::nullopt; + return HandleNoopLoopCondition( + std::get(*loop_comparison_or_noop), while_operand, + while_body); } CHECK_EQ(loop_comparison_or_noop->index(), 0); WhileCondComparison loop_comparison = std::get(*loop_comparison_or_noop); CHECK(loop_comparison.lhs.IsValid() && loop_comparison.rhs.IsValid()); - // If the while loop condition comparison's both sides take an init value - // from the while loop's parent computation's parameter, the loop is dynamic. - if (while_operand->opcode() == HloOpcode::kParameter) { - if (loop_comparison.lhs.param_index.has_value() || - loop_comparison.rhs.param_index.has_value()) { - return kParsedDynamicWhileLoop; - } - } - // We can't handle the case when the while loop argument is not a Tuple // instruction. if (while_operand->opcode() != HloOpcode::kTuple) { return std::nullopt; } - // If loop cond comparison LHS does not have a value defined inside the loop - // cond computation, try to evaluate its init value inside the while loop's - // parent computation. - if (!loop_comparison.lhs.value.has_value()) { - std::optional lhs_init_value = - EvaluateWhileLoopParamInitValue(while_operand, - *loop_comparison.lhs.param_index); - if (lhs_init_value.has_value()) { - if (lhs_init_value->is_dynamic()) { - return kParsedDynamicWhileLoop; - } else { - loop_comparison.lhs.value = *(lhs_init_value->static_value); - } - } else { - return std::nullopt; - } - } - - // If loop cond comparison RHS does not have a value defined inside the loop - // cond computation, try to evaluate its init value inside the while loop's - // parent computation. - if (!loop_comparison.rhs.value.has_value()) { - std::optional rhs_init_value = - EvaluateWhileLoopParamInitValue(while_operand, - *loop_comparison.rhs.param_index); - if (rhs_init_value.has_value()) { - if (rhs_init_value->is_dynamic()) { - return kParsedDynamicWhileLoop; - } else { - loop_comparison.rhs.value = *(rhs_init_value->static_value); - } - } else { - return std::nullopt; - } + if (!loop_comparison.lhs.value.has_value() || + !loop_comparison.rhs.value.has_value()) { + return std::nullopt; } - // We have either successfully evaluated the init value for both LHS and RHS - // or have returned as dynamic loop or failure. + // We have either successfully parsed the init value for both LHS and RHS + // or have returned as failure. CHECK(loop_comparison.lhs.value.has_value()); CHECK(loop_comparison.rhs.value.has_value()); - if (loop_comparison.lhs.param_index.has_value()) { - VLOG(3) << __func__ << " lhs index: " << *loop_comparison.lhs.param_index; - } - - VLOG(3) << __func__ << " lhs bound: " << *loop_comparison.lhs.value; + VLOG(3) << loop_comparison.ToString(); - if (loop_comparison.rhs.param_index.has_value()) { - VLOG(3) << __func__ << " rhs index: " << *loop_comparison.rhs.param_index; + // If both operands of the loop condition comparison have dynamic value, the + // trip count might be dynamic or static. This is a case that our existing + // patterns could not yet handle, so we return std::nullopt. + if (loop_comparison.lhs.value->is_dynamic() && + loop_comparison.rhs.value->is_dynamic()) { + VLOG(3) << "Both operands of the loop condition comparison are dynamic."; + return std::nullopt; } - - VLOG(3) << __func__ << " rhs bound: " << *loop_comparison.rhs.value; - - // Check whether LHS is the loop induction var. - std::optional lhs_induction_var_update; - if (loop_comparison.lhs.param_index.has_value()) { - lhs_induction_var_update = PatternMatchInductionVarUpdate( - while_body->root_instruction(), *loop_comparison.lhs.param_index); + // We would have returned if both operands are dynamic. So there is at most + // one dynamic operand, which is potentially the loop induction variable. + CHECK(!loop_comparison.lhs.value->is_dynamic() || + !loop_comparison.rhs.value->is_dynamic()); + + if (!loop_comparison.lhs.value->is_dynamic() && + !loop_comparison.rhs.value->is_dynamic()) { + int64_t lhs_value = *loop_comparison.lhs.value->static_value; + int64_t rhs_value = *loop_comparison.rhs.value->static_value; + Comparison::Direction comparison_direction = + loop_comparison.comparison_direction; + return HandleStaticLoopComparison(lhs_value, rhs_value, + comparison_direction); + } + std::optional induction_var_init; + std::optional induction_var_update; + bool lhs_is_induction_var = true; + if (loop_comparison.lhs.value->is_dynamic()) { + if (loop_comparison.lhs.param_index.has_value()) { + VLOG(3) << "Comparison LHS is induction variable."; + induction_var_init = EvaluateWhileLoopParamInitValue( + while_operand, *loop_comparison.lhs.param_index); + induction_var_update = PatternMatchInductionVarUpdateFromLoopBodyRoot( + while_body->root_instruction(), *loop_comparison.lhs.param_index); + lhs_is_induction_var = true; + } + } else { + CHECK(loop_comparison.rhs.value->is_dynamic()); + if (loop_comparison.rhs.param_index.has_value()) { + VLOG(3) << "Comparison RHS is induction variable."; + induction_var_init = EvaluateWhileLoopParamInitValue( + while_operand, *loop_comparison.rhs.param_index); + induction_var_update = PatternMatchInductionVarUpdateFromLoopBodyRoot( + while_body->root_instruction(), *loop_comparison.rhs.param_index); + lhs_is_induction_var = false; + } } - // Check whether LHS is the loop induction var. - std::optional rhs_induction_var_update; - if (loop_comparison.rhs.param_index.has_value()) { - rhs_induction_var_update = PatternMatchInductionVarUpdate( - while_body->root_instruction(), *loop_comparison.rhs.param_index); + if (!induction_var_init.has_value() || !induction_var_update.has_value()) { + return std::nullopt; } - + VLOG(3) << "induction_var_init: " << induction_var_init->ToString(); + VLOG(3) << "induction_var_update: " << induction_var_update->ToString(); + if (induction_var_init->is_dynamic() || induction_var_update->is_dynamic()) { + return kParsedDynamicWhileLoop; + } + + int64_t init_value = *induction_var_init->static_value; + int64_t update_value = *induction_var_update->static_value; + Comparison::Direction comparison_direction = + loop_comparison.comparison_direction; + ParsedWhileLoop parsed_static_while_loop = ParsedWhileLoop{ + ParsedStaticWhileLoop{/*trip_count=*/0, + // Unassigned. + /*induction_var_index=*/-1, + /*induction_var_init_value=*/init_value, + /*step_size=*/update_value, + // Unassigned. + /*loop_bound=*/-1}}; // Lhs is the induction variable. - if (lhs_induction_var_update.has_value()) { - // We cannot handle the case when both LHS and RHS are updated inside - // the loop body. - if (rhs_induction_var_update.has_value() && - *rhs_induction_var_update != 0) { - return std::nullopt; + if (lhs_is_induction_var) { + CHECK(loop_comparison.rhs.value.has_value() && + !loop_comparison.rhs.value->is_dynamic()); + int64_t bound = *loop_comparison.rhs.value->static_value; + parsed_static_while_loop.static_while_loop->induction_var_index = + *loop_comparison.lhs.param_index; + parsed_static_while_loop.static_while_loop->loop_bound = bound; + if (update_value > 0 && + (comparison_direction == Comparison::Direction::kLt || + comparison_direction == Comparison::Direction::kLe)) { + int64_t trip_count = ComputeTripCountFromComparison( + init_value, bound, update_value, + comparison_direction == Comparison::Direction::kLe); + parsed_static_while_loop.static_while_loop->trip_count = trip_count; + return parsed_static_while_loop; } - if (*lhs_induction_var_update > 0 && - (loop_comparison.comparson_direction == Comparison::Direction::kLt || - loop_comparison.comparson_direction == Comparison::Direction::kLe)) { - int64_t trip_count = - (*loop_comparison.rhs.value - *loop_comparison.lhs.value - 1) / - *lhs_induction_var_update + - 1; - // Additional logic to deal with Equal comparison. - if (loop_comparison.comparson_direction == Comparison::Direction::kLe && - (*loop_comparison.rhs.value - *loop_comparison.lhs.value) % - *lhs_induction_var_update == - 0) { - trip_count += 1; - } - return ParsedWhileLoop{ParsedStaticWhileLoop{ - /*trip_count=*/trip_count, - /*induction_var_index=*/*loop_comparison.lhs.param_index, - /*induction_var_init_value=*/*loop_comparison.lhs.value, - /*step_size=*/*lhs_induction_var_update, - /*loop_bound=*/*loop_comparison.rhs.value}}; - } else if (*lhs_induction_var_update < 0 && - (loop_comparison.comparson_direction == - Comparison::Direction::kGt || - loop_comparison.comparson_direction == - Comparison::Direction::kGe)) { - int trip_count = - (*loop_comparison.lhs.value - *loop_comparison.rhs.value - 1) / - *lhs_induction_var_update + - 1; - if (loop_comparison.comparson_direction == Comparison::Direction::kGe && - (*loop_comparison.lhs.value - *loop_comparison.rhs.value) % - *lhs_induction_var_update == - 0) { - trip_count += 1; - } - return ParsedWhileLoop{ParsedStaticWhileLoop{ - /*trip_count=*/trip_count, - /*induction_var_index=*/*(loop_comparison.lhs.param_index), - /*induction_var_init_value=*/*(loop_comparison.lhs.value), - /*step_size=*/-*lhs_induction_var_update, - /*loop_bound=*/*(loop_comparison.rhs.value)}}; + if (update_value < 0 && + (comparison_direction == Comparison::Direction::kGt || + comparison_direction == Comparison::Direction::kGe)) { + int64_t trip_count = ComputeTripCountFromComparison( + bound, init_value, -update_value, + comparison_direction == Comparison::Direction::kGe); + parsed_static_while_loop.static_while_loop->trip_count = trip_count; + return parsed_static_while_loop; } return std::nullopt; } // Rhs is the induction variable. - if (rhs_induction_var_update.has_value()) { - // We cannot handle the case when both LHS and RHS are updated inside - // the loop body. - if (lhs_induction_var_update.has_value() && - *lhs_induction_var_update == 0) { - return std::nullopt; - } - if (*rhs_induction_var_update > 0 && - (loop_comparison.comparson_direction == Comparison::Direction::kGt || - loop_comparison.comparson_direction == Comparison::Direction::kGe)) { - int trip_count = - (*loop_comparison.lhs.value - *loop_comparison.rhs.value - 1) / - *rhs_induction_var_update + - 1; - if (loop_comparison.comparson_direction == Comparison::Direction::kGe && - (*loop_comparison.lhs.value - *loop_comparison.rhs.value) % - *rhs_induction_var_update == - 0) { - trip_count += 1; - } - return ParsedWhileLoop{ParsedStaticWhileLoop{ - /*trip_count=*/trip_count, - /*induction_var_index=*/*(loop_comparison.rhs.param_index), - /*induction_var_init_value=*/*(loop_comparison.rhs.value), - /*step_size=*/*rhs_induction_var_update, - /*loop_bound=*/*(loop_comparison.lhs.value)}}; - } else if (*rhs_induction_var_update < 0 && - (loop_comparison.comparson_direction == - Comparison::Direction::kLt || - loop_comparison.comparson_direction == - Comparison::Direction::kLe)) { - int trip_count = - (*loop_comparison.rhs.value - *loop_comparison.lhs.value - 1) / - *rhs_induction_var_update + - 1; - if (loop_comparison.comparson_direction == Comparison::Direction::kLe && - (*loop_comparison.rhs.value - *loop_comparison.lhs.value) % - *rhs_induction_var_update == - 0) { - trip_count += 1; - } - return ParsedWhileLoop{ParsedStaticWhileLoop{ - /*trip_count=*/trip_count, - /*induction_var_index=*/*(loop_comparison.rhs.param_index), - /*induction_var_init_value=*/*(loop_comparison.rhs.value), - /*step_size=*/-*rhs_induction_var_update, - /*loop_bound=*/*(loop_comparison.lhs.value)}}; - } - return std::nullopt; + CHECK(loop_comparison.lhs.value.has_value() && + !loop_comparison.lhs.value->is_dynamic()); + int64_t bound = *loop_comparison.lhs.value->static_value; + parsed_static_while_loop.static_while_loop->induction_var_index = + *loop_comparison.rhs.param_index; + parsed_static_while_loop.static_while_loop->loop_bound = bound; + if (update_value > 0 && + (comparison_direction == Comparison::Direction::kGt || + comparison_direction == Comparison::Direction::kGe)) { + int64_t trip_count = ComputeTripCountFromComparison( + init_value, bound, update_value, + comparison_direction == Comparison::Direction::kGe); + parsed_static_while_loop.static_while_loop->trip_count = trip_count; + return parsed_static_while_loop; + } + if (update_value < 0 && + (comparison_direction == Comparison::Direction::kLt || + comparison_direction == Comparison::Direction::kLe)) { + int64_t trip_count = ComputeTripCountFromComparison( + bound, init_value, -update_value, + comparison_direction == Comparison::Direction::kLe); + parsed_static_while_loop.static_while_loop->trip_count = trip_count; + return parsed_static_while_loop; } return std::nullopt; } @@ -716,19 +800,19 @@ HloEvaluator::HloEvaluator(int64_t max_loop_iterations) typed_visitors_[PRED] = std::make_unique>(this); typed_visitors_[U8] = - std::make_unique>(this); + std::make_unique>(this); typed_visitors_[U16] = - std::make_unique>(this); + std::make_unique>(this); typed_visitors_[U32] = - std::make_unique>(this); + std::make_unique>(this); typed_visitors_[U64] = std::make_unique>(this); typed_visitors_[S8] = - std::make_unique>(this); + std::make_unique>(this); typed_visitors_[S16] = - std::make_unique>(this); + std::make_unique>(this); typed_visitors_[S32] = - std::make_unique>(this); + std::make_unique>(this); typed_visitors_[S64] = std::make_unique>(this); typed_visitors_[F16] = @@ -742,12 +826,17 @@ HloEvaluator::HloEvaluator(int64_t max_loop_iterations) typed_visitors_[C128] = std::make_unique>(this); - // Most of the evaluator computations we use don't support BF16 (e.g., - // std::ceil, std::tanh). To make evaluator work with BF16, we set all - // elementwise computations to be done in F32 and do BF16<->F32 conversion - // around the input and the output of the computations. + // Most of the evaluator computations we use don't support BF16 and F8 (e.g., + // std::ceil, std::tanh). To make evaluator work with these dtypes, we set all + // elementwise computations to be done in F32 and do BF16<->F32 or F8<->F32 + // conversion around the input and the output of the computations. typed_visitors_[BF16] = std::make_unique>(this); + typed_visitors_[F8E5M2] = + std::make_unique>(this); + typed_visitors_[F8E4M3FN] = + std::make_unique>( + this); typed_visitors_[TUPLE] = std::make_unique([](HloInstruction*) { @@ -772,6 +861,7 @@ StatusOr HloEvaluator::Evaluate( CHECK(computation.parent() != nullptr); XLA_VLOG_LINES( 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); + OnEvaluateComputation(computation); if (arg_literals.size() != computation.num_parameters()) { return InvalidArgument( @@ -794,6 +884,8 @@ StatusOr HloEvaluator::Evaluate( evaluated_.clear(); arg_literals_.clear(); + call_graph_cache_.reset(); + tuple_points_to_analysis_cache_.reset(); for (const auto& literal_ptr : arg_literals) { arg_literals_.push_back(&*literal_ptr); } @@ -829,6 +921,8 @@ StatusOr HloEvaluator::Evaluate( bool recursively_evaluate_nonconstant_operands) { arg_literals_.clear(); evaluated_.clear(); + call_graph_cache_.reset(); + tuple_points_to_analysis_cache_.reset(); auto enable_partial_evaluation_cleanup = absl::MakeCleanup([this] { enable_partial_evaluation_ = false; }); enable_partial_evaluation_ = recursively_evaluate_nonconstant_operands; @@ -967,6 +1061,61 @@ StatusOr HloEvaluator::EvaluateDotOp( return Evaluate(cloned_instruction.get()); } +Status HloEvaluator::EvaluateParameterFromCallerArgument( + HloInstruction* parameter, const ShapeIndex& shape_index) { + CHECK(!evaluated_.contains(parameter)); + const HloComputation* parent_computation = parameter->parent(); + std::vector computation_callers = + call_graph_cache_->GetComputationCallers(parent_computation); + // If the parent computation has multiple callers, we cannot determine from + // which caller the arguments are passed. + if (computation_callers.size() != 1) { + return tsl::errors::FailedPrecondition( + "The computation ", parent_computation->name(), " is called by ", + computation_callers.size(), + " callers and thus its argument value " + "cannot be determined statically."); + } + HloInstruction* computation_caller = computation_callers[0]; + HloInstruction* caller_operand = computation_caller->mutable_operand(0); + if (computation_caller->opcode() != HloOpcode::kWhile && + computation_caller->opcode() != HloOpcode::kCall) { + return tsl::errors::FailedPrecondition( + "The computation ", parent_computation->name(), " is called by ", + "instruction ", computation_caller->name(), + ", which is not yet supported."); + } + if (computation_caller->opcode() == HloOpcode::kWhile) { + HloComputation* while_body = computation_caller->while_body(); + TF_ASSIGN_OR_RETURN( + const LogicalBuffer* logical_buffer, + tuple_points_to_analysis_cache_->GetBufferDefinedAt( + while_body->parameter_instruction(parameter->parameter_number()), + shape_index)); + const TuplePointsToAnalysis::BufferAliasVector& buffer_aliases = + tuple_points_to_analysis_cache_->GetBufferAliases(*logical_buffer); + bool unchanged_in_return = false; + for (const BufferAlias& buffer_alias : buffer_aliases) { + if (buffer_alias.instruction() == while_body->root_instruction() && + buffer_alias.index() == shape_index) { + unchanged_in_return = true; + } + } + if (!unchanged_in_return) { + return MakeEvalErrorDueToParamOrInfeed(*parameter); + } + } + TF_RETURN_IF_ERROR(EvaluateInternal(caller_operand, shape_index, true)); + const Literal& caller_operand_literal = + GetEvaluatedLiteralFor(caller_operand); + evaluated_[parameter] = + Literal::CreateFromShapeWithUnknownLeafArrays(parameter->shape()); + TF_RETURN_IF_ERROR(evaluated_[parameter].CopyFrom( + caller_operand_literal, /*dest_shape_index=*/shape_index, + /*src_shape_index=*/shape_index)); + return OkStatus(); +} + Status HloEvaluator::EvaluateInternal( HloInstruction* instruction, const ShapeIndex& shape_index, bool recursively_evaluate_nonconstant_operands) { @@ -995,6 +1144,32 @@ Status HloEvaluator::EvaluateInternal( TF_RETURN_IF_ERROR(EvaluateInternal( instruction->mutable_operand(tuple_index), new_shape_index, /*recursively_evaluate_nonconstant_operands=*/true)); + } else if (instruction->opcode() == HloOpcode::kParameter) { + if (!call_graph_cache_) { + HloModule* module = instruction->GetModule(); + call_graph_cache_ = CallGraph::Build(module); + } + if (!tuple_points_to_analysis_cache_) { + HloModule* module = instruction->GetModule(); + StatusOr> + tuple_points_to_analysis = TuplePointsToAnalysis::Run(module); + if (tuple_points_to_analysis.ok()) { + tuple_points_to_analysis_cache_ = + *std::move(tuple_points_to_analysis); + } + } + if (call_graph_cache_ && tuple_points_to_analysis_cache_) { + Status argument_eval_status = + EvaluateParameterFromCallerArgument(instruction, shape_index); + if (!argument_eval_status.ok()) { + VLOG(4) << "Failed to evaluate parameter " << instruction->name() + << " from caller. Reason: " + << argument_eval_status.error_message(); + } else { + VLOG(4) << "Successfully evaluated parameter: " + << instruction->name(); + } + } } else { for (HloInstruction* operand : instruction->operands()) { TF_RETURN_IF_ERROR(EvaluateInternal( @@ -1080,7 +1255,7 @@ Status HloEvaluator::HandleSetDimensionSize( } Status HloEvaluator::HandleParameter(HloInstruction* parameter) { - if (arg_literals_.empty()) { + if (!IsAlreadyEvaluated(parameter, visitor_shape_index_)) { if (!enable_partial_evaluation_) { return tsl::errors::FailedPrecondition( "Failed to evaluate instruction since its operands are unknown " @@ -1091,10 +1266,10 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) { return OkStatus(); } - // Nothing to do other than sanity checks. Parameters' values are stored in - // arg_literals_. - CHECK_LT(parameter->parameter_number(), arg_literals_.size()); - + if (!arg_literals_.empty()) { + // Nothing to do other than sanity checks. Parameters' values are stored in + // arg_literals_. + CHECK_LT(parameter->parameter_number(), arg_literals_.size()); #ifndef NDEBUG const Literal* input_literal = arg_literals_[parameter->parameter_number()]; VLOG(2) << "Parameter evaluated to: " << input_literal->ToString(); @@ -1105,6 +1280,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) { << ", but input literal shape is: " << ShapeUtil::HumanStringWithLayout(input_literal->shape()); #endif + } return OkStatus(); } @@ -1203,6 +1379,9 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { "expected element type in shape to be floating point, but " "got: %s", PrimitiveType_Name(elem_ty)); + case F8E5M2: + case F8E4M3FN: + return InvalidArgument("F8 is unsupported in IsFinite"); case F16: { auto result_or = ElementWiseUnaryOpImpl( @@ -3148,7 +3327,9 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { HloModuleConfig config; // Attach cloned computation to an empty HLO module so the existing ones are // not modified. - HloModule empty_hlo_module("EmptyModuleForFusion", config); + HloModule empty_hlo_module("EmptyModuleForFusion", config, + std::make_unique( + fusion->GetModule()->comp_envs())); HloCloneContext context(&empty_hlo_module); auto cloned_fused_computation = fusion->fused_instructions_computation()->Clone( @@ -3773,19 +3954,25 @@ Status HloEvaluator::HandleReduce(HloInstruction* instr) { } } - std::unique_ptr embedded_evaluator = - CreateEmbedded(max_loop_iterations_); + const int num_threads = ShapeUtil::GetForEachIndexParallelThreadCount() + 1; + std::vector> embedded_evaluators; + embedded_evaluators.reserve(num_threads); + for (int i = 0; i < num_threads; ++i) { + embedded_evaluators.push_back(CreateEmbedded(max_loop_iterations_)); + } + absl::InlinedVector results(num_args); for (int64_t i = 0; i < num_args; ++i) { results[i] = Literal(is_tuple ? out_shape.tuple_shapes(i) : out_shape); } - TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( - output_shape, [&](absl::Span output_index) { + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexParallelWithStatus( + output_shape, [&](absl::Span output_index, int thread_id) { return GenerateReduceOutputElement( is_tuple, output_index, init_values, input_args, - absl::Span(results), function, embedded_evaluator.get(), - arg_dim_steps, arg_dim_counts, result_to_arg_index); + absl::Span(results), function, + embedded_evaluators[thread_id + 1].get(), arg_dim_steps, + arg_dim_counts, result_to_arg_index); })); if (is_tuple) { diff --git a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.h b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.h index 6f7b962f51b..5636a0ff7fc 100644 --- a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.h +++ b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.h @@ -26,14 +26,16 @@ limitations under the License. #include "absl/container/node_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" @@ -90,6 +92,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return std::make_unique(max_loop_iterations); } + // Enables subclasses to be notified when a new computation is being + // evaluated. + virtual void OnEvaluateComputation(const HloComputation& computation) {} + // Evaluates an HLO module and an array of pointers to literals. Returns the // evaluated result as a literal if successful. // @@ -242,6 +248,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status EvaluateInternal( HloInstruction* instruction, const ShapeIndex& shape_index = {}, bool recursively_evaluate_nonconstant_operands = false); + + Status EvaluateParameterFromCallerArgument(HloInstruction* parameter, + const ShapeIndex& shape_index); + // Make HloEvaluatorTypedVisitor a friend because it is logically part of this // class. // @@ -425,6 +435,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { ShapeIndex visitor_shape_index_; bool enable_partial_evaluation_ = false; + std::unique_ptr call_graph_cache_; + std::unique_ptr tuple_points_to_analysis_cache_; + // Use fast path that uses eigen in the evaluator. bool use_fast_path_ = false; @@ -439,8 +452,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::SameDimensions(shape, operand->shape())); Literal result(shape); - TF_RETURN_IF_ERROR( - result.Populate([&](absl::Span multi_index) { + TF_RETURN_IF_ERROR(result.PopulateParallel( + [&](absl::Span multi_index, int) { return unary_op(operand_literal.Get(multi_index)); })); return std::move(result); diff --git a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_test.cc b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_test.cc index c36ac720859..d1aaa6a42ee 100644 --- a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -23,12 +24,12 @@ limitations under the License. #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/permutation_util.h" #include "tensorflow/compiler/xla/reference_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -460,6 +461,12 @@ TEST_P(HloEvaluatorBf16Test, DoesSinR2) { TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } +TEST_P(HloEvaluatorBf16Test, DoesTanR2) { + auto operand = LiteralUtil::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); + auto expected = LiteralUtil::CreateR2({{0, 0}, {0, 0}}); + TestUnaryOp(HloOpcode::kTan, std::move(expected), std::move(operand), + use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); +} TEST_F(HloEvaluatorTest, DoesNotR2) { auto operand = LiteralUtil::CreateR2({{0, std::numeric_limits::min()}, @@ -5462,5 +5469,135 @@ TEST_F(PatternMatchParseWhileLoopTest, BooleanCond) { EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 1); } +TEST_F(PatternMatchParseWhileLoopTest, NestedLoop) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_all_reduce + + %nested_while_condition { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT + } + + %nested_while_body { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2 + %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3 + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation) + } + + %while_condition { + %param = (s32[], s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT + } + + %while_body { + %param = (s32[], s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + %gte.2 = s32[] get-tuple-element(%param), index=2 + %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3 + %gte.4 = f32[1024, 1024] get-tuple-element(%param), index=4 + %constant.4 = s32[] constant(0) + %nested_while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.4, s32[] %gte.2, f32[1024, 1024] %gte.3, f32[1024, 1024] %gte.4) + %nested_while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%nested_while_init), condition=%nested_while_condition, body=%nested_while_body + %nested_while_result = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %nested_while), index=3 + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %gte.3, %nested_while_result) + } + + ENTRY accumulated_all_reduce { + %param.1 = f32[1024, 1024] parameter(0) + %param.2 = s32[] parameter(1) + %constant.0 = s32[] constant(0) + %constant.2 = s32[] constant(4) + %loop_bound = s32[] multiply(s32[] %param.2, s32[] %constant.2) + %constant.3 = s32[] constant(5) + %nested_loop_bound = s32[] multiply(s32[] %constant.3, s32[] %constant.2) + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %while_init = (s32[], s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %loop_bound, s32[] %nested_loop_bound, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer) + %while = (s32[], s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + ROOT %result = f32[1024, 1024] get-tuple-element((s32[], s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=4 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloModule)); + HloInstruction* while_op = + hlo_module->entry_computation()->root_instruction()->mutable_operand(0); + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + HloComputation* while_body = while_op->while_body(); + HloInstruction* nested_while = + while_body->root_instruction()->mutable_operand(4)->mutable_operand(0); + CHECK_EQ(nested_while->opcode(), HloOpcode::kWhile); + std::optional parsed_while_loop = + PatternMatchParseWhileLoop(nested_while); + ASSERT_TRUE(parsed_while_loop.has_value()); + EXPECT_FALSE(parsed_while_loop->is_dynamic()); + EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 20); + EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0); + EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0); + EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1); + EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 20); +} + +TEST_F(PatternMatchParseWhileLoopTest, CopiedLoopCond) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_all_reduce + + %while_condition { + %param = (s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %copy.0 = s32[] copy(s32[] %gte.0) + %loop_bound = s32[] constant(5) + %result = pred[] compare(%gte.0, %loop_bound), direction=LT + ROOT %copy.1 = pred[] copy(pred[] %result) + } + + %while_body { + %param = (s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = f32[1024, 1024] get-tuple-element(%param), index=1 + %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2 + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.1, f32[1024, 1024] %gte.2) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %accumulation) + } + + ENTRY accumulated_all_reduce { + %param.1 = f32[1024, 1024] parameter(0) + %constant.0 = s32[] constant(0) + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %while_init = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer) + %while = (s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + ROOT %result = f32[1024, 1024] get-tuple-element((s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=2 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloModule)); + HloInstruction* while_op = + hlo_module->entry_computation()->root_instruction()->mutable_operand(0); + std::optional parsed_while_loop = + PatternMatchParseWhileLoop(while_op); + ASSERT_TRUE(parsed_while_loop.has_value()); + EXPECT_FALSE(parsed_while_loop->is_dynamic()); + EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 5); + EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0); + EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0); + EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1); + EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 5); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index 6542272c307..4bc37760338 100644 --- a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -29,6 +29,7 @@ limitations under the License. #include #include #include +#include #include "absl/algorithm/container.h" #include "absl/base/casts.h" @@ -36,13 +37,15 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instructions.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" -#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/tsl/platform/statusor.h" namespace xla { @@ -69,11 +72,10 @@ T Nibble1(T t) { // // Anyway this is relatively safe as-is because hlo_evaluator_typed_visitor.h is // a "private" header that's not exposed outside of hlo_evaluator.cc. -// -// Not using an alias template to work around MSVC 14.00 bug. template -struct is_complex_t : std::disjunction, - std::is_same> {}; +struct is_complex_t : std::false_type {}; +template +struct is_complex_t> : std::true_type {}; template inline constexpr bool is_complex_v = is_complex_t::value; @@ -101,33 +103,6 @@ auto ToArithmeticSafeType(T t) { } } -// UintWithSize gets an unsigned integer with the given size in bytes. -template -struct UintWithSize {}; - -template <> -struct UintWithSize<1> { - using type = uint8_t; -}; - -template <> -struct UintWithSize<2> { - using type = uint16_t; -}; - -template <> -struct UintWithSize<4> { - using type = uint32_t; -}; - -template <> -struct UintWithSize<8> { - using type = uint64_t; -}; - -template -using UintWithSizeType = typename UintWithSize::type; - // Templated DfsHloVisitor for use by HloEvaluator. // // Typically ReturnT here indicates the resulting literal type of each evaluated @@ -277,16 +252,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template >* = nullptr> Status HandleRoundNearestEven(HloInstruction* round) { - // Saves current rounding direction. - int curr_direction = fegetround(); - fesetround(FE_TONEAREST); + // Verify the current rounding direction. + TF_RET_CHECK(fegetround() == FE_TONEAREST); TF_ASSIGN_OR_RETURN( parent_->evaluated_[round], ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) { return std::nearbyint(elem_operand); })); - // Restores default rounding direction. - fesetround(curr_direction); return OkStatus(); } @@ -508,9 +480,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return (ElementwiseT(0) < elem_operand) - (elem_operand < ElementwiseT(0)); } - if constexpr (std::is_same_v || - std::is_same_v || - std::is_floating_point_v) { + if constexpr (std::is_floating_point_v) { return std::isnan(elem_operand) ? elem_operand : std::copysign(elem_operand != ElementwiseT(0), @@ -686,34 +656,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return OkStatus(); } - template >* = nullptr> - Status HandleCbrt(HloInstruction* cbrt) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[cbrt], - ElementWiseUnaryOp(cbrt, [](ElementwiseT elem_operand) -> ElementwiseT { - return std::pow(elem_operand, static_cast(1.0 / 3.0)); - return elem_operand.real() < 0 - ? -std::pow(-elem_operand, - static_cast(1.0 / 3.0)) - : std::pow(elem_operand, - static_cast(1.0 / 3.0)); - })); - return OkStatus(); - } - - template >* = nullptr> - Status HandleCbrt(HloInstruction* cbrt) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[cbrt], - ElementWiseUnaryOp(cbrt, [](ElementwiseT elem_operand) { - return std::cbrt(elem_operand); - })); - return OkStatus(); - } - Status HandleCbrt(HloInstruction* cbrt) override { - return HandleCbrt(cbrt); + if constexpr (!is_complex_v) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[cbrt], + ElementWiseUnaryOp(cbrt, [](ElementwiseT elem_operand) { + return std::cbrt(elem_operand); + })); + return OkStatus(); + } + return UnsupportedTypeError(cbrt); } Status HandleRsqrt(HloInstruction* rsqrt) override { @@ -809,13 +761,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleShiftRightArithmetic(HloInstruction* shr) override { if constexpr (std::is_integral_v && !std::is_same_v) { - using SignedT = std::make_signed_t; + using SignedT = std::make_signed_t; TF_ASSIGN_OR_RETURN( parent_->evaluated_[shr], ElementWiseBinaryOp( shr, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { SignedT lhs_signed = static_cast(lhs_elem); - if (IsShiftOutOfBounds(rhs_elem)) { + if (IsShiftOutOfBounds(rhs_elem)) { return lhs_signed < 0 ? static_cast(-1) : 0; } else { return lhs_signed >> rhs_elem; @@ -829,13 +781,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleShiftRightLogical(HloInstruction* shr) override { if constexpr (std::is_integral_v && !std::is_same_v) { - using UnsignedT = std::make_unsigned_t; + using UnsignedT = std::make_unsigned_t; TF_ASSIGN_OR_RETURN(parent_->evaluated_[shr], ElementWiseBinaryOp(shr, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { // If shift amount is greater than the number of // bits, then return 0. - if (IsShiftOutOfBounds(rhs_elem)) { + if (IsShiftOutOfBounds(rhs_elem)) { return static_cast(0); } return static_cast( @@ -847,8 +799,18 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandleStochasticConvert(HloInstruction* stochastic_convert) override { - // TODO(b/232442915): Add support for stochastic convert. - return UnsupportedTypeError(stochastic_convert); + const HloInstruction* operand = stochastic_convert->operand(0); + const HloInstruction* random = stochastic_convert->operand(1); + const Shape& result_shape = stochastic_convert->shape(); + TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), random->shape())); + TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), result_shape)); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& random_literal = parent_->GetEvaluatedLiteralFor(random); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[stochastic_convert], + StochasticConvertOp(operand_literal, random_literal, result_shape)); + return OkStatus(); } Status HandleClamp(HloInstruction* clamp) override { @@ -909,8 +871,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); Literal result(result_shape); - TF_RETURN_IF_ERROR( - result.Populate([&](absl::Span out_index) { + TF_RETURN_IF_ERROR(result.PopulateParallel( + [&](absl::Span out_index, int) { std::vector from_index(out_index.begin(), out_index.end()); for (const int64_t dim : reverse_dimensions) { from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; @@ -1006,8 +968,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64_t feature_group_index = out_index[output_z_dim] / output_feature_group_size; - const int64_t depthwise_multiplier = - batch_group_count > 1 ? output_z_size / input_batch_size : 1; + const int64_t depthwise_multiplier = output_z_size / batch_group_count; const int64_t batch_group_index = out_index[output_z_dim] / depthwise_multiplier; @@ -1080,9 +1041,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // This approach works out automatically for 'groups' in batches // with group_size > 1, because we already descend down the batch // dimension for the 'output_batch_dim' above. - lhs_linear_index += - ((batch_group_index * batch_group_size) % input_batch_size) * - lhs_dim_multipliers[input_batch_dim]; + lhs_linear_index += (batch_group_index * batch_group_size) * + lhs_dim_multipliers[input_batch_dim]; lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; int64_t rhs_linear_index = rhs_linear_spatial_index; @@ -1108,6 +1068,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } while (IndexUtil::BumpIndices(window_shape, absl::MakeSpan(rhs_spatial_index))); + if constexpr (std::is_integral_v) { + auto l = static_cast(std::numeric_limits::min()); + auto h = static_cast(std::numeric_limits::max()); + result_val = std::max(l, std::min(h, result_val)); + } return static_cast(result_val); }; @@ -1341,7 +1306,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto rhs_n0 = ToArithmeticSafeType(Nibble0(rhs)); auto rhs_n1 = ToArithmeticSafeType(Nibble1(rhs)); result_val += (lhs_n0 * rhs_n0) + (lhs_n1 * rhs_n1); - } else { result_val += ToArithmeticSafeType(lhs) * ToArithmeticSafeType(rhs); @@ -1421,8 +1385,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { ReturnT scalar = parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); Literal result(pad->shape()); - TF_RETURN_IF_ERROR(result.Populate( - [&scalar](absl::Span multi_index) { return scalar; })); + TF_RETURN_IF_ERROR(result.PopulateParallel( + [&scalar](absl::Span multi_index, int) { + return scalar; + })); const Literal& evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0)); @@ -1662,6 +1628,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); break; } + case F8E5M2: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], + MapImpl(map)); + break; + } + case F8E4M3FN: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], + MapImpl(map)); + break; + } case F16: { TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); @@ -1848,7 +1824,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Shape window_shape = ShapeUtil::MakeShape( input_arrays[0]->shape().element_type(), window_dimension_sizes); - const int num_threads = tsl::port::MaxParallelism() + 1; + const int num_threads = ShapeUtil::GetForEachIndexParallelThreadCount() + 1; std::vector> embedded_evaluators; embedded_evaluators.reserve(num_threads); for (int i = 0; i < num_threads; ++i) { @@ -1968,7 +1944,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64_t rank = operand->shape().rank(); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto func = [&](absl::Span out_index) { + auto func = [&](absl::Span out_index, int) { DimensionVector operand_index(rank); for (int64_t i = 0; i < rank; ++i) { operand_index[i] = @@ -1978,12 +1954,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { }; Literal result(shape); - TF_RETURN_IF_ERROR(result.Populate(func)); + TF_RETURN_IF_ERROR(result.PopulateParallel(func)); parent_->evaluated_[slice] = std::move(result); return OkStatus(); } - // Enable CLZ only for int32_t, uint32_t, int64_t and uint64_t. + // Enable CLZ only for integer types. template || std::is_same_v>* = nullptr> @@ -1995,13 +1971,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_integral_v && !std::is_same_v>* = nullptr> Status HandleClz(HloInstruction* clz) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[clz], - ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { - using UnsignedElementwiseT = std::make_unsigned_t; - return (std::numeric_limits::digits - 1) - - Log2Floor(elem_operand); - })); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], + ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { + using UnsignedT = std::make_unsigned_t; + return (std::numeric_limits::digits - 1) - + Log2Floor(elem_operand); + })); return OkStatus(); } @@ -2023,8 +1998,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( parent_->evaluated_[popcnt], ElementWiseUnaryOp(popcnt, [](ElementwiseT elem_operand) { - return std::bitset(elem_operand) - .count(); + return std::bitset(elem_operand).count(); })); return OkStatus(); } @@ -2033,33 +2007,38 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandlePopulationCount(popcnt); } - template >* = nullptr> - Status HandleSin(HloInstruction* sin) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], - ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { - return std::sin(elem_operand); - })); - return OkStatus(); - } - - template || - is_complex_v>* = nullptr> - Status HandleSin(HloInstruction* sin) { + Status HandleSin(HloInstruction* sin) override { + if constexpr (std::is_floating_point_v || + is_complex_v) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[sin], + ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { + return std::sin(elem_operand); + })); + return OkStatus(); + } return UnsupportedTypeError(sin); } - Status HandleSin(HloInstruction* sin) override { - return HandleSin(sin); + Status HandleCos(HloInstruction* cos) override { + if constexpr (std::is_floating_point_v || + is_complex_v) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[cos], + ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { + return std::cos(elem_operand); + })); + return OkStatus(); + } + return UnsupportedTypeError(cos); } template >* = nullptr> - Status HandleCos(HloInstruction* cos) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], - ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { - return std::cos(elem_operand); + Status HandleTan(HloInstruction* tan) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[tan], + ElementWiseUnaryOp(tan, [](ElementwiseT elem_operand) { + return std::tan(elem_operand); })); return OkStatus(); } @@ -2067,17 +2046,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { template || is_complex_v>* = nullptr> - Status HandleCos(HloInstruction* cos) { - return UnsupportedTypeError(cos); + Status HandleTan(HloInstruction* tan) { + return UnsupportedTypeError(tan); } - Status HandleCos(HloInstruction* cos) override { - return HandleCos(cos); + Status HandleTan(HloInstruction* tan) override { + return HandleTan(tan); } template || - std::is_same_v>* = nullptr> + std::is_floating_point_v>* = nullptr> Status HandleReducePrecision(HloInstruction* reduce_precision) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[reduce_precision], @@ -2089,7 +2067,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const uint32_t dest_mantissa_bits = reduce_precision->mantissa_bits(); const uint32_t dest_exponent_bits = reduce_precision->exponent_bits(); - using Uint = UintWithSizeType; + using Uint = UnsignedIntegerTypeForSizeType; Uint value_as_int = absl::bit_cast(elem); // Code is based on the CPU/GPU implementation in LLVM-emitting code. @@ -2186,143 +2164,118 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleReducePrecision(reduce_precision); } - template || - std::is_same_v || - std::is_integral_v || is_complex_v || - std::is_floating_point_v>* = nullptr> - Status HandleIota(HloInstruction* instruction) { + Status HandleIota(HloInstruction* instruction) override { auto* iota = Cast(instruction); - - Literal result(iota->shape()); - ShapeUtil::ForEachIndex(iota->shape(), [&](absl::Span idx) { - result.Set(idx, static_cast(idx[iota->iota_dimension()])); - return true; - }); - parent_->evaluated_[iota] = std::move(result); - return OkStatus(); - } - template || - std::is_same_v || - std::is_integral_v || is_complex_v || - std::is_floating_point_v)>* = nullptr> - Status HandleIota(HloInstruction* iota) { + if constexpr (std::is_integral_v || + is_complex_v || + std::is_floating_point_v) { + Literal result(iota->shape()); + ShapeUtil::ForEachIndex( + iota->shape(), [&](absl::Span idx) { + result.Set(idx, static_cast(idx[iota->iota_dimension()])); + return true; + }); + parent_->evaluated_[iota] = std::move(result); + return OkStatus(); + } return UnsupportedTypeError(iota); } - Status HandleIota(HloInstruction* iota) override { - return HandleIota(iota); - } - template || - std::is_floating_point_v)>* = - nullptr> - Status HandleRng(HloInstruction* random) { - return UnsupportedTypeError(random); - } - template < - typename NativeT, - typename std::enable_if_t<(std::is_floating_point_v)>* = nullptr> - Status HandleRng(HloInstruction* random) { + Status HandleRng(HloInstruction* random) override { RandomDistribution distribution = random->random_distribution(); const auto result_shape = random->shape(); Literal result(result_shape); - switch (distribution) { - case RNG_UNIFORM: { - const Literal& low = - parent_->GetEvaluatedLiteralFor(random->operand(0)); - const Literal& high = - parent_->GetEvaluatedLiteralFor(random->operand(1)); - - // std::uniform_real_distribution(a, b) can sometimes return a value - // equal to b. Unclear if this is a spec bug or an implementation bug - // or WAI [0] [1] [2]. Anyway for our purposes we want a half-open - // interval, so we have to re-sample if we get `b` out. - // - // [0] https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63176 - // [1] https://bugs.llvm.org/show_bug.cgi?id=18767 - // [2] http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524 - auto low_val = low.Get({}); - auto high_val = high.Get({}); - std::uniform_real_distribution generator(low_val, high_val); - TF_RETURN_IF_ERROR(result.Populate( - [&](absl::Span /*indexes*/) { - while (true) { - NativeT v = generator(parent_->engine_); - if (v != high_val) { - return v; + if constexpr (std::is_floating_point_v) { + switch (distribution) { + case RNG_UNIFORM: { + const Literal& low = + parent_->GetEvaluatedLiteralFor(random->operand(0)); + const Literal& high = + parent_->GetEvaluatedLiteralFor(random->operand(1)); + + // std::uniform_real_distribution(a, b) can sometimes return a value + // equal to b. Unclear if this is a spec bug or an implementation bug + // or WAI [0] [1] [2]. Anyway for our purposes we want a half-open + // interval, so we have to re-sample if we get `b` out. + // + // [0] https://gcc.gnu.org/bugzilla/show_bug.cgi?id=63176 + // [1] https://bugs.llvm.org/show_bug.cgi?id=18767 + // [2] http://open-std.org/JTC1/SC22/WG21/docs/lwg-active.html#2524 + const ReturnT low_val = low.Get({}); + const ReturnT high_val = high.Get({}); + std::uniform_real_distribution generator( + static_cast(low_val), + static_cast(high_val)); + TF_RETURN_IF_ERROR(result.Populate( + [&](absl::Span /*indexes*/) { + while (true) { + const ReturnT v = + static_cast(generator(parent_->engine_)); + if (v >= low_val && v < high_val) { + return v; + } } - } - })); - break; - } - case RNG_NORMAL: { - const Literal& mean = - parent_->GetEvaluatedLiteralFor(random->operand(0)); - const Literal& stddev = - parent_->GetEvaluatedLiteralFor(random->operand(1)); - - std::normal_distribution generator(mean.Get({}), - stddev.Get({})); - - TF_RETURN_IF_ERROR(result.Populate( - [&](absl::Span /*indexes*/) { - return generator(parent_->engine_); - })); - break; + })); + break; + } + case RNG_NORMAL: { + const Literal& mean = + parent_->GetEvaluatedLiteralFor(random->operand(0)); + const Literal& stddev = + parent_->GetEvaluatedLiteralFor(random->operand(1)); + + std::normal_distribution generator( + static_cast(mean.Get({})), + static_cast(stddev.Get({}))); + + TF_RETURN_IF_ERROR(result.Populate( + [&](absl::Span /*indexes*/) { + return static_cast(generator(parent_->engine_)); + })); + break; + } + default: + return UnimplementedStrCat("The distribution ", + RandomDistribution_Name(distribution), + " is not implemented."); } - default: - return UnimplementedStrCat("The distribution ", - RandomDistribution_Name(distribution), - " is not implemented."); + parent_->evaluated_[random] = std::move(result); + return OkStatus(); } - parent_->evaluated_[random] = std::move(result); - return OkStatus(); - } - template )>* = nullptr> - Status HandleRng(HloInstruction* random) { - RandomDistribution distribution = random->random_distribution(); - const auto result_shape = random->shape(); - Literal result(result_shape); - - switch (distribution) { - case RNG_UNIFORM: { - const Literal& low = - parent_->GetEvaluatedLiteralFor(random->operand(0)); - const Literal& high = - parent_->GetEvaluatedLiteralFor(random->operand(1)); - - // Note std::uniform_int_distribution assumes interval is closed, i.e., - // [low, high], but we want [low, high) instead. Hence high-1 is used as - // the upper range. - std::uniform_int_distribution generator( - low.Get({}), high.Get({}) - 1); - - TF_RETURN_IF_ERROR(result.Populate( - [&](absl::Span /*indexes*/) { - return static_cast(generator(parent_->engine_)); - })); - break; - } - case RNG_NORMAL: { - return Unimplemented( - "Normal distribution is not supported for integral types."); + if constexpr (std::is_integral_v) { + switch (distribution) { + case RNG_UNIFORM: { + const Literal& low = + parent_->GetEvaluatedLiteralFor(random->operand(0)); + const Literal& high = + parent_->GetEvaluatedLiteralFor(random->operand(1)); + + // Note std::uniform_int_distribution assumes interval is closed, + // i.e., [low, high], but we want [low, high) instead. Hence high-1 is + // used as the upper range. + std::uniform_int_distribution generator( + low.Get({}), high.Get({}) - 1); + + TF_RETURN_IF_ERROR(result.Populate( + [&](absl::Span /*indexes*/) { + return static_cast(generator(parent_->engine_)); + })); + break; + } + case RNG_NORMAL: { + return Unimplemented( + "Normal distribution is not supported for integral types."); + } + default: + return UnimplementedStrCat("The distribution ", + RandomDistribution_Name(distribution), + " is not implemented."); } - default: - return UnimplementedStrCat("The distribution ", - RandomDistribution_Name(distribution), - " is not implemented."); + parent_->evaluated_[random] = std::move(result); + return OkStatus(); } - parent_->evaluated_[random] = std::move(result); - return OkStatus(); - } - Status HandleRng(HloInstruction* random) override { - return HandleRng(random); + return UnsupportedTypeError(random); } private: @@ -2494,8 +2447,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Literal result(shape); - TF_RETURN_IF_ERROR( - result.Populate([&](absl::Span multi_index) { + TF_RETURN_IF_ERROR(result.PopulateParallel( + [&](absl::Span multi_index, int) { return ConvertBinaryFunction(binary_op)( lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -2521,8 +2474,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Literal result(shape); - TF_RETURN_IF_ERROR( - result.Populate([&](absl::Span multi_index) { + TF_RETURN_IF_ERROR(result.PopulateParallel( + [&](absl::Span multi_index, int) { return ternary_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index), ehs_literal.Get(multi_index)); @@ -2531,6 +2484,131 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return std::move(result); } + template + StatusOr StochasticConvertOp(const Literal& operand_literal, + const Literal& random_literal, + const Shape& result_shape) { + std::function stochastic_convert_op = + [](Fp operand, Uint random) -> ResultT { + bool is_negative = ToSignMagnitude(operand) < 0; + if (Eigen::numext::isinf(operand)) { + return is_negative ? std::numeric_limits::min() + : std::numeric_limits::max(); + } + if (Eigen::numext::isnan(operand)) { + return static_cast(0); + } + if (operand >= static_cast(std::numeric_limits::max())) { + return std::numeric_limits::max(); + } + if (operand <= static_cast(std::numeric_limits::min())) { + return std::numeric_limits::min(); + } + + operand = Eigen::numext::abs(operand); + + // Gets the integral piece of the floating point input. + auto truncated = static_cast(operand); + + // Removes the integral piece to obtain the fractional piece. + Fp fractional = operand - static_cast(truncated); + if (fractional == Fp{0}) { + // No rounding necessary. + return is_negative ? -truncated : truncated; + } + + // Compares fractional values against unsigned random values by + // normalizing random values into [0, 1): fractional vs. (random / + // random_max). This equals to comparing (fractional * random_max) vs. + // random. + auto fixed_fractional = static_cast(std::ldexp( + static_cast(fractional), std::numeric_limits::digits)); + + // Rounds the integer output up if the fractional pieces is larger than + // the input random number. + if (random < fixed_fractional) { + // This only happens when the operand is in the (min, -max) range and + // should be rounded to min. + if (truncated == std::numeric_limits::max()) { + return std::numeric_limits::min(); + } + truncated++; + } + return is_negative ? -truncated : truncated; + }; + + Literal result(result_shape); + TF_RETURN_IF_ERROR( + result.Populate([&](absl::Span multi_index) { + return stochastic_convert_op(operand_literal.Get(multi_index), + random_literal.Get(multi_index)); + })); + return std::move(result); + } + + // Converts from primitive types to native types. + template + StatusOr StochasticConvertOp(const Literal& operand_literal, + const Literal& random_literal, + const Shape& result_shape) { + return StochasticConvertOp< + typename primitive_util::PrimitiveTypeToNative::type, + typename primitive_util::PrimitiveTypeToNative::type, + typename primitive_util::PrimitiveTypeToNative::type>( + operand_literal, random_literal, result_shape); + } + + // Evaluates all possible paths of converting to different integers. + template + StatusOr StochasticConvertOp(const Literal& operand_literal, + const Literal& random_literal, + const Shape& result_shape) { + switch (result_shape.element_type()) { +#define CONVERT_IF_RESULT_TYPES_MATCH(type) \ + case (type): \ + return StochasticConvertOp( \ + operand_literal, random_literal, result_shape); + CONVERT_IF_RESULT_TYPES_MATCH(S32) + CONVERT_IF_RESULT_TYPES_MATCH(S16) + CONVERT_IF_RESULT_TYPES_MATCH(S8) +#undef CONVERT_IF_RESULT_TYPES_MATCH + default: + break; + } + // TODO(b/232442915): Enable converting big floats to small floats. + return Unimplemented( + "Stochastically converting from type %s to type %s is not implemented.", + PrimitiveType_Name(operand_literal.shape().element_type()), + PrimitiveType_Name(result_shape.element_type())); + } + + StatusOr StochasticConvertOp(const Literal& operand_literal, + const Literal& random_literal, + const Shape& result_shape) { + switch (operand_literal.shape().element_type()) { + case F16: + return StochasticConvertOp(operand_literal, random_literal, + result_shape); + case BF16: + return StochasticConvertOp(operand_literal, random_literal, + result_shape); + case F32: + return StochasticConvertOp(operand_literal, random_literal, + result_shape); + case F64: + return StochasticConvertOp(operand_literal, random_literal, + result_shape); + default: + break; + } + // TODO(b/232442915): Enable converting big floats to small floats. + return Unimplemented( + "Stochastically converting from type %s to type %s is not implemented.", + PrimitiveType_Name(operand_literal.shape().element_type()), + PrimitiveType_Name(result_shape.element_type())); + } + template static bool IsShiftOutOfBounds(NativeT rhs) { using UnsignedT = std::make_unsigned_t; @@ -2546,11 +2624,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // instantiating it. We explicitly instantiate this class in the various // hlo_evaluator_typed_visitor*.cc files. extern template class HloEvaluatorTypedVisitor; -extern template class HloEvaluatorTypedVisitor; -extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; -extern template class HloEvaluatorTypedVisitor; -extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; @@ -2558,6 +2638,8 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc new file mode 100644 index 00000000000..3c878b900a6 --- /dev/null +++ b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc @@ -0,0 +1,22 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator.h" +#include "tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int16.cc b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int16.cc index 25a7cba94ad..775c491f4c4 100644 --- a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int16.cc +++ b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int16.cc @@ -17,5 +17,5 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" namespace xla { -template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int32.cc b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int32.cc index 993db017845..a22c31d516f 100644 --- a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int32.cc +++ b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int32.cc @@ -17,5 +17,5 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" namespace xla { -template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int8.cc b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int8.cc index b8a4b954c12..72c38b0db6c 100644 --- a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int8.cc +++ b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_int8.cc @@ -17,5 +17,5 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" namespace xla { -template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint16.cc b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint16.cc index 4cd243cde64..7c0b14e82fb 100644 --- a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint16.cc +++ b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint16.cc @@ -17,5 +17,5 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" namespace xla { -template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint32.cc b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint32.cc index 5740ea49bfe..56559d8fd21 100644 --- a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint32.cc +++ b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint32.cc @@ -17,5 +17,5 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" namespace xla { -template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint8.cc b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint8.cc index bee43f3feb5..4e94729ce86 100644 --- a/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint8.cc +++ b/tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor_uint8.cc @@ -17,5 +17,5 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" namespace xla { -template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD index c5d822d5abe..858bc589280 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/BUILD @@ -1,8 +1,11 @@ # Automatic sharding annotation -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_binary") -package(default_visibility = [":friends"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], +) package_group( name = "friends", @@ -16,21 +19,21 @@ cc_library( srcs = [ "auto_sharding.cc", "auto_sharding_dot_handler.cc", - "auto_sharding_util.cc", ], hdrs = [ "auto_sharding.h", - "auto_sharding_cost_graph.h", - "auto_sharding_strategy.h", - "auto_sharding_util.h", ], deps = [ - "//tensorflow/compiler/xla:array", - "//tensorflow/compiler/xla:shape_util", + ":auto_sharding_cost_graph", + ":auto_sharding_solver_option", + ":auto_sharding_strategy", + ":auto_sharding_util", + ":cluster_environment", + ":matrix", + ":metrics", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:heap_simulator", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_live_range", "//tensorflow/compiler/xla/service:hlo_memory_scheduler", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_pass", @@ -38,27 +41,110 @@ cc_library( "//tensorflow/compiler/xla/service:sharding_propagation", "//tensorflow/tsl/platform:errors", "//tensorflow/tsl/platform:status", - "//tensorflow/tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", "@com_google_ortools//ortools/linear_solver", "@com_google_ortools//ortools/linear_solver:linear_solver_cc_proto", ], ) -tf_cc_binary( +cc_library( + name = "auto_sharding_strategy", + hdrs = [ + "auto_sharding_strategy.h", + ], + deps = [ + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:hlo_value", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "auto_sharding_cost_graph", + hdrs = [ + "auto_sharding_cost_graph.h", + ], + deps = [ + ":auto_sharding_strategy", + ":matrix", + ], +) + +cc_library( + name = "matrix", + hdrs = [ + "matrix.h", + ], + deps = [ + "//tensorflow/tsl/platform:logging", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "cluster_environment", + srcs = ["cluster_environment.cc"], + hdrs = ["cluster_environment.h"], + deps = [ + ":auto_sharding_solver_option", + ":auto_sharding_util", + ":profiling_result", + ], +) + +cc_library( + name = "profiling_result", + hdrs = ["profiling_result.h"], + deps = [":auto_sharding_strategy"], +) + +cc_library( + name = "auto_sharding_solver_option", + hdrs = ["auto_sharding_solver_option.h"], +) + +cc_library( + name = "auto_sharding_util", + srcs = [ + "auto_sharding_util.cc", + ], + hdrs = [ + "auto_sharding_util.h", + ], + deps = [ + ":auto_sharding_strategy", + "//tensorflow/compiler/xla:array", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:hlo_sharding_util", + "//tensorflow/tsl/platform:errors", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "metrics", + srcs = ["metrics.cc"], + hdrs = ["metrics.h"], + deps = ["//tensorflow/tsl/lib/monitoring:counter"], +) + +xla_cc_binary( name = "auto_sharding_runner", srcs = ["auto_sharding_runner.cc"], deps = [ ":auto_sharding", "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/hlo/ir:hlo", "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/stream_executor:dnn", "//tensorflow/compiler/xla/tools:hlo_module_loader", "//tensorflow/tsl/platform:platform_port", ], diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc index c8a4da0b3dc..7c44139bf74 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -36,20 +36,27 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" #include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h" +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/cluster_environment.h" +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/matrix.h" +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/metrics.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" #include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/hlo_sharding_util.h" #include "tensorflow/compiler/xla/service/sharding_propagation.h" #include "tensorflow/tsl/platform/errors.h" #include "tensorflow/tsl/platform/status.h" #include "ortools/linear_solver/linear_solver.h" #include "ortools/linear_solver/linear_solver.pb.h" +#ifdef PLATFORM_GOOGLE +#include "file/base/helpers.h" +#include "util/task/status.pb.h" +#endif using MPConstraint = operations_research::MPConstraint; using MPSolver = operations_research::MPSolver; @@ -410,15 +417,9 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, const CallGraph& call_graph) { for (int64_t i = 0; i < shape.rank(); ++i) { for (int64_t j = 0; j < device_mesh.num_dimensions(); ++j) { - // Split one dim only when the tensor shape is divisable by device mesh. - // TODO(b/220942808) Shard non-divisible dimensions. - if (device_mesh.dim(j) == 1 || - !IsDivisible(shape.dimensions(i), device_mesh.dim(j))) { - continue; - } - - if (only_allow_divisible && - shape.dimensions(i) % device_mesh.dim(j) != 0) { + if (device_mesh.dim(j) == 1 || shape.dimensions(i) < device_mesh.dim(j) || + (only_allow_divisible && + !IsDivisible(shape.dimensions(i), device_mesh.dim(j)))) { continue; } @@ -435,6 +436,11 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, resharding_costs = ReshardingCostsForTupleOperand( ins->operand(0), strategy_map.at(ins->operand(0)).get()); LOG(INFO) << absl::StrJoin(resharding_costs.back(), ","); + } else if (ins->opcode() == HloOpcode::kRngBitGenerator && + ins->operand(0)->shape().IsArray()) { + resharding_costs = GenerateReshardingCostsForAllOperands( + ins, output_spec, strategy_map, cluster_env, call_graph, + {HloSharding::Replicate()}); } else { resharding_costs = GenerateReshardingCostsForAllOperands( ins, output_spec, strategy_map, cluster_env, call_graph); @@ -479,9 +485,10 @@ void EnumerateAll2DPartition(const HloInstruction* ins, const Shape& shape, } if (only_allow_divisible && - (shape.dimensions(i) % device_mesh.dim(shardable_mesh_dims[0]) != 0 || - shape.dimensions(j) % device_mesh.dim(shardable_mesh_dims[1]) != - 0)) { + (!IsDivisible(shape.dimensions(i), + device_mesh.dim(shardable_mesh_dims[0])) || + !IsDivisible(shape.dimensions(j), + device_mesh.dim(shardable_mesh_dims[1])))) { continue; } @@ -522,14 +529,15 @@ void EnumerateAll1DPartitionReshape(const HloInstruction* ins, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategies, + bool only_allow_divisible, const std::string& suffix) { const HloInstruction* operand = ins->operand(0); for (int64_t i = 0; i < ins->shape().rank(); ++i) { for (int64_t j = 0; j < device_mesh.num_dimensions(); ++j) { - // TODO(b/220942808) Shard non-divisible dimensions. if (device_mesh.dim(j) == 1 || - !IsDivisible(ins->shape().dimensions(i), device_mesh.dim(j))) { + (only_allow_divisible && + !IsDivisible(ins->shape().dimensions(i), device_mesh.dim(j)))) { continue; } HloSharding output_spec = Tile(ins->shape(), {i}, {j}, device_mesh); @@ -572,7 +580,8 @@ void Enumerate2DPartitionReshape(const HloInstruction* ins, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const InstructionBatchDimMap& batch_dim_map, - std::unique_ptr& strategies) { + std::unique_ptr& strategies, + bool only_allow_divisible) { std::vector shardable_mesh_dims = VectorGreaterThanOneElementIndices(device_mesh.dimensions()); auto iter = batch_dim_map.find(GetBatchDimMapKey(ins)); @@ -595,6 +604,13 @@ void Enumerate2DPartitionReshape(const HloInstruction* ins, device_mesh.dim(shardable_mesh_dims[1])) { continue; } + if (only_allow_divisible && + (!IsDivisible(ins->shape().dimensions(i), + device_mesh.dim(shardable_mesh_dims[0])) || + !IsDivisible(ins->shape().dimensions(j), + device_mesh.dim(shardable_mesh_dims[1])))) { + continue; + } HloSharding output_spec = Tile(ins->shape(), {i, j}, @@ -742,7 +758,7 @@ void DisableIncompatibleMixedMeshShapeAndForceBatchDim( .dimensions(iter.second)); } - if (batch_size % num_devices != 0) { + if (IsDivisible(batch_size, num_devices)) { if (solver_option.allow_mixed_mesh_shape) { solver_option.allow_mixed_mesh_shape = false; LOG(WARNING) @@ -760,7 +776,8 @@ StatusOr> CreateParameterStrategyVector( LeafStrategies& leaf_strategies, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const AutoShardingSolverOption& solver_option, double replicated_penalty, - const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph) { + const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph, + bool only_allow_divisible) { std::unique_ptr strategies; if (shape.IsTuple()) { strategies = CreateTupleStrategyVector(instruction_id); @@ -770,14 +787,15 @@ StatusOr> CreateParameterStrategyVector( CreateParameterStrategyVector( ins, shape.tuple_shapes().at(i), instruction_id, leaf_strategies, cluster_env, strategy_map, solver_option, replicated_penalty, - batch_dim_map, call_graph) + batch_dim_map, call_graph, only_allow_divisible) .value()); } } else if (shape.IsArray()) { strategies = CreateLeafStrategyVector(instruction_id, ins, strategy_map, leaf_strategies); EnumerateAll1DPartition(ins, shape, cluster_env.device_mesh_, cluster_env, - strategy_map, strategies, true, "", call_graph); + strategy_map, strategies, only_allow_divisible, "", + call_graph); // Split 2 dims if (cluster_env.IsDeviceMesh2D()) { // NOTE(zhuohan): In full alpa, we only include 2D partition strategy @@ -785,8 +803,8 @@ StatusOr> CreateParameterStrategyVector( // this logic here since this pass might be used for // more general cases. EnumerateAll2DPartition(ins, shape, cluster_env.device_mesh_, cluster_env, - strategy_map, strategies, batch_dim_map, true, - call_graph); + strategy_map, strategies, batch_dim_map, + only_allow_divisible, call_graph); } if (solver_option.allow_mixed_mesh_shape && cluster_env.IsDeviceMesh2D()) { @@ -797,8 +815,8 @@ StatusOr> CreateParameterStrategyVector( // Split 1 dim, but for 1d mesh EnumerateAll1DPartition(ins, shape, cluster_env.device_mesh_1d_, - cluster_env, strategy_map, strategies, true, - " 1d", call_graph); + cluster_env, strategy_map, strategies, + only_allow_divisible, " 1d", call_graph); } if (solver_option.allow_replicated_parameters || strategies->leaf_vector.empty()) { @@ -829,14 +847,19 @@ bool ShardingIsComplete(const HloSharding& sharding, size_t total_num_devices) { // Two shardings shard the same dimension of a given tensor. bool ShardingIsConsistent(const HloSharding& partial_sharding, - const HloSharding& complete_sharding) { + const HloSharding& complete_sharding, bool strict) { if (partial_sharding.tile_assignment().num_dimensions() > complete_sharding.tile_assignment().num_dimensions()) { return false; } for (size_t i = 0; i < partial_sharding.tile_assignment().num_dimensions(); ++i) { - if (partial_sharding.tile_assignment().dim(i) > 1 && + if (strict && partial_sharding.tile_assignment().dim(i) > 1 && + partial_sharding.tile_assignment().dim(i) == + complete_sharding.tile_assignment().dim(i)) { + return true; + } + if (!strict && partial_sharding.tile_assignment().dim(i) > 1 && complete_sharding.tile_assignment().dim(i) > 1) { return true; } @@ -861,13 +884,13 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( const std::vector instructions, const HloSharding& existing_sharding, const ClusterEnvironment& cluster_env, StableHashMap>& trimmed_strategy_map, - const CallGraph& call_graph) { + const CallGraph& call_graph, bool strict) { if (strategies->is_tuple) { for (size_t i = 0; i < strategies->childs.size(); ++i) { TrimOrGenerateStrategiesBasedOnExistingSharding( output_shape.tuple_shapes(i), strategies->childs.at(i).get(), strategy_map, instructions, existing_sharding.tuple_elements().at(i), - cluster_env, trimmed_strategy_map, call_graph); + cluster_env, trimmed_strategy_map, call_graph, strict); } } else { if (ShardingIsComplete(existing_sharding, @@ -923,13 +946,17 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( ShardingStrategy({name, existing_sharding, 0, 0, memory_cost, resharding_costs, input_shardings})); } - } else { + } else if (!strategies->following) { // If existing sharding is a partial sharding from previous iteration, // find the strategies that are 1D&&complete or align with user // sharding. + // It is IMPORTANT that we do this only for instructions that do no follow + // others, to keep the number of ILP variable small. std::vector new_vector; for (const auto& strategy : strategies->leaf_vector) { - if (ShardingIsConsistent(existing_sharding, strategy.output_sharding) || + if (strategy.output_sharding.IsReplicated() || + ShardingIsConsistent(existing_sharding, strategy.output_sharding, + strict) || (VectorGreaterThanOneElementCount( strategy.output_sharding.tile_assignment().dimensions()) == 1 && @@ -942,7 +969,9 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( // If no sharding strategy left, just keep the original set, because we do // not have to strictly keep those shardings and the only purpose is to // reduce problem size for the last iteration. - if (!new_vector.empty()) { + if (!new_vector.empty() && + new_vector.size() != strategies->leaf_vector.size()) { + strategies->following = nullptr; strategies->leaf_vector = std::move(new_vector); } } @@ -1062,6 +1091,7 @@ bool LeafVectorsAreConsistent(const std::vector& one, // Build possible sharding strategies and their costs for all instructions. StatusOr> BuildStrategyAndCost(const HloInstructionSequence& sequence, + const HloModule* module, const InstructionDepthMap& depth_map, const InstructionBatchDimMap& batch_dim_map, const AliasMap& alias_map, @@ -1114,15 +1144,31 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, std::unique_ptr strategies; HloOpcode opcode = ins->opcode(); + + bool only_allow_divisible; + if (IsEntryComputationInputOrOutput(module, ins)) { + // With IsEntryComputationInputOrOutput(module, ins) == true, entry + // computation's root instruction may still be unevenly sharded because it + // usually "follows" other instruction's sharding. If the instruction it + // follows is an intermediate instruction, it may be able to choose + // unevenly sharded strategiyes. Usually if we constraint input's sharding + // strategies, outputs would be constrained as welll, but if outputs are + // still unevely sharded in some cases, we need to fix the implementation + // in auto sharding. + only_allow_divisible = solver_option.only_allow_divisible_input_output; + } else { + only_allow_divisible = solver_option.only_allow_divisible_intermediate; + } switch (opcode) { case HloOpcode::kParameter: case HloOpcode::kRngBitGenerator: case HloOpcode::kRng: { - strategies = CreateParameterStrategyVector( - ins, ins->shape(), instruction_id, leaf_strategies, - cluster_env, strategy_map, solver_option, - replicated_penalty, batch_dim_map, call_graph) - .value(); + strategies = + CreateParameterStrategyVector( + ins, ins->shape(), instruction_id, leaf_strategies, cluster_env, + strategy_map, solver_option, replicated_penalty, batch_dim_map, + call_graph, only_allow_divisible) + .value(); break; } case HloOpcode::kConstant: { @@ -1157,7 +1203,9 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, // mesh. // TODO(b/220942808) Shard non-divisible dimensions. if (device_mesh.dim(j) == 1 || - !IsDivisible(shape.dimensions(index_dim), device_mesh.dim(j))) { + (only_allow_divisible && + !IsDivisible(shape.dimensions(index_dim), + device_mesh.dim(j)))) { continue; } std::string name = absl::StrCat("S", std::to_string(index_dim), @@ -1206,16 +1254,18 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, CHECK(!operand_strategies->is_tuple); if (ins->shape().rank() == 1 || cluster_env.IsDeviceMesh1D()) { EnumerateAll1DPartition(ins, ins->shape(), cluster_env.device_mesh_, - cluster_env, strategy_map, strategies, true, - "", call_graph); + cluster_env, strategy_map, strategies, + only_allow_divisible, "", call_graph); } else { EnumerateAll2DPartition(ins, ins->shape(), cluster_env.device_mesh_, cluster_env, strategy_map, strategies, - batch_dim_map, true, call_graph); + batch_dim_map, only_allow_divisible, + call_graph); if (solver_option.allow_mixed_mesh_shape) { - EnumerateAll1DPartition( - ins, ins->shape(), cluster_env.device_mesh_1d_, cluster_env, - strategy_map, strategies, true, "1d", call_graph); + EnumerateAll1DPartition(ins, ins->shape(), + cluster_env.device_mesh_1d_, cluster_env, + strategy_map, strategies, + only_allow_divisible, "1d", call_graph); } } AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, @@ -1278,19 +1328,21 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, // Split 1 dim if (cluster_env.IsDeviceMesh1D()) { EnumerateAll1DPartitionReshape(ins, device_mesh, cluster_env, - strategy_map, strategies, ""); + strategy_map, strategies, + only_allow_divisible, ""); } if (solver_option.allow_mixed_mesh_shape && cluster_env.IsDeviceMesh2D()) { // Split 1 dim, but for 1d mesh EnumerateAll1DPartitionReshape(ins, device_mesh_1d, cluster_env, - strategy_map, strategies, " 1d"); + strategy_map, strategies, + only_allow_divisible, " 1d"); } if (cluster_env.IsDeviceMesh2D()) { // Split 2 dim, one is always the batch dim Enumerate2DPartitionReshape(ins, device_mesh, cluster_env, - strategy_map, batch_dim_map, - strategies); + strategy_map, batch_dim_map, strategies, + only_allow_divisible); } // Replicate @@ -1474,6 +1526,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, case HloOpcode::kSin: case HloOpcode::kSqrt: case HloOpcode::kCbrt: + case HloOpcode::kTan: case HloOpcode::kTanh: // Binary elementwise operations case HloOpcode::kAdd: @@ -1600,14 +1653,14 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, leaf_strategies); if (cluster_env.IsDeviceMesh1D()) { EnumerateAll1DPartition(ins, ins->shape(), device_mesh, cluster_env, - strategy_map, strategies, false, "", - call_graph); + strategy_map, strategies, + only_allow_divisible, "", call_graph); } if (cluster_env.IsDeviceMesh2D()) { // Split 2 dims EnumerateAll2DPartition(ins, ins->shape(), device_mesh, cluster_env, strategy_map, strategies, batch_dim_map, - false, call_graph); + only_allow_divisible, call_graph); } if (cluster_env.IsDeviceMesh2D() && solver_option.allow_mixed_mesh_shape) { @@ -1615,8 +1668,8 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, // For example, when the mesh shape is (2, 4), we add strategies for // mesh shape (1, 8) here in addition. EnumerateAll1DPartition(ins, ins->shape(), device_mesh_1d, - cluster_env, strategy_map, strategies, false, - " 1d", call_graph); + cluster_env, strategy_map, strategies, + only_allow_divisible, " 1d", call_graph); } if (strategies->leaf_vector.empty() || IsFollowedByBroadcast(ins)) { @@ -1661,8 +1714,12 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, /* have_memory_cost= */ true, leaf_strategies, cluster_env, trimmed_strategy_map); } else if (ins->has_sharding()) { - strategies = CreateLeafStrategyVector(instruction_id, ins, - strategy_map, leaf_strategies); + if (ins->shape().IsTuple()) { + strategies = CreateTupleStrategyVector(instruction_id); + } else { + strategies = CreateLeafStrategyVector( + instruction_id, ins, strategy_map, leaf_strategies); + } } else if (OutputInputSameShapes(ins)) { auto* partitioner = GetCustomCallPartitioner(ins->custom_call_target()); @@ -1678,10 +1735,25 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, trimmed_strategy_map); } } else { - strategies = CreateLeafStrategyVector(instruction_id, ins, - strategy_map, leaf_strategies); - AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, - strategies, replicated_penalty); + // TODO (b/258723035) Handle CustomCall ops for GPUs in a better way. + if (ins->shape().IsTuple()) { + strategies = CreateTupleStrategyVector(instruction_id); + strategies->childs.reserve(ins->shape().tuple_shapes_size()); + for (size_t i = 0; i < ins->shape().tuple_shapes_size(); ++i) { + std::unique_ptr child_strategies = + CreateLeafStrategyVector(instruction_id, ins, strategy_map, + leaf_strategies); + AddReplicatedStrategy(ins, ins->shape().tuple_shapes(i), + cluster_env, strategy_map, child_strategies, + replicated_penalty); + strategies->childs.push_back(std::move(child_strategies)); + } + } else { + strategies = CreateLeafStrategyVector( + instruction_id, ins, strategy_map, leaf_strategies); + AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, + strategies, replicated_penalty); + } } break; } @@ -1709,10 +1781,10 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, // Do not merge nodes if this one instruction has annotations. // TODO(b/208668853) If needed, we can make auto sharding faster by using // this sharding spec when merging node using strategies->following. - strategies->following = nullptr; TrimOrGenerateStrategiesBasedOnExistingSharding( ins->shape(), strategies.get(), strategy_map, instructions, - ins->sharding(), cluster_env, trimmed_strategy_map, call_graph); + ins->sharding(), cluster_env, trimmed_strategy_map, call_graph, + solver_option.nd_sharding_iteratively_strict_search_space); } if (!strategies->is_tuple && strategies->following) { if (!LeafVectorsAreConsistent( @@ -1736,6 +1808,7 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, } } } + RemoveInvalidShardingsWithShapes(ins->shape(), strategies.get()); XLA_VLOG_LINES(2, absl::StrCat("strategies:\n", strategies->ToString())); // Debug options: forcibly set the strategy of some instructions. @@ -1763,7 +1836,6 @@ BuildStrategyAndCost(const HloInstructionSequence& sequence, // Checks the shape of resharding_costs is valid. It will check fail if the // shape is not as expected. CheckReshardingCostsShape(strategies.get()); - RemoveInvalidShardingsWithShapes(ins->shape(), strategies.get()); CheckMemoryCosts(strategies.get(), ins->shape()); strategy_map[ins] = std::move(strategies); } // end of for loop @@ -1907,31 +1979,38 @@ CallORToolsSolver(int64_t N, int64_t M, const std::vector& s_len, const std::vector& instruction_names) { size_t num_edges = E.size(); - std::unique_ptr solver( - std::make_unique("", MPSolver::GLPK_MIXED_INTEGER_PROGRAMMING)); + int32_t num_workers = 32; + // SAT or SCIP + std::unique_ptr solver(std::make_unique("", MPSolver::GLPK_MIXED_INTEGER_PROGRAMMING)); CHECK(solver); solver->MutableObjective()->SetMinimization(); - + std::string solver_parameter_str; +#ifdef PLATFORM_GOOGLE + if (solver->ProblemType() == + operations_research::MPSolver::SAT_INTEGER_PROGRAMMING) { + // Set random_seed, interleave_search and share_binary_clauses for + // determinism, and num_workers for parallelism. + solver_parameter_str = absl::StrCat( + "share_binary_clauses:false,random_seed:1,interleave_" + "search:true,num_workers:", + num_workers); + solver->SetSolverSpecificParametersAsString(solver_parameter_str); + } +#endif // Create variables std::vector> s(N); std::vector> e(num_edges); size_t var_vector_cnt = 0; - size_t var_cnt = 0; for (size_t i = 0; i < N; ++i) { if (s_follow[i] < 0) { var_vector_cnt += 1; - var_cnt += s_len[i]; // Creates variables for instructions that do not follow others. solver->MakeBoolVarArray( s_len[i], absl::StrCat("s[", std::to_string(i), "]"), &s[i]); } } - VLOG(1) << "Total variables for ILP: " << var_cnt - << ", total vector of variables: " << var_vector_cnt - << ", total instructions: " << N; - for (size_t i = 0; i < N; ++i) { if (s_follow[i] >= 0) { // Copies the variable of followed instruction to the following @@ -2046,6 +2125,7 @@ CallORToolsSolver(int64_t N, int64_t M, const std::vector& s_len, } } } + // d. specified via "BoolVarArray" // e. for (size_t i = 0; i < num_edges; ++i) { @@ -2098,21 +2178,50 @@ CallORToolsSolver(int64_t N, int64_t M, const std::vector& s_len, } } } - // Solve - VLOG(1) << "Total number of ILP constraints: " << solver->NumConstraints(); + +#ifdef PLATFORM_GOOGLE + // Exports the model for debugging. + bool dump_model = false; + if (dump_model) { + operations_research::MPModelProto model_proto; + solver->ExportModelToProto(&model_proto); + auto write_status = file::SetTextProto( + // Modify this file path if needed. + absl::StrCat("/tmp/model_", solver->NumVariables(), ".proto"), + model_proto, file::Defaults()); + if (!write_status.ok()) { + LOG(ERROR) << write_status.message(); + } + } +#endif + solver->set_time_limit(3600 * 1000); // in ms + VLOG(0) << "Starting solver " << solver->ProblemType() << "\n" + << "Solver parameter string: " << solver_parameter_str << "\n" + << "Number of workers: " << num_workers << "\n" + << "Number of threads: " << solver->GetNumThreads() << "\n" + << "Time limit: " << solver->time_limit() << "\n" + << "Number variables for ILP: " << solver->NumVariables() << "\n" + << "Total vector of variables: " << var_vector_cnt << "\n" + << "Total instructions: " << N << "\n" + << "Memory budget: " << M / (1024 * 1024 * 1024) << "GB\n" + << "Number of ILP constraints: " << solver->NumConstraints(); auto status = solver->Solve(); if (status == operations_research::MPSolver::INFEASIBLE) { LOG(ERROR) << "MPSolver could not find any feasible solution."; - /* - // TODO (zhuohan): Move this part of code to a non-open sourced position. - // Need to include "util/task/status.pb.h" +#ifdef PLATFORM_GOOGLE operations_research::MPModelRequest model_request; solver->ExportModelToProto(model_request.mutable_model()); - model_request.set_solver_type( - operations_research::MPModelRequest::SCIP_MIXED_INTEGER_PROGRAMMING); + if (solver->ProblemType() == + operations_research::MPSolver::SAT_INTEGER_PROGRAMMING) { + model_request.set_solver_type( + operations_research::MPModelRequest::SAT_INTEGER_PROGRAMMING); + } else if (solver->ProblemType() == + operations_research::MPSolver::SCIP_MIXED_INTEGER_PROGRAMMING) { + model_request.set_solver_type( + operations_research::MPModelRequest::SCIP_MIXED_INTEGER_PROGRAMMING); + } model_request.set_solver_time_limit_seconds(100); auto iis = MPSolver::ComputeIrreducibleInfeasibleSubset(model_request); - LOG(INFO) << iis.status().DebugString(); LOG(INFO) << "Infeasible constraints: "; for (int index : iis.constraint_index()) { @@ -2123,7 +2232,7 @@ CallORToolsSolver(int64_t N, int64_t M, const std::vector& s_len, << " - " << model_request.model().general_constraint(index).DebugString(); } - */ +#endif return tsl::errors::Internal( "MPSolver could not find any feasible solution."); @@ -2328,10 +2437,9 @@ void CheckHloSharding(const HloInstructionSequence& sequence, size > 1) { LOG(INFO) << "Instruction is not fully sharded: (" << size << " GB) " << ins->ToString(); + } else if (!ins->has_sharding()) { + LOG(INFO) << "Instruction does not have sharding: " << ins->name(); } - } else if (!ins->has_sharding()) { - LOG(INFO) << "Instruction does not have sharding: " << ins->name(); - } for (const auto& op : ins->operands()) { if (op->has_sharding()) { if (op->sharding().IsReplicated() || ins->sharding().IsReplicated()) { @@ -2373,6 +2481,7 @@ void CheckHloSharding(const HloInstructionSequence& sequence, } } } + } struct { bool operator()(const std::pair& a, const std::pair& b) const { @@ -2712,9 +2821,12 @@ void SaveShardingForInstruction( absl::flat_hash_map>& preserve_shardings, HloInstruction* inst) { - if (inst->has_sharding() && !inst->sharding().IsTuple()) { + if (!inst->has_sharding()) { + return; + } + if (!inst->sharding().IsTuple()) { preserve_shardings[inst->name()] = {inst->sharding()}; - } else if (inst->has_sharding() && inst->sharding().IsTuple()) { + } else { preserve_shardings[inst->name()] = inst->sharding().tuple_elements(); } } @@ -2779,22 +2891,22 @@ void CheckUserShardingPreservation( if (preserve_shardings.find(inst->name()) == preserve_shardings.end()) { continue; } - if (!inst->has_sharding()) { - LOG(FATAL) << "User sharding is not preserved! Instruction with name " - << inst->name() << " should be: " - << preserve_shardings.at(inst->name())[0].ToString() - << "\nbut it's empty."; - } else if (!inst->sharding().IsTuple() && - preserve_shardings.at(inst->name())[0].ToString() != - inst->sharding().ToString()) { - LOG(FATAL) << "User sharding is not preserved! Instruction with name " - << inst->name() << " should be: " - << preserve_shardings.at(inst->name())[0].ToString() - << "\nbut it's: " << inst->sharding().ToString(); - } else if (inst->sharding().IsTuple()) { - const std::vector* preserve_shardings_tuple = - &preserve_shardings.at(inst->name()); - for (size_t i = 0; i < inst->shape().tuple_shapes_size(); i++) { + if (!inst->has_sharding()) { + LOG(FATAL) << "User sharding is not preserved! Instruction with name " + << inst->name() << " should be: " + << preserve_shardings.at(inst->name())[0].ToString() + << "\nbut it's empty."; + } else if (!inst->sharding().IsTuple() && + preserve_shardings.at(inst->name())[0].ToString() != + inst->sharding().ToString()) { + LOG(FATAL) << "User sharding is not preserved! Instruction with name " + << inst->name() << " should be: " + << preserve_shardings.at(inst->name())[0].ToString() + << "\nbut it's: " << inst->sharding().ToString(); + } else if (inst->sharding().IsTuple()) { + const std::vector* preserve_shardings_tuple = + &preserve_shardings.at(inst->name()); + for (size_t i = 0; i < inst->shape().tuple_shapes_size(); i++) { if (preserve_shardings_tuple->at(i).ToString() != inst->sharding().tuple_elements().at(i).ToString()) { LOG(FATAL) << "Tuple sharding is not preserved! Instruction " @@ -2805,8 +2917,8 @@ void CheckUserShardingPreservation( << "\nbut it's: " << inst->sharding().tuple_elements().at(i).ToString(); } - } } + } } } } @@ -2820,19 +2932,19 @@ int64_t MemoryBudgetLowerBound(const HloModule& module, for (const HloValue* value : liveness_set[t]) { size_t tmp; if (value->instruction()->shape().IsTuple() && value->index().empty()) { - continue; + continue; } Shape shape = ShapeUtil::GetSubshape(value->instruction()->shape(), value->index()); if (value->instruction()->has_sharding()) { - tmp = GetShardedInstructionSize( - shape, num_devices, - !value->index().empty() - ? value->instruction()->sharding().GetSubSharding( - value->instruction()->shape(), value->index()) - : value->instruction()->sharding()); + tmp = GetShardedInstructionSize( + shape, num_devices, + !value->index().empty() + ? value->instruction()->sharding().GetSubSharding( + value->instruction()->shape(), value->index()) + : value->instruction()->sharding()); } else { - tmp = GetShardedInstructionSize(shape, num_devices); + tmp = GetShardedInstructionSize(shape, num_devices); } memory_usage += tmp; } @@ -2862,6 +2974,669 @@ void RecoverShardingsFromPartialMesh( } } } +// DFS to find the replicated set starting from cur instruction. +void FindReplicateSet( + HloInstruction* cur, const AliasMap& alias_map, const CostGraph& cost_graph, + absl::Span s_val, const StrategyMap& strategy_map, + const ShardingStrategy& strategy, const HloInstruction* output, + bool do_all_gather_after_backward, HloInstruction*& transpose_inst, + StableHashSet& replicated_set, + StableHashSet& boundary_set, + StableHashSet& consumer_set, + StableHashSet& visited) { + visited.insert(cur); + + // Check whether the node is a boundary node. + StableHashSet users = UsersWithAlias(cur, alias_map, output); + for (HloInstruction* consumer : users) { + const HloInstruction* shape_inst = cur; + + // Allow at most one transpose + if (consumer->opcode() == HloOpcode::kTranspose && + (transpose_inst == nullptr || + DimensionsEqual(transpose_inst->shape(), consumer->shape()))) { + shape_inst = consumer; + transpose_inst = consumer; + // TODO(zhuohan): fix output_sharding comparison. + } + + if (consumer->opcode() == HloOpcode::kTuple || + (do_all_gather_after_backward && IsParameterConvert(consumer)) || + GetShardingStrategy(consumer, strategy_map, cost_graph, s_val) + .output_sharding != strategy.output_sharding || + !DimensionsEqual(consumer->shape(), shape_inst->shape())) { + boundary_set.insert(cur); + return; + } + } + + // If this node is not a boundary node, propagate from this node. + replicated_set.insert(cur); + for (HloInstruction* consumer : users) { + if (!visited.contains(consumer)) { + consumer_set.insert(consumer); + FindReplicateSet(consumer, alias_map, cost_graph, s_val, strategy_map, + strategy, output, do_all_gather_after_backward, + transpose_inst, replicated_set, boundary_set, + consumer_set, visited); + } + } + + for (size_t i = 0; i < cur->operand_count(); ++i) { + HloInstruction* operand = cur->mutable_operand(i); + operand = PassThroughCustomCallMarkerOperand(operand, cur); + + if (!visited.contains(operand) && !IsAlwaysReplicated(operand) && + GetShardingStrategy(operand, strategy_map, cost_graph, s_val) + .output_sharding == strategy.output_sharding && + DimensionsEqual(operand->shape(), cur->shape())) { + FindReplicateSet(operand, alias_map, cost_graph, s_val, strategy_map, + strategy, output, do_all_gather_after_backward, + transpose_inst, replicated_set, boundary_set, + consumer_set, visited); + } + } +} + +// Substitute all-reduce strategies with their reduce-scatter variants. +void GenerateReduceScatter(const HloInstructionSequence& sequence, + const AliasMap& alias_map, + const InstructionDepthMap& depth_map, + const StrategyMap& strategy_map, + const CostGraph& cost_graph, + absl::Span s_val, + const ClusterEnvironment& cluster_env, + const AutoShardingSolverOption& solver_option) { + const std::vector& instructions = sequence.instructions(); + + // Propagation ends at output + const HloInstruction* output = instructions.back(); + if (IsCustomCallMarker(output)) { + output = output->operand(0); + } + + // A debug option: whether to do all-gather after backward pass. + // This controls the location of all-gather. + // If true, all-gather happens after backward pass, which is desired for + // gradient accumulation. If false, all-gather happens before forward pass, + // which can partitions more tensors. + bool do_all_gather_after_backward = true; + + // If true, do not actually generate reduce-scatter + all-gather, + // but generate all-reduce + all-gather instead. + // This saves less memory but is more friendly to gradient accumulation. + // This is a temporary workaround due to implementation difficulty. + // Ideally, we should be able to generate a gradient-accumulation-friendly + // reduce-scatter + all-gather, but for now it is not easy to implement this + // in our current system. So we generate a gradient-accumulation-friendly + // all-reduce + all-gather, which has the same memory consumption but with 50% + // communication overhead. + bool use_all_reduce_for_grad_acc = + solver_option.reduce_scatter_grad_acc_friendly; + + std::vector insert_all_gather; + StableHashSet modified; + + for (HloInstruction* inst : instructions) { + if (!HasReduceScatterOpportunity(inst, strategy_map, cost_graph, s_val, + modified)) { + continue; + } + const ShardingStrategy& strategy = + GetShardingStrategy(inst, strategy_map, cost_graph, s_val); + if (!absl::StrContains(strategy.name, "allreduce")) { + continue; + } + + StableHashSet replicated_set; + StableHashSet boundary_set; + StableHashSet consumer_set; + StableHashSet visited; + + // We allow at most one transpose in the path of replication analysis. + HloInstruction* transpose_inst = nullptr; + + // Find the replicated set starting from the all-reduce instruction. + visited.insert(output); + FindReplicateSet(inst, alias_map, cost_graph, s_val, strategy_map, strategy, + output, do_all_gather_after_backward, transpose_inst, + replicated_set, boundary_set, consumer_set, visited); + + // Try to reduce the boundary set to its common ancestor + TryReduceWithCommonAncestor(replicated_set, boundary_set, consumer_set, + alias_map); + + // Analyze the instructions after which all-gather should be inserted. + std::vector need_all_gather; + for (HloInstruction* node : boundary_set) { + if (consumer_set.contains(node)) { + if (AllUsersAreReduce(node)) { + // If users are reduce, the all-gather cost after this instruction + // should be small, so we ignore all-gather cost of these + // instructions. + replicated_set.insert(node); + } else { + need_all_gather.push_back(node); + } + } + } + + // If we do all-gather on some parameters, move this all-gather after + // backward. + if (do_all_gather_after_backward && need_all_gather.size() == 1) { + HloInstruction* point = need_all_gather.front(); + std::vector path; + HloInstruction* root = point; + while (true) { + path.push_back(root); + if (root->opcode() == HloOpcode::kGetTupleElement) { + root = PassThroughCustomCallMarkerOperand(root->mutable_operand(0), + root); + } else { + break; + } + } + + if (root->opcode() == HloOpcode::kParameter) { + for (auto x : path) { + replicated_set.erase(x); + boundary_set.erase(x); + } + need_all_gather.clear(); + for (auto x : replicated_set) { + auto iter = alias_map.find(x); + if (iter != alias_map.end() && iter->second == root) { + boundary_set.insert(x); + need_all_gather.push_back(x); + break; + } + } + } + } + + // Analyze how many parameters can be partitioned if we do this + // transformation. + int num_replicated_parameters = 0; + for (const HloInstruction* node : replicated_set) { + if (node->opcode() == HloOpcode::kParameter) { + num_replicated_parameters++; + } + } + for (const HloInstruction* to_split : need_all_gather) { + if (to_split->users().size() == 1 && + to_split->users().front() == output && alias_map.contains(to_split)) { + // Move the all-gather to its alias parameter. + num_replicated_parameters++; + } + } + + // Print replicated set and boundary set for debugging. + VLOG(10) << inst->ToString(HloPrintOptions::ShortParsable()) << "\n"; + VLOG(10) << "replicated set (#parameter: " << num_replicated_parameters + << "):\n"; + for (auto x : replicated_set) { + VLOG(10) << " " << x->ToString(HloPrintOptions::ShortParsable()) << "\n"; + } + VLOG(10) << "boundary set (#incompatible: " << need_all_gather.size() + << "):\n"; + for (auto x : boundary_set) { + VLOG(10) << " " << x->ToString(HloPrintOptions::ShortParsable()) << " " + << absl::c_linear_search(need_all_gather, x) << "\n"; + } + + // If applicable, replace all-reduce with reduce-scatter by + // setting instructions' sharding. + if (num_replicated_parameters >= 1 && need_all_gather.size() <= 1 && + replicated_set.size() >= 5) { + HloSharding output_spec = + GetReduceScatterOutput(inst, strategy, cluster_env); + if (IsUndefined(output_spec)) { + continue; + } + + VLOG(10) << "SET: " << output_spec.ToString(); + + if (absl::StartsWith(strategy.name, "RR = RS x SR")) { + // If set the sharding for this dot instruction, the SPMD + // partitioner will generate bad fallback code. + replicated_set.erase(inst); + } + + if (use_all_reduce_for_grad_acc) { + UseAllReduceForGradAcc(replicated_set, inst); + } + + for (HloInstruction* to_split : replicated_set) { + SetSharding(to_split, output_spec, inst, transpose_inst, modified); + } + + if (!solver_option.reduce_scatter_aggressive_partition) { + // The normal case + for (HloInstruction* to_split : need_all_gather) { + SetSharding(to_split, output_spec, inst, transpose_inst, modified); + + if (!do_all_gather_after_backward && to_split->users().size() == 1 && + to_split->users().front() == output && + alias_map.contains(to_split)) { + // Move the all-gather to its alias parameter. + // This partitions more tensors but introduces communication + // in the forward pass, which is not desired in gradient + // accumulation. + SetSharding(alias_map.at(to_split), output_spec, inst, + transpose_inst, modified); + insert_all_gather.push_back(alias_map.at(to_split)); + } else { + insert_all_gather.push_back(to_split); + + if (to_split->opcode() == HloOpcode::kGetTupleElement && + IsCustomCallMarker(to_split->operand(0)) && + to_split->users().size() == 1 && + to_split->users().front() == output) { + insert_all_gather.push_back(PassThroughCustomCallMarkerOperand( + to_split->mutable_operand(0), to_split)); + } + } + } + } else { + // Aggressively partition more parameter tensors. + // This can result in a strategy similar to ZeRO stage 3. + // NOTE: The combination of this branch with pipeline parallel is not + // tested. + for (HloInstruction* to_split : need_all_gather) { + SetSharding(to_split, output_spec, inst, transpose_inst, modified); + + if (to_split->users().size() == 1 && + to_split->users().front() == output && + alias_map.contains(to_split)) { + // Move the all-gather to its alias parameter. + HloInstruction* param = alias_map.at(to_split); + + // Find the branching point (i.e., skip elementwise ops like + // convert) + HloInstruction* cur = param; + while (cur->users().size() == 1) { + // TODO(zhuohan): handle tuple. + CHECK(cur->shape().IsArray()); + SetSharding(cur, output_spec, inst, transpose_inst, modified); + cur = cur->users().front(); + } + SetSharding(cur, output_spec, inst, transpose_inst, modified); + + CHECK(!cur->users().empty()); + + // Find the first user + HloInstruction* first_user = nullptr; + int64_t min_depth = ((int64_t)1) << 50; + for (const auto& x : cur->users()) { + auto iter = depth_map.find(x); + if (iter == depth_map.end()) { + LOG(FATAL) << "ERROR: " << x->ToString(); + } + if (x->opcode() != HloOpcode::kConvolution && + x->opcode() != HloOpcode::kDot) { + // Only apply this aggressive optimization for dot and conv + continue; + } + if (iter->second < min_depth) { + first_user = x; + min_depth = iter->second; + } + } + + if (first_user != nullptr) { + // Insert an identity to prevent CSE of all-gather + HloInstruction* identity = inst->parent()->AddInstruction( + HloInstruction::CreateCustomCall(cur->shape(), {cur}, + kIdentityMarker)); + SetSharding(identity, output_spec, inst, transpose_inst, + modified); + ReplaceOperand(first_user, cur, identity); + } + } + } + } + } + + VLOG(10) << "-----------------------done\n"; + } + + // Insert all-gather on the output of boundary nodes by setting + // their shardings. This also works as CSE of all-gather. + for (HloInstruction* inst : insert_all_gather) { + HloInstruction* replace_with = inst->parent()->AddInstruction( + HloInstruction::CreateReshape(inst->shape(), inst)); + replace_with->set_sharding( + GetShardingStrategy(inst, strategy_map, cost_graph, s_val) + .output_sharding); + TF_CHECK_OK(inst->ReplaceAllUsesWith(replace_with)); + } +} + +void AnnotateShardingWithSimpleHeuristic( + HloModule* module, const std::string& heuristic, const AliasMap& alias_map, + const ClusterEnvironment& cluster_env) { + const Array& device_mesh = cluster_env.device_mesh_; + const Array& device_mesh_1d = cluster_env.device_mesh_1d_; + int64_t num_devices = device_mesh.num_elements(); + + // Count the non-one mesh dimension. + size_t mesh_nn_dims = 0; + for (int dim : device_mesh.dimensions()) { + if (dim > 1) { + mesh_nn_dims++; + } + } + + // Shard instructions + HloComputation* entry_computation = module->entry_computation(); + for (HloInstruction* inst : entry_computation->instructions()) { + if (inst->opcode() == HloOpcode::kParameter) { + HloSharding output_spec = HloSharding::Replicate(); + inst->set_sharding(output_spec); + + if (heuristic == "shard-largest") { + std::vector lengths; + for (int64_t i = 0; i < inst->shape().rank(); ++i) { + lengths.push_back(inst->shape().dimensions(i)); + } + + std::vector indices = Argsort(lengths); + int common_dims = std::min(mesh_nn_dims, indices.size()); + + if (common_dims < 1) { + continue; + } + + if (common_dims == 1) { + int dim = indices[0]; + int length = lengths[dim]; + if (length % num_devices == 0) { + output_spec = Tile(inst->shape(), {dim}, {0}, device_mesh_1d); + } + } else { + int dim1 = indices[0]; + int length1 = lengths[dim1]; + int dim0 = indices[1]; + int length0 = lengths[dim0]; + + if (length0 % device_mesh.dim(0) == 0 && + length1 % device_mesh.dim(1) == 0) { + output_spec = + Tile(inst->shape(), {dim0, dim1}, {0, 1}, device_mesh); + } + } + } else if (heuristic == "shard-first") { + if (inst->shape().rank() > 0 && + inst->shape().dimensions(0) % num_devices == 0) { + output_spec = Tile(inst->shape(), {0}, {0}, device_mesh_1d); + } + } else if (heuristic == "shard-last") { + int64_t last_dim = inst->shape().rank() - 1; + if (inst->shape().rank() > 0 && + inst->shape().dimensions(last_dim) % num_devices == 0) { + output_spec = Tile(inst->shape(), {last_dim}, {0}, device_mesh_1d); + } + } else { + LOG(FATAL) << "Invalid heuristic: " << heuristic; + } + + inst->set_sharding(output_spec); + // std::cerr << "ins: " << inst->ToString() << ", spec: " << + // output_spec.ToString() << std::endl; + } else if (inst->opcode() == HloOpcode::kDot) { + const HloInstruction* lhs = inst->operand(0); + const HloInstruction* rhs = inst->operand(1); + const DotDimensionNumbers& dot_dnums = inst->dot_dimension_numbers(); + // const auto& lhs_con_dims = dot_dnums.lhs_contracting_dimensions(); + // const auto& rhs_con_dims = dot_dnums.rhs_contracting_dimensions(); + std::vector lhs_space_dims, rhs_space_dims; + std::tie(lhs_space_dims, rhs_space_dims) = + GetSpaceDims(lhs->shape(), rhs->shape(), dot_dnums); + } + } + + // Meet the alias requirement for the output tuple. + HloInstruction* output = entry_computation->root_instruction(); + const Shape& out_shape = output->shape(); + ShapeTree tuple_sharding(out_shape, HloSharding::Replicate()); + std::vector flattened_shardings; + + std::function get_flattened_shardings; + get_flattened_shardings = [&](HloInstruction* cur) { + for (int64_t i = 0; i < cur->operand_count(); ++i) { + HloInstruction* operand = cur->mutable_operand(i); + + if (operand->shape().IsTuple()) { + get_flattened_shardings(operand); + } else { + if (alias_map.contains(operand)) { + operand = alias_map.at(operand); + } + if (!operand->has_sharding()) { + operand->set_sharding(HloSharding::Replicate()); + } + CHECK(operand->has_sharding()); + flattened_shardings.push_back(operand->sharding()); + } + } + }; + get_flattened_shardings(output); + int i = 0; + for (auto& leaf : tuple_sharding.leaves()) { + leaf.second = flattened_shardings[i++]; + } + CHECK_EQ(i, flattened_shardings.size()); + output->set_sharding(HloSharding::Tuple(tuple_sharding)); +} + +// Filter strategies according to the solver_option.force_batch_dim_to_mesh_dim. +// This can be used to forcibly generate data-parallel strategies. +Status FilterStrategy(const HloInstruction* ins, const Shape& shape, + std::unique_ptr& strategies, + const ClusterEnvironment& cluster_env, + const InstructionBatchDimMap& batch_map, + const AutoShardingSolverOption& solver_option) { + int mesh_dim = solver_option.force_batch_dim_to_mesh_dim; + int batch_dim = batch_map.at(GetBatchDimMapKey(ins)); + const Array& device_mesh = cluster_env.device_mesh_; + + if (shape.dimensions(batch_dim) % device_mesh.dim(mesh_dim) != 0) { + return tsl::errors::InvalidArgument( + "The length of batch dimension is " + "not divisible by the number of devices"); + } + + std::vector new_leaf_vector; + for (auto& stra : strategies->leaf_vector) { + std::vector tensor_dim_to_mesh_dim = + cluster_env.GetTensorDimToMeshDimWrapper(shape, stra.output_sharding); + + if (device_mesh.dim(mesh_dim) > 1) { + // If the mesh dim is not one, the output tensor must be + // tiled along the mesh dim. + if (tensor_dim_to_mesh_dim[batch_dim] == mesh_dim) { + new_leaf_vector.push_back(std::move(stra)); + } + } else { + // If the mesh dim is one, the output tensor must be replicated + // on the mesh dim. + if (tensor_dim_to_mesh_dim[batch_dim] == -1) { + new_leaf_vector.push_back(std::move(stra)); + } + } + } + CHECK(!new_leaf_vector.empty()) + << ins->ToString() << " does not have any valid strategies"; + strategies->leaf_vector = std::move(new_leaf_vector); + + return OkStatus(); +} + +// Return the output sharding of the reduce-scatter variant of a given strategy. +HloSharding GetReduceScatterOutput(const HloInstruction* ins, + const ShardingStrategy& strategy, + const ClusterEnvironment& cluster_env) { + const Array& device_mesh = cluster_env.device_mesh_; + const Array& device_mesh_1d = cluster_env.device_mesh_1d_; + + if (ins->opcode() == HloOpcode::kDot) { + const DotDimensionNumbers& dot_dnums = ins->dot_dimension_numbers(); + int64_t space_base_dim = dot_dnums.lhs_batch_dimensions_size(); + + if (absl::StartsWith(strategy.name, "SR = SS x SR") || + absl::StartsWith(strategy.name, "RS = RS x SS")) { + int mesh_dim0, mesh_dim1; + std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(strategy.name); + + if (!IsDivisible(ins, device_mesh, {space_base_dim, space_base_dim + 1}, + {mesh_dim0, mesh_dim1})) { + // XLA supports uneven partitioning by adding padding. + // However, the ShardingSpec in Jax does not support uneven + // partitioning. + return Undefined(); + } + + return Tile(ins->shape(), {space_base_dim, space_base_dim + 1}, + {mesh_dim0, mesh_dim1}, device_mesh); + } + if (absl::StartsWith(strategy.name, "SbR = SbSk x SbSk")) { + int mesh_dim0, mesh_dim1; + std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(strategy.name); + + if (!IsDivisible(ins, device_mesh, {0, space_base_dim}, + {mesh_dim0, mesh_dim1})) { + // XLA supports uneven partitioning by adding padding. + // However, the ShardingSpec in Jax does not support uneven + // partitioning. + return Undefined(); + } + + return Tile(ins->shape(), {0, space_base_dim}, {mesh_dim0, mesh_dim1}, + device_mesh); + } + if (absl::StartsWith(strategy.name, "RR = RS x SR")) { + int mesh_dim = absl::StrContains(strategy.name, "{0}") ? 0 : 1; + + if (!IsDivisible(ins, device_mesh, {space_base_dim}, {mesh_dim})) { + return Undefined(); + } + + return Tile(ins->shape(), {space_base_dim}, {mesh_dim}, device_mesh); + } + if (absl::StartsWith(strategy.name, "R = Sk x Sk")) { + int mesh_dim = 0; + + if (!IsDivisible(ins, device_mesh_1d, {space_base_dim}, {mesh_dim})) { + return Undefined(); + } + + return Tile(ins->shape(), {space_base_dim}, {mesh_dim}, device_mesh_1d); + } + } else if (ins->opcode() == HloOpcode::kConvolution) { + const ConvolutionDimensionNumbers& conv_dnums = + ins->convolution_dimension_numbers(); + int out_batch_dim = conv_dnums.output_batch_dimension(); + int out_out_channel_dim = conv_dnums.output_feature_dimension(); + + if (absl::StartsWith(strategy.name, "SR = SS x SR") || + absl::StartsWith(strategy.name, "RS = RS x SS")) { + int mesh_dim0, mesh_dim1; + std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(strategy.name); + + if (!IsDivisible(ins, device_mesh, {out_batch_dim, out_out_channel_dim}, + {mesh_dim0, mesh_dim1})) { + return Undefined(); + } + + return Tile(ins->shape(), {out_batch_dim, out_out_channel_dim}, + {mesh_dim0, mesh_dim1}, device_mesh); + } + if (absl::StartsWith(strategy.name, "R = Sk x Sk")) { + int mesh_dim = 0; + + if (!IsDivisible(ins, device_mesh_1d, {out_batch_dim}, {mesh_dim})) { + return Undefined(); + } + + return Tile(ins->shape(), {out_batch_dim}, {mesh_dim}, device_mesh_1d); + } + } else if (ins->opcode() == HloOpcode::kReduce) { + // TODO(zhuohan): support more cases. + CHECK_EQ(ins->shape().rank(), 1); + + int mesh_dim; + if (absl::StrContains(strategy.name, "allreduce @ [0]")) { + mesh_dim = 0; + } else { + mesh_dim = 1; + } + + if (strategy.output_sharding.IsReplicated()) { + if (absl::StrContains(strategy.name, "1d")) { + if (!IsDivisible(ins, device_mesh_1d, {0}, {mesh_dim})) { + return Undefined(); + } + + return Tile(ins->shape(), {0}, {mesh_dim}, device_mesh_1d); + } + if (!IsDivisible(ins, device_mesh, {0}, {mesh_dim})) { + return Undefined(); + } + + return Tile(ins->shape(), {0}, {mesh_dim}, device_mesh); + } + if (!IsDivisible(ins, device_mesh_1d, {0}, {0})) { + return Undefined(); + } + + Array tile_assignment = strategy.output_sharding.tile_assignment(); + tile_assignment.Reshape({cluster_env.total_devices_}); + return HloSharding::Tile(std::move(tile_assignment)); + + } else { + LOG(FATAL) << "Invalid instruction: " << ins->ToString(); + } + + return Undefined(); +} + +// Return whether an instruction has the opportunity to generate reduce-scatter. +bool HasReduceScatterOpportunity( + const HloInstruction* inst, const StrategyMap& strategy_map, + const CostGraph& cost_graph, absl::Span s_val, + const StableHashSet& modified) { + // If the operand is already modified by other ops, skip this instruction to + // avoid conflicts. + for (const HloInstruction* operand : inst->operands()) { + if (modified.contains(operand)) { + return false; + } + } + if (modified.contains(inst)) { + return false; + } + + if (inst->opcode() == HloOpcode::kReduce && inst->shape().rank() == 1) { + return true; + } + if (inst->opcode() == HloOpcode::kDot) { + if (GetShardingStrategy(inst->operand(0), strategy_map, cost_graph, s_val) + .output_sharding.IsReplicated() && + GetShardingStrategy(inst->operand(1), strategy_map, cost_graph, s_val) + .output_sharding.IsReplicated()) { + // This dot is replicated on all devices. Do not split it. + // TODO(zhuohan): improve this condition. + return false; + } + + return true; + } + if (inst->opcode() == HloOpcode::kConvolution) { + return true; + } + + return false; +} } // namespace spmd @@ -2939,6 +3714,12 @@ StatusOr AutoSharding::Run( bool module_is_changed = false; VLOG(1) << "Start auto sharding pass"; +#if !defined(__APPLE__) + // Streamz metrics. + absl::Time start_time = absl::Now(); + metrics::RecordAutoShardingInvocations(); +#endif + bool set_to_memory_lower_bound = (option_.memory_budget_per_device == 0); TF_RETURN_IF_ERROR(option_.CheckAndSetup()); VLOG(1) << "AutoShardingOptions:\n" << option_.ToString(); @@ -2977,6 +3758,9 @@ StatusOr AutoSharding::Run( solver_option.force_strategy_inst_indices = option_.force_strategy_inst_indices; solver_option.force_strategy_stra_names = option_.force_strategy_stra_names; + solver_option.only_allow_divisible_input_output = true; + solver_option.only_allow_divisible_intermediate = false; + solver_option.nd_sharding_iteratively_strict_search_space = false; // Remove CustomCalls with custom_call_target="Sharding" and move their // shardings to their input ops. @@ -2989,7 +3773,7 @@ StatusOr AutoSharding::Run( // sharding propagation pass after that before spmd partitioner. auto status_or_changed = ProcessShardingInstruction( module, execution_threads, /*replace_sharding_with_copy=*/true, - &unspecified_dims); + &unspecified_dims, /*saved_root_shardings=*/nullptr); if (!status_or_changed.ok()) { return status_or_changed; } @@ -3052,6 +3836,7 @@ StatusOr AutoSharding::Run( } VLOG(10) << hlo_live_range->ToString(); VLOG(10) << spmd::PrintLivenessSet(liveness_set); + XLA_VLOG_LINES(10, spmd::PrintLivenessSet(liveness_set)); const HloInstructionSequence& sequence = hlo_live_range->flattened_instruction_sequence(); @@ -3112,22 +3897,17 @@ StatusOr AutoSharding::Run( << " GB."; if (set_to_memory_lower_bound) { LOG(INFO) - << "--xla_tpu_auto_spmd_partitioning_memory_budget_gb is 0, setting " - "option.memory_budget_per_device to be the estimated memory " - "consumption lower bound of this module to maximize sharding. " - "Note " - "that the memory consumption estimation does not take into " - "account " - "alias pairs or while op inputs. So if the model " - "is very small such that the alias pairs and while op inputs " - "consist significant memory usage percentage, this lower bound " - "will " - "cause solver being unable to find feasible solutison. Please set " - "xla_tpu_auto_spmd_partitioning_memory_budget_gb to be greater " - "than " - << memory_lower_bound_gb << " if this behavior is undesired."; - option_.memory_budget_per_device = - memory_lower_bound_gb * (1024 * 1024 * 1024); + << "--xla_tpu_auto_spmd_partitioning_memory_budget_gb is 0, and " + "--xla_tpu_auto_spmd_partitioning_memory_budget_ratio is " + << option_.memory_budget_ratio + << ", so setting " + "option.memory_budget_per_device to " + << memory_lower_bound_gb << " x " << option_.memory_budget_ratio + << " = " << memory_lower_bound_gb * option_.memory_budget_ratio + << " GB"; + option_.memory_budget_per_device = memory_lower_bound_gb * + (1024 * 1024 * 1024) * + option_.memory_budget_ratio; } else if (option_.memory_budget_per_device > 0) { option_.memory_budget_per_device = original_memory_budget * original_device_mesh.num_elements() / @@ -3158,8 +3938,9 @@ StatusOr AutoSharding::Run( TF_ASSIGN_OR_RETURN( std::tie(strategy_map, leaf_strategies, associative_dot_pairs), - BuildStrategyAndCost(sequence, ins_depth_map, batch_dim_map, alias_map, - cluster_env, solver_option, *call_graph)); + BuildStrategyAndCost(sequence, module, ins_depth_map, batch_dim_map, + alias_map, cluster_env, solver_option, + *call_graph)); spmd::AliasSet alias_set = spmd::BuildAliasSet(module, strategy_map); CheckAliasSetCompatibility(alias_set, leaf_strategies, sequence); XLA_VLOG_LINES(8, PrintStrategyMap(strategy_map, sequence)); @@ -3216,6 +3997,13 @@ StatusOr AutoSharding::Run( TF_RETURN_IF_ERROR(CanonicalizeLayouts(module)); XLA_VLOG_LINES(6, absl::StrCat("After auto sharding:\n", module->ToString())); DumpHloModuleIfEnabled(*module, "after_auto_spmd_sharding"); + +#if !defined(__APPLE__) + absl::Time end_time = absl::Now(); + auto duration = end_time - start_time; + metrics::RecordAutoShardingCompilationTime( + absl::ToInt64Microseconds(duration)); +#endif return module_is_changed; } diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.h b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.h index 80b036add45..db7fe6bfc41 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -17,17 +17,20 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_H_ #include +#include #include #include #include #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver_option.h" +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/cluster_environment.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/tsl/platform/errors.h" -#include "tensorflow/tsl/protobuf/error_codes.pb.h" - namespace xla { class DummyAutoSharding : public HloModulePass { @@ -73,8 +76,14 @@ struct AutoShardingOption { bool simplify_graph = true; // Memory budget (bytes) per device. Default value -1 means no memory budget. + // Value 0 means setting it to the memory lower bound estimation. int64_t memory_budget_per_device = -1; + // Memory budget = + // memory_budget_ratio * (memory lower bound estimation). + // Enabled when memory_budget_per_device == 0; + float memory_budget_ratio = 1.1; + // Overwrite the all gather cost with the input all reduce cost. bool force_all_gather_cost = false; double all_gather_cost; @@ -329,6 +338,91 @@ class AutoSharding : public HloModulePass { AutoShardingOption option_; }; +namespace spmd { +// Function declarations +// Their comments can be found in their definitions in *.cc files. +HloSharding Tile(const Shape& shape, absl::Span tensor_dims, + absl::Span mesh_dims, + const Array& device_mesh); + +std::vector ReshardingCostVector(const StrategyVector* strategies, + const Shape& shape, + const HloSharding& required_sharding, + const ClusterEnvironment& cluster_env); + +std::vector FollowInsCostVector(int64_t source_len, int64_t index); + +std::unique_ptr CreateLeafStrategyVector( + size_t instruction_id, const HloInstruction* ins, + const StrategyMap& strategy_map, LeafStrategies& leaf_strategies); + +void SetInNodesWithInstruction(std::unique_ptr& strategies, + const HloInstruction* ins, + const StrategyMap& strategy_map); + +void RemoveDuplicatedStrategy(std::unique_ptr& strategies); + +Status FilterStrategy(const HloInstruction* ins, const Shape& shape, + std::unique_ptr& strategies, + const ClusterEnvironment& cluster_env, + const InstructionBatchDimMap& batch_map, + const AutoShardingSolverOption& solver_option); + +Status HandleDot(std::unique_ptr& strategies, + LeafStrategies& leaf_strategies, StrategyMap& strategy_map, + const HloInstruction* ins, size_t instruction_id, + const ClusterEnvironment& cluster_env, + const InstructionBatchDimMap& batch_map, + const AutoShardingSolverOption& solver_option); + +Status HandleConv(std::unique_ptr& strategies, + LeafStrategies& leaf_strategies, StrategyMap& strategy_map, + const HloInstruction* ins, size_t instruction_id, + const ClusterEnvironment& cluster_env, + const InstructionBatchDimMap& batch_map, + const AutoShardingSolverOption& solver_option); + +void AnnotateShardingWithSimpleHeuristic(HloModule* module, + const std::string& heuristic, + const AliasMap& alias_map, + const ClusterEnvironment& cluster_env); + +// Handle alias: alias pairs must have the same HloSharding. +// To deal with alias, we do special process both before and after +// BuildStrategyAndCost. Because it is easier to handle elementwise +// instructions before BuildStrategyAndCost and it is easier to handle +// dot/conv instructions after BuildStrategyAndCost. Before +// BuildStrategyAndCost, we build an AliasMap to guide the generation of +// strategies. After BuildStrategyAndCost, we use AliasSet to add alias +// constraints in the ILP problem. +AliasMap BuildAliasMap(const HloModule* module); + +AliasSet BuildAliasSet(const HloModule* module, + const StrategyMap& strategy_map); + +void CheckAliasSetCompatibility(const AliasSet& alias_set, + const LeafStrategies& leaf_strategies, + const HloInstructionSequence& sequence); + +void GenerateReduceScatter(const HloInstructionSequence& sequence, + const AliasMap& alias_map, + const InstructionDepthMap& depth_map, + const StrategyMap& strategy_map, + const CostGraph& cost_graph, + absl::Span s_val, + const ClusterEnvironment& cluster_env, + const AutoShardingSolverOption& solver_option); + +bool HasReduceScatterOpportunity( + const HloInstruction* inst, const StrategyMap& strategy_map, + const CostGraph& cost_graph, absl::Span s_val, + const StableHashSet& modified); + +HloSharding GetReduceScatterOutput(const HloInstruction* ins, + const ShardingStrategy& strategy, + const ClusterEnvironment& cluster_env); + +} // namespace spmd } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_H_ diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h index 0169d0a80de..718d00aea1f 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h @@ -23,8 +23,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" -#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h" - +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/matrix.h" namespace xla { namespace spmd { @@ -344,15 +343,6 @@ inline const ShardingStrategy& GetShardingStrategy( return strategies->leaf_vector[stra_idx]; } -void GenerateReduceScatter(const HloInstructionSequence& sequence, - const AliasMap& alias_map, - const InstructionDepthMap& depth_map, - const StrategyMap& strategy_map, - const CostGraph& cost_graph, - absl::Span s_val, - const ClusterEnvironment& cluster_env, - const AutoShardingSolverOption& solver_option); - } // namespace spmd } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_COST_GRAPH_H_ diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 9334c366a6e..70b121d1a16 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -20,8 +20,11 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/str_format.h" +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.h" +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver_option.h" #include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h" +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/cluster_environment.h" #include "tensorflow/tsl/platform/errors.h" namespace xla { @@ -88,11 +91,17 @@ class DotHandler { void SplitLhsSpaceRhsSpace(int mesh_dim0, int mesh_dim1) { for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) { for (int64_t j = 0; j < rhs_space_dims_.size(); ++j) { - // TODO(b/220942808) Shard non-dividible op dimensions. - if (!IsDivisible(lhs_->shape().dimensions().at(lhs_space_dims_.at(i)), - device_mesh_.dim(mesh_dim0)) || - !IsDivisible(rhs_->shape().dimensions().at(rhs_space_dims_.at(j)), - device_mesh_.dim(mesh_dim1))) { + if (lhs_->shape().dimensions().at(lhs_space_dims_.at(i)) < + device_mesh_.dim(mesh_dim0) || + rhs_->shape().dimensions().at(rhs_space_dims_.at(j)) < + device_mesh_.dim(mesh_dim1)) { + continue; + } + if (solver_option_.only_allow_divisible_intermediate && + (!IsDivisible(lhs_->shape().dimensions().at(lhs_space_dims_.at(i)), + device_mesh_.dim(mesh_dim0)) || + !IsDivisible(rhs_->shape().dimensions().at(rhs_space_dims_.at(j)), + device_mesh_.dim(mesh_dim1)))) { continue; } std::string name = @@ -117,10 +126,17 @@ class DotHandler { void SplitLhsSpaceOnly(int mesh_dim0, int mesh_dim1) { for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) { for (int64_t j = i + 1; j < lhs_space_dims_.size(); ++j) { - if (!IsDivisible(lhs_->shape().dimensions().at(lhs_space_dims_.at(i)), - device_mesh_.dim(mesh_dim0)) || - !IsDivisible(lhs_->shape().dimensions().at(lhs_space_dims_.at(j)), - device_mesh_.dim(mesh_dim1))) { + if (lhs_->shape().dimensions().at(lhs_space_dims_.at(i)) < + device_mesh_.dim(mesh_dim0) || + lhs_->shape().dimensions().at(lhs_space_dims_.at(j)) < + device_mesh_.dim(mesh_dim1)) { + continue; + } + if (solver_option_.only_allow_divisible_intermediate && + (!IsDivisible(lhs_->shape().dimensions().at(lhs_space_dims_.at(i)), + device_mesh_.dim(mesh_dim0)) || + !IsDivisible(lhs_->shape().dimensions().at(lhs_space_dims_.at(j)), + device_mesh_.dim(mesh_dim1)))) { continue; } std::string name = @@ -142,10 +158,17 @@ class DotHandler { void SplitRhsSpaceOnly(int mesh_dim0, int mesh_dim1) { for (int64_t i = 0; i < rhs_space_dims_.size(); ++i) { for (int64_t j = i + 1; j < rhs_space_dims_.size(); ++j) { - if (!IsDivisible(rhs_->shape().dimensions().at(rhs_space_dims_.at(i)), - device_mesh_.dim(mesh_dim0)) || - !IsDivisible(rhs_->shape().dimensions().at(rhs_space_dims_.at(j)), - device_mesh_.dim(mesh_dim1))) { + if (rhs_->shape().dimensions().at(rhs_space_dims_.at(i)) < + device_mesh_.dim(mesh_dim0) || + rhs_->shape().dimensions().at(rhs_space_dims_.at(j)) < + device_mesh_.dim(mesh_dim1)) { + continue; + } + if (solver_option_.only_allow_divisible_intermediate && + (!IsDivisible(rhs_->shape().dimensions().at(rhs_space_dims_.at(i)), + device_mesh_.dim(mesh_dim0)) || + !IsDivisible(rhs_->shape().dimensions().at(rhs_space_dims_.at(j)), + device_mesh_.dim(mesh_dim1)))) { continue; } std::string name = @@ -174,10 +197,18 @@ class DotHandler { mesh_dim1, mesh_dim1); for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) { for (int64_t j = 0; j < lhs_con_dims_.size(); ++j) { - if (!IsDivisible(lhs_->shape().dimensions().at(lhs_space_dims_.at(i)), - device_mesh_.dim(mesh_dim0)) || - !IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(j)), - device_mesh_.dim(mesh_dim1))) { + if (lhs_->shape().dimensions().at(lhs_space_dims_.at(i)) < + device_mesh_.dim(mesh_dim0) || + lhs_->shape().dimensions().at(lhs_con_dims_.at(j)) < + device_mesh_.dim(mesh_dim1)) { + continue; + } + if (solver_option_.only_allow_divisible_intermediate && + (!IsDivisible( + lhs_->shape().dimensions().at(lhs_space_dims_.at(i)), + device_mesh_.dim(mesh_dim0)) || + !IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(j)), + device_mesh_.dim(mesh_dim1)))) { continue; } @@ -207,10 +238,18 @@ class DotHandler { mesh_dim1, mesh_dim0); for (int64_t i = 0; i < rhs_space_dims_.size(); ++i) { for (int64_t j = 0; j < lhs_con_dims_.size(); ++j) { - if (!IsDivisible(rhs_->shape().dimensions().at(rhs_space_dims_.at(i)), - device_mesh_.dim(mesh_dim1)) || - !IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(j)), - device_mesh_.dim(mesh_dim0))) { + if (rhs_->shape().dimensions().at(rhs_space_dims_.at(i)) < + device_mesh_.dim(mesh_dim1) || + lhs_->shape().dimensions().at(lhs_con_dims_.at(j)) < + device_mesh_.dim(mesh_dim0)) { + continue; + } + if (solver_option_.only_allow_divisible_intermediate && + (!IsDivisible( + rhs_->shape().dimensions().at(rhs_space_dims_.at(i)), + device_mesh_.dim(mesh_dim1)) || + !IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(j)), + device_mesh_.dim(mesh_dim0)))) { continue; } HloSharding output_spec = @@ -240,7 +279,12 @@ class DotHandler { [](int64_t size) { return size > 1; }) == 1) { for (int64_t i = 0; i < lhs_batch_dims_.size(); ++i) { for (int64_t j = 0; j < device_mesh_.num_dimensions(); ++j) { - if (!IsDivisible(lhs_->shape().dimensions().at(lhs_batch_dims_.at(i)), + if (lhs_->shape().dimensions().at(lhs_batch_dims_.at(i)) < + device_mesh_.dim(j)) { + continue; + } + if (solver_option_.only_allow_divisible_intermediate && + !IsDivisible(lhs_->shape().dimensions().at(lhs_batch_dims_.at(i)), device_mesh_.dim(j))) { continue; } @@ -260,11 +304,20 @@ class DotHandler { void SplitTwoBatchDims(int mesh_dim0, int mesh_dim1) { if (lhs_batch_dims_.size() == 2 && device_mesh_.dim(mesh_dim0) > 1 && - device_mesh_.dim(mesh_dim1) > 1 && - IsDivisible(lhs_->shape().dimensions().at(lhs_batch_dims_.at(0)), - device_mesh_.dim(mesh_dim0)) && - IsDivisible(lhs_->shape().dimensions().at(lhs_batch_dims_.at(1)), - device_mesh_.dim(mesh_dim1))) { + device_mesh_.dim(mesh_dim1) > 1) { + if (lhs_->shape().dimensions().at(lhs_batch_dims_.at(0)) < + device_mesh_.dim(mesh_dim0) || + lhs_->shape().dimensions().at(lhs_batch_dims_.at(1)) < + device_mesh_.dim(mesh_dim1)) { + return; + } + if (solver_option_.only_allow_divisible_intermediate && + (!IsDivisible(lhs_->shape().dimensions().at(lhs_batch_dims_.at(0)), + device_mesh_.dim(mesh_dim0)) || + !IsDivisible(lhs_->shape().dimensions().at(lhs_batch_dims_.at(1)), + device_mesh_.dim(mesh_dim1)))) { + return; + } std::string name = absl::StrFormat("Sb = Sb x Sb @ {%d,%d}", mesh_dim0, mesh_dim1); HloSharding output_spec = @@ -287,10 +340,19 @@ class DotHandler { absl::StrFormat("SbSi = SbSi x SbR @ {%d,%d}", mesh_dim0, mesh_dim1); for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) { for (int64_t j = 0; j < lhs_batch_dims_.size(); ++j) { - if (!IsDivisible(lhs_->shape().dimensions().at(lhs_space_dims_.at(i)), - device_mesh_.dim(mesh_dim0)) || - !IsDivisible(lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)), - device_mesh_.dim(mesh_dim1))) { + if (lhs_->shape().dimensions().at(lhs_space_dims_.at(i)) < + device_mesh_.dim(mesh_dim0) || + lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)) < + device_mesh_.dim(mesh_dim1)) { + continue; + } + if (solver_option_.only_allow_divisible_intermediate && + (!IsDivisible( + lhs_->shape().dimensions().at(lhs_space_dims_.at(i)), + device_mesh_.dim(mesh_dim0)) || + !IsDivisible( + lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)), + device_mesh_.dim(mesh_dim1)))) { continue; } HloSharding output_spec = @@ -316,10 +378,19 @@ class DotHandler { absl::StrFormat("SbSj = SbR x SbSj @ {%d,%d}", mesh_dim0, mesh_dim1); for (int64_t i = 0; i < rhs_space_dims_.size(); ++i) { for (int64_t j = 0; j < lhs_batch_dims_.size(); ++j) { - if (!IsDivisible(rhs_->shape().dimensions().at(rhs_space_dims_.at(i)), - device_mesh_.dim(mesh_dim1)) || - !IsDivisible(lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)), - device_mesh_.dim(mesh_dim0))) { + if (rhs_->shape().dimensions().at(rhs_space_dims_.at(i)) < + device_mesh_.dim(mesh_dim1) || + lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)) < + device_mesh_.dim(mesh_dim0)) { + continue; + } + if (solver_option_.only_allow_divisible_intermediate && + (!IsDivisible( + rhs_->shape().dimensions().at(rhs_space_dims_.at(i)), + device_mesh_.dim(mesh_dim1)) || + !IsDivisible( + lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)), + device_mesh_.dim(mesh_dim0)))) { continue; } HloSharding output_spec = @@ -348,10 +419,18 @@ class DotHandler { mesh_dim0, mesh_dim1, mesh_dim1); for (int64_t i = 0; i < lhs_con_dims_.size(); ++i) { for (int64_t j = 0; j < lhs_batch_dims_.size(); ++j) { - if (!IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(i)), - device_mesh_.dim(mesh_dim1)) || - !IsDivisible(lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)), - device_mesh_.dim(mesh_dim0))) { + if (lhs_->shape().dimensions().at(lhs_con_dims_.at(i)) < + device_mesh_.dim(mesh_dim1) || + lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)) < + device_mesh_.dim(mesh_dim0)) { + continue; + } + if (solver_option_.only_allow_divisible_intermediate && + (!IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(i)), + device_mesh_.dim(mesh_dim1)) || + !IsDivisible( + lhs_->shape().dimensions().at(lhs_batch_dims_.at(j)), + device_mesh_.dim(mesh_dim0)))) { continue; } HloSharding output_spec = @@ -383,14 +462,25 @@ class DotHandler { mesh_dim0, mesh_dim1, mesh_dim0, mesh_dim1); for (int64_t i = 0; i < lhs_con_dims_.size(); ++i) { for (int64_t j = i + 1; j < lhs_con_dims_.size(); ++j) { - if (!IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(i)), - device_mesh_.dim(mesh_dim0)) || - !IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(j)), - device_mesh_.dim(mesh_dim1)) || - !IsDivisible(rhs_->shape().dimensions().at(rhs_con_dims_.at(i)), - device_mesh_.dim(mesh_dim0)) || - !IsDivisible(rhs_->shape().dimensions().at(rhs_con_dims_.at(j)), - device_mesh_.dim(mesh_dim1))) { + if (lhs_->shape().dimensions().at(lhs_con_dims_.at(i)) < + device_mesh_.dim(mesh_dim0) || + lhs_->shape().dimensions().at(lhs_con_dims_.at(j)) < + device_mesh_.dim(mesh_dim1) || + rhs_->shape().dimensions().at(rhs_con_dims_.at(i)) < + device_mesh_.dim(mesh_dim0) || + rhs_->shape().dimensions().at(rhs_con_dims_.at(j)) < + device_mesh_.dim(mesh_dim1)) { + continue; + } + if (solver_option_.only_allow_divisible_intermediate && + (!IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(i)), + device_mesh_.dim(mesh_dim0)) || + !IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(j)), + device_mesh_.dim(mesh_dim1)) || + !IsDivisible(rhs_->shape().dimensions().at(rhs_con_dims_.at(i)), + device_mesh_.dim(mesh_dim0)) || + !IsDivisible(rhs_->shape().dimensions().at(rhs_con_dims_.at(j)), + device_mesh_.dim(mesh_dim1)))) { continue; } HloSharding output_spec = HloSharding::Replicate(); @@ -416,7 +506,12 @@ class DotHandler { std::string name = absl::StrFormat("RR = RS x SR @ {%d} (allreduce @ %d)", mesh_dim0, mesh_dim0); for (int64_t i = 0; i < lhs_con_dims_.size(); ++i) { - if (!IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(i)), + if (lhs_->shape().dimensions().at(lhs_con_dims_.at(i)) < + device_mesh_.dim(mesh_dim0)) { + continue; + } + if (solver_option_.only_allow_divisible_intermediate && + !IsDivisible(lhs_->shape().dimensions().at(lhs_con_dims_.at(i)), device_mesh_.dim(mesh_dim0))) { continue; } @@ -447,8 +542,14 @@ class DotHandler { // Si = Si x R @ 0 for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) { - if (IsDivisible(lhs_->shape().dimensions(lhs_space_dims_[i]), - num_devices)) { + if (lhs_->shape().dimensions(lhs_space_dims_[i]) < num_devices) { + continue; + } + if (solver_option_.only_allow_divisible_intermediate && + !IsDivisible(lhs_->shape().dimensions(lhs_space_dims_[i]), + num_devices)) { + continue; + } std::string name = absl::StrFormat("Si = Si x R @ %d", mesh_dim); HloSharding output_spec = Tile(ins_->shape(), {space_base_dim_ + i}, {mesh_dim}, device_mesh_1d_); @@ -457,13 +558,18 @@ class DotHandler { HloSharding rhs_spec = HloSharding::Replicate(); AppendNewStrategy(ins_, name, output_spec, {lhs_spec, rhs_spec}, 0, 0, cluster_env_, strategy_map_, strategies_); - } } // R = Sk x Sk @ (allreduce @ 0) for (int64_t i = 0; i < lhs_con_dims_.size(); ++i) { - if (IsDivisible(lhs_->shape().dimensions(lhs_con_dims_[i]), - num_devices)) { + if (lhs_->shape().dimensions(lhs_con_dims_[i]) < num_devices) { + continue; + } + if (solver_option_.only_allow_divisible_intermediate && + !IsDivisible(lhs_->shape().dimensions(lhs_con_dims_[i]), + num_devices)) { + continue; + } std::string name = absl::StrFormat( "R = Sk x Sk @ %d (allreduce @ %d)", mesh_dim, mesh_dim); HloSharding output_spec = HloSharding::Replicate(); @@ -479,7 +585,6 @@ class DotHandler { communication_cost, cluster_env_, strategy_map_, strategies_); } - } } } @@ -489,10 +594,15 @@ class DotHandler { [](int64_t size) { return size > 1; }) > 1) { int mesh_dim = 0; for (int64_t i = 0; i < lhs_batch_dims_.size(); ++i) { - if (!IsDivisible(rhs_->shape().dimensions().at(lhs_batch_dims_.at(i)), - device_mesh_.dim(mesh_dim))) { + if (rhs_->shape().dimensions().at(lhs_batch_dims_.at(i)) < + device_mesh_.dim(mesh_dim)) { continue; - } + } + if (solver_option_.only_allow_divisible_intermediate && + !IsDivisible(rhs_->shape().dimensions().at(lhs_batch_dims_.at(i)), + device_mesh_.dim(mesh_dim))) { + continue; + } std::string name = absl::StrFormat("Sb_%d = Sb x Sb @ {%d} 1d", i, mesh_dim); HloSharding output_spec = diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_runner.cc b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_runner.cc index 0a87624df87..e12ecef29f0 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_runner.cc +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_runner.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/tools/hlo_module_loader.h" diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver_option.h b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver_option.h new file mode 100644 index 00000000000..18d5495ae12 --- /dev/null +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver_option.h @@ -0,0 +1,107 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_SOLVER_OPTION_H__ +#define TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_SOLVER_OPTION_H__ + +#include +#include +#include +#include + +namespace xla { +namespace spmd { +// Options for the auto-sharding solver. +struct AutoShardingSolverOption { + // Forcibly split the batch dimension and map it to a mesh dimension. + // This can force the auto-sharding pass to generate the data parallel + // strategy. + int force_batch_dim_to_mesh_dim; + + // If true, override the cost of all-gather with the given value. + bool override_all_gather_cost; + double all_gather_cost; + + // If true, override the cost of all-reduce with the given value. + bool override_all_reduce_cost; + double all_reduce_cost; + + // If true, override the cost of reduce-scatter with the given value. + bool override_reduce_scatter_cost; + double reduce_scatter_cost; + + // If true, override the cost of all-to-all with the given value. + bool override_all_to_all_cost; + double all_to_all_cost; + + // If true, allow replicated parameters. + bool allow_replicated_parameters; + + // If true, prefer reduce-scatter + all-gather over all-reduce. + // A post process will be applied to replace all-reduce with reduce-scater + + // all-gather if no communication overhead is introduced. + bool prefer_reduce_scatter; + + // If True, generate a gradient-accumulation friendly variant of + // reduce-scatter + bool reduce_scatter_grad_acc_friendly; + + // If true, aggressively partition more tensors when generating + // reduce-scatter, even if it introduces more communication. + bool reduce_scatter_aggressive_partition; + + // If true, the batch matmul will always be parallelized on the batch dim in + // 2d mesh case. + bool batch_matmul_always_split_batch; + + // If true, allow strategies that recompute heavy operators (e.g., dot) + // to reduce communication. + bool allow_recompute_heavy_op; + + // If true, allow adding 1d strategies in 2d logical mesh. + bool allow_mixed_mesh_shape; + + // The number of micro batches if gradient accumulation is used. + // If this is not 1, the cost of all-reduce for gradient synchronization + // is divided by this number. + int grad_acc_num_micro_batches; + + // If true, load solution vector from PassContext + bool load_solution_vector; + + // If it is not empty, forcibly use simple heuristic strategies + // instead of the ILP solver. This is used for ablation study. + std::string force_simple_heuristic; + + // If true, forcibly set the strategy of some instructions + bool force_strategy; + std::vector force_strategy_inst_indices; + std::vector force_strategy_stra_names; + + bool only_allow_divisible_input_output; + + bool only_allow_divisible_intermediate; + + // If true, trictly limit the following iterations to use the same number of + // shards for sharded tensor dimensions; if false, the following iterations + // can choose different number of shards for sharded tensor dimensions. + // Enabling it can hurt the performance of dot ops, but can make the search + // space more scalable. Therefore leaving it as an option. + bool nd_sharding_iteratively_strict_search_space; +}; +} // namespace spmd +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_SOLVER_OPTION_H_ diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h index d23d897e53d..49ef931c3ce 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h @@ -31,82 +31,32 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h" -#include "tensorflow/compiler/xla/service/hlo_live_range.h" - +#include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" namespace xla { namespace spmd { // A constant to represent infinity cost. constexpr double kInfinityCost = 1e13; -// Options for the auto-sharding solver. -struct AutoShardingSolverOption { - // Forcibly split the batch dimension and map it to a mesh dimension. - // This can force the auto-sharding pass to generate the data parallel - // strategy. - int force_batch_dim_to_mesh_dim; - - // If true, override the cost of all-gather with the given value. - bool override_all_gather_cost; - double all_gather_cost; - - // If true, override the cost of all-reduce with the given value. - bool override_all_reduce_cost; - double all_reduce_cost; - - // If true, override the cost of reduce-scatter with the given value. - bool override_reduce_scatter_cost; - double reduce_scatter_cost; - - // If true, override the cost of all-to-all with the given value. - bool override_all_to_all_cost; - double all_to_all_cost; - - // If true, allow replicated parameters. - bool allow_replicated_parameters; - - // If true, prefer reduce-scatter + all-gather over all-reduce. - // A post process will be applied to replace all-reduce with reduce-scater + - // all-gather if no communication overhead is introduced. - bool prefer_reduce_scatter; - - // If True, generate a gradient-accumulation friendly variant of - // reduce-scatter - bool reduce_scatter_grad_acc_friendly; - - // If true, aggressively partition more tensors when generating - // reduce-scatter, even if it introduces more communication. - bool reduce_scatter_aggressive_partition; - - // If true, the batch matmul will always be parallelized on the batch dim in - // 2d mesh case. - bool batch_matmul_always_split_batch; - - // If true, allow strategies that recompute heavy operators (e.g., dot) - // to reduce communication. - bool allow_recompute_heavy_op; - - // If true, allow adding 1d strategies in 2d logical mesh. - bool allow_mixed_mesh_shape; - - // The number of micro batches if gradient accumulation is used. - // If this is not 1, the cost of all-reduce for gradient synchronization - // is divided by this number. - int grad_acc_num_micro_batches; - - // If true, load solution vector from PassContext - bool load_solution_vector; - - // If it is not empty, forcibly use simple heuristic strategies - // instead of the ILP solver. This is used for ablation study. - std::string force_simple_heuristic; - - // If true, forcibly set the strategy of some instructions - bool force_strategy; - std::vector force_strategy_inst_indices; - std::vector force_strategy_stra_names; -}; +// Type alias +template +using StableHashMap = ::absl::flat_hash_map; +template +using StableHashSet = ::absl::flat_hash_set; + +// Map an instruction to its depth. +using InstructionDepthMap = StableHashMap; +// Map an instruction to its batch dimension. +using InstructionBatchDimMap = StableHashMap; +// Map an instruction to its alias source parameter. +using AliasMap = StableHashMap; +// Map an instruction to its resharding cache. +using ReshardingCache = + StableHashMap>>; // One sharding strategy struct ShardingStrategy { @@ -233,524 +183,7 @@ using AssociativeDotPairs = // The set of all alias pairs using AliasSet = StableHashSet>; -// Store the profiling results of communication and computation. -class ProfilingResult { - public: - // TODO (zhuohan): loading the profiling result. - ProfilingResult() { - if (all_reduce_cost_dict_.empty()) { - enabled_ = false; - } else { - enabled_ = true; - } - } - - bool Enabled() const { return enabled_; } - - double EstimateAllGatherCost( - const std::vector>& replica_groups, int64_t size, - const std::string& dtype) const { - if (all_gather_cost_dict_.empty()) { - // Use all-reduce to approximate all-gather. - return EstimateAllReduceCost(replica_groups, size, dtype) / 2; - } - - return EstimateInternal(replica_groups, size, dtype, - all_gather_cost_dict_) - - EstimateInternal(replica_groups, 0, dtype, all_gather_cost_dict_); - } - - double EstimateAllReduceCost( - const std::vector>& replica_groups, int64_t size, - const std::string& dtype) const { - return EstimateInternal(replica_groups, size, dtype, - all_reduce_cost_dict_) - - EstimateInternal(replica_groups, 0, dtype, all_reduce_cost_dict_); - } - - double EstimateReduceScatterCost( - const std::vector>& replica_groups, int64_t size, - const std::string& dtype) const { - if (reduce_scatter_cost_dict_.empty()) { - // Use all-reduce to approximate reduce-scatter. - return EstimateAllReduceCost(replica_groups, size, dtype) / 2; - } - - return EstimateInternal(replica_groups, size, dtype, - reduce_scatter_cost_dict_) - - EstimateInternal(replica_groups, 0, dtype, - reduce_scatter_cost_dict_); - } - - double EstimateAllToAllCost( - const std::vector>& replica_groups, int64_t size, - const std::string& dtype) const { - // A penalty factor to make the theoretical cost match the - // empirical cost on v100 + nvlink. - int64_t num_devices = replica_groups.front().size(); - double penalty_factor = static_cast(num_devices) / 2.0; - // Use all-gather to approximate all-to-all. - return EstimateAllGatherCost(replica_groups, size / num_devices, dtype) * - penalty_factor; - } - - std::string ToString() { - std::string str; - for (const auto& item : all_reduce_cost_dict_) { - absl::StrAppend(&str, item.first.first, " ", item.first.second, "\n"); - } - return str; - } - - private: - // pair - using Key = std::pair; - // vector> - using Value = std::vector>; - - // Estimate the cost by linear interpolation between the two closest points. - double EstimateInternal( - const std::vector>& replica_groups, int64_t size, - const std::string& dtype, - const StableHashMap& cost_dict) const { - Key key(Group2Str(replica_groups), dtype); - Value cost_list = cost_dict.at(key); - - CHECK(!cost_list.empty()); - - size_t i; - if (size > cost_list.back().first) { - i = cost_list.size() - 2; - } else if (size < cost_list.front().first) { - i = 0; - } else { - for (i = 0; i < cost_list.size() - 1; ++i) { - if (cost_list[i].first <= size && size <= cost_list[i + 1].first) { - break; - } - } - } - - int64_t left_size = cost_list[i].first; - double left_cost = cost_list[i].second; - int64_t right_size = cost_list[i + 1].first; - double right_cost = cost_list[i + 1].second; - - return 1.0 * (size - left_size) / (right_size - left_size) * - (right_cost - left_cost) + - left_cost; - } - - // Make a string key of a replica_groups. - std::string Group2Str( - const std::vector>& replica_groups) const { - std::string str; - absl::StrAppend(&str, "("); - for (const auto& group : replica_groups) { - absl::StrAppend(&str, "(", absl::StrJoin(group, ","), ")"); - } - absl::StrAppend(&str, ")"); - - return str; - } - - bool enabled_; - StableHashMap all_reduce_cost_dict_; - StableHashMap all_gather_cost_dict_; - StableHashMap reduce_scatter_cost_dict_; -}; - -// The cluster has a multi-dimensional device mesh topology. -// Each mesh dimension has its own latency and bandwidth. -// We use alpha-beta model to model the communication cost. -// If profiling result is provided, we always prefer to use -// the real profiling result. -class ClusterEnvironment { - public: - ClusterEnvironment(const Array& original_device_mesh, - const Array& device_mesh, - absl::Span mesh_alpha, - absl::Span mesh_beta, - const ProfilingResult& prof_result, - const AutoShardingSolverOption& solver_option) - : original_device_mesh_(original_device_mesh), - device_mesh_(device_mesh), - mesh_alpha_(mesh_alpha.begin(), mesh_alpha.end()), - mesh_beta_(mesh_beta.begin(), mesh_beta.end()), - prof_result_(prof_result), - total_devices_(device_mesh.num_elements()), - device_mesh_1d_(original_device_mesh), - solver_option_(solver_option) { - // Build replica group for each dimension. - non_zero_mesh_dims_ = - VectorGreaterThanOneElementIndices(device_mesh.dimensions()); - GenerateCachedReplicaGroups(); - // TODO(yuemmawang) Find the largest dimension in original_device_mesh and - // create 1d mesh on that dimension. - device_mesh_1d_.Reshape({original_device_mesh.num_elements(), 1}); - } - - size_t NumDevices() const { return total_devices_; } - - bool IsDeviceMesh3D() const { - return VectorGreaterThanOneElementCount(device_mesh_.dimensions()) == 3; - } - - bool IsDeviceMesh2D() const { - return VectorGreaterThanOneElementCount(device_mesh_.dimensions()) == 2; - } - - bool IsDeviceMesh1D() const { - return VectorGreaterThanOneElementCount(device_mesh_.dimensions()) == 1; - } - - bool IsOriginalDeviceMesh2D() const { - return VectorGreaterThanOneElementCount( - original_device_mesh_.dimensions()) == 2; - } - - double AllGatherCost(double num_bytes, int mesh_dim) const { - if (solver_option_.override_all_gather_cost) { - return solver_option_.all_gather_cost; - } - - if (prof_result_.Enabled()) { - return prof_result_.EstimateAllGatherCost( - cached_replica_groups_[mesh_dim], num_bytes / 4, "float32"); - } - - if (solver_option_.force_batch_dim_to_mesh_dim == mesh_dim) { - // if data-parallel is forced on this dim, we only allow all-reduce - // in this dimension. - return kInfinityCost; - } - - int64_t num_devices = device_mesh_.dim(mesh_dim); - return (round(mesh_alpha_[mesh_dim] + mesh_beta_[mesh_dim] * - (num_devices - 1) / num_devices * - num_bytes) + - 0.1); - } - - // TODO(zhuohan): distinguish dtype and reduce_op. - double AllReduceCost(double num_bytes, int32_t mesh_dim, - int32_t mesh_dim_another = -1) const { - if (solver_option_.override_all_reduce_cost) { - return solver_option_.all_reduce_cost; - } - - if (prof_result_.Enabled()) { - return prof_result_.EstimateAllReduceCost( - cached_replica_groups_[mesh_dim], num_bytes / 4, "float32"); - } - double alpha, beta; - int64_t num_devices; - if (mesh_dim_another == -1) { - // Only communicating on one mesh dimension. - alpha = mesh_alpha_[mesh_dim]; - beta = mesh_beta_[mesh_dim]; - num_devices = device_mesh_.dim(mesh_dim); - } else { - // Communicating through both mesh dimensions. - alpha = std::max(mesh_alpha_[mesh_dim], mesh_alpha_[mesh_dim_another]); - beta = std::max(mesh_beta_[mesh_dim], mesh_beta_[mesh_dim_another]); - num_devices = device_mesh_.num_elements(); - } - return ( - round(alpha + beta * 2 * (num_devices - 1) / num_devices * num_bytes) + - 0.01); - } - - double ReduceScatterCost(double num_bytes, int mesh_dim) const { - if (solver_option_.override_reduce_scatter_cost) { - return solver_option_.reduce_scatter_cost; - } - - if (prof_result_.Enabled()) { - return prof_result_.EstimateReduceScatterCost( - cached_replica_groups_[mesh_dim], num_bytes / 4, "float32"); - } - - int64_t num_devices = device_mesh_.dim(mesh_dim); - return (round(mesh_alpha_[mesh_dim] + mesh_beta_[mesh_dim] * - (num_devices - 1) / num_devices * - num_bytes) + - 0.001); - } - - double AllToAllCost(double num_bytes, int mesh_dim) const { - if (solver_option_.override_all_to_all_cost) { - return solver_option_.all_to_all_cost; - } - - if (prof_result_.Enabled()) { - return prof_result_.EstimateAllToAllCost(cached_replica_groups_[mesh_dim], - num_bytes / 4, "float32"); - } - - if (solver_option_.force_batch_dim_to_mesh_dim == mesh_dim) { - // if data-parallel is forced on this dim, we only allow all-reduce - // in this dimension. - return kInfinityCost; - } - - int64_t num_devices = device_mesh_.dim(mesh_dim); - return AllToAllCostUtil(num_bytes, mesh_dim, num_devices, mesh_alpha_, - mesh_beta_); - } - - double DotCost(const Shape& lhs_shape, const Shape& rhs_shape, - const DotDimensionNumbers& dot_dnums) const { - if (!solver_option_.allow_recompute_heavy_op) { - return kInfinityCost; - } - - // TODO(zhuohan): When profiling data is not available, it is not easy to - // align the scale of compute cost and communication cost. Here we just use - // a simple heuristic to compute the compute cost with communication cost. - double num_bytes = GetBytes(lhs_shape) + GetBytes(rhs_shape); - return AllReduceCost(num_bytes, 0) + AllReduceCost(num_bytes, 1); - } - - // Get the corresponding mesh dimension for every tensor dimension. - // -1 means replicated on that dimension - std::vector GetTensorDimToMeshDimWrapper( - const Shape& shape, const HloSharding& spec) const { - int64_t n_dim = NumTileDimensions(spec); - std::vector tensor_dim_to_mesh_dim = - GetTensorDimToMeshDim(shape.rank(), spec, device_mesh_); - AdjustTensorMeshDimMapping(tensor_dim_to_mesh_dim, n_dim); - return tensor_dim_to_mesh_dim; - } - - // The communication cost of resharding a tensor from src to dst - // TODO(b/238210866) Do not use kInfinityCost. - double ReshardingCost(const Shape& shape, const HloSharding& src_spec, - const HloSharding& dst_spec) const { - // TODO(zhuohan): This function can be wrong and needs more tests. - if (src_spec == dst_spec || IsUndefined(src_spec)) { - return 0.0; - } - CHECK(!IsUndefined(dst_spec)); - int64_t src_n_dim = NumTileDimensions(src_spec); - int64_t dst_n_dim = NumTileDimensions(dst_spec); - // When src_spec and dst_spec are for arrays with different number of - // dimensions, which could happen when an instruction follows the sharding - // of an operand with a different shape, we need to use their - // TiledDataRank(). - size_t src_rank = shape.rank(); - if (src_spec.IsTiled()) { - src_rank = src_spec.TiledDataRank(); - } - size_t dst_rank = shape.rank(); - if (dst_spec.IsTiled()) { - dst_rank = dst_spec.TiledDataRank(); - } - std::vector src_tensor_dim_to_mesh_dim; - if (VectorGreaterThanOneElementCount( - src_spec.tile_assignment().dimensions()) == 1 && - VectorGreaterThanOneElementCount(device_mesh_.dimensions()) > 1) { - // src spec is 1D and device_mesh is 2D or 3D - src_tensor_dim_to_mesh_dim = - GetTensorDimToMeshDim(src_rank, src_spec, device_mesh_1d_); - } else { - src_tensor_dim_to_mesh_dim = - GetTensorDimToMeshDim(src_rank, src_spec, device_mesh_); - } - std::vector dst_tensor_dim_to_mesh_dim; - if (VectorGreaterThanOneElementCount( - dst_spec.tile_assignment().dimensions()) == 1 && - VectorGreaterThanOneElementCount(device_mesh_.dimensions()) > 1) { - // src spec is 1D and device_mesh is 2D or 3D - dst_tensor_dim_to_mesh_dim = - GetTensorDimToMeshDim(dst_rank, dst_spec, device_mesh_1d_); - } else { - dst_tensor_dim_to_mesh_dim = - GetTensorDimToMeshDim(dst_rank, dst_spec, device_mesh_); - } - if (src_n_dim != dst_n_dim && src_n_dim != -1 && dst_n_dim != -1) { - return ReshardingCostMixedMeshShape( - shape, src_tensor_dim_to_mesh_dim, dst_tensor_dim_to_mesh_dim, - device_mesh_.num_elements(), mesh_alpha_, mesh_beta_); - } - - AdjustTensorMeshDimMapping(src_tensor_dim_to_mesh_dim, src_n_dim); - AdjustTensorMeshDimMapping(dst_tensor_dim_to_mesh_dim, dst_n_dim); - - // Analyze the dims that need to dynamic-sliced or all-gather. - std::vector slice_dims; - std::vector all_gather_dims; - for (int64_t i = 0; i < std::min(src_rank, dst_rank); ++i) { - int src_mesh_dim = src_tensor_dim_to_mesh_dim[i]; - int dst_mesh_dim = dst_tensor_dim_to_mesh_dim[i]; - if (src_mesh_dim == dst_mesh_dim) { - continue; - } - if (src_mesh_dim == -1) { - slice_dims.push_back(src_mesh_dim); - continue; - } - if (dst_mesh_dim == -1) { - all_gather_dims.push_back(src_mesh_dim); - continue; - } - // Do not allow other re-sharding patterns. (e.g., collective-permute) - return kInfinityCost; - } - - // Case 1: no communication is required. Only needs dynamic-slice. - if (all_gather_dims.empty()) { - return 0; - } - - // Do not allow some strange re-sharding patterns. - if (slice_dims.size() > 1 && all_gather_dims.size() > 1) { - return kInfinityCost; - } - - // Case 2: all-to-all - if (slice_dims.size() == 1 && all_gather_dims.size() == 1) { - if (device_mesh_.dim(0) > 1 && device_mesh_.dim(1) > 1) { - return kInfinityCost; - } - - double bytes = GetBytes(shape); - return AllToAllCost(bytes, all_gather_dims.front()); - } - - // Case 3: all-gather - double bytes = GetBytes(shape) / src_spec.NumTiles(); - double cost = 0.0; - for (int dim : all_gather_dims) { - if (dim >= device_mesh_.num_dimensions()) { - return kInfinityCost; - } - bytes *= device_mesh_.dim(dim); - cost += AllGatherCost(bytes, dim); - } - return cost; - } - - // Print the information of this device mesh. - std::string ToString() { - std::string str; - absl::StrAppend(&str, "device_mesh: ", device_mesh_.ToString(), "\n"); - absl::StrAppend(&str, "mesh_alpha: ", absl::StrJoin(mesh_alpha_, " "), - "\n"); - absl::StrAppend(&str, "mesh_beta: ", absl::StrJoin(mesh_beta_, " "), "\n"); - return str; - } - - // The original, complete device mesh shape that describes the hardware. - const Array original_device_mesh_; - // When solve_nd_sharding_iteratively is true, it is a partial mesh shape from - // the original_device_mesh_. When solve_nd_sharding_iteratively is false, it - // is the same as original_device_mesh_. - const Array device_mesh_; - // Bandwidth of the device mesh - const std::vector mesh_alpha_; - const std::vector mesh_beta_; - const ProfilingResult& prof_result_; - std::vector non_zero_mesh_dims_; - const int total_devices_; - - // Cache a flatten 1d version of the device mesh. - // Used for mixed mesh shape strategies. - Array device_mesh_1d_; - - // The solver option may override the cost of communication primitives - const AutoShardingSolverOption& solver_option_; - - // Cached replica groups. Shape: [mesh_dim, group_id, ids in this group]. - std::vector>> cached_replica_groups_; - - private: - void GenerateCachedReplicaGroups() { - // One vector per device_mesh_ dimension. - cached_replica_groups_.reserve(device_mesh_.num_dimensions()); - for (size_t i = 0; i < device_mesh_.num_dimensions(); i++) { - cached_replica_groups_.push_back( - GetReplicaGroupsAlongOneDimension(device_mesh_, i)); - } - } - - void AdjustTensorMeshDimMapping(std::vector& mapping, - int64_t n_dim) const { - // Shift the non-zero dim for 1d mesh - if (n_dim == 1 && non_zero_mesh_dims_.size() == 1) { - for (size_t i = 0; i < mapping.size(); ++i) { - if (mapping[i] == 0) { - mapping[i] = non_zero_mesh_dims_.front(); - } - } - } - } -}; - -// Function declarations -// Their comments can be found in their definitions in *.cc files. -HloSharding Tile(const Shape& shape, absl::Span tensor_dims, - absl::Span mesh_dims, - const Array& device_mesh); - -std::vector ReshardingCostVector(const StrategyVector* strategies, - const Shape& shape, - const HloSharding& required_sharding, - const ClusterEnvironment& cluster_env); - -std::vector FollowInsCostVector(int64_t source_len, int64_t index); - -std::unique_ptr CreateLeafStrategyVector( - size_t instruction_id, const HloInstruction* ins, - const StrategyMap& strategy_map, LeafStrategies& leaf_strategies); - -void SetInNodesWithInstruction(std::unique_ptr& strategies, - const HloInstruction* ins, - const StrategyMap& strategy_map); - -void RemoveDuplicatedStrategy(std::unique_ptr& strategies); - -Status FilterStrategy(const HloInstruction* ins, const Shape& shape, - std::unique_ptr& strategies, - const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, - const AutoShardingSolverOption& solver_option); - -Status HandleDot(std::unique_ptr& strategies, - LeafStrategies& leaf_strategies, StrategyMap& strategy_map, - const HloInstruction* ins, size_t instruction_id, - const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, - const AutoShardingSolverOption& solver_option); - -Status HandleConv(std::unique_ptr& strategies, - LeafStrategies& leaf_strategies, StrategyMap& strategy_map, - const HloInstruction* ins, size_t instruction_id, - const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, - const AutoShardingSolverOption& solver_option); - -void AnnotateShardingWithSimpleHeuristic(HloModule* module, - const std::string& heuristic, - const AliasMap& alias_map, - const ClusterEnvironment& cluster_env); - -// Handle alias: alias pairs must have the same HloSharding. -// To deal with alias, we do special process both before and after -// BuildStrategyAndCost. Because it is easier to handle elementwise -// instructions before BuildStrategyAndCost and it is easier to handle -// dot/conv instructions after BuildStrategyAndCost. Before -// BuildStrategyAndCost, we build an AliasMap to guide the generation of -// strategies. After BuildStrategyAndCost, we use AliasSet to add alias -// constraints in the ILP problem. -AliasMap BuildAliasMap(const HloModule* module); - -AliasSet BuildAliasSet(const HloModule* module, - const StrategyMap& strategy_map); -void CheckAliasSetCompatibility(const AliasSet& alias_set, - const LeafStrategies& leaf_strategies, - const HloInstructionSequence& sequence); } // namespace spmd } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_STRATEGY_H_ diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index e9bdbbb731a..bbb5c831f2d 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -36,8 +36,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/array.h" -#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" -#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/service/hlo_sharding_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -231,21 +230,6 @@ std::optional PropagateReduceWindowSharding( return input_spec; } -// Pass through the custom call marker and get the source instruction -inline const HloInstruction* PassThroughCustomCallMarkerGetSource( - const HloInstruction* ins) { - while (ins->opcode() == HloOpcode::kGetTupleElement && - IsCustomCallMarker(ins->operand(0))) { - const HloInstruction* custom_call = ins->operand(0); - const HloInstruction* tuple = custom_call->operand(0); - while (IsCustomCallMarker(tuple)) { - tuple = tuple->operand(0); - } - ins = tuple->operand(ins->tuple_index()); - } - return ins; -} - // Depth analysis (breadth first search). // We also assign a much larger distance to heavy operators (e.g., dot, // convolution). @@ -446,6 +430,7 @@ void BatchDimMapForward(const std::vector& instructions, case HloOpcode::kSin: case HloOpcode::kSqrt: case HloOpcode::kCbrt: + case HloOpcode::kTan: case HloOpcode::kTanh: // Binary elementwise operations case HloOpcode::kAdd: @@ -704,6 +689,7 @@ void BatchDimMapBackward(const std::vector& instructions, case HloOpcode::kSin: case HloOpcode::kSqrt: case HloOpcode::kCbrt: + case HloOpcode::kTan: case HloOpcode::kTanh: // Binary elementwise operations case HloOpcode::kAdd: @@ -968,58 +954,7 @@ void RemoveDuplicatedStrategy(std::unique_ptr& strategies) { } } -// Filter strategies according to the solver_option.force_batch_dim_to_mesh_dim. -// This can be used to forcibly generate data-parallel strategies. -Status FilterStrategy(const HloInstruction* ins, const Shape& shape, - std::unique_ptr& strategies, - const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, - const AutoShardingSolverOption& solver_option) { - int mesh_dim = solver_option.force_batch_dim_to_mesh_dim; - int batch_dim = batch_map.at(GetBatchDimMapKey(ins)); - const Array& device_mesh = cluster_env.device_mesh_; - - if (shape.dimensions(batch_dim) % device_mesh.dim(mesh_dim) != 0) { - return tsl::errors::InvalidArgument( - "The length of batch dimension is " - "not divisible by the number of devices"); - } - - std::vector new_leaf_vector; - for (auto& stra : strategies->leaf_vector) { - std::vector tensor_dim_to_mesh_dim = - cluster_env.GetTensorDimToMeshDimWrapper(shape, stra.output_sharding); - - if (device_mesh.dim(mesh_dim) > 1) { - // If the mesh dim is not one, the output tensor must be - // tiled along the mesh dim. - if (tensor_dim_to_mesh_dim[batch_dim] == mesh_dim) { - new_leaf_vector.push_back(std::move(stra)); - } - } else { - // If the mesh dim is one, the output tensor must be replicated - // on the mesh dim. - if (tensor_dim_to_mesh_dim[batch_dim] == -1) { - new_leaf_vector.push_back(std::move(stra)); - } - } - } - CHECK(!new_leaf_vector.empty()) - << ins->ToString() << " does not have any valid strategies"; - strategies->leaf_vector = std::move(new_leaf_vector); - - return OkStatus(); -} - -inline std::pair ParseMeshDims(const std::string& strategy_name) { - if (absl::StrContains(strategy_name, "{0,1}")) { - return {0, 1}; - } - return {1, 0}; -} -// Return whether the tensor shape is divisible by -// the number of devices along multiple dimensions. bool IsDivisible(const HloInstruction* ins, const Array& device_mesh, absl::Span tensor_dims, absl::Span mesh_dims) { @@ -1034,181 +969,6 @@ bool IsDivisible(const HloInstruction* ins, const Array& device_mesh, return true; } -// Return the output sharding of the reduce-scatter variant of a given strategy. -HloSharding GetReduceScatterOutput(const HloInstruction* ins, - const ShardingStrategy& strategy, - const ClusterEnvironment& cluster_env) { - const Array& device_mesh = cluster_env.device_mesh_; - const Array& device_mesh_1d = cluster_env.device_mesh_1d_; - - if (ins->opcode() == HloOpcode::kDot) { - const DotDimensionNumbers& dot_dnums = ins->dot_dimension_numbers(); - int64_t space_base_dim = dot_dnums.lhs_batch_dimensions_size(); - - if (absl::StartsWith(strategy.name, "SR = SS x SR") || - absl::StartsWith(strategy.name, "RS = RS x SS")) { - int mesh_dim0, mesh_dim1; - std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(strategy.name); - - if (!IsDivisible(ins, device_mesh, {space_base_dim, space_base_dim + 1}, - {mesh_dim0, mesh_dim1})) { - // XLA supports uneven partitioning by adding padding. - // However, the ShardingSpec in Jax does not support uneven - // partitioning. - return Undefined(); - } - - return Tile(ins->shape(), {space_base_dim, space_base_dim + 1}, - {mesh_dim0, mesh_dim1}, device_mesh); - } - if (absl::StartsWith(strategy.name, "SbR = SbSk x SbSk")) { - int mesh_dim0, mesh_dim1; - std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(strategy.name); - - if (!IsDivisible(ins, device_mesh, {0, space_base_dim}, - {mesh_dim0, mesh_dim1})) { - // XLA supports uneven partitioning by adding padding. - // However, the ShardingSpec in Jax does not support uneven - // partitioning. - return Undefined(); - } - - return Tile(ins->shape(), {0, space_base_dim}, {mesh_dim0, mesh_dim1}, - device_mesh); - } - if (absl::StartsWith(strategy.name, "RR = RS x SR")) { - int mesh_dim = absl::StrContains(strategy.name, "{0}") ? 0 : 1; - - if (!IsDivisible(ins, device_mesh, {space_base_dim}, {mesh_dim})) { - return Undefined(); - } - - return Tile(ins->shape(), {space_base_dim}, {mesh_dim}, device_mesh); - } - if (absl::StartsWith(strategy.name, "R = Sk x Sk")) { - int mesh_dim = 0; - - if (!IsDivisible(ins, device_mesh_1d, {space_base_dim}, {mesh_dim})) { - return Undefined(); - } - - return Tile(ins->shape(), {space_base_dim}, {mesh_dim}, device_mesh_1d); - } - } else if (ins->opcode() == HloOpcode::kConvolution) { - const ConvolutionDimensionNumbers& conv_dnums = - ins->convolution_dimension_numbers(); - int out_batch_dim = conv_dnums.output_batch_dimension(); - int out_out_channel_dim = conv_dnums.output_feature_dimension(); - - if (absl::StartsWith(strategy.name, "SR = SS x SR") || - absl::StartsWith(strategy.name, "RS = RS x SS")) { - int mesh_dim0, mesh_dim1; - std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(strategy.name); - - if (!IsDivisible(ins, device_mesh, {out_batch_dim, out_out_channel_dim}, - {mesh_dim0, mesh_dim1})) { - return Undefined(); - } - - return Tile(ins->shape(), {out_batch_dim, out_out_channel_dim}, - {mesh_dim0, mesh_dim1}, device_mesh); - } - if (absl::StartsWith(strategy.name, "R = Sk x Sk")) { - int mesh_dim = 0; - - if (!IsDivisible(ins, device_mesh_1d, {out_batch_dim}, {mesh_dim})) { - return Undefined(); - } - - return Tile(ins->shape(), {out_batch_dim}, {mesh_dim}, device_mesh_1d); - } - } else if (ins->opcode() == HloOpcode::kReduce) { - // TODO(zhuohan): support more cases. - CHECK_EQ(ins->shape().rank(), 1); - - int mesh_dim; - if (absl::StrContains(strategy.name, "allreduce @ [0]")) { - mesh_dim = 0; - } else { - mesh_dim = 1; - } - - if (strategy.output_sharding.IsReplicated()) { - if (absl::StrContains(strategy.name, "1d")) { - if (!IsDivisible(ins, device_mesh_1d, {0}, {mesh_dim})) { - return Undefined(); - } - - return Tile(ins->shape(), {0}, {mesh_dim}, device_mesh_1d); - } - if (!IsDivisible(ins, device_mesh, {0}, {mesh_dim})) { - return Undefined(); - } - - return Tile(ins->shape(), {0}, {mesh_dim}, device_mesh); - } - if (!IsDivisible(ins, device_mesh_1d, {0}, {0})) { - return Undefined(); - } - - Array tile_assignment = strategy.output_sharding.tile_assignment(); - tile_assignment.Reshape({cluster_env.total_devices_}); - return HloSharding::Tile(std::move(tile_assignment)); - - } else { - LOG(FATAL) << "Invalid instruction: " << ins->ToString(); - } - - return Undefined(); -} - -// Return whether an instruction has the opportunity to generate reduce-scatter. -bool HasReduceScatterOpportunity( - const HloInstruction* inst, const StrategyMap& strategy_map, - const CostGraph& cost_graph, absl::Span s_val, - const StableHashSet& modified) { - // If the operand is already modified by other ops, skip this instruction to - // avoid conflicts. - for (const HloInstruction* operand : inst->operands()) { - if (modified.contains(operand)) { - return false; - } - } - if (modified.contains(inst)) { - return false; - } - - if (inst->opcode() == HloOpcode::kReduce && inst->shape().rank() == 1) { - return true; - } - if (inst->opcode() == HloOpcode::kDot) { - if (GetShardingStrategy(inst->operand(0), strategy_map, cost_graph, s_val) - .output_sharding.IsReplicated() && - GetShardingStrategy(inst->operand(1), strategy_map, cost_graph, s_val) - .output_sharding.IsReplicated()) { - // This dot is replicated on all devices. Do not split it. - // TODO(zhuohan): improve this condition. - return false; - } - - return true; - } - if (inst->opcode() == HloOpcode::kConvolution) { - return true; - } - - return false; -} - -// Return whether all users of an instruction is reduce. -bool AllUsersAreReduce(const HloInstruction* inst) { - for (const HloInstruction* user : inst->users()) { - if (user->opcode() != HloOpcode::kReduce) { - return false; - } - } - return true; -} // Set sharding, and apply transpose if necessary. void SetSharding(HloInstruction* to_split, const HloSharding& output_spec, @@ -1226,8 +986,6 @@ void SetSharding(HloInstruction* to_split, const HloSharding& output_spec, } } -// Return whether the instruction is always replicated. -// (e.g., constant, broadcasted constant, scalar) bool IsAlwaysReplicated(const HloInstruction* inst) { if (inst->opcode() == HloOpcode::kConstant) { return true; @@ -1241,149 +999,7 @@ bool IsAlwaysReplicated(const HloInstruction* inst) { return false; } -// Return whether this instruction is a convert on a parameter. -bool IsParameterConvert(const HloInstruction* inst) { - if (inst->opcode() == HloOpcode::kConvert && - inst->operand(0)->opcode() == HloOpcode::kParameter) { - return true; - } - return false; -} - -// Pass through the custom call marker and get the acutal operand. -inline HloInstruction* PassThroughCustomCallMarkerOperand( - HloInstruction* raw_operand, const HloInstruction* inst) { - if (!IsCustomCallMarker(raw_operand)) { - return raw_operand; - } - - CHECK_EQ(inst->opcode(), HloOpcode::kGetTupleElement); - - int index = inst->tuple_index(); - return raw_operand->mutable_operand(0)->mutable_operand(index); -} - -// Return whether the tuple is only used by a custom call marker. -inline bool IsCustomCallMarkerTuple(const HloInstruction* inst) { - return inst->opcode() == HloOpcode::kTuple && inst->users().size() == 1 && - IsCustomCallMarker(inst->users().front()); -} - -// Pass through the custom call marker and get the actual user. -inline HloInstruction* PassThroughCustomCallMarkerUser( - HloInstruction* raw_user, const HloInstruction* inst) { - if (!IsCustomCallMarkerTuple(raw_user)) { - return raw_user; - } - - const HloInstruction* custom_call = raw_user->users().front(); - - int index = -1; - for (int i = 0; i < raw_user->operand_count(); i++) { - if (raw_user->operand(i) == inst) { - index = i; - break; - } - } - CHECK_NE(index, -1); - - HloInstruction* ret = nullptr; - for (HloInstruction* user : custom_call->users()) { - CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement); - if (user->tuple_index() == index) { - CHECK_EQ(ret, nullptr); - ret = user; - } - } - - return ret == nullptr ? raw_user : ret; -} - -// Return the users of an instruction and its alias, -// excluding the final output tuple. -inline StableHashSet UsersWithAlias( - const HloInstruction* inst, const AliasMap& alias_map, - const HloInstruction* output) { - StableHashSet users; - for (HloInstruction* user : inst->users()) { - users.insert(PassThroughCustomCallMarkerUser(user, inst)); - } - - auto iter = alias_map.find(inst); - if (iter != alias_map.end()) { - for (HloInstruction* user : iter->second->users()) { - users.insert(PassThroughCustomCallMarkerUser(user, iter->second)); - } - } - - users.erase(output); - return users; -} - -// DFS to find the replicated set starting from cur instruction. -void FindReplicateSet( - HloInstruction* cur, const AliasMap& alias_map, const CostGraph& cost_graph, - absl::Span s_val, const StrategyMap& strategy_map, - const ShardingStrategy& strategy, const HloInstruction* output, - bool do_all_gather_after_backward, HloInstruction*& transpose_inst, - StableHashSet& replicated_set, - StableHashSet& boundary_set, - StableHashSet& consumer_set, - StableHashSet& visited) { - visited.insert(cur); - - // Check whether the node is a boundary node. - StableHashSet users = UsersWithAlias(cur, alias_map, output); - for (HloInstruction* consumer : users) { - const HloInstruction* shape_inst = cur; - - // Allow at most one transpose - if (consumer->opcode() == HloOpcode::kTranspose && - (transpose_inst == nullptr || - DimensionsEqual(transpose_inst->shape(), consumer->shape()))) { - shape_inst = consumer; - transpose_inst = consumer; - // TODO(zhuohan): fix output_sharding comparison. - } - - if (consumer->opcode() == HloOpcode::kTuple || - (do_all_gather_after_backward && IsParameterConvert(consumer)) || - GetShardingStrategy(consumer, strategy_map, cost_graph, s_val) - .output_sharding != strategy.output_sharding || - !DimensionsEqual(consumer->shape(), shape_inst->shape())) { - boundary_set.insert(cur); - return; - } - } - - // If this node is not a boundary node, propagate from this node. - replicated_set.insert(cur); - for (HloInstruction* consumer : users) { - if (!visited.contains(consumer)) { - consumer_set.insert(consumer); - FindReplicateSet(consumer, alias_map, cost_graph, s_val, strategy_map, - strategy, output, do_all_gather_after_backward, - transpose_inst, replicated_set, boundary_set, - consumer_set, visited); - } - } - - for (size_t i = 0; i < cur->operand_count(); ++i) { - HloInstruction* operand = cur->mutable_operand(i); - operand = PassThroughCustomCallMarkerOperand(operand, cur); - - if (!visited.contains(operand) && !IsAlwaysReplicated(operand) && - GetShardingStrategy(operand, strategy_map, cost_graph, s_val) - .output_sharding == strategy.output_sharding && - DimensionsEqual(operand->shape(), cur->shape())) { - FindReplicateSet(operand, alias_map, cost_graph, s_val, strategy_map, - strategy, output, do_all_gather_after_backward, - transpose_inst, replicated_set, boundary_set, - consumer_set, visited); - } - } -} // Try to reduce the boundary set to its common ancestor void TryReduceWithCommonAncestor(StableHashSet& replicated_set, @@ -1476,281 +1092,6 @@ void UseAllReduceForGradAcc(StableHashSet& replicated_set, } } -// Substitute all-reduce strategies with their reduce-scatter variants. -void GenerateReduceScatter(const HloInstructionSequence& sequence, - const AliasMap& alias_map, - const InstructionDepthMap& depth_map, - const StrategyMap& strategy_map, - const CostGraph& cost_graph, - absl::Span s_val, - const ClusterEnvironment& cluster_env, - const AutoShardingSolverOption& solver_option) { - const std::vector& instructions = sequence.instructions(); - - // Propagation ends at output - const HloInstruction* output = instructions.back(); - if (IsCustomCallMarker(output)) { - output = output->operand(0); - } - - // A debug option: whether to do all-gather after backward pass. - // This controls the location of all-gather. - // If true, all-gather happens after backward pass, which is desired for - // gradient accumulation. If false, all-gather happens before forward pass, - // which can partitions more tensors. - bool do_all_gather_after_backward = true; - - // If true, do not actually generate reduce-scatter + all-gather, - // but generate all-reduce + all-gather instead. - // This saves less memory but is more friendly to gradient accumulation. - // This is a temporary workaround due to implementation difficulty. - // Ideally, we should be able to generate a gradient-accumulation-friendly - // reduce-scatter + all-gather, but for now it is not easy to implement this - // in our current system. So we generate a gradient-accumulation-friendly - // all-reduce + all-gather, which has the same memory consumption but with 50% - // communication overhead. - bool use_all_reduce_for_grad_acc = - solver_option.reduce_scatter_grad_acc_friendly; - - std::vector insert_all_gather; - StableHashSet modified; - - for (HloInstruction* inst : instructions) { - if (!HasReduceScatterOpportunity(inst, strategy_map, cost_graph, s_val, - modified)) { - continue; - } - const ShardingStrategy& strategy = - GetShardingStrategy(inst, strategy_map, cost_graph, s_val); - if (!absl::StrContains(strategy.name, "allreduce")) { - continue; - } - - StableHashSet replicated_set; - StableHashSet boundary_set; - StableHashSet consumer_set; - StableHashSet visited; - - // We allow at most one transpose in the path of replication analysis. - HloInstruction* transpose_inst = nullptr; - - // Find the replicated set starting from the all-reduce instruction. - visited.insert(output); - FindReplicateSet(inst, alias_map, cost_graph, s_val, strategy_map, strategy, - output, do_all_gather_after_backward, transpose_inst, - replicated_set, boundary_set, consumer_set, visited); - - // Try to reduce the boundary set to its common ancestor - TryReduceWithCommonAncestor(replicated_set, boundary_set, consumer_set, - alias_map); - - // Analyze the instructions after which all-gather should be inserted. - std::vector need_all_gather; - for (HloInstruction* node : boundary_set) { - if (consumer_set.contains(node)) { - if (AllUsersAreReduce(node)) { - // If users are reduce, the all-gather cost after this instruction - // should be small, so we ignore all-gather cost of these - // instructions. - replicated_set.insert(node); - } else { - need_all_gather.push_back(node); - } - } - } - - // If we do all-gather on some parameters, move this all-gather after - // backward. - if (do_all_gather_after_backward && need_all_gather.size() == 1) { - HloInstruction* point = need_all_gather.front(); - std::vector path; - - HloInstruction* root = point; - while (true) { - path.push_back(root); - if (root->opcode() == HloOpcode::kGetTupleElement) { - root = PassThroughCustomCallMarkerOperand(root->mutable_operand(0), - root); - } else { - break; - } - } - - if (root->opcode() == HloOpcode::kParameter) { - for (auto x : path) { - replicated_set.erase(x); - boundary_set.erase(x); - } - need_all_gather.clear(); - for (auto x : replicated_set) { - auto iter = alias_map.find(x); - if (iter != alias_map.end() && iter->second == root) { - boundary_set.insert(x); - need_all_gather.push_back(x); - break; - } - } - } - } - - // Analyze how many parameters can be partitioned if we do this - // transformation. - int num_replicated_parameters = 0; - for (const HloInstruction* node : replicated_set) { - if (node->opcode() == HloOpcode::kParameter) { - num_replicated_parameters++; - } - } - for (const HloInstruction* to_split : need_all_gather) { - if (to_split->users().size() == 1 && - to_split->users().front() == output && alias_map.contains(to_split)) { - // Move the all-gather to its alias parameter. - num_replicated_parameters++; - } - } - - // Print replicated set and boundary set for debugging. - VLOG(10) << inst->ToString(HloPrintOptions::ShortParsable()) << "\n"; - VLOG(10) << "replicated set (#parameter: " << num_replicated_parameters - << "):\n"; - for (auto x : replicated_set) { - VLOG(10) << " " << x->ToString(HloPrintOptions::ShortParsable()) << "\n"; - } - VLOG(10) << "boundary set (#incompatible: " << need_all_gather.size() - << "):\n"; - for (auto x : boundary_set) { - VLOG(10) << " " << x->ToString(HloPrintOptions::ShortParsable()) << " " - << absl::c_linear_search(need_all_gather, x) << "\n"; - } - - // If applicable, replace all-reduce with reduce-scatter by - // setting instructions' sharding. - if (num_replicated_parameters >= 1 && need_all_gather.size() <= 1 && - replicated_set.size() >= 5) { - HloSharding output_spec = - GetReduceScatterOutput(inst, strategy, cluster_env); - if (IsUndefined(output_spec)) { - continue; - } - - VLOG(10) << "SET: " << output_spec.ToString(); - - if (absl::StartsWith(strategy.name, "RR = RS x SR")) { - // If set the sharding for this dot instruction, the SPMD - // partitioner will generate bad fallback code. - replicated_set.erase(inst); - } - - if (use_all_reduce_for_grad_acc) { - UseAllReduceForGradAcc(replicated_set, inst); - } - - for (HloInstruction* to_split : replicated_set) { - SetSharding(to_split, output_spec, inst, transpose_inst, modified); - } - - if (!solver_option.reduce_scatter_aggressive_partition) { - // The normal case - for (HloInstruction* to_split : need_all_gather) { - SetSharding(to_split, output_spec, inst, transpose_inst, modified); - - if (!do_all_gather_after_backward && to_split->users().size() == 1 && - to_split->users().front() == output && - alias_map.contains(to_split)) { - // Move the all-gather to its alias parameter. - // This partitions more tensors but introduces communication - // in the forward pass, which is not desired in gradient - // accumulation. - SetSharding(alias_map.at(to_split), output_spec, inst, - transpose_inst, modified); - insert_all_gather.push_back(alias_map.at(to_split)); - } else { - insert_all_gather.push_back(to_split); - - if (to_split->opcode() == HloOpcode::kGetTupleElement && - IsCustomCallMarker(to_split->operand(0)) && - to_split->users().size() == 1 && - to_split->users().front() == output) { - insert_all_gather.push_back(PassThroughCustomCallMarkerOperand( - to_split->mutable_operand(0), to_split)); - } - } - } - } else { - // Aggressively partition more parameter tensors. - // This can result in a strategy similar to ZeRO stage 3. - // NOTE: The combination of this branch with pipeline parallel is not - // tested. - for (HloInstruction* to_split : need_all_gather) { - SetSharding(to_split, output_spec, inst, transpose_inst, modified); - - if (to_split->users().size() == 1 && - to_split->users().front() == output && - alias_map.contains(to_split)) { - // Move the all-gather to its alias parameter. - HloInstruction* param = alias_map.at(to_split); - - // Find the branching point (i.e., skip elementwise ops like - // convert) - HloInstruction* cur = param; - while (cur->users().size() == 1) { - // TODO(zhuohan): handle tuple. - CHECK(cur->shape().IsArray()); - SetSharding(cur, output_spec, inst, transpose_inst, modified); - cur = cur->users().front(); - } - SetSharding(cur, output_spec, inst, transpose_inst, modified); - - CHECK(!cur->users().empty()); - - // Find the first user - HloInstruction* first_user = nullptr; - int64_t min_depth = ((int64_t)1) << 50; - for (const auto& x : cur->users()) { - auto iter = depth_map.find(x); - if (iter == depth_map.end()) { - LOG(FATAL) << "ERROR: " << x->ToString(); - } - if (x->opcode() != HloOpcode::kConvolution && - x->opcode() != HloOpcode::kDot) { - // Only apply this aggressive optimization for dot and conv - continue; - } - if (iter->second < min_depth) { - first_user = x; - min_depth = iter->second; - } - } - - if (first_user != nullptr) { - // Insert an identity to prevent CSE of all-gather - HloInstruction* identity = inst->parent()->AddInstruction( - HloInstruction::CreateCustomCall(cur->shape(), {cur}, - kIdentityMarker)); - SetSharding(identity, output_spec, inst, transpose_inst, - modified); - ReplaceOperand(first_user, cur, identity); - } - } - } - } - } - - VLOG(10) << "-----------------------done\n"; - } - - // Insert all-gather on the output of boundary nodes by setting - // their shardings. This also works as CSE of all-gather. - for (HloInstruction* inst : insert_all_gather) { - HloInstruction* replace_with = inst->parent()->AddInstruction( - HloInstruction::CreateReshape(inst->shape(), inst)); - replace_with->set_sharding( - GetShardingStrategy(inst, strategy_map, cost_graph, s_val) - .output_sharding); - TF_CHECK_OK(inst->ReplaceAllUsesWith(replace_with)); - } -} - void RemoveCustomCallMarker(HloModule* module) { HloComputation* entry_computation = module->entry_computation(); @@ -2029,133 +1370,21 @@ void FixMixedMeshShapeResharding(HloInstruction* inst, int operand_num, TF_CHECK_OK(inst->ReplaceOperandWith(operand_num, replace_with)); } -template -inline std::vector Argsort(const std::vector& scores) { - std::vector index; - index.reserve(scores.size()); - for (size_t i = 0; i < scores.size(); ++i) { - index.push_back(i); +bool IsParameterConvert(const HloInstruction* inst) { + if (inst->opcode() == HloOpcode::kConvert && + inst->operand(0)->opcode() == HloOpcode::kParameter) { + return true; } - auto cmp = [&scores](int l, int r) { return scores[l] > scores[r]; }; - std::sort(index.begin(), index.end(), cmp); - return index; + return false; } -void AnnotateShardingWithSimpleHeuristic( - HloModule* module, const std::string& heuristic, const AliasMap& alias_map, - const ClusterEnvironment& cluster_env) { - const Array& device_mesh = cluster_env.device_mesh_; - const Array& device_mesh_1d = cluster_env.device_mesh_1d_; - int64_t num_devices = device_mesh.num_elements(); - - // Count the non-one mesh dimension. - size_t mesh_nn_dims = 0; - for (int dim : device_mesh.dimensions()) { - if (dim > 1) { - mesh_nn_dims++; - } - } - - // Shard instructions - HloComputation* entry_computation = module->entry_computation(); - for (HloInstruction* inst : entry_computation->instructions()) { - if (inst->opcode() == HloOpcode::kParameter) { - HloSharding output_spec = HloSharding::Replicate(); - inst->set_sharding(output_spec); - - if (heuristic == "shard-largest") { - std::vector lengths; - for (int64_t i = 0; i < inst->shape().rank(); ++i) { - lengths.push_back(inst->shape().dimensions(i)); - } - - std::vector indices = Argsort(lengths); - int common_dims = std::min(mesh_nn_dims, indices.size()); - - if (common_dims < 1) { - continue; - } - - if (common_dims == 1) { - int dim = indices[0]; - int length = lengths[dim]; - if (length % num_devices == 0) { - output_spec = Tile(inst->shape(), {dim}, {0}, device_mesh_1d); - } - } else { - int dim1 = indices[0]; - int length1 = lengths[dim1]; - int dim0 = indices[1]; - int length0 = lengths[dim0]; - - if (length0 % device_mesh.dim(0) == 0 && - length1 % device_mesh.dim(1) == 0) { - output_spec = - Tile(inst->shape(), {dim0, dim1}, {0, 1}, device_mesh); - } - } - } else if (heuristic == "shard-first") { - if (inst->shape().rank() > 0 && - inst->shape().dimensions(0) % num_devices == 0) { - output_spec = Tile(inst->shape(), {0}, {0}, device_mesh_1d); - } - } else if (heuristic == "shard-last") { - int64_t last_dim = inst->shape().rank() - 1; - if (inst->shape().rank() > 0 && - inst->shape().dimensions(last_dim) % num_devices == 0) { - output_spec = Tile(inst->shape(), {last_dim}, {0}, device_mesh_1d); - } - } else { - LOG(FATAL) << "Invalid heuristic: " << heuristic; - } - - inst->set_sharding(output_spec); - // std::cerr << "ins: " << inst->ToString() << ", spec: " << - // output_spec.ToString() << std::endl; - } else if (inst->opcode() == HloOpcode::kDot) { - const HloInstruction* lhs = inst->operand(0); - const HloInstruction* rhs = inst->operand(1); - const DotDimensionNumbers& dot_dnums = inst->dot_dimension_numbers(); - // const auto& lhs_con_dims = dot_dnums.lhs_contracting_dimensions(); - // const auto& rhs_con_dims = dot_dnums.rhs_contracting_dimensions(); - std::vector lhs_space_dims, rhs_space_dims; - std::tie(lhs_space_dims, rhs_space_dims) = - GetSpaceDims(lhs->shape(), rhs->shape(), dot_dnums); - } - } - - // Meet the alias requirement for the output tuple. - HloInstruction* output = entry_computation->root_instruction(); - const Shape& out_shape = output->shape(); - ShapeTree tuple_sharding(out_shape, HloSharding::Replicate()); - std::vector flattened_shardings; - - std::function get_flattened_shardings; - get_flattened_shardings = [&](HloInstruction* cur) { - for (int64_t i = 0; i < cur->operand_count(); ++i) { - HloInstruction* operand = cur->mutable_operand(i); - - if (operand->shape().IsTuple()) { - get_flattened_shardings(operand); - } else { - if (alias_map.contains(operand)) { - operand = alias_map.at(operand); - } - if (!operand->has_sharding()) { - operand->set_sharding(HloSharding::Replicate()); - } - CHECK(operand->has_sharding()); - flattened_shardings.push_back(operand->sharding()); - } +bool AllUsersAreReduce(const HloInstruction* inst) { + for (const HloInstruction* user : inst->users()) { + if (user->opcode() != HloOpcode::kReduce) { + return false; } - }; - get_flattened_shardings(output); - int i = 0; - for (auto& leaf : tuple_sharding.leaves()) { - leaf.second = flattened_shardings[i++]; } - CHECK_EQ(i, flattened_shardings.size()); - output->set_sharding(HloSharding::Tuple(tuple_sharding)); + return true; } std::vector GetDimensionMapping( @@ -2633,7 +1862,6 @@ bool AdjustShardingsWithPartialMeshShape( if (!inst->has_sharding()) { continue; } - LOG(INFO) << inst->ToString(); if (inst->shape().IsTuple()) { ShapeTree output_tuple_sharding(inst->shape(), Undefined()); std::vector output_flattened_shardings; @@ -2703,5 +1931,18 @@ bool OutputInputSameShapes(const HloInstruction* ins) { return true; } +bool IsEntryComputationInputOrOutput(const HloModule* module, + const HloInstruction* ins) { + for (const auto param : + module->entry_computation()->parameter_instructions()) { + if (param->name() == ins->name()) { + return true; + } + } + if (module->entry_computation()->root_instruction() == ins) { + return true; + } + return false; +} } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index a97c07165ab..dfba0d6a459 100644 --- a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -32,38 +32,33 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/array.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_schedule.h" -#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_schedule.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_sharding.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" namespace xla { namespace spmd { -// Type alias - -template -using StableHashMap = ::absl::flat_hash_map; -template -using StableHashSet = ::absl::flat_hash_set; - -// Map an instruction to its depth. -using InstructionDepthMap = StableHashMap; -// Map an instruction to its batch dimension. -using InstructionBatchDimMap = StableHashMap; -// Map an instruction to its alias source parameter. -using AliasMap = StableHashMap; -// Map an instruction to its resharding cache. -using ReshardingCache = - StableHashMap>>; inline constexpr absl::string_view kPipelineMarker = "xla_pipeline_marker"; inline constexpr absl::string_view kIdentityMarker = "identity"; inline constexpr absl::string_view kPipelineMarkerStartType = "start"; inline constexpr absl::string_view kPipelineMarkerEndType = "end"; +inline std::pair ParseMeshDims(const std::string& strategy_name) { + if (absl::StrContains(strategy_name, "{0,1}")) { + return {0, 1}; + } + return {1, 0}; +} + +// Return whether the tensor shape is divisible by +// the number of devices along multiple dimensions. +bool IsDivisible(const HloInstruction* ins, const Array& device_mesh, + absl::Span tensor_dims, + absl::Span mesh_dims); + // Array/Vector/Matrix Utility // Append elements of `array` to `result`. The `indices` is a generalized @@ -131,86 +126,6 @@ std::string ToString(absl::Span span) { return absl::StrCat("[", absl::StrJoin(span, ", "), "]"); } -// A simple matrix class to store and manipulate the cost matrices on edges. -// It can create a view for matrix transpose without copying the memory. -// TODO (zhuohan): Inherit from Array2D and add Transpose and operator+ (See -// tensorflow/compiler/xla/array2d.h;l=39) -class Matrix { - public: - Matrix() : n_(0), m_(0), transpose_(false), data_(nullptr) {} - - Matrix(size_t n, size_t m) { - this->n_ = n; - this->m_ = m; - transpose_ = false; - data_ = std::make_shared>(n * m, 0.0); - } - - Matrix(size_t n, size_t m, bool transpose, - std::shared_ptr> data) { - this->n_ = n; - this->m_ = m; - this->transpose_ = transpose; - this->data_ = data; - } - - Matrix Transpose() { return Matrix(m_, n_, !transpose_, data_); } - - double operator()(size_t i, size_t j) const { - size_t idx; - if (transpose_) { - idx = j * n_ + i; - } else { - idx = i * m_ + j; - } - CHECK(data_ != nullptr) << n_ << " , " << m_; - CHECK(idx < n_ * m_) << idx << " , " << n_ << " , " << m_; - return (*data_)[idx]; - } - - double& operator()(size_t i, size_t j) { - size_t idx; - if (transpose_) { - idx = j * n_ + i; - } else { - idx = i * m_ + j; - } - CHECK(data_ != nullptr) << n_ << " , " << m_; - CHECK(idx < n_ * m_) << idx << " , " << n_ << " , " << m_; - return (*data_)[idx]; - } - - Matrix operator+(const Matrix& other) { - CHECK_EQ(n_, other.n_); - CHECK_EQ(m_, other.m_); - Matrix ret = Matrix(n_, m_); - for (size_t i = 0; i < n_; ++i) { - for (size_t j = 0; j < m_; ++j) { - ret(i, j) = operator()(i, j) + other(i, j); - } - } - return ret; - } - - std::string ToString() const { - std::string str; - - for (size_t i = 0; i < n_; ++i) { - for (size_t j = 0; j < m_; ++j) { - absl::StrAppend(&str, operator()(i, j), " "); - } - absl::StrAppend(&str, "\n"); - } - - return str; - } - - size_t n_; - size_t m_; - bool transpose_; - std::shared_ptr> data_; -}; - // Shape Utility // Get the bytes of an array shape without checking its layout. @@ -286,6 +201,128 @@ inline bool IsCustomCallMarker(const HloInstruction* inst) { return inst->IsCustomCall({kPipelineMarker, kIdentityMarker}); } +// Pass through the custom call marker and get the source instruction +inline const HloInstruction* PassThroughCustomCallMarkerGetSource( + const HloInstruction* ins) { + while (ins->opcode() == HloOpcode::kGetTupleElement && + IsCustomCallMarker(ins->operand(0))) { + const HloInstruction* custom_call = ins->operand(0); + const HloInstruction* tuple = custom_call->operand(0); + while (IsCustomCallMarker(tuple)) { + tuple = tuple->operand(0); + } + ins = tuple->operand(ins->tuple_index()); + } + return ins; +} + +// Pass through the custom call marker and get the acutal operand. +inline HloInstruction* PassThroughCustomCallMarkerOperand( + HloInstruction* raw_operand, const HloInstruction* inst) { + if (!IsCustomCallMarker(raw_operand)) { + return raw_operand; + } + + CHECK_EQ(inst->opcode(), HloOpcode::kGetTupleElement); + + int index = inst->tuple_index(); + return raw_operand->mutable_operand(0)->mutable_operand(index); +} + +// Return whether the tuple is only used by a custom call marker. +inline bool IsCustomCallMarkerTuple(const HloInstruction* inst) { + return inst->opcode() == HloOpcode::kTuple && inst->users().size() == 1 && + IsCustomCallMarker(inst->users().front()); +} + +// Pass through the custom call marker and get the actual user. +inline HloInstruction* PassThroughCustomCallMarkerUser( + HloInstruction* raw_user, const HloInstruction* inst) { + if (!IsCustomCallMarkerTuple(raw_user)) { + return raw_user; + } + + const HloInstruction* custom_call = raw_user->users().front(); + + int index = -1; + for (int i = 0; i < raw_user->operand_count(); i++) { + if (raw_user->operand(i) == inst) { + index = i; + break; + } + } + CHECK_NE(index, -1); + + HloInstruction* ret = nullptr; + for (HloInstruction* user : custom_call->users()) { + CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement); + if (user->tuple_index() == index) { + CHECK_EQ(ret, nullptr); + ret = user; + } + } + + return ret == nullptr ? raw_user : ret; +} + +// Return the users of an instruction and its alias, +// excluding the final output tuple. +inline StableHashSet UsersWithAlias( + const HloInstruction* inst, const AliasMap& alias_map, + const HloInstruction* output) { + StableHashSet users; + + for (HloInstruction* user : inst->users()) { + users.insert(PassThroughCustomCallMarkerUser(user, inst)); + } + + auto iter = alias_map.find(inst); + if (iter != alias_map.end()) { + for (HloInstruction* user : iter->second->users()) { + users.insert(PassThroughCustomCallMarkerUser(user, iter->second)); + } + } + + users.erase(output); + return users; +} + +// Return whether this instruction is a convert on a parameter. +bool IsParameterConvert(const HloInstruction* inst); + +// Return whether the instruction is always replicated. +// (e.g., constant, broadcasted constant, scalar) +bool IsAlwaysReplicated(const HloInstruction* inst); + +// Try to reduce the boundary set to its common ancestor +void TryReduceWithCommonAncestor(StableHashSet& replicated_set, + StableHashSet& boundary_set, + StableHashSet& consumer_set, + const AliasMap& alias_map); + +// Return whether all users of an instruction is reduce. +bool AllUsersAreReduce(const HloInstruction* inst); + +void UseAllReduceForGradAcc(StableHashSet& replicated_set, + const HloInstruction* inst); + +void SetSharding(HloInstruction* to_split, const HloSharding& output_spec, + const HloInstruction* ref_inst, + const HloInstruction* shape_inst, + StableHashSet& modified); + +template +inline std::vector Argsort(const std::vector& scores) { + std::vector index; + index.reserve(scores.size()); + for (size_t i = 0; i < scores.size(); ++i) { + index.push_back(i); + } + auto cmp = [&scores](int l, int r) { return scores[l] > scores[r]; }; + std::sort(index.begin(), index.end(), cmp); + return index; +} + // Return whether the reshape is a special reshape that switches the batch dim // of a dot. bool IsBatchDimSwitchReshape(const HloInstruction* inst); @@ -475,6 +512,11 @@ HloSharding Tile(const Shape& tensor_shape, absl::Span mesh_dims, const Array& device_mesh); +AliasMap BuildAliasMap(const HloModule* module); + +AliasSet BuildAliasSet(const HloModule* module, + const StrategyMap& strategy_map); + // Transpose an array of any number of dimensions given any axes order. // Similar to numpy.transpose(array, axes=()) function. template @@ -536,6 +578,9 @@ std::vector> DecomposeMeshShapes( std::vector mesh_shape); bool OutputInputSameShapes(const HloInstruction* ins); + +bool IsEntryComputationInputOrOutput(const HloModule* module, + const HloInstruction* ins); } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/cluster_environment.cc b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/cluster_environment.cc new file mode 100644 index 00000000000..9bc1eca4ed4 --- /dev/null +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/cluster_environment.cc @@ -0,0 +1,248 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/cluster_environment.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace xla { +namespace spmd { + +double ClusterEnvironment::AllGatherCost(double num_bytes, int mesh_dim) const { + if (solver_option_.override_all_gather_cost) { + return solver_option_.all_gather_cost; + } + + if (prof_result_.Enabled()) { + return prof_result_.EstimateAllGatherCost(cached_replica_groups_[mesh_dim], + num_bytes / 4, "float32"); + } + + if (solver_option_.force_batch_dim_to_mesh_dim == mesh_dim) { + // if data-parallel is forced on this dim, we only allow all-reduce + // in this dimension. + return kInfinityCost; + } + + int64_t num_devices = device_mesh_.dim(mesh_dim); + return (round(mesh_alpha_[mesh_dim] + mesh_beta_[mesh_dim] * + (num_devices - 1) / num_devices * + num_bytes) + + 0.1); +} + +// TODO(zhuohan): distinguish dtype and reduce_op. +double ClusterEnvironment::AllReduceCost(double num_bytes, int32_t mesh_dim, + int32_t mesh_dim_another) const { + if (solver_option_.override_all_reduce_cost) { + return solver_option_.all_reduce_cost; + } + + if (prof_result_.Enabled()) { + return prof_result_.EstimateAllReduceCost(cached_replica_groups_[mesh_dim], + num_bytes / 4, "float32"); + } + double alpha, beta; + int64_t num_devices; + if (mesh_dim_another == -1) { + // Only communicating on one mesh dimension. + alpha = mesh_alpha_[mesh_dim]; + beta = mesh_beta_[mesh_dim]; + num_devices = device_mesh_.dim(mesh_dim); + } else { + // Communicating through both mesh dimensions. + alpha = std::max(mesh_alpha_[mesh_dim], mesh_alpha_[mesh_dim_another]); + beta = std::max(mesh_beta_[mesh_dim], mesh_beta_[mesh_dim_another]); + num_devices = device_mesh_.num_elements(); + } + return ( + round(alpha + beta * 2 * (num_devices - 1) / num_devices * num_bytes) + + 0.01); +} + +double ClusterEnvironment::ReduceScatterCost(double num_bytes, + int mesh_dim) const { + if (solver_option_.override_reduce_scatter_cost) { + return solver_option_.reduce_scatter_cost; + } + + if (prof_result_.Enabled()) { + return prof_result_.EstimateReduceScatterCost( + cached_replica_groups_[mesh_dim], num_bytes / 4, "float32"); + } + + int64_t num_devices = device_mesh_.dim(mesh_dim); + return (round(mesh_alpha_[mesh_dim] + mesh_beta_[mesh_dim] * + (num_devices - 1) / num_devices * + num_bytes) + + 0.001); +} + +double ClusterEnvironment::AllToAllCost(double num_bytes, int mesh_dim) const { + if (solver_option_.override_all_to_all_cost) { + return solver_option_.all_to_all_cost; + } + + if (prof_result_.Enabled()) { + return prof_result_.EstimateAllToAllCost(cached_replica_groups_[mesh_dim], + num_bytes / 4, "float32"); + } + + if (solver_option_.force_batch_dim_to_mesh_dim == mesh_dim) { + // if data-parallel is forced on this dim, we only allow all-reduce + // in this dimension. + return kInfinityCost; + } + + int64_t num_devices = device_mesh_.dim(mesh_dim); + return AllToAllCostUtil(num_bytes, mesh_dim, num_devices, mesh_alpha_, + mesh_beta_); +} + +double ClusterEnvironment::DotCost(const Shape& lhs_shape, + const Shape& rhs_shape, + const DotDimensionNumbers& dot_dnums) const { + if (!solver_option_.allow_recompute_heavy_op) { + return kInfinityCost; + } + + // TODO(zhuohan): When profiling data is not available, it is not easy to + // align the scale of compute cost and communication cost. Here we just use + // a simple heuristic to compute the compute cost with communication cost. + double num_bytes = GetBytes(lhs_shape) + GetBytes(rhs_shape); + return AllReduceCost(num_bytes, 0) + AllReduceCost(num_bytes, 1); +} + +// The communication cost of resharding a tensor from src to dst +// TODO(b/238210866) Do not use kInfinityCost. +double ClusterEnvironment::ReshardingCost(const Shape& shape, + const HloSharding& src_spec, + const HloSharding& dst_spec) const { + // TODO(zhuohan): This function can be wrong and needs more tests. + if (src_spec == dst_spec || IsUndefined(src_spec)) { + return 0.0; + } + CHECK(!IsUndefined(dst_spec)); + int64_t src_n_dim = NumTileDimensions(src_spec); + int64_t dst_n_dim = NumTileDimensions(dst_spec); + // When src_spec and dst_spec are for arrays with different number of + // dimensions, which could happen when an instruction follows the sharding + // of an operand with a different shape, we need to use their + // TiledDataRank(). + size_t src_rank = shape.rank(); + if (src_spec.IsTiled()) { + src_rank = src_spec.TiledDataRank(); + } + size_t dst_rank = shape.rank(); + if (dst_spec.IsTiled()) { + dst_rank = dst_spec.TiledDataRank(); + } + std::vector src_tensor_dim_to_mesh_dim; + if (VectorGreaterThanOneElementCount( + src_spec.tile_assignment().dimensions()) == 1 && + VectorGreaterThanOneElementCount(device_mesh_.dimensions()) > 1) { + // src spec is 1D and device_mesh is 2D or 3D + src_tensor_dim_to_mesh_dim = + GetTensorDimToMeshDim(src_rank, src_spec, device_mesh_1d_); + } else { + src_tensor_dim_to_mesh_dim = + GetTensorDimToMeshDim(src_rank, src_spec, device_mesh_); + } + std::vector dst_tensor_dim_to_mesh_dim; + if (VectorGreaterThanOneElementCount( + dst_spec.tile_assignment().dimensions()) == 1 && + VectorGreaterThanOneElementCount(device_mesh_.dimensions()) > 1) { + // src spec is 1D and device_mesh is 2D or 3D + dst_tensor_dim_to_mesh_dim = + GetTensorDimToMeshDim(dst_rank, dst_spec, device_mesh_1d_); + } else { + dst_tensor_dim_to_mesh_dim = + GetTensorDimToMeshDim(dst_rank, dst_spec, device_mesh_); + } + if (src_n_dim != dst_n_dim && src_n_dim != -1 && dst_n_dim != -1) { + return ReshardingCostMixedMeshShape( + shape, src_tensor_dim_to_mesh_dim, dst_tensor_dim_to_mesh_dim, + device_mesh_.num_elements(), mesh_alpha_, mesh_beta_); + } + + AdjustTensorMeshDimMapping(src_tensor_dim_to_mesh_dim, src_n_dim); + AdjustTensorMeshDimMapping(dst_tensor_dim_to_mesh_dim, dst_n_dim); + + // Analyze the dims that need to dynamic-sliced or all-gather. + std::vector slice_dims; + std::vector all_gather_dims; + for (int64_t i = 0; i < std::min(src_rank, dst_rank); ++i) { + int src_mesh_dim = src_tensor_dim_to_mesh_dim[i]; + int dst_mesh_dim = dst_tensor_dim_to_mesh_dim[i]; + if (src_mesh_dim == dst_mesh_dim) { + continue; + } + if (src_mesh_dim == -1) { + slice_dims.push_back(src_mesh_dim); + continue; + } + if (dst_mesh_dim == -1) { + all_gather_dims.push_back(src_mesh_dim); + continue; + } + // Do not allow other re-sharding patterns. (e.g., collective-permute) + return kInfinityCost; + } + + // Case 1: no communication is required. Only needs dynamic-slice. + if (all_gather_dims.empty()) { + return 0; + } + + // Do not allow some strange re-sharding patterns. + if (slice_dims.size() > 1 && all_gather_dims.size() > 1) { + return kInfinityCost; + } + + // Case 2: all-to-all + if (slice_dims.size() == 1 && all_gather_dims.size() == 1) { + if (device_mesh_.dim(0) > 1 && device_mesh_.dim(1) > 1) { + return kInfinityCost; + } + + double bytes = GetBytes(shape); + return AllToAllCost(bytes, all_gather_dims.front()); + } + + // Case 3: all-gather + double bytes = GetBytes(shape) / src_spec.NumTiles(); + double cost = 0.0; + for (int dim : all_gather_dims) { + if (dim >= device_mesh_.num_dimensions()) { + return kInfinityCost; + } + bytes *= device_mesh_.dim(dim); + cost += AllGatherCost(bytes, dim); + } + return cost; +} +} // namespace spmd +} // namespace xla diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/cluster_environment.h b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/cluster_environment.h new file mode 100644 index 00000000000..b6c217f35d5 --- /dev/null +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/cluster_environment.h @@ -0,0 +1,170 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_CLUSTER_ENVIRONMENT_H_ +#define TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_CLUSTER_ENVIRONMENT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_solver_option.h" +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_util.h" +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/profiling_result.h" + +namespace xla { +namespace spmd { + +// The cluster has a multi-dimensional device mesh topology. +// Each mesh dimension has its own latency and bandwidth. +// We use alpha-beta model to model the communication cost. +// If profiling result is provided, we always prefer to use +// the real profiling result. +class ClusterEnvironment { + public: + ClusterEnvironment(const Array& original_device_mesh, + const Array& device_mesh, + absl::Span mesh_alpha, + absl::Span mesh_beta, + const ProfilingResult& prof_result, + const AutoShardingSolverOption& solver_option) + : original_device_mesh_(original_device_mesh), + device_mesh_(device_mesh), + mesh_alpha_(mesh_alpha.begin(), mesh_alpha.end()), + mesh_beta_(mesh_beta.begin(), mesh_beta.end()), + prof_result_(prof_result), + total_devices_(device_mesh.num_elements()), + device_mesh_1d_(original_device_mesh), + solver_option_(solver_option) { + // Build replica group for each dimension. + non_zero_mesh_dims_ = + VectorGreaterThanOneElementIndices(device_mesh.dimensions()); + GenerateCachedReplicaGroups(); + // TODO(yuemmawang) Find the largest dimension in original_device_mesh and + // create 1d mesh on that dimension. + device_mesh_1d_.Reshape({original_device_mesh.num_elements(), 1}); + } + + size_t NumDevices() const { return total_devices_; } + + bool IsDeviceMesh3D() const { + return VectorGreaterThanOneElementCount(device_mesh_.dimensions()) == 3; + } + + bool IsDeviceMesh2D() const { + return VectorGreaterThanOneElementCount(device_mesh_.dimensions()) == 2; + } + + bool IsDeviceMesh1D() const { + return VectorGreaterThanOneElementCount(device_mesh_.dimensions()) == 1; + } + + bool IsOriginalDeviceMesh2D() const { + return VectorGreaterThanOneElementCount( + original_device_mesh_.dimensions()) == 2; + } + + // Get the corresponding mesh dimension for every tensor dimension. + // -1 means replicated on that dimension + std::vector GetTensorDimToMeshDimWrapper( + const Shape& shape, const HloSharding& spec) const { + int64_t n_dim = NumTileDimensions(spec); + std::vector tensor_dim_to_mesh_dim = + GetTensorDimToMeshDim(shape.rank(), spec, device_mesh_); + AdjustTensorMeshDimMapping(tensor_dim_to_mesh_dim, n_dim); + return tensor_dim_to_mesh_dim; + } + + double AllGatherCost(double num_bytes, int mesh_dim) const; + + double AllReduceCost(double num_bytes, int32_t mesh_dim, + int32_t mesh_dim_another = -1) const; + + double ReduceScatterCost(double num_bytes, int mesh_dim) const; + + double AllToAllCost(double num_bytes, int mesh_dim) const; + + double DotCost(const Shape& lhs_shape, const Shape& rhs_shape, + const DotDimensionNumbers& dot_dnums) const; + + double ReshardingCost(const Shape& shape, const HloSharding& src_spec, + const HloSharding& dst_spec) const; + + // Print the information of this device mesh. + std::string ToString() { + std::string str; + absl::StrAppend(&str, "device_mesh: ", device_mesh_.ToString(), "\n"); + absl::StrAppend(&str, "mesh_alpha: ", absl::StrJoin(mesh_alpha_, " "), + "\n"); + absl::StrAppend(&str, "mesh_beta: ", absl::StrJoin(mesh_beta_, " "), "\n"); + return str; + } + + // The original, complete device mesh shape that describes the hardware. + const Array original_device_mesh_; + // When solve_nd_sharding_iteratively is true, it is a partial mesh shape from + // the original_device_mesh_. When solve_nd_sharding_iteratively is false, it + // is the same as original_device_mesh_. + const Array device_mesh_; + // Bandwidth of the device mesh + const std::vector mesh_alpha_; + const std::vector mesh_beta_; + const ProfilingResult& prof_result_; + std::vector non_zero_mesh_dims_; + const int total_devices_; + + // Cache a flatten 1d version of the device mesh. + // Used for mixed mesh shape strategies. + Array device_mesh_1d_; + + // The solver option may override the cost of communication primitives + const AutoShardingSolverOption& solver_option_; + + // Cached replica groups. Shape: [mesh_dim, group_id, ids in this group]. + std::vector>> cached_replica_groups_; + + private: + void GenerateCachedReplicaGroups() { + // One vector per device_mesh_ dimension. + cached_replica_groups_.reserve(device_mesh_.num_dimensions()); + for (size_t i = 0; i < device_mesh_.num_dimensions(); i++) { + cached_replica_groups_.push_back( + GetReplicaGroupsAlongOneDimension(device_mesh_, i)); + } + } + + void AdjustTensorMeshDimMapping(std::vector& mapping, + int64_t n_dim) const { + // Shift the non-zero dim for 1d mesh + if (n_dim == 1 && non_zero_mesh_dims_.size() == 1) { + for (size_t i = 0; i < mapping.size(); ++i) { + if (mapping[i] == 0) { + mapping[i] = non_zero_mesh_dims_.front(); + } + } + } + } +}; +} // namespace spmd +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_CLUSTER_ENVIRONMENT_H_ diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/matrix.h b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/matrix.h new file mode 100644 index 00000000000..f18bc36efea --- /dev/null +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/matrix.h @@ -0,0 +1,116 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_MATRIX_H_ +#define TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_MATRIX_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/tsl/platform/logging.h" + +namespace xla { +namespace spmd { +// A simple matrix class to store and manipulate the cost matrices on edges. +// It can create a view for matrix transpose without copying the memory. +// TODO (zhuohan): Inherit from Array2D and add Transpose and operator+ (See +// tensorflow/compiler/xla/array2d.h;l=39) +class Matrix { + public: + Matrix() : n_(0), m_(0), transpose_(false), data_(nullptr) {} + + Matrix(size_t n, size_t m) { + this->n_ = n; + this->m_ = m; + transpose_ = false; + data_ = std::make_shared>(n * m, 0.0); + } + + Matrix(size_t n, size_t m, bool transpose, + std::shared_ptr> data) { + this->n_ = n; + this->m_ = m; + this->transpose_ = transpose; + this->data_ = data; + } + + Matrix Transpose() { return Matrix(m_, n_, !transpose_, data_); } + + double operator()(size_t i, size_t j) const { + size_t idx; + if (transpose_) { + idx = j * n_ + i; + } else { + idx = i * m_ + j; + } + CHECK(data_ != nullptr) << n_ << " , " << m_; + CHECK(idx < n_ * m_) << idx << " , " << n_ << " , " << m_; + return (*data_)[idx]; + } + + double& operator()(size_t i, size_t j) { + size_t idx; + if (transpose_) { + idx = j * n_ + i; + } else { + idx = i * m_ + j; + } + CHECK(data_ != nullptr) << n_ << " , " << m_; + CHECK(idx < n_ * m_) << idx << " , " << n_ << " , " << m_; + return (*data_)[idx]; + } + + Matrix operator+(const Matrix& other) { + CHECK_EQ(n_, other.n_); + CHECK_EQ(m_, other.m_); + Matrix ret = Matrix(n_, m_); + for (size_t i = 0; i < n_; ++i) { + for (size_t j = 0; j < m_; ++j) { + ret(i, j) = operator()(i, j) + other(i, j); + } + } + return ret; + } + + std::string ToString() const { + std::string str; + + for (size_t i = 0; i < n_; ++i) { + for (size_t j = 0; j < m_; ++j) { + absl::StrAppend(&str, operator()(i, j), " "); + } + absl::StrAppend(&str, "\n"); + } + + return str; + } + + size_t n_; + size_t m_; + bool transpose_; + std::shared_ptr> data_; +}; +} // namespace spmd +} // namespace xla +#endif // TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_MATRIX_H_ diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/metrics.cc b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/metrics.cc new file mode 100644 index 00000000000..4431c9ac81b --- /dev/null +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/metrics.cc @@ -0,0 +1,48 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/metrics.h" + +#include +#include + +#include "tensorflow/tsl/lib/monitoring/counter.h" + +namespace xla { +namespace metrics { +namespace { + +auto* xla_auto_sharding_invocations = tsl::monitoring::Counter<0>::New( + "/tensorflow/compiler/xla/hlo/xla_auto_sharding_invocations", + "The number of XLA auto sharding invocations used to collect " + "/tensorflow/compiler/xla/hlo/xla_compilation_time_in_auto_sharding_usecs"); + +auto* auto_sharding_compilation_time_usecs = tsl::monitoring::Counter<0>::New( + "/tensorflow/compiler/xla/hlo/xla_compilation_time_in_auto_sharding_usecs", + "The total time spent on compiling XLA graphs in auto sharding pass in in " + "microseconds."); + +} // namespace + +void RecordAutoShardingInvocations() { + xla_auto_sharding_invocations->GetCell()->IncrementBy(1); +} + +void RecordAutoShardingCompilationTime(const uint64_t time_usecs) { + auto_sharding_compilation_time_usecs->GetCell()->IncrementBy(time_usecs); +} + +} // namespace metrics +} // namespace xla diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/metrics.h b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/metrics.h new file mode 100644 index 00000000000..4555f6ff6af --- /dev/null +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/metrics.h @@ -0,0 +1,31 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_METRICS_H_ +#define TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_METRICS_H_ + +#include + +namespace xla { +namespace metrics { + +void RecordAutoShardingInvocations(); + +void RecordAutoShardingCompilationTime(uint64_t time_usecs); + +} // namespace metrics +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_METRICS_H_ diff --git a/tensorflow/compiler/xla/hlo/experimental/auto_sharding/profiling_result.h b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/profiling_result.h new file mode 100644 index 00000000000..92a3a417649 --- /dev/null +++ b/tensorflow/compiler/xla/hlo/experimental/auto_sharding/profiling_result.h @@ -0,0 +1,159 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_PROFILING_RESULT_H_ +#define TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_PROFILING_RESULT_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" + +namespace xla { +namespace spmd { + +// Store the profiling results of communication and computation. +class ProfilingResult { + public: + // TODO (zhuohan): loading the profiling result. + ProfilingResult() { + if (all_reduce_cost_dict_.empty()) { + enabled_ = false; + } else { + enabled_ = true; + } + } + + bool Enabled() const { return enabled_; } + + double EstimateAllGatherCost( + const std::vector>& replica_groups, int64_t size, + const std::string& dtype) const { + if (all_gather_cost_dict_.empty()) { + // Use all-reduce to approximate all-gather. + return EstimateAllReduceCost(replica_groups, size, dtype) / 2; + } + + return EstimateInternal(replica_groups, size, dtype, + all_gather_cost_dict_) - + EstimateInternal(replica_groups, 0, dtype, all_gather_cost_dict_); + } + + double EstimateAllReduceCost( + const std::vector>& replica_groups, int64_t size, + const std::string& dtype) const { + return EstimateInternal(replica_groups, size, dtype, + all_reduce_cost_dict_) - + EstimateInternal(replica_groups, 0, dtype, all_reduce_cost_dict_); + } + + double EstimateReduceScatterCost( + const std::vector>& replica_groups, int64_t size, + const std::string& dtype) const { + if (reduce_scatter_cost_dict_.empty()) { + // Use all-reduce to approximate reduce-scatter. + return EstimateAllReduceCost(replica_groups, size, dtype) / 2; + } + + return EstimateInternal(replica_groups, size, dtype, + reduce_scatter_cost_dict_) - + EstimateInternal(replica_groups, 0, dtype, + reduce_scatter_cost_dict_); + } + + double EstimateAllToAllCost( + const std::vector>& replica_groups, int64_t size, + const std::string& dtype) const { + // A penalty factor to make the theoretical cost match the + // empirical cost on v100 + nvlink. + int64_t num_devices = replica_groups.front().size(); + double penalty_factor = static_cast(num_devices) / 2.0; + // Use all-gather to approximate all-to-all. + return EstimateAllGatherCost(replica_groups, size / num_devices, dtype) * + penalty_factor; + } + + std::string ToString() { + std::string str; + for (const auto& item : all_reduce_cost_dict_) { + absl::StrAppend(&str, item.first.first, " ", item.first.second, "\n"); + } + return str; + } + + private: + // pair + using Key = std::pair; + // vector> + using Value = std::vector>; + + // Estimate the cost by linear interpolation between the two closest points. + double EstimateInternal( + const std::vector>& replica_groups, int64_t size, + const std::string& dtype, + const StableHashMap& cost_dict) const { + Key key(Group2Str(replica_groups), dtype); + Value cost_list = cost_dict.at(key); + + CHECK(!cost_list.empty()); + + size_t i; + if (size > cost_list.back().first) { + i = cost_list.size() - 2; + } else if (size < cost_list.front().first) { + i = 0; + } else { + for (i = 0; i < cost_list.size() - 1; ++i) { + if (cost_list[i].first <= size && size <= cost_list[i + 1].first) { + break; + } + } + } + + int64_t left_size = cost_list[i].first; + double left_cost = cost_list[i].second; + int64_t right_size = cost_list[i + 1].first; + double right_cost = cost_list[i + 1].second; + + return 1.0 * (size - left_size) / (right_size - left_size) * + (right_cost - left_cost) + + left_cost; + } + + // Make a string key of a replica_groups. + std::string Group2Str( + const std::vector>& replica_groups) const { + std::string str; + absl::StrAppend(&str, "("); + for (const auto& group : replica_groups) { + absl::StrAppend(&str, "(", absl::StrJoin(group, ","), ")"); + } + absl::StrAppend(&str, ")"); + + return str; + } + + bool enabled_; + StableHashMap all_reduce_cost_dict_; + StableHashMap all_gather_cost_dict_; + StableHashMap reduce_scatter_cost_dict_; +}; + +} // namespace spmd +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_PROFILING_RESULT_H_ diff --git a/tensorflow/compiler/xla/hlo/ir/BUILD b/tensorflow/compiler/xla/hlo/ir/BUILD index b4a9a21ac37..b70aec6060b 100644 --- a/tensorflow/compiler/xla/hlo/ir/BUILD +++ b/tensorflow/compiler/xla/hlo/ir/BUILD @@ -4,7 +4,8 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") package( - default_visibility = [":friends"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], licenses = ["notice"], ) @@ -56,6 +57,7 @@ cc_library( "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:printer", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", @@ -66,6 +68,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/compiler/xla/service:compilation_environments", "//tensorflow/compiler/xla/service:computation_placer_hdr", "//tensorflow/compiler/xla/service:hlo_module_config", @@ -95,3 +98,15 @@ cc_library( "@com_google_absl//absl/types:span", ], ) + +cc_library( + name = "hlo_module_group", + srcs = ["hlo_module_group.cc"], + hdrs = ["hlo_module_group.h"], + deps = [ + "//tensorflow/compiler/xla/hlo/ir:hlo", + "//tensorflow/compiler/xla/service:hlo_proto_cc", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) diff --git a/tensorflow/compiler/xla/hlo/ir/dfs_hlo_visitor.h b/tensorflow/compiler/xla/hlo/ir/dfs_hlo_visitor.h index 6fde433ed95..fd68cc08df9 100644 --- a/tensorflow/compiler/xla/hlo/ir/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/hlo/ir/dfs_hlo_visitor.h @@ -63,8 +63,8 @@ class DfsHloVisitorBase { "HloInstruction*"); public: - DfsHloVisitorBase() {} - virtual ~DfsHloVisitorBase() {} + DfsHloVisitorBase() = default; + virtual ~DfsHloVisitorBase() = default; // These routines are self-descriptive, see class comment for usage // information. @@ -112,26 +112,33 @@ class DfsHloVisitorBase { virtual Status HandleCbrt(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } - virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; - virtual Status HandleFft(HloInstructionPtr fft) = 0; - virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0; - virtual Status HandleCholesky(HloInstructionPtr hlo) = 0; - virtual Status HandleOptimizationBarrier(HloInstructionPtr hlo) = 0; + /* go/keep-sorted start */ virtual Status HandleAllGather(HloInstructionPtr hlo) = 0; - virtual Status HandleAllGatherStart(HloInstructionPtr hlo) = 0; virtual Status HandleAllGatherDone(HloInstructionPtr hlo) = 0; + virtual Status HandleAllGatherStart(HloInstructionPtr hlo) = 0; virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0; - virtual Status HandleReduceScatter(HloInstructionPtr hlo) = 0; - virtual Status HandleAllReduceStart(HloInstructionPtr hlo) = 0; virtual Status HandleAllReduceDone(HloInstructionPtr hlo) = 0; + virtual Status HandleAllReduceStart(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; - virtual Status HandleCollectivePermuteStart(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermuteDone(HloInstructionPtr hlo) = 0; - virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0; + virtual Status HandleCollectivePermuteStart(HloInstructionPtr hlo) = 0; + virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; + virtual Status HandleOptimizationBarrier(HloInstructionPtr hlo) = 0; virtual Status HandlePartitionId(HloInstructionPtr hlo) = 0; + virtual Status HandleReduceScatter(HloInstructionPtr hlo) = 0; + virtual Status HandleReplicaId(HloInstructionPtr hlo) = 0; + /* go/keep-sorted end */ + + /* go/keep-sorted start */ + virtual Status HandleCholesky(HloInstructionPtr hlo) = 0; + virtual Status HandleFft(HloInstructionPtr fft) = 0; + virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0; + /* go/keep-sorted end */ + virtual Status HandleGetDimensionSize(HloInstructionPtr hlo) = 0; virtual Status HandleSetDimensionSize(HloInstructionPtr hlo) = 0; + virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -195,6 +202,9 @@ class DfsHloVisitorBase { virtual Status HandleSin(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleTan(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleTanh(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } @@ -240,37 +250,42 @@ class DfsHloVisitorBase { return HandleElementwiseUnary(hlo); } + /* go/keep-sorted start */ virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; - virtual Status HandleRng(HloInstructionPtr hlo) = 0; - virtual Status HandleRngBitGenerator(HloInstructionPtr hlo) = 0; - virtual Status HandleRngGetAndUpdateState(HloInstructionPtr hlo) = 0; - virtual Status HandleReverse(HloInstructionPtr hlo) = 0; - virtual Status HandleSort(HloInstructionPtr hlo) = 0; - virtual Status HandleConstant(HloInstructionPtr hlo) = 0; - virtual Status HandleIota(HloInstructionPtr hlo) = 0; - virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0; - virtual Status HandleReduce(HloInstructionPtr hlo) = 0; + /* go/keep-sorted end */ + + /* go/keep-sorted start */ virtual Status HandleBitcast(HloInstructionPtr hlo) = 0; virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0; - virtual Status HandleReshape(HloInstructionPtr hlo) = 0; - virtual Status HandleDynamicReshape(HloInstructionPtr hlo) = 0; - virtual Status HandleTranspose(HloInstructionPtr hlo) = 0; - virtual Status HandleParameter(HloInstructionPtr hlo) = 0; - virtual Status HandleFusion(HloInstructionPtr hlo) = 0; virtual Status HandleCall(HloInstructionPtr hlo) = 0; + virtual Status HandleConditional(HloInstructionPtr hlo) = 0; + virtual Status HandleConstant(HloInstructionPtr hlo) = 0; virtual Status HandleCustomCall(HloInstructionPtr hlo) = 0; - virtual Status HandleSlice(HloInstructionPtr hlo) = 0; + virtual Status HandleDynamicReshape(HloInstructionPtr hlo) = 0; virtual Status HandleDynamicSlice(HloInstructionPtr hlo) = 0; virtual Status HandleDynamicUpdateSlice(HloInstructionPtr hlo) = 0; - virtual Status HandleTuple(HloInstructionPtr hlo) = 0; + virtual Status HandleFusion(HloInstructionPtr hlo) = 0; + virtual Status HandleGather(HloInstructionPtr hlo) = 0; + virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0; + virtual Status HandleIota(HloInstructionPtr hlo) = 0; virtual Status HandleMap(HloInstructionPtr hlo) = 0; + virtual Status HandleParameter(HloInstructionPtr hlo) = 0; + virtual Status HandleReduce(HloInstructionPtr hlo) = 0; virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0; + virtual Status HandleReshape(HloInstructionPtr hlo) = 0; + virtual Status HandleReverse(HloInstructionPtr hlo) = 0; + virtual Status HandleRng(HloInstructionPtr hlo) = 0; + virtual Status HandleRngBitGenerator(HloInstructionPtr hlo) = 0; + virtual Status HandleRngGetAndUpdateState(HloInstructionPtr hlo) = 0; + virtual Status HandleScatter(HloInstructionPtr hlo) = 0; virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0; + virtual Status HandleSlice(HloInstructionPtr hlo) = 0; + virtual Status HandleSort(HloInstructionPtr hlo) = 0; + virtual Status HandleTranspose(HloInstructionPtr hlo) = 0; + virtual Status HandleTuple(HloInstructionPtr hlo) = 0; virtual Status HandleWhile(HloInstructionPtr hlo) = 0; - virtual Status HandleConditional(HloInstructionPtr hlo) = 0; - virtual Status HandleGather(HloInstructionPtr hlo) = 0; - virtual Status HandleScatter(HloInstructionPtr hlo) = 0; + /* go/keep-sorted end */ virtual Status HandlePad(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc b/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc index 1d2d4911dc3..f696ec182b2 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc +++ b/tensorflow/compiler/xla/hlo/ir/hlo_computation.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/printer.h" #include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -424,43 +425,25 @@ void ComputeComputationPostOrder(HloComputation* computation, } } -std::optional GetChannelId(const HloInstruction& inst) { - // Note that we only include Send and RecvDone, as we want to create a - // dependency between those, but not SendDone and Recv. - switch (inst.opcode()) { - case HloOpcode::kSend: - case HloOpcode::kRecvDone: - case HloOpcode::kAllReduce: - case HloOpcode::kAllGather: - case HloOpcode::kAllToAll: - case HloOpcode::kCollectivePermute: - case HloOpcode::kReduceScatter: - return inst.channel_id(); - default: - return std::nullopt; - } -} - } // namespace void HloComputation::ComputeInstructionPostOrder( - HloInstruction* root, - HloComputation::ChannelDependencyGroup& channel_dependencies, + HloInstruction* root, const ChannelDependencies& channel_dependencies, absl::flat_hash_map& visited, std::vector& post_order) const { std::vector dfs_stack = {root}; while (!dfs_stack.empty()) { HloInstruction& current = *dfs_stack.back(); - auto result = visited.insert({¤t, kVisiting}); - if (!result.second) { // We've already seen this instruction. + auto [it, was_inserted] = visited.insert({¤t, kVisiting}); + if (!was_inserted) { // We've already seen this instruction. dfs_stack.pop_back(); - if (result.first->second != kVisited) { + if (it->second != kVisited) { DCHECK_EQ(current.parent(), this) << "Instruction " << current.name() << " is not in the current computation (" << name() << ")."; post_order.push_back(¤t); - result.first->second = kVisited; + it->second = kVisited; } continue; } @@ -470,15 +453,10 @@ void HloComputation::ComputeInstructionPostOrder( // Collectives with the same channel ID must be performed together, as these // represent MPMD-partitioned that will later be split into separate modules // and the order must be preserved. - std::optional channel_id = - ((¤t != root) && (current.opcode() != HloOpcode::kSend)) - ? GetChannelId(current) - : std::nullopt; - if (channel_id) { - auto it = channel_dependencies.find(*channel_id); + if (¤t != root) { + auto it = channel_dependencies.find(¤t); if (it != channel_dependencies.end()) { dfs_stack.insert(dfs_stack.end(), it->second.begin(), it->second.end()); - channel_dependencies.erase(it); } } @@ -494,25 +472,68 @@ void HloComputation::ComputeInstructionPostOrder( } } -HloComputation::ChannelDependencyGroup -HloComputation::ComputeChannelDependencies() const { +HloComputation::ChannelDependencies HloComputation::ComputeChannelDependencies() + const { if (parent() && parent()->config().has_static_device_assignment() && (parent()->config().static_device_assignment().computation_count() == 1 || parent()->config().use_spmd_partitioning())) { return {}; } - ChannelDependencyGroup channel_dependencies; + using Instructions = absl::InlinedVector; + absl::flat_hash_map channel_groups; + + // Create dependencies RecvDone -> Send, and between partitioned collectives. + ChannelDependencies dependencies; for (const auto& instruction : instructions_) { - std::optional channel_id = GetChannelId(*instruction); - if (channel_id) - channel_dependencies[*channel_id].push_back(instruction.get()); + switch (instruction->opcode()) { + case HloOpcode::kSend: { + Instructions& group = channel_groups[*instruction->channel_id()]; + if (group.empty()) { + group.push_back(instruction.get()); + } else { + dependencies[group[0]] = {instruction.get()}; + } + break; + } + case HloOpcode::kRecvDone: { + Instructions& group = channel_groups[*instruction->channel_id()]; + if (group.empty()) { + group.push_back(instruction.get()); + } else { + dependencies[instruction.get()] = {group[0]}; + } + break; + } + case HloOpcode::kAllReduce: + case HloOpcode::kAllGather: + case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: + case HloOpcode::kReduceScatter: { + std::optional channel_id = instruction->channel_id(); + if (channel_id) { + Instructions& group = channel_groups[*channel_id]; + for (const HloInstruction* group_inst : group) { + dependencies[group_inst].push_back(instruction.get()); + } + dependencies[instruction.get()] = group; + group.push_back(instruction.get()); + } + break; + } + default: + break; + } } - return channel_dependencies; + return dependencies; } std::vector HloComputation::MakeInstructionPostOrder() const { - ChannelDependencyGroup channel_dependencies = ComputeChannelDependencies(); + return MakeInstructionPostOrder(ComputeChannelDependencies()); +} + +std::vector HloComputation::MakeInstructionPostOrder( + const ChannelDependencies& channel_dependencies) const { std::vector post_order; post_order.reserve(instruction_count()); absl::flat_hash_map visited; @@ -551,36 +572,37 @@ std::vector HloComputation::MakeEmbeddedComputationsList() return post_order; } -std::string HloComputation::ToString(const HloPrintOptions& options) const { - return ToString(options, MakeInstructionPostOrder()); +void HloComputation::Print(Printer* printer, + const HloPrintOptions& options) const { + Print(printer, options, MakeInstructionPostOrder()); } -std::string HloComputation::ToString( - const HloPrintOptions& options, +void HloComputation::Print( + Printer* printer, const HloPrintOptions& options, absl::Span instruction_order) const { CHECK_EQ(instruction_order.size(), instruction_count()); const std::string tab(2 * options.indent_amount(), ' '); - std::string result; - absl::StrAppend(&result, tab); + printer->Append(tab); if (!options.is_in_nested_computation()) { if (options.print_percent()) { - absl::StrAppend(&result, "%"); + printer->Append("%"); } if (options.print_ids()) { // When print_ids() is false, exclude entry computation's name because it // includes and leads to non-deterministic fingerprint. - absl::StrAppend(&result, name(), " "); + printer->Append(name()); + printer->Append(" "); } } if (options.print_program_shape()) { - absl::StrAppend( - &result, - ShapeUtil::HumanString(ComputeProgramShape(options.print_ids())), " "); + ShapeUtil::PrintHumanString(printer, + ComputeProgramShape(options.print_ids())); + printer->Append(" "); } - absl::StrAppend(&result, "{\n"); + printer->Append("{\n"); { // Print the instructions in this computation. @@ -593,24 +615,37 @@ std::string HloComputation::ToString( for (const HloInstruction* const instruction : instruction_order) { DCHECK_EQ(this, instruction->parent()); // 2 more spaces than just 'tab' due to indent_amount()+1 above - absl::StrAppend(&result, tab, " "); + printer->Append(tab); + printer->Append(" "); if (instruction == root_instruction_) { - absl::StrAppend(&result, "ROOT "); + printer->Append("ROOT "); } - absl::StrAppend( - &result, - instruction->ToStringWithCanonicalNameMap(new_options, &name_map), - "\n"); + instruction->PrintWithCanonicalNameMap(printer, new_options, &name_map); + printer->Append("\n"); } } - absl::StrAppend(&result, tab, "}"); + printer->Append(tab); + printer->Append("}"); if (options.print_ids() && !IsMainThread()) { // When print_ids() is false, exclude entry computation's thread name // because it includes and leads to non-deterministic fingerprint. - absl::StrAppend(&result, ", execution_thread=\"", execution_thread(), "\""); + printer->Append(", execution_thread=\""); + printer->Append(execution_thread()); + printer->Append("\""); } - return result; +} + +std::string HloComputation::ToString(const HloPrintOptions& options) const { + return ToString(options, MakeInstructionPostOrder()); +} + +std::string HloComputation::ToString( + const HloPrintOptions& options, + absl::Span instruction_order) const { + StringPrinter printer; + Print(&printer, options, instruction_order); + return std::move(printer).ToString(); } absl::Cord HloComputation::ToCord(const HloPrintOptions& options) const { @@ -620,7 +655,9 @@ absl::Cord HloComputation::ToCord(const HloPrintOptions& options) const { absl::Cord HloComputation::ToCord( const HloPrintOptions& options, absl::Span instruction_order) const { - return absl::Cord(ToString(options, instruction_order)); + CordPrinter printer; + Print(&printer, options, instruction_order); + return std::move(printer).ToCord(); } HloComputationProto HloComputation::ToProto() const { @@ -874,7 +911,7 @@ ProgramShape HloComputation::ComputeProgramShape(bool include_ids) const { for (auto* param_instruction : param_instructions_) { *program_shape.add_parameters() = param_instruction->shape(); *program_shape.add_parameter_names() = - PrintName(param_instruction->name(), include_ids); + std::string(PrintName(param_instruction->name(), include_ids)); } *program_shape.mutable_result() = root_instruction_->shape(); diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_computation.h b/tensorflow/compiler/xla/hlo/ir/hlo_computation.h index 63e193dabea..cac2f47b59a 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_computation.h +++ b/tensorflow/compiler/xla/hlo/ir/hlo_computation.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/printer.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_tree.h" @@ -239,6 +240,17 @@ class HloComputation { // on the computation's existing name. void UniquifyName(NameUniquer* name_uniquer); + // Prints a string representation of the computation. + // + // (We express the default options using an overload rather than a default + // param because gdb ignores default params, but does resolve overloads.) + void Print(Printer* printer) const { + return Print(printer, HloPrintOptions()); + } + void Print(Printer* printer, const HloPrintOptions& options) const; + void Print(Printer* printer, const HloPrintOptions& options, + absl::Span instruction_order) const; + // Return a string representation of the computation. // // (We express the default options using an overload rather than a default @@ -312,9 +324,15 @@ class HloComputation { MakeUnwrappingIterator(instructions_.end())}; } + using ChannelDependencies = + absl::flat_hash_map>; + // Compute and return a post-order of the instructions in the computation. In // this order, definitions of values always appear before their uses. std::vector MakeInstructionPostOrder() const; + std::vector MakeInstructionPostOrder( + const ChannelDependencies& channel_dependencies) const; int64_t instruction_count() const { return instruction_iterators_.size(); } @@ -557,13 +575,13 @@ class HloComputation { // make each channel complete). bool IsSafelyRemovable(const HloInstruction* instruction); - // Returns a map from channel-id to the group of instructions associated with - // the channel. These instructions will be considered as a single node for - // dependency purposes. Send and RecvDone are in the group, and AllReduces - // with the same channel id are in the group. - using ChannelDependencyGroup = - absl::flat_hash_map>; - ChannelDependencyGroup ComputeChannelDependencies() const; + // Returns a map from an instruction to the group of instructions associated + // with the same channel. These instructions will be considered as a single + // node for dependency purposes. + // RecvDone ops will map to the corresponding Send op. + // Cross-partition collectives will map to every other instruction with the + // same channel ID (it doesn't map to itself). + ChannelDependencies ComputeChannelDependencies() const; // Returns true if this computation has a side effect. A computation has a // side effect if it contains one or more instructions with a side effect. @@ -711,8 +729,7 @@ class HloComputation { enum VisitState { kVisiting, kVisited }; void ComputeInstructionPostOrder( - HloInstruction* root, - HloComputation::ChannelDependencyGroup& channel_dependencies, + HloInstruction* root, const ChannelDependencies& channel_dependencies, absl::flat_hash_map& visited, std::vector& post_order) const; diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_instruction.cc b/tensorflow/compiler/xla/hlo/ir/hlo_instruction.cc index 3cbab1a0e3c..9366c8908fd 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_instruction.cc +++ b/tensorflow/compiler/xla/hlo/ir/hlo_instruction.cc @@ -46,6 +46,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/printer.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -234,8 +235,19 @@ StatusOr> HloInstruction::CreateFromProto( break; } case HloOpcode::kCopyStart: { - instruction = CreateCopyStart(shape, operands(0), - proto.is_cross_program_prefetch()); + std::optional cross_program_prefetch_index; + if (proto.optional_cross_program_prefetch_index_case() == + HloInstructionProto::kCrossProgramPrefetchIndex) { + cross_program_prefetch_index = + std::make_optional(proto.cross_program_prefetch_index()); + + // Silently upgrade HLO protos using the old field. + } else if (proto.is_cross_program_prefetch()) { + cross_program_prefetch_index = 0; + } + + instruction = + CreateCopyStart(shape, operands(0), cross_program_prefetch_index); break; } case HloOpcode::kCompare: { @@ -424,6 +436,19 @@ StatusOr> HloInstruction::CreateFromProto( << "No fusion computation with id " << fusion_id; instruction = CreateFusion(shape, fusion_kind, all_operands(), fused_computation); + std::vector>> + output_to_operand_aliasing; + for (const auto& aliasing : proto.output_operand_aliasing()) { + output_to_operand_aliasing.emplace_back( + ShapeIndex(aliasing.output_shape_index().begin(), + aliasing.output_shape_index().end()), + std::make_pair(aliasing.operand_index(), + ShapeIndex(aliasing.operand_shape_index().begin(), + aliasing.operand_shape_index().end()))); + } + auto fusion_instr = DynCast(instruction.get()); + fusion_instr->set_output_to_operand_aliasing( + std::move(output_to_operand_aliasing)); break; } case HloOpcode::kRng: @@ -782,14 +807,13 @@ StatusOr> HloInstruction::CreateFromProto( *custom_call_instr->mutable_precision_config() = precision_config; std::vector>> output_to_operand_aliasing; - for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) { + for (const auto& aliasing : proto.output_operand_aliasing()) { output_to_operand_aliasing.emplace_back( ShapeIndex(aliasing.output_shape_index().begin(), aliasing.output_shape_index().end()), - std::pair{ - aliasing.operand_index(), - ShapeIndex(aliasing.operand_shape_index().begin(), - aliasing.operand_shape_index().end())}); + std::make_pair(aliasing.operand_index(), + ShapeIndex(aliasing.operand_shape_index().begin(), + aliasing.operand_shape_index().end()))); } custom_call_instr->set_output_to_operand_aliasing( std::move(output_to_operand_aliasing)); @@ -998,8 +1022,11 @@ StatusOr> HloInstruction::CreateFromProto( instruction->unique_id_ = proto.id(); if (proto.has_sharding()) { - TF_ASSIGN_OR_RETURN(const auto& sharding, + TF_ASSIGN_OR_RETURN(HloSharding sharding, HloSharding::FromProto(proto.sharding())); + // To allow for existing Hlo protos to not fail verification, apply tuple + // sharding normalization. + sharding = sharding.NormalizeTupleSharding(instruction->shape()); instruction->set_sharding(sharding); } @@ -1106,6 +1133,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, case HloOpcode::kSqrt: case HloOpcode::kCbrt: case HloOpcode::kTanh: + case HloOpcode::kTan: break; default: LOG(FATAL) << "Invalid unary instruction opcode " @@ -1220,9 +1248,9 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, /* static */ std::unique_ptr HloInstruction::CreateCopyStart( const Shape& shape, HloInstruction* operand, - bool is_cross_program_prefetch) { + std::optional cross_program_prefetch) { return std::make_unique(shape, operand, - is_cross_program_prefetch); + cross_program_prefetch); } /* static */ std::unique_ptr HloInstruction::CreateCompare( @@ -1557,6 +1585,17 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, return instruction; } +/* static */ std::unique_ptr +HloInstruction::CreateStochasticConvert(const Shape& shape, + HloInstruction* operand, + HloInstruction* random) { + auto instruction = absl::WrapUnique( + new HloInstruction(HloOpcode::kStochasticConvert, shape)); + instruction->AppendOperand(operand); + instruction->AppendOperand(random); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateBitcast( const Shape& shape, HloInstruction* operand) { auto instruction = @@ -2096,6 +2135,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kSin: case HloOpcode::kSqrt: case HloOpcode::kCbrt: + case HloOpcode::kTan: case HloOpcode::kTanh: CHECK_EQ(new_operands.size(), 1); clone = CreateUnary(shape, opcode_, new_operands[0]); @@ -2117,7 +2157,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: - case HloOpcode::kStochasticConvert: CHECK_EQ(new_operands.size(), 2); clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]); break; @@ -2140,6 +2179,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBitcastConvert(shape, new_operands[0]); break; + case HloOpcode::kStochasticConvert: + CHECK_EQ(new_operands.size(), 2); + clone = CreateStochasticConvert(shape, new_operands[0], new_operands[1]); + break; case HloOpcode::kDynamicUpdateSlice: clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], new_operands.subspan(2)); @@ -2550,6 +2593,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kStochasticConvert: case HloOpcode::kCbrt: case HloOpcode::kSubtract: + case HloOpcode::kTan: case HloOpcode::kTanh: case HloOpcode::kTuple: return true; @@ -2928,7 +2972,7 @@ std::string HloInstruction::SignatureString() const { return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); } -std::string PrintName(const std::string& name, bool print_ids) { +absl::string_view PrintName(absl::string_view name, bool print_ids) { if (print_ids) { return name; } else { @@ -2941,10 +2985,19 @@ namespace { using DFSStack = absl::InlinedVector, 16>; +void PrintNameInternal(Printer* printer, absl::string_view name, + const HloPrintOptions& options) { + if (options.print_percent()) { + printer->Append("%"); + } + printer->Append(PrintName(name, options.print_ids())); +} + std::string PrintNameInternal(const std::string& name, const HloPrintOptions& options) { - return StrCat(options.print_percent() ? "%" : "", - PrintName(name, options.print_ids())); + StringPrinter printer; + PrintNameInternal(&printer, name, options); + return std::move(printer).ToString(); } void PrintCycle(const HloInstruction* child, DFSStack* dfs_stack) { @@ -2988,9 +3041,16 @@ void PrintCycle(const HloInstruction* child, DFSStack* dfs_stack) { } // namespace -std::string HloInstruction::ToString(const HloPrintOptions& options) const { +void HloInstruction::Print(Printer* printer, + const HloPrintOptions& options) const { CanonicalNameMap new_map; - return ToStringWithCanonicalNameMap(options, &new_map); + PrintWithCanonicalNameMap(printer, options, &new_map); +} + +std::string HloInstruction::ToString(const HloPrintOptions& options) const { + StringPrinter printer; + Print(&printer, options); + return std::move(printer).ToString(); } bool HloInstruction::IsOpElementwise(HloOpcode opcode) { @@ -3023,6 +3083,7 @@ bool HloInstruction::IsOpElementwise(HloOpcode opcode) { case HloOpcode::kSin: case HloOpcode::kSqrt: case HloOpcode::kCbrt: + case HloOpcode::kTan: case HloOpcode::kTanh: return true; @@ -3044,6 +3105,7 @@ bool HloInstruction::IsOpElementwise(HloOpcode opcode) { case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: + case HloOpcode::kStochasticConvert: return true; // Ternary elementwise operations. @@ -3077,30 +3139,31 @@ bool HloInstruction::IsCrossReplicaAllReduce() const { return opcode() == HloOpcode::kAllReduce && !channel_id(); } -std::string HloInstruction::ToStringWithCanonicalNameMap( - const HloPrintOptions& options, +void HloInstruction::PrintWithCanonicalNameMap( + Printer* printer, const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { - std::string result = ""; - // Logic to print the instruction name (e.g. "%foo = "). if (options.canonicalize_instruction_names()) { if (options.is_in_nested_computation()) { // If we are canonicalizing instruction names and this is a top-level // HloInstruction::ToString() call, don't print an instruction name. DCHECK(!options.print_percent()); // no need to call PrintNameInternal - StrAppend(&result, canonical_name_map->LookupOrInsert(name()), " = "); + printer->Append(canonical_name_map->LookupOrInsert(name())); + printer->Append(" = "); } } else { - StrAppend(&result, PrintNameInternal(name(), options), " = "); + PrintNameInternal(printer, name(), options); + printer->Append(" = "); } if (options.print_result_shape()) { // Print shape. if (options.include_layout_in_shapes()) { - StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " "); + ShapeUtil::PrintHumanStringWithLayout(printer, shape()); } else { - StrAppend(&result, ShapeUtil::HumanString(shape()), " "); + ShapeUtil::PrintHumanString(printer, shape()); } + printer->Append(" "); } // Print opcode, operand(s). @@ -3117,42 +3180,45 @@ std::string HloInstruction::ToStringWithCanonicalNameMap( return "-done"; } }(); - StrAppend(&result, HloOpcodeString(async_wrapped_opcode()), suffix); + printer->Append(HloOpcodeString(async_wrapped_opcode())); + printer->Append(suffix); } else { - StrAppend(&result, HloOpcodeString(opcode())); + printer->Append(HloOpcodeString(opcode())); } - StrAppend(&result, "(", - OperandsToStringWithCanonicalNameMap(options, canonical_name_map), - ")"); + printer->Append("("); + PrintOperandsWithCanonicalNameMap(printer, options, canonical_name_map); + printer->Append(")"); // Print additional attributes. If an instruction contains a subcomputation, // the subcomputation is also printed here. for (const std::string& extra : ExtraAttributesToString(options)) { - StrAppend(&result, ", ", extra); + printer->Append(", "); + printer->Append(extra); } if (options.print_metadata() && (!metadata_.op_type().empty() || !metadata_.op_name().empty() || !metadata_.source_file().empty())) { - StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); + printer->Append(", metadata={"); + printer->Append(xla::OpMetadataToString(metadata_)); + printer->Append("}"); } if (options.print_backend_config() && !backend_config_.empty()) { - StrAppend(&result, ", backend_config=\"", - CEscape(backend_config_.GetRawString()), "\""); + printer->Append(", backend_config=\""); + printer->Append(CEscape(backend_config_.GetRawString())); + printer->Append("\""); } - return result; } -std::string HloInstruction::OperandsToString( - const HloPrintOptions& options) const { +void HloInstruction::PrintOperands(Printer* printer, + const HloPrintOptions& options) const { CanonicalNameMap new_map; - return OperandsToStringWithCanonicalNameMap(options, &new_map); + PrintOperandsWithCanonicalNameMap(printer, options, &new_map); } -std::string HloInstruction::OperandsToStringWithCanonicalNameMap( - const HloPrintOptions& options, +void HloInstruction::PrintOperandsWithCanonicalNameMap( + Printer* printer, const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { - std::string operands; absl::Span slice(operands_); const int64_t kMaxOperandsToShowIfCompact = 4; if (options.compact_operands() && @@ -3162,42 +3228,45 @@ std::string HloInstruction::OperandsToStringWithCanonicalNameMap( for (int64_t i = 0; i < slice.size(); ++i) { HloInstruction* operand = slice[i]; if (i != 0) { - StrAppend(&operands, ", "); + printer->Append(", "); if (options.print_operand_index_annotation_interval() != 0 && i % options.print_operand_index_annotation_interval() == 0) { - StrAppend(&operands, absl::StrFormat("/*index=%lld*/", i)); + printer->Append(absl::StrFormat("/*index=%lld*/", i)); } } // If operand is already been deleted, put `null` to the string output. if (operand == nullptr) { - StrAppend(&operands, "null "); + printer->Append("null "); continue; } - std::vector str; + bool add_space = false; if (options.print_operand_shape()) { if (options.include_layout_in_shapes()) { - str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); + ShapeUtil::PrintHumanStringWithLayout(printer, operand->shape()); } else { - str.push_back(ShapeUtil::HumanString(operand->shape())); + ShapeUtil::PrintHumanString(printer, operand->shape()); } + add_space = true; } if (options.canonicalize_instruction_names()) { if (options.is_in_nested_computation()) { // In a top-level HloInstruction::ToString() call, the operand name is // not part of the canonical string. DCHECK(!options.print_percent()); // no need to call PrintNameInternal - str.push_back(canonical_name_map->LookupOrInsert(operand->name())); + if (add_space) printer->Append(" "); + printer->Append(canonical_name_map->LookupOrInsert(operand->name())); } } else if (options.print_operand_names()) { - str.push_back(PrintNameInternal(operand->name(), options)); + if (add_space) printer->Append(" "); + PrintNameInternal(printer, operand->name(), options); } - StrAppend(&operands, StrJoin(str, " ")); } const int64_t remaining = operands_.size() - slice.size(); if (slice.size() != operands_.size()) { - StrAppend(&operands, ", ...(+", remaining, ")"); + printer->Append(", ...(+"); + printer->Append(absl::AlphaNum(remaining).Piece()); + printer->Append(")"); } - return operands; } namespace { @@ -3621,6 +3690,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleLog(this); case HloOpcode::kLog1p: return visitor->HandleLog1p(this); + case HloOpcode::kTan: + return visitor->HandleTan(this); case HloOpcode::kTanh: return visitor->HandleTanh(this); case HloOpcode::kCos: @@ -4901,8 +4972,8 @@ void HloInstruction::set_called_computations_execution_thread( async_execution_thread, skip_async_execution_thread_overwrite); } -bool HloInstruction::is_cross_program_prefetch() const { - return Cast(this)->is_cross_program_prefetch(); +std::optional HloInstruction::cross_program_prefetch_index() const { + return Cast(this)->cross_program_prefetch_index(); } ComparisonDirection HloInstruction::comparison_direction() const { @@ -4922,8 +4993,8 @@ const CholeskyOptions& HloInstruction::cholesky_options() const { } const std::vector>>& -HloInstruction::custom_call_output_operand_aliasing() const { - return Cast(this)->output_to_operand_aliasing(); +HloInstruction::output_operand_aliasing() const { + return Cast(this)->output_to_operand_aliasing(); } } // namespace xla diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h b/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h index d48e326eadb..3a9dc645179 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h +++ b/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/printer.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -65,7 +66,7 @@ namespace xla { class HloComputation; class HloModule; -std::string PrintName(const std::string& name, bool print_ids); +absl::string_view PrintName(absl::string_view name, bool print_ids); // A bunch of switches that control how the hlo text should be printed. class HloPrintOptions { @@ -646,7 +647,7 @@ class HloInstruction { // prefetch or not. static std::unique_ptr CreateCopyStart( const Shape& shape, HloInstruction* operand, - bool is_cross_program_prefetch = false); + std::optional cross_program_prefetch_index = std::nullopt); // Creates a compare op, performing the comparison specified in direction. static std::unique_ptr CreateCompare( @@ -833,6 +834,12 @@ class HloInstruction { static std::unique_ptr CreateBitcastConvert( const Shape& shape, HloInstruction* operand); + // Creates a stochastic conversion instruction, where operand is the data to + // convert, random is a given random input to determine the rounding direction + // and shape is the target shape for the conversion. + static std::unique_ptr CreateStochasticConvert( + const Shape& shape, HloInstruction* operand, HloInstruction* random); + // Creates an infeed instruction, which reads data of the given shape from the // Infeed interface of the device. infeed_shape is the shape of the data // received from the infeed *not* the shape of the infeed instruction which @@ -1498,6 +1505,12 @@ class HloInstruction { // function, e.g. the signature of an F32 add is (F32, F32) -> F32. std::string SignatureString() const; + // Prints a debugging string that represents this instruction. + void Print(Printer* printer) const { + return Print(printer, HloPrintOptions()); + } + void Print(Printer* printer, const HloPrintOptions& options) const; + // Returns a debugging string that represents this instruction. // // (We express the default options using an overload rather than a default @@ -1509,10 +1522,10 @@ class HloInstruction { std::string ToString() const { return ToString(HloPrintOptions()); } std::string ToString(const HloPrintOptions& options) const; - // Components of the ToString() representation: + // Components of the Print() and ToString() representation: - // Returns a string representation of the operand list. - std::string OperandsToString(const HloPrintOptions& options) const; + // Prints a string representation of the operand list. + void PrintOperands(Printer* printer, const HloPrintOptions& options) const; // Returns string representation of op-specific attributes. std::vector ExtraAttributesToString( @@ -1526,9 +1539,9 @@ class HloInstruction { // The canonical string representation needs to name operands and instruction // names in a consistent way. This is implemented through the // canonical_name_map. - std::string ToStringWithCanonicalNameMap( - const HloPrintOptions& options, - CanonicalNameMap* canonical_name_map) const; + void PrintWithCanonicalNameMap(Printer* printer, + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const; // Returns a serialized representation of this instruction. virtual HloInstructionProto ToProto() const; @@ -1568,14 +1581,14 @@ class HloInstruction { // Returns the sharding unique device, if any. std::optional sharding_unique_device() const { if (sharding_ == nullptr) { - return std::optional(); + return std::nullopt; } return sharding_->UniqueDevice(); } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { - sharding_ = std::make_shared(sharding); + set_sharding(std::make_shared(sharding)); } void set_sharding(std::shared_ptr sharding) { sharding_ = std::move(sharding); @@ -2143,8 +2156,8 @@ class HloInstruction { absl::string_view async_execution_thread, bool skip_async_execution_thread_overwrite); - // Delegates to HloCopyStartInstruction::is_cross_program_prefetch(). - bool is_cross_program_prefetch() const; + // Delegates to HloCopyStartInstruction::is_cross_program_prefetch_index(). + std::optional cross_program_prefetch_index() const; // Delegates to HloCompareInstruction::direction(). ComparisonDirection comparison_direction() const; @@ -2157,9 +2170,9 @@ class HloInstruction { // Delegates to HloCholeskyInstruction::cholesky_options(). const CholeskyOptions& cholesky_options() const; - // Delegates to HloCustomCallInstruction::output_to_operand_aliasing(). + // Delegates to HloCallableInstruction::output_to_operand_aliasing(). const std::vector>>& - custom_call_output_operand_aliasing() const; + output_operand_aliasing() const; // Appends operand to the list of operands and adds this instruction as a user // of the operand. @@ -2275,8 +2288,8 @@ class HloInstruction { const std::optional& operand_idx) const; // Prints an operand to a string. Accessed by friend class HloInstruction. - virtual std::string OperandsToStringWithCanonicalNameMap( - const HloPrintOptions& options, + virtual void PrintOperandsWithCanonicalNameMap( + Printer* printer, const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const; // See comments on Identical(). diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_instructions.cc b/tensorflow/compiler/xla/hlo/ir/hlo_instructions.cc index a674391edf0..d437f7db846 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_instructions.cc +++ b/tensorflow/compiler/xla/hlo/ir/hlo_instructions.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/printer.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -369,25 +370,28 @@ HloInstructionProto HloAsyncInstruction::ToProto() const { return proto; } -HloCopyStartInstruction::HloCopyStartInstruction(const Shape& shape, - HloInstruction* operand, - bool is_cross_program_prefetch) +HloCopyStartInstruction::HloCopyStartInstruction( + const Shape& shape, HloInstruction* operand, + std::optional cross_program_prefetch_index) : HloInstruction(HloOpcode::kCopyStart, shape), - is_cross_program_prefetch_(is_cross_program_prefetch) { + cross_program_prefetch_index_(cross_program_prefetch_index) { AppendOperand(operand); } HloInstructionProto HloCopyStartInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - proto.set_is_cross_program_prefetch(is_cross_program_prefetch_); + if (cross_program_prefetch_index_.has_value()) { + proto.set_cross_program_prefetch_index(*cross_program_prefetch_index_); + } return proto; } std::vector HloCopyStartInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { std::vector result; - if (is_cross_program_prefetch()) { - result.push_back("is_cross_program_prefetch=true"); + if (cross_program_prefetch_index_.has_value()) { + result.push_back("cross_program_prefetch_index=" + + std::to_string(*cross_program_prefetch_index_)); } return result; } @@ -397,8 +401,8 @@ bool HloCopyStartInstruction::IdenticalSlowPath( absl::FunctionRef eq_computations) const { const auto& casted_other = static_cast(other); - return is_cross_program_prefetch() == - casted_other.is_cross_program_prefetch(); + return cross_program_prefetch_index() == + casted_other.cross_program_prefetch_index(); } std::unique_ptr @@ -406,8 +410,8 @@ HloCopyStartInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return std::make_unique(shape, new_operands[0], - is_cross_program_prefetch()); + return std::make_unique( + shape, new_operands[0], cross_program_prefetch_index()); } HloCompareInstruction::HloCompareInstruction( @@ -460,8 +464,8 @@ std::unique_ptr HloCompareInstruction::CloneWithNewOperandsImpl( namespace { -// Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector -// of "key=value" attribute strings generically, using protocol buffer +// Converts a protocol buffer message (e.g., TriangularSolveOptions) to a +// vector of "key=value" attribute strings generically, using protocol buffer // reflection. // // Currently implements a small subset of cases; feel free to add more as @@ -1509,23 +1513,36 @@ HloConstantInstruction::CloneWithNewOperandsImpl( this->shape()); } -std::string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( - const HloPrintOptions& options, +void HloConstantInstruction::PrintOperandsWithCanonicalNameMap( + Printer* printer, const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { if (options.print_only_essential_constants()) { if (!literal_.has_value()) { - return "{...}"; + printer->Append("{...}"); + return; } if (literal().IsAll(0)) { - return "0"; + printer->Append("0"); + return; } if (literal().IsAll(1)) { - return "1"; + printer->Append("1"); + return; } if (shape().IsInteger()) { - return literal_->ToStringWithoutShapeOneline(); + // The following prevents high compilation latencies caused by serializing + // large constant tensors; for example: b/265669625. The limit of 500k was + // chosen empirically to make sure that serialization of the `literal_` is + // less than a second. + if (auto num_constants = + absl::c_accumulate(shape().dimensions(), 1, std::multiplies<>()); + num_constants <= 500'000) { + literal_->PrintWithoutShapeOneline(printer); + return; + } } - return "{...}"; + printer->Append("{...}"); + return; } // For constants, show the actual value in place of an empty operand list. @@ -1534,10 +1551,10 @@ std::string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. - return literal_->ToStringWithoutShapeOneline(); + literal_->PrintWithoutShapeOneline(printer); } else { // Do not show large constants or tuples. - return "{...}"; + printer->Append("{...}"); } } @@ -1641,16 +1658,16 @@ HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation( CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); clone = called_computation_root(); } else { - // When add_output is false, instruction_to_append is necessarily an operand - // of the callable instruction. After appending this will no longer be the - // case. Remove the operand from the operand list and remove its - // corresponding called computation parameter instruction. + // When add_output is false, instruction_to_append is necessarily an + // operand of the callable instruction. After appending this will no + // longer be the case. Remove the operand from the operand list and remove + // its corresponding called computation parameter instruction. bool in_operand_list = absl::c_linear_search(operands(), instruction_to_append); CHECK(add_output || in_operand_list); if (do_not_clone) { - // We assume all uses of a kTuple operation are GTE ops. In this case, we - // don't need to clone 'instruction_to_append'. + // We assume all uses of a kTuple operation are GTE ops. In this case, + // we don't need to clone 'instruction_to_append'. CHECK(!in_operand_list); clone = instruction_to_append; } else { @@ -1662,8 +1679,8 @@ HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation( for (int64_t operand_num = 0; operand_num < operand_count(); ++operand_num) { if (instruction_to_append == operand(operand_num)) { - // Replace the called computation parameter instruction's uses with the - // clone. + // Replace the called computation parameter instruction's uses with + // the clone. HloInstruction* called_computation_parameter = called_computation_parameters[operand_num]; TF_CHECK_OK(called_computation_parameter->ReplaceAllUsesWith(clone)); @@ -1679,8 +1696,8 @@ HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation( // this callable instruction is no longer a use of instruction_to_append. if (in_operand_list) { DetachFrom(instruction_to_append); - // When the instruction_to_append does not have other users, we don't need - // to generate a multioutput instruction. + // When the instruction_to_append does not have other users, we don't + // need to generate a multioutput instruction. if (instruction_to_append->user_count() == 0) { add_output = false; } @@ -1691,8 +1708,8 @@ HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation( const std::vector& called_computation_parameters = called_computation()->parameter_instructions(); - // Add each operand of the clone as an operand of the callable instruction. A - // complication is that some clone operands may already be operands of the + // Add each operand of the clone as an operand of the callable instruction. + // A complication is that some clone operands may already be operands of the // callable instruction. for (int64_t operand_num = 0; operand_num < clone->operand_count(); ++operand_num) { @@ -1709,9 +1726,9 @@ HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation( } if (called_computation_parameter == nullptr) { - // Clone's operand was not already an operand of the callable instruction. - // Add it as an operand and add a corresponding called computation - // parameter instruction. + // Clone's operand was not already an operand of the callable + // instruction. Add it as an operand and add a corresponding called + // computation parameter instruction. called_computation_parameter = AddCallOperand(operand); } TF_CHECK_OK( @@ -1720,7 +1737,8 @@ HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation( if (add_output) { CHECK_GT(instruction_to_append->user_count(), 0); - // If this is already a multioutput instruction, expand the root tuple by 1. + // If this is already a multioutput instruction, expand the root tuple + // by 1. HloInstruction* root = called_computation_root(); HloInstruction::InstructionVector tuple_elements; bool newly_created_tuple_instr = false; @@ -1742,6 +1760,9 @@ HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation( called_computation()->set_root_instruction(new_root, /*accept_different_shape=*/true); *mutable_shape() = new_root->shape(); + // The instruction might have an existing sharding, which will no longer + // be valid after we change the shape. So clear the sharding. + clear_sharding(); if (root->opcode() == HloOpcode::kTuple) { TF_CHECK_OK(called_computation()->RemoveInstruction(root)); } @@ -1845,7 +1866,8 @@ void HloFusionInstruction::ClearFusionComputationInstruction() { // Each fusion calls a single computation, but we use called_computations() // instead of fused_instructions_computation(), because the order in which // things get destructed can vary; the fusion computation's back-pointer may - // already be null, which violates a check in fused_instructions_computation. + // already be null, which violates a check in + // fused_instructions_computation. for (HloComputation* computation : called_computations()) { // Some passes that rewrite fusions may reassign a fusion computation to a // different fusion instruction as this instruction gets destructed. @@ -1876,6 +1898,16 @@ std::string HloFusionInstruction::ToCategory() const { HloInstructionProto HloFusionInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_fusion_kind(xla::ToString(fusion_kind())); + for (const auto& pair : output_to_operand_aliasing()) { + auto aliasing = proto.add_output_operand_aliasing(); + aliasing->set_operand_index(pair.second.first); + for (int64_t index : pair.first) { + aliasing->add_output_shape_index(index); + } + for (int64_t index : pair.second.second) { + aliasing->add_operand_shape_index(index); + } + } proto.add_called_computation_ids( fused_instructions_computation()->unique_id()); return proto; @@ -1963,8 +1995,8 @@ void HloFusionInstruction::MergeFusionInstruction( // Replace instruction_to_merge use of 'this' with unfused_root. TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root)); - // Build a dummy root for the cloned fusion as we may remove the original root - // in the fusion process. + // Build a dummy root for the cloned fusion as we may remove the original + // root in the fusion process. if (!unfused_instructions.empty()) { HloComputation* computation = unfused_root->parent(); auto* dummy_root = computation->AddInstruction( @@ -2100,7 +2132,20 @@ int64_t HloFusionInstruction::fused_instruction_count() const { std::vector HloFusionInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("kind=", xla::ToString(fusion_kind()))}; + std::vector extra = { + StrCat("kind=", xla::ToString(fusion_kind()))}; + if (!output_to_operand_aliasing().empty()) { + std::vector pair_strings; + pair_strings.reserve(output_to_operand_aliasing().size()); + for (const auto& pair : output_to_operand_aliasing()) { + pair_strings.push_back(StrCat(pair.first.ToString(), ": (", + pair.second.first, ", ", + pair.second.second.ToString(), ")")); + } + extra.push_back(StrCat("output_to_operand_aliasing={", + StrJoin(pair_strings, ", "), "}")); + } + return extra; } bool HloFusionInstruction::IdenticalSlowPath( @@ -2108,6 +2153,9 @@ bool HloFusionInstruction::IdenticalSlowPath( absl::FunctionRef eq_computations) const { return fusion_kind() == other.fusion_kind() && + output_to_operand_aliasing() == + static_cast(other) + .output_to_operand_aliasing() && eq_computations(fused_instructions_computation(), other.fused_instructions_computation()); } @@ -2241,10 +2289,10 @@ std::vector HloParameterInstruction::ExtraAttributesToStringImpl( return result; } -std::string HloParameterInstruction::OperandsToStringWithCanonicalNameMap( - const HloPrintOptions& options, +void HloParameterInstruction::PrintOperandsWithCanonicalNameMap( + Printer* printer, const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { - return StrCat(parameter_number_); + printer->Append(absl::AlphaNum(parameter_number_).Piece()); } bool HloParameterInstruction::IdenticalSlowPath( @@ -2739,8 +2787,8 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { if (literal_.has_value()) { *proto.mutable_literal() = literal_->ToProto(); } - for (const auto& pair : output_to_operand_aliasing_) { - auto aliasing = proto.add_custom_call_output_operand_aliasing(); + for (const auto& pair : output_to_operand_aliasing()) { + auto aliasing = proto.add_output_operand_aliasing(); aliasing->set_operand_index(pair.second.first); for (int64_t index : pair.first) { aliasing->add_output_shape_index(index); @@ -2780,8 +2828,8 @@ std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( extra.push_back(StrCat("padding_type=", PaddingType_Name(padding_type()))); } // By contract, we print the custom call target even if - // options.print_subcomputation_mode() == kOff, because the call target is not - // an HloComputation. + // options.print_subcomputation_mode() == kOff, because the call target is + // not an HloComputation. extra.push_back( StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); @@ -2800,10 +2848,10 @@ std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( if (literal_.has_value()) { extra.push_back(StrCat("literal=", literal_->ToStringWithLayoutOneline())); } - if (!output_to_operand_aliasing_.empty()) { + if (!output_to_operand_aliasing().empty()) { std::vector pair_strings; - pair_strings.reserve(output_to_operand_aliasing_.size()); - for (const auto& pair : output_to_operand_aliasing_) { + pair_strings.reserve(output_to_operand_aliasing().size()); + for (const auto& pair : output_to_operand_aliasing()) { pair_strings.push_back(StrCat(pair.first.ToString(), ": (", pair.second.first, ", ", pair.second.second.ToString(), ")")); @@ -2867,7 +2915,7 @@ bool HloCustomCallInstruction::IdenticalSlowPath( casted_other.custom_call_has_side_effect()) { return false; } - if (output_to_operand_aliasing_ != + if (output_to_operand_aliasing() != casted_other.output_to_operand_aliasing()) { return false; } @@ -2928,7 +2976,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( cloned->set_feature_group_count(feature_group_count_); cloned->set_batch_group_count(batch_group_count_); cloned->set_custom_call_has_side_effect(custom_call_has_side_effect_); - cloned->set_output_to_operand_aliasing(output_to_operand_aliasing_); + cloned->set_output_to_operand_aliasing(output_to_operand_aliasing()); cloned->set_padding_type(padding_type_); *cloned->mutable_precision_config() = precision_config(); cloned->set_custom_call_schedule(custom_call_schedule_); diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_instructions.h b/tensorflow/compiler/xla/hlo/ir/hlo_instructions.h index 5ad24d6759d..a611b092d45 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_instructions.h +++ b/tensorflow/compiler/xla/hlo/ir/hlo_instructions.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_computation.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_opcode.h" +#include "tensorflow/compiler/xla/printer.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -274,10 +275,13 @@ class HloAsyncInstruction : public HloInstruction { class HloCopyStartInstruction : public HloInstruction { public: - explicit HloCopyStartInstruction(const Shape& shape, HloInstruction* operand, - bool is_cross_program_prefetch); + explicit HloCopyStartInstruction( + const Shape& shape, HloInstruction* operand, + std::optional cross_program_prefetch_index); - bool is_cross_program_prefetch() const { return is_cross_program_prefetch_; } + std::optional cross_program_prefetch_index() const { + return cross_program_prefetch_index_; + } HloInstructionProto ToProto() const override; static bool ClassOf(const HloInstruction* hlo) { @@ -295,7 +299,14 @@ class HloCopyStartInstruction : public HloInstruction { const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - bool is_cross_program_prefetch_; + // Each cross program prefetched buffer has a unique index. The indices are + // assigned contiguously starting from zero in + // AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer. This value + // is used during codegen to determine which buffer is being speculated at + // runtime. One possible implementation is to initialize an array with boolean + // values indicating whether the cross program prefetch succeeds or fails for + // each buffer. + std::optional cross_program_prefetch_index_; }; class HloCompareInstruction : public HloInstruction { @@ -1128,8 +1139,8 @@ class HloConstantInstruction : public HloInstruction { const HloInstruction& other, absl::FunctionRef eq_computations) const override; - std::string OperandsToStringWithCanonicalNameMap( - const HloPrintOptions& options, + void PrintOperandsWithCanonicalNameMap( + Printer* printer, const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( @@ -1200,9 +1211,30 @@ class HloCallableInstruction : public HloInstruction { hlo->opcode() == HloOpcode::kCustomCall; } + // Gets a list of output/operand buffer pairs that alias each other, where the + // output buffer is represented as a ShapeIndex, and the operand buffer is + // represented as the operand index and the ShapeIndex. By default this list + // is empty. + const std::vector>>& + output_to_operand_aliasing() const { + return output_to_operand_aliasing_; + } + // Sets the list of output/operand buffer pairs that alias each other. + void set_output_to_operand_aliasing( + std::vector>> + aliasing) { + output_to_operand_aliasing_ = std::move(aliasing); + } + protected: // Returns the default called computation name. virtual std::string default_called_computation_name() const = 0; + + private: + // A list of output/operand buffer pairs that alias each other. See comment of + // output_to_operand_aliasing(). + std::vector>> + output_to_operand_aliasing_; }; class HloFusionInstruction : public HloCallableInstruction { @@ -1428,8 +1460,8 @@ class HloParameterInstruction : public HloInstruction { const HloInstruction& other, absl::FunctionRef eq_computations) const override; - std::string OperandsToStringWithCanonicalNameMap( - const HloPrintOptions& options, + void PrintOperandsWithCanonicalNameMap( + Printer* printer, const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( @@ -1892,20 +1924,6 @@ class HloCustomCallInstruction : public HloCallableInstruction { CHECK(layout_constrained()); return operand_shapes_with_layout_; } - // Gets a list of output/operand buffer pairs that alias each other, where the - // output buffer is represented as a ShapeIndex, and the operand buffer is - // represented as the operand index and the ShapeIndex. By default this list - // is empty. - const std::vector>>& - output_to_operand_aliasing() const { - return output_to_operand_aliasing_; - } - // Sets the list of output/operand buffer pairs that alias each other. - void set_output_to_operand_aliasing( - std::vector>> - aliasing) { - output_to_operand_aliasing_ = std::move(aliasing); - } void set_custom_call_schedule(CustomCallSchedule custom_call_schedule) { custom_call_schedule_ = custom_call_schedule; } @@ -1958,10 +1976,6 @@ class HloCustomCallInstruction : public HloCallableInstruction { std::vector operand_shapes_with_layout_; // Whether this custom call has a side-effect. bool custom_call_has_side_effect_; - // A list of output/operand buffer pairs that alias each other. See comment of - // output_to_operand_aliasing(). - std::vector>> - output_to_operand_aliasing_; std::optional literal_; // A custom-call schedule hint. CustomCallSchedule custom_call_schedule_; diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_module.cc b/tensorflow/compiler/xla/hlo/ir/hlo_module.cc index e6b95f16b39..2d33b8be202 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_module.cc +++ b/tensorflow/compiler/xla/hlo/ir/hlo_module.cc @@ -36,13 +36,14 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_instruction.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_schedule.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/printer.h" #include "tensorflow/compiler/xla/service/compilation_environments.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/lib/gtl/map_util.h" #include "tensorflow/tsl/platform/errors.h" @@ -56,6 +57,16 @@ namespace xla { HloModule::HloModule(const std::string& name, HloModuleConfig config) : HloModule(name, config, std::make_unique()) {} +HloModule::HloModule(const std::string& name, HloModuleConfig config, + std::unique_ptr comp_envs) + : name_(NameUniquer::GetSanitizedName(name)), + config_(std::move(config)), + unique_id_(next_unique_module_id_++), + metadata_(tsl::Env::Default()), + comp_envs_(std::move(comp_envs)) { + metadata_.set_canonical_module_id(unique_id_); +} + Status HloModule::set_schedule(HloSchedule schedule) { TF_RET_CHECK(schedule.module() == this); TF_RETURN_IF_ERROR(schedule.Verify()); @@ -233,40 +244,45 @@ void HloModule::ReplaceComputations( computations_ = std::move(new_computations); } -std::string HloModule::ToString(const HloPrintOptions& options) const { - return std::string(ToCord(options)); -} - -absl::Cord HloModule::ToCord(const HloPrintOptions& options) const { - absl::Cord result; - result.Append("HloModule "); +void HloModule::Print(Printer* printer, const HloPrintOptions& options) const { + printer->Append("HloModule "); if (options.print_ids()) { // When print_ids() is false, exclude module's name because it includes and // leads to non-deterministic fingerprint. - result.Append(name()); + printer->Append(name()); } if (has_schedule()) { TF_CHECK_OK(schedule().Verify()); - result.Append(", is_scheduled=true"); + printer->Append(", is_scheduled=true"); } std::string serialized_aliasing = input_output_alias_config().ToShortString(); if (!serialized_aliasing.empty()) { - result.Append(", input_output_alias={ "); - result.Append(std::move(serialized_aliasing)); - result.Append(" }"); + printer->Append(", input_output_alias={ "); + printer->Append(std::move(serialized_aliasing)); + printer->Append(" }"); } if (config_.alias_passthrough_params()) { - result.Append(", alias_passthrough_params=true"); + printer->Append(", alias_passthrough_params=true"); } if (config_.has_entry_computation_layout()) { - result.Append(", entry_computation_layout={"); - result.Append(entry_computation_layout().ToString()); - result.Append("}"); + printer->Append(", entry_computation_layout={"); + entry_computation_layout().Print(printer); + printer->Append("}"); } - if (config_.allow_spmd_sharding_propagation_to_output()) { - result.Append(", allow_spmd_sharding_propagation_to_output=true"); + if (config_.allow_spmd_sharding_propagation_to_output().size() != 1 || + config_.allow_spmd_sharding_propagation_to_output().back()) { + struct BoolFormatter { + void operator()(std::string* out, bool i) const { + out->append(i ? "true" : "false"); + } + }; + printer->Append(absl::StrCat( + ", allow_spmd_sharding_propagation_to_output={", + absl::StrJoin(config_.allow_spmd_sharding_propagation_to_output(), ",", + BoolFormatter()), + "}")); } - result.Append("\n\n"); + printer->Append("\n\n"); const auto& computations = options.canonicalize_computations() ? MakeComputationSorted() : MakeComputationPostOrder(); @@ -277,17 +293,28 @@ absl::Cord HloModule::ToCord(const HloPrintOptions& options) const { continue; } if (computation == entry_computation()) { - result.Append("ENTRY "); + printer->Append("ENTRY "); } if (has_schedule() && schedule().is_computation_scheduled(computation)) { - result.Append(computation->ToCord( - options, schedule().sequence(computation).instructions())); + computation->Print(printer, options, + schedule().sequence(computation).instructions()); } else { - result.Append(computation->ToCord(options)); + computation->Print(printer, options); } - result.Append("\n\n"); + printer->Append("\n\n"); } - return result; +} + +std::string HloModule::ToString(const HloPrintOptions& options) const { + StringPrinter printer; + Print(&printer, options); + return std::move(printer).ToString(); +} + +absl::Cord HloModule::ToCord(const HloPrintOptions& options) const { + CordPrinter printer; + Print(&printer, options); + return std::move(printer).ToCord(); } HloModuleProto HloModule::ToProto() const { @@ -310,14 +337,16 @@ HloModuleProto HloModule::ToProto() const { *proto.mutable_input_output_alias() = input_output_alias_config().ToProto(); *proto.mutable_dynamic_parameter_binding() = dynamic_parameter_binding().ToProto(); - for (const auto& parameter_indices : CrossProgramPrefetches()) { - const auto& parameter = parameter_indices.first; - const auto& indices = parameter_indices.second; + for (const auto& [parameter, indices, alt_memory_offset] : + CrossProgramPrefetches()) { auto* prefetch = proto.mutable_cross_program_prefetches()->Add(); prefetch->set_parameter(parameter); for (auto index : indices) { prefetch->add_index(index); } + if (alt_memory_offset) { + prefetch->set_offset(*alt_memory_offset); + } } proto.set_is_dynamic(is_dynamic_); if (has_spmd_output_sharding()) { @@ -349,6 +378,13 @@ HloModuleProto HloModule::ToProto() const { return proto; } +StatusOr HloModule::ToProtoWithConfig() const { + HloModuleProtoWithConfig result; + TF_ASSIGN_OR_RETURN(*result.mutable_config(), config_.ToProto()); + *result.mutable_hlo_module() = ToProto(); + return result; +} + Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const { absl::flat_hash_set computation_names; absl::flat_hash_set computation_ids; @@ -475,7 +511,8 @@ StatusOr> HloModule::CreateFromProto( for (const auto& prefetch : proto.cross_program_prefetches()) { module->AddCrossProgramPrefetch( prefetch.parameter(), - ShapeIndex(prefetch.index().begin(), prefetch.index().end())); + ShapeIndex(prefetch.index().begin(), prefetch.index().end()), + prefetch.offset()); } module->set_is_dynamic(proto.is_dynamic()); @@ -540,8 +577,11 @@ StatusOr HloModule::CreateModuleConfigFromShape( } module_config.set_auto_spmd_partitioning_mesh_ids(mesh_ids); module_config.set_deduplicate_hlo(execution_options->deduplicate_hlo()); - module_config.set_allow_spmd_sharding_propagation_to_output( - execution_options->allow_spmd_sharding_propagation_to_output()); + if (!execution_options->allow_spmd_sharding_propagation_to_output() + .empty()) { + module_config.set_allow_spmd_sharding_propagation_to_output( + execution_options->allow_spmd_sharding_propagation_to_output()); + } if (execution_options->has_device_assignment()) { TF_ASSIGN_OR_RETURN(std::unique_ptr device_assignment, DeviceAssignment::Deserialize( @@ -556,6 +596,15 @@ StatusOr HloModule::CreateModuleConfigFromShape( module_config.num_partitions()); } } + std::vector param_requires_broadcast_via_collectives( + execution_options->param_requires_broadcast_via_collectives().begin(), + execution_options->param_requires_broadcast_via_collectives().end()); + module_config.set_param_requires_broadcast_via_collectives( + param_requires_broadcast_via_collectives); + module_config.set_allow_separate_sharding_programs( + execution_options->allow_separate_sharding_programs()); + HloModuleConfig::AssignStructShardableValueUpdatePairs( + module_config, execution_options->shardable_value_update_pairs()); } // The module config is constructed with default layouts regardless of what is @@ -596,6 +645,15 @@ StatusOr HloModule::CreateModuleConfigFromProto( return config; } +StatusOr> HloModule::CreateFromProtoWithConfig( + const HloModuleProtoWithConfig& proto, bool prohibit_empty_literal) { + auto hlo_module_proto = proto.hlo_module(); + TF_ASSIGN_OR_RETURN(std::unique_ptr config_ptr, + HloModuleConfig::CreateFromProto(proto.config())); + return HloModule::CreateFromProto(hlo_module_proto, *config_ptr, + prohibit_empty_literal); +} + namespace { // Returns whether `hlo` is used outside the given subcomputation. // `instructions_in_subcomputation` is the instruction set of the given @@ -893,10 +951,8 @@ std::unique_ptr HloModule::Clone(const HloModuleConfig& config, } TF_CHECK_OK(module->set_schedule(std::move(clone_schedule))); } - for (const auto& parameter_indices : CrossProgramPrefetches()) { - const auto& parameter = parameter_indices.first; - const auto& indices = parameter_indices.second; - module->AddCrossProgramPrefetch(parameter, indices); + for (const auto& [parameter, indices, offset] : CrossProgramPrefetches()) { + module->AddCrossProgramPrefetch(parameter, indices, offset); } // To make clone behavior match uncloned behavior, we reorder @@ -919,7 +975,8 @@ std::unique_ptr HloModule::Clone(const HloModuleConfig& config, Status HloModule::RemoveUnusedComputations() { std::string suffix = "tmp"; auto module = std::make_unique( - absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config()); + absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config(), + std::make_unique(*comp_envs_)); HloCloneContext context(module.get(), suffix); entry_computation_->Clone(suffix, &context); std::vector to_remove; @@ -963,16 +1020,6 @@ HloComputation* HloModule::GetComputationWithName(absl::string_view name) { return it == computations_in_module.end() ? nullptr : *it; } -HloModule::HloModule(const std::string& name, HloModuleConfig config, - std::unique_ptr comp_envs) - : name_(NameUniquer::GetSanitizedName(name)), - config_(std::move(config)), - unique_id_(next_unique_module_id_++), - metadata_(tsl::Env::Default()), - comp_envs_(std::move(comp_envs)) { - metadata_.set_canonical_module_id(unique_id_); -} - /* static */ std::atomic HloModule::next_unique_module_id_(0); } // namespace xla diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_module.h b/tensorflow/compiler/xla/hlo/ir/hlo_module.h index 2dfc8140e03..3e35fb8e1ef 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_module.h +++ b/tensorflow/compiler/xla/hlo/ir/hlo_module.h @@ -37,11 +37,12 @@ limitations under the License. #include "tensorflow/compiler/xla/hlo/ir/hlo_module_metadata.h" #include "tensorflow/compiler/xla/hlo/ir/hlo_schedule.h" #include "tensorflow/compiler/xla/iterator_util.h" +#include "tensorflow/compiler/xla/printer.h" #include "tensorflow/compiler/xla/service/compilation_environments.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" -#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/tsl/lib/gtl/iterator_range.h" #include "tensorflow/tsl/platform/logging.h" @@ -69,7 +70,11 @@ class HloModule { public: // Constructor. HloModule(const std::string& name, HloModuleConfig config); - virtual ~HloModule() {} + // REQUIRED: + // - comp_envs must not be null. + HloModule(const std::string& name, HloModuleConfig config, + std::unique_ptr comp_envs); + virtual ~HloModule() = default; // Adds an entry computation to the module. A module can only have one entry // computation. Returns a pointer to the newly added computation. @@ -319,6 +324,15 @@ class HloModule { bool is_dynamic() const { return is_dynamic_; } void set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; } + // Prints a string representation of the module. + // + // (We express the default options using an overload rather than a default + // param because gdb ignores default params, but does resolve overloads.) + void Print(Printer* printer) const { + return Print(printer, HloPrintOptions()); + } + void Print(Printer* printer, const HloPrintOptions& options) const; + // Return a string representation of the module. // // (We express the default options using an overload rather than a default @@ -339,6 +353,12 @@ class HloModule { const HloModuleProto& proto, const HloModuleConfig& module_config, bool prohibit_empty_literal = true); + // Convert an HloModule to or from a proto that includes module configuration + StatusOr ToProtoWithConfig() const; + static StatusOr> CreateFromProtoWithConfig( + const HloModuleProtoWithConfig& proto, + bool prohibit_empty_literal = true); + // Creates and returns an HloModuleConfig with an appropriate program shape // for the HLO module in the given proto. static StatusOr CreateModuleConfigFromProto( @@ -461,14 +481,36 @@ class HloModule { spmd_output_sharding_ = sharding; } + // Describes a buffer to be used for cross program prefetching. + struct CrossProgramPrefetchInfo { + // The parameter to prefetch. + int64_t parameter; + // Index of the buffer within a tuple-typed parameter. + ShapeIndex index; + // Offset into alt memory where the cross program pretched buffer will be + // stored. + std::optional alt_memory_offset; + }; + // Add a program argument to be prefetched across programs. - void AddCrossProgramPrefetch(int64_t parameter, const ShapeIndex& index) { - cross_program_prefetches_.emplace_back(parameter, index); + void AddCrossProgramPrefetch( + int64_t parameter, const ShapeIndex& index, + std::optional alt_memory_offset = std::nullopt) { + cross_program_prefetches_.emplace_back( + CrossProgramPrefetchInfo{parameter, index, alt_memory_offset}); + } + + Status SetCrossProgramPrefetchOffset(int64_t prefetch_index, int64_t offset) { + TF_RET_CHECK(prefetch_index < cross_program_prefetches_.size()); + auto& [parameter, index, optional_offset] = + cross_program_prefetches_[prefetch_index]; + TF_RET_CHECK(!optional_offset.has_value()); + optional_offset = offset; + return OkStatus(); } // Get the list of program arguments to be prefetch across programs. - const absl::Span> - CrossProgramPrefetches() const { + absl::Span CrossProgramPrefetches() const { return cross_program_prefetches_; } @@ -501,6 +543,23 @@ class HloModule { return profile_info_list_; } + void add_autofdo_pre_pass_fingerprint(absl::string_view fingerprint) { + autofdo_pre_pass_fingerprints_.push_back(std::string(fingerprint)); + } + + void set_autofdo_pre_pass_fingerprints( + const std::vector& fingerprints) { + autofdo_pre_pass_fingerprints_ = fingerprints; + } + + const std::vector& autofdo_pre_pass_fingerprints() const { + return autofdo_pre_pass_fingerprints_; + } + + bool has_module_autofdo_profiles() const { + return !autofdo_pre_pass_fingerprints_.empty(); + } + void set_relative_speedup(double relative_speedup) { relative_speedup_ = relative_speedup; } @@ -516,11 +575,6 @@ class HloModule { CompilationEnvironments& comp_envs() const { return *comp_envs_; } private: - // This constructor is used in Clone() to copy the CompilationEnvironments. - // comp_envs may be null, in which case a clean one will be created. - HloModule(const std::string& name, HloModuleConfig config, - std::unique_ptr comp_envs); - HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, bool uniquify_identifiers, bool preserve_entry_layouts); @@ -569,7 +623,7 @@ class HloModule { std::optional spmd_output_sharding_; // Arguments to be prefetched across programs. - std::vector> cross_program_prefetches_; + std::vector cross_program_prefetches_; // Metadata for this module, such as its canonical id and the HLO passes run. HloModuleMetadata metadata_; @@ -590,6 +644,10 @@ class HloModule { // The unoptimized module fingerprint. std::string autofdo_fingerprint_; + // The pre-pass module fingerprints used to retrieve the optimization profiles + // this module contains. + std::vector autofdo_pre_pass_fingerprints_; + bool use_auto_spmd_partitioning_ = false; // Layout canonicalization callback, used only when diff --git a/tensorflow/compiler/xla/service/hlo_module_group.cc b/tensorflow/compiler/xla/hlo/ir/hlo_module_group.cc similarity index 98% rename from tensorflow/compiler/xla/service/hlo_module_group.cc rename to tensorflow/compiler/xla/hlo/ir/hlo_module_group.cc index 789b77e92f3..b24e1ef931e 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group.cc +++ b/tensorflow/compiler/xla/hlo/ir/hlo_module_group.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_module_group.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_module_group.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/hlo/ir/hlo_module_group.h similarity index 94% rename from tensorflow/compiler/xla/service/hlo_module_group.h rename to tensorflow/compiler/xla/hlo/ir/hlo_module_group.h index dc4a9b72bf4..3a73fd8296e 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group.h +++ b/tensorflow/compiler/xla/hlo/ir/hlo_module_group.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ +#ifndef TENSORFLOW_COMPILER_XLA_HLO_IR_HLO_MODULE_GROUP_H_ +#define TENSORFLOW_COMPILER_XLA_HLO_IR_HLO_MODULE_GROUP_H_ #include #include @@ -22,8 +22,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "tensorflow/compiler/xla/hlo/ir/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" namespace xla { @@ -113,4 +113,4 @@ std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group); } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ +#endif // TENSORFLOW_COMPILER_XLA_HLO_IR_HLO_MODULE_GROUP_H_ diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_opcode.h b/tensorflow/compiler/xla/hlo/ir/hlo_opcode.h index 114f1f5e735..2f591956a76 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_opcode.h +++ b/tensorflow/compiler/xla/hlo/ir/hlo_opcode.h @@ -48,20 +48,22 @@ namespace xla { // the MHLO opset to keep both opsets synchronized. // LINT.IfChange #define HLO_OPCODE_LIST(V) \ + /* go/keep-sorted start */ \ V(kAbs, "abs", 1) \ V(kAdd, "add", 2) \ V(kAddDependency, "add-dependency", 2) \ V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ V(kAllGather, "all-gather", kHloOpcodeIsVariadic) \ - V(kAllGatherStart, "all-gather-start", kHloOpcodeIsVariadic) \ V(kAllGatherDone, "all-gather-done", 1) \ + V(kAllGatherStart, "all-gather-start", kHloOpcodeIsVariadic) \ V(kAllReduce, "all-reduce", kHloOpcodeIsVariadic) \ - V(kAllReduceStart, "all-reduce-start", kHloOpcodeIsVariadic) \ V(kAllReduceDone, "all-reduce-done", 1) \ + V(kAllReduceStart, "all-reduce-start", kHloOpcodeIsVariadic) \ V(kAllToAll, "all-to-all", kHloOpcodeIsVariadic) \ + V(kAnd, "and", 2) \ + V(kAsyncDone, "async-done", 1) \ V(kAsyncStart, "async-start", kHloOpcodeIsVariadic) \ V(kAsyncUpdate, "async-update", 1) \ - V(kAsyncDone, "async-done", 1) \ V(kAtan2, "atan2", 2) \ V(kBatchNormGrad, "batch-norm-grad", 5) \ V(kBatchNormInference, "batch-norm-inference", 5) \ @@ -70,13 +72,14 @@ namespace xla { V(kBitcastConvert, "bitcast-convert", 1) \ V(kBroadcast, "broadcast", 1) \ V(kCall, "call", kHloOpcodeIsVariadic) \ + V(kCbrt, "cbrt", 1) \ V(kCeil, "ceil", 1) \ V(kCholesky, "cholesky", 1) \ V(kClamp, "clamp", 3) \ + V(kClz, "count-leading-zeros", 1) \ V(kCollectivePermute, "collective-permute", kHloOpcodeIsVariadic) \ - V(kCollectivePermuteStart, "collective-permute-start", kHloOpcodeIsVariadic) \ V(kCollectivePermuteDone, "collective-permute-done", 1) \ - V(kClz, "count-leading-zeros", 1) \ + V(kCollectivePermuteStart, "collective-permute-start", kHloOpcodeIsVariadic) \ V(kCompare, "compare", 2) \ V(kComplex, "complex", 2) \ V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ @@ -92,6 +95,7 @@ namespace xla { V(kDivide, "divide", 2) \ V(kDomain, "domain", 1) \ V(kDot, "dot", 2) \ + V(kDynamicReshape, "dynamic-reshape", kHloOpcodeIsVariadic) \ V(kDynamicSlice, "dynamic-slice", kHloOpcodeIsVariadic) \ V(kDynamicUpdateSlice, "dynamic-update-slice", kHloOpcodeIsVariadic) \ V(kExp, "exponential", 1) \ @@ -101,7 +105,6 @@ namespace xla { V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather", 2) \ V(kGetDimensionSize, "get-dimension-size", 1) \ - V(kSetDimensionSize, "set-dimension-size", 2) \ V(kGetTupleElement, "get-tuple-element", 1) \ V(kImag, "imag", 1) \ V(kInfeed, "infeed", 1) \ @@ -110,16 +113,14 @@ namespace xla { V(kLog, "log", 1) \ V(kLog1p, "log-plus-one", 1) \ V(kLogistic, "logistic", 1) \ - V(kAnd, "and", 2) \ - V(kNot, "not", 1) \ - V(kOptimizationBarrier, "opt-barrier", 1) \ - V(kOr, "or", 2) \ - V(kXor, "xor", 2) \ V(kMap, "map", kHloOpcodeIsVariadic) \ V(kMaximum, "maximum", 2) \ V(kMinimum, "minimum", 2) \ V(kMultiply, "multiply", 2) \ V(kNegate, "negate", 1) \ + V(kNot, "not", 1) \ + V(kOptimizationBarrier, "opt-barrier", 1) \ + V(kOr, "or", 2) \ V(kOutfeed, "outfeed", 2) \ V(kPad, "pad", 2) \ V(kParameter, "parameter", 0) \ @@ -136,11 +137,10 @@ namespace xla { V(kRemainder, "remainder", 2) \ V(kReplicaId, "replica-id", 0) \ V(kReshape, "reshape", 1) \ - V(kDynamicReshape, "dynamic-reshape", kHloOpcodeIsVariadic) \ V(kReverse, "reverse", 1) \ V(kRng, "rng", kHloOpcodeIsVariadic) \ - V(kRngGetAndUpdateState, "rng-get-and-update-state", 0) \ V(kRngBitGenerator, "rng-bit-generator", 1) \ + V(kRngGetAndUpdateState, "rng-get-and-update-state", 0) \ V(kRoundNearestAfz, "round-nearest-afz", 1) \ V(kRoundNearestEven, "round-nearest-even", 1) \ V(kRsqrt, "rsqrt", 1) \ @@ -149,6 +149,7 @@ namespace xla { V(kSelectAndScatter, "select-and-scatter", 3) \ V(kSend, "send", 2) \ V(kSendDone, "send-done", 1) \ + V(kSetDimensionSize, "set-dimension-size", 2) \ V(kShiftLeft, "shift-left", 2) \ V(kShiftRightArithmetic, "shift-right-arithmetic", 2) \ V(kShiftRightLogical, "shift-right-logical", 2) \ @@ -158,14 +159,16 @@ namespace xla { V(kSort, "sort", kHloOpcodeIsVariadic) \ V(kSqrt, "sqrt", 1) \ V(kStochasticConvert, "stochastic-convert", 2) \ - V(kCbrt, "cbrt", 1) \ V(kSubtract, "subtract", 2) \ + V(kTan, "tan", 1) \ V(kTanh, "tanh", 1) \ V(kTranspose, "transpose", 1) \ V(kTriangularSolve, "triangular-solve", 2) \ V(kTuple, "tuple", kHloOpcodeIsVariadic) \ - V(kWhile, "while", 1) -// LINT.ThenChange(../../mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td) + V(kWhile, "while", 1) \ + V(kXor, "xor", 2) \ + /* go/keep-sorted end */ +// LINT.ThenChange(../../mlir_hlo/mhlo/IR/hlo_ops.td) enum class HloOpcode { #define DECLARE_ENUM(enum_name, opcode_name, ...) enum_name, @@ -219,7 +222,7 @@ inline bool HloOpcodeIsBinaryCommutative(HloOpcode opcode) { } // Returns the number of HloOpcode values. -inline const uint32_t HloOpcodeCount() { +inline constexpr uint32_t HloOpcodeCount() { #define HLO_COUNT_ONE(...) +1 #define HLO_XLIST_LENGTH(list) list(HLO_COUNT_ONE) return HLO_XLIST_LENGTH(HLO_OPCODE_LIST); diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_schedule.cc b/tensorflow/compiler/xla/hlo/ir/hlo_schedule.cc index 7be9b10bdf7..ea231dcd9b5 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_schedule.cc +++ b/tensorflow/compiler/xla/hlo/ir/hlo_schedule.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" @@ -365,23 +366,27 @@ std::string HloSchedule::ToString() const { std::vector pieces; pieces.push_back("HloSchedule"); + std::vector sorted_ids; for (const auto& id_sequence : sequences_) { - const HloComputation* computation = - IdToComputation(module_, id_sequence.first); + sorted_ids.push_back(id_sequence.first); + } + absl::c_sort(sorted_ids); + + for (const int64_t id : sorted_ids) { + const HloComputation* computation = IdToComputation(module_, id); + const HloInstructionSequence& sequence = sequences_.at(id); if (computation == nullptr) { // The computation is not in the module and may have been deleted so it is // not safe to dereference any HLO pointers. Just use the HLO unique ids // stored in this object. - pieces.push_back( - absl::StrFormat("computation with id %d (no longer in HLO module):", - id_sequence.first)); - for (int id : id_sequence.second.ids()) { + pieces.push_back(absl::StrFormat( + "computation with id %d (no longer in HLO module):", id)); + for (int id : sequence.ids()) { pieces.push_back(absl::StrCat(" ", id)); } } else { pieces.push_back(absl::StrFormat("computation %s:", computation->name())); - for (const HloInstruction* instruction : - id_sequence.second.instructions()) { + for (const HloInstruction* instruction : sequence.instructions()) { pieces.push_back(absl::StrCat(" ", instruction->name())); } } @@ -390,8 +395,7 @@ std::string HloSchedule::ToString() const { } std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule) { - out << schedule.ToString(); - return out; + return out << schedule.ToString(); } } // namespace xla diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_sharding.cc b/tensorflow/compiler/xla/hlo/ir/hlo_sharding.cc index 430697b519c..75519576c55 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_sharding.cc +++ b/tensorflow/compiler/xla/hlo/ir/hlo_sharding.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/tsl/platform/errors.h" namespace xla { @@ -95,9 +96,13 @@ HloSharding HloSharding::PartialTile( return HloSharding(fully_tiled, /*replicate_on_last_tile_dim=*/false, metadata); } - std::vector> sorted_groups( - tile_assignment_last_dim_replicate.num_elements() / - tile_assignment_last_dim_replicate.dimensions().back()); + std::vector sorted_groups( + tile_assignment_last_dim_replicate.num_elements()); + const int64_t group_size = + tile_assignment_last_dim_replicate.dimensions().back(); + const int64_t num_groups = + tile_assignment_last_dim_replicate.num_elements() / group_size; + std::vector current_group_idx(num_groups, 0); auto get_group_id = [&](absl::Span indices) { int64_t group_id = 0; for (int64_t i = 0; i < indices.size() - 1; ++i) { @@ -108,14 +113,20 @@ HloSharding HloSharding::PartialTile( }; tile_assignment_last_dim_replicate.Each( [&](absl::Span indices, const int64_t device) { - sorted_groups[get_group_id(indices)].insert(device); + const int64_t group_id = get_group_id(indices); + sorted_groups[group_id * group_size + current_group_idx[group_id]++] = + device; }); + for (int i = 0; i < num_groups; ++i) { + std::sort(sorted_groups.begin() + i * group_size, + sorted_groups.begin() + (i + 1) * group_size); + } + absl::c_fill(current_group_idx, 0); Array sorted_tile(tile_assignment_last_dim_replicate.dimensions()); sorted_tile.Each([&](absl::Span indices, int64_t* device) { const int64_t group_id = get_group_id(indices); - auto begin = sorted_groups[group_id].begin(); - *device = *begin; - sorted_groups[group_id].erase(begin); + *device = + sorted_groups[group_id * group_size + current_group_idx[group_id]++]; }); return HloSharding(sorted_tile, /*replicate_on_last_tile_dim=*/true, metadata); @@ -514,7 +525,14 @@ StatusOr HloSharding::GetTupleSharding(const Shape& shape) const { TF_RETURN_IF_ERROR(CheckLeafCount(shape)); return *this; } - return Tuple(ShapeTree(shape, *this)); + return SingleTuple(shape, *this); +} + +HloSharding HloSharding::NormalizeTupleSharding(const Shape& shape) const { + if (shape.IsTuple() && !IsTuple()) { + return HloSharding::SingleTuple(shape, *this); + } + return *this; } std::optional HloSharding::UniqueDevice() const { @@ -545,7 +563,7 @@ int64_t HloSharding::GetUniqueDevice() const { } Status HloSharding::ValidateTuple(const Shape& shape, - int64_t num_devices) const { + std::optional num_devices) const { if (!shape.IsTuple()) { return tsl::errors::InvalidArgument( StrCat("Sharding is tuple-shaped but validation shape is not.")); @@ -573,7 +591,8 @@ Status HloSharding::ValidateTuple(const Shape& shape, return OkStatus(); } -Status HloSharding::Validate(const Shape& shape, int64_t num_devices) const { +Status HloSharding::Validate(const Shape& shape, + std::optional num_devices) const { if (shape.IsToken()) { return OkStatus(); } @@ -588,7 +607,7 @@ Status HloSharding::Validate(const Shape& shape, int64_t num_devices) const { } Status HloSharding::ValidateNonTuple(const Shape& shape, - int64_t num_devices) const { + std::optional num_devices) const { if (shape.IsTuple()) { return tsl::errors::InvalidArgument( StrCat("Validation shape is a tuple but sharding is not.")); @@ -597,42 +616,44 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return OkStatus(); } - // All tile assignments must be less than the number of available cores and + // All tile assignments must be less than the number of available devices and // unique. - Status status = OkStatus(); - absl::flat_hash_set seen_cores; - tile_assignment_.Each([&](absl::Span indices, int32_t core) { - // Don't overwrite a bad status, so we report the first error. - if (status.ok()) { - if (core >= num_devices) { - status = tsl::errors::InvalidArgument( - StrCat("core ", core, " > ", num_devices, " in tile assignment")); - } else if (seen_cores.contains(core)) { - status = tsl::errors::InvalidArgument( - StrCat("core ", core, " is not unique in tile assignment")); - } - seen_cores.insert(core); - } - }); - if (!status.ok()) { - return status; - } + absl::flat_hash_set seen_devices; + Status status = tile_assignment_.EachStatus( + [&num_devices, &seen_devices](absl::Span /*indices*/, + int32_t device) { + if (num_devices.has_value() && device >= *num_devices) { + return tsl::errors::InvalidArgument( + StrCat("device ", device, " > num_devices (", *num_devices, + ") in tile assignment")); + } else if (seen_devices.contains(device)) { + return tsl::errors::InvalidArgument( + StrCat("device ", device, " is not unique in tile assignment")); + } + seen_devices.insert(device); + return OkStatus(); + }); + TF_RETURN_IF_ERROR(status); if (IsTileMaximal() || IsManual()) { return OkStatus(); } - // The tile assignment tensor must have the same rank as the input, or input - // rank + 1 for replicate_on_last_tile_dim_. - if (shape.rank() + (replicate_on_last_tile_dim_ ? 1 : 0) + - subgroup_types_.size() != - tile_assignment_.num_dimensions()) { + // The tile assignment tensor must have the same rank as the tiled data rank. + if (shape.rank() != TiledDataRank()) { return tsl::errors::InvalidArgument( - "Number of tile assignment dimensions is different to the input rank. " + "Number of tile assignment dimensions (excluding subgroups) is " + "different than the input rank. " "sharding=", ToString(), ", input_shape=", ShapeUtil::HumanString(shape)); } + // All devices should be seen in the tile assignment. + if (num_devices.has_value() && seen_devices.size() != *num_devices) { + return tsl::errors::InvalidArgument("tile_assignment should have ", + *num_devices, " devices"); + } + // The correct constructor has to be used to create tile maximal shardings. if (tile_assignment_.num_elements() == 1) { return tsl::errors::InvalidArgument( diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_sharding.h b/tensorflow/compiler/xla/hlo/ir/hlo_sharding.h index 8ea57c9d9c6..83775be7079 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_sharding.h +++ b/tensorflow/compiler/xla/hlo/ir/hlo_sharding.h @@ -121,7 +121,8 @@ class HloSharding { std::string ToString(bool include_metadata = false) const; // Validate that this sharding can be applied to a tensor with shape `shape`. - Status Validate(const Shape& shape, int64_t num_devices) const; + Status Validate(const Shape& shape, + std::optional num_devices = {}) const; // Returns true if the sharding has tuple type. bool IsTuple() const { return tuple_; } @@ -248,6 +249,12 @@ class HloSharding { // having this object sharding. StatusOr GetTupleSharding(const Shape& shape) const; + // If the shape is tuple and the current sharding is not a tuple, attempt to + // construct a sharding that is compatible with the shape by replicating the + // current sharding across all tuple elements. Note that the returned + // sharding is not guaranteed to be compatible with the input shape. + HloSharding NormalizeTupleSharding(const Shape& shape) const; + // Extracts the sharding that is common within the current sharding. // If the current sharding is not a tuple sharding, the current sharding will // be returned. If it is a tuple, and all the tuple elements are common, the @@ -363,13 +370,13 @@ class HloSharding { private: explicit HloSharding(bool manual, bool replicated, absl::Span metadata) - : replicated_(replicated), + : tile_assignment_({0}), + metadata_(metadata.begin(), metadata.end()), + replicated_(replicated), maximal_(replicated), tuple_(false), manual_(manual), - tile_assignment_({0}), - replicate_on_last_tile_dim_(false), - metadata_(metadata.begin(), metadata.end()) {} + replicate_on_last_tile_dim_(false) {} // device_id values: // -2: magic number to mean unassigned device, used by spatial partitioning // -1: the id of the host @@ -377,41 +384,41 @@ class HloSharding { // NOTE(dimvar): -1 is needed for outside compilation. It can be removed once // we have fully switched to the side-effect tokens. explicit HloSharding(int64_t device_id, absl::Span metadata) - : replicated_(false), + : tile_assignment_({1}, device_id), + metadata_(metadata.begin(), metadata.end()), + replicated_(false), maximal_(true), tuple_(false), manual_(false), - tile_assignment_({1}, device_id), - replicate_on_last_tile_dim_(false), - metadata_(metadata.begin(), metadata.end()) {} + replicate_on_last_tile_dim_(false) {} explicit HloSharding(const Array& tile_assignment, bool replicate_on_last_tile_dim, absl::Span metadata = {}) - : replicated_(false), + : tile_assignment_(tile_assignment), + metadata_(metadata.begin(), metadata.end()), + replicated_(false), maximal_(false), tuple_(false), manual_(false), - tile_assignment_(tile_assignment), - replicate_on_last_tile_dim_(replicate_on_last_tile_dim), - metadata_(metadata.begin(), metadata.end()) {} + replicate_on_last_tile_dim_(replicate_on_last_tile_dim) {} explicit HloSharding(const Array& tile_assignment, absl::Span subgroup_types, absl::Span metadata = {}) - : replicated_(false), + : tile_assignment_(tile_assignment), + metadata_(metadata.begin(), metadata.end()), + subgroup_types_(subgroup_types.begin(), subgroup_types.end()), + replicated_(false), maximal_(false), tuple_(false), manual_(false), - tile_assignment_(tile_assignment), - replicate_on_last_tile_dim_(false), - metadata_(metadata.begin(), metadata.end()), - subgroup_types_(subgroup_types.begin(), subgroup_types.end()) {} + replicate_on_last_tile_dim_(false) {} explicit HloSharding(const std::vector& tuple_shardings) - : replicated_(false), + : tile_assignment_({0}), + tuple_elements_(tuple_shardings), + replicated_(false), maximal_(false), tuple_(true), manual_(false), - tile_assignment_({0}), - tuple_elements_(tuple_shardings), replicate_on_last_tile_dim_(false) {} // Checks that the number of elements in tuple_elements_ is consistent with @@ -419,15 +426,13 @@ class HloSharding { Status CheckLeafCount(const Shape& shape) const; // Internal helper to validate a tuple sharding. - Status ValidateTuple(const Shape& shape, int64_t num_devices) const; + Status ValidateTuple(const Shape& shape, + std::optional num_devices) const; // Internal helper to validate a non-tuple (leaf) sharding. - Status ValidateNonTuple(const Shape& shape, int64_t num_devices) const; + Status ValidateNonTuple(const Shape& shape, + std::optional num_devices) const; - bool replicated_; - bool maximal_; - bool tuple_; - bool manual_; // This field is only used if replicated_ is false. If maximal_ is true, then // the field contains a rank 1 array with a single element, which is the // device the HLO is assigned to. If maximal_ is false, the field contains an @@ -450,7 +455,6 @@ class HloSharding { // true, tile_assignment_ will have an extra dimension in addition to the data // shape rank, and the added last dimension represents the subgroups of // replications, i.e., elements in slice [..., :] will be replicated. - bool replicate_on_last_tile_dim_; // This field is used to track the source of this sharding, usually derived // from instructions. Multiple metadata may be populated if sharding is // combined with other shardings. Metadata are to not be populated when @@ -464,6 +468,11 @@ class HloSharding { // When creating HloSharding, subgroup dims of the same type will be merged, // so that there is at most one dim with a given type. std::vector subgroup_types_; + bool replicated_; + bool maximal_; + bool tuple_; + bool manual_; + bool replicate_on_last_tile_dim_; }; std::ostream& operator<<(std::ostream& out, const HloSharding& sharding); diff --git a/tensorflow/compiler/xla/hlo/ir/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/hlo/ir/hlo_sharding_metadata.cc index 37c3ad0e541..24912eb98e2 100644 --- a/tensorflow/compiler/xla/hlo/ir/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/hlo/ir/hlo_sharding_metadata.cc @@ -128,7 +128,7 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, HloInstruction* gte = pass_through.operand->parent()->AddInstruction( HloInstruction::CreateGetTupleElement(pass_through.operand->shape(), tuple, 0)); - gte->set_sharding(sharding); + gte->set_sharding(sharding.NormalizeTupleSharding(gte->shape())); if (pass_through.user != nullptr) { TF_RETURN_IF_ERROR( pass_through.operand->ReplaceUseWith(pass_through.user, gte)); @@ -139,8 +139,8 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, return OkStatus(); } -// For tuple shardings if every element have the same sharsing then we want to -// treat them as single element sharsings to insert less domain separation as a +// For tuple shardings if every element have the same sharding then we want to +// treat them as single element shardings to insert less domain separation as a // domain can prevent some optimizations and we want to minimize that from // happening. std::shared_ptr CloneShardingForDomain( diff --git a/tensorflow/compiler/xla/hlo/transforms/BUILD b/tensorflow/compiler/xla/hlo/transforms/BUILD new file mode 100644 index 00000000000..695cc44e295 --- /dev/null +++ b/tensorflow/compiler/xla/hlo/transforms/BUILD @@ -0,0 +1,39 @@ +# Description: +# Implementation of XLA’s HLO transformations. + +load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +cc_library( + name = "hlo_constant_splitter", + srcs = ["hlo_constant_splitter.cc"], + hdrs = ["hlo_constant_splitter.h"], + deps = ["//tensorflow/compiler/xla/service:hlo_pass"], +) + +xla_cc_test( + name = "hlo_constant_splitter_test", + srcs = ["hlo_constant_splitter_test.cc"], + deps = [ + ":hlo_constant_splitter", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/tsl/lib/core:status_test_util", + ], +) diff --git a/tensorflow/compiler/xla/service/hlo_constant_splitter.cc b/tensorflow/compiler/xla/hlo/transforms/hlo_constant_splitter.cc similarity index 98% rename from tensorflow/compiler/xla/service/hlo_constant_splitter.cc rename to tensorflow/compiler/xla/hlo/transforms/hlo_constant_splitter.cc index b44a4dcc24a..a1ae36ebcfa 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_splitter.cc +++ b/tensorflow/compiler/xla/hlo/transforms/hlo_constant_splitter.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_constant_splitter.h" +#include "tensorflow/compiler/xla/hlo/transforms/hlo_constant_splitter.h" #include #include diff --git a/tensorflow/compiler/xla/service/hlo_constant_splitter.h b/tensorflow/compiler/xla/hlo/transforms/hlo_constant_splitter.h similarity index 88% rename from tensorflow/compiler/xla/service/hlo_constant_splitter.h rename to tensorflow/compiler/xla/hlo/transforms/hlo_constant_splitter.h index 56f222a4608..f4d3117068c 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_splitter.h +++ b/tensorflow/compiler/xla/hlo/transforms/hlo_constant_splitter.h @@ -12,8 +12,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CONSTANT_SPLITTER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CONSTANT_SPLITTER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_HLO_TRANSFORMS_HLO_CONSTANT_SPLITTER_H_ +#define TENSORFLOW_COMPILER_XLA_HLO_TRANSFORMS_HLO_CONSTANT_SPLITTER_H_ #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -43,4 +43,4 @@ class HloConstantSplitter : public HloModulePass { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CONSTANT_SPLITTER_H_ +#endif // TENSORFLOW_COMPILER_XLA_HLO_TRANSFORMS_HLO_CONSTANT_SPLITTER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_constant_splitter_test.cc b/tensorflow/compiler/xla/hlo/transforms/hlo_constant_splitter_test.cc similarity index 98% rename from tensorflow/compiler/xla/service/hlo_constant_splitter_test.cc rename to tensorflow/compiler/xla/hlo/transforms/hlo_constant_splitter_test.cc index 8a75714f2ff..e0210d2f8c2 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_splitter_test.cc +++ b/tensorflow/compiler/xla/hlo/transforms/hlo_constant_splitter_test.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_constant_splitter.h" +#include "tensorflow/compiler/xla/hlo/transforms/hlo_constant_splitter.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 3687dbc3e76..ac35fd0567b 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -25,70 +25,6 @@ limitations under the License. namespace xla { -/* static */ int64_t IndexUtil::MultidimensionalIndexToLinearIndex( - const Shape& shape, absl::Span multi_index) { - DCHECK_EQ(shape.dimensions_size(), multi_index.size()); - - for (size_t i = 0; i < multi_index.size(); ++i) { - DCHECK_GE(multi_index[i], 0); - DCHECK_LT(multi_index[i], shape.dimensions(i)) - << "indexing beyond extent in dimension " << i << ":" - << "\n\tindex: " << absl::StrJoin(multi_index, ",") - << "\n\tshape: " << ShapeUtil::HumanString(shape); - } - - // Let the array be sized like so for dimensions i from 0 to n-1: - // - // [D{n-1} x D{n-2} x .. x D{0}] - // - // Let the order of the dimensions in the minor_to_major field in - // Layout be: - // - // L(0), L(1), ... , L(n-1) - // - // where L(0) is the most-minor dimension and L(n-1) the most-major. The - // multidimensional index: - // - // [I{0}, I{1}, ... , I{n-1}] - // - // then corresponds to the following linear index: - // - // linear_index = - // ((( ... + I{L(2)}) * D{L(1)} + I{L(1)}) * D{L(0)} + I{L(0)} - // - // or equivalently: - // - // linear_index = - // I{L(n-1)} * (D{L(n-2)} * D{L(n-3)} * D{L(n-4)} * .... D{L(0)}) + - // I{L(n-2)} * (D{L(n-3)} * D{L(n-4)} * .... D{L(0)}) + - // I{L(n-3)} * (D{L(n-4)} * .... D{L(0)}) + - // ... + - // I{L(2)} * (D{L(1)} * D{L(0)}) + - // I{L(1)} * D{L(0)} + - // I{L(0)} - // - // We compute the linear index value by accumulating the terms above from - // I{L(0)} up to I{L(n-1)}. Scale accumulates the product term D{L(0}} * - // D{L(1)} * ... - - // Scale factor holding the growing product of D{L(i)} terms. - int64_t scale = 1; - int64_t linear_index = 0; - bool first = true; - for (auto dimension : LayoutUtil::MinorToMajor(shape)) { - if (first) { - // Avoid two multiplies on the first loop iteration - linear_index = multi_index[dimension]; - scale = shape.dimensions(dimension); - first = false; - } else { - linear_index += scale * multi_index[dimension]; - scale *= shape.dimensions(dimension); - } - } - return linear_index; -} - /* static */ std::vector IndexUtil::LinearIndexToMultidimensionalIndex( const Shape& shape, int64_t linear_index) { DCHECK_GE(linear_index, 0); diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index b3a7cfb5d03..be01a96f0d0 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -34,8 +35,62 @@ class IndexUtil { // Converts a multidimensional index (eg {x, y, z}) into a linear index based // on the shape and its layout. The first index in the multi_index is // dimension 0. - static int64_t MultidimensionalIndexToLinearIndex( - const Shape& shape, absl::Span multi_index); + static inline int64_t MultidimensionalIndexToLinearIndex( + const Shape& shape, absl::Span multi_index) { + // Let the array be sized like so for dimensions i from 0 to n-1: + // + // [D{n-1} x D{n-2} x .. x D{0}] + // + // Let the order of the dimensions in the minor_to_major field in + // Layout be: + // + // L(0), L(1), ... , L(n-1) + // + // where L(0) is the most-minor dimension and L(n-1) the most-major. The + // multidimensional index: + // + // [I{0}, I{1}, ... , I{n-1}] + // + // then corresponds to the following linear index: + // + // linear_index = + // ((( ... + I{L(2)}) * D{L(1)} + I{L(1)}) * D{L(0)} + I{L(0)} + // + // or equivalently: + // + // linear_index = + // I{L(n-1)} * (D{L(n-2)} * D{L(n-3)} * D{L(n-4)} * .... D{L(0)}) + + // I{L(n-2)} * (D{L(n-3)} * D{L(n-4)} * .... D{L(0)}) + + // I{L(n-3)} * (D{L(n-4)} * .... D{L(0)}) + + // ... + + // I{L(2)} * (D{L(1)} * D{L(0)}) + + // I{L(1)} * D{L(0)} + + // I{L(0)} + // + // We compute the linear index value by accumulating the terms above from + // I{L(0)} up to I{L(n-1)}. Scale accumulates the product term D{L(0}} * + // D{L(1)} * ... + + // Scale factor holding the growing product of D{L(i)} terms. + for (size_t i = 0; i < multi_index.size(); ++i) { + DCHECK_GE(multi_index[i], 0); + DCHECK_LT(multi_index[i], shape.dimensions(i)) + << "indexing beyond extent in dimension " << i << ":" + << "\n\tindex: " << absl::StrJoin(multi_index, ",") + << "\n\tshape: " << ShapeUtil::HumanString(shape); + } + auto effective_shape = LayoutUtil::MinorToMajor(shape); + if (effective_shape.empty()) { + return 0; + } + int64_t linear_index = multi_index[effective_shape[0]]; + int64_t scale = 1; + for (int i = 1; i < effective_shape.size(); ++i) { + scale *= shape.dimensions(effective_shape[i - 1]); + linear_index += scale * multi_index[effective_shape[i]]; + } + return linear_index; + } // Converts a linear index into multidimensional index (eg {x, y, z}) based on // the shape and its layout. The first index in the returned multidimensional diff --git a/tensorflow/compiler/xla/layout.cc b/tensorflow/compiler/xla/layout.cc index 629336d9442..ec842216d86 100644 --- a/tensorflow/compiler/xla/layout.cc +++ b/tensorflow/compiler/xla/layout.cc @@ -25,6 +25,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/printer.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -38,22 +40,30 @@ TileProto Tile::ToProto() const { return tile_proto; } -std::string Tile::ToString() const { - std::vector elements; +void Tile::Print(Printer* printer) const { + printer->Append("("); const auto& dims = dimensions(); - elements.reserve(dims.size()); - for (auto dim : dims) { + for (int i = 0; i < dims.size(); ++i) { + const auto dim = dims[i]; + if (i != 0) printer->Append(","); if (dim >= 0) { - elements.push_back(std::to_string(dim)); + printer->Append(absl::AlphaNum(dim).Piece()); } else { if (dim == kCombineDimension) { - elements.push_back("*"); + printer->Append("*"); } else { - elements.push_back(absl::StrCat("Invalid value ", dim)); + printer->Append("Invalid value "); + printer->Append(absl::AlphaNum(dim).Piece()); } } } - return absl::StrCat("(", absl::StrJoin(elements, ","), ")"); + printer->Append(")"); +} + +std::string Tile::ToString() const { + StringPrinter printer; + Print(&printer); + return std::move(printer).ToString(); } Layout::Layout() = default; @@ -67,7 +77,8 @@ Layout::Layout(absl::Span minor_to_major, absl::Span dim_ordered, absl::Span tiles, PrimitiveType index_primitive_type, PrimitiveType pointer_primitive_type, int64_t memory_space, - std::unique_ptr physical_shape) + std::unique_ptr physical_shape, + int64_t dynamic_shape_metadata_prefix_bytes) : dim_level_types_(dim_level_types.begin(), dim_level_types.end()), dim_unique_(dim_unique.begin(), dim_unique.end()), dim_ordered_(dim_ordered.begin(), dim_ordered.end()), @@ -76,7 +87,9 @@ Layout::Layout(absl::Span minor_to_major, index_primitive_type_(index_primitive_type), pointer_primitive_type_(pointer_primitive_type), memory_space_(memory_space), - physical_shape_(std::move(physical_shape)) {} + physical_shape_(std::move(physical_shape)), + dynamic_shape_metadata_prefix_bytes_( + dynamic_shape_metadata_prefix_bytes) {} Layout::Layout(const Layout& other) : dim_level_types_(other.dim_level_types_), @@ -89,7 +102,9 @@ Layout::Layout(const Layout& other) memory_space_(other.memory_space_), physical_shape_(other.physical_shape_ != nullptr ? std::make_unique(*other.physical_shape_) - : nullptr) {} + : nullptr), + dynamic_shape_metadata_prefix_bytes_( + other.dynamic_shape_metadata_prefix_bytes_) {} Layout::Layout(Layout&& other) = default; @@ -110,6 +125,8 @@ Layout& Layout::operator=(const Layout& other) { } else { physical_shape_ = nullptr; } + dynamic_shape_metadata_prefix_bytes_ = + other.dynamic_shape_metadata_prefix_bytes_; } return *this; } @@ -140,6 +157,8 @@ Layout& Layout::operator=(Layout&& other) = default; if (proto.has_physical_shape()) { *layout.mutable_physical_shape() = Shape(proto.physical_shape()); } + layout.set_dynamic_shape_metadata_prefix_bytes( + proto.dynamic_shape_metadata_prefix_bytes()); return layout; } @@ -167,6 +186,8 @@ LayoutProto Layout::ToProto() const { if (has_physical_shape()) { *proto.mutable_physical_shape() = physical_shape_->ToProto(); } + proto.set_dynamic_shape_metadata_prefix_bytes( + dynamic_shape_metadata_prefix_bytes_); return proto; } @@ -185,58 +206,96 @@ absl::string_view DimLevelTypeAbbrev(DimLevelType dim_level_type) { } } // namespace -std::string Layout::ToString() const { - std::string colon_string; +void Layout::Print(Printer* printer) const { + printer->Append("{"); + printer->Append(absl::StrJoin(minor_to_major(), ",")); + + bool colon_printed = false; + auto print_colon = [&]() { + if (colon_printed) return; + printer->Append(":"); + colon_printed = true; + }; if (!dim_level_types().empty()) { - absl::StrAppend(&colon_string, "D("); + print_colon(); + printer->Append("D("); for (int i = 0; i < dim_level_types().size(); ++i) { if (i != 0) { - absl::StrAppend(&colon_string, ","); + printer->Append(","); } - absl::StrAppend(&colon_string, DimLevelTypeAbbrev(dim_level_type(i))); + printer->Append(DimLevelTypeAbbrev(dim_level_type(i))); if (!dim_unique().empty() && !dim_unique(i)) { - absl::StrAppend(&colon_string, "+"); + printer->Append("+"); } if (!dim_ordered().empty() && !dim_ordered(i)) { - absl::StrAppend(&colon_string, "~"); + printer->Append("~"); } } - absl::StrAppend(&colon_string, ")"); + printer->Append(")"); } if (!tiles().empty()) { - absl::StrAppend(&colon_string, "T"); + print_colon(); + printer->Append("T"); for (const Tile& tile : tiles()) { - absl::StrAppend(&colon_string, tile.ToString()); + tile.Print(printer); } } if (index_primitive_type() != PRIMITIVE_TYPE_INVALID) { - absl::StrAppend( - &colon_string, "#(", - primitive_util::LowercasePrimitiveTypeName(index_primitive_type()), - ")"); + print_colon(); + if (primitive_util::IsIntegralType(index_primitive_type())) { + printer->Append("#("); + printer->Append( + primitive_util::LowercasePrimitiveTypeName(index_primitive_type())); + printer->Append(")"); + } else { + printer->Append("#(invalid)"); + } } if (pointer_primitive_type() != PRIMITIVE_TYPE_INVALID) { - absl::StrAppend( - &colon_string, "*(", - primitive_util::LowercasePrimitiveTypeName(pointer_primitive_type()), - ")"); + print_colon(); + if (primitive_util::IsIntegralType(pointer_primitive_type())) { + printer->Append("*("); + printer->Append( + primitive_util::LowercasePrimitiveTypeName(pointer_primitive_type())); + printer->Append(")"); + } else { + printer->Append("*(invalid)"); + } } if (memory_space() != 0) { - absl::StrAppend(&colon_string, "S(", memory_space(), ")"); + print_colon(); + printer->Append("S("); + printer->Append(absl::AlphaNum(memory_space()).Piece()); + printer->Append(")"); } if (has_physical_shape()) { - absl::StrAppend(&colon_string, "P(", - physical_shape_->ToString(/*print_layout=*/true), ")"); + print_colon(); + printer->Append("P("); + physical_shape_->Print(printer, /*print_layout=*/true); + printer->Append(")"); + } + + if (dynamic_shape_metadata_prefix_bytes_ > 0) { + print_colon(); + printer->Append("M("); + printer->Append( + absl::AlphaNum(dynamic_shape_metadata_prefix_bytes()).Piece()); + printer->Append(")"); } - return absl::StrCat("{", absl::StrJoin(minor_to_major(), ","), - colon_string.empty() ? "" : ":", colon_string, "}"); + printer->Append("}"); +} + +std::string Layout::ToString() const { + StringPrinter printer; + Print(&printer); + return std::move(printer).ToString(); } bool Layout::Equal::operator()(const Layout& lhs, const Layout& rhs) { diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h index d1e4d892763..7e266e46038 100644 --- a/tensorflow/compiler/xla/layout.h +++ b/tensorflow/compiler/xla/layout.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/xla/printer.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -50,6 +51,8 @@ class Tile { } bool operator!=(const Tile& other) const { return !(*this == other); } + void Print(Printer* printer) const; + std::string ToString() const; // Returns the bound of the tile in the given dimension index. @@ -103,7 +106,8 @@ class Layout { PrimitiveType index_primitive_type = PRIMITIVE_TYPE_INVALID, PrimitiveType element_primitive_type = PRIMITIVE_TYPE_INVALID, int64_t memory_space = 0, - std::unique_ptr physical_shape = nullptr); + std::unique_ptr physical_shape = nullptr, + int64_t dynamic_shape_metadata_prefix_bytes = 0); Layout& operator=(const Layout& other); Layout& operator=(Layout&& other); @@ -114,6 +118,9 @@ class Layout { // Returns a LayoutProto representation of the Layout. LayoutProto ToProto() const; + // Prints a human-readable string that represents this layout. + void Print(Printer* printer) const; + // Returns a human-readable string that represents this layout. std::string ToString() const; @@ -306,6 +313,13 @@ class Layout { Shape* mutable_physical_shape(); void clear_physical_shape(); + int64_t dynamic_shape_metadata_prefix_bytes() const { + return dynamic_shape_metadata_prefix_bytes_; + } + void set_dynamic_shape_metadata_prefix_bytes(int64_t bytes) { + dynamic_shape_metadata_prefix_bytes_ = bytes; + } + void Swap(Layout* other) { using std::swap; swap(*this, *other); @@ -355,6 +369,11 @@ class Layout { // The physical on-device shape used to represent a sparse array. std::unique_ptr physical_shape_; + + // The dynamic shape metadata size in bytes in front of the shape data. The + // field may be non-zero for a static shape whose associated buffer is for a + // dynamic shape, e.g. a result of SliceToDynamic. + int64_t dynamic_shape_metadata_prefix_bytes_ = 0; }; std::ostream& operator<<(std::ostream& out, const Tile& Tile); diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 1586f735a01..ff03706ef97 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -28,7 +28,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/printer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/logging.h" @@ -49,6 +51,8 @@ void SetDefaultLayoutToContainer(T* minor_to_major) { } } +absl::string_view BoolToString(bool b) { return b ? "true" : "false"; } + } // namespace /* static */ Layout LayoutUtil::MakeLayout( @@ -57,7 +61,8 @@ void SetDefaultLayoutToContainer(T* minor_to_major) { absl::Span dim_unique, absl::Span dim_ordered, absl::Span tiles, PrimitiveType index_primitive_type, PrimitiveType pointer_primitive_type, int64_t memory_space, - std::optional physical_shape) { + std::optional physical_shape, + int64_t dynamic_shape_metadata_prefix_bytes) { Layout layout; for (int64_t dimension_number : minor_to_major) { layout.add_minor_to_major(dimension_number); @@ -88,6 +93,8 @@ void SetDefaultLayoutToContainer(T* minor_to_major) { if (physical_shape != std::nullopt) { *layout.mutable_physical_shape() = *std::move(physical_shape); } + layout.set_dynamic_shape_metadata_prefix_bytes( + dynamic_shape_metadata_prefix_bytes); return layout; } @@ -271,6 +278,34 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { } } + if (!layout.dim_unique().empty()) { + if (layout.dim_unique().size() != shape.rank()) { + return InvalidArgument( + "layout dim_unique field contains %d elements, but shape is " + "rank %d: {%s}; shape: %s", + layout.dim_unique_size(), shape.rank(), + absl::StrJoin(layout.dim_unique(), ", ", + [](std::string* out, bool dim_unique) { + absl::StrAppend(out, BoolToString(dim_unique)); + }), + shape.ShortDebugString()); + } + } + + if (!layout.dim_ordered().empty()) { + if (layout.dim_ordered().size() != shape.rank()) { + return InvalidArgument( + "layout dim_unique field contains %d elements, but shape is " + "rank %d: {%s}; shape: %s", + layout.dim_ordered_size(), shape.rank(), + absl::StrJoin(layout.dim_unique(), ", ", + [](std::string* out, bool dim_unique) { + absl::StrAppend(out, BoolToString(dim_unique)); + }), + shape.ShortDebugString()); + } + } + if (LayoutUtil::IsSparse(layout)) { if (layout.tiles_size() > 0) { return InvalidArgument( @@ -323,6 +358,26 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { "layout has a physical_shape, but is not a sparse array: %s", shape.ShortDebugString()); } + for (const auto& tile : layout.tiles()) { + if (tile.dimensions().empty() || + absl::c_any_of(tile.dimensions(), + [](int64_t dim) { return dim == 0; })) { + return InvalidArgument("layout has invalid tiles: %s", + shape.ShortDebugString()); + } + } + } + + for (int64_t dim = 0; dim < shape.rank(); ++dim) { + DimLevelType dim_level_type = GetDimLevelType(layout, dim); + bool dim_unique = DimUnique(layout, dim); + bool dim_ordered = DimOrdered(layout, dim); + if (!ValidateDimLevel(dim_level_type, dim_unique, dim_ordered)) { + return InvalidArgument( + "layout dimension %d has invalid level encoding %s%s%s: %s", dim, + DimLevelType_Name(dim_level_type), dim_unique ? "" : ", non-unique", + dim_ordered ? "" : ", non-ordered", shape.ShortDebugString()); + } } return OkStatus(); @@ -441,32 +496,6 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { return lhs == rhs; } -/* static */ absl::Span LayoutUtil::MinorToMajor( - const Shape& shape) { - CHECK(shape.IsArray()); - return shape.layout().minor_to_major(); -} - -/* static */ absl::Span LayoutUtil::MinorToMajor( - const Layout& layout) { - return layout.minor_to_major(); -} - -/* static */ int64_t LayoutUtil::Major(const Layout& layout, - int64_t physical_dimension_number) { - CHECK_LE(0, physical_dimension_number); - CHECK_LT(physical_dimension_number, layout.minor_to_major_size()); - return Minor(layout, - layout.minor_to_major_size() - 1 - physical_dimension_number); -} - -/* static */ int64_t LayoutUtil::Minor(const Layout& layout, - int64_t physical_dimension_number) { - CHECK_LE(0, physical_dimension_number); - CHECK_LT(physical_dimension_number, layout.minor_to_major_size()); - return layout.minor_to_major(physical_dimension_number); -} - /* static */ std::vector LayoutUtil::MakeLogicalToPhysical( const Layout& layout) { std::vector logical_to_physical(layout.minor_to_major_size()); @@ -478,6 +507,11 @@ Layout CreateDefaultLayoutForRank(int64_t rank) { return logical_to_physical; } +/* static */ void LayoutUtil::PrintHumanString(Printer* printer, + const Layout& layout) { + layout.Print(printer); +} + /* static */ std::string LayoutUtil::HumanString(const Layout& layout) { return layout.ToString(); } @@ -635,4 +669,42 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { : Layout::kDefaultMemorySpace; } +/*static*/ DimLevelType LayoutUtil::GetDimLevelType(const Layout& layout, + int64_t dim) { + if (layout.dim_level_types_size() == 0) { + return DIM_DENSE; + } + CHECK_LT(dim, layout.dim_level_types_size()); + return layout.dim_level_type(dim); +} + +/*static*/ bool LayoutUtil::DimUnique(const Layout& layout, int64_t dim) { + if (layout.dim_unique_size() == 0) { + return true; + } + CHECK_LT(dim, layout.dim_unique_size()); + return layout.dim_unique(dim); +} + +/*static*/ bool LayoutUtil::DimOrdered(const Layout& layout, int64_t dim) { + if (layout.dim_ordered_size() == 0) { + return true; + } + CHECK_LT(dim, layout.dim_ordered_size()); + return layout.dim_ordered(dim); +} + +bool LayoutUtil::ValidateDimLevel(DimLevelType dim_level_type, bool dim_unique, + bool dim_ordered) { + switch (dim_level_type) { + case DIM_DENSE: + return dim_unique && dim_ordered; + case DIM_COMPRESSED: + case DIM_SINGLETON: + return true; + default: + return false; + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 702c5fb63f1..28d461f808b 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -24,6 +24,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout.h" +#include "tensorflow/compiler/xla/printer.h" #include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -44,7 +45,8 @@ class LayoutUtil { PrimitiveType index_primitive_type = PRIMITIVE_TYPE_INVALID, PrimitiveType pointer_primitive_type = PRIMITIVE_TYPE_INVALID, int64_t memory_space = 0, - std::optional physical_shape = std::nullopt); + std::optional physical_shape = std::nullopt, + int64_t dynamic_shape_metadata_prefix_bytes = 0); // Similar to MakeLayout, but take indices in reverse order. static Layout MakeLayoutFromMajorToMinor( @@ -163,8 +165,14 @@ class LayoutUtil { // Returns the minor_to_major array for the given Shape. Requires that the // shape is an array. - static absl::Span MinorToMajor(const Shape& shape); - static absl::Span MinorToMajor(const Layout& layout); + static inline absl::Span MinorToMajor(const Shape& shape) { + DCHECK(shape.IsArray()); + return shape.layout().minor_to_major(); + } + + static inline absl::Span MinorToMajor(const Layout& layout) { + return layout.minor_to_major(); + } // Major(0) is the most major logical dimension number, Major(1) is the // second-most-major logical dimension number and so on. @@ -180,11 +188,22 @@ class LayoutUtil { // the most major. Then Major(0) is the most major logical dimension, so Major // maps the physical dimension number 0 to the most major logical dimension // number Major(0). - static int64_t Major(const Layout& layout, int64_t physical_dimension_number); + static int64_t Major(const Layout& layout, + int64_t physical_dimension_number) { + DCHECK_LE(0, physical_dimension_number); + DCHECK_LT(physical_dimension_number, layout.minor_to_major_size()); + return Minor(layout, + layout.minor_to_major_size() - 1 - physical_dimension_number); + } // Minor(0) is the most minor logical dimension number, minor(1) is the // second-most-minor logical dimension number and so on. - static int64_t Minor(const Layout& layout, int64_t physical_dimension_number); + static inline int64_t Minor(const Layout& layout, + int64_t physical_dimension_number) { + DCHECK_LE(0, physical_dimension_number); + DCHECK_LT(physical_dimension_number, layout.minor_to_major_size()); + return layout.minor_to_major(physical_dimension_number); + } // Returns the inverse mapping of the Major() function. More precisely, return // a vector v such that if l == Major(p), then v[l] == p. @@ -201,6 +220,9 @@ class LayoutUtil { // the most minor physical dimension. static std::vector MakeLogicalToPhysical(const Layout& layout); + // Prints a human-readable string that represents the given layout. + static void PrintHumanString(Printer* printer, const Layout& layout); + // Returns a human-readable string that represents the given layout. static std::string HumanString(const Layout& layout); @@ -239,6 +261,15 @@ class LayoutUtil { // returns Layout::kDefaultMemorySpace. static int64_t MemorySpace(const Shape& shape); + static xla::DimLevelType GetDimLevelType(const Layout& layout, int64_t dim); + static bool DimUnique(const Layout& layout, int64_t dim); + static bool DimOrdered(const Layout& layout, int64_t dim); + + // Return true iff the given DimLevelType and dim_unique/dim_ordered values + // represent a valid encoding. + static bool ValidateDimLevel(xla::DimLevelType dim_level_type, + bool dim_unique, bool dim_ordered); + private: LayoutUtil(const LayoutUtil&) = delete; LayoutUtil& operator=(const LayoutUtil&) = delete; diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 6603b8a7178..3f412774d2e 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -474,6 +474,13 @@ TEST_F(LayoutUtilTest, ValidateLayout_Sparse) { tsl::error::INVALID_ARGUMENT, ::testing::HasSubstr( "layout has a physical_shape, but is not a sparse array"))); + *shape.mutable_layout() = + LayoutUtil::MakeLayout({1, 0}, {DIM_DENSE, DIM_DENSE}, {true, false}); + EXPECT_THAT(LayoutUtil::ValidateLayoutInShape(shape), + tsl::testing::StatusIs( + tsl::error::INVALID_ARGUMENT, + ::testing::HasSubstr("layout dimension 1 has invalid level " + "encoding DIM_DENSE, non-unique"))); } TEST_F(LayoutUtilTest, ValidateLayout_TupleSubshapesWithMissingLayouts) { diff --git a/tensorflow/compiler/xla/lazy.h b/tensorflow/compiler/xla/lazy.h new file mode 100644 index 00000000000..3ef9c093b02 --- /dev/null +++ b/tensorflow/compiler/xla/lazy.h @@ -0,0 +1,45 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_LAZY_H_ +#define TENSORFLOW_COMPILER_XLA_LAZY_H_ + +#include + +#include "absl/functional/any_invocable.h" + +namespace xla { + +template +class Lazy { + public: + explicit Lazy(absl::AnyInvocable func) + : maybe_value_(std::move(func)) {} + + const T& get() const { + if (!std::holds_alternative(maybe_value_)) { + maybe_value_ = + std::move(std::get>(maybe_value_))(); + } + return std::get(maybe_value_); + } + + private: + mutable std::variant, T> maybe_value_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LAZY_H_ diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 9e948438a9d..1a246789105 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/float8.h" #include "tensorflow/tsl/platform/logging.h" #include "tensorflow/tsl/platform/mem.h" #include "tensorflow/tsl/util/byte_swap_array.h" @@ -77,32 +78,6 @@ void ConvertEndianShort(char* bytes, int64_t size) { } } -std::string CompactOneline(const std::string& input) { - std::string result; - std::vector v = absl::StrSplit(input, absl::ByAnyChar("\n ")); - bool first = true; - // Concatenate elements in "v" with spaces separating them, but ignoring - // empty entries. - for (const auto& s : v) { - if (s.empty()) { - continue; - } - absl::StrAppend(&result, (first ? "" : " "), s); - first = false; - } - return result; -} - -// Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be -// able to transparently access the raw 16-bit value contained within. -template -T GetRawValue(T val) { - return val; -} -uint16_t GetRawValue(Eigen::half val) { - return Eigen::numext::bit_cast(val); -} - bool LiteralProtoHasValues(const LiteralProto& proto) { return proto.preds_size() || !proto.s8s().empty() || !proto.u8s().empty() || proto.s32s_size() || proto.s64s_size() || proto.u32s_size() || @@ -110,7 +85,8 @@ bool LiteralProtoHasValues(const LiteralProto& proto) { proto.c64s_size() || proto.c128s_size() || proto.tuple_literals_size() || !proto.f16s().empty() || !proto.bf16s().empty() || !proto.u16s().empty() || - !proto.s16s().empty(); + !proto.s16s().empty() || !proto.f8e5m2s().empty() || + !proto.f8e4m3fns().empty(); } // Lazy getter for the interned scalar shape in static storage. We reuse this @@ -147,6 +123,10 @@ const Shape& ScalarShape(PrimitiveType type) { return ScalarShapeImpl(); case S64: return ScalarShapeImpl(); + case F8E5M2: + return ScalarShapeImpl(); + case F8E4M3FN: + return ScalarShapeImpl(); case F16: return ScalarShapeImpl(); case BF16: @@ -675,6 +655,8 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src, COPY_ELEMENTS(S16, int16_t); COPY_ELEMENTS(S32, int32_t); COPY_ELEMENTS(S64, int64_t); + COPY_ELEMENTS(F8E5M2, tsl::float8_e5m2); + COPY_ELEMENTS(F8E4M3FN, tsl::float8_e4m3fn); COPY_ELEMENTS(F16, half); COPY_ELEMENTS(BF16, bfloat16); COPY_ELEMENTS(F32, float); @@ -835,6 +817,12 @@ Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, case S64: return CopySliceFromInternal(src_literal, src_base, dest_base, copy_size); + case F8E5M2: + return CopySliceFromInternal(src_literal, src_base, + dest_base, copy_size); + case F8E4M3FN: + return CopySliceFromInternal(src_literal, src_base, + dest_base, copy_size); case F16: return CopySliceFromInternal(src_literal, src_base, dest_base, copy_size); @@ -933,6 +921,9 @@ Literal LiteralBase::ToStatic() const { return; } for (int64_t i = 0; i < subshape->rank(); ++i) { + // GetDynamicSize has a 32-bit return type and may truncate static + // dimensions, so make sure to skip. + if (!subshape->is_dynamic_dimension(i)) continue; subshape->set_dynamic_dimension(i, false); subshape->set_dimensions(i, GetDynamicSize(i, index)); } @@ -1129,6 +1120,10 @@ Literal LiteralBase::Slice(absl::Span start_indices, return SliceInternal(result_shape, start_indices); case S64: return SliceInternal(result_shape, start_indices); + case F8E5M2: + return SliceInternal(result_shape, start_indices); + case F8E4M3FN: + return SliceInternal(result_shape, start_indices); case F16: return SliceInternal(result_shape, start_indices); case BF16: @@ -1196,6 +1191,12 @@ std::string LiteralBase::GetAsString(absl::Span multi_index, return RoundTripFpToString(Get(multi_index, shape_index)); case BF16: return RoundTripFpToString(Get(multi_index, shape_index)); + case F8E5M2: + return RoundTripFpToString( + Get(multi_index, shape_index)); + case F8E4M3FN: + return RoundTripFpToString( + Get(multi_index, shape_index)); case F64: return RoundTripFpToString(Get(multi_index, shape_index)); case C64: { @@ -1244,6 +1245,10 @@ std::optional LiteralBase::GetAsDouble( absl::Span multi_index) const { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { + case F8E5M2: + return static_cast(Get(multi_index)); + case F8E4M3FN: + return static_cast(Get(multi_index)); case F16: return static_cast(Get(multi_index)); case F32: @@ -1260,6 +1265,10 @@ std::optional LiteralBase::GetAsDouble( std::optional LiteralBase::GetAsComplex128( absl::Span multi_index) const { switch (shape().element_type()) { + case F8E5M2: + return {{static_cast(Get(multi_index)), 0}}; + case F8E4M3FN: + return {{static_cast(Get(multi_index)), 0}}; case BF16: return {{static_cast(Get(multi_index)), 0}}; case F16: @@ -1324,6 +1333,13 @@ Status MutableLiteralBase::SetFromDouble(absl::Span multi_index, case BF16: Set(multi_index, static_cast(value)); break; + case F8E5M2: + Set(multi_index, static_cast(value)); + break; + case F8E4M3FN: + Set(multi_index, + static_cast(value)); + break; default: return FailedPrecondition("Array element type is not floating: %s", PrimitiveType_Name(shape().element_type())); @@ -1333,46 +1349,44 @@ Status MutableLiteralBase::SetFromDouble(absl::Span multi_index, namespace { -std::string ShapeToString(bool print_layout, const Shape& shape) { - return print_layout ? ShapeUtil::HumanStringWithLayout(shape) - : ShapeUtil::HumanString(shape); +void PrintShape(bool print_layout, const Shape& shape, Printer* printer) { + if (print_layout) { + ShapeUtil::PrintHumanStringWithLayout(printer, shape); + } else { + ShapeUtil::PrintHumanString(printer, shape); + } } -void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_shape, bool print_layout, - std::vector* pieces); +void PrintHelper(const LiteralBase& literal, const ShapeIndex& shape_index, + bool print_shape, bool print_layout, bool oneline, + Printer* printer); -void TupleToStringHelper(const LiteralBase& literal, - const ShapeIndex& shape_index, bool print_shape, - bool print_layout, std::vector* pieces) { +void TuplePrintHelper(const LiteralBase& literal, const ShapeIndex& shape_index, + bool print_shape, bool print_layout, bool oneline, + Printer* printer) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); - pieces->push_back("(\n"); - std::vector tuple_pieces; - const auto tuple_element_count = ShapeUtil::TupleElementCount(subshape); - tuple_pieces.reserve(tuple_element_count); + printer->Append(oneline ? "( " : "(\n"); for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { ShapeIndex element_index = shape_index; element_index.push_back(i); - std::vector element_pieces; - ToStringHelper(literal, element_index, print_shape, print_layout, - &element_pieces); - tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); + if (i > 0) printer->Append(oneline ? ", " : ",\n"); + PrintHelper(literal, element_index, print_shape, print_layout, oneline, + printer); } - pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); - pieces->push_back("\n)"); + printer->Append(oneline ? " )" : "\n)"); } -void DenseArrayToStringHelper(const LiteralBase& literal, - const ShapeIndex& shape_index, bool print_shape, - bool print_layout, - std::vector* pieces) { +void DenseArrayPrintHelper(const LiteralBase& literal, + const ShapeIndex& shape_index, bool print_shape, + bool print_layout, bool oneline, Printer* printer) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); int64_t rank = subshape.rank(); + const absl::string_view linebreak = oneline ? " " : "\n"; std::function dimensions, std::vector*)> - to_string_recursive = [&](absl::Span dimensions, - std::vector* accum_indices) { + print_recursive = [&](absl::Span dimensions, + std::vector* accum_indices) { // dimensions.size() decreases by 1 at each recursive call, // and accum_indices->size() increases by 1. // Their sum is equal to the rank of the tensor. @@ -1385,7 +1399,8 @@ void DenseArrayToStringHelper(const LiteralBase& literal, } // Handle the innermost tensor of a 2D+ tensor. if (dimensions.size() == 1 && brace == "{") { - return StrCat(" ", brace, dimensions[0] <= 1 ? "" : " "); + return StrCat(oneline ? "" : " ", brace, + dimensions[0] <= 1 ? "" : " "); } if (dimensions.size() == 1 && brace == "}") { return StrCat(dimensions[0] <= 1 ? "" : " ", brace); @@ -1397,11 +1412,13 @@ void DenseArrayToStringHelper(const LiteralBase& literal, accum_indices_size < rank) { int index = accum_indices->size() - 1; int value = accum_indices->back(); - return StrCat(brace, " /*i", index, "=", value, "*/\n"); + int size = dimensions.front(); + return StrCat(brace, " /*i", index, "=", value, "*/", + size > 0 ? linebreak : ""); } - return StrCat(brace, "\n"); + return StrCat(brace, linebreak); } - return StrCat("\n", brace); + return StrCat(linebreak, brace); }; if (dimensions.empty()) { @@ -1412,35 +1429,36 @@ void DenseArrayToStringHelper(const LiteralBase& literal, } else { elem = literal.GetAsString(*accum_indices, shape_index); } - pieces->push_back(elem); + printer->Append(elem); } else { - pieces->push_back(brace_to_string("{")); + printer->Append(brace_to_string("{")); for (int i = 0; i < dimensions[0]; ++i) { accum_indices->push_back(i); - to_string_recursive(dimensions.subspan(1), accum_indices); + print_recursive(dimensions.subspan(1), accum_indices); accum_indices->pop_back(); if (i < dimensions[0] - 1) { - pieces->push_back(","); - pieces->push_back(dimensions.size() > 1 ? "\n" : " "); + printer->Append(","); + printer->Append(dimensions.size() > 1 ? linebreak : " "); } } - pieces->push_back(brace_to_string("}")); + printer->Append(brace_to_string("}")); } }; if (print_shape) { - pieces->push_back(ShapeToString(print_layout, subshape)); + PrintShape(print_layout, subshape, printer); if (subshape.is_dynamic()) { - pieces->push_back("("); + printer->Append("("); for (int64_t i = 0; i < subshape.dimensions_size(); ++i) { - pieces->push_back(StrCat(literal.GetDynamicSize(i, shape_index))); + printer->Append( + absl::AlphaNum(literal.GetDynamicSize(i, shape_index)).Piece()); if (i < subshape.dimensions_size() - 1) { - pieces->push_back(","); + printer->Append(","); } } - pieces->push_back(")"); + printer->Append(")"); } - pieces->push_back(" "); + printer->Append(" "); } std::vector indices = {}; std::vector dimensions; @@ -1448,73 +1466,108 @@ void DenseArrayToStringHelper(const LiteralBase& literal, for (int64_t i = 0; i < subshape.rank(); ++i) { dimensions.push_back(literal.GetDynamicSize(i, shape_index)); } - to_string_recursive(dimensions, &indices); + print_recursive(dimensions, &indices); } -void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, - bool print_shape, bool print_layout, - std::vector* pieces) { +void PrintHelper(const LiteralBase& literal, const ShapeIndex& shape_index, + bool print_shape, bool print_layout, bool oneline, + Printer* printer) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); CHECK(LayoutUtil::HasLayout(literal.shape())); CHECK(LayoutUtil::HasLayout(subshape)); if (subshape.IsTuple()) { - TupleToStringHelper(literal, shape_index, print_shape, print_layout, - pieces); + TuplePrintHelper(literal, shape_index, print_shape, print_layout, oneline, + printer); } else if (subshape.IsToken()) { - pieces->push_back("token"); + printer->Append("token"); } else { CHECK(LayoutUtil::IsDenseArray(subshape)); if (literal.IsKnown(shape_index)) { - DenseArrayToStringHelper(literal, shape_index, print_shape, print_layout, - pieces); + DenseArrayPrintHelper(literal, shape_index, print_shape, print_layout, + oneline, printer); } else { - pieces->push_back(ShapeToString(print_layout, subshape)); - pieces->push_back(" "); + PrintShape(print_layout, subshape, printer); + printer->Append(" "); if (literal.IsDetermined(shape_index)) { - pieces->push_back("unknown"); + printer->Append("unknown"); } else { - pieces->push_back("undetermined"); + printer->Append("undetermined"); } } } } - } // namespace -std::string LiteralBase::ToString() const { - std::vector pieces; +void LiteralBase::Print(Printer* printer) const { + CHECK(LayoutUtil::HasLayout(this->shape())); + PrintHelper(*this, {}, /*print_shape=*/true, /*print_layout=*/false, + /*oneline=*/false, printer); +} + +void LiteralBase::PrintOneline(Printer* printer) const { CHECK(LayoutUtil::HasLayout(this->shape())); - ToStringHelper(*this, {}, /*print_shape=*/true, - /*print_layout=*/false, &pieces); - return absl::StrJoin(pieces, ""); + PrintHelper(*this, {}, /*print_shape=*/true, /*print_layout=*/false, + /*oneline=*/true, printer); +} + +void LiteralBase::PrintWithoutShape(Printer* printer) const { + CHECK(LayoutUtil::HasLayout(this->shape())); + PrintHelper(*this, {}, /*print_shape=*/false, /*print_layout=*/false, + /*oneline=*/false, printer); +} + +void LiteralBase::PrintWithoutShapeOneline(Printer* printer) const { + CHECK(LayoutUtil::HasLayout(this->shape())); + PrintHelper(*this, {}, /*print_shape=*/false, /*print_layout=*/false, + /*oneline=*/true, printer); +} + +void LiteralBase::PrintWithLayout(Printer* printer) const { + CHECK(LayoutUtil::HasLayout(this->shape())); + PrintHelper(*this, {}, /*print_shape=*/true, /*print_layout=*/true, + /*oneline=*/false, printer); +} + +void LiteralBase::PrintWithLayoutOneline(Printer* printer) const { + CHECK(LayoutUtil::HasLayout(this->shape())); + PrintHelper(*this, {}, /*print_shape=*/true, /*print_layout=*/true, + /*oneline=*/true, printer); +} + +std::string LiteralBase::ToString() const { + StringPrinter printer; + Print(&printer); + return std::move(printer).ToString(); } std::string LiteralBase::ToStringOneline() const { - return CompactOneline(ToString()); + StringPrinter printer; + PrintOneline(&printer); + return std::move(printer).ToString(); } std::string LiteralBase::ToStringWithoutShape() const { - std::vector pieces; - CHECK(LayoutUtil::HasLayout(this->shape())); - ToStringHelper(*this, {}, /*print_shape=*/false, - /*print_layout=*/false, &pieces); - return absl::StrJoin(pieces, ""); + StringPrinter printer; + PrintWithoutShape(&printer); + return std::move(printer).ToString(); } std::string LiteralBase::ToStringWithoutShapeOneline() const { - return CompactOneline(ToStringWithoutShape()); + StringPrinter printer; + PrintWithoutShapeOneline(&printer); + return std::move(printer).ToString(); } std::string LiteralBase::ToStringWithLayout() const { - std::vector pieces; - CHECK(LayoutUtil::HasLayout(this->shape())); - ToStringHelper(*this, {}, /*print_shape=*/true, - /*print_layout=*/true, &pieces); - return absl::StrJoin(pieces, ""); + StringPrinter printer; + PrintWithLayout(&printer); + return std::move(printer).ToString(); } std::string LiteralBase::ToStringWithLayoutOneline() const { - return CompactOneline(ToStringWithLayout()); + StringPrinter printer; + PrintWithLayoutOneline(&printer); + return std::move(printer).ToString(); } void LiteralBase::EachCellAsString( @@ -1550,24 +1603,11 @@ Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal, } template -typename std::enable_if::value && - (std::is_same::value || - std::is_same::value), - Literal>::type -ConvertBetweenNativeTypes(const LiteralBase& src_literal) { - auto converter = [](NativeSrcT src) { - return NativeDestT(static_cast(src)); - }; - return ConvertBetweenNativeTypesWithConverter( - src_literal, converter); -} - -template -typename std::enable_if::value && - std::is_integral::value, - Literal>::type -ConvertBetweenNativeTypes(const LiteralBase& src_literal) { +Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { + if constexpr (std::is_same_v) { + return src; + } // C++ [conv.bool]p1: // A prvalue of arithmetic [...] type can be converted to a prvalue of // type bool. A zero value [...] is converted to false; any other value is @@ -1580,14 +1620,18 @@ ConvertBetweenNativeTypes(const LiteralBase& src_literal) { // may be undefined if the value's magnitude is too large or it is a NaN. // Let's choose saturating arithmetic as it captures the spirit of infinity // and arbitrarily map NaN to zero. - if (!std::is_same::value) { + if constexpr (!std::is_same_v && + !std::numeric_limits::is_integer && + std::numeric_limits::is_integer) { if (src != src) { return NativeDestT{0}; } - if (src >= std::numeric_limits::max()) { + if (src >= + static_cast(std::numeric_limits::max())) { return std::numeric_limits::max(); } - if (src <= std::numeric_limits::lowest()) { + if (src <= + static_cast(std::numeric_limits::lowest())) { return std::numeric_limits::lowest(); } } @@ -1598,53 +1642,18 @@ ConvertBetweenNativeTypes(const LiteralBase& src_literal) { } template -typename std::enable_if::value && - std::is_integral::value) && - !(std::is_same::value && - (std::is_same::value || - std::is_same::value)), - Literal>::type -ConvertBetweenNativeTypes(const LiteralBase& src_literal) { - auto converter = [](NativeSrcT src) { return static_cast(src); }; - return ConvertBetweenNativeTypesWithConverter( - src_literal, converter); -} - -template -typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT) && - !std::is_same::value), - Literal>::type -BitcastBetweenNativeTypes(const LiteralBase& src_literal) { - auto converter = [](NativeSrcT src) { - return absl::bit_cast(GetRawValue(src)); - }; - return ConvertBetweenNativeTypesWithConverter( - src_literal, converter); -} - -template -typename std::enable_if<(sizeof(NativeSrcT) == sizeof(Eigen::half) && - std::is_same::value), - Literal>::type -BitcastBetweenNativeTypes(const LiteralBase& src_literal) { - // Eigen::half doesn't satisfy the absl::bit_cast contract, so explicitly - // cast to unsigned short first. - auto converter = [](NativeSrcT src) { - return Eigen::numext::bit_cast( - absl::bit_cast(GetRawValue(src))); - }; - return ConvertBetweenNativeTypesWithConverter( - src_literal, converter); -} - -// This template specialization is here to make the compiler happy. bit_cast has -// a static check that the types are the same size. This specialization should -// never be used because the source and destination types are checked for -// identical sizes higher up. -template -typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), - Literal>::type -BitcastBetweenNativeTypes(const LiteralBase& src_literal) { +Literal BitcastBetweenNativeTypes(const LiteralBase& src_literal) { + if constexpr (sizeof(NativeSrcT) == sizeof(NativeDestT)) { + auto converter = [](NativeSrcT src) { + return Eigen::numext::bit_cast(src); + }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); + } + // This template specialization is here to make the compiler happy. bit_cast + // has a static check that the types are the same size. This specialization + // should never be used because the source and destination types are checked + // for identical sizes higher up. LOG(FATAL) << "Invalid bitcast between types of different sizes."; } @@ -1688,6 +1697,8 @@ StatusOr ConvertIfDestTypeMatches(const LiteralBase& src_literal, CONVERT_IF_TYPES_MATCH(F32) CONVERT_IF_TYPES_MATCH(F64) CONVERT_IF_TYPES_MATCH(BF16) + CONVERT_IF_TYPES_MATCH(F8E5M2) + CONVERT_IF_TYPES_MATCH(F8E4M3FN) #undef CONVERT_IF_TYPES_MATCH case C64: if (bitcast) { @@ -1733,6 +1744,8 @@ StatusOr ConvertSwitch(const LiteralBase& literal, CONVERT_IF_DEST_TYPE_MATCHES(F32) CONVERT_IF_DEST_TYPE_MATCHES(F64) CONVERT_IF_DEST_TYPE_MATCHES(BF16) + CONVERT_IF_DEST_TYPE_MATCHES(F8E5M2) + CONVERT_IF_DEST_TYPE_MATCHES(F8E4M3FN) #undef CONVERT_IF_DEST_TYPE_MATCHES // Other types are not yet supported. default: @@ -1922,6 +1935,10 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { return EqualElementsInternal(other, &multi_index); case BF16: return EqualElementsInternal(other, &multi_index); + case F8E5M2: + return EqualElementsInternal(other, &multi_index); + case F8E4M3FN: + return EqualElementsInternal(other, &multi_index); case C64: return EqualElementsInternal(other, &multi_index); case C128: @@ -2030,6 +2047,13 @@ bool Literal::Piece::IsAll(const Literal& scalar) const { case PRED: return AllElementsEqualValue(data(), scalar.GetFirstElement()); + case F8E5M2: + return AllElementsEqualValue(data(), + scalar.GetFirstElement()); + case F8E4M3FN: + return AllElementsEqualValue( + data(), + scalar.GetFirstElement()); case F16: return AllElementsEqualValue(data(), scalar.GetFirstElement()); @@ -2063,7 +2087,7 @@ bool LiteralBase::IsAll(int8_t value) const { } PrimitiveType ty = shape().element_type(); if (primitive_util::IsFloatingPointType(ty)) { - return IsAllFloat(value); + return IsAllFloatImpl(value, /*round_value=*/false); } if (primitive_util::IsUnsignedIntegralType(ty) && value < 0) { return false; @@ -2110,12 +2134,23 @@ bool LiteralBase::IsAll(int8_t value) const { } bool LiteralBase::IsAllFloat(float value) const { + return IsAllFloatImpl(value, /*round_value=*/true); +} + +bool LiteralBase::IsAllFloatImpl(float value, bool round_value) const { if (!shape().IsArray()) { return false; } PrimitiveType ty = shape().element_type(); Literal scalar(ShapeUtil::MakeScalarShape(ty)); switch (ty) { + case F8E5M2: + scalar.Set({}, static_cast(value)); + break; + case F8E4M3FN: + scalar.Set({}, + static_cast(value)); + break; case F16: scalar.Set({}, static_cast(value)); break; @@ -2131,6 +2166,9 @@ bool LiteralBase::IsAllFloat(float value) const { default: return false; } + if (!round_value && scalar.GetAsDouble({}) != value) { + return false; + } return root_piece().IsAll(scalar); } @@ -2207,6 +2245,12 @@ bool LiteralBase::IsR1Iota() const { return Get({idx}) == static_cast(idx); case BF16: return Get({idx}) == static_cast(idx); + case F8E5M2: + return Get({idx}) == + static_cast(idx); + case F8E4M3FN: + return Get({idx}) == + static_cast(idx); case C64: return Get({idx}) == complex64(idx, 0.0f); case C128: @@ -2316,6 +2360,12 @@ bool LiteralBase::IsZero(absl::Span indices) const { return Get(indices) == static_cast(0.0f); case BF16: return Get(indices) == static_cast(0.0f); + case F8E5M2: + return Get(indices) == + static_cast(0.0f); + case F8E4M3FN: + return Get(indices) == + static_cast(0.0f); case PRED: return Get(indices) == false; default: @@ -2399,6 +2449,16 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { ConvertEndianShort(proto->mutable_bf16s()); } break; + case F8E5M2: + *proto->mutable_f8e5m2s() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; + case F8E4M3FN: + *proto->mutable_f8e4m3fns() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; case F32: CopyToRepeatedField(proto->mutable_f32s(), data()); break; @@ -2506,6 +2566,19 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); } } break; + case F8E5M2: { + const std::string& s(proto.f8e5m2s()); + TF_RET_CHECK(data().size() * sizeof(tsl::float8_e5m2) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + } break; + case F8E4M3FN: { + const std::string& s(proto.f8e4m3fns()); + TF_RET_CHECK(data().size() * + sizeof(tsl::float8_e4m3fn) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + } break; case F16: { const std::string& s(proto.f16s()); TF_RET_CHECK(data().size() * sizeof(half) == s.size()); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 8f8972fe75b..8607d805df9 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -40,12 +40,14 @@ limitations under the License. #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/printer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/tsl/lib/core/bitmap.h" +#include "tensorflow/tsl/platform/cpu_info.h" #include "tensorflow/tsl/platform/logging.h" #include "tensorflow/tsl/platform/protobuf.h" #include "tensorflow/tsl/platform/status.h" @@ -90,15 +92,38 @@ class LiteralBase { // array. std::string GetR1U8AsString() const; - // Returns a string representation of the literal value. The Shape of the + // Prints a string representation of the literal value. The Shape of the // literal is a prefix of the literal value in the string. + // + // Warning: this function can take minutes for multi-million element Literals. + void Print(Printer* printer) const; + + // Similar to Print, but prints the result in a compact one-line form. + void PrintOneline(Printer* printer) const; + + // Prints a string representation of the literal value which does *not* + // include the shape string. + void PrintWithoutShape(Printer* printer) const; + + // Similar to PrintWithoutShape, but prints the result in a compact one-line + // form. + void PrintWithoutShapeOneline(Printer* printer) const; + + // Prints a string representation of the literal value which includes the + // shape string with its layout.does *not* include the shape string. + void PrintWithLayout(Printer* printer) const; - // Warning: this function can take minutes for multi-million - // element Literals. + // Similar to PrintWithLayout, but prints the result in a compact one-line + // form. + void PrintWithLayoutOneline(Printer* printer) const; + + // Returns a string representation of the literal value. The Shape of the + // literal is a prefix of the literal value in the string. + // + // Warning: this function can take minutes for multi-million element Literals. std::string ToString() const; - // Similar to ToString, but return the result in a compact - // one-line form. + // Similar to ToString, but return the result in a compact one-line form. std::string ToStringOneline() const; // Returns a string representation of the literal value which does *not* @@ -113,8 +138,8 @@ class LiteralBase { // shape string with its layout.does *not* include the shape string. std::string ToStringWithLayout() const; - // Similar to ToStringWithLayout, but return the result in a compact - // one-line form. + // Similar to ToStringWithLayout, but return the result in a compact one-line + // form. std::string ToStringWithLayoutOneline() const; // Gets an element in the literal at the given index. The multi_index is @@ -152,7 +177,9 @@ class LiteralBase { template typename std::enable_if<(std::is_arithmetic::value || std::is_same::value || - std::is_same::value), + std::is_same::value || + std::is_same::value || + std::is_same::value), bool>::type IsEqualAt(absl::Span multi_index, T value) const { if (auto as_s64 = GetIntegralAsS64(multi_index)) { @@ -238,9 +265,8 @@ class LiteralBase { // if it's not an array. // // This casts value to the type of literal, then compares using ==, with the - // caveat that NaNs are considered equal. The usual admonishments about - // floating-point equality checks apply. We expect you to use this to check - // for values that can be expressed precisely as a float, e.g. -0.5. + // caveat that NaNs are considered equal. Unlike IsAll, this does not + // necessarily return false if the value does not fit in this literal's type. bool IsAllFloat(float value) const; bool IsAllComplex(complex64 value) const; @@ -772,6 +798,12 @@ class LiteralBase { template Literal SliceInternal(const Shape& result_shape, absl::Span start_indices) const; + + // Like IsAllFloat, but if round_value is false and the value is not + // representable with the literal's type (e.g., due to rounding error or + // overflow/underflow when casting the value to the literal's type), returns + // false. + bool IsAllFloatImpl(float value, bool round_value) const; }; // Abstract base class representing a mutable literal in XLA. @@ -1492,7 +1524,7 @@ TF_ATTRIBUTE_NOINLINE Status MutableLiteralBase::PopulateParallel( [&](absl::Span indexes, int thread_id) { return generator(indexes, thread_id); }, - /*parallel=*/true); + /*parallel=*/data().size() > 32); } template diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index a30dbc1a0d9..01d3e813e37 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -20,14 +20,17 @@ limitations under the License. #endif #include +#include #include #include "absl/base/casts.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/float8.h" using absl::StrAppend; using absl::StrAppendFormat; @@ -70,6 +73,19 @@ bool CompareEqual(NativeT lhs, NativeT rhs, // Specializations for floating types that do bitwise comparisons when equality // comparison is requested. template <> +bool CompareEqual(tsl::float8_e5m2 lhs, tsl::float8_e5m2 rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, + multi_index); +} +template <> +bool CompareEqual(tsl::float8_e4m3fn lhs, + tsl::float8_e4m3fn rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, + multi_index); +} +template <> bool CompareEqual(bfloat16 lhs, bfloat16 rhs, absl::Span multi_index) { return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); @@ -127,6 +143,18 @@ Status MakeErrorStatus(NativeT lhs, NativeT rhs, LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs)); } +template <> +Status MakeErrorStatus(tsl::float8_e5m2 lhs, tsl::float8_e5m2 rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, + multi_index); +} +template <> +Status MakeErrorStatus(tsl::float8_e4m3fn lhs, tsl::float8_e4m3fn rhs, + absl::Span multi_index) { + return MakeBitwiseErrorStatus(lhs, rhs, + multi_index); +} template <> Status MakeErrorStatus(bfloat16 lhs, bfloat16 rhs, absl::Span multi_index) { @@ -239,6 +267,14 @@ bool IsNan(NativeT value) { } // Converts the given floating-point value to a string. +std::string FpValueToString(tsl::float8_e5m2 value) { + return absl::StrFormat("%5.3g", static_cast(value)); +} + +std::string FpValueToString(tsl::float8_e4m3fn value) { + return absl::StrFormat("%5.3g", static_cast(value)); +} + std::string FpValueToString(bfloat16 value) { return absl::StrFormat("%10.4g", static_cast(value)); } @@ -266,7 +302,7 @@ std::string FpValueToString(complex128 value) { } // A wrapper of std::abs to include data types that are not supported by -// std::abs, in particular, bfloat16 and half. +// std::abs, such as bfloat16 and half. template double FpAbsoluteValue(NativeT value) { return std::abs(value); @@ -282,6 +318,16 @@ double FpAbsoluteValue(half value) { return FpAbsoluteValue(static_cast(value)); } +template <> +double FpAbsoluteValue(tsl::float8_e5m2 value) { + return FpAbsoluteValue(static_cast(value)); +} + +template <> +double FpAbsoluteValue(tsl::float8_e4m3fn value) { + return FpAbsoluteValue(static_cast(value)); +} + // Helper class for comparing floating-point literals within an error bound. template class NearComparator { @@ -701,7 +747,11 @@ constexpr std::array NearComparator::kErrorBucketBounds; Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual, const ShapeIndex& shape_index, const MiscompareCallback& miscompare_callback) { - TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + if (expected.shape().is_static() && actual.shape().is_static()) { + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + } else { + TF_RETURN_IF_ERROR(EqualDynamicShapesAndDimensions(expected, actual)); + } Status result; if (expected.shape().IsTuple()) { @@ -756,6 +806,14 @@ Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual, case U64: result = Equal(expected, actual, index, 0, miscompared_ptr); break; + case F8E5M2: + result = Equal(expected, actual, index, 0, + miscompared_ptr); + break; + case F8E4M3FN: + result = Equal(expected, actual, index, 0, + miscompared_ptr); + break; case BF16: result = Equal(expected, actual, index, 0, miscompared_ptr); break; @@ -798,7 +856,11 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, const ShapeIndex& shape_index, const ErrorSpec& error, std::optional detailed_message, const MiscompareCallback& miscompare_callback) { - TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + if (expected.shape().is_static() && actual.shape().is_static()) { + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + } else { + TF_RETURN_IF_ERROR(EqualDynamicShapesAndDimensions(expected, actual)); + } if (expected.shape().IsTuple()) { Status return_status; @@ -840,6 +902,16 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, bool use_detailed_message = detailed_message.value_or( ShapeUtil::ElementsIn(expected.shape()) >= 64); switch (expected.shape().element_type()) { + case F8E5M2: + return NearComparator::Compare( + expected, actual, shape_index, error, use_detailed_message, + miscompare_callback); + break; + case F8E4M3FN: + return NearComparator::Compare( + expected, actual, shape_index, error, use_detailed_message, + miscompare_callback); + break; case BF16: return NearComparator::Compare(expected, actual, shape_index, error, use_detailed_message, @@ -932,6 +1004,53 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { return OkStatus(); } +Status EqualDynamicShapesAndDimensions(const LiteralSlice& expected, + const LiteralSlice& actual) { + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + return ShapeUtil::ForEachSubshapeWithStatus( + expected.shape(), [&expected, &actual](const Shape& expected_shape, + const ShapeIndex& index) { + auto actual_shape = ShapeUtil::GetSubshape(actual.shape(), index); + for (int i = 0; i < expected_shape.dimensions().size(); ++i) { + if (!expected_shape.is_dynamic_dimension(i) && + !actual_shape.is_dynamic_dimension(i)) { + // We're only interested in dynamic dimensions. + continue; + } + if (expected_shape.is_dynamic_dimension(i) && + !actual_shape.is_dynamic_dimension(i)) { + return InvalidArgument( + "mismatch at dimension %d. the expected shape %s is dynamic " + "while " + "the actual shape %s is not.", + i, ShapeUtil::HumanString(expected.shape()), + ShapeUtil::HumanString(actual.shape())); + } + if (!expected_shape.is_dynamic_dimension(i) && + actual_shape.is_dynamic_dimension(i)) { + return InvalidArgument( + "mismatch at dimension %d. the expected shape %s is not " + "dynamic " + "while the actual shape %s is dynamic.", + i, ShapeUtil::HumanString(expected.shape()), + ShapeUtil::HumanString(actual.shape())); + } + // Both dimensions are dynamic. Check that they are equal. + int64_t expected_dynamic_size = expected.GetDynamicSize(i, index); + int64_t actual_dynamic_size = actual.GetDynamicSize(i, index); + if (expected_dynamic_size != actual_dynamic_size) { + return InvalidArgument( + "mismatch at dimension %d. The expected dynamic size does not " + "match " + "the actual dynamic size. %d vs. %d", + i, expected_dynamic_size, actual_dynamic_size); + } + } + + return OkStatus(); + }); +} + namespace { // If result is an error, extend the error message with the expected and actual diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h index 49b2c688167..24a657d4ddb 100644 --- a/tensorflow/compiler/xla/literal_comparison.h +++ b/tensorflow/compiler/xla/literal_comparison.h @@ -30,6 +30,11 @@ namespace literal_comparison { // primitive types. Status EqualShapes(const Shape& expected, const Shape& actual); +// Returns ok if the given literals share identical dynamic shapes and +// dimension sizes. +Status EqualDynamicShapesAndDimensions(const LiteralSlice& expected, + const LiteralSlice& actual); + // Returns ok if the expected and actual literals are (bitwise) equal for all // elements in the literal. Also, asserts that the rank, dimensions sizes, and // primitive type are equal. diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index d6e15f570cc..ff539de0b05 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/tsl/lib/core/status_test_util.h" +#include "tensorflow/tsl/platform/float8.h" namespace xla { namespace { @@ -132,6 +133,19 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto bf16_lit_truncated2 = LiteralUtil::CreateR0(static_cast(9.001f)); EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString()); + + auto f8e5m2_lit = + LiteralUtil::CreateR0(tsl::float8_e5m2(0.5)); + EXPECT_EQ("f8e5m2[] 0.5", f8e5m2_lit.ToString()); + + // 3.14 will be rounded to 3 in e5m2 format. + auto f8e5m2_lit_truncated = + LiteralUtil::CreateR0(tsl::float8_e5m2(3.141)); + EXPECT_EQ("f8e5m2[] 3", f8e5m2_lit_truncated.ToString()); + + auto f8e4m3_lit = + LiteralUtil::CreateR0(tsl::float8_e4m3fn(0.5)); + EXPECT_EQ("f8e4m3fn[] 0.5", f8e4m3_lit.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -554,6 +568,15 @@ TEST_F(LiteralUtilTest, IsAll) { bfloat16 b90(9.00f); EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}}).IsAll(9.0)); + tsl::float8_e5m2 q16(8); + EXPECT_TRUE(LiteralUtil::CreateR1({q16}).IsAll(8)); + // 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false + EXPECT_FALSE(LiteralUtil::CreateR1({q16}).IsAll(9)); + + tsl::float8_e4m3fn r16(9); // Exactly representable in e4m3 + EXPECT_FALSE(LiteralUtil::CreateR1({r16}).IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR1({r16}).IsAll(9)); + complex64 c8_9 = {8, 9}; EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAll(8)); @@ -585,6 +608,10 @@ TEST_F(LiteralUtilTest, IsAllFloat) { EXPECT_FALSE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.49)); EXPECT_FALSE( LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0)); + + // IsAllFloat rounds the input scalar to the literal type + EXPECT_TRUE( + LiteralUtil::CreateR0(bfloat16(128.)).IsAllFloat(128.5)); } TEST_F(LiteralUtilTest, IsAllComplex) { @@ -1032,6 +1059,22 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { EXPECT_EQ(output, expected); } +TEST_F(LiteralUtilTest, PopulateWithValueR0F8e5m2) { + Literal output(ShapeUtil::MakeShape(F8E5M2, {})); + tsl::float8_e5m2 x(0.25f); + output.PopulateWithValue(x); + auto expected = LiteralUtil::CreateR0(x); + EXPECT_EQ(output, expected); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR1F8e4m3) { + Literal output(ShapeUtil::MakeShape(F8E4M3FN, {3})); + tsl::float8_e4m3fn x(0.5f); + output.PopulateWithValue(x); + auto expected = LiteralUtil::CreateR1({x, x, x}); + EXPECT_EQ(output, expected); +} + TEST_F(LiteralUtilTest, ReplicateR2U32) { auto input = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); @@ -1500,6 +1543,58 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { EXPECT_EQ(c128.Convert(S32).status().code(), tsl::error::UNIMPLEMENTED); } +TEST_F(LiteralUtilTest, ConvertIfTypesMatchF8) { + auto s8 = LiteralUtil::CreateR2WithLayout({{0, 1}, {2, 3}}, + layout_r2_dim0major_); + auto f32 = LiteralUtil::CreateR2WithLayout({{0., 1.}, {2., 3.}}, + layout_r2_dim0major_); + auto c128 = LiteralUtil::CreateR2WithLayout({{0., 1.}, {2., 3.}}, + layout_r2_dim0major_); + using e5 = tsl::float8_e5m2; + auto f8e5m2 = LiteralUtil::CreateR2WithLayout( + {{e5{0.}, e5{1.}}, {e5{2.}, e5{3.}}}, layout_r2_dim0major_); + using e4 = tsl::float8_e4m3fn; + auto f8e4m3 = LiteralUtil::CreateR2WithLayout( + {{e4{0.}, e4{1.}}, {e4{2.}, e4{3.}}}, layout_r2_dim0major_); + Literal conv; + + conv = s8.Convert(F8E5M2).value(); + EXPECT_EQ(conv, f8e5m2); + + conv = f32.Convert(F8E5M2).value(); + EXPECT_EQ(conv, f8e5m2); + + conv = f8e4m3.Convert(F8E5M2).value(); + EXPECT_EQ(conv, f8e5m2); + + conv = s8.Convert(F8E4M3FN).value(); + EXPECT_EQ(conv, f8e4m3); + + conv = f32.Convert(F8E4M3FN).value(); + EXPECT_EQ(conv, f8e4m3); + + conv = f8e5m2.Convert(F8E4M3FN).value(); + EXPECT_EQ(conv, f8e4m3); + + conv = f8e5m2.Convert(S8).value(); + EXPECT_EQ(conv, s8); + + conv = f8e5m2.Convert(F32).value(); + EXPECT_EQ(conv, f32); + + conv = f8e5m2.Convert(C128).value(); + EXPECT_EQ(conv, c128); + + conv = f8e4m3.Convert(S8).value(); + EXPECT_EQ(conv, s8); + + conv = f8e4m3.Convert(F32).value(); + EXPECT_EQ(conv, f32); + + conv = f8e4m3.Convert(C128).value(); + EXPECT_EQ(conv, c128); +} + TEST_F(LiteralUtilTest, BitcastConvert) { Literal original = LiteralUtil::CreateR1( {absl::bit_cast(2.5f), absl::bit_cast(-42.25f), @@ -1888,6 +1983,12 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); auto vector_half = LiteralUtil::CreateR1({half{10.0}, half{20.0}, half{-30.0}}); + using e5 = tsl::float8_e5m2; + auto vector_f8e5m2 = + LiteralUtil::CreateR1({e5{10.0}, e5{20.0}, e5{-32.0}}); + using e4 = tsl::float8_e4m3fn; + auto vector_f8e4m3 = + LiteralUtil::CreateR1({e4{10.0}, e4{20.0}, e4{-32.0}}); auto matrix_pred = LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); auto tuple = LiteralUtil::MakeTuple( @@ -1906,6 +2007,8 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { EXPECT_EQ(vector_c64, to_from_proto(vector_c64)); EXPECT_EQ(vector_c128, to_from_proto(vector_c128)); EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); + EXPECT_EQ(vector_f8e5m2, to_from_proto(vector_f8e5m2)); + EXPECT_EQ(vector_f8e4m3, to_from_proto(vector_f8e4m3)); EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); EXPECT_EQ(tuple, to_from_proto(tuple)); EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple)); @@ -2127,19 +2230,24 @@ TEST_F(LiteralUtilTest, IsEqualAt) { Literal c2 = LiteralUtil::CreateR0(10); EXPECT_TRUE(c2.IsEqualAt({}, val_double)); EXPECT_TRUE(c2.IsEqualAt({}, val_integral)); - complex128 val_complex = {10, 0}; - EXPECT_TRUE(c2.IsEqualAt({}, val_complex)); - EXPECT_TRUE(c1.IsEqualAt({}, val_complex)); - Literal c3 = LiteralUtil::CreateR0(val_complex); + Literal c3 = + LiteralUtil::CreateR0(tsl::float8_e5m2{val_double}); EXPECT_TRUE(c3.IsEqualAt({}, val_double)); EXPECT_TRUE(c3.IsEqualAt({}, val_integral)); + complex128 val_complex = {10, 0}; + EXPECT_TRUE(c1.IsEqualAt({}, val_complex)); + EXPECT_TRUE(c2.IsEqualAt({}, val_complex)); EXPECT_TRUE(c3.IsEqualAt({}, val_complex)); - EXPECT_FALSE(c3.IsEqualAt({}, std::numeric_limits::infinity())); + Literal c4 = LiteralUtil::CreateR0(val_complex); + EXPECT_TRUE(c4.IsEqualAt({}, val_double)); + EXPECT_TRUE(c4.IsEqualAt({}, val_integral)); + EXPECT_TRUE(c4.IsEqualAt({}, val_complex)); + EXPECT_FALSE(c4.IsEqualAt({}, std::numeric_limits::infinity())); complex128 val_true_complex = {10, 3}; complex64 val_smaller_complex = {10, 3}; - Literal c4 = LiteralUtil::CreateR0(val_true_complex); - EXPECT_TRUE(c4.IsEqualAt({}, val_true_complex)); - EXPECT_TRUE(c4.IsEqualAt({}, val_smaller_complex)); + Literal c5 = LiteralUtil::CreateR0(val_true_complex); + EXPECT_TRUE(c5.IsEqualAt({}, val_true_complex)); + EXPECT_TRUE(c5.IsEqualAt({}, val_smaller_complex)); } TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArrays) { diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 9285b826a9b..a766e451d7f 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -107,6 +107,11 @@ Literal CreateScalar(PrimitiveType primitive_type, Args... args) { return CreateScalarImpl(F{}, std::forward(args)...); case S64: return CreateScalarImpl(F{}, std::forward(args)...); + case F8E5M2: + return CreateScalarImpl(F{}, std::forward(args)...); + case F8E4M3FN: + return CreateScalarImpl(F{}, + std::forward(args)...); case F16: return CreateScalarImpl(F{}, std::forward(args)...); case BF16: diff --git a/tensorflow/compiler/xla/mlir/tools/BUILD b/tensorflow/compiler/xla/mlir/backends/cpu/BUILD similarity index 54% rename from tensorflow/compiler/xla/mlir/tools/BUILD rename to tensorflow/compiler/xla/mlir/backends/cpu/BUILD index 41c0a7c96f1..06ddfdc6ea7 100644 --- a/tensorflow/compiler/xla/mlir/tools/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/cpu/BUILD @@ -1,18 +1,17 @@ -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_binary") package( - default_visibility = [ - "//tensorflow:internal", - "@tf_runtime//:friends", - ], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/compiler/xla/mlir:__subpackages__"], licenses = ["notice"], ) -tf_cc_binary( +xla_cc_binary( name = "xla-cpu-opt", - srcs = ["xla_cpu_opt.cc"], + srcs = ["xla-cpu-opt.cc"], deps = [ - "//tensorflow/compiler/xla/mlir/transforms/cpu:passes", + "//tensorflow/compiler/xla/mlir/backends/cpu/transforms:passes", + "//tensorflow/compiler/xla/mlir/xla_cpu/ir:xla_cpu", "//tensorflow/compiler/xla/mlir_hlo:all_passes", "//tensorflow/compiler/xla/mlir_hlo:gml_st", "//tensorflow/compiler/xla/mlir_hlo:gml_st_passes", @@ -21,23 +20,13 @@ tf_cc_binary( "//tensorflow/compiler/xla/mlir_hlo:lhlo", "//tensorflow/compiler/xla/mlir_hlo:thlo", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "@llvm-project//mlir:BufferizationTransforms", "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:VectorDialect", "@stablehlo//:register", ], ) - -tf_cc_binary( - name = "xla-gpu-opt", - srcs = ["xla_gpu_opt.cc"], - deps = [ - "//tensorflow/compiler/xla/mlir/transforms/gpu:passes", - "//tensorflow/compiler/xla/mlir_hlo:lhlo", - "//tensorflow/compiler/xla/mlir_hlo:lhlo_gpu", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:MlirOptLib", - ], -) diff --git a/tensorflow/compiler/xla/mlir/transforms/cpu/BUILD b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/BUILD similarity index 57% rename from tensorflow/compiler/xla/mlir/transforms/cpu/BUILD rename to tensorflow/compiler/xla/mlir/backends/cpu/transforms/BUILD index 1415f64b400..51902deb2c0 100644 --- a/tensorflow/compiler/xla/mlir/transforms/cpu/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/BUILD @@ -1,9 +1,10 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") package( - default_visibility = ["//tensorflow:internal"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/compiler/xla:internal"], licenses = ["notice"], ) @@ -28,20 +29,34 @@ gentbl_cc_library( cc_library( name = "passes", srcs = [ + "legalize_collective_ops.cc", + "legalize_i1_vector_transfers.cc", "lmhlo_to_cpu_runtime.cc", + "remove_copies_to_out_params.cc", "xla_abi_legalization.cc", + "xla_cpu_memref_element_cast_to_llvm.cc", ], hdrs = ["passes.h"], deps = [ ":passes_inc_gen", + "//tensorflow/compiler/xla/mlir/runtime/transforms:type_converter", "//tensorflow/compiler/xla/mlir/runtime/utils:custom_calls", + "//tensorflow/compiler/xla/mlir/xla_cpu/ir:xla_cpu", "//tensorflow/compiler/xla/mlir_hlo", "//tensorflow/compiler/xla/mlir_hlo:lhlo", + "//tensorflow/compiler/xla/service:hlo_parser", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:VectorDialect", ], ) diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/legalize_collective_ops.cc b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/legalize_collective_ops.cc new file mode 100644 index 00000000000..61d8e500157 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/legalize_collective_ops.cc @@ -0,0 +1,303 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/Dialect/Utils/StaticValueUtils.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace xla { +namespace cpu { +namespace { + +#define GEN_PASS_DEF_LEGALIZECOLLECTIVEOPSPASS +#include "tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h.inc" + +using namespace mlir; // NOLINT + +class LegalizeCollectiveOpsPass + : public impl::LegalizeCollectiveOpsPassBase { + void runOnOperation() override; +}; + +Optional MatchReductionComputation(Region& region) { + if (!region.hasOneBlock()) { + return llvm::None; + } + + auto ret = dyn_cast(region.front().getTerminator()); + if (!ret || ret->getNumOperands() != 1) { + return llvm::None; + } + + auto computation = ret.getOperand(0).getDefiningOp(); + if (computation->getNumOperands() != 2 || + computation->getOperand(0) != region.front().getArgument(0) || + computation->getOperand(1) != region.front().getArgument(1)) { + return llvm::None; + } + + if (isa(computation)) { + return xla_cpu::ReductionKind::ALL_REDUCE_SUM; + } + if (isa(computation)) { + return xla_cpu::ReductionKind::ALL_REDUCE_PRODUCT; + } + if (isa(computation)) { + return xla_cpu::ReductionKind::ALL_REDUCE_MIN; + } + if (isa(computation)) { + return xla_cpu::ReductionKind::ALL_REDUCE_MAX; + } + + auto type = computation->getOperandTypes().front().dyn_cast(); + if (!type || !type.getElementType().isInteger(1)) { + return llvm::None; + } + + if (isa(computation)) { + return xla_cpu::ReductionKind::ALL_REDUCE_MIN; + } + if (isa(computation)) { + return xla_cpu::ReductionKind::ALL_REDUCE_MAX; + } + + return llvm::None; +} + +// Returns a `tensor.empty` with the same shape as `tensor`. +Value CreateEmptyLike(OpBuilder& b, Location loc, Value tensor) { + auto ty = tensor.getType().cast(); + auto sizes = tensor::getMixedSizes(b, loc, tensor); + return b.create(loc, sizes, ty.getElementType()); +} + +class AllReduceLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::AllReduceOp op, + PatternRewriter& rewriter) const override { + auto reduction_kind = MatchReductionComputation(op.getRegion()); + if (!reduction_kind) { + return failure(); + } + + SmallVector dsts; + for (auto operand : op->getOperands()) { + // The operands and results have the same shapes. + dsts.push_back(CreateEmptyLike(rewriter, op.getLoc(), operand)); + } + + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op->getOperands(), dsts, + op.getReplicaGroupsAttr(), + rewriter.getI64IntegerAttr(op.getChannelHandle() + ? op.getChannelHandle()->getHandle() + : int64_t{0}), + rewriter.getI32IntegerAttr(op.getUseGlobalDeviceIdsAttr() ? 1 : 0), + rewriter.getI32IntegerAttr(static_cast(*reduction_kind))); + + return success(); + }; +}; + +template +class IdLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IdOp op, + PatternRewriter& rewriter) const override { + Value id = rewriter.create(op.getLoc()); + // Wrap the scalar in a tensor. + Value id_tensor = rewriter.create( + op.getLoc(), RankedTensorType::get({}, rewriter.getI32Type()), id); + // And convert it to unsigned. This becomes a noop later. + rewriter.replaceOpWithNewOp( + op, + RankedTensorType::get({}, IntegerType::get(rewriter.getContext(), 32, + IntegerType::Unsigned)), + id_tensor); + return success(); + }; +}; + +class CollectivePermuteLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::CollectivePermuteOp op, + PatternRewriter& rewriter) const override { + // The result of collective_permute has the same shape as the operand. + Value dst = CreateEmptyLike(rewriter, op.getLoc(), op.getOperand()); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op->getOperand(0), dst, + op.getSourceTargetPairsAttr(), + rewriter.getI64IntegerAttr(op.getChannelHandle() + ? op.getChannelHandle()->getHandle() + : int64_t{0})); + return success(); + }; +}; + +class AllToAllLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::AllToAllOp op, + PatternRewriter& rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + SmallVector dsts; + + if (!op.getConcatDimensionAttr()) { + for (auto operand : op->getOperands()) { + // The operands and results of TupleAllToAll the same shapes. + dsts.push_back(CreateEmptyLike(rewriter, op.getLoc(), operand)); + } + } else { + auto sizes = + getAsValues(b, b.getLoc(), + tensor::getMixedSizes(b, op.getLoc(), op->getOperand(0))); + uint64_t split_dimension = *op.getSplitDimension(); + Value split_count = b.create(*op.getSplitCount()); + sizes[split_dimension] = b.createOrFold( + b.getIndexType(), sizes[split_dimension], split_count); + uint64_t concat_dimension = *op.getConcatDimension(); + sizes[concat_dimension] = + b.createOrFold(sizes[concat_dimension], split_count); + + dsts.push_back(rewriter.create( + op.getLoc(), getAsOpFoldResult(sizes), + op->getResultTypes()[0].cast().getElementType())); + } + + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op->getOperands(), dsts, + op.getReplicaGroupsAttr(), op.getSplitDimensionAttr(), + op.getConcatDimensionAttr(), op.getSplitCountAttr()); + return success(); + }; +}; + +class FftLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::FftOp op, + PatternRewriter& rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + // TODO(jreiffers): Support dynamic sizes. + auto dst = b.create(op.getLoc(), op.getType().getShape(), + op.getType().getElementType()); + + auto lengths = + llvm::to_vector<3>(op.getFftLengthAttr().getValues()); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op->getOperand(0), dst, + static_cast(op.getFftType()), + rewriter.getI64ArrayAttr(lengths)); + return success(); + }; +}; + +class OutfeedLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::OutfeedOp op, + PatternRewriter& rewriter) const override { + SmallVector result_types; + for (auto operand : op.getInputs()) { + result_types.push_back( + TypeAttr::get(operand.getType().cast().getElementType())); + } + rewriter.create( + op.getLoc(), llvm::None, op.getInputs(), op.getOutfeedConfigAttr(), + ArrayAttr::get(op->getContext(), result_types)); + + // Replacing the op with the token. + rewriter.replaceOp(op, op.getToken()); + return success(); + }; +}; + +class RngBitGeneratorLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::RngBitGeneratorOp op, + PatternRewriter& rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto state_init = CreateEmptyLike(b, op.getLoc(), op.getOperand()); + auto output_init = + b.create(op.getLoc(), op.getType(1), ValueRange{}); + + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op->getOperand(0), state_init, output_init, + op.getRngAlgorithmAttr()); + return success(); + }; +}; + +class AddDependencyLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::AddDependencyOp op, + PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op->getOperands()); + return success(); + }; +}; + +void LegalizeCollectiveOpsPass::runOnOperation() { + func::FuncOp func = getOperation(); + MLIRContext* ctx = func.getContext(); + + // Convert mhlo collective operations to XLA cpu ops. + RewritePatternSet patterns(ctx); + patterns.insert, + IdLowering, + OutfeedLowering, RngBitGeneratorLowering>(ctx); + + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + return signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> createLegalizeCollectiveOpsPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/legalize_i1_vector_transfers.cc b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/legalize_i1_vector_transfers.cc new file mode 100644 index 00000000000..b108e1f21c8 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/legalize_i1_vector_transfers.cc @@ -0,0 +1,139 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Vector/IR/VectorOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.h" + +namespace xla { +namespace cpu { +namespace { + +#define GEN_PASS_DEF_LEGALIZEI1VECTORTRANSFEROPSPASS +#include "tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h.inc" + +using namespace mlir; // NOLINT + +class LegalizeI1VectorTransferOpsPass + : public impl::LegalizeI1VectorTransferOpsPassBase< + LegalizeI1VectorTransferOpsPass> { + void runOnOperation() override; +}; + +Value CastToI8(Value in, ImplicitLocOpBuilder& b, bool optional = false) { + auto ty = in.getType(); + assert(optional || getElementTypeOrSelf(ty).isInteger(1)); + if (!getElementTypeOrSelf(ty).isInteger(1)) { + return {}; + } + + if (auto vec_ty = ty.dyn_cast()) { + return b.create( + vec_ty.cloneWith(std::nullopt, b.getI8Type()), in); + } + if (auto memref_ty = ty.dyn_cast()) { + auto cast_ty = memref_ty.cloneWith(std::nullopt, b.getI8Type()); + return b.create(cast_ty, in); + } + if (ty == b.getI1Type()) { + return b.create(b.getI8Type(), in); + } + return {}; +} + +class I1TransferReadLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter& rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + b.setInsertionPoint(op); + Value cast_src = CastToI8(op.getSource(), b, /*optional=*/true); + if (!cast_src) { + return failure(); + } + + auto cast_result_ty = + op.getVector().getType().cloneWith(std::nullopt, b.getI8Type()); + TypedValue new_read = + b.create( + TypeRange{cast_result_ty}, cast_src, op.getIndices(), + op.getPermutationMap(), CastToI8(op.getPadding(), b), op.getMask(), + op.getInBoundsAttr()) + .getResult(); + Value zero = b.create( + DenseElementsAttr::get(new_read.getType(), b.getI8IntegerAttr(0))); + auto result = + b.create(arith::CmpIPredicate::ne, new_read, zero); + rewriter.replaceOp(op, {result}); + return success(); + }; +}; + +class I1TransferWriteLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp op, + PatternRewriter& rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + b.setInsertionPoint(op); + // Confusingly, the destination is called 'source'. + auto cast_dst = CastToI8(op.getSource(), b, /*optional=*/true); + if (!cast_dst) { + return failure(); + } + + op.getVectorMutable().assign(CastToI8(op.getVector(), b)); + op.getSourceMutable().assign(cast_dst); + return success(); + }; +}; + +void LegalizeI1VectorTransferOpsPass::runOnOperation() { + func::FuncOp func = getOperation(); + MLIRContext* ctx = func.getContext(); + + RewritePatternSet patterns(ctx); + patterns.insert(ctx); + // TODO(jreiffers): Handle other transfer ops if we need them (load, + // maskedload, etc.). + + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + return signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> +createLegalizeI1VectorTransferOpsPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/lmhlo_to_cpu_runtime.cc b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/lmhlo_to_cpu_runtime.cc new file mode 100644 index 00000000000..1444d8daece --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/lmhlo_to_cpu_runtime.cc @@ -0,0 +1,515 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir/runtime/transforms/type_converter.h" +#include "tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.h" +#include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" + +namespace xla { +namespace cpu { +namespace { + +#define GEN_PASS_DEF_CONVERTLMHLOTOCPURUNTIMEPASS +#include "tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h.inc" + +using namespace mlir; // NOLINT + +using mlir::lmhlo::CustomCallOp; +using mlir::lmhlo::InfeedOp; +using mlir::lmhlo::OutfeedOp; + +using xla_cpu::PartitionIdOp; +using xla_cpu::ReplicaIdOp; + +using xla::runtime::AppendCustomCallAttrs; +using xla::runtime::CustomCallDeclarations; + +class ConvertLmhloToCpuRuntimePass + : public impl::ConvertLmhloToCpuRuntimePassBase< + ConvertLmhloToCpuRuntimePass> { + void runOnOperation() override; + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } +}; + +// Copies memrefs with non-identity layouts (e.g. results of memref.subviews) +// to newly allocated memrefs, ensuring all outputs have flat layouts. +// TODO(jreiffers): If the memref just as an offset, but its layout is otherwise +// default, the copy is overkill. +SmallVector EnsureFlatMemrefs(ValueRange values, + ImplicitLocOpBuilder& b) { + SmallVector out; + for (Value value : values) { + auto ty = value.getType().dyn_cast(); + if (!ty || ty.getLayout().isIdentity()) { + out.push_back(value); + } else { + auto default_layout_ty = + MemRefType::get(ty.getShape(), ty.getElementType()); + auto alloc = + out.emplace_back(b.create(default_layout_ty)); + b.create(value, alloc); + } + } + return out; +} + +// Replaces a DPS style collective op with a custom call. +func::CallOp CreateCallForDpsCollectiveOp(Operation* op, + CustomCallDeclarations& custom_calls, + StringRef call_target, + PatternRewriter& rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + b.setInsertionPoint(op); + + // Subview ops result in strided Memrefs. The runtime can't deal with them, + // so we copy everything that doesn't have the default layout. + SmallVector new_operands = EnsureFlatMemrefs(op->getOperands(), b); + + func::FuncOp callee = custom_calls.GetOrCreate( + b, call_target, TypeRange(ValueRange(new_operands)), TypeRange()); + auto call = + b.create(callee.getName(), TypeRange(), new_operands); + + // Copy attributes from original op. + for (auto& attr : op->getAttrs()) { + call->setAttr(attr.getName(), attr.getValue()); + } + rewriter.eraseOp(op); + return call; +} + +//===----------------------------------------------------------------------===// + +class CustomCallOpLowering : public OpRewritePattern { + private: + static constexpr const char kCustomCallTarget[] = "xla.cpu.custom_call"; + + public: + CustomCallOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) + : OpRewritePattern(ctx), custom_calls_(custom_calls) {} + + // Rewrite custom call with `API_VERSION_TYPED_FFI` version into XLA runtime + // custom calls bypassing custom call adaptor. + LogicalResult rewriteTypedCustomCall(CustomCallOp op, + PatternRewriter& rewriter) const { + // TODO(ezhulenev): Support target arg mapping, or explain why we do not + // need them for typed custom calls. + if (op.getTargetArgMapping()) + return op.emitOpError( + "API_VERSION_TYPED_FFI custom calls do not " + "support target arg mapping"); + + // Create a custom call function declaration. + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + func::FuncOp callee = + custom_calls_.GetOrCreate(b, op.getCallTargetName(), op); + callee->setAttr("rt.dynamic", UnitAttr::get(b.getContext())); + + // Forward backend config to the custom call implementation. + auto dict = op.getBackendConfig() + ? op.getBackendConfig()->cast() + : nullptr; + llvm::SmallVector backend_config(dict.begin(), dict.end()); + + // Call the custom call function forwarding user-defined attributes. + auto call = rewriter.replaceOpWithNewOp( + op, callee.getName(), TypeRange(), op.getOperands()); + AppendCustomCallAttrs(call, backend_config); + + return success(); + } + + LogicalResult matchAndRewrite(CustomCallOp op, + PatternRewriter& rewriter) const override { + // Typed custom calls lowered directly to XLA runtime custom calls. + if (op.getApiVersion() == mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI) + return rewriteTypedCustomCall(op, rewriter); + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + // By default all operands passed to the custom call handler. + llvm::SmallVector operands = op.getOperands(); + + // Get the number of outputs from operand_segment_sizes. + int64_t num_results = op->getAttrOfType( + op.getOperandSegmentSizesAttrName())[1]; + + // If custom call has target arguments mapping, then we need to pass empty + // memrefs in place of holes. + if (op.getTargetArgMapping().has_value()) { + auto mapping = *op.getTargetArgMapping(); + int64_t num_args = mapping.getNumArgs(); + num_results = mapping.getNumResults(); + + // Always create an `alloca` in the parent function entry block. + // See: https://llvm.org/docs/Frontend/PerformanceTips.html#use-of-allocas + Value hole = [&]() -> Value { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart( + &op->getParentOfType().front()); + return b.create(MemRefType::get({0}, b.getI8Type())); + }(); + + // We represent holes as empty i8 memrefs. + operands = llvm::SmallVector(num_args + num_results, hole); + + // Update operands to mapped custom call arguments. + auto args = mapping.getArgsToTargetArgs(); + for (const auto& indexed : llvm::enumerate(args)) + operands[indexed.value()] = op.getArgs()[indexed.index()]; + + // Update operands to mapped custom call results. + auto res = mapping.getResultsToTargetResults(); + for (const auto& indexed : llvm::enumerate(res)) + operands[num_args + indexed.value()] = op.getOutput()[indexed.index()]; + } + + // TODO(jreiffers): This will break if an output has a non-default layout. + operands = EnsureFlatMemrefs(operands, b); + // Create a custom call function declaration. + func::FuncOp callee = custom_calls_.GetOrCreate( + b, kCustomCallTarget, TypeRange(ValueRange(operands)), TypeRange()); + + // The ABI is different depending on whether the original op was outputting + // a tuple or not. For multiple outputs this is trivial but for a single + // output we rely on the xla_shape attribute to distinguish the ABIs. + bool output_tuple = num_results > 1; + if (auto xla_shape = op->getAttrOfType("xla_shape")) + output_tuple = ParseShape(xla_shape.strref())->IsTuple(); + + // This is not equivalent to op.getApiVersionAttr() - that call returns null + // if the attribute is absent. getApiVersion returns the default. + Attribute api_version = + mhlo::CustomCallApiVersionAttr::get(getContext(), op.getApiVersion()); + llvm::SmallVector custom_call_attrs = { + {b.getStringAttr("num_results"), + b.getI32IntegerAttr(static_cast(num_results))}, + {b.getStringAttr("output_tuple"), b.getBoolAttr(output_tuple)}, + {b.getStringAttr("api_version"), api_version}, + {b.getStringAttr("call_target_name"), op.getCallTargetNameAttr()}}; + + // Call the runtime intrinsic with the original operands. + auto call = rewriter.replaceOpWithNewOp( + op, callee.getName(), TypeRange(), operands); + AppendCustomCallAttrs(call, custom_call_attrs); + + return success(); + } + + private: + CustomCallDeclarations& custom_calls_; +}; + +//===----------------------------------------------------------------------===// + +class InfeedOpLowering : public OpRewritePattern { + private: + static constexpr const char kCallTarget[] = "xla.cpu.infeed"; + + public: + InfeedOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) + : OpRewritePattern(ctx), custom_calls_(custom_calls) {} + + LogicalResult matchAndRewrite(InfeedOp op, + PatternRewriter& rewriter) const override { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + + // By default all operands are passed to the custom call handler. + llvm::SmallVector operands = op->getOperands(); + + // Create a custom call function declaration. + func::FuncOp callee = + custom_calls_.GetOrCreate(b, StringRef(kCallTarget), + TypeRange(ValueRange(operands)), TypeRange()); + + // Call the runtime intrinsic with the original operands. + rewriter.replaceOpWithNewOp(op, callee.getName(), TypeRange(), + operands); + return success(); + } + + private: + CustomCallDeclarations& custom_calls_; +}; + +//===----------------------------------------------------------------------===// + +template +class IdOpLowering : public OpRewritePattern { + public: + IdOpLowering(MLIRContext* ctx, llvm::StringRef call_target, + CustomCallDeclarations& custom_calls) + : OpRewritePattern(ctx), + call_target_(call_target), + custom_calls_(custom_calls) {} + + LogicalResult matchAndRewrite(IdOp op, + PatternRewriter& rewriter) const override { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + + // Create a custom call function declaration. + func::FuncOp callee = custom_calls_.GetOrCreate( + b, call_target_, TypeRange(), TypeRange(rewriter.getI32Type())); + + rewriter.replaceOpWithNewOp(op, callee.getName(), + TypeRange(rewriter.getI32Type())); + return success(); + } + + private: + llvm::StringRef call_target_; + CustomCallDeclarations& custom_calls_; +}; + +//===----------------------------------------------------------------------===// + +class AllReduceLowering : public OpRewritePattern { + public: + AllReduceLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) + : OpRewritePattern(ctx), custom_calls_(custom_calls) {} + + LogicalResult matchAndRewrite(xla_cpu::AllReduceOp op, + PatternRewriter& rewriter) const override { + if (!op.getOperandTypes().front().isa()) { + return failure(); + } + + auto call = CreateCallForDpsCollectiveOp(op.getOperation(), custom_calls_, + kCallTarget, rewriter); + + // Set default attributes. + if (!call->hasAttr("use_global_device_ids")) { + call->setAttr("use_global_device_ids", rewriter.getI32IntegerAttr(0)); + } + if (!call->hasAttr("op_id")) { + call->setAttr("op_id", rewriter.getI64IntegerAttr(0)); + } + + return success(); + } + + private: + static constexpr const char kCallTarget[] = "xla.cpu.all_reduce"; + + CustomCallDeclarations& custom_calls_; +}; + +//===----------------------------------------------------------------------===// + +class AllToAllLowering : public OpRewritePattern { + public: + AllToAllLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) + : OpRewritePattern(ctx), custom_calls_(custom_calls) {} + + LogicalResult matchAndRewrite(xla_cpu::AllToAllOp op, + PatternRewriter& rewriter) const override { + if (op.getSplitDimensionAttr()) { + op.emitOpError("ArrayAllToAll is not supported"); + return failure(); + } + CreateCallForDpsCollectiveOp(op.getOperation(), custom_calls_, kCallTarget, + rewriter); + return success(); + } + + private: + static constexpr const char kCallTarget[] = "xla.cpu.tuple_all_to_all"; + + CustomCallDeclarations& custom_calls_; +}; + +//===----------------------------------------------------------------------===// + +class CollectivePermuteLowering + : public OpRewritePattern { + public: + CollectivePermuteLowering(MLIRContext* ctx, + CustomCallDeclarations& custom_calls) + : OpRewritePattern(ctx), custom_calls_(custom_calls) {} + + LogicalResult matchAndRewrite(xla_cpu::CollectivePermuteOp op, + PatternRewriter& rewriter) const override { + if (!op.getOperandTypes().front().isa()) { + return failure(); + } + + CreateCallForDpsCollectiveOp(op.getOperation(), custom_calls_, kCallTarget, + rewriter); + return success(); + } + + private: + static constexpr const char kCallTarget[] = "xla.cpu.collective_permute"; + + CustomCallDeclarations& custom_calls_; +}; + +//===----------------------------------------------------------------------===// + +class FftLowering : public OpRewritePattern { + public: + FftLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) + : OpRewritePattern(ctx), custom_calls_(custom_calls) {} + + LogicalResult matchAndRewrite(xla_cpu::FftOp op, + PatternRewriter& rewriter) const override { + CreateCallForDpsCollectiveOp(op.getOperation(), custom_calls_, kCallTarget, + rewriter); + return success(); + } + + private: + static constexpr const char kCallTarget[] = "xla.cpu.fft"; + + CustomCallDeclarations& custom_calls_; +}; + +//===----------------------------------------------------------------------===// + +class RngBitGeneratorLowering + : public OpRewritePattern { + public: + RngBitGeneratorLowering(MLIRContext* ctx, + CustomCallDeclarations& custom_calls) + : OpRewritePattern(ctx), custom_calls_(custom_calls) {} + + LogicalResult matchAndRewrite(xla_cpu::RngBitGeneratorOp op, + PatternRewriter& rewriter) const override { + auto algorithm = + op.getRngAlgorithmAttr().cast().getValue(); + op->removeAttr("rng_algorithm"); + + CreateCallForDpsCollectiveOp(op.getOperation(), custom_calls_, + algorithm == mhlo::RngAlgorithm::THREE_FRY + ? kThreeFryTarget + : kPhiloxTarget, + rewriter); + return success(); + } + + private: + static constexpr const char kThreeFryTarget[] = "xla.cpu.rng.three_fry"; + static constexpr const char kPhiloxTarget[] = "xla.cpu.rng.philox"; + + CustomCallDeclarations& custom_calls_; +}; + +//===----------------------------------------------------------------------===// + +class OutfeedLowering : public OpRewritePattern { + public: + OutfeedLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) + : OpRewritePattern(ctx), custom_calls_(custom_calls) {} + + LogicalResult matchAndRewrite(xla_cpu::OutfeedOp op, + PatternRewriter& rewriter) const override { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + + // By default all operands are passed to the custom call handler. + llvm::SmallVector operands = EnsureFlatMemrefs(op->getOperands(), b); + + // Create a custom call function declaration. + func::FuncOp callee = + custom_calls_.GetOrCreate(b, StringRef(kCallTarget), + TypeRange(ValueRange(operands)), TypeRange()); + + llvm::SmallVector custom_call_attrs; + SmallVector types; + for (int i = 0; i < op.getResultType().size(); ++i) { + auto type_attr = cast(op.getResultType()[i]); + auto status_or_primitive_type = + xla::runtime::TypeConverter::ConvertElementType(type_attr.getValue()); + if (!status_or_primitive_type.ok()) { + return rewriter.notifyMatchFailure( + op, + "is not provided with a supported primitive type in the result " + "type attribute."); + } + types.push_back(status_or_primitive_type.value()); + } + + // Call the runtime intrinsic with the original operands. + auto call = rewriter.replaceOpWithNewOp( + op, callee.getName(), TypeRange(), operands); + call->setAttr("result_type", b.getI32ArrayAttr(types)); + + return success(); + } + + private: + static constexpr const char kCallTarget[] = "xla.cpu.outfeed"; + + CustomCallDeclarations& custom_calls_; +}; + +//===----------------------------------------------------------------------===// + +void ConvertLmhloToCpuRuntimePass::runOnOperation() { + ModuleOp module = getOperation(); + MLIRContext* ctx = module.getContext(); + + // Keep track of the custom calls created from the lowered operations. + SymbolTable sym_table(module); + CustomCallDeclarations custom_calls(std::move(sym_table)); + + // Convert lmhlo operations to XLA cpu runtime custom calls. + RewritePatternSet patterns(ctx); + patterns.insert( + ctx, custom_calls); + patterns.insert>(ctx, "xla.cpu.partition_id", + custom_calls); + patterns.insert>(ctx, "xla.cpu.replica_id", + custom_calls); + + if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) + return signalPassFailure(); +} + +} // namespace + +std::unique_ptr> +createConvertLmhloToCpuRuntimePass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/mlir/transforms/cpu/passes.h b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h similarity index 67% rename from tensorflow/compiler/xla/mlir/transforms/cpu/passes.h rename to tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h index f9b5096fb60..34aae02bdcb 100644 --- a/tensorflow/compiler/xla/mlir/transforms/cpu/passes.h +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_MLIR_TRANSFORMS_CPU_PASSES_H_ -#define TENSORFLOW_COMPILER_XLA_MLIR_TRANSFORMS_CPU_PASSES_H_ +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_CPU_TRANSFORMS_PASSES_H_ +#define TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_CPU_TRANSFORMS_PASSES_H_ #include @@ -35,12 +35,24 @@ createConvertLmhloToCpuRuntimePass(); std::unique_ptr> createXlaAbiLegalizationPass(); +std::unique_ptr> +createLegalizeCollectiveOpsPass(); + +std::unique_ptr> +createLegalizeI1VectorTransferOpsPass(); + +std::unique_ptr> +createConvertXlaCpuMemRefElementCastToLLVMPass(); + +std::unique_ptr> +createRemoveCopiesToOutParamsPass(); + //===-----------------------------------------------------------------------===/ #define GEN_PASS_REGISTRATION -#include "tensorflow/compiler/xla/mlir/transforms/cpu/passes.h.inc" +#include "tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h.inc" } // namespace cpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_MLIR_TRANSFORMS_CPU_PASSES_H_ +#endif // TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_CPU_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/xla/mlir/transforms/cpu/passes.td b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.td similarity index 53% rename from tensorflow/compiler/xla/mlir/transforms/cpu/passes.td rename to tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.td index 19a8be700de..50aa6ff516e 100644 --- a/tensorflow/compiler/xla/mlir/transforms/cpu/passes.td +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.td @@ -51,4 +51,61 @@ def LegalizeXlaAbiPass : let constructor = "createXlaAbiLegalizationPass()"; } +def LegalizeCollectiveOpsPass : + Pass<"xla-legalize-collective-ops", "mlir::func::FuncOp"> { + let summary = "Legalizes collective ops to AllToAll and regular ops."; + + let description = [{ + Lowers collective ops to xla_cpu ops. + }]; + + let dependentDialects = [ + "mlir::mhlo::MhloDialect", "mlir::xla_cpu::XlaCpuDialect" + ]; + + let constructor = "createLegalizeCollectiveOpsPass()"; +} + +def LegalizeI1VectorTransferOpsPass : + Pass<"xla-legalize-i1-vector-transfers", "mlir::func::FuncOp"> { + let summary = "Legalizes transfer ops operating on vectors of i1."; + + let description = [{ + Replaces transfers involving vectors with memref casts to i8, and + vector comparisons. + }]; + + let dependentDialects = [ + "mlir::vector::VectorDialect", "mlir::xla_cpu::XlaCpuDialect" + ]; + + let constructor = "createLegalizeI1VectorTransferOpsPass()"; +} + +def ConvertXlaCpuMemRefElementCastToLLVMPass : + Pass<"xla-convert-memref-element-cast-to-llvm", "mlir::func::FuncOp"> { + let summary = "Converts xla_cpu.memref_element_cast ops to LLVM."; + + let description = [{ + Rewrites xla_cpu.memref_elements_cast ops as a new memref descriptor, + where the allocated and aligned pointers are updated. + }]; + + let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + + let constructor = "createConvertXlaCpuMemRefElementCastToLLVMPass()"; +} + +def RemoveCopiesToOutParamsPass : + Pass<"xla-remove-copies-to-out-params", "mlir::func::FuncOp"> { + let summary = "Removes redundant alloc/copy pairs to out params."; + + let description = [{ + Removes redundant alloc/alloca + copy pairs that can remain after running + bufferization's BufferResultsToOutParams pass. + }]; + + let constructor = "createRemoveCopiesToOutParamsPass()"; +} + #endif // TENSORFLOW_COMPILER_XLA_MLIR_TRANSFORMS_CPU_PASSES_TD_ diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/remove_copies_to_out_params.cc b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/remove_copies_to_out_params.cc new file mode 100644 index 00000000000..3d1d7d813ff --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/remove_copies_to_out_params.cc @@ -0,0 +1,129 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h" + +namespace xla { +namespace cpu { +namespace { + +using ::mlir::LogicalResult; +using ::mlir::Operation; +using ::mlir::OperationPass; +using ::mlir::PatternRewriter; +using ::mlir::RewritePatternSet; +using ::mlir::Value; + +namespace memref = ::mlir::memref; +namespace func = ::mlir::func; + +#define GEN_PASS_DEF_REMOVECOPIESTOOUTPARAMSPASS +#include "tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h.inc" + +LogicalResult AllocRemoval(memref::CopyOp copy, PatternRewriter &rewriter) { + Value from = copy.getSource(); + Value to = copy.getTarget(); + + Operation *alloc = + llvm::dyn_cast_or_null(from.getDefiningOp()); + if (!alloc) { + return mlir::failure(); + } + + // Only match if we dealloc immediately after the copy. + auto dealloc = llvm::dyn_cast_or_null(copy->getNextNode()); + if (!dealloc || dealloc.getMemref() != from) { + return mlir::failure(); + } + + // Only go up one level to grab the parent function; the match we're looking + // for is at the very end of a function. + auto func = llvm::dyn_cast_or_null(copy->getParentOp()); + if (!func) { + return mlir::failure(); + } + + // If the copy target is a function argument, use it directly. + if (llvm::is_contained(func.getArguments(), to)) { + rewriter.replaceAllUsesWith(from, to); + rewriter.eraseOp(alloc); + rewriter.eraseOp(dealloc); + rewriter.eraseOp(copy); + return mlir::success(); + } + return mlir::failure(); +} + +LogicalResult AllocaRemoval(memref::CopyOp copy, PatternRewriter &rewriter) { + Value from = copy.getSource(); + Value to = copy.getTarget(); + + Operation *alloca = + llvm::dyn_cast_or_null(from.getDefiningOp()); + if (!alloca) { + return mlir::failure(); + } + + // Only go up one level to grab the parent function; the match we're looking + // for is at the very end of a function. + auto func = llvm::dyn_cast_or_null(copy->getParentOp()); + if (!func) { + return mlir::failure(); + } + + // If the copy target is a function argument, use it directly. + if (llvm::is_contained(func.getArguments(), to)) { + rewriter.replaceAllUsesWith(from, to); + rewriter.eraseOp(alloca); + rewriter.eraseOp(copy); + return mlir::success(); + } + return mlir::failure(); +} + +class RemoveCopiesToOutParamsPass + : public impl::RemoveCopiesToOutParamsPassBase< + RemoveCopiesToOutParamsPass> { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(AllocRemoval); + patterns.add(AllocaRemoval); + if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +createRemoveCopiesToOutParamsPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/BUILD b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/BUILD new file mode 100644 index 00000000000..e65d104f2d4 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/BUILD @@ -0,0 +1,24 @@ +load("//tensorflow/tsl:tsl.default.bzl", "filegroup") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +glob_lit_tests( + data = [":test_utilities"], + driver = "//tensorflow/compiler/xla:run_lit.sh", + test_file_exts = ["mlir"], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/xla/mlir/backends/cpu:xla-cpu-opt", + "@llvm-project//llvm:FileCheck", + "@llvm-project//mlir:run_lit.sh", + ], +) diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/collective_ops.mlir b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/collective_ops.mlir new file mode 100644 index 00000000000..f02a4eac6ab --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/collective_ops.mlir @@ -0,0 +1,256 @@ +// RUN: xla-cpu-opt %s -xla-legalize-collective-ops | FileCheck %s + +func.func @max_reduce(%arg0: tensor<10xf32>) -> tensor<10xf32> { + %0 = "mhlo.all_reduce"(%arg0) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %max = mhlo.maximum %lhs, %rhs : tensor + "mhlo.return"(%max) : (tensor) -> () + }) + { + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, + channel_handle = #mhlo.channel_handle< + handle = 5, + type = 2 + >, + use_global_device_ids + } : (tensor<10xf32>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +// CHECK-LABEL: @max_reduce +// CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32> +// CHECK: %[[DST:.*]] = tensor.empty() : tensor<10xf32> +// CHECK: %[[RET:.*]] = "xla_cpu.all_reduce"(%[[ARG0]], %[[DST]]) { +// CHECK-SAME: channel_handle = 5 : i64, +// CHECK-SAME: reduction_kind = 3 : i32, +// CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7]]> +// CHECK-SAME: use_global_device_ids = 1 +// CHECK: return %[[RET]] + +func.func @and_reduce(%arg0: tensor<1xi1>) -> tensor<1xi1> { + %0 = "mhlo.all_reduce"(%arg0) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %1 = mhlo.and %lhs, %rhs : tensor + mhlo.return %1 : tensor + }) { + replica_groups = dense<> : tensor<0x0xi64> + } : (tensor<1xi1>) -> tensor<1xi1> + func.return %0 : tensor<1xi1> +} + +// CHECK-LABEL: @and_reduce +// CHECK: reduction_kind = 2 : i32, + +func.func @or_reduce(%arg0: tensor<1xi1>) -> tensor<1xi1> { + %0 = "mhlo.all_reduce"(%arg0) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %1 = mhlo.or %lhs, %rhs : tensor + mhlo.return %1 : tensor + }) { + replica_groups = dense<> : tensor<0x0xi64> + } : (tensor<1xi1>) -> tensor<1xi1> + func.return %0 : tensor<1xi1> +} + +// CHECK-LABEL: @or_reduce +// CHECK: reduction_kind = 3 : i32, + +func.func @min_reduce_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.all_reduce"(%arg0) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %max = mhlo.minimum %lhs, %rhs : tensor + "mhlo.return"(%max) : (tensor) -> () + }) + { + replica_groups = dense<> : tensor<0x0xi64>, + channel_handle = #mhlo.channel_handle< + handle = 5, + type = 2 + > + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: @min_reduce +// CHECK-SAME: %[[ARG0:.*]]: tensor +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK: %[[DST:.*]] = tensor.empty(%[[DIM]]) +// CHECK: "xla_cpu.all_reduce"(%[[ARG0]], %[[DST]]) +// CHECK-SAME: reduction_kind = 2 +// CHECK-SAME: use_global_device_ids = 0 + +func.func @partition_id() -> tensor { + %0 = "mhlo.partition_id"() : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: @partition_id +// CHECK: %[[ID:.*]] = "xla_cpu.partition_id"() : () -> i32 +// CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[ID]] : tensor +// CHECK: %[[CAST:.*]] = mhlo.convert %[[TENSOR]] : (tensor) -> tensor +// CHECK: return %[[CAST]] + +func.func @replica_id() -> tensor { + %0 = "mhlo.replica_id"() : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: @replica_id +// CHECK: %[[ID:.*]] = "xla_cpu.replica_id"() : () -> i32 +// CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[ID]] : tensor +// CHECK: %[[CAST:.*]] = mhlo.convert %[[TENSOR]] : (tensor) -> tensor +// CHECK: return %[[CAST]] + +func.func @collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + %0 = "mhlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, + channel_handle = #mhlo.channel_handle + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: @collective_permute +// CHECK-SAME: %[[ARG0:.*]]: tensor<16x8xf32> +// CHECK: %[[DST:.*]] = tensor.empty() : tensor<16x8xf32> +// CHECK: %[[RET:.*]] = "xla_cpu.collective_permute"(%[[ARG0]], %[[DST]]) { +// CHECK-SAME: channel_handle = 1 +// CHECK-SAME: source_target_pairs = dense< +// CHECK: return %[[RET]] + +func.func @collective_permute_dynamic(%arg0: tensor<16x?xf32>) + -> tensor<16x?xf32> { + %0 = "mhlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, + channel_handle = #mhlo.channel_handle + } : (tensor<16x?xf32>) -> tensor<16x?xf32> + func.return %0 : tensor<16x?xf32> +} + +// CHECK-LABEL: @collective_permute_dynamic +// CHECK-SAME: %[[ARG0:.*]]: tensor<16x?xf32> +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 +// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK: %[[DST:.*]] = tensor.empty(%[[DIM]]) : tensor<16x?xf32> +// CHECK: "xla_cpu.collective_permute"(%[[ARG0]], %[[DST]]) { + +func.func @all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { + %0 = "mhlo.all_to_all"(%arg0) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// CHECK-LABEL: @all_to_all +// CHECK-SAME: %[[ARG0:.*]]: tensor<4x16xf32> +// CHECK: %[[DST:.*]] = tensor.empty() : tensor<16x4xf32> +// CHECK: %[[RET:.*]] = "xla_cpu.all_to_all"(%[[ARG0]], %[[DST]]) { +// CHECK-SAME: concat_dimension = 0 +// CHECK-SAME: replica_groups = dense< +// CHECK-SAME: split_count = 4 +// CHECK-SAME: split_dimension = 1 +// CHECK: return %[[RET]] + +func.func @all_to_all_dynamic_concat_dim(%arg0: tensor) + -> tensor { + %0 = "mhlo.all_to_all"(%arg0) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: @all_to_all_dynamic_concat_dim +// CHECK-SAME: %[[ARG0:.*]]: tensor +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 +// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK: %[[CONCAT_DIM:.*]] = arith.muli %[[DIM]], %[[C4]] +// CHECK: %[[DST:.*]] = tensor.empty(%[[CONCAT_DIM]]) : tensor +// CHECK: "xla_cpu.all_to_all"(%[[ARG0]], %[[DST]]) { + +func.func @all_to_all_dynamic_split_dim(%arg0: tensor<4x?xf32>) + -> tensor<16x?xf32> { + %0 = "mhlo.all_to_all"(%arg0) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + } : (tensor<4x?xf32>) -> tensor<16x?xf32> + func.return %0 : tensor<16x?xf32> +} + +// CHECK-LABEL: @all_to_all_dynamic_split_dim +// CHECK-SAME: %[[ARG0:.*]]: tensor<4x?xf32> +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 +// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK: %[[CONCAT_DIM:.*]] = arith.divui %[[DIM]], %[[C4]] +// CHECK: %[[DST:.*]] = tensor.empty(%[[CONCAT_DIM]]) : tensor<16x?xf32> +// CHECK: "xla_cpu.all_to_all"(%[[ARG0]], %[[DST]]) { + +func.func @all_to_all_tuple(%arg0: tensor<128x4xf32>, %arg1: tensor<128x4xf32>) + -> (tensor<128x4xf32>, tensor<128x4xf32>) { + %0:2 = "mhlo.all_to_all"(%arg0, %arg1) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + } : (tensor<128x4xf32>, tensor<128x4xf32>) -> (tensor<128x4xf32>, tensor<128x4xf32>) + return %0#0, %0#1 : tensor<128x4xf32>, tensor<128x4xf32> +} + +// CHECK-LABEL: @all_to_all_tuple +// CHECK-SAME: %[[ARG0:.*]]: tensor<128x4xf32>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<128x4xf32> +// CHECK: %[[DST0:.*]] = tensor.empty() : tensor<128x4xf32> +// CHECK: %[[DST1:.*]] = tensor.empty() : tensor<128x4xf32> +// CHECK: "xla_cpu.all_to_all"(%[[ARG0]], %[[ARG1]], %[[DST0]], %[[DST1]]) + +func.func @outfeed_0_input(%token: !mhlo.token) -> !mhlo.token { + %res = "mhlo.outfeed"(%token) {outfeed_config = "foobar"} : (!mhlo.token) -> !mhlo.token + func.return %res : !mhlo.token +} + +// CHECK-LABEL: @outfeed_0_input +// CHECK: "xla_cpu.outfeed"() {config = "foobar", result_type = []} : () -> () + +func.func @outfeed_1_input(%data: tensor<2xui32>, %token: !mhlo.token) + -> !mhlo.token attributes {xlaframework.result_mapping = 1 : i32} { + %res = "mhlo.outfeed"(%data, %token) { + outfeed_config = "", xla_shape = "token[]" + } : (tensor<2xui32>, !mhlo.token) -> !mhlo.token + func.return %res : !mhlo.token +} + +// CHECK-LABEL: @outfeed_1_input +// CHECK-SAME: %[[DATA:.*]]: tensor<2xui32> +// CHECK-SAME: %[[TOKEN:.*]]: !mhlo.token +// CHECK: "xla_cpu.outfeed"(%[[DATA]]) {config = "", result_type = [ui32]} : (tensor<2xui32>) -> () +// CHECK: return %[[TOKEN]] : !mhlo.token + +func.func @outfeed_2_input(%data1: tensor<3xui32>, %data2: tensor<3xi32>, %token: !mhlo.token) -> !mhlo.token { + %res = "mhlo.outfeed"(%data1, %data2, %token) {outfeed_config = "foobar"} + : (tensor<3xui32>, tensor<3xi32>, !mhlo.token) -> !mhlo.token + func.return %res : !mhlo.token +} + +// CHECK-LABEL: @outfeed_2_input +// CHECK-SAME: %[[ARG0:.*]]: tensor<3xui32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<3xi32> +// CHECK: "xla_cpu.outfeed"(%[[ARG0]], %[[ARG1]]) {config = "foobar", result_type = [ui32, i32]} +// CHECK-SAME: (tensor<3xui32>, tensor<3xi32>) + +func.func @add_dependency(%arg0: tensor<16xf32>, %arg1: !mhlo.token) -> tensor<16xf32> { + %0 = "mhlo.add_dependency"(%arg0, %arg1) : (tensor<16xf32>, !mhlo.token) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: @add_dependency +// CHECK-SAME: %[[ARG0:.*]]: tensor<16xf32> +// CHECK-SAME: %[[ARG1:.*]]: !mhlo.token +// CHECK: %[[RES:.*]] = "xla_cpu.add_dependency" +// CHECK-SAME: %[[ARG0]], %[[ARG1]] +// CHECK: return %[[RES]] : tensor<16xf32> \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/collective_ops_to_cpu_runtime.mlir b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/collective_ops_to_cpu_runtime.mlir new file mode 100644 index 00000000000..f50cc3752f3 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/collective_ops_to_cpu_runtime.mlir @@ -0,0 +1,102 @@ +// RUN: xla-cpu-opt %s -split-input-file -xla-lmhlo-to-cpu-runtime | FileCheck %s + +func.func @partition_id() -> i32 { + %0 = "xla_cpu.partition_id"() : () -> i32 + func.return %0 : i32 +} + +// CHECK-LABEL: @partition_id +// CHECK: call @xla.cpu.partition_id() : () -> i32 + +// CHECK: func private @xla.cpu.partition_id() -> i32 attributes {rt.custom_call = "xla.cpu.partition_id"} + +// ----- + +func.func @replica_id() -> i32 { + %0 = "xla_cpu.replica_id"() : () -> i32 + func.return %0 : i32 +} + +// CHECK-LABEL: @replica_id +// CHECK: call @xla.cpu.replica_id() : () -> i32 + +// CHECK: func private @xla.cpu.replica_id() -> i32 attributes {rt.custom_call = "xla.cpu.replica_id"} + +// ----- + +#map = affine_map<(d0)[s0] -> (d0 + s0)> +func.func @all_reduce(%arg0: memref<32xf32, #map>, %arg1: memref<32xf32>) { + "xla_cpu.all_reduce"(%arg0, %arg1) { + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, + channel_handle = 42 : i64, + reduction_kind = 3 : i32, + use_global_device_ids = 0 : i32 + } : (memref<32xf32, #map>, memref<32xf32>) -> () + func.return +} + +// CHECK-LABEL: @all_reduce +// CHECK-SAME: %[[ARG0:.*]]: memref<32xf32, +// CHECK-SAME: %[[ARG1:.*]]: memref<32xf32> +// CHECK: %[[ALLOC:.*]] = memref.alloc +// CHECK: memref.copy %[[ARG0]], %[[ALLOC]] +// CHECK: call @xla.cpu.all_reduce(%[[ALLOC]], %[[ARG1]]) +// CHECK-SAME: channel_handle = 42 +// CHECK-SAME: op_id = 0 +// CHECK-SAME: reduction_kind = 3 +// CHECK-SAME: replica_groups = dense< +// CHECK: func.func private @xla.cpu.all_reduce( +// CHECK-SAME: memref<32xf32>, memref<32xf32>) +// CHECK-SAME: attributes {rt.custom_call = "xla.cpu.all_reduce"} + + +// ----- + +func.func @collective_permute(%arg0: memref<16x8xf32>, %arg1: memref<16x8xf32>) { + "xla_cpu.collective_permute"(%arg0, %arg1) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, + channel_handle = 42 : i64 + } : (memref<16x8xf32>, memref<16x8xf32>) -> () + func.return +} + +// CHECK-LABEL: @collective_permute +// CHECK-SAME: %[[ARG0:.*]]: memref<16x8xf32>, +// CHECK-SAME: %[[ARG1:.*]]: memref<16x8xf32> +// CHECK: call @xla.cpu.collective_permute(%[[ARG0]], %[[ARG1]]) +// CHECK-SAME: channel_handle = 42 +// CHECK-SAME: source_target_pairs = dense< +// CHECK: func.func private @xla.cpu.collective_permute( +// CHECK-SAME: attributes {rt.custom_call = "xla.cpu.collective_permute"} + +// ----- + +func.func @rng_bit_generator_default(%state: memref<3xui64>, + %state_out: memref<3xui64>, %values_out: memref<10xui32>) { + "xla_cpu.rng_bit_generator"(%state, %state_out, %values_out) + {rng_algorithm = #mhlo.rng_algorithm + } : (memref<3xui64>, memref<3xui64>, memref<10xui32>) -> () + return +} + +// CHECK-LABEL: @rng_bit_generator_default +// CHECK-SAME: %[[ARG0:.*]]: memref<3xui64>, %[[ARG1:.*]]: memref<3xui64>, +// CHECK-SAME: %[[ARG2:.*]]: memref<10xui32> +// CHECK: call @xla.cpu.rng.philox(%[[ARG0]], %[[ARG1]], %[[ARG2]]) +// CHECK: func.func private @xla.cpu.rng.philox( +// CHECK-SAME: attributes {rt.custom_call = "xla.cpu.rng.philox"} + +// ----- + +func.func @rng_bit_generator_three_fry(%state: memref<2xui64>, + %state_out: memref<2xui64>, %values_out: memref<10xui32>) { + "xla_cpu.rng_bit_generator"(%state, %state_out, %values_out) + {rng_algorithm = #mhlo.rng_algorithm + } : (memref<2xui64>, memref<2xui64>, memref<10xui32>) -> () + return +} + +// CHECK-LABEL: @rng_bit_generator_three_fry +// CHECK: call @xla.cpu.rng.three_fry( +// CHECK: func.func private @xla.cpu.rng.three_fry( +// CHECK-SAME: attributes {rt.custom_call = "xla.cpu.rng.three_fry"} diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/fft.mlir b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/fft.mlir new file mode 100644 index 00000000000..a914a754148 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/fft.mlir @@ -0,0 +1,16 @@ +// RUN: xla-cpu-opt %s -xla-legalize-collective-ops | FileCheck %s + +func.func @fft(%arg0: tensor<3x5x4x8x256xf32>) -> tensor<3x5x4x8x129xcomplex> { + %0 = "mhlo.fft"(%arg0) { + fft_length = dense<[4, 8, 256]> : tensor<3xi64>, + fft_type = #mhlo + } : (tensor<3x5x4x8x256xf32>) -> tensor<3x5x4x8x129xcomplex> + func.return %0 : tensor<3x5x4x8x129xcomplex> +} + +// CHECK-LABEL: @fft +// CHECK-SAME: %[[ARG0:.*]]: tensor +// CHECK: %[[DST:.*]] = tensor.empty() : tensor<3x5x4x8x129xcomplex> +// CHECK: %[[FFT:.*]] = "xla_cpu.fft"(%[[ARG0]], %[[DST]]) +// CHECK-SAME: {fft_length = [4, 8, 256], fft_type = 2 : i32} +// CHECK: return %[[FFT]] diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/legalize_i1_vector_transfers.mlir b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/legalize_i1_vector_transfers.mlir new file mode 100644 index 00000000000..39914b38770 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/legalize_i1_vector_transfers.mlir @@ -0,0 +1,35 @@ +// RUN: xla-cpu-opt %s -split-input-file -xla-legalize-i1-vector-transfers \ +// RUN: | FileCheck %s + +func.func @transfer_read(%in: memref<8xi1>) -> vector<8xi1> { + %pad = arith.constant true + %c1 = arith.constant 1 : index + %ret = vector.transfer_read %in[%c1], %pad : memref<8xi1>, vector<8xi1> + return %ret : vector<8xi1> +} + +// CHECK-LABEL: @transfer_read +// CHECK-SAME: %[[IN:.*]]: memref<8xi1> +// CHECK-DAG: %[[C1_I8:.*]] = arith.constant 1 : i8 +// CHECK-DAG: %[[C0_V:.*]] = arith.constant dense<0> : vector<8xi8> +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[CAST:.*]] = xla_cpu.memref_element_cast %[[IN]] +// CHECK: %[[READ:.*]] = vector.transfer_read %[[CAST]][%[[C1]]], +// CHECK-SAME: %[[C1_I8]] +// CHECK: %[[RET:.*]] = arith.cmpi ne, %[[READ]], %[[C0_V]] +// CHECK: return %[[RET]] + +func.func @transfer_write(%in: vector<8xi1>, %out: memref<8xi1>) { + %c0 = arith.constant 0 : index + vector.transfer_write %in, %out[%c0] : vector<8xi1>, memref<8xi1> + return +} + +// CHECK-LABEL: @transfer_write +// CHECK-SAME: %[[IN:.*]]: vector<8xi1> +// CHECK-SAME: %[[OUT:.*]]: memref<8xi1> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[CAST_IN:.*]] = arith.extui %[[IN]] {{.*}} to vector<8xi8> +// CHECK-DAG: %[[CAST_OUT:.*]] = xla_cpu.memref_element_cast %[[OUT]] +// CHECK-NOT: vector.transfer_write {{.*}}%[[IN]] +// CHECK: vector.transfer_write %[[CAST_IN]], %[[CAST_OUT]][%[[C0]]] diff --git a/tensorflow/compiler/xla/mlir/transforms/cpu/tests/lmhlo_custom_call.mlir b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/lmhlo_custom_call.mlir similarity index 72% rename from tensorflow/compiler/xla/mlir/transforms/cpu/tests/lmhlo_custom_call.mlir rename to tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/lmhlo_custom_call.mlir index 05130d71154..12162cdfe90 100644 --- a/tensorflow/compiler/xla/mlir/transforms/cpu/tests/lmhlo_custom_call.mlir +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/lmhlo_custom_call.mlir @@ -9,8 +9,9 @@ func.func @test(%arg0: memref) { // CHECK-SAME: api_version = 2 : i32 // CHECK-SAME: call_target_name = "target" // CHECK-SAME: num_results = 1 : i32 + // CHECK-SAME: output_tuple = false // CHECK-SAME: : (memref) -> () - "lmhlo.custom_call"(%arg0) { + "lmhlo.custom_call"(%arg0) ({}) { api_version = 2 : i32, call_target_name = "target", operand_segment_sizes = array @@ -43,7 +44,8 @@ func.func @test_with_mapping( // CHECK-SAME: api_version = 1 : i32 // CHECK-SAME: call_target_name = "target" // CHECK-SAME: num_results = 4 : i32 - "lmhlo.custom_call"(%arg0, %arg1, %arg2, %arg3, %arg4) { + // CHECK-SAME: output_tuple = true + "lmhlo.custom_call"(%arg0, %arg1, %arg2, %arg3, %arg4) ({}) { api_version = 1 : i32, call_target_name = "target", operand_segment_sizes = array, @@ -61,3 +63,24 @@ func.func @test_with_mapping( // CHECK-SAME: memref, memref<0xi8>, memref, memref, // CHECK-SAME: memref<0xi8>, memref) // CHECK-SAME: attributes {rt.custom_call = "xla.cpu.custom_call"} + +// ----- + +// CHECK: func @one_element_output_tuple +// CHECK: %[[ARG0:.*]]: memref +// CHECK: ) +func.func @one_element_output_tuple(%arg0: memref) { + // CHECK: call @[[CUSTOM_CALL:.*]](%[[ARG0]]) + // CHECK-SAME: api_version = 2 : i32 + // CHECK-SAME: call_target_name = "target" + // CHECK-SAME: num_results = 1 : i32 + // CHECK-SAME: output_tuple = true + // CHECK-SAME: : (memref) -> () + "lmhlo.custom_call"(%arg0) ({}) { + api_version = 2 : i32, + call_target_name = "target", + operand_segment_sizes = array, + xla_shape = "(f32[])" + } : (memref) -> () + return +} diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/lmhlo_infeed.mlir b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/lmhlo_infeed.mlir new file mode 100644 index 00000000000..5db2be4725e --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/lmhlo_infeed.mlir @@ -0,0 +1,13 @@ +// RUN: xla-cpu-opt %s -xla-lmhlo-to-cpu-runtime | FileCheck %s + +// CHECK: func @cpu_infeed( +// CHECK: %[[ARG0:[a-z0-9]+]]: memref<8xf32> +// CHECK: ) +func.func @cpu_infeed(%arg0: memref<8xf32>) { + // CHECK: call @[[INFEED:.*]](%[[ARG0]]) : (memref<8xf32>) -> () + "lmhlo.infeed"(%arg0) {config = "abc"} : (memref<8xf32>) -> () + return +} + +// CHECK: func private @[[INFEED]](memref<8xf32>) +// CHECK-SAME: attributes {rt.custom_call = "xla.cpu.infeed"} diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/remove_copies_to_out_params.mlir b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/remove_copies_to_out_params.mlir new file mode 100644 index 00000000000..fbe3b502ca6 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/remove_copies_to_out_params.mlir @@ -0,0 +1,127 @@ +// RUN: xla-cpu-opt %s -split-input-file -xla-remove-copies-to-out-params \ +// RUN: | FileCheck %s + +func.func @alloca(%arg0: memref, %arg1: memref) { + %0 = memref.load %arg0[] : memref + %1 = arith.addf %0, %0 : f64 + %alloca = memref.alloca() : memref + memref.store %1, %alloca[] : memref + memref.copy %alloca, %arg1 : memref to memref + return +} + +// CHECK-LABEL: func.func @alloca( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: memref) { +// CHECK: %[[R0:.*]] = memref.load %[[ARG0]][] : memref +// CHECK: %[[R1:.*]] = arith.addf %[[R0]], %[[R0]] : f64 +// CHECK-NOT memref.alloca +// CHECK: memref.store %[[R1]], %[[ARG1]][] : memref +// CHECK-NOT: memref.copy +// CHECK-NEXT: return +// CHECK: } + +// ----- + +func.func @alloc_vectorized(%arg0: memref<1024xf64>, %arg1: memref<1024xf64>) { + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %cst = arith.constant 0.000000e+00 : f64 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1024xf64> + scf.parallel (%arg2) = (%c0) to (%c1024) step (%c8) { + %subview = memref.subview %alloc[%arg2] [8] [1] : + memref<1024xf64> to memref<8xf64, strided<[1], offset: ?>> + %0 = vector.transfer_read %arg0[%arg2], %cst {in_bounds = [true]} : + memref<1024xf64>, vector<8xf64> + %1 = arith.addf %0, %0 : vector<8xf64> + vector.transfer_write %1, %subview[%c0] {in_bounds = [true]} : + vector<8xf64>, memref<8xf64, strided<[1], offset: ?>> + scf.yield + } + memref.copy %alloc, %arg1 : memref<1024xf64> to memref<1024xf64> + memref.dealloc %alloc : memref<1024xf64> + return +} + +// CHECK-LABEL: func.func @alloc_vectorized( +// CHECK-SAME: %[[ARG0:.*]]: memref<1024xf64>, +// CHECK-SAME: %[[ARG1:.*]]: memref<1024xf64>) { +// CHECK-NOT: memref.alloc +// CHECK: scf.parallel +// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ARG1]] +// CHECK: %[[R0:.*]] = vector.transfer_read %[[ARG0]] +// CHECK: %[[R1:.*]] = arith.addf %[[R0]], %[[R0]] : vector<8xf64> +// CHECK: vector.transfer_write %[[R1]], %[[SUBVIEW]] +// CHECK: scf.yield +// CHECK: } +// CHECK-NOT: memref.copy +// CHECK-NOT: memref.dealloc +// CHECK-NEXT: return +// CHECK: } + +// ----- + +// Similar to alloc_vectorized, but with two output params (%arg1 and %arg2). +// Note: %arg1 = %arg0 + %arg0, and %arg2 = (%arg0 + %arg0) * %arg0 +func.func @alloc2_vectorized(%arg0: memref<256xf64>, + %arg1: memref<256xf64>, + %arg2: memref<256xf64>) { + %c256 = arith.constant 256 : index + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %cst = arith.constant 0.000000e+00 : f64 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<256xf64> + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<256xf64> + scf.parallel (%arg3) = (%c0) to (%c256) step (%c8) { + %alloca = memref.alloca() : memref<8xf64> + %0 = vector.transfer_read %arg0[%arg3], %cst {in_bounds = [true]} : memref<256xf64>, vector<8xf64> + %1 = arith.addf %0, %0 : vector<8xf64> + vector.transfer_write %1, %alloca[%c0] {in_bounds = [true]} : vector<8xf64>, memref<8xf64> + %subview = memref.subview %alloc_0[%arg3] [8] [1] : memref<256xf64> to memref<8xf64, strided<[1], offset: ?>> + memref.copy %alloca, %subview : memref<8xf64> to memref<8xf64, strided<[1], offset: ?>> + scf.yield + } + scf.parallel (%arg3) = (%c0) to (%c256) step (%c8) { + %subview = memref.subview %alloc[%arg3] [8] [1] : memref<256xf64> to memref<8xf64, strided<[1], offset: ?>> + %0 = vector.transfer_read %alloc_0[%arg3], %cst {in_bounds = [true]} : memref<256xf64>, vector<8xf64> + %1 = vector.transfer_read %arg0[%arg3], %cst {in_bounds = [true]} : memref<256xf64>, vector<8xf64> + %2 = arith.mulf %0, %1 : vector<8xf64> + vector.transfer_write %2, %subview[%c0] {in_bounds = [true]} : vector<8xf64>, memref<8xf64, strided<[1], offset: ?>> + scf.yield + } + memref.copy %alloc_0, %arg1 : memref<256xf64> to memref<256xf64> + memref.dealloc %alloc_0 : memref<256xf64> + memref.copy %alloc, %arg2 : memref<256xf64> to memref<256xf64> + memref.dealloc %alloc : memref<256xf64> + return +} + +// CHECK-LABEL: func.func @alloc2_vectorized( +// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: memref<256xf64>, +// CHECK-SAME: %[[ARG1:.*]]: memref<256xf64>, +// CHECK-SAME: %[[ARG2:.*]]: memref<256xf64>) { +// CHECK-NOT: memref.alloc +// CHECK: scf.parallel +// CHECK: %[[ALLOCA:.*]] = memref.alloca() +// CHECK: %[[R0:.*]] = vector.transfer_read %[[ARG0]] +// CHECK: %[[R1:.*]] = arith.addf %[[R0]], %[[R0]] +// CHECK: vector.transfer_write %[[R1]], %[[ALLOCA]] +// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ARG1]] +// CHECK: memref.copy %[[ALLOCA]], %[[SUBVIEW]] +// CHECK: scf.yield +// CHECK: } +// CHECK-NOT: memref.copy +// CHECK-NOT: memref.dealloc +// CHECK-NEXT: scf.parallel +// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ARG2]] +// CHECK: %[[R0:.*]] = vector.transfer_read %[[ARG1]] +// CHECK: %[[R1:.*]] = vector.transfer_read %[[ARG0]] +// CHECK: %[[R2:.*]] = arith.mulf %[[R0]], %[[R1]] +// CHECK: vector.transfer_write %[[R2]], %[[SUBVIEW]] +// CHECK: scf.yield +// CHECK: } +// CHECK-NOT: memref.copy +// CHECK-NOT: memref.dealloc +// CHECK-NEXT: return +// CHECK: } diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/rng_bit_generator.mlir b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/rng_bit_generator.mlir new file mode 100644 index 00000000000..c1b934dd693 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/rng_bit_generator.mlir @@ -0,0 +1,16 @@ +// RUN: xla-cpu-opt %s -xla-legalize-collective-ops | FileCheck %s + +func.func @rng_bit_generator(%state: tensor<2xui64>) -> (tensor<2xui64>, tensor<10xui32>) { + %new_state, %output = "mhlo.rng_bit_generator"(%state) { + rng_algorithm = #mhlo.rng_algorithm + } : (tensor<2xui64>) -> (tensor<2xui64>, tensor<10xui32>) + func.return %new_state, %output : tensor<2xui64>, tensor<10xui32> +} + +// CHECK-LABEL: @rng_bit_generator +// CHECK-SAME: %[[ARG0:.*]]: tensor +// CHECK: %[[STATE_INIT:.*]] = tensor.empty() : tensor<2xui64> +// CHECK: %[[DST_INIT:.*]] = tensor.empty() : tensor<10xui32> +// CHECK: "xla_cpu.rng_bit_generator"(%[[ARG0]], %[[STATE_INIT]], %[[DST_INIT]]) +// CHECK-SAME: {rng_algorithm = #mhlo.rng_algorithm} : +// CHECK-SAME: (tensor<2xui64>, tensor<2xui64>, tensor<10xui32>) -> (tensor<2xui64>, tensor<10xui32>) diff --git a/tensorflow/compiler/xla/mlir/transforms/cpu/tests/xla_abi_legalization.mlir b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/xla_abi_legalization.mlir similarity index 90% rename from tensorflow/compiler/xla/mlir/transforms/cpu/tests/xla_abi_legalization.mlir rename to tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/xla_abi_legalization.mlir index 03a649def74..9e6b3520a3f 100644 --- a/tensorflow/compiler/xla/mlir/transforms/cpu/tests/xla_abi_legalization.mlir +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/xla_abi_legalization.mlir @@ -92,7 +92,7 @@ func.func @custom_call(%arg0: tensor, %arg1: tensor<2x3xf32>) -> (tensor<6x // CHECK-NOT: result_layouts // CHECK: %[[T1:.*]] = "mhlo.transpose"(%[[ARG1]]) {{.*}} -> tensor<3x2xf32> // CHECK: %[[R1:.*]] = mhlo.reshape %[[T1]] {{.*}} -> tensor<2x3xf32> -// CHECK: %[[CC:.*]]:2 = "mhlo.custom_call"(%[[ARG0]], %[[R1]]) +// CHECK: %[[CC:.*]]:2 = mhlo.custom_call @yolo(%[[ARG0]], %[[R1]]) // CHECK: %[[RR:.*]] = mhlo.reshape %[[CC]]#0 {{.*}} -> tensor<3x6xf32> // CHECK: %[[TR:.*]] = "mhlo.transpose"(%[[RR]]) {{.*}} -> tensor<6x3xf32> // CHECK: return %[[TR]], %[[CC]]#1 @@ -107,7 +107,7 @@ func.func @custom_call_i1_input(%arg0: tensor<42xi1>) { // CHECK-LABEL: @custom_call_i1_input // CHECK: %[[CONVERTED:.*]] = mhlo.convert {{.*}} : (tensor<42xi1>) -> tensor<42xui8> -// CHECK: "mhlo.custom_call"(%[[CONVERTED]]) +// CHECK: mhlo.custom_call @yolo(%[[CONVERTED]]) // ----- @@ -122,4 +122,18 @@ func.func @constant_with_layout() -> tensor<2x3xf32> { // CHECK-LABEL: @constant_with_layout // CHECK: %[[CST:.*]] = mhlo.constant {{.*}} : tensor<3x2xf32> // CHECK: %[[TR:.*]] = "mhlo.transpose"(%[[CST]]) {{.*}} -> tensor<2x3xf32> -// CHECK: return %[[TR]] \ No newline at end of file +// CHECK: return %[[TR]] + +// ----- + +func.func @non_tensor_inouts() -> !mhlo.token { + %0 = mhlo.create_token : !mhlo.token + %1 = "mhlo.custom_call"(%0) { + call_target_name = "yolo", + operand_layouts = [dense<> : tensor<0xindex>], + result_layouts = [dense<> : tensor<0xindex>] + } : (!mhlo.token) -> (!mhlo.token) + return %1 : !mhlo.token +} + +// CHECK-LABEL: @non_tensor_inouts diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/xla_cpu_memref_element_cast_to_llvm.mlir b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/xla_cpu_memref_element_cast_to_llvm.mlir new file mode 100644 index 00000000000..6ec509beecc --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/xla_cpu_memref_element_cast_to_llvm.mlir @@ -0,0 +1,48 @@ +// RUN: xla-cpu-opt -xla-convert-memref-element-cast-to-llvm %s \ +// RUN: -split-input-file | FileCheck %s + +func.func @memref_cast(%arg0: memref<10xf32>) -> memref<10xi32> { + %ret = xla_cpu.memref_element_cast %arg0 : memref<10xf32> to memref<10xi32> + return %ret : memref<10xi32> +} +// CHECK-LABEL: func.func @memref_cast( +// CHECK-SAME: %[[SRC:.*]]: memref<10xf32>) -> memref<10xi32> +// CHECK: %[[SRC_DESC:.*]] = builtin.unrealized_conversion_cast %[[SRC]] +// CHECK-SAME: : memref<10xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[ALLOC_PTR:.*]] = llvm.extractvalue %[[SRC_DESC]][0] +// CHECK-NEXT: %[[ALIGN_PTR:.*]] = llvm.extractvalue %[[SRC_DESC]][1] + +// CHECK: %[[ALLOC_PTR_CAST:.*]] = llvm.bitcast %[[ALLOC_PTR]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: %[[ALIGN_PTR_CAST:.*]] = llvm.bitcast %[[ALIGN_PTR]] : !llvm.ptr to !llvm.ptr + +// CHECK: %[[DST_DESC:.*]] = llvm.mlir.undef +// CHECK-SAME: : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[DST_DESC_:.*]] = llvm.insertvalue %[[ALLOC_PTR_CAST]], %[[DST_DESC]][0] +// CHECK-NEXT: llvm.insertvalue %[[ALIGN_PTR_CAST]], %[[DST_DESC_]][1] + +// CHECK: builtin.unrealized_conversion_cast +// CHECK-SAME: : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<10xi32> + +// ----- + +func.func @memref_cast_i1(%arg0: memref<10xi1>) -> memref<10xi8> { + %ret = xla_cpu.memref_element_cast %arg0 : memref<10xi1> to memref<10xi8> + return %ret : memref<10xi8> +} +// CHECK-LABEL: func.func @memref_cast_i1( +// CHECK-SAME: %[[SRC:.*]]: memref<10xi1>) -> memref<10xi8> +// CHECK: %[[SRC_DESC:.*]] = builtin.unrealized_conversion_cast %[[SRC]] +// CHECK-SAME: : memref<10xi1> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[ALLOC_PTR:.*]] = llvm.extractvalue %[[SRC_DESC]][0] +// CHECK-NEXT: %[[ALIGN_PTR:.*]] = llvm.extractvalue %[[SRC_DESC]][1] + +// CHECK: %[[ALLOC_PTR_CAST:.*]] = llvm.bitcast %[[ALLOC_PTR]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: %[[ALIGN_PTR_CAST:.*]] = llvm.bitcast %[[ALIGN_PTR]] : !llvm.ptr to !llvm.ptr + +// CHECK: %[[DST_DESC:.*]] = llvm.mlir.undef +// CHECK-SAME: : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[DST_DESC_:.*]] = llvm.insertvalue %[[ALLOC_PTR_CAST]], %[[DST_DESC]][0] +// CHECK-NEXT: llvm.insertvalue %[[ALIGN_PTR_CAST]], %[[DST_DESC_]][1] + +// CHECK: builtin.unrealized_conversion_cast +// CHECK-SAME: : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<10xi8> diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/xla_cpu_outfeed.mlir b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/xla_cpu_outfeed.mlir new file mode 100644 index 00000000000..5ab7ae76411 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/tests/xla_cpu_outfeed.mlir @@ -0,0 +1,37 @@ +// RUN: xla-cpu-opt %s -split-input-file -xla-lmhlo-to-cpu-runtime \ +// RUN: | FileCheck %s + +func.func @cpu_onfeed(%arg0: memref<8xf32>, %arg1: memref<10xui32>) { + "xla_cpu.outfeed"(%arg0, %arg1) {config = "abc", result_type = [f32, ui32]} : (memref<8xf32>, memref<10xui32>) -> () + return +} + +// CHECK: func @cpu_onfeed( +// CHECK-SAME: %[[ARG0:[a-z0-9]+]]: memref<8xf32> +// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: memref<10xui32> +// CHECK-SAME: ) +// CHECK: call @[[OUTFEED:.*]](%[[ARG0]], %[[ARG1]]) +// CHECK-SAME: {result_type = [11 : i32, 8 : i32]} : (memref<8xf32>, memref<10xui32>) -> () +// CHECK: func private @[[OUTFEED]](memref<8xf32>, memref<10xui32>) +// CHECK-SAME: attributes {rt.custom_call = "xla.cpu.outfeed"} + +// ----- + +func.func @cpu_onfeed_strided( + %arg0: memref<8x8xf32, strided<[?, 1], offset: ?>>, + %arg1: memref<10xui32>) { + "xla_cpu.outfeed"(%arg0, %arg1) {config = "abc", result_type = [f32, ui32]} + : (memref<8x8xf32, strided<[?, 1], offset: ?>>, memref<10xui32>) -> () + return +} + +// CHECK: func @cpu_onfeed_strided( +// CHECK-SAME: %[[ARG0:[a-z0-9]+]]: memref<8x8xf32, strided<[?, 1], offset: ?>> +// CHECK-SAME: %[[ARG1:[a-z0-9]+]]: memref<10xui32> +// CHECK-SAME: ) +// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() +// CHECK-NEXT: memref.copy %[[ARG0]], %[[ALLOC]] +// CHECK: call @[[OUTFEED:.*]](%[[ALLOC]], %[[ARG1]]) +// CHECK-SAME: {result_type = [11 : i32, 8 : i32]} : (memref<8x8xf32>, memref<10xui32>) -> () +// CHECK: func private @[[OUTFEED]](memref<8x8xf32>, memref<10xui32>) +// CHECK-SAME: attributes {rt.custom_call = "xla.cpu.outfeed"} \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir/transforms/cpu/xla_abi_legalization.cc b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/xla_abi_legalization.cc similarity index 90% rename from tensorflow/compiler/xla/mlir/transforms/cpu/xla_abi_legalization.cc rename to tensorflow/compiler/xla/mlir/backends/cpu/transforms/xla_abi_legalization.cc index 6c1eac0e04f..fb300c49270 100644 --- a/tensorflow/compiler/xla/mlir/transforms/cpu/xla_abi_legalization.cc +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/xla_abi_legalization.cc @@ -20,22 +20,22 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/xla/mlir/transforms/cpu/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace xla { namespace cpu { namespace { #define GEN_PASS_DEF_LEGALIZEXLAABIPASS -#include "tensorflow/compiler/xla/mlir/transforms/cpu/passes.h.inc" +#include "tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h.inc" using namespace mlir; // NOLINT @@ -75,14 +75,14 @@ Value NormalizeTensor(ImplicitLocOpBuilder& b, TypedValue tensor, return b.create(tensor.getType(), transpose); } -void NormalizeInputInPlace(ImplicitLocOpBuilder& b, - TypedValue tensor, +void NormalizeInputInPlace(ImplicitLocOpBuilder& b, Value tensor, ArrayRef layout) { - if (IsDefaultLayout(layout)) { + auto typedTensor = tensor.dyn_cast>(); + if (!typedTensor || IsDefaultLayout(layout)) { return; } - Value normalized = NormalizeTensor(b, tensor, layout, /*isInput=*/true); + Value normalized = NormalizeTensor(b, typedTensor, layout, /*isInput=*/true); tensor.replaceAllUsesExcept( normalized, normalized.getDefiningOp()->getOperand(0).getDefiningOp()); } @@ -97,7 +97,9 @@ SmallVector> FlattenLayoutAttribute(Attribute attr) { }; if (auto array = attr.dyn_cast()) { - array.walkSubAttrs(visit_attr); + for (int64_t i = 0; i < array.size(); ++i) { + visit_attr(array[i]); + } } else { visit_attr(attr); } @@ -125,7 +127,7 @@ struct RewriteInputArgs : OpRewritePattern { ImplicitLocOpBuilder b(op.getLoc(), rewriter); b.setInsertionPointToStart(&op.getBody().front()); - BlockAndValueMapping bvm; + IRMapping bvm; for (const auto&& [param, layout] : llvm::zip(op.getArguments(), param_layouts)) { NormalizeInputInPlace(b, param, layout); @@ -161,7 +163,9 @@ struct RewriteReturnArgs : OpRewritePattern { results.push_back( IsDefaultLayout(layout) ? result - : NormalizeTensor(b, result, layout, /*isInput=*/false)); + : NormalizeTensor(b, result.cast>(), + layout, + /*isInput=*/false)); } func->removeAttr("xla_entry_computation_result_layout"); @@ -228,12 +232,13 @@ struct RewriteCustomCalls : OpRewritePattern { for (const auto& [index, operand] : llvm::enumerate(op.getOperands())) { const auto& layout = operand_layouts[index]; if (!IsDefaultLayout(layout)) { - Value normalized = NormalizeTensor(b, op.getOperand(index), layout, - /*isInput=*/false); + Value normalized = NormalizeTensor( + b, op.getOperand(index).cast>(), layout, + /*isInput=*/false); op.setOperand(index, normalized); } } - op.removeOperand_layoutsAttr(); + op.removeOperandLayoutsAttr(); } // Rewrite i1 inputs to ui8. @@ -257,7 +262,7 @@ struct RewriteCustomCalls : OpRewritePattern { NormalizeInputInPlace(b, result, layout); } - op.removeResult_layoutsAttr(); + op.removeResultLayoutsAttr(); } return success(); diff --git a/tensorflow/compiler/xla/mlir/backends/cpu/transforms/xla_cpu_memref_element_cast_to_llvm.cc b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/xla_cpu_memref_element_cast_to_llvm.cc new file mode 100644 index 00000000000..a9f4ab79dfe --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/cpu/transforms/xla_cpu_memref_element_cast_to_llvm.cc @@ -0,0 +1,117 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Analysis/DataLayoutAnalysis.h" // from @llvm-project +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.h" + +namespace xla { +namespace cpu { +namespace { + +#define GEN_PASS_DEF_CONVERTXLACPUMEMREFELEMENTCASTTOLLVMPASS +#include "tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h.inc" + +using namespace mlir; // NOLINT + +struct MemRefElementCastOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + xla_cpu::MemRefElementCastOp>::ConvertOpToLLVMPattern; + + LogicalResult matchAndRewrite( + xla_cpu::MemRefElementCastOp cast_op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto target_memref_ty = cast_op.getDst().getType().cast(); + + LLVMTypeConverter type_converter = *getTypeConverter(); + auto target_desc_ty = type_converter.convertType(target_memref_ty) + .dyn_cast_or_null(); + if (!target_desc_ty) { + return failure(); + } + + // Unpack the descriptor into the list of its fields. + Location loc = cast_op.getLoc(); + Type src_type = cast_op.getSrc().getType(); + + SmallVector desc_fields; + MemRefDescriptor::unpack(rewriter, loc, adaptor.getSrc(), + src_type.cast(), desc_fields); + + // Bitcast allocated and aligned pointers. + auto dst_elem_ty = + typeConverter->convertType(cast_op.getType().getElementType()); + auto dst_elem_ptr_ty = LLVM::LLVMPointerType::get( + dst_elem_ty, cast_op.getType().getMemorySpaceAsInt()); + desc_fields[0] = + rewriter.create(loc, dst_elem_ptr_ty, desc_fields[0]); + desc_fields[1] = + rewriter.create(loc, dst_elem_ptr_ty, desc_fields[1]); + + // Create descriptor. + auto dst_desc = MemRefDescriptor::pack(rewriter, loc, type_converter, + cast_op.getType(), desc_fields); + rewriter.replaceOp(cast_op, {dst_desc}); + return success(); + } +}; + +struct ConvertXlaCpuMemRefElementCastToLLVMPass + : public impl::ConvertXlaCpuMemRefElementCastToLLVMPassBase< + ConvertXlaCpuMemRefElementCastToLLVMPass> { + ConvertXlaCpuMemRefElementCastToLLVMPass() = default; + + void runOnOperation() override { + Operation *op = getOperation(); + const auto &data_layout_analysis = getAnalysis(); + LowerToLLVMOptions options(&getContext(), + data_layout_analysis.getAtOrAbove(op)); + + LLVMTypeConverter type_converter(&getContext(), options, + &data_layout_analysis); + RewritePatternSet patterns(&getContext()); + patterns.add(type_converter); + + LLVMConversionTarget target(getContext()); + target.addLegalOp(); + if (failed(applyPartialConversion(op, target, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +createConvertXlaCpuMemRefElementCastToLLVMPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/mlir/tools/xla_cpu_opt.cc b/tensorflow/compiler/xla/mlir/backends/cpu/xla-cpu-opt.cc similarity index 50% rename from tensorflow/compiler/xla/mlir/tools/xla_cpu_opt.cc rename to tensorflow/compiler/xla/mlir/backends/cpu/xla-cpu-opt.cc index c8d2027a611..eae0d1ec8c0 100644 --- a/tensorflow/compiler/xla/mlir/tools/xla_cpu_opt.cc +++ b/tensorflow/compiler/xla/mlir/backends/cpu/xla-cpu-opt.cc @@ -13,32 +13,40 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "mlir-hlo/Dialect/gml_st/transforms/test_passes.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" // from @llvm-project +#include "mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project +#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/Dialect/Vector/IR/VectorOps.h" // from @llvm-project +#include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project #include "stablehlo/dialect/Register.h" // from @stablehlo -#include "tensorflow/compiler/xla/mlir/transforms/cpu/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/IR/thlo_ops.h" +#include "tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.h" +#include "tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/test_passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/register.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/thlo/IR/thlo_ops.h" int main(int argc, char **argv) { mlir::mhlo::registerAllMhloPasses(); mlir::lmhlo::registerAllLmhloPasses(); mlir::gml_st::registerGmlStPasses(); mlir::gml_st::registerGmlStTestPasses(); + mlir::bufferization::registerBufferizationPasses(); mlir::DialectRegistry registry; mlir::mhlo::registerAllMhloDialects(registry); mlir::stablehlo::registerAllDialects(registry); registry.insert(); + mlir::linalg::LinalgDialect, mlir::tensor::TensorDialect, + mlir::vector::VectorDialect, mlir::xla_cpu::XlaCpuDialect>(); xla::cpu::registerCpuTransformsPasses(); diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/BUILD b/tensorflow/compiler/xla/mlir/backends/gpu/BUILD new file mode 100644 index 00000000000..1d0879a0234 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/gpu/BUILD @@ -0,0 +1,22 @@ +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_binary") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/compiler/xla/mlir:__subpackages__"], + licenses = ["notice"], +) + +xla_cc_binary( + name = "xla-gpu-opt", + srcs = ["xla-gpu-opt.cc"], + deps = [ + "//tensorflow/compiler/xla/mlir/backends/gpu/transforms:passes", + "//tensorflow/compiler/xla/mlir_hlo:lhlo", + "//tensorflow/compiler/xla/mlir_hlo:lhlo_gpu", + "//tensorflow/compiler/xla/stream_executor:stream_executor_impl", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MlirOptLib", + ], +) diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/BUILD b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/BUILD similarity index 87% rename from tensorflow/compiler/xla/mlir/transforms/gpu/BUILD rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/BUILD index 00143ed1066..087efcb98d4 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/BUILD +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/BUILD @@ -1,9 +1,10 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") package( - default_visibility = ["//tensorflow:internal"], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/compiler/xla:internal"], licenses = ["notice"], ) @@ -29,11 +30,11 @@ cc_library( srcs = [ "add_hlo_trace_annotations.cc", "gpu_to_gpu_runtime.cc", - "launch_func_to_cuda_graph.cc", "lmhlo_gpu_to_gpu_runtime.cc", "lmhlo_to_gpu_launch.cc", "lmhlo_to_gpu_runtime.cc", "memref_get_global_to_arg.cc", + "outline_cuda_graphs.cc", "passes.cc", "uid_generator.h", ], @@ -52,9 +53,8 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:nccl_collective_thunks", "//tensorflow/compiler/xla/stream_executor:blas", "//tensorflow/compiler/xla/translate/mhlo_to_hlo:location_exporter", - "@com_google_absl//absl/status", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ControlFlowDialect", diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/add_hlo_trace_annotations.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_hlo_trace_annotations.cc similarity index 80% rename from tensorflow/compiler/xla/mlir/transforms/gpu/add_hlo_trace_annotations.cc rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_hlo_trace_annotations.cc index 810a2b8d1bb..fc029349b61 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/add_hlo_trace_annotations.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/add_hlo_trace_annotations.cc @@ -20,15 +20,15 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h" #include "tensorflow/compiler/xla/mlir/runtime/ir/rt_dialect.h" -#include "tensorflow/compiler/xla/mlir/transforms/gpu/passes.h" #include "tensorflow/compiler/xla/translate/mhlo_to_hlo/location_exporter.h" namespace xla { namespace gpu { #define GEN_PASS_DEF_ADDHLOTRACEANNOTATIONSPASS -#include "tensorflow/compiler/xla/mlir/transforms/gpu/passes.h.inc" +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h.inc" using namespace mlir; // NOLINT @@ -51,14 +51,6 @@ void AddHloTraceAnnotationsPass::runOnOperation() { ModuleOp module = getOperation(); SymbolTable sym_table(module); - // Get a unique mhlo id from the top level module. - auto uid = module->getAttrOfType("mhlo.unique_id"); - int64_t program_id = uid ? uid.getValue().getZExtValue() : -1; - - // XLA HLO -> MLIR export encodes module name in the location. - std::string module_name = - mlir::mhlo::GetDebugNameFromLocation(module->getLoc()); - getOperation().walk([&](func::CallOp call) { // Check if the callee is a custom call. auto callee = sym_table.lookup(call.getCallee()); @@ -66,7 +58,7 @@ void AddHloTraceAnnotationsPass::runOnOperation() { // HLO operation name is encoded in the operation location. std::string hlo_op = mlir::mhlo::GetDebugNameFromLocation(call->getLoc()); - auto annotation = HloTraceAttr::get(ctx, hlo_op, module_name, program_id); + auto annotation = HloTraceAttr::get(ctx, std::move(hlo_op)); call->setAttr("rt.trace", annotation); }); } diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/gpu_to_gpu_runtime.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/gpu_to_gpu_runtime.cc similarity index 86% rename from tensorflow/compiler/xla/mlir/transforms/gpu/gpu_to_gpu_runtime.cc rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/gpu_to_gpu_runtime.cc index b3ed5ed7565..f0e67615918 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/gpu_to_gpu_runtime.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/gpu_to_gpu_runtime.cc @@ -27,13 +27,14 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/uid_generator.h" #include "tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.h" namespace xla { namespace gpu { #define GEN_PASS_DEF_CONVERTGPUTOGPURUNTIMEPASS -#include "tensorflow/compiler/xla/mlir/transforms/gpu/passes.h.inc" +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h.inc" using namespace mlir; // NOLINT @@ -139,8 +140,9 @@ class LaunchFuncOpLowering : public OpRewritePattern { static constexpr const char kCustomCallTarget[] = "xla.gpu.func.launch"; public: - LaunchFuncOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} + LaunchFuncOpLowering(MLIRContext* ctx, UidGenerator& uid, + CustomCallDeclarations& custom_calls) + : OpRewritePattern(ctx), uid_(uid), custom_calls_(custom_calls) {} LogicalResult matchAndRewrite(LaunchFuncOp op, PatternRewriter& rewriter) const override { @@ -157,6 +159,15 @@ class LaunchFuncOpLowering : public OpRewritePattern { cast(op.getGridSizeZ()), cast(op.getBlockSizeX()), cast(op.getBlockSizeY()), cast(op.getBlockSizeZ())}; + // Shared memory size is optional for the `gpu.launch` but mandatory for the + // Xla runtime kernel launch custom call. + if (op.getDynamicSharedMemorySize()) { + args.insert(args.begin(), op.getDynamicSharedMemorySize()); + } else { + auto zero = b.create(0, b.getI32Type()); + args.insert(args.begin(), zero); + } + // Add kernel arguments. llvm::copy(op.getKernelOperands(), std::back_inserter(args)); @@ -168,6 +179,9 @@ class LaunchFuncOpLowering : public OpRewritePattern { auto call = b.create(callee.getName(), TypeRange(), args); call->setAttr(b.getStringAttr("kernel"), op.getKernelName()); + // Assign a unique id to this instance of a kernel launch operation. + call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); + // Erase the original gpu launch operation. rewriter.eraseOp(op); @@ -175,6 +189,7 @@ class LaunchFuncOpLowering : public OpRewritePattern { } private: + UidGenerator& uid_; CustomCallDeclarations& custom_calls_; }; @@ -188,11 +203,14 @@ void ConvertGpuToGpuRuntimePass::runOnOperation() { SymbolTable sym_table(module); CustomCallDeclarations custom_calls(std::move(sym_table)); + // Each kernel launch operation gets a unique id. + UidGenerator kernel_uid; + // Convert gpu operations to XLA gpu runtime custom calls. RewritePatternSet patterns(ctx); patterns.insert(ctx); - patterns.insert( - ctx, custom_calls); + patterns.insert(ctx, kernel_uid, custom_calls); + patterns.insert(ctx, custom_calls); if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) return signalPassFailure(); diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/lmhlo_gpu_to_gpu_runtime.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc similarity index 76% rename from tensorflow/compiler/xla/mlir/transforms/gpu/lmhlo_gpu_to_gpu_runtime.cc rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc index dfccbd08d2a..2ae32d3cc84 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/lmhlo_gpu_to_gpu_runtime.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_gpu_to_gpu_runtime.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include "absl/strings/str_cat.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -32,18 +31,19 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/uid_generator.h" #include "tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.h" -#include "tensorflow/compiler/xla/mlir/transforms/gpu/uid_generator.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "tensorflow/compiler/xla/stream_executor/blas.h" namespace xla { namespace gpu { #define GEN_PASS_DEF_CONVERTLMHLOGPUTOGPURUNTIMEPASS -#include "tensorflow/compiler/xla/mlir/transforms/gpu/passes.h.inc" +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h.inc" using namespace mlir; // NOLINT @@ -53,6 +53,7 @@ using mlir::lmhlo_gpu::ConvBackwardInputOp; using mlir::lmhlo_gpu::ConvForwardFusedOp; using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp; using mlir::lmhlo_gpu::ConvForwardOp; +using mlir::lmhlo_gpu::CublasLtMatmulF8Op; using mlir::lmhlo_gpu::CublasLtMatmulOp; using mlir::lmhlo_gpu::GEMMOp; @@ -130,16 +131,104 @@ class CublasLtMatmulOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(CublasLtMatmulOp op, PatternRewriter& rewriter) const override { // Get the custom call target. - std::string matmul; - switch (op.getOperands().size()) { - case 4: - matmul = kCustomCallTarget; + std::string matmul = kCustomCallTarget; + + switch (op.getEpilogue()) { + case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Default: + case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Relu: + case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Gelu: + if (op.getNumOperands() != 4) { + return op.emitOpError("unexpected number of operands for matmul"); + } break; - case 5: - matmul = absl::StrCat(kCustomCallTarget, ".bias"); + case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::Bias: + case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasRelu: + case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasGelu: + if (op.getNumOperands() != 5) { + return op.emitOpError("unexpected number of operands for matmul"); + } + matmul += ".bias"; break; - default: - return op.emitOpError("unexpected number of operands for matmul"); + case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::GeluAux: + if (op.getNumOperands() != 5) { + return op.emitOpError("unexpected number of operands for matmul"); + } + matmul += ".aux"; + break; + case mlir::lmhlo_gpu::CublasLtMatmulEpilogue::BiasGeluAux: + if (op.getNumOperands() != 6) { + return op.emitOpError("unexpected number of operands for matmul"); + } + matmul += ".bias.aux"; + break; + } + + // Get or create a custom call function declaration. + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + func::FuncOp callee = custom_calls_.GetOrCreate(b, matmul, op); + + // Convert matmul to a function call. + auto call = rewriter.create(op.getLoc(), callee.getName(), + TypeRange(), op.getOperands()); + + // Assign a unique id to this instance of a matmul operation. + call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())); + + // Copy backend specific attributes. + call->setAttr(b.getStringAttr("algorithm"), op.getAlgorithmAttr()); + call->setAttr(b.getStringAttr("alpha_imag"), op.getAlphaImagAttr()); + call->setAttr(b.getStringAttr("alpha_real"), op.getAlphaRealAttr()); + call->setAttr(b.getStringAttr("beta"), op.getBetaAttr()); + call->setAttr(b.getStringAttr("dot_dims"), op.getDotDimensionNumbers()); + call->setAttr(b.getStringAttr("epilogue"), op.getEpilogueAttr()); + + // TODO(ezhulenev): Today we can't pass an array of enum attributes to the + // custom call. Also we do not have a corresponding precision enum on the + // SE/XLA side, so we encode it as an i32 array (tensor). + if (auto precisions = op.getPrecisionConfig()) { + llvm::SmallVector values; + for (auto precision : *precisions) { + auto value = precision.cast().getValue(); + values.push_back(static_cast(value)); + } + call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr(values)); + } else { + call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr({0, 0})); + } + + // Erase the original matmul operation. + rewriter.eraseOp(op); + + return success(); + } + + private: + UidGenerator& uid_; + CustomCallDeclarations& custom_calls_; +}; + +// As above for FP8 Custom Calls. +class CublasLtMatmulF8OpLowering : public OpRewritePattern { + private: + static constexpr const char kCustomCallTarget[] = + "xla.gpu.cublas.lt.matmul.f8"; + + public: + CublasLtMatmulF8OpLowering(MLIRContext* ctx, UidGenerator& uid, + CustomCallDeclarations& custom_calls) + : OpRewritePattern(ctx), + uid_(uid), + custom_calls_(custom_calls) {} + + LogicalResult matchAndRewrite(CublasLtMatmulF8Op op, + PatternRewriter& rewriter) const override { + // Get the custom call target. + std::string matmul = kCustomCallTarget; + + if (op.getNumOperands() == 9) { + matmul += ".d_amax"; + } else if (op.getNumOperands() != 8) { + return op.emitOpError("unexpected number of operands for matmul"); } // Get or create a custom call function declaration. @@ -228,7 +317,8 @@ class ConvOpLowering : public OpRewritePattern { call->setAttr(b.getStringAttr(name), attr); }; - auto set_xi64 = [&](StringRef name, Optional attr) { + auto set_xi64 = [&](StringRef name, + std::optional attr) { SmallVector values; if (attr.has_value()) values = llvm::to_vector(attr->getValues()); @@ -237,7 +327,7 @@ class ConvOpLowering : public OpRewritePattern { // Convert `BoolElementsAttr` to i64 before passing to the runtime. // TODO(ezhulenev): Allow passing boolean tensors to the XLA custom calls. - auto set_xi1 = [&](StringRef name, Optional attr) { + auto set_xi1 = [&](StringRef name, std::optional attr) { SmallVector values; if (attr.has_value()) values.assign(attr->getValues().begin(), @@ -378,8 +468,8 @@ void ConvertLmhloGpuToGpuRuntimePass::runOnOperation() { // Each unique Gemm/Matmul operation in the module will get assigned a uid. UidGenerator matmul_uid; - patterns.insert(ctx, matmul_uid, - custom_calls); + patterns.insert(ctx, matmul_uid, custom_calls); // Each unique Conv operation in the module will get assigned a uid. UidGenerator conv_uid; diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/lmhlo_to_gpu_launch.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_launch.cc similarity index 75% rename from tensorflow/compiler/xla/mlir/transforms/gpu/lmhlo_to_gpu_launch.cc rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_launch.cc index 604ee6eed8f..a381d679261 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/lmhlo_to_gpu_launch.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_launch.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include +#include #include #include #include @@ -29,14 +31,12 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/xla/mlir/runtime/ir/rt_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" @@ -49,7 +49,7 @@ namespace xla { namespace gpu { #define GEN_PASS_DEF_CONVERTLMHLOTOGPULAUNCHPASS -#include "tensorflow/compiler/xla/mlir/transforms/gpu/passes.h.inc" +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h.inc" using namespace mlir; // NOLINT @@ -79,6 +79,13 @@ class ConvertLmhloToGpuLaunchPass ThunkSequence* thunk_sequence_; }; +// XLA some times (ab)uses custom calls to represent operations for which we do +// not want to define a separate `HloOpcode`. These operations emitted as device +// kernels (similar to fusions), and we detect such custom calls by name, and +// handle them similar to how we handle fusions. +static std::array kCustomCallIntrinsics = { + "SliceToDynamic", "PadToStatic", "__triton"}; + //===-----------------------------------------------------------------------===/ static Value MakeBitPatternConstant(OpBuilder& b, Location loc, Type type, @@ -157,17 +164,6 @@ static Value MakeBitPatternConstant(OpBuilder& b, Location loc, Type type, return b.create(loc, 0); } -// Replaces lmhlo ops within a module with gpu.launch_func and gpu.memcpy ops. -struct KernelOpsPattern : OpRewritePattern { - KernelOpsPattern(MLIRContext* context, ThunkSequence* thunk_sequence) - : OpRewritePattern(context), thunk_sequence(thunk_sequence) {} - - LogicalResult matchAndRewrite(ModuleOp module_op, - PatternRewriter& rewriter) const override; - - ThunkSequence* thunk_sequence; -}; - static void ExtractThunksForOp(Operation* op, ThunkSequence& thunk_sequence, ThunkSequence* thunks_for_op) { for (std::unique_ptr& thunk : thunk_sequence) { @@ -219,44 +215,44 @@ static absl::StatusOr> Match( return std::move(thunks_for_op); } -static void LowerThunkToGpuOp(Operation* op, PatternRewriter& rewriter, +static void LowerThunkToGpuOp(Operation* op, OpBuilder& b, GPUModuleOp gpu_module, Thunk* thunk); // Replaces op with gpu.launch_func, gpu.memcpy, gpu.memset ops. -static void Rewrite(Operation* op, PatternRewriter& rewriter, - SymbolTable& symbol_table, ThunkSequence* thunks) { - OpBuilder::InsertionGuard guard(rewriter); +static void Rewrite(Operation* op, OpBuilder& b, SymbolTable& symbol_table, + ThunkSequence* thunks) { + OpBuilder::InsertionGuard guard(b); auto loc = op->getLoc(); - rewriter.setInsertionPoint(op->getParentOfType()); - auto gpu_module = rewriter.create(loc, "gpu_module"); + b.setInsertionPoint(op->getParentOfType()); + auto gpu_module = b.create(loc, "gpu_module"); symbol_table.insert(gpu_module); for (const std::unique_ptr& thunk : *thunks) { - LowerThunkToGpuOp(op, rewriter, gpu_module, thunk.get()); + LowerThunkToGpuOp(op, b, gpu_module, thunk.get()); } - rewriter.eraseOp(op); + op->erase(); } -static void LowerThunkToGpuOp(Operation* op, PatternRewriter& rewriter, +static void LowerThunkToGpuOp(Operation* op, OpBuilder& b, GPUModuleOp gpu_module, Thunk* thunk) { auto loc = op->getLoc(); if (thunk->kind() == Thunk::kSequential) { const auto* seq_thunk = static_cast(thunk); for (const std::unique_ptr& thunk : seq_thunk->thunks()) { - LowerThunkToGpuOp(op, rewriter, gpu_module, thunk.get()); + LowerThunkToGpuOp(op, b, gpu_module, thunk.get()); } return; } if (thunk->kind() == Thunk::kCopy) { const auto* copy_thunk = static_cast(thunk); - rewriter.setInsertionPoint(op); - rewriter.create(loc, TypeRange(), ValueRange(), - copy_thunk->destination_value(), - copy_thunk->source_value()); + b.setInsertionPoint(op); + b.create(loc, TypeRange(), ValueRange(), + copy_thunk->destination_value(), + copy_thunk->source_value()); return; } @@ -264,11 +260,9 @@ static void LowerThunkToGpuOp(Operation* op, PatternRewriter& rewriter, uint32_t memset_value, Value buffer_arg) { auto element_type = buffer_arg.getType().cast().getElementType(); - rewriter.setInsertionPoint(op); - Value value = - MakeBitPatternConstant(rewriter, loc, element_type, memset_value); - rewriter.create(loc, TypeRange(), ValueRange(), buffer_arg, - value); + b.setInsertionPoint(op); + Value value = MakeBitPatternConstant(b, loc, element_type, memset_value); + b.create(loc, TypeRange(), ValueRange(), buffer_arg, value); }; if (thunk->kind() == Thunk::kMemset32BitValue) { @@ -285,25 +279,24 @@ static void LowerThunkToGpuOp(Operation* op, PatternRewriter& rewriter, } const auto* kernel_thunk = static_cast(thunk); - rewriter.setInsertionPointToStart(gpu_module.getBody()); + b.setInsertionPointToStart(gpu_module.getBody()); SmallVector kernel_args; for (auto kernel_arg : kernel_thunk->values()) kernel_args.push_back(kernel_arg); - auto func_type = rewriter.getType( - TypeRange(ValueRange(kernel_args)), TypeRange()); + auto func_type = + b.getType(TypeRange(ValueRange(kernel_args)), TypeRange()); - gpu::GPUFuncOp kernel_func = rewriter.create( - loc, kernel_thunk->kernel_name(), func_type); - kernel_func->setAttr(GPUDialect::getKernelFuncAttrName(), - rewriter.getUnitAttr()); - rewriter.setInsertionPointToEnd(&kernel_func.getBody().back()); - rewriter.create(loc); + gpu::GPUFuncOp kernel_func = + b.create(loc, kernel_thunk->kernel_name(), func_type); + kernel_func->setAttr(GPUDialect::getKernelFuncAttrName(), b.getUnitAttr()); + b.setInsertionPointToEnd(&kernel_func.getBody().back()); + b.create(loc); auto make_const_idx = [&](int64_t value) { - auto attr = rewriter.getIndexAttr(value); - return rewriter.create(loc, attr).getResult(); + auto attr = b.getIndexAttr(value); + return b.create(loc, attr).getResult(); }; auto make_kernel_dim3 = [&](const auto& dim3) { @@ -313,13 +306,15 @@ static void LowerThunkToGpuOp(Operation* op, PatternRewriter& rewriter, const auto& launch_dims = kernel_thunk->launch_dimensions(); - rewriter.setInsertionPoint(op); + b.setInsertionPoint(op); auto grid_size = make_kernel_dim3(launch_dims.block_counts()); auto block_size = make_kernel_dim3(launch_dims.thread_counts_per_block()); + auto shmem_size = b.create( + loc, + b.getI32IntegerAttr(kernel_thunk->launch_dimensions().SharedMemBytes())); - rewriter.create(loc, kernel_func, grid_size, block_size, - /*shared_memory_size_bytes=*/nullptr, - kernel_args); + b.create(loc, kernel_func, grid_size, block_size, shmem_size, + kernel_args); } // An overload set for defining predicates for operations that should @@ -331,36 +326,39 @@ static bool HasGpuEmitter(OpTy) { // Select custom calls that have corresponding GPU emitters. static bool HasGpuEmitter(lmhlo::CustomCallOp custom_call) { - llvm::StringRef target = custom_call.getCallTargetName(); - return target == "SliceToDynamic" || target == "PadToStatic"; + return llvm::any_of(kCustomCallIntrinsics, [&](std::string_view name) { + return custom_call.getCallTargetName().equals(name); + }); } -LogicalResult KernelOpsPattern::matchAndRewrite( - ModuleOp module_op, PatternRewriter& rewriter) const { +//===-----------------------------------------------------------------------===/ + +void ConvertLmhloToGpuLaunchPass::runOnOperation() { + ModuleOp module = getOperation(); + // No thunks to lower from. Skip pass. - if (thunk_sequence == nullptr) { - return failure(); - } + if (thunk_sequence_ == nullptr) return signalPassFailure(); + // Collect thunks for rewriting each compatible operation in the module into + // the sequence of device kernel launches. Some operation might have an empty + // thunk sequence (e.g. redundant copy operation that does not require running + // anything on device). absl::flat_hash_map> rewrites; // Get data to rewrite kernel ops without changing the IR. auto walk = [&](auto op_type_tag) { - using OpTy = decltype(op_type_tag); - - return module_op.walk([&](OpTy op) -> WalkResult { + return module.walk([&](decltype(op_type_tag) op) -> WalkResult { if (!HasGpuEmitter(op)) return success(); - auto data = Match(op, *thunk_sequence); - if (!data.ok()) - return rewriter.notifyMatchFailure(op, data.status().message()); + auto data = Match(op, *thunk_sequence_); + if (!data.ok()) return op.emitOpError(data.status().message()); rewrites[op] = std::move(*data); return success(); }); }; - // Compile all operations that have GPU code emitters to the GPU binary, + // Collect all operations that have GPU code emitters. if (walk(lmhlo::FusionOp()).wasInterrupted() || walk(lmhlo::RngGetAndUpdateStateOp()).wasInterrupted() || walk(lmhlo::ScatterOp()).wasInterrupted() || @@ -368,37 +366,21 @@ LogicalResult KernelOpsPattern::matchAndRewrite( walk(lmhlo::SortOp()).wasInterrupted() || walk(lmhlo::CustomCallOp()).wasInterrupted() || walk(LaunchFuncOp()).wasInterrupted()) - return failure(); + return signalPassFailure(); - if (rewrites.empty()) { - return rewriter.notifyMatchFailure(module_op, "No kernel ops"); - } + // No operations that should be lowered to sequence of device launches. + if (rewrites.empty()) return; - // Mark module as gpu.container_module. - rewriter.updateRootInPlace(module_op, [&] { - module_op->setAttr(GPUDialect::getContainerModuleAttrName(), - rewriter.getUnitAttr()); - }); + OpBuilder b(module); + SymbolTable symbol_table(module); - // Replace the kernel ops with gpu.launch_func. - SymbolTable symbol_table(module_op); - for (const auto& rewrite : rewrites) { - Rewrite(rewrite.first, rewriter, symbol_table, rewrite.second.get()); + // Replace matched operations with gpu.launch_func's. + for (const auto& [op, thunks] : rewrites) { + Rewrite(op, b, symbol_table, thunks.get()); } - return success(); -} - -//===-----------------------------------------------------------------------===/ - -void ConvertLmhloToGpuLaunchPass::runOnOperation() { - MLIRContext* ctx = &getContext(); - - RewritePatternSet patterns(ctx); - patterns.insert(ctx, thunk_sequence_); - - if (failed(applyOpPatternsAndFold(getOperation(), std::move(patterns)))) - return signalPassFailure(); + // Mark module as gpu.container_module. + module->setAttr(GPUDialect::getContainerModuleAttrName(), b.getUnitAttr()); } std::unique_ptr> diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/lmhlo_to_gpu_runtime.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_runtime.cc similarity index 73% rename from tensorflow/compiler/xla/mlir/transforms/gpu/lmhlo_to_gpu_runtime.cc rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_runtime.cc index adfd93fb8f4..e2e0cba0a85 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/lmhlo_to_gpu_runtime.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/lmhlo_to_gpu_runtime.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/functional/any_invocable.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" // from @llvm-project @@ -28,16 +29,19 @@ limitations under the License. #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/uid_generator.h" #include "tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h" #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h" @@ -48,7 +52,7 @@ namespace xla { namespace gpu { #define GEN_PASS_DEF_CONVERTLMHLOTOGPURUNTIMEPASS -#include "tensorflow/compiler/xla/mlir/transforms/gpu/passes.h.inc" +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h.inc" using namespace mlir; // NOLINT @@ -142,8 +146,43 @@ class CustomCallOpLowering : public OpRewritePattern { CustomCallOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) : OpRewritePattern(ctx), custom_calls_(custom_calls) {} + // Rewrite custom call with `API_VERSION_TYPED_FFI` version into XLA runtime + // custom calls bypassing custom call adaptor. + LogicalResult rewriteTypedCustomCall(CustomCallOp op, + PatternRewriter& rewriter) const { + // TODO(ezhulenev): Support target arg mapping, or explain why we do not + // need them for typed custom calls. + if (op.getTargetArgMapping()) + return op.emitOpError( + "API_VERSION_TYPED_FFI custom calls do not " + "support target arg mapping"); + + // Create a custom call function declaration. + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + func::FuncOp callee = + custom_calls_.GetOrCreate(b, op.getCallTargetName(), op); + callee->setAttr("rt.dynamic", UnitAttr::get(b.getContext())); + + // Forward backend config to the custom call implementation. + auto dict = op.getBackendConfig() + ? op.getBackendConfig()->cast() + : nullptr; + llvm::SmallVector backend_config(dict.begin(), dict.end()); + + // Call the custom call function forwarding user-defined attributes. + auto call = rewriter.replaceOpWithNewOp( + op, callee.getName(), TypeRange(), op.getOperands()); + AppendCustomCallAttrs(call, backend_config); + + return success(); + } + LogicalResult matchAndRewrite(CustomCallOp op, PatternRewriter& rewriter) const override { + // Typed custom calls lowered directly to XLA runtime custom calls. + if (op.getApiVersion() == mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI) + return rewriteTypedCustomCall(op, rewriter); + ImplicitLocOpBuilder b(op.getLoc(), rewriter); // By default all operands passed to the custom call handler. @@ -204,8 +243,9 @@ class FftOpLowering : public OpRewritePattern { static constexpr const char kCustomCallTarget[] = "xla.gpu.fft"; public: - FftOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} + FftOpLowering(MLIRContext* ctx, UidGenerator& uid, + CustomCallDeclarations& custom_calls) + : OpRewritePattern(ctx), uid_(uid), custom_calls_(custom_calls) {} LogicalResult matchAndRewrite(FftOp op, PatternRewriter& rewriter) const override { @@ -215,7 +255,8 @@ class FftOpLowering : public OpRewritePattern { llvm::SmallVector custom_call_attrs = { {b.getStringAttr("fft_length"), op.getFftLengthAttr()}, - {b.getStringAttr("fft_type"), op.getFftTypeAttr()}}; + {b.getStringAttr("fft_type"), op.getFftTypeAttr()}, + {b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid())}}; // Convert Fft to a function call. auto call = rewriter.replaceOpWithNewOp( @@ -225,6 +266,7 @@ class FftOpLowering : public OpRewritePattern { } private: + UidGenerator& uid_; CustomCallDeclarations& custom_calls_; }; @@ -347,7 +389,7 @@ class WhileOpLowering : public OpRewritePattern { auto loop = b.create(lb, ub, c1, ValueRange()); // Move body region into the new loop operation. - BlockAndValueMapping mapping; + IRMapping mapping; rewriter.eraseOp(op.getBody().front().getTerminator()); rewriter.mergeBlockBefore(&op.getBody().front(), loop.getLoopBody().front().getTerminator()); @@ -369,7 +411,7 @@ class WhileOpLowering : public OpRewritePattern { Value pred = op.getOperand(0); // Inline condition and body regions into the new loop operation. - BlockAndValueMapping mapping; + IRMapping mapping; rewriter.inlineRegionBefore(op.getCond(), loop.getBefore(), loop.getBefore().begin()); rewriter.inlineRegionBefore(op.getBody(), loop.getAfter(), @@ -433,6 +475,8 @@ using mlir::lmhlo::ReduceScatterOp; using mlir::lmhlo::ReplicaIdOp; using mlir::lmhlo_gpu::AllReduceDoneOp; using mlir::lmhlo_gpu::AllReduceStartOp; +using mlir::lmhlo_gpu::CollectivePermuteDoneOp; +using mlir::lmhlo_gpu::CollectivePermuteStartOp; // We assign unique id to all collective operations in the module, so that we // can efficiently access per-op state at run time. Exception to this rule are @@ -454,7 +498,7 @@ class CollectiveUidGenerator { CollectiveUidGenerator() : cnt_(0) {} // Assings a unique event id to the pair of start and done operations. - int32_t AssignUid(AllReduceStartOp start, AllReduceDoneOp done) { + int32_t AssignUid(Operation* start, Operation* done) { int32_t id = next(); uids_[start] = id; uids_[done] = id; @@ -463,7 +507,8 @@ class CollectiveUidGenerator { FailureOr AssignedUid(Operation* op) { // Async operations must be assigned uid ahead of time. - if (isa(op)) { + if (isa(op)) { auto it = uids_.find(op); if (it == uids_.end()) return failure(); return it->second; @@ -488,6 +533,9 @@ class CollectiveOpLowering : public OpRewritePattern { static StringRef Target(CollectivePermuteOp) { return "xla.gpu.collective_permute"; } + static StringRef Target(CollectivePermuteStartOp) { + return "xla.gpu.collective_permute_start"; + } static StringRef Target(AllReduceStartOp) { return "xla.gpu.all_reduce_start"; } @@ -515,6 +563,13 @@ class CollectiveOpLowering : public OpRewritePattern { .config; } + static NcclCollectiveConfig GetNcclCollectiveConfig( + CollectivePermuteStartOp op, int replica_count, int num_partitions) { + return NcclCollectivePermuteStartThunk::GetNcclCollectivePermuteConfig( + op, replica_count, num_partitions) + .config; + } + template static LogicalResult TryDegenerateToMemCopy( NonCollectivePermuteOp op, const NcclCollectiveConfig& config, @@ -532,11 +587,11 @@ class CollectiveOpLowering : public OpRewritePattern { return success(); } - static LogicalResult TryDegenerateToMemCopy( - CollectivePermuteOp op, const NcclCollectiveConfig& config, - int replica_count, int num_partitions, PatternRewriter& rewriter) { - if (!NcclCollectivePermuteThunk::IsDegenerate(op, replica_count, - num_partitions)) { + template + static LogicalResult TryDegenerateCollectivePermuteToMemCopy( + OpT op, const NcclCollectiveConfig& config, int replica_count, + int num_partitions, PatternRewriter& rewriter) { + if (!ThunkT::IsDegenerate(op, replica_count, num_partitions)) { return failure(); } @@ -547,6 +602,21 @@ class CollectiveOpLowering : public OpRewritePattern { return success(); } + static LogicalResult TryDegenerateToMemCopy( + CollectivePermuteOp op, const NcclCollectiveConfig& config, + int replica_count, int num_partitions, PatternRewriter& rewriter) { + return TryDegenerateCollectivePermuteToMemCopy( + op, config, replica_count, num_partitions, rewriter); + } + + static LogicalResult TryDegenerateToMemCopy( + CollectivePermuteStartOp op, const NcclCollectiveConfig& config, + int replica_count, int num_partitions, PatternRewriter& rewriter) { + return TryDegenerateCollectivePermuteToMemCopy< + NcclCollectivePermuteStartThunk>(op, config, replica_count, + num_partitions, rewriter); + } + static bool CanImplement(AllGatherOp op) { return NcclAllGatherThunk::CanImplement(op); } @@ -571,6 +641,10 @@ class CollectiveOpLowering : public OpRewritePattern { return NcclCollectivePermuteThunk::CanImplement(op); } + static bool CanImplement(CollectivePermuteStartOp op) { + return NcclCollectivePermuteStartThunk::CanImplement(op); + } + template static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, ReduceOp op, func::CallOp call) { @@ -599,9 +673,9 @@ class CollectiveOpLowering : public OpRewritePattern { return success(); } - static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, - CollectivePermuteOp op, - func::CallOp call) { + template + static LogicalResult SetCollectivePermuteAttrs(ImplicitLocOpBuilder& b, + OpT op, func::CallOp call) { auto source_target_pairs_or = ConvertNx2Attribute(op.getSourceTargetPairs()); if (!source_target_pairs_or.ok()) { @@ -628,6 +702,18 @@ class CollectiveOpLowering : public OpRewritePattern { return success(); } + static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, + CollectivePermuteOp op, + func::CallOp call) { + return SetCollectivePermuteAttrs(b, op, call); + } + + static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, + CollectivePermuteStartOp op, + func::CallOp call) { + return SetCollectivePermuteAttrs(b, op, call); + } + public: CollectiveOpLowering(MLIRContext* ctx, CollectiveUidGenerator& uid, CustomCallDeclarations& custom_calls) @@ -730,10 +816,12 @@ class CollectiveOpLowering : public OpRewritePattern { // For asynchonous start operation we need to produce a fake token, that // will be later removed, because corresponding `done` operation doesn't - // have the token argument. We rely on the `unrealized_conversion_cast` - // operation to create a fake token from the `i8` constant. - if (auto start = dyn_cast(op.getOperation())) { - Value token = start.getToken(); + // have a token argument. We rely on the `unrealized_conversion_cast` + // operation to create a fake token from the `i8` constant, and on the dead + // code elimination pass that will remove unused fake tokens. + if constexpr (std::is_same_v || + std::is_same_v) { + Value token = op.getToken(); Value c0 = b.create(b.getI8IntegerAttr(0)); auto fake = b.create(token.getType(), c0); token.replaceAllUsesWith(fake.getResult(0)); @@ -762,27 +850,27 @@ DEFINE_COLLECTIVE_OP_LOWERING(AllReduceStartOp); DEFINE_COLLECTIVE_OP_LOWERING(ReduceScatterOp); DEFINE_COLLECTIVE_OP_LOWERING(AllToAllOp); DEFINE_COLLECTIVE_OP_LOWERING(CollectivePermuteOp); +DEFINE_COLLECTIVE_OP_LOWERING(CollectivePermuteStartOp); #undef DEFINE_COLLECTIVE_OP_LOWERING -class AllReduceDoneOpLowering : public OpRewritePattern { - static constexpr const char kCustomCallTarget[] = "xla.gpu.all_reduce_done"; - +template +class AsyncDoneOpLowering : public OpRewritePattern { public: - AllReduceDoneOpLowering(MLIRContext* ctx, CollectiveUidGenerator& uid, - CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), uid_(uid), custom_calls_(custom_calls) {} + AsyncDoneOpLowering(MLIRContext* ctx, const char* custom_call_target, + CollectiveUidGenerator& uid, + CustomCallDeclarations& custom_calls) + : OpRewritePattern(ctx), + custom_call_target_(custom_call_target), + uid_(uid), + custom_calls_(custom_calls) {} - LogicalResult matchAndRewrite(AllReduceDoneOp op, + LogicalResult matchAndRewrite(OpT op, PatternRewriter& rewriter) const override { - // For done operation we drop the token argument and communicate async event - // dependency through the `uid` attribute. - llvm::SmallVector operands = op.getOperands().drop_front(); - // Get or create a custom call function declaration. ImplicitLocOpBuilder b(op.getLoc(), rewriter); - func::FuncOp callee = custom_calls_.GetOrCreate( - b, kCustomCallTarget, TypeRange(ValueRange(operands)), TypeRange()); + func::FuncOp callee = custom_calls_.GetOrCreate(b, custom_call_target_, + TypeRange(), TypeRange()); // Get a unique collective operation id. FailureOr uid = uid_.AssignedUid(op); @@ -793,18 +881,39 @@ class AllReduceDoneOpLowering : public OpRewritePattern { {b.getStringAttr("uid"), b.getI32IntegerAttr(*uid)}}; // Convert AllReduceDone to a function call. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), operands); + auto call = rewriter.replaceOpWithNewOp(op, callee.getName(), + TypeRange()); AppendCustomCallAttrs(call, custom_call_attributes); return success(); } private: + const char* custom_call_target_; CollectiveUidGenerator& uid_; CustomCallDeclarations& custom_calls_; }; +class AllReduceDoneOpLowering : public AsyncDoneOpLowering { + static constexpr const char kCustomCallTarget[] = "xla.gpu.all_reduce_done"; + + public: + AllReduceDoneOpLowering(MLIRContext* ctx, CollectiveUidGenerator& uid, + CustomCallDeclarations& custom_calls) + : AsyncDoneOpLowering(ctx, kCustomCallTarget, uid, custom_calls) {} +}; + +class CollectivePermuteDoneOpLowering + : public AsyncDoneOpLowering { + static constexpr const char kCustomCallTarget[] = + "xla.gpu.collective_permute_done"; + + public: + CollectivePermuteDoneOpLowering(MLIRContext* ctx, CollectiveUidGenerator& uid, + CustomCallDeclarations& custom_calls) + : AsyncDoneOpLowering(ctx, kCustomCallTarget, uid, custom_calls) {} +}; + template class CollectiveIdOpLowering : public OpRewritePattern { static StringRef Target(ReplicaIdOp) { return "xla.gpu.replica_id"; } @@ -841,6 +950,131 @@ class PartitionIdOpLowering : public CollectiveIdOpLowering { }; //===----------------------------------------------------------------------===// +// Point-to-Point communication ops lowering (Send/Recv). +//===----------------------------------------------------------------------===// + +template +class SendRecvOpLowering : public OpRewritePattern { + public: + SendRecvOpLowering(MLIRContext* ctx, const char* custom_call_target, + CustomCallDeclarations& custom_calls) + : OpRewritePattern(ctx), + custom_call_target_(custom_call_target), + custom_calls_(custom_calls) {} + + LogicalResult matchAndRewrite(OpT op, + PatternRewriter& rewriter) const override { + // Get or create a custom call function declaration. + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + func::FuncOp callee = custom_calls_.GetOrCreate( + b, custom_call_target_, TypeRange(op.getOperands()), TypeRange()); + + llvm::SmallVector custom_call_attributes = { + {b.getStringAttr("channel_handle"), op.getChannelHandle()}, + {b.getStringAttr("is_host_transfer"), op.getIsHostTransferAttr()}, + {b.getStringAttr("frontend_attributes"), op.getFrontendAttributes()}}; + + // Convert Send/Recv to a function call. + auto call = rewriter.create(op.getLoc(), callee.getName(), + TypeRange(), op.getOperands()); + AppendCustomCallAttrs(call, custom_call_attributes); + + // For communication operation we need to produce a fake token, that will be + // later removed, because corresponding `done` operation doesn't have the + // token argument. We rely on the `unrealized_conversion_cast` operation to + // create a fake token from the `i8` constant. + Value token = op.getResult(); + Value c0 = b.create(b.getI8IntegerAttr(0)); + auto fake = b.create(token.getType(), c0); + token.replaceAllUsesWith(fake.getResult(0)); + + // Erase the original operation. + rewriter.eraseOp(op); + + return success(); + } + + private: + const char* custom_call_target_; + CustomCallDeclarations& custom_calls_; +}; + +template +class SendRecvDoneOpLowering : public OpRewritePattern { + public: + SendRecvDoneOpLowering(MLIRContext* ctx, const char* custom_call_target, + CustomCallDeclarations& custom_calls) + : OpRewritePattern(ctx), + custom_call_target_(custom_call_target), + custom_calls_(custom_calls) {} + + LogicalResult matchAndRewrite(OpT op, + PatternRewriter& rewriter) const override { + // Get or create a custom call function declaration. + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + func::FuncOp callee = custom_calls_.GetOrCreate(b, custom_call_target_, + TypeRange(), TypeRange()); + + llvm::SmallVector custom_call_attributes = { + {b.getStringAttr("channel_handle"), op.getChannelHandleAttr()}, + {b.getStringAttr("is_host_transfer"), op.getIsHostTransferAttr()}}; + + // Convert SendDone/RecvDone to a function call. + auto call = rewriter.replaceOpWithNewOp(op, callee.getName(), + TypeRange()); + AppendCustomCallAttrs(call, custom_call_attributes); + + return success(); + } + + private: + const char* custom_call_target_; + CustomCallDeclarations& custom_calls_; +}; + +struct SendOpLowering : public SendRecvOpLowering { + static constexpr const char kCustomCallTarget[] = "xla.gpu.send"; + SendOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) + : SendRecvOpLowering(ctx, kCustomCallTarget, custom_calls) {} +}; + +struct SendDoneOpLowering : public SendRecvDoneOpLowering { + static constexpr const char kCustomCallTarget[] = "xla.gpu.send_done"; + SendDoneOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) + : SendRecvDoneOpLowering(ctx, kCustomCallTarget, custom_calls) {} +}; + +struct RecvOpLowering : public SendRecvOpLowering { + static constexpr const char kCustomCallTarget[] = "xla.gpu.recv"; + RecvOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) + : SendRecvOpLowering(ctx, kCustomCallTarget, custom_calls) {} +}; + +struct RecvDoneOpLowering : public SendRecvDoneOpLowering { + static constexpr const char kCustomCallTarget[] = "xla.gpu.recv_done"; + RecvDoneOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) + : SendRecvDoneOpLowering(ctx, kCustomCallTarget, custom_calls) {} +}; + +//===----------------------------------------------------------------------===// + +template +static absl::AnyInvocable GetAsyncUidGenerator( + CollectiveUidGenerator& collective_uid) { + return [&collective_uid](StartOpT start) -> WalkResult { + Value token = start.getToken(); + + // We expect the token to be consumed just once. + if (!token.hasOneUse()) return start.emitOpError("token has multiple uses"); + + // Token must be consumed by the corresponding done operation. + auto done = dyn_cast(*token.getUsers().begin()); + if (!done) return start.emitOpError("illegal token user"); + + collective_uid.AssignUid(start, done); + return WalkResult::advance(); + }; +} void ConvertLmhloToGpuRuntimePass::runOnOperation() { ModuleOp module = getOperation(); @@ -853,25 +1087,21 @@ void ConvertLmhloToGpuRuntimePass::runOnOperation() { // Convert lmhlo operations to XLA gpu runtime custom calls. RewritePatternSet patterns(ctx); patterns.insert(ctx); - patterns.insert(ctx, custom_calls); + patterns.insert( + ctx, custom_calls); + + UidGenerator fft_uid; + patterns.insert(ctx, fft_uid, custom_calls); // Assign shared unique id to each unique pair of async start-done operations, // all other collective operations will get assigned uid. CollectiveUidGenerator collective_uid; - auto walked = module.walk([&](AllReduceStartOp start) -> WalkResult { - Value token = start.getToken(); - - // We expect the token to be consumed just once. - if (!token.hasOneUse()) return start.emitOpError("token has multiple uses"); - - // Token must be consumed by the corresponding done operation. - auto done = dyn_cast(*token.getUsers().begin()); - if (!done) return start.emitOpError("illegal token user"); - - collective_uid.AssignUid(start, done); - return WalkResult::advance(); - }); + auto walked = module.walk( + GetAsyncUidGenerator(collective_uid)); + if (walked.wasInterrupted()) return signalPassFailure(); + walked = module.walk( + GetAsyncUidGenerator( + collective_uid)); if (walked.wasInterrupted()) return signalPassFailure(); // Convert lmhlo collective operations to XLA gpu runtime custom calls. @@ -879,8 +1109,12 @@ void ConvertLmhloToGpuRuntimePass::runOnOperation() { custom_calls); patterns.insert( - ctx, collective_uid, custom_calls); + CollectivePermuteOpLowering, CollectivePermuteStartOpLowering, + ReduceScatterOpLowering>(ctx, collective_uid, custom_calls); + + // Convert lmhlo point-to-point communication operations to XLA gpu runtime. + patterns.insert(ctx, custom_calls); if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) return signalPassFailure(); @@ -892,7 +1126,8 @@ void ConvertLmhloToGpuRuntimePass::runOnOperation() { // This should be a part of lmhlo operation canonicalization. { RewritePatternSet patterns(ctx); - patterns.insert(ctx, collective_uid, custom_calls); + patterns.insert( + ctx, collective_uid, custom_calls); if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) return signalPassFailure(); } diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/memref_get_global_to_arg.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/memref_get_global_to_arg.cc similarity index 97% rename from tensorflow/compiler/xla/mlir/transforms/gpu/memref_get_global_to_arg.cc rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/memref_get_global_to_arg.cc index 8ee51e3cfef..5eb8efcb0fc 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/memref_get_global_to_arg.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/memref_get_global_to_arg.cc @@ -21,13 +21,13 @@ limitations under the License. #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/xla/mlir/transforms/gpu/passes.h" +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h" namespace xla { namespace gpu { #define GEN_PASS_DEF_CONVERTMEMREFGETGLOBALTOARGPASS -#include "tensorflow/compiler/xla/mlir/transforms/gpu/passes.h.inc" +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h.inc" using namespace mlir; // NOLINT diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc new file mode 100644 index 00000000000..2c4955a8eb2 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc @@ -0,0 +1,354 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dominance.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir/runtime/ir/rt_dialect.h" +#include "tensorflow/compiler/xla/mlir/runtime/ir/rt_ops.h" +#include "tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" + +namespace xla { +namespace gpu { + +#define GEN_PASS_DEF_OUTLINECUDAGRAPHSPASS +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h.inc" + +using namespace mlir; // NOLINT + +using mlir::gpu::LaunchFuncOp; + +class OutlineCudaGraphsPass + : public impl::OutlineCudaGraphsPassBase { + void runOnOperation() override; + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } +}; + +//===----------------------------------------------------------------------===// + +struct OpCapturePattern { + // CUDA-graph-compatible operations can be either moved or cloned into the + // graph capture function. Most of the operations should be moved, as they + // have side effects, however small constants and pure operations like + // `memref.view` can be safely cloned into the graph region. We rely on later + // dead code elimination to erase them from the "main" function if they are + // not used by any other operations. + enum class Capture { kMove, kClone }; + + virtual ~OpCapturePattern() = default; + virtual FailureOr match(Operation* op) = 0; +}; + +using OpCapturePatternSet = std::vector>; + +// A sequence of operations to be outlined into cuda graph capture function. +using CaptureSequence = + llvm::SmallVector>; + +//===----------------------------------------------------------------------===// + +template +struct OpCapture : public OpCapturePattern { + FailureOr match(Operation* op) final { + if (isa(op)) return capture; + return failure(); + } +}; + +static constexpr auto kMove = OpCapturePattern::Capture::kMove; +static constexpr auto kClone = OpCapturePattern::Capture::kClone; + +template +using MoveOp = OpCapture; +template +using CloneOp = OpCapture; + +// Capture gpu operations by moving them intp graph capture function. +struct LaunchFuncOpCapture : public MoveOp {}; +struct ConvOpCapture : public MoveOp {}; + +// Capture pure operations by cloning them into graph capture function. +struct ConstantOpCapture : public CloneOp {}; +struct ViewOpCapture : public CloneOp {}; + +//===----------------------------------------------------------------------===// + +// Collect sequences of operations that can be outlined into Cuda Graphs. +static std::vector CollectCaptureSequences( + DominanceInfo& dominance, ModuleOp module, OpCapturePatternSet& patterns) { + std::vector seqs; + + // Match given operation with all capture patterns. + auto match = [&](Operation* op) -> FailureOr { + for (auto& pattern : patterns) { + if (auto matched = pattern->match(op); succeeded(matched)) return matched; + } + return failure(); + }; + + // Find graph-compatible sequences of operations in every block. + module.walk([&](Block* block) { + CaptureSequence* seq = &seqs.emplace_back(); + + for (Operation& op : *block) { + FailureOr matched = match(&op); + // Append matched operation to the current sequence. We only append + // operations that must be moved into the graph capture function (ops with + // side effects), and add cloneable operations later. + if (succeeded(matched) && *matched == kMove) + seq->emplace_back(&op, *matched); + + // Skip unsupported operation and start a new sequence. + if (failed(matched) && !seq->empty()) seq = &seqs.emplace_back(); + } + + // Remove the last sequence if it's empty. + if (seq->empty()) seqs.pop_back(); + }); + + // Remove cloneable operations accidentally captured by the sequence of ops, + // e.g. we can have `memref.view` between two kernel launch operations that + // is not used by operations in the captured sequence. + for (CaptureSequence& seq : seqs) { + llvm::DenseSet moveable_ops; + for (auto& [op, capture] : seq) + if (capture == kMove) moveable_ops.insert(op); + + llvm::erase_if(seq, [&](auto& pair) { + return pair.second == kClone && + llvm::none_of(pair.first->getUsers(), [&](Operation* user) { + return moveable_ops.contains(user); + }); + }); + } + + // Try to extend discovered sequences of ops following operands use-def chains + // and pulling cloneable operations defining operands into the graph capture + // sequence. In practice we just clone `arith.constant` and `memref.view` + // operations into the graph capture function, to make it cheaper to compute + // the hash of the arguments at run time. + for (CaptureSequence& seq : seqs) { + llvm::DenseSet seq_ops; // operations already in `seq` + llvm::SmallVector worklist; + + // Add operations that define `op` arguments to the worklist. + auto populate_worklist = [&](Operation* op) { + for (Value arg : op->getOperands()) + if (Operation* op = arg.getDefiningOp()) worklist.push_back(op); + }; + + for (auto& [op, _] : seq) { + seq_ops.insert(op); + populate_worklist(op); + } + + // Find cloneable ops and group them by block where they are defined. + llvm::DenseMap> cloneable; + + // Traverse use-def chains to collect all cloneable operations. + while (!worklist.empty()) { + Operation* op = worklist.pop_back_val(); + if (seq_ops.contains(op)) continue; + + // Check if operation can be cloned into graph capture function. + if (auto matched = match(op); + succeeded(matched) && *matched == OpCapturePattern::Capture::kClone) { + cloneable[op->getBlock()].push_back(op); + seq_ops.insert(op); + populate_worklist(op); + } + } + + // Traverse blocks according to their dominance to avoid used-before-defined + // invalid SSA region construction in graph capture function. + llvm::SmallVector blocks; + for (auto& [block, _] : cloneable) blocks.push_back(block); + llvm::sort(blocks, [&](Block* a, Block* b) { + return dominance.properlyDominates(a, b); + }); + + for (Block* block : llvm::reverse(blocks)) { + // Sort operations according to their original position in the block. + llvm::sort(cloneable[block], [](Operation* a, Operation* b) { + return a->isBeforeInBlock(b); + }); + + // Prepend all cloneable operations to the discovered ops sequence. + auto cloned = llvm::map_range(cloneable[block], [](Operation* op) { + return std::make_pair(op, OpCapturePattern::Capture::kClone); + }); + seq.insert(seq.begin(), cloned.begin(), cloned.end()); + } + } + + return seqs; +} + +//===----------------------------------------------------------------------===// + +using xla::runtime::CustomCallDeclarations; + +static std::vector GetGraphCaptureFuncArgs(const CaptureSequence& seq) { + llvm::SetVector args; + + // Values defined by operations in the capture sequence. + llvm::DenseSet defined_by_seq; + for (auto& [op, _] : seq) + defined_by_seq.insert(op->result_begin(), op->result_end()); + + // Add arguments defined outside of the capture sequence. + for (auto& [op, _] : seq) { + auto external_args = llvm::make_filter_range( + op->getOperands(), + [&](Value arg) { return !defined_by_seq.contains(arg); }); + args.insert(external_args.begin(), external_args.end()); + } + + return args.takeVector(); +} + +// Given a sequence of operations, outline them into a graph capture function +// and replace them with an XLA Gpu runtime function call. +static LogicalResult Outline(unsigned ordinal, + CustomCallDeclarations& custom_calls, + CaptureSequence& seq) { + // Only operations that have to be moved into the graph capture function + // represent Gpu computations. + unsigned num_move_captures = llvm::count_if(seq, [](auto capture) { + return capture.second == OpCapturePattern::Capture::kMove; + }); + if (num_move_captures < 2) return failure(); + + SymbolTable& sym_table = custom_calls.sym_table(); + MLIRContext* ctx = sym_table.getOp()->getContext(); + + // Create a fused location out of LaunchFuncOp operations. + llvm::SmallVector locations; + for (auto& op : seq) locations.push_back(op.first->getLoc()); + ImplicitLocOpBuilder b(FusedLoc::get(ctx, locations), sym_table.getOp()); + + // Arguments of the graph capture function. + std::vector args = GetGraphCaptureFuncArgs(seq); + + // Create a function in the compiled module. + auto func = b.create( + "xla.gpu.cuda.graph.capture", + FunctionType::get(ctx, TypeRange(ValueRange(args)), TypeRange())); + + // Add graph capture function to the module. + sym_table.insert(func); + + // Export graph capture function to the runtime. + b.setInsertionPoint(func); + b.create(func, ordinal); + + // Create a custom call declaration corresponding to the outlined graph + // capture function. + func::FuncOp graph_launch = custom_calls.GetOrCreate( + b, "xla.gpu.cuda.graph.launch", TypeRange(ValueRange(args)), TypeRange()); + + // Call the cuda graph launch custom call right before the first moved op. + auto insertion_point = llvm::find_if(seq, [](auto capture) { + return capture.second == OpCapturePattern::Capture::kMove; + }); + b.setInsertionPoint(insertion_point->first); + + auto call = b.create(graph_launch.getName(), TypeRange(), args); + call->setAttr(b.getStringAttr("capture"), FlatSymbolRefAttr::get(func)); + + // At this point we successfully added new functions to the module, so we can + // move or clone captured operations from their original location to the graph + // capture function. + Block* body = func.addEntryBlock(); + + // We'll need to replace operands of cloned/moved operations inside the graph + // capture function. + llvm::SmallVector> mappings; // {from, to} mappings + for (auto mapping : llvm::zip(args, func.getArguments())) + mappings.emplace_back(std::get<0>(mapping), std::get<1>(mapping)); + + // Move or clone operations into the graph capture function. + for (auto& [op, capture] : seq) { + if (capture == OpCapturePattern::Capture::kMove) + op->moveBefore(body, body->end()); + + if (capture == OpCapturePattern::Capture::kClone) { + Operation* clone = op->clone(); + OpBuilder::atBlockEnd(body).insert(clone); + + for (auto mapping : llvm::zip(op->getResults(), clone->getResults())) + mappings.emplace_back(std::get<0>(mapping), std::get<1>(mapping)); + } + } + + // Update def-use chains inside the graph capture function. + for (auto mapping : mappings) { + replaceAllUsesInRegionWith(mapping.first, mapping.second, func.getBody()); + } + + // Add a return operation to the graph capture function. + b.setInsertionPointToEnd(body); + b.create(ValueRange()); + + return success(); +} + +//===----------------------------------------------------------------------===// + +void OutlineCudaGraphsPass::runOnOperation() { + SymbolTable sym_table(getOperation()); + CustomCallDeclarations custom_calls(std::move(sym_table)); + + OpCapturePatternSet patterns; + patterns.emplace_back(new LaunchFuncOpCapture()); + patterns.emplace_back(new ConvOpCapture()); + patterns.emplace_back(new ConstantOpCapture()); + patterns.emplace_back(new ViewOpCapture()); + + unsigned ordinal = 1; // entry point will be exported with ordinal 0 + for (auto& seq : CollectCaptureSequences(getAnalysis(), + getOperation(), patterns)) { + if (succeeded(Outline(ordinal, custom_calls, seq))) ordinal++; + } +} + +std::unique_ptr> createOutlineCudaGraphsPass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/passes.cc b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc similarity index 74% rename from tensorflow/compiler/xla/mlir/transforms/gpu/passes.cc rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc index 058e93632ce..e71c76ab0e7 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/passes.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/mlir/transforms/gpu/passes.h" +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h" #include #include @@ -26,31 +26,28 @@ namespace gpu { using namespace mlir; // NOLINT -static bool UseExperimentalCudaGraphs() { - std::string_view flag = std::getenv("XLA_GPU_RUNTIME_USE_CUDA_GRAPHS"); - return flag == "true" || flag == "1"; -} - void populateXlaGpuRuntimePasses(mlir::OpPassManager& pm, - ThunkSequence* thunk_sequence) { + ThunkSequence* thunk_sequence, + const GpuPipelineOpts& opts) { // Lower operations with registered IR emitters to Gpu launches. pm.addPass(createConvertLmhloToGpuLaunchPass(thunk_sequence)); + // Clean up IR before converting it to the runtime operations. + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + // Convert global memrefs corresponding to constant arguments. pm.addPass(createConvertMemrefGetGlobalToArgPass()); pm.addPass(createSymbolDCEPass()); // Clean up unused global constants. + // Outline CUDA-Graph-compatible operations into graph capture functions. + if (opts.enable_cuda_graphs) { + pm.addPass(createOutlineCudaGraphsPass()); + } + // Lower all Gpu operations to the XLA Gpu runtime custom calls. pm.addPass(createConvertLmhloGpuToGpuRuntimePass()); pm.addPass(createConvertLmhloToGpuRuntimePass()); - - // Enable experimental pass that wraps all launch func operations into Cuda - // Graph. Currently it's intended to be a proof of concept and not anywhere - // near production readiness. - if (UseExperimentalCudaGraphs()) { - pm.addPass(createConvertLaunchFuncToCudaGraphPass()); - } - pm.addPass(createConvertGpuToGpuRuntimePass()); // Add performance tracing annotations. diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/passes.h b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h similarity index 80% rename from tensorflow/compiler/xla/mlir/transforms/gpu/passes.h rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h index badbc2205cf..1777e06505f 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/passes.h +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_MLIR_TRANSFORMS_GPU_PASSES_H_ -#define TENSORFLOW_COMPILER_XLA_MLIR_TRANSFORMS_GPU_PASSES_H_ +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_GPU_TRANSFORMS_PASSES_H_ +#define TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_GPU_TRANSFORMS_PASSES_H_ #include @@ -30,16 +30,24 @@ namespace gpu { #define GEN_PASS_DECL_CONVERTLMHLOTOGPULAUNCHPASS #define GEN_PASS_DECL_CONVERTLMHLOTOGPURUNTIMEPASS #define GEN_PASS_DECL_CONVERTMEMREFGETGLOBALTOARGPASS -#define GEN_PASS_DECL_CONVERTLAUNCHFUNCTOCUDAGRAPHPASS -#include "tensorflow/compiler/xla/mlir/transforms/gpu/passes.h.inc" +#define GEN_PASS_DECL_OUTLINECUDAGRAPHSPASS +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h.inc" class ThunkSequence; // forward declare +struct GpuPipelineOpts { + // Enable experimental pass that outlines parts of the XLA computation into + // CUDA Graphs, which allows us to amortize the cost of launching multiple + // device kernels. + bool enable_cuda_graphs = false; +}; + // Populate passes that lower MLIR modules from a combination of LMHLO and // LMHLO_GPU dialects to the XLA Gpu runtime. This pipeline is composed from // the passes defined below, and few builtin MLIR passes. void populateXlaGpuRuntimePasses(mlir::OpPassManager& pm, - ThunkSequence* thunk_sequence); + ThunkSequence* thunk_sequence, + const GpuPipelineOpts& opts = {}); //===----------------------------------------------------------------------===// // Auxiliary passes for lowering to XLA Gpu runtime. @@ -83,18 +91,18 @@ std::unique_ptr> createAddHloTraceAnnotationsPass(); //===----------------------------------------------------------------------===// -// XLA runtime <-> Cuda Graphs experimental integration. +// XLA runtime <-> Cuda Graphs integration. //===----------------------------------------------------------------------===// std::unique_ptr> -createConvertLaunchFuncToCudaGraphPass(); +createOutlineCudaGraphsPass(); //===-----------------------------------------------------------------------===/ #define GEN_PASS_REGISTRATION -#include "tensorflow/compiler/xla/mlir/transforms/gpu/passes.h.inc" +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h.inc" } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_MLIR_TRANSFORMS_GPU_PASSES_H_ +#endif // TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_GPU_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/passes.td b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td similarity index 89% rename from tensorflow/compiler/xla/mlir/transforms/gpu/passes.td rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td index 7374eaad7ed..9608e325811 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/passes.td +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.td @@ -149,17 +149,17 @@ def AddHloTraceAnnotationsPass : } //===----------------------------------------------------------------------===// -// Experimental passes for Xla Gpu <-> Cuda Graphs integration. +// Xla Gpu <-> Cuda Graphs integration. //===----------------------------------------------------------------------===// -def ConvertLaunchFuncToCudaGraphPass : - Pass<"xla-gpu-launch-func-to-cuda-graphs", "mlir::ModuleOp"> { - let summary = "Capture sequence of Gpu function launches as cuda graphs"; +def OutlineCudaGraphsPass : + Pass<"xla-gpu-outline-cuda-graphs", "mlir::ModuleOp"> { + let summary = "Outline sequences of Xla Gpu operations into CUDA Graphs"; let description = [{ - Converts sequences of two or more `gpu.launch_func` operations to Cuda - Graph building functions, and replaces the original sequences with calls to - the Xla Cuda Graph runtime API. + Converts sequences of supported Xla Gpu operations to Cuda Graph capture + functions, and replaces the original sequences with calls to the Xla Cuda + Graph runtime API. Example: @@ -171,20 +171,20 @@ def ConvertLaunchFuncToCudaGraphPass : becomes: ```mlir - // Export cuda graph builder function to Xla runtime. - rt.export @builder ordinal 1 - func.func @builder(@arg0: memref, %arg1: memref) { + // Export cuda graph capture function to Xla runtime. + rt.export @capture ordinal 1 + func.func @capture(@arg0: memref, %arg1: memref) { ... capture a graph corresponding to a sequence of `gpu.launch_func` ops } // Replace a sequence of graph launch operations with a call to runtime API. - call @xla.gpu.cuda.graph.execute(%arg0: memref, + call @xla.gpu.cuda.graph.launch(%arg0: memref, %arg1: memref) - attributes { builder = @builder } + attributes { capture = @capture } ``` }]; - let constructor = "createConvertLaunchFuncToCudaGraphPass()"; + let constructor = "createOutlineCudaGraphsPass()"; } #endif // XLA_GPU_PASSES diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/BUILD b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/BUILD new file mode 100644 index 00000000000..3d18ffe7130 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/BUILD @@ -0,0 +1,24 @@ +load("//tensorflow/tsl:tsl.default.bzl", "filegroup") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +glob_lit_tests( + data = [":test_utilities"], + driver = "//tensorflow/compiler/xla:run_lit.sh", + test_file_exts = ["mlir"], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/xla/mlir/backends/gpu:xla-gpu-opt", + "@llvm-project//llvm:FileCheck", + "@llvm-project//mlir:run_lit.sh", + ], +) diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/add_hlo_trace.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/add_hlo_trace.mlir similarity index 77% rename from tensorflow/compiler/xla/mlir/transforms/gpu/tests/add_hlo_trace.mlir rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/add_hlo_trace.mlir index bb5cb5dce72..d2f46c8d6bd 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/add_hlo_trace.mlir +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/add_hlo_trace.mlir @@ -1,4 +1,4 @@ -// RUN: xla-gpu-opt %s -xla-add-hlo-trace-annotations | FileCheck %s --dump-input=always +// RUN: xla-gpu-opt %s -xla-add-hlo-trace-annotations | FileCheck %s module attributes { mhlo.unique_id = 42 : i64 } { @@ -7,7 +7,7 @@ func.func private @xla.foo() attributes { rt.custom_call = "xla.foo" } // CHECK: func @func() { func.func @func() { // CHECK: call @xla.foo() - // CHECK-SAME: rt.trace = #rt.hlo_trace<"gemm.name.42", "module-name", 42> + // CHECK-SAME: rt.trace = #rt.hlo_trace<"gemm.name.42"> call @xla.foo() : () -> () loc("gemm.name.42") return } diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/gpu_launch.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/gpu_launch.mlir similarity index 68% rename from tensorflow/compiler/xla/mlir/transforms/gpu/tests/gpu_launch.mlir rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/gpu_launch.mlir index 93a69774e55..e05ff982bb3 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/gpu_launch.mlir +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/gpu_launch.mlir @@ -18,33 +18,37 @@ gpu.module @gpu_module attributes {binary = "kernel binary"} { // CHECK: ) func.func @func(%arg0: memref<4x4xf32>, %arg1: memref<4x4xf32>) { // Launch dimensions converted to i32 as a part of the lowering. - // CHECK: %[[C1:.*]] = arith.constant 1 : i32 - // CHECK: %[[C2:.*]] = arith.constant 2 : i32 - // CHECK: %[[C3:.*]] = arith.constant 3 : i32 - // CHECK: %[[C4:.*]] = arith.constant 4 : i32 - // CHECK: %[[C5:.*]] = arith.constant 5 : i32 - // CHECK: %[[C6:.*]] = arith.constant 6 : i32 + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 + // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : i32 + // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : i32 + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32 + // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : i32 + // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : i32 + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 + // CHECK-DAG: %[[C256:.*]] = arith.constant 256 : i32 %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index %c4 = arith.constant 4 : index %c5 = arith.constant 5 : index %c6 = arith.constant 6 : index + %c256 = arith.constant 256 : i32 - // CHECK: call @[[LAUNCH:[_a-z.]+]](%[[C1]], %[[C2]], %[[C3]], %[[C4]], - // CHECK-SAME: %[[C5]], %[[C6]], %[[ARG0]], %[[ARG1]]) + // CHECK: call @[[LAUNCH:[_a-z.]+]](%[[C0]], %[[C1]], %[[C2]], %[[C3]], + // CHECK-SAME: %[[C4]], %[[C5]], %[[C6]], %[[ARG0]], %[[ARG1]]) // CHECK-SAME: kernel = "fn0" gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c2, %c3) threads in (%c4, %c5, %c6) args(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) - // CHECK: call @[[LAUNCH]](%[[C3]], %[[C2]], %[[C1]], %[[C6]], + // CHECK: call @[[LAUNCH]](%[[C256]], %[[C3]], %[[C2]], %[[C1]], %[[C6]], // CHECK-SAME: %[[C5]], %[[C4]], %[[ARG0]], %[[ARG1]]) // CHECK-DAG: kernel = "fn1" gpu.launch_func @gpu_module::@fn1 blocks in (%c3, %c2, %c1) threads in (%c6, %c5, %c4) + dynamic_shared_memory_size %c256 args(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) func.return diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/gpu_memcpy.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/gpu_memcpy.mlir similarity index 100% rename from tensorflow/compiler/xla/mlir/transforms/gpu/tests/gpu_memcpy.mlir rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/gpu_memcpy.mlir diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/gpu_memset.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/gpu_memset.mlir similarity index 100% rename from tensorflow/compiler/xla/mlir/transforms/gpu/tests/gpu_memset.mlir rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/gpu_memset.mlir diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_case.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_case.mlir similarity index 100% rename from tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_case.mlir rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_case.mlir diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_custom_call.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_custom_call.mlir similarity index 95% rename from tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_custom_call.mlir rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_custom_call.mlir index 1013cf446dd..0a1b067f50c 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_custom_call.mlir +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_custom_call.mlir @@ -10,7 +10,7 @@ func.func @test(%arg0: memref) { // CHECK-SAME: backend_config = "" // CHECK-SAME: call_target_name = "target" // CHECK-SAME: : (memref) -> () - "lmhlo.custom_call"(%arg0) { + "lmhlo.custom_call"(%arg0) ({}) { api_version = 2 : i32, backend_config = "", call_target_name = "target", @@ -44,7 +44,7 @@ func.func @test_with_mapping( // CHECK-SAME: api_version = 1 : i32 // CHECK-SAME: backend_config = "" // CHECK-SAME: call_target_name = "target" - "lmhlo.custom_call"(%arg0, %arg1, %arg2, %arg3, %arg4) { + "lmhlo.custom_call"(%arg0, %arg1, %arg2, %arg3, %arg4) ({}) { api_version = 1 : i32, backend_config = "", call_target_name = "target", diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_fft.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_fft.mlir similarity index 96% rename from tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_fft.mlir rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_fft.mlir index 99c97350fc9..aeb19228d01 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_fft.mlir +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_fft.mlir @@ -10,6 +10,7 @@ func.func @compute(%arg0: memref<3x5x16x5xcomplex>, // CHECK: call @[[FFT:.*]](%[[ARG0]], %[[ARG1]]) // CHECK-SAME: fft_length = dense<[16, 8]> : tensor<2xi64> // CHECK-SAME: fft_type = #mhlo + // CHECK-SAME: uid = 0 : i64 "lmhlo.fft"(%arg0, %arg1) { fft_length = dense<[16, 8]> : tensor<2xi64>, fft_type = #mhlo diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_gpu_cholesky.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cholesky.mlir similarity index 100% rename from tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_gpu_cholesky.mlir rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cholesky.mlir diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_gpu_conv.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_conv.mlir similarity index 100% rename from tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_gpu_conv.mlir rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_conv.mlir diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_gpu_cublas_lt_matmul.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cublas_lt_matmul.mlir similarity index 94% rename from tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_gpu_cublas_lt_matmul.mlir rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cublas_lt_matmul.mlir index 15f0b199906..d5f99754fd9 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_gpu_cublas_lt_matmul.mlir +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_cublas_lt_matmul.mlir @@ -34,7 +34,8 @@ func.func @compute(%a: memref<2x6x2x2xf32>, lhs_contracting_dimensions = [3], rhs_contracting_dimensions = [2]>, epilogue = #lmhlo_gpu, - precision_config = [#mhlo, #mhlo] + precision_config = [#mhlo, #mhlo], + operand_segment_sizes = array } : (memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, memref<2x6x2x2xf32>) -> () @@ -70,7 +71,7 @@ func.func @compute(%a: memref<2x6x2x2xf32>, // CHECK-SAME: rhs_batching_dimensions = [0, 1], // CHECK-SAME: lhs_contracting_dimensions = [3], // CHECK-SAME: rhs_contracting_dimensions = [2]> - // CHECK-SAME: epilogue = #lmhlo_gpu + // CHECK-SAME: epilogue = #lmhlo_gpu // CHECK-SAME: precision = dense<0> : tensor<2xi32> // CHECK-SAME: uid = 0 : i64 "lmhlo_gpu.cublas.lt.matmul"(%a, %b, %c, %d, %bias) { @@ -83,8 +84,9 @@ func.func @compute(%a: memref<2x6x2x2xf32>, rhs_batching_dimensions = [0, 1], lhs_contracting_dimensions = [3], rhs_contracting_dimensions = [2]>, - epilogue = #lmhlo_gpu, - precision_config = [#mhlo, #mhlo] + epilogue = #lmhlo_gpu, + precision_config = [#mhlo, #mhlo], + operand_segment_sizes = array } : (memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, memref<2x6x2x2xf32>, memref<2x6x2x2xf32>) -> () diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_gpu_gemm.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_gemm.mlir similarity index 100% rename from tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_gpu_gemm.mlir rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_gpu_gemm.mlir diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_infeed.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_infeed.mlir similarity index 100% rename from tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_infeed.mlir rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_infeed.mlir diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_outfeed.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_outfeed.mlir similarity index 100% rename from tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_outfeed.mlir rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_outfeed.mlir diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_send_recv.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_send_recv.mlir new file mode 100644 index 00000000000..f51bc6a1657 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_send_recv.mlir @@ -0,0 +1,102 @@ +// RUN: xla-gpu-opt %s -split-input-file -xla-lmhlo-to-gpu-runtime \ +// RUN: | FileCheck %s + +// CHECK: func @send( +// CHECK: %[[ARG0:[a-z0-9]+]]: memref<4xf32> +// CHECK: ) +func.func @send(%arg0: memref<4xf32>) { + // CHECK: call @xla.gpu.send(%[[ARG0]]) { + // CHECK-SAME: channel_handle = #mhlo.channel_handle, + // CHECK-SAME: frontend_attributes = { + // CHECK-SAME: _xla_dcn_recv_channel = "2", + // CHECK-SAME: _xla_host_transfer_handler_name = "undef", + // CHECK-SAME: _xla_host_transfer_is_lower_bits = "false", + // CHECK-SAME: _xla_host_transfer_original_type = "f32", + // CHECK-SAME: _xla_host_transfer_rendezvous = "undef" + // CHECK-SAME: }, + // CHECK-SAME: is_host_transfer = true + // CHECK-SAME: } : (memref<4xf32>) -> () + "lmhlo.send"(%arg0) { + channel_handle = #mhlo.channel_handle, + frontend_attributes = {_xla_dcn_recv_channel = "2", + _xla_host_transfer_handler_name = "undef", + _xla_host_transfer_is_lower_bits = "false", + _xla_host_transfer_original_type = "f32", + _xla_host_transfer_rendezvous = "undef"}, + is_host_transfer = true + } : (memref<4xf32>) -> !mhlo.token + return +} + +// CHECK: func private @xla.gpu.send(memref<4xf32>) +// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.send"} + +// ----- + +// CHECK: func @recv( +// CHECK: %[[ARG0:[a-z0-9]+]]: memref<4xf32> +// CHECK: ) +func.func @recv(%arg0: memref<4xf32>) { + // CHECK: call @xla.gpu.recv(%[[ARG0]]) { + // CHECK-SAME: channel_handle = #mhlo.channel_handle, + // CHECK-SAME: frontend_attributes = { + // CHECK-SAME: _xla_host_transfer_handler_name = "undef", + // CHECK-SAME: _xla_host_transfer_is_lower_bits = "false", + // CHECK-SAME: _xla_host_transfer_original_type = "f32", + // CHECK-SAME: _xla_host_transfer_rendezvous = "undef" + // CHECK-SAME: }, + // CHECK-SAME: is_host_transfer = true + // CHECK-SAME: } : (memref<4xf32>) -> () + "lmhlo.recv"(%arg0) { + channel_handle = #mhlo.channel_handle, + frontend_attributes = {_xla_host_transfer_handler_name = "undef", + _xla_host_transfer_is_lower_bits = "false", + _xla_host_transfer_original_type = "f32", + _xla_host_transfer_rendezvous = "undef"}, + is_host_transfer = true + } : (memref<4xf32>) -> !mhlo.token + return +} + +// CHECK: func private @xla.gpu.recv(memref<4xf32>) +// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.recv"} + +// ----- + +// CHECK: func @send_done( +// CHECK: %[[ARG0:[a-z0-9]+]]: !mhlo.token +// CHECK: ) +func.func @send_done(%arg0: !mhlo.token) { + // CHECK: call @xla.gpu.send_done() { + // CHECK-SAME: channel_handle = #mhlo.channel_handle, + // CHECK-SAME: is_host_transfer = true + // CHECK-SAME: } : () -> () + "lmhlo.send_done"(%arg0) { + channel_handle = #mhlo.channel_handle, + is_host_transfer = true + } : (!mhlo.token) -> () + return +} + +// CHECK: func private @xla.gpu.send_done() +// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.send_done"} + +// ----- + +// CHECK: func @recv_done( +// CHECK: %[[ARG0:[a-z0-9]+]]: !mhlo.token +// CHECK: ) +func.func @recv_done(%arg0: !mhlo.token) { + // CHECK: call @xla.gpu.recv_done() { + // CHECK-SAME: channel_handle = #mhlo.channel_handle, + // CHECK-SAME: is_host_transfer = true + // CHECK-SAME: } : () -> () + "lmhlo.recv_done"(%arg0) { + channel_handle = #mhlo.channel_handle, + is_host_transfer = true + } : (!mhlo.token) -> () + return +} + +// CHECK: func private @xla.gpu.recv_done() +// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.recv_done"} diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_while.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_while.mlir similarity index 100% rename from tensorflow/compiler/xla/mlir/transforms/gpu/tests/lmhlo_while.mlir rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/lmhlo_while.mlir diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/memref_get_global_to_arg.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/memref_get_global_to_arg.mlir similarity index 99% rename from tensorflow/compiler/xla/mlir/transforms/gpu/tests/memref_get_global_to_arg.mlir rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/memref_get_global_to_arg.mlir index 9583a5b7fd3..6361c77f145 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/memref_get_global_to_arg.mlir +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/memref_get_global_to_arg.mlir @@ -38,6 +38,6 @@ func.func @get_global(%arg0: memref<24xi8> {lmhlo.constant_name = "cst0"}, %2 = memref.get_global @cst2 : memref<2x3xf32, #map> // CHECK: return %[[V0]], %[[V1]], %[[V2]] - // CHECK-SAME: : memref<2x3xf32>, memref, memref<2x3xf32, #map> + // CHECK-SAME: : memref<2x3xf32>, memref, memref<2x3xf32, #map{{[0-9]*}}> return %0, %1, %2 : memref<2x3xf32>, memref, memref<2x3xf32, #map> } diff --git a/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir new file mode 100644 index 00000000000..72593f188c1 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/tests/outline_cuda_graphs.mlir @@ -0,0 +1,290 @@ +// RUN: xla-gpu-opt %s --split-input-file -xla-gpu-outline-cuda-graphs \ +// RUN: | FileCheck %s + +module attributes {gpu.container_module} { + +gpu.module @gpu_module attributes {binary = "kernel binary"} { + gpu.func @fn0(%arg0: memref) kernel { + gpu.return + } + gpu.func @fn1(%arg0: memref) kernel { + gpu.return + } +} + +// CHECK: @func( +// CHECK: %[[ARG0:.*]]: memref, +// CHECK: %[[ARG1:.*]]: memref +// CHECK: ) +func.func @func(%arg0: memref, %arg1: memref) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + + // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture} + // CHECK-NEXT: return + + gpu.launch_func @gpu_module::@fn0 + blocks in (%c1, %c2, %c3) + threads in (%c4, %c5, %c6) + args(%arg0 : memref) + + gpu.launch_func @gpu_module::@fn1 + blocks in (%c3, %c2, %c1) + threads in (%c6, %c5, %c4) + args(%arg1 : memref) + + func.return +} + +// CHECK: func @xla.gpu.cuda.graph.capture +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 +// CHECK-NEXT: %[[C3:.*]] = arith.constant 3 +// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 +// CHECK-NEXT: %[[C5:.*]] = arith.constant 5 +// CHECK-NEXT: %[[C6:.*]] = arith.constant 6 +// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 +// CHECK-SAME: blocks in (%[[C1]], %[[C2]], %[[C3]]) +// CHECK-SAME: threads in (%[[C4]], %[[C5]], %[[C6]]) +// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 +// CHECK-SAME: blocks in (%[[C3]], %[[C2]], %[[C1]]) +// CHECK-SAME: threads in (%[[C6]], %[[C5]], %[[C4]]) +// CHECK-NEXT: return + +// CHECK: func private @xla.gpu.cuda.graph.launch(memref, memref) +// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.cuda.graph.launch"} +} + +// ----- +// Check that single function launch was not outlined into graph capture. + +module attributes {gpu.container_module} { + +gpu.module @gpu_module attributes {binary = "kernel binary"} { + gpu.func @fn0(%arg0: memref) kernel { + gpu.return + } +} + +// CHECK: @func(%[[ARG0:.*]]: memref) +func.func @func(%arg0: memref) { + %c1 = arith.constant 1 : index + + // CHECK: gpu.launch_func {{.*}} args(%[[ARG0]] : memref) + // CHECK-NOT: call @xla.gpu.cuda.graph.launch + gpu.launch_func @gpu_module::@fn0 + blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) + args(%arg0 : memref) + + func.return +} + +} + +// ----- +// Check that two different sequences are outlined in different capture +// functions. + +module attributes {gpu.container_module} { + +gpu.module @gpu_module attributes {binary = "kernel binary"} { + gpu.func @fn0(%arg0: memref) kernel { + gpu.return + } + gpu.func @fn1(%arg0: memref) kernel { + gpu.return + } +} + +// CHECK: @func(%[[ARG0:.*]]: memref) +func.func @func(%arg0: memref) { + // CHECK: %[[C1:.*]] = arith.constant 1 + %c1 = arith.constant 1 : index + + // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]]) + // CHECK-SAME: {capture = @[[CAPTURE:.*]]} + + gpu.launch_func @gpu_module::@fn0 + blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) + args(%arg0 : memref) + + gpu.launch_func @gpu_module::@fn1 + blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) + args(%arg0 : memref) + + // CHECK: %[[C2:.*]] = arith.constant 2 + %c2 = arith.constant 2 : index + + // Use function call to break the captured ops sequence. + // CHECK: call @external + call @external(): () -> () + + // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]]) + // CHECK-SAME: {capture = @[[CAPTURE_0:.*]]} + + gpu.launch_func @gpu_module::@fn1 + blocks in (%c2, %c2, %c2) + threads in (%c2, %c2, %c2) + args(%arg0 : memref) + + gpu.launch_func @gpu_module::@fn0 + blocks in (%c2, %c2, %c2) + threads in (%c2, %c2, %c2) + args(%arg0 : memref) + + func.return +} + +func.func private @external() + +// CHECK: rt.export @[[CAPTURE]] +// CHECK: func.func @[[CAPTURE]](%arg0: memref) +// CHECK-NEXT: arith.constant 1 +// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 +// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 + +// CHECK: rt.export @[[CAPTURE_0]] +// CHECK: func.func @[[CAPTURE_0]](%arg0: memref) +// CHECK-NEXT: arith.constant 2 +// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 +// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 + +} + +// ----- +// Check that constants from the different basic blocks are cloned into the +// graph capture function. + +module attributes {gpu.container_module} { + +gpu.module @gpu_module attributes {binary = "kernel binary"} { + gpu.func @fn0(%arg0: memref) kernel { + gpu.return + } + gpu.func @fn1(%arg0: memref) kernel { + gpu.return + } +} + +// CHECK: @func( +// CHECK: %[[ARG0:.*]]: memref, +// CHECK: %[[ARG1:.*]]: memref +// CHECK: ) +func.func @func(%arg0: memref, %arg1: memref) { + cf.br ^bb2 +^bb1: + // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture} + // CHECK-NEXT: return + + gpu.launch_func @gpu_module::@fn0 + blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) + args(%arg0 : memref) + + gpu.launch_func @gpu_module::@fn1 + blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) + args(%arg1 : memref) + + func.return + +^bb2: + %c1 = arith.constant 1 : index + cf.br ^bb1 +} +} + +// CHECK: func @xla.gpu.cuda.graph.capture +// CHECK-NEXT: arith.constant 1 +// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 +// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 +// CHECK-NEXT: return + +// ----- +// Check that memref.view operations are cloned into the graph capture function. + +module attributes {gpu.container_module} { + +gpu.module @gpu_module attributes {binary = "kernel binary"} { + gpu.func @fn0(%arg0: memref<4xf32>) kernel { gpu.return } + gpu.func @fn1(%arg0: memref<4xf32>) kernel { gpu.return } +} + +// CHECK: @func(%[[ARG0:.*]]: memref<16xi8>) +func.func @func(%arg0: memref<16xi8>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %view = memref.view %arg0[%c0][] : memref<16xi8> to memref<4xf32> + + call @external() : () -> () + + // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]]) + // CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture} + // CHECK-NEXT: return + gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) args(%view : memref<4xf32>) + gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) args(%view : memref<4xf32>) + + func.return +} + +func.func private @external() +} + +// CHECK: func @xla.gpu.cuda.graph.capture +// CHECK-NEXT: arith.constant 0 +// CHECK-NEXT: arith.constant 1 +// CHECK-NEXT: memref.view +// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 +// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 +// CHECK-NEXT: return + +// ----- +// Check that memref.view not used by operations in the captured graph will not +// be moved into the graph capture function. + +module attributes {gpu.container_module} { + +gpu.module @gpu_module attributes {binary = "kernel binary"} { + gpu.func @fn0(%arg0: memref<16xi8>) kernel { gpu.return } + gpu.func @fn1(%arg0: memref<16xi8>) kernel { gpu.return } +} + +// CHECK: @func(%[[ARG0:.*]]: memref<16xi8>) +func.func @func(%arg0: memref<16xi8>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + call @external() : () -> () + + // CHECK: call @xla.gpu.cuda.graph.launch(%[[ARG0]]) + // CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture} + // CHECK-NEXT: memref.view + // CHECK-NEXT: return + gpu.launch_func @gpu_module::@fn0 blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) args(%arg0 : memref<16xi8>) + %view = memref.view %arg0[%c0][] : memref<16xi8> to memref<4xf32> + gpu.launch_func @gpu_module::@fn1 blocks in (%c1, %c1, %c1) + threads in (%c1, %c1, %c1) args(%arg0 : memref<16xi8>) + + func.return +} + +func.func private @external() +} + +// CHECK: func @xla.gpu.cuda.graph.capture +// CHECK-NEXT: arith.constant 1 +// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 +// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 +// CHECK-NEXT: return diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/uid_generator.h b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/uid_generator.h similarity index 83% rename from tensorflow/compiler/xla/mlir/transforms/gpu/uid_generator.h rename to tensorflow/compiler/xla/mlir/backends/gpu/transforms/uid_generator.h index 4740363e13b..9b3a925e0a7 100644 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/uid_generator.h +++ b/tensorflow/compiler/xla/mlir/backends/gpu/transforms/uid_generator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_MLIR_TRANSFORMS_GPU_UID_GENERATOR_H_ -#define TENSORFLOW_COMPILER_XLA_MLIR_TRANSFORMS_GPU_UID_GENERATOR_H_ +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_GPU_TRANSFORMS_UID_GENERATOR_H_ +#define TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_GPU_TRANSFORMS_UID_GENERATOR_H_ #include @@ -39,4 +39,4 @@ class UidGenerator { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_MLIR_TRANSFORMS_GPU_UID_GENERATOR_H_ +#endif // TENSORFLOW_COMPILER_XLA_MLIR_BACKENDS_GPU_TRANSFORMS_UID_GENERATOR_H_ diff --git a/tensorflow/compiler/xla/mlir/tools/xla_gpu_opt.cc b/tensorflow/compiler/xla/mlir/backends/gpu/xla-gpu-opt.cc similarity index 72% rename from tensorflow/compiler/xla/mlir/tools/xla_gpu_opt.cc rename to tensorflow/compiler/xla/mlir/backends/gpu/xla-gpu-opt.cc index ab76a186863..74a6552f191 100644 --- a/tensorflow/compiler/xla/mlir/tools/xla_gpu_opt.cc +++ b/tensorflow/compiler/xla/mlir/backends/gpu/xla-gpu-opt.cc @@ -17,15 +17,16 @@ limitations under the License. #include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project -#include "tensorflow/compiler/xla/mlir/transforms/gpu/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h" +#include "tensorflow/compiler/xla/mlir/backends/gpu/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" int main(int argc, char **argv) { mlir::DialectRegistry registry; - registry.insert(); + registry + .insert(); xla::gpu::registerGpuTransformsPasses(); diff --git a/tensorflow/compiler/xla/mlir/framework/ir/BUILD b/tensorflow/compiler/xla/mlir/framework/ir/BUILD new file mode 100644 index 00000000000..07f6d640466 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/framework/ir/BUILD @@ -0,0 +1,74 @@ +load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +td_library( + name = "td_files", + srcs = [ + "xla_framework_ops.td", + ], + compatible_with = get_compatible_with_cloud(), + deps = [ + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "xla_framework_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ( + ["-gen-op-decls"], + "xla_framework.h.inc", + ), + ( + ["-gen-op-defs"], + "xla_framework.cc.inc", + ), + ( + ["-gen-dialect-decls"], + "xla_framework_dialect.h.inc", + ), + ( + ["-gen-dialect-defs"], + "xla_framework_dialect.cc.inc", + ), + ( + ["-gen-typedef-decls"], + "xla_framework_types.h.inc", + ), + ( + ["-gen-typedef-defs"], + "xla_framework_types.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_framework_ops.td", + deps = [":td_files"], +) + +cc_library( + name = "xla_framework", + srcs = [ + "xla_framework.cc", + "xla_framework.cc.inc", + "xla_framework.h.inc", + ], + hdrs = ["xla_framework.h"], + deps = [ + ":xla_framework_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/xla/ir/xla_framework.cc b/tensorflow/compiler/xla/mlir/framework/ir/xla_framework.cc similarity index 74% rename from tensorflow/compiler/mlir/xla/ir/xla_framework.cc rename to tensorflow/compiler/xla/mlir/framework/ir/xla_framework.cc index 4e9272ac79a..9ad145d764d 100644 --- a/tensorflow/compiler/mlir/xla/ir/xla_framework.cc +++ b/tensorflow/compiler/xla/mlir/framework/ir/xla_framework.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // This file defines the operations used in the xla_framework dialect. -#include "tensorflow/compiler/mlir/xla/ir/xla_framework.h" +#include "tensorflow/compiler/xla/mlir/framework/ir/xla_framework.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project @@ -22,9 +22,9 @@ limitations under the License. #include "mlir/IR/DialectImplementation.h" // from @llvm-project // Generated dialect definitions. -#include "tensorflow/compiler/mlir/xla/ir/xla_framework_dialect.cc.inc" +#include "tensorflow/compiler/xla/mlir/framework/ir/xla_framework_dialect.cc.inc" #define GET_TYPEDEF_CLASSES -#include "tensorflow/compiler/mlir/xla/ir/xla_framework_types.cc.inc" +#include "tensorflow/compiler/xla/mlir/framework/ir/xla_framework_types.cc.inc" namespace mlir { namespace xla_framework { @@ -33,13 +33,13 @@ namespace xla_framework { void XLAFrameworkDialect::initialize() { addOperations< #define GET_OP_LIST -#include "tensorflow/compiler/mlir/xla/ir/xla_framework.cc.inc" +#include "tensorflow/compiler/xla/mlir/framework/ir/xla_framework.cc.inc" #undef GET_OP_LIST >(); addTypes< #define GET_TYPEDEF_LIST -#include "tensorflow/compiler/mlir/xla/ir/xla_framework_types.cc.inc" +#include "tensorflow/compiler/xla/mlir/framework/ir/xla_framework_types.cc.inc" #undef GET_TYPEDEF_LIST >(); } @@ -48,4 +48,4 @@ void XLAFrameworkDialect::initialize() { } // namespace mlir #define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/xla/ir/xla_framework.cc.inc" +#include "tensorflow/compiler/xla/mlir/framework/ir/xla_framework.cc.inc" diff --git a/tensorflow/compiler/mlir/xla/ir/xla_framework.h b/tensorflow/compiler/xla/mlir/framework/ir/xla_framework.h similarity index 72% rename from tensorflow/compiler/mlir/xla/ir/xla_framework.h rename to tensorflow/compiler/xla/mlir/framework/ir/xla_framework.h index 77bc4a264e4..1ab417c5a70 100644 --- a/tensorflow/compiler/mlir/xla/ir/xla_framework.h +++ b/tensorflow/compiler/xla/mlir/framework/ir/xla_framework.h @@ -15,8 +15,8 @@ limitations under the License. // This file defines the operations and types used in the XLAFramework dialect. // -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_XLA_FRAMEWORK_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_IR_XLA_FRAMEWORK_H_ +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_FRAMEWORK_IR_XLA_FRAMEWORK_H_ +#define TENSORFLOW_COMPILER_XLA_MLIR_FRAMEWORK_IR_XLA_FRAMEWORK_H_ #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -26,11 +26,11 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" // from @llvm-project #define GET_TYPEDEF_CLASSES -#include "tensorflow/compiler/mlir/xla/ir/xla_framework_types.h.inc" +#include "tensorflow/compiler/xla/mlir/framework/ir/xla_framework_types.h.inc" #define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/xla/ir/xla_framework.h.inc" -#include "tensorflow/compiler/mlir/xla/ir/xla_framework_dialect.h.inc" +#include "tensorflow/compiler/xla/mlir/framework/ir/xla_framework.h.inc" +#include "tensorflow/compiler/xla/mlir/framework/ir/xla_framework_dialect.h.inc" #undef GET_OP_CLASSES -#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_XLA_FRAMEWORK_H_ +#endif // TENSORFLOW_COMPILER_XLA_MLIR_FRAMEWORK_IR_XLA_FRAMEWORK_H_ diff --git a/tensorflow/compiler/mlir/xla/ir/xla_framework_ops.td b/tensorflow/compiler/xla/mlir/framework/ir/xla_framework_ops.td similarity index 98% rename from tensorflow/compiler/mlir/xla/ir/xla_framework_ops.td rename to tensorflow/compiler/xla/mlir/framework/ir/xla_framework_ops.td index d4a260f22e1..8b57eff2aa4 100644 --- a/tensorflow/compiler/mlir/xla/ir/xla_framework_ops.td +++ b/tensorflow/compiler/xla/mlir/framework/ir/xla_framework_ops.td @@ -37,8 +37,8 @@ def XLAFramework_Dialect : Dialect { static constexpr StringRef kXLAEntryAttrName = "xla_entry"; }]; - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; let useDefaultTypePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } def XLAFramework_BufferType : TypeDef { diff --git a/tensorflow/compiler/xla/mlir/framework/tests/BUILD b/tensorflow/compiler/xla/mlir/framework/tests/BUILD new file mode 100644 index 00000000000..8fefb30ac1b --- /dev/null +++ b/tensorflow/compiler/xla/mlir/framework/tests/BUILD @@ -0,0 +1,28 @@ +load("//tensorflow/tsl:tsl.default.bzl", "filegroup") +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +glob_lit_tests( + data = [":test_utilities"], + driver = "@llvm-project//mlir:run_lit.sh", + test_file_exts = [ + "mlir", + "hlotxt", + ], +) + +# Bundle together all of the test utilities that are used by tests. +# This intentionally does not pull-in the top-level tf-opt to reduce the +# dependencies. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/xla/translate/mhlo_to_lhlo_with_xla:xla-translate-opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-xla-framework.mlir b/tensorflow/compiler/xla/mlir/framework/tests/legalize-xla-framework.mlir similarity index 98% rename from tensorflow/compiler/mlir/xla/tests/legalize-xla-framework.mlir rename to tensorflow/compiler/xla/mlir/framework/tests/legalize-xla-framework.mlir index 8ff92f186c2..3747f70feb7 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-xla-framework.mlir +++ b/tensorflow/compiler/xla/mlir/framework/tests/legalize-xla-framework.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt %s -xla-legalize-xla-framework-to-llvm | FileCheck %s +// RUN: xla-translate-opt %s -xla-legalize-xla-framework-to-llvm | FileCheck %s memref.global "private" constant @__constant_xf32 : memref = dense<42.0> diff --git a/tensorflow/compiler/mlir/xla/tests/outline-with-xla-framework.mlir b/tensorflow/compiler/xla/mlir/framework/tests/outline-with-xla-framework.mlir similarity index 86% rename from tensorflow/compiler/mlir/xla/tests/outline-with-xla-framework.mlir rename to tensorflow/compiler/xla/mlir/framework/tests/outline-with-xla-framework.mlir index 85aa1a9b4db..25dc64e8f61 100644 --- a/tensorflow/compiler/mlir/xla/tests/outline-with-xla-framework.mlir +++ b/tensorflow/compiler/xla/mlir/framework/tests/outline-with-xla-framework.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt %s -split-input-file -outline-with-xla-framework | FileCheck %s +// RUN: xla-translate-opt %s -split-input-file -outline-with-xla-framework | FileCheck %s // CHECK-LABEL: @func_to_outline_xla_framework // CHECK-SAME: %[[ARG0:.*]]: !xla_framework.buffer diff --git a/tensorflow/compiler/mlir/xla/tests/xla-framework.mlir b/tensorflow/compiler/xla/mlir/framework/tests/xla-framework.mlir similarity index 93% rename from tensorflow/compiler/mlir/xla/tests/xla-framework.mlir rename to tensorflow/compiler/xla/mlir/framework/tests/xla-framework.mlir index e081c578484..d8d6e9abf06 100644 --- a/tensorflow/compiler/mlir/xla/tests/xla-framework.mlir +++ b/tensorflow/compiler/xla/mlir/framework/tests/xla-framework.mlir @@ -1,4 +1,4 @@ -// RUN: xla-opt %s | xla-opt | FileCheck %s +// RUN: xla-translate-opt %s | FileCheck %s // CHECK-LABEL: @buffer_type func.func @buffer_type(%arg1: !xla_framework.buffer) -> !xla_framework.buffer diff --git a/tensorflow/compiler/xla/mlir/framework/transforms/BUILD b/tensorflow/compiler/xla/mlir/framework/transforms/BUILD new file mode 100644 index 00000000000..5a1fddaa0a9 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/framework/transforms/BUILD @@ -0,0 +1,57 @@ +load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +gentbl_cc_library( + name = "passes_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=XlaFramework", + ], + "passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes.td", + deps = [ + "//tensorflow/compiler/xla/mlir_hlo:hlo_ops_td_files", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + "@llvm-project//mlir:TensorOpsTdFiles", + ], +) + +cc_library( + name = "passes", + srcs = [ + "outline_with_xla_framework.cc", + "xla_framework_to_llvm_pass.cc", + ], + hdrs = [ + "passes.h", + ], + deps = [ + ":passes_inc_gen", + "//tensorflow/compiler/xla/mlir/framework/ir:xla_framework", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncToLLVM", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) diff --git a/tensorflow/compiler/mlir/xla/transforms/outline_with_xla_framework.cc b/tensorflow/compiler/xla/mlir/framework/transforms/outline_with_xla_framework.cc similarity index 91% rename from tensorflow/compiler/mlir/xla/transforms/outline_with_xla_framework.cc rename to tensorflow/compiler/xla/mlir/framework/transforms/outline_with_xla_framework.cc index 1b7482c0300..76a5cff9cc3 100644 --- a/tensorflow/compiler/mlir/xla/transforms/outline_with_xla_framework.cc +++ b/tensorflow/compiler/xla/mlir/framework/transforms/outline_with_xla_framework.cc @@ -28,8 +28,8 @@ limitations under the License. #include "mlir/IR/TypeRange.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/xla/ir/xla_framework.h" -#include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h" +#include "tensorflow/compiler/xla/mlir/framework/ir/xla_framework.h" +#include "tensorflow/compiler/xla/mlir/framework/transforms/passes.h" namespace mlir { namespace mhlo { @@ -54,14 +54,13 @@ struct OutlineXLAFunc : public RewritePattern { explicit OutlineXLAFunc(MLIRContext *context, PatternBenefit benefit = 1) : RewritePattern(func::FuncOp::getOperationName(), benefit, context) {} - static void filterFuncAttributes(ArrayRef attrs, - bool argAttrs, + static void filterFuncAttributes(func::FuncOp func, bool argAttrs, SmallVectorImpl &result) { - for (const auto &attr : attrs) { + for (const auto &attr : func->getAttrs()) { if (attr.getName() == SymbolTable::getSymbolAttrName() || - attr.getName() == FunctionOpInterface::getTypeAttrName() || + attr.getName() == func.getFunctionTypeAttrName() || attr.getName() == "std.varargs" || - (argAttrs && attr.getName() == func::FuncOp::getArgDictAttrName())) + (argAttrs && attr.getName() == func.getArgAttrsAttrName())) continue; result.push_back(attr); } @@ -91,7 +90,7 @@ struct OutlineXLAFunc : public RewritePattern { ::mlir::xla_framework::BufferType::get(ctx)); auto func_type = FunctionType::get(ctx, operands, result_array); SmallVector attrs; - filterFuncAttributes(func->getAttrs(), true, attrs); + filterFuncAttributes(func, true, attrs); SmallVector arg_attrs; func.getAllArgAttrs(arg_attrs); @@ -132,7 +131,7 @@ struct OutlineXLAFunc : public RewritePattern { }; #define GEN_PASS_DEF_OUTLINEWITHXLAFRAMEWORK -#include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h.inc" +#include "tensorflow/compiler/xla/mlir/framework/transforms/passes.h.inc" class OutlineWithXLAFrameworkPass : public impl::OutlineWithXLAFrameworkBase { diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_passes.h b/tensorflow/compiler/xla/mlir/framework/transforms/passes.h similarity index 87% rename from tensorflow/compiler/mlir/xla/transforms/xla_passes.h rename to tensorflow/compiler/xla/mlir/framework/transforms/passes.h index 35ebbd68300..eb175a0ab2e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_passes.h +++ b/tensorflow/compiler/xla/mlir/framework/transforms/passes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_PASSES_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_PASSES_H_ +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_FRAMEWORK_TRANSFORMS_PASSES_H_ +#define TENSORFLOW_COMPILER_XLA_MLIR_FRAMEWORK_TRANSFORMS_PASSES_H_ #include @@ -51,9 +51,9 @@ void PopulateLegalizeXLAFrameworkToLLVMPatterns(llvm::StringRef device_type, #define GEN_PASS_DECL_LEGALIZEXLAFRAMEWORKTOLLVM #define GEN_PASS_DECL_OUTLINEWITHXLAFRAMEWORK #define GEN_PASS_DECL_PREPAREFOREXPORTPASS -#include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h.inc" +#include "tensorflow/compiler/xla/mlir/framework/transforms/passes.h.inc" } // namespace mhlo } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_PASSES_H_ +#endif // TENSORFLOW_COMPILER_XLA_MLIR_FRAMEWORK_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_passes.td b/tensorflow/compiler/xla/mlir/framework/transforms/passes.td similarity index 100% rename from tensorflow/compiler/mlir/xla/transforms/xla_passes.td rename to tensorflow/compiler/xla/mlir/framework/transforms/passes.td diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_framework_to_llvm_pass.cc b/tensorflow/compiler/xla/mlir/framework/transforms/xla_framework_to_llvm_pass.cc similarity index 96% rename from tensorflow/compiler/mlir/xla/transforms/xla_framework_to_llvm_pass.cc rename to tensorflow/compiler/xla/mlir/framework/transforms/xla_framework_to_llvm_pass.cc index 9ca8864edd2..50ac0bdabb3 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_framework_to_llvm_pass.cc +++ b/tensorflow/compiler/xla/mlir/framework/transforms/xla_framework_to_llvm_pass.cc @@ -26,7 +26,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -35,8 +35,8 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/xla/ir/xla_framework.h" -#include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h" +#include "tensorflow/compiler/xla/mlir/framework/ir/xla_framework.h" +#include "tensorflow/compiler/xla/mlir/framework/transforms/passes.h" namespace mlir { namespace mhlo { @@ -130,7 +130,7 @@ struct BarePtrFuncOpConversion : public ConvertOpToLLVMPattern { "xla_entry function lowered with result values when memrefs should " "be caller supplied"); - BlockAndValueMapping mapping; + IRMapping mapping; auto num_refs = funcOp.getFunctionType().getNumInputs(); auto result_index = 0; for (unsigned i = 0; i < num_refs; ++i) { @@ -179,7 +179,7 @@ struct BarePtrFuncOpConversion : public ConvertOpToLLVMPattern { rewriter.create( loc, ptr, rewriter.create(loc, ptr_type, first_load, - llvm::makeArrayRef(second_index))); + llvm::ArrayRef(second_index))); } else { // Non tuple outputs can be simply mapped to the first load op. @@ -217,7 +217,7 @@ struct BarePtrFuncOpConversion : public ConvertOpToLLVMPattern { }; #define GEN_PASS_DEF_LEGALIZEXLAFRAMEWORKTOLLVM -#include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h.inc" +#include "tensorflow/compiler/xla/mlir/framework/transforms/passes.h.inc" class LegalizeXLAFrameworkToLLVMPass : public impl::LegalizeXLAFrameworkToLLVMBase< diff --git a/tensorflow/compiler/xla/mlir/math/BUILD b/tensorflow/compiler/xla/mlir/math/BUILD new file mode 100644 index 00000000000..39215827e06 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/math/BUILD @@ -0,0 +1,16 @@ +package_group( + name = "friends", + packages = [ + "//tensorflow/compiler/xla/mlir/...", + # copybara:uncomment_begin(google-only) + # # TODO(ezhulenev): Clean up dependencies that are leforvers from Autofusion project. + # "@tf_runtime//...", + # copybara:uncomment_end(google-only) + ], +) + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) diff --git a/tensorflow/compiler/xla/mlir/transforms/math/BUILD b/tensorflow/compiler/xla/mlir/math/transforms/BUILD similarity index 66% rename from tensorflow/compiler/xla/mlir/transforms/math/BUILD rename to tensorflow/compiler/xla/mlir/math/transforms/BUILD index c26222ecc5a..23edd754501 100644 --- a/tensorflow/compiler/xla/mlir/transforms/math/BUILD +++ b/tensorflow/compiler/xla/mlir/math/transforms/BUILD @@ -1,12 +1,10 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") package( - default_visibility = [ - "//tensorflow:internal", - "@tf_runtime//:friends", - ], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/compiler/xla/mlir/math:friends"], licenses = ["notice"], ) @@ -29,15 +27,23 @@ gentbl_cc_library( cc_library( name = "passes", - srcs = ["math_optimization.cc"], + srcs = [ + "math_approximation.cc", + "math_legalization.cc", + "math_optimization.cc", + ], hdrs = ["passes.h"], compatible_with = get_compatible_with_cloud(), deps = [ ":passes_inc_gen", - "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:MathToLibm", "@llvm-project//mlir:MathTransforms", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Transforms", diff --git a/tensorflow/compiler/xla/mlir/math/transforms/math_approximation.cc b/tensorflow/compiler/xla/mlir/math/transforms/math_approximation.cc new file mode 100644 index 00000000000..fffd0126049 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/math/transforms/math_approximation.cc @@ -0,0 +1,293 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "tensorflow/compiler/xla/mlir/math/transforms/passes.h" + +namespace xla { +namespace { + +#define GEN_PASS_DEF_MATHAPPROXIMATIONPASS +#include "tensorflow/compiler/xla/mlir/math/transforms/passes.h.inc" + +using ::llvm::ArrayRef; +using ::llvm::SmallVector; + +using ::mlir::ImplicitLocOpBuilder; +using ::mlir::LogicalResult; +using ::mlir::OperationPass; +using ::mlir::OpRewritePattern; +using ::mlir::PatternRewriter; +using ::mlir::RewritePatternSet; +using ::mlir::Type; +using ::mlir::Value; +using ::mlir::VectorType; + +namespace arith = ::mlir::arith; +namespace func = ::mlir::func; +namespace math = ::mlir::math; +namespace vector = ::mlir::vector; + +using TypePredicate = ::llvm::function_ref; + +// Returns vector shape if the element type is matching the predicate (scalars +// that do match the predicate have shape equal to `{1}`). +llvm::Optional> vectorShape(Type type, + TypePredicate pred) { + // If the type matches the predicate then its shape is `{1}`. + if (pred(type)) return SmallVector{1}; + + // Otherwise check if the type is a vector type. + auto vectorType = type.dyn_cast(); + if (vectorType && pred(vectorType.getElementType())) { + return llvm::to_vector<2>(vectorType.getShape()); + } + + return llvm::None; +} + +bool isF32(Type type) { return type.isF32(); } + +//----------------------------------------------------------------------------// +// Broadcast scalar types and values into vector types and values. +//----------------------------------------------------------------------------// + +// Returns true if shape != {1}. +bool isNonScalarShape(ArrayRef shape) { + return shape.size() > 1 || shape[0] > 1; +} + +// Broadcasts scalar type into vector type (iff shape is non-scalar). +Type broadcast(Type type, ArrayRef shape) { + assert(!type.isa() && "must be scalar type"); + return isNonScalarShape(shape) ? VectorType::get(shape, type) : type; +} + +// Broadcasts scalar value into vector (iff shape is non-scalar). +Value broadcast(ImplicitLocOpBuilder &builder, Value value, + ArrayRef shape) { + assert(!value.getType().isa() && "must be scalar value"); + auto type = broadcast(value.getType(), shape); + return isNonScalarShape(shape) + ? builder.create(type, value) + : value; +} + +//----------------------------------------------------------------------------// +// Helper functions to create constants. +//----------------------------------------------------------------------------// + +Value f32Cst(ImplicitLocOpBuilder &builder, float value) { + return builder.create(builder.getF32FloatAttr(value)); +} + +Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) { + return builder.create(builder.getI32IntegerAttr(value)); +} + +Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) { + Value i32v = i32Cst(builder, static_cast(bits)); + return builder.create(builder.getF32Type(), i32v); +} + +struct EigenExpM1Approximation : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::ExpM1Op op, + PatternRewriter &rewriter) const final; +}; + +LogicalResult EigenExpM1Approximation::matchAndRewrite( + math::ExpM1Op op, PatternRewriter &rewriter) const { + auto shape = vectorShape(op.getOperand().getType(), isF32); + if (!shape.has_value()) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, *shape); + }; + + // expm1(x) = exp(x) - 1 = u - 1. + // We have to handle it carefully when x is near 0, i.e. u ~= 1, + // and when the input is ~= -inf, i.e. u - 1 ~= -1. + Value cstOne = bcast(f32Cst(builder, 1.0f)); + Value cstNegOne = bcast(f32Cst(builder, -1.0f)); + Value x = op.getOperand(); + Value u = builder.create(x); + Value uEqOneOrNaN = + builder.create(arith::CmpFPredicate::UEQ, u, cstOne); + Value uMinusOne = builder.create(u, cstOne); + Value uMinusOneEqNegOne = builder.create( + arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne); + // logU = log(u) ~= x + Value logU = builder.create(u); + + // Detect exp(x) = +inf; written this way to avoid having to form +inf. + Value isInf = + builder.create(arith::CmpFPredicate::OEQ, logU, u); + + // (u - 1) * (x / ~x) + Value expm1 = builder.create( + uMinusOne, builder.create(x, logU)); + expm1 = builder.create(isInf, u, expm1); + Value approximation = builder.create( + uEqOneOrNaN, x, + builder.create(uMinusOneEqNegOne, cstNegOne, expm1)); + rewriter.replaceOp(op, approximation); + + return mlir::success(); +} + +struct LogApproximation : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::LogOp op, + PatternRewriter &rewriter) const final; +}; + +LogicalResult LogApproximation::matchAndRewrite( + math::LogOp op, PatternRewriter &rewriter) const { + auto shape = vectorShape(op.getOperand().getType(), isF32); + if (!shape.has_value()) { + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + } + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, *shape); + }; + + Value cst_min_norm_pos = bcast(f32FromBits(builder, 0x00800000u)); + Value cst_zero = bcast(f32Cst(builder, 0.0f)); + + Value x = op.getOperand(); + + // Flush positive denormals to zero. + Value less_than_zero = + builder.create(arith::CmpFPredicate::OLT, x, cst_zero); + Value less_than_min_norm_pos = builder.create( + arith::CmpFPredicate::OLT, x, cst_min_norm_pos); + x = builder.create( + less_than_min_norm_pos, + builder.create(less_than_zero, x, cst_zero), x); + + // Emit Log2Op instead of LogOp to avoid an infinite match-and-rewrite loop. + Value log2 = builder.create(x); + Value cst = bcast(f32Cst(builder, 6.93147181e-1f)); + Value res = builder.create(cst, log2); + rewriter.replaceOp(op, res); + return mlir::success(); +} + +struct Log1pApproximation : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::Log1pOp op, + PatternRewriter &rewriter) const final; +}; + +// Approximate log(1+x). +LogicalResult Log1pApproximation::matchAndRewrite( + math::Log1pOp op, PatternRewriter &rewriter) const { + auto shape = vectorShape(op.getOperand().getType(), isF32); + if (!shape.has_value()) { + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + } + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, *shape); + }; + + // Approximate log(1+x) using the following, due to W. Kahan: + // u = x + 1.0; + // if (u == 1.0 || u == inf) return x; + // return x * log(u) / (u - 1.0); + // ^^^^^^^^^^^^^^^^^^^^^^ + // "log_large" below. + Value cst_one = bcast(f32Cst(builder, 1.0f)); + Value x = op.getOperand(); + Value u = builder.create(x, cst_one); + Value u_small = + builder.create(arith::CmpFPredicate::OEQ, u, cst_one); + Value log_u = builder.create(u); + Value u_inf = + builder.create(arith::CmpFPredicate::OEQ, u, log_u); + Value log_large = builder.create( + x, builder.create( + log_u, builder.create(u, cst_one))); + Value approximation = builder.create( + builder.create(u_small, u_inf), x, log_large); + rewriter.replaceOp(op, approximation); + return mlir::success(); +} + +void populateMathApproximationPatterns(RewritePatternSet &patterns, + ArrayRef oplist) { + for (const std::string &op : oplist) { + if (op == "all") { + patterns + .add( + patterns.getContext()); + } else if (op == "expm1") { + patterns.add(patterns.getContext()); + } else if (op == "log") { + patterns.add(patterns.getContext()); + } else if (op == "log1p") { + patterns.add(patterns.getContext()); + } + } +} + +struct MathApproximationPass + : public impl::MathApproximationPassBase { + explicit MathApproximationPass(ArrayRef approx_oplist) { + this->oplist = approx_oplist; + } + + void runOnOperation() override; +}; + +void MathApproximationPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + populateMathApproximationPatterns(patterns, oplist); + if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) + signalPassFailure(); +} + +} // namespace + +std::unique_ptr> CreateMathApproximationPass( + ArrayRef oplist) { + return std::make_unique(oplist); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/mlir/math/transforms/math_legalization.cc b/tensorflow/compiler/xla/mlir/math/transforms/math_legalization.cc new file mode 100644 index 00000000000..72e3d7b04db --- /dev/null +++ b/tensorflow/compiler/xla/mlir/math/transforms/math_legalization.cc @@ -0,0 +1,75 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" // from @llvm-project +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project +#include "mlir/Conversion/MathToLibm/MathToLibm.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project +#include "mlir/Dialect/Vector/IR/VectorOps.h" // from @llvm-project +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/math/transforms/passes.h" + +namespace xla { + +using namespace mlir; // NOLINT + +#define GEN_PASS_DEF_MATHLEGALIZATIONPASS +#include "tensorflow/compiler/xla/mlir/math/transforms/passes.h.inc" + +struct MathLegalizationPass + : public impl::MathLegalizationPassBase { + explicit MathLegalizationPass(bool enable_approximations) { + enable_approximations_ = enable_approximations; + } + void runOnOperation() override; +}; + +void MathLegalizationPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + LLVMTypeConverter converter(&getContext()); + + populateMathToLLVMConversionPatterns(converter, patterns); + int32_t libm_log1p_benefit = enable_approximations_ ? 0 : 2; + // MathToLibm patterns are a last resort, so they have a 0 benefit (except + // for log1p if approximations are disabled, because it has accuracy issues + // near 0 if implemented naively). + populateMathToLibmConversionPatterns(patterns, 0, {libm_log1p_benefit}); + + ConversionTarget target(getContext()); + target.addIllegalDialect(); + target.addLegalDialect(); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + } +} + +std::unique_ptr> CreateMathLegalizationPass( + bool enable_approximations) { + return std::make_unique(enable_approximations); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/mlir/transforms/math/math_optimization.cc b/tensorflow/compiler/xla/mlir/math/transforms/math_optimization.cc similarity index 91% rename from tensorflow/compiler/xla/mlir/transforms/math/math_optimization.cc rename to tensorflow/compiler/xla/mlir/math/transforms/math_optimization.cc index f91c66be333..c2caee87646 100644 --- a/tensorflow/compiler/xla/mlir/transforms/math/math_optimization.cc +++ b/tensorflow/compiler/xla/mlir/math/transforms/math_optimization.cc @@ -21,15 +21,14 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" // from @llvm-project #include "mlir/Dialect/X86Vector/X86VectorDialect.h" // from @llvm-project -#include "tensorflow/compiler/xla/mlir/transforms/math/passes.h" +#include "tensorflow/compiler/xla/mlir/math/transforms/passes.h" namespace xla { -namespace runtime { using namespace mlir; // NOLINT #define GEN_PASS_DEF_MATHOPTIMIZATIONPASS -#include "tensorflow/compiler/xla/mlir/transforms/math/passes.h.inc" +#include "tensorflow/compiler/xla/mlir/math/transforms/passes.h.inc" struct MathOptimizationPass : public impl::MathOptimizationPassBase { @@ -56,5 +55,4 @@ std::unique_ptr> CreateMathOptimizationPass( return std::make_unique(enable_avx2); } -} // namespace runtime } // namespace xla diff --git a/tensorflow/compiler/xla/mlir/transforms/math/passes.h b/tensorflow/compiler/xla/mlir/math/transforms/passes.h similarity index 58% rename from tensorflow/compiler/xla/mlir/transforms/math/passes.h rename to tensorflow/compiler/xla/mlir/math/transforms/passes.h index fcf911f504c..38c6f8236bd 100644 --- a/tensorflow/compiler/xla/mlir/transforms/math/passes.h +++ b/tensorflow/compiler/xla/mlir/math/transforms/passes.h @@ -13,27 +13,35 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_MLIR_TRANSFORMS_MATH_PASSES_H_ -#define TENSORFLOW_COMPILER_XLA_MLIR_TRANSFORMS_MATH_PASSES_H_ +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_MATH_TRANSFORMS_PASSES_H_ +#define TENSORFLOW_COMPILER_XLA_MLIR_MATH_TRANSFORMS_PASSES_H_ #include +#include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project namespace xla { -namespace runtime { +#define GEN_PASS_DECL_MATHAPPROXIMATIONPASS #define GEN_PASS_DECL_MATHOPTIMIZATIONPASS -#include "tensorflow/compiler/xla/mlir/transforms/math/passes.h.inc" +#define GEN_PASS_DECL_MATHLEGALIZATIONPASS +#include "tensorflow/compiler/xla/mlir/math/transforms/passes.h.inc" std::unique_ptr> CreateMathOptimizationPass(bool enable_avx2 = false); +std::unique_ptr> +CreateMathApproximationPass(llvm::ArrayRef oplist = {}); + +std::unique_ptr> CreateMathLegalizationPass( + bool enable_approximations = true); + #define GEN_PASS_REGISTRATION -#include "tensorflow/compiler/xla/mlir/transforms/math/passes.h.inc" +#include "tensorflow/compiler/xla/mlir/math/transforms/passes.h.inc" -} // namespace runtime } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_MLIR_TRANSFORMS_MATH_PASSES_H_ +#endif // TENSORFLOW_COMPILER_XLA_MLIR_MATH_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/xla/mlir/transforms/math/passes.td b/tensorflow/compiler/xla/mlir/math/transforms/passes.td similarity index 55% rename from tensorflow/compiler/xla/mlir/transforms/math/passes.td rename to tensorflow/compiler/xla/mlir/math/transforms/passes.td index 66e154bd0fa..2de3f189b68 100644 --- a/tensorflow/compiler/xla/mlir/transforms/math/passes.td +++ b/tensorflow/compiler/xla/mlir/math/transforms/passes.td @@ -32,7 +32,7 @@ def MathOptimizationPass "mlir::x86vector::X86VectorDialect" ]; - let constructor = "::xla::runtime::CreateMathOptimizationPass()"; + let constructor = "::xla::CreateMathOptimizationPass()"; let options = [ Option<"enable_avx2_", "enable-avx2", "bool", "false", @@ -40,4 +40,37 @@ def MathOptimizationPass ]; } +def MathApproximationPass + : Pass<"xla-math-approximation", "mlir::func::FuncOp"> { + let summary = "Approximate math operations for accuracy and speed."; + let constructor = "::xla::CreateMathApproximationPass()"; + let options = [ + ListOption<"oplist", "oplist", "std::string", + "List of math operations to be approximated. Use 'all' to select " + "all supported math operations.">, + ]; +} + +def MathLegalizationPass : Pass<"xla-math-legalization", "mlir::ModuleOp"> { + let summary = "Legalize operations from the `math` dialect."; + + let description = [{ + This pass lowers ops from the Math dialect to LLVM. + }]; + + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::func::FuncDialect", + "mlir::vector::VectorDialect", + "LLVM::LLVMDialect", + ]; + + let constructor = "::xla::CreateMathLegalizationPass()"; + + let options = [ + Option<"enable_approximations_", "enable-approximations", "bool", "true", + "Enable math approximations."> + ]; +} + #endif // XLA_MATH_PASSES diff --git a/tensorflow/compiler/xla/mlir/transforms/math/tests/BUILD b/tensorflow/compiler/xla/mlir/math/transforms/tests/BUILD similarity index 66% rename from tensorflow/compiler/xla/mlir/transforms/math/tests/BUILD rename to tensorflow/compiler/xla/mlir/math/transforms/tests/BUILD index 7f511fa1c49..153584bd1a0 100644 --- a/tensorflow/compiler/xla/mlir/transforms/math/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/math/transforms/tests/BUILD @@ -1,11 +1,14 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup") +load("//tensorflow/tsl:tsl.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], - driver = "//tensorflow/compiler/mlir:run_lit.sh", + driver = "//tensorflow/compiler/xla:run_lit.sh", test_file_exts = ["mlir"], ) diff --git a/tensorflow/compiler/xla/mlir/math/transforms/tests/math_legalization.mlir b/tensorflow/compiler/xla/mlir/math/transforms/tests/math_legalization.mlir new file mode 100644 index 00000000000..f60970f3c09 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/math/transforms/tests/math_legalization.mlir @@ -0,0 +1,19 @@ +// RUN: xla-runtime-opt %s --xla-math-legalization \ +// RUN: | FileCheck %s + +// RUN: xla-runtime-opt %s --xla-math-legalization=enable-approximations=0 \ +// RUN: | FileCheck --check-prefix=NO-APPROX %s + +// CHECK-LABEL: func @log1p( +// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1.0 +// CHECK: %[[P1:.*]] = llvm.fadd %[[C1]] +// CHECK: %[[RET:.*]] = llvm.intr.log(%[[P1]]) +// CHECK: return %[[RET]] + +// NO-APPROX-LABEL: func @log1p( +// NO-APPROX: call @log1pf + +func.func @log1p(%arg0: f32) -> f32 { + %0 = math.log1p %arg0 : f32 + func.return %0 : f32 +} \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir/transforms/math/tests/math_optimization.mlir b/tensorflow/compiler/xla/mlir/math/transforms/tests/math_optimization.mlir similarity index 100% rename from tensorflow/compiler/xla/mlir/transforms/math/tests/math_optimization.mlir rename to tensorflow/compiler/xla/mlir/math/transforms/tests/math_optimization.mlir diff --git a/tensorflow/compiler/xla/mlir/memref/BUILD b/tensorflow/compiler/xla/mlir/memref/BUILD new file mode 100644 index 00000000000..39215827e06 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/memref/BUILD @@ -0,0 +1,16 @@ +package_group( + name = "friends", + packages = [ + "//tensorflow/compiler/xla/mlir/...", + # copybara:uncomment_begin(google-only) + # # TODO(ezhulenev): Clean up dependencies that are leforvers from Autofusion project. + # "@tf_runtime//...", + # copybara:uncomment_end(google-only) + ], +) + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) diff --git a/tensorflow/compiler/xla/mlir/transforms/memref/BUILD b/tensorflow/compiler/xla/mlir/memref/transforms/BUILD similarity index 81% rename from tensorflow/compiler/xla/mlir/transforms/memref/BUILD rename to tensorflow/compiler/xla/mlir/memref/transforms/BUILD index 0ac3f5886d3..cd77dc34eea 100644 --- a/tensorflow/compiler/xla/mlir/transforms/memref/BUILD +++ b/tensorflow/compiler/xla/mlir/memref/transforms/BUILD @@ -1,12 +1,10 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") package( - default_visibility = [ - "//tensorflow:internal", - "@tf_runtime//:friends", - ], + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/compiler/xla/mlir/memref:friends"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/xla/mlir/transforms/memref/aligned_allocations.cc b/tensorflow/compiler/xla/mlir/memref/transforms/aligned_allocations.cc similarity index 92% rename from tensorflow/compiler/xla/mlir/transforms/memref/aligned_allocations.cc rename to tensorflow/compiler/xla/mlir/memref/transforms/aligned_allocations.cc index d3cd9e962a4..97d5c5b561a 100644 --- a/tensorflow/compiler/xla/mlir/transforms/memref/aligned_allocations.cc +++ b/tensorflow/compiler/xla/mlir/memref/transforms/aligned_allocations.cc @@ -20,15 +20,14 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/xla/mlir/transforms/memref/passes.h" +#include "tensorflow/compiler/xla/mlir/memref/transforms/passes.h" namespace xla { -namespace runtime { using namespace mlir; // NOLINT #define GEN_PASS_DEF_ALIGNEDALLOCATIONSPASS -#include "tensorflow/compiler/xla/mlir/transforms/memref/passes.h.inc" +#include "tensorflow/compiler/xla/mlir/memref/transforms/passes.h.inc" struct AlignedAllocationsPass : public impl::AlignedAllocationsPassBase { @@ -56,5 +55,4 @@ std::unique_ptr> CreateAlignedAllocationsPass( return std::make_unique(alignment); } -} // namespace runtime } // namespace xla diff --git a/tensorflow/compiler/xla/mlir/transforms/memref/passes.h b/tensorflow/compiler/xla/mlir/memref/transforms/passes.h similarity index 74% rename from tensorflow/compiler/xla/mlir/transforms/memref/passes.h rename to tensorflow/compiler/xla/mlir/memref/transforms/passes.h index 6db63fdf378..57f77bca519 100644 --- a/tensorflow/compiler/xla/mlir/transforms/memref/passes.h +++ b/tensorflow/compiler/xla/mlir/memref/transforms/passes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_MLIR_TRANSFORMS_MEMREF_PASSES_H_ -#define TENSORFLOW_COMPILER_XLA_MLIR_TRANSFORMS_MEMREF_PASSES_H_ +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_MEMREF_TRANSFORMS_PASSES_H_ +#define TENSORFLOW_COMPILER_XLA_MLIR_MEMREF_TRANSFORMS_PASSES_H_ #include @@ -22,18 +22,16 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project namespace xla { -namespace runtime { #define GEN_PASS_DECL_ALIGNEDALLOCATIONSPASS -#include "tensorflow/compiler/xla/mlir/transforms/memref/passes.h.inc" +#include "tensorflow/compiler/xla/mlir/memref/transforms/passes.h.inc" std::unique_ptr> CreateAlignedAllocationsPass(int64_t alignment = 64); #define GEN_PASS_REGISTRATION -#include "tensorflow/compiler/xla/mlir/transforms/memref/passes.h.inc" +#include "tensorflow/compiler/xla/mlir/memref/transforms/passes.h.inc" -} // namespace runtime } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_MLIR_TRANSFORMS_MEMREF_PASSES_H_ +#endif // TENSORFLOW_COMPILER_XLA_MLIR_MEMREF_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/xla/mlir/transforms/memref/passes.td b/tensorflow/compiler/xla/mlir/memref/transforms/passes.td similarity index 94% rename from tensorflow/compiler/xla/mlir/transforms/memref/passes.td rename to tensorflow/compiler/xla/mlir/memref/transforms/passes.td index f50bf3a4450..cf657aafa49 100644 --- a/tensorflow/compiler/xla/mlir/transforms/memref/passes.td +++ b/tensorflow/compiler/xla/mlir/memref/transforms/passes.td @@ -28,7 +28,7 @@ def AlignedAllocationsPass configured for this pass. }]; - let constructor = "::xla::runtime::CreateAlignedAllocationsPass()"; + let constructor = "::xla::CreateAlignedAllocationsPass()"; let options = [ Option<"alignment_", "alignment", "int64_t", "64", diff --git a/tensorflow/compiler/xla/mlir/transforms/memref/tests/BUILD b/tensorflow/compiler/xla/mlir/memref/transforms/tests/BUILD similarity index 66% rename from tensorflow/compiler/xla/mlir/transforms/memref/tests/BUILD rename to tensorflow/compiler/xla/mlir/memref/transforms/tests/BUILD index 7f511fa1c49..153584bd1a0 100644 --- a/tensorflow/compiler/xla/mlir/transforms/memref/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/memref/transforms/tests/BUILD @@ -1,11 +1,14 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup") +load("//tensorflow/tsl:tsl.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], - driver = "//tensorflow/compiler/mlir:run_lit.sh", + driver = "//tensorflow/compiler/xla:run_lit.sh", test_file_exts = ["mlir"], ) diff --git a/tensorflow/compiler/xla/mlir/transforms/memref/tests/aligned_allocations.mlir b/tensorflow/compiler/xla/mlir/memref/transforms/tests/aligned_allocations.mlir similarity index 100% rename from tensorflow/compiler/xla/mlir/transforms/memref/tests/aligned_allocations.mlir rename to tensorflow/compiler/xla/mlir/memref/transforms/tests/aligned_allocations.mlir diff --git a/tensorflow/compiler/xla/mlir/runtime/BUILD b/tensorflow/compiler/xla/mlir/runtime/BUILD index 76340c2afa9..ea462a3865f 100644 --- a/tensorflow/compiler/xla/mlir/runtime/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/BUILD @@ -1,5 +1,5 @@ -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_binary") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") package_group( name = "friends", @@ -21,21 +21,23 @@ package_group( ) package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [":friends"], licenses = ["notice"], ) -tf_cc_binary( +xla_cc_binary( name = "xla-runtime-opt", srcs = ["xla-runtime-opt.cc"], compatible_with = get_compatible_with_cloud(), deps = [ + "//tensorflow/compiler/xla/mlir/math/transforms:passes", + "//tensorflow/compiler/xla/mlir/memref/transforms:passes", "//tensorflow/compiler/xla/mlir/runtime/ir/tests:testlib", "//tensorflow/compiler/xla/mlir/runtime/transforms:compilation_pipeline_cpu", "//tensorflow/compiler/xla/mlir/runtime/transforms:compilation_pipeline_gpu", "//tensorflow/compiler/xla/mlir/runtime/transforms:passes", - "//tensorflow/compiler/xla/mlir/transforms/math:passes", - "//tensorflow/compiler/xla/mlir/transforms/memref:passes", + "@llvm-project//mlir:AsyncDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", diff --git a/tensorflow/compiler/xla/mlir/runtime/ir/BUILD b/tensorflow/compiler/xla/mlir/runtime/ir/BUILD index cb91a60c683..2f7875bb911 100644 --- a/tensorflow/compiler/xla/mlir/runtime/ir/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/ir/BUILD @@ -1,8 +1,9 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//tensorflow/compiler/xla/mlir/runtime:friends"], licenses = ["notice"], ) @@ -52,14 +53,6 @@ gentbl_cc_library( ["-gen-typedef-defs"], "rt_types.cc.inc", ), - ( - ["-gen-attr-interface-decls"], - "rt_attr_interfaces.h.inc", - ), - ( - ["-gen-attr-interface-defs"], - "rt_attr_interfaces.cc.inc", - ), ( ["-gen-attrdef-decls"], "rt_attrs.h.inc", @@ -74,6 +67,24 @@ gentbl_cc_library( deps = [":rt_ops_td_files"], ) +gentbl_cc_library( + name = "rt_interfaces_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ( + ["-gen-attr-interface-decls"], + "rt_attr_interfaces.h.inc", + ), + ( + ["-gen-attr-interface-defs"], + "rt_attr_interfaces.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "rt_interfaces.td", + deps = [":rt_ops_td_files"], +) + cc_library( name = "rt", srcs = [ @@ -89,6 +100,7 @@ cc_library( compatible_with = get_compatible_with_cloud(), deps = [ ":rt_inc_gen", + ":rt_interfaces_inc_gen", "//tensorflow/compiler/xla/runtime:constraints", "@llvm-project//llvm:Support", "@llvm-project//mlir:ControlFlowInterfaces", diff --git a/tensorflow/compiler/xla/mlir/runtime/ir/rt_dialect.cc b/tensorflow/compiler/xla/mlir/runtime/ir/rt_dialect.cc index e798c601550..2a88d311da2 100644 --- a/tensorflow/compiler/xla/mlir/runtime/ir/rt_dialect.cc +++ b/tensorflow/compiler/xla/mlir/runtime/ir/rt_dialect.cc @@ -69,7 +69,7 @@ mlir::LogicalResult RuntimeDialect::verifyOperationAttribute( << " to be an integer attribute"; } - auto func = llvm::dyn_cast(op); + auto func = llvm::dyn_cast(op); if (!func) { return op->emitError() << attribute.getName() << " can only be applied to a function"; diff --git a/tensorflow/compiler/xla/mlir/runtime/ir/rt_dialect.td b/tensorflow/compiler/xla/mlir/runtime/ir/rt_dialect.td index c70d53f7d11..dc505f9e5d1 100644 --- a/tensorflow/compiler/xla/mlir/runtime/ir/rt_dialect.td +++ b/tensorflow/compiler/xla/mlir/runtime/ir/rt_dialect.td @@ -38,13 +38,12 @@ def RuntimeDialect : Dialect { let cppNamespace = "::xla::runtime"; - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; - let useDefaultTypePrinterParser = 1; let useDefaultAttributePrinterParser = 1; let hasOperationAttrVerify = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } def RT_Ordinal : SignlessIntegerAttrBase; @@ -90,12 +89,10 @@ def HloTraceAttr : AttrDef:$hlo_op, - StringRefParameter<"module">:$module, - "int64_t":$program_id + StringRefParameter<"hlo_op">:$hlo_op ); - let assemblyFormat = "`<` $hlo_op `,` $module `,` $program_id `>`"; + let assemblyFormat = "`<` $hlo_op `>`"; } #endif // RT_DIALECT diff --git a/tensorflow/compiler/xla/mlir/runtime/ir/rt_ops.cc b/tensorflow/compiler/xla/mlir/runtime/ir/rt_ops.cc index bfdee7d699f..db7502cb933 100644 --- a/tensorflow/compiler/xla/mlir/runtime/ir/rt_ops.cc +++ b/tensorflow/compiler/xla/mlir/runtime/ir/rt_ops.cc @@ -36,24 +36,24 @@ using llvm::Optional; //===----------------------------------------------------------------------===// void ExportOp::build(OpBuilder &builder, OperationState &result, - func::FuncOp function_ref) { + FunctionOpInterface function_ref) { result.addAttribute("function_ref", SymbolRefAttr::get(function_ref)); } void ExportOp::build(OpBuilder &builder, OperationState &result, - func::FuncOp function_ref, unsigned ordinal) { + FunctionOpInterface function_ref, unsigned ordinal) { build(builder, result, function_ref); result.addAttribute("ordinal", builder.getI32IntegerAttr(ordinal)); } LogicalResult ExportOp::verifySymbolUses(SymbolTableCollection &symbolTable) { Operation *op = getOperation(); - auto func = symbolTable.lookupNearestSymbolFrom( + auto func = symbolTable.lookupNearestSymbolFrom( op, getFunctionRefAttr()); // Function reference must reference a valid FuncOp operation. if (!func) { - return op->emitError() << "func.func op named '" << getFunctionRef() + return op->emitError() << "func op named '" << getFunctionRef() << "' not found for export"; } @@ -65,16 +65,16 @@ Optional ExportOp::ordinal() { return llvm::None; } -mlir::func::FuncOp ExportOp::exported(mlir::SymbolTable &sym_table) { - return sym_table.lookupNearestSymbolFrom(getOperation(), - getFunctionRefAttr()); +FunctionOpInterface ExportOp::exported(mlir::SymbolTable &sym_table) { + return sym_table.lookupNearestSymbolFrom( + getOperation(), getFunctionRefAttr()); } //===----------------------------------------------------------------------===// // TraceOp //===----------------------------------------------------------------------===// -void TraceOp::getSuccessorRegions(Optional index, +void TraceOp::getSuccessorRegions(std::optional index, ArrayRef operands, SmallVectorImpl ®ions) { // If the predecessor is the TraceOp, branch into the body. @@ -122,8 +122,8 @@ void TraceOp::build(OpBuilder &builder, OperationState &result, //===----------------------------------------------------------------------===// MutableOperandRange YieldOp::getMutableSuccessorOperands( - Optional index) { - return operandsMutable(); + std::optional index) { + return getArgumentsMutable(); } } // namespace runtime diff --git a/tensorflow/compiler/xla/mlir/runtime/ir/rt_ops.td b/tensorflow/compiler/xla/mlir/runtime/ir/rt_ops.td index 0f7e147bf76..a7e21050182 100644 --- a/tensorflow/compiler/xla/mlir/runtime/ir/rt_ops.td +++ b/tensorflow/compiler/xla/mlir/runtime/ir/rt_ops.td @@ -67,13 +67,14 @@ def RT_ExportOp : RT_Op<"export", [ let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "mlir::func::FuncOp":$function_ref)>, - OpBuilder<(ins "mlir::func::FuncOp":$function_ref, "unsigned":$ordinal)>, + OpBuilder<(ins "mlir::FunctionOpInterface":$function_ref)>, + OpBuilder<(ins "mlir::FunctionOpInterface":$function_ref, + "unsigned":$ordinal)>, ]; let extraClassDeclaration = [{ llvm::Optional ordinal(); - mlir::func::FuncOp exported(mlir::SymbolTable& sym_table); + mlir::FunctionOpInterface exported(mlir::SymbolTable& sym_table); }]; let assemblyFormat = "$function_ref (`ordinal` $ordinal^)? attr-dict"; @@ -243,7 +244,7 @@ def RT_CallOp : RT_Op<"call"> { ExecutionContextType:$ctx, StrAttr:$callee, UnitAttr:$dynamic, - Variadic:$operands + Variadic:$arguments ); let results = (outs @@ -252,8 +253,8 @@ def RT_CallOp : RT_Op<"call"> { ); let assemblyFormat = [{ - (`dynamic` $dynamic^)? $ctx `[` $callee `]` `(` $operands `)` - attr-dict `:` functional-type($operands, $results) + (`dynamic` $dynamic^)? $ctx `[` $callee `]` `(` $arguments `)` + attr-dict `:` functional-type($arguments, $results) }]; } @@ -327,9 +328,9 @@ def RT_YieldOp : RT_Op<"yield", the users of `rt.trace` operation results. }]; - let arguments = (ins Variadic:$operands); + let arguments = (ins Variadic:$arguments); - let assemblyFormat = "($operands^ `:` type($operands))? attr-dict"; + let assemblyFormat = "($arguments^ `:` type($arguments))? attr-dict"; // Default builder needed for ensureTerminator let builders = [OpBuilder<(ins), "build($_builder, $_state, {});">]; diff --git a/tensorflow/compiler/xla/mlir/runtime/ir/tests/BUILD b/tensorflow/compiler/xla/mlir/runtime/ir/tests/BUILD index 9b29f4d0e76..b5d3ee388f1 100644 --- a/tensorflow/compiler/xla/mlir/runtime/ir/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/ir/tests/BUILD @@ -1,10 +1,12 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.default.bzl", "filegroup", "get_compatible_with_portable") +load("//tensorflow/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") + +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) glob_lit_tests( data = [":test_utilities"], - driver = "//tensorflow/compiler/mlir:run_lit.sh", + driver = "//tensorflow/compiler/xla:run_lit.sh", test_file_exts = ["mlir"], ) diff --git a/tensorflow/compiler/xla/mlir/runtime/ir/tests/ops.mlir b/tensorflow/compiler/xla/mlir/runtime/ir/tests/ops.mlir index 25ba552ac44..36387c24a5e 100644 --- a/tensorflow/compiler/xla/mlir/runtime/ir/tests/ops.mlir +++ b/tensorflow/compiler/xla/mlir/runtime/ir/tests/ops.mlir @@ -69,13 +69,13 @@ func.func @opaque_arg(%ctx: !rt.execution_context, // CHECK: ) -> memref func.func @trace(%ctx: !rt.execution_context, %arg: memref) -> memref { - // CHECK: rt.trace #rt.hlo_trace<"fusion", "foo", 0>, %[[CTX]] - rt.trace #rt.hlo_trace<"fusion", "foo", 0>, %ctx {} + // CHECK: rt.trace #rt.hlo_trace<"fusion">, %[[CTX]] + rt.trace #rt.hlo_trace<"fusion">, %ctx {} - // CHECK: rt.trace #rt.hlo_trace<"fusion", "bar", 0> + // CHECK: rt.trace #rt.hlo_trace<"fusion"> // CHECK-SAME: %[[CTX]] -> memref // CHECK-NEXT: yield %[[ARG]] : memref - %0 = rt.trace #rt.hlo_trace<"fusion", "bar", 0>, %ctx -> memref { + %0 = rt.trace #rt.hlo_trace<"fusion">, %ctx -> memref { yield %arg : memref } diff --git a/tensorflow/compiler/xla/mlir/runtime/ir/tests/ops_verify.mlir b/tensorflow/compiler/xla/mlir/runtime/ir/tests/ops_verify.mlir index 997e59f69d7..4fb5231582a 100644 --- a/tensorflow/compiler/xla/mlir/runtime/ir/tests/ops_verify.mlir +++ b/tensorflow/compiler/xla/mlir/runtime/ir/tests/ops_verify.mlir @@ -1,7 +1,7 @@ // RUN: xla-runtime-opt -verify-diagnostics -split-input-file %s // ----- -// expected-error @+1 {{func.func op named 'foo' not found for export}} +// expected-error @+1 {{func op named 'foo' not found for export}} rt.export @foo // ----- diff --git a/tensorflow/compiler/xla/mlir/runtime/ir/tests/testlib.td b/tensorflow/compiler/xla/mlir/runtime/ir/tests/testlib.td index d6f38878cdd..fd3bbade524 100644 --- a/tensorflow/compiler/xla/mlir/runtime/ir/tests/testlib.td +++ b/tensorflow/compiler/xla/mlir/runtime/ir/tests/testlib.td @@ -34,8 +34,7 @@ def TestlibDialect : Dialect { let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 1; - - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let useFoldAPI = kEmitFoldAdaptorFolder; } include "testlib_attrs.td" diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD b/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD index 1960ae3eabc..fde2d7d6955 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/BUILD @@ -1,9 +1,11 @@ -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("//tensorflow/tsl/platform:build_config.bzl", "if_llvm_system_z_available") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//tensorflow/compiler/xla/mlir/runtime:friends"], licenses = ["notice"], ) @@ -73,7 +75,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "calling_convention_test", srcs = ["calling_convention_test.cc"], compatible_with = get_compatible_with_cloud(), @@ -100,9 +102,9 @@ cc_library( ":compiler", ":custom_call_encoding", ":passes", - "//tensorflow/compiler/xla/mlir/transforms/cpu:passes", - "//tensorflow/compiler/xla/mlir/transforms/math:passes", - "//tensorflow/compiler/xla/mlir/transforms/memref:passes", + "//tensorflow/compiler/xla/mlir/backends/cpu/transforms:passes", + "//tensorflow/compiler/xla/mlir/math/transforms:passes", + "//tensorflow/compiler/xla/mlir/memref/transforms:passes", "//tensorflow/compiler/xla/runtime:compiler", "@llvm-project//mlir:AMXToLLVMIRTranslation", "@llvm-project//mlir:AffineDialect", @@ -131,6 +133,7 @@ cc_library( "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:Transforms", "@llvm-project//mlir:VectorToLLVM", "@llvm-project//mlir:X86VectorToLLVMIRTranslation", @@ -153,6 +156,9 @@ cc_library( "//tensorflow/compiler/xla/mlir_hlo:lhlo", "//tensorflow/compiler/xla/mlir_hlo:lhlo_gpu", "//tensorflow/compiler/xla/runtime:compiler", + "@llvm-project//mlir:AsyncDialect", + "@llvm-project//mlir:AsyncToLLVM", + "@llvm-project//mlir:AsyncTransforms", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncToLLVM", "@llvm-project//mlir:IR", @@ -163,6 +169,7 @@ cc_library( "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:Transforms", ], alwayslink = 1, # has pipeline registration ) @@ -190,11 +197,13 @@ cc_library( "//tensorflow/compiler/xla/runtime:type_id", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:AsyncDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:Support", + "@tf_runtime//:async_value", ], ) @@ -213,12 +222,10 @@ cc_library( "//tensorflow/compiler/xla/runtime:arguments", "//tensorflow/compiler/xla/runtime:compiler", "//tensorflow/compiler/xla/runtime:constraints", - "//tensorflow/compiler/xla/runtime:errors", "//tensorflow/compiler/xla/runtime:executable", "//tensorflow/compiler/xla/runtime:symbolic_shape", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", "@llvm-project//mlir:ExecutionEngineUtils", @@ -226,7 +233,22 @@ cc_library( "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:ToLLVMIRTranslation", - ], + ] + select({ + "//tensorflow/tsl:arm_any": [ + "@llvm-project//llvm:AArch64AsmParser", + ], + "//tensorflow/tsl:linux_ppc64le": [ + "@llvm-project//llvm:PowerPCAsmParser", + ], + "//tensorflow/tsl:macos_arm64": [ + "@llvm-project//llvm:AArch64AsmParser", + ], + "//conditions:default": [ + "@llvm-project//llvm:X86AsmParser", + ], + }) + if_llvm_system_z_available([ + "@llvm-project//llvm:SystemZAsmParser", + ]), ) cc_library( @@ -273,7 +295,7 @@ cc_library( ], ) -tf_cc_test( +xla_cc_test( name = "type_converter_test", srcs = ["type_converter_test.cc"], compatible_with = get_compatible_with_cloud(), diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc index 893b94615bf..278a087c89b 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_cpu.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h" // from @llvm-project @@ -47,24 +48,25 @@ limitations under the License. #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/backends/cpu/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir/math/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir/memref/transforms/passes.h" #include "tensorflow/compiler/xla/mlir/runtime/transforms/compiler.h" #include "tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.h" #include "tensorflow/compiler/xla/mlir/runtime/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir/transforms/cpu/passes.h" -#include "tensorflow/compiler/xla/mlir/transforms/math/passes.h" -#include "tensorflow/compiler/xla/mlir/transforms/memref/passes.h" namespace xla { namespace runtime { void RegisterDefaultXlaCpuRuntimeDialects(DialectRegistry& dialects) { // Register MLIR dialects supported by the compiled executables. - dialects->insert(); + dialects->insert< + mlir::AffineDialect, mlir::arith::ArithDialect, mlir::async::AsyncDialect, + mlir::cf::ControlFlowDialect, mlir::linalg::LinalgDialect, + mlir::math::MathDialect, mlir::memref::MemRefDialect, + mlir::scf::SCFDialect, mlir::func::FuncDialect, + mlir::sparse_tensor::SparseTensorDialect, mlir::tensor::TensorDialect, + mlir::vector::VectorDialect, RuntimeDialect>(); // Register MLIR dialects that can be translated to LLVM IR. mlir::registerArmNeonDialectTranslation(*dialects); @@ -76,8 +78,11 @@ void RegisterDefaultXlaCpuRuntimeDialects(DialectRegistry& dialects) { static void CreateDefaultXlaCpuRuntimeCompilationPipeline( mlir::OpPassManager& pm, const CpuPipelineOptions& opts) { + pm.addPass(mlir::createAsyncFuncToAsyncRuntimePass()); + // Convert entry function to the XLA entrypoint. pm.addPass(CreateExportRuntimeFunctionsPass()); + pm.addPass(cpu::createConvertLmhloToCpuRuntimePass()); pm.addPass(CreateConvertCustomCallsPass()); pm.addPass(CreateConvertAssertsPass()); @@ -85,10 +90,9 @@ static void CreateDefaultXlaCpuRuntimeCompilationPipeline( pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); - // Optimize operations from the math dialect before outlining compute regions - // into functions to see all constant operands. + // Enable math approximations to match XLA's FP accuracy spec. pm.addNestedPass( - xla::runtime::CreateMathOptimizationPass(opts.math_avx2)); + xla::CreateMathApproximationPass({"all"})); // Convert all linalg operations to parallel loops. pm.addNestedPass( @@ -107,10 +111,12 @@ static void CreateDefaultXlaCpuRuntimeCompilationPipeline( // Expand math operations into std/arith dialect operations. pm.addNestedPass(mlir::arith::createArithExpandOpsPass()); pm.addNestedPass(mlir::memref::createExpandOpsPass()); + pm.addNestedPass( + mlir::memref::createExpandStridedMetadataPass()); // Add alignment attribute to all memref allocations. pm.addNestedPass( - xla::runtime::CreateAlignedAllocationsPass(opts.alignment)); + xla::CreateAlignedAllocationsPass(opts.alignment)); // Lower everything down to LLVM dialect. pm.addPass(mlir::createConvertLinalgToLLVMPass()); @@ -118,8 +124,6 @@ static void CreateDefaultXlaCpuRuntimeCompilationPipeline( pm.addPass(mlir::createConvertSCFToCFPass()); // Convert runtime operations and custom calls to LLVM dialect. - pm.addPass(cpu::createConvertLmhloToCpuRuntimePass()); - pm.addPass(CreateConvertCustomCallsPass()); const CompilationPipelineOptions& copts = opts.common_options; ConvertRuntimeToLLvmOpts rt_to_llvm_opts = { copts.populate_type_id_names, copts.populate_type_conversions, @@ -130,17 +134,13 @@ static void CreateDefaultXlaCpuRuntimeCompilationPipeline( // Convert async dialect to LLVM once everything else is in the LLVM dialect. pm.addPass(mlir::createConvertAsyncToLLVMPass()); - { - mlir::OpPassManager& fpm = pm.nest(); - fpm.addPass(mlir::createConvertMathToLLVMPass()); - } - pm.addPass(mlir::createConvertMathToLibmPass()); + pm.addPass(xla::CreateMathLegalizationPass(/*enable_approximations=*/false)); // Convert everything else to LLVM dialect. mlir::LowerVectorToLLVMOptions vector_to_llvm_opts; if (opts.math_avx2) vector_to_llvm_opts.enableX86Vector(); pm.addPass(mlir::createConvertVectorToLLVMPass(vector_to_llvm_opts)); - pm.addPass(mlir::createMemRefToLLVMConversionPass()); + pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); pm.addPass(mlir::createConvertFuncToLLVMPass()); pm.addPass(mlir::createConvertComplexToLLVMPass()); pm.addPass(mlir::createReconcileUnrealizedCastsPass()); diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.cc index aeeab6f50ae..cb05272ed15 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/compilation_pipeline_gpu.cc @@ -17,21 +17,24 @@ limitations under the License. #include +#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" // from @llvm-project #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project +#include "mlir/Dialect/Async/IR/Async.h" // from @llvm-project +#include "mlir/Dialect/Async/Passes.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/xla/mlir/runtime/ir/tests/testlib.h" #include "tensorflow/compiler/xla/mlir/runtime/transforms/compiler.h" #include "tensorflow/compiler/xla/mlir/runtime/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" namespace xla { namespace runtime { @@ -41,7 +44,7 @@ void RegisterDefaultXlaGpuRuntimeDialects(DialectRegistry& dialects) { dialects->insert(); + mlir::async::AsyncDialect, RuntimeDialect>(); // Register MLIR dialects that can be translated to LLVM IR. mlir::registerLLVMDialectTranslation(*dialects); @@ -58,12 +61,19 @@ void RegisterTestlibDialect(DialectRegistry& dialects) { static void CreateDefaultXlaGpuRuntimeCompilationPipeline( mlir::OpPassManager& pm, const CompilationPipelineOptions& opts) { pm.addPass(mlir::createConvertSCFToCFPass()); + pm.addPass(mlir::createAsyncFuncToAsyncRuntimePass()); // Export functions to the XLA runtime. pm.addPass(CreateExportRuntimeFunctionsPass()); pm.addPass(CreateConvertCustomCallsPass()); pm.addPass(CreateConvertAssertsPass()); + // Lower from high level async operations to async runtime. + pm.addPass(mlir::createAsyncToAsyncRuntimePass()); + + // Add async.runtime reference counting operations. + pm.addPass(mlir::createAsyncRuntimePolicyBasedRefCountingPass()); + // Convert runtime operations and custom calls to LLVM dialect. ConvertRuntimeToLLvmOpts rt_to_llvm_opts = { opts.populate_type_id_names, opts.populate_type_conversions, @@ -71,10 +81,17 @@ static void CreateDefaultXlaGpuRuntimeCompilationPipeline( opts.populate_attr_encodings}; pm.addPass(CreateConvertRuntimeToLLVMPass(std::move(rt_to_llvm_opts))); - // Convert everythinG else to LLVM dialect. - pm.addPass(mlir::createMemRefToLLVMConversionPass()); + // Convert async dialect to LLVM once everything else is in the LLVM dialect. + pm.addPass(mlir::createConvertAsyncToLLVMPass()); + + // Convert everything else to LLVM dialect. + pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); pm.addPass(mlir::createConvertFuncToLLVMPass()); pm.addPass(mlir::createReconcileUnrealizedCastsPass()); + + // Clean up IR before passing it to LLVM. + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); } void CreateDefaultXlaGpuRuntimeCompilationPipeline( diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.cc index a19a7033246..c4dd90e93a1 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.cc @@ -23,8 +23,10 @@ limitations under the License. #include #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/FormatVariadic.h" #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Async/IR/AsyncTypes.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project @@ -32,6 +34,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -39,6 +42,8 @@ limitations under the License. #include "tensorflow/compiler/xla/runtime/custom_call.h" #include "tensorflow/compiler/xla/runtime/tracing.h" #include "tensorflow/compiler/xla/runtime/type_id.h" +#include "tfrt/concurrency/async_value_ref.h" // from @tf_runtime +#include "tfrt/concurrency/chain.h" // from @tf_runtime namespace Eigen { struct half; @@ -59,13 +64,13 @@ using llvm::ArrayRef; using EncodedArg = CustomCallArgEncodingSet::Encoded; -FailureOr CustomCallArgEncodingSet::Encode(Globals &g, +FailureOr CustomCallArgEncodingSet::Encode(Globals &g, Allocas &a, ImplicitLocOpBuilder &b, Value value, Value converted) const { for (auto &encoding : encodings_) if (succeeded(encoding->Match(value, converted))) - return encoding->Encode(g, b, value, converted); + return encoding->Encode(g, a, b, value, converted); return failure(); } @@ -75,13 +80,13 @@ FailureOr CustomCallArgEncodingSet::Encode(Globals &g, using EncodedRet = CustomCallRetEncodingSet::Encoded; -FailureOr CustomCallRetEncodingSet::Encode(Globals &g, +FailureOr CustomCallRetEncodingSet::Encode(Globals &g, Allocas &a, ImplicitLocOpBuilder &b, Type type, Type converted) const { for (auto &encoding : encodings_) if (succeeded(encoding->Match(type, converted))) - return encoding->Encode(g, b, type, converted); + return encoding->Encode(g, a, b, type, converted); return failure(); } @@ -114,13 +119,14 @@ FailureOr CustomCallAttrEncodingSet::Encode( // A set of helper functions for packing primitive attributes. //===----------------------------------------------------------------------===// -Value PackTypeId(Globals &g, ImplicitLocOpBuilder &b, TypeID type_id) { - auto global = g.GetOrCreate(b, type_id); - return Globals::AddrOf(b, global); +LLVM::GlobalOp EncodeTypeId(Globals &g, ImplicitLocOpBuilder &b, + TypeID type_id) { + return g.GetOrCreate(b, type_id); } -Value PackString(Globals &g, ImplicitLocOpBuilder &b, std::string_view strref, - std::string_view symbol_base) { +LLVM::GlobalOp EncodeString(Globals &g, ImplicitLocOpBuilder &b, + std::string_view strref, + std::string_view symbol_base) { MLIRContext *ctx = b.getContext(); int64_t size = strref.size(); @@ -142,15 +148,13 @@ Value PackString(Globals &g, ImplicitLocOpBuilder &b, std::string_view strref, }; auto value = b.getStringAttr(strref); - auto global = g.GetOrCreate(b, value, type, symbol_base, init); - return Globals::AddrOf(b, global); + return g.GetOrCreate(b, value, type, symbol_base, init); } -// Packs scalar attribute as a global constant. Returns `!llvm.ptr`. -Value PackScalarAttribute(Globals &g, ImplicitLocOpBuilder &b, Attribute value, - std::string_view symbol_base) { - auto global = g.GetOrCreate(b, value, symbol_base); - return Globals::AddrOf(b, global); +mlir::LLVM::GlobalOp EncodeScalar(Globals &g, mlir::ImplicitLocOpBuilder &b, + mlir::Attribute value, + std::string_view symbol_base) { + return g.GetOrCreate(b, value, symbol_base); } // Reshape dense elements as a one-dimensional array. @@ -165,11 +169,10 @@ static mlir::DenseElementsAttr Flatten(DenseIntOrFPElementsAttr dense) { // A set of helper functions for packing dense and array-like attributes. //===----------------------------------------------------------------------===// -// Packs dense elements attribute as a global constant. Returns -// `!llvm.ptr`. -static Value PackDenseElementsAttribute(Globals &g, ImplicitLocOpBuilder &b, - Attribute value, - std::string_view symbol_base) { +// Encodes dense elements attribute as a global constant. +static LLVM::GlobalOp EncodeDenseElementsAttribute( + Globals &g, ImplicitLocOpBuilder &b, Attribute value, + std::string_view symbol_base) { MLIRContext *ctx = b.getContext(); DenseIntOrFPElementsAttr dense = value.cast(); @@ -186,7 +189,7 @@ static Value PackDenseElementsAttribute(Globals &g, ImplicitLocOpBuilder &b, // cast pointers to dense elements attributes (shaped tensors) as pointers to // flat array attributes. // - // See `PackArrayAttribute` defined below. + // See `EncodeArrayAttribute` defined below. Type encoded_arr_type = LLVM::LLVMStructType::getLiteral(ctx, {b.getI64Type(), ptr}); @@ -229,15 +232,13 @@ static Value PackDenseElementsAttribute(Globals &g, ImplicitLocOpBuilder &b, ib.create(encoded); }; - auto global = g.GetOrCreate(b, value, type, symbol_base, init); - return Globals::AddrOf(b, global); + return g.GetOrCreate(b, value, type, symbol_base, init); } -// Create a global for the data array in an EncodedArray. -// Returns `!llvm.ptr> -static Value CreateGlobalFromArray(Globals &g, ImplicitLocOpBuilder &b, - ArrayAttr array, Type element_type, - std::string_view symbol_base) { +// Encodes the payload of an array attribute as a global constant. +static LLVM::GlobalOp EncodeArrayAttrData(Globals &g, ImplicitLocOpBuilder &b, + ArrayAttr array, Type element_type, + std::string_view symbol_base) { Type arr_type = LLVM::LLVMArrayType::get(element_type, array.size()); auto init = [&](ImplicitLocOpBuilder &ib, Attribute) { @@ -249,14 +250,13 @@ static Value CreateGlobalFromArray(Globals &g, ImplicitLocOpBuilder &b, ib.create(data); }; - auto global = g.GetOrCreate(b, array, arr_type, symbol_base, init); - return Globals::AddrOf(b, global); + return g.GetOrCreate(b, array, arr_type, symbol_base, init); } -// Packs array attribute as a global constant. Returns `!llvm.ptr`. -static Value PackArrayAttribute(Globals &g, ImplicitLocOpBuilder &b, - ArrayAttr array, Type element_type, - std::string_view symbol_base) { +// Encodes array attribute as a global constant. +static LLVM::GlobalOp EncodeArrayAttribute(Globals &g, ImplicitLocOpBuilder &b, + ArrayAttr array, Type element_type, + std::string_view symbol_base) { MLIRContext *ctx = b.getContext(); int64_t size = array.size(); @@ -269,7 +269,8 @@ static Value PackArrayAttribute(Globals &g, ImplicitLocOpBuilder &b, auto init = [&](ImplicitLocOpBuilder &ib, Attribute) { // Array size and the pointer to data. Value num_elements = ib.create(b.getI64IntegerAttr(size)); - Value data = CreateGlobalFromArray(g, b, array, element_type, symbol_base); + Value data = Globals::AddrOf( + b, EncodeArrayAttrData(g, b, array, element_type, symbol_base)); // Store size and values into the struct. Value encoded = ib.create(type); @@ -279,8 +280,7 @@ static Value PackArrayAttribute(Globals &g, ImplicitLocOpBuilder &b, ib.create(encoded); }; - auto global = g.GetOrCreate(b, array, type, symbol_base, init); - return Globals::AddrOf(b, global); + return g.GetOrCreate(b, array, type, symbol_base, init); } template @@ -295,10 +295,12 @@ static Value FillDataFromDenseArrayAttr( return data; } -static Value CreateGlobalFromDenseArray(Globals &g, ImplicitLocOpBuilder &b, - DenseArrayAttr base_array, - Type arr_type, - std::string_view symbol_base) { +// Encodes the payload of a dense array attribute as a global constant. +static LLVM::GlobalOp EncodeDenseArrayAttrData(Globals &g, + ImplicitLocOpBuilder &b, + DenseArrayAttr base_array, + Type arr_type, + std::string_view symbol_base) { auto init = [&](ImplicitLocOpBuilder &ib, Attribute) { Value data = ib.create(arr_type); llvm::TypeSwitch(base_array) @@ -332,13 +334,13 @@ static Value CreateGlobalFromDenseArray(Globals &g, ImplicitLocOpBuilder &b, ib.create(data); }; - auto global = g.GetOrCreate(b, base_array, arr_type, symbol_base, init); - return Globals::AddrOf(b, global); + return g.GetOrCreate(b, base_array, arr_type, symbol_base, init); } -static Value PackDenseArrayAttribute(Globals &g, ImplicitLocOpBuilder &b, - Attribute value, - std::string_view symbol_base) { +static LLVM::GlobalOp EncodeDenseArrayAttribute(Globals &g, + ImplicitLocOpBuilder &b, + Attribute value, + std::string_view symbol_base) { MLIRContext *ctx = b.getContext(); DenseArrayAttr base_array = value.cast(); @@ -347,7 +349,7 @@ static Value PackDenseArrayAttribute(Globals &g, ImplicitLocOpBuilder &b, Type ptr = LLVM::LLVMPointerType::get(ctx); // Stored array type: !llvm.array - Type element_type = base_array.getType().getElementType(); + Type element_type = base_array.getElementType(); Type arr_type = LLVM::LLVMArrayType::get(element_type, size); // Encoded array type: !llvm.struct<(i64, !llvm.ptr)>. @@ -357,8 +359,8 @@ static Value PackDenseArrayAttribute(Globals &g, ImplicitLocOpBuilder &b, auto init = [&](ImplicitLocOpBuilder &ib, Attribute) { // Array size and values. Value num_elements = ib.create(b.getI64IntegerAttr(size)); - Value data = - CreateGlobalFromDenseArray(g, ib, base_array, arr_type, symbol_base); + Value data = Globals::AddrOf( + b, EncodeDenseArrayAttrData(g, ib, base_array, arr_type, symbol_base)); // Store size and values into the struct. Value encoded = ib.create(type); @@ -368,13 +370,13 @@ static Value PackDenseArrayAttribute(Globals &g, ImplicitLocOpBuilder &b, ib.create(encoded); }; - auto global = g.GetOrCreate(b, value, type, symbol_base, init); - return Globals::AddrOf(b, global); + return g.GetOrCreate(b, value, type, symbol_base, init); } -static Value PackEmptyArrayAttribute(Globals &g, ImplicitLocOpBuilder &b, - Attribute value, - std::string_view symbol_base) { +static LLVM::GlobalOp EncodeEmptyArrayAttribute(Globals &g, + ImplicitLocOpBuilder &b, + Attribute value, + std::string_view symbol_base) { MLIRContext *ctx = b.getContext(); Type ptr = LLVM::LLVMPointerType::get(ctx); @@ -396,8 +398,7 @@ static Value PackEmptyArrayAttribute(Globals &g, ImplicitLocOpBuilder &b, ib.create(encoded); }; - auto global = g.GetOrCreate(b, value, type, symbol_base, init); - return Globals::AddrOf(b, global); + return g.GetOrCreate(b, value, type, symbol_base, init); } //===----------------------------------------------------------------------===// @@ -414,22 +415,14 @@ static FuncOp GetParentFunc(Value value) { } // Packs value on the stack. Returns allocation holding the value. -static LLVM::AllocaOp PackValue(ImplicitLocOpBuilder &b, Value value) { - Type ptr = LLVM::LLVMPointerType::get(b.getContext()); - - // Always create an `alloca` in the parent function entry block. - // See: https://llvm.org/docs/Frontend/PerformanceTips.html#use-of-allocas - LLVM::AllocaOp mem = [&]() -> LLVM::AllocaOp { - Block &block = GetParentFunc(value).getBody().front(); - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(&block); - Value one = b.create(b.getI32IntegerAttr(1)); - return b.create(ptr, value.getType(), one, 0); - }(); +static LLVM::AllocaOp PackValue(ImplicitLocOpBuilder &b, Allocas &a, + Value value) { + LLVM::AllocaOp alloca = a.GetOrCreate(b, value.getType()); + // Start the lifetime of encoded value. + b.create(b.getI64IntegerAttr(-1), alloca); + b.create(value, alloca); - b.create(value, mem); - - return mem; + return alloca; } //===----------------------------------------------------------------------===// @@ -527,6 +520,50 @@ mlir::FailureOr Globals::TryGetOrCreate( global.getSymName()); } +//===----------------------------------------------------------------------===// +// A helper class to create alloca operations for encoded arguments. +//===----------------------------------------------------------------------===// + +Allocas::Allocas(Block *block, + llvm::DenseMap *allocas) + : block_(block), allocas_(allocas) { + for (auto &[_, v] : *allocas_) { + assert(v.offset == 0 && "expected zero offset"); + (void)v; + } +} + +Allocas::~Allocas() { + for (auto &[k, v] : *allocas_) v.offset = 0; +} + +mlir::LLVM::AllocaOp Allocas::GetOrCreate(mlir::ImplicitLocOpBuilder &b, + mlir::Type type) { + TypedAllocas &allocas = (*allocas_)[type]; + + // Reuse existing alloca for the given type. + if (allocas.offset < allocas.allocas.size()) { + return allocas.allocas[allocas.offset++]; + } + + // Create a new alloca at the beginning of the block. + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(block_); + Value c1 = b.create(b.getI32IntegerAttr(1)); + Type ptr = LLVM::LLVMPointerType::get(b.getContext()); + auto alloca = b.create(ptr, type, c1, 0); + + ++allocas.offset; + return allocas.allocas.emplace_back(alloca); +} + +Allocas EncodingAllocas::GetForOperation(mlir::Operation *op) { + // Always create an `alloca` in the parent function entry block. + // See: https://llvm.org/docs/Frontend/PerformanceTips.html#use-of-allocas + Block *block = &op->getParentOfType().getBody().front(); + return Allocas(block, &allocas_[block]); +} + //===----------------------------------------------------------------------===// // Helper functions for encoding attributes and values for custom calls. //===----------------------------------------------------------------------===// @@ -592,6 +629,8 @@ static PrimitiveType ScalarPrimitiveType(Type type) { if (type.isInteger(64)) return PrimitiveType::S64; // Floating point types. + if (type.isFloat8E4M3FN()) return PrimitiveType::F8E4M3FN; + if (type.isFloat8E5M2()) return PrimitiveType::F8E5M2; if (type.isF16()) return PrimitiveType::F16; if (type.isF32()) return PrimitiveType::F32; if (type.isF64()) return PrimitiveType::F64; @@ -608,12 +647,39 @@ static PrimitiveType ScalarPrimitiveType(Type type) { } static TypeID ArrayRuntimeTypeId(Type elem_type) { - if (elem_type.isInteger(8)) return TypeID::get>>(); - if (elem_type.isInteger(16)) return TypeID::get>>(); - if (elem_type.isInteger(32)) return TypeID::get>>(); - if (elem_type.isInteger(64)) return TypeID::get>>(); - if (elem_type.isF32()) return TypeID::get>>(); - if (elem_type.isF64()) return TypeID::get>>(); + if (elem_type.isInteger(8)) + return TypeID::get>>(); + if (elem_type.isInteger(16)) + return TypeID::get>>(); + if (elem_type.isInteger(32)) + return TypeID::get>>(); + if (elem_type.isInteger(64)) + return TypeID::get>>(); + + if (elem_type.isF32()) return TypeID::get>>(); + if (elem_type.isF64()) return TypeID::get>>(); + + assert(false && "unsupported type id"); + return TypeID::getFromOpaquePointer(reinterpret_cast(0xDEADBEEF)); +} + +static TypeID AsyncValueRuntimeTypeId(Type elem_type) { + if (elem_type.isInteger(1)) + return TypeID::get>>(); + if (elem_type.isInteger(8)) + return TypeID::get>>(); + if (elem_type.isInteger(16)) + return TypeID::get>>(); + if (elem_type.isInteger(32)) + return TypeID::get>>(); + if (elem_type.isInteger(64)) + return TypeID::get>>(); + if (elem_type.isF32()) + return TypeID::get>>(); + if (elem_type.isF64()) + return TypeID::get>>(); + if (elem_type.isa()) + return TypeID::get>>(); assert(false && "unsupported type id"); return TypeID::getFromOpaquePointer(reinterpret_cast(0xDEADBEEF)); @@ -651,9 +717,9 @@ FailureOr StringAttrEncoding::Encode(mlir::SymbolTable &, auto str = attr.cast(); Encoded encoded; - encoded.name = PackString(g, b, name, kAttrName); - encoded.type_id = PackTypeId(g, b, TypeID::get>()); - encoded.value = PackString(g, b, str.getValue(), kAttrValue); + encoded.name = EncodeString(g, b, name, kAttrName); + encoded.type_id = EncodeTypeId(g, b, TypeID::get>()); + encoded.value = EncodeString(g, b, str.getValue(), kAttrValue); return encoded; } @@ -673,9 +739,9 @@ FailureOr ScalarAttrEncoding::Encode(mlir::SymbolTable &, Type type = attr.cast().getType(); Encoded encoded; - encoded.name = PackString(g, b, name, kAttrName); - encoded.type_id = PackTypeId(g, b, ScalarRuntimeTypeId(type)); - encoded.value = PackScalarAttribute(g, b, attr, kAttrValue); + encoded.name = EncodeString(g, b, name, kAttrName); + encoded.type_id = EncodeTypeId(g, b, ScalarRuntimeTypeId(type)); + encoded.value = EncodeScalar(g, b, attr, kAttrValue); return encoded; } @@ -697,9 +763,9 @@ FailureOr DenseElementsAttrEncoding::Encode( Type elem_type = dense.getType().getElementType(); Encoded encoded; - encoded.name = PackString(g, b, name, kAttrName); - encoded.type_id = PackTypeId(g, b, DenseElementsRuntimeTypeId(elem_type)); - encoded.value = PackDenseElementsAttribute(g, b, attr, kAttrValue); + encoded.name = EncodeString(g, b, name, kAttrName); + encoded.type_id = EncodeTypeId(g, b, DenseElementsRuntimeTypeId(elem_type)); + encoded.value = EncodeDenseElementsAttribute(g, b, attr, kAttrValue); return encoded; } @@ -732,9 +798,9 @@ FailureOr ArrayAttrEncoding::Encode(mlir::SymbolTable &, if (!all_of_same_type) return failure(); Encoded encoded; - encoded.name = PackString(g, b, name, kAttrName); - encoded.type_id = PackTypeId(g, b, ArrayRuntimeTypeId(elem_type)); - encoded.value = PackArrayAttribute(g, b, array, elem_type, kAttrValue); + encoded.name = EncodeString(g, b, name, kAttrName); + encoded.type_id = EncodeTypeId(g, b, ArrayRuntimeTypeId(elem_type)); + encoded.value = EncodeArrayAttribute(g, b, array, elem_type, kAttrValue); return encoded; } @@ -755,12 +821,12 @@ FailureOr DenseArrayAttrEncoding::Encode(mlir::SymbolTable &, ImplicitLocOpBuilder &b, std::string_view name, Attribute attr) const { - Type elem_type = attr.cast().getType().getElementType(); + Type elem_type = attr.cast().getElementType(); Encoded encoded; - encoded.name = PackString(g, b, name, kAttrName); - encoded.type_id = PackTypeId(g, b, ArrayRuntimeTypeId(elem_type)); - encoded.value = PackDenseArrayAttribute(g, b, attr, kAttrValue); + encoded.name = EncodeString(g, b, name, kAttrName); + encoded.type_id = EncodeTypeId(g, b, ArrayRuntimeTypeId(elem_type)); + encoded.value = EncodeDenseArrayAttribute(g, b, attr, kAttrValue); return encoded; } @@ -782,9 +848,9 @@ FailureOr EmptyArrayAttrEncoding::Encode(mlir::SymbolTable &, std::string_view name, Attribute attr) const { Encoded encoded; - encoded.name = PackString(g, b, name, kAttrName); - encoded.type_id = PackTypeId(g, b, TypeID::get>()); - encoded.value = PackEmptyArrayAttribute(g, b, attr, kAttrValue); + encoded.name = EncodeString(g, b, name, kAttrName); + encoded.type_id = EncodeTypeId(g, b, TypeID::get>()); + encoded.value = EncodeEmptyArrayAttribute(g, b, attr, kAttrValue); return encoded; } @@ -815,9 +881,9 @@ FailureOr SymbolRefAttrEncoding::Encode( auto type_id = TypeID::get>(); Encoded encoded; - encoded.name = PackString(g, b, name, kAttrName); - encoded.type_id = PackTypeId(g, b, type_id); - encoded.value = PackScalarAttribute(g, b, ordinal, kAttrValue); + encoded.name = EncodeString(g, b, name, kAttrName); + encoded.type_id = EncodeTypeId(g, b, type_id); + encoded.value = EncodeScalar(g, b, ordinal, kAttrValue); return encoded; } @@ -833,14 +899,40 @@ FailureOr UnitAttrEncoding::Encode(mlir::SymbolTable &, Globals &g, ImplicitLocOpBuilder &b, std::string_view name, Attribute attr) const { - // Unit attribute encodes empty optional as a null pointer. - Type ptr = LLVM::LLVMPointerType::get(b.getContext()); - Encoded encoded; - encoded.name = PackString(g, b, name, kAttrName); - encoded.type_id = PackTypeId(g, b, TypeID::get>()); - encoded.value = b.create(ptr); + encoded.name = EncodeString(g, b, name, kAttrName); + encoded.type_id = EncodeTypeId(g, b, TypeID::get>()); + encoded.value = nullptr; // unit attribute encoded as null global op + + return encoded; +} + +//===----------------------------------------------------------------------===// + +LogicalResult DictionaryAttrEncoding::Match(mlir::SymbolTable &, + std::string_view, + Attribute attr) const { + return success(attr.isa()); +} +FailureOr DictionaryAttrEncoding::Encode( + mlir::SymbolTable &sym_table, Globals &g, ImplicitLocOpBuilder &b, + std::string_view name, Attribute attr) const { + // TODO(ezhulenev): Add current set of available encodings to `Encode` + // arguments and remove it from `AggregateAttrEncoding` constructor. + CustomCallAttrEncodingSet encoding = DefaultAttrEncodings(); + + auto dict = cast(attr); + auto encoded_dict = EncodeAttributes( + sym_table, g, b, encoding, "__rt_dictionary", + // We rely on the fact that dictionary keeps attributes sorted by name. + llvm::SmallVector(dict.begin(), dict.end())); + if (mlir::failed(encoded_dict)) return mlir::failure(); + + Encoded encoded; + encoded.name = EncodeString(g, b, name, kAttrName); + encoded.type_id = EncodeTypeId(g, b, TypeID::get>()); + encoded.value = *encoded_dict; return encoded; } @@ -848,18 +940,17 @@ FailureOr UnitAttrEncoding::Encode(mlir::SymbolTable &, Globals &g, // Encoding for collection of attributes. //===----------------------------------------------------------------------===// -FailureOr EncodeAttributes(mlir::SymbolTable &sym_table, Globals &g, - ImplicitLocOpBuilder &b, - const CustomCallAttrEncodingSet &encoding, - std::string_view symbol_base, - ArrayRef attrs) { +FailureOr EncodeAttributes( + mlir::SymbolTable &sym_table, Globals &g, ImplicitLocOpBuilder &b, + const CustomCallAttrEncodingSet &encoding, std::string_view symbol_base, + ArrayRef attrs) { using EncodedAttr = std::pair; // In addition to encoded attributes we encode the number of attributes. int64_t n_attrs = attrs.size(); - // We store encoded attribute as `!llvm.array x len>`. + // We store encoded attribute as `!llvm.array`. Type ptr = LLVM::LLVMPointerType::get(b.getContext()); Type type = LLVM::LLVMArrayType::get(ptr, 1 + n_attrs * 3); @@ -881,18 +972,26 @@ FailureOr EncodeAttributes(mlir::SymbolTable &sym_table, Globals &g, }; // Insert the number of encoded attributes. - Attribute num_attrs = b.getI64IntegerAttr(n_attrs); - Value size = PackScalarAttribute(g, b, num_attrs, "__rt_num_attrs"); - insert_value(size, 0); + LLVM::GlobalOp num_attrs = + EncodeScalar(g, b, b.getI64IntegerAttr(n_attrs), "__rt_num_attrs"); + insert_value(Globals::AddrOf(b, num_attrs), 0); // Insert encoded attributes into the allocated storage. for (auto &pair : llvm::enumerate(encoded_attrs)) { CustomCallAttrEncoding::Encoded encoded = pair.value().second; int64_t offset = 1 + pair.index() * 3; - insert_value(encoded.name, offset + 0); - insert_value(encoded.type_id, offset + 1); - insert_value(encoded.value, offset + 2); + insert_value(Globals::AddrOf(b, encoded.name), offset + 0); + insert_value(Globals::AddrOf(b, encoded.type_id), offset + 1); + + // For unit attributes we do not create any global operations, and just + // pass them as a null pointer. Attribute decoding treats null pointers as + // empty optional attributes. + if (encoded.value) { + insert_value(Globals::AddrOf(b, encoded.value), offset + 2); + } else { + insert_value(b.create(ptr), offset + 2); + } } // Return attributes array from the global initializer block. @@ -907,8 +1006,8 @@ FailureOr EncodeAttributes(mlir::SymbolTable &sym_table, Globals &g, auto global = g.TryGetOrCreate(b, attrs_map, type, symbol_base, init); if (failed(global)) return failure(); - // Return an address of global encoding attributes. - return Globals::AddrOf(b, *global); + // Return global encoding attributes. + return *global; } //===----------------------------------------------------------------------===// @@ -919,15 +1018,24 @@ LogicalResult ScalarArgEncoding::Match(Value value, Value converted) const { return success(IsSupportedScalarType(value.getType())); } -FailureOr ScalarArgEncoding::Encode(Globals &g, +FailureOr ScalarArgEncoding::Encode(Globals &g, Allocas &a, ImplicitLocOpBuilder &b, Value value, Value converted) const { Type type = converted.getType(); Encoded encoded; - encoded.type_id = PackTypeId(g, b, ScalarRuntimeTypeId(type)); - encoded.value = PackValue(b, converted); + encoded.type_id = EncodeTypeId(g, b, ScalarRuntimeTypeId(type)); + + // Encode constant arguments as global values. + if (IntegerAttr cst; matchPattern(converted, m_Constant(&cst))) { + std::string name = llvm::formatv("__rt_c{0}", cst.getValue()); + encoded.value = g.GetOrCreate(b, cst, name); + } else if (FloatAttr cst; matchPattern(converted, m_Constant(&cst))) { + encoded.value = g.GetOrCreate(b, cst, "__rt_cst"); + } else { + encoded.value = PackValue(b, a, converted); + } return encoded; } @@ -951,13 +1059,13 @@ LogicalResult OpaqueArgEncoding::Match(Value value, Value converted) const { return failure(); } -FailureOr OpaqueArgEncoding::Encode(Globals &g, +FailureOr OpaqueArgEncoding::Encode(Globals &g, Allocas &a, ImplicitLocOpBuilder &b, Value value, Value converted) const { Encoded encoded; - encoded.type_id = PackTypeId(g, b, type_id_); - encoded.value = PackValue(b, converted); + encoded.type_id = EncodeTypeId(g, b, type_id_); + encoded.value = PackValue(b, a, converted); return encoded; } @@ -1012,7 +1120,7 @@ static Value EncodeMemRef(ImplicitLocOpBuilder &b, MemRefType memref_ty, llvm::SmallVector strides; int64_t memref_offset; if (failed(getStridesAndOffset(memref_ty, strides, memref_offset))) - strides.resize(memref_ty.getRank(), ShapedType::kDynamicStrideOrOffset); + strides.resize(memref_ty.getRank(), ShapedType::kDynamic); // Build encoded memref sizes + strides: !llvm.array<... x i64> Value payload = b.create(type.getBody()[3]); @@ -1024,10 +1132,9 @@ static Value EncodeMemRef(ImplicitLocOpBuilder &b, MemRefType memref_ty, ? desc->size(b, loc, i) : b.create(i64(dim_size)); - Value stride = - ShapedType::isDynamicStrideOrOffset(stride_size) && desc.has_value() - ? desc->stride(b, loc, i) - : b.create(i64(stride_size)); + Value stride = ShapedType::isDynamic(stride_size) && desc.has_value() + ? desc->stride(b, loc, i) + : b.create(i64(stride_size)); auto stride_pos = memref_ty.getRank() + i; @@ -1045,9 +1152,12 @@ static Value EncodeMemRef(ImplicitLocOpBuilder &b, MemRefType memref_ty, // dynamic values into the struct after all statically know values leads to a // better canonicalization and cleaner final LLVM IR. if (desc.has_value()) { + Value offset = b.create(i64(memref_offset)); + Value data = b.create(desc->getElementPtrType(), + desc->alignedPtr(b, loc), offset); auto ptr = LLVM::LLVMPointerType::get(b.getContext()); - Value data = b.create(ptr, desc->alignedPtr(b, loc)); - memref = b.create(memref, data, 2); + memref = b.create( + memref, b.create(ptr, data), 2); } return memref; @@ -1057,7 +1167,7 @@ LogicalResult MemrefArgEncoding::Match(Value value, Value converted) const { return success(value.getType().isa()); } -FailureOr MemrefArgEncoding::Encode(Globals &g, +FailureOr MemrefArgEncoding::Encode(Globals &g, Allocas &a, ImplicitLocOpBuilder &b, Value value, Value converted) const { @@ -1070,8 +1180,8 @@ FailureOr MemrefArgEncoding::Encode(Globals &g, : TypeID::get>(); Encoded encoded; - encoded.type_id = PackTypeId(g, b, type_id); - encoded.value = PackValue(b, EncodeMemRef(b, memref_type, converted)); + encoded.type_id = EncodeTypeId(g, b, type_id); + encoded.value = PackValue(b, a, EncodeMemRef(b, memref_type, converted)); return encoded; } @@ -1084,16 +1194,16 @@ LogicalResult ScalarRetEncoding::Match(Type type, Type converted) const { return success(IsSupportedScalarType(type)); } -FailureOr ScalarRetEncoding::Encode(Globals &g, +FailureOr ScalarRetEncoding::Encode(Globals &g, Allocas &a, ImplicitLocOpBuilder &b, Type type, Type converted) const { Encoded encoded; - encoded.type_id = PackTypeId(g, b, ScalarRuntimeTypeId(converted)); + encoded.type_id = EncodeTypeId(g, b, ScalarRuntimeTypeId(converted)); + encoded.value = a.GetOrCreate(b, converted); - Type ptr = LLVM::LLVMPointerType::get(b.getContext()); - Value one = b.create(b.getI32IntegerAttr(1)); - encoded.value = b.create(ptr, converted, one, 0); + // Start the lifetime of encoded result. + b.create(b.getI64IntegerAttr(-1), encoded.value); return encoded; } @@ -1121,16 +1231,16 @@ LogicalResult OpaqueRetEncoding::Match(Type type, Type converted) const { return failure(); } -FailureOr OpaqueRetEncoding::Encode(Globals &g, +FailureOr OpaqueRetEncoding::Encode(Globals &g, Allocas &a, ImplicitLocOpBuilder &b, Type value, Type converted) const { Encoded encoded; - encoded.type_id = PackTypeId(g, b, type_id_); + encoded.type_id = EncodeTypeId(g, b, type_id_); + encoded.value = a.GetOrCreate(b, converted); - Type ptr = LLVM::LLVMPointerType::get(b.getContext()); - Value one = b.create(b.getI32IntegerAttr(1)); - encoded.value = b.create(ptr, converted, one, 0); + // Start the lifetime of encoded result. + b.create(b.getI64IntegerAttr(-1), encoded.value); return encoded; } @@ -1148,7 +1258,7 @@ LogicalResult MemrefRetEncoding::Match(Type type, Type converted) const { converted.isa()); } -FailureOr MemrefRetEncoding::Encode(Globals &g, +FailureOr MemrefRetEncoding::Encode(Globals &g, Allocas &a, ImplicitLocOpBuilder &b, Type type, Type converted) const { @@ -1159,11 +1269,11 @@ FailureOr MemrefRetEncoding::Encode(Globals &g, auto type_id = TypeID::get>(); Encoded encoded; - encoded.type_id = PackTypeId(g, b, type_id); + encoded.type_id = EncodeTypeId(g, b, type_id); // No memref descriptor for result, we only encode compile time known info: // dtype, rank, dims encoded.value = - PackValue(b, EncodeMemRef(b, memref_ty, /*descriptor=*/nullptr)); + PackValue(b, a, EncodeMemRef(b, memref_ty, /*descriptor=*/nullptr)); return encoded; } @@ -1194,7 +1304,6 @@ FailureOr MemrefRetEncoding::Decode(ImplicitLocOpBuilder &b, Type type, b.create(ptr, gep)); memref_desc.setAllocatedPtr(b, loc, data_ptr); memref_desc.setAlignedPtr(b, loc, data_ptr); - memref_desc.setConstantOffset(b, loc, 0); // Get the statically known strides and offset from the memref type. SmallVector strides; @@ -1203,6 +1312,8 @@ FailureOr MemrefRetEncoding::Decode(ImplicitLocOpBuilder &b, Type type, return failure(); } + memref_desc.setConstantOffset(b, loc, memref_offset); + // Fill memref descriptor dimensions and strides. for (unsigned i = 0; i < memref_type.getRank(); ++i) { memref_desc.setConstantSize(b, loc, i, memref_type.getDimSize(i)); @@ -1214,6 +1325,69 @@ FailureOr MemrefRetEncoding::Decode(ImplicitLocOpBuilder &b, Type type, return casted.getResult(0); } +//===----------------------------------------------------------------------===// + +LogicalResult AsyncValueRetEncoding::Match(Type type, Type converted) const { + return success( + (type.isa() || type.isa()) && + converted.isa()); +} + +FailureOr AsyncValueRetEncoding::Encode(Globals &g, Allocas &a, + ImplicitLocOpBuilder &b, + Type type, + Type converted) const { + Type ptr = LLVM::LLVMPointerType::get(b.getContext()); + Value one = b.create(b.getI32IntegerAttr(1)); + + auto type_id = type.isa() + ? AsyncValueRuntimeTypeId( + type.cast().getValueType()) + : TypeID::get>>(); + + Encoded encoded; + encoded.type_id = EncodeTypeId(g, b, type_id); + + // for !async.value encoding its dtype, rank and dims with + // EncodedMemRef struct; we use its data field to store async value ptr. + if (auto value_ty = type.dyn_cast()) { + if (auto memref_ty = value_ty.getValueType().dyn_cast()) { + encoded.value = + PackValue(b, a, EncodeMemRef(b, memref_ty, /*descriptor=*/nullptr)); + return encoded; + } + } + + encoded.value = b.create(ptr, converted, one, 0); + + return encoded; +} + +FailureOr AsyncValueRetEncoding::Decode(ImplicitLocOpBuilder &b, + Type type, Type converted, + LLVM::AllocaOp alloca) const { + if (auto value_ty = type.dyn_cast()) { + if (auto memref_ty = value_ty.getValueType().dyn_cast()) { + // TODO(ezhulenev): Add support for returning dynamically shaped memref. + if (!memref_ty.hasStaticShape()) return failure(); + + Value c0 = b.create(b.getI64IntegerAttr(0)); + Value c2 = b.create(b.getI64IntegerAttr(2)); + Type ptr = LLVM::LLVMPointerType::get(b.getContext()); + LLVM::LLVMStructType encoded = GetEncodeMemRefType(b, memref_ty); + Value gep = + b.create(ptr, encoded, alloca, ValueRange({c0, c2})); + Value async_value = b.create(converted, gep); + auto casted = b.create(type, async_value); + return casted.getResult(0); + } + } + + auto async_value = Value{b.create(converted, alloca)}; + auto casted = b.create(type, async_value); + return casted.getResult(0); +} + //===----------------------------------------------------------------------===// // Default encodings for arguments, attributes, and results //===----------------------------------------------------------------------===// @@ -1223,13 +1397,11 @@ CustomCallAttrEncodingSet DefaultAttrEncodings() { encodings .Add(); + SymbolRefAttrEncoding, UnitAttrEncoding, DictionaryAttrEncoding>(); encodings.Add>( - encodings, AggregateAttrDef() - .Add("hlo_op", &HloTraceAttr::getHloOp) - .Add("module", &HloTraceAttr::getModule) - .Add("program_id", &HloTraceAttr::getProgramId)); + encodings, + AggregateAttrDef().Add("hlo_op", &HloTraceAttr::getHloOp)); return encodings; } @@ -1242,7 +1414,8 @@ CustomCallArgEncodingSet DefaultArgEncodings() { CustomCallRetEncodingSet DefaultRetEncodings() { CustomCallRetEncodingSet encodings; - encodings.Add(); + encodings.Add(); return encodings; } diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.h b/tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.h index d59a9d5f3a1..6e1e8dae14f 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.h +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/custom_call_encoding.h @@ -23,12 +23,14 @@ limitations under the License. #include #include #include +#include #include #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project @@ -60,23 +62,27 @@ namespace runtime { // // Custom call arguments are encoded as an array of pointers allocated on the // stack. Each individual argument is also encoded on the stack, because -// arguments are run time values and we can't encode them in the constant -// section. +// arguments are typically run time values and we can't encode them in the +// constant section. Statically known arguments (constants) can be encoded as +// global values together with attributes. // Forward declare class declared below. class Globals; +class Allocas; //===----------------------------------------------------------------------===// // Custom call arguments encoding. //===----------------------------------------------------------------------===// -// Encodes argument into stack allocated storage according to the ABI. If -// argument is a constant, then it can be packed as a global constant. +// Encodes argument into stack allocated storage according to the ABI. class CustomCallArgEncoding { public: struct Encoded { - mlir::Value type_id; // !llvm.ptr - mlir::Value value; // !llvm.ptr + mlir::LLVM::GlobalOp type_id; // llvm.mlir.global external $type_name : i64 + + // Statically known arguments might be encoded as global constants, + // otherwise it will be `!llvm.alloca 1 x ArgType`. + std::variant value; }; virtual ~CustomCallArgEncoding() = default; @@ -84,7 +90,7 @@ class CustomCallArgEncoding { virtual mlir::LogicalResult Match(mlir::Value value, mlir::Value conterted) const = 0; - virtual mlir::FailureOr Encode(Globals &g, + virtual mlir::FailureOr Encode(Globals &g, Allocas &a, mlir::ImplicitLocOpBuilder &b, mlir::Value value, mlir::Value converted) const = 0; @@ -97,7 +103,8 @@ class CustomCallArgEncodingSet { // Finds matching argument encoding and tries to encode the values. Returns // failure if didn't match values to any of the argument encodings. - mlir::FailureOr Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, + mlir::FailureOr Encode(Globals &g, Allocas &a, + mlir::ImplicitLocOpBuilder &b, mlir::Value value, mlir::Value converted) const; @@ -128,8 +135,8 @@ class CustomCallArgEncodingSet { class CustomCallRetEncoding { public: struct Encoded { - mlir::Value type_id; // !llvm.ptr - mlir::LLVM::AllocaOp value; // !llvm.alloca 1 x ResultType + mlir::LLVM::GlobalOp type_id; // llvm.mlir.global external $type_name : i64 + mlir::LLVM::AllocaOp value; // !llvm.alloca 1 x ResultType }; virtual ~CustomCallRetEncoding() = default; @@ -137,7 +144,7 @@ class CustomCallRetEncoding { virtual mlir::LogicalResult Match(mlir::Type type, mlir::Type converted) const = 0; - virtual mlir::FailureOr Encode(Globals &g, + virtual mlir::FailureOr Encode(Globals &g, Allocas &a, mlir::ImplicitLocOpBuilder &b, mlir::Type type, mlir::Type converted) const = 0; @@ -154,7 +161,8 @@ class CustomCallRetEncodingSet { // Finds matching result encoding and tries to encode the values. Returns // failure if didn't match values to any of the result encodings. - mlir::FailureOr Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, + mlir::FailureOr Encode(Globals &g, Allocas &a, + mlir::ImplicitLocOpBuilder &b, mlir::Type type, mlir::Type converted) const; // Convert the encoded value in alloca back to a value with the converted @@ -193,9 +201,9 @@ struct CustomCallAttrEncoding { static constexpr char kAttrValue[] = "__rt_attr_value"; struct Encoded { - mlir::Value name; // !llvm.ptr - mlir::Value type_id; // !llvm.ptr - mlir::Value value; // !llvm.ptr + mlir::LLVM::GlobalOp name; // llvm.mlir.global + mlir::LLVM::GlobalOp type_id; // llvm.mlir.global external $type_name : i64 + mlir::LLVM::GlobalOp value; // llvm.mlir.global }; virtual ~CustomCallAttrEncoding() = default; @@ -243,26 +251,28 @@ class CustomCallAttrEncodingSet { }; //===----------------------------------------------------------------------===// -// A set of helper functions for packing primitive attributes. +// A set of helper functions for packing encoding attributes. //===----------------------------------------------------------------------===// -// Packs TypeID as `i64` constant value and casts it to the `!llvm.ptr`, -// because type id internally is implemented as an opaque pointer. -mlir::Value PackTypeId(Globals &g, mlir::ImplicitLocOpBuilder &b, - mlir::TypeID type_id); - -// Packs string as a module global null-terminated string constant. We reuse -// the encoding scheme for arrays to store sting with its size, to avoid -// computing the length of the null-terminated string at run tine. -// -// Returns `!llvm.ptr>`. -mlir::Value PackString(Globals &g, mlir::ImplicitLocOpBuilder &b, - std::string_view strref, std::string_view symbol_base); - -// Packs scalar attribute as a global constant. Returns `!llvm.ptr`. -mlir::Value PackScalarAttribute(Globals &g, mlir::ImplicitLocOpBuilder &b, - mlir::Attribute value, - std::string_view symbol_base); +// Encodes type id as an external LLVM global of type `i64`. The global name is +// defined by the type id name registry. Internally type id implemented as an +// opaque pointer (void*), and type equality check at run time is just a pointer +// comparison. All type id symbols at run time must be resolved to the type id +// instances defined in the current process. +mlir::LLVM::GlobalOp EncodeTypeId(Globals &g, mlir::ImplicitLocOpBuilder &b, + TypeID type_id); + +// Encodes string as a module global null-terminated string constant + size. We +// reuse the encoding scheme for arrays to store sting with its size, to avoid +// computing the length of the null-terminated string at run time. +mlir::LLVM::GlobalOp EncodeString(Globals &g, mlir::ImplicitLocOpBuilder &b, + std::string_view strref, + std::string_view symbol_base); + +// Encodes scalar attribute as a global constant. +mlir::LLVM::GlobalOp EncodeScalar(Globals &g, mlir::ImplicitLocOpBuilder &b, + mlir::Attribute value, + std::string_view symbol_base); //===----------------------------------------------------------------------===// // A helper class to create global constants in the module. @@ -286,7 +296,7 @@ class Globals { // Creates a global external variable for the type id. mlir::LLVM::GlobalOp GetOrCreate(mlir::ImplicitLocOpBuilder &b, - mlir::TypeID type_id); + TypeID type_id); // Creates a global null-terminated string constant. mlir::LLVM::GlobalOp GetOrCreate(mlir::ImplicitLocOpBuilder &b, @@ -338,6 +348,55 @@ class Globals { TypeIDNameRegistry type_id_names_; }; +//===----------------------------------------------------------------------===// +// A helper class to create alloca operations for encoded arguments. +//===----------------------------------------------------------------------===// + +class EncodingAllocas; + +// We reuse allocas for encoding custom call arguments and results, because we +// potentially can have thousands of custom calls, and we do not want to +// accidentally blow up the stack size. It means that we might encode the same +// argument multiple times, but encoding is cheap (few store operations), and +// LLVM can potentially optimize them away. +// +// TODO(ezhulenev): Use `llvm.invariant.start` and `llvm.invariant.end` to mark +// encoded arguments allocas. +class Allocas { + public: + ~Allocas(); + + mlir::LLVM::AllocaOp GetOrCreate(mlir::ImplicitLocOpBuilder &b, + mlir::Type type); + + private: + friend class EncodingAllocas; + + struct TypedAllocas { + size_t offset = 0; + llvm::SmallVector allocas; + }; + + explicit Allocas(mlir::Block *block, + llvm::DenseMap *allocas); + + mlir::Block *block_; + llvm::DenseMap *allocas_; +}; + +// Mapping from basic block to allocas. +class EncodingAllocas { + public: + Allocas GetForOperation(mlir::Operation *op); + + private: + friend class Allocas; + + llvm::DenseMap> + allocas_; +}; + //===----------------------------------------------------------------------===// // Custom call attributes encoding. //===----------------------------------------------------------------------===// @@ -357,7 +416,7 @@ class Globals { // 2. Custom call attributes, where the attributes sorted lexicographically by // name, to be able to efficiently decode named attributes. // -mlir::FailureOr EncodeAttributes( +mlir::FailureOr EncodeAttributes( mlir::SymbolTable &sym_table, Globals &g, mlir::ImplicitLocOpBuilder &b, const CustomCallAttrEncodingSet &encoding, std::string_view symbol_base, llvm::ArrayRef attrs); @@ -434,6 +493,15 @@ struct UnitAttrEncoding : public CustomCallAttrEncoding { mlir::Attribute) const final; }; +struct DictionaryAttrEncoding : public CustomCallAttrEncoding { + mlir::LogicalResult Match(mlir::SymbolTable &, std::string_view, + mlir::Attribute) const final; + mlir::FailureOr Encode(mlir::SymbolTable &, Globals &, + mlir::ImplicitLocOpBuilder &, + std::string_view, + mlir::Attribute) const final; +}; + // Custom call attribute encoding that encodes enums using their underlying // scalar type. Type id is based on the enum type passed to the runtime. // @@ -471,13 +539,13 @@ struct EnumAttrEncoding : public CustomCallAttrEncoding { using T = std::underlying_type_t; T underlying_value = static_cast(run_time_enum); - mlir::TypeID type_id = mlir::TypeID::get>(); + TypeID type_id = TypeID::get>(); mlir::Attribute underlying_attr = AsAttr(b, underlying_value); Encoded encoded; - encoded.name = PackString(g, b, name, kAttrName); - encoded.type_id = PackTypeId(g, b, type_id); - encoded.value = PackScalarAttribute(g, b, underlying_attr, kAttrValue); + encoded.name = EncodeString(g, b, name, kAttrName); + encoded.type_id = EncodeTypeId(g, b, type_id); + encoded.value = EncodeScalar(g, b, underlying_attr, kAttrValue); return encoded; } @@ -557,15 +625,15 @@ struct AggregateAttrEncoding : public CustomCallAttrEncoding { attrs.emplace_back(bind(attr.cast(), b)); // Encode extracted attributes as an aggregate. - auto type_id = mlir::TypeID::get>(); + auto type_id = TypeID::get>(); auto sym = "__rt_aggregate_" + AttrType::getMnemonic(); auto aggregate = EncodeAttributes(sym_table, g, b, encoding, sym.str(), attrs); if (mlir::failed(aggregate)) return mlir::failure(); Encoded encoded; - encoded.name = PackString(g, b, name, kAttrName); - encoded.type_id = PackTypeId(g, b, type_id); + encoded.name = EncodeString(g, b, name, kAttrName); + encoded.type_id = EncodeTypeId(g, b, type_id); encoded.value = *aggregate; return encoded; } @@ -582,8 +650,9 @@ struct AggregateAttrEncoding : public CustomCallAttrEncoding { class ScalarArgEncoding : public CustomCallArgEncoding { public: mlir::LogicalResult Match(mlir::Value, mlir::Value) const final; - mlir::FailureOr Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, - mlir::Value, mlir::Value) const final; + mlir::FailureOr Encode(Globals &g, Allocas &a, + mlir::ImplicitLocOpBuilder &b, mlir::Value, + mlir::Value) const final; }; // Encodes custom call arguments passed as an opaque LLVM pointer (!llvm.ptr) @@ -595,8 +664,9 @@ class OpaqueArgEncoding : public CustomCallArgEncoding { OpaqueArgEncoding(std::function match, TypeID type_id); mlir::LogicalResult Match(mlir::Value, mlir::Value) const final; - mlir::FailureOr Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, - mlir::Value, mlir::Value) const final; + mlir::FailureOr Encode(Globals &g, Allocas &a, + mlir::ImplicitLocOpBuilder &b, mlir::Value, + mlir::Value) const final; template static auto Match() { @@ -612,8 +682,9 @@ class OpaqueArgEncoding : public CustomCallArgEncoding { class MemrefArgEncoding : public CustomCallArgEncoding { public: mlir::LogicalResult Match(mlir::Value, mlir::Value) const final; - mlir::FailureOr Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, - mlir::Value, mlir::Value) const final; + mlir::FailureOr Encode(Globals &g, Allocas &a, + mlir::ImplicitLocOpBuilder &b, mlir::Value, + mlir::Value) const final; }; //===----------------------------------------------------------------------===// @@ -624,8 +695,9 @@ class MemrefArgEncoding : public CustomCallArgEncoding { class ScalarRetEncoding : public CustomCallRetEncoding { public: mlir::LogicalResult Match(mlir::Type, mlir::Type) const final; - mlir::FailureOr Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, - mlir::Type, mlir::Type) const final; + mlir::FailureOr Encode(Globals &g, Allocas &a, + mlir::ImplicitLocOpBuilder &b, mlir::Type, + mlir::Type) const final; mlir::FailureOr Decode(mlir::ImplicitLocOpBuilder &b, mlir::Type, mlir::Type, mlir::LLVM::AllocaOp) const final; @@ -640,8 +712,9 @@ class OpaqueRetEncoding : public CustomCallRetEncoding { OpaqueRetEncoding(std::function match, TypeID type_id); mlir::LogicalResult Match(mlir::Type, mlir::Type) const final; - mlir::FailureOr Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, - mlir::Type, mlir::Type) const final; + mlir::FailureOr Encode(Globals &g, Allocas &a, + mlir::ImplicitLocOpBuilder &b, mlir::Type, + mlir::Type) const final; mlir::FailureOr Decode(mlir::ImplicitLocOpBuilder &b, mlir::Type, mlir::Type, mlir::LLVM::AllocaOp) const final; @@ -660,8 +733,20 @@ class OpaqueRetEncoding : public CustomCallRetEncoding { class MemrefRetEncoding : public CustomCallRetEncoding { public: mlir::LogicalResult Match(mlir::Type, mlir::Type) const final; - mlir::FailureOr Encode(Globals &g, mlir::ImplicitLocOpBuilder &b, - mlir::Type, mlir::Type) const final; + mlir::FailureOr Encode(Globals &g, Allocas &a, + mlir::ImplicitLocOpBuilder &b, mlir::Type, + mlir::Type) const final; + mlir::FailureOr Decode(mlir::ImplicitLocOpBuilder &b, mlir::Type, + mlir::Type, + mlir::LLVM::AllocaOp) const final; +}; + +class AsyncValueRetEncoding : public CustomCallRetEncoding { + public: + mlir::LogicalResult Match(mlir::Type, mlir::Type) const final; + mlir::FailureOr Encode(Globals &g, Allocas &a, + mlir::ImplicitLocOpBuilder &b, mlir::Type, + mlir::Type) const final; mlir::FailureOr Decode(mlir::ImplicitLocOpBuilder &b, mlir::Type, mlir::Type, mlir::LLVM::AllocaOp) const final; diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/jit_compiler.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/jit_compiler.cc index ed722569027..0e45d184c7b 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/jit_compiler.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/jit_compiler.cc @@ -30,6 +30,7 @@ limitations under the License. #include "llvm/Support/TargetSelect.h" #include "mlir/ExecutionEngine/OptUtils.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project @@ -51,7 +52,7 @@ static bool DebugJitCompiler() { #if defined(DEBUG_XLA_RUNTIME_COMPILER) return true; #endif - return false; + return VLOG_IS_ON(5); } static bool EnablePassTiming() { @@ -94,12 +95,13 @@ static LogicalResult RunPipeline( ModuleOp module, const std::function& create_pipeline) { if (!create_pipeline) return success(); - mlir::PassManager pm(module.getContext()); - SetupPassDebugging(module.getContext(), pm); - // Instrument the pass manager to capture timing information. DefaultTimingManager tm; TimingScope timing; + + mlir::PassManager pm(module.getContext()); + SetupPassDebugging(module.getContext(), pm); + if (EnablePassTiming()) { tm.setEnabled(true); timing = tm.getRootScope(); @@ -125,19 +127,17 @@ static LogicalResult RunSpecializationPipeline( //===----------------------------------------------------------------------===// -// Creates a new MLIR Context and registers all the dialects that are expected +// Configures MLIR Context and registers all the dialects that are expected // in the compiled module. -static std::unique_ptr CreateMlirContext( - const JitCompiler::Options& opts) { - DialectRegistry dialects; - - // Call user-provided callback to register all required dialects. - if (opts.register_dialects) opts.register_dialects(dialects); - - auto threading = MLIRContext::Threading::DISABLED; - auto ctx = std::make_unique(*dialects, threading); - ctx->loadAllAvailableDialects(); - return ctx; +static void ConfigureMlirContext(MLIRContext* context, + const JitCompiler::Options& opts) { + if (opts.register_dialects) { + // Call user-provided callback to register all required dialects. + DialectRegistry dialects; + opts.register_dialects(dialects); + context->appendDialectRegistry(*dialects); + context->loadAllAvailableDialects(); + } } //===----------------------------------------------------------------------===// @@ -147,33 +147,36 @@ static std::unique_ptr CreateMlirContext( JitCompiler::JitCompiler(JitCompiler::Options opts, std::string_view mlir_module) : opts_(std::move(opts)), - context_(CreateMlirContext(opts_)), + owned_context_( + std::make_unique(MLIRContext::Threading::DISABLED)), + context_(owned_context_.get()), diagnostic_os_(diagnostic_), - handler_(source_mgr_, context_.get(), diagnostic_os_), + handler_(source_mgr_, context_, diagnostic_os_), specialized_(false) { + ConfigureMlirContext(context_, opts_); source_mgr_.AddNewSourceBuffer( llvm::MemoryBuffer::getMemBuffer(mlir_module, "xla.program"), llvm::SMLoc()); - module_ = parseSourceFile(source_mgr_, context_.get()); + module_ = parseSourceFile(source_mgr_, context_); } -/*static*/ absl::StatusOr> -JitCompiler::Instantiate(JitCompiler::Options opts, - std::string_view mlir_module, - absl::Span exported) { - std::unique_ptr compiler( - new JitCompiler(std::move(opts), mlir_module)); - - // Check that mlir source was parsed into module operation. - if (!compiler->module_) - return compiler->Error("failed to parse the mlir source"); +JitCompiler::JitCompiler(JitCompiler::Options opts, mlir::ModuleOp mlir_module) + : opts_(std::move(opts)), + context_(mlir_module.getContext()), + diagnostic_os_(diagnostic_), + handler_(source_mgr_, context_, diagnostic_os_), + module_(mlir_module), + specialized_(false) { + ConfigureMlirContext(context_, opts_); +} - ModuleOp module = *compiler->module_; - SymbolTable sym_table(module); +absl::Status JitCompiler::ComputeOrdinalsForExportedFunctions( + absl::Span exported) { + SymbolTable sym_table(*module_); // Add `rt.export` operations for all explicitly exported functions. for (auto& indexed : llvm::enumerate(exported)) { - if (auto func = sym_table.lookup(indexed.value())) { + if (auto func = sym_table.lookup(indexed.value())) { OpBuilder(func).create(func.getLoc(), func, indexed.index()); continue; } @@ -182,19 +185,50 @@ JitCompiler::Instantiate(JitCompiler::Options opts, // Assign unique ordinals to all exported functions, including functions that // were already exported with `rt.export` operations in the input IR. - mlir::PassManager pm(module.getContext()); + mlir::PassManager pm(module_->getContext()); pm.addPass(CreateOrdinalAssignmentPass()); - if (failed(pm.run(module))) - return compiler->Error("failed to run ordinal assignment pass"); + if (failed(pm.run(*module_))) + return Error("failed to run ordinal assignment pass"); // Resolve all functions exported from the module indexed by ordinal. - for (ExportOp op : module.getOps()) { + for (ExportOp op : module_->getOps()) { unsigned ordinal = *op.ordinal(); - if (ordinal >= compiler->exported_.size()) - compiler->exported_.resize(ordinal + 1); - compiler->exported_[ordinal] = op.exported(sym_table); + if (ordinal >= exported_.size()) exported_.resize(ordinal + 1); + exported_[ordinal] = op.exported(sym_table); } + return absl::OkStatus(); +} + +/*static*/ absl::StatusOr> +JitCompiler::Instantiate(JitCompiler::Options opts, + std::string_view mlir_module, + absl::Span exported) { + std::unique_ptr compiler( + new JitCompiler(std::move(opts), mlir_module)); + + // Check that mlir source was parsed into module operation. + if (!compiler->module_) + return compiler->Error("failed to parse the mlir source"); + + auto status = compiler->ComputeOrdinalsForExportedFunctions(exported); + if (!status.ok()) return status; + + // Initialize LLVM compiler internals. + InitializeLlvmCompiler(); + + return {std::move(compiler)}; +} + +/*static*/ absl::StatusOr> +JitCompiler::Instantiate(JitCompiler::Options opts, ModuleOp mlir_module, + absl::Span exported) { + std::unique_ptr compiler( + new JitCompiler(std::move(opts), mlir_module)); + + auto status = compiler->ComputeOrdinalsForExportedFunctions(exported); + if (!status.ok()) return status; + // Initialize LLVM compiler internals. InitializeLlvmCompiler(); @@ -218,15 +252,17 @@ JitCompiler::Instantiate(JitCompiler::Options opts, std::vector exported; // names of exported functions for (auto& indexed : llvm::enumerate(compiler->exported())) { - func::FuncOp func = indexed.value(); + auto func = indexed.value(); std::string_view name = exported.emplace_back(func.getName()); // Get the signature of the exported function. - auto signature = opts.type_converter.Convert(func.getFunctionType()); + auto signature = opts.type_converter.Convert( + llvm::cast(func.getFunctionType())); if (!signature.ok()) return signature.status(); // Calling convention conversion can fail if some types are not supported. - auto runtime_type = opts.calling_convention(func.getFunctionType()); + auto runtime_type = opts.calling_convention( + llvm::cast(func.getFunctionType())); if (!runtime_type) return compiler->Error(StrFormat( "calling convention failed to convert function type for %s", name)); @@ -247,14 +283,11 @@ JitCompiler::Instantiate(JitCompiler::Options opts, // Add function with an unresolved function pointer; it will be updated once // we compile the input module to the native executable. - Executable::Function function{std::string(name), - /*fptr=*/nullptr, - std::move(*signature), - std::move(*runtime_signature), - std::move(*arguments_memory_layout), - std::move(*results_memory_layout)}; - - functions.push_back(std::move(function)); + functions.push_back(Executable::Function( + name, + /*fptr=*/nullptr, std::move(*signature), std::move(*runtime_signature), + std::move(*arguments_memory_layout), + std::move(*results_memory_layout))); } // Run the compilation pipeline to lower the module to LLVM dialect. @@ -335,7 +368,7 @@ absl::Status JitCompiler::Specialize(unsigned ordinal, ArgumentsRef arguments, assert(!specialized_ && "can specialize executable only once"); specialized_ = true; - func::FuncOp func = exported(ordinal); + auto func = exported(ordinal); // Update function signature and sink constant arguments into the body. if (auto specialized = SpecializeFunction(func, arguments, symbolic_shapes, diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/jit_compiler.h b/tensorflow/compiler/xla/mlir/runtime/transforms/jit_compiler.h index b15a4031ff7..0c18b831563 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/jit_compiler.h +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/jit_compiler.h @@ -28,7 +28,6 @@ limitations under the License. #include "llvm/Support/SourceMgr.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/xla/mlir/runtime/transforms/calling_convention.h" #include "tensorflow/compiler/xla/mlir/runtime/transforms/specialization.h" #include "tensorflow/compiler/xla/mlir/runtime/transforms/type_converter.h" @@ -113,6 +112,11 @@ class JitCompiler { Options opts, std::string_view mlir_module, absl::Span exported); + // Instantiates compiler from the mlir module. + static absl::StatusOr> Instantiate( + Options opts, mlir::ModuleOp mlir_module, + absl::Span exported); + // Makes an executable from an instance of the JitCompiler. This is the end of // life for the `JitCompiler`, it effectively converts the MLIR module // to the executable (function pointer) using LLVM JIT code generation. @@ -154,15 +158,21 @@ class JitCompiler { size_t num_exported() const { return exported_.size(); } - absl::Span exported() const { return exported_; } + absl::Span exported() const { + return exported_; + } - mlir::func::FuncOp exported(unsigned ordinal) const { + mlir::FunctionOpInterface exported(unsigned ordinal) const { assert(exported_[ordinal] && "failed to resolve exported function"); return exported_[ordinal]; } private: JitCompiler(Options opts, std::string_view mlir_module); + JitCompiler(Options opts, mlir::ModuleOp mlir_module); + + absl::Status ComputeOrdinalsForExportedFunctions( + absl::Span exported); absl::Status Error(std::string_view error) { // TODO(ezhulenev): Pass diagnstic as a status payload. @@ -170,7 +180,8 @@ class JitCompiler { } Options opts_; - std::unique_ptr context_; + std::unique_ptr owned_context_; // set if context is owned + mlir::MLIRContext* context_; std::string diagnostic_; llvm::raw_string_ostream diagnostic_os_; @@ -178,8 +189,8 @@ class JitCompiler { llvm::SourceMgr source_mgr_; mlir::SourceMgrDiagnosticHandler handler_; - mlir::OwningOpRef module_; // can be null if failed to parse - std::vector exported_; // can be empty if failed to parse + mlir::OwningOpRef module_; // null if failed to parse + std::vector exported_; // empty if failed to parse bool specialized_; }; diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/rt_to_llvm.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/rt_to_llvm.cc index a1e595c4cbc..d5fd07f68b9 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/rt_to_llvm.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/rt_to_llvm.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include "llvm/ADT/None.h" @@ -33,6 +34,7 @@ limitations under the License. #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project @@ -104,10 +106,12 @@ struct RuntimeAPI { }; // Adds function declaration if it doesn't already exist. -static void AddDeclaration(ModuleOp module, std::string_view name, - FunctionType type) { +static void AddDeclaration(SymbolTable &sym_table, ModuleOp module, + std::string_view name, FunctionType type) { + assert(sym_table.getOp() == module && "incorrect symbol table"); + if (sym_table.lookup(name)) return; + auto b = ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); - if (module.lookupSymbol(name)) return; MLIRContext *ctx = module.getContext(); func::FuncOp func = b.create(name, type); @@ -116,12 +120,14 @@ static void AddDeclaration(ModuleOp module, std::string_view name, // TODO(ezhulenev): Add per-argument nocapture attributes? func->setAttr("passthrough", ArrayAttr::get(ctx, {StringAttr::get(ctx, "nounwind")})); + + sym_table.insert(func); } // Adds Runtime C API declarations to the module. -static void AddRuntimeApiDeclarations(ModuleOp module) { +static void AddRuntimeApiDeclarations(SymbolTable &sym_table, ModuleOp module) { auto add = [&](std::string_view name, FunctionType type) { - AddDeclaration(module, name, type); + AddDeclaration(sym_table, module, name, type); }; MLIRContext *ctx = module.getContext(); @@ -215,82 +221,136 @@ class IsOkOpLowering : public OpConversionPattern { // Convert rt.custom_call to the corresponding runtime API call. //===----------------------------------------------------------------------===// -static FailureOr EncodeArguments( - CallOp op, CustomCallArgEncodingSet &encodings, Globals &g, +static Value AsPtr(ImplicitLocOpBuilder &b, + std::variant &v) { + if (auto *alloca = std::get_if(&v)) + return alloca->getResult(); + return Globals::AddrOf(b, std::get(v)); +} + +static LLVM::GlobalOp EncodeEmptyArgsRets(Globals &g, ImplicitLocOpBuilder &b, + std::string_view symbol_base) { + // Empty args/rets is just an array with a single pointer to size (zero). + Type ptr = LLVM::LLVMPointerType::get(b.getContext()); + Type type = LLVM::LLVMArrayType::get(ptr, 1); + + auto init = [&](ImplicitLocOpBuilder &ib, Attribute) { + LLVM::GlobalOp zero = + EncodeScalar(g, b, b.getI64IntegerAttr(0), "__rt_zero"); + + Value arr = b.create(type); + arr = b.create(arr, Globals::AddrOf(b, zero), 0); + b.create(arr); + }; + + return g.GetOrCreate(b, b.getArrayAttr({}), type, symbol_base, init); +} + +static LLVM::GlobalOp EncodeTypeTable(Globals &g, ImplicitLocOpBuilder &b, + ArrayRef type_ids, + std::string_view symbol_base) { + // We store type table as `!llvm.array`. + Type ptr = LLVM::LLVMPointerType::get(b.getContext()); + Type type = LLVM::LLVMArrayType::get(ptr, type_ids.size()); + + // Global initializer that encodes type ids as pointers. + auto init = [&](ImplicitLocOpBuilder &ib, Attribute) -> LogicalResult { + Value arr = b.create(type); + for (auto &pair : llvm::enumerate(type_ids)) { + arr = b.create(arr, Globals::AddrOf(b, pair.value()), + pair.index()); + } + b.create(arr); + return success(); + }; + + // Put all type ids into an array attribute, so we can use it as a globals + // cache key, so we do not encode the same type table multiple times. + llvm::SmallVector type_id_syms; + for (auto type_id : type_ids) type_id_syms.push_back(type_id.getSymName()); + auto arr_attr = b.getStrArrayAttr(type_id_syms); + + return g.GetOrCreate(b, arr_attr, type, symbol_base, init); +} + +struct EncodedArguments { + std::variant encoded; // `args` argument + SmallVector> values; +}; + +static FailureOr EncodeArguments( + CallOp op, CustomCallArgEncodingSet &encodings, Globals &g, Allocas &a, DenseMap &encoded_args, ImplicitLocOpBuilder &b, ValueRange operands, ValueRange converted) { llvm::SmallVector encoded; - // Encode all arguments as a set of pointers (skip the execution context). - for (auto tuple : llvm::drop_begin(llvm::zip(operands, converted))) { - // Check if the value was already encoded. - auto it = encoded_args.find(std::get<0>(tuple)); - if (it != encoded_args.end()) { - encoded.push_back(it->second); - continue; - } + // Encode empty arguments as a global array (skip the status type). + if (operands.drop_front().empty()) { + return EncodedArguments{EncodeEmptyArgsRets(g, b, "__rt_empty_args"), {}}; + } - // Otherwise encode it right after the converted value definition. - OpBuilder::InsertionGuard guard(b); - if (auto *defining_op = std::get<1>(tuple).getDefiningOp()) { - b.setInsertionPointAfter(defining_op); - } else { - b.setInsertionPointToStart(std::get<1>(tuple).getParentBlock()); - } + EncodedArguments arguments; + // Encode all arguments as a set of pointers (skip the execution context). + for (auto tuple : llvm::drop_begin(llvm::zip(operands, converted))) { auto encoded_arg = - encodings.Encode(g, b, std::get<0>(tuple), std::get<1>(tuple)); + encodings.Encode(g, a, b, std::get<0>(tuple), std::get<1>(tuple)); if (failed(encoded_arg)) return failure(); encoded.push_back(*encoded_arg); encoded_args.try_emplace(std::get<0>(tuple), *encoded_arg); } - // We store encoded arguments as `!llvm.array x len>`. + // We store encoded arguments as `!llvm.array`. + size_t len = encoded.empty() ? 1 : 2 + encoded.size(); Type ptr = LLVM::LLVMPointerType::get(b.getContext()); - Type type = LLVM::LLVMArrayType::get(ptr, 1 + encoded.size() * 2); + Type type = LLVM::LLVMArrayType::get(ptr, len); - // Prepare an array for encoding arguments. + // Prepare an array for encoded arguments. Value arr = b.create(type); auto insert_value = [&](Value value, int64_t offset) { arr = b.create(arr, value, offset); }; // Insert the number of encoded arguments. - Attribute num_args = b.getI64IntegerAttr(encoded.size()); - insert_value(PackScalarAttribute(g, b, num_args, "__rt_num_args"), 0); - - // Store encoded arguments into the allocated storage. + LLVM::GlobalOp num_args = + EncodeScalar(g, b, b.getI64IntegerAttr(encoded.size()), "__rt_num_args"); + insert_value(Globals::AddrOf(b, num_args), 0); + + // Package arguments type ids into a type table global value. + llvm::SmallVector type_ids; + for (auto &arg : encoded) type_ids.push_back(arg.type_id); + LLVM::GlobalOp type_table = + EncodeTypeTable(g, b, type_ids, "__rt_args_type_table"); + if (!encoded.empty()) insert_value(Globals::AddrOf(b, type_table), 1); + + // Store pointer to encoded arguments into the allocated storage. for (auto &pair : llvm::enumerate(encoded)) { CustomCallArgEncoding::Encoded encoded = pair.value(); - int64_t offset = 1 + pair.index() * 2; - - insert_value(encoded.type_id, offset + 0); - insert_value(encoded.value, offset + 1); + int64_t offset = 2 + pair.index(); + insert_value(AsPtr(b, encoded.value), offset); + arguments.values.push_back(encoded.value); } - // Always create an `alloca` in the parent function entry block. - // See: https://llvm.org/docs/Frontend/PerformanceTips.html#use-of-allocas - Value mem = [&]() -> Value { - Block &block = op->getParentOfType().getBody().front(); - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(&block); - Value c1 = b.create(b.getI32IntegerAttr(1)); - return b.create(ptr, type, c1, 0); - }(); + // Get allocation for packed arguments pointers. + LLVM::AllocaOp alloca = a.GetOrCreate(b, type); + + // Start the lifetime of the encoded arguments pointers. + b.create(b.getI64IntegerAttr(-1), alloca); - // Store constructed arguments array on the stack and return a pointer to it. - b.create(arr, mem); + // Store constructed arguments pointers array into the alloca. + b.create(arr, alloca.getRes()); - // Return a pointer to the encoded arguments. - return mem; + // Alloca that encodes the custom call arguments. + arguments.encoded = alloca; + + return arguments; } // Encodes attributes into the global constant (array of pointers to the // attributes data, which are also stored as global constants). -static FailureOr EncodeAttributes(CustomCallAttrEncodingSet &encodings, - SymbolTable &sym_table, Globals &g, - ImplicitLocOpBuilder &b, - ArrayRef attrs) { +static FailureOr EncodeAttributes( + CustomCallAttrEncodingSet &encodings, SymbolTable &sym_table, Globals &g, + ImplicitLocOpBuilder &b, ArrayRef attrs) { // Forward attributes that are not part of the custom call operation itself. auto forward_attr = [](NamedAttribute attr) -> bool { return attr.getName() != "callee" && attr.getName() != "dynamic"; @@ -310,30 +370,34 @@ static FailureOr EncodeAttributes(CustomCallAttrEncodingSet &encodings, } struct EncodedResults { - Value result_array_ptr; // passed as 'rets' argument to custom call - SmallVector allocas; // storage for values of results + std::variant encoded; // `rets` argument + SmallVector allocas; // encoded returns }; static FailureOr EncodeResults( - CallOp op, CustomCallRetEncodingSet &encodings, Globals &g, + CallOp op, CustomCallRetEncodingSet &encodings, Globals &g, Allocas &a, ImplicitLocOpBuilder &b, TypeRange ret_types, TypeRange converted_types) { llvm::SmallVector encoded; + + // Encode empty returns as a global array (skip the status type). + if (ret_types.drop_front().empty()) { + return EncodedResults{EncodeEmptyArgsRets(g, b, "__rt_empty_rets"), {}}; + } + EncodedResults results; // Encode all returns as a set of pointers (skip the status type). for (auto tuple : llvm::drop_begin(llvm::zip(ret_types, converted_types))) { - Block &block = op->getParentOfType().getBody().front(); - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(&block); auto encoded_ret = - encodings.Encode(g, b, std::get<0>(tuple), std::get<1>(tuple)); + encodings.Encode(g, a, b, std::get<0>(tuple), std::get<1>(tuple)); if (failed(encoded_ret)) return failure(); encoded.push_back(*encoded_ret); } - // We store encoded results as `!llvm.array x len>`. + // We store encoded results as `!llvm.array`. + size_t len = encoded.empty() ? 1 : 2 + encoded.size(); Type ptr = LLVM::LLVMPointerType::get(b.getContext()); - Type type = LLVM::LLVMArrayType::get(ptr, 1 + encoded.size() * 2); + Type type = LLVM::LLVMArrayType::get(ptr, len); // Prepare an array for encoding results. Value arr = b.create(type); @@ -342,35 +406,37 @@ static FailureOr EncodeResults( }; // Insert the number of encoded results. - Attribute num_rets = b.getI64IntegerAttr(encoded.size()); - insert_value(PackScalarAttribute(g, b, num_rets, "__rt_num_rets"), 0); + LLVM::GlobalOp num_rets = + EncodeScalar(g, b, b.getI64IntegerAttr(encoded.size()), "__rt_num_rets"); + insert_value(Globals::AddrOf(b, num_rets), 0); + + // Package results type ids into a type table global value. + llvm::SmallVector type_ids; + for (auto &arg : encoded) type_ids.push_back(arg.type_id); + LLVM::GlobalOp type_table = + EncodeTypeTable(g, b, type_ids, "__rt_rets_type_table"); + if (!encoded.empty()) insert_value(Globals::AddrOf(b, type_table), 1); // Store encoded results into the allocated storage. for (auto &pair : llvm::enumerate(encoded)) { CustomCallRetEncoding::Encoded encoded_pair = pair.value(); - int64_t offset = 1 + pair.index() * 2; - - insert_value(encoded_pair.type_id, offset + 0); - insert_value(encoded_pair.value, offset + 1); - + int64_t offset = 2 + pair.index(); + insert_value(encoded_pair.value, offset); results.allocas.push_back(encoded_pair.value); } - // Always create an `alloca` in the parent function entry block. - // See: https://llvm.org/docs/Frontend/PerformanceTips.html#use-of-allocas - Value mem = [&]() -> Value { - Block &block = op->getParentOfType().getBody().front(); - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(&block); - Value c1 = b.create(b.getI32IntegerAttr(1)); - return b.create(ptr, type, c1, 0); - }(); + // Get allocation for packed results pointers. + LLVM::AllocaOp alloca = a.GetOrCreate(b, type); + + // Start the lifetime of the encoded results pointers allocation. + b.create(b.getI64IntegerAttr(-1), alloca); - // Store constructed results array on the stack - b.create(arr, mem); + // Store constructed results pointers array on the stack + b.create(arr, alloca); + + // Alloca that encodes the custom call returns. + results.encoded = alloca; - // Return a pointer to the encoded results. - results.result_array_ptr = mem; return results; } @@ -398,6 +464,7 @@ class CallOpLowering : public OpConversionPattern { CallOpLowering(TypeConverter &converter, MLIRContext *ctx, SymbolTable &sym_table, Globals &globals, + EncodingAllocas &allocas, CustomCallArgEncodingSet &arg_encoding, CustomCallAttrEncodingSet &attr_encoding, CustomCallRetEncodingSet &ret_encoding, @@ -405,6 +472,7 @@ class CallOpLowering : public OpConversionPattern { : OpConversionPattern(converter, ctx), sym_table_(sym_table), globals_(globals), + allocas_(allocas), arg_encoding_(arg_encoding), attr_encoding_(attr_encoding), ret_encoding_(ret_encoding), @@ -415,9 +483,13 @@ class CallOpLowering : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); + // Reuse allocas for encoding custom call arguments. + Allocas allocas = allocas_.GetForOperation(op); + // Encode operation arguments as a runtime API arguments. - auto args = EncodeArguments(op, arg_encoding_, globals_, encoded_args_, b, - op->getOperands(), adaptor.getOperands()); + auto args = + EncodeArguments(op, arg_encoding_, globals_, allocas, encoded_args_, b, + op->getOperands(), adaptor.getOperands()); if (failed(args)) return op.emitOpError() << "failed to encode arguments"; // Encode operation attributes as a runtime API argument. @@ -431,8 +503,8 @@ class CallOpLowering : public OpConversionPattern { std::transform( ret_types.begin(), ret_types.end(), converted_ret_types.begin(), [&](Type type) { return getTypeConverter()->convertType(type); }); - auto rets = EncodeResults(op, ret_encoding_, globals_, b, ret_types, - converted_ret_types); + auto rets = EncodeResults(op, ret_encoding_, globals_, allocas, b, + ret_types, converted_ret_types); if (failed(rets)) return op.emitOpError() << "failed to encode results"; // Creates a dynamic custom call resolved by name at run time. @@ -442,19 +514,20 @@ class CallOpLowering : public OpConversionPattern { return b.create( kCustomCall, TypeRange(rewriter.getI1Type()), - ValueRange({adaptor.getCtx(), callee, *args, *attrs, - rets->result_array_ptr})); + ValueRange({adaptor.getCtx(), callee, AsPtr(b, args->encoded), + Globals::AddrOf(b, *attrs), AsPtr(b, rets->encoded)})); }; // Creates a direct custom call resolved at link time. auto call_direct = [&]() -> func::CallOp { auto type = RuntimeAPI::DirectCustomCallFunctionType(op.getContext()); - AddDeclaration(op->getParentOfType(), op.getCallee(), type); + AddDeclaration(sym_table_, op->getParentOfType(), + op.getCallee(), type); - return b.create(op.getCallee(), - TypeRange(rewriter.getI1Type()), - ValueRange({adaptor.getCtx(), *args, *attrs, - rets->result_array_ptr})); + return b.create( + op.getCallee(), TypeRange(rewriter.getI1Type()), + ValueRange({adaptor.getCtx(), AsPtr(b, args->encoded), + Globals::AddrOf(b, *attrs), AsPtr(b, rets->encoded)})); }; // Build a call operation and result decoding right after the original op. @@ -470,6 +543,23 @@ class CallOpLowering : public OpConversionPattern { if (failed(decoded_results)) return op.emitOpError() << "failed to decode results"; + auto size = b.getI64IntegerAttr(-1); + + // End the lifetime of encoded arguments and results pointers. + if (auto *alloca = std::get_if(&args->encoded)) + b.create(size, *alloca); + if (auto *alloca = std::get_if(&rets->encoded)) + b.create(size, *alloca); + + // End the lifetime of arguments encoded on a stack. + for (auto &arg : args->values) + if (auto *alloca = std::get_if(&arg)) + b.create(size, *alloca); + + // End the lifetime of results encoded on a stack. + for (LLVM::AllocaOp alloca : rets->allocas) + b.create(size, alloca); + rewriter.replaceOp(op, ValueRange(*decoded_results)); return success(); } @@ -477,6 +567,7 @@ class CallOpLowering : public OpConversionPattern { private: SymbolTable &sym_table_; Globals &globals_; + EncodingAllocas &allocas_; CustomCallArgEncodingSet &arg_encoding_; CustomCallAttrEncodingSet &attr_encoding_; CustomCallRetEncodingSet &ret_encoding_; @@ -588,8 +679,11 @@ void ConvertRuntimeToLLVMPass::runOnOperation() { ModuleOp module = getOperation(); MLIRContext *ctx = module.getContext(); + // A symbol table for resolving symbol references attributes. + SymbolTable sym_table(module); + // Add declarations for the runtime API functions. - AddRuntimeApiDeclarations(module); + AddRuntimeApiDeclarations(sym_table, module); RuntimeTypeConverter converter; RewritePatternSet patterns(ctx); @@ -637,12 +731,12 @@ void ConvertRuntimeToLLVMPass::runOnOperation() { PopulateTraceTypeIdNames(type_id_names); if (opts_.populate_type_id_names) opts_.populate_type_id_names(type_id_names); - // A symbol table for resolving symbol references attributes. - SymbolTable sym_table(module); - // A helper class to create unique global constants. Globals globals(module, type_id_names); + // A helper class to create allocas for values encoded on a stack. + EncodingAllocas allocas; + // Keep a cache of encoded values to encode each unique value just once. DenseMap encoded_args; @@ -667,8 +761,8 @@ void ConvertRuntimeToLLVMPass::runOnOperation() { if (opts_.populate_attr_encodings) opts_.populate_attr_encodings(attrs); if (opts_.populate_ret_encodings) opts_.populate_ret_encodings(rets); - patterns.add(llvm_converter, ctx, sym_table, globals, args, - attrs, rets, encoded_args); + patterns.add(llvm_converter, ctx, sym_table, globals, allocas, + args, attrs, rets, encoded_args); // Convert function signatures and call sites. mlir::populateFunctionOpInterfaceTypeConversionPattern( diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/specialization.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/specialization.cc index 49b664d8d23..45f8dc16a1d 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/specialization.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/specialization.cc @@ -27,6 +27,8 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/FunctionInterfaces.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "tensorflow/compiler/xla/mlir/runtime/transforms/type_converter.h" @@ -145,7 +147,8 @@ static mlir::DenseElementsAttr GetMemrefValues(mlir::Builder& builder, return mlir::DenseElementsAttr::get(ranked_tensor, attributes); } -Status SpecializeFunction(mlir::func::FuncOp func, ArgumentsRef arguments, +Status SpecializeFunction(mlir::FunctionOpInterface func, + ArgumentsRef arguments, ArrayRef symbolic_shapes, ArrayRef constraints, const SpecializationListener* listener) { @@ -156,16 +159,17 @@ Status SpecializeFunction(mlir::func::FuncOp func, ArgumentsRef arguments, // Specialize all function inputs to the given arguments. llvm::SmallVector specialized_inputs(num_inputs); for (unsigned i = 0; i < num_inputs; ++i) { - auto specialized = - SpecializeOperandType(i, func.getFunctionType().getInput(i), - arguments[i], symbolic_shapes[i]); + auto specialized = SpecializeOperandType( + i, llvm::cast(func.getFunctionType()).getInput(i), + arguments[i], symbolic_shapes[i]); if (!specialized.ok()) return specialized.status(); specialized_inputs[i] = *specialized; } // Update function type to a new specialized one. auto specialized = mlir::FunctionType::get( - ctx, specialized_inputs, func.getFunctionType().getResults()); + ctx, specialized_inputs, + llvm::cast(func.getFunctionType()).getResults()); func.setType(specialized); // Update function entry block arguments. @@ -213,12 +217,13 @@ Status SpecializeFunction(mlir::func::FuncOp func, ArgumentsRef arguments, } // Sink small constants into the function body. - builder.setInsertionPointToStart(&func.getBody().front()); + builder.setInsertionPointToStart(&func.getFunctionBody().front()); for (int i = 0; i < constraints.size(); ++i) { if (constraints[i] != ArgumentConstraint::kValue) continue; // We only support sinking of Tensor arguments into the function body. - mlir::Type input = func.getFunctionType().getInput(i); + mlir::Type input = + llvm::cast(func.getFunctionType()).getInput(i); mlir::TensorType tensor = input.dyn_cast(); if (!tensor || !SupportsValueSpecialization(tensor)) { return InvalidArgumentError(StrCat( diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/specialization.h b/tensorflow/compiler/xla/mlir/runtime/transforms/specialization.h index 8b97f90df49..3fa163e5a74 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/specialization.h +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/specialization.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_MLIR_RUNTIME_TRANSFORMS_SPECIALIZATION_H_ #define TENSORFLOW_COMPILER_XLA_MLIR_RUNTIME_TRANSFORMS_SPECIALIZATION_H_ -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/FunctionInterfaces.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/compiler/xla/runtime/arguments.h" #include "tensorflow/compiler/xla/runtime/constraints.h" @@ -58,7 +58,7 @@ struct SpecializationListener { // // Returns error if arguments are not compatible with the function signature. absl::Status SpecializeFunction( - mlir::func::FuncOp func, ArgumentsRef arguments, + mlir::FunctionOpInterface func, ArgumentsRef arguments, llvm::ArrayRef symbolic_shapes, llvm::ArrayRef constraints, const SpecializationListener* listener = nullptr); diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/tests/BUILD b/tensorflow/compiler/xla/mlir/runtime/transforms/tests/BUILD index e73af47cf56..f63e709402b 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/tests/BUILD @@ -1,11 +1,14 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup") +load("//tensorflow/tsl:tsl.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) glob_lit_tests( data = [":test_utilities"], - driver = "//tensorflow/compiler/mlir:run_lit.sh", + driver = "//tensorflow/compiler/xla:run_lit.sh", test_file_exts = ["mlir"], ) @@ -38,6 +41,8 @@ cc_library( "@llvm-project//mlir:FuncToLLVM", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMToLLVMIRTranslation", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:MemRefToLLVM", "@llvm-project//mlir:Pass", "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:SCFDialect", diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/tests/convert_custom_calls.mlir b/tensorflow/compiler/xla/mlir/runtime/transforms/tests/convert_custom_calls.mlir index 975ad8c90c5..4b4e8495454 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/tests/convert_custom_calls.mlir +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/tests/convert_custom_calls.mlir @@ -46,14 +46,14 @@ func.func @function_call_to_traced_custom_call( %arg0: !rt.execution_context, %arg1: memref ) -> memref attributes {rt.exported = 0 : i32} { - // CHECK: %[[RES:.*]]:2 = rt.trace #rt.hlo_trace<"fusion", "foo", 0>, %[[CTX]] + // CHECK: %[[RES:.*]]:2 = rt.trace #rt.hlo_trace<"fusion">, %[[CTX]] // CHECK-SAME: -> !rt.status, memref { // CHECK-NEXT: %[[STATUS:.*]], %[[RET:.*]] = call %[[CTX]]["target"] // CHECK-NOT: #rt.hlo_trace // CHECK-NEXT: yield %[[STATUS]], %[[RET]] : !rt.status, memref // CHECK-NEXT: } // CHECK: rt.is_ok %[[RES]]#0 - %0 = call @custom_call(%arg1) { rt.trace = #rt.hlo_trace<"fusion", "foo", 0> } + %0 = call @custom_call(%arg1) { rt.trace = #rt.hlo_trace<"fusion"> } : (memref) -> memref return %0 : memref } \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/tests/rt_to_llvm.mlir b/tensorflow/compiler/xla/mlir/runtime/transforms/tests/rt_to_llvm.mlir index 348f869e461..31914e3cb36 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/tests/rt_to_llvm.mlir +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/tests/rt_to_llvm.mlir @@ -151,6 +151,13 @@ func.func @custom_call(%arg0: !rt.execution_context) { // ----- // CHECK: global internal constant @__rt_custom_call_name("target\00") + +// CHECK: global internal constant @__rt_empty_rets() +// CHECK: { +// CHECK: llvm.mlir.undef : !llvm.array<1 x ptr> +// CHECK: llvm.mlir.addressof @__rt_zero : !llvm.ptr +// CHECK: } + // CHECK: global internal constant @__rt_num_attrs(0 : i64) // CHECK: global internal constant @__rt_custom_call_attrs() @@ -159,22 +166,21 @@ func.func @custom_call(%arg0: !rt.execution_context) { // CHECK: llvm.mlir.addressof @__rt_num_attrs : !llvm.ptr // CHECK: } -// CHECK: global internal constant @__rt_num_args(0 : i64) +// CHECK: global internal constant @__rt_empty_args() +// CHECK: { +// CHECK: llvm.mlir.undef : !llvm.array<1 x ptr> +// CHECK: llvm.mlir.addressof @__rt_zero : !llvm.ptr +// CHECK: } // CHECK: func @dynamic_custom_call( // CHECK: %[[CTX:.*]]: !llvm.ptr // CHECK: ) func.func @dynamic_custom_call(%arg0: !rt.execution_context) { - // CHECK: %[[C1:.*]] = arith.constant 1 : i32 - // CHECK: %[[RETS:.*]] = llvm.alloca %[[C1]] x !llvm.array<1 x ptr> - - // CHECK: %[[C1_0:.*]] = arith.constant 1 : i32 - // CHECK: %[[ARGS:.*]] = llvm.alloca %[[C1_0]] x !llvm.array<1 x ptr> - - // CHECK: %[[ATTRS:.*]] = llvm.mlir.addressof @__rt_custom_call_attrs - // CHECK: %[[CALLEE_ADDR:.*]] = llvm.mlir.addressof @__rt_custom_call_name + // CHECK: %[[ARGS:.*]] = llvm.mlir.addressof @__rt_empty_args + // CHECK: %[[ATTRS:.*]] = llvm.mlir.addressof @__rt_custom_call_attrs + // CHECK: %[[RETS:.*]] = llvm.mlir.addressof @__rt_empty_rets // CHECK: %[[STATUS:.*]] = call @runtimeCustomCall(%[[CTX]], %[[CALLEE_ADDR]], // CHECK-SAME: %[[ARGS]], %[[ATTRS]], @@ -200,9 +206,10 @@ func.func @dynamic_custom_call(%arg0: !rt.execution_context) { // CHECK: global internal constant @__rt_custom_call_attrs() // CHECK-SAME: : !llvm.array<4 x ptr> { +// CHECK: llvm.mlir.addressof @__rt_num_attrs // CHECK: llvm.mlir.addressof @__rt_attr_name // CHECK: llvm.mlir.addressof @__type_id_float -// CHECK: llvm.mlir.addressof @__rt_attr_value : !llvm.ptr +// CHECK: llvm.mlir.addressof @__rt_attr_value // CHECK: } // CHECK: func @custom_call( @@ -310,6 +317,15 @@ func.func @custom_call(%arg0: !rt.execution_context) { // ----- +// CHECK: llvm.mlir.global internal constant @__rt_empty_rets() + +// CHECK: llvm.mlir.global internal constant @__rt_num_attrs(0 : i64) +// CHECK: llvm.mlir.global internal constant @__rt_custom_call_attrs + +// CHECK: llvm.mlir.global internal constant @__rt_args_type_table +// CHECK: llvm.mlir.undef : !llvm.array<1 x ptr> +// CHECK: llvm.mlir.addressof @__type_id_float + // CHECK: func @custom_call( // CHECK: %[[CTX:.*]]: !llvm.ptr // CHECK: %[[ARG:.*]]: f32 @@ -318,19 +334,27 @@ func.func @custom_call(%arg0: !rt.execution_context, %arg1 : f32) { // CHECK-DAG: %[[MEM:.*]] = llvm.alloca {{.*}} x f32 // CHECK-DAG: %[[ARGS:.*]] = llvm.alloca {{.*}} x !llvm.array<3 x ptr> - // CHECK-DAG: %[[TYPE_ID:.*]] = llvm.mlir.addressof @__type_id_float // CHECK-DAG: %[[N_ARGS:.*]] = llvm.mlir.addressof @__rt_num_args - // CHECK-DAG: llvm.store %[[ARG]], %[[MEM]] - // CHECK-DAG: llvm.store {{.*}}, %[[ARGS]] : !llvm.array<3 x ptr>, !llvm.ptr + + // CHECK: %[[ARGS_TYPES:.*]] = llvm.mlir.addressof @__rt_args_type_table + // CHECK: llvm.insertvalue %[[ARGS_TYPES]], {{.*}}[1] : !llvm.array<3 x ptr> + // CHECK: llvm.intr.lifetime.start -1, %[[ARGS]] + // CHECK: llvm.store {{.*}}, %[[ARGS]] : !llvm.array<3 x ptr>, !llvm.ptr + + // CHECK: %[[RETS:.*]] = llvm.mlir.addressof @__rt_empty_rets // CHECK: call @target + // CHECK: llvm.intr.lifetime.end -1, %[[ARGS]] rt.call %arg0["target"] (%arg1) : (f32) -> () func.return } // ----- +// CHECK: llvm.mlir.global internal constant @__rt_args_type_table +// CHECK: llvm.mlir.addressof @__type_id_memref_view + // CHECK: func @custom_call( // CHECK: %[[CTX:.*]]: !llvm.ptr // CHECK: %[[ARG:.*]]: memref @@ -340,8 +364,6 @@ func.func @custom_call(%arg0: !rt.execution_context, %arg1 : memref) // CHECK: %[[DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG]] // CHECK-SAME: to !llvm.struct - // CHECK: %[[TYPE_ID:.*]] = llvm.mlir.addressof @__type_id_memref_view - // CHECK: llvm.mlir.undef : !llvm.array<4 x i64> // CHECK-NEXT: llvm.extractvalue %[[DESC]][3, 0] // CHECK-NEXT: arith.constant 256 : i64 @@ -359,6 +381,7 @@ func.func @custom_call(%arg0: !rt.execution_context, %arg1 : memref) // CHECK: llvm.insertvalue // CHECK: %[[N_ARGS:.*]] = llvm.mlir.addressof @__rt_num_args + // CHECK: %[[TYPES:.*]] = llvm.mlir.addressof @__rt_args_type_table // CHECK: call @target rt.call %arg0["target"] (%arg1) : (memref) -> () @@ -399,17 +422,18 @@ func.func @dynamic_custom_call(%arg0: !rt.execution_context) { // ----- -// CHECK: %[[C1:.*]] = arith.constant 1 : i32 -// CHECK: %[[RETS:.*]] = llvm.alloca %[[C1]] x !llvm.array<3 x ptr> +func.func @custom_call(%ctx: !rt.execution_context) -> (f32) { + // CHECK: %[[C1:.*]] = arith.constant 1 : i32 + // CHECK: %[[RETS:.*]] = llvm.alloca %[[C1]] x !llvm.array<3 x ptr> -// CHECK: %[[C1_0:.*]] = arith.constant 1 : i32 -// CHECK: %[[F32_ALLOCA:.*]] = llvm.alloca %[[C1_0]] x f32 + // CHECK: %[[C1_0:.*]] = arith.constant 1 : i32 + // CHECK: %[[F32_ALLOCA:.*]] = llvm.alloca %[[C1_0]] x f32 -// CHECK: %[[N_RETS:.*]] = llvm.mlir.addressof @__rt_num_rets + // CHECK: %[[N_RETS:.*]] = llvm.mlir.addressof @__rt_num_rets -// CHECK: call @f32_reduce -// CHECK: %[[LOAD2:.*]] = llvm.load %[[F32_ALLOCA]] -func.func @custom_call(%ctx: !rt.execution_context) -> (f32) { + // CHECK: call @f32_reduce + // CHECK: %[[LOAD2:.*]] = llvm.load %[[F32_ALLOCA]] + // CHECK: llvm.intr.lifetime.end -1, %[[F32_ALLOCA]] %status, %0 = rt.call %ctx["f32_reduce"] () : () -> (f32) return %0 : f32 } @@ -426,6 +450,9 @@ func.func @opaque_arg(%ctx: !rt.execution_context, %arg: !rt.opaque) { // ----- +// CHECK: llvm.mlir.global internal constant @__rt_args_type_table +// CHECK: llvm.mlir.addressof @__type_id_opaque : !llvm.ptr + // CHECK: func @opaque_custom_call_arg( // CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, // CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr @@ -433,7 +460,6 @@ func.func @opaque_arg(%ctx: !rt.execution_context, %arg: !rt.opaque) { func.func @opaque_custom_call_arg(%ctx: !rt.execution_context, %arg: !rt.opaque) { // CHECK: %[[ALLOCA:.*]] = llvm.alloca {{.*}} x !llvm.ptr - // CHECK: llvm.mlir.addressof @__type_id_opaque : !llvm.ptr // CHECK: llvm.store %[[ARG1]], %[[ALLOCA]] : !llvm.ptr // CHECK: call @target %status = rt.call %ctx["target"] (%arg) : (!rt.opaque) -> () @@ -504,6 +530,22 @@ func.func @custom_call(%ctx: !rt.execution_context) -> (memref<2x2xf32>) { // ----- +// CHECK: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK: %[[RETS_ALLOCA:.*]] = llvm.alloca %[[C1]] x !llvm.array<3 x ptr> + +// CHECK: %[[C1_0:.*]] = arith.constant 1 : i32 +// CHECK: %[[MEMREF_ALLOCA:.*]] = llvm.alloca %[[C1_0]] x !llvm.struct<(i8, i8, ptr, array<4 x i64>)> + +// CHECK: call @f32_reduce +func.func @custom_call(%ctx: !rt.execution_context) + -> (!async.value>) { + %status, %0 = rt.call %ctx["f32_reduce"] () + : () -> (!async.value>) + return %0 : !async.value> +} + +// ----- + // Test that custom call encoding can pass a reference to exported function as a // custom call attribute. func.func @init(%ctx: !rt.execution_context) @@ -536,9 +578,72 @@ func.func @trace(%ctx: !rt.execution_context) -> tensor { // CHECK: call @xla.trace.activity_start // CHECK: call @compute // CHECK: call @xla.trace.activity_end - %0 = rt.trace #rt.hlo_trace<"foo", "bar", 0>, %ctx -> tensor { + %0 = rt.trace #rt.hlo_trace<"foo">, %ctx -> tensor { %1 = func.call @compute(): () -> tensor yield %1 : tensor } return %0 : tensor } + +// ----- + +// CHECK: llvm.mlir.global internal constant @__rt_c123(123 : i32) + +// CHECK: func @custom_call( +// CHECK: %[[CTX:.*]]: !llvm.ptr +// CHECK: ) +func.func @custom_call(%arg0: !rt.execution_context) { + // CHECK: llvm.mlir.addressof @__rt_c123 : !llvm.ptr + // CHECK: call @target + %c123 = arith.constant 123 : i32 + rt.call %arg0["target"] (%c123) : (i32) -> () + func.return +} + +// ----- + +// CHECK: llvm.mlir.global internal constant @__rt_cst(1.234560e+02 : f32) + +// CHECK: func @custom_call( +// CHECK: %[[CTX:.*]]: !llvm.ptr +// CHECK: ) +func.func @custom_call(%arg0: !rt.execution_context) { + // CHECK: llvm.mlir.addressof @__rt_cst : !llvm.ptr + // CHECK: call @target + %cst = arith.constant 123.456 : f32 + rt.call %arg0["target"] (%cst) : (f32) -> () + func.return +} + +// ----- +// Check that we reuse allocas for encoding arguments on the stack. + +// CHECK: func @custom_call( +// CHECK: %[[CTX:.*]]: !llvm.ptr, +// CHECK: %[[ARG:.*]]: f32 +// CHECK: ) +func.func @custom_call(%arg0: !rt.execution_context, %arg1: f32) { + // CHECK: %[[ARGS:.*]] = llvm.alloca {{.*}} x !llvm.array<3 x ptr> + // CHECK: %[[ARG_ALLOCA:.*]] = llvm.alloca %{{.*}} x f32 + // CHECK-NOT: llvm.alloca + + // llvm.intr.lifetime.start -1, %[[ARG_ALLOCA]] : !llvm.ptr + // CHECK: llvm.store %[[ARG]], %[[ARG_ALLOCA]] : f32, !llvm.ptr + // llvm.intr.lifetime.start -1, %[[ARGS]] : !llvm.ptr + // CHECK: llvm.store {{.*}}, %[[ARGS]] + // CHECK: call @target + rt.call %arg0["target"] (%arg1) : (f32) -> () + // llvm.intr.lifetime.end -1, %[[ARGS]] : !llvm.ptr + // llvm.intr.lifetime.end -1, %[[ARG_ALLOCA]] : !llvm.ptr + + // llvm.intr.lifetime.start -1, %[[ARG_ALLOCA]] : !llvm.ptr + // CHECK: llvm.store %[[ARG]], %[[ARG_ALLOCA]] : f32, !llvm.ptr + // llvm.intr.lifetime.start -1, %[[ARGS]] : !llvm.ptr + // CHECK: llvm.store {{.*}}, %[[ARGS]] + // CHECK: call @target + rt.call %arg0["target"] (%arg1) : (f32) -> () + // llvm.intr.lifetime.end -1, %[[ARGS]] : !llvm.ptr + // llvm.intr.lifetime.end -1, %[[ARG_ALLOCA]] : !llvm.ptr + + func.return +} diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/tests/testlib_pipeline.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/tests/testlib_pipeline.cc index 76311710863..e2d45689da7 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/tests/testlib_pipeline.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/tests/testlib_pipeline.cc @@ -19,12 +19,14 @@ limitations under the License. #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" // from @llvm-project #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Async/IR/Async.h" // from @llvm-project #include "mlir/Dialect/Async/Passes.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project @@ -39,7 +41,7 @@ void RegisterXlaRuntimeTestlibDialects(DialectRegistry& dialects) { // Register MLIR dialects supported by the Xla runtime tests. dialects->insert(); + mlir::memref::MemRefDialect, RuntimeDialect>(); // Register MLIR dialects that can be translated to LLVM IR. registerLLVMDialectTranslation(*dialects); @@ -47,9 +49,11 @@ void RegisterXlaRuntimeTestlibDialects(DialectRegistry& dialects) { void CreateXlaRuntimeTestlibPipeline(PassManager& passes) { passes->addPass(mlir::createConvertSCFToCFPass()); + passes->addPass(mlir::createAsyncFuncToAsyncRuntimePass()); // Export functions to the XLA runtime. passes->addPass(CreateExportRuntimeFunctionsPass()); + passes->addPass(CreateConvertCustomCallsPass()); passes->addPass(CreateConvertAssertsPass()); // Lower from high level async operations to async runtime. @@ -66,6 +70,7 @@ void CreateXlaRuntimeTestlibPipeline(PassManager& passes) { passes->addPass(mlir::createConvertAsyncToLLVMPass()); // Convert everything else to LLVM dialect. + passes->addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); passes->addPass(mlir::createConvertFuncToLLVMPass()); passes->addPass(mlir::createReconcileUnrealizedCastsPass()); diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/type_converter.cc b/tensorflow/compiler/xla/mlir/runtime/transforms/type_converter.cc index c126b4017a7..89b2df9b3af 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/type_converter.cc +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/type_converter.cc @@ -108,6 +108,8 @@ static std::unique_ptr ConvertCanonicalType( /*static*/ StatusOr TypeConverter::ConvertElementType( mlir::Type type) { + if (type.isFloat8E4M3FN()) return PrimitiveType::F8E4M3FN; + if (type.isFloat8E5M2()) return PrimitiveType::F8E5M2; if (type.isIndex()) return PrimitiveType::S64; if (type.isBF16()) return PrimitiveType::BF16; if (type.isF16()) return PrimitiveType::F16; diff --git a/tensorflow/compiler/xla/mlir/runtime/transforms/type_converter.h b/tensorflow/compiler/xla/mlir/runtime/transforms/type_converter.h index 8b8506e781c..ef43fcec553 100644 --- a/tensorflow/compiler/xla/mlir/runtime/transforms/type_converter.h +++ b/tensorflow/compiler/xla/mlir/runtime/transforms/type_converter.h @@ -40,6 +40,13 @@ class TypeConverter { // time type if the conversion is successful, or `nullptr` if failed. using ConversionFn = std::function(mlir::Type)>; + TypeConverter() = default; + + template + explicit TypeConverter(Fns&&... fn) { + (AddConversion(std::forward(fn)), ...); + } + // Adds a type conversion function with a type predicate. // // Example: @@ -50,7 +57,7 @@ class TypeConverter { // result for all other types, and the type converter will try the next // conversion function (see `Convert` implementation). template > - void AddConversion(Fn fn) { + void AddConversion(Fn&& fn) { using ArgType = typename FnTraits::template arg_t<0>; conversions_.emplace_back( [fn = std::forward(fn)](mlir::Type type) -> std::unique_ptr { diff --git a/tensorflow/compiler/xla/mlir/runtime/utils/BUILD b/tensorflow/compiler/xla/mlir/runtime/utils/BUILD index 0ace63c24a0..0464e84ccb2 100644 --- a/tensorflow/compiler/xla/mlir/runtime/utils/BUILD +++ b/tensorflow/compiler/xla/mlir/runtime/utils/BUILD @@ -1,7 +1,8 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//tensorflow/compiler/xla/mlir/runtime:friends"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/xla/mlir/runtime/utils/async_runtime_api.cc b/tensorflow/compiler/xla/mlir/runtime/utils/async_runtime_api.cc index fe40566e0e5..27c7081fd0d 100644 --- a/tensorflow/compiler/xla/mlir/runtime/utils/async_runtime_api.cc +++ b/tensorflow/compiler/xla/mlir/runtime/utils/async_runtime_api.cc @@ -55,7 +55,7 @@ void ExtractAsyncValue( // Fast path if async value is already available. if (async_value->IsAvailable()) { - void *storage = AsyncRuntime::GetStorage(value); + auto *storage = AsyncRuntime::GetStorage(value); emplace_fn(storage, dst); AsyncRuntime::DropRef(AsyncRuntime::ToAsyncRuntimeObject(value)); return; @@ -63,7 +63,7 @@ void ExtractAsyncValue( // Wait for the async value completion, and emplace the `dst`. async_value->AndThen([value, emplace_fn, dst = FormRef(dst)]() { - void *storage = AsyncRuntime::GetStorage(value); + auto *storage = AsyncRuntime::GetStorage(value); emplace_fn(storage, dst.get()); AsyncRuntime::DropRef(AsyncRuntime::ToAsyncRuntimeObject(value)); }); @@ -77,7 +77,7 @@ void ExtractAsyncValue( // Fast path if async value is already available. if (async_value->IsAvailable()) { - void *storage = AsyncRuntime::GetStorage(value); + auto *storage = AsyncRuntime::GetStorage(value); emplace_fn(storage, dst, context); AsyncRuntime::DropRef(AsyncRuntime::ToAsyncRuntimeObject(value)); return; @@ -85,7 +85,7 @@ void ExtractAsyncValue( // Wait for the async value completion, and emplace the `dst`. async_value->AndThen([value, emplace_fn, context, dst = FormRef(dst)]() { - void *storage = AsyncRuntime::GetStorage(value); + auto *storage = AsyncRuntime::GetStorage(value); emplace_fn(storage, dst.get(), context); AsyncRuntime::DropRef(AsyncRuntime::ToAsyncRuntimeObject(value)); }); diff --git a/tensorflow/compiler/xla/mlir/runtime/utils/constraints.cc b/tensorflow/compiler/xla/mlir/runtime/utils/constraints.cc index 3ce5deb51bf..61355cec648 100644 --- a/tensorflow/compiler/xla/mlir/runtime/utils/constraints.cc +++ b/tensorflow/compiler/xla/mlir/runtime/utils/constraints.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "llvm/ADT/SmallVector.h" +#include "mlir/IR/FunctionInterfaces.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project namespace xla { @@ -35,7 +36,7 @@ using absl::StrCat; using llvm::SmallVector; StatusOr> GetArgumentsConstraints( - func::FuncOp func) { + FunctionOpInterface func) { llvm::SmallVector constraints; constraints.reserve(func.getNumArguments()); @@ -52,7 +53,8 @@ StatusOr> GetArgumentsConstraints( }; for (int i = 0; i < func.getNumArguments(); ++i) { - auto arg_type = func.getFunctionType().getInput(i); + auto arg_type = + llvm::cast(func.getFunctionType()).getInput(i); auto constraint = parse(func.getArgAttr(i, kArgumentConstraintAttrName)); if (!constraint.ok()) return constraint.status(); diff --git a/tensorflow/compiler/xla/mlir/runtime/utils/constraints.h b/tensorflow/compiler/xla/mlir/runtime/utils/constraints.h index 9c460d2cf23..2ab66c44b64 100644 --- a/tensorflow/compiler/xla/mlir/runtime/utils/constraints.h +++ b/tensorflow/compiler/xla/mlir/runtime/utils/constraints.h @@ -17,17 +17,15 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_MLIR_RUNTIME_UTILS_CONSTRAINTS_H_ #include "absl/status/statusor.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/FunctionInterfaces.h" // from @llvm-project #include "tensorflow/compiler/xla/runtime/constraints.h" namespace xla { namespace runtime { - // Returns arguments constraints inferred from the function signature. absl::StatusOr> GetArgumentsConstraints( - mlir::func::FuncOp func); + mlir::FunctionOpInterface func); // Resolves argument constraint based on the argument type, if constraint is // fully satisfied by the type, returns `kResolved`. diff --git a/tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.cc b/tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.cc index 4b16c4f6d63..ddc413519a4 100644 --- a/tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.cc +++ b/tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.cc @@ -40,7 +40,7 @@ FuncOp CustomCallDeclarations::GetOrCreate(ImplicitLocOpBuilder& b, StringRef target, FunctionType type) { // Check if we already have a custom all declaration. - Key key = {target, type}; + Key key = {b.getStringAttr(target), type}; if (auto it = custom_calls_.find(key); it != custom_calls_.end()) return it->second; diff --git a/tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.h b/tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.h index 69dc72d7649..2e081047849 100644 --- a/tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.h +++ b/tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.h @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project @@ -59,7 +60,7 @@ class CustomCallDeclarations { private: mlir::SymbolTable sym_table_; - using Key = std::pair; + using Key = std::pair; llvm::DenseMap custom_calls_; }; diff --git a/tensorflow/compiler/xla/mlir/runtime/xla-runtime-opt.cc b/tensorflow/compiler/xla/mlir/runtime/xla-runtime-opt.cc index a1c38251095..34e05c414fa 100644 --- a/tensorflow/compiler/xla/mlir/runtime/xla-runtime-opt.cc +++ b/tensorflow/compiler/xla/mlir/runtime/xla-runtime-opt.cc @@ -13,24 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/Dialect/Async/IR/Async.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Math/IR/Math.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/math/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir/memref/transforms/passes.h" #include "tensorflow/compiler/xla/mlir/runtime/ir/tests/testlib.h" #include "tensorflow/compiler/xla/mlir/runtime/transforms/passes.h" -#include "tensorflow/compiler/xla/mlir/transforms/math/passes.h" -#include "tensorflow/compiler/xla/mlir/transforms/memref/passes.h" int main(int argc, char **argv) { mlir::DialectRegistry registry; registry.insert(); - - xla::runtime::registerMathTransformsPasses(); - xla::runtime::registerMemrefTransformsPasses(); + mlir::async::AsyncDialect, xla::runtime::TestlibDialect>(); + xla::registerMathTransformsPasses(); + xla::registerMemrefTransformsPasses(); xla::runtime::registerRuntimeTransformsPasses(); return failed(MlirOptMain(argc, argv, "Xla Runtime Pass Driver\n", registry)); diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/BUILD b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/BUILD new file mode 100644 index 00000000000..319ed6ac5b0 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/BUILD @@ -0,0 +1,57 @@ +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_binary") +load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") + +xla_cc_binary( + name = "mlir-bisect", + srcs = ["mlir_bisect.cc"], + visibility = ["//visibility:public"], + deps = [ + ":bisect_lib", + "//tensorflow/compiler/xla/mlir/runtime/ir:rt", + "//tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites", + "//tensorflow/compiler/xla/mlir/tools/mlir_replay/public:execution_trace_utils", + "//tensorflow/compiler/xla/mlir_hlo:gml_st", + "//tensorflow/compiler/xla/mlir_hlo:gml_st_bufferizable_op_interface", + "//tensorflow/compiler/xla/mlir_hlo:gml_st_passes", + "//tensorflow/compiler/xla/mlir_hlo:gml_st_test_passes", + "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", + "//tensorflow/compiler/xla/mlir_hlo:lhlo", + "//tensorflow/compiler/xla/mlir_hlo:lmhlo_passes", + "//tensorflow/compiler/xla/mlir_hlo:mhlo_passes", + "//tensorflow/compiler/xla/mlir_hlo:mlir_interpreter_dialects", + "//tensorflow/compiler/xla/mlir_hlo:mlir_interpreter_framework", + "//tensorflow/compiler/xla/mlir_hlo:thlo", + "//tensorflow/compiler/xla/mlir_hlo:thlo_passes", + "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:platform_port", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MlirReduceLib", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "bisect_lib", + srcs = [ + "bisect_lib.cc", + "test_passes.cc", + ], + hdrs = [ + "bisect_lib.h", + "test_passes.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla/mlir/tools/mlir_replay/public:execution_trace_proto_cc", + "//tensorflow/compiler/xla/mlir/tools/mlir_replay/public:execution_trace_proto_cc_impl", + "//tensorflow/compiler/xla/mlir/tools/mlir_replay/public:execution_trace_utils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/README.md b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/README.md new file mode 100644 index 00000000000..c18a21ac0c5 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/README.md @@ -0,0 +1,85 @@ +# MLIR HLO mlir_bisect + +This is a test case reduction tool, similar in purpose to `mlir-reduce`, but +specific to the `mlir-interpreter` infrastructure. In particular, reductions can +depend on concrete values encountered during execution, and reductions can (and +usually do) generate multiple candidates. + +For example, the `ReplaceOpWithConstant` reduction will attempt to replace each +op with each of its results. If the op is in a loop, each execution will be a +candidate for replacement. + +## Using this tool + +1. Run a JAX test with snapshots enabled: + + ``` + bazel test some-jax-test + --test_env=XLA_FLAGS="--xla_cpu_use_xla_runtime --xla_dump_to=/tmp/dump + --xla_dump_hlo_snapshots" --test_filter=SomeSpecific.Test + --test_sharding_strategy=disabled + ``` + +1. Figure out the culprit module and pass (sorry, no automation yet): + + ``` + bazel run tensorflow/compiler/xla/mlir/tools/mlir_replay:mlir_replay -- \ + --mlir_compilation_trace=/tmp/dump/module_0000.jit__something.mlir-trace.pb \ + --hlo_snapshot=/tmp/dump/module_0000.jit__something.snapshot.0.pb \ + --print_changes_only \ + --execution_trace_dir=/tmp/execution + ``` + + You should see a pass after which results change. You'll want to use the + .mlir file in `/tmp/execution` corresponding to the pass *before* that with + the bisect tool. + + Note: If the failing pass is bufferization, you may have to use an earlier + snapshot, e.g. before EmptyTensorToAllocTensor. +1. Run bisect: + + ``` + bazel run tensorflow/compiler/xla/mlir/tools/mlir_bisect:mlir-bisect -- \ + --hlo-snapshot=/tmp/dump/module_0000.jit_something.snapshot.0.pb \ + --pass-pipeline="builtin.module(empty-tensor-to-alloc-tensor,one-shot-bufferize{allow-return-allocs bufferize-function-boundaries create-deallocs=0})" \ + /tmp/execution/0052.ScalarizationPass.mlir + ``` + +## Adding a reduction + +To add a reduction, create a function that generates the candidates and register +it: + +``` +SmallVector> +FrobulateAndDefenestrate(BisectState&, dialect::SomeOp some_op) { + auto [cloned_module_1, cloned_op_1] = CloneModuleFor(some_op); + Frobulate(cloned_op_1); + + auto [cloned_module_2, cloned_op_2] = CloneModuleFor(some_op); + Defenestrate(cloned_op_2); + + return {cloned_module_1, cloned_module_2}; +} + +REGISTER_MLIR_REDUCE_STRATEGY(FrobulateAndDefenestrate); +``` + +Then, add a test for the strategy. Make sure your strategy is linked into +mlir-bisect and has `alwayslink` set. + +``` +// RUN: mlir-bisect %s --debug-strategy=FrobulateAndDefenestrate | FileCheck %s + +func.func @main() { + dialect.some_op() +} + +// CHECK: func @main() +// CHECK-NEXT: frobulated + +// CHECK: func @main() +// CHECK-NEXT: defenestrated +``` + +`--debug-strategy` will print all candidates generated by the given strategy. \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/bisect_lib.cc b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/bisect_lib.cc new file mode 100644 index 00000000000..475a042e140 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/bisect_lib.cc @@ -0,0 +1,81 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/mlir/tools/mlir_bisect/bisect_lib.h" + +#include +#include +#include +#include + +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace bisect { + +Operation* FindInClone(Operation* op, ModuleOp clone) { + if (llvm::isa(op)) { + return clone; + } + + auto* parent_clone = FindInClone(op->getParentOp(), clone); + auto cloned_ops = + parent_clone->getRegions()[op->getParentRegion()->getRegionNumber()] + .getOps(); + for (auto [original_op, cloned_op] : + llvm::zip(op->getParentRegion()->getOps(), cloned_ops)) { + if (&original_op == op) { + return &cloned_op; + } + } + + llvm_unreachable("Op not found in clone."); +} + +std::pair, Operation*> CloneModuleFor(Operation* op) { + auto module = op->getParentOfType().clone(); + return {OwningOpRef{module}, FindInClone(op, module)}; +} + +namespace detail { + +DenseMap>& +GetStrategies() { + static auto* strategies = + new DenseMap>(); + return *strategies; +} + +void RegisterReduceStrategy( + StringRef name, + std::function fn) { + GetStrategies()[name] = fn; +} + +CandidateVector GetCandidates( + const std::function& strategy, + BisectState& state, ModuleOp op) { + assert(strategy && "GetCandidates was passed a null strategy"); + CandidateVector result; + op.lookupSymbol("main")->walk([&](Operation* subOp) { + llvm::move(strategy(state, subOp), std::back_inserter(result)); + }); + return result; +} + +} // namespace detail +} // namespace bisect +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/bisect_lib.h b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/bisect_lib.h new file mode 100644 index 00000000000..c57b49f3fd0 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/bisect_lib.h @@ -0,0 +1,96 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_TOOLS_MLIR_BISECT_BISECT_LIB_H_ +#define TENSORFLOW_COMPILER_XLA_MLIR_TOOLS_MLIR_BISECT_BISECT_LIB_H_ + +#include +#include +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace.pb.h" +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils.h" + +#define REGISTER_MLIR_REDUCE_STRATEGY(name) \ + static int name##_init = []() { \ + ::mlir::bisect::detail::RegisterReduceStrategy(#name, name); \ + return 1; \ + }(); + +namespace mlir { +namespace bisect { + +class BisectState { + public: + void SetTrace(mlir::interpreter::ExecutionTrace trace) { + trace_ = std::move(trace); + } + + // Returns all executions of the given op. + llvm::SmallVector GetExecutions( + mlir::Operation* op) const { + return interpreter::FindOpExecutionsInTrace(trace_, op); + } + + private: + mlir::interpreter::ExecutionTrace trace_; +}; + +std::pair, Operation*> CloneModuleFor(Operation* op); +Operation* FindInClone(Operation* op, ModuleOp clone); + +template +std::pair, Op> CloneModuleFor(Op op) { + auto [module, op_clone] = CloneModuleFor(op.getOperation()); + return {std::move(module), llvm::cast(op_clone)}; +} + +namespace detail { + +using CandidateVector = SmallVector>; + +CandidateVector GetCandidates( + const std::function& strategy, + BisectState& state, ModuleOp op); + +DenseMap>& +GetStrategies(); + +// Registers a strategy that applies to all ops. +void RegisterReduceStrategy( + StringRef name, + std::function fn); + +// Registers a strategy that applies to specific ops. +template +void RegisterReduceStrategy(StringRef name, + CandidateVector (*fn)(BisectState&, Op)) { + RegisterReduceStrategy( + name, [fn](BisectState& state, Operation* op) -> CandidateVector { + if (auto cast = llvm::dyn_cast(op)) { + return fn(state, cast); + } + return {}; + }); +} + +} // namespace detail + +} // namespace bisect +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_XLA_MLIR_TOOLS_MLIR_BISECT_BISECT_LIB_H_ diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/mlir_bisect.cc b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/mlir_bisect.cc new file mode 100644 index 00000000000..aca37f2ceea --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/mlir_bisect.cc @@ -0,0 +1,346 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project +#include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Tools/ParseUtilities.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/runtime/ir/rt_dialect.h" +#include "tensorflow/compiler/xla/mlir/tools/mlir_bisect/bisect_lib.h" +#include "tensorflow/compiler/xla/mlir/tools/mlir_bisect/test_passes.h" +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils.h" +#include "tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/gml_st/interfaces/bufferizable_op_interface_impl.h" +#include "tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/test_passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/register.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/thlo/IR/thlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/thlo/transforms/passes.h" +#include "tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/interpreter.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/init_main.h" + +struct Options { + llvm::cl::opt input_filename{llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")}; + llvm::cl::opt hlo_snapshot{ + "hlo-snapshot", + llvm::cl::desc( + "If set, get argument values from the given snapshot. If not set, " + "the input function must not have any arguments."), + llvm::cl::init("")}; + llvm::cl::opt debug_strategy{ + "debug-strategy", + llvm::cl::desc("If set, print all reductions for the given strategy and " + "exit. For testing."), + llvm::cl::init("")}; + llvm::cl::opt expected_error{ + "expected-error", + llvm::cl::desc("If set, expect the given error message after applying " + "the pass instead of a successful execution."), + llvm::cl::init("")}; + llvm::cl::opt max_steps_per_run{ + "max-steps-per-run", + llvm::cl::desc("Maximum number of steps to execute for each attempt."), + llvm::cl::init(100000)}; + mlir::PassPipelineCLParser pass_pipeline{"", "Passes to run"}; + llvm::cl::opt canonicalize{ + "enable-canonicalization", + llvm::cl::desc("If set, canonicalize candidates before trying them. Set " + "to false if you're bisecting --canonicalize."), + llvm::cl::init(true)}; +}; + +namespace mlir { +namespace bisect { +namespace { + +OwningOpRef ParseMlirInput(llvm::StringRef inputFilename, + MLIRContext* context) { + std::string error_message; + auto file = mlir::openInputFile(inputFilename, &error_message); + if (!file) { + llvm::errs() << error_message << "\n"; + return {}; + } + + auto source_mgr = std::make_shared(); + source_mgr->AddNewSourceBuffer(std::move(file), SMLoc()); + return { + llvm::cast(parseSourceFileForTool(source_mgr, context, + /*insertImplicitModule=*/true) + .release())}; +} + +LogicalResult RunPipeline(ModuleOp module, const Options& options) { + if (!options.pass_pipeline.hasAnyOccurrences()) { + return mlir::success(); + } + + auto error_handler = [&](const Twine& msg) { + llvm::errs() << msg << "\n"; + return failure(); + }; + PassManager pm(module.getContext()); + if (failed(options.pass_pipeline.addToPipeline(pm, error_handler)) || + failed(pm.run(module))) { + llvm::errs() << "pipeline failed\n"; + return failure(); + } + return success(); +} + +LogicalResult Run(ModuleOp module, interpreter::ExecutionTrace* trace, + const Options& options) { + SymbolTable symbol_table{module}; + interpreter::ExecutionTraceListener tracer(trace); + interpreter::InterpreterOptions interpreter_options; + interpreter_options.listener = &tracer; + interpreter_options.maxSteps = options.max_steps_per_run; + auto results_before_pass = interpreter::runInterpreter( + symbol_table, llvm::cast(symbol_table.lookup("main")), {}, + interpreter_options); + + if (!succeeded(results_before_pass)) { + llvm::errs() << "Interpreter failed\n"; + return failure(); + } + + if (!options.debug_strategy.empty()) { + return success(); + } + + OwningOpRef clone(module.clone()); + if (!succeeded(RunPipeline(*clone, options))) { + return failure(); + } + + SymbolTable symbol_table_after{*clone}; + interpreter_options.listener = nullptr; + bool found_expected_error = false; + if (!options.expected_error.empty()) { + auto original_handler = interpreter_options.errorHandler; + interpreter_options.errorHandler = [&](llvm::StringRef failure) { + found_expected_error |= + failure.find(options.expected_error) != std::string::npos; + original_handler(failure); + }; + } + + auto results_after_pass = interpreter::runInterpreter( + symbol_table_after, + llvm::cast(symbol_table_after.lookup("main")), {}, + interpreter_options); + + if (!succeeded(results_after_pass)) { + if (found_expected_error) { + return success(); + } + llvm::errs() << "Interpreter failed\n"; + return failure(); + } else if (!options.expected_error.empty()) { + llvm::errs() << "Expected error not seen\n"; + return failure(); + } + + // If the results are the same, the bug is no longer present. + if (*results_before_pass == *results_after_pass) { + return failure(); + } + + llvm::errs() << "results before:\n"; + for (auto& result : *results_before_pass) { + llvm::errs() << " " << result.toString() << "\n"; + } + llvm::errs() << "\nresults after:\n"; + for (auto& result : *results_after_pass) { + llvm::errs() << " " << result.toString() << "\n"; + } + + return success(); +} + +LogicalResult Canonicalize(ModuleOp module) { + PassManager pm(module.getContext()); + pm.addPass(createCanonicalizerPass()); + return pm.run(module.getOperation()); +} + +OwningOpRef ReduceModule(OwningOpRef module, + BisectState& state, const Options& options) { + auto strategies = llvm::to_vector(mlir::bisect::detail::GetStrategies()); + + auto apply_step = [&]() -> std::optional> { + for (auto it = strategies.begin(); it != strategies.end(); ++it) { + for (auto& candidate : + detail::GetCandidates(it->second, state, *module)) { + if (!mlir::verify(*candidate).succeeded()) { + continue; + } + if (options.canonicalize && !Canonicalize(*candidate).succeeded()) { + continue; + } + + interpreter::ExecutionTrace trace; + // Verify that the candidate is still buggy. + if (!Run(*candidate, &trace, options).succeeded()) { + continue; + } + + // Print the new buggy module. + llvm::outs() << "module after " << it->first << ":\n" + << *candidate << "\n\n"; + + // Update the trace. + state.SetTrace(trace); + + // Move failed strategies to the end. + decltype(strategies) new_strategies; + std::copy(it, strategies.end(), std::back_inserter(new_strategies)); + std::copy(strategies.begin(), it, std::back_inserter(new_strategies)); + strategies = new_strategies; + return {std::move(candidate)}; + } + } + return std::nullopt; + }; + + while (auto new_module = apply_step()) { + module = std::move(*new_module); + } + return module; +} + +void ReplaceArgsWithConstants(ModuleOp module, + const xla::HloSnapshot& snapshot) { + auto main = llvm::cast(module.lookupSymbol("main")); + OpBuilder b(main.getBody()); + for (auto [arg, bbarg] : + llvm::zip(snapshot.arguments(), main.getBody().getArguments())) { + auto attr = interpreter::ValueToAttribute( + *interpreter::LiteralToValue(*xla::Literal::CreateFromProto(arg)), + bbarg.getType()); + CHECK_EQ(attr.size(), 1) << "unsupported argument"; + + bbarg.replaceAllUsesWith(b.create( + main.getLoc(), attr.front(), bbarg.getType())); + } + while (main.getBody().getNumArguments() > 0) { + main.getBody().eraseArgument(0); + } + main.setFunctionType(FunctionType::get(main.getContext(), /*inputs=*/{}, + main.getFunctionType().getResults())); + main.setArgAttrsAttr(b.getArrayAttr({})); +} + +} // namespace +} // namespace bisect +} // namespace mlir + +int main(int argc, char* argv[]) { + llvm::errs().tie(&llvm::outs()); + llvm::outs().tie(&llvm::errs()); + int dummy_argc = 1; + tsl::port::InitMain("", &dummy_argc, &argv); + + Options options; + llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR bisect tool\n"); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::registerAllPasses(); + mlir::bisect::test::RegisterTestPasses(); + mlir::mhlo::registerAllMhloPasses(); + mlir::lmhlo::registerAllLmhloPasses(); + mlir::thlo::registerAllThloPasses(); + mlir::gml_st::registerGmlStPasses(); + mlir::gml_st::registerGmlStTestPasses(); + mlir::gml_st::registerBufferizableOpInterfaceExternalModels(registry); + mlir::mhlo::registerAllMhloDialects(registry); + + registry.insert(); + + mlir::MLIRContext context(registry); + auto module = mlir::bisect::ParseMlirInput(options.input_filename, &context); + + if (!options.hlo_snapshot.empty()) { + xla::HloSnapshot snapshot; + TF_CHECK_OK(tsl::ReadBinaryProto(tsl::Env::Default(), options.hlo_snapshot, + &snapshot)); + mlir::bisect::ReplaceArgsWithConstants(*module, snapshot); + } + + if (options.debug_strategy.empty()) { + llvm::outs() << "initial module:\n" << *module << "\n"; + } + + mlir::interpreter::ExecutionTrace trace; + if (!mlir::bisect::Run(*module, &trace, options).succeeded()) { + llvm::outs() << "Did not find bug in initial module\n"; + if (options.pass_pipeline.hasAnyOccurrences() && + mlir::succeeded(mlir::bisect::RunPipeline(*module, options))) { + llvm::outs() << "Module after running pipeline:\n" << *module << "\n"; + } + return 1; + } + + mlir::bisect::BisectState state; + state.SetTrace(trace); + if (!options.debug_strategy.empty()) { + bool some_failed = false; + for (auto& candidate : mlir::bisect::detail::GetCandidates( + mlir::bisect::detail::GetStrategies()[options.debug_strategy], + state, *module)) { + llvm::outs() << *candidate << "\n\n"; + if (!mlir::verify(*candidate).succeeded()) { + some_failed = true; + llvm::errs() << "verification failed\n"; + } + } + return some_failed ? 1 : 0; + } + + module = mlir::bisect::ReduceModule(std::move(module), state, options); + + llvm::outs() << "Final module:\n" << *module << "\n"; + if (options.pass_pipeline.hasAnyOccurrences() && + mlir::succeeded(mlir::bisect::RunPipeline(*module, options))) { + llvm::outs() << "Final module after running pipeline:\n" << *module << "\n"; + } + return 0; +} diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/BUILD b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/BUILD new file mode 100644 index 00000000000..ae56591cbf7 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/BUILD @@ -0,0 +1,24 @@ +load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") + +cc_library( + name = "rewrites", + srcs = [ + "func.cc", + "general.cc", + "gml_st.cc", + "scf.cc", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla/mlir/tools/mlir_bisect:bisect_lib", + "//tensorflow/compiler/xla/mlir/tools/mlir_replay/public:execution_trace_utils", + "//tensorflow/compiler/xla/mlir_hlo:gml_st", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + ], + alwayslink = 1, +) diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/func.cc b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/func.cc new file mode 100644 index 00000000000..08e5306f2a1 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/func.cc @@ -0,0 +1,75 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/tools/mlir_bisect/bisect_lib.h" + +namespace mlir { +namespace bisect { +namespace { + +void SetReturnValues(func::FuncOp func, ValueRange values) { + // We only operate on functions without arguments. + func.setFunctionType(mlir::FunctionType::get(func.getContext(), /*inputs=*/{}, + values.getTypes())); + func.getBody().getBlocks().front().getTerminator()->setOperands(values); +} + +SmallVector> TruncateFunction(BisectState&, + func::FuncOp func) { + SmallVector> result; + for (auto& ret : func.getBody().getBlocks().front().without_terminator()) { + if (func.getBody().getBlocks().front().getTerminator()->getOperands() == + ret.getResults()) { + continue; + } + auto [module, ret_clone] = CloneModuleFor(&ret); + SetReturnValues(ret_clone->getParentOfType(), + ret_clone->getResults()); + result.push_back(std::move(module)); + } + return result; +} + +SmallVector> ReturnOperandsOfTerminatorOperands( + BisectState&, func::FuncOp func) { + SmallVector> result; + auto [module, func_clone] = CloneModuleFor(func); + auto* terminator = func_clone.getBody().getBlocks().front().getTerminator(); + SmallVector new_operands; + for (auto operand : terminator->getOperands()) { + if (operand.getDefiningOp()) { + llvm::copy(operand.getDefiningOp()->getOperands(), + std::back_inserter(new_operands)); + } else { + return result; + } + } + SetReturnValues(func_clone, new_operands); + result.push_back(std::move(module)); + return result; +} + +REGISTER_MLIR_REDUCE_STRATEGY(TruncateFunction); +REGISTER_MLIR_REDUCE_STRATEGY(ReturnOperandsOfTerminatorOperands); + +} // namespace +} // namespace bisect +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/general.cc b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/general.cc new file mode 100644 index 00000000000..24368d320e8 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/general.cc @@ -0,0 +1,173 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/tools/mlir_bisect/bisect_lib.h" +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils.h" + +namespace mlir { +namespace bisect { +namespace { + +bool IsTerminator(Operation* op) { + return op->hasTrait(); +} + +bool IsTopLevelOp(Operation* op) { + return !op->getBlock()->back().mightHaveTrait(); +} + +SmallVector> EraseOpWithoutResults(BisectState& state, + Operation* op) { + // Only erase ops with results if they're unused. + if (op->getNumResults() > 0 && !op->use_empty()) { + return {}; + } + + // Don't erase entire functions, constants, terminators. + if (IsTopLevelOp(op) || IsTerminator(op)) { + return {}; + } + + auto [module, cloned_op] = CloneModuleFor(op); + cloned_op->erase(); + SmallVector> ret; + ret.push_back(std::move(module)); + return ret; +} + +llvm::SmallVector> ReplaceOpWithConstant( + BisectState& state, Operation* op) { + llvm::SmallVector> result; + if (op->hasTrait() || IsTopLevelOp(op) || + IsTerminator(op) || op->use_empty() || op->getNumResults() == 0) { + return result; + } + + // Ops that are never executed won't be replaced here, but we have other + // strategies that get rid of them (e.g. deleting the entire region). + for (auto* execution : state.GetExecutions(op)) { + assert(execution->results_size() == op->getNumResults() && + "unexpected number of results"); + + auto [module_clone, op_clone] = CloneModuleFor(op); + SmallVector results; + OpBuilder b(op_clone); + bool all_replaced = true; + for (int64_t i = 0; i < op->getNumResults(); ++i) { + auto type = op->getResultTypes()[i]; + auto value = *interpreter::TracedValueToValue( + execution->results(static_cast(i))); + auto attribute = interpreter::ValueToAttribute(value, type); + if (attribute.size() == 1) { + op_clone->getResults()[i].replaceAllUsesWith( + b.create(op_clone->getLoc(), attribute.front(), + type)); + } else { + // We don't currently support tuples. + all_replaced = false; + } + } + if (all_replaced) { + result.push_back(std::move(module_clone)); + } + } + return result; +} + +llvm::SmallVector> ReplaceOperandWithConstant( + BisectState& state, Operation* op) { + llvm::SmallVector> result; + if (IsTopLevelOp(op) || op->getNumOperands() == 0) { + return result; + } + + for (auto* execution : state.GetExecutions(op)) { + for (int64_t i = 0; i < op->getNumOperands(); ++i) { + auto operand = op->getOperand(i); + if (operand.getDefiningOp() && + operand.getDefiningOp()->hasTrait()) { + continue; + } + auto type = op->getOperandTypes()[i]; + auto value = *interpreter::TracedValueToValue( + execution->args(static_cast(i))); + auto attribute = interpreter::ValueToAttribute(value, type); + if (attribute.size() == 1) { + auto [module_clone, op_clone] = CloneModuleFor(op); + OpBuilder b(op_clone); + op_clone->setOperand( + i, b.create(op_clone->getLoc(), + attribute.front(), type)); + result.push_back(std::move(module_clone)); + } + } + } + return result; +} + +// Replaces an op's result with some other value with the same type defined +// previously in the same region. +llvm::SmallVector> ReplaceOpWithValue(BisectState&, + Operation* op) { + llvm::SmallVector> ret; + if (op->hasTrait() || IsTopLevelOp(op) || + IsTerminator(op)) { + return ret; + } + + // TODO(jreiffers): Consider bbargs. + llvm::DenseMap>> + candidates_by_type; + for (auto* pred = op->getPrevNode(); pred != nullptr; + pred = pred->getPrevNode()) { + for (auto [index, result] : llvm::enumerate(pred->getResults())) { + candidates_by_type[result.getType()].emplace_back(pred, index); + } + } + + for (auto [index, result] : llvm::enumerate(op->getResults())) { + if (result.use_empty()) { + continue; + } + + for (auto [new_result_op, new_result_index] : + candidates_by_type[result.getType()]) { + auto [module_clone, op_clone] = CloneModuleFor(op); + op_clone->getResults()[index].replaceAllUsesWith( + FindInClone(new_result_op, module_clone.get()) + ->getResults()[new_result_index]); + ret.push_back(std::move(module_clone)); + } + } + return ret; +} + +REGISTER_MLIR_REDUCE_STRATEGY(EraseOpWithoutResults); +REGISTER_MLIR_REDUCE_STRATEGY(ReplaceOpWithConstant); +REGISTER_MLIR_REDUCE_STRATEGY(ReplaceOpWithValue); +REGISTER_MLIR_REDUCE_STRATEGY(ReplaceOperandWithConstant); + +} // namespace +} // namespace bisect +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/gml_st.cc b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/gml_st.cc new file mode 100644 index 00000000000..140297e1351 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/gml_st.cc @@ -0,0 +1,54 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/tools/mlir_bisect/bisect_lib.h" +#include "tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.h" + +namespace mlir { +namespace bisect { +namespace { + +SmallVector> ReduceGmlStParallelBounds( + BisectState&, gml_st::ParallelOp parallel_op) { + SmallVector> result; + for (int64_t i = 0; i < parallel_op.getUpperBound().size(); ++i) { + if (!parallel_op.getUpperBound()[i] + .getDefiningOp() + ->hasTrait()) { + continue; + } + + auto [module, op] = CloneModuleFor(parallel_op); + OpBuilder b(op); + op.getUpperBoundMutable().slice(i, 1).assign( + b.createOrFold( + op->getLoc(), op.getUpperBound()[i], + b.create(op->getLoc(), 1))); + result.push_back(std::move(module)); + } + return result; +} + +REGISTER_MLIR_REDUCE_STRATEGY(ReduceGmlStParallelBounds); + +} // namespace +} // namespace bisect +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/scf.cc b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/scf.cc new file mode 100644 index 00000000000..4dcbd911106 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/scf.cc @@ -0,0 +1,102 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project + +#include // NOLINT +#include // NOLINT + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/tools/mlir_bisect/bisect_lib.h" + +namespace mlir { +namespace bisect { +namespace { + +constexpr int64_t kMaxWhileIterations = 1; + +// Rewrites a while loop to execute its body a fixed number of times. The +// condition is executed, but its result is ignored. +// For ease of implementation, this generates scf.execute_region ops. These are +// subsequently canonicalized away. +llvm::SmallVector> InlineScfWhile(BisectState&, + scf::WhileOp whileOp) { + llvm::SmallVector> result; + for (int64_t num_executions = 0; num_executions <= kMaxWhileIterations; + ++num_executions) { + using ::mlir::scf::ExecuteRegionOp; + + auto [module, op] = CloneModuleFor(whileOp); + OpBuilder b(op); + llvm::SmallVector regions; + + auto wrap_region_in_execute = [&, loc = op.getLoc()](mlir::Region& region) { + regions + .emplace_back(b.create( + loc, + region.getBlocks().front().getTerminator()->getOperandTypes(), + mlir::ValueRange{})) + .getRegion() + .takeBody(region); + }; + + wrap_region_in_execute(op.getBefore()); + // Replace the condition terminator with a yield terminator. + { + auto& before_block = regions[0].getRegion().getBlocks().front(); + OpBuilder before_builder(before_block.getTerminator()); + IRRewriter before_rewriter(before_builder); + before_rewriter.replaceOpWithNewOp( + before_block.getTerminator(), + before_block.getTerminator()->getOperands()); + } + + // Clone the execute region ops the requested number of times. + if (num_executions > 0) { + wrap_region_in_execute(op.getAfter()); + for (int64_t i = 0; i < num_executions - 1; ++i) { + b.insert(regions.emplace_back(regions[0].clone())); + b.insert(regions.emplace_back(regions[1].clone())); + } + b.insert(regions.emplace_back(regions[0].clone())); + } + + // Rewire region arguments and erase them. + for (int64_t i = 0; i < regions.size(); ++i) { + auto args = i == 0 ? ValueRange{op.getOperands()} + : ValueRange{regions[i - 1].getResults()}; + bool is_after_region = (i & 1) == 1; + auto& region = regions[i].getRegion(); + for (int64_t arg = static_cast(region.getNumArguments()) - 1; + arg >= 0; --arg) { + region.getArgument(arg).replaceAllUsesWith( + args[is_after_region ? arg + 1 : arg]); + region.eraseArgument(arg); + } + } + op->replaceAllUsesWith(regions.back().getResults().drop_front(1)); + op->erase(); + result.push_back(std::move(module)); + } + return result; +} + +REGISTER_MLIR_REDUCE_STRATEGY(InlineScfWhile); + +} // namespace +} // namespace bisect +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir/transforms/cpu/tests/BUILD b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/BUILD similarity index 58% rename from tensorflow/compiler/xla/mlir/transforms/cpu/tests/BUILD rename to tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/BUILD index b2e46760f92..cba567fb924 100644 --- a/tensorflow/compiler/xla/mlir/transforms/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/BUILD @@ -1,12 +1,14 @@ -load("//tensorflow:tensorflow.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/tsl:tsl.default.bzl", "filegroup") package(licenses = ["notice"]) glob_lit_tests( data = [":test_utilities"], - driver = "//tensorflow/compiler/mlir:run_lit.sh", - test_file_exts = ["mlir"], + driver = "@llvm-project//mlir:run_lit.sh", + test_file_exts = [ + "mlir", + ], ) # Bundle together all of the test utilities that are used by tests. @@ -14,8 +16,7 @@ filegroup( name = "test_utilities", testonly = True, data = [ - "//tensorflow/compiler/xla/mlir/tools:xla-cpu-opt", + "//tensorflow/compiler/xla/mlir/tools/mlir_bisect:mlir-bisect", "@llvm-project//llvm:FileCheck", - "@llvm-project//mlir:run_lit.sh", ], ) diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/erase-op-without-results.mlir b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/erase-op-without-results.mlir new file mode 100644 index 00000000000..e918e112fe4 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/erase-op-without-results.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-bisect %s --debug-strategy=EraseOpWithoutResults | FileCheck %s + +func.func @main() -> memref { + %a = arith.constant 1 : i32 + %b = memref.alloc() : memref + memref.store %a, %b[] : memref + func.return %b : memref +} + +// CHECK: func.func @main() +// CHECK: %[[ALLOC:.*]] = memref.alloc +// CHECK-NEXT: return %[[ALLOC]] diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/inline-scf-while.mlir b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/inline-scf-while.mlir new file mode 100644 index 00000000000..6c9deddbc37 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/inline-scf-while.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-bisect %s --debug-strategy=InlineScfWhile | FileCheck %s + +func.func @main() -> i64 { + %c0 = arith.constant 0 : i64 + %c1 = arith.constant 1 : i64 + %c4 = arith.constant 4 : i64 + %alloc = memref.alloc() : memref + memref.store %c0, %alloc[] : memref + %ret = scf.while(%arg0 = %c0): (i64) -> (i64) { + %cond = arith.cmpi slt, %arg0, %c4 : i64 + scf.condition(%cond) %arg0 : i64 + } do { + ^bb0(%arg1: i64): + %add = arith.addi %arg1, %c1 : i64 + scf.yield %add : i64 + } + return %ret : i64 +} + +// CHECK: func @main +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 +// CHECK: %[[RET:.*]]:2 = scf.execute_region +// CHECK: arith.cmpi slt, %[[C0]], %[[C4]] +// CHECK: yield {{.*}}, %[[C0]] +// CHECK: return %[[RET]]#1 + +// CHECK: func @main +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 +// CHECK: %[[BEFORE0:.*]]:2 = scf.execute_region +// CHECK: arith.cmpi +// CHECK: yield {{.*}}, %[[C0]] +// CHECK: %[[AFTER:.*]] = scf.execute_region +// CHECK: %[[ADD:.*]] = arith.addi %[[BEFORE0]]#1, %[[C1]] +// CHECK: yield %[[ADD]] +// CHECK: %[[BEFORE1:.*]]:2 = scf.execute_region +// CHECK: arith.cmpi +// CHECK: yield {{.*}}, %[[AFTER]] +// CHECK: return %[[BEFORE1]]#1 \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/reduce-gml-st-parallel-bounds.mlir b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/reduce-gml-st-parallel-bounds.mlir new file mode 100644 index 00000000000..d45683ddda4 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/reduce-gml-st-parallel-bounds.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-bisect %s --debug-strategy=ReduceGmlStParallelBounds | FileCheck %s + +func.func @main() -> tensor<8xindex> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %init = tensor.empty() : tensor<8xindex> + %iota = gml_st.parallel (%i) = (%c0) to (%c8) step (%c1) + outs (%init_ = %init: tensor<8xindex>) { + %tile = gml_st.tile [%i] [1] [1] : !gml_st.tile<1> + gml_st.set_yield %i into %init_[%tile] + : index into tensor<8xindex>[!gml_st.tile<1>] + } : tensor<8xindex> + func.return %iota : tensor<8xindex> +} + +// CHECK: func @main() +// CHECK: %[[C7:.*]] = arith.constant 7 +// CHECK: gml_st.parallel {{.*}} to (%[[C7]]) diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/replace-op-with-constant.mlir b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/replace-op-with-constant.mlir new file mode 100644 index 00000000000..171472ad733 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/replace-op-with-constant.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-bisect %s --debug-strategy=ReplaceOpWithConstant | FileCheck %s + +func.func @main() -> tensor<2xi32> { + %a = arith.constant dense<3> : tensor<2xi32> + %b = arith.constant dense<2> : tensor<2xi32> + %c = mhlo.add %a, %b : tensor<2xi32> + %d = mhlo.multiply %b, %c : tensor<2xi32> + func.return %d : tensor<2xi32> +} + +// CHECK: func.func @main() +// CHECK-NEXT: arith.constant dense<3> +// CHECK-NEXT: arith.constant dense<2> +// CHECK-NEXT: arith.constant dense<5> +// CHECK-NEXT: %[[ADD:.*]] = mhlo.add +// CHECK-NOT: %[[ADD]] +// CHECK-NEXT: mhlo.multiply +// CHECK-NEXT: return + +// CHECK: func.func @main() +// CHECK-NEXT: arith.constant dense<3> +// CHECK-NEXT: arith.constant dense<2> +// CHECK-NEXT: mhlo.add +// CHECK-NEXT: %[[D:.*]] = arith.constant dense<10> +// CHECK-NEXT: mhlo.multiply +// CHECK-NEXT: return %[[D]] diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/replace-op-with-value.mlir b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/replace-op-with-value.mlir new file mode 100644 index 00000000000..f89f647f14d --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/replace-op-with-value.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-bisect %s --debug-strategy=ReplaceOpWithValue | FileCheck %s + +func.func @main() -> (memref, memref) { + %a = memref.alloc() : memref + %b = memref.alloc() : memref + %c0 = arith.constant 0 : i32 + memref.store %c0, %b[] : memref + return %a, %b : memref, memref +} + +// CHECK: func @main() +// CHECK: %[[ALLOC:.*]] = memref.alloc() +// CHECK-NEXT: memref.alloc +// CHECK-NEXT: constant +// CHECK-NEXT: memref.store {{.*}}, %[[ALLOC]] +// CHECK-NEXT: return %[[ALLOC]], %[[ALLOC]] diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/replace-operand-with-constant.mlir b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/replace-operand-with-constant.mlir new file mode 100644 index 00000000000..7619a8a500c --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/replace-operand-with-constant.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-bisect %s --debug-strategy=ReplaceOperandWithConstant | FileCheck %s + +func.func @main() -> (tensor<2xi32>, tensor<2xi32>) { + %a = arith.constant dense<3> : tensor<2xi32> + %b = arith.constant dense<2> : tensor<2xi32> + %c = mhlo.add %a, %b : tensor<2xi32> + %d = mhlo.multiply %b, %c : tensor<2xi32> + func.return %c, %d : tensor<2xi32>, tensor<2xi32> +} + +// CHECK: func @main() +// CHECK: %[[C2:.*]] = arith.constant dense<2> +// CHECK: %[[ADD:.*]] = mhlo.add +// CHECK: %[[C5:.*]] = arith.constant dense<5> +// CHECK: %[[MUL:.*]] = mhlo.multiply %[[C2]], %[[C5]] : tensor<2xi32> +// CHECK: return %[[ADD]], %[[MUL]] + +// CHECK: func @main() +// CHECK: mhlo.add +// CHECK: %[[MUL:.*]] = mhlo.multiply %cst_0, %0 : tensor<2xi32> +// CHECK: %[[C5:.*]] = arith.constant dense<5> +// CHECK: return %[[C5]], %[[MUL]] + +// CHECK: func @main() +// CHECK: %[[ADD:.*]] = mhlo.add +// CHECK: mhlo.multiply +// CHECK: %[[C10:.*]] = arith.constant dense<10> +// CHECK: return %[[ADD]], %[[C10]] diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/return-operands-of-terminator-operands.mlir b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/return-operands-of-terminator-operands.mlir new file mode 100644 index 00000000000..8584e2a0008 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/return-operands-of-terminator-operands.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-bisect %s --debug-strategy=ReturnOperandsOfTerminatorOperands | FileCheck %s + +func.func @main() -> tensor<2xi32> { + %a = arith.constant dense<3> : tensor<2xi32> + %b = arith.constant dense<2> : tensor<2xi32> + %c = mhlo.add %a, %b : tensor<2xi32> + %d = mhlo.multiply %b, %c : tensor<2xi32> + func.return %d : tensor<2xi32> +} + +// CHECK: @main +// CHECK: %[[C2:.*]] = arith.constant dense<2> +// CHECK: %[[ADD:.*]] = mhlo.add +// CHECK: mhlo.multiply +// CHECK: return %[[C2]], %[[ADD]] \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/truncate-function.mlir b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/truncate-function.mlir new file mode 100644 index 00000000000..af06778bd47 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/rewrites/tests/truncate-function.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-bisect %s --debug-strategy=TruncateFunction | FileCheck %s + +// Function to prevent constant folding below. +func.func private @cst() -> tensor<2xi32> { + %cst = arith.constant dense<2> : tensor<2xi32> + return %cst : tensor<2xi32> +} + +func.func @main() -> tensor<2xi32> { + %a = arith.constant dense<1> : tensor<2xi32> + %b = func.call @cst() : () -> tensor<2xi32> + %c = mhlo.add %a, %b : tensor<2xi32> + %d = mhlo.multiply %b, %c : tensor<2xi32> + func.return %d : tensor<2xi32> +} + +// CHECK: func @main() +// CHECK: %[[A:.*]] = arith.constant dense<1> +// CHECK: return %[[A]] + +// CHECK: func @main() +// CHECK: %[[B:.*]] = call @cst() +// CHECK: return %[[B]] + +// CHECK: func @main() +// CHECK: %[[A:.*]] = arith.constant dense<1> +// CHECK: %[[B:.*]] = call @cst() +// CHECK: %[[ADD:.*]] = mhlo.add +// CHECK-DAG: %[[A]] +// CHECK-DAG: %[[B]] +// CHECK: return %[[ADD]] diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/test_passes.cc b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/test_passes.cc new file mode 100644 index 00000000000..28daec2025e --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/test_passes.cc @@ -0,0 +1,48 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/mlir/tools/mlir_bisect/test_passes.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace bisect { +namespace test { +namespace { + +struct BreakLinalgTransposePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BreakLinalgTransposePass) + + StringRef getArgument() const final { return "test-break-linalg-transpose"; } + StringRef getDescription() const final { return "breaks linalg transpose"; } + BreakLinalgTransposePass() = default; + + void runOnOperation() override { + getOperation().walk([](linalg::TransposeOp op) { + auto permutation = llvm::to_vector(op.getPermutation()); + std::swap(permutation[0], permutation[1]); + op.setPermutation(permutation); + }); + } +}; +} // namespace + +void RegisterTestPasses() { PassRegistration(); } + +} // namespace test +} // namespace bisect +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/test_passes.h b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/test_passes.h new file mode 100644 index 00000000000..84511daec2d --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/test_passes.h @@ -0,0 +1,29 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_TOOLS_MLIR_BISECT_TEST_PASSES_H_ +#define TENSORFLOW_COMPILER_XLA_MLIR_TOOLS_MLIR_BISECT_TEST_PASSES_H_ + +namespace mlir { +namespace bisect { +namespace test { + +void RegisterTestPasses(); + +} +} // namespace bisect +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_XLA_MLIR_TOOLS_MLIR_BISECT_TEST_PASSES_H_ diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/BUILD b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/BUILD new file mode 100644 index 00000000000..57fad554f46 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/BUILD @@ -0,0 +1,24 @@ +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/tsl:tsl.default.bzl", "filegroup") + +package(licenses = ["notice"]) + +glob_lit_tests( + data = [":test_utilities"], + driver = "@llvm-project//mlir:run_lit.sh", + test_file_exts = [ + "mlir", + ], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "snapshot.mlir.pb", + "//tensorflow/compiler/xla/mlir/tools/mlir_bisect:mlir-bisect", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:not", + ], +) diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/bisect.mlir b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/bisect.mlir new file mode 100644 index 00000000000..ca839d982c4 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/bisect.mlir @@ -0,0 +1,46 @@ +// RUN: mlir-bisect %s \ +// RUN: --pass-pipeline="builtin.module(test-break-linalg-transpose)" \ +// RUN: --max-steps-per-run=200 \ +// RUN: | FileCheck %s + +func.func @main() -> (memref<2x2xindex>, memref<2x2xindex>) { + %a = memref.alloc() : memref<2x2xindex> + %b = memref.alloc() : memref<2x2xindex> + %c = memref.alloc() : memref<2x2xindex> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + scf.for %i = %c0 to %c2 step %c1 { + scf.for %j = %c0 to %c2 step %c1 { + memref.store %i, %a[%i, %j] : memref<2x2xindex> + memref.store %j, %b[%i, %j] : memref<2x2xindex> + } + } + + %i = scf.while: () -> (index) { + %value = memref.load %a[%c0, %c0] : memref<2x2xindex> + %cond = arith.cmpi slt, %value, %c3 : index + scf.condition(%cond) %value : index + } do { + ^bb0(%_: index): + %value = memref.load %a[%c0, %c0] : memref<2x2xindex> + %add = arith.addi %value, %c1 : index + memref.store %add, %a[%c0, %c0] : memref<2x2xindex> + linalg.transpose ins(%b : memref<2x2xindex>) outs(%c : memref<2x2xindex>) + permutation = [1, 0] + memref.copy %c, %b : memref<2x2xindex> to memref<2x2xindex> + scf.yield + } + + return %a, %b : memref<2x2xindex>, memref<2x2xindex> +} + +// CHECK: Final module +// CHECK: func @main() -> memref<2x2xindex> { +// CHECK-NOT: scf.while +// CHECK-NOT: scf.for +// CHECK: linalg.transpose {{.*}} permutation = [1, 0] + +// CHECK: Final module after running pipeline +// CHECK: linalg.transpose {{.*}} permutation = [0, 1] diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/no-bug.mlir b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/no-bug.mlir new file mode 100644 index 00000000000..df343f3bf8b --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/no-bug.mlir @@ -0,0 +1,10 @@ +// RUN: not mlir-bisect %s \ +// RUN: --pass-pipeline="builtin.module(test-break-linalg-transpose)" \ +// RUN: | FileCheck %s + +func.func @main() -> memref<2x2xindex> { + %a = memref.alloc() : memref<2x2xindex> + return %a : memref<2x2xindex> +} + +// CHECK: Did not find bug in initial module diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/snapshot.mlir b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/snapshot.mlir new file mode 100644 index 00000000000..916ca47ab0f --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/snapshot.mlir @@ -0,0 +1,12 @@ +// RUN: not mlir-bisect %s --hlo-snapshot=%s.pb \ +// RUN: --pass-pipeline="builtin.module(test-break-linalg-transpose)" \ +// RUN: | FileCheck %s + +func.func @main(%a: tensor<3x1xi32>, %b: tensor<3x1xi32>) -> tensor<3x1xi32> { + return %a : tensor<3x1xi32> +} + +// CHECK: initial module +// CHECK: func @main() -> tensor<3x1xi32> { +// CHECK{LITERAL}: arith.constant dense<[[2], [-4], [5]]> : tensor<3x1xi32> +// CHECK{LITERAL}: arith.constant dense<[[0], [7], [-5]]> : tensor<3x1xi32> diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/snapshot.mlir.pb b/tensorflow/compiler/xla/mlir/tools/mlir_bisect/tests/snapshot.mlir.pb new file mode 100644 index 0000000000000000000000000000000000000000..ee3c8f759494db153cd7114783124b1cb7fb5da0 GIT binary patch literal 68 scmWeq;1UpEkz!(I)MDXcVq`F4Vqj3>VfynQ3K&_1u&Q8S{|#3H00LzcG5`Po literal 0 HcmV?d00001 diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_replay/BUILD b/tensorflow/compiler/xla/mlir/tools/mlir_replay/BUILD new file mode 100644 index 00000000000..ee7345122cb --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_replay/BUILD @@ -0,0 +1,54 @@ +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_binary") +load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") + +xla_cc_binary( + name = "mlir_replay", + srcs = ["mlir_replay.cc"], + deps = [ + ":mlir_replay_lib", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla/mlir/runtime/ir:rt", + "//tensorflow/compiler/xla/mlir/tools/mlir_replay/public:compiler_trace_proto_cc", + "//tensorflow/compiler/xla/mlir/tools/mlir_replay/public:compiler_trace_proto_cc_impl", + "//tensorflow/compiler/xla/mlir/tools/mlir_replay/public:execution_trace_proto_cc", + "//tensorflow/compiler/xla/mlir/tools/mlir_replay/public:execution_trace_utils", + "//tensorflow/compiler/xla/mlir_hlo:gml_st", + "//tensorflow/compiler/xla/mlir_hlo:hlo_dialect_registration", + "//tensorflow/compiler/xla/mlir_hlo:lhlo", + "//tensorflow/compiler/xla/mlir_hlo:lhlo_gpu", + "//tensorflow/compiler/xla/mlir_hlo:mlir_interpreter_dialects", + "//tensorflow/compiler/xla/mlir_hlo:mlir_interpreter_framework", + "//tensorflow/compiler/xla/mlir_hlo:thlo", + "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:path", + "//tensorflow/tsl/platform:platform_port", + "//tensorflow/tsl/platform:status", + "//tensorflow/tsl/util:command_line_flags", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "mlir_replay_lib", + srcs = ["mlir_replay_lib.cc"], + hdrs = ["mlir_replay_lib.h"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/mlir/tools/mlir_replay/public:execution_trace_proto_cc", + "//tensorflow/compiler/xla/mlir/tools/mlir_replay/public:execution_trace_utils", + "//tensorflow/compiler/xla/mlir_hlo:mlir_interpreter_framework", + "//tensorflow/compiler/xla/service:hlo_proto_cc", + "//tensorflow/tsl/platform:errors", + "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/random", + "@com_google_absl//absl/random:bit_gen_ref", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MlirReduceLib", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_replay/README.md b/tensorflow/compiler/xla/mlir/tools/mlir_replay/README.md new file mode 100644 index 00000000000..89478f5b1bf --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_replay/README.md @@ -0,0 +1,48 @@ +# MLIR Replay tool + +This tool is mainly intended for helping debug miscompiles. It takes as inputs +an HLO snapshot proto with input tensors and a compiler trace proto with the +state of the IR after each pass. + +This tool is built on top of +[mlir-interpreter](https://github.com/tensorflow/mlir-hlo/tree/master/tools/mlir_interpreter/). + +Example usage: + +``` +# Run a JAX test with debug flags enabled: +$ bazel test :some_jax_test --compilation_mode=opt \ + --test_env=XLA_FLAGS="--xla_cpu_use_xla_runtime --xla_dump_to=/tmp/test-dump --xla_dump_hlo_snapshots" \ + --test_filter=SomeSpecific.TestCase \ + --test_sharding_strategy=disabled + +# JAX tends to compile many modules, so first check which one is broken: +./mlir_replay \ + --mlir_compilation_trace_dir=/tmp/test-dump + +Failures for /tmp/test-dump/module_1234.jit_something.mlir-trace.pb: + Result mismatch for /tmp/test-dump/module_1234.jit_something.snapshot.56.pb: TensorOrMemref<3xi32>: [1, 2, 3] != TensorOrMemref<3xi32>: [1, 1, 1] + run :mlir_replay -- --mlir_compilation_trace=/tmp/test-dump/module_1234.jit_something.mlir-trace.pb --hlo_snapshot=/tmp/test-dump/module_1234.jit_something.snapshot.56.pb --print_changes_only --stop_after_first_failure +``` + +There may be multiple failing modules. You can run the provided command to +replay a particular one: + +``` +# Run the IR after each pass. Note that JAX typically compiles many modules, so +# you may have check more than one. +# There is one .mlir-trace.pb file per module (containing the intermediate IR) +# and one .snapshot.pb file per execution (containing the inputs and outputs). +$ ./mlir_replay \ + --mlir_compilation_trace=/tmp/test-dump/module_1234.jit_something.mlir-trace.pb \ + --hlo_snapshot=/tmp/test-dump/module_1234.jit_something.snapshot.56.pb \ + --print_changes_only --stop_after_first_failure +Running IR after APass +Results: [1, 2, 3] + +Running IR after BPass +Running IR after CPass +Running IR after BrokenPass +Results: [1, 1, 1] +``` + diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_replay/mlir_replay.cc b/tensorflow/compiler/xla/mlir/tools/mlir_replay/mlir_replay.cc new file mode 100644 index 00000000000..11a24b294a7 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_replay/mlir_replay.cc @@ -0,0 +1,230 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/mlir/runtime/ir/rt_dialect.h" +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/mlir_replay_lib.h" +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/public/compiler_trace.pb.h" +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace.pb.h" +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils.h" +#include "tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/register.h" +#include "tensorflow/compiler/xla/mlir_hlo/thlo/IR/thlo_ops.h" +#include "tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/interpreter_value.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/init_main.h" +#include "tensorflow/tsl/platform/path.h" +#include "tensorflow/tsl/platform/status.h" +#include "tensorflow/tsl/util/command_line_flags.h" + +struct ReplayOptions { + std::string hlo_snapshot; + std::string mlir_compilation_trace; + std::string mlir_compilation_trace_dir = ""; + std::string execution_trace_dir = ""; + std::string entry_point = "main"; + bool print_changes_only = false; + bool stop_after_first_failure = false; +}; + +bool ResultsMatch(const xla::HloSnapshot& snapshot, + const llvm::SmallVector& + first_pass_results, + std::vector& failures) { + auto actual = mlir::interpreter::LiteralToValue(snapshot.result()); + TF_CHECK_OK(actual.status()); + + // We assume this is MHLO, so multiple results will be in a tuple. + if (first_pass_results.size() != 1) { + failures.push_back("expected one result"); + return false; + } + + if (!(*actual == first_pass_results[0])) { + failures.push_back("result mismatch: " + actual->toString() + + " != " + first_pass_results[0].toString()); + return false; + } + return true; +} + +void TestAll(mlir::MLIRContext& context, const ReplayOptions& opts) { + std::vector traces; + TF_CHECK_OK(tsl::Env::Default()->GetMatchingPaths( + opts.mlir_compilation_trace_dir + "/*.mlir-trace.pb", &traces)); + + for (const auto& trace_path : traces) { + mlir::interpreter::MlirCompilationTrace trace; + TF_CHECK_OK(tsl::ReadBinaryProto(tsl::Env::Default(), trace_path, &trace)) + << "Failed to load " << trace_path; + + std::vector snapshots; + std::string prefix = + trace_path.substr(0, trace_path.length() - strlen(".mlir-trace.pb")); + TF_CHECK_OK(tsl::Env::Default()->GetMatchingPaths(prefix + "*.snapshot.*", + &snapshots)); + CHECK_NE(snapshots.size(), 0) + << "No snapshots found for module " << trace_path << "."; + + std::vector failures; + for (const auto& snapshot_path : snapshots) { + xla::HloSnapshot snapshot; + TF_CHECK_OK( + tsl::ReadBinaryProto(tsl::Env::Default(), snapshot_path, &snapshot)); + + auto results = + mlir::interpreter::Run(context, trace.passes(0).mlir_module(), + snapshot, nullptr, opts.entry_point); + if (!results.status().ok()) { + failures.push_back("Failed to execute " + snapshot_path + ": " + + results.status().ToString()); + } else { + if (!ResultsMatch(snapshot, *results, failures)) { + failures.push_back( + std::string("run :mlir_replay -- --mlir_compilation_trace=") + + trace_path + " --hlo_snapshot=" + snapshot_path + + " --print_changes_only --stop_after_first_failure"); + } + } + } + + if (!failures.empty()) { + llvm::errs() << "Failures for " << trace_path << ":\n " + << absl::StrJoin(failures, "\n ") << "\n"; + } + } +} + +int main(int argc, char* argv[]) { + // Flush llvm::outs before writing errors. + llvm::errs().tie(&llvm::outs()); + + ReplayOptions opts; + std::vector flag_list = { + tsl::Flag("hlo_snapshot", &opts.hlo_snapshot, + "Filename of an HloSnapshot proto. Only used to read inputs."), + tsl::Flag("mlir_compilation_trace", &opts.mlir_compilation_trace, + "Filename of an MlirCompilerTrace proto."), + tsl::Flag("mlir_compilation_trace_dir", &opts.mlir_compilation_trace_dir, + "Directory from which to load MlirCompilerTrace and " + "HloSnapshot protos. The tool will run all snapshots and " + "report the ones with bugs."), + tsl::Flag("execution_trace_dir", &opts.execution_trace_dir, + "Directory where to store the execution traces (optional)."), + tsl::Flag("entry_point", &opts.entry_point, + "Program entry function (optional, defaults to 'main')."), + tsl::Flag("print_changes_only", &opts.print_changes_only, + "If set, only print changed values"), + tsl::Flag("stop_after_first_failure", &opts.stop_after_first_failure, + "If set, stop after the first failed invocation."), + }; + xla::AppendDebugOptionsFlags(&flag_list); + + // The usage string includes the message at the top of the file, the + // DebugOptions flags and the flags defined above. + std::string usage_string = tsl::Flags::Usage(argv[0], flag_list); + if (!tsl::Flags::Parse(&argc, argv, flag_list)) { + return 1; + } + tsl::port::InitMain(usage_string.c_str(), &argc, &argv); + + CHECK(opts.mlir_compilation_trace.empty() != + opts.mlir_compilation_trace_dir.empty()) + << "Exactly one of --mlir_compilation_trace and " + "--mlir_compilation_trace_dir must be specified."; + + CHECK(opts.mlir_compilation_trace_dir.empty() || opts.hlo_snapshot.empty()) + << "If --mlir_compilation_trace_dir is set, --hlo_snapshot must not be."; + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::mhlo::registerAllMhloDialects(registry); + registry.insert(); + + mlir::MLIRContext context(registry); + + if (!opts.mlir_compilation_trace_dir.empty()) { + TestAll(context, opts); + return 0; + } + + xla::HloSnapshot snapshot; + if (!opts.hlo_snapshot.empty()) { + TF_CHECK_OK(tsl::ReadBinaryProto(tsl::Env::Default(), opts.hlo_snapshot, + &snapshot)); + } + mlir::interpreter::MlirCompilationTrace trace; + TF_CHECK_OK(tsl::ReadBinaryProto(tsl::Env::Default(), + opts.mlir_compilation_trace, &trace)); + + llvm::SmallVector previous_results; + int pass_id = 0; + for (auto& state : trace.passes()) { + llvm::outs() << "Running IR after " << state.after_pass() << ".\n"; + mlir::interpreter::ExecutionTrace execution_trace; + auto results = mlir::interpreter::Run( + context, state.mlir_module(), snapshot, + opts.execution_trace_dir.empty() ? nullptr : &execution_trace, + opts.entry_point); + if (results.status().ok()) { + if (!opts.print_changes_only || (*results != previous_results)) { + llvm::outs() << "Results:\n"; + for (const auto& result : *results) { + llvm::outs() << result.toString() << "\n"; + } + previous_results = *results; + llvm::outs() << "\n"; + } + } else { + llvm::errs() << results.status().ToString() << "\n"; + if (opts.stop_after_first_failure) { + return 1; + } + } + + if (!opts.execution_trace_dir.empty()) { + TF_CHECK_OK( + tsl::Env::Default()->RecursivelyCreateDir(opts.execution_trace_dir)); + std::string filename = tsl::io::JoinPath( + opts.execution_trace_dir, + absl::StrFormat("%.4d.%s.mlir", pass_id, state.after_pass())); + TF_CHECK_OK(tsl::WriteStringToFile(tsl::Env::Default(), filename, + execution_trace.ir())); + + filename = tsl::io::JoinPath( + opts.execution_trace_dir, + absl::StrFormat("%.4d.%s.trace.pb", pass_id, state.after_pass())); + TF_CHECK_OK(tsl::WriteBinaryProto(tsl::Env::Default(), filename, + execution_trace)); + } + ++pass_id; + } + + return 0; +} diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_replay/mlir_replay_lib.cc b/tensorflow/compiler/xla/mlir/tools/mlir_replay/mlir_replay_lib.cc new file mode 100644 index 00000000000..e5b5e445104 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_replay/mlir_replay_lib.cc @@ -0,0 +1,187 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/mlir_replay_lib.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/random/bit_gen_ref.h" +#include "absl/random/random.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Tools/ParseUtilities.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils.h" +#include "tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/interpreter.h" +#include "tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/interpreter_value.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/tsl/platform/errors.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace mlir { +namespace interpreter { +namespace { + +tsl::StatusOr> LoadArgs( + const xla::HloSnapshot& snapshot) { + SmallVector result; + for (const auto& arg : snapshot.arguments()) { + TF_ASSIGN_OR_RETURN(auto converted, LiteralToValue(arg)); + result.push_back(std::move(converted)); + } + return result; +} + +namespace { +template class rng_t> +mlir::interpreter::InterpreterValue RandomTensor(absl::BitGenRef bitgen, + mlir::Type type) { + llvm::SmallVector shape; + auto shaped_ty = type.dyn_cast(); + if (shaped_ty) { + shape = llvm::to_vector(shaped_ty.getShape()); + } + + auto rng = rng_t{}; + auto result = mlir::interpreter::TensorOrMemref::empty(shape); + for (const auto& index : result.view.indices()) { + auto& elem = result.at(index) = rng(bitgen); + // Ints are typically indices, so scale them down to a more reasonable + // range. + if constexpr (std::is_same_v) { + elem >>= 60; + } + } + if (shaped_ty) { + return {result}; + } + return {result.at({})}; +} +} // namespace + +mlir::FailureOr MakeRandomInput( + absl::BitGenRef bitgen, mlir::Type type) { + auto elem_ty = + type.isa() ? type.cast().getElementType() : type; + if (elem_ty.isF32()) { + return RandomTensor(bitgen, type); + } + if (elem_ty.isF64()) { + return RandomTensor(bitgen, type); + } + if (elem_ty.isInteger(32)) { + return RandomTensor(bitgen, type); + } + if (elem_ty.isInteger(16)) { + return RandomTensor(bitgen, type); + } + if (elem_ty.isInteger(64)) { + return RandomTensor(bitgen, type); + } + if (elem_ty.isInteger(1)) { + return {{TensorOrMemref::empty(type.cast().getShape())}}; + } + + llvm::errs() << "Unsupported type: "; + type.print(llvm::errs()); + llvm::errs() << "\n"; + return failure(); +} + +} // namespace + +tsl::StatusOr> Run( + MLIRContext& context, const std::string& mlir_ir, + const xla::HloSnapshot& snapshot, ExecutionTrace* trace, + const std::string& entry) { + auto sourceMgr = std::make_shared(); + sourceMgr->AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(mlir_ir), + mlir::SMLoc()); + mlir::OwningOpRef module = + mlir::parseSourceFileForTool(sourceMgr, &context, false); + if (!module) { + return tsl::errors::InvalidArgument("failed to parse MLIR"); + } + + SymbolTable symbols(*module); + auto main = llvm::dyn_cast_or_null(symbols.lookup(entry)); + if (!main) { + return tsl::errors::InvalidArgument("failed to find entry function \"" + + entry + "\""); + } + + if (trace) { + llvm::raw_string_ostream os(*trace->mutable_ir()); + (*module)->print(os, OpPrintingFlags().printGenericOpForm()); + } + + // After xla-rt-export-functions, we have an execution context as the first + // argument. The interpreter currently cannot deal with these things, so we + // fail in that case. + auto function_args = main.getBody().getBlocks().front().getArguments(); + if (!llvm::all_of(function_args, [](Value arg) { + return arg.getType().isa(); + })) { + return tsl::errors::InvalidArgument( + "expected all function arguments to be shaped types"); + } + + TF_ASSIGN_OR_RETURN(auto args, LoadArgs(snapshot)); + auto out_args = + main.getBody().getBlocks().front().getArguments().drop_front(args.size()); + + std::seed_seq my_seed_seq({0}); + absl::BitGen bitgen(my_seed_seq); + llvm::SmallVector out_buffers; + // Add random inputs for output arguments and unspecified inputs. + for (auto arg : out_args) { + auto arg_or = MakeRandomInput(bitgen, arg.getType()); + if (!succeeded(arg_or)) { + return tsl::errors::InvalidArgument("failed to create input"); + } + out_buffers.push_back(*arg_or); + args.push_back(*arg_or); + } + + InterpreterOptions options; + ExecutionTraceListener tracer(trace); + if (trace) { + options.listener = &tracer; + } + auto results_or = runInterpreter(symbols, main, args, options); + if (!succeeded(results_or)) { + return tsl::errors::Internal("interpreter failed"); + } + + if (results_or->empty()) { + return out_buffers; + } + return *results_or; +} + +} // namespace interpreter +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_replay/mlir_replay_lib.h b/tensorflow/compiler/xla/mlir/tools/mlir_replay/mlir_replay_lib.h new file mode 100644 index 00000000000..e4a16f663c8 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_replay/mlir_replay_lib.h @@ -0,0 +1,39 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_TOOLS_MLIR_REPLAY_MLIR_REPLAY_LIB_H_ +#define TENSORFLOW_COMPILER_XLA_MLIR_TOOLS_MLIR_REPLAY_MLIR_REPLAY_LIB_H_ + +#include + +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace.pb.h" +#include "tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/interpreter_value.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace mlir { +namespace interpreter { + +// Runs the given IR on the inputs from `snapshot` and returns the result. +tsl::StatusOr> Run( + MLIRContext& context, const std::string& mlir_ir, + const xla::HloSnapshot& snapshot, ExecutionTrace* trace, + const std::string& entry); + +} // namespace interpreter +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_XLA_MLIR_TOOLS_MLIR_REPLAY_MLIR_REPLAY_LIB_H_ diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/BUILD b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/BUILD new file mode 100644 index 00000000000..b4b88af76c3 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/BUILD @@ -0,0 +1,69 @@ +load("//tensorflow/tsl/platform:build_config.bzl", "tf_proto_library") +load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") +load("//tensorflow/compiler/xla:xla.bzl", "xla_cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "compiler_trace_instrumentation", + srcs = ["compiler_trace_instrumentation.cc"], + hdrs = ["compiler_trace_instrumentation.h"], + deps = [ + ":compiler_trace_proto_cc", + ":compiler_trace_proto_cc_impl", + "//tensorflow/tsl/platform:env", + "//tensorflow/tsl/platform:logging", + "//tensorflow/tsl/platform:path", + "//tensorflow/tsl/platform:protobuf", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + +cc_library( + name = "execution_trace_utils", + srcs = ["execution_trace_utils.cc"], + hdrs = ["execution_trace_utils.h"], + deps = [ + ":execution_trace_proto_cc", + ":execution_trace_proto_cc_impl", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/mlir_hlo:mlir_interpreter_framework", + "//tensorflow/tsl/platform:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +xla_cc_test( + name = "execution_trace_utils_test", + srcs = ["execution_trace_utils_test.cc"], + deps = [ + ":execution_trace_utils", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/mlir_hlo:mlir_interpreter_framework", + "//tensorflow/tsl/platform:statusor", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Support", + ], +) + +tf_proto_library( + name = "execution_trace_proto", + srcs = ["execution_trace.proto"], + cc_api_version = 2, + make_default_target_header_only = True, +) + +tf_proto_library( + name = "compiler_trace_proto", + srcs = ["compiler_trace.proto"], + cc_api_version = 2, + make_default_target_header_only = True, +) diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/README.md b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/README.md new file mode 100644 index 00000000000..c886abf5ffd --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/README.md @@ -0,0 +1,10 @@ +# Public API of mlir_replay + +This contains protocol buffers and utilities that can be reused for other +debugging tools: + +1. **The compiler trace proto**: A record of the state of the IR after each + compilation pass +1. A compiler instrumentation to create the above proto. +1. **The execution trace proto**: A record of SSA values as the IR is executed +1. Utilities for working with the above protos. diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/compiler_trace.proto b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/compiler_trace.proto new file mode 100644 index 00000000000..bf144716208 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/compiler_trace.proto @@ -0,0 +1,31 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mlir.interpreter; + +message MlirCompilationTraceEntry { + // The name of the pass that was previously executed. + optional string after_pass = 1; + + // MLIR module IR of the state after the pass. + optional string mlir_module = 2; +} + +message MlirCompilationTrace { + // MLIR modules corresponding to each stage of the compilation pipeline. + repeated MlirCompilationTraceEntry passes = 1; +} \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.cc b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.cc new file mode 100644 index 00000000000..8c9697723f5 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.cc @@ -0,0 +1,61 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.h" + +#include + +#include "absl/strings/str_format.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/tsl/platform/env.h" +#include "tensorflow/tsl/platform/logging.h" +#include "tensorflow/tsl/platform/path.h" + +namespace mlir { +namespace interpreter { + +void MlirCompilerTraceInstrumentation::runAfterPass(Pass* pass, Operation* op) { + ModuleOp module = llvm::dyn_cast(op); + if (!module) { + module = op->getParentOfType(); + } + if (!module) { + LOG(ERROR) << "Failed to find a ModuleOp: " << pass->getName().str() << "."; + return; + } + + auto* item = trace_.mutable_passes()->Add(); + item->set_after_pass(pass->getName().str()); + llvm::raw_string_ostream os(*item->mutable_mlir_module()); + module.print(os); +} + +MlirCompilerTraceInstrumentation::~MlirCompilerTraceInstrumentation() { + if (!trace_.passes().empty()) { + std::string filename; + absl::StrAppendFormat(&filename, "module_%04d", unique_id_); + if (!module_name_.empty()) { + absl::StrAppend(&filename, ".", module_name_); + } + absl::StrAppend(&filename, ".mlir-trace.pb"); + filename = tsl::io::JoinPath(dirname_, filename); + TF_CHECK_OK(tsl::Env::Default()->RecursivelyCreateDir(dirname_)); + TF_CHECK_OK(tsl::WriteBinaryProto(tsl::Env::Default(), filename, trace_)); + } +} + +} // namespace interpreter +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.h b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.h new file mode 100644 index 00000000000..fa687217005 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.h @@ -0,0 +1,48 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_TOOLS_MLIR_REPLAY_PUBLIC_COMPILER_TRACE_INSTRUMENTATION_H_ +#define TENSORFLOW_COMPILER_XLA_MLIR_TOOLS_MLIR_REPLAY_PUBLIC_COMPILER_TRACE_INSTRUMENTATION_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassInstrumentation.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/public/compiler_trace.pb.h" + +namespace mlir { +namespace interpreter { + +// Instrumentation that logs the state of the IR after each pass. +class MlirCompilerTraceInstrumentation : public PassInstrumentation { + public: + explicit MlirCompilerTraceInstrumentation(const std::string& dirname, + int unique_id, + const std::string& module_name) + : dirname_(dirname), unique_id_(unique_id), module_name_(module_name) {} + ~MlirCompilerTraceInstrumentation() override; + void runAfterPass(Pass* pass, Operation* op) override; + + private: + MlirCompilationTrace trace_; + std::string dirname_; + int unique_id_; + std::string module_name_; +}; + +} // namespace interpreter +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_XLA_MLIR_TOOLS_MLIR_REPLAY_PUBLIC_COMPILER_TRACE_INSTRUMENTATION_H_ diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace.proto b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace.proto new file mode 100644 index 00000000000..6be6407505c --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace.proto @@ -0,0 +1,72 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mlir.interpreter; + +message TracedValue { + // The shape - includes vector dimensions. + // TODO(jreiffers): Model vector dimensions separately. + repeated int64 shape = 1; + optional bool is_scalar = 2; + + enum ElementType { + UNKNOWN = 0; + INTEGRAL = 1; + UNSIGNED = 2; + FLOAT = 3; + COMPLEX = 4; + TUPLE = 5; + } + + optional int32 bit_width = 3; + optional ElementType element_type = 4; + + repeated float floats = 5 [packed = true]; + repeated double doubles = 6 [packed = true]; + repeated int64 ints = 7 [packed = true]; + repeated uint64 uints = 8 [packed = true]; + repeated TracedValue tuple_elements = 9; +} + +message InstructionTrace { + optional string name = 1; + repeated TracedValue args = 2; + repeated TracedValue results = 3; + // TODO(jreiffers): Model side effects (e.g. memref.store). + + repeated RegionTrace regions = 4; +} + +message RegionTrace { + // The number of the region that is being executed (within the parent op). + // For example: '1' for an scf.while's `after` region. + optional int32 region_number = 1; + // The arguments that were passed to the region. + repeated TracedValue bbargs = 2; + // One instruction per instruction in the region. + repeated InstructionTrace instructions = 3; + repeated TracedValue results = 4; +} + +message ExecutionTrace { + // The IR that was executed. Note: this should always be filled in the generic + // format. + optional string ir = 1; + + // The trace of the entry function execution. + optional RegionTrace trace = 2; +} \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc new file mode 100644 index 00000000000..bbbc48df149 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc @@ -0,0 +1,429 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace.pb.h" +#include "tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/interpreter_value.h" +#include "tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/tensor_or_memref.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace mlir { +namespace interpreter { +namespace { + +// Visitor for converting an InterpreterValue to a TracedValue. +struct TraceInterpreterValueVisitor { + TracedValue out; + + void Add(float v) { out.add_floats(v); } + void Add(double v) { out.add_doubles(v); } + void Add(std::complex v) { + out.add_floats(v.real()); + out.add_floats(v.imag()); + } + void Add(std::complex v) { + out.add_doubles(v.real()); + out.add_doubles(v.imag()); + } + void Add(int64_t v) { out.add_ints(v); } + void Add(int32_t v) { out.add_ints(v); } + void Add(int16_t v) { out.add_ints(v); } + void Add(int8_t v) { out.add_ints(v); } + void Add(uint64_t v) { out.add_uints(v); } + void Add(uint32_t v) { out.add_uints(v); } + void Add(uint16_t v) { out.add_uints(v); } + void Add(uint8_t v) { out.add_uints(v); } + void Add(bool v) { out.add_ints(static_cast(v)); } + + template + void operator()(T v) { + SetElementType(); + out.set_is_scalar(true); + Add(v); + } + + void operator()(const Tuple& t) { + out.set_element_type(TracedValue::TUPLE); + for (const auto& v : t.values) { + *out.add_tuple_elements() = ValueToTracedValue(*v); + } + } + + template + void operator()(const TensorOrMemref& v) { + for (int64_t size : v.view.sizes) { + out.add_shape(size); + } + SetElementType(); + for (const auto& index : v.view.indices()) { + Add(v.at(index)); + } + } + + template + void SetElementType() { + out.set_element_type(GetElementType(T{})); + if constexpr (std::is_same_v) { + out.set_bit_width(1); + } else { + out.set_bit_width(sizeof(T) * 8); + } + } + + template + static TracedValue::ElementType GetElementType(const T&) { + if constexpr (std::is_floating_point_v) { + return TracedValue::FLOAT; + } else if constexpr (std::is_integral_v) { + if constexpr (std::is_unsigned_v) { + return TracedValue::UNSIGNED; + } else { + return TracedValue::INTEGRAL; + } + } else { + T{"invalid type"} + 0; + return TracedValue::UNKNOWN; + } + } + + template + static TracedValue::ElementType GetElementType(const std::complex&) { + return TracedValue::COMPLEX; + } + + static TracedValue::ElementType GetElementType(const Tuple&) { + return TracedValue::UNKNOWN; + } +}; + +} // namespace + +void ExecutionTraceListener::beforeOp(ArrayRef args, + Operation* op) { + auto* inst = regions_.back()->add_instructions(); + inst->set_name(op->getName().getStringRef().str()); + for (const auto& arg : args) { + *inst->add_args() = ValueToTracedValue(arg); + } +} + +void ExecutionTraceListener::afterOp(ArrayRef results) { + auto* traced_results = + regions_.back()->mutable_instructions()->rbegin()->mutable_results(); + for (const auto& result : results) { + *traced_results->Add() = ValueToTracedValue(result); + } +} + +void ExecutionTraceListener::enterRegion(ArrayRef bbargs, + Region& region) { + if (regions_.empty()) { + regions_.push_back(trace_->mutable_trace()); + } else { + regions_.push_back( + regions_.back()->mutable_instructions()->rbegin()->add_regions()); + } + + auto& traced_region = *regions_.back(); + traced_region.set_region_number(region.getRegionNumber()); + for (const auto& bbarg : bbargs) { + *traced_region.add_bbargs() = ValueToTracedValue(bbarg); + } +} + +void ExecutionTraceListener::leaveRegion(ArrayRef yielded) { + for (const auto& result : yielded) { + *regions_.back()->add_results() = ValueToTracedValue(result); + } + regions_.pop_back(); +} + +llvm::SmallVector ValueToAttribute( + const InterpreterValue& value, mlir::Type type) { + if (std::holds_alternative(value.storage)) { + auto types = type.cast().getTypes(); + const auto& t = std::get(value.storage); + llvm::SmallVector attrs; + for (const auto& [v, ty] : llvm::zip(t.values, types)) { + auto attr = ValueToAttribute(*v, ty); + assert(attr.size() == 1 && "nested tuples not supported"); + attrs.push_back(attr.front()); + } + return attrs; + } + + if (!value.isTensor()) { + return {cast( + ValueToAttribute(value.asUnitTensor(), + mlir::RankedTensorType::get({}, type)) + .front()) + .getValues()[0]}; + } + + if (!type.isa()) { + return {}; + } + + return { + dispatchScalarType(type.cast().getElementType(), + [&](auto dummy) -> mlir::Attribute { + using T = decltype(dummy); + auto& t = std::get>(value.storage); + SmallVector vals; + for (const auto& index : t.view.indices()) { + vals.push_back(t.at(index)); + } + if constexpr (std::is_same_v) { + return mlir::DenseElementsAttr::get(type, vals); + } else { + return mlir::DenseElementsAttr::get(type, vals); + } + })}; +} + +namespace { +template +TensorOrMemref ArrayLiteralToTensor(const xla::Literal& literal) { + SmallVector layout; + if (literal.shape().has_layout()) { + llvm::copy(literal.shape().layout().minor_to_major(), + std::back_inserter(layout)); + } + SmallVector shape{literal.shape().dimensions().begin(), + literal.shape().dimensions().end()}; + auto result = TensorOrMemref::empty(shape, layout); + assert(literal.size_bytes() == result.buffer->getByteSize() && + "expected buffer sizes to match"); + memcpy(result.buffer->at(0, 0), literal.untyped_data(), + result.buffer->getByteSize()); + return result; +} +} // namespace + +tsl::StatusOr LiteralToValue(const xla::Literal& literal) { + if (literal.shape().IsTuple()) { + auto elements = literal.Clone().DecomposeTuple(); + Tuple result; + for (auto& element : elements) { + TF_ASSIGN_OR_RETURN(auto converted, LiteralToValue(element)); + result.values.push_back( + std::make_shared(std::move(converted))); + } + return {{result}}; + } + + if (literal.shape().IsToken()) { + return tsl::errors::Unimplemented("token arguments are not implemented"); + } + + if (literal.shape().IsArray()) { + switch (literal.shape().element_type()) { + case xla::PRED: + return {{ArrayLiteralToTensor(literal)}}; + case xla::S8: + return {{ArrayLiteralToTensor(literal)}}; + case xla::S16: + return {{ArrayLiteralToTensor(literal)}}; + case xla::S32: + return {{ArrayLiteralToTensor(literal)}}; + case xla::S64: + return {{ArrayLiteralToTensor(literal)}}; + case xla::U8: + return {{ArrayLiteralToTensor(literal)}}; + case xla::U16: + return {{ArrayLiteralToTensor(literal)}}; + case xla::U32: + return {{ArrayLiteralToTensor(literal)}}; + case xla::U64: + return {{ArrayLiteralToTensor(literal)}}; + case xla::F16: + return tsl::errors::Unimplemented("F16 not implemented"); + case xla::F32: + return {{ArrayLiteralToTensor(literal)}}; + case xla::BF16: + return tsl::errors::Unimplemented("BF16 not implemented"); + case xla::F64: + return {{ArrayLiteralToTensor(literal)}}; + case xla::F8E5M2: + return tsl::errors::Unimplemented("F8E5M2 not implemented"); + case xla::F8E4M3FN: + return tsl::errors::Unimplemented("F8E4M3FN not implemented"); + case xla::C64: + return {{ArrayLiteralToTensor>(literal)}}; + case xla::C128: + return {{ArrayLiteralToTensor>(literal)}}; + default: + // Fallthrough intended. + break; + } + } + + return tsl::errors::InvalidArgument("unexpected literal type"); +} + +tsl::StatusOr LiteralToValue( + const xla::LiteralProto& literal) { + TF_ASSIGN_OR_RETURN(auto deserialized, + xla::Literal::CreateFromProto(literal)); + return LiteralToValue(deserialized); +} + +TracedValue ValueToTracedValue(const InterpreterValue& value) { + TraceInterpreterValueVisitor visitor; + std::visit(visitor, value.storage); + return visitor.out; +} + +tsl::StatusOr TracedValueToValue( + const TracedValue& traced_value) { + auto extract = [&](auto dummy, auto& elements) -> InterpreterValue { + using T = decltype(dummy); + if (traced_value.is_scalar()) { + return {static_cast(elements[0])}; + } + + auto result = + TensorOrMemref::empty(llvm::to_vector(traced_value.shape())); + for (auto [index, element] : llvm::zip(result.view.indices(), elements)) { + result.at(index) = element; + } + return {result}; + }; + auto extract_complex = [&](auto& elements) -> InterpreterValue { + using T = std::complex>; + if (traced_value.is_scalar()) { + return {T{elements[0], elements[1]}}; + } + + auto result = + TensorOrMemref::empty(llvm::to_vector(traced_value.shape())); + int64_t i = 0; + for (auto it = result.view.indices().begin(), + end = result.view.indices().end(); + it != end; ++it, i += 2) { + result.at(*it) = {elements[i], elements[i + 1]}; + } + return {result}; + }; + switch (traced_value.element_type()) { + case TracedValue::UNKNOWN: + break; + case TracedValue::FLOAT: + if (traced_value.bit_width() == 32) { + return extract(float{}, traced_value.floats()); + } + return extract(double{}, traced_value.doubles()); + case TracedValue::UNSIGNED: + switch (traced_value.bit_width()) { + case 1: + return extract(bool{}, traced_value.ints()); + case 8: + return extract(uint8_t{}, traced_value.uints()); + case 16: + return extract(uint16_t{}, traced_value.uints()); + case 32: + return extract(uint32_t{}, traced_value.uints()); + case 64: + return extract(uint64_t{}, traced_value.uints()); + } + break; + case TracedValue::INTEGRAL: + switch (traced_value.bit_width()) { + case 8: + return extract(int8_t{}, traced_value.ints()); + case 16: + return extract(int16_t{}, traced_value.ints()); + case 32: + return extract(int32_t{}, traced_value.ints()); + case 64: + return extract(int64_t{}, traced_value.ints()); + } + break; + case TracedValue::COMPLEX: + switch (traced_value.bit_width()) { + case 64: + return extract_complex(traced_value.floats()); + case 128: + return extract_complex(traced_value.doubles()); + } + break; + case TracedValue::TUPLE: + Tuple result; + for (const auto& elem : traced_value.tuple_elements()) { + TF_ASSIGN_OR_RETURN(auto converted, TracedValueToValue(elem)); + result.values.push_back( + std::make_shared(std::move(converted))); + } + return {{std::move(result)}}; + } + return tsl::errors::InvalidArgument("unexpected type: " + + traced_value.DebugString()); +} + +llvm::SmallVector FindOpExecutionsInTrace( + const ExecutionTrace& trace, mlir::Operation* op) { + llvm::SmallVector region_indices; + llvm::SmallVector op_indices; + + std::function get_op_path; + get_op_path = [&](mlir::Operation* op) { + auto* parent = op->getParentOp(); + if (!llvm::isa(parent)) { + get_op_path(parent); + region_indices.push_back(op->getParentRegion()->getRegionNumber()); + } + + int64_t index = 0; + while ((op = op->getPrevNode()) != nullptr) ++index; + op_indices.push_back(index); + }; + get_op_path(op); + + llvm::SmallVector result; + std::function step; + step = [&](const RegionTrace& trace, int index) { + auto& instruction_trace = trace.instructions(op_indices[index]); + if (region_indices.size() > index) { + for (const auto& region : instruction_trace.regions()) { + if (region.region_number() == region_indices[index]) { + step(region, index + 1); + } + } + } else { + result.push_back(&instruction_trace); + } + }; + step(trace.trace(), 0); + + return result; +} + +} // namespace interpreter +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils.h b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils.h new file mode 100644 index 00000000000..7a92d51585c --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils.h @@ -0,0 +1,73 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_TOOLS_MLIR_REPLAY_PUBLIC_EXECUTION_TRACE_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_MLIR_TOOLS_MLIR_REPLAY_PUBLIC_EXECUTION_TRACE_UTILS_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace.pb.h" +#include "tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/interpreter.h" +#include "tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/interpreter_value.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace mlir { +namespace interpreter { + +// Interpreter listener that builds a trace of all executed ops and regions. +class ExecutionTraceListener : public InterpreterListener { + public: + explicit ExecutionTraceListener(ExecutionTrace* trace) : trace_(trace) {} + + void beforeOp(ArrayRef args, Operation* op) override; + void afterOp(ArrayRef results) override; + void enterRegion(ArrayRef bbargs, Region& region) override; + void leaveRegion(ArrayRef yielded) override; + + private: + ExecutionTrace* trace_; + SmallVector regions_; +}; + +// Returns an attribute with the given contents and type. +llvm::SmallVector ValueToAttribute( + const InterpreterValue& value, mlir::Type type); + +// Deserializes the given literal. +tsl::StatusOr LiteralToValue( + const xla::LiteralProto& literal); + +// Deserializes the given literal. +tsl::StatusOr LiteralToValue(const xla::Literal& literal); + +// Serializes the given interpreter value. +TracedValue ValueToTracedValue(const InterpreterValue& value); + +// Deserializes the given traced value. +tsl::StatusOr TracedValueToValue( + const TracedValue& traced_value); + +// Returns all executions of the given op in the given trace. +llvm::SmallVector FindOpExecutionsInTrace( + const ExecutionTrace& trace, mlir::Operation* op); + +} // namespace interpreter +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_XLA_MLIR_TOOLS_MLIR_REPLAY_PUBLIC_EXECUTION_TRACE_UTILS_H_ diff --git a/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils_test.cc b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils_test.cc new file mode 100644 index 00000000000..d23ef77b929 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils_test.cc @@ -0,0 +1,138 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/mlir/tools/mlir_replay/public/execution_trace_utils.h" + +#include +#include +#include +#include +#include +#include + +#include +#include "llvm/ADT/STLExtras.h" +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/framework/interpreter_value.h" +#include "tensorflow/tsl/platform/statusor.h" + +namespace mlir { +namespace interpreter { +namespace { + +class TracedValueRoundTripTest + : public ::testing::TestWithParam {}; + +TEST_P(TracedValueRoundTripTest, Run) { + auto traced_value = ValueToTracedValue(GetParam()); + TF_ASSERT_OK_AND_ASSIGN(auto value, TracedValueToValue(traced_value)); + EXPECT_EQ(GetParam(), value) << GetParam().toString(); +} + +template +InterpreterValue MakeTensor(ArrayRef shape, ArrayRef values) { + auto result = TensorOrMemref::empty(shape); + for (auto [indices, value] : llvm::zip(result.view.indices(), values)) { + result.at(indices) = value; + } + return {result}; +} + +template +std::shared_ptr WrapShared(T value) { + return std::make_shared(std::move(value)); +} + +INSTANTIATE_TEST_SUITE_P( + RoundTrip, TracedValueRoundTripTest, + ::testing::ValuesIn(std::vector{ + {uint8_t{42}}, + {uint16_t{43}}, + {uint32_t{44}}, + {uint64_t{45}}, + {int8_t{-47}}, + {int16_t{-48}}, + {int32_t{-49}}, + {int64_t{-50}}, + {float{42.0}}, + {double{42.0}}, + {std::complex{1.0, 2.0}}, + {std::complex{3.0, 4.0}}, + {true}, + {false}, + {MakeTensor({1, 2}, {42, 43})}, + {MakeTensor({2, 2}, {1.0, -INFINITY, INFINITY, NAN})}, + {MakeTensor>({}, {{1.0, 2.0}})}, + {Tuple{SmallVector>{ + WrapShared(InterpreterValue{42}), + WrapShared(InterpreterValue{43.0}), + }}}})); + +class FromLiteralTest + : public ::testing::TestWithParam< + std::pair, InterpreterValue>> {}; + +TEST_P(FromLiteralTest, Run) { + TF_ASSERT_OK_AND_ASSIGN(auto value, LiteralToValue(*GetParam().first)); + EXPECT_EQ(value, GetParam().second) + << value.toString() << " vs " << GetParam().second.toString(); +} + +std::vector, InterpreterValue>> +MakeInputs() { + using ::xla::LiteralUtil; + return { + {WrapShared(LiteralUtil::CreateR2({{41, 42}})), + MakeTensor({1, 2}, {41, 42})}, + {WrapShared(LiteralUtil::CreateR0(43)), + MakeTensor({}, {43})}, + {WrapShared(LiteralUtil::CreateR0(44)), + MakeTensor({}, {44})}, + {WrapShared(LiteralUtil::CreateR0(45)), + MakeTensor({}, {45})}, + {WrapShared(LiteralUtil::CreateR0(46)), + MakeTensor({}, {46})}, + {WrapShared(LiteralUtil::CreateR0(47)), + MakeTensor({}, {47})}, + {WrapShared(LiteralUtil::CreateR0(48)), + MakeTensor({}, {48})}, + {WrapShared(LiteralUtil::CreateR0(49)), + MakeTensor({}, {49})}, + {WrapShared(LiteralUtil::CreateR0(50.0)), + MakeTensor({}, {50.0})}, + {WrapShared(LiteralUtil::CreateR0(51.0)), + MakeTensor({}, {51.0})}, + {WrapShared(LiteralUtil::CreateR0>({52.0, 53.0})), + MakeTensor>({}, {{52.0, 53.0}})}, + {WrapShared(LiteralUtil::CreateR0>({54.0, 55.0})), + MakeTensor>({}, {{54.0, 55.0}})}, + {WrapShared(LiteralUtil::CreateR1({true, false})), + MakeTensor({2}, {true, false})}, + {WrapShared( + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(true), + LiteralUtil::CreateR0(56))), + InterpreterValue{Tuple{SmallVector>{ + std::make_shared(MakeTensor({}, {true})), + std::make_shared( + MakeTensor({}, {56}))}}}}}; +} + +INSTANTIATE_TEST_SUITE_P(Test, FromLiteralTest, + ::testing::ValuesIn(MakeInputs())); + +} // namespace +} // namespace interpreter +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir/transforms/cpu/lmhlo_to_cpu_runtime.cc b/tensorflow/compiler/xla/mlir/transforms/cpu/lmhlo_to_cpu_runtime.cc deleted file mode 100644 index 35b09a5588c..00000000000 --- a/tensorflow/compiler/xla/mlir/transforms/cpu/lmhlo_to_cpu_runtime.cc +++ /dev/null @@ -1,159 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include - -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.h" -#include "tensorflow/compiler/xla/mlir/transforms/cpu/passes.h" -#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" - -namespace xla { -namespace cpu { -namespace { - -#define GEN_PASS_DEF_CONVERTLMHLOTOCPURUNTIMEPASS -#include "tensorflow/compiler/xla/mlir/transforms/cpu/passes.h.inc" - -using namespace mlir; // NOLINT - -using mlir::lmhlo::CustomCallOp; - -using xla::runtime::AppendCustomCallAttrs; -using xla::runtime::CustomCallDeclarations; - -class ConvertLmhloToCpuRuntimePass - : public impl::ConvertLmhloToCpuRuntimePassBase< - ConvertLmhloToCpuRuntimePass> { - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } -}; - -//===----------------------------------------------------------------------===// - -class CustomCallOpLowering : public OpRewritePattern { - private: - static constexpr const char kCustomCallTarget[] = "xla.cpu.custom_call"; - - public: - CustomCallOpLowering(MLIRContext* ctx, CustomCallDeclarations& custom_calls) - : OpRewritePattern(ctx), custom_calls_(custom_calls) {} - - LogicalResult matchAndRewrite(CustomCallOp op, - PatternRewriter& rewriter) const override { - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - - // By default all operands passed to the custom call handler. - llvm::SmallVector operands = op.getOperands(); - - // Get the number of outputs from operand_segment_sizes. - int64_t num_results = op->getAttrOfType( - op.getOperandSegmentSizesAttrName())[1]; - - // If custom call has target arguments mapping, then we need to pass empty - // memrefs in place of holes. - if (op.getTargetArgMapping().has_value()) { - auto mapping = *op.getTargetArgMapping(); - int64_t num_args = mapping.getNumArgs(); - num_results = mapping.getNumResults(); - - // Always create an `alloca` in the parent function entry block. - // See: https://llvm.org/docs/Frontend/PerformanceTips.html#use-of-allocas - Value hole = [&]() -> Value { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart( - &op->getParentOfType().front()); - return b.create(MemRefType::get({0}, b.getI8Type())); - }(); - - // We represent holes as empty i8 memrefs. - operands = llvm::SmallVector(num_args + num_results, hole); - - // Update operands to mapped custom call arguments. - auto args = mapping.getArgsToTargetArgs(); - for (const auto& indexed : llvm::enumerate(args)) - operands[indexed.value()] = op.getArgs()[indexed.index()]; - - // Update operands to mapped custom call results. - auto res = mapping.getResultsToTargetResults(); - for (const auto& indexed : llvm::enumerate(res)) - operands[num_args + indexed.value()] = op.getOutput()[indexed.index()]; - } - - // Create a custom call function declaration. - func::FuncOp callee = custom_calls_.GetOrCreate( - b, kCustomCallTarget, TypeRange(ValueRange(operands)), TypeRange()); - - llvm::SmallVector custom_call_attrs = { - {b.getStringAttr("num_results"), - b.getI32IntegerAttr(static_cast(num_results))}, - {b.getStringAttr("api_version"), op.getApiVersionAttr()}, - {b.getStringAttr("call_target_name"), op.getCallTargetNameAttr()}}; - - // Call the runtime intrinsic with the original operands. - auto call = rewriter.replaceOpWithNewOp( - op, callee.getName(), TypeRange(), operands); - AppendCustomCallAttrs(call, custom_call_attrs); - - return success(); - } - - private: - CustomCallDeclarations& custom_calls_; -}; - -//===----------------------------------------------------------------------===// - -void ConvertLmhloToCpuRuntimePass::runOnOperation() { - ModuleOp module = getOperation(); - MLIRContext* ctx = module.getContext(); - - // Keep track of the custom calls created from the lowered operations. - SymbolTable sym_table(module); - CustomCallDeclarations custom_calls(std::move(sym_table)); - - // Convert lmhlo operations to XLA cpu runtime custom calls. - RewritePatternSet patterns(ctx); - patterns.insert(ctx, custom_calls); - - if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) - return signalPassFailure(); -} - -} // namespace - -std::unique_ptr> -createConvertLmhloToCpuRuntimePass() { - return std::make_unique(); -} - -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/launch_func_to_cuda_graph.cc b/tensorflow/compiler/xla/mlir/transforms/gpu/launch_func_to_cuda_graph.cc deleted file mode 100644 index 441cc4db232..00000000000 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/launch_func_to_cuda_graph.cc +++ /dev/null @@ -1,185 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project -#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/Transforms/RegionUtils.h" // from @llvm-project -#include "tensorflow/compiler/xla/mlir/runtime/ir/rt_dialect.h" -#include "tensorflow/compiler/xla/mlir/runtime/ir/rt_ops.h" -#include "tensorflow/compiler/xla/mlir/runtime/utils/custom_calls.h" -#include "tensorflow/compiler/xla/mlir/transforms/gpu/passes.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DEF_CONVERTLAUNCHFUNCTOCUDAGRAPHPASS -#include "tensorflow/compiler/xla/mlir/transforms/gpu/passes.h.inc" - -using namespace mlir; // NOLINT - -using mlir::gpu::LaunchFuncOp; - -class ConvertLaunchFuncToCudaGraphPass - : public impl::ConvertLaunchFuncToCudaGraphPassBase< - ConvertLaunchFuncToCudaGraphPass> { - void runOnOperation() override; - - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } -}; - -//===----------------------------------------------------------------------===// - -// A sequence of launch func operation to be outlined into cuda graph -// constructor. -struct LaunchFuncSequence { - llvm::SmallVector ops; -}; - -// Collect sequences of LaunchFuncOp operations that can be outlined into -// Cuda Graph functions. -// -// TODO(ezhulenev): Do not collect launch func sequences if they are already -// inside a graph capture function. -static llvm::SmallVector CollectLaunchFuncSequences( - ModuleOp module) { - llvm::SmallVector seqs; - llvm::DenseSet outlined; - - module.walk([&](LaunchFuncOp op) { - // This launch operation is a part of already collected sequence. - if (outlined.contains(op)) return; - - // Find the first LaunchFuncOp in a sequence. - Operation* first = op; - while (Operation* prev = first->getPrevNode()) { - if (!isa(prev)) break; - first = prev; - } - - // Find the last LaunchFuncOp in a sequence. - Operation* last = op; - while (Operation* next = last->getNextNode()) { - if (!isa(next)) break; - last = next; - } - - // Skip sequences consisting of a single operation. - if (first == last) return; - - // Collect all launch func ops. - LaunchFuncSequence& seq = seqs.emplace_back(); - - auto r = llvm::make_range(Block::iterator(first), ++Block::iterator(last)); - llvm::transform(r, std::back_inserter(seq.ops), [&](Operation& op) { - auto launch = cast(op); - outlined.insert(launch); - return launch; - }); - }); - - return seqs; -} - -//===----------------------------------------------------------------------===// - -using xla::runtime::CustomCallDeclarations; - -// Given a sequence of LaunchFuncOp operations outline them into a function, -// and replace with an XLA Gpu runtime function call. -static void Outline(CustomCallDeclarations& custom_calls, - LaunchFuncSequence& seq) { - SymbolTable& sym_table = custom_calls.sym_table(); - MLIRContext* ctx = sym_table.getOp()->getContext(); - - // Create a fused location out of LaunchFuncOp operations. - llvm::SmallVector locations; - for (auto& op : seq.ops) locations.push_back(op.getLoc()); - ImplicitLocOpBuilder b(FusedLoc::get(ctx, locations), sym_table.getOp()); - - // Collect all arguments used by the launch func operations. - llvm::SetVector args; - for (LaunchFuncOp op : seq.ops) - args.insert(op.operand_begin(), op.operand_end()); - - llvm::SmallVector args_types; - for (Value arg : args) args_types.push_back(arg.getType()); - - // Create a function in the compiled module. - auto func_type = FunctionType::get(ctx, args_types, TypeRange()); - auto func = b.create("xla.gpu.cuda.graph.capture", func_type); - - // Add graph building function to the module. - sym_table.insert(func); - - // Export graph builder function to runtime. - b.setInsertionPoint(func); - b.create(func); - - // Create a custom call declaration corresponding to the outlined graph - // capture function. - func::FuncOp graph_launch = custom_calls.GetOrCreate( - b, "xla.gpu.cuda.graph.launch", args_types, TypeRange()); - - // Call the cuda graph launch custom call. - b.setInsertionPoint(seq.ops.front()); - auto call = b.create(graph_launch.getName(), TypeRange(), - args.getArrayRef()); - call->setAttr(b.getStringAttr("capture"), FlatSymbolRefAttr::get(func)); - - // At this point we successfully added new functions to the module, so we can - // move LaunchFuncOp operations from their original location to the graph - // capture function. - - // Move all launch func operations into the function body. - Block* body = func.addEntryBlock(); - for (LaunchFuncOp op : seq.ops) op->moveBefore(body, body->end()); - - // Replace uses of original values with block arguments. - for (auto p : llvm::zip(args, func.getArguments())) - replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), func.getBody()); - - // Add a return operation to the graph capture function. - b.setInsertionPointToEnd(body); - b.create(ValueRange()); -} - -//===----------------------------------------------------------------------===// - -void ConvertLaunchFuncToCudaGraphPass::runOnOperation() { - SymbolTable sym_table(getOperation()); - CustomCallDeclarations custom_calls(std::move(sym_table)); - - for (auto& seq : CollectLaunchFuncSequences(getOperation())) { - Outline(custom_calls, seq); - } -} - -std::unique_ptr> -createConvertLaunchFuncToCudaGraphPass() { - return std::make_unique(); -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/gpu_launch_to_cuda_graph.mlir b/tensorflow/compiler/xla/mlir/transforms/gpu/tests/gpu_launch_to_cuda_graph.mlir deleted file mode 100644 index ba92e509236..00000000000 --- a/tensorflow/compiler/xla/mlir/transforms/gpu/tests/gpu_launch_to_cuda_graph.mlir +++ /dev/null @@ -1,149 +0,0 @@ -// RUN: xla-gpu-opt %s --split-input-file -xla-gpu-launch-func-to-cuda-graphs \ -// RUN: | FileCheck %s - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref) kernel { - gpu.return - } - gpu.func @fn1(%arg0: memref) kernel { - gpu.return - } -} - -// CHECK: @func( -// CHECK: %[[ARG0:.*]]: memref, -// CHECK: %[[ARG1:.*]]: memref -// CHECK: ) -func.func @func(%arg0: memref, %arg1: memref) { - // CHECK: %[[C1:.*]] = arith.constant 1 - // CHECK: %[[C2:.*]] = arith.constant 2 - // CHECK: %[[C3:.*]] = arith.constant 3 - // CHECK: %[[C4:.*]] = arith.constant 4 - // CHECK: %[[C5:.*]] = arith.constant 5 - // CHECK: %[[C6:.*]] = arith.constant 6 - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %c5 = arith.constant 5 : index - %c6 = arith.constant 6 : index - - // CHECK: call @xla.gpu.cuda.graph.launch( - // CHECK: %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]], - // CHECK: %[[ARG0]], %[[ARG1]]) - // CHECK-SAME: {capture = @xla.gpu.cuda.graph.capture} - // CHECK-NEXT: return - - gpu.launch_func @gpu_module::@fn0 - blocks in (%c1, %c2, %c3) - threads in (%c4, %c5, %c6) - args(%arg0 : memref) - - gpu.launch_func @gpu_module::@fn1 - blocks in (%c3, %c2, %c1) - threads in (%c6, %c5, %c4) - args(%arg1 : memref) - - func.return -} - -// CHECK: func @xla.gpu.cuda.graph.capture -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn0 -// CHECK-NEXT: gpu.launch_func @gpu_module::@fn1 -// CHECK-NEXT: return - -// CHECK: func private @xla.gpu.cuda.graph.launch( -// CHECK-SAME: index, index, index, index, index, index, -// CHECK-SAME: memref, memref) -// CHECK-SAME: attributes {rt.custom_call = "xla.gpu.cuda.graph.launch"} -} - -// ----- -// Check that single function launch was not outlined into graph capture. - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref) kernel { - gpu.return - } -} - -// CHECK: @func(%[[ARG0:.*]]: memref) -func.func @func(%arg0: memref) { - %c1 = arith.constant 1 : index - - // CHECK: gpu.launch_func {{.*}} args(%[[ARG0]] : memref) - // CHECK-NOT: call @xla.gpu.cuda.graph.launch - gpu.launch_func @gpu_module::@fn0 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - - func.return -} - -} - -// ----- -// Check that two different sequences are outlined in different capture -// functions. - -module attributes {gpu.container_module} { - -gpu.module @gpu_module attributes {binary = "kernel binary"} { - gpu.func @fn0(%arg0: memref) kernel { - gpu.return - } - gpu.func @fn1(%arg0: memref) kernel { - gpu.return - } -} - -// CHECK: @func(%[[ARG0:.*]]: memref) -func.func @func(%arg0: memref) { - // CHECK: %[[C1:.*]] = arith.constant 1 - %c1 = arith.constant 1 : index - - // CHECK: call @xla.gpu.cuda.graph.launch(%[[C1]], %[[ARG0]]) - // CHECK-SAME: {capture = @[[CAPTURE:.*]]} - - gpu.launch_func @gpu_module::@fn0 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - - gpu.launch_func @gpu_module::@fn1 - blocks in (%c1, %c1, %c1) - threads in (%c1, %c1, %c1) - args(%arg0 : memref) - - // Use constant to break the large function launch sequence. - // CHECK: %[[C2:.*]] = arith.constant 2 - %c2 = arith.constant 2 : index - - // CHECK: call @xla.gpu.cuda.graph.launch(%[[C2]], %[[ARG0]]) - // CHECK-SAME: {capture = @[[CAPTURE_0:.*]]} - - gpu.launch_func @gpu_module::@fn1 - blocks in (%c2, %c2, %c2) - threads in (%c2, %c2, %c2) - args(%arg0 : memref) - - gpu.launch_func @gpu_module::@fn0 - blocks in (%c2, %c2, %c2) - threads in (%c2, %c2, %c2) - args(%arg0 : memref) - - func.return -} - -// CHECK: rt.export @[[CAPTURE]] -// CHECK: func.func @[[CAPTURE]](%arg0: index, %arg1: memref) - -// CHECK: rt.export @[[CAPTURE_0]] -// CHECK: func.func @[[CAPTURE_0]](%arg0: index, %arg1: memref) - -} diff --git a/tensorflow/compiler/xla/mlir/utils/BUILD b/tensorflow/compiler/xla/mlir/utils/BUILD index 64139a49eb7..9e9a6f974d9 100644 --- a/tensorflow/compiler/xla/mlir/utils/BUILD +++ b/tensorflow/compiler/xla/mlir/utils/BUILD @@ -1,9 +1,10 @@ -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - "//tensorflow:internal", + "//tensorflow/compiler/xla:internal", ], licenses = ["notice"], ) diff --git a/tensorflow/compiler/xla/mlir/xla_cpu/ir/BUILD b/tensorflow/compiler/xla/mlir/xla_cpu/ir/BUILD new file mode 100644 index 00000000000..88c3b2aa49a --- /dev/null +++ b/tensorflow/compiler/xla/mlir/xla_cpu/ir/BUILD @@ -0,0 +1,106 @@ +load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +td_library( + name = "td_files", + srcs = [ + "xla_cpu_dialect.td", + "xla_cpu_enums.td", + "xla_cpu_ops.td", + ], + compatible_with = get_compatible_with_cloud(), + deps = [ + "//tensorflow/compiler/xla/mlir_hlo:hlo_ops_td_files", + "@llvm-project//mlir:BufferizableOpInterfaceTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "xla_cpu_dialect_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ( + ["-gen-dialect-decls"], + "xla_cpu_dialect.h.inc", + ), + ( + ["-gen-dialect-defs"], + "xla_cpu_dialect.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_cpu_dialect.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "xla_cpu_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ( + ["-gen-op-decls"], + "xla_cpu.h.inc", + ), + ( + ["-gen-op-defs"], + "xla_cpu.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_cpu_ops.td", + deps = [":td_files"], +) + +gentbl_cc_library( + name = "xla_cpu_enums_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ( + ["-gen-enum-decls"], + "xla_cpu_enums.h.inc", + ), + ( + ["-gen-enum-defs"], + "xla_cpu_enums.cc.inc", + ), + ( + ["-gen-attrdef-decls"], + "xla_cpu_attrdefs.h.inc", + ), + ( + ["-gen-attrdef-defs"], + "xla_cpu_attrdefs.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_cpu_enums.td", + deps = [ + ":td_files", + ], +) + +cc_library( + name = "xla_cpu", + srcs = [ + "xla_cpu.cc", + ], + hdrs = ["xla_cpu.h"], + deps = [ + ":xla_cpu_dialect_inc_gen", + ":xla_cpu_enums_inc_gen", + ":xla_cpu_inc_gen", + "//tensorflow/compiler/xla/mlir_hlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + ], +) diff --git a/tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.cc b/tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.cc new file mode 100644 index 00000000000..293c55b95b4 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.cc @@ -0,0 +1,164 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_dialect.cc.inc" +#include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_enums.cc.inc" +#include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#define GET_ATTRDEF_CLASSES +#include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_attrdefs.cc.inc" + +namespace mlir { +namespace xla_cpu { + +using ::mlir::mhlo::TokenType; + +void XlaCpuDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.cc.inc" +#undef GET_OP_LIST + >(); +} + +template +LogicalResult BufferizeOp(Op op, RewriterBase &rewriter, + const bufferization::BufferizationOptions &options, + int64_t num_inputs) { + if (op.getOperands().front().getType().template isa()) { + return success(); + } + SmallVector new_operands; + for (auto operand : op.getOperands()) { + FailureOr maybe_buffer = getBuffer(rewriter, operand, options); + if (failed(maybe_buffer)) { + return failure(); + } + new_operands.push_back(*maybe_buffer); + } + rewriter.create(op.getLoc(), TypeRange{}, new_operands, + op.getOperation()->getAttrs()); + bufferization::replaceOpWithBufferizedValues( + rewriter, op.getOperation(), + llvm::ArrayRef(new_operands).drop_front(num_inputs)); + return success(); +} + +bool AllReduceOp::bufferizesToMemoryRead(OpOperand &opOperand, + const bufferization::AnalysisState &) { + return opOperand.getOperandNumber() < getNumOperands() / 2; +} + +bool AllReduceOp::bufferizesToMemoryWrite( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + return !bufferizesToMemoryRead(opOperand, state); +} + +bufferization::AliasingOpResultList AllReduceOp::getAliasingOpResults( + OpOperand &opOperand, const bufferization::AnalysisState &) { + if (opOperand.getOperandNumber() < getNumOperands() / 2) { + return {}; + } + return {getOperation()->getOpResult(opOperand.getOperandNumber() - + getNumOperands() / 2)}; +} + +LogicalResult AllReduceOp::bufferize( + RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) { + return BufferizeOp(*this, rewriter, options, this->getNumOperands() / 2); +} + +bufferization::BufferRelation AllReduceOp::bufferRelation( + OpResult, const bufferization::AnalysisState &) { + return bufferization::BufferRelation::Equivalent; +} + +LogicalResult CollectivePermuteOp::bufferize( + RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) { + return BufferizeOp(*this, rewriter, options, this->getNumOperands() / 2); +} + +LogicalResult AllToAllOp::bufferize( + RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) { + return BufferizeOp(*this, rewriter, options, this->getNumOperands() / 2); +} + +LogicalResult FftOp::bufferize( + RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) { + return BufferizeOp(*this, rewriter, options, this->getNumOperands() / 2); +} + +LogicalResult OutfeedOp::bufferize( + RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) { + return BufferizeOp(*this, rewriter, options, this->getNumOperands()); +} + +LogicalResult RngBitGeneratorOp::bufferize( + RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) { + return BufferizeOp(*this, rewriter, options, 1); +} + +LogicalResult AddDependencyOp::bufferize( + RewriterBase &rewriter, + const bufferization::BufferizationOptions &options) { + FailureOr maybe_buffer = + getBuffer(rewriter, this->getOperand(), options); + if (failed(maybe_buffer)) { + return rewriter.notifyMatchFailure(*this, + "failed during bufferizing operand"); + } + bufferization::replaceOpWithBufferizedValues(rewriter, this->getOperation(), + *maybe_buffer); + return success(); +} + +LogicalResult MemRefElementCastOp::verify() { + auto src_memref_ty = getSrc().getType().cast(); + auto dst_memref_ty = getDst().getType().cast(); + if (src_memref_ty.getShape() != dst_memref_ty.getShape()) { + return emitOpError() << "expects matching shapes"; + } + + unsigned src_width = src_memref_ty.getElementType().getIntOrFloatBitWidth(); + unsigned dst_width = dst_memref_ty.getElementType().getIntOrFloatBitWidth(); + if ((src_width + CHAR_BIT - 1) / CHAR_BIT != + (dst_width + CHAR_BIT - 1) / CHAR_BIT) { + return emitOpError() << "cannot cast from " + << src_memref_ty.getElementType() << " to " + << dst_memref_ty.getElementType(); + } + return success(); +} + +} // namespace xla_cpu +} // namespace mlir + +#define GET_OP_CLASSES +#include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.cc.inc" diff --git a/tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.h b/tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.h new file mode 100644 index 00000000000..de5c65c7600 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.h @@ -0,0 +1,38 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations and types used in the XLAFramework dialect. +// +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_XLA_CPU_IR_XLA_CPU_H_ +#define TENSORFLOW_COMPILER_XLA_MLIR_XLA_CPU_IR_XLA_CPU_H_ + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project + +#define GET_OP_CLASSES +#include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu.h.inc" +#include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_dialect.h.inc" +#include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_enums.h.inc" +#define GET_ATTRDEF_CLASSES +#include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_attrdefs.h.inc" +#undef GET_OP_CLASSES + +#endif // TENSORFLOW_COMPILER_XLA_MLIR_XLA_CPU_IR_XLA_CPU_H_ diff --git a/tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_dialect.td b/tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_dialect.td new file mode 100644 index 00000000000..906d665e2da --- /dev/null +++ b/tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_dialect.td @@ -0,0 +1,33 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_XLA_CPU_DIALECT_TD_ +#define TENSORFLOW_COMPILER_XLA_MLIR_XLA_CPU_DIALECT_TD_ + +include "mlir/IR/OpBase.td" + +def XlaCpuDialect : Dialect { + let name = "xla_cpu"; + + let summary = "Enums and operations for the xla_cpu dialect"; + let description = [{ + This dialect contains operations that bridge the gap between HLO and the + CPU runtime. + }]; + let cppNamespace = "::mlir::xla_cpu"; + let useFoldAPI = kEmitFoldAdaptorFolder; +} + +#endif // TENSORFLOW_COMPILER_XLA_MLIR_XLA_CPU_DIALECT_TD_ diff --git a/tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_enums.td b/tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_enums.td new file mode 100644 index 00000000000..7c3656ff1bc --- /dev/null +++ b/tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_enums.td @@ -0,0 +1,39 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_XLA_CPU_ENUMS_TD_ +#define TENSORFLOW_COMPILER_XLA_MLIR_XLA_CPU_ENUMS_TD_ + +include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_dialect.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/PatternBase.td" + +def ALL_REDUCE_SUM : I32EnumAttrCase<"ALL_REDUCE_SUM", 0>; +def ALL_REDUCE_PRODUCT : I32EnumAttrCase<"ALL_REDUCE_PRODUCT", 1>; +def ALL_REDUCE_MIN : I32EnumAttrCase<"ALL_REDUCE_MIN", 2>; +def ALL_REDUCE_MAX : I32EnumAttrCase<"ALL_REDUCE_MAX", 3>; + +def XlaCpuReductionKind : I32EnumAttr<"ReductionKind", + "Type of reduction to apply.", + [ALL_REDUCE_SUM, ALL_REDUCE_PRODUCT, ALL_REDUCE_MIN, ALL_REDUCE_MAX]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::xla_cpu"; +} + +def XlaCpuReductionKindEnum : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +#endif // TENSORFLOW_COMPILER_XLA_MLIR_XLA_CPU_ENUMS_TD_ diff --git a/tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_ops.td b/tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_ops.td new file mode 100644 index 00000000000..183b87c375f --- /dev/null +++ b/tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_ops.td @@ -0,0 +1,339 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_MLIR_XLA_CPU_OPS_TD_ +#define TENSORFLOW_COMPILER_XLA_MLIR_XLA_CPU_OPS_TD_ + +include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_dialect.td" +include "tensorflow/compiler/xla/mlir/xla_cpu/ir/xla_cpu_enums.td" +include "tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_base.td" + +// Base class for XLA CPU dialect ops. +class XlaCpu_Op traits = []> : + Op; + +def TensorOrMemref : + AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; + +def AllReduceOp : XlaCpu_Op<"all_reduce", + [SameOperandsElementType, + SameVariadicOperandSize, + BufferizableOpInterface]> { + let summary = [{ + CPU-specific version of AllReduce. + }]; + + let description = [{ + The major differences between this and HLO's all_reduce are: + - It bufferizes to itself. + - It has no region. + - It uses destination passing style. + }]; + + let arguments = (ins + Variadic:$operand, + Variadic:$dsts, + I64ElementsAttr:$replica_groups, + I64Attr:$channel_handle, + I32Attr:$use_global_device_ids, + XlaCpuReductionKind:$reduction_kind + ); + let results = (outs + Variadic + ); + let extraClassDeclaration = [{ + // Declarations for BufferizableOpInterface: + bool bufferizesToMemoryRead(OpOperand &opOperand, + const bufferization::AnalysisState &state); + bool bufferizesToMemoryWrite(OpOperand &opOperand, + const bufferization::AnalysisState &state); + bufferization::AliasingOpResultList getAliasingOpResults( + OpOperand &opOperand, const bufferization::AnalysisState &state); + LogicalResult bufferize(RewriterBase &rewriter, + const bufferization::BufferizationOptions &options); + bufferization::BufferRelation bufferRelation(OpResult opResult, + const bufferization::AnalysisState &state); + }]; +} + +def ReplicaIdOp : XlaCpu_Op<"replica_id"> { + let summary = "CPU-specific version of ReplicaId"; + let description = [{ + ReplicaId, but returns a i32 instead of tensor. + }]; + let results = (outs I32); +} + +def PartitionIdOp : XlaCpu_Op<"partition_id"> { + let summary = "CPU-specific version of PartitionId"; + let description = [{ + PartitionId, but returns a i32 instead of tensor. + }]; + let results = (outs I32); +} + +def CollectivePermuteOp : XlaCpu_Op<"collective_permute", [BufferizableOpInterface]> { + let summary = "CPU-specific version of CollectivePermute"; + let description = [{ + The major differences between this and HLO's collective_permute are: + - It bufferizes to itself. + - It uses destination passing style. + }]; + + let arguments = (ins + TensorOrMemref:$operand, + TensorOrMemref:$dst, + I64ElementsAttr:$source_target_pairs, + I64Attr:$channel_handle + ); + let results = (outs Variadic); + let extraClassDeclaration = [{ + // Declarations for BufferizableOpInterface: + bool bufferizesToMemoryRead(OpOperand &opOperand, + const bufferization::AnalysisState &state) { + return opOperand.getOperandNumber() == 0; + } + bool bufferizesToMemoryWrite(OpOperand &opOperand, + const bufferization::AnalysisState &state) { + return opOperand.getOperandNumber() == 1; + } + bufferization::AliasingOpResultList getAliasingOpResults( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + if (opOperand.getOperandNumber() == 0) return {}; + return {getOperation()->getOpResult(0)}; + } + LogicalResult bufferize(RewriterBase &rewriter, + const bufferization::BufferizationOptions &options); + bufferization::BufferRelation bufferRelation(OpResult opResult, + const bufferization::AnalysisState &state) { + return bufferization::BufferRelation::Equivalent; + } + }]; +} + +def AllToAllOp : XlaCpu_Op<"all_to_all", + [SameOperandsElementType, + SameVariadicOperandSize, + BufferizableOpInterface]> { + let summary = "CPU-specific version of AllToAll"; + let description = [{ + The major differences between this and HLO's all_to_all are: + - It bufferizes to itself. + - It uses destination passing style. + }]; + + let arguments = (ins + Variadic:$operand, + Variadic:$dst, + I64ElementsAttr:$replica_groups, + OptionalAttr:$split_dimension, + OptionalAttr:$concat_dimension, + OptionalAttr:$split_count + ); + let results = (outs Variadic); + let extraClassDeclaration = [{ + // Declarations for BufferizableOpInterface: + bool bufferizesToMemoryRead(OpOperand &opOperand, + const bufferization::AnalysisState &state) { + return opOperand.getOperandNumber() < getNumOperands() / 2; + } + bool bufferizesToMemoryWrite(OpOperand &opOperand, + const bufferization::AnalysisState &state) { + return opOperand.getOperandNumber() >= getNumOperands() / 2; + } + bufferization::AliasingOpResultList getAliasingOpResults( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + if (bufferizesToMemoryRead(opOperand, state)) return {}; + return {getOperation()->getOpResult(opOperand.getOperandNumber() - getNumOperands() / 2)}; + } + LogicalResult bufferize(RewriterBase &rewriter, + const bufferization::BufferizationOptions &options); + bufferization::BufferRelation bufferRelation(OpResult opResult, + const bufferization::AnalysisState &state) { + return bufferization::BufferRelation::Equivalent; + } + }]; +} + +def FftOp : XlaCpu_Op<"fft", [BufferizableOpInterface]> { + let summary = "CPU-specific version of FFT"; + let description = [{ + The major differences between this and HLO's fft are: + - It bufferizes to itself. + - It uses destination passing style. + }]; + + let arguments = (ins + TensorOrMemref:$operand, + TensorOrMemref:$dst, + I32Attr:$fft_type, + I64ArrayAttr:$fft_length + ); + let results = (outs Variadic); + let extraClassDeclaration = [{ + // Declarations for BufferizableOpInterface: + bool bufferizesToMemoryRead(OpOperand &opOperand, + const bufferization::AnalysisState &state) { + return opOperand.getOperandNumber() == 0; + } + bool bufferizesToMemoryWrite(OpOperand &opOperand, + const bufferization::AnalysisState &state) { + return opOperand.getOperandNumber() == 1; + } + bufferization::AliasingOpResultList getAliasingOpResults( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + if (opOperand.getOperandNumber() == 0) return {}; + return {getOperation()->getOpResult(0)}; + } + LogicalResult bufferize(RewriterBase &rewriter, + const bufferization::BufferizationOptions &options); + bufferization::BufferRelation bufferRelation(OpResult opResult, + const bufferization::AnalysisState &state) { + return bufferization::BufferRelation::Equivalent; + } + }]; +} + +def OutfeedOp : XlaCpu_Op<"outfeed", [BufferizableOpInterface]> { + let summary = "CPU-specific version of Outfeed"; + let description = [{ + The major differences between this and HLO's outfeed are: + - It bufferizes to itself. + - It captures the output type to reinstate it after signless conversions. + }]; + let arguments = (ins + Variadic:$operand, + DefaultValuedStrAttr:$config, + ArrayAttr:$result_type + ); + let extraClassDeclaration = [{ + // Declarations for BufferizableOpInterface: + bool bufferizesToMemoryRead(OpOperand &opOperand, + const bufferization::AnalysisState &state) { + return true; + } + bool bufferizesToMemoryWrite(OpOperand &opOperand, + const bufferization::AnalysisState &state) { + return false; + } + bufferization::AliasingOpResultList getAliasingOpResults( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + return {}; + } + LogicalResult bufferize(RewriterBase &rewriter, + const bufferization::BufferizationOptions &options); + bufferization::BufferRelation bufferRelation(OpResult opResult, + const bufferization::AnalysisState &state) { + return bufferization::BufferRelation::Equivalent; + } + }]; +} + +def MemRefElementCastOp : XlaCpu_Op<"memref_element_cast", + [SameOperandsAndResultShape]> { + let summary = "MemRef reinterpret_cast on element types"; + let description = [{ + This op is the equivalent of C++'s reinterpret_cast on pointers. The element + types' storage sizes must be the same. Does not cast shapes. + }]; + let arguments = (ins + MemRefOf<[I1, I8, I16, I32, I64, BF16, F16, F32, F64]>:$src + ); + let results = (outs + MemRefOf<[I1, I8, I16, I32, I64, BF16, F16, F32, F64]>:$dst + ); + let assemblyFormat = "$src attr-dict `:` type($src) `to` type($dst)"; + let hasVerifier = 1; +} + +def RngBitGeneratorOp : XlaCpu_Op<"rng_bit_generator", [BufferizableOpInterface]> { + let summary = "CPU-specific version of rng_bit_generator"; + let description = [{ + The major differences between this and HLO's rng_bit_generator are: + - It bufferizes to itself. + - It uses destination passing style. + }]; + let arguments = (ins + TensorOrMemref:$state, + TensorOrMemref:$dst_state, + TensorOrMemref:$dst, + AnyAttr:$rng_algorithm + ); + let results = (outs Variadic); + let extraClassDeclaration = [{ + // Declarations for BufferizableOpInterface: + bool bufferizesToMemoryRead(OpOperand &opOperand, + const bufferization::AnalysisState &state) { + return opOperand.getOperandNumber() == 0; + } + bool bufferizesToMemoryWrite(OpOperand &opOperand, + const bufferization::AnalysisState &state) { + return opOperand.getOperandNumber() != 0; + } + bufferization::AliasingOpResultList getAliasingOpResults( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + if (opOperand.getOperandNumber() == 0) return {}; + return {getOperation()->getOpResult(opOperand.getOperandNumber()-1)}; + } + LogicalResult bufferize(RewriterBase &rewriter, + const bufferization::BufferizationOptions &options); + bufferization::BufferRelation bufferRelation(OpResult opResult, + const bufferization::AnalysisState &state) { + return bufferization::BufferRelation::Equivalent; + } + }]; +} + +def AddDependencyOp : XlaCpu_Op<"add_dependency", [BufferizableOpInterface]> { + let summary = "CPU-specific version of AddDependency"; + let description = [{ + The major differences between this and HLO's add_dependency are: + - It bufferizes itself. + }]; + let arguments = (ins + MHLO_TensorOrToken:$operand, + MHLO_Token:$token + ); + let results = (outs MHLO_TensorOrToken); + let extraClassDeclaration = [{ + // Declarations for BufferizableOpInterface: + bool bufferizesToMemoryRead(OpOperand &opOperand, + const bufferization::AnalysisState &state) { + return opOperand.getOperandNumber() == 0; + } + bool bufferizesToMemoryWrite(OpOperand &opOperand, + const bufferization::AnalysisState &state) { + return false; + } + bufferization::AliasingOpResultList getAliasingOpResults( + OpOperand &opOperand, const bufferization::AnalysisState &state) { + if (opOperand.getOperandNumber() == 0 || opOperand.getOperandNumber() == 1) + return {}; + return {getOperation()->getOpResult(0)}; + } + LogicalResult bufferize(RewriterBase &rewriter, + const bufferization::BufferizationOptions &options); + bufferization::BufferRelation bufferRelation(OpResult opResult, + const bufferization::AnalysisState &state) { + return bufferization::BufferRelation::Unknown; + } + }]; +} + +#endif // TENSORFLOW_COMPILER_XLA_MLIR_XLA_CPU_OPS_TD_ diff --git a/tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data/BUILD b/tensorflow/compiler/xla/mlir/xla_cpu/tests/BUILD similarity index 55% rename from tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data/BUILD rename to tensorflow/compiler/xla/mlir/xla_cpu/tests/BUILD index 30aa6a6ef3d..c29814eebc5 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data/BUILD +++ b/tensorflow/compiler/xla/mlir/xla_cpu/tests/BUILD @@ -1,23 +1,25 @@ +load("//tensorflow/tsl:tsl.default.bzl", "filegroup") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") -load("//tensorflow:tensorflow.bzl", "if_oss") package(licenses = ["notice"]) glob_lit_tests( data = [":test_utilities"], - driver = "//tensorflow/compiler/mlir:run_lit.sh", - features = if_oss(["--path=org_tensorflow/tensorflow/compiler/mlir/tfrt"]), - test_file_exts = ["mlir"], + driver = "//tensorflow/compiler/xla:run_lit.sh", + test_file_exts = [ + "mlir", + ], ) # Bundle together all of the test utilities that are used by tests. +# This intentionally does not pull-in the top-level tf-opt to reduce the +# dependencies. filegroup( name = "test_utilities", testonly = True, data = [ - "//tensorflow/compiler/mlir/tfrt:tf-tfrt-opt", + "//tensorflow/compiler/xla/mlir/backends/cpu:xla-cpu-opt", "@llvm-project//llvm:FileCheck", - "@llvm-project//llvm:not", "@llvm-project//mlir:run_lit.sh", ], ) diff --git a/tensorflow/compiler/xla/mlir/xla_cpu/tests/bufferize.mlir b/tensorflow/compiler/xla/mlir/xla_cpu/tests/bufferize.mlir new file mode 100644 index 00000000000..f0d5cde450c --- /dev/null +++ b/tensorflow/compiler/xla/mlir/xla_cpu/tests/bufferize.mlir @@ -0,0 +1,129 @@ +// RUN: xla-cpu-opt %s -split-input-file -empty-tensor-to-alloc-tensor \ +// RUN: -one-shot-bufferize | FileCheck %s + +func.func @max_reduce(%arg0: tensor<10xf32>) -> tensor<10xf32> { + %0 = tensor.empty() : tensor<10xf32> + %1 = "xla_cpu.all_reduce"(%arg0, %0) { + channel_handle = 5 : i64, + reduction_kind = 3 : i32, + replica_groups = dense<[]> : tensor<0xi64>, + use_global_device_ids = 0 : i32 + } : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + return %1 : tensor<10xf32> +} + +// CHECK-LABEL: @max_reduce +// CHECK-SAME: %[[ARG0:.*]]: tensor<10xf32> +// CHECK: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0]] +// CHECK: %[[OUT:.*]] = memref.alloc() {{.*}} memref<10xf32> +// CHECK: "xla_cpu.all_reduce"(%[[ARG0_MEMREF]], %[[OUT]]) { +// CHECK-SAME: channel_handle = 5 +// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[OUT]] +// CHECK: return %[[RESULT]] + +// ----- + +func.func @collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + %0 = tensor.empty() : tensor<16x8xf32> + %1 = "xla_cpu.collective_permute"(%arg0, %0) { + channel_handle = 1 : i64, + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> + } : (tensor<16x8xf32>, tensor<16x8xf32>) -> tensor<16x8xf32> + return %1 : tensor<16x8xf32> +} + +// CHECK-LABEL: @collective_permute +// CHECK-SAME: %[[ARG0:.*]]: tensor<16x8xf32> +// CHECK: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0]] +// CHECK: %[[OUT:.*]] = memref.alloc() {{.*}} memref<16x8xf32> +// CHECK: "xla_cpu.collective_permute"(%[[ARG0_MEMREF]], %[[OUT]]) { +// CHECK-SAME: channel_handle = 1 +// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[OUT]] +// CHECK: return %[[RESULT]] + +// ----- + +func.func @all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { + %0 = tensor.empty() : tensor<16x4xf32> + %1 = "xla_cpu.all_to_all"(%arg0, %0) { + concat_dimension = 0 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + split_count = 4 : i64, + split_dimension = 1 : i64 + } : (tensor<4x16xf32>, tensor<16x4xf32>) -> tensor<16x4xf32> + return %1 : tensor<16x4xf32> +} + +// CHECK-LABEL: @all_to_all +// CHECK-SAME: %[[ARG0:.*]]: tensor<4x16xf32> +// CHECK: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0]] +// CHECK: %[[OUT:.*]] = memref.alloc() {{.*}} memref<16x4xf32> +// CHECK: "xla_cpu.all_to_all"(%[[ARG0_MEMREF]], %[[OUT]]) { +// CHECK-SAME: split_count = 4 +// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[OUT]] +// CHECK: return %[[RESULT]] + + +// ----- + +func.func @all_to_all_tuple(%arg0: tensor<128x4xf32>, %arg1: tensor<128x4xf32>) + -> (tensor<128x4xf32>, tensor<128x4xf32>) { + %0 = tensor.empty() : tensor<128x4xf32> + %1 = tensor.empty() : tensor<128x4xf32> + %2:2 = "xla_cpu.all_to_all"(%arg0, %arg1, %0, %1) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + } : (tensor<128x4xf32>, tensor<128x4xf32>, + tensor<128x4xf32>, tensor<128x4xf32>) -> + (tensor<128x4xf32>, tensor<128x4xf32>) + return %2#0, %2#1 : tensor<128x4xf32>, tensor<128x4xf32> +} + +// CHECK-LABEL: @all_to_all_tuple +// CHECK-SAME: %[[ARG0:.*]]: tensor<128x4xf32>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<128x4xf32> +// CHECK-DAG: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0]] +// CHECK-DAG: %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1]] +// CHECK-DAG: "xla_cpu.all_to_all"(%[[ARG0_MEMREF]], %[[ARG1_MEMREF]], %[[OUT0:.*]], %[[OUT1:.*]]) { +// CHECK-DAG: %[[OUT0]] = memref.alloc() {{.*}} memref<128x4xf32> +// CHECK-DAG: %[[OUT1]] = memref.alloc() {{.*}} memref<128x4xf32> +// CHECK-DAG: %[[RESULT0:.*]] = bufferization.to_tensor %[[OUT0]] : +// CHECK-DAG: %[[RESULT1:.*]] = bufferization.to_tensor %[[OUT1]] : +// CHECK: return %[[RESULT0]], %[[RESULT1]] + +// ----- + +func.func @fft(%arg0: tensor<3x5x4x8x256xf32>) -> tensor<3x5x4x8x129xcomplex> { + %0 = tensor.empty() : tensor<3x5x4x8x129xcomplex> + %1 = "xla_cpu.fft"(%arg0, %0) { + fft_length = [4, 8, 256], + fft_type = 2 : i32 + } : (tensor<3x5x4x8x256xf32>,tensor<3x5x4x8x129xcomplex>) -> tensor<3x5x4x8x129xcomplex> + return %1 : tensor<3x5x4x8x129xcomplex> +} + +// CHECK-LABEL: @fft +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x5x4x8x256xf32> +// CHECK: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0]] +// CHECK: %[[OUT:.*]] = memref.alloc() {{.*}} +// CHECK: "xla_cpu.fft"(%[[ARG0_MEMREF]], %[[OUT]]) + + +// ----- + +func.func @rng_bit_generator(%state: tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) { + %new_state_init = tensor.empty() : tensor<2xui64> + %output_init = tensor.empty() : tensor<10x12xui32> + %new_state, %output = "xla_cpu.rng_bit_generator"(%state, %new_state_init, + %output_init) { + rng_algorithm = #mhlo.rng_algorithm + } : (tensor<2xui64>, tensor<2xui64>, tensor<10x12xui32>) + -> (tensor<2xui64>, tensor<10x12xui32>) + func.return %new_state, %output : tensor<2xui64>, tensor<10x12xui32> +} + +// CHECK-LABEL: @rng_bit_generator +// CHECK-SAME: %[[STATE:.*]]: tensor +// CHECK: %[[STATE_MEMREF:.*]] = bufferization.to_memref %[[STATE]] +// CHECK: %[[STATE_OUT:.*]] = memref.alloc() {{.*}}<2xui64> +// CHECK: %[[OUTPUT:.*]] = memref.alloc() {{.*}}<10x12xui32> +// CHECK: "xla_cpu.rng_bit_generator"(%[[STATE_MEMREF]], %[[STATE_OUT]], %[[OUTPUT]]) \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir/xla_cpu/tests/invalid.mlir b/tensorflow/compiler/xla/mlir/xla_cpu/tests/invalid.mlir new file mode 100644 index 00000000000..8f9584417e6 --- /dev/null +++ b/tensorflow/compiler/xla/mlir/xla_cpu/tests/invalid.mlir @@ -0,0 +1,7 @@ +// RUN: xla-cpu-opt %s -split-input-file -verify-diagnostics + +func.func @memref_cast_out_of_place(%arg0: memref<10xi1>) -> memref<10xi16> { + // expected-error @+1 {{cannot cast from 'i1' to 'i16'}} + %ret = xla_cpu.memref_element_cast %arg0 : memref<10xi1> to memref<10xi16> + return %ret : memref<10xi16> +} diff --git a/tensorflow/compiler/xla/mlir/xla_cpu/tests/ops.mlir b/tensorflow/compiler/xla/mlir/xla_cpu/tests/ops.mlir new file mode 100644 index 00000000000..7f06ab3fd3d --- /dev/null +++ b/tensorflow/compiler/xla/mlir/xla_cpu/tests/ops.mlir @@ -0,0 +1,16 @@ +// RUN: xla-cpu-opt %s -split-input-file -empty-tensor-to-alloc-tensor \ +// RUN: -one-shot-bufferize | FileCheck %s + +func.func @memref_cast(%arg0: memref<10xf32>) -> memref<10xi32> { + %ret = xla_cpu.memref_element_cast %arg0 : memref<10xf32> to memref<10xi32> + return %ret : memref<10xi32> +} + +// CHECK: xla_cpu.memref_element_cast {{.*}} : memref<10xf32> to memref<10xi32> + +func.func @memref_cast_i1(%arg0: memref<10xi1>) -> memref<10xi8> { + %ret = xla_cpu.memref_element_cast %arg0 : memref<10xi1> to memref<10xi8> + return %ret : memref<10xi8> +} + +// CHECK: xla_cpu.memref_element_cast {{.*}} : memref<10xi1> to memref<10xi8> \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir_hlo/BUILD b/tensorflow/compiler/xla/mlir_hlo/BUILD index 877c1c738c4..fb8b4700681 100644 --- a/tensorflow/compiler/xla/mlir_hlo/BUILD +++ b/tensorflow/compiler/xla/mlir_hlo/BUILD @@ -1,30 +1,31 @@ load("//tensorflow/tsl/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_cloud") +load("//tensorflow/tsl:tsl.default.bzl", "get_compatible_with_cloud") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library") package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:public"], licenses = ["notice"], ) exports_files([ - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", - "include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.td", + "mhlo/IR/hlo_ops.td", + "lhlo/IR/lhlo_ops.td", ]) # Python extension sources. -exports_files(["python/MlirHloModule.cc"]) +exports_files(["bindings/python/MlirHloModule.cc"]) filegroup( name = "hlo_ops_td_filegroup", - srcs = glob(["include/mlir-hlo/Dialect/mhlo/IR/*.td"]), + srcs = glob(["mhlo/IR/*.td"]), ) td_library( name = "hlo_ops_td_files", - srcs = glob(["include/mlir-hlo/Dialect/mhlo/IR/*.td"]), + srcs = glob(["mhlo/IR/*.td"]), compatible_with = get_compatible_with_cloud(), - includes = ["include"], + includes = ["."], deps = [ "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:ControlFlowInterfacesTdFiles", @@ -45,130 +46,133 @@ td_library( gentbl_cc_library( name = "mhlo_pass_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", + strip_include_prefix = ".", tbl_outs = [ ( [ "-gen-pass-decls", "-name=AllMhlo", ], - "include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc", + "mhlo/transforms/mhlo_passes.h.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td", + td_file = "mhlo/transforms/mhlo_passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], ) gentbl_cc_library( name = "lmhlo_pass_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", + strip_include_prefix = ".", tbl_outs = [ ( [ "-gen-pass-decls", "-name=AllLmhlo", ], - "include/mlir-hlo/Dialect/lhlo/transforms/lmhlo_passes.h.inc", + "lhlo/transforms/lmhlo_passes.h.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/lhlo/transforms/lmhlo_passes.td", + td_file = "lhlo/transforms/lmhlo_passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], ) gentbl_cc_library( name = "hlo_ops_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", + strip_include_prefix = ".", tbl_outs = [ ( ["-gen-op-decls"], - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc", + "mhlo/IR/hlo_ops.h.inc", ), ( ["-gen-op-defs"], - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc", + "mhlo/IR/hlo_ops.cc.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", + td_file = "mhlo/IR/hlo_ops.td", deps = [":hlo_ops_td_files"], ) gentbl_cc_library( name = "hlo_ops_attrs_inc_gen", compatible_with = get_compatible_with_cloud(), + strip_include_prefix = ".", tbl_outs = [ ( ["-gen-attrdef-decls"], - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_attrs.h.inc", + "mhlo/IR/hlo_ops_attrs.h.inc", ), ( ["-gen-attrdef-defs"], - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_attrs.cc.inc", + "mhlo/IR/hlo_ops_attrs.cc.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", + td_file = "mhlo/IR/hlo_ops.td", deps = [":hlo_ops_td_files"], ) gentbl_cc_library( name = "hlo_ops_enums_inc_gen", compatible_with = get_compatible_with_cloud(), + strip_include_prefix = ".", tbl_outs = [ ( ["-gen-enum-decls"], - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_enums.h.inc", + "mhlo/IR/hlo_ops_enums.h.inc", ), ( ["-gen-enum-defs"], - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_enums.cc.inc", + "mhlo/IR/hlo_ops_enums.cc.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", + td_file = "mhlo/IR/hlo_ops.td", deps = [":hlo_ops_td_files"], ) gentbl_cc_library( name = "hlo_ops_typedefs_inc_gen", compatible_with = get_compatible_with_cloud(), + strip_include_prefix = ".", tbl_outs = [ ( [ "-gen-typedef-decls", "--typedefs-dialect=mhlo", ], - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_typedefs.h.inc", + "mhlo/IR/hlo_ops_typedefs.h.inc", ), ( [ "-gen-typedef-defs", "--typedefs-dialect=mhlo", ], - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_typedefs.cc.inc", + "mhlo/IR/hlo_ops_typedefs.cc.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", + td_file = "mhlo/IR/hlo_ops.td", deps = [":hlo_ops_td_files"], ) gentbl_cc_library( name = "hlo_ops_pattern_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "lib/Dialect/mhlo/IR/", + strip_include_prefix = "mhlo/IR/", tbl_outs = [ ( ["-gen-rewriters"], - "lib/Dialect/mhlo/IR/hlo_patterns.cc.inc", + "mhlo/IR/hlo_patterns.cc.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "lib/Dialect/mhlo/IR/hlo_patterns.td", + td_file = "mhlo/IR/hlo_patterns.td", deps = [ ":hlo_ops_td_files", "@llvm-project//mlir:FuncTdFiles", @@ -179,95 +183,95 @@ gentbl_cc_library( gentbl_cc_library( name = "lhlo_ops_structs_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", + strip_include_prefix = ".", tbl_outs = [ ( ["-gen-attrdef-decls"], - "include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops_structs.h.inc", + "lhlo/IR/lhlo_ops_structs.h.inc", ), ( ["-gen-attrdef-defs"], - "include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops_structs.cc.inc", + "lhlo/IR/lhlo_ops_structs.cc.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops_structs.td", + td_file = "lhlo/IR/lhlo_ops_structs.td", deps = [":lhlo_ops_td_files"], ) gentbl_cc_library( name = "lhlo_ops_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", + strip_include_prefix = ".", tbl_outs = [ ( ["-gen-op-decls"], - "include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h.inc", + "lhlo/IR/lhlo_ops.h.inc", ), ( ["-gen-op-defs"], - "include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.cc.inc", + "lhlo/IR/lhlo_ops.cc.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.td", + td_file = "lhlo/IR/lhlo_ops.td", deps = [":lhlo_ops_td_files"], ) gentbl_cc_library( name = "lhlo_gpu_ops_enums_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", + strip_include_prefix = ".", tbl_outs = [ ( ["-gen-enum-decls"], - "include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_enums.h.inc", + "lhlo_gpu/IR/lhlo_gpu_ops_enums.h.inc", ), ( ["-gen-enum-defs"], - "include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_enums.cc.inc", + "lhlo_gpu/IR/lhlo_gpu_ops_enums.cc.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_enums.td", + td_file = "lhlo_gpu/IR/lhlo_gpu_ops_enums.td", deps = [":lhlo_gpu_ops_td_files"], ) gentbl_cc_library( name = "lhlo_gpu_ops_dialect_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", + strip_include_prefix = ".", tbl_outs = [ ( ["-gen-dialect-decls"], - "include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_dialect.h.inc", + "lhlo_gpu/IR/lhlo_gpu_ops_dialect.h.inc", ), ( ["-gen-dialect-defs"], - "include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_dialect.cc.inc", + "lhlo_gpu/IR/lhlo_gpu_ops_dialect.cc.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_enums.td", + td_file = "lhlo_gpu/IR/lhlo_gpu_ops_enums.td", deps = [":lhlo_gpu_ops_td_files"], ) gentbl_cc_library( name = "lhlo_gpu_ops_attrdefs_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", + strip_include_prefix = ".", tbl_outs = [ ( ["-gen-attrdef-decls"], - "include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_attrdefs.h.inc", + "lhlo_gpu/IR/lhlo_gpu_ops_attrdefs.h.inc", ), ( ["-gen-attrdef-defs"], - "include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_attrdefs.cc.inc", + "lhlo_gpu/IR/lhlo_gpu_ops_attrdefs.cc.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_enums.td", + td_file = "lhlo_gpu/IR/lhlo_gpu_ops_enums.td", deps = [":lhlo_gpu_ops_td_files"], ) @@ -284,7 +288,7 @@ gentbl_filegroup( ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", + td_file = "mhlo/IR/hlo_ops.td", deps = [":hlo_ops_td_files"], ) @@ -301,15 +305,15 @@ gentbl_filegroup( ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.td", + td_file = "lhlo/IR/lhlo_ops.td", deps = [":lhlo_ops_td_files"], ) cc_library( name = "hlo_ops_common", - srcs = ["lib/Dialect/mhlo/IR/hlo_ops_common.cc"], - hdrs = ["include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"], - includes = ["include"], + srcs = ["mhlo/IR/hlo_ops_common.cc"], + hdrs = ["mhlo/IR/hlo_ops_common.h"], + strip_include_prefix = ".", deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -318,9 +322,9 @@ cc_library( td_library( name = "lhlo_gpu_ops_td_files", - srcs = glob(["include/mlir-hlo/Dialect/lhlo_gpu/IR/*.td"]), + srcs = glob(["lhlo_gpu/IR/*.td"]), compatible_with = get_compatible_with_cloud(), - includes = ["include"], + includes = ["."], deps = [ ":hlo_ops_td_files", ":lhlo_ops_td_files", @@ -331,19 +335,19 @@ td_library( gentbl_cc_library( name = "lhlo_gpu_ops_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", + strip_include_prefix = ".", tbl_outs = [ ( ["-gen-op-decls"], - "include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h.inc", + "lhlo_gpu/IR/lhlo_gpu_ops.h.inc", ), ( ["-gen-op-defs"], - "include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.cc.inc", + "lhlo_gpu/IR/lhlo_gpu_ops.cc.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.td", + td_file = "lhlo_gpu/IR/lhlo_gpu_ops.td", deps = [":lhlo_gpu_ops_td_files"], ) @@ -351,23 +355,23 @@ gentbl_cc_library( gentbl_cc_library( name = "canonicalize_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "lib/Dialect/mhlo/IR/", + strip_include_prefix = ".", tbl_outs = [ ( ["-gen-rewriters"], - "lib/Dialect/mhlo/IR/mhlo_canonicalize.inc", + "mhlo/IR/mhlo_canonicalize.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "lib/Dialect/mhlo/IR/mhlo_canonicalize.td", + td_file = "mhlo/IR/mhlo_canonicalize.td", deps = [":hlo_ops_td_files"], ) td_library( name = "lhlo_ops_td_files", - srcs = glob(["include/mlir-hlo/Dialect/lhlo/IR/*.td"]), + srcs = glob(["lhlo/IR/*.td"]), compatible_with = get_compatible_with_cloud(), - includes = ["include"], + includes = ["."], deps = [ ":hlo_ops_td_files", "@llvm-project//mlir:ControlFlowInterfacesTdFiles", @@ -385,29 +389,30 @@ td_library( gentbl_cc_library( name = "lhlo_structured_interface_inc_gen", compatible_with = get_compatible_with_cloud(), + strip_include_prefix = ".", tbl_outs = [ ( ["-gen-op-interface-decls"], - "include/mlir-hlo/Dialect/lhlo/IR/lhlo_structured_interface.h.inc", + "lhlo/IR/lhlo_structured_interface.h.inc", ), ( ["-gen-op-interface-defs"], - "include/mlir-hlo/Dialect/lhlo/IR/lhlo_structured_interface.cpp.inc", + "lhlo/IR/lhlo_structured_interface.cpp.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/lhlo/IR/lhlo_structured_interface.td", + td_file = "lhlo/IR/lhlo_structured_interface.td", deps = [":lhlo_ops_td_files"], ) cc_library( name = "lhlo_structured_interface", - srcs = ["lib/Dialect/lhlo/IR/lhlo_structured_interface.cc"], + srcs = ["lhlo/IR/lhlo_structured_interface.cc"], hdrs = [ - "include/mlir-hlo/Dialect/lhlo/IR/lhlo_structured_interface.h", - "include/mlir-hlo/Dialect/lhlo/IR/lhlo_structured_interface.h.inc", + "lhlo/IR/lhlo_structured_interface.h", + "lhlo/IR/lhlo_structured_interface.h.inc", ], - includes = ["include"], + strip_include_prefix = ".", deps = [ ":lhlo_structured_interface_inc_gen", "@llvm-project//mlir:IR", @@ -417,9 +422,9 @@ cc_library( cc_library( name = "convert_op_folder", - srcs = ["lib/utils/convert_op_folder.cc"], - hdrs = ["include/mlir-hlo/utils/convert_op_folder.h"], - includes = ["include"], + srcs = ["utils/convert_op_folder.cc"], + hdrs = ["utils/convert_op_folder.h"], + strip_include_prefix = ".", deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -429,24 +434,24 @@ cc_library( cc_library( name = "mlir_hlo", srcs = [ - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc", - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc", - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_attrs.cc.inc", - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_attrs.h.inc", - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_enums.cc.inc", - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_enums.h.inc", - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_typedefs.cc.inc", - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_typedefs.h.inc", - "lib/Dialect/mhlo/IR/hlo_ops.cc", - "lib/Dialect/mhlo/IR/mhlo_bytecode.cc", - "lib/utils/hlo_utils.cc", + "mhlo/IR/hlo_ops.cc", + "mhlo/IR/hlo_ops.cc.inc", + "mhlo/IR/hlo_ops.h.inc", + "mhlo/IR/hlo_ops_attrs.cc.inc", + "mhlo/IR/hlo_ops_attrs.h.inc", + "mhlo/IR/hlo_ops_enums.cc.inc", + "mhlo/IR/hlo_ops_enums.h.inc", + "mhlo/IR/hlo_ops_typedefs.cc.inc", + "mhlo/IR/hlo_ops_typedefs.h.inc", + "mhlo/IR/mhlo_bytecode.cc", + "utils/hlo_utils.cc", ], hdrs = [ - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h", - "include/mlir-hlo/Dialect/mhlo/IR/mhlo_bytecode.h", - "include/mlir-hlo/utils/hlo_utils.h", + "mhlo/IR/hlo_ops.h", + "mhlo/IR/mhlo_bytecode.h", + "utils/hlo_utils.h", ], - includes = ["include"], + strip_include_prefix = ".", deps = [ ":canonicalize_inc_gen", ":convert_op_folder", @@ -465,6 +470,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", @@ -485,13 +491,13 @@ cc_library( cc_library( name = "lhlo", - srcs = ["lib/Dialect/lhlo/IR/lhlo_ops.cc"], + srcs = ["lhlo/IR/lhlo_ops.cc"], hdrs = [ - "include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h", - "include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops_structs.h", - "include/mlir-hlo/utils/lhlo_utils.h", + "lhlo/IR/lhlo_ops.h", + "lhlo/IR/lhlo_ops_structs.h", + "lhlo/utils/lhlo_utils.h", ], - includes = ["include"], + strip_include_prefix = ".", deps = [ ":hlo_ops_common", ":lhlo_ops_inc_gen", @@ -507,14 +513,15 @@ cc_library( "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:ViewLikeInterface", + "@stablehlo//:stablehlo_type_inference", ], ) cc_library( name = "lhlo_gpu", - srcs = ["lib/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.cc"], - hdrs = ["include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"], - includes = ["include"], + srcs = ["lhlo_gpu/IR/lhlo_gpu_ops.cc"], + hdrs = ["lhlo_gpu/IR/lhlo_gpu_ops.h"], + strip_include_prefix = ".", deps = [ ":hlo_ops_common", ":lhlo", @@ -532,9 +539,9 @@ cc_library( cc_library( name = "lhlo_gpu_ops_ops", - srcs = ["include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.cc.inc"], - hdrs = ["include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h.inc"], - includes = ["include"], + srcs = ["lhlo_gpu/IR/lhlo_gpu_ops.cc.inc"], + hdrs = ["lhlo_gpu/IR/lhlo_gpu_ops.h.inc"], + strip_include_prefix = ".", deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", @@ -555,8 +562,9 @@ cc_library( cc_library( name = "hlo_dialect_registration", - srcs = ["lib/Dialect/mhlo/IR/init.cc"], - hdrs = ["include/mlir-hlo/Dialect/mhlo/IR/register.h"], + srcs = ["mhlo/IR/init.cc"], + hdrs = ["mhlo/IR/register.h"], + strip_include_prefix = ".", deps = [ ":mlir_hlo", "@llvm-project//mlir:IR", @@ -569,59 +577,62 @@ cc_library( cc_library( name = "mhlo_passes", srcs = [ - "include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc", - "lib/Dialect/mhlo/transforms/broadcast_propagation.cc", - "lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc", - "lib/Dialect/mhlo/transforms/collapse_elementwise_map.cc", - "lib/Dialect/mhlo/transforms/constraint_fusion_pass.cc", - "lib/Dialect/mhlo/transforms/convert_to_signless_pass.cc", - "lib/Dialect/mhlo/transforms/expand_hlo_tuples.cc", - "lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc", - "lib/Dialect/mhlo/transforms/generated_lower_complex.inc", - "lib/Dialect/mhlo/transforms/group_reduction_dimensions.cc", - "lib/Dialect/mhlo/transforms/hlo_legalize_shape_ops_to_standard.cc", - "lib/Dialect/mhlo/transforms/hlo_legalize_to_arithmetic.cc", - "lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc", - "lib/Dialect/mhlo/transforms/hlo_legalize_to_memref.cc", - "lib/Dialect/mhlo/transforms/hlo_legalize_to_stablehlo_pass.cc", - "lib/Dialect/mhlo/transforms/legalize_control_flow.cc", - "lib/Dialect/mhlo/transforms/legalize_einsum_to_dot_general.cc", - "lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc", - "lib/Dialect/mhlo/transforms/legalize_mhlo_to_thlo.cc", - "lib/Dialect/mhlo/transforms/legalize_shape_computations.cc", - "lib/Dialect/mhlo/transforms/legalize_sort.cc", - "lib/Dialect/mhlo/transforms/legalize_to_linalg.cc", - "lib/Dialect/mhlo/transforms/legalize_to_standard.cc", - "lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc", - "lib/Dialect/mhlo/transforms/lower_complex.cc", - "lib/Dialect/mhlo/transforms/lower_general_dot.cc", - "lib/Dialect/mhlo/transforms/materialize_broadcasts.cc", - "lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc", - "lib/Dialect/mhlo/transforms/merge_assuming_ops.cc", - "lib/Dialect/mhlo/transforms/mhlo_canonicalize_gather.cc", - "lib/Dialect/mhlo/transforms/mhlo_canonicalize_reduction.cc", - "lib/Dialect/mhlo/transforms/mhlo_canonicalize_scatter.cc", - "lib/Dialect/mhlo/transforms/mhlo_flatten_tuple.cc", - "lib/Dialect/mhlo/transforms/optimize_mhlo.cc", - "lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc", - "lib/Dialect/mhlo/transforms/prepare_for_export.cc", - "lib/Dialect/mhlo/transforms/rank_specialization.cc", - "lib/Dialect/mhlo/transforms/restrict_max_rank.cc", - "lib/Dialect/mhlo/transforms/shape_reification_pass.cc", - "lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc", - "lib/Dialect/mhlo/transforms/sparse_chlo_legalize_to_linalg.cc", - "lib/Dialect/mhlo/transforms/sparse_rewriting.cc", - "lib/Dialect/mhlo/transforms/stablehlo_legalize_to_hlo_pass.cc", - "lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc", - "lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc", + "mhlo/transforms/broadcast_propagation/broadcast_propagation.cc", + "mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc", + "mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc", + "mhlo/transforms/constraint_fusion/constraint_fusion_pass.cc", + "mhlo/transforms/convert_to_signless/convert_to_signless_pass.cc", + "mhlo/transforms/expand_hlo_tuples/expand_hlo_tuples.cc", + "mhlo/transforms/expand_ops_simplifier/expand_ops_simplifier.cc", + "mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc", + "mhlo/transforms/hlo_legalize_shape_ops_to_standard/hlo_legalize_shape_ops_to_standard.cc", + "mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc", + "mhlo/transforms/hlo_legalize_to_lhlo/hlo_legalize_to_lhlo.cc", + "mhlo/transforms/hlo_legalize_to_memref/hlo_legalize_to_memref.cc", + "mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc", + "mhlo/transforms/legalize_control_flow/legalize_control_flow.cc", + "mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc", + "mhlo/transforms/legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc", + "mhlo/transforms/legalize_mhlo_to_thlo/legalize_mhlo_to_thlo.cc", + "mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc", + "mhlo/transforms/legalize_sort/legalize_sort.cc", + "mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc", + "mhlo/transforms/legalize_to_standard/generated_legalize_to_standard.inc", + "mhlo/transforms/legalize_to_standard/legalize_to_standard.cc", + "mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc", + "mhlo/transforms/lower_complex/generated_lower_complex.inc", + "mhlo/transforms/lower_complex/lower_complex.cc", + "mhlo/transforms/lower_general_dot/lower_general_dot.cc", + "mhlo/transforms/materialize_broadcasts/materialize_broadcasts.cc", + "mhlo/transforms/materialize_broadcasts/materialize_broadcasts_pass.cc", + "mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc", + "mhlo/transforms/mhlo_canonicalize_gather/mhlo_canonicalize_gather.cc", + "mhlo/transforms/mhlo_canonicalize_reduction/mhlo_canonicalize_reduction.cc", + "mhlo/transforms/mhlo_canonicalize_scatter/mhlo_canonicalize_scatter.cc", + "mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc", + "mhlo/transforms/mhlo_passes.h.inc", + "mhlo/transforms/optimize_mhlo/optimize_mhlo.cc", + "mhlo/transforms/optimize_mhlo/optimize_mhlo_pass.cc", + "mhlo/transforms/prepare_for_export/prepare_for_export.cc", + "mhlo/transforms/rank_specialization/rank_specialization.cc", + "mhlo/transforms/restrict_max_rank/restrict_max_rank.cc", + "mhlo/transforms/shape_reification/shape_reification_pass.cc", + "mhlo/transforms/shape_simplification/shape_simplification.cc", + "mhlo/transforms/sink_constants_to_control_flow/sink_constants_to_control_flow.cc", + "mhlo/transforms/sparse_chlo_legalize_to_linalg/sparse_chlo_legalize_to_linalg.cc", + "mhlo/transforms/sparse_rewriting/sparse_rewriting.cc", + "mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc", + "mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc", + "mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc", + "mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc", ], hdrs = [ - "include/mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h", - "include/mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h", - "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", - "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", + "mhlo/interfaces/bufferizable_op_interface_impl.h", + "mhlo/transforms/passes.h", + "mhlo/transforms/rewriters.h", + "mhlo/utils/legalize_to_linalg_utils.h", ], - includes = ["include"], + strip_include_prefix = ".", deps = [ ":chlo_legalize_to_hlo", ":gml_st_bufferizable_op_interface", @@ -636,6 +647,7 @@ cc_library( ":mhlo_pass_inc_gen", ":mhlo_scatter_gather_utils", ":mlir_hlo", + ":shape_component_analysis", ":stablehlo_legalize_to_hlo", ":thlo", ":thlo_bufferizable_op_interface", @@ -655,6 +667,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:LinalgUtils", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", @@ -662,6 +675,7 @@ cc_library( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:ShapeTransforms", + "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", @@ -674,9 +688,9 @@ cc_library( cc_library( name = "type_conversion", - srcs = ["lib/Dialect/mhlo/transforms/type_conversion.cc"], - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"], - includes = ["include"], + srcs = ["mhlo/utils/type_conversion.cc"], + hdrs = ["mhlo/utils/type_conversion.h"], + strip_include_prefix = ".", deps = [ ":mlir_hlo", "@llvm-project//mlir:FuncDialect", @@ -690,7 +704,8 @@ cc_library( cc_library( name = "map_lmhlo_to_scalar_op", - hdrs = ["include/mlir-hlo/Dialect/lhlo/transforms/map_lmhlo_to_scalar_op.h"], + hdrs = ["lhlo/transforms/map_lmhlo_to_scalar_op.h"], + strip_include_prefix = ".", deps = [ ":lhlo", ":map_lhlo_to_hlo_op", @@ -708,7 +723,8 @@ cc_library( cc_library( name = "map_mhlo_to_scalar_op", - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h"], + hdrs = ["mhlo/transforms/map_mhlo_to_scalar_op.h"], + strip_include_prefix = ".", deps = [ ":mlir_hlo", "@llvm-project//llvm:Support", @@ -723,7 +739,8 @@ cc_library( cc_library( name = "map_chlo_to_hlo_op", - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"], + hdrs = ["mhlo/transforms/map_chlo_to_hlo_op.h"], + strip_include_prefix = ".", deps = [ ":mlir_hlo", "@llvm-project//mlir:IR", @@ -733,7 +750,8 @@ cc_library( cc_library( name = "map_hlo_to_lhlo_op", - hdrs = ["include/mlir-hlo/Dialect/lhlo/transforms/map_hlo_to_lhlo_op.h"], + hdrs = ["lhlo/transforms/map_hlo_to_lhlo_op.h"], + strip_include_prefix = ".", deps = [ ":lhlo", ":mlir_hlo", @@ -742,7 +760,8 @@ cc_library( cc_library( name = "map_lhlo_to_hlo_op", - hdrs = ["include/mlir-hlo/Dialect/lhlo/transforms/map_lhlo_to_hlo_op.h"], + hdrs = ["lhlo/transforms/map_lhlo_to_hlo_op.h"], + strip_include_prefix = ".", deps = [ ":lhlo", ":mlir_hlo", @@ -751,7 +770,8 @@ cc_library( cc_library( name = "map_stablehlo_to_hlo_op", - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_stablehlo_to_hlo_op.h"], + hdrs = ["mhlo/transforms/map_stablehlo_to_hlo_op.h"], + strip_include_prefix = ".", deps = [ ":mlir_hlo", "@llvm-project//mlir:IR", @@ -762,14 +782,14 @@ cc_library( cc_library( name = "lmhlo_passes", srcs = [ - "include/mlir-hlo/Dialect/lhlo/transforms/lmhlo_passes.h.inc", - "lib/Dialect/lhlo/transforms/legalize_to_tensor_op.cc", - "lib/Dialect/lhlo/transforms/lhlo_fuse_linalg.cc", - "lib/Dialect/lhlo/transforms/lhlo_legalize_to_affine.cc", - "lib/Dialect/lhlo/transforms/lhlo_legalize_to_gpu.cc", - "lib/Dialect/lhlo/transforms/lhlo_legalize_to_parallel_loops.cc", - ], - hdrs = ["include/mlir-hlo/Dialect/lhlo/transforms/passes.h"], + "lhlo/transforms/legalize_to_tensor_op/legalize_to_tensor_op.cc", + "lhlo/transforms/lhlo_legalize_to_affine/lhlo_legalize_to_affine.cc", + "lhlo/transforms/lhlo_legalize_to_gpu/lhlo_legalize_to_gpu.cc", + "lhlo/transforms/lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc", + "lhlo/transforms/lmhlo_passes.h.inc", + ], + hdrs = ["lhlo/transforms/passes.h"], + strip_include_prefix = ".", deps = [ ":lhlo", ":lmhlo_pass_inc_gen", @@ -782,7 +802,6 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgAnalysis", "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:MemRefDialect", @@ -799,9 +818,9 @@ cc_library( cc_library( name = "codegen_utils", - srcs = ["lib/utils/codegen_utils.cc"], - hdrs = ["include/mlir-hlo/utils/codegen_utils.h"], - includes = ["include"], + srcs = ["utils/codegen_utils.cc"], + hdrs = ["utils/codegen_utils.h"], + strip_include_prefix = ".", deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", @@ -815,15 +834,16 @@ cc_library( cc_library( name = "placement_utils", - hdrs = ["include/mlir-hlo/utils/placement_utils.h"], - includes = ["include"], + hdrs = ["utils/placement_utils.h"], + strip_include_prefix = ".", deps = ["@llvm-project//llvm:Support"], ) cc_library( name = "lhlo_elemental_utils", - srcs = ["lib/Dialect/lhlo/transforms/lhlo_elemental_utils.cc"], - hdrs = ["include/mlir-hlo/Dialect/lhlo/transforms/lhlo_elemental_utils.h"], + srcs = ["lhlo/transforms/lhlo_elemental_utils.cc"], + hdrs = ["lhlo/transforms/lhlo_elemental_utils.h"], + strip_include_prefix = ".", deps = [ ":codegen_utils", ":lhlo", @@ -842,8 +862,9 @@ cc_library( cc_library( name = "legalize_to_linalg_utils", - srcs = ["lib/Dialect/mhlo/transforms/legalize_to_linalg_utils.cc"], - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h"], + srcs = ["mhlo/utils/legalize_to_linalg_utils.cc"], + hdrs = ["mhlo/utils/legalize_to_linalg_utils.h"], + strip_include_prefix = ".", deps = [ ":map_mhlo_to_scalar_op", ":mlir_hlo", @@ -868,8 +889,9 @@ cc_library( cc_library( name = "mhlo_scatter_gather_utils", - srcs = ["lib/Dialect/mhlo/transforms/mhlo_scatter_gather_utils.cc"], - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/mhlo_scatter_gather_utils.h"], + srcs = ["mhlo/utils/mhlo_scatter_gather_utils.cc"], + hdrs = ["mhlo/utils/mhlo_scatter_gather_utils.h"], + strip_include_prefix = ".", deps = [ ":mlir_hlo", "@llvm-project//mlir:DialectUtils", @@ -880,15 +902,15 @@ cc_library( gentbl_cc_library( name = "legalize_to_standard_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "lib/Dialect/mhlo/transforms/", + strip_include_prefix = "mhlo/transforms/", tbl_outs = [ ( ["-gen-rewriters"], - "lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc", + "mhlo/transforms/legalize_to_standard/generated_legalize_to_standard.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td", + td_file = "mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td", deps = [ ":hlo_ops_td_files", "@llvm-project//mlir:ArithOpsTdFiles", @@ -900,15 +922,15 @@ gentbl_cc_library( gentbl_cc_library( name = "lower_complex_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "lib/Dialect/mhlo/transforms/", + strip_include_prefix = "mhlo/transforms/", tbl_outs = [ ( ["-gen-rewriters"], - "lib/Dialect/mhlo/transforms/generated_lower_complex.inc", + "mhlo/transforms/lower_complex/generated_lower_complex.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "lib/Dialect/mhlo/transforms/lower_complex_patterns.td", + td_file = "mhlo/transforms/lower_complex/lower_complex_patterns.td", deps = [ ":hlo_ops_td_files", "@llvm-project//mlir:FuncTdFiles", @@ -917,8 +939,9 @@ gentbl_cc_library( cc_library( name = "unfuse_batch_norm", - srcs = ["lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc"], - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"], + srcs = ["mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm.cc"], + hdrs = ["mhlo/transforms/rewriters.h"], + strip_include_prefix = ".", deps = [ ":mlir_hlo", "@llvm-project//llvm:Support", @@ -933,8 +956,9 @@ cc_library( cc_library( name = "chlo_legalize_to_hlo", - srcs = ["lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc"], - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"], + srcs = ["mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc"], + hdrs = ["mhlo/transforms/rewriters.h"], + strip_include_prefix = ".", deps = [ ":chlo_legalize_to_hlo_inc_gen", ":map_chlo_to_hlo_op", @@ -955,22 +979,23 @@ cc_library( gentbl_cc_library( name = "chlo_legalize_to_hlo_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "lib/Dialect/mhlo/transforms/", + strip_include_prefix = "mhlo/transforms", tbl_outs = [ ( ["-gen-rewriters"], - "lib/Dialect/mhlo/transforms/generated_chlo_legalize_to_hlo.inc", + "mhlo/transforms/chlo_legalize_to_hlo/generated_chlo_legalize_to_hlo.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td", + td_file = "mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td", deps = [":hlo_ops_td_files"], ) cc_library( name = "hlo_legalize_to_stablehlo", - srcs = ["lib/Dialect/mhlo/transforms/hlo_legalize_to_stablehlo.cc"], - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"], + srcs = ["mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc"], + hdrs = ["mhlo/transforms/rewriters.h"], + strip_include_prefix = ".", deps = [ ":map_stablehlo_to_hlo_op", ":mlir_hlo", @@ -985,12 +1010,14 @@ cc_library( cc_library( name = "stablehlo_legalize_to_hlo", - srcs = ["lib/Dialect/mhlo/transforms/stablehlo_legalize_to_hlo.cc"], - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"], + srcs = ["mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc"], + hdrs = ["mhlo/transforms/rewriters.h"], + strip_include_prefix = ".", deps = [ ":map_stablehlo_to_hlo_op", ":mlir_hlo", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", @@ -1004,18 +1031,19 @@ cc_library( srcs = [ # These are not exposed as headers in the dependent targets, and # shouldn't be. Ideally, this entire target should be removed. - "include/mlir-hlo/Dialect/gml_st/transforms/passes.h.inc", - "include/mlir-hlo/Dialect/lhlo/transforms/lmhlo_passes.h.inc", - "include/mlir-hlo/Dialect/thlo/transforms/thlo_passes.h.inc", - "include/mlir-hlo/Transforms/passes.h.inc", + "lhlo/transforms/lmhlo_passes.h.inc", + "gml_st/transforms/passes.h.inc", + "thlo/transforms/thlo_passes.h.inc", + "transforms/passes.h.inc", ], hdrs = [ - "include/mlir-hlo/Dialect/gml_st/transforms/passes.h", - "include/mlir-hlo/Dialect/lhlo/transforms/passes.h", - "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", - "include/mlir-hlo/Dialect/thlo/transforms/passes.h", - "include/mlir-hlo/Transforms/passes.h", + "gml_st/transforms/passes.h", + "lhlo/transforms/passes.h", + "mhlo/transforms/passes.h", + "thlo/transforms/passes.h", + "transforms/passes.h", ], + strip_include_prefix = ".", deps = [ ":chlo_legalize_to_hlo", ":gml_st_passes", @@ -1033,6 +1061,7 @@ cc_library( ":userange_analysis", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:Pass", ], ) @@ -1040,30 +1069,29 @@ cc_library( cc_library( name = "transforms_passes", srcs = [ - "include/mlir-hlo/Transforms/passes.h.inc", - "include/mlir-hlo/Transforms/rewriters.h", - "lib/Analysis/test_shape_component_analysis.cc", - "lib/Analysis/test_userange_analysis.cc", - "lib/Transforms/alloc_to_arg_pass.cc", - "lib/Transforms/buffer_packing.cc", - "lib/Transforms/buffer_reuse.cc", - "lib/Transforms/bufferize.cc", - "lib/Transforms/bufferize_pass.cc", - "lib/Transforms/collapse_parallel_loops_to_1d_pass.cc", - "lib/Transforms/copy_removal.cc", - "lib/Transforms/detensorize_scf_ops.cc", - "lib/Transforms/generic_host_to_llvm.cc", - "lib/Transforms/inline_fusion_pass.cc", - "lib/Transforms/lower_index_cast_pass.cc", - "lib/Transforms/propagate_static_shapes_to_kernel.cc", - "lib/Transforms/scalarization.cc", - "lib/Transforms/shape_simplification.cc", - "lib/Transforms/symbolic_shape_optimization.cc", - "lib/Transforms/tile_loops_pass.cc", - "lib/Transforms/unbufferize_pass.cc", - "lib/Transforms/unroll_loops.cc", - ], - hdrs = ["include/mlir-hlo/Transforms/passes.h"], + "analysis/test_userange_analysis.cc", + "mhlo/analysis/test_shape_component_analysis.cc", + "transforms/alloc_to_arg_pass.cc", + "transforms/buffer_packing.cc", + "transforms/buffer_reuse.cc", + "transforms/bufferize.cc", + "transforms/bufferize_pass.cc", + "transforms/collapse_parallel_loops_to_1d_pass.cc", + "transforms/copy_removal.cc", + "transforms/detensorize_scf_ops.cc", + "transforms/generic_host_to_llvm.cc", + "transforms/lower_index_cast_pass.cc", + "transforms/propagate_static_shapes_to_kernel.cc", + "transforms/tile_loops_pass.cc", + "transforms/unbufferize_pass.cc", + "transforms/unroll_loops.cc", + ], + hdrs = [ + "transforms/passes.h", + "transforms/passes.h.inc", + "transforms/rewriters.h", + ], + strip_include_prefix = ".", deps = [ ":gml_st", ":gml_st_bufferizable_op_interface", @@ -1079,6 +1107,7 @@ cc_library( ":userange_analysis", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithToLLVM", @@ -1092,6 +1121,7 @@ cc_library( "@llvm-project//mlir:ControlFlowDialect", "@llvm-project//mlir:ControlFlowToLLVM", "@llvm-project//mlir:CopyOpInterface", + "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncToLLVM", "@llvm-project//mlir:FuncTransforms", @@ -1132,12 +1162,14 @@ cc_library( cc_library( name = "transforms_gpu_passes", srcs = [ - "include/mlir-hlo/Transforms/gpu_passes.h.inc", - "lib/Transforms/gpu_fusion_rewrite.cc", - "lib/Transforms/gpu_kernel_lowering_passes.cc", - "lib/Transforms/hlo_to_gpu_pipeline.cc", - ], - hdrs = ["include/mlir-hlo/Transforms/gpu_passes.h"], + "transforms/gpu_fusion_rewrite.cc", + "transforms/gpu_kernel_lowering_passes.cc", + "transforms/gpu_passes.h.inc", + "transforms/hlo_to_gpu_pipeline.cc", + "transforms/hlo_to_triton_pipeline.cc", + ], + hdrs = ["transforms/gpu_passes.h"], + strip_include_prefix = ".", deps = [ ":gml_st_passes", ":gpu_transforms_passes_inc_gen", @@ -1148,6 +1180,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithToLLVM", "@llvm-project//mlir:ArithTransforms", "@llvm-project//mlir:BufferizationDialect", @@ -1167,6 +1200,7 @@ cc_library( "@llvm-project//mlir:MathToLLVM", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MemRefToLLVM", + "@llvm-project//mlir:MemRefTransforms", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:ROCDLDialect", @@ -1174,39 +1208,43 @@ cc_library( "@llvm-project//mlir:SCFToControlFlow", "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:ShapeToStandard", + "@llvm-project//mlir:TensorInferTypeOpInterfaceImpl", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToLLVM", + "@llvm-project//mlir:VectorTransforms", ], ) gentbl_cc_library( name = "gml_st_test_passes_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", + strip_include_prefix = ".", tbl_outs = [ ( [ "-gen-pass-decls", "-name=GmlStTest", ], - "include/mlir-hlo/Dialect/gml_st/transforms/test_passes.h.inc", + "gml_st/transforms/test_passes.h.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/gml_st/transforms/test_passes.td", + td_file = "gml_st/transforms/test_passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], ) cc_library( name = "gml_st_test_passes", srcs = [ - "include/mlir-hlo/Dialect/gml_st/transforms/test_passes.h.inc", - "lib/Dialect/gml_st/transforms/test_passes.cc", + "gml_st/transforms/test_passes.cc", + "gml_st/transforms/test_passes.h.inc", ], - hdrs = ["include/mlir-hlo/Dialect/gml_st/transforms/test_passes.h"], - includes = ["include"], + hdrs = ["gml_st/transforms/test_passes.h"], + strip_include_prefix = ".", deps = [ ":gml_st_bufferizable_op_interface", + ":gml_st_passes", ":gml_st_test_passes_inc_gen", ":gml_st_transforms", "@llvm-project//mlir:AffineDialect", @@ -1223,44 +1261,44 @@ cc_library( gentbl_cc_library( name = "transforms_passes_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", + strip_include_prefix = ".", tbl_outs = [ ( [ "-gen-pass-decls", "-name=LMHLOTransforms", ], - "include/mlir-hlo/Transforms/passes.h.inc", + "transforms/passes.h.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Transforms/passes.td", + td_file = "transforms/passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], ) gentbl_cc_library( name = "gpu_transforms_passes_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", + strip_include_prefix = ".", tbl_outs = [ ( [ "-gen-pass-decls", "-name=LMHLOGPUTransforms", ], - "include/mlir-hlo/Transforms/gpu_passes.h.inc", + "transforms/gpu_passes.h.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Transforms/gpu_passes.td", + td_file = "transforms/gpu_passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], ) cc_library( name = "userange_analysis", - srcs = ["lib/Analysis/userange_analysis.cc"], - hdrs = ["include/mlir-hlo/Analysis/userange_analysis.h"], - includes = ["include"], + srcs = ["analysis/userange_analysis.cc"], + hdrs = ["analysis/userange_analysis.h"], + strip_include_prefix = ".", deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", @@ -1272,9 +1310,9 @@ cc_library( cc_library( name = "shape_component_analysis", - srcs = ["lib/Analysis/shape_component_analysis.cc"], - hdrs = ["include/mlir-hlo/Analysis/shape_component_analysis.h"], - includes = ["include"], + srcs = ["mhlo/analysis/shape_component_analysis.cc"], + hdrs = ["mhlo/analysis/shape_component_analysis.h"], + strip_include_prefix = ".", deps = [ ":mlir_hlo", "@llvm-project//llvm:Support", @@ -1289,31 +1327,50 @@ cc_library( cc_library( name = "gml_st_passes", srcs = [ - "include/mlir-hlo/Dialect/gml_st/transforms/passes.h.inc", - "include/mlir-hlo/Dialect/gml_st/transforms/transforms.h", - "lib/Dialect/gml_st/transforms/collapse_materialize_ops.cc", - "lib/Dialect/gml_st/transforms/fusion.cc", - "lib/Dialect/gml_st/transforms/gml_st_to_gpu.cc", - "lib/Dialect/gml_st/transforms/gml_st_to_scf.cc", - "lib/Dialect/gml_st/transforms/linalg_utils.cc", - "lib/Dialect/gml_st/transforms/tiling.cc", - "lib/Dialect/gml_st/transforms/tiling_cwise.cc", - "lib/Dialect/gml_st/transforms/tiling_gpu_warp.cc", - "lib/Dialect/gml_st/transforms/tiling_softmax.cc", - "lib/Dialect/gml_st/transforms/transform_map_for_cpu.cc", - "lib/Dialect/gml_st/transforms/transform_matmul_for_cpu.cc", - "lib/Dialect/gml_st/transforms/transform_scatter_for_cpu.cc", - "lib/Dialect/gml_st/transforms/transform_transpose_for_cpu.cc", - "lib/Dialect/gml_st/transforms/vectorization.cc", + "gml_st/transforms/add_debug_info/add_debug_info.cc", + "gml_st/transforms/collapse_shape/collapse_shape.cc", + "gml_st/transforms/compose_extract_insert_slice/compose_extract_insert_slice.cc", + "gml_st/transforms/cpu_tiling/cpu_tiling_pipeline.cc", + "gml_st/transforms/cpu_tiling/transform_map_for_cpu.cc", + "gml_st/transforms/cpu_tiling/transform_matmul_for_cpu.cc", + "gml_st/transforms/cpu_tiling/transform_reduce_for_cpu.cc", + "gml_st/transforms/cpu_tiling/transform_reverse_for_cpu.cc", + "gml_st/transforms/cpu_tiling/transform_scatter_for_cpu.cc", + "gml_st/transforms/cpu_tiling/transform_sort_for_cpu.cc", + "gml_st/transforms/cpu_tiling/transform_transpose_for_cpu.cc", + "gml_st/transforms/fusion/fusion.cc", + "gml_st/transforms/gml_st_simtfy/gml_st_simtfy.cc", + "gml_st/transforms/gml_st_to_gpu/gml_st_to_gpu.cc", + "gml_st/transforms/gml_st_to_scf/gml_st_to_scf.cc", + "gml_st/transforms/gpu_tiling/greedy_fusion.cc", + "gml_st/transforms/gpu_tiling/tiling_cwise.cc", + "gml_st/transforms/gpu_tiling/tiling_gpu_warp.cc", + "gml_st/transforms/passes.h.inc", + "gml_st/transforms/peeling/peeling.cc", + "gml_st/transforms/rewrite_vector_ops/rewrite_vector_contract.cc", + "gml_st/transforms/rewrite_vector_ops/rewrite_vector_multi_reduction.cc", + "gml_st/transforms/rewrite_vector_ops/rewrite_vector_transpose.cc", + "gml_st/transforms/scalarization/scalarization.cc", + "gml_st/transforms/tiling/tiling.cc", + "gml_st/transforms/tiling_softmax/tiling_softmax.cc", + "gml_st/transforms/transforms.h", + "gml_st/transforms/triton_tiling/transform_matmul_for_triton.cc", + "gml_st/transforms/vectorization/vectorization.cc", + "gml_st/transforms/vectorization/vectorize_copy.cc", + "gml_st/transforms/vectorization/vectorize_for_cpu.cc", + "gml_st/transforms/vectorization/vectorize_for_gpu.cc", + "gml_st/utils/linalg_utils.cc", ], hdrs = [ - "include/mlir-hlo/Dialect/gml_st/transforms/fusion.h", - "include/mlir-hlo/Dialect/gml_st/transforms/linalg_utils.h", - "include/mlir-hlo/Dialect/gml_st/transforms/passes.h", - "include/mlir-hlo/Dialect/gml_st/transforms/rewriters.h", - "include/mlir-hlo/Dialect/gml_st/transforms/tiling.h", - "include/mlir-hlo/Dialect/gml_st/transforms/vector_utils.h", - ], + "gml_st/transforms/fusion/fusion.h", + "gml_st/transforms/passes.h", + "gml_st/transforms/peeling/peeling.h", + "gml_st/transforms/tiling/tiling.h", + "gml_st/transforms/vectorization/vectorization.h", + "gml_st/utils/linalg_utils.h", + "gml_st/utils/vector_utils.h", + ], + strip_include_prefix = ".", deps = [ ":gml_st", ":gml_st_passes_inc_gen", @@ -1321,11 +1378,11 @@ cc_library( ":lhlo", ":mlir_hlo", ":thlo", - ":tiling_interface", - ":tiling_interface_impl", ":type_conversion", + "@llvm-project//llvm:BinaryFormat", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineUtils", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithUtils", "@llvm-project//mlir:BufferizationDialect", @@ -1342,43 +1399,50 @@ cc_library( "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:LinalgUtils", + "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:ShapeTransforms", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TensorInferTypeOpInterfaceImpl", + "@llvm-project//mlir:TensorTilingInterfaceImpl", "@llvm-project//mlir:TensorTransforms", "@llvm-project//mlir:TensorUtils", + "@llvm-project//mlir:TilingInterface", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorTransforms", + "@llvm-project//mlir:X86VectorTransforms", "@stablehlo//:chlo_ops", ], ) CAPI_HEADERS = [ - "include/mlir-hlo-c/Attributes.h", - "include/mlir-hlo-c/Dialects.h", - "include/mlir-hlo-c/Passes.h", - "include/mlir-hlo-c/Types.h", + "bindings/c/Attributes.h", + "bindings/c/Dialects.h", + "bindings/c/Passes.h", + "bindings/c/Types.h", ] CAPI_SOURCES = [ - "lib/CAPI/Attributes.cc", - "lib/CAPI/Dialects.cc", - "lib/CAPI/Passes.cc", - "lib/CAPI/Types.cc", + "bindings/c/Attributes.cc", + "bindings/c/Dialects.cc", + "bindings/c/Passes.cc", + "bindings/c/Types.cc", ] cc_library( name = "CAPI", srcs = CAPI_SOURCES, hdrs = CAPI_HEADERS, + strip_include_prefix = ".", deps = [ ":all_passes", ":mlir_hlo", @@ -1390,7 +1454,7 @@ cc_library( cc_library( name = "CAPIHeaders", hdrs = CAPI_HEADERS, - includes = ["include"], + strip_include_prefix = ".", deps = ["@llvm-project//mlir:CAPIIRHeaders"], ) @@ -1399,6 +1463,7 @@ cc_library( name = "CAPIObjects", srcs = CAPI_SOURCES, hdrs = CAPI_HEADERS, + strip_include_prefix = ".", deps = [ ":all_passes", ":mlir_hlo", @@ -1435,7 +1500,6 @@ cc_binary( td_library( name = "MhloOpsPyTdFiles", srcs = ["@llvm-project//mlir:include/mlir/Bindings/Python/Attributes.td"], - includes = ["include"], deps = [ ":hlo_ops_td_files", "@llvm-project//mlir:OpBaseTdFiles", @@ -1450,27 +1514,27 @@ gentbl_filegroup( "-gen-python-op-bindings", "-bind-dialect=mhlo", ], - "python/mlir/dialects/_mhlo_ops_gen.py", + "bindings/python/mlir/dialects/_mhlo_ops_gen.py", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "python/mlir/dialects/MhloOps.td", + td_file = "bindings/python/mlir/dialects/MhloOps.td", deps = [":MhloOpsPyTdFiles"], ) filegroup( name = "MhloOpsPyFiles", srcs = [ - "python/mlir/dialects/mhlo.py", + "bindings/python/mlir/dialects/mhlo.py", ":MhloOpsPyGen", ], ) td_library( name = "gml_st_ops_td_files", - srcs = glob(["include/mlir-hlo/Dialect/gml_st/IR/*.td"]), + srcs = glob(["gml_st/IR/*.td"]), compatible_with = get_compatible_with_cloud(), - includes = ["include"], + includes = ["."], deps = [ "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:DialectUtilsTdFiles", @@ -1485,54 +1549,51 @@ td_library( gentbl_cc_library( name = "gml_st_ops_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", + strip_include_prefix = ".", tbl_outs = [ ( ["-gen-op-decls"], - "include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h.inc", + "gml_st/IR/gml_st_ops.h.inc", ), ( ["-gen-op-defs"], - "include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.cc.inc", + "gml_st/IR/gml_st_ops.cc.inc", ), ( ["-gen-dialect-decls"], - "include/mlir-hlo/Dialect/gml_st/IR/gml_st_dialect.h.inc", + "gml_st/IR/gml_st_dialect.h.inc", ), ( ["-gen-dialect-defs"], - "include/mlir-hlo/Dialect/gml_st/IR/gml_st_dialect.cc.inc", + "gml_st/IR/gml_st_dialect.cc.inc", ), ( ["-gen-typedef-decls"], - "include/mlir-hlo/Dialect/gml_st/IR/gml_st_types.h.inc", + "gml_st/IR/gml_st_types.h.inc", ), ( ["-gen-typedef-defs"], - "include/mlir-hlo/Dialect/gml_st/IR/gml_st_types.cc.inc", + "gml_st/IR/gml_st_types.cc.inc", ), ( ["-gen-attrdef-decls"], - "include/mlir-hlo/Dialect/gml_st/IR/gml_st_attrs.h.inc", + "gml_st/IR/gml_st_attrs.h.inc", ), ( ["-gen-attrdef-defs"], - "include/mlir-hlo/Dialect/gml_st/IR/gml_st_attrs.cc.inc", + "gml_st/IR/gml_st_attrs.cc.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.td", - td_srcs = [ - "include/mlir-hlo/Dialect/gml_st/IR/gml_st_legacy_ops.td", - ], + td_file = "gml_st/IR/gml_st_ops.td", deps = [":gml_st_ops_td_files"], ) cc_library( name = "gml_st", - srcs = ["lib/Dialect/gml_st/IR/gml_st_ops.cc"], - hdrs = ["include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"], - includes = ["include"], + srcs = ["gml_st/IR/gml_st_ops.cc"], + hdrs = ["gml_st/IR/gml_st_ops.h"], + strip_include_prefix = ".", deps = [ ":gml_st_ops_inc_gen", "@llvm-project//llvm:Support", @@ -1547,82 +1608,16 @@ cc_library( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorUtils", - "@llvm-project//mlir:ViewLikeInterface", - ], -) - -td_library( - name = "tiling_interface_td_files", - srcs = ["include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.td"], - includes = ["include"], - deps = ["@llvm-project//mlir:OpBaseTdFiles"], -) - -gentbl_cc_library( - name = "tiling_interface_inc_gen", - compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", - tbl_outs = [ - ( - ["-gen-op-interface-decls"], - "include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.td", - deps = ["@llvm-project//mlir:OpBaseTdFiles"], -) - -cc_library( - name = "tiling_interface", - srcs = [ - "include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.cc.inc", - "include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.h.inc", - "lib/Dialect/gml_st/transforms/tiling_interface.cc", - ], - hdrs = ["include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.h"], - includes = ["include"], - deps = [ - ":tiling_interface_inc_gen", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:IR", + "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:ViewLikeInterface", ], ) -cc_library( - name = "tiling_interface_impl", - srcs = ["lib/Dialect/gml_st/transforms/tiling_interface_impl.cc"], - hdrs = ["include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h"], - includes = ["include"], - deps = [ - ":gml_st", - ":thlo", - ":tiling_interface", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:DestinationStyleOpInterface", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgDialect", - "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:LinalgUtils", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TensorUtils", - ], -) - cc_library( name = "gml_st_bufferizable_op_interface", - srcs = ["lib/Dialect/gml_st/transforms/bufferizable_op_interface_impl.cc"], - hdrs = ["include/mlir-hlo/Dialect/gml_st/transforms/bufferizable_op_interface_impl.h"], - includes = ["include"], + srcs = ["gml_st/interfaces/bufferizable_op_interface_impl.cc"], + hdrs = ["gml_st/interfaces/bufferizable_op_interface_impl.h"], + strip_include_prefix = ".", deps = [ ":gml_st", "@llvm-project//mlir:ArithDialect", @@ -1630,6 +1625,7 @@ cc_library( "@llvm-project//mlir:BufferizationDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:ViewLikeInterface", ], ) @@ -1637,43 +1633,48 @@ cc_library( gentbl_cc_library( name = "gml_st_passes_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", + strip_include_prefix = ".", tbl_outs = [ ( [ "-gen-pass-decls", "-name=GmlSt", ], - "include/mlir-hlo/Dialect/gml_st/transforms/passes.h.inc", + "gml_st/transforms/passes.h.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/gml_st/transforms/passes.td", + td_file = "gml_st/transforms/passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], ) cc_library( name = "gml_st_transforms", - srcs = ["lib/Dialect/gml_st/transforms/transforms.cc"], - hdrs = ["include/mlir-hlo/Dialect/gml_st/transforms/transforms.h"], - includes = ["include"], + srcs = ["gml_st/transforms/transforms.cc"], + hdrs = ["gml_st/transforms/transforms.h"], + strip_include_prefix = ".", deps = [ ":gml_st", + "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithUtils", "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFUtils", "@llvm-project//mlir:TensorUtils", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorDialect", ], ) td_library( name = "thlo_ops_td_files", - srcs = glob(["include/mlir-hlo/Dialect/thlo/IR/*.td"]), + srcs = glob(["thlo/IR/*.td"]), compatible_with = get_compatible_with_cloud(), - includes = ["include"], + includes = ["."], deps = [ "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:OpBaseTdFiles", @@ -1684,43 +1685,42 @@ td_library( gentbl_cc_library( name = "thlo_ops_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", + strip_include_prefix = ".", tbl_outs = [ ( ["-gen-op-decls"], - "include/mlir-hlo/Dialect/thlo/IR/thlo_ops.h.inc", + "thlo/IR/thlo_ops.h.inc", ), ( ["-gen-op-defs"], - "include/mlir-hlo/Dialect/thlo/IR/thlo_ops.cc.inc", + "thlo/IR/thlo_ops.cc.inc", ), ( ["-gen-dialect-decls"], - "include/mlir-hlo/Dialect/thlo/IR/thlo_dialect.h.inc", + "thlo/IR/thlo_dialect.h.inc", ), ( ["-gen-dialect-defs"], - "include/mlir-hlo/Dialect/thlo/IR/thlo_dialect.cc.inc", + "thlo/IR/thlo_dialect.cc.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/thlo/IR/thlo_ops.td", - td_srcs = ["include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.td"], + td_file = "thlo/IR/thlo_ops.td", deps = [ ":thlo_ops_td_files", - "@llvm-project//mlir:LinalgStructuredOpsTdFiles", + "@llvm-project//mlir:DestinationStyleOpInterfaceTdFiles", + "@llvm-project//mlir:TilingInterfaceTdFiles", ], ) cc_library( name = "thlo", - srcs = ["lib/Dialect/thlo/IR/thlo_ops.cc"], - hdrs = ["include/mlir-hlo/Dialect/thlo/IR/thlo_ops.h"], - includes = ["include"], + srcs = ["thlo/IR/thlo_ops.cc"], + hdrs = ["thlo/IR/thlo_ops.h"], + strip_include_prefix = ".", deps = [ ":gml_st", ":thlo_ops_inc_gen", - ":tiling_interface", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithUtils", @@ -1736,15 +1736,16 @@ cc_library( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TensorUtils", + "@llvm-project//mlir:TilingInterface", "@llvm-project//mlir:ViewLikeInterface", ], ) cc_library( name = "thlo_bufferizable_op_interface", - srcs = ["lib/Dialect/thlo/transforms/bufferizable_op_interface_impl.cc"], - hdrs = ["include/mlir-hlo/Dialect/thlo/transforms/bufferizable_op_interface_impl.h"], - includes = ["include"], + srcs = ["thlo/interfaces/bufferizable_op_interface_impl.cc"], + hdrs = ["thlo/interfaces/bufferizable_op_interface_impl.h"], + strip_include_prefix = ".", deps = [ ":thlo", "@llvm-project//mlir:BufferizationDialect", @@ -1755,36 +1756,38 @@ cc_library( gentbl_cc_library( name = "thlo_passes_inc_gen", compatible_with = get_compatible_with_cloud(), - strip_include_prefix = "include", + strip_include_prefix = ".", tbl_outs = [ ( [ "-gen-pass-decls", "-name=AllThlo", ], - "include/mlir-hlo/Dialect/thlo/transforms/thlo_passes.h.inc", + "thlo/transforms/thlo_passes.h.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-hlo/Dialect/thlo/transforms/thlo_passes.td", + td_file = "thlo/transforms/thlo_passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], ) cc_library( name = "thlo_passes", srcs = [ - "include/mlir-hlo/Dialect/thlo/transforms/thlo_passes.h.inc", - "lib/Dialect/thlo/transforms/legalize_sort.cc", + "thlo/transforms/legalize_sort/legalize_sort.cc", + "thlo/transforms/thlo_passes.h.inc", ], hdrs = [ - "include/mlir-hlo/Dialect/thlo/transforms/passes.h", + "thlo/transforms/passes.h", ], + strip_include_prefix = ".", deps = [ ":thlo", ":thlo_passes_inc_gen", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithUtils", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", @@ -1800,3 +1803,106 @@ cc_binary( linkstatic = False, deps = ["@llvm-project//mlir:mlir_c_runner_utils"], ) + +cc_library( + name = "mlir_interpreter_dialects", + srcs = glob( + [ + "tools/mlir_interpreter/dialects/*.cc", + ], + exclude = ["tools/mlir_interpreter/dialects/util.cc"], + ), + strip_include_prefix = ".", + deps = [ + ":gml_st", + ":mlir_hlo", + ":mlir_interpreter_dialect_utils", + ":mlir_interpreter_framework", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:VectorDialect", + "@llvm-project//mlir:ViewLikeInterface", + ], + alwayslink = 1, +) + +cc_library( + name = "mlir_interpreter_dialect_utils", + srcs = [ + "tools/mlir_interpreter/dialects/util.cc", + ], + hdrs = [ + "tools/mlir_interpreter/dialects/comparators.h", + "tools/mlir_interpreter/dialects/cwise_math.h", + "tools/mlir_interpreter/dialects/util.h", + ], + strip_include_prefix = ".", + deps = [ + ":mlir_interpreter_framework", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:ViewLikeInterface", + ], +) + +cc_library( + name = "mlir_interpreter_framework", + srcs = [ + "tools/mlir_interpreter/framework/interpreter.cc", + "tools/mlir_interpreter/framework/interpreter_value.cc", + "tools/mlir_interpreter/framework/registration.cc", + "tools/mlir_interpreter/framework/tensor_or_memref.cc", + ], + hdrs = [ + "tools/mlir_interpreter/framework/interpreter.h", + "tools/mlir_interpreter/framework/interpreter_value.h", + "tools/mlir_interpreter/framework/interpreter_value_util.h", + "tools/mlir_interpreter/framework/registration.h", + "tools/mlir_interpreter/framework/tensor_or_memref.h", + ], + strip_include_prefix = ".", + deps = [ + "//tensorflow/tsl/platform:logging", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_binary( + name = "mlir-interpreter-runner", + srcs = ["tools/mlir_interpreter/mlir-interpreter-runner.cc"], + deps = [ + ":gml_st", + ":hlo_dialect_registration", + ":lhlo", + ":lhlo_gpu", + ":mhlo_passes", + ":mlir_interpreter_dialects", + ":mlir_interpreter_framework", + ":thlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MlirReduceLib", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/xla/mlir_hlo/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/CMakeLists.txt index d9b372a5b53..f4e5bf8c2e3 100644 --- a/tensorflow/compiler/xla/mlir_hlo/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/CMakeLists.txt @@ -125,6 +125,7 @@ include(HandleLLVMOptions) include_directories(${LLVM_INCLUDE_DIRS}) include_directories(${MLIR_INCLUDE_DIRS}) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/) link_directories(${LLVM_BUILD_LIBRARY_DIR}) @@ -153,22 +154,26 @@ endif() set(MLIR_HLO_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) set(MLIR_HLO_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) -set(MLIR_HLO_MAIN_INCLUDE_DIR ${MLIR_HLO_SOURCE_DIR}/include ) +set(MLIR_HLO_MAIN_INCLUDE_DIR ${MLIR_HLO_SOURCE_DIR}/include) set(MLIR_HLO_GEN_INCLUDE_DIR ${MLIR_HLO_BINARY_DIR}/include) set(MLIR_HLO_TOOLS_DIR ${MLIR_HLO_BINARY_DIR}/bin) set(MLIR_HLO_LIB_DIR ${MLIR_HLO_BINARY_DIR}/lib) add_custom_target(check-mlir-hlo) -add_subdirectory(include/mlir-hlo) -add_subdirectory(lib) +add_subdirectory(analysis) +add_subdirectory(bindings) +add_subdirectory(gml_st) +add_subdirectory(lhlo) +add_subdirectory(lhlo_gpu) +add_subdirectory(mhlo) add_subdirectory(stablehlo) +add_subdirectory(tests) +add_subdirectory(thlo) add_subdirectory(tools) add_subdirectory(tosa) -add_subdirectory(tests) +add_subdirectory(transforms) +add_subdirectory(utils) -if(MHLO_ENABLE_BINDINGS_PYTHON) - add_subdirectory(python) -endif() add_subdirectory(cmake/modules) diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Analysis/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/analysis/CMakeLists.txt similarity index 84% rename from tensorflow/compiler/xla/mlir_hlo/lib/Analysis/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/analysis/CMakeLists.txt index a68ef2e84da..88a0fec1494 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Analysis/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/analysis/CMakeLists.txt @@ -1,5 +1,4 @@ add_mlir_library(MLIRHLOAnalysis - shape_component_analysis.cc userange_analysis.cc DEPENDS @@ -11,7 +10,6 @@ add_mlir_library(MLIRHLOAnalysis ) add_mlir_library(MLIRHLOTestAnalysis - test_shape_component_analysis.cc test_userange_analysis.cc DEPENDS @@ -27,4 +25,4 @@ add_mlir_library(MLIRHLOTestAnalysis MLIRAnalysis MLIRPass MLIRTransforms - ) +) diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Analysis/test_userange_analysis.cc b/tensorflow/compiler/xla/mlir_hlo/analysis/test_userange_analysis.cc similarity index 91% rename from tensorflow/compiler/xla/mlir_hlo/lib/Analysis/test_userange_analysis.cc rename to tensorflow/compiler/xla/mlir_hlo/analysis/test_userange_analysis.cc index 97aca6e3b67..f29916297ed 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Analysis/test_userange_analysis.cc +++ b/tensorflow/compiler/xla/mlir_hlo/analysis/test_userange_analysis.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Analysis/userange_analysis.h" -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" +#include "analysis/userange_analysis.h" +#include "lhlo/IR/lhlo_ops.h" #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" #include "mlir/Pass/Pass.h" @@ -22,7 +22,7 @@ limitations under the License. namespace mlir { #define GEN_PASS_DEF_TESTUSERANGE -#include "mlir-hlo/Transforms/passes.h.inc" +#include "transforms/passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Analysis/userange_analysis.cc b/tensorflow/compiler/xla/mlir_hlo/analysis/userange_analysis.cc similarity index 99% rename from tensorflow/compiler/xla/mlir_hlo/lib/Analysis/userange_analysis.cc rename to tensorflow/compiler/xla/mlir_hlo/analysis/userange_analysis.cc index 51d90f33dbe..05731828b1b 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Analysis/userange_analysis.cc +++ b/tensorflow/compiler/xla/mlir_hlo/analysis/userange_analysis.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Analysis/userange_analysis.h" +#include "analysis/userange_analysis.h" #include #include diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Analysis/userange_analysis.h b/tensorflow/compiler/xla/mlir_hlo/analysis/userange_analysis.h similarity index 97% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Analysis/userange_analysis.h rename to tensorflow/compiler/xla/mlir_hlo/analysis/userange_analysis.h index 219f556b65b..ba6131ae86e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Analysis/userange_analysis.h +++ b/tensorflow/compiler/xla/mlir_hlo/analysis/userange_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H #define MLIR_HLO_ANALYSIS_USERANGE_ANALYSIS_H +#include #include #include "mlir/Analysis/Liveness.h" @@ -110,7 +111,7 @@ class UserangeAnalysis { /// empty Optional if the value has no uses. llvm::Optional getFirstUseIndex(Value value) const { auto &intervals = useIntervalMap.find(value)->second; - if (intervals.empty()) return llvm::None; + if (intervals.empty()) return std::nullopt; return intervals.begin()->start; } @@ -118,7 +119,7 @@ class UserangeAnalysis { llvm::Optional getUserangeInterval( Value value) const { auto intervals = useIntervalMap.find(value); - if (intervals == useIntervalMap.end()) return llvm::None; + if (intervals == useIntervalMap.end()) return std::nullopt; return &intervals->second; } @@ -128,7 +129,7 @@ class UserangeAnalysis { Value value) const { auto usePosition = usePositionMap.find(value); if (usePosition == usePositionMap.end() || usePosition->second.empty()) - return llvm::None; + return std::nullopt; return &usePosition->second; } diff --git a/tensorflow/compiler/xla/mlir_hlo/bindings/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/bindings/CMakeLists.txt new file mode 100644 index 00000000000..ec1a1c36e99 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/bindings/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(c) + +if(MHLO_ENABLE_BINDINGS_PYTHON) + add_subdirectory(python) +endif() diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/CAPI/Attributes.cc b/tensorflow/compiler/xla/mlir_hlo/bindings/c/Attributes.cc similarity index 79% rename from tensorflow/compiler/xla/mlir_hlo/lib/CAPI/Attributes.cc rename to tensorflow/compiler/xla/mlir_hlo/bindings/c/Attributes.cc index 101dcf9aa1d..7e5b16faea3 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/CAPI/Attributes.cc +++ b/tensorflow/compiler/xla/mlir_hlo/bindings/c/Attributes.cc @@ -10,9 +10,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo-c/Attributes.h" +#include "bindings/c/Attributes.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include + +#include "mhlo/IR/hlo_ops.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" @@ -26,10 +28,9 @@ MlirAttribute mlirMhloScatterDimensionNumbersGet( const int64_t *insertedWindowDims, intptr_t nScatteredDimsToOperandDims, const int64_t *scatteredDimsToOperandDims, int64_t indexVectorDim) { return wrap(mlir::mhlo::ScatterDimensionNumbersAttr::get( - unwrap(ctx), llvm::makeArrayRef(updateWindowDims, nUpdateWindowDims), - llvm::makeArrayRef(insertedWindowDims, nInsertedWindowDims), - llvm::makeArrayRef(scatteredDimsToOperandDims, - nScatteredDimsToOperandDims), + unwrap(ctx), llvm::ArrayRef(updateWindowDims, nUpdateWindowDims), + llvm::ArrayRef(insertedWindowDims, nInsertedWindowDims), + llvm::ArrayRef(scatteredDimsToOperandDims, nScatteredDimsToOperandDims), indexVectorDim)); } @@ -98,9 +99,9 @@ MlirAttribute mlirMhloGatherDimensionNumbersGet( intptr_t nStartIndexMap, const int64_t *startIndexMap, int64_t indexVectorDim) { return wrap(mlir::mhlo::GatherDimensionNumbersAttr::get( - unwrap(ctx), llvm::makeArrayRef(offsetDims, nOffsetDims), - llvm::makeArrayRef(collapsedSliceDims, nCollapsedSliceDims), - llvm::makeArrayRef(startIndexMap, nStartIndexMap), indexVectorDim)); + unwrap(ctx), llvm::ArrayRef(offsetDims, nOffsetDims), + llvm::ArrayRef(collapsedSliceDims, nCollapsedSliceDims), + llvm::ArrayRef(startIndexMap, nStartIndexMap), indexVectorDim)); } bool mlirMhloAttributeIsAGatherDimensionNumbers(MlirAttribute attr) { @@ -169,10 +170,10 @@ MlirAttribute mlirMhloDotDimensionNumbersGet( const int64_t *rhsContractingDimensions) { return wrap(mlir::mhlo::DotDimensionNumbersAttr::get( unwrap(ctx), - llvm::makeArrayRef(lhsBatchingDimensions, nLhsBatchingDimensions), - llvm::makeArrayRef(rhsBatchingDimensions, nRhsBatchingDimensions), - llvm::makeArrayRef(lhsContractingDimensions, nLhsContractingDimensions), - llvm::makeArrayRef(rhsContractingDimensions, nRhsContractingDimensions))); + llvm::ArrayRef(lhsBatchingDimensions, nLhsBatchingDimensions), + llvm::ArrayRef(rhsBatchingDimensions, nRhsBatchingDimensions), + llvm::ArrayRef(lhsContractingDimensions, nLhsContractingDimensions), + llvm::ArrayRef(rhsContractingDimensions, nRhsContractingDimensions))); } bool mlirMhloAttributeIsADotDimensionNumbers(MlirAttribute attr) { @@ -252,11 +253,11 @@ MlirAttribute mlirMhloConvDimensionNumbersGet( intptr_t nOutputSpatialDimensions, const int64_t *outputSpatialDimensions) { return wrap(mlir::mhlo::ConvDimensionNumbersAttr::get( unwrap(ctx), inputBatchDimension, inputFeatureDimension, - llvm::makeArrayRef(inputSpatialDimensions, nInputSpatialDimensions), + llvm::ArrayRef(inputSpatialDimensions, nInputSpatialDimensions), kernelInputFeatureDimension, kernelOutputFeatureDimension, - llvm::makeArrayRef(kernelSpatialDimensions, nKernelSpatialDimensions), + llvm::ArrayRef(kernelSpatialDimensions, nKernelSpatialDimensions), outputBatchDimension, outputFeatureDimension, - llvm::makeArrayRef(outputSpatialDimensions, nOutputSpatialDimensions))); + llvm::ArrayRef(outputSpatialDimensions, nOutputSpatialDimensions))); } bool mlirMhloAttributeIsAConvDimensionNumbers(MlirAttribute attr) { @@ -358,9 +359,8 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirMhloOutputOperandAliasGet( const int64_t *outputTupleIndices, int64_t operandIndex, intptr_t nOperandTupleIndices, const int64_t *operandTupleIndices) { return wrap(mlir::mhlo::OutputOperandAliasAttr::get( - unwrap(ctx), llvm::makeArrayRef(outputTupleIndices, nOutputTupleIndices), - operandIndex, - llvm::makeArrayRef(operandTupleIndices, nOperandTupleIndices))); + unwrap(ctx), llvm::ArrayRef(outputTupleIndices, nOutputTupleIndices), + operandIndex, llvm::ArrayRef(operandTupleIndices, nOperandTupleIndices))); } bool mlirMhloAttributeIsAOutputOperandAlias(MlirAttribute attr) { @@ -407,20 +407,19 @@ int64_t mlirMhloOutputOperandAliasGetOperandTupleIndicesElem(MlirAttribute attr, // ComparisonDirectionAttr. // MlirAttribute mlirMhloComparisonDirectionAttrGet(MlirContext ctx, - MlirStringRef direction) { - llvm::Optional compareDirection = - mlir::mhlo::symbolizeComparisonDirection(unwrap(direction)); - if (!compareDirection) - llvm_unreachable("Invalid comparison-direction specified."); + MlirStringRef value) { + std::optional comparisonDirection = + mlir::mhlo::symbolizeComparisonDirection(unwrap(value)); + if (!comparisonDirection) llvm_unreachable("Invalid value."); return wrap(mlir::mhlo::ComparisonDirectionAttr::get( - unwrap(ctx), compareDirection.value())); + unwrap(ctx), comparisonDirection.value())); } bool mlirMhloAttributeIsAComparisonDirectionAttr(MlirAttribute attr) { return unwrap(attr).isa(); } -MlirStringRef mlirMhloComparisonDirectionAttrGetDirection(MlirAttribute attr) { +MlirStringRef mlirMhloComparisonDirectionAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyComparisonDirection( unwrap(attr).cast().getValue())); } @@ -430,19 +429,19 @@ MlirStringRef mlirMhloComparisonDirectionAttrGetDirection(MlirAttribute attr) { // MlirAttribute mlirMhloComparisonTypeAttrGet(MlirContext ctx, - MlirStringRef type) { - llvm::Optional compareType = - mlir::mhlo::symbolizeComparisonType(unwrap(type)); - if (!compareType) llvm_unreachable("Invalid comparison-type specified."); + MlirStringRef value) { + std::optional comparisonType = + mlir::mhlo::symbolizeComparisonType(unwrap(value)); + if (!comparisonType) llvm_unreachable("Invalid value."); return wrap( - mlir::mhlo::ComparisonTypeAttr::get(unwrap(ctx), compareType.value())); + mlir::mhlo::ComparisonTypeAttr::get(unwrap(ctx), comparisonType.value())); } bool mlirMhloAttributeIsAComparisonTypeAttr(MlirAttribute attr) { return unwrap(attr).isa(); } -MlirStringRef mlirMhloComparisonTypeAttrGetType(MlirAttribute attr) { +MlirStringRef mlirMhloComparisonTypeAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyComparisonType( unwrap(attr).cast().getValue())); } @@ -451,10 +450,10 @@ MlirStringRef mlirMhloComparisonTypeAttrGetType(MlirAttribute attr) { // DomainKindAttr. // -MlirAttribute mlirMhloDomainKindAttrGet(MlirContext ctx, MlirStringRef kind) { - llvm::Optional domainKind = - mlir::mhlo::symbolizeDomainKind(unwrap(kind)); - if (!domainKind) llvm_unreachable("Invalid domain kind specified."); +MlirAttribute mlirMhloDomainKindAttrGet(MlirContext ctx, MlirStringRef value) { + std::optional domainKind = + mlir::mhlo::symbolizeDomainKind(unwrap(value)); + if (!domainKind) llvm_unreachable("Invalid value."); return wrap(mlir::mhlo::DomainKindAttr::get(unwrap(ctx), domainKind.value())); } @@ -462,7 +461,7 @@ bool mlirMhloAttributeIsADomainKindAttr(MlirAttribute attr) { return unwrap(attr).isa(); } -MlirStringRef mlirMhloDomainKindAttrGetType(MlirAttribute attr) { +MlirStringRef mlirMhloDomainKindAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyDomainKind( unwrap(attr).cast().getValue())); } @@ -471,19 +470,18 @@ MlirStringRef mlirMhloDomainKindAttrGetType(MlirAttribute attr) { // PrecisionAttr. // -MlirAttribute mlirMhloPrecisionAttrGet(MlirContext ctx, MlirStringRef type) { - llvm::Optional precisionType = - mlir::mhlo::symbolizePrecision(unwrap(type)); - if (!precisionType) llvm_unreachable("Invalid precision-type specified."); - return wrap( - mlir::mhlo::PrecisionAttr::get(unwrap(ctx), precisionType.value())); +MlirAttribute mlirMhloPrecisionAttrGet(MlirContext ctx, MlirStringRef value) { + std::optional precision = + mlir::mhlo::symbolizePrecision(unwrap(value)); + if (!precision) llvm_unreachable("Invalid value specified."); + return wrap(mlir::mhlo::PrecisionAttr::get(unwrap(ctx), precision.value())); } bool mlirMhloAttributeIsAPrecisionAttr(MlirAttribute attr) { return unwrap(attr).isa(); } -MlirStringRef mlirMhloPrecisionAttrGetPrecision(MlirAttribute attr) { +MlirStringRef mlirMhloPrecisionAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyPrecision( unwrap(attr).cast().getValue())); } @@ -492,10 +490,10 @@ MlirStringRef mlirMhloPrecisionAttrGetPrecision(MlirAttribute attr) { // FftTypeAttr. // -MlirAttribute mlirMhloFftTypeAttrGet(MlirContext ctx, MlirStringRef type) { - llvm::Optional fftType = - mlir::mhlo::symbolizeFftType(unwrap(type)); - if (!fftType) llvm_unreachable("Invalid fft-type specified."); +MlirAttribute mlirMhloFftTypeAttrGet(MlirContext ctx, MlirStringRef value) { + std::optional fftType = + mlir::mhlo::symbolizeFftType(unwrap(value)); + if (!fftType) llvm_unreachable("Invalid value."); return wrap(mlir::mhlo::FftTypeAttr::get(unwrap(ctx), fftType.value())); } @@ -503,7 +501,7 @@ bool mlirMhloAttributeIsAFftTypeAttr(MlirAttribute attr) { return unwrap(attr).isa(); } -MlirStringRef mlirMhloFftTypeAttrGetFftType(MlirAttribute attr) { +MlirStringRef mlirMhloFftTypeAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyFftType( unwrap(attr).cast().getValue())); } @@ -513,10 +511,10 @@ MlirStringRef mlirMhloFftTypeAttrGetFftType(MlirAttribute attr) { // MlirAttribute mlirMhloDequantizeModeAttrGet(MlirContext ctx, - MlirStringRef mode) { - llvm::Optional dequantizeMode = - mlir::mhlo::symbolizeDequantizeMode(unwrap(mode)); - if (!dequantizeMode) llvm_unreachable("Invalid dequantize-mode specified."); + MlirStringRef value) { + std::optional dequantizeMode = + mlir::mhlo::symbolizeDequantizeMode(unwrap(value)); + if (!dequantizeMode) llvm_unreachable("Invalid value."); return wrap( mlir::mhlo::DequantizeModeAttr::get(unwrap(ctx), dequantizeMode.value())); } @@ -525,7 +523,7 @@ bool mlirMhloAttributeIsADequantizeModeAttr(MlirAttribute attr) { return unwrap(attr).isa(); } -MlirStringRef mlirMhloDequantizeModeAttrGetDequantizeMode(MlirAttribute attr) { +MlirStringRef mlirMhloDequantizeModeAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyDequantizeMode( unwrap(attr).cast().getValue())); } @@ -534,19 +532,18 @@ MlirStringRef mlirMhloDequantizeModeAttrGetDequantizeMode(MlirAttribute attr) { // TransposeAttr. // -MlirAttribute mlirMhloTransposeAttrGet(MlirContext ctx, MlirStringRef type) { - llvm::Optional transposeType = - mlir::mhlo::symbolizeTranspose(unwrap(type)); - if (!transposeType) llvm_unreachable("Invalid transpose-type specified."); - return wrap( - mlir::mhlo::TransposeAttr::get(unwrap(ctx), transposeType.value())); +MlirAttribute mlirMhloTransposeAttrGet(MlirContext ctx, MlirStringRef value) { + std::optional transpose = + mlir::mhlo::symbolizeTranspose(unwrap(value)); + if (!transpose) llvm_unreachable("Invalid value."); + return wrap(mlir::mhlo::TransposeAttr::get(unwrap(ctx), transpose.value())); } bool mlirMhloAttributeIsATransposeAttr(MlirAttribute attr) { return unwrap(attr).isa(); } -MlirStringRef mlirMhloTransposeAttrGetTranspose(MlirAttribute attr) { +MlirStringRef mlirMhloTransposeAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyTranspose( unwrap(attr).cast().getValue())); } @@ -555,10 +552,10 @@ MlirStringRef mlirMhloTransposeAttrGetTranspose(MlirAttribute attr) { // FusionKindAttr. // -MlirAttribute mlirMhloFusionKindAttrGet(MlirContext ctx, MlirStringRef kind) { - llvm::Optional fusionKind = - mlir::mhlo::symbolizeFusionKind(unwrap(kind)); - if (!fusionKind) llvm_unreachable("Invalid fusion-kind specified."); +MlirAttribute mlirMhloFusionKindAttrGet(MlirContext ctx, MlirStringRef value) { + std::optional fusionKind = + mlir::mhlo::symbolizeFusionKind(unwrap(value)); + if (!fusionKind) llvm_unreachable("Invalid value."); return wrap(mlir::mhlo::FusionKindAttr::get(unwrap(ctx), fusionKind.value())); } @@ -566,7 +563,7 @@ bool mlirMhloAttributeIsAFusionKindAttr(MlirAttribute attr) { return unwrap(attr).isa(); } -MlirStringRef mlirMhloFusionKindAttrGetFusionKind(MlirAttribute attr) { +MlirStringRef mlirMhloFusionKindAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyFusionKind( unwrap(attr).cast().getValue())); } @@ -576,10 +573,10 @@ MlirStringRef mlirMhloFusionKindAttrGetFusionKind(MlirAttribute attr) { // MlirAttribute mlirMhloRngDistributionAttrGet(MlirContext ctx, - MlirStringRef distribution) { - llvm::Optional rngDistribution = - mlir::mhlo::symbolizeRngDistribution(unwrap(distribution)); - if (!rngDistribution) llvm_unreachable("Invalid rng-distribution specified."); + MlirStringRef value) { + std::optional rngDistribution = + mlir::mhlo::symbolizeRngDistribution(unwrap(value)); + if (!rngDistribution) llvm_unreachable("Invalid value."); return wrap(mlir::mhlo::RngDistributionAttr::get(unwrap(ctx), rngDistribution.value())); } @@ -588,8 +585,7 @@ bool mlirMhloAttributeIsARngDistributionAttr(MlirAttribute attr) { return unwrap(attr).isa(); } -MlirStringRef mlirMhloRngDistributionAttrGetRngDistribution( - MlirAttribute attr) { +MlirStringRef mlirMhloRngDistributionAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyRngDistribution( unwrap(attr).cast().getValue())); } @@ -599,10 +595,10 @@ MlirStringRef mlirMhloRngDistributionAttrGetRngDistribution( // MlirAttribute mlirMhloRngAlgorithmAttrGet(MlirContext ctx, - MlirStringRef algorithm) { - llvm::Optional rngAlgorithm = - mlir::mhlo::symbolizeRngAlgorithm(unwrap(algorithm)); - if (!rngAlgorithm) llvm_unreachable("Invalid rng-algorithm specified."); + MlirStringRef value) { + std::optional rngAlgorithm = + mlir::mhlo::symbolizeRngAlgorithm(unwrap(value)); + if (!rngAlgorithm) llvm_unreachable("Invalid value."); return wrap( mlir::mhlo::RngAlgorithmAttr::get(unwrap(ctx), rngAlgorithm.value())); } @@ -611,7 +607,7 @@ bool mlirMhloAttributeIsARngAlgorithmAttr(MlirAttribute attr) { return unwrap(attr).isa(); } -MlirStringRef mlirMhloRngAlgorithmAttrGetRngAlgorithm(MlirAttribute attr) { +MlirStringRef mlirMhloRngAlgorithmAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyRngAlgorithm( unwrap(attr).cast().getValue())); } @@ -644,7 +640,7 @@ int64_t mlirMhloChannelHandleGetType(MlirAttribute attr) { MlirAttribute mlirMhloTypeExtensionsGet(MlirContext ctx, intptr_t nBounds, const int64_t *bounds) { return wrap(mlir::mhlo::TypeExtensionsAttr::get( - unwrap(ctx), llvm::makeArrayRef(bounds, nBounds))); + unwrap(ctx), llvm::ArrayRef(bounds, nBounds))); } bool mlirMhloAttributeIsTypeExtensions(MlirAttribute attr) { diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo-c/Attributes.h b/tensorflow/compiler/xla/mlir_hlo/bindings/c/Attributes.h similarity index 86% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo-c/Attributes.h rename to tensorflow/compiler/xla/mlir_hlo/bindings/c/Attributes.h index e91c0a5714c..624440c2ce8 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo-c/Attributes.h +++ b/tensorflow/compiler/xla/mlir_hlo/bindings/c/Attributes.h @@ -9,8 +9,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_C_ATTRIBUTES_H -#define MLIR_HLO_C_ATTRIBUTES_H +#ifndef MLIR_HLO_BINDINGS_C_ATTRIBUTES_H +#define MLIR_HLO_BINDINGS_C_ATTRIBUTES_H #include @@ -197,9 +197,9 @@ MLIR_CAPI_EXPORTED int64_t mlirMhloOutputOperandAliasGetOperandTupleIndicesElem( // ComparisonDirectionAttr. // // Creates a new ComparisonDirection attribute with the given -// 'direction' string parameter. +// 'value' string parameter. MLIR_CAPI_EXPORTED MlirAttribute -mlirMhloComparisonDirectionAttrGet(MlirContext ctx, MlirStringRef direction); +mlirMhloComparisonDirectionAttrGet(MlirContext ctx, MlirStringRef value); // Returns true if the given attribute is a ComparisonDirection attribute. MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsAComparisonDirectionAttr( @@ -207,15 +207,15 @@ MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsAComparisonDirectionAttr( // Returns the direction string associated with ComparisonDirection attribute. MLIR_CAPI_EXPORTED MlirStringRef -mlirMhloComparisonDirectionAttrGetDirection(MlirAttribute attr); +mlirMhloComparisonDirectionAttrGetValue(MlirAttribute attr); // // ComparisonTypeAttr. // -// Creates a new ComparisonType attribute with the given 'type' string +// Creates a new ComparisonType attribute with the given 'value' string // parameter. MLIR_CAPI_EXPORTED MlirAttribute -mlirMhloComparisonTypeAttrGet(MlirContext ctx, MlirStringRef type); +mlirMhloComparisonTypeAttrGet(MlirContext ctx, MlirStringRef value); // Returns true if the given attribute is a ComparisonType attribute. MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsAComparisonTypeAttr( @@ -223,128 +223,127 @@ MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsAComparisonTypeAttr( // Returns the type string associated with ComparisonType attribute. MLIR_CAPI_EXPORTED MlirStringRef -mlirMhloComparisonTypeAttrGetType(MlirAttribute attr); +mlirMhloComparisonTypeAttrGetValue(MlirAttribute attr); // // DomainKindAttr. // -// Creates a new DomainKind attribute with the given 'kind' string +// Creates a new DomainKind attribute with the given 'value' string // parameter. MLIR_CAPI_EXPORTED MlirAttribute mlirMhloDomainKindAttrGet(MlirContext ctx, - MlirStringRef kind); + MlirStringRef value); // Returns true if the given attribute is a DomainKind attribute. MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsADomainKindAttr(MlirAttribute attr); -// Returns the type string associated with DomainKind attribute. +// Returns the value string associated with DomainKind attribute. MLIR_CAPI_EXPORTED MlirStringRef -mlirMhloDomainKindAttrGetType(MlirAttribute attr); +mlirMhloDomainKindAttrGetValue(MlirAttribute attr); // // PrecisionAttr. // -// Creates a new Precision attribute with the given 'type' string +// Creates a new Precision attribute with the given 'value' string // parameter. MLIR_CAPI_EXPORTED MlirAttribute mlirMhloPrecisionAttrGet(MlirContext ctx, - MlirStringRef type); + MlirStringRef value); // Returns true if the given attribute is a Precision attribute. MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsAPrecisionAttr(MlirAttribute attr); -// Returns the type string associated with Precision attribute. +// Returns the value string associated with Precision attribute. MLIR_CAPI_EXPORTED MlirStringRef -mlirMhloPrecisionAttrGetPrecision(MlirAttribute attr); +mlirMhloPrecisionAttrGetValue(MlirAttribute attr); // // FftTypeAttr. // -// Creates a new FftType attribute with the given 'type' string parameter. +// Creates a new FftType attribute with the given 'value' string parameter. MLIR_CAPI_EXPORTED MlirAttribute mlirMhloFftTypeAttrGet(MlirContext ctx, - MlirStringRef type); + MlirStringRef value); // Returns true if the given attribute is a FftType attribute. MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsAFftTypeAttr(MlirAttribute attr); -// Returns the type string associated with FftType attribute. +// Returns the value string associated with FftType attribute. MLIR_CAPI_EXPORTED MlirStringRef -mlirMhloFftTypeAttrGetFftType(MlirAttribute attr); +mlirMhloFftTypeAttrGetValue(MlirAttribute attr); // // DequantizeModeAttr. // -// Creates a new DequantizeMode attribute with the given 'mode' string +// Creates a new DequantizeMode attribute with the given 'value' string // parameter. MLIR_CAPI_EXPORTED MlirAttribute -mlirMhloDequantizeModeAttrGet(MlirContext ctx, MlirStringRef mode); +mlirMhloDequantizeModeAttrGet(MlirContext ctx, MlirStringRef value); // Returns true if the given attribute is a DequantizeMode attribute. MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsADequantizeModeAttr( MlirAttribute attr); -// Returns the mode string associated with DequantizeMode attribute. +// Returns the value string associated with DequantizeMode attribute. MLIR_CAPI_EXPORTED MlirStringRef -mlirMhloDequantizeModeAttrGetDequantizeMode(MlirAttribute attr); +mlirMhloDequantizeModeAttrGetValue(MlirAttribute attr); // // TransposeAttr. // -// Creates a new Transpose attribute with the given 'type' string parameter. +// Creates a new Transpose attribute with the given 'value' string parameter. MLIR_CAPI_EXPORTED MlirAttribute mlirMhloTransposeAttrGet(MlirContext ctx, - MlirStringRef type); + MlirStringRef value); // Returns true if the given attribute is a Transpose attribute. MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsATransposeAttr(MlirAttribute attr); -// Returns the type string associated with Transpose attribute. +// Returns the value string associated with Transpose attribute. MLIR_CAPI_EXPORTED MlirStringRef -mlirMhloTransposeAttrGetTranspose(MlirAttribute attr); +mlirMhloTransposeAttrGetValue(MlirAttribute attr); // // FusionKindAttr. // -// Creates a new FusionKind attribute with the given 'kind' string parameter. +// Creates a new FusionKind attribute with the given 'value' string parameter. MLIR_CAPI_EXPORTED MlirAttribute mlirMhloFusionKindAttrGet(MlirContext ctx, - MlirStringRef kind); + MlirStringRef value); // Returns true if the given attribute is a FusionKind attribute. MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsAFusionKindAttr(MlirAttribute attr); -// Returns the fusion-kind string associated with FusionKind attribute. +// Returns the value string associated with FusionKind attribute. MLIR_CAPI_EXPORTED MlirStringRef -mlirMhloFusionKindAttrGetFusionKind(MlirAttribute attr); +mlirMhloFusionKindAttrGetValue(MlirAttribute attr); // // RngDistributionAttr. // -// Creates a new RngDistribution attribute with the given 'distribution' string +// Creates a new RngDistribution attribute with the given 'value' string // parameter. MLIR_CAPI_EXPORTED MlirAttribute -mlirMhloRngDistributionAttrGet(MlirContext ctx, MlirStringRef distribution); +mlirMhloRngDistributionAttrGet(MlirContext ctx, MlirStringRef value); // Returns true if the given attribute is a RngDistribution attribute. MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsARngDistributionAttr( MlirAttribute attr); -// Returns the rng-distribution string associated with RngDistribution -// attribute. +// Returns the value string associated with RngDistribution attribute. MLIR_CAPI_EXPORTED MlirStringRef -mlirMhloRngDistributionAttrGetRngDistribution(MlirAttribute attr); +mlirMhloRngDistributionAttrGetValue(MlirAttribute attr); // // RngAlgorithmAttr. // -// Creates a new RngAlgorithm attribute with the given 'algorithm' string +// Creates a new RngAlgorithm attribute with the given 'value' string // parameter. MLIR_CAPI_EXPORTED MlirAttribute -mlirMhloRngAlgorithmAttrGet(MlirContext ctx, MlirStringRef algorithm); +mlirMhloRngAlgorithmAttrGet(MlirContext ctx, MlirStringRef value); // Returns true if the given attribute is a RngAlgorithm attribute. MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsARngAlgorithmAttr( MlirAttribute attr); -// Returns the rng-algorithm string associated with RngAlgorithm attribute. +// Returns the value string associated with RngAlgorithm attribute. MLIR_CAPI_EXPORTED MlirStringRef -mlirMhloRngAlgorithmAttrGetRngAlgorithm(MlirAttribute attr); +mlirMhloRngAlgorithmAttrGetValue(MlirAttribute attr); // // ChannelHandle @@ -386,4 +385,4 @@ mlirMhloTypeExtensionsGetBoundsElem(MlirAttribute attr, intptr_t pos); } #endif -#endif // MLIR_HLO_C_ATTRIBUTES_H +#endif // MLIR_HLO_BINDINGS_C_ATTRIBUTES_H diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/CAPI/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/bindings/c/CMakeLists.txt similarity index 100% rename from tensorflow/compiler/xla/mlir_hlo/lib/CAPI/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/bindings/c/CMakeLists.txt diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/CAPI/Dialects.cc b/tensorflow/compiler/xla/mlir_hlo/bindings/c/Dialects.cc similarity index 90% rename from tensorflow/compiler/xla/mlir_hlo/lib/CAPI/Dialects.cc rename to tensorflow/compiler/xla/mlir_hlo/bindings/c/Dialects.cc index 67917cced59..edb68eb11a3 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/CAPI/Dialects.cc +++ b/tensorflow/compiler/xla/mlir_hlo/bindings/c/Dialects.cc @@ -10,9 +10,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo-c/Dialects.h" +#include "bindings/c/Dialects.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/CAPI/Registration.h" MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Mhlo, mhlo, mlir::mhlo::MhloDialect) diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo-c/Dialects.h b/tensorflow/compiler/xla/mlir_hlo/bindings/c/Dialects.h similarity index 88% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo-c/Dialects.h rename to tensorflow/compiler/xla/mlir_hlo/bindings/c/Dialects.h index 895213f5a6e..ca18c4980e4 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo-c/Dialects.h +++ b/tensorflow/compiler/xla/mlir_hlo/bindings/c/Dialects.h @@ -10,8 +10,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_C_DIALECTS_H -#define MLIR_HLO_C_DIALECTS_H +#ifndef MLIR_HLO_BINDINGS_C_DIALECTS_H +#define MLIR_HLO_BINDINGS_C_DIALECTS_H #include "mlir-c/RegisterEverything.h" @@ -26,4 +26,4 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Mhlo, mhlo); } #endif -#endif // MLIR_HLO_C_DIALECTS_H +#endif // MLIR_HLO_BINDINGS_C_DIALECTS_H diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/CAPI/Passes.cc b/tensorflow/compiler/xla/mlir_hlo/bindings/c/Passes.cc similarity index 89% rename from tensorflow/compiler/xla/mlir_hlo/lib/CAPI/Passes.cc rename to tensorflow/compiler/xla/mlir_hlo/bindings/c/Passes.cc index 8f25bf3e3ed..0a47ced1836 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/CAPI/Passes.cc +++ b/tensorflow/compiler/xla/mlir_hlo/bindings/c/Passes.cc @@ -10,8 +10,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo-c/Passes.h" +#include "bindings/c/Passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mhlo/transforms/passes.h" void mlirRegisterAllMhloPasses() { mlir::mhlo::registerAllMhloPasses(); } diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo-c/Passes.h b/tensorflow/compiler/xla/mlir_hlo/bindings/c/Passes.h similarity index 88% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo-c/Passes.h rename to tensorflow/compiler/xla/mlir_hlo/bindings/c/Passes.h index c0d254dd2a2..a2cfb784575 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo-c/Passes.h +++ b/tensorflow/compiler/xla/mlir_hlo/bindings/c/Passes.h @@ -10,8 +10,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_C_PASSES_H -#define MLIR_HLO_C_PASSES_H +#ifndef MLIR_HLO_BINDINGS_C_PASSES_H +#define MLIR_HLO_BINDINGS_C_PASSES_H #include "mlir-c/IR.h" #include "mlir-c/Support.h" @@ -27,4 +27,4 @@ MLIR_CAPI_EXPORTED void mlirRegisterAllMhloPasses(); } #endif -#endif // MLIR_HLO_C_PASSES_H +#endif // MLIR_HLO_BINDINGS_C_PASSES_H diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/CAPI/Types.cc b/tensorflow/compiler/xla/mlir_hlo/bindings/c/Types.cc similarity index 92% rename from tensorflow/compiler/xla/mlir_hlo/lib/CAPI/Types.cc rename to tensorflow/compiler/xla/mlir_hlo/bindings/c/Types.cc index 252eed07c02..b4669eccb8e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/CAPI/Types.cc +++ b/tensorflow/compiler/xla/mlir_hlo/bindings/c/Types.cc @@ -10,9 +10,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo-c/Types.h" +#include "bindings/c/Types.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/CAPI/IR.h" MlirType mlirMhloTokenTypeGet(MlirContext ctx) { diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo-c/Types.h b/tensorflow/compiler/xla/mlir_hlo/bindings/c/Types.h similarity index 90% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo-c/Types.h rename to tensorflow/compiler/xla/mlir_hlo/bindings/c/Types.h index 54a3871b3aa..6869997aa03 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo-c/Types.h +++ b/tensorflow/compiler/xla/mlir_hlo/bindings/c/Types.h @@ -10,8 +10,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_C_TYPES_H -#define MLIR_HLO_C_TYPES_H +#ifndef MLIR_HLO_BINDINGS_C_TYPES_H +#define MLIR_HLO_BINDINGS_C_TYPES_H #include "mlir-c/IR.h" #include "mlir-c/Support.h" @@ -30,4 +30,4 @@ MLIR_CAPI_EXPORTED bool mlirMhloTypeIsAToken(MlirType type); } #endif -#endif // MLIR_HLO_C_TYPES_H +#endif // MLIR_HLO_BINDINGS_C_TYPES_H diff --git a/tensorflow/compiler/xla/mlir_hlo/python/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/bindings/python/CMakeLists.txt similarity index 100% rename from tensorflow/compiler/xla/mlir_hlo/python/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/bindings/python/CMakeLists.txt diff --git a/tensorflow/compiler/xla/mlir_hlo/python/MlirHloModule.cc b/tensorflow/compiler/xla/mlir_hlo/bindings/python/MlirHloModule.cc similarity index 80% rename from tensorflow/compiler/xla/mlir_hlo/python/MlirHloModule.cc rename to tensorflow/compiler/xla/mlir_hlo/bindings/python/MlirHloModule.cc index 6c4f821959f..3215acd8aa0 100644 --- a/tensorflow/compiler/xla/mlir_hlo/python/MlirHloModule.cc +++ b/tensorflow/compiler/xla/mlir_hlo/bindings/python/MlirHloModule.cc @@ -10,11 +10,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "bindings/c/Attributes.h" +#include "bindings/c/Dialects.h" +#include "bindings/c/Passes.h" +#include "bindings/c/Types.h" #include "mlir-c/IR.h" -#include "mlir-hlo-c/Attributes.h" -#include "mlir-hlo-c/Dialects.h" -#include "mlir-hlo-c/Passes.h" -#include "mlir-hlo-c/Types.h" #include "mlir/Bindings/Python/PybindAdaptors.h" namespace py = pybind11; @@ -348,138 +348,128 @@ PYBIND11_MODULE(_mlirHlo, m) { m, "ComparisonDirectionAttr", mlirMhloAttributeIsAComparisonDirectionAttr) .def_classmethod( "get", - [](py::object cls, const std::string &direction, MlirContext ctx) { + [](py::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloComparisonDirectionAttrGet( - ctx, mlirStringRefCreate(direction.c_str(), direction.size()))); + ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("comparison_direction"), - py::arg("context") = py::none(), - "Creates a ComparisonDirection attribute with the given direction.") - .def_property_readonly("comparison_direction", [](MlirAttribute self) { - return toPyString(mlirMhloComparisonDirectionAttrGetDirection(self)); + py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + "Creates a ComparisonDirection attribute with the given value.") + .def_property_readonly("value", [](MlirAttribute self) { + return toPyString(mlirMhloComparisonDirectionAttrGetValue(self)); }); mlir::python::adaptors::mlir_attribute_subclass( m, "ComparisonTypeAttr", mlirMhloAttributeIsAComparisonTypeAttr) .def_classmethod( "get", - [](py::object cls, const std::string &type, MlirContext ctx) { + [](py::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloComparisonTypeAttrGet( - ctx, mlirStringRefCreate(type.c_str(), type.size()))); + ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("comparison_type"), - py::arg("context") = py::none(), - "Creates a ComparisonType attribute with the given type.") - .def_property_readonly("comparison_type", [](MlirAttribute self) { - return toPyString(mlirMhloComparisonTypeAttrGetType(self)); + py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + "Creates a ComparisonType attribute with the given value.") + .def_property_readonly("value", [](MlirAttribute self) { + return toPyString(mlirMhloComparisonTypeAttrGetValue(self)); }); mlir::python::adaptors::mlir_attribute_subclass( m, "PrecisionAttr", mlirMhloAttributeIsAPrecisionAttr) .def_classmethod( "get", - [](py::object cls, const std::string &type, MlirContext ctx) { + [](py::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloPrecisionAttrGet( - ctx, mlirStringRefCreate(type.c_str(), type.size()))); + ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("precision_type"), - py::arg("context") = py::none(), - "Creates a Precision attribute with the given type.") - .def_property_readonly("precision_type", [](MlirAttribute self) { - return toPyString(mlirMhloPrecisionAttrGetPrecision(self)); + py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + "Creates a Precision attribute with the given value.") + .def_property_readonly("value", [](MlirAttribute self) { + return toPyString(mlirMhloPrecisionAttrGetValue(self)); }); mlir::python::adaptors::mlir_attribute_subclass( m, "FftTypeAttr", mlirMhloAttributeIsAFftTypeAttr) .def_classmethod( "get", - [](py::object cls, const std::string &type, MlirContext ctx) { + [](py::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloFftTypeAttrGet( - ctx, mlirStringRefCreate(type.c_str(), type.size()))); + ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("fft_type"), py::arg("context") = py::none(), - "Creates a FftType attribute with the given type.") - .def_property_readonly("fft_type", [](MlirAttribute self) { - return toPyString(mlirMhloFftTypeAttrGetFftType(self)); + py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + "Creates a FftType attribute with the given value.") + .def_property_readonly("value", [](MlirAttribute self) { + return toPyString(mlirMhloFftTypeAttrGetValue(self)); }); mlir::python::adaptors::mlir_attribute_subclass( m, "DequantizeModeAttr", mlirMhloAttributeIsADequantizeModeAttr) .def_classmethod( "get", - [](py::object cls, const std::string &type, MlirContext ctx) { + [](py::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloDequantizeModeAttrGet( - ctx, mlirStringRefCreate(type.c_str(), type.size()))); + ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("dequantize_mode"), - py::arg("context") = py::none(), - "Creates a DequantizeMode attribute with the given mode.") - .def_property_readonly("dequantize_mode", [](MlirAttribute self) { - return toPyString(mlirMhloDequantizeModeAttrGetDequantizeMode(self)); + py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + "Creates a DequantizeMode attribute with the given value.") + .def_property_readonly("value", [](MlirAttribute self) { + return toPyString(mlirMhloDequantizeModeAttrGetValue(self)); }); mlir::python::adaptors::mlir_attribute_subclass( m, "TransposeAttr", mlirMhloAttributeIsATransposeAttr) .def_classmethod( "get", - [](py::object cls, const std::string &type, MlirContext ctx) { + [](py::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloTransposeAttrGet( - ctx, mlirStringRefCreate(type.c_str(), type.size()))); + ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("transpose_type"), - py::arg("context") = py::none(), - "Creates a Transpose attribute with the given type.") - .def_property_readonly("transpose_type", [](MlirAttribute self) { - return toPyString(mlirMhloTransposeAttrGetTranspose(self)); + py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + "Creates a Transpose attribute with the given value.") + .def_property_readonly("value", [](MlirAttribute self) { + return toPyString(mlirMhloTransposeAttrGetValue(self)); }); mlir::python::adaptors::mlir_attribute_subclass( m, "FusionKindAttr", mlirMhloAttributeIsAFusionKindAttr) .def_classmethod( "get", - [](py::object cls, const std::string &kind, MlirContext ctx) { + [](py::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloFusionKindAttrGet( - ctx, mlirStringRefCreate(kind.c_str(), kind.size()))); + ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("fusion_kind"), - py::arg("context") = py::none(), - "Creates a FusionKind attribute with the given kind.") - .def_property_readonly("fusion_kind", [](MlirAttribute self) { - return toPyString(mlirMhloFusionKindAttrGetFusionKind(self)); + py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + "Creates a FusionKind attribute with the given value.") + .def_property_readonly("value", [](MlirAttribute self) { + return toPyString(mlirMhloFusionKindAttrGetValue(self)); }); mlir::python::adaptors::mlir_attribute_subclass( m, "RngDistributionAttr", mlirMhloAttributeIsARngDistributionAttr) .def_classmethod( "get", - [](py::object cls, const std::string &distribution, MlirContext ctx) { + [](py::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloRngDistributionAttrGet( - ctx, mlirStringRefCreate(distribution.c_str(), - distribution.size()))); + ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("rng_distribution"), - py::arg("context") = py::none(), - "Creates a RngDistribution attribute with the given rng " - "distribution.") - .def_property_readonly("rng_distribution", [](MlirAttribute self) { - auto distribution = mlirMhloRngDistributionAttrGetRngDistribution(self); - return py::str(distribution.data, distribution.length); + py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + "Creates a RngDistribution attribute with the given value.") + .def_property_readonly("value", [](MlirAttribute self) { + auto value = mlirMhloRngDistributionAttrGetValue(self); + return py::str(value.data, value.length); }); mlir::python::adaptors::mlir_attribute_subclass( m, "RngAlgorithmAttr", mlirMhloAttributeIsARngAlgorithmAttr) .def_classmethod( "get", - [](py::object cls, const std::string &algorithm, MlirContext ctx) { + [](py::object cls, const std::string &value, MlirContext ctx) { return cls(mlirMhloRngAlgorithmAttrGet( - ctx, mlirStringRefCreate(algorithm.c_str(), algorithm.size()))); + ctx, mlirStringRefCreate(value.c_str(), value.size()))); }, - py::arg("cls"), py::arg("rng_algorithm"), - py::arg("context") = py::none(), - "Creates a RngAlgorithm attribute with the given rng algorithm.") - .def_property_readonly("rng_algorithm", [](MlirAttribute self) { - auto algorithm = mlirMhloRngAlgorithmAttrGetRngAlgorithm(self); - return py::str(algorithm.data, algorithm.length); + py::arg("cls"), py::arg("value"), py::arg("context") = py::none(), + "Creates a RngAlgorithm attribute with the given value.") + .def_property_readonly("value", [](MlirAttribute self) { + auto value = mlirMhloRngAlgorithmAttrGetValue(self); + return py::str(value.data, value.length); }); mlir::python::adaptors::mlir_attribute_subclass( diff --git a/tensorflow/compiler/xla/mlir_hlo/python/mlir/dialects/MhloOps.td b/tensorflow/compiler/xla/mlir_hlo/bindings/python/mlir/dialects/MhloOps.td similarity index 94% rename from tensorflow/compiler/xla/mlir_hlo/python/mlir/dialects/MhloOps.td rename to tensorflow/compiler/xla/mlir_hlo/bindings/python/mlir/dialects/MhloOps.td index eb9d1c52e17..f7bfea08eb2 100644 --- a/tensorflow/compiler/xla/mlir_hlo/python/mlir/dialects/MhloOps.td +++ b/tensorflow/compiler/xla/mlir_hlo/bindings/python/mlir/dialects/MhloOps.td @@ -19,6 +19,6 @@ limitations under the License. #define PYTHON_BINDINGS_MHLO_OPS include "mlir/Bindings/Python/Attributes.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "mhlo/IR/hlo_ops.td" #endif diff --git a/tensorflow/compiler/xla/mlir_hlo/python/mlir/dialects/mhlo.py b/tensorflow/compiler/xla/mlir_hlo/bindings/python/mlir/dialects/mhlo.py similarity index 100% rename from tensorflow/compiler/xla/mlir_hlo/python/mlir/dialects/mhlo.py rename to tensorflow/compiler/xla/mlir_hlo/bindings/python/mlir/dialects/mhlo.py diff --git a/tensorflow/compiler/xla/mlir_hlo/cmake/modules/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/cmake/modules/CMakeLists.txt index dea85e1a660..462a7f0475d 100644 --- a/tensorflow/compiler/xla/mlir_hlo/cmake/modules/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/cmake/modules/CMakeLists.txt @@ -23,6 +23,7 @@ set(MHLO_CONFIG_CMAKE_DIR "${mhlo_cmake_builddir}") set(MHLO_CONFIG_LLVM_CMAKE_DIR "${llvm_cmake_builddir}") set(MHLO_CONFIG_INCLUDE_EXPORTS "include(\"\${MHLO_CMAKE_DIR}/MHLOTargets.cmake\")") set(MHLO_CONFIG_INCLUDE_DIRS + "${MLIR_HLO_SOURCE_DIR}" "${MLIR_HLO_MAIN_INCLUDE_DIR}" "${MLIR_HLO_GEN_INCLUDE_DIR}" ) diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/gml_st/CMakeLists.txt similarity index 92% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/gml_st/CMakeLists.txt index 88672e5e298..47c038050ca 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/CMakeLists.txt @@ -13,4 +13,6 @@ # limitations under the License. add_subdirectory(IR) +add_subdirectory(interfaces) add_subdirectory(transforms) +add_subdirectory(utils) diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/gml_st/IR/CMakeLists.txt similarity index 72% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/gml_st/IR/CMakeLists.txt index 1f00694bf40..b1fe0fda4dd 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/IR/CMakeLists.txt @@ -24,3 +24,27 @@ mlir_tablegen(gml_st_attrs.cc.inc -gen-attrdef-defs) add_public_tablegen_target(MLIRgml_st_opsIncGen) add_dependencies(mlir-headers MLIRgml_st_opsIncGen) + +include_directories(BEFORE + ${CMAKE_CURRENT_BINARY_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}) + +add_mlir_dialect_library(GmlStDialect + gml_st_ops.cc + + DEPENDS + MLIRgml_st_opsIncGen + + LINK_LIBS PUBLIC + MLIRArithUtils + MLIRControlFlowInterfaces + MLIRIR + MLIRInferTypeOpInterface + MLIRLoopLikeInterface + MLIRMemRefDialect + MLIRSideEffectInterfaces + MLIRSupport + MLIRTensorDialect + MLIRViewLikeInterface + MLIRVectorDialect +) diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.cc new file mode 100644 index 00000000000..a1985a06609 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.cc @@ -0,0 +1,1460 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "gml_st/IR/gml_st_ops.h" + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/ViewLikeInterface.h" + +namespace mlir { +namespace { + +void printShapeTypeDimensionsList(AsmPrinter &printer, + ArrayRef integers) { + llvm::interleave( + integers, printer, + [&](int64_t val) { + if (val == ShapedType::kDynamic) { + printer << '?'; + } else { + printer << val; + } + }, + "x"); +} + +ParseResult parseShapeTypeDimensionsList(AsmParser &parser, + SmallVectorImpl &dims) { + SmallVector vals; + if (failed(parser.parseDimensionList(vals, /*allowDynamic=*/true, + /*withTrailingX=*/false))) { + return failure(); + } + dims = vals; + return success(); +} + +Type inferReturnType(ShapedType sourceType, ArrayRef tileShape) { + return sourceType.clone(tileShape, sourceType.getElementType()); +} + +LogicalResult verifyCompatibleExtractedSubset(Operation *op, + ShapedType shapedType, + Type extractedType, + ArrayRef tileShape) { + auto sourceRank = shapedType.getRank(); + auto elementType = shapedType.getElementType(); + + // If the result is a scalar, check that the tile had a single element. + if (!extractedType.isa()) { + if (extractedType != elementType) { + return op->emitOpError("expected the result type ") + << extractedType << " to match source element type " + << elementType; + } + if (!ShapedType::isDynamicShape(tileShape) && + ShapedType::getNumElements(tileShape) == 1) + return success(); + + return op->emitOpError("expected tile type ") + << tileShape << " to have a single element shape"; + } + + // If the result is a shaped type, compare with the inferred type. + auto extractedShapedType = extractedType.cast(); + unsigned tileRank = tileShape.size(); + if (tileRank != sourceRank) { + return op->emitOpError("expected source rank = ") + << sourceRank << " to match tile rank = " << tileRank; + } + + auto inferredType = shapedType.clone(tileShape, shapedType.getElementType()); + if (extractedShapedType != inferredType) { + return op->emitOpError("expected result type = ") + << extractedShapedType + << " to match the inferred type = " << inferredType; + } + + return success(); +} + +} // namespace +} // namespace mlir + +// Generated dialect definitions. +#include "gml_st/IR/gml_st_dialect.cc.inc" + +// Generated type classes. +#define GET_TYPEDEF_CLASSES +#include "gml_st/IR/gml_st_types.cc.inc" + +// Generated attribute classes. +#define GET_ATTRDEF_CLASSES +#include "gml_st/IR/gml_st_attrs.cc.inc" + +namespace mlir { +namespace gml_st { + +//===----------------------------------------------------------------------===// +// GmlStDialect +//===----------------------------------------------------------------------===// + +void GmlStDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "gml_st/IR/gml_st_ops.cc.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "gml_st/IR/gml_st_types.cc.inc" + >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "gml_st/IR/gml_st_attrs.cc.inc" + >(); +} + +Operation *GmlStDialect::materializeConstant(OpBuilder &builder, Attribute attr, + Type type, Location loc) { + if (type.isa()) { + int64_t intValue = attr.cast().getInt(); + return builder.create(loc, intValue); + } + return {}; +} + +//===----------------------------------------------------------------------===// +// MaterializeOp +//===----------------------------------------------------------------------===// + +void MaterializeOp::build(OpBuilder &b, OperationState &result, Value source, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + SmallVector staticOffsets, staticSizes, staticStrides; + SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + auto sourceType = source.getType().cast(); + Type resultType = inferReturnType(sourceType, staticSizes); + build(b, result, resultType, source, dynamicOffsets, dynamicSizes, + dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), + b.getDenseI64ArrayAttr(staticSizes), + b.getDenseI64ArrayAttr(staticStrides)); +} + +void MaterializeOp::build(OpBuilder &b, OperationState &result, Type resultType, + Value source, ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + SmallVector staticOffsets, staticSizes, staticStrides; + SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + build(b, result, resultType, source, dynamicOffsets, dynamicSizes, + dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), + b.getDenseI64ArrayAttr(staticSizes), + b.getDenseI64ArrayAttr(staticStrides)); +} + +void MaterializeOp::build(OpBuilder &b, OperationState &result, Value source, + ArrayRef offsets) { + SmallVector unitSizesAndStrides(offsets.size(), + b.getIndexAttr(1)); + build(b, result, source, offsets, unitSizesAndStrides, unitSizesAndStrides); +} + +LogicalResult MaterializeOp::verify() { + return verifyCompatibleExtractedSubset(getOperation(), getSource().getType(), + getType(), getStaticSizes()); +} + +LogicalResult MaterializeOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + reifiedReturnShapes.push_back( + getAsValues(builder, getLoc(), getMixedSizes())); + return success(); +} + +namespace { + +/// Adapted from OpWithOffsetSizesAndStridesConstantArgumentFolder, which makes +/// slightly incompatible assumptions about the op. +struct FoldConstantsIntoMaterializeOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MaterializeOp op, + PatternRewriter &rewriter) const override { + // No constant operand, just return; + if (llvm::none_of(op.getOperands(), [](Value operand) { + return matchPattern(operand, matchConstantIndex()); + })) + return failure(); + + // At least one of offsets/sizes/strides is a new constant. + // Form the new list of operands and constant attributes from the existing. + SmallVector mixedOffsets(op.getMixedOffsets()); + SmallVector mixedSizes(op.getMixedSizes()); + SmallVector mixedStrides(op.getMixedStrides()); + canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamic); + canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic); + canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamic); + + SmallVector staticSizes; + SmallVector dynamicSizes; + dispatchIndexOpFoldResults(mixedSizes, dynamicSizes, staticSizes); + + Type opResultType = op.getType(); + Type newResultType = + opResultType.isa() + ? inferReturnType(op.getSource().getType(), staticSizes) + : opResultType; + // Create the new tile in canonical form. + auto newMaterializeOp = rewriter.create( + op.getLoc(), newResultType, op.getSource(), mixedOffsets, mixedSizes, + mixedStrides); + + // Cast the result back to the original type. + if (opResultType != newResultType) { + rewriter.replaceOpWithNewOp(op, opResultType, + newMaterializeOp.getResult()); + } else { + rewriter.replaceOp(op, newMaterializeOp.getResult()); + } + return success(); + } +}; + +/// Folds tensor::CastOp sources into MaterializeOp. +struct FoldSrcCastIntoMaterialize : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MaterializeOp op, + PatternRewriter &rewriter) const override { + auto cast = op.getSource().getDefiningOp(); + if (!cast) return failure(); + + auto src = cast.getSource(); + auto shape = op.getStaticSizes(); + rewriter.replaceOpWithNewOp( + op, inferReturnType(src.getType(), shape), src, op.getMixedOffsets(), + op.getMixedSizes(), op.getMixedStrides()); + return success(); + } +}; +} // namespace + +void MaterializeOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add( + context); +} + +//===----------------------------------------------------------------------===// +// LoopLikeOp +//===----------------------------------------------------------------------===// + +namespace { + +ParseResult parseLoopLikeOpOutputArgs( + OpAsmParser &parser, OperationState &result, + SmallVectorImpl ®ionOperands, + SmallVectorImpl ®ionTypes, int32_t *outputCount) { + SmallVector outputs, outputRegionArgs; + SmallVector outputTypes; + + auto parseElt = [&]() -> ParseResult { + if (parser.parseOperand(outputRegionArgs.emplace_back(), + /*allowResultNumber=*/false) || + parser.parseEqual()) { + return failure(); + } + if (parser.parseOperand(outputs.emplace_back()) || parser.parseColon() || + parser.parseType(outputTypes.emplace_back())) { + return failure(); + } + *outputCount = static_cast(outputs.size()); + return success(); + }; + if (succeeded(parser.parseOptionalKeyword("outs"))) { + SMLoc loc = parser.getCurrentLocation(); + + if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, parseElt)) + return failure(); + if (parser.resolveOperands(outputs, outputTypes, loc, result.operands)) + return failure(); + } + regionOperands.append(outputRegionArgs); + regionTypes.append(outputTypes); + return success(); +} + +} // namespace + +template +ParseResult parseLoopLikeOp(OpAsmParser &parser, OperationState &result) { + auto &builder = parser.getBuilder(); + // Parse an opening `(` followed by induction variables followed by `)` + SmallVector ivs; + if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren, + /*allowResultNumber=*/false)) + return failure(); + + // Parse loop bounds. + SmallVector lower; + if (parser.parseEqual() || + parser.parseOperandList(lower, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(lower, builder.getIndexType(), result.operands)) + return failure(); + + SmallVector upper; + if (parser.parseKeyword("to") || + parser.parseOperandList(upper, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(upper, builder.getIndexType(), result.operands)) + return failure(); + + // Parse step values. + SmallVector steps; + if (parser.parseKeyword("step") || + parser.parseOperandList(steps, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(steps, builder.getIndexType(), result.operands)) + return failure(); + + // Parse the output tensors and the body. + SmallVector regionOperands(ivs); + SmallVector regionTypes(ivs.size(), builder.getIndexType()); + + int32_t outputCount = 0; + if (failed(parseLoopLikeOpOutputArgs(parser, result, regionOperands, + regionTypes, &outputCount))) + return failure(); + + // Parse distribution type (only for ParallelOp) + if (std::is_same::value) { + if (succeeded(parser.parseOptionalKeyword("distribution"))) { + StringAttr distributionType; + if (parser.parseLParen() || parser.parseAttribute(distributionType) || + parser.parseRParen()) + return failure(); + result.addAttribute(ParallelOp::getDistributionTypeAttrName(result.name), + distributionType); + } + } + + SmallVector regionArgs; + for (auto argAndType : llvm::zip(regionOperands, regionTypes)) { + auto &arg = regionArgs.emplace_back(); + std::tie(arg.ssaName, arg.type) = argAndType; + } + Region *body = result.addRegion(); + if (parser.parseRegion(*body, regionArgs)) return failure(); + + // Parse attributes. + if (parser.parseOptionalAttrDict(result.attributes)) return failure(); + + // Parser result types. + if (parser.parseOptionalColonTypeList(result.types)) return failure(); + + // Add segment sizes. + result.addAttribute(LoopTy::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr( + {static_cast(lower.size()), + static_cast(upper.size()), + static_cast(steps.size()), outputCount})); + + return success(); +} + +template +void buildLoopLikeOp( + OpBuilder &builder, OperationState &result, TypeRange resultTypes, + ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps, + ValueRange outputs, + function_ref + bodyBuilderFn) { + result.addOperands(lowerBounds); + result.addOperands(upperBounds); + result.addOperands(steps); + result.addOperands(outputs); + result.addTypes(resultTypes); + result.addAttribute( + LoopTy::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({static_cast(lowerBounds.size()), + static_cast(upperBounds.size()), + static_cast(steps.size()), + static_cast(outputs.size())})); + + OpBuilder::InsertionGuard guard(builder); + unsigned numIvs = steps.size(); + SmallVector argTypes(numIvs, builder.getIndexType()); + SmallVector argLocs(numIvs, result.location); + for (Value output : outputs) { + argTypes.push_back(output.getType()); + argLocs.push_back(output.getLoc()); + } + Region *bodyRegion = result.addRegion(); + Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs); + + if (bodyBuilderFn) { + builder.setInsertionPointToStart(bodyBlock); + bodyBuilderFn(builder, result.location, + bodyBlock->getArguments().take_front(numIvs), + bodyBlock->getArguments().take_back(outputs.size())); + LoopTy::ensureTerminator(*bodyRegion, builder, result.location); + } +} + +template +struct CollapseSingleIterationLoops : public OpRewritePattern { + explicit CollapseSingleIterationLoops( + MLIRContext *context, + llvm::function_ref filterFn = nullptr, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), filterFn(filterFn) {} + + LogicalResult matchAndRewrite(LoopLikeOp op, + PatternRewriter &rewriter) const override { + if (filterFn && !filterFn(op)) + return rewriter.notifyMatchFailure(op, "did not match filter"); + + IRMapping mapping; + // Compute new loop bounds that omit all single-iteration loop dimensions. + SmallVector newLowerBounds, newUpperBounds, newSteps; + newLowerBounds.reserve(op.getLowerBound().size()); + newUpperBounds.reserve(op.getUpperBound().size()); + newSteps.reserve(op.getStep().size()); + auto getConstant = [](Value v) -> Optional { + auto constant = + dyn_cast_or_null(v.getDefiningOp()); + if (constant) return constant.value(); + return std::nullopt; + }; + for (auto [lowerBound, upperBound, step, iv] : + llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(), + op.getInductionVars())) { + // Collect the statically known loop bounds. + auto lowerBoundConstant = getConstant(lowerBound); + auto upperBoundConstant = getConstant(upperBound); + auto stepConstant = getConstant(step); + // Remove the loop if it performs zero iterations. + if (lowerBoundConstant && upperBoundConstant && + *lowerBoundConstant == *upperBoundConstant) { + rewriter.replaceOp(op, op.getOutputs()); + return success(); + } + // Replace the loop induction variable by the lower bound if the loop + // performs a single iteration. Otherwise, copy the loop bounds. + if (lowerBoundConstant && upperBoundConstant && stepConstant && + (*upperBoundConstant - *lowerBoundConstant) > 0 && + (*upperBoundConstant - *lowerBoundConstant) <= *stepConstant) { + mapping.map(iv, lowerBound); + } else { + newLowerBounds.push_back(lowerBound); + newUpperBounds.push_back(upperBound); + newSteps.push_back(step); + } + } + // Exit if none of the loop dimensions perform a single iteration. + if (newLowerBounds.size() == op.getLowerBound().size()) return failure(); + + // All of the loop dimensions perform a single iteration. Inline loop body. + if (newLowerBounds.empty()) { + mapping.map(op.getRegionOutputArgs(), op.getOutputs()); + for (auto &bodyOp : op.getBody()->without_terminator()) { + rewriter.clone(bodyOp, mapping); + } + SmallVector results; + results.reserve(op.getResults().size()); + SetYieldOp terminator = op.getTerminator(); + for (const auto &[dst, src, set] : + llvm::zip(terminator.getDsts(), terminator.getSrcs(), + terminator.getSets())) { + auto tileOp = set.template getDefiningOp(); + + if (!tileOp) { + return terminator.emitOpError( + "expected the SetYieldOp terminator of gml_st loop to have a " + "TileOp set"); + } + auto getMappedValues = [&](ValueRange values) { + return llvm::to_vector(llvm::map_range(values, [&](Value value) { + return mapping.lookupOrDefault(value); + })); + }; + + if (dst.getType().template isa()) { + Value srcVal = mapping.lookupOrDefault(src); + if (srcVal.getType().isa()) { + results.push_back(rewriter.create( + op.getLoc(), dst.getType(), srcVal, + mapping.lookupOrDefault(dst), + getMappedValues(tileOp.getOffsets()), + getMappedValues(tileOp.getSizes()), + getMappedValues(tileOp.getStrides()), tileOp.getStaticOffsets(), + tileOp.getStaticSizes(), tileOp.getStaticStrides())); + } else { + SmallVector mappedOffsets = + getMappedValues(tileOp.getOffsets()); + SmallVector ofrs; + int idx = 0; + for (int64_t offset : tileOp.getStaticOffsets()) { + if (ShapedType::isDynamic(offset)) { + ofrs.push_back(mappedOffsets[idx++]); + } else { + ofrs.push_back(rewriter.getIndexAttr(offset)); + } + } + results.push_back(rewriter.create( + op.getLoc(), srcVal, mapping.lookupOrDefault(dst), + getAsValues(rewriter, op.getLoc(), ofrs))); + } + } else if (dst.getType().template isa()) { + results.push_back(rewriter.create( + op.getLoc(), dst.getType(), mapping.lookupOrDefault(src), + mapping.lookupOrDefault(dst), + rewriter.getI64ArrayAttr(tileOp.getStaticSizes()), + rewriter.getI64ArrayAttr(tileOp.getStaticStrides()))); + } else { + return op.emitOpError( + "expected output of gml_st loop to be either a tensor or a " + "vector"); + } + } + rewriter.replaceOp(op, results); + return success(); + } + + // Replace the loop by a lower-dimensional loop. + LoopLikeOp newOp; + if constexpr (std::is_same_v) { + auto parallelLoop = cast(op); + newOp = rewriter.create(op.getLoc(), op.getResultTypes(), + newLowerBounds, newUpperBounds, + newSteps, parallelLoop.getOutputs()); + } else { + newOp = rewriter.create(op.getLoc(), op.getResultTypes(), + newLowerBounds, newUpperBounds, newSteps, + op.getOutputs(), nullptr); + } + // The new loop needs to keep all attributes from the old one, except for + // "operand_segment_sizes" which captures the outdated information of the + // old iteration domain. + for (const auto &namedAttr : op->getAttrs()) { + if (namedAttr.getName() == LoopLikeOp::getOperandSegmentSizeAttr()) + continue; + newOp->setAttr(namedAttr.getName(), namedAttr.getValue()); + } + + // Clone the loop body and remap the block arguments of the collapsed loops + // (inlining does not support a cancellable block argument mapping). + rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().begin(), mapping); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } + + private: + llvm::function_ref filterFn; +}; + +//===----------------------------------------------------------------------===// +// ParallelOp +//===----------------------------------------------------------------------===// + +Region &ParallelOp::getLoopBody() { return getRegion(); } + +SetYieldOp ParallelOp::getTerminator() { + return cast(getBody()->getTerminator()); +} + +LogicalResult ParallelOp::verify() { + if (getNumResults() != getNumOutputs()) { + return emitOpError() << "expected the number of output arguments to match " + "the number of results"; + } + + // Check if types of output arguments match region args types. + for (auto &item : llvm::enumerate( + llvm::zip(getOutputs(), getRegionOutputArgs(), getResultTypes()))) { + Value output, outputRegionArg; + Type resultType; + unsigned index = item.index(); + std::tie(output, outputRegionArg, resultType) = item.value(); + if (output.getType() != outputRegionArg.getType()) { + return emitOpError("expected output arg ") + << index << " with type = " << output.getType() + << " to match region arg " << index + getNumLoops() + << " type = " << outputRegionArg.getType(); + } + if (output.getType() != resultType) { + return emitOpError("expected output arg ") + << index << " with type = " << output.getType() + << " to match resultType " << index << " type = " << resultType; + } + auto terminator = getTerminator(); + auto numDstOperands = terminator.getNumDstOperands(); + if (index >= numDstOperands) { + const auto *s = index ? "s" : ""; + return terminator.emitOpError("expected to have at least ") + << index + 1 << " destination operand" << s << " (currently " + << numDstOperands << ")"; + } + + if (terminator.getDstOperand(index)->get() != outputRegionArg) { + return terminator.emitOpError("expected output block argument ") + << index << " to match set_yield destination"; + } + } + return success(); +} + +void ParallelOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTypes, + ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps, + ValueRange outputs, std::optional distributionType, + function_ref + bodyBuilderFn) { + if (distributionType.has_value()) + result.addAttribute(getDistributionTypeAttrName(result.name), + distributionType.value()); + + buildLoopLikeOp(builder, result, resultTypes, lowerBounds, + upperBounds, steps, outputs, bodyBuilderFn); +} + +void ParallelOp::print(OpAsmPrinter &p) { + p << " (" << getInductionVars() << ") = (" << getLowerBound() << ") to (" + << getUpperBound() << ") step (" << getStep() << ") "; + + if (!getOutputs().empty()) { + p << "outs ("; + llvm::interleaveComma( + llvm::zip(getRegionOutputArgs(), getOutputs()), p, [&](auto it) { + Value outputRegionArg, output; + std::tie(outputRegionArg, output) = it; + p << outputRegionArg << " = " << output << ": " << output.getType(); + }); + p << ") "; + } + + if (getDistributionType().has_value()) + p << "distribution (" << getDistributionTypeAttr() << ") "; + + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); + p.printOptionalAttrDict( + getOperation()->getAttrs(), + /*elidedAttrs=*/{ParallelOp::getOperandSegmentSizeAttr(), + getDistributionTypeAttrName()}); + + if (!getResultTypes().empty()) { + p << " : "; + llvm::interleave(getResultTypes(), p, ", "); + } +} + +ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { + return parseLoopLikeOp(parser, result); +} + +namespace { + +/// Fold tensor.dim(gml_st.parallel outs(... = %t)) to tensor.dim(%t). +struct DimOfParallelOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::DimOp dimOp, + PatternRewriter &rewriter) const final { + auto parallelOp = dimOp.getSource().getDefiningOp(); + if (!parallelOp) return failure(); + + OpOperand &out = + parallelOp.getOpOperandForResult(dimOp.getSource().cast()); + rewriter.updateRootInPlace( + dimOp, [&]() { dimOp.getSourceMutable().assign(out.get()); }); + return success(); + } +}; + +/// Fold tensor.casts into the output arguments of gml_st.parallel. +struct FoldTensorCastIntoParallelOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + struct TypeCast { + Type srcType; + Type dstType; + }; + + LogicalResult matchAndRewrite(gml_st::ParallelOp parallelOp, + PatternRewriter &rewriter) const final { + llvm::SmallMapVector tensorCastProducers; + llvm::SmallVector newOutputTensors = parallelOp.getOutputs(); + for (auto &en : llvm::enumerate(newOutputTensors)) { + if (auto castOp = en.value().getDefiningOp()) { + tensorCastProducers[en.index()] = + TypeCast{castOp.getSource().getType(), castOp.getType()}; + en.value() = castOp.getSource(); + } + } + + if (tensorCastProducers.empty()) return failure(); + + // Create new loop. + Location loc = parallelOp.getLoc(); + std::optional distTypeAttr; + if (auto distType = parallelOp.getDistributionType()) + distTypeAttr = rewriter.getStringAttr(*distType); + auto newParallelOp = rewriter.create( + loc, TypeRange{ValueRange{newOutputTensors}}, + parallelOp.getLowerBound(), parallelOp.getUpperBound(), + parallelOp.getStep(), newOutputTensors, distTypeAttr, nullptr); + + Block *loopBody = newParallelOp.getBody(); + + // Cast bbArgs back to the original types. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointToStart(loopBody); + SmallVector castBBArgs = + ValueRange{newParallelOp.getRegionOutputArgs()}; + for (auto &item : tensorCastProducers) { + Value &oldTypeBBArg = castBBArgs[item.first]; + oldTypeBBArg = rewriter.create(loc, item.second.dstType, + oldTypeBBArg); + } + + // Move old body into new parallel loop. + SmallVector blockArgs = newParallelOp.getInductionVars(); + blockArgs.append(castBBArgs); + rewriter.mergeBlocks(parallelOp.getBody(), loopBody, blockArgs); + + // Cast `set_yield` destination operands to the new types. + SetYieldOp terminator = newParallelOp.getTerminator(); + rewriter.setInsertionPoint(terminator); + SmallVector castDsts = terminator.getDsts(); + for (auto &item : tensorCastProducers) { + Value &newTypeDsts = castDsts[item.first]; + newTypeDsts = rewriter.create(loc, item.second.srcType, + newTypeDsts); + } + terminator.getDstsMutable().assign(castDsts); + + // Cast results back to the original types. + rewriter.setInsertionPointAfter(newParallelOp); + SmallVector castResults = newParallelOp.getResults(); + for (auto &item : tensorCastProducers) { + Value &oldTypeResult = castResults[item.first]; + oldTypeResult = rewriter.create(loc, item.second.dstType, + oldTypeResult); + } + rewriter.replaceOp(parallelOp, castResults); + + return success(); + } +}; + +} // namespace + +void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>( + context, + [&](ParallelOp op) { return !op.getDistributionType().has_value(); }); + results.add(context); +} + +//===----------------------------------------------------------------------===// +// ForOp +//===----------------------------------------------------------------------===// + +Region &ForOp::getLoopBody() { return getRegion(); } + +SetYieldOp ForOp::getTerminator() { + return cast(getBody()->getTerminator()); +} + +LogicalResult ForOp::verify() { + if (getNumResults() != getNumOutputs()) { + return emitOpError() << "expected the number of output arguments to match " + "the number of results"; + } + + // Check if types of output arguments match region args types. + for (auto &item : llvm::enumerate( + llvm::zip(getOutputs(), getRegionOutputArgs(), getResultTypes()))) { + Value output, outputRegionArg; + Type resultType; + unsigned index = item.index(); + std::tie(output, outputRegionArg, resultType) = item.value(); + if (output.getType() != outputRegionArg.getType()) { + return emitOpError("expected output arg ") + << index << " with type = " << output.getType() + << " to match region arg " << index + getNumLoops() + << " type = " << outputRegionArg.getType(); + } + if (output.getType() != resultType) { + return emitOpError("expected output arg ") + << index << " with type = " << output.getType() + << " to match resultType " << index << " type = " << resultType; + } + auto terminator = getTerminator(); + auto numDstOperands = terminator.getNumDstOperands(); + if (index >= numDstOperands) { + const auto *s = index ? "s" : ""; + return terminator.emitOpError("expected to have at least ") + << index + 1 << " destination operand" << s << " (currently " + << numDstOperands << ")"; + } + + if (terminator.getDstOperand(index)->get() != outputRegionArg) { + return terminator.emitOpError("expected output block argument ") + << index << " to match set_yield destination"; + } + } + return success(); +} + +void ForOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTypes, + ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps, + ValueRange outputs, + function_ref + bodyBuilderFn) { + buildLoopLikeOp(builder, result, resultTypes, lowerBounds, upperBounds, + steps, outputs, bodyBuilderFn); +} + +void ForOp::print(OpAsmPrinter &p) { + p << " (" << getInductionVars() << ") = (" << getLowerBound() << ") to (" + << getUpperBound() << ") step (" << getStep() << ")"; + + if (!getOutputs().empty()) { + p << " outs ("; + llvm::interleaveComma( + llvm::zip(getRegionOutputArgs(), getOutputs()), p, [&](auto it) { + Value outputRegionArg, output; + std::tie(outputRegionArg, output) = it; + p << outputRegionArg << " = " << output << ": " << output.getType(); + }); + p << ")"; + } + + p << ' '; + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); + p.printOptionalAttrDict(getOperation()->getAttrs(), + /*elidedAttrs=*/{ForOp::getOperandSegmentSizeAttr()}); + + if (!getResultTypes().empty()) { + p << " : "; + llvm::interleave(getResultTypes(), p, ", "); + } +} + +ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { + return parseLoopLikeOp(parser, result); +} + +namespace { +/// Folds CastOp of loop outputs into ForOp +struct RefineForOpShape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForOp op, + PatternRewriter &rewriter) const override { + if (llvm::all_of(op.getOutputs(), [](auto out) { + return out.template getDefiningOp() == nullptr; + })) + return failure(); + + Location loc = op.getLoc(); + // Scans through output args to find what args are produced by `tensor.cast` + // ops. Also cache the info since we are gonna reuse it a lot. + SmallVector newOutputs{op.getOutputs()}; + SmallVector newTypes{op.getResultTypes()}; + SmallVector castOutputs; + for (auto &&[out, type] : llvm::zip(newOutputs, newTypes)) { + if (auto cast = + castOutputs.emplace_back(out.getDefiningOp())) { + out = cast.getSource(); + type = out.getType(); + } + } + + auto newFor = rewriter.create(loc, newTypes, op.getLowerBound(), + op.getUpperBound(), op.getStep(), + newOutputs, nullptr); + // The new loop needs to keep all attributes from the old one. + newFor->setAttrs(op->getAttrs()); + + // Map outputs, insert `tensor.cast` if necessary. + IRMapping bvm; + bvm.map(op.getInductionVars(), newFor.getInductionVars()); + + auto innerBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, newFor.getBody()); + rewriter.setInsertionPointAfter(newFor); + + for (const auto &[oldArg, newArg, cast] : + llvm::zip(op.getRegionOutputArgs(), newFor.getRegionOutputArgs(), + castOutputs)) { + bvm.map(oldArg, + cast ? innerBuilder.create(cast.getType(), newArg) + : Value(newArg)); + } + // Cast the loop results for downstream uses of the loop if necessary. + SmallVector newResults{newFor.getResults()}; + for (auto &&[res, cast] : llvm::zip(newResults, castOutputs)) { + if (cast) res = rewriter.create(loc, cast.getType(), res); + } + + // Clone loop body. + for (auto &o : *(op.getBody())) innerBuilder.clone(o, bvm); + + // Update set_yield destinations to the new type. + auto term = cast(newFor.getTerminator()); + rewriter.updateRootInPlace(term, [&]() { + term.getDstsMutable().assign(newFor.getRegionOutputArgs()); + }); + + // Update the original loop by the new loop + CastOp. + rewriter.replaceOp(op, newResults); + return success(); + } +}; + +// Fold away ForOp iter arguments when: +// 1) The op yields the iter arguments. +// 2) The iter arguments have no use and the corresponding outer region +// iterators (inputs) are yielded. +// 3) The iter arguments have no use and the corresponding (operation) results +// have no use. +// +// These arguments must be defined outside of the ForOp region and can just be +// forwarded after simplifying the op inits, yields and returns. +// +// The implementation uses `mergeBlockBefore` to steal the content of the +// original ForOp and avoid cloning. +struct ForOpIterArgsFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForOp forOp, + PatternRewriter &rewriter) const final { + bool canonicalize = false; + auto yieldOp = forOp.getTerminator(); + + // An internal flat vector of block transfer + // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to + // transformed block argument mappings. This plays the role of a + // IRMapping for the particular use case of calling into + // `mergeBlockBefore`. + SmallVector keepMask; + keepMask.reserve(yieldOp.getNumUpdates()); + SmallVector newBlockTransferArgs, newOutputArgs, newResultValues; + newBlockTransferArgs.reserve(1 + forOp.getNumOutputs()); + newBlockTransferArgs.push_back(Value()); // iv placeholder with null value + newOutputArgs.reserve(forOp.getNumOutputs()); + newResultValues.reserve(forOp.getNumResults()); + // [iter from outside, iter inside region, op results, yield sources] + for (auto [out, arg, res, yieldSrc] : + llvm::zip(forOp.getOutputs(), forOp.getRegionOutputArgs(), + forOp.getResults(), yieldOp.getSrcs())) { + // Forwarded is `true` when: + // 1) The region `iter` argument is yielded. + // 2) The region `iter` argument has no use, and the corresponding iter + // operand (input) is yielded. + // 3) The region `iter` argument has no use, and the corresponding op + // result has no use. + bool forwarded = + ((arg == yieldSrc) || + (arg.use_empty() && (out == yieldSrc || res.use_empty()))); + keepMask.push_back(!forwarded); + canonicalize |= forwarded; + if (forwarded) { + newBlockTransferArgs.push_back(out); + newResultValues.push_back(out); + continue; + } + newOutputArgs.push_back(out); + newBlockTransferArgs.push_back(Value()); // placeholder with null value + newResultValues.push_back(Value()); // placeholder with null value + } + + if (!canonicalize) return failure(); + + auto newForOp = rewriter.create( + forOp.getLoc(), + llvm::to_vector<1>(llvm::map_range( + newOutputArgs, [&](Value v) -> Type { return v.getType(); })), + forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), + newOutputArgs, nullptr); + // The new loop needs to keep all attributes from the old one, except for + // "operand_segment_sizes" which captures the outdated information of the + // old iteration domain. + for (const auto &namedAttr : forOp->getAttrs()) { + if (namedAttr.getName() == ForOp::getOperandSegmentSizeAttr()) continue; + newForOp->setAttr(namedAttr.getName(), namedAttr.getValue()); + } + Block &newBlock = newForOp.getRegion().front(); + + // Replace the null placeholders with newly constructed values. + newBlockTransferArgs[0] = newBlock.getArgument(0); // iv + for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size(); + idx != e; ++idx) { + Value &blockTransferArg = newBlockTransferArgs[1 + idx]; + Value &newResultVal = newResultValues[idx]; + assert((blockTransferArg && newResultVal) || + (!blockTransferArg && !newResultVal)); + if (!blockTransferArg) { + blockTransferArg = newForOp.getRegionOutputArgs()[collapsedIdx]; + newResultVal = newForOp.getResult(collapsedIdx++); + } + } + + Block &oldBlock = forOp.getRegion().front(); + assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() && + "unexpected argument size mismatch"); + + // No terminator case: merge and rewrite the merged terminator. + auto cloneFilteredTerminator = [&](SetYieldOp mergedTerminator) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(mergedTerminator); + SmallVector filteredSrcs, filteredDsts, filteredSets; + filteredSrcs.reserve(newResultValues.size()); + filteredDsts.reserve(newResultValues.size()); + filteredSets.reserve(newResultValues.size()); + for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx) { + if (keepMask[idx]) { + filteredSrcs.push_back(mergedTerminator.getSrcs()[idx]); + filteredDsts.push_back(mergedTerminator.getDsts()[idx]); + filteredSets.push_back(mergedTerminator.getSets()[idx]); + } + } + rewriter.create(mergedTerminator.getLoc(), filteredSrcs, + filteredDsts, filteredSets); + }; + + rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs); + auto mergedYieldOp = newForOp.getTerminator(); + cloneFilteredTerminator(mergedYieldOp); + rewriter.eraseOp(mergedYieldOp); + rewriter.replaceOp(forOp, newResultValues); + return success(); + } +}; +} // namespace + +void ForOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add, RefineForOpShape, + ForOpIterArgsFolder>(context); +} + +//===----------------------------------------------------------------------===// +// TileOp +//===----------------------------------------------------------------------===// + +namespace { +/// Fold gml_st.tile [%c0] ... into gml_st.tile [0] ... +/// Adapted from OpWithOffsetSizesAndStridesConstantArgumentFolder, which makes +/// slightly incompatible assumptions about the op. +struct FoldConstantsIntoTileType : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TileOp op, + PatternRewriter &rewriter) const override { + // No constant operand, just return; + if (llvm::none_of(op.getOperands(), [](Value operand) { + return matchPattern(operand, matchConstantIndex()); + })) + return failure(); + + // At least one of offsets/sizes/strides is a new constant. + // Form the new list of operands and constant attributes from the existing. + SmallVector mixedOffsets(op.getMixedOffsets()); + SmallVector mixedSizes(op.getMixedSizes()); + SmallVector mixedStrides(op.getMixedStrides()); + canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamic); + canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic); + canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamic); + + // Create the new tile in canonical form. + TileOp newOp = rewriter.create(op.getLoc(), mixedOffsets, + mixedSizes, mixedStrides); + // Cast the result back to the original type. This will be folded further + // materialize ops. + rewriter.replaceOpWithNewOp( + op, TypeRange{op.getType()}, ValueRange{newOp}); + + return success(); + } +}; +} // namespace + +void TileOp::build(OpBuilder &b, OperationState &result, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides, + ArrayRef attrs) { + SmallVector staticOffsets, staticSizes, staticStrides; + SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes); + dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); + auto tileType = TileType::get(b.getContext(), staticSizes); + build(b, result, tileType, dynamicOffsets, dynamicSizes, dynamicStrides, + b.getDenseI64ArrayAttr(staticOffsets), + b.getDenseI64ArrayAttr(staticSizes), + b.getDenseI64ArrayAttr(staticStrides)); + result.addAttributes(attrs); +} + +void TileOp::build(OpBuilder &b, OperationState &result, + ArrayRef offsets, + ArrayRef attrs) { + SmallVector unitSizesAndStrides(offsets.size(), + b.getIndexAttr(1)); + return build(b, result, offsets, unitSizesAndStrides, unitSizesAndStrides, + attrs); +} + +LogicalResult TileOp::inferReturnTypes( + MLIRContext *ctx, std::optional /*loc*/, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // Derive result shape. + TileOp::Adaptor adaptor(operands, attributes, regions); + SmallVector shape = llvm::to_vector(adaptor.getStaticSizes()); + + auto resultTy = TileType::get(ctx, shape); + inferredReturnTypes.push_back(resultTy); + return success(); +} + +LogicalResult TileOp::verify() { + auto resultType = getType(); + auto rank = resultType.getRank(); + if (failed(mlir::verifyListOfOperandsOrIntegers( + getOperation(), "size", rank, getStaticSizes(), getSizes()))) { + return failure(); + } + if (failed(mlir::verifyListOfOperandsOrIntegers( + getOperation(), "offset", rank, getStaticOffsets(), getOffsets()))) { + return failure(); + } + if (failed(mlir::verifyListOfOperandsOrIntegers( + getOperation(), "stride", rank, getStaticStrides(), getStrides()))) { + return failure(); + } + for (auto [tileSize, offset, size, stride] : + llvm::zip(resultType.getShape(), getStaticOffsets(), getStaticSizes(), + getStaticStrides())) { + if (offset < 0 && offset != ShapedType::kDynamic) { + return emitOpError("expected offset = ") + << offset << " to be non-negative"; + } + if (size < 0 && size != ShapedType::kDynamic) { + return emitOpError("expected size = ") << size << " to be non-negative"; + } + if (stride < 0 && stride != ShapedType::kDynamic) { + return emitOpError("expected stride = ") + << stride << " to be non-negative"; + } + if (tileSize != size) { + return emitOpError("size arg = ") + << size << " does not match tile size = " << tileSize; + } + } + return success(); +} + +void TileOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//===----------------------------------------------------------------------===// +// SetYieldOp +//===----------------------------------------------------------------------===// + +using AccumulatorRegionBuilderFn = + function_ref; + +void SetYieldOp::build(OpBuilder &builder, OperationState &result) { + build(builder, result, std::nullopt, std::nullopt, std::nullopt); +} + +void SetYieldOp::build(OpBuilder &builder, OperationState &result, + ValueRange srcs, ValueRange dsts, ValueRange sets) { + SmallVector accumulatorFlags(srcs.size(), false); + build(builder, result, srcs, dsts, sets, + builder.getBoolArrayAttr(accumulatorFlags), std::nullopt); +} + +void SetYieldOp::build( + OpBuilder &builder, OperationState &result, ValueRange srcs, + ValueRange dsts, ValueRange sets, ArrayAttr accumulatorFlags, + ArrayRef accumulatorBuilderFns) { + assert(dsts.size() == srcs.size() && + "`dsts` and `srcs` should have the same size"); + assert(sets.size() == srcs.size() && + "`sets` and `srcs` should have the same size"); + assert(accumulatorFlags.size() == srcs.size() && + "`accumulatorFlags` and `srcs` should have the same size"); + + auto accumulatorCount = llvm::count_if(accumulatorFlags, [](Attribute attr) { + return attr.cast().getValue(); + }); + (void)accumulatorCount; + assert(accumulatorCount == + static_cast(accumulatorBuilderFns.size()) && + "the number of flags set in `accumulatorFlags` attribute should be " + "equal to the number of `accumulatorBuilderFns`"); + + result.addOperands(srcs); + result.addOperands(dsts); + result.addOperands(sets); + result.addAttribute(SetYieldOp::getAccumulatorFlagsAttrName(result.name), + accumulatorFlags); + + const auto *builderFnIt = accumulatorBuilderFns.begin(); + for (auto item : llvm::zip(srcs, accumulatorFlags)) { + Value src = std::get<0>(item); + auto accumulatorFlag = std::get<1>(item).cast(); + + if (!accumulatorFlag.getValue()) continue; + Region *region = result.addRegion(); + OpBuilder::InsertionGuard g(builder); + SmallVector argTypes(2, src.getType()); + builder.createBlock(region); + Block &bodyBlock = region->front(); + bodyBlock.addArguments(argTypes, {result.location, result.location}); + + builder.setInsertionPointToStart(&bodyBlock); + (*builderFnIt)(builder, result.location, bodyBlock.getArgument(0), + bodyBlock.getArgument(1)); + ++builderFnIt; + } +} + +LogicalResult SetYieldOp::verify() { + for (const auto [dst, src, set] : + llvm::zip(getDsts(), getSrcs(), getSets())) { + if (failed(verifyCompatibleExtractedSubset( + getOperation(), dst.getType().cast(), src.getType(), + set.getType().cast().getShape()))) + return failure(); + } + auto accumulatorCount = llvm::count_if( + getAccumulatorFlags(), + [](Attribute attr) { return attr.cast().getValue(); }); + if (accumulatorCount != static_cast(getAccumulators().size())) + return emitOpError("expected the number of accumulator regions ") + << getAccumulators().size() + << " to match the number of set accumulator flags " + << accumulatorCount; + + auto *regionIt = getAccumulators().begin(); + for (auto item : llvm::zip(getSrcs(), getAccumulatorFlags())) { + Type srcType = std::get<0>(item).getType(); + BoolAttr accumulatorFlag = std::get<1>(item).cast(); + if (!accumulatorFlag.getValue()) continue; + + Block &block = regionIt->front(); + if (block.getArgumentTypes() != SmallVector{srcType, srcType}) + return emitOpError() + << "expected accumulator region to have 2 arguments of type " + << srcType; + ++regionIt; + } + return success(); +} + +void SetYieldOp::print(OpAsmPrinter &p) { + p.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs = */ + {getAccumulatorFlagsAttrName().str()}); + + auto *regionIt = getOperation()->getRegions().begin(); + for (auto &en : llvm::enumerate( + llvm::zip(getSrcs(), getDsts(), getSets(), getAccumulatorFlags()))) { + if (en.index() > 0) { + p << ','; + p.printNewline(); + } + Value src = std::get<0>(en.value()); + Value dst = std::get<1>(en.value()); + Value set = std::get<2>(en.value()); + auto accumulatorFlag = std::get<3>(en.value()).cast(); + + p << ' ' << src << " into " << dst << '[' << set << ']'; + + if (accumulatorFlag.getValue()) { + auto &block = regionIt->getBlocks().front(); + Value newValue = block.getArgument(0); + Value oldValue = block.getArgument(1); + p << " acc (" << newValue << ", " << oldValue << ": " + << oldValue.getType() << ") "; + + p.printRegion(*regionIt, false); + ++regionIt; + } + + p << " : " << src.getType() << " into " << dst.getType() << '[' + << set.getType() << ']'; + } +} + +ParseResult SetYieldOp::parse(OpAsmParser &parser, OperationState &result) { + if (parser.parseOptionalAttrDict(result.attributes)) return failure(); + + SmallVector accumulatorFlags; + SmallVector srcs, dsts, sets; + SmallVector srcTypes, dstTypes, setTypes; + + auto parseElt = [&]() -> ParseResult { + OpAsmParser::UnresolvedOperand src; + auto parseResult = parser.parseOptionalOperand(src); + + if (!parseResult.has_value()) return success(); + srcs.push_back(src); + + if (parser.parseKeyword("into") || + parser.parseOperand(dsts.emplace_back()) || parser.parseLSquare() || + parser.parseOperand(sets.emplace_back()) || parser.parseRSquare()) + return failure(); + + OpBuilder b(parser.getBuilder().getContext()); + bool hasAccumulatorRegion = succeeded(parser.parseOptionalKeyword("acc")); + accumulatorFlags.push_back(hasAccumulatorRegion); + if (hasAccumulatorRegion) { + auto region = std::make_unique(); + OpAsmParser::UnresolvedOperand newValue, oldValue; + Type argType; + if (parser.parseLParen() || parser.parseOperand(newValue) || + parser.parseComma() || parser.parseOperand(oldValue) || + parser.parseColonType(argType) || parser.parseRParen()) + return failure(); + + SmallVector regionArgs; + for (auto value : {newValue, oldValue}) { + auto &arg = regionArgs.emplace_back(); + arg.ssaName = value; + arg.type = argType; + } + + if (parser.parseRegion(*region, regionArgs)) return failure(); + result.addRegion(std::move(region)); + } + if (parser.parseColon() || parser.parseType(srcTypes.emplace_back()) || + parser.parseKeyword("into") || + parser.parseType(dstTypes.emplace_back()) || parser.parseLSquare() || + parser.parseType(setTypes.emplace_back()) || parser.parseRSquare()) + return failure(); + + return success(); + }; + if (parser.parseCommaSeparatedList(AsmParser::Delimiter::None, parseElt)) + return failure(); + + if (parser.resolveOperands(srcs, srcTypes, parser.getCurrentLocation(), + result.operands) || + parser.resolveOperands(dsts, dstTypes, parser.getCurrentLocation(), + result.operands) || + parser.resolveOperands(sets, setTypes, parser.getCurrentLocation(), + result.operands)) + return failure(); + + result.addAttribute(SetYieldOp::getAccumulatorFlagsAttrName(result.name), + parser.getBuilder().getBoolArrayAttr(accumulatorFlags)); + return success(); +} + +namespace { +/// Folds UnrealizedConversionCast of TileType into SetYieldOp. +struct FoldTileCastIntoSetYield : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SetYieldOp op, + PatternRewriter &rewriter) const override { + if (!llvm::any_of(op.getSets(), [](auto set) { + return set.template getDefiningOp() != + nullptr; + })) + return failure(); + SmallVector newSrcs{op.getSrcs()}; + SmallVector newSets{op.getSets()}; + for (auto &&[src, set] : llvm::zip(newSrcs, newSets)) { + auto cast = set.getDefiningOp(); + if (!cast) continue; + set = cast.getOperand(0); + Type castResultType = src.getType(); + if (auto shapedType = dyn_cast(castResultType)) { + castResultType = + shapedType.clone(set.getType().cast().getShape(), + shapedType.getElementType()); + src = rewriter.create(op.getLoc(), castResultType, src); + } + } + rewriter.replaceOpWithNewOp(op, newSrcs, op.getDsts(), newSets); + return success(); + } +}; +} // namespace + +void SetYieldOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +LogicalResult YieldOp::verify() { return success(); } + +} // namespace gml_st +} // namespace mlir + +// Generated op classes. +#define GET_OP_CLASSES +#include "gml_st/IR/gml_st_ops.cc.inc" diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h b/tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.h similarity index 78% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h rename to tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.h index ad8d7bede60..eb56b4b437a 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.h @@ -15,8 +15,8 @@ limitations under the License. // This file defines the operations used in the GML ST dialect. -#ifndef MLIR_HLO_DIALECT_GML_ST_IR_GML_ST_OPS_H -#define MLIR_HLO_DIALECT_GML_ST_IR_GML_ST_OPS_H +#ifndef MLIR_HLO_GML_ST_IR_GML_ST_OPS_H +#define MLIR_HLO_GML_ST_IR_GML_ST_OPS_H #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Dialect.h" @@ -29,18 +29,18 @@ limitations under the License. #include "mlir/Interfaces/ViewLikeInterface.h" // Generated dialect declarations. -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_dialect.h.inc" +#include "gml_st/IR/gml_st_dialect.h.inc" // Generated custom type declarations. #define GET_TYPEDEF_CLASSES -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_types.h.inc" +#include "gml_st/IR/gml_st_types.h.inc" // Generated attribute classes. #define GET_ATTRDEF_CLASSES -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_attrs.h.inc" +#include "gml_st/IR/gml_st_attrs.h.inc" // Generated operation classes. #define GET_OP_CLASSES -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h.inc" +#include "gml_st/IR/gml_st_ops.h.inc" -#endif // MLIR_HLO_DIALECT_GML_ST_IR_GML_ST_OPS_H +#endif // MLIR_HLO_GML_ST_IR_GML_ST_OPS_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.td b/tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.td similarity index 51% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.td rename to tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.td index 43497e6dff6..8da88d0a0e8 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.td +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops.td @@ -19,11 +19,12 @@ limitations under the License. #define GML_ST_OPS include "mlir/IR/OpBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/ViewLikeInterface.td" -include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops_base.td" -include "mlir-hlo/Dialect/gml_st/IR/gml_st_legacy_ops.td" +include "gml_st/IR/gml_st_ops_base.td" /////////////////////////////////////////////////////////////////////////////// // Types @@ -34,9 +35,37 @@ class GMLST_Set : TypeDef { } def GMLST_TileType : GMLST_Set<"Tile"> { let mnemonic = "tile"; - let summary = "Type that represents a tile of an index space."; + + let summary = "Type that represents a tile in an N-d iteration space."; + let description = [{ + Values with the tile type represent an N-dimensional rectangle. Unlike a + TensorType, there is no element type specified for the tile. Each dimension + may be a static non-negative decimal constant or be dynamically determined + (indicated by `?`). + + Examples: + + ```mlir + // Full dynamic dimensions. + !gml_st.tile + + // Partially dynamic dimensions. + !gml_st.tile + + // Full static shape. + !gml_st.tile<16 x 8 x f32> + + // Tile with rank zero. + !gml_st.tile + + // Zero-element dimensions are allowed. + !gml_st.tile<0 x 42 x f32> + ``` + }]; + let parameters = (ins ArrayRefParameter<"int64_t">:$shape); let assemblyFormat = "`<` custom($shape) `>`"; + let extraClassDeclaration = [{ unsigned getRank() const { return getShape().size(); } bool hasStaticShape() const { @@ -50,11 +79,17 @@ def GMLST_TileType : GMLST_Set<"Tile"> { def AnySet : Type, "subset type">; +def Vector : AnyTypeOf<[ + AnyVectorOfAnyRank +], "", "::mlir::ShapedType">; +def VectorOrScalar : AnyTypeOf<[ + AnyVectorOfAnyRank, AnyFloat, AnyInteger, AnyComplex, Index +]>; def RankedTensorOrVector : AnyTypeOf<[ AnyRankedTensor, AnyVectorOfAnyRank ], "", "::mlir::ShapedType">; def RankedTensorOrVectorOrScalar : AnyTypeOf<[ - AnyRankedTensor, AnyVectorOfAnyRank, AnyFloat, AnyInteger, AnyComplex + AnyRankedTensor, AnyVectorOfAnyRank, AnyFloat, AnyInteger, AnyComplex, Index ]>; /////////////////////////////////////////////////////////////////////////////// @@ -62,26 +97,34 @@ def RankedTensorOrVectorOrScalar : AnyTypeOf<[ /////////////////////////////////////////////////////////////////////////////// def GMLST_TileOp : GMLST_Op<"tile", [ - Pure, - AttrSizedOperandSegments, - OffsetSizeAndStrideOpInterface, - DeclareOpInterfaceMethods]> { + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + OffsetSizeAndStrideOpInterface, + Pure + ]> { + let summary = "Operation to generate a TileType from offsets, sizes, strides"; + let description = [{ + This is a metadata operation that encapsulates mixed dynamic and static + offsets, sizes and strides. + + Examples: + + ```mlir + // Tile with partially dynamic sizes. + %tile_2d = gml_st.tile [0, 0] [%x, 16] [1, 1] : !gml_st.tile + + // Tile with rank zero. + %tile_0d = gml_st.tile [] [] [] : !gml_st.tile<> + ``` + }]; let arguments = (ins Variadic:$offsets, Variadic:$sizes, Variadic:$strides, - I64ArrayAttr:$static_offsets, - I64ArrayAttr:$static_sizes, - I64ArrayAttr:$static_strides); + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_sizes, + DenseI64ArrayAttr:$static_strides); let results = (outs GMLST_TileType:$result); - let assemblyFormat = [{ - custom($offsets, $static_offsets, - "ShapedType::kDynamicStrideOrOffset") - custom($sizes, $static_sizes, - "ShapedType::kDynamicSize") - custom($strides, $static_strides, - "ShapedType::kDynamicStrideOrOffset") - attr-dict `:` qualified(type($result)) - }]; + let builders = [ OpBuilder<(ins "ArrayRef":$offsets, "ArrayRef":$sizes, "ArrayRef":$strides, @@ -89,6 +132,14 @@ def GMLST_TileOp : GMLST_Op<"tile", [ OpBuilder<(ins "ArrayRef":$offsets, CArg<"ArrayRef", "{}">:$attrs)>, ]; + + let assemblyFormat = [{ + custom($offsets, $static_offsets) + custom($sizes, $static_sizes) + custom($strides, $static_strides) + attr-dict `:` qualified(type($result)) + }]; + let extraClassDeclaration = [{ /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. @@ -104,18 +155,76 @@ def GMLST_TileOp : GMLST_Op<"tile", [ let hasVerifier = 1; } -def GMLST_MaterializeOp : GMLST_Op<"materialize", [Pure]> { - let arguments = (ins RankedTensorOrVector:$source, AnySet:$set); - let results = (outs RankedTensorOrVectorOrScalar:$result); +def GMLST_MaterializeOp : GMLST_Op<"materialize", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + OffsetSizeAndStrideOpInterface, Pure]> { + let summary = "Operation to extract values from tensor or vector"; + let description = [{ + The "materialize" operation extracts a vector from another or a vector. + MaterializeOp accepts a list of offsets-sizes-strides. + + ``` + %subvector = gml_st.materialize %vector[0, 0][42, 16][1, 1] + : vector[!gml_st.tile<42x16>] to vector<42x16xf32> + ``` + + If `sizes` define a shape with a single element, it is also possible to + extract an element from a vector. + + ``` + %element = gml_st.materialize %vector[0, 0][1, 1][1, 1] + : vector<3x1xindex>[!gml_st.tile<1x1>] to index + + %element = gml_st.materialize %vector[%tile_0d] + : vector[][][] to f32 - let builders = [OpBuilder<(ins "Value":$source, "Value":$set)>]; + %subvector = gml_st.materialize %vector[%tile_with_single_element] + : vector<3x1xindex>[0, 0][1, 1][1, 1] to vector<1x1xindex> + ``` + }]; + let arguments = (ins + Vector:$source, + Variadic:$offsets, + Variadic:$sizes, + Variadic:$strides, + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_sizes, + DenseI64ArrayAttr:$static_strides + ); + let results = (outs VectorOrScalar:$result); + + let builders = [ + OpBuilder<(ins "Value":$source, "ArrayRef":$offsets, + "ArrayRef":$sizes, "ArrayRef":$strides)>, + OpBuilder<(ins "Type":$resultType, "Value":$source, + "ArrayRef":$offsets, "ArrayRef":$sizes, + "ArrayRef":$strides)>, + OpBuilder<(ins "Value":$source, "ArrayRef":$offsets)>, + ]; let assemblyFormat = [{ - $source`[` $set `]` attr-dict `:` type($source) `[` type($set) `]` - `to` type($result) + $source + custom($offsets, $static_offsets) + custom($sizes, $static_sizes) + custom($strides, $static_strides) + attr-dict + `:` type($source) `to` type($result) }]; + let extraClassDeclaration = [{ + /// Return the expected rank of each of the`static_offsets`, `static_sizes` + /// and `static_strides` attributes. + std::array getArrayAttrMaxRanks() { + unsigned rank = getSource().getType().cast().getRank(); + return {rank, rank, rank}; + } + /// Return the number of leading operands before the `offsets`, `sizes` and + /// and `strides` operands. + static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; } + }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } class GMLST_LoopLikeOp traits = []> @@ -146,6 +255,58 @@ class GMLST_LoopLikeOp traits = []> /// Return terminator of the loop body. SetYieldOp getTerminator(); + + /// Number of output operands + unsigned getNumOutputs() { return getOutputs().size(); } + + /// Get the region output args. + Block::BlockArgListType getRegionOutputArgs() { + return getBody()->getArguments().take_back(getNumOutputs()); + } + + /// Get the region output arg that corresponds to an OpOperand. + BlockArgument getRegionOutputArgForOpOperand(OpOperand &opOperand) { + assert(opOperand.getOperandNumber() >= getNumControlOperands() && + "expected an output args operand"); + assert(opOperand.getOwner() == getOperation() && + "opOperand does not belong to this gml_st::ForOp operation"); + return getBody()->getArgument(opOperand.getOperandNumber() - + getNumControlOperands() + getNumLoops()); + } + + /// Get the OpOperand& that corresponds to a region output arg. + OpOperand &getOpOperandForRegionOutputArg(BlockArgument bbArg) { + assert(bbArg.getArgNumber() >= getNumLoops() && + "expected a bbArg that is not an induction variable"); + assert(bbArg.getOwner()->getParentOp() == getOperation() && + "bbArg does not belong to the gml_st::ForOp body"); + return getOperation()->getOpOperand( + getNumControlOperands() + bbArg.getArgNumber() - getNumLoops()); + } + + /// Get the OpResult that corresponds to an OpOperand. + OpResult getResultForOpOperand(OpOperand &opOperand) { + assert(opOperand.getOperandNumber() >= getNumControlOperands() && + "expected an output args operand"); + assert(opOperand.getOwner() == getOperation() && + "opOperand does not belong to this gml_st::ForOp operation"); + return getOperation()->getResult( + opOperand.getOperandNumber() - getNumControlOperands()); + } + + /// Get the OpOperand& that corresponds to an OpResultOpOperand. + OpOperand &getOpOperandForResult(OpResult opResult) { + assert(opResult.getDefiningOp() == getOperation() && + "opResult does not belong to the gml_st::ForOp operation"); + return getOperation()->getOpOperand( + getNumControlOperands() + opResult.getResultNumber()); + } + + /// Return the destinations for a gml_st.for op. + ValueRange getLoopLikeOpInits() { + return getOutputs(); + } + }]; let hasCustomAssemblyFormat = 1; @@ -154,99 +315,171 @@ class GMLST_LoopLikeOp traits = []> def GMLST_ParallelOp : GMLST_LoopLikeOp<"parallel", []> { let summary = "Loop-like operation for parallel loops"; let description = [{ - This is a loop-like operation with additional properties. The arguments - also include the output tensors or memrefs. - - Tensor-based version: + This is a multi-dimensional loop-like operation to support distribution on + tensors. The loop can have variadic number of results. - The body region of the loop contains set operations applied to - every output tensor argument of LoopOp. - - The body region must contain exactly one block that terminates with - `gml_st.set_yield` which yields a tensor into a subset of outs. + The body region contains exactly one block that terminates with + `gml_st.set_yield` which specifies how to combine partial results computed + in every iteration of the loop. Example: ```mlir - %space = gml_st.space [8, 16] : !gml_st.tile<8x16> + %add = gml_st.parallel (%i, %j) + = (%c0, %c0) to (%c8, %c16) step (%c4, %c4) { - %result = gml_st.parallel (%i) = (%c0, %c0) to (%c8, %c16) step (%c4, %c4) { - %tile = gml_st.tile %space [%i, %j] [4, 4] [1, 1] - : ! gml_st.tile<8x16> to !gml_st.tile<4x4> + %tile = gml_st.tile [%i, %j] [4, 4] [1, 1] + : !gml_st.tile<8x16> to !gml_st.tile<4x4> %lhs_sub = gml_st.materialize %lhs_[%tile] - : tensor<8x16xf32>[!gml_st.tile<4x4>] + : tensor<8x16xf32>[!gml_st.tile<4x4>] to tensor<4x4xf32> %rhs_sub = gml_st.materialize %rhs_[%tile] - : tensor<8x16xf32>[!gml_st.tile<4x4>] + : tensor<8x16xf32>[!gml_st.tile<4x4>] to tensor<4x4xf32> %out_sub = gml_st.materialize %out_[%tile] - : tensor<8x16xf32>[!gml_st.tile<4x4>] + : tensor<8x16xf32>[!gml_st.tile<4x4>] to tensor<4x4xf32> - %result_sub = linalg.generic (%lhs_sub, %rhs_sub, %out_sub) ... + %add_sub = linalg.map + ins(%lhs_sub: tensor<4x4xf32>, %rhs_sub: tensor<4x4xf32>) + outs(%out_sub: tensor<4x4xf32>) + (%lhs_elem: f32, %rhs_elem: f32) { + %0 = arith.addf %lhs_elem, %rhs_elem : f32 + linalg.yield %0 : f32 + } - gml_st.set_yield %result_sub into %out[%tile] + gml_st.set_yield %add_sub into %out[%tile] : tensor<4x4xf32> into tensor<16x64xf32>[!gml_st.tile<4x4>] + } : tensor<16x64xf32> + ``` + + The terminator specifies, that the partial result `%add_sub` is to become a + part of the final result initialized by `%out` at the position defined by + `%tile`. + + Example with concurrent updates: + + %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c1) { + %tile = gml_st.tile [%i] [1] [1] : !gml_st.tile<1> + %in_sub = gml_st.materialize %in[%tile] + : tensor<8xf32>[!gml_st.tile<1>] to tensor<1xf32> + + gml_st.set_yield %in_sub into %out[%tile_0d] + acc (%new, %old: tensor<1xf32>) { + %combined = mhlo.add %new, %old : tensor<1xf32> + gml_st.yield %combined : tensor<1xf32> + } : tensor<1xf32> into tensor<8xf32>[!gml_st.tile<1>] + } : tensor<8xf32> + + Every iteration of this loop extracts an element of the input and combines + it with the overlapping subset of the output. In that case, + `gml_st.set_yield` has an optional 'accumulator' region, that models the + concurrent update. The code in the accumulator can be lowered to atomic RMW + or to some other synchronization primitive. + + After bufferization the loop does not produce any results. + + ```mlir + gml_st.parallel (%i, %j) = (%c0, %c0) to (%c8, %c16) step (%c4, %c4) { + %lhs_sub = memref.subview %lhs[%tile] + : memref<8x16xf32> to memref<4x4xf32, #map> + %rhs_sub = memref.subview %rhs[%tile] + : memref<8x16xf32> to memref<4x4xf32, #map> + %out_sub = memref.subview %out[%tile] + : memref<8x16xf32> to memref<4x4xf32, #map> + + linalg.map + ins(%lhs_sub: memref<4x4xf32>, %rhs_sub: memref<4x4xf32>) + outs(%out_sub: memref<4x4xf32>) + (%lhs_elem: f32, %rhs_elem: f32) { + %0 = arith.addf %lhs_elem, %rhs_elem : f32 + linalg.yield %0 : f32 + } + + gml_st.set_yield } ``` }]; + let arguments = (ins Variadic:$lowerBound, Variadic:$upperBound, Variadic:$step, + Variadic:$outputs, OptionalAttr:$distributionType); - // The default builder does not generate the block with induction variables - // as arguments, and conflicts with the custom one. Prevent tablegen from - // generating it. - let skipDefaultBuilders = 1; + let builders = [ OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$lowerBounds, - "ValueRange":$upperBounds, "ValueRange":$steps, - CArg<"Optional", "llvm::None">:$distributionType, - CArg<"function_ref", + "ValueRange":$upperBounds, "ValueRange":$steps, "ValueRange":$outputs, + CArg<"std::optional", "std::nullopt">:$distributionType, + CArg<"function_ref", "nullptr">:$bodyBuilderFn)>, ]; + let skipDefaultBuilders = 1; - let extraClassDeclaration = extraBaseClassDeclaration # [{ - /// Return the destinations for a gml_st.parallel op. - ValueRange getLoopLikeOpInits(); - }]; + let extraClassDeclaration = extraBaseClassDeclaration; + let hasCanonicalizer = 1; } def GMLST_ForOp : GMLST_LoopLikeOp<"for", []> { let summary = "Loop-like operation for sequential loops"; let description = [{ - This is a loop-like operation with additional properties. The arguments - also include the output tensors or memrefs. + This is a multi-dimensional loop-like operation with sequential semantics. + The arguments also include the loop-carried tensor or vector variables. + The loop can have variadic number of results. - Tensor-based version: - - The body region of the loop contains set operations applied to - every output tensor argument of LoopOp. - - The body region must contain exactly one block that terminates with - `gml_st.set_yield` which yields a tensor into a subset of outs. + The body region contains exactly one block that terminates with + `gml_st.set_yield` which specifies how to update a subset of the + loop-carried variable on every iteration. Example: ```mlir - %space = gml_st.space [8, 16] : !gml_st.tile<8x16> - - %result = gml_st.for (%i) = (%c0, %c0) to (%c8, %c16) step (%c4, %c4) - outs(%out_ = %output: tensor<8x16xf32>) { - %tile = gml_st.tile %in_space [%i, %j] [4, 4] [1, 1] - : ! gml_st.tile<8x16> to !gml_st.tile<4x4> + %add = gml_st.for (%i) = (%c0, %c0) to (%c8, %c16) step (%c4, %c4) + outs(%out_ = %out: tensor<8x16xf32>) { + %tile = gml_st.tile [%i, %j] [4, 4] [1, 1] + : !gml_st.tile<8x16> to !gml_st.tile<4x4> - %lhs_sub = gml_st.materialize %lhs_[%tile] + %lhs_sub = gml_st.materialize %lhs[%tile] : tensor<8x16xf32>[!gml_st.tile<4x4>] - %rhs_sub = gml_st.materialize %rhs_[%tile] + %rhs_sub = gml_st.materialize %rhs[%tile] : tensor<8x16xf32>[!gml_st.tile<4x4>] %out_sub = gml_st.materialize %out_[%tile] : tensor<8x16xf32>[!gml_st.tile<4x4>] - %result_sub = linalg.generic (%lhs_sub, %rhs_sub, %out_sub) ... + %add_sub = linalg.map + ins(%lhs_sub: tensor<4x4xf32>, %rhs_sub: tensor<4x4xf32>) + outs(%out_sub: tensor<4x4xf32>) + (%lhs_elem: f32, %rhs_elem: f32) { + %0 = arith.addf %lhs_elem, %rhs_elem : f32 + linalg.yield %0 : f32 + } - gml_st.set_yield %result_sub into %out_[%tile] + gml_st.set_yield %add_sub into %out_[%tile] : tensor<4x4xf32> into tensor<16x64xf32>[!gml_st.tile<4x4>] } ``` + + After bufferization the `outs` argument list becomes empty and the loop + does not produce any results. + + ```mlir + gml_st.for (%i, %j) = (%c0, %c0) to (%c8, %c16) step (%c4, %c4) { + %lhs_sub = memref.subview %lhs[%tile] + : memref<8x16xf32> to memref<4x4xf32, #map> + %rhs_sub = memref.subview %rhs[%tile] + : memref<8x16xf32> to memref<4x4xf32, #map> + %out_sub = memref.subview %out[%tile] + : memref<8x16xf32> to memref<4x4xf32, #map> + + linalg.map + ins(%lhs_sub: memref<4x4xf32>, %rhs_sub: memref<4x4xf32>) + outs(%out_sub: memref<4x4xf32>) + (%lhs_elem: f32, %rhs_elem: f32) { + %0 = arith.addf %lhs_elem, %rhs_elem : f32 + linalg.yield %0 : f32 + } + + gml_st.set_yield + } + ``` }]; let arguments = (ins Variadic:$lowerBound, @@ -261,86 +494,50 @@ def GMLST_ForOp : GMLST_LoopLikeOp<"for", []> { "/*outputs=*/ValueRange)>", "nullptr">:$bodyBuilderFn)>, ]; - let extraClassDeclaration = extraBaseClassDeclaration # [{ - /// Number of output operands - unsigned getNumOutputs() { return getOutputs().size(); } - - /// Get the region output args. - Block::BlockArgListType getRegionOutputArgs() { - return getBody()->getArguments().take_back(getNumOutputs()); - } - - /// Get the region output arg that corresponds to an OpOperand. - BlockArgument getRegionOutputArgForOpOperand(OpOperand &opOperand) { - assert(opOperand.getOperandNumber() >= getNumControlOperands() && - "expected an output args operand"); - assert(opOperand.getOwner() == getOperation() && - "opOperand does not belong to this gml_st::ForOp operation"); - return getBody()->getArgument(opOperand.getOperandNumber() - - getNumControlOperands() + getNumLoops()); - } - - /// Get the OpOperand& that corresponds to a region output arg. - OpOperand &getOpOperandForRegionOutputArg(BlockArgument bbArg) { - assert(bbArg.getArgNumber() >= getNumLoops() && - "expected a bbArg that is not an induction variable"); - assert(bbArg.getOwner()->getParentOp() == getOperation() && - "bbArg does not belong to the gml_st::ForOp body"); - return getOperation()->getOpOperand( - getNumControlOperands() + bbArg.getArgNumber() - getNumLoops()); - } - - /// Get the OpResult that corresponds to an OpOperand. - OpResult getResultForOpOperand(OpOperand &opOperand) { - assert(opOperand.getOperandNumber() >= getNumControlOperands() && - "expected an output args operand"); - assert(opOperand.getOwner() == getOperation() && - "opOperand does not belong to this gml_st::ForOp operation"); - return getOperation()->getResult( - opOperand.getOperandNumber() - getNumControlOperands()); - } - - /// Get the OpOperand& that corresponds to an OpResultOpOperand. - OpOperand &getOpOperandForResult(OpResult opResult) { - assert(opResult.getDefiningOp() == getOperation() && - "opResult does not belong to the gml_st::ForOp operation"); - return getOperation()->getOpOperand( - getNumControlOperands() + opResult.getResultNumber()); - } - - /// Return the destinations for a gml_st.for op. - ValueRange getLoopLikeOpInits() { - return getOutputs(); - } - }]; - + let extraClassDeclaration = extraBaseClassDeclaration; let hasCanonicalizer = 1; } -def GMLST_SetYieldOp : GMLST_Op<"set_yield", [Pure, ReturnLike, - Terminator, SameVariadicOperandSize, - SingleBlockImplicitTerminator<"YieldOp"> - ]> { +def GMLST_SetYieldOp : GMLST_Op<"set_yield", [ + Pure, + ReturnLike, + SameVariadicOperandSize, + SingleBlockImplicitTerminator<"YieldOp">, + Terminator + ]> { let summary = "Set yield operation"; let description = [{ - `gml_st.set_yield` is a special terminator operation for - `gml_st.parallel` or `gml_st.for` body. + `gml_st.set_yield` is a special terminator operation for `gml_st.parallel` + or `gml_st.for` body. It specifies how to combine a source tensor, vector + or scalar with the destination tensor or vector. Example: ```mlir - gml_st.set_yield %result_sub at %tile into %dst + // `src` is a tensor. + gml_st.set_yield %src at %tile into %dst : tensor<4x4xf32> into tensor<16x64xf32>[!gml_st.tile<4x4>] + + // `src` is a scalar. + gml_st.set_yield %src at %tile into %dst + : f32 into tensor<16x64xf32>[!gml_st.tile<1x1>] + + // `src` and `dst` are vectors. + gml_st.set_yield %src at %tile into %dst + : vector<4x4xf32> into vector<16x64xf32>[!gml_st.tile<4x4>] + ``` + + The operation is designed to be polymorphic to support non-rectangular + subsets. It will accept `set` arguments of types other than `!gml_st.tile`. }]; + let arguments = (ins Variadic:$srcs, Variadic:$dsts, Variadic:$sets, BoolArrayAttr:$accumulatorFlags); let regions = (region VariadicRegion>:$accumulators); - let hasCustomAssemblyFormat = 1; - let skipDefaultBuilders = 1; let builders = [ OpBuilder<(ins)>, @@ -353,9 +550,10 @@ def GMLST_SetYieldOp : GMLST_Op<"set_yield", [Pure, ReturnLike, "ArrayRef>" :$combiners)> ]; + let skipDefaultBuilders = 1; + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ - unsigned getNumUpdates() { return getSrcs().size(); } // Methods for `dst` arguments. @@ -385,10 +583,29 @@ def GMLST_SetYieldOp : GMLST_Op<"set_yield", [Pure, ReturnLike, let hasCanonicalizer = 1; } +def GMLST_YieldOp : GMLST_Op<"yield", [Pure, ReturnLike, Terminator, + HasParent<"::mlir::gml_st::SetYieldOp">]>, + Arguments<(ins AnyType:$value)> { + let summary = "Yield operation"; + let description = [{ + `gml_st.yield` is a special terminator operation for accumulator regions of + `gml_st.set_yield`. + + Example: + + ```mlir + gml_st.yield %f0: tensor + ``` + }]; + let assemblyFormat = "attr-dict $value `:` type($value)"; +} + // TODO(b/253560795): Figure out where this operation shoud live, and how to // model it properly. -def GMLST_DistributeOp : GMLST_Op<"distribute", [Pure, - AllElementTypesMatch<["source", "result"]>]> { +def GMLST_DistributeOp : GMLST_Op<"distribute", [ + Pure, + AllElementTypesMatch<["source", "result"]> + ]> { let summary = "Tile combining operation"; let description = [{ `gml_st.distribute` is in a sense an inverse operation to @@ -403,6 +620,7 @@ def GMLST_DistributeOp : GMLST_Op<"distribute", [Pure, pipeline. This is also the reason why it does not need a destination argument. }]; + let arguments = (ins AnyVectorOfAnyRank:$source, AnySet:$set); let results = (outs AnyVectorOfAnyRank:$result); diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops_base.td b/tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops_base.td similarity index 87% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops_base.td rename to tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops_base.td index ebfe4f5854d..2560de459ca 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops_base.td +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/IR/gml_st_ops_base.td @@ -23,12 +23,16 @@ include "mlir/IR/OpBase.td" def GmlSt_Dialect : Dialect { let name = "gml_st"; let cppNamespace = "::mlir::gml_st"; + let description = [{ + The GmlSt (Google ML Structured) dialect is intended to hold operations, + types and transformations to assist structured code generation. + }]; let dependentDialects = ["tensor::TensorDialect"]; - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; let useDefaultTypePrinterParser = 1; let useDefaultAttributePrinterParser = 1; let hasConstantMaterializer = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } class GMLST_Op traits> : diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/README.md b/tensorflow/compiler/xla/mlir_hlo/gml_st/README.md similarity index 100% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/README.md rename to tensorflow/compiler/xla/mlir_hlo/gml_st/README.md diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo_gpu/IR/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/gml_st/interfaces/CMakeLists.txt similarity index 73% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo_gpu/IR/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/gml_st/interfaces/CMakeLists.txt index c25af334864..652f653b2ee 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo_gpu/IR/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/interfaces/CMakeLists.txt @@ -12,19 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + include_directories(BEFORE ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) -add_mlir_dialect_library(LmhloGPUDialect - lhlo_gpu_ops.cc +set(LLVM_OPTIONAL_SOURCES + bufferizable_op_interface_impl.cc +) - DEPENDS - MLIRlhlo_gpu_opsIncGen +add_mlir_library(GmlStBufferizableOpInterface + bufferizable_op_interface_impl.cc LINK_LIBS PUBLIC - MhloDialect + GmlStDialect + MLIRBufferizationDialect + MLIRBufferizationTransforms + MLIRDestinationStyleOpInterface MLIRIR - HloOpsCommon + MLIRSupport ) + diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/bufferizable_op_interface_impl.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/interfaces/bufferizable_op_interface_impl.cc similarity index 57% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/bufferizable_op_interface_impl.cc rename to tensorflow/compiler/xla/mlir_hlo/gml_st/interfaces/bufferizable_op_interface_impl.cc index 12dbcd4621d..fff37a31150 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/bufferizable_op_interface_impl.cc +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/interfaces/bufferizable_op_interface_impl.cc @@ -13,211 +13,49 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/gml_st/transforms/bufferizable_op_interface_impl.h" +#include "gml_st/interfaces/bufferizable_op_interface_impl.h" #include +#include #include -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" +#include "gml_st/IR/gml_st_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Support/LogicalResult.h" +using mlir::bufferization::AliasingOpOperandList; +using mlir::bufferization::AliasingOpResultList; using mlir::bufferization::AnalysisState; using mlir::bufferization::BufferizableOpInterface; using mlir::bufferization::BufferizationOptions; using mlir::bufferization::BufferRelation; -using mlir::bufferization::ToMemrefOp; using mlir::bufferization::ToTensorOp; +using mlir::tensor::ExtractSliceOp; namespace mlir { namespace gml_st { namespace { -/// Bufferization of gml_st.loop. Replace with a new gml_st.loop -/// that operates entirely on memrefs. -struct LoopOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - auto loopOp = cast(op); - - // gml_st.loop operands alone do not bufferize to a memory read, but - // one of the uses of their matching bbArgs may. - return state.isValueRead(loopOp.getTiedBlockArgument(opOperand)); - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - // Only operands with an aliasing OpResult (i.e., output operands) bufferize - // to a memory write. - auto bufferizableOp = cast(op); - return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); - } - - SmallVector getAliasingOpResult( - Operation *op, OpOperand &opOperand, - const AnalysisState & /*state*/) const { - auto loopOp = cast(op); - - // Output operands are tied to their corresponding OpResults. - OpResult opResult = loopOp.getTiedOpResult(opOperand); - if (!opResult) return {}; - return {opResult}; - } - - BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/, - const AnalysisState & /*state*/) const { - return BufferRelation::Equivalent; - } - - bool isWritable(Operation * /*op*/, Value /*value*/, - const AnalysisState & /*state*/) const { - // Interestingly, LoopOp's bbArgs can **always** be viewed - // inplace from the perspective of nested ops: - // 1. Either the matching iter operand is not bufferized inplace and an - // alloc + optional copy makes the bbArg itself inplaceable. - // 2. Or the matching iter operand is bufferized inplace and bbArg just - // bufferizes to that too. - return true; - } - - FailureOr getBufferType( - Operation *op, Value value, const BufferizationOptions &options, - const DenseMap &fixedTypes) const { - auto loopOp = cast(op); - if (auto opResult = value.dyn_cast()) { - return bufferization::getBufferType( - loopOp.getOutputs()[opResult.getResultNumber()], options, fixedTypes); - } - BlockArgument bbArg = value.cast(); - return bufferization::getBufferType(loopOp.getTiedOperand(bbArg).get(), - options, fixedTypes); - } - - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions &options) const { - auto loopOp = cast(op); - - // Compute new inputs, outputs and results. - SmallVector newInputs, newOutputs, newResults; - for (unsigned i = loopOp.getNumControlOperands(); - i < loopOp->getNumOperands(); ++i) { - OpOperand &operand = loopOp->getOpOperand(i); - Value rewrittenValue = operand.get(); - if (rewrittenValue.getType().isa()) { - FailureOr maybeBuffer = - getBuffer(rewriter, operand.get(), options); - if (failed(maybeBuffer)) return failure(); - rewrittenValue = *maybeBuffer; - } - if (i < loopOp.getNumControlOperands() + loopOp.getNumInputs()) { - newInputs.push_back(rewrittenValue); - } else { - newOutputs.push_back(rewrittenValue); - if (operand.get().getType().isa()) - newResults.push_back(rewrittenValue); - } - } - - // Create new TiledLoopOp. - auto newLoopOp = rewriter.create( - loopOp.getLoc(), loopOp.getLowerBound(), loopOp.getUpperBound(), - loopOp.getStep(), newInputs, newOutputs, loopOp.getIteratorTypes(), - loopOp.getDistributionTypes()); - - // Remove terminator. - if (!newLoopOp.getBody()->empty()) - rewriter.eraseOp(loopOp.getBody()->getTerminator()); - - // Compute new loop body arguments. - SmallVector newBlockArgs, newRegionInOutArgs, oldRegionInOutArgs; - ValueRange newInductionVars = newLoopOp.getInductionVars(); - newBlockArgs.append(newInductionVars.begin(), newInductionVars.end()); - - ValueRange newRegionInArgs = newLoopOp.getRegionInputArgs(); - ValueRange newRegionOutArgs = newLoopOp.getRegionOutputArgs(); - newRegionInOutArgs.append(newRegionInArgs.begin(), newRegionInArgs.end()); - newRegionInOutArgs.append(newRegionOutArgs.begin(), newRegionOutArgs.end()); - - ValueRange oldRegionInArgs = loopOp.getRegionInputArgs(); - ValueRange oldRegionOutArgs = loopOp.getRegionOutputArgs(); - oldRegionInOutArgs.append(oldRegionInArgs.begin(), oldRegionInArgs.end()); - oldRegionInOutArgs.append(oldRegionOutArgs.begin(), oldRegionOutArgs.end()); - assert(newRegionInArgs.size() == oldRegionInArgs.size() && - "expected same number of input args"); - assert(newRegionOutArgs.size() == oldRegionOutArgs.size() && - "expected same number of output args"); - - for (auto it : llvm::zip(oldRegionInOutArgs, newRegionInOutArgs)) { - Value oldArg = std::get<0>(it); - Value newArg = std::get<1>(it); - rewriter.setInsertionPointToStart(newLoopOp.getBody()); - if (oldArg.getType().isa()) { - newBlockArgs.push_back(rewriter.create( - oldArg.getLoc(), newArg)); - } else { - newBlockArgs.push_back(newArg); - } - } - - // Move old body into new loop. - rewriter.mergeBlocks(loopOp.getBody(), newLoopOp.getBody(), newBlockArgs); - - // Replace previous terminator with a new one that does not yield anything. - auto oldTerminator = - cast(newLoopOp.getBody()->getTerminator()); - rewriter.setInsertionPointToEnd(newLoopOp.getBody()); - auto newTerminator = - rewriter.create(oldTerminator->getLoc()); - - // Copy buffer of yielded tensor to output buffer. If everything bufferized - // inplace, this copy will fold away. - rewriter.setInsertionPoint(newTerminator); - for (auto it : llvm::zip(oldTerminator.getValues(), newOutputs)) { - Value output = std::get<1>(it); - Value toMemrefOp = rewriter.create( - newTerminator.getLoc(), output.getType(), std::get<0>(it)); - if (failed(options.createMemCpy(rewriter, newTerminator.getLoc(), - toMemrefOp, output))) - return failure(); - } - - // Erase old terminator. - rewriter.eraseOp(oldTerminator); - - // Replace results and delete old op. - bufferization::replaceOpWithBufferizedValues(rewriter, op, newResults); - - return success(); - } -}; - // Returns a scalar or a memref type result of `gml_st.materialize` op after // bufferization. FailureOr materializeExtraction(OpBuilder &b, Value memref, MaterializeOp materializeOp) { - Value set = materializeOp.getSet(); - - Operation *setDefiningOp = set.getDefiningOp(); - - Location loc = set.getLoc(); - if (auto tile = dyn_cast(setDefiningOp)) { - if (!materializeOp.getType().isa()) { - auto indices = - getValueOrCreateConstantIndexOp(b, loc, tile.getMixedOffsets()); - return b.create(loc, memref, indices).getResult(); - } - Value subview = b.create( - loc, memref, tile.getMixedOffsets(), tile.getMixedSizes(), - tile.getMixedStrides()); - return subview; - } - return failure(); + Location loc = materializeOp.getLoc(); + if (!materializeOp.getType().isa()) { + auto indices = getValueOrCreateConstantIndexOp( + b, loc, materializeOp.getMixedOffsets()); + return b.create(loc, memref, indices).getResult(); + } + Value subview = b.create( + loc, memref, materializeOp.getMixedOffsets(), + materializeOp.getMixedSizes(), materializeOp.getMixedStrides()); + return subview; } LogicalResult materializeInsertion(OpBuilder &b, Value update, Value set, @@ -259,9 +97,9 @@ LogicalResult materializeInsertion(OpBuilder &b, Value update, Value set, struct MaterializeOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/, + bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand &opOperand, const AnalysisState & /*state*/) const { - return false; + return opOperand.getOperandNumber() == 0; } bool bufferizesToMemoryWrite(Operation * /*op*/, OpOperand & /*opOperand*/, @@ -269,7 +107,7 @@ struct MaterializeOpInterface return false; } - SmallVector getAliasingOpResult( + AliasingOpResultList getAliasingOpResults( Operation *op, OpOperand &opOperand, const AnalysisState & /*state*/) const { auto result = op->getOpResult(0); @@ -281,7 +119,7 @@ struct MaterializeOpInterface BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/, const AnalysisState & /*state*/) const { - return BufferRelation::None; + return BufferRelation::Unknown; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -306,20 +144,28 @@ struct MaterializeOpInterface struct ParallelOpInterface : public BufferizableOpInterface::ExternalModel { - SmallVector getAliasingOpOperand( - Operation *op, OpResult opResult, const AnalysisState & /*state*/) const { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { auto parallelOp = cast(op); - return { - parallelOp.getTerminator().getDstOperand(opResult.getResultNumber())}; + + // gml_st.parallel alone doesn't bufferize to a memory read, one of the uses + // of its matching bbArg may. + return state.isValueRead( + parallelOp.getRegionOutputArgForOpOperand(opOperand)); } - bool isMemoryWrite(Operation *, OpResult, const AnalysisState &) const { - // This op is a memory write. Stop lookup here to avoid finding false - // conflicts involving this op and one of the ops in the region. This is - // similar to how scf.if ops are analyzed. + bool bufferizesToMemoryWrite(Operation * /*op*/, OpOperand & /*opOperand*/, + const AnalysisState & /*state*/) const { + // Outputs of gml_st::ParallelOp are always considered as a write. return true; } + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, + const AnalysisState &) const { + auto parallelOp = cast(op); + return {parallelOp.getResultForOpOperand(opOperand)}; + } + BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/, const AnalysisState & /*state*/) const { return BufferRelation::Equivalent; @@ -331,25 +177,72 @@ struct ParallelOpInterface } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationOptions & /*options*/) const { - auto loopOp = cast(op); + const BufferizationOptions &options) const { + auto parallelOp = cast(op); - // Create new TiledLoopOp. - Optional distTypeAttr; + // Get the bufferized output arguments. + Location loc = op->getLoc(); + SmallVector bufferizedOutputs; + bufferizedOutputs.reserve(parallelOp.getNumOutputs()); + for (Value output : parallelOp.getOutputs()) { + FailureOr maybeBuffer = getBuffer(rewriter, output, options); + if (failed(maybeBuffer)) return failure(); + bufferizedOutputs.push_back(*maybeBuffer); + } + + // Create new ParallelOp. + std::optional distTypeAttr; if (auto distType = cast(op).getDistributionType()) distTypeAttr = rewriter.getStringAttr(*distType); - auto newLoopOp = rewriter.create( - loopOp.getLoc(), TypeRange{llvm::None}, loopOp.getLowerBound(), - loopOp.getUpperBound(), loopOp.getStep(), distTypeAttr); - // Move the old body into the new loop. - rewriter.mergeBlocks(loopOp.getBody(), newLoopOp.getBody(), - newLoopOp.getInductionVars()); + auto newParallelOp = rewriter.create( + loc, TypeRange{}, parallelOp.getLowerBound(), + parallelOp.getUpperBound(), parallelOp.getStep(), ValueRange{}, + distTypeAttr, nullptr); + Block *loopBody = newParallelOp.getBody(); - // Remove the old op. - rewriter.eraseOp(op); + // Add conversions to tensor so that we can reuse the old loop body. + rewriter.setInsertionPointToStart(loopBody); + SmallVector outputsToTensors; + for (auto buf : bufferizedOutputs) { + Value tensor = rewriter.create(loc, buf); + outputsToTensors.push_back(tensor); + } + SmallVector blockArgs = newParallelOp.getInductionVars(); + blockArgs.append(outputsToTensors); + + // Move old body into new for loop. + rewriter.mergeBlocks(parallelOp.getBody(), loopBody, blockArgs); + + // Replace results and delete old op. + bufferization::replaceOpWithBufferizedValues(rewriter, op, + bufferizedOutputs); return success(); } + + FailureOr getBufferType( + Operation *op, Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) const { + auto parallelOp = cast(op); + + if (auto bbArg = value.dyn_cast()) { + // A tensor block argument has the same bufferized type as the + // corresponding output operand. + return bufferization::getBufferType( + parallelOp.getOpOperandForRegionOutputArg(bbArg).get(), options, + fixedTypes); + } + + // The bufferized result type is the same as the bufferized type of the + // corresponding output operand. + return bufferization::getBufferType( + parallelOp.getOutputs()[value.cast().getResultNumber()], + options, fixedTypes); + } + + bool isRepetitiveRegion(Operation * /*op*/, unsigned /*index*/) const { + return true; + } }; struct ForOpInterface @@ -365,7 +258,7 @@ struct ForOpInterface return true; } - SmallVector getAliasingOpResult( + AliasingOpResultList getAliasingOpResults( Operation *op, OpOperand &opOperand, const AnalysisState & /*state*/) const { auto forOp = cast(op); @@ -451,13 +344,13 @@ struct ForOpInterface struct SetYieldOpInterface : public BufferizableOpInterface::ExternalModel { - SmallVector getAliasingOpResult( + AliasingOpResultList getAliasingOpResults( Operation * /*op*/, OpOperand & /*opOperand*/, const AnalysisState & /*state*/) const { return {}; } - bool bufferizesToMemoryRead(Operation * /*op*/, OpOperand & /*opOperand*/, + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState & /*state*/) const { return true; } @@ -476,8 +369,6 @@ struct SetYieldOpInterface const BufferizationOptions &options) const { auto yieldOp = cast(op); Operation *loop = yieldOp->getParentOp(); - if (!isa(loop)) - return yieldOp->emitError("unsupported gml_st::SetYieldOp parent"); rewriter.setInsertionPoint(op); for (const auto &it : @@ -509,7 +400,8 @@ struct SetYieldOpInterface Value resultToTensor = rewriter.create(loop->getLoc(), dstBuffer); - for (OpOperand &use : loopResult.getUses()) { + for (OpOperand &use : + llvm::make_early_inc_range(loopResult.getUses())) { rewriter.updateRootInPlace(use.getOwner(), [&]() { use.set(resultToTensor); }); } @@ -519,11 +411,135 @@ struct SetYieldOpInterface return success(); } - bool isNotConflicting(Operation * /*op*/, OpOperand * /*uRead*/, - OpOperand * /*uConflictingWrite*/, - const AnalysisState & /*state*/) const { + LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, + const AnalysisState &state) const { + OpBuilder::InsertionGuard g(rewriter); + SmallVector outOfPlaceOpOperands; + DenseSet copiedOpOperands; + DenseSet escapingOpOperandCopies; + + // Find all out-of-place OpOperands. + for (OpOperand &opOperand : op->getOpOperands()) { + Type operandType = opOperand.get().getType(); + if (!operandType.isa()) continue; + if (state.isInPlace(opOperand)) continue; + if (operandType.isa()) + return op->emitError("copies of unranked tensors are not supported"); + + AliasingOpResultList aliasingOpResults = + state.getAliasingOpResults(opOperand); + // Is the result yielded from a block? Or are deallocations turned off + // entirely? In either case, mark the allocation as "escaping", so that it + // will not be deallocated. + bool escape = !state.getOptions().createDeallocs || + llvm::any_of(aliasingOpResults, [&](Value v) { + return state.isTensorYielded(v); + }); + + // In all other cases, make a copy of the OpOperand. + outOfPlaceOpOperands.push_back(&opOperand); + if (!state.canOmitTensorCopy(opOperand)) + copiedOpOperands.insert(&opOperand); + if (escape) escapingOpOperandCopies.insert(&opOperand); + } + + // Insert copies of OpOperands before the loop. + rewriter.setInsertionPoint(op->getParentOp()); + for (OpOperand *opOperand : outOfPlaceOpOperands) { + FailureOr copy = allocateTensorForShapedValue( + rewriter, op->getLoc(), opOperand->get(), + escapingOpOperandCopies.contains(opOperand), state.getOptions(), + copiedOpOperands.contains(opOperand)); + if (failed(copy)) return failure(); + rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); }); + } + + return success(); + } + + bool areEquivalentSlices(const AnalysisState &state, + ExtractSliceOp extractSliceOp, SetYieldOp setYieldOp, + int64_t updateIdx) const { + if (!extractSliceOp || !setYieldOp) return false; + if (extractSliceOp != setYieldOp && + !state.areEquivalentBufferizedValues(extractSliceOp.getSource(), + setYieldOp.getDsts()[updateIdx])) { + return false; + } + if (!sameOffsetsSizesAndStrides( + extractSliceOp, + setYieldOp.getSets()[updateIdx].getDefiningOp(), + isEqualConstantIntOrValue)) + return false; return true; } + + /// Return true if `value` is originating from an ExtractSliceOp that matches + /// the given SetYieldOp. + bool matchesInsertDestination(const AnalysisState &state, Value value, + SetYieldOp setYieldOp, + int64_t updateIdx) const { + // Look for matching slices. + auto matchesSlice = [&](Value val) { + if (auto materializeOp = val.getDefiningOp()) { + if (areEquivalentSlices(state, materializeOp, setYieldOp, updateIdx)) { + return true; + } + } + return false; + }; + return llvm::all_of( + state.findValueInReverseUseDefChain(value, matchesSlice), matchesSlice); + } + + // Copied and modified for gml_st.materialize/gml_st.set_yield pairs from + // mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp + // Takes into account that gml_st.set_yield can have multiple src/dst pairs. + bool isNotConflicting(Operation *op, OpOperand *uRead, + OpOperand *uConflictingWrite, + const AnalysisState &state) const { + if (llvm::isa(op->getParentOp())) { + return true; + } + Operation *readingOp = uRead->getOwner(); + Operation *conflictingWritingOp = uConflictingWrite->getOwner(); + + // Special rules for matching SetYieldOp/ExtractSliceOp pairs. If + // uRead is an SetYieldOp... + if (auto setYieldOp = dyn_cast(readingOp)) { + for (int64_t updateIdx : + llvm::seq(0, setYieldOp.getNumUpdates())) { + OpOperand &srcOpOperand = setYieldOp->getOpOperand(updateIdx); + OpOperand *dstOpOperand = setYieldOp.getDstOperand(updateIdx); + + if (uRead == dstOpOperand /*dest*/ && + matchesInsertDestination(state, uConflictingWrite->get(), + setYieldOp, updateIdx)) + return true; + + if (uRead == &srcOpOperand /*source*/ && + uConflictingWrite == dstOpOperand /*dest*/ && + matchesInsertDestination(state, uRead->get(), setYieldOp, + updateIdx)) + return true; + } + } + + // If uConflictingWrite is an SetYieldOp... + if (auto setYieldOp = dyn_cast(conflictingWritingOp)) { + for (int64_t updateIdx : + llvm::seq(0, setYieldOp.getNumUpdates())) { + if (uConflictingWrite == setYieldOp.getDstOperand(updateIdx) && + state.areEquivalentBufferizedValues( + uRead->get(), setYieldOp.getSrcs()[updateIdx]) && + matchesInsertDestination(state, setYieldOp.getSrcs()[updateIdx], + setYieldOp, updateIdx)) + return true; + } + } + + return false; + } }; } // namespace @@ -535,7 +551,6 @@ void mlir::gml_st::registerBufferizableOpInterfaceExternalModels( registry.addExtension( +[](MLIRContext *ctx, gml_st::GmlStDialect * /*dialect*/) { ForOp::attachInterface(*ctx); - LoopOp::attachInterface(*ctx); MaterializeOp::attachInterface(*ctx); ParallelOp::attachInterface(*ctx); SetYieldOp::attachInterface(*ctx); diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/bufferizable_op_interface_impl.h b/tensorflow/compiler/xla/mlir_hlo/gml_st/interfaces/bufferizable_op_interface_impl.h similarity index 78% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/bufferizable_op_interface_impl.h rename to tensorflow/compiler/xla/mlir_hlo/gml_st/interfaces/bufferizable_op_interface_impl.h index 6a4844771ef..5b3e35fc7b1 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/bufferizable_op_interface_impl.h +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/interfaces/bufferizable_op_interface_impl.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_BUFFERIZABLE_OP_INTERFACE_IMPL_H -#define MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_BUFFERIZABLE_OP_INTERFACE_IMPL_H +#ifndef MLIR_HLO_GML_ST_INTERFACES_BUFFERIZABLE_OP_INTERFACE_IMPL_H +#define MLIR_HLO_GML_ST_INTERFACES_BUFFERIZABLE_OP_INTERFACE_IMPL_H namespace mlir { class DialectRegistry; @@ -26,4 +26,4 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); } // namespace gml_st } // namespace mlir -#endif // MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_BUFFERIZABLE_OP_INTERFACE_IMPL_H +#endif // MLIR_HLO_GML_ST_INTERFACES_BUFFERIZABLE_OP_INTERFACE_IMPL_H diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/CMakeLists.txt new file mode 100644 index 00000000000..ccc7ca94d0d --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/CMakeLists.txt @@ -0,0 +1,114 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set(LLVM_TARGET_DEFINITIONS passes.td) +mlir_tablegen(passes.h.inc -gen-pass-decls -name GmlSt) +add_public_tablegen_target(MLIRGmlStPassIncGen) + +set(LLVM_TARGET_DEFINITIONS test_passes.td) +mlir_tablegen(test_passes.h.inc -gen-pass-decls -name GmlStTest) +add_public_tablegen_target(MLIRGmlStTestPassIncGen) + +include_directories(BEFORE + ${CMAKE_CURRENT_BINARY_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}) + +add_mlir_library(GmlStPasses + add_debug_info/add_debug_info.cc + collapse_shape/collapse_shape.cc + compose_extract_insert_slice/compose_extract_insert_slice.cc + cpu_tiling/cpu_tiling_pipeline.cc + cpu_tiling/transform_map_for_cpu.cc + cpu_tiling/transform_matmul_for_cpu.cc + cpu_tiling/transform_reduce_for_cpu.cc + cpu_tiling/transform_reverse_for_cpu.cc + cpu_tiling/transform_scatter_for_cpu.cc + cpu_tiling/transform_sort_for_cpu.cc + cpu_tiling/transform_transpose_for_cpu.cc + fusion/fusion.cc + gml_st_simtfy/gml_st_simtfy.cc + gml_st_to_gpu/gml_st_to_gpu.cc + gml_st_to_scf/gml_st_to_scf.cc + gpu_tiling/greedy_fusion.cc + gpu_tiling/tiling_cwise.cc + gpu_tiling/tiling_gpu_warp.cc + peeling/peeling.cc + rewrite_vector_ops/rewrite_vector_contract.cc + rewrite_vector_ops/rewrite_vector_multi_reduction.cc + rewrite_vector_ops/rewrite_vector_transpose.cc + scalarization/scalarization.cc + tiling/tiling.cc + tiling_softmax/tiling_softmax.cc + triton_tiling/transform_matmul_for_triton.cc + vectorization/vectorization.cc + vectorization/vectorize_copy.cc + vectorization/vectorize_for_cpu.cc + vectorization/vectorize_for_gpu.cc + + DEPENDS + MLIRGmlStPassIncGen + MLIRGmlStUtils + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRDestinationStyleOpInterface + MhloDialect + MLIRDialectUtils + MLIRAffineDialect + MLIRArithDialect + MLIRFuncDialect + MLIRGPUOps + MLIRIR + MLIRLinalgDialect + MLIRLinalgTransforms + MLIRMemRefDialect + MLIRPass + MLIRSCFUtils + MLIRSupport + MLIRVectorDialect +) + +add_mlir_library(GmlStTransforms + transforms.cc + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + GmlStDialect + MLIRAffineDialect + MLIRDialectUtils + MLIRIR +) + +add_mlir_library(GmlStTestPasses + test_passes.cc + + DEPENDS + MLIRGmlStTestPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + GmlStBufferizableOpInterface + GmlStDialect + GmlStTransforms + MLIRPass + MLIRTransforms +) diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/add_debug_info/add_debug_info.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/add_debug_info/add_debug_info.cc new file mode 100644 index 00000000000..8cf648cb9b2 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/add_debug_info/add_debug_info.cc @@ -0,0 +1,74 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "gml_st/transforms/passes.h" +#include "llvm/BinaryFormat/Dwarf.h" +#include "llvm/Support/Path.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace gml_st { +namespace { + +#define GEN_PASS_DEF_ADDDEBUGINFOPASS +#include "gml_st/transforms/passes.h.inc" + +struct AddDebugInfoPass : public impl::AddDebugInfoPassBase { + void runOnOperation() override { + auto module = getOperation(); + auto *context = &getContext(); + OpBuilder builder(context); + std::string inputFilePath("-"); + + if (auto fileLoc = module.getLoc().dyn_cast()) + inputFilePath = fileLoc.getFilename().getValue(); + + auto fileAttr = + LLVM::DIFileAttr::get(context, llvm::sys::path::filename(inputFilePath), + llvm::sys::path::parent_path(inputFilePath)); + + auto producer = StringAttr::get(context, "XLA CPU"); + auto cuAttr = LLVM::DICompileUnitAttr::get( + context, llvm::dwarf::DW_LANG_C_plus_plus_17, fileAttr, producer, + /*isOptimized=*/false, LLVM::DIEmissionKind::LineTablesOnly); + module.walk([&](func::FuncOp funcOp) { + StringAttr funcName = StringAttr::get(context, funcOp.getName()); + auto bT = LLVM::DIBasicTypeAttr::get( + context, llvm::dwarf::DW_TAG_base_type, "void", /*sizeInBits=*/0, + /*encoding=*/1); + auto subTypeAttr = LLVM::DISubroutineTypeAttr::get( + context, llvm::dwarf::DW_CC_normal, {bT}); + auto spAttr = LLVM::DISubprogramAttr::get( + context, cuAttr, fileAttr, funcName, funcName, fileAttr, /*line=*/1, + /*scopeline=*/1, LLVM::DISubprogramFlags::Definition, subTypeAttr); + funcOp->setLoc(builder.getFusedLoc({funcOp->getLoc()}, spAttr)); + }); + } +}; +} // namespace + +std::unique_ptr> createAddDebugInfoPass() { + return std::make_unique(); +} + +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/collapse_shape/collapse_shape.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/collapse_shape/collapse_shape.cc new file mode 100644 index 00000000000..7da420ef633 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/collapse_shape/collapse_shape.cc @@ -0,0 +1,351 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/transforms.h" +#include "gml_st/utils/linalg_utils.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace gml_st { +namespace { + +#define GEN_PASS_DEF_COLLAPSESHAPEPASS +#include "gml_st/transforms/passes.h.inc" + +// Creates reassociation indices for `shape_collapse` and `shape_expand` ops. +// Given `rank`(N) and `retainTrailingDims`(M), returns the following +// reassociation: +// [[0, 1, ..., N-M-1], [N-M], [N-M+1], ..., [N-1]] +// |--- retainTrailingDims ---| +// |-------------------- rank --------------------| +SmallVector getCollapsingReassociationIndices( + int64_t rank, int64_t retainTrailingDims) { + SmallVector reassociation; + reassociation.reserve(retainTrailingDims + 1); + if (rank > retainTrailingDims) { + auto seq = llvm::seq(0, rank - retainTrailingDims); + reassociation.emplace_back(seq.begin(), seq.end()); + } + for (int64_t i = rank - retainTrailingDims; i < rank; ++i) + reassociation.push_back({i}); + return reassociation; +} + +struct CollapseBcastPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + CollapseBcastPattern(MLIRContext* ctx, int64_t retainTrailingDims) + : OpRewritePattern(ctx), + retainTrailingDims(retainTrailingDims) {} + + LogicalResult matchAndRewrite(linalg::BroadcastOp op, + PatternRewriter& rewriter) const override { + Value init = op.getInit(); + auto initTy = init.getType().cast(); + int64_t initRank = initTy.getRank(); + int64_t numCollapsedDims = initRank - retainTrailingDims; + + if (numCollapsedDims < 2) { + return rewriter.notifyMatchFailure(op, "no dimension to collapse"); + } + + // Dimensions to be collapsed must either be all broadcasted or not + // broadcasted. + llvm::ArrayRef nonBroadcastedDims = op.getDimensions(); + + bool firstDimsBroadcasted = true; + if (!nonBroadcastedDims.empty()) { + int64_t i = 0; + while (i < (int64_t)nonBroadcastedDims.size() && + nonBroadcastedDims[i] == i && i < numCollapsedDims) { + ++i; + } + if (i >= numCollapsedDims) { + firstDimsBroadcasted = false; + } else if (llvm::any_of(nonBroadcastedDims, + [numCollapsedDims](unsigned dim) { + return dim < numCollapsedDims; + })) { + return rewriter.notifyMatchFailure( + op, "collapsed dims are not broadcasted in order"); + } + } + + Value operand = op.getInput(); + auto operandTy = operand.getType().cast(); + int64_t operandRank = operandTy.getRank(); + llvm::DenseSet nonBroadcastedDimsSet(nonBroadcastedDims.begin(), + nonBroadcastedDims.end()); + llvm::SmallVector collapsedNonBroadcastedDims; + collapsedNonBroadcastedDims.reserve(numCollapsedDims + + (firstDimsBroadcasted ? 1 : 0)); + for (int64_t dim = numCollapsedDims; dim < initRank; ++dim) { + if (nonBroadcastedDimsSet.contains(dim)) { + collapsedNonBroadcastedDims.push_back(dim - numCollapsedDims + 1); + } + } + int64_t operandRetainTrailingDims = + retainTrailingDims - collapsedNonBroadcastedDims.size(); + + // Collapse operand and init tensor. + // For bcasts, this retains the last `retainTrailingDims` dimensions of the + // *result* and collapses all others. + Location loc = op.getLoc(); + Value collapsedOperand = operand; + if (operandRank > operandRetainTrailingDims + 1) { + SmallVector operandReassociation = + getCollapsingReassociationIndices(operandRank, + operandRetainTrailingDims); + collapsedOperand = rewriter.createOrFold( + loc, operand, operandReassociation); + } + SmallVector initReassociation = + getCollapsingReassociationIndices(initRank, retainTrailingDims); + Value collapsedInit = + rewriter.create(loc, init, initReassociation); + + // Create collapsed bcast op. + if (!firstDimsBroadcasted) { + collapsedNonBroadcastedDims.push_back(0); + } + Value collapsedBcastOp = + rewriter + .create( + loc, collapsedOperand, collapsedInit, + ArrayRef(collapsedNonBroadcastedDims)) + .getResult() + .front(); + + // Re-expand broadcast op and replace the original. + auto reexpandedBcastOp = rewriter.create( + loc, initTy, collapsedBcastOp, initReassociation); + rewriter.replaceOp(op, reexpandedBcastOp.getResult()); + return success(); + } + + private: + int64_t retainTrailingDims; +}; + +struct CollapseReductionPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + CollapseReductionPattern(MLIRContext* ctx, int64_t retainTrailingDims) + : OpRewritePattern(ctx), + retainTrailingDims(retainTrailingDims) {} + + LogicalResult matchAndRewrite(linalg::ReduceOp op, + PatternRewriter& rewriter) const override { + if (op.getNumDpsInits() != 1 || op.getDimensions().empty()) + return failure(); + int64_t reductionDim = op.getDimensions()[0]; + + Value operand = op.getInputs().front(); + auto operandTy = operand.getType().cast(); + int64_t operandRank = operandTy.getRank(); + + if (operandRank <= retainTrailingDims + 1) { + return rewriter.notifyMatchFailure(op, "no dimension to collapse"); + } + + if (operandRank - 1 - reductionDim >= retainTrailingDims) { + return rewriter.notifyMatchFailure( + op, "reduction dimension must be retained"); + } + + Value init = op.getInits().front(); + auto initTy = init.getType().cast(); + int64_t initRank = initTy.getRank(); + + // Collapse operand and init tensor. + // For reductions, this retains the last `retainTrailingDims` dimensions of + // the *operand* and collapses all others. + Location loc = op.getLoc(); + SmallVector operandReassociation = + getCollapsingReassociationIndices(operandRank, retainTrailingDims); + Value collapsedOperand = rewriter.create( + loc, operand, operandReassociation); + SmallVector initReassociation = + getCollapsingReassociationIndices(initRank, retainTrailingDims - 1); + Value collapsedInit = + rewriter.create(loc, init, initReassociation); + + auto collapsedOperandTy = + collapsedOperand.getType().cast(); + int64_t collapsedOperandRank = collapsedOperandTy.getRank(); + auto collapsedInitTy = collapsedInit.getType().cast(); + + // Create collapsed reduction op. + int64_t collapsedReductionDim = + reductionDim - operandRank + collapsedOperandRank; + SmallVector collapsedIteratorTypes( + collapsedOperandRank, utils::IteratorType::parallel); + collapsedIteratorTypes[collapsedReductionDim] = + utils::IteratorType::reduction; + auto collapsedReductionOp = rewriter.create( + loc, collapsedInitTy, collapsedOperand, collapsedInit, + ArrayRef({collapsedReductionDim})); + collapsedReductionOp.getRegion().takeBody(op.getBodyRegion()); + + // Re-expand reduction op and replace the original. + auto reexpandedReductionOp = rewriter.create( + loc, initTy, collapsedReductionOp.getResults().front(), + initReassociation); + rewriter.replaceOp(op, reexpandedReductionOp.getResult()); + return success(); + } + + private: + int64_t retainTrailingDims; +}; + +linalg::MapOp createCollapsedMapOp( + linalg::MapOp mapOp, PatternRewriter& rewriter, + const SmallVector& reassociation) { + // Collapsed operands and init tensor. + Location loc = mapOp.getLoc(); + SmallVector collapsedOperands = llvm::to_vector( + llvm::map_range(mapOp.getInputs(), [&](Value it) -> Value { + return rewriter.create(loc, it, reassociation); + })); + Value init = mapOp.getInit(); + Value collapsedInit = + rewriter.create(loc, init, reassociation); + + // Create collapsed map op. + auto collapsedInitTy = collapsedInit.getType().cast(); + auto collapsedMapOp = rewriter.create( + loc, collapsedInitTy, collapsedOperands, collapsedInit); + IRMapping bvm; + mapOp.getBodyRegion().cloneInto(&collapsedMapOp.getRegion(), bvm); + return collapsedMapOp; +} + +struct CollapseMapPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + CollapseMapPattern(MLIRContext* ctx, int64_t retainTrailingDims) + : OpRewritePattern(ctx), + retainTrailingDims(retainTrailingDims) {} + + LogicalResult matchAndRewrite(linalg::MapOp op, + PatternRewriter& rewriter) const override { + Value init = op.getInit(); + auto initTy = init.getType().cast(); + int64_t rank = initTy.getRank(); + + if (rank <= retainTrailingDims + 1) { + return rewriter.notifyMatchFailure(op, "no dimension to collapse"); + } + + SmallVector reassociation = + getCollapsingReassociationIndices(rank, retainTrailingDims); + auto collapsedMapOp = createCollapsedMapOp(op, rewriter, reassociation); + + // Re-expand map op and replace the original. + auto reexpandedMapOp = rewriter.create( + op.getLoc(), initTy, collapsedMapOp.getResult().front(), reassociation); + rewriter.replaceOp(op, reexpandedMapOp.getResult()); + return success(); + } + + private: + int64_t retainTrailingDims; +}; + +struct MoveCollapseBeforeMapPattern + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + explicit MoveCollapseBeforeMapPattern(MLIRContext* ctx) + : OpRewritePattern(ctx) {} + + LogicalResult matchAndRewrite(tensor::CollapseShapeOp op, + PatternRewriter& rewriter) const override { + auto mapOp = op.getSrc().getDefiningOp(); + if (!mapOp) return failure(); + auto collapsedMapOp = + createCollapsedMapOp(mapOp, rewriter, op.getReassociationIndices()); + rewriter.replaceOp(op, collapsedMapOp.getResult()); + return success(); + } +}; + +struct CollapseShapePass + : public impl::CollapseShapePassBase { + using CollapseShapePassBase::CollapseShapePassBase; + + void getDependentDialects(DialectRegistry& registry) const override { + CollapseShapePassBase::getDependentDialects(registry); + + // TODO(frgossen): Move these iface implementations into the tensor dialect. + // Some of its canonicalizations depend on it. Until then, we have to + // register them explicitly. + tensor::registerInferTypeOpInterfaceExternalModels(registry); + } + + void runOnOperation() override { + func::FuncOp f = getOperation(); + MLIRContext* ctx = &getContext(); + + // Populate shape-collapsing patterns for cwise ops, reductions, and bcasts. + RewritePatternSet patterns(ctx); + patterns.add(ctx, retainTrailingDims); + // By moving CollapseShapeOp before MapOp, we can potentially remove it if + // it cancels out with an ExpandShapeOp. + patterns.add(ctx); + + // Collect some related canonicalization patterns. + linalg::BroadcastOp::getCanonicalizationPatterns(patterns, ctx); + linalg::FillOp::getCanonicalizationPatterns(patterns, ctx); + linalg::MapOp::getCanonicalizationPatterns(patterns, ctx); + linalg::ReduceOp::getCanonicalizationPatterns(patterns, ctx); + tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, ctx); + tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx); + tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, ctx); + tensor::populateFoldTensorEmptyPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> createCollapseShapePass() { + return std::make_unique(); +} + +std::unique_ptr> createCollapseShapePass( + const CollapseShapePassOptions& options) { + return std::make_unique(options); +} + +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/compose_extract_insert_slice/compose_extract_insert_slice.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/compose_extract_insert_slice/compose_extract_insert_slice.cc new file mode 100644 index 00000000000..3765bcb2a04 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/compose_extract_insert_slice/compose_extract_insert_slice.cc @@ -0,0 +1,55 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "gml_st/transforms/passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace gml_st { +namespace { + +#define GEN_PASS_DEF_COMPOSEEXTRACTINSERTSLICEPASS +#include "gml_st/transforms/passes.h.inc" + +struct ComposeExtractInsertSlicePass + : public impl::ComposeExtractInsertSlicePassBase< + ComposeExtractInsertSlicePass> { + void runOnOperation() override { + MLIRContext* ctx = &getContext(); + RewritePatternSet patterns(ctx); + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +createComposeExtractInsertSlicePass() { + return std::make_unique(); +} + +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/cpu_tiling_pipeline.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/cpu_tiling_pipeline.cc new file mode 100644 index 00000000000..4fc64e49f2f --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/cpu_tiling_pipeline.cc @@ -0,0 +1,51 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "gml_st/transforms/passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace gml_st { + +void addCPUTilingPipeline(OpPassManager& pm, + const GmlStCPUPipelineOptions& options) { + using func::FuncOp; + + pm.addNestedPass(createTransformScatterForCpuPass()); + pm.addNestedPass(createTransformReduceForCpuPass( + options.vectorSize, options.reduction1DTileSize, + options.reduction2DTileSizes)); + pm.addNestedPass(createTransformMatmulForCpuPass( + options.matmulTileSizes, options.lowerToMmt4d)); + pm.addNestedPass(createTransformTransposeForCpuPass()); + pm.addNestedPass(createTransformMapForCpuPass(options.vectorSize)); + pm.addNestedPass(createTransformSortForCpuPass()); + pm.addNestedPass( + mlir::gml_st::createTransformReverseForCpuPass()); + + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + + pm.addNestedPass(createComposeExtractInsertSlicePass()); + pm.addNestedPass(createVectorizeForCPUPass()); + pm.addNestedPass(createScalarizationPass()); + pm.addNestedPass(createRewriteVectorContractPass()); +} + +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transform_map_for_cpu.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_map_for_cpu.cc similarity index 51% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transform_map_for_cpu.cc rename to tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_map_for_cpu.cc index ddbb6aa07e1..8bb766a6720 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transform_map_for_cpu.cc +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_map_for_cpu.cc @@ -16,14 +16,16 @@ limitations under the License. #include #include -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h" -#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h" +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/fusion/fusion.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/peeling/peeling.h" +#include "gml_st/transforms/tiling/tiling.h" +#include "gml_st/transforms/transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -33,33 +35,82 @@ namespace mlir::gml_st { namespace { #define GEN_PASS_DEF_TRANSFORMMAPFORCPUPASS -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h.inc" +#include "gml_st/transforms/passes.h.inc" + +static constexpr llvm::StringRef kMapTransformedLabel = + "__map_transformed_label__"; struct TileMapPattern : public OpRewritePattern { - TileMapPattern(MLIRContext *context, TilingOptions options, + TileMapPattern(MLIRContext *context, int64_t innerDimTileSize, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), - options(std::move(options)) {} + innerDimTileSize(innerDimTileSize) {} LogicalResult matchAndRewrite(linalg::MapOp op, PatternRewriter &rewriter) const override { - if (hasTransformationAttr(op)) return failure(); + if (hasLabel(op, kMapTransformedLabel)) return failure(); - auto tilingResult = - tile(options, rewriter, cast(op.getOperation())); - if (failed(tilingResult)) return failure(); + if (isa(op->getParentOp())) + return rewriter.notifyMatchFailure( + op, "has already been tiled by another pass."); + + auto fuseFilterFn = [](Operation *op) { + return isa(op); + }; + + // Find there another linalg.map where this op can be fused. + op = findRootMap(op, fuseFilterFn); + + if (hasLabel(op, kMapTransformedLabel)) return failure(); + + mlir::gml_st::TilingOptions opts; + opts.tileSizeComputationFn = [&](OpBuilder &b, Operation *op) { + auto numLoops = cast(op).getNumLoops(); + SmallVector tiles( + numLoops, b.create(op->getLoc(), 1)); + if (!tiles.empty()) + tiles.back() = + b.create(op->getLoc(), innerDimTileSize); + return tiles; + }; + + auto tiledLoop = tileUsingGmlStParallelAndFuseGreedily( + rewriter, op, opts, kMapTransformedLabel, fuseFilterFn); + if (failed(tiledLoop)) return failure(); + + // Peel parallel loops. + auto peelingResult = peelAllLoops(*tiledLoop, rewriter); + setLabel(*tiledLoop, kPerfectlyTiledLoopLabel); + + // Tile ops in the peeled loop again, to size 1, so they can be + // scalarized. + if (failed(tilePeeledOpsToScalars(rewriter, peelingResult, + kMapTransformedLabel, fuseFilterFn))) + return failure(); - // If we did not tile (e.g. when all tile sizes are 0), do not replace - // original op and just mark it as transformed then return. - if (tilingResult->loop != nullptr) { - rewriter.replaceOp(op, tilingResult->loop->getResults()); - } - setTransformationAttr(rewriter, tilingResult->tiledOp); return success(); } private: - TilingOptions options; + // Find the root of the fusion cluster. + linalg::MapOp findRootMap( + linalg::MapOp op, + llvm::function_ref fuseFilterFn) const { + linalg::MapOp rootMap = op; + + Operation *curOp = op; + while (fuseFilterFn(curOp)) { + auto users = llvm::to_vector(curOp->getUsers()); + // The op has more than 1 user. It will no be fused. + if (users.size() != 1) break; + curOp = users[0]; + + if (auto curMap = dyn_cast(curOp)) rootMap = curMap; + } + return rootMap; + } + + int64_t innerDimTileSize; }; struct TransformMapForCpuPass @@ -69,32 +120,21 @@ struct TransformMapForCpuPass void getDependentDialects(DialectRegistry ®istry) const final { registry.insert(); - mlir::gml_st::registerGmlStTilingInterfaceExternalModels(registry); + linalg::registerTilingInterfaceExternalModels(registry); } void runOnOperation() override { func::FuncOp f = getOperation(); MLIRContext *context = &getContext(); - mlir::gml_st::TilingOptions opts; - - opts.tileSizeComputationFn = [&](OpBuilder &b, Operation *op) { - auto numLoops = cast(op).getNumLoops(); - SmallVector tiles( - numLoops, b.create(op->getLoc(), 1)); - if (!tiles.empty()) - tiles.back() = b.create(op->getLoc(), tileSize); - return tiles; - }; - RewritePatternSet patterns(context); - patterns.add(context, opts); + patterns.add(context, tileSize); if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { return signalPassFailure(); } - f.walk([](linalg::MapOp op) { gml_st::removeTransformationAttr(op); }); + f.walk([](linalg::MapOp op) { removeLabel(op, kMapTransformedLabel); }); } }; diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_matmul_for_cpu.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_matmul_for_cpu.cc new file mode 100644 index 00000000000..4a9b59d58eb --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_matmul_for_cpu.cc @@ -0,0 +1,788 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/fusion/fusion.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/peeling/peeling.h" +#include "gml_st/transforms/tiling/tiling.h" +#include "gml_st/transforms/transforms.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::gml_st { +namespace { + +#define GEN_PASS_DEF_TRANSFORMMATMULFORCPUPASS +#define GEN_PASS_DEF_SIMPLIFYDEADCOPYPASS +#include "gml_st/transforms/passes.h.inc" + +static constexpr llvm::StringRef kMatmulTransformedLabel = + "__matmul_transformed_label__"; + +// Helper to pick the tile shapes to use as the 2 inner dimensions of the +// 4D shapes appearing in a Mmt4D. +class Mmt4DTileParams { + public: + Mmt4DTileParams(ArrayRef m0k0n0, const llvm::StringRef comment) + : m0(m0k0n0[0]), k0(m0k0n0[1]), n0(m0k0n0[2]), comment(comment) {} + std::array lhs() const { return {m0, k0}; } + std::array rhs() const { return {k0, n0}; } + std::array acc() const { return {m0, n0}; } + std::array rhsTranspose() const { return {n0, k0}; } + const std::string &getComment() const { return comment; } + + private: + const int64_t m0; + const int64_t k0; + const int64_t n0; + const std::string comment; +}; + +Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim) { + ShapedType type = v.getType().cast(); + if (!type.isDynamicDim(dim)) { + return builder.create(loc, type.getDimSize(dim)); + } + return TypeSwitch(v.getType()) + .Case([&](RankedTensorType /*t*/) -> Value { + return builder.create(loc, v, dim); + }) + .Case([&](MemRefType /*t*/) -> Value { + return builder.create(loc, v, dim); + }); +} + +OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, int64_t dim) { + auto t = v.getType().cast(); + if (t.isDynamicDim(dim)) { + return getDimValue(builder, loc, v, dim); + } + return builder.getI64IntegerAttr(t.getDimSize(dim)); +} + +// Returns dimensions of |shapedTypeValue|, handling both static and dynamic +// shapes. +SmallVector getDims(OpBuilder &builder, Location loc, + Value shapedTypeValue) { + return llvm::to_vector(llvm::map_range( + llvm::seq( + 0, shapedTypeValue.getType().cast().getRank()), + [&](int64_t dim) { return getDim(builder, loc, shapedTypeValue, dim); })); +} + +Optional getPaddingValue(Value &source) { + auto padOp = source.getDefiningOp(); + if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad()) + return std::nullopt; + + Value constantPaddingValue = padOp.getConstantPaddingValue(); + if (!constantPaddingValue) return std::nullopt; + + source = padOp.getSource(); + return constantPaddingValue; +} + +// Returns a tiled and packed value of |source|, the data layout is described by +// |innerDimsPos|, |innerTileSizes| and |outerDimsPerm|. +Value pack(Location loc, PatternRewriter &rewriter, Value source, + ArrayRef innerDimsPos, ArrayRef innerTileSizes, + ArrayRef outerDimsPerm) { + SmallVector innerTileSizesOfr = + getAsOpFoldResult(rewriter.getI64ArrayAttr(innerTileSizes)); + auto empty = tensor::PackOp::createDestinationTensor( + rewriter, loc, source, innerTileSizesOfr, innerDimsPos, outerDimsPerm); + Optional paddingValue = getPaddingValue(source); + return rewriter.create(loc, source, empty, innerDimsPos, + innerTileSizesOfr, paddingValue, + outerDimsPerm); +} + +// Returns an unpacked value of |source|, the data layout is described by +// |innerDimsPos|, |innerTileSizes| and |outerDimsPerm|. |resultShapeValue| is +// used to create the destination tensor for the resulting unpacked value. +Value unpack(Location loc, PatternRewriter &rewriter, Value source, + Value resultShapeValue, ArrayRef innerDimsPos, + ArrayRef innerTileSizes, + ArrayRef outerDimsPerm) { + SmallVector resultDims = + getDims(rewriter, loc, resultShapeValue); + auto empty = rewriter.create( + loc, resultDims, + source.getType().cast().getElementType()); + + SmallVector innerTileSizesOfr = + getAsOpFoldResult(rewriter.getI64ArrayAttr(innerTileSizes)); + + return rewriter.create(loc, source, empty, innerDimsPos, + innerTileSizesOfr, outerDimsPerm); +} + +// Returns true if an input of the given |inputShape| needs padding to +// ensure that its shape will be a multiple of |tileShape|. That's always true +// in the dynamic shape case. +bool needsPadding(ArrayRef inputShape, ArrayRef tileShape) { + assert(inputShape.size() == tileShape.size()); + for (size_t i = 0; i < inputShape.size(); i++) { + if (inputShape[i] == ShapedType::kDynamic) { + return true; + } + if (inputShape[i] % tileShape[i] != 0) { + return true; + } + } + return false; +} + +// Pads |input| on the bottom and on the right to the next multiple of +// |tileShape|. +Value pad(Location loc, PatternRewriter &rewriter, Value input, + ArrayRef tileShape) { + SmallVector lowPadding, highPadding; + SmallVector resultTypeShape; + auto inputType = input.getType().cast(); + ArrayRef inputShape = inputType.getShape(); + if (!needsPadding(inputShape, tileShape)) { + return input; + } + int64_t rank = inputType.getRank(); + for (int64_t i = 0; i < rank; ++i) { + // No 'low' padding i.e. no padding at the top and on the left. + lowPadding.push_back(rewriter.getIndexAttr(0)); + // 'High' padding i.e. padding at the bottom and on the right, and the + // result type shape, will be dynamic in any dimension if and only if the + // input shape is. + if (inputShape[i] == ShapedType::kDynamic) { + resultTypeShape.push_back(ShapedType::kDynamic); + // There only remains to compute the 'high' padding Value. + auto add = [&](Value a, Value b) { + return rewriter.create(loc, a, b); + }; + auto sub = [&](Value a, Value b) { + return rewriter.create(loc, a, b); + }; + auto rem = [&](Value a, Value b) { + return rewriter.create(loc, a, b); + }; + // Compare to the plainer distanceToNextMultipleOf in the static + // dimension case below. + auto distanceToNextMultipleOf = [&](Value a, Value b) { + Value one = rewriter.create(loc, 1); + Value bMinusOne = sub(b, one); + return sub(bMinusOne, rem(add(a, bMinusOne), b)); + }; + Value inputDim = rewriter.create(loc, input, i); + Value tileDim = + rewriter.create(loc, tileShape[i]); + Value padding = distanceToNextMultipleOf(inputDim, tileDim); + highPadding.push_back(padding); + } else { + auto distanceToNextMultipleOf = [=](int64_t a, int64_t b) { + int64_t bMinusOne = b - 1; + return bMinusOne - ((a + bMinusOne) % b); + }; + int64_t inputDim = inputShape[i]; + int64_t tileDim = tileShape[i]; + int64_t padding = distanceToNextMultipleOf(inputDim, tileDim); + resultTypeShape.push_back(inputDim + padding); + highPadding.push_back(rewriter.getIndexAttr(padding)); + } + } + Type elementType = inputType.getElementType(); + RankedTensorType resultType = + RankedTensorType::get(resultTypeShape, elementType); + Value padValue; + if (auto complexTy = elementType.dyn_cast()) { + auto zero = rewriter.getZeroAttr(complexTy.getElementType()); + padValue = rewriter.create( + loc, elementType, rewriter.getArrayAttr({zero, zero})); + } else { + auto zero = rewriter.getZeroAttr(elementType); + padValue = rewriter.create(loc, elementType, zero); + } + return rewriter.create(loc, resultType, input, lowPadding, + highPadding, padValue); +} + +// Returns a top-left slice from |input| shaped like |likeWhat|. +Value extractSliceLike(Location loc, PatternRewriter &rewriter, Value input, + Value likeWhat) { + SmallVector offsets, dims, strides; + auto resultType = likeWhat.getType().cast(); + int64_t rank = resultType.getRank(); + auto resultShape = likeWhat.getType().cast().getShape(); + for (int i = 0; i < rank; ++i) { + offsets.push_back(rewriter.getIndexAttr(0)); + strides.push_back(rewriter.getIndexAttr(1)); + if (resultShape[i] == ShapedType::kDynamic) { + dims.emplace_back(rewriter.create(loc, likeWhat, i)); + } else { + dims.push_back(rewriter.getIndexAttr(resultShape[i])); + } + } + return rewriter.create(loc, resultType, input, + offsets, dims, strides); +} + +bool haveEqualShapeDim(Value x, Value y, int i) { + return x.getType().cast().getDimSize(i) == + y.getType().cast().getDimSize(i); +} + +// Pattern to convert linalg.matmul to an equivalent subgraph using +// linalg.mmt4d. Currently, m0, n0 and k0 (packing parameters, aka layout tiling +// parameters) are compile-time constants. +struct MatmulToMmt4dPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + explicit MatmulToMmt4dPattern(MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp, + PatternRewriter &rewriter) const override { + Location loc = matmulOp.getLoc(); + + Value lhs = matmulOp.getDpsInputOperand(0)->get(); + Value rhs = matmulOp.getDpsInputOperand(1)->get(); + Value acc = matmulOp.getDpsInitOperand(0)->get(); + + // This transformation supports any mixing of static and dynamic dimensions, + // with one exception: the dynamic-ness of each dimension of the accumulator + // must match the dynamic-ness of the corresponding lhs/rhs dimension. + // This limitation is not inherent to this transformation's code, it's just + // here to avoid a current linalg folding limitation: at the moment, + // removing this gives the following error in e2e matmul tests, + // "error: failed to legalize operation 'tensor.cast' that was explicitly + // marked illegal" + // apparently due to some missing folding of tensor.cast op into reshapes. + if (!haveEqualShapeDim(lhs, acc, 0) || !haveEqualShapeDim(rhs, acc, 1)) { + return failure(); + } + + ShapedType lhsType = lhs.getType().cast(); + ShapedType rhsType = rhs.getType().cast(); + int64_t shapeM = lhsType.getShape()[0]; + int64_t shapeN = rhsType.getShape()[1]; + auto chooseMatMulOrMatVec = [=](ArrayRef m0k0n0, + ArrayRef m0k0n0ForMatVec, + ArrayRef m0k0n0ForWhenRhsHas2Columns, + std::string comment) { + assert(m0k0n0ForMatVec[2] == 1 && "not a matrix*vector shape"); + assert(m0k0n0ForWhenRhsHas2Columns[2] == 2 && + "N=2 is expected when RHS has 2 columns"); + + SmallVector params; + if (shapeN == 1 || shapeM == 1) { + params.assign(m0k0n0ForMatVec.begin(), m0k0n0ForMatVec.end()); + } else if (shapeN == 2 || shapeM == 2) { + params.assign(m0k0n0ForWhenRhsHas2Columns.begin(), + m0k0n0ForWhenRhsHas2Columns.end()); + } else { + return Mmt4DTileParams(m0k0n0, comment); + } + + if (shapeN == 1 || shapeN == 2) { + comment += ", matrix * narrow matrix, where the narrow matrix has " + + std::to_string(shapeN) + " column(s)"; + } else { + // The vector*matrix case is intentionally derived from the + // matrix*vector case by swapping M and N dims so that in kernel + // codegen we can reuse matrix*vector kernels by swapping LHS and RHS. + std::swap(params[0], params[2]); + comment += ", narrow matrix * matrix, where the narrow matrix has " + + std::to_string(shapeM) + " column(s)"; + } + return Mmt4DTileParams(params, comment); + }; + + const auto &tileParams = chooseMatMulOrMatVec( + {8, 1, 8}, {8, 1, 1}, {8, 1, 2}, "f32*f32->f32, generic"); + + Value paddedLhs = pad(loc, rewriter, lhs, tileParams.lhs()); + Value paddedRhs = pad(loc, rewriter, rhs, tileParams.rhs()); + Value paddedAcc = pad(loc, rewriter, acc, tileParams.acc()); + + Value packed4DLhs = + pack(loc, rewriter, paddedLhs, {0, 1}, tileParams.lhs(), {}); + Value packed4DRhs = pack(loc, rewriter, paddedRhs, {1, 0}, + tileParams.rhsTranspose(), {1, 0}); + Value packed4DAcc = + pack(loc, rewriter, paddedAcc, {0, 1}, tileParams.acc(), {}); + + auto mmt4d = rewriter.create( + loc, packed4DAcc.getType(), ValueRange{packed4DLhs, packed4DRhs}, + ValueRange{packed4DAcc}); + mmt4d->setAttr(StringAttr::get(getContext(), "comment"), + StringAttr::get(getContext(), tileParams.getComment())); + + Value paddedResult = unpack(loc, rewriter, mmt4d.getResult(0), paddedAcc, + {0, 1}, tileParams.acc(), {}); + + Value result = extractSliceLike(loc, rewriter, paddedResult, acc); + rewriter.replaceOp(matmulOp, ArrayRef{result}); + + return success(); + } +}; + +FailureOr tileMatmul(PatternRewriter &rewriter, Operation *op, + ArrayRef tileSizes) { + TilingOptions opts; + opts.setTileSizeComputationFn(tileSizes); + opts.distribute = true; + return tileUsingGmlSt(opts, rewriter, cast(op)); +} + +/// Splits the tile sizes in `parallelSizes` into `reductionSizes` for the +/// reduction loops. +void splitParallelAndReductionTiles(linalg::LinalgOp op, + SmallVectorImpl ¶llelSizes, + SmallVectorImpl &reductionSizes) { + reductionSizes.assign(parallelSizes.begin(), parallelSizes.end()); + for (auto [index, iteratorType] : + llvm::enumerate(op.getIteratorTypesArray())) { + if (iteratorType == utils::IteratorType::parallel) { + reductionSizes[index] = 0; + } else { + parallelSizes[index] = 0; + } + } +} + +FailureOr tileUsingSCFForAndReplace( + PatternRewriter &rewriter, Operation *op, + const scf::SCFTilingOptions &tilingOptions) { + auto tilingResult = scf::tileUsingSCFForOp(rewriter, op, tilingOptions); + if (failed(tilingResult) || tilingResult->loops.empty()) return failure(); + rewriter.replaceOp(op, tilingResult->replacements); + return tilingResult->tiledOps.front(); +} + +/// Pattern to tile `linalg.mmt4d`. +struct Mmt4DTransformPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + explicit Mmt4DTransformPattern(MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(linalg::Mmt4DOp mmt4dOp, + PatternRewriter &rewriter) const override { + if (hasLabel(mmt4dOp, kMatmulTransformedLabel)) { + return rewriter.notifyMatchFailure(mmt4dOp, + "has already been transformed."); + } + + // Tile tensor.pack ops. + auto packTilingOptions = + scf::SCFTilingOptions().setTileSizeComputationFunction( + [&](OpBuilder b, Operation *op) { + auto numLoops = + cast(op).getLoopIteratorTypes().size(); + SmallVector tiles( + numLoops, b.create(op->getLoc(), 1)); + return tiles; + }); + + auto *lhsOp = mmt4dOp.getInputs()[0].getDefiningOp(); + if (failed(tileUsingSCFForAndReplace(rewriter, lhsOp, packTilingOptions))) + return failure(); + + auto *rhsOp = mmt4dOp.getInputs()[1].getDefiningOp(); + if (failed(tileUsingSCFForAndReplace(rewriter, rhsOp, packTilingOptions))) + return failure(); + + auto *accOp = mmt4dOp.getOutputs()[0].getDefiningOp(); + if (failed(tileUsingSCFForAndReplace(rewriter, accOp, packTilingOptions))) + return failure(); + + // Tile tensor.unpack op. + auto unpackTilingOptions = + scf::SCFTilingOptions().setTileSizeComputationFunction( + [](OpBuilder &builder, Operation *op) { + Location loc = op->getLoc(); + auto unpackOp = cast(op); + auto numLoops = unpackOp.getDestRank(); + auto dimAndTileMapping = unpackOp.getDimAndTileMapping(); + SmallVector tileSizes; + for (size_t i = 0; i < numLoops; ++i) { + if (dimAndTileMapping.count(i)) { + tileSizes.push_back(getValueOrCreateConstantIndexOp( + builder, loc, dimAndTileMapping[i])); + } else { + tileSizes.push_back( + getDimValue(builder, loc, unpackOp.getDest(), i)); + } + } + return tileSizes; + }); + + auto *unpackOp = *mmt4dOp->user_begin(); + if (failed( + tileUsingSCFForAndReplace(rewriter, unpackOp, unpackTilingOptions))) + return failure(); + + // Compute the tile sizes. Note that at this stage we only do layout tiling. + // Later we might also want to do traversal tiling (only on M and N dims). + auto getL1TileSizes = [&]() -> SmallVector { + auto lhsShape = + mmt4dOp.getInputs()[0].getType().cast().getShape(); + auto rhsShape = + mmt4dOp.getInputs()[1].getType().cast().getShape(); + int64_t m0 = lhsShape[2]; + int64_t n0 = rhsShape[2]; + int64_t k0 = lhsShape[3]; + return {1, 1, 1, m0, n0, k0}; + }; + + SmallVector parallelTileSizes = getL1TileSizes(); + SmallVector reductionTileSizes; + + // Search the number of outer parallel loops to separate them from possible + // inner reduction dimensions. + auto iterTypes = mmt4dOp.getIteratorTypesArray(); + // Make sure to only look at the leading loops for tiling---we will scan + // this array to find the first non-parallel loop later and use that for + // indexing into the tile sizes. + if (iterTypes.size() > parallelTileSizes.size()) { + iterTypes.resize(parallelTileSizes.size()); + } + + splitParallelAndReductionTiles(mmt4dOp.getOperation(), parallelTileSizes, + reductionTileSizes); + + // Tile the parallel loops. + auto tiledOp = tileUsingSCFForAndReplace( + rewriter, mmt4dOp.getOperation(), + scf::SCFTilingOptions().setTileSizes(parallelTileSizes)); + if (failed(tiledOp)) return failure(); + mmt4dOp = cast(*tiledOp); + + // Tile the reduction loops. + tiledOp = tileUsingSCFForAndReplace( + rewriter, mmt4dOp.getOperation(), + scf::SCFTilingOptions().setTileSizes(reductionTileSizes)); + if (failed(tiledOp)) return failure(); + mmt4dOp = cast(*tiledOp); + + setLabel(mmt4dOp, kMatmulTransformedLabel); + return success(); + } +}; + +/// Pattern to tile `linalg.matmul`, fuse `linalg.fill` into generated +/// `gml_st.parallel`, and peel the generated loops. +struct MatmulTransformPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + explicit MatmulTransformPattern(MLIRContext *context, + int64_t lhsParallelDimTileSize = 2, + int64_t rhsParallelDimTileSize = 4, + int64_t reductionDimTileSize = 8, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + lhsParallelDimTileSize(lhsParallelDimTileSize), + rhsParallelDimTileSize(rhsParallelDimTileSize), + reductionDimTileSize(reductionDimTileSize) {} + + LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp, + PatternRewriter &rewriter) const override { + if (hasLabel(matmulOp, kMatmulTransformedLabel)) + return rewriter.notifyMatchFailure(matmulOp, + "has already been transformed."); + if (isa(matmulOp->getParentOp())) + return rewriter.notifyMatchFailure( + matmulOp, "has already been tiled by another pass."); + + auto cluster = findMapFusionCluster(matmulOp); + auto fusionCluster = cluster.operations; + auto *tilingRoot = cluster.root; + + // Tiling of linalg.map requires two dimensions, linalg.matmul requires + // three. + SmallVector parallelDimsTileSizes{lhsParallelDimTileSize, + rhsParallelDimTileSize}; + if (isa(tilingRoot)) parallelDimsTileSizes.push_back(0); + + // First level tiling: parallel dimensions. + auto tilingParallelDimsResult = + tileMatmul(rewriter, tilingRoot, parallelDimsTileSizes); + if (failed(tilingParallelDimsResult)) return failure(); + + // Update the results if tiling occurred. + if (tilingParallelDimsResult->loop != nullptr) { + rewriter.replaceOp(tilingRoot, + tilingParallelDimsResult->loop->getResults()); + tilingRoot = tilingParallelDimsResult->tiledOps.front(); + + // Fuse ops into the loop. + fuseGreedily(rewriter, *tilingRoot->getBlock(), + [&](Operation *op) { return fusionCluster.contains(op); }); + (void)fuseFillOpsIntoParallelOp( + rewriter, cast(tilingParallelDimsResult->loop)); + } + + // Second level tiling: reduction dimension for matmuls. + SmallVector tilingReductionDimsResults; + for (auto op : + llvm::to_vector(tilingRoot->getBlock()->getOps())) { + auto result = tileMatmulReductionDims(rewriter, op); + if (failed(result)) return failure(); + tilingReductionDimsResults.push_back(*result); + } + + // Peel parallel loops. + // + // We only want to peel (1) the parallel loop then (2) our kernel. + if (auto loop = + dyn_cast_or_null(tilingParallelDimsResult->loop)) { + auto peelingResult = peelAllLoops(loop, rewriter); + } + + // Peel reduction loop inside the main parallel loop, label the main loop as + // "perfectly tiled" one, to enable vectorization after canonicalization. + for (auto &res : tilingReductionDimsResults) { + if (res.loops.size() == 1) { + auto peelingResult = peelSCFForOp(rewriter, res.loops.front()); + setLabel(peelingResult.mainLoop, kPerfectlyTiledLoopLabel); + } + } + return success(); + } + + private: + FailureOr tileMatmulReductionDims( + PatternRewriter &rewriter, linalg::MatmulOp matmulOp) const { + SmallVector reductionDimsTileSizes{0, 0, reductionDimTileSize}; + scf::SCFTilingOptions opts; + opts.setTileSizes(reductionDimsTileSizes); + auto tilingReductionDimsResult = + scf::tileUsingSCFForOp(rewriter, matmulOp.getOperation(), opts); + if (failed(tilingReductionDimsResult)) return failure(); + + // Update the results if tiling occurred. + if (!tilingReductionDimsResult->loops.empty()) { + rewriter.replaceOp(matmulOp, tilingReductionDimsResult->replacements); + matmulOp = + cast(tilingReductionDimsResult->tiledOps.front()); + } + + setLabel(matmulOp, kMatmulTransformedLabel); + return tilingReductionDimsResult; + } + + int64_t lhsParallelDimTileSize; + int64_t rhsParallelDimTileSize; + int64_t reductionDimTileSize; +}; + +struct TransformMatmulForCpuPass + : public impl::TransformMatmulForCpuPassBase { + TransformMatmulForCpuPass() = default; + + explicit TransformMatmulForCpuPass(llvm::ArrayRef matmulTileSizes, + bool lowerToMmt4DOp) { + tileSizes = matmulTileSizes; + lowerToMmt4D = lowerToMmt4DOp; + } + + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + linalg::registerTilingInterfaceExternalModels(registry); + tensor::registerTilingInterfaceExternalModels(registry); + tensor::registerInferTypeOpInterfaceExternalModels(registry); + } + + void runOnOperation() override { + func::FuncOp f = getOperation(); + MLIRContext *ctx = &getContext(); + + // Just do tiling and fusion on linalg.matmul. + if (!lowerToMmt4D) { + if (tileSizes.empty()) { + tileSizes = {4, 4, 4}; + } + assert(tileSizes.size() == 3 && + "Tiling sizes for MatMul should have 3 elements"); + RewritePatternSet patterns(ctx); + patterns.add(ctx, tileSizes[0], tileSizes[1], + tileSizes[2]); + if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { + return signalPassFailure(); + } + // Ensure we drop the marker in the end. + f.walk([](linalg::MatmulOp op) { + removeLabel(op, kMatmulTransformedLabel); + }); + return; + } + + // Lower linalg.matmul to linalg.mmt4d (packed matmul). + { + // Convert linalg.matmul to linalg.mmt4d. + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + // Canonicalization. + tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, ctx); + tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx); + linalg::FillOp::getCanonicalizationPatterns(patterns, ctx); + + if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { + return signalPassFailure(); + } + // Ensure we drop the marker in the end. + f.walk([](Operation *op) { + if (isa(op) || isa(op)) + removeLabel(op, kMatmulTransformedLabel); + }); + } + // Tiling pack, unpack and mmt4d ops. + { + RewritePatternSet patterns(ctx); + // We tile towards SIMD codegen, so the tile sizes depend on the target + // architecture (vector instruction sizes, etc.). Luckily, this + // information is already captured in linalg.mmt4d during linalg.matmul -> + // linalg.mmt4d lowering phase. It is hardcoded for AVX on x86 for now. + patterns.add(ctx); + + if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { + return signalPassFailure(); + } + // Ensure we drop the marker in the end. + f.walk( + [](linalg::Mmt4DOp op) { removeLabel(op, kMatmulTransformedLabel); }); + } + // Expanding pack and unpack ops to other primitive tensor/linalg ops and + // canonicalize tiled ops. + { + RewritePatternSet patterns(ctx); + linalg::populateLinalgTilingCanonicalizationPatterns(patterns); + patterns.add(ctx); + patterns.add(ctx); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } + } +}; + +/// Remove memref::CopyOp whose target (can be either a memref::SubViewOp or +/// memref::AllocOp) has no other users. +struct SimplifyDeadCopyPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::CopyOp op, + PatternRewriter &rewriter) const override { + auto valueIt = op.getTarget(); + Operation *onlyNonStoreLikeUser = op; + for (auto subviewOp = valueIt.getDefiningOp(); subviewOp; + onlyNonStoreLikeUser = subviewOp, valueIt = subviewOp.getSource(), + subviewOp = valueIt.getDefiningOp()) { + // TODO(vuson) simplify if other uses are also memref.copy writing to + // subview + // %alloc_4 = memref.alloc() + // %subview_5 = memref.subview %alloc_4 + // %subview_6 = memref.subview %alloc_4 + // memref.copy %arg0, %subview_6 + // memref.copy %arg1, %subview_5 + if (!subviewOp->hasOneUse()) return failure(); + } + + auto hasOnlyStoreLikeUsers = [&](Value alloc) { + return !llvm::any_of(alloc.getUsers(), [&](Operation *op) { + if (op == onlyNonStoreLikeUser) return false; + // TODO(vuson) remove this exception when MemoryEffectOpInterface gets + // corrected for linalg::FillOp. Right now it has MemoryEffects::Read + // while the only thing it ever reads is metadata such as dynamic sizes. + if (isa(op)) return false; + if (auto effect = dyn_cast(op)) { + return effect.getEffectOnValue(alloc) + .has_value() || + !effect.getEffectOnValue(alloc) + .has_value(); + } + return true; + }); + }; + if (!valueIt.getDefiningOp() || + !hasOnlyStoreLikeUsers(valueIt)) + return failure(); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct SimplifyDeadCopyPass + : public impl::SimplifyDeadCopyPassBase { + void runOnOperation() override { + auto func = getOperation(); + auto *ctx = func.getContext(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx); + memref::AllocOp::getCanonicalizationPatterns(patterns, ctx); + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +createTransformMatmulForCpuPass() { + return std::make_unique(); +} + +std::unique_ptr> +createTransformMatmulForCpuPass(llvm::ArrayRef matmulTileSizes, + bool lowerToMmt4DOp) { + return std::make_unique( + matmulTileSizes, lowerToMmt4DOp); +} + +std::unique_ptr> createSimplifyDeadCopyPass() { + return std::make_unique(); +} + +} // namespace mlir::gml_st diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_reduce_for_cpu.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_reduce_for_cpu.cc new file mode 100644 index 00000000000..fc81040b954 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_reduce_for_cpu.cc @@ -0,0 +1,529 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/fusion/fusion.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/peeling/peeling.h" +#include "gml_st/transforms/tiling/tiling.h" +#include "gml_st/transforms/transforms.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::gml_st { +namespace { + +#define GEN_PASS_DEF_TRANSFORMREDUCEFORCPUPASS +#include "gml_st/transforms/passes.h.inc" + +constexpr llvm::StringRef kReduceTransformedLabel = + "__reduce_transformed_label__"; + +FailureOr tileReduce(PatternRewriter &rewriter, + linalg::ReduceOp reduceOp, + ArrayRef tileSizes) { + TilingOptions opts; + opts.setTileSizeComputationFn(tileSizes); + opts.distribute = true; + return tileUsingGmlSt(opts, rewriter, + cast(reduceOp.getOperation())); +} + +SmallVector getParallelDimTileSizes(int64_t reductionDim, + int64_t parallelDimTileSize) { + return reductionDim ? SmallVector{parallelDimTileSize, 0} + : SmallVector{0, parallelDimTileSize}; +} + +SmallVector getReductionDimTileSizes(int64_t reductionDim, + int64_t reductionDimTileSize) { + return reductionDim ? SmallVector{0, reductionDimTileSize} + : SmallVector{reductionDimTileSize, 0}; +} + +LogicalResult validateOp(linalg::ReduceOp reduceOp, PatternRewriter &rewriter, + int64_t expectedRank) { + ArrayRef reduceDimensions = reduceOp.getDimensions(); + if (reduceDimensions.size() != 1) { + return rewriter.notifyMatchFailure( + reduceOp, "expects 1 reduction dimension element. 0 or > 1 received."); + } + OpOperandVector operands = reduceOp.getDpsInputOperands(); + if (operands.size() != 1) { + return rewriter.notifyMatchFailure(reduceOp, + "expects 1 operand. 0 or > 1 received."); + } + const int64_t operandRank = + operands[0]->get().getType().cast().getRank(); + if (operandRank != expectedRank) { + return rewriter.notifyMatchFailure(reduceOp, [&](::mlir::Diagnostic &diag) { + diag << "expects rank " << expectedRank << ". " << operandRank + << "received."; + }); + } + return success(); +} + +struct Reduce1DTransformPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + explicit Reduce1DTransformPattern(MLIRContext *context, int64_t vectorSize, + int64_t tileSize, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + vectorSize(vectorSize), + tileSize(tileSize) {} + + LogicalResult matchAndRewrite(linalg::ReduceOp reduceOp, + PatternRewriter &rewriter) const override { + if (hasLabel(reduceOp, kReduceTransformedLabel)) { + return rewriter.notifyMatchFailure(reduceOp, + "has already been transformed."); + } + + if (isa(reduceOp->getParentOp())) { + return rewriter.notifyMatchFailure( + reduceOp, "has already been tiled by another pass."); + } + + if (failed(validateOp(reduceOp, rewriter, /*expectedRank=*/1))) + return failure(); + + Location loc = reduceOp.getLoc(); + + // Constants. + Value zero = rewriter.create(loc, 0); + Value tileSizeValue = + rewriter.create(loc, tileSize); + + // Input. + Value input = reduceOp.getInputs().front(); + Value inputSize = rewriter.create(loc, input, 0); + + // Loop boundaries. + // tileableBound = inputSize - inputSize % tileSize + // remainderSize = inputSize - tileableBound + Value tileableBound = getTileableBound(rewriter, loc, inputSize); + Value remainderSize = + getRemainderSize(rewriter, loc, tileableBound, inputSize); + + // 0-d tensor with the neutral elements. + auto fillOp = reduceOp.getInits().front().getDefiningOp(); + if (!fillOp) return failure(); + auto neutralValue = fillOp.value(); + + // fillOp.getValue(); + Type elementType = neutralValue.getType(); + + // Create tensor with neutral elements for tile loop + // init. + Value emptyVector = rewriter.create( + loc, llvm::ArrayRef({vectorSize}), elementType); + Value filledVector = + rewriter.create(loc, neutralValue, emptyVector) + .getResult(0); + + auto tiledLoopBodyBuilder = [&](OpBuilder &b, Location loc, Value iv, + ValueRange inits) { + // Tile input as tensor and reshape into + // tensor<(TILE_SIZE/VECTOR_SIZE)xVECTOR_SIZExELEM_TYPE>. + Value inputSlice = tileAndReshapeInput(b, loc, iv, input, elementType); + + tensor::ExtractSliceOp initSlice = create1DSlice( + b, loc, inits.front(), b.getIndexAttr(0), b.getIndexAttr(vectorSize)); + + // Create `linalg.reduce` to combine + // `tensor<(TILE_SIZE/VECTOR_SIZE)xVECTOR_SIZExELEM_TYPE> input with the + // `tensor` accumulator. + auto tiledReduceOp = b.create( + loc, ValueRange{inputSlice}, ValueRange{initSlice}, + /*dimensions=*/SmallVector{0}, + /*bodyBuilder=*/nullptr, linalg::getPrunedAttributeList(reduceOp)); + OpBuilder::InsertionGuard g(rewriter); + Region ®ion = tiledReduceOp.getRegion(); + rewriter.cloneRegionBefore(reduceOp.getRegion(), region, region.end()); + setLabel(tiledReduceOp, kReduceTransformedLabel); + + b.create(loc, tiledReduceOp.getResults()); + }; + + // Create a tiled loop + auto tiledLoop = + rewriter.create(loc, zero, tileableBound, tileSizeValue, + filledVector, tiledLoopBodyBuilder); + setLabel(tiledLoop, kPerfectlyTiledLoopLabel); + + // Create `linalg.reduce` from tensor to + // tensor. + auto horizontalReduce = + cloneReduceOp(rewriter, reduceOp, tiledLoop.getResult(0), + reduceOp.getInits().front()); + + auto remainderLoopBodyBuilder = [&](OpBuilder &b, Location loc, Value iv, + ValueRange inits) { + Value inputSlice = create1DSlice(b, loc, input, iv, remainderSize); + + Value initSlice = b.create( + loc, inits.front(), /*offsets=*/SmallVector{}, + /*sizes=*/SmallVector{}, + /*strides=*/SmallVector{}); + + auto newReduce = cloneReduceOp(b, reduceOp, inputSlice, initSlice); + b.create(loc, newReduce); + }; + + // Combine `horizontal reduce` with the tail of the input. The tail is + // always smaller than TILE_SIZE. + auto remainderLoop = + rewriter + .create(loc, tileableBound, inputSize, tileSizeValue, + horizontalReduce, remainderLoopBodyBuilder) + .getResult(0); + + rewriter.replaceOp(reduceOp, remainderLoop); + + return success(); + } + + private: + Value getTileableBound(OpBuilder &b, Location loc, Value inputSize) const { + if (tileSize == 1) return inputSize; + + auto inputSizeInt = getConstantIntValue(inputSize); + if (inputSizeInt && *inputSizeInt % tileSize == 0) return inputSize; + + AffineExpr sym0; + bindSymbols(b.getContext(), sym0); + + auto modMap = AffineMap::get(0, 1, {sym0 - sym0 % tileSize}); + return b.createOrFold(loc, modMap, ValueRange{inputSize}); + } + + Value getRemainderSize(OpBuilder &b, Location loc, Value tileableBound, + Value inputSize) const { + AffineExpr sym0, sym1; + bindSymbols(b.getContext(), sym0, sym1); + auto diffMap = AffineMap::get(0, 2, {sym1 - sym0}); + return b.create(loc, diffMap, + ValueRange{tileableBound, inputSize}); + } + + tensor::ExtractSliceOp create1DSlice(OpBuilder &b, Location loc, Value source, + OpFoldResult offset, + OpFoldResult size) const { + SmallVector offsets{offset}; + SmallVector sizes{size}; + SmallVector strides{b.getIndexAttr(1)}; + + return b.create(loc, source, offsets, sizes, + strides); + } + + Value cloneReduceOp(OpBuilder &b, linalg::ReduceOp reduceOp, + ValueRange newInputs, Value newInit) const { + IRMapping bvm; + bvm.map(reduceOp.getInputs(), newInputs); + bvm.map(reduceOp.getInits(), ValueRange{newInit}); + + auto *newReduceOp = b.clone(*reduceOp.getOperation(), bvm); + setLabel(newReduceOp, kReduceTransformedLabel); + return newReduceOp->getResult(0); + } + + Value tileAndReshapeInput(OpBuilder &b, Location loc, Value iv, Value input, + Type elementType) const { + Value inputSlice = + create1DSlice(b, loc, input, iv, b.getIndexAttr(tileSize)); + + auto reshapeType = + RankedTensorType::get({tileSize / vectorSize, vectorSize}, elementType); + SmallVector ri = {{0, 1}}; + return b.create(loc, reshapeType, inputSlice, ri); + } + + int64_t vectorSize; + int64_t tileSize; +}; + +/// Pattern to tile `linalg.reduce` and fuse `linalg.fill` into generated +/// `gml_st.parallel`. +struct Reduce2DTransformPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + explicit Reduce2DTransformPattern(MLIRContext *context, + int64_t parallelDimTileSize = 4, + int64_t reductionDimTileSize = 2, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + parallelDimTileSize(parallelDimTileSize), + reductionDimTileSize(reductionDimTileSize) {} + + LogicalResult matchAndRewrite(linalg::ReduceOp reduceOp, + PatternRewriter &rewriter) const override { + if (hasLabel(reduceOp, kReduceTransformedLabel)) { + return rewriter.notifyMatchFailure(reduceOp, + "has already been transformed."); + } + + if (failed(validateOp(reduceOp, rewriter, /*expectedRank=*/2))) + return failure(); + + auto cluster = getFusionCluster(reduceOp); + auto fusionCluster = cluster.operations; + auto *tilingRoot = cluster.root; + if (!isa(tilingRoot) && !isa(tilingRoot)) { + return rewriter.notifyMatchFailure( + tilingRoot, + "Expected MapOp or ReduceOp as a root of fusion cluster."); + } + + // First level tiling: parallel dimension. + auto tilingParallelDimsResult = + tileParallelDimensions(tilingRoot, rewriter); + if (failed(tilingParallelDimsResult)) return failure(); + + // Update the results if tiling occurred. + rewriter.replaceOp(tilingRoot, + tilingParallelDimsResult->loop->getResults()); + tilingRoot = (tilingParallelDimsResult->tiledOps.front()); + + // Fuse greedily into root op. + fuseGreedily(rewriter, *tilingRoot->getBlock(), + [&](Operation *op) { return fusionCluster.contains(op); }); + (void)fuseFillOpsIntoParallelOp( + rewriter, cast(tilingParallelDimsResult->loop)); + + // Process all reduces in a fusion cluster. + for (auto tiledReduceOp : + llvm::to_vector(tilingRoot->getBlock()->getOps())) { + // Second level tiling: reduction dimension. + auto tilingReductionDimsResult = + tileReductionDims(rewriter, tiledReduceOp); + if (failed(tilingReductionDimsResult)) return failure(); + + // Update the results if tiling occurred. + if (!tilingReductionDimsResult->loops.empty()) { + rewriter.replaceOp(tiledReduceOp, + tilingReductionDimsResult->replacements); + tiledReduceOp = + cast(tilingReductionDimsResult->tiledOps.front()); + fuseGreedily(rewriter, *tiledReduceOp->getBlock(), + [&](Operation *op) { return isa(op); }); + } + setLabel(tiledReduceOp, kReduceTransformedLabel); + + // Peel reduction loops. + if (failed(peelReduction(rewriter, tilingParallelDimsResult.value(), + tilingReductionDimsResult.value()))) + return failure(); + } + + return success(); + } + + private: + // Find a cluster of operations that can be tiled and fused together around + // the root op. + FusionCluster getFusionCluster(linalg::ReduceOp reduceOp) const { + // Find a chain of MapOp users and use the last one as a root of cluster. + DenseSet resultOps; + Operation *rootOp = reduceOp.getOperation(); + + while (true) { + auto users = llvm::to_vector(rootOp->getUsers()); + + if (users.size() != 1) break; + if (!isa(users[0])) break; + resultOps.insert(rootOp); + + rootOp = users[0]; + } + + // Run DFS to find all MapOps, TransposeOps, BroadcastOps that can be fused + // in the root op. + SmallVector remainingProducers; + remainingProducers.reserve(reduceOp.getDpsInputOperands().size()); + resultOps.insert(reduceOp.getOperation()); + for (auto *operand : reduceOp.getDpsInputOperands()) + remainingProducers.push_back(operand->get().getDefiningOp()); + + while (!remainingProducers.empty()) { + Operation *curOp = remainingProducers.pop_back_val(); + if (!curOp || resultOps.contains(curOp)) continue; + auto linalgOp = dyn_cast(curOp); + if (linalgOp && + isa(curOp)) { + resultOps.insert(curOp); + for (auto *operand : linalgOp.getDpsInputOperands()) + remainingProducers.push_back(operand->get().getDefiningOp()); + } + } + return {resultOps, rootOp}; + } + + FailureOr tileParallelDimensions( + Operation *tilingRoot, PatternRewriter &rewriter) const { + FailureOr tilingParallelDimsResult; + if (auto reduceOp = dyn_cast(tilingRoot)) { + tilingParallelDimsResult = + tileReduce(rewriter, reduceOp, + getParallelDimTileSizes(reduceOp.getDimensions()[0], + parallelDimTileSize)); + } else if (isa(tilingRoot)) { + TilingOptions opts; + opts.setTileSizeComputationFn({parallelDimTileSize}); + opts.distribute = true; + + tilingParallelDimsResult = + tileUsingGmlSt(opts, rewriter, cast(tilingRoot)); + } else { + return failure(); + } + + return tilingParallelDimsResult; + } + + FailureOr tileReductionDims( + PatternRewriter &rewriter, linalg::ReduceOp reduceOp) const { + scf::SCFTilingOptions tilingOptions; + tilingOptions.setTileSizes(getReductionDimTileSizes( + reduceOp.getDimensions()[0], reductionDimTileSize)); + return scf::tileUsingSCFForOp(rewriter, reduceOp.getOperation(), + tilingOptions); + } + + LogicalResult peelReduction( + PatternRewriter &rewriter, const TilingResult &tilingParallelDimsResult, + const scf::SCFTilingResult &tilingReductionDimsResult) const { + // Peel parallel loops. + if (auto loop = + dyn_cast_or_null(tilingParallelDimsResult.loop)) { + auto peelingResult = peelAllLoops(loop, rewriter); + } + + // Peel reduction loop inside the main parallel loop, label the main loop as + // "perfectly tiled" one, to enable vectorization after canonicalization. + if (!tilingReductionDimsResult.loops.empty()) { + scf::ForOp forLoop = tilingReductionDimsResult.loops.front(); + SCFForPeelingResult peelingResult = peelSCFForOp(rewriter, forLoop); + if (peelingResult.mainLoop) { + setLabel(peelingResult.mainLoop, kPerfectlyTiledLoopLabel); + } + + if (!peelingResult.tailLoop) return success(); + // Tile ops in the peeled loop again, to size 1, so they can be + // scalarized. + scf::ForOp peeledLoop = peelingResult.tailLoop; + auto yieldOp = cast(peeledLoop.getBody()->getTerminator()); + auto reduceOp = getRootReduce(yieldOp); + if (!reduceOp) return failure(); + + scf::SCFTilingOptions opts; + opts.setTileSizes( + getReductionDimTileSizes(reduceOp.getDimensions()[0], 1)); + + if (failed(tileUsingSCFForOpAndFuseGreedily( + rewriter, reduceOp, opts, kReduceTransformedLabel, + [&](Operation *op) { return isa(op); }))) + return failure(); + } + return success(); + } + + linalg::ReduceOp getRootReduce(scf::YieldOp yieldOp) const { + if (yieldOp.getResults().size() != 1) return nullptr; + + Value reduceResult = yieldOp.getResults().front(); + if (auto insertSliceOp = + reduceResult.getDefiningOp()) { + reduceResult = insertSliceOp.getSource(); + } + return reduceResult.getDefiningOp(); + } + + int64_t parallelDimTileSize; + int64_t reductionDimTileSize; +}; + +struct TransformReduceForCpuPass + : public impl::TransformReduceForCpuPassBase { + TransformReduceForCpuPass() = default; + + explicit TransformReduceForCpuPass(int64_t reduceVectorSize = 8, + int64_t reduceTileSize1D = 32, + ArrayRef reduceTileSizes2D = {}) { + vectorSize = reduceVectorSize; + tileSize1D = reduceTileSize1D; + tileSizes2D = reduceTileSizes2D; + } + + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + linalg::registerTilingInterfaceExternalModels(registry); + } + + void runOnOperation() override { + func::FuncOp f = getOperation(); + MLIRContext *ctx = &getContext(); + + if (tileSizes2D.empty()) { + tileSizes2D = {4, 2}; + } + + assert(tileSizes2D.size() == 2 && + "Tiling sizes for Reduce should have 2 element."); + + RewritePatternSet patterns(ctx); + patterns.add(ctx, vectorSize, tileSize1D); + patterns.add(ctx, tileSizes2D[0], tileSizes2D[1]); + + if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { + return signalPassFailure(); + } + + // Ensure we drop the marker in the end. + f.walk([](linalg::ReduceOp reduceOp) { + removeLabel(reduceOp, kReduceTransformedLabel); + }); + } +}; + +} // namespace + +std::unique_ptr> +createTransformReduceForCpuPass(int64_t vectorSize, int64_t tileSize1D, + ArrayRef tileSizes2D) { + return std::make_unique( + vectorSize, tileSize1D, tileSizes2D); +} + +} // namespace mlir::gml_st diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_reverse_for_cpu.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_reverse_for_cpu.cc new file mode 100644 index 00000000000..3c183e1b018 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_reverse_for_cpu.cc @@ -0,0 +1,162 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/peeling/peeling.h" +#include "gml_st/transforms/tiling/tiling.h" +#include "gml_st/transforms/transforms.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "thlo/IR/thlo_ops.h" + +namespace mlir::gml_st { +namespace { + +#define GEN_PASS_DEF_TRANSFORMREVERSEFORCPUPASS +#include "gml_st/transforms/passes.h.inc" + +constexpr llvm::StringRef kReverseTransformedLabel = + "__reverse_transformed_label__"; + +FailureOr tileReverseAndUpdateResultIfTiled( + PatternRewriter &rewriter, thlo::ReverseOp &reverseOp, + ArrayRef tileSizes, bool distribute) { + TilingOptions opts; + opts.setTileSizeComputationFn(tileSizes); + opts.distribute = distribute; + auto tilingResult = tileUsingGmlSt( + opts, rewriter, cast(reverseOp.getOperation())); + + if (failed(tilingResult)) return failure(); + + // Update the results if tiling occurred. + if (tilingResult->loop != nullptr) { + rewriter.replaceOp(reverseOp, tilingResult->loop->getResults()); + reverseOp = cast(tilingResult->tiledOps.front()); + } + + return tilingResult; +} + +SmallVector getTileSizes(int64_t rank, int64_t vectorSize, + bool tileToScalarize) { + SmallVector sizes(rank, 1); + if (!tileToScalarize) sizes[rank - 1] = vectorSize; + return sizes; +} + +/// Pattern to tile `thlo.reverse`. +struct ReverseTransformPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + explicit ReverseTransformPattern(MLIRContext *context, int64_t vectorSize, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + vectorSize(vectorSize) {} + + LogicalResult matchAndRewrite(thlo::ReverseOp reverseOp, + PatternRewriter &rewriter) const override { + if (hasLabel(reverseOp, kReverseTransformedLabel)) + return rewriter.notifyMatchFailure(reverseOp, + "has already been transformed."); + if (isa(reverseOp->getParentOp())) + return rewriter.notifyMatchFailure( + reverseOp, "has already been tiled by another pass."); + + // Parallel dimension tiling. Tiling will be of the form + // 1x1x..x1xVectorSize. + int64_t rank = reverseOp.getInput().getType().getRank(); + auto tilingResult = tileReverseAndUpdateResultIfTiled( + rewriter, reverseOp, getTileSizes(rank, vectorSize, false), + /*distribute=*/true); + + // Peel parallel loop. + if (auto loop = dyn_cast_or_null(tilingResult->loop)) { + auto peelingResult = peelAllLoops(loop, rewriter); + + // If last dim is to be reversed. + if (llvm::is_contained(reverseOp.getReverseDimensions(), rank - 1)) { + // If we have a remaining loop, we tile this to sizes of 1. + for (auto *remParLoop : peelingResult.tailLoops) { + remParLoop->walk([&](Operation *childOp) { + if (isa(childOp)) { + auto innerReverseOp = dyn_cast(*childOp); + auto secondTiling = tileReverseAndUpdateResultIfTiled( + rewriter, innerReverseOp, + getTileSizes(rank, vectorSize, true), + /*distribute=*/true); + setLabel(innerReverseOp, kReverseTransformedLabel); + } + }); + } + } + } + + setLabel(reverseOp, kReverseTransformedLabel); + return success(); + } + + private: + int64_t vectorSize; +}; + +struct TransformReverseForCpuPass + : public impl::TransformReverseForCpuPassBase { + TransformReverseForCpuPass() = default; + + explicit TransformReverseForCpuPass(int64_t reverseVectorSize = 8) { + vectorSize = reverseVectorSize; + } + + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + linalg::registerTilingInterfaceExternalModels(registry); + } + + void runOnOperation() override { + func::FuncOp f = getOperation(); + MLIRContext *ctx = &getContext(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx, vectorSize); + + if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { + return signalPassFailure(); + } + + // Ensure we drop the marker in the end. + f.walk([](thlo::ReverseOp reverseOp) { + removeLabel(reverseOp, kReverseTransformedLabel); + }); + } +}; + +} // namespace + +std::unique_ptr> +createTransformReverseForCpuPass(int64_t vectorSize) { + return std::make_unique(vectorSize); +} + +} // namespace mlir::gml_st diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transform_scatter_for_cpu.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_scatter_for_cpu.cc similarity index 52% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transform_scatter_for_cpu.cc rename to tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_scatter_for_cpu.cc index 34c9097cb62..31c3aecafbb 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transform_scatter_for_cpu.cc +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_scatter_for_cpu.cc @@ -16,61 +16,87 @@ limitations under the License. #include #include -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h" -#include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h" +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/tiling/tiling.h" +#include "gml_st/transforms/transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "thlo/IR/thlo_ops.h" namespace mlir::gml_st { namespace { #define GEN_PASS_DEF_TRANSFORMSCATTERFORCPUPASS -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h.inc" +#include "gml_st/transforms/passes.h.inc" -struct TransformScatterForCpuPass - : public impl::TransformScatterForCpuPassBase { - void getDependentDialects(DialectRegistry ®istry) const final { - registry.insert(); - mlir::gml_st::registerGmlStTilingInterfaceExternalModels(registry); - } +constexpr llvm::StringRef kScatterTransformedLabel = + "__scatter_transformed_label__"; - void runOnOperation() override { - func::FuncOp f = getOperation(); - MLIRContext *ctx = &getContext(); +struct TileScatterPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(thlo::ScatterOp op, + PatternRewriter &rewriter) const override { + if (hasLabel(op, kScatterTransformedLabel)) return failure(); - mlir::gml_st::TilingOptions opts; - opts.distribute = false; // Tile to `for` loops. + if (isa(op->getParentOp())) { + return rewriter.notifyMatchFailure( + op, "has already been tiled by another pass."); + } // Tile everything to points. - opts.tileSizeComputationFn = [](OpBuilder &b, Operation *op) { + scf::SCFTilingOptions opts; + opts.setTileSizeComputationFunction([](OpBuilder &b, Operation *op) { OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart( &op->getParentOfType().getBody().front()); - auto loops = cast(op).getLoopIteratorTypes(); + auto loops = cast(op).getLoopIteratorTypes(); return SmallVector( loops.size(), b.create(op->getLoc(), 1)); - }; + }); + + auto tilingResult = scf::tileUsingSCFForOp( + rewriter, cast(op.getOperation()), opts); + if (failed(tilingResult)) return failure(); + + // If we did not tile, do not replace original op and just mark it as + // transformed then return. + if (!tilingResult->loops.empty()) { + rewriter.replaceOp(op, tilingResult->replacements); + } + setLabel(tilingResult->tiledOps.front(), kScatterTransformedLabel); + return success(); + } +}; - auto filterFn = [&](Operation *op) { - if (isa(op)) - return success(); - return failure(); - }; +struct TransformScatterForCpuPass + : public impl::TransformScatterForCpuPassBase { + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + } + + void runOnOperation() override { + func::FuncOp f = getOperation(); + MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); - populateTilingPatterns(ctx, filterFn, opts, &patterns); + patterns.add(ctx); - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { + if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) return signalPassFailure(); - } + + // Ensure we drop the marker in the end. + f.walk([](thlo::ScatterOp scatterOp) { + removeLabel(scatterOp, kScatterTransformedLabel); + }); } }; diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_sort_for_cpu.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_sort_for_cpu.cc new file mode 100644 index 00000000000..b7dfee3a992 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_sort_for_cpu.cc @@ -0,0 +1,118 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/tiling/tiling.h" +#include "gml_st/transforms/transforms.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "thlo/IR/thlo_ops.h" + +namespace mlir::gml_st { +namespace { + +#define GEN_PASS_DEF_TRANSFORMSORTFORCPUPASS +#include "gml_st/transforms/passes.h.inc" + +using mlir::arith::ConstantIndexOp; +using mlir::thlo::SortOp; + +constexpr llvm::StringRef kSortTransformedLabel = "__sort_transformed_label__"; + +struct TileSortPattern : public OpRewritePattern { + TileSortPattern(MLIRContext *context, TilingOptions options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(std::move(options)) {} + + LogicalResult matchAndRewrite(SortOp op, + PatternRewriter &rewriter) const override { + if (hasLabel(op, kSortTransformedLabel)) return failure(); + + if (isa(op->getParentOp())) + return rewriter.notifyMatchFailure( + op, "has already been tiled by another pass."); + + auto tilingResult = tileUsingGmlSt( + options, rewriter, cast(op.getOperation())); + if (failed(tilingResult)) return failure(); + + // If we did not tile (e.g. when all tile sizes are 0), do not replace + // original op and just mark it as transformed then return. + if (tilingResult->loop != nullptr) { + rewriter.replaceOp(op, tilingResult->loop->getResults()); + } + setLabel(tilingResult->tiledOps.front(), kSortTransformedLabel); + return success(); + } + + private: + TilingOptions options; +}; + +struct TransformSortForCpuPass + : public impl::TransformSortForCpuPassBase { + TransformSortForCpuPass() = default; + + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + linalg::registerTilingInterfaceExternalModels(registry); + } + + void runOnOperation() override { + auto getTileSize = [&](mlir::OpBuilder b, Operation *op) { + // use tile sizes 1 by default + auto sortOp = llvm::cast(op); + auto size = sortOp.getLoopIteratorTypes().size(); + return SmallVector(size, + b.create(op->getLoc(), 1)); + }; + + func::FuncOp f = getOperation(); + MLIRContext *ctx = &getContext(); + + TilingOptions tilingOptions; + tilingOptions.tileSizeComputationFn = getTileSize; + + RewritePatternSet patterns(ctx); + patterns.add(ctx, tilingOptions); + + if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { + return signalPassFailure(); + } + + // Ensure we drop the marker in the end. + f.walk([](thlo::SortOp sortOp) { + removeLabel(sortOp, kSortTransformedLabel); + }); + } +}; + +} // namespace + +std::unique_ptr> +createTransformSortForCpuPass() { + return std::make_unique(); +} + +} // namespace mlir::gml_st diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transform_transpose_for_cpu.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_transpose_for_cpu.cc similarity index 75% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transform_transpose_for_cpu.cc rename to tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_transpose_for_cpu.cc index 31060bf2cea..bc6703f0d3f 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transform_transpose_for_cpu.cc +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/cpu_tiling/transform_transpose_for_cpu.cc @@ -16,15 +16,18 @@ limitations under the License. #include #include -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h" -#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h" +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/fusion/fusion.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/peeling/peeling.h" +#include "gml_st/transforms/tiling/tiling.h" +#include "gml_st/transforms/transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -32,10 +35,13 @@ namespace mlir::gml_st { namespace { #define GEN_PASS_DEF_TRANSFORMTRANSPOSEFORCPUPASS -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h.inc" +#include "gml_st/transforms/passes.h.inc" using mlir::arith::ConstantIndexOp; +static constexpr llvm::StringRef kTransposeTransformedLabel = + "__transpose_transformed_label__"; + struct TileTransposePattern : public OpRewritePattern { TileTransposePattern(MLIRContext *context, TilingOptions options, PatternBenefit benefit = 1) @@ -44,10 +50,14 @@ struct TileTransposePattern : public OpRewritePattern { LogicalResult matchAndRewrite(linalg::TransposeOp op, PatternRewriter &rewriter) const override { - if (hasTransformationAttr(op)) return failure(); + if (hasLabel(op, kTransposeTransformedLabel)) return failure(); + + if (isa(op->getParentOp())) + return rewriter.notifyMatchFailure( + op, "has already been tiled by another pass."); - auto tilingResult = - tile(options, rewriter, cast(op.getOperation())); + auto tilingResult = tileUsingGmlSt( + options, rewriter, cast(op.getOperation())); if (failed(tilingResult)) return failure(); // If we did not tile (e.g. when all tile sizes are 0), do not replace @@ -55,7 +65,22 @@ struct TileTransposePattern : public OpRewritePattern { if (tilingResult->loop != nullptr) { rewriter.replaceOp(op, tilingResult->loop->getResults()); } - setTransformationAttr(rewriter, tilingResult->tiledOp); + setLabel(tilingResult->tiledOps.front(), kTransposeTransformedLabel); + + // Peel parallel loops, label the main loop as "perfectly tiled" one, to + // enable vectorization after canonicalization. + if (auto loop = dyn_cast_or_null(tilingResult->loop)) { + auto peelingResult = peelAllLoops(loop, rewriter); + setLabel(loop, kPerfectlyTiledLoopLabel); + + // Tile ops in the peeled loop again, to size 1, so they can be + // scalarized. + if (failed(tilePeeledOpsToScalars(rewriter, peelingResult, + kTransposeTransformedLabel, + /*fuseFilterFn=*/nullptr))) + return failure(); + } + return success(); } @@ -75,7 +100,7 @@ struct TransformTransposeForCpuPass void getDependentDialects(DialectRegistry ®istry) const final { registry.insert(); - registerGmlStTilingInterfaceExternalModels(registry); + linalg::registerTilingInterfaceExternalModels(registry); } void runOnOperation() override { @@ -130,7 +155,9 @@ struct TransformTransposeForCpuPass } // Ensure we drop the marker in the end. - func.walk([](linalg::TransposeOp op) { removeTransformationAttr(op); }); + func.walk([](linalg::TransposeOp op) { + removeLabel(op, kTransposeTransformedLabel); + }); } }; diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/fusion/fusion.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/fusion/fusion.cc new file mode 100644 index 00000000000..9bcf4e4dce4 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/fusion/fusion.cc @@ -0,0 +1,537 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "gml_st/transforms/fusion/fusion.h" + +#include +#include +#include + +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/peeling/peeling.h" +#include "gml_st/transforms/transforms.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "thlo/IR/thlo_ops.h" + +namespace mlir { +namespace gml_st { +namespace { + +#define GEN_PASS_DEF_FUSIONPASS +#include "gml_st/transforms/passes.h.inc" + +// TODO(frgossen): Move this to the shape reification pass. +struct DimOpFissionPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extract, + PatternRewriter& rewriter) const override { + auto shapeDef = llvm::dyn_cast_or_null( + extract.getTensor().getDefiningOp()); + if (!shapeDef || extract.getIndices().size() != 1) return failure(); + rewriter.replaceOpWithNewOp(extract, shapeDef.getArg(), + extract.getIndices().front()); + return success(); + } +}; + +// TODO(frgossen): Implement this through the shape reification interface and +// move this pattern to the shape reification pass. +struct DimOpReificationPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::DimOp op, + PatternRewriter& rewriter) const override { + Operation* def = op.getSource().getDefiningOp(); + if (!def) return failure(); + + // TODO(pifon): Split this pattern into many. + // Case tensor::ExtractSliceOp. + if (auto extractSliceOp = llvm::dyn_cast(def)) { + assert(extractSliceOp->getNumResults() == 1 && "assume single result"); + auto dimConstantIndex = op.getConstantIndex(); + if (!dimConstantIndex.has_value()) return failure(); + + rewriter.replaceOp(op, extractSliceOp.getSizes()[*dimConstantIndex]); + return success(); + } + // Case LinalgOp. + if (auto linalgOp = llvm::dyn_cast(def)) { + if (linalgOp->getNumResults() != 1 || !linalgOp.hasTensorSemantics()) { + return failure(); + } + Value outputOperand = linalgOp.getDpsInitOperand(0)->get(); + rewriter.replaceOpWithNewOp(op, outputOperand, + op.getIndex()); + return success(); + } + + // Case EmptyOp. + if (auto emptyTensorOp = llvm::dyn_cast(def)) { + if (auto indexConstantOp = llvm::dyn_cast_or_null( + op.getIndex().getDefiningOp())) { + int64_t idx = + indexConstantOp.getValue().dyn_cast().getInt(); + OpFoldResult dim = emptyTensorOp.getMixedSizes()[idx]; + Value dimValue; + if (dim.is()) { + dimValue = dim.get(); + } else { + assert(dim.is() && "expected Value or Attribute"); + int64_t dimInt = dim.get().cast().getInt(); + dimValue = + rewriter.create(op.getLoc(), dimInt); + } + assert(dimValue); + rewriter.replaceOp(op, ValueRange{dimValue}); + return success(); + } + } + + // Case ConcatenateOp. + if (auto concat = llvm::dyn_cast(def)) { + rewriter.replaceOpWithNewOp(op, concat.getInit(), + op.getIndex()); + return success(); + } + + // Case DynamicBroadcastInDimOp. + if (auto bcast = llvm::dyn_cast(def)) { + rewriter.replaceOpWithNewOp(op, bcast.getInit(), + op.getIndex()); + return success(); + } + + return failure(); + } +}; + +class FusionPattern : public OpRewritePattern { + public: + FusionPattern(MLIRContext* context, + function_ref filterFn, + mlir::PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + filterFn(filterFn) {} + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp, + PatternRewriter& rewriter) const override { + assert(filterFn && "expect filter function"); + if (failed(filterFn(extractSliceOp))) + return rewriter.notifyMatchFailure(extractSliceOp, "filtered"); + + // If there is an output argument produced by `linalg.fill`, then we can + // also fuse it into the parallel loop. To check that, we verify that the + // `src` of `extract_slice` is the bbArg of `gml_st.parallel` and that the + // corresponding operand of `gml_st.parallel` is defined by `linalg.fill`. + if (auto bbArg = dyn_cast(extractSliceOp.getSource())) { + if (auto parallelOp = + dyn_cast_or_null(bbArg.getOwner()->getParentOp())) { + Value loopOperand = + parallelOp.getOpOperandForRegionOutputArg(bbArg).get(); + if (loopOperand.getDefiningOp()) + return fuseFillOpsIntoParallelOp(rewriter, parallelOp); + } + } + return fuse(rewriter, extractSliceOp); + } + + private: + function_ref filterFn; +}; + +struct FusionPass : public impl::FusionPassBase { + FusionPass(StringRef producer, StringRef consumer) { + this->producerLabel = producer.str(); + this->consumerLabel = consumer.str(); + } + + void getDependentDialects(DialectRegistry& registry) const final { + registry.insert(); + } + + void runOnOperation() final { + MLIRContext* ctx = &getContext(); + + auto filterFn = [&](tensor::ExtractSliceOp op) { + Operation* producerOp = op.getSource().getDefiningOp(); + if (auto bbArg = dyn_cast(op.getSource())) { + if (isa(bbArg.getOwner()->getParentOp())) return success(); + } + if (!producerOp || (!producerLabel.empty() && + !hasMatchingLabel(producerOp, producerLabel))) { + return failure(); + } + + Operation* consumerOp = nullptr; + if (!consumerLabel.empty()) { + for (Operation* user : op.getResult().getUsers()) { + if (hasMatchingLabel(user, consumerLabel)) { + consumerOp = user; + break; + } + } + return success(consumerOp != nullptr); + } + + return success(); + }; + + // Populate patterns. + RewritePatternSet patterns(ctx); + populateFusionPatterns(ctx, filterFn, &patterns); + + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +bool isEqualOp(const Operation* lhsC, const Operation* rhsC) { + return OperationEquivalence::isEquivalentTo( + const_cast(lhsC), const_cast(rhsC), + OperationEquivalence::exactValueMatch, + /*markEquivalent=*/nullptr, OperationEquivalence::IgnoreLocations); +} + +template +void eliminateEqualOps(PatternRewriter& rewriter, Block& block) { + SmallVector uniqueOps; + for (auto op : llvm::make_early_inc_range(block.getOps())) { + auto* it = llvm::find_if( + uniqueOps, [&](OpTy uniqueOp) { return isEqualOp(uniqueOp, op); }); + if (it == uniqueOps.end()) { + uniqueOps.push_back(op); + } else { + rewriter.replaceOp(op, it->getResult()); + } + } +} + +void eliminateTriviallyDeadUsers(PatternRewriter& rewriter, Operation* op) { + for (auto* user : + DenseSet(op->getUsers().begin(), op->getUsers().end())) { + if (isOpTriviallyDead(user)) rewriter.eraseOp(user); + } +} + +void reifyDimOp(PatternRewriter& rewriter, tensor::DimOp dimOp) { + auto dimValue = dimOp.getSource().template dyn_cast(); + if (!dimValue) return; + auto rankedShapeTypeOp = + dyn_cast(dimValue.getOwner()); + if (!rankedShapeTypeOp) return; + + std::optional dimIndex = dimOp.getConstantIndex(); + if (!dimIndex) return; + + SmallVector> reifiedResultShapes; + if (failed( + rankedShapeTypeOp.reifyResultShapes(rewriter, reifiedResultShapes))) + return; + + if (reifiedResultShapes.size() != rankedShapeTypeOp->getNumResults()) return; + + unsigned resultNumber = dimValue.getResultNumber(); + auto sourceType = dimValue.getType().dyn_cast(); + if (reifiedResultShapes[resultNumber].size() != + static_cast(sourceType.getRank())) + return; + + rewriter.replaceOp(dimOp, reifiedResultShapes[resultNumber][*dimIndex]); +} + +void reifyDimOpsUsers(PatternRewriter& rewriter, Operation* op) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(op); + + for (auto* user : llvm::make_early_inc_range(op->getUsers())) { + auto dimOp = dyn_cast(user); + if (dimOp) reifyDimOp(rewriter, dimOp); + } +} + +// Iterates over tensor::ExtractSliceOp inside the block, finds a suitable +// candidate for fusion and fuses it. The fusion candidate should satisfy the +// filter function and not have uses outside of the block. Fails if nothing +// can be fused. +LogicalResult fuseGreedilyOneOpIntoBlock( + PatternRewriter& rewriter, Block& block, + llvm::function_ref filterFn) { + // Ad-hoc CSE to eliminate duplicate MatrializeOp that could have been added + // after previous fusions. Running the whole CSE pass would be to expensive + // here and unnecessary. Without removing those duplicate, some ops will be + // fused multiple times resulting in exponential code growth. + eliminateEqualOps(rewriter, block); + eliminateEqualOps(rewriter, block); + + for (auto extractSliceOp : block.getOps()) { + auto* fusionCandidate = extractSliceOp.getSource().getDefiningOp(); + // Do not fuse if there is no defining op. Of example if it's a + // materialize from a function argument. + if (!fusionCandidate) continue; + + if (filterFn && !filterFn(fusionCandidate)) continue; + + // Ad-hoc DCE to trim the fusion candidate from dead users that could have + // been added in the previous fusion cycles. Normally those ops would be + // garbage collected after the pattern rewriter driver finished working, + // but here it requires manual handling. + eliminateTriviallyDeadUsers(rewriter, fusionCandidate); + + // Push tensor.dim ops 'above' the fusion candidate. This is normally done + // by canonicalization passes, but running the whole canonicalization + // pipeline here is too expensive. + reifyDimOpsUsers(rewriter, fusionCandidate); + + // After the previous steps, extractSliceOp should be only one user of the + // fusion candidate. Otherwise this candidate should not be fused. + auto fusionCandidateUsers = llvm::to_vector(fusionCandidate->getUsers()); + if (fusionCandidateUsers.size() != 1 || + fusionCandidateUsers[0] != extractSliceOp) + continue; + + if (succeeded(fuse(rewriter, extractSliceOp))) { + return success(); + } + } + return failure(); +} + +} // namespace + +FailureOr fuse(PatternRewriter& rewriter, + tensor::ExtractSliceOp extractSliceOp) { + Location loc = extractSliceOp.getLoc(); + FailureOr fusedOr = createFusedOp(rewriter, extractSliceOp); + if (failed(fusedOr)) return failure(); // Match failure already notified. + + // Insert cast if needed. + Value fused = *fusedOr; + if (fused.getType() != extractSliceOp.getType()) { + // The result should be a tensor, cast it to the correct shape + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(fused.getDefiningOp()); + fused = + rewriter.create(loc, extractSliceOp.getType(), fused); + } + + rewriter.replaceOp(extractSliceOp, fused); + return fused.getDefiningOp(); +} + +void fuseGreedily(PatternRewriter& rewriter, Block& block, + llvm::function_ref filterFn) { + while (succeeded(fuseGreedilyOneOpIntoBlock(rewriter, block, filterFn))) + ; +} + +FusionCluster findMapFusionCluster(Operation* op) { + // Find the root operation in the chain of elementwise ops. Current approach + // doesn't work well if maps don't form a chain. + Operation* rootOp = op; + while (true) { + auto users = llvm::to_vector(rootOp->getUsers()); + + if (users.size() != 1) break; + if (!isa(users[0])) break; + + rootOp = users[0]; + } + + // Run a graph search to find all linalg.map and that can be fused in + // the root op. + DenseSet resultOps; + SmallVector remainingProducers{rootOp}; + + while (!remainingProducers.empty()) { + Operation* curOp = remainingProducers.pop_back_val(); + if (!curOp) continue; + + if (auto mapOp = dyn_cast(curOp)) { + resultOps.insert(curOp); + for (auto* operand : mapOp.getDpsInputOperands()) + remainingProducers.push_back(operand->get().getDefiningOp()); + } else if (curOp->getName() == op->getName()) { + for (auto* u : curOp->getUsers()) { + // Do not fuse curOp that is used by another op of the same type. + if (u->getName() == op->getName()) continue; + } + resultOps.insert(curOp); + } + } + return {resultOps, rootOp}; +} + +LogicalResult fuseFillOpsIntoParallelOp(PatternRewriter& rewriter, + ParallelOp parallelOp) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointToStart(parallelOp.getBody()); + bool fillOpsWereFused = false; + for (OpOperand& output : + parallelOp->getOpOperands().take_back(parallelOp.getNumOutputs())) { + auto fillOp = output.get().getDefiningOp(); + if (!fillOp) continue; + + fillOpsWereFused = true; + + // Clone `linalg.fill` op inside the loop, update the uses of bbArg. + BlockArgument regionOutputArg = + parallelOp.getRegionOutputArgForOpOperand(output); + auto clonedFill = cast( + mlir::clone(rewriter, fillOp, fillOp.getResultTypes(), + {fillOp.value(), regionOutputArg})); + + output.set(fillOp.output()); + + SmallVector sliceOps; + regionOutputArg.replaceUsesWithIf( + clonedFill.getResult(0), [&](OpOperand& operand) { + Operation* owner = operand.getOwner(); + if (auto sliceOp = dyn_cast_or_null(owner)) + sliceOps.push_back(sliceOp); + return owner != clonedFill && !isa(owner) && + owner->getParentOfType() == parallelOp; + }); + + // Use standard fusion logic to swap extract_slice(fill) -> + // fill(extract_slice). + for (tensor::ExtractSliceOp sliceOp : sliceOps) + (void)fuse(rewriter, sliceOp); + } + return success(fillOpsWereFused); +} + +FailureOr tileUsingGmlStParallelAndFuseGreedily( + PatternRewriter& rewriter, Operation* op, + const mlir::gml_st::TilingOptions& opts, StringRef label, + llvm::function_ref fuseFilterFn) { + assert(opts.distribute == true && + "gml_st.for should not be used for CPU pipeline"); + auto tilingResult = tileUsingGmlSt(opts, rewriter, cast(op)); + if (failed(tilingResult)) return failure(); + + // If we did not tile (e.g. when all tile sizes are 0), do not replace + // original op and just mark it as transformed then return. + if (tilingResult->loop != nullptr) { + rewriter.replaceOp(op, tilingResult->loop->getResults()); + + // Fuse ops into the loop. + fuseGreedily(rewriter, *tilingResult->tiledOps.front()->getBlock(), + fuseFilterFn); + } + setLabel(tilingResult->tiledOps.front(), label); + return cast(tilingResult->loop); +} + +FailureOr tileUsingSCFForOpAndFuseGreedily( + PatternRewriter& rewriter, Operation* op, const scf::SCFTilingOptions& opts, + StringRef label, llvm::function_ref fuseFilterFn) { + auto tilingResult = scf::tileUsingSCFForOp(rewriter, op, opts); + if (failed(tilingResult)) return failure(); + + // If we did not tile (e.g. when all tile sizes are 0), do not replace + // original op and just mark it as transformed then return. + if (!tilingResult->loops.empty()) { + rewriter.replaceOp(op, tilingResult->replacements); + + // Fuse ops into the loop. + fuseGreedily(rewriter, *tilingResult->loops.back().getBody(), fuseFilterFn); + } + setLabel(tilingResult->tiledOps.front(), label); + return tilingResult; +} + +LogicalResult tilePeeledOpsToScalars( + PatternRewriter& rewriter, const GmlStPeelingResult& peelingResult, + StringRef label, llvm::function_ref fuseFilterFn) { + for (auto* loop : peelingResult.tailLoops) { + ParallelOp peeledLoop = dyn_cast(loop); + auto* terminatorOp = peeledLoop->getRegion(0).front().getTerminator(); + if (!terminatorOp) return failure(); + + auto* definingOp = terminatorOp->getOperand(0).getDefiningOp(); + if (!definingOp) return failure(); + + mlir::gml_st::TilingOptions opts; + opts.setTileSizeComputationFn(SmallVector( + cast(definingOp).getNumLoops(), 1)); + + if (failed(tileUsingGmlStParallelAndFuseGreedily(rewriter, definingOp, opts, + label, fuseFilterFn))) + return failure(); + } + return success(); +} + +FailureOr createFusedOp(PatternRewriter& rewriter, + tensor::ExtractSliceOp extractSliceOp) { + Value src = extractSliceOp.getSource(); + if (!src) return failure(); + auto tileableOp = src.getDefiningOp(); + if (!tileableOp) { + return rewriter.notifyMatchFailure( + extractSliceOp, + "expected source to be defined by tiling interface op "); + } + + SmallVector offsets = extractSliceOp.getMixedOffsets(); + SmallVector sizes = extractSliceOp.getMixedSizes(); + + // Tile the producer. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(extractSliceOp); + FailureOr tiledProducer = tileableOp.generateResultTileValue( + rewriter, /*resultNumber=*/0, offsets, sizes); + if (failed(tiledProducer)) { + return rewriter.notifyMatchFailure(tileableOp, + "failed to tile the producer"); + } + + return tiledProducer; +} + +void populateFusionPatterns( + MLIRContext* ctx, + function_ref filterFn, + RewritePatternSet* patterns) { + patterns->insert(ctx, filterFn); + // clang-format off + patterns->insert< + DimOpFissionPattern, + DimOpReificationPattern>(ctx); + // clang-format on +} + +std::unique_ptr> createFusionPass( + StringRef producer, StringRef consumer) { + return std::make_unique(producer, consumer); +} + +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/fusion/fusion.h b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/fusion/fusion.h new file mode 100644 index 00000000000..2f7b23bc637 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/fusion/fusion.h @@ -0,0 +1,86 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MLIR_HLO_GML_ST_TRANSFORMS_FUSION_FUSION_H +#define MLIR_HLO_GML_ST_TRANSFORMS_FUSION_FUSION_H + +#include "gml_st/transforms/peeling/peeling.h" +#include "gml_st/transforms/tiling/tiling.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace gml_st { + +// Create fused operation based on the specificed subset. The result is +// equivalent to the given `tensor.extract_slice` op. +FailureOr createFusedOp(PatternRewriter &rewriter, + tensor::ExtractSliceOp materializeOp); + +// Fuses an op into `tensor.extract_slice` and performs the necessary updates to +// the surrounding loop if any. +FailureOr fuse(PatternRewriter &rewriter, + tensor::ExtractSliceOp materializeOp); + +// Finds `tensor.extract_slice` ops in the block and fuses ops into them. +// Verifies that fusion candidate doesn't have any uses except the one +// `tensor.extract_slice` in the block to avoid exponential code growth. +void fuseGreedily(PatternRewriter &rewriter, Block &block, + llvm::function_ref filterFn = nullptr); + +/// Populate fusion patterns. +void populateFusionPatterns( + MLIRContext *ctx, + function_ref filterFn, + RewritePatternSet *patterns); + +struct FusionCluster { + DenseSet operations; + Operation *root; +}; + +// Find a cluster of operations that can be tiled and fused together around +// the root op. We want to fuse output of the fusion op with elementwise ops. In +// general case a cluster is a tree that can have multiple leaf-node ops, +// e.g. map(op, map(op)). +// First element of the cluster is always the root for tiling. +FusionCluster findMapFusionCluster(Operation *op); + +// Fuses linalg.fill that is used in output argument of the ParallelOp. +LogicalResult fuseFillOpsIntoParallelOp(PatternRewriter &rewriter, + ParallelOp parallelOp); + +// Tiles the op to gml_st.parallel and fuses greedily according to the filter. +FailureOr tileUsingGmlStParallelAndFuseGreedily( + PatternRewriter &rewriter, Operation *op, + const mlir::gml_st::TilingOptions &opts, StringRef label, + llvm::function_ref fuseFilterFn); + +// Tiles the op to scf.for and fuses greedily according to the filter. +FailureOr tileUsingSCFForOpAndFuseGreedily( + PatternRewriter &rewriter, Operation *op, const scf::SCFTilingOptions &opts, + StringRef label, llvm::function_ref fuseFilterFn); + +// Tiles the op to 1 for all dimensions and fuses greedily according to the +// filter function. +LogicalResult tilePeeledOpsToScalars( + PatternRewriter &rewriter, const GmlStPeelingResult &peelingResult, + StringRef label, llvm::function_ref fuseFilterFn); + +} // namespace gml_st +} // namespace mlir + +#endif // MLIR_HLO_GML_ST_TRANSFORMS_FUSION_FUSION_H diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/gml_st_to_gpu.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gml_st_simtfy/gml_st_simtfy.cc similarity index 51% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/gml_st_to_gpu.cc rename to tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gml_st_simtfy/gml_st_simtfy.cc index 79c9f7f3807..6795c25b35e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/gml_st_to_gpu.cc +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gml_st_simtfy/gml_st_simtfy.cc @@ -18,46 +18,38 @@ limitations under the License. #include #include #include +#include #include +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/passes.h" #include "llvm/ADT/STLExtras.h" -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "mlir-hlo/Dialect/gml_st/transforms/vector_utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#define GEN_PASS_DEF_GMLSTTOGPUPASS -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h.inc" +#define GEN_PASS_DEF_GMLSTSIMTFYPASS +#include "gml_st/transforms/passes.h.inc" using namespace mlir; using namespace mlir::gml_st; using mlir::gpu::LaunchOp; -using mlir::memref::SubViewOp; -using mlir::vector::CombiningKind; -using mlir::vector::ExtractOp; -using mlir::vector::MultiDimReductionOp; -using mlir::vector::TransferReadOp; -using mlir::vector::TransferWriteOp; namespace { + /// Converts a sequence of 3 nested gml_st.parallel ops into a gpu.launch op. -/// Throughout thes pass we will call the first level of nesting "block", the -/// second "warp", and the 3rd "thread" level. The intention is to allude to the -/// fact that these will likely correspond to the CUDA programming concepts of -/// the same name when the IR is lowered to PTX. However, this pass does not -/// make, nor verify all the requirements (e.g., that the warp-level iteration -/// contains exactly 32 steps) for mapping to this level. +/// Throughout this pass we call the first level of nesting "block", the second +/// "warp", and the 3rd "thread" level. The intent is to allude to the fact that +/// these will likely correspond to the CUDA programming concepts of the same +/// name when the IR is lowered to PTX. However, this pass does not make, nor +/// verify all the requirements (e.g., that the warp-level iteration contains +/// exactly 32 steps) for mapping to this level. /// /// Each gml_st.parallel is expected to only have a single induction variable. /// The loops representing the block, warp, and thread level are mapped to @@ -74,58 +66,26 @@ namespace { /// long as they have the same iteration space, i.e., the SSA values defining /// the lower bound, upper bound and the step of all parallels on the same level /// of nesting are the same values. -struct ParallelOpToGpuPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ParallelOp root, - PatternRewriter& rewriter) const override; -}; - -struct MultiDimReductionOpToWarpReductionPattern - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp, - PatternRewriter& rewriter) const override; -}; - -struct EliminateMaterializeOfTransferReadPattern - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(MaterializeOp materialize, - PatternRewriter& rewriter) const override; -}; - -struct EliminateDistributeIntoTransferWritePattern - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TransferWriteOp transferWrite, - PatternRewriter& rewriter) const override; -}; +LogicalResult simtfyOp(ParallelOp root, RewriterBase& rewriter); /// Implements the GmlStToGpuPass declared in /// include/mlir-hlo/Dialect/gml_st/transforms/passes.td. -struct GmlStToGpuPass : public ::impl::GmlStToGpuPassBase { +struct GmlStSimtfyPass : public ::impl::GmlStSimtfyPassBase { + using GmlStSimtfyPassBase::GmlStSimtfyPassBase; + void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - patterns - .add(&getContext()); func::FuncOp func = getOperation(); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) - signalPassFailure(); - // Make sure there are no gml_st.parallel ops left. - // TODO(b/254497967): Properly handle a full conversion from GmlSt to GPU - // once SIMTfication is split from conversion of other GmlSt operations. - // For now, only verify that we do not have a ParallelOp, since there - // might still be other gml_st operations that will be removed by a - // subsequent conversion from GmlSt to SCF. - WalkResult walk = func.walk([](ParallelOp op) { - op->emitOpError("failed to simtfy"); - return WalkResult::interrupt(); + IRRewriter rewriter(&getContext()); + + WalkResult walk = func.walk([&](ParallelOp op) { + if (op.getDistributionType().has_value() && + op.getDistributionType().value() == blockDistributionLabel) { + if (failed(simtfyOp(op, rewriter))) { + op->emitOpError("failed to simtfy"); + return WalkResult::interrupt(); + } + } + return WalkResult::skip(); }); if (walk.wasInterrupted()) signalPassFailure(); } @@ -136,7 +96,7 @@ struct GmlStToGpuPass : public ::impl::GmlStToGpuPassBase { /// thread. The idea is to update those later, as we discover the correct values /// from the nesting structure. static LaunchOp createInitialGpuLaunchOp(Location loc, Value defaultSize, - PatternRewriter& rewriter) { + RewriterBase& rewriter) { auto launch = rewriter.create(loc, defaultSize, defaultSize, defaultSize, defaultSize, defaultSize, defaultSize); @@ -186,7 +146,7 @@ static LogicalResult verifyLoopBoundsMatch(Value currentBound, /// `inductionVar`) that result from using the approximated value. static Value handleImperfectTile(Location loc, LaunchOp launch, Value upperBound, Value inductionVar, - PatternRewriter& rewriter) { + RewriterBase& rewriter) { // We are assuming that imperfect tiling is expressed through an affine.min // op with an affine map of the form ()[] -> // (, tileSize), where s possibly depend on values @@ -212,7 +172,7 @@ static Value handleImperfectTile(Location loc, LaunchOp launch, // Create a constant corresponding to the tile size, and return it as the // iteration-independent upper bound. - PatternRewriter::InsertionGuard guard(rewriter); + RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(launch); return rewriter.create(loc, tileSize.getValue()); } @@ -221,9 +181,9 @@ static Value handleImperfectTile(Location loc, LaunchOp launch, /// space of `parallel`. Returns an SSA value that is a part of the `launch`'s /// region, and represents the value of `parallel`'s induction variable. static Value matchLaunchSpaceToLoop(ParallelOp parallel, - const BlockAndValueMapping& bvm, + const IRMapping& bvm, unsigned launchIdx, LaunchOp launch, - PatternRewriter& rewriter) { + RewriterBase& rewriter) { Location loc = parallel.getLoc(); Value upperBound = bvm.lookupOrDefault(parallel.getUpperBound().front()); Value lowerBound = bvm.lookupOrDefault(parallel.getLowerBound().front()); @@ -247,7 +207,7 @@ static Value matchLaunchSpaceToLoop(ParallelOp parallel, /*dimCount=*/1, /*symbolCount=*/2, (rewriter.getAffineDimExpr(0) - rewriter.getAffineSymbolExpr(0)) .ceilDiv(rewriter.getAffineSymbolExpr(1))); - PatternRewriter::InsertionGuard guard(rewriter); + RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(launch); launch.setOperand( launchIdx, rewriter.create( @@ -256,37 +216,40 @@ static Value matchLaunchSpaceToLoop(ParallelOp parallel, return inductionVar; } +namespace { // Converts the 3 nested gml_st.parallel ops rooted at `root` into a // gpu.launch op. We do this by creating an empty gpu.launch region and // copying all the operations in gml_st.parallel into that region, // recursively copying the bodies of any nested gml_st.parallel regions that // we encounter. -LogicalResult ParallelOpToGpuPattern::matchAndRewrite( - ParallelOp root, PatternRewriter& rewriter) const { +LogicalResult simtfyOp(ParallelOp root, RewriterBase& rewriter) { + rewriter.setInsertionPoint(root); Location loc = root.getLoc(); - if (root->getParentOfType()) - return rewriter.notifyMatchFailure(root, "should be the root parallel"); - Value defaultSize = rewriter.create(loc, 1); LaunchOp launch = createInitialGpuLaunchOp(loc, defaultSize, rewriter); - BlockAndValueMapping bvm; + constexpr size_t kNumberOfNestedLoops = 3; + + IRMapping bvm; // We need to keep track of which value in the gpu.launch region represents // which level of the induction variable in the nested region. This is because // we might have multiple gml_st.parallel operations on the same level, and // their induction variables should map to the same value in the flattened // gpu.launch region. - SmallVector nestingLevelToInductionVarMap; + SmallVector nestingLevelToInductionVarMap; // This is our stack holding in-flight operations of gml_st.parallel regions // that we started to copy over to the gpu.launch region, but are on hold // while we are processing a nested gml_st.parallel. - SmallVector, 3> loopIterators; + SmallVector, kNumberOfNestedLoops> + loopIterators; // This functor implements the processing of a single parallel op: // 1) update of GPU launch bounds according to the interation space // 2) addition of a nesting level to `loopIterators`, with the iterator // over `parallel`'s body + // 3) propagation of the distribution level of the op to all its + // (non-ParallelOp) children if it is not at the innermost nesting level auto processParallelOp = [&](ParallelOp parallel) { unsigned nestingLevel = loopIterators.size(); unsigned inductionVarIdx = 0; @@ -314,10 +277,31 @@ LogicalResult ParallelOpToGpuPattern::matchAndRewrite( parallel, bvm, inductionVarIdx, launch, rewriter)); } + Block* body = parallel.getBody(); + + // Check that the nesting level is not the innermost one (i.e. that we + // are not at the thread level, but either at the block, or at the warp + // level). + if (nestingLevel < kNumberOfNestedLoops - 1) { + if (!parallel.getDistributionType().has_value()) + return rewriter.notifyMatchFailure( + parallel, + "expected parallel operation to define a distribution type"); + + const StringAttr distributionTypeAttr = + parallel.getDistributionTypeAttr(); + body->walk( + [&distributionTypeAttr](Operation* nestedOp) { + if (dyn_cast(nestedOp)) return WalkResult::skip(); + nestedOp->setAttr(kDistributionLabelKey, distributionTypeAttr); + return WalkResult::advance(); + }); + } + bvm.map(parallel.getInductionVars().front(), nestingLevelToInductionVarMap[nestingLevel]); - Block* body = parallel.getBody(); loopIterators.push_back(llvm::make_range(body->begin(), body->end())); + return success(); }; @@ -360,195 +344,11 @@ LogicalResult ParallelOpToGpuPattern::matchAndRewrite( rewriter.eraseOp(root); return success(); } +} // namespace -static Value createCombineOp(Location loc, Value lhs, Value rhs, - CombiningKind kind, PatternRewriter& rewriter) { - auto helper = [&](auto dummy) { - return rewriter.create(loc, lhs, rhs); - }; - switch (kind) { - case CombiningKind::ADD: - return helper(arith::AddFOp()); - case CombiningKind::MUL: - return helper(arith::MulFOp()); - case CombiningKind::MINUI: - return helper(arith::MinUIOp()); - case CombiningKind::MINSI: - return helper(arith::MinSIOp()); - case CombiningKind::MINF: - return helper(arith::MinFOp()); - case CombiningKind::MAXUI: - return helper(arith::MaxUIOp()); - case CombiningKind::MAXSI: - return helper(arith::MaxSIOp()); - case CombiningKind::MAXF: - return helper(arith::MaxFOp()); - case CombiningKind::AND: - return helper(arith::AndIOp()); - case CombiningKind::OR: - return helper(arith::OrIOp()); - case CombiningKind::XOR: - return helper(arith::XOrIOp()); - default: - llvm_unreachable("unhandled"); - } -} - -LogicalResult MultiDimReductionOpToWarpReductionPattern::matchAndRewrite( - MultiDimReductionOp reductionOp, PatternRewriter& rewriter) const { - auto inType = reductionOp.getSourceVectorType(); - int64_t width = inType.getNumElements(); - std::initializer_list supportedWidths = {1, 2, 4, 8, 16, 32}; - if (!llvm::is_contained(supportedWidths, width)) { - return rewriter.notifyMatchFailure( - reductionOp, "expected input vector with size 2^N, <=32"); - } - auto hasOneElement = [](auto type) { - return type && type.getNumElements() == 1; - }; - auto outType = reductionOp.getDestType().dyn_cast(); - if (!hasOneElement(outType)) { - return rewriter.notifyMatchFailure(reductionOp, "expected 1-vector output"); - } - auto distribute = reductionOp.getSource().getDefiningOp(); - if (!distribute) { - return rewriter.notifyMatchFailure( - reductionOp, "source not defined by gml_st.distribute"); - } - // Even if this value was not written into the tile corresponding to the - // current thread's lane id, this is fine, since it doesn't matter which - // thread processes which element within a reduction. - TypedValue lhsVector = distribute.getSource(); - if (!hasOneElement(lhsVector.getType())) { - return rewriter.notifyMatchFailure(distribute, "expected 1-vector input"); - } - - // Preamble: extract element from input - Location loc = reductionOp->getLoc(); - Value lhs = rewriter.create( - loc, lhsVector, SmallVector(lhsVector.getType().getRank(), 0)); - - auto createConstant = [&](int32_t value) { - return rewriter.create( - loc, rewriter.getI32IntegerAttr(value)); - }; - Value cWidth = createConstant(width); - // Create warp shuffles of increasing offset and interleave with a clone of - // the accumulate block. - for (int64_t i = 1; i < width; i *= 2) { - auto shuffleOp = rewriter.create( - loc, lhs, createConstant(i), cWidth, gpu::ShuffleMode::XOR); - lhs = createCombineOp(loc, lhs, shuffleOp.getShuffleResult(), - reductionOp.getKind(), rewriter); - } - - // Combine with init element and broadcast result back to vector. - Value acc = rewriter.create(loc, reductionOp.getAcc(), 0); - lhs = createCombineOp(loc, lhs, acc, reductionOp.getKind(), rewriter); - rewriter.replaceOpWithNewOp(reductionOp, outType, lhs); - - return success(); -} - -SubViewOp createSubView(Location loc, Value source, TileOp tile, - PatternRewriter& rewriter) { - auto asIntArray = [](ArrayAttr array) { - return llvm::to_vector(llvm::map_range(array, [](Attribute attr) { - return attr.cast().getInt(); - })); - }; - Type memRefType = SubViewOp::inferResultType( - source.getType().cast(), asIntArray(tile.getStaticOffsets()), - asIntArray(tile.getStaticSizes()), asIntArray(tile.getStaticStrides())); - return rewriter.create( - loc, memRefType, source, tile.getOffsets(), tile.getSizes(), - tile.getStrides(), tile.getStaticOffsets(), tile.getStaticSizes(), - tile.getStaticStrides()); -} - -LogicalResult EliminateMaterializeOfTransferReadPattern::matchAndRewrite( - MaterializeOp materialize, PatternRewriter& rewriter) const { - // Match the following pattern: - // gml_st.materialize( - // vector.transfer_read Memref:$src[(arith.constant 0)...] - // gml_st.tile [$offsets] [$sizes] [$strides]) - auto transferRead = materialize.getSource().getDefiningOp(); - if (!transferRead) { - return rewriter.notifyMatchFailure( - materialize, "expected vector.transfer_read as source"); - } - Value source = transferRead.getSource(); - if (!source.getType().isa()) { - return rewriter.notifyMatchFailure(transferRead, - "expected memref as source"); - } - if (failed(matchSimpleTransferOp(transferRead, rewriter))) return failure(); - - auto tile = materialize.getSet().getDefiningOp(); - if (!tile) { - return rewriter.notifyMatchFailure(materialize, - "expected gml_st.tile as set"); - } - - // Rewrite the pattern as: - // vector.transfer_read - // (memref.subview $src [$offsets] [$sizes] [$strides]) - // [(arith.constant 0)...] - // TODO(b/254271932): This might not be correct if there is someone writing - // to `source` in between `transferRead` and `materialize`. This won't happen - // for elementwise fusion and softmax, but might become a problem down the - // line. - auto subview = createSubView(materialize.getLoc(), source, tile, rewriter); - Type resultType = materialize.getResult().getType(); - if (!resultType.isa()) { - // We have a transfer to a single element: just use memref.load directly. - rewriter.replaceOpWithNewOp(materialize, subview, - transferRead.getIndices()); - return success(); - } - rewriter.replaceOpWithNewOp( - materialize, resultType, subview, transferRead.getIndices(), - transferRead.getPermutationMap(), transferRead.getPadding(), - /*mask=*/nullptr, transferRead.getInBounds().value_or(nullptr)); - return success(); -} - -LogicalResult EliminateDistributeIntoTransferWritePattern::matchAndRewrite( - TransferWriteOp transferWrite, PatternRewriter& rewriter) const { - // Match the following pattern: - // vector.transfer_write - // (gml_st.distribute $src into - // [(gml_st.tile [$offsets] [$sizes] [$strides])]) - // Memref:$dst[(arith.constant 0)] - Value destination = transferWrite.getSource(); - if (!destination.getType().isa()) { - return rewriter.notifyMatchFailure(transferWrite, - "expected memref as destination"); - } - if (failed(matchSimpleTransferOp(transferWrite, rewriter))) return failure(); - - auto distribute = transferWrite.getVector().getDefiningOp(); - if (!distribute) { - return rewriter.notifyMatchFailure(transferWrite, - "expected distribute as source"); - } - Value source = distribute.getSource(); - - auto tile = distribute.getSet().getDefiningOp(); - if (!tile) { - return rewriter.notifyMatchFailure(distribute, - "expected gml_st.tile as set"); - } - - // Rewrite the pattern as: - // vector.transfer_write $src, - // (memref.subview $dst [$offsets] [$sizes] [$strides]) - // [(arith.constant 0)...] - auto subview = - createSubView(transferWrite.getLoc(), destination, tile, rewriter); - rewriter.replaceOpWithNewOp( - transferWrite, /*resultType=*/llvm::None, source, subview, - transferWrite.getIndices(), transferWrite.getPermutationMap(), - /*mask=*/nullptr, transferWrite.getInBounds().value_or(nullptr)); - return success(); +std::unique_ptr> +mlir::gml_st::createGmlStSimtfyPass(StringRef blockDistributionLabel) { + const GmlStSimtfyPassOptions passOptions = { + /*.warpDistributionLabel=*/std::string(blockDistributionLabel)}; + return std::make_unique(passOptions); } diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gml_st_to_gpu/gml_st_to_gpu.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gml_st_to_gpu/gml_st_to_gpu.cc new file mode 100644 index 00000000000..737a51df36f --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gml_st_to_gpu/gml_st_to_gpu.cc @@ -0,0 +1,364 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/utils/vector_utils.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define GEN_PASS_DEF_GMLSTTOGPUPASS +#include "gml_st/transforms/passes.h.inc" + +using namespace mlir; +using namespace mlir::gml_st; +using mlir::memref::SubViewOp; +using mlir::vector::CombiningKind; +using mlir::vector::ExtractOp; +using mlir::vector::MultiDimReductionOp; +using mlir::vector::TransferReadOp; +using mlir::vector::TransferWriteOp; + +namespace { + +struct MultiDimReductionOpToWarpReductionPattern + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + MultiDimReductionOpToWarpReductionPattern(MLIRContext* context, + StringRef warpDistributionLabel) + : OpRewritePattern(context), + warpDistributionLabel(warpDistributionLabel) {} + + LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp, + PatternRewriter& rewriter) const override; + + private: + std::string warpDistributionLabel; +}; + +struct EliminateMaterializeOfTransferReadPattern + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MaterializeOp materialize, + PatternRewriter& rewriter) const override; +}; + +struct EliminateDistributeIntoTransferWritePattern + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TransferWriteOp transferWrite, + PatternRewriter& rewriter) const override; +}; + +/// Implements the GmlStToGpuPass declared in +/// gml_st/transforms/passes.td. +struct GmlStToGpuPass : public ::impl::GmlStToGpuPassBase { + using GmlStToGpuPassBase::GmlStToGpuPassBase; + + void runOnOperation() override { + MLIRContext& ctx = getContext(); + RewritePatternSet patterns(&ctx); + + patterns.add(&ctx); + patterns.add( + &ctx, warpDistributionLabel); + + func::FuncOp func = getOperation(); + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + signalPassFailure(); + } +}; + +Value createCombineOp(Location loc, Value lhs, Value rhs, CombiningKind kind, + PatternRewriter& rewriter, Type elementType) { + auto helper = [&](auto dummy) { + return rewriter.create(loc, lhs, rhs); + }; + bool isInt = elementType.isa(); + switch (kind) { + case CombiningKind::ADD: + if (isInt) return helper(arith::AddIOp()); + return helper(arith::AddFOp()); + case CombiningKind::MUL: + if (isInt) return helper(arith::MulIOp()); + return helper(arith::MulFOp()); + case CombiningKind::MINUI: + return helper(arith::MinUIOp()); + case CombiningKind::MINSI: + return helper(arith::MinSIOp()); + case CombiningKind::MINF: + return helper(arith::MinFOp()); + case CombiningKind::MAXUI: + return helper(arith::MaxUIOp()); + case CombiningKind::MAXSI: + return helper(arith::MaxSIOp()); + case CombiningKind::MAXF: + return helper(arith::MaxFOp()); + case CombiningKind::AND: + return helper(arith::AndIOp()); + case CombiningKind::OR: + return helper(arith::OrIOp()); + case CombiningKind::XOR: + return helper(arith::XOrIOp()); + } + llvm_unreachable("unhandled"); +} + +} // namespace + +LogicalResult MultiDimReductionOpToWarpReductionPattern::matchAndRewrite( + MultiDimReductionOp reductionOp, PatternRewriter& rewriter) const { + auto distributionLevelAttr = + reductionOp->getAttrOfType(kDistributionLabelKey); + + if (!distributionLevelAttr || + distributionLevelAttr.getValue() != warpDistributionLabel) { + return rewriter.notifyMatchFailure(reductionOp, + "expected warp-level operation"); + } + + auto inType = reductionOp.getSourceVectorType(); + auto elementType = inType.getElementType(); + if (!elementType.isIntOrFloat() || elementType.getIntOrFloatBitWidth() > 32) { + return rewriter.notifyMatchFailure( + reductionOp, "expected int or float element type <= 32b"); + } + int64_t width = inType.getNumElements(); + std::initializer_list supportedWidths = {1, 2, 4, 8, 16, 32}; + if (!llvm::is_contained(supportedWidths, width)) { + return rewriter.notifyMatchFailure( + reductionOp, "expected input vector with size 2^N, <=32"); + } + auto hasOneElement = [](auto type) { + return type && type.getNumElements() == 1; + }; + auto outType = reductionOp.getDestType().dyn_cast(); + if (!hasOneElement(outType)) { + return rewriter.notifyMatchFailure(reductionOp, "expected 1-vector output"); + } + auto distribute = reductionOp.getSource().getDefiningOp(); + if (!distribute) { + return rewriter.notifyMatchFailure( + reductionOp, "source not defined by gml_st.distribute"); + } + // Even if this value was not written into the tile corresponding to the + // current thread's lane id, this is fine, since it doesn't matter which + // thread processes which element within a reduction. + TypedValue distributeSource = distribute.getSource(); + if (!hasOneElement(distributeSource.getType())) { + return rewriter.notifyMatchFailure(distribute, "expected 1-vector input"); + } + + // Preamble: extract element from input. + Location loc = reductionOp->getLoc(); + Value result = rewriter.create( + loc, distributeSource, + SmallVector(distributeSource.getType().getRank(), 0)); + + auto createConstant = [&](int32_t value) { + return rewriter.create( + loc, rewriter.getI32IntegerAttr(value)); + }; + // Always have all lanes participate. This assumes that the lanes are either + // in convergence or that they have exited the kernel. + Value cWarpWidth = createConstant(32); + // Create warp shuffles of increasing offset and interleave with a clone of + // the accumulate block. + unsigned bitWidth = elementType.getIntOrFloatBitWidth(); + for (int64_t i = 1; i < width; i *= 2) { + Value shuffle = result; + if (bitWidth < 32) { + shuffle = rewriter.create( + loc, rewriter.getI32Type(), + rewriter.create( + loc, rewriter.getIntegerType(bitWidth), shuffle)); + } + shuffle = rewriter + .create( + loc, shuffle, createConstant(static_cast(i)), + cWarpWidth, gpu::ShuffleMode::XOR) + .getShuffleResult(); + if (bitWidth < 32) { + shuffle = rewriter.create( + loc, elementType, + rewriter.create( + loc, rewriter.getIntegerType(bitWidth), shuffle)); + } + result = createCombineOp(loc, result, shuffle, reductionOp.getKind(), + rewriter, elementType); + } + + // Combine with init element and broadcast result back to vector. + Value acc = rewriter.create(loc, reductionOp.getAcc(), 0); + result = createCombineOp(loc, acc, result, reductionOp.getKind(), rewriter, + elementType); + rewriter.replaceOpWithNewOp(reductionOp, outType, + result); + + return success(); +} + +namespace { +SubViewOp createSubView(Location loc, Value source, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides, + PatternRewriter& rewriter) { + Type memRefType = SubViewOp::inferResultType( + source.getType().cast(), offsets, sizes, strides); + return rewriter.create(loc, memRefType.cast(), source, + offsets, sizes, strides); +} + +// Matches a simple version of vector.transfer_read `op`. +// 1. it has a minor identity permutation map +// 2. it has no mask +LogicalResult matchNonPermutingTransferRead(vector::TransferReadOp op, + PatternRewriter& rewriter) { + if (!op.getPermutationMap().isMinorIdentity()) { + return rewriter.notifyMatchFailure(op, + "expected cannonical permutation map"); + } + if (op.getMask()) { + return rewriter.notifyMatchFailure(op, "should have no mask"); + } + return success(); +} + +} // namespace + +LogicalResult EliminateMaterializeOfTransferReadPattern::matchAndRewrite( + MaterializeOp materialize, PatternRewriter& rewriter) const { + // Match the following pattern: + // gml_st.materialize( + // vector.transfer_read Memref:$src[(arith.constant 0)...] + // gml_st.tile [$offsets] [$sizes] [$strides]) + auto transferRead = materialize.getSource().getDefiningOp(); + if (!transferRead) { + return rewriter.notifyMatchFailure( + materialize, "expected vector.transfer_read as source"); + } + Value source = transferRead.getSource(); + if (!source.getType().isa()) { + return rewriter.notifyMatchFailure(transferRead, + "expected memref as source"); + } + if (failed(matchNonPermutingTransferRead(transferRead, rewriter))) + return failure(); + + // Rewrite the pattern as: + // vector.transfer_read + // (memref.subview $src [$offsets] [$sizes] [$strides]) + // [(arith.constant 0)...] + // TODO(b/254271932): This might not be correct if there is someone writing + // to `source` in between `transferRead` and `materialize`. This won't happen + // for elementwise fusion and softmax, but might become a problem down the + // line. + SmallVector offsets; + for (auto en : llvm::zip(transferRead.getIndices(), + getAsValues(rewriter, materialize.getLoc(), + materialize.getMixedOffsets()))) { + Value transferReadOffset = std::get<0>(en); + Value materializeOffset = std::get<1>(en); + offsets.push_back({rewriter.createOrFold( + materialize.getLoc(), transferReadOffset, materializeOffset)}); + } + SmallVector zeros( + transferRead.getIndices().size(), + rewriter.create(materialize.getLoc(), 0)); + auto subview = createSubView(materialize.getLoc(), source, offsets, + materialize.getMixedSizes(), + materialize.getMixedStrides(), rewriter); + Type resultType = materialize.getResult().getType(); + if (!resultType.isa()) { + // We have a transfer to a single element: just use memref.load directly. + rewriter.replaceOpWithNewOp(materialize, subview, zeros); + return success(); + } + rewriter.replaceOpWithNewOp( + materialize, resultType, subview, zeros, transferRead.getPermutationMap(), + transferRead.getPadding(), + /*mask=*/nullptr, transferRead.getInBounds().value_or(nullptr)); + return success(); +} + +LogicalResult EliminateDistributeIntoTransferWritePattern::matchAndRewrite( + TransferWriteOp transferWrite, PatternRewriter& rewriter) const { + // Match the following pattern: + // vector.transfer_write + // (gml_st.distribute $src into + // [(gml_st.tile [$offsets] [$sizes] [$strides])]) + // Memref:$dst[(arith.constant 0)] + Value destination = transferWrite.getSource(); + if (!destination.getType().isa()) { + return rewriter.notifyMatchFailure(transferWrite, + "expected memref as destination"); + } + if (failed(matchSimpleTransferOp(transferWrite, rewriter))) return failure(); + + auto distribute = transferWrite.getVector().getDefiningOp(); + if (!distribute) { + return rewriter.notifyMatchFailure(transferWrite, + "expected distribute as source"); + } + Value source = distribute.getSource(); + + auto tile = distribute.getSet().getDefiningOp(); + if (!tile) { + return rewriter.notifyMatchFailure(distribute, + "expected gml_st.tile as set"); + } + + // Rewrite the pattern as: + // vector.transfer_write $src, + // (memref.subview $dst [$offsets] [$sizes] [$strides]) + // [(arith.constant 0)...] + auto subview = + createSubView(transferWrite.getLoc(), destination, tile.getMixedOffsets(), + tile.getMixedSizes(), tile.getMixedStrides(), rewriter); + rewriter.replaceOpWithNewOp( + transferWrite, /*resultType=*/std::nullopt, source, subview, + transferWrite.getIndices(), transferWrite.getPermutationMap(), + /*mask=*/nullptr, transferWrite.getInBounds().value_or(nullptr)); + return success(); +} + +std::unique_ptr> mlir::gml_st::createGmlStToGpuPass( + StringRef warpDistributionLabel) { + const GmlStToGpuPassOptions passOptions = { + /*.warpDistributionLabel=*/std::string(warpDistributionLabel)}; + return std::make_unique(passOptions); +} diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/gml_st_to_scf.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gml_st_to_scf/gml_st_to_scf.cc similarity index 51% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/gml_st_to_scf.cc rename to tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gml_st_to_scf/gml_st_to_scf.cc index 055bf8812ba..0ff52d6ccfb 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/gml_st_to_scf.cc +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gml_st_to_scf/gml_st_to_scf.cc @@ -15,14 +15,15 @@ limitations under the License. #include #include +#include #include #include -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" -#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/IRMapping.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -30,72 +31,7 @@ namespace gml_st { namespace { #define GEN_PASS_DEF_GMLSTTOSCF -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h.inc" - -/// Converts gml_st.loop to SCF loop nest. All parallel dimensions are collected -/// into an scf.parallel loop and all sequential dimensions will result in a -/// nested scf.for loop nest. The pattern assumes that a gml_st.loop with -/// iterator_types ["reduction", "parallel", "reduction"] can be reordered. -struct LoopToSCFPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(LoopOp loop, - PatternRewriter &rewriter) const override { - // Fail conversion if the `gml_st.loop` has not been bufferized. - if (!loop.hasBufferSemantics()) return failure(); - - // Collect loop control parameters for parallel and sequential dimensions. - SmallVector seqLBs, seqUBs, seqSteps, seqIVs; - SmallVector parLBs, parUBs, parSteps, parIVs; - for (const auto &en : - llvm::enumerate(llvm::zip(loop.getLowerBound(), loop.getUpperBound(), - loop.getStep(), loop.getInductionVars()))) { - Value lb, ub, step, iv; - std::tie(lb, ub, step, iv) = en.value(); - if (loop.isParallelDimension(en.index())) { - parLBs.push_back(lb); - parUBs.push_back(ub); - parSteps.push_back(step); - parIVs.push_back(iv); - } else { - seqLBs.push_back(lb); - seqUBs.push_back(ub); - seqSteps.push_back(step); - seqIVs.push_back(iv); - } - } - - Location loc = loop.getLoc(); - auto generateForLoopNestAndCloneBody = [&](OpBuilder &builder, Location loc, - ValueRange ivs) { - BlockAndValueMapping bvm; - bvm.map(parIVs, ivs); - bvm.map(loop.getRegionInputArgs(), loop.getInputs()); - bvm.map(loop.getRegionOutputArgs(), loop.getOutputs()); - - // If not all dimensions of the gml_st.loop are parallel, an scf.for loop - // nest is generated. - if (!seqIVs.empty()) { - scf::LoopNest nest = - scf::buildLoopNest(builder, loc, seqLBs, seqUBs, seqSteps, - [&](OpBuilder & /*builder*/, Location /*loc*/, - ValueRange ivs) { bvm.map(seqIVs, ivs); }); - builder.setInsertionPointToStart(nest.loops.back().getBody()); - } - for (auto &op : loop.getBody()->without_terminator()) - builder.clone(op, bvm); - }; - - if (parIVs.empty()) { - generateForLoopNestAndCloneBody(rewriter, loc, llvm::None); - } else { - rewriter.create(loc, parLBs, parUBs, parSteps, - generateForLoopNestAndCloneBody); - } - rewriter.eraseOp(loop); - return success(); - } -}; +#include "gml_st/transforms/passes.h.inc" /// Converts gml_st.parallel to SCF loop nest. struct ParallelOpToSCFPattern : public OpRewritePattern { @@ -107,7 +43,7 @@ struct ParallelOpToSCFPattern : public OpRewritePattern { if (!loop.hasBufferSemantics()) return failure(); auto cloneBody = [&](OpBuilder &builder, Location /*loc*/, ValueRange ivs) { - BlockAndValueMapping bvm; + IRMapping bvm; bvm.map(loop.getInductionVars(), ivs); for (auto &op : loop.getBody()->without_terminator()) @@ -131,14 +67,14 @@ struct ForOpToSCFPattern : public OpRewritePattern { PatternRewriter &rewriter) const override { auto cloneBody = [&](OpBuilder &builder, Location /*loc*/, ValueRange ivs, ValueRange iterArgs) { - BlockAndValueMapping bvm; + IRMapping bvm; bvm.map(loop.getInductionVars(), ivs); bvm.map(loop.getRegionOutputArgs(), iterArgs); for (auto &op : loop.getBody()->without_terminator()) builder.clone(op, bvm); - std::vector result; + scf::ValueVector result; llvm::transform(loop.getTerminator().getSrcs(), std::back_inserter(result), [&](Value src) { return bvm.lookupOrDefault(src); }); @@ -148,11 +84,7 @@ struct ForOpToSCFPattern : public OpRewritePattern { scf::LoopNest nest = scf::buildLoopNest( rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), loop.getOutputs(), cloneBody); - // TODO(csigg): just nest.getResults() once https://reviews.llvm.org/D136926 - // has landed. - ValueRange results; - if (!nest.loops.empty()) results = nest.getResults(); - rewriter.replaceOp(loop, results); + rewriter.replaceOp(loop, nest.results); return success(); } }; @@ -161,7 +93,7 @@ struct GmlStToScfPass : public impl::GmlStToScfBase { void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns.add( + patterns.add( patterns.getContext()); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gpu_tiling/greedy_fusion.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gpu_tiling/greedy_fusion.cc new file mode 100644 index 00000000000..135413aa6e1 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gpu_tiling/greedy_fusion.cc @@ -0,0 +1,159 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/fusion/fusion.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/tiling/tiling.h" +#include "gml_st/transforms/transforms.h" +#include "gml_st/utils/linalg_utils.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace gml_st { +namespace { + +#define GEN_PASS_DEF_GREEDYFUSIONPASS +#include "gml_st/transforms/passes.h.inc" + +namespace { + +class FuseTensorExtractPattern : public OpRewritePattern { + public: + explicit FuseTensorExtractPattern(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + if (extractOp->getParentOfType()) + return rewriter.notifyMatchFailure(extractOp, "already fused"); + + if (extractOp->getUsers().empty()) + return rewriter.notifyMatchFailure(extractOp, "op is trivially dead"); + + ParallelOp outerMostParallelOp; + for (Operation *user : extractOp->getUsers()) { + auto parallelOp = user->getParentOfType(); + while (parallelOp && parallelOp->getParentOfType()) + parallelOp = parallelOp->getParentOfType(); + + if (!parallelOp) + return rewriter.notifyMatchFailure(extractOp, "consumer is not fused"); + + if (!outerMostParallelOp) { + outerMostParallelOp = parallelOp; + } else if (outerMostParallelOp != parallelOp) { + return rewriter.notifyMatchFailure( + extractOp, + "consumers are not all nested under the same ParallelOp"); + } + } + + rewriter.setInsertionPointToStart(outerMostParallelOp.getBody()); + Value newExtractOp = rewriter.create( + extractOp.getLoc(), extractOp.getTensor(), extractOp.getIndices()); + rewriter.replaceAllUsesWith(extractOp, newExtractOp); + + return success(); + } +}; + +} // namespace + +struct GreedyFusionPass : public impl::GreedyFusionPassBase { + GreedyFusionPass() = default; + GreedyFusionPass(bool distr, ArrayRef ts, StringRef dl) { + this->distribute = distr; + this->tileSizes = ts; + this->distributionLabel = dl.str(); + } + + void getDependentDialects(DialectRegistry ®istry) const final { + registry + .insert(); + linalg::registerTilingInterfaceExternalModels(registry); + } + + void runOnOperation() override { + func::FuncOp f = getOperation(); + MLIRContext *ctx = &getContext(); + + TilingOptions opts; + opts.distribute = distribute; + opts.distributionLabel = distributionLabel; + SmallVector ts(tileSizes.begin(), tileSizes.end()); + opts.tileSizeComputationFn = [ts](OpBuilder &b, Operation *op) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart( + &op->getParentOfType().getBody().front()); + return llvm::to_vector(llvm::map_range(ts, [&](int64_t s) { + Value v = b.create(op->getLoc(), s); + return v; + })); + }; + + auto tilingFilterFn = [&](TilingInterface op) { + return success(llvm::none_of(op->getUsers(), [](Operation *user) { + return llvm::isa(user) || + llvm::isa(user); + })); + }; + + { + RewritePatternSet patterns(ctx); + populateTilingPatterns(ctx, tilingFilterFn, opts, &patterns); + + auto fusionFilterFn = [](tensor::ExtractSliceOp) { return success(); }; + populateFusionPatterns(ctx, fusionFilterFn, &patterns); + + if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) + return signalPassFailure(); + } + + RewritePatternSet patterns(ctx); + + patterns.add(ctx); + + if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) + return signalPassFailure(); + + // Clean up by removing temporary attributes. + removeTilingLabels(f); + } +}; + +} // namespace + +std::unique_ptr> createGreedyFusionPass() { + return std::make_unique(); +} + +std::unique_ptr> createGreedyFusionPass( + bool distribute, ArrayRef tileSizes, StringRef distributionLabel) { + return std::make_unique(distribute, tileSizes, + distributionLabel); +} + +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_cwise.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gpu_tiling/tiling_cwise.cc similarity index 77% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_cwise.cc rename to tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gpu_tiling/tiling_cwise.cc index 6d9361330f1..c04b77c6b04 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_cwise.cc +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gpu_tiling/tiling_cwise.cc @@ -17,19 +17,19 @@ limitations under the License. #include #include +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/fusion/fusion.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/tiling/tiling.h" +#include "gml_st/transforms/transforms.h" +#include "gml_st/utils/linalg_utils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/fusion.h" -#include "mlir-hlo/Dialect/gml_st/transforms/linalg_utils.h" -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h" -#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -38,12 +38,13 @@ namespace gml_st { namespace { #define GEN_PASS_DEF_TILINGCWISEPASS -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h.inc" +#include "gml_st/transforms/passes.h.inc" bool isRootOfCwiseExpr(Operation *op) { return isCwiseGenericOp(op) && llvm::none_of(op->getUsers(), [](Operation *user) { - return isCwiseGenericOp(user) || llvm::isa(user); + return isCwiseGenericOp(user) || + llvm::isa(user); }); } @@ -59,7 +60,7 @@ struct TilingCwisePass : public impl::TilingCwisePassBase { void getDependentDialects(DialectRegistry ®istry) const final { registry .insert(); - registerGmlStTilingInterfaceExternalModels(registry); + linalg::registerTilingInterfaceExternalModels(registry); } void runOnOperation() override { @@ -88,15 +89,11 @@ struct TilingCwisePass : public impl::TilingCwisePassBase { opts.distributionLabel = distributionLabel_; // Tile the roots of cwise expressions and fuse all cwise operands greedily. - auto tileRootOfCwiseExprFn = [](Operation *op) { - if (!isRootOfCwiseExpr(op)) return failure(); - return success(); + auto tileRootOfCwiseExprFn = [](TilingInterface op) { + return success(isRootOfCwiseExpr(op)); }; - auto fuseCwiseOperandsGreedilyFn = [](Operation *op) { - Operation *producerOp = - llvm::cast(op).getSource().getDefiningOp(); - if (!isCwiseGenericOp(producerOp)) return failure(); - return success(); + auto fuseCwiseOperandsGreedilyFn = [](tensor::ExtractSliceOp op) { + return success(isCwiseGenericOp(op.getSource().getDefiningOp())); }; // Populate tiling and fusion patterns. @@ -109,7 +106,7 @@ struct TilingCwisePass : public impl::TilingCwisePassBase { } // Clean up by removing temporary attributes. - f.walk([](Operation *op) { removeTransformationAttr(op); }); + removeTilingLabels(f); } }; diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_gpu_warp.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gpu_tiling/tiling_gpu_warp.cc similarity index 51% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_gpu_warp.cc rename to tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gpu_tiling/tiling_gpu_warp.cc index 5dc8fd60b6c..75ea8a1dafe 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_gpu_warp.cc +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/gpu_tiling/tiling_gpu_warp.cc @@ -16,33 +16,51 @@ limitations under the License. #include #include -#include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/fusion.h" -#include "mlir-hlo/Dialect/gml_st/transforms/linalg_utils.h" -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h" -#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h" +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/fusion/fusion.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/IRMapping.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define GEN_PASS_DEF_TILINGGPUWARPPASS -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h.inc" +#include "gml_st/transforms/passes.h.inc" namespace mlir { namespace gml_st { - namespace { +constexpr llvm::StringRef kTileGpuWarpAppliedLabel = + "__tile_gpu_warp_applied_label__"; + constexpr const char* kWarpDistributionLabel = "warp"; constexpr const char* kThreadDistributionLabel = "thread"; +using OpFoldResults = SmallVector; + +Value materializePoint(OpBuilder& b, Location loc, Value valueToTile, + ArrayRef offsets) { + auto tensorType = valueToTile.getType().cast(); + int64_t rank = tensorType.getRank(); + + IntegerAttr oneAttr = b.getIndexAttr(1); + SmallVector sizes(rank, oneAttr); + SmallVector strides(rank, oneAttr); + + Value slice = b.create(loc, valueToTile, offsets, + sizes, strides); + Value zero = b.create(loc, 0); + return b.create(loc, slice, + SmallVector(rank, zero)); +} + // Returns 'count' rounded up to power of two, up to warp size (32). -static int64_t getGroupSize(int64_t count) { +int64_t getGroupSize(int64_t count) { constexpr int64_t kWarpSize = 32; if (count < 0) return kWarpSize; for (int64_t i = 1; i < kWarpSize; i *= 2) @@ -57,35 +75,36 @@ bool isWarpLevelOp(Operation* op) { *parentPloop.getDistributionType() == kWarpDistributionLabel; } -struct TilingCwisePattern : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct TilingCwisePattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + LogicalResult matchAndRewrite(linalg::MapOp mapOp, PatternRewriter& rewriter) const override { - if (hasTransformationAttr(genericOp)) { - return rewriter.notifyMatchFailure(genericOp, "already transformed"); + if (hasLabel(mapOp, kTileGpuWarpAppliedLabel)) { + return rewriter.notifyMatchFailure(mapOp, "already transformed"); } - // Match only cwise `linalg.generic` ops on the shape 1x?. - auto genericOpTy = - genericOp.getResultTypes().front().dyn_cast(); - if (!isCwiseGenericOp(genericOp) || !genericOpTy || - genericOpTy.getRank() != 2 || genericOpTy.getDimSize(0) != 1) { - return rewriter.notifyMatchFailure(genericOp, - "not a cwise op on tensor<1x?>"); + // Match only `linalg.map` ops on the shape 1x?. + if (mapOp.getNumDpsInits() != 1) { + return rewriter.notifyMatchFailure(mapOp, "not element-wise"); + } + Value mapOpResult = mapOp.getResult().front(); + auto ploopTy = mapOpResult.getType().dyn_cast(); + if (!ploopTy || ploopTy.getRank() != 2 || ploopTy.getDimSize(0) != 1) { + return rewriter.notifyMatchFailure(mapOp, "result no tensor<1x?>"); } // Only tile root ops on the warp level. - if (!isWarpLevelOp(genericOp) || !genericOp->hasOneUse() || - !llvm::isa(*genericOp->getUsers().begin())) { - return rewriter.notifyMatchFailure(genericOp, "not a warp level root op"); + if (!isWarpLevelOp(mapOp) || !mapOp->hasOneUse() || + !llvm::isa(*mapOp->getUsers().begin())) { + return rewriter.notifyMatchFailure(mapOp, "not a warp level root op"); } // The number of threads per row (power of two, <= kWarpSize). - int64_t groupSize = getGroupSize(genericOpTy.getDimSize(1)); + int64_t groupSize = getGroupSize(ploopTy.getDimSize(1)); // Constants and attributes. - Location loc = genericOp.getLoc(); + Location loc = mapOp.getLoc(); Value c0 = rewriter.create(loc, 0); Value c1 = rewriter.create(loc, 1); Value cGroupSize = rewriter.create(loc, groupSize); @@ -98,15 +117,13 @@ struct TilingCwisePattern : OpRewritePattern { rewriter.getStringAttr(kThreadDistributionLabel); // Create `gml_st.parallel` loop to distribute among lanes. - Value init = genericOp.getOutputs().front(); - Value genericOpResult = genericOp.getResults().front(); - Value dimSize = - rewriter.createOrFold(loc, genericOpResult, c1); + Value init = mapOp.getInit(); + Value dimSize = rewriter.createOrFold(loc, mapOpResult, c1); Value dimSizePlusWarpSizeMinusOne = rewriter.createOrFold(loc, dimSize, cGroupSizeMinusOne); auto ploop = rewriter.create( - loc, genericOpTy, c0, cGroupSize, c1, threadDistrLabel, - [&](OpBuilder& b, Location loc, ValueRange ivs) { + loc, ploopTy, c0, cGroupSize, c1, ValueRange{init}, threadDistrLabel, + [&](OpBuilder& b, Location loc, ValueRange ivs, ValueRange initBbArg) { // Compute the lane tile with a stride of `warpSize`. This tile // defines the subset of the result that is produced by the lane. // The `laneId` defines the initial offset into the tensor. The @@ -121,16 +138,13 @@ struct TilingCwisePattern : OpRewritePattern { loc, b.create(loc, dimSizePlusWarpSizeMinusOne, laneId), cGroupSize); - Value laneTile = b.createOrFold( - loc, SmallVector{zeroAttr, laneId}, - SmallVector{oneAttr, laneTileSize}, - SmallVector{oneAttr, groupSizeAttr}); + Value laneInit = b.create( + loc, initBbArg.front(), OpFoldResults{zeroAttr, laneId}, + OpFoldResults{oneAttr, laneTileSize}, + OpFoldResults{oneAttr, groupSizeAttr}); // Create `gml_st.for` loop to iterate over the lane's tile. - Type elemTy = genericOpTy.getElementType(); - auto sloopTy = - RankedTensorType::get({1, ShapedType::kDynamicSize}, elemTy); - Value laneInit = b.create(loc, init, laneTile); + auto sloopTy = ploopTy.clone({1, ShapedType::kDynamic}); auto sloop = b.create( loc, sloopTy, c0, laneTileSize, c1, laneInit, [&](OpBuilder& b, Location loc, ValueRange ivs, ValueRange aggr) { @@ -139,89 +153,74 @@ struct TilingCwisePattern : OpRewritePattern { Value i = ivs.front(); Value iterTileOffset = b.create( loc, laneId, b.create(loc, i, cGroupSize)); - Value iterTile = b.create( - loc, SmallVector{zeroAttr, iterTileOffset}, - SmallVector{oneAttr, oneAttr}, - SmallVector{oneAttr, oneAttr}); // Materialize scalar subsets per operand. - SmallVector iterOperands = - llvm::to_vector(llvm::map_range( - genericOp.getInputs(), [&](Value arg) -> Value { - return b.create(loc, elemTy, - arg, iterTile); - })); - - // Create scalar computation from `linalg.generic` body by (i) + SmallVector iterOperands = llvm::to_vector( + llvm::map_range(mapOp.getInputs(), [&](Value arg) -> Value { + return materializePoint( + b, loc, arg, OpFoldResults{zeroAttr, iterTileOffset}); + })); + + // Create scalar computation from `linalg.map` body by (i) // mapping its block arguments to the newly materialized // scalar operands, and (ii) cloning the body. - BlockAndValueMapping bvm; - for (const auto& [blockArg, iterOperand] : llvm::zip( - genericOp.getBlock()->getArguments(), iterOperands)) { - bvm.map(blockArg, iterOperand); - } - for (auto& innerop : - genericOp.getBody()->without_terminator()) { - rewriter.clone(innerop, bvm); + IRMapping bvm; + bvm.map(mapOp.getBlock()->getArguments(), iterOperands); + for (auto& innerOp : mapOp.getBody()->without_terminator()) { + rewriter.clone(innerOp, bvm); } // Yield iteration result. - Value iterResult = bvm.lookup(genericOp.getBody() - ->getTerminator() - ->getOperands() - .front()); - Value iterTileInLaneTile = b.create( - loc, SmallVector{zeroAttr, i}, - SmallVector{oneAttr, oneAttr}, - SmallVector{oneAttr, oneAttr}); + Value iterResult = + bvm.lookup(mapOp.getBody()->getTerminator()->getOperand(0)); + Value iterTileInLaneTile = + b.create(loc, OpFoldResults{zeroAttr, i}, + OpFoldResults{oneAttr, oneAttr}, + OpFoldResults{oneAttr, oneAttr}); b.create(loc, iterResult, aggr, iterTileInLaneTile); }); - b.create(loc, sloop.getResults().front(), init, + Value laneTile = b.createOrFold( + loc, OpFoldResults{zeroAttr, laneId}, + OpFoldResults{oneAttr, laneTileSize}, + OpFoldResults{oneAttr, groupSizeAttr}); + b.create(loc, sloop.getResult(0), initBbArg, laneTile); }); - rewriter.replaceOp(genericOp, ploop.getResults()); + rewriter.replaceOp(mapOp, ploop.getResults()); return success(); } }; -struct TilingReductionPattern : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct TilingReductionPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + LogicalResult matchAndRewrite(linalg::ReduceOp reduceOp, PatternRewriter& rewriter) const override { - if (hasTransformationAttr(genericOp)) { - return rewriter.notifyMatchFailure(genericOp, "already transformed"); + if (hasLabel(reduceOp, kTileGpuWarpAppliedLabel)) { + return rewriter.notifyMatchFailure(reduceOp, "already transformed"); } // Only tile ops on the warp level. - if (!isWarpLevelOp(genericOp)) { - return rewriter.notifyMatchFailure(genericOp, "not a warp level op"); + if (!isWarpLevelOp(reduceOp)) { + return rewriter.notifyMatchFailure(reduceOp, "not a warp level op"); } - // Match only if it's a linalg.generic tensor<1x?xf32> -> tensor<1xf32> with - // iterator_types = ["parallel", "reduction"]. - auto itTypes = llvm::to_vector( - genericOp.getIteratorTypes().getAsValueRange()); - if (itTypes.size() != 2 || itTypes[0] != getParallelIteratorTypeName() || - itTypes[1] != getReductionIteratorTypeName()) { - return rewriter.notifyMatchFailure(genericOp, - "Expected ['parallel', 'reduction']"); - } - if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) { - return rewriter.notifyMatchFailure(genericOp, + // Match only if it's a linalg.reduce tensor<1x?xf32> -> tensor<1xf32> + if (reduceOp.getNumDpsInputs() != 1 || reduceOp.getNumDpsInits() != 1) { + return rewriter.notifyMatchFailure(reduceOp, "Expected single input and output"); } auto inputTy = - genericOp.getInputs().front().getType().dyn_cast(); + reduceOp.getInputs().front().getType().dyn_cast(); // The number of threads per row (power of two, <= kWarpSize). int64_t groupSize = getGroupSize(inputTy.getDimSize(1)); // Attributes and constants. - Location loc = genericOp->getLoc(); + Location loc = reduceOp->getLoc(); Value c0 = rewriter.create(loc, 0); Value c1 = rewriter.create(loc, 1); Value cGroupSize = rewriter.create(loc, groupSize); @@ -231,8 +230,8 @@ struct TilingReductionPattern : OpRewritePattern { StringAttr threadDistrLabel = rewriter.getStringAttr(kThreadDistributionLabel); - Value operand = genericOp.getInputs().front(); - Value init = genericOp.getOutputs().front(); + Value operand = reduceOp.getInputs().front(); + Value init = reduceOp.getInits().front(); Type scalarTy = inputTy.getElementType(); @@ -240,23 +239,21 @@ struct TilingReductionPattern : OpRewritePattern { // Create warp-sized partial reduction result tensor. Value warpResult = rewriter.create( - loc, SmallVector{oneAttr, groupSizeAttr}, scalarTy); - Value initTile = - rewriter.create(loc, SmallVector{zeroAttr}); + loc, OpFoldResults{oneAttr, groupSizeAttr}, scalarTy); Value initMaterialized = - rewriter.create(loc, scalarTy, init, initTile); + materializePoint(rewriter, loc, init, OpFoldResults{zeroAttr}); warpResult = rewriter.create(loc, initMaterialized, warpResult) - .getResults() - .front(); + .getResult(0); // Create gml_st.parallel finalizing the partial result. auto parallelOpBodyBuilderFn = [&](OpBuilder& b, Location loc, - ValueRange ivs) { + ValueRange ivs, + ValueRange parallelLoopOutputs) { Value laneId = ivs.front(); - Value laneTile = - b.create(loc, SmallVector{zeroAttr, laneId}); - Value laneResult = b.create(loc, warpResult, laneTile); + Value laneResult = b.create( + loc, parallelLoopOutputs.front(), OpFoldResults{zeroAttr, laneId}, + OpFoldResults{oneAttr, oneAttr}, OpFoldResults{oneAttr, oneAttr}); // Create gml_st.for sequentially reducing parts of the row. auto forOpBodyBuilderFn = [&](OpBuilder& b, Location loc, ValueRange ivs, @@ -265,49 +262,48 @@ struct TilingReductionPattern : OpRewritePattern { Value laneAcc = outputs.front(); // Materialize operand subset. - Value operandTile = b.create( - loc, ArrayRef{zeroAttr, iterationId}); - Value operandMaterialized = - b.create(loc, scalarTy, operand, operandTile); + Value operandMaterialized = materializePoint( + b, loc, operand, ArrayRef{zeroAttr, iterationId}); // Materialize intermediate result. - Value iterationTile = rewriter.create( - loc, SmallVector{zeroAttr, zeroAttr}); - Value iterationResult = rewriter.create( - loc, scalarTy, laneAcc, iterationTile); - - // Create scalar computation based on `linalg.generic` body. - BlockAndValueMapping bvm; - bvm.map(genericOp.getBlock()->getArguments()[0], operandMaterialized); - bvm.map(genericOp.getBlock()->getArguments()[1], iterationResult); - for (Operation& inner : genericOp.getBody()->without_terminator()) { + Value iterationResult = materializePoint( + rewriter, loc, laneAcc, OpFoldResults{zeroAttr, zeroAttr}); + + // Create scalar computation based on `linalg.reduce` body. + IRMapping bvm; + bvm.map(reduceOp.getBlock()->getArguments()[0], operandMaterialized); + bvm.map(reduceOp.getBlock()->getArguments()[1], iterationResult); + for (Operation& inner : reduceOp.getBody()->without_terminator()) { rewriter.clone(inner, bvm); } - iterationResult = bvm.lookup( - genericOp.getBody()->getTerminator()->getOperands().front()); + iterationResult = + bvm.lookup(reduceOp.getBody()->getTerminator()->getOperand(0)); + Value iterationTile = + rewriter.create(loc, OpFoldResults{zeroAttr, zeroAttr}); b.create(loc, iterationResult, laneAcc, iterationTile); }; laneResult = b.create(loc, laneResult.getType(), laneId, reductionDimSize, cGroupSize, laneResult, forOpBodyBuilderFn) - .getResults() - .front(); + .getResult(0); - b.create(loc, laneResult, warpResult, laneTile); + Value laneTile = b.create(loc, OpFoldResults{zeroAttr, laneId}); + b.create(loc, laneResult, parallelLoopOutputs.front(), + laneTile); }; warpResult = rewriter .create( loc, warpResult.getType(), c0, cGroupSize, c1, - threadDistrLabel, parallelOpBodyBuilderFn) - .getResults() - .front(); + /*outputs=*/ValueRange{warpResult}, threadDistrLabel, + parallelOpBodyBuilderFn) + .getResult(0); // Change existing linalg.generic to warp-reduce the partial results. - rewriter.updateRootInPlace(genericOp, [&] { - genericOp->setOperand(0, warpResult); - gml_st::setTransformationAttr(rewriter, genericOp); + rewriter.updateRootInPlace(reduceOp, [&] { + reduceOp->setOperand(0, warpResult); + setLabel(reduceOp, kTileGpuWarpAppliedLabel); }); return success(); @@ -319,7 +315,7 @@ struct TilingGPUWarpPass void getDependentDialects(DialectRegistry& registry) const final { ::impl::TilingGPUWarpPassBase::getDependentDialects( registry); - registerGmlStTilingInterfaceExternalModels(registry); + linalg::registerTilingInterfaceExternalModels(registry); } void runOnOperation() override { @@ -331,11 +327,14 @@ struct TilingGPUWarpPass // Populate fusion patterns. auto fuseGreedilyFilterFn = [](Operation* op) { - auto materializeOp = llvm::dyn_cast(op); + auto materializeOp = llvm::dyn_cast(op); Operation* source = materializeOp.getSource().getDefiningOp(); - // Do not fuse wap-level reductions. - if (isSimpleReduction(source) && isWarpLevelOp(source)) return failure(); + // Do not fuse warp-level reductions. + auto reductionOp = llvm::dyn_cast_or_null(source); + if (reductionOp && reductionOp.getNumDpsInits() == 1 && + isWarpLevelOp(source)) + return failure(); return success(); }; @@ -347,13 +346,13 @@ struct TilingGPUWarpPass } // Clean up by removing temporary attributes. - func.walk([](Operation* op) { removeTransformationAttr(op); }); + func.walk([](Operation* op) { removeLabel(op, kTileGpuWarpAppliedLabel); }); } }; } // namespace -std::unique_ptr> createTilingGPUWarpPass() { +std::unique_ptr> createTilingGpuWarpPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/passes.h b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/passes.h new file mode 100644 index 00000000000..6a83bdec9f3 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/passes.h @@ -0,0 +1,202 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MLIR_HLO_GML_ST_TRANSFORMS_PASSES_H +#define MLIR_HLO_GML_ST_TRANSFORMS_PASSES_H + +#include +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" + +#define GEN_PASS_DECL +#include "gml_st/transforms/passes.h.inc" + +namespace mlir { +namespace gml_st { + +/// The key to the attribute corresponding to the distribution type of +/// operations that have been SIMTfied. +inline constexpr const char kDistributionLabelKey[] = + "gml-st-distribution-label"; + +/// Pass to tile ops using TilingInterface. +std::unique_ptr> createTilingPass( + StringRef opName = "", StringRef opLabel = "", bool distribute = true, + ArrayRef tileSizes = {}); + +/// Pass to fuse producers into a tiled consumer. +std::unique_ptr> createFusionPass( + StringRef producer = "", StringRef consumer = ""); + +/// Pass to tile and fuse all cwise ops. +std::unique_ptr> createTilingCwisePass( + bool distribute, ArrayRef tileSizes, + StringRef distributionLabel = ""); +std::unique_ptr> createTilingCwisePass(); + +/// Pass to tile warp-level ops on GPU. +std::unique_ptr> createTilingGpuWarpPass(); + +/// Pass to match, tile, and fuse softmax implementations. +std::unique_ptr> createTilingSoftmaxPass( + bool distribute, ArrayRef tileSizes, + StringRef distributionLabel = ""); +std::unique_ptr> createTilingSoftmaxPass(); + +/// Pass to tile the root operation and to greedily fuse producers into it. +std::unique_ptr> createGreedyFusionPass( + bool distribute, ArrayRef tileSizes, StringRef distributionLabel); +std::unique_ptr> createGreedyFusionPass(); + +// Pass to collapse dimensions of bcasts, reductions, and cwise ops. +std::unique_ptr> createCollapseShapePass(); +std::unique_ptr> createCollapseShapePass( + const CollapseShapePassOptions &options); + +/// Pass to compose tensor.extract_slice/insert_slice ops. +std::unique_ptr> +createComposeExtractInsertSlicePass(); + +/// Pass to lower `gml_st.parallel` to `gpu.launch`, transforming the code into +/// its SIMT interpretation. +std::unique_ptr> createGmlStSimtfyPass( + StringRef blockDistributionLabel = "block"); + +/// Pass to eliminate the remaining `gml_st` ops after SIMTfication. +std::unique_ptr> createGmlStToGpuPass( + StringRef warpDistributionLabel = "warp"); + +/// Create a pass to convert `gml_st.loop` to `scf.for` and `scf.parallel` +/// loops and memref.load/memref.store accesses. +std::unique_ptr> createGmlStToScfPass(); + +/// Pass to vectorize compute ops and gml_st.loops. +std::unique_ptr> createVectorizeForGPUPass( + bool vectorizeGmlStOps = false, + ArrayRef distributionLabels = {}); + +/// Pass to vectorize compute ops and scf.for loops that are tiled perfectly. +std::unique_ptr> createVectorizeForCPUPass(); + +/// Pass to vectorize `memref.copy`. +std::unique_ptr> createVectorizeCopyPass(); + +/// Pass to eliminate dead `memref.copy`. +std::unique_ptr> createSimplifyDeadCopyPass(); + +/// Pass to rewrite vector.contract. +std::unique_ptr> createRewriteVectorContractPass(); + +/// Pass to rewrite vector.transpose. +std::unique_ptr> createRewriteVectorTransposePass(); + +/// Pass to rewrite vector.multi_reduction. +std::unique_ptr> +createRewriteVectorMultiReductionPass(); + +/// Pass to transform a thlo.scatter op for CPU backend. +std::unique_ptr> createTransformScatterForCpuPass(); + +/// Pass to transform a linalg.matmul op for CPU backend. +std::unique_ptr> createTransformMatmulForCpuPass( + ArrayRef matmulTileSizes = std::nullopt, + bool lowerToMmt4DOp = false); + +/// Pass to transform a linalg.matmul op for Triton. +std::unique_ptr> createTransformMatmulForTritonPass( + ArrayRef matmulTileSizes = std::nullopt, + StringRef distributionLabel = ""); + +/// Pass to fuse linalg on tensor operations. +std::unique_ptr> createFusionOfTensorOpsPass(); + +/// Pass to convert ops on tensors with 1 element to scalar ops. +std::unique_ptr> createScalarizationPass(); + +/// Pass to transform a linalg.map op for CPU backend. +std::unique_ptr> +createTransformMapForCpuPass(int64_t tileSize = 1); + +/// Pass to transform a linalg.reduce op for CPU backend. +std::unique_ptr> +createTransformReduceForCpuPass(int64_t vectorSize = 8, int64_t tileSize1D = 32, + ArrayRef tileSizes2D = {}); + +/// Pass to transform a thlo.reverse op for CPU backend. +std::unique_ptr> +createTransformReverseForCpuPass(int64_t vectorSize = 8); + +/// Pass to transform a linalg.transpose op for CPU backend. +std::unique_ptr> +createTransformTransposeForCpuPass(ArrayRef tileSizes = std::nullopt); + +/// Pass to transform a thlo.sort op for CPU backend. +std::unique_ptr> +createTransformSortForCpuPass(); + +/// Pass to add debug info to be propagated into LLVM backend. +std::unique_ptr> createAddDebugInfoPass(); + +struct GmlStCPUPipelineOptions + : public mlir::PassPipelineOptions { + Option vectorize{*this, "vectorize", + llvm::cl::desc("Enable tiling for vectorization."), + llvm::cl::init(false)}; + + Option vectorSize{*this, "vector-size", + llvm::cl::desc("Vector size for a 1D reduction."), + llvm::cl::init(8)}; + + Option reduction1DTileSize{ + *this, "reduction-1d-tile-size", + llvm::cl::desc("Tile size for a 1D reduction."), llvm::cl::init(32)}; + + ListOption reduction2DTileSizes{ + *this, "reduction-2d-tile-sizes", + llvm::cl::desc("Tile sizes for a 2D reduction."), + llvm::cl::list_init({4, 4}), llvm::cl::ZeroOrMore}; + + ListOption matmulTileSizes{ + *this, "matmul-tile-sizes", + llvm::cl::desc("Tile sizes for `linalg.matmul`."), + llvm::cl::list_init({4, 4, 4}), llvm::cl::ZeroOrMore}; + + Option lowerToMmt4d{ + *this, "lower-to-mmt4d", + llvm::cl::desc("Enable the specific code generation (packing) for matmul " + "operations."), + llvm::cl::init(false)}; +}; + +// Make GmlStCPUPipelineOptions hashable. +inline ::llvm::hash_code hashValue(const GmlStCPUPipelineOptions &opts) { + return ::llvm::hash_value(static_cast(opts.vectorize)); +} + +// Adds tiling-fusion-vectorization passes for tHLO/Linalg ops mix. +void addCPUTilingPipeline(OpPassManager &pm, + const GmlStCPUPipelineOptions &options); + +#define GEN_PASS_REGISTRATION +#include "gml_st/transforms/passes.h.inc" + +} // namespace gml_st +} // namespace mlir + +#endif // MLIR_HLO_GML_ST_TRANSFORMS_PASSES_H diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/passes.td b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/passes.td new file mode 100644 index 00000000000..e2ccbfe5153 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/passes.td @@ -0,0 +1,322 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/Pass/PassBase.td" + +def TilingPass : Pass<"gml-tiling", "mlir::func::FuncOp"> { + let summary = "Tile operations using TilingInterface to produce gml_st.for"; + let constructor = "::mlir::gml_st::createTilingPass()"; + let options = [ + Option<"opName", "op-name", "std::string", /*default=*/"", + "Operation with this name is the anchor to latch on.">, + Option<"opLabel", "op-label", "std::string", /*default=*/"", + "Operation with this label is the anchor to latch on.">, + Option<"distribute", "distribute", "bool", /*default=*/"true", + "Generate gml_st.parallel or gml_st.for">, + ListOption<"tileSizes", "tile-sizes", "int64_t", "Tile sizes", + "llvm::cl::ZeroOrMore">, + ]; +} + +def FusionPass : Pass<"gml-fusion", "mlir::func::FuncOp"> { + let summary = "Fuse producers in into `gml_st.materialize` operations"; + let constructor = "::mlir::gml_st::createFusionPass()"; + let options = [ + Option<"producerLabel", "producer-label", "std::string", /*default=*/"", + "Producer label.">, + Option<"consumerLabel", "consumer-label", "std::string", /*default=*/"", + "Consumer label.">, + ]; +} + +def TilingCwisePass : Pass<"gml-tiling-cwise", "mlir::func::FuncOp"> { + let summary = "Tile and fuse all cwise ops"; + let constructor = "::mlir::gml_st::createTilingCwisePass()"; + let options = [ + Option<"distribute_", "distribute", "bool", /*default=*/"true", + "Generate gml_st.parallel or gml_st.for">, + ListOption<"tileSizes_", "tile-sizes", "int64_t", + "Right-aligned tile sizes. Do not tile possible remaining " + "dimensions", "llvm::cl::ZeroOrMore">, + Option<"distributionLabel_", "distribution-label", "std::string", + /*default=*/"", "Distribution label for generated gml_st.parallel">, + ]; +} + +def TilingGPUWarpPass : Pass<"gml-tiling-gpu-warp", "mlir::func::FuncOp"> { + let summary = "Tile warp-level ops for GPU"; + let constructor = "::mlir::gml_st::createTilingGpuWarpPass()"; + let dependentDialects = ["::mlir::gml_st::GmlStDialect", + "::mlir::arith::ArithDialect"]; +} + +def TilingSoftmaxPass : Pass<"gml-tiling-softmax", "mlir::func::FuncOp"> { + let summary = "Match, tile, and fuse softmax implementations"; + let constructor = "::mlir::gml_st::createTilingSoftmaxPass()"; + let options = [ + Option<"distribute", "distribute", "bool", /*default=*/"true", + "Generate gml_st.parallel or gml_st.for">, + ListOption<"tileSizes", "tile-sizes", "int64_t", + "Right-aligned tile sizes. Do not tile possible remaining " + "dimensions", "llvm::cl::ZeroOrMore">, + Option<"distributionLabel", "distribution-label", "std::string", + /*default=*/"", "Distribution label for generated gml_st.parallel">, + ]; +} + +def GreedyFusionPass : Pass<"gml-greedy-fusion", "mlir::func::FuncOp"> { + let summary = "Pass to tile the root operation and to greedily fuse " + "producers into it."; + let constructor = "::mlir::gml_st::createGreedyFusionPass()"; + let options = [ + Option<"distribute", "distribute", "bool", /*default=*/"true", + "Generate gml_st.parallel or gml_st.for">, + ListOption<"tileSizes", "tile-sizes", "int64_t", + "Tile sizes", "llvm::cl::ZeroOrMore">, + Option<"distributionLabel", "distribution-label", "std::string", + /*default=*/"", "Distribution label for generated gml_st.parallel">, + ]; +} + +def CollapseShapePass : Pass<"gml-collapse-shape", "mlir::func::FuncOp"> { + let summary = "Collapse dimensions of bcasts, reductions, and cwise ops"; + let description = [{ + Pass to collapse dimensions of bcasts, reductions, and cwise ops. A given + number of trailing dimensions remains untouched while the remaining leading + dimensions will be collapsed where possible. + }]; + let constructor = "::mlir::gml_st::createCollapseShapePass()"; + let options = [ + Option<"retainTrailingDims", "retain-trailing-dims", "int64_t", + /*default=*/"0", + "Number of trailing dimensions that will not be collapsed.">, + ]; + let dependentDialects = ["::mlir::tensor::TensorDialect"]; +} + +def ComposeExtractInsertSlicePass : Pass<"gml-compose-extract-insert-slice", + "mlir::func::FuncOp"> { + let summary = "Compose tensor.extract_slice/insert_slice ops."; + let constructor = "::mlir::gml_st::createComposeExtractInsertSlicePass()"; +} + +def GmlStToScf : Pass<"gml-st-to-scf", "mlir::func::FuncOp"> { + let summary = "Lower `gml_st.loop` to SCF loops and parallel loops"; + let constructor = "::mlir::gml_st::createGmlStToScfPass()"; + let dependentDialects = ["::mlir::scf::SCFDialect"]; +} + +def GmlStToGpuPass : Pass<"gml-st-to-gpu", "mlir::func::FuncOp"> { + let summary = "Lower nested `gml_st.parallel` to `gpu.launch`"; + let constructor = "::mlir::gml_st::createGmlStToGpuPass()"; + let dependentDialects = ["::mlir::gpu::GPUDialect", + "::mlir::vector::VectorDialect", + "::mlir::memref::MemRefDialect"]; + let options = [ + Option<"warpDistributionLabel", "warp-distribution-label", + "std::string", /*default=*/"\"warp\"", + "Direct children of `gml_st.parallel` loops with this distribution " + "type are distributed over warps."> + ]; +} + +def GmlStSimtfyPass : Pass<"gml-st-simtfy", "mlir::func::FuncOp"> { + let summary = "Lower nested `gml_st.parallel` to `gpu.launch`"; + let constructor = "::mlir::gml_st::createGmlStSimtfyPass()"; + let dependentDialects = ["::mlir::AffineDialect", + "::mlir::arith::ArithDialect", + "::mlir::gpu::GPUDialect", + "::mlir::scf::SCFDialect"]; + let options = [ + Option<"blockDistributionLabel", "block-distribution-label", + "std::string", /*default=*/"\"block\"", + "Direct children of `gml_st.parallel` loops with this distribution " + "type are distributed over blocks."> + ]; +} + +def VectorizeForGPUPass : Pass<"vectorize-for-gpu", "mlir::func::FuncOp"> { + let summary = "Pass to vectorize compute ops and gml_st.loops."; + let constructor = "::mlir::gml_st::createVectorizeForGPUPass()"; + let options = [ + Option<"vectorizeGmlStOps", "vectorize-gml-st-ops", "bool", "false", + "If true, vectorizes GmlSt ops in addition to linalg ops">, + ListOption<"distributionLabels", "included-distribution-labels", + "std::string", "Distribution labels of gml_st.parallel ops " + "where vectorization is allowed. Empty list signifies that " + "vectorization is allowed within all loops.", + "llvm::cl::ZeroOrMore">, + ]; + let dependentDialects = ["::mlir::vector::VectorDialect"]; +} + +def VectorizeForCPUPass : Pass<"vectorize-for-cpu", "mlir::func::FuncOp"> { + let summary = "Pass to vectorize gml_st.for loops that are tiled perfectly."; + let constructor = "::mlir::gml_st::createVectorizeForCPUPass()"; + let dependentDialects = [ + "::mlir::vector::VectorDialect", + "::mlir::tensor::TensorDialect" + ]; +} + +def VectorizeCopyPass : + Pass<"vectorize-copy", "mlir::func::FuncOp"> { + let summary = "Pass to vectorize `memref.copy`."; + let constructor = "::mlir::gml_st::createVectorizeCopyPass()"; + let dependentDialects = ["::mlir::vector::VectorDialect"]; +} + +def SimplifyDeadCopyPass : + Pass<"simplify-dead-copy", "mlir::func::FuncOp"> { + let summary = "Pass to simplify dead `memref.copy`."; + let constructor = "::mlir::gml_st::createSimplifyDeadCopyPass()"; + let dependentDialects = ["::mlir::vector::VectorDialect", + "::mlir::memref::MemRefDialect"]; +} + +def RewriteVectorContractPass : + Pass<"rewrite-vector-contract", "mlir::func::FuncOp"> { + let summary = "Pass to rewrite vector.contract."; + let constructor = "::mlir::gml_st::createRewriteVectorContractPass()"; + let dependentDialects = ["::mlir::vector::VectorDialect"]; +} + + +def RewriteVectorMultiReductionPass : + Pass<"rewrite-vector-multi-reduction", "mlir::func::FuncOp"> { + let summary = "Pass to rewrite vector.multi_reduction."; + let constructor = "::mlir::gml_st::createRewriteVectorMultiReductionPass()"; + let dependentDialects = ["::mlir::vector::VectorDialect"]; +} + +def RewriteVectorTransposePass : Pass<"rewrite-vector-transpose", "mlir::func::FuncOp"> { + let summary = "Pass to rewrite vector.transpose."; + let constructor = "::mlir::gml_st::createRewriteVectorTransposePass()"; + let dependentDialects = [ + "::mlir::LLVM::LLVMDialect", + "::mlir::vector::VectorDialect", + ]; +} + +def ScalarizationPass : Pass<"scalarize", "mlir::func::FuncOp"> { + let summary = "Converts ops on tensors with 1 element to scalar ops."; + let dependentDialects = [ + "arith::ArithDialect", + "gml_st::GmlStDialect", + "scf::SCFDialect", + "tensor::TensorDialect" + ]; + let constructor = "createScalarizationPass()"; +} + +def TransformScatterForCpuPass : + Pass<"xla-cpu-transform-scatter", "mlir::func::FuncOp"> { + let summary = "Transform scatter ops for running on CPU"; + + let constructor = "createTransformScatterForCpuPass()"; +} + +def TransformMatmulForCpuPass : + Pass<"xla-cpu-transform-matmul", "mlir::func::FuncOp"> { + let summary = "Transform matmul ops for running on CPU"; + + let constructor = "createTransformMatmulForCpuPass()"; + + let options = [ + Option<"lowerToMmt4D", "lower-to-mmt4d", "bool", "false", + "If true, lower linalg.matmul into linalg.mmt4d">, + ListOption<"tileSizes", "tile-sizes", "int64_t", + "Tile sizes for a `linalg.matmul`">, + ]; +} + +def TransformMatmulForTritonPass : + Pass<"xla-triton-transform-matmul", "mlir::func::FuncOp"> { + let summary = "Transform matmul ops for lowering to Triton"; + + let constructor = "createTransformMatmulForTritonPass()"; + + let options = [ + ListOption<"tileSizes", "tile-sizes", "int64_t", + "Tile sizes for a `linalg.matmul`">, + Option<"distributionLabel", "distribution-label", "std::string", + /*default=*/"", "Distribution label for generated gml_st.parallel">, + ]; +} + +def TransformMapForCpuPass : + Pass <"gml-st-cpu-transform-map", "mlir::func::FuncOp"> { + let summary = "Transform map ops for running on CPU"; + + let constructor = "::mlir::gml_st::createTransformMapForCpuPass()"; + + let options = [ + Option<"tileSize", "tile-size", "int64_t", "1", + "Tile size for the innermost dimension of `linalg.map`">, + ]; +} + +def TransformTransposeForCpuPass : + Pass<"gml-st-cpu-transform-transpose", "mlir::func::FuncOp"> { + let summary = "Transform transpose ops for running on CPU"; + + let constructor = "createTransformTransposeForCpuPass()"; + + let options = [ + ListOption<"tileSizes", "tile-sizes", "int64_t", + "Tile sizes for a `linalg.transpose`">, + ]; +} + +def TransformReduceForCpuPass : + Pass<"xla-cpu-transform-reduce", "mlir::func::FuncOp"> { + let summary = "Transform reduce ops for running on CPU"; + + let constructor = "createTransformReduceForCpuPass()"; + + let options = [ + Option<"vectorSize", "vector-size", "int64_t", "8", + "Vector size for a 1D `linalg.reduce`">, + Option<"tileSize1D", "tile-size-1d", "int64_t", "32", + "Tile size for a 1D `linalg.reduce`">, + ListOption<"tileSizes2D", "tile-sizes-2d", "int64_t", + "Tile sizes for a `linalg.reduce`. tileSizes[0] is the parallel " + "dimension and tileSizes[1] is the reduction dimension.">, + ]; +} + +def TransformReverseForCpuPass : + Pass<"xla-cpu-transform-reverse", "mlir::func::FuncOp"> { + let summary = "Transform reverse ops for running on CPU"; + let constructor = "createTransformReverseForCpuPass()"; + let options = [ + Option<"vectorSize", "vector-size", "int64_t", "8", + "Vector size for 'thlo.reverse`">, + ]; + } + +def TransformSortForCpuPass : + Pass<"gml-st-cpu-transform-sort", "mlir::func::FuncOp"> { + let summary = "Transform sort ops for running on CPU"; + + let constructor = "createTransformSortForCpuPass()"; +} + +def AddDebugInfoPass : + Pass<"add-debug-info", "mlir::ModuleOp"> { + let summary = "Add debug info for the whole module"; + let constructor = "::mlir::gml_st::createAddDebugInfoPass()"; + let dependentDialects = ["::mlir::LLVM::LLVMDialect"]; +} diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/peeling/peeling.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/peeling/peeling.cc new file mode 100644 index 00000000000..20d6a11042d --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/peeling/peeling.cc @@ -0,0 +1,180 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "gml_st/transforms/peeling/peeling.h" + +#include +#include +#include +#include +#include + +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/transforms.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" +#include "mlir/IR/IRMapping.h" + +namespace mlir { +namespace gml_st { +namespace { + +bool isATensor(Type t) { return t.isa(); } + +/// Return true if the given op has only tensor-typed results or operands. +bool hasTensorSemantics(Operation *op) { + return llvm::all_of(op->getResultTypes(), isATensor) || + llvm::all_of(op->getOperandTypes(), isATensor); +} + +LogicalResult peelLoop(RewriterBase &b, ParallelOp loopOp, int64_t idx, + ParallelOp &result, Value &splitBound) { + if (!hasTensorSemantics(loopOp)) return failure(); + + Value lb = loopOp.getLowerBound()[idx], ub = loopOp.getUpperBound()[idx], + step = loopOp.getStep()[idx]; + auto ubInt = getConstantIntValue(ub); + + auto loc = loopOp.getLoc(); + AffineExpr exprLb, exprUb, exprStep; + bindSymbols(b.getContext(), exprLb, exprUb, exprStep); + // New upper bound: %ub - (%ub - %lb) mod %step + auto modMap = AffineMap::get(0, 3, exprUb - ((exprUb - exprLb) % exprStep)); + SmallVector operands{lb, ub, step}; + canonicalizeMapAndOperands(&modMap, &operands); + modMap = simplifyAffineMap(modMap); + RewriterBase::InsertionGuard guard(b); + b.setInsertionPoint(loopOp); + splitBound = b.createOrFold(loc, modMap, operands); + + // No specialization necessary if step already divides upper bound evenly. + if (splitBound == ub || (ubInt && ubInt == getConstantIntValue(splitBound))) + return failure(); + + // Create remainder loop. + IRMapping bvm; + for (const auto &[res, termDst] : + llvm::zip(loopOp.getResults(), loopOp.getLoopLikeOpInits())) { + bvm.map(termDst, res); + } + b.setInsertionPointAfter(loopOp); + auto remainderLoop = cast(b.clone(*loopOp.getOperation(), bvm)); + + Operation *remainderLoopOp = remainderLoop.getOperation(); + + for (const auto &[oldRes, newRes] : + llvm::zip(loopOp.getResults(), remainderLoop.getResults())) { + SmallPtrSet exceptions({remainderLoopOp}); + for (OpOperand &use : oldRes.getUses()) { + Operation *user = use.getOwner(); + if (user->getParentOp() == remainderLoopOp) exceptions.insert(user); + } + oldRes.replaceAllUsesExcept(newRes, exceptions); + } + + // Set new loop bounds. + b.updateRootInPlace(loopOp, [&]() { + SmallVector ubs = loopOp.getUpperBound(); + ubs[idx] = splitBound; + loopOp.getUpperBoundMutable().assign(ubs); + }); + SmallVector lbs = remainderLoop.getLowerBound(); + lbs[idx] = splitBound; + b.updateRootInPlace(remainderLoop, [&]() { + remainderLoop.getLowerBoundMutable().assign(lbs); + }); + + result = remainderLoop; + return success(); +} + +template +void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, Operation *mainLoop, + Operation *remainderLoop, Value mainIv, + Value remainderIv, Value ub, Value step) { + mainLoop->walk([&](OpTy affineOp) { + (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, mainIv, ub, step, + /*insideLoop=*/true); + }); + remainderLoop->walk([&](OpTy affineOp) { + (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, remainderIv, ub, step, + /*insideLoop=*/false); + }); +} + +} // namespace + +GmlStPeelingResult peelAllLoops(ParallelOp loop, + mlir::PatternRewriter &rewriter) { + setLabel(loop, kPeelingAppliedLabel); + GmlStPeelingResult peelingResult; + + bool hasMainLoop = true; + for (unsigned peeledIdx = 0; peeledIdx < loop.getNumLoops(); ++peeledIdx) { + int64_t numLoops = loop.getNumLoops(); + if (peeledIdx < 0 || numLoops <= peeledIdx) continue; + + Value ub = loop.getUpperBound()[peeledIdx]; + Value step = loop.getStep()[peeledIdx]; + auto ubInt = getConstantIntValue(ub); + auto stepInt = getConstantIntValue(step); + + // If the loop is smaller than the step, then append loop as tail. Needs to + // be done only once. + if (ubInt && stepInt && ubInt < stepInt) { + if (hasMainLoop) { + peelingResult.tailLoops.push_back(loop); + hasMainLoop = false; + } + continue; + } + + ParallelOp remainderLoop; + Value splitBound; + if (failed(peelLoop(rewriter, loop, peeledIdx, remainderLoop, splitBound))) + continue; + + // Rewrite affine.min and affine.max ops. + Value mainIv = loop.getInductionVars()[peeledIdx], + remainderIv = remainderLoop.getInductionVars()[peeledIdx]; + + rewriteAffineOpAfterPeeling(rewriter, loop, remainderLoop, + mainIv, remainderIv, ub, step); + rewriteAffineOpAfterPeeling(rewriter, loop, remainderLoop, + mainIv, remainderIv, ub, step); + + // Mark the new loop if one was created. + setLabel(remainderLoop.getOperation(), kPeelingAppliedLabel); + peelingResult.tailLoops.push_back(remainderLoop); + } + + // Update main loop if applicable. + if (hasMainLoop) peelingResult.mainLoop = loop; + + return peelingResult; +} + +SCFForPeelingResult peelSCFForOp(RewriterBase &rewriter, scf::ForOp loop) { + // Peeling fails, if the step divides the upper bound. In that case, + // we still want to return {loop, nullptr}. + scf::ForOp tailLoop; + return succeeded(scf::peelAndCanonicalizeForLoop(rewriter, loop, tailLoop)) + ? SCFForPeelingResult{loop, tailLoop} + : SCFForPeelingResult{loop, nullptr}; +} + +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/peeling/peeling.h b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/peeling/peeling.h new file mode 100644 index 00000000000..02aa905518a --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/peeling/peeling.h @@ -0,0 +1,57 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MLIR_HLO_GML_ST_TRANSFORMS_PEELING_PEELING_H +#define MLIR_HLO_GML_ST_TRANSFORMS_PEELING_PEELING_H + +#include +#include + +#include "gml_st/IR/gml_st_ops.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace gml_st { + +constexpr llvm::StringRef kPeelingAppliedLabel = "__peeling_applied_label__"; + +struct GmlStPeelingResult { + Operation *mainLoop = nullptr; + SmallVector tailLoops = {}; +}; + +/// Rewrite a gml_st::ParallelOp with bounds/step that potentially do not divide +/// evenly into a gml_st::ParallelOp where the step divides the iteration space +/// evenly, followed by another gml_st::ParallelOp for the last (partial) +/// iteration (if any). This transformation is called "loop peeling". +/// +/// These functions peel all loops in the loop nest by calling +/// peelAndCanonicalizeGmlStLoop. Additionally, they mark all loops (main and +/// remainder loops) as peeled, so the same loop is not rewritten a second time. +GmlStPeelingResult peelAllLoops(ParallelOp loop, + mlir::PatternRewriter &rewriter); + +struct SCFForPeelingResult { + scf::ForOp mainLoop = nullptr; + scf::ForOp tailLoop = nullptr; +}; +SCFForPeelingResult peelSCFForOp(RewriterBase &rewriter, scf::ForOp); + +} // namespace gml_st +} // namespace mlir + +#endif // MLIR_HLO_GML_ST_TRANSFORMS_PEELING_PEELING_H diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/rewrite_vector_ops/rewrite_vector_contract.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/rewrite_vector_ops/rewrite_vector_contract.cc new file mode 100644 index 00000000000..8eb3dfeb84c --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/rewrite_vector_ops/rewrite_vector_contract.cc @@ -0,0 +1,132 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/transforms.h" +#include "gml_st/utils/vector_utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace gml_st { +namespace { + +using vector::OuterProductOp; + +#define GEN_PASS_DEF_REWRITEVECTORCONTRACTPASS +#include "gml_st/transforms/passes.h.inc" + +struct OuterProductOpCanonicalizationPattern + : public OpRewritePattern { + OuterProductOpCanonicalizationPattern( + MLIRContext *context, llvm::function_ref filterFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), filterFn(filterFn) {} + + LogicalResult matchAndRewrite(OuterProductOp op, + PatternRewriter &rewriter) const override { + if (!filterFn(op)) + return rewriter.notifyMatchFailure(op, "did not match filter"); + + bool changed = false; + SmallVector newAccs{op.getAcc()}; + for (auto &acc : newAccs) { + auto materializeOp = acc.getDefiningOp(); + auto src = materializeOp.getSource(); + auto srcType = src.getType().cast(); + if (auto resType = op.getResult().getType().dyn_cast()) { + if (resType.hasStaticShape() && srcType == resType) { + acc = src; + changed = true; + } + } + } + if (!changed) return failure(); + rewriter.updateRootInPlace(op, + [&]() { op.getAccMutable().assign(newAccs); }); + return success(); + } + + private: + llvm::function_ref filterFn; +}; + +struct RewriteVectorContractPass + : public impl::RewriteVectorContractPassBase { + RewriteVectorContractPass() = default; + + void runOnOperation() override { + auto func = getOperation(); + auto *ctx = func.getContext(); + + // Reduce vector.contract dimensions to fit one of the lowering patterns to + // vector.outerproduct. + { + RewritePatternSet castAwayUnitDimPatterns(ctx); + vector::populateCastAwayVectorLeadingOneDimPatterns( + castAwayUnitDimPatterns); + if (failed(applyPatternsAndFoldGreedily( + func, std::move(castAwayUnitDimPatterns)))) { + return signalPassFailure(); + } + + RewritePatternSet reductionToContractPatterns(ctx); + vector::populateVectorReductionToContractPatterns( + reductionToContractPatterns); + vector::ExtractOp::getCanonicalizationPatterns( + reductionToContractPatterns, ctx); + if (failed(applyPatternsAndFoldGreedily( + func, std::move(reductionToContractPatterns)))) { + return signalPassFailure(); + } + } + + RewritePatternSet patterns(ctx); + + auto outerProductOpFilter = [&](OuterProductOp op) { + return (llvm::any_of(op.getAcc(), [](auto acc) { + return acc.template getDefiningOp() != nullptr; + })); + }; + + vector::populateVectorToVectorCanonicalizationPatterns(patterns); + // Currently we always lower vector.contract into vector.outerproduct. + patterns.add( + vector::VectorTransformsOptions().setVectorTransformsOptions( + vector::VectorContractLowering::OuterProduct), + ctx, 2); + patterns.add(ctx, + outerProductOpFilter); + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> createRewriteVectorContractPass() { + return std::make_unique(); +} + +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_rewrite_vector_multi_reduction.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/rewrite_vector_ops/rewrite_vector_multi_reduction.cc similarity index 55% rename from tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_rewrite_vector_multi_reduction.cc rename to tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/rewrite_vector_ops/rewrite_vector_multi_reduction.cc index 9033ef7962c..1058d12d798 100644 --- a/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_rewrite_vector_multi_reduction.cc +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/rewrite_vector_ops/rewrite_vector_multi_reduction.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,24 +16,20 @@ limitations under the License. #include #include +#include "gml_st/transforms/passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h" -namespace tensorflow { +namespace mlir { +namespace gml_st { namespace { #define GEN_PASS_DEF_REWRITEVECTORMULTIREDUCTIONPASS -#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h.inc" - -using mlir::MLIRContext; -using mlir::Operation; -using mlir::vector::MultiDimReductionOp; -using mlir::vector::VectorMultiReductionLowering; +#include "gml_st/transforms/passes.h.inc" struct RewriteVectorMultiReductionPass : public impl::RewriteVectorMultiReductionPassBase< @@ -41,43 +37,42 @@ struct RewriteVectorMultiReductionPass void runOnOperation() override { MLIRContext* ctx = &getContext(); Operation* op = getOperation(); - if (failed(RewriteTwoAndMoreDimReductions(ctx, op))) signalPassFailure(); - if (failed(RewriteOneDimReductions(ctx, op))) signalPassFailure(); + if (failed(rewriteTwoAndMoreDimReductions(ctx, op))) signalPassFailure(); + if (failed(rewriteOneDimReductions(ctx, op))) signalPassFailure(); } // Rewrite N-D reductions as the sequence of vector operations without // horizontal reduction, i.e. `vector.reduction`. - mlir::LogicalResult RewriteTwoAndMoreDimReductions(MLIRContext* ctx, - Operation* op) const { - mlir::ConversionTarget target(*ctx); - target.addLegalDialect(); - target.addDynamicallyLegalOp( - [&](MultiDimReductionOp op) { + LogicalResult rewriteTwoAndMoreDimReductions(MLIRContext* ctx, + Operation* op) const { + ConversionTarget target(*ctx); + target.addLegalDialect(); + target.addDynamicallyLegalOp( + [&](vector::MultiDimReductionOp op) { return op.getSourceVectorType().getRank() == 1; }); - mlir::RewritePatternSet patterns(ctx); - mlir::vector::populateVectorMultiReductionLoweringPatterns( - patterns, VectorMultiReductionLowering::InnerParallel); + RewritePatternSet patterns(ctx); + vector::populateVectorMultiReductionLoweringPatterns( + patterns, vector::VectorMultiReductionLowering::InnerParallel); return applyPartialConversion(op, target, std::move(patterns)); } // Rewrite 1D reductions as a `vector.reduction`. - mlir::LogicalResult RewriteOneDimReductions(MLIRContext* ctx, - Operation* op) const { - mlir::RewritePatternSet patterns(ctx); - mlir::vector::populateVectorMultiReductionLoweringPatterns( - patterns, VectorMultiReductionLowering::InnerReduction); + LogicalResult rewriteOneDimReductions(MLIRContext* ctx, Operation* op) const { + RewritePatternSet patterns(ctx); + vector::populateVectorMultiReductionLoweringPatterns( + patterns, vector::VectorMultiReductionLowering::InnerReduction); return applyPatternsAndFoldGreedily(op, std::move(patterns)); } }; } // namespace -std::unique_ptr> +std::unique_ptr> createRewriteVectorMultiReductionPass() { return std::make_unique(); } -} // namespace tensorflow +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/rewrite_vector_ops/rewrite_vector_transpose.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/rewrite_vector_ops/rewrite_vector_transpose.cc new file mode 100644 index 00000000000..ab2e372be70 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/rewrite_vector_ops/rewrite_vector_transpose.cc @@ -0,0 +1,66 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "gml_st/transforms/passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace gml_st { +namespace { + +#define GEN_PASS_DEF_REWRITEVECTORTRANSPOSEPASS +#include "gml_st/transforms/passes.h.inc" + +struct RewriteVectorTransposePass + : public impl::RewriteVectorTransposePassBase { + void runOnOperation() override { + auto avxLoweringOptions = + x86vector::avx2::LoweringOptions().setTransposeOptions( + x86vector::avx2::TransposeLoweringOptions() + .lower4x8xf32() + .lower8x8xf32()); + + func::FuncOp funcOp = getOperation(); + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns(context); + vector::VectorTransformsOptions vectorTransformOptions; + vectorTransformOptions = vectorTransformOptions.setVectorTransposeLowering( + vector::VectorTransposeLowering::EltWise); + vector::populateVectorTransposeLoweringPatterns(patterns, + vectorTransformOptions); + x86vector::avx2::populateSpecializedTransposeLoweringPatterns( + patterns, avxLoweringOptions, /*benefit=*/10); + + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +createRewriteVectorTransposePass() { + return std::make_unique(); +} + +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/scalarization.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/scalarization/scalarization.cc similarity index 53% rename from tensorflow/compiler/xla/mlir_hlo/lib/Transforms/scalarization.cc rename to tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/scalarization/scalarization.cc index f538c4ccde2..11d3731504d 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/scalarization.cc +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/scalarization/scalarization.cc @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include +#include #include -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h" -#include "mlir-hlo/Transforms/passes.h" +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/transforms.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -26,104 +28,113 @@ limitations under the License. #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "thlo/IR/thlo_ops.h" namespace mlir { +namespace gml_st { namespace { #define GEN_PASS_DEF_SCALARIZATIONPASS -#include "mlir-hlo/Transforms/passes.h.inc" +#include "gml_st/transforms/passes.h.inc" -using linalg::GenericOp; +using linalg::LinalgOp; using tensor::ExtractOp; using tensor::FromElementsOp; using tensor::InsertOp; -template -bool hasSingleElement(ShapedTy type) { - return type.hasStaticShape() && type.getNumElements() == 1; -} - -struct ScalarizeGenericOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +Value materializePoint(OpBuilder &b, Location loc, Value valueToTile, + ArrayRef offsets) { + auto tensorType = valueToTile.getType().cast(); + int64_t rank = tensorType.getRank(); - LogicalResult matchAndRewrite(GenericOp genericOp, - PatternRewriter &rewriter) const override { - auto isNonScalar = [](Type type) { - return type.isa() && - !hasSingleElement(type.cast()); - }; - if (llvm::any_of(genericOp.getOperandTypes(), isNonScalar) || - llvm::any_of(genericOp.getResultTypes(), isNonScalar)) - return failure(); + IntegerAttr oneAttr = b.getIndexAttr(1); + SmallVector sizes(rank, oneAttr); + SmallVector strides(rank, oneAttr); + + Value slice = b.create(loc, valueToTile, offsets, + sizes, strides); + Value zero = b.create(loc, 0); + return b.create(loc, slice, + SmallVector(rank, zero)); +} - // Map block arguments of genericOp to tensor.extract ops of its args. - Location loc = genericOp.getLoc(); - BlockAndValueMapping bvm; - for (OpOperand &opOperand : genericOp->getOpOperands()) { - Value operandValue = opOperand.get(); - Type operandType = operandValue.getType(); - auto bbArg = genericOp.getMatchingBlockArgument(&opOperand); - if (!operandType.isa()) continue; - - SmallVector indices( - operandType.cast().getRank(), - rewriter.create(loc, 0)); - Value extractedElement = - rewriter.create(loc, operandValue, indices); - bvm.map(bbArg, extractedElement); - } +struct ScalarizeLinalgOp : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + static LogicalResult inlinePayload(PatternRewriter &rewriter, Location loc, + LinalgOp linalgOp, ValueRange argValues) { // Clone everything but terminator. - Block *body = genericOp.getBody(); - for (Operation &op : body->without_terminator()) { - // `linalg.index` can only result in 0 for scalar linalg.generic. - if (auto indexOp = dyn_cast(op)) { + Block *body = linalgOp.getBlock(); + IRMapping map; + map.map(body->getArguments(), argValues); + for (auto &op : body->without_terminator()) { + if (auto indexOp = dyn_cast(&op)) { Value zero = rewriter.create(loc, 0); - bvm.map(indexOp.getResult(), zero); + map.map(indexOp.getResult(), zero); continue; } - rewriter.clone(op, bvm); + rewriter.clone(op, map); } // Wrap every scalar result into a tensor using `tensor.from_elements`. SmallVector newResults; for (auto [resultType, yieldOperand] : - llvm::zip(genericOp->getResultTypes(), + llvm::zip(linalgOp->getResultTypes(), body->getTerminator()->getOperands())) { - auto scalarValue = bvm.lookupOrDefault(yieldOperand); + auto scalarValue = map.lookupOrDefault(yieldOperand); newResults.push_back( rewriter.create(loc, resultType, scalarValue)); } - rewriter.replaceOp(genericOp, newResults); - + rewriter.replaceOp(linalgOp, newResults); return success(); } -}; - -// Extracts a point using gml_st.materialize and gml_st.tile with 1 element. -Value getPoint(OpBuilder &b, Location loc, Value tensor, ValueRange indices) { - IntegerAttr oneAttr = b.getIndexAttr(1); - - auto tensorType = tensor.getType().cast(); - int64_t tensorRank = tensorType.getRank(); - SmallVector offsets(indices.begin(), indices.end()); - SmallVector sizes(tensorRank, oneAttr); - SmallVector strides(tensorRank, oneAttr); + LogicalResult matchAndRewrite(LinalgOp linalgOp, + PatternRewriter &rewriter) const override { + // Fail if not every argument is a scalar or a single-element tensor. + if (!hasSingleElementOperandsAndResults(linalgOp)) return failure(); + + // TODO(aliia): fix scalarization of FillOp. + if (auto *fillOp = dyn_cast(&linalgOp)) return failure(); + + // Load the data corresponding to the block arguments that + // represent input operands. + SmallVector indexedValues; + indexedValues.reserve(linalgOp->getNumOperands()); + Location loc = linalgOp->getLoc(); + auto zero = rewriter.create(loc, 0); + for (OpOperand &operand : linalgOp->getOpOperands()) { + if (!linalgOp.payloadUsesValueFromOperand(&operand)) { + indexedValues.push_back(nullptr); + continue; + } + if (linalgOp.isScalar(&operand)) { + indexedValues.push_back(operand.get()); + continue; + } + Value operandValue = operand.get(); + Type operandType = operandValue.getType(); + SmallVector indices(operandType.cast().getRank(), + zero); + Value load = rewriter.create(loc, operandValue, indices); + indexedValues.push_back(load); + } - Value tile = b.create(loc, offsets, sizes, strides); - return b.create(loc, tensorType.getElementType(), - tensor, tile); -} + // Inline the op payload and rewrite the operation. + return inlinePayload(rewriter, loc, linalgOp, indexedValues); + } +}; // Returns `startIndices`[0, :] for `startIndices` of shape 1xn. Returns None if // startIndices has a different shape. Optional> extractStartIndices( - ImplicitLocOpBuilder &b, TypedValue startIndices) { + ImplicitLocOpBuilder &b, TypedValue startIndices) { if (startIndices.getType().getRank() != 2 || startIndices.getType().getDimSize(0) != 1) { - return llvm::None; + return std::nullopt; } int64_t indexVectorSize = startIndices.getType().getDimSize(1); @@ -163,17 +174,10 @@ struct ScalarizeScatterOp : public OpRewritePattern { auto initType = init.getType().dyn_cast(); if (!initType) return failure(); - int64_t initRank = initType.getRank(); - SmallVector initDimSizes = tensor::getMixedSizes(b, loc, init); auto initDimValues = getValueOrCreateConstantIndexOp(b, loc, initDimSizes); - Value initTile = b.create( - loc, SmallVector(initRank, b.getI64IntegerAttr(0)), - initDimSizes, - SmallVector(initRank, b.getI64IntegerAttr(1))); - Value zero = b.create(0); Value one = b.create(1); @@ -181,56 +185,68 @@ struct ScalarizeScatterOp : public OpRewritePattern { SmallVector lbs(updatesRank, zero); SmallVector steps(updatesRank, one); - auto loop = b.create( - TypeRange(ValueRange{init}), lbs, updatesDimValues, steps, init, - [&](OpBuilder &nestedBuilder, Location bodyLoc, ValueRange updateIndex, - ValueRange loopInits) { - Value initBlockArg = loopInits.front(); - - auto initIndex = llvm::to_vector(updateIndex.drop_front()); - for (const auto &en : llvm::enumerate(*scatterIndices)) { - initIndex[en.index()] = nestedBuilder.create( - bodyLoc, initIndex[en.index()], en.value()); - } + SmallVector limitIndex{ + ArrayRef(updatesDimValues).drop_front()}; + for (const auto &en : llvm::enumerate(*scatterIndices)) { + limitIndex[en.index()] = + b.create(loc, limitIndex[en.index()], en.value()); + } + for (auto &value : limitIndex) { + value = b.create(loc, value, one); + } - Value indexIsInBounds = - isValidIndex(nestedBuilder, loc, initIndex, initDimValues, zero); - Value maybeUpdatedInit = - nestedBuilder - .create( - loc, initType, indexIsInBounds, - [&](OpBuilder &thenBuilder, Location thenLoc) { - Value updateValue = - getPoint(thenBuilder, loc, updates, updateIndex); - Value currentValue = - getPoint(thenBuilder, loc, initBlockArg, initIndex); - - // Combine update with the value in the output. - Block *body = scatterOp.getBody(); - BlockAndValueMapping bvm; - bvm.map(body->getArgument(0), updateValue); - bvm.map(body->getArgument(1), currentValue); - - for (Operation &op : body->without_terminator()) - thenBuilder.clone(op, bvm); - - // Wrap every scalar result into a tensor using - // `tensor.from_elements`. - auto combinedValue = - bvm.lookup(body->getTerminator()->getOperand(0)); - Value updatedInit = thenBuilder.create( - thenLoc, combinedValue, initBlockArg, initIndex); - thenBuilder.create(thenLoc, updatedInit); - }, - [&](OpBuilder &elseBuilder, Location elseLoc) { - elseBuilder.create(elseLoc, initBlockArg); - }) - .getResult(0); - - nestedBuilder.create(bodyLoc, maybeUpdatedInit, - initBlockArg, initTile); + Value indexIsInBounds = + isValidIndex(b, loc, limitIndex, initDimValues, zero); + indexIsInBounds = b.create( + loc, indexIsInBounds, + isValidIndex(b, loc, *scatterIndices, initDimValues, zero)); + auto ifOp = b.create( + loc, indexIsInBounds, + [&](OpBuilder &thenBuilder, Location thenLoc) { + scf::LoopNest loopNest = scf::buildLoopNest( + thenBuilder, thenLoc, lbs, updatesDimValues, steps, + ValueRange{init}, + [&](OpBuilder &nestedBuilder, Location bodyLoc, + ValueRange updateIndex, ValueRange loopInits) { + Value initBlockArg = loopInits.front(); + + auto initIndex = llvm::to_vector(updateIndex.drop_front()); + for (const auto &en : llvm::enumerate(*scatterIndices)) { + initIndex[en.index()] = nestedBuilder.create( + bodyLoc, initIndex[en.index()], en.value()); + } + + Value updateValue = materializePoint( + thenBuilder, loc, updates, getAsOpFoldResult(updateIndex)); + Value currentValue = + materializePoint(thenBuilder, loc, initBlockArg, + getAsOpFoldResult(initIndex)); + + // Combine update with the value in the output. + Block *body = scatterOp.getBody(); + IRMapping bvm; + bvm.map(body->getArgument(0), updateValue); + bvm.map(body->getArgument(1), currentValue); + + for (Operation &op : body->without_terminator()) + thenBuilder.clone(op, bvm); + + // Wrap every scalar result into a tensor using + // `tensor.from_elements`. + auto combinedValue = + bvm.lookup(body->getTerminator()->getOperand(0)); + Value updatedInit = thenBuilder.create( + thenLoc, combinedValue, initBlockArg, initIndex); + + return scf::ValueVector({updatedInit}); + }); + + thenBuilder.create(thenLoc, loopNest.results); + }, + [&](OpBuilder &elseBuilder, Location elseLoc) { + elseBuilder.create(elseLoc, init); }); - rewriter.replaceOp(scatterOp, loop.getResults()); + rewriter.replaceOp(scatterOp, ifOp.getResults()); return success(); } @@ -277,12 +293,25 @@ struct ScalarizeGatherOp : public OpRewritePattern { TypedValue operand = gatherOp.getOperand(); auto operandSizes = getValueOrCreateConstantIndexOp( b, loc, tensor::createDimValues(b, loc, operand)); - Value zero = b.create(0); Value one = b.create(1); + + SmallVector sliceSizes{initDimSizeValues.begin() + 1, + initDimSizeValues.end()}; + while (sliceSizes.size() < startIndices->size()) { + sliceSizes.push_back(one); + } + + // Clamp the indices. + for (auto &&[startIndex, max, sliceSize] : + llvm::zip(*startIndices, operandSizes, sliceSizes)) { + auto maxMinusSize = b.createOrFold(loc, max, sliceSize); + startIndex = b.create(loc, startIndex, maxMinusSize); + startIndex = b.create(loc, startIndex, zero); + } + SmallVector lbs(initRank, zero); SmallVector steps(initRank, one); - rewriter.replaceOpWithNewOp( gatherOp, TypeRange(ValueRange{init}), lbs, initDimSizeValues, steps, init, @@ -290,28 +319,19 @@ struct ScalarizeGatherOp : public OpRewritePattern { ValueRange loopInits) { // Compute the index in the operand. SmallVector readIndices(operand.getType().getRank(), zero); - llvm::copy(ivs, readIndices.begin()); + llvm::copy(ivs.drop_front(1), readIndices.begin()); for (auto &&[readIndex, startIndex] : llvm::zip(readIndices, *startIndices)) { readIndex = nestedBuilder.create(bodyLoc, readIndex, startIndex); } - // Clamp the indices. - for (auto &&[readIndex, max] : llvm::zip(readIndices, operandSizes)) { - auto maxMinusOne = - nestedBuilder.createOrFold(bodyLoc, max, one); - readIndex = nestedBuilder.create(bodyLoc, readIndex, - maxMinusOne); - readIndex = - nestedBuilder.create(bodyLoc, readIndex, zero); - } - // Materialize the value and yield it. SmallVector ones(initRank, oneAttr); Value tile = nestedBuilder.create( bodyLoc, SmallVector(ivs), ones, ones); - Value val = getPoint(nestedBuilder, bodyLoc, operand, readIndices); + Value val = materializePoint(nestedBuilder, bodyLoc, operand, + getAsOpFoldResult(readIndices)); nestedBuilder.create(bodyLoc, val, loopInits.front(), tile); }); @@ -328,7 +348,7 @@ struct ScalarizeConcatenateOp : public OpRewritePattern { LogicalResult matchAndRewrite(thlo::ConcatenateOp concatenateOp, PatternRewriter &rewriter) const override { Location loc = concatenateOp.getLoc(); - int64_t concatDim = concatenateOp.getDimension(); + int64_t concatDim = concatenateOp.getDimension().getSExtValue(); auto initTensor = concatenateOp.getInit(); auto initType = initTensor.getType(); @@ -352,10 +372,10 @@ struct ScalarizeConcatenateOp : public OpRewritePattern { sizes.emplace_back(rewriter.create(loc, initTensor, i)); } } - Value tile = rewriter.create(loc, offsets, sizes, strides); auto materializeAndInsert = [&](OpBuilder &b, Location l, Value input) { - Value slice = b.create(l, input, tile); + Value slice = + b.create(l, input, offsets, sizes, strides); return b.create(l, slice, initTensor, offsets, sizes, strides); }; @@ -389,8 +409,7 @@ struct ScalarizeConcatenateOp : public OpRewritePattern { return b .create( - loc, resultType, - tensorHasElement(b, loc, inputs.front(), concatDim), + loc, tensorHasElement(b, loc, inputs.front(), concatDim), [&](OpBuilder &thenBuilder, Location thenLoc) { thenBuilder.create( thenLoc, @@ -406,56 +425,124 @@ struct ScalarizeConcatenateOp : public OpRewritePattern { } }; +namespace { +LogicalResult scalarizeOp(Operation *op, PatternRewriter &rewriter, + TypedValue &input, + TypedValue &output) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + + auto outputType = output.getType().dyn_cast(); + if (!outputType) { + return rewriter.notifyMatchFailure( + op, "failed to cast output to RankedTensorType"); + } + if (!hasSingleElement(outputType)) { + return rewriter.notifyMatchFailure( + op, "has output with number of elements not equal to 1"); + } + + auto inputType = input.getType().dyn_cast(); + if (!inputType) { + return rewriter.notifyMatchFailure( + op, "failed to cast input to RankedTensorType"); + } + + Value zero = b.create(0); + llvm::SmallVector indicesInput(inputType.getRank(), zero); + llvm::SmallVector indicesOutput(outputType.getRank(), zero); + + Value extractedValue = b.create(input, indicesInput); + Value result = b.create(outputType, extractedValue); + + rewriter.replaceOp(op, result); + return success(); +} + +} // namespace + struct ScalarizeDynamicBroadcastInDimOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(thlo::DynamicBroadcastInDimOp broadcastOp, PatternRewriter &rewriter) const override { - Location loc = broadcastOp.getLoc(); - ImplicitLocOpBuilder b(loc, rewriter); - - auto output = broadcastOp.getInit(); - auto outputType = output.getType().dyn_cast(); - if (!outputType) return failure(); - - if (!hasSingleElement(outputType)) return failure(); - auto input = broadcastOp.getOperand(); - auto inputType = input.getType().dyn_cast(); - if (!inputType) return failure(); - - Value zero = b.create(0); - llvm::SmallVector indicesInput(inputType.getRank(), zero); - llvm::SmallVector indicesOutput(outputType.getRank(), zero); + auto output = broadcastOp.getInit(); + return scalarizeOp(broadcastOp, rewriter, input, output); + } +}; - Value extractedValue = b.create(input, indicesInput); - Value result = - b.create(extractedValue, output, indicesOutput); +struct ScalarizeReverseOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - rewriter.replaceOp(broadcastOp, result); - return success(); + LogicalResult matchAndRewrite(thlo::ReverseOp reverseOp, + PatternRewriter &rewriter) const override { + auto input = reverseOp.getInput(); + auto output = reverseOp.getInit(); + return scalarizeOp(reverseOp, rewriter, input, output); } }; -// Fold `tensor.extract(gml_st.materialize -> tensor<1x1xf32>)` into -// `gml_st.materialize -> f32` for single-element tensors. -struct FoldTensorExtractIntoMaterialize : public OpRewritePattern { +struct ScalarizeIfOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ExtractOp extractOp, + LogicalResult matchAndRewrite(scf::IfOp op, PatternRewriter &rewriter) const override { - auto materializeOp = - extractOp.getTensor().getDefiningOp(); - if (!materializeOp) return failure(); + // Analyse result types and determine what we can scalarize. + int64_t numResults = op.getNumResults(); + SmallVector isScalarizableResult(numResults, false); + SmallVector unscalarizedResultType = + llvm::to_vector(op.getResultTypes()); + SmallVector scalarizedResultType = + llvm::to_vector(op.getResultTypes()); + bool isAnyResultScalarizable = false; + for (int64_t i = 0; i < numResults; ++i) { + auto rankedTy = scalarizedResultType[i].dyn_cast(); + if (!rankedTy || !hasSingleElement(rankedTy)) continue; + isScalarizableResult[i] = true; + scalarizedResultType[i] = rankedTy.getElementType(); + isAnyResultScalarizable = true; + } + + if (!isAnyResultScalarizable) { + return rewriter.notifyMatchFailure(op, "cannot scalarize any result"); + } - auto tileType = - materializeOp.getSet().getType().dyn_cast(); - if (!tileType || !hasSingleElement(tileType)) return failure(); + // Create new if op. + Location loc = op.getLoc(); + Value zero = rewriter.create(loc, 0); + auto scalarizedOp = rewriter.create(loc, scalarizedResultType, + op.getCondition()); + scalarizedOp.getThenRegion().takeBody(op.getThenRegion()); + scalarizedOp.getElseRegion().takeBody(op.getElseRegion()); + for (int64_t i = 0; i < numResults; ++i) { + if (!isScalarizableResult[i]) continue; + + // Insert `extract` ops to yield value as a scalar. + llvm::SmallVector zeroIndices( + unscalarizedResultType[i].cast().getRank(), zero); + rewriter.setInsertionPoint(scalarizedOp.thenYield()); + Value thenScalar = rewriter.createOrFold( + loc, scalarizedOp.thenYield().getOperand(i), zeroIndices); + scalarizedOp.thenYield().setOperand(i, thenScalar); + rewriter.setInsertionPoint(scalarizedOp.elseYield()); + Value elseScalar = rewriter.createOrFold( + loc, scalarizedOp.elseYield().getOperand(i), zeroIndices); + scalarizedOp.elseYield().setOperand(i, elseScalar); + } + + // Insert `from_elements` op to be type compatible. + rewriter.setInsertionPointAfter(scalarizedOp); + SmallVector results(scalarizedOp.getResults()); + for (int64_t i = 0; i < numResults; ++i) { + if (!isScalarizableResult[i]) continue; + + // Wrap scalar. + results[i] = rewriter.create( + loc, unscalarizedResultType[i], results[i]); + } - rewriter.replaceOpWithNewOp( - extractOp, extractOp.getType(), materializeOp.getSource(), - materializeOp.getSet()); + rewriter.replaceOp(op, results); return success(); } }; @@ -492,8 +579,7 @@ struct FoldTensorFromElementsIntoSetYield }; void populateTensorInsertExtractFoldingPatterns(RewritePatternSet *patterns) { - patterns->add(patterns->getContext()); + patterns->add(patterns->getContext()); } struct ScalarizationPass @@ -508,9 +594,10 @@ struct ScalarizationPass ScalarizeConcatenateOp, ScalarizeDynamicBroadcastInDimOp, ScalarizeGatherOp, - ScalarizeGenericOp, - ScalarizeScatterOp - >(context); + ScalarizeIfOp, + ScalarizeLinalgOp, + ScalarizeReverseOp, + ScalarizeScatterOp>(context); // clang-format on populateTensorInsertExtractFoldingPatterns(&patterns); FromElementsOp::getCanonicalizationPatterns(patterns, context); @@ -525,4 +612,5 @@ std::unique_ptr> createScalarizationPass() { return std::make_unique(); } +} // namespace gml_st } // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/test_passes.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/test_passes.cc new file mode 100644 index 00000000000..b823f88e6db --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/test_passes.cc @@ -0,0 +1,95 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "gml_st/transforms/test_passes.h" + +#include +#include + +#include "gml_st/interfaces/bufferizable_op_interface_impl.h" +#include "gml_st/transforms/fusion/fusion.h" +#include "gml_st/transforms/peeling/peeling.h" +#include "gml_st/transforms/transforms.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace gml_st { +namespace { + +#define GEN_PASS_DEF_TESTGMLSTGREEDYFUSION +#include "gml_st/transforms/test_passes.h.inc" + +static constexpr llvm::StringRef kTestFusionAppliedLabel = + "__test_fusion_applied_label__"; + +struct GreedyFusionPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(gml_st::ParallelOp op, + PatternRewriter &rewriter) const override { + if (hasLabel(op, kTestFusionAppliedLabel)) return failure(); + + rewriter.updateRootInPlace(op, [&]() { + fuseGreedily(rewriter, op.getRegion().front(), [](Operation *op) { + return isa(op); + }); + }); + + setLabel(op, kTestFusionAppliedLabel); + return success(); + } +}; + +struct TestGmlStGreedyFusionPass + : public impl::TestGmlStGreedyFusionBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + linalg::registerTilingInterfaceExternalModels(registry); + } + + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(ctx); + + patterns.add(ctx); + + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) + return signalPassFailure(); + + funcOp.walk([](gml_st::ParallelOp op) { + removeLabel(op, kTestFusionAppliedLabel); + }); + } +}; + +} // namespace + +std::unique_ptr> createTestGmlStGreedyFusionPass() { + return std::make_unique(); +} + +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.h b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/test_passes.h similarity index 62% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.h rename to tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/test_passes.h index c4cc30d423d..6458bc5bc07 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.h +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/test_passes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TEST_PASSES_H -#define MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TEST_PASSES_H +#ifndef MLIR_HLO_GML_ST_TRANSFORMS_TEST_PASSES_H +#define MLIR_HLO_GML_ST_TRANSFORMS_TEST_PASSES_H #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" @@ -24,18 +24,14 @@ namespace mlir { namespace gml_st { #define GEN_PASS_DECL -#include "mlir-hlo/Dialect/gml_st/transforms/test_passes.h.inc" +#include "gml_st/transforms/test_passes.h.inc" -std::unique_ptr> createTestGmlStLoopPeelingPass(); - -std::unique_ptr> createTestGmlStLoopTilingPass(); - -std::unique_ptr> createTestGmlStBufferizationPass(); +std::unique_ptr> createTestGmlStGreedyFusionPass(); #define GEN_PASS_REGISTRATION -#include "mlir-hlo/Dialect/gml_st/transforms/test_passes.h.inc" +#include "gml_st/transforms/test_passes.h.inc" } // namespace gml_st } // namespace mlir -#endif // MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TEST_PASSES_H +#endif // MLIR_HLO_GML_ST_TRANSFORMS_TEST_PASSES_H diff --git a/tensorflow/compiler/xla/mlir_hlo/tosa/include/mhlo_tosa/Transforms/passes.td b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/test_passes.td similarity index 75% rename from tensorflow/compiler/xla/mlir_hlo/tosa/include/mhlo_tosa/Transforms/passes.td rename to tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/test_passes.td index 90394952933..2be3577fc79 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tosa/include/mhlo_tosa/Transforms/passes.td +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/test_passes.td @@ -15,8 +15,7 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def TosaLegalizeMhloPass : Pass<"tosa-legalize-mhlo", "mlir::func::FuncOp"> { - let summary = "Legalize from MHLO to TOSA"; - let constructor = "createLegalizeMhloPass()"; - let dependentDialects = ["::mlir::tosa::TosaDialect"]; -} \ No newline at end of file +def TestGmlStGreedyFusion : Pass<"test-gml-st-greedy-fusion", "mlir::func::FuncOp"> { + let summary = "Fuse ops greedily into gml-st loops."; + let constructor = "::mlir::gml_st::createTestGmlStGreedyFusionPass()"; +} diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/tiling/tiling.cc similarity index 70% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling.cc rename to tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/tiling/tiling.cc index 677918933b6..59a86a311a5 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling.cc +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/tiling/tiling.cc @@ -13,30 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/gml_st/transforms/tiling.h" +#include "gml_st/transforms/tiling/tiling.h" #include #include +#include #include #include #include +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/transforms.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h" -#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -46,8 +45,7 @@ void TilingOptions::setTileSizeComputationFn(ArrayRef ts) { SmallVector tileSizes(ts.begin(), ts.end()); tileSizeComputationFn = [tileSizes](OpBuilder &b, Operation *op) { return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { - Value v = b.create(op->getLoc(), s); - return v; + return b.create(op->getLoc(), s).getResult(); })); }; } @@ -55,7 +53,9 @@ void TilingOptions::setTileSizeComputationFn(ArrayRef ts) { namespace { #define GEN_PASS_DEF_TILINGPASS -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h.inc" +#include "gml_st/transforms/passes.h.inc" + +constexpr llvm::StringRef kTileAppliedLabel = "__tile_applied_label__"; // Compute tile size for the tile that starts at `offset`, has size `tileSize` // for the tensor with the dimension size `dimSize`. @@ -66,8 +66,8 @@ namespace { OpFoldResult computeTileSizeInDim(OpBuilder &builder, Location loc, OpFoldResult tileSize, OpFoldResult dimSize, OpFoldResult offset) { - Optional tileCst = getConstantIntValue(tileSize); - Optional dimCst = getConstantIntValue(dimSize); + std::optional tileCst = getConstantIntValue(tileSize); + std::optional dimCst = getConstantIntValue(dimSize); bool hasTileSizeOne = tileCst && *tileCst == 1; bool dividesEvenly = tileCst && dimCst && ((*dimCst % *tileCst) == 0); @@ -127,7 +127,7 @@ Operation *generateTileLoopNest(OpBuilder &builder, Location loc, nestedBuilder, bodyLoc, steps[index], ubs[index], iv); } }; - Optional distributionLabelAttr; + std::optional distributionLabelAttr; if (!distributionLabel.empty()) { distributionLabelAttr = StringAttr::get(builder.getContext(), distributionLabel); @@ -139,9 +139,9 @@ Operation *generateTileLoopNest(OpBuilder &builder, Location loc, getValueOrCreateConstantIndexOp(builder, loc, lbs), getValueOrCreateConstantIndexOp(builder, loc, ubs), getValueOrCreateConstantIndexOp(builder, loc, steps), - distributionLabelAttr, + dstOperands, distributionLabelAttr, [&](OpBuilder &nestedBuilder, Location bodyLoc, - ValueRange ivs) { + ValueRange ivs, ValueRange /*outputs*/) { buildBody(nestedBuilder, bodyLoc, ivs); }) .getOperation() @@ -160,38 +160,11 @@ Operation *generateTileLoopNest(OpBuilder &builder, Location loc, return loop; } -struct DimOfMaterializedTilePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::DimOp op, - PatternRewriter &rewriter) const override { - Operation *def = op.getSource().getDefiningOp(); - if (!def) return failure(); - - auto materializeOp = llvm::dyn_cast(def); - if (!materializeOp) return failure(); - - auto tileOp = materializeOp.getSet().getDefiningOp(); - if (!tileOp) return failure(); - - Optional indexOr = op.getConstantIndex(); - if (!indexOr.has_value()) return failure(); - - Value tileSizeValue = - tileOp.isDynamicSize(*indexOr) - ? tileOp.getDynamicSize(*indexOr) - : rewriter.create( - op.getLoc(), tileOp.getStaticSize(*indexOr)); - rewriter.replaceOp(op, tileSizeValue); - return success(); - } -}; - /// Pattern to tile an op that implements the `TilingInterface` using /// `gml_st.for` for iterating over the tiles. struct TilingPattern : public OpInterfaceRewritePattern { TilingPattern(MLIRContext *context, - llvm::function_ref filterFn, + llvm::function_ref filterFn, TilingOptions options, PatternBenefit benefit = 1) : OpInterfaceRewritePattern(context, benefit), filterFn(filterFn), @@ -199,10 +172,10 @@ struct TilingPattern : public OpInterfaceRewritePattern { LogicalResult matchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const override { - if (!filterFn || failed(filterFn(op)) || hasTransformationAttr(op)) + if (!filterFn || failed(filterFn(op)) || hasLabel(op, kTileAppliedLabel)) return failure(); - auto tilingResult = tile(options, rewriter, op); + auto tilingResult = tileUsingGmlSt(options, rewriter, op); if (failed(tilingResult)) return failure(); // If we did not tile (e.g. when all tile sizes are 0), do not replace @@ -210,12 +183,12 @@ struct TilingPattern : public OpInterfaceRewritePattern { if (tilingResult->loop != nullptr) { rewriter.replaceOp(op, tilingResult->loop->getResults()); } - setTransformationAttr(rewriter, tilingResult->tiledOp); + setLabel(tilingResult->tiledOps.front(), kTileAppliedLabel); return success(); } private: - llvm::function_ref filterFn; + llvm::function_ref filterFn; TilingOptions options; }; @@ -230,9 +203,9 @@ struct TilingPass : public impl::TilingPassBase { } void getDependentDialects(DialectRegistry ®istry) const final { - registry - .insert(); - registerGmlStTilingInterfaceExternalModels(registry); + registry.insert(); + linalg::registerTilingInterfaceExternalModels(registry); } void runOnOperation() override { @@ -252,7 +225,7 @@ struct TilingPass : public impl::TilingPassBase { })); }; - auto filterFn = [&](Operation *op) { + auto filterFn = [&](TilingInterface op) { if (!opName.empty() && op->getName().getStringRef() != opName) return failure(); if (!opLabel.empty() && !hasMatchingLabel(op, opLabel)) return failure(); @@ -260,23 +233,49 @@ struct TilingPass : public impl::TilingPassBase { }; RewritePatternSet patterns(ctx); populateTilingPatterns(ctx, filterFn, opts, &patterns); - patterns.add(ctx); if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) return signalPassFailure(); // Clean up by removing temporary attributes. - f.walk([](Operation *op) { removeTransformationAttr(op); }); + removeTilingLabels(f); } }; +template +void insertTerminatorAndUpdateOutputs(PatternRewriter &rewriter, + const TilingResult &tilingResult, + SetYieldOp terminator, + ValueRange dstOperands, + ValueRange outputTiles) { + auto parallelLoop = cast(tilingResult.loop); + rewriter.replaceOpWithNewOp( + terminator, tilingResult.tiledOps.front()->getResults(), + parallelLoop.getRegionOutputArgs(), outputTiles); + + if (auto dstOp = dyn_cast( + tilingResult.tiledOps.front())) { + for (auto [dst, regionArg] : + llvm::zip(dstOperands, parallelLoop.getRegionOutputArgs())) { + dst.replaceUsesWithIf(regionArg, [&](OpOperand &operand) { + Operation *owner = operand.getOwner(); + return isa(owner) && + owner->getParentOfType() == parallelLoop.getOperation(); + }); + } + } +} + } // namespace -FailureOr tile(const TilingOptions &options, - PatternRewriter &rewriter, TilingInterface op) { +FailureOr tileUsingGmlSt(const TilingOptions &options, + PatternRewriter &rewriter, + TilingInterface op) { + rewriter.setInsertionPoint(op); if (!options.tileSizeComputationFn) { return rewriter.notifyMatchFailure( op, "missing tile size computation function"); } + Location loc = op.getLoc(); // 1. Get the range of the loops that are represented by the operation. SmallVector iterationDomain = op.getIterationDomain(rewriter); @@ -294,57 +293,71 @@ FailureOr tile(const TilingOptions &options, } if (tileSizeVector.size() < iterationDomain.size()) { - auto zero = rewriter.create(op.getLoc(), 0); + auto zero = rewriter.create(loc, 0); tileSizeVector.append(numLoops - tileSizeVector.size(), zero); } if (llvm::all_of(tileSizeVector, mlir::gml_st::isZero)) { - return TilingResult{op, nullptr}; + return TilingResult{{op}, nullptr}; } // 3. Materialize an empty loop nest that iterates over the tiles. - auto dstOperands = op.getDestinationOperands(rewriter); + SmallVector dstOperands; + if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dstOperands))) + return rewriter.notifyMatchFailure(op, "failed to get destinations"); SmallVector offsets, sizes; TilingResult tilingResult; tilingResult.loop = generateTileLoopNest( - rewriter, op.getLoc(), iterationDomain, tileSizeVector, dstOperands, + rewriter, loc, iterationDomain, tileSizeVector, dstOperands, options.distribute, options.distributionLabel, offsets, sizes); Block *loopBody = &tilingResult.loop->getRegion(0).front(); - Operation *terminator = loopBody->getTerminator(); + auto terminator = cast(loopBody->getTerminator()); rewriter.setInsertionPoint(terminator); // 4. Insert the tiled implementation within the loop. - TilingInterface tiledOp = op.getTiledImplementation(rewriter, offsets, sizes); - tilingResult.tiledOp = tiledOp.getOperation(); - - // 5. Add `gml_st.set_yield` terminator. - SmallVector dstSubsets; - for (Value dst : tiledOp.getDestinationOperands(rewriter)) - dstSubsets.push_back(dst.getDefiningOp().getSet()); - rewriter.replaceOpWithNewOp( - terminator, tilingResult.tiledOp->getResults(), dstOperands, dstSubsets); - - // 6. Replace the uses of `outputs` with the output block arguments. - if (!options.distribute) { - auto forLoop = cast(tilingResult.loop); - for (auto [dst, regionArg] : - llvm::zip(dstOperands, forLoop.getRegionOutputArgs())) { - dst.replaceUsesWithIf(regionArg, [&](OpOperand &operand) { - return operand.getOwner()->getBlock() == loopBody; - }); + tilingResult.tiledOps = op.getTiledImplementation(rewriter, offsets, sizes); + + // 5. Compute tiles for the insertion. + int64_t numResults = op->getNumResults(); + SmallVector outputTiles; + auto oneAttr = rewriter.getI64IntegerAttr(1); + for (const auto &result : llvm::enumerate(op->getResults())) { + SmallVector resultOffsetsList(numResults), + resultSizesList(numResults); + if (failed(op.getResultTilePosition(rewriter, result.index(), offsets, + sizes, resultOffsetsList, + resultSizesList))) { + return rewriter.notifyMatchFailure( + op, "failed to get slice of result produced"); } + outputTiles.push_back(rewriter.createOrFold( + loc, resultOffsetsList, resultSizesList, + SmallVector(resultSizesList.size(), oneAttr))); } + // 6. Add a `set_yield` terminator, update the uses of `outputs` with the + // output bbArgs. + if (options.distribute) { + insertTerminatorAndUpdateOutputs( + rewriter, tilingResult, terminator, dstOperands, outputTiles); + } else { + insertTerminatorAndUpdateOutputs(rewriter, tilingResult, terminator, + dstOperands, outputTiles); + } return tilingResult; } void populateTilingPatterns( MLIRContext *context, - llvm::function_ref filterFn, + llvm::function_ref filterFn, const TilingOptions &opts, RewritePatternSet *patterns) { patterns->add(context, filterFn, opts); } +void removeTilingLabels(Operation *op) { + op->walk([](Operation *op) { removeLabel(op, kTileAppliedLabel); }); +} + std::unique_ptr> createTilingPass( StringRef opName, StringRef opLabel, bool distribute, ArrayRef tileSizes) { diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling.h b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/tiling/tiling.h similarity index 74% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling.h rename to tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/tiling/tiling.h index d856e33dd6c..c0a24c9246d 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling.h +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/tiling/tiling.h @@ -13,20 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TILING_H -#define MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TILING_H +#ifndef MLIR_HLO_GML_ST_TRANSFORMS_TILING_TILING_H +#define MLIR_HLO_GML_ST_TRANSFORMS_TILING_TILING_H #include #include -#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/TilingInterface.h" namespace mlir { namespace gml_st { struct TilingResult { - Operation *tiledOp = nullptr; + SmallVector tiledOps; Operation *loop = nullptr; }; @@ -53,16 +54,20 @@ struct TilingOptions { /// Create tiled operation based on the specified tiling options. The result is /// equivalent to original op. -FailureOr tile(const TilingOptions &options, - PatternRewriter &rewriter, TilingInterface op); +FailureOr tileUsingGmlSt(const TilingOptions &options, + PatternRewriter &rewriter, + TilingInterface op); /// Populate tiling patterns. void populateTilingPatterns( MLIRContext *context, - llvm::function_ref filterFn, + llvm::function_ref filterFn, const TilingOptions &opts, RewritePatternSet *patterns); +/// Cleans up attributes from applying above tiling patterns. +void removeTilingLabels(Operation *op); + } // namespace gml_st } // namespace mlir -#endif // MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TILING_H +#endif // MLIR_HLO_GML_ST_TRANSFORMS_TILING_TILING_H diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_softmax.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/tiling_softmax/tiling_softmax.cc similarity index 71% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_softmax.cc rename to tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/tiling_softmax/tiling_softmax.cc index 4191055aa26..67e465d272d 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_softmax.cc +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/tiling_softmax/tiling_softmax.cc @@ -14,17 +14,18 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include -#include "mlir-hlo/Dialect/gml_st/transforms/fusion.h" -#include "mlir-hlo/Dialect/gml_st/transforms/linalg_utils.h" -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h" -#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h" +#include "gml_st/transforms/fusion/fusion.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/tiling/tiling.h" +#include "gml_st/transforms/transforms.h" +#include "gml_st/utils/linalg_utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -33,11 +34,15 @@ namespace gml_st { namespace { #define GEN_PASS_DEF_TILINGSOFTMAXPASS -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h.inc" +#include "gml_st/transforms/passes.h.inc" + +constexpr llvm::StringRef kTileSoftmaxAppliedLabel = + "__tile_softmax_applied_label__"; Operation *fuseIthOperandInPlace(PatternRewriter &rewriter, Operation *op, int64_t i) { - auto matOp = llvm::cast(op->getOperand(i).getDefiningOp()); + auto matOp = + llvm::cast(op->getOperand(i).getDefiningOp()); FailureOr fused = createFusedOp(rewriter, matOp); assert(succeeded(fused) && "expect success after matching"); rewriter.replaceOp(matOp, *fused); @@ -49,8 +54,6 @@ LogicalResult tilePartialSoftmax( llvm::function_ref(Operation *, int64_t)> tileOperationFn) { // Match cwise root op. - if (!isCwiseGenericOp(op)) return failure(); - // Match all operands to be derived from the same source value in one of two // ways: // i) by a reduction and subsequent bcast in one dimension, or @@ -58,42 +61,49 @@ LogicalResult tilePartialSoftmax( Value commonSource; Optional commonReductionDim; SmallVector> simpleBcastReductions; - auto genericOp = llvm::dyn_cast_or_null(op.getOperation()); - for (Value operand : genericOp.getInputs()) { + auto mapOp = llvm::dyn_cast_or_null(op.getOperation()); + if (!mapOp || mapOp.getNumDpsInits() != 1) + return rewriter.notifyMatchFailure(op, "no mapOp"); + for (Value operand : mapOp.getInputs()) { // Case i. SimpleBcastReduction bcastReduction; int64_t reductionDim; if (isSimpleBcastReduction(operand.getDefiningOp(), &reductionDim, &bcastReduction)) { if (commonSource && commonSource != bcastReduction.operand) { - return failure(); + return rewriter.notifyMatchFailure(bcastReduction.bcast, + "no common reduction source"); } commonSource = bcastReduction.operand; if (commonReductionDim && *commonReductionDim != reductionDim) { - return failure(); + return rewriter.notifyMatchFailure(bcastReduction.reduction, + "no common reduction dim"); } commonReductionDim = reductionDim; simpleBcastReductions.push_back(bcastReduction); - // foundBcastReduction = true; continue; } // Case ii. - if (commonSource && commonSource != operand) return failure(); + if (commonSource && commonSource != operand) + return rewriter.notifyMatchFailure(op, "common source != operand"); commonSource = operand; - simpleBcastReductions.push_back(llvm::None); + simpleBcastReductions.push_back(std::nullopt); } - if (!commonReductionDim || !commonSource) return failure(); + if (!commonReductionDim || !commonSource) + return rewriter.notifyMatchFailure(op, "no common dim/src"); // Tile or fuse cwise root op. FailureOr tiledOp = tileOperationFn(op, *commonReductionDim); - if (failed(tiledOp)) return failure(); - setTransformationAttr(rewriter, *tiledOp); + if (failed(tiledOp)) + return rewriter.notifyMatchFailure(op, "call to tileOperationFn failed"); + setLabel(*tiledOp, kTileSoftmaxAppliedLabel); // Fuse through the bcast reduction chains. Value commonTiledSource; - for (int64_t i = 0; i < simpleBcastReductions.size(); i++) { + for (int64_t i = 0; i < static_cast(simpleBcastReductions.size()); + i++) { if (!simpleBcastReductions[i]) continue; // Fuse. @@ -110,7 +120,7 @@ LogicalResult tilePartialSoftmax( } // Also use the common tiled source value for the remaining operands. - for (int64_t i = 0; i < simpleBcastReductions.size(); i++) { + for (size_t i = 0; i < simpleBcastReductions.size(); i++) { if (simpleBcastReductions[i]) continue; (*tiledOp)->setOperand(i, commonTiledSource); } @@ -133,13 +143,15 @@ struct TilePartialSoftmaxPattern LogicalResult matchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const override { - if (hasTransformationAttr(op)) return failure(); + if (hasLabel(op, kTileSoftmaxAppliedLabel)) + return rewriter.notifyMatchFailure(op, "has tranformation attr"); // Only apply to non-fusable occurrences. bool hasFusableOccurrences = llvm::any_of( op->getUsers(), - [](Operation *op) { return llvm::isa(op); }); - if (hasFusableOccurrences) return failure(); + [](Operation *op) { return llvm::isa(op); }); + if (hasFusableOccurrences) + return rewriter.notifyMatchFailure(op, "has fusable occurrences"); return tilePartialSoftmax( op, rewriter, @@ -151,7 +163,8 @@ struct TilePartialSoftmaxPattern [&](OpBuilder &b, Operation *op) -> SmallVector { Location loc = op->getLoc(); SmallVector tileSizeValues; - for (int64_t i = 0; i < tileSizes.size(); i++) { + for (int64_t i = 0; i < static_cast(tileSizes.size()); + i++) { // Skip tiling the reduction dimension. By convention, this is a // tile size of 0. int64_t tileSizeInDim = @@ -165,12 +178,12 @@ struct TilePartialSoftmaxPattern tilingOptions.distributionLabel = distributionLabel; // Tile. FailureOr tilingResult = - tile(tilingOptions, rewriter, op); + tileUsingGmlSt(tilingOptions, rewriter, op); if (failed(tilingResult)) return failure(); rewriter.replaceOp(op, tilingResult->loop->getResults()); - setTransformationAttr(rewriter, tilingResult->tiledOp); - return tilingResult->tiledOp; + setLabel(tilingResult->tiledOps.front(), kTileSoftmaxAppliedLabel); + return tilingResult->tiledOps.front(); }); } @@ -180,10 +193,11 @@ struct TilePartialSoftmaxPattern std::string distributionLabel; }; -struct FusePartialSoftmaxPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct FusePartialSoftmaxPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(MaterializeOp op, + LogicalResult matchAndRewrite(tensor::ExtractSliceOp op, PatternRewriter &rewriter) const override { Value source = op.getSource(); Operation *def = source.getDefiningOp(); @@ -196,23 +210,25 @@ struct FusePartialSoftmaxPattern : public OpRewritePattern { [&](Operation *cwiseOp, int64_t /*commonReductionDim*/) -> FailureOr { auto iface = llvm::dyn_cast_or_null(cwiseOp); - if (!iface) return failure(); + if (!iface) { + return rewriter.notifyMatchFailure( + cwiseOp, "doesn't implement tiling iface"); + } // By construction, we assume that the tile spans the operand in the // common reduction dimension (`commonReductionDim`). // TODO(frgossen): Assert this assumption when we have moved to // unnested tiles. - // Extract tile offsets and sizes. - auto tile = op.getSet().getDefiningOp(); - if (!tile) return failure(); - // Fuse. - SmallVector offsets = tile.getMixedOffsets(); - SmallVector sizes = tile.getMixedSizes(); + SmallVector offsets = op.getMixedOffsets(); + SmallVector sizes = op.getMixedSizes(); FailureOr result = iface.generateResultTileValue(rewriter, 0, offsets, sizes); - if (failed(result)) return failure(); + if (failed(result)) { + return rewriter.notifyMatchFailure( + cwiseOp, "failed to generate result tile"); + } rewriter.replaceOp(op, *result); return result->getDefiningOp(); @@ -220,15 +236,15 @@ struct FusePartialSoftmaxPattern : public OpRewritePattern { } }; -struct FuseUnaryCwisePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct FuseUnaryCwisePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(MaterializeOp op, + LogicalResult matchAndRewrite(tensor::ExtractSliceOp op, PatternRewriter &rewriter) const override { // Match unary cwise ops. Operation *source = op.getSource().getDefiningOp(); - if (!isUnaryCwiseGenericOp(source)) return failure(); - + auto mapOp = dyn_cast_or_null(source); + if (!mapOp || mapOp.getNumDpsInputs() != 1) return failure(); // Fuse. FailureOr fused = createFusedOp(rewriter, op); if (failed(fused)) return failure(); @@ -250,7 +266,7 @@ struct TilingSoftmaxPass void getDependentDialects(DialectRegistry ®istry) const final { registry .insert(); - registerGmlStTilingInterfaceExternalModels(registry); + linalg::registerTilingInterfaceExternalModels(registry); } void runOnOperation() override { @@ -271,7 +287,7 @@ struct TilingSoftmaxPass } // Clean up by removing temporary attributes. - f.walk([](Operation *op) { removeTransformationAttr(op); }); + f.walk([](Operation *op) { removeLabel(op, kTileSoftmaxAppliedLabel); }); } }; diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/transforms.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/transforms.cc new file mode 100644 index 00000000000..fb20d311ed8 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/transforms.cc @@ -0,0 +1,312 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "gml_st/transforms/transforms.h" + +#include +#include +#include + +#include "gml_st/IR/gml_st_ops.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace gml_st { + +bool isZero(Value v) { return matchPattern(v, m_Zero()); } +bool isOne(Value v) { return matchPattern(v, m_One()); } + +bool hasSingleElementOperandsAndResults(Operation *op) { + auto isScalar = [](Type type) { + return !type.isa() || + (type.isa() && + hasSingleElement(type.cast())); + }; + return llvm::all_of(op->getOperandTypes(), isScalar) && + llvm::all_of(op->getResultTypes(), isScalar); +} + +/// Hoisting after vectorization +namespace { + +using mlir::vector::TransferReadOp; +using mlir::vector::TransferWriteOp; + +bool isLoopInvariantTransferWriteOp(ForOp forOp, TransferWriteOp candidate) { + // Indexing must not depend on `forOp`. + for (Value operand : candidate.getIndices()) + if (!forOp.isDefinedOutsideOfLoop(operand)) return false; + return candidate->hasOneUse(); +} + +/// Look for a TransferReadOp, in the given tensor users, accessing the same +/// offset as `write`. +FailureOr findMatchingTransferRead(TransferWriteOp write, + Value srcTensor) { + SmallVector users(srcTensor.getUsers().begin(), + srcTensor.getUsers().end()); + while (!users.empty()) { + Operation *user = users.pop_back_val(); + + auto read = dyn_cast(user); + if (read && read.getIndices() == write.getIndices() && + read.getVectorType() == write.getVectorType()) + return read; + } + return failure(); +} + +/// Check if the chunk of data inserted by `write` is read by any +/// other op than `candidateRead` or `terminator`. +bool tensorChunkAccessedByUnknownOp(TransferWriteOp write, + TransferReadOp candidateRead, Value tensor, + SetYieldOp terminator) { + // Make sure none of the other uses read the part of the tensor modified + // by the transfer_write. + llvm::SmallVector uses; + uses.push_back(tensor.getUses()); + while (!uses.empty()) { + for (OpOperand &use : uses.pop_back_val()) { + Operation *user = use.getOwner(); + // Skip the candidate and terminator uses, only inspect the "other" uses. + if (user == candidateRead || user == write || user == terminator) + continue; + // Consider all transitive uses through a extract_slice / insert_slice. + // Consider all transitive uses through a vector.transfer_write. + // Consider all nested uses through a gml_st::ForOp. We may have + // pass-through tensor arguments left from previous level of hoisting. + // TODO(vuson): atm we just bail because a stronger analysis is needed for + // these cases. + if (isa(user)) + return true; + + auto read = dyn_cast(user); + if (!read || !vector::isDisjointTransferIndices( + cast(read.getOperation()), + cast(write.getOperation()))) { + return true; + } + } + } + return false; +} + +ForOp replaceLoopWithNewYields(OpBuilder &builder, ForOp loop, + ValueRange newOutputOperands, + ValueRange newYieldValues, Value yieldSet) { + assert(newOutputOperands.size() == newYieldValues.size() && + "expected as many new yield values as new iter operands"); + // Create a new loop before the existing one, with the extra operands. + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(loop); + auto operands = llvm::to_vector(loop.getOutputs()); + operands.append(newOutputOperands.begin(), newOutputOperands.end()); + auto newLoop = builder.create( + loop.getLoc(), + llvm::to_vector<1>(llvm::map_range( + operands, [&](Value v) -> Type { return v.getType(); })), + loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), operands, + nullptr); + + Block *loopBody = loop.getBody(); + Block *newLoopBody = newLoop.getBody(); + + // Move the body of the original loop to the new loop. + builder.setInsertionPointToStart(newLoopBody); + IRMapping bvm; + for (Operation &bodyMember : loopBody->without_terminator()) { + builder.clone(bodyMember, bvm); + } + + // Generate the new yield values to use by using the callback and append the + // yield values to the set_yield operation. + auto oldYield = loop.getTerminator(); + ArrayRef newBBArgs = + newLoopBody->getArguments().take_back(newOutputOperands.size()); + { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(newLoopBody); + auto getMappedValues = [&](ValueRange values) { + return llvm::to_vector(llvm::map_range( + values, [&](Value value) { return bvm.lookupOrDefault(value); })); + }; + auto srcs = getMappedValues(oldYield.getSrcs()); + srcs.append(getMappedValues(newYieldValues)); + auto dsts = getMappedValues(oldYield.getDsts()); + dsts.append(newBBArgs.begin(), newBBArgs.end()); + auto sets = getMappedValues(oldYield.getSets()); + sets.append(newYieldValues.size(), bvm.lookupOrDefault(yieldSet)); + builder.create(newLoop.getLoc(), srcs, dsts, sets); + } + + // Remap the BlockArguments from the original loop to the new loop + // BlockArguments. + ArrayRef bbArgs = loopBody->getArguments(); + for (auto it : + llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size()))) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + + // Replace all uses of `newOutputOperands` with the corresponding basic block + // arguments. + for (auto it : llvm::zip(newOutputOperands, newBBArgs)) { + std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newLoop->isProperAncestor(user); + }); + } + + // Replace all uses of the original loop with corresponding values from the + // new loop. + loop.replaceAllUsesWith( + newLoop.getResults().take_front(loop.getNumResults())); + + return newLoop; +} + +/// Mechanical hoisting of a matching transfeSeread / transfer_write pair. +void hoistReadWrite(TransferReadOp read, TransferWriteOp write, + BlockArgument tensorBBArg, Value yieldSet) { + auto forOp = cast(tensorBBArg.getOwner()->getParentOp()); + + // Hoist the transfer_read op. + forOp.moveOutOfLoop(read); + + // FIXME: don't hardcode /*numIvs=*/1. + assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); + unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; + + // Update the source tensor. + read.getSourceMutable().assign(forOp.getOutputs()[initArgNumber]); + + // Hoist write after. + write->moveAfter(forOp); + + // Update the yield. + auto setYieldOp = forOp.getTerminator(); + setYieldOp->setOperand(initArgNumber, write.getSource()); + + // Rewrite `loop` with additional new yields. + OpBuilder b(read); + auto newForOp = replaceLoopWithNewYields(b, forOp, read.getVector(), + write.getVector(), yieldSet); + + // Transfer write has been hoisted, need to update the vector and tensor + // source. Replace the result of the loop to use the new tensor created + // outside the loop. + // Depending on whether a insert_slice is present or not, it carries the + // update on the tensor operands. + newForOp.getResult(initArgNumber).replaceAllUsesWith(write.getResult()); + write.getSourceMutable().assign(newForOp.getResult(initArgNumber)); + + // Always update with the newly yield tensor and vector. + write.getVectorMutable().assign(newForOp.getResults().back()); +} +} // namespace + +bool isIdentitySlice(ValueRange offsets, ValueRange strides) { + // Offsets must be all 0s and strides must be all 1s. + return llvm::all_of(offsets, [](Value v) { return isZero(v); }) && + llvm::all_of(strides, [](Value v) { return isOne(v); }); +} + +bool haveSameStaticShape(Value lhs, Value rhs) { + auto lhsType = lhs.getType().cast(); + auto rhsType = rhs.getType().cast(); + if (!lhsType.hasStaticShape() || !rhsType.hasStaticShape()) return false; + return lhsType == rhsType; +} + +void hoistRedundantVectorTransfersOnTensor(func::FuncOp func) { + bool changed = true; + while (changed) { + changed = false; + func.walk([&](ForOp forOp) { + auto terminator = forOp.getTerminator(); + for (const auto &[src, set, dst, outputArg] : + llvm::zip(terminator.getSrcs(), terminator.getSets(), + terminator.getDsts(), forOp.getRegionOutputArgs())) { + auto write = src.getDefiningOp(); + if (!write) continue; + if (!isLoopInvariantTransferWriteOp(forOp, write)) continue; + + auto srcTensor = write.getSource(); + if (srcTensor != outputArg) continue; + + auto tileOp = set.getDefiningOp(); + if (!tileOp || + !isIdentitySlice(tileOp.getOffsets(), tileOp.getStrides()) || + !haveSameStaticShape(src, dst)) + continue; + + // Find a read with the same type and indices. + auto matchingRead = findMatchingTransferRead(write, srcTensor); + + // Make sure none of the other uses reads the part of the tensor + // modified by the transfer_write. + if (failed(matchingRead) || + tensorChunkAccessedByUnknownOp(write, *matchingRead, srcTensor, + terminator)) + continue; + + hoistReadWrite(*matchingRead, write, outputArg, set); + changed = true; + forOp.erase(); + + // Need to interrupt and restart: erasing the loop messes up the walk. + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + // Apply canonicalization so the newForOp + yield folds immediately, thus + // cleaning up the IR and potentially enabling more hoisting. + if (changed) { + auto *ctx = func->getContext(); + RewritePatternSet patterns(ctx); + ForOp::getCanonicalizationPatterns(patterns, ctx); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + } + } +} + +void setLabel(Operation *op, StringRef name) { + op->setAttr(name, UnitAttr::get(op->getContext())); +} + +void removeLabel(Operation *op, StringRef name) { op->removeAttr(name); } + +bool hasLabel(Operation *op, StringRef name) { return op->hasAttr(name); } + +constexpr llvm::StringLiteral kOpLabel = "op_label"; + +bool hasMatchingLabel(Operation *op, StringRef label) { + auto opLabelAttr = op->getAttr(kOpLabel); + if (!opLabelAttr) return false; + + return opLabelAttr.cast().getValue() == label; +} + +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/transforms.h b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/transforms.h new file mode 100644 index 00000000000..22da1122e88 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/transforms.h @@ -0,0 +1,116 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MLIR_HLO_GML_ST_TRANSFORMS_TRANSFORMS_H +#define MLIR_HLO_GML_ST_TRANSFORMS_TRANSFORMS_H + +#include "gml_st/IR/gml_st_ops.h" +#include "llvm/ADT/Hashing.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassOptions.h" + +namespace mlir { + +class OpPassManager; + +namespace linalg { + +class LinalgOp; +struct TiledLinalgOp; +struct LinalgTilingOptions; + +} // namespace linalg +} // namespace mlir + +namespace mlir { +namespace gml_st { + +constexpr llvm::StringRef kPerfectlyTiledLoopLabel = + "__perfectly_tiled_loop_label__"; + +bool isZero(Value v); +bool isOne(Value v); + +template +bool hasSingleElement(ShapedTy type) { + return type.hasStaticShape() && type.getNumElements() == 1; +} +bool hasSingleElementOperandsAndResults(Operation *op); + +/// Hoist vector.transfer_read/vector.transfer_write pairs out of immediately +/// enclosing gml_st::ForOp iteratively, if the following conditions are true: +/// 1. The two ops access the same tensor with the same indices. +/// 2. All operands are invariant under the enclosing gml_st::ForOp. +/// 3. No uses of the tensor either dominate the transfer_read or are +/// dominated by the transfer_write (i.e. no aliasing between the write and +/// the read across the loop) +/// The transformation follows this logic: +/// 1. Look for transfer_write with a single use from ForOp terminator +/// 2. Check the uses of the matching block argument and look for a +/// transfer_read with the same indices. +/// 3. Check that all the other uses of the tensor argument are either +/// disjoint tensor_read or transfer_write. For transfer_write uses recurse to +/// make sure the new tensor has the same restrictions on its uses. +/// 4. Hoist the tensor_read/tensor_write and update the tensor SSA links. +/// +/// Example: +/// %for = gml_st.for ... outs (%arg6 = %out: tensor<8x4xf32>) { +/// %tile = gml_st.tile [0, 0] [8, 4] [1, 1] : !gml_st.tile<8x4> +/// ... +/// %read = vector.transfer_read %arg6[%c0, %c0] +/// %compute = foo(%read) : vector<8x4xf32> +/// %write = vector.transfer_write %compute, %arg6[%c0, %c0] +/// gml_st.set_yield %write into %arg6[%tile] +/// } : tensor<8x4xf32> +/// +/// will be transformed into: +/// +/// %read = vector.transfer_read %out[%c0, %c0] +/// %for = gml_st.for ... outs (%arg6 = %read: vector<8x4xf32>) { +/// %tile = gml_st.tile [0, 0] [8, 4] [1, 1] : !gml_st.tile<8x4> +/// ... +/// %compute = foo(%read) : vector<8x4xf32> +/// gml_st.set_yield %compute into %arg6[%tile] +/// } : vector<8x4xf32> +/// %write = vector.transfer_write %for, %out[%c0, %c0] +/// +/// After this transformation the gml_st.ForOp may have unused arguments that +/// can be remove by the canonicalization pass. +void hoistRedundantVectorTransfersOnTensor(func::FuncOp func); + +/// Returns true if `candidate`'s offsets are all 0s and strides are all 1s. +bool isIdentitySlice(ValueRange offsets, ValueRange strides); + +/// Returns true if `lhs` and `rhs` are of same static shape. +bool haveSameStaticShape(Value lhs, Value rhs); + +// Sets the attribute to the `op` that indicates that the op was transformed. +void setLabel(Operation *op, StringRef name); + +// Removes the attribute that indicates that it was transformed. +void removeLabel(Operation *op, StringRef name); + +// Checks if `op` has the attribute that indicates that it was transformed. +bool hasLabel(Operation *op, StringRef name); + +// Checks if `op` has the matching label attribute. +bool hasMatchingLabel(Operation *op, StringRef label); + +} // namespace gml_st +} // namespace mlir + +#endif // MLIR_HLO_GML_ST_TRANSFORMS_TRANSFORMS_H diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/triton_tiling/transform_matmul_for_triton.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/triton_tiling/transform_matmul_for_triton.cc new file mode 100644 index 00000000000..3e2c663d6ac --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/triton_tiling/transform_matmul_for_triton.cc @@ -0,0 +1,192 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/fusion/fusion.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/tiling/tiling.h" +#include "gml_st/transforms/transforms.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::gml_st { +namespace { + +#define GEN_PASS_DEF_TRANSFORMMATMULFORTRITONPASS +#include "gml_st/transforms/passes.h.inc" + +static constexpr llvm::StringRef kMatmulTransformedLabel = + "__matmul_transformed_label__"; + +FailureOr tileMatmul(PatternRewriter &rewriter, Operation *op, + ArrayRef tileSizes, bool distribute, + StringRef distributionLabel = "") { + TilingOptions opts; + opts.setTileSizeComputationFn(tileSizes); + opts.distribute = distribute; + opts.distributionLabel = distributionLabel; + return tileUsingGmlSt(opts, rewriter, cast(op)); +} + +/// Pattern to tile `linalg.matmul`, fuse `linalg.fill` into generated +/// `gml_st.parallel`, and peel the generated loops. +struct MatmulTransformPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + explicit MatmulTransformPattern(MLIRContext *context, + int64_t lhsParallelDimTileSize = 2, + int64_t rhsParallelDimTileSize = 4, + int64_t reductionDimTileSize = 8, + StringRef distributionLabel = "", + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + lhsParallelDimTileSize(lhsParallelDimTileSize), + rhsParallelDimTileSize(rhsParallelDimTileSize), + reductionDimTileSize(reductionDimTileSize), + distributionLabel(distributionLabel) {} + + LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp, + PatternRewriter &rewriter) const override { + if (hasLabel(matmulOp, kMatmulTransformedLabel)) + return rewriter.notifyMatchFailure(matmulOp, + "has already been transformed."); + + if (isa(matmulOp->getParentOp())) + return rewriter.notifyMatchFailure( + matmulOp, "has already been tiled by another pass."); + + auto cluster = findMapFusionCluster(matmulOp); + auto fusionCluster = cluster.operations; + Operation *tilingRoot = cluster.root; + // Tiling of linalg.map requires two dimensions, linalg.matmul requires + // three. + SmallVector parallelDimsTileSizes{lhsParallelDimTileSize, + rhsParallelDimTileSize}; + if (isa(tilingRoot)) parallelDimsTileSizes.push_back(0); + + auto tilingParallelDimsResult = + tileMatmul(rewriter, tilingRoot, parallelDimsTileSizes, + /*distribute=*/true, distributionLabel); + if (failed(tilingParallelDimsResult)) return failure(); + + // Update the results if tiling occurred. + if (tilingParallelDimsResult->loop != nullptr) { + rewriter.replaceOp(tilingRoot, + tilingParallelDimsResult->loop->getResults()); + tilingRoot = tilingParallelDimsResult->tiledOps.front(); + // Fuse ops into the loop. + fuseGreedily(rewriter, *tilingRoot->getBlock(), + [&](Operation *op) { return fusionCluster.contains(op); }); + (void)fuseFillOpsIntoParallelOp( + rewriter, cast(tilingParallelDimsResult->loop)); + } + + auto inputFusionFilterFn = [&](Operation *op) { + return isa(op); + }; + + // Second level tiling: reduction dimension. + SmallVector reductionDimsTileSizes{0, 0, reductionDimTileSize}; + for (auto op : + llvm::to_vector(tilingRoot->getBlock()->getOps())) { + fuseGreedily(rewriter, *op->getBlock(), inputFusionFilterFn); + + auto tilingReductionDimsResult = tileMatmul( + rewriter, op, reductionDimsTileSizes, /*distribute=*/false); + if (failed(tilingReductionDimsResult)) return failure(); + + // Update the results if tiling occurred. + if (tilingReductionDimsResult->loop != nullptr) { + rewriter.replaceOp(op, tilingReductionDimsResult->loop->getResults()); + op = + cast(tilingReductionDimsResult->tiledOps.front()); + + fuseGreedily(rewriter, *op->getBlock(), inputFusionFilterFn); + } + + setLabel(op, kMatmulTransformedLabel); + } + + return success(); + } + + private: + int64_t lhsParallelDimTileSize; + int64_t rhsParallelDimTileSize; + int64_t reductionDimTileSize; + std::string distributionLabel; +}; + +struct TransformMatmulForTritonPass + : public impl::TransformMatmulForTritonPassBase< + TransformMatmulForTritonPass> { + TransformMatmulForTritonPass() = default; + + explicit TransformMatmulForTritonPass(llvm::ArrayRef matmulTileSizes, + StringRef distributionLabelParam) { + tileSizes = matmulTileSizes; + distributionLabel = distributionLabelParam.str(); + } + + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + linalg::registerTilingInterfaceExternalModels(registry); + } + + void runOnOperation() override { + func::FuncOp f = getOperation(); + MLIRContext *ctx = &getContext(); + + // Just do tiling and fusion on linalg.matmul. + if (tileSizes.empty()) { + tileSizes = {4, 4, 4}; + } + assert(tileSizes.size() == 3 && + "Tiling sizes for MatMul should have 3 elements"); + RewritePatternSet patterns(ctx); + patterns.add(ctx, tileSizes[0], tileSizes[1], + tileSizes[2], distributionLabel); + if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { + return signalPassFailure(); + } + // Ensure we drop the marker in the end. + f.walk( + [](linalg::MatmulOp op) { removeLabel(op, kMatmulTransformedLabel); }); + } +}; + +} // namespace + +std::unique_ptr> +createTransformMatmulForTritonPass(llvm::ArrayRef matmulTileSizes, + StringRef distributionLabel) { + return std::make_unique( + matmulTileSizes, distributionLabel); +} + +} // namespace mlir::gml_st diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorization.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorization.cc new file mode 100644 index 00000000000..94d505633d6 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorization.cc @@ -0,0 +1,93 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "gml_st/transforms/vectorization/vectorization.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace gml_st { +namespace { + +using mlir::tensor::ExpandShapeOp; +using mlir::vector::TransferReadOp; +using mlir::vector::TransferWriteOp; + +// Rewrite `vector.transfer_read(linalg.expand_shape)` as +// `vector.shape_cast(vector.transfer_read)`. +struct TransferReadOfOneDimExpandShape + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + TransferReadOp vectorRead, + mlir::PatternRewriter &rewriter) const override { + auto expand = vectorRead.getSource().getDefiningOp(); + if (!expand) return failure(); + + auto expandSrc = expand.getSrc(); + auto expandSrcType = expand.getSrcType(); + auto expandDstType = expand.getResultType(); + if (expandSrcType.getRank() != 1 || expandDstType.getRank() != 2) + return failure(); + + auto resultType = vectorRead.getType().dyn_cast(); + if (!resultType || resultType.getShape() != expandDstType.getShape()) + return failure(); + + auto zero = rewriter.create(vectorRead.getLoc(), 0); + auto map = mlir::AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0)}, + vectorRead.getContext()); + // TODO(pifon): Also support canonicalization in case the map is not an + // identity. + if (!map.isIdentity()) return failure(); + + auto newRead = rewriter.create( + vectorRead.getLoc(), + mlir::VectorType::get(expandSrcType.getShape(), + expandSrcType.getElementType()), + expandSrc, mlir::ValueRange{zero}, mlir::AffineMapAttr::get(map), + vectorRead.getPadding(), + /*mask=*/mlir::Value(), rewriter.getBoolArrayAttr({true})); + rewriter.replaceOpWithNewOp( + vectorRead, vectorRead.getType(), newRead); + return success(); + } +}; + +} // namespace + +void populateTransferReadOfOneDimExpandShapePattern( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +RewritePatternSet getDefaultVectorizationPatterns(MLIRContext *ctx) { + RewritePatternSet patterns(ctx); + mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + mlir::vector::populateVectorReductionToContractPatterns(patterns); + patterns.add(ctx, + /*benefit=*/2); + TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); + return patterns; +} + +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorization.h b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorization.h new file mode 100644 index 00000000000..c3e2ca2e04c --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorization.h @@ -0,0 +1,61 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MLIR_HLO_GML_ST_TRANSFORMS_VECTORIZATION_VECTORIZATION_H +#define MLIR_HLO_GML_ST_TRANSFORMS_VECTORIZATION_VECTORIZATION_H + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace gml_st { + +// The upper limit for vectorization of untiled `linalg.fill`. If a tensor has a +// static shape with more elements, then `linalg.fill` won't be vectorized. It +// is expected that such operations are tiled to get to small static shapes. +static constexpr int64_t kNumElementsThreshold = 1024; + +// TODO(manany): This should be parameterized later on depending on hardware. +static constexpr int64_t kNumElementsVectorization = 8; + +template +struct VectorizationPattern : public mlir::OpRewritePattern { + VectorizationPattern(MLIRContext *context, + llvm::function_ref matchFn, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit), filterFn(matchFn) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (!filterFn(op)) + return rewriter.notifyMatchFailure(op, "did not match filter"); + return mlir::linalg::vectorize(rewriter, op); + } + + private: + llvm::function_ref filterFn; +}; + +void populateTransferReadOfOneDimExpandShapePattern( + RewritePatternSet &patterns); + +RewritePatternSet getDefaultVectorizationPatterns(MLIRContext *ctx); + +} // namespace gml_st +} // namespace mlir + +#endif // MLIR_HLO_GML_ST_TRANSFORMS_VECTORIZATION_VECTORIZATION_H diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_copy.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_copy.cc new file mode 100644 index 00000000000..d37a1c30207 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_copy.cc @@ -0,0 +1,108 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/vectorization/vectorization.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace gml_st { +namespace { + +#define GEN_PASS_DEF_VECTORIZECOPYPASS +#include "gml_st/transforms/passes.h.inc" + +/// Custom vectorization pattern for small and non-contiguous memref::CopyOp. +struct CopyVectorizationPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::CopyOp op, + PatternRewriter &rewriter) const override { + auto srcType = op.getSource().getType().cast(); + auto targetType = op.getTarget().getType().cast(); + + auto isStaticShapeAndContiguousRowMajor = [](MemRefType type) { + if (!type.hasStaticShape()) return false; + + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(type, strides, offset))) return false; + + int64_t runningStride = 1; + for (unsigned i = strides.size(); i > 0; --i) { + if (strides[i - 1] != runningStride) return false; + runningStride *= type.getDimSize(i - 1); + } + return true; + }; + + auto isContiguousMemrefType = [&](BaseMemRefType type) { + auto memrefType = type.dyn_cast(); + return memrefType && (memrefType.getLayout().isIdentity() || + isStaticShapeAndContiguousRowMajor(memrefType)); + }; + + auto isSmallMemrefType = [&](BaseMemRefType type) { + auto memrefType = type.dyn_cast(); + return memrefType && memrefType.hasStaticShape() && + memrefType.getNumElements() > 0 && + memrefType.getNumElements() < kNumElementsThreshold; + }; + + // If memref has an identity layout or is contiguous with an arbitrary + // offset, it will be turned into llvm.memcpy intrinsic later, do not + // vectorize it. + if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) { + return failure(); + } + + // If memref is too big, vectorizing it actually explodes the compilation + // time. Also, ignore empty memrefs, which will be handled by memrefCopy + // function. + if (!isSmallMemrefType(srcType) || !isSmallMemrefType(targetType)) { + return failure(); + } + return linalg::vectorizeCopy(rewriter, op); + } +}; + +struct VectorizeCopyPass + : public impl::VectorizeCopyPassBase { + void runOnOperation() override { + auto func = getOperation(); + auto *ctx = func.getContext(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> createVectorizeCopyPass() { + return std::make_unique(); +} + +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc new file mode 100644 index 00000000000..3b9eb9b3686 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc @@ -0,0 +1,264 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/transforms.h" +#include "gml_st/transforms/vectorization/vectorization.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "thlo/IR/thlo_ops.h" + +namespace mlir { +namespace gml_st { +namespace { + +#define GEN_PASS_DEF_VECTORIZEFORCPUPASS +#include "gml_st/transforms/passes.h.inc" + +using mlir::linalg::BroadcastOp; +using mlir::linalg::FillOp; +using mlir::linalg::GenericOp; +using mlir::linalg::MapOp; +using mlir::linalg::MatmulOp; +using mlir::linalg::Mmt4DOp; +using mlir::linalg::ReduceOp; +using mlir::linalg::TransposeOp; +using mlir::tensor::ExpandShapeOp; +using mlir::thlo::ReverseOp; +using mlir::vector::TransferReadOp; +using mlir::vector::TransferWriteOp; + +// Rewrite `vector.transfer_read(linalg.expand_shape)` as +// `vector.shape_cast(vector.transfer_read)`. +struct TransferReadOfOneDimExpandShape + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + vector::TransferReadOp vectorRead, + mlir::PatternRewriter &rewriter) const override { + auto expand = vectorRead.getSource().getDefiningOp(); + if (!expand) return failure(); + + auto expandSrc = expand.getSrc(); + auto expandSrcType = expand.getSrcType(); + auto expandDstType = expand.getResultType(); + if (expandSrcType.getRank() != 1 || expandDstType.getRank() != 2) + return failure(); + + auto resultType = vectorRead.getType().dyn_cast(); + if (!resultType || resultType.getShape() != expandDstType.getShape()) + return failure(); + + auto zero = rewriter.create(vectorRead.getLoc(), 0); + auto map = mlir::AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0)}, + vectorRead.getContext()); + // TODO(pifon): Also support canonicalization in case the map is not an + // identity. + if (!map.isIdentity()) return failure(); + + auto newRead = rewriter.create( + vectorRead.getLoc(), + mlir::VectorType::get(expandSrcType.getShape(), + expandSrcType.getElementType()), + expandSrc, mlir::ValueRange{zero}, mlir::AffineMapAttr::get(map), + vectorRead.getPadding(), + /*mask=*/mlir::Value(), rewriter.getBoolArrayAttr({true})); + rewriter.replaceOpWithNewOp( + vectorRead, vectorRead.getType(), newRead); + return success(); + } +}; + +// This currently matches for all thlo.reverse of the form 1x1x..x1xVectorSize. +// DimSize < kNumElementsVectorization will be handled by Scalarization. +bool isPerfectlyTiledReverse(thlo::ReverseOp reverseOp) { + auto inputType = reverseOp.getInput().getType(); + for (unsigned i = 0; i < inputType.getRank(); ++i) { + if (inputType.isDynamicDim(i)) { + return false; + } + if (i == inputType.getRank() - 1) { + return inputType.getDimSize(i) == kNumElementsVectorization && + llvm::is_contained(reverseOp.getReverseDimensions(), i); + } + if (inputType.getDimSize(i) != 1) { + return false; + } + } + return false; +} + +// Rewrite thlo.reverse of pattern 1x1x..x1xVectorSize as vector.transfer_read +// followed by vector.shuffle followed by vector.transfer_write. +struct ThloReverseVectorizationPattern + : public mlir::OpRewritePattern { + explicit ThloReverseVectorizationPattern(MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(thlo::ReverseOp op, + PatternRewriter &rewriter) const override { + if (!isPerfectlyTiledReverse(op)) + return rewriter.notifyMatchFailure(op, "did not match filter"); + + auto inputType = op.getInput().getType(); + auto vecTargetType = + RankedTensorType::get(inputType.getShape()[inputType.getRank() - 1], + inputType.getElementType()); + Value zero = rewriter.create(op.getLoc(), 0); + SmallVector indices(op.getInit().getType().getRank(), zero); + + auto readInput = rewriter.create( + op.getLoc(), + VectorType::get(vecTargetType.getShape(), + vecTargetType.getElementType()), + op.getInput(), indices); + + SmallVector mask; + int64_t maskSize = inputType.getShape()[inputType.getRank() - 1]; + mask.reserve(maskSize); + for (int64_t i = maskSize - 1; i >= 0; --i) { + mask.push_back(i); + } + auto shuffle = rewriter.create(op.getLoc(), readInput, + readInput, mask); + + rewriter.replaceOpWithNewOp( + op, shuffle.getResult(), op.getInit(), indices); + return success(); + } +}; + +struct IdentityTransposeOpFoldingPattern + : public OpRewritePattern { + explicit IdentityTransposeOpFoldingPattern(MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(TransposeOp op, + PatternRewriter & /*rewriter*/) const override { + auto perm = op.getPermutation(); + for (int64_t i = 0; static_cast(i) < perm.size(); ++i) { + if (perm[i] != i) return failure(); + } + + if (!hasSingleElementOperandsAndResults(op)) return failure(); + + op.replaceAllUsesWith(SmallVector(1, op.getInput())); + return success(); + } +}; + +bool isInsideGmlStLoop(Operation *op) { + Operation *parent = op->getParentOp(); + return isa(parent) || isa(parent); +} + +bool isFillTiledOrSmall(linalg::FillOp fill) { + if (isInsideGmlStLoop(fill)) return true; + + // Allow vectorization for static shapes with low number of elements. + auto outputType = fill.output().getType().dyn_cast(); + return outputType && outputType.hasStaticShape() && + outputType.getNumElements() < kNumElementsThreshold; +} + +struct VectorizeForCPUPass + : public impl::VectorizeForCPUPassBase { + void runOnOperation() override { + auto func = getOperation(); + auto *ctx = func.getContext(); + + auto hasSmallStaticOutputs = [&](Operation *op) { + return llvm::all_of(op->getResultTypes(), [](Type type) { + auto outputType = type.dyn_cast(); + return outputType && outputType.hasStaticShape() && + outputType.getNumElements() < kNumElementsThreshold; + }); + }; + auto isPerfectlyTiledLoop = [&](Operation *op) { + return (isa(op)) && + hasLabel(op, kPerfectlyTiledLoopLabel); + }; + auto isInsidePerfectlyTiledLoop = [&](Operation *op) { + return isPerfectlyTiledLoop(op->getParentOp()); + }; + auto isInsidePerfectlyTiledLoopOrSmall = [&](Operation *op) { + return !hasSingleElementOperandsAndResults(op) && + (isInsidePerfectlyTiledLoop(op) || hasSmallStaticOutputs(op)); + }; + { + RewritePatternSet patterns = getDefaultVectorizationPatterns(ctx); + TransferReadOp::getCanonicalizationPatterns(patterns, ctx); + // clang-format off + patterns.add< + VectorizationPattern, + VectorizationPattern, + VectorizationPattern, + VectorizationPattern, + VectorizationPattern, + VectorizationPattern, + VectorizationPattern + >(ctx, isInsidePerfectlyTiledLoopOrSmall); + // clang-format on + patterns.add>(ctx, isFillTiledOrSmall); + populateTransferReadOfOneDimExpandShapePattern(patterns); + patterns.add(ctx); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + } + + { + RewritePatternSet patterns = getDefaultVectorizationPatterns(ctx); + TransferReadOp::getCanonicalizationPatterns(patterns, ctx); + linalg::populatePadOpVectorizationPatterns(patterns); + patterns.add(ctx); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + } + + // Hoisting transfer_read/transfer_write. + { + RewritePatternSet patterns(ctx); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + + hoistRedundantVectorTransfersOnTensor(func); + } + // Hoisting transfer_read/transfer_write. + linalg::hoistRedundantVectorTransfersOnTensor(func); + } +}; + +} // namespace + +std::unique_ptr> createVectorizeForCPUPass() { + return std::make_unique(); +} + +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/vectorization.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_gpu.cc similarity index 62% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/vectorization.cc rename to tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_gpu.cc index bad92c2c8f6..11ba182e597 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/vectorization.cc +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_gpu.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,158 +15,41 @@ limitations under the License. #include #include +#include #include -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "mlir-hlo/Dialect/gml_st/transforms/vector_utils.h" +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/transforms.h" +#include "gml_st/transforms/vectorization/vectorization.h" +#include "gml_st/utils/vector_utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace gml_st { namespace { -#define GEN_PASS_DEF_VECTORIZEGMLSTLOOPSPASS -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h.inc" +#define GEN_PASS_DEF_VECTORIZEFORGPUPASS +#include "gml_st/transforms/passes.h.inc" +using mlir::linalg::BroadcastOp; using mlir::linalg::FillOp; using mlir::linalg::GenericOp; +using mlir::linalg::LinalgOp; +using mlir::linalg::MapOp; using mlir::linalg::MatmulOp; -using mlir::tensor::ExpandShapeOp; +using mlir::linalg::ReduceOp; using mlir::vector::TransferReadOp; using mlir::vector::TransferWriteOp; -// The upper limit for vectorization of untiled `linalg.fill`. If a tensor has a -// static shape with more elements, then `linalg.fill` won't be vectorized. It -// is expected that such operations are tiled to get to small static shapes. -constexpr int64_t kNumElementsThreshold = 1024; - -// Rewrite `vector.transfer_read(linalg.expand_shape)` as -// `vector.shape_cast(vector.transfer_read)`. -struct TransferReadOfOneDimExpandShape - : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - TransferReadOp vectorRead, - mlir::PatternRewriter &rewriter) const override { - auto expand = vectorRead.getSource().getDefiningOp(); - if (!expand) return failure(); - - auto expandSrc = expand.getSrc(); - auto expandSrcType = expand.getSrcType(); - auto expandDstType = expand.getResultType(); - if (expandSrcType.getRank() != 1 || expandDstType.getRank() != 2) - return failure(); - - auto resultType = vectorRead.getType().dyn_cast(); - if (!resultType || resultType.getShape() != expandDstType.getShape()) - return failure(); - - auto zero = rewriter.create(vectorRead.getLoc(), 0); - auto map = mlir::AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0)}, - vectorRead.getContext()); - // TODO(pifon): Also support canonicalization in case the map is not an - // identity. - if (!map.isIdentity()) return failure(); - - auto newRead = rewriter.create( - vectorRead.getLoc(), - mlir::VectorType::get(expandSrcType.getShape(), - expandSrcType.getElementType()), - expandSrc, mlir::ValueRange{zero}, mlir::AffineMapAttr::get(map), - vectorRead.getPadding(), - /*mask=*/mlir::Value(), rewriter.getBoolArrayAttr({true})); - rewriter.replaceOpWithNewOp( - vectorRead, vectorRead.getType(), newRead); - return success(); - } -}; - -// Rewrite materialize of scalar from 1-element vector into a vector.extract / -// vector.extractelement. -struct MaterializeFromSingleElementToExtractPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(MaterializeOp op, - PatternRewriter &rewriter) const override { - Value source = op.getSource(); - auto sourceType = source.getType().dyn_cast(); - if (!sourceType || sourceType.getNumDynamicDims() > 0 || - sourceType.getNumElements() > 1) { - return rewriter.notifyMatchFailure( - op, "source should be a single element vector"); - } - if (op.getResult().getType().isa()) - return rewriter.notifyMatchFailure(op, "result should be a scalar"); - - int64_t rank = sourceType.getRank(); - if (rank == 0) { - // vector.extract doesn't support 0D tensors at the moment, - // use vector.extractelement. - rewriter.replaceOpWithNewOp(op, source); - return success(); - } - rewriter.replaceOpWithNewOp( - op, source, SmallVector(rank, 0)); - return success(); - } -}; - -// Prepend a set_yield of scalar into 1-element vector with a vector.insert. -struct SetYieldOfScalarToVectorPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SetYieldOp op, - PatternRewriter &rewriter) const override { - auto tryRewrite = [&](Value dst, Value set, OpOperand &src) { - if (!dst.getType().isa()) return failure(); - if (src.get().getType().isa()) return failure(); - auto tileOp = set.getDefiningOp(); - if (!tileOp || !tileOp.getOffsets().empty()) return failure(); - - src.set(rewriter.create(op.getLoc(), src.get(), dst, - tileOp.getStaticOffsets())); - return success(); - }; - - if (llvm::none_of( - llvm::zip_first(op.getDsts(), op.getSets(), op->getOpOperands()), - [&](auto &&tuple) { - return succeeded(std::apply(tryRewrite, tuple)); - })) { - return rewriter.notifyMatchFailure( - op, "expected scalar srcs and static offsets"); - } - - return success(); - } -}; - -template -struct VectorizationPattern : public mlir::OpRewritePattern { - VectorizationPattern(MLIRContext *context, - llvm::function_ref matchFn, - mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern(context, benefit), filterFn(matchFn) {} - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - if (!filterFn(op)) - return rewriter.notifyMatchFailure(op, "did not match filter"); - return mlir::linalg::vectorize(rewriter, op); - } - - private: - llvm::function_ref filterFn; -}; - // Generates an offset of all 0s suitable as the index paramter for the builder // of vector.transfer_read or vector.transfer_write with input or output // `value`, respectively. @@ -181,7 +64,7 @@ SmallVector generateDefaultOffsetFor(Value value, // Converts the ranked-tensor-typed `bvm`-mapped operands of `op` into vectors // via vector.transfer_read. Updates `bvm`'s mapping of `op`'s operands to the // newly created vector values. -void convertTensorOperandsToVector(Operation *op, BlockAndValueMapping &bvm, +void convertTensorOperandsToVector(Operation *op, IRMapping &bvm, OpBuilder &builder) { OpBuilder::InsertionGuard guard(builder); for (Value operand : op->getOperands()) { @@ -203,8 +86,7 @@ void convertTensorOperandsToVector(Operation *op, BlockAndValueMapping &bvm, // `op`'s results to the newly generated tensors. Expects that the operation's // results are vectors, and the destinations tensors. void convertVectorResultsToTensor(ValueRange results, ValueRange destinations, - BlockAndValueMapping &bvm, - OpBuilder &builder) { + IRMapping &bvm, OpBuilder &builder) { for (auto [result, dest] : llvm::zip(results, destinations)) { Value mappedResult = bvm.lookupOrDefault(result); // Skip over scalars and leave them as is. @@ -213,140 +95,13 @@ void convertVectorResultsToTensor(ValueRange results, ValueRange destinations, "op's result should be a vector"); assert(dest.getType().isa() && "destination should be a tensor"); - auto writeOp = builder.create( + auto writeOp = builder.create( mappedResult.getLoc(), mappedResult, dest, generateDefaultOffsetFor(dest, builder)); bvm.map(result, writeOp.getResult()); } } -// Rewrite tensor.extract on single-element tensors into a vector.extract. -struct TensorToElementVectorizationPattern - : public mlir::OpRewritePattern { - TensorToElementVectorizationPattern( - MLIRContext *context, llvm::function_ref matchFn, - mlir::PatternBenefit benefit = 1) - : mlir::OpRewritePattern(context, benefit), - filterFn(matchFn) {} - - LogicalResult matchAndRewrite(tensor::ExtractOp op, - PatternRewriter &rewriter) const override { - if (!filterFn(op)) - return rewriter.notifyMatchFailure(op, "did not match filter"); - TensorType tensorType = op.getTensor().getType(); - if (tensorType.getNumDynamicDims() > 0 || tensorType.getNumElements() > 1) - return rewriter.notifyMatchFailure(op, "should have a single element"); - - BlockAndValueMapping bvm; - convertTensorOperandsToVector(op, bvm, rewriter); - if (tensorType.getRank() == 0) { - // ExtractOp only supports ranks > 0, for rank = 0 use ExtractElementOp - rewriter.replaceOpWithNewOp( - op, bvm.lookupOrDefault(op.getTensor())); - } else { - rewriter.replaceOpWithNewOp( - op, bvm.lookupOrDefault(op.getTensor()), - SmallVector(tensorType.getRank(), 0)); - } - return success(); - } - - private: - llvm::function_ref filterFn; -}; - -// Rewrite vector.transfer_read(tensor.empty) into a constant vector of the -// right size. This is our temporary way of expressing the nonexistent -// vector.undef, which creates a vector to be used in destination-passing-style -// ops. -// TODO(b/255779480): Figure out how to properly solve this issue. -struct TensorEmptyToVectorBroadcastPattern - : public OpRewritePattern { - TensorEmptyToVectorBroadcastPattern( - MLIRContext *context, llvm::function_ref filterFn, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), filterFn(filterFn) {} - - LogicalResult matchAndRewrite(TransferReadOp op, - PatternRewriter &rewriter) const override { - if (failed(matchSimpleTransferOp(op, rewriter))) return failure(); - auto tensorEmpty = op.getSource().getDefiningOp(); - if (!tensorEmpty) - return rewriter.notifyMatchFailure(op, "source should be tensor.empty"); - VectorType vectorType = op.getResult().getType().dyn_cast(); - if (!vectorType) - return rewriter.notifyMatchFailure(op, "result should be a vector"); - Type elementType = vectorType.getElementType(); - TypedAttr nanAttr; - if (elementType.isa()) { - nanAttr = rewriter.getIntegerAttr(elementType, 0l); - } else if (elementType.isa()) { - nanAttr = rewriter.getFloatAttr(elementType, - std::numeric_limits::quiet_NaN()); - } else { - return rewriter.notifyMatchFailure( - op, "should operate on integer or floating point vectors"); - } - - rewriter.replaceOpWithNewOp( - op, DenseElementsAttr::get(vectorType, nanAttr)); - return success(); - } - - private: - llvm::function_ref filterFn; -}; - -struct MaterializeOpVectorizationPattern - : public OpRewritePattern { - MaterializeOpVectorizationPattern( - MLIRContext *context, llvm::function_ref filterFn, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), filterFn(filterFn) {} - - LogicalResult matchAndRewrite(MaterializeOp op, - PatternRewriter &rewriter) const override { - if (!filterFn(op)) - return rewriter.notifyMatchFailure(op, "did not match filter"); - TypedValue source = op.getSource(); - ShapedType sourceType = source.getType(); - // TODO(b/244314345): Support imperfect tiling, which results in dynamic - // shapes. - if (!sourceType.isa() || - sourceType.getNumDynamicDims() > 0 || - !op.getSet().getType().cast().hasStaticShape()) - return rewriter.notifyMatchFailure(op, "input is not statically shaped"); - - Location loc = op.getLoc(); - BlockAndValueMapping bvm; - convertTensorOperandsToVector(op, bvm, rewriter); - Type newResult = op.getResult().getType(); - if (auto tensorResult = newResult.dyn_cast()) { - newResult = VectorType::get(tensorResult.getShape(), - tensorResult.getElementType()); - } - Value vectorMaterialize = rewriter.create( - loc, newResult, bvm.lookupOrDefault(source), op.getSet()); - bvm.map(op, vectorMaterialize); - if (auto vectorType = newResult.dyn_cast()) { - // The result is not a scalar, generate a TransferWrite back to tensor. - // transfer_write uses destination passing style, so we need to "invent" a - // destination tensor. The entinre tensor_write op, together with the - // invented tensor will be folded when vectorizing the final - // gml_st.set_yield op. - auto emptyTensor = rewriter.create( - loc, vectorType.getShape(), vectorType.getElementType()); - convertVectorResultsToTensor(op->getResults(), {emptyTensor}, bvm, - rewriter); - } - rewriter.replaceOp(op, bvm.lookupOrDefault(op)); - return success(); - } - - private: - llvm::function_ref filterFn; -}; - // Converts static tensors among `types` to their equivalent vectors. SmallVector convertToVectorTypes(TypeRange types) { return llvm::to_vector<1>(llvm::map_range(types, [&](Type type) -> Type { @@ -361,8 +116,7 @@ SmallVector convertToVectorTypes(TypeRange types) { // Copies the body of a loop `op` that is being vectorized, vectorizing the // terminator, and stores the mapping to new values into `bvm`. void copyLoopBodyAndVectorizeTerminator(LoopLikeOpInterface op, - OpBuilder &builder, - BlockAndValueMapping &bvm) { + OpBuilder &builder, IRMapping &bvm) { auto &blocks = op.getLoopBody().getBlocks(); assert(blocks.size() == 1 && "loop body should contain a single block"); Block &block = blocks.front(); @@ -375,24 +129,30 @@ void copyLoopBodyAndVectorizeTerminator(LoopLikeOpInterface op, // Vectorizes a gml_st.parallel `op`, and stores the mapping from old to new // values into `bvm`. -ParallelOp vectorizeLoopLikeOp(ParallelOp op, BlockAndValueMapping &bvm, +ParallelOp vectorizeLoopLikeOp(ParallelOp op, IRMapping &bvm, PatternRewriter &rewriter) { - Optional distTypeAttr; + convertTensorOperandsToVector(op, bvm, rewriter); + auto outputs = llvm::to_vector(llvm::map_range( + op.getOutputs(), [&](Value v) { return bvm.lookupOrDefault(v); })); + + std::optional distTypeAttr; if (auto distType = op.getDistributionType()) distTypeAttr = rewriter.getStringAttr(*distType); return rewriter.create( op.getLoc(), convertToVectorTypes(op->getResultTypes()), - op.getLowerBound(), op.getUpperBound(), op.getStep(), distTypeAttr, - [&](OpBuilder &builder, Location, ValueRange inductionVars) { + op.getLowerBound(), op.getUpperBound(), op.getStep(), outputs, + distTypeAttr, + [&](OpBuilder &builder, Location, ValueRange inductionVars, + ValueRange outputs) { bvm.map(op.getInductionVars(), inductionVars); + bvm.map(op.getRegionOutputArgs(), outputs); copyLoopBodyAndVectorizeTerminator(op, builder, bvm); }); } // Vectorizes a gml_st.for `op`, and stores the mapping from old to new // values into `bvm`. -ForOp vectorizeLoopLikeOp(ForOp op, BlockAndValueMapping &bvm, - PatternRewriter &rewriter) { +ForOp vectorizeLoopLikeOp(ForOp op, IRMapping &bvm, PatternRewriter &rewriter) { convertTensorOperandsToVector(op, bvm, rewriter); auto outputs = llvm::to_vector(llvm::map_range( op.getOutputs(), [&](Value v) { return bvm.lookupOrDefault(v); })); @@ -430,15 +190,17 @@ struct LoopLikeOpVectorizationPattern : public OpRewritePattern { auto dstTensor = dstType.template dyn_cast(); // TODO(b/244314345): Support imperfect tiling, which results in dynamic // shapes. - if (!dstTensor || dstTensor.getNumDynamicDims() > 0) + if (!dstTensor || dstTensor.getNumDynamicDims() > 0) { return rewriter.notifyMatchFailure( op, "destination tensors should be statically shaped"); + } hasTensor = true; if (!srcType.template isa()) continue; auto srcTensor = srcType.template dyn_cast(); - if (!srcTensor || srcTensor.getNumDynamicDims() > 0) + if (!srcTensor || srcTensor.getNumDynamicDims() > 0) { return rewriter.notifyMatchFailure( op, "source tensors should be statically shaped"); + } } if (!hasTensor) { return rewriter.notifyMatchFailure( @@ -451,8 +213,7 @@ struct LoopLikeOpVectorizationPattern : public OpRewritePattern { op, "shoud not use set_yield accumulators"); } - Location loc = op.getLoc(); - BlockAndValueMapping bvm; + IRMapping bvm; auto vectorLoopLikeOp = vectorizeLoopLikeOp(op, bvm, rewriter); bvm.map(op.getResults(), vectorLoopLikeOp.getResults()); @@ -463,49 +224,373 @@ struct LoopLikeOpVectorizationPattern : public OpRewritePattern { op.getResults(), [&](Value v) { return bvm.lookupOrDefault(v); })); rewriter.replaceOp(op, mappedResults); + return success(); } private: llvm::function_ref filterFn; }; +// Prepend a set_yield of scalar into 1-element vector with a vector.insert. +struct SetYieldOfScalarToVectorPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; -RewritePatternSet getDefaultVectorizationPatterns(MLIRContext *ctx) { - RewritePatternSet patterns(ctx); - mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); - mlir::vector::populateVectorReductionToContractPatterns(patterns); - patterns.add(ctx, - /*benefit=*/2); - TransferReadOp::getCanonicalizationPatterns(patterns, ctx); - TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); - return patterns; -} + LogicalResult matchAndRewrite(SetYieldOp op, + PatternRewriter &rewriter) const override { + auto tryRewrite = [&](Value dst, Value set, OpOperand &src) { + if (!dst.getType().isa()) return failure(); + if (src.get().getType().isa()) return failure(); + auto tileOp = set.getDefiningOp(); + if (!tileOp || !tileOp.getOffsets().empty()) return failure(); + + src.set(rewriter.create(op.getLoc(), src.get(), dst, + tileOp.getStaticOffsets())); + return success(); + }; + + if (llvm::none_of( + llvm::zip_first(op.getDsts(), op.getSets(), op->getOpOperands()), + [&](auto &&tuple) { + return succeeded(std::apply(tryRewrite, tuple)); + })) { + return rewriter.notifyMatchFailure( + op, "expected scalar srcs and static offsets"); + } + + return success(); + } +}; + +// Rewrite materialize of scalar from 1-element vector into a vector.extract / +// vector.extractelement. +struct MaterializeFromSingleElementToExtractPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MaterializeOp op, + PatternRewriter &rewriter) const override { + Value source = op.getSource(); + auto sourceType = source.getType().dyn_cast(); + if (!sourceType || sourceType.getNumDynamicDims() > 0 || + sourceType.getNumElements() > 1) { + return rewriter.notifyMatchFailure( + op, "source should be a single element vector"); + } + if (op.getResult().getType().isa()) + return rewriter.notifyMatchFailure(op, "result should be a scalar"); + + int64_t rank = sourceType.getRank(); + if (rank == 0) { + // vector.extract doesn't support 0D tensors at the moment, + // use vector.extractelement. + rewriter.replaceOpWithNewOp(op, source); + return success(); + } + rewriter.replaceOpWithNewOp( + op, source, SmallVector(rank, 0)); + return success(); + } +}; + +/// Update tensor operand of vector.transfer_write that uses MaterializeOp. +struct MaterializeUpdateTransferWriteTensorOperand + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp op, + PatternRewriter &rewriter) const override { + if (!op->getParentOfType()) return failure(); + + // Sanity checks of TransferWriteOp. + if (op.hasOutOfBoundsDim()) return failure(); + if (op.getVectorType().getRank() != op.getShapedType().getRank()) + return failure(); + if (op.getMask()) return failure(); + // Fold only if the TransferWriteOp completely overwrites the `source` + // with a vector, i.e. the result of the TransferWriteOp is a new tensor + // whose content is the data of the vector. + if (!llvm::equal(op.getVectorType().getShape(), + op.getShapedType().getShape())) + return failure(); + if (!op.getPermutationMap().isIdentity()) return failure(); + + auto src = op.getSource().getDefiningOp(); + if (!src) return failure(); + + SmallVector indices = getValueOrCreateConstantIndexOp( + rewriter, op.getLoc(), src.getMixedOffsets()); + SmallVector inBounds(op.getTransferRank(), true); + rewriter.setInsertionPointAfter(op); + auto newOp = rewriter.create( + op.getLoc(), op.getVector(), src.getSource(), indices, + ArrayRef{inBounds}); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType().cast(), + newOp.getResult(), src.getOffsets(), src.getSizes(), src.getStrides(), + src.getStaticOffsets(), src.getStaticSizes(), src.getStaticStrides()); + + return success(); + } +}; + +/// Update tensor operand of vector.transfer_write used by SetYieldOp. +struct SetYieldUpdateTransferWriteTensorOperand + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SetYieldOp op, + PatternRewriter &rewriter) const override { + if (!op->getParentOfType()) return failure(); + + bool changed = false; + for (const auto &[src, dst, set] : + llvm::zip(op.getSrcs(), op.getDsts(), op.getSets())) { + auto xferOp = src.getDefiningOp(); + + // Sanity checks of TransferWriteOp. + if (!xferOp) continue; + if (xferOp.getSource() == dst) continue; + if (xferOp.hasOutOfBoundsDim()) continue; + if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank()) + continue; + if (xferOp.getMask()) continue; + // Fold only if the TransferWriteOp completely overwrites the `source` + // with a vector, i.e. the result of the TransferWriteOp is a new tensor + // whose content is the data of the vector. + if (!llvm::equal(xferOp.getVectorType().getShape(), + xferOp.getShapedType().getShape())) + continue; + if (!xferOp.getPermutationMap().isIdentity()) continue; + + auto tileOp = set.getDefiningOp(); + + if (!tileOp) continue; + + SmallVector indices = getValueOrCreateConstantIndexOp( + rewriter, op.getLoc(), tileOp.getMixedOffsets()); + SmallVector inBounds(xferOp.getTransferRank(), true); + auto newOp = rewriter.create( + xferOp.getLoc(), xferOp.getVector(), dst, indices, + ArrayRef{inBounds}); + rewriter.replaceOpWithNewOp( + xferOp, xferOp.getResult().getType().cast(), + newOp.getResult(), tileOp.getOffsets(), tileOp.getSizes(), + tileOp.getStrides(), tileOp.getStaticOffsets(), + tileOp.getStaticSizes(), tileOp.getStaticStrides()); + changed = true; + } + return success(changed); + } +}; + +struct MaterializeOpVectorizationPattern + : public OpRewritePattern { + MaterializeOpVectorizationPattern( + MLIRContext *context, + llvm::function_ref filterFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), filterFn(filterFn) {} + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp op, + PatternRewriter &rewriter) const override { + if (!filterFn(op)) + return rewriter.notifyMatchFailure(op, "did not match filter"); + TypedValue source = op.getSource(); + ShapedType sourceType = source.getType(); + // TODO(b/244314345): Support imperfect tiling, which results in dynamic + // shapes. + if (!sourceType.isa() || + sourceType.getNumDynamicDims() > 0 || + ShapedType::isDynamicShape(op.getStaticSizes())) + return rewriter.notifyMatchFailure(op, "input is not statically shaped"); + + Location loc = op.getLoc(); + IRMapping bvm; + convertTensorOperandsToVector(op, bvm, rewriter); + Type newResult = op.getResult().getType(); + if (auto tensorResult = newResult.dyn_cast()) { + newResult = VectorType::get(tensorResult.getShape(), + tensorResult.getElementType()); + } + Value vectorMaterialize = rewriter.create( + loc, newResult, bvm.lookupOrDefault(source), op.getMixedOffsets(), + op.getMixedSizes(), op.getMixedStrides()); + bvm.map(op, vectorMaterialize); + if (auto vectorType = newResult.dyn_cast()) { + // The result is not a scalar, generate a TransferWrite back to tensor. + // transfer_write uses destination passing style, so we need to "invent" a + // destination tensor. The entinre tensor_write op, together with the + // invented tensor will be folded when vectorizing the final + // gml_st.set_yield op. + auto emptyTensor = rewriter.create( + loc, vectorType.getShape(), vectorType.getElementType()); + convertVectorResultsToTensor(op->getResults(), {emptyTensor}, bvm, + rewriter); + } + rewriter.replaceOp(op, bvm.lookupOrDefault(op)); + return success(); + } + + private: + llvm::function_ref filterFn; +}; + +// TODO(pifon): Remove patterns that use gml_st.materialize, once GmlSt loops +// are removed/upstreamed. +struct FoldVectorExtractOfMaterialize + : public OpRewritePattern { + explicit FoldVectorExtractOfMaterialize(MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(vector::ExtractOp op, + PatternRewriter &rewriter) const override { + auto materializeOp = op.getVector().getDefiningOp(); + if (!materializeOp) return failure(); + + if (llvm::any_of(op.getPosition().getAsRange(), + [](IntegerAttr pos) { return pos.getInt() != 0; })) + return failure(); + + rewriter.replaceOpWithNewOp( + op, op.getType(), materializeOp.getSource(), + materializeOp.getMixedOffsets(), materializeOp.getMixedSizes(), + materializeOp.getMixedStrides()); + return success(); + } +}; + +// Rewrite vector.transfer_read(tensor.empty) into a constant vector of the +// right size. This is our temporary way of expressing the nonexistent +// vector.undef, which creates a vector to be used in destination-passing-style +// ops. +// TODO(b/255779480): Figure out how to properly solve this issue. +struct TensorEmptyToVectorBroadcastPattern + : public OpRewritePattern { + TensorEmptyToVectorBroadcastPattern( + MLIRContext *context, llvm::function_ref filterFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), filterFn(filterFn) {} + + LogicalResult matchAndRewrite(TransferReadOp op, + PatternRewriter &rewriter) const override { + if (failed(matchSimpleTransferOp(op, rewriter))) return failure(); + auto tensorEmpty = op.getSource().getDefiningOp(); + if (!tensorEmpty) + return rewriter.notifyMatchFailure(op, "source should be tensor.empty"); + auto vectorType = op.getResult().getType().dyn_cast(); + if (!vectorType) + return rewriter.notifyMatchFailure(op, "result should be a vector"); + Type elementType = vectorType.getElementType(); + TypedAttr nanAttr; + if (elementType.isa()) { + nanAttr = rewriter.getIntegerAttr(elementType, 0l); + } else if (elementType.isa()) { + nanAttr = rewriter.getFloatAttr(elementType, + std::numeric_limits::quiet_NaN()); + } else { + return rewriter.notifyMatchFailure( + op, "should operate on integer or floating point vectors"); + } + + rewriter.replaceOpWithNewOp( + op, DenseElementsAttr::get(vectorType, nanAttr)); + return success(); + } + + private: + llvm::function_ref filterFn; +}; + +struct IdentityMaterializeOpFoldingPattern + : public OpRewritePattern { + explicit IdentityMaterializeOpFoldingPattern(MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp op, + PatternRewriter &rewriter) const override { + auto src = op.getSource(); + // Only fold identity materialize of ForOp's block argument. + // Set has to be an identity tile op and source and result are static and + // have the same shapes. + if (!op->getParentOfType() || !src.isa() || + !isIdentitySlice(op.getOffsets(), op.getStrides()) || + !haveSameStaticShape(src, op.getResult())) + return rewriter.notifyMatchFailure(op, "did not match filter"); + + op.replaceAllUsesWith(src); + return success(); + } +}; + +// Rewrite tensor.extract on single-element tensors into a vector.extract. +struct TensorToElementVectorizationPattern + : public mlir::OpRewritePattern { + TensorToElementVectorizationPattern( + MLIRContext *context, llvm::function_ref matchFn, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern(context, benefit), + filterFn(matchFn) {} + + LogicalResult matchAndRewrite(tensor::ExtractOp op, + PatternRewriter &rewriter) const override { + if (!filterFn(op)) + return rewriter.notifyMatchFailure(op, "did not match filter"); + TensorType tensorType = op.getTensor().getType(); + if (tensorType.getNumDynamicDims() > 0 || tensorType.getNumElements() > 1) + return rewriter.notifyMatchFailure(op, "should have a single element"); + + IRMapping bvm; + convertTensorOperandsToVector(op, bvm, rewriter); + if (tensorType.getRank() == 0) { + // ExtractOp only supports ranks > 0, for rank = 0 use ExtractElementOp + rewriter.replaceOpWithNewOp( + op, bvm.lookupOrDefault(op.getTensor())); + } else { + rewriter.replaceOpWithNewOp( + op, bvm.lookupOrDefault(op.getTensor()), + SmallVector(tensorType.getRank(), 0)); + } + return success(); + } + + private: + llvm::function_ref filterFn; +}; bool isInsideGmlStLoop(Operation *op) { Operation *parent = op->getParentOp(); - return isa(parent) || isa(parent) || isa(parent); + return isa(parent) || isa(parent); } + bool isFillTiledOrSmall(FillOp fill) { if (isInsideGmlStLoop(fill)) return true; // Allow vectorization for static shapes with low number of elements. - auto outputType = fill.output().getType().cast(); - return outputType.hasStaticShape() && + auto outputType = fill.output().getType().dyn_cast(); + return outputType && outputType.hasStaticShape() && outputType.getNumElements() < kNumElementsThreshold; } -bool isGenericOpTiledOrOneDimReduction(GenericOp generic) { +bool isLinalgOpTiledOrOneDimReduction(linalg::LinalgOp op) { + if (isInsideGmlStLoop(op)) return true; + + // Allow vectorization of 1D reductions. + return op.getNumLoops() == 1 && op.getNumReductionLoops() == 1; +} + +bool isGenericOpTiledOrOneDimReduction(linalg::GenericOp generic) { if (isInsideGmlStLoop(generic)) return true; // Allow vectorization of 1D reductions. return generic.getNumLoops() == 1 && generic.getNumReductionLoops() == 1; } -struct VectorizeGmlStLoopsPass - : public impl::VectorizeGmlStLoopsPassBase { - VectorizeGmlStLoopsPass(bool vectorizeGmlStOpsParam, - ArrayRef distributionLabelsParam) { +struct VectorizeForGPUPass + : public impl::VectorizeForGPUPassBase { + VectorizeForGPUPass(bool vectorizeGmlStOpsParam, + ArrayRef distributionLabelsParam) { vectorizeGmlStOps = vectorizeGmlStOpsParam; for (StringRef distribution : distributionLabelsParam) distributionLabels.push_back(distribution.str()); @@ -521,7 +606,7 @@ struct VectorizeGmlStLoopsPass auto isValidDistribution = [&](Operation *op) { if (distributionLabels.empty()) return true; - ParallelOp parent = op->getParentOfType(); + auto parent = op->getParentOfType(); if (!parent || !parent.getDistributionType().has_value()) return false; return llvm::find(distributionLabels, parent.getDistributionType().value()) != @@ -530,7 +615,11 @@ struct VectorizeGmlStLoopsPass // These lambdas have to be assigned to local variables, so that they // survive beyond patterns.add() and applyPatternsAndFoldGreedily() calls. auto fillOpFilter = [&](FillOp op) { - return isValidDistribution(op) && isFillTiledOrSmall(op); + bool filter = isValidDistribution(op) && isFillTiledOrSmall(op); + return filter; + }; + auto linalgOpFilter = [&](LinalgOp op) { + return isValidDistribution(op) && isLinalgOpTiledOrOneDimReduction(op); }; auto genericOpFilter = [&](GenericOp op) { return isValidDistribution(op) && isGenericOpTiledOrOneDimReduction(op); @@ -542,7 +631,7 @@ struct VectorizeGmlStLoopsPass op.getResult(0).getType().cast(); return outputType.hasStaticShape(); }; - auto materializeOpFilter = [&](MaterializeOp op) { + auto materializeOpFilter = [&](tensor::ExtractSliceOp op) { // Materialize op should only be vectorized if the producer of its // source is within the vectorized region, otherwise we vectorize one // level too much. (E.g., for GPU, if we are vectorizing up to warp level, @@ -550,34 +639,63 @@ struct VectorizeGmlStLoopsPass // block-level tiles, since it means we are inserting a // vector.transfer_read on the source, i.e., a block-level tile). Operation *sourceOp = op.getSource().getDefiningOp(); - return sourceOp && isValidDistribution(sourceOp); + // Only vectorize MaterializeOp inside a loop, since we are only enabling + // this pattern when vectorizing ForOp and ParallelOp anyway. + Operation *parent = op->getParentOp(); + bool opInsideLoop = isa(parent) || isa(parent); + return sourceOp != nullptr && opInsideLoop && + isValidDistribution(sourceOp); }; + { + RewritePatternSet patterns = getDefaultVectorizationPatterns(ctx); + + populateTransferReadOfOneDimExpandShapePattern(patterns); + patterns.add(ctx); + patterns.add>(ctx, fillOpFilter); + patterns.add>(ctx, genericOpFilter); + patterns.add, + VectorizationPattern, VectorizationPattern>( + ctx, linalgOpFilter); + patterns.add>(ctx, matmulOpFilter); + patterns.add(ctx, + isValidDistribution); + if (vectorizeGmlStOps) { + patterns.add(ctx, + materializeOpFilter); + patterns.add, + LoopLikeOpVectorizationPattern>( + ctx, isValidDistribution); + } + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + } + + { + RewritePatternSet patterns = getDefaultVectorizationPatterns(ctx); + patterns.add(ctx); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + } + + // Hoisting transfer_read/transfer_write. + { + RewritePatternSet patterns(ctx); + patterns.add(ctx); + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); - RewritePatternSet patterns = getDefaultVectorizationPatterns(ctx); - patterns.add(ctx); - patterns.add>(ctx, fillOpFilter); - patterns.add>(ctx, genericOpFilter); - patterns.add>(ctx, matmulOpFilter); - patterns.add(ctx, isValidDistribution); - if (vectorizeGmlStOps) { - patterns.add(ctx, materializeOpFilter); - patterns.add, - LoopLikeOpVectorizationPattern>(ctx, - isValidDistribution); + hoistRedundantVectorTransfersOnTensor(func); } - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } }; } // namespace -std::unique_ptr> createVectorizeGmlStLoopsPass( +std::unique_ptr> createVectorizeForGPUPass( bool vectorizeGmlStOps, ArrayRef distributionLabels) { - return std::make_unique(vectorizeGmlStOps, - distributionLabels); + return std::make_unique(vectorizeGmlStOps, + distributionLabels); } } // namespace gml_st diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/gml_st/utils/CMakeLists.txt similarity index 85% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/gml_st/utils/CMakeLists.txt index eae7839da74..2bdbbdbe291 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/utils/CMakeLists.txt @@ -13,5 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -add_subdirectory(Dialect) -add_subdirectory(Transforms) + +add_mlir_dialect_library(MLIRGmlStUtils + linalg_utils.cc + + LINK_LIBS PUBLIC + MLIRLinalgDialect +) diff --git a/tensorflow/compiler/xla/mlir_hlo/gml_st/utils/linalg_utils.cc b/tensorflow/compiler/xla/mlir_hlo/gml_st/utils/linalg_utils.cc new file mode 100644 index 00000000000..d53e3ef3b24 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/utils/linalg_utils.cc @@ -0,0 +1,70 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "gml_st/utils/linalg_utils.h" + +#include "mlir/Dialect/Linalg/Utils/Utils.h" + +namespace mlir { +namespace gml_st { + +bool isCwiseGenericOp(Operation *op, int64_t *arity) { + auto genericOp = llvm::dyn_cast_or_null(op); + if (!genericOp || genericOp.getNumDpsInits() != 1) return false; + + // Check all-parallel iterator types. + if (!llvm::all_of(genericOp.getIteratorTypesArray(), + linalg::isParallelIterator)) + return false; + + // Check all-identity maps. + if (!llvm::all_of(genericOp.getIndexingMapsArray(), + [](AffineMap map) { return map.isIdentity(); })) { + return false; + } + + // Allow for pattern matching the arity. + if (arity != nullptr) *arity = genericOp.getNumDpsInputs(); + return true; +} + +bool isSimpleBcastReduction(Operation *op, int64_t *dimension, + SimpleBcastReduction *chain) { + // Match bcast. + auto broadcastOp = llvm::dyn_cast_or_null(op); + if (!broadcastOp) return false; + + // Match reduction. + auto reduceOp = llvm::dyn_cast_or_null( + broadcastOp.getOperands().front().getDefiningOp()); + if (!reduceOp || reduceOp.getNumDpsInits() != 1) return false; + + // Check that bcast and reduction dimensions match. + auto bcstDimensions = broadcastOp.getDimensions(); + if (!bcstDimensions.empty() && bcstDimensions != reduceOp.getDimensions()) + return false; + + // Allow for pattern matching the reduction dimension and operation chain. + if (dimension != nullptr) *dimension = bcstDimensions.front(); + if (chain != nullptr) { + chain->bcast = op; + chain->reduction = reduceOp; + chain->operand = reduceOp.getInputs().front(); + } + return true; +} + +} // namespace gml_st +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/linalg_utils.h b/tensorflow/compiler/xla/mlir_hlo/gml_st/utils/linalg_utils.h similarity index 76% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/linalg_utils.h rename to tensorflow/compiler/xla/mlir_hlo/gml_st/utils/linalg_utils.h index e0e4185456b..4f18251c6dd 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/linalg_utils.h +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/utils/linalg_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_LINALG_UTILS_H -#define MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_LINALG_UTILS_H +#ifndef MLIR_HLO_GML_ST_UTILS_LINALG_UTILS_H +#define MLIR_HLO_GML_ST_UTILS_LINALG_UTILS_H #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -24,17 +24,9 @@ namespace gml_st { // Helper functions to match `linalg.generic` ops that implement simple // reductions, bcasts, and cwise ops. -bool isSimpleReduction(Operation *op, int64_t *dimension = nullptr, - Value *operand = nullptr); - // Returns whether 'op' is element-wise linalg.generic with single result. bool isCwiseGenericOp(Operation *op, int64_t *arity = nullptr); -bool isUnaryCwiseGenericOp(Operation *op); - -bool isSimpleBcast(Operation *op, int64_t *dimension = nullptr, - Value *operand = nullptr); - struct SimpleBcastReduction { Operation *bcast; Operation *reduction; diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/vector_utils.h b/tensorflow/compiler/xla/mlir_hlo/gml_st/utils/vector_utils.h similarity index 94% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/vector_utils.h rename to tensorflow/compiler/xla/mlir_hlo/gml_st/utils/vector_utils.h index 1a2c2b3cd52..752a8c6b152 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/vector_utils.h +++ b/tensorflow/compiler/xla/mlir_hlo/gml_st/utils/vector_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_VECTOR_UTILS_H -#define MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_VECTOR_UTILS_H +#ifndef MLIR_HLO_GML_ST_UTILS_VECTOR_UTILS_H +#define MLIR_HLO_GML_ST_UTILS_VECTOR_UTILS_H #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/PatternMatch.h" diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/CMakeLists.txt deleted file mode 100644 index fdb06328302..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -# -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -add_subdirectory(gml_st) -add_subdirectory(lhlo) -add_subdirectory(lhlo_gpu) -add_subdirectory(mhlo) -add_subdirectory(thlo) diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/CMakeLists.txt deleted file mode 100644 index 88672e5e298..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -add_subdirectory(IR) -add_subdirectory(transforms) diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_legacy_ops.td b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_legacy_ops.td deleted file mode 100644 index 86e02eb245c..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_legacy_ops.td +++ /dev/null @@ -1,339 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This is the operation definition file for ST ops. - -#ifndef GML_ST_LEGACY_OPS -#define GML_ST_LEGACY_OPS - -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Interfaces/LoopLikeInterface.td" -include "mlir/Interfaces/SideEffectInterfaces.td" - -def GMLST_LoopOp : GMLST_Op<"loop", [ - AttrSizedOperandSegments, - DeclareOpInterfaceMethods, - RecursiveMemoryEffects, - SingleBlockImplicitTerminator<"gml_st::YieldOp"> - ]> { - let summary = "Loop-like operation"; - let description = [{ - This is a loop-like operation with additional properties. The arguments - also include the input and the output tensors or memrefs and the attributes - to specify the iterator types. - - Parsing LoopOp will set all elements of the `iterator_types` attribute - to "parallel" type, when it is absent from the custom format. - - Tensor-based version: - - The body region of the loop contains `extract_slice` operations applied to - every tensor argument of LoopOp. - - The body region must contain exactly one block that terminates with - `gml_st.yield` with the operands resulting from `insert_slice` operations. - - Example: - - ```mlir - %0 = gml_st.loop (%i) = (%c0) to (%c24) step (%c4) - ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>) - outs(%out : tensor<24x64xi8>) - iterators("parallel") - distribution("block_x") { - %lhs_sub = tensor.extract_slice %lhs[%i, 0] [%c4, %c64] [1, 1] - : tensor<24x64xi8> to tensor - %rhs_sub = tensor.extract_slice %rhs[%i, 0] [%c4, %c64] [1, 1] - : tensor<24x64xi8> to tensor - %out_sub = tensor.extract_slice %out[%i, 0] [%c4, %c64] [1, 1] - : tensor<24x64xi8> to tensor - - %result_sub = linalg.generic ... - - %result = tensor.insert_slice %result_sub into %out[%i, 0][%c4, %c64][1, 1] - : tensor into tensor<24x64xi8> - gml_st.yield %result : tensor<24x64xi8> - } - ``` - - MemRef-based version: - - The body region of the loop contains `subview` operations applied to - every memref argument of LoopOp. - - The body region must contain exactly one block that terminates with - `gml_st.yield` with no operands. - - Example: - - ```mlir - gml_st.loop (%i) = (%c0) to (%c24) step (%c4) - ins(%lhs, %rhs : memref<24x64xi8>, memref<24x64xi8>) - outs(%out : memref<24x64xi8>) - iterators("parallel") - distribution("block_x") { - %lhs_sub = subview %lhs[%i, 0] [%c4, %c64] [1, 1] - : memref<24x64xi8> to memref - %rhs_sub = subview %rhs[%i, 0] [%c4, %c64] [1, 1] - : memref<24x64xi8> to memref - %out_sub = subview %out[%i, 0] [%c4, %c64] [1, 1] - : memref<24x64xi8> to memref - - %result_sub = linalg.generic ... - gml_st.yield - } - ``` - }]; - - let arguments = (ins Variadic:$lowerBound, - Variadic:$upperBound, - Variadic:$step, - Variadic:$inputs, - Variadic:$outputs, - IteratorTypeArrayAttr:$iterator_types, - OptionalAttr:$distribution_types); - let results = (outs Variadic:$results); - let regions = (region SizedRegion<1>:$region); - - let builders = [ - OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds, - "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs, - "ArrayAttr":$iteratorTypes, "Optional":$distributionTypes, - CArg<"function_ref", - "nullptr">:$bodyBuilderFn)>, - OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds, - "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs, - "ArrayAttr":$iteratorTypes, - CArg<"function_ref", - "nullptr">:$bodyBuilderFn)>, - ]; - - let extraClassDeclaration = [{ - /// Number of loops - unsigned getNumLoops() { return getStep().size(); } - - /// Number of input operands - unsigned getNumInputs() { return getInputs().size(); } - - /// Number of output operands - unsigned getNumOutputs() { return getOutputs().size(); } - - /// Number of operands controlling the loop: lbs, ubs, steps - unsigned getNumControlOperands() { return 3 * getNumLoops(); } - - ValueRange getInductionVars() { - return getBody()->getArguments().take_front(getNumLoops()); - } - ValueRange getRegionInputArgs() { - return getBody()->getArguments().slice(getNumLoops(), getInputs().size()); - } - ValueRange getRegionOutputArgs() { - return getBody()->getArguments().take_back(getOutputs().size()); - } - - void setDistributionTypes(Builder& b, ArrayRef types) { - assert(types.size() == getNumLoops() && - "expected distribution type for every dimension"); - setDistributionTypesAttr(b.getStrArrayAttr(types)); - } - - void setLowerBounds(ValueRange lowerBounds) { - unsigned numLoops = getNumLoops(); - assert(lowerBounds.size() == numLoops && - "expected lower bounds for every loop dimension"); - for (unsigned i = 0; i < numLoops; ++i) - setOperand(i, lowerBounds[i]); - } - - void setUpperBounds(ValueRange upperBounds) { - unsigned numLoops = getNumLoops(); - assert(upperBounds.size() == numLoops && - "expected upper bounds for every loop dimension"); - for (unsigned i = 0, pos = numLoops; i < numLoops; ++i, ++pos) - setOperand(pos, upperBounds[i]); - } - - void setSteps(ValueRange steps) { - unsigned numLoops = getNumLoops(); - assert(steps.size() == numLoops && - "expected upper bounds for every loop dimension"); - for (unsigned i = 0, pos = 2 * numLoops; i < numLoops; ++i, ++pos) - setOperand(pos, steps[i]); - } - - /// Operand that corresponds to the `bbArg` block argument. - OpOperand& getTiedOperand(BlockArgument& bbArg) { - return getOperation()->getOpOperand(getNumControlOperands() + - bbArg.getArgNumber() - getNumLoops()); - } - - /// Block argument that corresponds to the `input` or `output` operand. - BlockArgument getTiedBlockArgument(OpOperand& operand) { - auto operandIndex = operand.getOperandNumber(); - assert( - operandIndex >= getNumControlOperands() && - operandIndex < getNumOperands() && - "tied block arg is defined only for `input` and `output` arguments"); - return getBody()->getArgument(operandIndex - 2 * getNumLoops()); - } - - /// Result that corresponds to the `outputs` argument of tensor type. - OpResult getTiedOpResult(OpOperand& opOperand) { - // No result can correspond to a memref argument. - if (opOperand.get().getType().isa()) return OpResult(); - - // Check whether the operand index is in bounds of `outputs()` arg. - int operandIndex = opOperand.getOperandNumber(); - int outputIndexStart = - getNumControlOperands() + getInputs().size(); - int outputIndexEnd = outputIndexStart + getOutputs().size(); - if (operandIndex < outputIndexStart || operandIndex >= outputIndexEnd) - return OpResult(); - - // Count tensor arguments in `outputs` to compute the result index. - int tensorId = -1; - for (int i = outputIndexStart; i <= operandIndex; ++i) - tensorId += getOperand(i).getType().isa(); - return getOperation()->getResult(tensorId); - } - - /// Append `operand` to the `input` arguments. - OpOperand& appendInputOperand(OpBuilder& builder, Value operand) { - int numLoops = getNumLoops(); - int numInputs = getNumInputs(); - int numOutputs = getNumOutputs(); - - getOperation()->insertOperands(getNumControlOperands() + numInputs, - operand); - getBody()->insertArgument(numLoops + numInputs, operand.getType(), - getLoc()); - getOperation()->setAttr( - LoopOp::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr( - {numLoops, numLoops, numLoops, numInputs + 1, numOutputs})); - return getOperation()->getOpOperand(getNumControlOperands() + numInputs); - } - - /// Append `operand` to the `output` arguments. - OpOperand& appendOutputOperand(OpBuilder& builder, Value operand) { - int numLoops = getNumLoops(); - int numInputs = getNumInputs(); - int numOutputs = getNumOutputs(); - - getOperation()->insertOperands( - getNumControlOperands() + numInputs + numOutputs, operand); - getBody()->insertArgument(numLoops + numInputs + numOutputs, - operand.getType(), getLoc()); - getOperation()->setAttr( - LoopOp::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr( - {numLoops, numLoops, numLoops, numInputs, numOutputs + 1})); - return getOperation()->getOpOperand(getNumControlOperands() + numInputs + - numOutputs); - } - - /// Erase `operand` from the `input` or `output` arguments. - void eraseOperand(OpBuilder& builder, OpOperand& operand) { - int numInputs = getNumInputs(); - int numLoops = getNumLoops(); - int numOutputs = getNumOutputs(); - int numControlOperands = getNumControlOperands(); - - int operandIndex = operand.getOperandNumber(); - assert(operandIndex >= numControlOperands && - operandIndex < static_cast(getNumOperands()) && - "Can erase only `input` or `output` operand"); - - if (operandIndex >= numControlOperands + numInputs) - --numOutputs; - else - --numInputs; - - getOperation()->eraseOperand(operandIndex); - getBody()->eraseArgument(operandIndex - 2 * numLoops); - getOperation()->setAttr( - LoopOp::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr( - {numLoops, numLoops, numLoops, numInputs, numOutputs})); - } - - OpOperand* findInputOperand(Value value) { - OperandRange::iterator it = llvm::find(getInputs(), value); - if (it == getInputs().end()) return nullptr; - return it.getBase(); - } - - OpOperand* findOutputOperand(Value value) { - OperandRange::iterator it = llvm::find(getOutputs(), value); - if (it == getOutputs().end()) return nullptr; - return it.getBase(); - } - - /// Return whether the op has only MemRef input and outputs. - bool hasBufferSemantics() { - Operation* op = this->getOperation(); - return op->getNumResults() == 0 && - llvm::all_of(op->getOpOperands(), [&](OpOperand & operand) { - return !operand.get().getType().template isa() || - operand.get().getType().template isa(); - }); - } - - static constexpr StringRef getDistributionTypesAttrStrName() { - return "distribution_types"; - } - static constexpr StringRef getIteratorTypesAttrStrName() { - return "iterator_types"; - } - - /// Return whether the loop dimension is parallel or not. - bool isParallelDimension(unsigned dim) { - IteratorTypeAttr attr = - this->getIteratorTypes()[dim].cast(); - return attr.getValue() == utils::IteratorType::parallel; - } - - /// Return the destinations for a gml_st.loop op. - ValueRange getLoopLikeOpInits() { - return getOutputs(); - } - }]; - - let hasCanonicalizer = 1; - let hasCustomAssemblyFormat = 1; - let hasFolder = 1; -} - -def GMLST_YieldOp : GMLST_Op<"yield", [Pure, ReturnLike, Terminator, - HasParent<"::mlir::gml_st::LoopOp, ::mlir::gml_st::SetYieldOp">]>, - Arguments<(ins Variadic:$values)> { - let summary = "Yield operation"; - let description = [{ - `gml_st.yield` is a special terminator operation for `gml_st.loop` body or - for accumulator regions of `gml_st.set_yield`. - - Example: - - ```mlir - gml_st.yield %f0, %f1 : tensor, tensor - ``` - }]; - let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; - let assemblyFormat = "attr-dict ($values^ `:` type($values))?"; -} - -#endif // GML_ST_LEGACY_OPS diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt deleted file mode 100644 index f6a9948ec57..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/CMakeLists.txt +++ /dev/null @@ -1,28 +0,0 @@ -# -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -set(LLVM_TARGET_DEFINITIONS passes.td) -mlir_tablegen(passes.h.inc -gen-pass-decls -name GmlSt) -add_public_tablegen_target(MLIRGmlStPassIncGen) - -set(LLVM_TARGET_DEFINITIONS test_passes.td) -mlir_tablegen(test_passes.h.inc -gen-pass-decls -name GmlStTest) -add_public_tablegen_target(MLIRGmlStTestPassIncGen) - -set(LLVM_TARGET_DEFINITIONS tiling_interface.td) -mlir_tablegen(tiling_interface.h.inc -gen-op-interface-decls) -mlir_tablegen(tiling_interface.cc.inc -gen-op-interface-defs) -add_public_tablegen_target(MLIRGmlStTilingInterfaceIncGen) diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/fusion.h b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/fusion.h deleted file mode 100644 index d182c9d9255..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/fusion.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_FUSION_H -#define MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_FUSION_H - -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir/IR/PatternMatch.h" - -namespace mlir { -namespace gml_st { - -// Create fused operation based on the specificed subset. The result is -// equivalent to the given `materialize` op. -FailureOr createFusedOp(PatternRewriter &rewriter, - MaterializeOp materializeOp); - -/// Populate fusion patterns. -void populateFusionPatterns(MLIRContext *ctx, - function_ref filterFn, - RewritePatternSet *patterns); - -} // namespace gml_st -} // namespace mlir - -#endif // MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_FUSION_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h deleted file mode 100644 index 20a98220b99..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_PASSES_H -#define MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_PASSES_H - -#include -#include - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" - -#define GEN_PASS_DECL -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h.inc" - -namespace mlir { -namespace gml_st { - -/// Pass to tile ops using TilingInterface. -std::unique_ptr> createTilingPass( - StringRef opName = "", StringRef opLabel = "", bool distribute = true, - ArrayRef tileSizes = {}); - -/// Pass to fuse producers into a tiled consumer. -std::unique_ptr> createFusionPass( - StringRef producer = "", StringRef consumer = ""); - -/// Pass to tile and fuse all cwise ops. -std::unique_ptr> createTilingCwisePass( - bool distribute, ArrayRef tileSizes, - StringRef distributionLabel = ""); -std::unique_ptr> createTilingCwisePass(); - -/// Pass to tile warp-level ops on GPU. -std::unique_ptr> createTilingGPUWarpPass(); - -/// Pass to match, tile, and fuse softmax implementations. -std::unique_ptr> createTilingSoftmaxPass( - bool distribute, ArrayRef tileSizes, - StringRef distributionLabel = ""); -std::unique_ptr> createTilingSoftmaxPass(); - -/// Pass to collapse (or uncollapse) materialize operations. -std::unique_ptr> createCollapseMaterializeOpsPass(); - -/// Create a pass to convert `gml_st.loop` to `scf.for` and `scf.parallel` -/// loops and memref.load/memref.store accesses. -std::unique_ptr> createGmlStToScfPass(); - -/// Pass to vectorize linalg.generic ops tiled to gml_st.parallel and gml_st.for -/// loops. -std::unique_ptr> createVectorizeGmlStLoopsPass( - bool vectorizeGmlStOps = false, - ArrayRef distributionLabels = {}); - -/// Pass to transform a thlo.scatter op for CPU backend. -std::unique_ptr> createTransformScatterForCpuPass(); - -/// Pass to transform a linalg.matmul op for CPU backend. -std::unique_ptr> createTransformMatmulForCpuPass( - ArrayRef tileSizes = llvm::None); - -/// Pass to transform a linalg.map op for CPU backend. -std::unique_ptr> -createTransformMapForCpuPass(int64_t tileSize = 1); - -/// Pass to transform a linalg.transpose op for CPU backend. -std::unique_ptr> -createTransformTransposeForCpuPass(ArrayRef tileSizes = llvm::None); - -#define GEN_PASS_REGISTRATION -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h.inc" - -} // namespace gml_st -} // namespace mlir - -#endif // MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_PASSES_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.td b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.td deleted file mode 100644 index 80237e3edfe..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.td +++ /dev/null @@ -1,159 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -include "mlir/Pass/PassBase.td" - -def TilingPass : Pass<"gml-tiling", "mlir::func::FuncOp"> { - let summary = "Tile operations using TilingInterface to produce gml_st.for"; - let constructor = "::mlir::gml_st::createTilingPass()"; - let options = [ - Option<"opName", "op-name", "std::string", /*default=*/"", - "Operation with this name is the anchor to latch on.">, - Option<"opLabel", "op-label", "std::string", /*default=*/"", - "Operation with this label is the anchor to latch on.">, - Option<"distribute", "distribute", "bool", /*default=*/"true", - "Generate gml_st.parallel or gml_st.for">, - ListOption<"tileSizes", "tile-sizes", "int64_t", "Tile sizes", - "llvm::cl::ZeroOrMore">, - ]; -} - -def FusionPass : Pass<"gml-fusion", "mlir::func::FuncOp"> { - let summary = "Fuse producers in into `gml_st.materialize` operations"; - let constructor = "::mlir::gml_st::createFusionPass()"; - let options = [ - Option<"producerLabel", "producer-label", "std::string", /*default=*/"", - "Producer label.">, - Option<"consumerLabel", "consumer-label", "std::string", /*default=*/"", - "Consumer label.">, - ]; -} - -def TilingCwisePass : Pass<"gml-tiling-cwise", "mlir::func::FuncOp"> { - let summary = "Tile and fuse all cwise ops"; - let constructor = "::mlir::gml_st::createTilingCwisePass()"; - let options = [ - Option<"distribute_", "distribute", "bool", /*default=*/"true", - "Generate gml_st.parallel or gml_st.for">, - ListOption<"tileSizes_", "tile-sizes", "int64_t", - "Right-aligned tile sizes. Do not tile possible remaining " - "dimensions", "llvm::cl::ZeroOrMore">, - Option<"distributionLabel_", "distribution-label", "std::string", - /*default=*/"", "Distribution label for generated gml_st.parallel">, - ]; -} - -def TilingGPUWarpPass : Pass<"gml-tiling-gpu-warp", "mlir::func::FuncOp"> { - let summary = "Tile warp-level ops for GPU"; - let constructor = "::mlir::gml_st::createTilingGPUWarpPass()"; - let dependentDialects = ["::mlir::gml_st::GmlStDialect", - "::mlir::arith::ArithDialect"]; -} - -def TilingSoftmaxPass : Pass<"gml-tiling-softmax", "mlir::func::FuncOp"> { - let summary = "Match, tile, and fuse softmax implementations"; - let constructor = "::mlir::gml_st::createTilingSoftmaxPass()"; - let options = [ - Option<"distribute", "distribute", "bool", /*default=*/"true", - "Generate gml_st.parallel or gml_st.for">, - ListOption<"tileSizes", "tile-sizes", "int64_t", - "Right-aligned tile sizes. Do not tile possible remaining " - "dimensions", "llvm::cl::ZeroOrMore">, - Option<"distributionLabel", "distribution-label", "std::string", - /*default=*/"", "Distribution label for generated gml_st.parallel">, - ]; -} - -def CollapseMaterializeOpsPass : Pass<"gml-collapse-materialize-ops", - "mlir::func::FuncOp"> { - let summary = "Collapse (or uncollapse) materialize operations."; - let constructor = "::mlir::gml_st::createCollapseMaterializeOpsPass()"; -} - -def GmlStToScf : Pass<"gml-st-to-scf", "mlir::func::FuncOp"> { - let summary = "Lower `gml_st.loop` to SCF loops and parallel loops"; - let constructor = "::mlir::gml_st::createGmlStToScfPass()"; - let dependentDialects = ["::mlir::scf::SCFDialect"]; -} - -def GmlStToGpuPass : Pass<"gml-st-to-gpu", "mlir::func::FuncOp"> { - let summary = "Lower nested `gml_st.parallel` to `gpu.launch`"; - let dependentDialects = ["::mlir::AffineDialect", - "::mlir::arith::ArithDialect", - "::mlir::gpu::GPUDialect", "::mlir::scf::SCFDialect", - "::mlir::vector::VectorDialect", - "::mlir::memref::MemRefDialect"]; -} - -def VectorizeGmlStLoopsPass : - Pass<"vectorize-gml-st-loops", "mlir::func::FuncOp"> { - let summary = - "Pass to vectorize linalg.generic ops tiled to gml_st.parallel and " # - "gml_st.for loops."; - let constructor = "::mlir::gml_st::createVectorizeGmlStLoopsPass()"; - let options = [ - Option<"vectorizeGmlStOps", "vectorize-gml-st-ops", "bool", "false", - "If true, vectorizes GmlSt ops in addition to linalg ops">, - ListOption<"distributionLabels", "included-distribution-labels", - "std::string", "Distribution labels of gml_st.parallel ops " - "where vectorization is allowed. Empty list signifies that " - "vectorization is allowed within all loops.", - "llvm::cl::ZeroOrMore">, - ]; - let dependentDialects = ["::mlir::vector::VectorDialect"]; -} - -def TransformScatterForCpuPass : - Pass<"xla-cpu-transform-scatter", "mlir::func::FuncOp"> { - let summary = "Transform scatter ops for running on CPU"; - - let constructor = "createTransformScatterForCpuPass()"; -} - -def TransformMatmulForCpuPass : - Pass<"xla-cpu-transform-matmul", "mlir::func::FuncOp"> { - let summary = "Transform matmul ops for running on CPU"; - - let constructor = "createTransformMatmulForCpuPass()"; - - let options = [ - ListOption<"tileSizes", "tile-sizes", "int64_t", - "Tile sizes for a `linalg.matmul`">, - ]; -} - -def TransformMapForCpuPass : - Pass <"gml-st-cpu-transform-map", "mlir::func::FuncOp"> { - let summary = "Transform map ops for running on CPU"; - - let constructor = "::mlir::gml_st::createTransformMapForCpuPass()"; - - let options = [ - Option<"tileSize", "tile-size", "int64_t", "1", - "Tile size for the innermost dimension of `linalg.map`">, - ]; -} - -def TransformTransposeForCpuPass : - Pass<"gml-st-cpu-transform-transpose", "mlir::func::FuncOp"> { - let summary = "Transform transpose ops for running on CPU"; - - let constructor = "createTransformTransposeForCpuPass()"; - - let options = [ - ListOption<"tileSizes", "tile-sizes", "int64_t", - "Tile sizes for a `linalg.transpose`">, - ]; -} diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/rewriters.h b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/rewriters.h deleted file mode 100644 index bf7950f3bd1..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/rewriters.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_REWRITERS_H -#define MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_REWRITERS_H - -#include - -#include "mlir/IR/Builders.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Operation.h" - -namespace mlir { -namespace bufferization { -class BufferizeTypeConverter; -} // namespace bufferization -class MLIRContext; -class RewritePatternSet; - -namespace gml_st { - -/// Populate pattern to bufferize `linalg.tiled_loop`. -void populateTiledLoopBufferizePattern( - MLIRContext *context, - mlir::bufferization::BufferizeTypeConverter *converter, - RewritePatternSet *patterns); - -void populateCollapseMaterializeOpsPatterns(MLIRContext *, RewritePatternSet *); - -} // namespace gml_st -} // namespace mlir - -#endif // MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_REWRITERS_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.td b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.td deleted file mode 100644 index ad35a35da47..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/test_passes.td +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -include "mlir/Pass/PassBase.td" - -def TestGmlStLoopPeeling : Pass<"test-gml-st-loop-peeling", "mlir::func::FuncOp"> { - let summary = "Peel `gml_st.loop`"; - let constructor = "::mlir::gml_st::createTestGmlStLoopPeelingPass()"; - let options = [ - Option<"skip_partial", "skip-partial", "bool", /*default=*/"false", - "Skip loops inside partial iterations during peeling">, - ListOption<"dims", "dims", "unsigned", "Dimensions to peel", - "llvm::cl::OneOrMore">, - ]; -} - -def TestGmlStLoopTiling : Pass<"test-gml-st-loop-tiling", "mlir::func::FuncOp"> { - let summary = "Tile `gml_st.loop`."; - let constructor = "::mlir::gml_st::createTestGmlStLoopTilingPass()"; - let dependentDialects = [ - "AffineDialect", - "gml_st::GmlStDialect", - "linalg::LinalgDialect", - "memref::MemRefDialect" - ]; - let options = [ - ListOption<"tile_sizes", "tile-sizes", "int64_t", "Tile sizes", - "llvm::cl::ZeroOrMore">, - ListOption<"distribution_types", "distribution-types", "std::string", - "Distribution types", - "llvm::cl::ZeroOrMore"> - - ]; -} - -def TestGmlStBufferization - : Pass<"test-gml-st-bufferization", "mlir::ModuleOp"> { - let summary = "Bufferize `gml_st.loop`."; - let constructor = "::mlir::gml_st::createTestGmlStBufferizationPass()"; -} diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.td b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.td deleted file mode 100644 index 3907368cad6..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.td +++ /dev/null @@ -1,113 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef GML_ST_TILING_INTERFACE -#define GML_ST_TILING_INTERFACE - -include "mlir/IR/OpBase.td" - -def TilingInterface : OpInterface<"TilingInterface"> { - let description = [{ - Interface for operations to expose information needed to tile them. - }]; - let cppNamespace = "::mlir::gml_st"; - let methods = [ - InterfaceMethod< - /*desc=*/[{ - Returns a list of operands into which the result of the - tiled implementation is written into. With `tensor` - operands, this will be used as the initial tensor into which - the tiled results are inserted into. With `memref` operands, - this will be the operand into which the result of the tiled - operation is written into. - }], - /*retType=*/"SmallVector", - /*methodName=*/"getDestinationOperands", - /*args=*/(ins "OpBuilder &":$b) - >, - InterfaceMethod< - /*desc=*/[{ - Returns a list of iterator types that describe the number of loops. - }], - /*retType=*/"SmallVector", - /*methodName=*/"getLoopIteratorTypes", - /*args=*/(ins) - >, - InterfaceMethod< - /*desc=*/[{ - Returns a list of ranges that describe the loop bounds and - step for the loops of the operation. - }], - /*retTy=*/"SmallVector", - /*methodName=*/"getIterationDomain", - /*args=*/(ins "OpBuilder &":$b) - >, - InterfaceMethod< - /*desc=*/[{ - Method to generate the tiled implementation of an operation. - - The iteration space of the operation is returned by - `getIterationDomain`. The caller provides the information of the - tile within this iteration space whose implementation the - caller needs. - - `offsets` provides the offset of the tile in the coordinate system - of the original iteration space, i.e., if an iteration space - dimension had non-zero offset, it must be included in the offset - provided here (as opposed to zero-based offset "relative" to the - iteration space). - - `sizes` provides the size of the tile. - - The method returns the operation that is the tiled - implementation. - }], - /*retType=*/"mlir::gml_st::TilingInterface", - /*methodName=*/"getTiledImplementation", - /*args=*/(ins - "OpBuilder &":$b, - "ArrayRef":$offsets, - "ArrayRef":$sizes) - >, - InterfaceMethod< - /*desc=*/[{ - Generates the IR that computes the tile of a result of the - operation. The `offsets` and `sizes` describe the tile of - the output required. This is different from - `getTiledImplementation` which generates the tiled - implementation of the operation given a tile of the - iteration space. This method generates a tiled - implementation of the operation based on the tile of the - result required. This method enables fusion by using tile - and fuse. The method returns failure if the operation can't be - tiled to generate the result tile. In practical terms this - implies it cannot be tiled and fused with its consumers. - - - `offsets` provides the offset of the tile in the coordinate system - of the original iteration space, i.e., if an iteration space - dimension had non-zero offset, it must be included in the offset - provided here (as opposed to zero-based offset "relative" to the - iteration space). - - `sizes` provides the size of the tile. - }], - /*retType=*/"FailureOr", - /*methodName=*/"generateResultTileValue", - /*args=*/(ins - "OpBuilder &":$b, - "unsigned":$resultNumber, - "ArrayRef":$offsets, - "ArrayRef":$sizes) - > - ]; -} -#endif // GML_ST_TILING_INTERFACE diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/transforms.h b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/transforms.h deleted file mode 100644 index 3e6ca879e39..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/transforms.h +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TRANSFORMS_H -#define MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TRANSFORMS_H - -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir/IR/PatternMatch.h" - -namespace mlir { -namespace linalg { -class LinalgOp; -struct TiledLinalgOp; -struct LinalgTilingOptions; -} // namespace linalg -} // namespace mlir - -namespace mlir { -namespace gml_st { - -bool isZero(Value v); - -/// Rewrite a gml_st::LoopOp/ParallelOp/ForOp with bounds/step that potentially -/// do not divide evenly into a gml_st::LoopOp/ParallelOp/ForOp where the step -/// divides the iteration space evenly, followed by another -/// gml_st::LoopOp/ParallelOp/ForOp for the last (partial) iteration (if any). -/// This transformation is called "loop peeling". -/// -/// This function peels the `idx`-th loop of the -/// gml_st::LoopOp/ParallelOp/ForOp. To tile all loops in the loop nest, this -/// function must be called multiple times. -/// -/// After loop peeling, this function tries to simplify/canonicalize affine.min -/// and affine.max ops in the body of the two gml_st::LoopOp/ParallelOp/ForOps. -/// For more details, refer to `mlir::scf::peelAndCanonicalizeForLoop`. -/// -/// The return value indicates whether the loop was rewritten or not. Loops are -/// not rewritten if: -/// * Loop step size is 1 or -/// * Loop bounds and step size are static, and step already divides the -/// iteration space evenly. -/// -/// Note: This function rewrites the given gml_st::LoopOp/ParallelOp/ForOp -/// in-place and clones the gml_st::LoopOp/ParallelOp/ForOp operation for the -/// last iteration. It replaces all uses of the unpeeled -/// gml_st::LoopOp/ParallelOp/ForOp with the results of the newly generated -/// gml_st::LoopOp/ParallelOp/ForOp. -LogicalResult peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter, - LoopOp loopOp, int64_t idx, - LoopOp &result); -LogicalResult peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter, - ParallelOp loopOp, int64_t idx, - ParallelOp &result); -LogicalResult peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter, ForOp loopOp, - int64_t idx, ForOp &result); - -/// Perform standalone tiling of a single LinalgOp by `tileSizes`. -/// An empty vector is interpreted as the identity permutation and the -/// transformation returns early. -/// -/// Return a struct containing the tiled loops in the specified order -/// and the cloned op if successful, llvm::None otherwise. -FailureOr tileLinalgOp( - RewriterBase &b, linalg::LinalgOp op, - const linalg::LinalgTilingOptions &options); - -// Sets the attribute to the `op` that indicates that the op was transformed. -void setTransformationAttr(OpBuilder &b, Operation *op); - -// Removes the attribute that indicates that it was transformed. -void removeTransformationAttr(Operation *op); - -// Checks if `op` has the attribute that indicates that it was transformed. -bool hasTransformationAttr(Operation *op); - -// Checks if `op` has the matching label attribute. -bool hasMatchingLabel(Operation *op, StringRef label); - -// Uncollapse materialize operations with nested tile chains t1, t2, ..., tn. A -// materialize op of the form ... -// `materialize(t1(t2(...(tn(sn)))), arg)` -// ... is expanded into ... -// `materialize(t1(s1), materialize(t2(...(tn(sn))), arg))`. -FailureOr uncollapseMaterializeOp(OpBuilder &b, - MaterializeOp op); - -// Collapse materialize operations with nested tile chains t1, t2, ..., tn, and -// u1, u2, ..., un. A materialize op of the form ... -// `materialize(t1(t2(...(tn(sn)))), materialize(u1(u2(...(un(sn')))), arg))` -// ... is collapsed as ... -// `materialize(t1(t2(...(tn(u1(u2(...(un(sn'))))))), arg)`. -FailureOr collapseMaterializeOp(OpBuilder &b, MaterializeOp op); - -} // namespace gml_st -} // namespace mlir - -#endif // MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TRANSFORMS_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/CMakeLists.txt deleted file mode 100644 index c53be9d74f4..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -# -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -set(LLVM_TARGET_DEFINITIONS lmhlo_passes.td) -mlir_tablegen(lmhlo_passes.h.inc -gen-pass-decls -name AllLmhlo) -add_public_tablegen_target(MLIRLmhloPassIncGen) diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/CMakeLists.txt deleted file mode 100644 index e138afa587f..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -# -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -add_subdirectory(IR) -add_subdirectory(transforms) diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt deleted file mode 100644 index 6441c084a85..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt +++ /dev/null @@ -1,27 +0,0 @@ -# -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Need a separate function because of the .cc vs .cpp used in the one provided by MLIR -set(LLVM_TARGET_DEFINITIONS hlo_ops.td) -mlir_tablegen(hlo_ops.h.inc -gen-op-decls) -mlir_tablegen(hlo_ops.cc.inc -gen-op-defs) -mlir_tablegen(hlo_ops_enums.h.inc -gen-enum-decls) -mlir_tablegen(hlo_ops_enums.cc.inc -gen-enum-defs) -mlir_tablegen(hlo_ops_attrs.h.inc -gen-attrdef-decls) -mlir_tablegen(hlo_ops_attrs.cc.inc -gen-attrdef-defs) -mlir_tablegen(hlo_ops_typedefs.h.inc -gen-typedef-decls --typedefs-dialect=mhlo) -mlir_tablegen(hlo_ops_typedefs.cc.inc -gen-typedef-defs --typedefs-dialect=mhlo) -add_public_tablegen_target(MLIRhlo_opsIncGen) -add_dependencies(mlir-headers MLIRhlo_opsIncGen) diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h deleted file mode 100644 index f1c3094926e..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON_H -#define MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON_H - -// This file defines functionality shared between chlo/mhlo/lhlo dialects. - -#include - -#include "llvm/ADT/SmallSet.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/Operation.h" - -namespace mlir { -namespace hlo { - -// TODO(b/236017415): remove when mhlo uses prefix accessor. -namespace accessor_dispatch { -template -auto getReplicaGroups(OpT op, int) - -> decltype(op.getReplicaGroups(), DenseIntElementsAttr{}) { - return op.getReplicaGroups(); -} -template -auto getReplicaGroups(OpT op, char) - -> decltype(op.replica_groups(), DenseIntElementsAttr{}) { - return op.replica_groups(); -} -} // namespace accessor_dispatch - -// Verifies replica groups attached to collective communication operations. -// If the attribute is not empty, it must be a rank 2 tensor, and each replica -// should appear exactly once. If `is_uniform_sized` is true, then we also check -// that each group is of the same size. If the operation has -// `use_global_device_ids` set, then replica group cannot be empty. -template -LogicalResult verifyReplicaGroups(OpT op, bool isUniformSized) { - DenseIntElementsAttr attr = accessor_dispatch::getReplicaGroups(op, 0); - auto replicaGroupType = attr.getType().dyn_cast(); - if (!replicaGroupType || replicaGroupType.getRank() != 2 || - !replicaGroupType.getElementType().isInteger(/*width=*/64)) - return op.emitOpError( - "replica groups should be a rank 2 tensor of 64 bit integers"); - - if (replicaGroupType.getShape().equals(ArrayRef{0, 0})) { - // verifyReplicaGroups() is used by MHLO and LMHLO, note that MHLO does not - // have attr 'use_global_device_ids' actually. - if (op->hasAttr("use_global_device_ids") && - op->getAttr("use_global_device_ids") - .template cast() - .getValue()) { - return op.emitOpError( - "if `use_global_device_ids` is set, the replica groups cannot be " - "empty"); - } - return success(); - } - - int64_t maxReplicaIdSeen = 0; - llvm::SmallSet replicaSeen; - for (int64_t id : attr.getValues()) { - // Replica groups are stored in a 2D tensor. If the op supports non-uniform - // groups, null replica IDs are stored as -1. - if (id == -1) { - if (isUniformSized) { - return op.emitOpError("Invalid replica id -1"); - } - continue; - } - - if (!replicaSeen.insert(id).second) { - return op.emitOpError("replica id #") << id << " seen more than once"; - } - maxReplicaIdSeen = std::max(maxReplicaIdSeen, id); - } - - for (int64_t id = 0; id <= maxReplicaIdSeen; id++) { - if (!replicaSeen.contains(id)) { - return op.emitOpError("replica id #") - << id << " not seen in replica groups"; - } - } - return success(); -} - -// Verifies the source target pairs attached to collective permute. -LogicalResult verifyCollectivePermuteSourceTargetPairs( - Operation* op, DenseIntElementsAttr attr); - -LogicalResult verifyReduceScatter(Operation* op, TypeRange operandTypes, - TypeRange resultTypes, - uint64_t scatterDimension); - -// Custom formatting for convolution window attributes. -void printWindowAttributes(OpAsmPrinter& p, Operation* op, - llvm::Optional windowStrides, - llvm::Optional padding, - llvm::Optional lhsDilation, - llvm::Optional rhsDilation, - llvm::Optional windowReversal); - -ParseResult parseWindowAttributes(OpAsmParser& parser, - DenseIntElementsAttr& windowStrides, - DenseIntElementsAttr& padding, - DenseIntElementsAttr& lhsDilation, - DenseIntElementsAttr& rhsDilation, - DenseElementsAttr& windowReversal); - -} // namespace hlo -} // namespace mlir - -#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_enums.td b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_enums.td deleted file mode 100644 index 81937627495..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_enums.td +++ /dev/null @@ -1,237 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_ENUMS -#define MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_ENUMS - -include "mlir/IR/EnumAttr.td" -include "mlir/IR/PatternBase.td" - -//===----------------------------------------------------------------------===// -// Precision Config enum definitions. -//===----------------------------------------------------------------------===// - -// These mirror the XLA PrecisionConfig proto enum. -def HLO_PRECISION_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>; -def HLO_PRECISION_HIGH : I32EnumAttrCase<"HIGH", 1>; -def HLO_PRECISION_HIGHEST : I32EnumAttrCase<"HIGHEST", 2>; -def HLO_PRECISION_PACKED_NIBBLE : I32EnumAttrCase<"PACKED_NIBBLE", 3>; - -def HLO_Precision : I32EnumAttr<"Precision", - "XLA precision for an operand. Has backend specific meaning.", - [HLO_PRECISION_DEFAULT, HLO_PRECISION_HIGH, HLO_PRECISION_HIGHEST, HLO_PRECISION_PACKED_NIBBLE]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::mhlo"; -} - -def HLO_PrecisionAttr : EnumAttr; - -// TODO(b/129153247) See if it's possible to also validate the size. -def HLO_PrecisionConfigAttr: - OptionalAttr< - TypedArrayAttrBase>; - -//===----------------------------------------------------------------------===// -// Domain Metadata Kind enum definitions. -//===----------------------------------------------------------------------===// - -// These mirror the XLA FftType proto enum. -def HLO_DOMAIN_KIND_SHARDING : I32EnumAttrCase<"sharding", 0>; - -def HLO_DomainKind : I32EnumAttr<"DomainKind", - "Kind of domain metatdata attached to an HLO domain.", - [HLO_DOMAIN_KIND_SHARDING]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::mhlo"; -} - -def HLO_DomainKindAttr : EnumAttr; - -//===----------------------------------------------------------------------===// -// Fast Fourier Transform Type enum definitions. -//===----------------------------------------------------------------------===// - -// These mirror the XLA FftType proto enum. -def HLO_FFT_TYPE_FFT : I32EnumAttrCase<"FFT", 0>; -def HLO_FFT_TYPE_IFFT : I32EnumAttrCase<"IFFT", 1>; -def HLO_FFT_TYPE_RFFT : I32EnumAttrCase<"RFFT", 2>; -def HLO_FFT_TYPE_IRFFT : I32EnumAttrCase<"IRFFT", 3>; - -def HLO_FftType : I32EnumAttr<"FftType", - "XLA fast fourier transform type.", - [HLO_FFT_TYPE_FFT, HLO_FFT_TYPE_IFFT, - HLO_FFT_TYPE_RFFT, HLO_FFT_TYPE_IRFFT]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::mhlo"; -} - -def HLO_FftTypeAttr : EnumAttr; - -//===----------------------------------------------------------------------===// -// Custom call enum definitions. -//===----------------------------------------------------------------------===// - -// TODO(b/189822916): Remove this enum when all clients are migrated to the -// status-returning API. -def HLO_CUSTOM_CALL_API_VERISON_UNSPECIFIED : - I32EnumAttrCase<"API_VERSION_UNSPECIFIED", 0>; -def HLO_CUSTOM_CALL_API_VERSION_ORIGINAL : - I32EnumAttrCase<"API_VERSION_ORIGINAL", 1>; -def HLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING : - I32EnumAttrCase<"API_VERSION_STATUS_RETURNING", 2>; -def HLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING_UNIFIED : - I32EnumAttrCase<"API_VERSION_STATUS_RETURNING_UNIFIED", 3>; -def HLO_CustomCallApiVersionAttr : - I32EnumAttr<"CustomCallApiVersion", "Custom call API version", [ - HLO_CUSTOM_CALL_API_VERISON_UNSPECIFIED, - HLO_CUSTOM_CALL_API_VERSION_ORIGINAL, - HLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING, - HLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING_UNIFIED - ]> { - let cppNamespace = "::mlir::mhlo"; -} - -//===----------------------------------------------------------------------===// -// Comparison op definitions. -//===----------------------------------------------------------------------===// - -// These mirror the XLA ComparisonDirection enum. -def HLO_COMPARISON_DIRECTION_EQ : I32EnumAttrCase<"EQ", 0>; -def HLO_COMPARISON_DIRECTION_NE : I32EnumAttrCase<"NE", 1>; -def HLO_COMPARISON_DIRECTION_GE : I32EnumAttrCase<"GE", 2>; -def HLO_COMPARISON_DIRECTION_GT : I32EnumAttrCase<"GT", 3>; -def HLO_COMPARISON_DIRECTION_LE : I32EnumAttrCase<"LE", 4>; -def HLO_COMPARISON_DIRECTION_LT : I32EnumAttrCase<"LT", 5>; - -def HLO_ComparisonDirection : I32EnumAttr<"ComparisonDirection", - "Which comparison operation to perform.", - [ - HLO_COMPARISON_DIRECTION_EQ, - HLO_COMPARISON_DIRECTION_NE, - HLO_COMPARISON_DIRECTION_GE, - HLO_COMPARISON_DIRECTION_GT, - HLO_COMPARISON_DIRECTION_LE, - HLO_COMPARISON_DIRECTION_LT - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::mhlo"; -} - -def HLO_ComparisonDirectionAttr : EnumAttr; - -def HLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"::mlir::mhlo::ComparisonTypeAttr()">; -def HLO_COMPARISON_TYPE_NOTYPE : I32EnumAttrCase<"NOTYPE", 0>; -def HLO_COMPARISON_TYPE_FLOAT : I32EnumAttrCase<"FLOAT", 1>; -def HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : I32EnumAttrCase<"TOTALORDER", 2>; -def HLO_COMPARISON_TYPE_SIGNED : I32EnumAttrCase<"SIGNED", 3>; -def HLO_COMPARISON_TYPE_UNSIGNED : I32EnumAttrCase<"UNSIGNED", 4>; - -def HLO_ComparisonType : I32EnumAttr<"ComparisonType", - "Which comparison type to use.", - [ - HLO_COMPARISON_TYPE_NOTYPE, - HLO_COMPARISON_TYPE_FLOAT, - HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER, - HLO_COMPARISON_TYPE_SIGNED, - HLO_COMPARISON_TYPE_UNSIGNED - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::mhlo"; -} - -def HLO_ComparisonTypeAttr : EnumAttr; - -// These mirror the XLA Dequantize mode string enum. -def HLO_MIN_COMBINED : I32EnumAttrCase<"MIN_COMBINED", 0>; - -def HLO_DequantizeMode : I32EnumAttr<"DequantizeMode", - "Dequantization mode. Only MIN_COMBINED is supported.", - [HLO_MIN_COMBINED]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::mhlo"; -} - -def HLO_DequantizeModeAttr : EnumAttr; - -// These mirror the XLA Transpose enum in Triangular Solve options. -def HLO_TRANSPOSE_INVALID : I32EnumAttrCase<"TRANSPOSE_INVALID", 0>; -def HLO_NO_TRANSPOSE : I32EnumAttrCase<"NO_TRANSPOSE", 1>; -def HLO_TRANSPOSE : I32EnumAttrCase<"TRANSPOSE", 2>; -def HLO_ADJOINT : I32EnumAttrCase<"ADJOINT", 3>; - -def HLO_Transpose : I32EnumAttr<"Transpose", - "Transpose options", - [ - HLO_TRANSPOSE_INVALID, - HLO_NO_TRANSPOSE, - HLO_TRANSPOSE, - HLO_ADJOINT - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::mhlo"; -} - -def HLO_TransposeAttr : EnumAttr; - -def HLO_LOOP_FUSION : I32EnumAttrCase<"kLoop", 0>; -def HLO_INPUT_FUSION : I32EnumAttrCase<"kInput", 1>; -def HLO_OUTPUT_FUSION : I32EnumAttrCase<"kOutput", 2>; -def HLO_CUSTOM_FUSION : I32EnumAttrCase<"kCustom", 3>; -def HLO_FusionKind : I32EnumAttr<"FusionKind", "fusion kind", [ - HLO_LOOP_FUSION, HLO_INPUT_FUSION, HLO_OUTPUT_FUSION, HLO_CUSTOM_FUSION -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::mhlo"; -} - -def HLO_RNG_DISTRIBUTION_UNIFORM : I32EnumAttrCase<"UNIFORM", 1>; -def HLO_RNG_DISTRIBUTION_NORMAL : I32EnumAttrCase<"NORMAL", 2>; - -def HLO_RNG_DISTRIBUTION : I32EnumAttr<"RngDistribution", - "XLA PRNG distribution to be used.", - [ - HLO_RNG_DISTRIBUTION_UNIFORM, - HLO_RNG_DISTRIBUTION_NORMAL - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::mhlo"; -} - -def HLO_RngDistributionAttr : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -def HLO_FusionKindAttr : EnumAttr; - -def HLO_RNG_ALGORITHM_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>; -def HLO_RNG_ALGORITHM_THREE_FRY : I32EnumAttrCase<"THREE_FRY", 1>; -def HLO_RNG_ALGORITHM_PHILOX : I32EnumAttrCase<"PHILOX", 2>; - -def HLO_RNG_ALGORITHM : I32EnumAttr<"RngAlgorithm", - "XLA PRNG algorithm to be used.", - [ - HLO_RNG_ALGORITHM_DEFAULT, - HLO_RNG_ALGORITHM_THREE_FRY, - HLO_RNG_ALGORITHM_PHILOX - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::mhlo"; -} - -def HLO_RngAlgorithmAttr : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_ENUMS diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/CMakeLists.txt deleted file mode 100644 index a3176719006..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -# -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -set(LLVM_TARGET_DEFINITIONS passes.td) -mlir_tablegen(passes.h.inc -gen-pass-decls -name LMHLOTransforms) -add_public_tablegen_target(LMHLOTransformsPassIncGen) - -set(LLVM_TARGET_DEFINITIONS gpu_passes.td) -mlir_tablegen(gpu_passes.h.inc -gen-pass-decls -name LMHLOGPUTransforms) -add_public_tablegen_target(LMHLOGPUTransformsPassIncGen) \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/gml_st_pipeline.h b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/gml_st_pipeline.h deleted file mode 100644 index 38b94d5a54a..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/gml_st_pipeline.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_HLO_TRANSFORMS_GML_ST_PIPELINE_H -#define MLIR_HLO_TRANSFORMS_GML_ST_PIPELINE_H - -#include - -#include "mlir/Pass/PassManager.h" -#include "mlir/Pass/PassOptions.h" - -namespace mlir { -struct GmlStPipelineOptions - : public mlir::PassPipelineOptions { - ListOption tileSizes{*this, "tile-sizes", - llvm::cl::desc("Tile sizes")}; - Option lowerToLoops{ - *this, "lower-to-loops", - llvm::cl::desc("Enable bufferization and lowering to SCF dialect for " - "GmlSt and Linalg ops."), - llvm::cl::init(false)}; -}; - -void createGmlStPipeline(mlir::OpPassManager& pm, - const GmlStPipelineOptions& options); - -} // namespace mlir - -#endif // MLIR_HLO_TRANSFORMS_GML_ST_PIPELINE_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lhlo/CMakeLists.txt similarity index 100% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/lhlo/CMakeLists.txt diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/CMakeLists.txt similarity index 75% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/lhlo/IR/CMakeLists.txt index 8f2c2057af1..1cffce7c86f 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/CMakeLists.txt @@ -28,3 +28,31 @@ endfunction() add_mlir_hlo_dialect_separate_files(lhlo_ops) add_mlir_interface(lhlo_structured_interface) + +include_directories(BEFORE + ${CMAKE_CURRENT_BINARY_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}) + +add_mlir_library(LmhloStructuredInterface + lhlo_structured_interface.cc + + LINK_LIBS PUBLIC + MLIRIR + + DEPENDS + MLIRlhlo_structured_interfaceIncGen +) + +add_mlir_dialect_library(LmhloDialect + lhlo_ops.cc + + DEPENDS + MLIRlhlo_opsIncGen + + LINK_LIBS PUBLIC + HloOpsCommon + LmhloStructuredInterface + MhloDialect + MLIRIR +) + diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_dialect.td b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_dialect.td similarity index 94% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_dialect.td rename to tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_dialect.td index 684b5797301..2f973a38e10 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_dialect.td +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_dialect.td @@ -22,8 +22,7 @@ include "mlir/IR/OpBase.td" def LHLO_Dialect : Dialect { let name = "lmhlo"; let cppNamespace = "::mlir::lmhlo"; - - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // LHLO_DIALECT diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/IR/lhlo_ops.cc b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.cc similarity index 91% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/IR/lhlo_ops.cc rename to tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.cc index c282009f9a8..e9ebf1e4de0 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/IR/lhlo_ops.cc +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.cc @@ -15,14 +15,16 @@ limitations under the License. // This file defines the operations used in the LMHLO dialect. -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" +#include "lhlo/IR/lhlo_ops.h" #include #include #include +#include #include +#include "lhlo/utils/lhlo_utils.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" @@ -32,13 +34,12 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" -#include "mlir-hlo/utils/lhlo_utils.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops_common.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Attributes.h" -#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -54,21 +55,23 @@ limitations under the License. #include "mlir/IR/Value.h" #define GET_ATTRDEF_CLASSES -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops_structs.cc.inc" +#include "lhlo/IR/lhlo_ops_structs.cc.inc" namespace mlir { namespace lmhlo { +using mhlo::TokenType; + LmhloDialect::LmhloDialect(MLIRContext* context) : Dialect(getDialectNamespace(), context, TypeID::get()) { context->loadDialect(); addOperations< #define GET_OP_LIST -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.cc.inc" +#include "lhlo/IR/lhlo_ops.cc.inc" >(); addAttributes< #define GET_ATTRDEF_LIST -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops_structs.cc.inc" +#include "lhlo/IR/lhlo_ops_structs.cc.inc" >(); } @@ -120,13 +123,19 @@ LogicalResult AbsOp::verify() { // TODO(jurahul): Add verification for output shape. LogicalResult AllGatherOp::verify() { AllGatherOp op = *this; - return mlir::hlo::verifyReplicaGroups(op, /*isUniformSized=*/true); + return mlir::hlo::verifyReplicaGroups(op.getLoc(), op.getReplicaGroups(), + /*allGroupsMustHaveSameSize=*/true, + op.getUseGlobalDeviceIds(), + /*expectedGroupSize=*/std::nullopt); } // TODO(jurahul): Add verification for output shape. LogicalResult AllToAllOp::verify() { AllToAllOp op = *this; - return mlir::hlo::verifyReplicaGroups(op, /*isUniformSized=*/true); + return mlir::hlo::verifyReplicaGroups(op.getLoc(), op.getReplicaGroups(), + /*allGroupsMustHaveSameSize=*/true, + /*useGlobalDeviceIds=*/false, + /*expectedGroupSize=*/std::nullopt); } //===----------------------------------------------------------------------===// @@ -144,7 +153,10 @@ LogicalResult AllReduceOp::verify() { LogicalResult ReduceScatterOp::verify() { ReduceScatterOp op = *this; - if (failed(mlir::hlo::verifyReplicaGroups(op, /*isUniformSized=*/true))) + if (failed(hlo::verifyReplicaGroups(op.getLoc(), op.getReplicaGroups(), + /*allGroupsMustHaveSameSize=*/true, + op.getUseGlobalDeviceIds(), + /*expectedGroupSize=*/std::nullopt))) return failure(); if (failed(mlir::hlo::verifyReduceScatter( op, /*operandTypes=*/op.getInputs().getTypes(), @@ -158,7 +170,7 @@ LogicalResult ReduceScatterOp::verify() { // CaseOp //===----------------------------------------------------------------------===// -void CaseOp::getSuccessorRegions(Optional index, +void CaseOp::getSuccessorRegions(std::optional index, ArrayRef /*operands*/, SmallVectorImpl& regions) { // If the predecessor is the CaseOp, branch to all other branches. @@ -356,7 +368,7 @@ struct RemoveCopyInReduceBody : public OpRewritePattern { SmallVector(oldReduceBody.getNumArguments(), reduce.getLoc())); - mlir::BlockAndValueMapping bvm; + mlir::IRMapping bvm; for (auto item : llvm::zip(reduce.getBody().front().getArguments(), newBlock->getArguments())) { bvm.map(std::get<0>(item), std::get<1>(item)); @@ -396,7 +408,7 @@ LogicalResult ReduceWindowOp::verify() { // WhileOp //===----------------------------------------------------------------------===// -void WhileOp::getSuccessorRegions(Optional index, +void WhileOp::getSuccessorRegions(std::optional index, ArrayRef /*operands*/, SmallVectorImpl& regions) { // If the predecessor is the WhileOp or the body region, branch into the @@ -422,7 +434,7 @@ using mlir::hlo::printWindowAttributes; } // namespace mlir #define GET_OP_CLASSES -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.cc.inc" +#include "lhlo/IR/lhlo_ops.cc.inc" namespace mlir { namespace lmhlo { @@ -436,7 +448,7 @@ void FusionOp::build(OpBuilder& builder, OperationState& result, FusionOp::ensureTerminator(*bodyRegion, builder, result.location); } -void FusionOp::getSuccessorRegions(Optional index, +void FusionOp::getSuccessorRegions(std::optional index, ArrayRef /*operands*/, SmallVectorImpl& regions) { // If the predecessor is the fusion region, jump back to the parent op. diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h similarity index 84% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h rename to tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h index cd355cfd6bf..f5a1f7c013c 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.h @@ -15,13 +15,13 @@ limitations under the License. // This file defines the operations used in the LHLO dialect. -#ifndef MLIR_HLO_DIALECT_LHLO_IR_LHLO_OPS_H -#define MLIR_HLO_DIALECT_LHLO_IR_LHLO_OPS_H +#ifndef MLIR_HLO_LHLO_IR_LHLO_OPS_H +#define MLIR_HLO_LHLO_IR_LHLO_OPS_H +#include "lhlo/IR/lhlo_ops_structs.h" +#include "lhlo/IR/lhlo_structured_interface.h" #include "llvm/ADT/StringRef.h" -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops_structs.h" -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_structured_interface.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -59,6 +59,6 @@ class LmhloDialect : public Dialect { } // end namespace mlir #define GET_OP_CLASSES -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h.inc" +#include "lhlo/IR/lhlo_ops.h.inc" -#endif // MLIR_HLO_DIALECT_LHLO_IR_LHLO_OPS_H +#endif // MLIR_HLO_LHLO_IR_LHLO_OPS_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.td b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.td similarity index 93% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.td rename to tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.td index a7e1af82d60..f6b36221f96 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.td +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops.td @@ -20,7 +20,7 @@ limitations under the License. // to merge these two files together, but we need to consider the following // obstacles: // * We need to have a common representation for arguments. That is to say, -// HLO_Array translates to HLO_Tensor in HLO dialect, and +// HLO_Array translates to MHLO_Tensor in HLO dialect, and // Arg, "", [Mem(Read|Write)]> in LHLO. Array types within // tuples also need to be transformed. // * As of now, TableGen's dag functions are not sufficient to accomplish the @@ -40,10 +40,10 @@ include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" -include "mlir-hlo/Dialect/lhlo/IR/lhlo_dialect.td" -include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops_base.td" -include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops_structs.td" -include "mlir-hlo/Dialect/lhlo/IR/lhlo_structured_interface.td" +include "lhlo/IR/lhlo_dialect.td" +include "lhlo/IR/lhlo_ops_base.td" +include "lhlo/IR/lhlo_ops_structs.td" +include "lhlo/IR/lhlo_structured_interface.td" //===----------------------------------------------------------------------===// // LMHLO nullary op definitions. @@ -161,6 +161,15 @@ def LHLO_CosineOp: LHLO_UnaryElementwiseOp<"cosine", LHLO_FpOrComplexBuffer> { https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. }]; } +def LHLO_TanOp: LHLO_UnaryElementwiseOp<"tan", LHLO_FpOrComplexBuffer> { + let summary = "Tan operator"; + let description = [{ + Returns `Tan(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential", LHLO_FpOrComplexBuffer> { let summary = "Exponential operator"; let description = [{ @@ -631,15 +640,16 @@ def LHLO_CustomCallOp : LHLO_Op<"custom_call", [AttrSizedOperandSegments]> { Arg, "", [MemWrite]>:$output, StrAttr:$call_target_name, DefaultValuedOptionalAttr:$has_side_effect, - DefaultValuedStrAttr:$backend_config, + OptionalAttr>:$backend_config, // TODO(b/189822916): Remove this field when all clients are migrated to // the status-returning API. - DefaultValuedOptionalAttr: $api_version, OptionalAttr:$target_arg_mapping ); let hasVerifier = 1; + let regions = (region AnyRegion:$called_computation); } //===----------------------------------------------------------------------===// @@ -662,8 +672,8 @@ def LHLO_CompareOp: LHLO_Op<"compare", [Elementwise]> { Arg:$rhs, Arg:$out, OptionalAttr:$broadcast_dimensions, - HLO_ComparisonDirectionAttr:$comparison_direction, - OptionalAttr:$compare_type + MHLO_ComparisonDirectionAttr:$comparison_direction, + OptionalAttr:$compare_type ); } @@ -874,7 +884,7 @@ def LHLO_ConvolutionOp : LHLO_Op<"convolution", []> { Arg:$lhs, Arg:$rhs, Arg:$output), - ConvolutionAttributes.attributes); + MHLO_ConvolutionAttributes.attributes); code extraClassDeclaration = [{ bool hasWindowReversal() { @@ -921,8 +931,8 @@ def LHLO_DotOp: LHLO_Op<"dot", []> { let arguments = (ins Arg:$lhs, Arg:$rhs, - DotDimensionNumbers:$dot_dimension_numbers, - HLO_PrecisionConfigAttr:$precision_config, + MHLO_DotDimensionNumbers:$dot_dimension_numbers, + MHLO_PrecisionConfigAttr:$precision_config, Arg:$output ); } @@ -931,7 +941,7 @@ def LHLO_GatherOp: LHLO_Op<"gather", []> { let arguments = (ins Arg:$operand, Arg:$start_indices, - GatherDimensionNumbers:$dimension_numbers, + MHLO_GatherDimensionNumbers:$dimension_numbers, I64ElementsAttr:$slice_sizes, Arg:$output ); @@ -964,7 +974,7 @@ def LHLO_ScatterOp: LHLO_Op<"scatter", []> { Arg:$scatter_indices, Arg:$updates, Arg:$output, - ScatterDimensionNumbers:$scatter_dimension_numbers, + MHLO_ScatterDimensionNumbers:$scatter_dimension_numbers, DefaultValuedOptionalAttr:$indices_are_sorted, DefaultValuedOptionalAttr:$unique_indices ); @@ -1091,7 +1101,7 @@ class LHLO_CollectiveCommunicationOp traits = []> : Arg, "", [MemWrite]>:$outputs, I64ElementsAttr:$replica_groups, DefaultValuedOptionalAttr:$constrain_layout, - OptionalAttr:$channel_id, + OptionalAttr:$channel_id, DefaultValuedOptionalAttr:$use_global_device_ids ); let hasVerifier = 1; @@ -1162,7 +1172,7 @@ def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]> Arg:$operand, Arg:$output, I64ElementsAttr:$source_target_pairs, - OptionalAttr:$channel_id + OptionalAttr:$channel_id ); let hasVerifier = 1; } @@ -1178,7 +1188,7 @@ def LHLO_FftOp: LHLO_Op<"fft", []> { let arguments = (ins Arg:$operand, Arg:$output, - HLO_FftTypeAttr:$fft_type, + MHLO_FftTypeAttr:$fft_type, I64ElementsAttr:$fft_length ); } @@ -1289,10 +1299,10 @@ def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType BoolAttr:$left_side, BoolAttr:$lower, BoolAttr:$unit_diagonal, - HLO_TransposeAttr:$transpose_a, - HLO_LayoutAttr:$layout_a, - HLO_LayoutAttr:$layout_b, - HLO_LayoutAttr:$layout_output + MHLO_TransposeAttr:$transpose_a, + MHLO_LayoutAttr:$layout_a, + MHLO_LayoutAttr:$layout_b, + MHLO_LayoutAttr:$layout_output ); } @@ -1345,6 +1355,90 @@ def LHLO_SortOp: LHLO_Op<"sort", [SameVariadicOperandSize, SameOperandsShape]> { let regions = (region SizedRegion<1>:$comparator); } +//===----------------------------------------------------------------------===// +// Point-to-point communication operations. +//===----------------------------------------------------------------------===// + +def LHLO_SendOp : LHLO_Op<"send", []> { + + let summary = "Send operator"; + + let description = [{ + Sends the given operand data to a Recv instruction in another computation + that shares the same channel handle. Does not return any data. Send is an + asynchronous operation, and must be paired with a SendDone operation to + wait for the completion of the data transfer. + + See https://www.tensorflow.org/xla/operation_semantics#send. + }]; + + let arguments = (ins + Arg, "", [MemRead]>:$inputs, + MHLO_ChannelHandle:$channel_handle, + DefaultValuedOptionalAttr:$is_host_transfer, + DefaultValuedOptionalAttr:$frontend_attributes + ); + + let results = (outs MHLO_Token:$token); +} + +def LHLO_RecvOp : LHLO_Op<"recv", []> { + + let summary = "Recv operator"; + + let description = [{ + Receives data of the given shape from a Send instruction in another + computation that shares the same channel handle. Recv is an asynchronous + operation, and must be paired with a RecvDone operation to wait for the + completion of the data transfer. + + See https://www.tensorflow.org/xla/operation_semantics#recv. + }]; + + let arguments = (ins + Arg, "", [MemWrite]>:$outputs, + MHLO_ChannelHandle:$channel_handle, + DefaultValuedOptionalAttr:$is_host_transfer, + DefaultValuedOptionalAttr:$frontend_attributes + ); + + let results = (outs MHLO_Token:$token); +} + +def LHLO_SendDoneOp : LHLO_Op<"send_done", []> { + + let summary = "SendDone operator"; + + let description = [{ + Waits for the completion of corresponding Send operation data transfer. + + See https://www.tensorflow.org/xla/operation_semantics#send. + }]; + + let arguments = (ins + MHLO_Token:$token, + MHLO_ChannelHandle:$channel_handle, + DefaultValuedOptionalAttr:$is_host_transfer + ); +} + +def LHLO_RecvDoneOp : LHLO_Op<"recv_done", []> { + + let summary = "RecvDone operator"; + + let description = [{ + Waits for the completion of corresponding Recv operation data transfer. + + See https://www.tensorflow.org/xla/operation_semantics#recv. + }]; + + let arguments = (ins + MHLO_Token:$token, + MHLO_ChannelHandle:$channel_handle, + DefaultValuedOptionalAttr:$is_host_transfer + ); +} + //===----------------------------------------------------------------------===// // Late operations //===----------------------------------------------------------------------===// @@ -1426,7 +1520,7 @@ def TerminatorOp : }]; let builders = [ OpBuilder<(ins "ValueRange":$operands), - [{ build($_builder, $_state, llvm::None, operands, llvm::None); }]>]; + [{ build($_builder, $_state, std::nullopt, operands, std::nullopt); }]>]; } def LHLO_RealDynamicSliceOp: LHLO_Op< @@ -1475,8 +1569,8 @@ def LHLO_DotGeneralOp: LHLO_Op<"dot_general", []> { let arguments = (ins Arg:$lhs, Arg:$rhs, - DotDimensionNumbers:$dot_dimension_numbers, - HLO_PrecisionConfigAttr:$precision_config, + MHLO_DotDimensionNumbers:$dot_dimension_numbers, + MHLO_PrecisionConfigAttr:$precision_config, Arg:$output ); } @@ -1491,7 +1585,7 @@ def LHLO_DynamicGatherOp: LHLO_Op<"dynamic_gather", []> { Arg:$operand, Arg:$start_indices, Arg:$slice_sizes, - GatherDimensionNumbers:$dimension_numbers, + MHLO_GatherDimensionNumbers:$dimension_numbers, Arg:$output ); } @@ -1568,7 +1662,7 @@ def LHLO_DynamicConvOp : LHLO_Op<"dynamic_conv", []> { Arg:$rhs, Arg:$d_padding, Arg:$output), - ConvolutionAttributes.attributes); + MHLO_ConvolutionAttributes.attributes); } def LHLO_DynamicReshapeOp: LHLO_Op<"dynamic_reshape", []> { diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops_base.td b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops_base.td similarity index 82% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops_base.td rename to tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops_base.td index 58cf86ddc8d..0ba78dfd698 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops_base.td +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops_base.td @@ -18,14 +18,14 @@ limitations under the License. include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/IR/OpBase.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.td" +include "mhlo/IR/hlo_ops_common.td" //===----------------------------------------------------------------------===// // LMHLO type definitions. //===----------------------------------------------------------------------===// // Any integer tensor types -def LHLO_IntBuffer : MemRefOf<[HLO_Int]>; +def LHLO_IntBuffer : MemRefOf<[MHLO_Int]>; // Any floating-point tensor types def LHLO_FpBuffer : MemRefOf<[AnyFloat]>; @@ -34,16 +34,16 @@ def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>; def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>; -def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>; +def LHLO_PredBuffer : MemRefOf<[MHLO_Pred]>; // Any integer or floating-point tensor types -def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>; +def LHLO_IntOrFpBuffer : MemRefOf<[MHLO_Int, AnyFloat]>; -def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>; +def LHLO_PredOrIntBuffer : MemRefOf<[MHLO_Int, MHLO_Pred]>; def LHLO_Buffer : MemRefOf<[AnyFloat, AnyInteger, AnyComplex]>; -def LHLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>; +def LHLO_DimensionValue : AnyTypeOf<[Index, MHLO_Pred, MHLO_Int]>; // Dynamic representation of a shape vector def LHLO_DimensionBuffer : MemRefRankOf<[LHLO_DimensionValue], [1]>; diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops_structs.h b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops_structs.h similarity index 81% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops_structs.h rename to tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops_structs.h index 534d9ffdc12..593249a71eb 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops_structs.h +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops_structs.h @@ -15,8 +15,8 @@ limitations under the License. // This file defines structures used in LMHLO dialect. -#ifndef MLIR_HLO_DIALECT_LHLO_IR_LHLO_OPS_STRUCTS_H -#define MLIR_HLO_DIALECT_LHLO_IR_LHLO_OPS_STRUCTS_H +#ifndef MLIR_HLO_LHLO_IR_LHLO_OPS_STRUCTS_H +#define MLIR_HLO_LHLO_IR_LHLO_OPS_STRUCTS_H #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -25,6 +25,6 @@ limitations under the License. // Order matters, this .inc header is not self-contained, and relies on the // #includes above. #define GET_ATTRDEF_CLASSES -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops_structs.h.inc" +#include "lhlo/IR/lhlo_ops_structs.h.inc" -#endif // MLIR_HLO_DIALECT_LHLO_IR_LHLO_OPS_STRUCTS_H +#endif // MLIR_HLO_LHLO_IR_LHLO_OPS_STRUCTS_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops_structs.td b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops_structs.td similarity index 97% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops_structs.td rename to tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops_structs.td index dd0601cfc28..44a3650fa69 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops_structs.td +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_ops_structs.td @@ -16,7 +16,7 @@ limitations under the License. #ifndef LHLO_OPS_STRUCTS #define LHLO_OPS_STRUCTS -include "mlir-hlo/Dialect/lhlo/IR/lhlo_dialect.td" +include "lhlo/IR/lhlo_dialect.td" include "mlir/IR/AttrTypeBase.td" // This attribute defines information about how arguments to the LHLO custom diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/IR/lhlo_structured_interface.cc b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.cc similarity index 84% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/IR/lhlo_structured_interface.cc rename to tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.cc index a6cb1103cb7..73bd3450c16 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/IR/lhlo_structured_interface.cc +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_structured_interface.h" +#include "lhlo/IR/lhlo_structured_interface.h" namespace mlir { namespace lmhlo { -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_structured_interface.cpp.inc" +#include "lhlo/IR/lhlo_structured_interface.cpp.inc" } // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_structured_interface.h b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.h similarity index 74% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_structured_interface.h rename to tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.h index 29623435b00..0a584db58c4 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_structured_interface.h +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_LHLO_IR_LHLO_STRUCTURED_INTERFACE_H -#define MLIR_HLO_DIALECT_LHLO_IR_LHLO_STRUCTURED_INTERFACE_H +#ifndef MLIR_HLO_LHLO_IR_LHLO_STRUCTURED_INTERFACE_H +#define MLIR_HLO_LHLO_IR_LHLO_STRUCTURED_INTERFACE_H #include "mlir/IR/OpDefinition.h" /// Include the generated interface declarations. -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_structured_interface.h.inc" +#include "lhlo/IR/lhlo_structured_interface.h.inc" -#endif // MLIR_HLO_DIALECT_LHLO_IR_LHLO_STRUCTURED_INTERFACE_H +#endif // MLIR_HLO_LHLO_IR_LHLO_STRUCTURED_INTERFACE_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_structured_interface.td b/tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.td similarity index 100% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_structured_interface.td rename to tensorflow/compiler/xla/mlir_hlo/lhlo/IR/lhlo_structured_interface.td diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/CMakeLists.txt similarity index 73% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/CMakeLists.txt index f88a2e43250..b1a44b270e6 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/CMakeLists.txt @@ -13,17 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +set(LLVM_TARGET_DEFINITIONS lmhlo_passes.td) +mlir_tablegen(lmhlo_passes.h.inc -gen-pass-decls -name AllLmhlo) +add_public_tablegen_target(MLIRLmhloPassIncGen) + include_directories(BEFORE ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) add_mlir_library(LmhloPasses - legalize_to_tensor_op.cc + legalize_to_tensor_op/legalize_to_tensor_op.cc lhlo_elemental_utils.cc - lhlo_fuse_linalg.cc - lhlo_legalize_to_affine.cc - lhlo_legalize_to_gpu.cc - lhlo_legalize_to_parallel_loops.cc + lhlo_legalize_to_affine/lhlo_legalize_to_affine.cc + lhlo_legalize_to_gpu/lhlo_legalize_to_gpu.cc + lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc DEPENDS MLIRlhlo_opsIncGen @@ -38,7 +42,6 @@ add_mlir_library(LmhloPasses MLIRComplexDialect MLIRGPUOps MLIRLinalgDialect - MLIRLinalgAnalysis MLIRLinalgTransforms MLIRMhloUtils MLIRIR diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/legalize_to_tensor_op.cc b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/legalize_to_tensor_op/legalize_to_tensor_op.cc similarity index 96% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/legalize_to_tensor_op.cc rename to tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/legalize_to_tensor_op/legalize_to_tensor_op.cc index 0c5ea02abe1..a53eb342c94 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/legalize_to_tensor_op.cc +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/legalize_to_tensor_op/legalize_to_tensor_op.cc @@ -16,7 +16,7 @@ limitations under the License. // This file implements logic for lowering bufferization.to_tensor ops that are // inserted during `mhlo-legalize-to-lmhlo`. -#include "mlir-hlo/Dialect/lhlo/transforms/passes.h" +#include "lhlo/transforms/passes.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -34,7 +34,7 @@ namespace mlir { namespace lmhlo { #define GEN_PASS_DEF_LEGALIZETOTENSOROPPASS -#include "mlir-hlo/Dialect/lhlo/transforms/lmhlo_passes.h.inc" +#include "lhlo/transforms/lmhlo_passes.h.inc" namespace { using shape::ShapeOfOp; diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/lhlo_elemental_utils.cc b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lhlo_elemental_utils.cc similarity index 96% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/lhlo_elemental_utils.cc rename to tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lhlo_elemental_utils.cc index 116df06fd0e..85183e5b020 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/lhlo_elemental_utils.cc +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lhlo_elemental_utils.cc @@ -16,12 +16,11 @@ limitations under the License. // This file provides basic utilities for the elemental lowering of // each node -#include "mlir-hlo/Dialect/lhlo/transforms/lhlo_elemental_utils.h" +#include "lhlo/transforms/lhlo_elemental_utils.h" +#include "lhlo/IR/lhlo_ops.h" +#include "lhlo/transforms/map_lmhlo_to_scalar_op.h" #include "llvm/Support/Debug.h" -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" -#include "mlir-hlo/Dialect/lhlo/transforms/map_lmhlo_to_scalar_op.h" -#include "mlir-hlo/utils/codegen_utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -32,6 +31,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "utils/codegen_utils.h" using mlir::memref::DimOp; using mlir::memref::LoadOp; @@ -173,7 +173,7 @@ Value elementalLowerImplForBroadcastInDimOps(OpBuilder* b, Location loc, auto zero = b->create( loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 0)); inputIndex.push_back(zero); - } else if (staticDimSize == ShapedType::kDynamicSize) { + } else if (staticDimSize == ShapedType::kDynamic) { // we are not sure if this dim is to be broadcasted at compile time auto dimSize = b->create(loc, operandMemref, inputDim); auto one = b->create( @@ -253,7 +253,7 @@ memref::ReinterpretCastOp createMemRef1DReinterpretCast(OpBuilder& b, Value zero = b.create( loc, b.getIndexType(), b.getIntegerAttr(b.getIndexType(), 0)); auto memref1dType = - MemRefType::get({ShapedType::kDynamicSize}, memrefTy.getElementType(), + MemRefType::get({ShapedType::kDynamic}, memrefTy.getElementType(), b.getMultiDimIdentityMap(1), memrefTy.getMemorySpace()); return b.create( loc, memref1dType, memref, zero, ValueRange{size}, ValueRange{stride}); diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/lhlo_elemental_utils.h b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lhlo_elemental_utils.h similarity index 92% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/lhlo_elemental_utils.h rename to tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lhlo_elemental_utils.h index 88dea945c59..ac906415a4a 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/lhlo_elemental_utils.h +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lhlo_elemental_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_LHLO_TRANSFORMS_LHLO_ELEMENTAL_UTILS_H -#define MLIR_HLO_DIALECT_LHLO_TRANSFORMS_LHLO_ELEMENTAL_UTILS_H +#ifndef MLIR_HLO_LHLO_TRANSFORMS_LHLO_ELEMENTAL_UTILS_H +#define MLIR_HLO_LHLO_TRANSFORMS_LHLO_ELEMENTAL_UTILS_H #include "mlir/IR/Builders.h" @@ -71,4 +71,4 @@ memref::LoadOp createOffsetLoad(OpBuilder& b, Location loc, Value memref, } // namespace lmhlo } // namespace mlir -#endif // MLIR_HLO_DIALECT_LHLO_TRANSFORMS_LHLO_ELEMENTAL_UTILS_H +#endif // MLIR_HLO_LHLO_TRANSFORMS_LHLO_ELEMENTAL_UTILS_H diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_affine/lhlo_legalize_to_affine.cc similarity index 99% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/lhlo_legalize_to_affine.cc rename to tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_affine/lhlo_legalize_to_affine.cc index ad53a5ab9ae..0e75f30e72b 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_affine/lhlo_legalize_to_affine.cc @@ -15,10 +15,11 @@ limitations under the License. // This file implements logic for lowering LHLO dialect to Affine dialect. +#include #include -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" -#include "mlir-hlo/Dialect/lhlo/transforms/map_lmhlo_to_scalar_op.h" +#include "lhlo/IR/lhlo_ops.h" +#include "lhlo/transforms/map_lmhlo_to_scalar_op.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinTypes.h" @@ -30,7 +31,7 @@ namespace mlir { namespace lmhlo { #define GEN_PASS_DEF_LHLOLEGALIZETOAFFINEPASS -#include "mlir-hlo/Dialect/lhlo/transforms/lmhlo_passes.h.inc" +#include "lhlo/transforms/lmhlo_passes.h.inc" namespace { @@ -212,7 +213,7 @@ static void fillBuffer(Location loc, Value buffer, Value fillValue, SmallVector ivs(rank); AffineForOp forOp; for (unsigned i = 0; i < rank; ++i) { - forOp = builder.create(loc, llvm::None, lbMap, dimSizes[i], + forOp = builder.create(loc, std::nullopt, lbMap, dimSizes[i], idSymMap); builder.setInsertionPointToStart(forOp.getBody()); ivs[i] = forOp.getInductionVar(); @@ -223,7 +224,7 @@ static void fillBuffer(Location loc, Value buffer, Value fillValue, fillValueType.isIntOrFloat()) && "init value has to be a 0-d memref or int or fp"); Value initVal = fillMemRefType ? builder.create( - loc, fillValue, /*indices=*/llvm::None) + loc, fillValue, /*indices=*/std::nullopt) : fillValue; builder.create(loc, initVal, buffer, ivs); } diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_gpu/lhlo_legalize_to_gpu.cc similarity index 94% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/lhlo_legalize_to_gpu.cc rename to tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_gpu/lhlo_legalize_to_gpu.cc index 0fd72820d6f..e3794a13620 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/lhlo_legalize_to_gpu.cc +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_gpu/lhlo_legalize_to_gpu.cc @@ -16,10 +16,11 @@ limitations under the License. // This file implements logic for lowering LHLO dialect to GPU dialect. #include +#include +#include "lhlo/IR/lhlo_ops.h" +#include "lhlo/transforms/map_lmhlo_to_scalar_op.h" #include "llvm/ADT/ArrayRef.h" -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" -#include "mlir-hlo/Dialect/lhlo/transforms/map_lmhlo_to_scalar_op.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -28,7 +29,7 @@ limitations under the License. #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Attributes.h" -#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -43,7 +44,7 @@ namespace mlir { namespace lmhlo { #define GEN_PASS_DEF_LHLOLEGALIZETOGPUPASS -#include "mlir-hlo/Dialect/lhlo/transforms/lmhlo_passes.h.inc" +#include "lhlo/transforms/lmhlo_passes.h.inc" namespace { @@ -125,9 +126,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { // inline the body. auto output = *reduceOp.getOut().begin(); auto resType = MemRefType::get( - llvm::None, getElementTypeOrSelf(output.getType()), - makeStridedLinearLayoutMap(llvm::None, - MemRefType::getDynamicStrideOrOffset(), + std::nullopt, getElementTypeOrSelf(output.getType()), + makeStridedLinearLayoutMap(std::nullopt, ShapedType::kDynamic, rewriter.getContext())); OpFoldResult offset = launchOp.getThreadIds().x; auto oneAttr = rewriter.getI64IntegerAttr(1); @@ -152,7 +152,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { // Now copy over the actual body of the reduction, leaving out the // terminator. - BlockAndValueMapping mapping; + IRMapping mapping; mapping.map(reduceOp.getBody().getArgument(0), accumulator); mapping.map(reduceOp.getBody().getArgument(1), rhs); mapping.map(reduceOp.getBody().getArgument(2), accumulator); diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc similarity index 98% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/lhlo_legalize_to_parallel_loops.cc rename to tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc index bc6c354ce5a..e192c0491b5 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + +#include "lhlo/IR/lhlo_ops.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -30,7 +32,7 @@ namespace mlir { namespace lmhlo { #define GEN_PASS_DEF_LHLOLEGALIZETOPARALLELLOOPSPASS -#include "mlir-hlo/Dialect/lhlo/transforms/lmhlo_passes.h.inc" +#include "lhlo/transforms/lmhlo_passes.h.inc" namespace { @@ -56,7 +58,7 @@ Value applySingleResultLhloCode(Location loc, ValueRange operands, b->create(loc, operand.value(), argBufs[operand.index()]); } // Clone the ops from `lhlo_block`. - BlockAndValueMapping mapping; + IRMapping mapping; mapping.map(lhloBlock->getArguments(), argBufs); for (auto& nested : lhloBlock->without_terminator()) { auto* clone = b->clone(nested, mapping); @@ -84,7 +86,7 @@ void convertToReductionOperator(Location loc, scf::ReduceOp reduceOp, // to extract dimension at runtime. Value getStaticOrDynamicDim(mlir::Location loc, Value shapedValue, size_t dimIndex, int64_t dim, OpBuilder* b) { - return dim == ShapedType::kDynamicSize + return dim == ShapedType::kDynamic ? (Value)b->create(loc, shapedValue, dimIndex) : (Value)b->create(loc, dim); } @@ -208,7 +210,7 @@ class ReduceOpConverter : public OpConversionPattern { createReduceOpInNestedParallelLoops(reduceOp, &rewriter); convertToReductionOperator(reduceOp.getLoc(), scfReduceOp, &reduceOp.getBody().front(), &rewriter); - rewriter.replaceOp(reduceOp, llvm::None); + rewriter.replaceOp(reduceOp, std::nullopt); return success(); } @@ -384,7 +386,7 @@ class ReduceWindowOpConverter convertToReductionOperator(reduceWindowOp.getLoc(), reduceOp, &reduceWindowOp.getBody().front(), &rewriter); - rewriter.replaceOp(reduceWindowOp, llvm::None); + rewriter.replaceOp(reduceWindowOp, std::nullopt); return success(); } @@ -519,7 +521,7 @@ class SelectAndScatterOpConverter &sAndSOp.getScatter().front(), &rmwBuilder); rmwBuilder.create(loc, accResult); - rewriter.replaceOp(sAndSOp, llvm::None); + rewriter.replaceOp(sAndSOp, std::nullopt); return success(); } diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/lmhlo_passes.td b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lmhlo_passes.td similarity index 74% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/lmhlo_passes.td rename to tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lmhlo_passes.td index 14ee607c671..8cddbbf7dc9 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/lmhlo_passes.td +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/lmhlo_passes.td @@ -15,18 +15,6 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def LhloFuseLinalgPass : Pass<"lhlo-fuse-linalg", "func::FuncOp"> { - let summary = "Greedily fuse linalg ops obtained after LHLO lowering."; - let constructor = "createLhloFuseLinalgPass()"; - let options = [ - Option<"use_parallel_loops_", "use-parallel-loops", "bool", - /*default=*/"false", "Tiles GenericOp consumer to parallel loops before linalg fusion">, - ListOption<"tile_sizes_", "tile-sizes", "unsigned", - "Faster memory space number to promote fusion buffers to", - "llvm::cl::ZeroOrMore">, - ]; -} - def LhloLegalizeToAffinePass : Pass<"lhlo-legalize-to-affine", "func::FuncOp"> { let summary = "Legalize from LHLO dialect to affine dialect."; let constructor = "createLhloLegalizeToAffinePass()"; diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/map_hlo_to_lhlo_op.h similarity index 91% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_hlo_to_lhlo_op.h rename to tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/map_hlo_to_lhlo_op.h index a4c0515bcf3..b3849af1c8e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/map_hlo_to_lhlo_op.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_LHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H -#define MLIR_HLO_DIALECT_LHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H +#ifndef MLIR_HLO_LHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H +#define MLIR_HLO_LHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H #include -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "lhlo/IR/lhlo_ops.h" +#include "mhlo/IR/hlo_ops.h" namespace mlir { namespace mhlo { @@ -98,6 +98,7 @@ MAP_HLO_TO_LHLO(SineOp); MAP_HLO_TO_LHLO(SliceOp); MAP_HLO_TO_LHLO(SqrtOp); MAP_HLO_TO_LHLO(SubtractOp); +MAP_HLO_TO_LHLO(TanOp); MAP_HLO_TO_LHLO(TanhOp); MAP_HLO_TO_LHLO(TransposeOp); MAP_HLO_TO_LHLO(XorOp); @@ -110,4 +111,4 @@ MAP_HLO_TO_LHLO(RoundOp); } // namespace mhlo } // namespace mlir -#endif // MLIR_HLO_DIALECT_LHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H +#endif // MLIR_HLO_LHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_lhlo_to_hlo_op.h b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/map_lhlo_to_hlo_op.h similarity index 91% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_lhlo_to_hlo_op.h rename to tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/map_lhlo_to_hlo_op.h index 9666e4ffd6f..d1e7b520693 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_lhlo_to_hlo_op.h +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/map_lhlo_to_hlo_op.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_LHLO_TRANSFORMS_MAP_LHLO_TO_HLO_OP_H -#define MLIR_HLO_DIALECT_LHLO_TRANSFORMS_MAP_LHLO_TO_HLO_OP_H +#ifndef MLIR_HLO_LHLO_TRANSFORMS_MAP_LHLO_TO_HLO_OP_H +#define MLIR_HLO_LHLO_TRANSFORMS_MAP_LHLO_TO_HLO_OP_H #include -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "lhlo/IR/lhlo_ops.h" +#include "mhlo/IR/hlo_ops.h" namespace mlir { namespace lmhlo { @@ -93,6 +93,7 @@ MAP_LHLO_TO_HLO(SineOp); MAP_LHLO_TO_HLO(SliceOp); MAP_LHLO_TO_HLO(SqrtOp); MAP_LHLO_TO_HLO(SubtractOp); +MAP_LHLO_TO_HLO(TanOp); MAP_LHLO_TO_HLO(TanhOp); MAP_LHLO_TO_HLO(TransposeOp); MAP_LHLO_TO_HLO(XorOp); @@ -105,4 +106,4 @@ MAP_LHLO_TO_HLO(RoundOp); } // namespace lmhlo } // namespace mlir -#endif // MLIR_HLO_DIALECT_LHLO_TRANSFORMS_MAP_LHLO_TO_HLO_OP_H +#endif // MLIR_HLO_LHLO_TRANSFORMS_MAP_LHLO_TO_HLO_OP_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_lmhlo_to_scalar_op.h b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/map_lmhlo_to_scalar_op.h similarity index 88% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_lmhlo_to_scalar_op.h rename to tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/map_lmhlo_to_scalar_op.h index a9d450188a5..fb4a2e86672 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/map_lmhlo_to_scalar_op.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_LHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H -#define MLIR_HLO_DIALECT_LHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H +#ifndef MLIR_HLO_LHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H +#define MLIR_HLO_LHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H -#include "mlir-hlo/Dialect/lhlo/transforms/map_lhlo_to_hlo_op.h" -#include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h" +#include "lhlo/transforms/map_lhlo_to_hlo_op.h" +#include "mhlo/transforms/map_mhlo_to_scalar_op.h" namespace mlir { namespace lmhlo { @@ -61,4 +61,4 @@ struct LhloOpToStdScalarOp { } // namespace lmhlo } // namespace mlir -#endif // MLIR_HLO_DIALECT_LHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H +#endif // MLIR_HLO_LHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/passes.h b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/passes.h similarity index 68% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/passes.h rename to tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/passes.h index 580788168e1..8225dfa238c 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/passes.h +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/transforms/passes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_LHLO_TRANSFORMS_PASSES_H -#define MLIR_HLO_DIALECT_LHLO_TRANSFORMS_PASSES_H +#ifndef MLIR_HLO_LHLO_TRANSFORMS_PASSES_H +#define MLIR_HLO_LHLO_TRANSFORMS_PASSES_H #include @@ -38,7 +38,7 @@ class FusionOp; namespace lmhlo { #define GEN_PASS_DECL -#include "mlir-hlo/Dialect/lhlo/transforms/lmhlo_passes.h.inc" +#include "lhlo/transforms/lmhlo_passes.h.inc" // Lowers from LHLO dialect to Affine dialect. std::unique_ptr> createLhloLegalizeToAffinePass(); @@ -46,18 +46,6 @@ std::unique_ptr> createLhloLegalizeToAffinePass(); // Lowers from LHLO dialect to GPU dialect. std::unique_ptr> createLegalizeToGpuPass(); -// Fuses linalg ops obtained after LHLO lowering. To enable fusion, -// operations are first tiled. -// -// When 'use_parallel_loops' is set, the tiling will use scf.parallel -// operations. Otherwise, scf.for operations are used. -// -// 'tile_sizes' provides the tile sizes to use for tiling. If the linalg -// operation has more dimensions than tile sizes provided, 1 is used as -// default. -std::unique_ptr> createLhloFuseLinalgPass( - bool useParallelLoops = false, llvm::ArrayRef tileSizes = {}); - // Lowers from LHLO dialect to parallel loops. std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); @@ -69,9 +57,9 @@ std::unique_ptr> createLegalizeToTensorOpPass(); std::unique_ptr> createInputInlineFusionPass(); #define GEN_PASS_REGISTRATION -#include "mlir-hlo/Dialect/lhlo/transforms/lmhlo_passes.h.inc" +#include "lhlo/transforms/lmhlo_passes.h.inc" } // namespace lmhlo } // namespace mlir -#endif // MLIR_HLO_DIALECT_LHLO_TRANSFORMS_PASSES_H +#endif // MLIR_HLO_LHLO_TRANSFORMS_PASSES_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/lhlo_utils.h b/tensorflow/compiler/xla/mlir_hlo/lhlo/utils/lhlo_utils.h similarity index 81% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/lhlo_utils.h rename to tensorflow/compiler/xla/mlir_hlo/lhlo/utils/lhlo_utils.h index aac5863530b..007e9ddc7ea 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/utils/lhlo_utils.h +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo/utils/lhlo_utils.h @@ -13,13 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_UTILS_LHLO_UTILS_H -#define MLIR_HLO_UTILS_LHLO_UTILS_H +#ifndef MLIR_HLO_LHLO_UTILS_LHLO_UTILS_H +#define MLIR_HLO_LHLO_UTILS_LHLO_UTILS_H -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" +#include + +#include "mhlo/IR/hlo_ops_common.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" +#include "stablehlo/dialect/TypeInference.h" namespace mlir { namespace lmhlo { @@ -47,7 +50,10 @@ auto getInputs(OpT op, char) -> decltype(op.operands(), ValueRange{}) { template static LogicalResult verifyAllReduce(OpT op) { - if (failed(mlir::hlo::verifyReplicaGroups(op, /*is_uniform_sized=*/false))) + if (failed(hlo::verifyReplicaGroups(op.getLoc(), op.getReplicaGroups(), + /*allGroupsMustHaveSameSize=*/false, + op.getUseGlobalDeviceIds(), + /*expectedGroupSize=*/std::nullopt))) return failure(); // AllReduce has variadic operands and results that have the same size. @@ -69,4 +75,4 @@ static LogicalResult verifyAllReduce(OpT op) { } // namespace lmhlo } // namespace mlir -#endif // MLIR_HLO_UTILS_LHLO_UTILS_H +#endif // MLIR_HLO_LHLO_UTILS_LHLO_UTILS_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/CMakeLists.txt similarity index 100% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/CMakeLists.txt diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/CMakeLists.txt similarity index 83% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/CMakeLists.txt index 94b73785217..81912905556 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/CMakeLists.txt @@ -28,3 +28,19 @@ mlir_tablegen(lhlo_gpu_ops_dialect.cc.inc -gen-dialect-defs) add_public_tablegen_target(MLIRlhlo_gpu_opsIncGen) add_dependencies(mlir-headers MLIRlhlo_gpu_opsIncGen) + +include_directories(BEFORE + ${CMAKE_CURRENT_BINARY_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}) + +add_mlir_dialect_library(LmhloGPUDialect + lhlo_gpu_ops.cc + + DEPENDS + MLIRlhlo_gpu_opsIncGen + + LINK_LIBS PUBLIC + MhloDialect + MLIRIR + HloOpsCommon +) diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.cc b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.cc similarity index 79% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.cc rename to tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.cc index 8aee61e8ee0..1977edd2b2e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.cc +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.cc @@ -15,12 +15,13 @@ limitations under the License. // This file defines the operations used in the LMHLO GPU dialect. -#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h" +#include "lhlo_gpu/IR/lhlo_gpu_ops.h" #include #include #include +#include "lhlo/utils/lhlo_utils.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" @@ -29,9 +30,8 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" -#include "mlir-hlo/utils/lhlo_utils.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops_common.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -73,10 +73,10 @@ static FailureOr> parseI64Array(AsmParser &parser) { } // namespace mlir // Include order below matters. -#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_dialect.cc.inc" -#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_enums.cc.inc" +#include "lhlo_gpu/IR/lhlo_gpu_ops_dialect.cc.inc" +#include "lhlo_gpu/IR/lhlo_gpu_ops_enums.cc.inc" #define GET_ATTRDEF_CLASSES -#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_attrdefs.cc.inc" +#include "lhlo_gpu/IR/lhlo_gpu_ops_attrdefs.cc.inc" namespace mlir { namespace lmhlo_gpu { @@ -87,11 +87,11 @@ void LmhloGpuDialect::initialize() { getContext()->loadDialect(); addOperations< #define GET_OP_LIST -#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.cc.inc" +#include "lhlo_gpu/IR/lhlo_gpu_ops.cc.inc" >(); addAttributes< #define GET_ATTRDEF_LIST -#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_attrdefs.cc.inc" +#include "lhlo_gpu/IR/lhlo_gpu_ops_attrdefs.cc.inc" >(); } @@ -109,8 +109,18 @@ mlir::LogicalResult AllReduceStartOp::verify() { return lmhlo::verifyAllReduce(op); } +//===----------------------------------------------------------------------===// +// CollectivePermuteStartOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult CollectivePermuteStartOp::verify() { + CollectivePermuteStartOp op = *this; + return mlir::hlo::verifyCollectivePermuteSourceTargetPairs( + op, op.getSourceTargetPairs()); +} + } // namespace lmhlo_gpu } // namespace mlir #define GET_OP_CLASSES -#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.cc.inc" +#include "lhlo_gpu/IR/lhlo_gpu_ops.cc.inc" diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h similarity index 72% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h rename to tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h index dc1a0caae76..75d5f338693 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.h @@ -15,11 +15,11 @@ limitations under the License. // This file defines the operations used in the LHLO dialect. -#ifndef MLIR_HLO_DIALECT_LHLO_GPU_IR_LHLO_GPU_OPS_H -#define MLIR_HLO_DIALECT_LHLO_GPU_IR_LHLO_GPU_OPS_H +#ifndef MLIR_HLO_LHLO_GPU_IR_LHLO_GPU_OPS_H +#define MLIR_HLO_LHLO_GPU_IR_LHLO_GPU_OPS_H #include "llvm/ADT/StringRef.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -35,11 +35,11 @@ class OpBuilder; } // namespace mlir // Include order below matters. -#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_dialect.h.inc" -#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_enums.h.inc" +#include "lhlo_gpu/IR/lhlo_gpu_ops_dialect.h.inc" +#include "lhlo_gpu/IR/lhlo_gpu_ops_enums.h.inc" #define GET_ATTRDEF_CLASSES -#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_attrdefs.h.inc" +#include "lhlo_gpu/IR/lhlo_gpu_ops_attrdefs.h.inc" #define GET_OP_CLASSES -#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h.inc" +#include "lhlo_gpu/IR/lhlo_gpu_ops.h.inc" -#endif // MLIR_HLO_DIALECT_LHLO_GPU_IR_LHLO_GPU_OPS_H +#endif // MLIR_HLO_LHLO_GPU_IR_LHLO_GPU_OPS_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.td b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td similarity index 75% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.td rename to tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td index 426e261fd8f..e81d405b981 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.td +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops.td @@ -21,9 +21,9 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops_base.td" -include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_base.td" -include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_enums.td" +include "lhlo/IR/lhlo_ops_base.td" +include "lhlo_gpu/IR/lhlo_gpu_ops_base.td" +include "lhlo_gpu/IR/lhlo_gpu_ops_enums.td" include "stablehlo/dialect/Base.td" class LHLOGPU_Op traits = []> : @@ -42,7 +42,7 @@ def I32Buffer : MemRefOf<[I32]>; class GpuConvolutionAttributes { dag attributes = !con( - ConvolutionAttributes.attributes, + MHLO_ConvolutionAttributes.attributes, (ins F64Attr:$result_scale), extraAttribs, (ins ConvolutionBackendConfigAttr:$backend_config)); @@ -131,23 +131,44 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { Arg:$a, Arg:$b, Arg:$c, - DotDimensionNumbers:$dot_dimension_numbers, - HLO_PrecisionConfigAttr:$precision_config, + MHLO_DotDimensionNumbers:$dot_dimension_numbers, + MHLO_PrecisionConfigAttr:$precision_config, F64Attr:$alpha_real, F64Attr:$alpha_imag, F64Attr:$beta, OptionalAttr:$algorithm); } -def LHLOGPU_CublasLtMatmulOp : LHLOGPU_Op<"cublas.lt.matmul"> { +def LHLOGPU_CublasLtMatmulOp : LHLOGPU_Op<"cublas.lt.matmul", [AttrSizedOperandSegments]> { let arguments = (ins Arg:$a, Arg:$b, Arg:$c, Arg:$d, Arg, "", [MemRead]>:$bias, - DotDimensionNumbers:$dot_dimension_numbers, - HLO_PrecisionConfigAttr:$precision_config, + Arg, "", [MemRead, MemWrite]>:$aux, + MHLO_DotDimensionNumbers:$dot_dimension_numbers, + MHLO_PrecisionConfigAttr:$precision_config, + F64Attr:$alpha_real, + F64Attr:$alpha_imag, + F64Attr:$beta, + CublasLtMatmulEpilogueAttr:$epilogue, + I64Attr:$algorithm); +} + +def LHLOGPU_CublasLtMatmulF8Op : LHLOGPU_Op<"cublas.lt.matmul.f8"> { + let arguments = (ins + Arg:$a, + Arg:$b, + Arg:$c, + Arg:$a_scale, + Arg:$b_scale, + Arg:$c_scale, + Arg:$d_scale, + Arg:$d, + Arg, "", [MemWrite]>:$d_amax, + MHLO_DotDimensionNumbers:$dot_dimension_numbers, + MHLO_PrecisionConfigAttr:$precision_config, F64Attr:$alpha_real, F64Attr:$alpha_imag, F64Attr:$beta, @@ -164,8 +185,14 @@ def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { BoolAttr:$is_lower); } +class LHLOGPU_AsyncCollectiveCommunicationOp traits = []> : + LHLOGPU_Op { + let results = (outs MHLO_Token:$token); + let hasVerifier = 1; +} + def LHLOGPU_AllReduceStartOp : - LHLOGPU_Op<"all_reduce_start", [SameOperandsElementType, SameVariadicOperandSize]> { + LHLOGPU_AsyncCollectiveCommunicationOp<"all_reduce_start", [SameOperandsElementType]> { let summary = "AllReduceStart operator"; let description = [{ Performs an asynchronous custom reduction across replicas. @@ -175,22 +202,31 @@ def LHLOGPU_AllReduceStartOp : Arg, "", [MemWrite]>:$outputs, I64ElementsAttr:$replica_groups, DefaultValuedOptionalAttr:$constrain_layout, - OptionalAttr:$channel_id, + OptionalAttr:$channel_id, DefaultValuedOptionalAttr:$use_global_device_ids ); - let results = (outs HLO_Token:$token); let regions = (region SizedRegion<1>:$computation); - let hasVerifier = 1; } -def LHLOGPU_AllReduceDoneOp: - LHLOGPU_Op<"all_reduce_done", [SameVariadicOperandSize]> { +def LHLOGPU_AllReduceDoneOp: LHLOGPU_Op<"all_reduce_done"> { let summary = "AllReduceDone operator"; + let arguments = (ins MHLO_Token:$token); +} + +def LHLOGPU_CollectivePermuteStartOp : + LHLOGPU_AsyncCollectiveCommunicationOp<"collective_permute_start"> { + let summary = "CollectivePermuteStart operator"; let arguments = (ins - HLO_Token:$token, - Arg, "", [MemRead]>:$inputs, - Arg, "", [MemWrite]>:$outputs + Arg:$operand, + Arg:$output, + I64ElementsAttr:$source_target_pairs, + OptionalAttr:$channel_id ); } +def LHLOGPU_CollectivePermuteDoneOp: LHLOGPU_Op<"collective_permute_done"> { + let summary = "CollectivePermuteDone operator"; + let arguments = (ins MHLO_Token:$token); +} + #endif // LHLO_GPU_OPS diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_base.td b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_base.td similarity index 94% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_base.td rename to tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_base.td index e54c4378a17..8019a504912 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_base.td +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_base.td @@ -25,7 +25,7 @@ def LmhloGpuDialect : Dialect { let cppNamespace = "::mlir::lmhlo_gpu"; let useDefaultAttributePrinterParser = 1; - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // LHLO_GPU_OPS_BASE diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_enums.td b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td similarity index 86% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_enums.td rename to tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td index 9f8f865efa8..6dda1577c05 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_enums.td +++ b/tensorflow/compiler/xla/mlir_hlo/lhlo_gpu/IR/lhlo_gpu_ops_enums.td @@ -20,7 +20,7 @@ include "mlir/IR/OpBase.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/AttrTypeBase.td" -include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops_base.td" +include "lhlo_gpu/IR/lhlo_gpu_ops_base.td" def ActivationModeNone : I32EnumAttrCase<"None", 0>; def ActivationModeSigmoid : I32EnumAttrCase<"Sigmoid", 1>; @@ -29,12 +29,13 @@ def ActivationModeRelu : I32EnumAttrCase<"Relu", 3>; def ActivationModeRelu6 : I32EnumAttrCase<"Relu6", 4>; def ActivationModeReluX : I32EnumAttrCase<"ReluX", 5>; def ActivationModeBandPass : I32EnumAttrCase<"BandPass", 6>; +def ActivationModeElu: I32EnumAttrCase<"Elu", 7>; def Activation: I32EnumAttr<"Activation", "Activation applied with fused convolution", [ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh, ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX, - ActivationModeBandPass]> { + ActivationModeBandPass, ActivationModeElu]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::lmhlo_gpu"; } @@ -86,12 +87,18 @@ def CublasLtMatmulEpilogueDefault : I32EnumAttrCase<"Default", 0>; def CublasLtMatmulEpilogueBias : I32EnumAttrCase<"Bias", 1>; def CublasLtMatmulEpilogueRelu : I32EnumAttrCase<"Relu", 2>; def CublasLtMatmulEpilogueBiasRelu : I32EnumAttrCase<"BiasRelu", 3>; +def CublasLtMatmulEpilogueGelu : I32EnumAttrCase<"Gelu", 4>; +def CublasLtMatmulEpilogueBiasGelu : I32EnumAttrCase<"BiasGelu", 5>; +def CublasLtMatmulEpilogueGeluAux : I32EnumAttrCase<"GeluAux", 6>; +def CublasLtMatmulEpilogueBiasGeluAux : I32EnumAttrCase<"BiasGeluAux", 7>; def CublasLtMatmulEpilogue: I32EnumAttr<"CublasLtMatmulEpilogue", "Epilogue for cublasLt matmul", [CublasLtMatmulEpilogueDefault, CublasLtMatmulEpilogueBias, - CublasLtMatmulEpilogueRelu, CublasLtMatmulEpilogueBiasRelu]> { + CublasLtMatmulEpilogueRelu, CublasLtMatmulEpilogueBiasRelu, + CublasLtMatmulEpilogueGelu, CublasLtMatmulEpilogueBiasGelu, + CublasLtMatmulEpilogueGeluAux, CublasLtMatmulEpilogueBiasGeluAux]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::lmhlo_gpu"; } diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lib/CMakeLists.txt deleted file mode 100644 index 4e35b973cc3..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -# -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -add_subdirectory(Analysis) -add_subdirectory(CAPI) -add_subdirectory(Dialect) -add_subdirectory(Transforms) -add_subdirectory(utils) diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/CMakeLists.txt deleted file mode 100644 index 67607173fe4..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -add_subdirectory(gml_st) -add_subdirectory(lhlo) -add_subdirectory(lhlo_gpu) -add_subdirectory(mhlo) -add_subdirectory(thlo) diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/IR/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/IR/CMakeLists.txt deleted file mode 100644 index bd318090516..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/IR/CMakeLists.txt +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -include_directories(BEFORE - ${CMAKE_CURRENT_BINARY_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}) - -add_mlir_dialect_library(GmlStDialect - gml_st_ops.cc - - DEPENDS - MLIRgml_st_opsIncGen - - LINK_LIBS PUBLIC - MLIRArithUtils - MLIRControlFlowInterfaces - MLIRIR - MLIRInferTypeOpInterface - MLIRLoopLikeInterface - MLIRMemRefDialect - MLIRSideEffectInterfaces - MLIRSupport - MLIRTensorDialect - MLIRViewLikeInterface -) diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/IR/gml_st_ops.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/IR/gml_st_ops.cc deleted file mode 100644 index 96e0eca39fb..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/IR/gml_st_ops.cc +++ /dev/null @@ -1,1959 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" - -#include -#include -#include -#include -#include - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/Sequence.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Casting.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/ViewLikeInterface.h" - -namespace mlir { -namespace { - -void printShapeTypeDimensionsList(AsmPrinter &printer, - ArrayRef integers) { - llvm::interleave( - integers, printer, - [&](int64_t val) { - if (val == ShapedType::kDynamicSize) - printer << '?'; - else - printer << val; - }, - "x"); -} - -ParseResult parseShapeTypeDimensionsList( - AsmParser &parser, FailureOr> &dims) { - SmallVector vals; - if (failed(parser.parseDimensionList(vals, /*allowDynamic=*/true, - /*withTrailingX=*/false))) { - return failure(); - } - dims = vals; - return success(); -} - -ParseResult parseAssignmentListWithTypes( - OpAsmParser &parser, SmallVectorImpl &lhs, - SmallVectorImpl &rhs, - SmallVectorImpl &types) { - auto parseElt = [&]() -> ParseResult { - if (parser.parseOperand(lhs.emplace_back(), /*allowResultNumber=*/false) || - parser.parseEqual() || parser.parseOperand(rhs.emplace_back()) || - parser.parseColon() || parser.parseType(types.emplace_back())) { - return failure(); - } - return success(); - }; - return parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, parseElt); -} - -} // namespace -} // namespace mlir - -// Generated dialect definitions. -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_dialect.cc.inc" - -// Generated type classes. -#define GET_TYPEDEF_CLASSES -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_types.cc.inc" - -// Generated attribute classes. -#define GET_ATTRDEF_CLASSES -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_attrs.cc.inc" - -namespace mlir { -namespace gml_st { - -//===----------------------------------------------------------------------===// -// GmlStDialect -//===----------------------------------------------------------------------===// - -void GmlStDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.cc.inc" - >(); - addTypes< -#define GET_TYPEDEF_LIST -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_types.cc.inc" - >(); - addAttributes< -#define GET_ATTRDEF_LIST -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_attrs.cc.inc" - >(); -} - -// Helper function to ensure index types for some attrbutes when folding. -static OpFoldResult ensureIndexTypeForAttribute(OpFoldResult foldResult) { - if (foldResult.is()) { - auto attr = foldResult.get().dyn_cast(); - if (!attr.getType().isa()) { - Builder b(attr.getContext()); - return b.getIndexAttr(attr.getInt()); - } - } - return foldResult; -} - -Operation *GmlStDialect::materializeConstant(OpBuilder &builder, Attribute attr, - Type type, Location loc) { - if (type.isa()) { - int64_t intValue = attr.cast().getInt(); - return builder.create(loc, intValue); - } - return {}; -} - -//===----------------------------------------------------------------------===// -// MaterializeOp -//===----------------------------------------------------------------------===// - -static Type inferReturnType(ShapedType sourceType, Type setType) { - if (auto tileType = setType.dyn_cast()) { - return sourceType.clone(tileType.getShape(), sourceType.getElementType()); - } - assert(false && "could not infer result type"); - return {}; -} - -void MaterializeOp::build(OpBuilder &builder, OperationState &result, - Value source, Value set) { - auto sourceType = source.getType().cast(); - auto resultType = inferReturnType(sourceType, set.getType()); - build(builder, result, resultType, source, set); -} - -LogicalResult verifyCompatibleExtractedSubset(Operation *op, - ShapedType shapedType, - Type extractedType, - Type setType) { - auto sourceRank = shapedType.getRank(); - auto elementType = shapedType.getElementType(); - - // If the result is a scalar, check that the tile had a single element. - if (!extractedType.isa()) { - auto tileType = setType.cast(); - if (extractedType != elementType) { - return op->emitOpError("expected the result type ") - << extractedType << " to match source element type " - << elementType; - } - if (tileType.hasStaticShape() && tileType.getNumElements() == 1) - return success(); - - return op->emitOpError("expected tile type ") - << tileType << " to have a single element shape"; - } - - // If the result is a shaped type, compare with the inferred type. - auto extractedShapedType = extractedType.cast(); - auto tileType = setType.cast(); - int64_t tileRank = tileType.getRank(); - if (tileRank != sourceRank) { - return op->emitOpError("expected source rank = ") - << sourceRank << " to match tile rank = " << tileRank; - } - - auto inferredType = - shapedType.clone(tileType.getShape(), shapedType.getElementType()); - if (extractedShapedType != inferredType) { - return op->emitOpError("expected result type = ") - << extractedShapedType - << " to match the inferred type = " << inferredType; - } - - return success(); -} - -LogicalResult MaterializeOp::verify() { - // TODO(pifon): Add verification that was removed from TileOp::verify. - return verifyCompatibleExtractedSubset(getOperation(), getSource().getType(), - getType(), getSet().getType()); -} - -namespace { -/// Cleans up UnrealizedConversionCast sets from materialize ops. -struct FoldMaterializeUnrealizedConversionCast - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(MaterializeOp op, - PatternRewriter &rewriter) const override { - auto cast = op.getSet().getDefiningOp(); - if (!cast) return failure(); - - auto set = cast.getOperand(0); - auto newOp = rewriter.create( - op.getLoc(), inferReturnType(op.getSource().getType(), set.getType()), - op.getSource(), set); - rewriter.replaceOpWithNewOp(op, op.getType(), newOp); - return success(); - } -}; - -/// Folds tensor::CastOp sources into MaterializeOp. -struct FoldSrcCastIntoMaterialize : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(MaterializeOp op, - PatternRewriter &rewriter) const override { - auto cast = op.getSource().getDefiningOp(); - if (!cast) return failure(); - - auto src = cast.getSource(); - auto set = op.getSet(); - rewriter.replaceOpWithNewOp( - op, inferReturnType(src.getType(), set.getType()), src, set); - return success(); - } -}; -} // namespace - -void MaterializeOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results - .add( - context); -} - -//===----------------------------------------------------------------------===// -// LoopOp -//===----------------------------------------------------------------------===// - -void LoopOp::build(OpBuilder &builder, OperationState &result, - ValueRange lowerBounds, ValueRange upperBounds, - ValueRange steps, ValueRange inputs, ValueRange outputs, - ArrayAttr iteratorTypes, - function_ref - bodyBuilderFn) { - build(builder, result, lowerBounds, upperBounds, steps, inputs, outputs, - iteratorTypes, llvm::None, bodyBuilderFn); -} - -void LoopOp::build(OpBuilder &builder, OperationState &result, - ValueRange lowerBounds, ValueRange upperBounds, - ValueRange steps, ValueRange inputs, ValueRange outputs, - ArrayAttr iteratorTypes, - Optional distributionTypes, - function_ref - bodyBuilderFn) { - result.addOperands(lowerBounds); - result.addOperands(upperBounds); - result.addOperands(steps); - result.addOperands(inputs); - result.addOperands(outputs); - result.addAttribute( - LoopOp::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr({static_cast(lowerBounds.size()), - static_cast(upperBounds.size()), - static_cast(steps.size()), - static_cast(inputs.size()), - static_cast(outputs.size())})); - result.addAttribute(getIteratorTypesAttrStrName(), iteratorTypes); - - if (distributionTypes.has_value()) - result.addAttribute(getDistributionTypesAttrStrName(), - distributionTypes.value()); - - // Add output types for `RankedTensorType` output arguments. - for (Value output : outputs) { - Type outputType = output.getType(); - if (outputType.isa()) result.addTypes(outputType); - } - - OpBuilder::InsertionGuard guard(builder); - unsigned numIVs = steps.size(); - SmallVector argTypes(numIVs, builder.getIndexType()); - SmallVector argLocs(numIVs, result.location); - for (Value input : inputs) { - argTypes.push_back(input.getType()); - argLocs.push_back(input.getLoc()); - } - for (Value output : outputs) { - argTypes.push_back(output.getType()); - argLocs.push_back(output.getLoc()); - } - Region *bodyRegion = result.addRegion(); - Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs); - - if (bodyBuilderFn) { - builder.setInsertionPointToStart(bodyBlock); - bodyBuilderFn(builder, result.location, - bodyBlock->getArguments().take_front(numIVs), - bodyBlock->getArguments().slice(numIVs, inputs.size()), - bodyBlock->getArguments().take_back(outputs.size())); - LoopOp::ensureTerminator(*bodyRegion, builder, result.location); - } -} - -void LoopOp::print(OpAsmPrinter &p) { - p << " (" << getInductionVars() << ") = (" << getLowerBound() << ") to (" - << getUpperBound() << ") step (" << getStep() << ")"; - - if (!getInputs().empty()) { - p << " ins ("; - llvm::interleaveComma(llvm::zip(getRegionInputArgs(), getInputs()), p, - [&](auto it) { - p << std::get<0>(it) << " = " << std::get<1>(it) - << ": " << std::get<1>(it).getType(); - }); - p << ")"; - } - if (!getOutputs().empty()) { - p << " outs ("; - llvm::interleaveComma(llvm::zip(getRegionOutputArgs(), getOutputs()), p, - [&](auto it) { - p << std::get<0>(it) << " = " << std::get<1>(it) - << ": " << std::get<1>(it).getType(); - }); - p << ")"; - } - - if (llvm::any_of(getIteratorTypes(), [](Attribute attr) { - return attr.cast().getValue() != - utils::IteratorType::parallel; - })) - p << " iterators" << getIteratorTypes(); - - if (getDistributionTypes().has_value()) - p << " distribution" << getDistributionTypes().value(); - - p << ' '; - p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); - p.printOptionalAttrDict( - getOperation()->getAttrs(), - /*elidedAttrs=*/{LoopOp::getOperandSegmentSizeAttr(), - LoopOp::getIteratorTypesAttrName(), - LoopOp::getDistributionTypesAttrName()}); -} - -ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) { - auto &builder = parser.getBuilder(); - // Parse an opening `(` followed by induction variables followed by `)` - SmallVector ivs; - if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren, - /*allowResultNumber=*/false)) - return failure(); - - // Parse loop bounds. - SmallVector lower; - if (parser.parseEqual() || - parser.parseOperandList(lower, ivs.size(), - OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(lower, builder.getIndexType(), result.operands)) - return failure(); - - SmallVector upper; - if (parser.parseKeyword("to") || - parser.parseOperandList(upper, ivs.size(), - OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(upper, builder.getIndexType(), result.operands)) - return failure(); - - // Parse step values. - SmallVector steps; - if (parser.parseKeyword("step") || - parser.parseOperandList(steps, ivs.size(), - OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(steps, builder.getIndexType(), result.operands)) - return failure(); - - // Parse input tensors. - SmallVector inputs, inputRegionArgs; - SmallVector inputTypes; - if (succeeded(parser.parseOptionalKeyword("ins"))) { - SMLoc inputsOperandsLoc = parser.getCurrentLocation(); - - if (parseAssignmentListWithTypes(parser, inputRegionArgs, inputs, - inputTypes)) - return failure(); - - if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc, - result.operands)) - return failure(); - } - - // Parse output tensors. - SmallVector outputs, outputRegionArgs; - SmallVector outputTypes; - if (succeeded(parser.parseOptionalKeyword("outs"))) { - SMLoc outputsOperandsLoc = parser.getCurrentLocation(); - - if (parseAssignmentListWithTypes(parser, outputRegionArgs, outputs, - outputTypes)) - return failure(); - - if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc, - result.operands)) - return failure(); - for (Type outputType : outputTypes) - if (outputType.isa()) result.addTypes(outputType); - } - - Attribute iterTypes; - if (succeeded(parser.parseOptionalKeyword("iterators"))) { - if (parser.parseAttribute(iterTypes)) return failure(); - } else { - // Set all loop iterator types to "parallel" if they are not printed in IR. - auto parallelIter = - builder.getAttr(utils::IteratorType::parallel); - iterTypes = builder.getArrayAttr( - SmallVector(ivs.size(), parallelIter)); - } - - result.addAttribute(LoopOp::getIteratorTypesAttrStrName(), iterTypes); - - if (succeeded(parser.parseOptionalKeyword("distribution"))) { - Attribute distributionTypes; - if (failed(parser.parseAttribute(distributionTypes))) return failure(); - result.addAttribute(LoopOp::getDistributionTypesAttrStrName(), - distributionTypes); - } - - result.addAttribute( - LoopOp::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr({static_cast(lower.size()), - static_cast(upper.size()), - static_cast(steps.size()), - static_cast(inputs.size()), - static_cast(outputs.size())})); - - // Parse the body. - Region *body = result.addRegion(); - - SmallVector regionTypes(ivs.size(), builder.getIndexType()); - regionTypes.append(inputTypes); - regionTypes.append(outputTypes); - - SmallVector regionOperands(ivs); - regionOperands.append(inputRegionArgs); - regionOperands.append(outputRegionArgs); - - SmallVector regionArgs; - - for (auto argAndType : llvm::zip(regionOperands, regionTypes)) { - auto &arg = regionArgs.emplace_back(); - arg.ssaName = std::get<0>(argAndType); - arg.type = std::get<1>(argAndType); - } - - if (parser.parseRegion(*body, regionArgs)) return failure(); - - // Parse optional attributes. - if (parser.parseOptionalAttrDict(result.attributes)) return failure(); - - return success(); -} - -Region &LoopOp::getLoopBody() { return getRegion(); } - -LogicalResult LoopOp::verify() { - // Check if iterator types are provided for every loop dimension. - if (getIteratorTypes().size() != getNumLoops()) - return emitOpError("expected iterator types array attribute size = ") - << getIteratorTypes().size() - << " to match the number of loops = " << getNumLoops(); - - // Check if types of input arguments match region args types. - for (auto &item : - llvm::enumerate(llvm::zip(getInputs(), getRegionInputArgs()))) { - Value input, inputRegionArg; - unsigned index = item.index(); - std::tie(input, inputRegionArg) = item.value(); - if (input.getType() != inputRegionArg.getType()) - return emitOpError("expected input arg ") - << index << " with type = " << input.getType() - << " to match region arg " << index + getNumLoops() - << " type = " << inputRegionArg.getType(); - } - - // Check if types of output arguments match region args types. - for (auto &item : - llvm::enumerate(llvm::zip(getOutputs(), getRegionOutputArgs()))) { - Value output, outputRegionArg; - unsigned index = item.index(); - std::tie(output, outputRegionArg) = item.value(); - if (output.getType() != outputRegionArg.getType()) - return emitOpError("expected output arg ") - << index << " with type = " << output.getType() - << " to match region arg " - << index + getNumLoops() + getInputs().size() - << " type = " << outputRegionArg.getType(); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// LoopLikeOp -//===----------------------------------------------------------------------===// - -namespace { - -ParseResult parseForOpOutputArgs( - OpAsmParser &parser, OperationState &result, - SmallVectorImpl ®ionOperands, - SmallVectorImpl ®ionTypes, int32_t *outputCount) { - SmallVector outputs, outputRegionArgs; - SmallVector outputTypes; - - auto parseElt = [&]() -> ParseResult { - if (parser.parseOperand(outputRegionArgs.emplace_back(), - /*allowResultNumber=*/false) || - parser.parseEqual()) { - return failure(); - } - if (parser.parseOperand(outputs.emplace_back()) || parser.parseColon() || - parser.parseType(outputTypes.emplace_back())) { - return failure(); - } - *outputCount = outputs.size(); - return success(); - }; - if (succeeded(parser.parseOptionalKeyword("outs"))) { - SMLoc loc = parser.getCurrentLocation(); - - if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, parseElt)) - return failure(); - if (parser.resolveOperands(outputs, outputTypes, loc, result.operands)) - return failure(); - } - regionOperands.append(outputRegionArgs); - regionTypes.append(outputTypes); - return success(); -} - -} // namespace - -template -ParseResult parseLoopLikeOp(OpAsmParser &parser, OperationState &result) { - auto &builder = parser.getBuilder(); - // Parse an opening `(` followed by induction variables followed by `)` - SmallVector ivs; - if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren, - /*allowResultNumber=*/false)) - return failure(); - - // Parse loop bounds. - SmallVector lower; - if (parser.parseEqual() || - parser.parseOperandList(lower, ivs.size(), - OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(lower, builder.getIndexType(), result.operands)) - return failure(); - - SmallVector upper; - if (parser.parseKeyword("to") || - parser.parseOperandList(upper, ivs.size(), - OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(upper, builder.getIndexType(), result.operands)) - return failure(); - - // Parse step values. - SmallVector steps; - if (parser.parseKeyword("step") || - parser.parseOperandList(steps, ivs.size(), - OpAsmParser::Delimiter::Paren) || - parser.resolveOperands(steps, builder.getIndexType(), result.operands)) - return failure(); - - SmallVector segmentSizes{static_cast(lower.size()), - static_cast(upper.size()), - static_cast(steps.size())}; - - // Parse distribution type (only for ParallelOp) - if (std::is_same::value) { - if (succeeded(parser.parseOptionalKeyword("distribution"))) { - StringAttr distributionType; - if (parser.parseLParen() || parser.parseAttribute(distributionType) || - parser.parseRParen()) - return failure(); - result.addAttribute(ParallelOp::getDistributionTypeAttrName(result.name), - distributionType); - } - } - - // Parse the output tensors (only for ForOp) and the body. - SmallVector regionOperands(ivs); - SmallVector regionTypes(ivs.size(), builder.getIndexType()); - - if (std::is_same::value) { - int32_t outputCount = 0; - if (parseForOpOutputArgs(parser, result, regionOperands, regionTypes, - &outputCount)) - return failure(); - segmentSizes.push_back(outputCount); - } - - SmallVector regionArgs; - for (auto argAndType : llvm::zip(regionOperands, regionTypes)) { - auto &arg = regionArgs.emplace_back(); - std::tie(arg.ssaName, arg.type) = argAndType; - } - Region *body = result.addRegion(); - if (parser.parseRegion(*body, regionArgs)) return failure(); - - // Parse attributes. - if (parser.parseOptionalAttrDict(result.attributes)) return failure(); - - // Parser result types. - if (parser.parseOptionalColonTypeList(result.types)) return failure(); - - // Add segment sizes. - result.addAttribute(LoopTy::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr(segmentSizes)); - - return success(); -} - -//===----------------------------------------------------------------------===// -// ParallelOp -//===----------------------------------------------------------------------===// - -Region &ParallelOp::getLoopBody() { return getRegion(); } - -SetYieldOp ParallelOp::getTerminator() { - return cast(getBody()->getTerminator()); -} - -LogicalResult ParallelOp::verify() { return success(); } - -void ParallelOp::build( - OpBuilder &builder, OperationState &result, TypeRange resultTypes, - ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps, - Optional distributionType, - function_ref bodyBuilderFn) { - result.addOperands(lowerBounds); - result.addOperands(upperBounds); - result.addOperands(steps); - result.addTypes(resultTypes); - result.addAttribute( - LoopOp::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr({static_cast(lowerBounds.size()), - static_cast(upperBounds.size()), - static_cast(steps.size())})); - - if (distributionType.has_value()) - result.addAttribute(getDistributionTypeAttrName(result.name), - distributionType.value()); - - OpBuilder::InsertionGuard guard(builder); - unsigned numIvs = steps.size(); - SmallVector argTypes(numIvs, builder.getIndexType()); - SmallVector argLocs(numIvs, result.location); - Region *bodyRegion = result.addRegion(); - Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs); - - if (bodyBuilderFn) { - builder.setInsertionPointToStart(bodyBlock); - bodyBuilderFn(builder, result.location, - bodyBlock->getArguments().take_front(numIvs)); - ParallelOp::ensureTerminator(*bodyRegion, builder, result.location); - } -} - -void ParallelOp::print(OpAsmPrinter &p) { - p << " (" << getInductionVars() << ") = (" << getLowerBound() << ") to (" - << getUpperBound() << ") step (" << getStep() << ") "; - - if (getDistributionType().has_value()) - p << "distribution (" << getDistributionTypeAttr() << ") "; - - p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); - p.printOptionalAttrDict( - getOperation()->getAttrs(), - /*elidedAttrs=*/{ParallelOp::getOperandSegmentSizeAttr(), - getDistributionTypeAttrName()}); - - if (!getResultTypes().empty()) { - p << " : "; - llvm::interleave(getResultTypes(), p, ", "); - } -} - -ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) { - return parseLoopLikeOp(parser, result); -} - -ValueRange ParallelOp::getLoopLikeOpInits() { - return getTerminator().getDsts(); -} - -//===----------------------------------------------------------------------===// -// ForOp -//===----------------------------------------------------------------------===// - -Region &ForOp::getLoopBody() { return getRegion(); } - -SetYieldOp ForOp::getTerminator() { - return cast(getBody()->getTerminator()); -} - -LogicalResult ForOp::verify() { - // Check if types of output arguments match region args types. - for (auto &item : - llvm::enumerate(llvm::zip(getOutputs(), getRegionOutputArgs()))) { - Value output, outputRegionArg; - unsigned index = item.index(); - std::tie(output, outputRegionArg) = item.value(); - if (output.getType() != outputRegionArg.getType()) { - return emitOpError("expected output arg ") - << index << " with type = " << output.getType() - << " to match region arg " << index + getNumLoops() - << " type = " << outputRegionArg.getType(); - } - auto terminator = getTerminator(); - auto numDstOperands = terminator.getNumDstOperands(); - if (index >= numDstOperands) { - const auto *s = index ? "s" : ""; - return terminator.emitOpError("expected to have at least ") - << index + 1 << " destination operand" << s << " (currently " - << numDstOperands << ")"; - } - - if (terminator.getDstOperand(index)->get() != outputRegionArg) { - return terminator.emitOpError("expected output block argument ") - << index << " to match set_yield destination"; - } - } - return success(); -} - -void ForOp::build( - OpBuilder &builder, OperationState &result, TypeRange resultTypes, - ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps, - ValueRange outputs, - function_ref - bodyBuilderFn) { - result.addOperands(lowerBounds); - result.addOperands(upperBounds); - result.addOperands(steps); - result.addOperands(outputs); - result.addTypes(resultTypes); - result.addAttribute( - LoopOp::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr({static_cast(lowerBounds.size()), - static_cast(upperBounds.size()), - static_cast(steps.size()), - static_cast(outputs.size())})); - - OpBuilder::InsertionGuard guard(builder); - unsigned numIvs = steps.size(); - SmallVector argTypes(numIvs, builder.getIndexType()); - SmallVector argLocs(numIvs, result.location); - for (Value output : outputs) { - argTypes.push_back(output.getType()); - argLocs.push_back(output.getLoc()); - } - Region *bodyRegion = result.addRegion(); - Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes, argLocs); - - if (bodyBuilderFn) { - builder.setInsertionPointToStart(bodyBlock); - bodyBuilderFn(builder, result.location, - bodyBlock->getArguments().take_front(numIvs), - bodyBlock->getArguments().take_back(outputs.size())); - ForOp::ensureTerminator(*bodyRegion, builder, result.location); - } -} - -void ForOp::print(OpAsmPrinter &p) { - p << " (" << getInductionVars() << ") = (" << getLowerBound() << ") to (" - << getUpperBound() << ") step (" << getStep() << ")"; - - if (!getOutputs().empty()) { - p << " outs ("; - llvm::interleaveComma( - llvm::zip(getRegionOutputArgs(), getOutputs()), p, [&](auto it) { - Value outputRegionArg, output; - std::tie(outputRegionArg, output) = it; - p << outputRegionArg << " = " << output << ": " << output.getType(); - }); - p << ")"; - } - - p << ' '; - p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); - p.printOptionalAttrDict(getOperation()->getAttrs(), - /*elidedAttrs=*/{ForOp::getOperandSegmentSizeAttr()}); - - if (!getResultTypes().empty()) { - p << " : "; - llvm::interleave(getResultTypes(), p, ", "); - } -} - -ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { - return parseLoopLikeOp(parser, result); -} - -namespace { -// Collapse loop dimensions that perform a single iteration. -// This is a partial copy of the corresponding pattern from SCF. -struct CollapseSingleIterationLoops : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ForOp op, - PatternRewriter &rewriter) const override { - BlockAndValueMapping mapping; - // Compute new loop bounds that omit all single-iteration loop dimensions. - SmallVector newLowerBounds, newUpperBounds, newSteps; - newLowerBounds.reserve(op.getLowerBound().size()); - newUpperBounds.reserve(op.getUpperBound().size()); - newSteps.reserve(op.getStep().size()); - auto getConstant = [](Value v) -> Optional { - auto constant = - dyn_cast_or_null(v.getDefiningOp()); - if (constant) return constant.value(); - return None; - }; - for (auto [lowerBound, upperBound, step, iv] : - llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(), - op.getInductionVars())) { - // Collect the statically known loop bounds. - auto lowerBoundConstant = getConstant(lowerBound); - auto upperBoundConstant = getConstant(upperBound); - auto stepConstant = getConstant(step); - // Replace the loop induction variable by the lower bound if the loop - // performs a single iteration. Otherwise, copy the loop bounds. - if (lowerBoundConstant && upperBoundConstant && stepConstant && - (*upperBoundConstant - *lowerBoundConstant) > 0 && - (*upperBoundConstant - *lowerBoundConstant) <= *stepConstant) { - mapping.map(iv, lowerBound); - } else { - newLowerBounds.push_back(lowerBound); - newUpperBounds.push_back(upperBound); - newSteps.push_back(step); - } - } - // Exit if none of the loop dimensions perform a single iteration. - if (newLowerBounds.size() == op.getLowerBound().size()) return failure(); - - // Replace the parallel loop by lower-dimensional parallel loop. - auto newOp = rewriter.create(op.getLoc(), op.getResultTypes(), - newLowerBounds, newUpperBounds, - newSteps, op.getOutputs(), nullptr); - // Clone the loop body and remap the block arguments of the collapsed loops - // (inlining does not support a cancellable block argument mapping). - rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(), - newOp.getRegion().begin(), mapping); - rewriter.replaceOp(op, newOp.getResults()); - return success(); - } -}; - -/// Folds CastOp of loop outputs into ForOp -struct RefineForOpShape : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ForOp op, - PatternRewriter &rewriter) const override { - if (llvm::all_of(op.getOutputs(), [](auto out) { - return out.template getDefiningOp() == nullptr; - })) - return failure(); - - Location loc = op.getLoc(); - // Scans through output args to find what args are produced by `tensor.cast` - // ops. Also cache the info since we are gonna reuse it a lot. - SmallVector newOutputs{op.getOutputs()}; - SmallVector newTypes{op.getResultTypes()}; - SmallVector castOutputs; - for (auto &&[out, type] : llvm::zip(newOutputs, newTypes)) { - if (auto cast = - castOutputs.emplace_back(out.getDefiningOp())) { - out = cast.getSource(); - type = out.getType(); - } - } - - auto newFor = rewriter.create(loc, newTypes, op.getLowerBound(), - op.getUpperBound(), op.getStep(), - newOutputs, nullptr); - - // Map outputs, insert `tensor.cast` if necessary. - BlockAndValueMapping bvm; - bvm.map(op.getInductionVars(), newFor.getInductionVars()); - - auto innerBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, newFor.getBody()); - rewriter.setInsertionPointAfter(newFor); - - for (const auto &[oldArg, newArg, cast] : - llvm::zip(op.getRegionOutputArgs(), newFor.getRegionOutputArgs(), - castOutputs)) { - bvm.map(oldArg, - cast ? innerBuilder.create(cast.getType(), newArg) - : Value(newArg)); - } - // Cast the loop results for downstream uses of the loop if necessary. - SmallVector newResults{newFor.getResults()}; - for (auto &&[res, cast] : llvm::zip(newResults, castOutputs)) { - if (cast) res = rewriter.create(loc, cast.getType(), res); - } - - // Clone loop body. - for (auto &o : *(op.getBody())) innerBuilder.clone(o, bvm); - - // Update set_yield destinations to the new type. - auto term = cast(newFor.getTerminator()); - rewriter.updateRootInPlace(term, [&]() { - term.getDstsMutable().assign(newFor.getRegionOutputArgs()); - }); - - // Update the original loop by the new loop + CastOp. - rewriter.replaceOp(op, newResults); - return success(); - } -}; -} // namespace - -void ForOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -namespace { - -static constexpr int64_t kNoMatch = -1; - -// Folds away LoopOp inputs if they have no uses within the body. -// -// Example: -// -// %0 = gml_st.loop ... ins (%in_ = %in: tensor<...>, -// %in_buf_ = %in_buf: memref<...>) {...} -// Becomes -// -// gml_st.loop ... ins (%in_buf_ = %in_buf: memref<...>) {...} -struct LoopInputsFolder : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(LoopOp loop, - PatternRewriter &rewriter) const final { - SmallVector newInputs, regionInputTensorArgs; - // Store ids of the corresponding old and new input operands. - SmallVector oldInputIdToNew(loop.getInputs().size(), kNoMatch); - for (const auto &en : llvm::enumerate( - llvm::zip(loop.getInputs(), loop.getRegionInputArgs()))) { - Value in, bbArg; - size_t index = en.index(); - std::tie(in, bbArg) = en.value(); - if (!bbArg.use_empty()) { - oldInputIdToNew[index] = newInputs.size(); - newInputs.push_back(in); - } - } - if (newInputs.size() == loop.getInputs().size()) return failure(); - Location loc = loop.getLoc(); - auto newLoop = rewriter.create( - loc, loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), - newInputs, loop.getOutputs(), loop.getIteratorTypes(), - loop.getDistributionTypes()); - - // Clone the region. - BlockAndValueMapping bvm; - bvm.map(loop.getInductionVars(), newLoop.getInductionVars()); - bvm.map(loop.getRegionOutputArgs(), newLoop.getRegionOutputArgs()); - for (const auto &en : llvm::enumerate(oldInputIdToNew)) - if (en.value() != kNoMatch) - bvm.map(loop.getRegionInputArgs()[en.index()], - newLoop.getRegionInputArgs()[en.value()]); - OpBuilder innerBuilder = - OpBuilder::atBlockEnd(newLoop.getBody(), rewriter.getListener()); - for (auto &op : *loop.getBody()) innerBuilder.clone(op, bvm); - rewriter.replaceOp(loop, newLoop.getResults()); - - return success(); - } -}; - -} // namespace - -/// A simple, conservative analysis to determine if the loop is shape -/// conserving. I.e., the type of the arg-th yielded value is the same as the -/// type of the corresponding basic block argument of the loop. -/// Note: This function handles only simple cases. Expand as needed. -static bool isShapePreserving(LoopOp loopOp, int64_t arg) { - auto yieldOp = cast(loopOp.getLoopBody().front().getTerminator()); - if (yieldOp.getValues().empty()) - // Loop either has no outputs or is a "memref-based version". In either - // case, the loop is shape conserving. - return true; - assert(arg < static_cast(yieldOp.getValues().size()) && - "arg is out of bounds"); - Value value = yieldOp.getValues()[arg]; - while (value) { - if (value == loopOp.getRegionOutputArgs()[arg]) return true; - OpResult opResult = value.dyn_cast(); - if (!opResult) return false; - - using tensor::InsertSliceOp; - value = llvm::TypeSwitch(opResult.getOwner()) - .template Case( - [&](InsertSliceOp op) { return op.getDest(); }) - .template Case([&](LoopOp loopOp) { - return isShapePreserving(loopOp, opResult.getResultNumber()) - ? loopOp.getOutputs()[opResult.getResultNumber()] - : Value(); - }) - .Default([&](auto /*op*/) { return Value(); }); - } - return false; -} - -namespace { - -/// Fold dim(x) where `x` is an input/output argument of a LoopOp block -/// to dim(y) where `y` is the initial input/output value of the argument. -/// -/// E.g.: -/// %y = ... : tensor<...> -/// gml_st.loop ... ins(%x = %y : tensor<...>) { -/// tensor.dim %x, %c0 : tensor<...> -/// } -/// -/// is folded to: -/// %y = ... : tensor<...> -/// gml_st.loop ... ins(%x = %y : tensor<...>) { -/// tensor.dim %y, %c0 : tensor<...> -/// } -/// -/// Note: Dim ops are folded only if it can be proven that the runtime type of -/// the yielded value (in case of outputs) does not change with loop iterations. -template -struct DimOfLoopInsOutsFolder : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy dimOp, - PatternRewriter &rewriter) const final { - auto src = dimOp.getSource().template dyn_cast(); - if (!src) return failure(); - auto loopOp = dyn_cast(src.getOwner()->getParent()->getParentOp()); - if (!loopOp) return failure(); - unsigned numLoops = loopOp.getNumLoops(); - unsigned numInputArgs = loopOp.getRegionInputArgs().size(); - if (src.getArgNumber() >= numInputArgs + numLoops && - !isShapePreserving(loopOp, - src.getArgNumber() - numInputArgs - numLoops)) - return failure(); - - auto inputArgs = loopOp.getRegionInputArgs(); - auto it1 = llvm::find(inputArgs, src); - if (it1 != inputArgs.end()) { - rewriter.updateRootInPlace(dimOp, [&] { - dimOp.getSourceMutable().assign( - loopOp.getInputs()[it1 - inputArgs.begin()]); - }); - return success(); - } - - auto outputArgs = loopOp.getRegionOutputArgs(); - auto it2 = llvm::find(outputArgs, src); - if (it2 != outputArgs.end()) { - rewriter.updateRootInPlace(dimOp, [&] { - dimOp.getSourceMutable().assign( - loopOp.getOutputs()[it2 - outputArgs.begin()]); - }); - return success(); - } - - return failure(); - } -}; - -/// Fold dim(r) where `r` is the result of a LoopOp to dim(y) where `y` -/// is the initial output value of the loop. -/// -/// E.g.: -/// %y = ... : tensor<...> -/// %r = gml_st.loop ... outs(%i = %y : tensor<...>) { -/// ... -/// } -/// %0 = tensor.dim %r, %c0 : tensor<...> -/// -/// is folded to: -/// %y = ... : tensor<...> -/// gml_st.loop ... outs(%i = %y : tensor<...>) { -/// ... -/// } -/// %0 = tensor.dim %y, %c0 : tensor<...> -/// -/// Note: Dim ops are folded only if it can be proven that the runtime type of -/// the yielded value (in case of outputs) does not change with loop iterations. -template -struct DimOfLoopResultFolder : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy dimOp, - PatternRewriter &rewriter) const final { - auto loopOp = dimOp.getSource().template getDefiningOp(); - if (!loopOp) return failure(); - auto opResult = dimOp.getSource().template cast(); - unsigned resultNumber = opResult.getResultNumber(); - if (!isShapePreserving(loopOp, resultNumber)) return failure(); - rewriter.updateRootInPlace(dimOp, [&]() { - dimOp.getSourceMutable().assign(loopOp.getOutputs()[resultNumber]); - }); - return success(); - } -}; - -// Folds away LoopOp output tensors when the following conditions are met: -// * result of `gml_st.loop` has no uses -// * output tensor is the argument of `gml_st.yield` -// -// Example: -// -// %0 = gml_st.loop ... outs (%o_ = %out: tensor<...>, -// %obuf_ = %out_buf: memref<...>) { -// ... -// gml_st.yield %o_ : tensor ... -// } -// -// Becomes -// -// gml_st.loop ... outs (%obuf_ = %out_buf: memref<...>) { -// ... -// gml_st.yield -// } -struct LoopResultsFolder : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(LoopOp loop, - PatternRewriter &rewriter) const final { - if (loop.getNumResults() == 0) return failure(); - - Block *block = loop.getBody(); - auto yieldOp = cast(block->getTerminator()); - - // Match the pattern and collect output buffers that will replace the output - // tensors and also the ops that will be ignored when cloning the body. - SmallVector newOutputOperands, newYieldArgs; - int resultId = 0; - // Store ids of the corresponding old and new output operands. - SmallVector oldOutputIdToNew(loop.getOutputs().size(), - kNoMatch); - // Store ids of the corresponding old and new results. - SmallVector oldResultIdToNew(loop.getNumResults(), kNoMatch); - SmallVector resultReplacement(loop.getNumResults()); - for (const auto &en : llvm::enumerate( - llvm::zip(loop.getOutputs(), loop.getRegionOutputArgs()))) { - size_t index = en.index(); - Value out = std::get<0>(en.value()); - Value outRegionArg = std::get<1>(en.value()); - - if (!out.getType().isa()) { - oldOutputIdToNew[index] = newOutputOperands.size(); - newOutputOperands.push_back(out); - continue; - } - Value result = loop.getResult(resultId); - Value yieldArg = yieldOp.getOperand(resultId); - if (yieldArg != outRegionArg || !result.use_empty()) { - oldOutputIdToNew[index] = newOutputOperands.size(); - oldResultIdToNew[resultId] = newYieldArgs.size(); - resultReplacement[resultId] = out; - newOutputOperands.push_back(out); - newYieldArgs.push_back(yieldArg); - } - ++resultId; - } - if (newOutputOperands.size() == loop.getOutputs().size()) return failure(); - - Location loc = loop.getLoc(); - auto newLoop = rewriter.create( - loc, loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), - loop.getInputs(), newOutputOperands, loop.getIteratorTypes(), - loop.getDistributionTypes()); - - // Clone the region. - BlockAndValueMapping bvm; - bvm.map(loop.getInductionVars(), newLoop.getInductionVars()); - bvm.map(loop.getRegionInputArgs(), newLoop.getRegionInputArgs()); - for (const auto &en : llvm::enumerate(oldOutputIdToNew)) { - if (en.value() != kNoMatch) - bvm.map(loop.getRegionOutputArgs()[en.index()], - newLoop.getRegionOutputArgs()[en.value()]); - else - bvm.map(loop.getRegionOutputArgs()[en.index()], - loop.getOutputs()[en.index()]); - } - OpBuilder innerBuilder = - OpBuilder::atBlockEnd(newLoop.getBody(), rewriter.getListener()); - for (auto &op : loop.getBody()->without_terminator()) - innerBuilder.clone(op, bvm); - innerBuilder.create( - loc, llvm::to_vector<2>(llvm::map_range( - newYieldArgs, [&](Value arg) { return bvm.lookup(arg); }))); - - for (const auto &en : llvm::enumerate(oldResultIdToNew)) - if (en.value() != kNoMatch) - resultReplacement[en.index()] = newLoop.getResult(en.value()); - rewriter.replaceOp(loop, resultReplacement); - - return success(); - } -}; - -/// Pull `gml_st.loop` input/output arguments that are produced by -/// `tensor.cast` ops inside `gml_st.loop`: -/// -/// ``` -/// %in = tensor.cast %t0 : tensor<32x1024xf32> to tensor -/// %out = tensor.cast %t1 : tensor<32x1024xf32> to tensor -/// %result = gml_st.loop %i = %c0 to %c1024 step %c32 -/// ins (%in_ = %in: tensor) -/// outs (%out_ = %out: tensor) { -/// %0 = call @do(%in_, %out_) -/// : (tensor, tensor) -> tensor -/// scf.yield %0 : tensor -/// } -/// %result_cast = tensor.cast %result -/// : tensor to tensor<32x1024xf32> -/// use_of(%result_cast) -/// ``` -/// -/// folds into: -// -/// ``` -/// %result = gml_st.loop %i = %c0 to %c1024 step %c32 -/// ins (%in_ = %t0: tensor<32x1024xf32>) -/// outs (%out_ = %t1: tensor<32x1024xf32>) { -/// %in_cast = tensor.cast %in_ : tensor<32x1024xf32> to tensor -/// %out_cast = tensor.cast %out_ : tensor<32x1024xf32> to tensor -/// %0 = call @do(%in_, %out_) -/// : (tensor, tensor) -> tensor -/// %0_cast = tensor.cast %0 : tensor to tensor<32x1024xf32> -/// scf.yield %0 : tensor<32x1024xf32> -/// } -/// use_of(%result) -/// ``` -struct TensorCastOfLoopInsOutsFolder : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(LoopOp loop, - PatternRewriter &rewriter) const override { - CastOpsOfArgs inputCasts = findTensorCastOps(loop.getInputs()); - CastOpsOfArgs outputCasts = findTensorCastOps(loop.getOutputs()); - if (!inputCasts.castFound && !outputCasts.castFound) return failure(); - - auto newLoop = rewriter.create( - loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), - loop.getStep(), inputCasts.updatedArgs, outputCasts.updatedArgs, - loop.getIteratorTypes(), loop.getDistributionTypes()); - - rewriter.replaceOp(loop, insertCastsAndCloneBody(inputCasts, outputCasts, - loop, newLoop, rewriter)); - return success(); - } - - private: - struct CastOpsOfArgs { - SmallVector ops; - // Contains either old arguments or arguments of `tensor.cast`. - SmallVector updatedArgs; - bool castFound = false; - }; - - // Scans through args to find what args are produced by `tensor.cast` ops. - CastOpsOfArgs findTensorCastOps(ValueRange args) const { - CastOpsOfArgs result; - for (auto arg : args) { - if (auto cast = arg.getDefiningOp()) { - result.ops.push_back(cast); - result.updatedArgs.push_back(cast.getSource()); - result.castFound = true; - continue; - } - result.ops.push_back(nullptr); - result.updatedArgs.push_back(arg); - } - return result; - } - - SmallVector insertCastsAndCloneBody( - const CastOpsOfArgs &inputCasts, const CastOpsOfArgs &outputCasts, - LoopOp loop, LoopOp newLoop, PatternRewriter &rewriter) const { - auto loc = newLoop.getLoc(); - BlockAndValueMapping bvm; - bvm.map(loop.getInductionVars(), newLoop.getInductionVars()); - - auto innerBuilder = - OpBuilder::atBlockEnd(newLoop.getBody(), rewriter.getListener()); - - Value oldArg, newArg, yieldArg, result; - tensor::CastOp argCast; - - // Map inputs, insert `tensor.cast` if necessary. - for (auto item : llvm::zip(loop.getRegionInputArgs(), - newLoop.getRegionInputArgs(), inputCasts.ops)) { - std::tie(oldArg, newArg, argCast) = item; - if (!argCast) { - bvm.map(oldArg, newArg); - continue; - } - Value newCast = - innerBuilder.create(loc, argCast.getType(), newArg); - bvm.map(oldArg, newCast); - } - - // Map outputs, insert `tensor.cast` and cast the loop results if necessary. - SmallVector newResults; - rewriter.setInsertionPointAfter(newLoop); - for (auto item : - llvm::zip(loop.getRegionOutputArgs(), newLoop.getRegionOutputArgs(), - outputCasts.ops, newLoop.getResults())) { - std::tie(oldArg, newArg, argCast, result) = item; - if (!argCast) { - bvm.map(oldArg, newArg); - newResults.push_back(result); - continue; - } - Value newCast = - innerBuilder.create(loc, argCast.getType(), newArg); - bvm.map(oldArg, newCast); - - newResults.push_back( - rewriter.create(loc, argCast.getType(), result)); - } - - // Clone loop body. - for (auto &op : loop.getBody()->without_terminator()) - innerBuilder.clone(op, bvm); - - // Cast yield arguments to the new type. - SmallVector yieldArgs = - loop.getBody()->getTerminator()->getOperands(); - SmallVector newYieldArgs; - for (auto item : llvm::zip(yieldArgs, outputCasts.ops)) { - std::tie(yieldArg, argCast) = item; - if (!argCast) { - newYieldArgs.push_back(bvm.lookup(yieldArg)); - continue; - } - newYieldArgs.push_back(innerBuilder.create( - loc, argCast.getSource().getType(), bvm.lookup(yieldArg))); - } - innerBuilder.create(loc, newYieldArgs); - return newResults; - } -}; - -/// Removes loops in which at least one lower/upper bound pair consists -/// of the same values - such loops have an empty iteration domain. -struct FoldEmptyLoops : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(LoopOp op, - PatternRewriter &rewriter) const override { - for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound())) { - if (std::get<0>(dim) != std::get<1>(dim)) continue; - SmallVector tensorOutputs; - for (Value out : op.getOutputs()) { - if (out.getType().isa()) tensorOutputs.push_back(out); - } - rewriter.replaceOp(op, tensorOutputs); - return success(); - } - return failure(); - } -}; - -} // namespace - -void LoopOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results - .add, - DimOfLoopInsOutsFolder, - DimOfLoopResultFolder, - DimOfLoopResultFolder, TensorCastOfLoopInsOutsFolder>( - context); -} - -/// This is used for patterns of the form -/// ``` -/// gml_st.loop(memrefcast(%src)) -> gml_st.loop(%src) -/// ``` -/// It folds the source of the memref.cast into the root operation directly. -LogicalResult LoopOp::fold(ArrayRef, - SmallVectorImpl &) { - LoopOp op = *this; - bool folded = false; - Location loc = op->getLoc(); - - Block *body = op.getBody(); - OpBuilder b = OpBuilder::atBlockBegin(body); - - // Update `input` and `output` operands and block arguments if necessary. - // Operands list: [lbs, ubs, steps, inputs, outputs]. - // Block args list: [ivs, inputs, outputs]. - for (size_t operandIndex = op.getNumControlOperands(), - bbArgIndex = op.getNumLoops(), e = op.getNumOperands(); - operandIndex < e; ++operandIndex, ++bbArgIndex) { - OpOperand &operand = op->getOpOperand(operandIndex); - - auto castOp = operand.get().getDefiningOp(); - if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { - operand.set(castOp.getOperand()); - BlockArgument newBbArg = body->insertArgument( - bbArgIndex, castOp.getOperand().getType(), op.getLoc()); - BlockArgument oldBbArg = body->getArgument(newBbArg.getArgNumber() + 1); - - // Insert memref.cast back to the original type. - oldBbArg.replaceAllUsesWith( - b.create(loc, oldBbArg.getType(), newBbArg)); - body->eraseArgument(oldBbArg.getArgNumber()); - - folded = true; - } - } - return success(folded); -} - -//===----------------------------------------------------------------------===// -// YieldOp -//===----------------------------------------------------------------------===// - -LogicalResult YieldOp::verify() { - auto *parentOp = getOperation()->getParentOp(); - - if (auto setYield = dyn_cast(parentOp)) { - if (getValues().size() != 1) - return emitOpError( - "expected a single argument for the terminator of accumulator " - "region"); - return success(); - } - auto loopOp = cast(parentOp); - // Check if output args with tensor types match results types. - SmallVector tensorOuts; - llvm::copy_if( - loopOp.getOutputs(), std::back_inserter(tensorOuts), - [&](Value out) { return out.getType().isa(); }); - if (tensorOuts.size() != getValues().size()) - return emitOpError("expected number of tensor output args = ") - << tensorOuts.size() - << " to match the number of yield operands = " << getValues().size(); - - TypeRange tensorTypes{ValueRange{tensorOuts}}; - for (auto &item : - llvm::enumerate(llvm::zip(tensorTypes, getOperandTypes()))) { - Type outType, resultType; - unsigned index = item.index(); - std::tie(outType, resultType) = item.value(); - if (outType != resultType) - return emitOpError("expected yield operand ") - << index << " with type = " << resultType - << " to match output arg type = " << outType; - } - return success(); -} - -//===----------------------------------------------------------------------===// -// TileOp -//===----------------------------------------------------------------------===// - -namespace { -/// Fold gml_st.tile [%c0] ... into gml_st.tile [0] ... -/// Adapted from OpWithOffsetSizesAndStridesConstantArgumentFolder, which makes -/// slightly incompatible assumptions about the op. -struct FoldConstantsIntoTileType : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TileOp op, - PatternRewriter &rewriter) const override { - // No constant operand, just return; - if (llvm::none_of(op.getOperands(), [](Value operand) { - return matchPattern(operand, matchConstantIndex()); - })) - return failure(); - - // At least one of offsets/sizes/strides is a new constant. - // Form the new list of operands and constant attributes from the existing. - SmallVector mixedOffsets(op.getMixedOffsets()); - SmallVector mixedSizes(op.getMixedSizes()); - SmallVector mixedStrides(op.getMixedStrides()); - canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset); - canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic); - canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset); - - // Create the new tile in canonical form. - TileOp newOp = rewriter.create(op.getLoc(), mixedOffsets, - mixedSizes, mixedStrides); - // Cast the result back to the original type. This will be folded further - // materialize ops. - rewriter.replaceOpWithNewOp( - op, TypeRange{op.getType()}, ValueRange{newOp}); - - return success(); - } -}; -} // namespace - -void TileOp::build(OpBuilder &b, OperationState &result, - ArrayRef offsets, ArrayRef sizes, - ArrayRef strides, - ArrayRef attrs) { - SmallVector staticOffsets, staticSizes, staticStrides; - SmallVector dynamicOffsets, dynamicSizes, dynamicStrides; - dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets, - ShapedType::kDynamicStrideOrOffset); - dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes, - ShapedType::kDynamicSize); - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, - ShapedType::kDynamicStrideOrOffset); - auto tileType = TileType::get(b.getContext(), staticSizes); - build(b, result, tileType, dynamicOffsets, dynamicSizes, dynamicStrides, - b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes), - b.getI64ArrayAttr(staticStrides)); - result.addAttributes(attrs); -} - -void TileOp::build(OpBuilder &b, OperationState &result, - ArrayRef offsets, - ArrayRef attrs) { - SmallVector unitSizesAndStrides(offsets.size(), - b.getIndexAttr(1)); - return build(b, result, offsets, unitSizesAndStrides, unitSizesAndStrides, - attrs); -} - -LogicalResult TileOp::inferReturnTypes( - MLIRContext *ctx, Optional /*loc*/, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - // Derive result shape. - TileOp::Adaptor adaptor(operands, attributes, regions); - SmallVector shape = llvm::to_vector( - llvm::map_range(adaptor.getStaticSizes(), [&](const auto &size) { - return size.template dyn_cast() - .getValue() - .getSExtValue(); - })); - - auto resultTy = TileType::get(ctx, shape); - inferredReturnTypes.push_back(resultTy); - return success(); -} - -LogicalResult TileOp::verify() { - auto resultType = getType(); - auto rank = resultType.getRank(); - if (failed(mlir::verifyListOfOperandsOrIntegers(getOperation(), "size", rank, - getStaticSizes(), getSizes(), - ShapedType::isDynamic))) { - return failure(); - } - if (failed(mlir::verifyListOfOperandsOrIntegers( - getOperation(), "offset", rank, getStaticOffsets(), getOffsets(), - ShapedType::isDynamicStrideOrOffset))) { - return failure(); - } - if (failed(mlir::verifyListOfOperandsOrIntegers( - getOperation(), "stride", rank, getStaticStrides(), getStrides(), - ShapedType::isDynamicStrideOrOffset))) { - return failure(); - } - for (auto it : llvm::zip(resultType.getShape(), getStaticOffsets(), - getStaticSizes(), getStaticStrides())) { - auto offset = - std::get<1>(it).dyn_cast().getValue().getSExtValue(); - if (offset < 0 && offset != ShapedType::kDynamicStrideOrOffset) { - return emitOpError("expected offset = ") - << offset << " to be non-negative"; - } - auto size = - std::get<2>(it).dyn_cast().getValue().getSExtValue(); - if (size < 0 && size != ShapedType::kDynamicSize) { - return emitOpError("expected size = ") << size << " to be non-negative"; - } - auto stride = - std::get<3>(it).dyn_cast().getValue().getSExtValue(); - if (stride < 0 && stride != ShapedType::kDynamicStrideOrOffset) { - return emitOpError("expected stride = ") - << stride << " to be non-negative"; - } - auto tileSize = std::get<0>(it); - if (tileSize != size) { - return emitOpError("size arg = ") - << size << " does not match tile size = " << tileSize; - } - } - return success(); -} - -void TileOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -namespace { - -OpFoldResult multiplyOperandsOrIntegers(OpBuilder &builder, Location loc, - OpFoldResult lhs, OpFoldResult rhs) { - // Both operands are static. - if (lhs.is() && rhs.is()) { - return builder.getI64IntegerAttr( - lhs.get().cast().getInt() * - rhs.get().cast().getInt()); - } - - // Exploit commutativity and move static operand to the left (if any). - if (rhs.is()) std::swap(lhs, rhs); - - // Create constant if needed. - if (lhs.is()) { - int64_t lhsInt = lhs.get().cast().getInt(); - - // Exploit static operand if possible. - if (lhsInt == 0) return lhs; - if (lhsInt == 1) return rhs; - - lhs = builder.create(loc, lhsInt).getResult(); - } - - // Multiply. - return builder.create(loc, lhs.get(), rhs.get()) - .getResult(); -} - -OpFoldResult addOperandsOrIntegers(OpBuilder &builder, Location loc, - OpFoldResult lhs, OpFoldResult rhs) { - // Both operands are static. - if (lhs.is() && rhs.is()) { - return builder.getI64IntegerAttr( - lhs.get().cast().getInt() + - rhs.get().cast().getInt()); - } - - // Exploit commutativity and move static operand to the left (if any). - if (rhs.is()) std::swap(lhs, rhs); - - // Create constant if needed. - if (lhs.is()) { - int64_t lhsInt = lhs.get().cast().getInt(); - - // Exploit static operand if possible. - if (lhsInt == 0) return rhs; - - lhs = builder.create(loc, lhsInt).getResult(); - } - - // Add. - return builder.create(loc, lhs.get(), rhs.get()) - .getResult(); -} - -// Compose offsets with newOffset = supersetOffset + supersetStride * offset. -SmallVector composeOffsets( - const llvm::SmallVectorImpl &supersetOffsets, - const llvm::SmallVectorImpl &supersetStrides, - const llvm::SmallVectorImpl &offsets, Location loc, - OpBuilder &builder) { - SmallVector composedOffsets; - for (auto it : llvm::zip(supersetOffsets, supersetStrides, offsets)) { - composedOffsets.push_back(addOperandsOrIntegers( - builder, loc, std::get<0>(it), - multiplyOperandsOrIntegers(builder, loc, std::get<1>(it), - std::get<2>(it)))); - } - return composedOffsets; -} - -// Compose strides with newStride = supersetStride * stride. -SmallVector composeStrides( - OpBuilder &builder, Location loc, - const llvm::SmallVectorImpl &supersetStrides, - const llvm::SmallVectorImpl &strides) { - SmallVector composedStrides; - for (auto it : llvm::zip(supersetStrides, strides)) { - composedStrides.push_back(multiplyOperandsOrIntegers( - builder, loc, std::get<0>(it), std::get<1>(it))); - } - return composedStrides; -} - -} // namespace - -//===----------------------------------------------------------------------===// -// SetYieldOp -//===----------------------------------------------------------------------===// - -using AccumulatorRegionBuilderFn = - function_ref; - -void SetYieldOp::build(OpBuilder &builder, OperationState &result) { - build(builder, result, llvm::None, llvm::None, llvm::None); -} - -void SetYieldOp::build(OpBuilder &builder, OperationState &result, - ValueRange srcs, ValueRange dsts, ValueRange sets) { - SmallVector accumulatorFlags(srcs.size(), false); - build(builder, result, srcs, dsts, sets, - builder.getBoolArrayAttr(accumulatorFlags), llvm::None); -} - -void SetYieldOp::build( - OpBuilder &builder, OperationState &result, ValueRange srcs, - ValueRange dsts, ValueRange sets, ArrayAttr accumulatorFlags, - ArrayRef accumulatorBuilderFns) { - assert(dsts.size() == srcs.size() && - "`dsts` and `srcs` should have the same size"); - assert(sets.size() == srcs.size() && - "`sets` and `srcs` should have the same size"); - assert(accumulatorFlags.size() == srcs.size() && - "`accumulatorFlags` and `srcs` should have the same size"); - - auto accumulatorCount = llvm::count_if(accumulatorFlags, [](Attribute attr) { - return attr.cast().getValue(); - }); - (void)accumulatorCount; - assert(accumulatorCount == - static_cast(accumulatorBuilderFns.size()) && - "the number of flags set in `accumulatorFlags` attribute should be " - "equal to the number of `accumulatorBuilderFns`"); - - result.addOperands(srcs); - result.addOperands(dsts); - result.addOperands(sets); - result.addAttribute(SetYieldOp::getAccumulatorFlagsAttrName(result.name), - accumulatorFlags); - - const auto *builderFnIt = accumulatorBuilderFns.begin(); - for (auto item : llvm::zip(srcs, accumulatorFlags)) { - Value src = std::get<0>(item); - auto accumulatorFlag = std::get<1>(item).cast(); - - if (!accumulatorFlag.getValue()) continue; - Region *region = result.addRegion(); - OpBuilder::InsertionGuard g(builder); - SmallVector argTypes(2, src.getType()); - builder.createBlock(region); - Block &bodyBlock = region->front(); - bodyBlock.addArguments(argTypes, {result.location, result.location}); - - builder.setInsertionPointToStart(&bodyBlock); - (*builderFnIt)(builder, result.location, bodyBlock.getArgument(0), - bodyBlock.getArgument(1)); - ++builderFnIt; - } -} - -LogicalResult SetYieldOp::verify() { - for (const auto [dst, src, set] : - llvm::zip(getDsts(), getSrcs(), getSets())) { - if (failed(verifyCompatibleExtractedSubset(getOperation(), - dst.getType().cast(), - src.getType(), set.getType()))) - return failure(); - } - auto accumulatorCount = llvm::count_if( - getAccumulatorFlags(), - [](Attribute attr) { return attr.cast().getValue(); }); - if (accumulatorCount != static_cast(getAccumulators().size())) - return emitOpError("expected the number of accumulator regions ") - << getAccumulators().size() - << " to match the number of set accumulator flags " - << accumulatorCount; - - auto *regionIt = getAccumulators().begin(); - for (auto item : llvm::zip(getSrcs(), getAccumulatorFlags())) { - Type srcType = std::get<0>(item).getType(); - BoolAttr accumulatorFlag = std::get<1>(item).cast(); - if (!accumulatorFlag.getValue()) continue; - - Block &block = regionIt->front(); - if (block.getArgumentTypes() != SmallVector{srcType, srcType}) - return emitOpError() - << "expected accumulator region to have 2 arguments of type " - << srcType; - ++regionIt; - } - return success(); -} - -void SetYieldOp::print(OpAsmPrinter &p) { - p.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs = */ - {getAccumulatorFlagsAttrName().str()}); - - auto *regionIt = getOperation()->getRegions().begin(); - for (auto &en : llvm::enumerate( - llvm::zip(getSrcs(), getDsts(), getSets(), getAccumulatorFlags()))) { - if (en.index() > 0) { - p << ','; - p.printNewline(); - } - Value src = std::get<0>(en.value()); - Value dst = std::get<1>(en.value()); - Value set = std::get<2>(en.value()); - auto accumulatorFlag = std::get<3>(en.value()).cast(); - - p << ' ' << src << " into " << dst << '[' << set << ']'; - - if (accumulatorFlag.getValue()) { - auto &block = regionIt->getBlocks().front(); - Value newValue = block.getArgument(0); - Value oldValue = block.getArgument(1); - p << " acc (" << newValue << ", " << oldValue << ": " - << oldValue.getType() << ") "; - - p.printRegion(*regionIt, false); - ++regionIt; - } - - p << " : " << src.getType() << " into " << dst.getType() << '[' - << set.getType() << ']'; - } -} - -ParseResult SetYieldOp::parse(OpAsmParser &parser, OperationState &result) { - if (parser.parseOptionalAttrDict(result.attributes)) return failure(); - - SmallVector accumulatorFlags; - SmallVector srcs, dsts, sets; - SmallVector srcTypes, dstTypes, setTypes; - - auto parseElt = [&]() -> ParseResult { - OpAsmParser::UnresolvedOperand src; - auto parseResult = parser.parseOptionalOperand(src); - - if (!parseResult.has_value()) return success(); - srcs.push_back(src); - - if (parser.parseKeyword("into") || - parser.parseOperand(dsts.emplace_back()) || parser.parseLSquare() || - parser.parseOperand(sets.emplace_back()) || parser.parseRSquare()) - return failure(); - - OpBuilder b(parser.getBuilder().getContext()); - bool hasAccumulatorRegion = succeeded(parser.parseOptionalKeyword("acc")); - accumulatorFlags.push_back(hasAccumulatorRegion); - if (hasAccumulatorRegion) { - auto region = std::make_unique(); - OpAsmParser::UnresolvedOperand newValue, oldValue; - Type argType; - if (parser.parseLParen() || parser.parseOperand(newValue) || - parser.parseComma() || parser.parseOperand(oldValue) || - parser.parseColonType(argType) || parser.parseRParen()) - return failure(); - - SmallVector regionArgs; - for (auto value : {newValue, oldValue}) { - auto &arg = regionArgs.emplace_back(); - arg.ssaName = value; - arg.type = argType; - } - - if (parser.parseRegion(*region, regionArgs)) return failure(); - result.addRegion(std::move(region)); - } - if (parser.parseColon() || parser.parseType(srcTypes.emplace_back()) || - parser.parseKeyword("into") || - parser.parseType(dstTypes.emplace_back()) || parser.parseLSquare() || - parser.parseType(setTypes.emplace_back()) || parser.parseRSquare()) - return failure(); - - return success(); - }; - if (parser.parseCommaSeparatedList(AsmParser::Delimiter::None, parseElt)) - return failure(); - - if (parser.resolveOperands(srcs, srcTypes, parser.getCurrentLocation(), - result.operands) || - parser.resolveOperands(dsts, dstTypes, parser.getCurrentLocation(), - result.operands) || - parser.resolveOperands(sets, setTypes, parser.getCurrentLocation(), - result.operands)) - return failure(); - - result.addAttribute(SetYieldOp::getAccumulatorFlagsAttrName(result.name), - parser.getBuilder().getBoolArrayAttr(accumulatorFlags)); - return success(); -} - -namespace { -/// Folds UnrealizedConversionCast of TileType into SetYieldOp. -struct FoldTileCastIntoSetYield : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SetYieldOp op, - PatternRewriter &rewriter) const override { - if (!llvm::any_of(op.getSets(), [](auto set) { - return set.template getDefiningOp() != - nullptr; - })) - return failure(); - SmallVector newSrcs{op.getSrcs()}; - SmallVector newSets{op.getSets()}; - for (auto &&[src, set] : llvm::zip(newSrcs, newSets)) { - auto cast = set.getDefiningOp(); - if (!cast) continue; - set = cast.getOperand(0); - Type castResultType = src.getType(); - if (auto shapedType = dyn_cast(castResultType)) { - castResultType = - shapedType.clone(set.getType().cast().getShape(), - shapedType.getElementType()); - src = rewriter.create(op.getLoc(), castResultType, src); - } - } - rewriter.replaceOpWithNewOp(op, newSrcs, op.getDsts(), newSets); - return success(); - } -}; -} // namespace - -void SetYieldOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -} // namespace gml_st -} // namespace mlir - -// Generated op classes. -#define GET_OP_CLASSES -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.cc.inc" diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt deleted file mode 100644 index ea3c6cfb6bd..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -include_directories(BEFORE - ${CMAKE_CURRENT_BINARY_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}) - -set(LLVM_OPTIONAL_SOURCES - bufferizable_op_interface_impl.cc -) - -add_mlir_library(GmlStTilingInterface - tiling_interface.cc - - LINK_LIBS PUBLIC - MLIRIR - MLIRSupport - - DEPENDS - MLIRGmlStTilingInterfaceIncGen -) - -add_mlir_library(GmlStTilingInterfaceImpl - tiling_interface_impl.cc - - LINK_LIBS PUBLIC - GmlStDialect - GmlStTilingInterface - MLIRArithDialect - MLIRAffineDialect - MLIRDestinationStyleOpInterface - MLIRLinalgDialect - MLIRLinalgTransforms - MLIRLinalgUtils - MLIRTensorDialect - MLIRTensorUtils - MLIRIR - MLIRSupport - THLODialect -) - -add_mlir_library(GmlStBufferizableOpInterface - bufferizable_op_interface_impl.cc - - LINK_LIBS PUBLIC - GmlStDialect - MLIRBufferizationDialect - MLIRBufferizationTransforms - MLIRDestinationStyleOpInterface - MLIRIR - MLIRSupport -) - -add_mlir_library(GmlStPasses - collapse_materialize_ops.cc - fusion.cc - gml_st_to_scf.cc - gml_st_to_gpu.cc - linalg_utils.cc - tiling.cc - tiling_cwise.cc - tiling_gpu_warp.cc - tiling_softmax.cc - transform_map_for_cpu.cc - transform_matmul_for_cpu.cc - transform_transpose_for_cpu.cc - transform_scatter_for_cpu.cc - vectorization.cc - - DEPENDS - MLIRGmlStPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - GmlStTilingInterface - GmlStTilingInterfaceImpl - MLIRDestinationStyleOpInterface - MhloDialect - MLIRDialectUtils - MLIRAffineDialect - MLIRArithDialect - MLIRFuncDialect - MLIRGPUOps - MLIRIR - MLIRLinalgDialect - MLIRLinalgTransforms - MLIRPass - MLIRSupport - MLIRVectorDialect -) - -add_mlir_library(GmlStTransforms - transforms.cc - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - GmlStDialect - MLIRAffineDialect - MLIRDialectUtils - MLIRIR - MLIRSCFUtils -) - -add_mlir_library(GmlStTestPasses - test_passes.cc - - DEPENDS - MLIRGmlStTestPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - GmlStBufferizableOpInterface - GmlStDialect - GmlStTransforms - MLIRPass - MLIRTransforms -) diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/collapse_materialize_ops.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/collapse_materialize_ops.cc deleted file mode 100644 index 4f1e9a0a07d..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/collapse_materialize_ops.cc +++ /dev/null @@ -1,183 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "mlir-hlo/Dialect/gml_st/transforms/rewriters.h" -#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace gml_st { -namespace { - -#define GEN_PASS_DEF_COLLAPSEMATERIALIZEOPSPASS -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h.inc" -OpFoldResult multiplyOperandsOrIntegers(PatternRewriter& rewriter, Location loc, - OpFoldResult lhs, OpFoldResult rhs) { - // Both operands are static. - if (lhs.is() && rhs.is()) { - return rewriter.getI64IntegerAttr( - lhs.get().cast().getInt() * - rhs.get().cast().getInt()); - } - - // Exploit commutativity and move static operand to the left (if any). - if (rhs.is()) std::swap(lhs, rhs); - - // Create constant if needed. - if (lhs.is()) { - int64_t lhsInt = lhs.get().cast().getInt(); - - // Exploit static operand if possible. - if (lhsInt == 0) return lhs; - if (lhsInt == 1) return rhs; - - lhs = rewriter.create(loc, lhsInt).getResult(); - } - - // Multiply. - return rewriter.create(loc, lhs.get(), rhs.get()) - .getResult(); -} - -OpFoldResult addOperandsOrIntegers(PatternRewriter& rewriter, Location loc, - OpFoldResult lhs, OpFoldResult rhs) { - // Both operands are static. - if (lhs.is() && rhs.is()) { - return rewriter.getI64IntegerAttr( - lhs.get().cast().getInt() + - rhs.get().cast().getInt()); - } - - // Exploit commutativity and move static operand to the left (if any). - if (rhs.is()) std::swap(lhs, rhs); - - // Create constant if needed. - if (lhs.is()) { - int64_t lhsInt = lhs.get().cast().getInt(); - - // Exploit static operand if possible. - if (lhsInt == 0) return rhs; - - lhs = rewriter.create(loc, lhsInt).getResult(); - } - - // Add. - return rewriter.create(loc, lhs.get(), rhs.get()) - .getResult(); -} - -// Compose offsets with newOffset = supersetOffset + supersetStride * offset. -SmallVector composeOffsets( - const llvm::SmallVectorImpl& supersetOffsets, - const llvm::SmallVectorImpl& supersetStrides, - const llvm::SmallVectorImpl& offsets, Location loc, - PatternRewriter& rewriter) { - SmallVector composedOffsets; - for (auto it : llvm::zip(supersetOffsets, supersetStrides, offsets)) { - composedOffsets.push_back(addOperandsOrIntegers( - rewriter, loc, std::get<0>(it), - multiplyOperandsOrIntegers(rewriter, loc, std::get<1>(it), - std::get<2>(it)))); - } - return composedOffsets; -} - -// Compose strides with newStride = supersetStride * stride. -SmallVector composeStrides( - PatternRewriter& rewriter, Location loc, - const llvm::SmallVectorImpl& supersetStrides, - const llvm::SmallVectorImpl& strides) { - SmallVector composedStrides; - for (auto it : llvm::zip(supersetStrides, strides)) { - composedStrides.push_back(multiplyOperandsOrIntegers( - rewriter, loc, std::get<0>(it), std::get<1>(it))); - } - return composedStrides; -} - -// Collapse materialize operations with nested tile chains t1, t2, ..., tn, and -// u1, u2, ..., un. A materialize op of the form ... -// `materialize(materialize(tensor2, t2), t1) -// ... is collapsed as ... -// `materialize(t2, composed_tile(t1, t2)) -struct CollapseMaterializeOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(MaterializeOp op, - PatternRewriter& rewriter) const override { - auto tileOp = op.getSet().getDefiningOp(); - if (!tileOp) return failure(); - - auto producerMaterializeOp = op.getSource().getDefiningOp(); - if (!producerMaterializeOp) return failure(); - - auto producerTileOp = - producerMaterializeOp.getSet().getDefiningOp(); - if (!producerTileOp) return failure(); - - // Compose tileOp and producerTileOp. - auto loc = op.getLoc(); - auto producerStrides = producerTileOp.getMixedStrides(); - auto composedOffsets = - composeOffsets(producerTileOp.getMixedOffsets(), producerStrides, - tileOp.getMixedOffsets(), loc, rewriter); - auto composedStrides = composeStrides(rewriter, loc, producerStrides, - tileOp.getMixedStrides()); - auto composedTileOp = rewriter.create( - loc, composedOffsets, tileOp.getMixedSizes(), composedStrides); - - rewriter.replaceOpWithNewOp( - op, producerMaterializeOp.getSource(), composedTileOp); - return success(); - } -}; - -struct CollapseMaterializeOpsPass - : public impl::CollapseMaterializeOpsPassBase { - void runOnOperation() override { - MLIRContext* ctx = &getContext(); - RewritePatternSet patterns(ctx); - populateCollapseMaterializeOpsPatterns(ctx, &patterns); - - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -void populateCollapseMaterializeOpsPatterns(MLIRContext* ctx, - RewritePatternSet* patterns) { - patterns->add(ctx); -} - -std::unique_ptr> -createCollapseMaterializeOpsPass() { - return std::make_unique(); -} - -} // namespace gml_st -} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/fusion.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/fusion.cc deleted file mode 100644 index 7fcdf3b7e49..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/fusion.cc +++ /dev/null @@ -1,302 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir-hlo/Dialect/gml_st/transforms/fusion.h" - -#include -#include - -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "mlir-hlo/Dialect/gml_st/transforms/rewriters.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h" -#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h" -#include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Shape/IR/Shape.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace gml_st { -namespace { - -#define GEN_PASS_DEF_FUSIONPASS -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h.inc" - -// TODO(frgossen): Move this to the shape reification pass. -struct DimOpFissionPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ExtractOp extract, - PatternRewriter& rewriter) const override { - auto shapeDef = llvm::dyn_cast_or_null( - extract.getTensor().getDefiningOp()); - if (!shapeDef || extract.getIndices().size() != 1) return failure(); - rewriter.replaceOpWithNewOp(extract, shapeDef.getArg(), - extract.getIndices().front()); - return success(); - } -}; - -// TODO(frgossen): Implement this through the shape reification interface and -// move this pattern to the shape reification pass. -struct DimOpReificationPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::DimOp op, - PatternRewriter& rewriter) const override { - Operation* def = op.getSource().getDefiningOp(); - if (!def) return failure(); - - // TODO(pifon): Split this pattern into many. - // Case MaterializeOp. - if (auto materializeOp = llvm::dyn_cast(def)) { - assert(materializeOp->getNumResults() == 1 && "assume single result"); - auto dimConstantIndex = op.getConstantIndex(); - if (!dimConstantIndex.has_value()) return failure(); - - auto tileOp = materializeOp.getSet().getDefiningOp(); - if (!tileOp) return failure(); - rewriter.replaceOp(op, tileOp.getSizes()[*dimConstantIndex]); - return success(); - } - // Case GenericOp. - if (auto genericOp = llvm::dyn_cast(def)) { - if (genericOp.getNumResults() != 1 || !genericOp.hasTensorSemantics()) { - return failure(); - } - Value outputOperand = genericOp.getDpsInitOperand(0)->get(); - rewriter.replaceOpWithNewOp(op, outputOperand, - op.getIndex()); - return success(); - } - - // Case EmptyOp. - if (auto emptyTensorOp = llvm::dyn_cast(def)) { - if (auto indexConstantOp = llvm::dyn_cast_or_null( - op.getIndex().getDefiningOp())) { - int64_t idx = - indexConstantOp.getValue().dyn_cast().getInt(); - OpFoldResult dim = emptyTensorOp.getMixedSizes()[idx]; - Value dimValue; - if (dim.is()) { - dimValue = dim.get(); - } else { - assert(dim.is() && "expected Value or Attribute"); - int64_t dimInt = dim.get().cast().getInt(); - dimValue = - rewriter.create(op.getLoc(), dimInt); - } - assert(dimValue); - rewriter.replaceOp(op, ValueRange{dimValue}); - return success(); - } - } - - // Case ConcatenateOp. - if (auto concat = llvm::dyn_cast(def)) { - rewriter.replaceOpWithNewOp(op, concat.getInit(), - op.getIndex()); - return success(); - } - - // Case DynamicBroadcastInDimOp. - if (auto bcast = llvm::dyn_cast(def)) { - rewriter.replaceOpWithNewOp(op, bcast.getInit(), - op.getIndex()); - return success(); - } - - return failure(); - } -}; - -// Finds the `dst` operand of `setYieldOp` that matches `currentDst` and then -// replaces it with the corresponding `init` operand of the defining op of -// `currentDst`. At the moment this update is restricted to `linalg.fill` only, -// later it can be relaxed to support fusion of transposes into -// `gml_st.parallel`. -LogicalResult replaceSetYieldDstByProducerInit(SetYieldOp setYieldOp, - Value currentDst) { - auto fillOp = currentDst.getDefiningOp(); - if (!fillOp) return failure(); - - Value init = fillOp.getDpsInitOperand(0)->get(); - for (OpOperand& operand : setYieldOp->getOpOperands()) { - if (operand.get() != currentDst) continue; - operand.set(init); - return success(); - } - return failure(); -} - -class FusionPattern : public OpRewritePattern { - public: - FusionPattern(MLIRContext* context, - function_ref filterFn, - mlir::PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), filterFn(filterFn) {} - - LogicalResult matchAndRewrite(MaterializeOp materializeOp, - PatternRewriter& rewriter) const override { - assert(filterFn && "expect filter function"); - if (failed(filterFn(materializeOp))) - return rewriter.notifyMatchFailure(materializeOp, "filtered"); - - Location loc = materializeOp.getLoc(); - FailureOr fusedOr = createFusedOp(rewriter, materializeOp); - if (failed(fusedOr)) return failure(); // Match failure aleady notified. - - // Insert cast if needed. - Value fused = *fusedOr; - if (fused.getType() != materializeOp.getType()) { - if (!materializeOp.getType().isa()) { - // the result should be a scalar, insert tensor.extract - auto tensorType = fused.getType().dyn_cast(); - assert(tensorType && tensorType.getNumElements() == 1 && - "resulting tensor should contain a single element"); - auto zero = rewriter.create(loc, 0); - fused = rewriter.create( - loc, fused, SmallVector(tensorType.getRank(), zero)); - } else { - // The result should be a tensor, cast it to the correct shape - fused = rewriter.create(loc, materializeOp.getType(), - fused); - } - } - - // Update destination argument of SetYieldOp if we are fusing into the - // output tile. - if (auto parallelOp = dyn_cast(materializeOp->getParentOp())) { - SetYieldOp setYieldOp = parallelOp.getTerminator(); - Value src = materializeOp.getSource(); - if (llvm::is_contained(src.getUsers(), setYieldOp)) { - if (failed(replaceSetYieldDstByProducerInit(setYieldOp, src))) - return failure(); - } - } - - rewriter.replaceOp(materializeOp, fused); - return success(); - } - - private: - function_ref filterFn; -}; - -struct FusionPass : public impl::FusionPassBase { - FusionPass(StringRef producer, StringRef consumer) { - this->producerLabel = producer.str(); - this->consumerLabel = consumer.str(); - } - - void getDependentDialects(DialectRegistry& registry) const final { - registry.insert(); - registerGmlStTilingInterfaceExternalModels(registry); - } - - void runOnOperation() final { - MLIRContext* ctx = &getContext(); - - auto filterFn = [&](Operation* op) { - auto materializeOp = cast(op); - Operation* producerOp = materializeOp.getSource().getDefiningOp(); - if (!producerOp || (!producerLabel.empty() && - !hasMatchingLabel(producerOp, producerLabel))) { - return failure(); - } - - Operation* consumerOp = nullptr; - if (!consumerLabel.empty()) { - for (Operation* user : materializeOp.getResult().getUsers()) { - if (hasMatchingLabel(user, consumerLabel)) { - consumerOp = user; - break; - } - } - return success(consumerOp != nullptr); - } - - return success(); - }; - - // Populate patterns. - RewritePatternSet patterns(ctx); - populateFusionPatterns(ctx, filterFn, &patterns); - - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -FailureOr createFusedOp(PatternRewriter& rewriter, - MaterializeOp materializeOp) { - auto tileableOp = materializeOp.getSource().getDefiningOp(); - if (!tileableOp) { - return rewriter.notifyMatchFailure( - materializeOp, "expected source to be defined by tiling interface op "); - } - - auto tileOp = materializeOp.getSet().getDefiningOp(); - if (!tileOp) { - return rewriter.notifyMatchFailure( - materializeOp, "expected set to be defined by gml_st.tile"); - } - - SmallVector offsets = tileOp.getMixedOffsets(); - SmallVector sizes = tileOp.getMixedSizes(); - - // Tile the producer. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(materializeOp); - FailureOr tiledProducer = tileableOp.generateResultTileValue( - rewriter, /*resultNumber=*/0, offsets, sizes); - if (failed(tiledProducer)) { - return rewriter.notifyMatchFailure(tileableOp, - "failed to tile the producer"); - } - - return tiledProducer; -} - -void populateFusionPatterns(MLIRContext* ctx, - function_ref filterFn, - RewritePatternSet* patterns) { - patterns->insert(ctx, filterFn); - // clang-format off - patterns->insert< - DimOpFissionPattern, - DimOpReificationPattern>(ctx); - // clang-format on -} - -std::unique_ptr> createFusionPass( - StringRef producer, StringRef consumer) { - return std::make_unique(producer, consumer); -} - -} // namespace gml_st -} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/linalg_utils.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/linalg_utils.cc deleted file mode 100644 index f13e4d097f7..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/linalg_utils.cc +++ /dev/null @@ -1,187 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir-hlo/Dialect/gml_st/transforms/linalg_utils.h" - -namespace mlir { -namespace gml_st { - -namespace { - -bool hasUniqueInputAndOutputMaps(linalg::GenericOp genericOp, - AffineMap &inputMap, AffineMap &outputMap) { - if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) { - return false; - } - inputMap = genericOp.getIndexingMapsArray().front(); - outputMap = genericOp.getIndexingMapsArray().back(); - return true; -} - -// Checks if an affine map maps all dimensions in sequence, skipping a unique -// dimension. This can be the output map of a reduction, or the input map of a -// bcast. For example: -// - affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> -// - affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> -// - affine_map<(d0, d1) -> (d0)> -// - affine_map<(d0, d1) -> (d1)> -bool isBcastOrReductionMap(AffineMap map, int64_t &dim) { - const auto *it = map.getResults().begin(); - const auto *end = map.getResults().end(); - auto consumeIotaSeq = [&](int64_t &i) { - while (it != end) { - auto expr = it->dyn_cast(); - if (!expr || expr.getPosition() != i) break; - it++; - i++; - } - }; - int64_t i = 0; - consumeIotaSeq(i); - dim = i++; - consumeIotaSeq(i); - return i == map.getNumDims(); -} - -} // namespace - -bool isSimpleReduction(Operation *op, int64_t *dimension, Value *operand) { - auto genericOp = llvm::dyn_cast_or_null(op); - if (!genericOp || genericOp.getNumDpsInits() != 1) return false; - - // Expect monadic op. - AffineMap inputMap, outputMap; - if (!hasUniqueInputAndOutputMaps(genericOp, inputMap, outputMap)) - return false; - - // Check identity of operand map. - if (!inputMap.isIdentity()) return false; - - // Check that the output map is a reduction: it maps all dimensions in - // seqence, skipping the unique reduction dimension. - int64_t dim; - if (!isBcastOrReductionMap(outputMap, dim)) return false; - - // Check uniqueness of reduction dimension and remaining parallel iterator - // types. - auto iterTys = genericOp.getIteratorTypes(); - for (int i = 0; i < iterTys.size(); i++) { - StringRef expectedTy = i == dim ? getReductionIteratorTypeName() - : getParallelIteratorTypeName(); - StringRef actualTy = - genericOp.getIteratorTypes()[i].cast().getValue(); - if (expectedTy != actualTy) return false; - } - - // Allow for pattern matching the reduction dimension and operand. - if (dimension != nullptr) *dimension = dim; - if (operand != nullptr) *operand = genericOp.getInputs().front(); - - return true; -} - -bool isCwiseGenericOp(Operation *op, int64_t *arity) { - auto genericOp = llvm::dyn_cast_or_null(op); - if (!genericOp || genericOp.getNumDpsInits() != 1) return false; - - // Check all-parallel iterator types. - if (!llvm::all_of(genericOp.getIteratorTypes(), [](Attribute it) { - return it.cast().getValue() == - getParallelIteratorTypeName(); - })) { - return false; - } - - // Check all-identity maps. - if (!llvm::all_of(genericOp.getIndexingMapsArray(), - [](AffineMap map) { return map.isIdentity(); })) { - return false; - } - - // Allow for pattern matching the arity. - if (arity != nullptr) *arity = genericOp.getNumDpsInputs(); - return true; -} - -bool isUnaryCwiseGenericOp(Operation *op) { - int64_t arity; - return isCwiseGenericOp(op, &arity) && arity == 1; -} - -bool isSimpleBcast(Operation *op, int64_t *dimension, Value *operand) { - auto genericOp = llvm::dyn_cast_or_null(op); - if (!genericOp) return false; - - // Expect monadic op. - AffineMap inputMap, outputMap; - if (!hasUniqueInputAndOutputMaps(genericOp, inputMap, outputMap)) - return false; - - // Check all-parallel iterator types. - if (!llvm::all_of(genericOp.getIteratorTypes(), [](Attribute it) { - return it.cast().getValue() == - getParallelIteratorTypeName(); - })) { - return false; - } - - // Check that the operand map is a degenerate bcast: it maps all dimensions in - // seqence, skipping the unique bcast dimension. - int64_t dim; - if (!isBcastOrReductionMap(inputMap, dim)) return false; - - // Check that the output map is the identity. - if (!outputMap.isIdentity()) return false; - - // Allow for pattern matching the reduction dimension and operand. - if (dimension != nullptr) *dimension = dim; - if (operand != nullptr) *operand = genericOp.getInputs().front(); - - return true; -} - -bool isSimpleBcastReduction(Operation *op, int64_t *dimension, - SimpleBcastReduction *chain) { - // Match bcast. - int64_t bcastDim; - Value bcastOperand; - if (!isSimpleBcast(op, &bcastDim, &bcastOperand)) { - return false; - } - - // Match reduction. - Operation *reduction = bcastOperand.getDefiningOp(); - int64_t reductionDim; - Value operand; - if (!isSimpleReduction(reduction, &reductionDim, &operand)) { - return false; - } - - // Check that bcast and reduction dimensions match. - if (bcastDim != reductionDim) return false; - - // Allow for pattern matching the reduction dimension and operation chain. - if (dimension != nullptr) *dimension = bcastDim; - if (chain != nullptr) { - chain->bcast = op; - chain->operand = operand; - chain->operand = operand; - } - - return true; -} - -} // namespace gml_st -} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/test_passes.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/test_passes.cc deleted file mode 100644 index 583cc64e67e..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/test_passes.cc +++ /dev/null @@ -1,208 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir-hlo/Dialect/gml_st/transforms/test_passes.h" - -#include -#include - -#include "mlir-hlo/Dialect/gml_st/transforms/bufferizable_op_interface_impl.h" -#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" -#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" -#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace gml_st { -namespace { - -#define GEN_PASS_DEF_TESTGMLSTBUFFERIZATION -#define GEN_PASS_DEF_TESTGMLSTLOOPPEELING -#define GEN_PASS_DEF_TESTGMLSTLOOPTILING -#include "mlir-hlo/Dialect/gml_st/transforms/test_passes.h.inc" - -static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; -static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; - -/// Peel LoopOps, i.e., split them into two loops: One loop where the -/// `idx`-th loop contains only "full" iterations and a second loop for the -/// remaining partial iteration (if any). -struct TiledLoopPeelingPattern : public OpRewritePattern { - TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial) - : OpRewritePattern(ctx), idx(idx), skipPartial(skipPartial) {} - - LogicalResult matchAndRewrite(LoopOp loopOp, - PatternRewriter &rewriter) const override { - SmallVector peeledLoops; - if (loopOp->hasAttr(kPeeledLoopsLabel)) { - auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast(); - peeledLoops = - llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { - return attr.cast().getInt(); - })); - // Check if the loop was already peeled. - if (llvm::find(peeledLoops, idx) != peeledLoops.end()) return failure(); - } - if (skipPartial && loopOp->hasAttr(kPartialIterationLabel)) - // No peeling of loop nests with a partial iteration. - return failure(); - - if (static_cast(loopOp.getIteratorTypes().size()) <= idx) - return failure(); - - // Peel loop and canonicalize. - LoopOp result; - if (failed(peelAndCanonicalizeGmlStLoop(rewriter, loopOp, idx, result))) - return failure(); - - // Apply label, so that the same loop is not rewritten a second time. - peeledLoops.push_back(idx); - rewriter.updateRootInPlace(loopOp, [&]() { - loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); - }); - result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); - result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); - - return success(); - } - - /// Index of loop to peel. - int64_t idx; - - /// If set to true, do not peel LoopOps with a partial iteration. - bool skipPartial; -}; - -class TestGmlStLoopPeelingPass - : public impl::TestGmlStLoopPeelingBase { - void runOnOperation() final { - auto funcOp = getOperation(); - MLIRContext *ctx = funcOp.getContext(); - RewritePatternSet patterns(ctx); - for (unsigned idx : dims) - patterns.add(ctx, idx, skip_partial); - - (void)(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))); - - // Drop the markers. - funcOp.walk([](LoopOp op) { - op->removeAttr(kPeeledLoopsLabel); - op->removeAttr(kPartialIterationLabel); - }); - } -}; - -struct LinalgTilingPattern - : public OpInterfaceRewritePattern { - LinalgTilingPattern(MLIRContext *context, linalg::LinalgTilingOptions options, - PatternBenefit benefit = 1) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(options)) {} - - LogicalResult matchAndRewrite(linalg::LinalgOp op, - PatternRewriter &rewriter) const override { - if (hasTransformationAttr(op)) return failure(); - - FailureOr res = - gml_st::tileLinalgOp(rewriter, op, options); - if (failed(res)) return failure(); - - setTransformationAttr(rewriter, res->op); - - if (res->tensorResults.empty()) - rewriter.eraseOp(op); - else - rewriter.replaceOp(op, res->tensorResults); - - return success(); - } - - private: - linalg::LinalgTilingOptions options; -}; - -struct TestGmlStLoopTilingPass - : public impl::TestGmlStLoopTilingBase { - TestGmlStLoopTilingPass() = default; - TestGmlStLoopTilingPass(ArrayRef tileSizes, - ArrayRef distributionTypes) { - this->tile_sizes = tileSizes; - this->distribution_types = llvm::to_vector<2>(llvm::map_range( - distributionTypes, [](StringRef ref) { return ref.str(); })); - } - - void runOnOperation() override { - func::FuncOp funcOp = getOperation(); - - auto distTypes = llvm::to_vector<2>(llvm::map_range( - distribution_types, [](std::string &str) { return StringRef(str); })); - auto options = linalg::LinalgTilingOptions() - .setTileSizes(tile_sizes) - .setDistributionTypes(distTypes); - MLIRContext *ctx = funcOp.getContext(); - RewritePatternSet patterns(ctx); - - patterns.add(ctx, options); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); - - funcOp.walk([](linalg::LinalgOp op) { removeTransformationAttr(op); }); - } -}; - -struct TestGmlStBufferizationPass - : public impl::TestGmlStBufferizationBase { - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - linalg::registerBufferizableOpInterfaceExternalModels(registry); - gml_st::registerBufferizableOpInterfaceExternalModels(registry); - } - - void runOnOperation() override { - bufferization::OneShotBufferizationOptions opts; - opts.allowReturnAllocs = true; - opts.bufferizeFunctionBoundaries = true; - opts.functionBoundaryTypeConversion = - bufferization::BufferizationOptions::LayoutMapOption::IdentityLayoutMap; - - ModuleOp module = getOperation(); - if (failed(bufferization::runOneShotModuleBufferize(module, opts))) { - signalPassFailure(); - return; - } - } -}; - -} // namespace - -std::unique_ptr> createTestGmlStLoopPeelingPass() { - return std::make_unique(); -} - -std::unique_ptr> createTestGmlStLoopTilingPass() { - return std::make_unique(); -} - -std::unique_ptr> createTestGmlStBufferizationPass() { - return std::make_unique(); -} - -} // namespace gml_st -} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_interface_impl.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_interface_impl.cc deleted file mode 100644 index 03c851cfa9b..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_interface_impl.cc +++ /dev/null @@ -1,182 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h" - -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface.h" -#include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/Interfaces/DestinationStyleOpInterface.h" - -namespace mlir { -namespace gml_st { -namespace { - -template -struct ExternalLinalgOpTilingInterface - : public TilingInterface::ExternalModel< - ExternalLinalgOpTilingInterface, LinalgOpTy> { - /// Return the destination operands. - SmallVector getDestinationOperands(Operation *op, OpBuilder &) const { - return cast(op).getDpsInitOperands(); - } - - /// Return the loop iterator type. - SmallVector getLoopIteratorTypes(Operation *op) const { - auto linalgOp = cast(op); - return llvm::to_vector(llvm::map_range( - linalgOp.getIteratorTypesArray(), [](StringRef iteratorType) { - return utils::symbolizeIteratorType(iteratorType).value(); - })); - } - - /// Return the iteration domain range. - SmallVector getIterationDomain(Operation *op, OpBuilder &b) const { - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - Location loc = op->getLoc(); - linalg::LinalgOp linalgOp = cast(op); - SmallVector allShapesSizes = - linalgOp.createFlatListOfOperandDims(b, loc); - AffineMap map = linalgOp.getShapesToLoopsMap(); - - IRRewriter rewriter(b); - return llvm::to_vector( - llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) { - OpFoldResult ofr = makeComposedFoldedAffineApply( - rewriter, loc, loopExpr, allShapesSizes); - return Range{b.getIndexAttr(0), ofr, b.getIndexAttr(1)}; - })); - } - - // Instantiate the tiled implementation of the operation. - TilingInterface getTiledImplementation(Operation *op, OpBuilder &b, - ArrayRef offsets, - ArrayRef sizes) const { - Location loc = op->getLoc(); - linalg::LinalgOp linalgOp = cast(op); - OperandRange valuesToTile = linalgOp->getOperands(); - SmallVector> allSliceParams = - linalg::computeAllSliceParameters(b, loc, linalgOp, valuesToTile, - offsets, sizes, {}, true); - - SmallVector tiledOperands; - for (const auto &[valueToTile, sliceParams] : - llvm::zip(valuesToTile, allSliceParams)) { - // Use the original operand if it is not a ranked tensor. This could be a - // scalar, e.g. for `linalg.fill`. - auto valueToTileTy = - valueToTile.getType().template dyn_cast(); - if (!valueToTileTy) { - tiledOperands.push_back(valueToTile); - continue; - } - - int64_t rank = valueToTileTy.getRank(); - SmallVector valueToTileSizes{ - tensor::getMixedSizes(b, loc, valueToTile)}; - SmallVector zeros(rank, b.getI64IntegerAttr(0)); - SmallVector ones(rank, b.getI64IntegerAttr(1)); - Value set = - sliceParams.has_value() - ? b.create(loc, sliceParams->offsets, sliceParams->sizes, - sliceParams->strides) - : b.create(loc, zeros, valueToTileSizes, ones); - - Value materializedTile = b.create(loc, valueToTile, set); - tiledOperands.push_back(materializedTile); - } - - SmallVector resultTensorTypes = llvm::to_vector(llvm::map_range( - linalgOp.getDpsInitOperands(), [&](OpOperand *opOperand) { - return tiledOperands[opOperand->getOperandNumber()].getType(); - })); - - Operation *tiledOp = - linalgOp.clone(b, loc, resultTensorTypes, tiledOperands); - offsetIndices(b, cast(tiledOp), offsets); - - return {tiledOp}; - } - - FailureOr generateResultTileValue(Operation *op, OpBuilder &b, - unsigned resultNumber, - ArrayRef offsets, - ArrayRef sizes) const { - auto linalgOp = cast(op); - - // Check that the indexing map used for the output is a projected - // permutation. This could be relaxed with a more general approach that can - // map the offsets and sizes from the result to iteration space tiles - // (filling in full extent for dimensions not used to access the result). - AffineMap indexingMap = - linalgOp.getIndexingMapMatchingResult(op->getResult(resultNumber)); - if (!indexingMap.isProjectedPermutation()) { - return op->emitOpError( - "unhandled tiled implementation generation when result is not " - "accessed using a permuted projection"); - } - - auto numLoops = linalgOp.getNumLoops(); - auto tilingInterfaceOp = cast(op); - SmallVector iterationTileOffsets(numLoops), - iterationTileSizes(numLoops); - if (!indexingMap.isPermutation()) { - SmallVector iterationDomain = - tilingInterfaceOp.getIterationDomain(b); - for (const auto &range : llvm::enumerate(iterationDomain)) { - iterationTileOffsets[range.index()] = range.value().offset; - iterationTileSizes[range.index()] = range.value().size; - } - } - for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) { - unsigned dimPosition = - resultExpr.value().cast().getPosition(); - iterationTileOffsets[dimPosition] = offsets[resultExpr.index()]; - iterationTileSizes[dimPosition] = sizes[resultExpr.index()]; - } - - TilingInterface tiledOp = tilingInterfaceOp.getTiledImplementation( - b, iterationTileOffsets, iterationTileSizes); - - return tiledOp->getResult(resultNumber); - } -}; - -} // namespace - -void registerGmlStTilingInterfaceExternalModels(DialectRegistry ®istry) { - registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *) { - linalg::FillOp::attachInterface< - ExternalLinalgOpTilingInterface>(*ctx); - linalg::GenericOp::attachInterface< - ExternalLinalgOpTilingInterface>(*ctx); - linalg::MapOp::attachInterface< - ExternalLinalgOpTilingInterface>(*ctx); - linalg::MatmulOp::attachInterface< - ExternalLinalgOpTilingInterface>(*ctx); - linalg::TransposeOp::attachInterface< - ExternalLinalgOpTilingInterface>(*ctx); - }); -} - -} // namespace gml_st -} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transform_matmul_for_cpu.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transform_matmul_for_cpu.cc deleted file mode 100644 index aec7d3be33b..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transform_matmul_for_cpu.cc +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h" -#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h" -#include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir::gml_st { -namespace { - -#define GEN_PASS_DEF_TRANSFORMMATMULFORCPUPASS -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h.inc" - -struct TransformMatmulForCpuPass - : public impl::TransformMatmulForCpuPassBase { - TransformMatmulForCpuPass() = default; - explicit TransformMatmulForCpuPass( - llvm::ArrayRef matmulTileSizes) { - tileSizes = matmulTileSizes; - } - - void getDependentDialects(DialectRegistry ®istry) const final { - registry.insert(); - mlir::gml_st::registerGmlStTilingInterfaceExternalModels(registry); - } - - void runOnOperation() override { - func::FuncOp f = getOperation(); - MLIRContext *ctx = &getContext(); - - mlir::gml_st::TilingOptions opts; - - if ((*tileSizes).empty()) { - tileSizes = {2, 2, 2}; - } - - assert(tileSizes.size() == 3 && - "Tiling sizes for MatMul should have 3 elements"); - - auto filter_fn = [&](Operation *op) { - return success(isa(op)); - }; - - /////////////////////////////// - // Tiling parallel dimensions - opts.setTileSizeComputationFn({(*tileSizes)[0], (*tileSizes)[1], 0}); - - RewritePatternSet patterns(ctx); - populateTilingPatterns(ctx, filter_fn, opts, &patterns); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { - return signalPassFailure(); - } - - // Ensure we drop the marker in the end. - f.walk([](linalg::LinalgOp op) { gml_st::removeTransformationAttr(op); }); - - /////////////////////////////// - // Tiling reduction dimension - opts.setTileSizeComputationFn({0, 0, (*tileSizes).back()}); - opts.distribute = false; - - RewritePatternSet newpatterns(ctx); - populateTilingPatterns(ctx, filter_fn, opts, &newpatterns); - - if (failed(applyPatternsAndFoldGreedily(f, std::move(newpatterns)))) { - return signalPassFailure(); - } - - // Ensure we drop the marker in the end. - f.walk([](linalg::LinalgOp op) { gml_st::removeTransformationAttr(op); }); - } -}; - -} // namespace -} // namespace mlir::gml_st - -namespace mlir::gml_st { - -std::unique_ptr> -createTransformMatmulForCpuPass() { - return std::make_unique(); -} - -std::unique_ptr> -createTransformMatmulForCpuPass(llvm::ArrayRef matmulTileSizes) { - return std::make_unique( - matmulTileSizes); -} - -} // namespace mlir::gml_st diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transforms.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transforms.cc deleted file mode 100644 index d40c6682f1a..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/transforms.cc +++ /dev/null @@ -1,396 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h" - -#include -#include - -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Matchers.h" - -namespace mlir { -namespace gml_st { -bool isZero(Value v) { return matchPattern(v, m_Zero()); } -namespace { - -bool isATensor(Type t) { return t.isa(); } - -/// Return true if the given op has only tensor-typed results or operands. -bool hasTensorSemantics(Operation *op) { - return llvm::all_of(op->getResultTypes(), isATensor) || - llvm::all_of(op->getOperandTypes(), isATensor); -} - -/// Rewrite a LoopOp/ParallelOp/ForOp with bounds/step that potentially do not -/// divide evenly into two LoopOp/ParallelOp/ForOps: One where the step divides -/// the iteration space evenly, followed another one for the last (partial) -/// iteration (if any). This function only rewrites the `idx`-th loop of the -/// loop nest represented by the LoopOp/ParallelOp/ForOp. To peel the entire -/// loop nest, this function must be called multiple times. -/// -/// This function rewrites the given LoopOp/ParallelOp/ForOp in-place and -/// creates a new LoopOp/ParallelOp/ForOp for the last iteration. It replaces -/// all uses of the original LoopOp/ParallelOp/ForOp with the results of the -/// newly generated one. -/// -/// The newly generated LoopOp/ParallelOp/ForOp is returned via `result`. The -/// boundary at which the loop is split (new upper bound) is returned via -/// `splitBound`. The return value indicates whether the -/// LoopOp/ParallelOp/ForOp was rewritten or not. -template -LogicalResult peelLoop(RewriterBase &b, LoopTy loopOp, int64_t idx, - LoopTy &result, Value &splitBound) { - if (!hasTensorSemantics(loopOp)) return failure(); - - Value lb = loopOp.getLowerBound()[idx], ub = loopOp.getUpperBound()[idx], - step = loopOp.getStep()[idx]; - auto ubInt = getConstantIntValue(ub); - - auto loc = loopOp.getLoc(); - AffineExpr exprLb, exprUb, exprStep; - bindSymbols(b.getContext(), exprLb, exprUb, exprStep); - // New upper bound: %ub - (%ub - %lb) mod %step - auto modMap = AffineMap::get(0, 3, {exprUb - ((exprUb - exprLb) % exprStep)}); - SmallVector operands{lb, ub, step}; - canonicalizeMapAndOperands(&modMap, &operands); - modMap = simplifyAffineMap(modMap); - RewriterBase::InsertionGuard guard(b); - b.setInsertionPoint(loopOp); - splitBound = b.createOrFold(loc, modMap, operands); - // No specialization necessary if step already divides upper bound evenly. - if (splitBound == ub || (ubInt && ubInt == getConstantIntValue(splitBound))) - return failure(); - - // Create remainder loop. - BlockAndValueMapping bvm; - for (const auto &[res, termDst] : - llvm::zip(loopOp.getResults(), loopOp.getLoopLikeOpInits())) { - bvm.map(termDst, res); - } - b.setInsertionPointAfter(loopOp); - auto remainderLoop = cast(b.clone(*loopOp.getOperation(), bvm)); - - Operation *remainderLoopOp = remainderLoop.getOperation(); - - for (const auto &[oldRes, newRes] : - llvm::zip(loopOp.getResults(), remainderLoop.getResults())) { - SmallPtrSet exceptions({remainderLoopOp}); - for (OpOperand &use : oldRes.getUses()) { - Operation *user = use.getOwner(); - if (user->getParentOp() == remainderLoopOp) exceptions.insert(user); - } - oldRes.replaceAllUsesExcept(newRes, exceptions); - } - - // Set new loop bounds. - b.updateRootInPlace(loopOp, [&]() { - SmallVector ubs = loopOp.getUpperBound(); - ubs[idx] = splitBound; - loopOp.getUpperBoundMutable().assign(ubs); - }); - SmallVector lbs = remainderLoop.getLowerBound(); - lbs[idx] = splitBound; - b.updateRootInPlace(remainderLoop, [&]() { - remainderLoop.getLowerBoundMutable().assign(lbs); - }); - - result = remainderLoop; - return success(); -} - -template -void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, Operation *mainLoop, - Operation *remainderLoop, Value mainIv, - Value remainderIv, Value ub, Value step) { - mainLoop->walk([&](OpTy affineOp) { - AffineMap map = affineOp.getAffineMap(); - (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, - affineOp.operands(), IsMin, mainIv, ub, - step, /*insideLoop=*/true); - }); - remainderLoop->walk([&](OpTy affineOp) { - AffineMap map = affineOp.getAffineMap(); - (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, - affineOp.operands(), IsMin, remainderIv, - ub, step, /*insideLoop=*/false); - }); -} - -using ::mlir::linalg::LinalgOp; - -void generateLoopNest(OpBuilder &b, Location loc, ArrayRef loopRanges, - LinalgOp linalgOp, ArrayRef iteratorTypes, - function_ref - bodyBuilderFn, - ArrayRef distributionTypes) { - SmallVector lbs, ubs, steps; - for (Range range : loopRanges) { - lbs.emplace_back(range.offset); - ubs.emplace_back(range.size); - steps.emplace_back(range.stride); - } - - auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc, - ValueRange ivs, ValueRange inputs, - ValueRange outputs) { - SmallVector operandValuesToUse = inputs; - operandValuesToUse.append(outputs.begin(), outputs.end()); - scf::ValueVector results = - bodyBuilderFn(nestedBuilder, nestedLoc, ivs, operandValuesToUse); - nestedBuilder.create(nestedLoc, results); - }; - - SmallVector inputs{linalgOp.getDpsInputOperands()}; - SmallVector outputs{linalgOp.getDpsInitOperands()}; - - SmallVector lbsValue = - mlir::getValueOrCreateConstantIndexOp(b, loc, lbs); - SmallVector ubsValue = - mlir::getValueOrCreateConstantIndexOp(b, loc, ubs); - SmallVector stepsValue = - mlir::getValueOrCreateConstantIndexOp(b, loc, steps); - auto tiledLoop = - b.create(loc, lbsValue, ubsValue, stepsValue, inputs, outputs, - b.getArrayAttr(iteratorTypes), wrappedBuilderFn); - if (!distributionTypes.empty()) - tiledLoop.setDistributionTypes(b, distributionTypes); -} - -// Insert a tile `source` into the destination tensor `dest`. The position at -// which the tile is inserted (as well as size of tile) is taken from a given -// ExtractSliceOp `sliceOp`. -Value insertSliceIntoTensor(RewriterBase &b, Location loc, - tensor::ExtractSliceOp sliceOp, Value source, - Value dest) { - return b.create( - loc, sliceOp.getSource().getType(), source, dest, sliceOp.getOffsets(), - sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), - sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); -} - -FailureOr tileLinalgOpImpl( - RewriterBase &b, LinalgOp op, ValueRange tileSizes, - const linalg::LinalgTilingOptions &options) { - auto nLoops = op.getNumLoops(); - // Initial tile sizes may be too big, only take the first nLoops. - tileSizes = tileSizes.take_front(nLoops); - - if (llvm::all_of(tileSizes, isZero)) { - linalg::TiledLinalgOp tiledOp; - tiledOp.op = cast(b.clone(*op.getOperation())); - tiledOp.tensorResults.assign(tiledOp.op->result_begin(), - tiledOp.op->result_end()); - return tiledOp; - } - - SmallVector tileSizesFold; - for (Value tileSize : tileSizes) tileSizesFold.push_back(tileSize); - - // 1. Build the tiled loop ranges. - auto allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc()); - AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); - if (!shapeSizesToLoopsMap) return failure(); - - SmallVector loopRanges; - mlir::linalg::LoopIndexToRangeIndexMap loopIndexToRangeIndex; - std::tie(loopRanges, loopIndexToRangeIndex) = - mlir::linalg::makeTiledLoopRanges(b, op.getLoc(), shapeSizesToLoopsMap, - allShapeSizes, tileSizesFold); - - SmallVector iteratorTypes; - for (const auto &attr : enumerate(op.getIteratorTypesArray())) { - if (loopIndexToRangeIndex.count(attr.index())) - iteratorTypes.push_back(IteratorTypeAttr::get( - b.getContext(), utils::symbolizeIteratorType(attr.value()).value())); - } - - // 2. Create the tiled loops. - LinalgOp res = op; - SmallVector ivs, tensorResults; - auto tiledLoopBodyBuilder = - [&](OpBuilder & /*builder*/, Location loc, ValueRange localIvs, - ValueRange operandValuesToUse) -> scf::ValueVector { - ivs.assign(localIvs.begin(), localIvs.end()); - - // Tile the `operandValuesToUse` that either match the `op` operands - // themselves or the tile loop arguments forwarding them. - assert(operandValuesToUse.size() == op->getNumOperands() && - "expect the number of operands and inputs and outputs to match"); - SmallVector valuesToTile = operandValuesToUse; - auto sizeBounds = makeComposedFoldedMultiResultAffineApply( - b, loc, shapeSizesToLoopsMap, allShapeSizes); - SmallVector ivsFold(ivs.begin(), ivs.end()); - SmallVector tiledOperands = makeTiledShapes( - b, loc, op, valuesToTile, ivsFold, tileSizesFold, sizeBounds, - /*omitPartialTileCheck=*/false); - - SmallVector resultTensorTypes; - for (OpOperand *opOperand : op.getDpsInitOperands()) - resultTensorTypes.push_back( - tiledOperands[opOperand->getOperandNumber()].getType()); - - res = op.clone(b, loc, resultTensorTypes, tiledOperands); - - // Insert a insert_slice for each output tensor. - unsigned resultIdx = 0; - for (OpOperand *opOperand : op.getDpsInitOperands()) { - Value outputTensor = tiledOperands[opOperand->getOperandNumber()]; - IRRewriter rewriter(b); - if (auto sliceOp = outputTensor.getDefiningOp()) { - tensorResults.push_back(insertSliceIntoTensor(rewriter, loc, sliceOp, - res->getResult(resultIdx), - sliceOp.getSource())); - } else { - tensorResults.push_back(res->getResult(resultIdx)); - } - ++resultIdx; - } - return scf::ValueVector(tensorResults.begin(), tensorResults.end()); - }; - generateLoopNest(b, op.getLoc(), loopRanges, op, iteratorTypes, - tiledLoopBodyBuilder, options.distributionTypes); - - // 3. Transform IndexOp results w.r.t. the tiling. - mlir::linalg::transformIndexOps(b, res, ivs, loopIndexToRangeIndex); - - // 4. Gather the newly created loops and return them with the new op. - SmallVector loops; - loops.reserve(ivs.size()); - for (auto iv : ivs) { - if (iv.isa()) { - loops.push_back(iv.cast().getOwner()->getParentOp()); - assert(loops.back() && "no owner found for induction variable!"); - } else { - loops.push_back(nullptr); - } - } - - // 5. Get the tensor results from the outermost loop if available. Otherwise - // use the previously captured `tensorResults`. - Operation *outermostLoop = nullptr; - for (Operation *loop : loops) - if ((outermostLoop = loop)) break; - - return linalg::TiledLinalgOp{ - res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults}; -} - -template -LogicalResult peelAndCanonicalizeGmlStLoopImpl(RewriterBase &rewriter, - LoopTy loopOp, int64_t idx, - LoopTy &result) { - int64_t numLoops = loopOp.getNumLoops(); - if (idx < 0 || numLoops <= idx) return failure(); - - Value ub = loopOp.getUpperBound()[idx]; - LoopTy remainderLoop; - Value splitBound; - if (failed( - peelLoop(rewriter, loopOp, idx, remainderLoop, splitBound))) - return failure(); - - // Rewrite affine.min and affine.max ops. - Value mainIv = loopOp.getInductionVars()[idx], step = loopOp.getStep()[idx], - remainderIv = remainderLoop.getInductionVars()[idx]; - - rewriteAffineOpAfterPeeling( - rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step); - rewriteAffineOpAfterPeeling( - rewriter, loopOp, remainderLoop, mainIv, remainderIv, ub, step); - - result = remainderLoop; - return success(); -} -} // namespace - -LogicalResult peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter, - LoopOp loopOp, int64_t idx, - LoopOp &result) { - return peelAndCanonicalizeGmlStLoopImpl(rewriter, loopOp, idx, - result); -} - -LogicalResult peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter, - ParallelOp loopOp, int64_t idx, - ParallelOp &result) { - return peelAndCanonicalizeGmlStLoopImpl(rewriter, loopOp, idx, - result); -} - -LogicalResult peelAndCanonicalizeGmlStLoop(RewriterBase &rewriter, ForOp loopOp, - int64_t idx, ForOp &result) { - return peelAndCanonicalizeGmlStLoopImpl(rewriter, loopOp, idx, result); -} - -FailureOr tileLinalgOp( - RewriterBase &b, linalg::LinalgOp op, - const linalg::LinalgTilingOptions &options) { - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - - if (!options.tileSizeComputationFunction) return failure(); - - // Enforce the convention that "tiling by zero" skips tiling a particular - // dimension. This convention is significantly simpler to handle instead of - // adjusting affine maps to account for missing dimensions. - auto nLoops = op.getNumLoops(); - SmallVector tileSizeVector = - options.tileSizeComputationFunction(b, op); - if (tileSizeVector.size() < nLoops) { - auto zero = b.create(op.getLoc(), 0); - tileSizeVector.append(nLoops - tileSizeVector.size(), zero); - } - - return tileLinalgOpImpl(b, op, tileSizeVector, options); -} - -constexpr llvm::StringLiteral kTransformMarker = - "__internal_transformation_marker__"; - -void setTransformationAttr(mlir::OpBuilder &b, Operation *op) { - op->setAttr(kTransformMarker, b.getBoolAttr(true)); -} - -void removeTransformationAttr(Operation *op) { - op->removeAttr(kTransformMarker); -} - -bool hasTransformationAttr(Operation *op) { - auto marker = op->getAttr(kTransformMarker); - if (!marker) return false; - return marker && marker.cast().getValue(); -} - -constexpr llvm::StringLiteral kOpLabel = "op_label"; - -bool hasMatchingLabel(Operation *op, StringRef label) { - auto opLabelAttr = op->getAttr(kOpLabel); - if (!opLabelAttr) return false; - - return opLabelAttr.cast().getValue() == label; -} - -} // namespace gml_st -} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/CMakeLists.txt deleted file mode 100644 index e138afa587f..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -# -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -add_subdirectory(IR) -add_subdirectory(transforms) diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/lhlo_fuse_linalg.cc deleted file mode 100644 index c5b2dfef9de..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/transforms/lhlo_fuse_linalg.cc +++ /dev/null @@ -1,224 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements logic for fusing linalg ops obtained after LHLO -// lowering. - -#include - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "mlir-hlo/Dialect/lhlo/transforms/passes.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Interfaces/ViewLikeInterface.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -namespace mlir { -namespace lmhlo { - -#define GEN_PASS_DEF_LHLOFUSELINALGPASS -#include "mlir-hlo/Dialect/lhlo/transforms/lmhlo_passes.h.inc" - -namespace { - -using linalg::LinalgOp; - -class LhloFuseLinalgPass - : public impl::LhloFuseLinalgPassBase { - void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); - } - - public: - LhloFuseLinalgPass() = default; - LhloFuseLinalgPass(const LhloFuseLinalgPass&) = default; - LhloFuseLinalgPass(bool useParallelLoops, - llvm::ArrayRef tileSizes) { - tile_sizes_ = tileSizes; - use_parallel_loops_.setValue(useParallelLoops); - } - - void runOnOperation() override { - auto func = getOperation(); - - // TODO(pifon): Remove assumption that the function has a single block. - if (!llvm::hasSingleElement(func)) { - emitError(func.getLoc(), "The function needs to have a single block."); - signalPassFailure(); - return; - } - - // The fusion in Linalg is currently possible only when the consumer op is - // tiled. In order to greedily fuse the ops, we have to start from the tiled - // root linalg ops, i.e. linalg ops that write to output buffers of the - // function or are returned in case of escaping allocations. - llvm::SmallDenseSet resultBuffers; - for (auto funcArg : func.getArguments()) { - resultBuffers.insert(funcArg); - } - for (auto& block : func) { - auto returnOp = - mlir::dyn_cast(block.getTerminator()); - if (!returnOp) continue; - for (auto operand : returnOp.getOperands()) { - resultBuffers.insert(operand); - } - } - // Resolve aliasing operations (like casts) on the result to identify - // results. This only handles escaping results. - // TODO(herhut): Use BufferizeAliasAnalysis for this. - llvm::SmallVector worklist(resultBuffers.begin(), - resultBuffers.end()); - while (!worklist.empty()) { - Value result = worklist.pop_back_val(); - auto* definingOp = result.getDefiningOp(); - if (!definingOp) { - continue; - } - - if (auto viewLike = dyn_cast(definingOp)) { - auto alias = viewLike.getViewSource(); - if (resultBuffers.insert(alias).second) { - worklist.push_back(alias); - } - continue; - } - - if (auto toTensor = dyn_cast(definingOp)) { - auto alias = toTensor.getMemref(); - if (resultBuffers.insert(alias).second) { - worklist.push_back(alias); - } - continue; - } - - if (auto toMemref = dyn_cast(definingOp)) { - auto alias = toMemref.getTensor(); - if (resultBuffers.insert(alias).second) { - worklist.push_back(alias); - } - continue; - } - - if (auto tensorCast = dyn_cast(definingOp)) { - auto alias = tensorCast.getSource(); - if (resultBuffers.insert(alias).second) { - worklist.push_back(alias); - } - continue; - } - - if (auto regionInterface = - dyn_cast(definingOp)) { - for (Region& region : regionInterface.getOperation()->getRegions()) { - // Only consider regions that can return to the parent region. - SmallVector successorRegions; - regionInterface.getSuccessorRegions(region.getRegionNumber(), - successorRegions); - if (llvm::none_of(successorRegions, [&](auto successorRegion) { - return successorRegion.isParent(); - })) - continue; - - // Iterate over all immediate terminators and record the values - // corresponding to result_buffers of interest. - for (Block& block : region) { - if (block.empty()) continue; - Operation& operation = block.back(); - if (!operation.hasTrait()) continue; - auto idx = result.dyn_cast().getResultNumber(); - if (resultBuffers.insert(operation.getOperand(idx)).second) { - worklist.push_back(operation.getOperand(idx)); - } - } - } - } - } - - MLIRContext* ctx = func.getContext(); - OpBuilder b(func); - func.walk([&](linalg::GenericOp genericOp) { - SmallVector tileSizes(tile_sizes_.begin(), tile_sizes_.end()); - if (tileSizes.empty()) { - tileSizes = SmallVector(genericOp.getNumLoops(), 1); - } - auto op = cast(genericOp.getOperation()); - for (OpOperand* opOperand : op.getDpsInitOperands()) { - if (!resultBuffers.count(opOperand->get())) continue; - if (tileGenericOp(op, tileSizes, &b)) { - genericOp.erase(); - return; - } - } - }); - auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) - return signalPassFailure(); - - // Fuse producers of tiled linalg ops. - llvm::SmallDenseSet eraseSet; - SmallVector linalgOps; - func.walk([&](LinalgOp op) { linalgOps.push_back(op); }); - for (LinalgOp op : llvm::reverse(linalgOps)) { - for (OpOperand* inputOperand : op.getDpsInputOperands()) { - linalg::Aliases aliases; - linalg::LinalgDependenceGraph graph(aliases, linalgOps); - auto info = fuseProducerOfBuffer(b, *inputOperand, graph); - if (failed(info)) continue; - auto* originalOp = info->originalProducer.getOperation(); - eraseSet.insert(originalOp); - auto* originalOpInLinalgOpsVector = - std::find_if(linalgOps.begin(), linalgOps.end(), - [&](const Operation* op) { return op == originalOp; }); - *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); - } - - auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) - return signalPassFailure(); - } - for (auto* e : eraseSet) e->erase(); - } - - private: - bool tileGenericOp(LinalgOp op, ArrayRef tileSizes, OpBuilder* b) { - auto loopType = use_parallel_loops_ - ? linalg::LinalgTilingLoopType::ParallelLoops - : linalg::LinalgTilingLoopType::Loops; - IRRewriter rewriter(*b); - return succeeded(linalg::tileLinalgOp( - rewriter, op, - linalg::LinalgTilingOptions().setTileSizes(tileSizes).setLoopType( - loopType))); - } -}; - -} // namespace - -std::unique_ptr> createLhloFuseLinalgPass( - bool useParallelLoops, ArrayRef tileSizes) { - return std::make_unique(useParallelLoops, tileSizes); -} - -} // namespace lmhlo -} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo_gpu/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo_gpu/CMakeLists.txt deleted file mode 100644 index b16dd4a6fd4..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo_gpu/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -# -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -add_subdirectory(IR) diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_stablehlo.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_stablehlo.cc deleted file mode 100644 index 410bf14d476..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_stablehlo.cc +++ /dev/null @@ -1,226 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/map_stablehlo_to_hlo_op.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Types.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "stablehlo/dialect/StablehloOps.h" - -namespace mlir { -namespace stablehlo { -namespace { - -#define RETURN_CONVERTED_ENUM_ATTR(Name) \ - auto hloValue = mhlo::stringify##Name(attr.getValue()); \ - auto stablehloValue = stablehlo::symbolize##Name(hloValue); \ - if (!stablehloValue.has_value()) return {}; \ - return stablehlo::Name##Attr::get(attr.getContext(), stablehloValue.value()) - -Attribute convertAttr(Attribute hloAttr) { - // Handle MHLO attributes. - // The logic that handles attributes from other dialects (e.g. builtin - // attributes) lives below. - if (auto attr = hloAttr.dyn_cast()) { - return stablehlo::ChannelHandleAttr::get(attr.getContext(), - attr.getHandle(), attr.getType()); - } - if (auto attr = hloAttr.dyn_cast()) { - RETURN_CONVERTED_ENUM_ATTR(ComparisonDirection); - } - if (auto attr = hloAttr.dyn_cast()) { - RETURN_CONVERTED_ENUM_ATTR(ComparisonType); - } - if (auto attr = hloAttr.dyn_cast()) { - return stablehlo::ConvDimensionNumbersAttr::get( - attr.getContext(), attr.getInputBatchDimension(), - attr.getInputFeatureDimension(), attr.getInputSpatialDimensions(), - attr.getKernelInputFeatureDimension(), - attr.getKernelOutputFeatureDimension(), - attr.getKernelSpatialDimensions(), attr.getOutputBatchDimension(), - attr.getOutputFeatureDimension(), attr.getOutputSpatialDimensions()); - } - if (auto attr = hloAttr.dyn_cast()) { - RETURN_CONVERTED_ENUM_ATTR(CustomCallApiVersion); - } - if (auto attr = hloAttr.dyn_cast()) { - return stablehlo::DotDimensionNumbersAttr::get( - attr.getContext(), attr.getLhsBatchingDimensions(), - attr.getRhsBatchingDimensions(), attr.getLhsContractingDimensions(), - attr.getRhsContractingDimensions()); - } - if (auto attr = hloAttr.dyn_cast()) { - RETURN_CONVERTED_ENUM_ATTR(FftType); - } - if (auto attr = hloAttr.dyn_cast()) { - return stablehlo::GatherDimensionNumbersAttr::get( - attr.getContext(), attr.getOffsetDims(), attr.getCollapsedSliceDims(), - attr.getStartIndexMap(), attr.getIndexVectorDim()); - } - if (auto attr = hloAttr.dyn_cast()) { - // This precision value is used to experiment with int4 support. - // Needs more experimental data before we decide whether or not to propose - // it to StableHLO. - if (attr.getValue() == mhlo::Precision::PACKED_NIBBLE) return {}; - RETURN_CONVERTED_ENUM_ATTR(Precision); - } - if (auto attr = hloAttr.dyn_cast()) { - RETURN_CONVERTED_ENUM_ATTR(RngAlgorithm); - } - if (auto attr = hloAttr.dyn_cast()) { - RETURN_CONVERTED_ENUM_ATTR(RngDistribution); - } - if (auto attr = hloAttr.dyn_cast()) { - return stablehlo::ScatterDimensionNumbersAttr::get( - attr.getContext(), attr.getUpdateWindowDims(), - attr.getInsertedWindowDims(), attr.getScatterDimsToOperandDims(), - attr.getIndexVectorDim()); - } - if (auto attr = hloAttr.dyn_cast()) { - RETURN_CONVERTED_ENUM_ATTR(Transpose); - } - if (hloAttr.getDialect().getNamespace() == - mhlo::MhloDialect::getDialectNamespace()) { - // Our guiding principle is to support all StableHLO functionality in MHLO. - // The inverse is not necessarily true - some MHLO attributes are missing - // from StableHLO (either deliberately or haven't yet been proposed). - // As a result, these MHLO attributes will fail here. - return {}; - } - - // Handle non-MHLO attributes. - // If an attribute is not defined in MHLO, then it is unchanged, - // with the exception of ArrayAttr which is converted recursively. - if (auto hloAttrs = hloAttr.dyn_cast()) { - SmallVector stablehloAttrs; - for (auto hloAttr : hloAttrs) { - auto stablehloAttr = convertAttr(hloAttr); - if (!stablehloAttr) return {}; - stablehloAttrs.push_back(stablehloAttr); - } - return ArrayAttr::get(hloAttrs.getContext(), stablehloAttrs); - } - return hloAttr; -} - -#undef RETURN_CONVERTED_ENUM_ATTR - -template -class HloToStablehloOpConverter : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - HloOpTy hloOp, typename HloOpTy::Adaptor adaptor, - ConversionPatternRewriter& rewriter) const final { - // Most MHLO ops which end up here are fully supported by StableHLO. - // However, some of these ops are supported only partially because they - // have features that either haven't been proposed to StableHLO yet - // or aren't planned to be proposed to StableHLO. - // The check below makes sure we only proceed for supported ops. - if constexpr (std::is_same::value) { - // Added to MHLO per feature request from JAX. - // Hasn't been proposed to StableHLO yet. - if (!hloOp.getOutputOperandAliases().empty()) return failure(); - } - - // Convert MHLO types to StableHLO equivalents. - // If a type is not defined in MHLO, then it is unchanged, - // with the exception of RankedTensorType and TupleType which are - // converted recursively. - // See `HloToStablehloTypeConverter` for more information on when this - // conversion will succeed or fail. - SmallVector stablehloTypes; - if (failed(this->getTypeConverter()->convertTypes(hloOp->getResultTypes(), - stablehloTypes))) - return failure(); - - // These operands have already been converted to StableHLO by - // the dialect conversion infrastructure. - ValueRange stablehloOperands = adaptor.getOperands(); - - // Convert MHLO attributes to StableHLO equivalents. - // If an attribute is not defined in MHLO, then it is unchanged, - // with the exception of ArrayAttr which is converted recursively. - SmallVector stablehloAttrs; - for (NamedAttribute hloAttr : hloOp->getAttrs()) { - auto stablehloAttr = convertAttr(hloAttr.getValue()); - if (!stablehloAttr) return failure(); - stablehloAttrs.push_back({hloAttr.getName(), stablehloAttr}); - } - - // Convert the MHLO operation to a StableHLO equivalent. - // This can almost be done in a generic fashion, except for stablehlo.case - // that uses a variadic number of regions which means an additional argument - // for the generic builder. - HloToStablehloOp stablehloOp; - if constexpr (std::is_same::value) { - stablehloOp = rewriter.replaceOpWithNewOp( - hloOp, stablehloTypes, stablehloOperands, stablehloAttrs, - hloOp.getBranches().size()); - } else { - stablehloOp = rewriter.replaceOpWithNewOp>( - hloOp, stablehloTypes, stablehloOperands, stablehloAttrs); - } - - // Finally, populate the regions while converting argument types - // and nested operations. - for (auto [hloRegion, stablehloRegion] : - llvm::zip(hloOp->getRegions(), stablehloOp->getRegions())) { - rewriter.inlineRegionBefore(hloRegion, stablehloRegion, - stablehloRegion.end()); - } - return success(); - } -}; - -template -void populateHloToStablehloPatterns(RewritePatternSet* patterns, - TypeConverter* converter, - MLIRContext* context) { - patterns - ->add>...>( - *converter, context); -} - -} // namespace - -void populateHloToStablehloPatterns(RewritePatternSet* patterns, - TypeConverter* converter, - MLIRContext* context) { - // Populate conversion patterns for all StableHLO ops. - // Our guiding principle is to support all StableHLO functionality in MHLO. - // The inverse is not necessarily true - some MHLO ops are missing from - // StableHLO (either deliberately or haven't yet been proposed to StableHLO). - // As a result, these MHLO ops will not be added to these patterns and - // will fail the conversion. - populateHloToStablehloPatterns< -#define GET_OP_LIST -#include "stablehlo/dialect/StablehloOps.cpp.inc" - >(patterns, converter, context); -} - -} // namespace stablehlo -} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/lower_complex_patterns.td b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/lower_complex_patterns.td deleted file mode 100644 index eae54f1f4e3..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/lower_complex_patterns.td +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This is the legalization pattern that converts complex operations into -// equivalent real value operations. - -include "mlir/IR/OpBase.td" -include "mlir/Dialect/Func/IR/FuncOps.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" - -//===----------------------------------------------------------------------===// -// Binary op patterns. -//===----------------------------------------------------------------------===// - -// Add and subtraction are elementwise and can be distributed across the real -// and imaginary components. -foreach elementwiseOp = [HLO_AddOp, HLO_SubtractOp] in - def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs, - HLO_ComplexTensor:$rhs), - (HLO_ComplexOp - (elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs)), - (elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs)))>; - -// Complex multiplication results in a cross product multiplication between the -// real and imaginary components such that: -// result.real = lhs.real * rhs.real - lhs.imag * rhs.imag -// result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag -def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, - HLO_ComplexTensor:$rhs), - (HLO_ComplexOp - (HLO_SubtractOp - (HLO_MulOp - (HLO_RealOp:$lhs_real $lhs), - (HLO_RealOp:$rhs_real $rhs)), - (HLO_MulOp - (HLO_ImagOp:$lhs_imag $lhs), - (HLO_ImagOp:$rhs_imag $rhs))), - (HLO_AddOp - (HLO_MulOp $lhs_real, $rhs_imag), - (HLO_MulOp $lhs_imag, $rhs_real)))>; - - -// Division is performed by normalizing the denominator by multiplying by the -// conjugate of the rhs. -// numerator = lhs * conj(rhs) -// denominator = rhs * conj(rhs) -def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs), - (HLO_ComplexOp - (HLO_DivOp - (HLO_RealOp (HLO_MulOp:$num $lhs, - (HLO_ComplexOp:$conj - (HLO_RealOp $rhs), - (HLO_NegOp (HLO_ImagOp $rhs))))), - (HLO_AddOp:$den - (HLO_MulOp (HLO_RealOp $rhs), (HLO_RealOp $rhs)), - (HLO_MulOp (HLO_ImagOp $rhs), (HLO_ImagOp $rhs)))), - (HLO_DivOp (HLO_ImagOp $num), $den))>; - -// Absolute value is evaluated as: -// result = sqrt(val.real * val.real + val.imag * val.imag) -def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val), - (HLO_SqrtOp - (HLO_AddOp - (HLO_MulOp (HLO_RealOp:$real $val), $real), - (HLO_MulOp (HLO_ImagOp:$imag $val), $imag)))>; - -// Can deconstruct sin(a + ib) as follows: -// sin(a) * cosh(b) + icos(a) * sinh(b) -// sinh(b) = (e^x - e^-x) / 2 -// cosh(b) = (e^x + e^-x) / 2 -def : Pat<(HLO_SineOp HLO_ComplexTensor:$val), - (HLO_ComplexOp - (HLO_DivOp - (HLO_MulOp - (HLO_SineOp (HLO_RealOp:$real $val)), - (HLO_AddOp - (HLO_ExpOp:$exp (HLO_ImagOp:$imag $val)), - (HLO_ExpOp:$nexp (HLO_NegOp $imag)))), - (HLO_ConstantOp : $two (ConstantSplat<"2.0"> $real))), - (HLO_DivOp - (HLO_MulOp - (HLO_CosineOp $real), - (HLO_SubtractOp $exp, $nexp)), $two))>; - -// Can deconstruct cos(a + ib) as follows: -// cos(a) * cosh(b) - isin(a) * sinh(b) -// sinh(b) = (e^x - e^-x) / 2 -// cosh(b) = (e^x + e^-x) / 2 -def : Pat<(HLO_CosineOp HLO_ComplexTensor:$val), - (HLO_ComplexOp - (HLO_DivOp - (HLO_MulOp - (HLO_CosineOp (HLO_RealOp:$real $val)), - (HLO_AddOp - (HLO_ExpOp:$exp (HLO_ImagOp:$imag $val)), - (HLO_ExpOp:$nexp (HLO_NegOp $imag)))), - (HLO_ConstantOp : $two (ConstantSplat<"2.0"> $real))), - (HLO_DivOp - (HLO_MulOp - (HLO_SineOp $real), - (HLO_SubtractOp $nexp, $exp)), $two))>; - -// Exponential can be lowered to an exponential on the real component and a -// sum of sinusoids of the imaginary component, which equates to a normal -// exponential operator multiplied by Euler's formula. -// -// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * Cos(b) + Exp(a) * iSin(b)) -class HLO_ComparisonDirectionValue : - ConstantAttr; - -def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val), - (HLO_ComplexOp - (HLO_MulOp - (HLO_CosineOp (HLO_ImagOp:$imag $val)), - (HLO_ExpOp:$exp (HLO_RealOp:$real $val))), - (HLO_MulOp (HLO_SineOp $imag), $exp))>; - -foreach pair = [[HLO_ComparisonDirectionValue<"NE">, HLO_OrOp], - [HLO_ComparisonDirectionValue<"EQ">, HLO_AndOp]] in { -def : Pat<(HLO_CompareOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs, pair[0], $compare_type), - (pair[1] - (HLO_CompareOp (HLO_RealOp $lhs), (HLO_RealOp $rhs), pair[0], $compare_type), - (HLO_CompareOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs), pair[0], $compare_type))>; -} diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/alloc_to_arg_pass.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/alloc_to_arg_pass.cc deleted file mode 100644 index 7441ee2dc12..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/alloc_to_arg_pass.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This files implements a pass that partially bufferized IR. - -#include -#include -#include -#include - -#include "mlir-hlo/Transforms/passes.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/Value.h" - -namespace mlir { - -#define GEN_PASS_DEF_ALLOCTOARGPASS -#include "mlir-hlo/Transforms/passes.h.inc" - -using ::mlir::func::FuncOp; - -namespace { -class AllocToArgPass : public impl::AllocToArgPassBase { - public: - using AllocToArgPassBase::AllocToArgPassBase; - - private: - void runOnOperation() override; -}; -} // namespace - -void AllocToArgPass::runOnOperation() { - FuncOp funcOp = getOperation(); - IRRewriter rewriter(funcOp.getContext()); - BitVector resultsToErase(funcOp.getNumResults()); - Operation *terminator = funcOp.getBody().back().getTerminator(); - for (OpOperand &result : terminator->getOpOperands()) { - Operation *allocOp = result.get().getDefiningOp(); - if (!allocOp || !isa(allocOp)) { - terminator->emitOpError("expected operand #") - << result.getOperandNumber() << " to be defined by an memref.alloc"; - return signalPassFailure(); - } - resultsToErase.set(result.getOperandNumber()); - auto attrs = funcOp.getResultAttrDict(result.getOperandNumber()); - funcOp.insertArgument(funcOp.getNumArguments(), result.get().getType(), - attrs, result.get().getLoc()); - rewriter.replaceOp(allocOp, funcOp.getArguments().back()); - } - funcOp.eraseResults(resultsToErase); - terminator->eraseOperands(resultsToErase); -} - -std::unique_ptr> hlo::createAllocToArgPass() { - return std::make_unique(); -} - -} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/gml_st_pipeline.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/gml_st_pipeline.cc deleted file mode 100644 index 7be184c8228..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/gml_st_pipeline.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mlir-hlo/Transforms/gml_st_pipeline.h" - -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Transforms/passes.h" -#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Dialect/Bufferization/Transforms/Passes.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/Passes.h" - -namespace mlir { - -using ::mlir::func::FuncOp; - -void createGmlStPipeline(mlir::OpPassManager& pm, - const GmlStPipelineOptions& options) { - // Transforms HLO to Linalg + THLO. - pm.addNestedPass(mhlo::createLegalizeMHLOToTHLOPass()); - pm.addNestedPass(mhlo::createLegalizeHloToLinalgPass()); - - // Perform tiling, fusion, vectorization and other transformations. - pm.addNestedPass( - gml_st::createTilingPass("", "", true, options.tileSizes)); - pm.addNestedPass(gml_st::createFusionPass()); - pm.addNestedPass(createScalarizationPass()); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - - if (!options.lowerToLoops) return; - - // Bufferization-related passes. - pm.addNestedPass(bufferization::createEmptyTensorToAllocTensorPass()); - pm.addPass(hlo::createOneShotBufferizePass()); - pm.addPass(createCSEPass()); - pm.addPass(createCanonicalizerPass()); - pm.addNestedPass(bufferization::createBufferDeallocationPass()); - - // Convert Linalg + GmlSt to SCF loops. - pm.addNestedPass(createConvertLinalgToLoopsPass()); - pm.addNestedPass(gml_st::createVectorizeGmlStLoopsPass()); - pm.addNestedPass(gml_st::createGmlStToScfPass()); - - pm.addNestedPass(createLowerAffinePass()); - pm.addPass(createConvertSCFToCFPass()); - pm.addPass(hlo::createGenericHostToLLVMPass()); -} - -} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/hlo_to_gpu_pipeline.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/hlo_to_gpu_pipeline.cc deleted file mode 100644 index 57272dbbd4b..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/hlo_to_gpu_pipeline.cc +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/// This files contains a pipeline which converts HLO operations to GPU kernels -/// written in a combination of LLVM and NVVM dialects. - -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Transforms/gpu_passes.h" -#include "mlir-hlo/Transforms/passes.h" -#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" -#include "mlir/Dialect/Arith/Transforms/Passes.h" -#include "mlir/Dialect/Bufferization/Transforms/Passes.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/GPU/Transforms/Passes.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/SCF/Transforms/Passes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/Passes.h" - -using namespace mlir; -using ::mlir::func::FuncOp; -using ::mlir::gpu::GPUModuleOp; - -static constexpr const char* kBlockDistributionLabel = "block"; -static constexpr const char* kWarpDistributionLabel = "warp"; -static constexpr const char* kThreadDistributionLabel = "thread"; - -// TODO(b/233761238): We only want to have this pipeline temporarily, as it is -// not yet clear how exactly it will look like. The goal is to merge this with -// the unified kernel generator + autofusion + XLA Next pipeline once we have -// it, and once this code stabilizes. -void mlir::createHloToGpuPipeline(OpPassManager& pm, - ArrayRef blockTileDim, - ArrayRef warpTileDim, - ArrayRef threadTileDim, - bool experimentalSoftmax) { - pm.addNestedPass(hlo::createUnbufferizePass()); - pm.addNestedPass(hlo::createInlineFusionPass()); - pm.addPass(createCSEPass()); // Combine repeated subtract(broadcast). - - // HLO -> Linalg - pm.addNestedPass(mhlo::createChloLegalizeToHloPass()); - pm.addPass(createCanonicalizerPass()); // Clean up shape.assuming ops. - pm.addNestedPass(mhlo::createLegalizeHloToLinalgPass()); - - // Perform tiling either for softmax or for element-wise. - if (experimentalSoftmax) { - // Simplify unit dimension. - pm.addPass(mlir::createLinalgFoldUnitExtentDimsPass()); - - // Tile parallel dimensions of the softmax-like patterns and distribute them - // across warps. Warps remain independant of each other. - pm.addNestedPass(gml_st::createTilingSoftmaxPass( - /*distribute=*/true, blockTileDim, kBlockDistributionLabel)); - pm.addNestedPass(gml_st::createTilingSoftmaxPass( - /*distribute=*/true, warpTileDim, kWarpDistributionLabel)); - - // GPU-specific tiling for ops on the warp level. - pm.addNestedPass(gml_st::createTilingGPUWarpPass()); - pm.addNestedPass(createScalarizationPass()); - - pm.addNestedPass(gml_st::createVectorizeGmlStLoopsPass( - /*vectorizeGmlStOps=*/true, /*distributionLabels=*/{ - kWarpDistributionLabel, kThreadDistributionLabel})); - } else { - // TODO(b/244313563): This is a workaround to avoid temporary allocs within - // threads. It works for as long as all of our operations are cwise. - // Vectorize the inner loops instead. - // TODO(frgossen): We should not have to skip this pass for softmax. - pm.addNestedPass(createLinalgElementwiseOpFusionPass()); - - // Tiling - pm.addNestedPass(gml_st::createTilingCwisePass( - /*distribute=*/true, blockTileDim, kBlockDistributionLabel)); - pm.addNestedPass(gml_st::createTilingCwisePass( - /*distribute=*/true, warpTileDim, kWarpDistributionLabel)); - pm.addNestedPass(gml_st::createTilingCwisePass( - /*distribute=*/true, threadTileDim, kThreadDistributionLabel)); - pm.addNestedPass(createScalarizationPass()); - } - - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - - // Bufferization-related passes. - pm.addNestedPass(bufferization::createEmptyTensorToAllocTensorPass()); - pm.addPass(hlo::createOneShotBufferizePass()); - // We do not deallocate buffers, since grid-level buffers get converted into - // functions arguments, while block- (and lower-)level buffers become shared - // memory. None of which have to be deallocated. - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - // Canonicalize away memory copies into itself - pm.addPass(createCanonicalizerPass()); - - // Linalg + GmlSt -> GPU - pm.addNestedPass(createGmlStToGpuPass()); - pm.addNestedPass(gml_st::createGmlStToScfPass()); - pm.addNestedPass(arith::createArithExpandOpsPass()); - pm.addNestedPass(createConvertLinalgToLoopsPass()); - pm.addNestedPass(createCanonicalizerPass()); - pm.addPass(createGpuLauchSinkIndexComputationsPass()); - constexpr llvm::StringRef kGpuDataLayoutSpec = - "#dlti.dl_spec<#dlti.dl_entry>"; - pm.addPass(createGpuKernelOutliningPass(kGpuDataLayoutSpec)); - pm.addNestedPass(createForLoopSpecializationPass()); - pm.addNestedPass(hlo::createUnrollLoopsPass()); - pm.addNestedPass(createLowerAffinePass()); - pm.addNestedPass(createCanonicalizerPass()); - pm.addNestedPass(createConvertSCFToCFPass()); - - // GPU -> low-level IR -#if TENSORFLOW_USE_ROCM - pm.addNestedPass(createGpuKernelToRocdlPass()); -#else - pm.addNestedPass(createGpuKernelToNvvmPass()); -#endif - pm.addPass(createPropagateStaticShapesToKernelPass()); - pm.addNestedPass(createCSEPass()); - // Some instructions crash ptxas down the line if they have debug info - // attached. - pm.addNestedPass(createStripDebugInfoPass()); - pm.addNestedPass(hlo::createAllocToArgPass()); -} diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/inline_fusion_pass.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/inline_fusion_pass.cc deleted file mode 100644 index f91540153c5..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/inline_fusion_pass.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This files implements a pass that inlines all mlho.fusion op regions. - -#include -#include -#include -#include - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Transforms/passes.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" - -namespace mlir { - -#define GEN_PASS_DEF_INLINEFUSIONPASS -#include "mlir-hlo/Transforms/passes.h.inc" - -using ::mlir::func::FuncOp; - -namespace { -// Inlines all mhlo.fusion op regions. -class InlineFusionPass : public impl::InlineFusionPassBase { - public: - using InlineFusionPassBase::InlineFusionPassBase; - - private: - void runOnOperation() override; -}; -} // namespace - -void InlineFusionPass::runOnOperation() { - FuncOp funcOp = getOperation(); - IRRewriter rewriter(funcOp.getContext()); - funcOp->walk([&](mhlo::FusionOp fusionOp) { - assert(fusionOp.getFusedComputation().hasOneBlock()); - rewriter.setInsertionPoint(fusionOp); - BlockAndValueMapping bvm; - Block& body = *fusionOp.getFusedComputation().begin(); - bvm.map(body.getArguments(), fusionOp.getInputs()); - for (auto& op : body.without_terminator()) rewriter.clone(op, bvm); - auto results = llvm::map_range( - body.getTerminator()->getOperands(), - [&](Value operand) { return bvm.lookupOrDefault(operand); }); - rewriter.replaceOp(fusionOp, llvm::to_vector(results)); - }); -} - -std::unique_ptr> hlo::createInlineFusionPass() { - return std::make_unique(); -} - -} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/mhlo/CMakeLists.txt similarity index 92% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/mhlo/CMakeLists.txt index e138afa587f..a4a8881e204 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/CMakeLists.txt @@ -13,5 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +add_subdirectory(analysis) add_subdirectory(IR) add_subdirectory(transforms) +add_subdirectory(utils) diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/CMakeLists.txt similarity index 68% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/mhlo/IR/CMakeLists.txt index d2750882594..b982ed73435 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/CMakeLists.txt @@ -13,6 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# Need a separate function because of the .cc vs .cpp used in the one provided by MLIR +set(LLVM_TARGET_DEFINITIONS hlo_ops.td) +mlir_tablegen(hlo_ops.h.inc -gen-op-decls) +mlir_tablegen(hlo_ops.cc.inc -gen-op-defs) +mlir_tablegen(hlo_ops_enums.h.inc -gen-enum-decls) +mlir_tablegen(hlo_ops_enums.cc.inc -gen-enum-defs) +mlir_tablegen(hlo_ops_attrs.h.inc -gen-attrdef-decls) +mlir_tablegen(hlo_ops_attrs.cc.inc -gen-attrdef-defs) +mlir_tablegen(hlo_ops_typedefs.h.inc -gen-typedef-decls --typedefs-dialect=mhlo) +mlir_tablegen(hlo_ops_typedefs.cc.inc -gen-typedef-defs --typedefs-dialect=mhlo) +add_public_tablegen_target(MLIRhlo_opsIncGen) +add_dependencies(mlir-headers MLIRhlo_opsIncGen) + include_directories(BEFORE ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) @@ -49,12 +62,14 @@ target_link_libraries(MhloDialect MLIRQuantDialect MLIRSparseTensorDialect HloOpsCommon + StablehloAssemblyFormat StablehloBase + StablehloTypeInference ) target_include_directories(MhloDialect PUBLIC + $ $ - $ ) add_mlir_dialect_library(MhloRegisterDialects diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/chlo_canonicalize.td b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/chlo_canonicalize.td similarity index 83% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/chlo_canonicalize.td rename to tensorflow/compiler/xla/mlir_hlo/mhlo/IR/chlo_canonicalize.td index a8ffd26b7de..dcc0ab46366 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/chlo_canonicalize.td +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/chlo_canonicalize.td @@ -16,8 +16,8 @@ limitations under the License. // This is the canonicalize pattern definition file. include "mlir/IR/OpBase.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" +include "mhlo/IR/hlo_ops.td" +include "mhlo/IR/hlo_utils.td" def UnaryToBinaryEinsumEq : NativeCodeCall< "$_builder.getStringAttr(\",\" + $0.getValue().str())">; @@ -25,6 +25,6 @@ def UnaryToBinaryEinsumEq : NativeCodeCall< // Convert UnaryEinsumOp to EinsumOp with two operands with redundant first // operand. def UnaryEinsumToEinsum : Pat< - (HLO_UnaryEinsumOp $operand, $equation), - (HLO_EinsumOp (HLO_ConstantOp (GetScalarOfType<1> $operand)), + (MHLO_UnaryEinsumOp $operand, $equation), + (MHLO_EinsumOp (MHLO_ConstantOp (GetScalarOfType<1> $operand)), $operand, (UnaryToBinaryEinsumEq $equation))>; diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_base.td b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_base.td new file mode 100644 index 00000000000..cf50dc870d6 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_base.td @@ -0,0 +1,122 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MLIR_HLO_DIALECT_MHLO_IR_HLO_BASE +#define MLIR_HLO_DIALECT_MHLO_IR_HLO_BASE + +include "stablehlo/dialect/Base.td" + +//===----------------------------------------------------------------------===// +// MHLO type definitions. +//===----------------------------------------------------------------------===// + +defvar MHLO_Pred = HLO_Pred; + +defvar MHLO_SInt = HLO_SInt; +defvar MHLO_UInt = HLO_UInt; +defvar MHLO_Int = HLO_Int; + +defvar MHLO_Float = HLO_Float; +defvar MHLO_Float32Or64 = HLO_Float32Or64; + +defvar MHLO_Complex = HLO_Complex; + +//===----------------------------------------------------------------------===// +// Quantized element type definitions. +//===----------------------------------------------------------------------===// + +// Integer-based uniform quantized types. The definitions can be used to specify +// operand's tensor types. +defvar MHLO_QuantizedSignedInt = HLO_QuantizedSignedInt; +defvar MHLO_QuantizedUnsignedInt = HLO_QuantizedUnsignedInt; +defvar MHLO_QuantizedInt = HLO_QuantizedInt; + +// The broadcasting dimensions correspond to a tuple that describes how a +// smaller rank shape is broadcast into a larger rank shape. For example, +// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means +// matching the matrix to dimensions 1 and 2 of the cuboid. +defvar MHLO_BroadcastDimAttr = I64ElementsAttr; + +// Token type. +defvar MHLO_Token = HLO_Token; + +// Any integer tensor types +defvar MHLO_IntTensor = HLO_IntTensor; + +// Any integer tensor type with rank 0 (i.e. representing a single integer). +defvar MHLO_ScalarIntTensor = HLO_ScalarIntTensor; + +// Any floating-point tensor types +defvar MHLO_FpTensor = HLO_FpTensor; + +// 32 or 64 bits floating-point tensor types +defvar MHLO_Fp32Or64Tensor = HLO_Fp32Or64Tensor; + +// Any quantized integer tensor types +defvar MHLO_QuantizedIntTensor = HLO_QuantizedIntTensor; + +defvar MHLO_PredTensor = HLO_PredTensor; + +defvar MHLO_Tensor = HLO_Tensor; + +defvar MHLO_ComplexTensor = HLO_ComplexTensor; + +defvar MHLO_Tuple = HLO_Tuple; + +defvar MHLO_TensorOrToken = HLO_TensorOrToken; + +defvar MHLO_TensorOrTokenOrTuple = AnyTypeOf<[MHLO_Tensor, MHLO_Token, MHLO_Tuple]>; + +defvar MHLO_DimensionValue = HLO_DimensionValue; + +// Dynamic representation of a shape vector as a tensor. +defvar MHLO_DimensionTensor = HLO_DimensionTensor; + +//===----------------------------------------------------------------------===// +// MHLO combined type definitions. +//===----------------------------------------------------------------------===// + +// Any integer or floating-point tensor types + +// Any integer or floating-point tensor types +defvar MHLO_IntOrFpTensor = HLO_IntOrFpTensor; + +// Any integer or predicate tensor types +defvar MHLO_PredOrIntTensor = HLO_PredOrIntTensor; + +// Any floating-point or complex tensor types +defvar MHLO_FpOrComplexTensor = HLO_FpOrComplexTensor; + +// Any int, floating-point or complex tensor types +defvar MHLO_IntFpOrComplexTensor = HLO_IntFpOrComplexTensor; + +// Any pred, int or floating-point tensor types +defvar MHLO_PredIntOrFpTensor = HLO_PredIntOrFpTensor; + +//===----------------------------------------------------------------------===// +// MHLO static shape type definitions. +//===----------------------------------------------------------------------===// + +// In general, static shaped tensor constraints should be avoided unless +// it is for a legacy op which is only correct with static shapes. +defvar MHLO_StaticShapeTensor = HLO_StaticShapeTensor; + +defvar MHLO_StaticShapeTensorOrToken = HLO_StaticShapeTensorOrToken; + +defvar MHLO_StaticShapeIntOrFpTensor = HLO_StaticShapeIntOrFpTensor; + +defvar MHLO_StaticShapeIntFpOrComplexTensor = HLO_StaticShapeIntFpOrComplexTensor; + +#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_BASE diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.cc similarity index 68% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 27d2ce4a410..2957fb2f18d 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -15,7 +15,7 @@ limitations under the License. // This file defines the operations used in the MHLO dialect. -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" #include #include @@ -27,6 +27,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -47,11 +48,9 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" -#include "mlir-hlo/Dialect/mhlo/IR/mhlo_bytecode.h" -#include "mlir-hlo/utils/convert_op_folder.h" -#include "mlir-hlo/utils/hlo_utils.h" +#include "mhlo/IR/hlo_ops.h.inc" +#include "mhlo/IR/hlo_ops_common.h" +#include "mhlo/IR/mhlo_bytecode.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -81,16 +80,21 @@ limitations under the License. #include "mlir/Transforms/InliningUtils.h" #include "stablehlo/dialect/AssemblyFormat.h" #include "stablehlo/dialect/TypeInference.h" +#include "utils/convert_op_folder.h" +#include "utils/hlo_utils.h" namespace mlir { #include "hlo_patterns.cc.inc" } // namespace mlir -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_enums.cc.inc" +using mlir::hlo::parseDimSizes; +using mlir::hlo::printDimSizes; + +#include "mhlo/IR/hlo_ops_enums.cc.inc" #define GET_ATTRDEF_CLASSES -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_attrs.cc.inc" +#include "mhlo/IR/hlo_ops_attrs.cc.inc" #define GET_TYPEDEF_CLASSES -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_typedefs.cc.inc" +#include "mhlo/IR/hlo_ops_typedefs.cc.inc" namespace mlir { namespace mhlo { @@ -162,12 +166,6 @@ void createArgs(ArrayRef operands, } } -// Checks if the vector `nums` has duplicates. -const auto hasDuplicates = [](const ArrayRef nums) { - llvm::SmallDenseSet set(nums.begin(), nums.end()); - return set.size() != nums.size(); -}; - //===----------------------------------------------------------------------===// // Utilities for the canonicalize patterns //===----------------------------------------------------------------------===// @@ -232,7 +230,7 @@ static void replaceOpWithRegion(PatternRewriter& rewriter, Operation* op, rewriter.eraseOp(terminator); } -#include "mhlo_canonicalize.inc" +#include "mhlo/IR/mhlo_canonicalize.inc" // Common shape function helper for RngNormal and RngUniform. static LogicalResult rngInferReturnTypeComponents( @@ -260,7 +258,7 @@ static LogicalResult rngInferReturnTypeComponents( inferredReturnShapes.emplace_back(elementType); return success(); } - shapeVector.resize(size, ShapedType::kDynamicSize); + shapeVector.resize(size, ShapedType::kDynamic); inferredReturnShapes.emplace_back(shapeVector, elementType); return success(); } @@ -277,11 +275,13 @@ static LogicalResult rngInferReturnTypeComponents( // an integer or index type. Value maybeCastTo(OpBuilder& b, Location loc, Value value, Type type) { if (type == value.getType()) return value; + // DISC-Begin if (!type.isIndex() && !value.getType().isIndex()) { // in case of i32 -> i64 or vice versa Value casted = b.create(loc, b.getIndexType(), value); return b.create(loc, type, casted); } + // DISC-End return b.create(loc, type, value); } @@ -351,10 +351,10 @@ FailureOr>> convertNx2Attribute( //===----------------------------------------------------------------------===// LogicalResult TypeExtensionsAttr::verifyEncoding( - llvm::ArrayRef bounds, mlir::Type elementType, + llvm::ArrayRef shape, mlir::Type elementType, llvm::function_ref emitError) const { return hlo::verifyBounds( - getBounds(), RankedTensorType::get(bounds, elementType), emitError); + getBounds(), RankedTensorType::get(shape, elementType), emitError); } //===----------------------------------------------------------------------===// @@ -373,24 +373,9 @@ void CollectivePermuteOp::build(OpBuilder& odsBuilder, OperationState& odsState, //===----------------------------------------------------------------------===// LogicalResult ReduceScatterOp::verify() { - if (failed(mlir::hlo::verifyReplicaGroups(*this, /*isUniformSized=*/true))) - return failure(); - auto operandType = getOperand().getType().cast(); - bool operandTypeRanked = operandType.isa(); - Block& block = getComputation().front(); - SmallVector accumulatorSubshapes; - if (failed(hlo::verifyReducerShape( - this->getLoc(), block, {operandType}, - {RankedTensorType::get({}, operandType.getElementType())}, - /*numInputs=*/1, /*allowedDimensions=*/{}, - /*allInputsUnranked=*/!operandTypeRanked, accumulatorSubshapes))) - return failure(); - - return mlir::hlo::verifyReduceScatter( - *this, - /*operandTypes=*/{getOperand().getType()}, - /*resultTypes=*/{getType()}, - /*scatterDimension=*/getScatterDimension()); + return hlo::verifyReduceScatterOp( + getLoc(), getOperand(), getScatterDimension(), getReplicaGroups(), + getUseGlobalDeviceIds(), getComputation(), getResult()); } void ReduceScatterOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -460,6 +445,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SignOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SineOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SqrtOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SubtractOp) +INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(TanOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(TanhOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(XorOp) @@ -468,7 +454,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(XorOp) //===----------------------------------------------------------------------===// Type maybeTupleFromTypes(MLIRContext* ctx, ArrayRef types) { - if (types.size() == 1) return types[0]; + if (types.size() == 1 && !types[0].isa()) return types[0]; return TupleType::get(ctx, TypeRange(types)); } @@ -500,7 +486,7 @@ LogicalResult AsyncStartOp::verify() { << ", but expected: " << calleeType.getNumInputs() << "."; } - for (int i = 0; i < getOperands().size(); ++i) { + for (int i = 0; i < static_cast(getOperands().size()); ++i) { if (calleeType.getInput(i) != getOperandTypes()[i]) { return emitOpError() << "type mismatch on argument #" << i << " of " << getCalledComputation() @@ -631,12 +617,23 @@ LogicalResult AsyncDoneOp::inferReturnTypes( return success(); } +//===----------------------------------------------------------------------===// +// AfterAllOp +//===----------------------------------------------------------------------===// + +LogicalResult AfterAllOp::inferReturnTypes( + MLIRContext* context, Optional location, ValueRange, + DictionaryAttr, RegionRange, SmallVectorImpl& inferredReturnTypes) { + auto dialect = context->getLoadedDialect(); + return hlo::inferAfterAllOp(dialect, location, inferredReturnTypes); +} + //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// -OpFoldResult ConstantOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); +OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { + assert(adaptor.getOperands().empty() && "constant has no operands"); // Return the held attribute value. return getValue(); @@ -669,13 +666,12 @@ void ConstantOp::build(OpBuilder& /*builder*/, OperationState& result, } LogicalResult ConstantOp::inferReturnTypes( - MLIRContext*, Optional, ValueRange operands, + MLIRContext*, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange, SmallVectorImpl& inferredReturnTypes) { ConstantOpAdaptor adaptor(operands, attributes); - Type type = adaptor.getValue().getType(); - inferredReturnTypes.push_back(type); - return success(); + return hlo::inferConstantOp(location, adaptor.getValue(), + inferredReturnTypes); } bool ConstantOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { @@ -737,6 +733,73 @@ void ConstantOp::print(::mlir::OpAsmPrinter& p) { p.printStrippedAttrOrType(getValueAttr()); } +//===----------------------------------------------------------------------===// +// Helper function to verify output operand aliasing (FusionOp and CustomCallOp) +//===----------------------------------------------------------------------===// + +template +LogicalResult verifyOutputOperandAliasing(CallableOpType* op) { + auto aliasArrayAttr = op->getOutputOperandAliases(); + for (auto attr : aliasArrayAttr) { + auto alias = attr.template cast(); + auto outputTupleIndices = alias.getOutputTupleIndices(); + auto operandIndex = alias.getOperandIndex(); + auto operandTupleIndices = alias.getOperandTupleIndices(); + if (operandIndex < 0 || + operandIndex >= static_cast(op->getInputs().size())) + return op->emitOpError() + << "expects operandIndex in the output_operand_alias attribute " + "to be in range [0, " + << op->getInputs().size() << "); got: " << operandIndex << "."; + Type operandPart = op->getOperand(operandIndex).getType(); + for (auto i : operandTupleIndices) { + if (!operandPart.isa() || + i >= static_cast(operandPart.cast().size()) || + i < 0) + return op->emitOpError() + << "operand_tuple_indices in the output_operand_alias " + "attribute out of bounds"; + operandPart = operandPart.cast().getType(i); + } + Type outputPart = + op->getNumResults() > 1 + ? TupleType::get(op->getContext(), op->getResultTypes()) + : op->getResult(0).getType(); + for (auto i : outputTupleIndices) { + if (!outputPart.isa() || + i >= static_cast(outputPart.cast().size()) || + i < 0) + return op->emitOpError() + << "output_tuple_indices in the output_operand_alias " + "attribute out of bounds"; + outputPart = outputPart.cast().getType(i); + } + if (operandPart != outputPart) + return op->emitOpError() + << "shapes mismatch in the output_operand_alias attribute: " + << "operand part has type " << operandPart + << " and output part has type " << outputPart; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// FusionOp +//===----------------------------------------------------------------------===// + +LogicalResult FusionOp::verify() { return verifyOutputOperandAliasing(this); } + +//===----------------------------------------------------------------------===// +// CreateTokenOp +//===----------------------------------------------------------------------===// + +LogicalResult CreateTokenOp::inferReturnTypes( + MLIRContext* context, Optional location, ValueRange, + DictionaryAttr, RegionRange, SmallVectorImpl& inferredReturnTypes) { + auto dialect = context->getLoadedDialect(); + return hlo::inferCreateTokenOp(dialect, location, inferredReturnTypes); +} + //===----------------------------------------------------------------------===// // CustomCallOp //===----------------------------------------------------------------------===// @@ -749,10 +812,12 @@ void CustomCallOp::build( ::mlir::mhlo::CustomCallApiVersionAttr apiVersion, ::mlir::ArrayAttr calledComputations, ::mlir::ArrayAttr operandLayouts, ::mlir::ArrayAttr resultLayouts) { - return CustomCallOp::build(odsBuilder, odsState, resultType, operands, - callTargetName, hasSideEffect, backendConfig, - apiVersion, calledComputations, operandLayouts, - resultLayouts, nullptr); + return CustomCallOp::build( + odsBuilder, odsState, resultType, operands, callTargetName, hasSideEffect, + backendConfig, apiVersion, calledComputations, + CustomCallScheduleAttr::get(odsBuilder.getContext(), + CustomCallSchedule::NONE), + operandLayouts, resultLayouts, nullptr); } LogicalResult CustomCallOp::verify() { @@ -846,46 +911,25 @@ LogicalResult CustomCallOp::verify() { } // Check output_operand_aliases + if (failed(verifyOutputOperandAliasing(this))) return failure(); - auto aliasArrayAttr = getOutputOperandAliases(); - for (auto attr : aliasArrayAttr) { - auto alias = attr.cast(); - auto outputTupleIndices = alias.getOutputTupleIndices(); - auto operandIndex = alias.getOperandIndex(); - auto operandTupleIndices = alias.getOperandTupleIndices(); - - if (operandIndex < 0 || operandIndex >= getInputs().size()) - return emitOpError() - << "expects operandIndex in the output_operand_alias attribute " - "to be in range [0, " - << getInputs().size() << "); got: " << operandIndex << "."; - - Type operandPart = getOperand(operandIndex).getType(); - for (auto i : operandTupleIndices) { - if (!operandPart.isa() || - i >= operandPart.cast().size() || i < 0) + // Check backend_config attribute. + if (auto backendConfig = getBackendConfig()) { + if (getApiVersion() == CustomCallApiVersion::API_VERSION_TYPED_FFI) { + // Typed FFI custom calls require `backend_config` to be a DictionaryAttr. + if (backendConfig->isa()) return emitOpError() - << "operand_tuple_indices in the output_operand_alias " - "attribute out of bounds"; - operandPart = operandPart.cast().getType(i); - } - Type outputPart = getNumResults() > 1 - ? TupleType::get(getContext(), getResultTypes()) - : getResult(0).getType(); - for (auto i : outputTupleIndices) { - if (!outputPart.isa() || - i >= outputPart.cast().size() || i < 0) + << "unsupported user-encoded backend config," + " backend config must be a dictionary attribute."; + } else { + // Older API versions require user-encoded `backend_config` string. + if (backendConfig->isa()) return emitOpError() - << "output_tuple_indices in the output_operand_alias " - "attribute out of bounds"; - outputPart = outputPart.cast().getType(i); + << "unsupported dictionary attribute backend config, backend" + " config must be a user-encoded string attribute."; } - if (operandPart != outputPart) - return emitOpError() - << "shapes mismatch in the output_operand_alias attribute: " - << "operand part has type " << operandPart - << " and output part has type " << outputPart; } + return success(); } @@ -906,111 +950,24 @@ void CustomCallOp::getEffects( // CholeskyOp //===----------------------------------------------------------------------===// -// The following properties are already enforced by the ODS: -// P0. a.element_type is floating or complex -// We intend to verify the following properties -// P1. The 'a' argument to Cholesky must have rank >= 2, got shape %s -// P2. The two minor dimensions of 'a' must have equal size, got %s. LogicalResult CholeskyOp::inferReturnTypeComponents( MLIRContext*, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { CholeskyOp::Adaptor adaptor(operands, attributes, regions); - Type aType = adaptor.getA().getType(); - RankedTensorType aRankedType = aType.dyn_cast(); - if (!aRankedType) { - inferredReturnShapes.emplace_back( - aType.cast().getElementType()); - return success(); - } - - ArrayRef aShape = aRankedType.getShape(); - if (aShape.size() < 2) { - return emitOptionalError( - location, "argument 'a' must have rank >= 2, got shape ", aShape, "."); - } - - int64_t lastDim = aShape[aShape.size() - 1]; - int64_t penultimateDim = aShape[aShape.size() - 2]; - if (!hlo::isDynamicDimSize(lastDim) && - !hlo::isDynamicDimSize(penultimateDim) && lastDim != penultimateDim) { - return emitOptionalError( - location, "minor dimensions of 'a' must have equal size, got shape ", - aShape, "."); - } - inferredReturnShapes.emplace_back(aRankedType.getShape(), - aRankedType.getElementType()); - return success(); + return hlo::inferCholeskyOp(location, adaptor.getA(), inferredReturnShapes); } //===----------------------------------------------------------------------===// // DotOp //===----------------------------------------------------------------------===// -namespace { -bool dimCompatible(int64_t a, int64_t b) { - return hlo::isDynamicDimSize(a) || hlo::isDynamicDimSize(b) || a == b; -} - -ShapedType inferDotReturnType(ShapedType lhs, ShapedType rhs) { - auto elementType = lhs.getElementType(); - if (!lhs.hasRank() || !rhs.hasRank()) { - return UnrankedTensorType::get(elementType); - } - - // vector dot vector - if (1 == lhs.getRank() && 1 == rhs.getRank() && - dimCompatible(lhs.getDimSize(0), rhs.getDimSize(0))) { - return RankedTensorType::get({}, elementType); - } - // matrix dot vector - if (2 == lhs.getRank() && 1 == rhs.getRank() && - dimCompatible(lhs.getDimSize(1), rhs.getDimSize(0))) { - return RankedTensorType::get({lhs.getDimSize(0)}, elementType); - } - // vector dot matrix - if (1 == lhs.getRank() && 2 == rhs.getRank() && - dimCompatible(lhs.getDimSize(0), rhs.getDimSize(0))) { - return RankedTensorType::get({rhs.getDimSize(1)}, elementType); - } - // matrix dot matrix - if (2 == lhs.getRank() && 2 == rhs.getRank() && - dimCompatible(lhs.getDimSize(1), rhs.getDimSize(0))) { - int64_t shape[2] = {lhs.getDimSize(0), rhs.getDimSize(1)}; - return RankedTensorType::get(shape, elementType); - } - return {}; -} -} // namespace - -LogicalResult DotOp::inferReturnTypes( - MLIRContext*, Optional, ValueRange operands, DictionaryAttr, - RegionRange, SmallVectorImpl& inferredReturnTypes) { - DotOp::Adaptor op(operands); - auto lhsType = op.getLhs().getType().cast(); - auto rhsType = op.getRhs().getType().cast(); - inferredReturnTypes.push_back(inferDotReturnType(lhsType, rhsType)); - return success(); -} LogicalResult DotOp::verify() { - auto lhsType = getLhs().getType().cast(); - auto rhsType = getRhs().getType().cast(); - auto resultType = getType().cast(); - auto expectReturnType = inferDotReturnType(lhsType, rhsType); - if (!expectReturnType) { - return emitError() << "Unexpected operands type: " << lhsType << " and " - << rhsType; - } - if (resultType.hasRank() && expectReturnType.hasRank()) { - if (resultType.getShape() != expectReturnType.getShape()) { - return emitError() << "Unexpected result type: has " << resultType - << " but inferred " << expectReturnType - << " from operands " << lhsType << " and " << rhsType; - } - } - return success(); + return hlo::verifyDotOp(getLoc(), getLhs(), getRhs(), getPrecisionConfig(), + getResult()); } +// DISC-Begin LogicalResult DotOp::reifyReturnTypeShapes( OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { @@ -1052,117 +1009,20 @@ LogicalResult DotOp::reifyReturnTypeShapes( return success(); } - +// DISC-End //===----------------------------------------------------------------------===// // DotGeneralOp //===----------------------------------------------------------------------===// LogicalResult DotGeneralOp::verify() { - auto dimNumbers = this->getDotDimensionNumbers(); - - ArrayRef lhsBatchingDims = dimNumbers.getLhsBatchingDimensions(); - ArrayRef rhsBatchingDims = dimNumbers.getRhsBatchingDimensions(); - ArrayRef lhsContractingDims = - dimNumbers.getLhsContractingDimensions(); - ArrayRef rhsContractingDims = - dimNumbers.getRhsContractingDimensions(); - - if (lhsBatchingDims.size() != rhsBatchingDims.size()) { - return emitOpError() << "lhs and rhs should have the same number of " - "batching dimensions"; - } - if (lhsContractingDims.size() != rhsContractingDims.size()) { - return emitOpError() << "lhs and rhs should have the same number of " - "contracting dimensions"; - } - - llvm::SmallDenseSet dimSet; - - auto checkDimsDistinct = - [this](ArrayRef batchingDims, ArrayRef contractingDims, - llvm::SmallDenseSet& dimSet, llvm::StringRef lhs, - llvm::StringRef rhs) -> LogicalResult { - auto dims = llvm::concat(batchingDims, contractingDims); - for (auto dim : dims) { - auto [_, wasInserted] = dimSet.insert(dim); - if (!wasInserted) { - return emitOpError() << "has duplicated dimension from " << lhs - << " and " << rhs << ": " << dim; - } - } - return success(); - }; - - if (failed(checkDimsDistinct(lhsBatchingDims, lhsContractingDims, dimSet, - "lhs_batching_dimensions", - "lhs_contracting_dimensions"))) { - return failure(); - } - dimSet.clear(); - if (failed(checkDimsDistinct(rhsBatchingDims, rhsContractingDims, dimSet, - "rhs_batching_dimensions", - "rhs_contracting_dimensions"))) { - return failure(); - } - - auto checkDimsInRange = [this](int64_t rank, ArrayRef dims, - llvm::StringRef dimName) -> LogicalResult { - auto inRange = [&](int64_t i) -> bool { return 0 <= i && i < rank; }; - const auto* dimsNotInRange = - std::find_if_not(dims.begin(), dims.end(), inRange); - if (dimsNotInRange != dims.end()) { - return emitOpError() << dimName << " value: " << *dimsNotInRange - << " is out of range: " - << "[0, " << rank << ")"; - } - return success(); - }; - - auto lhsType = this->getLhs().getType().dyn_cast(); - auto rhsType = this->getRhs().getType().dyn_cast(); - - if (lhsType) { - if (failed(checkDimsInRange(lhsType.getRank(), lhsBatchingDims, - "lhs_batching_dimensions")) || - failed(checkDimsInRange(lhsType.getRank(), lhsContractingDims, - "lhs_contracting_dimensions"))) { - return failure(); - } - } - if (rhsType) { - if (failed(checkDimsInRange(rhsType.getRank(), rhsBatchingDims, - "rhs_batching_dimensions")) || - failed(checkDimsInRange(rhsType.getRank(), rhsContractingDims, - "rhs_contracting_dimensions"))) { - return failure(); - } - } - - // =================== BEGIN Added for DISC ====================== - // tf.BatchMatmul(tensor, tensor<4x?x?xf32>) is valid, while the - // tf.BatchMatmul tf2mhlo converter does not handle shape propagation between - // the lhs & rhs, leading to following check fail. Just disable the check - // as a workaround. - // if (lhsType && rhsType) { - // // Dimension sizes must be compatible for lhs/rhs. - // auto lhsShape = lhsType.getShape(); - // auto rhsShape = rhsType.getShape(); - - // for (auto [lhs, rhs] : llvm::zip(lhsBatchingDims, rhsBatchingDims)) { - // if (lhsShape[lhs] != rhsShape[rhs]) { - // return emitOpError() << "batching dimension sizes must match for " - // "lhs/rhs"; - // } - // } - // for (auto [lhs, rhs] : llvm::zip(lhsContractingDims, rhsContractingDims)) { - // if (lhsShape[lhs] != rhsShape[rhs]) { - // return emitOpError() << "contracting dimension sizes must match for " - // "lhs/rhs"; - // } - // } - // } - return success(); + return hlo::verifyDotGeneralOp( + getLoc(), getLhs(), getRhs(), + getDotDimensionNumbersAttr().getLhsBatchingDimensions(), + getDotDimensionNumbersAttr().getRhsBatchingDimensions(), + getDotDimensionNumbersAttr().getLhsContractingDimensions(), + getDotDimensionNumbersAttr().getRhsContractingDimensions(), + getPrecisionConfig(), getResult()); } namespace { @@ -1177,19 +1037,34 @@ struct DotGeneralToDot : public OpRewritePattern { auto lhsTy = lhs.getType().cast(); auto rhsTy = rhs.getType().cast(); - if (lhsTy.getRank() != 2) return failure(); - if (rhsTy.getRank() != 2) return failure(); + int64_t lhsRank = lhsTy.getRank(); + int64_t rhsRank = rhsTy.getRank(); + if ((lhsRank != 1 && lhsRank != 2) || (rhsRank != 1 && rhsRank != 2)) { + return rewriter.notifyMatchFailure( + dot, "input tensors must have rank of 1 or 2"); + } auto nums = dot.getDotDimensionNumbers(); - if (!nums.getLhsBatchingDimensions().empty()) return failure(); - if (!nums.getRhsBatchingDimensions().empty()) return failure(); + if ((!nums.getLhsBatchingDimensions().empty()) || + (!nums.getRhsBatchingDimensions().empty())) { + return rewriter.notifyMatchFailure(dot, "cannot have batch dimensions"); + } auto lhsContract = nums.getLhsContractingDimensions(); auto rhsContract = nums.getRhsContractingDimensions(); - if (lhsContract.size() != 1 || rhsContract.size() != 1) return failure(); - if (lhsContract.front() != 1) return failure(); - if (rhsContract.front() != 0) return failure(); + if (lhsContract.size() != 1 || rhsContract.size() != 1) { + return rewriter.notifyMatchFailure( + dot, "input tensors must only have 1 contracting dimension"); + } + if (rhsContract.front() != 0) { + return rewriter.notifyMatchFailure( + dot, "rhs must contract the first dimension"); + } + if (lhsContract.front() != lhsRank - 1) { + return rewriter.notifyMatchFailure( + dot, "lhs must contract the last dimension"); + } rewriter.replaceOpWithNewOp( dot, dot.getType(), lhs, rhs, @@ -1247,107 +1122,15 @@ LogicalResult DotGeneralOp::reifyReturnTypeShapes( // FftOp //===----------------------------------------------------------------------===// -// We intend to verify the following properties -// P1. 1 <= rank <= 3 -// P2. Element types agree with fft_type -// P3. Operand shape dimensions agree with fft_length for the given fft_type LogicalResult FftOp::inferReturnTypeComponents( MLIRContext*, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { FftOp::Adaptor adaptor(operands, attributes, regions); - auto fftLength = adaptor.getFftLength().getValues(); - int64_t fftRank = fftLength.size(); - - // P1. - if (fftRank > 3 || fftRank < 1) { - return emitOptionalError(location, "rank must be between 1 and 3, but got ", - fftRank, "."); - } - - // P2. Element type agreement - // FFT : C -> C - // IFFT : C -> C - // RFFT : R -> C - // IRFFT : C -> R - auto fftType = adaptor.getFftType(); - auto operandType = adaptor.getOperand().getType().cast(); - Type operandElementType = operandType.getElementType(); - // Check the input element type and infer return element type - if (fftType == FftType::RFFT) { - if (!operandElementType.isF32() && !operandElementType.isF64()) { - return emitOptionalError( - location, "RFFT requires f32 or f64 input type, but is given ", - operandElementType, "."); - } - } else { - if (!operandElementType.isa()) { - return emitOptionalError( - location, stringifyFftType(fftType), - " takes a complex tensor as input, but is given ", operandType, "."); - } - } - // Generate the output element type - Type resultElementType = operandElementType; - if (fftType == FftType::RFFT) { // RFFT : R -> C - resultElementType = ComplexType::get(resultElementType); - } else if (fftType == FftType::IRFFT) { // IRFFT : C -> R - resultElementType = operandElementType.cast().getElementType(); - } - - // P3. Check input shape and infer return shape - operandType = operandType.dyn_cast(); - if (!operandType) { - inferredReturnShapes.emplace_back(resultElementType); - return success(); - } - auto operandShape = operandType.getShape(); - if (static_cast(operandShape.size()) < fftRank) { - return emitOptionalError( - location, "operand rank must not be less than fft rank of ", fftRank, - " for operand of type ", operandType, "."); - } - - SmallVector resultShape = to_vector(operandShape); - - if (fftType == FftType::RFFT) { - auto shapeBack = operandShape.take_back(fftRank); - for (auto [operandDim, fftDim] : llvm::zip(shapeBack, fftLength)) { - if (operandDim != fftDim) { - return emitOptionalError( - location, - "RFFT requires innermost dimensions match fft_length. Got: ", - operandShape, " but wanted ", fftLength, "."); - } - } - if (fftLength[fftRank - 1] != 0) { - resultShape[resultShape.size() - 1] = fftLength[fftRank - 1] / 2 + 1; - } - } - if (fftType == FftType::IRFFT) { - auto shapeBack = operandShape.take_back(fftRank).drop_back(); - for (auto [operandDim, fftDim] : llvm::zip(shapeBack, fftLength)) { - if (operandDim != fftDim) { - return emitOptionalError(location, - "IRFFT requires non-final dimensions " - "match fft_length. Got: ", - operandShape, " but wanted ", fftLength, - ", and ", operandDim, " != ", fftDim, "."); - } - } - if ((operandShape[operandShape.size() - 1] != 0 || - fftLength[fftRank - 1] != 0) && - operandShape[operandShape.size() - 1] != fftLength[fftRank - 1] / 2 + 1) - return emitOptionalError(location, - "IRFFT requires innermost dimension match " - "fft_length[-1]/2+1. Got: ", - operandShape, " but fft_length is ", fftLength, - "."); - resultShape[resultShape.size() - 1] = fftLength[fftRank - 1]; - } - - inferredReturnShapes.emplace_back(resultShape, resultElementType); - return success(); + return hlo::inferFftOp(location, adaptor.getOperand(), + adaptor.getFftType() == FftType::RFFT, + adaptor.getFftType() == FftType::IRFFT, + adaptor.getFftLength(), inferredReturnShapes); } //===----------------------------------------------------------------------===// @@ -1467,267 +1250,6 @@ void getSliceSizeValues(DynamicGatherOp* /*dGather*/, OpBuilder& builder, } } -// Verify the following properties: -// P1. Verify no repeat in start_index_map. -// P2. Verify 0 <= start_index_map[i] < rank(operand), for every i. -// P3. Verify 0 <= index_vector_dim <= rank(start_indices). -// P4. Verify size(start_index_map) == shape(start_indices)[index_vector_dim]. -// P5. Verify offset_dims is_sorted and no repeated. -// P6. Verify collapsed_slice_dims is_sorted and no repeated. -// P7. Verify rank(operand) == size(offset_dims) + size(collapsed_slice_dims). -// P8. Verify slice_sizes has rank of 1. -// P9. Verify size(slice_sizes) == rank(operand). -// P10. Verify 0 <= collapsed_slice_dims[i] < size(slice_sizes) for all items. -static LogicalResult verifyGather( - ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape, - ShapeAdaptor sliceSizesShape, GatherDimensionNumbersAttr dimensionNumbers, - llvm::function_ref errorEmitter) { - int64_t indexVectorDim = dimensionNumbers.getIndexVectorDim(); - - // Check startIndexMap - auto startIndexMap = to_vector(dimensionNumbers.getStartIndexMap()); - // P1. - if (hasDuplicates(startIndexMap)) - return errorEmitter() << "expects start_index_map to not repeat, got: [" - << startIndexMap << "]"; - - // P2. - for (auto i : llvm::seq(0, startIndexMap.size())) - if (startIndexMap[i] < 0 || - (operandShape.hasRank() && startIndexMap[i] >= operandShape.getRank())) - return errorEmitter() - << "start_index_map[" << i << "]: " << startIndexMap[i] - << " is out of bounds for " - << "operand rank " << operandShape.getRank(); - - if (startIndicesShape.hasRank()) { - // P3. - // index_vector_dim == start_indices.rank implies a trailing 1 on the shape - // of start_indices. - if (indexVectorDim > startIndicesShape.getRank() || indexVectorDim < 0) - return errorEmitter() << "index_vector_dim " << indexVectorDim - << " is out of bounds for start indices with rank " - << startIndicesShape.getRank(); - - bool impliedTrailingDim = indexVectorDim == startIndicesShape.getRank(); - if (impliedTrailingDim || !startIndicesShape.isDynamicDim(indexVectorDim)) { - int64_t effectiveDimSize; - if (impliedTrailingDim) - effectiveDimSize = 1; - else - effectiveDimSize = startIndicesShape.getDimSize(indexVectorDim); - // P4. - if (effectiveDimSize != - static_cast(dimensionNumbers.getStartIndexMap().size())) - return errorEmitter() << "start_index_map size (" - << dimensionNumbers.getStartIndexMap().size() - << ") is not equal to size of index dimension (" - << indexVectorDim << ") of start_indices (" - << effectiveDimSize << ")"; - } - } - - // P5. - auto offsetDims = to_vector(dimensionNumbers.getOffsetDims()); - if (!llvm::is_sorted(offsetDims)) - return errorEmitter() << "expects offset_dims to be sorted, got: [" - << offsetDims << "]"; - if (hasDuplicates(offsetDims)) - return errorEmitter() << "expects offset_dims to not repeat, got: [" - << offsetDims << "]"; - - // P6. - auto collapsedSliceDims = to_vector(dimensionNumbers.getCollapsedSliceDims()); - if (!llvm::is_sorted(collapsedSliceDims)) - return errorEmitter() << "expects collapsed_slice_dims to be sorted, got: [" - << collapsedSliceDims << "]"; - if (hasDuplicates(collapsedSliceDims)) - return errorEmitter() - << "expects collapsed_slice_dims to not repeat, got: [" - << collapsedSliceDims << "]"; - - // P7. - int64_t impliedOperandRank = dimensionNumbers.getOffsetDims().size() + - dimensionNumbers.getCollapsedSliceDims().size(); - if (operandShape.hasRank() && operandShape.getRank() != impliedOperandRank) - return errorEmitter() << "offset_dims size (" - << dimensionNumbers.getOffsetDims().size() - << ") plus collapse_slice_dims size (" - << dimensionNumbers.getCollapsedSliceDims().size() - << ") is not equal to operand rank (" - << operandShape.getRank() << ")"; - - // P8. - // This should be fully expressible with type constraints, but it isn't - // obvious how to do that with the current infrastructure. - if (sliceSizesShape.hasRank() && sliceSizesShape.getRank() != 1) - return errorEmitter() << "slice_sizes.rank != 1"; - if (sliceSizesShape.hasStaticShape()) { - int64_t sliceSize = sliceSizesShape.getNumElements(); - - // P9. - if (sliceSize != impliedOperandRank) - return errorEmitter() << "slice_sizes size (" << sliceSize - << ") not equal to (implied) operand rank (" - << impliedOperandRank << ")"; - - // P10. - for (auto dim : dimensionNumbers.getCollapsedSliceDims()) - if (dim < 0 || dim >= sliceSize) - return errorEmitter() << "collapsed dimension " << dim - << " is out of bounds for slice_sizes.size (" - << sliceSize << ")"; - } - - return success(); -} - -// Verify the following properties: -// P1. Verifications by verifyGather(). -// P2. Verify slice_sizes[i] <= 1 for i in collapsed_slice_dims. -// P3. Verify 0 <= slice_sizes[i] < shape(operand)[i], for every i. -static LogicalResult verifyStaticGather( - ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape, - DenseIntElementsAttr sliceSizes, - GatherDimensionNumbersAttr dimensionNumbers, - llvm::function_ref errorEmitter) { - // P1. - // For some reason the getType call is necessary here - if (failed(verifyGather( - /*operandShape=*/operandShape, - /*startIndicesShape=*/startIndicesShape, - /*sliceSizesShape=*/sliceSizes.getType(), dimensionNumbers, - errorEmitter))) - return failure(); - - // P2. - for (auto dim : dimensionNumbers.getCollapsedSliceDims()) { - int64_t sliceDimSize = sliceSizes.getValues()[dim]; - if (sliceDimSize > 1) { - return errorEmitter() << "slice_sizes collapsed dimension " << dim - << " should <= 1 but got " << sliceDimSize; - } - } - - // P3. - if (operandShape.hasRank()) { - for (const auto& it : llvm::enumerate(sliceSizes.getValues())) { - if (operandShape.isDynamicDim(it.index())) continue; - auto operandDimSize = operandShape.getDimSize(it.index()); - auto sliceDimSize = it.value(); - if (sliceDimSize < 0 || sliceDimSize > operandDimSize) - return errorEmitter() << "slice size (" << sliceDimSize - << ") is out of bounds for operand dimension (" - << operandDimSize << ") at index " << it.index(); - } - } - return success(); -} - -template -static void inferGatherShape( - int64_t resultRank, llvm::function_ref getStartIndicesDim, - llvm::function_ref getSliceDim, - GatherDimensionNumbersAttr dimensionNumbers, - SmallVectorImpl& shape) { - ArrayRef collapsedSliceDims = - dimensionNumbers.getCollapsedSliceDims(); - int64_t indexVectorDim = dimensionNumbers.getIndexVectorDim(); - - // We don't necessarily know the rank of sliceSizes, but we do know that it - // can't be larger than the highest collapsed dimension. So go through those - // and populate the leading dimensions of adjustedSliceSizes. The trailing - // dimensions can just be adjusted by an offset. - const auto* maxCollapsedDimIt = - std::max_element(collapsedSliceDims.begin(), collapsedSliceDims.end()); - int64_t maxCollapsedDim = -1; - if (maxCollapsedDimIt != collapsedSliceDims.end()) - maxCollapsedDim = *maxCollapsedDimIt; - - SmallVector adjustedSliceSizePrefix; - for (int dimIndex = 0; dimIndex <= maxCollapsedDim; ++dimIndex) { - if (llvm::is_contained(collapsedSliceDims, dimIndex)) continue; - adjustedSliceSizePrefix.push_back(getSliceDim(dimIndex)); - } - auto getAdjustedSliceDim = [&](int64_t index) -> dimTy { - if (index < static_cast(adjustedSliceSizePrefix.size())) - return adjustedSliceSizePrefix[index]; - return getSliceDim(index + collapsedSliceDims.size()); - }; - - ArrayRef offsetDims = dimensionNumbers.getOffsetDims(); - - // Dimensions in the output that aren't offset dimensions are called batch - // dimensions. - SmallVector batchDims; - for (int dim = 0; dim < resultRank; ++dim) - if (!llvm::is_contained(offsetDims, dim)) batchDims.push_back(dim); - - for (int i = 0; i < resultRank; ++i) { - const auto* offsetDimsIt = - std::find(offsetDims.begin(), offsetDims.end(), i); - if (offsetDimsIt != offsetDims.end()) { - auto index = std::distance(offsetDims.begin(), offsetDimsIt); - shape.push_back(getAdjustedSliceDim(index)); - continue; - } - auto* batchDimsIt = std::find(batchDims.begin(), batchDims.end(), i); - assert(batchDimsIt != batchDims.end()); - auto index = std::distance(batchDims.begin(), batchDimsIt); - // This can never run into the special case where start_indices gets - // implicitly expanded with a trailing 1 if - // index_vector_dim = start_indices.rank because then index would equal - // index_vector_dim, which means we'd be looking at index+1, which would be - // out of bounds anyway. - if (index >= indexVectorDim) ++index; - shape.push_back(getStartIndicesDim(index)); - } -} - -// Verify the following properties: -// P1. Verify 0 <= offset_dims[i] < output_shape_rank, for every i. -// (output_shape_rank = size(offset_dims) + rank(start_indices) -1) -static LogicalResult inferGatherReturnTypeComponents( - ShapeAdaptor operandShape, ShapeAdaptor startIndicesShape, - llvm::function_ref getSliceDim, - GatherDimensionNumbersAttr dimensionNumbers, - SmallVectorImpl& inferredReturnShapes, - llvm::function_ref errorEmitter) { - Type elementType = operandShape.getElementType(); - - // We need this to determine the result rank. We could still place bounds on - // the result rank if that was something ShapedTypeComponents could express. - if (!startIndicesShape.hasRank()) { - inferredReturnShapes.push_back(elementType); - return success(); - } - - ArrayRef offsetDims = dimensionNumbers.getOffsetDims(); - int64_t startIndicesRank = startIndicesShape.getRank(); - // If index_vector_dim == start_indices.rank, then an implicit trailing 1 is - // appended to start_indices shape. - if (dimensionNumbers.getIndexVectorDim() == startIndicesRank) - ++startIndicesRank; - int64_t resultRank = offsetDims.size() + startIndicesRank - 1; - // P1. - for (auto i : llvm::seq(0, offsetDims.size())) - if (offsetDims[i] < 0 || offsetDims[i] >= resultRank) - return errorEmitter() << "offset_dims[" << i << "]: " << offsetDims[i] - << " is out of bounds for " - << "implied result rank " << resultRank; - - auto getStartIndicesDim = [&](int64_t index) { - return startIndicesShape.getDimSize(index); - }; - - SmallVector shape; - inferGatherShape(resultRank, getStartIndicesDim, getSliceDim, - dimensionNumbers, shape); - - inferredReturnShapes.emplace_back(shape, elementType); - return success(); -} - template LogicalResult reifyGatherShape(Op* op, OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { @@ -1741,7 +1263,7 @@ LogicalResult reifyGatherShape(Op* op, OpBuilder& builder, ValueRange operands, Location loc = op->getLoc(); int resultRank = resultTy.getRank(); - Type shapeElTy = startIndices.getType().cast().getElementType(); + Type shapeElTy = builder.getIndexType(); auto toShapeElType = [&](Value v) { return maybeCastTo(builder, loc, v, shapeElTy); }; @@ -1759,8 +1281,12 @@ LogicalResult reifyGatherShape(Op* op, OpBuilder& builder, ValueRange operands, auto getSliceDim = [&sliceSizes](int64_t index) -> Value { return sliceSizes[index]; }; - inferGatherShape(resultRank, getStartIndicesDim, getSliceDim, - op->getDimensionNumbers(), shapeValues); + hlo::reifyGatherDimSizes(resultRank, getStartIndicesDim, getSliceDim, + op->getDimensionNumbers().getOffsetDims(), + op->getDimensionNumbers().getCollapsedSliceDims(), + op->getDimensionNumbers().getStartIndexMap(), + op->getDimensionNumbers().getIndexVectorDim(), + shapeValues); Value outputShape = builder.create( loc, RankedTensorType::get({resultRank}, shapeElTy), shapeValues); @@ -1777,46 +1303,18 @@ LogicalResult GatherOp::reifyReturnTypeShapes( return reifyGatherShape(this, builder, operands, reifiedReturnShapes); } -// The following properties are already enforced by the ODS: -// P0. Verify the start_indices has element type of integer. -// Verify the following properties: -// Verifications by verifyStaticGather() and verifyGather() inside it. -// Verifications by inferGatherReturnTypeComponents. LogicalResult GatherOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - // TODO(zhouxin) remove this comment after the ordering issue is clear. - // This can get called before other op verify methods, so we have to do a - // bunch of verification up front. With a better story for ordering and/or - // multi-phase op verification, this should hopefully all go away. - Location loc = location.value_or(UnknownLoc::get(context)); - auto errorEmitter = [&loc]() { - return mlir::emitError(loc) - << "'" << GatherOp::getOperationName() << "' op "; - }; GatherOp::Adaptor adaptor(operands, attributes, regions); - if (failed(adaptor.verify(loc))) return failure(); - - // We want the ShapeAdaptors, so can't route via the adaptor :-/ - ShapeAdaptor operandShape = operands.getShape(0); - ShapeAdaptor startIndicesShape = operands.getShape(1); - GatherDimensionNumbersAttr dimensionNumbers = adaptor.getDimensionNumbers(); - DenseIntElementsAttr sliceSizesAttr = adaptor.getSliceSizes(); - - if (failed(verifyStaticGather(/*operandShape=*/operandShape, - /*startIndicesShape=*/startIndicesShape, - /*sliceSizes=*/sliceSizesAttr, dimensionNumbers, - errorEmitter))) - return failure(); - - auto getSliceDim = [&sliceSizesAttr](int64_t index) -> int64_t { - return sliceSizesAttr.getValues()[index]; - }; - - return inferGatherReturnTypeComponents(operandShape, startIndicesShape, - getSliceDim, dimensionNumbers, - inferredReturnShapes, errorEmitter); + return hlo::inferGatherOp( + location, adaptor.getOperand(), adaptor.getStartIndices(), + adaptor.getDimensionNumbers().getOffsetDims(), + adaptor.getDimensionNumbers().getCollapsedSliceDims(), + adaptor.getDimensionNumbers().getStartIndexMap(), + adaptor.getDimensionNumbers().getIndexVectorDim(), + adaptor.getSliceSizes(), inferredReturnShapes); } //===----------------------------------------------------------------------===// @@ -1826,18 +1324,34 @@ LogicalResult GatherOp::inferReturnTypeComponents( // Canonicalize mhlo.dynamic_gather to mhlo.gather when slice_sizes is constant. LogicalResult simplifyDynamicGatherToGather(DynamicGatherOp op, PatternRewriter& rewriter) { - DenseIntElementsAttr sliceSizes; - if (!matchPattern(op.getSliceSizes(), m_Constant(&sliceSizes))) { + DenseIntElementsAttr dynamicGatherSliceSizes; + if (!matchPattern(op.getSliceSizes(), m_Constant(&dynamicGatherSliceSizes))) { return failure(); } + + // DynamicGatherOp's slice_sizes is 1DTensorOf<[HLO_DimensionValue]> + // where HLO_DimensionValue is AnyTypeOf<[Index, HLO_Int]>. + // However, GatherOp's slice_sizes is I64ElementsAttr. + // Therefore, we need to convert the elements in case there is a mismatch + // of element types. + DenseIntElementsAttr gatherSliceSizes = dynamicGatherSliceSizes; + if (!dynamicGatherSliceSizes.getType().getElementType().isInteger(64)) { + SmallVector sliceSizes; + for (APInt sliceSize : dynamicGatherSliceSizes.getValues()) { + sliceSizes.push_back(sliceSize.getSExtValue()); + } + gatherSliceSizes = rewriter.getI64TensorAttr(sliceSizes); + } + rewriter.replaceOpWithNewOp( op, op.getOperand(), op.getStartIndices(), op.getDimensionNumbersAttr(), - sliceSizes, op.getIndicesAreSortedAttr()); + gatherSliceSizes, op.getIndicesAreSortedAttr()); return success(); } void DynamicGatherOp::getCanonicalizationPatterns(RewritePatternSet& result, MLIRContext* context) { + // Disc disable // result.add(simplifyDynamicGatherToGather); } @@ -1851,43 +1365,29 @@ LogicalResult DynamicGatherOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { - // This can get called before other op verify methods, so we have to do a - // bunch of verification up front. With a better story for ordering and/or - // multi-phase op verification, this should hopefully all go away. - Location loc = location.value_or(UnknownLoc::get(context)); - auto errorEmitter = [&loc]() { - return mlir::emitError(loc) - << "'" << DynamicGatherOp::getOperationName() << "' op "; - }; DynamicGatherOp::Adaptor adaptor(operands, attributes, regions); - if (failed(adaptor.verify(loc))) return failure(); - - // We want the ShapeAdaptors, so can't route via the adaptor :-/ - ShapeAdaptor operandShape = operands.getShape(0); - ShapeAdaptor startIndicesShape = operands.getShape(1); - ShapeAdaptor sliceSizesShape = operands.getShape(2); - GatherDimensionNumbersAttr dimensionNumbers = adaptor.getDimensionNumbers(); - - if (failed(verifyGather(/*operandShape=*/operandShape, - /*startIndicesShape=*/startIndicesShape, - /*sliceSizesShape=*/sliceSizesShape, dimensionNumbers, - errorEmitter))) - return failure(); - - auto getSliceDim = [](int64_t index) { return ShapedType::kDynamicSize; }; - return inferGatherReturnTypeComponents(operandShape, startIndicesShape, - getSliceDim, dimensionNumbers, - inferredReturnShapes, errorEmitter); + return hlo::inferDynamicGatherOp( + location, adaptor.getOperand(), adaptor.getStartIndices(), + adaptor.getSliceSizes(), adaptor.getDimensionNumbers().getOffsetDims(), + adaptor.getDimensionNumbers().getCollapsedSliceDims(), + adaptor.getDimensionNumbers().getStartIndexMap(), + adaptor.getDimensionNumbers().getIndexVectorDim(), inferredReturnShapes); } //===----------------------------------------------------------------------===// // GetDimensionSizeOp //===----------------------------------------------------------------------===// -// + LogicalResult GetDimensionSizeOp::verify() { return verifyDimAttr(*this); } +LogicalResult GetDimensionSizeOp::inferReturnTypes( + MLIRContext* context, Optional location, ValueRange, + DictionaryAttr, RegionRange, SmallVectorImpl& inferredReturnTypes) { + return hlo::inferGetDimensionSizeOp(context, location, inferredReturnTypes); +} + /// Fold get_dimension_size when the said shape dimension is a constant. -OpFoldResult GetDimensionSizeOp::fold(ArrayRef attrs) { +OpFoldResult GetDimensionSizeOp::fold(FoldAdaptor) { RankedTensorType type = getOperand().getType().dyn_cast(); if (!type) return {}; @@ -1903,16 +1403,7 @@ OpFoldResult GetDimensionSizeOp::fold(ArrayRef attrs) { //===----------------------------------------------------------------------===// LogicalResult IotaOp::verify() { - auto shape = getType().cast(); - if (!shape.hasRank()) return success(); - - if (shape.getRank() == 0) return emitOpError() << "does not support scalars."; - - auto iotaDimension = static_cast(this->getIotaDimension()); - if (iotaDimension >= shape.getRank() || iotaDimension < 0) - return emitOpError() - << "iota dimension cannot go beyond the output rank or be negative."; - return success(); + return hlo::verifyIotaOp(getLoc(), getIotaDimension(), getResult()); } // Iota operations across multiple dimensions can be reduced to an iota and a @@ -1949,7 +1440,7 @@ void IotaOp::getCanonicalizationPatterns(RewritePatternSet& results, results.add(context); } -OpFoldResult IotaOp::fold(ArrayRef operands) { +OpFoldResult IotaOp::fold(FoldAdaptor /*adaptor*/) { auto dimension = getIotaDimension(); auto resultTy = getResult().getType().cast(); if (resultTy.hasRank() && resultTy.getDimSize(dimension) == 1) { @@ -2115,30 +1606,17 @@ LogicalResult DynamicIotaOp::reifyReturnTypeShapes( // DynamicUpdateSliceOp //===----------------------------------------------------------------------===// -LogicalResult DynamicUpdateSliceOp::verify() { - OperandRange indices = getStartIndices(); - if (indices.size() <= 1) return success(); - - // Note: start_indices is constrained to Variadic, so it - // is OK to cast indices to ShapedType here. - auto idxTensor = indices.take_front().front().getType().cast(); - Type firstElemTy = idxTensor.getElementType(); - Type elemTy; - - for (auto idx : llvm::drop_begin(indices, 1)) { - idxTensor = idx.getType().cast(); - elemTy = idxTensor.getElementType(); - - if (firstElemTy != elemTy) { - return emitOpError() << "start indices must have same element type " - "(encountered mismatch: " - << firstElemTy << " vs " << elemTy << ")"; - } - } - return success(); +LogicalResult DynamicUpdateSliceOp::inferReturnTypeComponents( + MLIRContext*, Optional location, ValueShapeRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnShapes) { + DynamicUpdateSliceOp::Adaptor adaptor(operands, attributes, regions); + return hlo::inferDynamicUpdateSliceOp( + location, adaptor.getOperand(), adaptor.getUpdate(), + adaptor.getStartIndices(), inferredReturnShapes); } -OpFoldResult DynamicUpdateSliceOp::fold(ArrayRef operands) { +OpFoldResult DynamicUpdateSliceOp::fold(FoldAdaptor /*adaptor*/) { auto operandShape = this->getOperand().getType().cast(); auto updateShape = this->getUpdate().getType().cast(); @@ -2169,25 +1647,11 @@ OpFoldResult DynamicUpdateSliceOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// LogicalResult AbsOp::inferReturnTypes( - MLIRContext*, Optional, ValueRange operands, DictionaryAttr, - RegionRange, SmallVectorImpl& inferredReturnTypes) { - auto operandTy = (*operands.begin()).getType().cast(); - Type elementTy = operandTy.getElementType(); - if (auto complexTy = elementTy.dyn_cast()) { - elementTy = complexTy.getElementType(); - } - - Type resultTy; - if (auto rankedOperandTy = operandTy.dyn_cast()) { - resultTy = RankedTensorType::get(operandTy.getShape(), elementTy, - rankedOperandTy.getEncoding()); - } else if (operandTy.hasRank()) { - resultTy = RankedTensorType::get(operandTy.getShape(), elementTy); - } else { - resultTy = UnrankedTensorType::get(elementTy); - } - inferredReturnTypes.push_back(resultTy); - return success(); + MLIRContext*, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnTypes) { + AbsOp::Adaptor adaptor(operands, attributes, regions); + return hlo::inferAbsOp(location, adaptor.getOperand(), inferredReturnTypes); } //===----------------------------------------------------------------------===// @@ -2195,8 +1659,7 @@ LogicalResult AbsOp::inferReturnTypes( //===----------------------------------------------------------------------===// LogicalResult CollectivePermuteOp::verify() { - return mlir::hlo::verifyCollectivePermuteSourceTargetPairs( - *this, getSourceTargetPairs()); + return hlo::verifyCollectivePermuteOp(getLoc(), getSourceTargetPairs()); } //===----------------------------------------------------------------------===// @@ -2205,7 +1668,7 @@ LogicalResult CollectivePermuteOp::verify() { namespace { - +// DISC-Begin template LogicalResult ConvReifyReturnTypeImpl( Op* op, OpBuilder& builder, ValueRange operands, @@ -2272,7 +1735,7 @@ LogicalResult ConvReifyReturnTypeImpl( if (lhs_dilation_attr) { Value input_dilation = to_shape_scalar_type(builder.create( - loc, lhs_dilation_attr.getValue().getValues()[i])); + loc, lhs_dilation_attr.value().getValues()[i])); effective_input_value = builder.create( loc, builder.create( @@ -2298,7 +1761,7 @@ LogicalResult ConvReifyReturnTypeImpl( if (rhs_dilation_attr) { Value kernel_dilation = to_shape_scalar_type(builder.create( - loc, rhs_dilation_attr.getValue().getValues()[i])); + loc, rhs_dilation_attr.value().getValues()[i])); effective_kernel_size_value = builder.create( loc, one, builder.create( @@ -2314,7 +1777,7 @@ LogicalResult ConvReifyReturnTypeImpl( if (window_strides_attr) { Value stride_value = to_shape_scalar_type(builder.create( - loc, window_strides_attr.getValue().getValues()[i])); + loc, window_strides_attr.value().getValues()[i])); output_dim_value = builder.create(loc, output_dim_value, stride_value); } @@ -2327,179 +1790,7 @@ LogicalResult ConvReifyReturnTypeImpl( reifiedReturnShapes.push_back(output_shape); return success(); } - -// Checks: -// P1. Same sizes for input, kernel and output spatial_dims. -// P2. Spatial and non-spatial dimentions (for input,kernel, &output) should -// be unique and in range [0, num_dims), where num_dims = rank of input -// (lhs/rhs) tensors. -// -// Note that the spatial + non-spatial dimensions may not cover all the -// dimensions in the range [0,num) because of the presence of 'unknown' -// dimensions (ref. cl/415132294). -LogicalResult isSpatialDimensionsValid(ConvolutionOp op) { - auto inputSpatialDimensions = - op.getDimensionNumbers().getInputSpatialDimensions(); - auto kernelSpatialDimensions = - op.getDimensionNumbers().getKernelSpatialDimensions(); - auto outputSpatialDimensions = - op.getDimensionNumbers().getOutputSpatialDimensions(); - - // P1. - if ((inputSpatialDimensions.size() != kernelSpatialDimensions.size()) || - (inputSpatialDimensions.size() != outputSpatialDimensions.size())) - return op.emitOpError() << "expects the same size for input, kernel and " - "output spatial-dimensions, but got " - << inputSpatialDimensions.size() << ", " - << kernelSpatialDimensions.size() << ", and " - << outputSpatialDimensions.size() << " resp."; - - // P2. - SmallVector inputDnums(inputSpatialDimensions.size() + 2); - inputDnums[0] = op.getDimensionNumbers().getInputBatchDimension(); - inputDnums[1] = op.getDimensionNumbers().getInputFeatureDimension(); - std::copy(inputSpatialDimensions.begin(), inputSpatialDimensions.end(), - inputDnums.begin() + 2); - - SmallVector windowDnums(kernelSpatialDimensions.size() + 2); - windowDnums[0] = op.getDimensionNumbers().getKernelInputFeatureDimension(); - windowDnums[1] = op.getDimensionNumbers().getKernelOutputFeatureDimension(); - std::copy(kernelSpatialDimensions.begin(), kernelSpatialDimensions.end(), - windowDnums.begin() + 2); - - SmallVector outputDnums(outputSpatialDimensions.size() + 2); - outputDnums[0] = op.getDimensionNumbers().getOutputBatchDimension(); - outputDnums[1] = op.getDimensionNumbers().getOutputFeatureDimension(); - std::copy(outputSpatialDimensions.begin(), outputSpatialDimensions.end(), - outputDnums.begin() + 2); - - auto numDims = op.getLhs().getType().cast().getRank(); - const auto inRange = [numDims](int64_t i) { return 0 <= i && i < numDims; }; - - if (!llvm::all_of(inputDnums, inRange) || - !llvm::all_of(windowDnums, inRange) || - !llvm::all_of(outputDnums, inRange)) - return op.emitOpError() << "expects input, kernel, and output " - "dimension-numbers to be in-range [0, " - << numDims << ")."; - - if (hasDuplicates(inputDnums)) - return op.emitOpError() - << "expects input dimension-numbers to be unique, got {" - << inputDnums << "}."; - - if (hasDuplicates(windowDnums)) - return op.emitOpError() - << "expects kernel dimension-numbers to be unique, got {" - << windowDnums << "}."; - - if (hasDuplicates(outputDnums)) - return op.emitOpError() - << "expects output dimension-numbers to be unique, got {" - << outputDnums << "}."; - - return success(); -} - -// Verifies the following properties: -// P1. The input, kernel, and output spatial-dimentions are valid. -// P2. Given, -// input-dimensions: b * input-spatial-dims * f -// kernel-dimensions: kernel-spatial-dims * i * o -// output-dimensions: b' * out-spatial-dims * f' -// where b = input-batch-dims -// where f = input-feature-dims -// where i = kernel-input-feature-dims -// where o = kernel-output-feature-dims -// where b' = output-batch-dims -// where f' = output-feature-dims -// Check the following properties w.r.t feature_group_count (fgc) and -// batch_group_count (bgc). -// fgc > 0, bgc > 1 and !(fgc > 1 && bgc > 1) -// b % bgc == 0 -// f % fgc == 0 and i = f / fgc -// o (or f') % bgc == 0 and o (or f') % fgc == 0 -LogicalResult verifyConvolutionAttributes(ConvolutionOp op) { - // P1. - if (failed(isSpatialDimensionsValid(op))) return failure(); - - // P2. - const int64_t featureGroupCount = op.getFeatureGroupCount(); - const int64_t batchGroupCount = op.getBatchGroupCount(); - - if (featureGroupCount <= 0) - return op.emitOpError() - << "expects feature_group_count to be a positive number, got " - << featureGroupCount << "."; - - if (batchGroupCount <= 0) - return op.emitOpError() - << "expects batch_group_count to be a positive number, got " - << batchGroupCount << "."; - - if (batchGroupCount > 1 && featureGroupCount > 1) - return op.emitOpError() - << "expects batch_group_count and feature_group_count not to be " - "both greater than 1. Got " - << batchGroupCount << " and " << featureGroupCount << " resp."; - - auto lhsType = op.getLhs().getType().cast(); - const int64_t inputFeatures = - lhsType.getShape()[op.getDimensionNumbers().getInputFeatureDimension()]; - const int64_t inputBatch = - lhsType.getShape()[op.getDimensionNumbers().getInputBatchDimension()]; - - auto rhsType = op.getRhs().getType().cast(); - const int64_t kernelInputFeatures = - rhsType.getShape()[op.getDimensionNumbers() - .getKernelInputFeatureDimension()]; - const int64_t kernelOutputFeatures = - rhsType.getShape()[op.getDimensionNumbers() - .getKernelOutputFeatureDimension()]; - - if (!hlo::isDynamicDimSize(kernelOutputFeatures)) { - if (kernelOutputFeatures % batchGroupCount != 0) - return op.emitOpError() << "expects output feature dimension size (" - << kernelOutputFeatures - << ") to be a multiple of " - "batch_group_count. Got batch_group_count = " - << batchGroupCount << "."; - - if (kernelOutputFeatures % featureGroupCount != 0) - return op.emitOpError() - << "expects kernel output feature dimension (" - << kernelOutputFeatures - << ") to be divisible by " - "feature_group_count. For feature_group_count = " - << featureGroupCount << "."; - } - - if (!hlo::isDynamicDimSize(inputFeatures)) { - if (inputFeatures % featureGroupCount != 0) - return op.emitOpError() - << "expects input feature dimension (" << inputFeatures - << ") to be a multiple of " - "feature_group_count. Got feature_group_count = " - << featureGroupCount << "."; - - if (!hlo::isDynamicDimSize(kernelInputFeatures) && - inputFeatures / featureGroupCount != kernelInputFeatures) - return op.emitOpError() - << "expects input feature dimension (" << inputFeatures - << ") / " - "feature_group_count = kernel input feature dimension (" - << kernelInputFeatures - << "). Got feature_group_count = " << featureGroupCount << "."; - } - - if (!hlo::isDynamicDimSize(inputBatch) && inputBatch % batchGroupCount != 0) - return op.emitOpError() << "expects input batch dimension (" << inputBatch - << ") to be divisible by " - "batch_group_count. Got batch_group_count = " - << batchGroupCount << "."; - - return success(); -} +// DISC-End // Infer the return-shape of ConvolutionOp. // Precondition: @@ -2511,6 +1802,8 @@ SmallVector inferConvolutionOpReturnShape( // output-shape. To do that we initilize the output dimensions with the shape // of the return-type and updates only the spatial + non-spatial dimensions. // Precondition 2 ensures that size of output-shape == size of input-shape. + // NOTE: This is a divergence from StableHLO which doesn't allow us to fully + // share ConvolutionOp's verification / shape inference logic with StableHLO. SmallVector outputDimensions = to_vector(op.getResult().getType().cast().getShape()); @@ -2537,7 +1830,7 @@ SmallVector inferConvolutionOpReturnShape( .getKernelOutputFeatureDimension()]; outputDimensions[op.getDimensionNumbers().getOutputBatchDimension()] = - hlo::isDynamicDimSize(inputBatch) ? ShapedType::kDynamicSize + hlo::isDynamicDimSize(inputBatch) ? ShapedType::kDynamic : inputBatch / op.getBatchGroupCount(); outputDimensions[op.getDimensionNumbers().getOutputFeatureDimension()] = kernelOutputFeatures; @@ -2657,7 +1950,6 @@ void ConvolutionOp::getCanonicalizationPatterns(RewritePatternSet& results, * P2. Verify the convolution atributes. * P3. Verify and collect the window atributes. * P4. Verify the return shape. - * TODO(b/232574102): Verify the element-type of return-value. */ LogicalResult ConvolutionOp::verify() { auto lhsType = getLhs().getType().dyn_cast(); @@ -2680,7 +1972,19 @@ LogicalResult ConvolutionOp::verify() { << lhsType << " and " << rhsType << "."; // P2. - if (failed(verifyConvolutionAttributes(*this))) return failure(); + if (failed(hlo::verifyConvolutionAttributes( + getLoc(), getLhs(), getRhs(), + getDimensionNumbers().getInputBatchDimension(), + getDimensionNumbers().getInputFeatureDimension(), + getDimensionNumbers().getInputSpatialDimensions(), + getDimensionNumbers().getKernelInputFeatureDimension(), + getDimensionNumbers().getKernelOutputFeatureDimension(), + getDimensionNumbers().getKernelSpatialDimensions(), + getDimensionNumbers().getOutputBatchDimension(), + getDimensionNumbers().getOutputFeatureDimension(), + getDimensionNumbers().getOutputSpatialDimensions(), + getFeatureGroupCount(), getBatchGroupCount(), getPrecisionConfig()))) + return failure(); // P3. auto kernelSpatialDimensions = @@ -2696,12 +2000,14 @@ LogicalResult ConvolutionOp::verify() { auto windowOrErr = hlo::verifyWindowAttributesAndInferWindowDimensions( windowDimensions, convertDenseIntAttr(getWindowStrides()), padding, convertDenseIntAttr(getLhsDilation()), - convertDenseIntAttr(getRhsDilation()), getLoc()); + convertDenseIntAttr(getRhsDilation()), + *hlo::convertWindowReversalAttribute(getWindowReversal(), getLoc(), + "window_reversal"), + getLoc()); if (failed(windowOrErr)) return failure(); // P4. auto actualReturnType = getResult().getType().cast(); - auto actualReturnElementType = actualReturnType.getElementType(); if (!actualReturnType.hasRank()) return success(); auto actualReturnRankedType = actualReturnType.cast(); @@ -2712,49 +2018,71 @@ LogicalResult ConvolutionOp::verify() { << actualReturnRankedType.getRank() << "."; auto expectedReturnShape = inferConvolutionOpReturnShape(*this, *windowOrErr); - auto expectedReturnType = - RankedTensorType::get(expectedReturnShape, actualReturnElementType); - if (failed(verifyCompatibleShape(expectedReturnType, actualReturnRankedType))) - return emitOpError() - << "has shape mismatch between the expected return-type (" - << expectedReturnType << ") and actual return-type (" - << actualReturnRankedType << ")."; + if (failed(verifyCompatibleShape(expectedReturnShape, + actualReturnRankedType.getShape()))) + return emitOpError() << "inferred shape '" + << hlo::dimSizesToString(expectedReturnShape) << "' " + << "is incompatible with return type of operation " + << actualReturnRankedType; return success(); } -LogicalResult ConvolutionOp::reifyReturnTypeShapes( - OpBuilder& builder, ValueRange operands, - SmallVectorImpl& reifiedReturnShapes) { - ConvolutionOp::Adaptor adaptor(operands); - Location loc = this->getLoc(); +//===----------------------------------------------------------------------===// +// DynamicConvOp +//===----------------------------------------------------------------------===// - DenseIntElementsAttr padding_attr = this->getPadding().getValue(); - Type shape_scalar_type = builder.getIndexType(); +namespace { - SmallVector spatial_padding_values; - if (padding_attr) { - for (int64_t pad : padding_attr.getValues()) { - Value pad_value = builder.create(loc, pad); - pad_value = maybeCastTo(builder, loc, pad_value, shape_scalar_type); - spatial_padding_values.push_back(pad_value); +struct DynamicConvIsConv : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mhlo::DynamicConvOp op, + PatternRewriter& rewriter) const override { + DenseIntElementsAttr padAttr; + if (!matchPattern(op.getDPadding(), m_Constant(&padAttr))) { + return rewriter.notifyMatchFailure(op, "non-constant d_padding found"); } + + SmallVector padArray; + for (APInt pad : padAttr.getValues()) { + padArray.push_back(pad.getZExtValue()); + } + + int64_t paddedDimCount = padArray.size() / 2; + auto newPadAttr = DenseIntElementsAttr::get( + RankedTensorType::get({paddedDimCount, 2}, rewriter.getI64Type()), + padArray); + + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getLhs(), op.getRhs(), op.getWindowStridesAttr(), + newPadAttr, op.getLhsDilationAttr(), op.getRhsDilationAttr(), + op.getWindowReversalAttr(), op.getDimensionNumbers(), + op.getFeatureGroupCount(), op.getBatchGroupCount(), + op.getPrecisionConfigAttr()); + return success(); } +}; + +} // namespace - return ConvReifyReturnTypeImpl(this, builder, operands, reifiedReturnShapes, - spatial_padding_values, shape_scalar_type); +void DynamicConvOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { + // DISC disable + // results.add(context); } +// DISC-Begin LogicalResult DynamicConvOp::reifyReturnTypeShapes( - OpBuilder& builder, ValueRange operands, - SmallVectorImpl& reifiedReturnShapes) { + OpBuilder &builder, ValueRange operands, + SmallVectorImpl &reifiedReturnShapes) { DynamicConvOp::Adaptor adaptor(operands); Value d_padding = adaptor.getDPadding(); RankedTensorType padding_type = d_padding.getType().dyn_cast(); // Not support unranked type a.t.m. - if (!padding_type) return failure(); + if (!padding_type) + return failure(); Location loc = this->getLoc(); Type shape_scalar_type = padding_type.getElementType(); @@ -2774,9 +2102,11 @@ LogicalResult DynamicConvOp::reifyReturnTypeShapes( spatial_padding_values.push_back(pad_value); } - return ConvReifyReturnTypeImpl(this, builder, operands, reifiedReturnShapes, - spatial_padding_values, shape_scalar_type); + return ConvReifyReturnTypeImpl( + this, builder, operands, reifiedReturnShapes, spatial_padding_values, + shape_scalar_type); } +// DISC-End //===----------------------------------------------------------------------===// // ConvertOp @@ -2794,7 +2124,8 @@ void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand, build(builder, result, resultTy, operand); } -OpFoldResult ConvertOp::fold(ArrayRef operands) { +OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); auto operandTy = getOperand().getType().cast(); auto resultTy = getResult().getType().cast(); if (operandTy == resultTy) return getOperand(); @@ -2883,56 +2214,6 @@ LogicalResult StochasticConvertOp::verify() { return success(); } -//===----------------------------------------------------------------------===// -// GetTupleElementOp -//===----------------------------------------------------------------------===// - -LogicalResult GetTupleElementOp::verify() { - auto indexVal = getIndex(); - auto operandType = getOperand().getType().cast(); - if (indexVal >= operandType.size()) { - return emitOpError( - llvm::formatv("index {0} is out of bounds of operand with size {1}", - indexVal, operandType.size())); - } - - auto expectedType = operandType.getType(indexVal); - if (getType() != expectedType) { - return emitOpError(llvm::formatv("has return type {0}, but expected {1}", - getType(), expectedType)); - } - return success(); -} - -OpFoldResult GetTupleElementOp::fold(ArrayRef operands) { - if (auto tupleOp = getOperand().getDefiningOp()) { - return tupleOp.getOperand(getIndex()); - } - - return {}; -} - -//===----------------------------------------------------------------------===// -// TupleOp -//===----------------------------------------------------------------------===// - -LogicalResult TupleOp::verify() { - auto opType = getType().dyn_cast(); - if (!opType) return emitOpError("tuple op with non-tuple result"); - if (getNumOperands() != opType.size()) - return emitOpError( - "number of operands to tuple expected to match number of types in " - "resultant tuple type"); - for (const auto& it : - llvm::enumerate(llvm::zip_first(getOperandTypes(), opType.getTypes()))) { - if (std::get<0>(it.value()) != std::get<1>(it.value())) - return emitOpError("has return type mismatch at ") - << it.index() << "th value (" << std::get<0>(it.value()) - << " != " << std::get<1>(it.value()) << ")"; - } - return success(); -} - namespace { // Pattern for unpacking and repacking the same tuple. @@ -2982,44 +2263,62 @@ LogicalResult AllToAllOp::inferReturnTypeComponents( DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { AllToAllOp::Adaptor adaptor(operands, attributes, regions); - Type operandType = adaptor.getOperand().getType(); - RankedTensorType operandRankedType = operandType.dyn_cast(); - if (!operandRankedType) { - inferredReturnShapes.emplace_back( - operandType.cast().getElementType()); + + bool isArrayAllToAll = adaptor.getSplitDimension() && + adaptor.getConcatDimension() && + adaptor.getSplitCount(); + if (!isArrayAllToAll) { + if (adaptor.getSplitDimension() || adaptor.getConcatDimension() || + adaptor.getSplitCount()) { + return emitOptionalError(location, + "TupleAllToAll should not have split_dimension, " + "concat_dimension or split_count attributes"); + } + + // TupleAllToAll has identical result and operand shapes. + for (size_t i = 0; i < operands.size(); ++i) { + auto rankedOperand = operands[i].getType().dyn_cast(); + if (rankedOperand) + inferredReturnShapes.emplace_back(rankedOperand.getShape(), + rankedOperand.getElementType(), + rankedOperand.getEncoding()); + else + inferredReturnShapes.emplace_back( + operands[i].getType().cast()); + } + return success(); } - int64_t inputRank = operandRankedType.getRank(); - int64_t splitDimension = static_cast(adaptor.getSplitDimension()); - int64_t concatDimension = static_cast(adaptor.getConcatDimension()); - if (splitDimension >= inputRank || splitDimension < 0) { - return emitOptionalError(location, "AllToAll split_dimension ", - splitDimension, - " is out-of-bounds for input rank ", inputRank); - } - if (concatDimension >= inputRank || concatDimension < 0) { - return emitOptionalError(location, "AllToAll concat_dimension ", - concatDimension, - " is out-of-bounds for input rank ", inputRank); - } - - // If operand is ranked, size of split dimension should be a multiple of split - // count. - int64_t splitCount = adaptor.getSplitCount(); - auto splitDimSize = operandRankedType.getDimSize(splitDimension); - if (splitDimSize % splitCount != 0) { - return emitOptionalError( - location, "split dimension has size ", splitDimSize, - ", expected to be a multiple of split_count ", splitCount); - } - SmallVector resultShape(operandRankedType.getShape().begin(), - operandRankedType.getShape().end()); - resultShape[splitDimension] /= splitCount; - resultShape[concatDimension] *= splitCount; - inferredReturnShapes.emplace_back(resultShape, - operandRankedType.getElementType()); - return success(); + if (adaptor.getOperand().size() != 1) { + return emitOptionalError(location, + "ArrayAllToAll should have exactly one operand"); + } + + return hlo::inferAllToAllOp( + location, adaptor.getOperand()[0], *adaptor.getSplitDimension(), + *adaptor.getConcatDimension(), *adaptor.getSplitCount(), + adaptor.getReplicaGroups(), inferredReturnShapes); +} + +void AllToAllOp::build(OpBuilder& odsBuilder, OperationState& odsState, + Type resultType, Value operand, + IntegerAttr splitDimension, IntegerAttr concatDimension, + IntegerAttr splitCount, + DenseIntElementsAttr replicaGroups) { + AllToAllOp::build(odsBuilder, odsState, resultType, operand, splitDimension, + concatDimension, splitCount, replicaGroups, + /*channel_handle=*/nullptr); +} + +void AllToAllOp::build(OpBuilder& odsBuilder, OperationState& odsState, + ::mlir::TypeRange resultType, ::mlir::ValueRange operand, + IntegerAttr splitDimension, IntegerAttr concatDimension, + IntegerAttr splitCount, + DenseIntElementsAttr replicaGroups) { + AllToAllOp::build(odsBuilder, odsState, resultType, operand, splitDimension, + concatDimension, splitCount, replicaGroups, + /*channel_handle=*/nullptr); } //===----------------------------------------------------------------------===// @@ -3027,27 +2326,9 @@ LogicalResult AllToAllOp::inferReturnTypeComponents( //===----------------------------------------------------------------------===// LogicalResult AllGatherOp::verify() { - // If operand and result are both ranked, then the size of the gather - // dimension in the result should be a multiple of the size of the gather - // dimension in the operand. - auto operandType = getOperand().getType().dyn_cast(); - auto resultType = getType().dyn_cast(); - uint64_t allGatherDimIndex = getAllGatherDim(); - if (!operandType || !resultType || - operandType.isDynamicDim(allGatherDimIndex) || - resultType.isDynamicDim(allGatherDimIndex)) - return success(); - if (operandType.getDimSize(allGatherDimIndex) == 0) - return emitOpError() << "operand gather dimension cannot be zero."; - if ((resultType.getDimSize(allGatherDimIndex) % - operandType.getDimSize(allGatherDimIndex)) != 0) - return emitOpError() - << "result gather dimension has size " - << resultType.getDimSize(allGatherDimIndex) - << ", expected to be a multiple of operand gather dimension size " - << operandType.getDimSize(allGatherDimIndex); - - return success(); + return hlo::verifyAllGatherOp(getLoc(), getOperand(), getAllGatherDim(), + getReplicaGroups(), getUseGlobalDeviceIds(), + getResult()); } void AllGatherOp::build(OpBuilder& odsBuilder, OperationState& odsState, @@ -3061,49 +2342,17 @@ void AllGatherOp::build(OpBuilder& odsBuilder, OperationState& odsState, } //===----------------------------------------------------------------------===// -// BatchNormGradOp +// AllReduceOp //===----------------------------------------------------------------------===// -LogicalResult verifyBatchNorm(Location loc, Value operand, - int64_t feature_index, Value scale) { - auto operandType = operand.getType().cast(); - if (feature_index >= operandType.getRank()) - return emitError(loc) << "expects feature_index to be smaller " - "than the rank of operand type; got feature_index " - << feature_index << ", and rank " - << operandType.getRank() << "."; - - if (feature_index < 0) - return emitError(loc) << "expects feature_index to be a " - << "non-negative number, got " << feature_index - << "."; - - // Note: the above checks '0 <= feature-index < operandType.getRank()' - // imply 'operand_type.getRank() >= 1'. - - const int64_t featureCount = operandType.getDimSize(feature_index); - const int64_t scaleShape = - scale.getType().cast().getDimSize(0); - // As ODS enforces `scale`, `mean`, `variance`, `offset` are AllShapesMatch, - // this also infers that featureCount is aligned with them. - if (scaleShape != featureCount) - return emitError(loc) << "expects the size of scale factor to be " - "same as the feature count," - " but the size of scale factor is " - << scaleShape << " and the feature count is " - << featureCount << "."; - - return success(); +LogicalResult AllReduceOp::verify() { + return hlo::verifyAllReduceOp(getLoc(), getOperand(), getReplicaGroups(), + getUseGlobalDeviceIds(), getComputation()); } -// Refer ODS for properties that are already enforced including shapes and -// element types. This verifier includes additional checks. -LogicalResult BatchNormGradOp::verify() { - if (failed(verifyBatchNorm(getLoc(), getOperand(), getFeatureIndex(), - getScale()))) - return failure(); - return success(); -} +//===----------------------------------------------------------------------===// +// BatchNormGradOp +//===----------------------------------------------------------------------===// LogicalResult BatchNormGradOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueShapeRange operands, @@ -3111,30 +2360,22 @@ LogicalResult BatchNormGradOp::inferReturnTypeComponents( SmallVectorImpl& inferredReturnShapes) { BatchNormGradOp::Adaptor adaptor(operands, attributes, regions); return hlo::inferBatchNormGradOp( - location, adaptor.getOperand(), adaptor.getScale(), - adaptor.getFeatureIndex(), inferredReturnShapes); + location, adaptor.getOperand(), adaptor.getScale(), adaptor.getMean(), + adaptor.getVariance(), adaptor.getGradOutput(), adaptor.getFeatureIndex(), + inferredReturnShapes); } //===----------------------------------------------------------------------===// // BatchNormTrainingOp //===----------------------------------------------------------------------===// -// Refer ODS for properties that are already enforced including shapes and -// element types. This verifier includes additional checks. -LogicalResult BatchNormTrainingOp::verify() { - if (failed(verifyBatchNorm(getLoc(), getOperand(), getFeatureIndex(), - getScale()))) - return failure(); - return success(); -} - LogicalResult BatchNormTrainingOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { BatchNormTrainingOp::Adaptor adaptor(operands, attributes, regions); return hlo::inferBatchNormTrainingOp( - location, adaptor.getOperand(), adaptor.getScale(), + location, adaptor.getOperand(), adaptor.getScale(), adaptor.getOffset(), adaptor.getFeatureIndex(), inferredReturnShapes); } @@ -3142,23 +2383,36 @@ LogicalResult BatchNormTrainingOp::inferReturnTypeComponents( // BatchNormInferenceOp //===----------------------------------------------------------------------===// -// Refer ODS for properties that are already enforced including shapes and -// element types. This verifier includes additional checks. -LogicalResult BatchNormInferenceOp::verify() { - if (failed(verifyBatchNorm(getLoc(), getOperand(), getFeatureIndex(), - getScale()))) - return failure(); - return success(); -} - LogicalResult BatchNormInferenceOp::inferReturnTypeComponents( MLIRContext* context, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { BatchNormInferenceOp::Adaptor adaptor(operands, attributes, regions); return hlo::inferBatchNormInferenceOp( - location, adaptor.getOperand(), adaptor.getScale(), - adaptor.getFeatureIndex(), inferredReturnShapes); + location, adaptor.getOperand(), adaptor.getScale(), adaptor.getOffset(), + adaptor.getMean(), adaptor.getVariance(), adaptor.getFeatureIndex(), + inferredReturnShapes); +} + +//===----------------------------------------------------------------------===// +// BitcastOp +//===----------------------------------------------------------------------===// + +OpFoldResult BitcastOp::fold(FoldAdaptor) { + if (getResult().getType() != getOperand().getType()) { + return {}; + } + + auto sourceLayout = + getOperation()->getAttrOfType("source_layout"); + auto resultLayout = + getOperation()->getAttrOfType("result_layout"); + + if (sourceLayout == resultLayout) { + return getOperand(); + } + + return {}; } //===----------------------------------------------------------------------===// @@ -3187,109 +2441,16 @@ LogicalResult BitcastConvertOp::reifyReturnTypeShapes( &reifiedReturnShapes); } -/* - * We intend to verify the following properties - * P1. We cannot convert between complex and real types (cf xla) - * P3. The dimensions of the operand and the target - * shape must match, except that the shape with the smaller element bitwidth has - * an appropriately-sized additional innermost dimension, e.g. - * ... x f32 => [bitcast_convert] => ... x 4 x i8 - * ... x 4 x i8 => [bitcast_convert] => ... x f32 - */ LogicalResult BitcastConvertOp::verify() { - auto operandTensorType = getOperand().getType().cast(); - auto targetTensorType = getResult().getType().cast(); - - // P1. - auto targetElt = targetTensorType.getElementType(); - auto operandElt = operandTensorType.getElementType(); - if (targetElt.isa() != operandElt.isa()) { - return emitOpError() - << "cannot convert between real and complex types, but got: " - << operandTensorType << " and " << targetTensorType; - } - - auto targetEltBitwidth = hlo::potentiallyComplexBitwidth(targetElt); - auto operandEltBitwidth = hlo::potentiallyComplexBitwidth(operandElt); - - // P2. - auto operandType = operandTensorType.dyn_cast(); - auto targetType = targetTensorType.dyn_cast(); - if (!operandType || !targetType) return success(); - - auto targetShape = targetType.getShape(); - auto operandShape = operandType.getShape(); - ArrayRef smallerEltShape, biggerEltShape; - Type smallerElt, biggerElt; - if (operandEltBitwidth < targetEltBitwidth) { - smallerEltShape = operandShape; - smallerElt = operandElt; - biggerEltShape = targetShape; - biggerElt = targetElt; - } else { - smallerEltShape = targetShape; - smallerElt = targetElt; - biggerEltShape = operandShape; - biggerElt = operandElt; - } - - ArrayRef smallerEltPrefix; - auto smallerEltBitwidth = std::min(targetEltBitwidth, operandEltBitwidth); - auto biggerEltBitwidth = std::max(targetEltBitwidth, operandEltBitwidth); - if (operandEltBitwidth != targetEltBitwidth) { - if (smallerEltShape.empty()) { - return emitOpError() << "does not allow the smaller element type to be " - "part of a 0d tensor, but got: " - << operandType << " and " << targetType << "."; - } - smallerEltPrefix = smallerEltShape.drop_back(); - if (!hlo::isDynamicDimSize(smallerEltShape.back()) && - smallerEltShape.back() * smallerEltBitwidth != biggerEltBitwidth) { - return emitOpError() << "requires compatible bitwidths. " - << "Got: " << operandType << " and " << targetType - << ", but " << smallerEltBitwidth << " * " - << smallerEltShape.back() - << " != " << biggerEltBitwidth << "."; - } - } else { - smallerEltPrefix = smallerEltShape; - } - - for (auto it : llvm::zip(smallerEltPrefix, biggerEltShape)) { - auto targetDim = std::get<0>(it); - auto operandDim = std::get<1>(it); - if (!hlo::isDynamicDimSize(targetDim) && - !hlo::isDynamicDimSize(operandDim)) { - if (targetDim != operandDim) { - return emitOpError() << "operand and result shapes must match except " - "for the innermost dimension of the shape with " - "the smaller element type. Got: " - << operandType << " and " << targetType << "."; - } - } - } - - return success(); + return hlo::verifyBitcastConvertOp(getLoc(), getOperand(), getResult()); } //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// -// TODO(b/129012527) These should be expressed as type constraints. -LogicalResult BroadcastOp::verify() { - auto sizes = getBroadcastSizes(); - auto sizesType = sizes.getType(); - auto sizesRank = sizesType.getRank(); - if (sizesRank != 1) { - return emitOpError(llvm::formatv( - "broadcast_sizes has rank {0} instead of rank 1", sizesRank)); - } - - return success(); -} - -OpFoldResult BroadcastOp::fold(ArrayRef attrs) { +OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { + auto attrs = adaptor.getOperands(); auto type = getType().cast(); auto sizesType = getBroadcastSizes().getType(); if (sizesType.getNumElements() == 0) { @@ -3324,22 +2485,9 @@ LogicalResult BroadcastOp::inferReturnTypeComponents( DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { BroadcastOp::Adaptor adaptor(operands, attributes, regions); - Value operand = adaptor.getOperand(); - auto operandType = operand.getType().dyn_cast(); - if (!operandType) return failure(); - - Type elementTy = operandType.getElementType(); - auto dimensionAttr = adaptor.getBroadcastSizes(); - for (int64_t size : dimensionAttr.getValues()) { - if (size < 0) - return emitOptionalError(location, - "Broadcast with negative dimension size ", size); - } - SmallVector shapeValues(dimensionAttr.getValues()); - llvm::append_range(shapeValues, operandType.getShape()); - - inferredReturnShapes.emplace_back(shapeValues, elementTy); - return success(); + return hlo::inferBroadcastOp(location, adaptor.getOperand(), + adaptor.getBroadcastSizes(), + inferredReturnShapes); } LogicalResult BroadcastOp::reifyReturnTypeShapes( @@ -3381,70 +2529,12 @@ LogicalResult BroadcastOp::reifyReturnTypeShapes( //===----------------------------------------------------------------------===// LogicalResult BroadcastInDimOp::verify() { - auto operandType = getOperand().getType().dyn_cast(); - if (!operandType) { - // The following verification checks all depend on knowing the rank of - // the operand. Bail out now if we don't know the rank of the operand. - return success(); - } - - auto operandRank = operandType.getRank(); - if (!getBroadcastDimensions()) { - if (operandRank == 0) { - return success(); - } - return emitOpError( - llvm::formatv("broadcast_dimensions is absent, but required because " - "operand has non-zero rank ({0})", - operandRank)); - } - - auto dimensionsType = getBroadcastDimensions().getType(); - auto dimensionsRank = dimensionsType.getRank(); - if (dimensionsRank != 1) { - return emitOpError(llvm::formatv( - "broadcast_dimensions has rank {0} instead of rank 1", dimensionsRank)); - } - - auto dimensionsSize = dimensionsType.getNumElements(); - if (dimensionsSize != operandRank) { - return emitOpError(llvm::formatv( - "broadcast_dimensions size ({0}) does not match operand rank ({1})", - dimensionsSize, operandRank)); - } - - auto dimensions = - llvm::to_vector(getBroadcastDimensions().getValues()); - if (hasDuplicates(dimensions)) - return emitOpError("broadcast_dimensions should not have duplicates"); - - auto resultType = getResult().getType().cast(); - auto resultRank = resultType.getRank(); - for (int i = 0; i != dimensionsSize; ++i) { - auto dimIndex = dimensions[i]; - if (dimIndex >= resultRank) { - return emitOpError( - llvm::formatv("broadcast_dimensions contains invalid value {0} for " - "result with rank {1}", - dimIndex, resultRank)); - } - - if (!operandType.isDynamicDim(i)) { - auto dimSize = operandType.getDimSize(i); - auto resultDimSize = resultType.getDimSize(dimIndex); - if (dimSize != 1 && dimSize != resultDimSize) { - return emitOpError( - llvm::formatv("size of operand dimension {0} ({1}) is not equal to " - "1 or size of result dimension {2} ({3})", - i, dimSize, dimIndex, resultDimSize)); - } - } - } - - return success(); + return hlo::verifyBroadcastInDimOp(getLoc(), getOperand(), + getBroadcastDimensions(), getResult()); } -OpFoldResult BroadcastInDimOp::fold(ArrayRef attrs) { +OpFoldResult BroadcastInDimOp::fold(FoldAdaptor adaptor) { + auto attrs = adaptor.getOperands(); auto type = getType().cast(); if (type == getOperand().getType()) { auto broadcastValues = getBroadcastDimensions().getValues(); @@ -3537,101 +2627,10 @@ void BroadcastInDimOp::getCanonicalizationPatterns(RewritePatternSet& results, //===----------------------------------------------------------------------===// LogicalResult DynamicBroadcastInDimOp::verify() { - auto operandType = getOperand().getType().dyn_cast(); - auto resultType = getResult().getType().dyn_cast(); - - // If either the operand or result are unranked, there is very little - // to verify statically. - if (!operandType || !resultType) { - return success(); - } - - auto outputDimensionsType = - getOutputDimensions().getType().cast(); - auto outputDimensionsSize = outputDimensionsType.getDimSize(0); - auto operandRank = operandType.getRank(); - auto resultRank = resultType.getRank(); - - // Verify broadcast_dimensions. - auto bcastDimensions = getBroadcastDimensions(); - auto bcastDimensionsType = getBroadcastDimensions().getType(); - auto bcastDimensionsRank = bcastDimensionsType.getRank(); - // TODO(laurenzo): Update the BroadcastDimAttr to constrain its rank to 1. - if (bcastDimensionsRank != 1) { - return emitOpError( - llvm::formatv("broadcast_dimensions has rank {0} instead of rank 1", - bcastDimensionsRank)); - } - - auto bcastDimensionsSize = bcastDimensionsType.getNumElements(); - if (bcastDimensionsSize != operandRank) { - return emitOpError(llvm::formatv( - "broadcast_dimensions size ({0}) does not match operand rank ({1})", - bcastDimensionsSize, operandRank)); - } - - if (resultRank < operandRank) { - return emitOpError( - llvm::formatv("result rank ({0}) is less than operand rank ({1})", - resultRank, operandRank)); - } - - for (int i = 0; i != bcastDimensionsSize; ++i) { - auto dimIndex = bcastDimensions.getValues()[i]; - if (dimIndex >= resultRank) { - return emitOpError( - llvm::formatv("broadcast_dimensions contains invalid value {0} for " - "result with rank {1}", - dimIndex, resultRank)); - } - - auto dimSize = operandType.getDimSize(i); - auto resultDimSize = resultType.getDimSize(dimIndex); - // Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we - // add a manual check for this. - if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) { - return emitOpError( - llvm::formatv("size of operand dimension {0} ({1}) is not compatible " - "with size of result dimension {2} ({3})", - i, dimSize, dimIndex, resultDimSize)); - } - } - - if (outputDimensionsSize != resultRank) { - return emitOpError( - llvm::formatv("result rank ({0}) is not equal to number of output " - "dimensions ({1})", - resultRank, outputDimensionsSize)); - } - - // Verify that the known expanding and non-expanding dimensions are a subset - // of the operand's dimensions. - int64_t numKnownExpansionBehavior = 0; - DenseSet knownExpansionBehavior; - auto collectExpansionBehaviorDims = - [&](const Optional& attr) { - if (!attr) return; - for (const APInt& it : *attr) { - numKnownExpansionBehavior++; - knownExpansionBehavior.insert(it.getLimitedValue()); - } - }; - collectExpansionBehaviorDims(getKnownExpandingDimensions()); - collectExpansionBehaviorDims(getKnownNonexpandingDimensions()); - if (knownExpansionBehavior.size() != numKnownExpansionBehavior) { - return emitOpError( - "duplicate expansion hint for at least one operand dimension"); - } - for (int64_t i : knownExpansionBehavior) { - if (i < 0 || i >= operandRank) { - return emitOpError( - llvm::formatv("hint for expanding dimension {0} does not refer to a " - "valid operand dimension", - i)); - } - } - - return success(); + return hlo::verifyDynamicBroadcastInDimOp( + getLoc(), getOperand(), getOutputDimensions(), getBroadcastDimensions(), + getKnownExpandingDimensions(), getKnownNonexpandingDimensions(), + getResult()); } namespace { @@ -3702,12 +2701,37 @@ class ChainedDynamicBroadcastInDimCanonicalization return success(); } }; + +// If all dimensions are known to be nonexpanding from the attribute, replace +// the dynamic broadcast with a cast. +class DynamicBroadcastInDimAllDimsNonExpanding + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op, + PatternRewriter& rewriter) const override { + auto resultType = op.getResult().getType().dyn_cast(); + if (!resultType) + return rewriter.notifyMatchFailure(op, "requires ranked result type"); + + if (!op.getKnownNonexpandingDimensions().has_value() || + op.getKnownNonexpandingDimensions()->size() != resultType.getRank()) { + return rewriter.notifyMatchFailure( + op, "known_nonexpanding_dimensions don't cover all output dims"); + } + + auto cast = rewriter.createOrFold(op.getLoc(), resultType, + op.getOperand()); + rewriter.replaceOp(op, cast); + return success(); + } +}; } // namespace void DynamicBroadcastInDimOp::getCanonicalizationPatterns( RewritePatternSet& results, MLIRContext* context) { results.add( context); @@ -3755,104 +2779,19 @@ LogicalResult ClampOp::verify() { return success(); } -LogicalResult ClampOp::inferReturnTypeComponents( - MLIRContext*, Optional /*location*/, ValueShapeRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl& inferredReturnShapes) { - ClampOp::Adaptor adaptor(operands, attributes, regions); - RankedTensorType operandType = - adaptor.getOperand().getType().cast(); - inferredReturnShapes.emplace_back(operandType.getShape(), - operandType.getElementType()); - return success(); -} - -LogicalResult ClampOp::reifyReturnTypeShapes( - OpBuilder& builder, ValueRange operands, - SmallVectorImpl& reifiedReturnShapes) { - // For `mhlo.clamp`, the first operand may be a scalar. - return hlo::deriveShapeFromOperand(&builder, getOperation(), operands[1], - &reifiedReturnShapes); -} - -OpFoldResult ClampOp::fold(ArrayRef operands) { - - auto val = operands[0].dyn_cast_or_null(); - if (!val) return {}; - - auto type = getElementTypeOrSelf(getType()); - if (!type.isF32() && !type.isF64()) return {}; - - auto shapedType = getType().cast(); - if (!shapedType.hasStaticShape()) return {}; - - DenseElementsAttr min_val = operands[1].dyn_cast_or_null(); - DenseElementsAttr max_val = operands[2].dyn_cast_or_null(); - if (!min_val || !max_val) return {}; - - // min/max/val should be the same shape. - // Or min/max must be a scalar of the same type of val - int64_t val_num = val.getNumElements(); - int64_t min_num = min_val.getNumElements(); - int64_t max_num = max_val.getNumElements(); - if (!((val_num == min_num && val_num == max_num) || - (val_num == max_num && min_num == 1) || - (val_num == min_num && max_num == 1) || (min_num == max_num == 1))) - return {}; - - int bitWidth = type.getIntOrFloatBitWidth(); - auto convertValue = [](APFloat value, int bitWidth) { - double converted_value = bitWidth == 32 ? value.convertToFloat() - : value.convertToDouble(); - return converted_value; - }; - double first_min_value = convertValue(*min_val.getValues().begin(), bitWidth); - double first_max_value = convertValue(*max_val.getValues().begin(), bitWidth); - auto val_start = val.getValues().begin(); - auto min_start = min_val.getValues().begin(); - auto max_start = max_val.getValues().begin(); - llvm::SmallVector values; - values.reserve(val.getNumElements()); - for (int64_t i=0; i(cur_value)); - else - values.emplace_back(cur_value); - } - - return DenseFPElementsAttr::get(shapedType, values); - -} - //===----------------------------------------------------------------------===// // ComplexOp //===----------------------------------------------------------------------===// LogicalResult ComplexOp::inferReturnTypes( - MLIRContext*, Optional, ValueRange operands, DictionaryAttr, - RegionRange, SmallVectorImpl& inferredReturnTypes) { - TensorType operandType = operands[0].getType().cast(); - ComplexType elementTy = ComplexType::get(operandType.getElementType()); - inferredReturnTypes.push_back( - hlo::getSameShapeTensorType(operandType, elementTy)); - return success(); + MLIRContext*, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnTypes) { + ComplexOp::Adaptor adaptor(operands, attributes, regions); + return hlo::inferComplexOp(location, adaptor.getLhs(), inferredReturnTypes); } -OpFoldResult ComplexOp::fold(ArrayRef operands) { +OpFoldResult ComplexOp::fold(FoldAdaptor) { auto realOp = getOperand(0).getDefiningOp(); auto imagOp = getOperand(1).getDefiningOp(); if (realOp && imagOp && realOp.getOperand() == imagOp.getOperand()) { @@ -3866,25 +2805,15 @@ OpFoldResult ComplexOp::fold(ArrayRef operands) { // ImagOp //===----------------------------------------------------------------------===// -namespace { -Type createRealType(TensorType type) { - auto elementTy = type.getElementType(); - if (auto complexTy = elementTy.dyn_cast()) { - elementTy = complexTy.getElementType(); - } - return hlo::getSameShapeTensorType(type, elementTy); -} -} // namespace - LogicalResult ImagOp::inferReturnTypes( - MLIRContext*, Optional, ValueRange operands, DictionaryAttr, - RegionRange, SmallVectorImpl& inferredReturnTypes) { - inferredReturnTypes.push_back( - createRealType(operands[0].getType().cast())); - return success(); + MLIRContext*, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnTypes) { + ImagOp::Adaptor adaptor(operands, attributes, regions); + return hlo::inferImagOp(location, adaptor.getOperand(), inferredReturnTypes); } -OpFoldResult ImagOp::fold(ArrayRef operands) { +OpFoldResult ImagOp::fold(FoldAdaptor) { if (auto complexOp = getOperand().getDefiningOp()) { return complexOp.getOperand(1); } @@ -3897,13 +2826,12 @@ OpFoldResult ImagOp::fold(ArrayRef operands) { //===----------------------------------------------------------------------===// LogicalResult IsFiniteOp::inferReturnTypes( - MLIRContext* ctx, Optional, ValueRange operands, DictionaryAttr, - RegionRange, SmallVectorImpl& inferredReturnTypes) { - auto argTy = operands.front().getType().cast(); - Builder b(ctx); - inferredReturnTypes.push_back( - hlo::getSameShapeTensorType(argTy, b.getI1Type())); - return success(); + MLIRContext* ctx, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnTypes) { + IsFiniteOp::Adaptor adaptor(operands, attributes, regions); + return hlo::inferIsFiniteOp(ctx, location, adaptor.getX(), + inferredReturnTypes); } //===----------------------------------------------------------------------===// @@ -3911,14 +2839,14 @@ LogicalResult IsFiniteOp::inferReturnTypes( //===----------------------------------------------------------------------===// LogicalResult RealOp::inferReturnTypes( - MLIRContext*, Optional, ValueRange operands, DictionaryAttr, - RegionRange, SmallVectorImpl& inferredReturnTypes) { - inferredReturnTypes.push_back( - createRealType(operands[0].getType().cast())); - return success(); + MLIRContext*, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnTypes) { + RealOp::Adaptor adaptor(operands, attributes, regions); + return hlo::inferRealOp(location, adaptor.getOperand(), inferredReturnTypes); } -OpFoldResult RealOp::fold(ArrayRef operands) { +OpFoldResult RealOp::fold(FoldAdaptor) { if (auto complexOp = getOperand().getDefiningOp()) { return complexOp.getOperand(0); } @@ -4012,84 +2940,9 @@ LogicalResult ConcatenateOp::inferReturnTypes( MLIRContext*, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - if (operands.empty()) { - return failure(); - } - - auto dimensionAttr = attributes.get("dimension").cast(); - auto dimension = dimensionAttr.getInt(); - - auto firstType = (*operands.begin()).getType().cast(); - auto outElement = firstType.getElementType(); - - // Find the first ranked input to determine the output rank. - for (auto type : operands.getTypes()) { - auto shapedType = type.cast(); - if (shapedType.hasRank()) { - firstType = shapedType; - break; - } - } - - // If all inputs are unranked, the result must be unranked. - if (!firstType.hasRank()) { - inferredReturnTypes.push_back(UnrankedTensorType::get(outElement)); - return success(); - } - - auto outShape = llvm::to_vector<6>(firstType.getShape()); - - // Determine what the non-concatenate dimensions should be. - for (auto type : operands.getTypes()) { - auto shapedTy = type.cast(); - if (!shapedTy.hasRank()) { - continue; - } - - for (const auto& it : llvm::enumerate(shapedTy.getShape())) { - // If a dimension is not dynamic, the output shape should match. - if (ShapedType::isDynamic(outShape[it.index()])) { - outShape[it.index()] = it.value(); - } - } - } - - outShape[dimension] = 0; - - for (auto operand : operands.getTypes()) { - auto type = operand.cast(); - if (!type.hasRank()) { - inferredReturnTypes.push_back(UnrankedTensorType::get(outElement)); - return success(); - } - - // If the dimension is dynamic we know the output dimension is dynamic. - auto dim = type.getShape()[dimension]; - if (ShapedType::isDynamic(dim)) { - outShape[dimension] = ShapedType::kDynamicSize; - break; - } - - outShape[dimension] += dim; - } - - bool allSparse = llvm::all_of(operands.getTypes(), [](Type t) -> bool { - return sparse_tensor::getSparseTensorEncoding(t) != nullptr; - }); - - sparse_tensor::SparseTensorEncodingAttr enc; - if (allSparse) { - // Picks the encoding from an abitrary input is fine and it will be lowered - // correctly by the sparse compiler (though efficiency might vary). - // TODO: Extra rules are needed to infer sparse encoding when inputs have - // different encodings for better efficiency. - enc = sparse_tensor::getSparseTensorEncoding(operands.getTypes()[0]); - } - - inferredReturnTypes.push_back( - RankedTensorType::get(outShape, outElement, enc)); - - return success(); + ConcatenateOp::Adaptor adaptor(operands, attributes, regions); + return hlo::inferConcatenateOp(location, adaptor.getVal(), + adaptor.getDimension(), inferredReturnTypes); } void ConcatenateOp::getCanonicalizationPatterns(RewritePatternSet& results, @@ -4145,7 +2998,8 @@ static Attribute foldConcatenate(ConcatenateOp* op, return {}; } -OpFoldResult ConcatenateOp::fold(ArrayRef operands) { +OpFoldResult ConcatenateOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); if (getNumOperands() == 1) return getOperand(0); ShapedType type = getResult().getType().cast(); @@ -4166,58 +3020,6 @@ OpFoldResult ConcatenateOp::fold(ArrayRef operands) { return DenseElementsAttr::get(type, ArrayRef()); } -LogicalResult ConcatenateOp::verify() { - RankedTensorType firstRankedType; - int firstRankedIndex; - int numOperands = getNumOperands(); - int64_t concatDimension = static_cast(getDimension()); - if (concatDimension < 0) { - return emitOpError( - llvm::formatv("dimension {0} is negative", concatDimension)); - } - for (int i = 0; i < numOperands; i++) { - auto secondType = getOperand(i).getType().dyn_cast(); - if (!secondType.hasRank()) { - continue; - } - - if (!firstRankedType) { - firstRankedType = secondType.cast(); - firstRankedIndex = i; - if (firstRankedType.getRank() == 0) - return emitOpError( - llvm::formatv("rank-0 values cannot be concatenated")); - if (concatDimension >= firstRankedType.getRank()) { - return emitOpError( - llvm::formatv("dimension {0} is out-of-bounds for input rank {1}", - concatDimension, firstRankedType.getRank())); - } - continue; - } - - if (firstRankedType.getRank() != secondType.getRank()) { - return emitOpError(llvm::formatv( - "operands ({0}) and ({1}) do not match rank", firstRankedIndex, i)); - } - - auto firstShape = firstRankedType.getShape(); - auto secondShape = secondType.getShape(); - for (int d = 0; d < firstRankedType.getRank(); ++d) { - if (!ShapedType::isDynamic(firstShape[d]) && - !ShapedType::isDynamic(secondShape[d]) && - firstShape[d] != secondShape[d] && d != concatDimension) { - return emitOpError(llvm::formatv( - "shapes of operand ({0}) and ({1}) do not match at non-concat " - "index: ({2}) != ({3}) at non-concat index {4}", - firstRankedIndex, i, - llvm::make_range(firstShape.begin(), firstShape.end()), - llvm::make_range(secondShape.begin(), secondShape.end()), d)); - } - } - } - return success(); -} - LogicalResult ConcatenateOp::reifyReturnTypeShapes( OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { @@ -4277,15 +3079,7 @@ LogicalResult ConcatenateOp::reifyReturnTypeShapes( //===----------------------------------------------------------------------===// LogicalResult DynamicReshapeOp::verify() { - auto resultType = getResult().getType().dyn_cast(); - auto outputShapeType = - getOutputShape().getType().dyn_cast(); - if (resultType && outputShapeType && outputShapeType.hasStaticShape() && - outputShapeType.getDimSize(0) != resultType.getRank()) { - return emitError() << "output should have a rank equal to the number of " - "elements in output_shape"; - } - return success(); + return hlo::verifyDynamicReshapeOp(getLoc(), getOutputShape(), getResult()); } LogicalResult DynamicReshapeOp::reifyReturnTypeShapes( @@ -4456,55 +3250,14 @@ void DynamicSliceOp::getCanonicalizationPatterns(RewritePatternSet& results, results.add(context); } -// Verifies that the number of slice sizes and the number of start indices match -LogicalResult DynamicSliceOp::verify() { - int numSliceSizes = getSliceSizes().getNumElements(); - int numStartIndices = getStartIndices().size(); - if (numStartIndices != numSliceSizes) { - return emitOpError() << "has mismatched number of slice sizes (" - << numSliceSizes << ") and number of start indices (" - << numStartIndices << ")"; - } - auto operandType = getOperand().getType().dyn_cast(); - if (!operandType) return failure(); - - if (operandType.getRank() != numStartIndices) { - return emitOpError() << "has mismatched number of start indices (" - << numStartIndices << ") and the rank of operand (" - << operandType.getRank() << ")"; - } - - for (int i = 0; i < numSliceSizes; ++i) { - int64_t sliceSize = getSliceSizes().getValues()[i]; - if (sliceSize < 0) { - return emitOpError() << "has negative size index to dynamic slice: " - << sliceSize; - } - if (!operandType.isDynamicDim(i)) { - int64_t dimSize = operandType.getDimSize(i); - if (sliceSize > dimSize) { - return emitOpError() << "has slice size " << sliceSize - << " greater than dimension size " << dimSize - << " in dimension " << i << " of operand"; - } - } - } - return success(); -} - LogicalResult DynamicSliceOp::inferReturnTypeComponents( MLIRContext*, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { DynamicSliceOp::Adaptor adaptor(operands, attributes, regions); - Value operand = adaptor.getOperand(); - auto operandType = operand.getType().dyn_cast(); - if (!operandType) return failure(); - - auto sliceSizes = adaptor.getSliceSizes(); - Type elementTy = operandType.getElementType(); - inferredReturnShapes.emplace_back(sliceSizes.getValues(), elementTy); - return success(); + return hlo::inferDynamicSliceOp( + location, adaptor.getOperand(), adaptor.getStartIndices(), + adaptor.getSliceSizes(), inferredReturnShapes); } //===----------------------------------------------------------------------===// @@ -4512,94 +3265,66 @@ LogicalResult DynamicSliceOp::inferReturnTypeComponents( //===----------------------------------------------------------------------===// // Verifies that operand rank matches start_indices/limit_indices/strides size LogicalResult RealDynamicSliceOp::verify() { - auto inputType = getOperand().getType().dyn_cast(); - // If operand is unranked, there is very little to verify statically. - if (!inputType) return success(); - int inputRank = inputType.getRank(); - - auto startType = getStartIndices().getType().cast(); - auto limitType = getLimitIndices().getType().cast(); - auto stridesType = getStrides().getType().cast(); - - if (inputRank != startType.getNumElements()) { - return emitOpError() << "has mismatched number of operand rank (" - << inputRank << ") and start_indices size (" - << startType.getNumElements() << ")"; - } - - if (inputRank != limitType.getNumElements()) { - return emitOpError() << "has mismatched number of operand rank (" - << inputRank << ") and limit_indices size (" - << limitType.getNumElements() << ")"; - } - - if (inputRank != stridesType.getNumElements()) { - return emitOpError() << "has mismatched number of operand rank (" - << inputRank << ") and strides size (" - << stridesType.getNumElements() << ")"; - } - - return success(); + return hlo::verifyRealDynamicSliceOp(getLoc(), getOperand(), + getStartIndices(), getLimitIndices(), + getStrides()); } namespace { -// Canonicalizes RealDynamicSlice ops that can be replaced instead with Slice -// ops. This canonicalization is applied the case when the `begin` input values -// are compile time constants and thus can be made into a tensor. -struct RealDynamicSliceIsStatic : public OpRewritePattern { +struct RealDSliceToDSlice : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(RealDynamicSliceOp realDynamicSlice, + LogicalResult matchAndRewrite(RealDynamicSliceOp op, PatternRewriter& rewriter) const override { - Location loc = realDynamicSlice.getLoc(); - Value input = realDynamicSlice.getOperand(); - Value output = realDynamicSlice.getResult(); - auto inputTy = input.getType().dyn_cast(); - auto outputTy = output.getType().dyn_cast(); - - if (!inputTy || !outputTy || !inputTy.hasStaticShape() || - !outputTy.hasStaticShape()) { - return failure(); - } - - int64_t inputRank = inputTy.getRank(); - - auto startVal = realDynamicSlice.getStartIndices(); - auto limitVal = realDynamicSlice.getLimitIndices(); - auto strideVal = realDynamicSlice.getStrides(); - auto startOp = startVal.getDefiningOp(); - auto limitOp = limitVal.getDefiningOp(); - auto strideOp = strideVal.getDefiningOp(); - if (!startOp || !limitOp || !strideOp) return failure(); - - auto startAttr = - startOp.getValue().dyn_cast_or_null(); - auto limitAttr = - limitOp.getValue().dyn_cast_or_null(); - auto strideAttr = - strideOp.getValue().dyn_cast_or_null(); - if (!startAttr || !limitAttr || !strideAttr) return failure(); - - SmallVector tempStartIndices; - SmallVector tempLimitIndices; - SmallVector tempStride; - for (int64_t dimIdx = 0; dimIdx < inputRank; dimIdx++) { - int64_t start = startAttr.getValues()[dimIdx].getInt(); - tempStartIndices.push_back(start); - int64_t limit = limitAttr.getValues()[dimIdx].getInt(); - tempLimitIndices.push_back(limit); - int64_t end = strideAttr.getValues()[dimIdx].getInt(); - tempStride.push_back(end); - } - - DenseIntElementsAttr sliceStartIndices = - rewriter.getI64TensorAttr(tempStartIndices); - DenseIntElementsAttr sliceLimitIndices = - rewriter.getI64TensorAttr(tempLimitIndices); - DenseIntElementsAttr sliceStrides = rewriter.getI64TensorAttr(tempStride); - auto result = rewriter.create(loc, input, sliceStartIndices, - sliceLimitIndices, sliceStrides); - rewriter.replaceOp(realDynamicSlice, {result}); + // This rewrite only works for unit strides because DynamicSliceOp + // doesn't support strides (i.e. it implicitly has unit strides). + DenseIntElementsAttr stridesAttr; + if (!matchPattern(op.getStrides(), m_Constant(&stridesAttr))) + return rewriter.notifyMatchFailure(op, "requires constant strides"); + if (!llvm::all_of(stridesAttr.getValues(), + [&](APInt stride) { return stride == 1; })) + return rewriter.notifyMatchFailure(op, "requires unit strides"); + + // Check that slice sizes are fully static (DynamicSliceOp style). + // To detect that, we check whether `limit_indices` is defined as + // `start_indices + constant` or `constant + start_indices`. + DenseIntElementsAttr sliceSizesAttr; + auto m_startIndices = matchers::m_Val(op.getStartIndices()); + if (!matchPattern( + op.getLimitIndices(), + m_Op(m_startIndices, m_Constant(&sliceSizesAttr))) && + !matchPattern(op.getLimitIndices(), + m_Op(m_Constant(&sliceSizesAttr), m_startIndices))) + return rewriter.notifyMatchFailure( + op, "requires limit indices equal to start indices plus constant"); + + // RealDynamicSliceOp can take tensors of integer or index element types. + // DynamicSliceOp::slice_sizes only supports i64 element type. + // Adapt accordingly in order to be compatible with DynamicSliceOp. + SmallVector sliceSizes; + for (auto element : sliceSizesAttr.getValues()) { + sliceSizes.push_back(element.getSExtValue()); + } + + // RealDynamicSliceOp::start_indices is a 1-dimensional tensor. + // DynamicSliceOp::start_indices is a vararg of 0-dimensional tensors. + // Adapt accordingly in order to be compatible with DynamicSliceOp. + SmallVector startIndices; + for (auto i = 0; i < static_cast(sliceSizes.size()); ++i) { + auto startIndex1D = rewriter.create( + op.getLoc(), op.getStartIndices(), rewriter.getI64TensorAttr(i), + rewriter.getI64TensorAttr(i + 1), rewriter.getI64TensorAttr(1)); + auto startIndex0DType = RankedTensorType::get( + {}, + op.getStartIndices().getType().cast().getElementType()); + auto startIndex0D = rewriter.create( + op.getLoc(), startIndex0DType, startIndex1D); + startIndices.push_back(startIndex0D); + } + + rewriter.replaceOpWithNewOp( + op, op.getOperand(), startIndices, + rewriter.getI64TensorAttr(sliceSizes)); return success(); } }; @@ -4607,7 +3332,8 @@ struct RealDynamicSliceIsStatic : public OpRewritePattern { void RealDynamicSliceOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + // DISC disable + // results.add(context); } LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes( @@ -4661,56 +3387,9 @@ LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes( // InfeedOp //===----------------------------------------------------------------------===// -// Checks that the result type is of the form `zero_or_more_type(s), -// mhlo::token` LogicalResult InfeedOp::verify() { - auto resultTypes = getResultTypes(); - if (resultTypes.empty()) - return emitOpError() - << "result is expected to be at least of size 1, but got " - << resultTypes.size(); - - if (!resultTypes[resultTypes.size() - 1].isa()) - return emitOpError() << "last element of result types is expected to " - "be of token type, but got " - << resultTypes[resultTypes.size() - 1]; - - // Verify layout attribute - constexpr char kLayoutAttr[] = "layout"; - if (!getOperation()->hasAttr(kLayoutAttr)) return success(); - - mlir::ArrayAttr layout = - getOperation()->getAttrOfType(kLayoutAttr); - if (!layout) - return emitOpError() << "layout-attribute expected to be of array-type."; - - if (layout.size() != resultTypes.size() - 1) { - return emitOpError() << "layout-attribute size must be " - << resultTypes.size() - 1 - << " (which is the number of " - "op-results - 1 (for token result)), but got " - << layout.size(); - } - - for (auto childLayout : layout) { - mlir::ArrayAttr childLayoutArr = childLayout.dyn_cast(); - if (!childLayoutArr) { - return emitOpError() << "layout-attribute expected to have " - "elements of type array, but got " - << childLayout; - } - - for (auto i : childLayoutArr) { - mlir::IntegerAttr attr = i.dyn_cast(); - if (!attr) { - return emitOpError() << "layout-attribute's leaf elements are " - "expected to be of type integer, but got " - << i; - } - } - } - - return success(); + auto dialect = getContext()->getLoadedDialect(); + return hlo::verifyInfeedOp(dialect, getLoc(), getLayout(), getResults()); } //===----------------------------------------------------------------------===// @@ -4726,7 +3405,7 @@ LogicalResult MapOp::inferReturnTypeComponents( adaptor.getComputation(), inferredReturnShapes); } -OpFoldResult MapOp::fold(ArrayRef operands) { +OpFoldResult MapOp::fold(FoldAdaptor) { mlir::Block& bb = getComputation().front(); mlir::Operation& frontOp = bb.front(); @@ -4748,57 +3427,47 @@ LogicalResult MapOp::reifyReturnTypeShapes( &reifiedReturnShapes); } +//===----------------------------------------------------------------------===// +// OutfeedOp +//===----------------------------------------------------------------------===// + +LogicalResult OutfeedOp::inferReturnTypes( + MLIRContext* context, Optional location, ValueRange, + DictionaryAttr, RegionRange, SmallVectorImpl& inferredReturnTypes) { + auto dialect = context->getLoadedDialect(); + return hlo::inferOutfeedOp(dialect, location, inferredReturnTypes); +} + +//===----------------------------------------------------------------------===// +// SendOp +//===----------------------------------------------------------------------===// + +LogicalResult SendOp::inferReturnTypes( + MLIRContext* context, Optional location, ValueRange, + DictionaryAttr, RegionRange, SmallVectorImpl& inferredReturnTypes) { + auto dialect = context->getLoadedDialect(); + return hlo::inferSendOp(dialect, location, inferredReturnTypes); +} + //===----------------------------------------------------------------------===// // RecvOp //===----------------------------------------------------------------------===// -// Checks that the result type is of the form `zero_or_more_type(s), -// mhlo::token` LogicalResult RecvOp::verify() { - auto resultTypes = getResultTypes(); - if (resultTypes.empty()) - return emitOpError() - << "result is expected to be at least of size 1, but got " - << resultTypes.size(); - if (!resultTypes[resultTypes.size() - 1].isa()) - return emitOpError() << "last element of result types is expected to " - "be of token type, but got " - << resultTypes[resultTypes.size() - 1]; - return success(); + auto dialect = getContext()->getLoadedDialect(); + return hlo::verifyRecvOp(dialect, getLoc(), getResults()); } //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// -OpFoldResult CopyOp::fold(ArrayRef operands) { return getOperand(); } +OpFoldResult CopyOp::fold(FoldAdaptor) { return getOperand(); } //===----------------------------------------------------------------------===// // ReduceWindowOp //===----------------------------------------------------------------------===// -namespace { -// Infer the return-type of ReduceWindowOp. -SmallVector inferReduceWindowOpReturnType( - ArrayRef inputTypes, ArrayRef initTypes, - const ArrayRef window) { - SmallVector outputTypes; - for (size_t i = 0; i < inputTypes.size(); ++i) { - if (!inputTypes[i].hasRank()) { - outputTypes.push_back( - UnrankedTensorType::get(initTypes[i].getElementType())); - continue; - } - - outputTypes.push_back(RankedTensorType::get( - inferWindowOutputShape(inputTypes[i].getShape(), window), - initTypes[i].getElementType())); - } - - return outputTypes; -} -} // namespace - LogicalResult ReduceWindowOp::inferReturnTypeComponents( MLIRContext*, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -4808,7 +3477,14 @@ LogicalResult ReduceWindowOp::inferReturnTypeComponents( location, adaptor.getInputs(), adaptor.getInitValues(), adaptor.getWindowDimensions(), adaptor.getWindowStrides(), adaptor.getBaseDilations(), adaptor.getWindowDilations(), - adaptor.getPadding(), adaptor.getBody(), inferredReturnShapes); + adaptor.getPadding(), inferredReturnShapes); +} + +LogicalResult ReduceWindowOp::verify() { + return hlo::verifyReduceWindowOp(getLoc(), getInputs(), getInitValues(), + getWindowDimensions(), getWindowStrides(), + getBaseDilations(), getWindowDilations(), + getPadding(), getBody()); } // Get the operation used for reduction applied to `result_index`th result. Its @@ -4831,6 +3507,117 @@ Operation* ReduceWindowOp::getReductionOp(int resultIndex) { return nullptr; } +bool isSplatZero(SplatElementsAttr attr) { + if (!attr) return false; + if (attr.getElementType().isa()) { + return attr.getSplatValue().isZero(); + } + if (attr.getElementType().isa()) { + return attr.getSplatValue().isZero(); + } + return false; +} + +LogicalResult ReduceWindowOp::fold(FoldAdaptor adaptor, + SmallVectorImpl& results) { + auto operands = adaptor.getOperands(); + const auto emptyOrAllEq = [](const Optional opt, + const int64_t n) { + return !opt.has_value() || + (opt->isSplat() && opt->getSplatValue().getInt() == n); + }; + const auto isSumReductionBody = [](mlir::Region& body) { + if (body.getNumArguments() != 2) return false; + auto returnOp = dyn_cast_or_null(body.back().getTerminator()); + if (!returnOp || returnOp.getNumOperands() != 1) return false; + auto addOp = returnOp.getOperand(0).getDefiningOp(); + if (!addOp) return false; + return (addOp.getLhs() == body.getArgument(0) && + addOp.getRhs() == body.getArgument(1)) || + (addOp.getLhs() == body.getArgument(1) && + addOp.getRhs() == body.getArgument(0)); + }; + + // Fold no-op single input sum reduction. + if (getInputs().size() == 1 && + isSplatZero(operands[1].dyn_cast_or_null()) && + emptyOrAllEq(getWindowDimensionsAttr(), 1) && + emptyOrAllEq(getWindowStrides(), 1) && + emptyOrAllEq(getBaseDilations(), 1) && + emptyOrAllEq(getWindowDilations(), 1) && emptyOrAllEq(getPadding(), 0) && + isSumReductionBody(getBody())) { + results.push_back(getInputs()[0]); + return success(); + } + + return failure(); +} + +// Builder that takes a constructor for its region and infers result types +void ReduceWindowOp::build( + OpBuilder& odsBuilder, OperationState& odsState, ValueRange inputs, + ValueRange init_values, DenseIntElementsAttr window_dimensions, + /*optional*/ DenseIntElementsAttr window_strides, + /*optional*/ DenseIntElementsAttr base_dilations, + /*optional*/ DenseIntElementsAttr window_dilations, + /*optional*/ DenseIntElementsAttr padding, + function_ref bodyBuilder) { + odsState.addOperands(inputs); + odsState.addOperands(init_values); + odsState.addAttribute(getWindowDimensionsAttrName(odsState.name), + window_dimensions); + if (window_strides) { + odsState.addAttribute(getWindowStridesAttrName(odsState.name), + window_strides); + } + if (base_dilations) { + odsState.addAttribute(getBaseDilationsAttrName(odsState.name), + base_dilations); + } + if (window_dilations) { + odsState.addAttribute(getWindowDilationsAttrName(odsState.name), + window_dilations); + } + if (padding) { + odsState.addAttribute(getPaddingAttrName(odsState.name), padding); + } + Region* region = odsState.addRegion(); + + llvm::SmallVector blockArgTypes; + llvm::SmallVector locs; + auto numValues = inputs.size() + init_values.size(); + blockArgTypes.reserve(numValues); + locs.reserve(numValues); + for (auto i : inputs) { + auto iType = i.getType().cast(); + blockArgTypes.push_back(iType.cloneWith( + llvm::ArrayRef(std::nullopt), iType.getElementType())); + locs.push_back(i.getLoc()); + } + for (auto i : init_values) { + auto iType = i.getType().cast(); + blockArgTypes.push_back(iType.cloneWith( + llvm::ArrayRef(std::nullopt), iType.getElementType())); + locs.push_back(i.getLoc()); + } + + { + OpBuilder::InsertionGuard g(odsBuilder); + Block* body = + odsBuilder.createBlock(region, /*insertPt=*/{}, blockArgTypes, locs); + bodyBuilder(odsBuilder, odsState.location, body->getArguments()); + } + + llvm::SmallVector inferredReturnTypes; + if (mlir::succeeded(ReduceWindowOp::inferReturnTypes( + odsBuilder.getContext(), odsState.location, odsState.operands, + odsState.attributes.getDictionary(odsState.getContext()), + odsState.regions, inferredReturnTypes))) + odsState.addTypes(inferredReturnTypes); + else + llvm::report_fatal_error("Failed to infer result type(s)."); +} + //===----------------------------------------------------------------------===// // ReducePrecisionOp //===----------------------------------------------------------------------===// @@ -4841,10 +3628,8 @@ Operation* ReduceWindowOp::getReductionOp(int resultIndex) { // We intend to verify the following properties // P2. exponent_bits >= 1 LogicalResult ReducePrecisionOp::verify() { - if (getExponentBits() < 1) { - return emitOpError() << "exponent_bits must be at least 1."; - } - return success(); + return hlo::verifyReducePrecisionOp(getLoc(), getExponentBits(), + getMantissaBits()); } //===----------------------------------------------------------------------===// @@ -4908,7 +3693,8 @@ static Attribute foldReverseHelper(DenseElementsAttr& attr, ShapedType& type, return DenseElementsAttr::get(type, result); } -OpFoldResult ReverseOp::fold(ArrayRef operands) { +OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); Value input = getOperand(); // No dimensions to reverse. @@ -4938,39 +3724,11 @@ OpFoldResult ReverseOp::fold(ArrayRef operands) { return {}; } -LogicalResult ReverseOp::reifyReturnTypeShapes( - OpBuilder& builder, ValueRange operands, - SmallVectorImpl& reifiedReturnShapes) { - return mlir::hlo::deriveShapeFromOperand(&builder, getOperation(), operands.front(), - &reifiedReturnShapes); -} - //===----------------------------------------------------------------------===// // ReduceOp //===----------------------------------------------------------------------===// -// Returns the result type after reducing operand of the given type across the -// specified dimensions. -static TensorType getReduceResultType(Type operandTy, - DenseIntElementsAttr dimensions) { - Type elementTy = getElementTypeOrSelf(operandTy); - - auto rankedTy = operandTy.dyn_cast(); - if (!rankedTy) return UnrankedTensorType::get(elementTy); - - int64_t rank = rankedTy.getRank(); - llvm::SmallVector dimsMask(rank, false); - for (int64_t dim : dimensions.getValues()) dimsMask[dim] = true; - - SmallVector shape; - for (int64_t i = 0; i < rank; ++i) { - if (!dimsMask[i]) shape.push_back(rankedTy.getDimSize(i)); - } - - return RankedTensorType::get(shape, elementTy); -} - -LogicalResult ReduceOp::fold(ArrayRef operands, +LogicalResult ReduceOp::fold(FoldAdaptor /*adaptor*/, SmallVectorImpl& results) { // No dimensions to reduce. if (getDimensions().getNumElements() == 0) { @@ -5311,22 +4069,18 @@ ParseResult ReduceOp::parse(OpAsmParser& parser, OperationState& result) { } LogicalResult ReduceOp::inferReturnTypeComponents( - MLIRContext*, Optional, ValueShapeRange operands, + MLIRContext*, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { ReduceOp::Adaptor adaptor(operands, attributes, regions); - for (auto input : adaptor.getInputs()) { - ShapedType outputType = - getReduceResultType(input.getType(), adaptor.getDimensions()); - inferredReturnShapes.emplace_back(outputType); - } - return success(); + return hlo::inferReduceOp(location, adaptor.getInputs(), + adaptor.getInitValues(), adaptor.getDimensions(), + inferredReturnShapes); } LogicalResult ReduceOp::verify() { - SmallVector unusedReturnShapes; - return hlo::inferReduceOp(getLoc(), getInputs(), getInitValues(), - getDimensions(), getBody(), unusedReturnShapes); + return hlo::verifyReduceOp(getLoc(), getInputs(), getInitValues(), + getDimensions(), getBody()); } // Enable constant folding to occur within the region of the ReduceOp @@ -5429,19 +4183,44 @@ LogicalResult ReduceOp::reifyReturnTypeShapes( return success(); } +//===----------------------------------------------------------------------===// +// OptimizationBarrierOp +//===----------------------------------------------------------------------===// +LogicalResult OptimizationBarrierOp::inferReturnTypes( + MLIRContext*, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange, + SmallVectorImpl& inferredReturnTypes) { + OptimizationBarrierOp::Adaptor adaptor(operands, attributes); + return hlo::inferOptimizationBarrierOp(location, adaptor.getOperand(), + inferredReturnTypes); +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// +LogicalResult ReturnOp::inferReturnTypes( + MLIRContext*, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange, + SmallVectorImpl& inferredReturnTypes) { + ReturnOp::Adaptor adaptor(operands, attributes); + return hlo::inferReturnOp(location, inferredReturnTypes); +} + +//===----------------------------------------------------------------------===// +// ReverseOp +//===----------------------------------------------------------------------===// +LogicalResult ReverseOp::verify() { + return hlo::verifyReverseOp(getLoc(), getOperand(), getDimensions()); +} + //===----------------------------------------------------------------------===// // RngBitGeneratorOp //===----------------------------------------------------------------------===// // Verify that input state has the same shape as output shape LogicalResult RngBitGeneratorOp::verify() { - auto initialShape = getInitialState().getType().dyn_cast(); - auto outputShape = getOutputState().getType().dyn_cast(); - if (initialShape.getShape() != outputShape.getShape()) - return emitOpError() - << "output state shape must match initial state shape. Got: " - << initialShape << " and " << outputShape; - return success(); + return hlo::verifyRngBitGeneratorOp(getLoc(), getInitialState(), + getOutputState()); } //===----------------------------------------------------------------------===// @@ -5449,16 +4228,8 @@ LogicalResult RngBitGeneratorOp::verify() { //===----------------------------------------------------------------------===// LogicalResult RngOp::verify() { - auto dist = getRngDistribution(); - if (dist == RngDistribution::UNIFORM) { - return success(); - } - auto muTy = getA().getType().cast().getElementType(); - auto sigmaTy = getB().getType().cast().getElementType(); - if (muTy.isa() && sigmaTy.isa()) { - return success(); - } - return emitOpError() << "mu and sigma must be floats"; + return hlo::verifyRngOp(getLoc(), getA(), getB(), + getRngDistribution() == RngDistribution::UNIFORM); } LogicalResult RngOp::inferReturnTypeComponents( @@ -5509,29 +4280,8 @@ LogicalResult XlaRngGetAndUpdateStateOp::inferReturnTypes( // SelectOp //===----------------------------------------------------------------------===// -LogicalResult SelectOp::verify() { - // The operands 'on_true' and 'on_false' should have compatible types, i.e., - // (a) have the same element type, and - // (b) have compatible shapes (i.e. the same shape and/or at least one - // dynamic shape) - if (!hlo::compatibleShapeAndElementType(getOnTrue().getType(), - getOnFalse().getType())) - return emitOpError() - << "requires compatible types for non-predicate operands"; - - // The predicate, if not-scalar, should have the same shape as the remaining - // operands. - auto predTy = getPred().getType().dyn_cast(); - bool predMayBeScalar = !predTy || predTy.getRank() == 0; - if (predMayBeScalar) return success(); - - if (failed(verifyCompatibleShape(getPred().getType(), getOnTrue().getType()))) - return emitOpError() << "requires the same shape for all operands"; - - return success(); -} - -OpFoldResult SelectOp::fold(ArrayRef operands) { +OpFoldResult SelectOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); if (getOnTrue() == getOnFalse()) { return getOnTrue(); } @@ -5554,24 +4304,9 @@ OpFoldResult SelectOp::fold(ArrayRef operands) { return {}; } -// simplify select(not(%pred), true_value, false_value) => select(%pred, -// false_value, true_value) -static LogicalResult selectCanonicalization(SelectOp selectOp, - PatternRewriter& rewriter) { - auto notOp = selectOp.getPred().getDefiningOp(); - if (!notOp) { - return failure(); - } - std::array newOperands = {notOp.getOperand(), selectOp.getOnFalse(), - selectOp.getOnTrue()}; - rewriter.updateRootInPlace( - selectOp, [&]() { selectOp.getOperation()->setOperands(newOperands); }); - return success(); -} - void SelectOp::getCanonicalizationPatterns(RewritePatternSet& results, - MLIRContext* /*context*/) { - results.add(&selectCanonicalization); + MLIRContext* context) { + results.add(context); } // Makes it such that a SelectOp that is a non-root operation in a DRR infers @@ -5581,28 +4316,8 @@ LogicalResult SelectOp::inferReturnTypeComponents( DictionaryAttr attributes, RegionRange, SmallVectorImpl& inferredReturnShapes) { SelectOp::Adaptor op(operands, attributes); - auto trueType = op.getOnTrue().getType().cast(); - auto falseType = op.getOnFalse().getType().cast(); - - // The output shape should be the most general of the operand shapes at each - // dimension. - ShapedTypeComponents& outputType = inferredReturnShapes.emplace_back(); - if (trueType == falseType || !trueType.hasRank()) { - outputType = ShapedTypeComponents(trueType.cast()); - } else if (!falseType.hasRank()) { - outputType = ShapedTypeComponents(falseType.cast()); - } else { - assert(trueType.getRank() == falseType.getRank()); - llvm::SmallVector dims; - dims.reserve(trueType.getRank()); - for (auto dim : llvm::zip(trueType.getShape(), falseType.getShape())) { - dims.push_back(std::get<0>(dim) == std::get<1>(dim) - ? std::get<0>(dim) - : ShapedType::kDynamicSize); - } - outputType = ShapedTypeComponents(dims, trueType.getElementType()); - } - return success(); + return hlo::inferSelectOp(location, op.getPred(), op.getOnTrue(), + op.getOnFalse(), inferredReturnShapes); } LogicalResult SelectOp::reifyReturnTypeShapes( @@ -5626,7 +4341,8 @@ LogicalResult SetDimensionSizeOp::verify() { return verifyDimAttr(*this); } -OpFoldResult SetDimensionSizeOp::fold(ArrayRef operands) { +OpFoldResult SetDimensionSizeOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); DenseElementsAttr input = operands[0].dyn_cast_or_null(); if (input) return input; @@ -5667,13 +4383,13 @@ LogicalResult SetDimensionSizeOp::inferReturnTypes( } auto shape = llvm::to_vector<4>(inputType.getShape()); - llvm::SmallVector bounds(rank, ShapedType::kDynamicSize); + llvm::SmallVector bounds(rank, ShapedType::kDynamic); if (auto encoding = inputType.getEncoding().dyn_cast_or_null()) bounds = llvm::to_vector<4>(encoding.getBounds()); - if (shape[dim] != ShapedType::kDynamicSize) bounds[dim] = shape[dim]; - shape[dim] = ShapedType::kDynamicSize; + if (shape[dim] != ShapedType::kDynamic) bounds[dim] = shape[dim]; + shape[dim] = ShapedType::kDynamic; DenseIntElementsAttr sizeAttr; if (matchPattern(adaptor.getSize(), m_Constant(&sizeAttr))) { @@ -5681,14 +4397,13 @@ LogicalResult SetDimensionSizeOp::inferReturnTypes( sizeAttr.getSplatValue().getValue().getSExtValue(); if (splat == bounds[dim]) { shape[dim] = splat; - bounds[dim] = ShapedType::kDynamicSize; + bounds[dim] = ShapedType::kDynamic; } } auto extensions = TypeExtensionsAttr::get(context, bounds); auto resultType = - llvm::all_of(bounds, - [](int64_t v) { return v == ShapedType::kDynamicSize; }) + llvm::all_of(bounds, [](int64_t v) { return v == ShapedType::kDynamic; }) ? RankedTensorType::get(shape, inputType.getElementType()) : RankedTensorType::get(shape, inputType.getElementType(), extensions); @@ -5700,79 +4415,15 @@ LogicalResult SetDimensionSizeOp::inferReturnTypes( // PadOp //===----------------------------------------------------------------------===// -LogicalResult PadOp::inferReturnTypeComponents( - MLIRContext*, Optional location, ValueShapeRange operands, +LogicalResult PadOp::inferReturnTypes( + MLIRContext*, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl& inferredReturnShapes) { + SmallVectorImpl& inferredReturnTypes) { PadOp::Adaptor adaptor(operands, attributes, regions); - auto inputType = adaptor.getOperand().getType().cast(); - auto padType = adaptor.getPaddingValue().getType().cast(); - - if (padType.getRank() != 0) { - return emitOptionalError( - location, llvm::formatv("padding value type should be a rank-0 " - "tensor, is rank {0}", - padType.getRank())); - } - - const auto& paddingLow = adaptor.getEdgePaddingLow(); - if (paddingLow.getType().getNumElements() != inputType.getRank()) { - return emitOptionalError( - location, - llvm::formatv( - "edge_padding_low length ({0}) must match operand rank ({1})", - paddingLow.getType().getNumElements(), inputType.getRank())); - } - - const auto& paddingHigh = adaptor.getEdgePaddingHigh(); - if (paddingHigh.getType().getNumElements() != inputType.getRank()) { - return emitOptionalError( - location, - llvm::formatv( - "edge_padding_high length ({0}) must match operand rank ({1})", - paddingHigh.getType().getNumElements(), inputType.getRank())); - } - - const auto& paddingInterior = adaptor.getInteriorPadding(); - if (paddingInterior.getType().getNumElements() != inputType.getRank()) { - return emitOptionalError( - location, - llvm::formatv( - "interior_padding length ({0}) must match operand rank ({1})", - paddingInterior.getType().getNumElements(), inputType.getRank())); - } - - auto inputShape = inputType.getShape(); - SmallVector resultShape; - for (int i = 0, e = inputShape.size(); i < e; i++) { - if (hlo::isDynamicDimSize(inputShape[i])) { - resultShape.push_back(ShapedType::kDynamicSize); - continue; - } - - int64_t paddingLowVal = paddingLow.getValues()[i].getSExtValue(); - int64_t paddingHighVal = paddingHigh.getValues()[i].getSExtValue(); - int64_t paddingInteriorVal = - paddingInterior.getValues()[i].getSExtValue(); - if (paddingInteriorVal < 0) { - return emitOptionalError( - location, llvm::formatv("Interior padding cannot be negative: {0}", - paddingInteriorVal)); - } - int64_t expectedOutput = - inputShape[i] + paddingLowVal + paddingHighVal + - std::max(inputShape[i] - 1, 0LL) * paddingInteriorVal; - if (expectedOutput < 0) { - return emitOptionalError( - location, - llvm::formatv("Padding result in negative size for dimension {0}", - i)); - } - resultShape.push_back(expectedOutput); - } - inferredReturnShapes.emplace_back(resultShape, inputType.getElementType()); - - return success(); + return hlo::inferPadOp(location, adaptor.getOperand(), + adaptor.getPaddingValue(), adaptor.getEdgePaddingLow(), + adaptor.getEdgePaddingHigh(), + adaptor.getInteriorPadding(), inferredReturnTypes); } template @@ -5816,7 +4467,8 @@ OpFoldResult padOpFoldHelper(DenseElementsAttr input, DenseElementsAttr padding, return DenseElementsAttr::get(returnType, result); } -OpFoldResult PadOp::fold(ArrayRef operands) { +OpFoldResult PadOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); // If all padding is zero then it is an identity pad. auto isZero = [](const APInt& i) { return i == 0; }; if (llvm::all_of(getEdgePaddingLow().getValues(), isZero) && @@ -6013,49 +4665,9 @@ void DynamicPadOp::getCanonicalizationPatterns(RewritePatternSet& results, } LogicalResult DynamicPadOp::verify() { - auto inputType = getOperand().getType().dyn_cast(); - // If operand is unranked, there is very little to verify statically. - if (!inputType) return success(); - int inputRank = inputType.getRank(); - - auto padType = getPaddingValue().getType().cast(); - if (padType.getRank() != 0) { - return emitOpError() << "padding value type should be a rank-0"; - } - - auto paddingLowType = getEdgePaddingLow().getType().cast(); - if (paddingLowType.getNumElements() != inputRank) { - return emitOpError() << "edge_padding_low length(" - << paddingLowType.getNumElements() - << ") must match operand rank(" << inputRank << ")."; - } - - auto paddingHighType = - getEdgePaddingHigh().getType().cast(); - if (paddingHighType.getNumElements() != inputRank) { - return emitOpError() << "edge_padding_high length(" - << paddingHighType.getNumElements() - << ") must match operand rank(" << inputRank << ")."; - } - - auto interiorPaddingType = - getInteriorPadding().getType().cast(); - if (interiorPaddingType.getNumElements() != inputRank) { - return emitOpError() << "edge_padding_interior length(" - << interiorPaddingType.getNumElements() - << ") must match operand rank(" << inputRank << ")."; - } - - auto outputType = getResult().getType().dyn_cast(); - // If result is unranked, there is very little to verify statically. - if (!outputType) return success(); - int outputRank = outputType.getRank(); - if (inputRank != outputRank) { - return emitOpError() << "operand rank(" << inputRank - << ") must match result(" << outputRank << ")."; - } - - return success(); + return hlo::verifyDynamicPadOp(getLoc(), getOperand(), getPaddingValue(), + getEdgePaddingLow(), getEdgePaddingHigh(), + getInteriorPadding(), getResult()); } LogicalResult DynamicPadOp::reifyReturnTypeShapes( @@ -6126,26 +4738,11 @@ LogicalResult DynamicPadOp::reifyReturnTypeShapes( //===----------------------------------------------------------------------===// LogicalResult ReshapeOp::verify() { - // If the operand type is dynamically shaped there is nothing to verify. - auto operandTy = getOperand().getType().dyn_cast(); - if (!operandTy || !operandTy.hasStaticShape()) return success(); - - // If the operand type is statically shaped (not required) the number of - // elements must match that of the result type. - auto resultTy = getType().cast(); - assert(resultTy && resultTy.hasStaticShape() && - "result type must be statically shaped"); - int64_t numResultElements = resultTy.getNumElements(); - int64_t numOperandElements = operandTy.getNumElements(); - if (numResultElements != numOperandElements) - return emitOpError() << "number of output elements (" << numResultElements - << ") doesn't match expected number of elements (" - << numOperandElements << ")"; - - return success(); + return hlo::verifyReshapeOp(getLoc(), getOperand(), getResult()); } -OpFoldResult ReshapeOp::fold(ArrayRef operands) { +OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); if (getOperand().getType() == getType()) { return getOperand(); } @@ -6173,11 +4770,19 @@ void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet& results, //===----------------------------------------------------------------------===// LogicalResult ReplicaIdOp::inferReturnTypes( - MLIRContext* context, Optional, ValueRange operands, + MLIRContext* context, Optional location, ValueRange /*operands*/, DictionaryAttr, RegionRange, SmallVectorImpl& inferredReturnTypes) { - inferredReturnTypes.push_back(RankedTensorType::get( - /*shape=*/{}, IntegerType::get(context, 32, IntegerType::Unsigned))); - return success(); + return hlo::inferReplicaIdOp(context, location, inferredReturnTypes); +} + +//===----------------------------------------------------------------------===// +// PartitionId Op +//===----------------------------------------------------------------------===// + +LogicalResult PartitionIdOp::inferReturnTypes( + MLIRContext* context, Optional location, ValueRange /*operands*/, + DictionaryAttr, RegionRange, SmallVectorImpl& inferredReturnTypes) { + return hlo::inferPartitionIdOp(context, location, inferredReturnTypes); } //===----------------------------------------------------------------------===// @@ -6357,13 +4962,28 @@ struct Sign { Optional operator()(const FloatOrInt& fi) { return compute(fi); } }; +template +struct Abs { + APFloat compute(const APFloat& f) { return abs(f); } + + APInt compute(const APInt& i) { return i.abs(); } + + Optional operator()(const FloatOrInt& fi) { return compute(fi); } +}; + double rsqrt(double d) { return 1.0 / std::sqrt(d); } double logistic(double d) { return 1.0 / (1.0 + std::exp(-d)); } // NOLINTBEGIN(bugprone-macro-parentheses) #define UNARY_FOLDER(Op, Func) \ - OpFoldResult Op::fold(ArrayRef attrs) { \ + OpFoldResult Op::fold(FoldAdaptor adaptor) { \ + auto attrs = adaptor.getOperands(); \ + /* AbsOp could take complex but return float */ \ + if (getElementTypeOrSelf(getOperation()->getOperand(0).getType()) != \ + getElementTypeOrSelf(getType())) { \ + return {}; \ + } \ if (getElementTypeOrSelf(getType()).isa()) \ return UnaryFolder>(this, attrs); \ if (getElementTypeOrSelf(getType()).isa()) \ @@ -6372,14 +4992,16 @@ double logistic(double d) { return 1.0 / (1.0 + std::exp(-d)); } } #define UNARY_FOLDER_INT(Op, Func) \ - OpFoldResult Op::fold(ArrayRef attrs) { \ + OpFoldResult Op::fold(FoldAdaptor adaptor) { \ + auto attrs = adaptor.getOperands(); \ if (getElementTypeOrSelf(getType()).isa()) \ return UnaryFolder>(this, attrs); \ return {}; \ } #define UNARY_FOLDER_FLOAT(Op, Func) \ - OpFoldResult Op::fold(ArrayRef attrs) { \ + OpFoldResult Op::fold(FoldAdaptor adaptor) { \ + auto attrs = adaptor.getOperands(); \ if (getElementTypeOrSelf(getType()).isa()) \ return UnaryFolder(this, attrs); \ return {}; \ @@ -6401,7 +5023,8 @@ double logistic(double d) { return 1.0 / (1.0 + std::exp(-d)); } return result; \ } \ }; \ - OpFoldResult Op::fold(ArrayRef attrs) { \ + OpFoldResult Op::fold(FoldAdaptor adaptor) { \ + auto attrs = adaptor.getOperands(); \ if (getElementTypeOrSelf(getType()).isa()) \ return UnaryFolder>(this, attrs); \ @@ -6411,6 +5034,7 @@ double logistic(double d) { return 1.0 / (1.0 + std::exp(-d)); } UNARY_FOLDER(NegOp, std::negate) UNARY_FOLDER(SignOp, Sign) +UNARY_FOLDER(AbsOp, Abs) UNARY_FOLDER_INT(NotOp, std::bit_not) UNARY_FOLDER_FLOAT(RoundNearestEvenOp, RoundNearestEven) UNARY_FOLDER_FLOAT(RoundOp, Round) @@ -6418,9 +5042,11 @@ UNARY_FOLDER_FLOAT(RoundOp, Round) UNARY_FOLDER_UPCAST_TO_F64(CosineOp, std::cos, AnyValue) UNARY_FOLDER_UPCAST_TO_F64(ExpOp, std::exp, AnyValue) UNARY_FOLDER_UPCAST_TO_F64(LogisticOp, logistic, AnyValue) +UNARY_FOLDER_UPCAST_TO_F64(LogOp, std::log, PositiveValue) UNARY_FOLDER_UPCAST_TO_F64(RsqrtOp, rsqrt, PositiveValue) UNARY_FOLDER_UPCAST_TO_F64(SineOp, std::sin, AnyValue) UNARY_FOLDER_UPCAST_TO_F64(SqrtOp, std::sqrt, NonNegativeValue) +UNARY_FOLDER_UPCAST_TO_F64(TanOp, std::tan, AnyValue) UNARY_FOLDER_UPCAST_TO_F64(TanhOp, std::tanh, AnyValue) #undef UNARY_FOLDER @@ -6551,9 +5177,10 @@ struct Min { return BinaryFolder>(this, attrs); \ return {}; -#define BINARY_FOLDER(Op, Func) \ - OpFoldResult Op::fold(ArrayRef attrs) { \ - BINARY_FOLDER_INTERNAL(Op, Func) \ +#define BINARY_FOLDER(Op, Func) \ + OpFoldResult Op::fold(FoldAdaptor adaptor) { \ + auto attrs = adaptor.getOperands(); \ + BINARY_FOLDER_INTERNAL(Op, Func) \ } // Addition, subtraction and multiplication use the std:: versions of the ops. @@ -6566,18 +5193,8 @@ BINARY_FOLDER(RemOp, Remainder) BINARY_FOLDER(MaxOp, Max) BINARY_FOLDER(MinOp, Min) -bool isSplatZero(SplatElementsAttr attr) { - if (!attr) return false; - if (attr.getElementType().isa()) { - return attr.getSplatValue().isZero(); - } - if (attr.getElementType().isa()) { - return attr.getSplatValue().isZero(); - } - return false; -} - -OpFoldResult AddOp::fold(ArrayRef attrs) { +OpFoldResult AddOp::fold(FoldAdaptor adaptor) { + auto attrs = adaptor.getOperands(); // Handle special case where one operand is 0: x + 0 => x if (attrs[0] || attrs[1]) { SplatElementsAttr splatLhs = attrs[0].dyn_cast_or_null(); @@ -6604,7 +5221,8 @@ bool isSplatOne(SplatElementsAttr attr) { return false; } -OpFoldResult MulOp::fold(ArrayRef attrs) { +OpFoldResult MulOp::fold(FoldAdaptor adaptor) { + auto attrs = adaptor.getOperands(); // Handle special case where one operand is 1: x * 1 => x if (attrs[0] || attrs[1]) { SplatElementsAttr splatLhs = attrs[0].dyn_cast_or_null(); @@ -6624,7 +5242,8 @@ OpFoldResult MulOp::fold(ArrayRef attrs) { // Logical Ops //===----------------------------------------------------------------------===// -OpFoldResult AndOp::fold(ArrayRef operands) { +OpFoldResult AndOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); if (getLhs() == getRhs()) return getLhs(); auto lhsVal = operands[0].dyn_cast_or_null(); @@ -6655,7 +5274,8 @@ OpFoldResult AndOp::fold(ArrayRef operands) { this, operands); } -OpFoldResult OrOp::fold(ArrayRef operands) { +OpFoldResult OrOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); if (getLhs() == getRhs()) return getLhs(); auto lhsVal = operands[0].dyn_cast_or_null(); @@ -6686,7 +5306,8 @@ OpFoldResult OrOp::fold(ArrayRef operands) { operands); } -OpFoldResult XorOp::fold(ArrayRef operands) { +OpFoldResult XorOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); // Fold x^x to 0. Attributes only support static shapes. auto rType = getType().cast(); if (getLhs() == getRhs() && rType.hasStaticShape()) { @@ -6718,96 +5339,69 @@ OpFoldResult XorOp::fold(ArrayRef operands) { #undef BINARY_FOLDER //===----------------------------------------------------------------------===// -// SliceOp +// ClampOp //===----------------------------------------------------------------------===// -// Returns output dimension size for slice result for the given arguments. -// Returns -1 if arguments are illegal. -static int64_t inferSliceDim(int64_t inputDim, int64_t start, int64_t end, - int64_t stride) { - if (inputDim == -1 || start < 0 || start > end || end > inputDim || - stride == 0) - return -1; - - return llvm::divideCeil(end - start, stride); +OpFoldResult ClampOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); + auto operand = operands[1].dyn_cast_or_null(); + auto min = operands[0].dyn_cast_or_null(); + auto max = operands[2].dyn_cast_or_null(); + if (!operand || !min || !max) { + return {}; + } + if (min.getType().getRank() == 0) { + min = DenseElementsAttr::get(operand.getType(), + min.getValues()[0]); + } + if (max.getType().getRank() == 0) { + max = DenseElementsAttr::get(operand.getType(), + max.getValues()[0]); + } + Attribute result = {}; + if (operand.getType().getElementType().isa()) { + result = BinaryFolder>( + this, ArrayRef{min, operand}); + result = BinaryFolder>( + this, ArrayRef{max, result}); + } else if (operand.getType().getElementType().isa()) { + result = BinaryFolder>( + this, ArrayRef{min, operand}); + result = BinaryFolder>( + this, ArrayRef{max, result}); + } + return result; } -// The following properties are already enforced by the ODS: -// type(start_indices) == type(limit_indices) == type(strides). -// Verify the following properties: -// P1. Verify rank(start_indices) == 1. -// P2. Verify size(start_indices) == rank(operand). -// P3~5. Verify 0 <= start_indices[i] <= limit_indices[i] <= shape(operand)[i]. -// P6. Verify stride[i] > 0. -LogicalResult SliceOp::inferReturnTypes( - MLIRContext* context, Optional location, ValueRange operands, +LogicalResult ClampOp::inferReturnTypeComponents( + MLIRContext*, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl& inferredReturnTypes) { - SliceOpAdaptor slice(operands, attributes); - Type ty = slice.getOperand().getType(); - RankedTensorType rankedTy = ty.dyn_cast(); - if (!rankedTy) { - // The operand type is unranked, so the best we can infer for the result - // type is an unranked tensor with the same element type as the operand - // type. - inferredReturnTypes.assign({ty}); - return success(); - } + SmallVectorImpl& inferredReturnShapes) { + ClampOp::Adaptor adaptor(operands, attributes, regions); + return hlo::inferClampOp(location, adaptor.getMin(), adaptor.getOperand(), + adaptor.getMax(), inferredReturnShapes); +} - ShapedType attrTy = slice.getStartIndices().getType(); - // P1. - // Note: ODS has type(start_indices) == type(limit_indices) == type(strides) - // So this implies rank(limit_indices) == rank(strides) == 1 also. - if (attrTy.getRank() != 1) { - return emitOptionalError(location, "start_indices has rank ", - attrTy.getRank(), " instead of required rank 1"); - } +LogicalResult ClampOp::reifyReturnTypeShapes( + OpBuilder& builder, ValueRange operands, + SmallVectorImpl& reifiedReturnShapes) { + // For `mhlo.clamp`, the first operand may be a scalar. + return hlo::deriveShapeFromOperand(&builder, getOperation(), operands[1], + &reifiedReturnShapes); +} - // P2. - int64_t rank = rankedTy.getRank(); - if (attrTy.getNumElements() != rank) { - return emitOptionalError( - location, "the number of elements in start_indices (", - attrTy.getNumElements(), ") does not match the rank of the operand (", - rank, ")"); - } - - SmallVector start(slice.getStartIndices().getValues()); - SmallVector limit(slice.getLimitIndices().getValues()); - SmallVector strideVals(slice.getStrides().getValues()); - - SmallVector shape; - shape.reserve(rank); - for (int64_t i = 0, e = rank; i != e; i++) { - if (hlo::isDynamicDimSize(rankedTy.getDimSize(i))) { - shape.push_back(ShapedType::kDynamicSize); - continue; - } - // P3. - if (start[i] < 0) - return emitOptionalError(location, "negative start index ", start[i], - " in dimension ", i); - // P4. - if (limit[i] > rankedTy.getDimSize(i)) - return emitOptionalError(location, "limit index ", limit[i], - " is larger than dimension size ", - rankedTy.getDimSize(i), " in dimension ", i); - // P5. - if (start[i] > limit[i]) - return emitOptionalError(location, "start index ", start[i], - " is larger than limit index ", limit[i], - " in dimension ", i); - // P6. - if (strideVals[i] <= 0) - return emitOptionalError(location, "stride must be positive but got ", - strideVals[i], " in dimension ", i); - - shape.push_back(inferSliceDim(rankedTy.getDimSize(i), start[i], limit[i], - strideVals[i])); - } - inferredReturnTypes.assign( - {RankedTensorType::get(shape, rankedTy.getElementType())}); - return success(); +//===----------------------------------------------------------------------===// +// SliceOp +//===----------------------------------------------------------------------===// + +LogicalResult SliceOp::inferReturnTypes( + MLIRContext* /*context*/, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange /*regions*/, + SmallVectorImpl& inferredReturnTypes) { + SliceOpAdaptor adaptor(operands, attributes); + return hlo::inferSliceOp(location, adaptor.getOperand(), + adaptor.getStartIndices(), adaptor.getLimitIndices(), + adaptor.getStrides(), inferredReturnTypes); } template @@ -6873,7 +5467,8 @@ static Attribute foldSlice(SliceOp* op, I values) { outValues); } -OpFoldResult SliceOp::fold(ArrayRef operands) { +OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); // Check if the SliceOp is a NoOp operation. auto operandType = getOperand().getType().cast(); auto resultType = getResult().getType().cast(); @@ -7022,19 +5617,16 @@ void SortOp::build(OpBuilder& builder, OperationState& state, } LogicalResult SortOp::inferReturnTypeComponents( - MLIRContext*, Optional, ValueShapeRange operands, + MLIRContext*, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { SortOp::Adaptor adaptor(operands, attributes, regions); - for (auto resultType : adaptor.getInputs().getTypes()) - inferredReturnShapes.emplace_back(resultType.cast()); - return success(); + return hlo::inferSortOp(location, adaptor.getInputs(), inferredReturnShapes); } LogicalResult SortOp::verify() { - SmallVector unusedReturnShapes; - return hlo::inferSortOp(getLoc(), getInputs(), getDimension(), - getComparator(), unusedReturnShapes); + return hlo::verifySortOp(getLoc(), getInputs(), getDimension(), + getComparator()); } /// Drops the operands if the results are not used and they are not used in @@ -7115,8 +5707,8 @@ void SortOp::getCanonicalizationPatterns(RewritePatternSet& results, // TransposeOp //===----------------------------------------------------------------------===// +// DISC-Begin namespace { - // Transpose the given elements attr according to the specified permutation. mlir::ElementsAttr TransposeElementsAttr( const mlir::ElementsAttr& elements, const DenseIntElementsAttr& perm_attr) { @@ -7163,21 +5755,28 @@ mlir::ElementsAttr TransposeElementsAttr( } } // namespace +// DISC-End + +OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); -OpFoldResult TransposeOp::fold(ArrayRef operands) { + // DISC-BEGIN // If the result has non-static shape, a transpsed op is necessary to go from // static shape to non-static shape. auto resultTy = getResult().getType().dyn_cast(); if (!resultTy || !resultTy.hasStaticShape()) return {}; + // DISC-END if (auto elements = operands.front().dyn_cast_or_null()) { return reshape(elements, getResult().getType().cast()); } + // DISC-BEGIN // operand is const, thus fold it directly. if (auto elementsAttr = operands.front().dyn_cast_or_null()) { return TransposeElementsAttr(elementsAttr, getPermutation()); } + // DISC-END for (const auto& it : llvm::enumerate(getPermutation().getValues())) { if (it.index() != it.value()) { @@ -7307,47 +5906,13 @@ LogicalResult TransposeOp::reifyReturnTypeShapes( return success(); } -// Method for InferTypeOpInterface: infer the return type from the operand type -// and the permutation. LogicalResult TransposeOp::inferReturnTypes( - MLIRContext* /*context*/, Optional loc, ValueRange operands, - DictionaryAttr attributes, RegionRange, + MLIRContext*, Optional loc, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - auto type = operands[0].getType(); - auto rankedTy = type.dyn_cast(); - if (!rankedTy) { - auto shapedTy = type.dyn_cast(); - inferredReturnTypes.emplace_back(shapedTy); - return success(); - } - auto permutation = attributes.getAs("permutation"); - int64_t rank = rankedTy.getRank(); - if (permutation.getType().getRank() != 1) - return emitOptionalError(loc, "TransposeOp permutation has rank ", - permutation.getType().getRank(), - " instead of rank 1"); - - if (permutation.size() != rank) - return emitOptionalError(loc, "TransposeOp operand rank ", rank, - " does not match permutation size ", - permutation.size()); - - std::vector range(rank); - std::iota(range.begin(), range.end(), 0); - if (!std::is_permutation(range.begin(), range.end(), permutation.begin())) - return emitOptionalError(loc, - "attribute permutation must be a permutation" - " of [", - range, "] but got ", permutation); - - SmallVector resultShape; - ArrayRef inputShape = rankedTy.getShape(); - for (int64_t dim : permutation.getValues()) { - resultShape.push_back(inputShape[dim]); - } - inferredReturnTypes.emplace_back(RankedTensorType::get( - resultShape, rankedTy.getElementType(), rankedTy.getEncoding())); - return success(); + TransposeOp::Adaptor adaptor(operands, attributes, regions); + return hlo::inferTransposeOp(loc, adaptor.getOperand(), + adaptor.getPermutation(), inferredReturnTypes); } //===----------------------------------------------------------------------===// @@ -7370,20 +5935,21 @@ LogicalResult TriangularSolveOp::inferReturnTypeComponents( // GetTupleElementOp //===----------------------------------------------------------------------===// -LogicalResult GetTupleElementOp::inferReturnTypes( - MLIRContext*, Optional, ValueRange operands, - DictionaryAttr attributes, RegionRange, - SmallVectorImpl& inferredReturnTypes) { - auto tupleType = operands[0].getType().dyn_cast(); - if (!tupleType) return failure(); +OpFoldResult GetTupleElementOp::fold(FoldAdaptor /*adaptor*/) { + if (auto tupleOp = getOperand().getDefiningOp()) { + return tupleOp.getOperand(getIndex()); + } - auto indexAttr = attributes.get("index").cast(); - auto index = indexAttr.getInt(); - if (index < 0 || index >= static_cast(tupleType.size())) - return failure(); + return {}; +} - inferredReturnTypes.push_back(tupleType.getType(index)); - return success(); +LogicalResult GetTupleElementOp::inferReturnTypes( + MLIRContext*, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnTypes) { + GetTupleElementOp::Adaptor adaptor(operands, attributes, regions); + return hlo::inferGetTupleElementOp(location, adaptor.getOperand(), + adaptor.getIndex(), inferredReturnTypes); } //===----------------------------------------------------------------------===// @@ -7391,11 +5957,12 @@ LogicalResult GetTupleElementOp::inferReturnTypes( //===----------------------------------------------------------------------===// LogicalResult TupleOp::inferReturnTypes( - MLIRContext* context, Optional, ValueRange operands, - DictionaryAttr attributes, RegionRange, + MLIRContext* context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { - inferredReturnTypes.push_back(TupleType::get(context, TypeRange(operands))); - return success(); + TupleOp::Adaptor adaptor(operands, attributes, regions); + return hlo::inferTupleOp(context, location, adaptor.getVal(), + inferredReturnTypes); } //===----------------------------------------------------------------------===// @@ -7420,17 +5987,12 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, } LogicalResult CompareOp::inferReturnTypeComponents( - mlir::MLIRContext* ctx, llvm::Optional, - ValueShapeRange operands, mlir::DictionaryAttr, mlir::RegionRange, - llvm::SmallVectorImpl& inferredReturnTypes) { - ShapedTypeComponents& components = - inferredReturnTypes.emplace_back(IntegerType::get(ctx, /*width=*/1)); - auto argTy = operands.front().getType().cast(); - if (argTy.hasRank()) { - components = - ShapedTypeComponents(argTy.getShape(), components.getElementType()); - } - return success(); + MLIRContext* context, Optional location, ValueShapeRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnShapes) { + CompareOp::Adaptor adaptor(operands, attributes, regions); + return hlo::inferCompareOp(context, location, adaptor.getLhs(), + inferredReturnShapes); } LogicalResult CompareOp::reifyReturnTypeShapes( @@ -7475,7 +6037,8 @@ static Attribute CompareFolder(CompareOp op, ArrayRef attrs) { return DenseElementsAttr::get(resultTy, values); } -OpFoldResult CompareOp::fold(ArrayRef operands) { +OpFoldResult CompareOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); auto resultTy = getType().cast(); if (!resultTy.hasStaticShape()) return {}; @@ -7554,421 +6117,42 @@ OpFoldResult CompareOp::fold(ArrayRef operands) { // SelectAndScatterOp //===----------------------------------------------------------------------===// -namespace { -// Infer the return-type of SelectAndScatterOp. -TensorType inferSelectAndScatterOpReturnType( - TensorType operandType, const ArrayRef window) { - if (!operandType.hasRank()) - return UnrankedTensorType::get(operandType.getElementType()); - - return RankedTensorType::get( - inferWindowOutputShape(operandType.getShape(), window), - operandType.getElementType()); +LogicalResult SelectAndScatterOp::inferReturnTypes( + MLIRContext*, Optional, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnTypes) { + SelectAndScatterOp::Adaptor adaptor(operands, attributes, regions); + return hlo::inferSelectAndScatterOp(adaptor.getOperand(), + inferredReturnTypes); } -} // namespace -// We intend to verify the following properties: -// P1. Check if the select function has a proper shape of (T,T) -> PRED, where -// T is a 0-D tensor with element-type same as 'operand' element-type. -// P2. Verify scatter-computation type. -// P3. size-of(window_dimension) == rank-of(input), -// where input is an element of 'inputs'. -// P4. Verify and collect the window attributes. -// P5. Verify the return type matches the operand-type. -// P6. Check if the result type of window operation matches the source type. LogicalResult SelectAndScatterOp::verify() { - auto operandType = getOperand().getType().cast(); - auto initValueType = getInitValue().getType().cast(); - auto sourceType = getSource().getType().cast(); - auto resultType = getResult().getType().cast(); - - // P1. - Block& selectBlock = getSelect().front(); - - if (selectBlock.getArguments().size() != 2) - return emitOpError() - << "expects the select-region to take 2 parameters, but takes " - << selectBlock.getArguments().size(); - - Type expectedSelectArgType = - RankedTensorType::get({}, operandType.getElementType()); - for (const auto& selectArgIt : llvm::enumerate(selectBlock.getArguments())) - if (!hlo::compatibleShapeAndElementType(expectedSelectArgType, - selectArgIt.value().getType(), - /*ignoreFpPrecision=*/true)) - return emitOpError() - << "expects the type of select-region's parameter at index " - << selectArgIt.index() << " to be " << expectedSelectArgType - << ", but got " << selectArgIt.value().getType(); - - auto selectResult = selectBlock.getTerminator()->getOperands(); - if (selectResult.size() != 1) - return emitOpError() - << "expects select-region to return single value, but got: " - << selectResult.size(); - - auto selectResultType = selectResult[0].getType().dyn_cast(); - if (!selectResultType || !selectResultType.getElementType().isInteger(1) || - (selectResultType.hasRank() && - selectResultType.cast().getRank() != 0)) - return emitOpError() << "expects the return-type of select-region to be " - "tensor, but got: " - << selectResult[0].getType(); - - // P2. - Block& scatterBlock = getScatter().front(); - SmallVector accumulatorSubshapes; - if (failed(hlo::verifyReducerShape( - this->getLoc(), scatterBlock, - {RankedTensorType::get({}, sourceType.getElementType())}, - {initValueType}, - /*numInputs=*/1, /*allowedDimensions=*/{}, - /*allInputsUnranked=*/false, accumulatorSubshapes))) - return failure(); - - // P3. - SmallVector windowDims = - convertDenseIntAttr(this->getWindowDimensions()); - if (operandType.hasRank()) { - if (operandType.getRank() != static_cast(windowDims.size())) - return emitOpError() - << "expects window-dimensions size == operand rank, but got " - "window-dimensions size: " - << windowDims.size() << " and operand-type: " << operandType - << " with rank = " << operandType.getRank() << "."; - } - - // P4. - auto paddingOrErr = convertNx2Attribute(this->getPadding(), getLoc()); - if (failed(paddingOrErr)) return failure(); - SmallVector> padding = *paddingOrErr; - - auto windowOrErr = hlo::verifyWindowAttributesAndInferWindowDimensions( - windowDims, convertDenseIntAttr(getWindowStrides()), padding, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{}, getLoc()); - if (failed(windowOrErr)) return failure(); - - // P5. - if (!hlo::compatibleShapeAndElementType(operandType, resultType)) - return emitOpError() - << "expects the return-type to match the operand-type, but got " - << resultType << " and " << operandType << " resp."; - - // P6. - auto windowResultType = - inferSelectAndScatterOpReturnType(operandType, *windowOrErr); - - if (!hlo::compatibleShapeAndElementType(windowResultType, sourceType, - /*ignoreFpPrecision=*/true)) - return emitOpError() << "expects source-type to be " << windowResultType - << ", but got" << sourceType; - - return success(); + return hlo::verifySelectAndScatterOp(getLoc(), getOperand(), getSource(), + getInitValue(), getWindowDimensions(), + getWindowStrides(), getPadding(), + getSelect(), getScatter()); } //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// -/* - * We intend to verify the following properties: - * P1. The 'update_window_dims' must be valid indices of 'updates' tensor. - * P2. The 'inserted_window_dims' must be valid indices of 'operand' tensor. - * P3. Check if the rank-of('operand') == size-of('update_window_dims') + - * size-of('inserted_window_dims') - * P4. size-of('scatter_dims_to_operand_dims') = - * 'scatter_indices'['index_vector_dim'] & - * 'scatter_dims_to_operand_dims' must be valid indices of 'operand' tensor. - */ -LogicalResult validateScatterDimensionNumbers( - ShapedType operandType, ArrayRef scatterIndicesShape, - ShapedType updateType, bool operandTypeRanked, - bool scatterIndicesTypeRanked, bool updatesTypeRanked, - ScatterDimensionNumbersAttr dimNumbers, Location loc) { - // P1. - auto updateWindowDims = to_vector(dimNumbers.getUpdateWindowDims()); - if (!llvm::is_sorted(updateWindowDims)) - return mlir::emitError(loc) - << "Expects update_window_dims to be sorted; got: [" - << updateWindowDims << "]."; - - if (hasDuplicates(updateWindowDims)) - return mlir::emitError(loc) - << "Expects update_window_dims to not repeat; got: [" - << updateWindowDims << "]."; - - if (updatesTypeRanked) { - for (int64_t windowDim : updateWindowDims) { - if (windowDim < 0 || windowDim >= updateType.getRank()) { - return mlir::emitError(loc) - << "Expects each element of update_window_dims to be in range " - "[0, " - "rank-of('updates') i.e. [0, " - << updateType.getRank() << "). got: " << windowDim << "."; - } - } - } - - // P2. - auto insertedWindowDims = to_vector(dimNumbers.getInsertedWindowDims()); - if (!llvm::is_sorted(insertedWindowDims)) - return mlir::emitError(loc) - << "Expects inserted_window_dims to be sorted; got: [" - << insertedWindowDims << "]."; - - if (hasDuplicates(insertedWindowDims)) - return mlir::emitError(loc) - << "Expects inserted_window_dims to not repeat; got: [" - << insertedWindowDims << "]."; - - if (operandTypeRanked) { - for (int64_t insertedDim : insertedWindowDims) { - if (insertedDim < 0 || insertedDim >= operandType.getRank()) { - return mlir::emitError(loc) - << "Expects each element of inserted_window_dims to be in range " - "[0, rank-of('operand') i.e. [0, " - << operandType.getRank() << "). got: " << insertedDim << "."; - } - } - } - - // P3. - if (operandTypeRanked) { - auto windowSize = updateWindowDims.size() + insertedWindowDims.size(); - if (operandType.getRank() != static_cast(windowSize)) - return mlir::emitError(loc) - << "Expects rank-of operand to match " - "size-of('update_window_dims') + " - "size-of('inserted_window_dims') i.e. " - << windowSize << " but got " << operandType.getRank() << "."; - } - - // P4. - auto scatterDimsToOperandDims = - to_vector(dimNumbers.getScatterDimsToOperandDims()); - auto indexVectorDim = dimNumbers.getIndexVectorDim(); - if (scatterIndicesTypeRanked) { - if (!hlo::isDynamicDimSize(scatterIndicesShape[indexVectorDim]) && - static_cast(scatterDimsToOperandDims.size()) != - scatterIndicesShape[dimNumbers.getIndexVectorDim()]) - return mlir::emitError(loc) - << "Scatter op has " << scatterDimsToOperandDims.size() - << " elements in scatter_dims_to_operand_dims and the bound of " - "dimension index_vector_dim=" - << dimNumbers.getIndexVectorDim() << " of scatter_indices is " - << scatterIndicesShape[dimNumbers.getIndexVectorDim()] - << ". These two numbers must be equal."; - } - - if (operandTypeRanked) { - for (int64_t i = 0; - i < static_cast(scatterDimsToOperandDims.size()); ++i) { - int64_t scatterDimToOperandDim = scatterDimsToOperandDims[i]; - if (scatterDimToOperandDim < 0 || - scatterDimToOperandDim >= operandType.getRank()) - return mlir::emitError(loc) - << "Invalid scatter_dims_to_operand_dims mapping; domain is [0, " - << operandType.getRank() << "), got: " << i << "->" - << scatterDimToOperandDim << "."; - } - } - - if (hasDuplicates(scatterDimsToOperandDims)) - return mlir::emitError(loc) - << "Expects scatter_dims_to_operand_dims to not repeat; got: [" - << scatterDimsToOperandDims << "]."; - - return success(); +LogicalResult ScatterOp::inferReturnTypes( + MLIRContext*, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnTypes) { + ScatterOp::Adaptor adaptor(operands, attributes, regions); + return hlo::inferScatterOp(location, adaptor.getInputs(), + inferredReturnTypes); } -/* - * We intend to verify the following properties: - * P0. scatter_indices argument must be an integral tensor. Enforced by ODS. - * P1. Scatter index leaf dimension must be within [0, rank(scatter_indices)" - * " + 1). - * P2. Verify reducer shape. - * P3. rank-of('updates[i]') == size-of('update_window_dims') + - * rank-of('scatter_indices') - 1, where 'scatter_indices' is expanded by a - * trailing 1 dimension if 'index_vector_dim' == rank-of('scatter_indices') - * for all values of `i`. - * P4. Validate the scatter-dimensions-numbers. - * P5. Valide the bounds of each of the 'updates' w.r.t the operands. - * P6. Validate the bounds of each of the 'updates' w.r.t the - * 'scatter_indices'. - * P7. Check return types. - */ -LogicalResult ScatterOp::verify() { - // Get the first operand and update, since variadic Scatter is not yet - // implemented - auto numOperands = getInputs().size(); - auto scatterIndicesType = - getScatterIndices().getType().dyn_cast(); - - SmallVector operandTypes = - llvm::to_vector(llvm::map_range(getInputs().getTypes(), [](Type type) { - return type.cast(); - })); - SmallVector updatesTypes = - llvm::to_vector(llvm::map_range(getUpdates().getTypes(), [](Type type) { - return type.cast(); - })); - bool allOperandTypesRanked = - llvm::all_of(getInputs().getTypes(), - [](Type type) { return type.isa(); }); - bool scatterIndicesTypeRanked = scatterIndicesType.isa(); - - // P1. - int64_t indexVectorDim = getScatterDimensionNumbers().getIndexVectorDim(); - if (scatterIndicesTypeRanked) { - if (indexVectorDim > scatterIndicesType.getRank() || indexVectorDim < 0) - return emitOpError() - << "expects scatter index leaf dimension to be within [0, " - "rank(scatter_indices) + 1." - " rank(scatter_indices) is " - << scatterIndicesType.getRank() - << " and scatter index leaf dimension is " << indexVectorDim - << "."; - } - - // P2. - Block& block = getUpdateComputation().front(); - SmallVector accumulatorSubshapes; - SmallVector inputTypes, initValueTypes; - for (int64_t i = 0; i < static_cast(numOperands); i++) { - inputTypes.push_back(operandTypes[i]); - initValueTypes.push_back( - RankedTensorType::get({}, updatesTypes[i].getElementType())); - } - if (failed(hlo::verifyReducerShape( - this->getLoc(), block, inputTypes, initValueTypes, numOperands, - /*allowedDimensions=*/{}, - /*allInputsUnranked=*/!allOperandTypesRanked, accumulatorSubshapes))) - return failure(); - - // P3. - auto updateWindowDims = getScatterDimensionNumbers().getUpdateWindowDims(); - SmallVector expandedScatterIndicesShape; - if (scatterIndicesTypeRanked) { - expandedScatterIndicesShape = - llvm::to_vector(scatterIndicesType.getShape()); - if (static_cast(expandedScatterIndicesShape.size()) == - indexVectorDim) - expandedScatterIndicesShape.push_back(1); - } - - for (int64_t i = 0; i < static_cast(numOperands); i++) { - if (scatterIndicesTypeRanked && updatesTypes[i].isa()) { - int64_t expectedUpdatesRank = - expandedScatterIndicesShape.size() - 1 + updateWindowDims.size(); - if (updatesTypes[i].getRank() != expectedUpdatesRank) - return emitOpError() - << "expects updates tensor must be of rank " - << expectedUpdatesRank - << " ( == rank-of('scatter_indices') - 1 + " - "size-of('update_window_dims'), where 'scatter_indices' is " - "expanded by a trailing 1 dimension if 'index_vector_dim' == " - "rank-of('scatter_indices')), but got " - << updatesTypes[i].getRank() << "."; - } - } - - // P4. - for (int64_t i = 0; i < static_cast(numOperands); i++) { - if (failed(validateScatterDimensionNumbers( - operandTypes[i], expandedScatterIndicesShape, updatesTypes[i], - operandTypes[i].isa(), scatterIndicesTypeRanked, - updatesTypes[i].isa(), - getScatterDimensionNumbers(), getLoc()))) - return failure(); - } - - // P5. - for (int64_t i = 0; i < static_cast(numOperands); i++) { - if (updatesTypes[i].isa()) { - auto updatesShape = updatesTypes[i].getShape(); - if (operandTypes[i].isa()) { - auto operandShape = operandTypes[i].getShape(); - auto insertedWindowDims = - getScatterDimensionNumbers().getInsertedWindowDims(); - - int64_t insertedDimsSeen = 0; - SmallVector maxUpdateSliceSizes; - const auto dimensionsSize = operandTypes[i].getRank(); - maxUpdateSliceSizes.reserve(dimensionsSize); - for (int i = 0; i < dimensionsSize; ++i) { - if (insertedDimsSeen < - static_cast(insertedWindowDims.size()) && - insertedWindowDims[insertedDimsSeen] == i) { - ++insertedDimsSeen; - } else { - maxUpdateSliceSizes.push_back(operandShape[i]); - } - } - - for (int64_t i = 0; i < static_cast(updateWindowDims.size()); - ++i) { - auto updateWindowDim = updateWindowDims[i]; - - if (hlo::isDynamicDimSize(updatesShape[updateWindowDim]) || - hlo::isDynamicDimSize(maxUpdateSliceSizes[i])) - continue; - - if (updatesShape[updateWindowDim] > maxUpdateSliceSizes[i]) { - return emitOpError() - << "expects bounds of the window dimensions of " - "updates to not exceed the " - "bounds of the corresponding dimensions of " - "operand. For dimension " - << updateWindowDim << ", updates bound is " - << updatesShape[updateWindowDim] << ", operand bound is " - << maxUpdateSliceSizes[i] << "."; - } - } - } - - // P6. - if (scatterIndicesTypeRanked) { - int64_t scatterDimsSeen = 0; - for (int64_t i = 0; i < static_cast(updatesShape.size()); - ++i) { - bool isUpdateWindowDim = std::binary_search( - updateWindowDims.begin(), updateWindowDims.end(), i); - - if (isUpdateWindowDim) continue; - if (scatterDimsSeen == indexVectorDim) ++scatterDimsSeen; - - if (!hlo::isDynamicDimSize(updatesShape[i]) && - !hlo::isDynamicDimSize( - expandedScatterIndicesShape[scatterDimsSeen]) && - (updatesShape[i] != - expandedScatterIndicesShape[scatterDimsSeen])) { - return emitOpError() - << "expects bounds of the scatter dimensions of " - "updates to be same as the " - "bounds of the corresponding dimensions of " - "scatter indices. For " - "scatter dimension " - << i << ", updates bound is " << updatesShape[i] - << " , scatter_indices " - "bound is " - << expandedScatterIndicesShape[scatterDimsSeen] << "."; - } - ++scatterDimsSeen; - } - } - } - } - - // P7. - for (int64_t i = 0; i < static_cast(numOperands); i++) { - if (!hlo::compatibleShapeAndElementType(operandTypes[i], - getResult(i).getType())) - return emitOpError() - << "expects the return type to be same as the operand type: " - << operandTypes[i] << ", but got " << getResult(i).getType() - << "."; - } - return success(); +LogicalResult ScatterOp::verify() { + return hlo::verifyScatterOp( + getLoc(), getInputs(), getScatterIndices(), getUpdates(), + getScatterDimensionNumbers().getUpdateWindowDims(), + getScatterDimensionNumbers().getInsertedWindowDims(), + getScatterDimensionNumbers().getScatterDimsToOperandDims(), + getScatterDimensionNumbers().getIndexVectorDim(), getUpdateComputation()); } llvm::SmallVector evaluateMhloRegion(Region& region, @@ -7999,8 +6183,8 @@ llvm::SmallVector evaluateMhloRegion(Region& region, } LogicalResult ScatterOp::fold( - ArrayRef args, - llvm::SmallVectorImpl& foldResults) { + FoldAdaptor adaptor, llvm::SmallVectorImpl& foldResults) { + auto args = adaptor.getOperands(); // Variadic Scatter not yet implemented if (getInputs().size() != 1 || getUpdates().size() != 1) return failure(); auto index = args[1].dyn_cast_or_null(); @@ -8183,8 +6367,11 @@ LogicalResult WhileOp::inferReturnTypes( DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnTypes) { WhileOp::Adaptor adaptor(operands, attributes, regions); - return hlo::inferWhileOp(location, adaptor.getOperand(), adaptor.getCond(), - adaptor.getBody(), inferredReturnTypes); + return hlo::inferWhileOp(location, adaptor.getOperand(), inferredReturnTypes); +} + +LogicalResult WhileOp::verify() { + return hlo::verifyWhileOp(getLoc(), getOperand(), getCond(), getBody()); } /// Print a `while` op. @@ -8253,10 +6440,13 @@ ParseResult WhileOp::parse(OpAsmParser& parser, OperationState& result) { return success(); } -LogicalResult WhileOp::fold(ArrayRef /*operands*/, +LogicalResult WhileOp::fold(FoldAdaptor /*adaptor*/, SmallVectorImpl& results) { DenseIntElementsAttr condValue; - auto condReturnOp = cast(getCond().front().back()); + // TODO: This folder is executed on invalid mhlo.while ops during + // LegalizeMhlo, mlir_hlo/tosa/tests/unary.mlir. Broken pattern? + auto condReturnOp = dyn_cast(getCond().front().back()); + if (!condReturnOp) return failure(); if (!matchPattern(condReturnOp.getOperand(0), m_Constant(&condValue))) return failure(); if (condValue.getSplatValue().getValue()) @@ -8330,16 +6520,12 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet& results, } LogicalResult UniformDequantizeOp::inferReturnTypeComponents( - MLIRContext*, Optional /*location*/, ValueShapeRange operands, + MLIRContext*, Optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl& inferredReturnShapes) { UniformDequantizeOp::Adaptor adaptor(operands, attributes, regions); - auto operandType = (*operands.begin()).getType().cast(); - // Trait HLO_QuantizedIntTensor in ODS guarantees QuantizedType; - auto quantType = operandType.getElementType().cast(); - auto shape = operandType.dyn_cast().getShape(); - inferredReturnShapes.emplace_back(shape, quantType.getExpressedType()); - return success(); + return hlo::inferUniformDequantizeOp(location, adaptor.getOperand(), + inferredReturnShapes); } using mlir::hlo::parseWindowAttributes; @@ -8348,23 +6534,29 @@ using mlir::hlo::printWindowAttributes; } // namespace mhlo } // namespace mlir -// clang-format off -using mlir::hlo::printSameOperandsAndResultType; +using mlir::hlo::parseComplexOpType; +using mlir::hlo::parseCustomCallTarget; +using mlir::hlo::parseDenseI64Array; +using mlir::hlo::parseExponentMantissa; +using mlir::hlo::parsePairwiseOpType; using mlir::hlo::parseSameOperandsAndResultType; -using mlir::hlo::printVariadicSameOperandsAndResultType; +using mlir::hlo::parseSelectOpType; +using mlir::hlo::parseTupleOpType; +using mlir::hlo::parseVariadicOperandWithAttribute; using mlir::hlo::parseVariadicSameOperandsAndResultType; using mlir::hlo::printComplexOpType; -using mlir::hlo::parseComplexOpType; +using mlir::hlo::printCustomCallTarget; +using mlir::hlo::printDenseI64Array; +using mlir::hlo::printExponentMantissa; using mlir::hlo::printPairwiseOpType; -using mlir::hlo::parsePairwiseOpType; +using mlir::hlo::printSameOperandsAndResultType; using mlir::hlo::printSelectOpType; -using mlir::hlo::parseSelectOpType; using mlir::hlo::printTupleOpType; -using mlir::hlo::parseTupleOpType; -// clang-format on +using mlir::hlo::printVariadicOperandWithAttribute; +using mlir::hlo::printVariadicSameOperandsAndResultType; #define GET_OP_CLASSES -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" +#include "mhlo/IR/hlo_ops.cc.inc" namespace mlir { namespace mhlo { @@ -8374,7 +6566,7 @@ namespace mhlo { //===----------------------------------------------------------------------===// namespace { -struct HLOInlinerInterface : public DialectInlinerInterface { +struct MhloDialectInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; // Allow all call operations to be inlined. @@ -8385,21 +6577,27 @@ struct HLOInlinerInterface : public DialectInlinerInterface { // We don't have any special restrictions on what can be inlined into // destination regions (e.g. while/conditional bodies). Always allow it. bool isLegalToInline(Region* dest, Region* src, bool wouldBeCloned, - BlockAndValueMapping& valueMapping) const final { + IRMapping& valueMapping) const final { return true; } // Operations in mhlo dialect are always legal to inline since they are // pure. bool isLegalToInline(Operation*, Region*, bool, - BlockAndValueMapping&) const final { + IRMapping&) const final { return true; } }; -struct HLOBoundedDialectInterface : public hlo::BoundedDialectInterface { - using BoundedDialectInterface::BoundedDialectInterface; +struct MhloHloDialectInterface : public hlo::HloDialectInterface { + using HloDialectInterface::HloDialectInterface; + + Type createTokenType() const override { + return TokenType::get(getDialect()->getContext()); + } + + bool isTokenType(Type type) const override { return type.isa(); } - Attribute createBoundedAttr(ArrayRef bounds) const override { + Attribute createTypeExtensions(ArrayRef bounds) const override { return TypeExtensionsAttr::get(getDialect()->getContext(), bounds); } }; @@ -8413,15 +6611,15 @@ MhloDialect::MhloDialect(MLIRContext* context) : Dialect(getDialectNamespace(), context, TypeID::get()) { addOperations< #define GET_OP_LIST -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" +#include "mhlo/IR/hlo_ops.cc.inc" >(); - addInterfaces(); - addInterfaces(); + addInterfaces(); + addInterfaces(); addBytecodeInterface(this); addTypes(); addAttributes< #define GET_ATTRDEF_LIST -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_attrs.cc.inc" +#include "mhlo/IR/hlo_ops_attrs.cc.inc" >(); context->loadDialect(); } @@ -8466,37 +6664,28 @@ void MhloDialect::printAttribute(Attribute attr, DialectAsmPrinter& os) const { } /// Helpers for attributes parsing. -static ParseResult parseDims(AsmParser& parser, SmallVector& dims) { - dims.clear(); - return parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&] { - dims.emplace_back(); - return parser.parseInteger(dims.back()); - }); +static ParseResult parseDims(AsmParser& parser, + SmallVector& dimSizes) { + dimSizes.clear(); + auto failOrDims = parseDimSizes(parser); + if (failed(failOrDims)) { + return failure(); + } + dimSizes = std::move(*failOrDims); + return success(); } static ParseResult parseDimsWithMinimumElements(AsmParser& parser, - SmallVector& dims, + SmallVector& dimSizes, int minElements) { - if (failed(parseDims(parser, dims))) return failure(); - if (static_cast(dims.size()) < minElements) + if (failed(parseDims(parser, dimSizes))) return failure(); + if (static_cast(dimSizes.size()) < minElements) return parser.emitError(parser.getCurrentLocation()) << "expected at least " << minElements << " element(s), found " - << dims.size(); + << dimSizes.size(); return success(); } -FailureOr> parseIntArray(AsmParser& parser) { - SmallVector ints; - if (failed(parseDims(parser, ints))) return failure(); - return ints; -} - -void printIntArray(AsmPrinter& printer, ArrayRef ints) { - printer << '['; - llvm::interleaveComma(ints, printer); - printer << ']'; -} - /// Parse a custom attribute that resembles a struct of the form /// < /// foo = something_parsed_by_custom_parser, @@ -9168,6 +7357,92 @@ static LogicalResult verifyArgResultAliasAttr(StringAttr attrName, return success(); } +// Each CrossProgramPrefetchAttr specifies a parameter and a ShapeIndex +// (1) the parameter must be valid +// (2) there must be a subshape at the given indices +LogicalResult verifyCrossProgramPrefetchAttr(CrossProgramPrefetchAttr cpp, + ModuleOp module) { + func::FuncOp main = module.lookupSymbol("main"); + if (cpp.getParameter() >= main.getNumArguments()) + return module->emitOpError() + << "cross_program_prefetch: parameter " << cpp.getParameter() + << " out of range. main has only " << main.getNumArguments() + << " arguments"; + auto type = getTypeFromTupleIndices(main.getArgument(cpp.getParameter()) + .getType() + .dyn_cast_or_null(), + cpp.getIndices()); + if (!type) + return module->emitOpError() + << "cross_program_prefetch: no subshape at given index: " + << cpp.getIndices(); + return success(); +} + +// Each DynamicParameterBinding specifies a dynamic parameter, a target +// parameter, a shape index of each and a target dimension. +// (1) the parameters must be valid +// (2) there must be a subshape at the given ShapeIndex for each parameter +// (3) the given subshape for the dynamic parameter must be of type tensor +// (4) there must be a dimension at the given dimension number for the given +// subshape of the target parameter +// (5) that dimension is dynamic +LogicalResult verifyDynamicParameterBinding(DynamicParameterBindingAttr bind, + ModuleOp module) { + func::FuncOp main = module.lookupSymbol("main"); + + // (1) + if (bind.getDynamicParamNum() >= main.getNumArguments() || + bind.getTargetParamNum() >= main.getNumArguments()) + return module->emitOpError() + << "dynamic_parameter_binding: parameters " + << bind.getDynamicParamNum() << " and " << bind.getTargetParamNum() + << " out of range. main has only " << main.getNumArguments() + << " arguments"; + + // (2) + auto dynamicParamSubshape = + getTypeFromTupleIndices( + main.getArgument(bind.getDynamicParamNum()).getType(), + bind.getDynamicParamIndices()) + .dyn_cast_or_null(); + if (!dynamicParamSubshape) + return module->emitOpError() << "dynamic_parameter_binding: no ranked " + "tensor type at dynamic_param_indices: " + << bind.getDynamicParamIndices(); + // (3) + if (dynamicParamSubshape.getRank() != 0 || + !dynamicParamSubshape.getElementType().isInteger(32)) + return module->emitOpError() + << "dynamic_parameter_binding: dynamic size must be tensor"; + + // (2) + auto targetParamSubshape = + getTypeFromTupleIndices( + main.getArgument(bind.getTargetParamNum()).getType(), + bind.getTargetParamIndices()) + .dyn_cast_or_null(); + if (!targetParamSubshape) + return module->emitOpError() << "dynamic_parameter_binding: no ranked " + "tensor type at target_param_indices: " + << bind.getTargetParamIndices(); + // (4) + if (targetParamSubshape.getRank() <= bind.getTargetParamDimNum()) + return module->emitOpError() + << "dynamic_parameter_binding: no dimension number " + << bind.getTargetParamDimNum() << " in target subshape " + << targetParamSubshape; + + // (5) + if (!targetParamSubshape.isDynamicDim(bind.getTargetParamDimNum())) + return module->emitOpError() + << "dynamic_parameter_binding: dimension number " + << bind.getTargetParamDimNum() << " in target subshape " + << targetParamSubshape << " is not dynamic"; + + return success(); +} + //===----------------------------------------------------------------------===// // Builder utilities //===----------------------------------------------------------------------===// @@ -9215,7 +7490,7 @@ SortOp createSortOp(PatternRewriter* rewriter, const Location& loc, // Use TOTALORDER comparison type instead of the default comparison if the // element type is of type float. - llvm::Optional compareType = llvm::None; + llvm::Optional compareType = std::nullopt; for (auto const& elementType : elementTypes) if (elementType.isa()) { compareType.emplace("TOTALORDER"); @@ -9242,6 +7517,16 @@ Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value, return builder.create(loc, type, elementsAttr); } +int64_t getNumLeafBuffers(Type type) { + if (auto tuple = type.dyn_cast()) { + auto ans = 0; + for (auto type : tuple.getTypes()) ans += getNumLeafBuffers(type); + return ans; + } else { + return 1; + } +} + LogicalResult MhloDialect::verifyRegionArgAttribute(Operation* op, unsigned /*regionIndex*/, unsigned argIndex, @@ -9251,6 +7536,27 @@ LogicalResult MhloDialect::verifyRegionArgAttribute(Operation* op, verifyArgResultAliasAttr(attr.getName(), aliasAttr, argIndex, op))) return failure(); } + if (attr.getName() == "mhlo.parameter_replication") { + auto arrayAttr = attr.getValue().dyn_cast(); + if (!arrayAttr) + return op->emitOpError() << "parameter_replication: must be an array"; + auto func = dyn_cast(op); + if (!func) { + return op->emitOpError() + << "has parameter_replication but is not a function"; + } + // parameter_replication = [] or [false] is equivalent to + // [false,...,false] and parameter_replication = [true] means + // [true,...,true] + if (arrayAttr.size() == 0 || arrayAttr.size() == 1) return success(); + auto num_leaf_buffers = + getNumLeafBuffers(func.getArgumentTypes()[argIndex]); + if ((size_t)num_leaf_buffers != arrayAttr.size()) + return op->emitOpError() + << "parameter_replication: arg " << argIndex << " has " + << num_leaf_buffers << " leaf_buffers, but parameter_replication" + << " expects " << arrayAttr.size(); + } return success(); } @@ -9262,6 +7568,58 @@ LogicalResult MhloDialect::verifyOperationAttribute(Operation* op, << "attribute " << attr.getName() << " can only be used on function-like operations"; } + if (attr.getName() == "mhlo.cross_program_prefetches") { + auto arrayAttr = attr.getValue().dyn_cast(); + if (!arrayAttr) + return op->emitOpError() << "cross_program_prefetches must be an array"; + for (auto attrElt : arrayAttr) { + auto prefetchAttr = attrElt.dyn_cast(); + if (!prefetchAttr) + return op->emitOpError() << "cross_program_prefetches must be an array " + "of cross_program_prefetch attrs"; + auto module = dyn_cast(op); + if (!module) + return op->emitOpError() + << "has cross_program_prefetches but is not a module"; + auto res = verifyCrossProgramPrefetchAttr(prefetchAttr, module); + if (failed(res)) return res; + } + } + if (attr.getName() == "mhlo.dynamic_parameter_bindings") { + auto arrayAttr = attr.getValue().dyn_cast(); + if (!arrayAttr) + return op->emitOpError() << "dynamic_parameter_bindings must be an array"; + auto module = dyn_cast(op); + if (!module) + return op->emitOpError() + << "has dynamic_parameter_bindings but is not a module"; + for (auto attrElt : arrayAttr) { + auto bindingAttr = attrElt.dyn_cast(); + if (!bindingAttr) + return op->emitOpError() << "dynamic_parameter_bindings must be an " + "array of dynamic_parameter_binding attrs"; + auto res = verifyDynamicParameterBinding(bindingAttr, module); + if (failed(res)) return res; + } + } + if (attr.getName() == "mhlo.spmd_parameters_sharding") { + auto arrayAttr = attr.getValue().dyn_cast(); + if (!arrayAttr) + return op->emitOpError() << "spmd_parameters_sharding: must be an array"; + auto module = dyn_cast(op); + if (!module) + return op->emitOpError() + << "has spmd_paramters_sharding but is not a module"; + // Check that the "main" function exists: + auto main = module.lookupSymbol("main"); + if (!main) + return module.emitOpError() << "spmd_parameters_sharding: main not found"; + if (main.getNumArguments() != arrayAttr.size()) + return module.emitOpError() + << "spmd_parameters_sharding: main has " << main.getNumArguments() + << " arguments, but spmd_parameters_sharding expects " + << arrayAttr.size(); + } return success(); } diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h similarity index 88% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h rename to tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h index a53c763b5f4..9b54a8494a8 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.h @@ -15,8 +15,8 @@ limitations under the License. // This file defines the operations used in the MHLO dialect. -#ifndef MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H -#define MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H +#ifndef MLIR_HLO_MHLO_IR_HLO_OPS_H +#define MLIR_HLO_MHLO_IR_HLO_OPS_H #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Quant/QuantTypes.h" @@ -37,11 +37,11 @@ limitations under the License. #include "stablehlo/dialect/Base.h" // Include order below matters. -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_enums.h.inc" +#include "mhlo/IR/hlo_ops_enums.h.inc" #define GET_ATTRDEF_CLASSES -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_attrs.h.inc" +#include "mhlo/IR/hlo_ops_attrs.h.inc" #define GET_TYPEDEF_CLASSES -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_typedefs.h.inc" +#include "mhlo/IR/hlo_ops_typedefs.h.inc" namespace mlir { class OpBuilder; @@ -92,13 +92,11 @@ void printConvolutionDimensions(AsmPrinter &p, Operation *, ParseResult parseConvolutionDimensions(AsmParser &parser, ConvDimensionNumbersAttr &dnums); -FailureOr> parseIntArray(AsmParser &parser); -void printIntArray(AsmPrinter &printer, ArrayRef ints); } // end namespace mhlo } // end namespace mlir #define GET_OP_CLASSES -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" +#include "mhlo/IR/hlo_ops.h.inc" namespace mlir { namespace mhlo { @@ -111,4 +109,4 @@ SortOp createSortOp(PatternRewriter *rewriter, const Location &loc, } // end namespace mhlo } // end namespace mlir -#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H +#endif // MLIR_HLO_MHLO_IR_HLO_OPS_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.td similarity index 75% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td rename to tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.td index cf5f253f29c..d76ca1c002e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -23,11 +23,11 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/OpBase.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.td" +include "mhlo/IR/hlo_utils.td" +include "mhlo/IR/hlo_ops_common.td" -class HLO_Op traits> : - Op { +class MHLO_Op traits> : + Op { // Whether this operation has a custom conversion to HLO or not. bit hasCustomHLOConverter = 0b0; @@ -40,8 +40,8 @@ class HLO_Op traits> : }]; } -class HLO_ShapedInterfaceOp traits> : - HLO_Op traits> : + MHLO_Op]> { let extraClassDeclaration = [{ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { @@ -54,7 +54,7 @@ class HLO_ShapedInterfaceOp traits> : // MHLO nullary op definitions. //===----------------------------------------------------------------------===// -def HLO_ConstantOp : HLO_Op<"constant", +def MHLO_ConstantOp : MHLO_Op<"constant", [ConstantLike, Pure, DeclareOpInterfaceMethods]> { let summary = "Constant operator"; let description = [{ @@ -65,7 +65,7 @@ def HLO_ConstantOp : HLO_Op<"constant", ); let results = (outs - HLO_StaticShapeTensor:$output + MHLO_StaticShapeTensor:$output ); let builders = [ @@ -83,14 +83,14 @@ def HLO_ConstantOp : HLO_Op<"constant", }]; } -def HLO_IotaOp : HLO_Op<"iota", [Pure]> { +def MHLO_IotaOp : MHLO_Op<"iota", [Pure]> { let summary = "Iota operator"; let description = [{ Creates a rank 1 array of values starting at zero and incrementing by one. }]; let arguments = (ins I64Attr:$iota_dimension); - let results = (outs HLO_IntFpOrComplexTensor:$output); + let results = (outs MHLO_StaticShapeIntFpOrComplexTensor:$output); // TODO(b/130357376): Iota has special conversion logic to HLO. let hasCustomHLOConverter = 1; @@ -99,7 +99,7 @@ def HLO_IotaOp : HLO_Op<"iota", [Pure]> { let hasVerifier = 1; } -def HLO_DynamicIotaOp: HLO_ShapedInterfaceOp<"dynamic_iota", [Pure]> { +def MHLO_DynamicIotaOp: MHLO_ShapedInterfaceOp<"dynamic_iota", [Pure]> { let summary = "Create linear increasing values from 0 to length -1."; let description = [{ Produces an HLO Tensor of the specified shape, with an incremental set of @@ -109,8 +109,8 @@ def HLO_DynamicIotaOp: HLO_ShapedInterfaceOp<"dynamic_iota", [Pure]> { - The output length of the tensor result. }]; - let arguments = (ins HLO_DimensionTensor:$output_shape, I64Attr:$iota_dimension); - let results = (outs HLO_Tensor:$result); + let arguments = (ins MHLO_DimensionTensor:$output_shape, I64Attr:$iota_dimension); + let results = (outs MHLO_Tensor:$result); let hasCanonicalizer = 1; // Cannot be exported to legacy formats. @@ -118,7 +118,8 @@ def HLO_DynamicIotaOp: HLO_ShapedInterfaceOp<"dynamic_iota", [Pure]> { } -def HLO_CreateTokenOp : HLO_Op<"create_token", [Pure]> { +def MHLO_CreateTokenOp : MHLO_Op<"create_token", [Pure, + DeclareOpInterfaceMethods]> { let summary = "Create Token operator"; let description = [{ @@ -133,7 +134,7 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [Pure]> { ``` }]; - let results = (outs HLO_Token:$output); + let results = (outs MHLO_Token:$output); let assemblyFormat = "attr-dict `:` type(results)"; } @@ -143,8 +144,8 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [Pure]> { //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions -class HLO_UnaryElementwiseOp traits, - Type OperandType, Type ResultType = OperandType> : HLO_Op traits, + Type OperandType, Type ResultType = OperandType> : MHLO_Op { let arguments = (ins OperandType:$operand); let results = (outs ResultType:$result); @@ -168,10 +169,10 @@ class HLO_UnaryElementwiseOp traits, } // Abs supports complex to real, so element type is not guaranteed to match. -def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", +def MHLO_AbsOp: MHLO_UnaryElementwiseOp<"abs", [Pure, DeclareOpInterfaceMethods], - TensorOf<[HLO_SInt, HLO_Float, HLO_Complex]>> { + TensorOf<[MHLO_SInt, MHLO_Float, MHLO_Complex]>> { let summary = "Absolute value operator"; let description = [{ Returns `abs(operand)` element-wise. @@ -185,10 +186,11 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", %0 = mhlo.abs %arg0 : tensor<3xi32> ``` }]; + let hasFolder = 1; } -def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpTensor> { +def MHLO_CbrtOp: MHLO_UnaryElementwiseOp<"cbrt", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrComplexTensor> { let summary = "Cubic root operator"; let description = [{ Returns element-wise cubic root of the operand. @@ -203,8 +205,8 @@ def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt", ``` }]; } -def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpTensor> { +def MHLO_CeilOp: MHLO_UnaryElementwiseOp<"ceil", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpTensor> { let summary = "Ceil operator"; let description = [{ Returns `Ceil(operand)` element-wise. @@ -219,8 +221,8 @@ def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil", ``` }]; } -def HLO_ConvertOp : HLO_UnaryElementwiseOp<"convert", - [Pure, SameOperandsAndResultShape], HLO_Tensor> { +def MHLO_ConvertOp : MHLO_UnaryElementwiseOp<"convert", + [Pure, SameOperandsAndResultShape], MHLO_Tensor> { let summary = "Convert operator"; let description = [{ Performs element-wise conversion of values from one type to another, e.g. @@ -244,8 +246,8 @@ def HLO_ConvertOp : HLO_UnaryElementwiseOp<"convert", let hasCustomHLOConverter = 1; } -def HLO_ClzOp: HLO_UnaryElementwiseOp<"count_leading_zeros", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_IntTensor> { +def MHLO_ClzOp: MHLO_UnaryElementwiseOp<"count_leading_zeros", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntTensor> { let summary = "Count-leading-zeros (Clz) operator"; let description = [{ Returns the number of leading zeros in each operand element-wise. @@ -261,8 +263,8 @@ def HLO_ClzOp: HLO_UnaryElementwiseOp<"count_leading_zeros", }]; } -def HLO_CosineOp: HLO_UnaryElementwiseOp<"cosine", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { +def MHLO_CosineOp: MHLO_UnaryElementwiseOp<"cosine", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrComplexTensor> { let summary = "Cos operator"; let description = [{ Returns `Cos(operand)` element-wise. @@ -281,8 +283,8 @@ def HLO_CosineOp: HLO_UnaryElementwiseOp<"cosine", let hasCustomHLOConverter = 1; } -def HLO_ExpOp: HLO_UnaryElementwiseOp<"exponential", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { +def MHLO_ExpOp: MHLO_UnaryElementwiseOp<"exponential", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrComplexTensor> { let summary = "Exponential operator"; let description = [{ Returns `e^(operand)` element-wise. @@ -298,8 +300,8 @@ def HLO_ExpOp: HLO_UnaryElementwiseOp<"exponential", }]; let hasFolder = 1; } -def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { +def MHLO_Expm1Op: MHLO_UnaryElementwiseOp<"exponential_minus_one", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrComplexTensor> { let summary = "Exponential minus one operator"; let description = [{ Returns `e^(operand) - 1` element-wise. @@ -314,8 +316,8 @@ def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one", ``` }]; } -def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpTensor> { +def MHLO_FloorOp: MHLO_UnaryElementwiseOp<"floor", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpTensor> { let summary = "Floor operator"; let description = [{ Returns `Floor(operand)` element-wise. @@ -330,9 +332,9 @@ def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", ``` }]; } -def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag", +def MHLO_ImagOp: MHLO_UnaryElementwiseOp<"imag", [Pure, DeclareOpInterfaceMethods], - HLO_FpOrComplexTensor, HLO_FpTensor> { + MHLO_FpOrComplexTensor, MHLO_FpTensor> { let summary = "Imag operator"; let description = [{ Returns `Imag(operand)` element-wise. @@ -348,8 +350,8 @@ def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag", } -def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite", [Pure, - DeclareOpInterfaceMethods], HLO_Tensor> { +def MHLO_IsFiniteOp: MHLO_UnaryElementwiseOp<"is_finite", [Pure, + DeclareOpInterfaceMethods], MHLO_Tensor> { let summary = "IsFinite operator"; let description = [{ Tests whether each element of operand is finite, i.e., is not positive or @@ -366,16 +368,16 @@ def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite", [Pure, %0 = mhlo.is_finite %arg0 : (tensor<2xf32>) -> tensor<2xi1> ``` }]; - let arguments = (ins HLO_FpTensor:$x); - let results = (outs HLO_PredTensor:$y); + let arguments = (ins MHLO_FpTensor:$x); + let results = (outs MHLO_PredTensor:$y); let assemblyFormat = [{ $x attr-dict `:` functional-type(operands, results) }]; } -def HLO_LogOp: HLO_UnaryElementwiseOp<"log", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { +def MHLO_LogOp: MHLO_UnaryElementwiseOp<"log", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrComplexTensor> { let summary = "Logarithm operator"; let description = [{ Returns `log(operand)` element-wise. @@ -389,9 +391,10 @@ def HLO_LogOp: HLO_UnaryElementwiseOp<"log", %0 = mhlo.log %arg0 : tensor<2xf32> ``` }]; + let hasFolder = 1; } -def HLO_Log1pOp: HLO_UnaryElementwiseOp<"log_plus_one", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { +def MHLO_Log1pOp: MHLO_UnaryElementwiseOp<"log_plus_one", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrComplexTensor> { let summary = "Log1p operator"; let description = [{ Returns `log(operand+1)` element-wise. @@ -406,8 +409,8 @@ def HLO_Log1pOp: HLO_UnaryElementwiseOp<"log_plus_one", ``` }]; } -def HLO_LogisticOp: HLO_UnaryElementwiseOp<"logistic", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { +def MHLO_LogisticOp: MHLO_UnaryElementwiseOp<"logistic", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrComplexTensor> { let summary = "Logistic operator"; let description = [{ Returns `logistic(operand)` element-wise. @@ -423,12 +426,12 @@ def HLO_LogisticOp: HLO_UnaryElementwiseOp<"logistic", }]; let hasFolder = 1; } -def HLO_NotOp: HLO_UnaryElementwiseOp<"not", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_PredOrIntTensor> { +def MHLO_NotOp: MHLO_UnaryElementwiseOp<"not", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_PredOrIntTensor> { let summary = "Not operator"; let description = [{ Returns biwise-NOT of `operand` element-wise. The input tensor must be - of type integer `HLO_Int` or boolean `HLO_Pred`. + of type integer `MHLO_Int` or boolean `MHLO_Pred`. Note: For boolean tensor, the bitwise-NOT is equivalent to logical-NOT. @@ -441,8 +444,8 @@ def HLO_NotOp: HLO_UnaryElementwiseOp<"not", let hasFolder = 1; } -def HLO_NegOp: HLO_UnaryElementwiseOp<"negate", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_IntFpOrComplexTensor> { +def MHLO_NegOp: MHLO_UnaryElementwiseOp<"negate", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntFpOrComplexTensor> { let summary = "Negation operator"; let description = [{ Returns `-operand` element-wise. @@ -459,8 +462,8 @@ def HLO_NegOp: HLO_UnaryElementwiseOp<"negate", let hasFolder = 1; } -def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_IntTensor> { +def MHLO_PopulationCountOp: MHLO_UnaryElementwiseOp<"popcnt", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntTensor> { let summary = "PopulationCount operator"; let description = [{ Returns the number of bits set in each operand element-wise. @@ -475,9 +478,9 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", ``` }]; } -def HLO_RealOp: HLO_UnaryElementwiseOp<"real", +def MHLO_RealOp: MHLO_UnaryElementwiseOp<"real", [Pure, DeclareOpInterfaceMethods], - HLO_FpOrComplexTensor, HLO_FpTensor> { + MHLO_FpOrComplexTensor, MHLO_FpTensor> { let summary = "Real operator"; let description = [{ Returns `Real(operand)` element-wise. @@ -491,8 +494,8 @@ def HLO_RealOp: HLO_UnaryElementwiseOp<"real", let hasFolder = 1; } -def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpTensor> { +def MHLO_RoundOp: MHLO_UnaryElementwiseOp<"round_nearest_afz", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpTensor> { let summary = "Round operator, ties away from zero"; let description = [{ Returns `Round(operand)` element-wise, rounding to nearest integer with @@ -510,8 +513,8 @@ def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz", let hasFolder = 1; } -def HLO_RoundNearestEvenOp: HLO_UnaryElementwiseOp<"round_nearest_even", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpTensor> { +def MHLO_RoundNearestEvenOp: MHLO_UnaryElementwiseOp<"round_nearest_even", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpTensor> { let summary = "Round operator, ties to even"; let description = [{ Returns `Round(operand)` element-wise, rounding to nearest integer with @@ -529,8 +532,8 @@ def HLO_RoundNearestEvenOp: HLO_UnaryElementwiseOp<"round_nearest_even", let hasFolder = 1; } -def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { +def MHLO_RsqrtOp: MHLO_UnaryElementwiseOp<"rsqrt", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrComplexTensor> { let summary = "Reciprocal Square-root operator"; let description = [{ Returns `1.0 / sqrt(operand)` element-wise. @@ -547,9 +550,9 @@ def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", let hasFolder = 1; } -def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", +def MHLO_SignOp: MHLO_UnaryElementwiseOp<"sign", [Pure, HLO_CompatibleOperandsAndResultType], - TensorOf<[HLO_SInt, HLO_Float, HLO_Complex]>> { + TensorOf<[MHLO_SInt, MHLO_Float, MHLO_Complex]>> { let summary = "Sign operator"; let description = [{ Returns `sign(operand)` element-wise, where @@ -574,8 +577,8 @@ def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", let hasFolder = 1; } -def HLO_SineOp: HLO_UnaryElementwiseOp<"sine", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { +def MHLO_SineOp: MHLO_UnaryElementwiseOp<"sine", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrComplexTensor> { let summary = "Sin operator"; let description = [{ Returns `Sin(operand)` element-wise. @@ -593,8 +596,27 @@ def HLO_SineOp: HLO_UnaryElementwiseOp<"sine", let hasCustomHLOConverter = 1; } -def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt", - [Pure, HLO_CompatibleOperandsAndResultType], HLO_FpOrComplexTensor> { +def MHLO_TanOp: MHLO_UnaryElementwiseOp<"tan", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrComplexTensor> { + let summary = "Tan operator"; + let description = [{ + Returns `Tan(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + + Example: + + ```mlir + %0 = mhlo.tan %arg0 : tensor<2xf32> + ``` + }]; + let hasFolder = 1; + let hasCustomHLOConverter = 1; +} + +def MHLO_SqrtOp: MHLO_UnaryElementwiseOp<"sqrt", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrComplexTensor> { let summary = "Square-root operator"; let description = [{ Returns `sqrt(operand)` element-wise. @@ -611,9 +633,9 @@ def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt", let hasFolder = 1; } -def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", +def MHLO_TanhOp: MHLO_UnaryElementwiseOp<"tanh", [Pure, HLO_CompatibleOperandsAndResultType], - HLO_FpOrComplexTensor> { + MHLO_FpOrComplexTensor> { let summary = "Tanh operator"; let description = [{ Returns `tanh(operand)` element-wise. @@ -634,13 +656,13 @@ def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations -// TODO(gleasonk): Merge HLO_BinaryElementwiseOp and HLO_BinaryElementwiseOpNoAssembly -class HLO_BinaryElementwiseOpNoAssembly traits> : - HLO_Op traits, + Type OperandType = MHLO_Tensor, Type ResultType = OperandType> : + MHLO_Op { let arguments = (ins - HLO_Tensor:$lhs, - HLO_Tensor:$rhs + OperandType:$lhs, + OperandType:$rhs ); let extraClassDeclaration = [{ @@ -656,18 +678,15 @@ class HLO_BinaryElementwiseOpNoAssembly traits> : } }]; - let results = (outs HLO_Tensor:$result); -} + let results = (outs ResultType:$result); -class HLO_BinaryElementwiseOp traits> : - HLO_BinaryElementwiseOpNoAssembly { let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` custom(type($lhs), type($rhs), type($result)) }]; } -def HLO_AddOp : HLO_BinaryElementwiseOp<"add", +def MHLO_AddOp : MHLO_BinaryElementwiseOp<"add", [Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { let summary = "Addition operator"; let description = [{ @@ -686,8 +705,8 @@ def HLO_AddOp : HLO_BinaryElementwiseOp<"add", let hasFolder = 1; } -def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", - [Pure, HLO_CompatibleOperandsAndResultType]> { +def MHLO_Atan2Op : MHLO_BinaryElementwiseOp<"atan2", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_FpOrComplexTensor> { let summary = "Atan2 operator"; let description = [{ Returns `atan2(lhs/rhs)` element-wise. @@ -703,8 +722,9 @@ def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", }]; } -def HLO_ComplexOp: HLO_BinaryElementwiseOpNoAssembly<"complex", [Pure, - SameOperandsElementType, DeclareOpInterfaceMethods]> { +def MHLO_ComplexOp: MHLO_BinaryElementwiseOp<"complex", [Pure, + SameOperandsElementType, SameOperandsAndResultShape, + DeclareOpInterfaceMethods]> { let summary = "Complex operator"; let description = [{ Performs element-wise conversion of a pair of real and imaginary values to @@ -717,8 +737,8 @@ def HLO_ComplexOp: HLO_BinaryElementwiseOpNoAssembly<"complex", [Pure, ``` }]; - let arguments = (ins HLO_Fp32Or64Tensor:$lhs, HLO_Fp32Or64Tensor:$rhs); - let results = (outs HLO_ComplexTensor:$result); + let arguments = (ins MHLO_Fp32Or64Tensor:$lhs, MHLO_Fp32Or64Tensor:$rhs); + let results = (outs MHLO_ComplexTensor:$result); let hasFolder = 1; @@ -728,8 +748,8 @@ def HLO_ComplexOp: HLO_BinaryElementwiseOpNoAssembly<"complex", [Pure, }]; } -def HLO_DivOp : HLO_BinaryElementwiseOp<"divide", - [Pure, HLO_CompatibleOperandsAndResultType]> { +def MHLO_DivOp : MHLO_BinaryElementwiseOp<"divide", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntFpOrComplexTensor> { let summary = "Division operator"; let description = [{ Returns `lhs / rhs` element-wise. @@ -746,7 +766,7 @@ def HLO_DivOp : HLO_BinaryElementwiseOp<"divide", let hasFolder = 1; } -def HLO_MaxOp : HLO_BinaryElementwiseOp<"maximum", +def MHLO_MaxOp : MHLO_BinaryElementwiseOp<"maximum", [Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { let summary = "Maximum operator"; let description = [{ @@ -764,7 +784,7 @@ def HLO_MaxOp : HLO_BinaryElementwiseOp<"maximum", let hasFolder = 1; } -def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum", +def MHLO_MinOp : MHLO_BinaryElementwiseOp<"minimum", [Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { let summary = "Minimum operator"; let description = [{ @@ -782,7 +802,7 @@ def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum", let hasFolder = 1; } -def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply", +def MHLO_MulOp : MHLO_BinaryElementwiseOp<"multiply", [Commutative, Pure, HLO_CompatibleOperandsAndResultType]> { let summary = "Multiplication operator"; let description = [{ @@ -800,8 +820,8 @@ def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply", let hasFolder = 1; } -def HLO_PowOp : HLO_BinaryElementwiseOp<"power", - [Pure, HLO_CompatibleOperandsAndResultType]> { +def MHLO_PowOp : MHLO_BinaryElementwiseOp<"power", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntFpOrComplexTensor> { let summary = "Power operator"; let description = [{ Returns `lhs ^ rhs` element-wise. @@ -816,8 +836,8 @@ def HLO_PowOp : HLO_BinaryElementwiseOp<"power", ``` }]; } -def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", - [Pure, HLO_CompatibleOperandsAndResultType]> { +def MHLO_RemOp : MHLO_BinaryElementwiseOp<"remainder", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntFpOrComplexTensor> { let summary = "Remainder operator"; let description = [{ Returns `lhs % rhs` element-wise. @@ -828,14 +848,14 @@ def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", Example: ```mlir - %0 = mhlo.remainder %arg0, %arg1 : (ensor<4xi64> + %0 = mhlo.remainder %arg0, %arg1 : tensor<4xi64> ``` }]; let hasFolder = 1; } -def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", - [Pure, HLO_CompatibleOperandsAndResultType]> { +def MHLO_ShiftLeftOp : MHLO_BinaryElementwiseOp<"shift_left", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntTensor> { let summary = "Shift Left operator"; let description = [{ Returns `lhs << rhs` element-wise. @@ -851,8 +871,8 @@ def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", }]; } -def HLO_ShiftRightArithmeticOp : HLO_BinaryElementwiseOp<"shift_right_arithmetic", - [Pure, HLO_CompatibleOperandsAndResultType]> { +def MHLO_ShiftRightArithmeticOp : MHLO_BinaryElementwiseOp<"shift_right_arithmetic", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntTensor> { let summary = "Shift right arithmetic operator"; let description = [{ Returns arithmetic `lhs >> rhs` element-wise. @@ -869,8 +889,8 @@ def HLO_ShiftRightArithmeticOp : HLO_BinaryElementwiseOp<"shift_right_arithmetic }]; } -def HLO_ShiftRightLogicalOp : HLO_BinaryElementwiseOp<"shift_right_logical", - [Pure, HLO_CompatibleOperandsAndResultType]> { +def MHLO_ShiftRightLogicalOp : MHLO_BinaryElementwiseOp<"shift_right_logical", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntTensor> { let summary = "Shift right logical operator"; let description = [{ Returns logical `lhs >> rhs` element-wise. @@ -886,8 +906,8 @@ def HLO_ShiftRightLogicalOp : HLO_BinaryElementwiseOp<"shift_right_logical", }]; } -def HLO_SubtractOp : HLO_BinaryElementwiseOp<"subtract", - [Pure, HLO_CompatibleOperandsAndResultType]> { +def MHLO_SubtractOp : MHLO_BinaryElementwiseOp<"subtract", + [Pure, HLO_CompatibleOperandsAndResultType], MHLO_IntFpOrComplexTensor> { let summary = "Subtraction operator"; let description = [{ Returns `lhs - rhs` element-wise. @@ -906,7 +926,7 @@ def HLO_SubtractOp : HLO_BinaryElementwiseOp<"subtract", } // TODO(b/232442915): Implement stochastic_convert MHLO once HLO interface is submitted. -def HLO_StochasticConvertOp : HLO_Op<"stochastic_convert", +def MHLO_StochasticConvertOp : MHLO_Op<"stochastic_convert", [Pure, AllShapesMatch<["operand", "random", "result"]>]> { let summary = "Stochastic convert operator"; let description = [{ @@ -914,8 +934,8 @@ def HLO_StochasticConvertOp : HLO_Op<"stochastic_convert", one with stochastic rounding using the random number passed in. }]; - let arguments = (ins HLO_FpTensor:$operand, TensorOf<[HLO_UInt]>:$random); - let results = (outs HLO_Tensor:$result); + let arguments = (ins MHLO_FpTensor:$operand, TensorOf<[MHLO_UInt]>:$random); + let results = (outs MHLO_Tensor:$result); let hasCustomHLOConverter = 1; let hasVerifier = 1; } @@ -925,22 +945,22 @@ def HLO_StochasticConvertOp : HLO_Op<"stochastic_convert", //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations -class HLO_BinaryBiwiseOrLogicalElementwiseOp : - HLO_BinaryElementwiseOp : + MHLO_BinaryElementwiseOp { let arguments = (ins - HLO_PredOrIntTensor:$lhs, - HLO_PredOrIntTensor:$rhs + MHLO_PredOrIntTensor:$lhs, + MHLO_PredOrIntTensor:$rhs ); let hasFolder = 1; } -def HLO_AndOp: HLO_BinaryBiwiseOrLogicalElementwiseOp<"and"> { +def MHLO_AndOp: MHLO_BinaryBiwiseOrLogicalElementwiseOp<"and"> { let summary = "And operator"; let description = [{ Returns biwise-AND of `lhs` and `rhs` element-wise. The input tensors must - be of type integer `HLO_Int` or boolean `HLO_Pred`. + be of type integer `MHLO_Int` or boolean `MHLO_Pred`. Note: For boolean tensor, the bitwise-AND is equivalent to logical-AND. @@ -952,11 +972,11 @@ def HLO_AndOp: HLO_BinaryBiwiseOrLogicalElementwiseOp<"and"> { }]; } -def HLO_OrOp: HLO_BinaryBiwiseOrLogicalElementwiseOp<"or"> { +def MHLO_OrOp: MHLO_BinaryBiwiseOrLogicalElementwiseOp<"or"> { let summary = "Or operator"; let description = [{ Returns biwise-OR of `lhs` and `rhs` element-wise. The input tensors must - be of type integer `HLO_Int` or boolean `HLO_Pred`. + be of type integer `MHLO_Int` or boolean `MHLO_Pred`. Note: For boolean tensor, the bitwise-OR is equivalent to logical-OR. @@ -968,11 +988,11 @@ def HLO_OrOp: HLO_BinaryBiwiseOrLogicalElementwiseOp<"or"> { }]; } -def HLO_XorOp : HLO_BinaryBiwiseOrLogicalElementwiseOp<"xor"> { +def MHLO_XorOp : MHLO_BinaryBiwiseOrLogicalElementwiseOp<"xor"> { let summary = "Xor operator"; let description = [{ Returns biwise-XOR of `lhs` and `rhs` element-wise. The input tensors must - be of type integer `HLO_Int` or boolean `HLO_Pred`. + be of type integer `MHLO_Int` or boolean `MHLO_Pred`. Note: For boolean tensor, the bitwise-XOR is equivalent to logical-XOR. @@ -990,7 +1010,7 @@ def HLO_XorOp : HLO_BinaryBiwiseOrLogicalElementwiseOp<"xor"> { // InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'. // InfeedWithToken allows ordering of infeed HLO instructions using tokens. -def HLO_InfeedOp : HLO_Op<"infeed", []> { +def MHLO_InfeedOp : MHLO_Op<"infeed", []> { let summary = "Infeed operator"; @@ -1009,18 +1029,19 @@ def HLO_InfeedOp : HLO_Op<"infeed", []> { }]; let arguments = (ins - HLO_Token:$token, + MHLO_Token:$token, DefaultValuedStrAttr:$infeed_config, OptionalAttr:$layout ); - let results = (outs Variadic); + let results = (outs Variadic); let hasCustomHLOConverter = 1; let hasVerifier = 1; } // OutfeedOp corresponds to 'OutfeedWithToken' xla client API and not 'Outfeed'. // OutfeedWithToken allows ordering of outfeed HLO instructions using tokens. -def HLO_OutfeedOp : HLO_Op<"outfeed", []> { +def MHLO_OutfeedOp : MHLO_Op<"outfeed", + [DeclareOpInterfaceMethods]> { let summary = "Outfeed operator"; @@ -1033,15 +1054,16 @@ def HLO_OutfeedOp : HLO_Op<"outfeed", []> { }]; let arguments = (ins - Variadic:$inputs, - HLO_Token:$token, + Variadic:$inputs, + MHLO_Token:$token, DefaultValuedStrAttr:$outfeed_config ); - let results = (outs HLO_Token); + let results = (outs MHLO_Token); let hasCustomHLOConverter = 1; } -def HLO_SendOp : HLO_Op<"send", []> { +def MHLO_SendOp : MHLO_Op<"send", + [DeclareOpInterfaceMethods]> { let summary = "Send operator"; @@ -1056,17 +1078,17 @@ def HLO_SendOp : HLO_Op<"send", []> { }]; let arguments = (ins - Variadic:$inputs, - HLO_Token:$token, - ChannelHandle:$channel_handle, + Variadic:$inputs, + MHLO_Token:$token, + MHLO_ChannelHandle:$channel_handle, DefaultValuedOptionalAttr:$is_host_transfer ); - let results = (outs HLO_Token); + let results = (outs MHLO_Token); let hasCustomHLOConverter = 1; } -def HLO_RecvOp : HLO_Op<"recv", []> { +def MHLO_RecvOp : MHLO_Op<"recv", []> { let summary = "Recv operator"; @@ -1082,12 +1104,12 @@ def HLO_RecvOp : HLO_Op<"recv", []> { }]; let arguments = (ins - HLO_Token:$token, - ChannelHandle:$channel_handle, + MHLO_Token:$token, + MHLO_ChannelHandle:$channel_handle, DefaultValuedOptionalAttr:$is_host_transfer ); - let results = (outs Variadic); + let results = (outs Variadic); let hasCustomHLOConverter = 1; let hasVerifier = 1; } @@ -1096,7 +1118,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> { // MHLO parallelism related op definitions. //===----------------------------------------------------------------------===// -def HLO_ReplicaIdOp : HLO_Op<"replica_id", [Pure, +def MHLO_ReplicaIdOp : MHLO_Op<"replica_id", [Pure, DeclareOpInterfaceMethods]> { let summary = "ReplicaId operator"; let description = [{ @@ -1124,7 +1146,7 @@ def HLO_ReplicaIdOp : HLO_Op<"replica_id", [Pure, // MHLO control flow op definitions. //===----------------------------------------------------------------------===// -def HLO_AddDependencyOp : HLO_Op<"add_dependency", [Pure, +def MHLO_AddDependencyOp : MHLO_Op<"add_dependency", [Pure, DeclareOpInterfaceMethods]> { let summary = "AddDependency operator"; let description = [{ @@ -1140,15 +1162,16 @@ def HLO_AddDependencyOp : HLO_Op<"add_dependency", [Pure, ``` }]; - let arguments = (ins HLO_TensorOrToken:$operand, HLO_Token:$token); - let results = (outs HLO_TensorOrToken:$output); + let arguments = (ins MHLO_TensorOrToken:$operand, MHLO_Token:$token); + let results = (outs MHLO_TensorOrToken:$output); let hasCustomHLOConverter = 1; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } -def HLO_AfterAllOp : HLO_Op<"after_all", [Pure]> { +def MHLO_AfterAllOp : MHLO_Op<"after_all", [Pure, + DeclareOpInterfaceMethods]> { let summary = "AfterAll operator"; @@ -1167,8 +1190,8 @@ def HLO_AfterAllOp : HLO_Op<"after_all", [Pure]> { ``` }]; - let arguments = (ins Variadic:$inputs); - let results = (outs HLO_Token:$result); + let arguments = (ins Variadic:$inputs); + let results = (outs MHLO_Token:$result); let assemblyFormat = [{ $inputs attr-dict @@ -1176,7 +1199,7 @@ def HLO_AfterAllOp : HLO_Op<"after_all", [Pure]> { }]; } -def HLO_AsyncStartOp : HLO_Op<"async_start", []> { +def MHLO_AsyncStartOp : MHLO_Op<"async_start", []> { let summary = "AsyncStart operator"; let description = [{ @@ -1208,18 +1231,18 @@ def HLO_AsyncStartOp : HLO_Op<"async_start", []> { }]; let arguments = (ins - Variadic:$inputs, + Variadic:$inputs, FlatSymbolRefAttr:$called_computation, StrAttr:$execution_thread, OptionalAttr:$group_id ); - let results = (outs HLO_AsyncBundle); + let results = (outs MHLO_AsyncBundle); let hasCustomHLOConverter = 1; let hasVerifier = 1; } -def HLO_AsyncUpdateOp : HLO_Op<"async_update", [DeclareOpInterfaceMethods]> { +def MHLO_AsyncUpdateOp : MHLO_Op<"async_update", [DeclareOpInterfaceMethods]> { let summary = "AsyncUpdate operator"; let description = [{ @@ -1230,19 +1253,19 @@ def HLO_AsyncUpdateOp : HLO_Op<"async_update", [DeclareOpInterfaceMethods:$group_id ); - let results = (outs HLO_AsyncBundle); + let results = (outs MHLO_AsyncBundle); let hasVerifier = 1; let hasCustomHLOConverter = 1; } -def HLO_AsyncDoneOp : HLO_Op<"async_done", [DeclareOpInterfaceMethods]> { +def MHLO_AsyncDoneOp : MHLO_Op<"async_done", [DeclareOpInterfaceMethods]> { let summary = "AsyncDone operator"; let description = [{ @@ -1253,13 +1276,13 @@ def HLO_AsyncDoneOp : HLO_Op<"async_done", [DeclareOpInterfaceMethods:$group_id ); - let results = (outs Variadic); + let results = (outs Variadic); let hasVerifier = 1; let hasCustomHLOConverter = 1; } @@ -1267,7 +1290,7 @@ def HLO_AsyncDoneOp : HLO_Op<"async_done", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { @@ -1286,13 +1309,13 @@ def HLO_IfOp: HLO_Op<"if", [ }]; let arguments = (ins - HLO_PredTensor:$pred + MHLO_PredTensor:$pred ); let regions = (region SizedRegion<1>:$true_branch, SizedRegion<1>:$false_branch); - let results = (outs Variadic); + let results = (outs Variadic); // TODO(b/129422361): ConditionalOp has special conversion logic to HLO. let hasCustomHLOConverter = 1; @@ -1303,7 +1326,7 @@ def HLO_IfOp: HLO_Op<"if", [ // Xla Client API has two separate calls for indexed and predicated conditional, // although both eventually map to kConditional HLO. CaseOp maps to indexed // conditional use of kConditional HLO. -def HLO_CaseOp: HLO_Op<"case", [ +def MHLO_CaseOp: MHLO_Op<"case", [ RecursiveMemoryEffects, SingleBlockImplicitTerminator<"ReturnOp">, DeclareOpInterfaceMethods @@ -1326,14 +1349,14 @@ def HLO_CaseOp: HLO_Op<"case", [ let regions = (region VariadicRegion>:$branches); - let results = (outs Variadic); + let results = (outs Variadic); let hasCustomHLOConverter = 1; let hasCanonicalizer = 1; } -def HLO_WhileOp: HLO_Op<"while", [ +def MHLO_WhileOp: MHLO_Op<"while", [ RecursiveMemoryEffects, SingleBlockImplicitTerminator<"ReturnOp">, DeclareOpInterfaceMethods, @@ -1346,11 +1369,11 @@ def HLO_WhileOp: HLO_Op<"while", [ See https://www.tensorflow.org/xla/operation_semantics#while. }]; - let arguments = (ins Variadic:$operand); + let arguments = (ins Variadic:$operand); let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); - let results = (outs Variadic); + let results = (outs Variadic); let extraClassDeclaration = [{ // Method of OpAsmOpInterface used during custom printing to name the block @@ -1372,9 +1395,10 @@ def HLO_WhileOp: HLO_Op<"while", [ let hasCanonicalizer = 1; let hasCustomAssemblyFormat = 1; let hasFolder = 1; + let hasVerifier = 1; } -def HLO_AllGatherOp : HLO_Op<"all_gather", [SameOperandsAndResultElementType]> { +def MHLO_AllGatherOp : MHLO_Op<"all_gather", [SameOperandsAndResultElementType]> { string summary = "AllGather operator"; @@ -1385,13 +1409,13 @@ def HLO_AllGatherOp : HLO_Op<"all_gather", [SameOperandsAndResultElementType]> { }]; let arguments = (ins - HLO_Tensor:$operand, + MHLO_Tensor:$operand, I64Attr:$all_gather_dim, I64ElementsAttr:$replica_groups, - OptionalAttr:$channel_handle, + OptionalAttr:$channel_handle, UnitAttr:$use_global_device_ids ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); // use_global_device_ids is rarely used, so we add simplified builder methods // for convenience. let builders = [ @@ -1405,7 +1429,7 @@ def HLO_AllGatherOp : HLO_Op<"all_gather", [SameOperandsAndResultElementType]> { let hasVerifier = 1; } -def HLO_AllReduceOp : HLO_Op<"all_reduce", +def MHLO_AllReduceOp : MHLO_Op<"all_reduce", [HLO_CompatibleOperandsAndResultType]> { let summary = "AllReduce operator"; let description = [{ @@ -1415,25 +1439,20 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce", }]; let arguments = (ins - HLO_Tensor:$operand, + MHLO_Tensor:$operand, I64ElementsAttr:$replica_groups, - OptionalAttr:$channel_handle, + OptionalAttr:$channel_handle, UnitAttr:$use_global_device_ids ); let regions = (region SizedRegion<1>:$computation); - let results = (outs HLO_Tensor); - // use_global_device_ids is rarely used, so we add a simplified builder method - // for convenience. - let builders = [ - OpBuilder<(ins - "::mlir::Type":$result_type, "::mlir::Value":$operand, - "::mlir::DenseIntElementsAttr":$replica_groups, - "::mlir::mhlo::ChannelHandleAttr":$channel_handle)>]; + let results = (outs MHLO_Tensor); + let hasVerifier = 1; let hasCustomHLOConverter = 1; + let hasVerifier = 1; } -def HLO_ReduceScatterOp : HLO_Op<"reduce_scatter", +def MHLO_ReduceScatterOp : MHLO_Op<"reduce_scatter", [SameOperandsAndResultElementType]> { let summary = "ReduceScatter operator"; let description = [{ @@ -1443,14 +1462,14 @@ def HLO_ReduceScatterOp : HLO_Op<"reduce_scatter", }]; let arguments = (ins - HLO_Tensor:$operand, + MHLO_Tensor:$operand, I64Attr:$scatter_dimension, I64ElementsAttr:$replica_groups, - OptionalAttr:$channel_handle, + OptionalAttr:$channel_handle, UnitAttr:$use_global_device_ids ); let regions = (region SizedRegion<1>:$computation); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); // use_global_device_ids is rarely used, so we add simplified builder methods // for convenience. let builders = [ @@ -1464,21 +1483,42 @@ def HLO_ReduceScatterOp : HLO_Op<"reduce_scatter", let hasVerifier = 1; } -def HLO_AllToAllOp : HLO_Op<"all_to_all", - [Pure, SameOperandsElementType, SameOperandsShape, +def MHLO_AllToAllOp : MHLO_Op<"all_to_all", + [Pure, SameOperandsElementType, SameOperandsShape, SameVariadicOperandSize, InferTensorType]> { let arguments = (ins - HLO_Tensor:$operand, - I64Attr:$split_dimension, - I64Attr:$concat_dimension, - I64Attr:$split_count, - I64ElementsAttr:$replica_groups + // ArrayAllToAll must have exactly one operand, TupleAllToAll at least one. + Variadic:$operand, + // split_dimension, concat_dimension and split_count are present for array + // all-to-all, absent for tuple all-to-all. + OptionalAttr:$split_dimension, + OptionalAttr:$concat_dimension, + OptionalAttr:$split_count, + I64ElementsAttr:$replica_groups, + OptionalAttr:$channel_handle ); - let results = (outs HLO_Tensor); + let results = (outs Variadic); + let hasCustomHLOConverter = 1; + + // channel_handle is only used for the SPMD partitioner, so we add a + // simplified builder method for convenience. + let builders = [ + OpBuilder<(ins + "::mlir::Type":$result_type, "::mlir::Value":$operand, + "::mlir::IntegerAttr": $split_dimension, + "::mlir::IntegerAttr": $concat_dimension, + "::mlir::IntegerAttr": $split_count, + "::mlir::DenseIntElementsAttr": $replica_groups)>, + OpBuilder<(ins + "::mlir::TypeRange":$result_type, "::mlir::ValueRange":$operand, + "::mlir::IntegerAttr": $split_dimension, + "::mlir::IntegerAttr": $concat_dimension, + "::mlir::IntegerAttr": $split_count, + "::mlir::DenseIntElementsAttr": $replica_groups)>]; } -def HLO_ReduceOp: HLO_ShapedInterfaceOp<"reduce", [ +def MHLO_ReduceOp: MHLO_ShapedInterfaceOp<"reduce", [ RecursiveMemoryEffects, SameVariadicOperandSize, SingleBlockImplicitTerminator<"ReturnOp">, @@ -1492,12 +1532,12 @@ def HLO_ReduceOp: HLO_ShapedInterfaceOp<"reduce", [ See https://www.tensorflow.org/xla/operation_semantics#reduce. }]; let arguments = (ins - Variadic:$inputs, - Variadic:$init_values, + Variadic:$inputs, + Variadic:$init_values, I64ElementsAttr:$dimensions ); - let results = (outs Variadic); + let results = (outs Variadic); let hasCanonicalizer = 1; let hasCustomAssemblyFormat = 1; @@ -1515,7 +1555,7 @@ def HLO_ReduceOp: HLO_ShapedInterfaceOp<"reduce", [ //===----------------------------------------------------------------------===// // MHLO tuple op definitions. //===----------------------------------------------------------------------===// -def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [Pure, +def MHLO_GetTupleElementOp: MHLO_Op<"get_tuple_element", [Pure, DeclareOpInterfaceMethods]> { let summary = "GetTupleElement operator"; let description = [{ @@ -1524,21 +1564,20 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [Pure, See https://www.tensorflow.org/xla/operation_semantics#gettupleelement. }]; let arguments = (ins - HLO_Tuple:$operand, + MHLO_Tuple:$operand, I32Attr:$index ); - let results = (outs HLO_TensorOrTokenOrTuple); + let results = (outs MHLO_TensorOrTokenOrTuple); let hasFolder = 1; - let hasVerifier = 1; let assemblyFormat = [{ $operand `[` $index `]` attr-dict `:` functional-type(operands, results) }]; } -def HLO_TupleOp : HLO_Op<"tuple", [Pure, +def MHLO_TupleOp : MHLO_Op<"tuple", [Pure, DeclareOpInterfaceMethods]> { let summary = "XLA's tuple op"; let description = [{ @@ -1552,18 +1591,17 @@ def HLO_TupleOp : HLO_Op<"tuple", [Pure, %0 = mhlo.tuple %arg0, %arg0 : tuple, tensor> ``` }]; - let arguments = (ins Variadic:$val); - let results = (outs HLO_Tuple:$result); + let arguments = (ins Variadic:$val); + let results = (outs MHLO_Tuple:$result); let hasCanonicalizer = 1; - let hasVerifier = 1; let assemblyFormat = [{ $val attr-dict `:` custom(type($val), type($result)) }]; } -def HLO_CompareOp: HLO_Op<"compare", [Pure, SameOperandsElementType, +def MHLO_CompareOp: MHLO_Op<"compare", [Pure, SameOperandsElementType, SameOperandsAndResultShape, Elementwise, InferTensorTypeWithReify]> { let summary = "Comparison operator"; let description = [{ @@ -1583,12 +1621,12 @@ def HLO_CompareOp: HLO_Op<"compare", [Pure, SameOperandsElementType, ``` }]; let arguments = (ins - HLO_Tensor:$lhs, - HLO_Tensor:$rhs, - HLO_ComparisonDirectionAttr:$comparison_direction, - OptionalAttr:$compare_type + MHLO_Tensor:$lhs, + MHLO_Tensor:$rhs, + MHLO_ComparisonDirectionAttr:$comparison_direction, + OptionalAttr:$compare_type ); - let results = (outs HLO_PredTensor); + let results = (outs MHLO_PredTensor); let hasFolder = 1; @@ -1611,25 +1649,25 @@ def HLO_CompareOp: HLO_Op<"compare", [Pure, SameOperandsElementType, // MHLO Slice definitions. //===----------------------------------------------------------------------===// -def HLO_SliceOp: HLO_Op< +def MHLO_SliceOp: MHLO_Op< "slice", [Pure, SameOperandsAndResultElementType, AllTypesMatch<["start_indices", "limit_indices", "strides"]>, DeclareOpInterfaceMethods]> { let arguments = (ins - HLO_Tensor:$operand, + MHLO_Tensor:$operand, I64ElementsAttr:$start_indices, I64ElementsAttr:$limit_indices, I64ElementsAttr:$strides ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); let hasCanonicalizer = 1; let hasFolder = 1; } -def HLO_DynamicSliceOp: HLO_Op<"dynamic_slice", +def MHLO_DynamicSliceOp: MHLO_Op<"dynamic_slice", [Pure, AllElementTypesMatch<["operand", "result"]>, InferTensorType]> { let summary = "Dynamic Slice operator"; @@ -1639,19 +1677,18 @@ def HLO_DynamicSliceOp: HLO_Op<"dynamic_slice", See https://www.tensorflow.org/xla/operation_semantics#dynamicslice. }]; let arguments = (ins - HLO_Tensor:$operand, - Variadic:$start_indices, + MHLO_Tensor:$operand, + Variadic:$start_indices, I64ElementsAttr:$slice_sizes ); - let results = (outs HLO_Tensor:$result); + let results = (outs MHLO_Tensor:$result); let hasCanonicalizer = 1; - let hasVerifier = 1; } -def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic_update_slice", +def MHLO_DynamicUpdateSliceOp: MHLO_Op<"dynamic_update_slice", [Pure, AllElementTypesMatch<["operand", "update", "result"]>, - AllShapesMatch<["operand", "result"]>]> { + InferTensorType]> { let summary = "Dynamic Update Slice operator"; let description = [{ DynamicUpdateSlice generates a result which is the value of the input array @@ -1667,13 +1704,12 @@ def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic_update_slice", ``` }]; let arguments = (ins - HLO_Tensor:$operand, - HLO_Tensor:$update, - Variadic:$start_indices + MHLO_Tensor:$operand, + MHLO_Tensor:$update, + Variadic:$start_indices ); - let results = (outs HLO_Tensor:$result); + let results = (outs MHLO_Tensor:$result); let hasFolder = 1; - let hasVerifier = 1; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } @@ -1683,7 +1719,7 @@ def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic_update_slice", // MHLO Other op definitions. //===----------------------------------------------------------------------===// -def HLO_DomainOp : HLO_Op<"domain", [HLO_CompatibleOperandsAndResultType, InferTypeOpInterface, Pure]> { +def MHLO_DomainOp : MHLO_Op<"domain", [HLO_CompatibleOperandsAndResultType, InferTypeOpInterface, Pure]> { let summary = "Marks groups of instructions (domains) with a property"; let description = [{ Domain instructions are used to group instructions with the same @@ -1699,21 +1735,17 @@ def HLO_DomainOp : HLO_Op<"domain", [HLO_CompatibleOperandsAndResultType, InferT one on the operand side and one on the user side of the domain. }]; let arguments = (ins - HLO_TensorOrToken:$operand, - HLO_DomainKindAttr:$kind, + MHLO_TensorOrToken:$operand, + MHLO_DomainKindAttr:$kind, StrAttr:$entry_metadata, StrAttr:$exit_metadata ); - let results = (outs HLO_TensorOrToken:$result); + let results = (outs MHLO_TensorOrToken:$result); let hasCustomHLOConverter = 1; } -def HLO_BatchNormGradOp : HLO_Op<"batch_norm_grad", [Pure, - AllShapesMatch<["scale", "mean", "variance", "grad_scale", - "grad_offset"]>, - AllShapesMatch<["operand", "grad_output"]>, - AllElementTypesMatch<["operand", "grad_scale", "grad_offset"]>, - AllTypesMatch<["operand", "grad_operand"]>, +def MHLO_BatchNormGradOp : MHLO_Op<"batch_norm_grad", [Pure, + AllElementTypesMatch<["operand", "grad_operand", "grad_scale", "grad_offset"]>, InferTensorType]> { let summary = "Batch Normalization Gradient"; let description = [{ @@ -1723,28 +1755,25 @@ def HLO_BatchNormGradOp : HLO_Op<"batch_norm_grad", [Pure, }]; let arguments = (ins - RankedTensorOf<[HLO_Float]>:$operand, - 1DTensorOf<[HLO_Float]>:$scale, - 1DTensorOf<[HLO_Float]>:$mean, - 1DTensorOf<[HLO_Float]>:$variance, - RankedTensorOf<[HLO_Float]>:$grad_output, + RankedTensorOf<[MHLO_Float]>:$operand, + 1DTensorOf<[MHLO_Float]>:$scale, + 1DTensorOf<[MHLO_Float]>:$mean, + 1DTensorOf<[MHLO_Float]>:$variance, + RankedTensorOf<[MHLO_Float]>:$grad_output, F32Attr:$epsilon, I64Attr:$feature_index ); let results = (outs - RankedTensorOf<[HLO_Float]>:$grad_operand, - 1DTensorOf<[HLO_Float]>:$grad_scale, - 1DTensorOf<[HLO_Float]>:$grad_offset); + RankedTensorOf<[MHLO_Float]>:$grad_operand, + 1DTensorOf<[MHLO_Float]>:$grad_scale, + 1DTensorOf<[MHLO_Float]>:$grad_offset); let hasCustomHLOConverter = 1; - let hasVerifier = 1; } -def HLO_BatchNormInferenceOp : HLO_Op<"batch_norm_inference", - [Pure, AllTypesMatch<["operand", "result"]>, - AllShapesMatch<["scale", "offset", "mean", "variance"]>, - InferTensorType]> { +def MHLO_BatchNormInferenceOp : MHLO_Op<"batch_norm_inference", + [Pure, AllElementTypesMatch<["operand", "result"]>, InferTensorType]> { let summary = "Batch Normalization for Inference"; let description = [{ Normalizes an array across batch and spatial dimensions. @@ -1753,24 +1782,20 @@ def HLO_BatchNormInferenceOp : HLO_Op<"batch_norm_inference", }]; let arguments = (ins - RankedTensorOf<[HLO_Float]>:$operand, - 1DTensorOf<[HLO_Float]>:$scale, - 1DTensorOf<[HLO_Float]>:$offset, - 1DTensorOf<[HLO_Float]>:$mean, - 1DTensorOf<[HLO_Float]>:$variance, + RankedTensorOf<[MHLO_Float]>:$operand, + 1DTensorOf<[MHLO_Float]>:$scale, + 1DTensorOf<[MHLO_Float]>:$offset, + 1DTensorOf<[MHLO_Float]>:$mean, + 1DTensorOf<[MHLO_Float]>:$variance, F32Attr:$epsilon, I64Attr:$feature_index ); - let results = (outs RankedTensorOf<[HLO_Float]>:$result); - - let hasVerifier = 1; + let results = (outs RankedTensorOf<[MHLO_Float]>:$result); } -def HLO_BatchNormTrainingOp : HLO_Op<"batch_norm_training", - [Pure, AllTypesMatch<["operand", "output"]>, - AllElementTypesMatch<["operand", "batch_mean", "batch_var"]>, - AllShapesMatch<["scale", "offset", "batch_mean", "batch_var"]>, +def MHLO_BatchNormTrainingOp : MHLO_Op<"batch_norm_training", + [Pure, AllElementTypesMatch<["operand", "output", "batch_mean", "batch_var"]>, InferTensorType]> { let summary = "Batch Normalization for Training"; let description = [{ @@ -1780,23 +1805,22 @@ def HLO_BatchNormTrainingOp : HLO_Op<"batch_norm_training", }]; let arguments = (ins - RankedTensorOf<[HLO_Float]>:$operand, - 1DTensorOf<[HLO_Float]>:$scale, - 1DTensorOf<[HLO_Float]>:$offset, + RankedTensorOf<[MHLO_Float]>:$operand, + 1DTensorOf<[MHLO_Float]>:$scale, + 1DTensorOf<[MHLO_Float]>:$offset, F32Attr:$epsilon, I64Attr:$feature_index ); let results = (outs - RankedTensorOf<[HLO_Float]>:$output, - 1DTensorOf<[HLO_Float]>:$batch_mean, - 1DTensorOf<[HLO_Float]>:$batch_var); + RankedTensorOf<[MHLO_Float]>:$output, + 1DTensorOf<[MHLO_Float]>:$batch_mean, + 1DTensorOf<[MHLO_Float]>:$batch_var); - let hasVerifier = 1; let hasCustomHLOConverter = 1; } -def HLO_BitcastConvertOp : HLO_ShapedInterfaceOp<"bitcast_convert", +def MHLO_BitcastConvertOp : MHLO_ShapedInterfaceOp<"bitcast_convert", [Pure]> { let summary = "BitcastConvert operator"; let description = [{ @@ -1815,15 +1839,15 @@ def HLO_BitcastConvertOp : HLO_ShapedInterfaceOp<"bitcast_convert", ``` }]; - let arguments = (ins HLO_Tensor:$operand); - let results = (outs HLO_Tensor); + let arguments = (ins MHLO_Tensor:$operand); + let results = (outs MHLO_Tensor); let hasVerifier = 1; let hasCustomHLOConverter = 1; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } -def HLO_BroadcastOp : HLO_ShapedInterfaceOp<"broadcast", +def MHLO_BroadcastOp : MHLO_ShapedInterfaceOp<"broadcast", [Pure, SameOperandsAndResultElementType, InferTensorType]> { let summary = "Broadcast a tensor to a higher rank by prepending dimensions"; let description = [{ @@ -1838,17 +1862,16 @@ def HLO_BroadcastOp : HLO_ShapedInterfaceOp<"broadcast", See https://www.tensorflow.org/xla/operation_semantics#broadcast. }]; let arguments = (ins - HLO_Tensor:$operand, + MHLO_Tensor:$operand, I64ElementsAttr:$broadcast_sizes ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); let hasFolder = 1; - let hasVerifier = 1; } -def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim", +def MHLO_BroadcastInDimOp : MHLO_Op<"broadcast_in_dim", [Pure, SameOperandsAndResultElementType]> { let summary = "Broadcast a tensor into the given shape by adding dimensions."; let description = [{ @@ -1868,11 +1891,11 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim", See https://www.tensorflow.org/xla/broadcasting. }]; let arguments = (ins - HLO_Tensor:$operand, - BroadcastDimAttr:$broadcast_dimensions + MHLO_Tensor:$operand, + MHLO_BroadcastDimAttr:$broadcast_dimensions ); - let results = (outs HLO_StaticShapeTensor); + let results = (outs MHLO_StaticShapeTensor); let hasFolder = 1; let hasCanonicalizer = 1; @@ -1881,7 +1904,7 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim", let hasCustomHLOConverter = 1; } -def HLO_DynamicBroadcastInDimOp : HLO_ShapedInterfaceOp< +def MHLO_DynamicBroadcastInDimOp : MHLO_ShapedInterfaceOp< "dynamic_broadcast_in_dim", [Pure]> { let summary = "Broadcast a tensor into the given dynamic shape by adding dimensions."; let description = [{ @@ -1897,14 +1920,14 @@ def HLO_DynamicBroadcastInDimOp : HLO_ShapedInterfaceOp< must be disjoint and they must be a subset of the operand's dimensions. }]; let arguments = (ins - HLO_Tensor:$operand, - HLO_DimensionTensor:$output_dimensions, - BroadcastDimAttr:$broadcast_dimensions, - OptionalAttr:$known_expanding_dimensions, - OptionalAttr:$known_nonexpanding_dimensions + MHLO_Tensor:$operand, + MHLO_DimensionTensor:$output_dimensions, + MHLO_BroadcastDimAttr:$broadcast_dimensions, + OptionalAttr:$known_expanding_dimensions, + OptionalAttr:$known_nonexpanding_dimensions ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); let builders = [ OpBuilder<(ins @@ -1922,11 +1945,11 @@ def HLO_DynamicBroadcastInDimOp : HLO_ShapedInterfaceOp< let hasCustomHLOConverter = 1; } -// Note: There is no HLO_CallOp because the standard call operation mlir::func::CallOp +// Note: There is no MHLO_CallOp because the standard call operation mlir::func::CallOp // is used instead. A mlir::func::CallOp is exported to a HLO call instruction // directly. -def HLO_CholeskyOp : HLO_Op<"cholesky", +def MHLO_CholeskyOp : MHLO_Op<"cholesky", [Pure, SameOperandsAndResultElementType, InferTensorType]> { let summary = "Cholesky operator"; let description = [{ @@ -1951,14 +1974,14 @@ def HLO_CholeskyOp : HLO_Op<"cholesky", See https://www.tensorflow.org/xla/operation_semantics#cholesky. }]; let arguments = (ins - HLO_FpOrComplexTensor:$a, + MHLO_FpOrComplexTensor:$a, DefaultValuedOptionalAttr:$lower ); - let results = (outs HLO_FpOrComplexTensor); + let results = (outs MHLO_FpOrComplexTensor); } -def HLO_ClampOp : HLO_ShapedInterfaceOp<"clamp", [Pure, +def MHLO_ClampOp : MHLO_ShapedInterfaceOp<"clamp", [Pure, SameOperandsAndResultElementType, HLO_BroadcastingElementwise, InferTensorType]> { let summary = "Clamp operator"; @@ -1980,11 +2003,11 @@ def HLO_ClampOp : HLO_ShapedInterfaceOp<"clamp", [Pure, }]; let arguments = (ins - HLO_Tensor:$min, - HLO_Tensor:$operand, - HLO_Tensor:$max + MHLO_Tensor:$min, + MHLO_Tensor:$operand, + MHLO_Tensor:$max ); - let results = (outs HLO_Tensor:$result); + let results = (outs MHLO_Tensor:$result); let hasVerifier = 1; let hasFolder = 1; @@ -1995,7 +2018,7 @@ def HLO_ClampOp : HLO_ShapedInterfaceOp<"clamp", [Pure, }]; } -def HLO_ConcatenateOp : HLO_ShapedInterfaceOp<"concatenate", +def MHLO_ConcatenateOp : MHLO_ShapedInterfaceOp<"concatenate", [Pure, SameOperandsAndResultElementType, DeclareOpInterfaceMethods]> { let summary = "XLA's concatenate op"; @@ -2006,18 +2029,17 @@ def HLO_ConcatenateOp : HLO_ShapedInterfaceOp<"concatenate", }]; let arguments = (ins - Variadic:$val, + Variadic:$val, I64Attr:$dimension ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); let hasCanonicalizer = 1; let hasFolder = 1; - let hasVerifier = 1; } -def HLO_CollectivePermuteOp: HLO_Op<"collective_permute", +def MHLO_CollectivePermuteOp: MHLO_Op<"collective_permute", [Pure, HLO_CompatibleOperandsAndResultType]> { let summary = "CollectivePermute operator"; let description = [{ @@ -2034,11 +2056,11 @@ def HLO_CollectivePermuteOp: HLO_Op<"collective_permute", }]; let arguments = (ins - HLO_Tensor:$operand, + MHLO_Tensor:$operand, I64ElementsAttr:$source_target_pairs, - OptionalAttr:$channel_handle + OptionalAttr:$channel_handle ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); let hasVerifier = 1; // channel_handle is only used for the SPMD partitioner, so we add a // simplified builder method for convenience. @@ -2048,7 +2070,7 @@ def HLO_CollectivePermuteOp: HLO_Op<"collective_permute", "::mlir::DenseIntElementsAttr":$source_target_pairs)>]; } -def HLO_ConvolutionOp : HLO_ShapedInterfaceOp<"convolution", [Pure]> { +def MHLO_ConvolutionOp : MHLO_Op<"convolution", [Pure]> { let summary = "Convolution operator"; let description = [{ Computes a convolution of the kind used in neural networks. @@ -2057,11 +2079,11 @@ def HLO_ConvolutionOp : HLO_ShapedInterfaceOp<"convolution", [Pure]> { }]; let arguments = !con( (ins - HLO_Tensor:$lhs, - HLO_Tensor:$rhs), - ConvolutionAttributes.attributes); + MHLO_Tensor:$lhs, + MHLO_Tensor:$rhs), + MHLO_ConvolutionAttributes.attributes); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); let hasCanonicalizer = 1; let hasCustomHLOConverter = 1; let hasVerifier = 1; @@ -2084,7 +2106,7 @@ def HLO_ConvolutionOp : HLO_ShapedInterfaceOp<"convolution", [Pure]> { }]; } -def HLO_CopyOp: HLO_Op<"copy", [Pure, HLO_CompatibleOperandsAndResultType]> { +def MHLO_CopyOp: MHLO_Op<"copy", [Pure, HLO_CompatibleOperandsAndResultType]> { let summary = "Copy operator"; let description = [{ Returns a copy of `operand`. @@ -2096,10 +2118,10 @@ def HLO_CopyOp: HLO_Op<"copy", [Pure, HLO_CompatibleOperandsAndResultType]> { ``` }]; let arguments = (ins - HLO_Tensor:$operand, - UnitAttr:$is_cross_program_prefetch + MHLO_TensorOrTokenOrTuple:$operand, + OptionalAttr:$cross_program_prefetch_index ); - let results = (outs HLO_Tensor:$result); + let results = (outs MHLO_TensorOrTokenOrTuple:$result); let hasCustomHLOConverter = 1; let hasFolder = 1; @@ -2109,7 +2131,7 @@ def HLO_CopyOp: HLO_Op<"copy", [Pure, HLO_CompatibleOperandsAndResultType]> { }]; } -def HLO_CrossReplicaSumOp : HLO_Op<"cross-replica-sum", +def MHLO_CrossReplicaSumOp : MHLO_Op<"cross-replica-sum", [Pure, HLO_CompatibleOperandsAndResultType]> { let summary = "Sums input across replicated instances."; let description = [{ @@ -2125,14 +2147,14 @@ def HLO_CrossReplicaSumOp : HLO_Op<"cross-replica-sum", }]; let arguments = (ins - HLO_Tensor:$operand, + MHLO_Tensor:$operand, I64ElementsAttr:$replica_groups ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); } -def HLO_CustomCallOp: HLO_Op<"custom_call", +def MHLO_CustomCallOp: MHLO_Op<"custom_call", [DeclareOpInterfaceMethods]> { let summary = "CustomCall operator"; let description = [{ @@ -2142,8 +2164,29 @@ def HLO_CustomCallOp: HLO_Op<"custom_call", backend, a call instruction is emitted which targets a symbol with the name `call_target_name`. + If XLA runtime is enabled for a backend, then custom calls use the runtime + custom call calling convention to call into the external functions. This + calling convention defines an ABI for encoding arguments, attributes and + results. + `call_target_name` and `backend_config` can be arbitrary strings, but `call_target_name` should be short as it may be used in labels. + + Depending on the API version there are two ways to pass extra bits of static + information to the external function: + + 1. For `API_VERSION_TYPED_FFI` custom calls `backend_config` must be a + dictionary attribute, that will be encoded according to the custom call + calling convention and passed to the external function as the attributes + argument. External code is expected to use declarative bindings (see + `xla/runtime/custom_call.h`) to decode them at run time. + + 2. For previous API versions it is the user responsibility to encode extra + bits of static information as a string `backend_config` attribute, and + decode it at run time. + + `API_VERSION_TYPED_FFI` custom calls only supported if XLA uses XLA runtime. + `backend_config` can encode arbitrarily large amounts of information. `has_side_effect` must be true if the custom call has side-effects. @@ -2168,28 +2211,36 @@ def HLO_CustomCallOp: HLO_Op<"custom_call", result tuple. See https://www.tensorflow.org/xla/operation_semantics#customcall. + + Example: + + ```mlir + %1 = mhlo.custom_call @foo(%arg0, %arg1) {backend_config = "bar", has_side_effect = true} + : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32> + ``` }]; let arguments = (ins - Variadic:$inputs, + Variadic:$inputs, StrAttr:$call_target_name, DefaultValuedOptionalAttr:$has_side_effect, - DefaultValuedOptionalStrAttr:$backend_config, + OptionalAttr>:$backend_config, // TODO(b/189822916): Remove this field when all clients are migrated to // the status-returning API. DefaultValuedOptionalAttr< - HLO_CustomCallApiVersionAttr, + MHLO_CustomCallApiVersionAttr, "::mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL">: $api_version, - DefaultValuedOptionalAttr:$called_computations, - OptionalAttr:$operand_layouts, - OptionalAttr:$result_layouts, + DefaultValuedOptionalAttr:$called_computations, + DefaultValuedOptionalAttr:$custom_call_schedule, + OptionalAttr:$operand_layouts, + OptionalAttr:$result_layouts, DefaultValuedOptionalAttr< TypedArrayAttrBase< - OutputOperandAlias, + MHLO_OutputOperandAlias, "Aliasing attribute for outputs and operands of CustomCall">, "{}">:$output_operand_aliases ); - let results = (outs Variadic); + let results = (outs Variadic); let hasCustomHLOConverter = 1; let hasVerifier = 1; @@ -2205,11 +2256,14 @@ def HLO_CustomCallOp: HLO_Op<"custom_call", "::mlir::ArrayAttr":$called_computations, "::mlir::ArrayAttr":$operand_layouts, "::mlir::ArrayAttr":$result_layouts)>]; + + let assemblyFormat = [{ + custom($call_target_name) `(` $inputs `)` + attr-dict `:` functional-type(operands, results) + }]; } -def HLO_DotOp: HLO_Op<"dot", - [Pure, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { let summary = "Dot operator"; let description = [{ @@ -2219,24 +2273,18 @@ def HLO_DotOp: HLO_Op<"dot", See https://www.tensorflow.org/xla/operation_semantics#dot. }]; let arguments = ( - ins HLO_Tensor:$lhs, - HLO_Tensor:$rhs, - HLO_PrecisionConfigAttr:$precision_config + ins MHLO_Tensor:$lhs, + MHLO_Tensor:$rhs, + MHLO_PrecisionConfigAttr:$precision_config ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); // Dot op required custom exporter to pass the preferred element type // to Xla builder. let hasCustomHLOConverter = 1; let hasVerifier = 1; - - let extraClassDeclaration = [{ - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { - return succeeded(mlir::verifyCompatibleShapes(l, r)); - } - }]; } -def HLO_DotGeneralOp: HLO_ShapedInterfaceOp<"dot_general", [Pure]> { +def MHLO_DotGeneralOp: MHLO_ShapedInterfaceOp<"dot_general", [Pure]> { let summary = "General Dot operator"; let description = [{ Performs general dot products between vectors, vector/matrix and @@ -2245,13 +2293,13 @@ def HLO_DotGeneralOp: HLO_ShapedInterfaceOp<"dot_general", [Pure]> { See https://www.tensorflow.org/xla/operation_semantics#dotgeneral. }]; let arguments = (ins - HLO_Tensor:$lhs, - HLO_Tensor:$rhs, - DotDimensionNumbers:$dot_dimension_numbers, - HLO_PrecisionConfigAttr:$precision_config + MHLO_Tensor:$lhs, + MHLO_Tensor:$rhs, + MHLO_DotDimensionNumbers:$dot_dimension_numbers, + MHLO_PrecisionConfigAttr:$precision_config ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); let hasCanonicalizer = 1; // DotGeneral op required custom exporter to pass the preferred element type // to Xla builder. @@ -2270,26 +2318,26 @@ class BASE_EinsumOp { }]; } -def HLO_EinsumOp: HLO_Op<"einsum", [Pure]>, BASE_EinsumOp { +def MHLO_EinsumOp: MHLO_Op<"einsum", [Pure]>, BASE_EinsumOp { let arguments = (ins - HLO_Tensor:$lhs, - HLO_Tensor:$rhs, + MHLO_Tensor:$lhs, + MHLO_Tensor:$rhs, StrAttr:$einsum_config ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); // TODO(hinsu): Canonicalize to lower this client side HLO op to server // side HLO ops. } -def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [Pure]>, BASE_EinsumOp { +def MHLO_UnaryEinsumOp: MHLO_Op<"unary_einsum", [Pure]>, BASE_EinsumOp { let arguments = (ins - HLO_Tensor:$operand, + MHLO_Tensor:$operand, StrAttr:$einsum_config ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); let hasCanonicalizer = 1; @@ -2298,7 +2346,7 @@ def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [Pure]>, BASE_EinsumOp { let hasCustomHLOConverter = 1; } -def HLO_FftOp: HLO_Op<"fft", [InferTensorType, Pure]> { +def MHLO_FftOp: MHLO_Op<"fft", [InferTensorType, Pure]> { let summary = "Fast fourier transform operator"; let description = [{ Returns the fast-fourier-transform of the input array. @@ -2307,15 +2355,15 @@ def HLO_FftOp: HLO_Op<"fft", [InferTensorType, Pure]> { https://www.tensorflow.org/xla/operation_semantics#fft. }]; let arguments = (ins - HLO_Tensor:$operand, - HLO_FftTypeAttr:$fft_type, + MHLO_Tensor:$operand, + MHLO_FftTypeAttr:$fft_type, I64ElementsAttr:$fft_length ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); } -def HLO_GatherOp: HLO_Op<"gather", [InferTensorTypeWithReify, Pure]> { +def MHLO_GatherOp: MHLO_Op<"gather", [InferTensorTypeWithReify, Pure]> { let summary = "Gather operator"; let description = [{ Stitches together several slices of `operand` from offsets specified in @@ -2325,19 +2373,20 @@ def HLO_GatherOp: HLO_Op<"gather", [InferTensorTypeWithReify, Pure]> { }]; let arguments = (ins - HLO_Tensor:$operand, - HLO_IntTensor:$start_indices, - GatherDimensionNumbers:$dimension_numbers, + MHLO_Tensor:$operand, + MHLO_IntTensor:$start_indices, + MHLO_GatherDimensionNumbers:$dimension_numbers, I64ElementsAttr:$slice_sizes, DefaultValuedOptionalAttr:$indices_are_sorted ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); let hasCanonicalizer = 1; } -def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [Pure]> { +def MHLO_GetDimensionSizeOp: MHLO_Op<"get_dimension_size", [Pure, + DeclareOpInterfaceMethods]> { let summary = "GetDimensionSize operator"; let description = [{ Returns the size of the given dimension of the operand. @@ -2346,7 +2395,7 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [Pure]> { https://www.tensorflow.org/xla/operation_semantics#getdimensionsize. }]; let arguments = (ins - HLO_Tensor:$operand, + MHLO_Tensor:$operand, I64Attr:$dimension ); // TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the @@ -2358,7 +2407,7 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [Pure]> { let hasVerifier = 1; } -def HLO_MapOp: HLO_ShapedInterfaceOp<"map", +def MHLO_MapOp: MHLO_ShapedInterfaceOp<"map", [RecursiveMemoryEffects, SameOperandsAndResultShape, SingleBlockImplicitTerminator<"ReturnOp">, InferTensorTypeWithReify]> { let summary = "Map operator"; @@ -2375,16 +2424,16 @@ def HLO_MapOp: HLO_ShapedInterfaceOp<"map", See https://www.tensorflow.org/xla/operation_semantics#map. }]; let arguments = (ins - Variadic:$inputs, + Variadic:$inputs, I64ElementsAttr:$dimensions ); let regions = (region SizedRegion<1>:$computation); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); let hasFolder = 1; let hasCustomHLOConverter = 1; } -def HLO_ReshapeOp: HLO_Op<"reshape", +def MHLO_ReshapeOp: MHLO_Op<"reshape", [Pure, SameOperandsAndResultElementType]> { let summary = "Reshape operator"; let description = [{ @@ -2399,9 +2448,9 @@ def HLO_ReshapeOp: HLO_Op<"reshape", ``` }]; - let arguments = (ins HLO_Tensor:$operand); + let arguments = (ins MHLO_Tensor:$operand); - let results = (outs HLO_StaticShapeTensor); + let results = (outs MHLO_StaticShapeTensor); let hasFolder = 1; let hasCanonicalizer = 1; let hasVerifier = 1; @@ -2411,7 +2460,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape", let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } -def HLO_DynamicReshapeOp: HLO_ShapedInterfaceOp<"dynamic_reshape", [Pure]> { +def MHLO_DynamicReshapeOp: MHLO_ShapedInterfaceOp<"dynamic_reshape", [Pure]> { let summary = "Reshape a tensor to a given, possibly dynamic, shape."; let description = [{ Reshapes `operand` to `output_shape`. @@ -2429,8 +2478,8 @@ def HLO_DynamicReshapeOp: HLO_ShapedInterfaceOp<"dynamic_reshape", [Pure]> { ``` }]; - let arguments = (ins HLO_Tensor:$operand, HLO_DimensionTensor:$output_shape); - let results = (outs HLO_Tensor:$result); + let arguments = (ins MHLO_Tensor:$operand, MHLO_DimensionTensor:$output_shape); + let results = (outs MHLO_Tensor:$result); let hasCanonicalizer = 1; // Cannot be exported to legacy formats. @@ -2440,7 +2489,9 @@ def HLO_DynamicReshapeOp: HLO_ShapedInterfaceOp<"dynamic_reshape", [Pure]> { let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } -def HLO_ScatterOp: HLO_Op<"scatter", [SameVariadicOperandSize, RecursiveMemoryEffects]> { +def MHLO_ScatterOp: MHLO_Op<"scatter", + [SameVariadicOperandSize, RecursiveMemoryEffects, + DeclareOpInterfaceMethods]> { let summary = "Scatter operator"; let description = [{ Generates a result which is the value of the input array `operand`, @@ -2450,17 +2501,17 @@ def HLO_ScatterOp: HLO_Op<"scatter", [SameVariadicOperandSize, RecursiveMemoryEf See https://www.tensorflow.org/xla/operation_semantics#scatter. }]; let arguments = (ins - Variadic:$inputs, + Variadic:$inputs, TensorOf<[AnyInteger, Index]>:$scatter_indices, - Variadic:$updates, - ScatterDimensionNumbers:$scatter_dimension_numbers, + Variadic:$updates, + MHLO_ScatterDimensionNumbers:$scatter_dimension_numbers, DefaultValuedOptionalAttr:$indices_are_sorted, DefaultValuedOptionalAttr:$unique_indices ); let regions = (region SizedRegion<1>:$update_computation); - let results = (outs Variadic); + let results = (outs Variadic); let hasCustomHLOConverter = 1; @@ -2469,7 +2520,7 @@ def HLO_ScatterOp: HLO_Op<"scatter", [SameVariadicOperandSize, RecursiveMemoryEf let hasCanonicalizer = 1; } -def HLO_SelectOp: HLO_Op<"select", [Pure, HLO_BroadcastingElementwise, +def MHLO_SelectOp: MHLO_Op<"select", [Pure, HLO_BroadcastingElementwise, InferTensorTypeWithReify]> { let summary = "Select operator"; let description = [{ @@ -2488,15 +2539,14 @@ def HLO_SelectOp: HLO_Op<"select", [Pure, HLO_BroadcastingElementwise, ``` }]; let arguments = (ins - HLO_PredTensor:$pred, - HLO_Tensor:$on_true, - HLO_Tensor:$on_false + MHLO_PredTensor:$pred, + MHLO_Tensor:$on_true, + MHLO_Tensor:$on_false ); - let results = (outs HLO_Tensor:$result); + let results = (outs MHLO_Tensor:$result); let hasFolder = 1; - let hasVerifier = 1; let hasCanonicalizer = 1; let assemblyFormat = [{ @@ -2505,8 +2555,8 @@ def HLO_SelectOp: HLO_Op<"select", [Pure, HLO_BroadcastingElementwise, }]; } -def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter", - [RecursiveMemoryEffects]> { +def MHLO_SelectAndScatterOp: MHLO_Op<"select_and_scatter", + [RecursiveMemoryEffects, DeclareOpInterfaceMethods]> { let summary = "SelectAndScatter operator"; let description = [{ Runs a windowed selection `select` function over `operand` with shape @@ -2519,9 +2569,9 @@ def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter", See https://www.tensorflow.org/xla/operation_semantics#selectandscatter. }]; let arguments = (ins - HLO_Tensor:$operand, - HLO_Tensor:$source, - HLO_Tensor:$init_value, + MHLO_Tensor:$operand, + MHLO_Tensor:$source, + MHLO_Tensor:$init_value, OptionalAttr:$window_dimensions, OptionalAttr:$window_strides, OptionalAttr:$padding @@ -2529,13 +2579,13 @@ def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter", let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); let hasVerifier = 1; let hasCustomHLOConverter = 1; } -def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [Pure, +def MHLO_SetDimensionSizeOp: MHLO_Op<"set_dimension_size", [Pure, DeclareOpInterfaceMethods]> { let summary = "SetDimensionSize operator"; let description = [{ @@ -2546,17 +2596,17 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [Pure, See https://www.tensorflow.org/xla/operation_semantics#setdimensionsize. }]; let arguments = (ins - HLO_Tensor:$operand, + MHLO_Tensor:$operand, I32Tensor:$size, I64Attr:$dimension ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); let hasFolder = 1; let hasVerifier = 1; } -def HLO_SortOp : HLO_Op<"sort", +def MHLO_SortOp : MHLO_Op<"sort", [RecursiveMemoryEffects, SameOperandsAndResultShape, InferTensorType]> { let summary = "Sort operator"; let description = [{ @@ -2566,12 +2616,12 @@ def HLO_SortOp : HLO_Op<"sort", See https://www.tensorflow.org/xla/operation_semantics#sort. }]; let arguments = (ins - Variadic:$inputs, + Variadic:$inputs, DefaultValuedOptionalAttr:$dimension, DefaultValuedOptionalAttr:$is_stable ); - let results = (outs Variadic); + let results = (outs Variadic); let regions = (region SizedRegion<1>:$comparator); @@ -2587,8 +2637,8 @@ def HLO_SortOp : HLO_Op<"sort", let hasVerifier = 1; } -def HLO_ReverseOp: HLO_ShapedInterfaceOp<"reverse", - [Pure, SameOperandsAndResultType, HLO_CompatibleOperandsAndResultType]> { +def MHLO_ReverseOp: MHLO_Op<"reverse", + [Pure, HLO_CompatibleOperandsAndResultType]> { let summary = "Reverse operator"; let description = [{ Reverses the specified dimensions of `operand` according to the given @@ -2597,16 +2647,34 @@ def HLO_ReverseOp: HLO_ShapedInterfaceOp<"reverse", See https://www.tensorflow.org/xla/operation_semantics#rev_reverse. }]; let arguments = (ins - HLO_Tensor:$operand, + MHLO_Tensor:$operand, I64ElementsAttr:$dimensions ); - let results = (outs HLO_Tensor); + // DISC-Begin + let extraClassDeclaration = [{ + LogicalResult reifyReturnTypeShapes( + OpBuilder& builder, ValueRange operands, + SmallVectorImpl& reifiedReturnShapes) { + return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(), + operands.front(), + &reifiedReturnShapes); + } + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { + return mlir::hlo::isCompatibleForHloTypeInference(l, r); + } + }]; + // DISC-end + + let hasVerifier = 1; + + let results = (outs MHLO_Tensor); let hasFolder = 1; } -def HLO_PartitionIdOp : HLO_Op<"partition_id", []> { +def MHLO_PartitionIdOp : MHLO_Op<"partition_id", [ + DeclareOpInterfaceMethods]> { let summary = "PartitionId operator"; let description = [{ Returns the value of the partition id of the currently executing device. @@ -2628,8 +2696,9 @@ def HLO_PartitionIdOp : HLO_Op<"partition_id", []> { let assemblyFormat = "attr-dict `:` type(results)"; } -def HLO_PadOp: HLO_ShapedInterfaceOp<"pad", - [Pure, SameOperandsAndResultElementType, InferTensorType]> { +def MHLO_PadOp: MHLO_ShapedInterfaceOp<"pad", + [Pure, SameOperandsAndResultElementType, + DeclareOpInterfaceMethods]> { let summary = "Pad operator"; let description = [{ Pads edges and between the elements of `operand` with the `padding_value` @@ -2655,14 +2724,14 @@ def HLO_PadOp: HLO_ShapedInterfaceOp<"pad", }]; let arguments = (ins - HLO_Tensor:$operand, - HLO_Tensor:$padding_value, + MHLO_Tensor:$operand, + MHLO_Tensor:$padding_value, I64ElementsAttr:$edge_padding_low, I64ElementsAttr:$edge_padding_high, I64ElementsAttr:$interior_padding ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); // TODO(b/129422361): PadOp has a custom constructor for HLO. let hasCustomHLOConverter = 1; @@ -2671,7 +2740,7 @@ def HLO_PadOp: HLO_ShapedInterfaceOp<"pad", let hasFolder = 1; } -def HLO_TraceOp: HLO_Op<"trace", []> { +def MHLO_TraceOp: MHLO_Op<"trace", []> { let summary = "Trace operator"; let description = [{ Emits a logging message `tag` with the `operand`. @@ -2683,14 +2752,14 @@ def HLO_TraceOp: HLO_Op<"trace", []> { ``` }]; let arguments = (ins - HLO_Tensor:$operand, + MHLO_Tensor:$operand, StrAttr:$tag ); let hasCustomHLOConverter = 1; let assemblyFormat = "$operand `,` $tag attr-dict `:` type($operand)"; } -def HLO_TransposeOp: HLO_ShapedInterfaceOp<"transpose", +def MHLO_TransposeOp: MHLO_ShapedInterfaceOp<"transpose", [Pure, SameOperandsAndResultElementType, DeclareOpInterfaceMethods]> { let summary = "Transpose operator"; @@ -2702,16 +2771,16 @@ def HLO_TransposeOp: HLO_ShapedInterfaceOp<"transpose", See https://www.tensorflow.org/xla/operation_semantics#transpose. }]; let arguments = (ins - HLO_Tensor:$operand, + MHLO_Tensor:$operand, I64ElementsAttr:$permutation ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); let hasFolder = 1; let hasCanonicalizer = 1; } -def HLO_TriangularSolveOp: HLO_Op<"triangular_solve", +def MHLO_TriangularSolveOp: MHLO_Op<"triangular_solve", [Pure, SameOperandsAndResultElementType, InferTensorType]> { let summary = "TriangularSolve operator"; let description = [{ @@ -2734,17 +2803,17 @@ def HLO_TriangularSolveOp: HLO_Op<"triangular_solve", See https://www.tensorflow.org/xla/operation_semantics#triangularsolve. }]; let arguments = (ins - HLO_FpOrComplexTensor:$a, - HLO_FpOrComplexTensor:$b, + MHLO_FpOrComplexTensor:$a, + MHLO_FpOrComplexTensor:$b, BoolAttr:$left_side, BoolAttr:$lower, BoolAttr:$unit_diagonal, - HLO_TransposeAttr:$transpose_a + MHLO_TransposeAttr:$transpose_a ); - let results = (outs HLO_FpOrComplexTensor); + let results = (outs MHLO_FpOrComplexTensor); } -def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ +def MHLO_ReduceWindowOp: MHLO_Op<"reduce_window", [ RecursiveMemoryEffects, SameVariadicOperandSize, SingleBlockImplicitTerminator<"ReturnOp">, @@ -2762,8 +2831,8 @@ def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ // attributes are 1-d. Attributes' leading dimension should match rank of the // operands. let arguments = (ins - Variadic:$inputs, - Variadic:$init_values, + Variadic:$inputs, + Variadic:$init_values, I64ElementsAttr:$window_dimensions, // If strides or dilations attributes are missing then the default value is // one for each of the operand dimensions. Similarly, padding values are zero @@ -2774,7 +2843,7 @@ def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ OptionalAttr:$padding ); - let results = (outs Variadic); + let results = (outs Variadic); // TODO(hinsu): Verify that the attached body arguments and results are // compatible with reduce op's operands. @@ -2793,10 +2862,21 @@ def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ build($_builder, $_state, TypeRange(result_type), ValueRange(operand), ValueRange(init_value), window_dimensions, window_strides, base_dilations, window_dilations, padding); - }]> + }]>, + OpBuilder<(ins "ValueRange":$operands, + "ValueRange":$init_values, + "DenseIntElementsAttr":$window_dimensions, + "DenseIntElementsAttr":$window_strides, + "DenseIntElementsAttr":$base_dilations, + "DenseIntElementsAttr":$window_dilations, + "DenseIntElementsAttr":$padding, + "function_ref":$bodyBuilder + )>, ]; let hasCustomHLOConverter = 1; + let hasFolder = 1; + let hasVerifier = 1; // TODO(hinsu): Implement custom printer and parser. let extraClassDeclaration = [{ @@ -2809,7 +2889,9 @@ def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ }]; } -def HLO_ReturnOp : HLO_Op<"return", [Pure, Terminator]> { +def MHLO_ReturnOp : MHLO_Op<"return", + [Pure, Terminator, + DeclareOpInterfaceMethods]> { let summary = [{ The `hlo.return` operation terminates a region and returns values. @@ -2824,7 +2906,7 @@ def HLO_ReturnOp : HLO_Op<"return", [Pure, Terminator]> { }]; let arguments = (ins - Variadic:$results + Variadic:$results ); // Disable conversion operator for return op as the op is not an actual XLA @@ -2834,22 +2916,45 @@ def HLO_ReturnOp : HLO_Op<"return", [Pure, Terminator]> { let assemblyFormat = "$results attr-dict (`:` type($results)^)?"; } -def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [Pure]> { +def MHLO_TorchIndexSelectOp : MHLO_Op<"torch_index_select", [Pure]> { + let summary = "Torch Index Select operator"; + let description = [{ + Returns a new tensor which indexes the input tensor along dimension `dim` + using the entries in `index`. + + The returned tensor has the same dimensions as `operand`, except for the + `dim`th dimension which is replaced by the shape of `index` without the + leading `batch_dims` dimensions; + + The `batch_dims` attribute specifies the number of major batch dimensions + (0 or more) that act like a multidimensional loop over both the input and + the index. + + Example: + + ```mlir + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + batch_dims = 1 : i64, dim = 2 : i64 + } : (tensor<8x128x3072x64xf32>, tensor<8x16x1024xi32>) -> tensor<8x128x16x1024x64xf32> + ``` + }]; + let arguments = (ins - HLO_Tensor:$operand, - HLO_Tensor:$index, + MHLO_Tensor:$operand, + MHLO_Tensor:$index, I64Attr:$dim, I64Attr:$batch_dims ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); // TODO(hinsu): Canonicalize to lower this client side HLO op to server // side HLO ops. } -def HLO_OptimizationBarrierOp : HLO_Op<"optimization_barrier", - [Pure, HLO_PairwiseSameOperandAndResultType]> { +def MHLO_OptimizationBarrierOp : MHLO_Op<"optimization_barrier", + [Pure, HLO_PairwiseSameOperandAndResultType, + DeclareOpInterfaceMethods]> { let summary = [{ The `mhlo.optimization_barrier` op blocks optimizations. @@ -2876,9 +2981,9 @@ def HLO_OptimizationBarrierOp : HLO_Op<"optimization_barrier", ``` }]; - let arguments = (ins Variadic:$operand); + let arguments = (ins Variadic:$operand); - let results = (outs Variadic:$result); + let results = (outs Variadic:$result); let hasCustomHLOConverter = 1; @@ -2898,7 +3003,7 @@ def HLO_OptimizationBarrierOp : HLO_Op<"optimization_barrier", // MHLO RNG Operators. //===----------------------------------------------------------------------===// -def HLO_RngOp : HLO_Op<"rng", [InferTensorTypeWithReify, AllElementTypesMatch<["a", "b", "result"]>]> { +def MHLO_RngOp : MHLO_Op<"rng", [InferTensorTypeWithReify, AllElementTypesMatch<["a", "b", "result"]>]> { let summary = "RNG with uniform distribution."; let description = [{ Constructs an output of a given shape with random numbers generated @@ -2917,19 +3022,19 @@ def HLO_RngOp : HLO_Op<"rng", [InferTensorTypeWithReify, AllElementTypesMatch<[" See https://www.tensorflow.org/xla/operation_semantics#rngnormal. }]; let arguments = (ins - 0DTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>:$a, - 0DTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>:$b, - HLO_DimensionTensor:$shape, - HLO_RngDistributionAttr:$rng_distribution + 0DTensorOf<[MHLO_Pred, MHLO_Int, MHLO_Float]>:$a, + 0DTensorOf<[MHLO_Pred, MHLO_Int, MHLO_Float]>:$b, + MHLO_DimensionTensor:$shape, + MHLO_RngDistributionAttr:$rng_distribution ); - let results = (outs HLO_PredIntOrFpTensor:$result); + let results = (outs MHLO_PredIntOrFpTensor:$result); let hasCustomHLOConverter = 1; let hasVerifier = 1; } -def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [Pure]> { +def MHLO_RngBitGeneratorOp : MHLO_Op<"rng_bit_generator", [Pure]> { let summary = "Uniform random number generator operator"; let description = [{ Returns an output with a given shape filled with uniform random bits using @@ -2939,13 +3044,13 @@ def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [Pure]> { See https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator. }]; let arguments = (ins - HLO_RngAlgorithmAttr:$rng_algorithm, - HLO_IntOrFpTensor:$initial_state + MHLO_RngAlgorithmAttr:$rng_algorithm, + MHLO_IntOrFpTensor:$initial_state ); let results = (outs - HLO_IntOrFpTensor:$output_state, - HLO_IntOrFpTensor:$output + MHLO_IntOrFpTensor:$output_state, + MHLO_StaticShapeIntOrFpTensor:$output ); let hasVerifier = 1; @@ -2953,7 +3058,7 @@ def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [Pure]> { let hasCustomHLOConverter = 1; } -def HLO_XlaRngGetAndUpdateStateOp: HLO_Op<"xla.rng_get_and_update_state", [DeclareOpInterfaceMethods]> { +def MHLO_XlaRngGetAndUpdateStateOp: MHLO_Op<"xla.rng_get_and_update_state", [DeclareOpInterfaceMethods]> { let summary = "RNG state change"; let description = [{ This instruction represents the change of the global random number generator @@ -2979,9 +3084,9 @@ def HLO_XlaRngGetAndUpdateStateOp: HLO_Op<"xla.rng_get_and_update_state", [Decla //===----------------------------------------------------------------------===// // TODO(b/230662142): Implement unknown scales/zero_point cases. -def HLO_UniformQuantizeOp : HLO_UnaryElementwiseOp<"uniform_quantize", - [Pure], TensorOf<[F32, BF16, HLO_QuantizedInt]>, - HLO_QuantizedIntTensor> { +def MHLO_UniformQuantizeOp : MHLO_UnaryElementwiseOp<"uniform_quantize", + [Pure], TensorOf<[F32, BF16, MHLO_QuantizedInt]>, + MHLO_QuantizedIntTensor> { let summary = "Uniform quantize operator"; let description = [{ Converts floating point tensors or uniform quantized integer tensors to @@ -3000,8 +3105,8 @@ def HLO_UniformQuantizeOp : HLO_UnaryElementwiseOp<"uniform_quantize", let hasCustomHLOConverter = 1; } -def HLO_UniformDequantizeOp : HLO_UnaryElementwiseOp<"uniform_dequantize", - [InferTensorType, Pure], HLO_QuantizedIntTensor, TensorOf<[F32, BF16]>> { +def MHLO_UniformDequantizeOp : MHLO_UnaryElementwiseOp<"uniform_dequantize", + [InferTensorType, Pure], MHLO_QuantizedIntTensor, TensorOf<[F32, BF16]>> { let summary = "Uniform dequantize operator"; let description = [{ Converts quantized array of integers to floating-points according to the @@ -3019,7 +3124,7 @@ def HLO_UniformDequantizeOp : HLO_UnaryElementwiseOp<"uniform_dequantize", let hasCustomHLOConverter = 1; } -def HLO_FusionOp : HLO_Op<"fusion", []> { +def MHLO_FusionOp : MHLO_Op<"fusion", []> { let summary = "Fusion operator"; let description = [{ Models the fusion instruction. @@ -3031,20 +3136,27 @@ def HLO_FusionOp : HLO_Op<"fusion", []> { let regions = (region SizedRegion<1>:$fused_computation); let arguments = (ins - Variadic:$inputs, - OptionalAttr:$fusion_kind + Variadic:$inputs, + OptionalAttr:$fusion_kind, + DefaultValuedOptionalAttr< + TypedArrayAttrBase< + MHLO_OutputOperandAlias, + "Aliasing attribute for outputs and operands of Fusion">, + "{}">:$output_operand_aliases ); let results = (outs - Variadic>:$results + Variadic>:$results ); // FusionOp has special conversion logic to HLO. let hasCustomHLOConverter = 1; + + let hasVerifier = 1; } // This is an op for purposes internal to XLA/GPU. -def HLO_BitcastOp : HLO_Op<"bitcast", [Pure]> { +def MHLO_BitcastOp : MHLO_Op<"bitcast", [Pure]> { let summary = "Bitcast operator"; let description = [{ This op changes the shape of the input in the way that the physical @@ -3061,15 +3173,16 @@ def HLO_BitcastOp : HLO_Op<"bitcast", [Pure]> { ``` }]; - let arguments = (ins HLO_Tensor:$operand); - let results = (outs HLO_Tensor); + let arguments = (ins MHLO_Tensor:$operand); + let results = (outs MHLO_Tensor); let hasCustomHLOConverter = 1; + let hasFolder = 1; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } -def HLO_ReducePrecisionOp : - HLO_Op<"reduce_precision", [HLO_CompatibleOperandsAndResultType]> { +def MHLO_ReducePrecisionOp : + MHLO_Op<"reduce_precision", [HLO_CompatibleOperandsAndResultType, Pure]> { let summary = "Reduce precision operator"; let description = [{ Models the effect of converting floating - point values to a lower - @@ -3080,17 +3193,26 @@ def HLO_ReducePrecisionOp : implementations. See https://www.tensorflow.org/xla/operation_semantics#reduceprecision. + + ```mlir + %0 = mhlo.reduce_precision %arg0, format = e8m10 : tensor<3x4xf32> + ``` }]; let arguments = (ins - HLO_FpTensor:$operand, + MHLO_FpTensor:$operand, I32Attr:$exponent_bits, I32Attr:$mantissa_bits ); let hasVerifier = 1; - let results = (outs HLO_FpTensor:$output); + let results = (outs MHLO_FpTensor:$output); + + let assemblyFormat = [{ + $operand `,` `format` `=` custom($exponent_bits, $mantissa_bits) + attr-dict `:` custom(type($operand), type($output)) + }]; } -def HLO_RealDynamicSliceOp: HLO_ShapedInterfaceOp< +def MHLO_RealDynamicSliceOp: MHLO_ShapedInterfaceOp< "real_dynamic_slice", [Pure, AllElementTypesMatch<["operand", "result"]>, AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { @@ -3109,12 +3231,12 @@ def HLO_RealDynamicSliceOp: HLO_ShapedInterfaceOp< ``` }]; let arguments = (ins - HLO_Tensor:$operand, - HLO_DimensionTensor:$start_indices, - HLO_DimensionTensor:$limit_indices, - HLO_DimensionTensor:$strides + MHLO_Tensor:$operand, + MHLO_DimensionTensor:$start_indices, + MHLO_DimensionTensor:$limit_indices, + MHLO_DimensionTensor:$strides ); - let results = (outs HLO_Tensor:$result); + let results = (outs MHLO_Tensor:$result); let hasCanonicalizer = 1; let hasCustomHLOConverter = 1; let hasVerifier = 1; @@ -3122,7 +3244,7 @@ def HLO_RealDynamicSliceOp: HLO_ShapedInterfaceOp< let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } -def HLO_DynamicPadOp: HLO_ShapedInterfaceOp<"dynamic_pad", +def MHLO_DynamicPadOp: MHLO_ShapedInterfaceOp<"dynamic_pad", [Pure, AllElementTypesMatch<["operand", "padding_value", "result"]>, AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> { let summary = "Dynamic Pad operator"; @@ -3142,13 +3264,13 @@ def HLO_DynamicPadOp: HLO_ShapedInterfaceOp<"dynamic_pad", ``` }]; let arguments = (ins - HLO_Tensor:$operand, - HLO_Tensor:$padding_value, - HLO_DimensionTensor:$edge_padding_low, - HLO_DimensionTensor:$edge_padding_high, - HLO_DimensionTensor:$interior_padding + MHLO_Tensor:$operand, + MHLO_Tensor:$padding_value, + MHLO_DimensionTensor:$edge_padding_low, + MHLO_DimensionTensor:$edge_padding_high, + MHLO_DimensionTensor:$interior_padding ); - let results = (outs HLO_Tensor:$result); + let results = (outs MHLO_Tensor:$result); let description = [{ Dynamically Pads the `operand`, with amount of padding added at low-end/high-end/interior is passed through input tensors. @@ -3160,7 +3282,7 @@ def HLO_DynamicPadOp: HLO_ShapedInterfaceOp<"dynamic_pad", let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } -def HLO_DynamicGatherOp: HLO_Op<"dynamic_gather", +def MHLO_DynamicGatherOp: MHLO_Op<"dynamic_gather", [InferTensorTypeWithReify, Pure]> { string summary = "Dynamic Gather operator"; string description = [{ @@ -3169,19 +3291,20 @@ def HLO_DynamicGatherOp: HLO_Op<"dynamic_gather", }]; let arguments = (ins - HLO_Tensor:$operand, - HLO_IntTensor:$start_indices, - HLO_IntTensor:$slice_sizes, - GatherDimensionNumbers:$dimension_numbers, + MHLO_Tensor:$operand, + MHLO_IntTensor:$start_indices, + MHLO_IntTensor:$slice_sizes, + MHLO_GatherDimensionNumbers:$dimension_numbers, DefaultValuedOptionalAttr:$indices_are_sorted ); - let results = (outs HLO_Tensor); + let results = (outs MHLO_Tensor); let hasCustomHLOConverter = 1; let hasCanonicalizer = 1; } -def HLO_DynamicConvOp : HLO_ShapedInterfaceOp<"dynamic_conv", [Pure]> { +def MHLO_DynamicConvOp : MHLO_Op<"dynamic_conv", [Pure, DeclareOpInterfaceMethods]> { let summary = "Dynamic Convolution operator"; let description = [{ The dynamic shape version of ConvOp. Computes a convolution with dynamic padding. @@ -3189,16 +3312,17 @@ def HLO_DynamicConvOp : HLO_ShapedInterfaceOp<"dynamic_conv", [Pure]> { let arguments = !con( (ins - HLO_Tensor:$lhs, - HLO_Tensor:$rhs, - HLO_Tensor:$d_padding), - ConvolutionAttributes.attributes); - let results = (outs HLO_Tensor); + MHLO_Tensor:$lhs, + MHLO_Tensor:$rhs, + MHLO_Tensor:$d_padding), + MHLO_ConvolutionAttributes.attributes); + let results = (outs MHLO_Tensor); + let hasCanonicalizer = 1; let hasCustomHLOConverter = 1; } -def HLO_ComputeReshapeShapeOp : - HLO_Op<"compute_reshape_shape", [Pure]> { +def MHLO_ComputeReshapeShapeOp : + MHLO_Op<"compute_reshape_shape", [Pure]> { string summary = "Compute input for reshape with any dynamic dim resolved"; string description = [{ @@ -3230,8 +3354,8 @@ def HLO_ComputeReshapeShapeOp : let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; } -def HLO_CstrReshapableOp : - HLO_Op<"cstr_reshapable", [Pure]> { +def MHLO_CstrReshapableOp : + MHLO_Op<"cstr_reshapable", [Pure]> { string summary = "Compute input for reshape with any dynamic dim resolved"; string description = [{ diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_attrs.td b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td similarity index 58% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_attrs.td rename to tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td index 05bbbd6a0b9..c90b50cc950 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_attrs.td +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td @@ -20,66 +20,66 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/IR/TensorEncoding.td" -def HLODim : ArrayRefParameter<"int64_t", "Dimension"> { - let parser = "mlir::mhlo::parseIntArray($_parser)"; - let printer = "mlir::mhlo::printIntArray($_printer, $_self)"; +def MHLO_Dims : ArrayRefParameter<"int64_t", "Dimension"> { + let parser = "parseDimSizes($_parser)"; + let printer = "printDimSizes($_printer, $_self)"; } -def ScatterDimensionNumbers : AttrDef { +def MHLO_ScatterDimensionNumbers : AttrDef { let mnemonic = "scatter"; let summary = "Attribute that models the dimension information for scatter"; let parameters = (ins - HLODim:$updateWindowDims, - HLODim:$insertedWindowDims, - HLODim:$scatterDimsToOperandDims, + MHLO_Dims:$updateWindowDims, + MHLO_Dims:$insertedWindowDims, + MHLO_Dims:$scatterDimsToOperandDims, "int64_t":$indexVectorDim ); let hasCustomAssemblyFormat = 1; } -def GatherDimensionNumbers : AttrDef { +def MHLO_GatherDimensionNumbers : AttrDef { let mnemonic = "gather"; let summary = "Attribute that models the dimension information for gather"; let parameters = (ins - HLODim:$offsetDims, - HLODim:$collapsedSliceDims, - HLODim:$startIndexMap, + MHLO_Dims:$offsetDims, + MHLO_Dims:$collapsedSliceDims, + MHLO_Dims:$startIndexMap, "int64_t":$indexVectorDim ); let hasCustomAssemblyFormat = 1; } -def DotDimensionNumbers : AttrDef { +def MHLO_DotDimensionNumbers : AttrDef { let mnemonic = "dot"; let summary = "Attribute that models the dimension information for dot."; let parameters = (ins - HLODim:$lhsBatchingDimensions, - HLODim:$rhsBatchingDimensions, - HLODim:$lhsContractingDimensions, - HLODim:$rhsContractingDimensions + MHLO_Dims:$lhsBatchingDimensions, + MHLO_Dims:$rhsBatchingDimensions, + MHLO_Dims:$lhsContractingDimensions, + MHLO_Dims:$rhsContractingDimensions ); let hasCustomAssemblyFormat = 1; } -def ConvDimensionNumbers : AttrDef { +def MHLO_ConvDimensionNumbers : AttrDef { let cppNamespace = "::mlir::mhlo"; let mnemonic = "conv"; let summary = "Structure of dimension information for conv op"; let parameters = (ins "int64_t":$inputBatchDimension, "int64_t":$inputFeatureDimension, - HLODim:$inputSpatialDimensions, + MHLO_Dims:$inputSpatialDimensions, "int64_t":$kernelInputFeatureDimension, "int64_t":$kernelOutputFeatureDimension, - HLODim:$kernelSpatialDimensions, + MHLO_Dims:$kernelSpatialDimensions, "int64_t":$outputBatchDimension, "int64_t":$outputFeatureDimension, - HLODim:$outputSpatialDimensions + MHLO_Dims:$outputSpatialDimensions ); let hasCustomAssemblyFormat = 1; } -def OutputOperandAlias : AttrDef { +def MHLO_OutputOperandAlias : AttrDef { let cppNamespace = "::mlir::mhlo"; let mnemonic = "output_operand_alias"; let summary = @@ -112,9 +112,9 @@ def OutputOperandAlias : AttrDef { ``` }]; let parameters = (ins - HLODim:$outputTupleIndices, + MHLO_Dims:$outputTupleIndices, "int64_t":$operandIndex, - HLODim:$operandTupleIndices + MHLO_Dims:$operandTupleIndices ); let assemblyFormat = [{ `<` `output_tuple_indices` `=` $outputTupleIndices `,` @@ -123,7 +123,7 @@ def OutputOperandAlias : AttrDef { }]; } -def ArgResultAlias : AttrDef { +def MHLO_ArgResultAlias : AttrDef { let cppNamespace = "::mlir::mhlo"; let mnemonic = "result_alias"; let summary = @@ -148,9 +148,9 @@ def ArgResultAlias : AttrDef { ``` }]; let parameters = (ins - HLODim:$argTupleIndices, + MHLO_Dims:$argTupleIndices, "int64_t":$resultIndex, - HLODim:$resultTupleIndices, + MHLO_Dims:$resultTupleIndices, "bool":$isMustAlias ); let hasCustomAssemblyFormat = 1; @@ -159,16 +159,104 @@ def ArgResultAlias : AttrDef { // Represents a unique identifier for each Send/Recv instruction pair or // optionally for collective instructions (AllReduce, CollectivePermute, // AllToAll). Non-positive channel_id handle is equivalent to no channel id. -def ChannelHandle : AttrDef { +def MHLO_ChannelHandle : AttrDef { let mnemonic = "channel_handle"; let parameters = (ins "int64_t":$handle, "int64_t":$type); let summary = "two 64-bit integers 'handle' and 'type'"; let assemblyFormat = "`<` struct(params) `>`"; } +def MHLO_CrossProgramPrefetch : AttrDef { + let mnemonic = "cross_program_prefetch"; + let parameters = (ins "int64_t":$parameter, MHLO_Dims:$indices, "Optional":$offset); + let summary = "Argument that is prefetched from another program"; + let description = [{ + This attribute captures an argument that is prefetched from another program. + For a given `CrossProgramPrefetchAttr`, `parameter` tells us which argument + of the `main` function of the module is prefetched, and `indices` is a shape + index telling us what subshape of that argument is prefetched. + + A shape has a subshape iff it is a tuple. In that case, the subshape of + the tuple by `indices` is the shape achieved after indexing by each + element of `indices` in turn. For example, the [1,0] subshape of + `tuple, tuple, token>>` is `tensor`. + + An empty value for `indices` means the whole shape is prefetched. + + For example, + + ```mlir + module attributes { mhlo.cross_program_prefetch = [ #mhlo.cross_program_prefetch< parameter = 0, indices = [0]> ]} { + func.func @copy(%arg0 : tuple, tensor>) -> tuple, tensor> { + %0 = "mhlo.copy"(%arg0) {is_cross_program_prefetch} + return %0 : tuple, tensor> + } + func.func @main(%arg0 : tuple, tensor>) -> tuple, tensor> { + %1 = "mhlo.async_start"(%arg0) {called_computation=@copy} + %2 = "mhlo.async_done"(%1) {called_computation=@copy} + return %2 : tuple, tensor> + } + } + ``` + + The `parameter = 0` tells us that the async copy of the `0`th parameter is + a `cross_program_prefetch`, while the `index` of `[0]` tells us that the + `0`th element of the tuple is prefetched while the other element of the + tuple is not. + }]; + let assemblyFormat = "`<` struct(params) `>`"; +} + +def MHLO_DynamicParameterBinding : AttrDef { + let mnemonic = "dynamic_parameter_binding"; + let summary = "Indicates which parameters represent dynamic dimension sizes."; + let description = [{ + This attribute indicates which parameters to the `main` function of a + given module have the runtime values for dynamic dimension sizes for + other parameters. + + Each binding here specifies that parameter containing the runtime value of + the dynamic dimension size `dynamic_param_num` (which contains the + runtime size at a subshape given by the shape index + `dynamic_param_index`) and the dimension for which that is the size, + which is the `target_param_dim_num` dimension of the `target_param_num` + parameter to `main` at subshape `target_param_index`. + + See cross_program_prefetch for a discussion of subshapes. + + For example, + + module @main attributes { + mhlo.dynamic_parameter_bindings = [ + #mhlo.dynamic_parameter_binding< + dynamic_param_num = 0, + dynamic_param_indices = [], + target_param_num = 1, + target_param_indices = [], + target_param_dim_num = 0>] } { + func.func @main(%a : tensor, %b : tensor>) -> () { + func.return + } + } + + Here, 'b' (param index 1) has a dynamic shape whose real size is + determined at runtime. 'a' represents the real size of b's zeroth + dimension. + + }]; + let parameters = (ins + "int64_t":$dynamic_param_num, + MHLO_Dims:$dynamic_param_indices, + "int64_t":$target_param_num, + MHLO_Dims:$target_param_indices, + "int64_t":$target_param_dim_num + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + // Note: This is an experimental attribute and shouldn't be relied upon for // production. -def TypeExtensions : AttrDef, DeclareAttrInterfaceMethods]> { let mnemonic = "type_extensions"; @@ -188,11 +276,12 @@ def TypeExtensions : AttrDef().getType().getRank() == 1}]>]>, @@ -203,12 +292,12 @@ def HLO_LayoutAttr : Attr< } // An array of layout (1D tensor) attributes. -def HLO_ArrayOfLayoutAttr : TypedArrayAttrBase; // An array of FlatSymbolRef attributes that can be used as a default valued // attribute. -def HLO_FlatSymbolRefArrayAttr : +def MHLO_FlatSymbolRefArrayAttr : TypedArrayAttrBase { let constBuilderCall = "::mlir::ArrayAttr::get($_builder.getContext(), $0)"; } @@ -217,7 +306,7 @@ def HLO_FlatSymbolRefArrayAttr : // Common convolution attributes //===----------------------------------------------------------------------===// -def BoolElementsAttr : +def MHLO_BoolElementsAttr : ElementsAttrBase< And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">, CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>, @@ -228,7 +317,7 @@ def BoolElementsAttr : let convertFromStorage = "$_self"; } -def ConvolutionAttributes { +def MHLO_ConvolutionAttributes { dag attributes = (ins // Default value: one for each of the spatial dimension. OptionalAttr:$window_strides, @@ -239,11 +328,11 @@ def ConvolutionAttributes { // Default value: one for each of the spatial dimension. OptionalAttr:$rhs_dilation, // Default value: false for each of the spatial dimension. - OptionalAttr:$window_reversal, - ConvDimensionNumbers:$dimension_numbers, + OptionalAttr:$window_reversal, + MHLO_ConvDimensionNumbers:$dimension_numbers, I64Attr:$feature_group_count, I64Attr:$batch_group_count, - HLO_PrecisionConfigAttr:$precision_config + MHLO_PrecisionConfigAttr:$precision_config ); } diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops_common.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_common.cc similarity index 98% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops_common.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_common.cc index a4c196356b8..64147cb52cb 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_ops_common.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_common.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h" +#include "mhlo/IR/hlo_ops_common.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSet.h" @@ -26,7 +26,7 @@ namespace hlo { // Verifies the source target pairs attached to collective permute. LogicalResult verifyCollectivePermuteSourceTargetPairs( Operation *op, DenseIntElementsAttr attr) { - auto type = attr.getType().dyn_cast(); + auto type = attr.getType().cast(); if (type.getRank() != 2) return op->emitError() << "expect source_target_pairs attribute to be of " "rank 2, but got rank " @@ -40,6 +40,10 @@ LogicalResult verifyCollectivePermuteSourceTargetPairs( llvm::DenseSet targets; for (auto i = attr.begin(), e = attr.end(); i != e; ++i) { auto val = (*i).getSExtValue(); + if (val < 0) + return op->emitError() + << "replica ids in source_target_pairs must be >= 0."; + if (i.getIndex() % 2 == 0) { bool isUnique = sources.insert(val).second; if (!isUnique) return op->emitError() << "duplicate sources not allowed."; diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_common.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_common.h new file mode 100644 index 00000000000..69e1b43f5c1 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_common.h @@ -0,0 +1,57 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MLIR_HLO_MHLO_IR_HLO_OPS_COMMON_H +#define MLIR_HLO_MHLO_IR_HLO_OPS_COMMON_H + +// This file defines functionality shared between chlo/mhlo/lhlo dialects. + +#include + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" + +namespace mlir { +namespace hlo { + +// Verifies the source target pairs attached to collective permute. +LogicalResult verifyCollectivePermuteSourceTargetPairs( + Operation* op, DenseIntElementsAttr attr); + +LogicalResult verifyReduceScatter(Operation* op, TypeRange operandTypes, + TypeRange resultTypes, + uint64_t scatterDimension); + +// Custom formatting for convolution window attributes. +void printWindowAttributes(OpAsmPrinter& p, Operation* op, + llvm::Optional windowStrides, + llvm::Optional padding, + llvm::Optional lhsDilation, + llvm::Optional rhsDilation, + llvm::Optional windowReversal); + +ParseResult parseWindowAttributes(OpAsmParser& parser, + DenseIntElementsAttr& windowStrides, + DenseIntElementsAttr& padding, + DenseIntElementsAttr& lhsDilation, + DenseIntElementsAttr& rhsDilation, + DenseElementsAttr& windowReversal); + +} // namespace hlo +} // namespace mlir + +#endif // MLIR_HLO_MHLO_IR_HLO_OPS_COMMON_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.td b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_common.td similarity index 77% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.td rename to tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_common.td index d2e2e13e041..67aac3a94ae 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.td +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_common.td @@ -16,18 +16,18 @@ limitations under the License. #ifndef MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON #define MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON -def HLO_Dialect : Dialect { +def MHLO_Dialect : Dialect { let name = "mhlo"; let cppNamespace = "::mlir::mhlo"; - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; let useDefaultAttributePrinterParser = 0; let useDefaultTypePrinterParser = 0; + let useFoldAPI = kEmitFoldAdaptorFolder; } -include "stablehlo/dialect/Base.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_enums.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_attrs.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_typedefs.td" +include "mhlo/IR/hlo_base.td" +include "mhlo/IR/hlo_ops_enums.td" +include "mhlo/IR/hlo_ops_attrs.td" +include "mhlo/IR/hlo_ops_typedefs.td" #endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td new file mode 100644 index 00000000000..729dbf48cb2 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td @@ -0,0 +1,264 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_ENUMS +#define MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_ENUMS + +include "mlir/IR/EnumAttr.td" +include "mlir/IR/PatternBase.td" + +//===----------------------------------------------------------------------===// +// Precision Config enum definitions. +//===----------------------------------------------------------------------===// + +// These mirror the XLA PrecisionConfig proto enum. +def MHLO_PRECISION_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>; +def MHLO_PRECISION_HIGH : I32EnumAttrCase<"HIGH", 1>; +def MHLO_PRECISION_HIGHEST : I32EnumAttrCase<"HIGHEST", 2>; +def MHLO_PRECISION_PACKED_NIBBLE : I32EnumAttrCase<"PACKED_NIBBLE", 3>; + +def MHLO_Precision : I32EnumAttr<"Precision", + "XLA precision for an operand. Has backend specific meaning.", + [MHLO_PRECISION_DEFAULT, MHLO_PRECISION_HIGH, MHLO_PRECISION_HIGHEST, MHLO_PRECISION_PACKED_NIBBLE]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mhlo"; +} + +def MHLO_PrecisionAttr : EnumAttr; + +// TODO(b/129153247) See if it's possible to also validate the size. +def MHLO_PrecisionConfigAttr: + OptionalAttr< + TypedArrayAttrBase>; + +//===----------------------------------------------------------------------===// +// Custom call schedule hints +//===----------------------------------------------------------------------===// + +// These mirror the XLA CustomCallSchedule proto enum. +def MHLO_CUSTOM_CALL_SCHEDULE_NONE : I32EnumAttrCase<"NONE", 0>; +def MHLO_CUSTOM_CALL_SCHEDULE_LATEST : I32EnumAttrCase<"LATEST", 1>; +def MHLO_CUSTOM_CALL_SCHEDULE_EARLIEST : I32EnumAttrCase<"EARLIEST", 2>; + +// mhlo.custom_call_schedule gives us a scheduling hint for placing calls +// LATEST indicates that the operation should be scheduled just before the first +// user in the use/def chain. +// EARLIEST indicates that the operation should be scheduled just after the last +// operation that defines an argument of this operation in the use/def chain. +// NONE indicates no hint for the compiler. +def MHLO_CustomCallSchedule : I32EnumAttr<"CustomCallSchedule", + "Specifies the desired schedule for the custom-call.", + [MHLO_CUSTOM_CALL_SCHEDULE_NONE, MHLO_CUSTOM_CALL_SCHEDULE_LATEST, MHLO_CUSTOM_CALL_SCHEDULE_EARLIEST]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mhlo"; +} + +def MHLO_CustomCallScheduleAttr : EnumAttr; + +//===----------------------------------------------------------------------===// +// Domain Metadata Kind enum definitions. +//===----------------------------------------------------------------------===// + +// These mirror the XLA FftType proto enum. +def MHLO_DOMAIN_KIND_SHARDING : I32EnumAttrCase<"sharding", 0>; + +def MHLO_DomainKind : I32EnumAttr<"DomainKind", + "Kind of domain metatdata attached to an HLO domain.", + [MHLO_DOMAIN_KIND_SHARDING]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mhlo"; +} + +def MHLO_DomainKindAttr : EnumAttr; + +//===----------------------------------------------------------------------===// +// Fast Fourier Transform Type enum definitions. +//===----------------------------------------------------------------------===// + +// These mirror the XLA FftType proto enum. +def MHLO_FFT_TYPE_FFT : I32EnumAttrCase<"FFT", 0>; +def MHLO_FFT_TYPE_IFFT : I32EnumAttrCase<"IFFT", 1>; +def MHLO_FFT_TYPE_RFFT : I32EnumAttrCase<"RFFT", 2>; +def MHLO_FFT_TYPE_IRFFT : I32EnumAttrCase<"IRFFT", 3>; + +def MHLO_FftType : I32EnumAttr<"FftType", + "XLA fast fourier transform type.", + [MHLO_FFT_TYPE_FFT, MHLO_FFT_TYPE_IFFT, + MHLO_FFT_TYPE_RFFT, MHLO_FFT_TYPE_IRFFT]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mhlo"; +} + +def MHLO_FftTypeAttr : EnumAttr; + +//===----------------------------------------------------------------------===// +// Custom call enum definitions. +//===----------------------------------------------------------------------===// + +// TODO(b/189822916): Remove this enum when all clients are migrated to the +// status-returning API. +def MHLO_CUSTOM_CALL_API_VERISON_UNSPECIFIED : + I32EnumAttrCase<"API_VERSION_UNSPECIFIED", 0>; +def MHLO_CUSTOM_CALL_API_VERSION_ORIGINAL : + I32EnumAttrCase<"API_VERSION_ORIGINAL", 1>; +def MHLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING : + I32EnumAttrCase<"API_VERSION_STATUS_RETURNING", 2>; +def MHLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING_UNIFIED : + I32EnumAttrCase<"API_VERSION_STATUS_RETURNING_UNIFIED", 3>; +def MHLO_CUSTOM_CALL_API_VERSION_TYPED_FFI : + I32EnumAttrCase<"API_VERSION_TYPED_FFI", 4>; +def MHLO_CustomCallApiVersionAttr : + I32EnumAttr<"CustomCallApiVersion", "Custom call API version", [ + MHLO_CUSTOM_CALL_API_VERISON_UNSPECIFIED, + MHLO_CUSTOM_CALL_API_VERSION_ORIGINAL, + MHLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING, + MHLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING_UNIFIED, + MHLO_CUSTOM_CALL_API_VERSION_TYPED_FFI + ]> { + let cppNamespace = "::mlir::mhlo"; +} + +//===----------------------------------------------------------------------===// +// Comparison op definitions. +//===----------------------------------------------------------------------===// + +// These mirror the XLA ComparisonDirection enum. +def MHLO_COMPARISON_DIRECTION_EQ : I32EnumAttrCase<"EQ", 0>; +def MHLO_COMPARISON_DIRECTION_NE : I32EnumAttrCase<"NE", 1>; +def MHLO_COMPARISON_DIRECTION_GE : I32EnumAttrCase<"GE", 2>; +def MHLO_COMPARISON_DIRECTION_GT : I32EnumAttrCase<"GT", 3>; +def MHLO_COMPARISON_DIRECTION_LE : I32EnumAttrCase<"LE", 4>; +def MHLO_COMPARISON_DIRECTION_LT : I32EnumAttrCase<"LT", 5>; + +def MHLO_ComparisonDirection : I32EnumAttr<"ComparisonDirection", + "Which comparison operation to perform.", + [ + MHLO_COMPARISON_DIRECTION_EQ, + MHLO_COMPARISON_DIRECTION_NE, + MHLO_COMPARISON_DIRECTION_GE, + MHLO_COMPARISON_DIRECTION_GT, + MHLO_COMPARISON_DIRECTION_LE, + MHLO_COMPARISON_DIRECTION_LT + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mhlo"; +} + +def MHLO_ComparisonDirectionAttr : EnumAttr; + +def MHLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"::mlir::mhlo::ComparisonTypeAttr()">; +def MHLO_COMPARISON_TYPE_NOTYPE : I32EnumAttrCase<"NOTYPE", 0>; +def MHLO_COMPARISON_TYPE_FLOAT : I32EnumAttrCase<"FLOAT", 1>; +def MHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : I32EnumAttrCase<"TOTALORDER", 2>; +def MHLO_COMPARISON_TYPE_SIGNED : I32EnumAttrCase<"SIGNED", 3>; +def MHLO_COMPARISON_TYPE_UNSIGNED : I32EnumAttrCase<"UNSIGNED", 4>; + +def MHLO_ComparisonType : I32EnumAttr<"ComparisonType", + "Which comparison type to use.", + [ + MHLO_COMPARISON_TYPE_NOTYPE, + MHLO_COMPARISON_TYPE_FLOAT, + MHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER, + MHLO_COMPARISON_TYPE_SIGNED, + MHLO_COMPARISON_TYPE_UNSIGNED + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mhlo"; +} + +def MHLO_ComparisonTypeAttr : EnumAttr; + +// These mirror the XLA Dequantize mode string enum. +def MHLO_MIN_COMBINED : I32EnumAttrCase<"MIN_COMBINED", 0>; + +def MHLO_DequantizeMode : I32EnumAttr<"DequantizeMode", + "Dequantization mode. Only MIN_COMBINED is supported.", + [MHLO_MIN_COMBINED]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mhlo"; +} + +def MHLO_DequantizeModeAttr : EnumAttr; + +// These mirror the XLA Transpose enum in Triangular Solve options. +def MHLO_TRANSPOSE_INVALID : I32EnumAttrCase<"TRANSPOSE_INVALID", 0>; +def MHLO_NO_TRANSPOSE : I32EnumAttrCase<"NO_TRANSPOSE", 1>; +def MHLO_TRANSPOSE : I32EnumAttrCase<"TRANSPOSE", 2>; +def MHLO_ADJOINT : I32EnumAttrCase<"ADJOINT", 3>; + +def MHLO_Transpose : I32EnumAttr<"Transpose", + "Transpose options", + [ + MHLO_TRANSPOSE_INVALID, + MHLO_NO_TRANSPOSE, + MHLO_TRANSPOSE, + MHLO_ADJOINT + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mhlo"; +} + +def MHLO_TransposeAttr : EnumAttr; + +def MHLO_LOOP_FUSION : I32EnumAttrCase<"kLoop", 0>; +def MHLO_INPUT_FUSION : I32EnumAttrCase<"kInput", 1>; +def MHLO_OUTPUT_FUSION : I32EnumAttrCase<"kOutput", 2>; +def MHLO_CUSTOM_FUSION : I32EnumAttrCase<"kCustom", 3>; +def MHLO_FusionKind : I32EnumAttr<"FusionKind", "fusion kind", [ + MHLO_LOOP_FUSION, MHLO_INPUT_FUSION, MHLO_OUTPUT_FUSION, MHLO_CUSTOM_FUSION +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mhlo"; +} + +def MHLO_RNG_DISTRIBUTION_UNIFORM : I32EnumAttrCase<"UNIFORM", 1>; +def MHLO_RNG_DISTRIBUTION_NORMAL : I32EnumAttrCase<"NORMAL", 2>; + +def MHLO_RngDistribution : I32EnumAttr<"RngDistribution", + "XLA PRNG distribution to be used.", + [ + MHLO_RNG_DISTRIBUTION_UNIFORM, + MHLO_RNG_DISTRIBUTION_NORMAL + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mhlo"; +} + +def MHLO_RngDistributionAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def MHLO_FusionKindAttr : EnumAttr; + +def MHLO_RNG_ALGORITHM_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>; +def MHLO_RNG_ALGORITHM_THREE_FRY : I32EnumAttrCase<"THREE_FRY", 1>; +def MHLO_RNG_ALGORITHM_PHILOX : I32EnumAttrCase<"PHILOX", 2>; + +def MHLO_RngAlgorithm : I32EnumAttr<"RngAlgorithm", + "XLA PRNG algorithm to be used.", + [ + MHLO_RNG_ALGORITHM_DEFAULT, + MHLO_RNG_ALGORITHM_THREE_FRY, + MHLO_RNG_ALGORITHM_PHILOX + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mhlo"; +} + +def MHLO_RngAlgorithmAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_ENUMS diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_typedefs.td b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_typedefs.td similarity index 86% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_typedefs.td rename to tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_typedefs.td index 44acebbb923..532915ab980 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_typedefs.td +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_ops_typedefs.td @@ -17,7 +17,7 @@ limitations under the License. #ifndef MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_TYPEDEFS #define MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_TYPEDEFS -def AsyncBundle : TypeDef { +def MHLO_AsyncBundleTypeDef : TypeDef { let mnemonic = "async_bundle"; let summary = "Opaque collection of other types"; let parameters = (ins ArrayRefParameter<"Type">:$types); @@ -51,11 +51,11 @@ def AsyncBundle : TypeDef { } // Whether a type is a AsyncBundleType. -def IsAsyncBundleTypePred : CPred<"$_self.isa<::mlir::mhlo::AsyncBundleType>()">; +def MHLO_IsAsyncBundleTypePred : CPred<"$_self.isa<::mlir::mhlo::AsyncBundleType>()">; -def HLO_AsyncBundle : - MixedContainerType, IsAsyncBundleTypePred, +def MHLO_AsyncBundle : + MixedContainerType, MHLO_IsAsyncBundleTypePred, "AsyncBundleType::getFlattenedTypes($_self.cast<::mlir::mhlo::AsyncBundleType>())", "async_bundle">; -#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_TYPEDEFS \ No newline at end of file +#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_TYPEDEFS diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_patterns.td b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_patterns.td similarity index 64% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_patterns.td rename to tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_patterns.td index 9dae5e88813..2c1534e6568 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/hlo_patterns.td +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_patterns.td @@ -17,56 +17,74 @@ limitations under the License. include "mlir/Dialect/Shape/IR/ShapeOps.td" include "mlir/Dialect/Tensor/IR/TensorOps.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "mhlo/IR/hlo_ops.td" def HasSameType : Constraint>; +// Checks if the value has only one user. +def HasOneUse : Constraint>; + // Canonicalization patterns. def DynamicBroadcastToOwnShape_1 : Pat< - (HLO_DynamicBroadcastInDimOp:$op $x, + (MHLO_DynamicBroadcastInDimOp:$op $x, (Shape_ToExtentTensorOp (Shape_ShapeOfOp $x)), $_, $_, $_), (replaceWithValue $x)>; def DynamicBroadcastToOwnShape_2 : Pat< - (HLO_DynamicBroadcastInDimOp:$op $x, + (MHLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr0, $attr1, $attr2), (replaceWithValue $x)>; def DynamicBroadcastToOwnShape_3 : Pat< - (HLO_DynamicBroadcastInDimOp:$op $x, + (MHLO_DynamicBroadcastInDimOp:$op $x, (Tensor_CastOp (Shape_ToExtentTensorOp (Shape_ShapeOfOp $x))), $attr0, $attr1, $attr2), (Tensor_CastOp $x)>; def DynamicBroadcastToOwnShape_4 : Pat< - (HLO_DynamicBroadcastInDimOp:$op $x, + (MHLO_DynamicBroadcastInDimOp:$op $x, (Tensor_CastOp (Shape_ShapeOfOp $x)), $attr0, $attr1, $attr2), (Tensor_CastOp $x)>; def ShapeOfDynamicReshape : Pat< - (Shape_ShapeOfOp:$op (HLO_DynamicReshapeOp $x, $shape)), + (Shape_ShapeOfOp:$op (MHLO_DynamicReshapeOp $x, $shape)), (replaceWithValue $shape), [(HasSameType $shape, $op)]>; def IdentityBroadcastReshape : Pat< - (HLO_ReshapeOp:$op (HLO_BroadcastOp $input, $dims)), + (MHLO_ReshapeOp:$op (MHLO_BroadcastOp $input, $dims)), (replaceWithValue $input), [(HasSameType $input, $op)]>; def IdentityBroadcastInDimReshape : Pat< - (HLO_ReshapeOp:$op (HLO_BroadcastInDimOp $input, $dims)), + (MHLO_ReshapeOp:$op (MHLO_BroadcastInDimOp $input, $dims)), (replaceWithValue $input), [(HasSameType $input, $op)]>; def EliminateIdentityConvert : Pat< - (HLO_ConvertOp:$res $src), + (MHLO_ConvertOp:$res $src), (replaceWithValue $src), [(HasSameType $res, $src)]>; def EliminateRedundantReshape : Pat< - (HLO_ReshapeOp:$res (HLO_ReshapeOp $src)), + (MHLO_ReshapeOp:$res (MHLO_ReshapeOp $src)), (replaceWithValue $src), [(HasSameType $res, $src)]>; def EliminateIdentityReshape : Pat< - (HLO_ReshapeOp:$res $src), + (MHLO_ReshapeOp:$res $src), (replaceWithValue $src), [(HasSameType $res, $src)]>; + +// select(not(p), t, f) => select(p, f, t) +def FusePredNegIntoSelect : Pat< + (MHLO_SelectOp (MHLO_NotOp $pred), $on_true, $on_false), + (MHLO_SelectOp $pred, $on_false, $on_true)>; + +// select(broadcast(not(p)), t, f) => select(broadcast(p), f, t) +def FuseBroadcastedPredNegIntoSelect : Pat< + (MHLO_SelectOp + (MHLO_BroadcastInDimOp:$b (MHLO_NotOp $pred), $broadcast_dimensions), + $on_true, $on_false), + (MHLO_SelectOp + (MHLO_BroadcastInDimOp $pred, $broadcast_dimensions, (returnType $b)), + $on_false, $on_true), + [(HasOneUse $b)]>; diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_utils.td similarity index 92% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td rename to tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_utils.td index 2146fa2c5a4..b5eb6de56fe 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/hlo_utils.td @@ -30,16 +30,16 @@ def CastIntElementsAttr : NativeCodeCall<"$0.cast()">; class ConstantSplat : NativeCodeCall< "hlo::getSplat(&$_builder, $0, " # value # ")">; -class HLO_ConstantLike : NativeCodeCall< +class MHLO_ConstantLike : NativeCodeCall< "chlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; -def HLO_ConstantLikeMaxFiniteValue : NativeCodeCall< +def MHLO_ConstantLikeMaxFiniteValue : NativeCodeCall< "chlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">; -def HLO_ConstantLikePosInfValue : NativeCodeCall< +def MHLO_ConstantLikePosInfValue : NativeCodeCall< "chlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/false)">; -def HLO_ConstantLikeNegInfValue : NativeCodeCall< +def MHLO_ConstantLikeNegInfValue : NativeCodeCall< "chlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/true)">; def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/init.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/init.cc similarity index 92% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/init.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/IR/init.cc index b8393b735cf..9cd5e2a9265 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/init.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/init.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/IR/register.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/IR/register.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "stablehlo/dialect/ChloOps.h" diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/mhlo_bytecode.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc similarity index 99% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/mhlo_bytecode.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc index 12d0c1584d9..9efce77c1a3 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/mhlo_bytecode.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/mhlo/IR/mhlo_bytecode.h" +#include "mhlo/IR/mhlo_bytecode.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/IR/Diagnostics.h" #include "stablehlo/dialect/Base.h" diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/mhlo_bytecode.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.h similarity index 85% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/mhlo_bytecode.h rename to tensorflow/compiler/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.h index e7c36a044c6..a2c78c82c4d 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/mhlo_bytecode.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/mhlo_bytecode.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_MHLO_IR_MHLO_BYTECODE_H -#define MLIR_HLO_DIALECT_MHLO_IR_MHLO_BYTECODE_H +#ifndef MLIR_HLO_MHLO_IR_MHLO_BYTECODE_H +#define MLIR_HLO_MHLO_IR_MHLO_BYTECODE_H namespace mlir { namespace mhlo { @@ -26,4 +26,4 @@ void addBytecodeInterface(MhloDialect *dialect); } // namespace mhlo } // namespace mlir -#endif // MLIR_HLO_DIALECT_MHLO_IR_MHLO_BYTECODE_H +#endif // MLIR_HLO_MHLO_IR_MHLO_BYTECODE_H diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/mhlo_canonicalize.td b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/mhlo_canonicalize.td similarity index 67% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/mhlo_canonicalize.td rename to tensorflow/compiler/xla/mlir_hlo/mhlo/IR/mhlo_canonicalize.td index df79559d589..28ea9a469ad 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/IR/mhlo_canonicalize.td +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/mhlo_canonicalize.td @@ -16,8 +16,8 @@ limitations under the License. // This is the canonicalize pattern definition file. include "mlir/IR/OpBase.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" +include "mhlo/IR/hlo_ops.td" +include "mhlo/IR/hlo_utils.td" def UnaryToBinaryEinsumEq : NativeCodeCall< "$_builder.getStringAttr(\",\" + $0.getValue().str())">; @@ -28,48 +28,48 @@ def GetI64DenseElementsAttr : NativeCodeCall< // Convert UnaryEinsumOp to EinsumOp with two operands with redundant first // operand. def UnaryEinsumToEinsum : Pat< - (HLO_UnaryEinsumOp $operand, $equation), - (HLO_EinsumOp (HLO_ConstantOp (GetScalarOfType<1> $operand)), + (MHLO_UnaryEinsumOp $operand, $equation), + (MHLO_EinsumOp (MHLO_ConstantOp (GetScalarOfType<1> $operand)), $operand, (UnaryToBinaryEinsumEq $equation))>; // A dynamic reshape of a dynamic reshape is a dynamic reshape. def RemoveRedundantDynamicReshape : Pat< - (HLO_DynamicReshapeOp (HLO_DynamicReshapeOp $operand, $shape1), $shape2), - (HLO_DynamicReshapeOp $operand, $shape2)>; + (MHLO_DynamicReshapeOp (MHLO_DynamicReshapeOp $operand, $shape1), $shape2), + (MHLO_DynamicReshapeOp $operand, $shape2)>; // A dynamic broadcast of a dynamic reshape with the same shape operand // is a dynamic reshape. def RemoveRedundantDynamicBroadcast : Pat< - (HLO_DynamicBroadcastInDimOp - (HLO_DynamicReshapeOp $operand, $shape), + (MHLO_DynamicBroadcastInDimOp + (MHLO_DynamicReshapeOp $operand, $shape), $shape, IdentityBroadcastDims, $known_expanding_dimensions, $known_nonexpanding_dimensions), - (HLO_DynamicReshapeOp $operand, $shape)>; + (MHLO_DynamicReshapeOp $operand, $shape)>; // Convert DPad to Pad if edge_padding_low, edge_padding_high and -// interior_padding are HLO_ConstantOp +// interior_padding are constant. def DPadToPad: Pat< - (HLO_DynamicPadOp $input, + (MHLO_DynamicPadOp $input, $padding_value, (ConstantLikeMatcher AnyIntElementsAttr:$edge_padding_low), (ConstantLikeMatcher AnyIntElementsAttr:$edge_padding_high), (ConstantLikeMatcher AnyIntElementsAttr:$interior_padding)), - (HLO_PadOp $input, $padding_value, + (MHLO_PadOp $input, $padding_value, (GetI64DenseElementsAttr (CastIntElementsAttr $edge_padding_low)), (GetI64DenseElementsAttr (CastIntElementsAttr $edge_padding_high)), (GetI64DenseElementsAttr (CastIntElementsAttr $interior_padding)))>; // Convert RealDynamicSliceOp to SliceOp if start_indices, limit_indices and -// strides are HLO_ConstantOp +// strides are constant. def RealDSliceToSlice: Pat< - (HLO_RealDynamicSliceOp $operand, - (HLO_ConstantOp I64ElementsAttr:$start_indices), - (HLO_ConstantOp I64ElementsAttr:$limit_indices), - (HLO_ConstantOp I64ElementsAttr:$strides)), - (HLO_SliceOp $operand, - (CastIntElementsAttr $start_indices), - (CastIntElementsAttr $limit_indices), - (CastIntElementsAttr $strides))>; + (MHLO_RealDynamicSliceOp $operand, + (ConstantLikeMatcher AnyIntElementsAttr:$start_indices), + (ConstantLikeMatcher AnyIntElementsAttr:$limit_indices), + (ConstantLikeMatcher AnyIntElementsAttr:$strides)), + (MHLO_SliceOp $operand, + (GetI64DenseElementsAttr (CastIntElementsAttr $start_indices)), + (GetI64DenseElementsAttr (CastIntElementsAttr $limit_indices)), + (GetI64DenseElementsAttr (CastIntElementsAttr $strides)))>; diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/IR/register.h similarity index 100% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h rename to tensorflow/compiler/xla/mlir_hlo/mhlo/IR/register.h diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/analysis/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/mhlo/analysis/CMakeLists.txt new file mode 100644 index 00000000000..4e15d7389ca --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/analysis/CMakeLists.txt @@ -0,0 +1,24 @@ +add_mlir_library(MhloAnalysis + shape_component_analysis.cc + + DEPENDS + mlir-headers + + LINK_LIBS PUBLIC + MLIRAnalysis + MLIRIR + LmhloDialect +) + +add_mlir_library(MhloTestAnalysis + test_shape_component_analysis.cc + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRHLOAnalysis + MLIRAnalysis + MLIRPass + MLIRTransforms +) diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Analysis/shape_component_analysis.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc similarity index 99% rename from tensorflow/compiler/xla/mlir_hlo/lib/Analysis/shape_component_analysis.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc index 320c2969e0a..93e900b9572 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Analysis/shape_component_analysis.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Analysis/shape_component_analysis.h" +#include "mhlo/analysis/shape_component_analysis.h" #include +#include #include #include "llvm/ADT/STLExtras.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" @@ -717,7 +718,7 @@ struct ShapeVisitor { ArrayRef lookup(ShapeOrValueInfo requestedInfo) { auto i = symbolicExprsMap->find(requestedInfo); assert(i != symbolicExprsMap->end() && "op not processed yet?"); - return llvm::makeArrayRef(i->second); + return llvm::ArrayRef(i->second); } // Inserts a new entry into the cache and returns a reference to its result @@ -748,7 +749,7 @@ ShapeComponentAnalysis::ShapeComponentAnalysis::GetShapeInfo(Value value) { compute(request); auto found = symbolicExprsMap.find(request); if (found == symbolicExprsMap.end()) return {}; - return llvm::makeArrayRef(found->second); + return llvm::ArrayRef(found->second); } Optional> @@ -757,7 +758,7 @@ ShapeComponentAnalysis::ShapeComponentAnalysis::GetValueInfo(Value shape) { compute(request); auto found = symbolicExprsMap.find(request); if (found == symbolicExprsMap.end()) return {}; - return llvm::makeArrayRef(found->second); + return llvm::ArrayRef(found->second); } void ShapeComponentAnalysis::reset() { @@ -819,7 +820,7 @@ llvm::Optional SymbolicExpr::singleton() const { assert(symbols.size() == 1); return symbols[0]; } - return llvm::None; + return std::nullopt; } void SymbolicExpr::dump(llvm::raw_ostream &os) const { diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Analysis/shape_component_analysis.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.h similarity index 96% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Analysis/shape_component_analysis.h rename to tensorflow/compiler/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.h index 0fcefbfcea2..d1c3b13af11 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Analysis/shape_component_analysis.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H -#define MLIR_HLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H +#ifndef MLIR_HLO_MHLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H +#define MLIR_HLO_MHLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H #include "llvm/Support/raw_ostream.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Value.h" @@ -165,4 +165,4 @@ struct DenseMapInfo { } // namespace llvm -#endif // MLIR_HLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H +#endif // MLIR_HLO_MHLO_ANALYSIS_SHAPE_COMPONENT_ANALYSIS_H diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Analysis/test_shape_component_analysis.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/analysis/test_shape_component_analysis.cc similarity index 94% rename from tensorflow/compiler/xla/mlir_hlo/lib/Analysis/test_shape_component_analysis.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/analysis/test_shape_component_analysis.cc index a1ffa95ba0b..a749f7fad73 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Analysis/test_shape_component_analysis.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/analysis/test_shape_component_analysis.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Analysis/shape_component_analysis.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/analysis/shape_component_analysis.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" namespace mlir { #define GEN_PASS_DEF_TESTSHAPECOMPONENTANALYSIS -#include "mlir-hlo/Transforms/passes.h.inc" +#include "transforms/passes.h.inc" using SymbolicExpr = ShapeComponentAnalysis::SymbolicExpr; diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/interfaces/bufferizable_op_interface_impl.h similarity index 81% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h rename to tensorflow/compiler/xla/mlir_hlo/mhlo/interfaces/bufferizable_op_interface_impl.h index 4fc8a0b13b6..a0ccb1d7a21 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/interfaces/bufferizable_op_interface_impl.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_BUFFERIZABLE_OP_INTERFACE_IMPL_H -#define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_BUFFERIZABLE_OP_INTERFACE_IMPL_H +#ifndef MLIR_HLO_MHLO_INTERFACES_BUFFERIZABLE_OP_INTERFACE_IMPL_H +#define MLIR_HLO_MHLO_INTERFACES_BUFFERIZABLE_OP_INTERFACE_IMPL_H #include #include @@ -30,4 +30,4 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); } // namespace mhlo } // namespace mlir -#endif // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_BUFFERIZABLE_OP_INTERFACE_IMPL_H +#endif // MLIR_HLO_MHLO_INTERFACES_BUFFERIZABLE_OP_INTERFACE_IMPL_H diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt similarity index 56% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt index 6a8ef48a961..8c3395d5829 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt @@ -13,68 +13,64 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +set(LLVM_TARGET_DEFINITIONS mhlo_passes.td) +mlir_tablegen(mhlo_passes.h.inc -gen-pass-decls -name AllMhlo) +add_public_tablegen_target(MLIRMhloPassIncGen) + include_directories(BEFORE ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) -set(LLVM_TARGET_DEFINITIONS lower_complex_patterns.td) -mlir_tablegen(generated_lower_complex.inc -gen-rewriters) +set(LLVM_TARGET_DEFINITIONS lower_complex/lower_complex_patterns.td) +mlir_tablegen(lower_complex/generated_lower_complex.inc -gen-rewriters) add_public_tablegen_target(MLIRMhloLowerComplexIncGen) -set(LLVM_TARGET_DEFINITIONS legalize_to_standard_patterns.td) -mlir_tablegen(generated_legalize_to_standard.inc -gen-rewriters) +set(LLVM_TARGET_DEFINITIONS legalize_to_standard/legalize_to_standard_patterns.td) +mlir_tablegen(legalize_to_standard/generated_legalize_to_standard.inc -gen-rewriters) add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen) -set(LLVM_TARGET_DEFINITIONS chlo_legalize_to_hlo_patterns.td) -mlir_tablegen(generated_chlo_legalize_to_hlo.inc -gen-rewriters) +set(LLVM_TARGET_DEFINITIONS chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td) +mlir_tablegen(chlo_legalize_to_hlo/generated_chlo_legalize_to_hlo.inc -gen-rewriters) add_public_tablegen_target(MLIRChloLegalizeToHloIncGen) -add_mlir_library(MhloScatterUtils - mhlo_scatter_gather_utils.cc - - DEPENDS - MLIRhlo_opsIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MhloDialect -) add_mlir_library(MhloPasses - broadcast_propagation.cc - collapse_elementwise_map.cc - constraint_fusion_pass.cc - convert_to_signless_pass.cc - expand_hlo_tuples.cc - group_reduction_dimensions.cc - legalize_einsum_to_dot_general.cc - legalize_gather_to_torch_index_select.cc - legalize_shape_computations.cc - legalize_trigonometric_to_approximation.cc - lower_complex.cc - lower_complex_patterns.td - lower_general_dot.cc - materialize_broadcasts.cc - materialize_broadcasts_pass.cc - merge_assuming_ops.cc - mhlo_canonicalize_gather.cc - mhlo_canonicalize_reduction.cc - mhlo_canonicalize_scatter.cc - mhlo_flatten_tuple.cc - prepare_for_export.cc - optimize_mhlo.cc - optimize_mhlo_pass.cc - rank_specialization.cc - restrict_max_rank.cc - shape_reification_pass.cc - sink_constants_to_control_flow.cc - sparse_rewriting.cc - test_infer_shaped_type_pass.cc - unfuse_batch_norm.cc - unfuse_batch_norm_pass.cc + broadcast_propagation/broadcast_propagation.cc + collapse_elementwise_map/collapse_elementwise_map.cc + constraint_fusion/constraint_fusion_pass.cc + convert_to_signless/convert_to_signless_pass.cc + expand_hlo_tuples/expand_hlo_tuples.cc + expand_ops_simplifier/expand_ops_simplifier.cc + group_reduction_dimensions/group_reduction_dimensions.cc + legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc + legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc + legalize_shape_computations/legalize_shape_computations.cc + legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc + lower_complex/lower_complex.cc + lower_complex/lower_complex_patterns.td + lower_general_dot/lower_general_dot.cc + materialize_broadcasts/materialize_broadcasts.cc + materialize_broadcasts/materialize_broadcasts_pass.cc + merge_assuming_ops/merge_assuming_ops.cc + mhlo_canonicalize_gather/mhlo_canonicalize_gather.cc + mhlo_canonicalize_reduction/mhlo_canonicalize_reduction.cc + mhlo_canonicalize_scatter/mhlo_canonicalize_scatter.cc + mhlo_flatten_tuple/mhlo_flatten_tuple.cc + prepare_for_export/prepare_for_export.cc + optimize_mhlo/optimize_mhlo.cc + optimize_mhlo/optimize_mhlo_pass.cc + rank_specialization/rank_specialization.cc + restrict_max_rank/restrict_max_rank.cc + shape_reification/shape_reification_pass.cc + shape_simplification/shape_simplification.cc + sink_constants_to_control_flow/sink_constants_to_control_flow.cc + sparse_rewriting/sparse_rewriting.cc + symbolic_shape_optimization/symbolic_shape_optimization.cc + test_infer_shaped_type/test_infer_shaped_type_pass.cc + unfuse_batch_norm/unfuse_batch_norm.cc + unfuse_batch_norm/unfuse_batch_norm_pass.cc DEPENDS MLIRhlo_opsIncGen @@ -86,22 +82,26 @@ add_mlir_library(MhloPasses LINK_LIBS PUBLIC ChloOps + MhloAnalysis MhloDialect MhloScatterUtils MhloTypeConversion MLIRIR + MLIRLinalgDialect + MLIRMathDialect MLIRMhloUtils MLIRPass + MLIRSCFDialect + MLIRSideEffectInterfaces MLIRTransformUtils StablehloBroadcastUtils ) add_mlir_library(MhloToThloConversion - legalize_mhlo_to_thlo.cc + legalize_mhlo_to_thlo/legalize_mhlo_to_thlo.cc DEPENDS MLIRMhloPassIncGen - MLIRGmlStTilingInterfaceIncGen THLODialect LINK_COMPONENTS @@ -109,6 +109,7 @@ add_mlir_library(MhloToThloConversion LINK_LIBS PUBLIC MhloDialect + MhloToArithmeticConversion MhloTypeConversion THLODialect MLIRIR @@ -118,7 +119,7 @@ add_mlir_library(MhloToThloConversion ) add_mlir_library(MhloToLhloConversion - hlo_legalize_to_lhlo.cc + hlo_legalize_to_lhlo/hlo_legalize_to_lhlo.cc DEPENDS MLIRhlo_opsIncGen @@ -142,23 +143,8 @@ add_mlir_library(MhloToLhloConversion MLIRTransforms ) -add_mlir_library(MhloTypeConversion - type_conversion.cc - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MhloDialect - MLIRIR - MLIRFuncDialect - MLIRFuncTransforms - MLIRTensorDialect - StablehloOps -) - add_mlir_library(MhloToMemrefConversion - hlo_legalize_to_memref.cc + hlo_legalize_to_memref/hlo_legalize_to_memref.cc DEPENDS MLIRhlo_opsIncGen @@ -168,6 +154,7 @@ add_mlir_library(MhloToMemrefConversion Core LINK_LIBS PUBLIC + LmhloDialect MhloDialect MhloTypeConversion MLIRIR @@ -178,7 +165,7 @@ add_mlir_library(MhloToMemrefConversion ) add_mlir_library(MhloToArithmeticConversion - hlo_legalize_to_arithmetic.cc + hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc DEPENDS MLIRhlo_opsIncGen @@ -193,14 +180,15 @@ add_mlir_library(MhloToArithmeticConversion MLIRIR MLIRPass MLIRMathDialect + MLIRSCFDialect MLIRTransforms MLIRTransformUtils ) add_mlir_library(MhloToStandard - legalize_control_flow.cc - legalize_sort.cc - legalize_to_standard.cc + legalize_control_flow/legalize_control_flow.cc + legalize_sort/legalize_sort.cc + legalize_to_standard/legalize_to_standard.cc DEPENDS MLIRhlo_opsIncGen @@ -215,38 +203,17 @@ add_mlir_library(MhloToStandard LINK_LIBS PUBLIC MhloDialect MLIRIR + MLIRMathDialect MLIRPass + MLIRSCFDialect MLIRTensorDialect MLIRTransformUtils ) -add_mlir_library(HloToLinalgUtils - legalize_to_linalg_utils.cc - - DEPENDS - MLIRhlo_opsIncGen - MLIRMhloPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - LmhloDialect - MhloDialect - MhloTypeConversion - MLIRBufferizationDialect - MLIRComplexDialect - MLIRIR - MLIRLinalgUtils - MLIRPass - MLIRRewrite - MLIRTransformUtils -) - add_mlir_library(ChloPasses - chlo_legalize_to_hlo.cc - chlo_legalize_to_hlo_pass.cc - sparse_chlo_legalize_to_linalg.cc + chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc + chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc + sparse_chlo_legalize_to_linalg/sparse_chlo_legalize_to_linalg.cc DEPENDS MLIRhlo_opsIncGen @@ -268,7 +235,7 @@ add_mlir_library(ChloPasses ) add_mlir_library(MhloToLinalg - legalize_to_linalg.cc + legalize_to_linalg/legalize_to_linalg.cc DEPENDS MLIRhlo_opsIncGen @@ -281,10 +248,12 @@ add_mlir_library(MhloToLinalg HloToLinalgUtils LmhloDialect MhloDialect + MhloToArithmeticConversion MhloTypeConversion MLIRBufferizationDialect MLIRComplexDialect MLIRIR + MLIRLinalgTransforms MLIRLinalgUtils MLIRPass MLIRRewrite @@ -292,7 +261,7 @@ add_mlir_library(MhloToLinalg ) add_mlir_library(MhloShapeOpsToStandard - hlo_legalize_shape_ops_to_standard.cc + hlo_legalize_shape_ops_to_standard/hlo_legalize_shape_ops_to_standard.cc DEPENDS MLIRhlo_opsIncGen @@ -313,8 +282,8 @@ add_mlir_library(MhloShapeOpsToStandard ) add_mlir_library(MhloToStablehlo - hlo_legalize_to_stablehlo.cc - hlo_legalize_to_stablehlo_pass.cc + hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc + hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc DEPENDS MLIRMhloPassIncGen @@ -333,8 +302,8 @@ add_mlir_library(MhloToStablehlo ) add_mlir_library(StablehloToMhlo - stablehlo_legalize_to_hlo.cc - stablehlo_legalize_to_hlo_pass.cc + stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc + stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc DEPENDS MLIRMhloPassIncGen @@ -345,6 +314,7 @@ add_mlir_library(StablehloToMhlo LINK_LIBS PUBLIC MhloDialect MhloTypeConversion + MLIRAsmParser MLIRIR MLIRPass MLIRSupport diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/broadcast_propagation.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc similarity index 98% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/broadcast_propagation.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc index 777c57bb8e0..44e09a135a1 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/broadcast_propagation.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc @@ -23,8 +23,8 @@ limitations under the License. #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Location.h" @@ -38,7 +38,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_BROADCASTPROPAGATIONPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { @@ -328,7 +328,7 @@ void transitivelyEraseUnusedSideEffectFreeOps(Operation *root, if (opsToEraseSet.count(op)) continue; // Erase only operations that are unused and free of side effects. - if (!MemoryEffectOpInterface::hasNoEffect(op) || + if (!isMemoryEffectFree(op) || !llvm::all_of(op->getUsers(), [opsToEraseSet](Operation *user) { return opsToEraseSet.count(user); })) { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc similarity index 97% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc index 8e058be70f1..3bd2c477890 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc @@ -24,10 +24,9 @@ limitations under the License. #include #include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "mlir-hlo/utils/hlo_utils.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/map_chlo_to_hlo_op.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -42,6 +41,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/BroadcastUtils.h" #include "stablehlo/dialect/ChloOps.h" +#include "utils/hlo_utils.h" namespace mlir { namespace chlo { @@ -311,19 +311,19 @@ Value materializeErfcApproximationF64ForMagnituteGeOne( Value expZ = rewriter.create(loc, z); Value absX = rewriter.create(loc, x); Value polP = materializePolynomialApproximation( - rewriter, loc, absX, llvm::makeArrayRef(kErfcPCoefficients)); + rewriter, loc, absX, llvm::ArrayRef(kErfcPCoefficients)); Value expZMulPolyP = rewriter.create(loc, expZ, polP); Value polQ = materializePolynomialApproximation( - rewriter, loc, absX, llvm::makeArrayRef(kErfcQCoefficients)); + rewriter, loc, absX, llvm::ArrayRef(kErfcQCoefficients)); Value erfcApprox18 = rewriter.create(loc, expZMulPolyP, polQ); // Materialize polynomial approximation for x in >= 8 as // erfc(x) exp(z) R(|x|) / S(|x|). Value polR = materializePolynomialApproximation( - rewriter, loc, absX, llvm::makeArrayRef(kErfcRCoefficients)); + rewriter, loc, absX, llvm::ArrayRef(kErfcRCoefficients)); Value expZMulPolyR = rewriter.create(loc, expZ, polR); Value polS = materializePolynomialApproximation( - rewriter, loc, absX, llvm::makeArrayRef(kErfcSCoefficients)); + rewriter, loc, absX, llvm::ArrayRef(kErfcSCoefficients)); Value erfcApprox8Inf = rewriter.create(loc, expZMulPolyR, polS); // Combine polynomial approximations for x >= 1. @@ -375,10 +375,10 @@ Value materializeErfApproximationF64ForMagnituteLeOne( // erf(x) = x T(x^2) / U(x^2). Value xSq = rewriter.create(loc, x, x); Value polyT = materializePolynomialApproximation( - rewriter, loc, xSq, llvm::makeArrayRef(kErfTCoefficients)); + rewriter, loc, xSq, llvm::ArrayRef(kErfTCoefficients)); Value xMulPolyT = rewriter.create(loc, x, polyT); Value polyU = materializePolynomialApproximation( - rewriter, loc, xSq, llvm::makeArrayRef(kErfUCoefficients)); + rewriter, loc, xSq, llvm::ArrayRef(kErfUCoefficients)); return rewriter.create(loc, xMulPolyT, polyU); } @@ -475,9 +475,9 @@ Value materializeErfcApproximationF32ForMagnitudeGeOne( Value absXLtTwo = rewriter.create( loc, absX, two, mhlo::ComparisonDirection::LT); Value polP = materializePolynomialApproximation( - rewriter, loc, reciprocalXSq, llvm::makeArrayRef(kErfcPCoefficients)); + rewriter, loc, reciprocalXSq, llvm::ArrayRef(kErfcPCoefficients)); Value polR = materializePolynomialApproximation( - rewriter, loc, reciprocalXSq, llvm::makeArrayRef(kErfcRCoefficients)); + rewriter, loc, reciprocalXSq, llvm::ArrayRef(kErfcRCoefficients)); Value poly = rewriter.create(loc, absXLtTwo, polP, polR); Value erfcApprox = rewriter.create(loc, expZMulOneDivAbsX, poly); @@ -519,7 +519,7 @@ Value materializeErfApproximationF32ForMagnitudeLeOne( // erf(x) = x T(x^2). Value xSq = rewriter.create(loc, x, x); Value polyT = materializePolynomialApproximation( - rewriter, loc, xSq, llvm::makeArrayRef(kErfTCoefficients)); + rewriter, loc, xSq, llvm::ArrayRef(kErfTCoefficients)); return rewriter.create(loc, x, polyT); } @@ -547,10 +547,10 @@ Value materializeErfApproximationF32(ConversionPatternRewriter &rewriter, // Materialize polynomial approximation for x in [-4, 4] as // erf(x) = x * Alpha(x^2) / Beta(x^2). - Value alphaPoly = materializePolynomialApproximation( - rewriter, loc, xSq, llvm::makeArrayRef(kAlpha)); - Value betaPoly = materializePolynomialApproximation( - rewriter, loc, xSq, llvm::makeArrayRef(kBeta)); + Value alphaPoly = materializePolynomialApproximation(rewriter, loc, xSq, + llvm::ArrayRef(kAlpha)); + Value betaPoly = materializePolynomialApproximation(rewriter, loc, xSq, + llvm::ArrayRef(kBeta)); Value xMulAlphaPoly = rewriter.create(loc, x, alphaPoly); return rewriter.create(loc, xMulAlphaPoly, betaPoly); } @@ -1337,26 +1337,6 @@ struct ConvertSinhOp : public OpConversionPattern { } }; -Value materializeTan(ConversionPatternRewriter &rewriter, Location loc, - ValueRange operands) { - TanOp::Adaptor transformed(operands); - return rewriter.create( - loc, rewriter.create(loc, transformed.getOperand()), - rewriter.create(loc, transformed.getOperand())); -} - -struct ConvertTanOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite( - TanOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOp( - op, materializeWithUpcast(rewriter, op.getLoc(), adaptor.getOperands(), - rewriter.getF32Type(), &materializeTan)); - return success(); - } -}; - // Converts chlo.top_k to MHLO iota, sort, and slice ops. // // chlo.top_k sorts along last dimension of the input tensor and then returns @@ -1397,7 +1377,7 @@ struct ConvertTopKOp : public OpConversionPattern { int64_t operandRank = operandType.getRank(); int64_t lastDimIndex = operandRank - 1; int64_t lastDimSize = operandType.getDimSize(lastDimIndex); - assert(lastDimSize != ShapedType::kDynamicSize); + assert(lastDimSize != ShapedType::kDynamic); // Create an Iota op for indices. auto i32Type = rewriter.getIntegerType(32); @@ -1697,7 +1677,7 @@ class ConvertDynamicReshapeOp } }; -#include "generated_chlo_legalize_to_hlo.inc" +#include "chlo_legalize_to_hlo/generated_chlo_legalize_to_hlo.inc" } // namespace void populateChloBroadcastingPatterns(MLIRContext *context, @@ -1729,7 +1709,6 @@ void populateDecomposeChloPatterns(MLIRContext *context, ConvertNextAfterOp, ConvertPolygammaOp, ConvertSinhOp, - ConvertTanOp, ConvertTopKOp, ConvertZetaOp>(context); // clang-format on diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc similarity index 93% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc index 0cfa79b48f9..9ef25181fe1 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -28,7 +28,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_CHLOLEGALIZETOHLOPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td similarity index 57% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td index 0559e595b6c..3090ef17a48 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_patterns.td @@ -21,11 +21,11 @@ limitations under the License. include "mlir/Dialect/Shape/IR/ShapeOps.td" include "mlir/IR/OpBase.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "mhlo/IR/hlo_ops.td" include "stablehlo/dialect/ChloOps.td" -class HLO_ComparisonDirectionValue : - ConstantAttr; +class MHLO_ComparisonDirectionValue : + ConstantAttr; //===----------------------------------------------------------------------===// // Unary op patterns. @@ -39,29 +39,29 @@ class HLO_ComparisonDirectionValue : // using the following formula. // acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x)))) def : Pat<(CHLO_AcosOp NonComplexElementType:$input), - (HLO_SelectOp - (HLO_CompareOp + (MHLO_SelectOp + (MHLO_CompareOp $input, - (HLO_ConstantLike<"-1"> $input), - HLO_ComparisonDirectionValue<"NE">, - (HLO_DEFAULT_COMPARISON_TYPE) + (MHLO_ConstantLike<"-1"> $input), + MHLO_ComparisonDirectionValue<"NE">, + (MHLO_DEFAULT_COMPARISON_TYPE) ), - (HLO_MulOp - (HLO_ConstantLike<"2"> $input), - (HLO_Atan2Op - (HLO_SqrtOp - (HLO_SubtractOp - (HLO_ConstantLike<"1"> $input), - (HLO_MulOp $input, $input) + (MHLO_MulOp + (MHLO_ConstantLike<"2"> $input), + (MHLO_Atan2Op + (MHLO_SqrtOp + (MHLO_SubtractOp + (MHLO_ConstantLike<"1"> $input), + (MHLO_MulOp $input, $input) ) ), - (HLO_AddOp - (HLO_ConstantLike<"1"> $input), + (MHLO_AddOp + (MHLO_ConstantLike<"1"> $input), $input ) ) ), - (HLO_ConstantLike<"M_PI"> $input) + (MHLO_ConstantLike<"M_PI"> $input) )>; // Expand acosh to MHLO dialect as follows: @@ -73,40 +73,40 @@ def : Pat<(CHLO_AcosOp NonComplexElementType:$input), // log(2*x) = log(2) + log(x). (Note this works because negative x never // overflows; x < -1 simply yields nan. def : Pat<(CHLO_AcoshOp NonComplexElementType:$input), - (HLO_SelectOp - (HLO_CompareOp + (MHLO_SelectOp + (MHLO_CompareOp $input, - (HLO_ConstantLike<"-1"> $input), - HLO_ComparisonDirectionValue<"LT">, - (HLO_DEFAULT_COMPARISON_TYPE) + (MHLO_ConstantLike<"-1"> $input), + MHLO_ComparisonDirectionValue<"LT">, + (MHLO_DEFAULT_COMPARISON_TYPE) ), - (HLO_ConstantLike<"NAN"> $input), - (HLO_SelectOp - (HLO_CompareOp + (MHLO_ConstantLike<"NAN"> $input), + (MHLO_SelectOp + (MHLO_CompareOp $input, - (HLO_SqrtOp - (HLO_ConstantLikeMaxFiniteValue $input) + (MHLO_SqrtOp + (MHLO_ConstantLikeMaxFiniteValue $input) ), - HLO_ComparisonDirectionValue<"GE">, - (HLO_DEFAULT_COMPARISON_TYPE) + MHLO_ComparisonDirectionValue<"GE">, + (MHLO_DEFAULT_COMPARISON_TYPE) ), - (HLO_AddOp - (HLO_LogOp $input), - (HLO_LogOp - (HLO_ConstantLike<"2"> $input) + (MHLO_AddOp + (MHLO_LogOp $input), + (MHLO_LogOp + (MHLO_ConstantLike<"2"> $input) ) ), - (HLO_LogOp - (HLO_AddOp + (MHLO_LogOp + (MHLO_AddOp $input, - (HLO_SqrtOp - (HLO_MulOp - (HLO_AddOp - (HLO_ConstantLike<"1"> $input), + (MHLO_SqrtOp + (MHLO_MulOp + (MHLO_AddOp + (MHLO_ConstantLike<"1"> $input), $input ), - (HLO_AddOp - (HLO_ConstantLike<"-1"> $input), + (MHLO_AddOp + (MHLO_ConstantLike<"-1"> $input), $input ) ) @@ -124,18 +124,18 @@ def : Pat<(CHLO_AcoshOp NonComplexElementType:$input), // complex type, because we don't yet have exhaustive tests for complex trig // functions". def : Pat<(CHLO_AcoshOp ComplexElementType:$input), - (HLO_LogOp - (HLO_AddOp + (MHLO_LogOp + (MHLO_AddOp $input, - (HLO_SqrtOp - (HLO_MulOp - (HLO_AddOp + (MHLO_SqrtOp + (MHLO_MulOp + (MHLO_AddOp $input, - (HLO_ConstantLike<"1"> $input) + (MHLO_ConstantLike<"1"> $input) ), - (HLO_SubtractOp + (MHLO_SubtractOp $input, - (HLO_ConstantLike<"1"> $input) + (MHLO_ConstantLike<"1"> $input) ) ) ) @@ -146,16 +146,16 @@ def : Pat<(CHLO_AcoshOp ComplexElementType:$input), // Expand asin to MHLO dialect as follows: // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) def : Pat<(CHLO_AsinOp $input), - (HLO_MulOp - (HLO_ConstantLike<"2"> $input), - (HLO_Atan2Op + (MHLO_MulOp + (MHLO_ConstantLike<"2"> $input), + (MHLO_Atan2Op $input, - (HLO_AddOp - (HLO_ConstantLike<"1"> $input), - (HLO_SqrtOp - (HLO_SubtractOp - (HLO_ConstantLike<"1"> $input), - (HLO_MulOp $input, $input) + (MHLO_AddOp + (MHLO_ConstantLike<"1"> $input), + (MHLO_SqrtOp + (MHLO_SubtractOp + (MHLO_ConstantLike<"1"> $input), + (MHLO_MulOp $input, $input) ) ) ) @@ -183,48 +183,48 @@ def : Pat<(CHLO_AsinOp $input), // the result as x + abs(x) = 0 but we are saved by the fact that asinh(-x) = // -asinh(x). def : Pat<(CHLO_AsinhOp NonComplexElementType:$input), - (HLO_MulOp - (HLO_SignOp $input), - (HLO_SelectOp - (HLO_CompareOp - (HLO_AbsOp $input), - (HLO_SqrtOp - (HLO_ConstantLikeMaxFiniteValue $input) + (MHLO_MulOp + (MHLO_SignOp $input), + (MHLO_SelectOp + (MHLO_CompareOp + (MHLO_AbsOp $input), + (MHLO_SqrtOp + (MHLO_ConstantLikeMaxFiniteValue $input) ), - HLO_ComparisonDirectionValue<"GE">, - (HLO_DEFAULT_COMPARISON_TYPE) + MHLO_ComparisonDirectionValue<"GE">, + (MHLO_DEFAULT_COMPARISON_TYPE) ), - (HLO_AddOp - (HLO_LogOp - (HLO_AbsOp $input) + (MHLO_AddOp + (MHLO_LogOp + (MHLO_AbsOp $input) ), - (HLO_LogOp - (HLO_ConstantLike<"2"> $input) + (MHLO_LogOp + (MHLO_ConstantLike<"2"> $input) ) ), - (HLO_SelectOp - (HLO_CompareOp - (HLO_AbsOp $input), - (HLO_ConstantLike<"1"> $input), - HLO_ComparisonDirectionValue<"LE">, - (HLO_DEFAULT_COMPARISON_TYPE) + (MHLO_SelectOp + (MHLO_CompareOp + (MHLO_AbsOp $input), + (MHLO_ConstantLike<"1"> $input), + MHLO_ComparisonDirectionValue<"LE">, + (MHLO_DEFAULT_COMPARISON_TYPE) ), - (HLO_Log1pOp - (HLO_AddOp - (HLO_AbsOp $input), - (HLO_MulOp - (HLO_AbsOp $input), - (HLO_DivOp - (HLO_AbsOp $input), - (HLO_AddOp - (HLO_ConstantLike<"1"> $input), - (HLO_SqrtOp - (HLO_AddOp - (HLO_MulOp - (HLO_AbsOp $input), - (HLO_AbsOp $input) + (MHLO_Log1pOp + (MHLO_AddOp + (MHLO_AbsOp $input), + (MHLO_MulOp + (MHLO_AbsOp $input), + (MHLO_DivOp + (MHLO_AbsOp $input), + (MHLO_AddOp + (MHLO_ConstantLike<"1"> $input), + (MHLO_SqrtOp + (MHLO_AddOp + (MHLO_MulOp + (MHLO_AbsOp $input), + (MHLO_AbsOp $input) ), - (HLO_ConstantLike<"1"> $input) + (MHLO_ConstantLike<"1"> $input) ) ) ) @@ -232,16 +232,16 @@ def : Pat<(CHLO_AsinhOp NonComplexElementType:$input), ) ) ), - (HLO_LogOp - (HLO_AddOp - (HLO_AbsOp $input), - (HLO_SqrtOp - (HLO_AddOp - (HLO_MulOp - (HLO_AbsOp $input), - (HLO_AbsOp $input) + (MHLO_LogOp + (MHLO_AddOp + (MHLO_AbsOp $input), + (MHLO_SqrtOp + (MHLO_AddOp + (MHLO_MulOp + (MHLO_AbsOp $input), + (MHLO_AbsOp $input) ), - (HLO_ConstantLike<"1"> $input) + (MHLO_ConstantLike<"1"> $input) ) ) ) @@ -258,13 +258,13 @@ def : Pat<(CHLO_AsinhOp NonComplexElementType:$input), // complex type, because we don't yet have exhaustive tests for complex trig // functions". def : Pat<(CHLO_AsinhOp ComplexElementType:$input), - (HLO_LogOp - (HLO_AddOp + (MHLO_LogOp + (MHLO_AddOp $input, - (HLO_SqrtOp - (HLO_AddOp - (HLO_MulOp $input, $input), - (HLO_ConstantLike<"1"> $input) + (MHLO_SqrtOp + (MHLO_AddOp + (MHLO_MulOp $input, $input), + (MHLO_ConstantLike<"1"> $input) ) ) ) @@ -273,31 +273,31 @@ def : Pat<(CHLO_AsinhOp ComplexElementType:$input), // Express `atan` as // atan(x) = atan2(x, 1) def : Pat<(CHLO_AtanOp $input), - (HLO_Atan2Op + (MHLO_Atan2Op $input, - (HLO_ConstantLike<"1"> $input) + (MHLO_ConstantLike<"1"> $input) )>; // Express `atanh` for non-complex arguments as follows: // atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1 // atanh(x) = nan otherwise def : Pat<(CHLO_AtanhOp NonComplexElementType:$input), - (HLO_SelectOp - (HLO_CompareOp - (HLO_AbsOp $input), - (HLO_ConstantLike<"1"> $input), - HLO_ComparisonDirectionValue<"GT">, - (HLO_DEFAULT_COMPARISON_TYPE) + (MHLO_SelectOp + (MHLO_CompareOp + (MHLO_AbsOp $input), + (MHLO_ConstantLike<"1"> $input), + MHLO_ComparisonDirectionValue<"GT">, + (MHLO_DEFAULT_COMPARISON_TYPE) ), - (HLO_ConstantLike<"NAN"> $input), - (HLO_MulOp - (HLO_SubtractOp - (HLO_Log1pOp $input), - (HLO_Log1pOp - (HLO_NegOp $input) + (MHLO_ConstantLike<"NAN"> $input), + (MHLO_MulOp + (MHLO_SubtractOp + (MHLO_Log1pOp $input), + (MHLO_Log1pOp + (MHLO_NegOp $input) ) ), - (HLO_ConstantLike<"0.5"> $input) + (MHLO_ConstantLike<"0.5"> $input) ) )>; @@ -308,47 +308,50 @@ def : Pat<(CHLO_AtanhOp NonComplexElementType:$input), // "For now, we ignore the nan edge case for complex inputs, // because we don't yet have exhaustive tests for complex trig functions". def : Pat<(CHLO_AtanhOp ComplexElementType:$input), - (HLO_MulOp - (HLO_SubtractOp - (HLO_Log1pOp $input), - (HLO_Log1pOp - (HLO_NegOp $input) + (MHLO_MulOp + (MHLO_SubtractOp + (MHLO_Log1pOp $input), + (MHLO_Log1pOp + (MHLO_NegOp $input) ) ), - (HLO_ConstantLike<"0.5"> $input) + (MHLO_ConstantLike<"0.5"> $input) )>; // Express `conj` as // conj(x) = (re(x), -im(x)). def : Pat<(CHLO_ConjOp $v), - (HLO_ComplexOp (HLO_RealOp $v), (HLO_NegOp (HLO_ImagOp $v)))>; + (MHLO_ComplexOp (MHLO_RealOp $v), (MHLO_NegOp (MHLO_ImagOp $v)))>; // Express `is_inf` as // is_inf(x) = is_pos_inf(|x|) def : Pat<(CHLO_IsInfOp NonComplexElementType:$input), (CHLO_IsPosInfOp - (HLO_AbsOp $input) + (MHLO_AbsOp $input) )>; // Express `is_pos_inf` as // is_pos_inf(x) = (x == +inf) def : Pat<(CHLO_IsPosInfOp NonComplexElementType:$input), - (HLO_CompareOp + (MHLO_CompareOp $input, - (HLO_ConstantLikePosInfValue $input), - HLO_ComparisonDirectionValue<"EQ">, - (HLO_DEFAULT_COMPARISON_TYPE) + (MHLO_ConstantLikePosInfValue $input), + MHLO_ComparisonDirectionValue<"EQ">, + (MHLO_DEFAULT_COMPARISON_TYPE) )>; // Express `is_neg_inf` as // is_neg_inf(x) = (x == -inf) def : Pat<(CHLO_IsNegInfOp NonComplexElementType:$input), - (HLO_CompareOp + (MHLO_CompareOp $input, - (HLO_ConstantLikeNegInfValue $input), - HLO_ComparisonDirectionValue<"EQ">, - (HLO_DEFAULT_COMPARISON_TYPE) + (MHLO_ConstantLikeNegInfValue $input), + MHLO_ComparisonDirectionValue<"EQ">, + (MHLO_DEFAULT_COMPARISON_TYPE) )>; def : Pat<(CHLO_ConstantOp $v), - (HLO_ConstantOp $v)>; + (MHLO_ConstantOp $v)>; + +def : Pat<(CHLO_TanOp $v), + (MHLO_TanOp $v)>; diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/collapse_elementwise_map.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc similarity index 92% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/collapse_elementwise_map.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc index cf8b2309ae0..adc3937a9e0 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/collapse_elementwise_map.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc @@ -15,11 +15,11 @@ limitations under the License. #include -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -28,7 +28,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_COLLAPSEELEMENTWISEMAPPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { @@ -58,7 +58,7 @@ struct ConvertMapOfElementwiseOps : public OpRewritePattern { } rewriter.setInsertionPointAfter(map); - BlockAndValueMapping blockAndValueMap; + IRMapping blockAndValueMap; for (mlir::BlockArgument barg : map.getComputation().front().getArguments()) { blockAndValueMap.map(barg, map->getOperand(barg.getArgNumber())); diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/constraint_fusion_pass.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/constraint_fusion/constraint_fusion_pass.cc similarity index 99% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/constraint_fusion_pass.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/constraint_fusion/constraint_fusion_pass.cc index 5ddb16cc50d..4dae80389f5 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/constraint_fusion_pass.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/constraint_fusion/constraint_fusion_pass.cc @@ -21,7 +21,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Pass/Pass.h" @@ -30,7 +30,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_CONSTRAINTFUSIONPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/convert_to_signless_pass.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/convert_to_signless/convert_to_signless_pass.cc similarity index 96% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/convert_to_signless_pass.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/convert_to_signless/convert_to_signless_pass.cc index 2ca30707a0d..3037b380d07 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/convert_to_signless_pass.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/convert_to_signless/convert_to_signless_pass.cc @@ -21,8 +21,8 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/utils/type_conversion.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" @@ -40,7 +40,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_CONVERTTOSIGNLESSPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/expand_hlo_tuples.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/expand_hlo_tuples/expand_hlo_tuples.cc similarity index 89% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/expand_hlo_tuples.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/expand_hlo_tuples/expand_hlo_tuples.cc index fc6401ec083..d9a3e6d9520 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/expand_hlo_tuples.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/expand_hlo_tuples/expand_hlo_tuples.cc @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -36,7 +37,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_EXPANDHLOTUPLESPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { @@ -102,27 +103,27 @@ class ExpandHloTuplesPass // Update output signatures. auto returnOp = cast(func.getBody().back().back()); + OpBuilder builder(returnOp); // Expand all tuples in old return operands. SmallVector expandedReturnOperands; SmallVector expandedResultTypes; for (auto value : returnOp.getOperands()) { - auto tuple = dyn_cast_or_null(value.getDefiningOp()); - if (!tuple) { + if (auto tupleTy = value.getType().dyn_cast()) { + llvm::copy(tupleTy.getTypes(), std::back_inserter(expandedResultTypes)); + for (auto [index, ty] : llvm::enumerate(tupleTy.getTypes())) { + expandedReturnOperands.push_back( + builder.createOrFold(value.getLoc(), ty, + value, index)); + } + } else { expandedReturnOperands.push_back(value); expandedResultTypes.push_back(value.getType()); - continue; - } - - for (auto tupleOperand : tuple.getOperands()) { - expandedReturnOperands.push_back(tupleOperand); - expandedResultTypes.push_back(tupleOperand.getType()); } } if (returnOp.getOperands() == expandedReturnOperands) return; - OpBuilder builder(returnOp); builder.create(returnOp.getLoc(), expandedReturnOperands); returnOp.erase(); diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/expand_ops_simplifier/expand_ops_simplifier.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/expand_ops_simplifier/expand_ops_simplifier.cc new file mode 100644 index 00000000000..1a47c0bb72f --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/expand_ops_simplifier/expand_ops_simplifier.cc @@ -0,0 +1,229 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file replaces some complicated HLOs such as SelectAndScatter with a +// sequence of simpler HLOs. + +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace mhlo { + +#define GEN_PASS_DEF_MHLOEXPANDOPSSIMPLIFIERPASS +#include "mhlo/transforms/mhlo_passes.h.inc" + +namespace { + +ShapedType getScalarizedType(ShapedType t) { + return t.cloneWith(llvm::ArrayRef(std::nullopt), t.getElementType()); +} + +struct SelectAndScatterExpanderPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SelectAndScatterOp sas, + PatternRewriter& rewriter) const override { + // Capture original values with variables + ImplicitLocOpBuilder builder(sas.getLoc(), rewriter); + TypedValue operand = sas.getOperand(); + llvm::ArrayRef operandShape = operand.getType().getShape(); + TypedValue source = sas.getSource(); + Value initValue = sas.getInitValue(); + Region& select = sas.getSelect(); + Region& scatter = sas.getScatter(); + TensorType sasType = sas.getType(); + + // Useful shapes + const auto iotaShape = + operand.getType().cloneWith(operandShape, rewriter.getI64Type()); + const auto sourceShape = source.getType().getShape(); + const auto iotaShapeReduced = + source.getType().cloneWith(sourceShape, rewriter.getI64Type()); + const auto scalarIota = getScalarizedType(iotaShapeReduced); + + // Construct one iota for each dimension. This will reduced in the reduction + // to determine the indices to be scattered to. + llvm::SmallVector iotas; + iotas.reserve(operandShape.size()); + for (size_t i = 0; i < operandShape.size(); ++i) { + iotas.push_back(builder.create(iotaShape, i)); + } + + // ReduceWindow arguments + auto numReduceValues = iotas.size() + 1; + auto negOne = builder.create( + mlir::DenseIntElementsAttr::get(scalarIota, (uint64_t)-1)); + llvm::SmallVector reduceInitValues(numReduceValues, negOne); + reduceInitValues.front() = initValue; + + // ReduceWindow arguments + llvm::SmallVector ops; + ops.reserve(numReduceValues); + ops.push_back(operand); + ops.insert(ops.end(), iotas.begin(), iotas.end()); + + // Construct ReduceWindow and its region. + auto reduceWindow = builder.create( + ops, reduceInitValues, sas.getWindowDimensionsAttr(), + sas.getWindowStridesAttr(), /*dilations=*/nullptr, + /*dilations=*/nullptr, sas.getPaddingAttr(), + [&](OpBuilder& b, Location loc, ValueRange /*values*/) { + ImplicitLocOpBuilder builder(loc, b); + Block* block = b.getBlock(); + auto rhsBegin = static_cast(numReduceValues); + auto lhsBegin = 0; + auto firstIota = 1; + Value firstLhsIota = block->getArgument(firstIota); + Value firstRhsIota = block->getArgument(firstIota + rhsBegin); + Value lhsFirstInWindow = builder.create( + firstLhsIota, negOne, mhlo::ComparisonDirection::NE); + // Current implementations of ReduceWindow do not need the following + // line in their implementations, but it is actually required in the + // documented behavior of the implementation which allows the seed + // value to occur on both lhs and rhs sides when padding occurs. + Value rhsFirstInWindow = builder.create( + firstRhsIota, negOne, mhlo::ComparisonDirection::NE); + auto rhsNotFirstInWindow = + builder.create(rhsFirstInWindow); + + Value operandLhs = block->getArgument(0); + Value operandRhs = block->getArgument(rhsBegin); + llvm::SmallVector selectIns; + selectIns.push_back(operandLhs); + selectIns.push_back(operandRhs); + rewriter.mergeBlocks(&select.front(), block, selectIns); + Value call = block->back().getOperand(0); + rewriter.eraseOp(&block->back()); + + Value pred = builder.create(call, lhsFirstInWindow); + pred = builder.create(pred, rhsNotFirstInWindow); + + llvm::SmallVector resultTuple; + for (auto i = lhsBegin; i < rhsBegin; ++i) { + Value iotaLhs = block->getArgument(i); + Value iotaRhs = block->getArgument(i + rhsBegin); + resultTuple.push_back( + builder.create(pred, iotaLhs, iotaRhs)); + } + builder.create(resultTuple); + }); + + // Handle the results of the reduction + llvm::SmallVector iotaIndices; + llvm::SmallVector broadcastedIotaDims; + broadcastedIotaDims.reserve(iotaShapeReduced.getRank() + 1); + broadcastedIotaDims.insert(broadcastedIotaDims.end(), + iotaShapeReduced.getShape().begin(), + iotaShapeReduced.getShape().end()); + broadcastedIotaDims.push_back(1); + auto broadcastedIotaShape = RankedTensorType::get( + broadcastedIotaDims, iotaShapeReduced.getElementType()); + + for (size_t i = 1; i < numReduceValues; ++i) { + Value element = reduceWindow.getResult(i); + iotaIndices.push_back( + builder.create(broadcastedIotaShape, element) + .getResult()); + } + + // Prepare scatter inputs + llvm::SmallVector scatterDims(operandShape.size()); + std::iota(scatterDims.begin(), scatterDims.end(), 0); + Value broadcastedInitValue = builder.create( + initValue, mlir::DenseIntElementsAttr::get( + RankedTensorType::get(sasType.getShape().size(), + rewriter.getIntegerType(64, true)), + sasType.getShape())); + + llvm::SmallVector concatenatedIotasDims; + concatenatedIotasDims.reserve( + iotaIndices.front().getType().cast().getRank()); + concatenatedIotasDims.insert(concatenatedIotasDims.end(), + broadcastedIotaDims.begin(), + broadcastedIotaDims.end()); + concatenatedIotasDims.back() = static_cast(iotaIndices.size()); + Value indices = builder.create( + RankedTensorType::get(concatenatedIotasDims, + iotaShape.getElementType()), + iotaIndices, iotaShape.getRank()); + + // Scatter + auto dimNums = mhlo::ScatterDimensionNumbersAttr::get( + sas->getContext(), + /*updateWindowDims=*/{}, + /*insertedWindowDims=*/scatterDims, + /*scatterDimsToOperandDims=*/scatterDims, + /*indexVectorDim=*/source.getType().getRank()); + auto scatterOp = builder.create( + /*shape=*/sasType, /*operand=*/broadcastedInitValue, + /*scatter_indices=*/indices, /*updates=*/source, + /*scatter_dim_numbers=*/dimNums, + /*indices_are_sorted=*/false, /*unique_indices=*/false); + + // Prepare ScatterOp block and then copy SelectAndScatter's body + llvm::SmallVector scatterIns; + llvm::SmallVector scatterLocs; + scatterIns.push_back(RankedTensorType::get( + {}, + broadcastedInitValue.getType().cast().getElementType())); + scatterIns.push_back( + RankedTensorType::get({}, source.getType().getElementType())); + scatterLocs.push_back(broadcastedInitValue.getLoc()); + scatterLocs.push_back(source.getLoc()); + + rewriter.inlineRegionBefore(scatter, scatterOp.getUpdateComputation(), + scatterOp.getUpdateComputation().end()); + rewriter.replaceOp(sas, scatterOp.getResults()); + return success(); + } +}; + +struct MhloExpandOpsSimplifierPass + : impl::MhloExpandOpsSimplifierPassBase { + void runOnOperation() override { + auto* ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> +createMhloExpandOpsSimplifierPass() { + return std::make_unique(); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/group_reduction_dimensions.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc similarity index 98% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/group_reduction_dimensions.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc index 3c21fdcc460..b7a9718a4cd 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/group_reduction_dimensions.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc @@ -20,9 +20,9 @@ limitations under the License. #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" @@ -34,7 +34,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_GROUPREDUCTIONDIMENSIONSPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_shape_ops_to_standard.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_shape_ops_to_standard/hlo_legalize_shape_ops_to_standard.cc similarity index 96% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_shape_ops_to_standard.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_shape_ops_to_standard/hlo_legalize_shape_ops_to_standard.cc index 024346d993a..e539564c87d 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_shape_ops_to_standard.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_shape_ops_to_standard/hlo_legalize_shape_ops_to_standard.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/rewriters.h" +#include "mhlo/utils/type_conversion.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" @@ -44,7 +44,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_HLOLEGALIZESHAPEOPSTOSTANDARDPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { @@ -137,7 +137,7 @@ struct CstrReshapableConversion : rewriter.create( loc, extentType, adaptor.getOperands()[1]); auto reduction = rewriter.create( - loc, newShape, llvm::makeArrayRef({one, zero, zero})); + loc, newShape, llvm::ArrayRef({one, zero, zero})); { PatternRewriter::InsertionGuard g(rewriter); auto* body = reduction.getBody(); @@ -158,7 +158,7 @@ struct CstrReshapableConversion Value totalElements = rewriter.create( loc, extentOrOne, body->getArgument(2)); rewriter.create( - loc, llvm::makeArrayRef({totalElements, totalDynamic, totalInvalid})); + loc, llvm::ArrayRef({totalElements, totalDynamic, totalInvalid})); } // Avoid division by zero. Value isZeroElements = rewriter.create( diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_arithmetic.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc similarity index 92% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_arithmetic.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc index 057ee6067bd..955182f3b52 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_arithmetic.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/map_mhlo_to_scalar_op.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -33,7 +33,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_HLOLEGALIZETOARITHMETICPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { @@ -111,11 +111,18 @@ struct RngGetAndUpdateStatePattern template struct ScalarHloToArithmeticPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + ScalarHloToArithmeticPattern( + TypeConverter& typeConverter, MLIRContext* context, + llvm::function_ref filterFn = nullptr, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + filterFn(filterFn) {} LogicalResult matchAndRewrite( OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter& rewriter) const final { + if (filterFn && !filterFn(op)) return failure(); + auto isScalar = [&](Value v) { return v.getType().cast().getRank() == 0; }; @@ -141,6 +148,9 @@ struct ScalarHloToArithmeticPattern : public OpConversionPattern { scalarResult); return success(); } + + private: + llvm::function_ref filterFn; }; struct HloLegalizeToArithmeticPass @@ -177,7 +187,8 @@ void populateHloToArithmeticConversionPatterns(RewritePatternSet* patterns) { void populateScalarHloToArithmeticConversionPatterns( MLIRContext* context, TypeConverter& typeConverter, - RewritePatternSet* patterns) { + RewritePatternSet* patterns, + llvm::function_ref filterFn) { // clang-format off patterns->add< ScalarHloToArithmeticPattern, @@ -225,9 +236,10 @@ void populateScalarHloToArithmeticConversionPatterns( ScalarHloToArithmeticPattern, ScalarHloToArithmeticPattern, ScalarHloToArithmeticPattern, + ScalarHloToArithmeticPattern, ScalarHloToArithmeticPattern, ScalarHloToArithmeticPattern - >(typeConverter, context); + >(typeConverter, context, filterFn); // clang-format on } diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_lhlo/hlo_legalize_to_lhlo.cc similarity index 94% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_lhlo/hlo_legalize_to_lhlo.cc index fc0a9a7cbe7..55b32099e6e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_lhlo/hlo_legalize_to_lhlo.cc @@ -16,13 +16,14 @@ limitations under the License. // This file implements logic for lowering HLO dialect to LHLO dialect. #include +#include #include -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" -#include "mlir-hlo/Dialect/lhlo/transforms/map_hlo_to_lhlo_op.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "lhlo/IR/lhlo_ops.h" +#include "lhlo/transforms/map_hlo_to_lhlo_op.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" @@ -35,7 +36,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" -#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -50,7 +51,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_HLOLEGALIZETOLHLOPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { @@ -60,17 +61,14 @@ using BaseOpConversion = OpConversionPattern; Value insertDynamicAlloc(Location loc, Value result, Value shapeOperand, ConversionPatternRewriter* rewriter) { auto resultType = result.getType().dyn_cast(); - if (!resultType) { - result.getDefiningOp()->emitOpError() - << "tensor to buffer conversion expects ranked results"; - } + assert(resultType); auto memrefType = MemRefType::get(resultType.getShape(), resultType.getElementType()); // Extract the required element out of the vector. SmallVector dynamicOperands; for (const auto& shapeElement : llvm::enumerate(resultType.getShape())) { - if (shapeElement.value() != ShapedType::kDynamicSize) continue; + if (shapeElement.value() != ShapedType::kDynamic) continue; Value index = rewriter->create(loc, shapeElement.index()); Value allocOperand = @@ -88,10 +86,7 @@ Value insertDynamicAlloc(Location loc, Value result, Value shapeOperand, Value insertAlloc(Location loc, OpResult result, ConversionPatternRewriter* rewriter) { auto resultType = result.getType().dyn_cast(); - if (!resultType || !resultType.hasStaticShape()) { - result.getDefiningOp()->emitOpError() - << "tensor to buffer conversion expects statically shaped results"; - } + assert(resultType && resultType.hasStaticShape()); auto memrefType = MemRefType::get(resultType.getShape(), resultType.getElementType()); OpBuilder::InsertionGuard guard(*rewriter); @@ -150,10 +145,11 @@ class HloToLhloOpConverter : public BaseOpConversion { Operation* op = hloOp.getOperation(); SmallVector bufferArgs(adaptor.getOperands()); if (failed(convertResults(op, bufferArgs, rewriter))) return failure(); - rewriter.create>(op->getLoc(), llvm::None, + rewriter.create>(op->getLoc(), std::nullopt, bufferArgs, op->getAttrs()); - rewriter.replaceOp(op, llvm::makeArrayRef(bufferArgs) - .drop_front(adaptor.getOperands().size())); + rewriter.replaceOp( + op, + llvm::ArrayRef(bufferArgs).drop_front(adaptor.getOperands().size())); return success(); } }; @@ -174,7 +170,7 @@ class HloToLhloOpConverter : public BaseOpConversion { SmallVector bufferArgs(adaptor.getOperands()); if (failed(convertResults(op, bufferArgs, rewriter))) return failure(); - auto dotOp = rewriter.create(op->getLoc(), llvm::None, + auto dotOp = rewriter.create(op->getLoc(), std::nullopt, bufferArgs, op->getAttrs()); // MHLO's Dot uses rank-2 operands, of the form ([N, M], [M, O]) -> [N, O]. auto dimensionNumbers = mhlo::DotDimensionNumbersAttr::get( @@ -201,7 +197,7 @@ struct HloToLhloCustomCallOpConverter if (failed(convertResults(op, bufferArgs, rewriter))) return failure(); auto lhloOp = rewriter.create( - op->getLoc(), llvm::None, bufferArgs, op->getAttrs()); + op->getLoc(), std::nullopt, bufferArgs, op->getAttrs()); // Setup AttrSizedOperandSegments attribute to indicate number of operands // for args and outputs. const int32_t segments[2] = { @@ -248,7 +244,7 @@ struct HloToLhloDotGeneralOpConverter resultsShape.front(), &rewriter); } - rewriter.create(op->getLoc(), llvm::None, bufferArgs, + rewriter.create(op->getLoc(), std::nullopt, bufferArgs, op->getAttrs()); rewriter.replaceOp(op, bufferArgs[2]); return success(); @@ -273,7 +269,7 @@ struct HloToLhloReduceLikeOpConverter : public BaseOpConversion { SmallVector bufferArgs(adaptor.getOperands()); if (failed(convertResults(op, bufferArgs, rewriter))) return failure(); auto newOp = rewriter.create>( - loc, llvm::None, bufferArgs, op->getAttrs()); + loc, std::nullopt, bufferArgs, op->getAttrs()); // Copy over the operations inside the region. rewriter.inlineRegionBefore(hloOp.getBody(), newOp.getBody(), @@ -571,6 +567,7 @@ void populateHloToLhloConversionPattern( HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_memref.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_memref/hlo_legalize_to_memref.cc similarity index 70% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_memref.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_memref/hlo_legalize_to_memref.cc index 2071140afa7..5f95e643520 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_memref.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_memref/hlo_legalize_to_memref.cc @@ -17,12 +17,13 @@ limitations under the License. #include #include +#include #include -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/bufferizable_op_interface_impl.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "lhlo/IR/lhlo_ops.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/interfaces/bufferizable_op_interface_impl.h" +#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -37,10 +38,11 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_HLOLEGALIZETOMEMREFPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { +using bufferization::AliasingOpResultList; using bufferization::AnalysisState; using bufferization::BufferizableOpInterface; using bufferization::BufferizationOptions; @@ -60,7 +62,7 @@ struct CustomCallOpInterface return false; // Arguments are read-only. } - SmallVector getAliasingOpResult(Operation *, OpOperand &, + AliasingOpResultList getAliasingOpResults(Operation *, OpOperand &, const AnalysisState &) const { return {}; } @@ -68,20 +70,33 @@ struct CustomCallOpInterface LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto customCallOp = cast(op); + Value tokenArgument; // Bufferize arguments. SmallVector bufferArgs; for (OpOperand &operand : customCallOp->getOpOperands()) { + auto &newBuffer = bufferArgs.emplace_back(); + if (operand.get().getType().isa()) { + // Remember the token for later. We need it for the return value but + // it's not getting passed to LMHLO. + if (tokenArgument) return failure(); + tokenArgument = operand.get(); + continue; + } if (!operand.get().getType().isa()) return failure(); FailureOr operandBuffer = getBuffer(rewriter, operand.get(), options); if (failed(operandBuffer)) return failure(); - bufferArgs.push_back(*operandBuffer); + newBuffer = *operandBuffer; } // Allocate outputs. for (OpResult result : customCallOp->getOpResults()) { - auto tensorType = result.getType().cast(); + auto &newBuffer = bufferArgs.emplace_back(); + if (result.getType().isa()) { + continue; + } + auto tensorType = result.getType().dyn_cast(); if (!tensorType) return failure(); // TODO(springerm): Create alloc_tensor ops during TensorCopyInsertion. AnalysisState analysisState(options); @@ -92,21 +107,121 @@ struct CustomCallOpInterface if (failed(tensorAlloc)) return failure(); auto memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - Value resultBuffer = rewriter.create( + newBuffer = rewriter.create( op->getLoc(), memrefType, *tensorAlloc); - bufferArgs.push_back(resultBuffer); + } + + lmhlo::CustomCallTargetArgMappingAttr targetMapping; + auto numArguments = static_cast(customCallOp->getNumOperands()); + auto numResults = static_cast(customCallOp->getNumResults()); + + // Take the result buffers and fill in the token input in the gaps. + auto bufferResults = llvm::to_vector(llvm::map_range( + llvm::ArrayRef(bufferArgs).slice(numArguments), + [&](Value buffer) { return buffer ? buffer : tokenArgument; })); + + if (tokenArgument) { + // If there was a token, squeeze all the non-token arguments and results + // (in-place) and remember the mapping. + int nextIndex = 0; + llvm::SmallVector argToTargetArgMapping; + for (int i = 0; i < numArguments; ++i) { + if (bufferArgs[i]) { + argToTargetArgMapping.push_back(i); + bufferArgs[nextIndex++] = bufferArgs[i]; + } + } + llvm::SmallVector resultToTargetResultMapping; + for (int32_t i = numArguments; + i < static_cast(bufferArgs.size()); ++i) { + if (bufferArgs[i]) { + resultToTargetResultMapping.push_back(i - numArguments); + bufferArgs[nextIndex++] = bufferArgs[i]; + } + } + + // Build the mapping attribute. + targetMapping = lmhlo::CustomCallTargetArgMappingAttr::get( + rewriter.getContext(), numArguments, numResults, + argToTargetArgMapping, resultToTargetResultMapping); + + // Drop the remaining operands and adjust num_arguments and num_results + // for LMHLO creation. + bufferArgs.resize(nextIndex); + numArguments = static_cast(argToTargetArgMapping.size()); + numResults = static_cast(resultToTargetResultMapping.size()); } auto lhloOp = rewriter.create( - op->getLoc(), llvm::None, bufferArgs, op->getAttrs()); + op->getLoc(), std::nullopt, bufferArgs, op->getAttrs()); + if (targetMapping) lhloOp.setTargetArgMappingAttr(targetMapping); // lmhlo.custom_call uses a segment_size attribute to tell input from output // arguments. lhloOp->setAttr(lhloOp.getOperandSegmentSizeAttr(), - rewriter.getDenseI32ArrayAttr( - {static_cast(op->getNumOperands()), - static_cast(op->getNumResults())})); - bufferization::replaceOpWithBufferizedValues( - rewriter, op, makeArrayRef(bufferArgs).slice(op->getNumOperands())); + rewriter.getDenseI32ArrayAttr({numArguments, numResults})); + bufferization::replaceOpWithBufferizedValues(rewriter, op, bufferResults); + return success(); + } +}; + +struct InfeedOpInterface + : public BufferizableOpInterface::ExternalModel { + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + // Allocate buffers for the outputs of infeed. + SmallVector bufferArgs; + for (OpResult result : op->getOpResults()) { + if (!result.getType().isa()) continue; + AnalysisState analysisState(options); + auto tensorType = result.getType().cast(); + FailureOr tensorAlloc = + bufferization::allocateTensorForShapedValue( + rewriter, op->getLoc(), result, + analysisState.isTensorYielded(result), options); + if (failed(tensorAlloc)) return failure(); + auto memrefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + bufferArgs.push_back(rewriter.create( + op->getLoc(), memrefType, *tensorAlloc)); + } + rewriter.create(op->getLoc(), std::nullopt, bufferArgs, + op->getAttrs()); + // Pass the token along. + bufferArgs.push_back((op->getOperand(0))); + bufferization::replaceOpWithBufferizedValues(rewriter, op, bufferArgs); + return success(); + } +}; + +struct OutfeedOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *, OpOperand &, + const AnalysisState &) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *, OpOperand &, + const AnalysisState &) const { + return false; // Arguments are read-only. + } + + AliasingOpResultList getAliasingOpResults(Operation *, OpOperand &, + const AnalysisState &) const { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + // Outfeed trivially bufferizes to lmhlo. Just pass the token operand along. + FailureOr operandBuffer = + getBuffer(rewriter, op->getOperand(0), options); + if (failed(operandBuffer)) return failure(); + rewriter.create(op->getLoc(), std::nullopt, + *operandBuffer, op->getAttrs()); + bufferization::replaceOpWithBufferizedValues(rewriter, op, + {op->getOperand(1)}); return success(); } }; @@ -124,7 +239,7 @@ struct ReshapeOpInterface return false; } - SmallVector getAliasingOpResult( + AliasingOpResultList getAliasingOpResults( Operation *op, OpOperand & /*opOperand*/, const AnalysisState & /*state*/) const { return {op->getResult(0)}; @@ -169,7 +284,7 @@ struct DynamicReshapeOpInterface return false; } - SmallVector getAliasingOpResult( + AliasingOpResultList getAliasingOpResults( Operation *op, OpOperand & /*opOperand*/, const AnalysisState & /*state*/) const { return {op->getResult(0)}; @@ -308,8 +423,7 @@ FailureOr insertDynamicMemrefCastOp( } // Type-erased memref type with static rank and dynamic strides. - SmallVector dynamicLayout(resultRank, - ShapedType::kDynamicStrideOrOffset); + SmallVector dynamicLayout(resultRank, ShapedType::kDynamic); auto typeErasedMemrefType = MemRefType::get( resultType.getShape(), operandType.getElementType(), makeStridedLinearLayoutMap(dynamicLayout, @@ -334,7 +448,7 @@ struct DynamicBroadcastInDimOpInterface return false; } - SmallVector getAliasingOpResult( + AliasingOpResultList getAliasingOpResults( Operation *op, OpOperand & /*opOperand*/, const AnalysisState & /*state*/) const { return {op->getResult(0)}; @@ -343,7 +457,7 @@ struct DynamicBroadcastInDimOpInterface BufferRelation bufferRelation(Operation * /*op*/, OpResult /*opResult*/, const AnalysisState & /*state*/) const { // The op may allocate. - return BufferRelation::None; + return BufferRelation::Unknown; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -367,8 +481,7 @@ struct DynamicBroadcastInDimOpInterface struct HloLegalizeToMemrefPass : public impl::HloLegalizeToMemrefPassBase { void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); registerBufferizableOpInterfaceExternalModels(registry); } @@ -390,10 +503,16 @@ std::unique_ptr> createLegalizeToMemrefPass() { void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, MhloDialect * /*dialect*/) { CustomCallOp::attachInterface(*ctx); + InfeedOp::attachInterface(*ctx); + OutfeedOp::attachInterface(*ctx); ReshapeOp::attachInterface(*ctx); DynamicReshapeOp::attachInterface(*ctx); DynamicBroadcastInDimOp::attachInterface( *ctx); + + // Load additional dialects of which ops may get created. + ctx->loadDialect(); }); } diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc new file mode 100644 index 00000000000..2eb43736f90 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc @@ -0,0 +1,384 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/map_stablehlo_to_hlo_op.h" +#include "mhlo/transforms/rewriters.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/DebugStringHelper.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "stablehlo/dialect/StablehloOps.h" + +namespace mlir { +namespace stablehlo { +namespace { + +// PRIVATE MHLO features are internal to XLA and not used by any ML frontends. +// These should never be converted to StableHLO, as they are not a good fit for +// StableHLO. +template +bool hasPrivateFeaturesNotInStablehlo(HloOpTy hloOp) { + // To the best of our knowledge, none of the ML frontends are using these ops + // directly or indirectly, so we categorized them as private to XLA. + // Please let us know if we missed something, and we'll recategorize them. + if (isa(hloOp.getOperation())) { + return true; + } + if constexpr (std::is_same::value) { + // StableHLO convolution doesn't support "unknown" dimensions. + // This is an esoteric feature of MHLO convolutions, and it's different + // from the notion of dynamic dimensions. For more context, here's the + // commit which introduced it: + // https://github.com/tensorflow/mlir-hlo/commit/4d6dc3163c1c9289d86455d9f4de5711465c50fb + // This feature isn't supported in HLO and doesn't have documentation, so + // we may end up removing it from MHLO as well. + auto dimensionNumbers = debugString(hloOp.getDimensionNumbers()); + if (dimensionNumbers.find('?') != std::string::npos) return true; + } + if constexpr (std::is_same::value) { + // To the best of our knowledge, none of the ML frontends are using this + // enum, so we categorized it as private to XLA. + // Please let us know if we missed something, and we'll recategorize it. + if (hloOp.getCustomCallSchedule() != mhlo::CustomCallSchedule::NONE) + return true; + } + return false; +} + +bool hasPackedNibble(Optional precisionConfigAttr) { + if (!precisionConfigAttr) return false; + return llvm::any_of(*precisionConfigAttr, [&](Attribute attr) { + auto precisionAttr = attr.cast(); + return precisionAttr.getValue() == mhlo::Precision::PACKED_NIBBLE; + }); +} + +// EXPERIMENTAL MHLO features are being explored by ML frontends but do not have +// any agreed upon compatibility guarantees. By default, these features cannot +// be converted to StableHLO, although the allow-experimental-features flag can +// be used to manually enable the conversion. Such features might be a good fit +// for StableHLO, and they are usually accompanied by a StableHLO GitHub ticket. +template +bool hasExperimentalFeaturesNotInStablehlo(HloOpTy hloOp) { + if constexpr (std::is_same::value) { + // StableHLO AllToAll doesn't support the tuple form yet. + // Proposal: https://github.com/openxla/stablehlo/issues/574. + if (hloOp.getNumOperands() != 1) return true; + } + if constexpr (std::is_same::value) { + // StableHLO ConvolutionOp doesn't support PACKED_NIBBLE yet. + // Proposal: https://github.com/openxla/stablehlo/issues/742. + if (hasPackedNibble(hloOp.getPrecisionConfig())) return true; + } + if constexpr (std::is_same::value) { + // StableHLO CustomCall doesn't support API_VERSION_TYPED_FFI yet. + // Proposal: https://github.com/openxla/stablehlo/issues/637. + if (hloOp.getApiVersion() == + mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI) + return true; + } + if constexpr (std::is_same::value) { + // StableHLO DotGeneral doesn't support PACKED_NIBBLE yet. + // Proposal: https://github.com/openxla/stablehlo/issues/742. + if (hasPackedNibble(hloOp.getPrecisionConfig())) return true; + } + if constexpr (std::is_same::value) { + // StableHLO Dot doesn't support PACKED_NIBBLE yet. + // Proposal: https://github.com/openxla/stablehlo/issues/742. + if (hasPackedNibble(hloOp.getPrecisionConfig())) return true; + } + return false; +} + +// PUBLIC MHLO features are not yet in StableHLO but are agreed upon internally +// to have limited compatibility guarantees. These features are used by ML +// frontends but are not yet part of StableHLO. Such features might be a good +// fit for StableHLO, and are usually accompanied by a StableHLO GitHub ticket. +template +bool hasPublicFeaturesNotInStablehlo(HloOpTy) { + return false; +} + +#define RETURN_CONVERTED_ENUM_ATTR(Name) \ + auto hloValue = mhlo::stringify##Name(attr.getValue()); \ + auto stablehloValue = stablehlo::symbolize##Name(hloValue); \ + if (!stablehloValue.has_value()) return {}; \ + return stablehlo::Name##Attr::get(attr.getContext(), stablehloValue.value()) + +Attribute convertAttr(Attribute hloAttr) { + // Handle MHLO attributes. + // The logic that handles attributes from other dialects (e.g. builtin + // attributes) lives below. + if (auto attr = hloAttr.dyn_cast()) { + return stablehlo::ChannelHandleAttr::get(attr.getContext(), + attr.getHandle(), attr.getType()); + } + if (auto attr = hloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(ComparisonDirection); + } + if (auto attr = hloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(ComparisonType); + } + if (auto attr = hloAttr.dyn_cast()) { + return stablehlo::ConvDimensionNumbersAttr::get( + attr.getContext(), attr.getInputBatchDimension(), + attr.getInputFeatureDimension(), attr.getInputSpatialDimensions(), + attr.getKernelInputFeatureDimension(), + attr.getKernelOutputFeatureDimension(), + attr.getKernelSpatialDimensions(), attr.getOutputBatchDimension(), + attr.getOutputFeatureDimension(), attr.getOutputSpatialDimensions()); + } + // NOTE: We cannot process CustomCallApiVersionAttr here because + // `dyn_cast()` succeeds for IntegerAttr too. + if (auto attr = hloAttr.dyn_cast()) { + return stablehlo::DotDimensionNumbersAttr::get( + attr.getContext(), attr.getLhsBatchingDimensions(), + attr.getRhsBatchingDimensions(), attr.getLhsContractingDimensions(), + attr.getRhsContractingDimensions()); + } + if (auto attr = hloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(FftType); + } + if (auto attr = hloAttr.dyn_cast()) { + return stablehlo::GatherDimensionNumbersAttr::get( + attr.getContext(), attr.getOffsetDims(), attr.getCollapsedSliceDims(), + attr.getStartIndexMap(), attr.getIndexVectorDim()); + } + if (auto attr = hloAttr.dyn_cast()) { + return stablehlo::OutputOperandAliasAttr::get( + attr.getContext(), attr.getOutputTupleIndices(), attr.getOperandIndex(), + attr.getOperandTupleIndices()); + } + if (auto attr = hloAttr.dyn_cast()) { + // StableHLO Precision doesn't support PACKED_NIBBLE yet. + // Proposal: https://github.com/openxla/stablehlo/issues/742. + if (attr.getValue() == mhlo::Precision::PACKED_NIBBLE) return {}; + RETURN_CONVERTED_ENUM_ATTR(Precision); + } + if (auto attr = hloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(RngAlgorithm); + } + if (auto attr = hloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(RngDistribution); + } + if (auto attr = hloAttr.dyn_cast()) { + return stablehlo::ScatterDimensionNumbersAttr::get( + attr.getContext(), attr.getUpdateWindowDims(), + attr.getInsertedWindowDims(), attr.getScatterDimsToOperandDims(), + attr.getIndexVectorDim()); + } + if (auto attr = hloAttr.dyn_cast()) { + RETURN_CONVERTED_ENUM_ATTR(Transpose); + } + if (hloAttr.getDialect().getNamespace() == + mhlo::MhloDialect::getDialectNamespace()) { + // Our guiding principle is to support all StableHLO functionality in MHLO. + // The inverse is not necessarily true - some MHLO attributes are missing + // from StableHLO (either deliberately or haven't yet been proposed). + // As a result, these MHLO attributes will fail here. + return {}; + } + + // Handle non-MHLO attributes. + // If an attribute is not defined in MHLO, then it is unchanged, + // with the exception of ArrayAttr which is converted recursively. + if (auto hloAttrs = hloAttr.dyn_cast()) { + SmallVector stablehloAttrs; + for (auto hloAttr : hloAttrs) { + auto stablehloAttr = convertAttr(hloAttr); + if (!stablehloAttr) return {}; + stablehloAttrs.push_back(stablehloAttr); + } + return ArrayAttr::get(hloAttrs.getContext(), stablehloAttrs); + } + return hloAttr; +} + +#undef RETURN_CONVERTED_ENUM_ATTR + +template +class HloToStablehloOpConverter : public OpConversionPattern { + public: + HloToStablehloOpConverter(TypeConverter& converter, MLIRContext* context, + bool allowExperimentalFeatures) + : OpConversionPattern::OpConversionPattern(converter, context), + allowExperimentalFeatures(allowExperimentalFeatures) {} + + LogicalResult matchAndRewrite( + HloOpTy hloOp, typename HloOpTy::Adaptor adaptor, + ConversionPatternRewriter& rewriter) const final { + // Most MHLO ops which end up here are fully supported by StableHLO. + // However, some of these ops are supported only partially because they + // have features that are not supported in StableHLO. + // These MHLO features fall into two distinct categories: + // 1) Features that are private to the XLA compiler, so they are not + // a good fit for StableHLO. Conversion of such features should fail. + // 2) Features that might be a good fit for StableHLO but haven't yet + // been proposed or approved in StableHLO. Conversion of such features + // should succeed using custom_call extensibility protocol (see below). + if (hasPrivateFeaturesNotInStablehlo(hloOp)) { + return failure(); + } + + // Convert MHLO types to StableHLO equivalents. + // If a type is not defined in MHLO, then it is unchanged, + // with the exception of RankedTensorType and TupleType which are + // converted recursively. + // See `HloToStablehloTypeConverter` for more information on when this + // conversion will succeed or fail. + SmallVector stablehloTypes; + if (failed(this->getTypeConverter()->convertTypes(hloOp->getResultTypes(), + stablehloTypes))) + return failure(); + + // These operands have already been converted to StableHLO by + // the dialect conversion infrastructure. + ValueRange stablehloOperands = adaptor.getOperands(); + + // Extensibility protocol for MHLO ops with public MHLO features that + // are not yet supported in StableHLO. + // 1) The op is represented by stablehlo::CustomCallOp. + // 2) The full name, e.g. "mhlo.all_to_all" is stored in the + // `call_target_name` attribute of the CustomCallOp. + // 3) The operands become operands of the CustomCallOp. + // 4) The attributes are wrapped in a DictionaryAttr, which is + // prettyprinted and then stored in the `backend_config` attribute + // of the CustomCallOp. + // 5) The result types become result types of the CustomCallOp. + // + // This StableHLO representation does not come with any compatibility + // guarantees. For example, when it is roundtripped back to MHLO, it may + // turn out that the original MHLO op no longer exists or has different + // attributes in the current version. + bool hasExperimentalFeatures = hasExperimentalFeaturesNotInStablehlo(hloOp); + if (!allowExperimentalFeatures && hasExperimentalFeatures) return failure(); + if (hasPublicFeaturesNotInStablehlo(hloOp) || hasExperimentalFeatures) { + if (hloOp->getNumRegions() != 0) { + // Extensibility protocol for regions hasn't been implemented yet. + // In principle, it should be straightforward to implement by + // converting regions into functions and calling them out in + // "called_computations". + // https://github.com/openxla/stablehlo/issues/593. + return failure(); + } + + auto stablehloCallTargetName = hloOp->getName().getStringRef(); + std::string stablehloBackendConfig; + llvm::raw_string_ostream os(stablehloBackendConfig); + os << hloOp->getAttrDictionary(); + + SmallVector stablehloAttrs; + stablehloAttrs.push_back(rewriter.getNamedAttr( + "call_target_name", rewriter.getStringAttr(stablehloCallTargetName))); + stablehloAttrs.push_back(rewriter.getNamedAttr( + "backend_config", rewriter.getStringAttr(stablehloBackendConfig))); + rewriter.replaceOpWithNewOp( + hloOp, stablehloTypes, stablehloOperands, stablehloAttrs); + return success(); + } + + // Convert MHLO attributes to StableHLO equivalents. + // If an attribute is not defined in MHLO, then it is unchanged, + // with the exception of ArrayAttr which is converted recursively. + SmallVector stablehloAttrs; + for (NamedAttribute hloAttr : hloOp->getAttrs()) { + if constexpr (std::is_same::value) { + // custom_call_schedule is private to XLA, but we still want to allow + // #mhlo (by ignoring it). + if (hloAttr.getName() == "custom_call_schedule" && + hloOp.getCustomCallSchedule() == mhlo::CustomCallSchedule::NONE) + continue; + } + auto stablehloAttr = convertAttr(hloAttr.getValue()); + if (!stablehloAttr) return failure(); + stablehloAttrs.push_back({hloAttr.getName(), stablehloAttr}); + } + + // Convert the MHLO operation to a StableHLO equivalent. + // This can almost be done in a generic fashion, except for stablehlo.case + // that uses a variadic number of regions which means an additional argument + // for the generic builder. + HloToStablehloOp stablehloOp; + if constexpr (std::is_same::value) { + stablehloOp = rewriter.replaceOpWithNewOp( + hloOp, stablehloTypes, stablehloOperands, stablehloAttrs, + hloOp.getBranches().size()); + } else { + stablehloOp = rewriter.replaceOpWithNewOp>( + hloOp, stablehloTypes, stablehloOperands, stablehloAttrs); + } + + // Finally, populate the regions while converting argument types + // and nested operations. + for (auto [hloRegion, stablehloRegion] : + llvm::zip(hloOp->getRegions(), stablehloOp->getRegions())) { + rewriter.inlineRegionBefore(hloRegion, stablehloRegion, + stablehloRegion.end()); + if (failed(rewriter.convertRegionTypes(&stablehloRegion, + *this->getTypeConverter(), + /*entryConversion=*/nullptr))) + return failure(); + } + return success(); + } + + bool allowExperimentalFeatures; +}; + +template +void populateHloToStablehloPatterns(RewritePatternSet* patterns, + TypeConverter* converter, + MLIRContext* context, + bool allowExperimentalFeatures) { + patterns + ->add>...>( + *converter, context, allowExperimentalFeatures); +} + +} // namespace + +void populateHloToStablehloPatterns(RewritePatternSet* patterns, + TypeConverter* converter, + MLIRContext* context, + bool allowExperimentalFeatures) { + // Populate conversion patterns for all StableHLO ops. + // Our guiding principle is to support all StableHLO functionality in MHLO. + // The inverse is not necessarily true - some MHLO ops are missing from + // StableHLO (either deliberately or haven't yet been proposed to StableHLO). + // As a result, these MHLO ops will not be added to these patterns and + // will fail the conversion. + populateHloToStablehloPatterns< +#define GET_OP_LIST +#include "stablehlo/dialect/StablehloOps.cpp.inc" + >(patterns, converter, context, allowExperimentalFeatures); +} + +} // namespace stablehlo +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_stablehlo_pass.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc similarity index 83% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_stablehlo_pass.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc index b73ae40f90e..05bbd455f86 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_stablehlo_pass.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo_pass.cc @@ -18,10 +18,10 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" +#include "mhlo/utils/type_conversion.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/PatternMatch.h" @@ -35,7 +35,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_HLOLEGALIZETOSTABLEHLOPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { @@ -48,8 +48,8 @@ struct HloLegalizeToStablehloPass stablehlo::HloToStablehloTypeConverter converter; RewritePatternSet patterns(&getContext()); - stablehlo::populateHloToStablehloPatterns(&patterns, &converter, - &getContext()); + stablehlo::populateHloToStablehloPatterns( + &patterns, &converter, &getContext(), allow_experimental_features_); stablehlo::registerFuncOpsForTypeConversion(target, patterns, converter); if (failed(applyPartialConversion(getOperation(), target, diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc similarity index 98% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc index 8c21200246f..70deb9d89ef 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc @@ -19,8 +19,8 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -44,7 +44,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_LEGALIZECONTROLFLOWPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_einsum_to_dot_general.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc similarity index 97% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_einsum_to_dot_general.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc index 819863f6b49..e0efda1c92e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_einsum_to_dot_general.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" @@ -28,7 +28,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_LEGALIZEEINSUMTODOTGENERALPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc similarity index 96% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc index 7b105d7501c..445e982a925 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" @@ -25,7 +25,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_LEGALIZEGATHERTOTORCHINDEXSELECTPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_mhlo_to_thlo.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_mhlo_to_thlo/legalize_mhlo_to_thlo.cc similarity index 85% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_mhlo_to_thlo.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_mhlo_to_thlo/legalize_mhlo_to_thlo.cc index 6d003d74a5d..04d7efa05de 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_mhlo_to_thlo.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_mhlo_to_thlo/legalize_mhlo_to_thlo.cc @@ -21,42 +21,34 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h" -#include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h" -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_scatter_gather_utils.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h" -#include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/map_mhlo_to_scalar_op.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" +#include "mhlo/utils/legalize_to_linalg_utils.h" +#include "mhlo/utils/mhlo_scatter_gather_utils.h" +#include "mhlo/utils/type_conversion.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "thlo/IR/thlo_ops.h" namespace mlir { namespace mhlo { #define GEN_PASS_DEF_LEGALIZEMHLOTOTHLOPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { -bool isIotaArray(llvm::ArrayRef array, int expectedSize = -1) { - if (expectedSize != -1 && static_cast(array.size()) != expectedSize) - return false; - for (int64_t i = 0, e = array.size(); i < e; ++i) { - if (i != array[i]) return false; - } - return true; -} - Value castToIndex(OpBuilder& b, Location loc, TensorType originalType, Value value) { Type elementTy = originalType.getElementType(); @@ -89,7 +81,7 @@ struct ConcatenateOpPattern : public OpConversionPattern { SmallVector dynamicInitSizes; for (int64_t i = 0; i < rank; ++i) { // No need to materialize anything for static dimensions. - if (staticInitSizes[i] != ShapedType::kDynamicSize) { + if (staticInitSizes[i] != ShapedType::kDynamic) { continue; } @@ -107,7 +99,7 @@ struct ConcatenateOpPattern : public OpConversionPattern { Value dynamicSum; for (const Value operand : adaptor.getVal()) { auto operandTy = operand.getType().cast(); - if (operandTy.getDimSize(concatDim) == ShapedType::kDynamicSize) { + if (operandTy.getDimSize(concatDim) == ShapedType::kDynamic) { const Value dynamicSummand = rewriter.create(loc, operand, concatDim); if (dynamicSum) { @@ -133,7 +125,8 @@ struct ConcatenateOpPattern : public OpConversionPattern { auto emptyTensor = rewriter.create( loc, staticInitSizes, resultTy.getElementType(), dynamicInitSizes); rewriter.replaceOpWithNewOp( - op, resultTy, adaptor.getVal(), emptyTensor, concatDim); + op, resultTy, adaptor.getVal(), emptyTensor, + rewriter.getIndexAttr(concatDim)); return success(); } }; @@ -169,7 +162,7 @@ struct DynamicBroadcastInDimOpPattern dynamicDims.push_back(rewriter.create( loc, outputDimensions, ValueRange{rewriter.create(loc, i)})); - staticShapeInfo.push_back(ShapedType::kDynamicSize); + staticShapeInfo.push_back(ShapedType::kDynamic); } auto emptyTensor = rewriter.create( loc, staticShapeInfo, resultTy.getElementType(), dynamicDims); @@ -214,7 +207,7 @@ struct GatherPattern : public OpConversionPattern { typeConverter->convertType(op.getType()).cast(); SmallVector sizes; sizes.reserve(resultType.getRank()); - if (resultType.getDimSize(0) != ShapedType::kDynamicSize) { + if (resultType.getDimSize(0) != ShapedType::kDynamic) { sizes.push_back(rewriter.getI64IntegerAttr(resultType.getDimSize(0))); } else { sizes.push_back( @@ -236,22 +229,6 @@ struct GatherPattern : public OpConversionPattern { } }; -static SmallVector getReduceOpEmptyTensorDynSizes( - OpBuilder& b, Location loc, Value operand, int64_t srcRank, - RankedTensorType resultType, ArrayRef reductionDims) { - SmallVector dynShape; - for (size_t i = 0, j = 0; i < srcRank; ++i) { - if (j < reductionDims.size() && reductionDims[j] == i) { - ++j; - continue; - } - size_t resultIndex = i - j; - if (!resultType.isDynamicDim(resultIndex)) continue; - dynShape.push_back(b.create(loc, operand, resultIndex)); - } - return dynShape; -} - bool isInBodyOfThloOp(Operation* op) { auto* parentOp = op->getParentRegion()->getParentOp(); return isa(*parentOp) || isa(*parentOp); @@ -357,7 +334,7 @@ struct SortPattern : public OpConversionPattern { auto thloSort = rewriter.create( loc, resultTypes, adaptor.getInputs(), outputs, - rewriter.getI64IntegerAttr(dimension), rewriter.getBoolAttr(isStable)); + rewriter.getIndexAttr(dimension), rewriter.getBoolAttr(isStable)); Region& region = thloSort.getComparator(); rewriter.inlineRegionBefore(op.getComparator(), region, region.end()); @@ -384,8 +361,41 @@ struct SortPattern : public OpConversionPattern { } }; +struct ReversePattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::ReverseOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final { + auto reverseDimensions = + llvm::to_vector(op.getDimensions().getValues()); + Type resultType = typeConverter->convertType(op->getResultTypes()[0]); + if (!resultType) + return rewriter.notifyMatchFailure(op, "failed to convert result type"); + Location loc = op.getLoc(); + auto operandType = + adaptor.getOperand().getType().dyn_cast(); + if (!operandType) + return rewriter.notifyMatchFailure(op, "expects known-rank operand"); + auto tensorResultType = resultType.cast(); + SmallVector dynShape = + tensor::createDynamicDimValues(rewriter, loc, adaptor.getOperand()); + Value initTensor = rewriter.create( + loc, tensorResultType.getShape(), tensorResultType.getElementType(), + dynShape); + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getOperand(), initTensor, reverseDimensions); + return success(); + } +}; + class LegalizeMHLOToTHLOPass : public impl::LegalizeMHLOToTHLOPassBase { + public: + explicit LegalizeMHLOToTHLOPass(bool enableExperimentalOps) { + enableExperimental = enableExperimentalOps; + } + + private: void runOnOperation() final { MLIRContext* ctx = &getContext(); RewritePatternSet patterns(ctx); @@ -404,18 +414,24 @@ class LegalizeMHLOToTHLOPass auto typeConverter = std::make_unique(); - populateScalarHloToArithmeticConversionPatterns(ctx, *typeConverter, - &patterns); + populateScalarHloToArithmeticConversionPatterns( + ctx, *typeConverter, &patterns, + [](Operation* op) { return isInBodyOfThloOp(op); }); // List of patterns. // clang-format off patterns.insert< - ConcatenateOpPattern, - DynamicBroadcastInDimOpPattern, - GatherPattern, + ReversePattern, ScatterPattern, SortPattern, ThloRegionReturnOpConversion>(*typeConverter, ctx); + + if (enableExperimental) { + patterns.insert< + ConcatenateOpPattern, + DynamicBroadcastInDimOpPattern, + GatherPattern>(*typeConverter, ctx); + } // clang-format on if (failed(applyPartialConversion(getOperation(), target, @@ -427,8 +443,9 @@ class LegalizeMHLOToTHLOPass } // namespace -std::unique_ptr> createLegalizeMHLOToTHLOPass() { - return std::make_unique(); +std::unique_ptr> createLegalizeMHLOToTHLOPass( + bool enableExperimentalOps) { + return std::make_unique(enableExperimentalOps); } } // namespace mhlo diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_shape_computations.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc similarity index 97% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_shape_computations.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc index 866ac7dfa24..0ec48e417a0 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_shape_computations.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc @@ -24,9 +24,9 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringSet.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/map_mhlo_to_scalar_op.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -52,7 +52,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_HLOLEGALIZESHAPECOMPUTATIONSPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_sort.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_sort/legalize_sort.cc similarity index 98% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_sort.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_sort/legalize_sort.cc index 264e2ef6eaf..9c9dd4f72d1 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_sort.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_sort/legalize_sort.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -28,7 +28,7 @@ limitations under the License. #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project #include "mlir/IR/Block.h" -#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" @@ -45,7 +45,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_HLOLEGALIZESORTPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { @@ -66,7 +66,7 @@ Value emitComparison(ImplicitLocOpBuilder& b, SmallVector& lhs, assert(block.getTerminator()->getOperands().size() == 1 && "Comparator must return a single value"); - BlockAndValueMapping mapping; + IRMapping mapping; for (auto [idx, arg] : llvm::enumerate(comparator.getArguments())) { Value value = idx % 2 == 0 ? lhs[idx / 2] : rhs[idx / 2]; Type type = RankedTensorType::get({}, value.getType()); @@ -416,8 +416,8 @@ struct Slicer { MemRefType toSlicedType(MemRefType sourceType) { return memref::SubViewOp::inferRankReducedResultType( - {ShapedType::kDynamicSize} /*1D output*/, sourceType, offsets, - sizes, strides) + {ShapedType::kDynamic} /*1D output*/, sourceType, offsets, sizes, + strides) .cast(); } @@ -510,7 +510,7 @@ struct SortOpPattern : public OpRewritePattern { forOps.reserve(inputRank - 1); ivs.reserve(inputRank - 1); for (int64_t i = 0; i < inputRank; ++i) { - if (i != op.getDimension()) { + if (i != static_cast(op.getDimension())) { Value dim = b.create(i); Value ub = b.create(firstOperand, dim); scf::ForOp& forOp = forOps.emplace_back( diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc similarity index 86% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc index 8bc2cad5f10..affac5688b0 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include +#include #include +#include #include #include @@ -25,19 +27,21 @@ limitations under the License. #include "llvm/ADT/BitVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h" -#include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/map_mhlo_to_scalar_op.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" +#include "mhlo/utils/legalize_to_linalg_utils.h" +#include "mhlo/utils/type_conversion.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -67,7 +71,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_HLOLEGALIZETOLINALGPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { @@ -146,6 +150,15 @@ Value extractIndexFromTensor(OpBuilder& builder, Location loc, Value tensor, loc, builder.getIndexType(), extracted); } +/// Ensures a tensor has the same shape (not including the element type) as +/// another. +Value coerceTensorShape(OpBuilder& builder, Location loc, + TypedValue value, ShapedType targetType) { + return builder.createOrFold( + loc, targetType.cloneWith(std::nullopt, value.getType().getElementType()), + value); +} + /// Returns true if the given `dimensionNumbers` from a mhlo.convolution op /// follows a canonical form: /// @@ -304,15 +317,15 @@ struct RngUniformConversion : public OpConversionPattern { // Looks through a set of dimension that has been marked as reduction axes, // if it is found within the set, then we set it as "reduction", otherwise // we can label it as "parallel". -SmallVector getEinsumLoopsAttrs( +SmallVector getEinsumLoopsAttrs( const llvm::SmallSetVector& inputInd, const llvm::SmallSetVector& reductionDims) { - SmallVector res; + SmallVector res; for (StringRef dim : inputInd) { if (!reductionDims.contains(dim)) { - res.push_back(getParallelIteratorTypeName()); + res.push_back(utils::IteratorType::parallel); } else { - res.push_back(getReductionIteratorTypeName()); + res.push_back(utils::IteratorType::reduction); } } return res; @@ -332,14 +345,14 @@ SmallVector extractDynamicEinsumSizes( // Query from lhs vars. auto dimIndPos = dimIndIt - lhsLoopVec.begin(); auto lhsShape = lhs.getType().dyn_cast().getShape(); - if (lhsShape[dimIndPos] != ShapedType::kDynamicSize) continue; + if (lhsShape[dimIndPos] != ShapedType::kDynamic) continue; dimSize = b.create(loc, lhs, dimIndPos); } else { // query from rhs vars. dimIndIt = std::find(rhsLoopVec.begin(), rhsLoopVec.end(), dimInd); auto dimIndPos = dimIndIt - rhsLoopVec.begin(); auto rhsShape = rhs.getType().dyn_cast().getShape(); - if (rhsShape[dimIndPos] != ShapedType::kDynamicSize) continue; + if (rhsShape[dimIndPos] != ShapedType::kDynamic) continue; dimSize = b.create(loc, rhs, dimIndPos); } dynSizes.push_back(dimSize); @@ -689,6 +702,30 @@ class BroadcastConverter } }; +class BroadcastOpToBroadcastConverter + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::BroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto resultTy = typeConverter->convertType(op.getType()).cast(); + + int64_t numPrependedDims = op.getBroadcastSizes().size(); + SmallVector dimensions = + llvm::to_vector(llvm::seq(0, numPrependedDims)); + + auto loc = op.getLoc(); + Value emptyTensor = + getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands()); + + rewriter.replaceOpWithNewOp( + op, op.getOperand(), emptyTensor, dimensions, + linalg::getPrunedAttributeList(op)); + return success(); + } +}; + class HloBroadcastInDimConverter : public DataMovementOpConverter { @@ -730,6 +767,120 @@ class HloBroadcastInDimConverter } }; +Value collapseExpandingDims(PatternRewriter& rewriter, Location loc, + Value operand, SmallVector& dimensions, + llvm::function_ref isExpandingDim) { + auto operandTy = operand.getType().cast(); + + SmallVector reassociationMap; + ReassociationIndices currentIndices; + + ArrayRef operandShape = operandTy.getShape(); + SmallVector newOperandShape; + SmallVector newDimensions; + + for (auto& [idx, dim] : llvm::enumerate(dimensions)) { + currentIndices.push_back(idx); + + if (!isExpandingDim(idx)) { + reassociationMap.push_back(currentIndices); + currentIndices.clear(); + newOperandShape.push_back(operandShape[idx]); + newDimensions.push_back(dim); + } + } + + if (!reassociationMap.empty()) + reassociationMap.back().insert(reassociationMap.back().end(), + currentIndices.begin(), + currentIndices.end()); + + if (dimensions.size() != newDimensions.size()) { + dimensions = newDimensions; + + auto newOperandType = + RankedTensorType::get(newOperandShape, operandTy.getElementType()); + operand = rewriter.create( + loc, newOperandType, operand, reassociationMap); + } + return operand; +} + +// Insert linalg.transpose if broadcasted dimensions are not in sorded order. +// linalg.broadcast does not support implicit transpose, so the input needs to +// be explicitely transposed. +Value transposeBroadcastOperand(PatternRewriter& rewriter, Location loc, + Value operand, + SmallVector& dimensions) { + // Do not insert `transpose` is dimensions are already sorted. + if (llvm::is_sorted(dimensions)) return operand; + + SmallVector permutation = + llvm::to_vector(llvm::seq(0, dimensions.size())); + llvm::sort(permutation, [&](int64_t lhs, int64_t rhs) { + return dimensions[lhs] < dimensions[rhs]; + }); + + auto operandTy = operand.getType().cast(); + ArrayRef operandShape = operandTy.getShape(); + SmallVector transposedOperandShape, transposedDimensions; + + for (int64_t index : permutation) { + transposedOperandShape.push_back(operandShape[index]); + transposedDimensions.push_back(dimensions[index]); + } + dimensions = transposedDimensions; + + return rewriter.create( + loc, + RankedTensorType::get(transposedOperandShape, operandTy.getElementType()), + operand, rewriter.getI64VectorAttr(permutation)); +} + +class BroadcastInDimOpToBroadcastConverter + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::BroadcastInDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + Location loc = op.getLoc(); + + SmallVector broadcastDimensions = + llvm::to_vector(op.getBroadcastDimensions().getValues()); + + Value operand = adaptor.getOperand(); + auto operandTy = operand.getType().cast(); + auto resultTy = typeConverter->convertType(op.getType()).cast(); + + ArrayRef operandShape = operandTy.getShape(); + ArrayRef resultShape = resultTy.getShape(); + + operand = collapseExpandingDims( + rewriter, loc, operand, broadcastDimensions, [&](int64_t i) { + return operandShape[i] == 1 && + resultShape[broadcastDimensions[i]] != 1; + }); + operand = + transposeBroadcastOperand(rewriter, loc, operand, broadcastDimensions); + + Value emptyTensor = + getEmptyTensorFor(rewriter, loc, resultTy, op, adaptor.getOperands()); + + SmallVector addedDimensions; + for (int64_t dim : llvm::seq(0, resultTy.getRank())) { + if (!llvm::is_contained(broadcastDimensions, dim)) + addedDimensions.push_back(dim); + } + + rewriter.replaceOpWithNewOp( + op, operand, emptyTensor, addedDimensions, + linalg::getPrunedAttributeList(op)); + return success(); + } +}; + // If the input has a static shape we know exactly when the broadcast must // expand (the dimension is 1, which also trivially expands to 1) or will never // expand (the dimension is not 1). We can also source the information from the @@ -801,10 +952,9 @@ class HloDynamicBroadcastInDimConverter rewriter.replaceOpWithNewOp( op, TypeRange{emptyTensor.getType()}, ValueRange{operand}, /*outputBuffers=*/ValueRange{emptyTensor}, - llvm::makeArrayRef( - {AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, dimExprs, - rewriter.getContext()), - rewriter.getMultiDimIdentityMap(nloops)}), + llvm::ArrayRef({AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, + dimExprs, rewriter.getContext()), + rewriter.getMultiDimIdentityMap(nloops)}), getNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location /*nested_loc*/, ValueRange args) { @@ -815,6 +965,128 @@ class HloDynamicBroadcastInDimConverter } }; +class DynamicBroadcastInDimOpToBroadcastConverter + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::DynamicBroadcastInDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final { + Location loc = op.getLoc(); + + Value operand = adaptor.getOperand(); + auto operandTy = operand.getType().dyn_cast(); + if (!operandTy) return failure(); + auto resultTy = + typeConverter->convertType(op.getType()).dyn_cast(); + if (!resultTy) return failure(); + + SmallVector broadcastDimensions = + llvm::to_vector(op.getBroadcastDimensions().getValues()); + + SmallVector> expansionBehavior(broadcastDimensions.size()); + + // Use static type info. + for (const auto& [idx, dim] : llvm::enumerate(operandTy.getShape())) { + if (ShapedType::isDynamic(dim)) continue; + expansionBehavior[idx] = (dim == 1); + } + + // Use annotated expansion behavior, if available. + if (op.getKnownExpandingDimensions()) { + for (const auto& it : + op.getKnownExpandingDimensions()->getValues()) { + expansionBehavior[it] = true; + } + } + if (op.getKnownNonexpandingDimensions()) { + for (const auto& it : + op.getKnownNonexpandingDimensions()->getValues()) { + expansionBehavior[it] = false; + } + } + + // Fail if unknown expansion behavior remains. + if (llvm::any_of(expansionBehavior, [](auto v) { return !v.has_value(); })) + return failure(); + + auto isExpandingDim = [&](int64_t i) { + return expansionBehavior[i].value(); + }; + + // Use attribute information to insert 1s into operand type. + operand = getBroadcastOperand(rewriter, loc, operand, isExpandingDim); + + auto broadcastResultTy = getBroadcastResultType( + operand, resultTy, broadcastDimensions, isExpandingDim); + + operand = collapseExpandingDims(rewriter, loc, operand, broadcastDimensions, + isExpandingDim); + operand = + transposeBroadcastOperand(rewriter, loc, operand, broadcastDimensions); + + Value emptyTensor = getEmptyTensorFor(rewriter, loc, broadcastResultTy, op, + adaptor.getOperands()); + + SmallVector addedDimensions; + for (int64_t dim : llvm::seq(0, resultTy.getRank())) { + if (!llvm::is_contained(broadcastDimensions, dim)) + addedDimensions.push_back(dim); + } + + Value result = rewriter + .create( + loc, operand, emptyTensor, addedDimensions, + linalg::getPrunedAttributeList(op)) + .getResults()[0]; + + if (resultTy != broadcastResultTy) { + result = rewriter.create(loc, resultTy, result); + } + + rewriter.replaceOp(op, result); + return success(); + } + + private: + static Value getBroadcastOperand( + PatternRewriter& rewriter, Location loc, Value operand, + llvm::function_ref isExpandingDim) { + auto operandTy = operand.getType().dyn_cast(); + + SmallVector updatedOperandShape = + llvm::to_vector(operandTy.getShape()); + for (size_t i = 0; i < updatedOperandShape.size(); ++i) { + if (isExpandingDim(i)) updatedOperandShape[i] = 1; + } + + auto updatedOperandTy = + RankedTensorType::get(updatedOperandShape, operandTy.getElementType()); + + if (updatedOperandTy != operandTy) { + operand = rewriter.create(loc, updatedOperandTy, operand); + } + + return operand; + } + + static ShapedType getBroadcastResultType( + Value operand, RankedTensorType resultTy, ArrayRef dimensions, + llvm::function_ref isExpandingDim) { + auto operandShape = operand.getType().cast().getShape(); + auto broadcastResultShape = llvm::to_vector(resultTy.getShape()); + + for (auto [operandIndex, resultIndex] : llvm::enumerate(dimensions)) { + if (isExpandingDim(operandIndex)) continue; + broadcastResultShape[resultIndex] = operandShape[operandIndex]; + } + + return RankedTensorType::get(broadcastResultShape, + resultTy.getElementType()); + } +}; + template class TransposeConverter : public DataMovementOpConverter, OpTy> { @@ -853,7 +1125,7 @@ class TransposeOpToTransposeConverter llvm::to_vector(op.getPermutation().getValues())); rewriter.replaceOpWithNewOp( - op, op.getOperand(), emptyTensor, permutation, + op, adaptor.getOperand(), emptyTensor, permutation, linalg::getPrunedAttributeList(op)); return success(); } @@ -1116,8 +1388,7 @@ class ReshapeOpConverter : public OpConversionPattern { // source. if (resultType.isDynamicDim(map.index())) continue; for (auto targetDim : map.value()) { - if (shape[targetDim] == ShapedType::kDynamicSize) - shape[targetDim] = 1; + if (shape[targetDim] == ShapedType::kDynamic) shape[targetDim] = 1; } } // Insert a cast if types are not the same (ignoring sparse encoding). @@ -1205,7 +1476,7 @@ class IotaConverter : public OpConversionPattern { ValueRange{getEmptyTensorFor(rewriter, loc, resultShapedType, iotaOp, adaptor.getOperands())}, - llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), + llvm::ArrayRef(rewriter.getMultiDimIdentityMap(nloops)), getNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange /*args*/) { Value indexOp = nestedBuilder.create( @@ -1262,7 +1533,7 @@ struct ConcatenateConverter : public OpConversionPattern { op, /*resultTensorTypes=*/resultType, /*inputs=*/ValueRange{}, /*outputBuffers=*/result, - llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), + llvm::ArrayRef(rewriter.getMultiDimIdentityMap(nloops)), getNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location loc, ValueRange) { OpBuilder b = nestedBuilder; @@ -1386,9 +1657,9 @@ class SliceConverter : public OpConversionPattern { // Say that there are k elements in total, we have condition: // start + (k - 1) * strides <= limit - 1 // -> - // k <= (limit - 1 - start) / strides + 1 + // k <= (limit - 1 - start + strides) / strides sizes.push_back( - rewriter.getI64IntegerAttr((limit - 1 - start) / stride + 1)); + rewriter.getI64IntegerAttr((limit - 1 - start + stride) / stride)); strides.push_back(rewriter.getI64IntegerAttr(stride)); } rewriter.replaceOpWithNewOp( @@ -1524,8 +1795,7 @@ DotOperationType getDotOperationType(mhlo::DotOp dotOp) { ArrayRef rhsShape = dotOp.getRhs().getType().cast().getShape(); auto shapeMatches = [](int64_t a, int64_t b) { - return a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize || - a == b; + return a == ShapedType::kDynamic || b == ShapedType::kDynamic || a == b; }; if (lhsShape.size() == 1 && rhsShape.size() == 1 && shapeMatches(lhsShape[0], rhsShape[0])) { @@ -1723,15 +1993,17 @@ class MapOpToMapConverter : public OpConversionPattern { "Expected a pointwise map"); Location loc = op.getLoc(); - Value output = - getEmptyTensorFor(rewriter, loc, resultType, op, adaptor.getOperands()); + Value operand0 = adaptor.getOperands()[0]; + Value operand1 = coerceTensorShape( + rewriter, loc, cast>(adaptor.getOperands()[1]), + operand0.getType()); + Value output = rewriter.create( + loc, tensor::getMixedSizes(rewriter, loc, operand0), + resultType.getElementType()); auto linalgOp = rewriter.create( - loc, resultType, adaptor.getOperands(), output); - // TODO(shyshkov): Add a builder for linalg::MapOp that accepts (inputs, - // init, attrs). Default builder can do either (inputs, init) or (all - // operands, attrs). - linalgOp->setAttrs(linalg::getPrunedAttributeList(op)); + loc, ValueRange{operand0, operand1}, output, + /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); // Convert the signature of the body. We scalarize the operands and add a // scalar operand representing the output tensor. @@ -1748,17 +2020,13 @@ class MapOpToMapConverter : public OpConversionPattern { rewriter.applySignatureConversion(®ion, signatureConverter, getTypeConverter()); - rewriter.replaceOp(op, linalgOp.getResults()); + auto result = rewriter.createOrFold(loc, resultType, + linalgOp.getResults()); + rewriter.replaceOp(op, result); return success(); } }; -bool isInBodyOfLinalgOps(Operation* op) { - auto* parentOp = op->getParentRegion()->getParentOp(); - return parentOp->getDialect() == - parentOp->getContext()->getLoadedDialect(); -} - SmallVector getReduceOpEmptyTensorDynSizes( OpBuilder& b, Location loc, Value arg, ShapedType resultType, ArrayRef reductionDims) { @@ -1799,7 +2067,7 @@ class ReduceRegionReturnOpConversion } }; -class ReduceConversion : public OpConversionPattern { +class ReduceOpToGenericConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( @@ -1899,6 +2167,109 @@ class ReduceConversion : public OpConversionPattern { } }; +struct ReduceOpToReduceConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final { + auto reductionDims = + llvm::to_vector(op.getDimensions().getValues()); + // mhlo.reduce doesn't specify the order of the reduction dimensions. + llvm::sort(reductionDims); + + auto toRankedTensor = [](Value v) -> RankedTensorType { + return v.getType().dyn_cast(); + }; + + SmallVector outputs; + SmallVector operandTypes, initTypes; + SmallVector resultTypes; + if (failed(typeConverter->convertTypes(op.getResultTypes(), resultTypes))) + return failure(); + + Location loc = op.getLoc(); + for (auto [operand, initValue, resultType] : + llvm::zip(adaptor.getInputs(), adaptor.getInitValues(), resultTypes)) { + auto initType = toRankedTensor(initValue); + if (!initType) + return rewriter.notifyMatchFailure(op, + "expects known-rank init values"); + initTypes.push_back(initType); + auto operandType = toRankedTensor(operand); + if (!operandType) + return rewriter.notifyMatchFailure(op, "expects known-rank operands"); + operandTypes.push_back(operandType); + initValue = rewriter.createOrFold(loc, initValue); + auto tensorResultType = resultType.cast(); + // For linalg.reduce, the result type's dimensions must match the input's + // dimensions, whereas MHLO allows replacing static dimensions with + // dynamic ones. + SmallVector resultShape; + SmallVector dynShape; + for (auto [index, dim] : + llvm::enumerate(operand.getType().cast().getShape())) { + if (!llvm::is_contained(reductionDims, index)) { + resultShape.push_back(dim); + if (ShapedType::isDynamic(dim)) { + dynShape.push_back( + rewriter.create(loc, operand, index)); + } + } + } + + Value emptyTensor = rewriter.create( + loc, resultShape, tensorResultType.getElementType(), dynShape); + Value filledTensor = + rewriter.create(loc, initValue, emptyTensor).result(); + outputs.push_back(filledTensor); + } + + auto linalgOp = rewriter.create( + loc, adaptor.getInputs(), outputs, reductionDims, + /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(op)); + + Region& region = linalgOp.getRegion(); + rewriter.inlineRegionBefore(op.getBody(), region, region.end()); + + // Convert the signature of the body. The reduce op 'computation' region + // apply function has a signature with tensor types, this is converted to a + // function with element types. E.g. the signature "(tensor, + // tensor) -> tensor" will be converted to "(f32, f32) -> f32". + // Also, we need to swap the operands of the function. The mhlo.reduce op + // expects the init values to be the first parameters of the apply function, + // while the linalg.reduction op expects the init values as the last + // parameters of the 'combiner' region apply function. + TypeConverter::SignatureConversion signatureConverter( + linalgOp.getNumDpsInputs() * 2); + assert(linalgOp.getNumDpsInputs() == linalgOp.getNumDpsInits()); + for (const auto& [idx, val] : llvm::enumerate(operandTypes)) { + signatureConverter.addInputs( + /*origInputNo=*/idx + linalgOp.getNumDpsInputs(), + // type for new operand number 'idx'. + typeConverter->convertType(val.getElementType())); + } + for (const auto& [idx, val] : llvm::enumerate(initTypes)) { + signatureConverter.addInputs( + /*origInputNo=*/idx, + // type for new operand number 'idx' + linalgOp.getNumInputs() + typeConverter->convertType(val.getElementType())); + } + rewriter.applySignatureConversion(®ion, signatureConverter, + getTypeConverter()); + + // Cast the result to the correct type. + SmallVector results; + for (auto [result, resultType] : + llvm::zip(linalgOp.getResults(), resultTypes)) { + results.push_back( + rewriter.createOrFold(loc, resultType, result)); + } + rewriter.replaceOp(op, results); + return success(); + } +}; + // Decomposes a pad with negative edge padding into a pad without negative edge // padding and a tensor.extract_slice. struct PadOpNegativePaddingConversion @@ -2069,6 +2440,29 @@ Value applyConvolutionPadding(Location loc, Value input, DenseIntElementsAttr::get(attrType, padInterior)); } +// If the ConvolutionOp has a window reversal, applies it to the filter. +Value applyConvolutionReversal(Location loc, OpBuilder& b, ConvolutionOp op, + Value filter) { + auto reversals = op.getWindowReversal(); + if (!reversals.has_value()) { + return filter; + } + llvm::SmallVector reversedDims; + for (auto [idx, reversed] : + llvm::enumerate(reversals.value().getValues())) { + if (reversed) { + reversedDims.push_back( + op.getDimensionNumbers().getKernelSpatialDimensions()[idx]); + } + } + + return b.create( + loc, filter, + mlir::DenseIntElementsAttr::get( + RankedTensorType::get(reversedDims.size(), b.getI64Type()), + reversedDims)); +} + /// Converts mhlo.conv operation to linalg named op. This only covers normal /// convolution cases. The op must have canonical dimension numbers. Depthwise /// convolution and pointwise convolution are not handled in the conversion. @@ -2087,10 +2481,18 @@ struct NormalConvolutionOpConversion Location loc = op.getLoc(); Value input = adaptor.getLhs(); Value filter = adaptor.getRhs(); + filter = applyConvolutionReversal(loc, rewriter, op, filter); auto resultType = typeConverter->convertType(op.getResult().getType()).cast(); int64_t rank = resultType.getRank(); + // Immediately emit an EmptyOp for output tensors with zero dimension. + if (llvm::is_contained(resultType.getShape(), 0)) { + rewriter.replaceOpWithNewOp(op, resultType.getShape(), + resultType.getElementType()); + return success(); + } + // The output shape is N spatial_dims F. SmallVector dynSizes; if (resultType.isDynamicDim(0)) { @@ -2193,6 +2595,13 @@ struct ConvolutionOpGeneralConversion auto reshapedResultShape = resultType.getShape().vec(); if (!resultType.hasStaticShape()) return failure(); + // Immediately emit an EmptyOp for output tensors with zero dimension. + if (llvm::is_contained(reshapedResultShape, 0)) { + rewriter.replaceOpWithNewOp(op, reshapedResultShape, + resultType.getElementType()); + return success(); + } + auto dimensionNumbers = op.getDimensionNumbers(); auto inputBatchDimension = dimensionNumbers.getInputBatchDimension(); auto inputFeatureDimension = dimensionNumbers.getInputFeatureDimension(); @@ -2226,25 +2635,7 @@ struct ConvolutionOpGeneralConversion Value modifiedRhs = applyConvolutionPadding( op.getLoc(), adaptor.getRhs(), nullptr, adaptor.getRhsDilationAttr(), op.getDimensionNumbers().getKernelSpatialDimensions(), rewriter); - - // Decompose the reversal dims into its own step - auto reversals = op.getWindowReversal(); - if (reversals.has_value()) { - llvm::SmallVector reversedDims; - for (auto& idxAndBool : - llvm::enumerate(reversals.value().getValues())) - if (idxAndBool.value()) - reversedDims.push_back( - op.getDimensionNumbers() - .getKernelSpatialDimensions()[idxAndBool.index()]); - - modifiedRhs = rewriter.create( - loc, modifiedRhs, - mlir::DenseIntElementsAttr::get( - RankedTensorType::get(reversedDims.size(), - rewriter.getIntegerType(64)), - reversedDims)); - } + modifiedRhs = applyConvolutionReversal(loc, rewriter, op, modifiedRhs); // Non-one values for feature or batch group counts will result in reshaped // inputs and outputs. These mappings are used to keep track of the the new @@ -2290,10 +2681,10 @@ struct ConvolutionOpGeneralConversion // If batch or feature count groupings exist, represent this through // reshaping the input to have an additional dimension that these groupings // exist along, and reduce in that dimension - SmallVector iterationLoops; + SmallVector iterationLoops; if (featureGroupCount != 1) { auto parallelDim = mlir::getAffineDimExpr(nextDim++, ctx); - iterationLoops.push_back(getParallelIteratorTypeName()); + iterationLoops.push_back(utils::IteratorType::parallel); // Reshape LHS { srcExprs.insert(srcExprs.begin() + inputFeatureDimension, parallelDim); @@ -2335,7 +2726,7 @@ struct ConvolutionOpGeneralConversion } if (batchGroupCount != 1) { - iterationLoops.push_back(getParallelIteratorTypeName()); + iterationLoops.push_back(utils::IteratorType::parallel); auto parallelDim = mlir::getAffineDimExpr(nextDim++, ctx); // Reshape LHS { @@ -2379,7 +2770,7 @@ struct ConvolutionOpGeneralConversion // Handle input feature dimension { - iterationLoops.push_back(getReductionIteratorTypeName()); + iterationLoops.push_back(utils::IteratorType::reduction); auto inputFeatureDim = mlir::getAffineDimExpr(nextDim++, ctx); srcExprs[lhsIndexMapping[inputFeatureDimension]] = inputFeatureDim; windowExprs[rhsIndexMapping[kernelInputFeatureDimension]] = @@ -2388,7 +2779,7 @@ struct ConvolutionOpGeneralConversion // Handle output feature dimension { - iterationLoops.push_back(getParallelIteratorTypeName()); + iterationLoops.push_back(utils::IteratorType::parallel); auto outputFeatureDim = mlir::getAffineDimExpr(nextDim++, ctx); dstExprs[resultIndexMapping[outputFeatureDimension]] = outputFeatureDim; windowExprs[rhsIndexMapping[kernelOutputFeatureDimension]] = @@ -2398,8 +2789,8 @@ struct ConvolutionOpGeneralConversion // Handle spatial Dimensions int64_t numSpatialDims = rank - 2; for (int64_t i = 0; i < numSpatialDims; i++) { - iterationLoops.push_back(getParallelIteratorTypeName()); - iterationLoops.push_back(getReductionIteratorTypeName()); + iterationLoops.push_back(utils::IteratorType::parallel); + iterationLoops.push_back(utils::IteratorType::reduction); auto dim0 = mlir::getAffineDimExpr(nextDim++, ctx); auto dim1 = mlir::getAffineDimExpr(nextDim++, ctx); @@ -2415,7 +2806,7 @@ struct ConvolutionOpGeneralConversion // Handle batch dimension { - iterationLoops.push_back(getParallelIteratorTypeName()); + iterationLoops.push_back(utils::IteratorType::parallel); auto batchDim = mlir::getAffineDimExpr(nextDim++, ctx); srcExprs[lhsIndexMapping[inputBatchDimension]] = batchDim; @@ -2435,10 +2826,10 @@ struct ConvolutionOpGeneralConversion .create( loc, /*resultTensors=*/ - llvm::makeArrayRef(zeroTensor.getType()), + llvm::ArrayRef(zeroTensor.getType()), /*inputs=*/ - llvm::makeArrayRef({modifiedLhs, modifiedRhs}), - /*outputs=*/llvm::makeArrayRef(zeroTensor), inferredMaps, + llvm::ArrayRef({modifiedLhs, modifiedRhs}), + /*outputs=*/llvm::ArrayRef(zeroTensor), inferredMaps, iterationLoops, /*bodyBuild=*/ [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange) { @@ -2517,6 +2908,13 @@ struct DepthwiseConvolutionOpConversion "expected output has static shapes"); } + // Immediately emit an EmptyOp for output tensors with zero dimension. + if (llvm::is_contained(resultType.getShape(), 0)) { + rewriter.replaceOpWithNewOp(op, resultType.getShape(), + resultType.getElementType()); + return success(); + } + // Apply padding and input dilation. llvm::SmallVector spatialDimMapping(spatialRank); std::iota(spatialDimMapping.begin(), spatialDimMapping.end(), 1); @@ -2943,7 +3341,7 @@ struct ReduceWindowOpConversion SmallVector resultDynamicDims; for (auto& en : llvm::enumerate(resultType.getShape())) { - if (en.value() != ShapedType::kDynamicSize) continue; + if (en.value() != ShapedType::kDynamic) continue; Value dimSize = rewriter.create(loc, input, en.index()); if (en.index() == 0 || static_cast(en.index()) == rank - 1) { // batch dims and channel dims can be derived from input dims @@ -3192,12 +3590,6 @@ struct GatherConversion : public OpConversionPattern { ArrayRef startIndexMap = gatherOp.getDimensionNumbers().getStartIndexMap(); - auto extractAsIndex = [&](Value input, ArrayRef index) -> Value { - return rewriter.create( - loc, rewriter.getIndexType(), - rewriter.create(loc, input, index)); - }; - // We'll need these later and creating them on demand we end up with // duplicates, which also makes lit tests really hard to write. SmallVector constants; @@ -3206,25 +3598,8 @@ struct GatherConversion : public OpConversionPattern { rewriter.create(loc, rewriter.getIndexAttr(i))); } - // Create ops to calculate the dynamic dimensions of the return shape, which - // are needed for the init tensor. - SmallVector dynDimSizes; - if (!resultType.hasStaticShape()) { - SmallVector returnShapes; - if (failed(gatherOp.reifyReturnTypeShapes(rewriter, adaptor.getOperands(), - returnShapes))) - return rewriter.notifyMatchFailure(gatherOp, - "could not reify return shape"); - assert(returnShapes.size() == 1); - Value returnShape = returnShapes[0]; - - for (int i = 0; i < resultRank; ++i) - if (resultType.isDynamicDim(i)) - dynDimSizes.push_back(extractAsIndex(returnShape, constants[i])); - } - - Value emptyOp = rewriter.create( - loc, resultType.getShape(), resultType.getElementType(), dynDimSizes); + auto emptyOp = getEmptyTensorFor(rewriter, loc, resultType, gatherOp, + adaptor.getOperands()); ValueRange ins; SmallVector indexingMaps( @@ -3348,7 +3723,7 @@ struct GatherConversion : public OpConversionPattern { extractOperand = operand; } else { // Cannot extract from unranked tensors, cast to ranked first. - SmallVector dims(operandRank, ShapedType::kDynamicSize); + SmallVector dims(operandRank, ShapedType::kDynamic); auto type = RankedTensorType::get( dims, operand.getType().cast().getElementType()); extractOperand = rewriter.create(loc, type, operand); @@ -3453,104 +3828,6 @@ class DotGeneralOpConversion : public OpConversionPattern { } }; -class SelectOpToMapConverter : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - mhlo::SelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const final { - Location loc = op.getLoc(); - - // Find result type, if on tensors. - Optional resultTy; - resultTy = typeConverter->convertType(op.getType()).dyn_cast(); - - // Check result type compatibility. - if (!resultTy || !resultTy->hasRank() || - !(resultTy->getElementType().isSignlessIntOrFloat() || - resultTy->getElementType().isa())) { - return rewriter.notifyMatchFailure( - op, "mismatched operand/result types or iterator count"); - } - - auto isScalar = [&](Value v) { - return v.getType().cast().getRank() == 0; - }; - - // Predicate in mhlo.select can be a shaped type with the same size as other - // operands, or a scalar. - const bool isScalarPred = isScalar(op.getPred()); - const bool allOperandsAreScalar = - isScalarPred && isScalar(op.getOnTrue()) && isScalar(op.getOnFalse()); - - // Within a linalg op, we can immediately de-tensorsize if the computation - // is scalar. We do not do this on the top-level, as that would break the - // nice invariant that all programs are exclusively on tensors, which is - // currently relied on for fusion in some pipelines. - if (allOperandsAreScalar && isInBodyOfLinalgOps(op)) { - SmallVector inputs; - for (auto input : adaptor.getOperands()) { - inputs.push_back(rewriter.create(loc, input)); - } - - Value scalarResult = mhlo::MhloOpToStdScalarOp::mapOp( - op, resultTy->getElementType(), inputs, &rewriter); - - rewriter.replaceOpWithNewOp(op, *resultTy, - scalarResult); - return success(); - } - - Value predValue; - ValueRange mappedInputs = adaptor.getOperands(); - // If predicate is a scalar, do not pass it as an argument to linalg.map, - // because linalg.map does not support broadcasting scalar values. Instead, - // extract the value and use it in the map block directly. - if (isScalarPred) { - predValue = rewriter.create(loc, adaptor.getPred()); - mappedInputs = mappedInputs.drop_front(); - } - - auto emptyTensor = - getEmptyTensorFor(rewriter, loc, *resultTy, op, op.getOperands()); - - auto linalgOp = rewriter.create(loc, *resultTy, mappedInputs, - emptyTensor); - linalgOp->setAttrs(linalg::getPrunedAttributeList(op)); - - { - OpBuilder::InsertionGuard guard(rewriter); - Region& region = linalgOp.getRegion(); - - SmallVector blockArgTypes; - SmallVector blockArgLocs; - for (Value v : mappedInputs) { - blockArgTypes.push_back(getElementTypeOrSelf(v)); - blockArgLocs.push_back(v.getLoc()); - } - - Block* block = rewriter.createBlock(®ion, region.end(), blockArgTypes, - blockArgLocs); - - // If predicate is scalar, the block has two arguments (on_true, on_false) - // and the predicate value is extracted outside of the block. - // If predicate is shaped, the block has three arguments (pred, on_true, - // on_false). - Value innerResult = rewriter.create( - loc, getElementTypeOrSelf(emptyTensor), - isScalarPred ? ValueRange{predValue, block->getArgument(0), - block->getArgument(1)} - : block->getArguments()); - - rewriter.create(loc, innerResult); - } - - rewriter.replaceOp(op, linalgOp.getResult()); - return success(); - } -}; - /// Converts a HLO operation to a linalg.map op that contains the corresponding /// scalar operations. template @@ -3561,16 +3838,17 @@ class PointwiseToLinalgMapConverter : public OpConversionPattern { LogicalResult matchAndRewrite( OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter& rewriter) const final { - auto getRank = [](Value v) { - return v.getType().cast().getRank(); - }; - int64_t maxRank = getRank(adaptor.getOperands().front()); + auto loc = op.getLoc(); + int64_t maxRank = getMaxRank(adaptor); - // Apply only if all operands have the same rank. - if (!llvm::all_of(adaptor.getOperands(), - [&](Value v) { return getRank(v) == maxRank; })) { - return rewriter.notifyMatchFailure(op, - "Operands must have the same rank."); + // Apply only if all operands are scalar or have the same rank. Some ops, + // like `mhlo.select`, support implicit broadcasting of scalars. + if (!llvm::all_of(adaptor.getOperands(), [&](Value v) { + int64_t r = getRank(v); + return r == 0 || r == maxRank; + })) { + return rewriter.notifyMatchFailure( + op, "Operands must be of same rank or scalar."); } // Find result type, if on tensors. @@ -3586,59 +3864,114 @@ class PointwiseToLinalgMapConverter : public OpConversionPattern { op, "mismatched operand/result types or iterator count"); } - auto loc = op.getLoc(); - // Within a thlo.map region, we can immediately de-tensorsize if the - // computation is scalar. We do not do this on the top-level, as that would - // break the nice invariant that all programs are exclusively on tensors, - // which is currently relied on for fusion in some pipelines. - if (maxRank == 0 && isInBodyOfLinalgOps(op)) { - SmallVector inputs; - for (auto input : adaptor.getOperands()) { - inputs.push_back( - rewriter.create(loc, input, ValueRange())); - } - Value scalarResult = mhlo::MhloOpToStdScalarOp::mapOp( - op, resultTy->getElementType(), inputs, &rewriter); - if (!scalarResult) return failure(); - rewriter.replaceOpWithNewOp(op, *resultTy, - scalarResult); - return success(); - } + // All-scalar pointwise ops inside of linalg ops are processes by + // ScalarHloToArithmeticPattern. + if (maxRank == 0 && isInBodyOfLinalgOps(op)) return failure(); // Find input/output values and types. - ValueRange inputs = adaptor.getOperands(); Value emptyTensor = getEmptyTensorFor(rewriter, loc, *resultTy, op, adaptor.getOperands()); - auto mapOp = - rewriter.create(loc, *resultTy, inputs, emptyTensor); - mapOp->setAttrs(linalg::getPrunedAttributeList(op)); - { - OpBuilder::InsertionGuard guard(rewriter); - auto& region = mapOp.getRegion(); - - SmallVector blockArgTypes; - SmallVector blockArgLocs; - for (Value v : inputs) { - blockArgTypes.push_back(getElementTypeOrSelf(v)); - blockArgLocs.push_back(v.getLoc()); + // Mapped inputs are cast to the same shape as the init tensor. + // Values from scalar inputs are extracted and used directly in the block. + SmallVector mappedInputs; + SmallVector scalarInputs; + for (Value input : adaptor.getOperands()) { + if (getRank(input) == maxRank) { + mappedInputs.push_back(coerceTensorShape( + rewriter, loc, cast>(input), + emptyTensor.getType())); + scalarInputs.push_back(nullptr); + } else { + scalarInputs.push_back(rewriter.create(loc, input)); } - Block* block = rewriter.createBlock(®ion, region.end(), blockArgTypes, - blockArgLocs); - - Value innerResult = mhlo::MhloOpToStdScalarOp::mapOp( - op, getElementTypeOrSelf(emptyTensor), block->getArguments(), - &rewriter); - rewriter.create(loc, innerResult); } + auto mapOp = rewriter.create( + loc, mappedInputs, emptyTensor, + [&](OpBuilder& b, Location loc, ValueRange args) { + Value innerResult = mhlo::MhloOpToStdScalarOp::mapOp( + op, getElementTypeOrSelf(emptyTensor), + interleaveScalarAndBlockArgs(scalarInputs, args), &b); + b.create(loc, innerResult); + }, + linalg::getPrunedAttributeList(op)); + rewriter.replaceOp(op, mapOp->getResults()); return success(); } + + protected: + int64_t getRank(Value v) const { + return v.getType().cast().getRank(); + } + + int64_t getMaxRank(typename OpTy::Adaptor adaptor) const { + int64_t maxRank = 0; + for (auto operand : adaptor.getOperands()) { + maxRank = std::max(maxRank, getRank(operand)); + } + return maxRank; + } + + // Inserts block arguments in places where scalar inputs have a nullptr. + SmallVector interleaveScalarAndBlockArgs(ValueRange scalarInputs, + ValueRange blockArgs) const { + SmallVector result; + auto argsIter = blockArgs.begin(); + for (Value scalarInput : scalarInputs) { + if (scalarInput) { + result.push_back(scalarInput); + } else { + result.push_back(*argsIter); + ++argsIter; + } + } + return result; + } +}; + +class SetDimensionSizeConverter + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::SetDimensionSizeOp setDimensionSizeOp, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final { + // We can lower SetDimensionSize to tensor extract. This turns into a + // regular dynamic shape. Note that the bounds annotation is still around + // but may be no longer valid depending on choices made by bufferization. + Location loc = setDimensionSizeOp.getLoc(); + auto resultType = setDimensionSizeOp.getType().cast(); + + SmallVector offsets(resultType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector strides(resultType.getRank(), + rewriter.getIndexAttr(1)); + SmallVector sizes(llvm::map_range( + resultType.getShape(), [&](int64_t dim) -> OpFoldResult { + return rewriter.getIndexAttr(dim); + })); + Value dimensionSize = + rewriter.create(loc, setDimensionSizeOp.getSize()); + sizes[setDimensionSizeOp.getDimension()] = + rewriter + .create(loc, rewriter.getIndexType(), + dimensionSize) + .getResult(); + + rewriter.replaceOpWithNewOp( + setDimensionSizeOp, resultType, adaptor.getOperand(), offsets, sizes, + strides); + return success(); + } }; struct HloLegalizeToLinalgPass : public impl::HloLegalizeToLinalgPassBase { + using HloLegalizeToLinalgPassBase::HloLegalizeToLinalgPassBase; + void getDependentDialects(DialectRegistry& registry) const override { registry.insertadd< BitcastConvertConverter, - BroadcastConverter, ConcatenateConverter, - ConstConverterTensor, HloDynamicBroadcastInDimConverter, - HloBroadcastInDimConverter, IotaConverter, + ConcatenateConverter, + ConstConverterTensor, + IotaConverter, EinsumToLinalgConverter, IotaConverter, RealDynamicSliceConverter, ReshapeOpConverter, ReverseConverter, + SetDimensionSizeConverter, SliceConverter, DynamicSliceConverter, DynamicUpdateSliceConverter, GatherConversion, PadOpConversion, PadOpNegativePaddingConversion, - ReduceConversion, ReduceWindowOpOnTensorsGenericConversion, ReduceWindowOpConversion, RngUniformConversion, @@ -3699,6 +4035,9 @@ void populateHloToLinalgConversionPattern(MLIRContext* context, if (enablePrimitiveOps) { patterns->add< + BroadcastInDimOpToBroadcastConverter, + BroadcastOpToBroadcastConverter, + DynamicBroadcastInDimOpToBroadcastConverter, MapOpToMapConverter, PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, @@ -3737,6 +4076,7 @@ void populateHloToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, @@ -3744,13 +4084,17 @@ void populateHloToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, + PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, PointwiseToLinalgMapConverter, - TransposeOpToTransposeConverter, - SelectOpToMapConverter + ReduceOpToReduceConverter, + TransposeOpToTransposeConverter >(typeConverter, context); } else { patterns->add< + BroadcastConverter, + HloBroadcastInDimConverter, + HloDynamicBroadcastInDimConverter, MapOpToGenericConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -3797,8 +4141,10 @@ void populateHloToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + ReduceOpToGenericConverter, TransposeConverter >(typeConverter, context); } @@ -3817,11 +4163,15 @@ void populateHloToLinalgConversionPattern(MLIRContext* context, patterns->add< ConvolutionOpGeneralConversion, DotGeneralOpConversion>(typeConverter, context, PatternBenefit(1)); + linalg::populateEraseUnusedOperandsAndResultsPatterns(*patterns); // clang-format on } -std::unique_ptr> createLegalizeHloToLinalgPass() { - return std::make_unique(); +std::unique_ptr> createLegalizeHloToLinalgPass( + bool enablePrimitiveOps) { + HloLegalizeToLinalgPassOptions options; + options.enablePrimitiveOps = enablePrimitiveOps; + return std::make_unique(options); } std::unique_ptr createHloToLinalgTypeConverter() { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc similarity index 95% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc index 7748db855e5..28ec1333929 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc @@ -15,11 +15,12 @@ limitations under the License. // This file implements logic for lowering MHLO dialect to Standard dialect. +#include #include -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -29,12 +30,12 @@ limitations under the License. namespace mlir { namespace { -#include "generated_legalize_to_standard.inc" +#include "legalize_to_standard/generated_legalize_to_standard.inc" } // end anonymous namespace namespace mhlo { #define GEN_PASS_DEF_LEGALIZETOSTANDARDPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { @@ -56,7 +57,7 @@ class CompareIConvert : public OpRewritePattern { !rhsType.getElementType().isSignlessInteger()) return failure(); - Optional comparePredicate = llvm::None; + Optional comparePredicate = std::nullopt; switch (op.getComparisonDirection()) { case ComparisonDirection::EQ: comparePredicate = arith::CmpIPredicate::eq; @@ -104,7 +105,7 @@ class CompareFConvert : public OpRewritePattern { !rhsType.getElementType().isa()) return failure(); - Optional comparePredicate = llvm::None; + Optional comparePredicate = std::nullopt; switch (op.getComparisonDirection()) { case ComparisonDirection::EQ: comparePredicate = arith::CmpFPredicate::OEQ; diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td similarity index 75% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td index c139b69ff7f..2436dadb870 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard_patterns.td @@ -19,13 +19,13 @@ include "mlir/IR/OpBase.td" include "mlir/Dialect/Arith/IR/ArithOps.td" include "mlir/Dialect/Math/IR/MathOps.td" include "mlir/Dialect/Func/IR/FuncOps.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "mhlo/IR/hlo_ops.td" //===----------------------------------------------------------------------===// // Nullary op patterns. //===----------------------------------------------------------------------===// -def : Pat<(HLO_ConstantOp ElementsAttr:$value), +def : Pat<(MHLO_ConstantOp ElementsAttr:$value), (Arith_ConstantOp $value)>; //===----------------------------------------------------------------------===// @@ -43,46 +43,46 @@ def createFastMathNone : NativeCodeCall< // Unary Lowering Patterns. -def : Pat<(HLO_CeilOp HLO_FpTensor:$i), (Math_CeilOp $i)>; +def : Pat<(MHLO_CeilOp MHLO_FpTensor:$i), (Math_CeilOp $i, (createFastMathNone ))>; // Binary Lowering Patterns. -def : Pat<(HLO_AndOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(MHLO_AndOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), (Arith_AndIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_OrOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(MHLO_OrOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), (Arith_OrIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r), +def : Pat<(MHLO_AddOp MHLO_FpTensor:$l, MHLO_FpTensor:$r), (Arith_AddFOp $l, $r, (createFastMathNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_SubtractOp HLO_FpTensor:$l, HLO_FpTensor:$r), +def : Pat<(MHLO_SubtractOp MHLO_FpTensor:$l, MHLO_FpTensor:$r), (Arith_SubFOp $l, $r, (createFastMathNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r), +def : Pat<(MHLO_MulOp MHLO_FpTensor:$l, MHLO_FpTensor:$r), (Arith_MulFOp $l, $r, (createFastMathNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r), +def : Pat<(MHLO_DivOp MHLO_FpTensor:$l, MHLO_FpTensor:$r), (Arith_DivFOp $l, $r, (createFastMathNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r), +def : Pat<(MHLO_RemOp MHLO_FpTensor:$l, MHLO_FpTensor:$r), (Arith_RemFOp $l, $r, (createFastMathNone )), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(MHLO_AddOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), (Arith_AddIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_SubtractOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(MHLO_SubtractOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), (Arith_SubIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(MHLO_MulOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), (Arith_MulIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(MHLO_DivOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), (Arith_DivSIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r), +def : Pat<(MHLO_RemOp MHLO_IntTensor:$l, MHLO_IntTensor:$r), (Arith_RemSIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(HLO_SelectOp $pred, $tv, $fv), +def : Pat<(MHLO_SelectOp $pred, $tv, $fv), (SelectOp $pred, $tv, $fv), [(IsSameSizeConstraint $pred, $tv), (IsSameSizeConstraint $tv, $fv)]>; diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc similarity index 97% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc index ceb9be3cc3b..e2fb93aa146 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc @@ -18,8 +18,8 @@ limitations under the License. #include -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -31,7 +31,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_LEGALIZETANHTOAPPROXIMATIONPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/lower_complex.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/lower_complex/lower_complex.cc similarity index 89% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/lower_complex.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/lower_complex/lower_complex.cc index 62aaffb43d0..977db01f842 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/lower_complex.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/lower_complex/lower_complex.cc @@ -24,10 +24,9 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "mlir-hlo/utils/hlo_utils.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" @@ -38,12 +37,13 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "utils/hlo_utils.h" namespace mlir { namespace mhlo { #define GEN_PASS_DEF_LOWERCOMPLEXPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { class LowerComplexPass : public impl::LowerComplexPassBase { @@ -52,7 +52,7 @@ class LowerComplexPass : public impl::LowerComplexPassBase { void runOnOperation() override; }; -#include "generated_lower_complex.inc" +#include "lower_complex/generated_lower_complex.inc" // Lowers the complex operations that can be represented using other operations. void LowerComplexPass::runOnOperation() { diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/lower_complex/lower_complex_patterns.td b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/lower_complex/lower_complex_patterns.td new file mode 100644 index 00000000000..1fc4128cba0 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/lower_complex/lower_complex_patterns.td @@ -0,0 +1,136 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the legalization pattern that converts complex operations into +// equivalent real value operations. + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mhlo/IR/hlo_ops.td" + +//===----------------------------------------------------------------------===// +// Binary op patterns. +//===----------------------------------------------------------------------===// + +// Add and subtraction are elementwise and can be distributed across the real +// and imaginary components. +foreach elementwiseOp = [MHLO_AddOp, MHLO_SubtractOp] in + def : Pat<(elementwiseOp MHLO_ComplexTensor:$lhs, + MHLO_ComplexTensor:$rhs), + (MHLO_ComplexOp + (elementwiseOp (MHLO_RealOp $lhs), (MHLO_RealOp $rhs)), + (elementwiseOp (MHLO_ImagOp $lhs), (MHLO_ImagOp $rhs)))>; + +// Complex multiplication results in a cross product multiplication between the +// real and imaginary components such that: +// result.real = lhs.real * rhs.real - lhs.imag * rhs.imag +// result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag +def : Pat<(MHLO_MulOp MHLO_ComplexTensor:$lhs, + MHLO_ComplexTensor:$rhs), + (MHLO_ComplexOp + (MHLO_SubtractOp + (MHLO_MulOp + (MHLO_RealOp:$lhs_real $lhs), + (MHLO_RealOp:$rhs_real $rhs)), + (MHLO_MulOp + (MHLO_ImagOp:$lhs_imag $lhs), + (MHLO_ImagOp:$rhs_imag $rhs))), + (MHLO_AddOp + (MHLO_MulOp $lhs_real, $rhs_imag), + (MHLO_MulOp $lhs_imag, $rhs_real)))>; + + +// Division is performed by normalizing the denominator by multiplying by the +// conjugate of the rhs. +// numerator = lhs * conj(rhs) +// denominator = rhs * conj(rhs) +def : Pat<(MHLO_DivOp MHLO_ComplexTensor:$lhs, MHLO_ComplexTensor:$rhs), + (MHLO_ComplexOp + (MHLO_DivOp + (MHLO_RealOp (MHLO_MulOp:$num $lhs, + (MHLO_ComplexOp:$conj + (MHLO_RealOp $rhs), + (MHLO_NegOp (MHLO_ImagOp $rhs))))), + (MHLO_AddOp:$den + (MHLO_MulOp (MHLO_RealOp $rhs), (MHLO_RealOp $rhs)), + (MHLO_MulOp (MHLO_ImagOp $rhs), (MHLO_ImagOp $rhs)))), + (MHLO_DivOp (MHLO_ImagOp $num), $den))>; + +// Absolute value is evaluated as: +// result = sqrt(val.real * val.real + val.imag * val.imag) +def : Pat<(MHLO_AbsOp MHLO_ComplexTensor:$val), + (MHLO_SqrtOp + (MHLO_AddOp + (MHLO_MulOp (MHLO_RealOp:$real $val), $real), + (MHLO_MulOp (MHLO_ImagOp:$imag $val), $imag)))>; + +// Can deconstruct sin(a + ib) as follows: +// sin(a) * cosh(b) + icos(a) * sinh(b) +// sinh(b) = (e^x - e^-x) / 2 +// cosh(b) = (e^x + e^-x) / 2 +def : Pat<(MHLO_SineOp MHLO_ComplexTensor:$val), + (MHLO_ComplexOp + (MHLO_DivOp + (MHLO_MulOp + (MHLO_SineOp (MHLO_RealOp:$real $val)), + (MHLO_AddOp + (MHLO_ExpOp:$exp (MHLO_ImagOp:$imag $val)), + (MHLO_ExpOp:$nexp (MHLO_NegOp $imag)))), + (MHLO_ConstantOp : $two (ConstantSplat<"2.0"> $real))), + (MHLO_DivOp + (MHLO_MulOp + (MHLO_CosineOp $real), + (MHLO_SubtractOp $exp, $nexp)), $two))>; + +// Can deconstruct cos(a + ib) as follows: +// cos(a) * cosh(b) - isin(a) * sinh(b) +// sinh(b) = (e^x - e^-x) / 2 +// cosh(b) = (e^x + e^-x) / 2 +def : Pat<(MHLO_CosineOp MHLO_ComplexTensor:$val), + (MHLO_ComplexOp + (MHLO_DivOp + (MHLO_MulOp + (MHLO_CosineOp (MHLO_RealOp:$real $val)), + (MHLO_AddOp + (MHLO_ExpOp:$exp (MHLO_ImagOp:$imag $val)), + (MHLO_ExpOp:$nexp (MHLO_NegOp $imag)))), + (MHLO_ConstantOp : $two (ConstantSplat<"2.0"> $real))), + (MHLO_DivOp + (MHLO_MulOp + (MHLO_SineOp $real), + (MHLO_SubtractOp $nexp, $exp)), $two))>; + +// Exponential can be lowered to an exponential on the real component and a +// sum of sinusoids of the imaginary component, which equates to a normal +// exponential operator multiplied by Euler's formula. +// +// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * Cos(b) + Exp(a) * iSin(b)) +class MHLO_ComparisonDirectionValue : + ConstantAttr; + +def : Pat<(MHLO_ExpOp MHLO_ComplexTensor:$val), + (MHLO_ComplexOp + (MHLO_MulOp + (MHLO_CosineOp (MHLO_ImagOp:$imag $val)), + (MHLO_ExpOp:$exp (MHLO_RealOp:$real $val))), + (MHLO_MulOp (MHLO_SineOp $imag), $exp))>; + +foreach pair = [[MHLO_ComparisonDirectionValue<"NE">, MHLO_OrOp], + [MHLO_ComparisonDirectionValue<"EQ">, MHLO_AndOp]] in { +def : Pat<(MHLO_CompareOp MHLO_ComplexTensor:$lhs, MHLO_ComplexTensor:$rhs, pair[0], $compare_type), + (pair[1] + (MHLO_CompareOp (MHLO_RealOp $lhs), (MHLO_RealOp $rhs), pair[0], $compare_type), + (MHLO_CompareOp (MHLO_ImagOp $lhs), (MHLO_ImagOp $rhs), pair[0], $compare_type))>; +} diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/lower_general_dot/lower_general_dot.cc similarity index 86% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/lower_general_dot/lower_general_dot.cc index d7085fdfbcd..28b9288154c 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/lower_general_dot/lower_general_dot.cc @@ -17,12 +17,13 @@ limitations under the License. #include +#include #include #include "llvm/ADT/STLExtras.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/Attributes.h" @@ -37,7 +38,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_LEGALIZEGENERALDOTPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { @@ -51,14 +52,14 @@ Value transposeReshape(Value arg, Location loc, int64_t leftSize = 1; for (auto dim : leftDims) { leftSize = (ShapedType::isDynamic(argShape[dim]) || leftSize < 0) - ? ShapedType::kDynamicSize + ? ShapedType::kDynamic : leftSize * argShape[dim]; } int64_t rightSize = 1; for (auto dim : rightDims) { rightSize = (ShapedType::isDynamic(argShape[dim]) || rightSize < 0) - ? ShapedType::kDynamicSize + ? ShapedType::kDynamic : rightSize * argShape[dim]; } @@ -73,7 +74,7 @@ Value transposeReshape(Value arg, Location loc, auto transposePermutationAttr = DenseIntElementsAttr::get(transposePermutationType, - llvm::makeArrayRef(transposePermutation)) + llvm::ArrayRef(transposePermutation)) .cast(); // Compute the resulting shape. @@ -111,16 +112,18 @@ Value transposeReshape(Value arg, Location loc, SmallVector reshapeDims; auto multiplyDynamicDims = [&](llvm::ArrayRef dims) -> Value { Value dynamicSize = rewriter.create( - loc, RankedTensorType::get({1}, rewriter.getI32Type()), arg, - rewriter.getI64IntegerAttr(dims.front())); - + loc, arg, rewriter.getI64IntegerAttr(dims.front())); + Value dynamicSizeReshaped = rewriter.create( + loc, RankedTensorType::get({1}, rewriter.getI32Type()), dynamicSize); for (auto idx : dims.drop_front()) { Value dim = rewriter.create( - loc, RankedTensorType::get({1}, rewriter.getI32Type()), arg, - rewriter.getI64IntegerAttr(idx)); - dynamicSize = rewriter.create(loc, dynamicSize, dim); + loc, arg, rewriter.getI64IntegerAttr(idx)); + Value dimReshaped = rewriter.create( + loc, RankedTensorType::get({1}, rewriter.getI32Type()), dim); + dynamicSizeReshaped = + rewriter.create(loc, dynamicSizeReshaped, dimReshaped); } - return dynamicSize; + return dynamicSizeReshaped; }; if (leftSize < 0) { @@ -140,7 +143,6 @@ Value transposeReshape(Value arg, Location loc, Value reshapeDimsTensor = rewriter.create( loc, RankedTensorType::get({2}, rewriter.getI32Type()), reshapeDims, rewriter.getI64IntegerAttr(0)); - return rewriter.create(loc, reshapedType, transposeResult, reshapeDimsTensor); } @@ -206,13 +208,13 @@ struct GeneralDotConvert : public OpRewritePattern { RankedTensorType rhsTy = rhs.getType().dyn_cast(); if (!lhsTy || !rhsTy) return failure(); - lhs = processDotArg(op.getLhs(), op.getLoc(), - dotNumbers.getLhsContractingDimensions(), - /*outerDimsFirst=*/true, rewriter); + lhs = llvm::cast>(processDotArg( + op.getLhs(), op.getLoc(), dotNumbers.getLhsContractingDimensions(), + /*outerDimsFirst=*/true, rewriter)); - rhs = processDotArg(op.getRhs(), op.getLoc(), - dotNumbers.getRhsContractingDimensions(), - /*outerDimsFirst=*/false, rewriter); + rhs = llvm::cast>(processDotArg( + op.getRhs(), op.getLoc(), dotNumbers.getRhsContractingDimensions(), + /*outerDimsFirst=*/false, rewriter)); // Accept only static shaped types. auto lhsShapeType = lhs.getType().dyn_cast_or_null(); @@ -221,16 +223,11 @@ struct GeneralDotConvert : public OpRewritePattern { ArrayAttr precisionConfig; if (op.getPrecisionConfig()) precisionConfig = *op.getPrecisionConfig(); - SmallVector results; - LogicalResult res = - DotOp::inferReturnTypes(rewriter.getContext(), None, {lhs, rhs}, - op->getAttrDictionary(), {}, results); - (void)res; - assert(succeeded(res) && "invalid input to dot"); ShapedType resultTy = op.getType().cast(); - ShapedType newTy = - results.front().cast().clone(resultTy.getElementType()); + ShapedType newTy = RankedTensorType::get( + {lhsShapeType.getShape()[0], rhsShapeType.getShape()[1]}, + resultTy.getElementType()); Value newDotOp = rewriter.create(op.getLoc(), newTy, lhs, rhs, precisionConfig); if (static_cast(lhsContractingDims.size()) == @@ -266,18 +263,22 @@ struct GeneralDotConvert : public OpRewritePattern { for (auto contractingDim : contractingDims) { for (; index < contractingDim; index++) { staticDims.push_back(ty.getDimSize(index)); - dynDims.push_back(rewriter.create( - loc, RankedTensorType::get({1}, rewriter.getI32Type()), arg, - rewriter.getI64IntegerAttr(index))); + Value dynDim = rewriter.create( + loc, arg, rewriter.getI64IntegerAttr(index)); + Value dynDimReshaped = rewriter.create( + loc, RankedTensorType::get({1}, rewriter.getI32Type()), dynDim); + dynDims.push_back(dynDimReshaped); } index++; } for (; index < ty.getRank(); index++) { staticDims.push_back(ty.getDimSize(index)); - dynDims.push_back(rewriter.create( - loc, RankedTensorType::get({1}, rewriter.getI32Type()), arg, - rewriter.getI64IntegerAttr(index))); + Value dynDim = rewriter.create( + loc, arg, rewriter.getI64IntegerAttr(index)); + Value dynDimReshaped = rewriter.create( + loc, RankedTensorType::get({1}, rewriter.getI32Type()), dynDim); + dynDims.push_back(dynDimReshaped); } }; diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/map_chlo_to_hlo_op.h similarity index 94% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/map_chlo_to_hlo_op.h index ce277ab9b03..a2974a2e49d 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/map_chlo_to_hlo_op.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_HLO_OP_H -#define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_HLO_OP_H +#ifndef MLIR_HLO_MHLO_TRANSFORMS_MAP_CHLO_TO_HLO_OP_H +#define MLIR_HLO_MHLO_TRANSFORMS_MAP_CHLO_TO_HLO_OP_H #include -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/IR/PatternMatch.h" #include "stablehlo/dialect/ChloOps.h" @@ -49,9 +49,8 @@ inline llvm::Optional mhloComparisonDirection( return mhlo::ComparisonDirection::LE; case chlo::ComparisonDirection::LT: return mhlo::ComparisonDirection::LT; - default: - return {}; } + return {}; } inline llvm::Optional mhloComparisonType( @@ -67,9 +66,8 @@ inline llvm::Optional mhloComparisonType( return mhlo::ComparisonType::SIGNED; case chlo::ComparisonType::UNSIGNED: return mhlo::ComparisonType::UNSIGNED; - default: - return {}; } + return {}; } struct HloCompareAdaptor { @@ -135,4 +133,4 @@ void populateForBroadcastingBinaryOp(MLIRContext *context, } // namespace chlo } // namespace mlir -#endif // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_HLO_OP_H +#endif // MLIR_HLO_MHLO_TRANSFORMS_MAP_CHLO_TO_HLO_OP_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h similarity index 96% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h index ceb0576bdb6..7e59059449a 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_MHLO_TO_SCALAR_OP_H -#define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_MHLO_TO_SCALAR_OP_H +#ifndef MLIR_HLO_MHLO_TRANSFORMS_MAP_MHLO_TO_SCALAR_OP_H +#define MLIR_HLO_MHLO_TRANSFORMS_MAP_MHLO_TO_SCALAR_OP_H + +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -56,6 +58,10 @@ struct MhloToScalarOp { using UOp = ::mlir::arith::AndIOp; }; template <> +struct MhloToScalarOp { + using FOp = ::mlir::math::CbrtOp; +}; +template <> struct MhloToScalarOp { using FOp = ::mlir::arith::CmpFOp; using IOp = ::mlir::arith::CmpIOp; @@ -159,6 +165,11 @@ struct MhloToScalarOp { using COp = ::mlir::complex::SinOp; }; template <> +struct MhloToScalarOp { + using FOp = ::mlir::math::TanOp; + using COp = ::mlir::complex::TanOp; +}; +template <> struct MhloToScalarOp { using FOp = ::mlir::math::Atan2Op; using COp = ::mlir::complex::Atan2Op; @@ -200,7 +211,8 @@ template struct MapMhloOpToScalarOpImpl { Value operator()(Location loc, ArrayRef resultTypes, ArrayRef /*argTypes*/, ValueRange args, OpBuilder* b) { - return b->template create(loc, resultTypes, args, mlir::None); + return b->template create(loc, resultTypes, args, + std::nullopt); } }; @@ -211,7 +223,7 @@ struct MapMhloOpToScalarOpImpl { Type elementType = getElementTypeOrSelf(argTypes.front()); if (SupportedType{}(elementType)) { return b->template create(loc, resultTypes, args, - mlir::None); + std::nullopt); } return MapMhloOpToScalarOpImpl{}(loc, resultTypes, argTypes, args, b); @@ -318,31 +330,10 @@ inline Value getConstantOrSplat(OpBuilder* b, Location loc, Type t, return b->create(loc, t, v); } -template <> -inline Value mapMhloOpToStdScalarOp(Location loc, - ArrayRef resultTypes, - ArrayRef argTypes, - mhlo::CbrtOp::Adaptor adaptor, - OpBuilder* b) { - Type elementType = getElementTypeOrSelf(argTypes.front()); - if (auto floatType = elementType.dyn_cast()) { - // Convert cbrt(x) to copysign(cbrt(abs(x), 1.0 / 3.0), x). - // This is to allow cbrt using pow while still handling negative numbers. It - // should match most cbrt intrinsics. - Value abs = b->create(loc, adaptor.getOperand()); - Value third = b->create( - loc, b->getFloatAttr(floatType, 1.0 / 3.0)); - Value pow = b->create(loc, resultTypes[0], abs, third); - return b->create(loc, floatType, pow, - adaptor.getOperand()); - } - return nullptr; -} - template inline Optional getCmpPredicate(mhlo::ComparisonDirection, bool) { - return llvm::None; + return std::nullopt; } template <> @@ -357,7 +348,7 @@ inline Optional getCmpPredicate( .Case("GT", arith::CmpFPredicate::OGT) .Case("LE", arith::CmpFPredicate::OLE) .Case("LT", arith::CmpFPredicate::OLT) - .Default(llvm::None); + .Default(std::nullopt); } template <> @@ -375,7 +366,7 @@ inline Optional getCmpPredicate( isSigned ? arith::CmpIPredicate::sle : arith::CmpIPredicate::ule) .Case("LT", isSigned ? arith::CmpIPredicate::slt : arith::CmpIPredicate::ult) - .Default(llvm::None); + .Default(std::nullopt); } template <> @@ -627,7 +618,8 @@ inline Value mapConvertOpToStdScalarOp(Location loc, ArrayRef targetTypes, if (sourceType.isSignlessInteger(1) && mlir::arith::UIToFPOp::areCastCompatible(convertedSourceType, targetType)) { - return b->create(loc, resultTypes, args, mlir::None); + return b->create(loc, resultTypes, args, + std::nullopt); } if (sourceType.isUnsignedInteger() && mlir::arith::UIToFPOp::areCastCompatible(convertedSourceType, @@ -644,17 +636,19 @@ inline Value mapConvertOpToStdScalarOp(Location loc, ArrayRef targetTypes, // End of Added by DISC } if (mlir::arith::SIToFPOp::areCastCompatible(sourceType, targetType)) { - return b->create(loc, resultTypes, args, mlir::None); + return b->create(loc, resultTypes, args, + std::nullopt); } if (sourceType.isa() && targetType.isa()) { auto src = sourceType.cast(); auto res = targetType.cast(); if (src.getWidth() > res.getWidth()) { return b->create(loc, resultTypes, args, - mlir::None); + std::nullopt); } if (src.getWidth() < res.getWidth()) { - return b->create(loc, resultTypes, args, mlir::None); + return b->create(loc, resultTypes, args, + std::nullopt); } // There's no direct conversion between different 16 bit floating point // types, so go through 32 bit float. @@ -687,16 +681,16 @@ inline Value mapConvertOpToStdScalarOp(Location loc, ArrayRef targetTypes, auto res = targetType.cast(); if (src.getWidth() > res.getWidth()) { return b->create(loc, resultTypes, args, - mlir::None); + std::nullopt); } if (src.getWidth() < res.getWidth()) { // Special case boolean values, so they get casted to `1` instead of `-1`. if (IsUnsignedIntegerType{}(src)) { return b->create(loc, resultTypes, args, - mlir::None); + std::nullopt); } return b->create(loc, resultTypes, args, - mlir::None); + std::nullopt); } // No conversion is needed for the same width integers return args.front(); @@ -704,11 +698,13 @@ inline Value mapConvertOpToStdScalarOp(Location loc, ArrayRef targetTypes, if (targetType.isUnsignedInteger() && mlir::arith::FPToUIOp::areCastCompatible(convertedSourceType, targetType)) { - return b->create(loc, resultTypes, args, mlir::None); + return b->create(loc, resultTypes, args, + std::nullopt); } if (mlir::arith::FPToSIOp::areCastCompatible(convertedSourceType, targetType)) { - return b->create(loc, resultTypes, args, mlir::None); + return b->create(loc, resultTypes, args, + std::nullopt); } if (targetType.isa()) { Type targetElementType = targetType.cast().getElementType(); @@ -1290,4 +1286,4 @@ struct MhloOpToStdScalarOp { } // namespace mhlo } // namespace mlir -#endif // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_MHLO_TO_SCALAR_OP_H +#endif // MLIR_HLO_MHLO_TRANSFORMS_MAP_MHLO_TO_SCALAR_OP_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_stablehlo_to_hlo_op.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h similarity index 93% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_stablehlo_to_hlo_op.h rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h index 433d0e88430..3667563ac07 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_stablehlo_to_hlo_op.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/map_stablehlo_to_hlo_op.h @@ -13,28 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_STABLEHLO_TO_HLO_OP_H -#define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_STABLEHLO_TO_HLO_OP_H +#ifndef MLIR_HLO_MHLO_TRANSFORMS_MAP_STABLEHLO_TO_HLO_OP_H +#define MLIR_HLO_MHLO_TRANSFORMS_MAP_STABLEHLO_TO_HLO_OP_H #include -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" #include "stablehlo/dialect/StablehloOps.h" namespace mlir { namespace stablehlo { template -struct HloToStablehloOpImpl { - using Type = std::false_type; -}; +struct HloToStablehloOpImpl; template using HloToStablehloOp = typename HloToStablehloOpImpl::Type; template -struct StablehloToHloOpImpl { - using Type = std::false_type; -}; +struct StablehloToHloOpImpl; template using StablehloToHloOp = typename StablehloToHloOpImpl::Type; @@ -118,6 +114,7 @@ MAP_STABLEHLO_TO_HLO(OptimizationBarrierOp) MAP_STABLEHLO_TO_HLO(OrOp) MAP_STABLEHLO_TO_HLO(OutfeedOp) MAP_STABLEHLO_TO_HLO(PadOp) +MAP_STABLEHLO_TO_HLO(PartitionIdOp) MAP_STABLEHLO_TO_HLO(PopulationCountOp) MAP_STABLEHLO_TO_HLO(PowOp) MAP_STABLEHLO_TO_HLO(RealDynamicSliceOp) @@ -168,4 +165,4 @@ MAP_STABLEHLO_TO_HLO(XorOp) } // namespace stablehlo } // namespace mlir -#endif // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_STABLEHLO_TO_HLO_OP_H +#endif // MLIR_HLO_MHLO_TRANSFORMS_MAP_STABLEHLO_TO_HLO_OP_H diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/materialize_broadcasts/materialize_broadcasts.cc similarity index 98% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/materialize_broadcasts/materialize_broadcasts.cc index 4730eeeccb2..4b2cb559b5e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/materialize_broadcasts/materialize_broadcasts.cc @@ -15,7 +15,7 @@ limitations under the License. #include -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/materialize_broadcasts/materialize_broadcasts_pass.cc similarity index 92% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/materialize_broadcasts/materialize_broadcasts_pass.cc index 81ab6936b32..e7eb7209aba 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/materialize_broadcasts/materialize_broadcasts_pass.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/MLIRContext.h" @@ -27,7 +27,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_TESTMATERIALIZEBROADCASTSPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/merge_assuming_ops.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc similarity index 96% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/merge_assuming_ops.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc index 6d5c4a17516..8f7ca771e98 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/merge_assuming_ops.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc @@ -20,13 +20,13 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" @@ -41,7 +41,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_MERGEASSUMINGOPSPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { @@ -146,7 +146,7 @@ LogicalResult moveUpIntoAssumingOpMatchAndRewrite(Operation *op, assumingOp.getLoc(), assumingOp.getWitness(), [&](OpBuilder &b, Location) { // Copy body. - BlockAndValueMapping mapping; + IRMapping mapping; for (auto &nested : body->without_terminator()) b.clone(nested, mapping); @@ -202,7 +202,7 @@ struct MoveElementwiseOpsUpIntoAssumingOpPattern : public RewritePattern { !op->hasTrait()) { return failure(); } - if (!MemoryEffectOpInterface::hasNoEffect(op)) return failure(); + if (!isMemoryEffectFree(op)) return failure(); return moveUpIntoAssumingOpMatchAndRewrite(op, rewriter); } @@ -250,7 +250,7 @@ struct MoveElementwiseOpsDownIntoAssumingOpPattern : public RewritePattern { !op->hasTrait()) { return failure(); } - if (!MemoryEffectOpInterface::hasNoEffect(op)) return failure(); + if (!isMemoryEffectFree(op)) return failure(); return moveDownIntoAssumingOpMatchAndRewrite(op, rewriter); } @@ -302,7 +302,7 @@ struct MoveUpOutOfAssumingOpPattern : public OpRewritePattern { assumingOp.getLoc(), assumingOp.getWitness(), [&](OpBuilder &b, Location) { // Copy body. - BlockAndValueMapping mapping; + IRMapping mapping; for (Operation &nested : body->without_terminator()) { b.clone(nested, mapping); } @@ -358,7 +358,7 @@ struct MergeAssumingOpsPattern : public OpRewritePattern { auto newAssumingOp = rewriter.create( precedingOp.getLoc(), newWitness, [&](OpBuilder &b, Location) { // Copy preceding op's body. - BlockAndValueMapping mapping; + IRMapping mapping; for (auto &nested : body_a->without_terminator()) { b.clone(nested, mapping); } @@ -432,7 +432,7 @@ struct MergeAssumingOpsPass RewritePatternSet patterns(ctx); mhlo::populateMergeAssumingOpsPatterns(ctx, &patterns); GreedyRewriteConfig config; - config.maxIterations = GreedyRewriteConfig::kNoIterationLimit; + config.maxIterations = GreedyRewriteConfig::kNoLimit; if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config))) { return signalPassFailure(); diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/mhlo_canonicalize_gather.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_gather/mhlo_canonicalize_gather.cc similarity index 92% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/mhlo_canonicalize_gather.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_gather/mhlo_canonicalize_gather.cc index 75f7e72fbb2..62cf99f5e16 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/mhlo_canonicalize_gather.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_gather/mhlo_canonicalize_gather.cc @@ -21,9 +21,9 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_scatter_gather_utils.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/utils/mhlo_scatter_gather_utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -37,7 +37,7 @@ namespace mhlo { namespace { #define GEN_PASS_DEF_HLOCANONICALIZEGATHERPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" // Given an input tensor, collapse dimensions 1+collapsedSliceDims[...]. Value collapseSliceDims(ImplicitLocOpBuilder& b, TypedValue input, @@ -65,7 +65,8 @@ Value expandBatchDimension(ImplicitLocOpBuilder& b, originalGatherOp.getStartIndices().getType().getShape()}; // Erase the index vector dimension if it wasn't implicit. int64_t indexDim = originalGatherOp.getDimensionNumbers().getIndexVectorDim(); - if (indexDim < newShape.size()) newShape.erase(newShape.begin() + indexDim); + if (indexDim < static_cast(newShape.size())) + newShape.erase(newShape.begin() + indexDim); // `input` has one batch dimension, if we still have one now, there is nothing // to do. @@ -79,7 +80,7 @@ Value expandBatchDimension(ImplicitLocOpBuilder& b, RankedTensorType::get(newShape, input.getType().getElementType()); auto reassociation = *getReassociationIndicesForReshape(input.getType(), newType); - if (newShape.size() > input.getType().getRank()) { + if (static_cast(newShape.size()) > input.getType().getRank()) { return b.create(newType, input, reassociation); } return b.create(newType, input, reassociation); @@ -167,13 +168,16 @@ struct CanonicalizeGatherPattern : public OpRewritePattern { result, b.getI64TensorAttr(operandPermutationInverse)); // Collapse the requested dimensions. - result = collapseSliceDims(b, result, dims.getCollapsedSliceDims()); + result = cast>( + collapseSliceDims(b, result, dims.getCollapsedSliceDims())); // Expand the start index dimensions. - result = expandBatchDimension(b, result, gatherOp); + result = + cast>(expandBatchDimension(b, result, gatherOp)); // Move the offset dims to the final locations. - result = moveOffsetDimensions(b, result, gatherOp); + result = + cast>(moveOffsetDimensions(b, result, gatherOp)); rewriter.replaceOp(gatherOp.getOperation(), {result}); return success(); diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/mhlo_canonicalize_reduction.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_reduction/mhlo_canonicalize_reduction.cc similarity index 97% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/mhlo_canonicalize_reduction.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_reduction/mhlo_canonicalize_reduction.cc index 4b3a97e1630..65cedd54595 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/mhlo_canonicalize_reduction.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_reduction/mhlo_canonicalize_reduction.cc @@ -16,7 +16,7 @@ limitations under the License. // This file canonicalize reduction ops in hlo dialect to match the // capacity of codegen backend. -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -26,7 +26,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_HLOCANONICALIZEREDUCTIONPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { @@ -212,10 +212,9 @@ struct HloCanonicalizeReductionPass for (Value operand : op.getInputs()) { newOperands.push_back(b.create( loc, - RankedTensorType::get( - SmallVector(newOperandDims.size(), - ShapedType::kDynamicSize), - elemTy), + RankedTensorType::get(SmallVector(newOperandDims.size(), + ShapedType::kDynamic), + elemTy), operand, newOperandShape)); } auto newOp = diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/mhlo_canonicalize_scatter.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_scatter/mhlo_canonicalize_scatter.cc similarity index 95% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/mhlo_canonicalize_scatter.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_scatter/mhlo_canonicalize_scatter.cc index a056bfc4d60..54892f63ed6 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/mhlo_canonicalize_scatter.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_canonicalize_scatter/mhlo_canonicalize_scatter.cc @@ -17,12 +17,13 @@ limitations under the License. #include #include +#include #include #include "llvm/ADT/STLExtras.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_scatter_gather_utils.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/utils/mhlo_scatter_gather_utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Pass/Pass.h" @@ -33,7 +34,7 @@ namespace mhlo { namespace { #define GEN_PASS_DEF_HLOCANONICALIZESCATTERPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" DenseIntElementsAttr getI64ElementsAttr(ArrayRef values, Builder* builder) { @@ -119,7 +120,9 @@ SmallVector reshapeUpdatesToEnsureSingleScatterDimension( } if (numScatterDims == 0) { return to_vector(llvm::map_range(updates, [&](Value update) -> Value { - return insertDegenerateDimensions(b, loc, update, {0}); + return insertDegenerateDimensions( + b, loc, cast>(update), + {0}); })); } return updates; @@ -137,8 +140,9 @@ SmallVector reshapeUpdatesToMatchOperandShape( shiftedScatterDimsToOperandDims.push_back(i + 1); return to_vector(map_range(updates, [&](Value update) -> Value { - return insertDegenerateDimensions(b, loc, update, - shiftedScatterDimsToOperandDims); + return insertDegenerateDimensions( + b, loc, cast>(update), + shiftedScatterDimsToOperandDims); })); } @@ -205,7 +209,7 @@ struct CanonicalizeScatterPattern : public OpRewritePattern { rewriter.getContext(), /*updateWindowDims=*/ llvm::to_vector<4>(llvm::seq(1, operandRank + 1)), - /*insertedWindowDims=*/llvm::None, + /*insertedWindowDims=*/std::nullopt, /*scatterDimsToOperandDims=*/ llvm::to_vector<4>(llvm::seq(0, scatterIndicesVectorSize)), /*indexVectorDim=*/1); diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/mhlo_flatten_tuple.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc similarity index 96% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/mhlo_flatten_tuple.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc index 5432f200268..48ad9df9da6 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/mhlo_flatten_tuple.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc @@ -23,8 +23,8 @@ limitations under the License. #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" @@ -37,7 +37,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_FLATTENTUPLEPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td similarity index 94% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td index 3b9e66ca04a..8ac42b37fba 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td @@ -232,6 +232,17 @@ def ShapeReificationPass : Pass<"shape-reification", "func::FuncOp"> { let constructor = "createShapeReificationPass()"; } +def SymbolicShapeOptimization : Pass<"symbolic-shape-optimization", "func::FuncOp"> { + let summary = "Analyzes shapes and performs shape-related optimizations"; + let constructor = "createSymbolicShapeOptimizationPass()"; +} + +def ShapeSimplification + : Pass<"shape-simplification", "mlir::func::FuncOp"> { + let summary = "Simplify shape ops"; + let constructor = "createShapeSimplification()"; +} + def ConstraintFusionPass : Pass<"constraint-fusion", "func::FuncOp"> { let summary = "Fuse shape constraints and merge all assuming regions."; let constructor = "createConstraintFusionPass()"; @@ -302,6 +313,12 @@ def RankSpecializationToSCFPass ]; } +def MhloExpandOpsSimplifierPass + : Pass<"mhlo-expand-ops-simplifier", "func::FuncOp"> { + let summary = "Expand feature rich mhlo ops into a set of simpler mhlo ops."; + let constructor = "createMhloExpandOpsSimplifierPass()"; +} + def CollapseElementwiseMapPass : Pass<"mhlo-collapse-elementwise-map", "func::FuncOp"> { let summary = "Collapse the mhlo.map if the map only has elementwise ops."; @@ -312,6 +329,12 @@ def HloLegalizeToStablehloPass : Pass<"hlo-legalize-to-stablehlo", "ModuleOp"> { let summary = "Legalize HLO to StableHLO."; let constructor = "createHloLegalizeToStablehloPass()"; let dependentDialects = ["stablehlo::StablehloDialect"]; + let options = [ + Option<"allow_experimental_features_", "allow-experimental-features", + "bool", /*default=*/"false", + "Allow legalization of experimental MHLO features via StableHLO " + "custom_call"> + ]; } def StablehloLegalizeToHloPass : Pass<"stablehlo-legalize-to-hlo", "ModuleOp"> { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/optimize_mhlo.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo.cc similarity index 97% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/optimize_mhlo.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo.cc index 30505a0c8e7..e8f40574b43 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/optimize_mhlo.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo.cc @@ -22,9 +22,8 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/utils/hlo_utils.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" @@ -33,6 +32,7 @@ limitations under the License. #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" +#include "utils/hlo_utils.h" namespace mlir { namespace mhlo { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo_pass.cc similarity index 89% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo_pass.cc index 036901146ce..6576edce296 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/optimize_mhlo/optimize_mhlo_pass.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" @@ -28,7 +28,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_OPTIMIZEMHLOPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { class OptimizeMhloPass : public impl::OptimizeMhloPassBase { diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h similarity index 89% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h index 95545f16384..10a5aef783a 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/passes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H -#define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H +#ifndef MLIR_HLO_MHLO_TRANSFORMS_PASSES_H +#define MLIR_HLO_MHLO_TRANSFORMS_PASSES_H #include #include @@ -38,7 +38,7 @@ class FusionOp; namespace mhlo { #define GEN_PASS_DECL -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" /// Lowers HLO control flow ops to SCF. std::unique_ptr> createLegalizeControlFlowPass(); @@ -61,6 +61,10 @@ createLegalizeSparseChloToLinalgPass(); std::unique_ptr> createHloCanonicalizeReductionPass(); +// Expand feature rich mhlo ops to simpler mhlo ops. +std::unique_ptr> +createMhloExpandOpsSimplifierPass(); + // Rewrites scatter into transposes, reshapes and a simpler scatter. std::unique_ptr> createHloCanonicalizeScatterPass(); @@ -83,10 +87,12 @@ std::unique_ptr> createLegalizeHloShapeOpsToStandardPass(); /// Lowers from MHLO dialect to THLO dialect. -std::unique_ptr> createLegalizeMHLOToTHLOPass(); +std::unique_ptr> createLegalizeMHLOToTHLOPass( + bool enableExperimentalOps = false); /// Lowers from HLO dialect to Linalg dialect. -std::unique_ptr> createLegalizeHloToLinalgPass(); +std::unique_ptr> createLegalizeHloToLinalgPass( + bool enablePrimitiveOps = false); /// Lowers from HLO dialects dim operations. std::unique_ptr> @@ -119,6 +125,14 @@ std::unique_ptr> createMergeAssumingOpsPass(); // Iteratively reifies all shape computations in the function. std::unique_ptr> createShapeReificationPass(); +/// Creates a pass to analyze shapes and to use that information for +/// shape-related optimizations. +std::unique_ptr> +createSymbolicShapeOptimizationPass(); + +// Pass to simplify shape ops. +std::unique_ptr> createShapeSimplification(); + // Fuse shape constraints and merge all assuming regions. std::unique_ptr> createConstraintFusionPass(); @@ -171,9 +185,9 @@ std::unique_ptr createTestMaterializeBroadcastsPass(); std::unique_ptr createTestUnfuseBatchNormPass(); #define GEN_PASS_REGISTRATION -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" } // namespace mhlo } // namespace mlir -#endif // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H +#endif // MLIR_HLO_MHLO_TRANSFORMS_PASSES_H diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/prepare_for_export.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc similarity index 96% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/prepare_for_export.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc index 04676cc4716..192994676bb 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/prepare_for_export.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" @@ -38,11 +38,11 @@ limitations under the License. namespace mlir { namespace mhlo { -namespace { #define GEN_PASS_DEF_PREPAREFOREXPORTPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" +namespace { // Prepare module for export to XLA HLO. struct PrepareForExportPass : public impl::PrepareForExportPassBase { @@ -163,9 +163,5 @@ void PrepareForExportPass::runOnOperation() { }); } -std::unique_ptr createPrepareForExportPass() { - return std::make_unique(); -} - } // end namespace mhlo } // end namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/rank_specialization.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rank_specialization/rank_specialization.cc similarity index 95% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/rank_specialization.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rank_specialization/rank_specialization.cc index 2e8fbdbb7ab..0a4310d880f 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rank_specialization/rank_specialization.cc @@ -21,9 +21,9 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -31,7 +31,7 @@ limitations under the License. #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Block.h" -#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" @@ -54,7 +54,7 @@ namespace mhlo { #define GEN_PASS_DEF_RANKSPECIALIZATIONCLUSTERPASS #define GEN_PASS_DEF_RANKSPECIALIZATIONTOSCFPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { @@ -142,7 +142,7 @@ struct RankSpecializationClusterPattern : public RewritePattern { SmallVector(operandTypes.size(), loc)); // Copy operations into the body. - BlockAndValueMapping bvm; + IRMapping bvm; for (auto it : llvm::zip(operands, block->getArguments())) bvm.map(std::get<0>(it), std::get<1>(it)); rewriter.setInsertionPointToStart(block); @@ -191,10 +191,10 @@ struct MergeRankSpecializationClusterOpsPattern // Merge cluster operands. Consider only those operands of the second // cluster that do not originate in the preceding cluster. SmallVector newOperands; - for (Value v : precedingOp.operands()) newOperands.push_back(v); - for (Value v : op.operands()) { + for (Value v : precedingOp.getOperands()) newOperands.push_back(v); + for (Value v : op.getOperands()) { if (v.getDefiningOp() != precedingOp && - !llvm::is_contained(precedingOp.operands(), v)) { + !llvm::is_contained(precedingOp.getOperands(), v)) { newOperands.push_back(v); } } @@ -229,7 +229,7 @@ struct MergeRankSpecializationClusterOpsPattern // Map operands and copy operations of the preceding cluster into the new // body. - BlockAndValueMapping bvm; + IRMapping bvm; for (const auto &it : llvm::enumerate(precedingBody->getArguments())) bvm.map(it.value(), newBody->getArgument(it.index())); for (Operation &nestedOp : precedingBody->without_terminator()) @@ -238,7 +238,7 @@ struct MergeRankSpecializationClusterOpsPattern // Map operands and copy operations of the second cluster. If they result // from the preceeding cluster, we can simply map the corresponding value // internally. - for (auto it : llvm::zip(body->getArguments(), op.operands())) { + for (auto it : llvm::zip(body->getArguments(), op.getOperands())) { Value blockArg, operand; std::tie(blockArg, operand) = it; if (operand.getDefiningOp() == precedingOp) { @@ -247,7 +247,7 @@ struct MergeRankSpecializationClusterOpsPattern bvm.map(blockArg, bvm.lookup(precedingYieldOp.getOperand(where.getIndex()))); } else { - auto where = llvm::find(newOp.operands(), operand); + auto where = llvm::find(newOp.getOperands(), operand); bvm.map(blockArg, newBody->getArgument(where.getIndex())); } } @@ -314,7 +314,7 @@ bool isScalarShapeType(Type ty) { Type deriveRankedTensorTypes(Type ty, int64_t rank) { auto tensorTy = ty.dyn_cast(); if (!tensorTy) return ty; - SmallVector shape(rank, ShapedType::kDynamicSize); + SmallVector shape(rank, ShapedType::kDynamic); return RankedTensorType::get(shape, tensorTy.getElementType()); } @@ -325,7 +325,7 @@ Type deriveUnrankedTensorTypes(Type ty) { } SmallVector materializeRankedOperations( - OpBuilder &b, Location loc, BlockAndValueMapping &bvm, + OpBuilder &b, Location loc, IRMapping &bvm, chlo::RankSpecializationClusterOp op) { // Create ranked operations. for (Operation &nestedOp : op.SingleBlock::getBody()->without_terminator()) { @@ -413,7 +413,7 @@ SmallVector materializeFinalReshape( // Replace all remaining uses of the original cluster's block args. for (auto it : - llvm::zip(op.operands(), op.SingleBlock::getBody()->getArguments())) { + llvm::zip(op.getOperands(), op.SingleBlock::getBody()->getArguments())) { Value operand, barg; std::tie(operand, barg) = it; barg.replaceUsesWithIf(operand, [&](OpOperand &operand) { @@ -443,7 +443,7 @@ Value materializeScalarRankSpecializationCase( // non-scalars. Value one = b.create(loc, 1); Value allOthersAreScalar; - for (auto it : llvm::zip(op.operands(), shapes)) { + for (auto it : llvm::zip(op.getOperands(), shapes)) { Value operand, shape; std::tie(operand, shape) = it; if (llvm::is_contained(nonScalarsOfSameShape, operand) || @@ -461,11 +461,11 @@ Value materializeScalarRankSpecializationCase( } auto ifOp = b.create( - loc, op->getResultTypes(), allOthersAreScalar, + loc, allOthersAreScalar, [&](OpBuilder &b, Location loc) { // Compute flat non-scalar shape. SmallVector nonScalarShapes; - for (auto it : llvm::zip(op.operands(), shapes)) { + for (auto it : llvm::zip(op.getOperands(), shapes)) { Value operand, shape; std::tie(operand, shape) = it; if (llvm::is_contained(nonScalarsOfSameShape, operand)) @@ -475,7 +475,7 @@ Value materializeScalarRankSpecializationCase( // Derive ranked operands. auto rankedOperands = llvm::to_vector<8>( - llvm::map_range(op.operands(), [&](Value v) -> Value { + llvm::map_range(op.getOperands(), [&](Value v) -> Value { if (isScalarTensorType(v.getType())) return v; if (!llvm::is_contained(nonScalarsOfSameShape, v)) { return b @@ -492,7 +492,7 @@ Value materializeScalarRankSpecializationCase( })); // Materialize ranked variants for the element-wise operations. - BlockAndValueMapping bvm; + IRMapping bvm; for (auto it : llvm::zip(op.SingleBlock::getBody()->getArguments(), rankedOperands)) bvm.map(std::get<0>(it), std::get<1>(it)); @@ -532,12 +532,12 @@ Value materializeEqualShapesRankSpecializationCase( } auto ifOp = b.create( - loc, op->getResultTypes(), allShapesEqOrScalar, + loc, allShapesEqOrScalar, [&](OpBuilder &b, Location loc) { // Flatten non-scalar operands. Value flatShape = materializeFlatShape(b, loc, nonScalarShapes); auto flatOperands = llvm::to_vector<8>( - llvm::map_range(op.operands(), [&](Value v) -> Value { + llvm::map_range(op.getOperands(), [&](Value v) -> Value { if (isScalarTensorType(v.getType())) return v; return b.create( loc, deriveRankedTensorTypes(v.getType(), /*rank=*/1), v, @@ -545,7 +545,7 @@ Value materializeEqualShapesRankSpecializationCase( })); // Materialize ranked variants for the element-wise operations. - BlockAndValueMapping bvm; + IRMapping bvm; for (auto it : llvm::zip(op.SingleBlock::getBody()->getArguments(), flatOperands)) bvm.map(std::get<0>(it), std::get<1>(it)); @@ -575,7 +575,7 @@ Value materializeTargetRankSpecializationCase( mlir::DenseIntElementsAttr::get(extentTensorTy, SmallVector(targetRank, 1))); SmallVector rankedOperands; - for (auto it : llvm::zip(op.operands(), shapes)) { + for (auto it : llvm::zip(op.getOperands(), shapes)) { Value operand, shape; std::tie(operand, shape) = it; if (operand.getType().isa()) { @@ -594,7 +594,7 @@ Value materializeTargetRankSpecializationCase( } // Materialize ranked versions of the element-wise operations. - BlockAndValueMapping bvm; + IRMapping bvm; for (auto it : llvm::zip(op.getBody().front().getArguments(), rankedOperands)) bvm.map(std::get<0>(it), std::get<1>(it)); @@ -707,9 +707,9 @@ materializeRankSpecializationForSingleNonScalarShapeEquivalenceClass( Value flatShape = materializeFlatShape(rewriter, loc, nonScalarShapes); // Materialize ranked variants for the element-wise operations. - BlockAndValueMapping bvm; + IRMapping bvm; for (auto it : - llvm::zip(op.SingleBlock::getBody()->getArguments(), op.operands())) { + llvm::zip(op.SingleBlock::getBody()->getArguments(), op.getOperands())) { Value operand; Value bbArg; std::tie(bbArg, operand) = it; @@ -740,9 +740,10 @@ Value materializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses( SmallVector, 4> nonScalarEqs, int64_t maxTargetRank) { assert(nonScalarEqs.size() == 2 && "Expect two non-scalar equivalence classes."); - auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { - return rewriter.create(loc, v).getResult(); - })); + auto shapes = + llvm::to_vector<8>(llvm::map_range(op.getOperands(), [&](Value v) { + return rewriter.create(loc, v).getResult(); + })); ValueRange lhsNonScalarEqs = nonScalarEqs[0]; ValueRange rhsNonScalarEqs = nonScalarEqs[1]; @@ -769,9 +770,10 @@ Value materializeDefaultRankSpecialization(PatternRewriter &rewriter, Location loc, chlo::RankSpecializationClusterOp op, int64_t maxTargetRank) { - auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { - return rewriter.create(loc, v).getResult(); - })); + auto shapes = + llvm::to_vector<8>(llvm::map_range(op.getOperands(), [&](Value v) { + return rewriter.create(loc, v).getResult(); + })); // Materialize all the different cases. Value unshapedResult = materializeDefaultRankSpecializationCases( @@ -788,7 +790,7 @@ SmallVector, 4> findNonScalarShapeEquivalences( // Bridge the equivalences between operands and block arguments. for (auto it : - llvm::zip(op.operands(), op.SingleBlock::getBody()->getArguments())) + llvm::zip(op.getOperands(), op.SingleBlock::getBody()->getArguments())) eqs.unionSets(std::get<0>(it), std::get<1>(it)); // Find equalities through `SameOperandsAndResultShape` trait. @@ -848,7 +850,7 @@ SmallVector, 4> findNonScalarShapeEquivalences( // Convert to a list-like equivalence class representation. SmallVector, 4> nonScalarEqs; - for (Value v : op.operands()) { + for (Value v : op.getOperands()) { if (isScalarTensorType(v.getType())) continue; bool inserted = false; for (auto &eqClass : nonScalarEqs) { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/restrict_max_rank.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/restrict_max_rank/restrict_max_rank.cc similarity index 98% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/restrict_max_rank.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/restrict_max_rank/restrict_max_rank.cc index e497c1dded2..a4db09b8eb9 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/restrict_max_rank.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/restrict_max_rank/restrict_max_rank.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include "llvm/Support/Casting.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -35,7 +35,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_RESTRICTMAXRANKPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rewriters.h similarity index 96% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rewriters.h index 550195b3597..dcca8e78cd8 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/rewriters.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H -#define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H +#ifndef MLIR_HLO_MHLO_TRANSFORMS_REWRITERS_H +#define MLIR_HLO_MHLO_TRANSFORMS_REWRITERS_H #include #include @@ -69,7 +69,8 @@ void populateHloToArithmeticConversionPatterns(RewritePatternSet *patterns); // arguments to arithmetic dialect. void populateScalarHloToArithmeticConversionPatterns( MLIRContext *context, TypeConverter &typeConverter, - RewritePatternSet *patterns); + RewritePatternSet *patterns, + llvm::function_ref filterFn = nullptr); // Collection of rewrite patterns for lowering of shape operations from the HLO // dialect to the standard dialect. @@ -187,7 +188,8 @@ namespace stablehlo { // illegal types also get converted. void populateHloToStablehloPatterns(RewritePatternSet *patterns, TypeConverter *converter, - MLIRContext *context); + MLIRContext *context, + bool allowExperimentalFeatures); // Populates StableHLO ops to MHLO ops rewriting patterns. // Also see `stablehlo::registerFuncOpsForTypeConversion` for helper patterns @@ -201,4 +203,4 @@ void populateStablehloToHloPatterns(RewritePatternSet *patterns, } // namespace mlir -#endif // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H +#endif // MLIR_HLO_MHLO_TRANSFORMS_REWRITERS_H diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/shape_reification_pass.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/shape_reification/shape_reification_pass.cc similarity index 97% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/shape_reification_pass.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/shape_reification/shape_reification_pass.cc index b0381eb5fbb..08f7551d12e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/shape_reification_pass.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/shape_reification/shape_reification_pass.cc @@ -19,8 +19,8 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -32,7 +32,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_SHAPEREIFICATIONPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/shape_simplification.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc similarity index 94% rename from tensorflow/compiler/xla/mlir_hlo/lib/Transforms/shape_simplification.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc index 5491cc00e35..06ddde5e6a7 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/shape_simplification.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include "llvm/ADT/Optional.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Transforms/passes.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" @@ -31,9 +31,10 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { +namespace mhlo { #define GEN_PASS_DEF_SHAPESIMPLIFICATION -#include "mlir-hlo/Transforms/passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { @@ -63,16 +64,16 @@ struct BroadcastRemoveSubsumedOperandsPattern if (extents.size() > knownExtents.size()) { knownExtents.insert(knownExtents.begin(), extents.size() - knownExtents.size(), - ShapedType::kDynamicSize); + ShapedType::kDynamic); } for (size_t i = 0, e = extents.size(); i != e; ++i) { int64_t extent = extents[e - i - 1]; - if (extent != ShapedType::kDynamicSize && extent != 1) { + if (extent != ShapedType::kDynamic && extent != 1) { int64_t &knownExtent = knownExtents[knownExtents.size() - i - 1]; // A dynamic dimension is subsumed by a static one, but bail out for // known conflicting shapes. - if (knownExtent != extent && knownExtent != ShapedType::kDynamicSize) + if (knownExtent != extent && knownExtent != ShapedType::kDynamic) return failure(); knownExtent = extent; } @@ -80,7 +81,7 @@ struct BroadcastRemoveSubsumedOperandsPattern } // If we've figured out all shapes to be constants we're done. - if (!llvm::is_contained(knownExtents, ShapedType::kDynamicSize)) { + if (!llvm::is_contained(knownExtents, ShapedType::kDynamic)) { rewriter.replaceOpWithNewOp( op, op->getResultTypes(), rewriter.getIndexTensorAttr(knownExtents)); return success(); @@ -111,14 +112,14 @@ struct BroadcastRemoveSubsumedOperandsPattern // - a dynamic dim but the result is known to be constant. int64_t knownExtent = knownExtents[knownExtents.size() - i - 1]; assert(knownExtent != 1); - if (knownExtent != ShapedType::kDynamicSize && - extent == ShapedType::kDynamicSize) + if (knownExtent != ShapedType::kDynamic && + extent == ShapedType::kDynamic) continue; // - a constant non-1 dimension equal to the "known" dim. // In this case we also have to check whether this operand is the only // contributor of that constant. - if (knownExtent != ShapedType::kDynamicSize && extent == knownExtent && + if (knownExtent != ShapedType::kDynamic && extent == knownExtent && llvm::count_if(operandExtents, [&](ArrayRef operandShape) { return i < operandShape.size() && operandShape[operandShape.size() - i - 1] == knownExtent; @@ -251,4 +252,5 @@ std::unique_ptr> createShapeSimplification() { return std::make_unique(); } +} // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/sink_constants_to_control_flow/sink_constants_to_control_flow.cc similarity index 96% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/sink_constants_to_control_flow/sink_constants_to_control_flow.cc index c7bb4ae05e1..47a72c5e707 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/sink_constants_to_control_flow/sink_constants_to_control_flow.cc @@ -15,7 +15,7 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Casting.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" @@ -27,7 +27,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_SINKCONSTANTSTOCONTROLFLOWPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/sparse_chlo_legalize_to_linalg.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/sparse_chlo_legalize_to_linalg/sparse_chlo_legalize_to_linalg.cc similarity index 96% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/sparse_chlo_legalize_to_linalg.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/sparse_chlo_legalize_to_linalg/sparse_chlo_legalize_to_linalg.cc index fc593e9cedb..9809697212c 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/sparse_chlo_legalize_to_linalg.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/sparse_chlo_legalize_to_linalg/sparse_chlo_legalize_to_linalg.cc @@ -22,9 +22,9 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" -#include "mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h" +#include "mhlo/transforms/rewriters.h" +#include "mhlo/utils/legalize_to_linalg_utils.h" +#include "mhlo/utils/type_conversion.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -46,7 +46,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_CHLOLEGALIZETOLINALGPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/sparse_rewriting.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/sparse_rewriting/sparse_rewriting.cc similarity index 95% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/sparse_rewriting.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/sparse_rewriting/sparse_rewriting.cc index 4699e2dca86..b749bc81ca9 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/sparse_rewriting.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/sparse_rewriting/sparse_rewriting.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "llvm/Support/Debug.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/Operation.h" @@ -31,7 +31,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_SPARSEREWRITINGPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/stablehlo_legalize_to_hlo.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc similarity index 80% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/stablehlo_legalize_to_hlo.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc index 171e0244c60..2bd0e3ba0e9 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/stablehlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc @@ -16,9 +16,10 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/map_stablehlo_to_hlo_op.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/map_stablehlo_to_hlo_op.h" +#include "mhlo/transforms/rewriters.h" +#include "mlir/AsmParser/AsmParser.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Location.h" @@ -85,6 +86,11 @@ Attribute convertAttr(Attribute stablehloAttr) { attr.getContext(), attr.getOffsetDims(), attr.getCollapsedSliceDims(), attr.getStartIndexMap(), attr.getIndexVectorDim()); } + if (auto attr = stablehloAttr.dyn_cast()) { + return mhlo::OutputOperandAliasAttr::get( + attr.getContext(), attr.getOutputTupleIndices(), attr.getOperandIndex(), + attr.getOperandTupleIndices()); + } if (auto attr = stablehloAttr.dyn_cast()) { RETURN_CONVERTED_ENUM_ATTR(Precision); } @@ -151,6 +157,36 @@ class StablehloToHloOpConverter : public OpConversionPattern { // the dialect conversion infrastructure. ValueRange hloOperands = adaptor.getOperands(); + // Extensibility protocol for public MHLO features that are not yet + // supported in StableHLO. See hlo_legalize_to_stablehlo.cc for details. + if constexpr (std::is_same::value) { + if (stablehloOp.getCallTargetName().starts_with("mhlo.")) { + // Only call_target_name and backend_config are compatible with + // the extensibility protocol. + for (NamedAttribute stablehloAttr : stablehloOp->getAttrs()) { + auto stablehloName = stablehloAttr.getName().getValue(); + if (stablehloName != "call_target_name" && + stablehloName != "backend_config") + return failure(); + } + + // Dynamically create the corresponding MHLO op using call_target_name + // and backend_config. (It is quite neat that we have an API for this!). + OperationState hloOpState(stablehloOp.getLoc(), + stablehloOp.getCallTargetName()); + hloOpState.addOperands(hloOperands); + hloOpState.addTypes(hloTypes); + auto hloAttrs = parseAttribute(stablehloOp.getBackendConfig(), + stablehloOp.getContext()) + .template dyn_cast_or_null(); + if (!hloAttrs) return failure(); + hloOpState.addAttributes(hloAttrs.getValue()); + Operation* hloOp = rewriter.create(hloOpState); + rewriter.replaceOp(stablehloOp, hloOp->getResults()); + return success(); + } + } + // Convert StableHLO attributes to MHLO equivalents. // If an attribute is not defined in StableHLO, then it is unchanged, // with the exception of ArrayAttr which is converted recursively. @@ -180,6 +216,10 @@ class StablehloToHloOpConverter : public OpConversionPattern { for (auto [stablehloRegion, hloRegion] : llvm::zip(stablehloOp->getRegions(), hloOp->getRegions())) { rewriter.inlineRegionBefore(stablehloRegion, hloRegion, hloRegion.end()); + if (failed(rewriter.convertRegionTypes(&hloRegion, + *this->getTypeConverter(), + /*entryConversion=*/nullptr))) + return failure(); } return success(); } diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/stablehlo_legalize_to_hlo_pass.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc similarity index 89% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/stablehlo_legalize_to_hlo_pass.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc index 35a41e7cad4..848cf5597e8 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/stablehlo_legalize_to_hlo_pass.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo_pass.cc @@ -18,10 +18,10 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/passes.h" +#include "mhlo/transforms/rewriters.h" +#include "mhlo/utils/type_conversion.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/PatternMatch.h" @@ -35,7 +35,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_STABLEHLOLEGALIZETOHLOPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/symbolic_shape_optimization.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc similarity index 87% rename from tensorflow/compiler/xla/mlir_hlo/lib/Transforms/symbolic_shape_optimization.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc index e9306348d80..7eb478c7b3e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Transforms/symbolic_shape_optimization.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc @@ -18,14 +18,15 @@ limitations under the License. #include #include #include +#include #include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Analysis/shape_component_analysis.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Transforms/passes.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/analysis/shape_component_analysis.h" +#include "mhlo/transforms/passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -38,9 +39,10 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { +namespace mhlo { #define GEN_PASS_DEF_SYMBOLICSHAPEOPTIMIZATION -#include "mlir-hlo/Transforms/passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" using ShapeOrValueInfo = ShapeComponentAnalysis::ShapeOrValueInfo; using Symbol = ShapeComponentAnalysis::Symbol; @@ -78,7 +80,7 @@ struct SimplifyBroadcasts : public mlir::OpRewritePattern { // Compute broadcast symbolically. SmallVector> symResult(rank, - llvm::None); + std::nullopt); for (const auto &sInfo : llvm::enumerate(shapesInfo)) { size_t dimOffset = rank - sInfo.value().size(); for (const auto &symExpr : llvm::enumerate(sInfo.value())) { @@ -183,7 +185,7 @@ struct AnnotateExpandingDimensionsInDynamicBroadcastInDim // Collect possibly already annotated info. auto insertAll = [](llvm::SmallSetVector &dst, - Optional src) { + std::optional src) { if (!src) return; for (auto it : *src) dst.insert(it.getLimitedValue()); }; @@ -338,7 +340,7 @@ struct RemoveRedundantCstrReshapable final if (!isSymbolicProduct( dim, [&](int64_t c) { - if (c != ShapedType::kDynamicSize) concreteProductDynShape *= c; + if (c != -1) concreteProductDynShape *= c; }, [&](Symbol s) { partialSymbolicFactorsDynShape.push_back(s); })) { return failure(); @@ -455,7 +457,7 @@ bool isUnpairedUnitDimension( int64_t getShapedTypyDimSize(const SymbolicProduct &symProduct) { return symProduct.symbolic.empty() ? symProduct.concrete - : ShapedType::kDynamicSize; + : ShapedType::kDynamic; } // Iterate over the operand's and the result's shape dimensions and find @@ -587,7 +589,7 @@ LogicalResult findExpandingAndCollapsingDimensionGroups( int64_t tyDimSize = getShapedTypyDimSize(gcd); // Allow no more than one dynamic dimension per expansion group. - if (tyDimSize == ShapedType::kDynamicSize) { + if (tyDimSize == ShapedType::kDynamic) { numDynamicDims++; if (numDynamicDims > 1) return failure(); } @@ -640,7 +642,7 @@ SmallVector concretizeOperandShape( return result; } -llvm::Optional> requiresReassociationOfKind( +std::optional> requiresReassociationOfKind( DimensionGroupKind kind, const SmallVector &dimGroups) { SmallVector reassociation; reassociation.reserve(dimGroups.size()); @@ -659,7 +661,7 @@ llvm::Optional> requiresReassociationOfKind( // Return the reassociation if expansion is required. if (isStrictlyReassociating) return reassociation; - return llvm::None; + return std::nullopt; } LogicalResult materializeReshapeAsExpandAndCollapse( @@ -786,6 +788,92 @@ struct CstrBroadcastableOpLowering } }; +// Returns a shape tensor if the shapes can be broadcasted to a known shape. +// Will either return one of the shapes or a generated mix of the shapes. +llvm::Optional simplifyBroadcast(ShapeComponentAnalysis &analysis, + ValueRange shapes, Location loc, + OpBuilder *builder) { + // First find the input shape with the largest rank. + SmallVector> shapesFound; + size_t maxRank = 0; + for (const auto &shape : llvm::enumerate(shapes)) { + auto foundShape = analysis.GetValueInfo(shape.value()); + if (!foundShape) return {}; + shapesFound.push_back(*foundShape); + maxRank = std::max(maxRank, foundShape->size()); + } + if (maxRank == 0) { + return Value(builder->create( + loc, shapes[0].getType(), SmallVector())); + } + + SmallVector joinedDimensions( + maxRank); + SmallVector> shapeAndRankForDim(maxRank); + for (const auto &shape : llvm::enumerate(shapesFound)) { + for (const auto &dim : llvm::enumerate(llvm::reverse(shape.value()))) { + // 1 dimensions don't contribute to the final result. + if (dim.value().isConstant(1)) continue; + // If it's not a 1 dimension it will be present in the result. Remember + // where it came from. + auto index = maxRank - dim.index() - 1; + if (!joinedDimensions[index]) { + joinedDimensions[index] = &dim.value(); + shapeAndRankForDim[index] = + std::make_pair(shapes[shape.index()], shape.value().size()); + continue; + } + // Bail if the dimensions are neither equal nor 1. + if (*joinedDimensions[index] != dim.value()) return {}; + } + } + // If the output is the same as one of the inputs just return that. + if (llvm::all_equal(shapeAndRankForDim) && shapeAndRankForDim[0].first) { + return shapeAndRankForDim[0].first; + } + // Otherwise rematerialize the shape from the pieces we have. + SmallVector elements; + for (size_t i = 0; i != maxRank; ++i) { + // 1 dimensions are filtered above, recreate the constant. + if (!shapeAndRankForDim[i].first) { + auto one = builder->getIntegerAttr( + shapes[0].getType().cast().getElementType(), 1); + elements.push_back(builder->create(loc, one)); + continue; + } + // Extract from one of the shapes, accounting for the reverse indexing + // performed by broadcast. + Value index = builder->create( + loc, i - maxRank + shapeAndRankForDim[i].second); + elements.push_back(builder->create( + loc, shapeAndRankForDim[i].first, index)); + } + return Value(builder->create(loc, elements)); +} + +// Replace shape.broadcast with a shape if it's statically known. +struct BroadcastOpLowering final + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite( + shape::BroadcastOp op, mlir::PatternRewriter &rewriter) const override { + ShapeComponentAnalysis shapeComponentAnalysis; + auto newBroadcast = simplifyBroadcast( + shapeComponentAnalysis, op.getShapes(), op.getLoc(), &rewriter); + if (!newBroadcast) return failure(); + + // Insert cast, if needed. + Type expectedTy = op.getType(); + if (newBroadcast->getType() != expectedTy) { + newBroadcast = rewriter.create(op.getLoc(), expectedTy, + *newBroadcast); + } + + rewriter.replaceOp(op, {*newBroadcast}); + return success(); + } +}; + class SymbolicShapeOptimizationPass final : public impl::SymbolicShapeOptimizationBase< SymbolicShapeOptimizationPass> { @@ -800,13 +888,17 @@ class SymbolicShapeOptimizationPass final // clang-format off patterns.insert< AnnotateExpandingDimensionsInDynamicBroadcastInDim, + BroadcastOpLowering, CstrBroadcastableOpLowering, DynamicReshapeToExpandAndCollapseShape, RemoveComputeReshapeShape, RemoveRedundantCstrReshapable, SimplifyBroadcasts>(ctx); // clang-format on + + // Collect some relevant canonicalization patterns. shape::AssumingOp::getCanonicalizationPatterns(patterns, ctx); + shape::ShapeOfOp::getCanonicalizationPatterns(patterns, ctx); if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { @@ -822,4 +914,5 @@ createSymbolicShapeOptimizationPass() { return std::make_unique(); } +} // end namespace mhlo } // end namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc similarity index 98% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc index 9a0918d14a6..727c70440a8 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc @@ -28,7 +28,7 @@ limitations under the License. namespace mlir { #define GEN_PASS_DEF_TESTINFERSHAPEDTYPEMETHODSPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace mhlo { namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm.cc similarity index 99% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm.cc index 74a0254f3bd..8e300a15b4a 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc similarity index 91% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc index 3d2f707b255..f912946c8ef 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc @@ -15,8 +15,8 @@ limitations under the License. #include -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/rewriters.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" @@ -32,7 +32,7 @@ namespace mlir { namespace mhlo { #define GEN_PASS_DEF_TESTUNFUSEBATCHNORMPASS -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +#include "mhlo/transforms/mhlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/IR/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/mhlo/utils/CMakeLists.txt similarity index 57% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/IR/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/mhlo/utils/CMakeLists.txt index 9b83249d16b..ad7cc55334b 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/lhlo/IR/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/utils/CMakeLists.txt @@ -13,30 +13,58 @@ # See the License for the specific language governing permissions and # limitations under the License. # + include_directories(BEFORE ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) -add_mlir_library(LmhloStructuredInterface - lhlo_structured_interface.cc +add_mlir_library(MhloScatterUtils + mhlo_scatter_gather_utils.cc + + DEPENDS + MLIRhlo_opsIncGen + + LINK_COMPONENTS + Core LINK_LIBS PUBLIC - MLIRIR + MhloDialect +) - DEPENDS - MLIRlhlo_structured_interfaceIncGen +add_mlir_library(MhloTypeConversion + type_conversion.cc + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MhloDialect + MLIRIR + MLIRFuncDialect + MLIRFuncTransforms + MLIRTensorDialect + StablehloOps ) -add_mlir_dialect_library(LmhloDialect - lhlo_ops.cc +add_mlir_library(HloToLinalgUtils + legalize_to_linalg_utils.cc DEPENDS - MLIRlhlo_opsIncGen + MLIRhlo_opsIncGen + MLIRMhloPassIncGen + + LINK_COMPONENTS + Core LINK_LIBS PUBLIC - HloOpsCommon - LmhloStructuredInterface + LmhloDialect MhloDialect + MhloTypeConversion + MLIRBufferizationDialect + MLIRComplexDialect MLIRIR + MLIRLinalgUtils + MLIRPass + MLIRRewrite + MLIRTransformUtils ) - diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg_utils.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.cc similarity index 79% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg_utils.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.cc index ea073213497..b83f4c82bc6 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg_utils.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.cc @@ -16,7 +16,7 @@ limitations under the License. // This file implements utilities for lowering CHLO/HLO/LHLO dialect to Linalg // dialect. -#include "mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h" +#include "mhlo/utils/legalize_to_linalg_utils.h" #include #include @@ -39,21 +39,23 @@ bool hasIntegralShapeType(Operation* op) { } // namespace -SmallVector getParallelAndReductionIterators( +SmallVector getParallelAndReductionIterators( unsigned nLoops, unsigned nReduction) { - SmallVector res(nLoops - nReduction, - getParallelIteratorTypeName()); - res.append(nReduction, getReductionIteratorTypeName()); + SmallVector res(nLoops - nReduction, + utils::IteratorType::parallel); + res.append(nReduction, utils::IteratorType::reduction); return res; } -SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops) { +SmallVector getNParallelLoopsAttrs( + unsigned nParallelLoops) { return getParallelAndReductionIterators(nParallelLoops, 0); } Value getEmptySparseTensor(OpBuilder& b, Location loc, ShapedType type, ArrayRef dynSizes) { - return b.create(loc, type, dynSizes, + return b.create(loc, type.cast(), + dynSizes, /*copy=*/Value(), /*memory_space=*/IntegerAttr()); } @@ -61,7 +63,8 @@ Value getEmptySparseTensor(OpBuilder& b, Location loc, ShapedType type, Value getEmptyTensor(OpBuilder& b, Location loc, ShapedType type, ArrayRef dynSizes) { return b.create(loc, type.getShape(), type.getElementType(), - dynSizes); + dynSizes, + type.cast().getEncoding()); } Value getEmptyTensorFor(OpBuilder& b, Location loc, ShapedType resultType, @@ -79,7 +82,7 @@ Value getEmptyTensorFor(OpBuilder& b, Location loc, ShapedType resultType, assert(reifiedShapes.size() == 1 && "Expected one reified result"); // Construct sizes for the required dimensions. for (auto& en : llvm::enumerate(resultType.getShape())) { - if (en.value() != ShapedType::kDynamicSize) continue; + if (en.value() != ShapedType::kDynamic) continue; sizes.push_back(b.create( loc, reifiedShapes[0], ValueRange{b.create(loc, en.index())})); @@ -122,6 +125,19 @@ Value postSparsify(Operation* op, Value semiring, Value result, OpBuilder* b) { return result; } +bool allOperandsAreScalarTensors(Operation* op) { + return llvm::all_of(op->getOperands(), [](Value operand) { + auto operandTy = operand.getType().dyn_cast(); + return operandTy && operandTy.getRank() == 0; + }); +} + +bool isInBodyOfLinalgOps(Operation* op) { + auto* parentOp = op->getParentRegion()->getParentOp(); + return parentOp->getDialect() == + parentOp->getContext()->getLoadedDialect(); +} + } // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.h similarity index 81% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h rename to tensorflow/compiler/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.h index e0bd885e5ec..a91dbe9ac5e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/legalize_to_linalg_utils.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.h @@ -26,7 +26,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" -#include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h" +#include "mhlo/transforms/map_mhlo_to_scalar_op.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -50,11 +50,12 @@ namespace mhlo { /// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes /// are "parallel" except the last `nReduction` elements, where are "reduction" /// attributes. -SmallVector getParallelAndReductionIterators(unsigned nLoops, - unsigned nReduction); +SmallVector getParallelAndReductionIterators( + unsigned nLoops, unsigned nReduction); /// Returns an ArrayAttr that contains `nParallelLoops` "parallel" attributes. -SmallVector getNParallelLoopsAttrs(unsigned nParallelLoops); +SmallVector getNParallelLoopsAttrs( + unsigned nParallelLoops); /// Generates an init sparse tensor. Value getEmptySparseTensor(OpBuilder& b, Location loc, ShapedType type, @@ -88,6 +89,12 @@ Value preSparsify(Operation* op, llvm::SmallVector& values, Type rtp, /// Finalizes sparse semi-ring construction. Value postSparsify(Operation* op, Value semiring, Value result, OpBuilder* b); +/// Returns true if all operands are tensors with rank 0. +bool allOperandsAreScalarTensors(Operation* op); + +/// Returns true if parent op is linalg. +bool isInBodyOfLinalgOps(Operation* op); + /// Converts a HLO operation to a linalg.generic op that contains the /// corresponding scalar operations. template @@ -98,6 +105,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern { LogicalResult matchAndRewrite( OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter& rewriter) const final { + auto loc = op.getLoc(); // Find maximum rank / number of loops. auto getRank = [](Value v) { return v.getType().cast().getRank(); @@ -131,25 +139,8 @@ class PointwiseToLinalgConverter : public OpConversionPattern { op, "mismatched operand/result types or iterator count"); } - auto loc = op.getLoc(); - // Within a linalg op, we can immediately de-tensorsize if the computation - // is scalar. We do not do this on the top-level, as that would break the - // nice invariant that all programs are exclusively on tensors, which is - // currently relied on for fusion in some pipelines. - if (nloops == 0 && isInBodyOfLinalgOps(op)) { - // No need to create a linalg.generic if all inputs are scalars. - SmallVector inputs; - for (auto input : adaptor.getOperands()) { - inputs.push_back( - rewriter.create(loc, input, ValueRange())); - } - Value scalarResult = mhlo::MhloOpToStdScalarOp::mapOp( - op, resultTy->getElementType(), inputs, &rewriter); - if (!scalarResult) return failure(); - rewriter.replaceOpWithNewOp(op, *resultTy, - scalarResult); - return success(); - } + if (allOperandsAreScalarTensors(op) && isInBodyOfLinalgOps(op)) + return failure(); // Find input/output values and types. ValueRange inputs = adaptor.getOperands(); @@ -188,13 +179,6 @@ class PointwiseToLinalgConverter : public OpConversionPattern { rewriter.replaceOp(op, linalgOp->getResults()); return success(); } - - private: - static bool isInBodyOfLinalgOps(Operation* op) { - auto* parentOp = op->getParentRegion()->getParentOp(); - return parentOp->getDialect() == - parentOp->getContext()->getLoadedDialect(); - } }; } // namespace mhlo diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/mhlo_scatter_gather_utils.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.cc similarity index 96% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/mhlo_scatter_gather_utils.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.cc index 9a846be903c..1db992ad266 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/mhlo_scatter_gather_utils.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.cc @@ -16,7 +16,7 @@ limitations under the License. // This file implements utilities for the canonicalization of ScatterOp and // GatherOp. -#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_scatter_gather_utils.h" +#include "mhlo/utils/mhlo_scatter_gather_utils.h" #include @@ -85,11 +85,11 @@ makeOperandStartIndexPermutations(ArrayRef dimMap, int operandRank) { TypedValue insertDegenerateDimensions( OpBuilder& b, Location loc, TypedValue tensor, ArrayRef dimsToInsert) { + assert(llvm::is_sorted(dimsToInsert) && "dimsToInsert must be sorted"); if (dimsToInsert.empty()) return tensor; TensorType type = tensor.getType(); SmallVector newShape{type.getShape()}; - for (int64_t dim : llvm::reverse(dimsToInsert)) - newShape.insert(newShape.begin() + dim, 1); + for (int64_t dim : dimsToInsert) newShape.insert(newShape.begin() + dim, 1); auto newType = RankedTensorType::get(newShape, type.getElementType()); return b diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_scatter_gather_utils.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.h similarity index 98% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_scatter_gather_utils.h rename to tensorflow/compiler/xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.h index bd84a7e17bd..cf7461be5a0 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_scatter_gather_utils.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/utils/mhlo_scatter_gather_utils.h @@ -21,7 +21,7 @@ limitations under the License. #include -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mhlo/IR/hlo_ops.h" namespace mlir { namespace mhlo { diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/type_conversion.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/utils/type_conversion.cc similarity index 95% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/type_conversion.cc rename to tensorflow/compiler/xla/mlir_hlo/mhlo/utils/type_conversion.cc index 831d3e374f0..998fd02a25c 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/mhlo/transforms/type_conversion.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/utils/type_conversion.cc @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h" +#include "mhlo/utils/type_conversion.h" -#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include + +#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -53,7 +55,7 @@ llvm::Optional materializeCastFromIllegal(OpBuilder& builder, Type type, Type toType = getElementTypeOrSelf(type); if ((!fromType.isSignedInteger() && !fromType.isUnsignedInteger()) || !toType.isSignlessInteger()) - return llvm::None; + return std::nullopt; // Use unrealized conversion casts to do signful->signless conversions. return builder.create(loc, type, inputs[0]) ->getResult(0); @@ -66,7 +68,7 @@ llvm::Optional materializeCastToIllegal(OpBuilder& builder, Type type, Type toType = getElementTypeOrSelf(type); if (!fromType.isSignlessInteger() || (!toType.isSignedInteger() && !toType.isUnsignedInteger())) - return llvm::None; + return std::nullopt; // Use unrealized conversion casts to do signless->signful conversions. return builder.create(loc, type, inputs[0]) ->getResult(0); @@ -76,7 +78,7 @@ llvm::Optional scalarToTensor(OpBuilder& builder, Type /*type*/, ValueRange inputs, Location loc) { assert(inputs.size() == 1); if (inputs.front().getType().isa()) { - return llvm::None; + return std::nullopt; } return builder .create( @@ -148,6 +150,9 @@ HloToStablehloTypeConverter::HloToStablehloTypeConverter() addConversion([](mhlo::TokenType type) -> Type { return stablehlo::TokenType::get(type.getContext()); }); + // Consider implementing stablehlo::CustomType to provide an escape hatch + // for modelling MHLO types that aren't yet in StableHLO. + // Proposal: https://github.com/openxla/stablehlo/issues/743. } bool HloToStablehloTypeConverter::isSourceDialect(Dialect& dialect) { diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/type_conversion.h b/tensorflow/compiler/xla/mlir_hlo/mhlo/utils/type_conversion.h similarity index 95% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/type_conversion.h rename to tensorflow/compiler/xla/mlir_hlo/mhlo/utils/type_conversion.h index 987347de43d..ae3f0b963e8 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/type_conversion.h +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/utils/type_conversion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_TYPE_CONVERSION_H -#define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_TYPE_CONVERSION_H +#ifndef MLIR_HLO_MHLO_UTILS_TYPE_CONVERSION_H +#define MLIR_HLO_MHLO_UTILS_TYPE_CONVERSION_H #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" @@ -96,4 +96,4 @@ void registerFuncOpsForTypeConversion(ConversionTarget& target, } // namespace stablehlo } // namespace mlir -#endif // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_TYPE_CONVERSION_H +#endif // MLIR_HLO_MHLO_UTILS_TYPE_CONVERSION_H diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/BUILD b/tensorflow/compiler/xla/mlir_hlo/tests/BUILD index 7ebd172013a..766fdff76cb 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/BUILD +++ b/tensorflow/compiler/xla/mlir_hlo/tests/BUILD @@ -2,7 +2,10 @@ load("@llvm-project//llvm:lit_test.bzl", "lit_test", "package_path") load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@bazel_skylib//rules:expand_template.bzl", "expand_template") -package(licenses = ["notice"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) [ lit_test( diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_hlo_broadcasts.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_hlo_broadcasts.mlir index 5e0d3ba8497..ce5a493fd6e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_hlo_broadcasts.mlir @@ -274,26 +274,26 @@ func.func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) // ----- // CHECK-LABEL: @shift_leftWithoutBroadcast -func.func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +func.func @shift_leftWithoutBroadcast(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK: mhlo.shift_left %arg0, %arg1 - %0 = chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> + %0 = chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + func.return %0 : tensor<4xi32> } // ----- // CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast -func.func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +func.func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK: mhlo.shift_right_arithmetic %arg0, %arg1 - %0 = chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> + %0 = chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + func.return %0 : tensor<4xi32> } // ----- // CHECK-LABEL: @shift_right_logicalWithoutBroadcast -func.func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +func.func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK: mhlo.shift_right_logical %arg0, %arg1 - %0 = chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> + %0 = chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + func.return %0 : tensor<4xi32> } // ----- diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/add_debug_info.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/add_debug_info.mlir new file mode 100644 index 00000000000..cc3690fea32 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/add_debug_info.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-hlo-opt %s --add-debug-info --mlir-print-debuginfo | FileCheck %s + +builtin.module { + func.func @foo() { + return + } +} + +// CHECK: module +// CHECK: func.func @[[SUBPROGRAM_NAME:.*]]() { +// CHECK: return loc(#[[RET_LOC:.*]]) +// CHECK: } loc(#[[FUSED_SUBPROGRAM_LOC:.*]]) +// CHECK: } loc(#[[MODULE_LOC:.*]]) +// CHECK: #di_basic_type = #llvm.di_basic_type +// CHECK: #di_file = #llvm.di_file<"[[FILE_NAME:.*]]" in "[[DIR_NAME:.*]]"> +// CHECK: #[[MODULE_LOC]] = loc("[[DIR_NAME]]/[[FILE_NAME]]":[[#MODULE_LINE:]]:1) +// CHECK: #[[SUBPROGRAM_LOC:.*]] = loc("[[DIR_NAME]]/[[FILE_NAME]]":[[#MODULE_LINE+1]]:3) +// CHECK: #[[RET_LOC]] = loc("[[DIR_NAME]]/[[FILE_NAME]]":[[#MODULE_LINE+2]]:5) +// CHECK: #di_compile_unit = #llvm.di_compile_unit +// CHECK: #di_subroutine_type = #llvm.di_subroutine_type +// CHECK: #di_subprogram = #llvm.di_subprogram +// CHECK: #[[FUSED_SUBPROGRAM_LOC]] = loc(fused<#di_subprogram>[#[[SUBPROGRAM_LOC]]]) \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/bufferization.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/bufferization.mlir index 0dea74401c6..4813f3ffdeb 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/bufferization.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/bufferization.mlir @@ -1,5 +1,6 @@ -// RUN: mlir-hlo-opt %s -test-gml-st-bufferization -canonicalize -cse \ -// RUN: -split-input-file | FileCheck %s +// RUN: mlir-hlo-opt %s -empty-tensor-to-alloc-tensor \ +// RUN: -hlo-one-shot-bufferize -canonicalize -cse -canonicalize \ +// RUN: -split-input-file | FileCheck %s func.func @set_tile(%input: tensor) -> tensor<2x4xf32> { %c0 = arith.constant 0 : index @@ -8,10 +9,8 @@ func.func @set_tile(%input: tensor) -> tensor<2x4xf32> { %dim_0 = tensor.dim %input, %c0 : tensor %dim_1 = tensor.dim %input, %c1 : tensor - %tile = gml_st.tile [0, 1][2, 4][1, 1] : !gml_st.tile<2x4> - - %slice = gml_st.materialize %input[%tile] - : tensor[!gml_st.tile<2x4>] to tensor<2x4xf32> + %slice = tensor.extract_slice %input[0, 1][2, 4][1, 1] + : tensor to tensor<2x4xf32> return %slice : tensor<2x4xf32> } @@ -26,26 +25,26 @@ func.func @set_tile(%input: tensor) -> tensor<2x4xf32> { #map = affine_map<(d0, d1) -> (d0, d1)> func.func @parallel_with_tiles(%lhs: tensor, %rhs: tensor, - %init : tensor) -> tensor { + %out : tensor) -> tensor { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %dim_0 = tensor.dim %lhs, %c0 : tensor %dim_1 = tensor.dim %lhs, %c1 : tensor - %result = gml_st.parallel (%i, %j) = (%c0, %c0) to (%dim_0, %dim_1) step (%c4, %c1) { + %result = gml_st.parallel (%i, %j) = (%c0, %c0) to (%dim_0, %dim_1) + step (%c4, %c1) outs (%out_ = %out: tensor) { %7 = arith.addi %i, %c4 : index %8 = arith.cmpi sgt, %7, %dim_0 : index %9 = arith.subi %dim_0, %i : index %size_0 = arith.select %8, %9, %c4 : index - %tile = gml_st.tile [%i, %j] [%size_0, 1] [1, 1] : !gml_st.tile - %lhs_tile = gml_st.materialize %lhs[%tile] - : tensor[!gml_st.tile] to tensor - %rhs_tile = gml_st.materialize %rhs[%tile] - : tensor[!gml_st.tile] to tensor - %init_tile = gml_st.materialize %init[%tile] - : tensor[!gml_st.tile] to tensor + %lhs_tile = tensor.extract_slice %lhs[%i, %j] [%size_0, 1] [1, 1] + : tensor to tensor + %rhs_tile = tensor.extract_slice %rhs[%i, %j] [%size_0, 1] [1, 1] + : tensor to tensor + %init_tile = tensor.extract_slice %out_[%i, %j] [%size_0, 1] [1, 1] + : tensor to tensor %sum = linalg.generic { indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} @@ -55,7 +54,8 @@ func.func @parallel_with_tiles(%lhs: tensor, %rhs: tensor, %add = arith.addf %l, %r : f32 linalg.yield %add : f32 } -> tensor - gml_st.set_yield %sum into %init[%tile] + %tile = gml_st.tile [%i, %j] [%size_0, 1] [1, 1] : !gml_st.tile + gml_st.set_yield %sum into %out_[%tile] : tensor into tensor[!gml_st.tile] } : tensor return %result : tensor @@ -99,12 +99,14 @@ func.func @materialize_and_yield_with_constants( %c2 = arith.constant 2 : index %c8 = arith.constant 8 : index - %1 = gml_st.parallel (%i, %j) = (%c0, %c0) to (%c8, %c2) step (%c1, %c1) { - %2 = gml_st.tile [%i, %j] [1, 1] [1, 1] : !gml_st.tile<1x1> - %3 = gml_st.materialize %in[%2] - : tensor<8x2xf32>[!gml_st.tile<1x1>] to f32 + %1 = gml_st.parallel (%i, %j) = (%c0, %c0) to (%c8, %c2) step (%c1, %c1) + outs(%out_ = %out: tensor<8x2xf32>) { + %2 = tensor.extract_slice %in[%i, %j] [1, 1] [1, 1] + : tensor<8x2xf32> to tensor<1x1xf32> + %3 = tensor.extract %2[%c0, %c0] : tensor<1x1xf32> %4 = math.absf %3: f32 - gml_st.set_yield %4 into %out[%2] + %5 = gml_st.tile [%i, %j] [1, 1] [1, 1] : !gml_st.tile<1x1> + gml_st.set_yield %4 into %out_[%5] : f32 into tensor<8x2xf32>[!gml_st.tile<1x1>] } : tensor<8x2xf32> return %1 : tensor<8x2xf32> @@ -113,23 +115,26 @@ func.func @materialize_and_yield_with_constants( // CHECK-SAME: %[[IN:.*]]: memref<8x2xf32>, %[[OUT:.*]]: memref<8x2xf32>) // CHECK: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = -// CHECK-NEXT: %[[ELEM:.*]] = memref.load %[[IN]][%[[I]], %[[J]]] +// CHECK-NEXT: %[[SLICE:.*]] = memref.subview %[[IN]][%[[I]], %[[J]]] +// CHECK-NEXT: %[[ELEM:.*]] = memref.load %[[SLICE]] // CHECK-NEXT: %[[ABS:.*]] = math.absf %[[ELEM]] : f32 // CHECK-NEXT: memref.store %[[ABS]], %[[OUT]][%[[I]], %[[J]]] // CHECK-NEXT: gml_st.set_yield // ----- -func.func @parallel_with_vector(%in: vector<8xf32>, %init : vector<8xf32>) -> vector<8xf32> { + +func.func @parallel_with_vector(%in: vector<8xf32>, %out : vector<8xf32>) -> vector<8xf32> { %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %result = gml_st.parallel (%i) = (%c0) to (%c8) step (%c4) { - %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> - %in_tile = gml_st.materialize %in[%tile] - : vector<8xf32>[!gml_st.tile<4>] to vector<4xf32> + %result = gml_st.parallel (%i) = (%c0) to (%c8) step (%c4) + outs (%out_ = %out: vector<8xf32>) { + %in_tile = gml_st.materialize %in[%i] [4] [1] + : vector<8xf32> to vector<4xf32> %neg = arith.negf %in_tile : vector<4xf32> - gml_st.set_yield %neg into %init[%tile] + %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> + gml_st.set_yield %neg into %out_[%tile] : vector<4xf32> into vector<8xf32>[!gml_st.tile<4>] } : vector<8xf32> @@ -144,33 +149,35 @@ func.func @parallel_with_vector(%in: vector<8xf32>, %init : vector<8xf32>) -> ve // ----- -func.func @nested_parallel_with_vector(%init : tensor) +func.func @nested_parallel_with_vector(%out : tensor) -> tensor { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c32 = arith.constant 32 : index - %dim_0 = tensor.dim %init, %c0 : tensor + %dim_0 = tensor.dim %out, %c0 : tensor - %result = gml_st.parallel (%i) = (%c0) to (%dim_0) step (%c1) { - %tile = gml_st.tile [%i, 0] [1, 32] [1, 1] : !gml_st.tile<1x32> - %init_tile = gml_st.materialize %init[%tile] - : tensor[!gml_st.tile<1x32>] to tensor<1x32xf32> - %init_vec = vector.transfer_read %init_tile[%c0, %c0], %cst + %result = gml_st.parallel (%i) = (%c0) to (%dim_0) step (%c1) + outs (%out_ = %out: tensor) { + %out_tile = tensor.extract_slice %out_[%i, 0] [1, 32] [1, 1] + : tensor to tensor<1x32xf32> + %out_vec = vector.transfer_read %out_tile[%c0, %c0], %cst {in_bounds = [true, true]}: tensor<1x32xf32>, vector<1x32xf32> - %result_vec = gml_st.parallel (%j) = (%c0) to (%c32) step (%c4) { + %result_vec = gml_st.parallel (%j) = (%c0) to (%c32) step (%c4) + outs (%vec_out_ = %out_vec: vector<1x32xf32>) { + %inner_tile = gml_st.materialize %vec_out_[0, %j] [1, 4] [1, 1] + : vector<1x32xf32> to vector<1x4xf32> %vtile = gml_st.tile [0, %j] [1, 4] [1, 1] : !gml_st.tile<1x4> - %inner_tile = gml_st.materialize %init_vec[%vtile] - : vector<1x32xf32>[!gml_st.tile<1x4>] to vector<1x4xf32> - gml_st.set_yield %inner_tile into %init_vec[%vtile] + gml_st.set_yield %inner_tile into %vec_out_[%vtile] : vector<1x4xf32> into vector<1x32xf32>[!gml_st.tile<1x4>] } : vector<1x32xf32> - %result = vector.transfer_write %result_vec, %init_tile[%c0, %c0] + %result = vector.transfer_write %result_vec, %out_tile[%c0, %c0] {in_bounds = [true, true]} : vector<1x32xf32>, tensor<1x32xf32> - gml_st.set_yield %result into %init[%tile] + %tile = gml_st.tile [%i, 0] [1, 32] [1, 1] : !gml_st.tile<1x32> + gml_st.set_yield %result into %out_[%tile] : tensor<1x32xf32> into tensor[!gml_st.tile<1x32>] } : tensor @@ -186,11 +193,12 @@ func.func @nested_parallel_with_vector(%init : tensor) // CHECK-DAG: %[[INITVEC:.*]] = vector.transfer_read %[[INITTILE]] // CHECK-SAME: memref<1x32xf32, {{.*}}>, vector<1x32xf32> // CHECK: %[[RESVEC:.*]] = gml_st.parallel -// CHECK: gml_st.materialize %[[INITVEC]] +// CHECK-SAME: outs (%[[VEC_OUT_:.*]] = %[[INITVEC]]: +// CHECK: gml_st.materialize %[[VEC_OUT_]] // CHECK: gml_st.set_yield // CHECK-SAME: vector<1x4xf32> into vector<1x32xf32>[!gml_st.tile<1x4>] // CHECK: vector.transfer_write %[[RESVEC]], %[[INITTILE]] -// CHWECK-SAME: vector<1x32xf32>, memref<1x32xf32 +// CHECK-SAME: vector<1x32xf32>, memref<1x32xf32 // CHECK: return %[[INIT]] : memref @@ -207,9 +215,9 @@ func.func @scalarized_reduction(%arg: tensor<1x?xf32>) -> tensor<1xf32> { %dim = tensor.dim %arg, %c1 : tensor<1x?xf32> %result = gml_st.for (%i) = (%c0) to (%dim) step (%c1) outs (%out = %fill: tensor<1xf32>) { - %tile = gml_st.tile [0, %i] [1, 1] [1, 1] : !gml_st.tile<1x1> - %elem = gml_st.materialize %arg[%tile] - : tensor<1x?xf32>[!gml_st.tile<1x1>] to f32 + %slice = tensor.extract_slice %arg[0, %i] [1, 1] [1, 1] + : tensor<1x?xf32> to tensor<1x1xf32> + %elem = tensor.extract %slice[%c0, %c0] : tensor<1x1xf32> %extracted = tensor.extract %out[%c0] : tensor<1xf32> %sum = arith.addf %extracted, %elem : f32 @@ -231,7 +239,8 @@ func.func @scalarized_reduction(%arg: tensor<1x?xf32>) -> tensor<1xf32> { // CHECK: %[[DIM:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<1x?xf32> // CHECK-NEXT: gml_st.for (%[[I:.*]]) = (%[[C0]]) to (%[[DIM]]) step (%[[C1]]) { -// CHECK-NEXT: %[[ARG_ELEM:.*]] = memref.load %[[ARG]][%[[C0]], %[[I]]] +// CHECK-NEXT: %[[ARG_SLICE:.*]] = memref.subview %[[ARG]][0, %[[I]]] +// CHECK-NEXT: %[[ARG_ELEM:.*]] = memref.load %[[ARG_SLICE]][%[[C0]], %[[C0]]] // CHECK-NEXT: %[[ACC:.*]] = memref.load %[[ALLOC]][%[[C0]]] : memref<1xf32> // CHECK-NEXT: %[[SUM:.*]] = arith.addf %[[ACC]], %[[ARG_ELEM]] : f32 // CHECK-NEXT: memref.store %[[SUM]], %[[ALLOC]][%[[C0]]] : memref<1xf32> @@ -252,37 +261,34 @@ func.func @matmul(%lhs: tensor<128x16xf32>, %c64 = arith.constant 64 : index %c128 = arith.constant 128 : index %matmul = gml_st.parallel (%i, %j) - = (%c0, %c0) to (%c128, %c64) step (%c8, %c4) { - %lhs_tile = gml_st.tile [%i, 0] [8, 16] [1, 1] : !gml_st.tile<8x16> - %lhs_sub = gml_st.materialize %lhs[%lhs_tile] - : tensor<128x16xf32>[!gml_st.tile<8x16>] to tensor<8x16xf32> - %rhs_tile = gml_st.tile [0, %j] [16, 4] [1, 1] : !gml_st.tile<16x4> - %rhs_sub = gml_st.materialize %rhs[%rhs_tile] - : tensor<16x64xf32>[!gml_st.tile<16x4>] to tensor<16x4xf32> - %out_tile = gml_st.tile [%i, %j] [8, 4] [1, 1] : !gml_st.tile<8x4> - %out_sub = gml_st.materialize %out[%out_tile] - : tensor<128x64xf32>[!gml_st.tile<8x4>] to tensor<8x4xf32> + = (%c0, %c0) to (%c128, %c64) step (%c8, %c4) + outs (%out_ = %out: tensor<128x64xf32>) { + %lhs_sub = tensor.extract_slice %lhs[%i, 0] [8, 16] [1, 1] + : tensor<128x16xf32> to tensor<8x16xf32> + %rhs_sub = tensor.extract_slice %rhs[0, %j] [16, 4] [1, 1] + : tensor<16x64xf32> to tensor<16x4xf32> + %out_sub = tensor.extract_slice %out_[%i, %j] [8, 4] [1, 1] + : tensor<128x64xf32> to tensor<8x4xf32> %mat_sub = gml_st.for (%k) = (%c0) to (%c16) step (%c2) outs (%out_sub_ = %out_sub: tensor<8x4xf32>) { - %lhs_tile2 = gml_st.tile [0, %k] [8, 2] [1, 1] : !gml_st.tile<8x2> - %lhs_sub2 = gml_st.materialize %lhs_sub[%lhs_tile2] - : tensor<8x16xf32>[!gml_st.tile<8x2>] to tensor<8x2xf32> - %rhs_tile2 = gml_st.tile [%k, 0] [2, 4] [1, 1] : !gml_st.tile<2x4> - %rhs_sub2 = gml_st.materialize %rhs_sub[%rhs_tile2] - : tensor<16x4xf32>[!gml_st.tile<2x4>] to tensor<2x4xf32> - %out_tile2 = gml_st.tile [0, 0] [8, 4] [1, 1] : !gml_st.tile<8x4> - %out_sub2 = gml_st.materialize %out_sub_[%out_tile2] - : tensor<8x4xf32>[!gml_st.tile<8x4>] to tensor<8x4xf32> + %lhs_sub2 = tensor.extract_slice %lhs_sub[0, %k] [8, 2] [1, 1] + : tensor<8x16xf32> to tensor<8x2xf32> + %rhs_sub2 = tensor.extract_slice %rhs_sub[%k, 0] [2, 4] [1, 1] + : tensor<16x4xf32> to tensor<2x4xf32> + %out_sub2 = tensor.extract_slice %out_sub_[0, 0] [8, 4] [1, 1] + : tensor<8x4xf32> to tensor<8x4xf32> %mat_sub2 = linalg.matmul ins(%lhs_sub2, %rhs_sub2 : tensor<8x2xf32>, tensor<2x4xf32>) outs(%out_sub2 : tensor<8x4xf32>) -> tensor<8x4xf32> + %out_tile2 = gml_st.tile [0, 0] [8, 4] [1, 1] : !gml_st.tile<8x4> gml_st.set_yield %mat_sub2 into %out_sub_[%out_tile2] : tensor<8x4xf32> into tensor<8x4xf32>[!gml_st.tile<8x4>] } : tensor<8x4xf32> - gml_st.set_yield %mat_sub into %out[%out_tile] + %out_tile = gml_st.tile [%i, %j] [8, 4] [1, 1] : !gml_st.tile<8x4> + gml_st.set_yield %mat_sub into %out_[%out_tile] : tensor<8x4xf32> into tensor<128x64xf32>[!gml_st.tile<8x4>] } : tensor<128x64xf32> return %matmul : tensor<128x64xf32> @@ -290,9 +296,96 @@ func.func @matmul(%lhs: tensor<128x16xf32>, // CHECK-LABEL: func.func @matmul // CHECK-NOT: alloc // CHECK: gml_st.parallel -// CHECK-3: memref.subview +// CHECK-COUNT-3: memref.subview // CHECK-NOT: alloc // CHECK: gml_st.for -// CHECK-4: memref.subview +// CHECK-COUNT-2: memref.subview // CHECK-NOT: alloc // CHECK: linalg.matmul + +// ----- + +func.func @materialize_out_of_place(%arg0: tensor<1xi32>) -> tensor<1xi32> { + %c0 = arith.constant 0 : index + %c42 = arith.constant 42 : i32 + + %0 = tensor.insert %c42 into %arg0[%c0] : tensor<1xi32> + %1 = tensor.extract_slice %arg0[0][1][1] : tensor<1xi32> to tensor<1xi32> + %2 = tensor.extract %1[%c0] : tensor<1xi32> + %3 = tensor.insert %2 into %0[%c0] : tensor<1xi32> + + return %3 : tensor<1xi32> +} + +// CHECK-LABEL: @materialize_out_of_place +// CHECK-SAME: %[[ARG0:.*]]: memref<1xi32> +// CHECK-DAG: %[[C42:.*]] = arith.constant 42 +// CHECK: %[[ALLOC:.*]] = memref.alloc +// CHECK: memref.copy %{{.*}}, %[[ALLOC]] +// CHECK: memref.store %[[C42]], %[[ALLOC]] +// CHECK: %[[LOADED:.*]] = memref.load %[[ARG0]] +// CHECK: memref.store %[[LOADED]], %[[ALLOC]] +// CHECK: return %[[ALLOC]] + +// ----- + +func.func @same_enclosing_repetitive_region(%2: tensor<320xf32>, + %3: tensor<320x10240xf32>) + -> tensor<320xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant -0.000000e+00 : f32 + %c320 = arith.constant 320 : index + %4 = gml_st.parallel (%i) = (%c0) to (%c320) step (%c1) + outs(%arg1 = %2: tensor<320xf32>) { + %5 = tensor.extract_slice %3[%i, 0] [1, 10240] [1, 1] : tensor<320x10240xf32> to tensor<1x10240xf32> + %6 = tensor.extract_slice %arg1[%i] [1] [1] : tensor<320xf32> to tensor<1xf32> + %7 = linalg.fill ins(%cst : f32) outs(%6 : tensor<1xf32>) -> tensor<1xf32> + %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1xf32>) -> tensor<1xf32> + + %tile = gml_st.tile [%i] [1] [1] : !gml_st.tile<1> + gml_st.set_yield %8 into %arg1[%tile] + : tensor<1xf32> into tensor<320xf32>[!gml_st.tile<1>] + } : tensor<320xf32> + return %4 : tensor<320xf32> +} +// CHECK-LABEL: @same_enclosing_repetitive_region +// CHECK-NOT: memref.alloc + +// ----- + +// CHECK-LABEL: func @gml_st_parallel_private_var( +// CHECK-SAME: %[[t:.*]]: memref<10xf32 +func.func @gml_st_parallel_private_var(%t: tensor<10xf32>) -> f32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c5 = arith.constant 5 : index + + // A copy is inserted for the uses of %t in the loop. + // CHECK: %[[t_copy:.*]] = memref.alloc() {{.*}} : memref<10xf32> + // CHECK: memref.copy %[[t]], %[[t_copy]] + + // CHECK: gml_st.parallel + + // Load from the copy and store into the shared output. + // CHECK: %[[subview:.*]] = memref.subview %[[t]] + // CHECK: memref.load %[[t_copy]] + // CHECK: memref.store %{{.*}}, %[[subview]] + %0 = gml_st.parallel (%tid) = (%c0) to (%c2) step (%c1) + outs(%o = %t: tensor<10xf32>) { + %offset = arith.muli %c5, %tid : index + %slice = tensor.extract_slice %o[%offset] [5] [1] + : tensor<10xf32> to tensor<5xf32> + %r2 = tensor.extract %t[%tid] : tensor<10xf32> + %i = tensor.insert %r2 into %slice[%c2] : tensor<5xf32> + + %tile = gml_st.tile [%offset][5][1] : !gml_st.tile<5> + gml_st.set_yield %i into %o[%tile] + : tensor<5xf32> into tensor<10xf32>[!gml_st.tile<5>] + } : tensor<10xf32> + + %r = tensor.extract %0[%c2] : tensor<10xf32> + return %r : f32 +} + diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/canonicalize.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/canonicalize.mlir index c1ff5ca4018..b655d985782 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/canonicalize.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/canonicalize.mlir @@ -1,278 +1,4 @@ -// RUN: mlir-hlo-opt %s -canonicalize -split-input-file | FileCheck %s - -#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> - -// CHECK-LABEL: func @memref_cast_into_loop( -func.func @memref_cast_into_loop(%arg0: memref<192xf32>) { - %0 = memref.cast %arg0 - : memref<192xf32> to memref<192xf32, #map> - %cst = arith.constant 0.000000e+00 : f32 - %c24 = arith.constant 24 : index - %c0 = arith.constant 0 : index - %c192 = arith.constant 192 : index - // CHECK: gml_st.loop - // CHECK-SAME: outs (%{{.*}} = %{{.*}}: memref<192xf32>) - gml_st.loop (%arg3) = (%c0) to (%c192) step (%c24) - outs (%out = %0: memref<192xf32, #map>) { - %14 = affine.min affine_map<(d0) -> (-d0 + 192, 24)>(%arg3) - %16 = memref.subview %out[%arg3] [%14] [1] - : memref<192xf32, #map> to memref - linalg.fill ins(%cst : f32) outs(%16 : memref) - gml_st.yield - } - func.return -} - -// ----- - -func.func private @foo(%A: memref<48xf32>, %B: tensor<48xf32>, - %C: memref<48xf32>) -> (tensor<48xf32>) - -func.func @fold_loop_results(%A: memref<48xf32>, %B: tensor<48xf32>, - %C: memref<48xf32>, %C_tensor: tensor<48xf32>) -> tensor<48xf32> { - %c0 = arith.constant 0 : index - %c24 = arith.constant 24 : index - %c48 = arith.constant 48 : index - %useful, %useless = gml_st.loop (%i) = (%c0) to (%c48) step (%c24) - ins (%A_ = %A: memref<48xf32>) - outs (%B_ = %B: tensor<48xf32>, - %CT_ = %C_tensor: tensor<48xf32>, - %C_ = %C: memref<48xf32>) { - %result = func.call @foo(%A_, %B_, %C_) - : (memref<48xf32>, tensor<48xf32>, memref<48xf32>)-> (tensor<48xf32>) - gml_st.yield %result, %CT_ : tensor<48xf32>, tensor<48xf32> - } - func.return %useful : tensor<48xf32> -} - -// CHECK-LABEL: func @fold_loop_results( -// CHECK-SAME: %[[A:.*]]: [[BUF_TY:memref<48xf32>]], %[[B:.*]]: [[TY:tensor<48xf32>]], -// CHECK-SAME: %[[C:.*]]: [[BUF_TY]], %[[C_TENSOR:.*]]: [[TY]]) -> [[TY]] { - -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C24:.*]] = arith.constant 24 : index -// CHECK-DAG: %[[C48:.*]] = arith.constant 48 : index - -// CHECK-NOT: %{{.*}} = gml_st.loop -// CHECK: %[[RESULT:.*]] = gml_st.loop (%{{.*}}) = (%[[C0]]) -// CHECK-SAME: to (%[[C48]]) step (%[[C24]]) -// CHECK-SAME: ins (%[[A_:.*]] = %[[A]]: [[BUF_TY]]) -// CHECK-SAME: outs (%[[B_:.*]] = %[[B]]: [[TY]], %[[C_:.*]] = %[[C]]: [[BUF_TY]]) { -// CHECK-NEXT: %[[RES:.*]] = func.call @foo(%[[A_]], %[[B_]], %[[C_]]) -// CHECK-NEXT: gml_st.yield %[[RES]] : - -// CHECK: return %[[RESULT]] - -// ----- - -func.func private @foo(%A: memref<192xf32>, %B: tensor<192xf32>) -> tensor<192xf32> - -func.func @fold_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>, - %B_tensor: tensor<192xf32>) -> tensor<192xf32> { - %c0 = arith.constant 0 : index - %c24 = arith.constant 24 : index - %c192 = arith.constant 192 : index - %result = gml_st.loop (%i) = (%c0) to (%c192) step (%c24) - ins (%A_ = %A: memref<192xf32>, %AT_ = %A_tensor: tensor<192xf32>) - outs (%BT_ = %B_tensor: tensor<192xf32>) { - %0 = func.call @foo(%A_, %BT_) : (memref<192xf32>, tensor<192xf32>) -> tensor<192xf32> - gml_st.yield %0 : tensor<192xf32> - } - func.return %result : tensor<192xf32> -} - -// CHECK-LABEL: func @fold_loop_inputs -// CHECK: %[[RESULT:.*]] = gml_st.loop -// CHECK-SAME: ins (%{{.*}} = %{{.*}}: memref<192xf32>) - -// CHECK: return %[[RESULT]] - -// ----- - -// CHECK-LABEL: func @dim_of_loop_input_no_canonicalize( -// CHECK-SAME: %[[arg0:.*]]: tensor, %[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor -// CHECK: %[[c0:.*]] = arith.constant 0 : index -// CHECK: gml_st.loop {{.*}} outs (%[[o:.*]] = -// CHECK: %[[dim:.*]] = tensor.dim %[[o]], %[[c0]] -// CHECK: arith.index_cast %[[dim]] -func.func @dim_of_loop_input_no_canonicalize(%arg0: tensor, %arg1: tensor, %arg2: tensor, %s: index) - -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %d0 = tensor.dim %arg0, %c0 : tensor - %d1 = tensor.dim %arg0, %c1 : tensor - %r = gml_st.loop (%iv0, %iv1) = (%c0, %c0) - to (%d0, %d1) step (%c1, %c1) - ins (%in0 = %arg0 : tensor, %in1 = %arg1 : tensor) - outs (%out1 = %arg2 : tensor) { - %inner_dim = tensor.dim %out1, %c0 : tensor - %cast1 = arith.index_cast %inner_dim : index to i32 - %cast2 = arith.sitofp %cast1 : i32 to f32 - %fill = linalg.fill ins(%cast2 : f32) outs(%out1 : tensor) -> tensor - %slice = tensor.extract_slice %fill[0, 0][%s, %s][1, 1] : tensor to tensor - gml_st.yield %slice : tensor - } - func.return %r : tensor -} - -// ----- - -// CHECK-LABEL: func @dim_of_loop_input( -// CHECK-SAME: %[[arg0:.*]]: tensor, %[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor -// CHECK: %[[c0:.*]] = arith.constant 0 : index -// CHECK: gml_st.loop -// CHECK: %[[dim:.*]] = tensor.dim %[[arg1]], %[[c0]] -// CHECK: arith.index_cast %[[dim]] -func.func @dim_of_loop_input(%arg0: tensor, %arg1: tensor, %arg2: tensor) - -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %d0 = tensor.dim %arg0, %c0 : tensor - %d1 = tensor.dim %arg0, %c1 : tensor - %r = gml_st.loop (%iv0, %iv1) = (%c0, %c0) - to (%d0, %d1) step (%c1, %c1) - ins (%in0 = %arg0 : tensor, %in1 = %arg1 : tensor) - outs (%out1 = %arg2 : tensor) { - %inner_dim = tensor.dim %in1, %c0 : tensor - %cast1 = arith.index_cast %inner_dim : index to i32 - %cast2 = arith.sitofp %cast1 : i32 to f32 - %fill = linalg.fill ins(%cast2 : f32) outs(%out1 : tensor) -> tensor - gml_st.yield %fill : tensor - } - func.return %r : tensor -} - -// ----- - -// CHECK-LABEL: func @dim_of_loop_result( -// CHECK-SAME: %[[arg0:.*]]: tensor, %[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor -// CHECK: %[[c0:.*]] = arith.constant 0 : index -// CHECK: tensor.dim %[[arg2]], %[[c0]] -func.func @dim_of_loop_result(%arg0: tensor, %arg1: tensor, %arg2: tensor, %s: index) - -> index { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %d0 = tensor.dim %arg0, %c0 : tensor - %d1 = tensor.dim %arg0, %c1 : tensor - %r = gml_st.loop (%iv0, %iv1) = (%c0, %c0) - to (%d0, %d1) step (%c1, %c1) - ins (%in0 = %arg0 : tensor, %in1 = %arg1 : tensor) - outs (%out1 = %arg2 : tensor) { - %1 = tensor.insert_slice %arg0 into %out1 [0, 0] [%s, %s] [1, 1] : tensor into tensor - gml_st.yield %1 : tensor - } - %r2 = tensor.dim %r, %c0 : tensor - func.return %r2 : index -} - -// ----- - -// CHECK-LABEL: func @dim_of_loop_result_no_canonicalize( -// CHECK-SAME: %[[arg0:.*]]: tensor, %[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor -// CHECK: %[[c0:.*]] = arith.constant 0 : index -// CHECK: %[[r:.*]] = gml_st.loop -// CHECK: tensor.dim %[[r]], %[[c0]] -func.func @dim_of_loop_result_no_canonicalize(%arg0: tensor, %arg1: tensor, %arg2: tensor, %s: index) - -> index { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %d0 = tensor.dim %arg0, %c0 : tensor - %d1 = tensor.dim %arg0, %c1 : tensor - %r = gml_st.loop (%iv0, %iv1) = (%c0, %c0) - to (%d0, %d1) step (%c1, %c1) - ins (%in0 = %arg0 : tensor, %in1 = %arg1 : tensor) - outs (%out1 = %arg2 : tensor) { - %1 = tensor.insert_slice %arg0 into %arg1 [0, 0] [%s, %s] [1, 1] : tensor into tensor - gml_st.yield %1 : tensor - } - %r2 = tensor.dim %r, %c0 : tensor - func.return %r2 : index -} - -// ----- - -func.func private @do(%A: tensor, %B: tensor) -> tensor - -func.func @fold_tensor_cast(%in: tensor<4x600xf32>, - %out: tensor<4xf32>) -> tensor<4xf32> { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c600 = arith.constant 600 : index - - %in_cast = tensor.cast %in : tensor<4x600xf32> to tensor - %out_cast = tensor.cast %out : tensor<4xf32> to tensor - - %result = gml_st.loop (%i) = (%c0) to (%c600) step (%c4) - ins (%in_ = %in_cast: tensor) - outs (%out_ = %out_cast: tensor) - iterators[#gml_st.iterator_type] { - %dim_in = tensor.dim %in_, %c0 : tensor - %dim_out = tensor.dim %out_, %c0 : tensor - - %in_sub = tensor.extract_slice %in_[0, %i] [%dim_in, 4] [1, 1] - : tensor to tensor - %out_sub = tensor.extract_slice %out_[0] [%dim_out] [1] - : tensor to tensor - %result_sub = func.call @do(%in_sub, %out_sub): - (tensor, tensor) -> tensor - %out_update = tensor.insert_slice %result_sub into %out_[0] [%dim_out] [1] - : tensor into tensor - gml_st.yield %out_update : tensor - } - %result_cast = tensor.cast %result : tensor to tensor<4xf32> - func.return %result_cast : tensor<4xf32> -} - -// CHECK-LABEL: func @fold_tensor_cast( -// CHECK-SAME: %[[IN:.*]]: tensor<4x600xf32>, %[[OUT:.*]]: tensor<4xf32>) - -// CHECK-DAG: %[[C600:.*]] = arith.constant 600 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - -// CHECK: %[[RESULT:.*]] = gml_st.loop -// CHECK-SAME: ins (%[[IN_:.*]] = %[[IN]]: tensor<4x600xf32>) -// CHECK-SAME: outs (%[[OUT_:.*]] = %[[OUT]]: tensor<4xf32>) iterators - -// CHECK: %[[IN_SUB:.*]] = tensor.extract_slice -// CHECK: %[[IN_SUB_CAST:.*]] = tensor.cast %[[IN_SUB]] -// CHECK-SAME: : tensor<4x4xf32> to tensor - -// CHECK: %[[OUT_SUB:.*]] = tensor.cast %[[OUT_]] -// CHECK-SAME: : tensor<4xf32> to tensor - -// CHECK: %[[RESULT_SUB:.*]] = func.call @do(%[[IN_SUB_CAST]], %[[OUT_SUB]]) -// CHECK: %[[RESULT_CAST:.*]] = tensor.cast %[[RESULT_SUB]] -// CHECK: gml_st.yield %[[RESULT_CAST]] : tensor<4xf32> -// CHECK: } -// CHECK: return %[[RESULT]] : tensor<4xf32> - -// ----- - -func.func private @reduce(%A: tensor<4xf32>, %B: tensor) -> tensor - -// CHECK-LABEL: @remove_empty_loop -func.func @remove_empty_loop(%in: tensor<16xf32>, %out: tensor, - %buf: memref) -> tensor{ - // CHECK-NOT: gml_st.loop - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c16 = arith.constant 16 : index - %0 = gml_st.loop (%i, %j) = (%c0, %c0) to (%c16, %c0) step (%c4, %c4) - ins (%in_ = %in: tensor<16xf32>) - outs (%out_ = %out: tensor, %buf_ = %buf: memref) - iterators[#gml_st.iterator_type, - #gml_st.iterator_type] { - %in_sub = tensor.extract_slice %in_[%i][4][1] - : tensor<16xf32> to tensor<4xf32> - %result = func.call @reduce(%in_sub, %out_): - (tensor<4xf32>, tensor) -> tensor - gml_st.yield %result : tensor - } - func.return %0 : tensor -} - -// ----- +// RUN: mlir-hlo-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s // CHECK-LABEL: @fold_unit_dim func.func @fold_unit_dim() -> tensor<8x10xf32> { @@ -300,15 +26,37 @@ func.func @fold_unit_dim() -> tensor<8x10xf32> { // ----- +// CHECK-LABEL: @remove_empty_for +func.func @remove_empty_for() -> tensor<8x10xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + %c8 = arith.constant 8 : index + // CHECK: %[[INIT:.*]] = tensor.empty + %init = tensor.empty() : tensor<8x10xf32> + // CHECK-NOT: gml_st.for + %out = gml_st.for (%i, %j) = (%c0, %c4) to (%c4, %c4) step (%c1, %c1) + outs(%out_ = %init : tensor<8x10xf32>) { + %tile = gml_st.tile [%i, %j] [4, 1] [1, 1] : !gml_st.tile<4x1> + + %val = tensor.empty() : tensor<4x1xf32> + gml_st.set_yield %val into %out_[%tile] + : tensor<4x1xf32> into tensor<8x10xf32>[!gml_st.tile<4x1>] + } : tensor<8x10xf32> + // CHECK: return %[[INIT]] + func.return %out : tensor<8x10xf32> +} + +// ----- + // CHECK-LABEL: @fold_constant_tile_through_materialize func.func @fold_constant_tile_through_materialize(%in: tensor<4xf32>) -> tensor { %c2 = arith.constant 2 : index - // CHECK: %[[TILE:.*]] = gml_st.tile [2] [2] [1] : !gml_st.tile<2> - %tile = gml_st.tile [%c2] [%c2] [1] : !gml_st.tile - // CHECK: %[[MAT:.*]] = gml_st.materialize {{.*}}[%[[TILE]]] : tensor<4xf32>[!gml_st.tile<2>] - %mat = gml_st.materialize %in[%tile] : tensor<4xf32>[!gml_st.tile] - to tensor + // CHECK: %[[MAT:.*]] = tensor.extract_slice + // CHECK-SAME: [2] [2] [1] : tensor<4xf32> to tensor<2xf32> + %mat = tensor.extract_slice %in[%c2] [%c2] [1] : tensor<4xf32> to tensor // CHECK: %[[RET:.*]] = tensor.cast %[[MAT]] : tensor<2xf32> to tensor // CHECK: return %[[RET]] func.return %mat : tensor @@ -324,11 +72,11 @@ func.func @fold_constant_set_yield(%in: tensor, %cst = arith.constant 0.000000e+00 : f32 %1 = gml_st.for (%arg0) = (%c0) to (%c8) step (%c2) outs (%arg1 = %out: tensor) { - %tile = gml_st.tile [0, 0] [%c2, %c2] [1, 1] : !gml_st.tile - %out_sub = gml_st.materialize %out[%tile] : - tensor[!gml_st.tile] to tensor + %out_sub = tensor.extract_slice %out[0, 0] [%c2, %c2] [1, 1] : + tensor to tensor %fill = linalg.fill ins(%cst : f32) outs(%out_sub : tensor) -> tensor + %tile = gml_st.tile [0, 0] [%c2, %c2] [1, 1] : !gml_st.tile gml_st.set_yield %fill into %arg1[%tile] : tensor into tensor[!gml_st.tile] } : tensor @@ -337,10 +85,10 @@ func.func @fold_constant_set_yield(%in: tensor, // CHECK-LABEL: @fold_constant_set_yield // CHECK: %[[FOR:.*]] = gml_st.for{{.*}}: tensor -// CHECK: %[[TILE:.*]] = gml_st.tile [0, 0] [2, 2] {{.*}} !gml_st.tile<2x2> -// CHECK-NOT: builtin.unrealized_conversion_cast -// CHECK-NEXT: %[[SLICE:.*]] = gml_st.materialize %{{.*}}[%[[TILE]]] {{.*}} to tensor<2x2xf32> -// CHECK: %[[FILL:.*]] = linalg.fill {{.*}} outs(%[[SLICE]] : tensor<2x2xf32>) +// CHECK-NEXT: %[[SLICE:.*]] = tensor.extract_slice +// CHECK-SAME: [0, 0] [2, 2] +// CHECK-NEXT: %[[FILL:.*]] = linalg.fill {{.*}} outs(%[[SLICE]] : tensor<2x2xf32>) +// CHECK-NEXT: %[[TILE:.*]] = gml_st.tile [0, 0] [2, 2] {{.*}} !gml_st.tile<2x2> // CHECK-NEXT: gml_st.set_yield %[[FILL]] into %{{.*}}[%[[TILE]]] : tensor<2x2xf32> into tensor[!gml_st.tile<2x2>] // ----- @@ -364,10 +112,11 @@ func.func @fold_constant_set_yield_scalar(%in: tensor, } // CHECK-LABEL: @fold_constant_set_yield_scalar -// CHECK: %[[FOR:.*]] = gml_st.for{{.*}}: tensor -// CHECK: %[[TILE:.*]] = gml_st.tile [0] [1] {{.*}} !gml_st.tile<1> +// CHECK: %[[FOR:.*]] = gml_st.for (%{{.*}}) outs +// CHECK-SAME: (%[[INIT_:.*]] = %[[INIT:.*]]: tensor) +// CHECK: %[[TILE:.*]] = gml_st.tile [0] [1] [1] : !gml_st.tile<1> // CHECK-NOT: builtin.unrealized_conversion_cast -// CHECK-NEXT: gml_st.set_yield %[[SCALAR:.*]] into %{{.*}}[%[[TILE]]] : f32 into tensor[!gml_st.tile<1>] +// CHECK: gml_st.set_yield %[[SCALAR:.*]] into %[[INIT_]][%[[TILE]]] : f32 into tensor[!gml_st.tile<1>] // ----- @@ -377,20 +126,19 @@ func.func @fold_constant_for(%in: tensor, %c2 = arith.constant 2 : index %c8 = arith.constant 8 : index %cst = arith.constant 0.000000e+00 : f32 - %1 = gml_st.tile [0, 0] [8, 2] [1, 1] : !gml_st.tile<8x2> - %3 = gml_st.materialize %out[%1] : - tensor[!gml_st.tile<8x2>] to tensor<8x2xf32> + %3 = tensor.extract_slice %out[0, 0] [8, 2] [1, 1] : + tensor to tensor<8x2xf32> %cast_3 = tensor.cast %3 : tensor<8x2xf32> to tensor %4 = gml_st.for (%arg0) = (%c0) to (%c8) step (%c2) outs (%arg1 = %cast_3: tensor) { - %tile = gml_st.tile [0, %arg0] [8, 2] [1, 1] : !gml_st.tile<8x2> - %2 = builtin.unrealized_conversion_cast %tile : - !gml_st.tile<8x2> to !gml_st.tile - %out_sub = gml_st.materialize %arg1[%tile] : - tensor[!gml_st.tile<8x2>] to tensor<8x2xf32> + %out_sub = tensor.extract_slice %arg1[0, %arg0] [8, 2] [1, 1] : + tensor to tensor<8x2xf32> %fill = linalg.fill ins(%cst : f32) outs(%out_sub : tensor<8x2xf32>) -> tensor<8x2xf32> %cast_fill = tensor.cast %fill : tensor<8x2xf32> to tensor + %tile = gml_st.tile [0, %arg0] [8, 2] [1, 1] : !gml_st.tile<8x2> + %2 = builtin.unrealized_conversion_cast %tile : + !gml_st.tile<8x2> to !gml_st.tile gml_st.set_yield %cast_fill into %arg1[%2] : tensor into tensor[!gml_st.tile] } : tensor @@ -398,11 +146,11 @@ func.func @fold_constant_for(%in: tensor, } // CHECK-LABEL: @fold_constant_for -// CHECK: %[[SLICE:.*]] = gml_st.materialize {{.*}} to tensor<8x2xf32> +// CHECK: %[[SLICE:.*]] = tensor.extract_slice {{.*}} to tensor<8x2xf32> // CHECK-NOT: tensor.cast -// CHECK: %[[FOR1:.*]] = gml_st.for (%{{.*}} = (%c0) to {{.*}} outs (%[[ARG1:.*]] = %[[SLICE]]: tensor<8x2xf32> +// CHECK: %[[FOR1:.*]] = gml_st.for (%[[I:.*]]) = (%c0) to {{.*}} outs (%[[ARG1:.*]] = %[[SLICE]]: tensor<8x2xf32> +// CHECK-NEXT: %[[FOR1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[I]]] [8, 2] [1, 1] : tensor<8x2xf32> to tensor<8x2xf32> // CHECK: %[[FOR1_TILE:.*]] = gml_st.tile {{.*}} [8, 2] {{.*}} !gml_st.tile<8x2> -// CHECK-NEXT: %[[FOR1_SLICE:.*]] = gml_st.materialize %{{.*}}[%[[FOR1_TILE]]] {{.*}} to tensor<8x2xf32> // CHECK: gml_st.set_yield %{{.*}} into %[[ARG1]][%[[FOR1_TILE]]] : tensor<8x2xf32> into tensor<8x2xf32>[!gml_st.tile<8x2>] // CHECK-NEXT: } : tensor<8x2xf32> // CHECK: %[[CAST:.*]] = tensor.cast %[[FOR1]] : tensor<8x2xf32> to tensor @@ -412,16 +160,251 @@ func.func @fold_constant_for(%in: tensor, func.func @fold_cast_to_materialize_source(%in: tensor<4xf32>) -> tensor<2xf32> { - %tile = gml_st.tile [2] [2] [1] : !gml_st.tile<2> %cast = tensor.cast %in : tensor<4xf32> to tensor - %mat = gml_st.materialize %cast[%tile] : tensor[!gml_st.tile<2>] - to tensor<2xf32> + %mat = tensor.extract_slice %cast[2] [2] [1] + : tensor to tensor<2xf32> func.return %mat : tensor<2xf32> } // CHECK-LABEL: @fold_cast_to_materialize_source // CHECK-SAME: %[[IN:.*]]: tensor<4xf32> -// CHECK: %[[TILE:.*]] = gml_st.tile [2] [2] [1] : !gml_st.tile<2> // CHECK-NOT: tensor.cast -// CHECK: %[[MAT:.*]] = gml_st.materialize %[[IN]][%[[TILE]]] : tensor<4xf32>[!gml_st.tile<2>] +// CHECK: %[[MAT:.*]] = tensor.extract_slice %[[IN]][2] [2] [1] : tensor<4xf32> to tensor<2xf32> // CHECK: return %[[MAT]] + +// ----- + +func.func @inline_single_iteration_for( + %in: tensor<8x8xf32>) -> tensor<8x8xf32> { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %13 = gml_st.for (%arg4) = (%c0) to (%c1) step (%c8) + outs (%arg5 = %0: tensor<8x8xf32>) { + %19 = gml_st.tile [0, 0] [8, 8] [1, 1] : !gml_st.tile<8x8> + %11 = linalg.fill ins(%cst : f32) outs(%arg5 : tensor<8x8xf32>) + -> tensor<8x8xf32> + gml_st.set_yield %11 into %arg5[%19] : tensor<8x8xf32> + into tensor<8x8xf32>[!gml_st.tile<8x8>] + } : tensor<8x8xf32> + return %13 : tensor<8x8xf32> +} + +// CHECK-LABEL: @inline_single_iteration_for +// CHECK-NOT: gml_st.for +// CHECK: linalg.fill + +// ----- + +func.func @inline_single_iteration_parallel( + %in: tensor<8x8xf32>) -> tensor<8x8xf32> { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %13 = gml_st.parallel (%arg4, %arg5) = (%c0, %c0) to (%c1, %c1) + step (%c8, %c8) outs (%out_ = %0: tensor<8x8xf32>) { + %20 = tensor.extract_slice %out_[%arg4, %arg5] [8, 8] [1, 1] + : tensor<8x8xf32> to tensor<8x8xf32> + %11 = linalg.fill ins(%cst : f32) outs(%20 : tensor<8x8xf32>) + -> tensor<8x8xf32> + %19 = gml_st.tile [%arg4, %arg5] [8, 8] [1, 1] : !gml_st.tile<8x8> + gml_st.set_yield %11 into %out_[%19] : tensor<8x8xf32> + into tensor<8x8xf32>[!gml_st.tile<8x8>] + } : tensor<8x8xf32> + return %13 : tensor<8x8xf32> +} + +// CHECK-LABEL: @inline_single_iteration_parallel +// CHECK-NOT: gml_st.parallel +// CHECK: tensor.empty +// CHECK-NEXT: linalg.fill + +// ----- + +func.func @collapse_one_dim_parallel(%in: tensor<8x8xf32>) -> tensor<8x8xf32> { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %13 = gml_st.parallel (%arg4, %arg5) = (%c0, %c0) to (%c1, %c16) + step (%c8, %c8) outs (%out_ = %0: tensor<8x8xf32>) { + %19 = gml_st.tile [%arg4, %arg5] [8, 8] [1, 1] : !gml_st.tile<8x8> + %11 = linalg.fill ins(%cst : f32) outs(%out_ : tensor<8x8xf32>) + -> tensor<8x8xf32> + gml_st.set_yield %11 into %out_[%19] : tensor<8x8xf32> + into tensor<8x8xf32>[!gml_st.tile<8x8>] + } : tensor<8x8xf32> + return %13 : tensor<8x8xf32> +} + +// CHECK-LABEL: @collapse_one_dim_parallel +// CHECK: gml_st.parallel (%[[ARG:.*]]) = (%c0) to (%c16) step (%c8) +// CHECK: gml_st.tile [0, %[[ARG]]] +// CHECK: linalg.fill +// CHECK: gml_st.set_yield + +// ----- + +func.func @remove_empty_parallel(%in: tensor<8x8xf32>) -> tensor<8x8xf32> { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %13 = gml_st.parallel (%arg4, %arg5) = (%c0, %c16) to (%c1, %c16) + step (%c8, %c8) outs (%out_ = %0: tensor<8x8xf32>) { + %19 = gml_st.tile [%arg4, %arg5] [8, 8] [1, 1] : !gml_st.tile<8x8> + %11 = linalg.fill ins(%cst : f32) outs(%out_ : tensor<8x8xf32>) + -> tensor<8x8xf32> + gml_st.set_yield %11 into %out_[%19] : tensor<8x8xf32> + into tensor<8x8xf32>[!gml_st.tile<8x8>] + } : tensor<8x8xf32> + return %13 : tensor<8x8xf32> +} + +// CHECK-LABEL: @remove_empty_parallel +// CHECK-NOT: gml_st.parallel +// CHECK: %[[EMPTY:.*]] = tensor.empty +// CHECK: return %[[EMPTY]] + +// ----- + +func.func @fold_for_iter_arg(%in: tensor<8x8xf32>) -> tensor<8x8xf32> { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %1 = tensor.empty() : tensor<8x8xf32> + %13:2 = gml_st.for (%arg4) = (%c0) to (%c16) step (%c8) + outs (%arg5 = %0: tensor<8x8xf32>, %arg6 = %1: tensor<8x8xf32>) { + %19 = gml_st.tile [0, 0] [8, 8] [1, 1] : !gml_st.tile<8x8> + %11 = linalg.fill ins(%cst : f32) outs(%arg5 : tensor<8x8xf32>) + -> tensor<8x8xf32> + gml_st.set_yield %11 into %arg5[%19] + : tensor<8x8xf32> into tensor<8x8xf32>[!gml_st.tile<8x8>], + %arg6 into %arg6[%19] + : tensor<8x8xf32> into tensor<8x8xf32>[!gml_st.tile<8x8>], + } : tensor<8x8xf32>, tensor<8x8xf32> + return %13#0 : tensor<8x8xf32> +} + +// CHECK-LABEL: @fold_for_iter_arg +// CHECK: %[[INIT:.*]] = tensor.empty() +// CHECK-NOT: tensor.empty() +// CHECK: %[[FOR:.*]] = gml_st.for {{.*}} outs (%[[ARG:.*]] = %[[INIT]]: tensor<8x8xf32>) { +// CHECK: gml_st.set_yield {{.*}} into %[[ARG]][{{.*}}] : tensor<8x8xf32> into tensor<8x8xf32> +// CHECK: } : tensor<8x8xf32> +// CHECK: return %[[FOR]] : tensor<8x8xf32 + +// ----- + +func.func @fold_for_iter_arg_no_args(%in: tensor<8x8xf32>) -> tensor<8x8xf32> { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %1 = tensor.empty() : tensor<8x8xf32> + %13 = gml_st.for (%arg4) = (%c0) to (%c8) step (%c8) outs (%arg6 = %1: tensor<8x8xf32>) { + %19 = gml_st.tile [0, 0] [8, 8] [1, 1] : !gml_st.tile<8x8> + gml_st.set_yield %arg6 into %arg6[%19] + : tensor<8x8xf32> into tensor<8x8xf32>[!gml_st.tile<8x8>], + } : tensor<8x8xf32> + return %13 : tensor<8x8xf32> +} + +// CHECK-LABEL: @fold_for_iter_arg_no_args +// CHECK: %[[INIT:.*]] = tensor.empty() +// CHECK-NEXT: return %[[INIT]] : tensor<8x8xf32 + +// ----- + +func.func @collapse_empty_for_vector(%in: vector<8x8xf32>) -> vector<8x8xf32> { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant dense<0.000000e+00> : vector<8x8xf32> + %0 = tensor.empty() : tensor<8x8xf32> + %6 = vector.transfer_read %0[%c0, %c0], %cst {in_bounds = [true, true]} : + tensor<8x8xf32>, vector<8x8xf32> + %13 = gml_st.for (%arg4) = (%c0) to (%c1) step (%c8) + outs (%arg5 = %6: vector<8x8xf32>) { + %19 = gml_st.tile [0, 0] [8, 8] [1, 1] : !gml_st.tile<8x8> + %20 = tensor.extract_slice %0[0, 0] [8, 8] [1, 1] + : tensor<8x8xf32> to tensor<8x8xf32> + %7 = vector.transfer_write %arg5, %20[%c0, %c0] {in_bounds = [true, true]} : + vector<8x8xf32>, tensor<8x8xf32> + %11 = linalg.fill ins(%cst : f32) outs(%7 : tensor<8x8xf32>) + -> tensor<8x8xf32> + %8 = vector.transfer_read %11[%c0, %c0], %cst {in_bounds = [true, true]} : + tensor<8x8xf32>, vector<8x8xf32> + gml_st.set_yield %8 into %arg5[%19] : vector<8x8xf32> + into vector<8x8xf32>[!gml_st.tile<8x8>] + } : vector<8x8xf32> + return %13 : vector<8x8xf32> +} + +// CHECK-LABEL: @collapse_empty_for_vector +// CHECK-NOT: gml_st.for +// CHECK: linalg.fill +// CHECK: %[[READ:.*]] = vector.transfer_read +// CHECK: return %[[READ]] : vector<8x8xf32> + +// ----- + +func.func @fold_tensor_cast_into_parallel( + %in: tensor<2xi32>, %out: tensor<2xi32>) -> tensor<2xi32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %cst = arith.constant 100500 : i32 + + %out_cast = tensor.cast %out : tensor<2xi32> to tensor + %result = gml_st.parallel (%i) = (%c0) to (%c2) step (%c1) + outs (%out_ = %out_cast: tensor) { + %tile = gml_st.tile [%i] [1] [1] : !gml_st.tile<1> + gml_st.set_yield %cst into %out_[%tile] + : i32 into tensor[!gml_st.tile<1>] + } : tensor + %result_cast = tensor.cast %result + : tensor to tensor<2xi32> + + func.return %result_cast : tensor<2xi32> +} +// CHECK-LABEL: @fold_tensor_cast_into_parallel +// CHECK: gml_st.parallel +// CHECK-NEXT: gml_st.tile +// CHECK-NEXT: gml_st.set_yield +// CHECK-SAME: i32 into tensor<2xi32> +// CHECK-NEXT: } : tensor<2xi32> +// CHECK-NEXT: return + +// ----- + +func.func @dim_of_parallel_loop( + %in: tensor<2x10xi32>, %out: tensor<2x10xi32>) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c10 = arith.constant 10 : index + %cst = arith.constant 100500 : i32 + + %result = gml_st.parallel (%i, %j) = (%c0, %c0) to (%c2, %c10) + step (%c1, %c1) outs (%out_ = %out: tensor<2x10xi32>) { + %tile = gml_st.tile [%i, %j] [1, 1] [1, 1] : !gml_st.tile<1x1> + gml_st.set_yield %cst into %out_[%tile] + : i32 into tensor<2x10xi32>[!gml_st.tile<1x1>] + } : tensor<2x10xi32> + + %dim = tensor.dim %result, %c1 : tensor<2x10xi32> + func.return %dim : index +} +// CHECK-LABEL: @dim_of_parallel_loop +// CHECK: %[[C10:.*]] = arith.constant 10 +// CHECK-NEXT: return %[[C10]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/collapse-shape.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/collapse-shape.mlir new file mode 100644 index 00000000000..e165ac57023 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/collapse-shape.mlir @@ -0,0 +1,288 @@ +// RUN: mlir-hlo-opt %s --split-input-file --gml-collapse-shape | FileCheck %s + +// RUN: mlir-hlo-opt %s --split-input-file \ +// RUN: --gml-collapse-shape="retain-trailing-dims=1" | \ +// RUN: FileCheck %s --check-prefix=CHECK-1 + +// RUN: mlir-hlo-opt %s --split-input-file \ +// RUN: --gml-collapse-shape="retain-trailing-dims=2" | \ +// RUN: FileCheck %s --check-prefix=CHECK-2 + +// RUN: mlir-hlo-opt %s --split-input-file \ +// RUN: --gml-collapse-shape="retain-trailing-dims=3" | \ +// RUN: FileCheck %s --check-prefix=CHECK-3 + +func.func @bcast(%arg0: tensor<2x4x2048xf32>) -> tensor<2x4x2048x4096xf32> { + %0 = tensor.empty() : tensor<2x4x2048x4096xf32> + %1 = linalg.broadcast + ins(%arg0 : tensor<2x4x2048xf32>) + outs(%0 : tensor<2x4x2048x4096xf32>) + dimensions = [3] + return %1 : tensor<2x4x2048x4096xf32> +} + +// CHECK: func.func @bcast(%[[ARG0:.*]]: tensor<2x4x2048xf32>) +// CHECK-NOT: collapse_shape +// CHECK-NOT: expand_shape + +// CHECK-1: func.func @bcast(%[[ARG0:.*]]: tensor<2x4x2048xf32>) +// CHECK-1: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] [ +// CHECK-1-SAME: [0, 1, 2]] +// CHECK-1: %[[EMPTY:.*]] = tensor.empty() +// CHECK-1: %[[BROADCAST:.*]] = linalg.broadcast +// CHECK-1: ins(%[[COLLAPSED]] : tensor<16384xf32>) +// CHECK-1: outs(%[[EMPTY]] : tensor<16384x4096xf32>) +// CHECK-1: dimensions = [1] +// CHECK-1: %[[EXPANDED:.*]] = tensor.expand_shape %[[BROADCAST]] [ +// CHECK-1-SAME: [0, 1, 2], [3]] +// CHECK-1: return %[[EXPANDED]] + +// CHECK-2: func.func @bcast(%[[ARG0:.*]]: tensor<2x4x2048xf32>) +// CHECK-2: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] [ +// CHECK-2-SAME: [0, 1], [2]] +// CHECK-2: %[[EMPTY:.*]] = tensor.empty() +// CHECK-2: %[[BROADCASTED:.*]] = linalg.broadcast +// CHECK-2-SAME: ins(%[[COLLAPSED]] : tensor<8x2048xf32>) +// CHECK-2-SAME: outs(%[[EMPTY]] : tensor<8x2048x4096xf32>) +// CHECK-2: dimensions = [2] +// CHECK-2: %[[EXPANDED:.*]] = tensor.expand_shape %[[BROADCASTED]] [ +// CHECK-2-SAME: [0, 1], [2], [3]] +// CHECK-2: return %[[EXPANDED]] + +// CHECK-3: func.func @bcast(%[[ARG0:.*]]: tensor<2x4x2048xf32>) +// CHECK-3-NOT: collapse_shape +// CHECK-3-NOT: expand_shape + +// ----- + +func.func @bcast_from_scalar() -> tensor<2x4x2048x4096xf32> { + %0 = tensor.empty() : tensor<2x4x2048x4096xf32> + %cst = arith.constant 0xFF800000 : f32 + %1 = tensor.empty() : tensor + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor + %3 = linalg.broadcast + ins(%2 : tensor) + outs(%0 : tensor<2x4x2048x4096xf32>) + dimensions = [0, 1, 2, 3] + return %3 : tensor<2x4x2048x4096xf32> +} + +// CHECK: func.func @bcast_from_scalar() +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<67108864xf32> +// CHECK: %[[BROADCAST:.*]] = linalg.broadcast +// CHECK: ins(%{{.*}} : tensor) +// CHECK: outs(%[[EMPTY]] : tensor<67108864xf32>) +// CHECK: dimensions = [0] +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[BROADCAST]] [ +// CHECK-SAME: 0, 1, 2, 3]] +// CHECK: return %[[EXPANDED]] + +// CHECK-1: func.func @bcast_from_scalar() +// CHECK-1: %[[EMPTY:.*]] = tensor.empty() : tensor<16384x4096xf32> +// CHECK-1: %[[BROADCAST:.*]] = linalg.broadcast +// CHECK-1-SAME: ins(%{{.*}} : tensor) +// CHECK-1-SAME: outs(%[[EMPTY]] : tensor<16384x4096xf32>) +// CHECK-1-SAME: dimensions = [1, 0] +// CHECK-1: %[[EXPANDED:.*]] = tensor.expand_shape %[[BROADCAST]] [ +// CHECK-1-SAME: [0, 1, 2], [3]] +// CHECK-1: return %[[EXPANDED]] + +// CHECK-2: func.func @bcast_from_scalar() +// CHECK-2: %[[EMPTY:.*]] = tensor.empty() : tensor<8x2048x4096xf32> +// CHECK-2: %[[BROADCAST:.*]] = linalg.broadcast +// CHECK-2-SAME: ins(%{{.*}} : tensor +// CHECK-2-SAME: outs(%[[EMPTY]] : tensor<8x2048x4096xf32>) +// CHECK-2-SAME: dimensions = [1, 2, 0] +// CHECK-2: %[[EXPANDED:.*]] = tensor.expand_shape %[[BROADCAST]] [ +// CHECK-2-SAME: [0, 1], [2], [3]] +// CHECK-2: return %[[EXPANDED]] + +// CHECK-3: func.func @bcast_from_scalar() +// CHECK-3-NOT: collapse_shape +// CHECK-3-NOT: expand_shape + +// ----- + +func.func @reduction(%arg0: tensor<2x4x2048x4096xf32>) -> tensor<2x4x2048xf32> { + %cst = arith.constant 0xFF800000 : f32 + %0 = tensor.empty() : tensor<2x4x2048xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4x2048xf32>) + -> tensor<2x4x2048xf32> + %2 = linalg.reduce { arith.maxf } + ins(%arg0 : tensor<2x4x2048x4096xf32>) + outs(%1 : tensor<2x4x2048xf32>) + dimensions = [3] + return %2 : tensor<2x4x2048xf32> +} + +// CHECK: func.func @reduction(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>) +// CHECK-NOT: collapse_shape +// CHECK-NOT: expand_shape + +// CHECK-1: func.func @reduction(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>) +// CHECK-1-DAG: %[[CST:.*]] = arith.constant 0xFF800000 : f32 +// CHECK-1: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] [ +// CHECK-1-SAME: [0, 1, 2], [3]] +// CHECK-1: %[[EMPTY:.*]] = tensor.empty() +// CHECK-1: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY]] : tensor<16384xf32>) +// CHECK-1: %[[REDUCED:.*]] = linalg.reduce { arith.maxf } +// CHECK-1-SAME: ins(%[[COLLAPSED]] : tensor<16384x4096xf32>) +// CHECK-1-SAME: outs(%[[FILL]] : tensor<16384xf32>) +// CHECK-1: %[[EXPANDED:.*]] = tensor.expand_shape %[[REDUCED]] [ +// CHECK-1-SAME: [0, 1, 2]] +// CHECK-1: return %[[EXPANDED]] + + +// ----- + +func.func @cwise(%arg0: tensor<2x4x2048x4096xf32>, + %arg1: tensor<2x4x2048x4096xf32>) -> tensor<2x4x2048x4096xf32> { + %0 = tensor.empty() : tensor<2x4x2048x4096xf32> + %1 = linalg.map { arith.subf } + ins(%arg0, %arg1 : tensor<2x4x2048x4096xf32>, tensor<2x4x2048x4096xf32>) + outs(%0 : tensor<2x4x2048x4096xf32>) + return %1 : tensor<2x4x2048x4096xf32> +} + +// CHECK: func.func @cwise(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>, %[[ARG1:.*]]: tensor<2x4x2048x4096xf32>) +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] [ +// CHECK-SAME: [0, 1, 2, 3]] +// CHECK: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[ARG1]] [ +// CHECK-SAME: [0, 1, 2, 3]] +// CHECK: %[[EMPTY:.*]] = tensor.empty() +// CHECK: %[[MAP:.*]] = linalg.map { arith.subf } +// CHECK: ins(%[[COLLAPSED]], %[[COLLAPSED_0]] : tensor<67108864xf32>, tensor<67108864xf32>) +// CHECK: outs(%[[EMPTY]] : tensor<67108864xf32>) +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[MAP]] [ +// CHECK-SAME: [0, 1, 2, 3]] +// CHECK: return %[[EXPANDED]] + +// CHECK-1: func.func @cwise(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>, %[[ARG1:.*]]: tensor<2x4x2048x4096xf32>) +// CHECK-1: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] [ +// CHECK-1-SAME: [0, 1, 2], [3]] +// CHECK-1: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[ARG1]] [ +// CHECK-1-SAME: [0, 1, 2], [3]] +// CHECK-1: %[[EMPTY:.*]] = tensor.empty() +// CHECK-1: %[[MAP:.*]] = linalg.map { arith.subf } +// CHECK-1-SAME: ins(%[[COLLAPSED]], %[[COLLAPSED_0]] : tensor<16384x4096xf32>, tensor<16384x4096xf32>) +// CHECK-1-SAME outs(%[[EMPTY]] : tensor<16384x4096xf32>) +// CHECK-1: %[[EXPANDED:.*]] = tensor.expand_shape %[[MAP]] [ +// CHECK-1-SAME: [0, 1, 2], [3]] +// CHECK-1: return %[[EXPANDED]] + +// CHECK-2: func.func @cwise(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>, %[[ARG1:.*]]: tensor<2x4x2048x4096xf32>) +// CHECK-2: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] [ +// CHECK-2-SAME: [0, 1], [2], [3]] +// CHECK-2: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[ARG1]] [ +// CHECK-2-SAME: [0, 1], [2], [3]] +// CHECK-2: %[[EMPTY:.*]] = tensor.empty() +// CHECK-2: %[[MAP:.*]] = linalg.map { arith.subf } +// CHECK-2-SAME: ins(%[[COLLAPSED]], %[[COLLAPSED_0]] : tensor<8x2048x4096xf32>, tensor<8x2048x4096xf32>) +// CHECK-2-SAME outs(%[[EMPTY]] : tensor<8x2048x4096xf32>) +// CHECK-2: %[[EXPANDED:.*]] = tensor.expand_shape %[[MAP]] [ +// CHECK-2-SAME: [0, 1], [2], [3]] +// CHECK-2: return %[[EXPANDED]] + +// CHECK-3: func.func @cwise(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>, %[[ARG1:.*]]: tensor<2x4x2048x4096xf32>) +// CHECK-3-NOT: collapse_shape +// CHECK-3-NOT: expand_shape + +// ----- + +func.func @partial_softmax(%arg0: tensor<2x4x2048x4096xf32>) + -> tensor<2x4x2048x4096xf32> { + %cst = arith.constant 0xFF800000 : f32 + %0 = tensor.empty() : tensor<2x4x2048xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x4x2048xf32>) + -> tensor<2x4x2048xf32> + %2 = linalg.reduce { arith.maxf } + ins(%arg0 : tensor<2x4x2048x4096xf32>) + outs(%1 : tensor<2x4x2048xf32>) + dimensions = [3] + %3 = tensor.empty() : tensor<2x4x2048x4096xf32> + %4 = linalg.broadcast + ins(%2 : tensor<2x4x2048xf32>) + outs(%3 : tensor<2x4x2048x4096xf32>) + dimensions = [3] + %5 = linalg.map { arith.subf } + ins(%arg0, %4 : tensor<2x4x2048x4096xf32>, tensor<2x4x2048x4096xf32>) + outs(%3 : tensor<2x4x2048x4096xf32>) + return %5 : tensor<2x4x2048x4096xf32> +} + +// CHECK-1: func.func @partial_softmax(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>) +// CHECK-1-DAG: %[[CST:.*]] = arith.constant 0xFF800000 : f32 +// CHECK-1: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] [ +// CHECK-1-SAME: [0, 1, 2], [3]] +// CHECK-1: %[[EMPTY:.*]] = tensor.empty() +// CHECK-1: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY]] : tensor<16384xf32>) +// CHECK-1: %[[REDUCE:.*]] = linalg.reduce { arith.maxf } +// CHECK-1-SAME: ins(%[[COLLAPSED]] : tensor<16384x4096xf32>) +// CHECK-1-SAME: outs(%[[FILL]] : tensor<16384xf32>) +// CHECK-1-SAME: dimensions = [1] +// CHECK-1: %[[EMPTY_0:.*]] = tensor.empty() +// CHECK-1: %[[BROADCAST:.*]] = linalg.broadcast +// CHECK-1-SAME: ins(%[[REDUCE]] : tensor<16384xf32>) +// CHECK-1-SAME: outs(%[[EMPTY_0]] : tensor<16384x4096xf32>) +// CHECK-1-SAME: dimensions = [1] +// CHECK-1: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[ARG0]] [ +// CHECK-1-SAME: [0, 1, 2], [3]] +// CHECK-1: %[[EMPTY_1:.*]] = tensor.empty() +// CHECK-1: %[[MAP:.*]] = linalg.map { arith.subf } +// CHECK-1-SAME: ins(%[[COLLAPSED_0]], %[[BROADCAST]] : tensor<16384x4096xf32>, tensor<16384x4096xf32>) +// CHECK-1-SAME: outs(%[[EMPTY_1]] : tensor<16384x4096xf32>) +// CHECK-1: %[[EXPANDED:.*]] = tensor.expand_shape %[[MAP]] [ +// CHECK-1-SAME: [0, 1, 2], [3]] +// CHECK-1: return %[[EXPANDED]] + +// CHECK-2: func.func @partial_softmax(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>) +// CHECK-2-DAG: %[[CST:.*]] = arith.constant 0xFF800000 : f32 +// CHECK-2: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[ARG0]] [ +// CHECK-2-SAME: [0, 1], [2], [3]] +// CHECK-2: %[[EMPTY:.*]] = tensor.empty() +// CHECK-2: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY]] : tensor<8x2048xf32>) +// CHECK-2: %[[REDUCE:.*]] = linalg.reduce { arith.maxf } +// CHECK-2-SAME: ins(%[[COLLAPSED]] : tensor<8x2048x4096xf32>) +// CHECK-2-SAME: outs(%[[FILL]] : tensor<8x2048xf32>) +// CHECK-2-SAME: dimensions = [2] +// CHECK-2: %[[EMPTY_0:.*]] = tensor.empty() +// CHECK-2: %[[BROADCAST:.*]] = linalg.broadcast +// CHECK-2-SAME: ins(%[[REDUCE]] : tensor<8x2048xf32>) +// CHECK-2-SAME: outs(%[[EMPTY_0]] : tensor<8x2048x4096xf32>) +// CHECK-2-SAME: dimensions = [2] +// CHECK-2: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[ARG0]] [ +// CHECK-2-SAME: [0, 1], [2], [3]] +// CHECK-2: %[[EMPTY_1:.*]] = tensor.empty() +// CHECK-2: %[[MAP:.*]] = linalg.map { arith.subf } +// CHECK-2-SAME: ins(%[[COLLAPSED_0]], %[[BROADCAST]] : tensor<8x2048x4096xf32>, tensor<8x2048x4096xf32>) +// CHECK-2-SAME: outs(%[[EMPTY_1]] : tensor<8x2048x4096xf32>) +// CHECK-2: %[[EXPANDED:.*]] = tensor.expand_shape %[[MAP]] [ +// CHECK-2-SAME: [0, 1], [2], [3]] +// CHECK-2: return %[[EXPANDED]] + +// CHECK-3: func.func @partial_softmax(%[[ARG0:.*]]: tensor<2x4x2048x4096xf32>) +// CHECK-3-NOT: collapse_shape +// CHECK-3-NOT: expand_shape + +// ----- + + +func.func @collapse_shape_of_cwise(%arg0: tensor<2x4xf32>) -> tensor<8xf32> { + %0 = tensor.empty() : tensor<2x4xf32> + %1 = linalg.map { arith.negf } + ins(%arg0 : tensor<2x4xf32>) + outs(%0 : tensor<2x4xf32>) + %3 = tensor.collapse_shape %1 [[0, 1]] : tensor<2x4xf32> into tensor<8xf32> + return %3 : tensor<8xf32> +} + +// CHECK: func.func @collapse_shape_of_cwise +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape {{.*}} [ +// CHECK-SAME: [0, 1]] : tensor<2x4xf32> into tensor<8xf32> +// CHECK: %[[MAPPED:.*]] = linalg.map +// CHECK: ins(%[[COLLAPSED]] : tensor<8xf32>) + +// CHECK-1: func.func @collapse_shape_of_cwise +// CHECK-2: func.func @collapse_shape_of_cwise +// CHECK-3: func.func @collapse_shape_of_cwise + diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/collapse_materialize_ops.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/collapse_materialize_ops.mlir deleted file mode 100644 index 4bc3575ab34..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/collapse_materialize_ops.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: mlir-hlo-opt %s --gml-collapse-materialize-ops | \ -// RUN: FileCheck %s - -func.func @compose_tiles(%arg: tensor, %i: index, %j: index, %k: index, - %n: index, %a: index, %b: index) -> tensor<4x?xf32> { - %1 = gml_st.tile [%i, %j] [4, 128] [2, %a] - : !gml_st.tile<4x128> - %4 = gml_st.materialize %arg[%1] : tensor[!gml_st.tile<4x128>] to tensor<4x128xf32> - %3 = gml_st.tile [0, %k] [4, %n] [1, %b] - : !gml_st.tile<4x?> - %5 = gml_st.materialize %4[%3] : tensor<4x128xf32>[!gml_st.tile<4x?>] to tensor<4x?xf32> - return %5 : tensor<4x?xf32> -} -// CHECK-LABEL: @compose_tiles -// CHECK-SAME: %[[ARG:[a-z0-9]+]]: tensor, %[[I:[a-z0-9]+]]: index, -// CHECK-SAME: %[[J:[a-z0-9]+]]: index, %[[K:[a-z0-9]+]]: index, -// CHECK-SAME: %[[N:[a-z0-9]+]]: index, %[[A:[a-z0-9]+]]: index, -// CHECK-SAME: %[[B:[a-z0-9]+]]: index) - -// CHECK-DAG: %[[AK:.*]] = arith.muli %[[A]], %[[K]] -// CHECK-DAG: %[[J_PLUS_AK:.*]] = arith.addi %[[J]], %[[AK]] -// CHECK-DAG: %[[AB:.*]] = arith.muli %[[A]], %[[B]] -// CHECK: %[[TILE:.*]] = gml_st.tile [%[[I]], %[[J_PLUS_AK]]] [4, %[[N]]] -// CHECK-SAME: [2, %[[AB]]] : !gml_st.tile<4x?> -// CHECK-NEXT: %[[RES:.*]] = gml_st.materialize %[[ARG]][%[[TILE]]] -// CHECK-SAME: : tensor[!gml_st.tile<4x?>] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/compose_extract_insert_slice.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/compose_extract_insert_slice.mlir new file mode 100644 index 00000000000..eb58d45d08f --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/compose_extract_insert_slice.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-hlo-opt %s --gml-compose-extract-insert-slice | FileCheck %s + +func.func @compose_tiles(%arg: tensor, %i: index, %j: index, %k: index, + %n: index, %a: index, %b: index) -> tensor<4x?xf32> { + %4 = tensor.extract_slice %arg[%i, %j] [4, 128] [2, %a] + : tensor to tensor<4x128xf32> + %5 = tensor.extract_slice %4[0, %k] [4, %n] [1, %b] + : tensor<4x128xf32> to tensor<4x?xf32> + return %5 : tensor<4x?xf32> +} +// CHECK-LABEL: @compose_tiles +// CHECK-SAME: %[[ARG:[a-z0-9]+]]: tensor, %[[I:[a-z0-9]+]]: index, +// CHECK-SAME: %[[J:[a-z0-9]+]]: index, %[[K:[a-z0-9]+]]: index, +// CHECK-SAME: %[[N:[a-z0-9]+]]: index, %[[A:[a-z0-9]+]]: index, +// CHECK-SAME: %[[B:[a-z0-9]+]]: index) + +// CHECK-DAG: %[[J_PLUS_AK:.*]] = affine.apply +// CHECK-DAG: %[[AB:.*]] = affine.apply +// CHECK-NEXT: %[[RES:.*]] = tensor.extract_slice %[[ARG]] +// CHECK-SAME: [%[[I]], %[[J_PLUS_AK]]] [4, %[[N]]] [2, %[[AB]]] +// CHECK-SAME: : tensor diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_bcast_map.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_bcast_map.mlir new file mode 100644 index 00000000000..dc5210abd97 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_bcast_map.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-hlo-opt %s --gml-st-cpu-tiling-pipeline | FileCheck %s + +func.func @map_bcast_map(%arg0: tensor, %arg1: tensor, + %init0: tensor, + %init1: tensor) -> tensor { + %abs = linalg.map { math.absf } + ins(%arg0:tensor) + outs(%init0:tensor) + + %bcast = linalg.broadcast + ins(%abs : tensor) + outs(%init1 : tensor) + dimensions = [1, 2] + + %mapped = linalg.map { arith.addf } + ins(%bcast, %arg1 : tensor, tensor) + outs(%init1:tensor) + func.return %mapped : tensor +} + +// CHECK-LABEL: func.func @map_bcast_map + +// CHECK: gml_st.parallel +// CHECK: math.absf %{{.*}} : f32 +// CHECK: vector.broadcast %{{.*}} : vector<1xf32> to vector<1x8x1xf32> +// CHECK: vector.transpose %{{.*}}, [2, 0, 1] : vector<1x8x1xf32> to vector<1x1x8xf32> +// CHECK: arith.addf %{{.*}} : vector<8xf32> +// CHECK: gml_st.set_yield + +// CHECK: gml_st.parallel +// CHECK: gml_st.parallel +// CHECK: math.absf %{{.*}} : f32 +// CHECK: arith.addf %{{.*}} : f32 +// CHECK: gml_st.set_yield +// CHECK: gml_st.set_yield diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_matmul.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_matmul.mlir new file mode 100644 index 00000000000..4be58e566fd --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_matmul.mlir @@ -0,0 +1,70 @@ +// RUN: mlir-hlo-opt %s --split-input-file --gml-st-cpu-tiling-pipeline | FileCheck %s + +func.func @map_matmul(%arg0: tensor, + %arg1: tensor, %arg2: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg0, %c0 : tensor + %dim1 = tensor.dim %arg1, %c1 : tensor + %init = tensor.empty(%dim0, %dim1) : tensor + %cst = arith.constant 0.000000e+00 : f32 + %filled = linalg.fill ins(%cst : f32) + outs(%init : tensor) -> tensor + %4 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%filled : tensor) -> tensor + %5 = linalg.matmul ins(%arg0, %arg2 : tensor, tensor) + outs(%filled : tensor) -> tensor + %6 = linalg.map { math.absf } + ins(%5 : tensor) + outs(%init : tensor) + + %result = linalg.map { arith.addf } + ins(%4, %6 : tensor, tensor) + outs(%init : tensor) + return %result : tensor +} + +// CHECK-LABEL: @map_matmul + +// CHECK: gml_st.parallel +// CHECK: scf.for +// CHECK-COUNT-2: vector.transfer_read +// CHECK: vector.transpose +// CHECK-COUNT-4: vector.outerproduct +// CHECK: scf.yield +// CHECK: scf.for +// CHECK: linalg.matmul +// CHECK: scf.yield +// CHECK: scf.for +// CHECK-COUNT-2: vector.transfer_read +// CHECK: vector.transpose +// CHECK-COUNT-4: vector.outerproduct +// CHECK: scf.yield +// CHECK: scf.for +// CHECK: linalg.matmul +// CHECK: scf.yield +// CHECK: math.absf %{{.*}} : vector<4x4xf32> +// CHECK: arith.addf %{{.*}} : vector<4x4xf32> +// CHECK: gml_st.set_yield + +// CHECK: gml_st.parallel +// CHECK: scf.for +// CHECK: linalg.matmul +// CHECK: scf.yield +// CHECK: scf.for +// CHECK: linalg.matmul +// CHECK: scf.yield +// CHECK: linalg.map +// CHECK: linalg.map +// CHECK: gml_st.set_yield + +// CHECK: gml_st.parallel +// CHECK: scf.for +// CHECK: linalg.matmul +// CHECK: scf.yield +// CHECK: scf.for +// CHECK: linalg.matmul +// CHECK: scf.yield +// CHECK: linalg.map +// CHECK: linalg.map +// CHECK: gml_st.set_yield diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_reduce_map.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_reduce_map.mlir new file mode 100644 index 00000000000..7cf0fe560c3 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/map_reduce_map.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-hlo-opt %s --split-input-file --gml-st-cpu-tiling-pipeline \ +// RUN: | FileCheck %s + +func.func @reduce_map_fuse_map(%arg0: tensor<10x100xf32>, + %arg1: tensor<10x100xf32>, %output: tensor<10xf32>) -> tensor<10xf32> { + %map_init = tensor.empty() : tensor<10x100xf32> + %reduce_init = tensor.empty() : tensor<10xf32> + %mapped = linalg.map { arith.addf } + ins(%arg0, %arg1 : tensor<10x100xf32>, tensor<10x100xf32>) + outs(%map_init : tensor<10x100xf32>) + + %reduce = linalg.reduce { arith.addf } + ins(%mapped: tensor<10x100xf32>) + outs(%reduce_init: tensor<10xf32>) + dimensions = [1] + + %res = linalg.map { math.absf } + ins(%reduce: tensor<10xf32>) + outs(%output : tensor<10xf32>) + return %res : tensor<10xf32> +} +// CHECK-LABEL: @reduce_map_fuse_map + +// TODO(pifon): The lowering is severely broken. Fixing it in a follow-up. diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/matmul.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/matmul.mlir new file mode 100644 index 00000000000..c1d1558f06b --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/matmul.mlir @@ -0,0 +1,206 @@ +// RUN: mlir-hlo-opt %s --split-input-file --gml-st-cpu-tiling-pipeline | FileCheck %s +// RUN: mlir-hlo-opt %s --gml-st-cpu-tiling-pipeline="lower-to-mmt4d=true" | FileCheck %s --check-prefixes=PACKED + +func.func @matmul_static(%lhs: tensor<128x16xf32>, %rhs: tensor<16x64xf32>, + %output: tensor<128x64xf32>) -> tensor<128x64xf32> { + %2 = linalg.matmul ins(%lhs, %rhs : tensor<128x16xf32>, tensor<16x64xf32>) + outs(%output : tensor<128x64xf32>) -> tensor<128x64xf32> + return %2 : tensor<128x64xf32> +} + +// CHECK-LABEL: @matmul_static + +// CHECK: gml_st.parallel +// CHECK: vector.transfer_read +// CHECK-NEXT: scf.for +// CHECK-COUNT-2: vector.transfer_read +// CHECK: vector.transpose +// CHECK-COUNT-4: vector.outerproduct +// CHECK: scf.yield {{.*}} : vector<4x4xf32> +// CHECK: vector.transfer_write +// CHECK: gml_st.set_yield + +// PACKED-LABEL: @matmul_static + +// PACKED: tensor.empty() : tensor<16x16x8x1xf32> +// PACKED-COUNT-2: scf.for +// PACKED: vector.transfer_read +// PACKED: vector.transfer_write +// PACKED: scf.yield %{{.*}} : tensor<16x16x8x1xf32> +// PACKED: scf.yield %{{.*}} : tensor<16x16x8x1xf32> + +// PACKED: tensor.empty() : tensor<8x16x8x1xf32> +// PACKED-COUNT-2: scf.for +// PACKED: vector.broadcast +// PACKED: vector.transpose +// PACKED: scf.yield %{{.*}} : tensor<8x16x8x1xf32> +// PACKED: scf.yield %{{.*}} : tensor<8x16x8x1xf32> + +// PACKED: tensor.empty() : tensor<16x8x8x8xf32> +// PACKED-COUNT-2: scf.for +// PACKED: vector.transfer_read +// PACKED: vector.transfer_write +// PACKED: scf.yield +// PACKED: scf.yield + +// PACKED-COUNT-2: scf.for +// PACKED: vector.broadcast +// PACKED: scf.for +// PACKED: vector.transfer_read +// PACKED: vector.transfer_read +// PACKED: vector.transpose +// PACKED: vector.transpose +// PACKED: vector.outerproduct +// PACKED: vector.broadcast +// PACKED: vector.broadcast +// PACKED: scf.yield +// PACKED: scf.yield +// PACKED: scf.yield + +// PACKED: tensor.empty() : tensor<128x64xf32> +// PACKED-COUNT-2: scf.for +// PACKED: vector.transfer_read +// PACKED: vector.transfer_write +// PACKED: scf.yield %{{.*}} : tensor<128x64xf32> +// PACKED: scf.yield %{{.*}} : tensor<128x64xf32> + + + +// ----- + +func.func @matmul(%lhs: tensor, + %rhs: tensor) -> tensor { + %c0 = arith.constant 0 : index + %0 = tensor.dim %lhs, %c0 : tensor + %c1 = arith.constant 1 : index + %1 = tensor.dim %rhs, %c1 : tensor + %2 = tensor.empty(%0, %1) : tensor + %cst = arith.constant 0.000000e+00 : f32 + %3 = linalg.fill ins(%cst : f32) + outs(%2 : tensor) -> tensor + %4 = linalg.matmul ins(%lhs, %rhs : tensor, tensor) + outs(%3 : tensor) -> tensor + return %4 : tensor +} +// CHECK-LABEL: @matmul + +// CHECK: gml_st.parallel +// CHECK: scf.for +// CHECK-COUNT-2: vector.transfer_read +// CHECK: vector.transpose +// CHECK-COUNT-4: vector.outerproduct +// CHECK-NEXT: scf.yield %{{.*}} : vector<4x4xf32> +// CHECK: vector.transfer_write + +// CHECK-NEXT: scf.for +// CHECK: linalg.matmul {{.*}} -> tensor<4x4xf32> +// CHECK: scf.yield {{.*}} : tensor +// CHECK: gml_st.set_yield + +// CHECK: gml_st.parallel +// CHECK: linalg.fill +// CHECK: scf.for +// CHECK: linalg.matmul {{.*}} -> tensor<4x?xf32> +// CHECK: scf.yield {{.*}} : tensor<4x?xf32> +// CHECK: gml_st.set_yield + +// CHECK: gml_st.parallel +// CHECK: linalg.fill +// CHECK: scf.for +// CHECK: linalg.matmul +// CHECK: scf.yield {{.*}} : tensor +// CHECK: gml_st.set_yield + +// ----- + +func.func @matmul_narrow_static(%lhs: tensor<2x16xf32>, %rhs: tensor<16x64xf32>, + %output: tensor<2x64xf32>) -> tensor<2x64xf32> { + %2 = linalg.matmul ins(%lhs, %rhs : tensor<2x16xf32>, tensor<16x64xf32>) + outs(%output : tensor<2x64xf32>) -> tensor<2x64xf32> + return %2 : tensor<2x64xf32> +} +// CHECK-LABEL: @matmul_narrow_static + +// CHECK: gml_st.parallel +// CHECK: scf.for +// CHECK-COUNT-2: vector.transfer_read +// CHECK: vector.transpose +// CHECK-COUNT-4: vector.outerproduct +// CHECK: scf.yield {{.*}} : vector<2x4xf32> +// CHECK: vector.transfer_write +// CHECK: gml_st.set_yield + +// PACKED-LABEL: @matmul_narrow_static + +// PACKED: tensor.empty() : tensor<1x16x2x1xf32> +// PACKED: scf.for +// PACKED: vector.transfer_read +// PACKED: vector.transfer_write +// PACKED: scf.yield %{{.*}} : tensor<1x16x2x1xf32> +// PACKED: } + +// PACKED: tensor.empty() : tensor<8x16x8x1xf32> +// PACKED-COUNT: scf.for +// PACKED: vector.broadcast +// PACKED: vector.transpose +// PACKED: scf.yield %{{.*}} : tensor<8x16x8x1xf32> +// PACKED: scf.yield %{{.*}} : tensor<8x16x8x1xf32> + +// PACKED: tensor.empty() : tensor<1x8x2x8xf32> +// PACKED: scf.for +// PACKED: vector.transfer_read +// PACKED: vector.transfer_write +// PACKED: scf.yield %{{.*}} : tensor<1x8x2x8xf32> +// PACKED: scf.for +// PACKED: vector.broadcast +// PACKED: scf.for +// PACKED: vector.transpose +// PACKED: vector.transpose +// PACKED: vector.outerproduct +// PACKED: vector.broadcast +// PACKED: vector.broadcast +// PACKED: scf.yield %{{.*}} : vector<1x1x2x8xf32> +// PACKED: scf.yield + +// PACKED: tensor.empty() : tensor<2x64xf32> +// PACKED: scf.for +// PACKED: vector.transfer_read +// PACKED: vector.transfer_write +// PACKED: scf.yield %{{.*}} : tensor<2x64xf32> + +// ----- + +func.func @matmul_small_static_peeling(%lhs: tensor<2x4xf32>, %arg1: tensor<4x6xf32>, + %output: tensor<2x6xf32>) -> tensor<2x6xf32> { + %2 = linalg.matmul ins(%lhs, %arg1 : tensor<2x4xf32>, tensor<4x6xf32>) + outs(%output : tensor<2x6xf32>) -> tensor<2x6xf32> + return %2 : tensor<2x6xf32> +} +// CHECK-LABEL: @matmul_small_static_peeling + +// CHECK-NOT: gml_st.parallel +// CHECK-NOT: scf.for +// CHECK: vector.transpose +// CHECK-COUNT-4: vector.outerproduct +// CHECK: vector.transpose +// CHECK-COUNT-4: vector.outerproduct + +// ----- + +func.func @matvec_static(%lhs: tensor<1x16xf32>, %arg1: tensor<16x64xf32>, + %output: tensor<1x64xf32>) -> tensor<1x64xf32> { + %2 = linalg.matmul ins(%lhs, %arg1 : tensor<1x16xf32>, tensor<16x64xf32>) + outs(%output : tensor<1x64xf32>) -> tensor<1x64xf32> + return %2 : tensor<1x64xf32> +} +// CHECK-LABEL: @matvec_static + +// CHECK: gml_st.parallel +// CHECK: vector.transfer_read +// CHECK-NEXT: vector.broadcast +// CHECK-NEXT: scf.for +// CHECK-COUNT-2: vector.transfer_read +// CHECK-COUNT-4: vector.outerproduct +// CHECK: scf.yield {{.*}} : vector<1x4xf32> +// CHECK: vector.transfer_write +// CHECK: gml_st.set_yield diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_1d.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_1d.mlir new file mode 100644 index 00000000000..5a7f5f9c75a --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_1d.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-hlo-opt %s --split-input-file --gml-st-cpu-tiling-pipeline \ +// RUN: | FileCheck %s + +func.func @reduce_1d_static(%arg0: tensor<100xf32>) -> tensor { + %1 = tensor.empty() : tensor + %cst = arith.constant 0.0 : f32 + %init = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor + %res = linalg.reduce { arith.addf } + ins(%arg0: tensor<100xf32>) outs(%init: tensor) dimensions = [0] + return %res : tensor +} +// CHECK-LABEL: @reduce_1d_static + +// CHECK: arith.constant dense<0.000000e+00> : vector<8xf32> +// CHECK: tensor.empty() : tensor + +// CHECK: scf.for +// CHECK: vector.multi_reduction +// CHECK-SAME: : vector<4x8xf32> to vector<8xf32> +// CHECK: scf.yield %{{.*}} : vector<8xf32> + +// CHECK: vector.multi_reduction +// CHECK-SAME: : vector<8xf32> to f32 +// CHECK: vector.multi_reduction +// CHECK-SAME: : vector<4xf32> to f32 + +// ----- + +func.func @reduce_1d_dynamic(%arg0: tensor) -> tensor { + %1 = tensor.empty() : tensor + %cst = arith.constant 0.0 : f32 + %init = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor + %res = linalg.reduce { arith.addf } + ins(%arg0: tensor) outs(%init: tensor) dimensions = [0] + return %res : tensor +} +// CHECK-LABEL: func @reduce_1d_dynamic + +// CHECK: arith.constant dense<0.000000e+00> : vector<8xf32> +// CHECK: tensor.empty() : tensor + +// CHECK: scf.for +// CHECK: vector.multi_reduction +// CHECK-SAME: : vector<4x8xf32> to vector<8xf32> +// CHECK: scf.yield %{{.*}} : vector<8xf32> + +// CHECK: vector.multi_reduction +// CHECK-SAME: : vector<8xf32> to f32 + +// CHECK: scf.for +// CHECK: linalg.reduce +// CHECK: scf.yield %{{.*}} : tensor diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_2d.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_2d.mlir new file mode 100644 index 00000000000..c257d0ffdb2 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reduce_2d.mlir @@ -0,0 +1,55 @@ +// RUN: mlir-hlo-opt %s --split-input-file --gml-st-cpu-tiling-pipeline \ +// RUN: | FileCheck %s + +func.func @reduce_static(%input: tensor<100x10xf32>, + %output: tensor<10xf32>) -> tensor<10xf32> { + %res = linalg.reduce { arith.addf } + ins(%input: tensor<100x10xf32>) + outs(%output: tensor<10xf32>) + dimensions = [0] + return %res : tensor<10xf32> +} +// CHECK-LABEL: @reduce_static + +// CHECK: gml_st.parallel +// CHECK: scf.for +// CHECK: vector.multi_reduction +// CHECK-SAME: : vector<4x4xf32> to vector<4xf32> +// CHECK-NEXT: scf.yield %{{.*}} : vector<4xf32> +// CHECK: gml_st.set_yield + +// ----- + +func.func @reduce_dynamic(%input: tensor, + %output: tensor) -> tensor { + %c0 = arith.constant 0 : index + %0 = tensor.dim %output, %c0 : tensor + %1 = tensor.empty(%0) : tensor + %cst = arith.constant 0.000000e+00 : f32 + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor + %res = linalg.reduce { arith.mulf } + ins(%input: tensor) + outs(%2: tensor) + dimensions = [1] + return %res : tensor +} +// CHECK-LABEL: @reduce_dynamic + +// CHECK: gml_st.parallel +// CHECK: scf.for +// CHECK: vector.multi_reduction +// CHECK-SAME: : vector<4x4xf32> to vector<4xf32> +// CHECK-NEXT: scf.yield %{{.*}} : vector<4xf32> + +// CHECK: scf.for +// CHECK: vector.multi_reduction +// CHECK-SAME: : vector<4x1xf32> to vector<4xf32> +// CHECK-NEXT: scf.yield %{{.*}} : vector<4xf32> +// CHECK: gml_st.set_yield + +// CHECK: gml_st.parallel +// CHECK: linalg.fill +// CHECK: scf.for +// CHECK: linalg.reduce +// CHECK: scf.yield +// CHECK: gml_st.set_yield diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reverse.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reverse.mlir new file mode 100644 index 00000000000..e65c630345c --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/reverse.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-hlo-opt %s --split-input-file --gml-st-cpu-tiling-pipeline \ +// RUN: | FileCheck %s + +func.func @reverse_static_perfect_tiles( + %input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64xf32> { + %res = thlo.reverse + ins(%input: tensor<64xf32>) + outs(%init: tensor<64xf32>) + reverse_dimensions = [0] + func.return %res : tensor<64xf32> +} + +// CHECK-LABEL: @reverse_static_perfect_tiles + +// CHECK: gml_st.parallel +// CHECK: vector.shuffle +// CHECK: gml_st.set_yield + +// ----- + +func.func @reverse_dynamic( + %input: tensor, %init: tensor) -> tensor { + %res = thlo.reverse + ins(%input: tensor) + outs(%init: tensor) + reverse_dimensions = [0, 1] + func.return %res : tensor +} + +// CHECK-LABEL: @reverse_dynamic + +// CHECK: gml_st.parallel +// CHECK: vector.shuffle +// CHECK: gml_st.set_yield + +// CHECK: gml_st.parallel +// CHECK: gml_st.parallel +// CHECK: tensor.extract_slice +// CHECK: gml_st.set_yield +// CHECK: gml_st.set_yield diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/transform_scatter_for_cpu.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/scatter.mlir similarity index 77% rename from tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/transform_scatter_for_cpu.mlir rename to tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/scatter.mlir index f242ee41a18..33874e4e88a 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/transform_scatter_for_cpu.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/scatter.mlir @@ -1,7 +1,5 @@ // RUN: mlir-hlo-opt %s -xla-cpu-transform-scatter | FileCheck %s -#id_map = affine_map<(d0, d1) -> (d0, d1)> - func.func @scatter_small_vector_dim(%indices: tensor, %updates: tensor, %init: tensor) -> tensor { %result = thlo.scatter @@ -15,5 +13,6 @@ func.func @scatter_small_vector_dim(%indices: tensor, } // CHECK-LABEL: @scatter_small_vector_dim -// CHECK: gml_st.for -// CHECK: thlo.scatter ins(%{{.*}} : tensor<1x2xindex>, %{{.*}} : tensor<1x?x?xf32>) +// CHECK: scf.for +// CHECK: thlo.scatter +// CHECK-SAME: ins(%{{.*}} : tensor<1x2xindex>, %{{.*}} : tensor<1x?x?xf32>) diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/sort.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/sort.mlir new file mode 100644 index 00000000000..fcf9ccc7467 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/sort.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-hlo-opt %s --gml-st-cpu-tiling-pipeline | FileCheck %s + +func.func @sort(%input1: tensor<64x8x4xf32>, %input2: tensor<64x8x4xf32>, + %init1: tensor<64x8x4xf32>, %init2: tensor<64x8x4xf32>) { + thlo.sort + ins(%input1: tensor<64x8x4xf32>, %input2: tensor<64x8x4xf32>) + outs(%init1: tensor<64x8x4xf32>, %init2: tensor<64x8x4xf32>) + dimension = 1 + is_stable = true + (%e11: f32, %e12: f32, %e21: f32, %e22: f32) { + %gt = arith.cmpf ogt, %e11, %e12: f32 + thlo.yield %gt : i1 + } + func.return +} +// CHECK-LABEL: func.func @sort( + +// CHECK: gml_st.parallel +// CHECK: thlo.sort +// CHECK-SAME: ins(%{{.*}} : tensor<1x8x1xf32>, %{{.*}} : tensor<1x8x1xf32>) +// CHECK-SAME: dimension = 1 +// CHECK: gml_st.set_yield diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/transpose.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/transpose.mlir new file mode 100644 index 00000000000..231987e6b6f --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/cpu_tiling/transpose.mlir @@ -0,0 +1,41 @@ +// RUN: mlir-hlo-opt %s --gml-st-cpu-tiling-pipeline \ +// RUN: | FileCheck %s + +func.func @transpose(%input: tensor<16x32x64xf32>, + %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> { + %transpose = linalg.transpose + ins(%input:tensor<16x32x64xf32>) + outs(%init:tensor<32x64x16xf32>) + permutation = [1, 2, 0] + func.return %transpose : tensor<32x64x16xf32> +} +// CHECK-LABEL: func.func @transpose + +// CHECK: gml_st.parallel +// CHECK: vector.transpose +// CHECK-SAME: [1, 2, 0] : vector<8x1x8xf32> to vector<1x8x8xf32> +// CHECK: gml_st.set_yield + +// ----- + +func.func @peel_transpose(%input: tensor<16x32x65xf32>, + %init: tensor<32x65x16xf32>) -> tensor<32x65x16xf32> { + %transpose = linalg.transpose + ins(%input:tensor<16x32x65xf32>) + outs(%init:tensor<32x65x16xf32>) + permutation = [1, 2, 0] + func.return %transpose : tensor<32x65x16xf32> +} + +// CHECK-LABEL: @peel_transpose + +// CHECK: gml_st.parallel +// CHECK: vector.transpose +// CHECK-SAME: [1, 2, 0] : vector<8x1x8xf32> to vector<1x8x8xf32> +// CHECK: gml_st.set_yield + +// CHECK: gml_st.parallel +// CHECK: gml_st.parallel +// CHECK: tensor.extract +// CHECK: gml_st.set_yield +// CHECK: gml_st.set_yield diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/fusion.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/fusion.mlir index 81bd095d343..148421863ba 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/fusion.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/fusion.mlir @@ -18,9 +18,8 @@ func.func @dynamic_broadcast_in_dim(%arg : tensor, broadcast_dimensions = [0, 2] { op_label = "producer" } - %tile = gml_st.tile [%i, %j, %k] [3, 4, %arg_dim] [1, 1, 1] : !gml_st.tile<3x4x?> - %bcast_sub = gml_st.materialize %bcast[%tile] - : tensor[!gml_st.tile<3x4x?>] to tensor<3x4x?xf32> + %bcast_sub = tensor.extract_slice %bcast[%i, %j, %k] [3, 4, %arg_dim] [1, 1, 1] + : tensor to tensor<3x4x?xf32> func.return { op_label = "consumer" } %bcast_sub : tensor<3x4x?xf32> } // CHECK-LABEL: @dynamic_broadcast_in_dim @@ -35,8 +34,6 @@ func.func @dynamic_broadcast_in_dim(%arg : tensor, // CHECK: %[[EXTRACT_1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] // CHECK: %[[EXTRACT_2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]] // CHECK: %[[INIT:.*]] = tensor.empty(%[[EXTRACT_0]], %[[EXTRACT_1]], %[[EXTRACT_2]]) -// CHECK: %[[INIT_TILE:.*]] = gml_st.tile -// CHECK-SAME: [%[[I]], %[[J]], %[[K]]] [3, 4, %[[ARG_DIM]]] // CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG]], %[[C0]] // CHECK: %[[DIM_1:.*]] = tensor.dim %[[ARG]], %[[C1]] // CHECK: %[[CMPI_0:.*]] = arith.cmpi ne, %[[DIM_0]], %[[EXTRACT_0]] @@ -45,12 +42,12 @@ func.func @dynamic_broadcast_in_dim(%arg : tensor, // CHECK: %[[SELECT_0:.*]] = arith.select %[[CMPI_1]], %[[C0]], %[[K]] // CHECK: %[[SELECT_1:.*]] = arith.select %[[CMPI_0]], %[[C1]], %[[C3]] // CHECK: %[[SELECT_2:.*]] = arith.select %[[CMPI_1]], %[[C1]], %[[ARG_DIM]] -// CHECK: %[[ARG_TILE:.*]] = gml_st.tile +// CHECK: %[[INIT_SUB:.*]] = tensor.extract_slice %[[INIT]] +// CHECK-SAME: [%[[I]], %[[J]], %[[K]]] [3, 4, %[[ARG_DIM]]] +// CHECK: %[[ARG_SUB:.*]] = tensor.extract_slice %[[ARG]] // CHECK-SAME: [%[[SELECT]], %[[SELECT_0]]] // CHECK-SAME: [%[[SELECT_1]], %[[SELECT_2]]] // CHECK-SAME: [1, 1] -// CHECK: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]][%[[INIT_TILE]]] -// CHECK: %[[ARG_SUB:.*]] = gml_st.materialize %[[ARG]][%[[ARG_TILE]]] // CHECK: %[[DYNAMIC:.*]] = thlo.dynamic_broadcast_in_dim // CHECK-SAME: ins(%[[ARG_SUB]] : tensor) // CHECK-SAME: outs(%[[INIT_SUB]] : tensor<3x4x?xf32>) @@ -59,69 +56,62 @@ func.func @dynamic_broadcast_in_dim(%arg : tensor, // ----- -func.func @concatenate_at_tile(%init : tensor, %a: tensor, - %b: tensor, %c: tensor, %i: index, %j: index, - %arg_dim0: index, %arg_dim1: index) -> tensor { - %tile = gml_st.tile [%i, %j] [%arg_dim0, %arg_dim1] [1, 1] : !gml_st.tile - %concat = thlo.concatenate - ins(%a : tensor, %b : tensor, %c : tensor) - outs(%init : tensor) { - dimension = 1 : i64, - op_label = "producer" } - %concat_sub = gml_st.materialize %concat[%tile] - : tensor[!gml_st.tile] to tensor - func.return { op_label = "consumer" } %concat_sub : tensor +// CHECK-LABEL: @nary_concatenate_with_unit_dims +// CHECK-SAME: %[[INIT:.*]]: tensor, %[[ARG_A:.*]]: tensor, %[[ARG_B:.*]]: tensor, %[[ARG_C:.*]]: tensor, %[[I:.*]]: index, %[[J:.*]]: index, %[[N:.*]]: index +func.func @nary_concatenate_with_unit_dims(%init : tensor, + %a: tensor, %b: tensor, %c: tensor, %i: index, + %j: index, %n: index) -> tensor { + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG_A]], %[[C1]] + // CHECK: %[[CMPI:.*]] = arith.cmpi ult, %[[J]], %[[DIM]] + // CHECK: %[[IF:.*]] = scf.if %[[CMPI]] -> (tensor) + // CHECK: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[ARG_A]][%[[I]], %[[J]]] [%[[N]], 1] [1, 1] + // CHECK: scf.yield %[[MATERIALIZE]] + // CHECK: else + // CHECK: %[[SUBI:.*]] = arith.subi %[[J]], %[[DIM]] + // CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG_B]], %[[C1]] + // CHECK: %[[CMPI_0:.*]] = arith.cmpi ult, %[[SUBI]], %[[DIM_0]] + // CHECK: %[[IF_0:.*]] = scf.if %[[CMPI_0]] -> (tensor) + // CHECK: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[ARG_B]][%[[I]], %[[SUBI]]] [%[[N]], 1] [1, 1] + // CHECK: scf.yield %[[MATERIALIZE_0]] + // CHECK: else + // CHECK: %[[SUBI_0:.*]] = arith.subi %[[SUBI]], %[[DIM_0]] + // CHECK: %[[MATERIALIZE_1:.*]] = tensor.extract_slice %[[ARG_C]][%[[I]], %[[SUBI_0]]] [%[[N]], 1] [1, 1] + // CHECK: scf.yield %[[MATERIALIZE_1]] + // CHECK: scf.yield %[[IF_0]] + // CHECK: return {op_label = "consumer"} %[[IF]] + %concat = thlo.concatenate ins(%a : tensor, %b : tensor, + %c : tensor) outs(%init : tensor) dimension = 1 + { op_label = "producer" } + %tiled_concat = tensor.extract_slice %concat[%i, %j] [%n, 1] [1, 1] + : tensor to tensor + func.return { op_label = "consumer" } %tiled_concat : tensor } -// CHECK-LABEL: @concatenate -// CHECK-SAME: (%[[INIT:[a-z0-9]+]]: tensor, %[[A:[a-z0-9]+]]: tensor, -// CHECK-SAME: %[[B:[a-z0-9]+]]: tensor, %[[C:[a-z0-9]+]]: tensor, -// CHECK-SAME: %[[I:[a-z0-9]+]]: index, %[[J:[a-z0-9]+]]: index, -// CHECK-SAME: %[[ARG_DIM0:[a-z0-9]+]]: index, %[[ARG_DIM1:[a-z0-9]+]]: index) -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK: %[[INIT_TILE:.*]] = gml_st.tile -// CHECK-SAME: [%[[I]], %[[J]]] [%[[ARG_DIM0]], %[[ARG_DIM1]]] +// ----- -// CHECK: %[[DIM_2:.*]] = tensor.dim %[[A]], %[[C1]] -// CHECK: %[[MINUI:.*]] = arith.minui %[[J]], %[[DIM_2]] -// CHECK: %[[SUBI:.*]] = arith.subi %[[DIM_2]], %[[MINUI]] -// CHECK: %[[MINUI_0:.*]] = arith.minui %[[SUBI]], %[[ARG_DIM1]] -// CHECK: %[[A_TILE:.*]] = gml_st.tile -// CHECK-SAME: [%[[I]], %[[MINUI]]] -// CHECK-SAME: [%[[ARG_DIM0]], %[[MINUI_0]]] -// CHECK: %[[A_SUB:.*]] = gml_st.materialize %[[A]][%[[A_TILE]]] - -// CHECK: %[[CMPI:.*]] = arith.cmpi ule, %[[J]], %[[DIM_2]] -// CHECK: %[[SUBI_0:.*]] = arith.subi %[[J]], %[[DIM_2]] -// CHECK: %[[SELECT:.*]] = arith.select %[[CMPI]], %[[C0]], %[[SUBI_0]] -// CHECK: %[[DIM_3:.*]] = tensor.dim %[[B]], %[[C1]] -// CHECK: %[[MINUI_1:.*]] = arith.minui %[[SELECT]], %[[DIM_3]] -// CHECK: %[[SUBI_1:.*]] = arith.subi %[[DIM_3]], %[[MINUI_1]] -// CHECK: %[[MINUI_2:.*]] = arith.minui %[[SUBI_1]], %[[ARG_DIM1]] -// CHECK: %[[B_TILE:.*]] = gml_st.tile -// CHECK-SAME: [%[[I]], %[[MINUI_1]]] -// CHECK-SAME: [%[[ARG_DIM0]], %[[MINUI_2]]] -// CHECK: %[[B_SUB:.*]] = gml_st.materialize %[[B]][%[[B_TILE]]] - -// CHECK: %[[CMPI_0:.*]] = arith.cmpi ule, %[[SELECT]], %[[DIM_3]] -// CHECK: %[[SUBI_2:.*]] = arith.subi %[[SELECT]], %[[DIM_3]] -// CHECK: %[[SELECT_0:.*]] = arith.select %[[CMPI_0]], %[[C0]], %[[SUBI_2]] -// CHECK: %[[DIM_4:.*]] = tensor.dim %[[C]], %[[C1]] -// CHECK: %[[MINUI_3:.*]] = arith.minui %[[SELECT_0]], %[[DIM_4]] -// CHECK: %[[SUBI_3:.*]] = arith.subi %[[DIM_4]], %[[MINUI_3]] -// CHECK: %[[MINUI_4:.*]] = arith.minui %[[SUBI_3]], %[[ARG_DIM1]] -// CHECK: %[[C_TILE:.*]] = gml_st.tile -// CHECK-SAME: [%[[I]], %[[MINUI_3]]] -// CHECK-SAME: [%[[ARG_DIM0]], %[[MINUI_4]]] -// CHECK: %[[C_SUB:.*]] = gml_st.materialize %[[C]][%[[C_TILE]]] -// CHECK: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]][%[[INIT_TILE]]] -// CHECK: %[[CONCATENATE:.*]] = thlo.concatenate -// CHECK-SAME: ins(%[[A_SUB]] : tensor, %[[B_SUB]] : tensor, -// CHECK-SAME: %[[C_SUB]] : tensor) -// CHECK-SAME: outs(%[[INIT_SUB]] : tensor) -// CHECK-SAME: dimension = 1 -// CHECK: return {op_label = "consumer"} %[[CONCATENATE]] +// CHECK-LABEL: @binary_concatenate_with_unit_dims +// CHECK-SAME: %[[INIT:.*]]: tensor, %[[ARG_A:.*]]: tensor, %[[ARG_B:.*]]: tensor, %[[I:.*]]: index, %[[J:.*]]: index, %[[N:.*]]: index +func.func @binary_concatenate_with_unit_dims(%init : tensor, + %a: tensor, %b: tensor, %i: index, %j: index, %n: index) + -> tensor { + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG_A]], %[[C1]] + // CHECK: %[[CMPI:.*]] = arith.cmpi ult, %[[J]], %[[DIM]] + // CHECK: %[[IF:.*]] = scf.if %[[CMPI]] -> (tensor) + // CHECK: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[ARG_A]][%[[I]], %[[J]]] [%[[N]], 1] [1, 1] + // CHECK: scf.yield %[[MATERIALIZE]] + // CHECK: else + // CHECK: %[[SUBI:.*]] = arith.subi %[[J]], %[[DIM]] + // CHECK: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[ARG_B]][%[[I]], %[[SUBI]]] [%[[N]], 1] [1, 1] + // CHECK: scf.yield %[[MATERIALIZE_0]] + // CHECK: return {op_label = "consumer"} %[[IF]] + %concat = thlo.concatenate ins(%a : tensor, %b : tensor) + outs(%init : tensor) dimension = 1 { op_label = "producer" } + %tiled_concat = tensor.extract_slice %concat[%i, %j] [%n, 1] [1, 1] + : tensor to tensor + func.return { op_label = "consumer" } %tiled_concat : tensor +} // ----- @@ -140,9 +130,8 @@ func.func @add(%lhs: tensor<32x32xf32>, %rhs: tensor<32x32xf32>, %i: index, %add = arith.addf %lhs_scalar, %rhs_scalar : f32 linalg.yield %add : f32 } -> tensor<32x32xf32> - %tile = gml_st.tile [%i, %j] [%arg_dim0, %arg_dim1] [1, 1] : !gml_st.tile - %result = gml_st.materialize %linalg[%tile] - : tensor<32x32xf32>[!gml_st.tile] to tensor + %result = tensor.extract_slice %linalg[%i, %j] [%arg_dim0, %arg_dim1] [1, 1] + : tensor<32x32xf32> to tensor return { op_label = "consumer" } %result : tensor } // CHECK-LABEL: @add @@ -152,11 +141,12 @@ func.func @add(%lhs: tensor<32x32xf32>, %rhs: tensor<32x32xf32>, %i: index, // CHECK-SAME: %[[ARG_DIM0:[a-z0-9]+]]: index, %[[ARG_DIM1:[a-z0-9]+]]: index) // CHECK: %[[INIT:.*]] = tensor.empty() -// CHECK: %[[TILE:.*]] = gml_st.tile +// CHECK: %[[LHS_SUB:.*]] = tensor.extract_slice %[[LHS]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[ARG_DIM0]], %[[ARG_DIM1]]] +// CHECK: %[[RHS_SUB:.*]] = tensor.extract_slice %[[RHS]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[ARG_DIM0]], %[[ARG_DIM1]]] +// CHECK: %[[INIT_SUB:.*]] = tensor.extract_slice %[[INIT]] // CHECK-SAME: [%[[I]], %[[J]]] [%[[ARG_DIM0]], %[[ARG_DIM1]]] -// CHECK: %[[LHS_SUB:.*]] = gml_st.materialize %[[LHS]][%[[TILE]]] -// CHECK: %[[RHS_SUB:.*]] = gml_st.materialize %[[RHS]][%[[TILE]]] -// CHECK: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]][%[[TILE]]] // CHECK: %[[GENERIC:.*]] = linalg.generic // CHECK-SAME: iterator_types = ["parallel", "parallel"] // CHECK-SAME: ins(%[[LHS_SUB]], %[[RHS_SUB]] : tensor, @@ -187,9 +177,8 @@ func.func @empty(%lhs: tensor, %rhs: tensor, linalg.yield %arg2 : f32 } -> tensor - %tile = gml_st.tile [%i, %j] [1, 1] [1, 1] : !gml_st.tile<1x1> - %elem = gml_st.materialize %result[%tile] - : tensor[!gml_st.tile<1x1>] to tensor<1x1xf32> + %elem = tensor.extract_slice %result[%i, %j] [1, 1] [1, 1] + : tensor to tensor<1x1xf32> return { op_label = "consumer" } %elem : tensor<1x1xf32> } // CHECK-LABEL: @empty @@ -202,8 +191,8 @@ func.func @empty(%lhs: tensor, %rhs: tensor, // CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG0]], %[[C1]] // CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM]], %[[DIM_0]]) -// CHECK: %[[TILE:.*]] = gml_st.tile [%[[I]], %[[J]]] [1, 1] -// CHECK: %[[MATERIALIZE:.*]] = gml_st.materialize %[[INIT]][%[[TILE]]] +// CHECK: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[INIT]] +// CHECK-SAME: [%[[I]], %[[J]]] [1, 1] // CHECK: return {op_label = "consumer"} %[[MATERIALIZE]] // ----- @@ -226,8 +215,8 @@ func.func @dim_reification_fission(%arg: tensor) -> index { func.func @dim_reification_materialize(%arg: tensor, %arg_dim0: index, %arg_dim1: index) -> index { %c0 = arith.constant 0 : index - %tile = gml_st.tile [0, 0] [%arg_dim0, %arg_dim1] [1, 1] : !gml_st.tile - %0 = gml_st.materialize %arg[%tile] : tensor[!gml_st.tile] to tensor + %0 = tensor.extract_slice %arg[0, 0] [%arg_dim0, %arg_dim1] [1, 1] + : tensor to tensor %1 = tensor.dim %0, %c0 : tensor return %1 : index } @@ -296,7 +285,7 @@ func.func @dim_reification_concatenate(%init : tensor, %concat = thlo.concatenate ins(%a : tensor, %b : tensor, %c : tensor) outs(%init : tensor) - {dimension = 1 : i64} + dimension = 1 %dim = tensor.dim %concat, %c1 : tensor func.return %dim : index } @@ -321,13 +310,15 @@ func.func @fusion_into_materialize_element( %0 = arith.negf %in : f32 linalg.yield %0 : f32 } -> tensor - %tile = gml_st.tile [%idx] [1] [1] : !gml_st.tile<1> - %res = gml_st.materialize %neg[%tile] : tensor[!gml_st.tile<1>] to f32 - return { op_label="consumer" } %res : f32 + %res_slice = tensor.extract_slice %neg[%idx] [1] [1] + : tensor to tensor<1xf32> + %c0 = arith.constant 0 : index + %res = tensor.extract %res_slice[%c0] { op_label="consumer" } : tensor<1xf32> + return %res : f32 } // CHECK-LABEL: @fusion_into_materialize_element -// CHECK: %[[RES:.*]] = tensor.extract -// CHECK: return {{.*}} %[[RES]] +// CHECK-COUNT-2: tensor.extract_slice +// CHECK: linalg.generic // ----- @@ -350,22 +341,20 @@ func.func @matmul(%lhs: tensor<128x16xf32>, %fill = linalg.fill { op_label = "producer" } ins(%cst : f32) outs(%init : tensor<128x256xf32>) -> tensor<128x256xf32> %matmul = gml_st.parallel (%i, %j) = (%c0, %c0) to (%c128, %c256) - step (%c8, %c8) { - %lhs_tile = gml_st.tile [%i, 0] [8, 16] [1, 1] : !gml_st.tile<8x16> - %lhs_sub = gml_st.materialize %lhs[%lhs_tile] - : tensor<128x16xf32>[!gml_st.tile<8x16>] to tensor<8x16xf32> - %rhs_tile = gml_st.tile [0, %j] [16, 8] [1, 1] : !gml_st.tile<16x8> - %rhs_sub = gml_st.materialize %rhs[%rhs_tile] - : tensor<16x256xf32>[!gml_st.tile<16x8>] to tensor<16x8xf32> - %out_tile = gml_st.tile [%i, %j] [8, 8] [1, 1] : !gml_st.tile<8x8> - %out_sub = gml_st.materialize %fill[%out_tile] - : tensor<128x256xf32>[!gml_st.tile<8x8>] to tensor<8x8xf32> + step (%c8, %c8) outs (%out = %fill: tensor<128x256xf32>) { + %lhs_sub = tensor.extract_slice %lhs[%i, 0] [8, 16] [1, 1] + : tensor<128x16xf32> to tensor<8x16xf32> + %rhs_sub = tensor.extract_slice %rhs[0, %j] [16, 8] [1, 1] + : tensor<16x256xf32> to tensor<16x8xf32> + %out_sub = tensor.extract_slice %out[%i, %j] [8, 8] [1, 1] + : tensor<128x256xf32> to tensor<8x8xf32> %matmul_sub = linalg.matmul { op_label="consumer" } ins(%lhs_sub, %rhs_sub : tensor<8x16xf32>, tensor<16x8xf32>) outs(%out_sub : tensor<8x8xf32>) -> tensor<8x8xf32> - gml_st.set_yield %matmul_sub into %fill[%out_tile] + %out_tile = gml_st.tile [%i, %j] [8, 8] [1, 1] : !gml_st.tile<8x8> + gml_st.set_yield %matmul_sub into %out[%out_tile] : tensor<8x8xf32> into tensor<128x256xf32>[!gml_st.tile<8x8>] } : tensor<128x256xf32> return %matmul : tensor<128x256xf32> @@ -377,12 +366,12 @@ func.func @matmul(%lhs: tensor<128x16xf32>, // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<128x256xf32> // CHECK: gml_st.parallel (%[[I:[a-z0-9]+]], %[[J:[a-z0-9]+]]) -// CHECK: %[[OUT_TILE:.*]] = gml_st.tile [%[[I]], %[[J]]] [8, 8] [1, 1] -// CHECK: %[[OUT_SUB:.*]] = gml_st.materialize %[[EMPTY]][%[[OUT_TILE]]] +// CHECK-SAME: outs (%[[OUT_:.*]] = %[[EMPTY]]: +// CHECK: %[[OUT_SUB:.*]] = tensor.extract_slice %[[OUT_]][%[[I]], %[[J]]] [8, 8] [1, 1] // CHECK: %[[FILL:.*]] = linalg.fill // CHECK-SAME: outs(%[[OUT_SUB]] : tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK: %[[MATMUL:.*]] = linalg.matmul // CHECK-SAME: outs(%[[FILL]] : tensor<8x8xf32>) -> tensor<8x8xf32> -// CHECK: gml_st.set_yield %[[MATMUL]] into %[[EMPTY]][%[[OUT_TILE]]] +// CHECK: %[[OUT_TILE:.*]] = gml_st.tile [%[[I]], %[[J]]] [8, 8] [1, 1] +// CHECK: gml_st.set_yield %[[MATMUL]] into %[[OUT_]][%[[OUT_TILE]]] // CHECK-SAME: : tensor<8x8xf32> into tensor<128x256xf32>[!gml_st.tile<8x8>] - diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gml_st_simtfy.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gml_st_simtfy.mlir new file mode 100644 index 00000000000..df1c36465c2 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gml_st_simtfy.mlir @@ -0,0 +1,221 @@ +// RUN: mlir-hlo-opt %s -split-input-file -verify-diagnostics \ +// RUN: --gml-st-simtfy="block-distribution-label=block" -cse \ +// RUN: | FileCheck %s +// We run CSE above to deduplicate constant definitions, which would confuse +// FileCheck. + +#map = affine_map<(d0)[s0] -> (d0 + s0)> + +func.func @simple(%arg2: memref<2048xf32>) -> memref<2048xf32> { + %c0 = arith.constant 0 : index + %c2048 = arith.constant 2048 : index + %c128 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + %2 = memref.alloc() {alignment = 64 : i64} : memref<2048xf32> + gml_st.parallel (%arg3) = (%c0) to (%c2048) step (%c128) distribution ("block") { + %3 = memref.subview %2[%arg3] [128] [1] : memref<2048xf32> to memref<128xf32, #map> + %4 = memref.subview %arg2[%arg3] [128] [1] : memref<2048xf32> to memref<128xf32, #map> + gml_st.parallel (%arg4) = (%c0) to (%c128) step (%c32) distribution ("warp") { + %5 = memref.subview %3[%arg4] [32] [1] : memref<128xf32, #map> to memref<32xf32, #map> + %6 = memref.subview %4[%arg4] [32] [1] : memref<128xf32, #map> to memref<32xf32, #map> + gml_st.parallel (%arg5) = (%c0) to (%c32) step (%c1) distribution ("thread") { + %7 = memref.load %6[%arg5] : memref<32xf32, #map> + %8 = math.log %7 : f32 + memref.store %8, %5[%arg5] : memref<32xf32, #map> + gml_st.set_yield + } + gml_st.set_yield + } + gml_st.set_yield + } + return %2 : memref<2048xf32> +} + +// CHECK-LABEL: @simple +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index +// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index +// CHECK-DAG: %[[C2048:.*]] = arith.constant 2048 : index +// CHECK-DAG: %[[MAP0:.*]] = affine.apply #map(%[[C2048]])[%[[C0]], %[[C128]]] +// CHECK-DAG: %[[MAP1:.*]] = affine.apply #map(%[[C128]])[%[[C0]], %[[C32]]] +// CHECK-DAG: %[[MAP2:.*]] = affine.apply #map(%[[C32]])[%[[C0]], %[[C1]]] +// CHECK: gpu.launch blocks +// CHECK-SAME: ({{.*}}) in ({{.*}} = %[[MAP0]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) threads +// CHECK-SAME: ({{.*}}) in ({{.*}} = %[[MAP2]], {{.*}} = %[[MAP1]], {{.*}} = %[[C1]]) +// CHECK: affine.apply {{.*}}[%[[C0]], %[[C128]]] +// CHECK-NEXT: memref.subview +// CHECK-SAME: "gml-st-distribution-label" = "block" +// CHECK: affine.apply {{.*}}[%[[C0]], %[[C32]]] +// CHECK-NEXT: memref.subview +// CHECK-SAME: "gml-st-distribution-label" = "warp" +// CHECK: affine.apply {{.*}}[%[[C0]], %[[C1]]] +// CHECK-NOT: scf.if +// CHECK: memref.load +// CHECK-NOT: "gml-st-distribution-label" +// CHECK: math.log +// CHECK: memref.store + +// ----- + +#map = affine_map<(d0)[s0] -> (d0 + s0)> + +func.func @sibling_parallels(%arg2: memref<2048xf32>) -> memref<2048xf32> { + %c0 = arith.constant 0 : index + %c2048 = arith.constant 2048 : index + %c128 = arith.constant 128 : index + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + %2 = memref.alloc() {alignment = 64 : i64} : memref<2048xf32> + gml_st.parallel (%arg3) = (%c0) to (%c2048) step (%c128) distribution ("block") { + %3 = memref.subview %2[%arg3] [128] [1] : memref<2048xf32> to memref<128xf32, #map> + %4 = memref.subview %arg2[%arg3] [128] [1] : memref<2048xf32> to memref<128xf32, #map> + gml_st.parallel (%arg4) = (%c0) to (%c128) step (%c32) distribution ("warp") { + %5 = memref.subview %3[%arg4] [32] [1] : memref<128xf32, #map> to memref<32xf32, #map> + %6 = memref.subview %4[%arg4] [32] [1] : memref<128xf32, #map> to memref<32xf32, #map> + %7 = memref.alloc() {alignment = 64 : i64} : memref<32xf32> + gml_st.parallel (%arg5) = (%c0) to (%c32) step (%c1) distribution ("thread") { + %8 = memref.load %6[%arg5] : memref<32xf32, #map> + %9 = math.log %8 : f32 + memref.store %9, %7[%arg5] : memref<32xf32> + gml_st.set_yield + } + gml_st.parallel (%arg6) = (%c0) to (%c32) step (%c1) distribution ("thread") { + %10 = memref.load %7[%arg6] : memref<32xf32> + %11 = math.absf %10 : f32 + memref.store %11, %5[%arg6] : memref<32xf32, #map> + gml_st.set_yield + } + gml_st.set_yield + } + gml_st.set_yield + } + return %2 : memref<2048xf32> +} + +// CHECK-LABEL: @sibling_parallels +// CHECK: gpu.launch blocks +// CHECK: affine.apply +// CHECK: affine.apply +// CHECK: affine.apply +// CHECK: memref.load +// CHECK: math.log +// CHECK: memref.store +// CHECK-NOT: affine.apply +// CHECK: memref.load +// CHECK: math.absf +// CHECK: memref.store + +// ----- + +func.func @too_deep_nesting() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %alloc = memref.alloc() : memref + // expected-error@+1 {{failed to simtfy}} + gml_st.parallel (%arg3) = (%c0) to (%c1) step (%c1) distribution ("block") { + gml_st.parallel (%arg4) = (%c0) to (%c1) step (%c1) distribution ("warp") { + gml_st.parallel (%arg5) = (%c0) to (%c1) step (%c1) distribution ("thread") { + gml_st.parallel (%arg6) = (%c0) to (%c1) step (%c1) { + memref.store %c0, %alloc[] : memref + gml_st.set_yield + } + gml_st.set_yield + } + gml_st.set_yield + } + gml_st.set_yield + } + return +} + + +// ----- + +func.func @mismatched_bounds() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %alloc1 = memref.alloc() : memref + %alloc2 = memref.alloc() : memref + // expected-error@+1 {{failed to simtfy}} + gml_st.parallel (%arg3) = (%c0) to (%c1) step (%c1) distribution ("block") { + gml_st.parallel (%arg4) = (%c0) to (%c1) step (%c1) distribution ("warp") { + gml_st.parallel (%arg5) = (%c0) to (%c1) step (%c1) distribution ("thread") { + memref.store %c0, %alloc1[] : memref + gml_st.set_yield + } + gml_st.parallel (%arg6) = (%c0) to (%c2) step (%c1) distribution ("thread") { + memref.store %c0, %alloc2[] : memref + gml_st.set_yield + } + gml_st.set_yield + } + gml_st.set_yield + } + return +} + +// ----- + +func.func @mmultple_induction_vars() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %alloc = memref.alloc() : memref + // expected-error@+1 {{failed to simtfy}} + gml_st.parallel (%arg1, %arg2) = (%c0, %c0) to (%c1, %c1) step (%c1, %c1) distribution ("block") { + memref.store %c0, %alloc[] : memref + gml_st.set_yield + } + return +} + +// ----- + +#layout = strided<[1], offset: ?> + +func.func @imperfect_tiling(%arg0: memref<2051xf32>) -> memref<2051xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c2051 = arith.constant 2051 : index + %0 = memref.alloc() {alignment = 64 : i64} : memref<2051xf32> + gml_st.parallel (%arg1) = (%c0) to (%c2051) step (%c128) distribution ("block") { + %1 = affine.min affine_map<(d0) -> (-d0 + 2051, 128)>(%arg1) + %2 = memref.subview %arg0[%arg1] [%1] [1] : memref<2051xf32> to memref + %3 = memref.subview %0[%arg1] [%1] [1] : memref<2051xf32> to memref + gml_st.parallel (%arg2) = (%c0) to (%1) step (%c32) distribution ("warp") { + %4 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 32)>(%arg2)[%1] + %5 = memref.subview %2[%arg2] [%4] [1] : memref to memref + %6 = memref.subview %3[%arg2] [%4] [1] : memref to memref + gml_st.parallel (%arg3) = (%c0) to (%4) step (%c1) distribution ("thread") { + %7 = memref.load %5[%arg3] : memref + %8 = math.log %7 : f32 + memref.store %8, %6[%arg3] : memref + gml_st.set_yield + } + gml_st.set_yield + } + gml_st.set_yield + } + return %0 : memref<2051xf32> +} + +// CHECK-LABEL: @imperfect_tiling +// CHECK: gpu.launch blocks(%[[BLOCKID:.*]], %{{.*}}, %{{.*}}) in {{.*}} threads +// CHECK-SAME: (%[[THREADID:.*]], %[[WARPID:.*]], %{{.*}}) in +// CHECK-DAG: %[[ARG1:.*]] = affine.apply {{.*}}(%[[BLOCKID]]) +// CHECK-DAG: %[[BTILESIZE:.*]] = affine.min {{.*}}(%[[ARG1]]) +// CHECK-DAG: %[[ARG2:.*]] = affine.apply {{.*}}(%[[WARPID]]) +// CHECK-DAG: %[[WCOND:.*]] = arith.cmpi slt, %[[ARG2:.*]], %[[BTILESIZE]] +// CHECK-DAG: scf.if %[[WCOND]] +// CHECK-DAG: %[[WTILESIZE:.*]] = affine.min {{.*}}(%[[ARG2]])[%[[BTILESIZE]]] +// CHECK-DAG: %[[ARG3:.*]] = affine.apply {{.*}}(%[[THREADID]]) +// CHECK-DAG: %[[TCOND:.*]] = arith.cmpi slt, %[[ARG3:.*]], %[[WTILESIZE]] +// CHECK-DAG: scf.if %[[TCOND]] +// CHECK: memref.load +// CHECK: math.log +// CHECK: memref.store diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gml_st_to_gpu.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gml_st_to_gpu.mlir index 86308229cb5..af05f1175de 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gml_st_to_gpu.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gml_st_to_gpu.mlir @@ -1,221 +1,11 @@ // RUN: mlir-hlo-opt %s -split-input-file -verify-diagnostics \ -// RUN: -gml-st-to-gpu -cse \ +// RUN: --gml-st-to-gpu="warp-distribution-label=warp" -cse \ // RUN: | FileCheck %s // We run CSE above to deduplicate constant definitions, which would confuse // FileCheck. -#map = affine_map<(d0)[s0] -> (d0 + s0)> - -func.func @simple(%arg2: memref<2048xf32>) -> memref<2048xf32> { - %c0 = arith.constant 0 : index - %c2048 = arith.constant 2048 : index - %c128 = arith.constant 128 : index - %c32 = arith.constant 32 : index - %c1 = arith.constant 1 : index - %2 = memref.alloc() {alignment = 64 : i64} : memref<2048xf32> - gml_st.parallel (%arg3) = (%c0) to (%c2048) step (%c128) { - %3 = memref.subview %2[%arg3] [128] [1] : memref<2048xf32> to memref<128xf32, #map> - %4 = memref.subview %arg2[%arg3] [128] [1] : memref<2048xf32> to memref<128xf32, #map> - gml_st.parallel (%arg4) = (%c0) to (%c128) step (%c32) { - %5 = memref.subview %3[%arg4] [32] [1] : memref<128xf32, #map> to memref<32xf32, #map> - %6 = memref.subview %4[%arg4] [32] [1] : memref<128xf32, #map> to memref<32xf32, #map> - gml_st.parallel (%arg5) = (%c0) to (%c32) step (%c1) { - %7 = memref.load %6[%arg5] : memref<32xf32, #map> - %8 = math.log %7 : f32 - memref.store %8, %5[%arg5] : memref<32xf32, #map> - gml_st.set_yield - } - gml_st.set_yield - } - gml_st.set_yield - } - return %2 : memref<2048xf32> -} - -// CHECK-LABEL: @simple -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index -// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index -// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index -// CHECK: gpu.launch blocks -// CHECK-SAME: ({{.*}}) in ({{.*}} = %[[C16]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) threads -// CHECK-SAME: ({{.*}}) in ({{.*}} = %[[C32]], {{.*}} = %[[C4]], {{.*}} = %[[C1]]) -// CHECK: affine.apply {{.*}}[%[[C0]], %[[C128]]] -// CHECK: affine.apply {{.*}}[%[[C0]], %[[C32]]] -// CHECK: affine.apply {{.*}}[%[[C0]], %[[C1]]] -// CHECK-NOT: scf.if -// CHECK: memref.load -// CHECK: math.log -// CHECK: memref.store - -// ----- - -#map = affine_map<(d0)[s0] -> (d0 + s0)> - -func.func @sibling_parallels(%arg2: memref<2048xf32>) -> memref<2048xf32> { - %c0 = arith.constant 0 : index - %c2048 = arith.constant 2048 : index - %c128 = arith.constant 128 : index - %c32 = arith.constant 32 : index - %c1 = arith.constant 1 : index - %2 = memref.alloc() {alignment = 64 : i64} : memref<2048xf32> - gml_st.parallel (%arg3) = (%c0) to (%c2048) step (%c128) { - %3 = memref.subview %2[%arg3] [128] [1] : memref<2048xf32> to memref<128xf32, #map> - %4 = memref.subview %arg2[%arg3] [128] [1] : memref<2048xf32> to memref<128xf32, #map> - gml_st.parallel (%arg4) = (%c0) to (%c128) step (%c32) { - %5 = memref.subview %3[%arg4] [32] [1] : memref<128xf32, #map> to memref<32xf32, #map> - %6 = memref.subview %4[%arg4] [32] [1] : memref<128xf32, #map> to memref<32xf32, #map> - %7 = memref.alloc() {alignment = 64 : i64} : memref<32xf32> - gml_st.parallel (%arg5) = (%c0) to (%c32) step (%c1) { - %8 = memref.load %6[%arg5] : memref<32xf32, #map> - %9 = math.log %8 : f32 - memref.store %9, %7[%arg5] : memref<32xf32> - gml_st.set_yield - } - gml_st.parallel (%arg6) = (%c0) to (%c32) step (%c1) { - %10 = memref.load %7[%arg6] : memref<32xf32> - %11 = math.absf %10 : f32 - memref.store %11, %5[%arg6] : memref<32xf32, #map> - gml_st.set_yield - } - gml_st.set_yield - } - gml_st.set_yield - } - return %2 : memref<2048xf32> -} - -// CHECK-LABEL: @sibling_parallels -// CHECK: gpu.launch blocks -// CHECK: affine.apply -// CHECK: affine.apply -// CHECK: affine.apply -// CHECK: memref.load -// CHECK: math.log -// CHECK: memref.store -// CHECK-NOT: affine.apply -// CHECK: memref.load -// CHECK: math.absf -// CHECK: memref.store - -// ----- - -func.func @too_deep_nesting() { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %alloc = memref.alloc() : memref - gml_st.parallel (%arg3) = (%c0) to (%c1) step (%c1) { - gml_st.parallel (%arg4) = (%c0) to (%c1) step (%c1) { - gml_st.parallel (%arg5) = (%c0) to (%c1) step (%c1) { - // expected-error@+1 {{failed to simtfy}} - gml_st.parallel (%arg6) = (%c0) to (%c1) step (%c1) { - memref.store %c0, %alloc[] : memref - gml_st.set_yield - } - gml_st.set_yield - } - gml_st.set_yield - } - gml_st.set_yield - } - return -} - - -// ----- - -func.func @mismatched_bounds() { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %alloc1 = memref.alloc() : memref - %alloc2 = memref.alloc() : memref - gml_st.parallel (%arg3) = (%c0) to (%c1) step (%c1) { - gml_st.parallel (%arg4) = (%c0) to (%c1) step (%c1) { - // expected-error@+1 {{failed to simtfy}} - gml_st.parallel (%arg5) = (%c0) to (%c1) step (%c1) { - memref.store %c0, %alloc1[] : memref - gml_st.set_yield - } - gml_st.parallel (%arg6) = (%c0) to (%c2) step (%c1) { - memref.store %c0, %alloc2[] : memref - gml_st.set_yield - } - gml_st.set_yield - } - gml_st.set_yield - } - return -} - -// ----- - -func.func @mmultple_induction_vars() { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %alloc = memref.alloc() : memref - // expected-error@+1 {{failed to simtfy}} - gml_st.parallel (%arg1, %arg2) = (%c0, %c0) to (%c1, %c1) step (%c1, %c1) { - memref.store %c0, %alloc[] : memref - gml_st.set_yield - } - return -} - -// ----- - -#layout = strided<[1], offset: ?> - -func.func @imperfect_tiling(%arg0: memref<2051xf32>) -> memref<2051xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - %c128 = arith.constant 128 : index - %c2051 = arith.constant 2051 : index - %0 = memref.alloc() {alignment = 64 : i64} : memref<2051xf32> - gml_st.parallel (%arg1) = (%c0) to (%c2051) step (%c128) { - %1 = affine.min affine_map<(d0) -> (-d0 + 2051, 128)>(%arg1) - %2 = memref.subview %arg0[%arg1] [%1] [1] : memref<2051xf32> to memref - %3 = memref.subview %0[%arg1] [%1] [1] : memref<2051xf32> to memref - gml_st.parallel (%arg2) = (%c0) to (%1) step (%c32) { - %4 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 32)>(%arg2)[%1] - %5 = memref.subview %2[%arg2] [%4] [1] : memref to memref - %6 = memref.subview %3[%arg2] [%4] [1] : memref to memref - gml_st.parallel (%arg3) = (%c0) to (%4) step (%c1) { - %7 = memref.load %5[%arg3] : memref - %8 = math.log %7 : f32 - memref.store %8, %6[%arg3] : memref - gml_st.set_yield - } - gml_st.set_yield - } - gml_st.set_yield - } - return %0 : memref<2051xf32> -} - -// CHECK-LABEL: @imperfect_tiling -// CHECK: gpu.launch blocks(%[[BLOCKID:.*]], %{{.*}}, %{{.*}}) in {{.*}} threads -// CHECK-SAME: (%[[THREADID:.*]], %[[WARPID:.*]], %{{.*}}) in -// CHECK-DAG: %[[ARG1:.*]] = affine.apply {{.*}}(%[[BLOCKID]]) -// CHECK-DAG: %[[BTILESIZE:.*]] = affine.min {{.*}}(%[[ARG1]]) -// CHECK-DAG: %[[ARG2:.*]] = affine.apply {{.*}}(%[[WARPID]]) -// CHECK-DAG: %[[WCOND:.*]] = arith.cmpi slt, %[[ARG2:.*]], %[[BTILESIZE]] -// CHECK-DAG: scf.if %[[WCOND]] -// CHECK-DAG: %[[WTILESIZE:.*]] = affine.min {{.*}}(%[[ARG2]])[%[[BTILESIZE]]] -// CHECK-DAG: %[[ARG3:.*]] = affine.apply {{.*}}(%[[THREADID]]) -// CHECK-DAG: %[[TCOND:.*]] = arith.cmpi slt, %[[ARG3:.*]], %[[WTILESIZE]] -// CHECK-DAG: scf.if %[[TCOND]] -// CHECK: memref.load -// CHECK: math.log -// CHECK: memref.store - -// ----- - -#layout = strided<[1], offset: ?> +#map = affine_map<(d0)[s0, s1] -> ((d0 - s0) ceildiv s1)> +#map1 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> func.func @vectorized_tiling(%arg0: memref<2048xf32>) -> memref<2048xf32> { %c2048 = arith.constant 2048 : index @@ -223,38 +13,32 @@ func.func @vectorized_tiling(%arg0: memref<2048xf32>) -> memref<2048xf32> { %c1024 = arith.constant 1024 : index %c128 = arith.constant 128 : index %c4 = arith.constant 4 : index - %cst = arith.constant 0.000000e+00 : f32 + %c0f = arith.constant 0.0 : f32 %alloc = memref.alloc() {alignment = 64 : i64} : memref<2048xf32> - gml_st.parallel (%arg1) = (%c0) to (%c2048) step (%c1024) { - %subview = memref.subview %arg0[%arg1] [1024] [1] - : memref<2048xf32> to memref<1024xf32, #layout> - %subview_0 = memref.subview %alloc[%arg1] [1024] [1] - : memref<2048xf32> to memref<1024xf32, #layout> - gml_st.parallel (%arg2) = (%c0) to (%c1024) step (%c128) { - %subview_1 = memref.subview %subview[%arg2] [128] [1] - : memref<1024xf32, #layout> to memref<128xf32, #layout> - %0 = vector.transfer_read %subview_1[%c0], %cst {in_bounds = [true]} - : memref<128xf32, #layout>, vector<128xf32> - %subview_2 = memref.subview %subview_0[%arg2] [128] [1] - : memref<1024xf32, #layout> to memref<128xf32, #layout> - %1 = vector.transfer_read %subview_2[%c0], %cst {in_bounds = [true]} - : memref<128xf32, #layout>, vector<128xf32> - %2 = gml_st.parallel (%arg3) = (%c0) to (%c128) step (%c4) { - %4 = gml_st.tile [%arg3] [4] [1] : !gml_st.tile<4> - %5 = gml_st.materialize %0[%4] - : vector<128xf32>[!gml_st.tile<4>] to vector<4xf32> - %6 = math.absf %5 : vector<4xf32> - gml_st.set_yield %6 into %1[%4] - : vector<4xf32> into vector<128xf32>[!gml_st.tile<4>] - } : vector<128xf32> - vector.transfer_write %2, %subview_2[%c0] {in_bounds = [true]} - : vector<128xf32>, memref<128xf32, #layout> - gml_st.set_yield - } - gml_st.set_yield + %c1 = arith.constant 1 : index + %map0 = affine.apply #map(%c2048)[%c0, %c1024] + %map1 = affine.apply #map(%c1024)[%c0, %c128] + %map2 = affine.apply #map(%c128)[%c0, %c4] + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %map0, %grid_y = %c1, %grid_z = %c1) threads(%tx, %ty, %tz) in (%block_x = %map2, %block_y = %map1, %block_z = %c1) { + %apply_bx = affine.apply #map1(%bx)[%c0, %c1024] + %block_arg = memref.subview %arg0[%apply_bx] [1024] [1] {"gml-st-distribution-label" = "block"} : memref<2048xf32> to memref<1024xf32, strided<[1], offset: ?>> + %block_out = memref.subview %alloc[%apply_bx] [1024] [1] {"gml-st-distribution-label" = "block"} : memref<2048xf32> to memref<1024xf32, strided<[1], offset: ?>> + %apply_ty = affine.apply #map1(%ty)[%c0, %c128] + %warp_arg = memref.subview %block_arg[%apply_ty] [128] [1] {"gml-st-distribution-label" = "warp"} : memref<1024xf32, strided<[1], offset: ?>> to memref<128xf32, strided<[1], offset: ?>> + %transfer_read = vector.transfer_read %warp_arg[%c0], %c0f {"gml-st-distribution-label" = "warp", in_bounds = [true]} : memref<128xf32, strided<[1], offset: ?>>, vector<128xf32> + %warp_out = memref.subview %block_out[%apply_ty] [128] [1] {"gml-st-distribution-label" = "warp"} : memref<1024xf32, strided<[1], offset: ?>> to memref<128xf32, strided<[1], offset: ?>> + %apply_tx = affine.apply #map1(%tx)[%c0, %c4] + %materialized_tile = gml_st.materialize %transfer_read[%apply_tx] [4] [1] + : vector<128xf32> to vector<4xf32> + %result = math.absf %materialized_tile : vector<4xf32> + %tile = gml_st.tile [%apply_tx] [4] [1] : !gml_st.tile<4> + %distribute = gml_st.distribute %result into[%tile] : vector<4xf32> into vector<128xf32>[!gml_st.tile<4>] + vector.transfer_write %distribute, %warp_out[%c0] {"gml-st-distribution-label" = "warp", in_bounds = [true]} : vector<128xf32>, memref<128xf32, strided<[1], offset: ?>> + gpu.terminator } return %alloc : memref<2048xf32> } + // CHECK-LABEL: @vectorized_tiling // CHECK-SAME: %[[ARG:.*]]: memref // CHECK: %[[OUT:.*]] = memref.alloc @@ -282,9 +66,8 @@ func.func @materialize_scalar_of_transfer_read( %c0 = arith.constant 0 : index %vector = vector.transfer_read %in[%c0], %pad {in_bounds = [true]} : memref<32xf32>, vector<32xf32> - %tile = gml_st.tile [%idx] [1] [1] : !gml_st.tile<1> - %value = gml_st.materialize %vector[%tile] - : vector<32xf32>[!gml_st.tile<1>] to f32 + %value = gml_st.materialize %vector[%idx] [1] [1] + : vector<32xf32> to f32 return %value : f32 } // CHECK-LABEL: @materialize_scalar_of_transfer_read( @@ -292,4 +75,4 @@ func.func @materialize_scalar_of_transfer_read( // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[SUBVIEW:.*]] = memref.subview %[[IN]][%[[IDX]]] // CHECK: %[[VALUE:.*]] = memref.load %[[SUBVIEW]][%[[C0]]] -// CHECK: return %[[VALUE]] : f32 \ No newline at end of file +// CHECK: return %[[VALUE]] : f32 diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gml_st_to_scf.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gml_st_to_scf.mlir index fcfe5bc5763..90b7132247f 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gml_st_to_scf.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gml_st_to_scf.mlir @@ -1,113 +1,5 @@ // RUN: mlir-hlo-opt %s -gml-st-to-scf -split-input-file | FileCheck %s -#map0 = affine_map<(d0) -> (24, -d0 + 192)> -#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> -#map2 = affine_map<(d0) -> (16, -d0 + 192)> - -func.func @loop(%A: memref<192x192xf32>, - %B: memref<192x192xf32>, - %C: memref<192x192xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - %c24 = arith.constant 24 : index - %c16 = arith.constant 16 : index - %c0 = arith.constant 0 : index - %c192 = arith.constant 192 : index - - gml_st.loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16) - ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) - outs (%C_ = %C: memref<192x192xf32>) { - %0 = affine.min #map0(%i) - %1 = memref.subview %A_[%i, 0] [%0, 192] [1, 1] - : memref<192x192xf32> to memref - %2 = affine.min #map2(%j) - %3 = memref.subview %B_[0, %j] [192, %2] [1, 1] - : memref<192x192xf32> to memref<192x?xf32, #map1> - %4 = memref.subview %C_[%i, %j] [%0, %2] [1, 1] - : memref<192x192xf32> to memref - linalg.fill ins(%cst : f32) outs(%4 : memref) - linalg.matmul ins(%1, %3 : memref, - memref<192x?xf32, #map1>) - outs(%4 : memref) - gml_st.yield - } - func.return -} - -// CHECK-LABEL: @loop -// CHECK-SAME: %[[A:.*]]: memref<192x192xf32>, %[[B:.*]]: memref<192x192xf32>, -// CHECK-SAME: %[[C:.*]]: memref<192x192xf32>) { -// CHECK-DAG: %[[C24:.*]] = arith.constant 24 : index -// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C192:.*]] = arith.constant 192 : index -// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]]) { -// CHECK: %[[A_sub:.*]] = memref.subview %[[A]][%[[I]] -// CHECK: %[[B_sub:.*]] = memref.subview %[[B]][0, %[[J]]] -// CHECK: %[[C_sub:.*]] = memref.subview %[[C]][%[[I]] -// CHECK: linalg.fill -// CHECK: linalg.matmul - -// ----- - - -func.func @parallel(%A: memref<192x192xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - %c24 = arith.constant 24 : index - %c16 = arith.constant 16 : index - %c0 = arith.constant 0 : index - %c192 = arith.constant 192 : index - - gml_st.parallel (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16) { - linalg.fill ins(%cst : f32) outs(%A : memref<192x192xf32>) - gml_st.set_yield - } - func.return -} - -// CHECK-LABEL: @parallel -// CHECK-SAME: %[[A:.*]]: memref<192x192xf32> -// CHECK-DAG: %[[C24:.*]] = arith.constant 24 : index -// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C192:.*]] = arith.constant 192 : index -// CHECK: scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]]) { -// CHECK: linalg.fill - -// ----- - -func.func @loop_reduction(%A: memref<192x192xf32>, - %B: memref<192x192xf32>, - %C: memref) { - %c24 = arith.constant 24 : index - %c16 = arith.constant 16 : index - %c0 = arith.constant 0 : index - %c192 = arith.constant 192 : index - %cst = arith.constant 0.000000e+00 : f32 - - gml_st.loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16) - ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) - outs (%C_ = %C: memref) - iterators[#gml_st.iterator_type, - #gml_st.iterator_type] { - linalg.fill ins(%cst : f32) outs(%A_ : memref<192x192xf32>) - gml_st.yield - } - func.return -} - -// CHECK-LABEL: @loop_reduction -// CHECK-DAG: %[[C24:.*]] = arith.constant 24 : index -// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C192:.*]] = arith.constant 192 : index -// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C24]] -// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C16]] -// CHECK: linalg.fill - -// ----- - func.func @for(%A: memref<192x192xf32>) { %c24 = arith.constant 24 : index %c16 = arith.constant 16 : index @@ -133,112 +25,6 @@ func.func @for(%A: memref<192x192xf32>) { // ----- -#strided_1d = affine_map<(d0)[s0] -> (d0 + s0)> -#strided_2d = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)> - -func.func @loop_row_reduction(%A: memref<10x8xf32>, - %B: memref<8xf32>) { - %c0 = arith.constant 0 : index - %c2 = arith.constant 2 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c10 = arith.constant 10 : index - %cst = arith.constant 0.000000e+00 : f32 - - gml_st.loop (%i, %j) = (%c0, %c0) to (%c10, %c8) step (%c2, %c4) - ins (%A_ = %A: memref<10x8xf32>) - outs (%B_ = %B: memref<8xf32>) - iterators[#gml_st.iterator_type, - #gml_st.iterator_type] { - %A_sub = memref.subview %A_[%i, %j][2, 4][1, 1] - : memref<10x8xf32> to memref<2x4xf32, #strided_2d> - %B_sub = memref.subview %B_[%j][4][1] - : memref<8xf32> to memref<4xf32, #strided_1d> - linalg.generic { - indexing_maps = [affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (j)>], - iterator_types = ["reduction", "parallel"]} - ins(%A_sub : memref<2x4xf32, #strided_2d>) - outs(%B_sub : memref<4xf32, #strided_1d>) { - ^bb(%a: f32, %b: f32) : - %0 = arith.addf %a, %b: f32 - linalg.yield %0 : f32 - } - gml_st.yield - } - func.return -} - -// CHECK-LABEL: @loop_row_reduction - -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index - -// CHECK: scf.parallel (%[[J:.*]]) = (%[[C0]]) to (%[[C8]]) step (%[[C4]]) -// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C10]] step %[[C2]] -// CHECK-NEXT: memref.subview %arg{{[0-9]+}}[%[[I]], %[[J]]] [2, 4] [1, 1] -// CHECK-SAME: : memref<10x8xf32> to memref<2x4xf32, #map{{[0-9]*}}> -// CHECK-NEXT: memref.subview %arg{{[0-9]+}}[%[[J]]] [4] [1] -// CHECK-SAME: : memref<8xf32> to memref<4xf32, #map{{[0-9]*}}> - -// ----- - -#strided_1d = affine_map<(d0)[s0] -> (d0 + s0)> -#strided_2d = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)> - -func.func @loop_col_reduction(%A: memref<10x8xf32>, - %B: memref<10xf32>) { - %c0 = arith.constant 0 : index - %c2 = arith.constant 2 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %c10 = arith.constant 10 : index - %cst = arith.constant 0.000000e+00 : f32 - - gml_st.loop (%i, %j) = (%c0, %c0) to (%c10, %c8) step (%c2, %c4) - ins (%A_ = %A: memref<10x8xf32>) - outs (%B_ = %B: memref<10xf32>) - iterators[#gml_st.iterator_type, - #gml_st.iterator_type] { - %A_sub = memref.subview %A_[%i, %j][2, 4][1, 1] - : memref<10x8xf32> to memref<2x4xf32, #strided_2d> - %B_sub = memref.subview %B_[%i][2][1] - : memref<10xf32> to memref<2xf32, #strided_1d> - linalg.generic { - indexing_maps = [affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (i)>], - iterator_types = ["parallel", "reduction"]} - ins(%A_sub : memref<2x4xf32, #strided_2d>) - outs(%B_sub : memref<2xf32, #strided_1d>) { - ^bb(%a: f32, %b: f32) : - %0 = arith.addf %a, %b: f32 - linalg.yield %0 : f32 - } - gml_st.yield - } - func.return -} - -// CHECK-LABEL: @loop_col_reduction - -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index - -// CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[C10]]) step (%[[C2]]) -// CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] to %[[C8]] step %[[C4]] -// CHECK-NEXT: memref.subview %arg{{[0-9]+}}[%[[I]], %[[J]]] [2, 4] [1, 1] -// CHECK-SAME: : memref<10x8xf32> to memref<2x4xf32, #map{{[0-9]*}}> -// CHECK-NEXT: memref.subview %arg{{[0-9]+}}[%[[I]]] [2] [1] -// CHECK-SAME: : memref<10xf32> to memref<2xf32, #map{{[0-9]*}}> - -// ----- - func.func @for_with_result(%arg: vector<4xf32>) -> vector<4xf32> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_cwise.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gpu_tiling/tiling_cwise.mlir similarity index 70% rename from tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_cwise.mlir rename to tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gpu_tiling/tiling_cwise.mlir index bda22ae481e..70172d03e2e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_cwise.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gpu_tiling/tiling_cwise.mlir @@ -42,21 +42,28 @@ func.func @cwise_expr(%a: tensor, %b: tensor, // CHECK-DAG: %[[C1024:.*]] = arith.constant 1024 // CHECK-DAG: %[[A_D0:.*]] = tensor.dim %[[A]], %[[C0]] // CHECK-DAG: %[[INIT:.*]] = tensor.empty(%[[A_D0]]) -// CHECK: %[[ABC:.*]] = gml_st.parallel +// CHECK: %[[ABC:.*]] = gml_st.parallel // CHECK-SAME: (%[[I:.*]], %[[J:.*]], %[[K:.*]]) = (%[[C0]], %[[C0]], %[[C0]]) // CHECK-SAME: to (%[[A_D0]], %[[C1024]], %[[C1024]]) // CHECK-SAME: step (%[[C1]], %[[C4]], %[[C8]]) +// CHECK-SAME: outs (%[[INIT_:.*]] = %[[INIT]]: // CHECK-SAME: distribution ("test") -// CHECK-DAG: %[[TILE:.*]] = gml_st.tile [%[[I]], %[[J]], %[[K]]] [1, 4, 8] [1, 1, 1] -// CHECK-DAG: %[[A_SUB:.*]] = gml_st.materialize %[[A]][%[[TILE]]] -// CHECK-DAG: %[[B_SUB:.*]] = gml_st.materialize %[[B]][%[[TILE]]] -// CHECK-DAG: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]][%[[TILE]]] -// CHECK-DAG: %[[AB_SUB:.*]] = linalg.generic -// CHECK-SAME: ins(%[[A_SUB]], %[[B_SUB]] : tensor<1x4x8xf32>, tensor<1x4x8xf32>) +// CHECK-DAG: %[[A_SUB:.*]] = tensor.extract_slice %[[A]] +// CHECK-SAME: [%[[I]], %[[J]], %[[K]]] [1, 4, 8] [1, 1, 1] +// CHECK-DAG: %[[B_SUB:.*]] = tensor.extract_slice %[[B]] +// CHECK-SAME: [%[[I]], %[[J]], %[[K]]] [1, 4, 8] [1, 1, 1] +// CHECK-DAG: %[[INIT_SUB:.*]] = tensor.extract_slice %[[INIT]] +// CHECK-SAME: [%[[I]], %[[J]], %[[K]]] [1, 4, 8] [1, 1, 1] +// CHECK-DAG: %[[AB_SUB:.*]] = linalg.generic +// CHECK-SAME: ins(%[[A_SUB]], %[[B_SUB]] : tensor<1x4x8xf32>, tensor<1x4x8xf32>) // CHECK-SAME: outs(%[[INIT_SUB]] : tensor<1x4x8xf32>) -// CHECK-DAG: %[[C_SUB:.*]] = gml_st.materialize %[[C]][%[[TILE]]] +// CHECK-DAG: %[[C_SUB:.*]] = tensor.extract_slice %[[C]] +// CHECK-SAME: [%[[I]], %[[J]], %[[K]]] [1, 4, 8] [1, 1, 1] +// CHECK-DAG: %[[INIT_SUB_:.*]] = tensor.extract_slice %[[INIT_]] +// CHECK-SAME: [%[[I]], %[[J]], %[[K]]] [1, 4, 8] [1, 1, 1] // CHECK-DAG: %[[ABC_SUB:.*]] = linalg.generic -// CHECK-SAME: ins(%[[AB_SUB]], %[[C_SUB]] : tensor<1x4x8xf32>, tensor<1x4x8xf32>) -// CHECK-SAME: outs(%[[INIT_SUB]] : tensor<1x4x8xf32>) -// CHECK: gml_st.set_yield %[[ABC_SUB]] into %[[INIT]][%[[TILE]]] +// CHECK-SAME: ins(%[[AB_SUB]], %[[C_SUB]] : tensor<1x4x8xf32>, tensor<1x4x8xf32>) +// CHECK-SAME: outs(%[[INIT_SUB_]] : tensor<1x4x8xf32>) +// CHECK-DAG: %[[TILE:.*]] = gml_st.tile [%[[I]], %[[J]], %[[K]]] [1, 4, 8] [1, 1, 1] +// CHECK: gml_st.set_yield %[[ABC_SUB]] into %[[INIT_]][%[[TILE]]] // CHECK: return %[[ABC]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gpu_tiling/tiling_gpu_warp.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gpu_tiling/tiling_gpu_warp.mlir new file mode 100644 index 00000000000..6018b2b5d3a --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/gpu_tiling/tiling_gpu_warp.mlir @@ -0,0 +1,316 @@ +// RUN: mlir-hlo-opt %s --split-input-file --gml-tiling-gpu-warp | \ +// RUN: FileCheck %s + +// CHECK-LABEL: @tiling_warp_level_reduction +// CHECK-SAME: %[[ARG0:.*]]: tensor<7x13xf32> +func.func @tiling_warp_level_reduction(%arg0: tensor<7x13xf32>) + -> tensor<7xf32> { + // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index + // CHECK-DAG: %[[C13:.*]] = arith.constant 13 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C1024:.*]] = arith.constant 1024 : index + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[CST:.*]] = arith.constant 0xFF800000 : f32 + // CHECK: %[[EMPTY:.*]] = tensor.empty() + // CHECK: %[[PARALLEL:.*]] = gml_st.parallel (%[[ARG1:.*]]) = (%[[C0]]) to (%[[C1024]]) step (%[[C1]]) + // CHECK-SAME: outs (%[[EMPTY_:.*]] = %[[EMPTY]]: + // CHECK-SAME: distribution ("warp") + // CHECK: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0] [1, 13] [1, 1] + // CHECK: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[EMPTY_]][%[[ARG1]]] [1] [1] + // CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[MATERIALIZE_0]] : tensor<1xf32>) + // CHECK: %[[EMPTY_0:.*]] = tensor.empty() + // CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[FILL]][%[[C0]]] + // CHECK: %[[PARALLEL_0:.*]] = gml_st.parallel (%[[ARG2:.*]]) = (%[[C0]]) to (%[[C16]]) step (%[[C1]]) + // CHECK-SAME: outs (%[[EMPTY_0_:.*]] = %[[EMPTY_0]] + // CHECK-SAME: distribution ("thread") + // CHECK: %[[MATERIALIZE_2:.*]] = tensor.extract_slice %[[EMPTY_0_]][0, %[[ARG2]]] [1, 1] [1, 1] + // CHECK: %[[FILL_1:.*]] = linalg.fill ins(%[[EXTRACTED]] : f32) outs(%[[MATERIALIZE_2]] : tensor<1x1xf32>) + // CHECK: %[[FOR:.*]] = gml_st.for (%[[ARG3:.*]]) = (%[[ARG2]]) to (%[[C13]]) step (%[[C16]]) outs (%[[ARG4:.*]] = %[[FILL_1]]: tensor<1x1xf32>) + // CHECK: %[[MATERIALIZE_3_:.*]] = tensor.extract_slice %[[MATERIALIZE]][0, %[[ARG3]]] [1, 1] [1, 1] : tensor<1x13xf32> to tensor<1x1xf32> + // CHECK: %[[MATERIALIZE_3:.*]] = tensor.extract %[[MATERIALIZE_3_]] + // CHECK: %[[MATERIALIZE_4:.*]] = tensor.extract %[[ARG4]] + // CHECK: %[[MAXF:.*]] = arith.maxf %[[MATERIALIZE_4]], %[[MATERIALIZE_3]] : f32 + // CHECK: %[[TILE_6_:.*]] = gml_st.tile [0, 0] [1, 1] [1, 1] + // CHECK: gml_st.set_yield %[[MAXF]] into %[[ARG4]][%[[TILE_6_]]] : f32 into tensor<1x1xf32>[!gml_st.tile<1x1>] + // CHECK: %[[TILE_3:.*]] = gml_st.tile [0, %[[ARG2]]] [1, 1] [1, 1] + // CHECK: gml_st.set_yield %[[FOR]] into %[[EMPTY_0_]][%[[TILE_3]]] + // CHECK: %[[REDUCE:.*]] = linalg.reduce { arith.maxf } + // CHECK: ins(%[[PARALLEL_0]] : tensor<1x16xf32>) + // CHECK: outs(%[[FILL]] : tensor<1xf32>) + // CHECK: dimensions = [1] + // CHECK: %[[TILE_0:.*]] = gml_st.tile [%[[ARG1]]] [1] [1] + // CHECK: gml_st.set_yield %[[REDUCE]] into %[[EMPTY_]][%[[TILE_0]]] + // CHECK: return %[[PARALLEL]] + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0xFF800000 : f32 + %0 = tensor.empty() : tensor<7xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<7xf32>) + -> tensor<7xf32> + %2 = gml_st.parallel (%arg1) = (%c0) to (%c1024) step (%c1) + outs (%out_ = %1: tensor<7xf32>) + distribution ("warp") { + %4 = tensor.extract_slice %arg0[%arg1, 0] [1, 13] [1, 1] + : tensor<7x13xf32> to tensor<1x13xf32> + %6 = tensor.extract_slice %out_[%arg1] [1] [1] + : tensor<7xf32> to tensor<1xf32> + %7 = linalg.reduce { arith.maxf } + ins(%4 : tensor<1x13xf32>) + outs(%6 : tensor<1xf32>) + dimensions = [1] + %5 = gml_st.tile [%arg1] [1] [1] : !gml_st.tile<1> + gml_st.set_yield %7 into %out_[%5] + : tensor<1xf32> into tensor<7xf32>[!gml_st.tile<1>] + } : tensor<7xf32> + return %2 : tensor<7xf32> +} + +// ----- + +// CHECK-LABEL: @tiling_warp_level_cwise +// CHECK-SAME: %[[ARG0:.*]]: tensor<7x13xf32>, %[[ARG1:.*]]: tensor<7x13xf32> +func.func @tiling_warp_level_cwise(%arg0: tensor<7x13xf32>, + %arg1: tensor<7x13xf32>) -> tensor<7x13xf32> { + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-DAG: %[[C16:.*]] = arith.constant 16 + // CHECK-DAG: %[[C1024:.*]] = arith.constant 1024 + // CHECK-DAG: %[[C28:.*]] = arith.constant 28 + // CHECK-DAG: %[[EMPTY:.*]] = tensor.empty() : tensor<7x13xf32> + // CHECK: %[[PARALLEL:.*]] = gml_st.parallel + // CHECK-SAME: (%[[ARG2:.*]]) = (%[[C0]]) to (%[[C1024]]) + // CHECK-SAME: step (%[[C1]]) outs (%[[EMPTY_:.*]] = %[[EMPTY]]: + // CHECK-SAME: distribution ("warp") + // CHECK: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0] [1, 13] [1, 1] + // CHECK: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG2]], 0] [1, 13] [1, 1] + // CHECK: %[[MATERIALIZE_1:.*]] = tensor.extract_slice %[[EMPTY_]][%[[ARG2]], 0] [1, 13] [1, 1] + // CHECK: %[[PARALLEL_0:.*]] = gml_st.parallel + // CHECK-SAME: (%[[ARG3:.*]]) = (%[[C0]]) to (%[[C16]]) + // CHECK-SAME: step (%[[C1]]) outs (%[[MATERIALIZE_1_:.*]] = %[[MATERIALIZE_1]]: + // CHECK-SAME: distribution ("thread") + // CHECK: %[[SUBI:.*]] = arith.subi %[[C28]], %[[ARG3]] + // CHECK: %[[DIVUI:.*]] = arith.divui %[[SUBI]], %[[C16]] + // CHECK: %[[MATERIALIZE_2:.*]] = tensor.extract_slice %[[MATERIALIZE_1_]][0, %[[ARG3]]] [1, %[[DIVUI]]] [1, 16] + // CHECK: %[[FOR:.*]] = gml_st.for (%[[ARG4:.*]]) = (%[[C0]]) + // CHECK-SAME: to (%[[DIVUI]]) step (%[[C1]]) + // CHECK-SAME: outs (%[[ARG5:.*]] = %[[MATERIALIZE_2]]: tensor<1x?xf32>) + // CHECK: %[[MULI:.*]] = arith.muli %[[ARG4]], %[[C16]] : index + // CHECK: %[[ADDI:.*]] = arith.addi %[[ARG3]], %[[MULI]] : index + // CHECK: %[[MATERIALIZE_3_:.*]] = tensor.extract_slice %[[MATERIALIZE]][0, %[[ADDI]]] [1, 1] [1, 1] + // CHECK: %[[MATERIALIZE_3:.*]] = tensor.extract %[[MATERIALIZE_3_]] + // CHECK: %[[MATERIALIZE_4_:.*]] = tensor.extract_slice %[[MATERIALIZE_0]][0, %[[ADDI]]] [1, 1] [1, 1] + // CHECK: %[[MATERIALIZE_4:.*]] = tensor.extract %[[MATERIALIZE_4_]] + // CHECK: %[[SUBF:.*]] = arith.subf %[[MATERIALIZE_3]], %[[MATERIALIZE_4]] + // CHECK: %[[TILE_2:.*]] = gml_st.tile [0, %[[ARG4]]] [1, 1] [1, 1] + // CHECK: gml_st.set_yield %[[SUBF]] into %[[ARG5]][%[[TILE_2]]] + // CHECK: %[[TILE_0:.*]] = gml_st.tile [0, %[[ARG3]]] [1, %[[DIVUI]]] [1, 16] + // CHECK: gml_st.set_yield %[[FOR]] into %[[MATERIALIZE_1_]][%[[TILE_0]]] + // CHECK: %[[TILE:.*]] = gml_st.tile [%[[ARG2]], 0] [1, 13] [1, 1] + // CHECK: gml_st.set_yield %[[PARALLEL_0]] into %[[EMPTY_]][%[[TILE]]] + // CHECK: return %[[PARALLEL]] + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %0 = tensor.empty() : tensor<7x13xf32> + %1 = gml_st.parallel (%arg2) = (%c0) to (%c1024) step (%c1) + outs (%out_ = %0: tensor<7x13xf32>) distribution ("warp") { + %3 = tensor.extract_slice %arg0 [%arg2, 0] [1, 13] [1, 1] + : tensor<7x13xf32> to tensor<1x13xf32> + %4 = tensor.extract_slice %arg1 [%arg2, 0] [1, 13] [1, 1] + : tensor<7x13xf32> to tensor<1x13xf32> + %5 = tensor.extract_slice %out_ [%arg2, 0] [1, 13] [1, 1] + : tensor<7x13xf32> to tensor<1x13xf32> + %6 = linalg.map { arith.subf } + ins(%3, %4 : tensor<1x13xf32>, tensor<1x13xf32>) + outs(%5 : tensor<1x13xf32>) + %2 = gml_st.tile [%arg2, 0] [1, 13] [1, 1] : !gml_st.tile<1x13> + gml_st.set_yield %6 into %out_[%2] + : tensor<1x13xf32> into tensor<7x13xf32>[!gml_st.tile<1x13>] + } : tensor<7x13xf32> + return %1 : tensor<7x13xf32> +} + +// ----- + +// CHECK-LABEL: @softmax +// CHECK-SAME: %[[ARG0:.*]]: tensor<2048x4096xf32> +func.func @softmax(%arg0: tensor<2048x4096xf32>) -> tensor<2048x4096xf32> { + // CHECK-DAG: %[[C4096:.*]] = arith.constant 4096 : index + // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index + // CHECK-DAG: %[[C4127:.*]] = arith.constant 4127 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C1024:.*]] = arith.constant 1024 : index + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C2048:.*]] = arith.constant 2048 : index + // CHECK-DAG: %[[CST:.*]] = arith.constant -0.000000e+00 : f32 + // CHECK-DAG: %[[CST_0:.*]] = arith.constant 0xFF800000 : f32 + // CHECK-DAG: %[[EMPTY:.*]] = tensor.empty() : tensor<2048xf32> + // CHECK-DAG: %[[EMPTY_0:.*]] = tensor.empty() : tensor<2048x4096xf32> + // CHECK: %[[PARALLEL:.*]] = gml_st.parallel (%[[ARG1:.*]]) = (%[[C0]]) to (%[[C2048]]) step (%[[C1024]]) + // CHECK-SAME: outs (%[[BLOCK_OUT_:.*]] = %[[EMPTY_0]]: + // CHECK-SAME: distribution ("block") + // CHECK: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[BLOCK_OUT_]][%[[ARG1]], 0] [1024, 4096] [1, 1] + + // CHECK: %[[PARALLEL_0:.*]] = gml_st.parallel (%[[ARG2:.*]]) = (%[[C0]]) to (%[[C1024]]) step (%[[C1]]) + // CHECK-SAME: outs (%[[WARP_OUT_:.*]] = %[[MATERIALIZE]]: + // CHECK-SAME: distribution ("warp") + // CHECK: %[[ADDI:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index + // CHECK: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[ARG0]][%[[ADDI]], 0] [1, 4096] [1, 1] + // CHECK: %[[MATERIALIZE_1:.*]] = tensor.extract_slice %[[EMPTY]][%[[ADDI]]] [1] [1] + // CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[MATERIALIZE_1]] : tensor<1xf32>) + // CHECK: %[[EMPTY_1:.*]] = tensor.empty() + // CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[FILL]][%[[C0]]] + // CHECK: %[[PARALLEL_1:.*]] = gml_st.parallel (%[[ARG3:.*]]) = (%[[C0]]) to (%[[C32]]) step (%[[C1]]) + // CHECK-SAME: outs (%[[THREAD_OUT0_:.*]] = %[[EMPTY_1]] + // CHECK-SAME: distribution ("thread") + // CHECK: %[[MATERIALIZE_3:.*]] = tensor.extract_slice %[[THREAD_OUT0_]][0, %[[ARG3]]] [1, 1] [1, 1] + // CHECK: %[[FILL_1:.*]] = linalg.fill ins(%[[EXTRACTED]] : f32) outs(%[[MATERIALIZE_3]] : tensor<1x1xf32>) + // CHECK: %[[FOR:.*]] = gml_st.for (%[[ARG4:.*]]) = (%[[ARG3]]) to (%[[C4096]]) step (%[[C32]]) outs (%[[ARG5:.*]] = %[[FILL_1]]: tensor<1x1xf32>) + // CHECK: %[[MATERIALIZE_4_:.*]] = tensor.extract_slice %[[MATERIALIZE_0]][0, %[[ARG4]]] [1, 1] [1, 1] : tensor<1x4096xf32> to tensor<1x1xf32> + // CHECK: %[[MATERIALIZE_4:.*]] = tensor.extract %[[MATERIALIZE_4_]] + // CHECK: %[[MATERIALIZE_5:.*]] = tensor.extract %[[ARG5]] + // CHECK: %[[MAXF:.*]] = arith.maxf %[[MATERIALIZE_5]], %[[MATERIALIZE_4]] : f32 + // CHECK: %[[TILE_7_:.*]] = gml_st.tile [0, 0] [1, 1] [1, 1] + // CHECK: gml_st.set_yield %[[MAXF]] into %[[ARG5]][%[[TILE_7_]]] : f32 into tensor<1x1xf32>[!gml_st.tile<1x1>] + // CHECK: %[[TILE_4:.*]] = gml_st.tile [0, %[[ARG3]]] [1, 1] [1, 1] + // CHECK: gml_st.set_yield %[[FOR]] into %[[THREAD_OUT0_]][%[[TILE_4]]] + // CHECK: %[[REDUCE:.*]] = linalg.reduce { arith.maxf } + // CHECK: ins(%[[PARALLEL_1]] : tensor<1x32xf32>) + // CHECK: outs(%[[FILL]] : tensor<1xf32>) + // CHECK: dimensions = [1] + // CHECK: %[[MATERIALIZE_6:.*]] = tensor.extract_slice %[[EMPTY_0]][%[[ADDI]], 0] [1, 4096] [1, 1] + // CHECK: %[[MATERIALIZE_7:.*]] = tensor.extract_slice %[[EMPTY]][%[[ADDI]]] [1] [1] + // CHECK: %[[FILL_2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[MATERIALIZE_7]] : tensor<1xf32>) + // CHECK: %[[EMPTY_2:.*]] = tensor.empty() + // CHECK: %[[EXTRACTED_1:.*]] = tensor.extract %[[FILL_2]][%[[C0]]] + // CHECK: %[[PARALLEL_2:.*]] = gml_st.parallel (%[[ARG3]]) = (%[[C0]]) to (%[[C32]]) step (%[[C1]]) + // CHECK-SAME: outs (%[[THREAD_OUT1_:.*]] = %[[EMPTY_2]] + // CHECK-SAME: distribution ("thread") + // CHECK: %[[MATERIALIZE_9:.*]] = tensor.extract_slice %[[THREAD_OUT1_]][0, %[[ARG3]]] [1, 1] [1, 1] + // CHECK: %[[FILL_4:.*]] = linalg.fill ins(%[[EXTRACTED_1]] : f32) outs(%[[MATERIALIZE_9]] : tensor<1x1xf32>) + // CHECK: %[[FOR_0:.*]] = gml_st.for (%[[ARG4_0:.*]]) = (%[[ARG3]]) to (%[[C4096]]) step (%[[C32]]) outs (%[[ARG5_0:.*]] = %[[FILL_4]]: tensor<1x1xf32>) + // CHECK: %[[MATERIALIZE_10:.*]] = tensor.extract_slice %[[MATERIALIZE_0]][0, %[[ARG4_0]]] [1, 1] [1, 1] + // CHECK: %[[MATERIALIZE_12:.*]] = tensor.extract_slice %[[MATERIALIZE_6]][0, %[[ARG4_0]]] [1, 1] [1, 1] + // CHECK: %[[BROADCAST:.*]] = linalg.broadcast + // CHECK: ins(%[[REDUCE]] : tensor<1xf32>) + // CHECK: outs(%[[MATERIALIZE_12]] : tensor<1x1xf32>) + // CHECK: dimensions = [1] + // CHECK: %[[MATERIALIZE_13:.*]] = tensor.extract_slice %[[MATERIALIZE_6]][0, %[[ARG4_0]]] [1, 1] [1, 1] + // CHECK: %[[MAP:.*]] = linalg.map { arith.subf } + // CHECK: ins(%[[MATERIALIZE_10]], %[[BROADCAST]] : tensor<1x1xf32>, tensor<1x1xf32>) + // CHECK: outs(%[[MATERIALIZE_13]] : tensor<1x1xf32>) + // CHECK: %[[MATERIALIZE_14:.*]] = tensor.extract_slice %[[MATERIALIZE_6]][0, %[[ARG4_0]]] [1, 1] [1, 1] + // CHECK: %[[MAP_0:.*]] = linalg.map { math.exp } + // CHECK: ins(%[[MAP]] : tensor<1x1xf32>) + // CHECK: outs(%[[MATERIALIZE_14]] : tensor<1x1xf32>) + // CHECK: %[[EXTRACTED_2:.*]] = tensor.extract %[[MAP_0]][%[[C0]], %[[C0]]] + // CHECK: %[[MATERIALIZE_15:.*]] = tensor.extract %[[ARG5_0]] + + // CHECK: %[[ADDF:.*]] = arith.addf %[[MATERIALIZE_15]], %[[EXTRACTED_2]] : f32 + // CHECK: %[[TILE_17_:.*]] = gml_st.tile [0, 0] [1, 1] [1, 1] + // CHECK: gml_st.set_yield %[[ADDF]] into %[[ARG5_0]][%[[TILE_17_]]] : f32 into tensor<1x1xf32>[!gml_st.tile<1x1>] + // CHECK: %[[TILE_10:.*]] = gml_st.tile [0, %[[ARG3]]] [1, 1] [1, 1] + // CHECK: gml_st.set_yield %[[FOR_0]] into %[[THREAD_OUT1_]][%[[TILE_10]]] + // CHECK: %[[REDUCE_0:.*]] = linalg.reduce { arith.addf } + // CHECK: ins(%[[PARALLEL_2]] : tensor<1x32xf32>) + // CHECK: outs(%[[FILL_2]] : tensor<1xf32>) + // CHECK: dimensions = [1] + // CHECK: %[[PARALLEL_3:.*]] = gml_st.parallel (%[[ARG3]]) = (%[[C0]]) to (%[[C32]]) step (%[[C1]]) + // CHECK-SAME: outs (%[[THREAD_OUT2_:.*]] = %[[MATERIALIZE_6]]: + // CHECK-SAME: distribution ("thread") + // CHECK: %[[SUBI:.*]] = arith.subi %[[C4127]], %[[ARG3]] : index + // CHECK: %[[DIVUI:.*]] = arith.divui %[[SUBI]], %[[C32]] : index + // CHECK: %[[MATERIALIZE_16:.*]] = tensor.extract_slice %[[THREAD_OUT2_]][0, %[[ARG3]]] [1, %[[DIVUI]]] [1, 32] + // CHECK: %[[FOR_1:.*]] = gml_st.for (%[[ARG4_1:.*]]) = (%[[C0]]) to (%[[DIVUI]]) step (%[[C1]]) outs (%[[ARG5_1:.*]] = %[[MATERIALIZE_16]]: tensor<1x?xf32>) + // CHECK: %[[MULI:.*]] = arith.muli %[[ARG4_1]], %[[C32]] : index + // CHECK: %[[ADDI_0:.*]] = arith.addi %[[ARG3]], %[[MULI]] : index + // CHECK: %[[MATERIALIZE_17:.*]] = tensor.extract_slice %[[MATERIALIZE_0]][0, %[[ADDI_0]]] [1, 1] [1, 1] + // CHECK: %[[MATERIALIZE_19:.*]] = tensor.extract_slice %[[MATERIALIZE_6]][0, %[[ADDI_0]]] [1, 1] [1, 1] + // CHECK: %[[BROADCAST_0:.*]] = linalg.broadcast + // CHECK: ins(%[[REDUCE]] : tensor<1xf32>) + // CHECK: outs(%[[MATERIALIZE_19]] : tensor<1x1xf32>) + // CHECK: dimensions = [1] + // CHECK: %[[MATERIALIZE_20:.*]] = tensor.extract_slice %[[MATERIALIZE_6]][0, %[[ADDI_0]]] [1, 1] [1, 1] + // CHECK: %[[MAP_1:.*]] = linalg.map { arith.subf } + // CHECK: ins(%[[MATERIALIZE_17]], %[[BROADCAST_0]] : tensor<1x1xf32>, tensor<1x1xf32>) + // CHECK: outs(%[[MATERIALIZE_20]] : tensor<1x1xf32>) + // CHECK: %[[MATERIALIZE_21:.*]] = tensor.extract_slice %[[MATERIALIZE_6]][0, %[[ADDI_0]]] [1, 1] [1, 1] + // CHECK: %[[MAP_2:.*]] = linalg.map { math.exp } + // CHECK: ins(%[[MAP_1]] : tensor<1x1xf32>) + // CHECK: outs(%[[MATERIALIZE_21]] : tensor<1x1xf32>) + // CHECK: %[[EXTRACTED_2_0:.*]] = tensor.extract %[[MAP_2]][%[[C0]], %[[C0]]] + // CHECK: %[[MATERIALIZE_23:.*]] = tensor.extract_slice %[[MATERIALIZE_6]][0, %[[ADDI_0]]] [1, 1] [1, 1] + // CHECK: %[[BROADCAST_1:.*]] = linalg.broadcast + // CHECK: ins(%[[REDUCE_0]] : tensor<1xf32>) + // CHECK: outs(%[[MATERIALIZE_23]] : tensor<1x1xf32>) + // CHECK: dimensions = [1] + // CHECK: %[[EXTRACTED_3:.*]] = tensor.extract %[[BROADCAST_1]][%[[C0]], %[[C0]]] + // CHECK: %[[DIVF:.*]] = arith.divf %[[EXTRACTED_2_0]], %[[EXTRACTED_3]] : f32 + // CHECK: %[[TILE_26:.*]] = gml_st.tile [0, %[[ARG4_1]]] [1, 1] [1, 1] + // CHECK: gml_st.set_yield %[[DIVF]] into %[[ARG5_1]][%[[TILE_26]]] : f32 into tensor<1x?xf32>[!gml_st.tile<1x1>] + // CHECK: %[[TILE_18_:.*]] = gml_st.tile [0, %[[ARG3]]] [1, %[[DIVUI]]] [1, 32] + // CHECK: gml_st.set_yield %[[FOR_1]] into %[[THREAD_OUT2_]][%[[TILE_18_]]] + // CHECK: %[[TILE_0:.*]] = gml_st.tile [%[[ARG2]], 0] [1, 4096] [1, 1] + // CHECK: gml_st.set_yield %[[PARALLEL_3]] into %[[WARP_OUT_]][%[[TILE_0]]] + // CHECK: %[[TILE:.*]] = gml_st.tile [%[[ARG1]], 0] [1024, 4096] [1, 1] + // CHECK: gml_st.set_yield %[[PARALLEL_0]] into %[[BLOCK_OUT_]][%[[TILE]]] + // CHECK: return %[[PARALLEL]] + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c2048 = arith.constant 2048 : index + %cst = arith.constant -0.000000e+00 : f32 + %cst_0 = arith.constant 0xFF800000 : f32 + %0 = tensor.empty() : tensor<2048xf32> + %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<2048xf32>) + -> tensor<2048xf32> + %2 = tensor.empty() : tensor<2048x4096xf32> + %3 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2048xf32>) + -> tensor<2048xf32> + %4 = gml_st.parallel (%arg1) = (%c0) to (%c2048) step (%c1024) + outs (%block_out_ = %2: tensor<2048x4096xf32>) distribution ("block") { + %6 = tensor.extract_slice %block_out_[%arg1, 0] [1024, 4096] [1, 1] + : tensor<2048x4096xf32> to tensor<1024x4096xf32> + %7 = gml_st.parallel (%arg2) = (%c0) to (%c1024) step (%c1) + outs (%warp_out_ = %6: tensor<1024x4096xf32>) distribution ("warp") { + %9 = arith.addi %arg1, %arg2 : index + %11 = tensor.extract_slice %arg0[%9, 0] [1, 4096] [1, 1] + : tensor<2048x4096xf32> to tensor<1x4096xf32> + %13 = tensor.extract_slice %1[%9] [1] [1] + : tensor<2048xf32> to tensor<1xf32> + %14 = linalg.reduce { arith.maxf } + ins(%11 : tensor<1x4096xf32>) + outs(%13 : tensor<1xf32>) dimensions = [1] + %15 = tensor.extract_slice %2[%9, 0] [1, 4096] [1, 1] + : tensor<2048x4096xf32> to tensor<1x4096xf32> + %16 = linalg.broadcast + ins(%14 : tensor<1xf32>) outs(%15 : tensor<1x4096xf32>) + dimensions = [1] + %17 = linalg.map { arith.subf } + ins(%11, %16 : tensor<1x4096xf32>, tensor<1x4096xf32>) + outs(%15 : tensor<1x4096xf32>) + %18 = linalg.map { math.exp } + ins(%17 : tensor<1x4096xf32>) + outs(%15 : tensor<1x4096xf32>) + %19 = tensor.extract_slice %3[%9] [1] [1] + : tensor<2048xf32> to tensor<1xf32> + %20 = linalg.reduce { arith.addf } + ins(%18 : tensor<1x4096xf32>) + outs(%19 : tensor<1xf32>) dimensions = [1] + %21 = linalg.broadcast + ins(%20 : tensor<1xf32>) outs(%15 : tensor<1x4096xf32>) + dimensions = [1] + %22 = linalg.map { arith.divf } + ins(%18, %21 : tensor<1x4096xf32>, tensor<1x4096xf32>) + outs(%15 : tensor<1x4096xf32>) + %8 = gml_st.tile [%arg2, 0] [1, 4096] [1, 1] : !gml_st.tile<1x4096> + gml_st.set_yield %22 into %warp_out_[%8] + : tensor<1x4096xf32> into tensor<1024x4096xf32>[!gml_st.tile<1x4096>] + } : tensor<1024x4096xf32> + %5 = gml_st.tile [%arg1, 0] [1024, 4096] [1, 1] : !gml_st.tile<1024x4096> + gml_st.set_yield %7 into %block_out_[%5] + : tensor<1024x4096xf32> into tensor<2048x4096xf32>[!gml_st.tile<1024x4096>] + } : tensor<2048x4096xf32> + return %4 : tensor<2048x4096xf32> +} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/greedy_fusion.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/greedy_fusion.mlir new file mode 100644 index 00000000000..c269727ed58 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/greedy_fusion.mlir @@ -0,0 +1,131 @@ +// RUN: mlir-hlo-opt %s --split-input-file \ +// RUN: --gml-tiling="tile-sizes=2 op-label=root" --test-gml-st-greedy-fusion | \ +// RUN: FileCheck %s + +// CHECK-LABEL: func @fuse_broadcast_map +// CHECK-SAME: (%[[ARG0:.*]]: tensor<16xf32>, %[[ARG1:.*]]: tensor<16x32xf32>) +func.func @fuse_broadcast_map(%arg0: tensor<16xf32>, %arg1: tensor<16x32xf32>) + -> tensor<16x32xf32> { + %init = tensor.empty() : tensor<16x32xf32> + %bcast = linalg.broadcast + ins(%arg0 : tensor<16xf32>) + outs(%init : tensor<16x32xf32>) + dimensions = [1] + + %result = linalg.map { arith.addf } + ins(%bcast, %arg1 : tensor<16x32xf32>, tensor<16x32xf32>) + outs(%init : tensor<16x32xf32>) + { op_label = "root" } + func.return %result : tensor<16x32xf32> +} + +// CHECK: %[[INIT:.*]] = tensor.empty() +// CHECK: %[[RESULT:.*]] = gml_st.parallel +// CHECK-SAME: outs (%[[INIT_:.*]] = %[[INIT]]: +// CHECK-DAG: %[[INIT_SLICE:.*]] = tensor.extract_slice %[[INIT]] +// CHECK-DAG: %[[ARG0_SLICE:.*]] = tensor.extract_slice %[[ARG0]] +// CHECK: %[[BCAST:.*]] = linalg.broadcast +// CHECK-SAME: ins(%[[ARG0_SLICE]] +// CHECK-SAME: outs(%[[INIT_SLICE]] +// CHECK: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]] +// CHECK-DAG: %[[INIT_SLICE_:.*]] = tensor.extract_slice %[[INIT_]] +// CHECK: %[[MAPPED:.*]] = linalg.map +// CHECK-SAME: ins(%[[BCAST]], %[[ARG1_SLICE]] +// CHECK-SAME: outs(%[[INIT_SLICE_]] +// CHECK: gml_st.set_yield %[[MAPPED]] +// CHECK: return %[[RESULT]] + +// ----- + +// CHECK-LABEL: func @do_not_fuse_map_reduce +// CHECK-SAME: (%[[ARG0:.*]]: tensor<16x32xf32>, %[[ARG1:.*]]: tensor<16xf32>) +func.func @do_not_fuse_map_reduce(%arg0: tensor<16x32xf32>, %arg1: tensor<16xf32>) + -> tensor<16xf32> { + %init = tensor.empty() : tensor<16xf32> + %reduce = linalg.reduce { arith.addf } + ins(%arg0 : tensor<16x32xf32>) + outs(%init : tensor<16xf32>) + dimensions = [1] + + %result = linalg.map { arith.addf } + ins(%reduce, %arg1 : tensor<16xf32>, tensor<16xf32>) + outs(%init : tensor<16xf32>) + { op_label = "root" } + func.return %result : tensor<16xf32> +} + +// CHECK: %[[INIT:.*]] = tensor.empty() +// CHECK: %[[REDUCE:.*]] = linalg.reduce +// CHECK: %[[RESULT:.*]] = gml_st.parallel +// CHECK-SAME: outs (%[[INIT_:.*]] = %[[INIT]]: +// CHECK-DAG: %[[REDUCE_SLICE:.*]] = tensor.extract_slice %[[REDUCE]] +// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]] +// CHECK-DAG: %[[INIT_SLICE:.*]] = tensor.extract_slice %[[INIT_]] +// CHECK: %[[MAPPED:.*]] = linalg.map +// CHECK-SAME: ins(%[[REDUCE_SLICE]], %[[ARG1_SLICE]] +// CHECK-SAME: outs(%[[INIT_SLICE]] +// CHECK: gml_st.set_yield %[[MAPPED]] +// CHECK: return %[[RESULT]] + +// ----- + +// Only basic checks that all maps and fills were fused into gml_st.parallel. +// This test verified that ops are fused in correct order. If something is +// broken, the test will take exponential time and/or memory to finish. +// CHECK-LABEL: func @fuse_fibonacci +// CHECK-NOT: linalg.fill +// CHECK-NOT: linalg.map +// CHECK: gml_st.parallel +// CHECK-COUNT-2: linalg.fill +// CHECK-COUNT-38: linalg.map +// CHECK-NOT: linalg.fill +// CHECK-NOT: linalg.map +// CHECK: gml_st.set_yield +// CHECK: return +func.func @fuse_fibonacci(%init : tensor) -> tensor { + %c0 = arith.constant 0 : i64 + %c1 = arith.constant 1 : i64 + + %0 = linalg.fill ins(%c0 : i64) outs(%init : tensor) -> tensor + %1 = linalg.fill ins(%c1 : i64) outs(%init : tensor) -> tensor + %2 = linalg.map { arith.addi } ins(%0, %1 : tensor, tensor) outs(%init : tensor) + %3 = linalg.map { arith.addi } ins(%1, %2 : tensor, tensor) outs(%init : tensor) + %4 = linalg.map { arith.addi } ins(%2, %3 : tensor, tensor) outs(%init : tensor) + %5 = linalg.map { arith.addi } ins(%3, %4 : tensor, tensor) outs(%init : tensor) + %6 = linalg.map { arith.addi } ins(%4, %5 : tensor, tensor) outs(%init : tensor) + %7 = linalg.map { arith.addi } ins(%5, %6 : tensor, tensor) outs(%init : tensor) + %8 = linalg.map { arith.addi } ins(%6, %7 : tensor, tensor) outs(%init : tensor) + %9 = linalg.map { arith.addi } ins(%7, %8 : tensor, tensor) outs(%init : tensor) + %10 = linalg.map { arith.addi } ins(%8, %9 : tensor, tensor) outs(%init : tensor) + %11 = linalg.map { arith.addi } ins(%9, %10 : tensor, tensor) outs(%init : tensor) + %12 = linalg.map { arith.addi } ins(%10, %11 : tensor, tensor) outs(%init : tensor) + %13 = linalg.map { arith.addi } ins(%11, %12 : tensor, tensor) outs(%init : tensor) + %14 = linalg.map { arith.addi } ins(%12, %13 : tensor, tensor) outs(%init : tensor) + %15 = linalg.map { arith.addi } ins(%13, %14 : tensor, tensor) outs(%init : tensor) + %16 = linalg.map { arith.addi } ins(%14, %15 : tensor, tensor) outs(%init : tensor) + %17 = linalg.map { arith.addi } ins(%15, %16 : tensor, tensor) outs(%init : tensor) + %18 = linalg.map { arith.addi } ins(%16, %17 : tensor, tensor) outs(%init : tensor) + %19 = linalg.map { arith.addi } ins(%17, %18 : tensor, tensor) outs(%init : tensor) + %20 = linalg.map { arith.addi } ins(%18, %19 : tensor, tensor) outs(%init : tensor) + %21 = linalg.map { arith.addi } ins(%19, %20 : tensor, tensor) outs(%init : tensor) + %22 = linalg.map { arith.addi } ins(%20, %21 : tensor, tensor) outs(%init : tensor) + %23 = linalg.map { arith.addi } ins(%21, %22 : tensor, tensor) outs(%init : tensor) + %24 = linalg.map { arith.addi } ins(%22, %23 : tensor, tensor) outs(%init : tensor) + %25 = linalg.map { arith.addi } ins(%23, %24 : tensor, tensor) outs(%init : tensor) + %26 = linalg.map { arith.addi } ins(%24, %25 : tensor, tensor) outs(%init : tensor) + %27 = linalg.map { arith.addi } ins(%25, %26 : tensor, tensor) outs(%init : tensor) + %28 = linalg.map { arith.addi } ins(%26, %27 : tensor, tensor) outs(%init : tensor) + %29 = linalg.map { arith.addi } ins(%27, %28 : tensor, tensor) outs(%init : tensor) + %30 = linalg.map { arith.addi } ins(%28, %29 : tensor, tensor) outs(%init : tensor) + %31 = linalg.map { arith.addi } ins(%29, %30 : tensor, tensor) outs(%init : tensor) + %32 = linalg.map { arith.addi } ins(%30, %31 : tensor, tensor) outs(%init : tensor) + %33 = linalg.map { arith.addi } ins(%31, %32 : tensor, tensor) outs(%init : tensor) + %34 = linalg.map { arith.addi } ins(%32, %33 : tensor, tensor) outs(%init : tensor) + %35 = linalg.map { arith.addi } ins(%33, %34 : tensor, tensor) outs(%init : tensor) + %36 = linalg.map { arith.addi } ins(%34, %35 : tensor, tensor) outs(%init : tensor) + %37 = linalg.map { arith.addi } ins(%35, %36 : tensor, tensor) outs(%init : tensor) + %38 = linalg.map { arith.addi } ins(%36, %37 : tensor, tensor) outs(%init : tensor) + %39 = linalg.map { arith.addi } ins(%37, %38 : tensor, tensor) outs(%init : tensor) + { op_label = "root" } + func.return %39 : tensor +} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/greedy_tiling_and_fusion.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/greedy_tiling_and_fusion.mlir new file mode 100644 index 00000000000..ae763cd9fdd --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/greedy_tiling_and_fusion.mlir @@ -0,0 +1,143 @@ +// RUN: mlir-hlo-opt %s --split-input-file \ +// RUN: --gml-greedy-fusion="tile-sizes=8,16 distribute=true distribution-label=test" \ +// RUN: --canonicalize --cse | \ +// RUN: FileCheck %s + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +func.func @softmax(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { + %cst = arith.constant -0.000000e+00 : f32 + %cst_0 = arith.constant 0xFF800000 : f32 + %0 = tensor.empty() : tensor<64xf32> + %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<64xf32>) -> tensor<64xf32> + %2 = linalg.generic {indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor<64x128xf32>) outs(%1 : tensor<64xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %11 = arith.maxf %arg2, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor<64xf32> + %3 = tensor.empty() : tensor<64x128xf32> + %4 = linalg.generic {indexing_maps = [#map1, #map0], + iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<64xf32>) + outs(%3 : tensor<64x128xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %arg1 : f32 + } -> tensor<64x128xf32> + %5 = linalg.generic {indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %4 : tensor<64x128xf32>, tensor<64x128xf32>) + outs(%3 : tensor<64x128xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %11 = arith.subf %arg1, %arg2 : f32 + linalg.yield %11 : f32 + } -> tensor<64x128xf32> + %6 = linalg.generic {indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<64x128xf32>) + outs(%3 : tensor<64x128xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %11 = math.exp %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor<64x128xf32> + %7 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64xf32>) -> tensor<64xf32> + %8 = linalg.generic {indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "reduction"]} ins(%6 : tensor<64x128xf32>) + outs(%7 : tensor<64xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %11 = arith.addf %arg2, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor<64xf32> + %9 = linalg.generic {indexing_maps = [#map1, #map0], + iterator_types = ["parallel", "parallel"]} ins(%8 : tensor<64xf32>) + outs(%3 : tensor<64x128xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %arg1 : f32 + } -> tensor<64x128xf32> + %10 = linalg.generic {indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} + ins(%6, %9 : tensor<64x128xf32>, tensor<64x128xf32>) + outs(%3 : tensor<64x128xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %11 = arith.divf %arg1, %arg2 : f32 + linalg.yield %11 : f32 + } -> tensor<64x128xf32> + return %10 : tensor<64x128xf32> +} + +// CHECK-LABEL: @softmax +// CHECK-NOT: linalg.generic +// CHECK: %[[PARALLEL:.*]] = gml_st.parallel +// CHECK: gml_st.set_yield +// CHECK-NOT: linalg.generic +// CHECK: return %[[PARALLEL]] + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +func.func @softmax_fuse_tensor_extract(%arg0: tensor<64x128xf32>, + %cst_tensor: tensor) -> tensor<64x128xf32> { + %cst = tensor.extract %cst_tensor[] : tensor + %cst_0 = arith.constant 0xFF800000 : f32 + %0 = tensor.empty() : tensor<64xf32> + %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<64xf32>) -> tensor<64xf32> + %2 = linalg.generic {indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor<64x128xf32>) outs(%1 : tensor<64xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %11 = arith.maxf %arg2, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor<64xf32> + %3 = tensor.empty() : tensor<64x128xf32> + %4 = linalg.generic {indexing_maps = [#map1, #map0], + iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<64xf32>) + outs(%3 : tensor<64x128xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %arg1 : f32 + } -> tensor<64x128xf32> + %5 = linalg.generic {indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %4 : tensor<64x128xf32>, tensor<64x128xf32>) + outs(%3 : tensor<64x128xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %11 = arith.subf %arg1, %arg2 : f32 + linalg.yield %11 : f32 + } -> tensor<64x128xf32> + %6 = linalg.generic {indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<64x128xf32>) + outs(%3 : tensor<64x128xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %11 = math.exp %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor<64x128xf32> + %7 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64xf32>) -> tensor<64xf32> + %8 = linalg.generic {indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "reduction"]} ins(%6 : tensor<64x128xf32>) + outs(%7 : tensor<64xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %11 = arith.addf %arg2, %arg1 : f32 + linalg.yield %11 : f32 + } -> tensor<64xf32> + %9 = linalg.generic {indexing_maps = [#map1, #map0], + iterator_types = ["parallel", "parallel"]} ins(%8 : tensor<64xf32>) + outs(%3 : tensor<64x128xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %arg1 : f32 + } -> tensor<64x128xf32> + %10 = linalg.generic {indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} + ins(%6, %9 : tensor<64x128xf32>, tensor<64x128xf32>) + outs(%3 : tensor<64x128xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %11 = arith.divf %arg1, %arg2 : f32 + linalg.yield %11 : f32 + } -> tensor<64x128xf32> + return %10 : tensor<64x128xf32> +} + +// CHECK-LABEL: @softmax +// CHECK-NOT: tensor.extract +// CHECK: %[[PARALLEL_BLOCK:.*]] = gml_st.parallel +// CHECK: tensor.extract diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/invalid.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/invalid.mlir index 4479f0980b6..601481657b9 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/invalid.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/invalid.mlir @@ -1,158 +1,8 @@ // RUN: mlir-hlo-opt %s -split-input-file -verify-diagnostics - -func.func @materialize_rank_mismatch(%tensor: tensor, - %tile: !gml_st.tile<4>) { - // expected-error @+1 {{expected source rank = 2 to match tile rank = 1}} - %0 = gml_st.materialize %tensor[%tile] - : tensor[!gml_st.tile<4>] to tensor<4xf32> -} - -// ----- - -func.func @materialize_inferred_type_mismatch(%tensor: tensor, - %tile: !gml_st.tile) { - // expected-error @+1 {{expected result type = 'tensor<4x?xf32>' to match the inferred type = 'tensor}} - %0 = gml_st.materialize %tensor[%tile] - : tensor[!gml_st.tile] to tensor<4x?xf32> -} - -// ----- - -func.func @materialize_scalar_with_dynamic_tile( - %tensor: tensor, %tile: !gml_st.tile) { - // expected-error @+1 {{expected tile type '!gml_st.tile' to have a single element shape}} - %0 = gml_st.materialize %tensor[%tile] - : tensor[!gml_st.tile] to f32 -} - -// ----- - -func.func @materialize_scalar_with_nonsingle_element_tile( - %tensor: tensor, %tile: !gml_st.tile<1x2>) { - // expected-error @+1 {{expected tile type '!gml_st.tile<1x2>' to have a single element shape}} - %0 = gml_st.materialize %tensor[%tile] - : tensor[!gml_st.tile<1x2>] to f32 -} - -// ----- - -func.func @materialize_scalar_element_type_mismatch( - %tensor: tensor, %tile: !gml_st.tile<1x1>) { - // expected-error @+1 {{expected the result type 'i32' to match source element type 'f32'}} - %0 = gml_st.materialize %tensor[%tile] - : tensor[!gml_st.tile<1x1>] to i32 -} - -// ----- - -#map0 = affine_map<(d0) -> (24, -d0 + 192)> -#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> -#map2 = affine_map<(d0) -> (16, -d0 + 192)> - -func.func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>, - %C: memref<192x192xf32>) -> () - -func.func @loop_incorrent_num_yield_operands(%A: memref<192x192xf32>, - %B: memref<192x192xf32>, %C: memref<192x192xf32>, - %C_tensor: tensor<192x192xf32>) { - %c24 = arith.constant 24 : index - %c0 = arith.constant 0 : index - %c192 = arith.constant 192 : index - %0 = gml_st.loop (%i, %j) = (%c0, %c0) to (%c192, %c192) - step (%c24, %c24) - ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) - outs (%CT_ = %C_tensor: tensor<192x192xf32>, - %C_ = %C: memref<192x192xf32>) { - func.call @foo(%A_, %B_, %C_) - : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> () - // expected-error @+1 {{expected number of tensor output args = 1 to match the number of yield operands = 0}} - gml_st.yield - } - func.return -} - -// ----- - -#map0 = affine_map<(d0) -> (24, -d0 + 192)> -#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> -#map2 = affine_map<(d0) -> (16, -d0 + 192)> - -func.func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>, - %C: memref<192x192xf32>) -> tensor - -func.func @loop_incorrent_yield_operand_type(%A: memref<192x192xf32>, - %B: memref<192x192xf32>, %C: memref<192x192xf32>, - %C_tensor: tensor<192x192xf32>) { - %c24 = arith.constant 24 : index - %c0 = arith.constant 0 : index - %c192 = arith.constant 192 : index - %0 = gml_st.loop (%i, %j) = (%c0, %c0) to (%c192, %c192) - step (%c24, %c24) - ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>) - outs (%CT_ = %C_tensor: tensor<192x192xf32>, - %C_ = %C: memref<192x192xf32>) { - %1 = func.call @foo(%A_, %B_, %C_) - : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> tensor - // expected-error @+1 {{expected yield operand 0 with type = 'tensor' to match output arg type = 'tensor<192x192xf32>}} - gml_st.yield %1 : tensor - } - func.return -} - -// ----- - -func.func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>, - %C: memref<192x192xf32>) -> () - -func.func @loop_incorrent_iterator_types_count(%A: memref<192x192xf32>, - %B: memref<192x192xf32>, %C: memref<192x192xf32>, - %C_tensor: tensor<192x192xf32>) { - %c24 = arith.constant 24 : index - %c0 = arith.constant 0 : index - %c192 = arith.constant 192 : index - // expected-error @+1 {{expected iterator types array attribute size = 1 to match the number of loops = 2}} - %0 = "gml_st.loop"(%c0, %c0, %c192, %c192, %c24, %c24, %A, %B, %C_tensor, %C) ({ - ^bb0(%arg4: index, %arg5: index, %A_: memref<192x192xf32>, - %B_: memref<192x192xf32>, %CT_: tensor<192x192xf32>, - %C_: memref<192x192xf32>): - func.call @foo(%A_, %B_, %C_) - : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> () - gml_st.yield %CT_ : tensor<192x192xf32> - }) { - iterator_types = [#gml_st.iterator_type], - operand_segment_sizes = array - } : (index, index, index, index, index, index, memref<192x192xf32>, - memref<192x192xf32>, tensor<192x192xf32>, memref<192x192xf32> - ) -> tensor<192x192xf32> - func.return -} - -// ----- - -func.func private @foo(%A: memref<100xf32>) -> () - -func.func @loop_incorrent_block_arg_type(%A: memref<192xf32>) { - %c0 = arith.constant 0 : index - %c192 = arith.constant 192 : index - %c24 = arith.constant 24 : index - // expected-error @+1 {{expected output arg 0 with type = 'memref<192xf32>' to match region arg 1 type = 'memref<100xf32>'}} - "gml_st.loop"(%c0, %c192, %c24, %A) ({ - ^bb0(%arg4: index, %A_: memref<100xf32>): - func.call @foo(%A_) : (memref<100xf32>)-> () - gml_st.yield - }) { - iterator_types = [#gml_st.iterator_type], - operand_segment_sizes = array - } : (index, index, index, memref<192xf32>) -> () - func.return -} - -// ----- - func.func @tile_op_mismatch_sizes_and_static_sizes(%i: index) { // expected-error@+1 {{expected 0 dynamic size values}} - %1 = "gml_st.tile"(%i) { static_offsets = [0, 0], static_sizes = [1, 1], static_strides = [1, 1], operand_segment_sizes = array } : (index) -> !gml_st.tile + %1 = "gml_st.tile"(%i) { static_offsets = array, static_sizes = array, static_strides = array, operand_segment_sizes = array } : (index) -> !gml_st.tile func.return } @@ -160,7 +10,7 @@ func.func @tile_op_mismatch_sizes_and_static_sizes(%i: index) { func.func @tile_op_mismatch_offsets_and_static_offsets(%i: index) -> !gml_st.tile<8x8> { // expected-error@+1 {{expected 0 dynamic offset values}} - %1 = "gml_st.tile"(%i) {static_offsets = [0, 0], static_sizes = [8, 8], static_strides = [1, 1], operand_segment_sizes = array} : (index) -> !gml_st.tile<8x8> + %1 = "gml_st.tile"(%i) {static_offsets = array, static_sizes = array, static_strides = array, operand_segment_sizes = array} : (index) -> !gml_st.tile<8x8> func.return %1 : !gml_st.tile<8x8> } @@ -168,7 +18,7 @@ func.func @tile_op_mismatch_offsets_and_static_offsets(%i: index) -> !gml_st.til func.func @tile_op_mismatch_strides_and_static_strides(%i: index) -> !gml_st.tile<8x8> { // expected-error@+1 {{expected 0 dynamic stride values}} - %1 = "gml_st.tile"(%i) {static_offsets = [0, 0], static_sizes = [8, 8], static_strides = [1, 1], operand_segment_sizes = array} : (index) -> !gml_st.tile<8x8> + %1 = "gml_st.tile"(%i) {static_offsets = array, static_sizes = array, static_strides = array, operand_segment_sizes = array} : (index) -> !gml_st.tile<8x8> func.return %1 : !gml_st.tile<8x8> } @@ -176,7 +26,7 @@ func.func @tile_op_mismatch_strides_and_static_strides(%i: index) -> !gml_st.ti func.func @tile_op_negative_static_size(%i: index) -> !gml_st.tile { // expected-error@+1 {{'gml_st.tile' op expected size = -2 to be non-negative}} - %1 = "gml_st.tile"(%i) {static_offsets = [0, 0], static_sizes = [-1, -2], static_strides = [1, 1], operand_segment_sizes = array} : (index) -> !gml_st.tile + %1 = "gml_st.tile"(%i) {static_offsets = array, static_sizes = array, static_strides = array, operand_segment_sizes = array} : (index) -> !gml_st.tile func.return %1 : !gml_st.tile } @@ -184,7 +34,7 @@ func.func @tile_op_negative_static_size(%i: index) -> !gml_st.tile { func.func @tile_op_negative_static_stride(%i: index) -> !gml_st.tile { // expected-error@+1 {{'gml_st.tile' op expected stride = -2 to be non-negative}} - %1 = "gml_st.tile"(%i) {static_offsets = [0, 0], static_sizes = [-1, 8], static_strides = [1, -2], operand_segment_sizes = array} : (index) -> !gml_st.tile + %1 = "gml_st.tile"(%i) {static_offsets = array, static_sizes = array, static_strides = array, operand_segment_sizes = array} : (index) -> !gml_st.tile func.return %1 : !gml_st.tile } @@ -192,7 +42,7 @@ func.func @tile_op_negative_static_stride(%i: index) -> !gml_st.tile { func.func @tile_op_negative_static_offset(%i: index) -> !gml_st.tile { // expected-error@+1 {{'gml_st.tile' op expected offset = -2 to be non-negative}} - %1 = "gml_st.tile"(%i) {static_offsets = [0, -2], static_sizes = [-1, 8], static_strides = [1, 1], operand_segment_sizes = array} : (index) -> !gml_st.tile + %1 = "gml_st.tile"(%i) {static_offsets = array, static_sizes = array, static_strides = array, operand_segment_sizes = array} : (index) -> !gml_st.tile func.return %1 : !gml_st.tile } @@ -204,19 +54,18 @@ func.func @for_loop_wrong_yield_target( %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %identity = gml_st.tile[][][] : !gml_st.tile<> %sum = gml_st.for (%i) = (%c0) to (%c8) step (%c4) outs(%out_ = %output : tensor) { - %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> - %arg_sub = gml_st.materialize %arg[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> - %out_sub = gml_st.materialize %out_[%identity] - : tensor[!gml_st.tile<>] to tensor + %arg_sub = tensor.extract_slice %arg[%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> + %out_sub = tensor.extract_slice %out_[][][] + : tensor to tensor %result_sub = linalg.dot ins(%arg_sub, %arg_sub : tensor<4xf32>, tensor<4xf32>) outs(%out_sub : tensor) -> tensor + %identity = gml_st.tile[][][] : !gml_st.tile<> // expected-error@+1 {{'gml_st.set_yield' op expected output block argument 0 to match set_yield destination}} gml_st.set_yield %result_sub into %output[%identity] : tensor into tensor[!gml_st.tile<>] @@ -232,20 +81,20 @@ func.func @yield_with_accumulator_mismatched_type( %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %identity = gml_st.tile[][][] : !gml_st.tile<> - %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c4) { - %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> - %arg_sub = gml_st.materialize %arg[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> - %out_sub = gml_st.materialize %output[%identity] - : tensor[!gml_st.tile<>] to tensor + %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c4) + outs (%out_ = %output: tensor) { + %arg_sub = tensor.extract_slice %arg[%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> + %out_sub = tensor.extract_slice %out_[][][] + : tensor to tensor %result_sub = linalg.dot ins(%arg_sub, %arg_sub : tensor<4xf32>, tensor<4xf32>) outs(%out_sub : tensor) -> tensor + %identity = gml_st.tile[][][] : !gml_st.tile<> // expected-error@+1 {{'gml_st.set_yield' op expected accumulator region to have 2 arguments of type 'tensor'}} - gml_st.set_yield %result_sub into %output[%identity] + gml_st.set_yield %result_sub into %out_[%identity] acc (%in, %out: memref) { gml_st.yield %in : memref }: tensor into tensor[!gml_st.tile<>] @@ -263,12 +112,10 @@ func.func @for_loop_wrong_yield_operands( %sum = gml_st.for (%i) = (%c0) to (%c8) step (%c4) outs(%out_ = %output : tensor) { - %tile_0d =gml_st.tile [%i] [4] [1] : !gml_st.tile<> - %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> - %arg_sub = gml_st.materialize %arg[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> - %out_sub = gml_st.materialize %out_[%tile_0d] - : tensor[!gml_st.tile<>] to tensor + %arg_sub = tensor.extract_slice %arg[%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> + %out_sub = tensor.extract_slice %out_[][][] + : tensor to tensor %result_sub = linalg.dot ins(%arg_sub, %arg_sub : tensor<4xf32>, tensor<4xf32>) @@ -279,3 +126,24 @@ func.func @for_loop_wrong_yield_operands( } : tensor func.return %sum : tensor } + +// ----- + +func.func @missing_output_tensors(%in: tensor<8x8xf32>) -> tensor<8x8xf32> { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + // expected-error@+1 {{expected the number of output arguments to match the number of results}} + %13 = gml_st.parallel (%arg4, %arg5) = (%c0, %c16) to (%c1, %c16) + step (%c8, %c8) { + %19 = gml_st.tile [%arg4, %arg5] [8, 8] [1, 1] : !gml_st.tile<8x8> + %11 = linalg.fill ins(%cst : f32) outs(%0 : tensor<8x8xf32>) + -> tensor<8x8xf32> + gml_st.set_yield %11 into %0[%19] : tensor<8x8xf32> + into tensor<8x8xf32>[!gml_st.tile<8x8>] + } : tensor<8x8xf32> + return %13 : tensor<8x8xf32> +} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/legacy_bufferization.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/legacy_bufferization.mlir deleted file mode 100644 index 92f2f8a47cc..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/legacy_bufferization.mlir +++ /dev/null @@ -1,133 +0,0 @@ -// RUN: mlir-hlo-opt %s -test-gml-st-bufferization --canonicalize -cse \ -// RUN: -split-input-file | FileCheck %s - -func.func private @some_use(memref) - -#TILE_MAP = affine_map<(d0)[s0] -> (3, -d0 + s0)> - -// CHECK-DAG: #[[$TILE_MAP:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)> - -// CHECK: func @tiled_dot( -// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref -// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref -// CHECK-SAME: %[[c:[a-zA-Z0-9]*]]: memref -func.func @tiled_dot(%A: tensor {bufferization.writeable = false}, - %B: tensor {bufferization.writeable = false}, - %c: tensor {bufferization.writeable = true}, - %effecting: memref) -> tensor { - %c3 = arith.constant 3 : index - %c0 = arith.constant 0 : index - - // CHECK: %[[M:.*]] = memref.dim %[[A]], {{.*}} : memref - %0 = tensor.dim %A, %c0 : tensor - - // CHECK: gml_st.loop {{.*}} to (%[[M]]) {{.*}} %[[A]]{{.*}}%[[B]]{{.*}}outs{{.*}}%[[c]] - // CHECK-NOT: copy - %1 = gml_st.loop (%arg3) = (%c0) to (%0) step (%c3) - ins (%arg4 = %A: tensor, %use = %effecting : memref, - %arg5 = %B: tensor) - outs (%arg6 = %c: tensor) - iterators[#gml_st.iterator_type] { - // CHECK-NOT: alloc - - %2 = tensor.dim %arg4, %c0 : tensor - %3 = affine.min #TILE_MAP(%arg3)[%2] - - // CHECK: %[[SV_A:.*]] = memref.subview {{.*}} - %4 = tensor.extract_slice %arg4[%arg3] [%3] [1] - : tensor to tensor - %5 = tensor.dim %arg5, %c0 : tensor - %6 = affine.min #TILE_MAP(%arg3)[%5] - - // CHECK: %[[SV_B:.*]] = memref.subview {{.*}} - %7 = tensor.extract_slice %arg5[%arg3] [%6] [1] : tensor to tensor - - // CHECK: linalg.dot ins(%[[SV_A]], %[[SV_B]] : memref, memref>) outs(%{{.*}} : memref) - %8 = linalg.dot ins(%4, %7 : tensor, tensor) - outs(%arg6 : tensor) -> tensor - - // CHECK: call @some_use(%{{.*}}) : (memref) -> () - func.call @some_use(%use) : (memref) -> () - - gml_st.yield %8 : tensor - // CHECK: gml_st.yield - // CHECK-NOT: tensor - } - - // CHECK: return - // CHECK-NOT: tensor - func.return %1 : tensor -} - -// ----- - -#TILE_MAP = affine_map<(d0)[s0] -> (3, -d0 + s0)> - -// CHECK: func @tiled_fill( -// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref -func.func @tiled_fill(%A: tensor {bufferization.writeable = true}) -> tensor { - %c3 = arith.constant 3 : index - %c0 = arith.constant 0 : index - %f0 = arith.constant 0.0 : f32 - - // CHECK: %[[M:.*]] = memref.dim %[[A]], {{.*}} : memref - %0 = tensor.dim %A, %c0 : tensor - - // CHECK: gml_st.loop {{.*}} to (%[[M]]) {{.*}} outs{{.*}}%[[A]] - %1 = gml_st.loop (%arg3) = (%c0) to (%0) step (%c3) - outs (%arg1 = %A: tensor) - iterators[#gml_st.iterator_type] { - // CHECK-NOT: alloc - - %2 = tensor.dim %arg1, %c0 : tensor - %3 = affine.min #TILE_MAP(%arg3)[%2] - - // CHECK: %[[SV_A:.*]] = memref.subview {{.*}} - %4 = tensor.extract_slice %arg1[%arg3] [%3] [1] : tensor to tensor - - // CHECK: linalg.fill ins(%{{.*}}: f32) outs(%[[SV_A]] : memref) - %5 = linalg.fill ins(%f0: f32) outs(%4: tensor) - -> tensor - %6 = tensor.insert_slice %5 into %arg1[%arg3] [%3] [1] : tensor into tensor - - gml_st.yield %6 : tensor - // CHECK: gml_st.yield - // CHECK-NOT: tensor - } - - // CHECK: return - // CHECK-NOT: tensor - func.return %1 : tensor -} - -// ----- - -// CHECK: func @tiled_loop_yield_out_of_place( -// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref -// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref -func.func @tiled_loop_yield_out_of_place( - %A: tensor {bufferization.writeable = true}, - %B: tensor {bufferization.writeable = true}) -> tensor { - %c3 = arith.constant 3 : index - %c0 = arith.constant 0 : index - %f0 = arith.constant 0.0 : f32 - - // CHECK: %[[M:.*]] = memref.dim %[[A]], {{.*}} : memref - %0 = tensor.dim %A, %c0 : tensor - - // CHECK: gml_st.loop {{.*}} to (%[[M]]) {{.*}} outs{{.*}}%[[A]] - %1 = gml_st.loop (%arg3) = (%c0) to (%0) step (%c3) - outs (%arg1 = %A: tensor) - iterators[#gml_st.iterator_type] - { - // CHECK-NOT: alloc - // CHECK: memref.copy %[[B]], %[[A]] - gml_st.yield %B : tensor - // CHECK: gml_st.yield - // CHECK-NOT: tensor - } - - // CHECK: return - // CHECK-NOT: tensor - func.return %1 : tensor -} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/legacy_loop_tiling.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/legacy_loop_tiling.mlir deleted file mode 100644 index ea600613372..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/legacy_loop_tiling.mlir +++ /dev/null @@ -1,93 +0,0 @@ -// RUN: mlir-hlo-opt %s -split-input-file \ -// RUN: -test-gml-st-loop-tiling="tile-sizes=2,3,4 distribution-types=block_x,block_y,none" \ -// RUN: | FileCheck %s - -func.func @matmul_tensors( - %arg0: tensor, %arg1: tensor, %arg2: tensor) - -> tensor { - %0 = linalg.matmul ins(%arg0, %arg1: tensor, tensor) - outs(%arg2: tensor) - -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: func @matmul_tensors -// CHECK-SAME: (%[[ARG_0:.*]]: [[TY:.*]], %[[ARG_1:.*]]: [[TY]], -// CHECK-SAME: %[[ARG_2:.*]]: [[TY]]) -> [[TY]] { - -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index - -// CHECK: %[[ARG_0_X:.*]] = tensor.dim %[[ARG_0]], %[[C0]] : [[TY]] -// CHECK: %[[ARG_0_Y:.*]] = tensor.dim %[[ARG_0]], %[[C1]] : [[TY]] -// CHECK: %[[ARG_1_Y:.*]] = tensor.dim %[[ARG_1]], %[[C1]] : [[TY]] - -// CHECK: %{{.*}} = gml_st.loop (%[[I:.*]], %[[J:.*]], %[[K:.*]]) = -// CHECK-SAME: (%[[C0]], %[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[ARG_0_X]], %[[ARG_1_Y]], %[[ARG_0_Y]]) -// CHECK-SAME: step (%[[C2]], %[[C3]], %[[C4]]) -// CHECK-SAME: ins (%[[A0:.*]] = %[[ARG_0]]: [[TY]], %[[A1:.*]] = %[[ARG_1]]: [[TY]]) -// CHECK-SAME: outs (%[[A2:.*]] = %[[ARG_2]]: [[TY]]) -// CHECK-SAME: iterators[#gml_st.iterator_type, -// CHECK-SAME: #gml_st.iterator_type, #gml_st.iterator_type] -// CHECK-SAME: distribution["block_x", "block_y", "none"] { - -// CHECK: %[[SUB_ARG_0:.*]] = tensor.extract_slice %[[A0]][%[[I]], %[[K]]] -// CHECK: %[[SUB_ARG_1:.*]] = tensor.extract_slice %[[A1]][%[[K]], %[[J]]] -// CHECK: %[[SUB_ARG_2:.*]] = tensor.extract_slice %[[A2]][%[[I]], %[[J]]] - -// CHECK: %[[PROD:.*]] = linalg.matmul ins(%[[SUB_ARG_0]], %[[SUB_ARG_1]] -// CHECK-SE: outs(%[[SUB_ARG_2]] : [[TY]]) -> [[TY]] - -// CHECK: %[[O:.*]] = tensor.insert_slice %[[PROD]] into %[[A2]][%[[I]], %[[J]]] -// CHECK: gml_st.yield %[[O]] : [[TY]] - -// ----- - -func.func @generic_op_tensors( - %arg0 : tensor, %arg1 : tensor) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %0 = tensor.dim %arg0, %c0 : tensor - %1 = tensor.dim %arg0, %c1 : tensor - %2 = tensor.dim %arg0, %c2 : tensor - %3 = tensor.empty(%0, %1, %2) : tensor - %4 = linalg.generic - {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, - affine_map<(d0, d1, d2) -> (d0, d2, d1)>, - affine_map<(d0, d1, d2) -> (d2, d1, d0)>], - iterator_types = ["parallel", "parallel", "parallel"]} - ins(%arg0, %arg1 : tensor, tensor) - outs(%3 : tensor) { - ^bb0(%arg2 : f32, %arg3: f32, %arg4: f32): - %5 = arith.addf %arg2, %arg3 : f32 - linalg.yield %5 : f32 - } -> tensor - func.return %4 : tensor -} -// CHECK-LABEL: func @generic_op_tensors( -// CHECK-SAME: %[[ARG_0:.*]]: [[TY:.*]], -// CHECK-SAME: %[[ARG_1:.*]]: [[TY]]) -> [[TY]] { - -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index - -// CHECK: %[[INIT:.*]] = tensor.empty -// CHECK: %[[ARG_0_X:.*]] = tensor.dim %[[ARG_0]], %[[C0]] : [[TY]] -// CHECK: %[[ARG_0_Y:.*]] = tensor.dim %[[ARG_0]], %[[C1]] : [[TY]] -// CHECK: %[[ARG_0_Z:.*]] = tensor.dim %[[ARG_0]], %[[C2]] : [[TY]] - -// CHECK: %{{.*}} = gml_st.loop (%{{.*}}, %{{.*}}, %{{.*}}) = -// CHECK-SAME: (%[[C0]], %[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[ARG_0_X]], %[[ARG_0_Y]], %[[ARG_0_Z]]) -// CHECK-SAME: step (%[[C2]], %[[C3]], %[[C4]]) -// CHECK-SAME: ins (%{{.*}} = %[[ARG_0]]: [[TY]], %{{.*}} = %[[ARG_1]]: [[TY]]) -// CHECK-SAME: outs (%{{.*}} = %[[INIT]]: [[TY]]) -// CHECK-SAME: distribution["block_x", "block_y", "none"] { diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/legacy_peeling.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/legacy_peeling.mlir deleted file mode 100644 index 8f3f26d2b3d..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/legacy_peeling.mlir +++ /dev/null @@ -1,171 +0,0 @@ -// RUN: mlir-hlo-opt %s -allow-unregistered-dialect -test-gml-st-loop-peeling="dims=2" -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-2 -// RUN: mlir-hlo-opt %s -allow-unregistered-dialect -test-gml-st-loop-peeling="dims=0,1,2" -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-012 -// RUN: mlir-hlo-opt %s -allow-unregistered-dialect -test-gml-st-loop-peeling="dims=0,1,2 skip-partial" -split-input-file | FileCheck %s -check-prefix=CHECK-TILE-012-SKIP-PARTIAL - -// CHECK-TILE-2-LABEL: func @loop_3d_tensor( -// CHECK-TILE-2-SAME: %[[input:.*]]: tensor, %[[s0:.*]]: index, %[[s1:.*]]: index, %[[s2:.*]]: index -// CHECK-TILE-2-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-TILE-2-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-TILE-2-DAG: %[[c2:.*]] = arith.constant 2 : index -// CHECK-TILE-2: %[[dim0:.*]] = tensor.dim %[[input]], %[[c0]] -// CHECK-TILE-2: %[[dim1:.*]] = tensor.dim %[[input]], %[[c1]] -// CHECK-TILE-2: %[[dim2:.*]] = tensor.dim %[[input]], %[[c2]] -// CHECK-TILE-2: %[[init_tensor:.*]] = tensor.empty -// CHECK-TILE-2: %[[split_bound:.*]] = affine.apply -// CHECK-TILE-2: %[[r1:.*]] = gml_st.loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) -// CHECK-TILE-2-SAME: to (%[[dim0]], %[[dim1]], %[[split_bound]]) -// CHECK-TILE-2-SAME: step (%[[s0]], %[[s1]], %[[s2]]) -// CHECK-TILE-2-SAME: ins (%[[loop_in1:.*]] = %[[input]]: tensor) -// CHECK-TILE-2-SAME: outs (%[[loop_out1:.*]] = %[[init_tensor]]: tensor) { -// CHECK-TILE-2: %[[min0_1:.*]] = affine.min -// CHECK-TILE-2: %[[min1_1:.*]] = affine.min -// CHECK-TILE-2: %[[in_slice1:.*]] = tensor.extract_slice %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]] -// CHECK-TILE-2: %[[out_slice1:.*]] = tensor.extract_slice %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]] -// CHECK-TILE-2: %[[mod_slice1:.*]] = tensor.insert_slice %{{.*}} into %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_1]], %[[min1_1]], %[[s2]]] -// CHECK-TILE-2: gml_st.yield %[[mod_slice1]] -// CHECK-TILE-2: %[[r2:.*]] = gml_st.loop (%[[iv0:.*]], %[[iv1:.*]], %[[iv2:.*]]) = (%[[c0]], %[[c0]], %[[split_bound]]) -// CHECK-TILE-2-SAME: to (%[[dim0]], %[[dim1]], %[[dim2]]) -// CHECK-TILE-2-SAME: step (%[[s0]], %[[s1]], %[[s2]]) -// CHECK-TILE-2-SAME: ins (%[[loop_in2:.*]] = %[[input]]: tensor) -// CHECK-TILE-2-SAME: outs (%[[loop_out2:.*]] = %[[r1]]: tensor) { -// CHECK-TILE-2: %[[min0_2:.*]] = affine.min -// CHECK-TILE-2: %[[min1_2:.*]] = affine.min -// CHECK-TILE-2: %[[apply2:.*]] = affine.apply -// CHECK-TILE-2: %[[in_slice2:.*]] = tensor.extract_slice %[[loop_in1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]] -// CHECK-TILE-2: %[[out_slice2:.*]] = tensor.extract_slice %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]] -// CHECK-TILE-2: %[[mod_slice2:.*]] = tensor.insert_slice %{{.*}} into %[[loop_out1]][%[[iv0]], %[[iv1]], %[[iv2]]] [%[[min0_2]], %[[min1_2]], %[[apply2]]] -// CHECK-TILE-2: gml_st.yield %[[mod_slice2]] -// CHECK-TILE-2: return %[[r2]] - -// CHECK-TILE-012-LABEL: func @loop_3d_tensor -// CHECK-TILE-012: gml_st.loop {{.*}} { -// CHECK-TILE-012: gml_st.yield -// CHECK-TILE-012: } -// CHECK-TILE-012: gml_st.loop {{.*}} { -// CHECK-TILE-012: gml_st.yield -// CHECK-TILE-012: } -// CHECK-TILE-012: gml_st.loop {{.*}} { -// CHECK-TILE-012: gml_st.yield -// CHECK-TILE-012: } -// CHECK-TILE-012: gml_st.loop {{.*}} { -// CHECK-TILE-012: gml_st.yield -// CHECK-TILE-012: } -// CHECK-TILE-012: gml_st.loop {{.*}} { -// CHECK-TILE-012: gml_st.yield -// CHECK-TILE-012: } -// CHECK-TILE-012: gml_st.loop {{.*}} { -// CHECK-TILE-012: gml_st.yield -// CHECK-TILE-012: } -// CHECK-TILE-012: gml_st.loop {{.*}} { -// CHECK-TILE-012: gml_st.yield -// CHECK-TILE-012: } -// CHECK-TILE-012: gml_st.loop {{.*}} { -// CHECK-TILE-012: gml_st.yield -// CHECK-TILE-012: } -// CHECK-TILE-012-NOT: gml_st.loop - -// CHECK-TILE-012-SKIP-PARTIAL: func @loop_3d_tensor( -// CHECK-TILE-012-SKIP-PARTIAL-SAME: %[[input:.*]]: tensor -// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[c2:.*]] = arith.constant 2 : index -// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[dim0:.*]] = tensor.dim %[[input]], %[[c0]] -// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[dim1:.*]] = tensor.dim %[[input]], %[[c1]] -// CHECK-TILE-012-SKIP-PARTIAL-DAG: %[[dim2:.*]] = tensor.dim %[[input]], %[[c2]] -// CHECK-TILE-012-SKIP-PARTIAL: %[[p0:.*]] = affine.apply #{{.*}}()[%[[dim0]] -// CHECK-TILE-012-SKIP-PARTIAL: %[[p1:.*]] = affine.apply #{{.*}}()[%[[dim1]] -// CHECK-TILE-012-SKIP-PARTIAL: %[[p2:.*]] = affine.apply #{{.*}}()[%[[dim2]] -// CHECK-TILE-012-SKIP-PARTIAL: gml_st.loop {{.*}} = (%[[c0]], %[[c0]], %[[c0]]) to (%[[p0]], %[[p1]], %[[p2]]) -// CHECK-TILE-012-SKIP-PARTIAL: gml_st.loop {{.*}} = (%[[c0]], %[[c0]], %[[p2]]) to (%[[p0]], %[[p1]], %[[dim2]]) -// CHECK-TILE-012-SKIP-PARTIAL: gml_st.loop {{.*}} = (%[[c0]], %[[p1]], %[[c0]]) to (%[[p0]], %[[dim1]], %[[dim2]]) -// CHECK-TILE-012-SKIP-PARTIAL: gml_st.loop {{.*}} = (%[[p0]], %[[c0]], %[[c0]]) to (%[[dim0]], %[[dim1]], %[[dim2]]) -func.func @loop_3d_tensor(%arg0: tensor, %s0: index, %s1: index, - %s2: index) -> tensor { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c8 = arith.constant 8 : index - %dim0 = tensor.dim %arg0, %c0 : tensor - %dim1 = tensor.dim %arg0, %c1 : tensor - %dim2 = tensor.dim %arg0, %c2 : tensor - %output = tensor.empty(%dim0, %dim1, %dim2) : tensor - %result = gml_st.loop - (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %dim2) - step (%s0, %s1, %s2) ins (%arg4 = %arg0: tensor) - outs (%arg5 = %output: tensor) { - %min0 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg1, %s0)[%dim0] - %min1 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg2, %s1)[%dim1] - %min2 = affine.min affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>(%arg3, %s2)[%dim2] - %in_slice = tensor.extract_slice %arg4[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1]: tensor to tensor - %out_slice = tensor.extract_slice %arg5[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1] : tensor to tensor - %comp = "computation"(%in_slice, %out_slice) : (tensor, tensor) -> tensor - %updated_slice = tensor.insert_slice %comp into %arg5[%arg1, %arg2, %arg3] [%min0, %min1, %min2] [1, 1, 1] : tensor into tensor - gml_st.yield %updated_slice : tensor - } - func.return %result : tensor -} - -// ----- - -// CHECK-TILE-2-LABEL: func @step_1_do_not_peel -// CHECK-TILE-2: gml_st.loop -// CHECK-TILE-2-NOT: gml_st.loop - -// CHECK-TILE-012-LABEL: func @step_1_do_not_peel - -func.func @step_1_do_not_peel(%arg0: tensor) -> tensor { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c8 = arith.constant 8 : index - %dim0 = tensor.dim %arg0, %c0 : tensor - %dim1 = tensor.dim %arg0, %c1 : tensor - %dim2 = tensor.dim %arg0, %c2 : tensor - %output = tensor.empty(%dim0, %dim1, %dim2) : tensor - %result = gml_st.loop - (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %dim2) - step (%c1, %c1, %c1) ins (%arg4 = %arg0: tensor) - outs (%arg5 = %output: tensor) { - %in_slice = tensor.extract_slice %arg4[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1]: tensor to tensor - %out_slice = tensor.extract_slice %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor to tensor - %comp = "computation"(%in_slice, %out_slice) : (tensor, tensor) -> tensor - %updated_slice = tensor.insert_slice %comp into %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor into tensor - gml_st.yield %updated_slice : tensor - } - func.return %result : tensor -} - -// ----- - -// CHECK-TILE-2-LABEL: func @divides_evenly_do_not_peel -// CHECK-TILE-2: gml_st.loop -// CHECK-TILE-2-NOT: gml_st.loop - -// CHECK-TILE-012-LABEL: func @divides_evenly_do_not_peel - -func.func @divides_evenly_do_not_peel(%arg0: tensor, %s: index) - -> tensor { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c8 = arith.constant 8 : index - %c64 = arith.constant 64 : index - %dim0 = tensor.dim %arg0, %c0 : tensor - %dim1 = tensor.dim %arg0, %c1 : tensor - %dim2 = tensor.dim %arg0, %c2 : tensor - %output = tensor.empty(%dim0, %dim1, %dim2) : tensor - %result = gml_st.loop - (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%dim0, %dim1, %c64) - step (%s, %s, %c8) ins (%arg4 = %arg0: tensor) - outs (%arg5 = %output: tensor) { - %in_slice = tensor.extract_slice %arg4[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1]: tensor to tensor - %out_slice = tensor.extract_slice %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor to tensor - %comp = "computation"(%in_slice, %out_slice) : (tensor, tensor) -> tensor - %updated_slice = tensor.insert_slice %comp into %arg5[%arg1, %arg2, %arg3] [%c1, %c1, %c1] [1, 1, 1] : tensor into tensor - gml_st.yield %updated_slice : tensor - } - func.return %result : tensor -} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling.mlir index 4fae330546b..ff31254de32 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling.mlir @@ -16,17 +16,19 @@ func.func @add(%lhs : tensor, %rhs : tensor) // CHECK-DAG: %[[C512:.*]] = arith.constant 512 // CHECK: %[[INIT:.*]] = tensor.empty // CHECK: %[[LOOP:.*]] = gml_st.parallel - // CHECK: %[[LHS_SUB:.*]] = gml_st.materialize %[[LHS]] - // CHECK: %[[RHS_SUB:.*]] = gml_st.materialize %[[RHS]] - // CHECK: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]] + // CHECK-SAME: outs (%[[INIT_:.*]] = %[[INIT]]: + // CHECK: %[[LHS_SUB:.*]] = tensor.extract_slice %[[LHS]] + // CHECK: %[[RHS_SUB:.*]] = tensor.extract_slice %[[RHS]] + // CHECK: %[[INIT_SUB:.*]] = tensor.extract_slice %[[INIT_]] // CHECK: %[[LOOP_:.*]] = gml_st.parallel - // CHECK: %[[LHS_SUB_2:.*]] = gml_st.materialize %[[LHS_SUB]] - // CHECK: %[[RHS_SUB_2:.*]] = gml_st.materialize %[[RHS_SUB]] - // CHECK: %[[INIT_SUB_2:.*]] = gml_st.materialize %[[INIT_SUB]] + // CHECK-SAME: outs (%[[INIT_SUB_:.*]] = %[[INIT_SUB]]: + // CHECK: %[[LHS_SUB_2:.*]] = tensor.extract_slice %[[LHS_SUB]] + // CHECK: %[[RHS_SUB_2:.*]] = tensor.extract_slice %[[RHS_SUB]] + // CHECK: %[[INIT_SUB_2:.*]] = tensor.extract_slice %[[INIT_SUB_]] // CHECK: %[[GENERIC:.*]] = linalg.generic - // CHECK: gml_st.set_yield %[[GENERIC]] into %[[INIT_SUB]] - // CHECK: gml_st.set_yield %[[LOOP_]] into %[[INIT]] + // CHECK: gml_st.set_yield %[[GENERIC]] into %[[INIT_SUB_]] + // CHECK: gml_st.set_yield %[[LOOP_]] into %[[INIT_]] // CHECK: return %[[LOOP]] %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling_cwise.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling_cwise.mlir index 2b151a1b855..3ee7d1ef746 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling_cwise.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling_cwise.mlir @@ -23,35 +23,38 @@ func.func @cwise_expr(%a: tensor, %b: tensor, // CHECK-SAME: (%[[I:.*]], %[[J:.*]], %[[K:.*]]) = (%[[C0]], %[[C0]], %[[C0]]) // CHECK-SAME: to (%[[A_D0]], %[[C1024]], %[[C1024]]) // CHECK-SAME: step (%[[C1]], %[[C512]], %[[C1024]]) - // CHECK-DAG: %[[A_SUB:.*]] = gml_st.materialize %[[A]][%{{.*}}] - // CHECK-DAG: %[[B_SUB:.*]] = gml_st.materialize %[[B]][%{{.*}}] - // CHECK-DAG: %[[C_SUB:.*]] = gml_st.materialize %[[C]][%{{.*}}] - // CHECK-DAG: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]][%{{.*}}] + // CHECK-SAME: outs (%[[INIT_:.*]] = %[[INIT]]: + // CHECK-DAG: %[[A_SUB:.*]] = tensor.extract_slice %[[A]] + // CHECK-DAG: %[[B_SUB:.*]] = tensor.extract_slice %[[B]] + // CHECK-DAG: %[[C_SUB:.*]] = tensor.extract_slice %[[C]] + // CHECK-DAG: %[[INIT_SUB:.*]] = tensor.extract_slice %[[INIT_]] // CHECK: %[[PLOOP_:.*]] = gml_st.parallel // CHECK-SAME: (%[[I_:.*]], %[[J_:.*]], %[[K_:.*]]) = (%[[C0]], %[[C0]], %[[C0]]) // CHECK-SAME: to (%[[C1]], %[[C512]], %[[C1024]]) // CHECK-SAME: step (%[[C1]], %[[C64]], %[[C128]]) - // CHECK-DAG: %[[A_SUB_SUB:.*]] = gml_st.materialize %[[A_SUB]][%{{.*}}] - // CHECK-DAG: %[[B_SUB_SUB:.*]] = gml_st.materialize %[[B_SUB]][%{{.*}}] - // CHECK-DAG: %[[C_SUB_SUB:.*]] = gml_st.materialize %[[C_SUB]][%{{.*}}] - // CHECK-DAG: %[[INIT_SUB_SUB:.*]] = gml_st.materialize %[[INIT_SUB]][%{{.*}}] + // CHECK-SAME: outs (%[[INIT_SUB_:.*]] = %[[INIT_SUB]]: + // CHECK-DAG: %[[A_SUB_SUB:.*]] = tensor.extract_slice %[[A_SUB]] + // CHECK-DAG: %[[B_SUB_SUB:.*]] = tensor.extract_slice %[[B_SUB]] + // CHECK-DAG: %[[C_SUB_SUB:.*]] = tensor.extract_slice %[[C_SUB]] + // CHECK-DAG: %[[INIT_SUB_SUB:.*]] = tensor.extract_slice %[[INIT_SUB_]] // CHECK: %[[PLOOP__:.*]] = gml_st.parallel // CHECK-SAME: (%[[I__:.*]], %[[J__:.*]], %[[K__:.*]]) = (%[[C0]], %[[C0]], %[[C0]]) // CHECK-SAME: to (%[[C1]], %[[C64]], %[[C128]]) // CHECK-SAME: step (%[[C1]], %[[C1]], %[[C32]]) - // CHECK-DAG: %[[A_SUB_SUB_SUB:.*]] = gml_st.materialize %[[A_SUB_SUB]][%{{.*}}] - // CHECK-DAG: %[[B_SUB_SUB_SUB:.*]] = gml_st.materialize %[[B_SUB_SUB]][%{{.*}}] - // CHECK-DAG: %[[INIT_SUB_SUB_SUB:.*]] = gml_st.materialize %[[INIT_SUB_SUB]][%{{.*}}] + // CHECK-SAME: outs (%[[INIT_SUB_SUB_:.*]] = %[[INIT_SUB_SUB]]: + // CHECK-DAG: %[[A_SUB_SUB_SUB:.*]] = tensor.extract_slice %[[A_SUB_SUB]] + // CHECK-DAG: %[[B_SUB_SUB_SUB:.*]] = tensor.extract_slice %[[B_SUB_SUB]] // CHECK: %[[AB_SUB_SUB_SUB:.*]] = linalg.generic // CHECK-SAME: ins(%[[A_SUB_SUB_SUB]], %[[B_SUB_SUB_SUB]] : tensor<1x1x32xf32>, tensor<1x1x32xf32>) - // CHECK-SAME: outs(%[[INIT_SUB_SUB_SUB]] : tensor<1x1x32xf32>) - // CHECK-DAG: %[[C_SUB_SUB_SUB:.*]] = gml_st.materialize %[[C_SUB_SUB]][%{{.*}}] + // CHECK-SAME: outs(%{{.*}} : tensor<1x1x32xf32>) + // CHECK-DAG: %[[C_SUB_SUB_SUB:.*]] = tensor.extract_slice %[[C_SUB_SUB]] + // CHECK-DAG: %[[INIT_SUB_SUB_SUB_:.*]] = tensor.extract_slice %[[INIT_SUB_SUB_]] // CHECK: %[[ABC_SUB_SUB_SUB:.*]] = linalg.generic // CHECK-SAME: ins(%[[AB_SUB_SUB_SUB]], %[[C_SUB_SUB_SUB]] : tensor<1x1x32xf32>, tensor<1x1x32xf32>) - // CHECK-SAME: outs(%[[INIT_SUB_SUB_SUB]] : tensor<1x1x32xf32>) - // CHECK: gml_st.set_yield %[[ABC_SUB_SUB_SUB]] into %[[INIT_SUB_SUB]][%{{.*}}] - // CHECK: gml_st.set_yield %[[PLOOP__]] into %[[INIT_SUB]][%{{.*}}] - // CHECK: gml_st.set_yield %[[PLOOP_]] into %[[INIT]][%{{.*}}] + // CHECK-SAME: outs(%[[INIT_SUB_SUB_SUB_]] : tensor<1x1x32xf32>) + // CHECK: gml_st.set_yield %[[ABC_SUB_SUB_SUB]] into %[[INIT_SUB_SUB_]][%{{.*}}] + // CHECK: gml_st.set_yield %[[PLOOP__]] into %[[INIT_SUB_]][%{{.*}}] + // CHECK: gml_st.set_yield %[[PLOOP_]] into %[[INIT_]][%{{.*}}] // CHECK: return %[[PLOOP]] %c0 = arith.constant 0 : index %d0 = tensor.dim %a, %c0 : tensor diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling_softmax.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling_softmax.mlir index 40bf9afc825..a2691a97232 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling_softmax.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling_softmax.mlir @@ -5,67 +5,51 @@ // RUN: --canonicalize --cse | \ // RUN: FileCheck %s -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0)> -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0)> - func.func @softmax(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %cst = arith.constant -0.000000e+00 : f32 %cst_0 = arith.constant 0xFF800000 : f32 %0 = tensor.empty() : tensor<64xf32> %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<64xf32>) -> tensor<64xf32> - %2 = linalg.generic {indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "reduction"]} - ins(%arg0 : tensor<64x128xf32>) outs(%1 : tensor<64xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - %11 = arith.maxf %arg2, %arg1 : f32 - linalg.yield %11 : f32 - } -> tensor<64xf32> + %2 = linalg.reduce ins(%arg0 : tensor<64x128xf32>) + outs(%1 : tensor<64xf32>) dimensions = [1] + (%arg1: f32, %arg2: f32) { + %11 = arith.maxf %arg1, %arg2 : f32 + linalg.yield %11 : f32 + } %3 = tensor.empty() : tensor<64x128xf32> - %4 = linalg.generic {indexing_maps = [#map1, #map0], - iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<64xf32>) - outs(%3 : tensor<64x128xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - linalg.yield %arg1 : f32 - } -> tensor<64x128xf32> - %5 = linalg.generic {indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %4 : tensor<64x128xf32>, tensor<64x128xf32>) - outs(%3 : tensor<64x128xf32>) { - ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %4 = linalg.broadcast + ins(%2 : tensor<64xf32>) + outs(%3 : tensor<64x128xf32>) + dimensions = [1] + %5 = linalg.map ins(%arg0, %4 : tensor<64x128xf32>, tensor<64x128xf32>) + outs(%3 : tensor<64x128xf32>) + (%arg1: f32, %arg2: f32) { %11 = arith.subf %arg1, %arg2 : f32 linalg.yield %11 : f32 - } -> tensor<64x128xf32> - %6 = linalg.generic {indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<64x128xf32>) - outs(%3 : tensor<64x128xf32>) { - ^bb0(%arg1: f32, %arg2: f32): + } + %6 = linalg.map ins(%5 : tensor<64x128xf32>) + outs(%3 : tensor<64x128xf32>) + (%arg1: f32) { %11 = math.exp %arg1 : f32 linalg.yield %11 : f32 - } -> tensor<64x128xf32> + } %7 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64xf32>) -> tensor<64xf32> - %8 = linalg.generic {indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "reduction"]} ins(%6 : tensor<64x128xf32>) - outs(%7 : tensor<64xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - %11 = arith.addf %arg2, %arg1 : f32 - linalg.yield %11 : f32 - } -> tensor<64xf32> - %9 = linalg.generic {indexing_maps = [#map1, #map0], - iterator_types = ["parallel", "parallel"]} ins(%8 : tensor<64xf32>) - outs(%3 : tensor<64x128xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - linalg.yield %arg1 : f32 - } -> tensor<64x128xf32> - %10 = linalg.generic {indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel"]} - ins(%6, %9 : tensor<64x128xf32>, tensor<64x128xf32>) - outs(%3 : tensor<64x128xf32>) { - ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %8 = linalg.reduce ins(%6 : tensor<64x128xf32>) + outs(%7 : tensor<64xf32>) dimensions = [1] + (%arg1: f32, %arg2: f32) { + %11 = arith.addf %arg2, %arg1 : f32 + linalg.yield %11 : f32 + } + %9 = linalg.broadcast + ins(%8 : tensor<64xf32>) + outs(%3 : tensor<64x128xf32>) + dimensions = [1] + %10 = linalg.map ins(%6, %9 : tensor<64x128xf32>, tensor<64x128xf32>) + outs(%3 : tensor<64x128xf32>) + (%arg1: f32, %arg2: f32) { %11 = arith.divf %arg1, %arg2 : f32 linalg.yield %11 : f32 - } -> tensor<64x128xf32> + } return %10 : tensor<64x128xf32> } // CHECK-LABEL: @softmax @@ -84,62 +68,51 @@ func.func @softmax(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK-SAME: outs(%[[EMPTY]] : tensor<64xf32>) // CHECK: %[[PARALLEL:.*]] = gml_st.parallel (%[[ARG1:.*]]) = (%[[C0]]) to (%[[C64]]) step (%[[C8]]) -// CHECK-NEXT: %[[TILE:.*]] = gml_st.tile [%[[ARG1]], 0] [8, 128] [1, 1] -// CHECK-NEXT: %[[MATERIALIZE:.*]] = gml_st.materialize %[[ARG0]][%[[TILE]]] -// CHECK-NEXT: %[[TILE_0:.*]] = gml_st.tile [%[[ARG1]]] [8] [1] -// CHECK-NEXT: %[[MATERIALIZE_0:.*]] = gml_st.materialize %[[FILL]][%[[TILE_0]]] -// CHECK-NEXT: %[[MATERIALIZE_1:.*]] = gml_st.materialize %[[EMPTY_0]][%[[TILE]]] -// CHECK-NEXT: %[[MATERIALIZE_3:.*]] = gml_st.materialize %[[FILL_0]][%[[TILE_0]]] +// CHECK-SAME: outs (%[[EMPTY_:.*]] = %[[EMPTY_0]]: +// CHECK-DAG: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0] [8, 128] [1, 1] +// CHECK-DAG: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[FILL]][%[[ARG1]]] [8] [1] +// CHECK-DAG: %[[MATERIALIZE_1:.*]] = tensor.extract_slice %[[EMPTY_0]][%[[ARG1]], 0] [8, 128] [1, 1] +// CHECK-DAG: %[[MATERIALIZE_3:.*]] = tensor.extract_slice %[[FILL_0]][%[[ARG1]]] [8] [1] +// CHECK-DAG: %[[EMPTY_SUB:.*]] = tensor.extract_slice %[[EMPTY_]] // CHECK: %[[PARALLEL_0:.*]] = gml_st.parallel (%[[ARG2:.*]]) = (%[[C0]]) to (%[[C8]]) step (%[[C1]]) -// CHECK-NEXT: %[[TILE_4:.*]] = gml_st.tile [%[[ARG2]], 0] [1, 128] [1, 1] -// CHECK-NEXT: %[[MATERIALIZE_4:.*]] = gml_st.materialize %[[MATERIALIZE]][%[[TILE_4]]] -// CHECK-NEXT: %[[TILE_5:.*]] = gml_st.tile [%[[ARG2]]] [1] [1] -// CHECK-NEXT: %[[MATERIALIZE_5:.*]] = gml_st.materialize %[[MATERIALIZE_0]][%[[TILE_5]]] - -// CHECK-NEXT: %[[GENERIC:.*]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0:map[0-9]*]], #[[MAP1:map[0-9]*]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] +// CHECK-SAME: outs (%[[EMPTY_SUB_:.*]] = %[[EMPTY_SUB]]: +// CHECK-NEXT: %[[MATERIALIZE_4:.*]] = tensor.extract_slice %[[MATERIALIZE]][%[[ARG2]], 0] [1, 128] [1, 1] +// CHECK-NEXT: %[[MATERIALIZE_5:.*]] = tensor.extract_slice %[[MATERIALIZE_0]][%[[ARG2]]] [1] [1] +// CHECK-NEXT: %[[REDUCE:.*]] = linalg.reduce // CHECK-SAME: ins(%[[MATERIALIZE_4]] : tensor<1x128xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_5]] : tensor<1xf32>) +// CHECK-SAME: dimensions = [1] -// CHECK: %[[MATERIALIZE_6:.*]] = gml_st.materialize %[[MATERIALIZE_1]][%[[TILE_4]]] -// CHECK-NEXT: %[[GENERIC_0:.*]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-SAME: ins(%[[GENERIC]] : tensor<1xf32>) +// CHECK: %[[MATERIALIZE_6:.*]] = tensor.extract_slice %[[MATERIALIZE_1]][%[[ARG2]], 0] [1, 128] [1, 1] +// CHECK-NEXT: %[[BROADCAST:.*]] = linalg.broadcast +// CHECK-SAME: ins(%[[REDUCE]] : tensor<1xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_6]] : tensor<1x128xf32>) +// CHECK-SAME: dimensions = [1] -// CHECK: %[[GENERIC_1:.*]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-SAME: ins(%[[MATERIALIZE_4]], %[[GENERIC_0]] : tensor<1x128xf32>, tensor<1x128xf32>) +// CHECK: %[[MAP:.*]] = linalg.map +// CHECK-SAME: ins(%[[MATERIALIZE_4]], %[[BROADCAST]] : tensor<1x128xf32>, tensor<1x128xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_6]] : tensor<1x128xf32>) -// CHECK: %[[GENERIC_2:.*]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-SAME: ins(%[[GENERIC_1]] : tensor<1x128xf32>) +// CHECK: %[[MAP_0:.*]] = linalg.map +// CHECK-SAME: ins(%[[MAP]] : tensor<1x128xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_6]] : tensor<1x128xf32>) -// CHECK: %[[MATERIALIZE_8:.*]] = gml_st.materialize %[[MATERIALIZE_3]][%[[TILE_5]]] -// CHECK-NEXT: %[[GENERIC_3:.*]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%[[GENERIC_2]] : tensor<1x128xf32>) +// CHECK: %[[MATERIALIZE_8:.*]] = tensor.extract_slice %[[MATERIALIZE_3]][%[[ARG2]]] [1] [1] +// CHECK-NEXT: %[[REDUCE_0:.*]] = linalg.reduce +// CHECK-SAME: ins(%[[MAP_0]] : tensor<1x128xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_8]] : tensor<1xf32>) -// CHECK: %[[GENERIC_4:.*]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-SAME: ins(%[[GENERIC_3]] : tensor<1xf32>) +// CHECK: %[[BROADCAST_0:.*]] = linalg.broadcast +// CHECK-SAME: ins(%[[REDUCE_0]] : tensor<1xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_6]] : tensor<1x128xf32>) -// CHECK: %[[GENERIC_5:.*]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]] -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-SAME: ins(%[[GENERIC_2]], %[[GENERIC_4]] : tensor<1x128xf32>, tensor<1x128xf32>) -// CHECK-SAME: outs(%[[MATERIALIZE_6]] : tensor<1x128xf32>) -// CHECK: gml_st.set_yield %[[GENERIC_5]] into %[[MATERIALIZE_1]][%[[TILE_4]]] -// CHECK: gml_st.set_yield %[[PARALLEL_0]] into %[[EMPTY_0]][%[[TILE]]] +// CHECK-NEXT: %[[MATERIALIZE_7:.*]] = tensor.extract_slice %[[EMPTY_SUB_]] +// CHECK: %[[MAP_1:.*]] = linalg.map +// CHECK-SAME: ins(%[[MAP_0]], %[[BROADCAST_0]] : tensor<1x128xf32>, tensor<1x128xf32>) +// CHECK-SAME: outs(%[[MATERIALIZE_7]] : tensor<1x128xf32>) +// CHECK: %[[TILE_4:.*]] = gml_st.tile [%[[ARG2]], 0] [1, 128] [1, 1] +// CHECK: gml_st.set_yield %[[MAP_1]] into %[[EMPTY_SUB_]][%[[TILE_4]]] +// CHECK: %[[TILE:.*]] = gml_st.tile [%[[ARG1]], 0] [8, 128] [1, 1] +// CHECK: gml_st.set_yield %[[PARALLEL_0]] into %[[EMPTY_]][%[[TILE]]] // CHECK: return %[[PARALLEL]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/ops.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/ops.mlir index 6c6e2adb757..d08ec1e958d 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/ops.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/ops.mlir @@ -22,46 +22,25 @@ func.func @dynamic_types(%size : index) { // ----- -// CHECK-LABEL: @materialize_static_tensor -// CHECK-SAME: %[[TENSOR:.*]]: tensor<64x32xf32>, %[[TILE:.*]]: !gml_st.tile<42x16> -func.func @materialize_static_tensor(%tensor: tensor<64x32xf32>, %tile: !gml_st.tile<42x16>) { - // CHECK: %{{.*}} = gml_st.materialize %[[TENSOR]][%[[TILE]]] : tensor<64x32xf32>[!gml_st.tile<42x16>] - %0 = gml_st.materialize %tensor[%tile] : tensor<64x32xf32>[!gml_st.tile<42x16>] to tensor<42x16xf32> - func.return -} - -// ----- - -// CHECK-LABEL: @materialize_dynamic_tensor -// CHECK-SAME: %[[TENSOR:.*]]: tensor, %[[TILE:.*]]: !gml_st.tile<42x16> -func.func @materialize_dynamic_tensor(%tensor: tensor, %tile: !gml_st.tile<42x16>) { - // CHECK: %{{.*}} = gml_st.materialize %[[TENSOR]][%[[TILE]]] : tensor[!gml_st.tile<42x16>] - %0 = gml_st.materialize %tensor[%tile] : tensor[!gml_st.tile<42x16>] to tensor<42x16xf32> - func.return -} - // CHECK-LABEL: @materialize_vector -// CHECK-SAME: %[[VECTOR:.*]]: vector<64x32xf32>, -// CHECK-SAME: %[[TILE:.*]]: !gml_st.tile<42x16> -func.func @materialize_vector(%vector: vector<64x32xf32>, - %tile: !gml_st.tile<42x16>) { - // CHECK: %{{.*}} = gml_st.materialize %[[VECTOR]][%[[TILE]]] - // CHECK-SAME: : vector<64x32xf32>[!gml_st.tile<42x16>] - %0 = gml_st.materialize %vector[%tile] - : vector<64x32xf32>[!gml_st.tile<42x16>] to vector<42x16xf32> +// CHECK-SAME: %[[VECTOR:.*]]: vector<64x32xf32> +func.func @materialize_vector(%vector: vector<64x32xf32>) { + // CHECK: %{{.*}} = gml_st.materialize %[[VECTOR]] + // CHECK-SAME: : vector<64x32xf32> + %0 = gml_st.materialize %vector[0, 0][42, 16][1, 1] + : vector<64x32xf32> to vector<42x16xf32> func.return } // ----- // CHECK-LABEL: @materialize_0d_vector -// CHECK-SAME: %[[VECTOR:.*]]: vector, -// CHECK-SAME: %[[TILE:.*]]: !gml_st.tile<> -func.func @materialize_0d_vector(%vector: vector, %tile: !gml_st.tile<>) { - // CHECK: %{{.*}} = gml_st.materialize %[[VECTOR]][%[[TILE]]] - // CHECK-SAME: : vector[!gml_st.tile<>] to vector - %0 = gml_st.materialize %vector[%tile] - : vector[!gml_st.tile<>] to vector +// CHECK-SAME: %[[VECTOR:.*]]: vector +func.func @materialize_0d_vector(%vector: vector) { + // CHECK: %{{.*}} = gml_st.materialize %[[VECTOR]] + // CHECK-SAME: : vector to vector + %0 = gml_st.materialize %vector[][][] + : vector to vector func.return } @@ -94,163 +73,6 @@ func.func @distribute_0d_vector(%vector: vector, %tile: !gml_st.tile<>) { // ----- -#cwise_trait = { - indexing_maps = [ - affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (i, j)> - ], - iterator_types = ["parallel", "parallel"] -} - -func.func @tiled_loop(%lhs: tensor<24x64xi8>, %rhs: tensor<24x64xi8>, - %out: tensor<24x64xi8>) -> tensor<24x64xi8> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c24 = arith.constant 24 : index - %c64 = arith.constant 64 : index - %prod = gml_st.loop (%i) = (%c0) to (%c24) step (%c4) - ins(%lhs_ = %lhs: tensor<24x64xi8>, %rhs_ = %rhs: tensor<24x64xi8>) - outs(%out_ = %out: tensor<24x64xi8>) { - %lhs_sub = tensor.extract_slice %lhs_[%i, 0] [%c4, %c64] [1, 1] - : tensor<24x64xi8> to tensor - %rhs_sub = tensor.extract_slice %rhs_[%i, 0] [%c4, %c64] [1, 1] - : tensor<24x64xi8> to tensor - %out_sub = tensor.extract_slice %out_[%i, 0] [%c4, %c64] [1, 1] - : tensor<24x64xi8> to tensor - - %sum = linalg.generic #cwise_trait - ins(%lhs_sub, %rhs_sub : tensor, tensor) - outs(%out_sub : tensor) { - ^bb(%l: i8, %r: i8, %o: i8) : - %s = arith.addi %l, %r : i8 - linalg.yield %s : i8 - } -> tensor - - %sum_sub = tensor.insert_slice %sum into %out_[%i, 0][%c4, %c64][1, 1] - : tensor into tensor<24x64xi8> - gml_st.yield %sum_sub : tensor<24x64xi8> - } - func.return %prod : tensor<24x64xi8> -} -// CHECK-LABEL: func @tiled_loop -// CHECK-NOT: iterators[ - -// ----- - -#reduction_trait = { - indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d1, d2)>, - affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d1)>, - affine_map<(d0, d1, d2) -> (d1)> - ], - iterator_types = ["reduction", "parallel", "reduction"] -} - -func.func @tiled_loop_reduction(%input_3d: tensor<16x24x32xf32>, - %input_2d: tensor<16x32xf32>, - %input_1d: tensor<24xf32>, - %output: tensor<24xf32>) -> tensor<24xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %X = tensor.dim %input_3d, %c0 : tensor<16x24x32xf32> - %Y = tensor.dim %input_3d, %c1 : tensor<16x24x32xf32> - %Z = tensor.dim %input_3d, %c2 : tensor<16x24x32xf32> - %result = gml_st.loop (%i, %j, %k) - = (%c0, %c0, %c0) to (%X, %Y, %Z) step (%c2, %c4, %c8) - ins(%i3d_ = %input_3d: tensor<16x24x32xf32>, - %i2d_ = %input_2d: tensor<16x32xf32>, - %i1d_ = %input_1d: tensor<24xf32>) - outs(%o_ = %output: tensor<24xf32>) - iterators[#gml_st.iterator_type, - #gml_st.iterator_type, - #gml_st.iterator_type] - distribution["block_x", "block_y", "none"] { - %sub_3d = tensor.extract_slice %i3d_[%i, %j, %k][2, 4, 8][1, 1, 1] - : tensor<16x24x32xf32> to tensor<2x4x8xf32> - %sub_2d = tensor.extract_slice %i2d_[%i, %k][2, 8][1, 1] - : tensor<16x32xf32> to tensor<2x8xf32> - %sub_1d = tensor.extract_slice %i1d_[%j] [4] [1] - : tensor<24xf32> to tensor<4xf32> - %sub_out = tensor.extract_slice %o_[%j] [4] [1] - : tensor<24xf32> to tensor<4xf32> - %acc = linalg.generic #reduction_trait - ins(%sub_3d, %sub_2d, %sub_1d - : tensor<2x4x8xf32>, tensor<2x8xf32>, tensor<4xf32>) - outs(%sub_out : tensor<4xf32>) { - ^bb0(%i3d: f32, %i2d: f32, %i1d: f32, %o: f32): - %0 = arith.addf %i3d, %i2d : f32 - %1 = arith.addf %0, %i1d : f32 - linalg.yield %1 : f32 - } -> tensor<4xf32> - - %sum_sub = tensor.insert_slice %acc into %o_[%j][4][1] - : tensor<4xf32> into tensor<24xf32> - gml_st.yield %sum_sub : tensor<24xf32> - } - func.return %result : tensor<24xf32> -} -// CHECK-LABEL: func @tiled_loop_reduction -// CHECK: iterators[ - -#map_1 = affine_map<(d0, d1, d2)[s0] -> (d0 * 768 + s0 + d1 * 32 + d2)> -#map_2 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)> -#map_3 = affine_map<(d0)[s0] -> (d0 + s0)> - -func.func @tiled_loop_on_buffers(%input_3d: memref<16x24x32xf32>, - %input_2d: memref<16x32xf32>, - %input_1d: memref<24xf32>, - %output: memref<24xf32>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c4 = arith.constant 4 : index - %c8 = arith.constant 8 : index - %X = memref.dim %input_3d, %c0 : memref<16x24x32xf32> - %Y = memref.dim %input_3d, %c1 : memref<16x24x32xf32> - %Z = memref.dim %input_3d, %c2 : memref<16x24x32xf32> - gml_st.loop (%i, %j, %k) = (%c0, %c0, %c0) - to (%X, %Y, %Z) step (%c2, %c4, %c8) - ins(%i3d_ = %input_3d: memref<16x24x32xf32>, - %i2d_ = %input_2d: memref<16x32xf32>, - %i1d_ = %input_1d: memref<24xf32>) - outs(%o_ = %output: memref<24xf32>) - iterators[#gml_st.iterator_type, - #gml_st.iterator_type, - #gml_st.iterator_type] { - %sub_3d = memref.subview %i3d_[%i, %j, %k][2, 4, 8][1, 1, 1] - : memref<16x24x32xf32> to memref<2x4x8xf32, #map_1> - %sub_2d = memref.subview %i2d_[%i, %k][2, 8][1, 1] - : memref<16x32xf32> to memref<2x8xf32, #map_2> - %sub_1d = memref.subview %i1d_[%j] [4] [1] - : memref<24xf32> to memref<4xf32, #map_3> - %sub_out = memref.subview %o_[%j] [4] [1] - : memref<24xf32> to memref<4xf32, #map_3> - linalg.generic #reduction_trait - ins(%sub_3d, %sub_2d, %sub_1d - : memref<2x4x8xf32, #map_1>, - memref<2x8xf32, #map_2>, - memref<4xf32, #map_3>) - outs(%sub_out : memref<4xf32, #map_3>) { - ^bb0(%i3d: f32, %i2d: f32, %i1d: f32, %o: f32): - %0 = arith.addf %i3d, %i2d : f32 - %1 = arith.addf %0, %i1d : f32 - linalg.yield %1 : f32 - } - gml_st.yield - } - func.return -} -// CHECK-LABEL: func @tiled_loop_on_buffers -// CHECK: iterators[ - -// ----- - #id_1d = affine_map<(d0) -> (d0)> func.func @parallel_loop(%lhs: tensor<8xf32>, %rhs: tensor<8xf32>, @@ -259,14 +81,14 @@ func.func @parallel_loop(%lhs: tensor<8xf32>, %rhs: tensor<8xf32>, %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c4) { - %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> - %lhs_sub = gml_st.materialize %lhs[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> - %rhs_sub = gml_st.materialize %rhs[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> - %out_sub = gml_st.materialize %output[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> + %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c4) + outs(%out_ = %output : tensor<8xf32>) { + %lhs_sub = tensor.extract_slice %lhs[%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> + %rhs_sub = tensor.extract_slice %rhs[%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> + %out_sub = tensor.extract_slice %output[%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> %result_sub = linalg.generic { indexing_maps = [#id_1d, #id_1d, #id_1d], @@ -278,7 +100,8 @@ func.func @parallel_loop(%lhs: tensor<8xf32>, %rhs: tensor<8xf32>, linalg.yield %s : f32 } -> tensor<4xf32> - gml_st.set_yield %result_sub into %output[%tile] + %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> + gml_st.set_yield %result_sub into %out_[%tile] : tensor<4xf32> into tensor<8xf32>[!gml_st.tile<4>] } : tensor<8xf32> func.return %sum : tensor<8xf32> @@ -293,9 +116,10 @@ func.func @loop_on_points(%output: tensor<8xf32>) -> tensor<8xf32> { %c8 = arith.constant 8 : index %c0_f32 = arith.constant 0.0 : f32 - %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c1) { + %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c1) + outs(%out_ = %output : tensor<8xf32>) { %tile = gml_st.tile [%i] [1] [1] : !gml_st.tile<1> - gml_st.set_yield %c0_f32 into %output[%tile] + gml_st.set_yield %c0_f32 into %out_[%tile] : f32 into tensor<8xf32>[!gml_st.tile<1>] } : tensor<8xf32> func.return %sum : tensor<8xf32> @@ -309,9 +133,10 @@ func.func @parallel_with_distribution(%output: tensor<8xf32>) -> tensor<8xf32> { %c8 = arith.constant 8 : index %c0_f32 = arith.constant 0.0 : f32 - %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c1) distribution ("x") { + %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c1) + outs(%out_ = %output : tensor<8xf32>) distribution ("x") { %tile = gml_st.tile [%i] [1] [1] : !gml_st.tile<1> - gml_st.set_yield %c0_f32 into %output[%tile] + gml_st.set_yield %c0_f32 into %out_[%tile] : f32 into tensor<8xf32>[!gml_st.tile<1>] } : tensor<8xf32> func.return %sum : tensor<8xf32> @@ -327,9 +152,10 @@ func.func @loop_on_vector(%output: vector<8xf32>, %fill: vector<2xf32>) %c2 = arith.constant 2 : index %c8 = arith.constant 8 : index - %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c2) { + %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c2) + outs(%out_ = %output : vector<8xf32>) { %tile = gml_st.tile [%i] [2] [1] : !gml_st.tile<2> - gml_st.set_yield %fill into %output[%tile] + gml_st.set_yield %fill into %out_[%tile] : vector<2xf32> into vector<8xf32>[!gml_st.tile<2>] } : vector<8xf32> func.return %sum : vector<8xf32> @@ -349,15 +175,14 @@ func.func @for_loop(%lhs: tensor<8xf32>, %rhs: tensor<8xf32>, %sum, %sum2 = gml_st.for (%i) = (%c0) to (%c8) step (%c4) outs(%out_ = %output : tensor<8xf32>, %out2_ = %output2 : tensor<8xf32>) { - %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> - %lhs_sub = gml_st.materialize %lhs[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> - %rhs_sub = gml_st.materialize %rhs[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> - %out_sub = gml_st.materialize %out_[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> - %out2_sub = gml_st.materialize %out_[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> + %lhs_sub = tensor.extract_slice %lhs [%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> + %rhs_sub = tensor.extract_slice %rhs [%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> + %out_sub = tensor.extract_slice %out_ [%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> + %out2_sub = tensor.extract_slice %out_ [%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> %result_sub = linalg.generic { indexing_maps = [#id_1d, #id_1d, #id_1d], @@ -369,6 +194,7 @@ func.func @for_loop(%lhs: tensor<8xf32>, %rhs: tensor<8xf32>, linalg.yield %s : f32 } -> tensor<4xf32> + %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> gml_st.set_yield %result_sub into %out_[%tile] : tensor<4xf32> into tensor<8xf32>[!gml_st.tile<4>], %result_sub into %out2_[%tile] @@ -388,14 +214,14 @@ func.func @trivial_acc_region(%lhs: tensor<8xf32>, %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c4) { - %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> - %lhs_sub = gml_st.materialize %lhs[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> - %rhs_sub = gml_st.materialize %rhs[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> - %out_sub = gml_st.materialize %output[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> + %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c4) + outs(%out_ = %output : tensor<8xf32>) { + %lhs_sub = tensor.extract_slice %lhs [%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> + %rhs_sub = tensor.extract_slice %rhs[%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> + %out_sub = tensor.extract_slice %output[%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> %result_sub = linalg.generic { indexing_maps = [#id_1d, #id_1d, #id_1d], @@ -407,7 +233,8 @@ func.func @trivial_acc_region(%lhs: tensor<8xf32>, linalg.yield %s : f32 } -> tensor<4xf32> - gml_st.set_yield %result_sub into %output[%tile] + %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> + gml_st.set_yield %result_sub into %out_[%tile] acc (%new, %old: tensor<4xf32>) { gml_st.yield %new : tensor<4xf32> } : tensor<4xf32> into tensor<8xf32>[!gml_st.tile<4>] @@ -431,16 +258,16 @@ func.func @two_acc_region(%lhs: tensor<8xf32>, %rhs: tensor<8xf32>, %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %result:2 = gml_st.parallel (%i) = (%c0) to (%c8) step (%c4) { - %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> - %lhs_sub = gml_st.materialize %lhs[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> - %rhs_sub = gml_st.materialize %rhs[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> - %out_sub = gml_st.materialize %output[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> - %out_2_sub = gml_st.materialize %output_2[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> + %result:2 = gml_st.parallel (%i) = (%c0) to (%c8) step (%c4) + outs(%out_ = %output : tensor<8xf32>, %out2_ = %output_2 : tensor<8xf32>) { + %lhs_sub = tensor.extract_slice %lhs [%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> + %rhs_sub = tensor.extract_slice %rhs [%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> + %out_sub = tensor.extract_slice %output [%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> + %out_2_sub = tensor.extract_slice %output_2 [%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> %result_sub = linalg.generic { indexing_maps = [#id_1d, #id_1d, #id_1d], @@ -452,11 +279,12 @@ func.func @two_acc_region(%lhs: tensor<8xf32>, %rhs: tensor<8xf32>, linalg.yield %s : f32 } -> tensor<4xf32> + %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> gml_st.set_yield - %result_sub into %output[%tile] acc (%new, %old: tensor<4xf32>) { + %result_sub into %out_[%tile] acc (%new, %old: tensor<4xf32>) { gml_st.yield %new : tensor<4xf32> } : tensor<4xf32> into tensor<8xf32>[!gml_st.tile<4>], - %result_sub into %output_2[%tile] acc (%new, %old: tensor<4xf32>) { + %result_sub into %out2_[%tile] acc (%new, %old: tensor<4xf32>) { %sum = linalg.generic { indexing_maps = [#id_1d, #id_1d], iterator_types = ["parallel"]} @@ -493,14 +321,14 @@ func.func @accumulator_region(%lhs: tensor<8xf32>, %c4 = arith.constant 4 : index %c8 = arith.constant 8 : index - %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c4) { - %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> - %lhs_sub = gml_st.materialize %lhs[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> - %rhs_sub = gml_st.materialize %rhs[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> - %out_sub = gml_st.materialize %output[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> + %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c4) + outs(%out_ = %output : tensor<8xf32>) { + %lhs_sub = tensor.extract_slice %lhs [%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> + %rhs_sub = tensor.extract_slice %rhs [%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> + %out_sub = tensor.extract_slice %output [%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> %result_sub = linalg.generic { indexing_maps = [#id_1d, #id_1d, #id_1d], @@ -512,7 +340,8 @@ func.func @accumulator_region(%lhs: tensor<8xf32>, linalg.yield %s : f32 } -> tensor<4xf32> - gml_st.set_yield %result_sub into %output[%tile] + %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> + gml_st.set_yield %result_sub into %out_[%tile] acc (%new, %old: tensor<4xf32>) { %sum = linalg.generic { indexing_maps = [#id_1d, #id_1d], @@ -547,11 +376,10 @@ func.func @reduce_tiles(%arg: tensor<8xf32>, %c8 = arith.constant 8 : index %c0_f32 = arith.constant 0.0 : f32 - %init_tile = gml_st.tile [] [] [] : !gml_st.tile<> - %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c4) { - %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> - %arg_sub = gml_st.materialize %arg[%tile] - : tensor<8xf32>[!gml_st.tile<4>] to tensor<4xf32> + %sum = gml_st.parallel (%i) = (%c0) to (%c8) step (%c4) + outs(%out_ = %output : tensor) { + %arg_sub = tensor.extract_slice %arg [%i] [4] [1] + : tensor<8xf32> to tensor<4xf32> %local_init = tensor.empty() : tensor %local_fill = linalg.fill @@ -568,7 +396,8 @@ func.func @reduce_tiles(%arg: tensor<8xf32>, linalg.yield %s : f32 } -> tensor - gml_st.set_yield %result_sub into %output[%init_tile] + %init_tile = gml_st.tile [] [] [] : !gml_st.tile<> + gml_st.set_yield %result_sub into %out_[%init_tile] acc (%in, %out: tensor) { %in_pt = tensor.extract %in[] : tensor %out_pt = tensor.extract %out[] : tensor @@ -598,10 +427,10 @@ func.func @column_reduction(%arg: tensor<128x16xf32>, %c128 = arith.constant 128 : index %cst = arith.constant 0.000000e+00 : f32 - %sum = gml_st.parallel (%i, %j) = (%c0, %c0) to (%c128, %c16) step (%c8, %c8) { - %arg_tile = gml_st.tile [%i, %j] [8, 8] [1, 1] : !gml_st.tile<8x8> - %arg_sub = gml_st.materialize %arg[%arg_tile] - : tensor<128x16xf32>[!gml_st.tile<8x8>] to tensor<8x8xf32> + %sum = gml_st.parallel (%i, %j) = (%c0, %c0) to (%c128, %c16) step (%c8, %c8) + outs(%out_ = %out : tensor<16xf32>) { + %arg_sub = tensor.extract_slice %arg[%i, %j] [8, 8] [1, 1] + : tensor<128x16xf32> to tensor<8x8xf32> %init = tensor.empty() : tensor<8xf32> %fill = linalg.fill ins(%cst : f32) @@ -618,8 +447,7 @@ func.func @column_reduction(%arg: tensor<128x16xf32>, } -> tensor<8xf32> %out_tile = gml_st.tile [%j] [8] [1] : !gml_st.tile<8> - - gml_st.set_yield %result_sub into %out[%out_tile] + gml_st.set_yield %result_sub into %out_[%out_tile] acc (%new, %old: tensor<8xf32>) { %acc = linalg.generic { indexing_maps = [#id_1d, #id_1d], @@ -655,13 +483,11 @@ func.func @sequential_column_reduction(%arg: tensor<128x16xf32>, %sum = gml_st.for (%i, %j) = (%c0, %c0) to (%c128, %c16) step (%c8, %c8) outs(%out_ = %out : tensor<16xf32>) { - %arg_tile = gml_st.tile [%i, %j] [8, 8] [1, 1] : !gml_st.tile<8x8> - %arg_sub = gml_st.materialize %arg[%arg_tile] - : tensor<128x16xf32>[!gml_st.tile<8x8>] to tensor<8x8xf32> + %arg_sub = tensor.extract_slice %arg[%i, %j] [8, 8] [1, 1] + : tensor<128x16xf32> to tensor<8x8xf32> - %out_tile = gml_st.tile [%j] [8] [1] : !gml_st.tile<8> - %out_sub = gml_st.materialize %out_[%out_tile] - : tensor<16xf32>[!gml_st.tile<8>] to tensor<8xf32> + %out_sub = tensor.extract_slice %out_[%j] [8] [1] + : tensor<16xf32> to tensor<8xf32> %result_sub = linalg.generic { indexing_maps = [#id_2d, #map_1d], @@ -673,6 +499,7 @@ func.func @sequential_column_reduction(%arg: tensor<128x16xf32>, linalg.yield %s : f32 } -> tensor<8xf32> + %out_tile = gml_st.tile [%j] [8] [1] : !gml_st.tile<8> gml_st.set_yield %result_sub into %out_[%out_tile] : tensor<8xf32> into tensor<16xf32>[!gml_st.tile<8>] } : tensor<16xf32> diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/rewrite_vector_contract.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/rewrite_vector_contract.mlir new file mode 100644 index 00000000000..eabf0c4bbdd --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/rewrite_vector_contract.mlir @@ -0,0 +1,166 @@ +// RUN: mlir-hlo-opt %s --split-input-file --rewrite-vector-contract | FileCheck %s + +func.func @lower_vector_contract(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) + -> tensor<8x8xf32> { + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %2 = vector.transfer_read %arg0[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<8x8xf32>, vector<8x8xf32> + %3 = vector.transfer_read %arg1[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<8x8xf32>, vector<8x8xf32> + %4 = vector.transfer_read %0[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<8x8xf32>, vector<8x8xf32> + %5 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %2, %3, %4 : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32> + %6 = vector.transfer_write %5, %0[%c0, %c0] {in_bounds = [true, true]} : vector<8x8xf32>, tensor<8x8xf32> + return %6 : tensor<8x8xf32> +} + +// CHECK-LABEL: func @lower_vector_contract( +// CHECK-SAME: %[[LHS:.*]]: tensor<8x8xf32>, %[[RHS:.*]]: tensor<8x8xf32>) + +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[INIT:.*]] = tensor.empty + +// CHECK: %[[LHS_READ:.*]] = vector.transfer_read %[[LHS]] +// CHECK: %[[RHS_READ:.*]] = vector.transfer_read %[[RHS]] +// CHECK: %[[OUT_READ:.*]] = vector.transfer_read %[[INIT]] +// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[LHS_READ]] +// CHECK: %[[EXTRACT_LHS0:.*]] = vector.extract %[[TRANSPOSE]][0] +// CHECK: %[[EXTRACT_RHS0:.*]] = vector.extract %[[RHS_READ]][0] +// CHECK: %[[PRODUCT0:.*]] = vector.outerproduct %[[EXTRACT_LHS0]], %[[EXTRACT_RHS0]], %[[OUT_READ]] +// CHECK: %[[EXTRACT_LHS1:.*]] = vector.extract %[[TRANSPOSE]][1] +// CHECK: %[[EXTRACT_RHS1:.*]] = vector.extract %[[RHS_READ]][1] +// CHECK: %[[PRODUCT1:.*]] = vector.outerproduct %[[EXTRACT_LHS1]], %[[EXTRACT_RHS1]], %[[PRODUCT0]] +// CHECK: %[[PRODUCT2:.*]] = vector.outerproduct %[[EXTRACT_LHS2:.*]], %[[EXTRACT_RHS2:.*]], %[[PRODUCT1]] +// CHECK: %[[PRODUCT3:.*]] = vector.outerproduct %[[EXTRACT_LHS3:.*]], %[[EXTRACT_RHS3:.*]], %[[PRODUCT2]] +// CHECK: %[[PRODUCT4:.*]] = vector.outerproduct %[[EXTRACT_LHS4:.*]], %[[EXTRACT_RHS4:.*]], %[[PRODUCT3]] +// CHECK: %[[PRODUCT5:.*]] = vector.outerproduct %[[EXTRACT_LHS5:.*]], %[[EXTRACT_RHS5:.*]], %[[PRODUCT4]] +// CHECK: %[[PRODUCT6:.*]] = vector.outerproduct %[[EXTRACT_LHS6:.*]], %[[EXTRACT_RHS6:.*]], %[[PRODUCT5]] +// CHECK: %[[PRODUCT7:.*]] = vector.outerproduct %[[EXTRACT_LHS7:.*]], %[[EXTRACT_RHS7:.*]], %[[PRODUCT6]] +// CHECK: %[[RET:.*]] = vector.transfer_write %[[PRODUCT7]], %[[INIT]] +// CHECK: return %[[RET]] + +// ----- + +func.func @canonicalize_outer_product(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) + -> tensor<8x8xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.000000e+00> : vector<8x8xf32> + %cst_0 = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %2 = vector.transfer_read %arg0[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<8x8xf32>, vector<8x8xf32> + %3 = vector.transfer_read %arg1[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<8x8xf32>, vector<8x8xf32> + %5 = gml_st.materialize %cst[0, 0] [8, 8] [1, 1] : vector<8x8xf32> to vector<8x8xf32> + %6 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %2, %3, %5 : vector<8x8xf32>, vector<8x8xf32> into vector<8x8xf32> + %7 = vector.transfer_write %6, %0[%c0, %c0] {in_bounds = [true, true]} : vector<8x8xf32>, tensor<8x8xf32> + return %7 : tensor<8x8xf32> +} + +// CHECK-LABEL: func @canonicalize_outer_product( +// CHECK-SAME: %[[LHS:.*]]: tensor<8x8xf32>, %[[RHS:.*]]: tensor<8x8xf32>) + +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x8xf32> +// CHECK: %[[INIT:.*]] = tensor.empty + +// CHECK: %[[LHS_READ:.*]] = vector.transfer_read %[[LHS]] +// CHECK: %[[RHS_READ:.*]] = vector.transfer_read %[[RHS]] +// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[LHS_READ]] +// CHECK: %[[EXTRACT_LHS0:.*]] = vector.extract %[[TRANSPOSE]][0] +// CHECK: %[[EXTRACT_RHS0:.*]] = vector.extract %[[RHS_READ]][0] +// CHECK: %[[PRODUCT0:.*]] = vector.outerproduct %[[EXTRACT_LHS0]], %[[EXTRACT_RHS0]], %[[CST]] + +// ----- + +func.func @lower_vector_contract_4d(%arg0: tensor<1x1x8x1xf32>, + %arg1: tensor<1x1x8x1xf32>) + -> tensor<1x1x8x8xf32> { + %c0 = arith.constant 0 : index + %4 = tensor.empty() : tensor<1x1x8x8xf32> + %cst = arith.constant 0.000000e+00 : f32 + %20 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst + {in_bounds = [true, true, true, true]} : tensor<1x1x8x1xf32>, + vector<1x1x8x1xf32> + %21 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst + {in_bounds = [true, true, true, true]} : tensor<1x1x8x1xf32>, + vector<1x1x8x1xf32> + %22 = vector.transfer_read %4[%c0, %c0, %c0, %c0], %cst + {in_bounds = [true, true, true, true]} : tensor<1x1x8x8xf32>, + vector<1x1x8x8xf32> + %23 = vector.contract {indexing_maps = + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], + iterator_types = ["parallel", "parallel", "reduction", + "parallel", "parallel", "reduction"], + kind = #vector.kind} + %20, %21, %22 : vector<1x1x8x1xf32>, vector<1x1x8x1xf32> + into vector<1x1x8x8xf32> + %14 = vector.transfer_write %23, %4[%c0, %c0, %c0, %c0] + {in_bounds = [true, true, true, true]} : vector<1x1x8x8xf32>, + tensor<1x1x8x8xf32> + return %14 : tensor<1x1x8x8xf32> +} + +// CHECK-LABEL: func @lower_vector_contract_4d( +// CHECK-SAME: %[[LHS:.*]]: tensor<1x1x8x1xf32>, %[[RHS:.*]]: tensor<1x1x8x1xf32>) + +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[INIT:.*]] = tensor.empty + +// CHECK: %[[LHS_READ:.*]] = vector.transfer_read %[[LHS]]{{.*}} tensor<1x1x8x1xf32>, vector<8x1xf32> +// CHECK: %[[RHS_READ:.*]] = vector.transfer_read %[[RHS]]{{.*}} tensor<1x1x8x1xf32>, vector<8x1xf32> +// CHECK: %[[OUT_READ:.*]] = vector.transfer_read %[[INIT]]{{.*}} vector<8x8xf32> +// CHECK: %[[LHS_TRANSPOSE:.*]] = vector.transpose %[[LHS_READ]]{{.*}} : vector<8x1xf32> to vector<1x8xf32> +// CHECK: %[[RHS_TRANSPOSE:.*]] = vector.transpose %[[RHS_READ]]{{.*}} : vector<8x1xf32> to vector<1x8xf32> +// CHECK: %[[EXTRACT_LHS:.*]] = vector.extract %[[LHS_TRANSPOSE]][0] +// CHECK: %[[EXTRACT_RHS:.*]] = vector.extract %[[RHS_TRANSPOSE]][0] +// CHECK: %[[PRODUCT:.*]] = vector.outerproduct %[[EXTRACT_LHS]], %[[EXTRACT_RHS]], %[[OUT_READ]] +// CHECK: %[[RET:.*]] = vector.transfer_write %[[PRODUCT]], %[[INIT]]{{.*}} vector<8x8xf32>, tensor<1x1x8x8xf32> +// CHECK: return %[[RET]] + +// ----- + +func.func @lower_vector_contract_4d_matvec(%arg0: tensor<1x1x1x1xf32>, + %arg1: tensor<1x1x8x1xf32>) + -> tensor<1x1x1x8xf32> { + %c0 = arith.constant 0 : index + %4 = tensor.empty() : tensor<1x1x1x8xf32> + %cst = arith.constant 0.000000e+00 : f32 + %20 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst + {in_bounds = [true, true, true, true]} : tensor<1x1x1x1xf32>, + vector<1x1x1x1xf32> + %21 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst + {in_bounds = [true, true, true, true]} : tensor<1x1x8x1xf32>, + vector<1x1x8x1xf32> + %22 = vector.transfer_read %4[%c0, %c0, %c0, %c0], %cst + {in_bounds = [true, true, true, true]} : tensor<1x1x1x8xf32>, + vector<1x1x1x8xf32> + %23 = vector.contract {indexing_maps = + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>], + iterator_types = ["parallel", "parallel", "reduction", + "parallel", "parallel", "reduction"], + kind = #vector.kind} + %20, %21, %22 : vector<1x1x1x1xf32>, vector<1x1x8x1xf32> + into vector<1x1x1x8xf32> + %14 = vector.transfer_write %23, %4[%c0, %c0, %c0, %c0] + {in_bounds = [true, true, true, true]} : vector<1x1x1x8xf32>, + tensor<1x1x1x8xf32> + return %14 : tensor<1x1x1x8xf32> +} + +// CHECK-LABEL: func @lower_vector_contract_4d_matvec( +// CHECK-SAME: %[[LHS:.*]]: tensor<1x1x1x1xf32>, %[[RHS:.*]]: tensor<1x1x8x1xf32>) + +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[INIT:.*]] = tensor.empty + +// CHECK: %[[LHS_READ:.*]] = vector.transfer_read %[[LHS]]{{.*}} tensor<1x1x1x1xf32>, vector<1xf32> +// CHECK: %[[RHS_READ:.*]] = vector.transfer_read %[[RHS]]{{.*}} tensor<1x1x8x1xf32>, vector<8x1xf32> +// CHECK: %[[OUT_READ:.*]] = vector.transfer_read %[[INIT]]{{.*}} vector<8xf32> +// CHECK: %[[RHS_TRANSPOSE:.*]] = vector.transpose %[[RHS_READ]]{{.*}} : vector<8x1xf32> to vector<1x8xf32> +// CHECK: %[[EXTRACT_RHS:.*]] = vector.extract %[[RHS_TRANSPOSE]][0] +// CHECK: %[[EXTRACT_LHS:.*]] = vector.extract %[[LHS_READ]][0] +// CHECK: %[[PRODUCT:.*]] = vector.outerproduct %[[EXTRACT_RHS]], %[[EXTRACT_LHS]], %[[OUT_READ]] +// CHECK: %[[RET:.*]] = vector.transfer_write %[[PRODUCT]], %[[INIT]]{{.*}} vector<8xf32>, tensor<1x1x1x8xf32> +// CHECK: return %[[RET]] diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_rewrite_vector_multi_reduction.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/rewrite_vector_multi_reduction.mlir similarity index 53% rename from tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_rewrite_vector_multi_reduction.mlir rename to tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/rewrite_vector_multi_reduction.mlir index 5af902233a4..de73d162fe2 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_rewrite_vector_multi_reduction.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/rewrite_vector_multi_reduction.mlir @@ -1,19 +1,4 @@ -// Copyright 2022 The TensorFlow Runtime Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// RUN: tf-tfrt-opt %s --tf-jitrt-rewrite-vector-multi-reduction \ -// RUN: | FileCheck %s +// RUN: mlir-hlo-opt %s --rewrite-vector-multi-reduction | FileCheck %s // CHECK-LABEL: func @vector_row func.func @vector_row(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> { diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/simplify_dead_copy.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/simplify_dead_copy.mlir new file mode 100644 index 00000000000..0344e7b03d3 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/simplify_dead_copy.mlir @@ -0,0 +1,198 @@ +// RUN: mlir-hlo-opt %s --split-input-file --simplify-dead-copy | FileCheck %s + +func.func @target_is_alloc(%arg0: memref<8x8xf32>) -> memref<8x8xf32> { + %c4 = arith.constant 4 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x8xf32> + memref.copy %arg0, %alloc_4: memref<8x8xf32> to memref<8x8xf32> + return %arg0 : memref<8x8xf32> +} + +// CHECK-LABEL: func @target_is_alloc( +// CHECK-SAME: %[[INPUT:.*]]: memref<8x8xf32>) + +// CHECK-NOT: memref.copy +// CHECK: return %[[INPUT]] + +// ----- + +func.func @target_is_alloc_with_other_stores(%arg0: memref<8x8xf32>) + -> memref<8x8xf32> { + %c4 = arith.constant 4 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x8xf32> + memref.copy %arg0, %alloc_4: memref<8x8xf32> to memref<8x8xf32> + linalg.fill ins(%cst_0 : f32) outs(%alloc_4 : memref<8x8xf32>) + memref.store %cst_0, %alloc_4[%c4, %c4] : memref<8x8xf32> + return %arg0 : memref<8x8xf32> +} + +// CHECK-LABEL: func @target_is_alloc_with_other_stores( +// CHECK-SAME: %[[INPUT:.*]]: memref<8x8xf32>) + +// CHECK: memref.alloc +// CHECK-NOT: memref.copy +// CHECK: linalg.fill +// CHECK: memref.store +// CHECK: return %[[INPUT]] + +// ----- + +func.func @target_is_subview(%arg0: memref<8x8xf32>) -> memref<8x8xf32> { + %c4 = arith.constant 4 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x8xf32> + %subview_5 = memref.subview %alloc_4[0, 0] [%c4, %c4] [1, 1] : + memref<8x8xf32> to memref> + memref.copy %arg0, %subview_5 : + memref<8x8xf32> to memref> + return %arg0 : memref<8x8xf32> +} + +// CHECK-LABEL: func @target_is_subview( +// CHECK-SAME: %[[INPUT:.*]]: memref<8x8xf32>) + +// CHECK-NOT: memref.copy +// CHECK: return %[[INPUT]] + +// ----- + +func.func @target_is_subview_of_subview(%arg0: memref<8x8xf32>) + -> memref<8x8xf32> { + %c4 = arith.constant 4 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x8xf32> + %subview_5 = memref.subview %alloc_4[0, 0] [%c4, %c4] [1, 1] : + memref<8x8xf32> to memref> + %subview_6 = memref.subview %subview_5[0, 0] [%c4, %c4] [1, 1] : + memref> to memref> + memref.copy %arg0, %subview_6 : + memref<8x8xf32> to memref> + return %arg0 : memref<8x8xf32> +} + +// CHECK-LABEL: func @target_is_subview_of_subview( +// CHECK-SAME: %[[INPUT:.*]]: memref<8x8xf32>) + +// CHECK-NOT: memref.copy +// CHECK: return %[[INPUT]] + +// ----- + +func.func @do_not_simplify_subview_of_subview(%arg0: memref<8x8xf32>) + -> vector<8x8xf32> { + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x8xf32> + %subview_5 = memref.subview %alloc_4[0, 0] [%c4, %c4] [1, 1] : + memref<8x8xf32> to memref> + %subview_6 = memref.subview %subview_5[0, 0] [%c4, %c4] [1, 1] : + memref> to memref> + memref.copy %arg0, %subview_6 : + memref<8x8xf32> to memref> + %27 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 : + memref>, vector<8x8xf32> + return %27 : vector<8x8xf32> +} + +// CHECK-LABEL: func @do_not_simplify_subview_of_subview( + +// CHECK: memref.alloc +// CHECK: memref.subview +// CHECK: memref.subview +// CHECK: memref.copy + +// ----- + +func.func @do_not_simplify_subview(%arg0: memref<8x8xf32>) -> vector<8x8xf32> { + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x8xf32> + %subview_5 = memref.subview %alloc_4[0, 0] [%c4, %c4] [1, 1] : + memref<8x8xf32> to memref> + memref.copy %arg0, %subview_5 : + memref<8x8xf32> to memref> + %27 = vector.transfer_read %subview_5[%c0, %c0], %cst_0 : + memref>, vector<8x8xf32> + return %27 : vector<8x8xf32> +} + +// CHECK-LABEL: func @do_not_simplify_subview( + +// CHECK: memref.alloc +// CHECK: memref.subview +// CHECK: memref.copy + +// ----- + +func.func @do_not_simplify_alloc(%arg0: memref<8x8xf32>) -> vector<8x8xf32> { + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x8xf32> + memref.copy %arg0, %alloc_4 : memref<8x8xf32> to memref<8x8xf32> + %27 = vector.transfer_read %alloc_4[%c0, %c0], %cst_0 : + memref<8x8xf32>, vector<8x8xf32> + return %27 : vector<8x8xf32> +} + +// CHECK-LABEL: func @do_not_simplify_alloc( + +// CHECK: memref.alloc +// CHECK: memref.copy + +// ----- + +func.func @do_not_simplify_subview_with_other_use(%arg0: memref<8x8xf32>) + -> memref<8x8xf32> { + %c4 = arith.constant 4 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x8xf32> + %subview_5 = memref.subview %alloc_4[0, 0] [%c4, %c4] [1, 1] : + memref<8x8xf32> to memref> + %subview_6 = memref.subview %alloc_4[0, 0] [%c4, %c4] [1, 1] : + memref<8x8xf32> to memref> + memref.copy %arg0, %subview_6 : + memref<8x8xf32> to memref> + memref.copy %arg0, %subview_5 : + memref<8x8xf32> to memref> + return %arg0 : memref<8x8xf32> +} + + +// CHECK-LABEL: func @do_not_simplify_subview_with_other_use( + +// CHECK: memref.alloc +// CHECK: memref.subview +// CHECK: memref.subview +// CHECK: memref.copy +// CHECK: memref.copy + +// ----- + +func.func @target_is_alloc_with_loads_stores(%arg0: memref<8x8xf32>) + -> memref<8x8xf32> { + %c4 = arith.constant 4 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<8x8xf32> + memref.copy %arg0, %alloc_4: memref<8x8xf32> to memref<8x8xf32> + "lmhlo.custom_call"(%alloc_4, %alloc_4) ({ + }) { + backend_config = "", + call_target_name = "foo", + has_side_effect = false, + operand_segment_sizes = array + } : (memref<8x8xf32>, memref<8x8xf32>) -> () + + return %arg0 : memref<8x8xf32> +} + +// CHECK-LABEL: func @target_is_alloc_with_loads_stores( +// CHECK-SAME: %[[INPUT:.*]]: memref<8x8xf32>) + +// CHECK: memref.alloc +// CHECK: memref.copy +// CHECK: "lmhlo.custom_call" +// CHECK: return %[[INPUT]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling.mlir index 24777b633fc..51db848f008 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling.mlir @@ -3,213 +3,16 @@ // RUN: --gml-tiling="tile-sizes=1,1 distribute=false op-label=tile-2d-point" \ // RUN: --gml-tiling="tile-sizes=1 distribute=false op-label=tile-1d-point" \ // RUN: --gml-tiling="tile-sizes=256,512 distribute=false op-label=tile-3d" \ +// RUN: --gml-tiling="tile-sizes=10 distribute=false op-label=tile-1d" \ +// RUN: --gml-tiling="tile-sizes=2,4 distribute=false op-label=tile-pad" \ // RUN: --cse | \ // RUN: FileCheck %s --check-prefix=CHECK-FOR // RUN: mlir-hlo-opt %s --split-input-file \ -// RUN: --gml-tiling="tile-sizes=256,512 distribute=true op-label=tile-2d" \ +// RUN: --gml-tiling="tile-sizes=256,512 distribute=true op-label=tile-3d" \ // RUN: --cse | \ // RUN: FileCheck %s --check-prefix=CHECK-PARALLEL -#id_map = affine_map<(d0, d1) -> (d0, d1)> - -func.func @add_static(%lhs: tensor<1024x1024xf32>, %rhs: tensor<1024x1024xf32>) - -> tensor<1024x1024xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %init = tensor.empty() : tensor<1024x1024xf32> - %add = linalg.generic { - indexing_maps = [#id_map, #id_map, #id_map], - iterator_types = ["parallel", "parallel"], - op_label = "tile-2d"} - ins(%lhs, %rhs : tensor<1024x1024xf32>, tensor<1024x1024xf32>) - outs(%init : tensor<1024x1024xf32>) { - ^bb0(%lhs_scalar: f32, %rhs_scalar: f32, %_: f32): - %add_scalar = arith.addf %lhs_scalar, %rhs_scalar : f32 - linalg.yield %add_scalar : f32 - } -> tensor<1024x1024xf32> - func.return %add : tensor<1024x1024xf32> -} - -// CHECK-FOR-LABEL: @add_static -// CHECK-FOR-SAME: %[[ARG0:.*]]: tensor<1024x1024xf32>, %[[ARG1:.*]]: tensor<1024x1024xf32> - -// CHECK-FOR-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-FOR-DAG: %[[C256:.*]] = arith.constant 256 -// CHECK-FOR-DAG: %[[C512:.*]] = arith.constant 512 -// CHECK-FOR-DAG: %[[C1024:.*]] = arith.constant 1024 -// CHECK-FOR: %[[INIT:.*]] = tensor.empty() -// CHECK-FOR: %[[FOR:.*]] = gml_st.for (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) -// CHECK-FOR-SAME: to (%[[C1024]], %[[C1024]]) -// CHECK-FOR-SAME: step (%[[C256]], %[[C512]]) -// CHECK-FOR-SAME: outs (%[[ARG4:.*]] = %[[INIT]]: tensor<1024x1024xf32>) -// CHECK-FOR: %[[TILE:.*]] = gml_st.tile [%[[I]], %[[J]]] [256, 512] [1, 1] -// CHECK-FOR: %[[MATERIALIZE:.*]] = gml_st.materialize %[[ARG0]][%[[TILE]]] -// CHECK-FOR: %[[MATERIALIZE_0:.*]] = gml_st.materialize %[[ARG1]][%[[TILE]]] -// CHECK-FOR: %[[MATERIALIZE_1:.*]] = gml_st.materialize %[[ARG4]][%[[TILE]]] -// CHECK-FOR: %[[GENERIC:.*]] = linalg.generic -// CHECK-FOR-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-FOR-SAME: ins(%[[MATERIALIZE]], %[[MATERIALIZE_0]] : tensor<256x512xf32>, tensor<256x512xf32>) -// CHECK-FOR-SAME: outs(%[[MATERIALIZE_1]] : tensor<256x512xf32>) -// CHECK-FOR-SAME: attrs = {op_label = "tile-2d"} -// CHECK-FOR: ^bb0(%[[ARG5:.*]]: f32, %[[ARG6:.*]]: f32, %[[ARG7:.*]]: f32): -// CHECK-FOR: %[[ADDF:.*]] = arith.addf %[[ARG5]], %[[ARG6]] -// CHECK-FOR: linalg.yield %[[ADDF]] -// CHECK-FOR: gml_st.set_yield %[[GENERIC]] into %[[ARG4]][%[[TILE]]] -// CHECK-FOR: return %[[FOR]] - -// ----- - -#id_map = affine_map<(d0, d1) -> (d0, d1)> - -func.func @add(%lhs: tensor, %rhs: tensor) - -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %d0 = tensor.dim %lhs, %c0 : tensor - %d1 = tensor.dim %lhs, %c1 : tensor - %init = tensor.empty(%d0, %d1) : tensor - %add = linalg.generic { - indexing_maps = [#id_map, #id_map, #id_map], - iterator_types = ["parallel", "parallel"], - op_label = "tile-2d"} - ins(%lhs, %rhs : tensor, tensor) - outs(%init : tensor) { - ^bb0(%lhs_scalar: f32, %rhs_scalar: f32, %_: f32): - %add_scalar = arith.addf %lhs_scalar, %rhs_scalar : f32 - linalg.yield %add_scalar : f32 - } -> tensor - func.return %add : tensor -} - - -// CHECK-FOR-LABEL: @add( -// CHECK-FOR-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor - -// CHECK-FOR: %[[C0:.*]] = arith.constant 0 -// CHECK-FOR: %[[C1:.*]] = arith.constant 1 -// CHECK-FOR: %[[C256:.*]] = arith.constant 256 -// CHECK-FOR: %[[C512:.*]] = arith.constant 512 -// CHECK-FOR: %[[LHS_DIM_0:.*]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK-FOR: %[[LHS_DIM_1:.*]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK-FOR: %[[INIT:.*]] = tensor.empty(%[[LHS_DIM_0]], %[[LHS_DIM_1]]) -// CHECK-FOR: %[[FOR:.*]] = gml_st.for (%[[ARG2:.*]], %[[ARG3:.*]]) = (%[[C0]], %[[C0]]) -// CHECK-FOR-SAME: to (%[[LHS_DIM_0]], %[[LHS_DIM_1]]) -// CHECK-FOR-SAME: step (%[[C256]], %[[C512]]) -// CHECK-FOR-SAME: outs (%[[OUT:.*]] = %[[INIT]]: tensor) -// CHECK-FOR: %[[MIN:.*]] = affine.min #map{{[0-9]*}}(%[[ARG2]])[%[[LHS_DIM_0]]] -// CHECK-FOR: %[[MIN_0:.*]] = affine.min #map{{[0-9]*}}(%[[ARG3]])[%[[LHS_DIM_1]]] -// CHECK-FOR: %[[TILE:.*]] = gml_st.tile [%[[ARG2]], %[[ARG3]]] [%[[MIN]], %[[MIN_0]]] [1, 1] -// CHECK-FOR: %[[MATERIALIZE:.*]] = gml_st.materialize %[[ARG0]][%[[TILE]]] -// CHECK-FOR: %[[MATERIALIZE_0:.*]] = gml_st.materialize %[[ARG1]][%[[TILE]]] -// CHECK-FOR: %[[MATERIALIZE_1:.*]] = gml_st.materialize %[[OUT]][%[[TILE]]] -// CHECK-FOR: %[[GENERIC:.*]] = linalg.generic -// CHECK-FOR-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-FOR-SAME: ins(%[[MATERIALIZE]], %[[MATERIALIZE_0]] : tensor, tensor) -// CHECK-FOR-SAME: outs(%[[MATERIALIZE_1]] : tensor) -// CHECK-FOR-SAME: attrs = {op_label = "tile-2d"} -// CHECK-FOR: ^bb0(%[[ARG5:.*]]: f32, %[[ARG6:.*]]: f32, %[[ARG7:.*]]: f32): -// CHECK-FOR: %[[ADDF:.*]] = arith.addf %[[ARG5]], %[[ARG6]] -// CHECK-FOR: linalg.yield %[[ADDF]] -// CHECK-FOR: gml_st.set_yield %[[GENERIC]] into %[[OUT]][%[[TILE]]] -// CHECK-FOR: return %[[FOR]] - - -// CHECK-PARALLEL-LABEL: @add( -// CHECK-PARALLEL-SAME: %[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor - -// CHECK-PARALLEL: %[[C0:.*]] = arith.constant 0 -// CHECK-PARALLEL: %[[C1:.*]] = arith.constant 1 -// CHECK-PARALLEL: %[[C256:.*]] = arith.constant 256 -// CHECK-PARALLEL: %[[C512:.*]] = arith.constant 512 -// CHECK-PARALLEL: %[[LHS_DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] -// CHECK-PARALLEL: %[[LHS_DIM_1:.*]] = tensor.dim %[[LHS]], %[[C1]] -// CHECK-PARALLEL: %[[INIT:.*]] = tensor.empty(%[[LHS_DIM_0]], %[[LHS_DIM_1]]) -// CHECK-PARALLEL: %[[PARALLEL:.*]] = gml_st.parallel (%[[ARG2:.*]], %[[ARG3:.*]]) = (%[[C0]], %[[C0]]) -// CHECK-PARALLEL-SAME: to (%[[LHS_DIM_0]], %[[LHS_DIM_1]]) -// CHECK-PARALLEL-SAME: step (%[[C256]], %[[C512]]) -// CHECK-PARALLEL: %[[MIN:.*]] = affine.min #map{{[0-9]*}}(%[[ARG2]])[%[[LHS_DIM_0]]] -// CHECK-PARALLEL: %[[MIN_0:.*]] = affine.min #map{{[0-9]*}}(%[[ARG3]])[%[[LHS_DIM_1]]] -// CHECK-PARALLEL: %[[TILE:.*]] = gml_st.tile [%[[ARG2]], %[[ARG3]]] [%[[MIN]], %[[MIN_0]]] [1, 1] -// CHECK-PARALLEL: %[[MATERIALIZE:.*]] = gml_st.materialize %[[LHS]][%[[TILE]]] -// CHECK-PARALLEL: %[[MATERIALIZE_0:.*]] = gml_st.materialize %[[RHS]][%[[TILE]]] -// CHECK-PARALLEL: %[[MATERIALIZE_1:.*]] = gml_st.materialize %[[INIT]][%[[TILE]]] -// CHECK-PARALLEL: %[[GENERIC:.*]] = linalg.generic -// CHECK-PARALLEL-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-PARALLEL-SAME: ins(%[[MATERIALIZE]], %[[MATERIALIZE_0]] : tensor, tensor) -// CHECK-PARALLEL-SAME: outs(%[[MATERIALIZE_1]] : tensor) -// CHECK-PARALLEL-SAME: attrs = {op_label = "tile-2d"} -// CHECK-PARALLEL: ^bb0(%[[OUT:.*]]: f32, %[[ARG5:.*]]: f32, %[[ARG6:.*]]: f32): -// CHECK-PARALLEL: %[[ADDF:.*]] = arith.addf %[[OUT]], %[[ARG5]] -// CHECK-PARALLEL: linalg.yield %[[ADDF]] -// CHECK-PARALLEL: gml_st.set_yield %[[GENERIC]] into %[[INIT]][%[[TILE]]] -// CHECK-PARALLEL: return %[[PARALLEL]] - -// ----- - -func.func @reduce_row(%lhs: tensor, - %rhs: tensor) -> tensor { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %0 = tensor.dim %lhs, %c0 : tensor - - %init = tensor.empty(%0) : tensor - %fill = linalg.fill ins(%cst : f32) - outs(%init : tensor) -> tensor - %sum_of_prod = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"], - op_label = "tile-2d"} - ins(%lhs, %rhs : tensor, tensor) - outs(%fill : tensor) { - ^bb0(%l: f32, %r: f32, %o: f32): - %prod = arith.mulf %l, %r : f32 - %add = arith.addf %prod, %o : f32 - linalg.yield %add : f32 - } -> tensor - func.return %sum_of_prod : tensor -} - - -// CHECK-FOR-LABEL: @reduce_row -// CHECK-FOR-SAME: %[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor - -// CHECK-FOR-DAG: %[[C0_0:.*]] = arith.constant 0 -// CHECK-FOR-DAG: %[[C1_0:.*]] = arith.constant 1 -// CHECK-FOR-DAG: %[[LHS_DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0_0]] -// CHECK-FOR-DAG: %[[LHS_DIM_1:.*]] = tensor.dim %[[LHS]], %[[C1_0]] -// CHECK-FOR-DAG: %[[C256_0:.*]] = arith.constant 256 -// CHECK-FOR-DAG: %[[C512_0:.*]] = arith.constant 512 -// CHECK-FOR-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 -// CHECK-FOR-DAG: %[[INIT_0:.*]] = tensor.empty(%[[LHS_DIM_0]]) -// CHECK-FOR-DAG: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT_0]] : tensor) -// CHECK-FOR: %[[FOR_0:.*]] = gml_st.for (%[[ARG2_0:.*]], %[[ARG3_0:.*]]) = (%[[C0_0]], %[[C0_0]]) -// CHECK-FOR-SAME: to (%[[LHS_DIM_0]], %[[LHS_DIM_1]]) -// CHECK-FOR-SAME: step (%[[C256_0]], %[[C512_0]]) -// CHECK-FOR-SAME: outs (%[[OUT_0:.*]] = %[[FILL]]: tensor) -// CHECK-FOR: %[[MIN_1:.*]] = affine.min #map{{[0-9]*}}(%[[ARG2_0]])[%[[LHS_DIM_0]]] -// CHECK-FOR: %[[MIN_2:.*]] = affine.min #map{{[0-9]*}}(%[[ARG3_0]])[%[[LHS_DIM_1]]] -// CHECK-FOR: %[[TILE_2:.*]] = gml_st.tile [%[[ARG2_0]], %[[ARG3_0]]] [%[[MIN_1]], %[[MIN_2]]] [1, 1] -// CHECK-FOR: %[[MATERIALIZE_2:.*]] = gml_st.materialize %[[LHS]][%[[TILE_2]]] -// CHECK-FOR: %[[MATERIALIZE_3:.*]] = gml_st.materialize %[[RHS]][%[[TILE_2]]] -// CHECK-FOR: %[[TILE_4:.*]] = gml_st.tile [%[[ARG2_0]]] [%[[MIN_1]]] [1] -// CHECK-FOR: %[[MATERIALIZE_4:.*]] = gml_st.materialize %[[OUT_0]][%[[TILE_4]]] -// CHECK-FOR: %[[GENERIC_0:.*]] = linalg.generic -// CHECK-FOR-SAME: iterator_types = ["parallel", "reduction"]} -// CHECK-FOR-SAME: ins(%[[MATERIALIZE_2]], %[[MATERIALIZE_3]] : tensor, tensor) -// CHECK-FOR-SAME: outs(%[[MATERIALIZE_4]] : tensor) -// CHECK-FOR-SAME: attrs = {op_label = "tile-2d"} -// CHECK-FOR: ^bb0(%[[ARG5_0:.*]]: f32, %[[ARG6_0:.*]]: f32, %[[ARG7_0:.*]]: f32): -// CHECK-FOR: %[[MULF:.*]] = arith.mulf %[[ARG5_0]], %[[ARG6_0]] -// CHECK-FOR: %[[ADDF_0:.*]] = arith.addf %[[MULF]], %[[ARG7_0]] -// CHECK-FOR: linalg.yield %[[ADDF_0]] -// CHECK-FOR: gml_st.set_yield %[[GENERIC_0]] into %[[OUT_0]][%[[TILE_4]]] -// CHECK-FOR: return %[[FOR_0]] - -// ----- - func.func @dynamic_broadcast_in_dim_at_tile(%init : tensor, %arg : tensor) -> tensor { %bcast = thlo.dynamic_broadcast_in_dim ins(%arg: tensor) @@ -236,27 +39,61 @@ func.func @dynamic_broadcast_in_dim_at_tile(%init : tensor, // CHECK-FOR-SAME: outs (%[[OUT:.*]] = %[[INIT]]: tensor) // CHECK-FOR: %[[MIN:.*]] = affine.min #map{{[0-9]*}}(%[[I]])[%[[INIT_DIM_0]]] // CHECK-FOR: %[[MIN_0:.*]] = affine.min #map{{[0-9]*}}(%[[J]])[%[[INIT_DIM_1]]] -// CHECK-FOR: %[[TILE:.*]] = gml_st.tile [%[[I]], %[[J]], %[[C0]]] [%[[MIN]], %[[MIN_0]], %[[INIT_DIM_2]]] [1, 1, 1] // CHECK-FOR: %[[ARG_DIM_0:.*]] = tensor.dim %[[ARG]], %[[C0]] // CHECK-FOR: %[[ARG_DIM_1:.*]] = tensor.dim %[[ARG]], %[[C1]] -// CHECK-FOR: %[[OUT_DIM_0:.*]] = tensor.dim %[[OUT]], %[[C0]] -// CHECK-FOR: %[[CMPI:.*]] = arith.cmpi ne, %[[ARG_DIM_0]], %[[OUT_DIM_0]] -// CHECK-FOR: %[[OUT_DIM_2:.*]] = tensor.dim %[[OUT]], %[[C2]] -// CHECK-FOR: %[[CMPI_0:.*]] = arith.cmpi ne, %[[ARG_DIM_1]], %[[OUT_DIM_2]] +// CHECK-FOR: %[[CMPI:.*]] = arith.cmpi ne, %[[ARG_DIM_0]], %[[INIT_DIM_0]] +// CHECK-FOR: %[[CMPI_0:.*]] = arith.cmpi ne, %[[ARG_DIM_1]], %[[INIT_DIM_2]] // CHECK-FOR: %[[SELECT:.*]] = arith.select %[[CMPI]], %[[C0]], %[[I]] // CHECK-FOR: %[[SELECT_0:.*]] = arith.select %[[CMPI]], %[[C1]], %[[MIN]] // CHECK-FOR: %[[SELECT_1:.*]] = arith.select %[[CMPI_0]], %[[C1]], %[[INIT_DIM_2]] -// CHECK-FOR: %[[TILE_0:.*]] = gml_st.tile [%[[SELECT]], %[[C0]]] [%[[SELECT_0]], %[[SELECT_1]]] [1, 1] -// CHECK-FOR: %[[MATERIALIZE:.*]] = gml_st.materialize %[[OUT]][%[[TILE]]] -// CHECK-FOR: %[[MATERIALIZE_0:.*]] = gml_st.materialize %[[ARG]][%[[TILE_0]]] +// CHECK-FOR: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[OUT]] +// CHECK-FOR-SAME: [%[[I]], %[[J]], %[[C0]]] [%[[MIN]], %[[MIN_0]], %[[INIT_DIM_2]]] [1, 1, 1] +// CHECK-FOR: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[ARG]] +// CHECK-FOR-SAME: [%[[SELECT]], %[[C0]]] [%[[SELECT_0]], %[[SELECT_1]]] [1, 1] // CHECK-FOR: %[[DYNAMIC:.*]] = thlo.dynamic_broadcast_in_dim // CHECK-FOR-SAME: ins(%[[MATERIALIZE_0]] // CHECK-FOR-SAME: outs(%[[MATERIALIZE]] // CHECK-FOR-SAME: broadcast_dimensions = [0, 2] +// CHECK-FOR: %[[TILE:.*]] = gml_st.tile [%[[I]], %[[J]], %[[C0]]] [%[[MIN]], %[[MIN_0]], %[[INIT_DIM_2]]] [1, 1, 1] // CHECK-FOR: gml_st.set_yield %[[DYNAMIC]] into %[[OUT]][%[[TILE]]] // CHECK-FOR: return %[[FOR]] // CHECK-PARALLEL-LABEL: @dynamic_broadcast_in_dim_at_tile +// CHECK-PARALLEL-SAME: %[[INIT:.*]]: tensor, %[[ARG:.*]]: tensor + +// CHECK-PARALLEL: %[[C0:.*]] = arith.constant 0 +// CHECK-PARALLEL: %[[C1:.*]] = arith.constant 1 +// CHECK-PARALLEL: %[[C2:.*]] = arith.constant 2 +// CHECK-PARALLEL: %[[C256:.*]] = arith.constant 256 +// CHECK-PARALLEL: %[[C512:.*]] = arith.constant 512 +// CHECK-PARALLEL: %[[INIT_DIM_0:.*]] = tensor.dim %[[INIT]], %[[C0]] +// CHECK-PARALLEL: %[[INIT_DIM_1:.*]] = tensor.dim %[[INIT]], %[[C1]] +// CHECK-PARALLEL: %[[INIT_DIM_2:.*]] = tensor.dim %[[INIT]], %[[C2]] +// CHECK-PARALLEL: %[[PARALLEL:.*]] = gml_st.parallel (%[[I:.*]], %[[J:.*]]) = +// CHECK-PARALLEL-SAME: (%[[C0]], %[[C0]]) +// CHECK-PARALLEL-SAME: to (%[[INIT_DIM_0]], %[[INIT_DIM_1]]) +// CHECK-PARALLEL-SAME: step (%[[C256]], %[[C512]]) +// CHECK-PARALLEL-SAME: outs (%[[OUT:.*]] = %[[INIT]]: tensor) +// CHECK-PARALLEL: %[[MIN:.*]] = affine.min #map{{[0-9]*}}(%[[I]])[%[[INIT_DIM_0]]] +// CHECK-PARALLEL: %[[MIN_0:.*]] = affine.min #map{{[0-9]*}}(%[[J]])[%[[INIT_DIM_1]]] +// CHECK-PARALLEL: %[[ARG_DIM_0:.*]] = tensor.dim %[[ARG]], %[[C0]] +// CHECK-PARALLEL: %[[ARG_DIM_1:.*]] = tensor.dim %[[ARG]], %[[C1]] +// CHECK-PARALLEL: %[[CMPI:.*]] = arith.cmpi ne, %[[ARG_DIM_0]], %[[INIT_DIM_0]] +// CHECK-PARALLEL: %[[CMPI_0:.*]] = arith.cmpi ne, %[[ARG_DIM_1]], %[[INIT_DIM_2]] +// CHECK-PARALLEL: %[[SELECT:.*]] = arith.select %[[CMPI]], %[[C0]], %[[I]] +// CHECK-PARALLEL: %[[SELECT_0:.*]] = arith.select %[[CMPI]], %[[C1]], %[[MIN]] +// CHECK-PARALLEL: %[[SELECT_1:.*]] = arith.select %[[CMPI_0]], %[[C1]], %[[INIT_DIM_2]] +// CHECK-PARALLEL: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[OUT]] +// CHECK-PARALLEL-SAME: [%[[I]], %[[J]], %[[C0]]] [%[[MIN]], %[[MIN_0]], %[[INIT_DIM_2]]] [1, 1, 1] +// CHECK-PARALLEL: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[ARG]] +// CHECK-PARALLEL-SAME: [%[[SELECT]], %[[C0]]] [%[[SELECT_0]], %[[SELECT_1]]] [1, 1] +// CHECK-PARALLEL: %[[DYNAMIC:.*]] = thlo.dynamic_broadcast_in_dim +// CHECK-PARALLEL-SAME: ins(%[[MATERIALIZE_0]] +// CHECK-PARALLEL-SAME: outs(%[[MATERIALIZE]] +// CHECK-PARALLEL-SAME: broadcast_dimensions = [0, 2] +// CHECK-PARALLEL: %[[TILE:.*]] = gml_st.tile [%[[I]], %[[J]], %[[C0]]] [%[[MIN]], %[[MIN_0]], %[[INIT_DIM_2]]] [1, 1, 1] +// CHECK-PARALLEL: gml_st.set_yield %[[DYNAMIC]] into %[[OUT]][%[[TILE]]] +// CHECK-PARALLEL: return %[[PARALLEL]] // ----- @@ -283,13 +120,14 @@ func.func @scatter_i64(%indices: tensor, // CHECK-FOR: %[[INDICES_COUNT:.*]] = tensor.dim %[[INDICES]], %c0 // CHECK-FOR: gml_st.for (%{{.*}}) = (%[[C0]]) to (%[[INDICES_COUNT]]) +// CHECK-FOR-SAME: outs (%[[INIT_:.*]] = %[[INIT]]: tensor) { -// CHECK-FOR: %[[UPDATE_SUB:.*]] = gml_st.materialize %[[UPDATES]] -// CHECK-FOR-SAME: : tensor[!gml_st.tile<1x?x?>] -// CHECK-FOR: %[[INDICES_SUB:.*]] = gml_st.materialize %[[INDICES]] -// CHECK-FOR-SAME: : tensor[!gml_st.tile<1x2>] -// CHECK-FOR: %[[INIT_SUB:.*]] = gml_st.materialize -// CHECK-FOR-SAME: : tensor[!gml_st.tile] +// CHECK-FOR: %[[UPDATE_SUB:.*]] = tensor.extract_slice %[[UPDATES]] +// CHECK-FOR-SAME: : tensor +// CHECK-FOR: %[[INDICES_SUB:.*]] = tensor.extract_slice %[[INDICES]] +// CHECK-FOR-SAME: : tensor +// CHECK-FOR: %[[INIT_SUB:.*]] = tensor.extract_slice %[[INIT_]] +// CHECK-FOR-SAME: : tensor // CHECK-FOR: %[[SCATTER:.*]] = thlo.scatter // CHECK-FOR-SAME: ins(%[[INDICES_SUB]] : tensor<1x2xindex>, @@ -318,11 +156,11 @@ func.func @gather(%operand: tensor, %indices: tensor, // CHECK-FOR: %[[RESULT:.*]] = gml_st.for (%[[I:.*]]) = // CHECK-FOR-SAME: (%[[INIT_:[a-z0-9]+]] = %[[INIT]]: tensor) -// CHECK-FOR: %[[INDEX_TILE:.*]] = gml_st.tile [%[[I]], 0] [1, 4] [1, 1] -// CHECK-FOR: %[[INDEX_SLICE:.*]] = gml_st.materialize %[[INDICES]][%[[INDEX_TILE]]] +// CHECK-FOR: %[[INDEX_SLICE:.*]] = tensor.extract_slice %[[INDICES]] +// CHECK-FOR-SAME: [%[[I]], 0] [1, 4] [1, 1] -// CHECK-FOR: %[[INIT_TILE:.*]] = gml_st.tile [%[[I]], 0] [1, 10] [1, 1] -// CHECK-FOR: %[[INIT_SLICE:.*]] = gml_st.materialize %[[INIT_]][%[[INIT_TILE]]] +// CHECK-FOR: %[[INIT_SLICE:.*]] = tensor.extract_slice %[[INIT_]] +// CHECK-FOR-SAME: [%[[I]], 0] [1, 10] [1, 1] // CHECK-FOR: %[[GATHER_SLICE:.*]] = thlo.gather // CHECK-FOR-SAME: ins(%[[OPERAND]] : tensor, // CHECK-FOR-SAME: %[[INDEX_SLICE]] : tensor<1x4xindex>) @@ -336,9 +174,9 @@ func.func @concatenate_at_tile(%init : tensor, %a: tensor, -> tensor { %concat = thlo.concatenate ins(%a : tensor, %b : tensor, %c : tensor) - outs(%init : tensor) { - dimension = 1 : i64, - op_label = "tile-2d" } + outs(%init : tensor) + dimension = 1 + { op_label = "tile-2d" } func.return %concat : tensor } @@ -358,13 +196,12 @@ func.func @concatenate_at_tile(%init : tensor, %a: tensor, // CHECK-FOR-SAME: outs (%[[ARG6:.*]] = %[[ARG0]]: tensor) // CHECK-FOR: %[[MIN:.*]] = affine.min #map{{[0-9]*}}(%[[ARG4]])[%[[DIM]]] // CHECK-FOR: %[[MIN_0:.*]] = affine.min #map{{[0-9]*}}(%[[ARG5]])[%[[DIM_0]]] -// CHECK-FOR: %[[TILE:.*]] = gml_st.tile [%[[ARG4]], %[[ARG5]]] [%[[MIN]], %[[MIN_0]]] [1, 1] // CHECK-FOR: %[[DIM_4:.*]] = tensor.dim %[[ARG1]], %[[C1]] // CHECK-FOR: %[[MINUI:.*]] = arith.minui %[[ARG5]], %[[DIM_4]] // CHECK-FOR: %[[SUBI:.*]] = arith.subi %[[DIM_4]], %[[MINUI]] // CHECK-FOR: %[[MINUI_0:.*]] = arith.minui %[[SUBI]], %[[MIN_0]] -// CHECK-FOR: %[[TILE_0:.*]] = gml_st.tile [%[[ARG4]], %[[MINUI]]] [%[[MIN]], %[[MINUI_0]]] [%[[C1]], %[[C1]]] -// CHECK-FOR: %[[MATERIALIZE:.*]] = gml_st.materialize %[[ARG1]][%[[TILE_0]]] +// CHECK-FOR: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[ARG1]] +// CHECK-FOR-SAME: [%[[ARG4]], %[[MINUI]]] [%[[MIN]], %[[MINUI_0]]] [1, 1] // CHECK-FOR: %[[CMPI:.*]] = arith.cmpi ule, %[[ARG5]], %[[DIM_4]] // CHECK-FOR: %[[SUBI_0:.*]] = arith.subi %[[ARG5]], %[[DIM_4]] // CHECK-FOR: %[[SELECT:.*]] = arith.select %[[CMPI]], %[[C0]], %[[SUBI_0]] @@ -372,8 +209,8 @@ func.func @concatenate_at_tile(%init : tensor, %a: tensor, // CHECK-FOR: %[[MINUI_1:.*]] = arith.minui %[[SELECT]], %[[DIM_5]] // CHECK-FOR: %[[SUBI_1:.*]] = arith.subi %[[DIM_5]], %[[MINUI_1]] // CHECK-FOR: %[[MINUI_2:.*]] = arith.minui %[[SUBI_1]], %[[MIN_0]] -// CHECK-FOR: %[[TILE_1:.*]] = gml_st.tile [%[[ARG4]], %[[MINUI_1]]] [%[[MIN]], %[[MINUI_2]]] [%[[C1]], %[[C1]]] -// CHECK-FOR: %[[MATERIALIZE_0:.*]] = gml_st.materialize %[[ARG2]][%[[TILE_1]]] +// CHECK-FOR: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[ARG2]] +// CHECK-FOR-SAME: [%[[ARG4]], %[[MINUI_1]]] [%[[MIN]], %[[MINUI_2]]] [1, 1] // CHECK-FOR: %[[CMPI_0:.*]] = arith.cmpi ule, %[[SELECT]], %[[DIM_5]] // CHECK-FOR: %[[SUBI_2:.*]] = arith.subi %[[SELECT]], %[[DIM_5]] // CHECK-FOR: %[[SELECT_0:.*]] = arith.select %[[CMPI_0]], %[[C0]], %[[SUBI_2]] @@ -381,13 +218,15 @@ func.func @concatenate_at_tile(%init : tensor, %a: tensor, // CHECK-FOR: %[[MINUI_3:.*]] = arith.minui %[[SELECT_0]], %[[DIM_6]] // CHECK-FOR: %[[SUBI_3:.*]] = arith.subi %[[DIM_6]], %[[MINUI_3]] // CHECK-FOR: %[[MINUI_4:.*]] = arith.minui %[[SUBI_3]], %[[MIN_0]] -// CHECK-FOR: %[[TILE_2:.*]] = gml_st.tile [%[[ARG4]], %[[MINUI_3]]] [%[[MIN]], %[[MINUI_4]]] [%[[C1]], %[[C1]]] -// CHECK-FOR: %[[MATERIALIZE_1:.*]] = gml_st.materialize %[[ARG3]][%[[TILE_2]]] -// CHECK-FOR: %[[MATERIALIZE_2:.*]] = gml_st.materialize %[[ARG6]][%[[TILE]]] +// CHECK-FOR: %[[MATERIALIZE_1:.*]] = tensor.extract_slice %[[ARG3]] +// CHECK-FOR-SAME: [%[[ARG4]], %[[MINUI_3]]] [%[[MIN]], %[[MINUI_4]]] [1, 1] +// CHECK-FOR: %[[MATERIALIZE_2:.*]] = tensor.extract_slice %[[ARG6]] +// CHECK-FOR: [%[[ARG4]], %[[ARG5]]] [%[[MIN]], %[[MIN_0]]] [1, 1] // CHECK-FOR: %[[CONCATENATE:.*]] = thlo.concatenate // CHECK-FOR-SAME: ins(%[[MATERIALIZE]] : tensor, %[[MATERIALIZE_0]] : tensor, %[[MATERIALIZE_1]] : tensor) // CHECK-FOR-SAME: outs(%[[MATERIALIZE_2]] : tensor) -// CHECK-FOR-SAME: {dimension = 1 : i64} +// CHECK-FOR-SAME: dimension = 1 +// CHECK-FOR: %[[TILE:.*]] = gml_st.tile [%[[ARG4]], %[[ARG5]]] [%[[MIN]], %[[MIN_0]]] [1, 1] // CHECK-FOR: gml_st.set_yield %[[CONCATENATE]] into %[[ARG6]][%[[TILE]]] // CHECK-FOR: return %[[FOR]] @@ -401,7 +240,9 @@ func.func @sort(%input1: tensor, %input2: tensor, %sorted1, %sorted2 = thlo.sort ins(%input1: tensor, %input2: tensor) outs(%init1: tensor, %init2: tensor) - { dimension = 1 : i64, is_stable = true, op_label = "tile-3d" } + dimension = 1 + is_stable = true + {op_label = "tile-3d" } (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { %gt = arith.cmpf ogt, %e11, %e12: f32 thlo.yield %gt : i1 @@ -426,20 +267,32 @@ func.func @sort(%input1: tensor, %input2: tensor, // CHECK-FOR-DAG: %[[TILE_SIZE0:.*]] = affine.min #map{{[0-9]*}}(%[[START0]])[%[[DIM0]]] // CHECK-FOR-DAG: %[[TILE_SIZE2:.*]] = affine.min #map{{[0-9]*}}(%[[START2]])[%[[DIM2]]] // CHECK-FOR-DAG: %[[DIM1:.*]] = tensor.dim %[[IN0]], %[[C1]] -// CHECK-FOR: %[[TILE:.*]] = gml_st.tile +// CHECK-FOR-DAG: %[[IN0_SUB:.*]] = tensor.extract_slice %[[IN0]] +// CHECK-FOR-SAME: [%[[START0]], 0, %[[START2]]] +// CHECK-FOR-SAME: [%[[TILE_SIZE0]], %[[DIM1]], %[[TILE_SIZE2]]] +// CHECK-FOR-SAME: [1, 1, 1] +// CHECK-FOR-DAG: %[[IN1_SUB:.*]] = tensor.extract_slice %[[IN1]] +// CHECK-FOR-SAME: [%[[START0]], 0, %[[START2]]] +// CHECK-FOR-SAME: [%[[TILE_SIZE0]], %[[DIM1]], %[[TILE_SIZE2]]] +// CHECK-FOR-SAME: [1, 1, 1] +// CHECK-FOR-DAG: %[[INIT0_SUB:.*]] = tensor.extract_slice %[[INIT0_]] +// CHECK-FOR-SAME: [%[[START0]], 0, %[[START2]]] +// CHECK-FOR-SAME: [%[[TILE_SIZE0]], %[[DIM1]], %[[TILE_SIZE2]]] +// CHECK-FOR-SAME: [1, 1, 1] +// CHECK-FOR-DAG: %[[INIT1_SUB:.*]] = tensor.extract_slice %[[INIT1_]] // CHECK-FOR-SAME: [%[[START0]], 0, %[[START2]]] // CHECK-FOR-SAME: [%[[TILE_SIZE0]], %[[DIM1]], %[[TILE_SIZE2]]] // CHECK-FOR-SAME: [1, 1, 1] -// CHECK-FOR-DAG: %[[IN0_SUB:.*]] = gml_st.materialize %[[IN0]][%[[TILE]]] -// CHECK-FOR-DAG: %[[IN1_SUB:.*]] = gml_st.materialize %[[IN1]][%[[TILE]]] -// CHECK-FOR-DAG: %[[INIT0_SUB:.*]] = gml_st.materialize %[[INIT0_]][%[[TILE]]] -// CHECK-FOR-DAG: %[[INIT1_SUB:.*]] = gml_st.materialize %[[INIT1_]][%[[TILE]]] // CHECK-FOR: thlo.sort // CHECK-FOR-SAME: ins(%[[IN0_SUB]] : tensor, %[[IN1_SUB]] : tensor) // CHECK-FOR-SAME: outs(%[[INIT0_SUB]] : tensor, %[[INIT1_SUB]] : tensor) +// CHECK-FOR: %[[TILE:.*]] = gml_st.tile +// CHECK-FOR-SAME: [%[[START0]], 0, %[[START2]]] +// CHECK-FOR-SAME: [%[[TILE_SIZE0]], %[[DIM1]], %[[TILE_SIZE2]]] +// CHECK-FOR-SAME: [1, 1, 1] // CHECK-FOR: gml_st.set_yield -// CHECK-FOR-SAME: %[[RESULT_TILE:.*]]#0 into %[[INIT0_]][%[[TILE]]] -// CHECK-FOR: %[[RESULT_TILE]]#1 into %[[INIT1_]][%[[TILE]]] +// CHECK-FOR-SAME: %[[RESULT_TILE:.*]]0 into %[[INIT0_]][%[[TILE]]] +// CHECK-FOR: %[[RESULT_TILE]]1 into %[[INIT1_]][%[[TILE]]] // ----- @@ -453,7 +306,9 @@ func.func @sort2(%input1: tensor<1024x2048x4096xf32>, %input2: tensor<1024x2048x4096xi32>) outs(%init1: tensor<1024x2048x4096xf32>, %init2: tensor<1024x2048x4096xi32>) - { dimension = 1 : i64, is_stable = true, op_label = "tile-3d" } + dimension = 1 + is_stable = true + { op_label = "tile-3d" } (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { %gt = arith.cmpf ogt, %e11, %e12: f32 thlo.yield %gt : i1 @@ -463,3 +318,70 @@ func.func @sort2(%input1: tensor<1024x2048x4096xf32>, } // CHECK-FOR-LABEL: func.func @sort2 + +// ----- + +func.func @reverse_static(%input: tensor<100xf32>, %init: tensor<100xf32>) + -> tensor<100xf32> { + %res = thlo.reverse + ins(%input: tensor<100xf32>) + outs(%init: tensor<100xf32>) + reverse_dimensions = [0] + { op_label = "tile-1d" } + func.return %res : tensor<100xf32> +} + +// CHECK-FOR-LABEL: func @reverse_static +// CHECK-FOR-SAME: %[[ARG0:.*]]: tensor<100xf32>, %[[ARG1:.*]]: tensor<100xf32> +// CHECK-FOR-DAG: %[[C10:.*]] = arith.constant 10 +// CHECK-FOR-DAG: %[[C100:.*]] = arith.constant 100 +// CHECK-FOR: %[[FOR:.*]] = gml_st.for (%[[I:.*]]) = +// CHECK-FOR-SAME: outs (%[[ARG3:.*]] = %[[ARG1]] +// CHECK-FOR: %[[TEMP_SUB_RES:.*]] = arith.subi %[[C100]], %[[I]] +// CHECK-FOR: %[[IN_TILE_DIM:.*]] = arith.subi %[[TEMP_SUB_RES]], %[[C10]] +// CHECK-FOR-DAG: %[[IN_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IN_TILE_DIM]]] +// CHECK-FOR-DAG: %[[INIT_SLICE:.*]] = tensor.extract_slice %[[ARG3]][%[[I]]] +// CHECK-FOR: %[[REVERSED:.*]] = thlo.reverse ins(%[[IN_SLICE]] +// CHECK-FOR: outs(%[[INIT_SLICE]] +// CHECK-FOR-DAG: %[[INIT_TILE:.*]] = gml_st.tile [%[[I]]] +// CHECK-FOR: gml_st.set_yield %[[REVERSED]] into %[[ARG3]][%[[INIT_TILE]]] +// CHECK-FOR: return %[[FOR]] + +// ----- + +func.func @reverse_dynamic(%input: tensor, %init: tensor) + -> tensor { + %res = thlo.reverse + ins(%input: tensor) + outs(%init: tensor) + reverse_dimensions = [0, 1] + { op_label = "tile-2d" } + func.return %res : tensor +} + +// CHECK-FOR-LABEL: func @reverse_dynamic( +// CHECK-FOR-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor +// CHECK-FOR-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-FOR-DAG: %[[C1:.*]] = arith.constant 1 +// CHECK-FOR-DAG: %[[DIM:.*]] = tensor.dim %[[ARG1]], %[[C0]] +// CHECK-FOR-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK-FOR: %[[FOR:.*]] = gml_st.for (%[[I:.*]], %[[J:.*]]) = +// CHECK-FOR-SAME: (%[[C0]], %[[C0]]) to (%[[DIM]], %[[DIM0]]) +// CHECK-FOR-SAME: outs (%[[ARG4:.*]] = %[[ARG1]] +// CHECK-FOR-DAG: %[[AFFINE_MIN1:.*]] = affine.min +// CHECK-FOR-DAG: %[[AFFINE_MIN2:.*]] = affine.min +// CHECK-FOR-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-FOR-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-FOR-DAG: %[[TEMP_SUB_RES0:.*]] = arith.subi %[[DIM1]], %[[I]] +// CHECK-FOR-DAG: %[[IN_TILE_DIM0:.*]] = arith.subi %[[TEMP_SUB_RES0]], %[[AFFINE_MIN1]] +// CHECK-FOR-DAG: %[[TEMP_SUB_RES1:.*]] = arith.subi %[[DIM2]], %[[J]] +// CHECK-FOR-DAG: %[[IN_TILE_DIM1:.*]] = arith.subi %[[TEMP_SUB_RES1]], %[[AFFINE_MIN2]] +// CHECK-FOR-DAG: %[[IN_SLICE:.*]] = tensor.extract_slice %[[ARG0]] +// CHECK-FOR-SAME: [%[[IN_TILE_DIM0]], %[[IN_TILE_DIM1]]] +// CHECK-FOR-DAG: %[[INIT_SLICE:.*]] = tensor.extract_slice %[[ARG4]] +// CHECK-FOR-SAME: [%[[I]], %[[J]]] +// CHECK-FOR: %[[REVERSED:.*]] = thlo.reverse ins(%[[IN_SLICE]] +// CHECK-FOR-SAME: outs(%[[INIT_SLICE]] +// CHECK-FOR: %[[INIT_TILE:.*]] = gml_st.tile [%[[I]], %[[J]]] +// CHECK-FOR: gml_st.set_yield %[[REVERSED]] into %[[ARG4]][%[[INIT_TILE]]] +// CHECK-FOR: return %[[FOR]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_and_fusion.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_and_fusion.mlir index 8e79065a3d3..9b7b37bf353 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_and_fusion.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_and_fusion.mlir @@ -52,12 +52,12 @@ func.func @reduce_cwise(%lhs: tensor<32x16xf32>, %rhs: tensor<32x16xf32>) // CHECK-SAME: to (%[[C32]]) // CHECK-SAME: step (%[[C8]]) // CHECK-SAME: outs (%[[ARG3:.*]] = %[[FILL]]: tensor<32xf32>) -// CHECK: %[[TILE:.*]] = gml_st.tile [%[[ARG2]], 0] [8, 16] [1, 1] -// CHECK: %[[MATERIALIZE:.*]] = gml_st.materialize %[[ARG0]][%[[TILE]]] -// CHECK: %[[TILE_0:.*]] = gml_st.tile [%[[ARG2]], 0] [8, 16] [1, 1] -// CHECK: %[[MATERIALIZE_0:.*]] = gml_st.materialize %[[ARG1]][%[[TILE_0]]] -// CHECK: %[[TILE_1:.*]] = gml_st.tile [%[[ARG2]], 0] [8, 16] [1, 1] -// CHECK: %[[MATERIALIZE_1:.*]] = gml_st.materialize %[[INIT]][%[[TILE_1]]] +// CHECK: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[ARG2]], 0] [8, 16] [1, 1] +// CHECK: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[ARG1]] +// CHECK-SAME: [%[[ARG2]], 0] [8, 16] [1, 1] +// CHECK: %[[MATERIALIZE_1:.*]] = tensor.extract_slice %[[INIT]] +// CHECK-SAME: [%[[ARG2]], 0] [8, 16] [1, 1] // CHECK: %[[GENERIC:.*]] = linalg.generic // CHECK-SAME: iterator_types = ["parallel", "parallel"] // CHECK-SAME: ins(%[[MATERIALIZE]], %[[MATERIALIZE_0]] : tensor<8x16xf32>, tensor<8x16xf32>) @@ -66,8 +66,7 @@ func.func @reduce_cwise(%lhs: tensor<32x16xf32>, %rhs: tensor<32x16xf32>) // CHECK: ^bb0(%[[ARG4:.*]]: f32, %[[ARG5:.*]]: f32, %[[ARG6:.*]]: f32): // CHECK: %[[MULF:.*]] = arith.mulf %[[ARG4]], %[[ARG5]] // CHECK: linalg.yield %[[MULF]] -// CHECK: %[[TILE_2:.*]] = gml_st.tile [%[[ARG2]]] [8] [1] -// CHECK: %[[MATERIALIZE_2:.*]] = gml_st.materialize %[[ARG3]][%[[TILE_2]]] +// CHECK: %[[MATERIALIZE_2:.*]] = tensor.extract_slice %[[ARG3]][%[[ARG2]]] [8] [1] // CHECK: %[[GENERIC_0:.*]] = linalg.generic // CHECK-SAME: iterator_types = ["parallel", "reduction"] // CHECK-SAME: ins(%[[GENERIC]] : tensor<8x16xf32>) @@ -76,5 +75,6 @@ func.func @reduce_cwise(%lhs: tensor<32x16xf32>, %rhs: tensor<32x16xf32>) // CHECK: ^bb0(%[[ARG4_0:.*]]: f32, %[[ARG5_0:.*]]: f32): // CHECK: %[[ADDF:.*]] = arith.addf %[[ARG4_0]], %[[ARG5_0]] // CHECK: linalg.yield %[[ADDF]] -// CHECK: gml_st.set_yield %[[GENERIC_0]] into %[[ARG3]][%[[TILE_2]]] +// CHECK: %[[TILE_2_:.*]] = gml_st.tile [%[[ARG2]]] [8] [1] +// CHECK: gml_st.set_yield %[[GENERIC_0]] into %[[ARG3]][%[[TILE_2_]]] // CHECK: return %[[FOR]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_gpu_warp.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_gpu_warp.mlir deleted file mode 100644 index 596ab7f1bf3..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_gpu_warp.mlir +++ /dev/null @@ -1,389 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file --gml-tiling-gpu-warp | \ -// RUN: FileCheck %s - -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0)> - -// CHECK-LABEL: @tiling_warp_level_reduction -// CHECK-SAME: %[[ARG0:.*]]: tensor<7x13xf32> -func.func @tiling_warp_level_reduction(%arg0: tensor<7x13xf32>) - -> tensor<7xf32> { - // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index - // CHECK-DAG: %[[C13:.*]] = arith.constant 13 : index - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[C1024:.*]] = arith.constant 1024 : index - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[CST:.*]] = arith.constant 0xFF800000 : f32 - // CHECK: %[[EMPTY:.*]] = tensor.empty() - // CHECK: %[[PARALLEL:.*]] = gml_st.parallel (%[[ARG1:.*]]) = (%[[C0]]) to (%[[C1024]]) step (%[[C1]]) distribution ("warp") - // CHECK: %[[TILE:.*]] = gml_st.tile [%[[ARG1]], 0] [1, 13] [1, 1] - // CHECK: %[[MATERIALIZE:.*]] = gml_st.materialize %[[ARG0]][%[[TILE]]] - // CHECK: %[[TILE_0:.*]] = gml_st.tile [%[[ARG1]]] [1] [1] - // CHECK: %[[TILE_1:.*]] = gml_st.tile [%[[ARG1]]] [1] [1] - // CHECK: %[[MATERIALIZE_0:.*]] = gml_st.materialize %[[EMPTY]][%[[TILE_1]]] - // CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[MATERIALIZE_0]] : tensor<1xf32>) - // CHECK: %[[EMPTY_0:.*]] = tensor.empty() - // CHECK: %[[TILE_2:.*]] = gml_st.tile [0] [1] [1] - // CHECK: %[[MATERIALIZE_1:.*]] = gml_st.materialize %[[MATERIALIZE_0]][%[[TILE_2]]] - // CHECK: %[[FILL_0:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[MATERIALIZE_1]] : tensor<1xf32>) - // CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[FILL_0]][%[[C0]]] - // CHECK: %[[PARALLEL_0:.*]] = gml_st.parallel (%[[ARG2:.*]]) = (%[[C0]]) to (%[[C16]]) step (%[[C1]]) distribution ("thread") - // CHECK: %[[TILE_3:.*]] = gml_st.tile [0, %[[ARG2]]] [1, 1] [1, 1] - // CHECK: %[[TILE_4:.*]] = gml_st.tile [0, %[[ARG2]]] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_2:.*]] = gml_st.materialize %[[EMPTY_0]][%[[TILE_4]]] - // CHECK: %[[FILL_1:.*]] = linalg.fill ins(%[[EXTRACTED]] : f32) outs(%[[MATERIALIZE_2]] : tensor<1x1xf32>) - // CHECK: %[[FOR:.*]] = gml_st.for (%[[ARG3:.*]]) = (%[[ARG2]]) to (%[[C13]]) step (%[[C16]]) outs (%[[ARG4:.*]] = %[[FILL_1]]: tensor<1x1xf32>) - // CHECK: %[[TILE_5:.*]] = gml_st.tile [0, %[[ARG3]]] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_3:.*]] = gml_st.materialize %[[MATERIALIZE]][%[[TILE_5]]] : tensor<1x13xf32>[!gml_st.tile<1x1>] to f32 - // CHECK: %[[TILE_6:.*]] = gml_st.tile [0, 0] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_4:.*]] = gml_st.materialize %[[ARG4]][%[[TILE_6]]] : tensor<1x1xf32>[!gml_st.tile<1x1>] to f32 - // CHECK: %[[MAXF:.*]] = arith.maxf %[[MATERIALIZE_4]], %[[MATERIALIZE_3]] : f32 - // CHECK: gml_st.set_yield %[[MAXF]] into %[[ARG4]][%[[TILE_6]]] : f32 into tensor<1x1xf32>[!gml_st.tile<1x1>] - // CHECK: gml_st.set_yield %[[FOR]] into %[[EMPTY_0]][%[[TILE_3]]] - // CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%[[PARALLEL_0]] : tensor<1x16xf32>) outs(%[[FILL]] : tensor<1xf32>) - // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): - // CHECK: %[[MAXF_0:.*]] = arith.maxf %[[OUT]], %[[IN]] : f32 - // CHECK: linalg.yield %[[MAXF_0]] : f32 - // CHECK: gml_st.set_yield %[[GENERIC]] into %[[EMPTY]][%[[TILE_0]]] - // CHECK: return %[[PARALLEL]] - %c1 = arith.constant 1 : index - %c1024 = arith.constant 1024 : index - %c0 = arith.constant 0 : index - %cst = arith.constant 0xFF800000 : f32 - %0 = tensor.empty() : tensor<7xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<7xf32>) - -> tensor<7xf32> - %2 = gml_st.parallel (%arg1) = (%c0) to (%c1024) step (%c1) - distribution ("warp") { - %3 = gml_st.tile [%arg1, 0] [1, 13] [1, 1] : !gml_st.tile<1x13> - %4 = gml_st.materialize %arg0[%3] - : tensor<7x13xf32>[!gml_st.tile<1x13>] to tensor<1x13xf32> - %5 = gml_st.tile [%arg1] [1] [1] : !gml_st.tile<1> - %6 = gml_st.materialize %1[%5] - : tensor<7xf32>[!gml_st.tile<1>] to tensor<1xf32> - %7 = linalg.generic { indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "reduction"]} ins(%4 : tensor<1x13xf32>) - outs(%6 : tensor<1xf32>) { - ^bb0(%in: f32, %out: f32): - %8 = arith.maxf %out, %in : f32 - linalg.yield %8 : f32 - } -> tensor<1xf32> - gml_st.set_yield %7 into %1[%5] - : tensor<1xf32> into tensor<7xf32>[!gml_st.tile<1>] - } : tensor<7xf32> - return %2 : tensor<7xf32> -} - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK-LABEL: @tiling_warp_level_cwise -// CHECK-SAME: %[[ARG0:.*]]: tensor<7x13xf32>, %[[ARG1:.*]]: tensor<7x13xf32> -func.func @tiling_warp_level_cwise(%arg0: tensor<7x13xf32>, - %arg1: tensor<7x13xf32>) -> tensor<7x13xf32> { - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 - // CHECK-DAG: %[[C16:.*]] = arith.constant 16 - // CHECK-DAG: %[[C1024:.*]] = arith.constant 1024 - // CHECK-DAG: %[[C28:.*]] = arith.constant 28 - // CHECK-DAG: %[[EMPTY:.*]] = tensor.empty() : tensor<7x13xf32> - // CHECK: %[[PARALLEL:.*]] = gml_st.parallel - // CHECK-SAME: (%[[ARG2:.*]]) = (%[[C0]]) to (%[[C1024]]) - // CHECK-SAME: step (%[[C1]]) distribution ("warp") - // CHECK: %[[TILE:.*]] = gml_st.tile [%[[ARG2]], 0] [1, 13] [1, 1] - // CHECK: %[[MATERIALIZE:.*]] = gml_st.materialize %[[ARG0]][%[[TILE]]] - // CHECK: %[[MATERIALIZE_0:.*]] = gml_st.materialize %[[ARG1]][%[[TILE]]] - // CHECK: %[[MATERIALIZE_1:.*]] = gml_st.materialize %[[EMPTY]][%[[TILE]]] - // CHECK: %[[PARALLEL_0:.*]] = gml_st.parallel - // CHECK-SAME: (%[[ARG3:.*]]) = (%[[C0]]) to (%[[C16]]) step (%[[C1]]) - // CHECK-SAME: distribution ("thread") - // CHECK: %[[SUBI:.*]] = arith.subi %[[C28]], %[[ARG3]] - // CHECK: %[[DIVUI:.*]] = arith.divui %[[SUBI]], %[[C16]] - // CHECK: %[[TILE_0:.*]] = gml_st.tile [0, %[[ARG3]]] [1, %[[DIVUI]]] [1, 16] - // CHECK: %[[MATERIALIZE_2:.*]] = gml_st.materialize %[[MATERIALIZE_1]][%[[TILE_0]]] - // CHECK: %[[FOR:.*]] = gml_st.for (%[[ARG4:.*]]) = (%[[C0]]) - // CHECK-SAME: to (%[[DIVUI]]) step (%[[C1]]) - // CHECK-SAME: outs (%[[ARG5:.*]] = %[[MATERIALIZE_2]]: tensor<1x?xf32>) - // CHECK: %[[MULI:.*]] = arith.muli %[[ARG4]], %[[C16]] : index - // CHECK: %[[ADDI:.*]] = arith.addi %[[ARG3]], %[[MULI]] : index - // CHECK: %[[TILE_1:.*]] = gml_st.tile [0, %[[ADDI]]] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_3:.*]] = gml_st.materialize %[[MATERIALIZE]][%[[TILE_1]]] - // CHECK: %[[MATERIALIZE_4:.*]] = gml_st.materialize %[[MATERIALIZE_0]][%[[TILE_1]]] - // CHECK: %[[SUBF:.*]] = arith.subf %[[MATERIALIZE_3]], %[[MATERIALIZE_4]] - // CHECK: %[[TILE_2:.*]] = gml_st.tile [0, %[[ARG4]]] [1, 1] [1, 1] - // CHECK: gml_st.set_yield %[[SUBF]] into %[[ARG5]][%[[TILE_2]]] - // CHECK: gml_st.set_yield %[[FOR]] into %[[MATERIALIZE_1]][%[[TILE_0]]] - // CHECK: gml_st.set_yield %[[PARALLEL_0]] into %[[EMPTY]][%[[TILE]]] - // CHECK: return %[[PARALLEL]] - %c1 = arith.constant 1 : index - %c1024 = arith.constant 1024 : index - %c0 = arith.constant 0 : index - %0 = tensor.empty() : tensor<7x13xf32> - %1 = gml_st.parallel (%arg2) = (%c0) to (%c1024) step (%c1) - distribution ("warp") { - %2 = gml_st.tile [%arg2, 0] [1, 13] [1, 1] : !gml_st.tile<1x13> - %3 = gml_st.materialize %arg0[%2] - : tensor<7x13xf32>[!gml_st.tile<1x13>] to tensor<1x13xf32> - %4 = gml_st.materialize %arg1[%2] - : tensor<7x13xf32>[!gml_st.tile<1x13>] to tensor<1x13xf32> - %5 = gml_st.materialize %0[%2] - : tensor<7x13xf32>[!gml_st.tile<1x13>] to tensor<1x13xf32> - %6 = linalg.generic {indexing_maps = [#map, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%3, %4 : tensor<1x13xf32>, tensor<1x13xf32>) - outs(%5 : tensor<1x13xf32>) { - ^bb0(%in: f32, %in_0: f32, %out: f32): - %7 = arith.subf %in, %in_0 : f32 - linalg.yield %7 : f32 - } -> tensor<1x13xf32> - gml_st.set_yield %6 into %0[%2] - : tensor<1x13xf32> into tensor<7x13xf32>[!gml_st.tile<1x13>] - } : tensor<7x13xf32> - return %1 : tensor<7x13xf32> -} - -// ----- - -// CHECK-LABEL: @softmax -// CHECK-SAME: %[[ARG0:.*]]: tensor<2048x4096xf32> -func.func @softmax(%arg0: tensor<2048x4096xf32>) -> tensor<2048x4096xf32> { - // CHECK-DAG: %[[C4096:.*]] = arith.constant 4096 : index - // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index - // CHECK-DAG: %[[C4127:.*]] = arith.constant 4127 : index - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[C1024:.*]] = arith.constant 1024 : index - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[C2048:.*]] = arith.constant 2048 : index - // CHECK-DAG: %[[CST:.*]] = arith.constant -0.000000e+00 : f32 - // CHECK-DAG: %[[CST_0:.*]] = arith.constant 0xFF800000 : f32 - // CHECK: %[[EMPTY:.*]] = tensor.empty() - // CHECK: %[[EMPTY_0:.*]] = tensor.empty() - // CHECK: %[[PARALLEL:.*]] = gml_st.parallel (%[[ARG1:.*]]) = (%[[C0]]) to (%[[C2048]]) step (%[[C1024]]) distribution ("block") - // CHECK: %[[TILE:.*]] = gml_st.tile [%[[ARG1]], 0] [1024, 4096] [1, 1] - // CHECK: %[[MATERIALIZE:.*]] = gml_st.materialize %[[EMPTY_0]][%[[TILE]]] - // CHECK: %[[PARALLEL_0:.*]] = gml_st.parallel (%[[ARG2:.*]]) = (%[[C0]]) to (%[[C1024]]) step (%[[C1]]) distribution ("warp") - // CHECK: %[[TILE_0:.*]] = gml_st.tile [%[[ARG2]], 0] [1, 4096] [1, 1] - // CHECK: %[[ADDI:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index - // CHECK: %[[TILE_1:.*]] = gml_st.tile [%[[ADDI]], 0] [1, 4096] [1, 1] - // CHECK: %[[MATERIALIZE_0:.*]] = gml_st.materialize %[[ARG0]][%[[TILE_1]]] - // CHECK: %[[TILE_2:.*]] = gml_st.tile [%[[ADDI]]] [1] [1] - // CHECK: %[[MATERIALIZE_1:.*]] = gml_st.materialize %[[EMPTY]][%[[TILE_2]]] - // CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[MATERIALIZE_1]] : tensor<1xf32>) - // CHECK: %[[EMPTY_1:.*]] = tensor.empty() - // CHECK: %[[TILE_3:.*]] = gml_st.tile [0] [1] [1] - // CHECK: %[[MATERIALIZE_2:.*]] = gml_st.materialize %[[MATERIALIZE_1]][%[[TILE_3]]] - // CHECK: %[[FILL_0:.*]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[MATERIALIZE_2]] : tensor<1xf32>) - // CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[FILL_0]][%[[C0]]] - // CHECK: %[[PARALLEL_1:.*]] = gml_st.parallel (%[[ARG3:.*]]) = (%[[C0]]) to (%[[C32]]) step (%[[C1]]) distribution ("thread") - // CHECK: %[[TILE_4:.*]] = gml_st.tile [0, %[[ARG3]]] [1, 1] [1, 1] - // CHECK: %[[TILE_5:.*]] = gml_st.tile [0, %[[ARG3]]] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_3:.*]] = gml_st.materialize %[[EMPTY_1]][%[[TILE_5]]] - // CHECK: %[[FILL_1:.*]] = linalg.fill ins(%[[EXTRACTED]] : f32) outs(%[[MATERIALIZE_3]] : tensor<1x1xf32>) - // CHECK: %[[FOR:.*]] = gml_st.for (%[[ARG4:.*]]) = (%[[ARG3]]) to (%[[C4096]]) step (%[[C32]]) outs (%[[ARG5:.*]] = %[[FILL_1]]: tensor<1x1xf32>) - // CHECK: %[[TILE_6:.*]] = gml_st.tile [0, %[[ARG4]]] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_4:.*]] = gml_st.materialize %[[MATERIALIZE_0]][%[[TILE_6]]] : tensor<1x4096xf32>[!gml_st.tile<1x1>] to f32 - // CHECK: %[[TILE_7:.*]] = gml_st.tile [0, 0] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_5:.*]] = gml_st.materialize %[[ARG5]][%[[TILE_7]]] : tensor<1x1xf32>[!gml_st.tile<1x1>] to f32 - // CHECK: %[[MAXF:.*]] = arith.maxf %[[MATERIALIZE_5]], %[[MATERIALIZE_4]] : f32 - // CHECK: gml_st.set_yield %[[MAXF]] into %[[ARG5]][%[[TILE_7]]] : f32 into tensor<1x1xf32>[!gml_st.tile<1x1>] - // CHECK: gml_st.set_yield %[[FOR]] into %[[EMPTY_1]][%[[TILE_4]]] - // CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%[[PARALLEL_1]] : tensor<1x32xf32>) outs(%[[FILL]] : tensor<1xf32>) - // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): - // CHECK: %[[MAXF_0:.*]] = arith.maxf %[[OUT]], %[[IN]] : f32 - // CHECK: linalg.yield %[[MAXF_0]] : f32 - // CHECK: %[[MATERIALIZE_6:.*]] = gml_st.materialize %[[EMPTY_0]][%[[TILE_1]]] - // CHECK: %[[TILE_8:.*]] = gml_st.tile [%[[ADDI]]] [1] [1] - // CHECK: %[[MATERIALIZE_7:.*]] = gml_st.materialize %[[EMPTY]][%[[TILE_8]]] - // CHECK: %[[FILL_2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[MATERIALIZE_7]] : tensor<1xf32>) - // CHECK: %[[EMPTY_2:.*]] = tensor.empty() - // CHECK: %[[TILE_9:.*]] = gml_st.tile [0] [1] [1] - // CHECK: %[[MATERIALIZE_8:.*]] = gml_st.materialize %[[MATERIALIZE_7]][%[[TILE_9]]] - // CHECK: %[[FILL_3:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[MATERIALIZE_8]] : tensor<1xf32>) - // CHECK: %[[EXTRACTED_1:.*]] = tensor.extract %[[FILL_3]][%[[C0]]] - // CHECK: %[[PARALLEL_2:.*]] = gml_st.parallel (%[[ARG3]]) = (%[[C0]]) to (%[[C32]]) step (%[[C1]]) distribution ("thread") - // CHECK: %[[TILE_10:.*]] = gml_st.tile [0, %[[ARG3]]] [1, 1] [1, 1] - // CHECK: %[[TILE_11:.*]] = gml_st.tile [0, %[[ARG3]]] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_9:.*]] = gml_st.materialize %[[EMPTY_2]][%[[TILE_11]]] - // CHECK: %[[FILL_4:.*]] = linalg.fill ins(%[[EXTRACTED_1]] : f32) outs(%[[MATERIALIZE_9]] : tensor<1x1xf32>) - // CHECK: %[[FOR_0:.*]] = gml_st.for (%[[ARG4_0:.*]]) = (%[[ARG3]]) to (%[[C4096]]) step (%[[C32]]) outs (%[[ARG5_0:.*]] = %[[FILL_4]]: tensor<1x1xf32>) - // CHECK: %[[TILE_12:.*]] = gml_st.tile [0, %[[ARG4_0]]] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_10:.*]] = gml_st.materialize %[[MATERIALIZE_0]][%[[TILE_12]]] - // CHECK: %[[TILE_13:.*]] = gml_st.tile [0] [1] [1] - // CHECK: %[[MATERIALIZE_11:.*]] = gml_st.materialize %[[GENERIC]][%[[TILE_13]]] - // CHECK: %[[TILE_14:.*]] = gml_st.tile [0, %[[ARG4_0]]] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_12:.*]] = gml_st.materialize %[[MATERIALIZE_6]][%[[TILE_14]]] - // CHECK: %[[GENERIC_0:.*]] = linalg.generic {indexing_maps = [#map1, #map], iterator_types = ["parallel", "parallel"]} ins(%[[MATERIALIZE_11]] : tensor<1xf32>) outs(%[[MATERIALIZE_12]] : tensor<1x1xf32>) - // CHECK: ^bb0(%[[IN_0:.*]]: f32, %[[OUT_0:.*]]: f32): - // CHECK: linalg.yield %[[IN_0]] : f32 - // CHECK: %[[TILE_15:.*]] = gml_st.tile [0, %[[ARG4_0]]] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_13:.*]] = gml_st.materialize %[[MATERIALIZE_6]][%[[TILE_15]]] - // CHECK: %[[GENERIC_1:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[MATERIALIZE_10]], %[[GENERIC_0]] : tensor<1x1xf32>, tensor<1x1xf32>) outs(%[[MATERIALIZE_13]] : tensor<1x1xf32>) - // CHECK: ^bb0(%[[IN_1:.*]]: f32, %[[IN_3:.*]]: f32, %[[OUT_1:.*]]: f32): - // CHECK: %[[SUBF:.*]] = arith.subf %[[IN_1]], %[[IN_3]] : f32 - // CHECK: linalg.yield %[[SUBF]] : f32 - // CHECK: %[[TILE_16:.*]] = gml_st.tile [0, %[[ARG4_0]]] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_14:.*]] = gml_st.materialize %[[MATERIALIZE_6]][%[[TILE_16]]] - // CHECK: %[[GENERIC_2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[GENERIC_1]] : tensor<1x1xf32>) outs(%[[MATERIALIZE_14]] : tensor<1x1xf32>) - // CHECK: ^bb0(%[[IN_2:.*]]: f32, %[[OUT_2:.*]]: f32): - // CHECK: %[[EXP:.*]] = math.exp %[[IN_2]] : f32 - // CHECK: linalg.yield %[[EXP]] : f32 - // CHECK: %[[EXTRACTED_2:.*]] = tensor.extract %[[GENERIC_2]][%[[C0]], %[[C0]]] - // CHECK: %[[TILE_17:.*]] = gml_st.tile [0, 0] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_15:.*]] = gml_st.materialize %[[ARG5_0]][%[[TILE_17]]] : tensor<1x1xf32>[!gml_st.tile<1x1>] to f32 - // CHECK: %[[ADDF:.*]] = arith.addf %[[MATERIALIZE_15]], %[[EXTRACTED_2]] : f32 - // CHECK: gml_st.set_yield %[[ADDF]] into %[[ARG5_0]][%[[TILE_17]]] : f32 into tensor<1x1xf32>[!gml_st.tile<1x1>] - // CHECK: gml_st.set_yield %[[FOR_0]] into %[[EMPTY_2]][%[[TILE_10]]] - // CHECK: %[[GENERIC_3:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%[[PARALLEL_2]] : tensor<1x32xf32>) outs(%[[FILL_2]] : tensor<1xf32>) - // CHECK: ^bb0(%[[IN_4:.*]]: f32, %[[OUT_3:.*]]: f32): - // CHECK: %[[ADDF_0:.*]] = arith.addf %[[OUT_3]], %[[IN_4]] : f32 - // CHECK: linalg.yield %[[ADDF_0]] : f32 - // CHECK: %[[PARALLEL_3:.*]] = gml_st.parallel (%[[ARG3]]) = (%[[C0]]) to (%[[C32]]) step (%[[C1]]) distribution ("thread") - // CHECK: %[[SUBI:.*]] = arith.subi %[[C4127]], %[[ARG3]] : index - // CHECK: %[[DIVUI:.*]] = arith.divui %[[SUBI]], %[[C32]] : index - // CHECK: %[[TILE_18:.*]] = gml_st.tile [0, %[[ARG3]]] [1, %[[DIVUI]]] [1, 32] - // CHECK: %[[MATERIALIZE_16:.*]] = gml_st.materialize %[[MATERIALIZE_6]][%[[TILE_18]]] - // CHECK: %[[FOR_1:.*]] = gml_st.for (%[[ARG4_1:.*]]) = (%[[C0]]) to (%[[DIVUI]]) step (%[[C1]]) outs (%[[ARG5_1:.*]] = %[[MATERIALIZE_16]]: tensor<1x?xf32>) - // CHECK: %[[MULI:.*]] = arith.muli %[[ARG4_1]], %[[C32]] : index - // CHECK: %[[ADDI_0:.*]] = arith.addi %[[ARG3]], %[[MULI]] : index - // CHECK: %[[TILE_19:.*]] = gml_st.tile [0, %[[ADDI_0]]] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_17:.*]] = gml_st.materialize %[[MATERIALIZE_0]][%[[TILE_19]]] - // CHECK: %[[TILE_20:.*]] = gml_st.tile [0] [1] [1] - // CHECK: %[[MATERIALIZE_18:.*]] = gml_st.materialize %[[GENERIC]][%[[TILE_20]]] - // CHECK: %[[TILE_21:.*]] = gml_st.tile [0, %[[ADDI_0]]] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_19:.*]] = gml_st.materialize %[[MATERIALIZE_6]][%[[TILE_21]]] - // CHECK: %[[GENERIC_4:.*]] = linalg.generic {indexing_maps = [#map1, #map], iterator_types = ["parallel", "parallel"]} ins(%[[MATERIALIZE_18]] : tensor<1xf32>) outs(%[[MATERIALIZE_19]] : tensor<1x1xf32>) - // CHECK: ^bb0(%[[IN_5:.*]]: f32, %[[OUT_4:.*]]: f32): - // CHECK: linalg.yield %[[IN_5]] : f32 - // CHECK: %[[TILE_22:.*]] = gml_st.tile [0, %[[ADDI_0]]] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_20:.*]] = gml_st.materialize %[[MATERIALIZE_6]][%[[TILE_22]]] - // CHECK: %[[GENERIC_5:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[MATERIALIZE_17]], %[[GENERIC_4]] : tensor<1x1xf32>, tensor<1x1xf32>) outs(%[[MATERIALIZE_20]] : tensor<1x1xf32>) - // CHECK: ^bb0(%[[IN_6:.*]]: f32, %[[IN_4_0:.*]]: f32, %[[OUT_5:.*]]: f32): - // CHECK: %[[SUBF_0:.*]] = arith.subf %[[IN_6]], %[[IN_4_0]] : f32 - // CHECK: linalg.yield %[[SUBF_0]] : f32 - // CHECK: %[[TILE_23:.*]] = gml_st.tile [0, %[[ADDI_0]]] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_21:.*]] = gml_st.materialize %[[MATERIALIZE_6]][%[[TILE_23]]] - // CHECK: %[[GENERIC_6:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[GENERIC_5]] : tensor<1x1xf32>) outs(%[[MATERIALIZE_21]] : tensor<1x1xf32>) - // CHECK: ^bb0(%[[IN_7:.*]]: f32, %[[OUT_6:.*]]: f32): - // CHECK: %[[EXP_0:.*]] = math.exp %[[IN_7]] : f32 - // CHECK: linalg.yield %[[EXP_0]] : f32 - // CHECK: %[[EXTRACTED_2_0:.*]] = tensor.extract %[[GENERIC_6]][%[[C0]], %[[C0]]] - // CHECK: %[[TILE_24:.*]] = gml_st.tile [0] [1] [1] - // CHECK: %[[MATERIALIZE_22:.*]] = gml_st.materialize %[[GENERIC_3]][%[[TILE_24]]] - // CHECK: %[[TILE_25:.*]] = gml_st.tile [0, %[[ADDI_0]]] [1, 1] [1, 1] - // CHECK: %[[MATERIALIZE_23:.*]] = gml_st.materialize %[[MATERIALIZE_6]][%[[TILE_25]]] - // CHECK: %[[GENERIC_7:.*]] = linalg.generic {indexing_maps = [#map1, #map], iterator_types = ["parallel", "parallel"]} ins(%[[MATERIALIZE_22]] : tensor<1xf32>) outs(%[[MATERIALIZE_23]] : tensor<1x1xf32>) - // CHECK: ^bb0(%[[IN_8:.*]]: f32, %[[OUT_7:.*]]: f32): - // CHECK: linalg.yield %[[IN_8]] : f32 - // CHECK: %[[EXTRACTED_3:.*]] = tensor.extract %[[GENERIC_7]][%[[C0]], %[[C0]]] - // CHECK: %[[DIVF:.*]] = arith.divf %[[EXTRACTED_2_0]], %[[EXTRACTED_3]] : f32 - // CHECK: %[[TILE_26:.*]] = gml_st.tile [0, %[[ARG4_1]]] [1, 1] [1, 1] - // CHECK: gml_st.set_yield %[[DIVF]] into %[[ARG5_1]][%[[TILE_26]]] : f32 into tensor<1x?xf32>[!gml_st.tile<1x1>] - // CHECK: gml_st.set_yield %[[FOR_1]] into %[[MATERIALIZE_6]][%[[TILE_18]]] - // CHECK: gml_st.set_yield %[[PARALLEL_3]] into %[[MATERIALIZE]][%[[TILE_0]]] - // CHECK: gml_st.set_yield %[[PARALLEL_0]] into %[[EMPTY_0]][%[[TILE]]] - // CHECK: return %[[PARALLEL]] - %c1 = arith.constant 1 : index - %c1024 = arith.constant 1024 : index - %c0 = arith.constant 0 : index - %c2048 = arith.constant 2048 : index - %cst = arith.constant -0.000000e+00 : f32 - %cst_0 = arith.constant 0xFF800000 : f32 - %0 = tensor.empty() : tensor<2048xf32> - %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<2048xf32>) - -> tensor<2048xf32> - %2 = tensor.empty() : tensor<2048x4096xf32> - %3 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2048xf32>) - -> tensor<2048xf32> - %4 = gml_st.parallel (%arg1) = (%c0) to (%c2048) step (%c1024) - distribution ("block") { - %5 = gml_st.tile [%arg1, 0] [1024, 4096] [1, 1] : !gml_st.tile<1024x4096> - %6 = gml_st.materialize %2[%5] - : tensor<2048x4096xf32>[!gml_st.tile<1024x4096>] to tensor<1024x4096xf32> - %7 = gml_st.parallel (%arg2) = (%c0) to (%c1024) step (%c1) - distribution ("warp") { - %8 = gml_st.tile [%arg2, 0] [1, 4096] [1, 1] : !gml_st.tile<1x4096> - %9 = arith.addi %arg1, %arg2 : index - %10 = gml_st.tile [%9, 0] [1, 4096] [1, 1] : !gml_st.tile<1x4096> - %11 = gml_st.materialize %arg0[%10] - : tensor<2048x4096xf32>[!gml_st.tile<1x4096>] to tensor<1x4096xf32> - %12 = gml_st.tile [%9] [1] [1] : !gml_st.tile<1> - %13 = gml_st.materialize %1[%12] - : tensor<2048xf32>[!gml_st.tile<1>] to tensor<1xf32> - %14 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%11 : tensor<1x4096xf32>) outs(%13 : tensor<1xf32>) { - ^bb0(%in: f32, %out: f32): - %23 = arith.maxf %out, %in : f32 - linalg.yield %23 : f32 - } -> tensor<1xf32> - %15 = gml_st.materialize %2[%10] - : tensor<2048x4096xf32>[!gml_st.tile<1x4096>] to tensor<1x4096xf32> - %16 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%14 : tensor<1xf32>) outs(%15 : tensor<1x4096xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<1x4096xf32> - %17 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%11, %16 : tensor<1x4096xf32>, tensor<1x4096xf32>) - outs(%15 : tensor<1x4096xf32>) { - ^bb0(%in: f32, %in_1: f32, %out: f32): - %23 = arith.subf %in, %in_1 : f32 - linalg.yield %23 : f32 - } -> tensor<1x4096xf32> - %18 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%17 : tensor<1x4096xf32>) outs(%15 : tensor<1x4096xf32>) { - ^bb0(%in: f32, %out: f32): - %23 = math.exp %in : f32 - linalg.yield %23 : f32 - } -> tensor<1x4096xf32> - %19 = gml_st.materialize %3[%12] - : tensor<2048xf32>[!gml_st.tile<1>] to tensor<1xf32> - %20 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%18 : tensor<1x4096xf32>) outs(%19 : tensor<1xf32>) { - ^bb0(%in: f32, %out: f32): - %23 = arith.addf %out, %in : f32 - linalg.yield %23 : f32 - } -> tensor<1xf32> - %21 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%20 : tensor<1xf32>) outs(%15 : tensor<1x4096xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } -> tensor<1x4096xf32> - %22 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%18, %21 : tensor<1x4096xf32>, tensor<1x4096xf32>) - outs(%15 : tensor<1x4096xf32>) { - ^bb0(%in: f32, %in_1: f32, %out: f32): - %23 = arith.divf %in, %in_1 : f32 - linalg.yield %23 : f32 - } -> tensor<1x4096xf32> - gml_st.set_yield %22 into %6[%8] - : tensor<1x4096xf32> into tensor<1024x4096xf32>[!gml_st.tile<1x4096>] - } : tensor<1024x4096xf32> - gml_st.set_yield %7 into %2[%5] - : tensor<1024x4096xf32> into tensor<2048x4096xf32>[!gml_st.tile<1024x4096>] - } : tensor<2048x4096xf32> - return %4 : tensor<2048x4096xf32> -} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_softmax.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_softmax.mlir index 1cc1dae2461..d39b28e3c9c 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_softmax.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_softmax.mlir @@ -3,11 +3,6 @@ // RUN: --canonicalize --cse | \ // RUN: FileCheck %s -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0)> -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0)> - // CHECK-LABEL: @partial_softmax // CHECK-SAME: %[[ARG0:.*]]: tensor<64x128xf32> func.func @partial_softmax(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { @@ -19,74 +14,45 @@ func.func @partial_softmax(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK-DAG: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<64xf32>) // CHECK-DAG: %[[INIT_0:.*]] = tensor.empty() : tensor<64x128xf32> // CHECK: %[[PARALLEL:.*]] = gml_st.parallel - // CHECK-SAME: (%[[ARG1:.*]]) = (%[[C0]]) - // CHECK-SAME: to (%[[C64]]) step (%[[C8]]) - // CHECK-SAME: distribution ("test") - // CHECK-DAG: %[[TILE:.*]] = gml_st.tile [%[[ARG1]], 0] [8, 128] [1, 1] - // CHECK-DAG: %[[MATERIALIZE:.*]] = gml_st.materialize %[[ARG0]][%[[TILE]]] - // CHECK-DAG: %[[TILE_0:.*]] = gml_st.tile [%[[ARG1]]] [8] [1] - // CHECK-DAG: %[[MATERIALIZE_0:.*]] = gml_st.materialize %[[FILL]][%[[TILE_0]]] - // CHECK: %[[GENERIC:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP0:map[0-9]*]], #[[MAP1:map[0-9]*]]], - // CHECK-SAME: iterator_types = ["parallel", "reduction"]} + // CHECK-SAME: (%[[ARG1:.*]]) = (%[[C0]]) to (%[[C64]]) step (%[[C8]]) + // CHECK-SAME: outs (%[[INIT_0_:.*]] = %[[INIT_0]]: + // CHECK: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0] [8, 128] [1, 1] + // CHECK: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[FILL]][%[[ARG1]]] [8] [1] + // CHECK: %[[REDUCE:.*]] = linalg.reduce { arith.maxf } // CHECK-SAME: ins(%[[MATERIALIZE]] : tensor<8x128xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_0]] : tensor<8xf32>) - // CHECK: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): - // CHECK: %[[MAXF:.*]] = arith.maxf %[[ARG4]], %[[ARG3]] - // CHECK: linalg.yield %[[MAXF]] - // CHECK: %[[MATERIALIZE_1:.*]] = gml_st.materialize %[[INIT_0]][%[[TILE]]] - // CHECK: %[[GENERIC_0:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]]], - // CHECK-SAME: iterator_types = ["parallel", "parallel"]} - // CHECK-SAME: ins(%[[GENERIC]] : tensor<8xf32>) - // CHECK-SAME: outs(%[[MATERIALIZE_1]] : tensor<8x128xf32>) - // CHECK: ^bb0(%[[ARG3_0:.*]]: f32, %[[ARG4_0:.*]]: f32): - // CHECK: linalg.yield %[[ARG3_0]] - // CHECK: %[[GENERIC_1:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]], - // CHECK-SAME: iterator_types = ["parallel", "parallel"]} - // CHECK-SAME: ins(%[[MATERIALIZE]], %[[GENERIC_0]] : tensor<8x128xf32>, tensor<8x128xf32>) + // CHECK-SAME: dimensions = [1] + // CHECK: %[[MATERIALIZE_1:.*]] = tensor.extract_slice %[[INIT_0]][%[[ARG1]], 0] [8, 128] [1, 1] + // CHECK: %[[BROADCAST:.*]] = linalg.broadcast + // CHECK-SAME: ins(%[[REDUCE]] : tensor<8xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_1]] : tensor<8x128xf32>) - // CHECK: ^bb0(%[[ARG3_1:.*]]: f32, %[[ARG4_1:.*]]: f32, %[[ARG5:.*]]: f32): - // CHECK: %[[SUBF:.*]] = arith.subf %[[ARG3_1]], %[[ARG4_1]] - // CHECK: linalg.yield %[[SUBF]] - // CHECK: gml_st.set_yield %[[GENERIC_1]] into %[[INIT_0]][%[[TILE]]] + // CHECK-SAME: dimensions = [1] + // CHECK: %[[INIT_0_SUB:.*]] = tensor.extract_slice %[[INIT_0_]][%[[ARG1]], 0] [8, 128] [1, 1] + // CHECK: %[[MAP:.*]] = linalg.map { arith.subf } + // CHECK-SAME: ins(%[[MATERIALIZE]], %[[BROADCAST]] : tensor<8x128xf32>, tensor<8x128xf32>) + // CHECK-SAME: outs(%[[INIT_0_SUB]] : tensor<8x128xf32>) + // CHECK: return %[[PARALLEL]] %cst = arith.constant 0xFF800000 : f32 %0 = tensor.empty() : tensor<64xf32> %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64xf32>) -> tensor<64xf32> - %2 = linalg.generic {indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "reduction"]} - ins(%arg0 : tensor<64x128xf32>) outs(%1 : tensor<64xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - %6 = arith.maxf %arg2, %arg1 : f32 - linalg.yield %6 : f32 - } -> tensor<64xf32> + %2 = linalg.reduce { arith.maxf } + ins(%arg0 : tensor<64x128xf32>) + outs(%1 : tensor<64xf32>) + dimensions = [1] %3 = tensor.empty() : tensor<64x128xf32> - %4 = linalg.generic {indexing_maps = [#map1, #map0], - iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<64xf32>) - outs(%3 : tensor<64x128xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - linalg.yield %arg1 : f32 - } -> tensor<64x128xf32> - %5 = linalg.generic {indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %4 : tensor<64x128xf32>, tensor<64x128xf32>) - outs(%3 : tensor<64x128xf32>) { - ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): - %6 = arith.subf %arg1, %arg2 : f32 - linalg.yield %6 : f32 - } -> tensor<64x128xf32> + %4 = linalg.broadcast + ins(%2 : tensor<64xf32>) + outs(%3 : tensor<64x128xf32>) + dimensions = [1] + %5 = linalg.map { arith.subf } + ins(%arg0, %4 : tensor<64x128xf32>, tensor<64x128xf32>) + outs(%3 : tensor<64x128xf32>) return %5 : tensor<64x128xf32> } // ----- -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0)> -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0)> - // CHECK-LABEL: @partial_softmax_fusion // CHECK-SAME: %[[ARG0:.*]]: tensor<64x128xf32>, %[[ARG1:.*]]: index func.func @partial_softmax_fusion(%arg0: tensor<64x128xf32>, %arg1: index) @@ -95,73 +61,43 @@ func.func @partial_softmax_fusion(%arg0: tensor<64x128xf32>, %arg1: index) // CHECK-DAG: %[[INIT:.*]] = tensor.empty() : tensor<64xf32> // CHECK-DAG: %[[FILL:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<64xf32>) // CHECK-DAG: %[[INIT_0:.*]] = tensor.empty() : tensor<64x128xf32> - // CHECK-DAG: %[[TILE:.*]] = gml_st.tile [%[[ARG1]], 0] [8, 128] [1, 1] - // CHECK-DAG: %[[MATERIALIZE:.*]] = gml_st.materialize %[[ARG0]][%[[TILE]]] - // CHECK-DAG: %[[TILE_0:.*]] = gml_st.tile [%[[ARG1]]] [8] [1] - // CHECK-DAG: %[[MATERIALIZE_0:.*]] = gml_st.materialize %[[FILL]][%[[TILE_0]]] - // CHECK: %[[GENERIC:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP0:map[0-9]*]], #[[MAP1:map[0-9]*]]], - // CHECK-SAME: iterator_types = ["parallel", "reduction"]} + // CHECK-DAG: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0] [8, 128] [1, 1] + // CHECK-DAG: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[FILL]][%[[ARG1]]] [8] [1] + // CHECK: %[[REDUCE:.*]] = linalg.reduce { arith.maxf } // CHECK-SAME: ins(%[[MATERIALIZE]] : tensor<8x128xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_0]] : tensor<8xf32>) - // CHECK: ^bb0(%[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32): - // CHECK-DAG: %[[MAXF:.*]] = arith.maxf %[[ARG3]], %[[ARG2]] - // CHECK: linalg.yield %[[MAXF]] - // CHECK-DAG: %[[MATERIALIZE_1:.*]] = gml_st.materialize %[[INIT_0]][%[[TILE]]] - // CHECK: %[[GENERIC_0:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]]], - // CHECK-SAME: iterator_types = ["parallel", "parallel"]} - // CHECK-SAME: ins(%[[GENERIC]] : tensor<8xf32>) + // CHECK-SAME: dimensions = [1] + // CHECK-DAG: %[[MATERIALIZE_1:.*]] = tensor.extract_slice %[[INIT_0]][%[[ARG1]], 0] [8, 128] [1, 1] + // CHECK: %[[BROADCAST:.*]] = linalg.broadcast + // CHECK-SAME: ins(%[[REDUCE]] : tensor<8xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_1]] : tensor<8x128xf32>) - // CHECK: ^bb0(%[[ARG2_0:.*]]: f32, %[[ARG3_0:.*]]: f32): - // CHECK: linalg.yield %[[ARG2_0]] - // CHECK: %[[GENERIC_1:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]], - // CHECK-SAME: iterator_types = ["parallel", "parallel"]} - // CHECK-SAME: ins(%[[MATERIALIZE]], %[[GENERIC_0]] : tensor<8x128xf32>, tensor<8x128xf32>) + // CHECK-SAME: dimensions = [1] + // CHECK: %[[MAP:.*]] = linalg.map { arith.subf } + // CHECK-SAME: ins(%[[MATERIALIZE]], %[[BROADCAST]] : tensor<8x128xf32>, tensor<8x128xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_1]] : tensor<8x128xf32>) - // CHECK: ^bb0(%[[ARG2_1:.*]]: f32, %[[ARG3_1:.*]]: f32, %[[ARG4:.*]]: f32): - // CHECK-DAG: %[[SUBF:.*]] = arith.subf %[[ARG2_1]], %[[ARG3_1]] - // CHECK: linalg.yield %[[SUBF]] - // CHECK: return %[[GENERIC_1]] + // CHECK: return %[[MAP]] %cst = arith.constant 0xFF800000 : f32 %0 = tensor.empty() : tensor<64xf32> %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64xf32>) -> tensor<64xf32> - %2 = linalg.generic {indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "reduction"]} - ins(%arg0 : tensor<64x128xf32>) outs(%1 : tensor<64xf32>) { - ^bb0(%arg2: f32, %arg3: f32): - %9 = arith.maxf %arg3, %arg2 : f32 - linalg.yield %9 : f32 - } -> tensor<64xf32> + %2 = linalg.reduce { arith.maxf } + ins(%arg0 : tensor<64x128xf32>) + outs(%1 : tensor<64xf32>) + dimensions = [1] %3 = tensor.empty() : tensor<64x128xf32> - %4 = linalg.generic {indexing_maps = [#map1, #map0], - iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<64xf32>) - outs(%3 : tensor<64x128xf32>) { - ^bb0(%arg2: f32, %arg3: f32): - linalg.yield %arg2 : f32 - } -> tensor<64x128xf32> - %5 = linalg.generic {indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %4 : tensor<64x128xf32>, tensor<64x128xf32>) - outs(%3 : tensor<64x128xf32>) { - ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): - %9 = arith.subf %arg2, %arg3 : f32 - linalg.yield %9 : f32 - } -> tensor<64x128xf32> - %7 = gml_st.tile [%arg1, 0] [8, 128] [1, 1] : !gml_st.tile<8x128> - %8 = gml_st.materialize %5[%7] - : tensor<64x128xf32>[!gml_st.tile<8x128>] to tensor<8x128xf32> + %4 = linalg.broadcast + ins(%2 : tensor<64xf32>) + outs(%3 : tensor<64x128xf32>) + dimensions = [1] + %5 = linalg.map { arith.subf } + ins(%arg0, %4 : tensor<64x128xf32>, tensor<64x128xf32>) + outs(%3 : tensor<64x128xf32>) + %8 = tensor.extract_slice %5[%arg1, 0] [8, 128] [1, 1] + : tensor<64x128xf32> to tensor<8x128xf32> return %8 : tensor<8x128xf32> } // ----- -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1) -> (d0)> -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0)> - // CHECK-LABEL: @softmax // CHECK-SAME: %[[ARG0:.*]]: tensor<64x128xf32> func.func @softmax(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { @@ -176,122 +112,69 @@ func.func @softmax(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK-DAG: %[[FILL_0:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<64xf32>) // CHECK: %[[PARALLEL:.*]] = gml_st.parallel // CHECK-SAME: (%[[ARG1:.*]]) = (%[[C0]]) to (%[[C64]]) step (%[[C8]]) - // CHECK: %[[TILE:.*]] = gml_st.tile [%[[ARG1]], 0] [8, 128] [1, 1] - // CHECK: %[[MATERIALIZE:.*]] = gml_st.materialize %[[ARG0]][%[[TILE]]] - // CHECK: %[[TILE_0:.*]] = gml_st.tile [%[[ARG1]]] [8] [1] - // CHECK: %[[MATERIALIZE_0:.*]] = gml_st.materialize %[[FILL]][%[[TILE_0]]] - // CHECK: %[[GENERIC:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP0:map[0-9]*]], #[[MAP1:map[0-9]*]]], - // CHECK-SAME: iterator_types = ["parallel", "reduction"]} + // CHECK-SAME: outs (%[[INIT_0_:.*]] = %[[INIT_0]]: + // CHECK: %[[MATERIALIZE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0] [8, 128] [1, 1] + // CHECK: %[[MATERIALIZE_0:.*]] = tensor.extract_slice %[[FILL]][%[[ARG1]]] [8] [1] + // CHECK: %[[REDUCE:.*]] = linalg.reduce { arith.maxf } // CHECK-SAME: ins(%[[MATERIALIZE]] : tensor<8x128xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_0]] : tensor<8xf32>) - // CHECK: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: f32): - // CHECK: %[[MAXF:.*]] = arith.maxf %[[ARG4]], %[[ARG3]] - // CHECK: linalg.yield %[[MAXF]] - // CHECK: %[[MATERIALIZE_1:.*]] = gml_st.materialize %[[INIT_0]][%[[TILE]]] - // CHECK: %[[GENERIC_0:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]]], - // CHECK-SAME: iterator_types = ["parallel", "parallel"]} - // CHECK-SAME: ins(%[[GENERIC]] : tensor<8xf32>) + // CHECK-SAME: dimensions = [1] + // CHECK: %[[MATERIALIZE_1:.*]] = tensor.extract_slice %[[INIT_0]][%[[ARG1]], 0] [8, 128] [1, 1] + // CHECK: %[[BROADCAST:.*]] = linalg.broadcast + // CHECK-SAME: ins(%[[REDUCE]] : tensor<8xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_1]] : tensor<8x128xf32>) - // CHECK: ^bb0(%[[ARG3_0:.*]]: f32, %[[ARG4_0:.*]]: f32): - // CHECK: linalg.yield %[[ARG3_0]] - // CHECK: %[[GENERIC_1:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]], - // CHECK-SAME: iterator_types = ["parallel", "parallel"]} - // CHECK-SAME: ins(%[[MATERIALIZE]], %[[GENERIC_0]] : tensor<8x128xf32>, tensor<8x128xf32>) + // CHECK-SAME: dimensions = [1] + // CHECK: %[[MAP:.*]] = linalg.map { arith.subf } + // CHECK-SAME: ins(%[[MATERIALIZE]], %[[BROADCAST]] : tensor<8x128xf32>, tensor<8x128xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_1]] : tensor<8x128xf32>) - // CHECK: ^bb0(%[[ARG3_1:.*]]: f32, %[[ARG4_1:.*]]: f32, %[[ARG5:.*]]: f32): - // CHECK: %[[SUBF:.*]] = arith.subf %[[ARG3_1]], %[[ARG4_1]] - // CHECK: linalg.yield %[[SUBF]] - // CHECK: %[[GENERIC_2:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]]], - // CHECK-SAME: iterator_types = ["parallel", "parallel"]} - // CHECK-SAME: ins(%[[GENERIC_1]] : tensor<8x128xf32>) + // CHECK: %[[MAP_0:.*]] = linalg.map { math.exp } + // CHECK-SAME: ins(%[[MAP]] : tensor<8x128xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_1]] : tensor<8x128xf32>) - // CHECK: ^bb0(%[[ARG3_2:.*]]: f32, %[[ARG4_2:.*]]: f32): - // CHECK: %[[EXP:.*]] = math.exp %[[ARG3_2]] - // CHECK: linalg.yield %[[EXP]] - // CHECK: %[[MATERIALIZE_3:.*]] = gml_st.materialize %[[FILL_0]][%[[TILE_0]]] - // CHECK: %[[GENERIC_3:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]], - // CHECK-SAME: iterator_types = ["parallel", "reduction"]} - // CHECK-SAME: ins(%[[GENERIC_2]] : tensor<8x128xf32>) + // CHECK: %[[MATERIALIZE_3:.*]] = tensor.extract_slice %[[FILL_0]][%[[ARG1]]] [8] [1] + // CHECK: %[[REDUCE_0:.*]] = linalg.reduce { arith.addf } + // CHECK-SAME: ins(%[[MAP_0]] : tensor<8x128xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_3]] : tensor<8xf32>) - // CHECK: ^bb0(%[[ARG3_3:.*]]: f32, %[[ARG4_3:.*]]: f32): - // CHECK: %[[ADDF:.*]] = arith.addf %[[ARG4_3]], %[[ARG3_3]] - // CHECK: linalg.yield %[[ADDF]] - // CHECK: %[[GENERIC_4:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]]], - // CHECK-SAME: iterator_types = ["parallel", "parallel"]} - // CHECK-SAME: ins(%[[GENERIC_3]] : tensor<8xf32>) + // CHECK-SAME: dimensions = [1] + // CHECK: %[[BROADCAST_0:.*]] = linalg.broadcast + // CHECK-SAME: ins(%[[REDUCE_0]] : tensor<8xf32>) // CHECK-SAME: outs(%[[MATERIALIZE_1]] : tensor<8x128xf32>) - // CHECK: ^bb0(%[[ARG3_4:.*]]: f32, %[[ARG4_4:.*]]: f32): - // CHECK: linalg.yield %[[ARG3_4]] - // CHECK: %[[GENERIC_5:.*]] = linalg.generic - // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP0]]], - // CHECK-SAME: iterator_types = ["parallel", "parallel"]} - // CHECK-SAME: ins(%[[GENERIC_2]], %[[GENERIC_4]] : tensor<8x128xf32>, tensor<8x128xf32>) - // CHECK-SAME: outs(%[[MATERIALIZE_1]] : tensor<8x128xf32>) - // CHECK: ^bb0(%[[ARG3_5:.*]]: f32, %[[ARG4_5:.*]]: f32, %[[ARG5_0:.*]]: f32): - // CHECK: %[[DIVF:.*]] = arith.divf %[[ARG3_5]], %[[ARG4_5]] - // CHECK: linalg.yield %[[DIVF]] - // CHECK: gml_st.set_yield %[[GENERIC_5]] into %[[INIT_0]][%[[TILE]]] + // CHECK-SAME: dimensions = [1] + // CHECK: %[[INIT_0_SUB:.*]] = tensor.extract_slice %[[INIT_0_]][%[[ARG1]], 0] [8, 128] [1, 1] + // CHECK: %[[MAP_1:.*]] = linalg.map { arith.divf } + // CHECK-SAME: ins(%[[MAP_0]], %[[BROADCAST_0]] : tensor<8x128xf32>, tensor<8x128xf32>) + // CHECK-SAME: outs(%[[INIT_0_SUB]] : tensor<8x128xf32>) + // CHECK: %[[TILE:.*]] = gml_st.tile [%[[ARG1]], 0] [8, 128] [1, 1] + // CHECK: gml_st.set_yield %[[MAP_1]] into %[[INIT_0_]][%[[TILE]]] // CHECK: return %[[PARALLEL]] %cst = arith.constant -0.000000e+00 : f32 %cst_0 = arith.constant 0xFF800000 : f32 %0 = tensor.empty() : tensor<64xf32> %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<64xf32>) -> tensor<64xf32> - %2 = linalg.generic {indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "reduction"]} - ins(%arg0 : tensor<64x128xf32>) outs(%1 : tensor<64xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - %11 = arith.maxf %arg2, %arg1 : f32 - linalg.yield %11 : f32 - } -> tensor<64xf32> + %2 = linalg.reduce { arith.maxf } + ins(%arg0 : tensor<64x128xf32>) + outs(%1 : tensor<64xf32>) dimensions = [1] %3 = tensor.empty() : tensor<64x128xf32> - %4 = linalg.generic {indexing_maps = [#map1, #map0], - iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<64xf32>) - outs(%3 : tensor<64x128xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - linalg.yield %arg1 : f32 - } -> tensor<64x128xf32> - %5 = linalg.generic {indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %4 : tensor<64x128xf32>, tensor<64x128xf32>) - outs(%3 : tensor<64x128xf32>) { - ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): - %11 = arith.subf %arg1, %arg2 : f32 - linalg.yield %11 : f32 - } -> tensor<64x128xf32> - %6 = linalg.generic {indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<64x128xf32>) - outs(%3 : tensor<64x128xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - %11 = math.exp %arg1 : f32 - linalg.yield %11 : f32 - } -> tensor<64x128xf32> + %4 = linalg.broadcast + ins(%2 : tensor<64xf32>) + outs(%3 : tensor<64x128xf32>) + dimensions = [1] + %5 = linalg.map { arith.subf } + ins(%arg0, %4 : tensor<64x128xf32>, tensor<64x128xf32>) + outs(%3 : tensor<64x128xf32>) + %6 = linalg.map { math.exp } + ins(%5 : tensor<64x128xf32>) + outs(%3 : tensor<64x128xf32>) %7 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64xf32>) -> tensor<64xf32> - %8 = linalg.generic {indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "reduction"]} ins(%6 : tensor<64x128xf32>) - outs(%7 : tensor<64xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - %11 = arith.addf %arg2, %arg1 : f32 - linalg.yield %11 : f32 - } -> tensor<64xf32> - %9 = linalg.generic {indexing_maps = [#map1, #map0], - iterator_types = ["parallel", "parallel"]} ins(%8 : tensor<64xf32>) - outs(%3 : tensor<64x128xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - linalg.yield %arg1 : f32 - } -> tensor<64x128xf32> - %10 = linalg.generic {indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel"]} - ins(%6, %9 : tensor<64x128xf32>, tensor<64x128xf32>) - outs(%3 : tensor<64x128xf32>) { - ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): - %11 = arith.divf %arg1, %arg2 : f32 - linalg.yield %11 : f32 - } -> tensor<64x128xf32> + %8 = linalg.reduce { arith.addf } + ins(%6 : tensor<64x128xf32>) + outs(%7 : tensor<64xf32>) + dimensions = [1] + %9 = linalg.broadcast + ins(%8 : tensor<64xf32>) + outs(%3 : tensor<64x128xf32>) + dimensions = [1] + %10 = linalg.map { arith.divf } + ins(%6, %9 : tensor<64x128xf32>, tensor<64x128xf32>) + outs(%3 : tensor<64x128xf32>) return %10 : tensor<64x128xf32> } diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/transform_map_for_cpu.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/transform_map_for_cpu.mlir deleted file mode 100644 index 1e19feab4c9..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/transform_map_for_cpu.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// RUN: mlir-hlo-opt %s --gml-st-cpu-transform-map="tile-size=8" | FileCheck %s - -func.func @map_unary(%input: tensor, %init: tensor) - -> tensor { - %abs = linalg.map - ins(%input:tensor) - outs(%init:tensor) - (%input_elem: f32) { - %0 = math.absf %input_elem: f32 - linalg.yield %0: f32 - } - func.return %abs : tensor -} - -// CHECK-LABEL: func.func @map_unary( -// CHECK-SAME: %[[INPUT:.*]]: tensor, -// CHECK-SAME: %[[INIT:.*]]: tensor) - -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 - -// CHECK-DAG: %[[DIM_0:.*]] = tensor.dim %[[INPUT]], %[[C0]] -// CHECK-DAG: %[[DIM_1:.*]] = tensor.dim %[[INPUT]], %[[C1]] - -// CHECK-NEXT: %[[RESULT:.*]] = gml_st.parallel (%[[I:.*]], %[[J:.*]]) = -// CHECK-SAME: (%[[C0]], %[[C0]]) to (%[[DIM_0]], %[[DIM_1]]) -// CHECK-SAME: step (%[[C1]], %[[C8]]) { -// CHECK: %[[MIN_DIM:.*]] = affine.min #map(%[[J]])[%[[DIM_1]]] -// CHECK-NEXT: %[[INPUT_TILE:.*]] = gml_st.tile [%[[I]], %[[J]]] -// CHECK-SAME: [1, %[[MIN_DIM]]] [1, 1] -// CHECK-NEXT: %[[INPUT_SLICE:.*]] = gml_st.materialize %[[INPUT]] -// CHECK-NEXT: %[[INIT_TILE:.*]] = gml_st.tile [%[[I]], %[[J]]] -// CHECK-SAME: [1, %[[MIN_DIM]]] [1, 1] -// CHECK-NEXT: %[[INIT_SLICE:.*]] = gml_st.materialize %[[INIT]] -// CHECK-NEXT: %[[MAPPED:.*]] = linalg.map -// CHECK-NEXT: ins(%[[INPUT_SLICE]] : tensor<1x?xf32>) -// CHECK-NEXT: outs(%[[INIT_SLICE]] : tensor<1x?xf32>) -// CHECK-NEXT: (%[[IN_ELEM:.*]]: f32) { -// CHECK-NEXT: %[[RES_ELEM:.*]] = math.absf %[[IN_ELEM]] : f32 -// CHECK-NEXT: linalg.yield %[[RES_ELEM]] : f32 -// CHECK-NEXT: } -// CHECK-NEXT: gml_st.set_yield %[[MAPPED]] into %[[INIT]][%[[INIT_TILE]]] -// CHECK-NEXT: } -// CHECK-NEXT: return %[[RESULT]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/transform_matmul_for_cpu.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/transform_matmul_for_cpu.mlir deleted file mode 100644 index 3f4764ef5c1..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/transform_matmul_for_cpu.mlir +++ /dev/null @@ -1,126 +0,0 @@ -// RUN: mlir-hlo-opt %s -xla-cpu-transform-matmul="tile-sizes=8,4,2" | FileCheck %s - -#id_map = affine_map<(d0, d1) -> (d0, d1)> - -func.func @matmul_static(%arg0: tensor<128x16xf32>, %arg1: tensor<16x64xf32>, - %output: tensor<128x64xf32>) -> tensor<128x64xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x16xf32>, tensor<16x64xf32>) - outs(%output : tensor<128x64xf32>) -> tensor<128x64xf32> - return %2 : tensor<128x64xf32> -} - -// CHECK-LABEL: func @matmul_static( -// CHECK-SAME: %[[LHS:.*]]: tensor<128x16xf32>, -// CHECK-SAME: %[[RHS:.*]]: tensor<16x64xf32>, -// CHECK-SAME: %[[OUT:.*]]: tensor<128x64xf32>) - -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C_LHS_COL:.*]] = arith.constant 16 : index -// CHECK-DAG: %[[C_RHS_COL:.*]] = arith.constant 64 : index -// CHECK-DAG: %[[C_LHS_ROW:.*]] = arith.constant 128 : index - -// CHECK: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[C_LHS_ROW]], %[[C_RHS_COL]]) step (%[[C8]], %[[C4]]) - -// CHECK: %[[LHS_TILE:.*]] = gml_st.tile [%[[I]], 0] [8, 16] -// CHECK: %[[LHS_SLICE:.*]] = gml_st.materialize %[[LHS]][%[[LHS_TILE]]] - -// CHECK: %[[RHS_TILE:.*]] = gml_st.tile [0, %[[J]]] [16, 4] -// CHECK: %[[RHS_SLICE:.*]] = gml_st.materialize %[[RHS]][%[[RHS_TILE]]] - -// CHECK: %[[OUT_TILE:.*]] = gml_st.tile [%[[I]], %[[J]]] [8, 4] -// CHECK: %[[OUT_SLICE:.*]] = gml_st.materialize %[[OUT]][%[[OUT_TILE]]] - -// CHECK: %[[FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) -// CHECK-SAME: to (%[[C_LHS_COL]]) step (%[[C2]]) -// CHECK-SAME: outs (%[[OUT_SUB_ARG:.*]] = %[[OUT_SLICE]]: - -// CHECK: %[[LHS_TILE_2:.*]] = gml_st.tile [0, %[[K]]] [8, 2] -// CHECK: %[[LHS_SLICE_2:.*]] = gml_st.materialize %[[LHS_SLICE]][%[[LHS_TILE_2]]] - -// CHECK: %[[RHS_TILE_2:.*]] = gml_st.tile [%[[K]], 0] [2, 4] -// CHECK: %[[RHS_SLICE_2:.*]] = gml_st.materialize %[[RHS_SLICE]][%[[RHS_TILE_2]]] - -// CHECK: %[[OUT_TILE_2:.*]] = gml_st.tile [0, 0] [8, 4] -// CHECK: %[[OUT_SLICE_2:.*]] = gml_st.materialize %[[OUT_SUB_ARG]][%[[OUT_TILE_2]]] - -// CHECK: %[[MATMUL:.*]] = linalg.matmul -// CHECK-SAME: ins(%[[LHS_SLICE_2]], %[[RHS_SLICE_2]] : -// CHECK: outs(%[[OUT_SLICE_2]] : - -// CHECK-NEXT: gml_st.set_yield %[[MATMUL]] into %[[OUT_SUB_ARG]][%[[OUT_TILE_2]]] -// CHECK: gml_st.set_yield %[[FOR]] into %[[OUT]][%[[OUT_TILE]]] - -// ----- - -func.func @matmul(%arg0: tensor, %arg1: tensor) - -> tensor { - %c0 = arith.constant 0 : index - %0 = tensor.dim %arg0, %c0 : tensor - %c1 = arith.constant 1 : index - %1 = tensor.dim %arg1, %c1 : tensor - %2 = tensor.empty(%0, %1) : tensor - %cst = arith.constant 0.000000e+00 : f32 - %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor) -> tensor - %4 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) - outs(%3 : tensor) -> tensor - return %4 : tensor -} - -// CHECK-LABEL: func @matmul( -// CHECK-SAME: %[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) - -// CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index - -// CHECK-DAG: %[[DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : [[TY_2D:.*]] -// CHECK-DAG: %[[DIM_1:.*]] = tensor.dim %[[RHS]], %[[C1]] : [[TY_2D]] -// CHECK: %[[INIT:.*]] = tensor.empty(%[[DIM_0]], %[[DIM_1]]) : [[TY_2D]] -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[C0_F32]]{{.*}}outs(%[[INIT]] -// CHECK-DAG: %[[LHS_ROW:.*]] = tensor.dim %[[LHS]], %[[C0]] : [[TY_2D]] -// CHECK-DAG: %[[LHS_COL:.*]] = tensor.dim %[[LHS]], %[[C1]] : [[TY_2D]] -// CHECK-DAG: %[[RHS_COL:.*]] = tensor.dim %[[RHS]], %[[C1]] : [[TY_2D]] - -// CHECK: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[LHS_ROW]], %[[RHS_COL]]) step (%[[C8]], %[[C4]]) - -// CHECK: %[[LHS_TILE:.*]] = gml_st.tile [%[[I]], 0] -// CHECK: %[[LHS_SLICE:.*]] = gml_st.materialize %[[LHS]][%[[LHS_TILE]]] - -// CHECK: %[[RHS_TILE:.*]] = gml_st.tile [0, %[[J]]] -// CHECK: %[[RHS_SLICE:.*]] = gml_st.materialize %[[RHS]][%[[RHS_TILE]]] - -// CHECK: %[[OUT_TILE:.*]] = gml_st.tile [%[[I]], %[[J]]] -// CHECK: %[[OUT_SLICE:.*]] = gml_st.materialize %[[FILL]][%[[OUT_TILE]]] - -// CHECK: %[[LHS_SUB_ROW:.*]] = tensor.dim %[[LHS_SLICE]], %[[C0]] : [[TY_2D]] -// CHECK: %[[LHS_SUB_COL:.*]] = tensor.dim %[[LHS_SLICE]], %[[C1]] : [[TY_2D]] -// CHECK: %[[RHS_SUB_COL:.*]] = tensor.dim %[[RHS_SLICE]], %[[C1]] : [[TY_2D]] - -// CHECK: %[[FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) -// CHECK-SAME: to (%[[LHS_SUB_COL]]) step (%[[C2]]) -// CHECK-SAME: outs (%[[OUT_SUB_ARG:.*]] = %[[OUT_SLICE]]: [[TY_2D]]) - -// CHECK: %[[LHS_TILE_2:.*]] = gml_st.tile [0, %[[K]]] -// CHECK: %[[LHS_SLICE_2:.*]] = gml_st.materialize %[[LHS_SLICE]][%[[LHS_TILE_2]]] - -// CHECK: %[[RHS_TILE_2:.*]] = gml_st.tile [%[[K]], 0] -// CHECK: %[[RHS_SLICE_2:.*]] = gml_st.materialize %[[RHS_SLICE]][%[[RHS_TILE_2]]] - -// CHECK: %[[OUT_TILE_2:.*]] = gml_st.tile [0, 0] [%[[LHS_SUB_ROW]], %[[RHS_SUB_COL]]] -// CHECK: %[[OUT_SLICE_2:.*]] = gml_st.materialize %[[OUT_SUB_ARG]][%[[OUT_TILE_2]]] - -// CHECK: %[[MATMUL:.*]] = linalg.matmul -// CHECK-SAME: ins(%[[LHS_SLICE_2]], %[[RHS_SLICE_2]] : [[TY_2D]], [[TY_2D]]) -// CHECK: outs(%[[OUT_SLICE_2]] : [[TY_2D]]) - -// CHECK-NEXT: gml_st.set_yield %[[MATMUL]] into %[[OUT_SUB_ARG]][%[[OUT_TILE_2]]] -// CHECK: gml_st.set_yield %[[FOR]] into %[[FILL]][%[[OUT_TILE]]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/transform_transpose_for_cpu.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/transform_transpose_for_cpu.mlir deleted file mode 100644 index 6d4375f055a..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/transform_transpose_for_cpu.mlir +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: mlir-hlo-opt %s --gml-st-cpu-transform-transpose | FileCheck %s - -func.func @transpose_permutation(%input: tensor<16x32x64xf32>, - %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> { - %transpose = linalg.transpose - ins(%input:tensor<16x32x64xf32>) - outs(%init:tensor<32x64x16xf32>) - permutation = [1, 2, 0] - func.return %transpose : tensor<32x64x16xf32> -} - -// CHECK-LABEL: func.func @transpose_permutation( -// CHECK-SAME: %[[INPUT:.*]]: tensor<16x32x64xf32>, -// CHECK-SAME: %[[INIT:.*]]: tensor<32x64x16xf32>) - -// CHECK: gml_st.parallel -// CHECK: %[[INPUT_SUB:.*]] = gml_st.materialize %[[INPUT]] -// CHECK: : tensor<16x32x64xf32>[!gml_st.tile<8x1x8>] to tensor<8x1x8xf32> - -// CHECK: %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT]] -// CHECK: : tensor<32x64x16xf32>[!gml_st.tile<1x8x8>] to tensor<1x8x8xf32> - -// CHECK: linalg.transpose -// CHECK-NEXT: ins(%[[INPUT_SUB]] : tensor<8x1x8xf32>) -// CHECK-NEXT: outs(%[[INIT_SUB]] : tensor<1x8x8xf32>) -// CHECK-NEXT: permutation = [1, 2, 0] \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/triton_tiling/transform_matmul_for_triton.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/triton_tiling/transform_matmul_for_triton.mlir new file mode 100644 index 00000000000..b07a8e58bd7 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/triton_tiling/transform_matmul_for_triton.mlir @@ -0,0 +1,109 @@ +// RUN: mlir-hlo-opt %s --split-input-file \ +// RUN: -xla-triton-transform-matmul="tile-sizes=8,4,2 distribution-label=test" \ +// RUN: | FileCheck %s --dump-input=always + +func.func @matmul_static(%arg0: tensor<128x16xf32>, %arg1: tensor<16x64xf32>, + %output: tensor<128x64xf32>) -> tensor<128x64xf32> { + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x16xf32>, tensor<16x64xf32>) + outs(%output : tensor<128x64xf32>) -> tensor<128x64xf32> + return %2 : tensor<128x64xf32> +} + +// CHECK-LABEL: func @matmul_static( +// CHECK-SAME: %[[LHS:.*]]: tensor<128x16xf32>, +// CHECK-SAME: %[[RHS:.*]]: tensor<16x64xf32>, +// CHECK-SAME: %[[OUT:.*]]: tensor<128x64xf32>) + +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) +// CHECK-SAME: distribution ("test") +// CHECK: %[[FOR:.*]] = gml_st.for (%[[K:.*]]) = (%[[C0]]) +// CHECK: %[[MATMUL:.*]] = linalg.matmul +// CHECK-SAME: -> tensor<8x4xf32> +// CHECK: gml_st.set_yield %[[MATMUL]] +// CHECK: gml_st.set_yield %[[FOR]] + +// ----- + +func.func @matmul_fuse_output(%arg0: tensor, %arg1: tensor, + %arg2: tensor) + -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg0, %c0 : tensor + %dim1 = tensor.dim %arg1, %c1 : tensor + %init = tensor.empty(%dim0, %dim1) : tensor + %cst = arith.constant 0.000000e+00 : f32 + %filled = linalg.fill ins(%cst : f32) + outs(%init : tensor) -> tensor + %4 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%filled : tensor) -> tensor + %5 = linalg.matmul ins(%arg0, %arg2 : tensor, tensor) + outs(%filled : tensor) -> tensor + %6 = linalg.map { math.absf } + ins(%5 : tensor) + outs(%init : tensor) + + %result = linalg.map { arith.addf } + ins(%4, %6 : tensor, tensor) + outs(%init : tensor) + return %result : tensor +} + +// CHECK-LABEL: func @matmul_fuse_output( +// CHECK: %[[C0:.*]] = arith.constant 0 : index + +// CHECK: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) +// CHECK: gml_st.for (%[[K:.*]]) = (%[[C0]]) +// CHECK: %[[MATMUL:.*]] = linalg.matmul +// CHECK: gml_st.set_yield %[[MATMUL]] + +// CHECK: gml_st.for +// CHECK: %[[MATMUL:.*]] = linalg.matmul +// CHECK: gml_st.set_yield %[[MATMUL]] + +// CHECK: linalg.map +// CHECK: linalg.map + +// CHECK: gml_st.set_yield + +// ----- + +func.func @matmul_fuse_input_and_output( + %arg0: tensor, %arg1: tensor, + %init: tensor) + -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %filled = linalg.fill ins(%cst : f32) + outs(%init : tensor) -> tensor + %mapped = linalg.map { math.absf } + ins(%arg0 : tensor) + outs(%init : tensor) + %bcast = linalg.broadcast + ins(%arg1 : tensor) + outs(%init : tensor) + dimensions = [1] + + %matmul = linalg.matmul + ins(%mapped, %bcast : tensor, tensor) + outs(%filled : tensor) -> tensor + + %result = linalg.map { math.absf } + ins(%matmul : tensor) + outs(%init : tensor) + return %result : tensor +} + +// CHECK-LABEL: func @matmul_fuse_input_and_output( +// CHECK: %[[C0:.*]] = arith.constant 0 : index + +// CHECK: gml_st.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) +// CHECK: gml_st.for (%[[K:.*]]) = (%[[C0]]) +// CHECK: linalg.map +// CHECK: linalg.broadcast +// CHECK: %[[MATMUL:.*]] = linalg.matmul +// CHECK: gml_st.set_yield %[[MATMUL]] +// CHECK: linalg.map +// CHECK: gml_st.set_yield diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorization.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorization.mlir deleted file mode 100644 index a7f9fbc6a17..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorization.mlir +++ /dev/null @@ -1,295 +0,0 @@ -// RUN: mlir-hlo-opt %s --vectorize-gml-st-loops --split-input-file |\ -// RUN: FileCheck %s - -#map0 = affine_map<(d0) -> (d0)> -func.func @tiled_add(%A: tensor<8xf32>, %B: tensor<8xf32>, - %C: tensor<8xf32>) -> tensor<8xf32> { - %c0 = arith.constant 0 : index - %c2 = arith.constant 2 : index - %c8 = arith.constant 8 : index - %sum = gml_st.loop (%i) = (%c0) to (%c8) step (%c2) - ins (%A_ = %A: tensor<8xf32>, %B_ = %B: tensor<8xf32>) - outs (%C_ = %C: tensor<8xf32>) { - %A_sub = tensor.extract_slice %A_[%i] [2] [1] - : tensor<8xf32> to tensor<2xf32> - %B_sub = tensor.extract_slice %B_[%i] [2] [1] - : tensor<8xf32> to tensor<2xf32> - %C_sub = tensor.extract_slice %C_[%i] [2] [1] - : tensor<8xf32> to tensor<2xf32> - %sum_sub = linalg.generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel"] - } ins(%A_sub, %B_sub : tensor<2xf32>, tensor<2xf32>) - outs(%C_sub : tensor<2xf32>) { - ^bb0(%a: f32, %b: f32, %c: f32): - %0 = arith.addf %a, %b : f32 - linalg.yield %0 : f32 - } -> tensor<2xf32> - %update = tensor.insert_slice %sum_sub into %C_[%i] [2] [1] - : tensor<2xf32> into tensor<8xf32> - gml_st.yield %update : tensor<8xf32> - } - func.return %sum : tensor<8xf32> -} -// CHECK-LABEL: func @tiled_add - -// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - -// CHECK: gml_st.loop (%[[IV:arg[0-9]]]) = -// CHECK-SAME: ins (%[[A:arg[0-9]]] = %{{arg[0-9]}}: tensor<8xf32>, -// CHECK-SAME: %[[B:arg[0-9]]] = %{{arg[0-9]}}: tensor<8xf32> -// CHECK-SAME: outs (%[[C:arg[0-9]]] = %{{arg[0-9]}}: tensor<8xf32>) - -// CHECK-NEXT: %[[LHS:.*]] = vector.transfer_read %[[A]][%[[IV]]], %[[CST]] -// CHECK-SAME: {in_bounds = [true]} : tensor<8xf32>, vector<2xf32> -// CHECK-NEXT: %[[RHS:.*]] = vector.transfer_read %[[B]][%[[IV]]], %[[CST]] -// CHECK-SAME: {in_bounds = [true]} : tensor<8xf32>, vector<2xf32> - -// CHECK-NEXT: %[[SUM:.*]] = arith.addf %[[LHS]], %[[RHS]] : vector<2xf32> - -// CHECK-NEXT: %{{.*}} = vector.transfer_write %[[SUM]], %[[C]][%[[IV]]] -// CHECK-SAME: {in_bounds = [true]} : vector<2xf32>, tensor<8xf32> - -// ----- - -func.func @tiled_reduction_2d(%in: tensor<80x60xf32>) -> tensor<80xf32> { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c60 = arith.constant 60 : index - %c80 = arith.constant 80 : index - %cst = arith.constant 0.000000e+00 : f32 - - %init = tensor.empty() : tensor<80xf32> - %out = linalg.fill ins(%cst : f32) outs(%init : tensor<80xf32>) -> tensor<80xf32> - - %sum = gml_st.loop (%i, %j) = (%c0, %c0) to (%c80, %c60) step (%c4, %c4) - ins (%in_ = %in: tensor<80x60xf32>, %cst_ = %cst: f32) - outs (%out_ = %out: tensor<80xf32>) - iterators[#gml_st.iterator_type, - #gml_st.iterator_type] { - %in_sub = tensor.extract_slice %in_[%i, %j] [4, 4] [1, 1] - : tensor<80x60xf32> to tensor<4x4xf32> - %out_sub = tensor.extract_slice %out_[%i] [4] [1] - : tensor<80xf32> to tensor<4xf32> - %local_fill = linalg.fill ins(%cst_ : f32) outs(%out_sub : tensor<4xf32>) -> tensor<4xf32> - %reduced_tile = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>], - iterator_types = ["parallel", "reduction"]} - ins(%in_sub : tensor<4x4xf32>) - outs(%local_fill : tensor<4xf32>) { - ^bb0(%a: f32, %b: f32): - %0 = arith.addf %a, %b : f32 - linalg.yield %0 : f32 - } -> tensor<4xf32> - - %acc = linalg.generic { - indexing_maps = [affine_map<(d0) -> (d0)>, - affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - ins(%reduced_tile : tensor<4xf32>) - outs(%out_sub : tensor<4xf32>) { - ^bb0(%a: f32, %b: f32): - %1 = arith.addf %a, %b : f32 - linalg.yield %1 : f32 - } -> tensor<4xf32> - %update = tensor.insert_slice %acc into %out_[%i] [4] [1] - : tensor<4xf32> into tensor<80xf32> - gml_st.yield %update : tensor<80xf32> - } - func.return %sum : tensor<80xf32> -} - -// CHECK-LABEL: func @tiled_reduction_2d - -// CHECK: gml_st.loop -// CHECK-SAME: ins (%{{arg[0-9]}} = %{{arg[0-9]}}: tensor<80x60xf32>, -// CHECK-SAME: %[[CST:arg[0-9]]] = %{{.*}}: f32 - -// CHECK: %[[BCAST:.*]] = vector.broadcast %[[CST]] : f32 to vector<4xf32> -// CHECK-NOT: vector.transfer_write %[[BCAST]] -// CHECK: vector.multi_reduction , %{{.*}}, %[[BCAST]] [1] : vector<4x4xf32> to vector<4xf32> - -// ----- - -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d1)> -#map2 = affine_map<(d0) -> (d0)> -#map3 = affine_map<(d0) -> ()> -func.func @reduction_1d(%arg0: tensor<16xf32>) -> tensor { - %cst = arith.constant 0.000000e+00 : f32 - %c16 = arith.constant 16 : index - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %0 = tensor.empty() : tensor - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor - %2 = tensor.empty() : tensor<8xf32> - %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<8xf32>) -> tensor<8xf32> - %4 = gml_st.loop (%arg1) = (%c0) to (%c16) step (%c8) - ins (%arg2 = %arg0: tensor<16xf32>) - outs (%arg3 = %3: tensor<8xf32>) - iterators[#gml_st.iterator_type] { - %6 = tensor.extract_slice %arg2[%arg1] [8] [1] - : tensor<16xf32> to tensor<8xf32> - %7 = tensor.expand_shape %6 [[0, 1]] - : tensor<8xf32> into tensor<1x8xf32> - %8 = linalg.generic {indexing_maps = [#map0, #map1], - iterator_types = ["reduction", "parallel"]} - ins(%7 : tensor<1x8xf32>) - outs(%arg3 : tensor<8xf32>) { - ^bb0(%arg4: f32, %arg5: f32): - %9 = arith.addf %arg4, %arg5 : f32 - linalg.yield %9 : f32 - } -> tensor<8xf32> - gml_st.yield %8 : tensor<8xf32> - } - %5 = linalg.generic {indexing_maps = [#map2, #map3], - iterator_types = ["reduction"]} - ins(%4 : tensor<8xf32>) - outs(%1 : tensor) { - ^bb0(%arg1: f32, %arg2: f32): - %6 = arith.addf %arg1, %arg2 : f32 - linalg.yield %6 : f32 - } -> tensor - func.return %5 : tensor -} -// CHECK-LABEL: func @reduction_1d - -// CHECK: gml_st.loop -// CHECK-SAME: ins (%[[IN:arg[0-9]]] = %{{arg[0-9]}}: tensor<16xf32>) - -// CHECK: %[[VECTOR:.*]] = vector.transfer_read %[[IN]] -// CHECK: vector.shape_cast %[[VECTOR]] : vector<8xf32> to vector<1x8xf32> -// CHECK-NOT: tensor.expand_shape -// CHECK: vector.multi_reduction - -// ----- - -#map0 = affine_map<(d0, d1) -> (d0, d1)> -func.func @test_transfer_read_of_one_dim_expand_shape( - %in: tensor<10xf32>) -> tensor<5xf32> { - %c0 = arith.constant 0 : index - %min_float = arith.constant dense<-3.402820e+38> : vector<5xf32> - %zero_float = arith.constant 0.000000e+00 : f32 - %0 = tensor.expand_shape %in [[0, 1]] : tensor<10xf32> into tensor<2x5xf32> - %1 = tensor.empty() : tensor<5xf32> - %2 = vector.transfer_read %0[%c0, %c0], %zero_float - {in_bounds = [true, true], permutation_map = #map0} - : tensor<2x5xf32>, vector<2x5xf32> - %3 = vector.multi_reduction , %2, %min_float [0] - : vector<2x5xf32> to vector<5xf32> - %4 = vector.transfer_write %3, %1[%c0] {in_bounds = [true]} - : vector<5xf32>, tensor<5xf32> - func.return %4 : tensor<5xf32> -} -// CHECK-LABEL: func @test_transfer_read_of_one_dim_expand_shape( -// CHECK-SAME: %[[IN:.*]]: tensor<10xf32> -// CHECK-DAG: %[[MIN_FLOAT:.*]] = arith.constant dense<-3.402820e+38> : vector<5xf32> -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[ZERO_FLOAT:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<5xf32> -// CHECK: %[[TRANSFER_READ:.*]] = vector.transfer_read %[[IN]][%[[C0]]], %[[ZERO_FLOAT]] {in_bounds = [true]} : tensor<10xf32>, vector<10xf32> -// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[TRANSFER_READ]] : vector<10xf32> to vector<2x5xf32> -// CHECK: %[[MULTI_REDUCTION:.*]] = vector.multi_reduction , %[[SHAPE_CAST]], %[[MIN_FLOAT]] [0] : vector<2x5xf32> to vector<5xf32> -// CHECK: %[[TRANSFER_WRITE:.*]] = vector.transfer_write %[[MULTI_REDUCTION]], %[[INIT_TENSOR]][%[[C0]]] {in_bounds = [true]} : vector<5xf32>, tensor<5xf32> -// CHECK: return %[[TRANSFER_WRITE]] : tensor<5xf32> - -// ----- - -func.func @tiled_matmul(%arg0: tensor<128x16xf32>, %arg1: tensor<16x64xf32>, - %arg2: tensor<128x64xf32>) -> tensor<128x64xf32> { - %c2 = arith.constant 2 : index - %c16 = arith.constant 16 : index - %c8 = arith.constant 8 : index - %c4 = arith.constant 4 : index - %c0 = arith.constant 0 : index - %c128 = arith.constant 128 : index - %c64 = arith.constant 64 : index - %0 = gml_st.parallel (%arg3, %arg4) = - (%c0, %c0) to (%c128, %c64) step (%c8, %c4) { - %1 = gml_st.tile [%arg3, 0] [8, 16] [1, 1] : !gml_st.tile<8x16> - %2 = gml_st.materialize %arg0[%1] : - tensor<128x16xf32>[!gml_st.tile<8x16>] to tensor<8x16xf32> - %3 = gml_st.tile [0, %arg4] [16, 4] [1, 1] : !gml_st.tile<16x4> - %4 = gml_st.materialize %arg1[%3] : - tensor<16x64xf32>[!gml_st.tile<16x4>] to tensor<16x4xf32> - %5 = gml_st.tile [%arg3, %arg4] [8, 4] [1, 1] : !gml_st.tile<8x4> - %6 = gml_st.materialize %arg2[%5] : - tensor<128x64xf32>[!gml_st.tile<8x4>] to tensor<8x4xf32> - %7 = gml_st.for (%arg5) = - (%c0) to (%c16) step (%c2) outs (%arg6 = %6: tensor<8x4xf32>) { - %8 = gml_st.tile [0, %arg5] [8, 2] [1, 1] : !gml_st.tile<8x2> - %9 = gml_st.materialize %2[%8] : - tensor<8x16xf32>[!gml_st.tile<8x2>] to tensor<8x2xf32> - %10 = gml_st.tile [%arg5, 0] [2, 4] [1, 1] : !gml_st.tile<2x4> - %11 = gml_st.materialize %4[%10] : - tensor<16x4xf32>[!gml_st.tile<2x4>] to tensor<2x4xf32> - %12 = gml_st.tile [0, 0] [8, 4] [1, 1] : !gml_st.tile<8x4> - %13 = gml_st.materialize %arg6[%12] : - tensor<8x4xf32>[!gml_st.tile<8x4>] to tensor<8x4xf32> - %14 = linalg.matmul ins(%9, %11 : tensor<8x2xf32>, tensor<2x4xf32>) - outs(%13 : tensor<8x4xf32>) -> tensor<8x4xf32> - gml_st.set_yield %14 into %arg6[%12] : - tensor<8x4xf32> into tensor<8x4xf32>[!gml_st.tile<8x4>] - } : tensor<8x4xf32> - gml_st.set_yield %7 into %arg2[%5] : - tensor<8x4xf32> into tensor<128x64xf32>[!gml_st.tile<8x4>] - } : tensor<128x64xf32> - return %0 : tensor<128x64xf32> -} - -// CHECK-LABEL: func @tiled_matmul - -// CHECK: gml_st.for - -// CHECK: %[[LHS:.*]] = vector.transfer_read {{.*}} : tensor<8x2xf32>, vector<8x2xf32> -// CHECK: %[[RHS:.*]] = vector.transfer_read {{.*}} : tensor<2x4xf32>, vector<2x4xf32> -// CHECK: %[[OUT:.*]] = vector.transfer_read {{.*}} : tensor<8x4xf32>, vector<8x4xf32> -// CHECK: vector.contract {{{.*}}} %[[LHS]], %[[RHS]], %[[OUT]] - -// CHECK-NOT: linalg.matmul - -// ----- - -#map0 = affine_map<(d0, d1) -> (d0, 0)> -func.func @test_transfer_read_of_one_dim_expand_shape_different_shape( - %in: tensor<1xf32>) -> tensor<18xf32> { - %c0 = arith.constant 0 : index - %min_float = arith.constant dense<-3.402820e+38> : vector<18xf32> - %zero_float = arith.constant 0.000000e+00 : f32 - %0 = tensor.expand_shape %in [[0, 1]] : tensor<1xf32> into tensor<1x1xf32> - %1 = tensor.empty() : tensor<18xf32> - %2 = vector.transfer_read %0[%c0, %c0], %zero_float - {in_bounds = [true, true], permutation_map = #map0} - : tensor<1x1xf32>, vector<1x18xf32> - %3 = vector.multi_reduction , %2, %min_float [0] - : vector<1x18xf32> to vector<18xf32> - %4 = vector.transfer_write %3, %1[%c0] {in_bounds = [true]} - : vector<18xf32>, tensor<18xf32> - func.return %4 : tensor<18xf32> -} -// CHECK-LABEL: func @test_transfer_read_of_one_dim_expand_shape_different_shape -// CHECK: %{{.*}} = tensor.expand_shape - -// ----- - -func.func @do_not_vectorize_large_untiled_fill() -> tensor<2x1000xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %init = tensor.empty() : tensor<2x1000xf32> - %out = linalg.fill ins(%cst : f32) outs(%init : tensor<2x1000xf32>) -> tensor<2x1000xf32> - func.return %out : tensor<2x1000xf32> -} -// CHECK-LABEL: func @do_not_vectorize_large_untiled_fill -// CHECK: linalg.fill - -// ----- - -func.func @vectorize_small_untiled_fill() -> tensor<128xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %init = tensor.empty() : tensor<128xf32> - %out = linalg.fill ins(%cst : f32) outs(%init : tensor<128xf32>) -> tensor<128xf32> - func.return %out : tensor<128xf32> -} -// CHECK-LABEL: func @vectorize_small_untiled_fill -// CHECK: vector.transfer_write diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize.mlir deleted file mode 100644 index 9654093035f..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize.mlir +++ /dev/null @@ -1,213 +0,0 @@ -// Test vectorization of gml_st.parallel and gml_st.for loops. -// RUN: mlir-hlo-opt %s --split-input-file --vectorize-gml-st-loops \ -// RUN: | FileCheck %s - -#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -#map1 = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK-LABEL: @parallel_with_tiles( -func.func @parallel_with_tiles( - %arg0: memref, %arg1: memref, %arg2: memref) - -> memref { - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %0 = memref.dim %arg0, %c0 : memref - %1 = memref.dim %arg0, %c1 : memref - gml_st.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c4, %c1) { - %6 = memref.subview %arg2[%arg3, %arg4] [4, 1] [1, 1] - : memref to memref<4x1xf32, #map0> - %7 = memref.subview %arg1[%arg3, %arg4] [4, 1] [1, 1] - : memref to memref<4x1xf32, #map0> - %8 = memref.subview %arg0[%arg3, %arg4] [4, 1] [1, 1] - : memref to memref<4x1xf32, #map0> - linalg.generic {indexing_maps = [#map1, #map1, #map1], - iterator_types = ["parallel", "parallel"]} - ins(%8, %7 : memref<4x1xf32, #map0>, memref<4x1xf32, #map0>) - outs(%6 : memref<4x1xf32, #map0>) { - ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): - %9 = arith.addf %arg5, %arg6 : f32 - linalg.yield %9 : f32 - } - gml_st.set_yield - } - func.return %arg2 : memref -} -// CHECK-NOT: linalg.generic -// CHECK: %[[LHS:.*]] = vector.transfer_read {{%.*}}[%c0, %c0] -// CHECK: %[[RHS:.*]] = vector.transfer_read {{%.*}}[%c0, %c0] -// CHECK: %[[ADD:.*]] = arith.addf %[[LHS]], %[[RHS]] : vector<4x1xf32> -// CHECK: vector.transfer_write %[[ADD]], {{%.*}}[%c0, %c0] - -// ----- - -#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -#map1 = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK-LABEL: @for_with_tiles( -func.func @for_with_tiles( - %arg0: memref, %arg1: memref, %arg2: memref) - -> memref { - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %0 = memref.dim %arg0, %c0 : memref - %1 = memref.dim %arg0, %c1 : memref - gml_st.for (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c4, %c1) { - %6 = memref.subview %arg2[%arg3, %arg4] [4, 1] [1, 1] - : memref to memref<4x1xf32, #map0> - %7 = memref.subview %arg1[%arg3, %arg4] [4, 1] [1, 1] - : memref to memref<4x1xf32, #map0> - %8 = memref.subview %arg0[%arg3, %arg4] [4, 1] [1, 1] - : memref to memref<4x1xf32, #map0> - linalg.generic {indexing_maps = [#map1, #map1, #map1], - iterator_types = ["parallel", "parallel"]} - ins(%8, %7 : memref<4x1xf32, #map0>, memref<4x1xf32, #map0>) - outs(%6 : memref<4x1xf32, #map0>) { - ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): - %9 = arith.addf %arg5, %arg6 : f32 - linalg.yield %9 : f32 - } - gml_st.set_yield - } - func.return %arg2 : memref -} -// CHECK-NOT: linalg.generic -// CHECK: %[[LHS:.*]] = vector.transfer_read {{%.*}}[%c0, %c0] -// CHECK: %[[RHS:.*]] = vector.transfer_read {{%.*}}[%c0, %c0] -// CHECK: %[[ADD:.*]] = arith.addf %[[LHS]], %[[RHS]] : vector<4x1xf32> -// CHECK: vector.transfer_write %[[ADD]], {{%.*}}[%c0, %c0] - -// ----- - -#map3 = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: @parallel_on_tensor( -func.func @parallel_on_tensor( - %arg0: tensor, %arg1: tensor, %arg2: tensor) - -> tensor { - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %0 = tensor.dim %arg0, %c0 : tensor - %2 = gml_st.parallel (%i) = (%c0) to (%0) step (%c4) { - %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> - %6 = gml_st.materialize %arg0[%tile] - : tensor[!gml_st.tile<4>] to tensor<4xf32> - %7 = gml_st.materialize %arg1[%tile] - : tensor[!gml_st.tile<4>] to tensor<4xf32> - %8 = gml_st.materialize %arg2[%tile] - : tensor[!gml_st.tile<4>] to tensor<4xf32> - %9 = linalg.generic {indexing_maps = [#map3, #map3, #map3], - iterator_types = ["parallel"]} - ins(%6, %7 : tensor<4xf32>, tensor<4xf32>) - outs(%8 : tensor<4xf32>) { - ^bb0(%arg5: f32, %arg6: f32, %arg7: f32): - %10 = arith.addf %arg5, %arg6 : f32 - linalg.yield %10 : f32 - } -> tensor<4xf32> - gml_st.set_yield %9 into %arg2[%tile] - : tensor<4xf32> into tensor[!gml_st.tile<4>] - } : tensor - func.return %2 : tensor -} -// CHECK-NOT: linalg.generic -// CHECK: %[[LHS:.*]] = vector.transfer_read {{%.*}}[%c0] -// CHECK: %[[RHS:.*]] = vector.transfer_read {{%.*}}[%c0] -// CHECK: %[[ADD:.*]] = arith.addf %[[LHS]], %[[RHS]] : vector<4xf32> -// CHECK: vector.transfer_write %[[ADD]], {{%.*}}[%c0] - -// ----- - -// CHECK-LABEL: @single_element_tensor_to_element( -// CHECK-SAME: %[[IN:.*]]: vector<1xf32> -func.func @single_element_tensor_to_element(%in : vector<1xf32>) -> f32 { - %c0 = arith.constant 0 : index - %pad = arith.constant 0.0 : f32 - %empty = tensor.empty() : tensor<1xf32> - %r = vector.transfer_write %in, %empty[%c0] {in_bounds = [true]} - : vector<1xf32>, tensor<1xf32> - %v = tensor.extract %r[%c0] : tensor<1xf32> - return %v : f32 -} -// CHECK: %[[RESULT:.*]] = vector.extract %[[IN]][0] -// CHECK: return %[[RESULT]] - -// ----- - -// CHECK-LABEL: @zero_dim_element_tensor_to_element( -// CHECK-SAME: %[[IN:.*]]: vector -func.func @zero_dim_element_tensor_to_element(%in : vector) -> f32 { - %pad = arith.constant 0.0 : f32 - %empty = tensor.empty() : tensor - %r = vector.transfer_write %in, %empty[] {in_bounds = []} - : vector, tensor - %v = tensor.extract %r[] : tensor - return %v : f32 -} -// CHECK: %[[RESULT:.*]] = vector.extractelement %[[IN]][] -// CHECK: return %[[RESULT]] - -// ----- - -// CHECK-LABEL: @read_of_empty_float_to_constant( -func.func @read_of_empty_float_to_constant(%pad : f32) -> vector<32xf32> { - %empty = tensor.empty() : tensor<32xf32> - %c0 = arith.constant 0 : index - %r = vector.transfer_read %empty[%c0], %pad {in_bounds = [true]} - : tensor<32xf32>, vector<32xf32> - return %r : vector<32xf32> -} -// CHECK: %[[RESULT:.*]] = arith.constant dense<0x7FC00000> : vector<32xf32> -// CHECK: return %[[RESULT]] - -// ----- - -// CHECK-LABEL: @read_of_empty_int_to_constant( -func.func @read_of_empty_int_to_constant(%pad : i8) -> vector<32xi8> { - %empty = tensor.empty() : tensor<32xi8> - %c0 = arith.constant 0 : index - %r = vector.transfer_read %empty[%c0], %pad {in_bounds = [true]} - : tensor<32xi8>, vector<32xi8> - return %r : vector<32xi8> -} -// CHECK: %[[RESULT:.*]] = arith.constant dense<0> : vector<32xi8> -// CHECK: return %[[RESULT]] -// ----- - -// CHECK-LABEL: @materialize_scalar_from_0D_vector( -// CHECK-SAME: %[[V:.*]]: vector -func.func @materialize_scalar_from_0D_vector(%v : vector) -> f32 { - %tile = gml_st.tile [] [] [] : !gml_st.tile<> - %r = gml_st.materialize %v[%tile] : vector[!gml_st.tile<>] to f32 - return %r : f32 -} -// CHECK: %[[R:.*]] = vector.extractelement %[[V]][] -// CHECK: return %[[R]] - -// ----- - -// CHECK-LABEL: @materialize_scalar_from_single_element_vector( -// CHECK-SAME: %[[V:.*]]: vector<1x1xf32> -func.func @materialize_scalar_from_single_element_vector( - %v : vector<1x1xf32>) -> f32 { - %tile = gml_st.tile [0, 0] [1, 1] [1, 1] : !gml_st.tile<1x1> - %r = gml_st.materialize %v[%tile] : vector<1x1xf32>[!gml_st.tile<1x1>] to f32 - return %r : f32 -} -// CHECK: %[[R:.*]] = vector.extract %[[V]][0, 0] -// CHECK: return %[[R]] - - -// ----- - -// CHECK-LABEL: @set_yield_scalar_into_vector( -// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<1x1xf32>) -func.func @set_yield_scalar_into_vector( - %f: f32, %v: vector<1x1xf32>) { - %tile = gml_st.tile [0, 0] [1, 1] [1, 1] : !gml_st.tile<1x1> - gml_st.set_yield %f into %v[%tile] - : f32 into vector<1x1xf32>[!gml_st.tile<1x1>] -} -// CHECK: %[[R:.*]] = vector.insert %[[F]], %[[V]] [0, 0] -// CHECK: gml_st.set_yield %[[R]] into %[[V]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_copy.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_copy.mlir new file mode 100644 index 00000000000..ba6a604b95e --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_copy.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-hlo-opt %s --vectorize-copy --split-input-file | FileCheck %s + +func.func @vectorize_copy(%arg: memref<10x16xf32>) -> memref<10x10xf32> { + %subview = memref.subview %arg[0, 0] [10, 10] [1, 1] : memref<10x16xf32> to memref<10x10xf32, strided<[16, 1]>> + %alloc = memref.alloc() : memref<10x10xf32> + memref.copy %subview, %alloc : memref<10x10xf32, strided<[16, 1]>> to memref<10x10xf32> + return %alloc : memref<10x10xf32> +} + +// CHECK-LABEL: func @vectorize_copy + +// CHECK-NOT: memref.copy +// CHECK: vector.transfer_read +// CHECK: vector.transfer_write + +// ----- + +func.func @do_not_vectorize_copy(%arg: memref<10x10xf32>) -> memref<10x10xf32> { + %alloc_10 = memref.alloc() : memref<10x10xf32> + memref.copy %arg, %alloc_10 : memref<10x10xf32> to memref<10x10xf32> + return %alloc_10 : memref<10x10xf32> +} + +// CHECK-LABEL: func @do_not_vectorize_copy + +// CHECK-NOT: vector.transfer_read +// CHECK-NOT: vector.transfer_write +// CHECK: memref.copy diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_for_cpu.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_for_cpu.mlir new file mode 100644 index 00000000000..4aafd2d5d76 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_for_cpu.mlir @@ -0,0 +1,224 @@ +// RUN: mlir-hlo-opt %s --vectorize-for-cpu --split-input-file |\ +// RUN: FileCheck %s + + +func.func @vectorize_tiled_matmul(%lhs: tensor<8x16xf32>, + %rhs: tensor<16x4xf32>, %fill: tensor<8x4xf32>) -> tensor<8x4xf32> { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + + %7 = gml_st.for (%i) = + (%c0) to (%c16) step (%c2) outs (%arg6 = %fill: tensor<8x4xf32>) { + %9 = tensor.extract_slice %lhs[0, %i] [8, 2] [1, 1] : + tensor<8x16xf32> to tensor<8x2xf32> + + %11 = tensor.extract_slice %rhs[%i, 0] [2, 4] [1, 1] : + tensor<16x4xf32> to tensor<2x4xf32> + + %13 = tensor.extract_slice %arg6[0, 0] [8, 4] [1, 1] : + tensor<8x4xf32> to tensor<8x4xf32> + + %14 = linalg.matmul ins(%9, %11 : tensor<8x2xf32>, tensor<2x4xf32>) + outs(%13 : tensor<8x4xf32>) -> tensor<8x4xf32> + + %12 = gml_st.tile [0, 0] [8, 4] [1, 1] : !gml_st.tile<8x4> + gml_st.set_yield %14 into %arg6[%12] : + tensor<8x4xf32> into tensor<8x4xf32>[!gml_st.tile<8x4>] + } {__perfectly_tileable_loop_label__} : tensor<8x4xf32> + return %7 : tensor<8x4xf32> +} + +// CHECK-LABEL: func @vectorize_tiled_matmul + +// CHECK: %[[OUT_READ:.*]] = vector.transfer_read %[[OUT:.*]] +// CHECK: %[[FOR:.*]] = gml_st.for {{.*}} outs (%[[ARG:.*]] = +// CHECK: %[[LHS:.*]] = vector.transfer_read +// CHECK-SAME: : tensor<8x16xf32>, vector<8x2xf32> +// CHECK: %[[RHS:.*]] = vector.transfer_read +// CHECK-SAME: : tensor<16x4xf32>, vector<2x4xf32> +// CHECK: %[[CONTRACT:.*]] = vector.contract +// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ARG]] +// CHECK: gml_st.set_yield %[[CONTRACT]] into %[[ARG]] +// CHECK: vector.transfer_write %[[FOR]] + +// ----- + +func.func @vectorize_static_matmul(%lhs: tensor<128x16xf32>, + %rhs: tensor<16x64xf32>, %fill: tensor<128x64xf32>) -> tensor<128x64xf32> { + %c2 = arith.constant 2 : index + %c16 = arith.constant 16 : index + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %0 = gml_st.parallel (%i, %j) = (%c0, %c0) to (%c128, %c64) step (%c8, %c4) + outs (%out_ = %fill: tensor<128x64xf32>) { + %2 = tensor.extract_slice %lhs[%i, 0] [8, 16] [1, 1] : + tensor<128x16xf32> to tensor<8x16xf32> + %4 = tensor.extract_slice %rhs[0, %j] [16, 4] [1, 1] : + tensor<16x64xf32> to tensor<16x4xf32> + %6 = tensor.extract_slice %fill[%i, %j] [8, 4] [1, 1] : + tensor<128x64xf32> to tensor<8x4xf32> + %7 = gml_st.for (%k) = + (%c0) to (%c16) step (%c2) outs (%arg6 = %6: tensor<8x4xf32>) { + %9 = tensor.extract_slice %2[0, %k] [8, 2] [1, 1] : + tensor<8x16xf32> to tensor<8x2xf32> + %11 = tensor.extract_slice %4[%k, 0] [2, 4] [1, 1] : + tensor<16x4xf32> to tensor<2x4xf32> + %13 = tensor.extract_slice %arg6[0, 0] [8, 4] [1, 1] : + tensor<8x4xf32> to tensor<8x4xf32> + %14 = linalg.matmul ins(%9, %11 : tensor<8x2xf32>, tensor<2x4xf32>) + outs(%13 : tensor<8x4xf32>) -> tensor<8x4xf32> + %12 = gml_st.tile [0, 0] [8, 4] [1, 1] : !gml_st.tile<8x4> + gml_st.set_yield %14 into %arg6[%12] : + tensor<8x4xf32> into tensor<8x4xf32>[!gml_st.tile<8x4>] + } : tensor<8x4xf32> + %5 = gml_st.tile [%i, %j] [8, 4] [1, 1] : !gml_st.tile<8x4> + gml_st.set_yield %7 into %out_[%5] : + tensor<8x4xf32> into tensor<128x64xf32>[!gml_st.tile<8x4>] + } : tensor<128x64xf32> + return %0 : tensor<128x64xf32> +} +// CHECK-LABEL: func @vectorize_static_matmul + +// CHECK: %[[OUT_READ:.*]] = vector.transfer_read {{.*}} : tensor<8x4xf32>, vector<8x4xf32> +// CHECK: %[[FOR:.*]] = gml_st.for {{.*}} outs (%[[ARG:.*]] = %[[OUT_READ]] +// CHECK-NOT: linalg.matmul +// CHECK: %[[LHS:.*]] = vector.transfer_read {{.*}} : tensor<128x16xf32>, vector<8x2xf32> +// CHECK: %[[RHS:.*]] = vector.transfer_read {{.*}} : tensor<16x64xf32>, vector<2x4xf32> +// CHECK-NOT: vector.transfer_read +// CHECK: %[[CONTRACT:.*]] = vector.contract {{{.*}}} %[[LHS]], %[[RHS]], %[[ARG]] +// CHECK: gml_st.set_yield %[[CONTRACT]] into %[[ARG]] +// CHECK: vector.transfer_write %[[FOR]] + +// ----- + +func.func @pad(%arg0: tensor<10x10xf32>) -> tensor<16x10xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %padded = tensor.pad %arg0 low[0, 0] high[6, 0] { + ^bb0(%arg3: index, %arg4: index): + tensor.yield %cst : f32 + } : tensor<10x10xf32> to tensor<16x10xf32> + + return %padded : tensor<16x10xf32> +} + +// CHECK-LABEL: func @pad( + +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<16x10xf32> +// CHECK: %[[FILL:.*]] = linalg.fill {{.*}} outs(%[[EMPTY]] +// CHECK: %[[READ:.*]] = vector.transfer_read +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[FILL]] +// CHECK: return %[[WRITE]] + +// ----- + +func.func @transpose(%input: tensor<4x5x6xf32>, + %init: tensor<5x6x4xf32>) -> tensor<5x6x4xf32> { + %transpose = linalg.transpose + ins(%input:tensor<4x5x6xf32>) + outs(%init:tensor<5x6x4xf32>) + permutation = [1, 2, 0] + func.return %transpose : tensor<5x6x4xf32> +} + +// CHECK-LABEL: func @transpose( +// CHECK-SAME: %[[INPUT:.*]]: tensor<4x5x6xf32> +// CHECK-SAME: %[[INIT:.*]]: tensor<5x6x4xf32> + +// CHECK: %[[READ:.*]] = vector.transfer_read %[[INPUT]] +// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[READ]], [1, 2, 0] +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[TRANSPOSE]], %[[INIT]] +// CHECK: return %[[WRITE]] + +// ----- + +func.func @simplify_identity_transpose(%input: tensor<1x1xf32>, + %init: tensor<1x1xf32>) -> tensor<1x1xf32> { + %transpose = linalg.transpose + ins(%input:tensor<1x1xf32>) + outs(%init:tensor<1x1xf32>) + permutation = [0, 1] + func.return %transpose : tensor<1x1xf32> +} + +// CHECK-LABEL: func @simplify_identity_transpose( + +// CHECK-NOT: linalg.transpose +// CHECK: return + +// ----- + +func.func @do_not_simplify_transpose(%input: tensor<1x1xf32>, + %init: tensor<1x1xf32>) -> tensor<1x1xf32> { + %transpose = linalg.transpose + ins(%input:tensor<1x1xf32>) + outs(%init:tensor<1x1xf32>) + permutation = [1, 0] + func.return %transpose : tensor<1x1xf32> +} + +// CHECK-LABEL: func @do_not_simplify_transpose( + +// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose +// CHECK: return %[[TRANSPOSE]] + +// ----- + +func.func @perfectly_tiled_reverse_1d(%input: tensor<8xf32>, + %init: tensor<8xf32>) -> tensor<8xf32> { + %res = thlo.reverse + ins(%input: tensor<8xf32>) + outs(%init: tensor<8xf32>) + reverse_dimensions = [0] + func.return %res : tensor<8xf32> +} + +// CHECK-LABEL: func @perfectly_tiled_reverse_1d( +// CHECK-SAME: %[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xf32> +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]] +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[READ]] +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[SHUFFLE]], %[[ARG1]] +// CHECK: return %[[WRITE]] + +// ----- + +func.func @perfectly_tiled_reverse_2d(%input: tensor<1x8xf32>, + %init: tensor<1x8xf32>) -> tensor<1x8xf32> { + %res = thlo.reverse + ins(%input: tensor<1x8xf32>) + outs(%init: tensor<1x8xf32>) + reverse_dimensions = [1] + func.return %res : tensor<1x8xf32> +} + +// CHECK-LABEL: func @perfectly_tiled_reverse_2d( +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x8xf32>, %[[ARG1:.*]]: tensor<1x8xf32> +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]] +// CHECK-SAME: : tensor<1x8xf32>, vector<8xf32> +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[READ]] +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[SHUFFLE]], %[[ARG1]] +// CHECK-SAME: : vector<8xf32>, tensor<1x8xf32> +// CHECK: return %[[WRITE]] + +// ----- + +func.func @perfectly_tiled_reverse_4d(%input: tensor<1x1x1x8xf32>, + %init: tensor<1x1x1x8xf32>) -> tensor<1x1x1x8xf32> { + %res = thlo.reverse + ins(%input: tensor<1x1x1x8xf32>) + outs(%init: tensor<1x1x1x8xf32>) + reverse_dimensions = [3] + func.return %res : tensor<1x1x1x8xf32> +} + +// CHECK-LABEL: func @perfectly_tiled_reverse_4d( +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x8xf32>, %[[ARG1:.*]]: tensor<1x1x1x8xf32> +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]] +// CHECK-SAME: : tensor<1x1x1x8xf32>, vector<8xf32> +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[READ]] +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[SHUFFLE]], %[[ARG1]] +// CHECK-SAME: : vector<8xf32>, tensor<1x1x1x8xf32> +// CHECK: return %[[WRITE]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_for_gpu.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_for_gpu.mlir new file mode 100644 index 00000000000..525b8f88eef --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_for_gpu.mlir @@ -0,0 +1,300 @@ +// RUN: mlir-hlo-opt %s --vectorize-for-gpu --split-input-file |\ +// RUN: FileCheck %s + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func.func @test_transfer_read_of_one_dim_expand_shape( + %in: tensor<10xf32>) -> tensor<5xf32> { + %c0 = arith.constant 0 : index + %min_float = arith.constant dense<-3.402820e+38> : vector<5xf32> + %zero_float = arith.constant 0.000000e+00 : f32 + %0 = tensor.expand_shape %in [[0, 1]] : tensor<10xf32> into tensor<2x5xf32> + %1 = tensor.empty() : tensor<5xf32> + %2 = vector.transfer_read %0[%c0, %c0], %zero_float + {in_bounds = [true, true], permutation_map = #map0} + : tensor<2x5xf32>, vector<2x5xf32> + %3 = vector.multi_reduction , %2, %min_float [0] + : vector<2x5xf32> to vector<5xf32> + %4 = vector.transfer_write %3, %1[%c0] {in_bounds = [true]} + : vector<5xf32>, tensor<5xf32> + func.return %4 : tensor<5xf32> +} +// CHECK-LABEL: func @test_transfer_read_of_one_dim_expand_shape( +// CHECK-SAME: %[[IN:.*]]: tensor<10xf32> +// CHECK-DAG: %[[MIN_FLOAT:.*]] = arith.constant dense<-3.402820e+38> : vector<5xf32> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[ZERO_FLOAT:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<5xf32> +// CHECK: %[[TRANSFER_READ:.*]] = vector.transfer_read %[[IN]][%[[C0]]], %[[ZERO_FLOAT]] {in_bounds = [true]} : tensor<10xf32>, vector<10xf32> +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[TRANSFER_READ]] : vector<10xf32> to vector<2x5xf32> +// CHECK: %[[MULTI_REDUCTION:.*]] = vector.multi_reduction , %[[SHAPE_CAST]], %[[MIN_FLOAT]] [0] : vector<2x5xf32> to vector<5xf32> +// CHECK: %[[TRANSFER_WRITE:.*]] = vector.transfer_write %[[MULTI_REDUCTION]], %[[INIT_TENSOR]][%[[C0]]] {in_bounds = [true]} : vector<5xf32>, tensor<5xf32> +// CHECK: return %[[TRANSFER_WRITE]] : tensor<5xf32> + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, 0)> +func.func @test_transfer_read_of_one_dim_expand_shape_different_shape( + %in: tensor<1xf32>) -> tensor<18xf32> { + %c0 = arith.constant 0 : index + %min_float = arith.constant dense<-3.402820e+38> : vector<18xf32> + %zero_float = arith.constant 0.000000e+00 : f32 + %0 = tensor.expand_shape %in [[0, 1]] : tensor<1xf32> into tensor<1x1xf32> + %1 = tensor.empty() : tensor<18xf32> + %2 = vector.transfer_read %0[%c0, %c0], %zero_float + {in_bounds = [true, true], permutation_map = #map0} + : tensor<1x1xf32>, vector<1x18xf32> + %3 = vector.multi_reduction , %2, %min_float [0] + : vector<1x18xf32> to vector<18xf32> + %4 = vector.transfer_write %3, %1[%c0] {in_bounds = [true]} + : vector<18xf32>, tensor<18xf32> + func.return %4 : tensor<18xf32> +} +// CHECK-LABEL: func @test_transfer_read_of_one_dim_expand_shape_different_shape +// CHECK: %{{.*}} = tensor.expand_shape + +// ----- + +func.func @do_not_vectorize_large_untiled_fill() -> tensor<2x1000xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %init = tensor.empty() : tensor<2x1000xf32> + %out = linalg.fill ins(%cst : f32) outs(%init : tensor<2x1000xf32>) -> tensor<2x1000xf32> + func.return %out : tensor<2x1000xf32> +} +// CHECK-LABEL: func @do_not_vectorize_large_untiled_fill +// CHECK: linalg.fill + +// ----- + +func.func @vectorize_small_untiled_fill() -> tensor<128xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %init = tensor.empty() : tensor<128xf32> + %out = linalg.fill ins(%cst : f32) outs(%init : tensor<128xf32>) -> tensor<128xf32> + func.return %out : tensor<128xf32> +} +// CHECK-LABEL: func @vectorize_small_untiled_fill +// CHECK: vector.transfer_write + +// ----- + +#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> + +// CHECK-LABEL: @parallel_with_tiles( +func.func @parallel_with_tiles( + %arg0: memref, %arg1: memref, %arg2: memref) + -> memref { + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = memref.dim %arg0, %c0 : memref + %1 = memref.dim %arg0, %c1 : memref + gml_st.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c4, %c1) { + %6 = memref.subview %arg2[%arg3, %arg4] [4, 1] [1, 1] + : memref to memref<4x1xf32, #map0> + %7 = memref.subview %arg1[%arg3, %arg4] [4, 1] [1, 1] + : memref to memref<4x1xf32, #map0> + %8 = memref.subview %arg0[%arg3, %arg4] [4, 1] [1, 1] + : memref to memref<4x1xf32, #map0> + linalg.map { arith.addf } + ins(%8, %7 : memref<4x1xf32, #map0>, memref<4x1xf32, #map0>) + outs(%6 : memref<4x1xf32, #map0>) + gml_st.set_yield + } + func.return %arg2 : memref +} +// CHECK-NOT: linalg.map +// CHECK: %[[LHS:.*]] = vector.transfer_read {{%.*}}[%c0, %c0] +// CHECK: %[[RHS:.*]] = vector.transfer_read {{%.*}}[%c0, %c0] +// CHECK: %[[ADD:.*]] = arith.addf %[[LHS]], %[[RHS]] : vector<4x1xf32> +// CHECK: vector.transfer_write %[[ADD]], {{%.*}}[%c0, %c0] + +// ----- + +#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> + +// CHECK-LABEL: @for_with_tiles( +func.func @for_with_tiles( + %arg0: memref, %arg1: memref, %arg2: memref) + -> memref { + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = memref.dim %arg0, %c0 : memref + %1 = memref.dim %arg0, %c1 : memref + gml_st.for (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c4, %c1) { + %6 = memref.subview %arg2[%arg3, %arg4] [4, 1] [1, 1] + : memref to memref<4x1xf32, #map0> + %7 = memref.subview %arg1[%arg3, %arg4] [4, 1] [1, 1] + : memref to memref<4x1xf32, #map0> + %8 = memref.subview %arg0[%arg3, %arg4] [4, 1] [1, 1] + : memref to memref<4x1xf32, #map0> + linalg.map { arith.addf } + ins(%8, %7 : memref<4x1xf32, #map0>, memref<4x1xf32, #map0>) + outs(%6 : memref<4x1xf32, #map0>) + gml_st.set_yield + } + func.return %arg2 : memref +} +// CHECK-NOT: linalg.map +// CHECK: %[[LHS:.*]] = vector.transfer_read {{%.*}}[%c0, %c0] +// CHECK: %[[RHS:.*]] = vector.transfer_read {{%.*}}[%c0, %c0] +// CHECK: %[[ADD:.*]] = arith.addf %[[LHS]], %[[RHS]] : vector<4x1xf32> +// CHECK: vector.transfer_write %[[ADD]], {{%.*}}[%c0, %c0] + +// ----- + +// CHECK-LABEL: @parallel_on_tensor( +// CHECK: {{%.*}}: tensor, {{%.*}}: tensor, %[[ARG2:.*]]: tensor) +func.func @parallel_on_tensor( + %arg0: tensor, %arg1: tensor, %arg2: tensor) + -> tensor { + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = tensor.dim %arg0, %c0 : tensor + %2 = gml_st.parallel (%i) = (%c0) to (%0) step (%c4) + outs (%out_ = %arg2: tensor) { + %6 = tensor.extract_slice %arg0[%i] [4] [1] + : tensor to tensor<4xf32> + %7 = tensor.extract_slice %arg1[%i] [4] [1] + : tensor to tensor<4xf32> + %8 = tensor.extract_slice %out_[%i] [4] [1] + : tensor to tensor<4xf32> + %9 = linalg.map { arith.addf } + ins(%6, %7 : tensor<4xf32>, tensor<4xf32>) + outs(%8 : tensor<4xf32>) + %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> + gml_st.set_yield %9 into %out_[%tile] + : tensor<4xf32> into tensor[!gml_st.tile<4>] + } : tensor + func.return %2 : tensor +} +// CHECK-NOT: linalg.map +// CHECK: gml_st.parallel (%[[ITER:.*]]) = (%[[C0:[a-z0-9]+]]) +// CHECK: %[[LHS:.*]] = vector.transfer_read {{%[a-z0-9_]+}}[%[[C0]]] +// CHECK: %[[RHS:.*]] = vector.transfer_read {{%[a-z0-9_]+}}[%[[C0]]] +// CHECK: %[[ADD:.*]] = arith.addf %[[LHS]], %[[RHS]] : vector<4xf32> + +// ----- + +// CHECK-LABEL: @single_element_tensor_to_element( +// CHECK-SAME: %[[IN:.*]]: vector<1xf32> +func.func @single_element_tensor_to_element(%in : vector<1xf32>) -> f32 { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + %empty = tensor.empty() : tensor<1xf32> + %r = vector.transfer_write %in, %empty[%c0] {in_bounds = [true]} + : vector<1xf32>, tensor<1xf32> + %v = tensor.extract %r[%c0] : tensor<1xf32> + return %v : f32 +} +// CHECK: %[[RESULT:.*]] = vector.extract %[[IN]][0] +// CHECK: return %[[RESULT]] + +// ----- + +// CHECK-LABEL: @zero_dim_element_tensor_to_element( +// CHECK-SAME: %[[IN:.*]]: vector +func.func @zero_dim_element_tensor_to_element(%in : vector) -> f32 { + %pad = arith.constant 0.0 : f32 + %empty = tensor.empty() : tensor + %r = vector.transfer_write %in, %empty[] {in_bounds = []} + : vector, tensor + %v = tensor.extract %r[] : tensor + return %v : f32 +} +// CHECK: %[[RESULT:.*]] = vector.extractelement %[[IN]][] +// CHECK: return %[[RESULT]] + +// ----- + +// CHECK-LABEL: @read_of_empty_float_to_constant( +func.func @read_of_empty_float_to_constant(%pad : f32) -> vector<32xf32> { + %empty = tensor.empty() : tensor<32xf32> + %c0 = arith.constant 0 : index + %r = vector.transfer_read %empty[%c0], %pad {in_bounds = [true]} + : tensor<32xf32>, vector<32xf32> + return %r : vector<32xf32> +} +// CHECK: %[[RESULT:.*]] = arith.constant dense<0x7FC00000> : vector<32xf32> +// CHECK: return %[[RESULT]] + +// ----- + +// CHECK-LABEL: @read_of_empty_int_to_constant( +func.func @read_of_empty_int_to_constant(%pad : i8) -> vector<32xi8> { + %empty = tensor.empty() : tensor<32xi8> + %c0 = arith.constant 0 : index + %r = vector.transfer_read %empty[%c0], %pad {in_bounds = [true]} + : tensor<32xi8>, vector<32xi8> + return %r : vector<32xi8> +} +// CHECK: %[[RESULT:.*]] = arith.constant dense<0> : vector<32xi8> +// CHECK: return %[[RESULT]] +// ----- + +// CHECK-LABEL: @materialize_scalar_from_0D_vector( +// CHECK-SAME: %[[V:.*]]: vector +func.func @materialize_scalar_from_0D_vector(%v : vector) -> f32 { + %r = gml_st.materialize %v[][][] : vector to f32 + return %r : f32 +} +// CHECK: %[[R:.*]] = vector.extractelement %[[V]][] +// CHECK: return %[[R]] + +// ----- + +// CHECK-LABEL: @materialize_scalar_from_single_element_vector( +// CHECK-SAME: %[[V:.*]]: vector<1x1xf32> +func.func @materialize_scalar_from_single_element_vector( + %v : vector<1x1xf32>) -> f32 { + %r = gml_st.materialize %v[0, 0] [1, 1] [1, 1] + : vector<1x1xf32> to f32 + return %r : f32 +} +// CHECK: %[[R:.*]] = vector.extract %[[V]][0, 0] +// CHECK: return %[[R]] + + +// ----- + +// CHECK-LABEL: @set_yield_scalar_into_vector( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<1x1xf32>) +func.func @set_yield_scalar_into_vector( + %f: f32, %v: vector<1x1xf32>) { + %tile = gml_st.tile [0, 0] [1, 1] [1, 1] : !gml_st.tile<1x1> + gml_st.set_yield %f into %v[%tile] + : f32 into vector<1x1xf32>[!gml_st.tile<1x1>] +} +// CHECK: %[[R:.*]] = vector.insert %[[F]], %[[V]] [0, 0] +// CHECK: gml_st.set_yield %[[R]] into %[[V]] + +// ----- + +func.func @fold_identity_materialize(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) + -> tensor<8x8xf32> { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %6 = gml_st.for (%arg4) = (%c0) to (%c16) step (%c8) outs (%arg5 = %0: tensor<8x8xf32>) { + %19 = tensor.extract_slice %arg0[%c0, %arg4] [8, 8] [1, 1] : tensor<8x16xf32> to tensor<8x8xf32> + %21 = tensor.extract_slice %arg1[%arg4, %c0] [8, 8] [1, 1] : tensor<16x8xf32> to tensor<8x8xf32> + %23 = tensor.extract_slice %arg5[0, 0] [8, 8] [1, 1] : tensor<8x8xf32> to tensor<8x8xf32> + %28 = linalg.fill ins(%cst_0 : f32) outs(%23 : tensor<8x8xf32>) -> tensor<8x8xf32> + %29 = tensor.extract_slice %28[0, 0] [8, 8] [1, 1] : tensor<8x8xf32> to tensor<8x8xf32> + %22 = gml_st.tile [0, 0] [8, 8] [1, 1] : !gml_st.tile<8x8> + gml_st.set_yield %29 into %arg5[%22] : tensor<8x8xf32> into tensor<8x8xf32>[!gml_st.tile<8x8>] + } : tensor<8x8xf32> + return %6 : tensor<8x8xf32> +} + +// CHECK-LABEL: func @fold_identity_materialize( + +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x8xf32> +// CHECK: %[[INIT:.*]] = tensor.empty + +// CHECK: gml_st.for {{.*}} outs (%[[ARG:.*]] = %[[INIT]] +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[CST]], %[[ARG]] +// CHECK: %[[TILE:.*]] = gml_st.tile [0, 0] [8, 8] [1, 1] +// CHECK: gml_st.set_yield %[[WRITE]] into %[[ARG]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_for_gpu_distributed.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_for_gpu_distributed.mlir new file mode 100644 index 00000000000..7a1ed54e747 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_for_gpu_distributed.mlir @@ -0,0 +1,235 @@ +// RUN: mlir-hlo-opt %s --split-input-file \ +// RUN: --vectorize-for-gpu="vectorize-gml-st-ops=true included-distribution-labels=test" \ +// RUN: | FileCheck %s + +func.func @vectorize_gml_st_parallel_op( + %arg0: tensor<32xf32>, %arg1: tensor<32xf32>) + -> tensor<32xf32> { + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + // We need this outer trivial loop to make sure the inner loop has a parent + // with the correct distribution label. + %2 = gml_st.parallel (%unused) = (%c0) to (%c1) step (%c1) + outs (%arg1_ = %arg1: tensor<32xf32>) distribution ("test") { + %arg0tile = tensor.extract_slice %arg0[0][32][1] + : tensor<32xf32> to tensor<32xf32> + %arg1tile = tensor.extract_slice %arg1_[0][32][1] + : tensor<32xf32> to tensor<32xf32> + %3 = gml_st.parallel (%i) = (%c0) to (%c32) step (%c4) + outs (%arg1tile_ = %arg1tile: tensor<32xf32>) distribution ("test") { + %6 = tensor.extract_slice %arg0tile[%i] [4] [1] + : tensor<32xf32> to tensor<4xf32> + %7 = tensor.extract_slice %arg1tile_[%i] [4] [1] + : tensor<32xf32> to tensor<4xf32> + %9 = linalg.map {arith.negf } + ins(%6: tensor<4xf32>) + outs(%7 : tensor<4xf32>) + %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> + gml_st.set_yield %9 into %arg1tile_[%tile] + : tensor<4xf32> into tensor<32xf32>[!gml_st.tile<4>] + } : tensor<32xf32> + %tile32 = gml_st.tile [0][32][1] : !gml_st.tile<32> + gml_st.set_yield %3 into %arg1_[%tile32] + : tensor<32xf32> into tensor<32xf32>[!gml_st.tile<32>] + } : tensor<32xf32> + func.return %2 : tensor<32xf32> +} +// CHECK-LABEL: @vectorize_gml_st_parallel_op( +// CHECK-SAME: %[[ARG0:.*]]: tensor<32xf32>, %[[ARG1:.*]]: tensor<32xf32> + +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: gml_st.parallel +// CHECK-SAME: outs (%[[ARG1_:.*]] = %[[ARG1]]: +// CHECK-DAG: vector.transfer_read %[[ARG1_]][%[[C0]]] +// CHECK: %[[RESULT:.*]] = gml_st.parallel +// CHECK: %[[LHSTILE:.*]] = tensor.extract_slice %[[ARG0]] +// CHECK: %[[LHSVEC:.*]] = vector.transfer_read %[[LHSTILE]] +// CHECK: %[[NEG:.*]] = arith.negf %[[LHSVEC]] : vector<4xf32> +// CHECK: gml_st.set_yield %[[NEG]] +// CHECK-SAME: vector<4xf32> into vector<32xf32> +// CHECK: vector.transfer_write %[[RESULT]], {{%.*}}[%c0] + +// ----- + +func.func @vectorize_gml_st_for_op( + %arg0: tensor<32xf32>, %arg1: tensor<32xf32>) + -> tensor<32xf32> { + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + // We need this outer trivial loop to make sure the inner loop has a parent + // with the correct distribution label. + %2 = gml_st.parallel (%unused) = (%c0) to (%c1) step (%c1) + outs (%out_ = %arg1 : tensor<32xf32>) distribution ("test") { + %arg0tile = tensor.extract_slice %arg0[0][32][1] + : tensor<32xf32> to tensor<32xf32> + %out_tile = tensor.extract_slice %out_[0][32][1] + : tensor<32xf32> to tensor<32xf32> + %3 = gml_st.for (%i) = (%c0) to (%c32) step (%c4) + outs(%out = %out_tile : tensor<32xf32>) { + %6 = tensor.extract_slice %arg0tile[%i][4][1] + : tensor<32xf32> to tensor<4xf32> + %7 = tensor.extract_slice %out[%i][4][1] + : tensor<32xf32> to tensor<4xf32> + %9 = linalg.map { arith.negf } + ins(%6: tensor<4xf32>) + outs(%7 : tensor<4xf32>) + %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> + gml_st.set_yield %9 into %out[%tile] + : tensor<4xf32> into tensor<32xf32>[!gml_st.tile<4>] + } : tensor<32xf32> + %tile32 = gml_st.tile [0][32][1] : !gml_st.tile<32> + gml_st.set_yield %3 into %out_[%tile32] + : tensor<32xf32> into tensor<32xf32>[!gml_st.tile<32>] + } : tensor<32xf32> + func.return %2 : tensor<32xf32> +} +// CHECK-LABEL: @vectorize_gml_st_for_op( +// CHECK-SAME: %[[ARG0:.*]]: tensor<32xf32>, %[[ARG1:.*]]: tensor<32xf32> + +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: gml_st.parallel +// CHECK-SAME: outs (%[[OUT_:.*]] = %[[ARG1]]: +// CHECK-DAG: %[[RES:.*]] = vector.transfer_read %[[OUT_]][%[[C0]]] +// CHECK: %[[RESULT:.*]] = gml_st.for +// CHECK-SAME: outs (%[[OUT:.*]] = %[[RES]]: vector<32xf32>) +// CHECK-DAG: %[[LHSTILE:.*]] = tensor.extract_slice %[[ARG0]] +// CHECK-DAG: %[[LHSVEC:.*]] = vector.transfer_read %[[LHSTILE]] +// CHECK: %[[NEG:.*]] = arith.negf %[[LHSVEC]] : vector<4xf32> +// CHECK: gml_st.set_yield %[[NEG]] into %[[OUT]] +// CHECK-SAME: vector<4xf32> into vector<32xf32> + +// ----- + +func.func @vectorize_loop_on_scalars( + %arg0: tensor<32xf32>, %arg1: tensor<32xf32>) -> tensor<32xf32> { + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + // We need this outer trivial loop to make sure the inner loop has a parent + // with the correct distribution label. + %2 = gml_st.parallel (%unused) = (%c0) to (%c1) step (%c1) + outs (%out_ = %arg1 : tensor<32xf32>) distribution ("test") { + %arg0tile = tensor.extract_slice %arg0[0][32][1] + : tensor<32xf32> to tensor<32xf32> + %out_tile = tensor.extract_slice %out_[0][32][1] + : tensor<32xf32> to tensor<32xf32> + %3 = gml_st.for (%i) = (%c0) to (%c32) step (%c4) + outs(%out = %out_tile : tensor<32xf32>) { + %6 = tensor.extract_slice %arg0tile[%i][1][1] + : tensor<32xf32> to tensor<1xf32> + %7 = tensor.extract %6[%c0] : tensor<1xf32> + %9 = arith.negf %7 : f32 + %tile = gml_st.tile [%i] [1] [1] : !gml_st.tile<1> + gml_st.set_yield %9 into %out[%tile] + : f32 into tensor<32xf32>[!gml_st.tile<1>] + } : tensor<32xf32> + %tile32 = gml_st.tile [0][32][1] : !gml_st.tile<32> + gml_st.set_yield %3 into %out_[%tile32] + : tensor<32xf32> into tensor<32xf32>[!gml_st.tile<32>] + } : tensor<32xf32> + func.return %2 : tensor<32xf32> +} +// CHECK-LABEL: @vectorize_loop_on_scalars( +// CHECK-SAME: %[[ARG0:.*]]: tensor<32xf32>, %[[ARG1:.*]]: tensor<32xf32> + +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: gml_st.parallel +// CHECK-SAME: outs (%[[OUT_:.*]] = %[[ARG1]]: +// CHECK-DAG: %[[RES:.*]] = vector.transfer_read %[[OUT_]][%[[C0]]] +// CHECK: %[[RESULT:.*]] = gml_st.for +// CHECK-SAME: outs (%[[OUT:.*]] = %[[RES]]: vector<32xf32>) +// CHECK: %[[LHSTILE:.*]] = tensor.extract_slice %[[ARG0]] +// CHECK: %[[LHSVEC:.*]] = vector.transfer_read %[[LHSTILE]][%c0] +// CHECK: %[[LHSELEM:.*]] = vector.extract %[[LHSVEC]] +// CHECK: %[[NEG:.*]] = arith.negf %[[LHSELEM]] : f32 +// CHECK: gml_st.set_yield %[[NEG]] into %[[OUT]] +// CHECK-SAME: f32 into vector<32xf32> + +// ----- + +// CHECK-LABEL: @skip_vectorization_with_wrong_label( +func.func @skip_vectorization_with_wrong_label( + %arg0: tensor<32xf32>, %arg1: tensor<32xf32>) + -> tensor<32xf32> { + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %2 = gml_st.parallel (%unused) = (%c0) to (%c1) step (%c1) + outs (%out_ = %arg1 : tensor<32xf32>) distribution ("no_vec") { + %3 = gml_st.parallel (%i) = (%c0) to (%c32) step (%c4) + outs (%out2_ = %out_ : tensor<32xf32>) distribution ("no_vec") { + %6 = tensor.extract_slice %arg0[%i][4][1] + : tensor<32xf32> to tensor<4xf32> + %7 = tensor.extract_slice %out2_[%i][4][1] + : tensor<32xf32> to tensor<4xf32> + %9 = linalg.map { arith.negf } + ins(%6: tensor<4xf32>) + outs(%7 : tensor<4xf32>) + %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> + gml_st.set_yield %9 into %out2_[%tile] + : tensor<4xf32> into tensor<32xf32>[!gml_st.tile<4>] + } : tensor<32xf32> + %tile32 = gml_st.tile [0][32][1] : !gml_st.tile<32> + gml_st.set_yield %3 into %out_[%tile32] + : tensor<32xf32> into tensor<32xf32>[!gml_st.tile<32>] + } : tensor<32xf32> + func.return %2 : tensor<32xf32> +} +// CHECK-NOT: vector.transfer_read + +// ----- + +// CHECK-LABEL: @materialize_to_scalar( +func.func @materialize_to_scalar(%arg1 : tensor<4xf32>) -> tensor<4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %empty = tensor.empty() : tensor<4xf32> + %1 = gml_st.parallel (%arg2) = (%c0) to (%c1) step (%c1) + outs (%out_ = %empty : tensor<4xf32>) distribution ("test") { + %5 = tensor.extract_slice %arg1[1][4][1] + : tensor<4xf32> to tensor<4xf32> + %3 = tensor.extract_slice %5[1][1][1] + : tensor<4xf32> to tensor<1xf32> + %4 = tensor.extract %3[%c0] : tensor<1xf32> + // CHECK: gml_st.materialize {{.*}} : vector<4xf32> to f32 + %2 = arith.negf %4 : f32 + %point = gml_st.tile [1][1][1] : !gml_st.tile<1> + gml_st.set_yield %2 into %out_[%point] + : f32 into tensor<4xf32>[!gml_st.tile<1>] + } : tensor<4xf32> + return %1 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: @materialize_to_dynamic_tile( +func.func @materialize_to_dynamic_tile(%arg1 : tensor<4xf32>, %size : index) + -> tensor<4xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %empty = tensor.empty() : tensor<4xf32> + %0 = gml_st.parallel (%arg3) = (%c0) to (%c1) step (%c1) + outs (%out_ = %empty : tensor<4xf32>) distribution ("test") { + %1 = gml_st.parallel (%arg2) = (%c0) to (%c1) step (%c1) + outs (%out2_ = %out_ : tensor<4xf32>) distribution ("test") { + %2 = tensor.extract_slice %arg1[1][4][1] + : tensor<4xf32> to tensor<4xf32> + %3 = tensor.extract_slice %2[1][%size][1] + : tensor<4xf32> to tensor + %dynTile = gml_st.tile [1][%size][1] : !gml_st.tile + gml_st.set_yield %3 into %out2_[%dynTile] + : tensor into tensor<4xf32>[!gml_st.tile] + } : tensor<4xf32> + %tile = gml_st.tile [1][4][1] : !gml_st.tile<4> + gml_st.set_yield %1 into %out_[%tile] + : tensor<4xf32> into tensor<4xf32>[!gml_st.tile<4>] + } : tensor<4xf32> + return %0 : tensor<4xf32> +} +// CHECK-NOT: vector diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_gml_st.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_gml_st.mlir deleted file mode 100644 index b852ddce1d5..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/vectorize_gml_st.mlir +++ /dev/null @@ -1,249 +0,0 @@ -// RUN: mlir-hlo-opt %s --split-input-file \ -// RUN: --vectorize-gml-st-loops="vectorize-gml-st-ops=true included-distribution-labels=test" \ -// RUN: | FileCheck %s - -#map0 = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: @vectorize_gml_st_parallel_op( -func.func @vectorize_gml_st_parallel_op( - %arg0: tensor<32xf32>, %arg1: tensor<32xf32>) - -> tensor<32xf32> { - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c32 = arith.constant 32 : index - %tile32 = gml_st.tile [0][32][1] : !gml_st.tile<32> - // We need this outer trivial loop to make sure the inner loop has a parent - // with the correct distribution label. - %2 = gml_st.parallel (%unused) = (%c0) to (%c1) step (%c1) - distribution ("test") { - %arg0tile = gml_st.materialize %arg0[%tile32] - : tensor<32xf32>[!gml_st.tile<32>] to tensor<32xf32> - %arg1tile = gml_st.materialize %arg1[%tile32] - : tensor<32xf32>[!gml_st.tile<32>] to tensor<32xf32> - %3 = gml_st.parallel (%i) = (%c0) to (%c32) step (%c4) - distribution ("test") { - %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> - %6 = gml_st.materialize %arg0tile[%tile] - : tensor<32xf32>[!gml_st.tile<4>] to tensor<4xf32> - %7 = gml_st.materialize %arg1tile[%tile] - : tensor<32xf32>[!gml_st.tile<4>] to tensor<4xf32> - %9 = linalg.generic {indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} - ins(%6: tensor<4xf32>) - outs(%7 : tensor<4xf32>) { - ^bb0(%arg5: f32, %arg6: f32): - %10 = arith.negf %arg5 : f32 - linalg.yield %10 : f32 - } -> tensor<4xf32> - gml_st.set_yield %9 into %arg1tile[%tile] - : tensor<4xf32> into tensor<32xf32>[!gml_st.tile<4>] - } : tensor<32xf32> - gml_st.set_yield %3 into %arg1[%tile32] - : tensor<32xf32> into tensor<32xf32>[!gml_st.tile<32>] - } : tensor<32xf32> - func.return %2 : tensor<32xf32> -} -// CHECK: gml_st.parallel -// CHECK-DAG: %[[ARG0TILE:.*]] = gml_st.materialize %arg0 -// CHECK-DAG: %[[LHS:.*]] = vector.transfer_read %[[ARG0TILE]][%c0] -// CHECK: %[[RESULT:.*]] = gml_st.parallel -// CHECK-DAG: %[[LHSTILE:.*]] = gml_st.materialize %[[LHS]] -// CHECK: %[[NEG:.*]] = arith.negf %[[LHSTILE]] : vector<4xf32> -// CHECK: gml_st.set_yield %[[NEG]] -// CHECK-SAME: vector<4xf32> into vector<32xf32> -// CHECK: vector.transfer_write %[[RESULT]], {{%.*}}[%c0] - -// ----- - -#map0 = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: @vectorize_gml_st_for_op( -func.func @vectorize_gml_st_for_op( - %arg0: tensor<32xf32>, %arg1: tensor<32xf32>) - -> tensor<32xf32> { - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c32 = arith.constant 32 : index - %tile32 = gml_st.tile [0][32][1] : !gml_st.tile<32> - // We need this outer trivial loop to make sure the inner loop has a parent - // with the correct distribution label. - %2 = gml_st.parallel (%unused) = (%c0) to (%c1) step (%c1) - distribution ("test") { - %arg0tile = gml_st.materialize %arg0[%tile32] - : tensor<32xf32>[!gml_st.tile<32>] to tensor<32xf32> - %arg1tile = gml_st.materialize %arg1[%tile32] - : tensor<32xf32>[!gml_st.tile<32>] to tensor<32xf32> - %3 = gml_st.for (%i) = (%c0) to (%c32) step (%c4) - outs(%out = %arg1tile : tensor<32xf32>) { - %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> - %6 = gml_st.materialize %arg0tile[%tile] - : tensor<32xf32>[!gml_st.tile<4>] to tensor<4xf32> - %7 = gml_st.materialize %out[%tile] - : tensor<32xf32>[!gml_st.tile<4>] to tensor<4xf32> - %9 = linalg.generic {indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} - ins(%6: tensor<4xf32>) - outs(%7 : tensor<4xf32>) { - ^bb0(%arg5: f32, %arg6: f32): - %10 = arith.negf %arg5 : f32 - linalg.yield %10 : f32 - } -> tensor<4xf32> - gml_st.set_yield %9 into %out[%tile] - : tensor<4xf32> into tensor<32xf32>[!gml_st.tile<4>] - } : tensor<32xf32> - gml_st.set_yield %3 into %arg1[%tile32] - : tensor<32xf32> into tensor<32xf32>[!gml_st.tile<32>] - } : tensor<32xf32> - func.return %2 : tensor<32xf32> -} -// CHECK: gml_st.parallel -// CHECK-DAG: %[[ARG0TILE:.*]] = gml_st.materialize %arg0 -// CHECK-DAG: %[[ARG1TILE:.*]] = gml_st.materialize %arg1 -// CHECK-DAG: %[[LHS:.*]] = vector.transfer_read %[[ARG0TILE]][%c0] -// CHECK-DAG: %[[RES:.*]] = vector.transfer_read %[[ARG1TILE]][%c0] -// CHECK: %[[RESULT:.*]] = gml_st.for -// CHECK-SAME: outs (%[[OUT:.*]] = %[[RES]]: vector<32xf32>) -// CHECK-DAG: %[[LHSTILE:.*]] = gml_st.materialize %[[LHS]] -// CHECK: %[[NEG:.*]] = arith.negf %[[LHSTILE]] : vector<4xf32> -// CHECK: gml_st.set_yield %[[NEG]] into %[[OUT]] -// CHECK-SAME: vector<4xf32> into vector<32xf32> -// CHECK: vector.transfer_write %[[RESULT]], %[[ARG1TILE]][%c0] - -// ----- - -// CHECK-LABEL: @vectorize_loop_on_scalars( -func.func @vectorize_loop_on_scalars( - %arg0: tensor<32xf32>, %arg1: tensor<32xf32>) - -> tensor<32xf32> { - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c32 = arith.constant 32 : index - %tile32 = gml_st.tile [0][32][1] : !gml_st.tile<32> - // We need this outer trivial loop to make sure the inner loop has a parent - // with the correct distribution label. - %2 = gml_st.parallel (%unused) = (%c0) to (%c1) step (%c1) - distribution ("test") { - %arg0tile = gml_st.materialize %arg0[%tile32] - : tensor<32xf32>[!gml_st.tile<32>] to tensor<32xf32> - %arg1tile = gml_st.materialize %arg1[%tile32] - : tensor<32xf32>[!gml_st.tile<32>] to tensor<32xf32> - %3 = gml_st.for (%i) = (%c0) to (%c32) step (%c4) - outs(%out = %arg1tile : tensor<32xf32>) { - %tile = gml_st.tile [%i] [1] [1] : !gml_st.tile<1> - %6 = gml_st.materialize %arg0tile[%tile] - : tensor<32xf32>[!gml_st.tile<1>] to f32 - %9 = arith.negf %6 : f32 - gml_st.set_yield %9 into %out[%tile] - : f32 into tensor<32xf32>[!gml_st.tile<1>] - } : tensor<32xf32> - gml_st.set_yield %3 into %arg1[%tile32] - : tensor<32xf32> into tensor<32xf32>[!gml_st.tile<32>] - } : tensor<32xf32> - func.return %2 : tensor<32xf32> -} -// CHECK: gml_st.parallel -// CHECK-DAG: %[[ARG0TILE:.*]] = gml_st.materialize %arg0 -// CHECK-DAG: %[[ARG1TILE:.*]] = gml_st.materialize %arg1 -// CHECK-DAG: %[[LHS:.*]] = vector.transfer_read %[[ARG0TILE]][%c0] -// CHECK-DAG: %[[RES:.*]] = vector.transfer_read %[[ARG1TILE]][%c0] -// CHECK: %[[RESULT:.*]] = gml_st.for -// CHECK-SAME: outs (%[[OUT:.*]] = %[[RES]]: vector<32xf32>) -// CHECK-DAG: %[[LHSTILE:.*]] = gml_st.materialize %[[LHS]] -// CHECK: %[[NEG:.*]] = arith.negf %[[LHSTILE]] : f32 -// CHECK: gml_st.set_yield %[[NEG]] into %[[OUT]] -// CHECK-SAME: f32 into vector<32xf32> -// CHECK: vector.transfer_write %[[RESULT]], %[[ARG1TILE]][%c0] - -// ----- - -#map0 = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: @skip_vectorization_with_wrong_label( -func.func @skip_vectorization_with_wrong_label( - %arg0: tensor<32xf32>, %arg1: tensor<32xf32>) - -> tensor<32xf32> { - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %c32 = arith.constant 32 : index - %tile32 = gml_st.tile [0][32][1] : !gml_st.tile<32> - %2 = gml_st.parallel (%unused) = (%c0) to (%c1) step (%c1) - distribution ("no_vec") { - %3 = gml_st.parallel (%i) = (%c0) to (%c32) step (%c4) - distribution ("no_vec") { - %tile = gml_st.tile [%i] [4] [1] : !gml_st.tile<4> - %6 = gml_st.materialize %arg0[%tile] - : tensor<32xf32>[!gml_st.tile<4>] to tensor<4xf32> - %7 = gml_st.materialize %arg1[%tile] - : tensor<32xf32>[!gml_st.tile<4>] to tensor<4xf32> - %9 = linalg.generic {indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} - ins(%6 : tensor<4xf32>) - outs(%7 : tensor<4xf32>) { - ^bb0(%arg5: f32, %arg6: f32): - %10 = arith.negf %arg5 : f32 - linalg.yield %10 : f32 - } -> tensor<4xf32> - gml_st.set_yield %9 into %arg1[%tile] - : tensor<4xf32> into tensor<32xf32>[!gml_st.tile<4>] - } : tensor<32xf32> - gml_st.set_yield %3 into %arg1[%tile32] - : tensor<32xf32> into tensor<32xf32>[!gml_st.tile<32>] - } : tensor<32xf32> - func.return %2 : tensor<32xf32> -} -// CHECK-NOT: vector.transfer_read - -// ----- - -// CHECK-LABEL: @materialize_to_scalar( -func.func @materialize_to_scalar(%arg1 : tensor<4xf32>) -> tensor<4xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %empty = tensor.empty() : tensor<4xf32> - %1 = gml_st.parallel (%arg2) = (%c0) to (%c1) step (%c1) - distribution ("test") { - %tile = gml_st.tile [1][4][1] : !gml_st.tile<4> - %point = gml_st.tile [1][1][1] : !gml_st.tile<1> - %5 = gml_st.materialize %arg1[%tile] - : tensor<4xf32>[!gml_st.tile<4>] to tensor<4xf32> - %3 = gml_st.materialize %5[%point] - : tensor<4xf32>[!gml_st.tile<1>] to f32 - // CHECK: gml_st.materialize {{.*}} : vector<4xf32>[!gml_st.tile<1>] to f32 - %2 = arith.negf %3 : f32 - gml_st.set_yield %2 into %empty[%point] - : f32 into tensor<4xf32>[!gml_st.tile<1>] - } : tensor<4xf32> - return %1 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: @materialize_to_dynamic_tile( -func.func @materialize_to_dynamic_tile(%arg1 : tensor<4xf32>, %size : index) - -> tensor<4xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %empty = tensor.empty() : tensor<4xf32> - %0 = gml_st.parallel (%arg3) = (%c0) to (%c1) step (%c1) - distribution ("test") { - %tile = gml_st.tile [1][4][1] : !gml_st.tile<4> - %1 = gml_st.parallel (%arg2) = (%c0) to (%c1) step (%c1) - distribution ("test") { - %2 = gml_st.materialize %arg1[%tile] - : tensor<4xf32>[!gml_st.tile<4>] to tensor<4xf32> - %dynTile = gml_st.tile [1][%size][1] : !gml_st.tile - %3 = gml_st.materialize %2[%dynTile] - : tensor<4xf32>[!gml_st.tile] to tensor - gml_st.set_yield %3 into %empty[%dynTile] - : tensor into tensor<4xf32>[!gml_st.tile] - } : tensor<4xf32> - gml_st.set_yield %1 into %empty[%tile] - : tensor<4xf32> into tensor<4xf32>[!gml_st.tile<4>] - } : tensor<4xf32> - return %0 : tensor<4xf32> -} -// CHECK-NOT: vector diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/warp_reduce.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/warp_reduce.mlir new file mode 100644 index 00000000000..91ebf82b2ef --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/warp_reduce.mlir @@ -0,0 +1,226 @@ +// RUN: mlir-hlo-opt -split-input-file %s \ +// RUN: -gml-st-simtfy="block-distribution-label=block" \ +// RUN: -gml-st-to-gpu="warp-distribution-label=warp" \ +// RUN: | FileCheck %s + +// CHECK-LABEL: func @vector_reduce_add +func.func @vector_reduce_add( + %arg0: vector<1xf32>, + %arg1: vector<1xf32> +) -> vector<1xf32> { + + %lane = gpu.lane_id + %tile = gml_st.tile [%lane] [1] [1] : !gml_st.tile<1> + %dist = gml_st.distribute %arg1 into[%tile] + : vector<1xf32> into vector<1x32xf32>[!gml_st.tile<1>] + + // CHECK: %[[X0:.*]] = vector.extract %arg1[0] + // CHECK: %[[Y0:.*]], %{{.*}} = gpu.shuffle xor %[[X0]], %c1 + // CHECK: %[[X1:.*]] = arith.addf %[[X0]], %[[Y0]] + // CHECK: %[[Y1:.*]], %{{.*}} = gpu.shuffle xor %[[X1]], %c2 + // CHECK: %[[X2:.*]] = arith.addf %[[X1]], %[[Y1]] + // CHECK: %[[Y2:.*]], %{{.*}} = gpu.shuffle xor %[[X2]], %c4 + // CHECK: %[[X3:.*]] = arith.addf %[[X2]], %[[Y2]] + // CHECK: %[[Y3:.*]], %{{.*}} = gpu.shuffle xor %[[X3]], %c8 + // CHECK: %[[X4:.*]] = arith.addf %[[X3]], %[[Y3]] + // CHECK: %[[Y4:.*]], %{{.*}} = gpu.shuffle xor %[[X4]], %c16 + // CHECK: %[[X5:.*]] = arith.addf %[[X4]], %[[Y4]] + // CHECK: %[[Y5:.*]] = vector.extract %arg0[0] + // CHECK: %[[X6:.*]] = arith.addf %[[Y5]], %[[X5]] + // CHECK: %[[RESULT:.*]] = vector.broadcast %[[X6]] + %result = vector.multi_reduction , %dist, %arg0 + {"gml-st-distribution-label" = "warp"} [1] + : vector<1x32xf32> to vector<1xf32> + + // CHECK: return %[[RESULT]] + func.return %result : vector<1xf32> +} + +// ----- + +// CHECK-LABEL: func @vector_reduce_add_int +func.func @vector_reduce_add_int( + %arg0: vector<1xi32>, + %arg1: vector<1xi32> +) -> vector<1xi32> { + + %lane = gpu.lane_id + %tile = gml_st.tile [%lane] [1] [1] : !gml_st.tile<1> + %dist = gml_st.distribute %arg1 into[%tile] + : vector<1xi32> into vector<1x32xi32>[!gml_st.tile<1>] + + // CHECK: %[[X0:.*]] = vector.extract %arg1[0] + // CHECK: %[[Y0:.*]], %{{.*}} = gpu.shuffle xor %[[X0]], %c1 + // CHECK: %[[X1:.*]] = arith.addi %[[X0]], %[[Y0]] + // CHECK: %[[Y1:.*]], %{{.*}} = gpu.shuffle xor %[[X1]], %c2 + // CHECK: %[[X2:.*]] = arith.addi %[[X1]], %[[Y1]] + // CHECK: %[[Y2:.*]], %{{.*}} = gpu.shuffle xor %[[X2]], %c4 + // CHECK: %[[X3:.*]] = arith.addi %[[X2]], %[[Y2]] + // CHECK: %[[Y3:.*]], %{{.*}} = gpu.shuffle xor %[[X3]], %c8 + // CHECK: %[[X4:.*]] = arith.addi %[[X3]], %[[Y3]] + // CHECK: %[[Y4:.*]], %{{.*}} = gpu.shuffle xor %[[X4]], %c16 + // CHECK: %[[X5:.*]] = arith.addi %[[X4]], %[[Y4]] + // CHECK: %[[Y5:.*]] = vector.extract %arg0[0] + // CHECK: %[[X6:.*]] = arith.addi %[[Y5]], %[[X5]] + // CHECK: %[[RESULT:.*]] = vector.broadcast %[[X6]] + %result = vector.multi_reduction , %dist, %arg0 + {"gml-st-distribution-label" = "warp"} [1] + : vector<1x32xi32> to vector<1xi32> + + // CHECK: return %[[RESULT]] + func.return %result : vector<1xi32> +} + +// ----- + +// CHECK-LABEL: func @vector_reduce_mul +func.func @vector_reduce_mul( + %arg0: vector<1xf32>, + %arg1: vector<1xf32> +) -> vector<1xf32> { + %lane = gpu.lane_id + %tile = gml_st.tile [%lane] [1] [1] : !gml_st.tile<1> + %dist = gml_st.distribute %arg1 into[%tile] + : vector<1xf32> into vector<1x32xf32>[!gml_st.tile<1>] + + // CHECK: arith.mulf + %result = vector.multi_reduction , %dist, %arg0 + {"gml-st-distribution-label" = "warp"} [1] + : vector<1x32xf32> to vector<1xf32> + func.return %result : vector<1xf32> +} + +// ----- + +// CHECK-LABEL: func @vector_reduce_mul_int +func.func @vector_reduce_mul_int( + %arg0: vector<1xi32>, + %arg1: vector<1xi32> +) -> vector<1xi32> { + %lane = gpu.lane_id + %tile = gml_st.tile [%lane] [1] [1] : !gml_st.tile<1> + %dist = gml_st.distribute %arg1 into[%tile] + : vector<1xi32> into vector<1x32xi32>[!gml_st.tile<1>] + + // CHECK: arith.muli + %result = vector.multi_reduction , %dist, %arg0 + {"gml-st-distribution-label" = "warp"} [1] + : vector<1x32xi32> to vector<1xi32> + func.return %result : vector<1xi32> +} + +// ----- + +// CHECK-LABEL: func @vector_reduce_small +func.func @vector_reduce_small( + %arg0: vector<1xf32>, + %arg1: vector<1xf32> +) -> vector<1xf32> { + + %lane = gpu.lane_id + %tile = gml_st.tile [%lane] [1] [1] : !gml_st.tile<1> + %dist = gml_st.distribute %arg1 into[%tile] + : vector<1xf32> into vector<1x4xf32>[!gml_st.tile<1>] + + // CHECK: %[[X0:.*]] = vector.extract %arg1[0] + // CHECK: %[[Y0:.*]], %{{.*}} = gpu.shuffle xor %[[X0]], %c1 + // CHECK: %[[X1:.*]] = arith.addf %[[X0]], %[[Y0]] + // CHECK: %[[Y1:.*]], %{{.*}} = gpu.shuffle xor %[[X1]], %c2 + // CHECK: %[[X2:.*]] = arith.addf %[[X1]], %[[Y1]] + // CHECK: %[[Y2:.*]] = vector.extract %arg0[0] + // CHECK: %[[X3:.*]] = arith.addf %[[Y2]], %[[X2]] + // CHECK: %[[RESULT:.*]] = vector.broadcast %[[X3]] + %result = vector.multi_reduction , %dist, %arg0 + {"gml-st-distribution-label" = "warp"} [1] + : vector<1x4xf32> to vector<1xf32> + + // CHECK: return %[[RESULT]] + func.return %result : vector<1xf32> +} + +// ----- + +// CHECK-LABEL: func @vector_reduce_fp16 +func.func @vector_reduce_fp16( + %arg0: vector<1xf16>, + %arg1: vector<1xf16> +) -> vector<1xf16> { + + %lane = gpu.lane_id + %tile = gml_st.tile [%lane] [1] [1] : !gml_st.tile<1> + %dist = gml_st.distribute %arg1 into[%tile] + : vector<1xf16> into vector<1x2xf16>[!gml_st.tile<1>] + + // CHECK: %[[X0:.*]] = vector.extract %arg1[0] + // CHECK: %[[A0:.*]] = arith.bitcast %[[X0]] + // CHECK: %[[B0:.*]] = arith.extui %[[A0]] + // CHECK: %[[C0:.*]], %{{.*}} = gpu.shuffle xor %[[B0]], %c1 + // CHECK: %[[D0:.*]] = arith.trunci %[[C0]] + // CHECK: %[[Y0:.*]] = arith.bitcast %[[D0]] + // CHECK: %[[X1:.*]] = arith.maxf %[[X0]], %[[Y0]] + // CHECK: %[[Y1:.*]] = vector.extract %arg0[0] + // CHECK: %[[X2:.*]] = arith.maxf %[[Y1]], %[[X1]] + // CHECK: %[[RESULT:.*]] = vector.broadcast %[[X2]] + %result = vector.multi_reduction , %dist, %arg0 + {"gml-st-distribution-label" = "warp"} [1] + : vector<1x2xf16> to vector<1xf16> + + // CHECK: return %[[RESULT]] + func.return %result : vector<1xf16> +} + +// ----- + +#stride1 = strided<[1], offset: ?> + +// CHECK-LABEL: func @gpu_launch +func.func @gpu_launch() -> memref<64xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %cst = arith.constant dense<0.0> : vector<1xf32> + %0 = memref.alloc() : memref<64xf32> + // CHECK: gpu.launch + gml_st.parallel (%arg1) = (%c0) to (%c64) step (%c4) distribution ("block") { + %1 = memref.subview %0[%arg1] [4] [1] + : memref<64xf32> to memref<4xf32, #stride1> + gml_st.parallel (%arg2) = (%c0) to (%c4) step (%c1) distribution ("warp") { + %2 = memref.subview %1[%arg2] [1] [1] + : memref<4xf32, #stride1> to memref<1xf32, #stride1> + + %init = vector.broadcast %cst : vector<1xf32> to vector<1x32xf32> + %3 = gml_st.parallel (%arg3) = (%c0) to (%c32) step (%c1) + outs (%out_ = %init: vector<1x32xf32>) distribution ("thread") { + %tile = gml_st.tile [0, %arg3] [1, 1] [1, 1] : !gml_st.tile<1x1> + %elem = arith.constant dense<1.0> : vector<1x1xf32> + gml_st.set_yield %elem into %out_[%tile] + : vector<1x1xf32> into vector<1x32xf32>[!gml_st.tile<1x1>] + } : vector<1x32xf32> + + // CHECK-NOT: vector.multi_reduction + %sum = vector.multi_reduction , %3, %cst [1] + : vector<1x32xf32> to vector<1xf32> + vector.transfer_write %sum, %2[%c0] {in_bounds = [true]} + : vector<1xf32>, memref<1xf32, #stride1> + gml_st.set_yield + } + gml_st.set_yield + } + return %0 : memref<64xf32> +} + +// ----- + +func.func @transform_only_warp_level_multi_reduction(%in: vector<4x10xi32>) + -> i32 { + %acc = arith.constant 0 : i32 + %result = vector.multi_reduction , %in, %acc + {"gml-st-distribution-level" = "not-warp"} [0, 1] : vector<4x10xi32> to i32 + func.return %result : i32 +} + +// CHECK-LABEL: @transform_only_warp_level_multi_reduction +// CHECK: vector.multi_reduction , %[[IN:.*]], %[[ACC:.*]] +// CHECK-SAME {"gml-st-distribution-level" = "not-warp"} [0, 1] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-fuse-linalg.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-fuse-linalg.mlir deleted file mode 100644 index be1db0627ea..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-fuse-linalg.mlir +++ /dev/null @@ -1,429 +0,0 @@ -// RUN: mlir-hlo-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always -// RUN: mlir-hlo-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED -// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP - -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel"]} -func.func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, - %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) { - %temp_result = memref.alloc() : memref<6x6xf32> - linalg.generic #pointwise_2d_trait - ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>) - outs(%temp_result : memref<6x6xf32>) { - ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): - %out = arith.addf %summand_1_in, %summand_2_in : f32 - linalg.yield %out : f32 - } - linalg.generic #pointwise_2d_trait - ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>) - outs(%result : memref<6x6xf32>) { - ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): - %out = arith.mulf %temp_result_in, %multiplier_in : f32 - linalg.yield %out : f32 - } - memref.dealloc %temp_result : memref<6x6xf32> - func.return -} -// CHECK-LABEL: func @fusion -// CHECK: %[[C1:.*]] = arith.constant 1 -// CHECK-NOT: linalg.generic -// CHECK: scf.for {{.*}} step %[[C1]] -// CHECK: scf.for {{.*}} step %[[C1]] -// CHECK-NOT: scf.for -// CHECK: linalg.generic -// CHECK: addf -// CHECK: linalg.generic -// CHECK: mulf - -// TILED-LABEL: func @fusion -// TILED-DAG: %[[C2:.*]] = arith.constant 2 -// TILED-DAG: %[[C3:.*]] = arith.constant 3 -// TILED-NOT: linalg.generic -// TILED: scf.for {{.*}} step %[[C2]] -// TILED: scf.for {{.*}} step %[[C3]] -// TILED-NOT: scf.for -// TILED: linalg.generic -// TILED: addf -// TILED: linalg.generic -// TILED: mulf - -// PLOOP-LABEL: func @fusion -// PLOOP-NOT: linalg.generic -// PLOOP: scf.parallel -// PLOOP-NOT: scf.parallel -// PLOOP: linalg.generic -// PLOOP: addf -// PLOOP: linalg.generic -// PLOOP: mulf - -// ----- - -func.func @fusion_of_three(%arg0: memref<100x10xf32>, - %arg1: memref<100xf32>, - %arg2: memref<100x10xf32>) { - %0 = memref.alloc() : memref<100x10xf32> - linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%arg1 : memref<100xf32>) - outs(%0 : memref<100x10xf32>) { - ^bb0(%arg3: f32, %arg4: f32): - linalg.yield %arg3 : f32 - } - %1 = memref.alloc() : memref<100x10xf32> - linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>) - outs(%1 : memref<100x10xf32>) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): - %2 = arith.subf %arg3, %arg4 : f32 - linalg.yield %2 : f32 - } - memref.dealloc %0 : memref<100x10xf32> - linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%1 : memref<100x10xf32>) - outs(%arg2 : memref<100x10xf32>) { - ^bb0(%arg3: f32, %arg4: f32): - %2 = math.exp %arg3 : f32 - linalg.yield %2 : f32 - } - memref.dealloc %1 : memref<100x10xf32> - func.return -} -// CHECK-LABEL: func @fusion -// CHECK: %[[C1:.*]] = arith.constant 1 : -// CHECK-NOT: linalg.generic -// CHECK: scf.for {{.*}} step %[[C1]] -// CHECK: scf.for {{.*}} step %[[C1]] -// CHECK-NOT: scf.for -// CHECK: linalg.generic -// CHECK: linalg.generic -// CHECK: subf -// CHECK: linalg.generic -// CHECK: exp - -// TILED-LABEL: func @fusion_of_three -// TILED-DAG: %[[C2:.*]] = arith.constant 2 -// TILED-DAG: %[[C3:.*]] = arith.constant 3 -// TILED-NOT: linalg.generic -// TILED: scf.for {{.*}} step %[[C2]] -// TILED: scf.for {{.*}} step %[[C3]] -// TILED-NOT: scf.for -// TILED: linalg.generic -// TILED: linalg.generic -// TILED: subf -// TILED: linalg.generic -// TILED: exp - -// PLOOP-LABEL: func @fusion_of_three -// PLOOP-NOT: linalg.generic -// PLOOP: scf.parallel -// PLOOP-NOT: scf.parallel -// PLOOP: linalg.generic -// PLOOP: linalg.generic -// PLOOP: subf -// PLOOP: linalg.generic -// PLOOP: exp - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", - "parallel"]} -func.func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>, - %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) { - %temp_result = memref.alloc() : memref<6x6x6x6xf32> - linalg.generic #pointwise_4d_trait - ins(%summand_1, %summand_2 : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>) - outs(%temp_result : memref<6x6x6x6xf32>) { - ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): - %out = arith.addf %summand_1_in, %summand_2_in : f32 - linalg.yield %out : f32 - } - linalg.generic #pointwise_4d_trait - ins(%temp_result, %multiplier : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>) - outs(%result : memref<6x6x6x6xf32>) { - ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): - %out = arith.mulf %temp_result_in, %multiplier_in : f32 - linalg.yield %out : f32 - } - memref.dealloc %temp_result : memref<6x6x6x6xf32> - func.return -} -// CHECK-LABEL: func @fusion_4d -// CHECK: %[[C1:.*]] = arith.constant 1 -// CHECK-NOT: linalg.generic -// CHECK: scf.for {{.*}} step %[[C1]] -// CHECK: scf.for {{.*}} step %[[C1]] -// CHECK: scf.for {{.*}} step %[[C1]] -// CHECK: scf.for {{.*}} step %[[C1]] -// CHECK-NOT: scf.for -// CHECK: linalg.generic -// CHECK: addf -// CHECK: linalg.generic -// CHECK: mulf - -// TILED-LABEL: func @fusion_4d -// TILED-DAG: %[[C2:.*]] = arith.constant 2 -// TILED-DAG: %[[C3:.*]] = arith.constant 3 -// TILED-NOT: linalg.generic -// TILED: scf.for {{.*}} step %[[C2]] -// TILED: scf.for {{.*}} step %[[C3]] -// TILED-NOT: scf.for -// TILED: linalg.generic -// TILED: addf -// TILED: linalg.generic -// TILED: mulf - -// PLOOP-LABEL: func @fusion_4d -// PLOOP-NOT: linalg.generic -// PLOOP: scf.parallel -// PLOOP-NOT: scf.parallel -// PLOOP: linalg.generic -// PLOOP: addf -// PLOOP: linalg.generic -// PLOOP: mulf - -// ----- - -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel"]} -func.func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, - %summand_2: memref<6x6xf32>) -> memref<6x6xf32> { - %temp_result = memref.alloc() : memref<6x6xf32> - linalg.generic #pointwise_2d_trait - ins(%summand_1, %summand_2 : memref<6x6xf32>, memref<6x6xf32>) - outs(%temp_result : memref<6x6xf32>) { - ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): - %out = arith.addf %summand_1_in, %summand_2_in : f32 - linalg.yield %out : f32 - } - %result = memref.alloc() : memref<6x6xf32> - linalg.generic #pointwise_2d_trait - ins(%temp_result, %multiplier : memref<6x6xf32>, memref<6x6xf32>) - outs(%result : memref<6x6xf32>) { - ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): - %out = arith.mulf %temp_result_in, %multiplier_in : f32 - linalg.yield %out : f32 - } - memref.dealloc %temp_result : memref<6x6xf32> - func.return %result : memref<6x6xf32> -} - -// CHECK-LABEL: func @fusion -// CHECK: %[[C1:.*]] = arith.constant 1 -// CHECK-NOT: linalg.generic -// CHECK: scf.for {{.*}} step %[[C1]] -// CHECK: scf.for {{.*}} step %[[C1]] -// CHECK-NOT: scf.for -// CHECK: linalg.generic -// CHECK: addf -// CHECK: linalg.generic -// CHECK: mulf - -// TILED-LABEL: func @fusion -// TILED-DAG: %[[C2:.*]] = arith.constant 2 -// TILED-DAG: %[[C3:.*]] = arith.constant 3 -// TILED-NOT: linalg.generic -// TILED: scf.for {{.*}} step %[[C2]] -// TILED: scf.for {{.*}} step %[[C3]] -// TILED-NOT: scf.for -// TILED: linalg.generic -// TILED: addf -// TILED: linalg.generic -// TILED: mulf - -// PLOOP-LABEL: func @fusion -// PLOOP-NOT: linalg.generic -// PLOOP: scf.parallel -// PLOOP-NOT: scf.parallel -// PLOOP: linalg.generic -// PLOOP: addf -// PLOOP: linalg.generic -// PLOOP: mulf - -// ----- - -func.func @view_result(%arg0: memref, %arg1: memref, %arg2: index) - -> memref<*xf32> { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %1 = memref.alloc(%arg2) : memref - linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, - affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - ins(%arg0 : memref) outs(%1 : memref) { - ^bb0(%arg3: f32, %arg4: f32): - %13 = math.absf %arg3 : f32 - linalg.yield %13 : f32 - } - %2 = memref.reshape %1(%arg1) - : (memref, memref) -> memref<*xf32> - func.return %2 : memref<*xf32> -} - -// CHECK-LABEL: func @view_result -// CHECK: %[[C1:.*]] = arith.constant 1 -// CHECK-NOT: linalg.generic -// CHECK: scf.for {{.*}} step %[[C1]] -// CHECK-NOT: scf.for -// CHECK: linalg.generic -// CHECK: math.absf -// CHECK: memref.reshape - -// TILED-LABEL: func @view_result -// TILED-DAG: %[[C2:.*]] = arith.constant 2 -// TILED-NOT: linalg.generic -// TILED: scf.for {{.*}} step %[[C2]] -// TILED-NOT: scf.for -// TILED: linalg.generic -// TILED: math.absf -// TILED: memref.reshape - - -// PLOOP-LABEL: func @view_result -// PLOOP-NOT: linalg.generic -// PLOOP: scf.parallel -// PLOOP-NOT: scf.parallel -// PLOOP: linalg.generic -// PLOOP: math.absf -// PLOOP: memref.reshape - - - -// ----- - -// Confirm that tiling information is passed through RegionBranchOpInterfaces. -// This test also uses memref.reshape, just to have a value to return through -// the if statement. -func.func @branching_result(%arg0: memref, %arg1: memref, %arg2: index) - -> memref<*xf32> { - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %1 = memref.alloc(%arg2) : memref - linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, - affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - ins(%arg0 : memref) outs(%1 : memref) { - ^bb0(%arg3: f32, %arg4: f32): - %13 = math.absf %arg3 : f32 - linalg.yield %13 : f32 - } - %true = arith.constant 1 : i1 - %3 = scf.if %true -> memref<*xf32> { - %2 = memref.reshape %1(%arg1) - : (memref, memref) -> memref<*xf32> - scf.yield %2 : memref<*xf32> - } else { - %2 = memref.reshape %1(%arg1) - : (memref, memref) -> memref<*xf32> - scf.yield %2 : memref<*xf32> - } - func.return %3 : memref<*xf32> -} - -// CHECK-LABEL: func @branching_result -// CHECK: %[[C1:.*]] = arith.constant 1 -// CHECK-NOT: linalg.generic -// CHECK: scf.for {{.*}} step %[[C1]] -// CHECK-NOT: scf.for -// CHECK: linalg.generic -// CHECK: math.absf -// CHECK: scf.if -// CHECK: memref.reshape -// CHECK: scf.yield -// CHECK: else -// CHECK: memref.reshape -// CHECK: scf.yield - -// TILED-LABEL: func @branching_result -// TILED-DAG: %[[C2:.*]] = arith.constant 2 -// TILED-NOT: linalg.generic -// TILED: scf.for {{.*}} step %[[C2]] -// TILED-NOT: scf.for -// TILED: linalg.generic -// TILED: math.absf -// TILED: scf.if -// TILED: memref.reshape -// TILED: scf.yield -// TILED: else -// TILED: memref.reshape -// TILED: scf.yield - -// PLOOP-LABEL: func @branching_result -// PLOOP-NOT: linalg.generic -// PLOOP: scf.parallel -// PLOOP-NOT: scf.parallel -// PLOOP: linalg.generic -// PLOOP: math.absf -// PLOOP: scf.if -// PLOOP: memref.reshape -// PLOOP: scf.yield -// PLOOP: else -// PLOOP: memref.reshape -// PLOOP: scf.yield - -// ----- - -// Confirm that tiling information is passed through tensor_load, tensor.cast -// and memref_to_tensor operations. -func.func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>) - -> memref { - %c1 = arith.constant 1 : index - %1 = memref.alloc() : memref<32xf32> - linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, - affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - ins(%arg0 : memref<32xf32>) outs(%1 : memref<32xf32>) { - ^bb0(%arg3: f32, %arg4: f32): - %13 = math.absf %arg3 : f32 - linalg.yield %13 : f32 - } - %2 = bufferization.to_tensor %1 : memref<32xf32> - %3 = tensor.cast %2 : tensor<32xf32> to tensor - %4 = bufferization.to_memref %3 : memref - func.return %4 : memref -} - -// CHECK-LABEL: func @tensor_ops -// CHECK: %[[C1:.*]] = arith.constant 1 -// CHECK-NOT: linalg.generic -// CHECK: scf.for {{.*}} step %[[C1]] -// CHECK-NOT: scf.for -// CHECK: linalg.generic -// CHECK: math.absf -// CHECK: bufferization.to_tensor -// CHECK: tensor.cast -// CHECK: bufferization.to_memref - -// TILED-LABEL: func @tensor_ops -// TILED-DAG: %[[C2:.*]] = arith.constant 2 -// TILED-NOT: linalg.generic -// TILED: scf.for {{.*}} step %[[C2]] -// TILED-NOT: scf.for -// TILED: linalg.generic -// TILED: math.absf -// TILED: bufferization.to_tensor -// TILED: tensor.cast -// TILED: bufferization.to_memref - - -// PLOOP-LABEL: func @tensor_ops -// PLOOP-NOT: linalg.generic -// PLOOP: scf.parallel -// PLOOP-NOT: scf.parallel -// PLOOP: linalg.generic -// PLOOP: math.absf -// PLOOP: bufferization.to_tensor -// PLOOP: tensor.cast -// PLOOP: bufferization.to_memref diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-legalize-select-and-scatter.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-legalize-select-and-scatter.mlir index 15d161192fb..8a1f6c0225f 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-legalize-select-and-scatter.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo/lhlo-legalize-select-and-scatter.mlir @@ -131,7 +131,7 @@ func.func @select_and_scatter(%arg: memref<112x112xf32>, // CHECK: } // Use selected ivs to load element from the SRC buffer. -// CHECK: %[[SRC_ELEM:.*]] = "memref.load"(%[[SRC_BUF]], %[[II]], %[[JJ]]) : (memref<56x56xf32>, index, index) -> f32 +// CHECK: %[[SRC_ELEM:.*]] = "memref.load"(%[[SRC_BUF]], %[[II]], %[[JJ]]) {nontemporal = false} : (memref<56x56xf32>, index, index) -> f32 // Update of RESULT[SELECTED_I, SELECTED_J] should be done atomically, because // it may happen that several other threads select the same IVs if the windows diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo/ops.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo/ops.mlir index fbcc9c8c15e..b6a0604d625 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo/ops.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/lhlo/ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt -split-input-file | FileCheck %s // ----- @@ -78,7 +78,7 @@ func.func @invalid_alltoall(%input0: memref<2xf32>, %output: memref<8xf32>) { // ----- func.func @invalid_alltoall(%input0: memref<2xf32>, %output: memref<8xf32>) { - // expected-error@+1 {{replica groups should be a rank 2 tensor of 64 bit integers}} + // expected-error@+1 {{replica groups should be a rank 2 tensor}} "lmhlo.all_to_all"(%input0, %output) {channel_id = #mhlo.channel_handle, constrain_layout = false, replica_groups = dense<0> : tensor<1xi64>, @@ -1081,7 +1081,7 @@ func.func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, // CHECK-LABEL: func @valid_custom_call func.func @valid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { - "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) { + "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) ({}) { backend_config = "", call_target_name = "foo", has_side_effects = false, @@ -1100,7 +1100,7 @@ func.func @valid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { func.func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { // expected-error @+1 {{number of entries in the mapping for args (1) should match the number of args for the operation (2)}} - "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) { + "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) ({}) { backend_config = "", call_target_name = "foo", has_side_effects = false, @@ -1119,7 +1119,7 @@ func.func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { func.func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { // expected-error @+1 {{number of entries in the mapping for results (1) should match the number of results for the operation (2)}} - "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) { + "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) ({}) { backend_config = "", call_target_name = "foo", has_side_effects = false, @@ -1138,7 +1138,7 @@ func.func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { func.func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { // expected-error @+1 {{entry 0 cannot appear more than once in the mapping for args}} - "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) { + "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) ({}) { backend_config = "", call_target_name = "foo", has_side_effects = false, @@ -1157,7 +1157,7 @@ func.func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { func.func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { // expected-error @+1 {{entry 1 cannot appear more than once in the mapping for results}} - "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) { + "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) ({}) { backend_config = "", call_target_name = "foo", has_side_effects = false, @@ -1176,7 +1176,7 @@ func.func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { func.func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { // expected-error @+1 {{entries in mapping for args must be >= 0 and less than target's number of args (4)}} - "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) { + "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) ({}) { backend_config = "", call_target_name = "foo", has_side_effects = false, @@ -1195,7 +1195,7 @@ func.func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { func.func @invalid_custom_call(%arg0:memref<1xf32>, %arg1:memref<1xf32>) -> () { // expected-error @+1 {{entries in mapping for results must be >= 0 and less than target's number of results (3)}} - "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) { + "lmhlo.custom_call"(%arg0, %arg0, %arg1, %arg1) ({}) { backend_config = "", call_target_name = "foo", has_side_effects = false, @@ -1226,3 +1226,63 @@ func.func @invalid_float_abs_call(%input:memref<2xf32>, %result:memref<2xf64>) - "lmhlo.abs"(%input, %result) : (memref<2xf32>, memref<2xf64>) -> () func.return } + +// ----- + +// CHECK-LABEL: func @send_memrefs +func.func @send_memrefs(%arg0: memref<3xf32>) -> !mhlo.token { + // CHECK: lmhlo.send + // CHECK: channel_handle = #mhlo.channel_handle + // CHECK: frontend_attributes = {foo = "bar"} + // CHECK: is_host_transfer = true + %token = "lmhlo.send"(%arg0) { + channel_handle = #mhlo.channel_handle, + frontend_attributes = {foo = "bar"}, + is_host_transfer = true + } : (memref<3xf32>) -> (!mhlo.token) + return %token : !mhlo.token +} + +// ----- + +// CHECK-LABEL: func @send_done +func.func @send_done(%arg0: !mhlo.token) { + // CHECK: lmhlo.send_done + // CHECK: channel_handle = #mhlo.channel_handle + // CHECK: is_host_transfer = true + "lmhlo.send_done"(%arg0) { + channel_handle = #mhlo.channel_handle, + is_host_transfer = true + } : (!mhlo.token) -> () + return +} + +// ----- + +// CHECK-LABEL: func @recv_memrefs +func.func @recv_memrefs(%arg0: memref<3xf32>) -> !mhlo.token { + // CHECK: lmhlo.recv + // CHECK: channel_handle = #mhlo.channel_handle + // CHECK: frontend_attributes = {foo = "bar"} + // CHECK: is_host_transfer = true + %token = "lmhlo.recv"(%arg0) { + channel_handle = #mhlo.channel_handle, + frontend_attributes = {foo = "bar"}, + is_host_transfer = true + } : (memref<3xf32>) -> (!mhlo.token) + return %token : !mhlo.token +} + +// ----- + +// CHECK-LABEL: func @recv_done +func.func @recv_done(%arg0: !mhlo.token) { + // CHECK: lmhlo.recv_done + // CHECK: channel_handle = #mhlo.channel_handle + // CHECK: is_host_transfer = true + "lmhlo.recv_done"(%arg0) { + channel_handle = #mhlo.channel_handle, + is_host_transfer = true + } : (!mhlo.token) -> () + return +} \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/attrs.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/attrs.mlir new file mode 100644 index 00000000000..f42240a1628 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/attrs.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: parameter_replication +func.func @parameter_replication(%arg0: tensor {mhlo.parameter_replication = [true]}, %arg1: tuple, tuple>> {mhlo.parameter_replication = [false, true]}) -> tensor { + return %arg0 : tensor +} + +// ----- + +// CHECK-LABEL: parameter_replication +func.func @parameter_replication_empty(%arg0: tensor {mhlo.parameter_replication = []}, %arg1: tuple, tuple>> {mhlo.parameter_replication = []}) -> tensor { + return %arg0 : tensor +} + +// ----- + +// CHECK-LABEL: parameter_replication +func.func @parameter_replication_single_false(%arg0: tensor {mhlo.parameter_replication = [false]}, %arg1: tuple, tuple>> {mhlo.parameter_replication = [false]}) -> tensor { + return %arg0 : tensor +} + +// ----- + +// CHECK-LABEL: parameter_replication +func.func @parameter_replication_single_true(%arg0: tensor {mhlo.parameter_replication = [true]}, %arg1: tuple, tuple>> {mhlo.parameter_replication = [true]}) -> tensor { + return %arg0 : tensor +} + + +// ----- + +// expected-error@+1 {{parameter_replication: arg 0 has 1 leaf_buffers, but parameter_replication expects 2}} +func.func @parameter_replication_num_leaf_buffer_mismatch(%arg0: tensor {mhlo.parameter_replication = [true, false]}) -> tensor { + return %arg0 : tensor +} \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/bitcast.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/bitcast.mlir new file mode 100644 index 00000000000..4786e01d4a9 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/bitcast.mlir @@ -0,0 +1,71 @@ +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s + +// CHECK-LABEL:@no_layout +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4xf32> +// CHECK-NOT: bitcast +// CHECK: return %[[ARG0]] +func.func @no_layout(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { + %0 = "mhlo.bitcast"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> +} + +// ----- + +// CHECK-LABEL:@same_layout +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4xf32> +// CHECK-NOT: bitcast +// CHECK: return %[[ARG0]] +func.func @same_layout(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { + %0 = "mhlo.bitcast"(%arg) { + source_layout = dense<[1, 0]> : tensor<2xindex>, + result_layout = dense<[1, 0]> : tensor<2xindex> + }: (tensor<2x4xf32>) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> +} + +// ----- + +// CHECK-LABEL:@different_layout +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4xf32> +// CHECK: bitcast +func.func @different_layout(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { + %0 = "mhlo.bitcast"(%arg) { + source_layout = dense<[0, 1]> : tensor<2xindex>, + result_layout = dense<[1, 0]> : tensor<2xindex> + }: (tensor<2x4xf32>) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> +} + +// ----- + +// CHECK-LABEL:@source_layout_only +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4xf32> +// CHECK: bitcast +func.func @source_layout_only(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { + %0 = "mhlo.bitcast"(%arg) { + source_layout = dense<[0, 1]> : tensor<2xindex> + }: (tensor<2x4xf32>) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> +} + +// ----- + +// CHECK-LABEL:@result_layout_only +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4xf32> +// CHECK: bitcast +func.func @result_layout_only(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { + %0 = "mhlo.bitcast"(%arg) { + result_layout = dense<[1, 0]> : tensor<2xindex> + }: (tensor<2x4xf32>) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> +} + +// ----- + +// CHECK-LABEL:@type_cast +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4xf32> +// CHECK: bitcast +func.func @type_cast(%arg: tensor<2x4xf32>) -> tensor<2x4xi32> { + %0 = "mhlo.bitcast"(%arg): (tensor<2x4xf32>) -> tensor<2x4xi32> + func.return %0 : tensor<2x4xi32> +} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir index 166c5833b39..4ec3dd133a4 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: mlir-hlo-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // CHECK-LABEL: add_fold func.func @add_fold() -> tensor<4xi64> { @@ -252,6 +252,39 @@ func.func @min_fold_float() -> tensor<6xf32> { func.return %2 : tensor<6xf32> } +// CHECK-LABEL: clamp_scalar_fold +func.func @clamp_scalar_fold() -> tensor<5xi64> { + %0 = mhlo.constant dense<149> : tensor + %1 = mhlo.constant dense<[-1, 100, 200, 0, 149]> : tensor<5xi64> + %2 = mhlo.constant dense<0> : tensor + // CHECK{LITERAL}: mhlo.constant dense<[0, 100, 149, 0, 149]> + // CHECK-NOT: mhlo.clamp + %3 = mhlo.clamp %2, %1, %0 : (tensor, tensor<5xi64>, tensor) -> tensor<5xi64> + return %3 : tensor<5xi64> +} + +// CHECK-LABEL: clamp_fold +func.func @clamp_fold() -> tensor<5xi64> { + %0 = mhlo.constant dense<[149, 101, -1, 30, 50]> : tensor<5xi64> + %1 = mhlo.constant dense<[-1, 100, 200, 0, 149]> : tensor<5xi64> + %2 = mhlo.constant dense<[0, 10, -10, 10, -100]> : tensor<5xi64> + // CHECK{LITERAL}: mhlo.constant dense<[0, 100, -1, 10, 50]> + // CHECK-NOT: mhlo.clamp + %3 = mhlo.clamp %2, %1, %0 : (tensor<5xi64>, tensor<5xi64>, tensor<5xi64>) -> tensor<5xi64> + return %3 : tensor<5xi64> +} + +// CHECK-LABEL: clamp_fold_float +func.func @clamp_fold_float() -> tensor<6xf32> { + %0 = mhlo.constant dense<[5.0, 66.0, 0xFFFFFFFF, -2.0, 0xFFFFFFFF, 6.0]> : tensor<6xf32> + %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 0xFFFFFFFF, 0xFFFFFFFF, 4.0]> : tensor<6xf32> + %2 = mhlo.constant dense<[5.0, 1.0, 1.0, 0xFFFFFFFF, 0xFFFFFFFF, 5.0]> : tensor<6xf32> + // CHECK{LITERAL}: mhlo.constant dense<[5.000000e+00, 3.000000e+00, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 5.000000e+00] + // CHECK-NOT: mhlo.clamp + %3 = mhlo.clamp %2, %1, %0 : (tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) -> tensor<6xf32> + return %3 : tensor<6xf32> +} + // CHECK-LABEL: concatenate_noop func.func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { // CHECK-SAME: [[ARG:%.+]]: tensor<4xi32> @@ -781,6 +814,19 @@ func.func @dynamic_broadcast_in_dim_to_same_shape_4(%arg0: tensor<*xf32>) -> ten func.return %3 : tensor } +// CHECK-LABEL: func @dynamic_broadcast_in_dim_all_dims_non_expanding +func.func @dynamic_broadcast_in_dim_all_dims_non_expanding(%arg0: tensor<*xf32>, %arg1: tensor<1xindex>) -> tensor { + // CHECK-SAME: %[[ARG:.*]]: tensor<*xf32> + %1 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = dense<0> : tensor<1xi64>, + known_expanding_dimensions = dense<> : tensor<0xi64>, + known_nonexpanding_dimensions = dense<0> : tensor<1xi64> + } : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK: %[[RES:.*]] = tensor.cast %[[ARG]] : tensor<*xf32> to tensor + // CHECK: return %[[RES]] : tensor + func.return %1 : tensor +} + // CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d func.func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> { %cst = mhlo.constant dense<0.000000e+00> : tensor @@ -837,7 +883,7 @@ func.func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> { // CHECK-LABEL: @dynamic_iota_is_static_constant_arg func.func @dynamic_iota_is_static_constant_arg(%arg0: tensor<5xi32>) -> tensor { - // CHECK-NOTE: mhlo.dynamic_iota + // CHECK-NOT: mhlo.dynamic_iota // CHECK: [[RESULT:%.*]] = "mhlo.iota" // CHECK: [[CAST:%.*]] = tensor.cast [[RESULT]] // CHECK: return [[CAST]] @@ -1457,14 +1503,25 @@ func.func @fold_select_vector(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>) -> t } // CHECK-LABEL: func @simplify_not_as_select_pred( -// CHECK-SAME: [[ARGV0:%[a-zA-Z0-9_]+]]: tensor<4xi1> -// CHECK-SAME: [[ARGV1:%[a-zA-Z0-9_]+]]: tensor<4xf32> -// CHECK-SAME: [[ARGV2:%[a-zA-Z0-9_]+]]: tensor<4xf32> func.func @simplify_not_as_select_pred(%arg0 : tensor<4xi1>, %arg1 : tensor<4xf32>, %arg2 : tensor<4xf32>) -> tensor<4xf32> { %0 = "mhlo.not"(%arg0) : (tensor<4xi1>) -> tensor<4xi1> %1 = "mhlo.select"(%0, %arg1, %arg2) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - // CHECK: mhlo.select [[ARGV0]], [[ARGV2]], [[ARGV1]] func.return %1 : tensor<4xf32> + + // CHECK: %[[R:.*]] = mhlo.select %arg0, %arg2, %arg1 + // CHECK: return %[[R]] +} + +// CHECK-LABEL: func @simplify_broadcasted_not_as_select_pred( +func.func @simplify_broadcasted_not_as_select_pred(%arg0 : tensor<1xi1>, %arg1 : tensor<4xf32>, %arg2 : tensor<4xf32>) -> tensor<4xf32> { + %0 = "mhlo.not"(%arg0) : (tensor<1xi1>) -> tensor<1xi1> + %1 = "mhlo.broadcast_in_dim"(%0) {broadcast_dimensions = dense<[0]> : tensor<1xi64> } : (tensor<1xi1>) -> tensor<4xi1> + %2 = "mhlo.select"(%1, %arg1, %arg2) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + func.return %2 : tensor<4xf32> + + // CHECK: %[[B:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xi1>) -> tensor<4xi1> + // CHECK: %[[R:.*]] = mhlo.select %[[B]], %arg2, %arg1 + // CHECK: return %[[R]] } // CHECK-LABEL: gather_to_slice @@ -1838,6 +1895,15 @@ func.func @not_fold_rsqrt_const_zero() -> tensor<4xf32> { func.return %1 : tensor<4xf32> } +// CHECK-LABEL: func @fold_abs +func.func @fold_abs() -> tensor<4xf32> { + %0 = mhlo.constant dense<-1.0> : tensor<4xf32> + %1 = "mhlo.abs"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<1.000000e+00> : tensor<4xf32> + // CHECK-NOT: mhlo.abs + func.return %1 : tensor<4xf32> +} + // CHECK-LABEL: func @fold_sine func.func @fold_sine() -> tensor<4xf32> { %0 = mhlo.constant dense<2.0> : tensor<4xf32> @@ -1883,6 +1949,24 @@ func.func @fold_logistic() -> tensor<4xf32> { func.return %1 : tensor<4xf32> } +// CHECK-LABEL: func @fold_log +func.func @fold_log() -> tensor<4xf32> { + %0 = mhlo.constant dense<2.0> : tensor<4xf32> + %1 = "mhlo.log"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<0.693147182> : tensor<4xf32> + // CHECK-NOT: mhlo.log + func.return %1 : tensor<4xf32> +} + +// CHECK-LABEL: func @not_fold_log_neg_constants +func.func @not_fold_log_neg_constants() -> tensor<4xf32> { + %0 = mhlo.constant dense<-1.0> : tensor<4xf32> + %1 = "mhlo.log"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<-1.000000e+00> : tensor<4xf32> + // CHECK: mhlo.log + func.return %1 : tensor<4xf32> +} + // CHECK-LABEL: func @fold_if_true( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] @@ -2596,11 +2680,68 @@ func.func public @reshape_splat_of_bools() -> tensor<2x1xi1> { return %1 : tensor<2x1xi1> } -// CHECK-LABEL: @simplify_dynamic_gather -func.func @simplify_dynamic_gather(%arg0: tensor<375682x256xf16>, %arg1: tensor<16x64xi64>) -> tensor<16x64x256xf16> { +// CHECK-LABEL: @simplify_dynamic_gather_i64 +func.func @simplify_dynamic_gather_i64(%arg0: tensor<375682x256xf16>, %arg1: tensor<16x64xi64>) -> tensor<16x64x256xf16> { %0 = "arith.constant"() {value = dense<[1, 256]> : tensor<2xi64>} : () -> tensor<2xi64> %1 = "mhlo.dynamic_gather"(%arg0, %arg1, %0) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor<375682x256xf16>, tensor<16x64xi64>, tensor<2xi64>) -> tensor<16x64x256xf16> // CHECK: %[[RET:.+]] = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 256]> : tensor<2xi64>} : (tensor<375682x256xf16>, tensor<16x64xi64>) -> tensor<16x64x256xf16> // CHECK: return %[[RET]] return %1 : tensor<16x64x256xf16> } + +// CHECK-LABEL: @simplify_dynamic_gather_i32 +func.func @simplify_dynamic_gather_i32(%arg0: tensor<375682x256xf16>, %arg1: tensor<16x64xi64>) -> tensor<16x64x256xf16> { + %0 = "arith.constant"() {value = dense<[1, 256]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "mhlo.dynamic_gather"(%arg0, %arg1, %0) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor<375682x256xf16>, tensor<16x64xi64>, tensor<2xi32>) -> tensor<16x64x256xf16> + // CHECK: %[[RET:.+]] = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 256]> : tensor<2xi64>} : (tensor<375682x256xf16>, tensor<16x64xi64>) -> tensor<16x64x256xf16> + // CHECK: return %[[RET]] + return %1 : tensor<16x64x256xf16> +} + +// CHECK-LABEL: @fold_reduce_window +func.func @fold_reduce_window(%arg0: tensor<1x1x20xf32>) -> tensor<1x1x20xf32> { + %cst_0 = mhlo.constant dense<0.000000e+00> : tensor + %r = "mhlo.reduce_window"(%arg0, %cst_0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %s = mhlo.add %arg1, %arg2 : tensor + mhlo.return %s : tensor + }) { + padding = dense<0> : tensor<3x2xi64>, + window_dimensions = dense<1> : tensor<3xi64>, + window_strides = dense<1> : tensor<3xi64> + } : (tensor<1x1x20xf32>, tensor) -> tensor<1x1x20xf32> + func.return %r : tensor<1x1x20xf32> + + // CHECK: return %arg0 : tensor<1x1x20xf32> +} + +// CHECK-LABEL: @simplify_real_dynamic_slice_to_slice +func.func @simplify_real_dynamic_slice_to_slice(%arg0: tensor) -> tensor<1x4xf32> { + %0 = mhlo.constant dense<[0, 0]> : tensor<2xi32> + %1 = mhlo.constant dense<[1, 4]> : tensor<2xi32> + %2 = mhlo.constant dense<[1, 1]> : tensor<2xi32> + %3 = mhlo.real_dynamic_slice %arg0, %0, %1, %2 : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x4xf32> + // CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0) + // CHECK-DAG-SAME: start_indices = dense<[0, 0]> : tensor<2xi64> + // CHECK-DAG-SAME: limit_indices = dense<[1, 4]> : tensor<2xi64> + // CHECK-DAG-SAME: strides = dense<[1, 1]> : tensor<2xi64>} + // CHECK: return %[[RESULT]] : tensor<1x4xf32> + return %3 : tensor<1x4xf32> +} + +// CHECK-LABEL: @simplify_real_dynamic_slice_to_dynamic_slice +func.func @simplify_real_dynamic_slice_to_dynamic_slice(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor<1x4xf32> { + %0 = mhlo.constant dense<[1, 4]> : tensor<2xi32> + %1 = mhlo.add %arg1, %0 : tensor<2xi32> + %2 = mhlo.constant dense<[1, 1]> : tensor<2xi32> + %3 = mhlo.real_dynamic_slice %arg0, %arg1, %1, %2 : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x4xf32> + return %3 : tensor<1x4xf32> + // CHECK: [[START_INDEX_0_1D:%.*]] = "mhlo.slice"(%arg1) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: [[START_INDEX_0_0D:%.*]] = mhlo.reshape [[START_INDEX_0_1D]] : (tensor<1xi32>) -> tensor + // CHECK-NEXT: [[START_INDEX_1_1D:%.*]] = "mhlo.slice"(%arg1) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: [[START_INDEX_1_0D:%.*]] = mhlo.reshape [[START_INDEX_1_1D]] : (tensor<1xi32>) -> tensor + // CHECK-NEXT: [[RESULT:%.*]] = "mhlo.dynamic_slice"(%arg0, [[START_INDEX_0_0D]], [[START_INDEX_1_0D]]) { + // CHECK-SAME: slice_sizes = dense<[1, 4]> : tensor<2xi64> + // CHECK-SAME: } : (tensor, tensor, tensor) -> tensor<1x4xf32> + // CHECK-NEXT: return [[RESULT]] : tensor<1x4xf32> +} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/concatenate.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/concatenate.mlir index 50574cb0a24..fc937c1657e 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/concatenate.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/concatenate.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: mlir-hlo-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // CHECK-LABEL: func @single_operand // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/convert.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/convert.mlir index 5e3ee42651d..8bd55f3e070 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/convert.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/convert.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // ----- diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/convolution.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/convolution.mlir index 4c9cddeaa76..452168c923c 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/convolution.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/convolution.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // CHECK-LABEL: @dot_general_is_dot func.func @dot_general_is_dot(%arg0: tensor<5x6xf32>, %arg1: tensor<6x?xf32>) -> tensor<5x?xf32> { @@ -101,4 +101,28 @@ func.func @conv_grouped_is_dot_transpose_out(%arg0: tensor<5x4xf32>, %arg1: tens %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f]x[i, o]->[f, b], window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<5x4xf32>, tensor<2x6xf32>) -> tensor<6x5xf32> // CHECK: return %[[OUT]] return %0 : tensor<6x5xf32> -} \ No newline at end of file +} + +// ----- + +// CHECK-LABEL: @dynamic_conv2d_padding +func.func @dynamic_conv2d_padding(%arg0: tensor<1x8x8x32x207xf32>, %arg1: tensor<3x3x32x207x16xf32>) -> tensor<32x1x8x8x16xf32> { + %pad = arith.constant dense<[2, 0, 1, 1]> : tensor<4xi64> + // CHECK: %[[CONV:.+]] = mhlo.convolution + // CHECK-SAME: pad = {{\[\[}}2, 0], [1, 1]] + %0 = "mhlo.dynamic_conv"(%arg0, %arg1, %pad) {batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = [#mhlo, #mhlo], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : + (tensor<1x8x8x32x207xf32>, tensor<3x3x32x207x16xf32>, tensor<4xi64>) -> tensor<32x1x8x8x16xf32> + // CHECK: return %[[CONV]] + func.return %0 : tensor<32x1x8x8x16xf32> +} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/custom_call.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/custom_call.mlir index 7a5abf99886..ca31f880b7f 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/custom_call.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/custom_call.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // CHECK-LABEL:@noeffect func.func @noeffect(%arg0: tensor<8xf32>) -> tensor<8xf32> { diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/folder_limit.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/folder_limit.mlir index a18184f364a..1d62442436c 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/folder_limit.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/folder_limit.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: mlir-hlo-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // CHECK-LABEL: func @add_lhs_zero func.func @add_lhs_zero(%lhs: tensor<65537xi8>) -> tensor<65537xi8> { diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reduce.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reduce.mlir index 4fb78183951..6c38f5111dd 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reduce.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reduce.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // ----- diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reshape.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reshape.mlir index 8aad9915069..f15829fd2ad 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reshape.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reshape.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // CHECK-LABEL: func @const_fold_collapse_to_scalar func.func @const_fold_collapse_to_scalar() -> tensor { diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reverse.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reverse.mlir index 5afe47a7551..33880f6adb1 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reverse.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/reverse.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: mlir-hlo-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // CHECK-LABEL: func @noop // CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>) diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/scatter.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/scatter.mlir index cd8d3e869ed..1a7a6f343e9 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/scatter.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/scatter.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // Folding this case would explode the IR func.func @scatter_fold_explosion() -> tensor<512x1x6400x6400xf32> { diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/transpose.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/transpose.mlir index 719c1f21c3d..38f2cce820d 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/transpose.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/transpose.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // CHECK-LABEL: func @transpose_splat_constant func.func @transpose_splat_constant() -> tensor<5x10xf32> { diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/tuple.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/tuple.mlir index 769018d2ad1..77f31ab6eb0 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/tuple.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/tuple.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // CHECK-LABEL: func @fold_access // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/expand_hlo_tuples.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/expand_hlo_tuples.mlir index 6f2c0036118..30b97f1fca9 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/expand_hlo_tuples.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/expand_hlo_tuples.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -split-input-file -expand-hlo-tuples='entry-function=main' | FileCheck %s +// RUN: mlir-hlo-opt %s -split-input-file -expand-hlo-tuples='entry-function=main' -allow-unregistered-dialect | FileCheck %s // Check if the `expand-hlo-tuples` pass adds the right variable to return_op and function return type. func.func @main(%arg0: tensor<1x1xf32>, %arg1: tensor<1x8x8x16xf32>) -> tuple, tensor<1xf32>> { @@ -22,10 +22,10 @@ func.func @main(%arg0: tuple, tensor<1xf32>>) -> tuple, tensor<1xf32>> } -// CHECK: func @main(%[[VAL_0:.*]]: tensor<1024xf32>, %[[VAL_1:.*]]: tensor<1xf32>) -> (tensor<1024xf32>, tensor<1xf32>) { -// CHECK: %[[VAL_2:.*]] = mhlo.tuple %[[VAL_0]], %[[VAL_1]] : tuple, tensor<1xf32>> -// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1024xf32>, tensor<1xf32> -// CHECK: } +// CHECK: func @main(%[[VAL_0:.*]]: tensor<1024xf32>, %[[VAL_1:.*]]: tensor<1xf32>) -> (tensor<1024xf32>, tensor<1xf32>) { +// CHECK: %[[VAL_2:.*]] = mhlo.tuple %[[VAL_0]], %[[VAL_1]] : tuple, tensor<1xf32>> +// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1024xf32>, tensor<1xf32> +// CHECK: } // ----- @@ -35,5 +35,18 @@ func.func @main() -> tuple<> { } // CHECK-LABEL: func @main() { -// CHECK: return{{$}} -// CHECK: } +// CHECK: return{{$}} +// CHECK: } + +// ----- + +func.func @main() -> tuple, tensor<1xi32>> { + %0 = "test.dummy"() : () -> tuple, tensor<1xi32>> + func.return %0 : tuple, tensor<1xi32>> +} + +// CHECK-LABEL: func @main() +// CHECK: %[[TUPLE:.*]] = "test.dummy"() +// CHECK: %[[T0:.*]] = mhlo.get_tuple_element %[[TUPLE]][0] +// CHECK: %[[T1:.*]] = mhlo.get_tuple_element %[[TUPLE]][1] +// CHECK: return %[[T0]], %[[T1]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/expand_ops_simplifier.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/expand_ops_simplifier.mlir new file mode 100644 index 00000000000..58f60bb795c --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/expand_ops_simplifier.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-hlo-opt %s -split-input-file -mhlo-expand-ops-simplifier | FileCheck %s + +func.func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = "mhlo.compare"(%arg3, %arg4) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = mhlo.add %arg3, %arg4 : tensor + "mhlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> + func.return %1 : tensor<10x24x24x64xf32> +} + +// CHECK-LABEL: func @main +// CHECK-SAME: %[[OPERAND:.*]]: tensor<10x24x24x64xf32>, +// CHECK-SAME: %[[SOURCE:.*]]: tensor<10x12x12x64xf32> +// CHECK-SAME: -> tensor<10x24x24x64xf32> +// CHECK-DAG: %[[NEG_1:.*]] = mhlo.constant dense<-1> : tensor +// CHECK-DAG: %[[INIT:.*]] = mhlo.constant dense<0.000000e+00> : tensor<10x24x24x64xf32> +// CHECK-DAG: %[[C0:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[IOTA_0:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<10x24x24x64xi64> +// CHECK: %[[IOTA_1:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<10x24x24x64xi64> +// CHECK: %[[IOTA_2:.*]] = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<10x24x24x64xi64> +// CHECK: %[[IOTA_3:.*]] = "mhlo.iota"() {iota_dimension = 3 : i64} : () -> tensor<10x24x24x64xi64> +// CHECK: %[[REDUCE_WINDOW:.*]]:5 = "mhlo.reduce_window"(%[[OPERAND]], %[[IOTA_0]], %[[IOTA_1]], %[[IOTA_2]], %[[IOTA_3]], %[[C0]], %[[NEG_1]], %[[NEG_1]], %[[NEG_1]], %[[NEG_1]]) ({ +// CHECK: ^bb0(%[[VAL_10:.*]]: tensor, %[[VAL_11:.*]]: tensor, %[[VAL_12:.*]]: tensor, %[[VAL_13:.*]]: tensor, %[[VAL_14:.*]]: tensor, %[[VAL_15:.*]]: tensor, %[[VAL_16:.*]]: tensor, %[[VAL_17:.*]]: tensor, %[[VAL_18:.*]]: tensor, %[[VAL_19:.*]]: tensor): +// CHECK: %[[VAL_20:.*]] = mhlo.compare NE, %[[VAL_11]], %[[NEG_1]] +// CHECK: %[[VAL_21:.*]] = mhlo.compare NE, %[[VAL_16]], %[[NEG_1]] +// CHECK: %[[VAL_22:.*]] = mhlo.not %[[VAL_21]] : tensor +// CHECK: %[[VAL_23:.*]] = mhlo.compare GE, %[[VAL_10]], %[[VAL_15]] +// CHECK: %[[VAL_24:.*]] = mhlo.and %[[VAL_23]], %[[VAL_20]] : tensor +// CHECK: %[[VAL_25:.*]] = mhlo.or %[[VAL_24]], %[[VAL_22]] : tensor +// CHECK: %[[SELECTED_0:.*]] = mhlo.select %[[VAL_25]], %[[VAL_10]], %[[VAL_15]] +// CHECK: %[[SELECTED_1:.*]] = mhlo.select %[[VAL_25]], %[[VAL_11]], %[[VAL_16]] +// CHECK: %[[SELECTED_2:.*]] = mhlo.select %[[VAL_25]], %[[VAL_12]], %[[VAL_17]] +// CHECK: %[[SELECTED_3:.*]] = mhlo.select %[[VAL_25]], %[[VAL_13]], %[[VAL_18]] +// CHECK: %[[SELECTED_4:.*]] = mhlo.select %[[VAL_25]], %[[VAL_14]], %[[VAL_19]] +// CHECK: mhlo.return %[[SELECTED_0]], %[[SELECTED_1]], %[[SELECTED_2]], %[[SELECTED_3]], %[[SELECTED_4]] +// CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<10x24x24x64xf32>, tensor<10x24x24x64xi64>, tensor<10x24x24x64xi64>, tensor<10x24x24x64xi64>, tensor<10x24x24x64xi64>, tensor, tensor, tensor, tensor, tensor) -> (tensor<10x12x12x64xf32>, tensor<10x12x12x64xi64>, tensor<10x12x12x64xi64>, tensor<10x12x12x64xi64>, tensor<10x12x12x64xi64>) +// CHECK: %[[RESHAPE_0:.*]] = mhlo.reshape %[[REDUCE_WINDOW]]#1 : (tensor<10x12x12x64xi64>) -> tensor<10x12x12x64x1xi64> +// CHECK: %[[RESHAPE_1:.*]] = mhlo.reshape %[[REDUCE_WINDOW]]#2 : (tensor<10x12x12x64xi64>) -> tensor<10x12x12x64x1xi64> +// CHECK: %[[RESHAPE_2:.*]] = mhlo.reshape %[[REDUCE_WINDOW]]#3 : (tensor<10x12x12x64xi64>) -> tensor<10x12x12x64x1xi64> +// CHECK: %[[RESHAPE_3:.*]] = mhlo.reshape %[[REDUCE_WINDOW]]#4 : (tensor<10x12x12x64xi64>) -> tensor<10x12x12x64x1xi64> +// CHECK: %[[CONCAT:.*]] = "mhlo.concatenate"(%[[RESHAPE_0]], %[[RESHAPE_1]], %[[RESHAPE_2]], %[[RESHAPE_3]]) {dimension = 4 : i64} +// CHECK: %[[SCATTER:.*]] = "mhlo.scatter"(%[[INIT]], %[[CONCAT]], %[[SOURCE]]) ({ +// CHECK: ^bb0(%[[VAL_38:.*]]: tensor, %[[VAL_39:.*]]: tensor): +// CHECK: %[[UPDATE:.*]] = mhlo.add %[[VAL_38]], %[[VAL_39]] : tensor +// CHECK: mhlo.return %[[UPDATE]] : tensor +// CHECK: }) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter, unique_indices = false} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64x4xi64>, tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> +// CHECK: return %[[SCATTER]] : tensor<10x24x24x64xf32> +// CHECK: } \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir index 129c5b2f1b8..afae8cf8ab3 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-lhlo.mlir @@ -619,7 +619,8 @@ func.func @transpose(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK-LABEL: func @custom_call // CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>) func.func @custom_call(%arg0: tensor<2x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<4x4xf16> { - // CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = array} + // CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) ({ + // CHECK-NEXT: }) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = array} %result = "mhlo.custom_call"(%arg0, %arg1) {backend_config = "", call_target_name = "foo", has_side_effect = false} : (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16> @@ -631,7 +632,8 @@ func.func @custom_call(%arg0: tensor<2x2xf32>, %arg1: tensor<2x3xf32>) -> tensor // CHECK-LABEL: func @custom_call_multiout // CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>) func.func @custom_call_multiout(%arg0: tensor<2x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<4x4xf16> { - // CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}, %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = array} + // CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}, %{{.*}}) ({ + // CHECK-NEXT: }) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = array} %temp:2 = "mhlo.custom_call"(%arg0, %arg1) {backend_config = "", call_target_name = "foo", has_side_effect = false} : (tensor<2x2xf32>, tensor<2x3xf32>) -> (tensor<4x4xf16>, tensor<4x4xf16>) diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir index 1f4be3c535e..45150a6f0ee 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir @@ -21,7 +21,6 @@ func.func @float_add(%lhs: tensor<2x2xf32>, // CHECK-PRIMITIVE: linalg.map // CHECK-PRIMITIVE: arith.addf - // CHECK-PRIMITIVE: linalg.yield %0 = "mhlo.add"(%lhs, %rhs) {someattr} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> func.return %0 : tensor<2x2xf32> @@ -29,6 +28,27 @@ func.func @float_add(%lhs: tensor<2x2xf32>, // ----- +// CHECK-LABEL: func @float_add_dynamic_encoding +// CHECK-PRIMITIVE-LABEL: func @float_add_dynamic_encoding +func.func @float_add_dynamic_encoding( + %lhs: tensor<2x?xf32, #mhlo.type_extensions>, + %rhs: tensor<2x?xf32, #mhlo.type_extensions>) + -> tensor<2x?xf32, #mhlo.type_extensions> { + // CHECK: linalg.generic + // CHECK: arith.addf + // CHECK: linalg.yield + + // CHECK-PRIMITIVE: linalg.map + // CHECK-PRIMITIVE: arith.addf + %0 = "mhlo.add"(%lhs, %rhs) {someattr} + : (tensor<2x?xf32, #mhlo.type_extensions>, + tensor<2x?xf32, #mhlo.type_extensions>) + -> tensor<2x?xf32, #mhlo.type_extensions> + func.return %0 : tensor<2x?xf32, #mhlo.type_extensions> +} + +// ----- + // CHECK-LABEL: integer_add // CHECK-PRIMITIVE-LABEL: integer_add func.func @integer_add(%lhs: tensor<2x2xi32>, @@ -213,19 +233,13 @@ func.func @complex_rsqrt(%operand: tensor<2x2xcomplex>) func.func @float_cbrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { %tensor_result = "mhlo.cbrt"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK: %[[THIRD:.+]] = arith.constant 0.333333343 // CHECK: ^{{[a-z0-9_]*}} // CHECK-SAME: %[[IN:[a-zA-Z0-9_]*]]: f32 - // CHECK-SAME: %[[OUT:[a-zA-Z0-9_]*]]: f32 - // CHECK: %[[ABS:.+]] = math.absf %[[IN]] - // CHECK: %[[POW:.+]] = math.powf %[[ABS]], %[[THIRD]] - // CHECK: %[[RESULT:.+]] = math.copysign %[[POW]], %[[IN]] + // CHECK: %[[RESULT:.+]] = math.cbrt %[[IN]] // CHECK: linalg.yield %[[RESULT]] // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE: math.absf - // CHECK-PRIMITIVE: math.powf - // CHECK-PRIMITIVE: math.copysign + // CHECK-PRIMITIVE: math.cbrt func.return %tensor_result : tensor<2x2xf32> } @@ -283,11 +297,10 @@ func.func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK-SAME: {someattr} // CHECK: math.absf - // CHECK-PRIMITIVE: linalg.map - // CHECK-PRIMITIVE-NEXT: ins( - // CHECK-PRIMITIVE-NEXT: outs( + // CHECK-PRIMITIVE: linalg.map { math.absf } + // CHECK-PRIMITIVE-SAME: ins( + // CHECK-PRIMITIVE-SAME: outs( // CHECK-PRIMITIVE-SAME: {someattr} - // CHECK-PRIMITIVE: math.absf %0 = "mhlo.abs"(%arg0) {someattr} : (tensor<2x2xf32>) -> tensor<2x2xf32> func.return %0 : tensor<2x2xf32> } @@ -657,8 +670,8 @@ func.func @float_cmp_totalorder(%lhs: tensor<2x2xbf16>, // CHECK-PRIMITIVE-DAG: %[[C0:.*]] = arith.constant 0 : i16 // CHECK-PRIMITIVE-DAG: %[[C32767:.*]] = arith.constant 32767 : i16 // CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE-NEXT: ins( -// CHECK-PRIMITIVE-NEXT: outs( +// CHECK-PRIMITIVE-SAME: ins( +// CHECK-PRIMITIVE-SAME: outs( // CHECK-PRIMITIVE-NEXT: (%[[LHS_IN:[a-zA-Z0-9]*]]: bf16, %[[RHS_IN:.*]]: bf16) { // CHECK-PRIMITIVE-NEXT: %[[LHS_INT:.*]] = arith.bitcast %[[LHS_IN]] : bf16 to i16 // CHECK-PRIMITIVE-NEXT: %[[LHS_CMP:.*]] = arith.cmpi slt, %[[LHS_INT]], %[[C0]] : i16 @@ -858,12 +871,9 @@ func.func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, // CHECK-PRIMITIVE-LABEL: func @select // CHECK-PRIMITIVE: tensor.empty() : tensor<2x2xf32> -// CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE-NEXT: ins( -// CHECK-PRIMITIVE-NEXT: outs( -// CHECK-PRIMITIVE-NEXT: (%[[PRED_IN:[a-zA-Z0-9]*]]: i1, %[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32) { -// CHECK-PRIMITIVE-NEXT: %[[RESULT:.*]] = arith.select %[[PRED_IN]], %[[LHS_IN]], %[[RHS_IN]] : f32 -// CHECK-PRIMITIVE-NEXT: linalg.yield %[[RESULT]] : f32 +// CHECK-PRIMITIVE: linalg.map { arith.select } +// CHECK-PRIMITIVE-SAME: ins( +// CHECK-PRIMITIVE-SAME: outs( // ----- @@ -895,8 +905,8 @@ func.func @select_scalar_pred_dyn(%pred : tensor, %lhs: tensor<2x?xf32>, %rh // CHECK-PRIMITIVE-DAG: %[[DST:.*]] = tensor.empty(%[[DIM]]) // CHECK-PRIMITIVE-DAG: %[[PRED_ELEM:.*]] = tensor.extract %[[PRED]] // CHECK-PRIMITIVE: linalg.map -// CHECK-PRIMITIVE-NEXT: ins(%[[LHS]], %[[RHS]] : tensor<2x?xf32>, tensor<2x?xf32>) -// CHECK-PRIMITIVE-NEXT: outs(%[[DST]] : tensor<2x?xf32>) +// CHECK-PRIMITIVE-SAME: ins(%[[LHS]], %[[RHS]] : tensor<2x?xf32>, tensor<2x?xf32>) +// CHECK-PRIMITIVE-SAME: outs(%[[DST]] : tensor<2x?xf32>) // CHECK-PRIMITIVE-SAME: {someattr} // CHECK-PRIMITIVE: (%[[LHS_:.*]]: f32, %[[RHS_:.*]]: f32) { // CHECK-PRIMITIVE: %[[RES:.*]] = arith.select %[[PRED_ELEM]], %[[LHS_]], %[[RHS_]] : f32 @@ -904,6 +914,16 @@ func.func @select_scalar_pred_dyn(%pred : tensor, %lhs: tensor<2x?xf32>, %rh // ----- +// CHECK-LABEL: func @select_mixed +func.func @select_mixed(%pred: tensor<2x?xi1>, %lhs: tensor, + %rhs: tensor<2x2xf32>) -> tensor { + %0 = "mhlo.select"(%pred, %lhs, %rhs) + : (tensor<2x?xi1>, tensor, tensor<2x2xf32>) -> (tensor) + func.return %0 : tensor +} + +// ----- + // CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @broadcast_scalar @@ -916,6 +936,13 @@ func.func @broadcast_scalar(%arg: tensor) -> tensor<4x2x1xf32> { // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 +// CHECK-PRIMITIVE-LABEL: func @broadcast_scalar +// CHECK-PRIMITIVE: tensor.empty() : tensor<4x2x1xf32> +// CHECK-PRIMITIVE: linalg.broadcast +// CHECK-PRIMITIVE-SAME: ins( +// CHECK-PRIMITIVE-SAME: outs( +// CHECK-PRIMITIVE-SAME: dimensions = [0, 1, 2] + // ----- // CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> @@ -932,6 +959,13 @@ func.func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> { // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 +// CHECK-PRIMITIVE-LABEL: func @broadcast +// CHECK-PRIMITIVE-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-PRIMITIVE: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C1]] : tensor<4x?x16xf32> +// CHECK-PRIMITIVE: %{{.*}} = tensor.empty(%[[DIM]]) : tensor<4x2x1x4x?x16xf32> +// CHECK-PRIMITIVE: linalg.broadcast +// CHECK-PRIMITIVE: dimensions = [0, 1, 2] + // ----- // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)> @@ -948,6 +982,14 @@ func.func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf3 // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 +// CHECK-PRIMITIVE-LABEL: func @broadcast_in_dim +// CHECK-PRIMITIVE: tensor.collapse_shape +// CHECK-PRIMITIVE: linalg.transpose +// CHECK-PRIMITIVE: permutation = [1, 0] +// CHECK-PRIMITIVE: tensor.empty() : tensor<7x10x6x4x5xf32> +// CHECK-PRIMITIVE: linalg.broadcast +// CHECK-PRIMITIVE: dimensions = [1, 2, 3] + // ----- // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)> @@ -966,6 +1008,15 @@ func.func @broadcast_in_dim_ui32(%operand: tensor<5x7x1xui32>) -> tensor<7x10x6x // CHECK-NEXT: linalg.yield %[[OPERAND]] : i32 // CHECK: builtin.unrealized_conversion_cast %[[RES]] : tensor<7x10x6x4x5xi32> to tensor<7x10x6x4x5xui32> +// CHECK-PRIMITIVE-LABEL: func @broadcast_in_dim_ui32 +// CHECK-PRIMITIVE: tensor.collapse_shape +// CHECK-PRIMITIVE: linalg.transpose +// CHECK-PRIMITIVE: permutation = [1, 0] +// CHECK-PRIMITIVE: tensor.empty() : tensor<7x10x6x4x5xi32> +// CHECK-PRIMITIVE: %[[RES:.*]] = linalg.broadcast +// CHECK-PRIMITIVE: dimensions = [1, 2, 3] +// CHECK-PRIMITIVE: builtin.unrealized_conversion_cast %[[RES]] : tensor<7x10x6x4x5xi32> to tensor<7x10x6x4x5xui32> + // ----- // CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1) -> (d0)> @@ -983,12 +1034,43 @@ func.func @broadcast_in_dim_with_one_to_one( // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 +// CHECK-PRIMITIVE-LABEL: func @broadcast_in_dim_with_one_to_one +// CHECK-PRIMITIVE-NOT: tensor.collapse_shape +// CHECK-PRIMITIVE-NOT: linalg.transpose +// CHECK-PRIMITIVE: linalg.broadcast +// CHECK-PRIMITIVE: dimensions = [1] + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d0, d1)> +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @broadcast_in_dim_with_transpose +func.func @broadcast_in_dim_with_transpose( + %operand: tensor<2x3x4xf32>) -> tensor<3x4x2x5xf32> { + %0 = "mhlo.broadcast_in_dim"(%operand) + {broadcast_dimensions = dense<[2, 0, 1]> : tensor<3xi64>} + : (tensor<2x3x4xf32>) -> tensor<3x4x2x5xf32> + func.return %0 : tensor<3x4x2x5xf32> +} +// CHECK: tensor.empty() : tensor<3x4x2x5xf32> +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// CHECK-PRIMITIVE-LABEL: func @broadcast_in_dim_with_transpose +// CHECK-PRIMITIVE: tensor.empty() : tensor<3x4x2xf32> +// CHECK-PRIMITIVE: linalg.transpose +// CHECK-PRIMITIVE: permutation = [1, 2, 0] +// CHECK-PRIMITIVE: tensor.empty() : tensor<3x4x2x5xf32> +// CHECK-PRIMITIVE: linalg.broadcast +// CHECK-PRIMITIVE: dimensions = [3] + // ----- // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> ()> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-LABEL: func @broadcast_scalar -func.func @broadcast_scalar(%operand: tensor) -> tensor<7x10x6xf32> { +// CHECK-LABEL: func @broadcast_in_dim_scalar +func.func @broadcast_in_dim_scalar(%operand: tensor) -> tensor<7x10x6xf32> { %0 = "mhlo.broadcast_in_dim"(%operand) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<7x10x6xf32> @@ -999,6 +1081,11 @@ func.func @broadcast_scalar(%operand: tensor) -> tensor<7x10x6xf32> { // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %{{.*}}: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 +// CHECK-PRIMITIVE-LABEL: func @broadcast_in_dim_scalar +// CHECK-PRIMITIVE: tensor.empty() : tensor<7x10x6xf32> +// CHECK-PRIMITIVE: linalg.broadcast +// CHECK-PRIMITIVE: dimensions = [0, 1, 2] + // ----- // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3, d2)> @@ -1044,11 +1131,23 @@ func.func @transpose_dynamic(%arg0: tensor) -> tensor // CHECK-PRIMITIVE: %[[D3:.*]] = tensor.dim %arg0, %[[C3]] // CHECK-PRIMITIVE: %[[INIT:.*]] = tensor.empty(%[[D1]], %[[D0]], %[[D3]]) : tensor // CHECK-PRIMITIVE: linalg.transpose -// CHECK-PRIMITIVE-NEXT: ins(%arg0 : tensor) -// CHECK-PRIMITIVE-NEXT: outs(%[[INIT]] : tensor) -// CHECK-PRIMITIVE-NEXT: permutation = [1, 0, 3, 2] +// CHECK-PRIMITIVE-SAME: ins(%arg0 : tensor) +// CHECK-PRIMITIVE-SAME: outs(%[[INIT]] : tensor) +// CHECK-PRIMITIVE-SAME: permutation = [1, 0, 3, 2] // CHECK-PRIMITIVE-SAME: {someattr} +func.func @transpose_unsigned(%arg0: tensor<2x2xui32>) -> tensor<2x2xui32> { + %0 = "mhlo.transpose"(%arg0) { + permutation = dense<[1, 0]> : tensor<2xi64>, + result_layout = dense<[0, 1]> : tensor<2xindex> + } : (tensor<2x2xui32>) -> tensor<2x2xui32> + return %0 : tensor<2x2xui32> +} + +// Regression test. Just check that unsigned ints lower successfully. +// CHECK-LABEL: func @transpose_unsigned +// CHECK-PRIMITIVE-LABEL: func @transpose_unsigned + // ----- // CHECK-LABEL: func @real_dynamic_slice @@ -2108,7 +2207,7 @@ func.func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor { } : (tensor, tensor<1xindex>) -> tensor func.return %result : tensor } -// CHECK: [[CST:%.*]] = arith.constant +// CHECK: [[CST:%.*]] = arith.constant dense // CHECK: [[INIT:%.*]] = tensor.empty // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] @@ -2117,6 +2216,13 @@ func.func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor { // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 +// CHECK-PRIMITIVE-LABEL: func @dynamic_broadcast_in_dim +// CHECK-PRIMITIVE: [[CST:%.*]] = arith.constant dense +// CHECK-PRIMITIVE: [[INIT:%.*]] = tensor.empty +// CHECK-PRIMITIVE: linalg.broadcast +// CHECK-PRIMITIVE-SAME: ins([[CST]] +// CHECK-PRIMITIVE-SAME: outs([[INIT]] + // ----- // CHECK: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> ()> @@ -2139,6 +2245,10 @@ func.func @dynamic_broadcast_in_dim(%scalar: tensor, %shape: tensor<2xindex // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 +// CHECK-PRIMITIVE-LABEL: func @dynamic_broadcast_in_dim +// CHECK-PRIMITIVE: tensor.empty +// CHECK-PRIMITIVE: linalg.broadcast + // ----- // CHECK: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1)> @@ -2161,9 +2271,15 @@ func.func @dynamic_broadcast_in_dim(%vector: tensor<42xf32>, %shape: tensor<3xin // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 +// CHECK-PRIMITIVE-LABEL: func @dynamic_broadcast_in_dim +// CHECK-PRIMITIVE: tensor.empty +// CHECK-PRIMITIVE: %[[RESULT:.*]] = linalg.broadcast +// CHECK-PRIMITIVE: tensor.cast %[[RESULT]] : tensor to tensor + // ----- // CHECK-LABEL: func @dynamic_broadcast_in_dim( +// CHECK-PRIMITIVE-LABEL: func @dynamic_broadcast_in_dim // Note: this test requires no checks. The tensor.empty verifier will // fail if the %shape i32 -> index cast is not performed properly. func.func @dynamic_broadcast_in_dim(%scalar: tensor, %shape: tensor<2xi32>) @@ -2197,6 +2313,13 @@ func.func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>, %cst: tensor) // CHECK: [[RES:%.*]] = builtin.unrealized_conversion_cast [[GENERIC]] : tensor to tensor // CHECK: return [[RES]] : tensor +// CHECK-PRIMITIVE-LABEL: func @dynamic_broadcast_in_dim +// CHECK-PRIMITIVE: tensor.empty +// CHECK-PRIMITIVE: %[[BROADCASTED:.*]] = linalg.broadcast +// CHECK-PRIMITIVE: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[BROADCASTED]] +// CHECK-PRIMITIVE-SAME: tensor to tensor +// CHECK-PRIMITIVE: return %[[RES]] : tensor + // ----- // CHECK: #[[ARG_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (0, 0, d3, d4, 0, d6)> @@ -2204,30 +2327,10 @@ func.func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>, %cst: tensor) // CHECK-LABEL: @dynamic_broadcast_in_dim // CHECK-SAME: %[[ARG:.*]]: tensor, %[[SHAPE:.*]]: tensor<7xindex> +// CHECK-PRIMITIVE-LABEL: func @dynamic_broadcast_in_dim +// CHECK-PRIMITIVE-SAME: %[[ARG:.*]]: tensor, %[[SHAPE:.*]]: tensor<7xindex> func.func @dynamic_broadcast_in_dim(%arg: tensor, %shape: tensor<7xindex>) -> tensor { - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 - // CHECK-DAG: %[[C2:.*]] = arith.constant 2 - // CHECK-DAG: %[[C3:.*]] = arith.constant 3 - // CHECK-DAG: %[[C4:.*]] = arith.constant 4 - // CHECK-DAG: %[[C5:.*]] = arith.constant 5 - // CHECK-DAG: %[[C6:.*]] = arith.constant 6 - // CHECK-DAG: %[[DIM0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] - // CHECK-DAG: %[[DIM1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] - // CHECK-DAG: %[[DIM2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]] - // CHECK-DAG: %[[DIM3:.*]] = tensor.extract %[[SHAPE]][%[[C3]]] - // CHECK-DAG: %[[DIM4:.*]] = tensor.extract %[[SHAPE]][%[[C4]]] - // CHECK-DAG: %[[DIM5:.*]] = tensor.extract %[[SHAPE]][%[[C5]]] - // CHECK-DAG: %[[DIM6:.*]] = tensor.extract %[[SHAPE]][%[[C6]]] - // CHECK-DAG: %[[INIT:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]], %[[DIM5]], %[[DIM6]]) : tensor - // CHECK: %[[RES:.*]] = linalg.generic { - // CHECK-SAME: indexing_maps = [#[[ARG_MAP]], #[[RES_MAP]]], - // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} - // CHECK-SAME: ins(%[[ARG]] : tensor) outs(%[[INIT]] : tensor) { - // CHECK: ^bb0(%[[ARG_:.*]]: f32, %{{.*}}: f32): - // CHECK: linalg.yield %[[ARG_]] - // CHECK: return %[[RES]] %result = "mhlo.dynamic_broadcast_in_dim"(%arg, %shape) { broadcast_dimensions = dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi64>, known_expanding_dimensions = dense<[0, 1]> : tensor<2xi64>, @@ -2236,6 +2339,55 @@ func.func @dynamic_broadcast_in_dim(%arg: tensor, func.return %result : tensor } +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 +// CHECK-DAG: %[[DIM0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] +// CHECK-DAG: %[[DIM1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] +// CHECK-DAG: %[[DIM2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]] +// CHECK-DAG: %[[DIM3:.*]] = tensor.extract %[[SHAPE]][%[[C3]]] +// CHECK-DAG: %[[DIM4:.*]] = tensor.extract %[[SHAPE]][%[[C4]]] +// CHECK-DAG: %[[DIM5:.*]] = tensor.extract %[[SHAPE]][%[[C5]]] +// CHECK-DAG: %[[DIM6:.*]] = tensor.extract %[[SHAPE]][%[[C6]]] +// CHECK-DAG: %[[INIT:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]], %[[DIM5]], %[[DIM6]]) : tensor +// CHECK: %[[RES:.*]] = linalg.generic { +// CHECK-SAME: indexing_maps = [#[[ARG_MAP]], #[[RES_MAP]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[ARG]] : tensor) outs(%[[INIT]] : tensor) { +// CHECK: ^bb0(%[[ARG_:.*]]: f32, %{{.*}}: f32): +// CHECK: linalg.yield %[[ARG_]] +// CHECK: return %[[RES]] + +// CHECK-PRIMITIVE-DAG: %[[C0:.*]] = arith.constant 0 +// CHECK-PRIMITIVE-DAG: %[[C1:.*]] = arith.constant 1 +// CHECK-PRIMITIVE-DAG: %[[C2:.*]] = arith.constant 2 +// CHECK-PRIMITIVE-DAG: %[[C3:.*]] = arith.constant 3 +// CHECK-PRIMITIVE-DAG: %[[C4:.*]] = arith.constant 4 +// CHECK-PRIMITIVE-DAG: %[[C5:.*]] = arith.constant 5 +// CHECK-PRIMITIVE: tensor.cast %[[ARG]] +// CHECK-PRIMITIVE-SAME: tensor to tensor<1x1x?x?x1x42xf32> +// CHECK-PRIMITIVE: %[[COLLAPSED:.*]] = tensor.collapse_shape +// CHECK-PRIMITIVE-SAME{literal}: [[0, 1, 2], [3], [4, 5]] +// CHECK-PRIMITIVE-SAME: tensor<1x1x?x?x1x42xf32> into tensor +// CHECK-PRIMITIVE-DAG: %[[DIM0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] +// CHECK-PRIMITIVE-DAG: %[[DIM1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] +// CHECK-PRIMITIVE-DAG: %[[DIM2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]] +// CHECK-PRIMITIVE-DAG: %[[DIM3:.*]] = tensor.extract %[[SHAPE]][%[[C3]]] +// CHECK-PRIMITIVE-DAG: %[[DIM4:.*]] = tensor.extract %[[SHAPE]][%[[C4]]] +// CHECK-PRIMITIVE-DAG: %[[DIM5:.*]] = tensor.extract %[[SHAPE]][%[[C5]]] +// CHECK-PRIMITIVE: %[[INIT:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]], %[[DIM5]]) : tensor +// CHECK-PRIMITIVE: %[[BROADCASTED:.*]] = linalg.broadcast +// CHECK-PRIMITIVE-SAME: ins(%[[COLLAPSED]] : tensor) +// CHECK-PRIMITIVE-SAME: outs(%[[INIT]] : tensor) +// CHECK-PRIMITIVE-SAME: dimensions = [0, 1, 2, 5] +// CHECK-PRIMITIVE: %[[RES:.*]] = tensor.cast %[[BROADCASTED]] +// CHECK-PRIMITIVE-SAME: tensor to tensor +// CHECK-PRIMITIVE: return %[[RES]] + // ----- func.func @dot_matmul(%arg0: tensor<2x3xf32>, @@ -2694,9 +2846,9 @@ func.func @einsum_dynamic_size_broadcast_dot(%arg0: tensor, %arg1: te // ----- -// CHECK-LABEL: @clamp +// CHECK-LABEL: @clamp_static // CHECK-SAME: %[[LB:.*]]: tensor<4xf32>, %[[X:.*]]: tensor<4xf32>, %[[UB:.*]]: tensor<4xf32> -func.func @clamp(%lb : tensor<4xf32>, %x : tensor<4xf32>, %ub : tensor<4xf32>) +func.func @clamp_static(%lb : tensor<4xf32>, %x : tensor<4xf32>, %ub : tensor<4xf32>) -> tensor<4xf32> { // CHECK: %[[INIT:.*]] = tensor.empty // CHECK: %[[RESULT:.*]] = linalg.generic {{.*}} ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>) @@ -2711,6 +2863,17 @@ func.func @clamp(%lb : tensor<4xf32>, %x : tensor<4xf32>, %ub : tensor<4xf32>) func.return %0 : tensor<4xf32> } +// CHECK-PRIMITIVE-LABEL: @clamp_static +// CHECK-PRIMITIVE-SAME: %[[LB:.*]]: tensor<4xf32>, %[[X:.*]]: tensor<4xf32>, %[[UB:.*]]: tensor<4xf32> + +// CHECK-PRIMITIVE: %[[INIT:.*]] = tensor.empty +// CHECK-PRIMITIVE: %[[RESULT:.*]] = linalg.map ins(%[[LB]], %[[X]], %[[UB]] : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) outs(%[[INIT]] : tensor<4xf32>) +// CHECK-PRIMITIVE: (%[[SCALAR_LB:.*]]: f32, %[[SCALAR_X:.*]]: f32, %[[SCALAR_UB:.*]]: f32) +// CHECK-PRIMITIVE: %[[MAX:.*]] = arith.maxf %[[SCALAR_LB]], %[[SCALAR_X]] : f32 +// CHECK-PRIMITIVE: %[[MIN:.*]] = arith.minf %[[MAX]], %[[SCALAR_UB]] : f32 +// CHECK-PRIMITIVE: linalg.yield %[[MIN]] +// CHECK-PRIMITIVE: return %[[RESULT]] : tensor<4xf32> + // ----- // CHECK-LABEL: @clamp_dynamic @@ -2730,6 +2893,65 @@ func.func @clamp_dynamic(%lb : tensor, %x : tensor, %ub : tensor } +// CHECK-PRIMITIVE-LABEL: @clamp_dynamic +// CHECK-PRIMITIVE: linalg.map + +// ----- + +func.func @clamp_mixed(%lb : tensor<4xf32>, %x : tensor, %ub : tensor) + -> tensor { + %0 = "mhlo.clamp"(%lb, %x, %ub) : (tensor<4xf32>, tensor, + tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: @clamp_mixed +// CHECK: linalg.generic + +// CHECK-PRIMITIVE-LABEL: @clamp_mixed +// CHECK-PRIMITIVE: linalg.map + +// ----- + +func.func @clamp_scalar(%lb : tensor, %x : tensor, %ub : tensor) + -> tensor { + %0 = "mhlo.clamp"(%lb, %x, %ub) : (tensor, tensor, + tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: @clamp_scalar +// CHECK: linalg.generic + +// CHECK-PRIMITIVE-LABEL: @clamp_scalar +// CHECK-PRIMITIVE-SAME: %[[LB:.*]]: tensor, %[[X:.*]]: tensor, %[[UB:.*]]: tensor + +// CHECK-PRIMITIVE-DAG: %[[INIT:.*]] = tensor.empty +// CHECK-PRIMITIVE-DAG: %[[SCALAR_LB:.*]] = tensor.extract %[[LB]] +// CHECK-PRIMITIVE-DAG: %[[SCALAR_UB:.*]] = tensor.extract %[[UB]] +// CHECK-PRIMITIVE: %[[RESULT:.*]] = linalg.map ins(%[[X]] : tensor) outs(%[[INIT]] : tensor) +// CHECK-PRIMITIVE: (%[[SCALAR_X:.*]]: f32) +// CHECK-PRIMITIVE: %[[MAX:.*]] = arith.maxf %[[SCALAR_LB]], %[[SCALAR_X]] : f32 +// CHECK-PRIMITIVE: %[[MIN:.*]] = arith.minf %[[MAX]], %[[SCALAR_UB]] : f32 +// CHECK-PRIMITIVE: linalg.yield %[[MIN]] +// CHECK-PRIMITIVE: return %[[RESULT]] + + +// ----- + +func.func @clamp_scalar_mixed(%lb : tensor, %x : tensor, %ub : tensor) + -> tensor { + %0 = "mhlo.clamp"(%lb, %x, %ub) : (tensor, tensor, + tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: @clamp_scalar_mixed +// CHECK: linalg.generic + +// CHECK-PRIMITIVE-LABEL: @clamp_scalar_mixed +// CHECK-PRIMITIVE: linalg.map + // ----- func.func @map_compare(%arg0: tensor>, @@ -2770,8 +2992,8 @@ func.func @map_compare(%arg0: tensor>, // CHECK-PRIMITIVE: %[[INIT:.+]] = tensor.empty // CHECK-PRIMITIVE: %[[MAP:.+]] = linalg.map -// CHECK-PRIMITIVE-NEXT: ins(%[[ARG0]], %[[ARG1]] -// CHECK-PRIMITIVE-NEXT: outs(%[[INIT]] : tensor) +// CHECK-PRIMITIVE-SAME: ins(%[[ARG0]], %[[ARG1]] +// CHECK-PRIMITIVE-SAME: outs(%[[INIT]] : tensor) // CHECK-PRIMITIVE-NEXT: (%[[A:.+]]: complex, %[[B:.+]]: complex) { // CHECK-PRIMITIVE: %[[RE1:.+]] = complex.re %[[A]] : complex // CHECK-PRIMITIVE: %[[RE2:.+]] = complex.re %[[B]] : complex @@ -2779,6 +3001,26 @@ func.func @map_compare(%arg0: tensor>, // CHECK-PRIMITIVE: linalg.yield %[[CMP]] : i1 // CHECK-PRIMITIVE: } // CHECK-PRIMITIVE: return %[[MAP]] : tensor + +// ----- + +func.func @map_mixed(%arg0: tensor, + %arg1: tensor<4xf32>) -> tensor { + %0 = "mhlo.map"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<0> : tensor<1xi64>} + : (tensor, tensor<4xf32>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: @map_mixed +// CHECK: linalg.generic + +// CHECK-PRIMITIVE-LABEL: @map_mixed +// CHECK-PRIMITIVE: linalg.map + // ----- func.func @reduce_add(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<5xi32> { @@ -2805,9 +3047,19 @@ func.func @reduce_add(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<5xi3 // CHECK-NEXT: %[[RESULT:.*]] = arith.addi %[[RHS_IN]], %[[LHS_IN]] : i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32 +// CHECK-PRIMITIVE-LABEL: @reduce_add +// CHECK-PRIMITIVE-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor +// CHECK-PRIMITIVE-DAG: %[[INIT_TENSOR:.*]] = tensor.empty() +// CHECK-PRIMITIVE-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]] +// CHECK-PRIMITIVE: linalg.reduce { arith.addi } +// CHECK-PRIMITIVE-SAME: ins(%{{.*}}tensor<5x4xi32>) +// CHECK-PRIMITIVE-SAME: outs(%[[FILL_TENSOR]] : tensor<5xi32>) +// CHECK-PRIMITIVE-SAME: dimensions = [1] {someattr} + // ----- // CHECK-LABEL: @reduce_add_unranked +// CHECK-PRIMITIVE-LABEL: @reduce_add_unranked func.func @reduce_add_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor<*xi32> { %0 = "mhlo.reduce"(%arg0, %arg1) ({ ^bb0(%arg3: tensor, %arg4 : tensor): @@ -2817,6 +3069,7 @@ func.func @reduce_add_unranked(%arg0: tensor<*xi32>, %arg1: tensor) -> tens func.return %0 : tensor<*xi32> } // CHECK: mhlo.reduce +// CHECK-PRIMITIVE: mhlo.reduce // ----- @@ -2843,6 +3096,33 @@ func.func @reduce_dim0(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor<4xi // CHECK-NEXT: %[[RESULT:.*]] = arith.maxsi %[[RHS_IN]], %[[LHS_IN]] : i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32 +// CHECK-PRIMITIVE-LABEL: @reduce_dim0 +// CHECK-PRIMITIVE-DAG: %[[INIT:.*]] = tensor.extract %{{.*}} : tensor +// CHECK-PRIMITIVE-DAG: %[[INIT_TENSOR:.*]] = tensor.empty() +// CHECK-PRIMITIVE-DAG: %[[FILL_TENSOR:.*]] = linalg.fill ins(%[[INIT]]{{.*}}outs(%[[INIT_TENSOR]] +// CHECK-PRIMITIVE: linalg.reduce { arith.maxsi } +// CHECK-PRIMITIVE-SAME: ins(%{{.*}}tensor<5x4xi32>) +// CHECK-PRIMITIVE-SAME: outs(%[[FILL_TENSOR]] : tensor<4xi32>) +// CHECK-PRIMITIVE-SAME: dimensions = [0] + +// ----- + +func.func @reduce_dynamic_output(%arg0: tensor<5x4xi32>, %arg1: tensor) -> tensor { + %0 = "mhlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg3: tensor, %arg4 : tensor): + %1 = mhlo.maximum %arg3, %arg4 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<5x4xi32>, tensor) -> tensor + func.return %0 : tensor +} + +// Regression test: just check that this lowers successfully. +// CHECK-LABEL: @reduce_dynamic_output +// CHECK: linalg.generic + +// CHECK-PRIMITIVE-LABEL: @reduce_dynamic_output +// CHECK-PRIMITIVE: linalg.reduce + // ----- func.func @reduce_init_const(%arg0: tensor<1x10xf32>) -> tensor<1xf32> { @@ -2938,7 +3218,7 @@ func.func @reduce_lexicographic_min_complex(%arg0: tensor>, // CHECK: arith.select // CHECK-PRIMITIVE-LABEL: @reduce_lexicographic_min_complex -// CHECK-PRIMITIVE: linalg.generic +// CHECK-PRIMITIVE: linalg.reduce // CHECK-PRIMITIVE: complex.re // CHECK-PRIMITIVE: complex.re // CHECK-PRIMITIVE: arith.cmpf @@ -3017,8 +3297,6 @@ func.func @variadic_reduce(%arg0: tensor<9x2xi32>, %arg1: tensor<9x2xi32>) -> (t // CHECK-NEXT: %[[T6:.*]] = arith.select %[[T3]], %[[T4]], %[[T5]] : i32 // CHECK-NEXT: linalg.yield %[[T2]], %[[T6]] -// CHECK-PRIMITIVE-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)> -// CHECK-PRIMITIVE-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-PRIMITIVE: func @variadic_reduce // CHECK-PRIMITIVE-SAME: %[[ARG0:[a-zA-Z0-9_]*]] // CHECK-PRIMITIVE-SAME: %[[ARG1:[a-zA-Z0-9_]*]] @@ -3028,12 +3306,11 @@ func.func @variadic_reduce(%arg0: tensor<9x2xi32>, %arg1: tensor<9x2xi32>) -> (t // CHECK-PRIMITIVE: %[[FILL0:.*]] = linalg.fill ins(%[[CST0]]{{.*}}outs(%[[INIT0]] // CHECK-PRIMITIVE: %[[INIT1:.*]] = tensor.empty() : tensor<2xi32> // CHECK-PRIMITIVE: %[[FILL1:.*]] = linalg.fill ins(%[[CST1]]{{.*}}outs(%[[INIT1]] -// CHECK-PRIMITIVE: %[[RES:.+]]:2 = linalg.generic { -// CHECK-PRIMITIVE-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP1]], #[[MAP1]]] -// CHECK-PRIMITIVE-SAME: iterator_types = ["parallel", "reduction"] +// CHECK-PRIMITIVE: %[[RES:.+]]:2 = linalg.reduce // CHECK-PRIMITIVE-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<9x2xi32>, tensor<9x2xi32>) // CHECK-PRIMITIVE-SAME: outs(%[[FILL0]], %[[FILL1]] : tensor<2xi32>, tensor<2xi32>) -// CHECK-PRIMITIVE-NEXT: ^bb0(%[[IN0:.*]]: i32, %[[IN1:.*]]: i32, %[[OUT0:.*]]: i32, %[[OUT1:.*]]: i32): +// CHECK-PRIMITIVE-SAME: dimensions = [0] +// CHECK-PRIMITIVE-NEXT: (%[[IN0:.*]]: i32, %[[IN1:.*]]: i32, %[[OUT0:.*]]: i32, %[[OUT1:.*]]: i32) { // CHECK-PRIMITIVE-NEXT: %[[T1:.*]] = arith.cmpi sge, %[[OUT0]], %[[IN0]] : i32 // CHECK-PRIMITIVE-NEXT: %[[T2:.*]] = arith.select %[[T1]], %[[OUT0]], %[[IN0]] : i32 // CHECK-PRIMITIVE-NEXT: %[[T3:.*]] = arith.cmpi eq, %[[OUT0]], %[[IN0]] : i32 @@ -3078,8 +3355,6 @@ func.func @variadic_diff_type_reduce(%arg0: tensor<128x10xf32>, %arg1: tensor<12 // CHECK-NEXT: %[[RES1:.*]] = arith.select %[[B0]], %[[RHS1]], %[[LHS1]] : i32 // CHECK-NEXT: linalg.yield %[[RES0]], %[[RES1]] : f32, i32 -// CHECK-PRIMITIVE-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-PRIMITIVE-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-PRIMITIVE: func @variadic_diff_type_reduce // CHECK-PRIMITIVE-SAME: %[[ARG0:[a-zA-Z0-9_]*]] // CHECK-PRIMITIVE-SAME: %[[ARG1:[a-zA-Z0-9_]*]] @@ -3089,12 +3364,11 @@ func.func @variadic_diff_type_reduce(%arg0: tensor<128x10xf32>, %arg1: tensor<12 // CHECK-PRIMITIVE: %[[FILL0:.*]] = linalg.fill ins(%[[CST0]]{{.*}}outs(%[[INIT0]] // CHECK-PRIMITIVE: %[[INIT1:.*]] = tensor.empty() : tensor<128xi32> // CHECK-PRIMITIVE: %[[FILL1:.*]] = linalg.fill ins(%[[CST1]]{{.*}}outs(%[[INIT1]] -// CHECK-PRIMITIVE: %[[RES:.+]]:2 = linalg.generic { -// CHECK-PRIMITIVE-SAME: indexing_maps = [#[[MAP0]], #[[MAP0]], #[[MAP1]], #[[MAP1]]] -// CHECK-PRIMITIVE-SAME: iterator_types = ["parallel", "reduction"] +// CHECK-PRIMITIVE: %[[RES:.+]]:2 = linalg.reduce // CHECK-PRIMITIVE-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<128x10xf32>, tensor<128x10xi32>) -// CHECK-PRIMITIVE-SAME: outs(%[[FILL0]], %[[FILL1]] : tensor<128xf32>, tensor<128xi32>) -// CHECK-PRIMITIVE-NEXT: ^bb0(%[[LHS0:.*]]: f32, %[[LHS1:.*]]: i32, %[[RHS0:.*]]: f32, %[[RHS1:.*]]: i32): +// CHECK-PRIMITIVE-SAME: outs(%[[FILL0]], %[[FILL1]] : tensor<128xf32>, tensor<128xi32>) +// CHECK-PRIMITIVE-SAME: dimensions = [1] +// CHECK-PRIMITIVE-NEXT: (%[[LHS0:.*]]: f32, %[[LHS1:.*]]: i32, %[[RHS0:.*]]: f32, %[[RHS1:.*]]: i32) { // CHECK-PRIMITIVE-NEXT: %[[B0:.*]] = arith.cmpf oge, %[[RHS0]], %[[LHS0]] : f32 // CHECK-PRIMITIVE-NEXT: %[[RES0:.*]] = arith.select %[[B0]], %[[RHS0]], %[[LHS0]] : f32 // CHECK-PRIMITIVE-NEXT: %[[RES1:.*]] = arith.select %[[B0]], %[[RHS1]], %[[LHS1]] : i32 @@ -3154,6 +3428,19 @@ func.func @slice_with_strides2(%arg0: tensor<6xi32>) -> tensor<3xi32> { // ----- +func.func @slice_with_empty_result(%arg0: tensor<3x3x5xf64>) -> tensor<3x0x5xf64> { + %0 = "mhlo.slice"(%arg0) { + limit_indices = dense<[3, 2, 5]> : tensor<3xi64>, + start_indices = dense<[0, 2, 0]> : tensor<3xi64>, + strides = dense<[1, 2, 1]> : tensor<3xi64> + } : (tensor<3x3x5xf64>) -> tensor<3x0x5xf64> + func.return %0 : tensor<3x0x5xf64> +} +// CHECK-LABEL: func @slice_with_empty_result +// CHECK: tensor.extract_slice %{{.*}}[0, 2, 0] [3, 0, 5] [1, 2, 1] : tensor<3x3x5xf64> to tensor<3x0x5xf64> + +// ----- + func.func @dynamic_slice(%arg: tensor<3x4xf32>, %start1: tensor, %start2: tensor) -> tensor<1x4xf32> { %0 = "mhlo.dynamic_slice"(%arg, %start1, %start2) { slice_sizes = dense<[1, 4]> : tensor<2xi64> @@ -3179,7 +3466,7 @@ func.func @dynamic_slice(%arg: tensor<3x4xf32>, %start1: tensor, %start2: t // ----- func.func @dynamic_slice_unsigned_index( - %arg: tensor<3x4xui32>, %start1: tensor, %start2: tensor) + %arg: tensor<3x4xui32>, %start1: tensor, %start2: tensor) -> tensor<1x4xui32> { %0 = "mhlo.dynamic_slice"(%arg, %start1, %start2) { slice_sizes = dense<[1, 4]> : tensor<2xi64> @@ -4369,7 +4656,7 @@ func.func @reduce_window_generic_scalar(%arg0: tensor, %arg1: tensor) // CHECK: #[[MAP:.+]] = affine_map<() -> ()> // CHECK-LABEL: func @reduce_window_generic_scalar -// CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]]] +// CHECK: linalg.generic {indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] // ----- @@ -4750,19 +5037,21 @@ func.func @torch_index_select(%arg0: tensor<5x1x5xi32>, func.return %0 : tensor<2x1x5xi32> } // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: func @torch_index_select // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] -// CHECK: %[[INIT:.+]] = tensor.empty() : +// CHECK: %[[INIT1:.+]] = tensor.empty() : +// CHECK: %[[INIT2:.+]] = tensor.empty() : // CHECK: linalg.generic { // CHECK-SAME: indexing_maps -// CHECK-SAME: #[[MAP0]], #[[MAP1]] +// CHECK-SAME: #[[MAP0]], #[[MAP1]], #[[MAP2]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[INDEX]] : -// CHECK-SAME: outs(%[[INIT]] : +// CHECK-SAME: ins(%[[INDEX]], %[[INIT1]] : +// CHECK-SAME: outs(%[[INIT2]] : // CHECK-SAME: {someattr} -// CHECK: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: i32): +// CHECK: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: i32, %{{.+}}: i32): // CHECK: %[[CAST:.+]] = arith.index_cast %[[VAL]] : i32 to index // CHECK: %[[J:.+]] = linalg.index 1 // CHECK: %[[K:.+]] = linalg.index 2 @@ -4898,17 +5187,19 @@ func.func @torch_index_select_unsigned(%arg0: tensor<5x1x5xui32>, func.return %0 : tensor<2x1x5xui32> } // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: func @torch_index_select_unsigned // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] // CHECK: %[[INPUT_SIGNLESS:.*]] = builtin.unrealized_conversion_cast %[[INPUT]] : tensor<5x1x5xui32> to tensor<5x1x5xi32> +// CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1x5xi32> // CHECK: %[[RES:.+]] = linalg.generic { // CHECK-SAME: indexing_maps -// CHECK-SAME: #[[MAP0]], #[[MAP1]] +// CHECK-SAME: #[[MAP0]], #[[MAP1]], #[[MAP2]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>) -// CHECK: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: i32): +// CHECK-SAME: ins(%[[INDEX]], %[[INIT]] : tensor<2xi32>, tensor<1x5xi32>) +// CHECK: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: i32, %{{.+}}: i32): // CHECK: %[[CAST:.+]] = arith.index_cast %[[VAL]] : i32 to index // CHECK: %[[J:.+]] = linalg.index 1 // CHECK: %[[K:.+]] = linalg.index 2 @@ -4933,11 +5224,12 @@ func.func @torch_index_select_scalar(%arg0: tensor<4x8xf32>, // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] // CHECK: %[[T0:.+]] = tensor.empty() : tensor<8xf32> +// CHECK: %[[T1:.+]] = tensor.empty() : tensor<8xf32> // CHECK: linalg.generic { // CHECK-SAME: indexing_maps -// CHECK-SAME: #[[MAP0]], #[[MAP1]] +// CHECK-SAME: #[[MAP0]], #[[MAP1]], #[[MAP1]] // CHECK-SAME: iterator_types = ["parallel"] -// CHECK-SAME: ins(%[[INDEX]] : tensor) outs(%[[T0]] : tensor<8xf32>) +// CHECK-SAME: ins(%[[INDEX]], %[[T0]] : tensor, tensor<8xf32>) outs(%[[T1]] : tensor<8xf32>) // CHECK: ^{{.+}}(%[[VAL:[a-zA-Z0-9_]+]]: i32, %{{.+}}: f32): // CHECK: %[[CAST:.+]] = arith.index_cast %[[VAL]] : i32 to index // CHECK: %[[I:.+]] = linalg.index 0 @@ -4955,16 +5247,18 @@ func.func @torch_index_select_batch(%arg0: tensor<4x7x8x2xf32>, func.return %0 : tensor<4x7x1x2xf32> } // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK: func @torch_index_select_batch // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] +// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<4x7x2xf32> // CHECK: linalg.generic { // CHECK-SAME: indexing_maps -// CHECK-SAME: #[[MAP0]], #[[MAP1]] +// CHECK-SAME: #[[MAP0]], #[[MAP1]], #[[MAP2]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[INDEX]] : -// CHECK-NEXT: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: f32): +// CHECK-SAME: ins(%[[INDEX]], %[[INIT]] : +// CHECK-NEXT: ^{{.+}}(%[[VAL:.+]]: i32, %{{.+}}: f32, %{{.+}}: f32): // CHECK: %[[CAST:.+]] = arith.index_cast %[[VAL]] : i32 to index // CHECK: %[[I:.+]] = linalg.index 0 // CHECK: %[[J:.+]] = linalg.index 1 @@ -4983,7 +5277,8 @@ func.func @torch_index_select_dynamic(%input: tensor, func.return %0 : tensor } // CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> -// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK: func @torch_index_select_dynamic // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] // CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] @@ -4994,14 +5289,18 @@ func.func @torch_index_select_dynamic(%input: tensor, // CHECK: %[[D1:.+]] = tensor.dim %[[INPUT]], %[[C1]] // CHECK: %[[D2:.+]] = tensor.dim %[[INDEX]], %[[C1]] // CHECK: %[[D3:.+]] = tensor.dim %[[INPUT]], %[[C3]] -// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]], %[[D3]]) +// CHECK: %[[D4:.+]] = tensor.dim %[[INPUT]], %[[C0]] +// CHECK: %[[D5:.+]] = tensor.dim %[[INPUT]], %[[C1]] +// CHECK: %[[D6:.+]] = tensor.dim %[[INPUT]], %[[C3]] +// CHECK: %[[INIT0:.+]] = tensor.empty(%[[D4]], %[[D5]], %[[D6]]) : tensor +// CHECK: %[[INIT1:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]], %[[D3]]) // CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[INDEX]] : tensor) -// CHECK-SAME: outs(%[[INIT]] : tensor) +// CHECK-SAME: ins(%[[INDEX]], %[[INIT0]] : tensor, tensor) +// CHECK-SAME: outs(%[[INIT1]] : tensor) // CHECK: ^{{.+}}( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: i32, %{{[a-zA-Z0-9_]+}}: f32) +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: i32, %{{[a-zA-Z0-9_]+}}: f32, %{{[a-zA-Z0-9_]+}}: f32) // CHECK: %[[POS:.+]] = arith.index_cast %[[ARG0]] // CHECK: %[[IDX0:.+]] = linalg.index 0 // CHECK: %[[IDX1:.+]] = linalg.index 1 @@ -5562,3 +5861,84 @@ func.func @convolution_without_reversing_and_stride(%arg0: tensor<2x14x12x2xf64> : (tensor<2x14x12x2xf64>, tensor<7x7x1x2xf64>) -> tensor<2x12x16x2xf64> return %0 : tensor<2x12x16x2xf64> } + +// ----- + +// CHECK-DAG: affine_map<(d0, d1, d2, d3) -> (-d0 + 2, -d1 + 2, d2, d3)> +// CHECK-LABEL: @normal_convolution_with_reversal +func.func @normal_convolution_with_reversal(%arg0: tensor<1x3x3x3xf32>, + %arg1: tensor<3x3x3x1xf32>) -> tensor<1x1x1x1xf32> { + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], + pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1], + reverse = [1, 1] + } { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, precision_config = [ + #mhlo, + #mhlo] + } : (tensor<1x3x3x3xf32>, tensor<3x3x3x1xf32>) -> tensor<1x1x1x1xf32> + return %0 : tensor<1x1x1x1xf32> +} + +// ----- + +// CHECK-LABEL: set_dimension_size +// CHECK-SAME: %[[VALUE:.*]]: tensor<2x?xf32, #mhlo.type_extensions +func.func @set_dimension_size( + %value: tensor<2x?xf32, #mhlo.type_extensions>, + %dimension: tensor) + -> tensor<2x?xf32, #mhlo.type_extensions> { + // CHECK: tensor.extract_slice %[[VALUE]][0, 0] [2, %{{.*}}] [1, 1] : tensor<2x?xf32, #mhlo.type_extensions> to tensor<2x?xf32, #mhlo.type_extensions> + %0 = "mhlo.set_dimension_size"(%value, %dimension) { dimension = 1 } + : (tensor<2x?xf32, #mhlo.type_extensions>, tensor) + -> tensor<2x?xf32, #mhlo.type_extensions> + func.return %0 : tensor<2x?xf32, #mhlo.type_extensions> +} + +// ----- +// The following test checks that an EmptyOp is emitted for mhlo.convolution +// when the output shape has a zero-sized dimension. This goes through +// ConvolutionOpGeneralConversion rewrite pattern. + +// CHECK-LABEL: @general_convolution_with_zero_sized_dimension_in_output +// CHECK-SAME: %[[LHS:.*]]: tensor<2x4x9x0xi64> +// CHECK-SAME: %[[RHS:.*]]: tensor<4x5x2x4xi64> +// CHECK-SAME: -> tensor<2x5x0x4xi64> +// CHECK-NEXT: %[[RES:.*]] = tensor.empty +// CHECK-NEXT: return %[[RES]] + +func.func @general_convolution_with_zero_sized_dimension_in_output(%arg0: tensor<2x4x9x0xi64> {bufferization.writable = false, xla_framework.input_mapping = 2 : i32}, +%arg1: tensor<4x5x2x4xi64> {bufferization.writable = false, xla_framework.input_mapping = 0 : i32}) +-> tensor<2x5x0x4xi64> attributes {xla_framework.result_mapping = 1 : i32} { + %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [2, 1], pad = [[1, 2], [2, 0]], lhs_dilate = [1, 4], rhs_dilate = [1, 1], reverse = [0, 0]} + {batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#mhlo, #mhlo]} + : (tensor<2x4x9x0xi64>, tensor<4x5x2x4xi64>) -> tensor<2x5x0x4xi64> + return %0 : tensor<2x5x0x4xi64> +} + +// ----- +// This test is similar to the previous one, but runs through a different +// rewrite pattern (NormalConvolutionOpConversion). + +// CHECK-LABEL: @normal_convolution_with_zero_sized_dimension_in_output +// CHECK-SAME: %[[LHS:.*]]: tensor<3x9x0x2xi16> +// CHECK-SAME: %[[RHS:.*]]: tensor<4x5x2x2xi16> +// CHECK-SAME: -> tensor<3x9x0x2xi16> +// CHECK-NEXT: %[[RES:.*]] = tensor.empty +// CHECK-NEXT: return %[[RES]] + +func.func @normal_convolution_with_zero_sized_dimension_in_output(%arg0: tensor<3x9x0x2xi16> {bufferization.writable = false, xla_framework.input_mapping = 2 : i16}, +%arg1: tensor<4x5x2x2xi16> {bufferization.writable = false, xla_framework.input_mapping = 0 : i16}) +-> tensor<3x9x0x2xi16> attributes {xla_framework.result_mapping = 1 : i16} { + %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 2], [2, 0]], lhs_dilate = [1, 2], rhs_dilate = [1, 4], reverse = [0, 0]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo, #mhlo]} + : (tensor<3x9x0x2xi16>, tensor<4x5x2x2xi16>) -> tensor<3x9x0x2xi16> + return %0 : tensor<3x9x0x2xi16> +} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-memref.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-memref.mlir index f6fc852e63a..dbca93c7d9f 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-memref.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-memref.mlir @@ -24,7 +24,7 @@ func.func @dyn_broadcast(%operand: tensor) -> tensor { // CHECK: %[[EXPAND_2:.*]] = arith.cmpi slt, %[[OPER_DIM_1]], %[[C1]] : index // CHECK: %[[STRIDE_2:.*]] = arith.select %[[EXPAND_2]], %[[C0]], %[[C1]] : index -// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref.reinterpret_cast %[[OPERAND]] to offset: [0], sizes: [%[[C1]], %[[C1]], %[[C1]]], strides: [%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref to memref +// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref.reinterpret_cast %[[OPERAND]] to offset: [0], sizes: [%[[C1]], %[[C1]], %[[C1]]], strides: [%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref to memref // CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[TRANSFORMED_MEMREF]] // CHECK: return %[[RESULT]] @@ -63,7 +63,7 @@ func.func @dyn_broadcast_unsigned(%operand: tensor, %shape: tensor<3xi6 // CHECK: %[[EXPAND_2:.*]] = arith.cmpi slt, %[[OPER_DIM_1]], %[[SIZE_2]] : index // CHECK: %[[STRIDE_2:.*]] = arith.select %[[EXPAND_2]], %[[C0]], %[[C1]] : index -// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref.reinterpret_cast %[[OPERAND]] to offset: [0], sizes: [%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: [%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref to memref +// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref.reinterpret_cast %[[OPERAND]] to offset: [0], sizes: [%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: [%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref to memref // CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[TRANSFORMED_MEMREF]] @@ -118,10 +118,49 @@ func.func @custom_call_multiple_inputs_outputs(%x: tensor<2xf32>, // CHECK-DAG: %[[I0:.+]] = bufferization.to_memref %[[ARG0]] : memref<2xf32> // CHECK-DAG: %[[I1:.+]] = bufferization.to_memref %[[ARG1]] : memref<5xi32> -// CHECK-DAG: %[[O0:.*]] = memref.alloc() {alignment = 128 : i64} : memref<2xf32> -// CHECK-DAG: %[[O1:.*]] = memref.alloc() {alignment = 128 : i64} : memref<2xf32> -// CHECK-DAG: %[[O2:.*]] = memref.alloc() {alignment = 128 : i64} : memref<5xi32> -// CHECK: "lmhlo.custom_call"(%[[I0]], %[[I1]], %[[O0]], %[[O1]], %[[O2]]) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = array} : (memref<2xf32>, memref<5xi32>, memref<2xf32>, memref<2xf32>, memref<5xi32>) -> () +// CHECK-DAG: %[[O0:.*]] = memref.alloc() {{.*}} : memref<2xf32> +// CHECK-DAG: %[[O1:.*]] = memref.alloc() {{.*}} : memref<2xf32> +// CHECK-DAG: %[[O2:.*]] = memref.alloc() {{.*}} : memref<5xi32> +// CHECK: "lmhlo.custom_call"(%[[I0]], %[[I1]], %[[O0]], %[[O1]], %[[O2]]) ({ +// CHECK-NEXT: }) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = array} : (memref<2xf32>, memref<5xi32>, memref<2xf32>, memref<2xf32>, memref<5xi32>) -> () // CHECK-DAG: %[[T0:.+]] = bufferization.to_tensor %[[O0]] : memref<2xf32> // CHECK-DAG: %[[T1:.+]] = bufferization.to_tensor %[[O1]] : memref<2xf32> // CHECK: return %[[T0]], %[[T1]] : tensor<2xf32>, tensor<2xf32> + +// ----- + +// CHECK-LABEL: func @custom_call_side_effect +// CHECK-SAME: %[[ARG0:.*]]: tensor<2xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<5xi32> +func.func @custom_call_side_effect(%x: tensor<2xf32>, + %y: tensor<5xi32>) -> !mhlo.token { + %token = mhlo.create_token : !mhlo.token + %0:2 = "mhlo.custom_call"(%x, %y, %token) { + backend_config = "", + call_target_name = "bar", + has_side_effect = true + } : (tensor<2xf32>, tensor<5xi32>, !mhlo.token) + -> (!mhlo.token, tensor<2xi32>) + func.return %0#0 : !mhlo.token +} + +// CHECK-DAG: %[[TOKEN:.*]] = mhlo.create_token +// CHECK-DAG: %[[I0:.+]] = bufferization.to_memref %[[ARG0]] : memref<2xf32> +// CHECK-DAG: %[[I1:.+]] = bufferization.to_memref %[[ARG1]] : memref<5xi32> +// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc +// CHECK: "lmhlo.custom_call"(%[[I0]], %[[I1]], %[[ALLOC]]) ({ +// CHECK-NEXT: }) {backend_config = "", call_target_name = "bar", has_side_effect = true, operand_segment_sizes = array, target_arg_mapping = #lmhlo.custom_call_target_arg_mapping} : (memref<2xf32>, memref<5xi32>, memref<2xi32>) +// CHECK: return %[[TOKEN]] : !mhlo.token + +// ----- + +// CHECK-LABEL: func @infeed_outfeed +func.func @infeed_outfeed(%arg0: tensor) { + %0 = mhlo.create_token : !mhlo.token + %1:2 = "mhlo.infeed"(%0) {infeed_config = "", layout = [[1, 0]]} : (!mhlo.token) -> (tensor<3x4xf32>, !mhlo.token) +// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<3x4xf32> +// CHECK: "lmhlo.infeed"(%[[ALLOC]]) {config = "", infeed_config = "", layout = {{\[}}[1, 0]]} : (memref<3x4xf32>) -> () + %2 = "mhlo.outfeed"(%1#0, %1#1) {outfeed_config = ""} : (tensor<3x4xf32>, !mhlo.token) -> !mhlo.token +// CHECK: "lmhlo.outfeed"(%[[ALLOC]]) {config = "", outfeed_config = ""} : (memref<3x4xf32>) -> () + func.return +} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir new file mode 100644 index 00000000000..6d8d5854381 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir @@ -0,0 +1,50 @@ +// RUN: mlir-hlo-opt --hlo-legalize-to-stablehlo --mlir-print-op-generic --split-input-file --verify-diagnostics %s +// RUN: mlir-hlo-opt --hlo-legalize-to-stablehlo=allow-experimental-features --mlir-print-op-generic %s | FileCheck %s + +// This test file runs both FileCheck and diagnostic check. These tests all +// error when the experimental flag is disabled, and pass when it is enabled. + +func.func @op_all_to_all_tuple(%arg0: tensor<128x4xf32>, %arg1: tensor<128x4xf32>) -> (tensor<128x4xf32>, tensor<128x4xf32>) { + // CHECK: "stablehlo.custom_call"(%arg0, %arg1) { + // CHECK-SAME{LITERAL}: backend_config = "{replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>}", + // CHECK-SAME: call_target_name = "mhlo.all_to_all" + // CHECK-SAME: } : (tensor<128x4xf32>, tensor<128x4xf32>) + // expected-error@+1 {{failed to legalize operation 'mhlo.all_to_all' that was explicitly marked illegal}} + %0:2 = "mhlo.all_to_all"(%arg0, %arg1) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + } : (tensor<128x4xf32>, tensor<128x4xf32>) -> (tensor<128x4xf32>, tensor<128x4xf32>) + return %0#0, %0#1 : tensor<128x4xf32>, tensor<128x4xf32> +} +// CHECK-LABEL: "op_all_to_all_tuple" + +// ----- + +func.func @op_custom_call_api_version_typed_ffi(%arg0: tensor) -> tensor { + // CHECK: "stablehlo.custom_call"(%arg0) { + // CHECK-SAME: backend_config = "{api_version = 4 : i32, backend_config = {foo = \22bar\22}, call_target_name = \22foo\22}" + // CHECK-SAME: call_target_name = "mhlo.custom_call" + // CHECK-SAME: } : (tensor) -> tensor + // expected-error@+1 {{failed to legalize operation 'mhlo.custom_call' that was explicitly marked illegal}} + %0 = "mhlo.custom_call"(%arg0) { + call_target_name = "foo", + backend_config = {foo = "bar"}, + api_version = 4 : i32 + } : (tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: "op_custom_call_api_version_typed_ffi" + +// ----- + +func.func @attr_precision_packed_nibble(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "stablehlo.custom_call"(%arg0, %arg1) { + // CHECK-SAME: backend_config = "{precision_config = [#mhlo]}", + // CHECK-SAME: call_target_name = "mhlo.dot" + // CHECK-SAME: } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + // expected-error@+1 {{failed to legalize operation 'mhlo.dot' that was explicitly marked illegal}} + %0 = "mhlo.dot"(%arg0, %arg1) { + precision_config = [#mhlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} +// CHECK-LABEL: "attr_precision_packed_nibble" diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 4be23d50f00..7075982f6d5 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -1,4 +1,5 @@ // RUN: mlir-hlo-opt --hlo-legalize-to-stablehlo --mlir-print-op-generic --split-input-file --verify-diagnostics %s | FileCheck %s +// RUN: mlir-hlo-opt --hlo-legalize-to-stablehlo=allow-experimental-features --mlir-print-op-generic --split-input-file --verify-diagnostics %s | FileCheck %s // ============ ATTRIBUTES ============ @@ -151,6 +152,7 @@ func.func @attr_custom_call_api_version_status_returning_unified(%arg0: tensor is unsupported at the moment (see negative test below). // DequantizeMode aka #mhlo is unused at the moment. // DomainKind aka #mhlo is unsupported at the moment (see negative test below). // DotDimensionNumbers aka #mhlo.dot is covered below. @@ -198,34 +200,32 @@ func.func @attr_fft_type_irfft(%arg0: tensor<9xcomplex>) -> tensor<16xf32> // FusionKind aka #mhlo is unsupported at the moment (see negative test below). // GatherDimensionNumbers aka #mhlo.gather is covered below. -func.func @attr_precision_config_default(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { +func.func @attr_precision_default(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { %0 = "mhlo.dot"(%arg0, %arg1) { // CHECK: precision_config = [#stablehlo] precision_config = [#mhlo] } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> func.return %0 : tensor<8x8xf32> } -// CHECK-LABEL: "attr_precision_config_default" +// CHECK-LABEL: "attr_precision_default" -func.func @attr_precision_config_high(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { +func.func @attr_precision_high(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { %0 = "mhlo.dot"(%arg0, %arg1) { // CHECK: precision_config = [#stablehlo] precision_config = [#mhlo] } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> func.return %0 : tensor<8x8xf32> } -// CHECK-LABEL: "attr_precision_config_high" +// CHECK-LABEL: "attr_precision_high" -func.func @attr_precision_config_highest(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { +func.func @attr_precision_highest(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { %0 = "mhlo.dot"(%arg0, %arg1) { // CHECK: precision_config = [#stablehlo] precision_config = [#mhlo] } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> func.return %0 : tensor<8x8xf32> } -// CHECK-LABEL: "attr_precision_config_highest" - -// Precision::PACKED_NIBBLE is unsupported at the moment (see negative test below). +// CHECK-LABEL: "attr_precision_highest" func.func @attr_rng_algorithm_default(%arg0: tensor) -> (tensor, tensor) { %0:2 = "mhlo.rng_bit_generator"(%arg0) { @@ -312,9 +312,11 @@ func.func @attr_transpose_adjoint(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16x // TypeExtensionsAttr aka #mhlo.type_extensions is covered below. -func.func @attr_type_extensions_bounds(%arg0: tensor>) -> tensor> { - // CHECK: "func.return"(%arg0) : (tensor>) -> () - func.return %arg0 : tensor> +func.func @attr_type_extensions_bounds( + %arg0: tensor>) + -> tensor> { + // CHECK: "func.return"(%arg0) : (tensor>) -> () + func.return %arg0 : tensor> } // CHECK-LABEL: "attr_type_extensions_bounds" @@ -385,6 +387,7 @@ func.func @op_all_reduce(%arg0: tensor) -> tensor { func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { // CHECK: "stablehlo.all_to_all"(%arg0) { + // CHECK-SAME: channel_handle = #stablehlo.channel_handle, // CHECK-SAME: concat_dimension = 0 : i64, // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, // CHECK-SAME: split_count = 4 : i64, @@ -394,11 +397,11 @@ func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { split_dimension = 1 : i64, concat_dimension = 0 : i64, split_count = 4 : i64, - replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + channel_handle = #mhlo.channel_handle } : (tensor<4x16xf32>) -> tensor<16x4xf32> func.return %0 : tensor<16x4xf32> } -// CHECK-LABEL: "op_all_to_all" func.func @op_and(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor @@ -607,10 +610,10 @@ func.func @op_convert(%arg0: tensor) -> tensor { } // CHECK-LABEL: "op_convert" -func.func @op_convolution(%arg0: tensor<1x8x8x32x207xf32>, %arg1: tensor<3x3x32x207x16xf32>) -> tensor<32x1x8x8x16xf32> { +func.func @op_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // CHECK: "stablehlo.convolution"(%arg0, %arg1) { // CHECK-SAME: batch_group_count = 1 : i64, - // CHECK-SAME: dimension_numbers = #stablehlo.conv<[b, 0, 1, ?, f]x[0, 1, ?, i, o]->[?, b, 0, 1, f]>, + // CHECK-SAME: dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, // CHECK-SAME: feature_group_count = 1 : i64, // CHECK-SAME: lhs_dilation = dense<1> : tensor<2xi64>, // CHECK-SAME: padding = dense<1> : tensor<2x2xi64>, @@ -618,19 +621,19 @@ func.func @op_convolution(%arg0: tensor<1x8x8x32x207xf32>, %arg1: tensor<3x3x32x // CHECK-SAME: rhs_dilation = dense<1> : tensor<2xi64>, // CHECK-SAME: window_reversal = dense : tensor<2xi1>, // CHECK-SAME: window_strides = dense<1> : tensor<2xi64> - // CHECK-SAME: } : (tensor<1x8x8x32x207xf32>, tensor<3x3x32x207x16xf32>) -> tensor<32x1x8x8x16xf32> + // CHECK-SAME: } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> %0 = "mhlo.convolution"(%arg0, %arg1) { window_strides = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, lhs_dilation = dense<1> : tensor<2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_reversal = dense : tensor<2xi1>, - dimension_numbers = #mhlo.conv<[b, 0, 1, ?, f]x[0, 1, ?, i, o]->[?, b, 0, 1, f]>, + dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 1 : i64, batch_group_count = 1 : i64, precision_config = [#mhlo, #mhlo] - } : (tensor<1x8x8x32x207xf32>, tensor<3x3x32x207x16xf32>) -> tensor<32x1x8x8x16xf32> - func.return %0 : tensor<32x1x8x8x16xf32> + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> } // CHECK-LABEL: "op_convolution" @@ -669,7 +672,7 @@ func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.wi // CHECK-LABEL: "op_cstr_reshapable" func.func @called_computation() { func.return } -func.func @op_custom_call(%arg0: tensor) -> tensor { +func.func @op_custom_call_api_version_original(%arg0: tensor) -> tensor { // CHECK: "stablehlo.custom_call"(%arg0) { // CHECK-SAME: api_version = 1 : i32, // CHECK-SAME: backend_config = "", @@ -677,6 +680,11 @@ func.func @op_custom_call(%arg0: tensor) -> tensor { // CHECK-SAME: called_computations = [@foo], // CHECK-SAME: has_side_effect = false, // CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>], + // CHECK-SAME: output_operand_aliases = [ + // CHECK-SAME: #stablehlo.output_operand_alias< + // CHECK-SAME: output_tuple_indices = [], + // CHECK-SAME: operand_index = 0, + // CHECK-SAME: operand_tuple_indices = []>] // CHECK-SAME: result_layouts = [dense<> : tensor<0xindex>] // CHECK-SAME: } : (tensor) -> tensor %0 = "mhlo.custom_call"(%arg0) { @@ -686,12 +694,16 @@ func.func @op_custom_call(%arg0: tensor) -> tensor { api_version = 1 : i32, called_computations = [@foo], operand_layouts = [dense<> : tensor<0xindex>], + output_operand_aliases = [ + #mhlo.output_operand_alias + ], result_layouts = [dense<> : tensor<0xindex>] - // CustomCallOp::output_operand_aliases is unsupported at the moment (see negative test below). } : (tensor) -> tensor func.return %0 : tensor } -// CHECK-LABEL: "op_custom_call" +// CHECK-LABEL: "op_custom_call_api_version_original" func.func @op_divide(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "stablehlo.divide"(%arg0, %arg1) : (tensor, tensor) -> tensor @@ -751,10 +763,10 @@ func.func @op_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xind } // CHECK-LABEL: "op_dynamic_broadcast_in_dim" -func.func @op_dynamic_conv(%arg0: tensor<1x8x8x32x207xf32>, %arg1: tensor<3x3x32x207x16xf32>, %arg2: tensor<4xi32>) -> tensor<32x1x?x?x16xf32> { +func.func @op_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<4xi32>) -> tensor<1x?x?x16xf32> { // CHECK: "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { // CHECK-SAME: batch_group_count = 1 : i64, - // CHECK-SAME: dimension_numbers = #stablehlo.conv<[b, 0, 1, ?, f]x[0, 1, ?, i, o]->[?, b, 0, 1, f]>, + // CHECK-SAME: dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, // CHECK-SAME: feature_group_count = 1 : i64, // CHECK-SAME: lhs_dilation = dense<1> : tensor<2xi64>, // CHECK-SAME: padding = dense<1> : tensor<2x2xi64>, @@ -762,19 +774,19 @@ func.func @op_dynamic_conv(%arg0: tensor<1x8x8x32x207xf32>, %arg1: tensor<3x3x32 // CHECK-SAME: rhs_dilation = dense<1> : tensor<2xi64>, // CHECK-SAME: window_reversal = dense : tensor<2xi1>, // CHECK-SAME: window_strides = dense<1> : tensor<2xi64> - // CHECK-SAME: } : (tensor<1x8x8x32x207xf32>, tensor<3x3x32x207x16xf32>, tensor<4xi32>) -> tensor<32x1x?x?x16xf32> + // CHECK-SAME: } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<4xi32>) -> tensor<1x?x?x16xf32> %0 = "mhlo.dynamic_conv"(%arg0, %arg1, %arg2) { window_strides = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, lhs_dilation = dense<1> : tensor<2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_reversal = dense : tensor<2xi1>, - dimension_numbers = #mhlo.conv<[b, 0, 1, ?, f]x[0, 1, ?, i, o]->[?, b, 0, 1, f]>, + dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 1 : i64, batch_group_count = 1 : i64, precision_config = [#mhlo, #mhlo] - } : (tensor<1x8x8x32x207xf32>, tensor<3x3x32x207x16xf32>, tensor<4xi32>) -> tensor<32x1x?x?x16xf32> - func.return %0 : tensor<32x1x?x?x16xf32> + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<4xi32>) -> tensor<1x?x?x16xf32> + func.return %0 : tensor<1x?x?x16xf32> } // CHECK-LABEL: "op_dynamic_conv" @@ -927,13 +939,13 @@ func.func @op_get_dimension_size(%arg0: tensor) -> tensor { } // CHECK-LABEL: "op_get_dimension_size" -func.func @op_get_tuple_element(%arg0: tuple>) -> tensor { +func.func @op_get_tuple_element(%arg0: tuple, tensor, tensor, tensor, tensor>) -> tensor { // CHECK: "stablehlo.get_tuple_element"(%arg0) { - // CHECK-SAME: index = 0 : i32 - // CHECK-SAME: } : (tuple>) -> tensor + // CHECK-SAME: index = 4 : i32 + // CHECK-SAME: } : (tuple, tensor, tensor, tensor, tensor>) -> tensor %0 = "mhlo.get_tuple_element"(%arg0) { - index = 0 : i32 - } : (tuple>) -> tensor + index = 4 : i32 + } : (tuple, tensor, tensor, tensor, tensor>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "op_get_tuple_element" @@ -1106,7 +1118,12 @@ func.func @op_pad(%arg0: tensor<8xf32>, %arg1: tensor) -> tensor<16xf32> { } // CHECK-LABEL: "op_pad" -// PartitionIdOp aka mhlo.partition_id is unsupported at the moment (see negative test below). +func.func @op_partition_id() -> tensor { + // CHECK: "stablehlo.partition_id"() : () -> tensor + %0 = "mhlo.partition_id"() : () -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_partition_id" func.func @op_popcnt(%arg0: tensor) -> tensor { // CHECK: "stablehlo.popcnt"(%arg0) : (tensor) -> tensor @@ -1688,6 +1705,20 @@ func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { } // CHECK-LABEL: "type_ui64" +func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f8E4M3FN" + +func.func @type_f8E5M2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f8E5M2" + func.func @type_bf16(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor @@ -1775,6 +1806,25 @@ func.func @type_token_caller(%arg0: !mhlo.token) -> !mhlo.token { // CHECK: function_type = (!stablehlo.token) -> !stablehlo.token // CHECK-LABEL: "type_token_caller" +func.func @type_token_region(%arg0: tensor, %arg1: !mhlo.token) { + // CHECK: "stablehlo.while"(%arg1) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !stablehlo.token): + // CHECK-NEXT: "stablehlo.return"(%arg0) : (tensor) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !stablehlo.token): + // CHECK-NEXT: "stablehlo.return"(%[[ARG2]]) : (!stablehlo.token) -> () + // CHECK-NEXT: }) : (!stablehlo.token) -> !stablehlo.token + %0 = "mhlo.while"(%arg1) ({ + ^bb0(%arg2: !mhlo.token): + mhlo.return %arg0 : tensor + }, { + ^bb0(%arg2: !mhlo.token): + mhlo.return %arg2 : !mhlo.token + }) : (!mhlo.token) -> !mhlo.token + return +} +// CHECK-LABEL: "type_token_region" + func.func @type_tuple(%arg0: tuple>) -> tuple { %0 = "mhlo.custom_call"(%arg0) { call_target_name = "foo" @@ -1786,17 +1836,8 @@ func.func @type_tuple(%arg0: tuple>) -> tuple { // ============ NEGATIVE TESTS ============ // Some ops, attributes and types used in MHLO programs are not supported in StableHLO. -// For those cases, we have negative tests below. - -// ----- - -func.func @attr_precision_config_packed_nibble(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { - // expected-error@+1 {{failed to legalize operation 'mhlo.dot' that was explicitly marked illegal}} - %0 = "mhlo.dot"(%arg0, %arg1) { - precision_config = [#mhlo] - } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> - func.return %0 : tensor<8x8xf32> -} +// The following features are private, and not convertable to StableHLO even +// with the experimental flag. // ----- @@ -1881,17 +1922,29 @@ func.func @op_copy(%arg0: tensor) -> tensor { // ----- -func.func @op_custom_call_output_operand_aliases(%arg0: tensor) -> tensor { +func.func @op_convolution_unknown_dimension_numbers(%arg0: tensor<1x8x8x32x207xf32>, %arg1: tensor<3x3x32x207x16xf32>) -> tensor<32x1x8x8x16xf32> { + // expected-error@+1 {{failed to legalize operation 'mhlo.convolution' that was explicitly marked illegal}} + %0 = "mhlo.convolution"(%arg0, %arg1) { + window_strides = dense<1> : tensor<2xi64>, + padding = dense<1> : tensor<2x2xi64>, + lhs_dilation = dense<1> : tensor<2xi64>, + rhs_dilation = dense<1> : tensor<2xi64>, + window_reversal = dense : tensor<2xi1>, + dimension_numbers = #mhlo.conv<[b, 0, 1, ?, f]x[0, 1, ?, i, o]->[?, b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#mhlo, #mhlo] + } : (tensor<1x8x8x32x207xf32>, tensor<3x3x32x207x16xf32>) -> tensor<32x1x8x8x16xf32> + func.return %0 : tensor<32x1x8x8x16xf32> +} + +// ----- + +func.func @op_custom_call_custom_call_schedule(%arg0: tensor) -> tensor { // expected-error@+1 {{failed to legalize operation 'mhlo.custom_call' that was explicitly marked illegal}} %0 = "mhlo.custom_call"(%arg0) { call_target_name = "foo", - output_operand_aliases = [ - #mhlo.output_operand_alias< - output_tuple_indices = [], - operand_index = 0, - operand_tuple_indices = [] - > - ] + custom_call_schedule = #mhlo } : (tensor) -> tensor func.return %0 : tensor } @@ -1916,21 +1969,18 @@ func.func @op_fusion(%arg0: tensor) -> tensor { ^bb0(%arg1: tensor): "mhlo.return"(%arg1) : (tensor) -> () }) { - fusion_kind = #mhlo + fusion_kind = #mhlo, + output_operand_aliases = [ + #mhlo.output_operand_alias + ] } : (tensor) -> tensor func.return %0 : tensor } // ----- -func.func @op_partition_id() -> tensor { - // expected-error@+1 {{failed to legalize operation 'mhlo.partition_id' that was explicitly marked illegal}} - %0 = "mhlo.partition_id"() : () -> tensor - func.return %0 : tensor -} - -// ----- - func.func @op_stochastic_convert(%arg0: tensor, %arg1: tensor) -> tensor { // expected-error@+1 {{failed to legalize operation 'mhlo.stochastic_convert' that was explicitly marked illegal}} %0 = "mhlo.stochastic_convert"(%arg0, %arg1) : (tensor, tensor) -> tensor diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/invalid.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/invalid.mlir index 917250012ac..fdbca0a67fa 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/invalid.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/invalid.mlir @@ -111,6 +111,15 @@ func.func @complex_type_not_complex(%arg0: tensor<1xf64>) -> () { // ----- +func.func @dense_array_nested(%arg0: tensor<1x2xf64>) -> () { + // expected-error @+2 {{custom op 'stablehlo.transpose' expected integer value}} + // expected-error @+1 {{expected ']'}} + %0 = stablehlo.transpose %arg0, dims = [1, [0]] : tensor<1xf64> + func.return +} + +// ----- + func.func @select_type_wrong_type(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> () { // expected-error @+1 {{custom op 'mhlo.select' expected functional type or list of two types}} %0 = mhlo.select %arg0, %arg1, %arg1 : tensor<2x3xi1> @@ -132,3 +141,59 @@ func.func @pairwise_type_not_type(%arg0: tensor<1xf64>) -> tensor<1xf64> { %0 = mhlo.select %arg0, %arg1, %arg1 : %arg0 func.return %0 : tensor<1xf64> } + +// ----- + +func.func @reduce_precision_no_e_num(%arg0: tensor<3x4xf32>) -> (tensor<3x4xf32>) { + // expected-error @+1 {{custom op 'mhlo.reduce_precision' expected exponent mantissa in format e#m#, saw em2}} + %0 = mhlo.reduce_precision %arg0, format = em2 : tensor<3x4xf32> + func.return %0 : tensor +} + +// ----- + +func.func @reduce_precision_not_literal(%arg0: tensor<3x4xf32>) -> (tensor<3x4xf32>) { + // expected-error @+1 {{custom op 'mhlo.reduce_precision' expected valid keyword}} + %0 = mhlo.reduce_precision %arg0, format = "e2m2" : tensor<3x4xf32> + func.return %0 : tensor +} + +// ----- + +func.func @reduce_precision_no_em(%arg0: tensor<3x4xf32>) -> (tensor<3x4xf32>) { + // expected-error @+1 {{custom op 'mhlo.reduce_precision' expected exponent mantissa in format e#m#, saw z4f2}} + %0 = mhlo.reduce_precision %arg0, format = z4f2 : tensor<3x4xf32> + func.return %0 : tensor +} + +// ----- + +func.func @reduce_precision_em_order(%arg0: tensor<3x4xf32>) -> (tensor<3x4xf32>) { + // expected-error @+1 {{custom op 'mhlo.reduce_precision' expected exponent mantissa in format e#m#, saw m2e2}} + %0 = mhlo.reduce_precision %arg0, format = m2e2 : tensor<3x4xf32> + func.return %0 : tensor +} + +// ----- + +func.func @reduce_precision_no_e(%arg0: tensor<3x4xf32>) -> (tensor<3x4xf32>) { + // expected-error @+1 {{custom op 'mhlo.reduce_precision' expected exponent mantissa in format e#m#, saw m2}} + %0 = mhlo.reduce_precision %arg0, format = m2 : tensor<3x4xf32> + func.return %0 : tensor +} + +// ----- + +func.func @reduce_precision_no_m(%arg0: tensor<3x4xf32>) -> (tensor<3x4xf32>) { + // expected-error @+1 {{custom op 'mhlo.reduce_precision' expected exponent mantissa in format e#m#, saw e2}} + %0 = mhlo.reduce_precision %arg0, format = e2 : tensor<3x4xf32> + func.return %0 : tensor +} + +// ----- + +func.func @reduce_precision_no_m_num(%arg0: tensor<3x4xf32>) -> (tensor<3x4xf32>) { + // expected-error @+1 {{custom op 'mhlo.reduce_precision' expected exponent mantissa in format e#m#, saw e2m}} + %0 = mhlo.reduce_precision %arg0, format = e2m : tensor<3x4xf32> + func.return %0 : tensor +} \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/legalize-mhlo-to-thlo.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/legalize-mhlo-to-thlo.mlir index 6dc9eddd2f2..061a4591738 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/legalize-mhlo-to-thlo.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/legalize-mhlo-to-thlo.mlir @@ -11,9 +11,9 @@ func.func @dynamic_broadcast_in_dim(%arg : tensor, %shape : tensor<3xin // CHECK-DAG: %[[SHAPE_D2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]] // CHECK-DAG: %[[INIT:.*]] = tensor.empty(%[[SHAPE_D0]], %[[SHAPE_D1]], %[[SHAPE_D2]]) : tensor // CHECK-NEXT: %[[BCAST:.*]] = thlo.dynamic_broadcast_in_dim - // CHECK-SAME: ins(%[[ARG]] : tensor) - // CHECK-SAME: outs(%[[INIT]] : tensor) - // CHECK-SAME: broadcast_dimensions = [0, 2] + // CHECK-SAME: ins(%[[ARG]] : tensor) + // CHECK-SAME: outs(%[[INIT]] : tensor) + // CHECK-SAME: broadcast_dimensions = [0, 2] // CHECK: return %[[BCAST]] %0 = "mhlo.dynamic_broadcast_in_dim"(%arg, %shape) { broadcast_dimensions = dense<[0, 2]> : tensor<2xi64> } @@ -48,9 +48,9 @@ func.func @dynamic_broadcast_in_dim_with_known_expanding(%arg : tensor // CHECK-NEXT: %[[BCAST:.*]] = thlo.dynamic_broadcast_in_dim - // CHECK-SAME: ins(%[[ARG]] : tensor) - // CHECK-SAME: outs(%[[INIT]] : tensor) - // CHECK-SAME: broadcast_dimensions = [0, 2, 3] + // CHECK-SAME: ins(%[[ARG]] : tensor) + // CHECK-SAME: outs(%[[INIT]] : tensor) + // CHECK-SAME: broadcast_dimensions = [0, 2, 3] // CHECK-SAME: {known_expanding_dimensions = array, known_nonexpanding_dimensions = array} // CHECK: return %[[BCAST]] %0 = "mhlo.dynamic_broadcast_in_dim"(%arg, %shape) { @@ -74,9 +74,9 @@ func.func @concatenate(%a: tensor, %b: tensor, %c: tensor, %[[B]] : tensor, %[[C]] : tensor) - // CHECK-SAME: outs(%[[INIT]] : tensor) - // CHECK-SAME: {dimension = 1 : i64} + // CHECK-SAME: ins(%[[A]] : tensor, %[[B]] : tensor, %[[C]] : tensor) + // CHECK-SAME: outs(%[[INIT]] : tensor) + // CHECK-SAME: dimension = 1 // CHECK: return %[[CONCATENATE]] %concat = "mhlo.concatenate"(%a, %b, %c) { dimension = 1 } : (tensor, tensor, tensor) -> tensor func.return %concat : tensor @@ -91,9 +91,9 @@ func.func @concatenate_with_static_info(%a: tensor<64x32xi32>, %b: tensor<64x16x // CHECK-DAG: %[[CONCAT_DIM_SUM:.*]] = arith.addi %[[CONCAT_DIM_C]], %[[C48]] // CHECK-DAG: %[[INIT:.*]] = tensor.empty(%[[CONCAT_DIM_SUM]]) // CHECK: %[[CONCAT:.*]] = thlo.concatenate - // CHECK-SAME: ins(%[[A]] : tensor<64x32xi32>, %[[B]] : tensor<64x16xi32>, %[[C]] : tensor<64x?xi32>) - // CHECK-SAME: outs(%[[INIT]] : tensor<64x?xi32>) - // CHECK-SAME: {dimension = 1 : i64} + // CHECK-SAME: ins(%[[A]] : tensor<64x32xi32>, %[[B]] : tensor<64x16xi32>, %[[C]] : tensor<64x?xi32>) + // CHECK-SAME: outs(%[[INIT]] : tensor<64x?xi32>) + // CHECK-SAME: dimension = 1 // CHECK: return %[[CONCAT]] %concat = "mhlo.concatenate"(%a, %b, %c) { dimension = 1 } : (tensor<64x32xi32>, tensor<64x16xi32>, tensor<64x?xi32>) -> tensor<64x?xi32> func.return %concat : tensor<64x?xi32> @@ -184,7 +184,7 @@ func.func @gather_dynamic( // CHECK: %[[DIM:.*]] = tensor.dim {{.*}} %[[C0]] : tensor // CHECK: %[[INIT:.*]] = tensor.empty(%dim) : tensor // CHECK: thlo.gather -// CHECK-SAME: outs(%[[INIT]] : tensor) +// CHECK-SAME: outs(%[[INIT]] : tensor) func.func @unsupported_gather(%operand: tensor<3x3xf32>, %indices: tensor<3x2xi64>) -> tensor<3xf32> { @@ -227,10 +227,11 @@ func.func @simple_scatter(%dst: tensor<3x3xf32>, %indices: tensor<2x2xi32>, // CHECK-SAME: (%[[DST:.*]]: tensor<3x3xf32>, %[[INDICES:.*]]: tensor<2x2xi32>, // CHECK-SAME: %[[UPDATE:.*]]: tensor<2x1x3xf32>) // CHECK: %[[CAST:.*]] = arith.index_cast %[[INDICES]] {{.*}} to tensor<2x2xindex> -// CHECK: thlo.scatter ins(%[[CAST]] : tensor<2x2xindex>, -// CHECK-SAME: %[[UPDATE]] : tensor<2x1x3xf32>) -// CHECK-SAME: outs(%[[DST]] : tensor<3x3xf32>) -// CHECK-SAME: (%[[UPD:.*]]: f32, %[[CUR:.*]]: f32) { +// CHECK: thlo.scatter +// CHECK-SAME: ins(%[[CAST]] : tensor<2x2xindex>, +// CHECK-SAME: %[[UPDATE]] : tensor<2x1x3xf32>) +// CHECK-SAME: outs(%[[DST]] : tensor<3x3xf32>) +// CHECK-NEXT: (%[[UPD:.*]]: f32, %[[CUR:.*]]: f32) { // CHECK-NEXT: %[[CUR_T:.*]] = tensor.from_elements %[[CUR]] : tensor // CHECK-NEXT: %[[UPD_T:.*]] = tensor.from_elements %[[UPD]] : tensor // CHECK-NEXT: %[[CUR:.*]] = tensor.extract %[[CUR_T]][] : tensor @@ -261,8 +262,8 @@ func.func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // CHECK: thlo.sort // CHECK-SAME: ins(%[[IN0]] : tensor<16x16xf32>, %[[IN1]] : tensor<16x16xi32>) // CHECK-SAME: outs(%[[INIT0]] : tensor<16x16xf32>, %[[INIT1]] : tensor<16x16xi32>) -// CHECK-DAG: dimension = 1 : i64 -// CHECK-DAG: is_stable = true +// CHECK-SAME: dimension = 1 +// CHECK-SAME: is_stable = true // CHECK: (%[[FLOAT0:.*]]: f32, %[[FLOAT1:.*]]: f32, %[[INT0:.*]]: i32, %[[INT1:.*]]: i32) // CHECK-DAG: %[[TENSOR0:.*]] = tensor.from_elements %[[FLOAT0]] : tensor // CHECK-DAG: %[[TENSOR1:.*]] = tensor.from_elements %[[FLOAT1]] : tensor @@ -272,3 +273,39 @@ func.func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // CHECK-NEXT: %[[RESULT:.*]] = tensor.from_elements %[[CMPRESULT]] : tensor // CHECK-NEXT: %[[EXTRACTED_RESULT:.*]] = tensor.extract %[[RESULT]][] : tensor // CHECK-NEXT: thlo.yield %[[EXTRACTED_RESULT]] : i1 + +func.func @reverse_static(%input: tensor<100xf32>) + -> tensor<100xf32> { + %res = "mhlo.reverse"(%input) {dimensions = dense<[0]> : tensor<1xi64>} : + (tensor<100xf32>) -> tensor<100xf32> + func.return %res : tensor<100xf32> +} + +// CHECK-LABEL: func @reverse_static +// CHECK-SAME: (%[[ARG0:.*]]: tensor<100xf32>) -> tensor<100xf32> +// CHECK: %[[EMPTY:.*]] = tensor.empty +// CHECK: %[[REVERSED:.*]] = thlo.reverse +// CHECK-SAME: ins(%[[ARG0]] +// CHECK-SAME: outs(%[[EMPTY]] +// CHECK-SAME: reverse_dimensions = [0] +// CHECK-NEXT: return %[[REVERSED]] + +func.func @reverse_dynamic(%input: tensor) + -> tensor { + %res = "mhlo.reverse"(%input) {dimensions = dense<[0, 1]> : tensor<2xi64>} : + (tensor) -> tensor + func.return %res : tensor +} + +// CHECK-LABEL: func @reverse_dynamic +// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor +// CHECK: %[[C0:.*]] = arith.constant +// CHECK: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK: %[[C1:.*]] = arith.constant +// CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]]) +// CHECK: %[[REVERSED:.*]] = thlo.reverse +// CHECK-SAME: ins(%[[ARG0]] +// CHECK-SAME: outs(%[[EMPTY]] +// CHECK-SAME: reverse_dimensions = [0, 1] +// CHECK-NEXT: return %[[REVERSED]] \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/lower-general-dot.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/lower-general-dot.mlir index 4f0eeb60054..175a1820cb5 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/lower-general-dot.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/lower-general-dot.mlir @@ -39,7 +39,7 @@ func.func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> ten // ----- // CHECK-LABEL: @testBatchPassthrough -func.func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf32>) -> tensor<3x2x1xf32> { +func.func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf32>) -> tensor<2x3x1xf32> { // CHECK-NEXT: "mhlo.dot_general"(%arg0, %arg1) %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< @@ -49,8 +49,8 @@ func.func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf3 rhs_contracting_dimensions = [2] >, precision_config = [#mhlo, #mhlo] - } : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32> - func.return %0 : tensor<3x2x1xf32> + } : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<2x3x1xf32> + func.return %0 : tensor<2x3x1xf32> } // ----- @@ -111,18 +111,22 @@ func.func @dot_general_to_dot_dynamic(%arg0: tensor<128x4x?x32xf32>, %arg1: tens // CHECK-DAG: %[[C8:.+]] = mhlo.constant dense<8> : tensor<1xi32> // CHECK-DAG: %[[TRANS0:.+]] = "mhlo.transpose"(%arg0) {permutation = dense<[2, 3, 0, 1]> : tensor<4xi64>} // CHECK-DAG: %[[DIM0:.+]] = "mhlo.get_dimension_size"(%arg0) {dimension = 2 : i64} -// CHECK-DAG: %[[MUL0:.+]] = mhlo.multiply %[[DIM0]], %[[C32]] +// CHECK-DAG: %[[RESHAPE0:.+]] = mhlo.reshape %[[DIM0]] : (tensor) -> tensor<1xi32> +// CHECK-DAG: %[[MUL0:.+]] = mhlo.multiply %[[RESHAPE0]], %[[C32]] // CHECK-DAG: %[[CONCAT1:.+]] = "mhlo.concatenate"(%[[MUL0]], %[[C512]]) {dimension = 0 : i64} // CHECK-DAG: %[[DR1:.+]] = mhlo.dynamic_reshape %[[TRANS0]], %[[CONCAT1]] // CHECK-DAG: %[[TRANS1:.+]] = "mhlo.transpose"(%arg1) {permutation = dense<[2, 3, 0, 1]> : tensor<4xi64>} // CHECK-DAG: %[[DIM1:.+]] = "mhlo.get_dimension_size"(%arg1) {dimension = 1 : i64} -// CHECK-DAG: %[[MUL1:.+]] = mhlo.multiply %[[DIM1]], %[[C8]] +// CHECK-DAG: %[[RESHAPE1:.+]] = mhlo.reshape %[[DIM1]] : (tensor) -> tensor<1xi32> +// CHECK-DAG: %[[MUL1:.+]] = mhlo.multiply %[[RESHAPE1]], %[[C8]] // CHECK-DAG: %[[CONCAT2:.+]] = "mhlo.concatenate"(%[[C512]], %[[MUL1]]) {dimension = 0 : i64} // CHECK-DAG: %[[DR2:.+]] = mhlo.dynamic_reshape %[[TRANS1]], %[[CONCAT2]] // CHECK-DAG: %[[DOT:.+]] = "mhlo.dot"(%[[DR1:.+]], %[[DR2:.+]]) // CHECK-DAG: %[[DIM2:.+]] = "mhlo.get_dimension_size"(%arg0) {dimension = 2 : i64} +// CHECK-DAG: %[[RESHAPE2:.+]] = mhlo.reshape %[[DIM2]] : (tensor) -> tensor<1xi32> // CHECK-DAG: %[[DIM3:.+]] = "mhlo.get_dimension_size"(%arg1) {dimension = 1 : i64} -// CHECK-DAG: %[[CONCAT3:.+]] = "mhlo.concatenate"(%[[DIM2]], %[[C32]], %[[C8]], %[[DIM3]]) {dimension = 0 : i64} +// CHECK-DAG: %[[RESHAPE3:.+]] = mhlo.reshape %[[DIM3]] : (tensor) -> tensor<1xi32> +// CHECK-DAG: %[[CONCAT3:.+]] = "mhlo.concatenate"(%[[RESHAPE2]], %[[C32]], %[[C8]], %[[RESHAPE3]]) {dimension = 0 : i64} // CHECK-DAG: %[[DR3:.+]] = mhlo.dynamic_reshape %[[DOT]], %[[CONCAT3]] // CHECK: return %[[DR3]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_canonicalize_scatter.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_canonicalize_scatter.mlir index 9cf54592be4..653bd3cf913 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_canonicalize_scatter.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_canonicalize_scatter.mlir @@ -201,3 +201,34 @@ func.func @zero_dim_scatter_indices(%dst: tensor<4x4xf32>, // CHECK: update_window_dims = [1, 2], // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] // CHECK-SAME: index_vector_dim = 1 + +// ----- + +func.func @multiple_window_and_scatter_dims( + %dst: tensor<1x2x3x4x5xf32>, %indices: tensor<6x7x2xi32>, + %updates: tensor<2x6x4x7xf32>) -> tensor<1x2x3x4x5xf32> { + %0 = "mhlo.scatter"(%dst, %indices, %updates) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + mhlo.return %arg3 : tensor + }) { + indices_are_sorted = false, + scatter_dimension_numbers = #mhlo.scatter< + inserted_window_dims = [0, 2, 4], + update_window_dims = [0, 2], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 2 + >, unique_indices = false + } : (tensor<1x2x3x4x5xf32>, tensor<6x7x2xi32>, tensor<2x6x4x7xf32>) -> + tensor<1x2x3x4x5xf32> + return %0 : tensor<1x2x3x4x5xf32> +} + +// CHECK-LABEL: @multiple_window_and_scatter_dims( +// CHECK-SAME: %[[DST:.*]]: tensor<1x2x3x4x5xf32>, +// CHECK-SAME: %[[IND:.*]]: tensor<6x7x2xi32>, +// CHECK-SAME: %[[UPD:.*]]: tensor<2x6x4x7xf32> +// CHECK: %[[IND0:.*]] = tensor.collapse_shape %[[IND]] {{.*}} into tensor<42x2xi32> +// CHECK: %[[UPD0:.*]] = "mhlo.transpose"(%[[UPD]]) {{.*}} -> tensor<6x7x2x4xf32> +// CHECK: %[[UPD1:.*]] = tensor.collapse_shape %[[UPD0]] {{.*}} into tensor<42x2x4xf32> +// CHECK: %[[UPD2:.*]] = tensor.expand_shape %[[UPD1]] {{.*}} into tensor<42x1x2x1x4x1xf32> +// CHECK: "mhlo.scatter"(%[[DST]], %[[IND0]], %[[UPD2]]) \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_flatten_tuple.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_flatten_tuple.mlir index f38c43f418f..a6853afc8f2 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_flatten_tuple.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_flatten_tuple.mlir @@ -2,8 +2,8 @@ // CHECK-LABEL: @custom_call // CHECK-SAME: %[[X:.*]]: tensor<6x3xf32> -// CHECK: %[[CALL:.+]]:2 = "mhlo.custom_call"(%[[X]]) {api_version = 2 : i32, call_target_name = "f"} : (tensor<6x3xf32>) -> (tensor<6xf32>, tensor<3xf32>) -// CHECK: return %[[CALL]]#0, %[[CALL]]#1 : tensor<6xf32>, tensor<3xf32> +// CHECK: %[[CALL:.+]]:2 = mhlo.custom_call @f(%[[X]]) {api_version = 2 : i32} : (tensor<6x3xf32>) -> (tensor<6xf32>, tensor<3xf32>) +// CHECK: return %[[CALL]]#0, %[[CALL]]#1 : tensor<6xf32>, tensor<3xf32> func.func @custom_call(%x: tensor<6x3xf32>) -> (tensor<6xf32>, tensor<3xf32>) { %0 = "mhlo.custom_call"(%x) {api_version = 2 : i32, call_target_name = "f"} : (tensor<6x3xf32>) -> tuple, tensor<3xf32>> diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_infer_shape_type_methods.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_infer_shape_type_methods.mlir index f7791080074..cc2ee6aff7a 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_infer_shape_type_methods.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_infer_shape_type_methods.mlir @@ -1,20 +1,5 @@ // RUN: mlir-hlo-opt --mhlo-test-infer-shaped-type-methods --allow-unregistered-dialect --split-input-file --verify-diagnostics %s | FileCheck %s -// CHECK-LABEL: @select -// CHECK-SAME: (%{{.*}}: tensor, %[[SHAPED_ARG:.*]]: tensor<2x?xf32>, %{{.*}}: tensor<2x?xf32> -func.func @select(%pred : tensor, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>) - -> tensor<2xindex> { - // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[SHAPED_ARG]] : tensor<2x?xf32> -> tensor<2xindex> - // CHECK: return %[[SHAPE]] : tensor<2xindex> - %0 = "mhlo.select"(%pred, %a, %b) - : (tensor, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32> - %1 = "mhlo_test.reify_return_type_shapes"(%0) - : (tensor<2x?xf32>) -> tensor<2xindex> - func.return %1 : tensor<2xindex> -} - -// ----- - // CHECK-LABEL: @compare // CHECK-SAME: (%[[A:.*]]: tensor<2x?xf32>, func.func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xindex> { @@ -30,14 +15,14 @@ func.func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xindex // ----- // CHECK-LABEL: @select -func.func @select(%pred : tensor, %a : tensor<2x2xf32>, %b : tensor<2x2xf32>) - -> tensor<2x2xindex> { +func.func @select(%pred : tensor, %a : tensor, %b : tensor<1x?x3xf32>) + -> tensor<1x2x3xindex> { %0 = "mhlo.select"(%pred, %a, %b) - : (tensor, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + : (tensor, tensor, tensor<1x?x3xf32>) -> tensor<*xf32> %1 = "mhlo_test.get_return_type_components"(%0) - : (tensor<2x2xf32>) -> tensor<2x2xindex> -// CHECK: %1 = "mhlo_test.return_type_components"(%0) {dims0 = [2, 2], element_type0 = f32} : (tensor<2x2xf32>) -> tensor<2x2xindex> - func.return %1 : tensor<2x2xindex> + : (tensor<*xf32>) -> tensor<1x2x3xindex> + // CHECK: %1 = "mhlo_test.return_type_components"(%0) {dims0 = [1, 2, 3], element_type0 = f32} : (tensor<*xf32>) -> tensor<1x2x3xindex> + func.return %1 : tensor<1x2x3xindex> } // ----- @@ -87,16 +72,42 @@ func.func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tens // ----- // CHECK-LABEL: @pad -func.func @pad(%arg0: tensor<1x2x3xf16>, %arg1: tensor) -> tensor<2x4x7xf16> { +func.func @pad(%arg0: tensor<1x2x3xf16>, %arg1: tensor) -> tensor<2x4x7xindex> { %0 = "mhlo.pad"(%arg0, %arg1) { edge_padding_high = dense<[1, 1, 0]> : tensor<3xi64>, edge_padding_low = dense<[0, 1, 2]> : tensor<3xi64>, interior_padding = dense<[0, 0, 1]> : tensor<3xi64> } : (tensor<1x2x3xf16>, tensor) -> tensor<2x4x7xf16> - %1 = "mhlo_test.get_return_type_components"(%0) - : (tensor<2x4x7xf16>) -> tensor<2x4x7xindex> -// CHECK: %1 = "mhlo_test.return_type_components"(%0) {dims0 = [2, 4, 7], element_type0 = f16} : (tensor<2x4x7xf16>) -> tensor<2x4x7xindex> - func.return %0 : tensor<2x4x7xf16> + %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<2x4x7xf16>) -> tensor<2x4x7xindex> +// CHECK: %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<2x4x7xf16>) -> tensor<2x4x7xindex> + func.return %1 : tensor<2x4x7xindex> +} + +// ----- + +// CHECK-LABEL: @pad_with_bounds +func.func @pad_with_bounds(%arg0: tensor<3x?x?xf16, #mhlo.type_extensions>, %arg1: tensor) -> tensor<*xindex> { + %0 = "mhlo.pad"(%arg0, %arg1) { + edge_padding_low = dense<[2, 2, 0]> : tensor<3xi64>, + edge_padding_high = dense<[0, 0, 0]> : tensor<3xi64>, + interior_padding = dense<[1, 1, 1]> : tensor<3xi64> + } : (tensor<3x?x?xf16, #mhlo.type_extensions>, tensor) -> tensor<*xf16> + %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<*xf16>) -> tensor<*xindex> + // CHECK: %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<*xf16>) -> tensor<*xindex> + func.return %1 : tensor<*xindex> +} + +// ----- + +func.func @pad_with_negative_inferred_bounds(%arg0: tensor<3x?x?xf16, #mhlo.type_extensions>, %arg1: tensor) -> tensor<*xindex> { + // expected-error@+1 {{Padding result in negative bound for dimension 1}} + %0 = "mhlo.pad"(%arg0, %arg1) { + edge_padding_low = dense<[2, -10, 0]> : tensor<3xi64>, + edge_padding_high = dense<[0, 0, 0]> : tensor<3xi64>, + interior_padding = dense<[1, 1, 1]> : tensor<3xi64> + } : (tensor<3x?x?xf16, #mhlo.type_extensions>, tensor) -> tensor<*xf16> + %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<*xf16>) -> tensor<*xindex> + func.return %1 : tensor<*xindex> } // ----- @@ -104,9 +115,8 @@ func.func @pad(%arg0: tensor<1x2x3xf16>, %arg1: tensor) -> tensor<2x4x7xf16 // CHECK-LABEL: @cholesky func.func @cholesky(%arg0: tensor<1x2x2xf32>) -> tensor<1x2x2xindex> { %0 = "mhlo.cholesky"(%arg0) { lower = true } : (tensor<1x2x2xf32>) -> tensor<1x2x2xf32> - %1 = "mhlo_test.get_return_type_components"(%0) - : (tensor<1x2x2xf32>) -> tensor<1x2x2xindex> -// CHECK: %1 = "mhlo_test.return_type_components"(%0) {dims0 = [1, 2, 2], element_type0 = f32} : (tensor<1x2x2xf32>) -> tensor<1x2x2xindex> + %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<1x2x2xf32>) -> tensor<1x2x2xindex> + // CHECK: %1 = "mhlo_test.return_type_components"(%0) {dims0 = [1, 2, 2], element_type0 = f32} : (tensor<1x2x2xf32>) -> tensor<1x2x2xindex> func.return %1: tensor<1x2x2xindex> } @@ -128,6 +138,21 @@ func.func @alltoall(%data: tensor<4x16xf32>) -> tensor<16x4xindex> { // ----- +// CHECK-LABEL: func @alltoall_bounds +func.func @alltoall_bounds(%data: tensor<16x?xf32, #mhlo.type_extensions>) -> tensor<*xindex> { + %0 = "mhlo.all_to_all"(%data) { + split_dimension = 0 : i64, + concat_dimension = 1 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + } : (tensor<16x?xf32, #mhlo.type_extensions>) -> tensor<*xf32> + // CHECK: types0 = tensor<4x?xf32, #mhlo.type_extensions> + %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xf32>) -> tensor<*xindex> + func.return %1 : tensor<*xindex> +} + +// ----- + // CHECK-LABEL: func @abs func.func @abs(%arg0: tensor<1x2xf32>) -> tensor<1x2xindex> { %0 = "mhlo.abs"(%arg0) {} : (tensor<1x2xf32>) -> tensor<1x2xf32> @@ -140,7 +165,7 @@ func.func @abs(%arg0: tensor<1x2xf32>) -> tensor<1x2xindex> { // ----- // CHECK-LABEL: @concat -func.func @concat(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xindex> { +func.func @concat(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xindex> { %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<3xi32>) -> tensor<3xindex> @@ -150,6 +175,175 @@ func.func @concat(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xindex // ----- +// ----- + +// Inference rules to concat dimensions with bounds (lhs/rhs are commutative): +// Dim of lhs Dim of rhs Infer +// c0: X Y X+Y +// c1: X ? ? +// c2: X ?, B ?, X+B +// c3: ? ? ? +// c4: ? ?, B ? +// c5: ?, B ?, C ?, B+C + +// CHECK-LABEL: @concat_bounds_c0 +func.func @concat_bounds_c0( + %arg0: tensor<5x1xi32, #mhlo.type_extensions>, + %arg1: tensor<5x2xi32, #mhlo.type_extensions>) -> tensor<*xindex> { + %result = "mhlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 } : ( + tensor<5x1xi32, #mhlo.type_extensions>, + tensor<5x2xi32, #mhlo.type_extensions>) -> tensor + // CHECK: types0 = tensor<5x3xi32> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor<*xindex> + func.return %1 : tensor<*xindex> +} + +// ----- + +// CHECK-LABEL: @concat_bounds_c1 +func.func @concat_bounds_c1( + %arg0: tensor<5x2xi32, #mhlo.type_extensions>, + %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { + %result = "mhlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 } : ( + tensor<5x2xi32, #mhlo.type_extensions>, + tensor<5x?xi32, #mhlo.type_extensions>) -> tensor + // CHECK: types0 = tensor<5x?xi32> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor<*xindex> + + %result_swap = "mhlo.concatenate"(%arg1, %arg0) { dimension = 1 : i64 } : ( + tensor<5x?xi32, #mhlo.type_extensions>, + tensor<5x2xi32, #mhlo.type_extensions>) -> tensor + // CHECK: types0 = tensor<5x?xi32> + %2 = "mhlo_test.get_return_types"(%result_swap) : (tensor) -> tensor<*xindex> + + func.return %1 : tensor<*xindex> +} + +// ----- + +// CHECK-LABEL: @concat_bounds_c2 +func.func @concat_bounds_c2( + %arg0: tensor<5x2xi32, #mhlo.type_extensions>, + %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { + %result = "mhlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 } : ( + tensor<5x2xi32, #mhlo.type_extensions>, + tensor<5x?xi32, #mhlo.type_extensions>) -> tensor + // CHECK: types0 = tensor<5x?xi32, #mhlo.type_extensions> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor<*xindex> + + %result_swap = "mhlo.concatenate"(%arg1, %arg0) { dimension = 1 : i64 } : ( + tensor<5x?xi32, #mhlo.type_extensions>, + tensor<5x2xi32, #mhlo.type_extensions>) -> tensor + // CHECK: types0 = tensor<5x?xi32, #mhlo.type_extensions> + %2 = "mhlo_test.get_return_types"(%result_swap) : (tensor) -> tensor<*xindex> + + func.return %1 : tensor<*xindex> +} + +// ----- + +// CHECK-LABEL: @concat_bounds_c3 +func.func @concat_bounds_c3( + %arg0: tensor<5x?xi32, #mhlo.type_extensions>, + %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { + %result = "mhlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 } : ( + tensor<5x?xi32, #mhlo.type_extensions>, + tensor<5x?xi32, #mhlo.type_extensions>) -> tensor + // CHECK: types0 = tensor<5x?xi32> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor<*xindex> + func.return %1 : tensor<*xindex> +} + +// ----- + +// CHECK-LABEL: @concat_bounds_c4 +func.func @concat_bounds_c4( + %arg0: tensor<5x?xi32, #mhlo.type_extensions>, + %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { + %result = "mhlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 } : ( + tensor<5x?xi32, #mhlo.type_extensions>, + tensor<5x?xi32, #mhlo.type_extensions>) -> tensor + // CHECK: types0 = tensor<5x?xi32> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor<*xindex> + + %result_swap = "mhlo.concatenate"(%arg1, %arg0) { dimension = 1 : i64 } : ( + tensor<5x?xi32, #mhlo.type_extensions>, + tensor<5x?xi32, #mhlo.type_extensions>) -> tensor + // CHECK: types0 = tensor<5x?xi32> + %2 = "mhlo_test.get_return_types"(%result_swap) : (tensor) -> tensor<*xindex> + + func.return %1 : tensor<*xindex> +} + +// ----- + +// CHECK-LABEL: @concat_bounds_c5 +func.func @concat_bounds_c5( + %arg0: tensor<5x?xi32, #mhlo.type_extensions>, + %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { + %result = "mhlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 } : ( + tensor<5x?xi32, #mhlo.type_extensions>, + tensor<5x?xi32, #mhlo.type_extensions>) -> tensor + // CHECK: types0 = tensor<5x?xi32, #mhlo.type_extensions> + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor<*xindex> + func.return %1 : tensor<*xindex> +} + +// ----- + +// Note: unranked input types can't be ignored, consider these input types: +// c0: (<5x?xf32>, <*xf32>) with concat dim 0 should infer +// c1: (<5x?xf32>, <*xf32>) with concat dim 1 should infer <5x?xf32> +// Instead, they should be replaced with dynamic tensors: tensor + +// CHECK-LABEL: @concat_bounds_unranked_c0 +func.func @concat_bounds_unranked_c0( + %arg0: tensor<*xi32>, + %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { + %result = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : ( + tensor<*xi32>, + tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<5x?xi32> + // CHECK: types0 = tensor> + %1 = "mhlo_test.get_return_types"(%result) : (tensor<5x?xi32>) -> tensor<*xindex> + func.return %1 : tensor<*xindex> +} + +// ----- + +// CHECK-LABEL: @concat_bounds_unranked_c1 +func.func @concat_bounds_unranked_c1( + %arg0: tensor<*xi32>, + %arg1: tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { + %result = "mhlo.concatenate"(%arg0, %arg1) { dimension = 1 : i64 } : ( + tensor<*xi32>, + tensor<5x?xi32, #mhlo.type_extensions>) -> tensor<5x?xi32> + // CHECK: types0 = tensor<5x?xi32> + %1 = "mhlo_test.get_return_types"(%result) : (tensor<5x?xi32>) -> tensor<*xindex> + func.return %1 : tensor<*xindex> +} + +// ----- + +// CHECK-LABEL: while_bounds +func.func @while_bounds( + %while_arg_1: tensor<2x?xi32, #mhlo.type_extensions>, + %while_arg_2: tensor<3xf32>) -> tensor<*xindex> { + %1:2 = "mhlo.while"(%while_arg_1, %while_arg_2) ({ + ^bb0(%arg1: tensor<2x?xi32, #mhlo.type_extensions>, %arg2: tensor<3xf32>): + %2 = mhlo.constant dense<1> : tensor + "mhlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg1: tensor<2x?xi32, #mhlo.type_extensions>, %arg2: tensor<3xf32>): + "mhlo.return"(%arg1, %arg2) : (tensor<2x?xi32, #mhlo.type_extensions>, tensor<3xf32>) -> () + }) : (tensor<2x?xi32, #mhlo.type_extensions>, tensor<3xf32>) -> (tensor<*xi32>, tensor<*xf32>) + // CHECK: types0 = tensor<2x?xi32, #mhlo.type_extensions>, + // CHECK-SAME: types1 = tensor<3xf32> + %3 = "mhlo_test.get_return_types"(%1) : (tensor<*xi32>) -> tensor<*xindex> + func.return %3 : tensor<*xindex> +} + +// ----- + // CHECK-LABEL: @gather func.func @gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor<1x5x2xi32>) -> tensor<1x5x8xindex> { %res = "mhlo.gather"(%operand, %start_indices) { @@ -205,6 +399,25 @@ func.func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x2xindex> { // ----- +// CHECK-LABEL: func @slice_with_bounds +func.func @slice_with_bounds(%arg0: tensor<3x?x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { + %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0, 0]> : tensor<3xi64>, limit_indices = dense<[2, 4, 4]> : tensor<3xi64>, strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor<3x?x?xi32, #mhlo.type_extensions>) -> tensor<*xi32> + // CHECK: %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<*xi32>) -> tensor<*xindex> + %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<*xi32>) -> tensor<*xindex> + func.return %1 : tensor<*xindex> +} + +// ----- + +func.func @slice_with_index_larger_than_bound_dim(%arg0: tensor<3x?x?xi32, #mhlo.type_extensions>) -> tensor<*xindex> { + // expected-error@+1 {{limit index 5 is larger than dimension bound 4 in dimension 1}} + %0 = "mhlo.slice"(%arg0) {start_indices = dense<[1, 0, 0]> : tensor<3xi64>, limit_indices = dense<[2, 5, 4]> : tensor<3xi64>, strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor<3x?x?xi32, #mhlo.type_extensions>) -> tensor<*xi32> + %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<*xi32>) -> tensor<*xindex> + func.return %1 : tensor<*xindex> +} + +// ----- + // CHECK-LABEL: func @clamp func.func @clamp(%arg0: tensor<1xi32>) -> tensor<1xindex> { %0 = "mhlo.clamp"(%arg0, %arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> @@ -239,41 +452,110 @@ func.func @fft(%arg0: tensor<3x9xcomplex>) -> tensor<3x9xindex> { // ----- // CHECK-LABEL: func @batch_norm_grad -func.func @batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xindex> { +func.func @batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<*xindex> { %0:3 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) - // CHECK: (tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xindex> - %1 = "mhlo_test.get_return_type_components"(%0#0) : (tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xindex> - // CHECK: (tensor<2xf32>) -> tensor<2xindex> - %2 = "mhlo_test.get_return_type_components"(%0#1) : (tensor<2xf32>) -> tensor<2xindex> - // CHECK: (tensor<2xf32>) -> tensor<2xindex> - %3 = "mhlo_test.get_return_type_components"(%0#2) : (tensor<2xf32>) -> tensor<2xindex> - func.return %1 : tensor<2x2x2x2xindex> + // CHECK: types0 = tensor<2x2x2x2xf32> + // CHECK-SAME: types1 = tensor<2xf32> + // CHECK-SAME: types2 = tensor<2xf32> + %1 = "mhlo_test.get_return_types"(%0#0) : (tensor<2x2x2x2xf32>) -> tensor<*xindex> + func.return %1 : tensor<*xindex> } // ----- // CHECK-LABEL: func @batch_norm_train -func.func @batch_norm_train(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tensor<2x2x2x2xindex> { - %0:3 = "mhlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 1 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) - // CHECK: (tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xindex> - %1 = "mhlo_test.get_return_type_components"(%0#0) : (tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xindex> - // CHECK: (tensor<2xf32>) -> tensor<2xindex> - %2 = "mhlo_test.get_return_type_components"(%0#1) : (tensor<2xf32>) -> tensor<2xindex> - // CHECK: (tensor<2xf32>) -> tensor<2xindex> - %3 = "mhlo_test.get_return_type_components"(%0#2) : (tensor<2xf32>) -> tensor<2xindex> - func.return %1 : tensor<2x2x2x2xindex> +func.func @batch_norm_train(%input: tensor<2x?x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tensor<*xindex> { + %0:3 = "mhlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 1 : i64} : (tensor<2x?x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<2x?x2x2xf32>, tensor, tensor) + // CHECK: types0 = tensor<2x?x2x2xf32> + // CHECK-SAME: types1 = tensor + // CHECK-SAME: types2 = tensor + %1 = "mhlo_test.get_return_types"(%0#0) : (tensor<2x?x2x2xf32>) -> tensor<*xindex> + func.return %1 : tensor<*xindex> } // ----- // CHECK-LABEL: @batch_norm_inference -func.func @batch_norm_inference(%input: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, %mean: tensor<256xf32>, %variance: tensor<256xf32>) -> (tensor<4x256xindex>) { +func.func @batch_norm_inference(%input: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, %mean: tensor<256xf32>, %variance: tensor<256xf32>) -> (tensor<*xindex>) { %0 = "mhlo.batch_norm_inference" (%input, %scale, %offset, %mean, %variance) {epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} : (tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>) -> tensor<4x256xf32> - // CHECK: (tensor<4x256xf32>) -> tensor<4x256xindex> - %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<4x256xf32>) -> tensor<4x256xindex> - func.return %1 : tensor<4x256xindex> + // CHECK: types0 = tensor<4x256xf32> + %1 = "mhlo_test.get_return_types"(%0) : (tensor<4x256xf32>) -> tensor<*xindex> + func.return %1 : tensor<*xindex> +} + +// ----- + +// CHECK-LABEL: @batch_norm_inference_bounds +func.func @batch_norm_inference_bounds( + %input: tensor<4x?xf32, #mhlo.type_extensions>, %scale: tensor, + %offset: tensor, %mean: tensor, %variance: tensor +) -> (tensor<*xindex>) { + %0 = "mhlo.batch_norm_inference" (%input, %scale, %offset, %mean, %variance) { + epsilon = 1.001000e-05 : f32, feature_index = 1 : i64 + } : (tensor<4x?xf32, #mhlo.type_extensions>, tensor, tensor, tensor, tensor) -> tensor<4x?xf32, #mhlo.type_extensions> + // CHECK: types0 = tensor<4x?xf32, #mhlo.type_extensions> + %1 = "mhlo_test.get_return_types"(%0) : (tensor<4x?xf32, #mhlo.type_extensions>) -> tensor<*xindex> + func.return %1 : tensor<*xindex> +} + +// ----- + +// CHECK-LABEL: func @batch_norm_grad_bounds +func.func @batch_norm_grad_bounds( + %input: tensor<2x?xf32, #mhlo.type_extensions>, + %scale: tensor>, + %mean: tensor>, + %variance: tensor>, + %grad_output: tensor<2x?xf32, #mhlo.type_extensions> +) -> tensor<*xindex> { + %0:3 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) { + epsilon = 0.001 : f32, feature_index = 1 : i64 + } : ( + tensor<2x?xf32, #mhlo.type_extensions>, + tensor>, + tensor>, + tensor>, + tensor<2x?xf32, #mhlo.type_extensions> + ) -> + ( + tensor<2x?xf32, #mhlo.type_extensions>, + tensor>, + tensor> + ) + // CHECK: types0 = tensor<2x?xf32, #mhlo.type_extensions> + // CHECK-SAME: types1 = tensor> + // CHECK-SAME: types2 = tensor> + %1 = "mhlo_test.get_return_types"(%0#0) : (tensor<2x?xf32, #mhlo.type_extensions>) -> tensor<*xindex> + func.return %1 : tensor<*xindex> +} + +// ----- + +// CHECK-LABEL: func @batch_norm_train_bounds +func.func @batch_norm_train_bounds( + %input: tensor<2x?xf32, #mhlo.type_extensions>, + %scale: tensor>, + %offset: tensor> +) -> tensor<*xindex> { + %0:3 = "mhlo.batch_norm_training" (%input, %scale, %offset) { + epsilon = 0.001 : f32, feature_index = 1 : i64 + } : ( + tensor<2x?xf32, #mhlo.type_extensions>, + tensor>, + tensor> + ) -> + ( + tensor<2x?xf32, #mhlo.type_extensions>, + tensor>, + tensor> + ) + // CHECK: types0 = tensor<2x?xf32, #mhlo.type_extensions> + // CHECK-SAME: types1 = tensor> + // CHECK-SAME: types2 = tensor> + %1 = "mhlo_test.get_return_types"(%0#0) : (tensor<2x?xf32, #mhlo.type_extensions>) -> tensor<*xindex> + func.return %1 : tensor<*xindex> } // ----- @@ -303,7 +585,7 @@ func.func @triangular_solve(%arg0: tensor<10x5x4x4xf32>, %arg1: tensor<10x5x4x4x // ----- // CHECK-LABEL: func @if -func.func @if(%pred : tensor, %branch_operand : tensor<2xf32>, %wrong_type : tensor<2xf32>) { +func.func @if(%pred : tensor, %branch_operand : tensor<2xf32>, %wrong_type : tensor<2xf32>) -> tensor<2xindex> { %0 = "mhlo.if"(%pred) ({ "mhlo.return"(%wrong_type) : (tensor<2xf32>) -> () }, { @@ -311,13 +593,13 @@ func.func @if(%pred : tensor, %branch_operand : tensor<2xf32>, %wrong_type : }) : (tensor) -> tensor<2xf32> // CHECK: (tensor<2xf32>) -> tensor<2xindex> %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<2xf32>) -> tensor<2xindex> - func.return + func.return %1 : tensor<2xindex> } // ----- // CHECK-LABEL: func @case -func.func @case(%index : tensor, %branch_operand : tensor<2xf32>) { +func.func @case(%index : tensor, %branch_operand : tensor<2xf32>) -> tensor<2xindex> { %0 = "mhlo.case"(%index) ({ "mhlo.return"(%branch_operand) : (tensor<2xf32>) -> () }, { @@ -325,13 +607,13 @@ func.func @case(%index : tensor, %branch_operand : tensor<2xf32>) { }) : (tensor) -> tensor<2xf32> // CHECK: (tensor<2xf32>) -> tensor<2xindex> %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<2xf32>) -> tensor<2xindex> - func.return + func.return %1 : tensor<2xindex> } // ----- // CHECK-LABEL: func @sort -func.func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { +func.func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) -> (tensor<16x16xindex>, tensor<16x16xindex>) { %0:2 = "mhlo.sort"(%input0, %input1) ({ ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor @@ -341,11 +623,43 @@ func.func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { %1 = "mhlo_test.get_return_type_components"(%0#0) : (tensor<16x16xf32>) -> tensor<16x16xindex> // CHECK: (tensor<16x16xi32>) -> tensor<16x16xindex> %2 = "mhlo_test.get_return_type_components"(%0#1) : (tensor<16x16xi32>) -> tensor<16x16xindex> + func.return %1, %2 : tensor<16x16xindex>, tensor<16x16xindex> +} + +// ----- + +// CHECK-LABEL: @sort_bounds_and_unknown_rank +func.func @sort_bounds_and_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<5x?x?xi32, #mhlo.type_extensions>) { + %0, %1 = "mhlo.sort"(%input0, %input1) ({ + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %pred = "mhlo.compare"(%arg0, %arg1) { + comparison_direction = #mhlo + } : (tensor, tensor) -> tensor + "mhlo.return"(%pred) : (tensor) -> () + }) { dimension = 1 : i64, is_stable = true } : ( + tensor<*xf32>, + tensor<5x?x?xi32, #mhlo.type_extensions> + ) -> (tensor<*xf32>, tensor<*xi32>) + // CHECK: types0 = tensor<*xf32> + // CHECK-SAME: types1 = tensor<5x?x?xi32, #mhlo.type_extensions> + %2 = "mhlo_test.get_return_types"(%0) : (tensor<*xf32>) -> tensor<*xindex> func.return } // ----- +// CHECK-LABEL: func @outfeed +func.func @outfeed(%arg0: tensor<3x3x3xi32>, %arg1: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.outfeed"(%arg0, %arg1) { + outfeed_config = "" + } : (tensor<3x3x3xi32>, !mhlo.token) -> !mhlo.token + %1 = "mhlo_test.get_return_types"(%0) : (!mhlo.token) -> !mhlo.token + // CHECK: %1 = "mhlo_test.return_types"(%0) {types0 = !mhlo.token} : (!mhlo.token) -> !mhlo.token + func.return %1 : !mhlo.token +} + +// ----- + // CHECK-LABEL: func @while func.func @while(%arg0: tensor<4xf32>, %arg1: tensor, %arg2: tensor, %arg3: tensor<4xf32>, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor) -> tensor { %cst = arith.constant dense<-1> : tensor @@ -372,6 +686,148 @@ func.func @while(%arg0: tensor<4xf32>, %arg1: tensor, %arg2: tensor, % // ----- +// CHECK-LABEL: func @get_dimension_size +func.func @get_dimension_size(%arg0: tensor<4x2xf32>) -> tensor { + %0 = "mhlo.get_dimension_size"(%arg0) {dimension = 1 : i64} : (tensor<4x2xf32>) -> tensor + %1 = "mhlo_test.get_return_types"(%0) : (tensor) -> tensor + // CHECK: %1 = "mhlo_test.return_types"(%0) {types0 = tensor} : (tensor) -> tensor + func.return %1 : tensor +} + +// ----- + +// CHECK-LABEL: @dynamic_update_slice +func.func @dynamic_update_slice(%arg0: tensor<4x4xi32>, %arg1: tensor<2x2xi32>, %arg2: tensor, %arg3: tensor) -> tensor<4x4xindex> { + %0 = "mhlo.dynamic_update_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xi32>, tensor<2x2xi32>, tensor, tensor) -> tensor<4x4xi32> + %1 = "mhlo_test.get_return_type_components"(%0) : (tensor<4x4xi32>) -> tensor<4x4xindex> + // CHECK: %1 = "mhlo_test.return_type_components"(%0) {dims0 = [4, 4], element_type0 = i32} : (tensor<4x4xi32>) -> tensor<4x4xindex> + func.return %1 : tensor<4x4xindex> +} + +// ----- + +// CHECK-LABEL: @dynamic_update_slice_with_bounds +func.func @dynamic_update_slice_with_bounds(%input: tensor<3x?x?xi64, #mhlo.type_extensions>, %update: tensor<1x4x3xi64>, %start1: tensor, %start2: tensor, %start3 : tensor) -> tensor<*xindex> { + %0 = "mhlo.dynamic_update_slice"(%input, %update, %start1, %start2, %start3) : (tensor<3x?x?xi64, #mhlo.type_extensions>, tensor<1x4x3xi64>, tensor, tensor, tensor) -> tensor<3x?x?xi64> + // CHECK: types0 = tensor<3x?x?xi64, #mhlo.type_extensions> + %1 = "mhlo_test.get_return_types"(%0) : (tensor<3x?x?xi64>) -> tensor<*xindex> + func.return %1 : tensor<*xindex> +} + +// ----- + +// CHECK-LABEL: func @create_token +func.func @create_token() -> !mhlo.token { + %0 = "mhlo.create_token"() : () -> !mhlo.token + %1 = "mhlo_test.get_return_types"(%0) : (!mhlo.token) -> !mhlo.token + // CHECK: %1 = "mhlo_test.return_types"(%0) {types0 = !mhlo.token} : (!mhlo.token) -> !mhlo.token + func.return %1 : !mhlo.token +} + +// ----- + +// CHECK-LABEL: func @after_all_empty_arg +func.func @after_all_empty_arg() -> !mhlo.token { + %0 = "mhlo.after_all"() : () -> !mhlo.token + %1 = "mhlo_test.get_return_types"(%0) : (!mhlo.token) -> !mhlo.token + // CHECK: %1 = "mhlo_test.return_types"(%0) {types0 = !mhlo.token} : (!mhlo.token) -> !mhlo.token + func.return %1 : !mhlo.token +} + +// ----- + +// CHECK-LABEL: func @after_all +func.func @after_all(%arg0: !mhlo.token, %arg1: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0, %arg1) : (!mhlo.token, !mhlo.token) -> !mhlo.token + %1 = "mhlo_test.get_return_types"(%0) : (!mhlo.token) -> !mhlo.token + // CHECK: %1 = "mhlo_test.return_types"(%0) {types0 = !mhlo.token} : (!mhlo.token) -> !mhlo.token + func.return %1 : !mhlo.token +} + +// ----- + +// CHECK: func @select_and_scatter +func.func @select_and_scatter( + %arg0: tensor<10x24x24x64xf32>, + %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xindex> { + %0 = mhlo.constant dense<0.000000e+00> : tensor + + %1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = "mhlo.compare"(%arg3, %arg4) { + compare_type = #mhlo, + comparison_direction = #mhlo + } : (tensor, tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = mhlo.add %arg3, %arg4 : tensor + "mhlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> + tensor<10x24x24x64xf32> + %3 = "mhlo_test.get_return_types"(%1) : (tensor<10x24x24x64xf32>) -> tensor<10x24x24x64xindex> + // CHECK: %2 = "mhlo_test.return_types"(%1) {types0 = tensor<10x24x24x64xf32>} : (tensor<10x24x24x64xf32>) -> tensor<10x24x24x64xindex> + func.return %3 : tensor<10x24x24x64xindex> +} + +// ----- + +// CHECK-LABEL: func @scatter +func.func @scatter(%input_tensor: tensor<200x100x300xf32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> + tensor<200x100x300xindex> { + %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = mhlo.add %lhs, %rhs : tensor + "mhlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> + tensor<200x100x300xf32> + %1 = "mhlo_test.get_return_types"(%0) : (tensor<200x100x300xf32>) -> tensor<200x100x300xindex> + // CHECK: %1 = "mhlo_test.return_types"(%0) {types0 = tensor<200x100x300xf32>} : (tensor<200x100x300xf32>) -> tensor<200x100x300xindex> + func.return %1 : tensor<200x100x300xindex> +} + +// ----- + +// CHECK-LABEL: func @scatter_bounds +func.func @scatter_bounds(%input_tensor: tensor<200x?x?xf32, #mhlo.type_extensions>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> + tensor<*xindex> { + %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = mhlo.add %lhs, %rhs : tensor + "mhlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x?x?xf32, #mhlo.type_extensions>, tensor<10x2xi32>, tensor<10x300xf32>) -> + tensor<200x?x?xf32> + + %1 = "mhlo_test.get_return_types"(%0) : (tensor<200x?x?xf32>) -> tensor<*xindex> + // CHECK: types0 = tensor<200x?x?xf32, #mhlo.type_extensions> + func.return %1 : tensor<*xindex> +} + +// ----- + //===----------------------------------------------------------------------===// // Sparsity //===----------------------------------------------------------------------===// @@ -385,7 +841,7 @@ func.func @tanh_sparsity(%arg0: tensor<10x10xf32, #CSR>) -> tensor<10x10xindex> %0 = "mhlo.tanh"(%arg0) : (tensor<10x10xf32, #CSR>) -> tensor<10x10xf32> %1 = "mhlo_test.get_return_types"(%0) : (tensor<10x10xf32>) -> tensor<10x10xindex> -// CHECK: %1 = "mhlo_test.return_types"(%0) {types0 = tensor<10x10xf32, {{.*}}>} : (tensor<10x10xf32>) -> tensor<10x10xindex> + // CHECK: %1 = "mhlo_test.return_types"(%0) {types0 = tensor<10x10xf32, {{.*}}>} : (tensor<10x10xf32>) -> tensor<10x10xindex> func.return %1 : tensor<10x10xindex> } @@ -400,7 +856,7 @@ func.func @abs_sparsity(%arg0: tensor<10x10xf32, #CSR>) -> tensor<10x10xindex> { %0 = "mhlo.abs"(%arg0) : (tensor<10x10xf32, #CSR>) -> tensor<10x10xf32> %1 = "mhlo_test.get_return_types"(%0) : (tensor<10x10xf32>) -> tensor<10x10xindex> -// CHECK: %1 = "mhlo_test.return_types"(%0) {types0 = tensor<10x10xf32, {{.*}}>} : (tensor<10x10xf32>) -> tensor<10x10xindex> + // CHECK: %1 = "mhlo_test.return_types"(%0) {types0 = tensor<10x10xf32, {{.*}}>} : (tensor<10x10xf32>) -> tensor<10x10xindex> func.return %1 : tensor<10x10xindex> } @@ -415,7 +871,7 @@ func.func @real_sparsity(%arg0: tensor<10x10xcomplex, #CSR>) -> tensor<10x1 %0 = "mhlo.real"(%arg0) : (tensor<10x10xcomplex, #CSR>) -> tensor<10x10xf32> %1 = "mhlo_test.get_return_types"(%0) : (tensor<10x10xf32>) -> tensor<10x10xindex> -// CHECK: %1 = "mhlo_test.return_types"(%0) {types0 = tensor<10x10xf32, {{.*}}>} : (tensor<10x10xf32>) -> tensor<10x10xindex> + // CHECK: %1 = "mhlo_test.return_types"(%0) {types0 = tensor<10x10xf32, {{.*}}>} : (tensor<10x10xf32>) -> tensor<10x10xindex> func.return %1 : tensor<10x10xindex> } @@ -430,7 +886,7 @@ func.func @imag_sparsity(%arg0: tensor<10x10xcomplex, #CSR>) -> tensor<10x1 %0 = "mhlo.imag"(%arg0) : (tensor<10x10xcomplex, #CSR>) -> tensor<10x10xf32> %1 = "mhlo_test.get_return_types"(%0) : (tensor<10x10xf32>) -> tensor<10x10xindex> -// CHECK: %1 = "mhlo_test.return_types"(%0) {types0 = tensor<10x10xf32, {{.*}}>} : (tensor<10x10xf32>) -> tensor<10x10xindex> + // CHECK: %1 = "mhlo_test.return_types"(%0) {types0 = tensor<10x10xf32, {{.*}}>} : (tensor<10x10xf32>) -> tensor<10x10xindex> func.return %1 : tensor<10x10xindex> } @@ -445,26 +901,70 @@ func.func @complex_sparsity(%arg0: tensor<10x10xf32, #CSR>, %arg1: tensor<10x10x %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<10x10xf32, #CSR>, tensor<10x10xf32, #CSR>) -> tensor<10x10xcomplex> %1 = "mhlo_test.get_return_types"(%0) : (tensor<10x10xcomplex>) -> tensor<10x10xindex> -// CHECK: %1 = "mhlo_test.return_types"(%0) {types0 = tensor<10x10xcomplex, {{.*}}>} : (tensor<10x10xcomplex>) -> tensor<10x10xindex> + // CHECK: %1 = "mhlo_test.return_types"(%0) {types0 = tensor<10x10xcomplex, {{.*}}>} : (tensor<10x10xcomplex>) -> tensor<10x10xindex> func.return %1 : tensor<10x10xindex> } // ----- // CHECK-LABEL: func @reduce -func.func @reduce(%arg0: tensor<4x4xf32>, %arg1 : tensor<4xf32>) - -> (tensor<4xindex>) { +func.func @reduce(%arg0: tensor<7x5xf32>, %arg1 : tensor<5xf32>) + -> (tensor<5xindex>) { %0 = "mhlo.reduce"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor<4xf32>, %arg3: tensor<4xf32> ): - %1 = "mhlo.add"(%arg2, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - "mhlo.return"(%1) : (tensor<4xf32>) -> () + ^bb0(%arg2: tensor<5xf32>, %arg3: tensor<5xf32> ): + %1 = "mhlo.add"(%arg2, %arg3) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> + "mhlo.return"(%1) : (tensor<5xf32>) -> () + + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<7x5xf32>, tensor<5xf32>) -> tensor<5xf32> - }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: {dims0 = [5], element_type0 = f32} %2 = "mhlo_test.get_return_type_components"(%0) - : (tensor<4xf32>) -> tensor<4xindex> -// CHECK: %1 = "mhlo_test.return_type_components"(%0) {dims0 = [4], element_type0 = f32} : (tensor<4xf32>) -> tensor<4xindex> - func.return %2: tensor<4xindex> + : (tensor<5xf32>) -> tensor<5xindex> + + func.return %2: tensor<5xindex> +} + +// ----- + +// CHECK-LABEL: func @reduce_with_bounds +func.func @reduce_with_bounds(%arg0: tensor>, %arg1 : tensor<5xf32>) + -> (tensor<*xindex>) { + %0 = "mhlo.reduce"(%arg0, %arg1) ({ + + ^bb0(%arg2: tensor<5xf32>, %arg3: tensor<5xf32> ): + %1 = "mhlo.add"(%arg2, %arg3) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> + "mhlo.return"(%1) : (tensor<5xf32>) -> () + + }) {dimensions = dense<[0]> : tensor<1xi64>} + : (tensor>, tensor<5xf32>) + -> tensor> + + // CHECK: types0 = tensor> + %2 = "mhlo_test.get_return_types"(%0) + : (tensor>) -> tensor<*xindex> + + func.return %2: tensor<*xindex> +} + +// ----- + +// CHECK-LABEL: func @unranked_reduce +func.func @unranked_reduce(%arg0: tensor<*xf32>, %arg1 : tensor) + -> (tensor<*xindex>) { + %0 = "mhlo.reduce"(%arg0, %arg1) ({ + + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "mhlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () + + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<*xf32>, tensor) -> tensor<*xf32> + + // CHECK: {element_type0 = f32} + %2 = "mhlo_test.get_return_type_components"(%0) + : (tensor<*xf32>) -> tensor<*xindex> + + func.return %2: tensor<*xindex> } // ----- @@ -504,7 +1004,7 @@ func.func @reduce_window(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, func.func @tensor_bounds(%arg0: tensor<3x5xf32>, %arg1: tensor) -> tensor<*xindex> { %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x5xf32>, tensor) -> tensor<*xf32> - // CHECK: types0 = tensor> + // CHECK: types0 = tensor> %1 = "mhlo_test.get_return_types"(%result) : (tensor<*xf32>) -> tensor<*xindex> func.return %1 : tensor<*xindex> } @@ -512,9 +1012,9 @@ func.func @tensor_bounds(%arg0: tensor<3x5xf32>, %arg1: tensor) -> tensor<* // ----- // CHECK-LABEL: @static_tensor_bounds -func.func @static_tensor_bounds(%arg0: tensor>) -> tensor<*xindex> { +func.func @static_tensor_bounds(%arg0: tensor>) -> tensor<*xindex> { %bounds = mhlo.constant dense<8> : tensor - %result = "mhlo.set_dimension_size"(%arg0, %bounds) {dimension = 0 : i64} : (tensor>, tensor) -> tensor<*xf32> + %result = "mhlo.set_dimension_size"(%arg0, %bounds) {dimension = 0 : i64} : (tensor>, tensor) -> tensor<*xf32> // CHECK: types0 = tensor<8x5xf32> %1 = "mhlo_test.get_return_types"(%result) : (tensor<*xf32>) -> tensor<*xindex> @@ -524,8 +1024,8 @@ func.func @static_tensor_bounds(%arg0: tensor>, %arg1: tensor) -> tensor<*xindex> { - %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 1 : i64} : (tensor>, tensor) -> tensor<*xf32> +func.func @edit_tensor_bounds(%arg0: tensor>, %arg1: tensor) -> tensor<*xindex> { + %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 1 : i64} : (tensor>, tensor) -> tensor<*xf32> // CHECK: types0 = tensor> %1 = "mhlo_test.get_return_types"(%result) : (tensor<*xf32>) -> tensor<*xindex> @@ -535,10 +1035,10 @@ func.func @edit_tensor_bounds(%arg0: tensor>, %arg1: tensor) -> tensor<*xindex> { - %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 0 : i64} : (tensor>, tensor) -> tensor<*xf32> +func.func @retain_tensor_bounds(%arg0: tensor>, %arg1: tensor) -> tensor<*xindex> { + %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 0 : i64} : (tensor>, tensor) -> tensor<*xf32> - // CHECK: types0 = tensor> + // CHECK: types0 = tensor> %1 = "mhlo_test.get_return_types"(%result) : (tensor<*xf32>) -> tensor<*xindex> func.return %1 : tensor<*xindex> } @@ -546,10 +1046,10 @@ func.func @retain_tensor_bounds(%arg0: tensor>, %arg1: tensor) -> tensor<*xindex> { - %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 1 : i64} : (tensor>, tensor) -> tensor<*xf32> +func.func @unknown_bounds(%arg0: tensor>, %arg1: tensor) -> tensor<*xindex> { + %result = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 1 : i64} : (tensor>, tensor) -> tensor<*xf32> - // CHECK: types0 = tensor> + // CHECK: types0 = tensor> %1 = "mhlo_test.get_return_types"(%result) : (tensor<*xf32>) -> tensor<*xindex> func.return %1 : tensor<*xindex> } @@ -572,21 +1072,21 @@ func.func @unranked_input(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*x // See PairwiseSameOperandAndResultType::inferDimWithBound() // CHECK-LABEL: @add_bounds func.func @add_bounds( - %arg0: tensor<3x3x3x?x?x?x?xf32, #mhlo.type_extensions>, - %arg1: tensor<3x?x?x?x?x?x?xf32, #mhlo.type_extensions>) -> tensor<*xindex> { + %arg0: tensor<3x3x3x?x?x?x?xf32, #mhlo.type_extensions>, + %arg1: tensor<3x?x?x?x?x?x?xf32, #mhlo.type_extensions>) -> tensor<*xindex> { %result1 = "mhlo.add"(%arg0, %arg1) : ( - tensor<3x3x3x?x?x?x?xf32, #mhlo.type_extensions>, - tensor<3x?x?x?x?x?x?xf32, #mhlo.type_extensions>) + tensor<3x3x3x?x?x?x?xf32, #mhlo.type_extensions>, + tensor<3x?x?x?x?x?x?xf32, #mhlo.type_extensions>) -> tensor %result2 = "mhlo.add"(%arg1, %arg0) : ( - tensor<3x?x?x?x?x?x?xf32, #mhlo.type_extensions>, - tensor<3x3x3x?x?x?x?xf32, #mhlo.type_extensions>) + tensor<3x?x?x?x?x?x?xf32, #mhlo.type_extensions>, + tensor<3x3x3x?x?x?x?xf32, #mhlo.type_extensions>) -> tensor - // CHECK: types0 = tensor<3x3x3x?x?x?x?xf32, #mhlo.type_extensions> + // CHECK: types0 = tensor<3x3x3x?x?x?x?xf32, #mhlo.type_extensions> %1 = "mhlo_test.get_return_types"(%result1) : (tensor) -> tensor<*xindex> - // CHECK: types0 = tensor<3x3x3x?x?x?x?xf32, #mhlo.type_extensions> + // CHECK: types0 = tensor<3x3x3x?x?x?x?xf32, #mhlo.type_extensions> %2 = "mhlo_test.get_return_types"(%result2) : (tensor) -> tensor<*xindex> func.return %1 : tensor<*xindex> } @@ -596,11 +1096,11 @@ func.func @add_bounds( // This test covers "Error out" case for type inference of binary op with bounds // See PairwiseSameOperandAndResultType::inferDimWithBound() func.func @add_bounds_mismatch( - %arg0: tensor<3xf32, #mhlo.type_extensions>, + %arg0: tensor<3xf32, #mhlo.type_extensions>, %arg1: tensor>) -> tensor<*xindex> { // expected-error@+1 {{requires compatible types for all operands and results}} %result = "mhlo.add"(%arg0, %arg1) : ( - tensor<3xf32, #mhlo.type_extensions>, + tensor<3xf32, #mhlo.type_extensions>, tensor>) -> tensor %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor<*xindex> func.return %1 : tensor<*xindex> @@ -617,3 +1117,285 @@ func.func @add_bounds_unranked( %1 = "mhlo_test.get_return_types"(%result) : (tensor<*xf32>) -> tensor<*xindex> func.return %1 : tensor<*xindex> } + +// ----- + +// CHECK-LABEL: @partition_id +func.func @partition_id() -> tensor<*xindex> { + %result = "mhlo.partition_id"() : () -> tensor + // CHECK: types0 = tensor + %1 = "mhlo_test.get_return_types"(%result) : (tensor) -> tensor<*xindex> + func.return %1 : tensor<*xindex> +} + +// ----- + +// CHECK-LABEL: @send +func.func @send(%arg0: !mhlo.token) -> !mhlo.token { + %result = "mhlo.send"(%arg0) { + channel_handle = #mhlo.channel_handle + } : (!mhlo.token) -> !mhlo.token + // CHECK: types0 = !mhlo.token + %1 = "mhlo_test.get_return_types"(%result) : (!mhlo.token) -> !mhlo.token + func.return %1 : !mhlo.token +} + +// ----- + +// CHECK-LABEL: func @gather +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4x2xi32>, %[[ARG1:.*]]: tensor +func.func @gather(%operand : tensor<3x4x2xi32>, %start_indices : tensor) -> tensor<4xindex> { + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C3:.*]] = arith.constant 3 : index + // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor + // CHECK: %[[RES:.*]] = tensor.from_elements %[[DIM]], %[[C3]], %[[C2]], %[[C2]] : tensor<4xindex> + // CHECK: return %[[RES]] : tensor<4xindex> + %result = "mhlo.gather"(%operand, %start_indices) { + dimension_numbers = #mhlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = dense<[1, 2, 2]> : tensor<3xi64>, + indices_are_sorted = false + } : (tensor<3x4x2xi32>, tensor) -> tensor + %1 = "mhlo_test.reify_return_type_shapes"(%result) : (tensor) -> tensor<4xindex> + func.return %1 : tensor<4xindex> +} + +// ----- + +// CHECK-LABEL: func @pad +// CHECK-SAME: (%[[ARG0:.*]]: tensor +func.func @pad(%arg0: tensor) -> tensor<4xindex> { + // CHECK: %[[CST0:.*]] = arith.constant 0 : index + // CHECK: %[[CST1:.*]] = arith.constant 48 : index + // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[CST0]] : tensor + // CHECK: %[[RES:.*]] = tensor.from_elements %[[DIM]], %[[CST1]], %[[CST1]], %[[CST1]] : tensor<4xindex> + // CHECK: return %[[RES]] : tensor<4xindex> + %0 = "mhlo.constant"() {value = dense<0.000000e+00> : tensor} : () -> tensor + %result = "mhlo.pad"(%arg0, %0) { + edge_padding_high = dense<[0, 0, 0, 16]> : tensor<4xi64>, + edge_padding_low = dense<0> : tensor<4xi64>, + interior_padding = dense<0> : tensor<4xi64> + } : (tensor, tensor) -> tensor + %1 = "mhlo_test.reify_return_type_shapes"(%result) : (tensor) -> tensor<4xindex> + func.return %1 : tensor<4xindex> +} + +// ----- + +// CHECK-LABEL: func @cholesky_bounds +func.func @cholesky_bounds(%input: tensor<2x?x?xf32, #mhlo.type_extensions>) -> tensor<*xindex> { + %0 = "mhlo.cholesky"(%input) { lower = true } : (tensor<2x?x?xf32, #mhlo.type_extensions>) -> tensor<*xf32> + // CHECK: types0 = tensor<2x?x?xf32, #mhlo.type_extensions> + %1 = "mhlo_test.get_return_types"(%0) : (tensor<*xf32>) -> tensor<*xindex> + func.return %1 : tensor<*xindex> +} + +// CHECK-LABEL: func @concatenate +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor +func.func @concatenate(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<2xindex> { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor + // CHECK: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor + // CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor + // CHECK: %[[DIM2:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor + // CHECK: %[[V0:.*]] = arith.addi %[[DIM]], %[[DIM1]] : index + // CHECK: %[[V1:.*]] = arith.addi %[[V0]], %[[DIM2]] : index + // CHECK: %[[RES:.*]] = tensor.from_elements %[[V1]], %[[DIM0]] : tensor<2xindex> + // CHECK: return %[[RES]] : tensor<2xindex> + %result = "mhlo.concatenate"(%arg0, %arg1, %arg2) { + dimension = 0 : i64 + } : (tensor, tensor, tensor) -> tensor + %1 = "mhlo_test.reify_return_type_shapes"(%result) : (tensor) -> tensor<2xindex> + func.return %1 : tensor<2xindex> +} + +// ----- + +// CHECK-LABEL: func @reduce +// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x?xf32>, +func.func @reduce(%arg0: tensor<4x?xf32>, %arg1 : tensor<4xf32>)-> (tensor<1xindex>) { + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<4x?xf32> + // CHECK: %[[RES:.*]] = tensor.from_elements %[[DIM]] : tensor<1xindex> + // CHECK: return %[[RES]] : tensor<1xindex> + %result = "mhlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor<4xf32>, %arg3: tensor<4xf32> ): + %1 = "mhlo.add"(%arg2, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "mhlo.return"(%1) : (tensor<4xf32>) -> () + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x?xf32>, tensor<4xf32>) -> tensor + %1 = "mhlo_test.reify_return_type_shapes"(%result): (tensor) -> tensor<1xindex> + func.return %1: tensor<1xindex> +} + +// ----- + +// CHECK-LABEL: func @real_dynamic_slice +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor<1xindex>, %[[ARG2:.*]]: tensor<1xindex>, %[[ARG3:.*]]: tensor<1xindex> +func.func @real_dynamic_slice(%arg0: tensor, %arg1: tensor<1xindex>, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>) -> tensor<1xindex> { + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[EXTD:.*]] = tensor.extract %[[ARG1]][%[[C0]]] : tensor<1xindex> + // CHECK: %[[EXTD0:.*]] = tensor.extract %[[ARG2]][%[[C0]]] : tensor<1xindex> + // CHECK: %[[EXTD1:.*]] = tensor.extract %[[ARG3]][%[[C0]]] : tensor<1xindex> + // CHECK: %[[V0:.*]] = arith.subi %[[EXTD0]], %[[EXTD]] : index + // CHECK: %[[V1:.*]] = arith.addi %[[EXTD1]], %[[V0]] : index + // CHECK: %[[V2:.*]] = arith.subi %[[V1]], %[[C1]] : index + // CHECK: %[[V3:.*]] = arith.divsi %[[V2]], %[[EXTD1]] : index + // CHECK: %[[RES:.*]] = tensor.from_elements %[[V3]] : tensor<1xindex> + // CHECK: return %[[RES]] : tensor<1xindex> + %result = "mhlo.real_dynamic_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + %1 = "mhlo_test.reify_return_type_shapes"(%result): (tensor) -> tensor<1xindex> + func.return %1: tensor<1xindex> +} + +// ----- + +// CHECK-LABEL: func @dot_general +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor +func.func @dot_general(%arg0: tensor, %arg1: tensor) -> tensor<3xindex> { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor + // CHECK: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor + // CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C2]] : tensor + // CHECK: %[[RES:.*]] = tensor.from_elements %[[DIM]], %[[DIM0]], %[[DIM1]] : tensor<3xindex> + // CHECK: return %[[RES]] : tensor<3xindex> + %result = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1] + > + } : (tensor, tensor) -> tensor + %1 = "mhlo_test.reify_return_type_shapes"(%result): (tensor) -> tensor<3xindex> + func.return %1: tensor<3xindex> +} + +// ----- + +// CHECK-LABEL: func @dynamic_pad +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor<1xindex>, %[[ARG3:.*]]: tensor<1xindex>, %[[ARG4:.*]]: tensor<1xindex> +func.func @dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>, %arg4: tensor<1xindex>) -> tensor<1xindex> { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor + // CHECK: %[[EXTD:.*]] = tensor.extract %[[ARG2]][%[[C0]]] : tensor<1xindex> + // CHECK: %[[EXTD0:.*]] = tensor.extract %[[ARG3]][%[[C0]]] : tensor<1xindex> + // CHECK: %[[EXTD1:.*]] = tensor.extract %[[ARG4]][%[[C0]]] : tensor<1xindex> + // CHECK: %[[V0:.*]] = arith.cmpi slt, %[[DIM]], %[[C1]] : index + // CHECK: %[[V1:.*]] = arith.subi %[[DIM]], %[[C1]] : index + // CHECK: %[[V2:.*]] = arith.select %[[V0]], %[[C0]], %[[V1]] : index + // CHECK: %[[V3:.*]] = arith.muli %[[EXTD1]], %[[V2]] : index + // CHECK: %[[V4:.*]] = arith.addi %[[V3]], %[[DIM]] : index + // CHECK: %[[V5:.*]] = arith.addi %[[V4]], %[[EXTD]] : index + // CHECK: %[[V6:.*]] = arith.addi %[[V5]], %[[EXTD0]] : index + // CHECK: %[[RES:.*]] = tensor.from_elements %[[V6]] : tensor<1xindex> + // CHECK: return %[[RES]] : tensor<1xindex> + %result = "mhlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor, tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + %1 = "mhlo_test.reify_return_type_shapes"(%result): (tensor) -> tensor<1xindex> + func.return %1: tensor<1xindex> +} + +// ----- + +// CHECK-LABEL: func @broadcast +// CHECK-SAME: (%[[ARG0:.*]]: tensor +func.func @broadcast(%arg0: tensor) -> tensor<3xindex> { + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor + // CHECK: %[[RES:.*]] = tensor.from_elements %[[C1]], %[[C2]], %[[DIM]] : tensor<3xindex> + // CHECK: return %[[RES]] : tensor<3xindex> + %result = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor) -> tensor<1x2x?xi32> + %1 = "mhlo_test.reify_return_type_shapes"(%result): (tensor<1x2x?xi32>) -> tensor<3xindex> + func.return %1: tensor<3xindex> +} + +// ----- + +// CHECK-LABEL: func @transpose +// CHECK-SAME: (%[[ARG0:.*]]: tensor +func.func @transpose(%arg0: tensor) -> tensor<4xindex> { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[C3:.*]] = arith.constant 3 : index + // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor + // CHECK: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor + // CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C2]] : tensor + // CHECK: %[[DIM2:.*]] = tensor.dim %[[ARG0]], %[[C3]] : tensor + // CHECK: %[[RES:.*]] = tensor.from_elements %[[DIM0]], %[[DIM]], %[[DIM2]], %[[DIM1]] : tensor<4xindex> + // CHECK: return %[[RES]] : tensor<4xindex> + %result = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor) -> tensor + %1 = "mhlo_test.reify_return_type_shapes"(%result): (tensor) -> tensor<4xindex> + func.return %1: tensor<4xindex> +} + +// ----- + +// CHECK-LABEL: func @dynamic_iota +// CHECK-SAME: (%[[ARG0:.*]]: tensor<1xindex> +func.func @dynamic_iota(%arg0: tensor<1xindex>) -> tensor<1xindex> { + // CHECK: return %[[ARG0]] : tensor<1xindex> + %result = "mhlo.dynamic_iota"(%arg0) { + iota_dimension = 0 : i64 + } : (tensor<1xindex>) -> tensor + %1 = "mhlo_test.reify_return_type_shapes"(%result): (tensor) -> tensor<1xindex> + func.return %1: tensor<1xindex> +} + +// ----- + +// CHECK: func @select_and_scatter_bound +func.func @select_and_scatter_bound( + %arg0: tensor>, + %arg1: tensor>) -> tensor<*xindex> { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = "mhlo.compare"(%arg3, %arg4) { + compare_type = #mhlo, + comparison_direction = #mhlo + } : (tensor, tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = mhlo.add %arg3, %arg4 : tensor + "mhlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor>, + tensor>, + tensor) -> tensor<*xf32> + // CHECK: types0 = tensor> + %3 = "mhlo_test.get_return_types"(%1) : (tensor<*xf32>) -> tensor<*xindex> + func.return %3 : tensor<*xindex> +} + +// ----- + +// CHECK-LABEL: func @reduce_window_bound +func.func @reduce_window_bound(%arg0: tensor<4x?x?x?xf32, #mhlo.type_extensions>, + %init0: tensor) -> (tensor<*xindex>) { + %0:1 = "mhlo.reduce_window"(%arg0, %init0) ({ + ^bb0(%a0: tensor, %b0: tensor): + %2 = mhlo.add %a0, %b0 : tensor + "mhlo.return"(%2) : (tensor) -> () + }) { + padding = dense<[[0, 0], [0, 0], [2, 2], [0, 0]]> : tensor<4x2xi64>, + window_dimensions = dense<[1, 1, 5, 1]> : tensor<4xi64>, + window_strides = dense<[1, 1, 3, 1]> : tensor<4xi64> + } : (tensor<4x?x?x?xf32, #mhlo.type_extensions>, + tensor) -> (tensor<*xf32>) + // CHECK: types0 = tensor<4x?x?x?xf32, #mhlo.type_extensions> + %1 = "mhlo_test.get_return_types"(%0#0) : (tensor<*xf32>) -> tensor<*xindex> + func.return %1: tensor<*xindex> +} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_ops_prettyprint.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_ops_prettyprint.mlir index fde33bc810e..de812256b57 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_ops_prettyprint.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/mhlo_ops_prettyprint.mlir @@ -88,35 +88,35 @@ func.func @unary_ops(%arg0 : tensor<2xi32>, %arg1 : tensor<2xf32>) -> () { // ----- // CHECK-LABEL: func @binary_ops -func.func @binary_ops(%arg0: tensor<2xi1>, %arg1 : tensor<2xf32>) -> tensor<2xi1> { +func.func @binary_ops(%arg0: tensor<2xi1>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xi32>) -> tensor<2xi1> { // CHECK: %0 = mhlo.add %arg0, %arg0 : tensor<2xi1> // CHECK-NEXT: %1 = mhlo.and %arg0, %arg0 : tensor<2xi1> - // CHECK-NEXT: %2 = mhlo.atan2 %arg0, %arg0 : tensor<2xi1> - // CHECK-NEXT: %3 = mhlo.divide %arg0, %arg0 : tensor<2xi1> + // CHECK-NEXT: %2 = mhlo.atan2 %arg1, %arg1 : tensor<2xf32> + // CHECK-NEXT: %3 = mhlo.divide %arg1, %arg1 : tensor<2xf32> // CHECK-NEXT: %4 = mhlo.maximum %arg1, %arg1 : tensor<2xf32> // CHECK-NEXT: %5 = mhlo.minimum %arg1, %arg1 : tensor<2xf32> // CHECK-NEXT: %6 = mhlo.multiply %arg1, %arg1 : tensor<2xf32> // CHECK-NEXT: %7 = mhlo.or %arg0, %arg0 : tensor<2xi1> // CHECK-NEXT: %8 = mhlo.power %arg1, %arg1 : tensor<2xf32> // CHECK-NEXT: %9 = mhlo.remainder %arg1, %arg1 : tensor<2xf32> - // CHECK-NEXT: %10 = mhlo.shift_left %arg1, %arg1 : tensor<2xf32> - // CHECK-NEXT: %11 = mhlo.shift_right_arithmetic %arg1, %arg1 : tensor<2xf32> - // CHECK-NEXT: %12 = mhlo.shift_right_logical %arg1, %arg1 : tensor<2xf32> + // CHECK-NEXT: %10 = mhlo.shift_left %arg2, %arg2 : tensor<2xi32> + // CHECK-NEXT: %11 = mhlo.shift_right_arithmetic %arg2, %arg2 : tensor<2xi32> + // CHECK-NEXT: %12 = mhlo.shift_right_logical %arg2, %arg2 : tensor<2xi32> // CHECK-NEXT: %13 = mhlo.subtract %arg1, %arg1 : tensor<2xf32> // CHECK-NEXT: %14 = mhlo.xor %arg0, %arg0 : tensor<2xi1> %0 = "mhlo.add"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> %1 = "mhlo.and"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> - %2 = "mhlo.atan2"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> - %3 = "mhlo.divide"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> + %2 = "mhlo.atan2"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + %3 = "mhlo.divide"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> %4 = "mhlo.maximum"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> %5 = "mhlo.minimum"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> %6 = "mhlo.multiply"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> %7 = "mhlo.or"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> %8 = "mhlo.power"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> %9 = "mhlo.remainder"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - %10 = "mhlo.shift_left"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - %11 = "mhlo.shift_right_arithmetic"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - %12 = "mhlo.shift_right_logical"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + %10 = "mhlo.shift_left"(%arg2, %arg2) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %11 = "mhlo.shift_right_arithmetic"(%arg2, %arg2) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + %12 = "mhlo.shift_right_logical"(%arg2, %arg2) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %13 = "mhlo.subtract"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> %14 = "mhlo.xor"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> func.return %0 : tensor<2xi1> @@ -132,7 +132,7 @@ func.func @type_convert_ops(%arg0 : tensor<2xf32>) -> () { // CHECK-NEXT: %3 = mhlo.bitcast %arg0 : (tensor<2xf32>) -> tensor<2x1xf32> %0 = "mhlo.convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf64> %1 = "mhlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x2xf32> - %2 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> + %2 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> %3 = "mhlo.bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2x1xf32> "mhlo.return"() : () -> () } @@ -221,10 +221,10 @@ func.func @compare_op(%arg0 : tensor<3xi32>) -> () { // ----- // CHECK-LABEL: func @extensions -func.func @extensions(%arg0 : tensor>, +func.func @extensions(%arg0 : tensor>, %arg1 : tensor) -> () { - // CHECK: %0 = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 1 : i64} : (tensor>, tensor) -> tensor<*xf32> - %0 = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 1 : i64} : (tensor>, tensor) -> tensor<*xf32> + // CHECK: %0 = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 1 : i64} : (tensor>, tensor) -> tensor<*xf32> + %0 = "mhlo.set_dimension_size"(%arg0, %arg1) {dimension = 1 : i64} : (tensor>, tensor) -> tensor<*xf32> "mhlo.return"() : () -> () } diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index 7bdc51baaa2..8fc79709ef2 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -16,7 +16,6 @@ func.func private @invalid_type() -> !mhlo.foobar // CHECK-LABEL: func @reduce_scatter func.func @reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { %0 = "mhlo.reduce_scatter"(%data) ({ - // reduction computation ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () @@ -31,7 +30,6 @@ func.func @reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { func.func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x5xf32> { // expected-error@+1 {{operand scatter dimension has size 16, expected to be a multiple of result scatter dimension size 5}} %0 = "mhlo.reduce_scatter"(%data) ({ - // reduction computation ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () @@ -62,10 +60,222 @@ func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { // ----- +func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { + // expected-error@+1 {{Reduction-region must take 2 parameters, but takes 3 parameter(s)}} + %0 = "mhlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor): + %max = mhlo.maximum %arg0, %arg1 : tensor + "mhlo.return"(%max) : (tensor) -> () + }) + { + replica_groups = dense<[[0, 2, 4, -1], [1, 3, -1, -1]]> : tensor<2x4xi64> + } : (tensor<10xf32>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +// ----- + +func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { + // expected-error@+1 {{The reduction-region expected to return some value(s)}} + %0 = "mhlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %max = mhlo.maximum %arg0, %arg1 : tensor + "mhlo.return"() : () -> () + }) + { + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<10xf32>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +// ----- + +func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { + // expected-error@+1 {{Reduction-region here must produce 1 tensors, but produces 2 instead}} + %0 = "mhlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %max = mhlo.maximum %arg0, %arg1 : tensor + "mhlo.return"(%max, %max) : (tensor, tensor) -> () + }) + { + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<10xf32>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +// ----- + +func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { + // expected-error@+1 {{Reduction-region here must produce tensor-typed result(s), but produces 'tuple, tensor>' instead}} + %0 = "mhlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %max = mhlo.maximum %arg0, %arg1 : tensor + %tup = "mhlo.tuple"(%max, %max) : (tensor, tensor) -> tuple, tensor> + "mhlo.return"(%tup) : (tuple, tensor>) -> () + }) + { + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<10xf32>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +// ----- + +func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { + // expected-error@+1 {{The type of reduction-region's parameter at index 1 is different than the corresponding result type: 'tensor' vs 'tensor'}} + %0 = "mhlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %max = mhlo.maximum %arg0, %arg0 : tensor + "mhlo.return"(%max) : (tensor) -> () + }) + { + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<10xf32>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +// ----- + +func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { + // expected-error@+1 {{The type of reduction-region's parameter at index 0 is different than the corresponding result type: 'tensor' vs 'tensor'}} + %0 = "mhlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %max = mhlo.maximum %arg0, %arg1 : tensor + %maxint = "mhlo.convert"(%max) : (tensor) -> tensor + "mhlo.return"(%maxint) : (tensor) -> () + }) + { + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<10xf32>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +// ----- + +func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { + // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + %0 = "mhlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %max = mhlo.maximum %arg0, %arg1 : tensor + "mhlo.return"(%max) : (tensor) -> () + }) + { + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<10xf32>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +// ----- + +func.func @all_reduce_invalid_reducer(%operand: tensor<10xf32>) -> tensor<10xf32> { + // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor<4xf32>' vs 'tensor'}} + %0 = "mhlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>): + %max = mhlo.maximum %arg0, %arg1 : tensor<4xf32> + "mhlo.return"(%max) : (tensor<4xf32>) -> () + }) + { + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<10xf32>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +// ----- + +func.func @all_reduce_invalid_return_type(%operand: tensor<10xf32>) -> tensor<10x4xf32> { + // expected-error@+1 {{requires compatible types for all operands and results}} + %0 = "mhlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %max = mhlo.maximum %arg0, %arg1 : tensor + "mhlo.return"(%max) : (tensor) -> () + }) + { + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<10xf32>) -> tensor<10x4xf32> + func.return %0 : tensor<10x4xf32> +} + +// ----- + +func.func @all_reduce_invalid_return_type(%operand: tensor<10xf32>) -> tensor<10xi32> { + // expected-error@+1 {{requires compatible types for all operands and results}} + %0 = "mhlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %max = mhlo.maximum %arg0, %arg1 : tensor + "mhlo.return"(%max) : (tensor) -> () + }) + { + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<10xf32>) -> tensor<10xi32> + func.return %0 : tensor<10xi32> +} + +// ----- + +func.func @all_reduce_invalid_replica_group(%operand: tensor<10xf32>) -> tensor<10xf32> { + // expected-error@+1 {{replica groups should be a rank 2 tensor}} + %0 = "mhlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %max = mhlo.maximum %arg0, %arg1 : tensor + "mhlo.return"(%max) : (tensor) -> () + }) + { + replica_groups = dense<0> : tensor<1xi64> + } : (tensor<10xf32>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +// ----- + +func.func @all_reduce_invalid_replica_group(%operand: tensor<10xf32>) -> tensor<10xf32> { + // expected-error@+1 {{replica id #1 seen more than once}} + %0 = "mhlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %max = mhlo.maximum %arg0, %arg1 : tensor + "mhlo.return"(%max) : (tensor) -> () + }) + { + replica_groups = dense<[[0, 1, 1, 3]]> : tensor<1x4xi64> + } : (tensor<10xf32>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +// ----- + +func.func @all_reduce_invalid_replica_group(%operand: tensor<10xf32>) -> tensor<10xf32> { + // expected-error@+1 {{replica id #2 not seen in replica groups}} + %0 = "mhlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %max = mhlo.maximum %arg0, %arg1 : tensor + "mhlo.return"(%max) : (tensor) -> () + }) + { + replica_groups = dense<[[0, 1, 3]]> : tensor<1x3xi64> + } : (tensor<10xf32>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +// ----- + +func.func @all_reduce_invalid_replica_group(%operand: tensor<10xf32>) -> tensor<10xf32> { + // expected-error@+1 {{replica groups cannot be empty}} + %0 = "mhlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %max = mhlo.maximum %arg0, %arg1 : tensor + "mhlo.return"(%max) : (tensor) -> () + }) + { + replica_groups = dense<0> : tensor<0x2xi64>, + use_global_device_ids + } : (tensor<10xf32>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +// ----- + func.func @invalid_reduce_scatter(%data: tensor<4x0xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{operand scatter dimension cannot be zero}} %0 = "mhlo.reduce_scatter"(%data) ({ - // reduction computation ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () @@ -79,7 +289,6 @@ func.func @invalid_reduce_scatter(%data: tensor<4x0xf32>) -> tensor<4x4xf32> { func.func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x0xf32> { // expected-error@+1 {{result scatter dimension cannot be zero}} %0 = "mhlo.reduce_scatter"(%data) ({ - // reduction computation ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () @@ -93,7 +302,6 @@ func.func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x0xf32> { func.func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4xf32> { // expected-error@+1 {{operand and result should have same rank}} %0 = "mhlo.reduce_scatter"(%data) ({ - // reduction computation ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () @@ -107,7 +315,6 @@ func.func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4xf32> { func.func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{scatter dim should be less than operand/result rank}} %0 = "mhlo.reduce_scatter"(%data) ({ - // reduction computation ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () @@ -121,7 +328,6 @@ func.func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { func.func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<3x4xf32> { // expected-error@+1 {{non scatter dimensions should be same for operand (4) and result (3)}} %0 = "mhlo.reduce_scatter"(%data) ({ - // reduction computation ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () @@ -133,7 +339,7 @@ func.func @invalid_reduce_scatter(%data: tensor<4x16xf32>) -> tensor<3x4xf32> { // ----- func.func @reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { - // expected-error@+1 {{replica groups should be a rank 2 tensor of 64 bit integers}} + // expected-error@+1 {{replica groups should be a rank 2 tensor}} %0 = "mhlo.reduce_scatter"(%data) ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor @@ -210,6 +416,20 @@ func.func @reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // ----- +// CHECK-LABEL: func @reduce_scatter_dynamic +func.func @reduce_scatter_dynamic(%data: tensor) -> tensor { + %0 = "mhlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + use_global_device_ids} : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + // CHECK-LABEL: func @alltoall func.func @alltoall(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { %0 = "mhlo.all_to_all"(%data) { @@ -236,8 +456,60 @@ func.func @alltoall_unranked_input(%data: tensor<*xf32>) -> tensor<*xf32> { // ----- +// CHECK-LABEL: func @alltoall_dynamic_split_dim +func.func @alltoall_dynamic_split_dim(%data: tensor<4x?xf32>) -> tensor<20x?xf32> { + %0 = "mhlo.all_to_all"(%data) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 5 : i64, + replica_groups = dense<[[0, 1, 2, 3, 4]]> : tensor<1x5xi64> + } : (tensor<4x?xf32>) -> tensor<20x?xf32> + func.return %0 : tensor<20x?xf32> +} + +// ----- + +// CHECK-LABEL: func @alltoall_dynamic_concat_dim +func.func @alltoall_dynamic_concat_dim(%data: tensor) -> tensor { + %0 = "mhlo.all_to_all"(%data) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + } : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @alltoall_dynamic_split_dim +func.func @alltoall_dynamic_split_dim(%data: tensor<4x?xf32>) -> tensor<20x?xf32> { + %0 = "mhlo.all_to_all"(%data) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 5 : i64, + replica_groups = dense<[[0, 1, 2, 3, 4]]> : tensor<1x5xi64> + } : (tensor<4x?xf32>) -> tensor<20x?xf32> + func.return %0 : tensor<20x?xf32> +} + +// ----- + +// CHECK-LABEL: func @alltoall_dynamic_concat_dim +func.func @alltoall_dynamic_concat_dim(%data: tensor) -> tensor { + %0 = "mhlo.all_to_all"(%data) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + } : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @alltoall_negative_split_dimension(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { - // expected-error@+1 {{AllToAll split_dimension -1 is out-of-bounds for input rank 2}} + // expected-error@+1 {{AllToAll split_dimension cannot be negative}} %0 = "mhlo.all_to_all"(%data) { split_dimension = -1 : i64, concat_dimension = 0 : i64, @@ -263,7 +535,7 @@ func.func @alltoall_out_bound_split_dimension(%data: tensor<4x16xf32>) -> tensor // ----- func.func @alltoall_negative_concat_dimension(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { - // expected-error@+1 {{AllToAll concat_dimension -1 is out-of-bounds for input rank 2}} + // expected-error@+1 {{AllToAll concat_dimension cannot be negative}} %0 = "mhlo.all_to_all"(%data) { split_dimension = 1 : i64, concat_dimension = -1 : i64, @@ -288,6 +560,19 @@ func.func @alltoall_out_bound_concat_dimension(%data: tensor<4x16xf32>) -> tenso // ----- +func.func @alltoall_invalid_split_count(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { + // expected-error@+1 {{AllToAll split_count must be > 0}} + %0 = "mhlo.all_to_all"(%data) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 0 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// ----- + func.func @alltoall_invalid_split_dim_size(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { // expected-error@+1 {{split dimension has size 16, expected to be a multiple of split_count 5}} %0 = "mhlo.all_to_all"(%data) { @@ -301,39 +586,226 @@ func.func @alltoall_invalid_split_dim_size(%data: tensor<4x16xf32>) -> tensor<16 // ----- -func.func @allgather_incompatible_types(%arg0: tensor<128x32xf32>) -> tensor<128x100xf32> { - // expected-error@+1 {{result gather dimension has size 100, expected to be a multiple of operand gather dimension size 32}} +func.func @alltoall_invalid_replica_group(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { + // expected-error@+1 {{replica groups should be a rank 2 tensor}} + %0 = "mhlo.all_to_all"(%data) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[[0], [1], [2], [3]]]> : tensor<1x4x1xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// ----- + +func.func @alltoall_invalid_replica_group(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { + // expected-error@+1 {{replica id #1 not seen in replica groups}} + %0 = "mhlo.all_to_all"(%data) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[-5, -4, -3, 0]]> : tensor<1x4xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// ----- + +func.func @alltoall_invalid_replica_group(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { + // expected-error@+1 {{replica id #2 seen more than once}} + %0 = "mhlo.all_to_all"(%data) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 2]]> : tensor<2x4xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// ----- + +func.func @alltoall_invalid_replica_group(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { + // expected-error@+1 {{replica id #4 not seen in replica groups}} + %0 = "mhlo.all_to_all"(%data) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 2, 6, 8], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// ----- + +func.func @alltoall_invalid_replica_group(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { + // expected-error@+1 {{group size of replica_groups must be 4}} + %0 = "mhlo.all_to_all"(%data) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 2, 4], [1, 3, 5]]> : tensor<2x3xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// ----- + +func.func @allgather_incompatible_types(%arg0: tensor<128x32xf32>) -> tensor<128x100xf32> { + // expected-error@+1 {{result gather dimension has size 100, expected to be a multiple of operand gather dimension size 32}} + %0 = "mhlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<128x32xf32>) -> tensor<128x100xf32> + func.return %0 : tensor<128x100xf32> +} + +// ----- + +func.func @allgather_gather_along_zero_dimension(%arg0: tensor<128x0x32xf32>) -> tensor<128x100xf32> { + // expected-error@+1 {{dimension size of operand at 'all_gather_dim' cannot be zero}} + %0 = "mhlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<128x0x32xf32>) -> tensor<128x100xf32> + func.return %0 : tensor<128x100xf32> +} + +// ----- + +// CHECK-LABEL: func @allgather_dynamic_gather_dim +func.func @allgather_dynamic_gather_dim(%arg0: tensor<128x32xf32>) -> tensor<128x?xf32> { + %0 = "mhlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, + use_global_device_ids + } : (tensor<128x32xf32>) -> tensor<128x?xf32> + func.return %0 : tensor<128x?xf32> +} + +// ----- + +// CHECK-LABEL: func @allgather_dynamic_non_gather_dim +func.func @allgather_dynamic_non_gather_dim(%arg0: tensor<128x32xf32>) -> tensor { + %0 = "mhlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, + use_global_device_ids + } : (tensor<128x32xf32>) -> tensor + func.return %0 : tensor +} + +// ----- + +func.func @all_gather_invalid_dim(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { + // expected-error@+1 {{all_gather_dim must be a valid index of operand}} + %0 = "mhlo.all_gather"(%arg0) { + all_gather_dim = 2 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<8x2xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// ----- + +func.func @all_gather_invalid_dim(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { + // expected-error@+1 {{all_gather_dim cannot be negative}} + %0 = "mhlo.all_gather"(%arg0) { + all_gather_dim = -1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<8x2xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// ----- + +func.func @all_gather_invalid_result_shape(%arg0: tensor<8x2x32xf32>) -> tensor<8x8xf32> { + // expected-error@+1 {{operand and return must have the same rank}} + %0 = "mhlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<8x2x32xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// ----- + +func.func @all_gather_invalid_result_shape(%arg0: tensor<8x2xf32>) -> tensor<4x8xf32> { + // expected-error@+1 {{operand and result should have the same shape except for the dimension size at 'all_gather_dim'}} + %0 = "mhlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<8x2xf32>) -> tensor<4x8xf32> + func.return %0 : tensor<4x8xf32> +} + +// ----- + +func.func @all_gather_invalid_replica_group_shape(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { + // expected-error@+1 {{replica groups should be a rank 2 tensor}} + %0 = "mhlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[[0], [1], [2], [3]]]> : tensor<1x4x1xi64> + } : (tensor<8x2xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// ----- + +func.func @all_gather_invalid_replica_group_shape(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { + // expected-error@+1 {{replica groups cannot be empty}} + %0 = "mhlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<0> : tensor<0x2xi64>, + use_global_device_ids + } : (tensor<8x2xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// ----- + +func.func @all_gather_invalid_replica_group(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { + // expected-error@+1 {{replica id #1 not seen in replica groups}} %0 = "mhlo.all_gather"(%arg0) { all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<128x32xf32>) -> tensor<128x100xf32> - func.return %0 : tensor<128x100xf32> + replica_groups = dense<[[-5, -4, -3, 0]]> : tensor<1x4xi64> + } : (tensor<8x2xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> } // ----- -func.func @allgather_gather_along_zero_dimension(%arg0: tensor<128x0x32xf32>) -> tensor<128x100xf32> { - // expected-error@+1 {{operand gather dimension cannot be zero}} +func.func @all_gather_invalid_replica_group(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { + // expected-error@+1 {{replica id #2 seen more than once}} %0 = "mhlo.all_gather"(%arg0) { all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - } : (tensor<128x0x32xf32>) -> tensor<128x100xf32> - func.return %0 : tensor<128x100xf32> + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 2]]> : tensor<2x4xi64> + } : (tensor<8x2xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> } // ----- -// CHECK-LABEL: func @allgather_dynamic_gather_dim -func.func @allgather_dynamic_gather_dim(%arg0: tensor<128x32xf32>) -> tensor<128x?xf32> { +func.func @all_gather_invalid_replica_group(%arg0: tensor<8x2xf32>) -> tensor<8x8xf32> { + // expected-error@+1 {{replica id #4 not seen in replica groups}} %0 = "mhlo.all_gather"(%arg0) { all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, - use_global_device_ids - } : (tensor<128x32xf32>) -> tensor<128x?xf32> - func.return %0 : tensor<128x?xf32> + replica_groups = dense<[[0, 2, 6, 8], [1, 3, 5, 7]]> : tensor<2x4xi64> + } : (tensor<8x2xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> } // ----- @@ -452,7 +924,7 @@ func.func @broadcast_in_dim_bad_rank_decrease(%arg0: tensor<1x2x3xi32>) -> tenso func.func @broadcast_in_dim_duplicate_bcast_dimensions(%arg0: tensor<1x1x3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_dimensions should not have duplicates}} - %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,0,2]> : tensor<3xi64>} : (tensor<1x1x3xi32>) -> tensor<1x2x3xi32> + %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,0,2]> : tensor<3xi64>} : (tensor<1x1x3xi32>) -> tensor<1x2x3xi32> func.return %0 : tensor<1x2x3xi32> } @@ -786,7 +1258,7 @@ func.func @comp_mismatch_return_shape(%arg0: tensor<3xi32>, %arg1: tensor<3xi32> // ----- -func.func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { +func.func @collective_permute_invalid_sources(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { // expected-error@+1 {{duplicate sources not allowed}} %0 = "mhlo.collective_permute"(%arg0) { source_target_pairs = dense<[[0, 1], [0, 2], [2, 3]]> : tensor<3x2xi64> @@ -796,7 +1268,7 @@ func.func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> te // ----- -func.func @collective_permute_duplicate_targets(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { +func.func @collective_permute_invalid_destinations(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { // expected-error@+1 {{duplicate targets not allowed}} %0 = "mhlo.collective_permute"(%arg0) { source_target_pairs = dense<[[0, 1], [1, 2], [2, 1]]> : tensor<3x2xi64> @@ -806,7 +1278,7 @@ func.func @collective_permute_duplicate_targets(%arg0: tensor<128x32xf32>) -> te // ----- -func.func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { +func.func @collective_permute_invalid_source_target_pairs(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { // expected-error@+1 {{expect source_target_pairs attribute to be of rank 2, but got rank 1}} %0 = "mhlo.collective_permute"(%arg0) { source_target_pairs = dense<[0, 1]> : tensor<2xi64> @@ -816,7 +1288,7 @@ func.func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> te // ----- -func.func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { +func.func @collective_permute_invalid_source_target_pairs(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { // expected-error@+1 {{expect source_target_pairs attribute of shape (N, 2), but got (2, 3)}} %0 = "mhlo.collective_permute"(%arg0) { source_target_pairs = dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi64> @@ -826,6 +1298,16 @@ func.func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> te // ----- +func.func @collective_permute_invalid_source_target_pairs(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { + // expected-error@+1 {{replica ids in source_target_pairs must be >= 0}} + %0 = "mhlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [-1, 0]]> : tensor<2x2xi64> + } : (tensor<128x32xf32>) -> tensor<128x32xf32> + func.return %0 : tensor<128x32xf32> +} + +// ----- + func.func @concat_0D(%arg0: tensor, %arg1: tensor) -> tensor<2xi32> { // expected-error@+1 {{rank-0 values cannot be concatenated}} %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor, tensor) -> tensor<2xi32> @@ -909,7 +1391,7 @@ func.func @concat_outofbounds_dim(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) - // ----- func.func @concat_mismatch_rank(%arg0: tensor<1xi32>, %arg1: tensor<2x2xi32>) -> tensor<3xi32> { - // expected-error@+1 {{op operands (0) and (1) do not match rank}} + // expected-error@+1 {{operands (0) and (1) do not match rank}} %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2x2xi32>) -> tensor<3xi32> func.return %0 : tensor<3xi32> } @@ -1009,7 +1491,7 @@ func.func @cholesky_invalid_rank(%arg0: tensor<1xf32>) -> tensor<1xf32> { // ----- func.func @cholesky_invalid_elt(%arg0: tensor<1x2x2xi32>) -> tensor<1x2x2xi32> { - // expected-error@+1 {{operand #0 must be tensor of 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements values, but got 'tensor<1x2x2xi32>}} + // expected-error@+1 {{operand #0 must be tensor of f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements values, but got 'tensor<1x2x2xi32>}} %0 = "mhlo.cholesky"(%arg0) { lower = true } : (tensor<1x2x2xi32>) -> tensor<1x2x2xi32> func.return %0: tensor<1x2x2xi32> } @@ -1024,6 +1506,13 @@ func.func @cholesky_wrong_infer_shape(%arg0: tensor<1x2x2xf32>) -> tensor<1x2x2x // ----- +func.func @create_token() -> !mhlo.token { + %0 = "mhlo.create_token"() : () -> !mhlo.token + func.return %0: !mhlo.token +} + +// ----- + // CHECK-LABEL: func @dot_vector func.func @dot_vector(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) -> tensor<1x1xi32> { %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<1x1xi32> @@ -1048,6 +1537,14 @@ func.func @dot_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) // ----- +func.func @dot_precision_invalid_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> { + // expected-error@+1 {{expects precision config to be empty or have <= 2 elements}} + %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo, #mhlo, #mhlo]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + func.return %0: tensor<2x2xi32> +} + +// ----- + func.func @dot_bad_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> { // expected-error@+1 {{'precision_config' failed to satisfy constraint}} %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["FOO", #mhlo]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> @@ -1056,22 +1553,36 @@ func.func @dot_bad_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi3 // ----- -func.func @dot_illegal_input_type(%arg0: tensor<3xf32>, %arg1: tensor) -> tensor { - // expected-error@+1 {{Unexpected result type: has 'tensor' but inferred 'tensor<3xf32>' from operands 'tensor<3xf32>' and 'tensor}} +func.func @dot_more_dynamic_output_type(%arg0: tensor<3xf32>, %arg1: tensor) -> tensor { %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<3xf32>, tensor) -> tensor func.return %0 : tensor } // ----- -func.func @dot_illegal_result_type(%arg0: tensor, %arg1: tensor<3xf32>) -> tensor<3x?xf32> { - // expected-error@+1 {{Unexpected result type: has 'tensor<3x?xf32>' but inferred 'tensor' from operands 'tensor' and 'tensor<3xf32>'}} +func.func @dot_cannot_infer_type(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { + // expected-error@+1 {{expected both lhs/rhs ranks to be either 1 or 2}} + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +func.func @dot_result_type_mismatch_with_inferred_type(%arg0: tensor, %arg1: tensor<3xf32>) -> tensor<3x?xf32> { + // expected-error@+1 {{inferred shape '[?]' is incompatible with return type of operation 'tensor<3x?xf32>'}} %0 = "mhlo.dot"(%arg0, %arg1) : (tensor, tensor<3xf32>) -> tensor<3x?xf32> func.return %0 : tensor<3x?xf32> } // ----- +func.func @dot_result_type_match_with_inferred_type(%arg0: tensor, %arg1: tensor<3xf32>) -> tensor<*xf32> { + %0 = "mhlo.dot"(%arg0, %arg1) : (tensor, tensor<3xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + // CHECK-LABEL: func @dot_legal_unranked_rank_type func.func @dot_legal_unranked_rank_type(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<2x2xf32> { // unrank legal test @@ -1103,7 +1614,7 @@ func.func @imag_fp_input(%arg0: tensor<*xf32>) -> tensor<*xf32> { // ----- func.func @imag_int_input(%arg0: tensor<*xi32>) -> tensor<*xi32> { - // expected-error@+1 {{operand #0 must be tensor of 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements values, but got 'tensor<*xi32>'}} + // expected-error@+1 {{operand #0 must be tensor of f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements values, but got 'tensor<*xi32>'}} %0 = "mhlo.imag"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> func.return %0 : tensor<*xi32> } @@ -1334,6 +1845,16 @@ func.func @map_mismatch_arguments_and_dimensions(%arg0: tensor<4x5xf32>, %arg1: // ----- +// CHECK-LABEL: func @outfeed +func.func @outfeed(%arg0: tensor<3x3x3xi32>, %arg1: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.outfeed"(%arg0, %arg1) { + outfeed_config = "" + } : (tensor<3x3x3xi32>, !mhlo.token) -> !mhlo.token + func.return %0 : !mhlo.token +} + +// ----- + // CHECK-LABEL: func @real_fp_input func.func @real_fp_input(%arg0: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.real"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> @@ -1343,7 +1864,7 @@ func.func @real_fp_input(%arg0: tensor<*xf32>) -> tensor<*xf32> { // ----- func.func @real_int_input(%arg0: tensor<*xi32>) -> tensor<*xi32> { - // expected-error@+1 {{operand #0 must be tensor of 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements values, but got 'tensor<*xi32>'}} + // expected-error@+1 {{operand #0 must be tensor of f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or complex type with 32-bit float or 64-bit float elements values, but got 'tensor<*xi32>'}} %0 = "mhlo.real"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> func.return %0 : tensor<*xi32> } @@ -1399,22 +1920,24 @@ func.func @replica_id() -> tensor { // CHECK-LABEL: func @rng_bit_generator func.func @rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) { - %4 = mhlo.constant dense<[10, 12]> : tensor<2xui64> - %0 = mhlo.constant dense<[10, 12]> : tensor<2xi32> - %1 = mhlo.constant dense<3> : tensor - %2, %3 = "mhlo.rng_bit_generator"(%4) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) - func.return %2, %3 : tensor<2xui64>, tensor<10x12xui32> + %0, %1 = "mhlo.rng_bit_generator"(%arg0) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) + func.return %0, %1 : tensor<2xui64>, tensor<10x12xui32> } // ----- func.func @rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) { - %4 = mhlo.constant dense<[10, 12]> : tensor<2xui64> - %0 = mhlo.constant dense<[10, 12]> : tensor<2xi32> - %1 = mhlo.constant dense<3> : tensor - // expected-error@+1 {{output state shape must match initial state shape. Got: 'tensor<2xui64>' and 'tensor<3xui64>'}} - %2, %3 = "mhlo.rng_bit_generator"(%4) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<2xui64>) -> (tensor<3xui64>, tensor<10x12xui32>) - func.return %2, %3 : tensor<3xui64>, tensor<10x12xui32> + // expected-error@+1 {{output state shape must be compatible with initial state shape. Got: 'tensor<2xui64>' and 'tensor<3xui64>'}} + %0, %1 = "mhlo.rng_bit_generator"(%arg0) {rng_algorithm = #mhlo.rng_algorithm} : (tensor<2xui64>) -> (tensor<3xui64>, tensor<10x12xui32>) + func.return %0, %1 : tensor<3xui64>, tensor<10x12xui32> +} + +// ----- + +// CHECK-LABEL: func @rng_bit_generator_dynamic +func.func @rng_bit_generator_dynamic(%arg0: tensor) -> (tensor, tensor<10x12xui32>) { + %0, %1 = "mhlo.rng_bit_generator"(%arg0) {rng_algorithm = #mhlo.rng_algorithm} : (tensor) -> (tensor, tensor<10x12xui32>) + func.return %0, %1 : tensor, tensor<10x12xui32> } // ----- @@ -1455,7 +1978,7 @@ func.func @rng_normal_invalid_shape(%arg0: tensor, %arg1: tensor) { func.func @rng_normal_invalid_mu_rank(%mu: tensor<1xf32>, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "mhlo.rng"(%mu, %sigma, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1464,7 +1987,7 @@ func.func @rng_normal_invalid_mu_rank(%mu: tensor<1xf32>, %sigma: tensor) - func.func @rng_normal_invalid_sigma_rank(%mu: tensor, %sigma: tensor<1xf32>) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{#1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{#1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "mhlo.rng"(%mu, %sigma, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1482,7 +2005,7 @@ func.func @rng_normal_invalid_shape_rank(%mu: tensor, %sigma: tensor) func.func @rng_normal_invalid_type(%arg0: tensor>, %arg1: tensor) { %cst = "mhlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64> - // expected-error @+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} + // expected-error @+1 {{#0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} %0 = "mhlo.rng"(%arg0, %arg1, %cst) {rng_distribution = #mhlo.rng_distribution}: (tensor>, tensor, tensor<1xi64>) -> tensor<7xf32> func.return } @@ -1524,7 +2047,7 @@ func.func @rng_uniform_invalid_shape(%arg0: tensor, %arg1: tensor, %ar func.func @rng_uniform_invalid_a_rank(%a: tensor<1xf32>, %b: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor<1xf32>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1534,7 +2057,7 @@ func.func @rng_uniform_invalid_a_rank(%a: tensor<1xf32>, %b: tensor) -> ten func.func @rng_uniform_invalid_b_rank(%a: tensor, %b: tensor<1xf32>) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{operand #1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} + // expected-error@+1 {{operand #1 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<1xf32>'}} %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor, tensor<1xf32>, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1552,7 +2075,7 @@ func.func @rng_uniform_invalid_shape_rank(%a: tensor, %b: tensor) -> t func.func @rng_uniform_invalid_type(%a: tensor>, %b: tensor) -> tensor<2x3x5xf32> { %shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64> - // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} + // expected-error@+1 {{operand #0 must be 0D tensor of pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor>'}} %0 = "mhlo.rng"(%a, %b, %shape) {rng_distribution = #mhlo.rng_distribution}: (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -1670,12 +2193,24 @@ func.func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x2xi32> { // ----- -// CHECK-LABEL: func @slice +// CHECK-LABEL: func @slice_dynamic_dim func.func @slice_dynamic_dim(%arg0: tensor<3x?xi32>) -> tensor<1x?xi32> { + %0 = "mhlo.slice"(%arg0) { + start_indices = dense<[1, 1]> : tensor<2xi64>, + limit_indices = dense<[2, 2]> : tensor<2xi64>, + strides = dense<[1, 1]> : tensor<2xi64> + } : (tensor<3x?xi32>) -> tensor<1x?xi32> + func.return %0 : tensor<1x?xi32> +} + +// ----- + +func.func @slice_dynamic_dim_invalid_indices(%arg0: tensor<3x?xi32>) -> tensor<1x?xi32> { + // expected-error@+1 {{negative start index -1 in dimension 1}} %0 = "mhlo.slice"(%arg0) { start_indices = dense<[1, -1]> : tensor<2xi64>, - limit_indices = dense<[2, -1]> : tensor<2xi64>, - strides = dense<[1, -1]> : tensor<2xi64> + limit_indices = dense<[2, 2]> : tensor<2xi64>, + strides = dense<[1, 1]> : tensor<2xi64> } : (tensor<3x?xi32>) -> tensor<1x?xi32> func.return %0 : tensor<1x?xi32> } @@ -1802,6 +2337,14 @@ func.func @dynamic_slice_mismatch_indices(%arg0: tensor<3x4xi32>, %arg1: tensor< // ----- +func.func @dynamic_slice_mismatch_indices_element_type(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { + // expected-error@+1 {{start indices must have same element type}} + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + func.return %0 : tensor<1x4xi32> +} + +// ----- + func.func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor) -> tensor<1x4xi32> { // expected-error@+1 {{has mismatched number of start indices (1) and the rank of operand (2)}} %0 = "mhlo.dynamic_slice"(%arg0, %arg1) {slice_sizes = dense<[1]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor) -> tensor<1x4xi32> @@ -1859,29 +2402,85 @@ func.func @dynamic_slice_slice_size_too_large(%arg0: tensor<3x4xi32>, %arg1: ten // ----- // CHECK-LABEL: @dynamic_update_slice -func.func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start1: tensor, %start2: tensor) -> tensor<3x4xi64> { - %0 = "mhlo.dynamic_update_slice"(%input, %update, %start1, %start2) : (tensor<3x4xi64>, tensor<2xi64>, tensor, tensor) -> tensor<3x4xi64> +func.func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<1x4xi64>, %start1: tensor, %start2: tensor) -> tensor<3x4xi64> { + %0 = "mhlo.dynamic_update_slice"(%input, %update, %start1, %start2) : (tensor<3x4xi64>, tensor<1x4xi64>, tensor, tensor) -> tensor<3x4xi64> func.return %0 : tensor<3x4xi64> } // ----- -func.func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> { +// CHECK-LABEL: @dynamic_update_slice_dynamic_dim +func.func @dynamic_update_slice_dynamic_dim(%input: tensor, %update: tensor<1x4xi64>, %start1: tensor, %start2: tensor) -> tensor<3x4xi64> { + %0 = "mhlo.dynamic_update_slice"(%input, %update, %start1, %start2) : (tensor, tensor<1x4xi64>, tensor, tensor) -> tensor<3x4xi64> + func.return %0 : tensor<3x4xi64> +} + +// ----- + +func.func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<1x2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> { // expected-error@+1 {{operand #2 must be 0D tensor of 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} - %0 = "mhlo.dynamic_update_slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64> + %0 = "mhlo.dynamic_update_slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<1x2xi64>, tensor<2xi64>) -> tensor<3x4xi64> + func.return %0 : tensor<3x4xi64> +} + +// ----- + +func.func @dynamic_update_slice_invalid_update(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start1: tensor, %start2: tensor) -> tensor<3x4xi64> { + // expected-error@+1 {{update rank does not match operand rank: 1 vs 2.}} + %0 = "mhlo.dynamic_update_slice"(%input, %update, %start1, %start2) : (tensor<3x4xi64>, tensor<2xi64>, tensor, tensor) -> tensor<3x4xi64> + func.return %0 : tensor<3x4xi64> +} + +// ----- + +func.func @dynamic_update_slice_invalid_start_size(%input: tensor<3x4xi64>, %update: tensor<1x2xi64>, %start: tensor) -> tensor<3x4xi64> { + // expected-error@+1 {{expects number of start_indices to match operand rank: 1 vs 2.}} + %0 = "mhlo.dynamic_update_slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<1x2xi64>, tensor) -> tensor<3x4xi64> func.return %0 : tensor<3x4xi64> } // ----- func.func @dynamic_update_slice_mismatched_start(%input: tensor<11x3x4xi32>, %update: tensor<1x3x4xi32>, %start1: tensor, %start2: tensor, %start3: tensor) -> tensor<11x3x4xi32> { - // expected-error@+1 {{start indices must have same element type (encountered mismatch: 'i32' vs 'i64')}} + // expected-error@+1 {{start indices must have same element type}} %0 = "mhlo.dynamic_update_slice"(%input, %update, %start1, %start2, %start3) : (tensor<11x3x4xi32>, tensor<1x3x4xi32>, tensor, tensor, tensor) -> tensor<11x3x4xi32> func.return %0 : tensor<11x3x4xi32> } // ----- +func.func @dynamic_update_slice_invalid_update_size(%input: tensor<3x4xi64>, %update: tensor<1x5xi64>, %start1: tensor, %start2: tensor) -> tensor<3x4xi64> { + // expected-error@+1 {{expects size at dimension 1 of update to be in range [0, 4]. Got: 5.}} + %0 = "mhlo.dynamic_update_slice"(%input, %update, %start1, %start2) : (tensor<3x4xi64>, tensor<1x5xi64>, tensor, tensor) -> tensor<3x4xi64> + func.return %0 : tensor<3x4xi64> +} + +// ----- + +// CHECK-LABEL: func @dynamic_update_slice_dynamic_rank_input +func.func @dynamic_update_slice_dynamic_rank_input(%input: tensor<*xi64>, %update: tensor<1x4xi64>, %start1: tensor, %start2: tensor) -> tensor<*xi64> { + %0 = "mhlo.dynamic_update_slice"(%input, %update, %start1, %start2) : (tensor<*xi64>, tensor<1x4xi64>, tensor, tensor) -> tensor<*xi64> + func.return %0 : tensor<*xi64> +} + +// ----- + +// CHECK-LABEL: func @dynamic_update_slice_dynamic_rank_update +func.func @dynamic_update_slice_dynamic_rank_update(%input: tensor<3x4xi64>, %update: tensor<*xi64>, %start1: tensor, %start2: tensor) -> tensor<3x4xi64> { + %0 = "mhlo.dynamic_update_slice"(%input, %update, %start1, %start2) : (tensor<3x4xi64>, tensor<*xi64>, tensor, tensor) -> tensor<3x4xi64> + func.return %0 : tensor<3x4xi64> +} + +// ----- + +// CHECK-LABEL: func @dynamic_update_slice_dynamic_sizes +func.func @dynamic_update_slice_dynamic_sizes(%input: tensor, %update: tensor<1x?xi64>, %start1: tensor, %start2: tensor) -> tensor { + %0 = "mhlo.dynamic_update_slice"(%input, %update, %start1, %start2) : (tensor, tensor<1x?xi64>, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ----- + // CHECK-LABEL: func @transpose func.func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> @@ -1960,6 +2559,30 @@ func.func @triangular_solve(%arg0: tensor<10x5x4x4xf32>, %arg1: tensor<10x5x4x4x // ----- +// CHECK-LABEL: func @triangular_solve_dynamic_dims_minor +func.func @triangular_solve_dynamic_dims_minor(%arg0: tensor<10x5x?x4xf32>, %arg1: tensor<10x5x4x4xf32>) -> tensor<10x5x4x4xf32> { + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = #mhlo, unit_diagonal = true} : (tensor<10x5x?x4xf32>, tensor<10x5x4x4xf32>) -> tensor<10x5x4x4xf32> + func.return %0 : tensor<10x5x4x4xf32> +} + +// ----- + +// CHECK-LABEL: func @triangular_solve_dynamic_dims_shared +func.func @triangular_solve_dynamic_dims_shared(%arg0: tensor<10x5x4x?xf32>, %arg1: tensor<10x5x4x4xf32>) -> tensor<10x5x4x4xf32> { + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = #mhlo, unit_diagonal = true} : (tensor<10x5x4x?xf32>, tensor<10x5x4x4xf32>) -> tensor<10x5x4x4xf32> + func.return %0 : tensor<10x5x4x4xf32> +} + +// ----- + +// CHECK-LABEL: func @triangular_solve_dynamic_dims_batch +func.func @triangular_solve_dynamic_dims_batch(%arg0: tensor, %arg1: tensor<10x?x4x4xf32>) -> tensor<10x5x4x4xf32> { + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = #mhlo, unit_diagonal = true} : (tensor, tensor<10x?x4x4xf32>) -> tensor<10x5x4x4xf32> + func.return %0 : tensor<10x5x4x4xf32> +} + +// ----- + // CHECK-LABEL: func @triangular_solve_unranked func.func @triangular_solve_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = #mhlo, unit_diagonal = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> @@ -1993,7 +2616,7 @@ func.func @triangular_solve_rank_less_than_2(%arg0: tensor<4xf32>, %arg1: tensor // ----- func.func @triangular_solve_unequal_minor_dims_a(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> { - // expected-error@+1 {{two minor dimensions of operand 'a' must have equal size, but got 'tensor<4x3xf32>'}} + // expected-error@+1 {{two minor dimensions of operand 'a' must be compatible, but got 'tensor<4x3xf32>'}} %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = #mhlo, unit_diagonal = true} : (tensor<4x3xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> func.return %0 : tensor<4x3xf32> } @@ -2009,7 +2632,7 @@ func.func @triangular_solve_unequal_rank(%arg0: tensor<10x4x4xf32>, %arg1: tenso // ----- func.func @triangular_solve_mismatch_shared_dim(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<3x4xf32> { - // expected-error@+1 {{shared dimension of operands 'a' and 'b' does not match, but got 'tensor<4x4xf32>' and 'tensor<3x4xf32>'}} + // expected-error@+1 {{shared dimension of operands 'a' and 'b' must be compatible, but got 'tensor<4x4xf32>' and 'tensor<3x4xf32>'}} %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = #mhlo, unit_diagonal = true} : (tensor<4x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> func.return %0 : tensor<3x4xf32> } @@ -2017,7 +2640,7 @@ func.func @triangular_solve_mismatch_shared_dim(%arg0: tensor<4x4xf32>, %arg1: t // ----- func.func @triangular_solve_mismatch_leading_dims(%arg0: tensor<10x5x4x4xf32>, %arg1: tensor<10x6x4x3xf32>) -> tensor<10x6x4x3xf32> { - // expected-error@+1 {{leading batch dimensions of the operands must be same, but got 'tensor<10x5x4x4xf32>' and 'tensor<10x6x4x3xf32>'}} + // expected-error@+1 {{batch dimensions of the operands must be compatible, but got 'tensor<10x5x4x4xf32>' and 'tensor<10x6x4x3xf32>'}} %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = #mhlo, unit_diagonal = true} : (tensor<10x5x4x4xf32>, tensor<10x6x4x3xf32>) -> tensor<10x6x4x3xf32> func.return %0 : tensor<10x6x4x3xf32> } @@ -2056,7 +2679,7 @@ func.func @tuple_token(%arg0: tensor, %arg1: !mhlo.token) -> tuple, %arg1: tensor) -> tuple, tensor, tensor> { - // expected-error@+1 {{number of operands to tuple expected to match number of types}} + // expected-error@+1 {{inferred type(s) 'tuple, tensor>' are incompatible with return type(s) of operation 'tuple, tensor, tensor>'}} %0 = "mhlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor, tensor> func.return %0 : tuple, tensor, tensor> } @@ -2064,7 +2687,7 @@ func.func @tuple_arg_size_mismatch(%arg0: tensor, %arg1: tensor) -> tu // ----- func.func @tuple_type_mismatch(%arg0: tensor, %arg1: tensor) -> tuple, tensor> { - // expected-error@+1 {{op has return type mismatch at 1th value}} + // expected-error@+1 {{inferred type(s) 'tuple, tensor>' are incompatible with return type(s) of operation 'tuple, tensor>'}} %0 = "mhlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor> func.return %0 : tuple, tensor> } @@ -2086,7 +2709,7 @@ func.func @get_tuple_element_token(%arg0: tuple, !mhlo.token>) -> !m // ----- func.func @get_tuple_element_bad_type(%arg0: tuple, tensor>) -> tensor { - // expected-error@+1 {{has return type tensor, but expected tensor}} + // expected-error@+1 {{inferred type(s) 'tensor' are incompatible with return type(s) of operation 'tensor'}} %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor func.return %0 : tensor } @@ -2125,7 +2748,7 @@ func.func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te // ----- func.func @floor_invalid_i32_type(%arg0: tensor<4xi32>) -> tensor<4xi32> { - // expected-error@+1 {{must be tensor of 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<4xi32>'}} + // expected-error@+1 {{must be tensor of f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<4xi32>'}} %0 = "mhlo.floor"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> func.return %0 : tensor<4xi32> } @@ -2356,7 +2979,48 @@ func.func @reshape_invalid_shapes(%operand: tensor<2x4xf32>) -> tensor<3x3xf32> // ----- -func.func @dot_general(%arg0: tensor, %arg1: tensor) { +func.func @reverse_duplicate_dimensions(%operand: tensor<3x2xi32>) -> tensor<3x2xi32> { + // expected-error @+1 {{dimensions should be unique. Got: 0, 0}} + %0 = "mhlo.reverse"(%operand) { + dimensions = dense<[0, 0]> : tensor<2xi64> + } : (tensor<3x2xi32>) -> tensor<3x2xi32> + func.return %0 : tensor<3x2xi32> +} + +// ----- + +func.func @reverse_invalid_dimensions_unranked(%operand: tensor<*xi32>) -> tensor<*xi32> { + // expected-error @+1 {{all dimensions should be non-negative. Got dimension: -1.}} + %0 = "mhlo.reverse"(%operand) { + dimensions = dense<-1> : tensor + } : (tensor<*xi32>) -> tensor<*xi32> + func.return %0 : tensor<*xi32> +} + +// ----- + +func.func @reverse_invalid_dimensions_negative(%operand: tensor<3x2xi32>) -> tensor<3x2xi32> { + // expected-error @+1 {{all dimensions should be non-negative. Got dimension: -1.}} + %0 = "mhlo.reverse"(%operand) { + dimensions = dense<-1> : tensor + } : (tensor<3x2xi32>) -> tensor<3x2xi32> + func.return %0 : tensor<3x2xi32> +} + +// ----- + +func.func @reverse_invalid_dimensions(%operand: tensor<3x2xi32>) -> tensor<3x2xi32> { + // expected-error @+1 {{all dimensions should be between [0, 2). Got dimension: 2.}} + %0 = "mhlo.reverse"(%operand) { + dimensions = dense<2> : tensor + } : (tensor<3x2xi32>) -> tensor<3x2xi32> + func.return %0 : tensor<3x2xi32> +} + +// ----- + +// CHECK-LABEL: func @dot_general +func.func @dot_general(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> tensor<2x4x5xf32> { %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2364,7 +3028,21 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1] > - } : (tensor, tensor) -> tensor + } : (tensor<2x3x4xf32>, tensor<2x3x5xf32>) -> tensor<2x4x5xf32> + func.return %0 : tensor<2x4x5xf32> +} + +// ----- + +func.func @dot_general(%arg0: tensor<1x?x1x?xf32>, %arg1: tensor) { + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0, 1], + rhs_batching_dimensions = [0, 1], + lhs_contracting_dimensions = [2, 3], + rhs_contracting_dimensions = [2, 3] + > + } : (tensor<1x?x1x?xf32>, tensor) -> tensor func.return } @@ -2414,7 +3092,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor<*xf32>) { // ----- func.func @dot_general(%arg0: tensor, %arg1: tensor) { - // expected-error @+1 {{'mhlo.dot_general' op lhs and rhs should have the same number of batching dimensions}} + // expected-error @+1 {{lhs and rhs should have the same number of batching dimensions}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2429,7 +3107,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { - // expected-error @+1 {{'mhlo.dot_general' op lhs and rhs should have the same number of batching dimensions}} + // expected-error @+1 {{lhs and rhs should have the same number of batching dimensions}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2444,7 +3122,7 @@ func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { // ----- func.func @dot_general(%arg0: tensor, %arg1: tensor) { - // expected-error @+1 {{'mhlo.dot_general' op lhs and rhs should have the same number of contracting dimensions}} + // expected-error @+1 {{lhs and rhs should have the same number of contracting dimensions}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2459,7 +3137,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { - // expected-error @+1 {{'mhlo.dot_general' op lhs and rhs should have the same number of contracting dimensions}} + // expected-error @+1 {{lhs and rhs should have the same number of contracting dimensions}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2474,7 +3152,7 @@ func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { // ----- func.func @dot_general(%arg0: tensor, %arg1: tensor) { - // expected-error @+1 {{'mhlo.dot_general' op has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 0}} + // expected-error @+1 {{has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 0}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0, 0], @@ -2489,7 +3167,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { - // expected-error @+1 {{'mhlo.dot_general' op has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 0}} + // expected-error @+1 {{has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 0}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0, 0], @@ -2504,7 +3182,7 @@ func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { // ----- func.func @dot_general(%arg0: tensor, %arg1: tensor) { - // expected-error @+1 {{'mhlo.dot_general' op has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 1}} + // expected-error @+1 {{has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 1}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2519,7 +3197,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { - // expected-error @+1 {{'mhlo.dot_general' op has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 1}} + // expected-error @+1 {{has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 1}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2534,7 +3212,7 @@ func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { // ----- func.func @dot_general(%arg0: tensor, %arg1: tensor) { - // expected-error @+1 {{'mhlo.dot_general' op has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 0}} + // expected-error @+1 {{has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 0}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2549,7 +3227,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { - // expected-error @+1 {{'mhlo.dot_general' op has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 0}} + // expected-error @+1 {{has duplicated dimension from lhs_batching_dimensions and lhs_contracting_dimensions: 0}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2564,7 +3242,7 @@ func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { // ----- func.func @dot_general(%arg0: tensor, %arg1: tensor) { - // expected-error @+1 {{'mhlo.dot_general' op has duplicated dimension from rhs_batching_dimensions and rhs_contracting_dimensions: 0}} + // expected-error @+1 {{has duplicated dimension from rhs_batching_dimensions and rhs_contracting_dimensions: 0}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2579,7 +3257,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { - // expected-error @+1 {{'mhlo.dot_general' op has duplicated dimension from rhs_batching_dimensions and rhs_contracting_dimensions: 0}} + // expected-error @+1 {{has duplicated dimension from rhs_batching_dimensions and rhs_contracting_dimensions: 0}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2594,7 +3272,7 @@ func.func @dot_general(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) { // ----- func.func @dot_general(%arg0: tensor, %arg1: tensor) { - // expected-error @+1 {{'mhlo.dot_general' op lhs_batching_dimensions value: -1 is out of range: [0, 3)}} + // expected-error @+1 {{lhs_batching_dimensions value: -1 is out of range: [0, 3)}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [-1], @@ -2609,7 +3287,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- func.func @dot_general(%arg0: tensor, %arg1: tensor) { - // expected-error @+1 {{'mhlo.dot_general' op lhs_batching_dimensions value: 3 is out of range: [0, 3)}} + // expected-error @+1 {{lhs_batching_dimensions value: 3 is out of range: [0, 3)}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [3], @@ -2624,7 +3302,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- func.func @dot_general(%arg0: tensor, %arg1: tensor) { - // expected-error @+1 {{'mhlo.dot_general' op rhs_batching_dimensions value: -1 is out of range: [0, 3)}} + // expected-error @+1 {{rhs_batching_dimensions value: -1 is out of range: [0, 3)}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2639,7 +3317,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- func.func @dot_general(%arg0: tensor, %arg1: tensor) { - // expected-error @+1 {{'mhlo.dot_general' op rhs_batching_dimensions value: 3 is out of range: [0, 3)}} + // expected-error @+1 {{rhs_batching_dimensions value: 3 is out of range: [0, 3)}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2654,7 +3332,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- func.func @dot_general(%arg0: tensor, %arg1: tensor) { - // expected-error @+1 {{'mhlo.dot_general' op lhs_contracting_dimensions value: -1 is out of range: [0, 3)}} + // expected-error @+1 {{lhs_contracting_dimensions value: -1 is out of range: [0, 3)}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2669,7 +3347,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- func.func @dot_general(%arg0: tensor, %arg1: tensor) { - // expected-error @+1 {{'mhlo.dot_general' op lhs_contracting_dimensions value: 3 is out of range: [0, 3)}} + // expected-error @+1 {{lhs_contracting_dimensions value: 3 is out of range: [0, 3)}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2684,7 +3362,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- func.func @dot_general(%arg0: tensor, %arg1: tensor) { - // expected-error @+1 {{'mhlo.dot_general' op rhs_contracting_dimensions value: -1 is out of range: [0, 3)}} + // expected-error @+1 {{rhs_contracting_dimensions value: -1 is out of range: [0, 3)}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2699,7 +3377,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- func.func @dot_general(%arg0: tensor, %arg1: tensor) { - // expected-error @+1 {{'mhlo.dot_general' op rhs_contracting_dimensions value: 3 is out of range: [0, 3)}} + // expected-error @+1 {{rhs_contracting_dimensions value: 3 is out of range: [0, 3)}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2714,7 +3392,7 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- func.func @dot_general(%arg0: tensor<2x?x?xf32>, %arg1: tensor<3x?x?xf32>) { - // expected-error @+1 {{'mhlo.dot_general' op batching dimension sizes must match for lhs/rhs}} + // expected-error @+1 {{batching dimension sizes must match for lhs/rhs}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2729,7 +3407,7 @@ func.func @dot_general(%arg0: tensor<2x?x?xf32>, %arg1: tensor<3x?x?xf32>) { // ----- func.func @dot_general(%arg0: tensor, %arg1: tensor) { - // expected-error @+1 {{'mhlo.dot_general' op contracting dimension sizes must match for lhs/rhs}} + // expected-error @+1 {{contracting dimension sizes must match for lhs/rhs}} %0 = "mhlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #mhlo.dot< lhs_batching_dimensions = [0], @@ -2743,6 +3421,54 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) { // ----- +// CHECK-LABEL: func @dot_general +func.func @dot_general(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> tensor<2x4x6xf32> { + // expected-error@+1 {{inferred shape '[2, 4, 5]' is incompatible with return type of operation 'tensor<2x4x6xf32>'}} + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1] + > + } : (tensor<2x3x4xf32>, tensor<2x3x5xf32>) -> tensor<2x4x6xf32> + func.return %0 : tensor<2x4x6xf32> +} + + +// ----- + +func.func @dot_general_one_element_precision_config(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> tensor<2x4x5xf32> { + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1] + >, + precision_config = [#mhlo] + } : (tensor<2x3x4xf32>, tensor<2x3x5xf32>) -> tensor<2x4x5xf32> + func.return %0 : tensor<2x4x5xf32> +} + +// ----- + +func.func @dot_general_three_element_precision_config(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> tensor<2x4x5xf32> { + // expected-error@+1 {{expects precision config to be empty or have <= 2 elements}} + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1] + >, + precision_config = [#mhlo, #mhlo, #mhlo] + } : (tensor<2x3x4xf32>, tensor<2x3x5xf32>) -> tensor<2x4x5xf32> + func.return %0 : tensor<2x4x5xf32> +} + +// ----- + func.func @compatible_shapes(%arg0: tensor, %shape: tensor<2xindex>) -> tensor { %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor func.return %0 : tensor @@ -2841,7 +3567,7 @@ func.func @bitcast_convert_width_mismatch(%arg: tensor) -> tensor { // ----- func.func @bitcast_convert_empty_target(%arg: tensor<1xf64>) -> tensor { - // expected-error@+1 {{op does not allow the smaller element type to be part of a 0d tensor, but got: 'tensor<1xf64>' and 'tensor'.}} + // expected-error@+1 {{does not allow the smaller element type to be part of a 0d tensor, but got: 'tensor<1xf64>' and 'tensor'.}} %0 = "mhlo.bitcast_convert"(%arg) : (tensor<1xf64>) -> tensor return %0 : tensor } @@ -2849,7 +3575,7 @@ func.func @bitcast_convert_empty_target(%arg: tensor<1xf64>) -> tensor { // ----- func.func @bitcast_convert_empty_operand(%arg: tensor) -> tensor<1xf64> { - // expected-error@+1 {{op does not allow the smaller element type to be part of a 0d tensor, but got: 'tensor' and 'tensor<1xf64>'.}} + // expected-error@+1 {{does not allow the smaller element type to be part of a 0d tensor, but got: 'tensor' and 'tensor<1xf64>'.}} %0 = "mhlo.bitcast_convert"(%arg) : (tensor) -> tensor<1xf64> return %0 : tensor<1xf64> } @@ -2910,6 +3636,14 @@ func.func @reduce_precision_invalid_exponent(%arg: tensor<2x4xf32>) -> tensor<2x // ----- +func.func @reduce_precision_invalid_mantissa(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { + // expected-error @+1 {{mantissa_bits must be at least 0.}} + %0 = "mhlo.reduce_precision"(%arg) {exponent_bits=1 : i32, mantissa_bits=-1 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> +} + +// ----- + func.func @gather(%operand : tensor<2x4x9xi32>, %start_indices : tensor<1x5x2xi32>) -> tensor<1x5x8xi32> { %res = "mhlo.gather"(%operand, %start_indices) { dimension_numbers = #mhlo.gather< @@ -3575,9 +4309,33 @@ func.func @set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32 // ----- +func.func @custom_call_with_dictionary_backend_config() { + // CHECK: mhlo.custom_call @foo() {api_version = 4 : i32, backend_config = {foo = 42 : i32}} + "mhlo.custom_call"() {api_version = 4 : i32, backend_config={foo = 42 : i32}, call_target_name = "foo"} : () -> () + func.return +} + +// ----- + +func.func @custom_call_with_incompatible_backend_config() { + // expected-error@+1 {{unsupported user-encoded backend config, backend config must be a dictionary attribute}} + "mhlo.custom_call"() {api_version = 4 : i32, backend_config="bar=42", call_target_name = "foo"} : () -> () + func.return +} + +// ----- + +func.func @custom_call_with_incompatible_backend_config() { + // expected-error@+1 {{unsupported dictionary attribute backend config, backend config must be a user-encoded string attribute}} + "mhlo.custom_call"() {api_version = 3 : i32, backend_config={bar = 42 : i32}, call_target_name = "foo"} : () -> () + func.return +} + +// ----- + // CHECK: func @custom_call_multiple_inputs_outputs func.func @custom_call_multiple_inputs_outputs(%x: tensor<2xf32>, %token: !mhlo.token) -> tensor<2xf32> { - %0:3 = "mhlo.custom_call"(%x, %token) {backend_config="", call_target_name = "foo", has_side_effect = false} : (tensor<2xf32>, !mhlo.token) -> (tensor<2xf32>, tensor<2xf32>, !mhlo.token) + %0:3 = "mhlo.custom_call"(%x, %token) {backend_config="", call_target_name = "foo", has_side_effect = false, custom_call_schedule = #mhlo} : (tensor<2xf32>, !mhlo.token) -> (tensor<2xf32>, tensor<2xf32>, !mhlo.token) %1 = "mhlo.add"(%0#0, %0#1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> func.return %1 : tensor<2xf32> } @@ -3705,7 +4463,7 @@ func.func @custom_call_mismatch_tensor_and_layout_permutation(%arg: tensor<1x2x3 // CHECK-LABEL: func @custom_call_output_operand_alias func.func @custom_call_output_operand_alias(%arg0: tuple, tensor<2x3xf32>>, %arg1: tensor<5x5xf32>) { - // CHECK: "mhlo.custom_call" + // CHECK: mhlo.custom_call // CHECK-SAME{LITERAL}: output_operand_aliases = [#mhlo.output_operand_alias]} %0 = "mhlo.custom_call"(%arg0, %arg1) { call_target_name = "foo", @@ -4335,7 +5093,6 @@ func.func @error_incompatible_alias_element_types (%arg0: tensor<2xf32> {mhlo.re // ----- - // mhlo.batch_norm_training // CHECK-LABEL: @batch_norm_train @@ -4346,8 +5103,18 @@ func.func @batch_norm_train(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, // ----- +// CHECK-LABEL: @batch_norm_train_dynamic +func.func @batch_norm_train_dynamic(%input: tensor, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tensor { + %0:3 = "mhlo.batch_norm_training" (%input, %scale, %offset) { + epsilon = 0.001 : f32, feature_index = 1 : i64 + } : (tensor, tensor<2xf32>, tensor<2xf32>) -> (tensor, tensor<2xf32>, tensor<2xf32>) + func.return %0#0 : tensor +} + +// ----- + func.func @error_batch_norm_train(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{expects feature_index to be smaller than the rank of operand type; got feature_index 4, and rank 4.}} + // expected-error@+1 {{expects featureIndex to be smaller than the rank of multi-dimensional operands; got featureIndex 4, and rank 4.}} %0:3 = "mhlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 4 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) func.return %0#0 : tensor<2x2x2x2xf32> } @@ -4355,7 +5122,7 @@ func.func @error_batch_norm_train(%input: tensor<2x2x2x2xf32>, %scale: tensor<2x // ----- func.func @error_batch_norm_train(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{expects feature_index to be a non-negative number, got -1.}} + // expected-error@+1 {{expects featureIndex to be a non-negative number, got -1.}} %0:3 = "mhlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = -1 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) func.return %0#0 : tensor<2x2x2x2xf32> } @@ -4363,7 +5130,7 @@ func.func @error_batch_norm_train(%input: tensor<2x2x2x2xf32>, %scale: tensor<2x // ----- func.func @error_batch_norm_train(%input: tensor<2x2x2x2xf32>, %scale: tensor<3xf32>, %offset: tensor<3xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{expects the size of scale factor to be same as the feature count, but the size of scale factor is 3 and the feature count is 2.}} + // expected-error@+1 {{expects the size of single-dimensional operands to be compatible with feature count, but the size of single-dimensional operands is 3 and the feature count is 2.}} %0:3 = "mhlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<2x2x2x2xf32>, tensor<3xf32>, tensor<3xf32>) func.return %0#0 : tensor<2x2x2x2xf32> } @@ -4382,8 +5149,18 @@ func.func @batch_norm_inference(%input: tensor<4x256xf32>, %scale: tensor<256xf3 // ----- +// CHECK-LABEL: @batch_norm_inference_dynamic +func.func @batch_norm_inference_dynamic(%input: tensor<4x?xf32>, %scale: tensor, %offset: tensor, %mean: tensor, %variance: tensor) -> (tensor<4x?xf32>) { + %0 = "mhlo.batch_norm_inference" (%input, %scale, %offset, %mean, %variance) { + epsilon = 1.001000e-05 : f32, feature_index = 1 : i64 + } : (tensor<4x?xf32>, tensor, tensor, tensor, tensor) -> tensor<4x?xf32> + func.return %0 : tensor<4x?xf32> +} + +// ----- + func.func @error_batch_norm_inference(%input: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, %mean: tensor<256xf32>, %variance: tensor<256xf32>) -> (tensor<4x256xf32>) { - // expected-error@+1 {{expects feature_index to be smaller than the rank of operand type; got feature_index 2, and rank 2.}} + // expected-error@+1 {{expects featureIndex to be smaller than the rank of multi-dimensional operands; got featureIndex 2, and rank 2.}} %0 = "mhlo.batch_norm_inference" (%input, %scale, %offset, %mean, %variance) {epsilon = 1.001000e-05 : f32, feature_index = 2 : i64} : (tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>) -> tensor<4x256xf32> @@ -4393,7 +5170,7 @@ func.func @error_batch_norm_inference(%input: tensor<4x256xf32>, %scale: tensor< // ----- func.func @error_batch_norm_inference(%input: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, %mean: tensor<256xf32>, %variance: tensor<256xf32>) -> (tensor<4x256xf32>) { - // expected-error@+1 {{expects feature_index to be a non-negative number, got -1.}} + // expected-error@+1 {{expects featureIndex to be a non-negative number, got -1.}} %0 = "mhlo.batch_norm_inference" (%input, %scale, %offset, %mean, %variance) {epsilon = 1.001000e-05 : f32, feature_index = -1 : i64} : (tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>) -> tensor<4x256xf32> @@ -4403,7 +5180,7 @@ func.func @error_batch_norm_inference(%input: tensor<4x256xf32>, %scale: tensor< // ----- func.func @error_batch_norm_inference(%input: tensor<4x256xf32>, %scale: tensor<25xf32>, %offset: tensor<25xf32>, %mean: tensor<25xf32>, %variance: tensor<25xf32>) -> (tensor<4x256xf32>) { - // expected-error@+1 {{expects the size of scale factor to be same as the feature count, but the size of scale factor is 25 and the feature count is 256.}} + // expected-error@+1 {{expects the size of single-dimensional operands to be compatible with feature count, but the size of single-dimensional operands is 25 and the feature count is 256.}} %0 = "mhlo.batch_norm_inference" (%input, %scale, %offset, %mean, %variance) {epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} : (tensor<4x256xf32>, tensor<25xf32>, tensor<25xf32>, tensor<25xf32>, tensor<25xf32>) -> tensor<4x256xf32> @@ -4422,8 +5199,17 @@ func.func @batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, % // ----- +func.func @batch_norm_grad_dynamic(%input: tensor, %scale: tensor, %mean: tensor, %variance: tensor, %grad_output: tensor) -> tensor { + %0:3 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) { + epsilon = 0.001 : f32, feature_index = 0 : i64 + } : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor) + func.return %0#0 : tensor +} + +// ----- + func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{expects feature_index to be smaller than the rank of operand type; got feature_index 4, and rank 4.}} + // expected-error@+1 {{expects featureIndex to be smaller than the rank of multi-dimensional operands; got featureIndex 4, and rank 4.}} %0:3 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 4 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) func.return %0#0 : tensor<2x2x2x2xf32> } @@ -4431,7 +5217,7 @@ func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf // ----- func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{expects feature_index to be a non-negative number, got -1.}} + // expected-error@+1 {{expects featureIndex to be a non-negative number, got -1.}} %0:3 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = -1 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) func.return %0#0 : tensor<2x2x2x2xf32> } @@ -4439,7 +5225,7 @@ func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf // ----- func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<4xf32>, %mean: tensor<4xf32>, %variance: tensor<4xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{expects the size of scale factor to be same as the feature count, but the size of scale factor is 4 and the feature count is 2.}} + // expected-error@+1 {{expects the size of single-dimensional operands to be compatible with feature count, but the size of single-dimensional operands is 4 and the feature count is 2.}} %0:3 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<4xf32>, tensor<4xf32>) func.return %0#0 : tensor<2x2x2x2xf32> } @@ -4447,7 +5233,7 @@ func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<4xf // ----- func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<4xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{failed to verify that all of {scale, mean, variance, grad_scale, grad_offset} have same shape}} + // expected-error@+1 {{expects single-dimensional operands to have compatible shapes}} %0:3 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<4xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) func.return %0#0 : tensor<2x2x2x2xf32> } @@ -4455,7 +5241,7 @@ func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<4xf // ----- func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xi32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{operand #0 must be ranked tensor of 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<2x2x2x2xi32>'}} + // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<2x2x2x2xi32>'}} %0:3 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xi32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) func.return %0#0 : tensor<2x2x2x2xf32> } @@ -4463,7 +5249,7 @@ func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xi32>, %scale: tensor<2xf // ----- func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{failed to verify that all of {operand, grad_output} have same shape}} + // expected-error@+1 {{expects multi-dimensional operands to have compatible shapes}} %0:3 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) func.return %0#0 : tensor<2x2x2x2xf32> } @@ -4471,15 +5257,15 @@ func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf // ----- func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{failed to verify that all of {operand, grad_scale, grad_offset} have same element type}} + // expected-error@+1 {{failed to verify that all of {operand, grad_operand, grad_scale, grad_offset} have same element type}} %0:3 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2xf64>, tensor<2xf64>) func.return %0#0 : tensor<2x2x2x2xf32> } // ----- -func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{failed to verify that all of {operand, grad_operand} have same type}} +func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf64> { + // expected-error@+1 {{failed to verify that all of {operand, grad_operand, grad_scale, grad_offset} have same element type}} %0:3 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf64>, tensor<2xf32>, tensor<2xf32>) func.return %0#0 : tensor<2x2x2x2xf64> } @@ -4487,28 +5273,21 @@ func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf // ----- func.func @error_batch_norm_grad(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{result #1 must be 1D tensor of 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<2x2xf32>'}} + // expected-error@+1 {{result #1 must be 1D tensor of f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<2x2xf32>'}} %0:3 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<2x2x2x2xf32>, tensor<2x2xf32>, tensor<2xf32>) func.return %0#0 : tensor<2x2x2x2xf32> } // ----- -func.func @error_batch_norm_grad(%input: tensor<*xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> { - // expected-error@+1 {{operand #0 must be ranked tensor of 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<*xf32>'}} +func.func @error_batch_norm_grad(%input: tensor<*xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tensor<*xf32> { + // expected-error@+1 {{operand #0 must be ranked tensor of f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<*xf32>'}} %0:3 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<*xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> (tensor<*xf32>, tensor<2xf32>, tensor<2xf32>) func.return %0#0 : tensor<*xf32> } // ----- -func.func @error_batch_norm_grad(%input: tensor, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor) -> tensor { - // expected-error@+1 {{expects the size of scale factor to be same as the feature count, but the size of scale factor is 2 and the feature count is -1.}} - %0:3 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor) -> (tensor, tensor<2xf32>, tensor<2xf32>) - func.return %0#0 : tensor -} - -// ----- // Test rng_get_and_update_state_op // CHECK-LABEL: xla.rng_get_and_update_state func.func @xla.rng_get_and_update_state() -> tensor<2xui64> { @@ -4584,7 +5363,7 @@ func.func @fft_rank_mismatch(%arg0: tensor<3x9xf32>) -> tensor<3x9xcomplex> // ----- func.func @rfft_invalid_dim(%arg0: tensor<3x9xf32>) -> tensor<3x9xcomplex> { - // expected-error@+1 {{RFFT requires innermost dimensions match fft_length. Got: 3, 9 but wanted 9, 9.}} + // expected-error@+1 {{RFFT requires innermost dimensions to be compatible with fft_length. Got: 3, 9 but wanted 9, 9.}} %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<2xi64>, fft_type = #mhlo } : (tensor<3x9xf32>) -> tensor<3x9xcomplex> func.return %0 : tensor<3x9xcomplex> } @@ -4592,7 +5371,7 @@ func.func @rfft_invalid_dim(%arg0: tensor<3x9xf32>) -> tensor<3x9xcomplex> // ----- func.func @irfft_invalid_dim(%arg0: tensor<3x9xcomplex>) -> tensor<3x9xf32> { - // expected-error@+1 {{IRFFT requires non-final dimensions match fft_length. Got: 3, 9 but wanted 9, 9, and 3 != 9.}} + // expected-error@+1 {{IRFFT requires non-final dimensions to be compatible with fft_length. Got: 3, 9 but wanted 9, 9, and 3 != 9.}} %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<2xi64>, fft_type = #mhlo } : (tensor<3x9xcomplex>) -> tensor<3x9xf32> func.return %0 : tensor<3x9xf32> } @@ -4600,7 +5379,7 @@ func.func @irfft_invalid_dim(%arg0: tensor<3x9xcomplex>) -> tensor<3x9xf32> // ----- func.func @irfft_invalid_dim(%arg0: tensor<3x9xcomplex>) -> tensor<3x9xf32> { - // expected-error@+1 {{IRFFT requires innermost dimension match fft_length[-1]/2+1. Got: 3, 9 but fft_length is 9.}} + // expected-error@+1 {{IRFFT requires innermost dimension to be compatible with fft_length[-1]/2+1. Got: 9 but fft_length is 9.}} %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo } : (tensor<3x9xcomplex>) -> tensor<3x9xf32> func.return %0 : tensor<3x9xf32> } @@ -4608,7 +5387,7 @@ func.func @irfft_invalid_dim(%arg0: tensor<3x9xcomplex>) -> tensor<3x9xf32> // ----- func.func @irfft_invalid_elt(%arg0: tensor<3x9xf32>) -> tensor<3x9xcomplex> { - // expected-error@+1 {{IRFFT takes a complex tensor as input, but is given 'tensor<3x9xf32>'}} + // expected-error@+1 {{FFT/IFFT/IRFFT take a complex tensor as input, but is given 'tensor<3x9xf32>'}} %0 = "mhlo.fft"(%arg0) { fft_length = dense<16> : tensor<1xi64>, fft_type = #mhlo } : (tensor<3x9xf32>) -> tensor<3x9xcomplex> func.return %0 : tensor<3x9xcomplex> } @@ -4631,6 +5410,54 @@ func.func @rfft_invalid_ret_elt(%arg0: tensor<3x9xf32>) -> tensor<3x9xf32> { // ----- +// CHECK-LABEL: @rfft_dynamic +func.func @rfft_dynamic(%arg0: tensor) -> tensor> { + %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo } : (tensor) -> tensor> + func.return %0 : tensor> +} + +// ----- + +func.func @rfft_dynamic_incompatible_dims(%arg0: tensor<3x10xf32>) -> tensor> { + // expected-error@+1{{RFFT requires innermost dimensions to be compatible with fft_length. Got: 3, 10 but wanted 9.}} + %0 = "mhlo.fft"(%arg0) { fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo } : (tensor<3x10xf32>) -> tensor> + func.return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: @irfft_dynamic +func.func @irfft_dynamic(%arg0: tensor>) -> tensor { + %0 = "mhlo.fft"(%arg0) { fft_length = dense<16> : tensor<1xi64>, fft_type = #mhlo } : (tensor>) -> tensor + func.return %0 : tensor +} + +// ----- + +func.func @irfft_dynamic_incompatible_non_final_dims(%arg0: tensor>) -> tensor { + // expected-error@+1{{IRFFT requires non-final dimensions to be compatible with fft_length. Got: -9223372036854775808, 3, 15 but wanted 4, 16, and 3 != 4}} + %0 = "mhlo.fft"(%arg0) { fft_length = dense<[4, 16]> : tensor<2xi64>, fft_type = #mhlo } : (tensor>) -> tensor + func.return %0 : tensor +} + +// ----- + +func.func @irfft_dynamic_incompatible_final_dim(%arg0: tensor>) -> tensor { + // expected-error@+1{{IRFFT requires innermost dimension to be compatible with fft_length[-1]/2+1. Got: 8 but fft_length is 16.}} + %0 = "mhlo.fft"(%arg0) { fft_length = dense<16> : tensor<1xi64>, fft_type = #mhlo } : (tensor>) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @irfft_dynamic +func.func @irfft_dynamic(%arg0: tensor>) -> tensor { + %0 = "mhlo.fft"(%arg0) { fft_length = dense<16> : tensor<1xi64>, fft_type = #mhlo } : (tensor>) -> tensor + func.return %0 : tensor +} + +// ----- + // CHECK-LABEL: @eltwise_static_and_dynamic_type( // CHECK-SAME: %[[A:.*]]: tensor<10x10xf32>, %[[B:.*]]: tensor) -> tensor<10x10xf32> // CHECK: %[[R:.*]] = mhlo.add %[[A]], %[[B]] : (tensor<10x10xf32>, tensor) -> tensor<10x10xf32> @@ -5484,7 +6311,7 @@ func.func @is_finite(%arg0: tensor<3xf32>) -> tensor<3xi1> { // ----- func.func @is_finite_int_input(%arg0: tensor<3xi32>) -> tensor<3xi1> { - // expected-error@+1 {{operand #0 must be tensor of 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<3xi32>'}} + // expected-error@+1 {{operand #0 must be tensor of f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type values, but got 'tensor<3xi32>'}} %0 = "mhlo.is_finite"(%arg0) {} : (tensor<3xi32>) -> tensor<3xi1> func.return %0 : tensor<3xi1> } @@ -5512,3 +6339,17 @@ func.func @invalid_dimension_attr(%arg0: tensor>, tensor) -> tensor<*xf32> func.return %result : tensor<*xf32> } + +// ----- + +func.func @f8e4m3fn(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +func.func @f8e5m2(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/optimize-hlo.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/optimize-hlo.mlir index be3d7ffc968..f002d87fde8 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/optimize-hlo.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/optimize-hlo.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -pass-pipeline='func.func(mhlo-test-optimize)' | FileCheck %s --dump-input-context=30 +// RUN: mlir-hlo-opt %s -pass-pipeline='builtin.module(func.func(mhlo-test-optimize))' | FileCheck %s --dump-input-context=30 // CHECK-LABEL: @gather_is_slice_no_rank func.func @gather_is_slice_no_rank(%arg0: tensor<2x1x2xi32>, %arg1: tensor) -> tensor<1x2xi32> { diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index 01552900c23..da1b657a4ce 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt --stablehlo-legalize-to-hlo --mlir-print-op-generic --split-input-file %s | FileCheck %s +// RUN: mlir-hlo-opt --stablehlo-legalize-to-hlo --mlir-print-op-generic --split-input-file --verify-diagnostics %s | FileCheck %s // ============ ATTRIBUTES ============ @@ -307,10 +307,13 @@ func.func @attr_transpose_adjoint(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16x // TypeExtensionsAttr aka #stablehlo.type_extensions is covered below. -func.func @attr_type_extensions_bounds(%arg0: tensor>) -> tensor> { - // CHECK: "func.return"(%arg0) : (tensor>) -> () - func.return %arg0 : tensor> +func.func @attr_type_extensions_bounds( + %arg0: tensor>) + -> tensor> { + // CHECK: "func.return"(%arg0) : (tensor>) -> () + func.return %arg0 : tensor> } + // CHECK-LABEL: "attr_type_extensions_bounds" // ============ OPS ============ @@ -378,6 +381,7 @@ func.func @op_all_reduce(%arg0: tensor) -> tensor { func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { // CHECK: "mhlo.all_to_all"(%arg0) { + // CHECK-SAME: channel_handle = #mhlo.channel_handle, // CHECK-SAME: concat_dimension = 0 : i64, // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, // CHECK-SAME: split_count = 4 : i64, @@ -387,7 +391,8 @@ func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { split_dimension = 1 : i64, concat_dimension = 0 : i64, split_count = 4 : i64, - replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + channel_handle = #stablehlo.channel_handle } : (tensor<4x16xf32>) -> tensor<16x4xf32> func.return %0 : tensor<16x4xf32> } @@ -594,10 +599,10 @@ func.func @op_convert(%arg0: tensor) -> tensor { } // CHECK-LABEL: "op_convert" -func.func @op_convolution(%arg0: tensor<1x8x8x32x207xf32>, %arg1: tensor<3x3x32x207x16xf32>) -> tensor<32x1x8x8x16xf32> { +func.func @op_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // CHECK: "mhlo.convolution"(%arg0, %arg1) { // CHECK-SAME: batch_group_count = 1 : i64, - // CHECK-SAME: dimension_numbers = #mhlo.conv<[b, 0, 1, ?, f]x[0, 1, ?, i, o]->[?, b, 0, 1, f]>, + // CHECK-SAME: dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, // CHECK-SAME: feature_group_count = 1 : i64, // CHECK-SAME: lhs_dilation = dense<1> : tensor<2xi64>, // CHECK-SAME: padding = dense<1> : tensor<2x2xi64>, @@ -605,19 +610,19 @@ func.func @op_convolution(%arg0: tensor<1x8x8x32x207xf32>, %arg1: tensor<3x3x32x // CHECK-SAME: rhs_dilation = dense<1> : tensor<2xi64>, // CHECK-SAME: window_reversal = dense : tensor<2xi1>, // CHECK-SAME: window_strides = dense<1> : tensor<2xi64> - // CHECK-SAME: } : (tensor<1x8x8x32x207xf32>, tensor<3x3x32x207x16xf32>) -> tensor<32x1x8x8x16xf32> + // CHECK-SAME: } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> %0 = "stablehlo.convolution"(%arg0, %arg1) { window_strides = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, lhs_dilation = dense<1> : tensor<2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_reversal = dense : tensor<2xi1>, - dimension_numbers = #stablehlo.conv<[b, 0, 1, ?, f]x[0, 1, ?, i, o]->[?, b, 0, 1, f]>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 1 : i64, batch_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo] - } : (tensor<1x8x8x32x207xf32>, tensor<3x3x32x207x16xf32>) -> tensor<32x1x8x8x16xf32> - func.return %0 : tensor<32x1x8x8x16xf32> + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> } // CHECK-LABEL: "op_convolution" @@ -654,7 +659,7 @@ func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.wi // CHECK-LABEL: "op_cstr_reshapable" func.func @called_computation() { func.return } -func.func @op_custom_call(%arg0: tensor) -> tensor { +func.func @op_custom_call_api_version_original(%arg0: tensor) -> tensor { // CHECK: "mhlo.custom_call"(%arg0) { // CHECK-SAME: api_version = 1 : i32, // CHECK-SAME: backend_config = "", @@ -662,6 +667,11 @@ func.func @op_custom_call(%arg0: tensor) -> tensor { // CHECK-SAME: called_computations = [@foo], // CHECK-SAME: has_side_effect = false, // CHECK-SAME: operand_layouts = [dense<> : tensor<0xindex>], + // CHECK-SAME: output_operand_aliases = [ + // CHECK-SAME: #mhlo.output_operand_alias< + // CHECK-SAME: output_tuple_indices = [], + // CHECK-SAME: operand_index = 0, + // CHECK-SAME: operand_tuple_indices = []>] // CHECK-SAME: result_layouts = [dense<> : tensor<0xindex>] // CHECK-SAME: } : (tensor) -> tensor %0 = "stablehlo.custom_call"(%arg0) { @@ -671,11 +681,29 @@ func.func @op_custom_call(%arg0: tensor) -> tensor { api_version = 1 : i32, called_computations = [@foo], operand_layouts = [dense<> : tensor<0xindex>], + output_operand_aliases = [ + #stablehlo.output_operand_alias], result_layouts = [dense<> : tensor<0xindex>] } : (tensor) -> tensor func.return %0 : tensor } -// CHECK-LABEL: "op_custom_call" +// CHECK-LABEL: "op_custom_call_api_version_original" + +func.func @op_custom_call_api_version_typed_ffi(%arg0: tensor) -> tensor { + // CHECK: "mhlo.custom_call"(%arg0) { + // CHECK-SAME: api_version = 4 : i32, + // CHECK-SAME: backend_config = {foo = "bar"}, + // CHECK-SAME: call_target_name = "foo" + // CHECK-SAME: } : (tensor) -> tensor + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "mhlo.custom_call", + backend_config = "{api_version = 4 : i32, backend_config = {foo = \22bar\22}, call_target_name = \22foo\22}" + } : (tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: "op_custom_call_api_version_typed_ffi" func.func @op_divide(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "mhlo.divide"(%arg0, %arg1) : (tensor, tensor) -> tensor @@ -733,10 +761,10 @@ func.func @op_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xind } // CHECK-LABEL: "op_dynamic_broadcast_in_dim" -func.func @op_dynamic_conv(%arg0: tensor<1x8x8x32x207xf32>, %arg1: tensor<3x3x32x207x16xf32>, %arg2: tensor<4xi32>) -> tensor<32x1x?x?x16xf32> { +func.func @op_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<4xi32>) -> tensor<1x?x?x16xf32> { // CHECK: "mhlo.dynamic_conv"(%arg0, %arg1, %arg2) { // CHECK-SAME: batch_group_count = 1 : i64, - // CHECK-SAME: dimension_numbers = #mhlo.conv<[b, 0, 1, ?, f]x[0, 1, ?, i, o]->[?, b, 0, 1, f]>, + // CHECK-SAME: dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, // CHECK-SAME: feature_group_count = 1 : i64, // CHECK-SAME: lhs_dilation = dense<1> : tensor<2xi64>, // CHECK-SAME: padding = dense<1> : tensor<2x2xi64>, @@ -744,19 +772,19 @@ func.func @op_dynamic_conv(%arg0: tensor<1x8x8x32x207xf32>, %arg1: tensor<3x3x32 // CHECK-SAME: rhs_dilation = dense<1> : tensor<2xi64>, // CHECK-SAME: window_reversal = dense : tensor<2xi1>, // CHECK-SAME: window_strides = dense<1> : tensor<2xi64> - // CHECK-SAME: } : (tensor<1x8x8x32x207xf32>, tensor<3x3x32x207x16xf32>, tensor<4xi32>) -> tensor<32x1x?x?x16xf32> + // CHECK-SAME: } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<4xi32>) -> tensor<1x?x?x16xf32> %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { window_strides = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, lhs_dilation = dense<1> : tensor<2xi64>, rhs_dilation = dense<1> : tensor<2xi64>, window_reversal = dense : tensor<2xi1>, - dimension_numbers = #stablehlo.conv<[b, 0, 1, ?, f]x[0, 1, ?, i, o]->[?, b, 0, 1, f]>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 1 : i64, batch_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo] - } : (tensor<1x8x8x32x207xf32>, tensor<3x3x32x207x16xf32>, tensor<4xi32>) -> tensor<32x1x?x?x16xf32> - func.return %0 : tensor<32x1x?x?x16xf32> + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<4xi32>) -> tensor<1x?x?x16xf32> + func.return %0 : tensor<1x?x?x16xf32> } // CHECK-LABEL: "op_dynamic_conv" @@ -907,13 +935,13 @@ func.func @op_get_dimension_size(%arg0: tensor) -> tensor { } // CHECK-LABEL: "op_get_dimension_size" -func.func @op_get_tuple_element(%arg0: tuple>) -> tensor { +func.func @op_get_tuple_element(%arg0: tuple, tensor, tensor, tensor, tensor>) -> tensor { // CHECK: "mhlo.get_tuple_element"(%arg0) { - // CHECK-SAME: index = 0 : i32 - // CHECK-SAME: } : (tuple>) -> tensor + // CHECK-SAME: index = 4 : i32 + // CHECK-SAME: } : (tuple, tensor, tensor, tensor, tensor>) -> tensor %0 = "stablehlo.get_tuple_element"(%arg0) { - index = 0 : i32 - } : (tuple>) -> tensor + index = 4 : i32 + } : (tuple, tensor, tensor, tensor, tensor>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "op_get_tuple_element" @@ -1086,6 +1114,13 @@ func.func @op_pad(%arg0: tensor<8xf32>, %arg1: tensor) -> tensor<16xf32> { } // CHECK-LABEL: "op_pad" +func.func @op_partition_id() -> tensor { + // CHECK: "mhlo.partition_id"() : () -> tensor + %0 = "stablehlo.partition_id"() : () -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "op_partition_id" + func.func @op_popcnt(%arg0: tensor) -> tensor { // CHECK: "mhlo.popcnt"(%arg0) : (tensor) -> tensor %0 = "stablehlo.popcnt"(%arg0) : (tensor) -> tensor @@ -1664,6 +1699,20 @@ func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { } // CHECK-LABEL: "type_ui64" +func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f8E4M3FN" + +func.func @type_f8E5M2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: "type_f8E5M2" + func.func @type_bf16(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor @@ -1749,6 +1798,25 @@ func.func @type_token_caller(%arg0: !stablehlo.token) -> !stablehlo.token { // CHECK: function_type = (!mhlo.token) -> !mhlo.token // CHECK-LABEL: "type_token_caller" +func.func @type_token_region(%arg0: tensor, %arg1: !stablehlo.token) { + // CHECK: "mhlo.while"(%arg1) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !mhlo.token): + // CHECK-NEXT: "mhlo.return"(%arg0) : (tensor) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !mhlo.token): + // CHECK-NEXT: "mhlo.return"(%[[ARG2]]) : (!mhlo.token) -> () + // CHECK-NEXT: }) : (!mhlo.token) -> !mhlo.token + %0 = "stablehlo.while"(%arg1) ({ + ^bb0(%arg2: !stablehlo.token): + stablehlo.return %arg0 : tensor + }, { + ^bb0(%arg2: !stablehlo.token): + stablehlo.return %arg2 : !stablehlo.token + }) : (!stablehlo.token) -> !stablehlo.token + return +} +// CHECK-LABEL: "type_token_region" + func.func @type_tuple(%arg0: tuple>) -> tuple { %0 = "stablehlo.custom_call"(%arg0) { call_target_name = "foo" @@ -1757,3 +1825,19 @@ func.func @type_tuple(%arg0: tuple>) -> tuple { return %0 : tuple } // CHECK-LABEL: "type_tuple" + +// ============ NEGATIVE TESTS ============ +// Some ops, attributes and types used in StableHLO programs are not supported in MHLO. +// For those cases, we have negative tests below. + +// ----- + +func.func @op_custom_call_botched_extensibility_protocol(%arg0: tensor) -> tensor { + // expected-error@+1 {{failed to legalize operation 'stablehlo.custom_call' that was explicitly marked illegal}} + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "mhlo.custom_call", + backend_config = "{api_version = 4 : i32, backend_config = {foo = \22bar\22}, call_target_name = \22foo\22}", + has_side_effect = false + } : (tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir index 026e6f287f3..4eb4d418396 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/symbolic-shape-optimization.mlir @@ -645,27 +645,25 @@ func.func @empty_bcast(%arg0 : tensor, %arg1 : tensor) -> tensor<0xind // ----- // CHECK-LABEL: @simplifiable_bcast -// CHECK-SAME: %[[ARG0:.*]]: tensor -// CHECK-SAME: %[[ARG1:.*]]: tensor<1x8x1x?x1x?xf32> +// CHECK-SAME: %[[ARG0:.*]]: tensor {rt.symbolic_shape = dense<[-2, 1, 1, 4, -2, -3, 1]> : tensor<7xi64>} +// CHECK-SAME: %[[ARG1:.*]]: tensor<1x8x1x?x1x?xf32> {rt.symbolic_shape = dense<[1, 8, 1, -2, 1, -4]> : tensor<6xi64>} func.func @simplifiable_bcast( %arg0 : tensor {rt.symbolic_shape = dense<[-2, 1, 1, 4, -2, -3, 1]> : tensor<7xi64>}, %arg1 : tensor<1x8x1x?x1x?xf32> {rt.symbolic_shape = dense<[ 1, 8, 1, -2, 1, -4]> : tensor<6xi64>}) -> tensor<7xindex> { - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 - // CHECK-DAG: %[[C4:.*]] = arith.constant 4 - // CHECK-DAG: %[[C5:.*]] = arith.constant 5 - // CHECK-DAG: %[[C8:.*]] = arith.constant 8 - // CHECK-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] - // CHECK-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] - // CHECK-DAG: %[[S0D0:.*]] = tensor.extract %[[S0]][%[[C0]]] - // CHECK-DAG: %[[S0D4:.*]] = tensor.extract %[[S0]][%[[C4]]] - // CHECK-DAG: %[[S0D5:.*]] = tensor.extract %[[S0]][%[[C5]]] - // CHECK-DAG: %[[S1D5:.*]] = tensor.extract %[[S1]][%[[C5]]] - // CHECK-DAG: %[[RES:.*]] = tensor.from_elements %[[S0D0]], %[[C1]], %[[C8]], %[[C4]], %[[S0D4]], %[[S0D5]], %[[S1D5]] - // CHECK: return %[[RES]] + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index + // CHECK-DAG: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] + // CHECK-DAG: %[[DIM_0:.*]] = tensor.dim %[[ARG0]], %[[C4]] + // CHECK-DAG: %[[DIM_1:.*]] = tensor.dim %[[ARG0]], %[[C5]] + // CHECK-DAG: %[[DIM_2:.*]] = tensor.dim %[[ARG1]], %[[C5]] + // CHECK-DAG: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[C1]], %[[C8]], %[[C4]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] + // CHECK: return %[[FROM_ELEMENTS]] : tensor<7xindex> %0 = shape.shape_of %arg0 : tensor -> tensor<7xindex> %1 = shape.shape_of %arg1 : tensor<1x8x1x?x1x?xf32> -> tensor<6xindex> %2 = shape.broadcast %0, %1 : tensor<7xindex>, tensor<6xindex> @@ -695,10 +693,7 @@ func.func @very_dynamic_bcast(%arg0 : tensor, %arg1 : tensor) // CHECK-LABEL: @broadcast_w_dyn_ty // CHECK-SAME: %[[ARG:.*]]: tensor<1xindex> func.func @broadcast_w_dyn_ty(%arg0: tensor<1xindex>) -> tensor{ - // CHECK: %[[C0:.*]] = arith.constant 0 - // CHECK: %[[D0:.*]] = tensor.extract %[[ARG]][%[[C0]]] - // CHECK: %[[UNCAST:.*]] = tensor.from_elements %[[D0]] - // CHECK: %[[CAST:.*]] = tensor.cast %[[UNCAST]] : tensor<1xindex> to tensor + // CHECK: %[[CAST:.*]] = tensor.cast %[[ARG]] : tensor<1xindex> to tensor // CHECK: return %[[CAST]] %0 = shape.broadcast %arg0, %arg0 : tensor<1xindex>, tensor<1xindex> -> tensor @@ -717,3 +712,118 @@ func.func @broadcast_scalar_w_dyn_ty(%arg0: tensor<0xindex>) -> tensor{ : tensor<0xindex>, tensor<0xindex> -> tensor func.return %0 : tensor } + +// ----- + +// CHECK-LABEL: @optimize_1dx1d_bcast +// CHECK-SAME: %[[ARG0:.*]]: tensor {rt.symbolic_shape = dense<-2> : tensor<1xi64>}, %[[ARG1:.*]]: tensor {rt.symbolic_shape = dense<-2> : tensor<1xi64>} +func.func @optimize_1dx1d_bcast( + %arg0: tensor + {rt.symbolic_shape = dense<[-2]> : tensor<1xi64>}, + %arg1: tensor + {rt.symbolic_shape = dense<[-2]> : tensor<1xi64>} +) -> tensor { + // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] + // CHECK: %[[DYNAMIC:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[SHAPE]]) + // CHECK-SAME: broadcast_dimensions = dense<0> + // CHECK-SAME: known_expanding_dimensions = dense<> + // CHECK-SAME: known_nonexpanding_dimensions = dense<0> + // CHECK: return %[[DYNAMIC]] + %0 = shape.shape_of %arg0 : tensor -> tensor<1xindex> + %1 = shape.shape_of %arg1 : tensor -> tensor<1xindex> + %2 = shape.broadcast %0, %1 : tensor<1xindex>, tensor<1xindex> + -> tensor<1xindex> + %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) + {broadcast_dimensions = dense<[0]> : tensor<1xi64>} + : (tensor, tensor<1xindex>) -> tensor + func.return %3: tensor +} + +// ----- + +// CHECK-LABEL: @optimize_1dx2d_bcast_const_shape +// CHECK-SAME: %[[ARG0_0:.*]]: tensor<512xf32>, %[[ARG1_0:.*]]: tensor {rt.symbolic_shape = dense<[-2, 512]> : tensor<2xi64>} +func.func @optimize_1dx2d_bcast_const_shape( + %arg0: tensor<512xf32>, + %arg1: tensor + {rt.symbolic_shape = dense<[-2, 512]> : tensor<2xi64>} +) -> tensor { + // CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG1_0]] + // CHECK: %[[DYNAMIC_0:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0_0]], %[[SHAPE_0]]) + // CHECK-SAME: broadcast_dimensions = dense<1> + // CHECK-SAME: known_expanding_dimensions = dense<> + // CHECK-SAME: known_nonexpanding_dimensions = dense<0> + // CHECK: return %[[DYNAMIC_0]] + %0 = shape.const_shape [512] : tensor<1xindex> + %1 = shape.shape_of %arg1 : tensor -> tensor<2xindex> + %2 = shape.broadcast %0, %1 : tensor<1xindex>, tensor<2xindex> + -> tensor<2xindex> + %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) + {broadcast_dimensions = dense<[1]> : tensor<1xi64>} + : (tensor<512xf32>, tensor<2xindex>) -> tensor + func.return %3: tensor +} + +// ----- + +// CHECK-LABEL: @optimize_1dx1dx1d_bcast +// CHECK: %[[ARG0_1:.*]]: tensor {rt.symbolic_shape = dense<-2> : tensor<1xi64>}, %[[ARG1_1:.*]]: tensor {rt.symbolic_shape = dense<-2> : tensor<1xi64>}, %[[ARG2:.*]]: tensor {rt.symbolic_shape = dense<-2> : tensor<1xi64>} +func.func @optimize_1dx1dx1d_bcast( + %arg0: tensor + {rt.symbolic_shape = dense<[-2]> : tensor<1xi64>}, + %arg1: tensor + {rt.symbolic_shape = dense<[-2]> : tensor<1xi64>}, + %arg2: tensor + {rt.symbolic_shape = dense<[-2]> : tensor<1xi64>} +) -> tensor { + // CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG0_1]] + // CHECK: %[[DYNAMIC_1:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0_1]], %[[SHAPE_1]]) + // CHECK-SAME: broadcast_dimensions = dense<0> + // CHECK-SAME: known_expanding_dimensions = dense<> + // CHECK-SAME: known_nonexpanding_dimensions = dense<0> + // CHECK: return %[[DYNAMIC_1]] + %0 = shape.shape_of %arg0 : tensor -> tensor<1xindex> + %1 = shape.shape_of %arg1 : tensor -> tensor<1xindex> + %2 = shape.shape_of %arg2 : tensor -> tensor<1xindex> + %3 = shape.broadcast %0, %1 : tensor<1xindex>, tensor<1xindex> + -> tensor<1xindex> + %4 = shape.broadcast %3, %2 : tensor<1xindex>, tensor<1xindex> + -> tensor<1xindex> + %5 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) + {broadcast_dimensions = dense<[0]> : tensor<1xi64>} + : (tensor, tensor<1xindex>) -> tensor + func.return %5: tensor +} + +// ----- + +// CHECK-LABEL: @optimize_2dx1d_bcast +// CHECK-SAME: %[[ARG0_2:.*]]: tensor<10x?xf32> {rt.symbolic_shape = dense<[10, -2]> : tensor<2xi64>}, %[[ARG1_2:.*]]: tensor {rt.symbolic_shape = dense<-2> : tensor<1xi64>} +func.func @optimize_2dx1d_bcast( + %arg0: tensor<10x?xf32> + {rt.symbolic_shape = dense<[10, -2]> : tensor<2xi64>}, + %arg1: tensor + {rt.symbolic_shape = dense<[-2]> : tensor<1xi64>} +) -> (tensor<10x?xf32>, tensor<10x?xf32>) { + // CHECK: %[[SHAPE_2:.*]] = shape.shape_of %[[ARG0_2]] + // CHECK: %[[DYNAMIC_2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0_2]], %[[SHAPE_2]]) + // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> + // CHECK-SAME: known_expanding_dimensions = dense<> + // CHECK-SAME: known_nonexpanding_dimensions = dense<[0, 1]> + // CHECK: %[[DYNAMIC_3:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1_2]], %[[SHAPE_2]]) + // CHECK-SAME: broadcast_dimensions = dense<1> + // CHECK-SAME: known_expanding_dimensions = dense<> + // CHECK-SAME: known_nonexpanding_dimensions = dense<0> + // CHECK: return %[[DYNAMIC_2]], %[[DYNAMIC_3]] + %0 = shape.shape_of %arg0 : tensor<10x?xf32> -> tensor<2xindex> + %1 = shape.shape_of %arg1 : tensor -> tensor<1xindex> + %2 = shape.broadcast %0, %1 : tensor<2xindex>, tensor<1xindex> + -> tensor<2xindex> + %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) + {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + : (tensor<10x?xf32>, tensor<2xindex>) -> tensor<10x?xf32> + %4 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %2) + {broadcast_dimensions = dense<[1]> : tensor<1xi64>} + : (tensor, tensor<2xindex>) -> tensor<10x?xf32> + func.return %3, %4: tensor<10x?xf32>, tensor<10x?xf32> +} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_bounds.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_bounds.mlir new file mode 100644 index 00000000000..01c898b7820 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_bounds.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file + +// expected-error@+1 {{Bounds length is 1, expected to be equal to rank(2) of the tensor}} +func.func @incorrect_bounds_length(%arg0: tensor>) -> tensor> { + func.return %arg0 : tensor> +} + +// ----- + +// expected-error@+1 {{Static dimension 0 cannot have a bound, use ShapedType::kDynamic to indicate a missing bound}} +func.func @static_dim_with_bound(%arg0: tensor<3xf32, #mhlo.type_extensions>) -> tensor> { + func.return %arg0 : tensor> +} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_conv_op.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_conv_op.mlir index 713b200fd20..5b22e609fa5 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_conv_op.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_conv_op.mlir @@ -1,9 +1,8 @@ // RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | FileCheck %s -// ----- - // Valid: Generic convolution +// CHECK-LABEL: func @main func.func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> { %result = "mhlo.convolution"(%arg0, %arg1) { @@ -29,8 +28,11 @@ func.func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) func.return %result : tensor<100x28x28x1xf32> } +// ----- + // Valid: Test convolution i8xi8 -> i32. +// CHECK-LABEL: func @convolution_upcast func.func @convolution_upcast(%arg0 : tensor<100x26x26x32xi8>, %arg1 : tensor<3x3x1x32xi8>) -> tensor<100x28x28x1xi32> { %result = "mhlo.convolution"(%arg0, %arg1) { @@ -55,6 +57,8 @@ func.func @convolution_upcast(%arg0 : tensor<100x26x26x32xi8>, func.return %result : tensor<100x28x28x1xi32> } +// ----- + // Valid: Empty spatial dimensions // CHECK: func @conv_empty_spatial_dimensions @@ -274,7 +278,7 @@ func.func @invalid_conv_dimensions(%arg0 : tensor<100x26x26x32xf32>, func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{op expects batch_group_count to be a positive number, got 0.}} + // expected-error@+1 {{expects batch_group_count to be a positive number, got 0.}} %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], @@ -292,7 +296,7 @@ func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, func.func @invalid_conv_dimensions(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { - // expected-error@+1 {{op expects feature_group_count to be a positive number, got 0.}} + // expected-error@+1 {{expects feature_group_count to be a positive number, got 0.}} %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], @@ -468,6 +472,23 @@ func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, // ----- +func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, + %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error@+1 {{expects window-reversal to have same dimension-size as size of window dimensions (2), but got: 1.}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], + lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [false]} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo, #mhlo]} : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { // expected-error@+1 {{expects padding-entries to have same dimension-size as size of window dimensions (2), but got: 1.}} @@ -605,7 +626,7 @@ func.func @invalid_conv_window_attributes(%arg0: tensor<1x8x8x207xf32>, func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x16xf32> { - // expected-error @+1 {{expects rank of convolution return-type to be equal to input-ranks (4), but got 3.}} + // expected-error @+1 {{expects rank of convolution return-type to be equal to input-ranks (4), but got 3}} %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], @@ -626,7 +647,7 @@ func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>, func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<2x8x8x16xf32> { - // expected-error@+1 {{nvolution' op has shape mismatch between the expected return-type ('tensor<1x8x8x16xf32>') and actual return-type ('tensor<2x8x8x16xf32>').}} + // expected-error@+1 {{inferred shape '[1, 8, 8, 16]' is incompatible with return type of operation 'tensor<2x8x8x16xf32>'}} %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1,1]], @@ -647,7 +668,7 @@ func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>, func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x32xf32> { - // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x8x16xf32>') and actual return-type ('tensor<1x8x8x32xf32>').}} + // expected-error@+1 {{inferred shape '[1, 8, 8, 16]' is incompatible with return type of operation 'tensor<1x8x8x32xf32>'}} %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1,1]], @@ -670,7 +691,7 @@ func.func @invalid_conv_return_type(%arg0: tensor<1x8x8x207xf32>, // Dynamic input-batch-dimension func.func @invalid_conv_dynamic_shapes(%arg0: tensor, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> { - // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor') and actual return-type ('tensor<1x1x1x1xf32>').}} + // expected-error@+1 {{inferred shape '[?, 8, 8, 16]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1,1]], @@ -689,7 +710,7 @@ func.func @invalid_conv_dynamic_shapes(%arg0: tensor, // Dynamic input-feature-dimension: No effect on output dimensions. func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x8x8x?xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> { - // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x8x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}} + // expected-error@+1 {{inferred shape '[1, 8, 8, 16]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1,1]], @@ -708,7 +729,7 @@ func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x8x8x?xf32>, // Dynamic input-spatial-dimension func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x?x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x1x1x1xf32> { - // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x?x8x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}} + // expected-error@+1 {{inferred shape '[1, ?, 8, 16]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1,1]], @@ -727,7 +748,7 @@ func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x?x8x207xf32>, // Dynamic kernel-input-feature-dimension: No effect on output dimensions. func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x?x16xf32>) -> tensor<1x1x1x1xf32> { - // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x8x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}} + // expected-error@+1 {{inferred shape '[1, 8, 8, 16]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1,1]], @@ -746,7 +767,7 @@ func.func @invalid_conv_dynamic_shapes(%arg0: tensor<1x8x8x207xf32>, // Dynamic kernel-output-feature-dimension func.func @check_inferred_type_with_dynamic_input_dims(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x?xf32>) -> tensor<1x1x1x1xf32> { - // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x8x?xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}} + // expected-error@+1 {{inferred shape '[1, 8, 8, ?]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1,1]], @@ -765,7 +786,7 @@ func.func @check_inferred_type_with_dynamic_input_dims(%arg0: tensor<1x8x8x207xf // Dynamic kernel-spatial-dimension func.func @check_inferred_type_with_dynamic_input_dims(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x?x207x16xf32>) -> tensor<1x1x1x1xf32> { - // expected-error@+1 {{has shape mismatch between the expected return-type ('tensor<1x8x?x16xf32>') and actual return-type ('tensor<1x1x1x1xf32>').}} + // expected-error@+1 {{inferred shape '[1, 8, ?, 16]' is incompatible with return type of operation 'tensor<1x1x1x1xf32>'}} %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1,1]], @@ -779,3 +800,155 @@ func.func @check_inferred_type_with_dynamic_input_dims(%arg0: tensor<1x8x8x207xf func.return %0 : tensor<1x1x1x1xf32> } +// ----- + +func.func @conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { + // expected-error @+3 {{'mhlo.convolution' Expected array with 2 elements, got 3 elements instead}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1, 1], [1, 1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo, #mhlo]} : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> + func.return %0 : tensor<1x8x8x16xf32> +} + +// ----- + +// CHECK: module +// CHECK-SAME: mhlo.conv = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 1, 0, f]> +module attributes { mhlo.conv = #mhlo.conv} {} + +// ----- + +// CHECK: module +// CHECK: mhlo.conv = #mhlo.conv<[b, 1, 0, f]x[0, 1, i, o]->[b, 0, 1, f]> +module attributes { + mhlo.conv = #mhlo.conv<[b, 1, 0, f]x[0, 1, i, o]->[b, 0, 1, f]> +} {} + +// ----- + +module attributes { + // expected-error@+1{{Unexpected dimension c, expecting b, f}} + mhlo.conv = #mhlo.conv<[c, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]> +} {} + +// ----- + +module attributes { + // expected-error@+1{{Unexpected dimension b, expecting i, o}} + mhlo.conv = #mhlo.conv<[b, 0, 1, f]x[0, 1, b, o]->[b, 0, 1, f]> +} {} + +// ----- + +module attributes { + // expected-error@+1{{Unexpected dimension i, expecting o}} + mhlo.conv = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, i]->[b, 0, 1, f]> +} {} + +// ----- + +module attributes { + // expected-error@+1{{Expected dimensions f not specified}} + mhlo.conv = #mhlo.conv<[b, 0, 1]x[0, 1, i, o]->[b, 0, 1, f]> +} {} + +// ----- + +module attributes { + // expected-error@+1{{Unexpected keyword b}} + mhlo.conv = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o, b]->[b, 0, 1, f]> +} {} + +// ----- + +module attributes { + // expected-error@+1{{expected '['}} + mhlo.conv = #mhlo.conv<{b, 0, 1, f}x[0, 1, i, o]->[b, 0, 1, f]> +} {} + +// ----- + +module attributes { + // expected-error@+1{{Expected spatial dimensions 0 not specified}} + mhlo.conv = #mhlo.conv<[b, f, 1]x[o, 0, 1, i]->[f, b, 0, 1]> +} {} + +// ----- + +module attributes { + // expected-error@+1{{Duplicate entries for spatial dimension 1}} + mhlo.conv = #mhlo.conv<[b, f, 1, 0, 1]x[o, 0, 1, i]->[f, b, 0, 1]> +} {} + +// ----- + +module attributes { + // expected-error@+1{{Unexpected dimension -2}} + mhlo.conv = #mhlo.conv<[b, f, 1, -2]x[o, 0, 1, i]->[f, b, 0, 1]> +} {} + +// ----- + +func.func @convolution(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> { + // expected-error@+3{{Unexpected keyword stide}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stide = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]} + { batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> + func.return %0 : tensor<3x5x5x4xf32> +} + +// ----- + +func.func @convolution(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> { + // expected-error@+3{{expected integer value}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [2, b], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2]} + { batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> + func.return %0 : tensor<3x5x5x4xf32> +} + +// ----- + +func.func @convolution(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> { + // expected-error@+3{{Unexpected keyword stride}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [2, 1], pad = [[0, 1], [0, 1]], rhs_dilate = [1, 2], stride=[2,1]} + { batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> + func.return %0 : tensor<3x5x5x4xf32> +} + +// ----- + +func.func @conv_invalid_precision_config(%arg0: tensor<3x2xf16>, + %arg1: tensor<2x2xf16>) -> tuple> { + // expected-error@+1 {{expects precision config to be empty or have <= 2 elements}} + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, f]x[i, o]->[b, f], + window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], + reverse = []} + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + precision_config = [#mhlo, #mhlo, #mhlo] + } + : (tensor<3x2xf16>, tensor<2x2xf16>) -> tensor<3x2xf16> + %1 = "mhlo.tuple"(%0) : (tensor<3x2xf16>) -> tuple> + func.return %1 : tuple> +} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir index 50c8ca186ab..1d9a912fee6 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_reduce_op.mlir @@ -381,6 +381,23 @@ func.func @verify_reducer_function(%arg0: tensor<8x5xf32>, %arg1 : tensor<4xf32> // ----- +// Verifies that dynamic input type is allowed with reducer function with static shapes. +func.func @verify_dynamic_operand(%arg0: tensor<8x?xf32>, %arg1 : tensor<4xf32>) + -> (tensor) { + + %0 = "mhlo.reduce"(%arg0, %arg1) ({ + + ^bb0(%arg2: tensor<4xf32>, %arg3: tensor<4xf32> ): + %1 = "mhlo.add"(%arg2, %arg3) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + "mhlo.return"(%1) : (tensor<4xf32>) -> () + + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<8x?xf32>, tensor<4xf32>) -> tensor + + func.return %0: tensor +} + +// ----- + func.func @reduce_verify_rettype(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor) { diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_scatter_op.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_scatter_op.mlir index ce18f90ebae..137a6609845 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_scatter_op.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_scatter_op.mlir @@ -481,7 +481,7 @@ func.func @invalid_scatter_return_type(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100xf32> { - // expected-error @+1 {{expects the return type to be same as the operand type: 'tensor<200x100x300xf32>', but got 'tensor<200x100xf32>'.}} + // expected-error @+1 {{inferred type(s) 'tensor<200x100x300xf32>' are incompatible with return type(s) of operation 'tensor<200x100xf32>'}} %0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ ^bb0(%lhs: tensor, %rhs: tensor): %add = mhlo.add %lhs, %rhs : tensor diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_select_and_scatter_op.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_select_and_scatter_op.mlir index b5892a95a74..1eebafb939d 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_select_and_scatter_op.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_select_and_scatter_op.mlir @@ -394,7 +394,7 @@ func.func @select_and_scatter_invalid_ret_type( %arg1: tensor<10x12x12x64xf32>) -> () { %0 = mhlo.constant dense<0.000000e+00> : tensor - // expected-error @+1 {{expects the return-type to match the operand-type, but got 'tensor<10x24x24x32xf32>' and 'tensor<10x24x24x64xf32>' resp.}} + // expected-error @+1 {{inferred type(s) 'tensor<10x24x24x64xf32>' are incompatible with return type(s) of operation 'tensor<10x24x24x32xf32>'}} %1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({ ^bb0(%arg3: tensor, %arg4: tensor): %2 = "mhlo.compare"(%arg3, %arg4) { @@ -422,7 +422,7 @@ func.func @select_and_scatter_invalid_ret_type( %arg1: tensor<10x12x12x64xf32>) -> () { %0 = mhlo.constant dense<0.000000e+00> : tensor - // expected-error @+1 {{expects the return-type to match the operand-type, but got 'tensor<10x24x24x64xi32>' and 'tensor<10x24x24x64xf32>' resp.}} + // expected-error @+1 {{inferred type(s) 'tensor<10x24x24x64xf32>' are incompatible with return type(s) of operation 'tensor<10x24x24x64xi32>'}} %1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({ ^bb0(%arg3: tensor, %arg4: tensor): %2 = "mhlo.compare"(%arg3, %arg4) { diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_while_op.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_while_op.mlir index e4dc05894f7..882f7cf66fc 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_while_op.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/mhlo/verifier_while_op.mlir @@ -121,7 +121,7 @@ func.func @while_with_invalid_tuples(%arg0: tensor<3xf32>) -> tensor<3xf32> { %cst_2 = arith.constant dense<1.00> : tensor<1xf32> %0 = "mhlo.tuple"(%arg0, %cst_2) : (tensor<3xf32>, tensor<1xf32>) -> tuple, tensor<1xf32>> %1 = "mhlo.tuple"(%cst_1, %0) : (tensor<2xi32>, tuple, tensor<1xf32>>) -> tuple, tuple, tensor<1xf32>>> - // expected-error @+1 {{op operand #1 must be tensor of 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer values or token}} + // expected-error @+1 {{op operand #1 must be tensor of f8E4M3FN type or f8E5M2 type or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or pred (AKA boolean or 1-bit integer) or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer or complex type with 32-bit float or 64-bit float elements or 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer values or token}} %2:2 = "mhlo.while"(%cst_0, %1) ({ ^bb0(%arg1: tensor<1xi32>, %arg2: tuple, tuple, tensor<3xf32>>>): %t0 = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple, tuple, tensor<3xf32>>>) -> tensor<2xi32> diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/bufferize.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/bufferize.mlir index 860013026e4..5e88a2dd320 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/bufferize.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/bufferize.mlir @@ -8,7 +8,8 @@ func.func @sort(%input1: tensor, %input2: tensor, %sorted1, %sorted2 = thlo.sort ins(%input1: tensor, %input2: tensor) outs(%init1: tensor, %init2: tensor) - { dimension = 1 : i64, is_stable = true } + dimension = 1 + is_stable = true (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { %gt = arith.cmpf ogt, %e11, %e12: f32 thlo.yield %gt : i1 @@ -28,12 +29,13 @@ func.func @sort(%input1: tensor, %input2: tensor, // CHECK-DAG: memref.copy %[[INIT2]], %[[OUTPUT2]] // CHECK: thlo.sort // CHECK-SAME: ins(%[[INPUT1]] : memref, -// CHECK-SAME: %[[INPUT2]] : memref) +// CHECK-SAME: %[[INPUT2]] : memref) // CHECK-SAME: outs(%[[OUTPUT1]] : memref, -// CHECK-SAME: %[[OUTPUT2]] : memref) -// CHECK-SAME: {dimension = 1 : i64, is_stable = true} -// CHECK-SAME: (%[[FLOAT1:[A-Za-z_0-9]*]]: f32, %[[FLOAT2:.*]]: f32, +// CHECK-SAME: %[[OUTPUT2]] : memref) +// CHECK-SAME: dimension = 1 +// CHECK-SAME: is_stable = true +// CHECK-NEXT: (%[[FLOAT1:[A-Za-z_0-9]*]]: f32, %[[FLOAT2:.*]]: f32, // CHECK-SAME: %[[INT1:[A-Za-z_0-9]*]]: i32, %[[INT2:.*]]: i32) // CHECK: %[[RESULT:.*]] = arith.cmpf ogt, %[[FLOAT1]], %[[FLOAT2]] : f32 // CHECK: thlo.yield %[[RESULT]] : i1 -// CHECK: return %[[OUTPUT1]], %[[OUTPUT2]] \ No newline at end of file +// CHECK: return %[[OUTPUT1]], %[[OUTPUT2]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/canonicalize.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/canonicalize.mlir new file mode 100644 index 00000000000..aaec8cc5b1e --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/canonicalize.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-hlo-opt %s --split-input-file \ +// RUN: --canonicalize | FileCheck %s + +func.func @reverse_dynamic_fold(%input: tensor<1x?xf32>, %init: tensor<1x?xf32>) + -> tensor<1x?xf32> { + %res = thlo.reverse + ins(%input: tensor<1x?xf32>) + outs(%init: tensor<1x?xf32>) + reverse_dimensions = [0] + func.return %res : tensor<1x?xf32> +} + +// CHECK-LABEL: func @reverse_dynamic_fold +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>, %[[ARG1:.*]]: tensor<1x?xf32> +// CHECK: return %[[ARG0]] \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/invalid.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/invalid.mlir index 4f1c8d8fca6..8582b40adab 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/invalid.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/invalid.mlir @@ -7,7 +7,7 @@ func.func @concatenate(%arg1: tensor, %cat = thlo.concatenate ins(%arg1: tensor, %arg2: tensor) outs(%dst: tensor) - { dimension = 0 : i64 } + dimension = 0 func.return %cat : tensor } @@ -20,7 +20,7 @@ func.func @concatenate_mismatch_rank(%arg1: tensor, %cat = thlo.concatenate ins(%arg1: tensor, %arg2: tensor) outs(%dst: tensor) - { dimension = 0 : i64 } + dimension = 0 func.return %cat : tensor } @@ -33,7 +33,7 @@ func.func @concatenate_mismatch_shape(%arg1: tensor, %cat = thlo.concatenate ins(%arg1: tensor, %arg2: tensor) outs(%dst: tensor) - { dimension = 0 : i64 } + dimension = 0 func.return %cat : tensor } @@ -210,7 +210,8 @@ func.func @sort_mismatched_number_of_inputs_and_outputs( %sorted = thlo.sort ins(%input1: tensor, %input2: tensor) outs(%init1: tensor) - { dimension = 0 : i64, is_stable = true } + dimension = 0 + is_stable = true (%e11: f32, %e12: f32) { %gt = arith.cmpf ogt, %e11, %e12: f32 thlo.yield %gt : i1 @@ -228,7 +229,8 @@ func.func @sort_mismatched_number_of_inputs_and_comparator_arguments( %sorted1, %sorted2 = thlo.sort ins(%input1: tensor, %input2: tensor) outs(%init1: tensor, %init2: tensor) - { dimension = 0 : i64, is_stable = true } + dimension = 0 + is_stable = true (%e11: f32, %e12: f32, %e21: i32) { %gt = arith.cmpf ogt, %e11, %e12: f32 thlo.yield %gt : i1 @@ -246,7 +248,8 @@ func.func @sort_mismatched_input_and_comparator_type( %sorted1, %sorted2 = thlo.sort ins(%input1: tensor, %input2: tensor) outs(%init1: tensor, %init2: tensor) - { dimension = 0 : i64, is_stable = true } + dimension = 0 + is_stable = true (%e11: f32, %e12: f32, %e21: i32, %e22: f32) { %gt = arith.cmpf ogt, %e11, %e12: f32 thlo.yield %gt : i1 @@ -263,7 +266,8 @@ func.func @sort_comparator_yields_different_than_one_output( %sorted1, %sorted2 = thlo.sort ins(%input1: tensor, %input2: tensor) outs(%init1: tensor, %init2: tensor) - { dimension = 0 : i64, is_stable = true } + dimension = 0 + is_stable = true (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { %gt = arith.cmpf ogt, %e11, %e12: f32 // expected-error@+1{{'thlo.yield' op expects number of tensor output args = 1 to match the number of yield operands = 2}} @@ -281,7 +285,8 @@ func.func @sort_comparator_yields_non_boolean( %sorted1, %sorted2 = thlo.sort ins(%input1: tensor, %input2: tensor) outs(%init1: tensor, %init2: tensor) - { dimension = 0 : i64, is_stable = true } + dimension = 0 + is_stable = true (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { // expected-error@+1{{'thlo.yield' op expects yield operand 0 with type = 'f32' to match output arg element type = 'i1'}} thlo.yield %e11 : f32 @@ -299,7 +304,8 @@ func.func @sort_inputs_have_different_shapes( %sorted1, %sorted2 = thlo.sort ins(%input1: tensor<64x32xf32>, %input2: tensor<32x32xi32>) outs(%init1: tensor, %init2: tensor) - { dimension = 0 : i64, is_stable = true } + dimension = 0 + is_stable = true (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { %gt = arith.cmpf ogt, %e11, %e12: f32 thlo.yield %gt : i1 @@ -317,7 +323,8 @@ func.func @sort_output_has_different_shape_from_inputs( %sorted1, %sorted2 = thlo.sort ins(%input1: tensor<64x32xf32>, %input2: tensor<64x32xi32>) outs(%init1: tensor<32x64xf32>, %init2: tensor) - { dimension = 0 : i64, is_stable = true } + dimension = 0 + is_stable = true (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { %gt = arith.cmpf ogt, %e11, %e12: f32 thlo.yield %gt : i1 @@ -335,7 +342,8 @@ func.func @sort_dimension_is_incompatible_with_rank_of_inputs( %sorted1, %sorted2 = thlo.sort ins(%input1: tensor, %input2: tensor) outs(%init1: tensor, %init2: tensor) - { dimension = 2 : i64, is_stable = true } + dimension = 2 + is_stable = true (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { %gt = arith.cmpf ogt, %e11, %e12: f32 thlo.yield %gt : i1 diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/legalize_sort.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/legalize_sort.mlir index f893096d564..7a7c90dfe8a 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/legalize_sort.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/legalize_sort.mlir @@ -5,7 +5,8 @@ func.func @sort(%input1: memref, %input2: memref, thlo.sort ins(%input1: memref, %input2: memref) outs(%init1: memref, %init2: memref) - { dimension = 0 : i64, is_stable = true } + dimension = 0 + is_stable = true (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { %gt = arith.cmpf ogt, %e11, %e12: f32 thlo.yield %gt : i1 diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/ops.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/ops.mlir index 1292dc85d08..94bfedea885 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/ops.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/thlo/ops.mlir @@ -9,7 +9,7 @@ func.func @concatenate(%arg1: tensor, %cat = thlo.concatenate ins(%arg1: tensor, %arg2: tensor) outs(%dst: tensor) - { dimension = 0 : i64 } + dimension = 0 func.return %cat : tensor } // CHECK-LABEL: func @concatenate @@ -22,7 +22,7 @@ func.func @concatenate_memref(%arg1: memref, thlo.concatenate ins(%arg1: memref, %arg2: memref) outs(%dst: memref) - { dimension = 0 : i64 } + dimension = 0 func.return } // CHECK-LABEL: func @concatenate_memref @@ -112,7 +112,8 @@ func.func @sort(%input1: tensor, %input2: tensor, %sorted1, %sorted2 = thlo.sort ins(%input1: tensor, %input2: tensor) outs(%init1: tensor, %init2: tensor) - { dimension = 0 : i64, is_stable = true } + dimension = 0 + is_stable = true (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { %gt = arith.cmpf ogt, %e11, %e12: f32 thlo.yield %gt : i1 @@ -120,6 +121,9 @@ func.func @sort(%input1: tensor, %input2: tensor, func.return %sorted1, %sorted2 : tensor, tensor } // CHECK-LABEL: func @sort +// CHECK: %[[RES1:sorted0]], %[[RES2:sorted1]] = thlo.sort +// CHECK: %[[LHS0:lhs0: f32]], %[[RHS0:rhs0: f32]], +// CHECK-SAME: %[[LHS1:lhs1: i32]], %[[RHS1:rhs1: i32]] // ----- @@ -128,7 +132,8 @@ func.func @sort_memref(%input1: memref, %input2: memref, thlo.sort ins(%input1: memref, %input2: memref) outs(%init1: memref, %init2: memref) - { dimension = 0 : i64, is_stable = true } + dimension = 0 + is_stable = true (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { %gt = arith.cmpf ogt, %e11, %e12: f32 thlo.yield %gt : i1 @@ -136,3 +141,27 @@ func.func @sort_memref(%input1: memref, %input2: memref, func.return } // CHECK-LABEL: func @sort_memref + +// ----- + +func.func @reverse_static(%input: tensor<100xf32>, %init: tensor<100xf32>) + -> tensor<100xf32> { + %res = thlo.reverse + ins(%input: tensor<100xf32>) + outs(%init: tensor<100xf32>) + reverse_dimensions = [0] + func.return %res : tensor<100xf32> +} +// CHECK-LABEL: func @reverse_static + +// ----- + +func.func @reverse_dynamic(%input: tensor, %init: tensor) + -> tensor { + %res = thlo.reverse + ins(%input: tensor) + outs(%init: tensor) + reverse_dimensions = [0, 1] + func.return %res : tensor +} +// CHECK-LABEL: func @reverse_dynamic \ No newline at end of file diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/alloc_to_arg.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/alloc_to_arg.mlir index 1248d7b5a9e..f477ecc2292 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/alloc_to_arg.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/alloc_to_arg.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt --alloc-to-arg %s -verify-diagnostics -split-input-file \ +// RUN: mlir-hlo-opt %s --alloc-to-arg -verify-diagnostics -split-input-file -allow-unregistered-dialect \ // RUN: | FileCheck %s // CHECK-LABEL: func @alloc_to_arg @@ -12,6 +12,18 @@ func.func @alloc_to_arg(%arg0: memref<8xf32>) -> (memref<8xf32> {my.attr}) { // ----- func.func @not_alloc(%arg0: memref<8xf32>) -> memref<8xf32> { - // expected-error@+1 {{expected operand #0 to be defined by an memref.alloc}} + // expected-error@+1 {{expected operand #0 to be defined by (shape-expanded) memref.alloc}} return %arg0 : memref<8xf32> } + +// ----- + +// CHECK: @fusion(%[[ARG0:.*]]: memref<4x4x8x32xf32>) +func.func @fusion() -> memref<4x4x8x32xf32> { + // CHECK: %[[COLLAPSE_SHAPE:.*]] = memref.collapse_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3{{\]\]}} + // CHECK: "some.use"(%[[COLLAPSE_SHAPE]], %[[ARG0]]) + %alloc = memref.alloc() {alignment = 64 : i64} : memref<128x32xf32> + %expand_shape = memref.expand_shape %alloc [[0, 1, 2], [3]] : memref<128x32xf32> into memref<4x4x8x32xf32> + "some.use"(%alloc, %expand_shape) : (memref<128x32xf32>, memref<4x4x8x32xf32>) -> () + return %expand_shape : memref<4x4x8x32xf32> +} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/bufferize_one_shot.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/bufferize_one_shot.mlir index 5c625ace62a..91d04a84872 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/bufferize_one_shot.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/bufferize_one_shot.mlir @@ -159,216 +159,6 @@ func.func @init_tensor_multiple_users(%lhs: tensor<10xf32>, // ----- -// CHECK-LABEL: func @tiled_dot -func.func @tiled_dot(%A: tensor<10xf32>, %B: tensor<10xf32>, - %C: tensor) -> tensor { - %c0 = arith.constant 0 : index - %c2 = arith.constant 2 : index - %c10 = arith.constant 10 : index - - %dot = gml_st.loop (%i) = (%c0) to (%c10) step (%c2) - ins (%A_ = %A: tensor<10xf32>, %B_ = %B: tensor<10xf32>) - outs (%C_ = %C: tensor) - iterators[#gml_st.iterator_type] { - %A_sub = tensor.extract_slice %A_[%i] [%c2] [1] - : tensor<10xf32> to tensor - %B_sub = tensor.extract_slice %B_[%i] [%c2] [1] - : tensor<10xf32> to tensor - %dot_sub = linalg.dot ins(%A_sub, %B_sub : tensor, tensor) - outs(%C_ : tensor) -> tensor - gml_st.yield %dot_sub : tensor - } - // CHECK: gml_st.loop - // CHECK-SAME: ins (%[[A:arg[0-9]]] = %{{arg[0-9]}}: memref<10xf32>, - // CHECK-SAME: %[[B:arg[0-9]]] = %{{arg[0-9]}}: memref<10xf32> - // CHECK-SAME: outs (%[[C:arg[0-9]]] = %{{arg[0-9]}}: memref) - - // CHECK-NEXT: %[[SV_A:.*]] = memref.subview %[[A]] - // CHECK-NEXT: %[[SV_B:.*]] = memref.subview %[[B]] - // CHECK-NEXT: linalg.dot ins(%[[SV_A]], %[[SV_B]] - // CHECK-SAME: outs(%[[C]] : memref) - // CHECK-NEXT: memref.copy - // CHECK-NEXT: gml_st.yield - func.return %dot : tensor -} - -// ----- - -#map0 = affine_map<(d0) -> (d0)> - -func.func @tiled_add(%A: tensor<10xf32>, %B: tensor<10xf32>, - %C: tensor<10xf32>) -> tensor<10xf32> { - %c0 = arith.constant 0 : index - %c2 = arith.constant 2 : index - %c10 = arith.constant 10 : index - - %sum = gml_st.loop (%i) = (%c0) to (%c10) step (%c2) - ins (%A_ = %A: tensor<10xf32>, %B_ = %B: tensor<10xf32>) - outs (%C_ = %C: tensor<10xf32>) { - %A_sub = tensor.extract_slice %A_[%i] [%c2] [1] - : tensor<10xf32> to tensor - %B_sub = tensor.extract_slice %B_[%i] [%c2] [1] - : tensor<10xf32> to tensor - %C_sub = tensor.extract_slice %C_[%i] [%c2] [1] - : tensor<10xf32> to tensor - %sum_sub = linalg.generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel"] - } ins(%A_sub, %B_sub : tensor, tensor) - outs(%C_sub : tensor) { - ^bb0(%a: f32, %b: f32, %c: f32): - %0 = arith.addf %a, %b : f32 - linalg.yield %0 : f32 - } -> tensor - %update = tensor.insert_slice %sum_sub into %C_[%i] [%c2] [1] - : tensor into tensor<10xf32> - gml_st.yield %update : tensor<10xf32> - } - // CHECK: gml_st.loop - // CHECK-SAME: ins (%[[A:arg[0-9]]] = %{{arg[0-9]}}: memref<10xf32>, - // CHECK-SAME: %[[B:arg[0-9]]] = %{{arg[0-9]}}: memref<10xf32> - // CHECK-SAME: outs (%[[C:arg[0-9]]] = %{{arg[0-9]}}: memref<10xf32>) - - // CHECK-NEXT: %[[SV_A:.*]] = memref.subview %[[A]] - // CHECK-NEXT: %[[SV_B:.*]] = memref.subview %[[B]] - // CHECK-NEXT: %[[SV_C:.*]] = memref.subview %[[C]] - // CHECK-NEXT: linalg.generic - // CHECK-SAME: ins(%[[SV_A]], %[[SV_B]] - // CHECK-SAME: outs(%[[SV_C]] : memref<2xf32, strided{{.*}}>) - // CHECK: linalg.yield %{{[0-9]}} : f32 - // CHECK: gml_st.yield - func.return %sum : tensor<10xf32> -} - -// ----- - -func.func @tiled_add_broadcast(%A: tensor<1x?x12xf32>, %B: tensor, - %shape: tensor<3xi32>) -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - %cst = arith.constant 0.000000e+00 : f32 - %AA = "mhlo.dynamic_broadcast_in_dim"(%A, %shape) - {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} - : (tensor<1x?x12xf32>, tensor<3xi32>) -> tensor - - %d0 = tensor.dim %AA, %c0 : tensor - %d1 = tensor.dim %AA, %c1 : tensor - %sum = gml_st.loop (%i0, %i1, %i2) = (%c0, %c0, %c0) to (%d0, %d1, %c8) - step (%c1, %c1, %c8) - ins (%A_ = %AA: tensor) - outs (%B_ = %B: tensor) { - %v_in = vector.transfer_read %A_[%i0, %i1, %i2], %cst - {in_bounds = [true, true, true]} - : tensor, vector<1x1x8xf32> - %v_add = arith.addf %v_in, %v_in : vector<1x1x8xf32> - %v_out = vector.transfer_write %v_add, %B_[%i0, %i1, %i2] - {in_bounds = [true, true, true]} - : vector<1x1x8xf32>, tensor - gml_st.yield %v_out : tensor - } - // CHECK: gml_st.loop - // CHECK-SAME: ins (%[[A:arg[0-9]]] = %{{[0-9a-zA-Z_]+}}: memref) - // CHECK: memref.copy - func.return %sum : tensor -} - -// ----- - -#map0 = affine_map<()[s0] -> ((s0 floordiv 8) * 8)> -#map1 = affine_map<(d0)[s0] -> (-d0 + s0)> -#map2 = affine_map<(d0, d1) -> (d0, d1)> -func.func @init_tensor_multiple_users(%arg0: tensor<1x?xf32>) - -> (tensor<1x?xf32>, tensor<1x?xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - %cst_0 = arith.constant dense<1.000000e+00> : vector<1x8xf32> - %cst_1 = arith.constant 1.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %c1 = arith.constant 1 : index - %0 = tensor.dim %arg0, %c1 : tensor<1x?xf32> - %init = bufferization.alloc_tensor(%0) : tensor<1x?xf32> - %2 = affine.apply #map0()[%0] - %3 = gml_st.loop (%i, %j) = (%c0, %c0) to (%c1, %2) step (%c1, %c8) - ins (%arg3 = %arg0: tensor<1x?xf32>) - outs (%arg4 = %init: tensor<1x?xf32>) { - %7 = vector.transfer_read %arg3[%i, %j], %cst {in_bounds = [true, true]} - : tensor<1x?xf32>, vector<1x8xf32> - %8 = arith.subf %cst_0, %7 : vector<1x8xf32> - %9 = vector.transfer_write %8, %arg4[%i, %j] {in_bounds = [true, true]} - : vector<1x8xf32>, tensor<1x?xf32> - gml_st.yield %9 : tensor<1x?xf32> - } - %4 = gml_st.loop (%i, %j) = (%c0, %2) to (%c1, %0) step (%c1, %c8) - ins (%arg3 = %arg0: tensor<1x?xf32>) - outs (%arg4 = %3: tensor<1x?xf32>) { - %7 = affine.apply #map1(%j)[%0] - %8 = tensor.extract_slice %arg3[%i, %j] [1, %7] [1, 1] - : tensor<1x?xf32> to tensor<1x?xf32> - %9 = tensor.extract_slice %arg4[%i, %j] [1, %7] [1, 1] - : tensor<1x?xf32> to tensor<1x?xf32> - %10 = linalg.generic { - indexing_maps = [#map2, #map2], - iterator_types = ["parallel", "parallel"]} - ins(%8 : tensor<1x?xf32>) outs(%9 : tensor<1x?xf32>) { - ^bb0(%arg5: f32, %arg6: f32): - %12 = arith.subf %cst_1, %arg5 : f32 - linalg.yield %12 : f32 - } -> tensor<1x?xf32> - %11 = tensor.insert_slice %10 into %arg4[%i, %j] [1, %7] [1, 1] - : tensor<1x?xf32> into tensor<1x?xf32> - gml_st.yield %11 : tensor<1x?xf32> - } - %5 = gml_st.loop (%i, %j) = (%c0, %c0) to (%c1, %2) step (%c1, %c8) - ins (%arg3 = %arg0: tensor<1x?xf32>) - outs (%arg4 = %init: tensor<1x?xf32>) { - %7 = vector.transfer_read %arg3[%i, %j], %cst - {in_bounds = [true, true]} : tensor<1x?xf32>, vector<1x8xf32> - %8 = arith.subf %cst_0, %7 : vector<1x8xf32> - %9 = arith.subf %cst_0, %8 : vector<1x8xf32> - %10 = vector.transfer_write %9, %arg4[%i, %j] - {in_bounds = [true, true]} : vector<1x8xf32>, tensor<1x?xf32> - gml_st.yield %10 : tensor<1x?xf32> - } - %6 = gml_st.loop (%i, %j) = (%c0, %2) to (%c1, %0) step (%c1, %c8) - ins (%arg3 = %arg0: tensor<1x?xf32>) - outs (%arg4 = %5: tensor<1x?xf32>) { - %7 = affine.apply #map1(%j)[%0] - %8 = tensor.extract_slice %arg3[%i, %j] [1, %7] [1, 1] - : tensor<1x?xf32> to tensor<1x?xf32> - %9 = tensor.extract_slice %arg4[%i, %j] [1, %7] [1, 1] - : tensor<1x?xf32> to tensor<1x?xf32> - %10 = linalg.generic { - indexing_maps = [#map2, #map2], - iterator_types = ["parallel", "parallel"]} - ins(%8 : tensor<1x?xf32>) outs(%9 : tensor<1x?xf32>) { - ^bb0(%arg5: f32, %arg6: f32): - %12 = arith.subf %cst_1, %arg5 : f32 - %13 = arith.subf %cst_1, %12 : f32 - linalg.yield %13 : f32 - } -> tensor<1x?xf32> - %11 = tensor.insert_slice %10 into %arg4[%i, %j] [1, %7] [1, 1] - : tensor<1x?xf32> into tensor<1x?xf32> - gml_st.yield %11 : tensor<1x?xf32> - } - func.return %4, %6 : tensor<1x?xf32>, tensor<1x?xf32> -} -// CHECK-LABEL: init_tensor_multiple_users -// CHECK: %[[BUF2:.*]] = memref.alloc -// CHECK: %[[BUF1:.*]] = memref.alloc -// CHECK: gml_st.loop -// CHECK: %[[BUF1]] -// CHECK: gml_st.loop -// CHECK: %[[BUF1]] -// CHECK: gml_st.loop -// CHECK: %[[BUF2]] -// CHECK: gml_st.loop -// CHECK: %[[BUF2]] -// CHECK: return %[[BUF1]], %[[BUF2]] - -// ----- - // Test that scf ops are bufferized // CHECK-LABEL: func @if( // CHECK-SAME: %[[PRED:.*]]: i1, diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/capi_test.c b/tensorflow/compiler/xla/mlir_hlo/tests/capi_test.c index 4ca78b69299..a37d4a8aab9 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/capi_test.c +++ b/tensorflow/compiler/xla/mlir_hlo/tests/capi_test.c @@ -12,7 +12,7 @@ limitations under the License. // This file checks that the MHLO CAPI can actually be compiled by a C compiler. // At the moment, this check is only implemented in the Bazel build. -#include "mlir-hlo-c/Attributes.h" -#include "mlir-hlo-c/Dialects.h" -#include "mlir-hlo-c/Passes.h" -#include "mlir-hlo-c/Types.h" +#include "bindings/c/Attributes.h" +#include "bindings/c/Dialects.h" +#include "bindings/c/Passes.h" +#include "bindings/c/Types.h" diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/gpu_fusion_rewrite.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/gpu_fusion_rewrite.mlir index 83eabdc919f..c90a82e4ccc 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/gpu_fusion_rewrite.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/gpu_fusion_rewrite.mlir @@ -25,15 +25,11 @@ func.func @log( // We do however, need some index computations to convert from warp and thread // indices to offset in input/output that that thread should operate on. // TODO(b/247482325): Optimize this better if needed. -// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) -// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) // CHECK-DAG: %[[C32:.*]] = llvm.mlir.constant(32 : index) // CHECK-DAG: %[[TIDX:.*]] = nvvm.read.ptx.sreg.tid.x // CHECK-DAG: %[[TIDY:.*]] = nvvm.read.ptx.sreg.tid.y -// CHECK-DAG: %[[TMP1:.*]] = llvm.mul %[[TIDY]], %[[C32]] -// CHECK-DAG: %[[TMP2:.*]] = llvm.mul %[[TMP1]], %[[C1]] -// CHECK-DAG: %[[WARPOFS:.*]] = llvm.add %[[TMP2]], %[[C0]] -// CHECK: llvm.add %[[WARPOFS]], %[[TIDX]] +// CHECK-DAG: %[[WARPOFS:.*]] = llvm.mul %[[TIDY]], %[[C32]] +// CHECK: llvm.add %[[TIDX]], %[[WARPOFS]] // CHECK-NOT: llvm.mul // CHECK-NOT: llvm.add // CHECK-LABEL: func.func @multidimensional @@ -192,3 +188,93 @@ func.func @softmax( }) {fusion_type = "softmax_fusion"} : () -> () "lmhlo.terminator"() : () -> () } + +// ----- + +// CHECK: gpu.container_module +// CHECK: gpu.module @fusion_kernel +// CHECK: llvm.func @fusion_kernel +// CHECK-SAME: gpu.kernel +// CHECK-LABEL: func.func @complete_softmax +// CHECK: gpu.launch_func @fusion_kernel::@fusion_kernel +func.func @complete_softmax( + %arg0: memref<640xi8> {lmhlo.params = 0 : index}, + %arg1: memref<640xi8> {lmhlo.output_index = dense<> : tensor<0xi64>} +) attributes {result_xla_shape = "f32[32,5]{1,0}"} { + %c0 = arith.constant 0 : index + %view = memref.view %arg0[%c0][] : memref<640xi8> to memref<32x5xf32> + %c0_0 = arith.constant 0 : index + %view_1 = memref.view %arg1[%c0_0][] : memref<640xi8> to memref<32x5xf32> + "lmhlo.fusion"() ({ + %0 = bufferization.to_tensor %view : memref<32x5xf32> + %1 = mhlo.constant dense<0xFF800000> : tensor + %2 = mhlo.reduce(%0 init: %1) across dimensions = [1] : (tensor<32x5xf32>, tensor) -> tensor<32xf32> + reducer(%arg4: tensor, %arg5: tensor) { + %10 = mhlo.maximum %arg4, %arg5 : tensor + mhlo.return %10 : tensor + } + %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<32xf32>) -> tensor<32x5xf32> + %4 = mhlo.subtract %0, %3 : tensor<32x5xf32> + %5 = mhlo.exponential %4 : tensor<32x5xf32> + %6 = mhlo.constant dense<0.000000e+00> : tensor + %7 = mhlo.reduce(%5 init: %6) across dimensions = [1] : (tensor<32x5xf32>, tensor) -> tensor<32xf32> + reducer(%arg4: tensor, %arg5: tensor) { + %10 = mhlo.add %arg4, %arg5 : tensor + mhlo.return %10 : tensor + } + %8 = "mhlo.broadcast_in_dim"(%7) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<32xf32>) -> tensor<32x5xf32> + %9 = mhlo.divide %5, %8 : tensor<32x5xf32> + memref.tensor_store %9, %view_1 : memref<32x5xf32> + "lmhlo.terminator"() : () -> () + }) {fusion_type = "softmax_fusion"} : () -> () + "lmhlo.terminator"() : () -> () +} + +// ----- + +func.func @softmax_4d( + %arg0: memref<16384xi8> {lmhlo.params = 0 : index}, + %arg1: memref<4xi8> {lmhlo.params = 1 : index}, + %arg2: memref<16384xi8> {lmhlo.output_index = dense<> : tensor<0xi64>} +) attributes {result_xla_shape = "f32[128,32]{1,0}"} { + %c0 = arith.constant 0 : index + %view = memref.view %arg0[%c0][] : memref<16384xi8> to memref<4x4x8x32xf32> + %view2 = memref.view %arg1[%c0][] : memref<4xi8> to memref + %c0_0 = arith.constant 0 : index + %view_1 = memref.view %arg2[%c0_0][] + : memref<16384xi8> to memref<4x4x8x32xf32> + "lmhlo.fusion"() ({ + %0 = bufferization.to_tensor %view : memref<4x4x8x32xf32> + %1 = bufferization.to_tensor %view2 : memref + %2 = "mhlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<4x4x8x32xf32> + %3 = mhlo.add %0, %2 : tensor<4x4x8x32xf32> + %4 = mhlo.constant dense<0.0> : tensor + %5 = mhlo.reduce(%3 init: %4) across dimensions = [3] + : (tensor<4x4x8x32xf32>, tensor) -> tensor<4x4x8xf32> + reducer(%arg3: tensor, %arg4: tensor) { + %6 = mhlo.maximum %arg3, %arg4 : tensor + mhlo.return %6 : tensor + } + %7 = "mhlo.broadcast_in_dim"(%5) {broadcast_dimensions = dense<[0, 1, 2]> + : tensor<3xi64>} : (tensor<4x4x8xf32>) -> tensor<4x4x8x32xf32> + %8 = mhlo.subtract %3, %7 : tensor<4x4x8x32xf32> + memref.tensor_store %8, %view_1 : memref<4x4x8x32xf32> + "lmhlo.terminator"() : () -> () + }) {fusion_type = "softmax_fusion"} : () -> () + "lmhlo.terminator"() : () -> () +} + +// CHECK: module attributes {gpu.container_module} +// CHECK: gpu.module @fusion_kernel +// CHECK: llvm.func @fusion_kernel(%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr) + +// CHECK: func.func @softmax_4d(%[[ARG_0:.*]]: memref<16384xi8> {lmhlo.params = 0 : index}, %[[ARG_1:.*]]: memref<4xi8> {lmhlo.params = 1 : index}, %[[ARG_2:.*]]: memref<16384xi8> {lmhlo.output_index = dense<> : tensor<0xi64>}) attributes {result_xla_shape = "f32[128,32]{1,0}"} +// CHECK: %[[VIEW:.*]] = memref.view %[[ARG_0]][%{{.*}}][] : memref<16384xi8> to memref<4x4x8x32xf32> +// CHECK: %[[VIEW_1:.*]] = memref.view %[[ARG_1]][%{{.*}}][] : memref<4xi8> to memref +// CHECK: %[[VIEW_2:.*]] = memref.view %[[ARG_2]][%{{.*}}][] : memref<16384xi8> to memref<4x4x8x32xf32> +// CHECK: %[[COLLAPSE_SHAPE:.*]] = memref.collapse_shape %[[VIEW_2]] [ +// CHECK-SAME: [0, 1, 2], [3]] : memref<4x4x8x32xf32> into memref<128x32xf32> +// CHECK: %[[COLLAPSE_SHAPE_1:.*]] = memref.collapse_shape %[[VIEW]] [ +// CHECK-SAME: [0, 1, 2], [3]] : memref<4x4x8x32xf32> into memref<128x32xf32> +// CHECK: gpu.launch_func @fusion_kernel::@fusion_kernel +// CHECK-SAME: args(%[[VIEW_1]] : memref, %[[COLLAPSE_SHAPE]] : memref<128x32xf32>, %[[COLLAPSE_SHAPE_1]] : memref<128x32xf32>) diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/hlo_to_gpu_pipeline.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/hlo_to_gpu_pipeline.mlir index d03bd9e6066..d86ff8dcabc 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/hlo_to_gpu_pipeline.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/hlo_to_gpu_pipeline.mlir @@ -19,8 +19,14 @@ func.func @simple_op(%arg0: memref<2048xf32>, %arg1: memref<2048xf32>) { "lmhlo.terminator"() : () -> () } // CHECK: gpu.module @[[MODULE]] attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry>} -// CHECK: llvm.func @[[KERNEL]]({{.*}}) attributes {gpu.kernel, nvvm.kernel} +// CHECK: llvm.func @[[KERNEL]]({{.*}}) attributes {gpu.kernel, gpu.known_block_size = array, gpu.known_grid_size = array, nvvm.kernel} // CHECK: llvm.call @__nv_logf +// Make sure we successfully unrolled the loop 4 times +// CHECK: llvm.call @__nv_logf +// CHECK: llvm.call @__nv_logf +// CHECK: llvm.call @__nv_logf +// CHECK-NOT: llvm.call @__nv_logf + // ----- @@ -42,10 +48,15 @@ func.func @fusion(%arg0: memref<2048xf32>, %arg1: memref<2048xf32>) { "lmhlo.terminator"() : () -> () } // CHECK: gpu.module @[[MODULE]] -// CHECK: llvm.func @[[KERNEL]]({{.*}}) attributes {gpu.kernel, nvvm.kernel} +// CHECK: llvm.func @[[KERNEL]]({{.*}}) attributes {gpu.kernel, gpu.known_block_size = array, gpu.known_grid_size = array, nvvm.kernel} // CHECK: %[[ABS:.*]] = llvm.call @__nv_fabsf // CHECK-NOT: llvm.return // CHECK: %[[ADD:.*]] = llvm.fadd %[[ABS]], %[[ABS]] +// Make sure we successfully unrolled the loop 4 times: +// CHECK: llvm.call @__nv_fabsf +// CHECK: llvm.call @__nv_fabsf +// CHECK: llvm.call @__nv_fabsf +// CHECK-NOT: llvm.call @__nv_fabsf // ----- @@ -67,5 +78,13 @@ func.func @imperfect_tiling(%arg0: memref<2051xf32>, %arg1: memref<2051xf32>) { "lmhlo.terminator"() : () -> () } // CHECK: gpu.module @[[MODULE]] attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry>} -// CHECK: llvm.func @[[KERNEL]]({{.*}}) attributes {gpu.kernel, nvvm.kernel} +// CHECK: llvm.func @[[KERNEL]]({{.*}}) attributes {gpu.kernel, gpu.known_block_size = array, gpu.known_grid_size = array, nvvm.kernel} +// CHECK: llvm.call @__nv_logf +// Make sure we successfully unrolled the loop 4 times: +// CHECK: llvm.call @__nv_logf +// CHECK: llvm.call @__nv_logf +// CHECK: llvm.call @__nv_logf +// ... and that we have an imperfect-tile loop at the end: +// CHECK: llvm.cond_br // CHECK: llvm.call @__nv_logf +// CHECK-NOT: llvm.call @__nv_logf diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/hlo_to_gpu_pipeline_softmax.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/hlo_to_gpu_pipeline_softmax.mlir index 21e80e6d577..312f850c790 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/hlo_to_gpu_pipeline_softmax.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/hlo_to_gpu_pipeline_softmax.mlir @@ -42,7 +42,7 @@ func.func @perfectly_tiled_softmax(%argbuffer : memref<2048x4096xf32>, "lmhlo.terminator"() : () -> () } // CHECK: gpu.module @[[MODULE]] -// CHECK: llvm.func @[[KERNEL]]({{.*}}) attributes {gpu.kernel, nvvm.kernel} +// CHECK: llvm.func @[[KERNEL]]({{.*}}) attributes {gpu.kernel, gpu.known_block_size = array, gpu.known_grid_size = array, nvvm.kernel} // CHECK: nvvm.shfl.sync bfly // CHECK: llvm.fcmp // CHECK: llvm.select @@ -92,4 +92,44 @@ func.func @imperfectly_tiled_softmax(%argbuffer : memref<2047x4095xf32>, "lmhlo.terminator"() : () -> () } // CHECK: gpu.module @[[MODULE]] -// CHECK: llvm.func @[[KERNEL]]({{.*}}) attributes {gpu.kernel, nvvm.kernel} +// CHECK: llvm.func @[[KERNEL]]({{.*}}) attributes {gpu.kernel, gpu.known_block_size = array, gpu.known_grid_size = array, nvvm.kernel} + +// ----- + +// CHECK-LABEL: @imperfectly_tiled_softmax_4d +// CHECK-SAME: %[[ARG0:.*]]: memref<6x4x2047x4095xf32>, %[[ARG1:.*]]: memref<6x4x2047x4095xf32> +func.func @imperfectly_tiled_softmax_4d(%argbuffer : memref<6x4x2047x4095xf32>, + %resbuffer : memref<6x4x2047x4095xf32>) { + %arg = bufferization.to_tensor %argbuffer : memref<6x4x2047x4095xf32> + %0 = mhlo.constant dense<-1> : tensor<1xi64> + %1 = mhlo.convert %arg : tensor<6x4x2047x4095xf32> + %2 = mhlo.constant dense<0xFF800000> : tensor + %3 = mhlo.reduce(%1 init: %2) applies mhlo.maximum across dimensions = [3] + : (tensor<6x4x2047x4095xf32>, tensor) -> tensor<6x4x2047xf32> + %4 = mhlo.convert %3 : tensor<6x4x2047xf32> + %cst = arith.constant dense<1> : tensor<1xi32> + %5 = mhlo.reshape %4 : (tensor<6x4x2047xf32>) -> tensor<6x4x2047x1xf32> + %6 = chlo.broadcast_subtract %arg, %5 + : (tensor<6x4x2047x4095xf32>, tensor<6x4x2047x1xf32>) + -> tensor<6x4x2047x4095xf32> + %7 = mhlo.exponential %6 : tensor<6x4x2047x4095xf32> + %8 = mhlo.convert %7 : tensor<6x4x2047x4095xf32> + %9 = mhlo.constant dense<-0.000000e+00> : tensor + %10 = mhlo.reduce(%8 init: %9) applies mhlo.add across dimensions = [3] + : (tensor<6x4x2047x4095xf32>, tensor) -> tensor<6x4x2047xf32> + %11 = mhlo.convert %10 : tensor<6x4x2047xf32> + %cst_0 = arith.constant dense<1> : tensor<1xi32> + %12 = mhlo.reshape %11 : (tensor<6x4x2047xf32>) -> tensor<6x4x2047x1xf32> + %13 = chlo.broadcast_divide %7, %12 + : (tensor<6x4x2047x4095xf32>, tensor<6x4x2047x1xf32>) + -> tensor<6x4x2047x4095xf32> + // CHECK: %[[COLLAPSE_SHAPE:.*]] = memref.collapse_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3{{\]\]}} : memref<6x4x2047x4095xf32> into memref<49128x4095xf32> + // CHECK: %[[COLLAPSE_SHAPE_2:.*]] = memref.collapse_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3{{\]\]}} : memref<6x4x2047x4095xf32> into memref<49128x4095xf32> + // CHECK: gpu.launch_func @imperfectly_tiled_softmax_4d_kernel::@imperfectly_tiled_softmax_4d_kernel + // CHECK-SAME: args(%[[COLLAPSE_SHAPE_2]] : memref<49128x4095xf32>, %[[COLLAPSE_SHAPE]] : memref<49128x4095xf32>) + // CHECK: return + memref.tensor_store %13, %resbuffer : memref<6x4x2047x4095xf32> + "lmhlo.terminator"() : () -> () +} +// CHECK: gpu.module @imperfectly_tiled_softmax_4d_kernel +// CHECK: llvm.func @imperfectly_tiled_softmax_4d_kernel(%[[ARG0_0:.*]]: !llvm.ptr, %[[ARG1_0:.*]]: !llvm.ptr diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/hlo_to_triton_pipeline_softmax.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/hlo_to_triton_pipeline_softmax.mlir new file mode 100644 index 00000000000..b00eed66c3d --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tests/hlo_to_triton_pipeline_softmax.mlir @@ -0,0 +1,56 @@ +// RUN: mlir-hlo-opt --split-input-file %s \ +// RUN: --hlo-to-triton-pipeline="block-tile=1" \ +// RUN: | FileCheck %s + +// CHECK: gpu.container_module +// CHECK-LABEL: @perfectly_tiled_softmax( +func.func @perfectly_tiled_softmax(%argbuffer : memref<2048x4096xf32>, + %resbuffer : memref<2048x4096xf32>) { + %arg = bufferization.to_tensor %argbuffer : memref<2048x4096xf32> + %0 = mhlo.constant dense<-1> : tensor<1xi64> + %1 = mhlo.convert %arg : tensor<2048x4096xf32> + %2 = mhlo.constant dense<0xFF800000> : tensor + %3 = mhlo.reduce(%1 init: %2) applies mhlo.maximum across dimensions = [1] + : (tensor<2048x4096xf32>, tensor) -> tensor<2048xf32> + %4 = mhlo.convert %3 : tensor<2048xf32> + %cst = arith.constant dense<1> : tensor<1xi32> + %5 = mhlo.reshape %4 : (tensor<2048xf32>) -> tensor<2048x1xf32> + %6 = chlo.broadcast_subtract %arg, %5 + : (tensor<2048x4096xf32>, tensor<2048x1xf32>) -> tensor<2048x4096xf32> + %7 = mhlo.exponential %6 : tensor<2048x4096xf32> + %8 = mhlo.convert %7 : tensor<2048x4096xf32> + %9 = mhlo.constant dense<-0.000000e+00> : tensor + %10 = mhlo.reduce(%8 init: %9) applies mhlo.add across dimensions = [1] + : (tensor<2048x4096xf32>, tensor) -> tensor<2048xf32> + %11 = mhlo.convert %10 : tensor<2048xf32> + %cst_0 = arith.constant dense<1> : tensor<1xi32> + %12 = mhlo.reshape %11 : (tensor<2048xf32>) -> tensor<2048x1xf32> + %13 = chlo.broadcast_divide %7, %12 + : (tensor<2048x4096xf32>, tensor<2048x1xf32>) -> tensor<2048x4096xf32> + // CHECK-DAG: %[[ONE:.*]] = arith.constant 1 + // CHECK-DAG: %[[GRID:.*]] = arith.constant 2048 + // CHECK: gpu.launch_func @[[MODULE:.*]]::@[[KERNEL:.*]] blocks + // CHECK-SAME: in (%[[GRID]], %[[ONE]], %[[ONE]]) + // CHECK-SAME: threads in (%[[ONE]], %[[ONE]], %[[ONE]]) + // CHECK-SAME: args({{.*}} : memref<2048x4096xf32>, + // CHECK-SAME: {{.*}} : memref<2048x4096xf32>) + memref.tensor_store %13, %resbuffer : memref<2048x4096xf32> + "lmhlo.terminator"() : () -> () +} +// CHECK: gpu.module @[[MODULE]] +// CHECK: gpu.func @[[KERNEL]] + /// TODO(b/261710844): This should be Triton Dialect. +// CHECK-SAME: %[[IN:.*]]: memref<2048x4096xf32>, +// CHECK-SAME: %[[OUT:.*]]: memref<2048x4096xf32>) kernel +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[BID:.*]] = gpu.block_id x +// CHECK: %[[VEC:.*]] = vector.transfer_read %[[IN]][%[[BID]], %[[C0]]] +// CHECK: %[[MAX:.*]] = vector.multi_reduction , %[[VEC]] +// CHECK: %[[BMAX:.*]] = vector.broadcast %[[MAX]] +// CHECK: %[[SUB:.*]] = arith.subf %[[VEC]], %[[BMAX]] +// CHECK: %[[EXP:.*]] = math.exp %[[SUB]] +// CHECK: %[[SUM:.*]] = vector.multi_reduction , %[[EXP]] +// CHECK: %[[BSUM:.*]] = vector.broadcast %[[SUM]] +// CHECK: %[[OUT_SV:.*]] = memref.subview %[[OUT]][%[[BID]], 0] [1, 4096] +// CHECK: %[[DIV:.*]] = arith.divf %[[EXP]], %[[BSUM]] +// CHECK: vector.transfer_write %[[DIV]], %[[OUT_SV]][%[[C0]], %[[C0]]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/index_type_llvm_lowering.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/index_type_llvm_lowering.mlir index 56fc8a77c57..35923372513 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/index_type_llvm_lowering.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/index_type_llvm_lowering.mlir @@ -1,16 +1,16 @@ // RUN: mlir-hlo-opt %s -gpu-kernel-to-nvvm | FileCheck %s gpu.module @test_module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry>} { - gpu.func @test_kernel() kernel { + gpu.func @test_kernel(%out: memref<32xf32>) kernel { %0 = gpu.block_id x + %cst = arith.constant 0.0 : f32 + memref.store %cst, %out[%0] : memref<32xf32> gpu.return } } // CHECK-LABEL: gpu.module @test_module // CHECK-SAME: attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry>} { -// CHECK-NEXT: llvm.func @test_kernel() attributes {gpu.kernel, nvvm.kernel} { -// CHECK-NEXT: %0 = nvvm.read.ptx.sreg.ctaid.x : i32 -// CHECK-NEXT: llvm.return -// CHECK-NEXT: } -// CHECK-NEXT: } +// CHECK-NEXT: llvm.func @test_kernel +// CHECK-SAME attributes {gpu.kernel, nvvm.kernel} +// CHECK: %[[VAR:.*]] = nvvm.read.ptx.sreg.ctaid.x : i32 diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/inline_fusion.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/inline_fusion.mlir deleted file mode 100644 index f8bc1c484af..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/tests/inline_fusion.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: mlir-hlo-opt --inline-fusion %s | FileCheck %s - -// CHECK-LABEL: func @fusion -func.func @fusion(%arg0: tensor<8xf32>) -> tensor<8xf32> { - - // CHECK-NEXT: %[[EXP:.*]] = mhlo.exponential %arg0 - %0 = "mhlo.fusion"(%arg0) ({ - ^bb0(%arg1: tensor<8xf32>): - %1 = mhlo.exponential %arg1 : tensor<8xf32> - mhlo.return %1 : tensor<8xf32> - }) : (tensor<8xf32>) -> tensor<8xf32> - - // CHECK-NEXT: return %[[EXP]] : tensor<8xf32> - return %0 : tensor<8xf32> -} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/lower_index_cast.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/lower_index_cast.mlir index 93ed4f277fa..8974b6ea8b0 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/lower_index_cast.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/lower_index_cast.mlir @@ -30,3 +30,31 @@ func.func @f(%arg0 : tensor) -> tensor { %0 = arith.index_cast %arg0 : tensor to tensor func.return %0 : tensor } + +// ----- + +// index_cast of dynamic multidimensional tensor +func.func @f(%arg0 : tensor<42x?xi32>) -> tensor<42x?xindex> { + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[DIM:.*]] = tensor.dim %arg0, %[[C1]] : tensor<42x?xi32> + // CHECK: %[[TENSOR:.*]] = tensor.generate %[[DIM]] { + // CHECK: ^bb0(%arg1: index, %arg2: index): + // CHECK: %[[E:.*]] = tensor.extract %arg0[%arg1, %arg2] : tensor<42x?xi32> + // CHECK: %[[C:.*]] = arith.index_cast %[[E]] : i32 to index + // CHECK: tensor.yield %[[C]] : index + // CHECK: } : tensor<42x?xindex> + // CHECK: return %[[TENSOR]] : tensor<42x?xindex> + %0 = arith.index_cast %arg0 : tensor<42x?xi32> to tensor<42x?xindex> + func.return %0 : tensor<42x?xindex> +} + +// ----- + +// CHECK-LABEL: func @index_castui +func.func @index_castui(%arg0 : tensor<10xi32>) -> tensor<10xindex> { + // CHECK: tensor.generate { + // CHECK: %[[C:.*]] = arith.index_castui + // CHECK: tensor.yield %[[C]] : index + %0 = arith.index_castui %arg0 : tensor<10xi32> to tensor<10xindex> + func.return %0 : tensor<10xindex> +} diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/python/attributes.py b/tensorflow/compiler/xla/mlir_hlo/tests/python/attributes.py index d21a990531f..706dafc74b1 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/python/attributes.py +++ b/tensorflow/compiler/xla/mlir_hlo/tests/python/attributes.py @@ -14,226 +14,205 @@ # ============================================================================== """Test for Python APIs accessing MHLO attributes.""" -# pylint: disable=wildcard-import,undefined-variable +# pylint: disable=wildcard-import,undefined-variable,missing-function-docstring -from mlir.dialects.mhlo import * -from mlir.ir import * +from mlir import ir +from mlir.dialects import mhlo def run(f): - with Context() as context: - register_mhlo_dialect(context) + with ir.Context() as context: + mhlo.register_mhlo_dialect(context) f() return f @run -def test_scatter_dimension_numbers(): - """Check that ScatterDimensionNumbers attributes is available and usable.""" - - attr = ScatterDimensionNumbers.get( - update_window_dims=[1, 2, 3], - inserted_window_dims=[4, 6], - scattered_dims_to_operand_dims=[6, 7], - index_vector_dim=8) +def test_channel_handle(): + attr = mhlo.ChannelHandle.get(handle=1, type=2) assert attr is not None - assert str(attr) == ("#mhlo.scatter") - assert attr.update_window_dims == [1, 2, 3] - assert attr.inserted_window_dims == [4, 6] - assert attr.scattered_dims_to_operand_dims == [6, 7] - assert attr.index_vector_dim == 8 + assert attr.handle == 1 + assert attr.channel_type == 2 @run -def test_gather_dimension_numbers(): - """Check that GatherDimensionNumbers attributes is available and usable.""" - - attr = GatherDimensionNumbers.get( - offset_dims=[1, 2], - collapsed_slice_dims=[3, 4, 5], - start_index_map=[6], - index_vector_dim=7) +def test_comparison_direction_attr(): + attr = mhlo.ComparisonDirectionAttr.get("EQ") assert attr is not None - assert str(attr) == ("#mhlo.gather") - assert attr.offset_dims == [1, 2] - assert attr.collapsed_slice_dims == [3, 4, 5] - assert attr.start_index_map == [6] - assert attr.index_vector_dim == 7 + assert str(attr) == "#mhlo" + assert attr.value == "EQ" @run -def test_dot_dimension_numbers(): - """Check that DotDimensionNumbers attributes is available and usable.""" - - attr = DotDimensionNumbers.get( - lhs_batching_dimensions=[0, 1], - rhs_batching_dimensions=[2, 3], - lhs_contracting_dimensions=[4, 5], - rhs_contracting_dimensions=[6, 7]) +def test_comparison_type_attr(): + attr = mhlo.ComparisonTypeAttr.get("FLOAT") assert attr is not None - assert str(attr) == ("#mhlo.dot") - assert attr.lhs_batching_dimensions == [0, 1] - assert attr.rhs_batching_dimensions == [2, 3] - assert attr.lhs_contracting_dimensions == [4, 5] - assert attr.rhs_contracting_dimensions == [6, 7] + assert str(attr) == "#mhlo" + assert attr.value == "FLOAT" @run def test_conv_dimension_numbers(): - """Check that DotDimensionNumbers attributes is available and usable.""" - - attr = ConvDimensionNumbers.get( + attr = mhlo.ConvDimensionNumbers.get( input_batch_dimension=0, - input_feature_dimension=4, - input_spatial_dimensions=[1, 2, 3], - kernel_input_feature_dimension=1, - kernel_output_feature_dimension=2, - kernel_spatial_dimensions=[0, 3], - output_batch_dimension=1, - output_feature_dimension=3, - output_spatial_dimensions=[0, 2]) - assert str(attr) == "#mhlo.conv<[b, 0, 1, 2, f]x[0, i, o, 1]->[0, b, 1, f]>" + input_feature_dimension=1, + input_spatial_dimensions=[2, 3, 4], + kernel_input_feature_dimension=0, + kernel_output_feature_dimension=1, + kernel_spatial_dimensions=[2, 3], + output_batch_dimension=0, + output_feature_dimension=1, + output_spatial_dimensions=[2, 3], + ) + assert str(attr) == "#mhlo.conv<[b, f, 0, 1, 2]x[i, o, 0, 1]->[b, f, 0, 1]>" assert attr is not None assert attr.input_batch_dimension == 0 - assert attr.input_feature_dimension == 4 - assert attr.input_spatial_dimensions == [1, 2, 3] - assert attr.kernel_input_feature_dimension == 1 - assert attr.kernel_output_feature_dimension == 2 - assert attr.kernel_spatial_dimensions == [0, 3] - assert attr.output_batch_dimension == 1 - assert attr.output_feature_dimension == 3 - assert attr.output_spatial_dimensions == [0, 2] + assert attr.input_feature_dimension == 1 + assert attr.input_spatial_dimensions == [2, 3, 4] + assert attr.kernel_input_feature_dimension == 0 + assert attr.kernel_output_feature_dimension == 1 + assert attr.kernel_spatial_dimensions == [2, 3] + assert attr.output_batch_dimension == 0 + assert attr.output_feature_dimension == 1 + assert attr.output_spatial_dimensions == [2, 3] @run -def test_output_operand_alias(): - """Check that OutputOperandAlias attributes is available and usable.""" - - attr = OutputOperandAlias.get( - output_tuple_indices=[0], - operand_index=0, - operand_tuple_indices=[1]) - assert str(attr) == ("#mhlo.output_operand_alias") - assert attr.output_tuple_indices == [0] - assert attr.operand_index == 0 - assert attr.operand_tuple_indices == [1] +def test_dequantize_mode_attr(): + attr = mhlo.DequantizeModeAttr.get("MIN_COMBINED") + assert attr is not None + assert str(attr) == "#mhlo" + assert attr.value == "MIN_COMBINED" @run -def test_comparison_direction(): - """Check that ComparisonDirection attribute is available and usable.""" - - attr = ComparisonDirectionAttr.get("EQ") +def test_dot_dimension_numbers(): + attr = mhlo.DotDimensionNumbers.get( + lhs_batching_dimensions=[0, 1], + rhs_batching_dimensions=[2, 3], + lhs_contracting_dimensions=[4, 5], + rhs_contracting_dimensions=[6, 7], + ) assert attr is not None - assert str(attr) == ("#mhlo") - assert attr.comparison_direction == "EQ" + assert str(attr) == ("#mhlo.dot") + assert attr.lhs_batching_dimensions == [0, 1] + assert attr.rhs_batching_dimensions == [2, 3] + assert attr.lhs_contracting_dimensions == [4, 5] + assert attr.rhs_contracting_dimensions == [6, 7] @run -def test_comparison_type(): - """Check that ComparisonType attribute is available and usable.""" - - attr = ComparisonTypeAttr.get("TOTALORDER") +def test_fft_type_attr(): + attr = mhlo.FftTypeAttr.get("FFT") assert attr is not None - assert str(attr) == ("#mhlo") - assert attr.comparison_type == "TOTALORDER" + assert str(attr) == "#mhlo" + assert attr.value == "FFT" @run -def test_precision(): - """Check that Precision attribute is available and usable.""" - - attr = PrecisionAttr.get("DEFAULT") +def test_fusion_kind_attr(): + attr = mhlo.FusionKindAttr.get("kLoop") assert attr is not None - assert str(attr) == ("#mhlo") - assert attr.precision_type == "DEFAULT" + assert str(attr) == "#mhlo" + assert attr.value == "kLoop" @run -def test_fft_type(): - """Check that FftType attribute is available and usable.""" - - attr = FftTypeAttr.get("FFT") +def test_gather_dimension_numbers(): + attr = mhlo.GatherDimensionNumbers.get( + offset_dims=[1, 2], + collapsed_slice_dims=[3, 4, 5], + start_index_map=[6], + index_vector_dim=7, + ) assert attr is not None - assert str(attr) == ("#mhlo") - assert attr.fft_type == "FFT" + assert ( + str(attr) + == "#mhlo.gather" + ) + assert attr.offset_dims == [1, 2] + assert attr.collapsed_slice_dims == [3, 4, 5] + assert attr.start_index_map == [6] + assert attr.index_vector_dim == 7 @run -def test_dequantize_mode(): - """Check that DequantizeMode attribute is available and usable.""" - - attr = DequantizeModeAttr.get("MIN_COMBINED") +def test_output_operand_alias(): + attr = mhlo.OutputOperandAlias.get( + output_tuple_indices=[0], operand_index=0, operand_tuple_indices=[1] + ) assert attr is not None - assert str(attr) == ("#mhlo") - assert attr.dequantize_mode == "MIN_COMBINED" + assert str(attr) == ("#mhlo.output_operand_alias") + assert attr.output_tuple_indices == [0] + assert attr.operand_index == 0 + assert attr.operand_tuple_indices == [1] @run -def test_transpose_type(): - """Check that Transpose attribute is available and usable.""" - - attr = TransposeAttr.get("TRANSPOSE") +def test_precision_attr(): + attr = mhlo.PrecisionAttr.get("DEFAULT") assert attr is not None - assert str(attr) == ("#mhlo") - assert attr.transpose_type == "TRANSPOSE" + assert str(attr) == "#mhlo" + assert attr.value == "DEFAULT" @run -def test_fusion_kind(): - """Check that FusionKind attribute is available and usable.""" - - attr = FusionKindAttr.get("kLoop") +def test_rng_algorithm_attr(): + attr = mhlo.RngAlgorithmAttr.get("DEFAULT") assert attr is not None - assert str(attr) == ("#mhlo") - assert attr.fusion_kind == "kLoop" + assert str(attr) == "#mhlo.rng_algorithm" + assert attr.value == "DEFAULT" @run -def test_rng_distribution(): - """Check that RngDistribution attribute is available and usable.""" - - attr = RngDistributionAttr.get("UNIFORM") +def test_rng_distribution_attr(): + attr = mhlo.RngDistributionAttr.get("UNIFORM") assert attr is not None - assert str(attr) == ("#mhlo.rng_distribution") - assert attr.rng_distribution == "UNIFORM" + assert str(attr) == "#mhlo.rng_distribution" + assert attr.value == "UNIFORM" @run -def test_rng_algorithm(): - """Check that RngAlgorithm attribute is available and usable.""" - - attr = RngAlgorithmAttr.get("DEFAULT") +def test_scatter_dimension_numbers(): + attr = mhlo.ScatterDimensionNumbers.get( + update_window_dims=[1, 2, 3], + inserted_window_dims=[4, 5], + scattered_dims_to_operand_dims=[6, 7], + index_vector_dim=8, + ) assert attr is not None - assert str(attr) == ("#mhlo.rng_algorithm") - assert attr.rng_algorithm == "DEFAULT" + assert ( + str(attr) + == "#mhlo.scatter" + ) + assert attr.update_window_dims == [1, 2, 3] + assert attr.inserted_window_dims == [4, 5] + assert attr.scattered_dims_to_operand_dims == [6, 7] + assert attr.index_vector_dim == 8 @run -def test_channel_handle(): - """Check that ChannelHandle attribute is available and usable.""" - - attr = ChannelHandle.get(handle=1, type=2) +def test_transpose_attr(): + attr = mhlo.TransposeAttr.get("TRANSPOSE") assert attr is not None - assert attr.handle == 1 - assert attr.channel_type == 2 + assert str(attr) == "#mhlo" + assert attr.value == "TRANSPOSE" @run def test_type_extensions(): - """Check that TypeExtensions attribute is available and usable.""" - - attr = TypeExtensions.get(bounds=[128, -1]) + dyn_size = ir.ShapedType.get_dynamic_size() + attr = mhlo.TypeExtensions.get(bounds=[128, dyn_size]) assert attr is not None - assert attr.bounds == [128, -1] + assert attr.bounds == [128, dyn_size] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/python/types.py b/tensorflow/compiler/xla/mlir_hlo/tests/python/types.py index 8458fc5b70c..d66dd70cf54 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/python/types.py +++ b/tensorflow/compiler/xla/mlir_hlo/tests/python/types.py @@ -14,20 +14,21 @@ # ============================================================================== """Test for Python APIs accessing MHLO types.""" -# pylint: disable=wildcard-import,undefined-variable +# pylint: disable=wildcard-import,undefined-variable,missing-function-docstring -from mlir.dialects.mhlo import * -from mlir.ir import * +from mlir import ir +from mlir.dialects import mhlo def run(f): - with Context() as context: - register_mhlo_dialect(context) + with ir.Context() as context: + mhlo.register_mhlo_dialect(context) f() return f @run def test_token_type(): - """Check that the Token type is available.""" - assert str(TokenType.get()) == "!mhlo.token" + token_type = mhlo.TokenType.get() + assert token_type is not None + assert str(token_type) == "!mhlo.token" diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/scalarization.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/scalarization.mlir index affc8e8e105..4d350c8da71 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tests/scalarization.mlir +++ b/tensorflow/compiler/xla/mlir_hlo/tests/scalarization.mlir @@ -208,60 +208,73 @@ func.func @scatter_f32(%indices: tensor<1x2xindex>, // CHECK-DAG: %[[UPDATES_DIM_2:.*]] = tensor.dim %[[UPDATES]], %[[C2]] // CHECK-DAG: %[[INIT_DIM_0:.*]] = tensor.dim %[[INIT]], %[[C0]] // CHECK-DAG: %[[INIT_DIM_1:.*]] = tensor.dim %[[INIT]], %[[C1]] -// CHECK-DAG: %[[INIT_TILE:.*]] = gml_st.tile [0, 0] [%[[INIT_DIM_0]], %[[INIT_DIM_1]] // Extract scatter indices from `indices` arg. // CHECK-DAG: %[[INDEX_0:.*]] = tensor.extract %[[INDICES]][%[[C0]], // CHECK-DAG: %[[INDEX_1:.*]] = tensor.extract %[[INDICES]][%[[C0]], -// Iterate over indow dimensions.. -// CHECK-NEXT: %[[SCATTER:.*]] = gml_st.for (%[[I:.*]], %[[J:.*]]) = (%[[C0]], -// CHECK-SAME: %[[C0]]) to (%[[UPDATES_DIM_1]], %[[UPDATES_DIM_2]]) -// CHECK-SAME: step (%[[C1]], %[[C1]]) -// CHECK-SAME: outs (%[[INIT_:.*]] = %[[INIT]]: tensor) { - -// Check whetherthe index to update is not out-of-bounds. -// CHECK-NEXT: %[[I_PLUS_INDEX_0:.*]] = arith.addi %[[I]], %[[INDEX_0]] -// CHECK-NEXT: %[[J_PLUS_INDEX_1:.*]] = arith.addi %[[J]], %[[INDEX_1]] -// CHECK-NEXT: arith.cmpi sge, %[[I_PLUS_INDEX_0]], %[[C0]] -// CHECK-NEXT: arith.cmpi slt, %[[I_PLUS_INDEX_0]], %[[INIT_DIM_0]] +// Check bounds of the slice. +// CHECK-NEXT: %[[DIM_1_PLUS_INDEX_0:.*]] = arith.addi %[[UPDATES_DIM_1]], %[[INDEX_0]] +// CHECK-NEXT: %[[DIM_2_PLUS_INDEX_1:.*]] = arith.addi %[[UPDATES_DIM_2]], %[[INDEX_1]] +// CHECK-NEXT: %[[LIMIT_DIM_0:.*]] = arith.subi %[[DIM_1_PLUS_INDEX_0]], %[[C1]] +// CHECK-NEXT: %[[LIMIT_DIM_1:.*]] = arith.subi %[[DIM_2_PLUS_INDEX_1]], %[[C1]] +// CHECK-NEXT: arith.cmpi sge, %[[LIMIT_DIM_0]], %[[C0]] +// CHECK-NEXT: arith.cmpi slt, %[[LIMIT_DIM_0]], %[[INIT_DIM_0]] +// CHECK-NEXT: arith.andi +// CHECK-NEXT: arith.cmpi sge, %[[LIMIT_DIM_1]], %[[C0]] +// CHECK-NEXT: arith.cmpi slt, %[[LIMIT_DIM_1]], %[[INIT_DIM_1]] +// CHECK-NEXT: arith.andi +// CHECK-NEXT: arith.andi +// CHECK-NEXT: arith.cmpi sge, %[[INDEX_0]], %[[C0]] +// CHECK-NEXT: arith.cmpi slt, %[[INDEX_0]], %[[INIT_DIM_0]] +// CHECK-NEXT: arith.andi +// CHECK-NEXT: arith.cmpi sge, %[[INDEX_1]], %[[C0]] +// CHECK-NEXT: arith.cmpi slt, %[[INDEX_1]], %[[INIT_DIM_1]] // CHECK-NEXT: arith.andi -// CHECK-NEXT: arith.cmpi sge, %[[J_PLUS_INDEX_1]], %[[C0]] -// CHECK-NEXT: arith.cmpi slt, %[[J_PLUS_INDEX_1]], %[[INIT_DIM_1]] // CHECK-NEXT: arith.andi // CHECK-NEXT: %[[VALID_ACCESS:.*]] = arith.andi +// CHECK-NEXT: %[[RESULT:.*]] = scf.if %[[VALID_ACCESS]] + +// Iterate over window dimensions.. +// CHECK-NEXT: %[[SCATTER:.*]] = scf.for %[[K:.*]] = %[[C0]] +// CHECK-SAME: to %[[C1]] step %[[C1]] +// CHECK-SAME: iter_args(%[[INIT_:.*]] = %[[INIT]]) + +// CHECK-NEXT: %[[SCATTER_:.*]] = scf.for %[[I:.*]] = %[[C0]] +// CHECK-SAME: to %[[UPDATES_DIM_1]] step %[[C1]] +// CHECK-SAME: iter_args(%[[INIT__:.*]] = %[[INIT_]]) + +// CHECK-NEXT: %[[SCATTER__:.*]] = scf.for %[[J:.*]] = %[[C0]] +// CHECK-SAME: to %[[UPDATES_DIM_2]] step %[[C1]] +// CHECK-SAME: iter_args(%[[INIT___:.*]] = %[[INIT__]]) + +// CHECK-NEXT: %[[I_PLUS_INDEX_0:.*]] = arith.addi %[[I]], %[[INDEX_0]] +// CHECK-NEXT: %[[J_PLUS_INDEX_1:.*]] = arith.addi %[[J]], %[[INDEX_1]] // Extracts elements of `updates` and `init` tensors and combine. -// CHECK-NEXT: %[[INIT_AFTER_INSERTION:.*]] = scf.if %[[VALID_ACCESS]] -// CHECK-NEXT: %[[UPDATES_ELEM_TILE:.*]] = gml_st.tile -// CHECK-SAME: [%[[C0]], %[[I]], %[[J]]] [1, 1, 1] [1, 1, 1] -// CHECK-SAME: : !gml_st.tile<1x1x1> -// CHECK-NEXT: %[[UPDATES_ELEM:.*]] = gml_st.materialize %[[UPDATES]] -// CHECK-SAME: [%[[UPDATES_ELEM_TILE]]] -// CHECK-SAME: : tensor<1x?x?xf32>[!gml_st.tile<1x1x1>] to f32 - -// CHECK-NEXT: %[[INIT_ELEM_TILE:.*]] = gml_st.tile -// CHECK-SAME: [%[[I_PLUS_INDEX_0]], %[[J_PLUS_INDEX_1]]] [1, 1] [1, 1] -// CHECK-SAME: : !gml_st.tile<1x1> -// CHECK-NEXT: %[[INIT_ELEM:.*]] = gml_st.materialize %[[INIT_]] -// CHECK-SAME: [%[[INIT_ELEM_TILE]]] : tensor[!gml_st.tile<1x1>] to f32 - -// CHECK-NEXT: %[[COMBINED_ELEMS:.*]] = arith.addf %[[UPDATES_ELEM]], -// CHECK-SAME: %[[INIT_ELEM]] : f32 - -// CHECK-NEXT: %[[UPDATED_INIT:.*]] = tensor.insert %[[COMBINED_ELEMS]] -// CHECK-SAME: into %[[INIT_]][%[[I_PLUS_INDEX_0]], %[[J_PLUS_INDEX_1]]] -// CHECK-SAME: : tensor -// CHECK-NEXT: scf.yield %[[UPDATED_INIT]] : tensor +// CHECK-NEXT: %[[UPDATES_SLICE:.*]] = tensor.extract_slice %[[UPDATES]] +// CHECK-SAME: [%[[K]], %[[I]], %[[J]]] [1, 1, 1] [1, 1, 1] +// CHECK-SAME: : tensor<1x?x?xf32> to tensor<1x1x1xf32> +// CHECK-NEXT: %[[UPDATES_ELEM:.*]] = tensor.extract %[[UPDATES_SLICE]] + +// CHECK-NEXT: %[[INIT_SLICE:.*]] = tensor.extract_slice %[[INIT___]] +// CHECK-SAME: [%[[I_PLUS_INDEX_0]], %[[J_PLUS_INDEX_1]]] [1, 1] [1, 1] +// CHECK-SAME: : tensor to tensor<1x1xf32> +// CHECK-NEXT: %[[INIT_ELEM:.*]] = tensor.extract %[[INIT_SLICE]] + +// CHECK-NEXT: %[[COMBINED_ELEMS:.*]] = arith.addf %[[UPDATES_ELEM]], +// CHECK-SAME: %[[INIT_ELEM]] : f32 + +// CHECK-NEXT: %[[UPDATED_INIT:.*]] = tensor.insert %[[COMBINED_ELEMS]] +// CHECK-SAME: into %[[INIT___]][%[[I_PLUS_INDEX_0]], %[[J_PLUS_INDEX_1]]] +// CHECK-SAME: : tensor +// CHECK-NEXT: scf.yield %[[UPDATED_INIT]] : tensor +// CHECK: scf.yield %[[SCATTER_]] : tensor +// CHECK: scf.yield %[[SCATTER]] : tensor // CHECK-NEXT: } else { -// CHECK-NEXT: scf.yield %[[INIT_]] : tensor +// CHECK-NEXT: scf.yield %[[INIT]] : tensor // CHECK-NEXT: } - -// CHECK-NEXT: gml_st.set_yield %[[INIT_AFTER_INSERTION]] -// CHECK-SAME: into %[[INIT_]][%[[INIT_TILE]]] -// CHECK-SAME: : tensor into tensor[!gml_st.tile] -// CHECK-NEXT: } : tensor -// CHECK-NEXT: return %[[SCATTER]] : tensor +// CHECK-NEXT: return %[[RESULT]] : tensor // ----- @@ -286,74 +299,70 @@ func.func @scatter_i64(%indices: tensor<1x1xindex>, // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[INIT_TILE:.*]] = gml_st.tile [0, 0, 0] [3, 3, 4] -// CHECK-DAG: %[[INDEX_0:.*]] = tensor.extract %[[INDICES]] - -// CHECK: gml_st.for (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) -// CHECK-SAME: to (%[[C3]], %[[C4]]) step (%[[C1]], %[[C1]]) -// CHECK-SAME: outs (%[[INIT_:.*]] = %[[INIT]]: tensor<3x3x4xi64>) { -// CHECK: arith.cmpi sge, %[[INDEX_0]], %[[C0]] -// CHECK: arith.cmpi slt, %[[INDEX_0]], %[[C3]] -// CHECK: arith.andi -// CHECK: arith.cmpi sge, %[[I]], %[[C0]] -// CHECK: arith.cmpi slt, %[[I]], %[[C3]] -// CHECK: arith.andi -// CHECK: arith.andi -// CHECK: arith.cmpi sge, %[[J]], %[[C0]] -// CHECK: arith.cmpi slt, %[[J]], %[[C4]] -// CHECK: arith.andi -// CHECK: %[[VALID_ACCESS:.*]] = arith.andi - -// CHECK: %[[INIT_AFTER_INSERTION:.*]] = scf.if %[[VALID_ACCESS]] -// CHECK: %[[UPDATES_TILE:.*]] = gml_st.tile -// CHECK-SAME: [%[[C0]], %[[C0]], %[[I]], %[[J]]] -// CHECK-SAME: [1, 1, 1, 1] [1, 1, 1, 1] -// CHECK-SAME: : !gml_st.tile<1x1x1x1> -// CHECK: %[[UPDATES_ELEM:.*]] = gml_st.materialize %[[UPDATES]] -// CHECK-SAME: [%[[UPDATES_TILE]]] : tensor<1x1x3x4xi64>[{{.*}}] to i64 - -// CHECK: %[[UPDATED_INIT:.*]] = tensor.insert %[[UPDATES_ELEM]] into -// CHECK-SAME: %[[INIT_]][%[[INDEX_0]], %[[I]], %[[J]]] : tensor<3x3x4xi64> - -// CHECK: scf.yield %[[UPDATED_INIT]] : tensor<3x3x4xi64> -// CHECK: } else { -// CHECK: scf.yield %[[INIT_]] : tensor<3x3x4xi64> -// CHECK: } -// CHECK: gml_st.set_yield %[[INIT_AFTER_INSERTION:.*]] into %[[INIT_]] -// CHECK-SAME: [%[[INIT_TILE]]] : tensor<3x3x4xi64> -// CHECK-SAME: into tensor<3x3x4xi64>[!gml_st.tile<3x3x4>] -// CHECK: } : tensor<3x3x4xi64> +// CHECK-DAG: %[[INDEX_0:.*]] = tensor.extract %[[INDICES]] + +// CHECK: %[[SCATTER:.*]] = scf.for %[[K:.*]] = %[[C0]] +// CHECK-SAME: to %[[C1]] step %[[C1]] +// CHECK-SAME: iter_args(%[[INIT_:.*]] = %[[INIT]]) + +// CHECK-NEXT: %[[SCATTER_:.*]] = scf.for %[[L:.*]] = %[[C0]] +// CHECK-SAME: to %[[C1]] step %[[C1]] +// CHECK-SAME: iter_args(%[[INIT__:.*]] = %[[INIT_]]) + +// CHECK-NEXT: %[[SCATTER__:.*]] = scf.for %[[I:.*]] = %[[C0]] +// CHECK-SAME: to %[[C3]] step %[[C1]] +// CHECK-SAME: iter_args(%[[INIT___:.*]] = %[[INIT__]]) + +// CHECK-NEXT: %[[SCATTER___:.*]] = scf.for %[[J:.*]] = %[[C0]] +// CHECK-SAME: to %[[C4]] step %[[C1]] +// CHECK-SAME: iter_args(%[[INIT____:.*]] = %[[INIT___]]) + + +// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[L]], %[[INDEX_0]] +// CHECK: %[[UPDATES_SLICE:.*]] = tensor.extract_slice %[[UPDATES]] +// CHECK-SAME: [%[[K]], %[[L]], %[[I]], %[[J]]] +// CHECK-SAME: [1, 1, 1, 1] [1, 1, 1, 1] +// CHECK-SAME: : tensor<1x1x3x4xi64> to tensor<1x1x1x1xi64> +// CHECK-NEXT: %[[UPDATES_ELEM:.*]] = tensor.extract %[[UPDATES_SLICE]] + +// CHECK: %[[UPDATED_INIT:.*]] = tensor.insert %[[UPDATES_ELEM]] into +// CHECK-SAME: %[[INIT____]][%[[OFFSET]], %[[I]], %[[J]]] : tensor<3x3x4xi64> + +// CHECK-NEXT: scf.yield %[[UPDATED_INIT]] // ----- func.func @gather(%indices: tensor<1x2xindex>, - %operand: tensor<4x5x6xi64>, - %init: tensor<1x4xi64>) -> tensor<1x4xi64> { - %0 = thlo.gather ins(%operand : tensor<4x5x6xi64>, + %operand: tensor<5x6x7xi64>, + %init: tensor<1x3xi64>) -> tensor<1x3xi64> { + %0 = thlo.gather ins(%operand : tensor<5x6x7xi64>, %indices : tensor<1x2xindex>) - outs(%init : tensor<1x4xi64>) - func.return %0 : tensor<1x4xi64> + outs(%init : tensor<1x3xi64>) + func.return %0 : tensor<1x3xi64> } // CHECK-LABEL: func.func @gather( // CHECK-SAME: %[[INDICES:.*]]: tensor<1x2xindex> -// CHECK-SAME: %[[OPERAND:.*]]: tensor<4x5x6xi64> -// CHECK-SAME: %[[INIT:.*]]: tensor<1x4xi64> +// CHECK-SAME: %[[OPERAND:.*]]: tensor<5x6x7xi64> +// CHECK-SAME: %[[INIT:.*]]: tensor<1x3xi64> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 // CHECK-DAG: %[[C3:.*]] = arith.constant 3 -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 // CHECK-DAG: %[[INDEX0:.*]] = tensor.extract %[[INDICES]][%[[C0]], %[[C0]]] // CHECK-DAG: %[[INDEX1:.*]] = tensor.extract %[[INDICES]][%[[C0]], %[[C1]]] -// CHECK: gml_st.for (%[[J:.*]]) = (%[[C0]]) to (%[[C4]]) -// CHECK-DAG: %[[OFFSET_J:.*]] = arith.addi %[[J]], %[[INDEX1]] -// CHECK-DAG: %[[MIN_J:.*]] = arith.minsi %[[OFFSET_J]], %[[C4]] -// CHECK-DAG: %[[CLAMPED_J:.*]] = arith.maxsi %[[MIN_J]], %[[C0]] -// CHECK-DAG: %[[MIN_I:.*]] = arith.minsi %[[INDEX0]], %[[C3]] -// CHECK-DAG: %[[CLAMPED_I:.*]] = arith.maxsi %[[MIN_I]], %[[C0]] +// CHECK-DAG: %[[CLAMPED_INDEX0:.*]] = arith.minsi %[[INDEX0]], %[[C2]] +// CHECK-DAG: %[[CLAMPED_INDEX0_:.*]] = arith.maxsi %[[CLAMPED_INDEX0]], %[[C0]] +// CHECK-DAG: %[[CLAMPED_INDEX1:.*]] = arith.minsi %[[INDEX1]], %[[C5]] +// CHECK-DAG: %[[CLAMPED_INDEX1_:.*]] = arith.maxsi %[[CLAMPED_INDEX1]], %[[C0]] +// CHECK: gml_st.for (%[[J:.*]]) = (%[[C0]]) to (%[[C3]]) +// CHECK-DAG: %[[OFFSET_J:.*]] = arith.addi %[[J]], %[[CLAMPED_INDEX0_]] // CHECK: %[[INIT_TILE:.*]] = gml_st.tile [%[[C0]], %[[J]]] -// CHECK: %[[OPERAND_TILE:.*]] = gml_st.tile [%[[CLAMPED_I]], %[[CLAMPED_J]], %[[C0]]] -// CHECK: %[[VAL:.*]] = gml_st.materialize %[[OPERAND]][%[[OPERAND_TILE]]] + +// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[OPERAND]] +// CHECK-SAME: [%[[OFFSET_J]], %[[CLAMPED_INDEX1_]], 0] +// CHECK-NEXT: %[[VAL:.*]] = tensor.extract %[[SLICE]] // CHECK: gml_st.set_yield %[[VAL]] into {{.*}}[%[[INIT_TILE]]] // ----- @@ -365,26 +374,27 @@ func.func @fold_extract_from_elements_into_gml_st(%in: tensor<8x2xf32>, %c2 = arith.constant 2 : index %c8 = arith.constant 8 : index - %copy = gml_st.parallel (%i, %j) = (%c0, %c0) to (%c8, %c2) step (%c1, %c1) { - %tile = gml_st.tile [%i, %j] [1, 1] [1, 1] : !gml_st.tile<1x1> - - %in_sub = gml_st.materialize %in[%tile] - : tensor<8x2xf32>[!gml_st.tile<1x1>] to tensor<1x1xf32> + %copy = gml_st.parallel (%i, %j) = (%c0, %c0) to (%c8, %c2) step (%c1, %c1) + outs (%out_ = %out: tensor<8x2xf32>) { + %in_sub = tensor.extract_slice %in[%i, %j] [1, 1] [1, 1] + : tensor<8x2xf32> to tensor<1x1xf32> %elem = tensor.extract %in_sub[%c0, %c0] : tensor<1x1xf32> %out_sub = tensor.from_elements %elem : tensor<1x1xf32> - gml_st.set_yield %out_sub into %out[%tile] + %tile = gml_st.tile [%i, %j] [1, 1] [1, 1] : !gml_st.tile<1x1> + gml_st.set_yield %out_sub into %out_[%tile] : tensor<1x1xf32> into tensor<8x2xf32>[!gml_st.tile<1x1>] } : tensor<8x2xf32> func.return %copy: tensor<8x2xf32> } // CHECK-LABEL: func @fold_extract_from_elements_into_gml_st -// CHECK: = gml_st.tile -// CHECK-NEXT: %[[ELEM:.*]] = gml_st.materialize -// CHECK-SAME: : tensor<8x2xf32>[!gml_st.tile<1x1>] to f32 +// CHECK: %[[SLICE:.*]] = tensor.extract_slice +// CHECK-SAME: : tensor<8x2xf32> to tensor<1x1xf32> +// CHECK-NEXT: %[[ELEM:.*]] = tensor.extract %[[SLICE]] +// CHECK-NEXT: = gml_st.tile // CHECK-NEXT: gml_st.set_yield %[[ELEM]] // CHECK-SAME: : f32 into tensor<8x2xf32>[!gml_st.tile<1x1>] @@ -403,7 +413,7 @@ func.func @dynamic_broadcast_in_dim(%arg : tensor<1x1xf32>, // CHECK-SAME: %[[ARG:.*]]: tensor<1x1xf32>, %[[INIT:.*]]: tensor<1x1x1xf32>) // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK-NEXT: %[[ELEM:.*]] = tensor.extract %[[ARG]][%[[C0]], %[[C0]]] -// CHECK-NEXT: %[[UPDATED:.*]] = tensor.insert %[[ELEM]] into %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK-NEXT: %[[UPDATED:.*]] = tensor.from_elements %[[ELEM]] // ----- @@ -415,7 +425,7 @@ func.func @concatenate( %arg1: tensor, %arg2: tensor) outs(%init: tensor) - { dimension = 1 : i64 } + dimension = 1 func.return %cat : tensor } @@ -432,14 +442,13 @@ func.func @concatenate( // CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]] // CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[INIT]], %[[C2]] -// CHECK-NEXT: %[[TILE:.*]] = gml_st.tile [0, 0, 0] -// CHECK-SAME: [%[[DIM0]], 1, %[[DIM2]]] [1, 1, 1] // Extract elements from arg0 is it's not empty. // CHECK-NEXT: %[[DIM_ARG_0:.*]] = tensor.dim %[[ARG_0]], %[[C1]] // CHECK-NEXT: %[[CMP_0:.*]] = arith.cmpi ne, %[[DIM_ARG_0]], %[[C0]] // CHECK: %[[RESULT:.*]] = scf.if %[[CMP_0]] -// CHECK: %[[MAT_0:.*]] = gml_st.materialize %[[ARG_0]][%[[TILE]]] +// CHECK: %[[MAT_0:.*]] = tensor.extract_slice %[[ARG_0]] +// CHECK-SAME: [0, 0, 0] [%[[DIM0]], 1, %[[DIM2]]] [1, 1, 1] // CHECK: %[[RES_0:.*]] = tensor.insert_slice %[[MAT_0]] into %[[INIT]] // CHECK-NEXT: scf.yield %[[RES_0]] // CHECK-NEXT: } else { @@ -448,13 +457,15 @@ func.func @concatenate( // CHECK-NEXT: %[[DIM_ARG_1:.*]] = tensor.dim %[[ARG_1]], %[[C1]] // CHECK-NEXT: %[[CMP_1:.*]] = arith.cmpi ne, %[[DIM_ARG_1]], %[[C0]] // CHECK-NEXT: %[[RESULT_1:.*]] = scf.if %[[CMP_1]] -// CHECK-NEXT: %[[MAT_1:.*]] = gml_st.materialize %[[ARG_1]][%[[TILE]]] +// CHECK-NEXT: %[[MAT_1:.*]] = tensor.extract_slice %[[ARG_1]] +// CHECK-SAME: [0, 0, 0] [%[[DIM0]], 1, %[[DIM2]]] [1, 1, 1] // CHECK-NEXT: %[[RES_1:.*]] = tensor.insert_slice %[[MAT_1]] into %[[INIT]] // CHECK-NEXT: scf.yield %[[RES_1]] // CHECK-NEXT: } else { // Otherwise extract elements from arg2, because arg0 and arg1 are empty. -// CHECK-NEXT: %[[MAT_2:.*]] = gml_st.materialize %[[ARG_2]][%[[TILE]]] +// CHECK-NEXT: %[[MAT_2:.*]] = tensor.extract_slice %[[ARG_2]] +// CHECK-SAME: [0, 0, 0] [%[[DIM0]], 1, %[[DIM2]]] [1, 1, 1] // CHECK-NEXT: %[[RES_2:.*]] = tensor.insert_slice %[[MAT_2]] into %[[INIT]] // CHECK-NEXT: scf.yield %[[RES_2]] // CHECK-NEXT: } @@ -462,3 +473,150 @@ func.func @concatenate( // CHECK-NEXT: } // CHECK-NEXT: return %[[RESULT]] : tensor + +// ----- + +func.func @linalg_map(%lhs : tensor<1x1xf32>, + %rhs: tensor<1x1xf32>, + %init: tensor<1x1xf32>) + -> tensor<1x1xf32> { + %add = linalg.map + ins(%lhs, %rhs : tensor<1x1xf32>, tensor<1x1xf32>) + outs(%init: tensor<1x1xf32>) + (%lhs_elem: f32, %rhs_elem: f32) { + %0 = arith.addf %lhs_elem, %rhs_elem: f32 + linalg.yield %0: f32 + } + func.return %add : tensor<1x1xf32> +} + +// CHECK-LABEL: @linalg_map( +// CHECK-SAME: %[[LHS:.*]]: tensor<1x1xf32>, %[[RHS:.*]]: tensor<1x1xf32>, %[[INIT:.*]]: tensor<1x1xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[L_ELEM:.*]] = tensor.extract %[[LHS]][%[[C0]], %[[C0]]] +// CHECK-NEXT: %[[R_ELEM:.*]] = tensor.extract %[[RHS]][%[[C0]], %[[C0]]] +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[L_ELEM]], %[[R_ELEM]] +// CHECK-NEXT: tensor.from_elements %[[ADD]] + +// ----- + +func.func @linalg_reduce(%ins: tensor<1x1x1xf32>, + %outs: tensor<1x1xf32>) + -> tensor<1x1xf32> { + %reduce = linalg.reduce + ins(%ins: tensor<1x1x1xf32>) + outs(%outs: tensor<1x1xf32>) + dimensions = [1] + (%in: f32, %out: f32) { + %0 = arith.addf %in, %out: f32 + linalg.yield %0: f32 + } + func.return %reduce : tensor<1x1xf32> +} + +// CHECK-LABEL: @linalg_reduce( +// CHECK-SAME: %[[INS:.*]]: tensor<1x1x1xf32>, %[[OUTS:.*]]: tensor<1x1xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[L_ELEM:.*]] = tensor.extract %[[INS]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK-NEXT: %[[R_ELEM:.*]] = tensor.extract %[[OUTS]][%[[C0]], %[[C0]]] +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[L_ELEM]], %[[R_ELEM]] +// CHECK-NEXT: tensor.from_elements %[[ADD]] + +// ----- + +func.func @linalg_transpose(%ins: tensor<1x1xf32>, + %outs: tensor<1x1xf32>) + -> tensor<1x1xf32> { + %transpose = linalg.transpose + ins(%ins: tensor<1x1xf32>) + outs(%outs: tensor<1x1xf32>) + permutation = [1, 0] + func.return %transpose : tensor<1x1xf32> +} + +// CHECK-LABEL: @linalg_transpose( +// CHECK-SAME: %[[INS:.*]]: tensor<1x1xf32>, %[[OUTS:.*]]: tensor<1x1xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract %[[INS]][%[[C0]], %[[C0]]] +// CHECK-NEXT: tensor.from_elements %[[EXTRACTED]] + +// ----- + +func.func @linalg_matmul(%lhs: tensor<1x1xf32>, + %rhs: tensor<1x1xf32>, + %out : tensor<1x1xf32>) -> tensor<1x1xf32> { + %0 = linalg.matmul + ins(%lhs, %rhs : tensor<1x1xf32>, tensor<1x1xf32>) + outs(%out : tensor<1x1xf32>) -> tensor<1x1xf32> + return %0 : tensor<1x1xf32> +} + +// CHECK-LABEL: @linalg_matmul( +// CHECK-SAME: %[[LHS:.*]]: tensor<1x1xf32>, %[[RHS:.*]]: tensor<1x1xf32>, %[[OUT:.*]]: tensor<1x1xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[LHS_ELEM:.*]] = tensor.extract %[[LHS]][%[[C0]], %[[C0]]] +// CHECK-NEXT: %[[RHS_ELEM:.*]] = tensor.extract %[[RHS]][%[[C0]], %[[C0]]] +// CHECK-NEXT: %[[OUT_ELEM:.*]] = tensor.extract %[[OUT]][%[[C0]], %[[C0]]] +// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[LHS_ELEM]], %[[RHS_ELEM]] +// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[OUT_ELEM]], %[[MUL]] +// CHECK-NEXT: tensor.from_elements %[[ADD]] + +// ----- + +func.func @thlo_reverse(%arg : tensor<1x1xf32>, %init: tensor<1x1xf32>) + -> tensor<1x1xf32> { + %0 = thlo.reverse ins(%arg : tensor<1x1xf32>) + outs(%init : tensor<1x1xf32>) + reverse_dimensions = [0, 1] + func.return %0 : tensor<1x1xf32> +} + +// CHECK-LABEL: @thlo_reverse( +// CHECK-SAME: %[[ARG:.*]]: tensor<1x1xf32>, %[[INIT:.*]]: tensor<1x1xf32>) +// CHECK: return %[[ARG]] + +// ----- + +func.func @ite_1d(%arg0: i1, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) + -> tensor<1xf32> { + %0 = scf.if %arg0 -> (tensor<1xf32>) { + scf.yield %arg2 : tensor<1xf32> + } else { + scf.yield %arg1 : tensor<1xf32> + } + return %0 : tensor<1xf32> +} + +// CHECK: func.func @ite_1d(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: tensor<1xf32>, %[[ARG2:.*]]: tensor<1xf32>) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[IF:.*]] = scf.if %[[ARG0]] -> (f32) +// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[ARG2]][%[[C0]]] +// CHECK: scf.yield %[[EXTRACTED]] : f32 +// CHECK: else +// CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG1]][%[[C0]]] +// CHECK: scf.yield %[[EXTRACTED_0]] : f32 +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[IF]] +// CHECK: return %[[FROM_ELEMENTS]] + +// ----- + +func.func @ite_2d(%arg0: i1, %arg1: tensor<1x1xf32>, %arg2: tensor<1x1xf32>) + -> tensor<1x1xf32> { + %0 = scf.if %arg0 -> (tensor<1x1xf32>) { + scf.yield %arg2 : tensor<1x1xf32> + } else { + scf.yield %arg1 : tensor<1x1xf32> + } + return %0 : tensor<1x1xf32> +} + +// CHECK: func.func @ite_2d(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: tensor<1x1xf32>, %[[ARG2:.*]]: tensor<1x1xf32>) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[IF:.*]] = scf.if %[[ARG0]] -> (f32) +// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[ARG2]][%[[C0]], %[[C0]]] +// CHECK: scf.yield %[[EXTRACTED]] : f32 +// CHECK: else +// CHECK: %[[EXTRACTED_0:.*]] = tensor.extract %[[ARG1]][%[[C0]], %[[C0]]] +// CHECK: scf.yield %[[EXTRACTED_0]] : f32 +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[IF]] +// CHECK: return %[[FROM_ELEMENTS]] diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/warp_reduce.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/warp_reduce.mlir deleted file mode 100644 index 019c62b2fca..00000000000 --- a/tensorflow/compiler/xla/mlir_hlo/tests/warp_reduce.mlir +++ /dev/null @@ -1,101 +0,0 @@ -// RUN: mlir-hlo-opt -split-input-file -gml-st-to-gpu %s | FileCheck %s - -// CHECK-LABEL: func @vector_reduce -func.func @vector_reduce(%arg0 : memref<1xf32>) { - - %c0 = arith.constant 0 : index - %cst = arith.constant 1.0 : f32 - %init = vector.broadcast %cst : f32 to vector<1xf32> - %lane = gpu.lane_id - %tile = gml_st.tile [%lane] [1] [1] : !gml_st.tile<1> - %bcast = vector.broadcast %cst : f32 to vector<1xf32> - - // CHECK: %[[CST:.*]] = arith.constant 1.0 - // CHECK: %[[Y0:.*]], %{{.*}} = gpu.shuffle xor %[[CST]], %c1 - // CHECK: %[[X1:.*]] = arith.addf %[[Y0]], %[[CST]] - // CHECK: %[[Y1:.*]], %{{.*}} = gpu.shuffle xor %[[X1]], %c2 - // CHECK: %[[X2:.*]] = arith.addf %[[X1]], %[[Y1]] - // CHECK: %[[Y2:.*]], %{{.*}} = gpu.shuffle xor %[[X2]], %c4 - // CHECK: %[[X3:.*]] = arith.addf %[[X2]], %[[Y2]] - // CHECK: %[[Y3:.*]], %{{.*}} = gpu.shuffle xor %[[X3]], %c8 - // CHECK: %[[X4:.*]] = arith.addf %[[X3]], %[[Y3]] - // CHECK: %[[Y4:.*]], %{{.*}} = gpu.shuffle xor %[[X4]], %c16 - // CHECK: %[[X5:.*]] = arith.addf %[[X4]], %[[Y4]] - // CHECK: %[[Y5:.*]] = arith.addf %[[X5]], %[[CST]] - // CHECK: %[[SUM:.*]] = vector.broadcast %[[Y5]] - %dist = gml_st.distribute %bcast into[%tile] - : vector<1xf32> into vector<1x32xf32>[!gml_st.tile<1>] - %sum = vector.multi_reduction , %dist, %init [1] - : vector<1x32xf32> to vector<1xf32> - // CHECK: vector.transfer_write %[[SUM]], %arg0[%c0] - vector.transfer_write %sum, %arg0[%c0] : vector<1xf32>, memref<1xf32> - - func.return -} - -// ----- - -// CHECK-LABEL: func @vector_reduce_small -func.func @vector_reduce_small() -> vector<1xf32> { - - %cst = arith.constant 1.0 : f32 - %init = vector.broadcast %cst : f32 to vector<1xf32> - %lane = gpu.lane_id - %tile = gml_st.tile [%lane] [1] [1] : !gml_st.tile<1> - %dist = gml_st.distribute %init into[%tile] - : vector<1xf32> into vector<1x4xf32>[!gml_st.tile<1>] - - // CHECK: %[[CST:.*]] = arith.constant 1.0 - // CHECK: %[[Y0:.*]], %{{.*}} = gpu.shuffle xor %[[CST]], %c1 - // CHECK: %[[X1:.*]] = arith.addf %[[Y0]], %[[CST]] - // CHECK: %[[Y1:.*]], %{{.*}} = gpu.shuffle xor %[[X1]], %c2 - // CHECK: %[[X2:.*]] = arith.addf %[[X1]], %[[Y1]] - // CHECK: %[[Y2:.*]] = arith.addf %[[X2]], %[[CST]] - // CHECK: %[[SUM:.*]] = vector.broadcast %[[Y2]] - %sum = vector.multi_reduction , %dist, %init [1] - : vector<1x4xf32> to vector<1xf32> - - // CHECK: return %[[SUM]] - func.return %sum : vector<1xf32> -} - -// ----- - -#stride1 = strided<[1], offset: ?> - -// CHECK-LABEL: func @gpu_launch -func.func @gpu_launch() -> memref<64xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c4 = arith.constant 4 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %cst = arith.constant dense<0.0> : vector<1xf32> - %0 = memref.alloc() : memref<64xf32> - // CHECK: gpu.launch - gml_st.parallel (%arg1) = (%c0) to (%c64) step (%c4) { - %1 = memref.subview %0[%arg1] [4] [1] - : memref<64xf32> to memref<4xf32, #stride1> - gml_st.parallel (%arg2) = (%c0) to (%c4) step (%c1) { - %2 = memref.subview %1[%arg2] [1] [1] - : memref<4xf32, #stride1> to memref<1xf32, #stride1> - - %init = vector.broadcast %cst : vector<1xf32> to vector<1x32xf32> - %3 = gml_st.parallel (%arg3) = (%c0) to (%c32) step (%c1) { - %tile = gml_st.tile [0, %arg3] [1, 1] [1, 1] : !gml_st.tile<1x1> - %elem = arith.constant dense<1.0> : vector<1x1xf32> - gml_st.set_yield %elem into %init[%tile] - : vector<1x1xf32> into vector<1x32xf32>[!gml_st.tile<1x1>] - } : vector<1x32xf32> - - // CHECK-NOT: vector.multi_reduction - %sum = vector.multi_reduction , %3, %cst [1] - : vector<1x32xf32> to vector<1xf32> - vector.transfer_write %sum, %2[%c0] {in_bounds = [true]} - : vector<1xf32>, memref<1xf32, #stride1> - gml_st.set_yield - } - gml_st.set_yield - } - return %0 : memref<64xf32> -} diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/thlo/CMakeLists.txt similarity index 95% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/thlo/CMakeLists.txt index 88672e5e298..c3f701100cc 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/thlo/CMakeLists.txt @@ -14,3 +14,4 @@ add_subdirectory(IR) add_subdirectory(transforms) +add_subdirectory(interfaces) diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/IR/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/thlo/IR/CMakeLists.txt similarity index 73% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/IR/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/thlo/IR/CMakeLists.txt index 94e3aa5acd0..5a4f76478c2 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/IR/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/thlo/IR/CMakeLists.txt @@ -20,3 +20,24 @@ mlir_tablegen(thlo_dialect.cc.inc -gen-dialect-defs) add_public_tablegen_target(MLIRthlo_opsIncGen) add_dependencies(mlir-headers MLIRthlo_opsIncGen) + + +include_directories(BEFORE + ${CMAKE_CURRENT_BINARY_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}) + +add_mlir_dialect_library(THLODialect + thlo_ops.cc + + DEPENDS + MLIRthlo_opsIncGen + + LINK_LIBS PUBLIC + GmlStDialect + MLIRDestinationStyleOpInterface + MLIRIR + MLIRMemRefDialect + MLIRSideEffectInterfaces + MLIRSupport + MLIRTensorDialect +) diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/IR/thlo_ops.cc b/tensorflow/compiler/xla/mlir_hlo/thlo/IR/thlo_ops.cc similarity index 59% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/IR/thlo_ops.cc rename to tensorflow/compiler/xla/mlir_hlo/thlo/IR/thlo_ops.cc index db65d0bc878..5a038349bb2 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/IR/thlo_ops.cc +++ b/tensorflow/compiler/xla/mlir_hlo/thlo/IR/thlo_ops.cc @@ -13,38 +13,52 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h" +#include "thlo/IR/thlo_ops.h" #include +#include #include #include #include +#include #include #include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface.h" -#include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/TilingInterface.h" namespace mlir { namespace { +Value materializeSlice(OpBuilder &b, Location loc, Value valueToTile, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + return b.create(loc, valueToTile, offsets, sizes, + strides); +} + +Value materializeSlice(OpBuilder &b, Location loc, Value valueToTile, + ArrayRef offsets, + ArrayRef sizes) { + SmallVector strides(offsets.size(), b.getIndexAttr(1)); + return materializeSlice(b, loc, valueToTile, offsets, sizes, strides); +} + //===----------------------------------------------------------------------===// // Destination-style ops tools //===----------------------------------------------------------------------===// @@ -79,7 +93,10 @@ void printDstStyleOp( // Print attributes with custom printing logic. SmallVector elidedAttrs; - if (printAttrsFn) elidedAttrs = printAttrsFn(op, p); + if (printAttrsFn) { + p << ' '; + elidedAttrs = printAttrsFn(op, p); + } p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); } @@ -146,11 +163,7 @@ ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef attributeValue) { - p << " " << attributeName << " = [" << attributeValue << "] "; -} - -bool dimensionsMatch(int64_t d1, int64_t d2) { - return ShapedType::isDynamic(d1) || ShapedType::isDynamic(d2) || d1 == d2; + p << attributeName << " = [" << attributeValue << "] "; } SmallVector getParallelIteratorTypes(int64_t dimCount) { @@ -168,20 +181,11 @@ SmallVector getIterationDomainForTensor(OpBuilder &b, Location loc, })); } -Value getMaterializedTile(OpBuilder &b, Location loc, - TypedValue tensor, - ArrayRef offsets, - ArrayRef sizes) { - SmallVector strides(offsets.size(), b.getIndexAttr(1)); - Value tile = b.create(loc, offsets, sizes, strides); - return b.create(loc, tensor, tile); -} - } // namespace } // namespace mlir // Generated dialect definitions. -#include "mlir-hlo/Dialect/thlo/IR/thlo_dialect.cc.inc" +#include "thlo/IR/thlo_dialect.cc.inc" namespace mlir { namespace thlo { @@ -189,7 +193,7 @@ namespace thlo { void THLODialect::initialize() { addOperations< #define GET_OP_LIST -#include "mlir-hlo/Dialect/thlo/IR/thlo_ops.cc.inc" +#include "thlo/IR/thlo_ops.cc.inc" >(); } @@ -226,243 +230,217 @@ LogicalResult YieldOp::verify() { return success(); } // ConcatenateOp //===----------------------------------------------------------------------===// -namespace { - -gml_st::TileOp createTileOp(OpBuilder &b, Location loc, Value tensor, - ArrayRef offsets, - ArrayRef sizes) { - auto initTy = tensor.getType().cast(); - SmallVector unitStrides(initTy.getRank(), b.getIndexAttr(1)); - return b.create(loc, offsets, sizes, unitStrides); -} - -} // namespace - SmallVector ConcatenateOp::getLoopIteratorTypes() { return getParallelIteratorTypes(getInit().getType().getRank()); } -SmallVector ConcatenateOp::getDestinationOperands(OpBuilder &) { - return {getInit()}; -} - SmallVector ConcatenateOp::getIterationDomain(OpBuilder &b) { return getIterationDomainForTensor(b, getLoc(), getInit()); } namespace { -// TODO(frgossen): Fuse this as a switch statement if all the operands are unit -// size in the concatenation dimension. -Value fuseConcatenateOpThroughTile(ConcatenateOp op, OpBuilder &builder, - Location loc, Value tile) { - uint64_t concatDim = op.getDimension(); - RankedTensorType resultTy = op.getType(0).cast(); - int64_t rank = resultTy.getRank(); - OperandRange allOperands = op.getInputs(); - Value anyOperand = allOperands.front(); - - // Create the shared tile strides, which are the exact same for every operand - // tile. Also create a basis for the space sizes, tile offsets, and tile - // sizes. These hold the shared values in all non-concat dimensions and can be - // amended in the concat dimension to create the individual operand tiles. - SmallVector sharedTileStrides(rank); - SmallVector baseSpaceSizes(rank); - SmallVector baseTileOffsets(rank); - SmallVector baseTileSizes(rank); - auto tileOp = tile.getDefiningOp(); - auto tileOffsets = - getValueOrCreateConstantIndexOp(builder, loc, tileOp.getMixedOffsets()); - auto tileSizes = - getValueOrCreateConstantIndexOp(builder, loc, tileOp.getMixedSizes()); - auto tileStrides = - getValueOrCreateConstantIndexOp(builder, loc, tileOp.getMixedStrides()); - for (int64_t i = 0; i < rank; ++i) { - Value iCst = builder.create(loc, i); - sharedTileStrides[i] = tileStrides[i]; - - // The space sizes, tile offsets, and tile sizes differ in the concat - // dimension. Do not populate these. - if (i == static_cast(concatDim)) continue; - - baseSpaceSizes[i] = - builder.createOrFold(loc, anyOperand, iCst); - baseTileOffsets[i] = tileOffsets[i]; - baseTileSizes[i] = tileSizes[i]; +Value getSingleOperandTiledImplementationForConcatRecursively( + OpBuilder &b, Location loc, int64_t concatDim, ValueRange remainingOperands, + SmallVector &remainingOffsets, ArrayRef sizes) { + assert(!remainingOperands.empty() && "expect at least one remaining operand"); + assert(sizes[concatDim].get().cast().getInt() == 1 && + "expect unit size in concat dim"); + + // Terminal case of exactly one operand. + Value leadingOperand = remainingOperands.front(); + if (remainingOperands.size() == 1) { + return materializeSlice(b, loc, leadingOperand, remainingOffsets, sizes); } + // For more than one operand, distinguish between the leading operand and the + // remainder. + assert(remainingOperands.size() > 1 && + "expect more than one operand at this point"); + Value leadingOperandSizeInConcatDim = + b.create(loc, leadingOperand, concatDim); + Value remainingOffsetInConcatDim = + getValueOrCreateConstantIndexOp(b, loc, remainingOffsets[concatDim]); + Value leadingOperandPredicate = b.create( + loc, arith::CmpIPredicate::ult, remainingOffsetInConcatDim, + leadingOperandSizeInConcatDim); + auto ifOp = b.create( + loc, leadingOperandPredicate, + [&](OpBuilder &b, Location loc) { + Value tiledConcat = + getSingleOperandTiledImplementationForConcatRecursively( + b, loc, concatDim, {leadingOperand}, remainingOffsets, sizes); + b.create(loc, tiledConcat); + }, + [&](OpBuilder &b, Location loc) { + remainingOffsets[concatDim] = + b.create(loc, remainingOffsetInConcatDim, + leadingOperandSizeInConcatDim) + .getResult(); + Value tiledConcat = + getSingleOperandTiledImplementationForConcatRecursively( + b, loc, concatDim, remainingOperands.drop_front(), + remainingOffsets, sizes); + b.create(loc, tiledConcat); + }); + return ifOp.getResults().front(); +} + +Value getSingleOperandTiledImplementationForConcat( + ConcatenateOp op, OpBuilder &b, Location loc, + ArrayRef offsets, ArrayRef sizes) { + int64_t concatDim = op.getDimension().getSExtValue(); + SmallVector remainingOffsets(offsets); + return getSingleOperandTiledImplementationForConcatRecursively( + b, loc, concatDim, op.getInputs(), remainingOffsets, sizes); +} + +Value getGenericTiledImplementationForConcat(ConcatenateOp op, OpBuilder &b, + Location loc, + ArrayRef offsets, + ArrayRef sizes) { + // Create a basis for the tile offsets and sizes. These hold the shared values + // in all non-concat dimensions and are amended in the concat dimension to + // create the individual operand tiles. Also, create the shared tile strides, + // which are the exact same for every operand tile. + SmallVector operandTileOffsetsBase(offsets); + SmallVector operandTileSizesBase(sizes); + SmallVector operandTileStrides(sizes.size(), b.getIndexAttr(1)); + // Some shared values. - ArrayAttr allDynamicStridesOrOffsetsAttr = builder.getI64ArrayAttr( - SmallVector(rank, ShapedType::kDynamicStrideOrOffset)); - ArrayAttr allDynamicSizesAttr = builder.getI64ArrayAttr( - SmallVector(rank, ShapedType::kDynamicSize)); - Value zeroCst = builder.create(loc, 0); - Value concatDimCst = builder.create(loc, concatDim); - Value maxTileSizeInConcatDim = tileSizes[concatDim]; + Value zeroCst = b.create(loc, 0); + int64_t concatDim = op.getDimension().getSExtValue(); + Value concatDimCst = b.create(loc, concatDim); + Value maxTileSizeInConcatDim = + getValueOrCreateConstantIndexOp(b, loc, sizes[concatDim]); // The remaining tile offset in the concat dimension is subtracted by each // operand's size in that dimension. We maintain the invariant // remainingTileOffsetInConcatDim >= 0. - Value remainingTileOffsetInConcatDim = tileOffsets[concatDim]; + Value remainingTileOffsetInConcatDim = + getValueOrCreateConstantIndexOp(b, loc, offsets[concatDim]); // Create the relevant subsets per operand. These tiles can be empty at // runtime. - SmallVector subOperands; - subOperands.reserve(allOperands.size()); - for (Value operand : allOperands) { - // Create operand space. - Value operandSizeInConcatDim = - builder.create(loc, operand, concatDimCst); - baseSpaceSizes[concatDim] = operandSizeInConcatDim; - + SmallVector tiledOperands; + tiledOperands.reserve(op.getNumDpsInputs()); + for (Value operand : op.getInputs()) { // Find the current operand's tile offset in the concat dimension. This is // the remaining offset clamped into the bounds of the operand. Note that // the remaining offset is always >= 0. - Value operandTileOffsetInConcatDim = builder.create( + Value operandSizeInConcatDim = + b.create(loc, operand, concatDimCst); + Value operandTileOffsetInConcatDim = b.create( loc, remainingTileOffsetInConcatDim, operandSizeInConcatDim); - baseTileOffsets[concatDim] = operandTileOffsetInConcatDim; + operandTileOffsetsBase[concatDim] = operandTileOffsetInConcatDim; // Find the current operand's tile size in the concat dimension. - Value remainingOperandSizeInConcatDim = builder.create( + Value remainingOperandSizeInConcatDim = b.create( loc, operandSizeInConcatDim, operandTileOffsetInConcatDim); - baseTileSizes[concatDim] = builder.create( + operandTileSizesBase[concatDim] = b.createOrFold( loc, remainingOperandSizeInConcatDim, maxTileSizeInConcatDim); // Create the operand tile and materialize the subset for this operand. - Value tile = builder.create( - loc, baseTileOffsets, baseTileSizes, sharedTileStrides, - allDynamicStridesOrOffsetsAttr, allDynamicSizesAttr, - allDynamicStridesOrOffsetsAttr); - subOperands.push_back( - builder.create(loc, operand, tile)); + tiledOperands.push_back( + materializeSlice(b, loc, operand, operandTileOffsetsBase, + operandTileSizesBase, operandTileStrides)); // Unless it is the last operand, update the remaining tile offset in the // concat dimension. The remaining offset is subtracted by the operand's // size but must remain >= 0. - if (operand != allOperands.back()) { - Value cmp = builder.create(loc, arith::CmpIPredicate::ule, - remainingTileOffsetInConcatDim, - operandSizeInConcatDim); - Value sub = builder.create( - loc, remainingTileOffsetInConcatDim, operandSizeInConcatDim); + if (operand != op.getInputs().back()) { + Value cmp = b.create(loc, arith::CmpIPredicate::ule, + remainingTileOffsetInConcatDim, + operandSizeInConcatDim); + Value sub = b.create(loc, remainingTileOffsetInConcatDim, + operandSizeInConcatDim); remainingTileOffsetInConcatDim = - builder.create(loc, cmp, zeroCst, sub); + b.create(loc, cmp, zeroCst, sub); } } // Create the tiled concat op. - auto tileType = tile.getType().cast(); - Value subInit = - builder.create(loc, op.getInit(), tile); - auto subResultType = - RankedTensorType::get(tileType.getShape(), resultTy.getElementType()); - return builder - .create(loc, subResultType, subOperands, subInit, - concatDim) - ->getResult(0); -} - -Value fuseConcatenateOpThroughPointRecursively( - OpBuilder &builder, Location loc, RankedTensorType rankedTy, - uint64_t concatDim, SmallVector &remainingOffsets, - ValueRange remainingOperands) { - // Bail if called for no operands. - if (remainingOperands.empty()) { - return {}; - } - Value leadingOperand = remainingOperands.front(); - - // Terminal case of exactly one operand. - if (remainingOperands.size() == 1) { - // Create operand point. - SmallVector allDynamicOffsets(rankedTy.getRank(), - ShapedType::kDynamicStrideOrOffset); - - auto sizeOrStride = builder.getI64ArrayAttr({1}); - Value operandPoint = builder.create( - loc, remainingOffsets, ValueRange{}, ValueRange{}, - builder.getI64ArrayAttr(allDynamicOffsets), sizeOrStride, sizeOrStride); - - return builder.create(loc, rankedTy.getElementType(), - leadingOperand, operandPoint); - } - - // For more than 1 operand, distinguish between the leading operand and the - // remainder. - assert(remainingOperands.size() > 1 && - "expect more than 1 operand at this point"); - Value leadingOperandConcatDim = - builder.create(loc, leadingOperand, concatDim); - Value leadingOperandPredicate = builder.create( - loc, arith::CmpIPredicate::ult, remainingOffsets[concatDim], - leadingOperandConcatDim); - auto ifOp = builder.create( - loc, rankedTy.getElementType(), leadingOperandPredicate, - [&](OpBuilder &builder, Location loc) { - // For the leading operand, recur with the current offsets. - Value fused = fuseConcatenateOpThroughPointRecursively( - builder, loc, rankedTy, concatDim, remainingOffsets, - leadingOperand); - builder.create(loc, fused); - }, - [&](OpBuilder &builder, Location loc) { - // For the remaining operands, substract the leading operand's size from - // the remaining offsets in the concatenation dimension. - SmallVector thenRemainingOffsets(remainingOffsets.begin(), - remainingOffsets.end()); - thenRemainingOffsets[concatDim] = builder.create( - loc, remainingOffsets[concatDim], leadingOperandConcatDim); - Value fused = fuseConcatenateOpThroughPointRecursively( - builder, loc, rankedTy, concatDim, thenRemainingOffsets, - remainingOperands.drop_front()); - builder.create(loc, fused); - }); - return ifOp.getResults().front(); + Value tiledInit = materializeSlice(b, loc, op.getInit(), offsets, sizes); + auto tiledConcat = + b.create(loc, tiledInit.getType(), tiledOperands, + tiledInit, b.getIndexAttr(concatDim)); + return tiledConcat.getResults().front(); } -Value fuseConcatenateOpThroughPoint(ConcatenateOp op, OpBuilder &builder, - Location loc, Value subset) { - auto resultTy = op.getType(0).cast(); - uint64_t concatDim = op.getDimension(); - - // Materialize initial offsets. - auto tileOp = subset.getDefiningOp(); - SmallVector initialOffsets = - getValueOrCreateConstantIndexOp(builder, loc, tileOp.getMixedOffsets()); +Value getTiledImplementationForConcat(ConcatenateOp op, OpBuilder &b, + Location loc, + ArrayRef offsets, + ArrayRef sizes) { + // If the tile is of unit size in the concatenation dimension, we can generate + // the tiled implementation based on a single operand. + int64_t concatDim = op.getDimension().getSExtValue(); + OpFoldResult tileSizeInConcatDim = sizes[concatDim]; + if (tileSizeInConcatDim.is() && + tileSizeInConcatDim.get().cast().getInt() == 1) { + return getSingleOperandTiledImplementationForConcat(op, b, loc, offsets, + sizes); + } - ValueRange initialOperands = op.getInputs(); - return fuseConcatenateOpThroughPointRecursively( - builder, loc, resultTy, concatDim, initialOffsets, initialOperands); + // Otherwise, rely on the generic implementation. + return getGenericTiledImplementationForConcat(op, b, loc, offsets, sizes); } } // namespace -gml_st::TilingInterface ConcatenateOp::getTiledImplementation( +SmallVector ConcatenateOp::getTiledImplementation( OpBuilder &b, ArrayRef offsets, ArrayRef sizes) { - // Create tile subset. - auto loc = getLoc(); - gml_st::TileOp tile = createTileOp(b, loc, getInit(), offsets, sizes); + auto tiled = + getTiledImplementationForConcat(*this, b, getLoc(), offsets, sizes); + return {tiled.getDefiningOp()}; +} - auto tiled = fuseConcatenateOpThroughTile(*this, b, loc, tile); - return llvm::cast(tiled.getDefiningOp()); +LogicalResult ConcatenateOp::getResultTilePosition( + OpBuilder & /*b*/, unsigned /*resultNumber*/, + ArrayRef offsets, ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes) { + resultOffsets = llvm::to_vector(offsets); + resultSizes = llvm::to_vector(sizes); + return success(); } FailureOr ConcatenateOp::generateResultTileValue( OpBuilder &b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes) { assert(resultNumber == 0 && "expect unique result idx"); - return getTiledImplementation(b, offsets, sizes)->getResults().front(); + return getTiledImplementation(b, offsets, sizes) + .front() + ->getResults() + .front(); } ParseResult ConcatenateOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDstStyleOp(parser, result); + return parseDstStyleOp( + parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { + int64_t dimension = 0; + if (parser.parseKeyword("dimension") || parser.parseEqual() || + parser.parseInteger(dimension)) + return failure(); + + attributes.set("dimension", + parser.getBuilder().getIndexAttr(dimension)); + return success(); + }); } -void ConcatenateOp::print(OpAsmPrinter &p) { printDstStyleOp(*this, p); } +void ConcatenateOp::print(OpAsmPrinter &p) { + printDstStyleOp( + *this, p, + [](ConcatenateOp op, OpAsmPrinter &p) -> SmallVector { + p << op.getDimensionAttrName().str() << " = " << op.getDimension(); + + return {op.getDimensionAttrName()}; + }); +} LogicalResult ConcatenateOp::verify() { - int64_t concatDim = getDimension(); + int64_t concatDim = getDimension().getSExtValue(); ShapedType inputType = getDpsInputOperand(0)->get().getType().cast(); @@ -539,32 +517,18 @@ DynamicBroadcastInDimOp::getLoopIteratorTypes() { return getParallelIteratorTypes(getInit().getType().getRank()); } -SmallVector DynamicBroadcastInDimOp::getDestinationOperands( - OpBuilder &) { - return {getInit()}; -} - SmallVector DynamicBroadcastInDimOp::getIterationDomain(OpBuilder &b) { return getIterationDomainForTensor(b, getLoc(), getInit()); } -gml_st::TilingInterface DynamicBroadcastInDimOp::getTiledImplementation( +SmallVector DynamicBroadcastInDimOp::getTiledImplementation( OpBuilder &b, ArrayRef offsets, ArrayRef sizes) { // Create tile subset. auto loc = getLoc(); - auto tile = createTileOp(b, loc, getInit(), offsets, sizes); auto initRank = getInit().getType().cast().getRank(); - // Create the needed constants only once. DenseMap localIndexConstants; - auto getIndexConstant = [&](uint64_t c) -> Value { - auto it = localIndexConstants.find(c); - if (it != localIndexConstants.end()) return it->second; - auto cst = b.create(loc, c); - localIndexConstants[c] = cst; - return cst; - }; DenseSet dimensionsThatStay(getBroadcastDimensions().begin(), getBroadcastDimensions().end()); @@ -579,8 +543,9 @@ gml_st::TilingInterface DynamicBroadcastInDimOp::getTiledImplementation( operandDims.reserve(operandTy.getRank()); for (const auto &it : llvm::enumerate(operandTy.getShape())) { int64_t d = it.value(); - Value dim = d == ShapedType::kDynamicSize ? dynamicDims[dynamicDimsIdx++] - : getIndexConstant(d); + Value dim = d == ShapedType::kDynamic + ? dynamicDims[dynamicDimsIdx++] + : b.create(loc, d); operandDims.push_back(dim); } @@ -591,20 +556,18 @@ gml_st::TilingInterface DynamicBroadcastInDimOp::getTiledImplementation( SmallVector operandExpandingDims; for (const auto &it : llvm::enumerate(getBroadcastDimensions())) { auto operandDim = operandDims[it.index()]; - auto resultDim = - b.create(loc, getInit(), getIndexConstant(it.value())); + auto resultDim = b.create( + loc, getInit(), b.create(loc, it.value())); operandExpandingDims.push_back(b.create( loc, arith::CmpIPredicate::ne, operandDim, resultDim)); } // Compute operand tile offsets. - auto tileOpOffsets = - getValueOrCreateConstantIndexOp(b, loc, tile.getMixedOffsets()); + auto tileOpOffsets = getValueOrCreateConstantIndexOp(b, loc, offsets); int64_t operandRank = operandTy.getRank(); - auto staticOffsets = b.getI64ArrayAttr( - SmallVector(operandRank, ShapedType::kDynamicStrideOrOffset)); + auto staticOffsets = SmallVector(operandRank, ShapedType::kDynamic); SmallVector operandOffsets; - Value zero = getIndexConstant(0); + Value zero = b.create(loc, 0); for (int initId = 0, operandId = 0; initId < initRank; ++initId) { if (!dimensionsThatStay.contains(initId)) continue; Value isExpanding = operandExpandingDims[operandId++]; @@ -614,12 +577,11 @@ gml_st::TilingInterface DynamicBroadcastInDimOp::getTiledImplementation( } // Compute operand tile sizes. - auto staticTileSizes = b.getI64ArrayAttr( - SmallVector(operandRank, ShapedType::kDynamicSize)); + auto staticTileSizes = + SmallVector(operandRank, ShapedType::kDynamic); SmallVector tileSizes; - Value one = getIndexConstant(1); - auto tileOpSizes = - getValueOrCreateConstantIndexOp(b, loc, tile.getMixedSizes()); + Value one = b.create(loc, 1); + auto tileOpSizes = getValueOrCreateConstantIndexOp(b, loc, sizes); for (int initId = 0, operandId = 0; initId < initRank; ++initId) { if (!dimensionsThatStay.contains(initId)) continue; Value isExpanding = operandExpandingDims[operandId++]; @@ -629,36 +591,45 @@ gml_st::TilingInterface DynamicBroadcastInDimOp::getTiledImplementation( } // Create operand tile. - auto staticTileStrides = - b.getI64ArrayAttr(SmallVector(operandRank, 1)); + auto staticTileStrides = SmallVector(operandRank, 1); SmallVector tileStrides = {}; - auto operandTileTy = b.getType( - SmallVector(operandRank, ShapedType::kDynamicSize)); - auto operandTile = b.create( - loc, operandTileTy, operandOffsets, tileSizes, tileStrides, staticOffsets, - staticTileSizes, staticTileStrides); // Materialize operand tiles. - Value tiledInit = b.create(loc, getInit(), tile); - Value tiledOperand = - b.create(loc, getOperand(), operandTile); + Value tiledInit = materializeSlice(b, loc, getInit(), offsets, sizes); + Value tiledOperand = materializeSlice( + b, loc, getOperand(), getMixedValues(staticOffsets, operandOffsets, b), + getMixedValues(staticTileSizes, tileSizes, b), + getMixedValues(staticTileStrides, tileStrides, b)); // Finally, materialize tiled broadcast. - auto tileTy = tile.getType(); auto resultTy = getType(0).cast(); auto tiledResultTy = - RankedTensorType::get(tileTy.getShape(), resultTy.getElementType()); - return b.create( + RankedTensorType::get(tiledInit.getType().cast().getShape(), + resultTy.getElementType()); + return {b.create( loc, TypeRange{tiledResultTy}, tiledOperand, tiledInit, getBroadcastDimensionsAttr(), getKnownExpandingDimensionsAttr(), - getKnownNonexpandingDimensionsAttr()); + getKnownNonexpandingDimensionsAttr())}; +} + +LogicalResult DynamicBroadcastInDimOp::getResultTilePosition( + OpBuilder & /*b*/, unsigned /*resultNumber*/, + ArrayRef offsets, ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes) { + resultOffsets = llvm::to_vector(offsets); + resultSizes = llvm::to_vector(sizes); + return success(); } FailureOr DynamicBroadcastInDimOp::generateResultTileValue( OpBuilder &b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes) { assert(resultNumber == 0 && "expect unique result idx"); - return getTiledImplementation(b, offsets, sizes)->getResults().front(); + return getTiledImplementation(b, offsets, sizes) + .front() + ->getResults() + .front(); } //===----------------------------------------------------------------------===// @@ -683,12 +654,15 @@ ParseResult ScatterOp::parse(OpAsmParser &parser, OperationState &result) { void ScatterOp::print(OpAsmPrinter &p) { printDstStyleOp(*this, p); + p.increaseIndent(); + p.printNewline(); p << "("; llvm::interleaveComma(getUpdateComputation().getArguments(), p, [&](auto arg) { p.printRegionArgument(arg); }); p << ") "; p.printRegion(getUpdateComputation(), /*printEntryBlockArgs=*/false); + p.decreaseIndent(); } LogicalResult ScatterOp::verify() { @@ -744,30 +718,12 @@ SmallVector ScatterOp::getLoopIteratorTypes() { return {utils::IteratorType::reduction}; } -SmallVector ScatterOp::getDestinationOperands(OpBuilder &) { - return {getInit()}; -} - SmallVector ScatterOp::getIterationDomain(OpBuilder &b) { Value indicesCount = b.create(getLoc(), getIndices(), 0); return {Range{b.getIndexAttr(0), indicesCount, b.getIndexAttr(1)}}; } -static Value getSlice(OpBuilder &b, Location loc, Value tensor, - ArrayRef offsets, - ArrayRef sizes) { - SmallVector ones(offsets.size(), b.getIndexAttr(1)); - Value tile = b.create(loc, offsets, sizes, ones); - return b.create(loc, tensor, tile); -} - -static Value getFullSpace(OpBuilder &b, Location loc, Value tensor) { - SmallVector sizes = tensor::getMixedSizes(b, loc, tensor); - SmallVector offsets(sizes.size(), b.getIndexAttr(0)); - return getSlice(b, loc, tensor, offsets, sizes); -} - -mlir::gml_st::TilingInterface ScatterOp::getTiledImplementation( +SmallVector ScatterOp::getTiledImplementation( OpBuilder &b, ArrayRef offsets, ArrayRef sizes) { Location loc = getLoc(); @@ -785,7 +741,8 @@ mlir::gml_st::TilingInterface ScatterOp::getTiledImplementation( SmallVector updateSizes = tensor::getMixedSizes(b, loc, update); updateSizes.front() = tileSize; - Value updateSlice = getSlice(b, loc, update, updateOffsets, updateSizes); + Value updateSlice = + materializeSlice(b, loc, update, updateOffsets, updateSizes); // Tile outer dimension of indices. Value indices = this->getIndices(); @@ -796,22 +753,37 @@ mlir::gml_st::TilingInterface ScatterOp::getTiledImplementation( tensor::getMixedSizes(b, loc, indices); indicesSizes.front() = tileSize; - Value indicesSlice = getSlice(b, loc, indices, indicesOffsets, indicesSizes); + Value indicesSlice = + materializeSlice(b, loc, indices, indicesOffsets, indicesSizes); - // Get full space of the `init` tensor. - Value init = this->getInit(); - Value initSlice = getFullSpace(b, loc, init); + // Get full space of the `init` tensor. We use an extract_slice op because + // otherwise, tileUsingSCFForOp won't replace the arg with the bbarg. + int64_t initRank = getInit().getType().getRank(); + Value init = materializeSlice(b, loc, this->getInit(), + SmallVector(initRank, zeroAttr), + tensor::getMixedSizes(b, loc, this->getInit())); - auto dpsInterface = cast(this->getOperation()); - return dpsInterface.clone(b, loc, TypeRange{initSlice.getType()}, - ValueRange{indicesSlice, updateSlice, initSlice}); + return {mlir::clone(b, this->getOperation(), TypeRange{init.getType()}, + ValueRange{indicesSlice, updateSlice, init})}; +} + +LogicalResult ScatterOp::getResultTilePosition( + OpBuilder &b, unsigned /*resultNumber*/, ArrayRef /*offsets*/, + ArrayRef /*sizes*/, SmallVector &resultOffsets, + SmallVector &resultSizes) { + ScatterOp scatterOp = cast(this->getOperation()); + auto init = scatterOp.getInit(); + resultOffsets = + SmallVector(init.getType().getRank(), b.getIndexAttr(0)); + resultSizes = tensor::createDimValues(b, scatterOp.getLoc(), init); + return success(); } FailureOr ScatterOp::generateResultTileValue( OpBuilder &b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes) { assert(resultNumber == 0 && "variadic scatter is not implemented"); - return getTiledImplementation(b, offsets, sizes)->getResult(0); + return getTiledImplementation(b, offsets, sizes).front()->getResult(0); } //===----------------------------------------------------------------------===// @@ -855,16 +827,12 @@ SmallVector GatherOp::getLoopIteratorTypes() { return {utils::IteratorType::parallel}; } -SmallVector GatherOp::getDestinationOperands(OpBuilder &) { - return {getInit()}; -} - SmallVector GatherOp::getIterationDomain(OpBuilder &b) { Value indicesCount = b.create(getLoc(), getStartIndices(), 0); return {Range{b.getIndexAttr(0), indicesCount, b.getIndexAttr(1)}}; } -mlir::gml_st::TilingInterface GatherOp::getTiledImplementation( +SmallVector GatherOp::getTiledImplementation( OpBuilder &b, ArrayRef offsets, ArrayRef sizes) { SmallVector startIndexOffsets{offsets.front(), @@ -872,8 +840,8 @@ mlir::gml_st::TilingInterface GatherOp::getTiledImplementation( SmallVector startIndexSizes{ sizes.front(), b.getIndexAttr(getStartIndices().getType().getShape().back())}; - auto subStartIndices = getMaterializedTile( - b, getLoc(), getStartIndices(), startIndexOffsets, startIndexSizes); + auto subStartIndices = materializeSlice(b, getLoc(), getStartIndices(), + startIndexOffsets, startIndexSizes); int64_t initRank = getInit().getType().getRank(); SmallVector initOffsets(initRank, b.getIndexAttr(0)); @@ -881,27 +849,70 @@ mlir::gml_st::TilingInterface GatherOp::getTiledImplementation( auto initSizes = tensor::getMixedSizes(b, getLoc(), getInit()); initSizes[0] = sizes.front(); Value initSlice = - getMaterializedTile(b, getLoc(), getInit(), initOffsets, initSizes); + materializeSlice(b, getLoc(), getInit(), initOffsets, initSizes); - return b - .create(getLoc(), TypeRange{initSlice.getType()}, - ValueRange{getOperand(), subStartIndices, initSlice}) - .getOperation(); + return { + b.create(getLoc(), TypeRange{initSlice.getType()}, + ValueRange{getOperand(), subStartIndices, initSlice})}; +} + +LogicalResult GatherOp::getResultTilePosition( + OpBuilder &b, unsigned /*resultNumber*/, ArrayRef offsets, + ArrayRef sizes, SmallVector &resultOffsets, + SmallVector &resultSizes) { + GatherOp gatherOp = cast(this->getOperation()); + auto init = gatherOp.getInit(); + resultOffsets = + SmallVector(init.getType().getRank(), b.getIndexAttr(0)); + resultOffsets.front() = offsets.front(); + resultSizes = tensor::createDimValues(b, gatherOp.getLoc(), init); + resultSizes.front() = sizes.front(); + return success(); } FailureOr GatherOp::generateResultTileValue( OpBuilder &b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes) { assert(resultNumber == 0 && "resultNumber > 0 not implemented"); - return getTiledImplementation(b, offsets, sizes)->getResult(0); + return getTiledImplementation(b, offsets, sizes).front()->getResult(0); } //===----------------------------------------------------------------------===// // SortOp //===----------------------------------------------------------------------===// +void SortOp::getAsmResultNames(function_ref setNameFn) { + ResultRange results = getResults(); + for (size_t i = 0; i < results.size(); i++) { + setNameFn(results[i], "sorted" + std::to_string(i)); + } +} + +void SortOp::getAsmBlockArgumentNames(Region ®ion, + OpAsmSetValueNameFn setNameFn) { + for (int i = 0, e = region.getNumArguments(); i < e; i += 2) { + setNameFn(region.getArgument(i), "lhs" + std::to_string(i / 2)); + setNameFn(region.getArgument(i + 1), "rhs" + std::to_string(i / 2)); + } +} + ParseResult SortOp::parse(OpAsmParser &parser, OperationState &result) { - if (parseDstStyleOp(parser, result)) return failure(); + if (parseDstStyleOp( + parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { + int64_t dimension = 0; + int64_t isStable = 0; + if (parser.parseKeyword("dimension") || parser.parseEqual() || + parser.parseInteger(dimension) || + parser.parseKeyword("is_stable") || parser.parseEqual() || + parser.parseInteger(isStable)) + return failure(); + + auto b = parser.getBuilder(); + attributes.set("dimension", b.getIndexAttr(dimension)); + attributes.set("is_stable", b.getBoolAttr(isStable != 0)); + return success(); + })) + return failure(); SmallVector regionArgs; if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, @@ -916,14 +927,22 @@ ParseResult SortOp::parse(OpAsmParser &parser, OperationState &result) { } void SortOp::print(OpAsmPrinter &p) { - printDstStyleOp(*this, p); + printDstStyleOp( + *this, p, [](SortOp op, OpAsmPrinter &p) -> SmallVector { + p << op.getDimensionAttrName().str() << " = " << op.getDimension() + << ' ' << op.getIsStableAttrName().str() << " = " << op.getIsStable(); + return {op.getDimensionAttrName(), op.getIsStableAttrName()}; + }); + p.increaseIndent(); + p.printNewline(); p << "("; llvm::interleaveComma(getComparator().getArguments(), p, [&](auto arg) { p.printRegionArgument(arg); }); p << ") "; p.printRegion(getComparator(), /*printEntryBlockArgs=*/false); + p.decreaseIndent(); } LogicalResult SortOp::verify() { @@ -938,7 +957,7 @@ LogicalResult SortOp::verify() { return emitOpError() << "expected the number of inputs " << numInputs << " to match the number of outputs " << numOutputs; } - if (comparatorArgs.size() != numInputs * 2) { + if (static_cast(comparatorArgs.size()) != numInputs * 2) { return emitOpError() << "expected the number of block arguments " << comparatorArgs.size() << " to be twice the number " << "of inputs (2*" << numInputs << ")"; @@ -995,11 +1014,13 @@ LogicalResult SortOp::verify() { // Checks that the rank of the reference shape is larger than the absolute // value of the sorting dimension. This is enough to ensure that the dimension - // is valid, since all inputs are known to have the same shape. - int64_t referenceRank = referenceShape.size(); - if (getDimension() >= referenceRank || getDimension() < 0) { + // is valid, since all inputs are known to have the same shape. `getDimension` + // returns an unsigned int, so no need to check for negative values. + size_t referenceRank = referenceShape.size(); + if (getDimension().getSExtValue() >= (int64_t)referenceRank) { return emitOpError() << "sorting dimension must be in range [0, " - << referenceRank << ") but got " << getDimension(); + << referenceRank << ") but got " + << getDimension().getSExtValue(); } return verifyDestinationStyleOp(getOperation()); @@ -1009,10 +1030,6 @@ SmallVector SortOp::getLoopIteratorTypes() { return getParallelIteratorTypes(getType(0).cast().getRank() - 1); } -SmallVector SortOp::getDestinationOperands(OpBuilder &) { - return {getInits()}; -} - SmallVector SortOp::getIterationDomain(OpBuilder &b) { Location loc = getLoc(); auto oneInit = getInits().front(); @@ -1022,7 +1039,7 @@ SmallVector SortOp::getIterationDomain(OpBuilder &b) { IntegerAttr zero = b.getIndexAttr(0); IntegerAttr one = b.getIndexAttr(1); - int64_t sortDimension = getDimension(); + int64_t sortDimension = getDimension().getSExtValue(); for (auto axis : llvm::seq(0, operandsRank - 1)) { int64_t operandAxis = (axis >= sortDimension) ? axis + 1 : axis; @@ -1034,7 +1051,7 @@ SmallVector SortOp::getIterationDomain(OpBuilder &b) { return iterationDomain; } -mlir::gml_st::TilingInterface SortOp::getTiledImplementation( +SmallVector SortOp::getTiledImplementation( OpBuilder &b, ArrayRef offsets, ArrayRef sizes) { auto loc = getLoc(); @@ -1042,7 +1059,7 @@ mlir::gml_st::TilingInterface SortOp::getTiledImplementation( SmallVector tileSizes = llvm::to_vector(sizes); size_t numOutputs = getNumDpsInits(); - int64_t sortDimension = getDimension(); + int64_t sortDimension = getDimension().getSExtValue(); Value oneInput = getInputs().front(); @@ -1053,37 +1070,166 @@ mlir::gml_st::TilingInterface SortOp::getTiledImplementation( b.createOrFold(loc, oneInput, sortDimension); tileSizes.insert(tileSizes.begin() + sortDimension, sortDimensionSize); - gml_st::TileOp tile = createTileOp(b, loc, oneInput, tileOffsets, tileSizes); - // Materialize the tile for each input and init. SmallVector tiledInputsAndInits; SmallVector tiledResultTypes; tiledInputsAndInits.reserve(numOutputs * 2); tiledResultTypes.reserve(numOutputs); - auto tileShape = tile.getType().cast().getShape(); - for (const auto &input : getInputs()) { tiledInputsAndInits.push_back( - b.create(loc, input, tile)); + materializeSlice(b, loc, input, tileOffsets, tileSizes)); + auto tileShape = + tiledInputsAndInits.back().getType().cast().getShape(); tiledResultTypes.push_back(RankedTensorType::get( tileShape, input.getType().cast().getElementType())); } for (const auto &init : getInits()) { tiledInputsAndInits.push_back( - b.create(loc, init, tile)); + materializeSlice(b, loc, init, tileOffsets, tileSizes)); } - auto dpsInterface = cast(this->getOperation()); - return dpsInterface.clone(b, loc, tiledResultTypes, tiledInputsAndInits); + return {mlir::clone(b, this->getOperation(), tiledResultTypes, + tiledInputsAndInits)}; +} + +LogicalResult SortOp::getResultTilePosition( + OpBuilder &b, unsigned /*resultNumber*/, ArrayRef offsets, + ArrayRef sizes, SmallVector &resultOffsets, + SmallVector &resultSizes) { + SortOp sortOp = cast(this->getOperation()); + resultOffsets = llvm::to_vector(offsets); + resultSizes = llvm::to_vector(sizes); + + int64_t sortDimIndex = sortOp.getDimension().getSExtValue(); + Value sortDimValue = b.create( + sortOp.getLoc(), sortOp.getInputs().front(), sortDimIndex); + resultOffsets.insert(resultOffsets.begin() + sortDimIndex, b.getIndexAttr(0)); + resultSizes.insert(resultSizes.begin() + sortDimIndex, sortDimValue); + return success(); } FailureOr SortOp::generateResultTileValue(OpBuilder &b, unsigned resultNumber, ArrayRef offsets, ArrayRef sizes) { - return getTiledImplementation(b, offsets, sizes)->getResult(resultNumber); + return getTiledImplementation(b, offsets, sizes) + .front() + ->getResult(resultNumber); +} + +//===----------------------------------------------------------------------===// +// ReverseOp +//===----------------------------------------------------------------------===// + +ParseResult ReverseOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDstStyleOp( + parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { + return parseDenseI64ArrayAttr(parser, attributes, "reverse_dimensions"); + }); +} + +void ReverseOp::print(OpAsmPrinter &p) { + printDstStyleOp( + *this, p, [](ReverseOp op, OpAsmPrinter &p) -> SmallVector { + printDenseI64ArrayAttr(p, op.getReverseDimensionsAttrName(), + op.getReverseDimensions()); + return {op.getReverseDimensionsAttrName()}; + }); +} + +LogicalResult ReverseOp::verify() { + return verifyDestinationStyleOp(getOperation()); +} + +void ReverseOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), "reversed"); +} + +SmallVector ReverseOp::getLoopIteratorTypes() { + return getParallelIteratorTypes(getType().cast().getRank() - 1); +} + +SmallVector ReverseOp::getIterationDomain(OpBuilder &b) { + return getIterationDomainForTensor(b, getLoc(), getInit()); +} + +namespace { +SmallVector getInputTileOffsetsForReverse( + OpBuilder &b, Location loc, ArrayRef offsets, + ArrayRef tileSizes, ArrayRef reverseDimensions, + TypedValue &input) { + auto tileOpOffsets = getValueOrCreateConstantIndexOp(b, loc, offsets); + auto sizes = getValueOrCreateConstantIndexOp(b, loc, tileSizes); + SmallVector inputTileOffsets; + for (size_t i = 0; i < tileOpOffsets.size(); ++i) { + if (llvm::is_contained(reverseDimensions, i)) { + inputTileOffsets.push_back(OpFoldResult{b.createOrFold( + loc, + b.createOrFold( + loc, b.createOrFold(loc, input, i), + Value(tileOpOffsets[i])), + sizes[i])}); + } else { + inputTileOffsets.push_back(tileOpOffsets[i]); + } + } + + return inputTileOffsets; +} +} // namespace + +SmallVector ReverseOp::getTiledImplementation( + OpBuilder &b, ArrayRef offsets, + ArrayRef sizes) { + auto loc = getLoc(); + auto input = getInput(); + SmallVector inputTileOffsets = getInputTileOffsetsForReverse( + b, loc, offsets, sizes, getReverseDimensions(), input); + + // Materialize the tile for input and init. + SmallVector tiledInputsAndInits; + + tiledInputsAndInits.push_back( + materializeSlice(b, loc, input, inputTileOffsets, sizes)); + tiledInputsAndInits.push_back( + materializeSlice(b, loc, getInit(), offsets, sizes)); + auto tileShape = + tiledInputsAndInits.back().getType().cast().getShape(); + auto tiledResultType = RankedTensorType::get( + tileShape, input.getType().cast().getElementType()); + + return {mlir::clone(b, this->getOperation(), tiledResultType, + tiledInputsAndInits)}; +} + +LogicalResult ReverseOp::getResultTilePosition( + OpBuilder & /*b*/, unsigned /*resultNumber*/, + ArrayRef offsets, ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes) { + resultOffsets = llvm::to_vector(offsets); + resultSizes = llvm::to_vector(sizes); + return success(); +} + +FailureOr ReverseOp::generateResultTileValue( + OpBuilder &b, unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes) { + return getTiledImplementation(b, offsets, sizes) + .front() + ->getResult(resultNumber); +} + +OpFoldResult ReverseOp::fold( + ReverseOpGenericAdaptor>) /*operands*/ { + auto inputType = getInput().getType(); + for (unsigned i = 0; i < getReverseDimensions().size(); ++i) { + if (inputType.getDimSize(getReverseDimensions()[i]) != 1) return nullptr; + } + return getInput(); } } // namespace thlo @@ -1091,4 +1237,4 @@ FailureOr SortOp::generateResultTileValue(OpBuilder &b, // Generated op classes. #define GET_OP_CLASSES -#include "mlir-hlo/Dialect/thlo/IR/thlo_ops.cc.inc" +#include "thlo/IR/thlo_ops.cc.inc" diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/IR/thlo_ops.h b/tensorflow/compiler/xla/mlir_hlo/thlo/IR/thlo_ops.h similarity index 75% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/IR/thlo_ops.h rename to tensorflow/compiler/xla/mlir_hlo/thlo/IR/thlo_ops.h index 08d0cdd0b68..d844aeb884f 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/IR/thlo_ops.h +++ b/tensorflow/compiler/xla/mlir_hlo/thlo/IR/thlo_ops.h @@ -15,23 +15,22 @@ limitations under the License. // This file defines the operations used in the THLO dialect. -#ifndef MLIR_HLO_DIALECT_THLO_IR_THLO_OPS_H -#define MLIR_HLO_DIALECT_THLO_IR_THLO_OPS_H +#ifndef MLIR_HLO_THLO_IR_THLO_OPS_H +#define MLIR_HLO_THLO_IR_THLO_OPS_H -#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface.h" -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/TilingInterface.h" // Generated dialect declarations. -#include "mlir-hlo/Dialect/thlo/IR/thlo_dialect.h.inc" +#include "thlo/IR/thlo_dialect.h.inc" // Generated operation classes. #define GET_OP_CLASSES -#include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h.inc" +#include "thlo/IR/thlo_ops.h.inc" -#endif // MLIR_HLO_DIALECT_THLO_IR_THLO_OPS_H +#endif // MLIR_HLO_THLO_IR_THLO_OPS_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/IR/thlo_ops.td b/tensorflow/compiler/xla/mlir_hlo/thlo/IR/thlo_ops.td similarity index 59% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/IR/thlo_ops.td rename to tensorflow/compiler/xla/mlir_hlo/thlo/IR/thlo_ops.td index 6588c90e40e..362e540d1a6 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/IR/thlo_ops.td +++ b/tensorflow/compiler/xla/mlir_hlo/thlo/IR/thlo_ops.td @@ -16,12 +16,12 @@ limitations under the License. #ifndef THLO_OPS #define THLO_OPS -include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface.td" -include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" +include "mlir/IR/OpAsmInterface.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/TilingInterface.td" def TensorOrMemref : AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; @@ -33,8 +33,7 @@ class TensorOrMemrefOf allowedTypes> : def THLO_Dialect : Dialect { let name = "thlo"; let cppNamespace = "::mlir::thlo"; - - let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let useFoldAPI = kEmitFoldAdaptorFolder; } class THLO_Op traits> : @@ -47,21 +46,35 @@ class THLO_DstStyleOp traits> : THLO_Op]>, - Arguments<(ins Variadic:$values)> { - let summary = "Yield operation for tHLO ops with regions."; - let assemblyFormat = "attr-dict $values `:` type($values)"; - let hasVerifier = 1; -} - def THLO_ConcatenateOp : THLO_DstStyleOp<"concatenate", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods + ]> { let summary = "Destination-style twin for `mhlo.concatenate`"; + let description = [{ + tHLO ConcatenateOp composes a tensor or a memref from multiple tensors or + memrefs. + + Example: + ``` + %concat = thlo.concatenate + ins(%T1 : tensor<100x?xf32>, %T2 : tensor<300x?xf32>) + outs(%init : tensor<400x?xf32>) + dimension = 0 + ``` + + See https://www.tensorflow.org/xla/operation_semantics#concatenate + }]; + let arguments = (ins Variadic:$inputs, TensorOrMemref:$init, - I64Attr:$dimension + IndexAttr:$dimension ); let results = (outs Variadic:$result); @@ -75,8 +88,29 @@ def THLO_ConcatenateOp : THLO_DstStyleOp<"concatenate", [ } def THLO_DynamicBroadcastInDimOp : THLO_DstStyleOp<"dynamic_broadcast_in_dim", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods + ]> { let summary = "Destination-style twin for `mhlo.dynamic_broadcast_in_dim`"; + let description = [{ + tHLO DynamicBroadcastInDimOp specifies a map how to broadcast input + dimensions. It also supports broadcasting size-1 dimensions. + + Example: + ``` + %dyn_bcast = thlo.dynamic_broadcast_in_dim + ins(%input : tensor) + outs(%init : tensor) + broadcast_dimensions = [0, 2] + ``` + + See https://www.tensorflow.org/xla/operation_semantics#broadcastindim + }]; let arguments = (ins // Input args @@ -101,7 +135,14 @@ def THLO_DynamicBroadcastInDimOp : THLO_DstStyleOp<"dynamic_broadcast_in_dim", [ } def THLO_GatherOp : THLO_DstStyleOp<"gather", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods + ]> { let summary = "Destination-style twin for `mhlo.gather`"; let description = [{ tHLO GatherOp corresponds to the canonicalized mHLO GatherOp, i.e. @@ -111,6 +152,15 @@ def THLO_GatherOp : THLO_DstStyleOp<"gather", [ - offset_dims is [1, 2, ...] - collapsed_slice_dims is [] - start_index_map is range(start_indices.shape[1]) + + Example: + ``` + %gathered = thlo.gather + ins(%input : tensor<100xf32>, %indices : tensor<42x1xindex>) + outs(%init : tensor<42xf32>) + ``` + + See https://www.tensorflow.org/xla/operation_semantics#gather. }]; let arguments = (ins // Input args @@ -131,8 +181,15 @@ def THLO_GatherOp : THLO_DstStyleOp<"gather", [ } def THLO_ScatterOp : THLO_DstStyleOp<"scatter", [ - DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"YieldOp">]> { + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"YieldOp"> + ]> { let summary = "Destination-style twin for `mhlo.scatter`"; let description = [{ tHLO ScatterOp corresponds to the canonicalized mHLO ScatterOp, i.e. @@ -143,7 +200,21 @@ def THLO_ScatterOp : THLO_DstStyleOp<"scatter", [ - index_vector_dim is rank(indices) - 1 At the moment, the variadic case is not supported. + + Example: + ``` + %scattered = thlo.scatter + ins(%indices : tensor<2x2xindex>, %input : tensor<2x1x3xf32>) + outs(%init : tensor<3x3xf32>) + (%arg3: f32, %arg4: f32) { + %0 = arith.addf %arg3, %arg4 : f32 + thlo.yield %0 : f32 + } + ``` + + See https://www.tensorflow.org/xla/operation_semantics#scatter. }]; + let arguments = (ins // Input args TensorOrMemrefOf<[Index]>:$indices, @@ -174,8 +245,20 @@ def THLO_ScatterOp : THLO_DstStyleOp<"scatter", [ } def THLO_SortOp : THLO_DstStyleOp<"sort", [ - DeclareOpInterfaceMethods, SameVariadicOperandSize, - SingleBlockImplicitTerminator<"YieldOp">]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + SameVariadicOperandSize, + SingleBlockImplicitTerminator<"YieldOp"> + ]> { let summary = "Destination-style twin for the `mhlo.sort`"; let description = [{ Sorts the given `operands` along the given `dimension` using the given @@ -184,14 +267,16 @@ def THLO_SortOp : THLO_DstStyleOp<"sort", [ Example: ``` %sorted1, %sorted2 = thlo.sort - ins(%input1: tensor, %input2: tensor) - outs(%init1: tensor, %init2: tensor) - { dimension = 0 : i64, is_stable = true } - (%e11: f32, %e12: f32, %e21: i32, %e22: i32) { - %gt = arith.cmpf ogt, %e11, %e12: f32 - thlo.yield %gt : i1 - } + ins(%input1: tensor, %input2: tensor) + outs(%init1: tensor, %init2: tensor) + dimension = 0 + is_stable = true + (%lhs0: f32, %rhs0: f32, %lhs1: i32, %rhs1: i32) { + %0 = arith.cmpf ogt, %lhs0, %rhs0 : f32 + thlo.yield %0 : i1 + } ``` + See https://www.tensorflow.org/xla/operation_semantics#sort. }]; @@ -201,7 +286,7 @@ def THLO_SortOp : THLO_DstStyleOp<"sort", [ // Output args Variadic:$inits, - I64Attr:$dimension, + IndexAttr:$dimension, BoolAttr:$is_stable ); @@ -217,4 +302,48 @@ def THLO_SortOp : THLO_DstStyleOp<"sort", [ }]; } +def THLO_ReverseOp : THLO_DstStyleOp<"reverse", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods,]> { + let summary = "Destination-style twin for the `mhlo.reverse`"; + let description = [{ + Reverses the specified dimensions of `input` according to the given + `dimensions`. + + See https://www.tensorflow.org/xla/operation_semantics#rev_reverse. + }]; + + let arguments = (ins + TensorOrMemref:$input, + TensorOrMemref:$init, + DenseI64ArrayAttr:$reverse_dimensions + ); + + let results = (outs TensorOrMemref:$result); + + let hasFolder = 1; + + let extraClassDeclaration = [{ + // Implement method necessary for DestinationStyleOpInterface. + std::pair getDpsInitsPositionRange() { + int64_t getNumOperands = this->getNumOperands(); + return {getNumOperands - 1, getNumOperands}; + } + }]; +} + +def THLO_YieldOp : THLO_Op<"yield", [Pure, ReturnLike, Terminator, + ParentOneOf<["ScatterOp", "SortOp"]>]>, + Arguments<(ins Variadic:$values)> { + let summary = "Yield operation for tHLO ops with regions."; + let assemblyFormat = "attr-dict $values `:` type($values)"; + let hasVerifier = 1; +} + #endif // THLO_OPS diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/IR/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/thlo/interfaces/CMakeLists.txt similarity index 77% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/IR/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/thlo/interfaces/CMakeLists.txt index 907ab9407ef..6ee9828d2a5 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/IR/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/thlo/interfaces/CMakeLists.txt @@ -11,24 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# include_directories(BEFORE ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) -add_mlir_dialect_library(THLODialect - thlo_ops.cc - - DEPENDS - MLIRthlo_opsIncGen - MLIRGmlStTilingInterfaceIncGen +add_mlir_library(ThloBufferizableOpInterface + bufferizable_op_interface_impl.cc LINK_LIBS PUBLIC - GmlStDialect + THLODialect + MLIRBufferizationDialect MLIRDestinationStyleOpInterface - MLIRIR - MLIRMemRefDialect - MLIRSideEffectInterfaces - MLIRSupport - MLIRTensorDialect ) diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/transforms/bufferizable_op_interface_impl.cc b/tensorflow/compiler/xla/mlir_hlo/thlo/interfaces/bufferizable_op_interface_impl.cc similarity index 92% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/transforms/bufferizable_op_interface_impl.cc rename to tensorflow/compiler/xla/mlir_hlo/thlo/interfaces/bufferizable_op_interface_impl.cc index 9f63e594236..1996644b618 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/transforms/bufferizable_op_interface_impl.cc +++ b/tensorflow/compiler/xla/mlir_hlo/thlo/interfaces/bufferizable_op_interface_impl.cc @@ -13,16 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/thlo/transforms/bufferizable_op_interface_impl.h" +#include "thlo/interfaces/bufferizable_op_interface_impl.h" -#include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "thlo/IR/thlo_ops.h" namespace mlir { namespace thlo { namespace { +using mlir::bufferization::AliasingOpOperandList; +using mlir::bufferization::AliasingOpResultList; using mlir::bufferization::AnalysisState; using mlir::bufferization::BufferizableOpInterface; using mlir::bufferization::BufferizationOptions; @@ -79,8 +81,8 @@ static LogicalResult bufferizeDestinationStyleOpInterface( // Clone the op, but use the new operands. Move the existing block into the // new op. Since the new op does not have any tensor results, it does not // return anything. - auto newOp = cast(op.cloneWithoutRegions( - rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands)); + auto newOp = cast(cloneWithoutRegions( + rewriter, op, /*resultTypes=*/TypeRange{}, newOperands)); assert(op->getNumRegions() <= 1); if (op->getNumRegions() == 1) { @@ -107,7 +109,7 @@ struct ThloSortOpBufferizationModel return cast(op).isDpsInit(&opOperand); } - SmallVector getAliasingOpOperand( + AliasingOpOperandList getAliasingOpOperands( Operation *op, OpResult opResult, const AnalysisState & /*state*/) const { auto dstStyleOp = cast(op); @@ -115,7 +117,7 @@ struct ThloSortOpBufferizationModel return {dstStyleOp.getDpsInitOperand(opResult.getResultNumber())}; } - SmallVector getAliasingOpResult( + AliasingOpResultList getAliasingOpResults( Operation *op, OpOperand &opOperand, const AnalysisState & /*state*/) const { auto dstStyleOp = cast(op); diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/transforms/bufferizable_op_interface_impl.h b/tensorflow/compiler/xla/mlir_hlo/thlo/interfaces/bufferizable_op_interface_impl.h similarity index 79% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/transforms/bufferizable_op_interface_impl.h rename to tensorflow/compiler/xla/mlir_hlo/thlo/interfaces/bufferizable_op_interface_impl.h index 9b033298725..ee35b031ac5 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/transforms/bufferizable_op_interface_impl.h +++ b/tensorflow/compiler/xla/mlir_hlo/thlo/interfaces/bufferizable_op_interface_impl.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_THLO_TRANSFORMS_BUFFERIZABLE_OP_INTERFACE_IMPL_H -#define MLIR_HLO_DIALECT_THLO_TRANSFORMS_BUFFERIZABLE_OP_INTERFACE_IMPL_H +#ifndef MLIR_HLO_THLO_INTERFACES_BUFFERIZABLE_OP_INTERFACE_IMPL_H +#define MLIR_HLO_THLO_INTERFACES_BUFFERIZABLE_OP_INTERFACE_IMPL_H namespace mlir { class DialectRegistry; @@ -26,4 +26,4 @@ void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); } // namespace thlo } // namespace mlir -#endif // MLIR_HLO_DIALECT_THLO_TRANSFORMS_BUFFERIZABLE_OP_INTERFACE_IMPL_H +#endif // MLIR_HLO_THLO_INTERFACES_BUFFERIZABLE_OP_INTERFACE_IMPL_H diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/transforms/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/thlo/transforms/CMakeLists.txt similarity index 82% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/transforms/CMakeLists.txt rename to tensorflow/compiler/xla/mlir_hlo/thlo/transforms/CMakeLists.txt index 4a9bb5fe4f8..d582d86a724 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/transforms/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/thlo/transforms/CMakeLists.txt @@ -12,21 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +set(LLVM_TARGET_DEFINITIONS thlo_passes.td) +mlir_tablegen(thlo_passes.h.inc -gen-pass-decls -name AllThlo) +add_public_tablegen_target(MLIRThloPassIncGen) + include_directories(BEFORE ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) -add_mlir_library(ThloBufferizableOpInterface - bufferizable_op_interface_impl.cc - - LINK_LIBS PUBLIC - THLODialect - MLIRBufferizationDialect - MLIRDestinationStyleOpInterface -) - add_mlir_library(ThloPasses - legalize_sort.cc + legalize_sort/legalize_sort.cc DEPENDS MLIRThloPassIncGen @@ -39,4 +35,4 @@ add_mlir_library(ThloPasses MLIRPass MLIRSCFDialect MLIRTransforms -) \ No newline at end of file +) diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/transforms/legalize_sort.cc b/tensorflow/compiler/xla/mlir_hlo/thlo/transforms/legalize_sort/legalize_sort.cc similarity index 69% rename from tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/transforms/legalize_sort.cc rename to tensorflow/compiler/xla/mlir_hlo/thlo/transforms/legalize_sort/legalize_sort.cc index 9c45a9d0577..43f9a8cc2b7 100644 --- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/thlo/transforms/legalize_sort.cc +++ b/tensorflow/compiler/xla/mlir_hlo/thlo/transforms/legalize_sort/legalize_sort.cc @@ -15,23 +15,25 @@ limitations under the License. #include #include +#include #include -#include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h" -#include "mlir-hlo/Dialect/thlo/transforms/passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "thlo/IR/thlo_ops.h" +#include "thlo/transforms/passes.h" namespace mlir { namespace thlo { #define GEN_PASS_DEF_THLOLEGALIZESORTPASS -#include "mlir-hlo/Dialect/thlo/transforms/thlo_passes.h.inc" +#include "thlo/transforms/thlo_passes.h.inc" namespace { @@ -50,7 +52,7 @@ Value emitComparison(ImplicitLocOpBuilder& b, SmallVector& lhs, assert(block.getTerminator()->getOperands().size() == 1 && "Comparator must return a single value"); - BlockAndValueMapping mapping; + IRMapping mapping; for (auto [idx, arg] : llvm::enumerate(comparator.getArguments())) { Value value = idx % 2 == 0 ? lhs[idx / 2] : rhs[idx / 2]; mapping.map(arg, value); @@ -71,44 +73,42 @@ Value emitBinarySearch(ImplicitLocOpBuilder& b, Value leftInit, Value rightInit, ArithBuilder arith(b, b.getLoc()); // while ( - auto whileOp = - b.create(types, SmallVector{leftInit, rightInit}); - OpBuilder::InsertionGuard guard(b); - - // left < right) { - Block* before = b.createBlock(&whileOp.getBefore(), {}, types, - {whileOp.getLoc(), whileOp.getLoc()}); - { - Value left = before->getArgument(0), right = before->getArgument(1); - b.setInsertionPointToEnd(before); - b.create(arith.slt(left, right), before->getArguments()); - } - - Block* after = b.createBlock(&whileOp.getAfter(), {}, types, - {whileOp.getLoc(), whileOp.getLoc()}); - { - Value left = after->getArgument(0), right = after->getArgument(1); - b.setInsertionPointToEnd(after); - // int mid = (left + right) >> 1; - Value one = b.create(1); - Value mid = b.create(arith.add(left, right), one); - Value midPlusOne = b.create(mid, one); - - auto arraysAtMid = llvm::to_vector( - llvm::map_range(arrayMemrefs, [&](Value arrayMemref) -> Value { - return b.create(arrayMemref, mid); - })); - Value cond = emitComparison(b, pivots, arraysAtMid, comparator); - // if (comparator(pivot, array[mid])) - // right = mid; - // else - // left = mid + 1; - Value newLeft = arith.select(cond, left, midPlusOne); - Value newRight = arith.select(cond, mid, right); - - // } - b.create(ValueRange{newLeft, newRight}); - } + auto whileOp = b.create( + types, SmallVector{leftInit, rightInit}, + [&](OpBuilder& beforeBuilder, Location beforeLoc, ValueRange args) { + // left < right) { + Value left = args[0], right = args[1]; + beforeBuilder.create(beforeLoc, + arith.slt(left, right), args); + }, + [&](OpBuilder& afterBuilder, Location afterLoc, ValueRange args) { + ImplicitLocOpBuilder impLocAfterBuilder = + ImplicitLocOpBuilder(afterLoc, afterBuilder); + Value left = args[0], right = args[1]; + // int mid = (left + right) >> 1; + Value one = impLocAfterBuilder.create(1); + Value mid = impLocAfterBuilder.create( + arith.add(left, right), one); + Value midPlusOne = impLocAfterBuilder.create(mid, one); + + auto arraysAtMid = llvm::to_vector( + llvm::map_range(arrayMemrefs, [&](Value arrayMemref) -> Value { + return impLocAfterBuilder.create(arrayMemref, + mid); + })); + + Value cond = + emitComparison(impLocAfterBuilder, pivots, arraysAtMid, comparator); + // if (comparator(pivot, array[mid])) + // right = mid; + // else + // left = mid + 1; + Value newLeft = arith.select(cond, left, midPlusOne); + Value newRight = arith.select(cond, mid, right); + + // } + impLocAfterBuilder.create(ValueRange{newLeft, newRight}); + }); return whileOp.getResult(0); } @@ -195,51 +195,48 @@ void emitMerge(ImplicitLocOpBuilder& b, Value lo, Value mid, Value hi, SmallVector whileArgLocs(whileArgTypes.size(), b.getLoc()); // while( - auto whileOp = b.create(whileArgTypes, whileInitArgs); - { - OpBuilder::InsertionGuard guard(b); - { - Block* before = - b.createBlock(&whileOp.getBefore(), {}, whileArgTypes, whileArgLocs); - Value i0 = before->getArgument(1), i1 = before->getArgument(2); - b.setInsertionPointToEnd(before); - - // i0 < mid && i1 < hi) { - Value inbounds0 = arith.slt(i0, mid); - Value inbounds1 = arith.slt(i1, hi); - - b.create(arith._and(inbounds0, inbounds1), - before->getArguments()); - } - - { - Block* after = - b.createBlock(&whileOp.getAfter(), {}, whileArgTypes, whileArgLocs); - Value iOut = after->getArgument(0), i0 = after->getArgument(1), - i1 = after->getArgument(2); - b.setInsertionPointToEnd(after); - - // auto vals0 = readBufs[i0], vals1 = readBufs[i1]; - SmallVector vals0 = loadMemrefElements(b, readBufs, i0); - SmallVector vals1 = loadMemrefElements(b, readBufs, i1); - - // writeBufs[iOut] = comparator(vals1, vals0) - // ? readBufs[i1++] : readBufs[i0++]; - Value cmp = emitComparison(b, vals1, vals0, comparator); - SmallVector pickedVals; - for (auto [val0, val1] : llvm::zip(vals0, vals1)) { - pickedVals.push_back(b.create(cmp, val1, val0)); - } - storeMemrefElements(b, writeBufs, iOut, pickedVals); - - Value one = b.create(1); - Value nexti0 = b.create(cmp, i0, arith.add(i0, one)); - Value nexti1 = b.create(cmp, arith.add(i1, one), i1); - // ++iOut; - Value nextIOut = b.create(iOut, one); - b.create(ValueRange{nextIOut, nexti0, nexti1}); - } - } + auto whileOp = b.create( + whileArgTypes, whileInitArgs, + [&](OpBuilder& beforeBuilder, Location beforeLoc, ValueRange args) { + Value i0 = args[1], i1 = args[2]; + + // i0 < mid && i1 < hi) { + Value inbounds0 = arith.slt(i0, mid); + Value inbounds1 = arith.slt(i1, hi); + beforeBuilder.create( + beforeLoc, arith._and(inbounds0, inbounds1), args); + }, + [&](OpBuilder& afterBuilder, Location afterLoc, ValueRange args) { + ImplicitLocOpBuilder impLocAfterBuilder(afterLoc, afterBuilder); + Value iOut = args[0], i0 = args[1], i1 = args[2]; + + // auto vals0 = readBufs[i0], vals1 = readBufs[i1]; + SmallVector vals0 = + loadMemrefElements(impLocAfterBuilder, readBufs, i0); + SmallVector vals1 = + loadMemrefElements(impLocAfterBuilder, readBufs, i1); + + // writeBufs[iOut] = comparator(vals1, vals0) + // ? readBufs[i1++] : readBufs[i0++]; + Value cmp = + emitComparison(impLocAfterBuilder, vals1, vals0, comparator); + SmallVector pickedVals; + for (auto [val0, val1] : llvm::zip(vals0, vals1)) { + pickedVals.push_back( + impLocAfterBuilder.create(cmp, val1, val0)); + } + storeMemrefElements(impLocAfterBuilder, writeBufs, iOut, pickedVals); + Value one = impLocAfterBuilder.create(1); + Value nexti0 = + impLocAfterBuilder.create(cmp, i0, arith.add(i0, one)); + Value nexti1 = + impLocAfterBuilder.create(cmp, arith.add(i1, one), i1); + + // ++iOut; + Value nextIOut = impLocAfterBuilder.create(iOut, one); + impLocAfterBuilder.create( + ValueRange{nextIOut, nexti0, nexti1}); + }); // At this point, exactly one of the input ranges will have leftover elements. Value iOut = whileOp->getResult(0); @@ -286,12 +283,13 @@ Value emitBottomUpMergeSort(ImplicitLocOpBuilder& b, Value lo, Value hi, b.create(ValueRange{}); }; b.create(/*lowerBound=*/zero, /*upperBound=*/size, - /*step=*/insertionSortSize, /*iterArgs=*/llvm::None, + /*step=*/insertionSortSize, /*iterArgs=*/std::nullopt, forBody); } Value initParity = b.create(/*value=*/0, /*width=*/1); - if (staticSortDimSize >= 0 && staticSortDimSize < kInsertionSortSize) { + if (staticSortDimSize >= 0 && + staticSortDimSize < static_cast(kInsertionSortSize)) { return initParity; } @@ -314,54 +312,53 @@ Value emitBottomUpMergeSort(ImplicitLocOpBuilder& b, Value lo, Value hi, SmallVector whileArgLocs(whileArgTypes.size(), b.getLoc()); // while ( - auto whileOp = b.create(whileArgTypes, whileInitArgs); - OpBuilder::InsertionGuard guard(b); - - // currentSize < totalSize) - { - Block* before = - b.createBlock(&whileOp.getBefore(), {}, whileArgTypes, whileArgLocs); - Value currentSize = before->getArgument(0); - b.setInsertionPointToEnd(before); - b.create(arith.slt(currentSize, size), - before->getArguments()); - } - - size_t numArgs = inputMemrefs.size(); - // { - { - Block* after = - b.createBlock(&whileOp.getAfter(), {}, whileArgTypes, whileArgLocs); - - Value currentSize = after->getArgument(0); - Value parity = after->getArgument(1); - auto readBufs = after->getArguments().drop_front(2).take_front(numArgs); - auto writeBufs = after->getArguments().take_back(numArgs); - - Value twoCurrentSize = arith.add(currentSize, currentSize); - - // for (int start = 0; start < size; start += 2*currentSize) { - { - auto forOp = b.create(zero, size, twoCurrentSize); - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(forOp.getBody()); - Value start = forOp.getInductionVar(); - - Value mid = b.create(size, arith.add(start, currentSize)); - Value end = b.create(size, arith.add(start, twoCurrentSize)); - emitMerge(b, start, mid, end, readBufs, writeBufs, comparator); - } - // } - - // parity = !parity; - Value one = b.create(1, 1); - Value notParity = arith.sub(one, parity); - // currentSize *= 2; - SmallVector nextWhileArgs{twoCurrentSize, notParity}; - llvm::copy(writeBufs, std::back_inserter(nextWhileArgs)); - llvm::copy(readBufs, std::back_inserter(nextWhileArgs)); - b.create(nextWhileArgs); - } + auto whileOp = b.create( + whileArgTypes, whileInitArgs, + [&](OpBuilder& beforeBuilder, Location beforeLoc, ValueRange args) { + // currentSize < totalSize) + Value currentSize = args[0]; + beforeBuilder.create( + beforeLoc, arith.slt(currentSize, size), args); + }, + [&](OpBuilder& afterBuilder, Location afterLoc, ValueRange args) { + ImplicitLocOpBuilder impLocAfterBuilder = + ImplicitLocOpBuilder(afterLoc, afterBuilder); + ArithBuilder localArithBuilder(impLocAfterBuilder, afterLoc); + size_t numArgs = inputMemrefs.size(); + + // { + Value currentSize = args[0], parity = args[1]; + auto readBufs = args.drop_front(2).take_front(numArgs); + auto writeBufs = args.take_back(numArgs); + + Value twoCurrentSize = arith.add(currentSize, currentSize); + + // for (int start = 0; start < size; start += 2*currentSize) { + { + auto forOp = + impLocAfterBuilder.create(zero, size, twoCurrentSize); + OpBuilder::InsertionGuard guard(impLocAfterBuilder); + impLocAfterBuilder.setInsertionPointToStart(forOp.getBody()); + Value start = forOp.getInductionVar(); + + Value mid = impLocAfterBuilder.create( + size, localArithBuilder.add(start, currentSize)); + Value end = impLocAfterBuilder.create( + size, localArithBuilder.add(start, twoCurrentSize)); + emitMerge(impLocAfterBuilder, start, mid, end, readBufs, writeBufs, + comparator); + } + // } + + // parity = !parity; + Value one = impLocAfterBuilder.create(1, 1); + Value notParity = arith.sub(one, parity); + // currentSize *= 2; + SmallVector nextWhileArgs{twoCurrentSize, notParity}; + llvm::copy(writeBufs, std::back_inserter(nextWhileArgs)); + llvm::copy(readBufs, std::back_inserter(nextWhileArgs)); + impLocAfterBuilder.create(nextWhileArgs); + }); // } // The result is the parity bit. @@ -369,27 +366,27 @@ Value emitBottomUpMergeSort(ImplicitLocOpBuilder& b, Value lo, Value hi, } struct Slicer { - Slicer(OpBuilder& b, uint64_t sortDim, Value sortDimSize, + Slicer(OpBuilder& b, int64_t sortDim, Value sortDimSize, ValueRange inductionVariables) : sizes(inductionVariables.size() + 1, b.getI64IntegerAttr(1)), strides(inductionVariables.size() + 1, b.getI64IntegerAttr(1)) { sizes[sortDim] = sortDimSize; for (size_t i = 0; i < inductionVariables.size() + 1; ++i) { - if (i == sortDim) { + if ((int64_t)i == sortDim) { offsets.push_back(b.getI64IntegerAttr(0)); } else { offsets.push_back( - inductionVariables[i - static_cast(i > sortDim)]); + inductionVariables[i - static_cast((int64_t)i > sortDim)]); } } } Value slice(ImplicitLocOpBuilder& b, Value input) { auto ty = input.getType().cast(); - auto slicedType = memref::SubViewOp::inferRankReducedResultType( - {ShapedType::kDynamicSize} /*1D output*/, ty, offsets, - sizes, strides) - .cast(); + auto slicedType = + memref::SubViewOp::inferRankReducedResultType( + {ShapedType::kDynamic} /*1D output*/, ty, offsets, sizes, strides) + .cast(); return b .create(slicedType, input, offsets, sizes, strides) .getResult(); @@ -407,11 +404,10 @@ SmallVector sliceMemrefs(ImplicitLocOpBuilder& b, if (inductionVariables.empty()) return memrefs; SmallVector slices; - Slicer slicer(b, op.getDimension(), sortDimSize, inductionVariables); + Slicer slicer(b, op.getDimension().getSExtValue(), sortDimSize, + inductionVariables); - for (Value out : memrefs) { - slices.push_back(slicer.slice(b, out)); - } + for (Value out : memrefs) slices.push_back(slicer.slice(b, out)); return slices; } @@ -437,7 +433,7 @@ struct SortOpPattern : public OpRewritePattern { auto firstInputType = firstInput.getType().cast(); int64_t inputRank = firstInputType.getRank(); - int64_t sortDim = op.getDimension(); + int64_t sortDim = op.getDimension().getSExtValue(); Value sortDimSize = b.createOrFold( firstInput, b.create(sortDim)); int64_t staticSortDimSize = firstInputType.getDimSize(sortDim); diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/transforms/passes.h b/tensorflow/compiler/xla/mlir_hlo/thlo/transforms/passes.h similarity index 79% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/transforms/passes.h rename to tensorflow/compiler/xla/mlir_hlo/thlo/transforms/passes.h index 67b99a75f8c..7ac8499f714 100644 --- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/transforms/passes.h +++ b/tensorflow/compiler/xla/mlir_hlo/thlo/transforms/passes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MLIR_HLO_DIALECT_THLO_TRANSFORMS_PASSES_H -#define MLIR_HLO_DIALECT_THLO_TRANSFORMS_PASSES_H +#ifndef MLIR_HLO_THLO_TRANSFORMS_PASSES_H +#define MLIR_HLO_THLO_TRANSFORMS_PASSES_H #include @@ -32,15 +32,15 @@ class FuncOp; namespace thlo { #define GEN_PASS_DECL_THLOLEGALIZESORTPASS -#include "mlir-hlo/Dialect/thlo/transforms/thlo_passes.h.inc" +#include "thlo/transforms/thlo_passes.h.inc" /// Lowers sort to Arith, MemRef, and SCF std::unique_ptr> createLegalizeSortPass(); #define GEN_PASS_REGISTRATION -#include "mlir-hlo/Dialect/thlo/transforms/thlo_passes.h.inc" +#include "thlo/transforms/thlo_passes.h.inc" } // namespace thlo } // namespace mlir -#endif // MLIR_HLO_DIALECT_THLO_TRANSFORMS_PASSES_H +#endif // MLIR_HLO_THLO_TRANSFORMS_PASSES_H diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/transforms/thlo_passes.td b/tensorflow/compiler/xla/mlir_hlo/thlo/transforms/thlo_passes.td similarity index 100% rename from tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/thlo/transforms/thlo_passes.td rename to tensorflow/compiler/xla/mlir_hlo/thlo/transforms/thlo_passes.td diff --git a/tensorflow/compiler/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt index 6f12b949feb..74ea46a5dc0 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt +++ b/tensorflow/compiler/xla/mlir_hlo/tools/mlir-hlo-opt/CMakeLists.txt @@ -22,9 +22,11 @@ set(LIBS AllGmlStPasses GmlStDialect + GmlStPasses LmhloDialect LmhloGPUDialect MhloRegisterDialects + MhloTestAnalysis LmhloPasses AllMhloPasses AllThloPasses diff --git a/tensorflow/compiler/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc b/tensorflow/compiler/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc index c66fbf7db45..d1ea67141a8 100644 --- a/tensorflow/compiler/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc +++ b/tensorflow/compiler/xla/mlir_hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cc @@ -13,26 +13,26 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h" -#include "mlir-hlo/Dialect/gml_st/transforms/passes.h" -#include "mlir-hlo/Dialect/gml_st/transforms/test_passes.h" -#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h" -#include "mlir-hlo/Dialect/lhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h" -#include "mlir-hlo/Dialect/mhlo/IR/register.h" -#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "mlir-hlo/Dialect/thlo/IR/thlo_ops.h" -#include "mlir-hlo/Dialect/thlo/transforms/passes.h" -#include "mlir-hlo/Transforms/gpu_passes.h" -#include "mlir-hlo/Transforms/passes.h" +#include "gml_st/IR/gml_st_ops.h" +#include "gml_st/transforms/passes.h" +#include "gml_st/transforms/test_passes.h" +#include "lhlo/IR/lhlo_ops.h" +#include "lhlo/transforms/passes.h" +#include "lhlo_gpu/IR/lhlo_gpu_ops.h" +#include "mhlo/IR/register.h" +#include "mhlo/transforms/passes.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "stablehlo/dialect/Register.h" +#include "thlo/IR/thlo_ops.h" +#include "thlo/transforms/passes.h" +#include "transforms/gpu_passes.h" +#include "transforms/passes.h" using namespace mlir; -int main(int argc, char **argv) { +int main(int argc, char** argv) { mlir::registerAllPasses(); mlir::hlo::registerLMHLOTransformsPasses(); mlir::registerLMHLOGPUTransformsPasses(); @@ -67,6 +67,22 @@ int main(int argc, char **argv) { opts.threadTileDim, opts.experimentalSoftmax); }); + mlir::PassPipelineRegistration + gmlStCpuTilingPipeline("gml-st-cpu-tiling-pipeline", + "Tiles, fuses, vectorizes tileable ops for CPU", + gml_st::addCPUTilingPipeline); + + struct HloToTritonPipelineOptions + : public PassPipelineOptions { + ListOption blockTileDim{ + *this, "block-tile", + llvm::cl::desc("dimensions of the subproblem processed by the block")}; + }; + mlir::PassPipelineRegistration( + "hlo-to-triton-pipeline", "Pipeline to transform HLO to Triton dialect.", + [](OpPassManager& pm, const HloToTritonPipelineOptions& opts) { + return createHloToTritonPipeline(pm, opts.blockTileDim); + }); mlir::DialectRegistry registry; mlir::registerAllDialects(registry); diff --git a/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/affine.cc b/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/affine.cc new file mode 100644 index 00000000000..f237e6e7aba --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/affine.cc @@ -0,0 +1,49 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "tools/mlir_interpreter/dialects/util.h" +#include "tools/mlir_interpreter/framework/interpreter.h" +#include "tools/mlir_interpreter/framework/interpreter_value_util.h" +#include "tools/mlir_interpreter/framework/registration.h" + +namespace mlir { +namespace interpreter { +namespace { + +llvm::SmallVector apply(InterpreterState&, AffineApplyOp op, + ArrayRef operands) { + return evalAffineMap(op.getAffineMap(), operands); +} + +int64_t min(InterpreterState&, AffineMinOp op, ArrayRef operands) { + auto results = evalAffineMap(op.getAffineMap(), operands); + return *std::min_element(results.begin(), results.end()); +} + +int64_t max(InterpreterState&, AffineMaxOp op, ArrayRef operands) { + auto results = evalAffineMap(op.getAffineMap(), operands); + return *std::max_element(results.begin(), results.end()); +} + +REGISTER_MLIR_INTERPRETER_OP(apply); +REGISTER_MLIR_INTERPRETER_OP(max); +REGISTER_MLIR_INTERPRETER_OP(min); + +} // namespace +} // namespace interpreter +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/arith.cc b/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/arith.cc new file mode 100644 index 00000000000..65bc9c4cdff --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/arith.cc @@ -0,0 +1,237 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/Arith/IR/Arith.h" + +#include // NOLINT + +#include "llvm/Support/ErrorHandling.h" +#include "tools/mlir_interpreter/dialects/comparators.h" +#include "tools/mlir_interpreter/dialects/cwise_math.h" +#include "tools/mlir_interpreter/framework/interpreter.h" +#include "tools/mlir_interpreter/framework/interpreter_value.h" +#include "tools/mlir_interpreter/framework/interpreter_value_util.h" +#include "tools/mlir_interpreter/framework/registration.h" + +namespace mlir { +namespace interpreter { +namespace { + +InterpreterValue constant(InterpreterState&, arith::ConstantOp constant) { + auto ty = constant->getResultTypes()[0]; + auto shapedType = ty.dyn_cast(); + auto elemTy = shapedType ? shapedType.getElementType() : ty; + return dispatchScalarType(elemTy, [&](auto dummy) -> InterpreterValue { + using T = decltype(dummy); + if (shapedType) { + auto values = + constant.getValue().cast().getValues(); + auto result = TensorOrMemref::empty(shapedType.getShape()); + auto valueIt = values.begin(); + result.view.isVector = shapedType.isa(); + for (const auto& index : result.view.indices(true)) { + result.at(index) = *valueIt; + ++valueIt; + } + return {result}; + } + + auto value = constant.getValue(); + if (auto integer = value.dyn_cast()) { + return {static_cast(integer.getInt())}; + } + if (auto floatValue = value.dyn_cast()) { + return {static_cast(floatValue.getValueAsDouble())}; + } + + llvm_unreachable("unsupported constant type"); + }); +} + +template +InterpreterValue intCast(InterpreterState&, Op op, + const InterpreterValue& arg) { + if (arg.isTensor()) { + return dispatchScalarType( + op->getResultTypes()[0], [&](auto dummy) -> InterpreterValue { + auto result = + TensorOrMemref::empty(arg.view().sizes); + for (const auto& index : result.view.indices()) { + result.at(index) = + static_cast(arg.extractElement(index).asInt()); + } + return {result}; + }); + } + + return dispatchScalarType( + op->getResultTypes()[0], [&](auto dummy) -> InterpreterValue { + return {static_cast(arg.asInt())}; + }); +} + +llvm::SmallVector uiToFP( + MutableArrayRef args, mlir::Operation* op, + InterpreterState&) { + if (args[0].isTensor()) { + auto ty = op->getResultTypes()[0].cast(); + return {dispatchScalarType( + ty.getElementType(), [&](auto dummy) -> InterpreterValue { + auto result = TensorOrMemref::empty(ty.getShape()); + for (const auto& index : result.view.indices()) { + result.at(index) = static_cast( + args[0].extractElement(index).asUInt()); + } + return {result}; + })}; + } + + return {dispatchScalarType( + op->getResultTypes()[0], [&](auto dummy) -> InterpreterValue { + return {static_cast(args[0].asUInt())}; + })}; +} + +InterpreterValue cmpI(InterpreterState&, arith::CmpIOp compare, + const InterpreterValue& lhs, + const InterpreterValue& rhs) { + switch (compare.getPredicate()) { + case arith::CmpIPredicate::eq: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpIPredicate::ne: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpIPredicate::slt: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpIPredicate::sle: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpIPredicate::sgt: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpIPredicate::sge: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpIPredicate::ult: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpIPredicate::ule: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpIPredicate::ugt: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpIPredicate::uge: + return applyCwiseBinaryMap(lhs, rhs); + } +} + +template +struct ConstFunctor : CwiseAll { + template + static bool apply(T, T) { + return value; + } +}; + +InterpreterValue cmpF(InterpreterState&, arith::CmpFOp compare, + const InterpreterValue& lhs, + const InterpreterValue& rhs) { + switch (compare.getPredicate()) { + case arith::CmpFPredicate::AlwaysFalse: + return applyCwiseBinaryMap>(lhs, rhs); + case arith::CmpFPredicate::OEQ: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpFPredicate::OGT: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpFPredicate::OGE: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpFPredicate::OLT: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpFPredicate::OLE: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpFPredicate::ONE: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpFPredicate::ORD: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpFPredicate::UEQ: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpFPredicate::UGT: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpFPredicate::UGE: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpFPredicate::ULT: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpFPredicate::ULE: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpFPredicate::UNE: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpFPredicate::UNO: + return applyCwiseBinaryMap(lhs, rhs); + case arith::CmpFPredicate::AlwaysTrue: + return applyCwiseBinaryMap>(lhs, rhs); + } +} + +InterpreterValue select(InterpreterState&, arith::SelectOp, + const InterpreterValue& cond, + const InterpreterValue& trueValue, + const InterpreterValue& falseValue) { + return std::get(cond.storage) ? trueValue : falseValue; +} + +template +struct ExtFFunctor : CwiseFloat { + template + static R apply(A v) { + return v; + } +}; + +InterpreterValue extF(InterpreterState&, arith::ExtFOp op, + const InterpreterValue& in) { + return dispatchScalarType( + op->getResultTypes()[0], [&](auto dummy) -> InterpreterValue { + return applyCwiseMap>(in); + }); +} + +REGISTER_MLIR_INTERPRETER_OP("arith.addf", applyCwiseBinaryMap); +REGISTER_MLIR_INTERPRETER_OP("arith.andi", applyCwiseBinaryMap); +REGISTER_MLIR_INTERPRETER_OP("arith.divf", applyCwiseBinaryMap); +REGISTER_MLIR_INTERPRETER_OP("arith.extui", uiToFP); +REGISTER_MLIR_INTERPRETER_OP("arith.maxf", applyCwiseBinaryMap); +REGISTER_MLIR_INTERPRETER_OP("arith.minf", applyCwiseBinaryMap); +REGISTER_MLIR_INTERPRETER_OP("arith.mulf", applyCwiseBinaryMap); +REGISTER_MLIR_INTERPRETER_OP("arith.negf", applyCwiseMap); +REGISTER_MLIR_INTERPRETER_OP("arith.ori", applyCwiseBinaryMap); +REGISTER_MLIR_INTERPRETER_OP("arith.remf", applyCwiseBinaryMap); +REGISTER_MLIR_INTERPRETER_OP("arith.subf", applyCwiseBinaryMap); +REGISTER_MLIR_INTERPRETER_OP("arith.uitofp", uiToFP); +REGISTER_MLIR_INTERPRETER_OP("arith.xori", applyCwiseBinaryMap); + +// The float implementations support ints too. +REGISTER_MLIR_INTERPRETER_OP("arith.addi", "arith.addf"); +REGISTER_MLIR_INTERPRETER_OP("arith.divsi", "arith.divf"); +REGISTER_MLIR_INTERPRETER_OP("arith.maxsi", "arith.maxf"); +REGISTER_MLIR_INTERPRETER_OP("arith.minsi", "arith.minf"); +REGISTER_MLIR_INTERPRETER_OP("arith.muli", "arith.mulf"); +REGISTER_MLIR_INTERPRETER_OP("arith.subi", "arith.subf"); + +REGISTER_MLIR_INTERPRETER_OP(cmpF); +REGISTER_MLIR_INTERPRETER_OP(cmpI); +REGISTER_MLIR_INTERPRETER_OP(constant); +REGISTER_MLIR_INTERPRETER_OP(extF); +REGISTER_MLIR_INTERPRETER_OP(intCast); +REGISTER_MLIR_INTERPRETER_OP(intCast); +REGISTER_MLIR_INTERPRETER_OP(intCast); +REGISTER_MLIR_INTERPRETER_OP(select); + +} // namespace +} // namespace interpreter +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/bufferization.cc b/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/bufferization.cc new file mode 100644 index 00000000000..1b65950c4d7 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/bufferization.cc @@ -0,0 +1,69 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" + +#include // NOLINT +#include // NOLINT + +#include "tools/mlir_interpreter/dialects/util.h" +#include "tools/mlir_interpreter/framework/interpreter.h" +#include "tools/mlir_interpreter/framework/registration.h" + +namespace mlir { +namespace interpreter { +namespace { + +InterpreterValue toTensor(InterpreterState&, bufferization::ToTensorOp, + const InterpreterValue& in) { + return in.clone(); +} + +InterpreterValue toMemref(InterpreterState&, bufferization::ToMemrefOp, + const InterpreterValue& in) { + return in; +} + +InterpreterValue allocTensor( + InterpreterState&, bufferization::AllocTensorOp alloc, + ArrayRef dynamicSizes, std::optional copy, + const std::optional& /*sizeHint*/) { + auto ty = alloc->getResultTypes().front().cast(); + auto shape = replaceDynamicVals(ty.getShape(), dynamicSizes); + + if (copy) { + return copy->clone(); + } + return InterpreterValue::makeTensor(ty.getElementType(), shape); +} + +InterpreterValue clone(InterpreterState& state, bufferization::CloneOp, + const InterpreterValue& in) { + if (auto* stats = state.getOptions().stats) { + stats->heapSize += in.buffer()->getByteSize(); + stats->peakHeapSize = std::max(stats->peakHeapSize, stats->heapSize); + ++stats->numAllocations; + } + return in.clone(); +} + +REGISTER_MLIR_INTERPRETER_OP(allocTensor); +REGISTER_MLIR_INTERPRETER_OP(clone); +REGISTER_MLIR_INTERPRETER_OP(toMemref); +REGISTER_MLIR_INTERPRETER_OP(toTensor); + +} // namespace +} // namespace interpreter +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/builtin.cc b/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/builtin.cc new file mode 100644 index 00000000000..85303b4f702 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/builtin.cc @@ -0,0 +1,50 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tools/mlir_interpreter/framework/interpreter.h" +#include "tools/mlir_interpreter/framework/registration.h" + +namespace mlir { +namespace interpreter { +namespace { + +llvm::SmallVector unrealizedConversionCast( + MutableArrayRef args, mlir::Operation* op, + InterpreterState& state) { + auto resultTy = op->getResultTypes()[0]; + auto operandTy = op->getOperandTypes()[0]; + if (resultTy == operandTy) { + return {args[0]}; + } + + if (auto r = llvm::dyn_cast(resultTy)) { + if (auto o = llvm::dyn_cast(operandTy)) { + if (r.getElementType() == o.getElementType() && + r.getRank() == o.getRank()) { + return {args[0]}; + } + } + } + + llvm::errs() << "Unimplemented cast: " << *op << "\n"; + llvm_unreachable("unimplemented cast"); +} + +REGISTER_MLIR_INTERPRETER_OP("builtin.unrealized_conversion_cast", + unrealizedConversionCast); + +} // namespace +} // namespace interpreter +} // namespace mlir diff --git a/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/comparators.h b/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/comparators.h new file mode 100644 index 00000000000..00ca86c22fc --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/comparators.h @@ -0,0 +1,104 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MLIR_HLO_TOOLS_MLIR_INTERPRETER_DIALECTS_COMPARATORS_H_ +#define MLIR_HLO_TOOLS_MLIR_INTERPRETER_DIALECTS_COMPARATORS_H_ + +#include +#include + +#include "llvm/Support/ErrorHandling.h" +#include "tools/mlir_interpreter/framework/interpreter_value_util.h" + +namespace mlir { +namespace interpreter { + +// Despite the name, this works on integers and complex too. +template +struct FloatCompare : CwiseArith { + template + static bool apply(T a, T b) { + if (isnan(a) || isnan(b)) return nan_result; + if constexpr (v == 0) { + // For complex eq/ne. + return (a == b) == r; + } else if constexpr (std::is_floating_point_v || std::is_integral_v) { + auto cmp = a > b ? 1 : (a < b ? -1 : 0); + return (cmp == v) == r; + } else { + llvm_unreachable("operation not supported for this type"); + } + } + + template + static bool isnan(T a) { + return std::isnan(a); + } + template + static bool isnan(std::complex a) { + return std::isnan(std::real(a)) || std::isnan(std::imag(a)); + } +}; + +using Foeq = FloatCompare<0, true, false>; +using Foge = FloatCompare<-1, false, false>; +using Fogt = FloatCompare<1, true, false>; +using Fole = FloatCompare<1, false, false>; +using Folt = FloatCompare<-1, true, false>; +using Fone = FloatCompare<0, false, false>; +using Ford = FloatCompare<99, false, false>; +using Fueq = FloatCompare<0, true, true>; +using Fuge = FloatCompare<-1, false, true>; +using Fugt = FloatCompare<1, true, true>; +using Fule = FloatCompare<1, false, true>; +using Fult = FloatCompare<-1, true, true>; +using Fune = FloatCompare<0, false, true>; +using Funo = FloatCompare<99, true, true>; + +template +struct UnsignedCompare : CwiseInt { + template + static bool apply(T a, T b) { + using U = std::make_unsigned_t; + auto aU = static_cast(a); + auto bU = static_cast(b); + auto cmp = aU > bU ? 1 : (aU < bU ? -1 : 0); + return (cmp == v) == r; + } +}; + +using Iuge = UnsignedCompare<-1, false>; +using Iule = UnsignedCompare<1, false>; +using Iugt = UnsignedCompare<1, true>; +using Iult = UnsignedCompare<-1, true>; + +struct Iumax { + template + static T apply(T a, T b) { + return Iuge::apply(a, b) ? a : b; + } +}; + +struct Iumin { + template + static T apply(T a, T b) { + return Iule::apply(a, b) ? a : b; + } +}; + +} // namespace interpreter +} // namespace mlir + +#endif // MLIR_HLO_TOOLS_MLIR_INTERPRETER_DIALECTS_COMPARATORS_H_ diff --git a/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/cwise_math.h b/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/cwise_math.h new file mode 100644 index 00000000000..c539f33b410 --- /dev/null +++ b/tensorflow/compiler/xla/mlir_hlo/tools/mlir_interpreter/dialects/cwise_math.h @@ -0,0 +1,193 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MLIR_HLO_TOOLS_MLIR_INTERPRETER_DIALECTS_CWISE_MATH_H_ +#define MLIR_HLO_TOOLS_MLIR_INTERPRETER_DIALECTS_CWISE_MATH_H_ + +#include +#include + +#include "tools/mlir_interpreter/framework/interpreter_value_util.h" + +namespace mlir { +namespace interpreter { + +struct ATan2 : CwiseReal { + template + static T apply(T a, T b) { + return std::atan2(a, b); + } +}; + +struct Complex : CwiseFloat { + template + static std::complex apply(T a, T b) { + return {a, b}; + } +}; + +struct Max : CwiseReal { + template + static T apply(T a, T b) { + return std::max(a, b); + } +}; + +struct Min : CwiseReal { + template + static T apply(T a, T b) { + return std::min(a, b); + } +}; + +struct Power : CwiseArith { + template + static T apply(T a, T b) { + if constexpr (std::is_integral_v) { + if constexpr (std::is_signed_v) { + if (b < 0) { + return a == 1 ? 1 : 0; + } + } + T result = 1; + while (b > 0) { + if (b & 1) result *= a; + b >>= 1; + if (b) { + a *= a; + } + } + return result; + } else { + return std::pow(a, b); + } + } +}; + +struct Remainder : CwiseReal { + template + static T apply(T a, T b) { + if constexpr (std::is_integral_v) { + return a % b; + } else { + return std::fmod(a, b); + } + } +}; + +struct ShiftRightArith : CwiseInt { + template + static T apply(T a, T b) { + return b >= sizeof(T) * CHAR_BIT ? 0 : (a >> b); + } +}; + +struct ShiftRightLogical : CwiseInt { + template + static T apply(T a, T b) { + return b >= sizeof(T) * CHAR_BIT + ? 0 + : static_cast>(a) >> b; + } +}; + +struct ShiftLeft : CwiseInt { + template + static T apply(T a, T b) { + return b >= sizeof(T) * CHAR_BIT ? 0 : (a << b); + } +}; + +namespace detail { +template